diff options
Diffstat (limited to 'src')
278 files changed, 22531 insertions, 13088 deletions
diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index 93108dd47..3373506f2 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -36,10 +36,10 @@ #include <zenremotestore/chunking/chunkedfile.h> #include <zenremotestore/chunking/chunkingcache.h> #include <zenremotestore/chunking/chunkingcontroller.h> -#include <zenremotestore/filesystemutils.h> #include <zenremotestore/jupiter/jupiterhost.h> -#include <zenremotestore/operationlogoutput.h> #include <zenremotestore/transferthreadworkers.h> +#include <zenutil/filesystemutils.h> +#include <zenutil/progress.h> #include <zenutil/wildcard.h> #include <zenutil/workerpools.h> #include <zenutil/zenserverprocess.h> @@ -387,7 +387,8 @@ namespace builds_impl { return CleanAndRemoveDirectory(WorkerPool, AbortFlag, PauseFlag, Directory); } - void ValidateBuildPart(OperationLogOutput& Output, + void ValidateBuildPart(LoggerRef Log, + ProgressBase& Progress, TransferThreadWorkers& Workers, BuildStorageBase& Storage, const Oid& BuildId, @@ -398,7 +399,8 @@ namespace builds_impl { ProgressBar::SetLogOperationName(ProgressMode, "Validate Part"); - BuildsOperationValidateBuildPart ValidateOp(Output, + BuildsOperationValidateBuildPart ValidateOp(Log, + Progress, Storage, AbortFlag, PauseFlag, @@ -434,7 +436,8 @@ namespace builds_impl { const std::vector<std::string>& ExcludeExtensions = DefaultExcludeExtensions; }; - std::vector<std::pair<Oid, std::string>> UploadFolder(OperationLogOutput& Output, + std::vector<std::pair<Oid, std::string>> UploadFolder(LoggerRef Log, + ProgressBase& Progress, TransferThreadWorkers& Workers, StorageInstance& Storage, const Oid& BuildId, @@ -452,7 +455,8 @@ namespace builds_impl { Stopwatch UploadTimer; BuildsOperationUploadFolder UploadOp( - Output, + Log, + Progress, Storage, AbortFlag, PauseFlag, @@ -1232,7 +1236,6 @@ namespace builds_impl { EPartialBlockRequestMode PartialBlockRequestMode = EPartialBlockRequestMode::Mixed; bool CleanTargetFolder = false; bool PostDownloadVerify = false; - bool PrimeCacheOnly = false; bool EnableOtherDownloadsScavenging = true; bool EnableTargetFolderScavenging = true; bool AllowFileClone = true; @@ -1244,7 +1247,8 @@ namespace builds_impl { std::vector<std::string> ExcludeFolders = DefaultExcludeFolders; }; - void DownloadFolder(OperationLogOutput& Output, + void DownloadFolder(LoggerRef InLog, + ProgressBase& Progress, TransferThreadWorkers& Workers, StorageInstance& Storage, const BuildStorageCache::Statistics& StorageCacheStats, @@ -1256,6 +1260,7 @@ namespace builds_impl { const DownloadOptions& Options) { ZEN_TRACE_CPU("DownloadFolder"); + ZEN_SCOPED_LOG(InLog); ProgressBar::SetLogOperationName(ProgressMode, "Download Folder"); @@ -1272,9 +1277,6 @@ namespace builds_impl { auto EndProgress = MakeGuard([&]() { ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::StepCount, TaskSteps::StepCount); }); - ZEN_ASSERT((!Options.PrimeCacheOnly) || - (Options.PrimeCacheOnly && (Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off))); - Stopwatch DownloadTimer; ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::CheckState, TaskSteps::StepCount); @@ -1306,7 +1308,7 @@ namespace builds_impl { ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::CompareState, TaskSteps::StepCount); - ChunkedFolderContent RemoteContent = GetRemoteContent(Output, + ChunkedFolderContent RemoteContent = GetRemoteContent(InLog, Storage, BuildId, AllBuildParts, @@ -1327,75 +1329,67 @@ namespace builds_impl { BuildSaveState LocalState; - if (!Options.PrimeCacheOnly) + if (IsDir(Path)) { - if (IsDir(Path)) + if (!ChunkController && !IsQuiet) { - if (!ChunkController && !IsQuiet) - { - ZEN_CONSOLE_INFO("Unspecified chunking algorithm, using default"); - ChunkController = CreateStandardChunkingController(StandardChunkingControllerSettings{}); - } - std::unique_ptr<ChunkingCache> ChunkCache(CreateNullChunkingCache()); - - LocalState = GetLocalContent(Workers, - LocalFolderScanStats, - ChunkingStats, - Path, - ZenStateFilePath(Path / ZenFolderName), - *ChunkController, - *ChunkCache); - - std::vector<std::filesystem::path> UntrackedPaths = GetNewPaths(LocalState.State.ChunkedContent.Paths, RemoteContent.Paths); - - BuildSaveState UntrackedLocalContent = GetLocalStateFromPaths(Workers, - LocalFolderScanStats, - ChunkingStats, - Path, - *ChunkController, - *ChunkCache, - UntrackedPaths); - - if (!UntrackedLocalContent.State.ChunkedContent.Paths.empty()) - { - LocalState.State.ChunkedContent = - MergeChunkedFolderContents(LocalState.State.ChunkedContent, - std::vector<ChunkedFolderContent>{UntrackedLocalContent.State.ChunkedContent}); - - // TODO: Helper - LocalState.FolderState.Paths.insert(LocalState.FolderState.Paths.begin(), - UntrackedLocalContent.FolderState.Paths.begin(), - UntrackedLocalContent.FolderState.Paths.end()); - LocalState.FolderState.RawSizes.insert(LocalState.FolderState.RawSizes.begin(), - UntrackedLocalContent.FolderState.RawSizes.begin(), - UntrackedLocalContent.FolderState.RawSizes.end()); - LocalState.FolderState.Attributes.insert(LocalState.FolderState.Attributes.begin(), - UntrackedLocalContent.FolderState.Attributes.begin(), - UntrackedLocalContent.FolderState.Attributes.end()); - LocalState.FolderState.ModificationTicks.insert(LocalState.FolderState.ModificationTicks.begin(), - UntrackedLocalContent.FolderState.ModificationTicks.begin(), - UntrackedLocalContent.FolderState.ModificationTicks.end()); - } + ZEN_CONSOLE_INFO("Unspecified chunking algorithm, using default"); + ChunkController = CreateStandardChunkingController(StandardChunkingControllerSettings{}); + } + std::unique_ptr<ChunkingCache> ChunkCache(CreateNullChunkingCache()); - if (Options.AppendNewContent) - { - RemoteContent = ApplyChunkedContentOverlay(LocalState.State.ChunkedContent, - RemoteContent, - Options.IncludeWildcards, - Options.ExcludeWildcards); - } -#if ZEN_BUILD_DEBUG - ValidateChunkedFolderContent(RemoteContent, - BlockDescriptions, - LooseChunkHashes, - Options.IncludeWildcards, - Options.ExcludeWildcards); -#endif // ZEN_BUILD_DEBUG + LocalState = GetLocalContent(Workers, + LocalFolderScanStats, + ChunkingStats, + Path, + ZenStateFilePath(Path / ZenFolderName), + *ChunkController, + *ChunkCache); + + std::vector<std::filesystem::path> UntrackedPaths = GetNewPaths(LocalState.State.ChunkedContent.Paths, RemoteContent.Paths); + + BuildSaveState UntrackedLocalContent = + GetLocalStateFromPaths(Workers, LocalFolderScanStats, ChunkingStats, Path, *ChunkController, *ChunkCache, UntrackedPaths); + + if (!UntrackedLocalContent.State.ChunkedContent.Paths.empty()) + { + LocalState.State.ChunkedContent = + MergeChunkedFolderContents(LocalState.State.ChunkedContent, + std::vector<ChunkedFolderContent>{UntrackedLocalContent.State.ChunkedContent}); + + // TODO: Helper + LocalState.FolderState.Paths.insert(LocalState.FolderState.Paths.begin(), + UntrackedLocalContent.FolderState.Paths.begin(), + UntrackedLocalContent.FolderState.Paths.end()); + LocalState.FolderState.RawSizes.insert(LocalState.FolderState.RawSizes.begin(), + UntrackedLocalContent.FolderState.RawSizes.begin(), + UntrackedLocalContent.FolderState.RawSizes.end()); + LocalState.FolderState.Attributes.insert(LocalState.FolderState.Attributes.begin(), + UntrackedLocalContent.FolderState.Attributes.begin(), + UntrackedLocalContent.FolderState.Attributes.end()); + LocalState.FolderState.ModificationTicks.insert(LocalState.FolderState.ModificationTicks.begin(), + UntrackedLocalContent.FolderState.ModificationTicks.begin(), + UntrackedLocalContent.FolderState.ModificationTicks.end()); } - else + + if (Options.AppendNewContent) { - CreateDirectories(Path); + RemoteContent = ApplyChunkedContentOverlay(LocalState.State.ChunkedContent, + RemoteContent, + Options.IncludeWildcards, + Options.ExcludeWildcards); } +#if ZEN_BUILD_DEBUG + ValidateChunkedFolderContent(RemoteContent, + BlockDescriptions, + LooseChunkHashes, + Options.IncludeWildcards, + Options.ExcludeWildcards); +#endif // ZEN_BUILD_DEBUG + } + else + { + CreateDirectories(Path); } if (AbortFlag) { @@ -1473,13 +1467,14 @@ namespace builds_impl { if (!IsQuiet) { - ZEN_OPERATION_LOG_INFO(Output, "Indexed local and remote content in {}", NiceTimeSpanMs(IndexTimer.GetElapsedTimeMs())); + ZEN_INFO("Indexed local and remote content in {}", NiceTimeSpanMs(IndexTimer.GetElapsedTimeMs())); } ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Download, TaskSteps::StepCount); BuildsOperationUpdateFolder Updater( - Output, + InLog, + Progress, Storage, AbortFlag, PauseFlag, @@ -1504,7 +1499,6 @@ namespace builds_impl { .PreferredMultipartChunkSize = PreferredMultipartChunkSize, .PartialBlockRequestMode = Options.PartialBlockRequestMode, .WipeTargetFolder = Options.CleanTargetFolder, - .PrimeCacheOnly = Options.PrimeCacheOnly, .EnableOtherDownloadsScavenging = Options.EnableOtherDownloadsScavenging, .EnableTargetFolderScavenging = Options.EnableTargetFolderScavenging || Options.AppendNewContent, .ValidateCompletedSequences = Options.PostDownloadVerify, @@ -1524,40 +1518,37 @@ namespace builds_impl { VerifyFolderStatistics VerifyFolderStats; if (!AbortFlag) { - if (!Options.PrimeCacheOnly) + AddDownloadedPath(Options.SystemRootDir, + BuildsDownloadInfo{.Selection = LocalState.State.Selection, + .LocalPath = Path, + .StateFilePath = ZenStateFilePath(Options.ZenFolderPath), + .Iso8601Date = DateTime::Now().ToIso8601()}); + + ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Verify, TaskSteps::StepCount); + + VerifyFolder(Workers, + RemoteContent, + RemoteLookup, + Path, + Options.ExcludeFolders, + Options.PostDownloadVerify, + VerifyFolderStats); + + Stopwatch WriteStateTimer; + CbObject StateObject = CreateBuildSaveStateObject(LocalState); + + CreateDirectories(ZenStateFilePath(Options.ZenFolderPath).parent_path()); + TemporaryFile::SafeWriteFile(ZenStateFilePath(Options.ZenFolderPath), StateObject.GetView()); + if (!IsQuiet) { - AddDownloadedPath(Options.SystemRootDir, - BuildsDownloadInfo{.Selection = LocalState.State.Selection, - .LocalPath = Path, - .StateFilePath = ZenStateFilePath(Options.ZenFolderPath), - .Iso8601Date = DateTime::Now().ToIso8601()}); - - ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Verify, TaskSteps::StepCount); - - VerifyFolder(Workers, - RemoteContent, - RemoteLookup, - Path, - Options.ExcludeFolders, - Options.PostDownloadVerify, - VerifyFolderStats); - - Stopwatch WriteStateTimer; - CbObject StateObject = CreateBuildSaveStateObject(LocalState); - - CreateDirectories(ZenStateFilePath(Options.ZenFolderPath).parent_path()); - TemporaryFile::SafeWriteFile(ZenStateFilePath(Options.ZenFolderPath), StateObject.GetView()); - if (!IsQuiet) - { - ZEN_CONSOLE("Wrote local state in {}", NiceTimeSpanMs(WriteStateTimer.GetElapsedTimeMs())); - } + ZEN_CONSOLE("Wrote local state in {}", NiceTimeSpanMs(WriteStateTimer.GetElapsedTimeMs())); + } #if 0 ExtendableStringBuilder<1024> SB; CompactBinaryToJson(StateObject, SB); WriteFile(ZenStateFileJsonPath(Options.ZenFolderPath), IoBuffer(IoBuffer::Wrap, SB.Data(), SB.Size())); #endif // 0 - } const uint64_t DownloadCount = Updater.m_DownloadStats.DownloadedChunkCount.load() + Updater.m_DownloadStats.DownloadedBlockCount.load() + Updater.m_DownloadStats.DownloadedPartialBlockCount.load(); @@ -1647,26 +1638,6 @@ namespace builds_impl { } } } - if (Options.PrimeCacheOnly) - { - if (Storage.CacheStorage) - { - Storage.CacheStorage->Flush(5000, [](intptr_t Remaining) { - if (!IsQuiet) - { - if (Remaining == 0) - { - ZEN_CONSOLE("Build cache upload complete"); - } - else - { - ZEN_CONSOLE("Waiting for build cache to complete uploading. {} blobs remaining", Remaining); - } - } - return !AbortFlag; - }); - } - } ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Cleanup, TaskSteps::StepCount); @@ -1784,6 +1755,7 @@ namespace builds_impl { { OptionalStructuredOutput->AddString("path"sv, fmt::format("{}", Path)); OptionalStructuredOutput->AddInteger("rawSize"sv, RawSize); + OptionalStructuredOutput->AddHash("rawHash"sv, RawHash); switch (Platform) { case SourcePlatform::Windows: @@ -2011,7 +1983,7 @@ namespace builds_impl { } // namespace builds_impl ////////////////////////////////////////////////////////////////////////////////////////////////////// -// BuildsCommand — Option-adding helpers +// BuildsCommand - Option-adding helpers // void @@ -2300,10 +2272,6 @@ BuildsCommand::OnParentOptionsParsed(const ZenCliOptions& /*GlobalOptions*/) { ProgressMode = ProgressBar::Mode::Plain; } - else if (m_Verbose) - { - ProgressMode = ProgressBar::Mode::Plain; - } else if (IsQuiet) { ProgressMode = ProgressBar::Mode::Quiet; @@ -2416,13 +2384,8 @@ BuildsCommand::CreateBuildStorage(BuildStorageBase::Statistics& StorageStats, /*Hidden*/ false, m_Verbose); - BuildStorageResolveResult ResolveRes = ResolveBuildStorage(*CreateConsoleLogOutput(ProgressMode), - ClientSettings, - m_Host, - m_OverrideHost, - m_ZenCacheHost, - ZenCacheResolveMode::All, - m_Verbose); + BuildStorageResolveResult ResolveRes = + ResolveBuildStorage(ConsoleLog(), ClientSettings, m_Host, m_OverrideHost, m_ZenCacheHost, ZenCacheResolveMode::All, m_Verbose); if (!ResolveRes.Cloud.Address.empty()) { ClientSettings.AssumeHttp2 = ResolveRes.Cloud.AssumeHttp2; @@ -2793,12 +2756,8 @@ BuildsCommand::ResolveZenFolderPath(const std::filesystem::path& DefaultPath) } EPartialBlockRequestMode -BuildsCommand::ParseAllowPartialBlockRequests(bool PrimeCacheOnly, cxxopts::Options& SubOpts) +BuildsCommand::ParseAllowPartialBlockRequests(cxxopts::Options& SubOpts) { - if (PrimeCacheOnly) - { - return EPartialBlockRequestMode::Off; - } EPartialBlockRequestMode Mode = PartialBlockRequestModeFromString(m_AllowPartialBlockRequests); if (Mode == EPartialBlockRequestMode::Invalid) { @@ -3284,10 +3243,11 @@ BuildsUploadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) ? CreateNullChunkingCache() : CreateDiskChunkingCache(m_Parent.m_ChunkingCachePath, *ChunkController, 256u * 1024u); - std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode)); + std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode)); std::vector<std::pair<Oid, std::string>> UploadedParts = - UploadFolder(*Output, + UploadFolder(ConsoleLog(), + *Progress, Workers, Storage, BuildId, @@ -3314,7 +3274,7 @@ BuildsUploadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) { for (const auto& Part : UploadedParts) { - ValidateBuildPart(*Output, Workers, *Storage.BuildStorage, BuildId, Part.first, Part.second); + ValidateBuildPart(ConsoleLog(), *Progress, Workers, *Storage.BuildStorage, BuildId, Part.first, Part.second); } } } @@ -3364,13 +3324,6 @@ BuildsDownloadSubCmd::BuildsDownloadSubCmd(BuildsCommand& Parent) Parent.AddAppendNewContentOptions(Opts); Parent.AddExcludeFolderOption(Opts); - Opts.add_option("cache", - "", - "cache-prime-only", - "Only download blobs missing in cache and upload to cache", - cxxopts::value(m_PrimeCacheOnly), - "<cacheprimeonly>"); - Opts.add_option("", "l", "local-path", "Root file system folder for build", cxxopts::value(m_Path), "<local-path>"); Opts.add_option("", "", "build-id", "Build Id", cxxopts::value(m_BuildId), "<id>"); Opts.add_option("", @@ -3470,36 +3423,16 @@ BuildsDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) m_BuildId, /*RequireNamespace*/ true, /*RequireBucket*/ true, - /*BoostCacheBackgroundWorkerPool*/ m_PrimeCacheOnly, + /*BoostCacheBackgroundWorkerPool*/ false, Auth, Opts); const Oid BuildId = m_Parent.ParseBuildId(m_BuildId, Opts); - if (m_PostDownloadVerify && m_PrimeCacheOnly) - { - throw OptionParseException("'--cache-prime-only' conflicts with '--verify'", Opts.help()); - } - - if (m_Clean && m_PrimeCacheOnly) - { - ZEN_CONSOLE_WARN("Ignoring '--clean' option when '--cache-prime-only' is enabled"); - } - - if (m_Force && m_PrimeCacheOnly) - { - ZEN_CONSOLE_WARN("Ignoring '--force' option when '--cache-prime-only' is enabled"); - } - - if (m_Parent.m_AllowPartialBlockRequests != "false" && m_PrimeCacheOnly) - { - ZEN_CONSOLE_WARN("Ignoring '--allow-partial-block-requests' option when '--cache-prime-only' is enabled"); - } - std::vector<Oid> BuildPartIds = m_Parent.ParseBuildPartIds(m_BuildPartIds, Opts); std::vector<std::string> BuildPartNames = m_Parent.ParseBuildPartNames(m_BuildPartNames, Opts); - EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(m_PrimeCacheOnly, Opts); + EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(Opts); if (m_Parent.m_AppendNewContent && m_Clean) { @@ -3510,10 +3443,11 @@ BuildsDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) std::vector<std::string> ExcludeExtensions = DefaultExcludeExtensions; m_Parent.ParseExcludeFolderAndExtension(ExcludeFolders, ExcludeExtensions); - std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode)); + std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode)); DownloadFolder( - *Output, + ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -3528,7 +3462,6 @@ BuildsDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = m_Clean, .PostDownloadVerify = m_PostDownloadVerify, - .PrimeCacheOnly = m_PrimeCacheOnly, .EnableOtherDownloadsScavenging = m_EnableScavenging && !m_Force, .EnableTargetFolderScavenging = !m_Force, .AllowFileClone = m_AllowFileClone, @@ -3898,9 +3831,10 @@ BuildsPrimeCacheSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) ProgressBar::SetLogOperationName(ProgressMode, "Prime Cache"); - std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode)); + std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode)); - BuildsOperationPrimeCache PrimeOp(*Output, + BuildsOperationPrimeCache PrimeOp(ConsoleLog(), + *Progress, Storage, AbortFlag, PauseFlag, @@ -4056,9 +3990,9 @@ BuildsValidatePartSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) const Oid BuildPartId = m_BuildPartName.empty() ? Oid::Zero : m_Parent.ParseBuildPartId(m_BuildPartId, Opts); - std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode)); + std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode)); - ValidateBuildPart(*Output, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName); + ValidateBuildPart(ConsoleLog(), *Progress, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName); if (AbortFlag) { @@ -4131,7 +4065,7 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) } }); - EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(false, Opts); + EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(Opts); BuildStorageBase::Statistics StorageStats; BuildStorageCache::Statistics StorageCacheStats; @@ -4202,9 +4136,10 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) ? CreateNullChunkingCache() : CreateDiskChunkingCache(m_Parent.m_ChunkingCachePath, *ChunkController, 256u * 1024u); - std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode)); + std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode)); - UploadFolder(*Output, + UploadFolder(ConsoleLog(), + *Progress, Workers, Storage, BuildId, @@ -4231,7 +4166,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) { ZEN_CONSOLE("Upload Build {}, Part {} ({}) from '{}' with chunking cache", m_BuildId, BuildPartId, m_BuildPartName, m_Path); - UploadFolder(*Output, + UploadFolder(ConsoleLog(), + *Progress, Workers, Storage, Oid::NewOid(), @@ -4256,7 +4192,7 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) } } - ValidateBuildPart(*Output, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName); + ValidateBuildPart(ConsoleLog(), *Progress, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName); if (!m_Parent.m_IncludeWildcard.empty() || !m_Parent.m_ExcludeWildcard.empty()) { @@ -4268,7 +4204,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) std::vector<std::string> ExcludeWildcards; m_Parent.ParseFileFilters(IncludeWildcards, ExcludeWildcards); - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4283,7 +4220,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = true, .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = false, .AllowFileClone = m_AllowFileClone, @@ -4300,7 +4236,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) BuildPartId, m_BuildPartName, DownloadPath); - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4315,7 +4252,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = true, .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = true, .AllowFileClone = m_AllowFileClone, @@ -4328,7 +4264,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) } ZEN_CONSOLE("\nDownload Full Build {}, Part {} ({}) to '{}'", BuildId, BuildPartId, m_BuildPartName, DownloadPath); - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4343,7 +4280,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = false, .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = true, .AllowFileClone = m_AllowFileClone, @@ -4357,7 +4293,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) } ZEN_CONSOLE("\nDownload Build {}, Part {} ({}) to '{}'", BuildId, BuildPartId, m_BuildPartName, DownloadPath); - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4373,7 +4310,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = true, .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = false, .AllowFileClone = m_AllowFileClone}); @@ -4383,7 +4319,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) } ZEN_CONSOLE("\nRe-download Build {}, Part {} ({}) to '{}' (identical target)", BuildId, BuildPartId, m_BuildPartName, DownloadPath); - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4399,7 +4336,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = false, .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = true, .AllowFileClone = m_AllowFileClone}); @@ -4501,7 +4437,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) ScrambleDir(DownloadPath); ZEN_CONSOLE("\nRe-download Build {}, Part {} ({}) to '{}' (scrambled target)", BuildId, BuildPartId, m_BuildPartName, DownloadPath); - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4517,7 +4454,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = false, .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = true, .AllowFileClone = m_AllowFileClone}); @@ -4538,7 +4474,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) ZEN_CONSOLE("\nUpload scrambled Build {}, Part {} ({})\n{}\n", BuildId2, BuildPartId2, m_BuildPartName, SB.ToView()); } - UploadFolder(*Output, + UploadFolder(ConsoleLog(), + *Progress, Workers, Storage, BuildId2, @@ -4562,10 +4499,11 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) throw std::runtime_error("Test aborted. (Upload scrambled)"); } - ValidateBuildPart(*Output, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName); + ValidateBuildPart(ConsoleLog(), *Progress, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName); ZEN_CONSOLE("\nDownload Build {}, Part {} ({}) to '{}' (original)", BuildId, BuildPartId, m_BuildPartName, DownloadPath); - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4581,7 +4519,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = false, .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = true, .AllowFileClone = m_AllowFileClone}); @@ -4591,7 +4528,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) } ZEN_CONSOLE("\nDownload Build {}, Part {} ({}) to '{}' (scrambled)", BuildId2, BuildPartId2, m_BuildPartName, DownloadPath); - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4606,7 +4544,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = false, .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = true, .AllowFileClone = m_AllowFileClone}); @@ -4616,7 +4553,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) } ZEN_CONSOLE("\nRe-download Build {}, Part {} ({}) to '{}' (scrambled)", BuildId2, BuildPartId2, m_BuildPartName, DownloadPath); - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4631,7 +4569,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = false, .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = true, .AllowFileClone = m_AllowFileClone}); @@ -4641,7 +4578,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) } ZEN_CONSOLE("\nDownload Build {}, Part {} ({}) to '{}' (original)", BuildId, BuildPartId, m_BuildPartName, DownloadPath2); - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4656,7 +4594,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = false, .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = true, .AllowFileClone = m_AllowFileClone}); @@ -4666,7 +4603,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) } ZEN_CONSOLE("\nDownload Build {}, Part {} ({}) to '{}' (original)", BuildId, BuildPartId, m_BuildPartName, DownloadPath3); - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4681,7 +4619,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = false, .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = true, .AllowFileClone = m_AllowFileClone}); @@ -4741,7 +4678,7 @@ BuildsMultiTestDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) m_Parent.ResolveZenFolderPath(m_Path / ZenFolderName); - EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(false, Opts); + EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(Opts); BuildStorageBase::Statistics StorageStats; BuildStorageCache::Statistics StorageCacheStats; @@ -4758,7 +4695,7 @@ BuildsMultiTestDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) Auth, Opts); - std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode)); + std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode)); Stopwatch Timer; for (const std::string& BuildIdString : m_BuildIds) @@ -4768,7 +4705,8 @@ BuildsMultiTestDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) { throw OptionParseException(fmt::format("'--build-id' ('{}') is malformed", BuildIdString), Opts.help()); } - DownloadFolder(*Output, + DownloadFolder(ConsoleLog(), + *Progress, Workers, Storage, StorageCacheStats, @@ -4783,7 +4721,6 @@ BuildsMultiTestDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) .PartialBlockRequestMode = PartialBlockRequestMode, .CleanTargetFolder = BuildIdString == m_BuildIds.front(), .PostDownloadVerify = true, - .PrimeCacheOnly = false, .EnableOtherDownloadsScavenging = m_EnableScavenging, .EnableTargetFolderScavenging = false, .AllowFileClone = m_AllowFileClone}); diff --git a/src/zen/cmds/builds_cmd.h b/src/zen/cmds/builds_cmd.h index 7ef71e176..ef7500fd6 100644 --- a/src/zen/cmds/builds_cmd.h +++ b/src/zen/cmds/builds_cmd.h @@ -94,7 +94,6 @@ private: bool m_EnableScavenging = true; std::filesystem::path m_DownloadSpecPath; bool m_UploadToZenCache = true; - bool m_PrimeCacheOnly = false; bool m_AllowFileClone = true; }; @@ -280,7 +279,7 @@ public: cxxopts::Options& SubOpts); void ParsePath(std::filesystem::path& Path, cxxopts::Options& SubOpts); IoHash ParseBlobHash(const std::string& BlobHashStr, cxxopts::Options& SubOpts); - EPartialBlockRequestMode ParseAllowPartialBlockRequests(bool PrimeCacheOnly, cxxopts::Options& SubOpts); + EPartialBlockRequestMode ParseAllowPartialBlockRequests(cxxopts::Options& SubOpts); void ParseZenProcessId(int& ZenProcessId); void ParseFileFilters(std::vector<std::string>& OutIncludeWildcards, std::vector<std::string>& OutExcludeWildcards); void ParseExcludeFolderAndExtension(std::vector<std::string>& OutExcludeFolders, std::vector<std::string>& OutExcludeExtensions); diff --git a/src/zen/cmds/compute_cmd.cpp b/src/zen/cmds/compute_cmd.cpp new file mode 100644 index 000000000..01166cb0e --- /dev/null +++ b/src/zen/cmds/compute_cmd.cpp @@ -0,0 +1,96 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "compute_cmd.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/logging.h> +# include <zenhttp/httpclient.h> + +using namespace std::literals; + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// ComputeRecordStartSubCmd + +ComputeRecordStartSubCmd::ComputeRecordStartSubCmd() : ZenSubCmdBase("record-start", "Start recording compute actions") +{ + SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); +} + +void +ComputeRecordStartSubCmd::Run(const ZenCliOptions& GlobalOptions) +{ + ZEN_UNUSED(GlobalOptions); + + m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName); + if (m_HostName.empty()) + { + throw OptionParseException("Unable to resolve server specification", SubOptions().help()); + } + + HttpClient Http = ZenCmdBase::CreateHttpClient(m_HostName); + if (HttpClient::Response Response = Http.Post("/compute/record/start"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{})) + { + CbObject Obj = Response.AsObject(); + std::string_view Path = Obj["path"sv].AsString(); + ZEN_CONSOLE("recording started: " ZEN_BRIGHT_GREEN("{}"), Path); + } + else + { + Response.ThrowError("Failed to start recording"); + } +} + +////////////////////////////////////////////////////////////////////////// +// ComputeRecordStopSubCmd + +ComputeRecordStopSubCmd::ComputeRecordStopSubCmd() : ZenSubCmdBase("record-stop", "Stop recording compute actions") +{ + SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); +} + +void +ComputeRecordStopSubCmd::Run(const ZenCliOptions& GlobalOptions) +{ + ZEN_UNUSED(GlobalOptions); + + m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName); + if (m_HostName.empty()) + { + throw OptionParseException("Unable to resolve server specification", SubOptions().help()); + } + + HttpClient Http = ZenCmdBase::CreateHttpClient(m_HostName); + if (HttpClient::Response Response = Http.Post("/compute/record/stop"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{})) + { + CbObject Obj = Response.AsObject(); + std::string_view Path = Obj["path"sv].AsString(); + ZEN_CONSOLE("recording stopped: " ZEN_BRIGHT_GREEN("{}"), Path); + } + else + { + Response.ThrowError("Failed to stop recording"); + } +} + +////////////////////////////////////////////////////////////////////////// +// ComputeCommand + +ComputeCommand::ComputeCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("__hidden__", "", "subcommand", "", cxxopts::value<std::string>(m_SubCommand)->default_value(""), ""); + m_Options.parse_positional({"subcommand"}); + + AddSubCommand(m_RecordStartSubCmd); + AddSubCommand(m_RecordStopSubCmd); +} + +ComputeCommand::~ComputeCommand() = default; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zen/cmds/compute_cmd.h b/src/zen/cmds/compute_cmd.h new file mode 100644 index 000000000..b26f639c4 --- /dev/null +++ b/src/zen/cmds/compute_cmd.h @@ -0,0 +1,53 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +#include <string> + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen { + +class ComputeRecordStartSubCmd : public ZenSubCmdBase +{ +public: + ComputeRecordStartSubCmd(); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + std::string m_HostName; +}; + +class ComputeRecordStopSubCmd : public ZenSubCmdBase +{ +public: + ComputeRecordStopSubCmd(); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + std::string m_HostName; +}; + +class ComputeCommand : public ZenCmdWithSubCommands +{ +public: + static constexpr char Name[] = "compute"; + static constexpr char Description[] = "Compute service operations"; + + ComputeCommand(); + ~ComputeCommand(); + + cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{Name, Description}; + std::string m_SubCommand; + ComputeRecordStartSubCmd m_RecordStartSubCmd; + ComputeRecordStopSubCmd m_RecordStopSubCmd; +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp index 30e860a3f..dab53f13c 100644 --- a/src/zen/cmds/exec_cmd.cpp +++ b/src/zen/cmds/exec_cmd.cpp @@ -23,6 +23,8 @@ #include <zenhttp/httpclient.h> #include <zenhttp/packageformat.h> +#include "../progressbar.h" + #include <EASTL/hash_map.h> #include <EASTL/hash_set.h> #include <EASTL/map.h> @@ -114,7 +116,7 @@ namespace { } // namespace ////////////////////////////////////////////////////////////////////////// -// ExecSessionConfig — read-only configuration for a session run +// ExecSessionConfig - read-only configuration for a session run struct ExecSessionConfig { @@ -124,17 +126,18 @@ struct ExecSessionConfig std::vector<ExecFunctionDefinition>& FunctionList; // mutable for EmitFunctionListOnce std::string_view OrchestratorUrl; const std::filesystem::path& OutputPath; - int Offset = 0; - int Stride = 1; - int Limit = 0; - bool Verbose = false; - bool Quiet = false; - bool DumpActions = false; - bool Binary = false; + int Offset = 0; + int Stride = 1; + int Limit = 0; + bool Verbose = false; + bool Quiet = false; + bool DumpActions = false; + bool Binary = false; + ProgressBar::Mode ProgressMode = ProgressBar::Mode::Pretty; }; ////////////////////////////////////////////////////////////////////////// -// ExecSessionRunner — owns per-run state, drives the session lifecycle +// ExecSessionRunner - owns per-run state, drives the session lifecycle class ExecSessionRunner { @@ -345,8 +348,6 @@ ExecSessionRunner::DrainCompletedJobs() } m_PendingJobs.Remove(CompleteLsn); - - ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, m_PendingJobs.GetSize()); } } } @@ -897,17 +898,20 @@ ExecSessionRunner::Run() // Then submit work items - int FailedWorkCounter = 0; - size_t RemainingWorkItems = m_Config.RecordingReader.GetActionCount(); - int SubmittedWorkItems = 0; + std::atomic<int> FailedWorkCounter{0}; + std::atomic<size_t> RemainingWorkItems{m_Config.RecordingReader.GetActionCount()}; + std::atomic<int> SubmittedWorkItems{0}; + size_t TotalWorkItems = RemainingWorkItems.load(); - ZEN_CONSOLE("submitting {} work items", RemainingWorkItems); + ProgressBar SubmitProgress(m_Config.ProgressMode, "Submit"); + SubmitProgress.UpdateState({.Task = "Submitting work items", .TotalCount = TotalWorkItems, .RemainingCount = RemainingWorkItems.load()}, + false); int OffsetCounter = m_Config.Offset; int StrideCounter = m_Config.Stride; auto ShouldSchedule = [&]() -> bool { - if (m_Config.Limit && SubmittedWorkItems >= m_Config.Limit) + if (m_Config.Limit && SubmittedWorkItems.load() >= m_Config.Limit) { // Limit reached, ignore @@ -1005,17 +1009,14 @@ ExecSessionRunner::Run() { const int32_t LsnField = EnqueueResult.Lsn; - --RemainingWorkItems; - ++SubmittedWorkItems; + size_t Remaining = --RemainingWorkItems; + int Submitted = ++SubmittedWorkItems; - if (!m_Config.Quiet) - { - ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining", - SubmittedWorkItems, - LsnField, - NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), - RemainingWorkItems); - } + SubmitProgress.UpdateState({.Task = "Submitting work items", + .Details = fmt::format("#{} LSN {}", Submitted, LsnField), + .TotalCount = TotalWorkItems, + .RemainingCount = Remaining}, + false); if (!m_Config.OutputPath.empty()) { @@ -1055,22 +1056,36 @@ ExecSessionRunner::Run() }, TargetParallelism); + SubmitProgress.Finish(); + // Wait until all pending work is complete + size_t TotalPendingJobs = m_PendingJobs.GetSize(); + + ProgressBar CompletionProgress(m_Config.ProgressMode, "Execute"); + while (!m_PendingJobs.IsEmpty()) { - // TODO: improve this logic - zen::Sleep(500); + size_t PendingCount = m_PendingJobs.GetSize(); + CompletionProgress.UpdateState({.Task = "Executing work items", + .Details = fmt::format("{} completed, {} remaining", TotalPendingJobs - PendingCount, PendingCount), + .TotalCount = TotalPendingJobs, + .RemainingCount = PendingCount}, + false); + + zen::Sleep(GetUpdateDelayMS(m_Config.ProgressMode)); DrainCompletedJobs(); SendOrchestratorHeartbeat(); } + CompletionProgress.Finish(); + // Write summary files WriteSummaryFiles(); - if (FailedWorkCounter) + if (FailedWorkCounter.load()) { return 1; } @@ -1119,6 +1134,7 @@ ExecHttpSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) ExecInprocSubCmd::ExecInprocSubCmd(ExecCommand& Parent) : ZenSubCmdBase("inproc", "Handle execution in-process"), m_Parent(Parent) { + m_SubOptions.add_option("managed", "", "managed", "Use managed local runner (if supported)", cxxopts::value(m_Managed), "<bool>"); } void @@ -1130,7 +1146,16 @@ ExecInprocSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) zen::compute::ComputeServiceSession ComputeSession(Resolver); std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - ComputeSession.AddLocalRunner(Resolver, TempPath); + if (m_Managed) + { + ZEN_CONSOLE_INFO("using managed local runner"); + ComputeSession.AddManagedLocalRunner(Resolver, TempPath); + } + else + { + ZEN_CONSOLE_INFO("using local runner"); + ComputeSession.AddLocalRunner(Resolver, TempPath); + } Stopwatch ExecTimer; int ReturnValue = m_Parent.RunSession(ComputeSession); @@ -1413,6 +1438,12 @@ ExecCommand::OnParentOptionsParsed(const ZenCliOptions& GlobalOptions) int ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std::string_view OrchestratorUrl) { + ProgressBar::Mode ProgressMode = ProgressBar::Mode::Pretty; + if (m_QuietLogging) + { + ProgressMode = ProgressBar::Mode::Quiet; + } + ExecSessionConfig Config{ .Resolver = *m_ChunkResolver, .RecordingReader = *m_RecordingReader, @@ -1427,6 +1458,7 @@ ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std .Quiet = m_QuietLogging, .DumpActions = m_DumpActions, .Binary = m_Binary, + .ProgressMode = ProgressMode, }; ExecSessionRunner Runner(ComputeSession, Config); diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h index c55412780..a0bf201a1 100644 --- a/src/zen/cmds/exec_cmd.h +++ b/src/zen/cmds/exec_cmd.h @@ -61,6 +61,7 @@ public: private: ExecCommand& m_Parent; + bool m_Managed = false; }; class ExecBeaconSubCmd : public ZenSubCmdBase diff --git a/src/zen/cmds/projectstore_cmd.cpp b/src/zen/cmds/projectstore_cmd.cpp index d31c34fd0..7f94bf2df 100644 --- a/src/zen/cmds/projectstore_cmd.cpp +++ b/src/zen/cmds/projectstore_cmd.cpp @@ -24,10 +24,10 @@ #include <zenremotestore/builds/buildstorageutil.h> #include <zenremotestore/builds/jupiterbuildstorage.h> #include <zenremotestore/jupiter/jupiterhost.h> -#include <zenremotestore/operationlogoutput.h> #include <zenremotestore/projectstore/projectstoreoperations.h> #include <zenremotestore/projectstore/remoteprojectstore.h> #include <zenremotestore/transferthreadworkers.h> +#include <zenutil/progress.h> #include <zenutil/workerpools.h> #include "../progressbar.h" @@ -2500,10 +2500,6 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a { ProgressMode = ProgressBar::Mode::Plain; } - else if (m_Verbose) - { - ProgressMode = ProgressBar::Mode::Plain; - } else if (m_Quiet) { ProgressMode = ProgressBar::Mode::Quiet; @@ -2565,7 +2561,7 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a m_BoostWorkerMemory = true; } - std::unique_ptr<OperationLogOutput> OperationLogOutput(CreateConsoleLogOutput(ProgressMode)); + std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode)); TransferThreadWorkers Workers(m_BoostWorkerCount, /*SingleThreaded*/ false); if (!m_Quiet) @@ -2594,7 +2590,7 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a /*Hidden*/ false, m_Verbose); - BuildStorageResolveResult ResolveRes = ResolveBuildStorage(*OperationLogOutput, + BuildStorageResolveResult ResolveRes = ResolveBuildStorage(ConsoleLog(), ClientSettings, m_Host, m_OverrideHost, @@ -2677,7 +2673,7 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a } ProjectStoreOperationOplogState State( - *OperationLogOutput, + ConsoleLog(), Storage, BuildId, {.IsQuiet = m_Quiet, .IsVerbose = m_Verbose, .ForceDownload = m_ForceDownload, .TempFolderPath = StorageTempPath}); @@ -2704,7 +2700,8 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a } std::atomic<bool> PauseFlag; - ProjectStoreOperationDownloadAttachments Op(*OperationLogOutput, + ProjectStoreOperationDownloadAttachments Op(ConsoleLog(), + *Progress, Storage, AbortFlag, PauseFlag, diff --git a/src/zen/cmds/projectstore_cmd.h b/src/zen/cmds/projectstore_cmd.h index 1ba98b39e..41db36139 100644 --- a/src/zen/cmds/projectstore_cmd.h +++ b/src/zen/cmds/projectstore_cmd.h @@ -217,7 +217,7 @@ class SnapshotOplogCommand : public ProjectStoreCommand { public: static constexpr char Name[] = "oplog-snapshot"; - static constexpr char Description[] = "Snapshot project store oplog"; + static constexpr char Description[] = "Copy oplog's loose files on disk into zenserver"; SnapshotOplogCommand(); ~SnapshotOplogCommand(); diff --git a/src/zen/cmds/service_cmd.cpp b/src/zen/cmds/service_cmd.cpp index 37baf5483..5e284cbdf 100644 --- a/src/zen/cmds/service_cmd.cpp +++ b/src/zen/cmds/service_cmd.cpp @@ -320,7 +320,7 @@ ServiceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) throw OptionParseException("'verb' option is required", m_Options.help()); } - // Parse subcommand permissively — forward unrecognised options to the parent parser. + // Parse subcommand permissively - forward unrecognised options to the parent parser. std::vector<std::string> SubUnmatched; if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched)) { diff --git a/src/zen/cmds/ui_cmd.cpp b/src/zen/cmds/ui_cmd.cpp index 4846b4d18..28ab6c45c 100644 --- a/src/zen/cmds/ui_cmd.cpp +++ b/src/zen/cmds/ui_cmd.cpp @@ -162,7 +162,7 @@ UiCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) Labels.push_back(fmt::format("(all {} instances)", Servers.size())); const int32_t Cols = static_cast<int32_t>(TuiConsoleColumns()); - constexpr int32_t kIndicator = 3; // " ▶ " or " " prefix + constexpr int32_t kIndicator = 3; // " > " or " " prefix constexpr int32_t kSeparator = 2; // " " before cmdline constexpr int32_t kEllipsis = 3; // "..." diff --git a/src/zen/cmds/wipe_cmd.cpp b/src/zen/cmds/wipe_cmd.cpp index 10f5ad8e1..c027f0d67 100644 --- a/src/zen/cmds/wipe_cmd.cpp +++ b/src/zen/cmds/wipe_cmd.cpp @@ -4,6 +4,7 @@ #include <zencore/filesystem.h> #include <zencore/fmtutils.h> +#include <zencore/iohash.h> #include <zencore/logging.h> #include <zencore/parallelwork.h> #include <zencore/string.h> @@ -548,7 +549,7 @@ WipeCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) Quiet = m_Quiet; IsVerbose = m_Verbose; - ProgressMode = (IsVerbose || m_PlainProgress) ? ProgressBar::Mode::Plain : ProgressBar::Mode::Pretty; + ProgressMode = m_PlainProgress ? ProgressBar::Mode::Plain : ProgressBar::Mode::Pretty; BoostWorkerThreads = m_BoostWorkerThreads; MakeSafeAbsolutePathInPlace(m_Directory); diff --git a/src/zen/cmds/workspaces_cmd.cpp b/src/zen/cmds/workspaces_cmd.cpp index 9e49b464e..1ae2d58fa 100644 --- a/src/zen/cmds/workspaces_cmd.cpp +++ b/src/zen/cmds/workspaces_cmd.cpp @@ -137,7 +137,7 @@ WorkspaceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) throw OptionParseException("'verb' option is required", m_Options.help()); } - // Parse subcommand permissively — forward unrecognised options to the parent parser. + // Parse subcommand permissively - forward unrecognised options to the parent parser. std::vector<std::string> SubUnmatched; if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched)) { @@ -403,7 +403,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** throw OptionParseException("'verb' option is required", m_Options.help()); } - // Parse subcommand permissively — forward unrecognised options to the parent parser. + // Parse subcommand permissively - forward unrecognised options to the parent parser. std::vector<std::string> SubUnmatched; if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched)) { diff --git a/src/zen/progressbar.cpp b/src/zen/progressbar.cpp index 6581cd116..780b08707 100644 --- a/src/zen/progressbar.cpp +++ b/src/zen/progressbar.cpp @@ -5,10 +5,15 @@ #include "progressbar.h" +#include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/windows.h> -#include <zenremotestore/operationlogoutput.h> #include <zenutil/consoletui.h> +#include <zenutil/progress.h> + +#if !ZEN_PLATFORM_WINDOWS +# include <csignal> +#endif ZEN_THIRD_PARTY_INCLUDES_START #include <gsl/gsl-lite.hpp> @@ -18,6 +23,87 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { +// Global tracking for scroll region cleanup on abnormal termination (Ctrl+C etc.) +// Only one ProgressBar can own a scroll region at a time. +static std::atomic<ProgressBar*> g_ActiveScrollRegionOwner{nullptr}; +static std::atomic<uint32_t> g_ActiveScrollRegionRows{0}; + +static void +ResetScrollRegionRaw() +{ + // Signal-safe: emit raw escape sequences to restore terminal state. + // These are async-signal-safe on POSIX (write()) and safe in console + // ctrl handlers on Windows (WriteConsole is allowed). + uint32_t Rows = g_ActiveScrollRegionRows.load(std::memory_order_acquire); + if (Rows >= 3) + { + // Move to status line, erase it, reset scroll region, move cursor to end of content + TuiMoveCursor(Rows, 1); + TuiEraseLine(); + TuiResetScrollRegion(); + TuiMoveCursor(Rows - 1, 1); + } + else + { + TuiResetScrollRegion(); + } + TuiShowCursor(true); + TuiFlush(); +} + +#if ZEN_PLATFORM_WINDOWS +static BOOL WINAPI +ScrollRegionCtrlHandler(DWORD CtrlType) +{ + if (CtrlType == CTRL_C_EVENT || CtrlType == CTRL_BREAK_EVENT) + { + ResetScrollRegionRaw(); + } + // Return FALSE so the default handler (process termination) still runs + return FALSE; +} +#else +static struct sigaction s_PrevSigIntAction; +static struct sigaction s_PrevSigTermAction; + +static void +ScrollRegionSignalHandler(int Signal) +{ + ResetScrollRegionRaw(); + + // Re-raise with the previous handler + struct sigaction* PrevAction = (Signal == SIGINT) ? &s_PrevSigIntAction : &s_PrevSigTermAction; + sigaction(Signal, PrevAction, nullptr); + raise(Signal); +} +#endif + +static void +InstallScrollRegionCleanupHandler() +{ +#if ZEN_PLATFORM_WINDOWS + SetConsoleCtrlHandler(ScrollRegionCtrlHandler, TRUE); +#else + struct sigaction Action = {}; + Action.sa_handler = ScrollRegionSignalHandler; + Action.sa_flags = SA_RESETHAND; // one-shot + sigemptyset(&Action.sa_mask); + sigaction(SIGINT, &Action, &s_PrevSigIntAction); + sigaction(SIGTERM, &Action, &s_PrevSigTermAction); +#endif +} + +static void +RemoveScrollRegionCleanupHandler() +{ +#if ZEN_PLATFORM_WINDOWS + SetConsoleCtrlHandler(ScrollRegionCtrlHandler, FALSE); +#else + sigaction(SIGINT, &s_PrevSigIntAction, nullptr); + sigaction(SIGTERM, &s_PrevSigTermAction, nullptr); +#endif +} + #if ZEN_PLATFORM_WINDOWS static HANDLE GetConsoleHandle() @@ -128,12 +214,18 @@ ProgressBar::ProgressBar(Mode InMode, std::string_view InSubTask) { PushLogOperation(InMode, m_SubTask); } + + if (m_Mode == Mode::Pretty) + { + SetupScrollRegion(); + } } ProgressBar::~ProgressBar() { try { + TeardownScrollRegion(); ForceLinebreak(); if (!m_SubTask.empty()) { @@ -147,6 +239,87 @@ ProgressBar::~ProgressBar() } void +ProgressBar::SetupScrollRegion() +{ + // Only one scroll region owner at a time; nested bars fall back to the inline \r path. + if (g_ActiveScrollRegionOwner.load(std::memory_order_acquire) != nullptr) + { + return; + } + + uint32_t Rows = TuiConsoleRows(0); + if (Rows < 3) + { + return; + } + + TuiEnableOutput(); + + // Ensure cursor is not on the last row before we install the region. + // Print a newline to push content up if needed, then set the region. + OutputToConsoleRaw("\n"); + TuiSetScrollRegion(1, Rows - 1); + + // Move cursor into the scroll region so normal output stays there + TuiMoveCursor(Rows - 1, 1); + + m_ScrollRegionActive = true; + m_ScrollRegionRows = Rows; + + g_ActiveScrollRegionRows.store(Rows, std::memory_order_release); + g_ActiveScrollRegionOwner.store(this, std::memory_order_release); + InstallScrollRegionCleanupHandler(); +} + +void +ProgressBar::TeardownScrollRegion() +{ + if (!m_ScrollRegionActive) + { + return; + } + m_ScrollRegionActive = false; + + RemoveScrollRegionCleanupHandler(); + g_ActiveScrollRegionOwner.store(nullptr, std::memory_order_release); + g_ActiveScrollRegionRows.store(0, std::memory_order_release); + + // Emit all teardown escape sequences as a single atomic write + ExtendableStringBuilder<128> Buf; + Buf << fmt::format("\x1b[{};1H", m_ScrollRegionRows) // move to status line + << "\x1b[2K" // erase it + << "\x1b[r" // reset scroll region + << fmt::format("\x1b[{};1H", m_ScrollRegionRows - 1); // move to end of content + OutputToConsoleRaw(Buf); + TuiFlush(); +} + +void +ProgressBar::RenderStatusLine(std::string_view Line) +{ + // Handle terminal resizes by re-querying row count + uint32_t CurrentRows = TuiConsoleRows(0); + if (CurrentRows >= 3 && CurrentRows != m_ScrollRegionRows) + { + // Terminal was resized - reinstall scroll region + TuiSetScrollRegion(1, CurrentRows - 1); + m_ScrollRegionRows = CurrentRows; + } + + // Build the entire escape sequence as a single string so the console write + // is atomic and log output from other threads cannot interleave. + ExtendableStringBuilder<512> Buf; + Buf << "\x1b" + "7" // ESC 7 - save cursor + << fmt::format("\x1b[{};1H", m_ScrollRegionRows) // move to bottom row + << "\x1b[2K" // erase entire line + << Line // progress bar content + << "\x1b" + "8"; // ESC 8 - restore cursor + OutputToConsoleRaw(Buf); +} + +void ProgressBar::UpdateState(const State& NewState, bool DoLinebreak) { ZEN_ASSERT(NewState.TotalCount >= NewState.RemainingCount); @@ -226,7 +399,7 @@ ProgressBar::UpdateState(const State& NewState, bool DoLinebreak) ExtendableStringBuilder<256> OutputBuilder; - OutputBuilder << "\r" << Task << " " << PercentString; + OutputBuilder << Task << " " << PercentString; if (OutputBuilder.Size() + 1 < ConsoleColumns) { size_t RemainingSpace = ConsoleColumns - (OutputBuilder.Size() + 1); @@ -257,34 +430,43 @@ ProgressBar::UpdateState(const State& NewState, bool DoLinebreak) } } - std::string_view Output = OutputBuilder.ToView(); - std::string::size_type EraseLength = m_LastOutputLength > Output.length() ? (m_LastOutputLength - Output.length()) : 0; + if (m_ScrollRegionActive) + { + // Render on the pinned bottom status line + RenderStatusLine(OutputBuilder.ToView()); + } + else + { + // Fallback: inline \r-based overwrite (terminal too small for scroll region) + std::string_view Output = OutputBuilder.ToView(); + std::string::size_type EraseLength = + m_LastOutputLength > (Output.length() + 1) ? (m_LastOutputLength - Output.length() - 1) : 0; + ExtendableStringBuilder<256> LineToPrint; - ExtendableStringBuilder<256> LineToPrint; + if (Output.length() + 1 + EraseLength >= ConsoleColumns) + { + if (m_LastOutputLength > 0) + { + LineToPrint << "\n"; + } + LineToPrint << Output; + DoLinebreak = true; + } + else + { + LineToPrint << "\r" << Output << std::string(EraseLength, ' '); + } - if (Output.length() + EraseLength >= ConsoleColumns) - { - if (m_LastOutputLength > 0) + if (DoLinebreak) { LineToPrint << "\n"; } - LineToPrint << Output.substr(1); - DoLinebreak = true; - } - else - { - LineToPrint << Output << std::string(EraseLength, ' '); - } - if (DoLinebreak) - { - LineToPrint << "\n"; + OutputToConsoleRaw(LineToPrint); + m_LastOutputLength = DoLinebreak ? 0 : (Output.length() + 1); // +1 for \r prefix } - OutputToConsoleRaw(LineToPrint); - - m_LastOutputLength = DoLinebreak ? 0 : Output.length(); - m_State = NewState; + m_State = NewState; } else if (m_Mode == Mode::Log) { @@ -329,6 +511,8 @@ ProgressBar::ForceLinebreak() void ProgressBar::Finish() { + TeardownScrollRegion(); + if (m_LastOutputLength > 0 || m_State.RemainingCount > 0) { State NewState = m_State; @@ -353,64 +537,29 @@ ProgressBar::HasActiveTask() const return !m_State.Task.empty(); } -class ConsoleOpLogProgressBar : public OperationLogOutput::ProgressBar -{ -public: - ConsoleOpLogProgressBar(zen::ProgressBar::Mode InMode, std::string_view InSubTask) : m_Inner(InMode, InSubTask) {} - - virtual void UpdateState(const State& NewState, bool DoLinebreak) - { - zen::ProgressBar::State State = {.Task = NewState.Task, - .Details = NewState.Details, - .TotalCount = NewState.TotalCount, - .RemainingCount = NewState.RemainingCount, - .Status = ConvertStatus(NewState.Status)}; - m_Inner.UpdateState(State, DoLinebreak); - } - virtual void Finish() { m_Inner.Finish(); } - -private: - zen::ProgressBar::State::EStatus ConvertStatus(State::EStatus Status) - { - switch (Status) - { - case State::EStatus::Running: - return zen::ProgressBar::State::EStatus::Running; - case State::EStatus::Aborted: - return zen::ProgressBar::State::EStatus::Aborted; - case State::EStatus::Paused: - return zen::ProgressBar::State::EStatus::Paused; - default: - return (zen::ProgressBar::State::EStatus)Status; - } - } - zen::ProgressBar m_Inner; -}; - -class ConsoleOpLogOutput : public OperationLogOutput +class ConsoleOpLogOutput : public ProgressBase { public: ConsoleOpLogOutput(zen::ProgressBar::Mode InMode) : m_Mode(InMode) {} - virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override - { - logging::EmitConsoleLogMessage(Point, Args); - } virtual void SetLogOperationName(std::string_view Name) override { zen::ProgressBar::SetLogOperationName(m_Mode, Name); } virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { zen::ProgressBar::SetLogOperationProgress(m_Mode, StepIndex, StepCount); } - virtual uint32_t GetProgressUpdateDelayMS() override { return GetUpdateDelayMS(m_Mode); } + virtual uint32_t GetProgressUpdateDelayMS() const override { return GetUpdateDelayMS(m_Mode); } - virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override { return new ConsoleOpLogProgressBar(m_Mode, InSubTask); } + virtual std::unique_ptr<ProgressBase::ProgressBar> CreateProgressBar(std::string_view InSubTask) override + { + return std::make_unique<zen::ProgressBar>(m_Mode, InSubTask); + } private: zen::ProgressBar::Mode m_Mode; }; -OperationLogOutput* -CreateConsoleLogOutput(ProgressBar::Mode InMode) +ProgressBase* +CreateConsoleProgress(ProgressBar::Mode InMode) { return new ConsoleOpLogOutput(InMode); } diff --git a/src/zen/progressbar.h b/src/zen/progressbar.h index b54c009e1..26bb9b9c4 100644 --- a/src/zen/progressbar.h +++ b/src/zen/progressbar.h @@ -4,46 +4,13 @@ #include <zencore/timer.h> #include <zencore/zencore.h> - -#include <string> +#include <zenutil/progress.h> namespace zen { -class OperationLogOutput; - -class ProgressBar +class ProgressBar : public ProgressBase::ProgressBar { public: - struct State - { - bool operator==(const State&) const = default; - std::string Task; - std::string Details; - uint64_t TotalCount = 0; - uint64_t RemainingCount = 0; - uint64_t OptionalElapsedTime = (uint64_t)-1; - enum class EStatus - { - Running, - Aborted, - Paused - }; - EStatus Status = EStatus::Running; - - static EStatus CalculateStatus(bool IsAborted, bool IsPaused) - { - if (IsAborted) - { - return EStatus::Aborted; - } - if (IsPaused) - { - return EStatus::Paused; - } - return EStatus::Running; - } - }; - enum class Mode { Plain, @@ -60,24 +27,30 @@ public: explicit ProgressBar(Mode InMode, std::string_view InSubTask); ~ProgressBar(); - void UpdateState(const State& NewState, bool DoLinebreak); + void UpdateState(const State& NewState, bool DoLinebreak) override; void ForceLinebreak(); - void Finish(); + void Finish() override; bool IsSameTask(std::string_view Task) const; bool HasActiveTask() const; private: + void SetupScrollRegion(); + void TeardownScrollRegion(); + void RenderStatusLine(std::string_view Line); + const Mode m_Mode; Stopwatch m_SW; uint64_t m_LastUpdateMS; uint64_t m_PausedMS; State m_State; const std::string m_SubTask; - size_t m_LastOutputLength = 0; + size_t m_LastOutputLength = 0; + bool m_ScrollRegionActive = false; + uint32_t m_ScrollRegionRows = 0; }; uint32_t GetUpdateDelayMS(ProgressBar::Mode InMode); -OperationLogOutput* CreateConsoleLogOutput(ProgressBar::Mode InMode); +ProgressBase* CreateConsoleProgress(ProgressBar::Mode InMode); } // namespace zen diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 3277eb856..0229db4a8 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -9,6 +9,7 @@ #include "cmds/bench_cmd.h" #include "cmds/builds_cmd.h" #include "cmds/cache_cmd.h" +#include "cmds/compute_cmd.h" #include "cmds/copy_cmd.h" #include "cmds/dedup_cmd.h" #include "cmds/exec_cmd.h" @@ -368,7 +369,7 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** } } - // Parse subcommand args permissively — unrecognised options are collected + // Parse subcommand args permissively - unrecognised options are collected // and forwarded to the parent parser so that parent options (e.g. --path) // can appear after the subcommand name on the command line. std::vector<std::string> SubUnmatched; @@ -588,7 +589,8 @@ main(int argc, char** argv) DropCommand DropCmd; DropProjectCommand ProjectDropCmd; #if ZEN_WITH_COMPUTE_SERVICES - ExecCommand ExecCmd; + ComputeCommand ComputeCmd; + ExecCommand ExecCmd; #endif // ZEN_WITH_COMPUTE_SERVICES ExportOplogCommand ExportOplogCmd; FlushCommand FlushCmd; @@ -649,6 +651,7 @@ main(int argc, char** argv) {DownCommand::Name, &DownCmd, DownCommand::Description}, {DropCommand::Name, &DropCmd, DropCommand::Description}, #if ZEN_WITH_COMPUTE_SERVICES + {ComputeCommand::Name, &ComputeCmd, ComputeCommand::Description}, {ExecCommand::Name, &ExecCmd, ExecCommand::Description}, #endif {GcStatusCommand::Name, &GcStatusCmd, GcStatusCommand::Description}, diff --git a/src/zenbase/include/zenbase/atomic.h b/src/zenbase/include/zenbase/atomic.h deleted file mode 100644 index 4ad7962cf..000000000 --- a/src/zenbase/include/zenbase/atomic.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zenbase/zenbase.h> - -#if ZEN_COMPILER_MSC -# include <intrin.h> -#else -# include <atomic> -#endif - -#include <cinttypes> - -namespace zen { - -inline uint32_t -AtomicIncrement(volatile uint32_t& value) -{ -#if ZEN_COMPILER_MSC - return _InterlockedIncrement((long volatile*)&value); -#else - return ((std::atomic<uint32_t>*)(&value))->fetch_add(1, std::memory_order_seq_cst) + 1; -#endif -} -inline uint32_t -AtomicDecrement(volatile uint32_t& value) -{ -#if ZEN_COMPILER_MSC - return _InterlockedDecrement((long volatile*)&value); -#else - return ((std::atomic<uint32_t>*)(&value))->fetch_sub(1, std::memory_order_seq_cst) - 1; -#endif -} - -inline uint64_t -AtomicIncrement(volatile uint64_t& value) -{ -#if ZEN_COMPILER_MSC - return _InterlockedIncrement64((__int64 volatile*)&value); -#else - return ((std::atomic<uint64_t>*)(&value))->fetch_add(1, std::memory_order_seq_cst) + 1; -#endif -} -inline uint64_t -AtomicDecrement(volatile uint64_t& value) -{ -#if ZEN_COMPILER_MSC - return _InterlockedDecrement64((__int64 volatile*)&value); -#else - return ((std::atomic<uint64_t>*)(&value))->fetch_sub(1, std::memory_order_seq_cst) - 1; -#endif -} - -inline uint32_t -AtomicAdd(volatile uint32_t& value, uint32_t amount) -{ -#if ZEN_COMPILER_MSC - return _InterlockedExchangeAdd((long volatile*)&value, amount); -#else - return ((std::atomic<uint32_t>*)(&value))->fetch_add(amount, std::memory_order_seq_cst); -#endif -} -inline uint64_t -AtomicAdd(volatile uint64_t& value, uint64_t amount) -{ -#if ZEN_COMPILER_MSC - return _InterlockedExchangeAdd64((__int64 volatile*)&value, amount); -#else - return ((std::atomic<uint64_t>*)(&value))->fetch_add(amount, std::memory_order_seq_cst); -#endif -} - -} // namespace zen diff --git a/src/zenbase/include/zenbase/refcount.h b/src/zenbase/include/zenbase/refcount.h index 08bc6ae54..0da78ad91 100644 --- a/src/zenbase/include/zenbase/refcount.h +++ b/src/zenbase/include/zenbase/refcount.h @@ -1,9 +1,9 @@ // Copyright Epic Games, Inc. All Rights Reserved. #pragma once -#include <zenbase/atomic.h> #include <zenbase/concepts.h> +#include <atomic> #include <compare> namespace zen { @@ -13,6 +13,13 @@ namespace zen { * * When the reference count reaches zero, the object deletes itself. This class relies * on having a virtual destructor to ensure proper cleanup of derived classes. + * + * Release() is marked noexcept. Derived destructors are expected not to throw - if a + * derived destructor does throw, the exception crosses the noexcept boundary and the + * program terminates via std::terminate. This matches C++'s default destructor + * behaviour (destructors are implicitly noexcept) but is spelled out explicitly here + * because the delete happens inside Release() rather than at the call site, so the + * terminate point is not visually obvious. */ class RefCounted { @@ -20,10 +27,15 @@ public: RefCounted() = default; virtual ~RefCounted() = default; - inline uint32_t AddRef() const noexcept { return AtomicIncrement(const_cast<RefCounted*>(this)->m_RefCount); } - inline uint32_t Release() const + // AddRef uses relaxed ordering: a thread can only add a reference to an object it + // already has a reference to, so there is nothing that needs to synchronize here. + // Release uses acq_rel so that (a) all prior modifications of the object made on + // other threads happen-before the destructor, and (b) the ref-count decrement is + // visible to any thread that observes the new count. + inline uint32_t AddRef() const noexcept { return m_RefCount.fetch_add(1, std::memory_order_relaxed) + 1; } + inline uint32_t Release() const noexcept { - const uint32_t RefCount = AtomicDecrement(const_cast<RefCounted*>(this)->m_RefCount); + const uint32_t RefCount = m_RefCount.fetch_sub(1, std::memory_order_acq_rel) - 1; if (RefCount == 0) { delete this; @@ -39,10 +51,12 @@ public: RefCounted& operator=(RefCounted&&) = delete; protected: - inline uint32_t RefCount() const { return m_RefCount; } + // Diagnostic accessor: relaxed load is fine because the returned value is already + // stale the moment it is observed, so no extra ordering would make it more reliable. + inline uint32_t RefCount() const noexcept { return m_RefCount.load(std::memory_order_relaxed); } private: - uint32_t m_RefCount = 0; + mutable std::atomic<uint32_t> m_RefCount = 0; }; /** @@ -52,6 +66,10 @@ private: * It has no virtual destructor, so it's important that you either don't derive from it further, * or ensure that the derived class has a virtual destructor. * + * As with RefCounted, Release() is noexcept and the derived destructor must not throw. + * A throwing destructor will cause std::terminate to be called when the refcount hits + * zero. + * * This class is useful when you want to avoid adding a vtable to a class just to implement * reference counting. */ @@ -60,15 +78,17 @@ template<typename T> class TRefCounted { public: - TRefCounted() = default; - ~TRefCounted() = default; + TRefCounted() = default; - inline uint32_t AddRef() const noexcept { return AtomicIncrement(const_cast<TRefCounted<T>*>(this)->m_RefCount); } - inline uint32_t Release() const + // See RefCounted::AddRef/Release for ordering rationale. + inline uint32_t AddRef() const noexcept { return m_RefCount.fetch_add(1, std::memory_order_relaxed) + 1; } + inline uint32_t Release() const noexcept { - const uint32_t RefCount = AtomicDecrement(const_cast<TRefCounted<T>*>(this)->m_RefCount); + const uint32_t RefCount = m_RefCount.fetch_sub(1, std::memory_order_acq_rel) - 1; if (RefCount == 0) { + // DeleteThis may be overridden as a non-const member in derived types, + // so we cast away const to support both signatures. const_cast<T*>(static_cast<const T*>(this))->DeleteThis(); } return RefCount; @@ -82,92 +102,21 @@ public: TRefCounted& operator=(TRefCounted&&) = delete; protected: - inline uint32_t RefCount() const { return m_RefCount; } - - void DeleteThis() const noexcept { delete static_cast<const T*>(this); } - -private: - uint32_t m_RefCount = 0; -}; - -/** - * Smart pointer for classes derived from RefCounted - */ - -template<class T> -class RefPtr -{ -public: - inline RefPtr() = default; - inline RefPtr(const RefPtr& Rhs) : m_Ref(Rhs.m_Ref) { m_Ref && m_Ref->AddRef(); } - inline RefPtr(T* Ptr) : m_Ref(Ptr) { m_Ref && m_Ref->AddRef(); } - inline ~RefPtr() { m_Ref && m_Ref->Release(); } - - [[nodiscard]] inline bool IsNull() const { return m_Ref == nullptr; } - inline explicit operator bool() const { return m_Ref != nullptr; } - inline operator T*() const { return m_Ref; } - inline T* operator->() const { return m_Ref; } + // Non-virtual: destruction goes through DeleteThis(), which deletes the derived type. + // Protected so that external callers cannot `delete` through a TRefCounted<T>* and slice. + ~TRefCounted() = default; - inline std::strong_ordering operator<=>(const RefPtr& Rhs) const = default; + // Diagnostic accessor: see RefCounted::RefCount for ordering rationale. + inline uint32_t RefCount() const noexcept { return m_RefCount.load(std::memory_order_relaxed); } - inline RefPtr& operator=(T* Rhs) - { - Rhs && Rhs->AddRef(); - m_Ref && m_Ref->Release(); - m_Ref = Rhs; - return *this; - } - inline RefPtr& operator=(const RefPtr& Rhs) - { - if (&Rhs != this) - { - Rhs && Rhs->AddRef(); - m_Ref && m_Ref->Release(); - m_Ref = Rhs.m_Ref; - } - return *this; - } - inline RefPtr& operator=(RefPtr&& Rhs) noexcept - { - if (&Rhs != this) - { - m_Ref && m_Ref->Release(); - m_Ref = Rhs.m_Ref; - Rhs.m_Ref = nullptr; - } - return *this; - } - template<typename OtherType> - inline RefPtr& operator=(RefPtr<OtherType>&& Rhs) noexcept - { - if ((RefPtr*)&Rhs != this) - { - m_Ref && m_Ref->Release(); - m_Ref = Rhs.m_Ref; - Rhs.m_Ref = nullptr; - } - return *this; - } - inline RefPtr(RefPtr&& Rhs) noexcept : m_Ref(Rhs.m_Ref) { Rhs.m_Ref = nullptr; } - template<typename OtherType> - explicit inline RefPtr(RefPtr<OtherType>&& Rhs) noexcept : m_Ref(Rhs.m_Ref) - { - Rhs.m_Ref = nullptr; - } - - inline void Swap(RefPtr& Rhs) noexcept { std::swap(m_Ref, Rhs.m_Ref); } + void DeleteThis() const noexcept { delete static_cast<const T*>(this); } private: - T* m_Ref = nullptr; - template<typename U> - friend class RefPtr; + mutable std::atomic<uint32_t> m_RefCount = 0; }; /** - * Smart pointer for classes derived from RefCounted - * - * This variant does not decay to a raw pointer - * + * Smart pointer for classes derived from RefCounted (or TRefCounted) */ template<class T> @@ -177,12 +126,16 @@ public: inline Ref() = default; inline Ref(Ref&& Rhs) noexcept : m_Ref(Rhs.m_Ref) { Rhs.m_Ref = nullptr; } inline Ref(const Ref& Rhs) noexcept : m_Ref(Rhs.m_Ref) { m_Ref && m_Ref->AddRef(); } - inline explicit Ref(T* Ptr) : m_Ref(Ptr) { m_Ref && m_Ref->AddRef(); } - inline ~Ref() { m_Ref && m_Ref->Release(); } + inline explicit Ref(T* Ptr) noexcept : m_Ref(Ptr) { m_Ref && m_Ref->AddRef(); } + inline ~Ref() noexcept { m_Ref && m_Ref->Release(); } template<typename DerivedType> requires DerivedFrom<DerivedType, T> - inline Ref(const Ref<DerivedType>& Rhs) : Ref(Rhs.m_Ref) {} + inline Ref(const Ref<DerivedType>& Rhs) noexcept : Ref(Rhs.m_Ref) {} + + template<typename DerivedType> + requires DerivedFrom<DerivedType, T> + inline Ref(Ref<DerivedType>&& Rhs) noexcept : m_Ref(Rhs.m_Ref) { Rhs.m_Ref = nullptr; } [[nodiscard]] inline bool IsNull() const { return m_Ref == nullptr; } inline explicit operator bool() const { return m_Ref != nullptr; } @@ -192,14 +145,14 @@ public: inline std::strong_ordering operator<=>(const Ref& Rhs) const = default; - inline Ref& operator=(T* Rhs) + inline Ref& operator=(T* Rhs) noexcept { Rhs && Rhs->AddRef(); m_Ref && m_Ref->Release(); m_Ref = Rhs; return *this; } - inline Ref& operator=(const Ref& Rhs) + inline Ref& operator=(const Ref& Rhs) noexcept { if (&Rhs != this) { @@ -219,6 +172,20 @@ public: } return *this; } + template<typename DerivedType> + requires DerivedFrom<DerivedType, T> + inline Ref& operator=(Ref<DerivedType>&& Rhs) noexcept + { + if ((Ref*)&Rhs != this) + { + m_Ref && m_Ref->Release(); + m_Ref = Rhs.m_Ref; + Rhs.m_Ref = nullptr; + } + return *this; + } + + inline void Swap(Ref& Rhs) noexcept { std::swap(m_Ref, Rhs.m_Ref); } private: T* m_Ref = nullptr; diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md index a1a39fc3c..bb574edc2 100644 --- a/src/zencompute/CLAUDE.md +++ b/src/zencompute/CLAUDE.md @@ -141,7 +141,7 @@ Actions that fail or are abandoned can be automatically retried or manually resc **Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure. -**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. +**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Setting `max_retries=0` disables automatic retry entirely; omitting the field (or setting it to a negative value) uses the default of 3. Both automatic and manual paths respect this limit. **Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent. @@ -156,7 +156,7 @@ Queues group actions from a single client session. A `QueueEntry` (internal) tra - `ActiveLsns` — for cancellation lookup (under `m_Lock`) - `FinishedLsns` — moved here when actions complete - `IdleSince` — used for 15-minute automatic expiry -- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit +- `Config` — CbObject set at creation; supports `max_retries` (int, default 3) to override the default retry limit. `0` = no retries, negative or absent = use default **Queue state machine (`QueueState` enum):** ``` @@ -216,11 +216,14 @@ Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `Han ## Concurrency Model -**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks: -1. `m_ResultsLock` -2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock") -3. `m_PendingLock` -4. `m_QueueLock` +**Locking discipline:** The three action maps (`m_PendingActions`, `m_RunningMap`, `m_ResultsMap`) are guarded by a single `m_ActionMapLock`. This eliminates lock-ordering concerns between maps and prevents actions from being temporarily absent from all maps during state transitions. Runner-level `m_RunningLock` in `LocalProcessRunner` / `RemoteHttpRunner` is a separate lock on a different class — unrelated to the session-level action map lock. + +**Lock ordering:** When acquiring multiple session-level locks, always acquire in this order to avoid deadlocks: +1. `m_ActionMapLock` (session action maps) +2. `QueueEntry::m_Lock` (per-queue state) +3. `m_ActionHistoryLock` (action history ring) + +Never acquire an earlier lock while holding a later one (e.g. never acquire `m_ActionMapLock` while holding `QueueEntry::m_Lock`). **Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`. diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp index eb4c05f9f..f1df18e8e 100644 --- a/src/zencompute/cloudmetadata.cpp +++ b/src/zencompute/cloudmetadata.cpp @@ -183,7 +183,7 @@ CloudMetadata::TryDetectAWS() m_Info.AvailabilityZone = std::string(AzResponse.AsText()); } - // "spot" vs "on-demand" — determines whether the instance can be + // "spot" vs "on-demand" - determines whether the instance can be // reclaimed by AWS with a 2-minute warning HttpClient::Response LifecycleResponse = ImdsClient.Get("/latest/meta-data/instance-life-cycle", AuthHeaders); if (LifecycleResponse.IsSuccess()) @@ -273,7 +273,7 @@ CloudMetadata::TryDetectAzure() std::string Priority = Compute["priority"].string_value(); m_Info.IsSpot = (Priority == "Spot"); - // Check if part of a VMSS (Virtual Machine Scale Set) — indicates autoscaling + // Check if part of a VMSS (Virtual Machine Scale Set) - indicates autoscaling std::string VmssName = Compute["vmScaleSetName"].string_value(); m_Info.IsAutoscaling = !VmssName.empty(); @@ -609,7 +609,7 @@ namespace zen::compute { TEST_SUITE_BEGIN("compute.cloudmetadata"); // --------------------------------------------------------------------------- -// Test helper — spins up a local ASIO HTTP server hosting a MockImdsService +// Test helper - spins up a local ASIO HTTP server hosting a MockImdsService // --------------------------------------------------------------------------- struct TestImdsServer @@ -974,7 +974,7 @@ TEST_CASE("cloudmetadata.sentinel_files") SUBCASE("only failed providers get sentinels") { - // Switch to AWS — Azure and GCP never probed, so no sentinels for them + // Switch to AWS - Azure and GCP never probed, so no sentinels for them Imds.Mock.ActiveProvider = CloudProvider::AWS; auto Cloud = Imds.CreateCloud(); diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp index 92901de64..7f354a51c 100644 --- a/src/zencompute/computeservice.cpp +++ b/src/zencompute/computeservice.cpp @@ -8,6 +8,8 @@ # include "recording/actionrecorder.h" # include "runners/localrunner.h" # include "runners/remotehttprunner.h" +# include "runners/managedrunner.h" +# include "pathvalidation.h" # if ZEN_PLATFORM_LINUX # include "runners/linuxrunner.h" # elif ZEN_PLATFORM_WINDOWS @@ -119,6 +121,8 @@ struct ComputeServiceSession::Impl , m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) , m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) { + m_RemoteRunnerGroup.SetWorkerPool(&m_RemoteSubmitPool); + // Create a non-expiring, non-deletable implicit queue for legacy endpoints auto Result = CreateQueue("implicit"sv, {}, {}); m_ImplicitQueueId = Result.QueueId; @@ -195,13 +199,9 @@ struct ComputeServiceSession::Impl std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr}; - RwLock m_PendingLock; - std::map<int, Ref<RunnerAction>> m_PendingActions; - - RwLock m_RunningLock; + RwLock m_ActionMapLock; // Guards m_PendingActions, m_RunningMap, m_ResultsMap + std::map<int, Ref<RunnerAction>> m_PendingActions; std::unordered_map<int, Ref<RunnerAction>> m_RunningMap; - - RwLock m_ResultsLock; std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap; metrics::Meter m_ResultRate; std::atomic<uint64_t> m_RetiredCount{0}; @@ -242,8 +242,9 @@ struct ComputeServiceSession::Impl // Recording - void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); - void StopRecording(); + bool StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); + bool StopRecording(); + bool IsRecording() const; std::unique_ptr<ActionRecorder> m_Recorder; @@ -343,9 +344,12 @@ struct ComputeServiceSession::Impl ActionCounts GetActionCounts() { ActionCounts Counts; - Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); - Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load(); + m_ActionMapLock.WithSharedLock([&] { + Counts.Pending = (int)m_PendingActions.size(); + Counts.Running = (int)m_RunningMap.size(); + Counts.Completed = (int)m_ResultsMap.size(); + }); + Counts.Completed += (int)m_RetiredCount.load(); Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] { size_t Count = 0; for (const auto& [Id, Queue] : m_Queues) @@ -364,8 +368,10 @@ struct ComputeServiceSession::Impl { Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed)); m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); - m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); - m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); + m_ActionMapLock.WithSharedLock([&] { + Cbo << "actions_complete"sv << m_ResultsMap.size(); + Cbo << "actions_pending"sv << m_PendingActions.size(); + }); Cbo << "actions_submitted"sv << GetSubmittedActionCount(); EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo); EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); @@ -443,34 +449,24 @@ ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState) return true; } - // CAS failed, Current was updated — retry with the new value + // CAS failed, Current was updated - retry with the new value } } void ComputeServiceSession::Impl::AbandonAllActions() { - // Collect all pending actions and mark them as Abandoned + // Collect all pending and running actions under a single lock scope std::vector<Ref<RunnerAction>> PendingToAbandon; + std::vector<Ref<RunnerAction>> RunningToAbandon; - m_PendingLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { PendingToAbandon.reserve(m_PendingActions.size()); for (auto& [Lsn, Action] : m_PendingActions) { PendingToAbandon.push_back(Action); } - }); - - for (auto& Action : PendingToAbandon) - { - Action->SetActionState(RunnerAction::State::Abandoned); - } - - // Collect all running actions and mark them as Abandoned, then - // best-effort cancel via the local runner group - std::vector<Ref<RunnerAction>> RunningToAbandon; - m_RunningLock.WithSharedLock([&] { RunningToAbandon.reserve(m_RunningMap.size()); for (auto& [Lsn, Action] : m_RunningMap) { @@ -478,6 +474,11 @@ ComputeServiceSession::Impl::AbandonAllActions() } }); + for (auto& Action : PendingToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + } + for (auto& Action : RunningToAbandon) { Action->SetActionState(RunnerAction::State::Abandoned); @@ -617,6 +618,7 @@ ComputeServiceSession::Impl::UpdateCoordinatorState() m_KnownWorkerUris.insert(UriStr); auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool); + NewRunner->SetRemoteHostname(Hostname); SyncWorkersToRunner(*NewRunner); m_RemoteRunnerGroup.AddRunner(NewRunner); } @@ -718,31 +720,51 @@ ComputeServiceSession::Impl::ShutdownRunners() m_RemoteRunnerGroup.Shutdown(); } -void +bool ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) { + if (m_Recorder) + { + ZEN_WARN("recording is already active"); + return false; + } + ZEN_INFO("starting recording to '{}'", RecordingPath); m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath); ZEN_INFO("started recording to '{}'", RecordingPath); + return true; } -void +bool ComputeServiceSession::Impl::StopRecording() { + if (!m_Recorder) + { + ZEN_WARN("no recording is active"); + return false; + } + ZEN_INFO("stopping recording"); m_Recorder = nullptr; ZEN_INFO("stopped recording"); + return true; +} + +bool +ComputeServiceSession::Impl::IsRecording() const +{ + return m_Recorder != nullptr; } std::vector<ComputeServiceSession::RunningActionInfo> ComputeServiceSession::Impl::GetRunningActions() { std::vector<ComputeServiceSession::RunningActionInfo> Result; - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { Result.reserve(m_RunningMap.size()); for (const auto& [Lsn, Action] : m_RunningMap) { @@ -810,6 +832,11 @@ void ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) { ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker"); + + // Validate all paths in the worker description upfront, before the worker is + // distributed to runners. This rejects malicious packages early at ingestion time. + ValidateWorkerDescriptionPaths(Worker.GetObject()); + RwLock::ExclusiveLockScope _(m_WorkerLock); const IoHash& WorkerId = Worker.GetObject().GetHash(); @@ -994,10 +1021,15 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke Pending->ActionObj = ActionObj; Pending->Priority = RequestPriority; - // For now simply put action into pending state, so we can do batch scheduling + // Insert into the pending map immediately so the action is visible to + // FindActionResult/GetActionResult right away. SetActionState will call + // PostUpdate which adds the action to m_UpdatedActions and signals the + // scheduler, but the scheduler's HandleActionUpdates inserts with + // std::map::insert which is a no-op for existing keys. ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Pending}); }); Pending->SetActionState(RunnerAction::State::Pending); if (m_Recorder) @@ -1043,11 +1075,7 @@ ComputeServiceSession::Impl::GetSubmittedActionCount() HttpResponseCode ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) { - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); + RwLock::ExclusiveLockScope _(m_ActionMapLock); if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) { @@ -1058,25 +1086,14 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult return HttpResponseCode::OK; } + if (m_PendingActions.find(ActionLsn) != m_PendingActions.end()) { - RwLock::SharedLockScope __(m_PendingLock); - - if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - + if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) { - RwLock::SharedLockScope __(m_RunningLock); - - if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } return HttpResponseCode::NotFound; @@ -1085,11 +1102,7 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult HttpResponseCode ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) { - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); + RwLock::ExclusiveLockScope _(m_ActionMapLock); for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) { @@ -1103,30 +1116,19 @@ ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& } } + for (const auto& [K, Pending] : m_PendingActions) { - RwLock::SharedLockScope __(m_PendingLock); - - for (const auto& [K, Pending] : m_PendingActions) + if (Pending->ActionId == ActionId) { - if (Pending->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } } - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - + for (const auto& [K, v] : m_RunningMap) { - RwLock::SharedLockScope __(m_RunningLock); - - for (const auto& [K, v] : m_RunningMap) + if (v->ActionId == ActionId) { - if (v->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } } @@ -1144,12 +1146,16 @@ ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo) { Cbo.BeginArray("completed"); - m_ResultsLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (auto& [Lsn, Action] : m_ResultsMap) { Cbo.BeginObject(); Cbo << "lsn"sv << Lsn; Cbo << "state"sv << RunnerAction::ToString(Action->ActionState()); + if (!Action->FailureReason.empty()) + { + Cbo << "reason"sv << Action->FailureReason; + } Cbo.EndObject(); } }); @@ -1275,20 +1281,14 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId) std::vector<Ref<RunnerAction>> PendingActionsToCancel; std::vector<int> RunningLsnsToCancel; - m_PendingLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (int Lsn : LsnsToCancel) { if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end()) { PendingActionsToCancel.push_back(It->second); } - } - }); - - m_RunningLock.WithSharedLock([&] { - for (int Lsn : LsnsToCancel) - { - if (m_RunningMap.find(Lsn) != m_RunningMap.end()) + else if (m_RunningMap.find(Lsn) != m_RunningMap.end()) { RunningLsnsToCancel.push_back(Lsn); } @@ -1307,7 +1307,7 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId) // transition from the runner is blocked (Cancelled > Failed in the enum). for (int Lsn : RunningLsnsToCancel) { - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end()) { It->second->SetActionState(RunnerAction::State::Cancelled); @@ -1444,8 +1444,8 @@ ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo) if (Queue) { - Queue->m_Lock.WithSharedLock([&] { - m_ResultsLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { + Queue->m_Lock.WithSharedLock([&] { for (int Lsn : Queue->FinishedLsns) { if (m_ResultsMap.contains(Lsn)) @@ -1475,15 +1475,19 @@ ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, Run return; } + bool WasActive = false; Queue->m_Lock.WithExclusiveLock([&] { - Queue->ActiveLsns.erase(Lsn); + WasActive = Queue->ActiveLsns.erase(Lsn) > 0; Queue->FinishedLsns.insert(Lsn); }); - const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed); - if (PreviousActive == 1) + if (WasActive) { - Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed); + const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed); + if (PreviousActive == 1) + { + Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed); + } } switch (ActionState) @@ -1541,26 +1545,32 @@ ComputeServiceSession::Impl::SchedulePendingActions() { ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions"); int ScheduledCount = 0; - size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); - size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); + size_t RunningCount = 0; + size_t PendingCount = 0; + size_t ResultCount = 0; + + m_ActionMapLock.WithSharedLock([&] { + RunningCount = m_RunningMap.size(); + PendingCount = m_PendingActions.size(); + ResultCount = m_ResultsMap.size(); + }); static Stopwatch DumpRunningTimer; auto _ = MakeGuard([&] { - ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", - ScheduledCount, - RunningCount, - m_RetiredCount.load(), - PendingCount, - ResultCount); + ZEN_DEBUG("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", + ScheduledCount, + RunningCount, + m_RetiredCount.load(), + PendingCount, + ResultCount); if (DumpRunningTimer.GetElapsedTimeMs() > 30000) { DumpRunningTimer.Reset(); std::set<int> RunningList; - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (auto& [K, V] : m_RunningMap) { RunningList.insert(K); @@ -1602,13 +1612,13 @@ ComputeServiceSession::Impl::SchedulePendingActions() // Also note that the m_PendingActions list is not maintained // here, that's done periodically in SchedulePendingActions() - m_PendingLock.WithExclusiveLock([&] { - if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) - { - return; - } + // Extract pending actions under a shared lock - we only need to read + // the map and take Ref copies. ActionState() is atomic so this is safe. + // Sorting and capacity trimming happen outside the lock to avoid + // blocking HTTP handlers on O(N log N) work with large pending queues. - if (m_PendingActions.empty()) + m_ActionMapLock.WithSharedLock([&] { + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) { return; } @@ -1628,6 +1638,7 @@ ComputeServiceSession::Impl::SchedulePendingActions() case RunnerAction::State::Completed: case RunnerAction::State::Failed: case RunnerAction::State::Abandoned: + case RunnerAction::State::Rejected: case RunnerAction::State::Cancelled: break; @@ -1638,30 +1649,30 @@ ComputeServiceSession::Impl::SchedulePendingActions() } } - // Sort by priority descending, then by LSN ascending (FIFO within same priority) - std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) { - if (A->Priority != B->Priority) - { - return A->Priority > B->Priority; - } - return A->ActionLsn < B->ActionLsn; - }); + PendingCount = m_PendingActions.size(); + }); - if (ActionsToSchedule.size() > Capacity) + // Sort by priority descending, then by LSN ascending (FIFO within same priority) + std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) { + if (A->Priority != B->Priority) { - ActionsToSchedule.resize(Capacity); + return A->Priority > B->Priority; } - - PendingCount = m_PendingActions.size(); + return A->ActionLsn < B->ActionLsn; }); + if (ActionsToSchedule.size() > Capacity) + { + ActionsToSchedule.resize(Capacity); + } + if (ActionsToSchedule.empty()) { _.Dismiss(); return; } - ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); + ZEN_DEBUG("attempting schedule of {} pending actions", ActionsToSchedule.size()); Stopwatch SubmitTimer; std::vector<SubmitResult> SubmitResults = SubmitActions(ActionsToSchedule); @@ -1681,10 +1692,10 @@ ComputeServiceSession::Impl::SchedulePendingActions() } } - ZEN_INFO("scheduled {} pending actions in {} ({} rejected)", - ScheduledActionCount, - NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), - NotAcceptedCount); + ZEN_DEBUG("scheduled {} pending actions in {} ({} rejected)", + ScheduledActionCount, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + NotAcceptedCount); ScheduledCount += ScheduledActionCount; PendingCount -= ScheduledActionCount; @@ -1701,7 +1712,7 @@ ComputeServiceSession::Impl::SchedulerThreadFunction() { int TimeoutMs = 500; - auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + auto PendingCount = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.size(); }); if (PendingCount) { @@ -1720,22 +1731,22 @@ ComputeServiceSession::Impl::SchedulerThreadFunction() m_SchedulingThreadEvent.Reset(); } - ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}", - m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }), - PendingCount, - m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }), - m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }), - TimeoutMs); + m_ActionMapLock.WithSharedLock([&] { + ZEN_DEBUG("compute scheduler TICK (Pending: {}, Running: {}, Results: {}) timeout: {}", + m_PendingActions.size(), + m_RunningMap.size(), + m_ResultsMap.size(), + TimeoutMs); + }); HandleActionUpdates(); - // Auto-transition Draining → Paused when all work is done + // Auto-transition Draining -> Paused when all work is done if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining) { - size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + bool AllDrained = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.empty() && m_RunningMap.empty(); }); - if (Pending == 0 && Running == 0) + if (AllDrained) { SessionState Expected = SessionState::Draining; if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel)) @@ -1776,9 +1787,9 @@ ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId) if (Config) { - int Value = Config["max_retries"].AsInt32(0); + int Value = Config["max_retries"].AsInt32(-1); - if (Value > 0) + if (Value >= 0) { return Value; } @@ -1797,7 +1808,7 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) // Find, validate, and remove atomically under a single lock scope to prevent // concurrent RescheduleAction calls from double-removing the same action. - m_ResultsLock.WithExclusiveLock([&] { + m_ActionMapLock.WithExclusiveLock([&] { auto It = m_ResultsMap.find(ActionLsn); if (It == m_ResultsMap.end()) { @@ -1855,7 +1866,7 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) } } - // Reset action state — this calls PostUpdate() internally + // Reset action state - this calls PostUpdate() internally Action->ResetActionStateToPending(); int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); @@ -1871,26 +1882,20 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn) bool WasRunning = false; // Look for the action in pending or running maps - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end()) { Action = It->second; WasRunning = true; } + else if (auto PIt = m_PendingActions.find(ActionLsn); PIt != m_PendingActions.end()) + { + Action = PIt->second; + } }); if (!Action) { - m_PendingLock.WithSharedLock([&] { - if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end()) - { - Action = It->second; - } - }); - } - - if (!Action) - { return {.Success = false, .Error = "Action not found in pending or running maps"}; } @@ -1912,18 +1917,15 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn) void ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn) { - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); + // Caller must hold m_ActionMapLock exclusively. + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } } void @@ -1946,7 +1948,7 @@ ComputeServiceSession::Impl::HandleActionUpdates() // This is safe because state transitions are monotonically increasing by enum // rank (Pending < Submitting < Running < Completed/Failed/Cancelled), so // SetActionState rejects any transition to a lower-ranked state. By the time - // we read ActionState() here, it reflects the highest state reached — making + // we read ActionState() here, it reflects the highest state reached - making // the first occurrence per LSN authoritative and duplicates redundant. for (Ref<RunnerAction>& Action : UpdatedActions) { @@ -1956,7 +1958,7 @@ ComputeServiceSession::Impl::HandleActionUpdates() { switch (Action->ActionState()) { - // Newly enqueued — add to pending map for scheduling + // Newly enqueued - add to pending map for scheduling case RunnerAction::State::Pending: // Guard against a race where the session is abandoned between // EnqueueAction (which calls PostUpdate) and this scheduler @@ -1973,35 +1975,44 @@ ComputeServiceSession::Impl::HandleActionUpdates() } else { - m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); } break; - // Async submission in progress — remains in pending map + // Async submission in progress - remains in pending map case RunnerAction::State::Submitting: break; - // Dispatched to a runner — move from pending to running + // Dispatched to a runner - move from pending to running case RunnerAction::State::Running: - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - m_RunningMap.insert({ActionLsn, Action}); - m_PendingActions.erase(ActionLsn); - }); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.insert({ActionLsn, Action}); + m_PendingActions.erase(ActionLsn); }); ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); break; - // Retracted — pull back for rescheduling without counting against retry limit + // Retracted - pull back for rescheduling without counting against retry limit case RunnerAction::State::Retracted: { - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.erase(ActionLsn); + m_PendingActions[ActionLsn] = Action; + }); Action->ResetActionStateToPending(); ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn); break; } - // Terminal states — move to results, record history, notify queue + // Rejected - runner was at capacity, reschedule without retry cost + case RunnerAction::State::Rejected: + { + Action->ResetActionStateToPending(); + ZEN_DEBUG("action {} ({}) rescheduled after runner rejection", Action->ActionId, ActionLsn); + break; + } + + // Terminal states - move to results, record history, notify queue case RunnerAction::State::Completed: case RunnerAction::State::Failed: case RunnerAction::State::Abandoned: @@ -2010,7 +2021,7 @@ ComputeServiceSession::Impl::HandleActionUpdates() auto TerminalState = Action->ActionState(); // Automatic retry for Failed/Abandoned actions with retries remaining. - // Skip retries when the session itself is abandoned — those actions + // Skip retries when the session itself is abandoned - those actions // were intentionally abandoned and should not be rescheduled. if ((TerminalState == RunnerAction::State::Failed || TerminalState == RunnerAction::State::Abandoned) && m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned) @@ -2019,7 +2030,10 @@ ComputeServiceSession::Impl::HandleActionUpdates() if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) { - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.erase(ActionLsn); + m_PendingActions[ActionLsn] = Action; + }); // Reset triggers PostUpdate() which re-enters the action as Pending Action->ResetActionStateToPending(); @@ -2032,18 +2046,26 @@ ComputeServiceSession::Impl::HandleActionUpdates() MaxRetries); break; } + else + { + ZEN_WARN("action {} ({}) {} after {} retries, not rescheduling", + Action->ActionId, + ActionLsn, + RunnerAction::ToString(TerminalState), + Action->RetryCount.load(std::memory_order_relaxed)); + } } - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + RemoveActionFromActiveMaps(ActionLsn); - // Update queue counters BEFORE publishing the result into - // m_ResultsMap. GetActionResult erases from m_ResultsMap - // under m_ResultsLock, so if we updated counters after - // releasing that lock, a caller could observe ActiveCount - // still at 1 immediately after GetActionResult returned OK. - NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); + // Update queue counters BEFORE publishing the result into + // m_ResultsMap. GetActionResult erases from m_ResultsMap + // under m_ActionMapLock, so if we updated counters after + // releasing that lock, a caller could observe ActiveCount + // still at 1 immediately after GetActionResult returned OK. + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); - m_ResultsLock.WithExclusiveLock([&] { m_ResultsMap[ActionLsn] = Action; // Append to bounded action history ring @@ -2124,10 +2146,9 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions"); std::vector<SubmitResult> Results(Actions.size()); - // First try submitting the batch to local runners in parallel + // First try submitting the batch to local runners std::vector<SubmitResult> LocalResults = m_LocalRunnerGroup.SubmitActions(Actions); - std::vector<size_t> RemoteIndices; std::vector<Ref<RunnerAction>> RemoteActions; for (size_t i = 0; i < Actions.size(); ++i) @@ -2138,20 +2159,40 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& } else { - RemoteIndices.push_back(i); RemoteActions.push_back(Actions[i]); + Results[i] = SubmitResult{.IsAccepted = true, .Reason = "dispatched to remote"}; } } - // Submit remaining actions to remote runners in parallel + // Dispatch remaining actions to remote runners asynchronously. + // Mark actions as Submitting so the scheduler won't re-pick them. + // The remote runner will transition them to Running on success, or + // we mark them Failed on rejection so HandleActionUpdates retries. if (!RemoteActions.empty()) { - std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); - - for (size_t j = 0; j < RemoteIndices.size(); ++j) + for (const Ref<RunnerAction>& Action : RemoteActions) { - Results[RemoteIndices[j]] = std::move(RemoteResults[j]); + Action->SetActionState(RunnerAction::State::Submitting); } + + m_RemoteSubmitPool.ScheduleWork( + [this, RemoteActions = std::move(RemoteActions)]() { + std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); + + for (size_t j = 0; j < RemoteResults.size(); ++j) + { + if (!RemoteResults[j].IsAccepted) + { + ZEN_DEBUG("remote submission rejected for action {} ({}): {}", + RemoteActions[j]->ActionId, + RemoteActions[j]->ActionLsn, + RemoteResults[j].Reason); + + RemoteActions[j]->SetActionState(RunnerAction::State::Rejected); + } + } + }, + WorkerThreadPool::EMode::EnableBacklog); } return Results; @@ -2217,16 +2258,22 @@ ComputeServiceSession::NotifyOrchestratorChanged() m_Impl->NotifyOrchestratorChanged(); } -void +bool ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) { - m_Impl->StartRecording(InResolver, RecordingPath); + return m_Impl->StartRecording(InResolver, RecordingPath); } -void +bool ComputeServiceSession::StopRecording() { - m_Impl->StopRecording(); + return m_Impl->StopRecording(); +} + +bool +ComputeServiceSession::IsRecording() const +{ + return m_Impl->IsRecording(); } ComputeServiceSession::ActionCounts @@ -2282,6 +2329,18 @@ ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::files } void +ComputeServiceSession::AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddManagedLocalRunner"); + + auto* NewRunner = + new ManagedProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, MaxConcurrentActions); + + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner); +} + +void ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) { ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner"); diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp index bdfd9d197..5ab189d89 100644 --- a/src/zencompute/httpcomputeservice.cpp +++ b/src/zencompute/httpcomputeservice.cpp @@ -21,12 +21,14 @@ # include <zencore/thread.h> # include <zencore/trace.h> # include <zencore/uid.h> -# include <zenstore/cidstore.h> +# include <zenstore/hashkeyset.h> +# include <zenstore/zenstore.h> # include <zentelemetry/stats.h> # include <algorithm> # include <span> # include <unordered_map> +# include <utility> # include <vector> using namespace std::literals; @@ -45,7 +47,9 @@ auto OidMatcher = [](std::string_view Str) { return Str.size() == 24 && AsciiSe struct HttpComputeService::Impl { HttpComputeService* m_Self; - CidStore& m_CidStore; + ChunkStore& m_ActionStore; + ChunkStore& m_WorkerStore; + FallbackChunkResolver m_CombinedResolver; IHttpStatsService& m_StatsService; LoggerRef m_Log; std::filesystem::path m_BaseDir; @@ -58,6 +62,8 @@ struct HttpComputeService::Impl RwLock m_WsConnectionsLock; std::vector<Ref<WebSocketConnection>> m_WsConnections; + std::function<void()> m_ShutdownCallback; + // Metrics metrics::OperationTiming m_HttpRequests; @@ -72,13 +78,13 @@ struct HttpComputeService::Impl std::string ClientHostname; // empty if no hostname was provided }; - // Remote queue registry — all three maps share the same RemoteQueueInfo objects. + // Remote queue registry - all three maps share the same RemoteQueueInfo objects. // All maps are guarded by m_RemoteQueueLock. RwLock m_RemoteQueueLock; - std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token → info - std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId → info - std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key → info + std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token -> info + std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId -> info + std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key -> info LoggerRef Log() { return m_Log; } @@ -93,34 +99,38 @@ struct HttpComputeService::Impl uint64_t NewBytes = 0; }; - IngestStats IngestPackageAttachments(const CbPackage& Package); - bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); - void HandleWorkersGet(HttpServerRequest& HttpReq); - void HandleWorkersAllGet(HttpServerRequest& HttpReq); - void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); - void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); - void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker); + bool IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats); + bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); + bool ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment); + void HandleWorkersGet(HttpServerRequest& HttpReq); + void HandleWorkersAllGet(HttpServerRequest& HttpReq); + void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); + void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker); // WebSocket / observer - void OnWebSocketOpen(Ref<WebSocketConnection> Connection); + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri); void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code); void OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions); void RegisterRoutes(); - Impl(HttpComputeService* Self, - CidStore& InCidStore, - IHttpStatsService& StatsService, - const std::filesystem::path& BaseDir, - int32_t MaxConcurrentActions) + Impl(HttpComputeService* Self, + ChunkStore& InActionStore, + ChunkStore& InWorkerStore, + IHttpStatsService& StatsService, + std::filesystem::path BaseDir, + int32_t MaxConcurrentActions) : m_Self(Self) - , m_CidStore(InCidStore) + , m_ActionStore(InActionStore) + , m_WorkerStore(InWorkerStore) + , m_CombinedResolver(InActionStore, InWorkerStore) , m_StatsService(StatsService) , m_Log(logging::Get("compute")) - , m_BaseDir(BaseDir) - , m_ComputeService(InCidStore) + , m_BaseDir(std::move(BaseDir)) + , m_ComputeService(m_CombinedResolver) { - m_ComputeService.AddLocalRunner(InCidStore, m_BaseDir / "local", MaxConcurrentActions); + m_ComputeService.AddLocalRunner(m_CombinedResolver, m_BaseDir / "local", MaxConcurrentActions); m_ComputeService.WaitUntilReady(); m_StatsService.RegisterHandler("compute", *m_Self); RegisterRoutes(); @@ -182,6 +192,65 @@ HttpComputeService::Impl::RegisterRoutes() HttpVerb::kPost); m_Router.RegisterRoute( + "session/drain", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Draining)) + { + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Draining from current state"sv; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "session/status", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + auto Counts = m_ComputeService.GetActionCounts(); + Cbo << "actions_pending"sv << Counts.Pending; + Cbo << "actions_running"sv << Counts.Running; + Cbo << "actions_completed"sv << Counts.Completed; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "session/sunset", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Sunset)) + { + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + + if (m_ShutdownCallback) + { + m_ShutdownCallback(); + } + return; + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Sunset from current state"sv; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( "workers", [this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); }, HttpVerb::kGet); @@ -373,7 +442,7 @@ HttpComputeService::Impl::RegisterRoutes() if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output); ResponseCode != HttpResponseCode::OK) { - ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) + ZEN_DEBUG("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) if (ResponseCode == HttpResponseCode::NotFound) { @@ -498,9 +567,19 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::Forbidden); } - m_ComputeService.StartRecording(m_CidStore, m_BaseDir / "recording"); + std::filesystem::path RecordingPath = m_BaseDir / "recording"; + + if (!m_ComputeService.StartRecording(m_CombinedResolver, RecordingPath)) + { + CbObjectWriter Cbo; + Cbo << "error" + << "recording is already active"; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } - return HttpReq.WriteResponse(HttpResponseCode::OK); + CbObjectWriter Cbo; + Cbo << "path" << RecordingPath.string(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kPost); @@ -514,9 +593,19 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::Forbidden); } - m_ComputeService.StopRecording(); + std::filesystem::path RecordingPath = m_BaseDir / "recording"; - return HttpReq.WriteResponse(HttpResponseCode::OK); + if (!m_ComputeService.StopRecording()) + { + CbObjectWriter Cbo; + Cbo << "error" + << "no recording is active"; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "path" << RecordingPath.string(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kPost); @@ -583,7 +672,7 @@ HttpComputeService::Impl::RegisterRoutes() }, HttpVerb::kGet | HttpVerb::kPost); - // Queue creation routes — these remain separate since local creates a plain queue + // Queue creation routes - these remain separate since local creates a plain queue // while remote additionally generates an OID token for external access. m_Router.RegisterRoute( @@ -637,7 +726,7 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); } - // Queue has since expired — clean up stale entries and fall through to create a new one + // Queue has since expired - clean up stale entries and fall through to create a new one m_RemoteQueuesByToken.erase(Existing->Token); m_RemoteQueuesByQueueId.erase(Existing->QueueId); m_RemoteQueuesByTag.erase(It); @@ -666,7 +755,7 @@ HttpComputeService::Impl::RegisterRoutes() }, HttpVerb::kPost); - // Unified queue routes — {queueref} accepts both local integer IDs and remote OID tokens. + // Unified queue routes - {queueref} accepts both local integer IDs and remote OID tokens. // ResolveQueueRef() handles access control (local-only for integer IDs) and token resolution. m_Router.RegisterRoute( @@ -1016,7 +1105,7 @@ HttpComputeService::Impl::RegisterRoutes() }, HttpVerb::kPost); - // WebSocket upgrade endpoint — the handler logic lives in + // WebSocket upgrade endpoint - the handler logic lives in // HttpComputeService::OnWebSocket* methods; this route merely // satisfies the router so the upgrade request isn't rejected. m_Router.RegisterRoute( @@ -1027,11 +1116,12 @@ HttpComputeService::Impl::RegisterRoutes() ////////////////////////////////////////////////////////////////////////// -HttpComputeService::HttpComputeService(CidStore& InCidStore, +HttpComputeService::HttpComputeService(ChunkStore& InActionStore, + ChunkStore& InWorkerStore, IHttpStatsService& StatsService, const std::filesystem::path& BaseDir, int32_t MaxConcurrentActions) -: m_Impl(std::make_unique<Impl>(this, InCidStore, StatsService, BaseDir, MaxConcurrentActions)) +: m_Impl(std::make_unique<Impl>(this, InActionStore, InWorkerStore, StatsService, BaseDir, MaxConcurrentActions)) { } @@ -1057,6 +1147,12 @@ HttpComputeService::GetActionCounts() return m_Impl->m_ComputeService.GetActionCounts(); } +void +HttpComputeService::SetShutdownCallback(std::function<void()> Callback) +{ + m_Impl->m_ShutdownCallback = std::move(Callback); +} + const char* HttpComputeService::BaseUri() const { @@ -1145,7 +1241,7 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin { if (OidMatcher(Capture)) { - // Remote OID token — accessible from any client + // Remote OID token - accessible from any client const Oid Token = Oid::FromHexString(Capture); const int QueueId = ResolveQueueToken(Token); @@ -1157,7 +1253,7 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin return QueueId; } - // Local integer queue ID — restricted to local machine requests + // Local integer queue ID - restricted to local machine requests if (!HttpReq.IsLocalMachineRequest()) { HttpReq.WriteResponse(HttpResponseCode::Forbidden); @@ -1167,35 +1263,81 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin return ParseInt<int>(Capture).value_or(0); } -HttpComputeService::Impl::IngestStats -HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package) +bool +HttpComputeService::Impl::ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment) { - IngestStats Stats; + const IoHash ClaimedHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + const IoHash HeaderHash = Buffer.DecodeRawHash(); + if (HeaderHash != ClaimedHash) + { + ZEN_WARN("attachment header hash mismatch: claimed {} but header contains {}", ClaimedHash, HeaderHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + IoHashStream Hasher; + + bool DecompressOk = Buffer.DecompressToStream( + 0, + Buffer.DecodeRawSize(), + [&](uint64_t /*SourceOffset*/, uint64_t /*SourceSize*/, uint64_t /*Offset*/, const CompositeBuffer& Range) -> bool { + for (const SharedBuffer& Segment : Range.GetSegments()) + { + Hasher.Append(Segment.GetView()); + } + return true; + }); + + if (!DecompressOk) + { + ZEN_WARN("attachment {}: failed to decompress", ClaimedHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + const IoHash ActualHash = Hasher.GetHash(); + + if (ActualHash != ClaimedHash) + { + ZEN_WARN("attachment hash mismatch: claimed {} but decompressed data hashes to {}", ClaimedHash, ActualHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + return true; +} + +bool +HttpComputeService::Impl::IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats) +{ for (const CbAttachment& Attachment : Package.GetAttachments()) { ZEN_ASSERT(Attachment.IsCompressedBinary()); - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); + if (!ValidateAttachmentHash(HttpReq, Attachment)) + { + return false; + } - const uint64_t CompressedSize = DataView.GetCompressedSize(); + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + const uint64_t CompressedSize = DataView.GetCompressedSize(); - Stats.Bytes += CompressedSize; - ++Stats.Count; + OutStats.Bytes += CompressedSize; + ++OutStats.Count; - const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + const ChunkStore::InsertResult InsertResult = m_ActionStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); if (InsertResult.New) { - Stats.NewBytes += CompressedSize; - ++Stats.NewCount; + OutStats.NewBytes += CompressedSize; + ++OutStats.NewCount; } } - return Stats; + return true; } bool @@ -1204,7 +1346,7 @@ HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vecto ActionObj.IterateAttachments([&](CbFieldView Field) { const IoHash FileHash = Field.AsHash(); - if (!m_CidStore.ContainsChunk(FileHash)) + if (!m_ActionStore.ContainsChunk(FileHash)) { NeedList.push_back(FileHash); } @@ -1253,7 +1395,10 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que { CbPackage Package = HttpReq.ReadPayloadPackage(); Body = Package.GetObject(); - Stats = IngestPackageAttachments(Package); + if (!IngestPackageAttachments(HttpReq, Package, Stats)) + { + return; // validation failed, response already written + } break; } @@ -1268,8 +1413,7 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que { // --- Batch path --- - // For CbObject payloads, check all attachments upfront before enqueuing anything - if (HttpReq.RequestContentType() == HttpContentType::kCbObject) + // Verify all action attachment references exist in the store { std::vector<IoHash> NeedList; @@ -1345,7 +1489,6 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que // --- Single-action path: Body is the action itself --- - if (HttpReq.RequestContentType() == HttpContentType::kCbObject) { std::vector<IoHash> NeedList; @@ -1453,7 +1596,7 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const CbPackage WorkerPackage; WorkerPackage.SetObject(WorkerSpec); - m_CidStore.FilterChunks(ChunkSet); + m_WorkerStore.FilterChunks(ChunkSet); if (ChunkSet.IsEmpty()) { @@ -1491,15 +1634,19 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const { ZEN_ASSERT(Attachment.IsCompressedBinary()); + if (!ValidateAttachmentHash(HttpReq, Attachment)) + { + return; + } + const IoHash DataHash = Attachment.GetHash(); CompressedBuffer Buffer = Attachment.AsCompressedBinary(); - ZEN_UNUSED(DataHash); TotalAttachmentBytes += Buffer.GetCompressedSize(); ++AttachmentCount; - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); + const ChunkStore::InsertResult InsertResult = + m_WorkerStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); if (InsertResult.New) { @@ -1537,9 +1684,9 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const // void -HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { - m_Impl->OnWebSocketOpen(std::move(Connection)); + m_Impl->OnWebSocketOpen(std::move(Connection), RelativeUri); } void @@ -1562,12 +1709,13 @@ HttpComputeService::OnActionsCompleted(std::span<const CompletedActionNotificati ////////////////////////////////////////////////////////////////////////// // -// Impl — WebSocket / observer +// Impl - WebSocket / observer // void -HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); ZEN_INFO("compute WebSocket client connected"); m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); } diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp index 6cbe01e04..56eadcd57 100644 --- a/src/zencompute/httporchestrator.cpp +++ b/src/zencompute/httporchestrator.cpp @@ -7,6 +7,7 @@ # include <zencompute/orchestratorservice.h> # include <zencore/compactbinarybuilder.h> # include <zencore/logging.h> +# include <zencore/session.h> # include <zencore/string.h> # include <zencore/system.h> @@ -77,10 +78,47 @@ ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnn return Ann.Id; } +static OrchestratorService::WorkerAnnotator +MakeWorkerAnnotator(IProvisionerStateProvider* Prov) +{ + if (!Prov) + { + return {}; + } + return [Prov](std::string_view WorkerId, CbObjectWriter& Cbo) { + AgentProvisioningStatus Status = Prov->GetAgentStatus(WorkerId); + if (Status != AgentProvisioningStatus::Unknown) + { + const char* StatusStr = (Status == AgentProvisioningStatus::Draining) ? "draining" : "active"; + Cbo << "provisioner_status" << std::string_view(StatusStr); + } + }; +} + +bool +HttpOrchestratorService::ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId) +{ + std::string_view SessionStr = Data["coordinator_session"].AsString(""); + if (SessionStr.empty()) + { + return true; // backwards compatibility: accept announcements without a session + } + Oid Session = Oid::TryFromHexString(SessionStr); + if (Session == m_SessionId) + { + return true; + } + ZEN_WARN("rejecting stale announcement from '{}' (session {} != {})", WorkerId, SessionStr, m_SessionId.ToString()); + return false; +} + HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) : m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket)) , m_Hostname(GetMachineName()) { + m_SessionId = zen::GetSessionId(); + ZEN_INFO("orchestrator session id: {}", m_SessionId.ToString()); + m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); @@ -95,13 +133,17 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, [this](HttpRouterRequest& Req) { CbObjectWriter Cbo; Cbo << "hostname" << std::string_view(m_Hostname); + Cbo << "session_id" << m_SessionId.ToString(); Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kGet); m_Router.RegisterRoute( "provision", - [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + [this](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, + m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)))); + }, HttpVerb::kPost); m_Router.RegisterRoute( @@ -122,6 +164,11 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, "characters and uri must start with http:// or https://"); } + if (!ValidateCoordinatorSession(Data, WorkerId)) + { + return HttpReq.WriteResponse(HttpResponseCode::Conflict, HttpContentType::kText, "Stale coordinator session"); + } + m_Service->AnnounceWorker(Ann); HttpReq.WriteResponse(HttpResponseCode::OK); @@ -135,7 +182,10 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, m_Router.RegisterRoute( "agents", - [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + [this](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, + m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)))); + }, HttpVerb::kGet); m_Router.RegisterRoute( @@ -241,6 +291,59 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, }, HttpVerb::kGet); + // Provisioner endpoints + + m_Router.RegisterRoute( + "provisioner/status", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + Cbo << "name" << Prov->GetName(); + Cbo << "target_cores" << Prov->GetTargetCoreCount(); + Cbo << "estimated_cores" << Prov->GetEstimatedCoreCount(); + Cbo << "active_cores" << Prov->GetActiveCoreCount(); + Cbo << "agents" << Prov->GetAgentCount(); + Cbo << "agents_draining" << Prov->GetDrainingAgentCount(); + } + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "provisioner/target", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObject Data = HttpReq.ReadPayloadObject(); + int32_t Cores = Data["target_cores"].AsInt32(-1); + + ZEN_INFO("provisioner/target: received request (target_cores={}, payload_valid={})", Cores, Data ? true : false); + + if (Cores < 0) + { + ZEN_WARN("provisioner/target: bad request (target_cores={})", Cores); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Missing or invalid target_cores field"); + } + + IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire); + if (!Prov) + { + ZEN_WARN("provisioner/target: no provisioner configured"); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "No provisioner configured"); + } + + ZEN_INFO("provisioner/target: setting target to {} cores", Cores); + Prov->SetTargetCoreCount(static_cast<uint32_t>(Cores)); + + CbObjectWriter Cbo; + Cbo << "target_cores" << Prov->GetTargetCoreCount(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + // Client tracking endpoints m_Router.RegisterRoute( @@ -375,7 +478,7 @@ HttpOrchestratorService::Shutdown() m_PushThread.join(); } - // Clean up worker WebSocket connections — collect IDs under lock, then + // Clean up worker WebSocket connections - collect IDs under lock, then // notify the service outside the lock to avoid lock-order inversions. std::vector<std::string> WorkerIds; m_WorkerWsLock.WithExclusiveLock([&] { @@ -411,6 +514,13 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) } } +void +HttpOrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider) +{ + m_Provisioner.store(Provider, std::memory_order_release); + m_Service->SetProvisionerStateProvider(Provider); +} + ////////////////////////////////////////////////////////////////////////// // // IWebSocketHandler @@ -418,8 +528,9 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) # if ZEN_WITH_WEBSOCKETS void -HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); if (!m_PushEnabled.load()) { return; @@ -471,7 +582,7 @@ std::string HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg) { // Workers send CbObject in native binary format over the WebSocket to - // avoid the lossy CbObject↔JSON round-trip. + // avoid the lossy CbObject<->JSON round-trip. CbObject Data = CbObject::MakeView(Msg.Payload.GetData()); if (!Data) { @@ -487,6 +598,11 @@ HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Ms return {}; } + if (!ValidateCoordinatorSession(Data, WorkerId)) + { + return {}; + } + m_Service->AnnounceWorker(Ann); return std::string(WorkerId); } @@ -562,7 +678,7 @@ HttpOrchestratorService::PushThreadFunction() } // Build combined JSON with worker list, provisioning history, clients, and client history - CbObject WorkerList = m_Service->GetWorkerList(); + CbObject WorkerList = m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))); CbObject History = m_Service->GetProvisioningHistory(50); CbObject ClientList = m_Service->GetClientList(); CbObject ClientHistory = m_Service->GetClientHistory(50); @@ -614,6 +730,20 @@ HttpOrchestratorService::PushThreadFunction() JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2)); } + // Emit provisioner stats if available + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + JsonBuilder.Append( + fmt::format(",\"provisioner\":{{\"name\":\"{}\",\"target_cores\":{},\"estimated_cores\":{}" + ",\"active_cores\":{},\"agents\":{},\"agents_draining\":{}}}", + Prov->GetName(), + Prov->GetTargetCoreCount(), + Prov->GetEstimatedCoreCount(), + Prov->GetActiveCoreCount(), + Prov->GetAgentCount(), + Prov->GetDrainingAgentCount())); + } + JsonBuilder.Append("}"); std::string_view Json = JsonBuilder.ToView(); diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h index 3b9642ac3..280d794e7 100644 --- a/src/zencompute/include/zencompute/cloudmetadata.h +++ b/src/zencompute/include/zencompute/cloudmetadata.h @@ -64,7 +64,7 @@ public: explicit CloudMetadata(std::filesystem::path DataDir); /** Synchronously probes cloud providers at the given IMDS endpoint. - * Intended for testing — allows redirecting all IMDS queries to a local + * Intended for testing - allows redirecting all IMDS queries to a local * mock HTTP server instead of the real 169.254.169.254 endpoint. */ CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint); diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h index 1ca78738a..97de4321a 100644 --- a/src/zencompute/include/zencompute/computeservice.h +++ b/src/zencompute/include/zencompute/computeservice.h @@ -167,6 +167,7 @@ public: // Action runners void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); + void AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); // Action submission @@ -278,7 +279,7 @@ public: // sized to match RunnerAction::State::_Count but we can't use the enum here // for dependency reasons, so just use a fixed size array and static assert in // the implementation file - uint64_t Timestamps[9] = {}; + uint64_t Timestamps[10] = {}; }; [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); @@ -304,8 +305,9 @@ public: // Recording - void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); - void StopRecording(); + bool StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); + bool StopRecording(); + bool IsRecording() const; private: void PostUpdate(RunnerAction* Action); diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h index b58e73a0d..32f54f293 100644 --- a/src/zencompute/include/zencompute/httpcomputeservice.h +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -15,7 +15,7 @@ # include <memory> namespace zen { -class CidStore; +class ChunkStore; } namespace zen::compute { @@ -26,7 +26,8 @@ namespace zen::compute { class HttpComputeService : public HttpService, public IHttpStatsProvider, public IWebSocketHandler, public IComputeCompletionObserver { public: - HttpComputeService(CidStore& InCidStore, + HttpComputeService(ChunkStore& InActionStore, + ChunkStore& InWorkerStore, IHttpStatsService& StatsService, const std::filesystem::path& BaseDir, int32_t MaxConcurrentActions = 0); @@ -34,6 +35,10 @@ public: void Shutdown(); + /** Set a callback to be invoked when the session/sunset endpoint is hit. + * Typically wired to HttpServer::RequestExit() to shut down the process. */ + void SetShutdownCallback(std::function<void()> Callback); + [[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts(); const char* BaseUri() const override; @@ -45,7 +50,7 @@ public: // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h index da5c5dfc3..4e4f5f0f8 100644 --- a/src/zencompute/include/zencompute/httporchestrator.h +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -2,10 +2,12 @@ #pragma once +#include <zencompute/provisionerstate.h> #include <zencompute/zencompute.h> #include <zencore/logging.h> #include <zencore/thread.h> +#include <zencore/uid.h> #include <zenhttp/httpserver.h> #include <zenhttp/websocket.h> @@ -65,12 +67,22 @@ public: */ void Shutdown(); + /** Return the session ID generated at construction time. Provisioners + * pass this to spawned workers so the orchestrator can reject stale + * announcements from previous sessions. */ + Oid GetSessionId() const { return m_SessionId; } + + /** Register a provisioner whose target core count can be read and changed + * via the orchestrator HTTP API and dashboard. Caller retains ownership; + * the provider must outlive this service. */ + void SetProvisionerStateProvider(IProvisionerStateProvider* Provider); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; // IWebSocketHandler #if ZEN_WITH_WEBSOCKETS - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; #endif @@ -81,6 +93,11 @@ private: std::unique_ptr<OrchestratorService> m_Service; std::string m_Hostname; + Oid m_SessionId; + bool ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId); + + std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr}; + // WebSocket push #if ZEN_WITH_WEBSOCKETS @@ -91,9 +108,9 @@ private: Event m_PushEvent; void PushThreadFunction(); - // Worker WebSocket connections (worker→orchestrator persistent links) + // Worker WebSocket connections (worker->orchestrator persistent links) RwLock m_WorkerWsLock; - std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr → worker ID + std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr -> worker ID std::string HandleWorkerWebSocketMessage(const WebSocketMessage& Msg); #endif }; diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h index 704306913..6074240b9 100644 --- a/src/zencompute/include/zencompute/mockimds.h +++ b/src/zencompute/include/zencompute/mockimds.h @@ -1,5 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. -// Moved to zenutil — this header is kept for backward compatibility. +// Moved to zenutil - this header is kept for backward compatibility. #pragma once diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h index 071e902b3..2c49e22df 100644 --- a/src/zencompute/include/zencompute/orchestratorservice.h +++ b/src/zencompute/include/zencompute/orchestratorservice.h @@ -6,7 +6,10 @@ #if ZEN_WITH_COMPUTE_SERVICES +# include <zencompute/provisionerstate.h> # include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/logbase.h> # include <zencore/thread.h> # include <zencore/timer.h> # include <zencore/uid.h> @@ -88,9 +91,16 @@ public: std::string Hostname; }; - CbObject GetWorkerList(); + /** Per-worker callback invoked during GetWorkerList serialization. + * The callback receives the worker ID and a CbObjectWriter positioned + * inside the worker's object, allowing the caller to append extra fields. */ + using WorkerAnnotator = std::function<void(std::string_view WorkerId, CbObjectWriter& Cbo)>; + + CbObject GetWorkerList(const WorkerAnnotator& Annotate = {}); void AnnounceWorker(const WorkerAnnouncement& Announcement); + void SetProvisionerStateProvider(IProvisionerStateProvider* Provider); + bool IsWorkerWebSocketEnabled() const; void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected); @@ -164,7 +174,12 @@ private: void RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname); - bool m_EnableWorkerWebSocket = false; + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log{"compute.orchestrator"}; + bool m_EnableWorkerWebSocket = false; + + std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr}; std::thread m_ProbeThread; std::atomic<bool> m_ProbeThreadEnabled{true}; diff --git a/src/zencompute/include/zencompute/provisionerstate.h b/src/zencompute/include/zencompute/provisionerstate.h new file mode 100644 index 000000000..e9af8a635 --- /dev/null +++ b/src/zencompute/include/zencompute/provisionerstate.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <cstdint> +#include <string_view> + +namespace zen::compute { + +/** Per-agent provisioning status as seen by the provisioner. */ +enum class AgentProvisioningStatus +{ + Unknown, ///< Not known to the provisioner + Active, ///< Running and allocated + Draining, ///< Being gracefully deprovisioned +}; + +/** Abstract interface for querying and controlling a provisioner from the HTTP layer. + * This decouples the orchestrator service from specific provisioner implementations. */ +class IProvisionerStateProvider +{ +public: + virtual ~IProvisionerStateProvider() = default; + + virtual std::string_view GetName() const = 0; ///< e.g. "horde", "nomad" + virtual uint32_t GetTargetCoreCount() const = 0; + virtual uint32_t GetEstimatedCoreCount() const = 0; + virtual uint32_t GetActiveCoreCount() const = 0; + virtual uint32_t GetAgentCount() const = 0; + virtual uint32_t GetDrainingAgentCount() const { return 0; } + virtual void SetTargetCoreCount(uint32_t Count) = 0; + + /** Return the provisioning status for a worker by its orchestrator ID + * (e.g. "horde-{LeaseId}"). Returns Unknown if the ID is not recognized. */ + virtual AgentProvisioningStatus GetAgentStatus(std::string_view /*WorkerId*/) const { return AgentProvisioningStatus::Unknown; } +}; + +} // namespace zen::compute diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp index 9ea695305..68199ab3c 100644 --- a/src/zencompute/orchestratorservice.cpp +++ b/src/zencompute/orchestratorservice.cpp @@ -31,7 +31,7 @@ OrchestratorService::~OrchestratorService() } CbObject -OrchestratorService::GetWorkerList() +OrchestratorService::GetWorkerList(const WorkerAnnotator& Annotate) { ZEN_TRACE_CPU("OrchestratorService::GetWorkerList"); CbObjectWriter Cbo; @@ -71,6 +71,10 @@ OrchestratorService::GetWorkerList() Cbo << "ws_connected" << true; } Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); + if (Annotate) + { + Annotate(WorkerId, Cbo); + } Cbo.EndObject(); } }); @@ -144,6 +148,12 @@ OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann) } } +void +OrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider) +{ + m_Provisioner.store(Provider, std::memory_order_release); +} + bool OrchestratorService::IsWorkerWebSocketEnabled() const { @@ -170,11 +180,11 @@ OrchestratorService::SetWorkerWebSocketConnected(std::string_view WorkerId, bool if (Connected) { - ZEN_INFO("worker {} WebSocket connected — marking reachable", WorkerId); + ZEN_INFO("worker {} WebSocket connected - marking reachable", WorkerId); } else { - ZEN_WARN("worker {} WebSocket disconnected — marking unreachable", WorkerId); + ZEN_WARN("worker {} WebSocket disconnected - marking unreachable", WorkerId); } }); @@ -607,6 +617,14 @@ OrchestratorService::ProbeThreadFunction() continue; } + // Check if the provisioner knows this worker is draining - if so, + // unreachability is expected and should not be logged as a warning. + bool IsDraining = false; + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + IsDraining = Prov->GetAgentStatus(Snap.Id) == AgentProvisioningStatus::Draining; + } + ReachableState NewState = ReachableState::Unreachable; try @@ -621,7 +639,10 @@ OrchestratorService::ProbeThreadFunction() } catch (const std::exception& Ex) { - ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + if (!IsDraining) + { + ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + } } ReachableState PrevState = ReachableState::Unknown; @@ -646,6 +667,10 @@ OrchestratorService::ProbeThreadFunction() { ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri); } + else if (IsDraining) + { + ZEN_INFO("worker {} ({}) shut down (draining)", Snap.Id, Snap.Uri); + } else if (PrevState == ReachableState::Reachable) { ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri); diff --git a/src/zencompute/pathvalidation.h b/src/zencompute/pathvalidation.h new file mode 100644 index 000000000..d50ad4a2a --- /dev/null +++ b/src/zencompute/pathvalidation.h @@ -0,0 +1,118 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/except_fmt.h> +#include <zencore/string.h> + +#include <filesystem> +#include <string_view> + +namespace zen::compute { + +// Validate that a single path component contains only characters that are valid +// file/directory names on all supported platforms. Uses Windows rules as the most +// restrictive superset, since packages may be built on one platform and consumed +// on another. +inline void +ValidatePathComponent(std::string_view Component, std::string_view FullPath) +{ + // Reject control characters (0x00-0x1F) and characters forbidden on Windows + for (char Ch : Component) + { + if (static_cast<unsigned char>(Ch) < 0x20 || Ch == '<' || Ch == '>' || Ch == ':' || Ch == '"' || Ch == '|' || Ch == '?' || + Ch == '*') + { + throw zen::invalid_argument("invalid character in path component '{}' of '{}'", Component, FullPath); + } + } + + // Reject empty components and trailing dots or spaces (silently stripped on Windows, leading to confusion) + if (Component.empty() || Component.back() == '.' || Component.back() == ' ') + { + throw zen::invalid_argument("path component '{}' of '{}' has trailing dot or space", Component, FullPath); + } + + // Reject Windows reserved device names (CON, PRN, AUX, NUL, COM1-9, LPT1-9) + // These are reserved with or without an extension (e.g. "CON.txt" is still reserved). + std::string_view Stem = Component.substr(0, Component.find('.')); + + static constexpr std::string_view ReservedNames[] = { + "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", + "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", + }; + + for (std::string_view Reserved : ReservedNames) + { + if (zen::StrCaseCompare(Stem, Reserved) == 0) + { + throw zen::invalid_argument("path component '{}' of '{}' uses reserved device name '{}'", Component, FullPath, Reserved); + } + } +} + +// Validate that a path extracted from a package is a safe relative path. +// Rejects absolute paths, ".." components, and invalid platform filenames. +inline void +ValidateSandboxRelativePath(std::string_view Name) +{ + if (Name.empty()) + { + throw zen::invalid_argument("path traversal detected: empty path name"); + } + + std::filesystem::path Parsed(Name); + + if (Parsed.is_absolute()) + { + throw zen::invalid_argument("path traversal detected: '{}' is an absolute path", Name); + } + + for (const auto& Component : Parsed) + { + std::string ComponentStr = Component.string(); + + if (ComponentStr == "..") + { + throw zen::invalid_argument("path traversal detected: '{}' contains '..' component", Name); + } + + // Skip "." (current directory) - harmless in relative paths + if (ComponentStr != ".") + { + ValidatePathComponent(ComponentStr, Name); + } + } +} + +// Validate all path entries in a worker description CbObject. +// Checks path, executables[].name, dirs[], and files[].name fields. +// Throws an exception if any invalid paths are found. +inline void +ValidateWorkerDescriptionPaths(const CbObject& WorkerDescription) +{ + using namespace std::literals; + + if (auto PathField = WorkerDescription["path"sv]; PathField.HasValue()) + { + ValidateSandboxRelativePath(PathField.AsString()); + } + + for (auto& It : WorkerDescription["executables"sv]) + { + ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString()); + } + + for (auto& It : WorkerDescription["dirs"sv]) + { + ValidateSandboxRelativePath(It.AsString()); + } + + for (auto& It : WorkerDescription["files"sv]) + { + ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString()); + } +} + +} // namespace zen::compute diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp index 4f116e7d8..34bf065b4 100644 --- a/src/zencompute/runners/functionrunner.cpp +++ b/src/zencompute/runners/functionrunner.cpp @@ -6,9 +6,15 @@ # include <zencore/compactbinary.h> # include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/logging.h> +# include <zencore/string.h> +# include <zencore/timer.h> # include <zencore/trace.h> +# include <zencore/workthreadpool.h> # include <fmt/format.h> +# include <future> # include <vector> namespace zen::compute { @@ -118,23 +124,34 @@ std::vector<SubmitResult> BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) { ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions"); - RwLock::SharedLockScope _(m_RunnersLock); - const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + // Snapshot runners and query capacity under the lock, then release + // before submitting - HTTP submissions to remote runners can take + // hundreds of milliseconds and we must not hold m_RunnersLock during I/O. - if (RunnerCount == 0) - { - return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); - } + std::vector<Ref<FunctionRunner>> Runners; + std::vector<size_t> Capacities; + std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions; + size_t TotalCapacity = 0; - // Query capacity per runner and compute total - std::vector<size_t> Capacities(RunnerCount); - size_t TotalCapacity = 0; + m_RunnersLock.WithSharedLock([&] { + const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + Runners.assign(m_Runners.begin(), m_Runners.end()); + Capacities.resize(RunnerCount); + PerRunnerActions.resize(RunnerCount); - for (int i = 0; i < RunnerCount; ++i) + for (int i = 0; i < RunnerCount; ++i) + { + Capacities[i] = Runners[i]->QueryCapacity(); + TotalCapacity += Capacities[i]; + } + }); + + const int RunnerCount = gsl::narrow<int>(Runners.size()); + + if (RunnerCount == 0) { - Capacities[i] = m_Runners[i]->QueryCapacity(); - TotalCapacity += Capacities[i]; + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); } if (TotalCapacity == 0) @@ -143,9 +160,8 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } // Distribute actions across runners proportionally to their available capacity - std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions(RunnerCount); - std::vector<size_t> ActionRunnerIndex(Actions.size()); - size_t ActionIdx = 0; + std::vector<size_t> ActionRunnerIndex(Actions.size()); + size_t ActionIdx = 0; for (int i = 0; i < RunnerCount; ++i) { @@ -164,8 +180,9 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } } - // Assign any remaining actions to runners with capacity (round-robin) - for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount) + // Assign any remaining actions to runners with capacity (round-robin). + // Cap at TotalCapacity to avoid spinning when there are more actions than runners can accept. + for (int i = 0; ActionIdx < Actions.size() && ActionIdx < TotalCapacity; i = (i + 1) % RunnerCount) { if (Capacities[i] > PerRunnerActions[i].size()) { @@ -175,22 +192,83 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } } - // Submit batches per runner + // Submit batches per runner - in parallel when a worker pool is available + std::vector<std::vector<SubmitResult>> PerRunnerResults(RunnerCount); + int ActiveRunnerCount = 0; for (int i = 0; i < RunnerCount; ++i) { if (!PerRunnerActions[i].empty()) { - PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]); + ++ActiveRunnerCount; + } + } + + static constexpr uint64_t SubmitWarnThresholdMs = 500; + + auto SubmitToRunner = [&](int RunnerIndex) { + auto& Runner = Runners[RunnerIndex]; + Runner->m_LastSubmitStats.Reset(); + + Stopwatch Timer; + + PerRunnerResults[RunnerIndex] = Runner->SubmitActions(PerRunnerActions[RunnerIndex]); + + uint64_t ElapsedMs = Timer.GetElapsedTimeMs(); + if (ElapsedMs >= SubmitWarnThresholdMs) + { + size_t Attachments = Runner->m_LastSubmitStats.TotalAttachments.load(std::memory_order_relaxed); + uint64_t AttachmentBytes = Runner->m_LastSubmitStats.TotalAttachmentBytes.load(std::memory_order_relaxed); + + ZEN_WARN("submit of {} actions ({} attachments, {}) to '{}' took {}ms", + PerRunnerActions[RunnerIndex].size(), + Attachments, + NiceBytes(AttachmentBytes), + Runner->GetDisplayName(), + ElapsedMs); + } + }; + + if (m_WorkerPool && ActiveRunnerCount > 1) + { + std::vector<std::future<void>> Futures(RunnerCount); + + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + std::packaged_task<void()> Task([&SubmitToRunner, i]() { SubmitToRunner(i); }); + + Futures[i] = m_WorkerPool->EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog); + } + } + + for (int i = 0; i < RunnerCount; ++i) + { + if (Futures[i].valid()) + { + Futures[i].get(); + } + } + } + else + { + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + SubmitToRunner(i); + } } } - // Reassemble results in original action order - std::vector<SubmitResult> Results(Actions.size()); + // Reassemble results in original action order. + // Actions beyond ActionIdx were not assigned to any runner (insufficient capacity). + std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"}); std::vector<size_t> PerRunnerIdx(RunnerCount, 0); - for (size_t i = 0; i < Actions.size(); ++i) + for (size_t i = 0; i < ActionIdx; ++i) { size_t RunnerIdx = ActionRunnerIndex[i]; size_t Idx = PerRunnerIdx[RunnerIdx]++; @@ -307,10 +385,11 @@ RunnerAction::RetractAction() bool RunnerAction::ResetActionStateToPending() { - // Only allow reset from Failed, Abandoned, or Retracted states + // Only allow reset from Failed, Abandoned, Rejected, or Retracted states State CurrentState = m_ActionState.load(); - if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted) + if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Rejected && + CurrentState != State::Retracted) { return false; } @@ -331,11 +410,12 @@ RunnerAction::ResetActionStateToPending() // Clear execution fields ExecutionLocation.clear(); + FailureReason.clear(); CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); CpuSeconds.store(0.0f, std::memory_order_relaxed); - // Increment retry count (skip for Retracted — nothing failed) - if (CurrentState != State::Retracted) + // Increment retry count (skip for Retracted/Rejected - nothing failed) + if (CurrentState != State::Retracted && CurrentState != State::Rejected) { RetryCount.fetch_add(1, std::memory_order_relaxed); } diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h index 56c3f3af0..371a60b7a 100644 --- a/src/zencompute/runners/functionrunner.h +++ b/src/zencompute/runners/functionrunner.h @@ -10,6 +10,10 @@ # include <filesystem> # include <vector> +namespace zen { +class WorkerThreadPool; +} + namespace zen::compute { struct SubmitResult @@ -37,6 +41,22 @@ public: [[nodiscard]] virtual bool IsHealthy() = 0; [[nodiscard]] virtual size_t QueryCapacity(); [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + [[nodiscard]] virtual std::string_view GetDisplayName() const { return "local"; } + + // Accumulated stats from the most recent SubmitActions call. + // Reset before each call, populated by the runner implementation. + struct SubmitStats + { + std::atomic<size_t> TotalAttachments{0}; + std::atomic<uint64_t> TotalAttachmentBytes{0}; + + void Reset() + { + TotalAttachments.store(0, std::memory_order_relaxed); + TotalAttachmentBytes.store(0, std::memory_order_relaxed); + } + }; + SubmitStats m_LastSubmitStats; // Best-effort cancellation of a specific in-flight action. Returns true if the // cancellation signal was successfully sent. The action will transition to Cancelled @@ -68,6 +88,8 @@ public: bool CancelAction(int ActionLsn); void CancelRemoteQueue(int QueueId); + void SetWorkerPool(WorkerThreadPool* Pool) { m_WorkerPool = Pool; } + size_t GetRunnerCount() { return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); }); @@ -79,6 +101,7 @@ protected: RwLock m_RunnersLock; std::vector<Ref<FunctionRunner>> m_Runners; std::atomic<int> m_NextSubmitIndex{0}; + WorkerThreadPool* m_WorkerPool = nullptr; }; /** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal. @@ -151,6 +174,7 @@ struct RunnerAction : public RefCounted CbObject ActionObj; int Priority = 0; std::string ExecutionLocation; // "local" or remote hostname + std::string FailureReason; // human-readable reason when action fails (empty on success) // CPU usage and total CPU time of the running process, sampled periodically by the local runner. // CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage. @@ -168,6 +192,7 @@ struct RunnerAction : public RefCounted Completed, // Finished successfully with results available Failed, // Execution failed (transient error, eligible for retry) Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon) + Rejected, // Runner declined (e.g. at capacity) - rescheduled without retry cost Cancelled, // Intentional user cancellation (never retried) Retracted, // Pulled back for rescheduling on a different runner (no retry cost) _Count @@ -194,6 +219,8 @@ struct RunnerAction : public RefCounted return "Failed"; case State::Abandoned: return "Abandoned"; + case State::Rejected: + return "Rejected"; case State::Cancelled: return "Cancelled"; case State::Retracted: diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp index e79a6c90f..be4274823 100644 --- a/src/zencompute/runners/linuxrunner.cpp +++ b/src/zencompute/runners/linuxrunner.cpp @@ -195,7 +195,7 @@ namespace { WriteErrorAndExit(ErrorPipeFd, "bind mount /lib failed", errno); } - // /lib64 (optional — not all distros have it) + // /lib64 (optional - not all distros have it) { struct stat St; if (stat("/lib64", &St) == 0 && S_ISDIR(St.st_mode)) @@ -208,7 +208,7 @@ namespace { } } - // /etc (required — for resolv.conf, ld.so.cache, etc.) + // /etc (required - for resolv.conf, ld.so.cache, etc.) if (MkdirIfNeeded(BuildPath("etc"), 0755) != 0) { WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/etc failed", errno); @@ -218,7 +218,7 @@ namespace { WriteErrorAndExit(ErrorPipeFd, "bind mount /etc failed", errno); } - // /worker — bind-mount worker directory (contains the executable) + // /worker - bind-mount worker directory (contains the executable) if (MkdirIfNeeded(BuildPath("worker"), 0755) != 0) { WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/worker failed", errno); @@ -331,6 +331,8 @@ LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver, { ZEN_INFO("namespace sandboxing enabled for child processes"); } + + StartMonitorThread(); } SubmitResult @@ -428,11 +430,12 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (ChildPid == 0) { - // Child process + // Child process - lower priority so workers don't starve the main server + nice(5); if (m_Sandboxed) { - // Close read end of error pipe — child only writes + // Close read end of error pipe - child only writes close(ErrorPipe[0]); SetupNamespaceSandbox(SandboxPathStr.c_str(), CurrentUid, CurrentGid, WorkerPathStr.c_str(), ErrorPipe[1]); @@ -459,7 +462,7 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (m_Sandboxed) { - // Close write end of error pipe — parent only reads + // Close write end of error pipe - parent only reads close(ErrorPipe[1]); // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC @@ -479,7 +482,8 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) // Clean up the sandbox in the background m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); - ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf); + ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason); Action->SetActionState(RunnerAction::State::Failed); return SubmitResult{.IsAccepted = false}; @@ -675,7 +679,7 @@ ReadProcStatCpuTicks(pid_t Pid) Buf[Len] = '\0'; - // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens + // Skip past "pid (name) " - find last ')' to handle names containing spaces or parens const char* P = strrchr(Buf, ')'); if (!P) { @@ -705,7 +709,7 @@ LinuxProcessRunner::SampleProcessCpu(RunningAction& Running) if (CurrentOsTicks == 0) { - // Process gone or /proc entry unreadable — record timestamp without updating usage + // Process gone or /proc entry unreadable - record timestamp without updating usage Running.LastCpuSampleTicks = NowTicks; Running.LastCpuOsTicks = 0; return; diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp index b61e0a46f..259965e23 100644 --- a/src/zencompute/runners/localrunner.cpp +++ b/src/zencompute/runners/localrunner.cpp @@ -4,6 +4,8 @@ #if ZEN_WITH_COMPUTE_SERVICES +# include "pathvalidation.h" + # include <zencore/compactbinary.h> # include <zencore/compactbinarybuilder.h> # include <zencore/compactbinarypackage.h> @@ -104,8 +106,6 @@ LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, ZEN_INFO("Cleanup complete"); } - m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; - # if ZEN_PLATFORM_WINDOWS // Suppress any error dialogs caused by missing dependencies UINT OldMode = ::SetErrorMode(0); @@ -337,7 +337,7 @@ LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action) SubmitResult LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) { - // Base class is not directly usable — platform subclasses override this + // Base class is not directly usable - platform subclasses override this ZEN_UNUSED(Action); return SubmitResult{.IsAccepted = false}; } @@ -357,14 +357,21 @@ LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); - if (!std::filesystem::exists(WorkerDir)) + // worker.zcb is written as the last step of ManifestWorker, so its presence + // indicates a complete manifest. If the directory exists but the marker is + // missing, a previous manifest was interrupted and we need to start over. + bool NeedsManifest = !std::filesystem::exists(WorkerDir / "worker.zcb"); + + if (NeedsManifest) { _.ReleaseNow(); RwLock::ExclusiveLockScope $(m_WorkerLock); - if (!std::filesystem::exists(WorkerDir)) + if (!std::filesystem::exists(WorkerDir / "worker.zcb")) { + std::error_code Ec; + std::filesystem::remove_all(WorkerDir, Ec); ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {}); } } @@ -382,6 +389,8 @@ LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromP const IoHash ChunkHash = FileEntry["hash"sv].AsHash(); const uint64_t Size = FileEntry["size"sv].AsUInt64(); + ValidateSandboxRelativePath(Name); + CompressedBuffer Compressed; if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash)) @@ -457,7 +466,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, for (auto& It : WorkerDescription["dirs"sv]) { - std::string_view Name = It.AsString(); + std::string_view Name = It.AsString(); + ValidateSandboxRelativePath(Name); std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; // Validate dir path stays within sandbox @@ -482,6 +492,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, } WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer()); + + ZEN_INFO("manifested worker '{}' in '{}'", WorkerPackage.GetObjectHash(), SandboxPath); } CbPackage @@ -540,6 +552,12 @@ LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) } void +LocalProcessRunner::StartMonitorThread() +{ + m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; +} + +void LocalProcessRunner::MonitorThreadFunction() { SetCurrentThreadName("LocalProcessRunner_Monitor"); @@ -602,7 +620,7 @@ LocalProcessRunner::MonitorThreadFunction() void LocalProcessRunner::CancelRunningActions() { - // Base class is not directly usable — platform subclasses override this + // Base class is not directly usable - platform subclasses override this } void @@ -662,9 +680,15 @@ LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& Com } catch (std::exception& Ex) { - ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what()); + Running->Action->FailureReason = fmt::format("exception gathering outputs: {}", Ex.what()); + ZEN_ERROR("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason); } } + else + { + Running->Action->FailureReason = fmt::format("process exited with code {}", Running->ExitCode); + ZEN_WARN("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason); + } // Failed - clean up the sandbox in the background. diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h index b8cff6826..d6589db43 100644 --- a/src/zencompute/runners/localrunner.h +++ b/src/zencompute/runners/localrunner.h @@ -67,6 +67,7 @@ protected: { Ref<RunnerAction> Action; void* ProcessHandle = nullptr; + int Pid = 0; int ExitCode = 0; std::filesystem::path SandboxPath; @@ -83,8 +84,6 @@ protected: std::filesystem::path m_SandboxPath; int32_t m_MaxRunningActions = 64; // arbitrary limit for testing - // if used in conjuction with m_ResultsLock, this lock must be taken *after* - // m_ResultsLock to avoid deadlocks RwLock m_RunningLock; std::unordered_map<int, Ref<RunningAction>> m_RunningMap; @@ -95,6 +94,7 @@ protected: std::thread m_MonitorThread; std::atomic<bool> m_MonitorThreadEnabled{true}; Event m_MonitorThreadEvent; + void StartMonitorThread(); void MonitorThreadFunction(); virtual void SweepRunningActions(); virtual void CancelRunningActions(); diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp index 5cec90699..ab24d4672 100644 --- a/src/zencompute/runners/macrunner.cpp +++ b/src/zencompute/runners/macrunner.cpp @@ -130,6 +130,8 @@ MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver, { ZEN_INFO("Seatbelt sandboxing enabled for child processes"); } + + StartMonitorThread(); } SubmitResult @@ -209,18 +211,19 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (ChildPid == 0) { - // Child process + // Child process - lower priority so workers don't starve the main server + nice(5); if (m_Sandboxed) { - // Close read end of error pipe — child only writes + // Close read end of error pipe - child only writes close(ErrorPipe[0]); // Apply Seatbelt sandbox profile char* ErrorBuf = nullptr; if (sandbox_init(SandboxProfile.c_str(), 0, &ErrorBuf) != 0) { - // sandbox_init failed — write error to pipe and exit + // sandbox_init failed - write error to pipe and exit if (ErrorBuf) { WriteErrorAndExit(ErrorPipe[1], ErrorBuf, 0); @@ -259,7 +262,7 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (m_Sandboxed) { - // Close write end of error pipe — parent only reads + // Close write end of error pipe - parent only reads close(ErrorPipe[1]); // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC @@ -279,7 +282,8 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) // Clean up the sandbox in the background m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); - ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf); + ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason); Action->SetActionState(RunnerAction::State::Failed); return SubmitResult{.IsAccepted = false}; @@ -467,7 +471,7 @@ MacProcessRunner::SampleProcessCpu(RunningAction& Running) const uint64_t CurrentOsTicks = Info.pti_total_user + Info.pti_total_system; const uint64_t NowTicks = GetHifreqTimerValue(); - // Cumulative CPU seconds (absolute, available from first sample): ns → seconds + // Cumulative CPU seconds (absolute, available from first sample): ns -> seconds Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 1'000'000'000.0), std::memory_order_relaxed); if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) @@ -476,7 +480,7 @@ MacProcessRunner::SampleProcessCpu(RunningAction& Running) if (ElapsedMs > 0) { const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; - // ns → ms: divide by 1,000,000; then as percent of elapsed ms + // ns -> ms: divide by 1,000,000; then as percent of elapsed ms const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 1'000'000.0 / ElapsedMs * 100.0); Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); } diff --git a/src/zencompute/runners/managedrunner.cpp b/src/zencompute/runners/managedrunner.cpp new file mode 100644 index 000000000..a4f586852 --- /dev/null +++ b/src/zencompute/runners/managedrunner.cpp @@ -0,0 +1,279 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "managedrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/scopeguard.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio/io_context.hpp> +# include <asio/executor_work_guard.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +using namespace std::literals; + +ManagedProcessRunner::ManagedProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_IoContext(std::make_unique<asio::io_context>()) +, m_SubprocessManager(std::make_unique<SubprocessManager>(*m_IoContext)) +{ + m_ProcessGroup = m_SubprocessManager->CreateGroup("compute-workers"); + + // Run the io_context on a small thread pool so that exit callbacks and + // metrics sampling are dispatched without blocking each other. + for (int i = 0; i < kIoThreadCount; ++i) + { + m_IoThreads.emplace_back([this, i] { + SetCurrentThreadName(fmt::format("mrunner_{}", i)); + + // work_guard keeps run() alive even when there is no pending work yet + auto WorkGuard = asio::make_work_guard(*m_IoContext); + + m_IoContext->run(); + }); + } +} + +ManagedProcessRunner::~ManagedProcessRunner() +{ + try + { + Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception during managed process runner shutdown: {}", Ex.what()); + } +} + +void +ManagedProcessRunner::Shutdown() +{ + ZEN_TRACE_CPU("ManagedProcessRunner::Shutdown"); + m_AcceptNewActions = false; + + CancelRunningActions(); + + // Tear down the SubprocessManager before stopping the io_context so that + // any in-flight callbacks are drained cleanly. + if (m_SubprocessManager) + { + m_SubprocessManager->DestroyGroup("compute-workers"); + m_ProcessGroup = nullptr; + m_SubprocessManager.reset(); + } + + if (m_IoContext) + { + m_IoContext->stop(); + } + + for (std::thread& Thread : m_IoThreads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + m_IoThreads.clear(); +} + +SubmitResult +ManagedProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + // Parse environment variables from worker descriptor ("KEY=VALUE" strings) + // into the key-value pairs expected by CreateProcOptions. + std::vector<std::pair<std::string, std::string>> EnvPairs; + for (auto& It : WorkerDescription["environment"sv]) + { + std::string_view Str = It.AsString(); + size_t Eq = Str.find('='); + if (Eq != std::string_view::npos) + { + EnvPairs.emplace_back(std::string(Str.substr(0, Eq)), std::string(Str.substr(Eq + 1))); + } + } + + // Build command line + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred(); + + std::string CommandLine = fmt::format("\"{}\" -Build=build.action"sv, ExePath.string()); + + ZEN_DEBUG("Executing (managed): '{}' (sandbox='{}')", CommandLine, Prepared->SandboxPath); + + CreateProcOptions Options; + Options.WorkingDirectory = &Prepared->SandboxPath; + Options.Flags = CreateProcOptions::Flag_NoConsole | CreateProcOptions::Flag_BelowNormalPriority; + Options.Environment = std::move(EnvPairs); + + const int32_t ActionLsn = Prepared->ActionLsn; + + ManagedProcess* Proc = nullptr; + + try + { + Proc = m_ProcessGroup->Spawn(ExePath, CommandLine, Options, [this, ActionLsn](ManagedProcess& /*Process*/, int ExitCode) { + OnProcessExit(ActionLsn, ExitCode); + }); + } + catch (std::exception& Ex) + { + ZEN_ERROR("Failed to spawn process for action LSN {}: {}", ActionLsn, Ex.what()); + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Prepared->SandboxPath)); + return SubmitResult{.IsAccepted = false}; + } + + { + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = static_cast<void*>(Proc); + NewAction->Pid = Proc->Pid(); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + ZEN_DEBUG("Managed runner: action LSN {} -> PID {}", ActionLsn, Proc->Pid()); + + return SubmitResult{.IsAccepted = true}; +} + +void +ManagedProcessRunner::OnProcessExit(int ActionLsn, int ExitCode) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::OnProcessExit"); + + Ref<RunningAction> Running; + + m_RunningLock.WithExclusiveLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It != m_RunningMap.end()) + { + Running = std::move(It->second); + m_RunningMap.erase(It); + } + }); + + if (!Running) + { + return; + } + + ZEN_DEBUG("Managed runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"), ActionLsn, Running->Pid, ExitCode); + + Running->ExitCode = ExitCode; + + // Capture final CPU metrics from the managed process before it is removed. + auto* Proc = static_cast<ManagedProcess*>(Running->ProcessHandle); + if (Proc) + { + ProcessMetrics Metrics = Proc->GetLatestMetrics(); + float CpuMs = static_cast<float>(Metrics.UserTimeMs + Metrics.KernelTimeMs); + Running->Action->CpuSeconds.store(CpuMs / 1000.0f, std::memory_order_relaxed); + + float CpuPct = Proc->GetCpuUsagePercent(); + if (CpuPct >= 0.0f) + { + Running->Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running->ProcessHandle = nullptr; + + std::vector<Ref<RunningAction>> CompletedActions; + CompletedActions.push_back(std::move(Running)); + ProcessCompletedActions(CompletedActions); +} + +void +ManagedProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("ManagedProcessRunner::CancelRunningActions"); + + std::unordered_map<int, Ref<RunningAction>> RunningMap; + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling {} running actions via process group", RunningMap.size()); + + Stopwatch Timer; + + // Kill all processes in the group atomically (TerminateJobObject on Windows, + // SIGTERM+SIGKILL on POSIX). + if (m_ProcessGroup) + { + m_ProcessGroup->KillAll(); + } + + for (auto& [Lsn, Running] : RunningMap) + { + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +ManagedProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::CancelAction"); + + ManagedProcess* Proc = nullptr; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It != m_RunningMap.end() && It->second->ProcessHandle != nullptr) + { + Proc = static_cast<ManagedProcess*>(It->second->ProcessHandle); + } + }); + + if (!Proc) + { + return false; + } + + // Terminate the process. The exit callback will handle the rest + // (remove from running map, gather outputs or mark failed). + Proc->Terminate(222); + + ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn); + return true; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/managedrunner.h b/src/zencompute/runners/managedrunner.h new file mode 100644 index 000000000..21a44d43c --- /dev/null +++ b/src/zencompute/runners/managedrunner.h @@ -0,0 +1,64 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zenutil/process/subprocessmanager.h> + +# include <memory> + +namespace asio { +class io_context; +} + +namespace zen::compute { + +/** Cross-platform process runner backed by SubprocessManager. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and shared action preparation. Replaces the polling-based + monitor thread with async exit callbacks driven by SubprocessManager, and + delegates CPU/memory metrics sampling to the manager's built-in round-robin + sampler. + + A ProcessGroup (backed by a JobObject on Windows, process group on POSIX) is + used for bulk cancellation on shutdown. + + This runner does not perform any platform-specific sandboxing (AppContainer, + namespaces, Seatbelt). It is intended as a simpler, cross-platform alternative + to the platform-specific runners for non-sandboxed workloads. + */ +class ManagedProcessRunner : public LocalProcessRunner +{ +public: + ManagedProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions = 0); + ~ManagedProcessRunner(); + + void Shutdown() override; + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + [[nodiscard]] bool IsHealthy() override { return true; } + +private: + static constexpr int kIoThreadCount = 4; + + // Exit callback posted on an io_context thread. + void OnProcessExit(int ActionLsn, int ExitCode); + + std::unique_ptr<asio::io_context> m_IoContext; + std::unique_ptr<SubprocessManager> m_SubprocessManager; + ProcessGroup* m_ProcessGroup = nullptr; + std::vector<std::thread> m_IoThreads; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp index ce6a81173..08f381b7f 100644 --- a/src/zencompute/runners/remotehttprunner.cpp +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -20,6 +20,7 @@ # include <zenstore/cidstore.h> # include <span> +# include <unordered_set> ////////////////////////////////////////////////////////////////////////// @@ -38,6 +39,7 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, , m_ChunkResolver{InChunkResolver} , m_WorkerPool{InWorkerPool} , m_HostName{HostName} +, m_DisplayName{HostName} , m_BaseUrl{fmt::format("{}/compute", HostName)} , m_Http(m_BaseUrl) , m_InstanceId(Oid::NewOid()) @@ -59,6 +61,15 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; } +void +RemoteHttpRunner::SetRemoteHostname(std::string_view Hostname) +{ + if (!Hostname.empty()) + { + m_DisplayName = fmt::format("{} ({})", m_HostName, Hostname); + } +} + RemoteHttpRunner::~RemoteHttpRunner() { Shutdown(); @@ -108,6 +119,7 @@ RemoteHttpRunner::Shutdown() for (auto& [RemoteLsn, HttpAction] : Remaining) { ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn); + HttpAction.Action->FailureReason = "remote runner shutdown"; HttpAction.Action->SetActionState(RunnerAction::State::Failed); } } @@ -213,11 +225,13 @@ RemoteHttpRunner::QueryCapacity() return 0; } - // Estimate how much more work we're ready to accept + // Estimate how much more work we're ready to accept. + // Include actions currently being submitted over HTTP so we don't + // keep queueing new submissions while previous ones are still in flight. RwLock::SharedLockScope _{m_RunningLock}; - size_t RunningCount = m_RemoteRunningMap.size(); + size_t RunningCount = m_RemoteRunningMap.size() + m_InFlightSubmissions.load(std::memory_order_relaxed); if (RunningCount >= size_t(m_MaxRunningActions)) { @@ -232,6 +246,9 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) { ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions"); + m_InFlightSubmissions.fetch_add(Actions.size(), std::memory_order_relaxed); + auto InFlightGuard = MakeGuard([&] { m_InFlightSubmissions.fetch_sub(Actions.size(), std::memory_order_relaxed); }); + if (Actions.size() <= 1) { std::vector<SubmitResult> Results; @@ -246,7 +263,7 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) // Collect distinct QueueIds and ensure remote queues exist once per queue - std::unordered_map<int, Oid> QueueTokens; // QueueId → remote token (0 stays as Zero) + std::unordered_map<int, Oid> QueueTokens; // QueueId -> remote token (0 stays as Zero) for (const Ref<RunnerAction>& Action : Actions) { @@ -359,108 +376,141 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) } } - // Enqueue job. If the remote returns FailedDependency (424), it means it - // cannot resolve the worker/function — re-register the worker and retry once. + // Submit the action to the remote. In eager-attach mode we build a + // CbPackage with all referenced attachments upfront to avoid the 404 + // round-trip. In the default mode we POST the bare object first and + // only upload missing attachments if the remote requests them. + // + // In both modes, FailedDependency (424) triggers a worker re-register + // and a single retry. CbObject Result; HttpClient::Response WorkResponse; HttpResponseCode WorkResponseCode{}; - for (int Attempt = 0; Attempt < 2; ++Attempt) - { - WorkResponse = m_Http.Post(SubmitUrl, ActionObj); - WorkResponseCode = WorkResponse.StatusCode; - - if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) - { - ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", - m_Http.GetBaseUri(), - ActionId); - - (void)RegisterWorker(Action->Worker.Descriptor); - } - else - { - break; - } - } - - if (WorkResponseCode == HttpResponseCode::OK) - { - Result = WorkResponse.AsObject(); - } - else if (WorkResponseCode == HttpResponseCode::NotFound) + if (m_EagerAttach) { - // Not all attachments are present - - // Build response package including all required attachments - CbPackage Pkg; Pkg.SetObject(ActionObj); - CbObject Response = WorkResponse.AsObject(); + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash AttachHash = Field.AsHash(); - for (auto& Item : Response["need"sv]) - { - const IoHash NeedHash = Item.AsHash(); - - if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash)) { uint64_t DataRawSize = 0; IoHash DataRawHash; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); - ZEN_ASSERT(DataRawHash == NeedHash); + Pkg.AddAttachment(CbAttachment(Compressed, AttachHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + }); + + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, Pkg); + WorkResponseCode = WorkResponse.StatusCode; + + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} - re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); - Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + (void)RegisterWorker(Action->Worker.Descriptor); } else { - // No such attachment - - return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + break; } } + } + else + { + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, ActionObj); + WorkResponseCode = WorkResponse.StatusCode; - // Post resulting package + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} - re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); - HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); + (void)RegisterWorker(Action->Worker.Descriptor); + } + else + { + break; + } + } - if (!PayloadResponse) + if (WorkResponseCode == HttpResponseCode::NotFound) { - ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); + // Remote needs attachments - resolve them and retry with a CbPackage - // TODO: include more information about the failure in the response + CbPackage Pkg; + Pkg.SetObject(ActionObj); - return {.IsAccepted = false, .Reason = "HTTP request failed"}; - } - else if (PayloadResponse.StatusCode == HttpResponseCode::OK) - { - Result = PayloadResponse.AsObject(); - } - else - { - // Unexpected response - - const int ResponseStatusCode = (int)PayloadResponse.StatusCode; - - ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})", - ActionId, - m_Http.GetBaseUri(), - SubmitUrl, - ResponseStatusCode, - ToString(ResponseStatusCode)); - - return {.IsAccepted = false, - .Reason = fmt::format("unexpected response code {} {} from {}{}", - ResponseStatusCode, - ToString(ResponseStatusCode), - m_Http.GetBaseUri(), - SubmitUrl)}; + CbObject Response = WorkResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + else + { + return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + } + } + + HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); + + if (!PayloadResponse) + { + ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + + WorkResponse = std::move(PayloadResponse); + WorkResponseCode = WorkResponse.StatusCode; } } + if (WorkResponseCode == HttpResponseCode::OK) + { + Result = WorkResponse.AsObject(); + } + else if (!WorkResponse) + { + ZEN_WARN("submit of action {} to {}{} failed", ActionId, m_Http.GetBaseUri(), SubmitUrl); + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + else if (!IsHttpSuccessCode(WorkResponseCode)) + { + const int Code = static_cast<int>(WorkResponseCode); + ZEN_WARN("submit of action {} to {}{} returned {} {}", ActionId, m_Http.GetBaseUri(), SubmitUrl, Code, ToString(Code)); + return {.IsAccepted = false, + .Reason = fmt::format("unexpected response code {} {} from {}{}", Code, ToString(Code), m_Http.GetBaseUri(), SubmitUrl)}; + } + if (Result) { if (const int32_t LsnField = Result["lsn"].AsInt32(0)) @@ -512,83 +562,111 @@ RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vec CbObjectWriter Body; Body.BeginArray("actions"sv); + std::unordered_set<IoHash, IoHash::Hasher> AttachmentsSeen; + for (const Ref<RunnerAction>& Action : Actions) { Action->ExecutionLocation = m_HostName; MaybeDumpAction(Action->ActionLsn, Action->ActionObj); Body.AddObject(Action->ActionObj); + + if (m_EagerAttach) + { + Action->ActionObj.IterateAttachments([&](CbFieldView Field) { AttachmentsSeen.insert(Field.AsHash()); }); + } } Body.EndArray(); - // POST the batch - - HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save()); - - if (Response.StatusCode == HttpResponseCode::OK) - { - return ParseBatchResponse(Response, Actions); - } + // In eager-attach mode, build a CbPackage with all referenced attachments + // so the remote can accept in a single round-trip. Otherwise POST a bare + // CbObject and handle the 404 need-list flow. - if (Response.StatusCode == HttpResponseCode::NotFound) + if (m_EagerAttach) { - // Server needs attachments — resolve them and retry with a CbPackage - - CbObject NeedObj = Response.AsObject(); - CbPackage Pkg; Pkg.SetObject(Body.Save()); - for (auto& Item : NeedObj["need"sv]) + for (const IoHash& AttachHash : AttachmentsSeen) { - const IoHash NeedHash = Item.AsHash(); - - if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash)) { uint64_t DataRawSize = 0; IoHash DataRawHash; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); - ZEN_ASSERT(DataRawHash == NeedHash); - - Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); - } - else - { - ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash); - return FallbackToIndividualSubmit(Actions); + Pkg.AddAttachment(CbAttachment(Compressed, AttachHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); } } - HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg); + HttpClient::Response Response = m_Http.Post(SubmitUrl, Pkg); - if (RetryResponse.StatusCode == HttpResponseCode::OK) + if (Response.StatusCode == HttpResponseCode::OK) { - return ParseBatchResponse(RetryResponse, Actions); + return ParseBatchResponse(Response, Actions); } - - ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit", - (int)RetryResponse.StatusCode, - ToString(RetryResponse.StatusCode)); - return FallbackToIndividualSubmit(Actions); - } - - // Unexpected status or connection error — fall back to individual submission - - if (Response) - { - ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit", - m_Http.GetBaseUri(), - SubmitUrl, - (int)Response.StatusCode, - ToString(Response.StatusCode)); } else { - ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl); + HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save()); + + if (Response.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(Response, Actions); + } + + if (Response.StatusCode == HttpResponseCode::NotFound) + { + CbObject NeedObj = Response.AsObject(); + + CbPackage Pkg; + Pkg.SetObject(Body.Save()); + + for (auto& Item : NeedObj["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + else + { + ZEN_WARN("batch submit: missing attachment {} - falling back to individual submit", NeedHash); + return FallbackToIndividualSubmit(Actions); + } + } + + HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg); + + if (RetryResponse.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(RetryResponse, Actions); + } + + ZEN_WARN("batch submit retry failed with {} {} - falling back to individual submit", + (int)RetryResponse.StatusCode, + ToString(RetryResponse.StatusCode)); + return FallbackToIndividualSubmit(Actions); + } } + // Unexpected status or connection error - fall back to individual submission + + ZEN_WARN("batch submit to {}{} failed - falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl); + return FallbackToIndividualSubmit(Actions); } @@ -849,7 +927,7 @@ RemoteHttpRunner::MonitorThreadFunction() SweepOnce(); } - // Signal received — may be a WS wakeup or a quit signal + // Signal received - may be a WS wakeup or a quit signal SweepOnce(); } while (m_MonitorThreadEnabled); @@ -869,9 +947,10 @@ RemoteHttpRunner::SweepRunningActions() { for (auto& FieldIt : Completed["completed"sv]) { - CbObjectView EntryObj = FieldIt.AsObjectView(); - const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); - std::string_view StateName = EntryObj["state"sv].AsString(); + CbObjectView EntryObj = FieldIt.AsObjectView(); + const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); + std::string_view StateName = EntryObj["state"sv].AsString(); + std::string_view FailureReason = EntryObj["reason"sv].AsString(); RunnerAction::State RemoteState = RunnerAction::FromString(StateName); @@ -884,6 +963,7 @@ RemoteHttpRunner::SweepRunningActions() { HttpRunningAction CompletedAction = std::move(CompleteIt->second); CompletedAction.RemoteState = RemoteState; + CompletedAction.FailureReason = std::string(FailureReason); if (RemoteState == RunnerAction::State::Completed && ResponseJob) { @@ -927,16 +1007,44 @@ RemoteHttpRunner::SweepRunningActions() { const int ActionLsn = HttpAction.Action->ActionLsn; - ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", - HttpAction.Action->ActionId, - ActionLsn, - HttpAction.RemoteActionLsn, - RunnerAction::ToString(HttpAction.RemoteState)); - if (HttpAction.RemoteState == RunnerAction::State::Completed) { + ZEN_DEBUG("action {} LSN {} (remote LSN {}) completed on {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + m_HostName); HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); } + else if (HttpAction.RemoteState == RunnerAction::State::Failed || HttpAction.RemoteState == RunnerAction::State::Abandoned) + { + HttpAction.Action->FailureReason = HttpAction.FailureReason; + if (HttpAction.FailureReason.empty()) + { + ZEN_WARN("action {} ({}) {} on remote {}", + HttpAction.Action->ActionId, + ActionLsn, + RunnerAction::ToString(HttpAction.RemoteState), + m_HostName); + } + else + { + ZEN_WARN("action {} ({}) {} on remote {}: {}", + HttpAction.Action->ActionId, + ActionLsn, + RunnerAction::ToString(HttpAction.RemoteState), + m_HostName, + HttpAction.FailureReason); + } + } + else + { + ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + RunnerAction::ToString(HttpAction.RemoteState)); + } HttpAction.Action->SetActionState(HttpAction.RemoteState); } diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h index c17d0cf2a..521bf2f82 100644 --- a/src/zencompute/runners/remotehttprunner.h +++ b/src/zencompute/runners/remotehttprunner.h @@ -54,8 +54,10 @@ public: [[nodiscard]] virtual size_t QueryCapacity() override; [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override; virtual void CancelRemoteQueue(int QueueId) override; + [[nodiscard]] virtual std::string_view GetDisplayName() const override { return m_DisplayName; } std::string_view GetHostName() const { return m_HostName; } + void SetRemoteHostname(std::string_view Hostname); protected: LoggerRef Log() { return m_Log; } @@ -65,12 +67,15 @@ private: ChunkResolver& m_ChunkResolver; WorkerThreadPool& m_WorkerPool; std::string m_HostName; + std::string m_DisplayName; std::string m_BaseUrl; HttpClient m_Http; - std::atomic<bool> m_AcceptNewActions{true}; - int32_t m_MaxRunningActions = 256; // arbitrary limit for testing - int32_t m_MaxBatchSize = 50; + std::atomic<bool> m_AcceptNewActions{true}; + int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + int32_t m_MaxBatchSize = 50; + bool m_EagerAttach = true; ///< Send attachments with every submit instead of the two-step 404 retry + std::atomic<size_t> m_InFlightSubmissions{0}; // actions currently being submitted over HTTP struct HttpRunningAction { @@ -78,6 +83,7 @@ private: int RemoteActionLsn = 0; // Remote LSN RunnerAction::State RemoteState = RunnerAction::State::Failed; CbPackage ActionResults; + std::string FailureReason; }; RwLock m_RunningLock; @@ -90,7 +96,7 @@ private: size_t SweepRunningActions(); RwLock m_QueueTokenLock; - std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId → remote queue token + std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId -> remote queue token // Stable identity for this runner instance, used as part of the idempotency key when // creating remote queues. Generated once at construction and never changes. diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp index cd4b646e9..c6b3e82ea 100644 --- a/src/zencompute/runners/windowsrunner.cpp +++ b/src/zencompute/runners/windowsrunner.cpp @@ -21,6 +21,12 @@ ZEN_THIRD_PARTY_INCLUDES_START # include <sddl.h> ZEN_THIRD_PARTY_INCLUDES_END +// JOB_OBJECT_UILIMIT_ERRORMODE is defined in winuser.h which may be +// excluded by WIN32_LEAN_AND_MEAN. +# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE) +# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400 +# endif + namespace zen::compute { using namespace std::literals; @@ -34,38 +40,67 @@ WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver, : LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) , m_Sandboxed(Sandboxed) { - if (!m_Sandboxed) + // Create a job object shared by all child processes. Restricting the + // error-mode UI prevents crash dialogs (WER / Dr. Watson) from + // blocking the monitor thread when a worker process terminates + // abnormally. + m_JobObject = CreateJobObjectW(nullptr, nullptr); + if (m_JobObject) { - return; + JOBOBJECT_EXTENDED_LIMIT_INFORMATION ExtLimits{}; + ExtLimits.BasicLimitInformation.LimitFlags = + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION | JOB_OBJECT_LIMIT_PRIORITY_CLASS; + ExtLimits.BasicLimitInformation.PriorityClass = BELOW_NORMAL_PRIORITY_CLASS; + SetInformationJobObject(m_JobObject, JobObjectExtendedLimitInformation, &ExtLimits, sizeof(ExtLimits)); + + JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{}; + UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE; + SetInformationJobObject(m_JobObject, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions)); + + // Set error mode on this process so children inherit it. The + // UILIMIT_ERRORMODE restriction above prevents them from clearing + // SEM_NOGPFAULTERRORBOX. + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX); } - // Build a unique profile name per process to avoid collisions - m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); + if (m_Sandboxed) + { + // Build a unique profile name per process to avoid collisions + m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); - // Clean up any stale profile from a previous crash - DeleteAppContainerProfile(m_AppContainerName.c_str()); + // Clean up any stale profile from a previous crash + DeleteAppContainerProfile(m_AppContainerName.c_str()); - PSID Sid = nullptr; + PSID Sid = nullptr; - HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), - m_AppContainerName.c_str(), // display name - m_AppContainerName.c_str(), // description - nullptr, // no capabilities - 0, // capability count - &Sid); + HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), + m_AppContainerName.c_str(), // display name + m_AppContainerName.c_str(), // description + nullptr, // no capabilities + 0, // capability count + &Sid); - if (FAILED(Hr)) - { - throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); - } + if (FAILED(Hr)) + { + throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); + } - m_AppContainerSid = Sid; + m_AppContainerSid = Sid; + + ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); + } - ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); + StartMonitorThread(); } WindowsProcessRunner::~WindowsProcessRunner() { + if (m_JobObject) + { + CloseHandle(m_JobObject); + m_JobObject = nullptr; + } + if (m_AppContainerSid) { FreeSid(m_AppContainerSid); @@ -172,9 +207,9 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; BOOL bInheritHandles = FALSE; - DWORD dwCreationFlags = DETACHED_PROCESS; + DWORD dwCreationFlags = CREATE_SUSPENDED | DETACHED_PROCESS; - ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed); + ZEN_DEBUG("{}: '{}' (sandbox='{}')", m_Sandboxed ? "Sandboxing" : "Executing", WideToUtf8(CommandLine.c_str()), Prepared->SandboxPath); CommandLine.EnsureNulTerminated(); @@ -260,14 +295,21 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) } } - CloseHandle(ProcessInformation.hThread); + if (m_JobObject) + { + AssignProcessToJobObject(m_JobObject, ProcessInformation.hProcess); + } - Ref<RunningAction> NewAction{new RunningAction()}; - NewAction->Action = Action; - NewAction->ProcessHandle = ProcessInformation.hProcess; - NewAction->SandboxPath = std::move(Prepared->SandboxPath); + ResumeThread(ProcessInformation.hThread); + CloseHandle(ProcessInformation.hThread); { + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = ProcessInformation.hProcess; + NewAction->Pid = ProcessInformation.dwProcessId; + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + RwLock::ExclusiveLockScope _(m_RunningLock); m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); @@ -275,6 +317,8 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) Action->SetActionState(RunnerAction::State::Running); + ZEN_DEBUG("Local runner: action LSN {} -> PID {}", Action->ActionLsn, ProcessInformation.dwProcessId); + return SubmitResult{.IsAccepted = true}; } @@ -294,6 +338,11 @@ WindowsProcessRunner::SweepRunningActions() if (IsSuccess && ExitCode != STILL_ACTIVE) { + ZEN_DEBUG("Local runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"), + Running->Action->ActionLsn, + Running->Pid, + ExitCode); + CloseHandle(Running->ProcessHandle); Running->ProcessHandle = INVALID_HANDLE_VALUE; Running->ExitCode = ExitCode; @@ -436,7 +485,7 @@ WindowsProcessRunner::SampleProcessCpu(RunningAction& Running) const uint64_t CurrentOsTicks = FtToU64(KernelTime) + FtToU64(UserTime); const uint64_t NowTicks = GetHifreqTimerValue(); - // Cumulative CPU seconds (absolute, available from first sample): 100ns → seconds + // Cumulative CPU seconds (absolute, available from first sample): 100ns -> seconds Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 10'000'000.0), std::memory_order_relaxed); if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) @@ -445,7 +494,7 @@ WindowsProcessRunner::SampleProcessCpu(RunningAction& Running) if (ElapsedMs > 0) { const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; - // 100ns → ms: divide by 10000; then as percent of elapsed ms + // 100ns -> ms: divide by 10000; then as percent of elapsed ms const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 10000.0 / ElapsedMs * 100.0); Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); } diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h index 9f2385cc4..adeaf02fc 100644 --- a/src/zencompute/runners/windowsrunner.h +++ b/src/zencompute/runners/windowsrunner.h @@ -46,6 +46,7 @@ private: bool m_Sandboxed = false; PSID m_AppContainerSid = nullptr; std::wstring m_AppContainerName; + HANDLE m_JobObject = nullptr; }; } // namespace zen::compute diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp index 506bec73b..29ab93663 100644 --- a/src/zencompute/runners/winerunner.cpp +++ b/src/zencompute/runners/winerunner.cpp @@ -36,6 +36,8 @@ WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver, sigemptyset(&Action.sa_mask); Action.sa_handler = SIG_DFL; sigaction(SIGCHLD, &Action, nullptr); + + StartMonitorThread(); } SubmitResult @@ -94,7 +96,9 @@ WineProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (ChildPid == 0) { - // Child process + // Child process - lower priority so workers don't starve the main server + nice(5); + if (chdir(SandboxPathStr.c_str()) != 0) { _exit(127); diff --git a/src/zencore/compactbinaryjson.cpp b/src/zencore/compactbinaryjson.cpp index da560a449..5bfbd5e3e 100644 --- a/src/zencore/compactbinaryjson.cpp +++ b/src/zencore/compactbinaryjson.cpp @@ -11,6 +11,8 @@ #include <zencore/testing.h> #include <fmt/format.h> +#include <cmath> +#include <limits> #include <vector> ZEN_THIRD_PARTY_INCLUDES_START @@ -570,13 +572,37 @@ private: break; case Json::Type::NUMBER: { - if (FieldName.empty()) - { - Writer.AddFloat(Json.number_value()); + // If the JSON number has no fractional part and fits in an int64, + // store it as an integer so that AsInt32/AsInt64 work without + // requiring callers to go through AsFloat. + double Value = Json.number_value(); + double IntPart; + bool IsIntegral = (std::modf(Value, &IntPart) == 0.0) && + Value >= static_cast<double>(std::numeric_limits<int64_t>::min()) && + Value <= static_cast<double>(std::numeric_limits<int64_t>::max()); + + if (IsIntegral) + { + int64_t IntValue = static_cast<int64_t>(Value); + if (FieldName.empty()) + { + Writer.AddInteger(IntValue); + } + else + { + Writer.AddInteger(FieldName, IntValue); + } } else { - Writer.AddFloat(FieldName, Json.number_value()); + if (FieldName.empty()) + { + Writer.AddFloat(Value); + } + else + { + Writer.AddFloat(FieldName, Value); + } } } break; diff --git a/src/zencore/compactbinarypackage.cpp b/src/zencore/compactbinarypackage.cpp index 56a292ca6..87b58baf7 100644 --- a/src/zencore/compactbinarypackage.cpp +++ b/src/zencore/compactbinarypackage.cpp @@ -684,14 +684,22 @@ namespace legacy { Writer.Save(Ar); } - bool TryLoadCbPackage(CbPackage& Package, IoBuffer InBuffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper) + bool TryLoadCbPackage(CbPackage& Package, + IoBuffer InBuffer, + BufferAllocator Allocator, + CbPackage::AttachmentResolver* Mapper, + bool ValidateHashes) { BinaryReader Reader(InBuffer.Data(), InBuffer.Size()); - return TryLoadCbPackage(Package, Reader, Allocator, Mapper); + return TryLoadCbPackage(Package, Reader, Allocator, Mapper, ValidateHashes); } - bool TryLoadCbPackage(CbPackage& Package, BinaryReader& Reader, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper) + bool TryLoadCbPackage(CbPackage& Package, + BinaryReader& Reader, + BufferAllocator Allocator, + CbPackage::AttachmentResolver* Mapper, + bool ValidateHashes) { Package = CbPackage(); for (;;) @@ -708,7 +716,11 @@ namespace legacy { if (ValueField.IsBinary()) { const MemoryView View = ValueField.AsBinaryView(); - if (View.GetSize() > 0) + if (View.GetSize() == 0) + { + return false; + } + else { SharedBuffer Buffer = SharedBuffer::MakeView(View, ValueField.GetOuterBuffer()).MakeOwned(); CbField HashField = LoadCompactBinary(Reader, Allocator); @@ -748,7 +760,11 @@ namespace legacy { { const IoHash Hash = ValueField.AsHash(); - ZEN_ASSERT(Mapper); + if (!Mapper) + { + return false; + } + if (SharedBuffer AttachmentData = (*Mapper)(Hash)) { IoHash RawHash; @@ -763,6 +779,10 @@ namespace legacy { } else { + if (ValidateHashes && IoHash::HashBuffer(AttachmentData) != Hash) + { + return false; + } const CbValidateError ValidationResult = ValidateCompactBinary(AttachmentData.GetView(), CbValidateMode::All); if (ValidationResult != CbValidateError::None) { @@ -801,13 +821,13 @@ namespace legacy { #if ZEN_WITH_TESTS void -usonpackage_forcelink() +cbpackage_forcelink() { } TEST_SUITE_BEGIN("core.compactbinarypackage"); -TEST_CASE("usonpackage") +TEST_CASE("cbpackage") { using namespace std::literals; @@ -997,7 +1017,7 @@ TEST_CASE("usonpackage") } } -TEST_CASE("usonpackage.serialization") +TEST_CASE("cbpackage.serialization") { using namespace std::literals; @@ -1303,7 +1323,7 @@ TEST_CASE("usonpackage.serialization") } } -TEST_CASE("usonpackage.invalidpackage") +TEST_CASE("cbpackage.invalidpackage") { const auto TestLoad = [](std::initializer_list<uint8_t> RawData, BufferAllocator Allocator = UniqueBuffer::Alloc) { const MemoryView RawView = MakeMemoryView(RawData); @@ -1345,6 +1365,90 @@ TEST_CASE("usonpackage.invalidpackage") } } +TEST_CASE("cbpackage.legacy.invalidpackage") +{ + const auto TestLegacyLoad = [](std::initializer_list<uint8_t> RawData) { + const MemoryView RawView = MakeMemoryView(RawData); + IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(RawView.GetData()), RawView.GetSize()); + CbPackage Package; + CHECK_FALSE(legacy::TryLoadCbPackage(Package, Buffer, &UniqueBuffer::Alloc)); + }; + + SUBCASE("Empty") { TestLegacyLoad({}); } + + SUBCASE("Zero size binary rejects") + { + // A binary field with zero payload size should be rejected (would desync the reader) + BinaryWriter Writer; + CbWriter Cb; + Cb.AddBinary(MemoryView()); // zero-size binary + Cb.Save(Writer); + + IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(MakeMemoryView(Writer).GetData()), MakeMemoryView(Writer).GetSize()); + CbPackage Package; + CHECK_FALSE(legacy::TryLoadCbPackage(Package, Buffer, &UniqueBuffer::Alloc)); + } +} + +TEST_CASE("cbpackage.legacy.hashresolution") +{ + // Build a valid legacy package with an object, then round-trip it + CbObjectWriter RootWriter; + RootWriter.AddString("name", "test"); + CbObject RootObject = RootWriter.Save(); + + CbAttachment ObjectAttach(RootObject); + + CbPackage OriginalPkg; + OriginalPkg.SetObject(RootObject); + OriginalPkg.AddAttachment(ObjectAttach); + + BinaryWriter Writer; + legacy::SaveCbPackage(OriginalPkg, Writer); + + IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(MakeMemoryView(Writer).GetData()), MakeMemoryView(Writer).GetSize()); + CbPackage LoadedPkg; + CHECK(legacy::TryLoadCbPackage(LoadedPkg, Buffer, &UniqueBuffer::Alloc)); + + // The hash-only path requires a mapper - without one it should fail + CbWriter HashOnlyCb; + HashOnlyCb.AddHash(ObjectAttach.GetHash()); + HashOnlyCb.AddNull(); + BinaryWriter HashOnlyWriter; + HashOnlyCb.Save(HashOnlyWriter); + + IoBuffer HashOnlyBuffer(IoBuffer::Wrap, + const_cast<void*>(MakeMemoryView(HashOnlyWriter).GetData()), + MakeMemoryView(HashOnlyWriter).GetSize()); + CbPackage HashOnlyPkg; + CHECK_FALSE(legacy::TryLoadCbPackage(HashOnlyPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, nullptr)); + + // With a mapper that returns valid data, it should succeed + CbPackage::AttachmentResolver Resolver = [&](const IoHash& Hash) -> SharedBuffer { + if (Hash == ObjectAttach.GetHash()) + { + return RootObject.GetBuffer(); + } + return {}; + }; + CHECK(legacy::TryLoadCbPackage(HashOnlyPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &Resolver)); + + // Build a different but structurally valid CbObject to use as mismatched data + CbObjectWriter DifferentWriter; + DifferentWriter.AddString("name", "different"); + CbObject DifferentObject = DifferentWriter.Save(); + + CbPackage::AttachmentResolver BadResolver = [&](const IoHash&) -> SharedBuffer { return DifferentObject.GetBuffer(); }; + CbPackage BadPkg; + + // With ValidateHashes enabled and a mapper that returns mismatched data, it should fail + CHECK_FALSE(legacy::TryLoadCbPackage(BadPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &BadResolver, /*ValidateHashes*/ true)); + + // Without ValidateHashes, the mismatched data is accepted (structure is still valid CB) + CbPackage UncheckedPkg; + CHECK(legacy::TryLoadCbPackage(UncheckedPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &BadResolver, /*ValidateHashes*/ false)); +} + TEST_SUITE_END(); #endif diff --git a/src/zencore/crashhandler.cpp b/src/zencore/crashhandler.cpp index 31b8e6ce2..14904a4b2 100644 --- a/src/zencore/crashhandler.cpp +++ b/src/zencore/crashhandler.cpp @@ -56,7 +56,7 @@ CrashExceptionFilter(PEXCEPTION_POINTERS ExceptionInfo) HANDLE Process = GetCurrentProcess(); HANDLE Thread = GetCurrentThread(); - // SymInitialize is safe to call if already initialized — it returns FALSE + // SymInitialize is safe to call if already initialized - it returns FALSE // but existing state remains valid for SymFromAddr calls SymInitialize(Process, NULL, TRUE); diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 0d361801f..5160bfdc6 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -114,6 +114,20 @@ struct ScopedFd explicit operator bool() const { return Fd >= 0; } }; +# if ZEN_PLATFORM_LINUX +inline uint64_t +StatMtime100Ns(const struct stat& S) +{ + return uint64_t(S.st_mtim.tv_sec) * 10'000'000ULL + uint64_t(S.st_mtim.tv_nsec) / 100; +} +# elif ZEN_PLATFORM_MAC +inline uint64_t +StatMtime100Ns(const struct stat& S) +{ + return uint64_t(S.st_mtimespec.tv_sec) * 10'000'000ULL + uint64_t(S.st_mtimespec.tv_nsec) / 100; +} +# endif + #endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC #if ZEN_PLATFORM_WINDOWS @@ -2123,7 +2137,7 @@ FileSystemTraversal::TraverseFileSystem(const std::filesystem::path& RootDir, Tr } else if (S_ISREG(Stat.st_mode)) { - Visitor.VisitFile(RootDir, FileName, Stat.st_size, gsl::narrow<uint32_t>(Stat.st_mode), gsl::narrow<uint64_t>(Stat.st_mtime)); + Visitor.VisitFile(RootDir, FileName, Stat.st_size, gsl::narrow<uint32_t>(Stat.st_mode), StatMtime100Ns(Stat)); } else { @@ -2507,7 +2521,7 @@ GetModificationTickFromHandle(void* NativeHandle, std::error_code& Ec) struct stat Stat; if (0 == fstat(Fd, &Stat)) { - return gsl::narrow<uint64_t>(Stat.st_mtime); + return StatMtime100Ns(Stat); } #endif Ec = MakeErrorCodeFromLastError(); @@ -2546,7 +2560,7 @@ GetModificationTickFromPath(const std::filesystem::path& Filename) { ThrowLastError(fmt::format("Failed to get mode of file {}", Filename)); } - return gsl::narrow<uint64_t>(Stat.st_mtime); + return StatMtime100Ns(Stat); #endif } @@ -2589,7 +2603,7 @@ TryGetFileProperties(const std::filesystem::path& Path, { return false; } - OutModificationTick = gsl::narrow<uint64_t>(Stat.st_mtime); + OutModificationTick = StatMtime100Ns(Stat); OutSize = size_t(Stat.st_size); OutNativeModeOrAttributes = (uint32_t)Stat.st_mode; return true; @@ -2963,6 +2977,35 @@ GetEnvVariable(std::string_view VariableName) return ""; } +std::string +ExpandEnvironmentVariables(std::string_view Input) +{ + std::string Result; + Result.reserve(Input.size()); + + for (size_t i = 0; i < Input.size(); ++i) + { + if (Input[i] == '%') + { + size_t End = Input.find('%', i + 1); + if (End != std::string_view::npos && End > i + 1) + { + std::string_view VarName = Input.substr(i + 1, End - i - 1); + std::string Value = GetEnvVariable(VarName); + if (!Value.empty()) + { + Result += Value; + i = End; + continue; + } + } + } + Result += Input[i]; + } + + return Result; +} + std::error_code RotateFiles(const std::filesystem::path& Filename, std::size_t MaxFiles) { @@ -3275,14 +3318,25 @@ MakeSafeAbsolutePathInPlace(std::filesystem::path& Path) { if (!Path.empty()) { - std::filesystem::path AbsolutePath = std::filesystem::absolute(Path).make_preferred(); + Path = std::filesystem::absolute(Path).make_preferred(); #if ZEN_PLATFORM_WINDOWS - const std::string_view Prefix = "\\\\?\\"; - const std::u8string PrefixU8(Prefix.begin(), Prefix.end()); - std::u8string PathString = AbsolutePath.u8string(); - if (!PathString.empty() && !PathString.starts_with(PrefixU8)) + const std::u8string_view LongPathPrefix = u8"\\\\?\\"; + const std::u8string_view UncPrefix = u8"\\\\"; + const std::u8string_view LongPathUncPrefix = u8"\\\\?\\UNC\\"; + + std::u8string PathString = Path.u8string(); + if (!PathString.empty() && !PathString.starts_with(LongPathPrefix)) { - PathString.insert(0, PrefixU8); + if (PathString.starts_with(UncPrefix)) + { + // UNC path: \\server\share -> \\?\UNC\server\share + PathString.replace(0, UncPrefix.size(), LongPathUncPrefix); + } + else + { + // Local path: C:\foo -> \\?\C:\foo + PathString.insert(0, LongPathPrefix); + } Path = PathString; } #endif // ZEN_PLATFORM_WINDOWS @@ -3408,7 +3462,7 @@ public: ZEN_UNUSED(SystemGlobal); std::string InstanceMapName = fmt::format("/{}", Name); - ScopedFd FdGuard(shm_open(InstanceMapName.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666)); + ScopedFd FdGuard(shm_open(InstanceMapName.c_str(), O_RDWR | O_CREAT, 0666)); if (!FdGuard) { return {}; @@ -4049,6 +4103,93 @@ TEST_CASE("SharedMemory") CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, false)); } +TEST_CASE("filesystem.MakeSafeAbsolutePath") +{ +# if ZEN_PLATFORM_WINDOWS + // Local path gets \\?\ prefix + { + std::filesystem::path Local = MakeSafeAbsolutePath("C:\\Users\\test"); + CHECK(Local.u8string().starts_with(u8"\\\\?\\")); + CHECK(Local.u8string().find(u8"C:\\Users\\test") != std::u8string::npos); + } + + // UNC path gets \\?\UNC\ prefix + { + std::filesystem::path Unc = MakeSafeAbsolutePath("\\\\server\\share\\path"); + std::u8string UncStr = Unc.u8string(); + CHECK_MESSAGE(UncStr.starts_with(u8"\\\\?\\UNC\\"), fmt::format("Expected \\\\?\\UNC\\ prefix, got '{}'", Unc)); + CHECK_MESSAGE(UncStr.find(u8"server\\share\\path") != std::u8string::npos, + fmt::format("Expected server\\share\\path in '{}'", Unc)); + // Must NOT produce \\?\\\server (double backslash after \\?\) + CHECK_MESSAGE(UncStr.find(u8"\\\\?\\\\\\") == std::u8string::npos, + fmt::format("Path contains invalid double-backslash after prefix: '{}'", Unc)); + } + + // Already-prefixed path is not double-prefixed + { + std::filesystem::path Already = MakeSafeAbsolutePath("\\\\?\\C:\\already\\prefixed"); + size_t Count = 0; + std::u8string Str = Already.u8string(); + for (size_t Pos = Str.find(u8"\\\\?\\"); Pos != std::u8string::npos; Pos = Str.find(u8"\\\\?\\", Pos + 1)) + { + ++Count; + } + CHECK_EQ(Count, 1); + } + + // Already-prefixed UNC path is not double-prefixed + { + std::filesystem::path AlreadyUnc = MakeSafeAbsolutePath("\\\\?\\UNC\\server\\share"); + size_t Count = 0; + std::u8string Str = AlreadyUnc.u8string(); + for (size_t Pos = Str.find(u8"\\\\?\\"); Pos != std::u8string::npos; Pos = Str.find(u8"\\\\?\\", Pos + 1)) + { + ++Count; + } + CHECK_EQ(Count, 1); + } +# endif // ZEN_PLATFORM_WINDOWS +} + +TEST_CASE("ExpandEnvironmentVariables") +{ + // No variables - pass-through + CHECK_EQ(ExpandEnvironmentVariables("plain/path"), "plain/path"); + CHECK_EQ(ExpandEnvironmentVariables(""), ""); + + // Single percent sign is not a variable reference + CHECK_EQ(ExpandEnvironmentVariables("50%"), "50%"); + + // Empty variable name (%%) is not expanded + CHECK_EQ(ExpandEnvironmentVariables("%%"), "%%"); + + // Known variable +# if ZEN_PLATFORM_WINDOWS + // PATH is always set on Windows + std::string PathValue = GetEnvVariable("PATH"); + CHECK(!PathValue.empty()); + CHECK_EQ(ExpandEnvironmentVariables("%PATH%"), PathValue); + CHECK_EQ(ExpandEnvironmentVariables("prefix/%PATH%/suffix"), "prefix/" + PathValue + "/suffix"); +# else + std::string HomeValue = GetEnvVariable("HOME"); + CHECK(!HomeValue.empty()); + CHECK_EQ(ExpandEnvironmentVariables("%HOME%"), HomeValue); + CHECK_EQ(ExpandEnvironmentVariables("prefix/%HOME%/suffix"), "prefix/" + HomeValue + "/suffix"); +# endif + + // Unknown variable is left unexpanded + CHECK_EQ(ExpandEnvironmentVariables("%ZEN_UNLIKELY_SET_VAR_12345%"), "%ZEN_UNLIKELY_SET_VAR_12345%"); + + // Multiple variables +# if ZEN_PLATFORM_WINDOWS + std::string OSValue = GetEnvVariable("OS"); + if (!OSValue.empty()) + { + CHECK_EQ(ExpandEnvironmentVariables("%PATH%/%OS%"), PathValue + "/" + OSValue); + } +# endif +} + TEST_SUITE_END(); #endif diff --git a/src/zencore/include/zencore/compactbinarypackage.h b/src/zencore/include/zencore/compactbinarypackage.h index 64b62e2c0..148c0d3fd 100644 --- a/src/zencore/include/zencore/compactbinarypackage.h +++ b/src/zencore/include/zencore/compactbinarypackage.h @@ -278,10 +278,10 @@ public: * @return The attachment, or null if the attachment is not found. * @note The returned pointer is only valid until the attachments on this package are modified. */ - const CbAttachment* FindAttachment(const IoHash& Hash) const; + [[nodiscard]] const CbAttachment* FindAttachment(const IoHash& Hash) const; /** Find an attachment if it exists in the package. */ - inline const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); } + [[nodiscard]] const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); } /** Add the attachment to this package. */ inline void AddAttachment(const CbAttachment& Attachment) { AddAttachment(Attachment, nullptr); } @@ -336,17 +336,26 @@ private: IoHash ObjectHash; }; +/** In addition to the above, we also support a legacy format which is used by + * the HTTP project store for historical reasons. Don't use the below functions + * for new code. + */ namespace legacy { void SaveCbAttachment(const CbAttachment& Attachment, CbWriter& Writer); void SaveCbPackage(const CbPackage& Package, CbWriter& Writer); void SaveCbPackage(const CbPackage& Package, BinaryWriter& Ar); - bool TryLoadCbPackage(CbPackage& Package, IoBuffer Buffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper = nullptr); + bool TryLoadCbPackage(CbPackage& Package, + IoBuffer Buffer, + BufferAllocator Allocator, + CbPackage::AttachmentResolver* Mapper = nullptr, + bool ValidateHashes = false); bool TryLoadCbPackage(CbPackage& Package, BinaryReader& Reader, BufferAllocator Allocator, - CbPackage::AttachmentResolver* Mapper = nullptr); + CbPackage::AttachmentResolver* Mapper = nullptr, + bool ValidateHashes = false); } // namespace legacy -void usonpackage_forcelink(); // internal +void cbpackage_forcelink(); // internal } // namespace zen diff --git a/src/zencore/include/zencore/filesystem.h b/src/zencore/include/zencore/filesystem.h index 6dc159a83..73769cdb4 100644 --- a/src/zencore/include/zencore/filesystem.h +++ b/src/zencore/include/zencore/filesystem.h @@ -400,6 +400,10 @@ void GetDirectoryContent(const std::filesystem::path& RootDir, std::string GetEnvVariable(std::string_view VariableName); +// Expands %VAR% environment variable references in a string. +// Unknown or empty variables are left unexpanded. +std::string ExpandEnvironmentVariables(std::string_view Input); + std::filesystem::path SearchPathForExecutable(std::string_view ExecutableName); std::error_code RotateFiles(const std::filesystem::path& Filename, std::size_t MaxFiles); diff --git a/src/zencore/include/zencore/fmtutils.h b/src/zencore/include/zencore/fmtutils.h index 4ec05f901..a263c6f04 100644 --- a/src/zencore/include/zencore/fmtutils.h +++ b/src/zencore/include/zencore/fmtutils.h @@ -3,10 +3,7 @@ #pragma once #include <zencore/filesystem.h> -#include <zencore/guid.h> -#include <zencore/iohash.h> #include <zencore/string.h> -#include <zencore/uid.h> ZEN_THIRD_PARTY_INCLUDES_START #include <fmt/format.h> @@ -38,63 +35,49 @@ struct fmt::formatter<T> : fmt::formatter<std::string_view> } }; -// Custom formatting for some zencore types +// Generic formatter for any type that is explicitly convertible to std::string_view. +// This covers NiceNum, NiceBytes, ThousandsNum, StringBuilder, and similar types +// without needing per-type fmt::formatter specializations. template<typename T> -requires DerivedFrom<T, zen::StringBuilderBase> -struct fmt::formatter<T> : fmt::formatter<std::string_view> +concept HasStringViewConversion = std::is_class_v<T> && requires(const T& v) { - template<typename FormatContext> - auto format(const zen::StringBuilderBase& a, FormatContext& ctx) const { - return fmt::formatter<std::string_view>::format(a.ToView(), ctx); - } -}; + std::string_view(v) + } -> std::same_as<std::string_view>; +} && !HasFreeToString<T> && !std::is_same_v<T, std::string> && !std::is_same_v<T, std::string_view>; -template<typename T> -requires DerivedFrom<T, zen::NiceBase> +template<HasStringViewConversion T> struct fmt::formatter<T> : fmt::formatter<std::string_view> { template<typename FormatContext> - auto format(const zen::NiceBase& a, FormatContext& ctx) const + auto format(const T& Value, FormatContext& ctx) const { - return fmt::formatter<std::string_view>::format(std::string_view(a), ctx); + return fmt::formatter<std::string_view>::format(std::string_view(Value), ctx); } }; -template<> -struct fmt::formatter<zen::IoHash> : formatter<string_view> -{ - template<typename FormatContext> - auto format(const zen::IoHash& Hash, FormatContext& ctx) const - { - zen::IoHash::String_t String; - Hash.ToHexString(String); - return fmt::formatter<string_view>::format({String, zen::IoHash::StringLength}, ctx); - } -}; +// Generic formatter for any type with a ToString(StringBuilderBase&) member function. +// This covers Guid, IoHash, Oid, and similar types without needing per-type +// fmt::formatter specializations. -template<> -struct fmt::formatter<zen::Oid> : formatter<string_view> +template<typename T> +concept HasMemberToStringBuilder = std::is_class_v<T> && requires(const T& v, zen::StringBuilderBase& sb) { - template<typename FormatContext> - auto format(const zen::Oid& Id, FormatContext& ctx) const { - zen::StringBuilder<32> String; - Id.ToString(String); - return fmt::formatter<string_view>::format({String.c_str(), zen::Oid::StringLength}, ctx); - } -}; + v.ToString(sb) + } -> std::same_as<zen::StringBuilderBase&>; +} && !HasFreeToString<T> && !HasStringViewConversion<T>; -template<> -struct fmt::formatter<zen::Guid> : formatter<string_view> +template<HasMemberToStringBuilder T> +struct fmt::formatter<T> : fmt::formatter<std::string_view> { template<typename FormatContext> - auto format(const zen::Guid& Id, FormatContext& ctx) const + auto format(const T& Value, FormatContext& ctx) const { - zen::StringBuilder<48> String; - Id.ToString(String); - return fmt::formatter<string_view>::format({String.c_str(), zen::Guid::StringLength}, ctx); + zen::ExtendableStringBuilder<64> String; + Value.ToString(String); + return fmt::formatter<std::string_view>::format(String.ToView(), ctx); } }; diff --git a/src/zencore/include/zencore/hashutils.h b/src/zencore/include/zencore/hashutils.h index 8abfd4b6e..e253d7015 100644 --- a/src/zencore/include/zencore/hashutils.h +++ b/src/zencore/include/zencore/hashutils.h @@ -4,6 +4,8 @@ #include <cstddef> #include <functional> +#include <string> +#include <string_view> #include <type_traits> namespace zen { @@ -35,4 +37,21 @@ CombineHashes(const Types&... Args) return Seed; } +/** Transparent string hash for use with std::unordered_map/set. + Enables heterogeneous lookup so that a std::string_view can be used to + probe a std::string-keyed container without allocating a temporary std::string. + + Usage: + std::unordered_map<std::string, V, TransparentStringHash, std::equal_to<>> Map; + Map.find(some_string_view); // no allocation + */ +struct TransparentStringHash +{ + using is_transparent = void; + + size_t operator()(std::string_view Sv) const noexcept { return std::hash<std::string_view>{}(Sv); } + size_t operator()(const std::string& S) const noexcept { return std::hash<std::string_view>{}(S); } + size_t operator()(const char* S) const noexcept { return std::hash<std::string_view>{}(S); } +}; + } // namespace zen diff --git a/src/zencore/include/zencore/iobuffer.h b/src/zencore/include/zencore/iobuffer.h index 82c201edd..c6ba90692 100644 --- a/src/zencore/include/zencore/iobuffer.h +++ b/src/zencore/include/zencore/iobuffer.h @@ -109,10 +109,11 @@ public: // Reference counting - inline uint32_t AddRef() const { return AtomicIncrement(const_cast<IoBufferCore*>(this)->m_RefCount); } - inline uint32_t Release() const + // See zen::RefCounted::AddRef/Release for ordering rationale. + inline uint32_t AddRef() const noexcept { return m_RefCount.fetch_add(1, std::memory_order_relaxed) + 1; } + inline uint32_t Release() const noexcept { - const uint32_t NewRefCount = AtomicDecrement(const_cast<IoBufferCore*>(this)->m_RefCount); + const uint32_t NewRefCount = m_RefCount.fetch_sub(1, std::memory_order_acq_rel) - 1; if (NewRefCount == 0) { DeleteThis(); @@ -130,7 +131,7 @@ public: // void Materialize() const; - void DeleteThis() const; + void DeleteThis() const noexcept; void MakeOwned(bool Immutable = true); inline void EnsureDataValid() const @@ -228,14 +229,14 @@ public: return ZenContentType((m_Flags.load(std::memory_order_relaxed) >> kContentTypeShift) & kContentTypeMask); } - inline uint32_t GetRefCount() const { return m_RefCount; } + inline uint32_t GetRefCount() const noexcept { return m_RefCount.load(std::memory_order_relaxed); } protected: - uint32_t m_RefCount = 0; + mutable std::atomic<uint32_t> m_RefCount = 0; mutable std::atomic<uint32_t> m_Flags{0}; mutable const void* m_DataPtr = nullptr; size_t m_DataBytes = 0; - RefPtr<const IoBufferCore> m_OuterCore; + Ref<const IoBufferCore> m_OuterCore; enum { @@ -413,9 +414,9 @@ public: private: // We have a shared "null" buffer core which we share, this is initialized static and never released which will // cause a memory leak at exit. This does however save millions of memory allocations for null buffers - static RefPtr<IoBufferCore> NullBufferCore; + static Ref<IoBufferCore> NullBufferCore; - RefPtr<IoBufferCore> m_Core = NullBufferCore; + Ref<IoBufferCore> m_Core = NullBufferCore; IoBuffer(IoBufferCore* Core) : m_Core(Core) {} diff --git a/src/zencore/include/zencore/iohash.h b/src/zencore/include/zencore/iohash.h index a619b0053..50c439b70 100644 --- a/src/zencore/include/zencore/iohash.h +++ b/src/zencore/include/zencore/iohash.h @@ -54,6 +54,7 @@ struct IoHash static bool TryParse(std::string_view Str, IoHash& Hash); const char* ToHexString(char* outString /* 40 characters + NUL terminator */) const; StringBuilderBase& ToHexString(StringBuilderBase& outBuilder) const; + StringBuilderBase& ToString(StringBuilderBase& outBuilder) const { return ToHexString(outBuilder); } std::string ToHexString() const; static constexpr int StringLength = 40; diff --git a/src/zencore/include/zencore/logbase.h b/src/zencore/include/zencore/logbase.h index ad2ab218d..65f8a9dbe 100644 --- a/src/zencore/include/zencore/logbase.h +++ b/src/zencore/include/zencore/logbase.h @@ -8,7 +8,7 @@ #include <string_view> namespace zen::logging { -enum LogLevel : int +enum LogLevel : int8_t { Trace, Debug, @@ -22,6 +22,7 @@ enum LogLevel : int LogLevel ParseLogLevelString(std::string_view String); std::string_view ToStringView(LogLevel Level); +std::string_view ShortToStringView(LogLevel Level); void SetLogLevel(LogLevel NewLogLevel); LogLevel GetLogLevel(); @@ -49,11 +50,16 @@ struct SourceLocation */ struct LogPoint { - SourceLocation Location; + const char* Filename; + int Line; LogLevel Level; std::string_view FormatString; + + [[nodiscard]] SourceLocation Location() const { return SourceLocation{Filename, Line}; } }; +static_assert(sizeof(LogPoint) <= 32); + class Logger; /** This is the base class for all loggers @@ -91,6 +97,7 @@ struct LoggerRef { LoggerRef() = default; explicit LoggerRef(logging::Logger& InLogger); + explicit LoggerRef(std::string_view LogCategory); // This exists so that logging macros can pass LoggerRef or LogCategory // to ZEN_LOG without needing to know which one it is @@ -104,6 +111,8 @@ struct LoggerRef bool ShouldLog(logging::LogLevel Level) const { return m_Logger->ShouldLog(Level); } void SetLogLevel(logging::LogLevel NewLogLevel) { m_Logger->SetLevel(NewLogLevel); } logging::LogLevel GetLogLevel() { return m_Logger->GetLevel(); } + std::string_view GetLogLevelString() { return logging::ToStringView(GetLogLevel()); } + std::string_view GetShortLogLevelString() { return logging::ShortToStringView(GetLogLevel()); } void Flush(); diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h index 3427991d2..cf011fb1a 100644 --- a/src/zencore/include/zencore/logging.h +++ b/src/zencore/include/zencore/logging.h @@ -90,6 +90,34 @@ using zen::ConsoleLog; using zen::ErrorLog; using zen::Log; +//////////////////////////////////////////////////////////////////////// +// Color helpers + +#define ZEN_RED(str) "\033[31m" str "\033[0m" +#define ZEN_GREEN(str) "\033[32m" str "\033[0m" +#define ZEN_YELLOW(str) "\033[33m" str "\033[0m" +#define ZEN_BLUE(str) "\033[34m" str "\033[0m" +#define ZEN_MAGENTA(str) "\033[35m" str "\033[0m" +#define ZEN_CYAN(str) "\033[36m" str "\033[0m" +#define ZEN_WHITE(str) "\033[37m" str "\033[0m" + +#define ZEN_BRIGHT_RED(str) "\033[91m" str "\033[0m" +#define ZEN_BRIGHT_GREEN(str) "\033[92m" str "\033[0m" +#define ZEN_BRIGHT_YELLOW(str) "\033[93m" str "\033[0m" +#define ZEN_BRIGHT_BLUE(str) "\033[94m" str "\033[0m" +#define ZEN_BRIGHT_MAGENTA(str) "\033[95m" str "\033[0m" +#define ZEN_BRIGHT_CYAN(str) "\033[96m" str "\033[0m" +#define ZEN_BRIGHT_WHITE(str) "\033[97m" str "\033[0m" + +#define ZEN_BOLD(str) "\033[1m" str "\033[0m" +#define ZEN_UNDERLINE(str) "\033[4m" str "\033[0m" +#define ZEN_DIM(str) "\033[2m" str "\033[0m" +#define ZEN_ITALIC(str) "\033[3m" str "\033[0m" +#define ZEN_STRIKETHROUGH(str) "\033[9m" str "\033[0m" +#define ZEN_INVERSE(str) "\033[7m" str "\033[0m" + +//////////////////////////////////////////////////////////////////////// + #if ZEN_BUILD_DEBUG # define ZEN_CHECK_FORMAT_STRING(fmtstr, ...) \ while (false) \ @@ -103,31 +131,31 @@ using zen::Log; } #endif -#define ZEN_LOG_WITH_LOCATION(InLogger, InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - static constinit ZEN_LOG_SECTION(".zlog$l") \ - zen::logging::LogPoint LogPoint{zen::logging::SourceLocation{__FILE__, __LINE__}, InLevel, std::string_view(fmtstr)}; \ - zen::LoggerRef Logger = InLogger; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - if (Logger.ShouldLog(InLevel)) \ - { \ - zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ - } \ +#define ZEN_LOG_WITH_LOCATION(InLogger, InLevel, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + static constinit ZEN_LOG_SECTION(".zlog$l") \ + zen::logging::LogPoint LogPoint{__FILE__, __LINE__, InLevel, std::string_view(fmtstr)}; \ + zen::LoggerRef Logger = InLogger; \ + ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ + if (Logger.ShouldLog(InLevel)) \ + { \ + zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ + } \ } while (false); -#define ZEN_LOG(InLogger, InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \ - zen::LoggerRef Logger = InLogger; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - if (Logger.ShouldLog(InLevel)) \ - { \ - zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ - } \ +#define ZEN_LOG(InLogger, InLevel, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{0, 0, InLevel, std::string_view(fmtstr)}; \ + zen::LoggerRef Logger = InLogger; \ + ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ + if (Logger.ShouldLog(InLevel)) \ + { \ + zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ + } \ } while (false); #define ZEN_DEFINE_LOG_CATEGORY_STATIC(Category, Name) \ @@ -147,13 +175,18 @@ using zen::Log; #define ZEN_ERROR(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Err, fmtstr, ##__VA_ARGS__) #define ZEN_CRITICAL(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Critical, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_LOG(InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - zen::logging::EmitConsoleLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ +// Routes ZEN_INFO / ZEN_WARN / ZEN_DEBUG etc. in the enclosing scope through the given logger expression +// (a LoggerRef or something convertible, e.g. a member or a context field) instead of the namespace default. +// Expand at block scope; the resulting local `Log` shadows `zen::Log()` for the rest of the block. +#define ZEN_SCOPED_LOG(Expr) auto Log = [&]() { return (Expr); } + +#define ZEN_CONSOLE_LOG(InLevel, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{0, 0, InLevel, std::string_view(fmtstr)}; \ + ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ + zen::logging::EmitConsoleLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ } while (false) #define ZEN_CONSOLE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__) diff --git a/src/zencore/include/zencore/logging/broadcastsink.h b/src/zencore/include/zencore/logging/broadcastsink.h index c2709d87c..474662888 100644 --- a/src/zencore/include/zencore/logging/broadcastsink.h +++ b/src/zencore/include/zencore/logging/broadcastsink.h @@ -17,7 +17,7 @@ namespace zen::logging { /// sink is immediately visible to all of them. This is the recommended way /// to manage "default" sinks that should be active on most loggers. /// -/// Each child sink owns its own Formatter — BroadcastSink::SetFormatter() is +/// Each child sink owns its own Formatter - BroadcastSink::SetFormatter() is /// intentionally a no-op so that per-sink formatting is not accidentally /// overwritten by registry-wide formatter changes. class BroadcastSink : public Sink @@ -63,7 +63,7 @@ public: } } - /// No-op — child sinks manage their own formatters. + /// No-op - child sinks manage their own formatters. void SetFormatter(std::unique_ptr<Formatter> /*InFormatter*/) override {} void AddSink(SinkPtr InSink) diff --git a/src/zencore/include/zencore/logging/helpers.h b/src/zencore/include/zencore/logging/helpers.h index 765aa59e3..1092e7095 100644 --- a/src/zencore/include/zencore/logging/helpers.h +++ b/src/zencore/include/zencore/logging/helpers.h @@ -116,7 +116,7 @@ ShortFilename(const char* Path) inline std::string_view LevelToShortString(LogLevel Level) { - return ToStringView(Level); + return ShortToStringView(Level); } inline std::string_view diff --git a/src/zencore/include/zencore/logging/logmsg.h b/src/zencore/include/zencore/logging/logmsg.h index 4a777c71e..644af2730 100644 --- a/src/zencore/include/zencore/logging/logmsg.h +++ b/src/zencore/include/zencore/logging/logmsg.h @@ -7,49 +7,50 @@ #include <chrono> #include <string_view> +namespace zen { +int GetCurrentThreadId(); +} + namespace zen::logging { using LogClock = std::chrono::system_clock; +/** + * This represents a single log event, with all the data needed to format and + * emit the final log message. + * + * LogMessage is what gets passed to Sinks, and it's what contains all the + * contextual information about the log event. + */ + struct LogMessage { - LogMessage() = default; - LogMessage(const LogPoint& InPoint, std::string_view InLoggerName, std::string_view InPayload) : m_LoggerName(InLoggerName) - , m_Level(InPoint.Level) , m_Time(LogClock::now()) - , m_Source(InPoint.Location) + , m_ThreadId(zen::GetCurrentThreadId()) , m_Payload(InPayload) , m_Point(&InPoint) { } - std::string_view GetPayload() const { return m_Payload; } - int GetThreadId() const { return m_ThreadId; } - LogClock::time_point GetTime() const { return m_Time; } - LogLevel GetLevel() const { return m_Level; } - std::string_view GetLoggerName() const { return m_LoggerName; } - const SourceLocation& GetSource() const { return m_Source; } - const LogPoint& GetLogPoint() const { return *m_Point; } + [[nodiscard]] std::string_view GetPayload() const { return m_Payload; } + [[nodiscard]] int GetThreadId() const { return m_ThreadId; } + [[nodiscard]] LogClock::time_point GetTime() const { return m_Time; } + [[nodiscard]] LogLevel GetLevel() const { return m_Point->Level; } + [[nodiscard]] std::string_view GetLoggerName() const { return m_LoggerName; } + [[nodiscard]] SourceLocation GetSource() const { return m_Point->Location(); } + [[nodiscard]] const LogPoint& GetLogPoint() const { return *m_Point; } void SetThreadId(int InThreadId) { m_ThreadId = InThreadId; } - void SetPayload(std::string_view InPayload) { m_Payload = InPayload; } - void SetLoggerName(std::string_view InName) { m_LoggerName = InName; } - void SetLevel(LogLevel InLevel) { m_Level = InLevel; } void SetTime(LogClock::time_point InTime) { m_Time = InTime; } - void SetSource(const SourceLocation& InSource) { m_Source = InSource; } private: - static constexpr LogPoint s_DefaultPoint{{}, Off, {}}; - - std::string_view m_LoggerName; - LogLevel m_Level = Off; - std::chrono::system_clock::time_point m_Time; - SourceLocation m_Source; - std::string_view m_Payload; - const LogPoint* m_Point = &s_DefaultPoint; - int m_ThreadId = 0; + std::string_view m_LoggerName; + LogClock::time_point m_Time; + int m_ThreadId; + std::string_view m_Payload; + const LogPoint* m_Point; }; } // namespace zen::logging diff --git a/src/zencore/include/zencore/mpscqueue.h b/src/zencore/include/zencore/mpscqueue.h index d97c433fd..38a0bc14f 100644 --- a/src/zencore/include/zencore/mpscqueue.h +++ b/src/zencore/include/zencore/mpscqueue.h @@ -11,7 +11,7 @@ using std::hardware_constructive_interference_size; using std::hardware_destructive_interference_size; #else -// 64 bytes on x86-64 │ L1_CACHE_BYTES │ L1_CACHE_SHIFT │ __cacheline_aligned │ ... +// 64 bytes on x86-64 | L1_CACHE_BYTES | L1_CACHE_SHIFT | __cacheline_aligned | ... constexpr std::size_t hardware_constructive_interference_size = 64; constexpr std::size_t hardware_destructive_interference_size = 64; #endif diff --git a/src/zencore/include/zencore/process.h b/src/zencore/include/zencore/process.h index 5ae7fad68..fd24a6d7d 100644 --- a/src/zencore/include/zencore/process.h +++ b/src/zencore/include/zencore/process.h @@ -34,7 +34,7 @@ public: /// Throws std::system_error on failure. explicit ProcessHandle(int Pid); - /// Construct from an existing native process handle. Takes ownership — + /// Construct from an existing native process handle. Takes ownership - /// the caller must not close the handle afterwards. Windows only. #if ZEN_PLATFORM_WINDOWS explicit ProcessHandle(void* NativeHandle); @@ -56,7 +56,7 @@ public: /// Same as Initialize(int) but reports errors via @p OutEc instead of throwing. void Initialize(int Pid, std::error_code& OutEc); - /// Initialize from an existing native process handle. Takes ownership — + /// Initialize from an existing native process handle. Takes ownership - /// the caller must not close the handle afterwards. Windows only. #if ZEN_PLATFORM_WINDOWS void Initialize(void* ProcessHandle); @@ -174,13 +174,18 @@ struct CreateProcOptions // allocated and no conhost.exe is spawned. Stdout/stderr still work when redirected // via pipes. Prefer this for headless worker processes. Flag_NoConsole = 1 << 3, - // Create the child in a new process group (CREATE_NEW_PROCESS_GROUP on Windows). - // Allows sending CTRL_BREAK_EVENT to the child group without affecting the parent. - Flag_Windows_NewProcessGroup = 1 << 4, + // Spawn the child as a new process group leader (its pgid = its own pid). + // On Windows: CREATE_NEW_PROCESS_GROUP, enables CTRL_BREAK_EVENT targeting. + // On POSIX: child calls setpgid(0,0) / posix_spawn with POSIX_SPAWN_SETPGROUP+pgid=0. + // Mutually exclusive with ProcessGroupId > 0. + Flag_NewProcessGroup = 1 << 4, // Allocate a hidden console for the child (CREATE_NO_WINDOW on Windows). Unlike // Flag_NoConsole the child still gets a console (and a conhost.exe) but no visible // window. Use this when the child needs a console for stdio but should not show a window. Flag_NoWindow = 1 << 5, + // Launch the child at below-normal scheduling priority. + // On Windows: BELOW_NORMAL_PRIORITY_CLASS. On POSIX: nice(5). + Flag_BelowNormalPriority = 1 << 6, }; const std::filesystem::path* WorkingDirectory = nullptr; @@ -190,16 +195,16 @@ struct CreateProcOptions StdoutPipeHandles* StderrPipe = nullptr; // Optional separate pipe for stderr. When null, stderr shares StdoutPipe. /// Additional environment variables for the child process. These are merged - /// with the parent's environment — existing variables are inherited, and + /// with the parent's environment - existing variables are inherited, and /// entries here override or add to them. std::vector<std::pair<std::string, std::string>> Environment; #if ZEN_PLATFORM_WINDOWS JobObject* AssignToJob = nullptr; // When set, the process is created suspended, assigned to the job, then resumed #else - /// POSIX process group id. When > 0, the child is placed into this process - /// group via setpgid() before exec. Use the pid of the first child as the - /// pgid to create a group, then pass the same pgid for subsequent children. + /// When > 0, child joins this existing process group. Mutually exclusive with + /// Flag_NewProcessGroup; use that flag on the first spawn to create the group, + /// then pass the resulting pid here for subsequent spawns to join it. int ProcessGroupId = 0; #endif }; diff --git a/src/zencore/include/zencore/sharedbuffer.h b/src/zencore/include/zencore/sharedbuffer.h index 3d4c19282..3183c7c0c 100644 --- a/src/zencore/include/zencore/sharedbuffer.h +++ b/src/zencore/include/zencore/sharedbuffer.h @@ -65,7 +65,7 @@ public: private: // This may be null, for a default constructed UniqueBuffer only - RefPtr<IoBufferCore> m_Buffer; + Ref<IoBufferCore> m_Buffer; friend class SharedBuffer; }; @@ -81,7 +81,7 @@ public: inline explicit SharedBuffer(IoBufferCore* Owner) : m_Buffer(Owner) {} explicit SharedBuffer(IoBuffer&& Buffer) : m_Buffer(std::move(Buffer.m_Core)) {} explicit SharedBuffer(const IoBuffer& Buffer) : m_Buffer(Buffer.m_Core) {} - explicit SharedBuffer(RefPtr<IoBufferCore>&& Owner) : m_Buffer(std::move(Owner)) {} + explicit SharedBuffer(Ref<IoBufferCore>&& Owner) : m_Buffer(std::move(Owner)) {} [[nodiscard]] const void* GetData() const { @@ -143,7 +143,7 @@ public: /** Returns true if this points to a buffer owner. */ [[nodiscard]] inline explicit operator bool() const { return !IsNull(); } - [[nodiscard]] inline IoBuffer AsIoBuffer() const { return IoBuffer(m_Buffer); } + [[nodiscard]] inline IoBuffer AsIoBuffer() const { return IoBuffer(m_Buffer.Get()); } SharedBuffer& operator=(UniqueBuffer&& Rhs) { @@ -171,7 +171,7 @@ public: [[nodiscard]] static SharedBuffer Clone(MemoryView View); private: - RefPtr<IoBufferCore> m_Buffer; + Ref<IoBufferCore> m_Buffer; }; void sharedbuffer_forcelink(); diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h index 60293a313..b4926070c 100644 --- a/src/zencore/include/zencore/string.h +++ b/src/zencore/include/zencore/string.h @@ -402,6 +402,12 @@ public: inline std::string_view ToView() const { return std::string_view(m_Base, m_CurPos - m_Base); } inline std::string ToString() const { return std::string{Data(), Size()}; } + /// Append a zero-padded decimal integer. MinWidth is the minimum number of digits (zero-padded on the left). + void AppendPaddedInt(int64_t Value, int MinWidth); + + /// Append a single character repeated Count times. + void AppendFill(char C, size_t Count); + inline void AppendCodepoint(uint32_t cp) { if (cp < 0x80) // one octet @@ -435,6 +441,24 @@ public: } }; +/// Output iterator adapter for StringBuilderBase, enabling direct use with fmt::format_to / fmt::format_to_n. +class StringBuilderAppender +{ + StringBuilderBase* m_Builder; + +public: + explicit StringBuilderAppender(StringBuilderBase& Builder) : m_Builder(&Builder) {} + + StringBuilderAppender& operator=(char C) + { + m_Builder->Append(C); + return *this; + } + StringBuilderAppender& operator*() { return *this; } + StringBuilderAppender& operator++() { return *this; } + StringBuilderAppender operator++(int) { return *this; } +}; + template<size_t N> class StringBuilder : public StringBuilderBase { @@ -609,6 +633,17 @@ ParseHexBytes(std::string_view InputString, uint8_t* OutPtr) return ParseHexBytes(InputString.data(), InputString.size(), OutPtr); } +/** Parse hex string into a byte buffer, validating that the hex string is exactly ExpectedByteCount * 2 characters. */ +inline bool +ParseHexBytes(std::string_view InputString, uint8_t* OutPtr, size_t ExpectedByteCount) +{ + if (InputString.size() != ExpectedByteCount * 2) + { + return false; + } + return ParseHexBytes(InputString.data(), InputString.size(), OutPtr); +} + inline void ToHexBytes(const uint8_t* InputData, size_t ByteCount, char* OutString) { @@ -722,6 +757,32 @@ struct NiceNum : public NiceBase inline NiceNum(uint64_t Num) { NiceNumToBuffer(Num, m_Buffer); } }; +size_t ThousandsToBuffer(uint64_t Num, std::span<char> Buffer); + +/// Integer formatted with comma thousands separators (e.g. "1,234,567") +struct ThousandsNum +{ + inline ThousandsNum(UnsignedIntegral auto Number) { ThousandsToBuffer(uint64_t(Number), m_Buffer); } + inline ThousandsNum(SignedIntegral auto Number) + { + if (Number < 0) + { + m_Buffer[0] = '-'; + ThousandsToBuffer(uint64_t(-Number), std::span<char>(m_Buffer + 1, sizeof(m_Buffer) - 1)); + } + else + { + ThousandsToBuffer(uint64_t(Number), m_Buffer); + } + } + + inline const char* c_str() const { return m_Buffer; } + inline operator std::string_view() const { return std::string_view(m_Buffer); } + +private: + char m_Buffer[28]; // max uint64: "18,446,744,073,709,551,615" (26) + NUL + sign +}; + struct NiceBytes : public NiceBase { inline NiceBytes(uint64_t Num) { NiceBytesToBuffer(Num, m_Buffer); } diff --git a/src/zencore/include/zencore/system.h b/src/zencore/include/zencore/system.h index 52dafc18b..efc9bb6d2 100644 --- a/src/zencore/include/zencore/system.h +++ b/src/zencore/include/zencore/system.h @@ -46,6 +46,11 @@ struct ExtendedSystemMetrics : SystemMetrics SystemMetrics GetSystemMetrics(); +/// Lightweight query that only refreshes fields that change at runtime +/// (available memory, uptime). Topology fields (CPU/core counts, total memory) +/// are left at their default values and must be filled from a cached snapshot. +void RefreshDynamicSystemMetrics(SystemMetrics& InOutMetrics); + void SetCpuCountForReporting(int FakeCpuCount); SystemMetrics GetSystemMetricsForReporting(); diff --git a/src/zencore/include/zencore/testutils.h b/src/zencore/include/zencore/testutils.h index 2a789d18f..68461deb2 100644 --- a/src/zencore/include/zencore/testutils.h +++ b/src/zencore/include/zencore/testutils.h @@ -62,24 +62,24 @@ struct TrueType namespace utf8test { // 2-byte UTF-8 (Latin extended) - static constexpr const char kLatin[] = u8"café_résumé"; - static constexpr const wchar_t kLatinW[] = L"café_résumé"; + static constexpr const char kLatin[] = u8"caf\xC3\xA9_r\xC3\xA9sum\xC3\xA9"; + static constexpr const wchar_t kLatinW[] = L"caf\u00E9_r\u00E9sum\u00E9"; // 2-byte UTF-8 (Cyrillic) - static constexpr const char kCyrillic[] = u8"данные"; - static constexpr const wchar_t kCyrillicW[] = L"данные"; + static constexpr const char kCyrillic[] = u8"\xD0\xB4\xD0\xB0\xD0\xBD\xD0\xBD\xD1\x8B\xD0\xB5"; + static constexpr const wchar_t kCyrillicW[] = L"\u0434\u0430\u043D\u043D\u044B\u0435"; // 3-byte UTF-8 (CJK) - static constexpr const char kCJK[] = u8"日本語"; - static constexpr const wchar_t kCJKW[] = L"日本語"; + static constexpr const char kCJK[] = u8"\xE6\x97\xA5\xE6\x9C\xAC\xE8\xAA\x9E"; + static constexpr const wchar_t kCJKW[] = L"\u65E5\u672C\u8A9E"; // Mixed scripts - static constexpr const char kMixed[] = u8"zen_éд日"; - static constexpr const wchar_t kMixedW[] = L"zen_éд日"; + static constexpr const char kMixed[] = u8"zen_\xC3\xA9\xD0\xB4\xE6\x97\xA5"; + static constexpr const wchar_t kMixedW[] = L"zen_\u00E9\u0434\u65E5"; - // 4-byte UTF-8 (supplementary plane) — string tests only, NOT filesystem - static constexpr const char kEmoji[] = u8"📦"; - static constexpr const wchar_t kEmojiW[] = L"📦"; + // 4-byte UTF-8 (supplementary plane) - string tests only, NOT filesystem + static constexpr const char kEmoji[] = u8"\xF0\x9F\x93\xA6"; + static constexpr const wchar_t kEmojiW[] = L"\U0001F4E6"; // BMP-only test strings suitable for filesystem use static constexpr const char* kFilenameSafe[] = {kLatin, kCyrillic, kCJK, kMixed}; diff --git a/src/zencore/include/zencore/thread.h b/src/zencore/include/zencore/thread.h index 56ce5904b..0f7733df5 100644 --- a/src/zencore/include/zencore/thread.h +++ b/src/zencore/include/zencore/thread.h @@ -14,7 +14,7 @@ namespace zen { -void SetCurrentThreadName(std::string_view ThreadName); +void SetCurrentThreadName(std::string_view ThreadName, int32_t SortHint = 0); /** * Reader-writer lock diff --git a/src/zencore/include/zencore/zencore.h b/src/zencore/include/zencore/zencore.h index a31950b0b..57c7e20fa 100644 --- a/src/zencore/include/zencore/zencore.h +++ b/src/zencore/include/zencore/zencore.h @@ -94,7 +94,7 @@ protected: // With no extra args: ZEN_ASSERT_MSG_("expr") -> "expr" // With a message arg: ZEN_ASSERT_MSG_("expr", "msg") -> "expr" ": " "msg" // With fmt-style args: ZEN_ASSERT_MSG_("expr", "msg", args...) -> "expr" ": " "msg" -// The extra fmt args are silently discarded here — use ZEN_ASSERT_FORMAT for those. +// The extra fmt args are silently discarded here - use ZEN_ASSERT_FORMAT for those. #define ZEN_ASSERT_MSG_SELECT_(_1, _2, N, ...) N #define ZEN_ASSERT_MSG_1_(expr) expr #define ZEN_ASSERT_MSG_2_(expr, msg, ...) expr ": " msg diff --git a/src/zencore/iobuffer.cpp b/src/zencore/iobuffer.cpp index c47c54981..529afe341 100644 --- a/src/zencore/iobuffer.cpp +++ b/src/zencore/iobuffer.cpp @@ -107,7 +107,7 @@ IoBufferCore::~IoBufferCore() } void -IoBufferCore::DeleteThis() const +IoBufferCore::DeleteThis() const noexcept { // We do this just to avoid paying for the cost of a vtable if (const IoBufferExtendedCore* _ = ExtendedCore()) @@ -210,7 +210,12 @@ IoBufferExtendedCore::~IoBufferExtendedCore() // Mark file for deletion when final handle is closed FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE}; - SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); + if (!SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi)) + { + ZEN_WARN("SetFileInformationByHandle(DeleteOnClose) failed for file handle {}, reason '{}'", + m_FileHandle, + GetLastErrorAsString()); + } #else std::error_code Ec; std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle, Ec); @@ -447,7 +452,7 @@ GetNullBufferCore() return Core; } -RefPtr<IoBufferCore> IoBuffer::NullBufferCore(GetNullBufferCore()); +Ref<IoBufferCore> IoBuffer::NullBufferCore(GetNullBufferCore()); IoBuffer::IoBuffer(size_t InSize) : m_Core(new IoBufferCore(InSize)) { @@ -475,7 +480,7 @@ IoBuffer::IoBuffer(const IoBuffer& OuterBuffer, size_t Offset, size_t Size) } else { - m_Core = new IoBufferCore(OuterBuffer.m_Core, reinterpret_cast<const uint8_t*>(OuterBuffer.Data()) + Offset, Size); + m_Core = new IoBufferCore(OuterBuffer.m_Core.Get(), reinterpret_cast<const uint8_t*>(OuterBuffer.Data()) + Offset, Size); } } diff --git a/src/zencore/jobqueue.cpp b/src/zencore/jobqueue.cpp index 3e58fb97d..a5a82717d 100644 --- a/src/zencore/jobqueue.cpp +++ b/src/zencore/jobqueue.cpp @@ -93,7 +93,7 @@ public: { NewJobId = IdGenerator.fetch_add(1); } - RefPtr<Job> NewJob(new Job()); + Ref<Job> NewJob(new Job()); NewJob->Queue = this; NewJob->Name = Name; NewJob->Callback = std::move(JobFunc); @@ -124,7 +124,7 @@ public: QueueLock.WithExclusiveLock([&]() { if (auto It = std::find_if(QueuedJobs.begin(), QueuedJobs.end(), - [NewJobId](const RefPtr<Job>& Job) { return Job->Id.Id == NewJobId; }); + [NewJobId](const Ref<Job>& Job) { return Job->Id.Id == NewJobId; }); It != QueuedJobs.end()) { QueuedJobs.erase(It); @@ -156,7 +156,7 @@ public: Result = true; return; } - if (auto It = std::find_if(QueuedJobs.begin(), QueuedJobs.end(), [&Id](const RefPtr<Job>& Job) { return Job->Id.Id == Id.Id; }); + if (auto It = std::find_if(QueuedJobs.begin(), QueuedJobs.end(), [&Id](const Ref<Job>& Job) { return Job->Id.Id == Id.Id; }); It != QueuedJobs.end()) { ZEN_DEBUG("Cancelling queued background job {}:'{}'", (*It)->Id.Id, (*It)->Name); @@ -301,7 +301,7 @@ public: AbortedJobs.erase(It); return; } - if (auto It = std::find_if(QueuedJobs.begin(), QueuedJobs.end(), [&Id](const RefPtr<Job>& Job) { return Job->Id.Id == Id.Id; }); + if (auto It = std::find_if(QueuedJobs.begin(), QueuedJobs.end(), [&Id](const Ref<Job>& Job) { return Job->Id.Id == Id.Id; }); It != QueuedJobs.end()) { Result = Convert(JobStatus::Queued, *(*It)); @@ -340,20 +340,20 @@ public: std::atomic_uint64_t IdGenerator = 1; - std::atomic_bool InitializedFlag = false; - RwLock QueueLock; - std::deque<RefPtr<Job>> QueuedJobs; - std::unordered_map<uint64_t, RefPtr<Job>> RunningJobs; - std::unordered_map<uint64_t, RefPtr<Job>> CompletedJobs; - std::unordered_map<uint64_t, RefPtr<Job>> AbortedJobs; + std::atomic_bool InitializedFlag = false; + RwLock QueueLock; + std::deque<Ref<Job>> QueuedJobs; + std::unordered_map<uint64_t, Ref<Job>> RunningJobs; + std::unordered_map<uint64_t, Ref<Job>> CompletedJobs; + std::unordered_map<uint64_t, Ref<Job>> AbortedJobs; WorkerThreadPool WorkerPool; Latch WorkerCounter; void Worker() { - int CurrentThreadId = GetCurrentThreadId(); - RefPtr<Job> CurrentJob; + int CurrentThreadId = GetCurrentThreadId(); + Ref<Job> CurrentJob; QueueLock.WithExclusiveLock([&]() { if (!QueuedJobs.empty()) { diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp index 5ada0cac7..aa95db950 100644 --- a/src/zencore/logging.cpp +++ b/src/zencore/logging.cpp @@ -26,7 +26,7 @@ namespace { // Bootstrap logger: a minimal stdout logger that exists for the entire lifetime // of the process. TheDefaultLogger points here before InitializeLogging() runs // (and is restored here after ShutdownLogging()) so that log macros always have -// a usable target — no null checks or lazy init required on the common path. +// a usable target - no null checks or lazy init required on the common path. zen::Ref<zen::logging::Logger> s_BootstrapLogger = [] { zen::logging::SinkPtr Sink(new zen::logging::AnsiColorStdoutSink()); return zen::Ref<zen::logging::Logger>(new zen::logging::Logger("", Sink)); @@ -112,6 +112,14 @@ constinit std::string_view LevelNames[] = {std::string_view("trace", 5), std::string_view("critical", 8), std::string_view("off", 3)}; +constinit std::string_view ShortNames[] = {std::string_view("trc", 3), + std::string_view("dbg", 3), + std::string_view("inf", 3), + std::string_view("wrn", 3), + std::string_view("err", 3), + std::string_view("crt", 3), + std::string_view("off", 3)}; + LogLevel ParseLogLevelString(std::string_view Name) { @@ -139,12 +147,27 @@ ParseLogLevelString(std::string_view Name) std::string_view ToStringView(LogLevel Level) { + using namespace std::literals; + if (int(Level) < LogLevelCount) { return LevelNames[int(Level)]; } - return "None"; + return "None"sv; +} + +std::string_view +ShortToStringView(LogLevel Level) +{ + using namespace std::literals; + + if (int(Level) < LogLevelCount) + { + return ShortNames[int(Level)]; + } + + return "None"sv; } } // namespace zen::logging @@ -476,6 +499,10 @@ LoggerRef::LoggerRef(logging::Logger& InLogger) : m_Logger(static_cast<logging:: { } +LoggerRef::LoggerRef(std::string_view LogCategory) : m_Logger(zen::logging::Get(LogCategory).m_Logger) +{ +} + void LoggerRef::Flush() { diff --git a/src/zencore/logging/ansicolorsink.cpp b/src/zencore/logging/ansicolorsink.cpp index 03aae068a..fb127bede 100644 --- a/src/zencore/logging/ansicolorsink.cpp +++ b/src/zencore/logging/ansicolorsink.cpp @@ -5,6 +5,7 @@ #include <zencore/logging/messageonlyformatter.h> #include <zencore/thread.h> +#include <zencore/timer.h> #include <cstdio> #include <cstdlib> @@ -22,48 +23,37 @@ namespace zen::logging { -// Default formatter replicating spdlog's %+ pattern: -// [YYYY-MM-DD HH:MM:SS.mmm] [logger_name] [level] message\n +// Default formatter for console output: +// [HH:MM:SS.mmm] [logger_name] [level] message\n +// Timestamps show elapsed time since process launch. class DefaultConsoleFormatter : public Formatter { public: + DefaultConsoleFormatter() : m_Epoch(std::chrono::system_clock::now() - std::chrono::milliseconds(GetTimeSinceProcessStart())) {} + void Format(const LogMessage& Msg, MemoryBuffer& Dest) override { - // timestamp - auto Secs = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch()); - if (Secs != m_LastLogSecs) - { - m_LastLogSecs = Secs; - m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); - } + // Elapsed time since process launch + auto Elapsed = Msg.GetTime() - m_Epoch; + auto TotalSecs = std::chrono::duration_cast<std::chrono::seconds>(Elapsed); + int Count = static_cast<int>(TotalSecs.count()); + int LogSecs = Count % 60; + Count /= 60; + int LogMins = Count % 60; + int LogHours = Count / 60; Dest.push_back('['); - helpers::AppendInt(m_CachedLocalTm.tm_year + 1900, Dest); - Dest.push_back('-'); - helpers::Pad2(m_CachedLocalTm.tm_mon + 1, Dest); - Dest.push_back('-'); - helpers::Pad2(m_CachedLocalTm.tm_mday, Dest); - Dest.push_back(' '); - helpers::Pad2(m_CachedLocalTm.tm_hour, Dest); + helpers::Pad2(LogHours, Dest); Dest.push_back(':'); - helpers::Pad2(m_CachedLocalTm.tm_min, Dest); + helpers::Pad2(LogMins, Dest); Dest.push_back(':'); - helpers::Pad2(m_CachedLocalTm.tm_sec, Dest); + helpers::Pad2(LogSecs, Dest); Dest.push_back('.'); auto Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime()); helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest); Dest.push_back(']'); Dest.push_back(' '); - // logger name - if (Msg.GetLoggerName().size() > 0) - { - Dest.push_back('['); - helpers::AppendStringView(Msg.GetLoggerName(), Dest); - Dest.push_back(']'); - Dest.push_back(' '); - } - // level Dest.push_back('['); if (IsColorEnabled()) @@ -78,6 +68,25 @@ public: Dest.push_back(']'); Dest.push_back(' '); + using namespace std::literals; + + // logger name + if (Msg.GetLoggerName().size() > 0) + { + if (IsColorEnabled()) + { + Dest.append("\033[97m"sv); + } + Dest.push_back('['); + helpers::AppendStringView(Msg.GetLoggerName(), Dest); + Dest.push_back(']'); + if (IsColorEnabled()) + { + Dest.append("\033[0m"sv); + } + Dest.push_back(' '); + } + // message (align continuation lines with the first line) size_t AnsiBytes = IsColorEnabled() ? (helpers::AnsiColorForLevel(Msg.GetLevel()).size() + helpers::kAnsiReset.size()) : 0; size_t LinePrefixCount = Dest.size() - AnsiBytes; @@ -128,8 +137,7 @@ public: } private: - std::chrono::seconds m_LastLogSecs{0}; - std::tm m_CachedLocalTm{}; + LogClock::time_point m_Epoch; }; bool @@ -197,7 +205,7 @@ IsColorTerminal() // Windows console supports ANSI color by default in modern versions return true; #else - // Unknown terminal — be conservative + // Unknown terminal - be conservative return false; #endif } diff --git a/src/zencore/logging/registry.cpp b/src/zencore/logging/registry.cpp index 383a5d8ba..0f552aced 100644 --- a/src/zencore/logging/registry.cpp +++ b/src/zencore/logging/registry.cpp @@ -137,7 +137,7 @@ struct Registry::Impl { if (Pattern.find_first_of("*?") == std::string::npos) { - // Exact match — fast path via map lookup. + // Exact match - fast path via map lookup. auto It = m_Loggers.find(Pattern); if (It != m_Loggers.end()) { @@ -146,7 +146,7 @@ struct Registry::Impl } else { - // Wildcard pattern — iterate all loggers. + // Wildcard pattern - iterate all loggers. for (auto& [Name, CurLogger] : m_Loggers) { if (MatchLoggerPattern(Pattern, Name)) diff --git a/src/zencore/memory/memory.cpp b/src/zencore/memory/memory.cpp index 9e19c5db7..8dbb04e64 100644 --- a/src/zencore/memory/memory.cpp +++ b/src/zencore/memory/memory.cpp @@ -44,10 +44,10 @@ InitGMalloc() // when using sanitizers, we should use the default/ansi allocator #if ZEN_OVERRIDE_NEW_DELETE -# if ZEN_RPMALLOC_ENABLED +# if ZEN_MIMALLOC_ENABLED if (Malloc == MallocImpl::None) { - Malloc = MallocImpl::Rpmalloc; + Malloc = MallocImpl::Mimalloc; } # endif #endif diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp index e7baa3f8e..66062df4d 100644 --- a/src/zencore/process.cpp +++ b/src/zencore/process.cpp @@ -28,6 +28,7 @@ ZEN_THIRD_PARTY_INCLUDES_START # include <pthread.h> # include <signal.h> # include <sys/file.h> +# include <sys/resource.h> # include <sys/sem.h> # include <sys/stat.h> # include <sys/syscall.h> @@ -37,7 +38,9 @@ ZEN_THIRD_PARTY_INCLUDES_START #endif #if ZEN_PLATFORM_MAC +# include <crt_externs.h> # include <libproc.h> +# include <spawn.h> # include <sys/types.h> # include <sys/sysctl.h> #endif @@ -135,8 +138,68 @@ IsZombieProcess(int pid, std::error_code& OutEc) } return false; } + +static char** +GetEnviron() +{ + return *_NSGetEnviron(); +} #endif // ZEN_PLATFORM_MAC +#if ZEN_PLATFORM_LINUX +static char** +GetEnviron() +{ + return environ; +} +#endif // ZEN_PLATFORM_LINUX + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +// Holds a null-terminated envp array built by merging the current process environment with +// a set of overrides. When Overrides is empty, Data points directly to environ (no allocation). +// Must outlive any posix_spawn / execve call that receives Data. +struct EnvpHolder +{ + char** Data = GetEnviron(); + + explicit EnvpHolder(const std::vector<std::pair<std::string, std::string>>& Overrides) + { + if (Overrides.empty()) + { + return; + } + std::map<std::string, std::string> EnvMap; + for (char** E = GetEnviron(); *E; ++E) + { + std::string_view Entry(*E); + const size_t EqPos = Entry.find('='); + if (EqPos != std::string_view::npos) + { + EnvMap[std::string(Entry.substr(0, EqPos))] = std::string(Entry.substr(EqPos + 1)); + } + } + for (const auto& [Key, Value] : Overrides) + { + EnvMap[Key] = Value; + } + for (const auto& [Key, Value] : EnvMap) + { + m_Strings.push_back(Key + "=" + Value); + } + for (std::string& S : m_Strings) + { + m_Ptrs.push_back(S.data()); + } + m_Ptrs.push_back(nullptr); + Data = m_Ptrs.data(); + } + +private: + std::vector<std::string> m_Strings; + std::vector<char*> m_Ptrs; +}; +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + ////////////////////////////////////////////////////////////////////////// // Pipe creation for child process stdout capture @@ -444,7 +507,7 @@ ProcessHandle::Kill() std::error_code Ec; if (!Wait(5000, Ec)) { - // Graceful shutdown timed out — force-kill + // Graceful shutdown timed out - force-kill kill(pid_t(m_Pid), SIGKILL); Wait(1000, Ec); } @@ -691,6 +754,7 @@ BuildArgV(std::vector<char*>& Out, char* CommandLine) ++Cursor; } } + #endif // !WINDOWS || TESTS #if ZEN_PLATFORM_WINDOWS @@ -766,10 +830,14 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma { CreationFlags |= CREATE_NO_WINDOW; } - if (Options.Flags & CreateProcOptions::Flag_Windows_NewProcessGroup) + if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup) { CreationFlags |= CREATE_NEW_PROCESS_GROUP; } + if (Options.Flags & CreateProcOptions::Flag_BelowNormalPriority) + { + CreationFlags |= BELOW_NORMAL_PRIORITY_CLASS; + } if (AssignToJob) { CreationFlags |= CREATE_SUSPENDED; @@ -980,6 +1048,10 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C { CreateProcFlags |= CREATE_NO_WINDOW; } + if (Options.Flags & CreateProcOptions::Flag_BelowNormalPriority) + { + CreateProcFlags |= BELOW_NORMAL_PRIORITY_CLASS; + } if (AssignToJob) { CreateProcFlags |= CREATE_SUSPENDED; @@ -1070,23 +1142,30 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine } return CreateProcNormal(Executable, CommandLine, Options); -#else +#elif ZEN_PLATFORM_LINUX + // vfork uses CLONE_VM|CLONE_VFORK: the child shares the parent's address space and the + // parent is suspended until the child calls exec or _exit. This avoids page-table duplication + // and the ENOMEM that fork() produces on systems with strict overcommit (vm.overcommit_memory=2). + // All child-side setup uses only syscalls that do not modify user-space memory. + // Environment overrides are merged into envp before vfork so that setenv() is never called + // from the child (which would corrupt the shared address space). std::vector<char*> ArgV; std::string CommandLineZ(CommandLine); BuildArgV(ArgV, CommandLineZ.data()); ArgV.push_back(nullptr); - int ChildPid = fork(); + EnvpHolder Envp(Options.Environment); + + int ChildPid = vfork(); if (ChildPid < 0) { - ThrowLastError("Failed to fork a new child process"); + ThrowLastError("Failed to vfork a new child process"); } else if (ChildPid == 0) { if (Options.WorkingDirectory != nullptr) { - int Result = chdir(Options.WorkingDirectory->c_str()); - ZEN_UNUSED(Result); + chdir(Options.WorkingDirectory->c_str()); } if (Options.StdoutPipe != nullptr && Options.StdoutPipe->WriteFd >= 0) @@ -1118,23 +1197,109 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine } } - if (Options.ProcessGroupId > 0) + if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup) + { + setpgid(0, 0); + } + else if (Options.ProcessGroupId > 0) { setpgid(0, Options.ProcessGroupId); } - for (const auto& [Key, Value] : Options.Environment) + execve(Executable.c_str(), ArgV.data(), Envp.Data); + _exit(127); + } + + if (Options.Flags & CreateProcOptions::Flag_BelowNormalPriority) + { + setpriority(PRIO_PROCESS, ChildPid, 5); + } + + return ChildPid; +#else // macOS + std::vector<char*> ArgV; + std::string CommandLineZ(CommandLine); + BuildArgV(ArgV, CommandLineZ.data()); + ArgV.push_back(nullptr); + + posix_spawn_file_actions_t FileActions; + posix_spawnattr_t Attr; + + int Err = posix_spawn_file_actions_init(&FileActions); + if (Err != 0) + { + ThrowSystemError(Err, "posix_spawn_file_actions_init failed"); + } + auto FileActionsGuard = MakeGuard([&] { posix_spawn_file_actions_destroy(&FileActions); }); + + Err = posix_spawnattr_init(&Attr); + if (Err != 0) + { + ThrowSystemError(Err, "posix_spawnattr_init failed"); + } + auto AttrGuard = MakeGuard([&] { posix_spawnattr_destroy(&Attr); }); + + if (Options.WorkingDirectory != nullptr) + { + Err = posix_spawn_file_actions_addchdir_np(&FileActions, Options.WorkingDirectory->c_str()); + if (Err != 0) { - setenv(Key.c_str(), Value.c_str(), 1); + ThrowSystemError(Err, "posix_spawn_file_actions_addchdir_np failed"); } + } + + if (Options.StdoutPipe != nullptr && Options.StdoutPipe->WriteFd >= 0) + { + const int StdoutWriteFd = Options.StdoutPipe->WriteFd; + ZEN_ASSERT(StdoutWriteFd > STDERR_FILENO); + posix_spawn_file_actions_adddup2(&FileActions, StdoutWriteFd, STDOUT_FILENO); - if (execv(Executable.c_str(), ArgV.data()) < 0) + if (Options.StderrPipe != nullptr && Options.StderrPipe->WriteFd >= 0) { - ThrowLastError("Failed to exec() a new process image"); + const int StderrWriteFd = Options.StderrPipe->WriteFd; + ZEN_ASSERT(StderrWriteFd > STDERR_FILENO && StderrWriteFd != StdoutWriteFd); + posix_spawn_file_actions_adddup2(&FileActions, StderrWriteFd, STDERR_FILENO); + posix_spawn_file_actions_addclose(&FileActions, StderrWriteFd); } + else + { + posix_spawn_file_actions_adddup2(&FileActions, StdoutWriteFd, STDERR_FILENO); + } + + posix_spawn_file_actions_addclose(&FileActions, StdoutWriteFd); + } + else if (!Options.StdoutFile.empty()) + { + posix_spawn_file_actions_addopen(&FileActions, STDOUT_FILENO, Options.StdoutFile.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); + posix_spawn_file_actions_adddup2(&FileActions, STDOUT_FILENO, STDERR_FILENO); } - return ChildPid; + if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup) + { + posix_spawnattr_setflags(&Attr, POSIX_SPAWN_SETPGROUP); + posix_spawnattr_setpgroup(&Attr, 0); + } + else if (Options.ProcessGroupId > 0) + { + posix_spawnattr_setflags(&Attr, POSIX_SPAWN_SETPGROUP); + posix_spawnattr_setpgroup(&Attr, Options.ProcessGroupId); + } + + EnvpHolder Envp(Options.Environment); + + pid_t ChildPid = 0; + Err = posix_spawn(&ChildPid, Executable.c_str(), &FileActions, &Attr, ArgV.data(), Envp.Data); + if (Err != 0) + { + ThrowSystemError(Err, "Failed to posix_spawn a new child process"); + } + + if (Options.Flags & CreateProcOptions::Flag_BelowNormalPriority) + { + setpriority(PRIO_PROCESS, ChildPid, 5); + } + + return int(ChildPid); #endif } @@ -1252,14 +1417,28 @@ JobObject::Initialize() } JOBOBJECT_EXTENDED_LIMIT_INFORMATION LimitInfo = {}; - LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION; if (!SetInformationJobObject(m_JobHandle, JobObjectExtendedLimitInformation, &LimitInfo, sizeof(LimitInfo))) { ZEN_WARN("Failed to set job object limits: {}", zen::GetLastError()); CloseHandle(m_JobHandle); m_JobHandle = nullptr; + return; } + + // Prevent child processes from clearing SEM_NOGPFAULTERRORBOX, which + // suppresses WER/Dr. Watson crash dialogs. Without this, a crashing + // child can pop a modal dialog and block the monitor thread. +# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE) +# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400 +# endif + JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{}; + UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE; + SetInformationJobObject(m_JobHandle, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions)); + + // Set error mode on the current process so children inherit it. + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX); } bool @@ -1576,7 +1755,7 @@ GetProcessCommandLine(int Pid, std::error_code& OutEc) ++p; // skip null terminator of argv[0] } - // Build result: remaining entries joined by spaces (inter-arg nulls → spaces) + // Build result: remaining entries joined by spaces (inter-arg nulls -> spaces) std::string Result; Result.reserve(static_cast<size_t>(End - p)); for (const char* q = p; q < End; ++q) @@ -1914,7 +2093,7 @@ GetProcessMetrics(const ProcessHandle& Handle, ProcessMetrics& OutMetrics) { Buf[Len] = '\0'; - // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens + // Skip past "pid (name) " - find last ')' to handle names containing spaces or parens const char* P = strrchr(Buf, ')'); if (P) { diff --git a/src/zencore/refcount.cpp b/src/zencore/refcount.cpp index f19afe715..674b154e0 100644 --- a/src/zencore/refcount.cpp +++ b/src/zencore/refcount.cpp @@ -35,29 +35,29 @@ refcount_forcelink() TEST_SUITE_BEGIN("core.refcount"); -TEST_CASE("RefPtr") +TEST_CASE("Ref") { - RefPtr<TestRefClass> Ref; - Ref = new TestRefClass; + Ref<TestRefClass> RefA; + RefA = new TestRefClass; bool IsDestroyed = false; - Ref->OnDestroy = [&] { IsDestroyed = true; }; + RefA->OnDestroy = [&] { IsDestroyed = true; }; CHECK(IsDestroyed == false); - CHECK(Ref->RefCount() == 1); + CHECK(RefA->RefCount() == 1); - RefPtr<TestRefClass> Ref2; - Ref2 = Ref; + Ref<TestRefClass> RefB; + RefB = RefA; CHECK(IsDestroyed == false); - CHECK(Ref->RefCount() == 2); + CHECK(RefA->RefCount() == 2); - RefPtr<TestRefClass> Ref3; - Ref2 = Ref3; + Ref<TestRefClass> RefC; + RefB = RefC; CHECK(IsDestroyed == false); - CHECK(Ref->RefCount() == 1); - Ref = Ref3; + CHECK(RefA->RefCount() == 1); + RefA = RefC; CHECK(IsDestroyed == true); } diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp index 8491bef64..7e3f33191 100644 --- a/src/zencore/sentryintegration.cpp +++ b/src/zencore/sentryintegration.cpp @@ -250,7 +250,7 @@ SentryIntegration::Initialize(const Config& Conf, const std::string& CommandLine if (SentryOptions == nullptr) { - // OOM — skip sentry entirely rather than crashing on the subsequent set calls + // OOM - skip sentry entirely rather than crashing on the subsequent set calls m_SentryErrorCode = -1; m_IsInitialized = true; return; diff --git a/src/zencore/sharedbuffer.cpp b/src/zencore/sharedbuffer.cpp index 8dc6d49d8..48730e670 100644 --- a/src/zencore/sharedbuffer.cpp +++ b/src/zencore/sharedbuffer.cpp @@ -100,7 +100,7 @@ SharedBuffer::MakeView(MemoryView View, SharedBuffer OuterBuffer) return OuterBuffer; } - IoBufferCore* NewCore = new IoBufferCore(OuterBuffer.m_Buffer, View.GetData(), View.GetSize()); + IoBufferCore* NewCore = new IoBufferCore(OuterBuffer.m_Buffer.Get(), View.GetData(), View.GetSize()); NewCore->SetIsImmutable(true); return SharedBuffer(NewCore); } diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp index 358722b0b..44f78aa75 100644 --- a/src/zencore/string.cpp +++ b/src/zencore/string.cpp @@ -381,6 +381,34 @@ NiceNumGeneral(uint64_t Num, std::span<char> Buffer, NicenumFormat Format) } size_t +ThousandsToBuffer(uint64_t Num, std::span<char> Buffer) +{ + // Format into a temporary buffer without separators + char Tmp[24]; + int Len = snprintf(Tmp, sizeof(Tmp), "%llu", (unsigned long long)Num); + + // Insert comma separators + int SepCount = (Len - 1) / 3; + int TotalLen = Len + SepCount; + ZEN_ASSERT(TotalLen < (int)Buffer.size()); + + int Src = Len - 1; + int Dst = TotalLen; + Buffer[Dst--] = '\0'; + + for (int i = 0; Src >= 0; i++) + { + if (i > 0 && i % 3 == 0) + { + Buffer[Dst--] = ','; + } + Buffer[Dst--] = Tmp[Src--]; + } + + return TotalLen; +} + +size_t NiceNumToBuffer(uint64_t Num, std::span<char> Buffer) { return NiceNumGeneral(Num, Buffer, kNicenum1024); @@ -515,6 +543,40 @@ template class StringBuilderImpl<wchar_t>; ////////////////////////////////////////////////////////////////////////// void +StringBuilderBase::AppendPaddedInt(int64_t Value, int MinWidth) +{ + char Buf[24]; + char* End = Buf + sizeof(Buf); + char* Ptr = End; + bool Negative = Value < 0; + uint64_t Abs = Negative ? uint64_t(-Value) : uint64_t(Value); + do + { + *--Ptr = '0' + char(Abs % 10); + Abs /= 10; + } while (Abs > 0); + while ((End - Ptr) < MinWidth) + { + *--Ptr = '0'; + } + if (Negative) + { + *--Ptr = '-'; + } + AppendRange(Ptr, End); +} + +void +StringBuilderBase::AppendFill(char C, size_t Count) +{ + EnsureCapacity(Count); + std::memset(m_CurPos, C, Count); + m_CurPos += Count; +} + +////////////////////////////////////////////////////////////////////////// + +void UrlDecode(std::string_view InUrl, StringBuilderBase& OutUrl) { std::string_view::size_type i = 0; diff --git a/src/zencore/system.cpp b/src/zencore/system.cpp index 6909e1a9b..486050d83 100644 --- a/src/zencore/system.cpp +++ b/src/zencore/system.cpp @@ -148,6 +148,18 @@ GetSystemMetrics() return Metrics; } +void +RefreshDynamicSystemMetrics(SystemMetrics& Metrics) +{ + MEMORYSTATUSEX MemStatus{.dwLength = sizeof(MEMORYSTATUSEX)}; + GlobalMemoryStatusEx(&MemStatus); + + Metrics.AvailSystemMemoryMiB = MemStatus.ullAvailPhys / 1024 / 1024; + Metrics.AvailVirtualMemoryMiB = MemStatus.ullAvailVirtual / 1024 / 1024; + Metrics.AvailPageFileMiB = MemStatus.ullAvailPageFile / 1024 / 1024; + Metrics.UptimeSeconds = GetTickCount64() / 1000; +} + std::vector<std::string> GetLocalIpAddresses() { @@ -324,6 +336,51 @@ GetSystemMetrics() return Metrics; } + +void +RefreshDynamicSystemMetrics(SystemMetrics& Metrics) +{ + long PageSize = sysconf(_SC_PAGE_SIZE); + long AvailPages = sysconf(_SC_AVPHYS_PAGES); + + if (AvailPages > 0 && PageSize > 0) + { + Metrics.AvailSystemMemoryMiB = (AvailPages * PageSize) / 1024 / 1024; + Metrics.AvailVirtualMemoryMiB = Metrics.AvailSystemMemoryMiB; + } + + if (FILE* UptimeFile = fopen("/proc/uptime", "r")) + { + double UptimeSec = 0; + if (fscanf(UptimeFile, "%lf", &UptimeSec) == 1) + { + Metrics.UptimeSeconds = static_cast<uint64_t>(UptimeSec); + } + fclose(UptimeFile); + } + + if (FILE* MemInfo = fopen("/proc/meminfo", "r")) + { + char Line[256]; + long SwapFree = 0; + + while (fgets(Line, sizeof(Line), MemInfo)) + { + if (strncmp(Line, "SwapFree:", 9) == 0) + { + sscanf(Line, "SwapFree: %ld kB", &SwapFree); + break; + } + } + fclose(MemInfo); + + if (SwapFree > 0) + { + Metrics.AvailPageFileMiB = SwapFree / 1024; + } + } +} + #elif ZEN_PLATFORM_MAC std::string GetMachineName() @@ -398,6 +455,36 @@ GetSystemMetrics() return Metrics; } + +void +RefreshDynamicSystemMetrics(SystemMetrics& Metrics) +{ + vm_size_t PageSize = 0; + host_page_size(mach_host_self(), &PageSize); + + vm_statistics64_data_t VmStats; + mach_msg_type_number_t InfoCount = sizeof(VmStats) / sizeof(natural_t); + host_statistics64(mach_host_self(), HOST_VM_INFO64, (host_info64_t)&VmStats, &InfoCount); + + uint64_t FreeMemory = (uint64_t)(VmStats.free_count + VmStats.inactive_count) * PageSize; + Metrics.AvailSystemMemoryMiB = FreeMemory / 1024 / 1024; + Metrics.AvailVirtualMemoryMiB = Metrics.VirtualMemoryMiB; + + xsw_usage SwapUsage; + size_t Size = sizeof(SwapUsage); + sysctlbyname("vm.swapusage", &SwapUsage, &Size, nullptr, 0); + Metrics.AvailPageFileMiB = (SwapUsage.xsu_total - SwapUsage.xsu_used) / 1024 / 1024; + + struct timeval BootTime + { + }; + Size = sizeof(BootTime); + if (sysctlbyname("kern.boottime", &BootTime, &Size, nullptr, 0) == 0) + { + Metrics.UptimeSeconds = static_cast<uint64_t>(time(nullptr) - BootTime.tv_sec); + } +} + #else # error "Unknown platform" #endif @@ -655,11 +742,16 @@ struct SystemMetricsTracker::Impl std::mutex Mutex; CpuSampler Sampler; + SystemMetrics CachedMetrics; float CachedCpuPercent = 0.0f; Clock::time_point NextSampleTime = Clock::now(); std::chrono::milliseconds MinInterval; - explicit Impl(std::chrono::milliseconds InMinInterval) : MinInterval(InMinInterval) {} + explicit Impl(std::chrono::milliseconds InMinInterval) : MinInterval(InMinInterval) + { + // Capture topology and total memory once; these don't change at runtime + CachedMetrics = GetSystemMetrics(); + } float SampleCpu() { @@ -683,7 +775,8 @@ ExtendedSystemMetrics SystemMetricsTracker::Query() { ExtendedSystemMetrics Metrics; - static_cast<SystemMetrics&>(Metrics) = GetSystemMetrics(); + static_cast<SystemMetrics&>(Metrics) = m_Impl->CachedMetrics; + RefreshDynamicSystemMetrics(Metrics); std::lock_guard Lock(m_Impl->Mutex); Metrics.CpuUsagePercent = m_Impl->SampleCpu(); diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp index 9f88a3365..67285dcf1 100644 --- a/src/zencore/testing.cpp +++ b/src/zencore/testing.cpp @@ -39,7 +39,7 @@ PrintCrashCallstack([[maybe_unused]] const char* SignalName) // Use write() + backtrace_symbols_fd() which are async-signal-safe write(STDERR_FILENO, "\n*** Caught ", 12); write(STDERR_FILENO, SignalName, strlen(SignalName)); - write(STDERR_FILENO, " — callstack:\n", 15); + write(STDERR_FILENO, " - callstack:\n", 15); void* Frames[64]; int FrameCount = backtrace(Frames, 64); diff --git a/src/zencore/testutils.cpp b/src/zencore/testutils.cpp index c9908aec8..44446bd40 100644 --- a/src/zencore/testutils.cpp +++ b/src/zencore/testutils.cpp @@ -30,11 +30,15 @@ ScopedTemporaryDirectory::ScopedTemporaryDirectory() : m_RootPath(CreateTemporar { } -ScopedTemporaryDirectory::ScopedTemporaryDirectory(std::filesystem::path Directory) : m_RootPath(Directory) +ScopedTemporaryDirectory::ScopedTemporaryDirectory(std::filesystem::path Directory) +: m_RootPath(Directory.empty() ? CreateTemporaryDirectory() : Directory) { - std::error_code Ec; - DeleteDirectories(Directory, Ec); - CreateDirectories(Directory); + if (!Directory.empty()) + { + std::error_code Ec; + DeleteDirectories(Directory, Ec); + CreateDirectories(Directory); + } } ScopedTemporaryDirectory::~ScopedTemporaryDirectory() diff --git a/src/zencore/thread.cpp b/src/zencore/thread.cpp index 067e66c0d..fd72afaa7 100644 --- a/src/zencore/thread.cpp +++ b/src/zencore/thread.cpp @@ -99,7 +99,7 @@ SetNameInternal(DWORD thread_id, const char* name) #endif void -SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName) +SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName, [[maybe_unused]] int32_t SortHint) { constexpr std::string_view::size_type MaxThreadNameLength = 255; std::string_view LimitedThreadName = ThreadName.substr(0, MaxThreadNameLength); @@ -108,7 +108,7 @@ SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName) const int ThreadId = GetCurrentThreadId(); #if ZEN_WITH_TRACE - trace::ThreadRegister(ThreadNameZ.c_str(), /* system id */ ThreadId, /* sort id */ 0); + trace::ThreadRegister(ThreadNameZ.c_str(), /* system id */ ThreadId, /* sort id */ SortHint); #endif // ZEN_WITH_TRACE #if ZEN_PLATFORM_WINDOWS diff --git a/src/zencore/trace.cpp b/src/zencore/trace.cpp index 7c195e69f..d7084bbd1 100644 --- a/src/zencore/trace.cpp +++ b/src/zencore/trace.cpp @@ -6,6 +6,7 @@ # include <zencore/zencore.h> # include <zencore/commandline.h> # include <zencore/string.h> +# include <zencore/thread.h> # include <zencore/logging.h> # define TRACE_IMPLEMENT 1 @@ -121,7 +122,7 @@ TraceInit(std::string_view ProgramName) const char* CommandLineString = ""; # endif - trace::ThreadRegister("main", /* system id */ 0, /* sort id */ 0); + trace::ThreadRegister("main", /* system id */ GetCurrentThreadId(), /* sort id */ -1); trace::DescribeSession(ProgramName, # if ZEN_BUILD_DEBUG trace::Build::Debug, diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp index 8c29a8962..c1ac63621 100644 --- a/src/zencore/zencore.cpp +++ b/src/zencore/zencore.cpp @@ -273,7 +273,7 @@ zencore_forcelinktests() zen::uid_forcelink(); zen::uson_forcelink(); zen::usonbuilder_forcelink(); - zen::usonpackage_forcelink(); + zen::cbpackage_forcelink(); zen::cbjson_forcelink(); zen::cbyaml_forcelink(); zen::workthreadpool_forcelink(); diff --git a/src/zenhorde/README.md b/src/zenhorde/README.md new file mode 100644 index 000000000..13beaa968 --- /dev/null +++ b/src/zenhorde/README.md @@ -0,0 +1,17 @@ +# Horde Compute integration + +Zen compute can use Horde to provision runner nodes. + +## Launch a coordinator instance + +Coordinator instances provision compute resources (runners) from a compute provider such as Horde, and surface an interface which allows zenserver instances to discover endpoints which they can submit actions to. + +```bash +zenserver compute --horde-enabled --horde-server=https://horde.dev.net:13340/ --horde-max-cores=512 --horde-zen-service-port=25000 --http=asio +``` + +## Use a coordinator + +```bash +zen exec beacon --path=e:\lyra-recording --orch=http://localhost:8558 +``` diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp index 819b2d0cb..029b98e55 100644 --- a/src/zenhorde/hordeagent.cpp +++ b/src/zenhorde/hordeagent.cpp @@ -8,290 +8,479 @@ #include <zencore/logging.h> #include <zencore/trace.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <cstring> -#include <unordered_map> namespace zen::horde { -HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info) -{ - ZEN_TRACE_CPU("HordeAgent::Connect"); +// --- AsyncHordeAgent --- - auto Transport = std::make_unique<TcpComputeTransport>(Info); - if (!Transport->IsValid()) +static const char* +GetStateName(AsyncHordeAgent::State S) +{ + switch (S) { - ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort()); - return; + case AsyncHordeAgent::State::Idle: + return "idle"; + case AsyncHordeAgent::State::Connecting: + return "connect"; + case AsyncHordeAgent::State::WaitAgentAttach: + return "agent-attach"; + case AsyncHordeAgent::State::SentFork: + return "fork"; + case AsyncHordeAgent::State::WaitChildAttach: + return "child-attach"; + case AsyncHordeAgent::State::Uploading: + return "upload"; + case AsyncHordeAgent::State::Executing: + return "execute"; + case AsyncHordeAgent::State::Polling: + return "poll"; + case AsyncHordeAgent::State::Done: + return "done"; + default: + return "unknown"; } +} - // The 64-byte nonce is always sent unencrypted as the first thing on the wire. - // The Horde agent uses this to identify which lease this connection belongs to. - Transport->Send(Info.Nonce, sizeof(Info.Nonce)); +AsyncHordeAgent::AsyncHordeAgent(asio::io_context& IoContext) : m_IoContext(IoContext), m_Log(zen::logging::Get("horde.agent.async")) +{ +} - std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport); - if (Info.EncryptionMode == Encryption::AES) +AsyncHordeAgent::~AsyncHordeAgent() +{ + Cancel(); +} + +void +AsyncHordeAgent::Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone) +{ + m_Config = std::move(Config); + m_OnDone = std::move(OnDone); + m_State = State::Connecting; + DoConnect(); +} + +void +AsyncHordeAgent::Cancel() +{ + m_Cancelled = true; + if (m_Socket) { - FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport)); - if (!FinalTransport->IsValid()) - { - ZEN_WARN("failed to create AES transport"); - return; - } + m_Socket->Close(); + } + else if (m_TcpTransport) + { + // Cancelled before handshake completed - tear down the pending TCP connect. + m_TcpTransport->Close(); } +} - // Create multiplexed socket and channels - m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport)); +void +AsyncHordeAgent::DoConnect() +{ + ZEN_TRACE_CPU("AsyncHordeAgent::DoConnect"); - // Channel 0 is the agent control channel (handles Attach/Fork handshake). - // Channel 100 is the child I/O channel (handles file upload and remote execution). - Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0); - Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100); + m_TcpTransport = std::make_unique<AsyncTcpComputeTransport>(m_IoContext); + + auto Self = shared_from_this(); + m_TcpTransport->AsyncConnect(m_Config.Machine, [this, Self](const std::error_code& Ec) { OnConnected(Ec); }); +} - if (!AgentComputeChannel || !ChildComputeChannel) +void +AsyncHordeAgent::OnConnected(const std::error_code& Ec) +{ + if (Ec || m_Cancelled) { - ZEN_WARN("failed to create compute channels"); + if (Ec) + { + ZEN_WARN("connect failed: {}", Ec.message()); + } + Finish(false); return; } - m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel)); - m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel)); + // Optionally wrap with AES encryption + std::unique_ptr<AsyncComputeTransport> FinalTransport = std::move(m_TcpTransport); + if (m_Config.Machine.EncryptionMode == Encryption::AES) + { + FinalTransport = std::make_unique<AsyncAesComputeTransport>(m_Config.Machine.Key, std::move(FinalTransport), m_IoContext); + } + + // Create the multiplexed socket and register channels. Ownership of the transport + // moves into the socket here - no need to retain a separate m_Transport field. + m_Socket = std::make_shared<AsyncComputeSocket>(std::move(FinalTransport), m_IoContext); + + m_AgentChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 0, m_IoContext); + m_ChildChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 100, m_IoContext); + + m_Socket->RegisterChannel( + 0, + [this](std::vector<uint8_t> Data) { m_AgentChannel->OnFrame(std::move(Data)); }, + [this]() { m_AgentChannel->OnDetach(); }); - m_IsValid = true; + m_Socket->RegisterChannel( + 100, + [this](std::vector<uint8_t> Data) { m_ChildChannel->OnFrame(std::move(Data)); }, + [this]() { m_ChildChannel->OnDetach(); }); + + m_Socket->StartRecvPump(); + + m_State = State::WaitAgentAttach; + DoWaitAgentAttach(); } -HordeAgent::~HordeAgent() +void +AsyncHordeAgent::DoWaitAgentAttach() { - CloseConnection(); + auto Self = shared_from_this(); + m_AgentChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnAgentResponse(Type, Data, Size); + }); } -bool -HordeAgent::BeginCommunication() +void +AsyncHordeAgent::OnAgentResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/) { - ZEN_TRACE_CPU("HordeAgent::BeginCommunication"); - - if (!m_IsValid) + if (m_Cancelled) { - return false; + Finish(false); + return; } - // Start the send/recv pump threads - m_Socket->StartCommunication(); - - // Wait for Attach on agent channel - AgentMessageType Type = m_AgentChannel->ReadResponse(5000); if (Type == AgentMessageType::None) { ZEN_WARN("timed out waiting for Attach on agent channel"); - return false; + Finish(false); + return; } + if (Type != AgentMessageType::Attach) { ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type)); - return false; + Finish(false); + return; } - // Fork tells the remote agent to create child channel 100 with a 4MB buffer. - // After this, the agent will send an Attach on the child channel. + m_State = State::SentFork; + DoSendFork(); +} + +void +AsyncHordeAgent::DoSendFork() +{ m_AgentChannel->Fork(100, 4 * 1024 * 1024); - // Wait for Attach on child channel - Type = m_ChildChannel->ReadResponse(5000); + m_State = State::WaitChildAttach; + DoWaitChildAttach(); +} + +void +AsyncHordeAgent::DoWaitChildAttach() +{ + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnChildAttachResponse(Type, Data, Size); + }); +} + +void +AsyncHordeAgent::OnChildAttachResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/) +{ + if (m_Cancelled) + { + Finish(false); + return; + } + if (Type == AgentMessageType::None) { ZEN_WARN("timed out waiting for Attach on child channel"); - return false; + Finish(false); + return; } + if (Type != AgentMessageType::Attach) { ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type)); - return false; + Finish(false); + return; } - return true; + m_State = State::Uploading; + m_CurrentBundleIndex = 0; + DoUploadNext(); } -bool -HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator) +void +AsyncHordeAgent::DoUploadNext() { - ZEN_TRACE_CPU("HordeAgent::UploadBinaries"); + if (m_Cancelled) + { + Finish(false); + return; + } + + if (m_CurrentBundleIndex >= m_Config.Bundles.size()) + { + // All bundles uploaded - proceed to execute + m_State = State::Executing; + DoExecute(); + return; + } - m_ChildChannel->UploadFiles("", BundleLocator.c_str()); + const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex]; + m_ChildChannel->UploadFiles("", Locator.c_str()); - std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles; + // Enter the ReadBlob/Blob upload loop + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); +} - auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* { - std::string Key(Locator); +void +AsyncHordeAgent::OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size) +{ + if (m_Cancelled) + { + Finish(false); + return; + } - if (auto It = BlobFiles.find(Key); It != BlobFiles.end()) + if (Type == AgentMessageType::None) + { + if (m_ChildChannel->IsDetached()) { - return It->second.get(); + ZEN_WARN("connection lost during upload"); + Finish(false); + return; } + // Timeout - retry read + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); + return; + } - const std::filesystem::path Path = BundleDir / (Key + ".blob"); - std::error_code Ec; - auto File = std::make_unique<BasicFile>(); - File->Open(Path, BasicFile::Mode::kRead, Ec); + if (Type == AgentMessageType::WriteFilesResponse) + { + // This bundle upload is done - move to next + ++m_CurrentBundleIndex; + DoUploadNext(); + return; + } - if (Ec) + if (Type == AgentMessageType::Exception) + { + ExceptionInfo Ex; + if (!AsyncAgentMessageChannel::ReadException(Data, Size, Ex)) { - ZEN_ERROR("cannot read blob file: '{}'", Path); - return nullptr; + ZEN_ERROR("malformed Exception message during upload (size={})", Size); + Finish(false); + return; } + ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); + Finish(false); + return; + } - BasicFile* Ptr = File.get(); - BlobFiles.emplace(std::move(Key), std::move(File)); - return Ptr; - }; - - // The upload protocol is request-driven: we send WriteFiles, then the remote agent - // sends ReadBlob requests for each blob it needs. We respond with Blob data until - // the agent sends WriteFilesResponse indicating the upload is complete. - constexpr int32_t ReadResponseTimeoutMs = 1000; + if (Type != AgentMessageType::ReadBlob) + { + ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type)); + Finish(false); + return; + } - for (;;) + // Handle ReadBlob request + BlobRequest Req; + if (!AsyncAgentMessageChannel::ReadBlobRequest(Data, Size, Req)) { - bool TimedOut = false; + ZEN_ERROR("malformed ReadBlob message during upload (size={})", Size); + Finish(false); + return; + } - if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob) - { - if (TimedOut) - { - continue; - } - // End of stream - check if it was a successful upload - if (Type == AgentMessageType::WriteFilesResponse) - { - return true; - } - else if (Type == AgentMessageType::Exception) - { - ExceptionInfo Ex; - m_ChildChannel->ReadException(Ex); - ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); - } - else - { - ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type)); - } - return false; - } + const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex]; + const std::filesystem::path BlobPath = BundleDir / (std::string(Req.Locator) + ".blob"); - BlobRequest Req; - m_ChildChannel->ReadBlobRequest(Req); + std::error_code FsEc; + BasicFile File; + File.Open(BlobPath, BasicFile::Mode::kRead, FsEc); - BasicFile* File = FindOrOpenBlob(Req.Locator); - if (!File) - { - return false; - } + if (FsEc) + { + ZEN_ERROR("cannot read blob file: '{}'", BlobPath); + Finish(false); + return; + } - // Read from offset to end of file - const uint64_t TotalSize = File->FileSize(); - const uint64_t Offset = static_cast<uint64_t>(Req.Offset); - if (Offset >= TotalSize) - { - ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); - m_ChildChannel->Blob(nullptr, 0); - continue; - } + const uint64_t TotalSize = File.FileSize(); + const uint64_t Offset = static_cast<uint64_t>(Req.Offset); + if (Offset >= TotalSize) + { + ZEN_ERROR("blob request beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); + m_ChildChannel->Blob(nullptr, 0); + } + else + { + const IoBuffer FileData = File.ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); + m_ChildChannel->Blob(static_cast<const uint8_t*>(FileData.GetData()), FileData.GetSize()); + } + + // Continue the upload loop + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); +} - const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); - m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize()); +void +AsyncHordeAgent::DoExecute() +{ + ZEN_TRACE_CPU("AsyncHordeAgent::DoExecute"); + + std::vector<const char*> ArgPtrs; + ArgPtrs.reserve(m_Config.Args.size()); + for (const std::string& Arg : m_Config.Args) + { + ArgPtrs.push_back(Arg.c_str()); } + + m_ChildChannel->Execute(m_Config.Executable.c_str(), + ArgPtrs.data(), + ArgPtrs.size(), + nullptr, + nullptr, + 0, + m_Config.UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); + + ZEN_INFO("remote execution started on [{}:{}] lease={}", + m_Config.Machine.GetConnectionAddress(), + m_Config.Machine.GetConnectionPort(), + m_Config.Machine.LeaseId); + + m_State = State::Polling; + DoPoll(); } void -HordeAgent::Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir, - const char* const* EnvVars, - size_t NumEnvVars, - bool UseWine) +AsyncHordeAgent::DoPoll() { - ZEN_TRACE_CPU("HordeAgent::Execute"); - m_ChildChannel - ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); + if (m_Cancelled) + { + Finish(false); + return; + } + + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(100, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnPollResponse(Type, Data, Size); + }); } -bool -HordeAgent::Poll(bool LogOutput) +void +AsyncHordeAgent::OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size) { - constexpr int32_t ReadResponseTimeoutMs = 100; - AgentMessageType Type; + if (m_Cancelled) + { + Finish(false); + return; + } - while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None) + switch (Type) { - switch (Type) - { - case AgentMessageType::ExecuteOutput: + case AgentMessageType::None: + if (m_ChildChannel->IsDetached()) + { + ZEN_WARN("connection lost during execution"); + Finish(false); + } + else + { + // Timeout - poll again + DoPoll(); + } + break; + + case AgentMessageType::ExecuteOutput: + // Silently consume remote stdout (matching LogOutput=false in provisioner) + DoPoll(); + break; + + case AgentMessageType::ExecuteResult: + { + int32_t ExitCode = -1; + if (!AsyncAgentMessageChannel::ReadExecuteResult(Data, Size, ExitCode)) { - if (LogOutput && m_ChildChannel->GetResponseSize() > 0) - { - const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData()); - size_t ResponseSize = m_ChildChannel->GetResponseSize(); - - // Trim trailing newlines - while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r')) - { - --ResponseSize; - } - - if (ResponseSize > 0) - { - const std::string_view Output(ResponseData, ResponseSize); - ZEN_INFO("[remote] {}", Output); - } - } + // A remote with a malformed ExecuteResult cannot be trusted to report + // process outcome - treat as a protocol error and tear down rather than + // silently recording \"exited with -1\". + ZEN_ERROR("malformed ExecuteResult (size={}, lease={}) - disconnecting", Size, m_Config.Machine.LeaseId); + Finish(false); break; } + ZEN_INFO("remote process exited with code {} (lease={})", ExitCode, m_Config.Machine.LeaseId); + Finish(ExitCode == 0, ExitCode); + } + break; - case AgentMessageType::ExecuteResult: + case AgentMessageType::Exception: + { + ExceptionInfo Ex; + if (AsyncAgentMessageChannel::ReadException(Data, Size, Ex)) { - if (m_ChildChannel->GetResponseSize() == sizeof(int32_t)) - { - int32_t ExitCode; - memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t)); - ZEN_INFO("remote process exited with code {}", ExitCode); - } - m_IsValid = false; - return false; + ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); } - - case AgentMessageType::Exception: + else { - ExceptionInfo Ex; - m_ChildChannel->ReadException(Ex); - ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); - m_HasErrors = true; - break; + ZEN_ERROR("malformed Exception message (size={})", Size); } + Finish(false); + } + break; - default: - break; - } + default: + DoPoll(); + break; } - - return m_IsValid && !m_HasErrors; } void -HordeAgent::CloseConnection() +AsyncHordeAgent::Finish(bool Success, int32_t ExitCode) { - if (m_ChildChannel) + if (m_State == State::Done) { - m_ChildChannel->Close(); + return; // Already finished } - if (m_AgentChannel) + + if (!Success) { - m_AgentChannel->Close(); + ZEN_WARN("agent failed during {} (lease={})", GetStateName(m_State), m_Config.Machine.LeaseId); } -} -bool -HordeAgent::IsValid() const -{ - return m_IsValid && !m_HasErrors; + m_State = State::Done; + + if (m_Socket) + { + m_Socket->Close(); + } + + if (m_OnDone) + { + AsyncAgentResult Result; + Result.Success = Success; + Result.ExitCode = ExitCode; + Result.CoreCount = m_Config.Machine.LogicalCores; + + auto Handler = std::move(m_OnDone); + m_OnDone = nullptr; + Handler(Result); + } } } // namespace zen::horde diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h index e0ae89ead..a5b3248ab 100644 --- a/src/zenhorde/hordeagent.h +++ b/src/zenhorde/hordeagent.h @@ -10,68 +10,107 @@ #include <zencore/logbase.h> #include <filesystem> +#include <functional> #include <memory> #include <string> +#include <vector> + +namespace asio { +class io_context; +} namespace zen::horde { -/** Manages the lifecycle of a single Horde compute agent. +class AsyncComputeTransport; + +/** Result passed to the completion handler when an async agent finishes. */ +struct AsyncAgentResult +{ + bool Success = false; + int32_t ExitCode = -1; + uint16_t CoreCount = 0; ///< Logical cores on the provisioned machine +}; + +/** Completion handler for async agent lifecycle. */ +using AsyncAgentCompletionHandler = std::function<void(const AsyncAgentResult&)>; + +/** Configuration for launching a remote zenserver instance via an async agent. */ +struct AsyncAgentConfig +{ + MachineInfo Machine; + std::vector<std::pair<std::string, std::filesystem::path>> Bundles; ///< (locator, bundleDir) pairs + std::string Executable; + std::vector<std::string> Args; + bool UseWine = false; +}; + +/** Async agent that manages the full lifecycle of a single Horde compute connection. * - * Handles the full connection sequence for one provisioned machine: - * 1. Connect via TCP transport (with optional AES encryption wrapping) - * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100) - * 3. Perform the Attach/Fork handshake to establish the child channel - * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol - * 5. Execute zenserver remotely via ExecuteV2 - * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code) + * Driven by a state machine using callbacks on a shared io_context - no dedicated + * threads. Call Start() to begin the connection/handshake/upload/execute/poll + * sequence. The completion handler is invoked when the remote process exits or + * an error occurs. */ -class HordeAgent +class AsyncHordeAgent : public std::enable_shared_from_this<AsyncHordeAgent> { public: - explicit HordeAgent(const MachineInfo& Info); - ~HordeAgent(); + AsyncHordeAgent(asio::io_context& IoContext); + ~AsyncHordeAgent(); - HordeAgent(const HordeAgent&) = delete; - HordeAgent& operator=(const HordeAgent&) = delete; + AsyncHordeAgent(const AsyncHordeAgent&) = delete; + AsyncHordeAgent& operator=(const AsyncHordeAgent&) = delete; - /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel). - * Returns false if the handshake times out or receives an unexpected message. */ - bool BeginCommunication(); + /** Start the full agent lifecycle. The completion handler is called exactly once. */ + void Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone); - /** Upload binary files to the remote agent. - * @param BundleDir Directory containing .blob files. - * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */ - bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator); + /** Cancel in-flight operations. The completion handler is still called (with Success=false). */ + void Cancel(); - /** Execute a command on the remote machine. */ - void Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir = nullptr, - const char* const* EnvVars = nullptr, - size_t NumEnvVars = 0, - bool UseWine = false); + const MachineInfo& GetMachineInfo() const { return m_Config.Machine; } - /** Poll for output and results. Returns true if the agent is still running. - * When LogOutput is true, remote stdout is logged via ZEN_INFO. */ - bool Poll(bool LogOutput = true); - - void CloseConnection(); - bool IsValid() const; - - const MachineInfo& GetMachineInfo() const { return m_MachineInfo; } + enum class State + { + Idle, + Connecting, + WaitAgentAttach, + SentFork, + WaitChildAttach, + Uploading, + Executing, + Polling, + Done + }; private: LoggerRef Log() { return m_Log; } - std::unique_ptr<ComputeSocket> m_Socket; - std::unique_ptr<AgentMessageChannel> m_AgentChannel; ///< Channel 0: agent control - std::unique_ptr<AgentMessageChannel> m_ChildChannel; ///< Channel 100: child I/O - - LoggerRef m_Log; - bool m_IsValid = false; - bool m_HasErrors = false; - MachineInfo m_MachineInfo; + void DoConnect(); + void OnConnected(const std::error_code& Ec); + void DoWaitAgentAttach(); + void OnAgentResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoSendFork(); + void DoWaitChildAttach(); + void OnChildAttachResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoUploadNext(); + void OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoExecute(); + void DoPoll(); + void OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void Finish(bool Success, int32_t ExitCode = -1); + + asio::io_context& m_IoContext; + LoggerRef m_Log; + State m_State = State::Idle; + bool m_Cancelled = false; + + AsyncAgentConfig m_Config; + AsyncAgentCompletionHandler m_OnDone; + size_t m_CurrentBundleIndex = 0; + + std::unique_ptr<AsyncTcpComputeTransport> m_TcpTransport; + std::shared_ptr<AsyncComputeSocket> m_Socket; + std::unique_ptr<AsyncAgentMessageChannel> m_AgentChannel; + std::unique_ptr<AsyncAgentMessageChannel> m_ChildChannel; }; } // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp index 998134a96..bef1bdda8 100644 --- a/src/zenhorde/hordeagentmessage.cpp +++ b/src/zenhorde/hordeagentmessage.cpp @@ -4,337 +4,496 @@ #include <zencore/intmath.h> -#include <cassert> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <zencore/except_fmt.h> +#include <zencore/logging.h> + #include <cstring> +#include <limits> namespace zen::horde { -AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel)) -{ -} - -AgentMessageChannel::~AgentMessageChannel() = default; +// --- AsyncAgentMessageChannel --- -void -AgentMessageChannel::Close() +AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext) +: m_Socket(std::move(Socket)) +, m_ChannelId(ChannelId) +, m_IoContext(IoContext) +, m_TimeoutTimer(std::make_unique<asio::steady_timer>(m_Socket->GetStrand())) { - CreateMessage(AgentMessageType::None, 0); - FlushMessage(); } -void -AgentMessageChannel::Ping() +AsyncAgentMessageChannel::~AsyncAgentMessageChannel() { - CreateMessage(AgentMessageType::Ping, 0); - FlushMessage(); + if (m_TimeoutTimer) + { + m_TimeoutTimer->cancel(); + } } -void -AgentMessageChannel::Fork(int ChannelId, int BufferSize) +// --- Message building helpers --- + +std::vector<uint8_t> +AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload) { - CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); - WriteInt32(ChannelId); - WriteInt32(BufferSize); - FlushMessage(); + std::vector<uint8_t> Buf; + Buf.reserve(MessageHeaderLength + ReservePayload); + Buf.push_back(static_cast<uint8_t>(Type)); + Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder + return Buf; } void -AgentMessageChannel::Attach() +AsyncAgentMessageChannel::FinalizeAndSend(std::vector<uint8_t> Msg) { - CreateMessage(AgentMessageType::Attach, 0); - FlushMessage(); + const uint32_t PayloadSize = static_cast<uint32_t>(Msg.size() - MessageHeaderLength); + memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t)); + m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg)); } void -AgentMessageChannel::UploadFiles(const char* Path, const char* Locator) +AsyncAgentMessageChannel::WriteInt32(std::vector<uint8_t>& Buf, int Value) { - CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); - WriteString(Path); - WriteString(Locator); - FlushMessage(); + const uint8_t* Ptr = reinterpret_cast<const uint8_t*>(&Value); + Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int)); } -void -AgentMessageChannel::Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir, - const char* const* EnvVars, - size_t NumEnvVars, - ExecuteProcessFlags Flags) +int +AsyncAgentMessageChannel::ReadInt32(ReadCursor& C) { - size_t RequiredSize = 50 + strlen(Exe); - for (size_t i = 0; i < NumArgs; ++i) - { - RequiredSize += strlen(Args[i]) + 10; - } - if (WorkingDir) - { - RequiredSize += strlen(WorkingDir) + 10; - } - for (size_t i = 0; i < NumEnvVars; ++i) + if (!C.CheckAvailable(sizeof(int32_t))) { - RequiredSize += strlen(EnvVars[i]) + 20; + return 0; } + int32_t Value; + memcpy(&Value, C.Pos, sizeof(int32_t)); + C.Pos += sizeof(int32_t); + return Value; +} - CreateMessage(AgentMessageType::ExecuteV2, RequiredSize); - WriteString(Exe); +void +AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length) +{ + Buf.insert(Buf.end(), Data, Data + Length); +} - WriteUnsignedVarInt(NumArgs); - for (size_t i = 0; i < NumArgs; ++i) +const uint8_t* +AsyncAgentMessageChannel::ReadFixedLengthBytes(ReadCursor& C, size_t Length) +{ + if (!C.CheckAvailable(Length)) { - WriteString(Args[i]); + return nullptr; } + const uint8_t* Data = C.Pos; + C.Pos += Length; + return Data; +} - WriteOptionalString(WorkingDir); - - // ExecuteV2 protocol requires env vars as separate key/value pairs. - // Callers pass "KEY=VALUE" strings; we split on the first '=' here. - WriteUnsignedVarInt(NumEnvVars); - for (size_t i = 0; i < NumEnvVars; ++i) +size_t +AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value) +{ + if (Value == 0) { - const char* Eq = strchr(EnvVars[i], '='); - assert(Eq != nullptr); - - WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i])); - if (*(Eq + 1) == '\0') - { - WriteOptionalString(nullptr); - } - else - { - WriteOptionalString(Eq + 1); - } + return 1; } - - WriteInt32(static_cast<int>(Flags)); - FlushMessage(); + return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; } void -AgentMessageChannel::Blob(const uint8_t* Data, size_t Length) +AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value) { - // Blob responses are chunked to fit within the compute buffer's chunk size. - // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields). - const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength; - for (size_t ChunkOffset = 0; ChunkOffset < Length;) - { - const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize); - - CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); - WriteInt32(static_cast<int>(ChunkOffset)); - WriteInt32(static_cast<int>(Length)); - WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength); - FlushMessage(); + const size_t ByteCount = MeasureUnsignedVarInt(Value); + const size_t StartPos = Buf.size(); + Buf.resize(StartPos + ByteCount); - ChunkOffset += ChunkLength; + uint8_t* Output = Buf.data() + StartPos; + for (size_t i = 1; i < ByteCount; ++i) + { + Output[ByteCount - i] = static_cast<uint8_t>(Value); + Value >>= 8; } + Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); } -AgentMessageType -AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut) +size_t +AsyncAgentMessageChannel::ReadUnsignedVarInt(ReadCursor& C) { - // Deferred advance: the previous response's buffer is only released when the next - // ReadResponse is called. This allows callers to read response data between calls - // without copying, since the pointer comes directly from the ring buffer. - if (m_ResponseData) + // Need at least the leading byte to determine the encoded length. + if (!C.CheckAvailable(1)) { - m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength); - m_ResponseData = nullptr; - m_ResponseLength = 0; + return 0; } - const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut); - if (!Header) + const uint8_t FirstByte = C.Pos[0]; + const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24; + + // The encoded length implied by the leading 0xFF-run may be 1..9 bytes; ensure the remaining bytes are in-bounds. + if (!C.CheckAvailable(NumBytes)) { - return AgentMessageType::None; + return 0; } - uint32_t Length; - memcpy(&Length, Header + 1, sizeof(uint32_t)); - - Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut); - if (!Header) + size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes)); + for (size_t i = 1; i < NumBytes; ++i) { - return AgentMessageType::None; + Value <<= 8; + Value |= C.Pos[i]; } - m_ResponseType = static_cast<AgentMessageType>(Header[0]); - m_ResponseData = Header + MessageHeaderLength; - m_ResponseLength = Length; - - return m_ResponseType; + C.Pos += NumBytes; + return Value; } void -AgentMessageChannel::ReadException(ExceptionInfo& Ex) +AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, const char* Text) { - assert(m_ResponseType == AgentMessageType::Exception); - const uint8_t* Pos = m_ResponseData; - Ex.Message = ReadString(&Pos); - Ex.Description = ReadString(&Pos); + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Buf, Length); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length); } -int -AgentMessageChannel::ReadExecuteResult() +void +AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, std::string_view Text) { - assert(m_ResponseType == AgentMessageType::ExecuteResult); - const uint8_t* Pos = m_ResponseData; - return ReadInt32(&Pos); + WriteUnsignedVarInt(Buf, Text.size()); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); } -void -AgentMessageChannel::ReadBlobRequest(BlobRequest& Req) +std::string_view +AsyncAgentMessageChannel::ReadString(ReadCursor& C) { - assert(m_ResponseType == AgentMessageType::ReadBlob); - const uint8_t* Pos = m_ResponseData; - Req.Locator = ReadString(&Pos); - Req.Offset = ReadUnsignedVarInt(&Pos); - Req.Length = ReadUnsignedVarInt(&Pos); + const size_t Length = ReadUnsignedVarInt(C); + const uint8_t* Start = ReadFixedLengthBytes(C, Length); + if (C.ParseError || !Start) + { + return {}; + } + return std::string_view(reinterpret_cast<const char*>(Start), Length); } void -AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength) +AsyncAgentMessageChannel::WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text) { - m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength); - m_RequestData[0] = static_cast<uint8_t>(Type); - m_MaxRequestSize = MaxLength; - m_RequestSize = 0; + if (!Text) + { + WriteUnsignedVarInt(Buf, 0); + } + else + { + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Buf, Length + 1); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length); + } } +// --- Send methods --- + void -AgentMessageChannel::FlushMessage() +AsyncAgentMessageChannel::Close() { - const uint32_t Size = static_cast<uint32_t>(m_RequestSize); - memcpy(&m_RequestData[1], &Size, sizeof(uint32_t)); - m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize); - m_RequestSize = 0; - m_MaxRequestSize = 0; - m_RequestData = nullptr; + auto Msg = BeginMessage(AgentMessageType::None, 0); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteInt32(int Value) +AsyncAgentMessageChannel::Ping() { - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int)); + auto Msg = BeginMessage(AgentMessageType::Ping, 0); + FinalizeAndSend(std::move(Msg)); } -int -AgentMessageChannel::ReadInt32(const uint8_t** Pos) +void +AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize) { - int Value; - memcpy(&Value, *Pos, sizeof(int)); - *Pos += sizeof(int); - return Value; + auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); + WriteInt32(Msg, ChannelId); + WriteInt32(Msg, BufferSize); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length) +AsyncAgentMessageChannel::Attach() { - assert(m_RequestSize + Length <= m_MaxRequestSize); - memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length); - m_RequestSize += Length; + auto Msg = BeginMessage(AgentMessageType::Attach, 0); + FinalizeAndSend(std::move(Msg)); } -const uint8_t* -AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +void +AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator) { - const uint8_t* Data = *Pos; - *Pos += Length; - return Data; + auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); + WriteString(Msg, Path); + WriteString(Msg, Locator); + FinalizeAndSend(std::move(Msg)); } -size_t -AgentMessageChannel::MeasureUnsignedVarInt(size_t Value) +void +AsyncAgentMessageChannel::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags) { - if (Value == 0) + size_t ReserveSize = 50 + strlen(Exe); + for (size_t i = 0; i < NumArgs; ++i) { - return 1; + ReserveSize += strlen(Args[i]) + 10; + } + if (WorkingDir) + { + ReserveSize += strlen(WorkingDir) + 10; + } + for (size_t i = 0; i < NumEnvVars; ++i) + { + ReserveSize += strlen(EnvVars[i]) + 20; } - return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; -} -void -AgentMessageChannel::WriteUnsignedVarInt(size_t Value) -{ - const size_t ByteCount = MeasureUnsignedVarInt(Value); - assert(m_RequestSize + ByteCount <= m_MaxRequestSize); + auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize); + WriteString(Msg, Exe); - uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize; - for (size_t i = 1; i < ByteCount; ++i) + WriteUnsignedVarInt(Msg, NumArgs); + for (size_t i = 0; i < NumArgs; ++i) { - Output[ByteCount - i] = static_cast<uint8_t>(Value); - Value >>= 8; + WriteString(Msg, Args[i]); } - Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); - m_RequestSize += ByteCount; + WriteOptionalString(Msg, WorkingDir); + + WriteUnsignedVarInt(Msg, NumEnvVars); + for (size_t i = 0; i < NumEnvVars; ++i) + { + const char* Eq = strchr(EnvVars[i], '='); + if (Eq == nullptr) + { + // assert() would be compiled out in release and leave *(Eq+1) as UB - + // refuse to build the message for a malformed KEY=VALUE string instead. + throw zen::runtime_error("horde agent env var at index {} missing '=' separator", i); + } + + WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i])); + if (*(Eq + 1) == '\0') + { + WriteOptionalString(Msg, nullptr); + } + else + { + WriteOptionalString(Msg, Eq + 1); + } + } + + WriteInt32(Msg, static_cast<int>(Flags)); + FinalizeAndSend(std::move(Msg)); } -size_t -AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +void +AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length) { - const uint8_t* Data = *Pos; - const uint8_t FirstByte = Data[0]; - const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24; + static constexpr size_t MaxBlobChunkSize = 512 * 1024; - size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes)); - for (size_t i = 1; i < NumBytes; ++i) + // The Horde ReadBlobResponse wire format encodes both the chunk Offset and the total + // Length as int32. Lengths of 2 GiB or more would wrap to negative and confuse the + // remote parser. Refuse the send rather than produce a protocol violation. + if (Length > static_cast<size_t>(std::numeric_limits<int32_t>::max())) { - Value <<= 8; - Value |= Data[i]; + throw zen::runtime_error("horde ReadBlobResponse length {} exceeds int32 wire limit", Length); } - *Pos += NumBytes; - return Value; + for (size_t ChunkOffset = 0; ChunkOffset < Length;) + { + const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize); + + auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); + WriteInt32(Msg, static_cast<int32_t>(ChunkOffset)); + WriteInt32(Msg, static_cast<int32_t>(Length)); + WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength); + FinalizeAndSend(std::move(Msg)); + + ChunkOffset += ChunkLength; + } } -size_t -AgentMessageChannel::MeasureString(const char* Text) const +// --- Async response reading --- + +void +AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler) { - const size_t Length = strlen(Text); - return MeasureUnsignedVarInt(Length) + Length; + // Serialize all access to m_IncomingFrames / m_PendingHandler / m_TimeoutTimer onto + // the socket's strand; OnFrame/OnDetach also run on that strand. Without this, the + // timer wait completion would run on a bare io_context thread (3 concurrent run() + // loops in the provisioner) and race with OnFrame on m_PendingHandler. + asio::dispatch(m_Socket->GetStrand(), [this, TimeoutMs, Handler = std::move(Handler)]() mutable { + if (!m_IncomingFrames.empty()) + { + std::vector<uint8_t> Frame = std::move(m_IncomingFrames.front()); + m_IncomingFrames.pop_front(); + + if (Frame.size() >= MessageHeaderLength) + { + AgentMessageType Type = static_cast<AgentMessageType>(Frame[0]); + const uint8_t* Data = Frame.data() + MessageHeaderLength; + size_t Size = Frame.size() - MessageHeaderLength; + asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable { + // The Frame is captured to keep Data pointer valid + Handler(Type, Data, Size); + }); + } + else + { + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); + } + return; + } + + if (m_Detached) + { + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); + return; + } + + // No frames queued - store pending handler and arm timeout + m_PendingHandler = std::move(Handler); + + if (TimeoutMs >= 0) + { + m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs)); + m_TimeoutTimer->async_wait(asio::bind_executor(m_Socket->GetStrand(), [this](const asio::error_code& Ec) { + if (Ec) + { + return; // Cancelled - frame arrived before timeout + } + + // Already on the strand: safe to mutate m_PendingHandler. + if (m_PendingHandler) + { + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + Handler(AgentMessageType::None, nullptr, 0); + } + })); + } + }); } void -AgentMessageChannel::WriteString(const char* Text) +AsyncAgentMessageChannel::OnFrame(std::vector<uint8_t> Data) { - const size_t Length = strlen(Text); - WriteUnsignedVarInt(Length); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); + if (m_PendingHandler) + { + // Cancel the timeout timer + m_TimeoutTimer->cancel(); + + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + + if (Data.size() >= MessageHeaderLength) + { + AgentMessageType Type = static_cast<AgentMessageType>(Data[0]); + const uint8_t* Payload = Data.data() + MessageHeaderLength; + size_t PayloadSize = Data.size() - MessageHeaderLength; + Handler(Type, Payload, PayloadSize); + } + else + { + Handler(AgentMessageType::None, nullptr, 0); + } + } + else + { + m_IncomingFrames.push_back(std::move(Data)); + } } void -AgentMessageChannel::WriteString(std::string_view Text) +AsyncAgentMessageChannel::OnDetach() { - WriteUnsignedVarInt(Text.size()); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + m_Detached = true; + + if (m_PendingHandler) + { + m_TimeoutTimer->cancel(); + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + Handler(AgentMessageType::None, nullptr, 0); + } } -std::string_view -AgentMessageChannel::ReadString(const uint8_t** Pos) +// --- Response parsing helpers --- + +bool +AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex) { - const size_t Length = ReadUnsignedVarInt(Pos); - const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length)); - return std::string_view(Start, Length); + ReadCursor C{Data, Data + Size, false}; + Ex.Message = ReadString(C); + Ex.Description = ReadString(C); + if (C.ParseError) + { + Ex = {}; + return false; + } + return true; } -void -AgentMessageChannel::WriteOptionalString(const char* Text) +bool +AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode) { - // Optional strings use length+1 encoding: 0 means null/absent, - // N>0 means a string of length N-1 follows. This matches the UE - // FAgentMessageChannel serialization convention. - if (!Text) + ReadCursor C{Data, Data + Size, false}; + OutExitCode = ReadInt32(C); + return !C.ParseError; +} + +static bool +IsSafeLocator(std::string_view Locator) +{ + // Reject empty, overlong, path-separator-containing, parent-relative, absolute, or + // control-character-containing locators. The locator is used as a filename component + // joined with a trusted BundleDir, so the only safe characters are a restricted + // filename alphabet. + if (Locator.empty() || Locator.size() > 255) { - WriteUnsignedVarInt(0); + return false; } - else + if (Locator == "." || Locator == "..") { - const size_t Length = strlen(Text); - WriteUnsignedVarInt(Length + 1); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); + return false; + } + for (char Ch : Locator) + { + const unsigned char U = static_cast<unsigned char>(Ch); + if (U < 0x20 || U == 0x7F) + { + return false; // control / NUL / DEL + } + if (Ch == '/' || Ch == '\\' || Ch == ':') + { + return false; // path separators / drive letters + } + } + // Disallow leading/trailing dot or whitespace (Windows quirks + hidden-file dodges) + if (Locator.front() == '.' || Locator.front() == ' ' || Locator.back() == '.' || Locator.back() == ' ') + { + return false; + } + return true; +} + +bool +AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req) +{ + ReadCursor C{Data, Data + Size, false}; + Req.Locator = ReadString(C); + Req.Offset = ReadUnsignedVarInt(C); + Req.Length = ReadUnsignedVarInt(C); + if (C.ParseError || !IsSafeLocator(Req.Locator)) + { + Req = {}; + return false; } + return true; } } // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h index 38c4375fd..fb7c5ed29 100644 --- a/src/zenhorde/hordeagentmessage.h +++ b/src/zenhorde/hordeagentmessage.h @@ -4,14 +4,22 @@ #include <zenbase/zenbase.h> -#include "hordecomputechannel.h" +#include "hordecomputesocket.h" #include <cstddef> #include <cstdint> +#include <deque> +#include <functional> +#include <memory> #include <string> #include <string_view> +#include <system_error> #include <vector> +namespace asio { +class io_context; +} // namespace asio + namespace zen::horde { /** Agent message types matching the UE EAgentMessageType byte values. @@ -55,45 +63,34 @@ struct BlobRequest size_t Length = 0; }; -/** Channel for sending and receiving agent messages over a ComputeChannel. +/** Handler for async response reads. Receives the message type and a view of the payload data. + * The payload vector is valid until the next AsyncReadResponse call. */ +using AsyncResponseHandler = std::function<void(AgentMessageType Type, const uint8_t* Data, size_t Size)>; + +/** Async channel for sending and receiving agent messages over an AsyncComputeSocket. * - * Implements the Horde agent message protocol, matching the UE - * FAgentMessageChannel serialization format exactly. Messages are framed as - * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8; - * integers use variable-length encoding. + * Send methods build messages into vectors and submit them via AsyncComputeSocket. + * Receives are delivered via the socket's FrameHandler callback and queued internally. + * AsyncReadResponse checks the queue and invokes the handler, with optional timeout. * - * The protocol has two directions: - * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob - * - Responses (remote -> initiator): ReadResponse returns the type, then call the - * appropriate Read* method to parse the payload. + * All operations must be externally serialized (e.g. via the socket's strand). */ -class AgentMessageChannel +class AsyncAgentMessageChannel { public: - explicit AgentMessageChannel(Ref<ComputeChannel> Channel); - ~AgentMessageChannel(); + AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext); + ~AsyncAgentMessageChannel(); - AgentMessageChannel(const AgentMessageChannel&) = delete; - AgentMessageChannel& operator=(const AgentMessageChannel&) = delete; + AsyncAgentMessageChannel(const AsyncAgentMessageChannel&) = delete; + AsyncAgentMessageChannel& operator=(const AsyncAgentMessageChannel&) = delete; - // --- Requests (Initiator -> Remote) --- + // --- Requests (fire-and-forget sends) --- - /** Close the channel. */ void Close(); - - /** Send a keepalive ping. */ void Ping(); - - /** Fork communication to a new channel with the given ID and buffer size. */ void Fork(int ChannelId, int BufferSize); - - /** Send an attach request (used during channel setup handshake). */ void Attach(); - - /** Request the remote agent to write files from the given bundle locator. */ void UploadFiles(const char* Path, const char* Locator); - - /** Execute a process on the remote machine. */ void Execute(const char* Exe, const char* const* Args, size_t NumArgs, @@ -101,61 +98,85 @@ public: const char* const* EnvVars, size_t NumEnvVars, ExecuteProcessFlags Flags = ExecuteProcessFlags::None); - - /** Send blob data in response to a ReadBlob request. */ void Blob(const uint8_t* Data, size_t Length); - // --- Responses (Remote -> Initiator) --- - - /** Read the next response message. Returns the message type, or None on timeout. - * After this returns, use GetResponseData()/GetResponseSize() or the typed - * Read* methods to access the payload. */ - AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr); + // --- Async response reading --- - const void* GetResponseData() const { return m_ResponseData; } - size_t GetResponseSize() const { return m_ResponseLength; } + /** Read the next response. If a frame is already queued, the handler is posted immediately. + * Otherwise waits up to TimeoutMs for a frame to arrive. On timeout, invokes the handler + * with AgentMessageType::None. */ + void AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler); - /** Parse an Exception response payload. */ - void ReadException(ExceptionInfo& Ex); + /** Called by the socket's FrameHandler when a frame arrives for this channel. */ + void OnFrame(std::vector<uint8_t> Data); - /** Parse an ExecuteResult response payload. Returns the exit code. */ - int ReadExecuteResult(); + /** Called by the socket's DetachHandler. */ + void OnDetach(); - /** Parse a ReadBlob response payload into a BlobRequest. */ - void ReadBlobRequest(BlobRequest& Req); + /** Returns true if the channel has been detached (connection lost). */ + bool IsDetached() const { return m_Detached; } -private: - static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)] + // --- Response parsing helpers --- - Ref<ComputeChannel> m_Channel; + /** Parse an Exception message payload. Returns false on malformed/truncated input. */ + [[nodiscard]] static bool ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex); - uint8_t* m_RequestData = nullptr; - size_t m_RequestSize = 0; - size_t m_MaxRequestSize = 0; + /** Parse an ExecuteResult message payload. Returns false on malformed/truncated input. */ + [[nodiscard]] static bool ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode); - AgentMessageType m_ResponseType = AgentMessageType::None; - const uint8_t* m_ResponseData = nullptr; - size_t m_ResponseLength = 0; + /** Parse a ReadBlob message payload. Returns false on malformed/truncated input or + * if the Locator contains characters that would not be safe to use as a path component. */ + [[nodiscard]] static bool ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req); - void CreateMessage(AgentMessageType Type, size_t MaxLength); - void FlushMessage(); +private: + static constexpr size_t MessageHeaderLength = 5; + + // Message building helpers + std::vector<uint8_t> BeginMessage(AgentMessageType Type, size_t ReservePayload); + void FinalizeAndSend(std::vector<uint8_t> Msg); + + /** Bounds-checked reader cursor. All Read* helpers set ParseError instead of reading past End. */ + struct ReadCursor + { + const uint8_t* Pos = nullptr; + const uint8_t* End = nullptr; + bool ParseError = false; + + [[nodiscard]] bool CheckAvailable(size_t N) + { + if (ParseError || static_cast<size_t>(End - Pos) < N) + { + ParseError = true; + return false; + } + return true; + } + }; + + static void WriteInt32(std::vector<uint8_t>& Buf, int Value); + static int ReadInt32(ReadCursor& C); + + static void WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length); + static const uint8_t* ReadFixedLengthBytes(ReadCursor& C, size_t Length); - void WriteInt32(int Value); - static int ReadInt32(const uint8_t** Pos); + static size_t MeasureUnsignedVarInt(size_t Value); + static void WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value); + static size_t ReadUnsignedVarInt(ReadCursor& C); - void WriteFixedLengthBytes(const uint8_t* Data, size_t Length); - static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length); + static void WriteString(std::vector<uint8_t>& Buf, const char* Text); + static void WriteString(std::vector<uint8_t>& Buf, std::string_view Text); + static std::string_view ReadString(ReadCursor& C); - static size_t MeasureUnsignedVarInt(size_t Value); - void WriteUnsignedVarInt(size_t Value); - static size_t ReadUnsignedVarInt(const uint8_t** Pos); + static void WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text); - size_t MeasureString(const char* Text) const; - void WriteString(const char* Text); - void WriteString(std::string_view Text); - static std::string_view ReadString(const uint8_t** Pos); + std::shared_ptr<AsyncComputeSocket> m_Socket; + int m_ChannelId; + asio::io_context& m_IoContext; - void WriteOptionalString(const char* Text); + std::deque<std::vector<uint8_t>> m_IncomingFrames; + AsyncResponseHandler m_PendingHandler; + std::unique_ptr<asio::steady_timer> m_TimeoutTimer; + bool m_Detached = false; }; } // namespace zen::horde diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp index d3974bc28..8493a9456 100644 --- a/src/zenhorde/hordebundle.cpp +++ b/src/zenhorde/hordebundle.cpp @@ -10,6 +10,7 @@ #include <zencore/logging.h> #include <zencore/process.h> #include <zencore/trace.h> +#include <zencore/uid.h> #include <algorithm> #include <chrono> @@ -48,7 +49,7 @@ static constexpr uint8_t BlobType_DirectoryV1[20] = {0x11, 0xEC, 0x14, 0x07, 0x1 static constexpr size_t BlobTypeSize = 20; -// ─── VarInt helpers (UE format) ───────────────────────────────────────────── +// --- VarInt helpers (UE format) --------------------------------------------- static size_t MeasureVarInt(size_t Value) @@ -57,7 +58,7 @@ MeasureVarInt(size_t Value) { return 1; } - return (FloorLog2(static_cast<unsigned int>(Value)) / 7) + 1; + return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; } static void @@ -76,7 +77,7 @@ WriteVarInt(std::vector<uint8_t>& Buffer, size_t Value) Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); } -// ─── Binary helpers ───────────────────────────────────────────────────────── +// --- Binary helpers --------------------------------------------------------- static void WriteLE32(std::vector<uint8_t>& Buffer, int32_t Value) @@ -121,7 +122,7 @@ PatchLE32(std::vector<uint8_t>& Buffer, size_t Offset, int32_t Value) memcpy(Buffer.data() + Offset, &Value, 4); } -// ─── Packet builder ───────────────────────────────────────────────────────── +// --- Packet builder --------------------------------------------------------- // Builds a single uncompressed Horde V2 packet. Layout: // [Signature(3) + Version(1) + PacketLength(4)] 8 bytes (header) @@ -229,7 +230,7 @@ struct PacketBuilder { AlignTo4(Data); - // ── Type table: count(int32) + count * BlobTypeSize bytes ── + // -- Type table: count(int32) + count * BlobTypeSize bytes -- const int32_t TypeTableOffset = static_cast<int32_t>(Data.size()); WriteLE32(Data, static_cast<int32_t>(Types.size())); for (const uint8_t* TypeEntry : Types) @@ -237,12 +238,12 @@ struct PacketBuilder WriteBytes(Data, TypeEntry, BlobTypeSize); } - // ── Import table: count(int32) + (count+1) offsets(int32 each) + import data ── + // -- Import table: count(int32) + (count+1) offsets(int32 each) + import data -- const int32_t ImportTableOffset = static_cast<int32_t>(Data.size()); const int32_t ImportCount = static_cast<int32_t>(Imports.size()); WriteLE32(Data, ImportCount); - // Reserve space for (count+1) offset entries — will be patched below + // Reserve space for (count+1) offset entries - will be patched below const size_t ImportOffsetsStart = Data.size(); for (int32_t i = 0; i <= ImportCount; ++i) { @@ -266,7 +267,7 @@ struct PacketBuilder // Sentinel offset (points past the last import's data) PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(ImportCount) * 4, static_cast<int32_t>(Data.size())); - // ── Export table: count(int32) + (count+1) offsets(int32 each) ── + // -- Export table: count(int32) + (count+1) offsets(int32 each) -- const int32_t ExportTableOffset = static_cast<int32_t>(Data.size()); const int32_t ExportCount = static_cast<int32_t>(ExportOffsets.size()); WriteLE32(Data, ExportCount); @@ -278,7 +279,7 @@ struct PacketBuilder // Sentinel: points to the start of the type table (end of export data region) WriteLE32(Data, TypeTableOffset); - // ── Patch header ── + // -- Patch header -- // PacketLength = total packet size including the 8-byte header const int32_t PacketLength = static_cast<int32_t>(Data.size()); PatchLE32(Data, 4, PacketLength); @@ -290,7 +291,7 @@ struct PacketBuilder } }; -// ─── Encoded packet wrapper ───────────────────────────────────────────────── +// --- Encoded packet wrapper ------------------------------------------------- // Wraps an uncompressed packet with the encoded header: // [Signature(3) + Version(1) + HeaderLength(4)] 8 bytes @@ -327,24 +328,22 @@ EncodePacket(std::vector<uint8_t> UncompressedPacket) return Encoded; } -// ─── Bundle blob name generation ──────────────────────────────────────────── +// --- Bundle blob name generation -------------------------------------------- static std::string GenerateBlobName() { - static std::atomic<uint32_t> s_Counter{0}; - - const int Pid = GetCurrentProcessId(); - - auto Now = std::chrono::steady_clock::now().time_since_epoch(); - auto Ms = std::chrono::duration_cast<std::chrono::milliseconds>(Now).count(); - - ExtendableStringBuilder<64> Name; - Name << Pid << "_" << Ms << "_" << s_Counter.fetch_add(1); - return std::string(Name.ToView()); + // Oid is a 12-byte identifier built from a timestamp, a monotonic serial number + // initialised from std::random_device, and a per-process run id also drawn from + // std::random_device. The 24-hex-char rendering gives ~80 bits of effective + // name-prediction entropy, so a local attacker cannot race-create the blob + // path before we open it. Previously the name was pid+ms+counter, which two + // zenserver processes with the same PID could collide on and which was + // entirely predictable. + return zen::Oid::NewOid().ToString(); } -// ─── File info for bundling ───────────────────────────────────────────────── +// --- File info for bundling ------------------------------------------------- struct FileInfo { @@ -357,7 +356,7 @@ struct FileInfo IoHash RootExportHash; // IoHash of the root export for this file }; -// ─── CreateBundle implementation ──────────────────────────────────────────── +// --- CreateBundle implementation -------------------------------------------- bool BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult) @@ -534,7 +533,7 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil FileInfo& Info = ValidFiles[i]; DirImports.push_back(Info.DirectoryExportImportIndex); - // IoHash of target (20 bytes) — import is consumed sequentially from the + // IoHash of target (20 bytes) - import is consumed sequentially from the // export's import list by ReadBlobRef, not encoded in the payload WriteBytes(DirPayload, Info.RootExportHash.Hash, sizeof(IoHash)); // name (string) @@ -557,8 +556,16 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil std::vector<uint8_t> UncompressedPacket = Packet.Finish(); std::vector<uint8_t> EncodedPacket = EncodePacket(std::move(UncompressedPacket)); - // Write .blob file + // Write .blob file. Refuse to proceed if a file with this name already exists - + // the Oid-based BlobName should make collisions astronomically unlikely, so an + // existing file implies either an extraordinary collision or an attacker having + // pre-seeded the path; either way, we do not want to overwrite it. const std::filesystem::path BlobFilePath = OutputDir / (BlobName + ".blob"); + if (std::filesystem::exists(BlobFilePath, Ec)) + { + ZEN_ERROR("blob file already exists at {} - refusing to overwrite", BlobFilePath.string()); + return false; + } { BasicFile BlobFile(BlobFilePath, BasicFile::Mode::kTruncate, Ec); if (Ec) @@ -574,8 +581,10 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil Locator << BlobName << "#pkt=0," << uint64_t(EncodedPacket.size()) << "&exp=" << DirExportIndex; const std::string LocatorStr(Locator.ToView()); - // Write .ref file (use first file's name as the ref base) - const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + ".Bundle.ref"); + // Write .ref file. Include the Oid-based BlobName so that two concurrent + // CreateBundle() calls into the same OutputDir that happen to share the first + // filename don't clobber each other's ref file. + const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + "." + BlobName + ".Bundle.ref"); { BasicFile RefFile(RefFilePath, BasicFile::Mode::kTruncate, Ec); if (Ec) diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp index 0eefc57c6..762edce06 100644 --- a/src/zenhorde/hordeclient.cpp +++ b/src/zenhorde/hordeclient.cpp @@ -4,6 +4,7 @@ #include <zencore/iobuffer.h> #include <zencore/logging.h> #include <zencore/memoryview.h> +#include <zencore/string.h> #include <zencore/trace.h> #include <zenhorde/hordeclient.h> #include <zenhttp/httpclient.h> @@ -14,7 +15,7 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen::horde { -HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client")) +HordeClient::HordeClient(HordeConfig Config) : m_Config(std::move(Config)), m_Log("horde.client") { } @@ -32,16 +33,24 @@ HordeClient::Initialize() Settings.RetryCount = 1; Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests}; - if (!m_Config.AuthToken.empty()) + if (m_Config.AccessTokenProvider) { + Settings.AccessTokenProvider = m_Config.AccessTokenProvider; + } + else if (!m_Config.AuthToken.empty()) + { + // Static tokens have no wire-provided expiry. Synthesising \"now + 24h\" is wrong + // in both directions: if the real token expires before 24h we keep sending it after + // it dies; if it's long-lived we force unnecessary re-auth churn every day. Use the + // never-expires sentinel, matching zenhttp's CreateFromStaticToken. Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken { - return HttpClientAccessToken(token, HttpClientAccessToken::Clock::now() + std::chrono::hours{24}); + return HttpClientAccessToken(token, HttpClientAccessToken::TimePoint::max()); }; } m_Http = std::make_unique<zen::HttpClient>(m_Config.ServerUrl, Settings); - if (!m_Config.AuthToken.empty()) + if (Settings.AccessTokenProvider) { if (!m_Http->Authenticate()) { @@ -63,24 +72,21 @@ HordeClient::BuildRequestBody() const Requirements["pool"] = m_Config.Pool; } - std::string Condition; -#if ZEN_PLATFORM_WINDOWS ExtendableStringBuilder<256> CondBuf; +#if ZEN_PLATFORM_WINDOWS CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')"; - Condition = std::string(CondBuf); #elif ZEN_PLATFORM_MAC - Condition = "OSFamily == 'MacOS'"; + CondBuf << "OSFamily == 'MacOS'"; #else - Condition = "OSFamily == 'Linux'"; + CondBuf << "OSFamily == 'Linux'"; #endif if (!m_Config.Condition.empty()) { - Condition += " "; - Condition += m_Config.Condition; + CondBuf << " " << m_Config.Condition; } - Requirements["condition"] = Condition; + Requirements["condition"] = std::string(CondBuf); Requirements["exclusive"] = true; json11::Json::object Connection; @@ -156,39 +162,27 @@ HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutClus return false; } - OutCluster.ClusterId = ClusterIdVal.string_value(); - return true; -} - -bool -HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize) -{ - if (Hex.size() != OutSize * 2) + // A server-returned ClusterId is interpolated directly into the request URL below + // (api/v2/compute/<ClusterId>), so a compromised or MITM'd Horde server could + // otherwise inject additional path segments or query strings. Constrain to a + // conservative identifier alphabet. + const std::string& ClusterIdStr = ClusterIdVal.string_value(); + if (ClusterIdStr.size() > 64) { + ZEN_WARN("rejecting overlong clusterId ({} bytes) in cluster resolution response", ClusterIdStr.size()); return false; } - - for (size_t i = 0; i < OutSize; ++i) + static constexpr AsciiSet ValidClusterIdCharactersSet{"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-"}; + if (!AsciiSet::HasOnly(ClusterIdStr, ValidClusterIdCharactersSet)) { - auto HexToByte = [](char c) -> int { - if (c >= '0' && c <= '9') - return c - '0'; - if (c >= 'a' && c <= 'f') - return c - 'a' + 10; - if (c >= 'A' && c <= 'F') - return c - 'A' + 10; - return -1; - }; - - const int Hi = HexToByte(Hex[i * 2]); - const int Lo = HexToByte(Hex[i * 2 + 1]); - if (Hi < 0 || Lo < 0) - { - return false; - } - Out[i] = static_cast<uint8_t>((Hi << 4) | Lo); + ZEN_WARN("rejecting clusterId with unsafe character in cluster resolution response"); + return false; } + OutCluster.ClusterId = ClusterIdStr; + + ZEN_DEBUG("cluster resolution succeeded: clusterId='{}'", OutCluster.ClusterId); + return true; } @@ -197,8 +191,6 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C { ZEN_TRACE_CPU("HordeClient::RequestMachine"); - ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str()); - ExtendableStringBuilder<128> ResourcePath; ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str()); @@ -318,11 +310,15 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C } else if (Prop.starts_with("LogicalCores=")) { - LogicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 13)); + LogicalCores = ParseInt<uint16_t>(std::string_view(Prop).substr(13)).value_or(0); } else if (Prop.starts_with("PhysicalCores=")) { - PhysicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 14)); + PhysicalCores = ParseInt<uint16_t>(std::string_view(Prop).substr(14)).value_or(0); + } + else if (Prop.starts_with("Pool=")) + { + OutMachine.Pool = Prop.substr(5); } } } @@ -367,10 +363,12 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C OutMachine.LeaseId = LeaseIdVal.string_value(); } - ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}", + ZEN_INFO("Horde machine assigned [{}:{}] mode={} cores={} pool={} lease={}", OutMachine.GetConnectionAddress(), OutMachine.GetConnectionPort(), + ToString(OutMachine.Mode), OutMachine.LogicalCores, + OutMachine.Pool, OutMachine.LeaseId); return true; diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp deleted file mode 100644 index 0d032b5d5..000000000 --- a/src/zenhorde/hordecomputebuffer.cpp +++ /dev/null @@ -1,454 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "hordecomputebuffer.h" - -#include <algorithm> -#include <cassert> -#include <chrono> -#include <condition_variable> -#include <cstring> - -namespace zen::horde { - -// Simplified ring buffer implementation for in-process use only. -// Uses a single contiguous buffer with write/read cursors and -// mutex+condvar for synchronization. This is simpler than the UE version -// which uses lock-free atomics and shared memory, but sufficient for our -// use case where we're the initiator side of the compute protocol. - -struct ComputeBuffer::Detail : TRefCounted<Detail> -{ - std::vector<uint8_t> Data; - size_t NumChunks = 0; - size_t ChunkLength = 0; - - // Current write state - size_t WriteChunkIdx = 0; - size_t WriteOffset = 0; - bool WriteComplete = false; - - // Current read state - size_t ReadChunkIdx = 0; - size_t ReadOffset = 0; - bool Detached = false; - - // Per-chunk written length - std::vector<size_t> ChunkWrittenLength; - std::vector<bool> ChunkFinished; // Writer moved to next chunk - - std::mutex Mutex; - std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes - std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space - - bool HasWriter = false; - bool HasReader = false; - - uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; } - const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; } -}; - -// ComputeBuffer - -ComputeBuffer::ComputeBuffer() -{ -} -ComputeBuffer::~ComputeBuffer() -{ -} - -bool -ComputeBuffer::CreateNew(const Params& InParams) -{ - auto* NewDetail = new Detail(); - NewDetail->NumChunks = InParams.NumChunks; - NewDetail->ChunkLength = InParams.ChunkLength; - NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0); - NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0); - NewDetail->ChunkFinished.resize(InParams.NumChunks, false); - - m_Detail = NewDetail; - return true; -} - -void -ComputeBuffer::Close() -{ - m_Detail = nullptr; -} - -bool -ComputeBuffer::IsValid() const -{ - return static_cast<bool>(m_Detail); -} - -ComputeBufferReader -ComputeBuffer::CreateReader() -{ - assert(m_Detail); - m_Detail->HasReader = true; - return ComputeBufferReader(m_Detail); -} - -ComputeBufferWriter -ComputeBuffer::CreateWriter() -{ - assert(m_Detail); - m_Detail->HasWriter = true; - return ComputeBufferWriter(m_Detail); -} - -// ComputeBufferReader - -ComputeBufferReader::ComputeBufferReader() -{ -} -ComputeBufferReader::~ComputeBufferReader() -{ -} - -ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default; -ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default; -ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default; -ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default; - -ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) -{ -} - -void -ComputeBufferReader::Close() -{ - m_Detail = nullptr; -} - -void -ComputeBufferReader::Detach() -{ - if (m_Detail) - { - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - m_Detail->Detached = true; - m_Detail->ReadCV.notify_all(); - } -} - -bool -ComputeBufferReader::IsValid() const -{ - return static_cast<bool>(m_Detail); -} - -bool -ComputeBufferReader::IsComplete() const -{ - if (!m_Detail) - { - return true; - } - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - if (m_Detail->Detached) - { - return true; - } - return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx && - m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx]; -} - -void -ComputeBufferReader::AdvanceReadPosition(size_t Size) -{ - if (!m_Detail) - { - return; - } - - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - - m_Detail->ReadOffset += Size; - - // Check if we need to move to next chunk - const size_t ReadChunk = m_Detail->ReadChunkIdx; - if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) - { - const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; - m_Detail->ReadChunkIdx = NextChunk; - m_Detail->ReadOffset = 0; - m_Detail->WriteCV.notify_all(); - } - - m_Detail->ReadCV.notify_all(); -} - -size_t -ComputeBufferReader::GetMaxReadSize() const -{ - if (!m_Detail) - { - return 0; - } - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - const size_t ReadChunk = m_Detail->ReadChunkIdx; - return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; -} - -const uint8_t* -ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut) -{ - if (!m_Detail) - { - return nullptr; - } - - std::unique_lock<std::mutex> Lock(m_Detail->Mutex); - - auto Predicate = [&]() -> bool { - if (m_Detail->Detached) - { - return true; - } - - const size_t ReadChunk = m_Detail->ReadChunkIdx; - const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; - - if (Available >= MinSize) - { - return true; - } - - // If chunk is finished and we've read everything, try to move to next - if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) - { - if (m_Detail->WriteComplete) - { - return true; // End of stream - } - // Move to next chunk - const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; - m_Detail->ReadChunkIdx = NextChunk; - m_Detail->ReadOffset = 0; - m_Detail->WriteCV.notify_all(); - return false; // Re-check with new chunk - } - - if (m_Detail->WriteComplete) - { - return true; // End of stream - } - - return false; - }; - - if (TimeoutMs < 0) - { - m_Detail->ReadCV.wait(Lock, Predicate); - } - else - { - if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) - { - if (OutTimedOut) - { - *OutTimedOut = true; - } - return nullptr; - } - } - - if (m_Detail->Detached) - { - return nullptr; - } - - const size_t ReadChunk = m_Detail->ReadChunkIdx; - const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; - - if (Available < MinSize) - { - return nullptr; // End of stream - } - - return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset; -} - -size_t -ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut) -{ - const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut); - if (!Data) - { - return 0; - } - - const size_t Available = GetMaxReadSize(); - const size_t ToCopy = std::min(Available, MaxSize); - memcpy(Buffer, Data, ToCopy); - AdvanceReadPosition(ToCopy); - return ToCopy; -} - -// ComputeBufferWriter - -ComputeBufferWriter::ComputeBufferWriter() = default; -ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default; -ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default; -ComputeBufferWriter::~ComputeBufferWriter() = default; -ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default; -ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default; - -ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) -{ -} - -void -ComputeBufferWriter::Close() -{ - if (m_Detail) - { - { - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - if (!m_Detail->WriteComplete) - { - m_Detail->WriteComplete = true; - m_Detail->ReadCV.notify_all(); - } - } - m_Detail = nullptr; - } -} - -bool -ComputeBufferWriter::IsValid() const -{ - return static_cast<bool>(m_Detail); -} - -void -ComputeBufferWriter::MarkComplete() -{ - if (m_Detail) - { - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - m_Detail->WriteComplete = true; - m_Detail->ReadCV.notify_all(); - } -} - -void -ComputeBufferWriter::AdvanceWritePosition(size_t Size) -{ - if (!m_Detail || Size == 0) - { - return; - } - - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - const size_t WriteChunk = m_Detail->WriteChunkIdx; - m_Detail->ChunkWrittenLength[WriteChunk] += Size; - m_Detail->WriteOffset += Size; - m_Detail->ReadCV.notify_all(); -} - -size_t -ComputeBufferWriter::GetMaxWriteSize() const -{ - if (!m_Detail) - { - return 0; - } - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - const size_t WriteChunk = m_Detail->WriteChunkIdx; - return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; -} - -size_t -ComputeBufferWriter::GetChunkMaxLength() const -{ - if (!m_Detail) - { - return 0; - } - return m_Detail->ChunkLength; -} - -size_t -ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs) -{ - uint8_t* Dest = WaitToWrite(1, TimeoutMs); - if (!Dest) - { - return 0; - } - - const size_t Available = GetMaxWriteSize(); - const size_t ToCopy = std::min(Available, MaxSize); - memcpy(Dest, Buffer, ToCopy); - AdvanceWritePosition(ToCopy); - return ToCopy; -} - -uint8_t* -ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs) -{ - if (!m_Detail) - { - return nullptr; - } - - std::unique_lock<std::mutex> Lock(m_Detail->Mutex); - - if (m_Detail->WriteComplete) - { - return nullptr; - } - - const size_t WriteChunk = m_Detail->WriteChunkIdx; - const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; - - // If current chunk has enough space, return pointer - if (Available >= MinSize) - { - return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk]; - } - - // Current chunk is full - mark it as finished and move to next. - // The writer cannot advance until the reader has fully consumed the next chunk, - // preventing the writer from overwriting data the reader hasn't processed yet. - m_Detail->ChunkFinished[WriteChunk] = true; - m_Detail->ReadCV.notify_all(); - - const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks; - - // Wait until reader has consumed the next chunk - auto Predicate = [&]() -> bool { - // Check if read has moved past this chunk - return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached; - }; - - if (TimeoutMs < 0) - { - m_Detail->WriteCV.wait(Lock, Predicate); - } - else - { - if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) - { - return nullptr; - } - } - - if (m_Detail->Detached) - { - return nullptr; - } - - // Reset next chunk - m_Detail->ChunkWrittenLength[NextChunk] = 0; - m_Detail->ChunkFinished[NextChunk] = false; - m_Detail->WriteChunkIdx = NextChunk; - m_Detail->WriteOffset = 0; - - return m_Detail->ChunkPtr(NextChunk); -} - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h deleted file mode 100644 index 64ef91b7a..000000000 --- a/src/zenhorde/hordecomputebuffer.h +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zenbase/refcount.h> - -#include <cstddef> -#include <cstdint> -#include <mutex> -#include <vector> - -namespace zen::horde { - -class ComputeBufferReader; -class ComputeBufferWriter; - -/** Simplified in-process ring buffer for the Horde compute protocol. - * - * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files, - * this implementation uses plain heap-allocated memory since we only need in-process - * communication between channel and transport threads. The buffer is divided into - * fixed-size chunks; readers and writers block when no space is available. - */ -class ComputeBuffer -{ -public: - struct Params - { - size_t NumChunks = 2; - size_t ChunkLength = 512 * 1024; - }; - - ComputeBuffer(); - ~ComputeBuffer(); - - ComputeBuffer(const ComputeBuffer&) = delete; - ComputeBuffer& operator=(const ComputeBuffer&) = delete; - - bool CreateNew(const Params& InParams); - void Close(); - - bool IsValid() const; - - ComputeBufferReader CreateReader(); - ComputeBufferWriter CreateWriter(); - -private: - struct Detail; - Ref<Detail> m_Detail; - - friend class ComputeBufferReader; - friend class ComputeBufferWriter; -}; - -/** Read endpoint for a ComputeBuffer. - * - * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer - * directly into the buffer memory (zero-copy); the caller must call - * AdvanceReadPosition() after consuming the data. - */ -class ComputeBufferReader -{ -public: - ComputeBufferReader(); - ComputeBufferReader(const ComputeBufferReader&); - ComputeBufferReader(ComputeBufferReader&&) noexcept; - ~ComputeBufferReader(); - - ComputeBufferReader& operator=(const ComputeBufferReader&); - ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept; - - void Close(); - void Detach(); - bool IsValid() const; - bool IsComplete() const; - - void AdvanceReadPosition(size_t Size); - size_t GetMaxReadSize() const; - - /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */ - size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); - - /** Wait until at least MinSize bytes are available and return a direct pointer. - * Returns nullptr on timeout or if the writer has completed. */ - const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); - -private: - friend class ComputeBuffer; - explicit ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail); - - Ref<ComputeBuffer::Detail> m_Detail; -}; - -/** Write endpoint for a ComputeBuffer. - * - * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer - * directly into the buffer memory (zero-copy); the caller must call - * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal - * that no more data will be written. - */ -class ComputeBufferWriter -{ -public: - ComputeBufferWriter(); - ComputeBufferWriter(const ComputeBufferWriter&); - ComputeBufferWriter(ComputeBufferWriter&&) noexcept; - ~ComputeBufferWriter(); - - ComputeBufferWriter& operator=(const ComputeBufferWriter&); - ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept; - - void Close(); - bool IsValid() const; - - /** Signal that no more data will be written. Unblocks any waiting readers. */ - void MarkComplete(); - - void AdvanceWritePosition(size_t Size); - size_t GetMaxWriteSize() const; - size_t GetChunkMaxLength() const; - - /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */ - size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1); - - /** Wait until at least MinSize bytes of write space are available and return a direct pointer. - * Returns nullptr on timeout. */ - uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1); - -private: - friend class ComputeBuffer; - explicit ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail); - - Ref<ComputeBuffer::Detail> m_Detail; -}; - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp deleted file mode 100644 index ee2a6f327..000000000 --- a/src/zenhorde/hordecomputechannel.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "hordecomputechannel.h" - -namespace zen::horde { - -ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter) -: Reader(std::move(InReader)) -, Writer(std::move(InWriter)) -{ -} - -bool -ComputeChannel::IsValid() const -{ - return Reader.IsValid() && Writer.IsValid(); -} - -size_t -ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs) -{ - return Writer.Write(Data, Size, TimeoutMs); -} - -size_t -ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs) -{ - return Reader.Read(Data, Size, TimeoutMs); -} - -void -ComputeChannel::MarkComplete() -{ - Writer.MarkComplete(); -} - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h deleted file mode 100644 index c1dff20e4..000000000 --- a/src/zenhorde/hordecomputechannel.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include "hordecomputebuffer.h" - -namespace zen::horde { - -/** Bidirectional communication channel using a pair of compute buffers. - * - * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter - * (for sending data). Used by ComputeSocket to represent one logical channel - * within a multiplexed connection. - */ -class ComputeChannel : public TRefCounted<ComputeChannel> -{ -public: - ComputeBufferReader Reader; - ComputeBufferWriter Writer; - - ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter); - - bool IsValid() const; - - size_t Send(const void* Data, size_t Size, int TimeoutMs = -1); - size_t Recv(void* Data, size_t Size, int TimeoutMs = -1); - - /** Signal that no more data will be sent on this channel. */ - void MarkComplete(); -}; - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp index 6ef67760c..92a56c077 100644 --- a/src/zenhorde/hordecomputesocket.cpp +++ b/src/zenhorde/hordecomputesocket.cpp @@ -6,198 +6,326 @@ namespace zen::horde { -ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport) -: m_Log(zen::logging::Get("horde.socket")) +AsyncComputeSocket::AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext) +: m_Log(zen::logging::Get("horde.socket.async")) , m_Transport(std::move(Transport)) +, m_Strand(asio::make_strand(IoContext)) +, m_PingTimer(m_Strand) { } -ComputeSocket::~ComputeSocket() +AsyncComputeSocket::~AsyncComputeSocket() { - // Shutdown order matters: first stop the ping thread, then unblock send threads - // by detaching readers, then join send threads, and finally close the transport - // to unblock the recv thread (which is blocked on RecvMessage). - { - std::lock_guard<std::mutex> Lock(m_PingMutex); - m_PingShouldStop = true; - m_PingCV.notify_all(); - } - - for (auto& Reader : m_Readers) - { - Reader.Detach(); - } - - for (auto& [Id, Thread] : m_SendThreads) - { - if (Thread.joinable()) - { - Thread.join(); - } - } + Close(); +} - m_Transport->Close(); +void +AsyncComputeSocket::RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach) +{ + m_FrameHandlers[ChannelId] = std::move(OnFrame); + m_DetachHandlers[ChannelId] = std::move(OnDetach); +} - if (m_RecvThread.joinable()) - { - m_RecvThread.join(); - } - if (m_PingThread.joinable()) - { - m_PingThread.join(); - } +void +AsyncComputeSocket::StartRecvPump() +{ + StartPingTimer(); + DoRecvHeader(); } -Ref<ComputeChannel> -ComputeSocket::CreateChannel(int ChannelId) +void +AsyncComputeSocket::DoRecvHeader() { - ComputeBuffer::Params Params; + auto Self = shared_from_this(); + m_Transport->AsyncRead(&m_RecvHeader, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && !m_Closed) + { + ZEN_WARN("recv header error: {}", Ec.message()); + HandleError(); + } + return; + } - ComputeBuffer RecvBuffer; - if (!RecvBuffer.CreateNew(Params)) - { - return {}; - } + if (m_Closed) + { + return; + } - ComputeBuffer SendBuffer; - if (!SendBuffer.CreateNew(Params)) - { - return {}; - } + if (m_RecvHeader.Size >= 0) + { + DoRecvPayload(m_RecvHeader); + } + else if (m_RecvHeader.Size == ControlDetach) + { + if (auto It = m_DetachHandlers.find(m_RecvHeader.Channel); It != m_DetachHandlers.end() && It->second) + { + It->second(); + } + DoRecvHeader(); + } + else if (m_RecvHeader.Size == ControlPing) + { + DoRecvHeader(); + } + else + { + ZEN_WARN("invalid frame header size: {}", m_RecvHeader.Size); + } + })); +} - Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter())); +void +AsyncComputeSocket::DoRecvPayload(FrameHeader Header) +{ + auto PayloadBuf = std::make_shared<std::vector<uint8_t>>(static_cast<size_t>(Header.Size)); + auto Self = shared_from_this(); - // Attach recv buffer writer (transport recv thread writes into this) - { - std::lock_guard<std::mutex> Lock(m_WritersMutex); - m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter()); - } + m_Transport->AsyncRead(PayloadBuf->data(), + PayloadBuf->size(), + asio::bind_executor(m_Strand, [this, Self, Header, PayloadBuf](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && !m_Closed) + { + ZEN_WARN("recv payload error (channel={}, size={}): {}", Header.Channel, Header.Size, Ec.message()); + HandleError(); + } + return; + } - // Attach send buffer reader (send thread reads from this) - { - ComputeBufferReader Reader = SendBuffer.CreateReader(); - m_Readers.push_back(Reader); - m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader))); - } + if (m_Closed) + { + return; + } + + if (auto It = m_FrameHandlers.find(Header.Channel); It != m_FrameHandlers.end() && It->second) + { + It->second(std::move(*PayloadBuf)); + } + else + { + ZEN_WARN("recv frame for unknown channel {}", Header.Channel); + } - return Channel; + DoRecvHeader(); + })); } void -ComputeSocket::StartCommunication() +AsyncComputeSocket::AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler) { - m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this); - m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this); + auto Self = shared_from_this(); + asio::dispatch(m_Strand, [this, Self, ChannelId, Data = std::move(Data), Handler = std::move(Handler)]() mutable { + if (m_Closed) + { + if (Handler) + { + Handler(asio::error::make_error_code(asio::error::operation_aborted)); + } + return; + } + + PendingWrite Write; + Write.Header.Channel = ChannelId; + Write.Header.Size = static_cast<int32_t>(Data.size()); + Write.Data = std::move(Data); + Write.Handler = std::move(Handler); + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) + { + FlushNextSend(); + } + }); } void -ComputeSocket::PingThreadProc() +AsyncComputeSocket::AsyncSendDetach(int ChannelId, SendHandler Handler) { - while (true) - { + auto Self = shared_from_this(); + asio::dispatch(m_Strand, [this, Self, ChannelId, Handler = std::move(Handler)]() mutable { + if (m_Closed) { - std::unique_lock<std::mutex> Lock(m_PingMutex); - if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; })) + if (Handler) { - break; + Handler(asio::error::make_error_code(asio::error::operation_aborted)); } + return; } - std::lock_guard<std::mutex> Lock(m_SendMutex); - FrameHeader Header; - Header.Channel = 0; - Header.Size = ControlPing; - m_Transport->SendMessage(&Header, sizeof(Header)); - } + PendingWrite Write; + Write.Header.Channel = ChannelId; + Write.Header.Size = ControlDetach; + Write.Handler = std::move(Handler); + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) + { + FlushNextSend(); + } + }); } void -ComputeSocket::RecvThreadProc() +AsyncComputeSocket::FlushNextSend() { - // Writers are cached locally to avoid taking m_WritersMutex on every frame. - // The shared m_Writers map is only accessed when a channel is seen for the first time. - std::unordered_map<int, ComputeBufferWriter> CachedWriters; + if (m_SendQueue.empty() || m_Closed) + { + return; + } - FrameHeader Header; - while (m_Transport->RecvMessage(&Header, sizeof(Header))) + PendingWrite& Front = m_SendQueue.front(); + + if (Front.Data.empty()) { - if (Header.Size >= 0) - { - // Data frame - auto It = CachedWriters.find(Header.Channel); - if (It == CachedWriters.end()) - { - std::lock_guard<std::mutex> Lock(m_WritersMutex); - auto WIt = m_Writers.find(Header.Channel); - if (WIt == m_Writers.end()) - { - ZEN_WARN("recv frame for unknown channel {}", Header.Channel); - // Skip the data - std::vector<uint8_t> Discard(Header.Size); - m_Transport->RecvMessage(Discard.data(), Header.Size); - continue; - } - It = CachedWriters.emplace(Header.Channel, WIt->second).first; - } + // Control frame - header only + auto Self = shared_from_this(); + m_Transport->AsyncWrite(&Front.Header, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); - ComputeBufferWriter& Writer = It->second; - uint8_t* Dest = Writer.WaitToWrite(Header.Size); - if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size)) - { - ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size); - return; - } - Writer.AdvanceWritePosition(Header.Size); - } - else if (Header.Size == ControlDetach) - { - // Detach the recv buffer for this channel - CachedWriters.erase(Header.Channel); + if (Handler) + { + Handler(Ec); + } - std::lock_guard<std::mutex> Lock(m_WritersMutex); - auto It = m_Writers.find(Header.Channel); - if (It != m_Writers.end()) - { - It->second.MarkComplete(); - m_Writers.erase(It); - } - } - else if (Header.Size == ControlPing) + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send error: {}", Ec.message()); + HandleError(); + } + return; + } + + FlushNextSend(); + })); + } + else + { + // Data frame - write header first, then payload + auto Self = shared_from_this(); + m_Transport->AsyncWrite(&Front.Header, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); + if (Handler) + { + Handler(Ec); + } + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send header error: {}", Ec.message()); + HandleError(); + } + return; + } + + PendingWrite& Payload = m_SendQueue.front(); + m_Transport->AsyncWrite( + Payload.Data.data(), + Payload.Data.size(), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); + + if (Handler) + { + Handler(Ec); + } + + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send payload error: {}", Ec.message()); + HandleError(); + } + return; + } + + FlushNextSend(); + })); + })); + } +} + +void +AsyncComputeSocket::StartPingTimer() +{ + if (m_Closed) + { + return; + } + + m_PingTimer.expires_after(std::chrono::seconds(2)); + + auto Self = shared_from_this(); + m_PingTimer.async_wait(asio::bind_executor(m_Strand, [this, Self](const asio::error_code& Ec) { + if (Ec || m_Closed) { - // Ping response - ignore + return; } - else + + // Enqueue a ping control frame + PendingWrite Write; + Write.Header.Channel = 0; + Write.Header.Size = ControlPing; + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) { - ZEN_WARN("invalid frame header size: {}", Header.Size); - return; + FlushNextSend(); } - } + + StartPingTimer(); + })); } void -ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader) +AsyncComputeSocket::HandleError() { - // Each channel has its own send thread. All send threads share m_SendMutex - // to serialize writes to the transport, since TCP requires atomic frame writes. - FrameHeader Header; - Header.Channel = Channel; + if (m_Closed) + { + return; + } + + Close(); - const uint8_t* Data; - while ((Data = Reader.WaitToRead(1)) != nullptr) + // Notify all channels that the connection is gone so agents can clean up + for (auto& [ChannelId, Handler] : m_DetachHandlers) { - std::lock_guard<std::mutex> Lock(m_SendMutex); + if (Handler) + { + Handler(); + } + } +} - Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize()); - m_Transport->SendMessage(&Header, sizeof(Header)); - m_Transport->SendMessage(Data, Header.Size); - Reader.AdvanceReadPosition(Header.Size); +void +AsyncComputeSocket::Close() +{ + if (m_Closed) + { + return; } - if (Reader.IsComplete()) + m_Closed = true; + m_PingTimer.cancel(); + + if (m_Transport) { - std::lock_guard<std::mutex> Lock(m_SendMutex); - Header.Size = ControlDetach; - m_Transport->SendMessage(&Header, sizeof(Header)); + m_Transport->Close(); } } diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h index 0c3cb4195..6c494603a 100644 --- a/src/zenhorde/hordecomputesocket.h +++ b/src/zenhorde/hordecomputesocket.h @@ -2,45 +2,74 @@ #pragma once -#include "hordecomputebuffer.h" -#include "hordecomputechannel.h" #include "hordetransport.h" #include <zencore/logbase.h> -#include <condition_variable> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +#include <deque> +#include <functional> #include <memory> -#include <mutex> -#include <thread> +#include <system_error> #include <unordered_map> #include <vector> namespace zen::horde { -/** Multiplexed socket that routes data between multiple channels over a single transport. +class AsyncComputeTransport; + +/** Handler called when a data frame arrives for a channel. */ +using FrameHandler = std::function<void(std::vector<uint8_t> Data)>; + +/** Handler called when a channel is detached by the remote peer. */ +using DetachHandler = std::function<void()>; + +/** Handler for async send completion. */ +using SendHandler = std::function<void(const std::error_code&)>; + +/** Async multiplexed socket that routes data between channels over a single transport. * - * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers. - * A recv thread demultiplexes incoming frames to channel-specific buffers, while - * per-channel send threads multiplex outgoing data onto the shared transport. + * Uses an async recv pump, a serialized send queue, and a periodic ping timer - + * all running on a shared io_context. * - * Wire format per frame: [channelId (4B)][size (4B)][data] - * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping. + * Wire format per frame: [channelId(4B)][size(4B)][data]. + * Control messages use negative sizes: -2 = detach, -3 = ping. */ -class ComputeSocket +class AsyncComputeSocket : public std::enable_shared_from_this<AsyncComputeSocket> { public: - explicit ComputeSocket(std::unique_ptr<ComputeTransport> Transport); - ~ComputeSocket(); + AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext); + ~AsyncComputeSocket(); + + AsyncComputeSocket(const AsyncComputeSocket&) = delete; + AsyncComputeSocket& operator=(const AsyncComputeSocket&) = delete; + + /** Register callbacks for a channel. Must be called before StartRecvPump(). */ + void RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach); - ComputeSocket(const ComputeSocket&) = delete; - ComputeSocket& operator=(const ComputeSocket&) = delete; + /** Begin the async recv pump and ping timer. */ + void StartRecvPump(); - /** Create a channel with the given ID. - * Allocates anonymous in-process buffers and spawns a send thread for the channel. */ - Ref<ComputeChannel> CreateChannel(int ChannelId); + /** Enqueue a data frame for async transmission. */ + void AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler = {}); - /** Start the recv pump and ping threads. Must be called after all channels are created. */ - void StartCommunication(); + /** Send a control frame (detach) for a channel. */ + void AsyncSendDetach(int ChannelId, SendHandler Handler = {}); + + /** Close the transport and cancel all pending operations. */ + void Close(); + + /** The strand on which all socket I/O callbacks run. Channels that need to serialize + * their own state with OnFrame/OnDetach (which are invoked from this strand) should + * bind their timers and async operations to it as well. */ + asio::strand<asio::any_io_executor>& GetStrand() { return m_Strand; } private: struct FrameHeader @@ -49,31 +78,35 @@ private: int32_t Size = 0; }; + struct PendingWrite + { + FrameHeader Header; + std::vector<uint8_t> Data; + SendHandler Handler; + }; + static constexpr int32_t ControlDetach = -2; static constexpr int32_t ControlPing = -3; LoggerRef Log() { return m_Log; } - void RecvThreadProc(); - void SendThreadProc(int Channel, ComputeBufferReader Reader); - void PingThreadProc(); - - LoggerRef m_Log; - std::unique_ptr<ComputeTransport> m_Transport; - std::mutex m_SendMutex; ///< Serializes writes to the transport - - std::mutex m_WritersMutex; - std::unordered_map<int, ComputeBufferWriter> m_Writers; ///< Recv-side: writers keyed by channel ID + void DoRecvHeader(); + void DoRecvPayload(FrameHeader Header); + void FlushNextSend(); + void StartPingTimer(); + void HandleError(); - std::vector<ComputeBufferReader> m_Readers; ///< Send-side: readers for join on destruction - std::unordered_map<int, std::thread> m_SendThreads; ///< One send thread per channel + LoggerRef m_Log; + std::unique_ptr<AsyncComputeTransport> m_Transport; + asio::strand<asio::any_io_executor> m_Strand; + asio::steady_timer m_PingTimer; - std::thread m_RecvThread; - std::thread m_PingThread; + std::unordered_map<int, FrameHandler> m_FrameHandlers; + std::unordered_map<int, DetachHandler> m_DetachHandlers; - bool m_PingShouldStop = false; - std::mutex m_PingMutex; - std::condition_variable m_PingCV; + FrameHeader m_RecvHeader; + std::deque<PendingWrite> m_SendQueue; + bool m_Closed = false; }; } // namespace zen::horde diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp index 2dca228d9..9f6125c64 100644 --- a/src/zenhorde/hordeconfig.cpp +++ b/src/zenhorde/hordeconfig.cpp @@ -1,5 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#include <zencore/logging.h> +#include <zencore/string.h> #include <zenhorde/hordeconfig.h> namespace zen::horde { @@ -9,12 +11,14 @@ HordeConfig::Validate() const { if (ServerUrl.empty()) { + ZEN_WARN("Horde server URL is not configured"); return false; } // Relay mode implies AES encryption if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES) { + ZEN_WARN("Horde relay mode requires AES encryption, but encryption is set to '{}'", ToString(EncryptionMode)); return false; } @@ -52,37 +56,39 @@ ToString(Encryption Enc) bool FromString(ConnectionMode& OutMode, std::string_view Str) { - if (Str == "direct") + if (StrCaseCompare(Str, "direct") == 0) { OutMode = ConnectionMode::Direct; return true; } - if (Str == "tunnel") + if (StrCaseCompare(Str, "tunnel") == 0) { OutMode = ConnectionMode::Tunnel; return true; } - if (Str == "relay") + if (StrCaseCompare(Str, "relay") == 0) { OutMode = ConnectionMode::Relay; return true; } + ZEN_WARN("unrecognized Horde connection mode: '{}'", Str); return false; } bool FromString(Encryption& OutEnc, std::string_view Str) { - if (Str == "none") + if (StrCaseCompare(Str, "none") == 0) { OutEnc = Encryption::None; return true; } - if (Str == "aes") + if (StrCaseCompare(Str, "aes") == 0) { OutEnc = Encryption::AES; return true; } + ZEN_WARN("unrecognized Horde encryption mode: '{}'", Str); return false; } diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp index f88c95da2..ea0ea1e83 100644 --- a/src/zenhorde/hordeprovisioner.cpp +++ b/src/zenhorde/hordeprovisioner.cpp @@ -6,49 +6,83 @@ #include "hordeagent.h" #include "hordebundle.h" +#include <zencore/compactbinary.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/scopeguard.h> #include <zencore/thread.h> #include <zencore/trace.h> +#include <zenhttp/httpclient.h> +#include <zenutil/workerpools.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <algorithm> #include <chrono> #include <thread> namespace zen::horde { -struct HordeProvisioner::AgentWrapper -{ - std::thread Thread; - std::atomic<bool> ShouldExit{false}; -}; - HordeProvisioner::HordeProvisioner(const HordeConfig& Config, const std::filesystem::path& BinariesPath, const std::filesystem::path& WorkingDir, - std::string_view OrchestratorEndpoint) + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession, + bool CleanStart, + std::string_view TraceHost) : m_Config(Config) , m_BinariesPath(BinariesPath) , m_WorkingDir(WorkingDir) , m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_CoordinatorSession(CoordinatorSession) +, m_CleanStart(CleanStart) +, m_TraceHost(TraceHost) , m_Log(zen::logging::Get("horde.provisioner")) { + m_IoContext = std::make_unique<asio::io_context>(); + + auto Work = asio::make_work_guard(*m_IoContext); + for (int i = 0; i < IoThreadCount; ++i) + { + m_IoThreads.emplace_back([this, i, Work] { + zen::SetCurrentThreadName(fmt::format("horde_io_{}", i)); + m_IoContext->run(); + }); + } } HordeProvisioner::~HordeProvisioner() { - std::lock_guard<std::mutex> Lock(m_AgentsLock); - for (auto& Agent : m_Agents) + m_AskForAgents.store(false); + m_ShutdownEvent.Set(); + + // Shut down async agents and io_context { - Agent->ShouldExit.store(true); + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + for (auto& Entry : m_AsyncAgents) + { + Entry.Agent->Cancel(); + } + m_AsyncAgents.clear(); } - for (auto& Agent : m_Agents) + + m_IoContext->stop(); + + for (auto& Thread : m_IoThreads) { - if (Agent->Thread.joinable()) + if (Thread.joinable()) { - Agent->Thread.join(); + Thread.join(); } } + + // Wait for all pool work items to finish before destroying members they reference + if (m_PendingWorkItems.load() > 0) + { + m_AllWorkDone.Wait(); + } } void @@ -56,9 +90,23 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count) { ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount"); - m_TargetCoreCount.store(std::min(Count, static_cast<uint32_t>(m_Config.MaxCores))); + const uint32_t ClampedCount = std::min(Count, static_cast<uint32_t>(m_Config.MaxCores)); + const uint32_t PreviousTarget = m_TargetCoreCount.exchange(ClampedCount); + + if (ClampedCount != PreviousTarget) + { + ZEN_INFO("target core count changed: {} -> {} (active={}, estimated={})", + PreviousTarget, + ClampedCount, + m_ActiveCoreCount.load(), + m_EstimatedCoreCount.load()); + } - while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load()) + // Only provision if the gap is at least one agent-sized chunk. Without + // this, draining a 32-core agent to cover a 28-core excess would leave a + // 4-core gap that triggers a 32-core provision, which triggers another + // drain, ad infinitum. + while (m_EstimatedCoreCount.load() + EstimatedCoresPerAgent <= m_TargetCoreCount.load()) { if (!m_AskForAgents.load()) { @@ -67,21 +115,108 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count) RequestAgent(); } - // Clean up finished agent threads - std::lock_guard<std::mutex> Lock(m_AgentsLock); - for (auto It = m_Agents.begin(); It != m_Agents.end();) + // Scale down async agents { - if ((*It)->ShouldExit.load()) + std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock); + + uint32_t AsyncActive = m_ActiveCoreCount.load(); + uint32_t AsyncTarget = m_TargetCoreCount.load(); + + uint32_t AlreadyDrainingCores = 0; + for (const auto& Entry : m_AsyncAgents) { - if ((*It)->Thread.joinable()) + if (Entry.Draining) { - (*It)->Thread.join(); + AlreadyDrainingCores += Entry.CoreCount; } - It = m_Agents.erase(It); } - else + + uint32_t EffectiveAsync = (AsyncActive > AlreadyDrainingCores) ? AsyncActive - AlreadyDrainingCores : 0; + + if (EffectiveAsync > AsyncTarget) { - ++It; + struct Candidate + { + AsyncAgentEntry* Entry; + int Workload; + }; + std::vector<Candidate> Candidates; + + for (auto& Entry : m_AsyncAgents) + { + if (Entry.Draining || Entry.RemoteEndpoint.empty()) + { + continue; + } + + int Workload = 0; + bool Reachable = false; + HttpClientSettings Settings; + Settings.LogCategory = "horde.drain"; + Settings.ConnectTimeout = std::chrono::milliseconds{2000}; + Settings.Timeout = std::chrono::milliseconds{3000}; + try + { + HttpClient Client(Entry.RemoteEndpoint, Settings); + HttpClient::Response Resp = Client.Get("/compute/session/status"); + if (Resp.IsSuccess()) + { + CbObject Status = Resp.AsObject(); + Workload = Status["actions_pending"].AsInt32(0) + Status["actions_running"].AsInt32(0); + Reachable = true; + } + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("agent lease={} not yet reachable for drain: {}", Entry.LeaseId, Ex.what()); + } + + if (Reachable) + { + Candidates.push_back({&Entry, Workload}); + } + } + + const uint32_t ExcessCores = EffectiveAsync - AsyncTarget; + uint32_t CoresDrained = 0; + + while (CoresDrained < ExcessCores && !Candidates.empty()) + { + const uint32_t Remaining = ExcessCores - CoresDrained; + + Candidates.erase(std::remove_if(Candidates.begin(), + Candidates.end(), + [Remaining](const Candidate& C) { return C.Entry->CoreCount > Remaining; }), + Candidates.end()); + + if (Candidates.empty()) + { + break; + } + + Candidate* Best = &Candidates[0]; + for (auto& C : Candidates) + { + if (C.Entry->CoreCount > Best->Entry->CoreCount || + (C.Entry->CoreCount == Best->Entry->CoreCount && C.Workload < Best->Workload)) + { + Best = &C; + } + } + + ZEN_INFO("draining async agent lease={} ({} cores, workload={})", + Best->Entry->LeaseId, + Best->Entry->CoreCount, + Best->Workload); + + DrainAsyncAgent(*Best->Entry); + CoresDrained += Best->Entry->CoreCount; + + AsyncAgentEntry* Drained = Best->Entry; + Candidates.erase( + std::remove_if(Candidates.begin(), Candidates.end(), [Drained](const Candidate& C) { return C.Entry == Drained; }), + Candidates.end()); + } } } } @@ -101,266 +236,395 @@ HordeProvisioner::GetStats() const uint32_t HordeProvisioner::GetAgentCount() const { - std::lock_guard<std::mutex> Lock(m_AgentsLock); - return static_cast<uint32_t>(m_Agents.size()); + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + return static_cast<uint32_t>(m_AsyncAgents.size()); } -void -HordeProvisioner::RequestAgent() +compute::AgentProvisioningStatus +HordeProvisioner::GetAgentStatus(std::string_view WorkerId) const { - m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); + // Worker IDs are "horde-{LeaseId}" - strip the prefix to match lease ID + constexpr std::string_view Prefix = "horde-"; + if (!WorkerId.starts_with(Prefix)) + { + return compute::AgentProvisioningStatus::Unknown; + } + std::string_view LeaseId = WorkerId.substr(Prefix.size()); - std::lock_guard<std::mutex> Lock(m_AgentsLock); + std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock); + for (const auto& Entry : m_AsyncAgents) + { + if (Entry.LeaseId == LeaseId) + { + if (Entry.Draining) + { + return compute::AgentProvisioningStatus::Draining; + } + return compute::AgentProvisioningStatus::Active; + } + } - auto Wrapper = std::make_unique<AgentWrapper>(); - AgentWrapper& Ref = *Wrapper; - Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); }); + // Check recently-drained agents that have already been cleaned up + std::string WorkerIdStr(WorkerId); + if (m_RecentlyDrainedWorkerIds.erase(WorkerIdStr) > 0) + { + // Also remove from the ordering queue so size accounting stays consistent. + auto It = std::find(m_RecentlyDrainedOrder.begin(), m_RecentlyDrainedOrder.end(), WorkerIdStr); + if (It != m_RecentlyDrainedOrder.end()) + { + m_RecentlyDrainedOrder.erase(It); + } + return compute::AgentProvisioningStatus::Draining; + } - m_Agents.push_back(std::move(Wrapper)); + return compute::AgentProvisioningStatus::Unknown; } -void -HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) +std::vector<std::string> +HordeProvisioner::BuildAgentArgs(const MachineInfo& Machine) const { - ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent"); + std::vector<std::string> Args; + Args.emplace_back("compute"); + Args.emplace_back("--http=asio"); + Args.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); + Args.emplace_back("--data-dir=%UE_HORDE_SHARED_DIR%\\zen"); - static std::atomic<uint32_t> ThreadIndex{0}; - const uint32_t CurrentIndex = ThreadIndex.fetch_add(1); + if (m_CleanStart) + { + Args.emplace_back("--clean"); + } - zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex)); + if (!m_OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; + Args.emplace_back(CoordArg.ToView()); + } - std::unique_ptr<HordeAgent> Agent; - uint32_t MachineCoreCount = 0; + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=horde-" << Machine.LeaseId; + Args.emplace_back(IdArg.ToView()); + } - auto _ = MakeGuard([&] { - if (Agent) - { - Agent->CloseConnection(); - } - Wrapper.ShouldExit.store(true); - }); + if (!m_CoordinatorSession.empty()) + { + ExtendableStringBuilder<128> SessionArg; + SessionArg << "--coordinator-session=" << m_CoordinatorSession; + Args.emplace_back(SessionArg.ToView()); + } + if (!m_TraceHost.empty()) { - // EstimatedCoreCount is incremented speculatively when the agent is requested - // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. - auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); + ExtendableStringBuilder<128> TraceArg; + TraceArg << "--tracehost=" << m_TraceHost; + Args.emplace_back(TraceArg.ToView()); + } + // In relay mode, the remote zenserver's local address is not reachable from the + // orchestrator. Pass the relay-visible endpoint so it announces the correct URL. + if (Machine.Mode == ConnectionMode::Relay) + { + const auto [Addr, Port] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); + if (Addr.find(':') != std::string::npos) + { + Args.push_back(fmt::format("--announce-url=http://[{}]:{}", Addr, Port)); + } + else { - ZEN_TRACE_CPU("HordeProvisioner::CreateBundles"); + Args.push_back(fmt::format("--announce-url=http://{}:{}", Addr, Port)); + } + } - std::lock_guard<std::mutex> BundleLock(m_BundleLock); + return Args; +} - if (!m_BundlesCreated) - { - const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; +bool +HordeProvisioner::InitializeHordeClient() +{ + ZEN_TRACE_CPU("HordeProvisioner::InitializeHordeClient"); + + std::lock_guard<std::mutex> BundleLock(m_BundleLock); - std::vector<BundleFile> Files; + if (!m_BundlesCreated) + { + const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; + + std::vector<BundleFile> Files; #if ZEN_PLATFORM_WINDOWS - Files.emplace_back(m_BinariesPath / "zenserver.exe", false); + Files.emplace_back(m_BinariesPath / "zenserver.exe", false); + Files.emplace_back(m_BinariesPath / "zenserver.pdb", true); #elif ZEN_PLATFORM_LINUX - Files.emplace_back(m_BinariesPath / "zenserver", false); - Files.emplace_back(m_BinariesPath / "zenserver.debug", true); + Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver.debug", true); #elif ZEN_PLATFORM_MAC - Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver", false); #endif - BundleResult Result; - if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) - { - ZEN_WARN("failed to create bundle, cannot provision any agents!"); - m_AskForAgents.store(false); - return; - } - - m_Bundles.emplace_back(Result.Locator, Result.BundleDir); - m_BundlesCreated = true; - } - - if (!m_HordeClient) - { - m_HordeClient = std::make_unique<HordeClient>(m_Config); - if (!m_HordeClient->Initialize()) - { - ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); - m_AskForAgents.store(false); - return; - } - } + BundleResult Result; + if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) + { + ZEN_WARN("failed to create bundle, cannot provision any agents!"); + m_AskForAgents.store(false); + m_ShutdownEvent.Set(); + return false; } - if (!m_AskForAgents.load()) + m_Bundles.emplace_back(Result.Locator, Result.BundleDir); + m_BundlesCreated = true; + } + + if (!m_HordeClient) + { + m_HordeClient = std::make_unique<HordeClient>(m_Config); + if (!m_HordeClient->Initialize()) { - return; + ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); + m_AskForAgents.store(false); + m_ShutdownEvent.Set(); + return false; } + } - m_AgentsRequesting.fetch_add(1); - auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); + return true; +} - // Simple backoff: if the last machine request failed, wait up to 5 seconds - // before trying again. - // - // Note however that it's possible that multiple threads enter this code at - // the same time if multiple agents are requested at once, and they will all - // see the same last failure time and back off accordingly. We might want to - // use a semaphore or similar to limit the number of concurrent requests. +void +HordeProvisioner::RequestAgent() +{ + m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); - if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) - { - auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()); - const uint64_t ElapsedNs = Now - LastFail; - const uint64_t ElapsedMs = ElapsedNs / 1'000'000; - if (ElapsedMs < 5000) - { - const uint64_t WaitMs = 5000 - ElapsedMs; - for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100) - { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } + if (m_PendingWorkItems.fetch_add(1) == 0) + { + m_AllWorkDone.Reset(); + } - if (Wrapper.ShouldExit.load()) + GetSmallWorkerPool(EWorkloadType::Background) + .ScheduleWork( + [this] { + ProvisionAgent(); + if (m_PendingWorkItems.fetch_sub(1) == 1) { - return; + m_AllWorkDone.Set(); } - } - } + }, + WorkerThreadPool::EMode::EnableBacklog); +} - if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) - { - return; - } +void +HordeProvisioner::ProvisionAgent() +{ + ZEN_TRACE_CPU("HordeProvisioner::ProvisionAgent"); + + // EstimatedCoreCount is incremented speculatively when the agent is requested + // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. + auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); - std::string RequestBody = m_HordeClient->BuildRequestBody(); + if (!InitializeHordeClient()) + { + return; + } - // Resolve cluster if needed - std::string ClusterId = m_Config.Cluster; - if (ClusterId == HordeConfig::ClusterAuto) + if (!m_AskForAgents.load()) + { + return; + } + + m_AgentsRequesting.fetch_add(1); + auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); + + // Simple backoff: if the last machine request failed, wait up to 5 seconds + // before trying again. + // + // Note however that it's possible that multiple threads enter this code at + // the same time if multiple agents are requested at once, and they will all + // see the same last failure time and back off accordingly. We might want to + // use a semaphore or similar to limit the number of concurrent requests. + + if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) + { + auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()); + const uint64_t ElapsedNs = Now - LastFail; + const uint64_t ElapsedMs = ElapsedNs / 1'000'000; + if (ElapsedMs < 5000) { - ClusterInfo Cluster; - if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) + // Wait on m_ShutdownEvent so shutdown wakes this pool thread immediately instead + // of stalling for up to 5s in 100ms sleep chunks. Wait() returns true iff the + // event was signaled (shutdown); false means the backoff elapsed normally. + const uint64_t WaitMs = 5000 - ElapsedMs; + if (m_ShutdownEvent.Wait(static_cast<int>(WaitMs))) { - ZEN_WARN("failed to resolve cluster"); - m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); return; } - ClusterId = Cluster.ClusterId; } + } - MachineInfo Machine; - if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) + if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) + { + return; + } + + std::string RequestBody = m_HordeClient->BuildRequestBody(); + + // Resolve cluster if needed + std::string ClusterId = m_Config.Cluster; + if (ClusterId == HordeConfig::ClusterAuto) + { + ClusterInfo Cluster; + if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) { + ZEN_WARN("failed to resolve cluster"); m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); return; } + ClusterId = Cluster.ClusterId; + } - m_LastRequestFailTime.store(0); + ZEN_INFO("requesting machine from Horde (cluster='{}', cores={}/{})", + ClusterId.empty() ? "default" : ClusterId.c_str(), + m_ActiveCoreCount.load(), + m_TargetCoreCount.load()); - if (Wrapper.ShouldExit.load()) - { - return; - } + MachineInfo Machine; + if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) + { + m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } - // Connect to agent and perform handshake - Agent = std::make_unique<HordeAgent>(Machine); - if (!Agent->IsValid()) - { - ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort()); - return; - } + m_LastRequestFailTime.store(0); - if (!Agent->BeginCommunication()) - { - ZEN_WARN("BeginCommunication failed"); - return; - } + if (!m_AskForAgents.load()) + { + return; + } - for (auto& [Locator, BundleDir] : m_Bundles) + AsyncAgentConfig AgentConfig; + AgentConfig.Machine = Machine; + AgentConfig.Bundles = m_Bundles; + AgentConfig.Args = BuildAgentArgs(Machine); + +#if ZEN_PLATFORM_WINDOWS + AgentConfig.UseWine = !Machine.IsWindows; + AgentConfig.Executable = "zenserver.exe"; +#else + AgentConfig.UseWine = false; + AgentConfig.Executable = "zenserver"; +#endif + + auto AsyncAgent = std::make_shared<AsyncHordeAgent>(*m_IoContext); + + AsyncAgentEntry Entry; + Entry.Agent = AsyncAgent; + Entry.LeaseId = Machine.LeaseId; + Entry.CoreCount = Machine.LogicalCores; + + const auto [EndpointAddr, EndpointPort] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); + if (EndpointAddr.find(':') != std::string::npos) + { + Entry.RemoteEndpoint = fmt::format("http://[{}]:{}", EndpointAddr, EndpointPort); + } + else + { + Entry.RemoteEndpoint = fmt::format("http://{}:{}", EndpointAddr, EndpointPort); + } + + { + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + m_AsyncAgents.push_back(std::move(Entry)); + } + + AsyncAgent->Start(std::move(AgentConfig), [this, AsyncAgent](const AsyncAgentResult& Result) { + if (Result.CoreCount > 0) { - if (Wrapper.ShouldExit.load()) + // Only subtract estimated cores if not already subtracted by DrainAsyncAgent + bool WasDraining = false; { - return; + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + for (const auto& Entry : m_AsyncAgents) + { + if (Entry.Agent == AsyncAgent) + { + WasDraining = Entry.Draining; + break; + } + } } - if (!Agent->UploadBinaries(BundleDir, Locator)) + if (!WasDraining) { - ZEN_WARN("UploadBinaries failed"); - return; + m_EstimatedCoreCount.fetch_sub(Result.CoreCount); } + m_ActiveCoreCount.fetch_sub(Result.CoreCount); + m_AgentsActive.fetch_sub(1); } + OnAsyncAgentDone(AsyncAgent); + }); - if (Wrapper.ShouldExit.load()) - { - return; - } - - // Build command line for remote zenserver - std::vector<std::string> ArgStrings; - ArgStrings.push_back("compute"); - ArgStrings.push_back("--http=asio"); + // Track active cores (estimated was already added by RequestAgent) + m_EstimatedCoreCount.fetch_add(Machine.LogicalCores); + m_ActiveCoreCount.fetch_add(Machine.LogicalCores); + m_AgentsActive.fetch_add(1); +} - // TEMP HACK - these should be made fully dynamic - // these are currently here to allow spawning the compute agent locally - // for debugging purposes (i.e with a local Horde Server+Agent setup) - ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); - ArgStrings.push_back("--data-dir=c:\\temp\\123"); +void +HordeProvisioner::DrainAsyncAgent(AsyncAgentEntry& Entry) +{ + Entry.Draining = true; + m_EstimatedCoreCount.fetch_sub(Entry.CoreCount); + m_AgentsDraining.fetch_add(1); - if (!m_OrchestratorEndpoint.empty()) - { - ExtendableStringBuilder<256> CoordArg; - CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; - ArgStrings.emplace_back(CoordArg.ToView()); - } + HttpClientSettings Settings; + Settings.LogCategory = "horde.drain"; + Settings.ConnectTimeout = std::chrono::milliseconds{5000}; + Settings.Timeout = std::chrono::milliseconds{10000}; - { - ExtendableStringBuilder<128> IdArg; - IdArg << "--instance-id=horde-" << Machine.LeaseId; - ArgStrings.emplace_back(IdArg.ToView()); - } + try + { + HttpClient Client(Entry.RemoteEndpoint, Settings); - std::vector<const char*> Args; - Args.reserve(ArgStrings.size()); - for (const std::string& Arg : ArgStrings) + HttpClient::Response Response = Client.Post("/compute/session/drain"); + if (!Response.IsSuccess()) { - Args.push_back(Arg.c_str()); + ZEN_WARN("drain[{}]: POST session/drain failed: HTTP {}", Entry.LeaseId, static_cast<int>(Response.StatusCode)); + return; } -#if ZEN_PLATFORM_WINDOWS - const bool UseWine = !Machine.IsWindows; - const char* AppName = "zenserver.exe"; -#else - const bool UseWine = false; - const char* AppName = "zenserver"; -#endif - - Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine); - - ZEN_INFO("remote execution started on [{}:{}] lease={}", - Machine.GetConnectionAddress(), - Machine.GetConnectionPort(), - Machine.LeaseId); - - MachineCoreCount = Machine.LogicalCores; - m_EstimatedCoreCount.fetch_add(MachineCoreCount); - m_ActiveCoreCount.fetch_add(MachineCoreCount); - m_AgentsActive.fetch_add(1); + ZEN_INFO("drain[{}]: session/drain accepted, sending sunset", Entry.LeaseId); + (void)Client.Post("/compute/session/sunset"); } + catch (const std::exception& Ex) + { + ZEN_WARN("drain[{}]: exception: {}", Entry.LeaseId, Ex.what()); + } +} - // Agent poll loop - - auto ActiveGuard = MakeGuard([&]() { - m_EstimatedCoreCount.fetch_sub(MachineCoreCount); - m_ActiveCoreCount.fetch_sub(MachineCoreCount); - m_AgentsActive.fetch_sub(1); - }); - - while (Agent->IsValid() && !Wrapper.ShouldExit.load()) +void +HordeProvisioner::OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent) +{ + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + for (auto It = m_AsyncAgents.begin(); It != m_AsyncAgents.end(); ++It) { - const bool LogOutput = false; - if (!Agent->Poll(LogOutput)) + if (It->Agent == Agent) { + if (It->Draining) + { + m_AgentsDraining.fetch_sub(1); + std::string WorkerId = "horde-" + It->LeaseId; + if (m_RecentlyDrainedWorkerIds.insert(WorkerId).second) + { + m_RecentlyDrainedOrder.push_back(WorkerId); + while (m_RecentlyDrainedOrder.size() > RecentlyDrainedCapacity) + { + m_RecentlyDrainedWorkerIds.erase(m_RecentlyDrainedOrder.front()); + m_RecentlyDrainedOrder.pop_front(); + } + } + } + m_AsyncAgents.erase(It); break; } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); } } diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp index 69766e73e..65eaea477 100644 --- a/src/zenhorde/hordetransport.cpp +++ b/src/zenhorde/hordetransport.cpp @@ -9,71 +9,33 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio.hpp> ZEN_THIRD_PARTY_INCLUDES_END -#if ZEN_PLATFORM_WINDOWS -# undef SendMessage -#endif - namespace zen::horde { -// ComputeTransport base +// --- AsyncTcpComputeTransport --- -bool -ComputeTransport::SendMessage(const void* Data, size_t Size) +struct AsyncTcpComputeTransport::Impl { - const uint8_t* Ptr = static_cast<const uint8_t*>(Data); - size_t Remaining = Size; - - while (Remaining > 0) - { - const size_t Sent = Send(Ptr, Remaining); - if (Sent == 0) - { - return false; - } - Ptr += Sent; - Remaining -= Sent; - } + asio::io_context& IoContext; + asio::ip::tcp::socket Socket; - return true; -} + explicit Impl(asio::io_context& Ctx) : IoContext(Ctx), Socket(Ctx) {} +}; -bool -ComputeTransport::RecvMessage(void* Data, size_t Size) +AsyncTcpComputeTransport::AsyncTcpComputeTransport(asio::io_context& IoContext) +: m_Impl(std::make_unique<Impl>(IoContext)) +, m_Log(zen::logging::Get("horde.transport.async")) { - uint8_t* Ptr = static_cast<uint8_t*>(Data); - size_t Remaining = Size; - - while (Remaining > 0) - { - const size_t Received = Recv(Ptr, Remaining); - if (Received == 0) - { - return false; - } - Ptr += Received; - Remaining -= Received; - } - - return true; } -// TcpComputeTransport - ASIO pimpl - -struct TcpComputeTransport::Impl +AsyncTcpComputeTransport::~AsyncTcpComputeTransport() { - asio::io_context IoContext; - asio::ip::tcp::socket Socket; - - Impl() : Socket(IoContext) {} -}; + Close(); +} -// Uses ASIO in synchronous mode only — no async operations or io_context::run(). -// The io_context is only needed because ASIO sockets require one to be constructed. -TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) -: m_Impl(std::make_unique<Impl>()) -, m_Log(zen::logging::Get("horde.transport")) +void +AsyncTcpComputeTransport::AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler) { - ZEN_TRACE_CPU("TcpComputeTransport::Connect"); + ZEN_TRACE_CPU("AsyncTcpComputeTransport::AsyncConnect"); asio::error_code Ec; @@ -82,80 +44,75 @@ TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) { ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message()); m_HasErrors = true; + asio::post(m_Impl->IoContext, [Handler = std::move(Handler), Ec] { Handler(Ec); }); return; } const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort()); - m_Impl->Socket.connect(Endpoint, Ec); - if (Ec) - { - ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message()); - m_HasErrors = true; - return; - } + // Copy the nonce so it survives past this scope into the async callback + auto NonceBuf = std::make_shared<std::vector<uint8_t>>(Info.Nonce, Info.Nonce + NonceSize); - // Disable Nagle's algorithm for lower latency - m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec); -} + m_Impl->Socket.async_connect(Endpoint, [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("async connect failed: {}", Ec.message()); + m_HasErrors = true; + Handler(Ec); + return; + } -TcpComputeTransport::~TcpComputeTransport() -{ - Close(); + asio::error_code SetOptEc; + m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), SetOptEc); + + // Send the 64-byte nonce as the first thing on the wire + asio::async_write(m_Impl->Socket, + asio::buffer(*NonceBuf), + [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec, size_t /*BytesWritten*/) { + if (Ec) + { + ZEN_WARN("nonce write failed: {}", Ec.message()); + m_HasErrors = true; + } + Handler(Ec); + }); + }); } bool -TcpComputeTransport::IsValid() const +AsyncTcpComputeTransport::IsValid() const { return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed; } -size_t -TcpComputeTransport::Send(const void* Data, size_t Size) +void +AsyncTcpComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; - } - - asio::error_code Ec; - const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec); - - if (Ec) - { - m_HasErrors = true; - return 0; + asio::post(m_Impl->IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - return Sent; + asio::async_write(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler)); } -size_t -TcpComputeTransport::Recv(void* Data, size_t Size) +void +AsyncTcpComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; - } - - asio::error_code Ec; - const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec); - - if (Ec) - { - return 0; + asio::post(m_Impl->IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - return Received; -} - -void -TcpComputeTransport::MarkComplete() -{ + asio::async_read(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler)); } void -TcpComputeTransport::Close() +AsyncTcpComputeTransport::Close() { if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open()) { diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h index 1b178dc0f..b5e841d7a 100644 --- a/src/zenhorde/hordetransport.h +++ b/src/zenhorde/hordetransport.h @@ -8,55 +8,60 @@ #include <cstddef> #include <cstdint> +#include <functional> #include <memory> +#include <system_error> -#if ZEN_PLATFORM_WINDOWS -# undef SendMessage -#endif +namespace asio { +class io_context; +} namespace zen::horde { -/** Abstract base interface for compute transports. +/** Handler types for async transport operations. */ +using AsyncConnectHandler = std::function<void(const std::error_code&)>; +using AsyncIoHandler = std::function<void(const std::error_code&, size_t)>; + +/** Abstract base for asynchronous compute transports. * - * Matches the UE FComputeTransport pattern. Concrete implementations handle - * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides - * blocking message helpers on top. + * All callbacks are invoked on the io_context that was provided at construction. + * Callers are responsible for strand serialization if needed. */ -class ComputeTransport +class AsyncComputeTransport { public: - virtual ~ComputeTransport() = default; + virtual ~AsyncComputeTransport() = default; + + virtual bool IsValid() const = 0; - virtual bool IsValid() const = 0; - virtual size_t Send(const void* Data, size_t Size) = 0; - virtual size_t Recv(void* Data, size_t Size) = 0; - virtual void MarkComplete() = 0; - virtual void Close() = 0; + /** Asynchronous write of exactly Size bytes. Handler called on completion or error. */ + virtual void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) = 0; - /** Blocking send that loops until all bytes are transferred. Returns false on error. */ - bool SendMessage(const void* Data, size_t Size); + /** Asynchronous read of exactly Size bytes into Data. Handler called on completion or error. */ + virtual void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) = 0; - /** Blocking receive that loops until all bytes are transferred. Returns false on error. */ - bool RecvMessage(void* Data, size_t Size); + virtual void Close() = 0; }; -/** TCP socket transport using ASIO. +/** Async TCP transport using ASIO. * - * Connects to the Horde compute endpoint specified by MachineInfo and provides - * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the - * header clean. + * Connects to the Horde compute endpoint and provides async send/receive. + * The socket is created on a caller-provided io_context (shared across agents). */ -class TcpComputeTransport final : public ComputeTransport +class AsyncTcpComputeTransport final : public AsyncComputeTransport { public: - explicit TcpComputeTransport(const MachineInfo& Info); - ~TcpComputeTransport() override; - - bool IsValid() const override; - size_t Send(const void* Data, size_t Size) override; - size_t Recv(void* Data, size_t Size) override; - void MarkComplete() override; - void Close() override; + /** Construct a transport on the given io_context. Does not connect yet. */ + explicit AsyncTcpComputeTransport(asio::io_context& IoContext); + ~AsyncTcpComputeTransport() override; + + /** Asynchronously connect to the endpoint and send the nonce. */ + void AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler); + + bool IsValid() const override; + void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override; + void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override; + void Close() override; private: LoggerRef Log() { return m_Log; } diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp index 505b6bde7..0b94a4397 100644 --- a/src/zenhorde/hordetransportaes.cpp +++ b/src/zenhorde/hordetransportaes.cpp @@ -5,9 +5,12 @@ #include <zencore/logging.h> #include <zencore/trace.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <algorithm> #include <cstring> -#include <random> #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> @@ -22,315 +25,410 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen::horde { -struct AesComputeTransport::CryptoContext -{ - uint8_t Key[KeySize] = {}; - uint8_t EncryptNonce[NonceBytes] = {}; - uint8_t DecryptNonce[NonceBytes] = {}; - bool HasErrors = false; - -#if !ZEN_PLATFORM_WINDOWS - EVP_CIPHER_CTX* EncCtx = nullptr; - EVP_CIPHER_CTX* DecCtx = nullptr; -#endif - - CryptoContext(const uint8_t (&InKey)[KeySize]) - { - memcpy(Key, InKey, KeySize); - - // The encrypt nonce is randomly initialized and then deterministically mutated - // per message via UpdateNonce(). The decrypt nonce is not used — it comes from - // the wire (each received message carries its own nonce in the header). - std::random_device Rd; - std::mt19937 Gen(Rd()); - std::uniform_int_distribution<int> Dist(0, 255); - for (auto& Byte : EncryptNonce) - { - Byte = static_cast<uint8_t>(Dist(Gen)); - } - -#if !ZEN_PLATFORM_WINDOWS - // Drain any stale OpenSSL errors - while (ERR_get_error() != 0) - { - } - - EncCtx = EVP_CIPHER_CTX_new(); - EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); - - DecCtx = EVP_CIPHER_CTX_new(); - EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); -#endif - } +namespace { - ~CryptoContext() - { -#if ZEN_PLATFORM_WINDOWS - SecureZeroMemory(Key, sizeof(Key)); - SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); - SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); -#else - OPENSSL_cleanse(Key, sizeof(Key)); - OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); - OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); - - if (EncCtx) - { - EVP_CIPHER_CTX_free(EncCtx); - } - if (DecCtx) - { - EVP_CIPHER_CTX_free(DecCtx); - } -#endif - } + static constexpr size_t AesNonceBytes = 12; + static constexpr size_t AesTagBytes = 16; - void UpdateNonce() + /** AES-256-GCM crypto context. Not exposed outside this translation unit. */ + struct AesCryptoContext { - uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce); - N32[0]++; - N32[1]--; - N32[2] = N32[0] ^ N32[1]; - } + static constexpr size_t NonceBytes = AesNonceBytes; + static constexpr size_t TagBytes = AesTagBytes; - // Returns total encrypted message size, or 0 on failure - // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] - int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) - { - UpdateNonce(); + uint8_t Key[KeySize] = {}; + uint8_t EncryptNonce[NonceBytes] = {}; + uint8_t DecryptNonce[NonceBytes] = {}; + uint64_t DecryptCounter = 0; ///< Sequence number of the next message to be decrypted (for diagnostics) + bool HasErrors = false; - // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than - // caching but has some overhead. For our use case (relatively large, infrequent messages) - // this is acceptable. #if ZEN_PLATFORM_WINDOWS BCRYPT_ALG_HANDLE hAlg = nullptr; BCRYPT_KEY_HANDLE hKey = nullptr; +#else + EVP_CIPHER_CTX* EncCtx = nullptr; + EVP_CIPHER_CTX* DecCtx = nullptr; +#endif - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = EncryptNonce; - AuthInfo.cbNonce = NonceBytes; - uint8_t Tag[TagBytes] = {}; - AuthInfo.pbTag = Tag; - AuthInfo.cbTag = TagBytes; - - ULONG CipherLen = 0; - NTSTATUS Status = - BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); - - if (!BCRYPT_SUCCESS(Status)) + AesCryptoContext(const uint8_t (&InKey)[KeySize]) { - HasErrors = true; - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - return 0; - } + memcpy(Key, InKey, KeySize); - // Write header: length + nonce - memcpy(Out, &InLength, 4); - memcpy(Out + 4, EncryptNonce, NonceBytes); - // Write tag after ciphertext - memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + // EncryptNonce is zero-initialized (NIST SP 800-38D §8.2.1 deterministic + // construction): fixed_field = 0, counter starts at 0 and is incremented + // before each encryption by UpdateNonce(). No RNG is used here because + // std::random_device is not guaranteed to be a CSPRNG (historic MinGW, + // some WASI targets), and the deterministic construction does not need + // one as long as each session uses a unique key. - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - - return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; +#if ZEN_PLATFORM_WINDOWS + NTSTATUS Status = BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptOpenAlgorithmProvider failed: 0x{:08x}", static_cast<uint32_t>(Status)); + hAlg = nullptr; + HasErrors = true; + return; + } + + Status = BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptSetProperty(BCRYPT_CHAIN_MODE_GCM) failed: 0x{:08x}", static_cast<uint32_t>(Status)); + HasErrors = true; + return; + } + + Status = BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptGenerateSymmetricKey failed: 0x{:08x}", static_cast<uint32_t>(Status)); + hKey = nullptr; + HasErrors = true; + return; + } #else - if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) - { - HasErrors = true; - return 0; + while (ERR_get_error() != 0) + { + } + + EncCtx = EVP_CIPHER_CTX_new(); + DecCtx = EVP_CIPHER_CTX_new(); + if (!EncCtx || !DecCtx) + { + ZEN_ERROR("EVP_CIPHER_CTX_new failed"); + HasErrors = true; + return; + } + + if (EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_EncryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error()); + HasErrors = true; + return; + } + + if (EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_DecryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error()); + HasErrors = true; + return; + } +#endif } - int32_t Offset = 0; - // Write length - memcpy(Out + Offset, &InLength, 4); - Offset += 4; - // Write nonce - memcpy(Out + Offset, EncryptNonce, NonceBytes); - Offset += NonceBytes; - - // Encrypt - int OutLen = 0; - if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + ~AesCryptoContext() { - HasErrors = true; - return 0; +#if ZEN_PLATFORM_WINDOWS + if (hKey) + { + BCryptDestroyKey(hKey); + } + if (hAlg) + { + BCryptCloseAlgorithmProvider(hAlg, 0); + } + SecureZeroMemory(Key, sizeof(Key)); + SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); + SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); +#else + OPENSSL_cleanse(Key, sizeof(Key)); + OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); + OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); + + if (EncCtx) + { + EVP_CIPHER_CTX_free(EncCtx); + } + if (DecCtx) + { + EVP_CIPHER_CTX_free(DecCtx); + } +#endif } - Offset += OutLen; - // Finalize - int FinalLen = 0; - if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + void UpdateNonce() { + // NIST SP 800-38D §8.2.1 deterministic construction: + // nonce = [fixed_field (4 bytes) || invocation_counter (8 bytes, big-endian)] + // The low 8 bytes are a strict monotonic counter starting at zero. On 2^64 + // exhaustion the session is torn down (HasErrors) - never wrap, since a repeated + // (key, nonce) pair catastrophically breaks AES-GCM confidentiality and integrity. + for (int i = 11; i >= 4; --i) + { + if (++EncryptNonce[i] != 0) + { + return; + } + } HasErrors = true; - return 0; } - Offset += FinalLen; - // Get tag - if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) { - HasErrors = true; - return 0; - } - Offset += TagBytes; - - return Offset; -#endif - } + UpdateNonce(); + if (HasErrors) + { + return 0; + } - // Decrypt a message. Returns decrypted data length, or 0 on failure. - // Input must be [ciphertext][tag], with nonce provided separately. - int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) - { #if ZEN_PLATFORM_WINDOWS - BCRYPT_ALG_HANDLE hAlg = nullptr; - BCRYPT_KEY_HANDLE hKey = nullptr; - - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); - AuthInfo.cbNonce = NonceBytes; - AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); - AuthInfo.cbTag = TagBytes; - - ULONG PlainLen = 0; - NTSTATUS Status = BCryptDecrypt(hKey, - (PUCHAR)CipherAndTag, - (ULONG)DataLength, - &AuthInfo, - nullptr, - 0, - (PUCHAR)Out, - (ULONG)DataLength, - &PlainLen, - 0); - - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - - if (!BCRYPT_SUCCESS(Status)) - { - HasErrors = true; - return 0; - } - - return static_cast<int32_t>(PlainLen); + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = EncryptNonce; + AuthInfo.cbNonce = NonceBytes; + // Tag is output-only on encrypt; BCryptEncrypt writes TagBytes bytes into it, so skip zero-init. + uint8_t Tag[TagBytes]; + AuthInfo.pbTag = Tag; + AuthInfo.cbTag = TagBytes; + + ULONG CipherLen = 0; + const NTSTATUS Status = BCryptEncrypt(hKey, + (PUCHAR)In, + (ULONG)InLength, + &AuthInfo, + nullptr, + 0, + Out + 4 + NonceBytes, + (ULONG)InLength, + &CipherLen, + 0); + + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptEncrypt failed: 0x{:08x}", static_cast<uint32_t>(Status)); + HasErrors = true; + return 0; + } + + memcpy(Out, &InLength, 4); + memcpy(Out + 4, EncryptNonce, NonceBytes); + memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + + return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; #else - if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) - { - HasErrors = true; - return 0; - } - - int OutLen = 0; - if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) - { - HasErrors = true; - return 0; + // Reset per message so any stale state from a previous encrypt (e.g. partial + // completion after a prior error) cannot bleed into this operation. Re-bind + // the cipher/key; the IV is then set via the normal init call below. + if (EVP_CIPHER_CTX_reset(EncCtx) != 1 || EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_CIPHER_CTX_reset/EncryptInit failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) + { + ZEN_ERROR("EVP_EncryptInit_ex(key+iv) failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + + int32_t Offset = 0; + memcpy(Out + Offset, &InLength, 4); + Offset += 4; + memcpy(Out + Offset, EncryptNonce, NonceBytes); + Offset += NonceBytes; + + int OutLen = 0; + if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + { + ZEN_ERROR("EVP_EncryptUpdate failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + Offset += OutLen; + + int FinalLen = 0; + if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + { + ZEN_ERROR("EVP_EncryptFinal_ex failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + Offset += FinalLen; + + if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + { + ZEN_ERROR("EVP_CTRL_GCM_GET_TAG failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + Offset += TagBytes; + + return Offset; +#endif } - // Set the tag for verification - if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) { - HasErrors = true; - return 0; +#if ZEN_PLATFORM_WINDOWS + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); + AuthInfo.cbNonce = NonceBytes; + AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); + AuthInfo.cbTag = TagBytes; + + ULONG PlainLen = 0; + const NTSTATUS Status = BCryptDecrypt(hKey, + (PUCHAR)CipherAndTag, + (ULONG)DataLength, + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out, + (ULONG)DataLength, + &PlainLen, + 0); + + if (!BCRYPT_SUCCESS(Status)) + { + // STATUS_AUTH_TAG_MISMATCH (0xC000A002) indicates GCM integrity failure - + // either in-flight corruption or active tampering. Log distinctly from + // other BCryptDecrypt failures so that tamper attempts are auditable. + static constexpr NTSTATUS STATUS_AUTH_TAG_MISMATCH_VAL = static_cast<NTSTATUS>(0xC000A002L); + if (Status == STATUS_AUTH_TAG_MISMATCH_VAL) + { + ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter); + } + else + { + ZEN_ERROR("BCryptDecrypt failed: 0x{:08x} (seq={})", static_cast<uint32_t>(Status), DecryptCounter); + } + HasErrors = true; + return 0; + } + + ++DecryptCounter; + return static_cast<int32_t>(PlainLen); +#else + // Same rationale as EncryptMessage: reset the context and re-bind the cipher + // before each decrypt to avoid stale state from a previous operation. + if (EVP_CIPHER_CTX_reset(DecCtx) != 1 || EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_CIPHER_CTX_reset/DecryptInit failed (seq={}): {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) + { + ZEN_ERROR("EVP_DecryptInit_ex (seq={}) failed: {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + + int OutLen = 0; + if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) + { + ZEN_ERROR("EVP_DecryptUpdate failed (seq={}): {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + + if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + { + ZEN_ERROR("EVP_CTRL_GCM_SET_TAG failed (seq={}): {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + + int FinalLen = 0; + if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) + { + // EVP_DecryptFinal_ex returns 0 specifically on GCM tag verification failure + // once the tag has been set. Log distinctly so tamper attempts are auditable. + ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter); + HasErrors = true; + return 0; + } + + ++DecryptCounter; + return OutLen + FinalLen; +#endif } + }; - int FinalLen = 0; - if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) - { - HasErrors = true; - return 0; - } +} // anonymous namespace - return OutLen + FinalLen; -#endif - } +struct AsyncAesComputeTransport::CryptoContext : AesCryptoContext +{ + using AesCryptoContext::AesCryptoContext; }; -AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport) +// --- AsyncAesComputeTransport --- + +AsyncAesComputeTransport::AsyncAesComputeTransport(const uint8_t (&Key)[KeySize], + std::unique_ptr<AsyncComputeTransport> InnerTransport, + asio::io_context& IoContext) : m_Crypto(std::make_unique<CryptoContext>(Key)) , m_Inner(std::move(InnerTransport)) +, m_IoContext(IoContext) { } -AesComputeTransport::~AesComputeTransport() +AsyncAesComputeTransport::~AsyncAesComputeTransport() { Close(); } bool -AesComputeTransport::IsValid() const +AsyncAesComputeTransport::IsValid() const { return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; } -size_t -AesComputeTransport::Send(const void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) { - ZEN_TRACE_CPU("AesComputeTransport::Send"); - if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - std::lock_guard<std::mutex> Lock(m_Lock); - const int32_t DataLength = static_cast<int32_t>(Size); - const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; + const size_t MessageLength = 4 + CryptoContext::NonceBytes + Size + CryptoContext::TagBytes; - if (m_EncryptBuffer.size() < MessageLength) - { - m_EncryptBuffer.resize(MessageLength); - } + // Encrypt directly into the per-write buffer rather than a long-lived member. Using a + // member (plaintext + ciphertext share that buffer during encryption on the OpenSSL + // path) would leave plaintext on the heap indefinitely and would also make the + // transport unsafe if AsyncWrite were ever invoked concurrently. Size the shared_ptr + // exactly to EncryptedLen afterwards. + auto EncBuf = std::make_shared<std::vector<uint8_t>>(MessageLength); - const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); + const int32_t EncryptedLen = m_Crypto->EncryptMessage(EncBuf->data(), Data, DataLength); if (EncryptedLen == 0) { - return 0; + asio::post(m_IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); }); + return; } - if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen))) - { - return 0; - } + EncBuf->resize(static_cast<size_t>(EncryptedLen)); - return Size; + m_Inner->AsyncWrite( + EncBuf->data(), + EncBuf->size(), + [Handler = std::move(Handler), EncBuf, Size](const std::error_code& Ec, size_t /*BytesWritten*/) { Handler(Ec, Ec ? 0 : Size); }); } -size_t -AesComputeTransport::Recv(void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes - // than the decrypted message contains. Excess bytes are buffered in m_RemainingData - // and returned on subsequent Recv calls without another decryption round-trip. - ZEN_TRACE_CPU("AesComputeTransport::Recv"); - - std::lock_guard<std::mutex> Lock(m_Lock); + uint8_t* Dest = static_cast<uint8_t*>(Data); if (!m_RemainingData.empty()) { const size_t Available = m_RemainingData.size() - m_RemainingOffset; const size_t ToCopy = std::min(Available, Size); - memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); + memcpy(Dest, m_RemainingData.data() + m_RemainingOffset, ToCopy); m_RemainingOffset += ToCopy; if (m_RemainingOffset >= m_RemainingData.size()) @@ -339,82 +437,104 @@ AesComputeTransport::Recv(void* Data, size_t Size) m_RemainingOffset = 0; } - return ToCopy; - } - - // Receive packet header: [length(4B)][nonce(12B)] - struct PacketHeader - { - int32_t DataLength = 0; - uint8_t Nonce[NonceBytes] = {}; - } Header; - - if (!m_Inner->RecvMessage(&Header, sizeof(Header))) - { - return 0; - } - - // Validate DataLength to prevent OOM from malicious/corrupt peers - static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB - - if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) - { - ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); - return 0; - } - - // Receive ciphertext + tag - const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes; - - if (m_EncryptBuffer.size() < MessageLength) - { - m_EncryptBuffer.resize(MessageLength); - } - - if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) - { - return 0; - } - - // Decrypt - const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size); - - // We need a temporary buffer for decryption if we can't decrypt directly into output - std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength)); - - const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); - if (Decrypted == 0) - { - return 0; - } - - memcpy(Data, DecryptedBuf.data(), BytesToReturn); + if (ToCopy == Size) + { + asio::post(m_IoContext, [Handler = std::move(Handler), Size] { Handler(std::error_code{}, Size); }); + return; + } - // Store remaining data if we couldn't return everything - if (static_cast<size_t>(Header.DataLength) > BytesToReturn) - { - m_RemainingOffset = 0; - m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); + DoRecvMessage(Dest + ToCopy, Size - ToCopy, std::move(Handler)); + return; } - return BytesToReturn; + DoRecvMessage(Dest, Size, std::move(Handler)); } void -AesComputeTransport::MarkComplete() +AsyncAesComputeTransport::DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler) { - if (IsValid()) - { - m_Inner->MarkComplete(); - } + static constexpr size_t HeaderSize = 4 + CryptoContext::NonceBytes; + auto HeaderBuf = std::make_shared<std::array<uint8_t, 4 + 12>>(); + + m_Inner->AsyncRead(HeaderBuf->data(), + HeaderSize, + [this, Dest, Size, Handler = std::move(Handler), HeaderBuf](const std::error_code& Ec, size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + int32_t DataLength = 0; + memcpy(&DataLength, HeaderBuf->data(), 4); + + static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; + if (DataLength <= 0 || DataLength > MaxDataLength) + { + Handler(asio::error::make_error_code(asio::error::invalid_argument), 0); + return; + } + + const size_t MessageLength = static_cast<size_t>(DataLength) + CryptoContext::TagBytes; + if (m_DecryptBuffer.size() < MessageLength) + { + m_DecryptBuffer.resize(MessageLength); + } + + auto NonceBuf = std::make_shared<std::array<uint8_t, CryptoContext::NonceBytes>>(); + memcpy(NonceBuf->data(), HeaderBuf->data() + 4, CryptoContext::NonceBytes); + + m_Inner->AsyncRead( + m_DecryptBuffer.data(), + MessageLength, + [this, Dest, Size, Handler = std::move(Handler), DataLength, NonceBuf](const std::error_code& Ec, + size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + std::vector<uint8_t> PlaintextBuf(static_cast<size_t>(DataLength)); + const int32_t Decrypted = + m_Crypto->DecryptMessage(PlaintextBuf.data(), NonceBuf->data(), m_DecryptBuffer.data(), DataLength); + if (Decrypted == 0) + { + Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); + return; + } + + const size_t BytesToReturn = std::min(static_cast<size_t>(Decrypted), Size); + memcpy(Dest, PlaintextBuf.data(), BytesToReturn); + + if (static_cast<size_t>(Decrypted) > BytesToReturn) + { + m_RemainingOffset = 0; + m_RemainingData.assign(PlaintextBuf.begin() + BytesToReturn, PlaintextBuf.begin() + Decrypted); + } + + if (BytesToReturn < Size) + { + DoRecvMessage(Dest + BytesToReturn, Size - BytesToReturn, std::move(Handler)); + } + else + { + Handler(std::error_code{}, Size); + } + }); + }); } void -AesComputeTransport::Close() +AsyncAesComputeTransport::Close() { if (!m_IsClosed) { - if (m_Inner && m_Inner->IsValid()) + // Always forward Close() to the inner transport if we have one. Gating on + // IsValid() skipped cleanup when the inner transport was partially torn down + // (e.g. after a read/write error marked it non-valid but left its socket open), + // leaking OS handles. Close implementations are expected to be idempotent. + if (m_Inner) { m_Inner->Close(); } diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h index efcad9835..7846073dc 100644 --- a/src/zenhorde/hordetransportaes.h +++ b/src/zenhorde/hordetransportaes.h @@ -6,47 +6,54 @@ #include <cstdint> #include <memory> -#include <mutex> #include <vector> +namespace asio { +class io_context; +} + namespace zen::horde { -/** AES-256-GCM encrypted transport wrapper. +/** Async AES-256-GCM encrypted transport wrapper. * - * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting - * all incoming data using AES-256-GCM. The nonce is mutated per message using - * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1]. + * Wraps an AsyncComputeTransport, encrypting outgoing and decrypting incoming + * data using AES-256-GCM. Outgoing nonces follow the NIST SP 800-38D §8.2.1 + * deterministic construction: a 4-byte fixed field followed by an 8-byte + * big-endian monotonic counter. The session is torn down if the counter + * would wrap. * * Wire format per encrypted message: * [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)] * * Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time). + * + * Thread safety: all operations must be serialized by the caller (e.g. via a strand). */ -class AesComputeTransport final : public ComputeTransport +class AsyncAesComputeTransport final : public AsyncComputeTransport { public: - AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport); - ~AesComputeTransport() override; + AsyncAesComputeTransport(const uint8_t (&Key)[KeySize], + std::unique_ptr<AsyncComputeTransport> InnerTransport, + asio::io_context& IoContext); + ~AsyncAesComputeTransport() override; - bool IsValid() const override; - size_t Send(const void* Data, size_t Size) override; - size_t Recv(void* Data, size_t Size) override; - void MarkComplete() override; - void Close() override; + bool IsValid() const override; + void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override; + void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override; + void Close() override; private: - static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size - static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size + void DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler); struct CryptoContext; - std::unique_ptr<CryptoContext> m_Crypto; - std::unique_ptr<ComputeTransport> m_Inner; - std::vector<uint8_t> m_EncryptBuffer; - std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv - size_t m_RemainingOffset = 0; - std::mutex m_Lock; - bool m_IsClosed = false; + std::unique_ptr<CryptoContext> m_Crypto; + std::unique_ptr<AsyncComputeTransport> m_Inner; + asio::io_context& m_IoContext; + std::vector<uint8_t> m_DecryptBuffer; + std::vector<uint8_t> m_RemainingData; + size_t m_RemainingOffset = 0; + bool m_IsClosed = false; }; } // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h index 201d68b83..87caec019 100644 --- a/src/zenhorde/include/zenhorde/hordeclient.h +++ b/src/zenhorde/include/zenhorde/hordeclient.h @@ -45,14 +45,15 @@ struct MachineInfo uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES) bool IsWindows = false; std::string LeaseId; + std::string Pool; std::map<std::string, PortInfo> Ports; /** Return the address to connect to, accounting for connection mode. */ - const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } + [[nodiscard]] const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } /** Return the port to connect to, accounting for connection mode and port mapping. */ - uint16_t GetConnectionPort() const + [[nodiscard]] uint16_t GetConnectionPort() const { if (Mode == ConnectionMode::Relay) { @@ -65,7 +66,20 @@ struct MachineInfo return Port; } - bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } + /** Return the address and port for the Zen service endpoint, accounting for relay port mapping. */ + [[nodiscard]] std::pair<const std::string&, uint16_t> GetZenServiceEndpoint(uint16_t DefaultPort) const + { + if (Mode == ConnectionMode::Relay) + { + if (auto It = Ports.find("ZenPort"); It != Ports.end()) + { + return {ConnectionAddress, It->second.Port}; + } + } + return {Ip, DefaultPort}; + } + + [[nodiscard]] bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } }; /** Result of cluster auto-resolution via the Horde API. */ @@ -83,31 +97,29 @@ struct ClusterInfo class HordeClient { public: - explicit HordeClient(const HordeConfig& Config); + explicit HordeClient(HordeConfig Config); ~HordeClient(); HordeClient(const HordeClient&) = delete; HordeClient& operator=(const HordeClient&) = delete; /** Initialize the underlying HTTP client. Must be called before other methods. */ - bool Initialize(); + [[nodiscard]] bool Initialize(); /** Build the JSON request body for cluster resolution and machine requests. * Encodes pool, condition, connection mode, encryption, and port requirements. */ - std::string BuildRequestBody() const; + [[nodiscard]] std::string BuildRequestBody() const; /** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */ - bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); + [[nodiscard]] bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); /** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}. * On success, populates OutMachine with connection details and credentials. */ - bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); + [[nodiscard]] bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); LoggerRef Log() { return m_Log; } private: - bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize); - HordeConfig m_Config; std::unique_ptr<zen::HttpClient> m_Http; LoggerRef m_Log; diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h index dd70f9832..3a4dfb386 100644 --- a/src/zenhorde/include/zenhorde/hordeconfig.h +++ b/src/zenhorde/include/zenhorde/hordeconfig.h @@ -4,6 +4,10 @@ #include <zenhorde/zenhorde.h> +#include <zenhttp/httpclient.h> + +#include <functional> +#include <optional> #include <string> namespace zen::horde { @@ -33,20 +37,25 @@ struct HordeConfig static constexpr const char* ClusterDefault = "default"; static constexpr const char* ClusterAuto = "_auto"; - bool Enabled = false; ///< Whether Horde provisioning is active - std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") - std::string AuthToken; ///< Authentication token for the Horde API - std::string Pool; ///< Pool name to request machines from - std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve - std::string Condition; ///< Agent filter expression for machine selection - std::string HostAddress; ///< Address that provisioned agents use to connect back to us - std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload - uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication - - int MaxCores = 2048; - bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents - ConnectionMode Mode = ConnectionMode::Direct; - Encryption EncryptionMode = Encryption::None; + bool Enabled = false; ///< Whether Horde provisioning is active + std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") + std::string AuthToken; ///< Authentication token for the Horde API (static fallback) + + /// Optional token provider with automatic refresh (e.g. from OidcToken executable). + /// When set, takes priority over the static AuthToken string. + std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider; + std::string Pool; ///< Pool name to request machines from + std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve + std::string Condition; ///< Agent filter expression for machine selection + std::string HostAddress; ///< Address that provisioned agents use to connect back to us + std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload + uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication + + int MaxCores = 2048; + int DrainGracePeriodSeconds = 300; ///< Grace period for draining agents before force-kill (default 5 min) + bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents + ConnectionMode Mode = ConnectionMode::Direct; + Encryption EncryptionMode = Encryption::None; /** Validate the configuration. Returns false if the configuration is invalid * (e.g. Relay mode without AES encryption). */ diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h index 4e2e63bbd..ea2fd7783 100644 --- a/src/zenhorde/include/zenhorde/hordeprovisioner.h +++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h @@ -2,21 +2,32 @@ #pragma once +#include <zenhorde/hordeclient.h> #include <zenhorde/hordeconfig.h> +#include <zencompute/provisionerstate.h> #include <zencore/logbase.h> +#include <zencore/thread.h> #include <atomic> #include <cstdint> +#include <deque> #include <filesystem> #include <memory> #include <mutex> #include <string> +#include <thread> +#include <unordered_set> #include <vector> +namespace asio { +class io_context; +} + namespace zen::horde { class HordeClient; +class AsyncHordeAgent; /** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */ struct ProvisioningStats @@ -35,13 +46,12 @@ struct ProvisioningStats * binary, and executing it remotely. Each provisioned machine runs zenserver * in compute mode, which announces itself back to the orchestrator. * - * Spawns one thread per agent. Each thread handles the full lifecycle: - * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption -> - * channel setup -> binary upload -> remote execution -> poll until exit. + * Agent work (HTTP request, connect, upload, poll) is dispatched to a thread + * pool rather than spawning a dedicated thread per agent. * * Thread safety: SetTargetCoreCount and GetStats may be called from any thread. */ -class HordeProvisioner +class HordeProvisioner : public compute::IProvisionerStateProvider { public: /** Construct a provisioner. @@ -52,38 +62,48 @@ public: HordeProvisioner(const HordeConfig& Config, const std::filesystem::path& BinariesPath, const std::filesystem::path& WorkingDir, - std::string_view OrchestratorEndpoint); + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession = {}, + bool CleanStart = false, + std::string_view TraceHost = {}); - /** Signals all agent threads to exit and joins them. */ - ~HordeProvisioner(); + /** Signals all agents to exit and waits for completion. */ + ~HordeProvisioner() override; HordeProvisioner(const HordeProvisioner&) = delete; HordeProvisioner& operator=(const HordeProvisioner&) = delete; /** Set the target number of cores to provision. - * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the - * estimated core count is below the target. Also joins any finished - * agent threads. */ - void SetTargetCoreCount(uint32_t Count); + * Clamped to HordeConfig::MaxCores. Dispatches new agent work if the + * estimated core count is below the target. Also removes finished agents. */ + void SetTargetCoreCount(uint32_t Count) override; /** Return a snapshot of the current provisioning counters. */ ProvisioningStats GetStats() const; - uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); } - uint32_t GetAgentCount() const; + // IProvisionerStateProvider + std::string_view GetName() const override { return "horde"; } + uint32_t GetTargetCoreCount() const override { return m_TargetCoreCount.load(); } + uint32_t GetEstimatedCoreCount() const override { return m_EstimatedCoreCount.load(); } + uint32_t GetActiveCoreCount() const override { return m_ActiveCoreCount.load(); } + uint32_t GetAgentCount() const override; + uint32_t GetDrainingAgentCount() const override { return m_AgentsDraining.load(); } + compute::AgentProvisioningStatus GetAgentStatus(std::string_view WorkerId) const override; private: LoggerRef Log() { return m_Log; } - struct AgentWrapper; - void RequestAgent(); - void ThreadAgent(AgentWrapper& Wrapper); + void ProvisionAgent(); + bool InitializeHordeClient(); HordeConfig m_Config; std::filesystem::path m_BinariesPath; std::filesystem::path m_WorkingDir; std::string m_OrchestratorEndpoint; + std::string m_CoordinatorSession; + bool m_CleanStart = false; + std::string m_TraceHost; std::unique_ptr<HordeClient> m_HordeClient; @@ -91,20 +111,54 @@ private: std::vector<std::pair<std::string, std::filesystem::path>> m_Bundles; ///< (locator, bundleDir) pairs bool m_BundlesCreated = false; - mutable std::mutex m_AgentsLock; - std::vector<std::unique_ptr<AgentWrapper>> m_Agents; - std::atomic<uint64_t> m_LastRequestFailTime{0}; std::atomic<uint32_t> m_TargetCoreCount{0}; std::atomic<uint32_t> m_EstimatedCoreCount{0}; std::atomic<uint32_t> m_ActiveCoreCount{0}; std::atomic<uint32_t> m_AgentsActive{0}; + std::atomic<uint32_t> m_AgentsDraining{0}; std::atomic<uint32_t> m_AgentsRequesting{0}; std::atomic<bool> m_AskForAgents{true}; + std::atomic<uint32_t> m_PendingWorkItems{0}; + Event m_AllWorkDone; + /** Manual-reset event set alongside m_AskForAgents=false so pool-thread backoff waits + * wake immediately on shutdown instead of polling a 100ms sleep. */ + Event m_ShutdownEvent; LoggerRef m_Log; + // Async I/O + std::unique_ptr<asio::io_context> m_IoContext; + std::vector<std::thread> m_IoThreads; + + struct AsyncAgentEntry + { + std::shared_ptr<AsyncHordeAgent> Agent; + std::string RemoteEndpoint; + std::string LeaseId; + uint16_t CoreCount = 0; + bool Draining = false; + }; + + mutable std::mutex m_AsyncAgentsLock; + std::vector<AsyncAgentEntry> m_AsyncAgents; + + /** Worker IDs of agents that completed after draining. + * GetAgentStatus() consumes entries when queried, but if no one queries, entries would + * otherwise accumulate unbounded across the lifetime of the provisioner. Cap the set + * at RecentlyDrainedCapacity by evicting the oldest entry (tracked in an insertion-order + * queue) whenever we insert past the limit. */ + mutable std::unordered_set<std::string> m_RecentlyDrainedWorkerIds; + mutable std::deque<std::string> m_RecentlyDrainedOrder; + static constexpr size_t RecentlyDrainedCapacity = 256; + + void OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent); + void DrainAsyncAgent(AsyncAgentEntry& Entry); + + std::vector<std::string> BuildAgentArgs(const MachineInfo& Machine) const; + static constexpr uint32_t EstimatedCoresPerAgent = 32; + static constexpr int IoThreadCount = 3; }; } // namespace zen::horde diff --git a/src/zenhttp/asynchttpclient_test.cpp b/src/zenhttp/asynchttpclient_test.cpp new file mode 100644 index 000000000..151863370 --- /dev/null +++ b/src/zenhttp/asynchttpclient_test.cpp @@ -0,0 +1,315 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/asynchttpclient.h> +#include <zenhttp/httpserver.h> + +#if ZEN_WITH_TESTS + +# include <zencore/iobuffer.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> + +# include "servers/httpasio.h" + +# include <atomic> +# include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// Reusable test service for async client tests + +class AsyncHttpClientTestService : public HttpService +{ +public: + AsyncHttpClientTestService() + { + m_Router.RegisterRoute( + "hello", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "echo", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + IoBuffer Body = HttpReq.ReadPayload(); + HttpContentType CT = HttpReq.RequestContentType(); + HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body); + }, + HttpVerb::kPost | HttpVerb::kPut); + + m_Router.RegisterRoute( + "echo/method", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Method = ToString(HttpReq.RequestVerb()); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + + m_Router.RegisterRoute( + "nocontent", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "json", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"ok\":true}"); + }, + HttpVerb::kGet); + } + + virtual const char* BaseUri() const override { return "/api/async-test/"; } + virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); } + +private: + HttpRequestRouter m_Router; +}; + +////////////////////////////////////////////////////////////////////////// + +struct AsyncTestServerFixture +{ + AsyncHttpClientTestService TestService; + ScopedTemporaryDirectory TmpDir; + Ref<HttpServer> Server; + std::thread ServerThread; + int Port = -1; + + AsyncTestServerFixture() + { + Server = CreateHttpAsioServer(AsioConfig{}); + Port = Server->Initialize(0, TmpDir.Path()); + ZEN_ASSERT(Port != -1); + Server->RegisterService(TestService); + ServerThread = std::thread([this]() { Server->Run(false); }); + } + + ~AsyncTestServerFixture() + { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + } + + AsyncHttpClient MakeClient(HttpClientSettings Settings = {}) { return AsyncHttpClient(fmt::format("127.0.0.1:{}", Port), Settings); } + + AsyncHttpClient MakeClient(asio::io_context& IoContext, HttpClientSettings Settings = {}) + { + return AsyncHttpClient(fmt::format("127.0.0.1:{}", Port), IoContext, Settings); + } +}; + +////////////////////////////////////////////////////////////////////////// +// Tests + +TEST_SUITE_BEGIN("http.asynchttpclient"); + +TEST_CASE("asynchttpclient.future.verbs") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + SUBCASE("GET returns 200 with expected body") + { + auto Future = Client.Get("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "GET"); + } + + SUBCASE("POST dispatches correctly") + { + auto Future = Client.Post("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "POST"); + } + + SUBCASE("PUT dispatches correctly") + { + auto Future = Client.Put("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "PUT"); + } + + SUBCASE("DELETE dispatches correctly") + { + auto Future = Client.Delete("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "DELETE"); + } + + SUBCASE("HEAD returns 200 with empty body") + { + auto Future = Client.Head("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), ""sv); + } +} + +TEST_CASE("asynchttpclient.future.get") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + SUBCASE("simple GET with text response") + { + auto Future = Client.Get("/api/async-test/hello"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("GET returning JSON") + { + auto Future = Client.Get("/api/async-test/json"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "{\"ok\":true}"); + } + + SUBCASE("GET 204 NoContent") + { + auto Future = Client.Get("/api/async-test/nocontent"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); + } +} + +TEST_CASE("asynchttpclient.future.post.with.payload") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + std::string_view PayloadStr = "async payload data"; + IoBuffer Payload(IoBuffer::Clone, PayloadStr.data(), PayloadStr.size()); + Payload.SetContentType(ZenContentType::kText); + + auto Future = Client.Post("/api/async-test/echo", Payload); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "async payload data"); +} + +TEST_CASE("asynchttpclient.future.put.with.payload") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + std::string_view PutStr = "put payload"; + IoBuffer Payload(IoBuffer::Clone, PutStr.data(), PutStr.size()); + Payload.SetContentType(ZenContentType::kText); + + auto Future = Client.Put("/api/async-test/echo", Payload); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "put payload"); +} + +TEST_CASE("asynchttpclient.callback") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + std::promise<HttpClient::Response> Promise; + auto Future = Promise.get_future(); + + Client.AsyncGet("/api/async-test/hello", [&Promise](HttpClient::Response Resp) { Promise.set_value(std::move(Resp)); }); + + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); +} + +TEST_CASE("asynchttpclient.concurrent.requests") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + // Fire multiple requests concurrently + auto Future1 = Client.Get("/api/async-test/hello"); + auto Future2 = Client.Get("/api/async-test/json"); + auto Future3 = Client.Post("/api/async-test/echo/method"); + auto Future4 = Client.Delete("/api/async-test/echo/method"); + + auto Resp1 = Future1.get(); + auto Resp2 = Future2.get(); + auto Resp3 = Future3.get(); + auto Resp4 = Future4.get(); + + CHECK(Resp1.IsSuccess()); + CHECK_EQ(Resp1.AsText(), "hello world"); + + CHECK(Resp2.IsSuccess()); + CHECK_EQ(Resp2.AsText(), "{\"ok\":true}"); + + CHECK(Resp3.IsSuccess()); + CHECK_EQ(Resp3.AsText(), "POST"); + + CHECK(Resp4.IsSuccess()); + CHECK_EQ(Resp4.AsText(), "DELETE"); +} + +TEST_CASE("asynchttpclient.external.io_context") +{ + AsyncTestServerFixture Fixture; + + asio::io_context IoContext; + auto WorkGuard = asio::make_work_guard(IoContext); + std::thread IoThread([&IoContext]() { IoContext.run(); }); + + { + AsyncHttpClient Client = Fixture.MakeClient(IoContext); + + auto Future = Client.Get("/api/async-test/hello"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + WorkGuard.reset(); + IoThread.join(); +} + +TEST_CASE("asynchttpclient.connection.error") +{ + // Connect to a port where nothing is listening + AsyncHttpClient Client("127.0.0.1:1", HttpClientSettings{.ConnectTimeout = std::chrono::milliseconds(500)}); + + auto Future = Client.Get("/should-fail"); + auto Resp = Future.get(); + + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + CHECK(Resp.Error->IsConnectionError()); +} + +TEST_SUITE_END(); + +void +asynchttpclient_test_forcelink() +{ +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp index 209276621..2fa22f2c2 100644 --- a/src/zenhttp/auth/authmgr.cpp +++ b/src/zenhttp/auth/authmgr.cpp @@ -132,7 +132,7 @@ public: } } - RefPtr<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId})); + Ref<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId})); if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) { @@ -232,10 +232,10 @@ public: private: struct OpenIdProvider { - std::string Name; - std::string Url; - std::string ClientId; - RefPtr<OidcClient> HttpClient; + std::string Name; + std::string Url; + std::string ClientId; + Ref<OidcClient> HttpClient; }; struct OpenIdToken @@ -262,7 +262,7 @@ private: { ZEN_TRACE_CPU("AuthMgr::RefreshOpenIdToken"); - RefPtr<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient; + Ref<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient; if (!Client) { return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; diff --git a/src/zenhttp/clients/asynchttpclient.cpp b/src/zenhttp/clients/asynchttpclient.cpp new file mode 100644 index 000000000..ea88fc783 --- /dev/null +++ b/src/zenhttp/clients/asynchttpclient.cpp @@ -0,0 +1,1033 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/asynchttpclient.h> + +#include "httpclientcurlhelpers.h" + +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zencore/session.h> +#include <zencore/thread.h> +#include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +#include <asio/steady_timer.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <thread> +#include <unordered_map> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// TransferContext: per-transfer state associated with each CURL easy handle + +struct TransferContext +{ + AsyncHttpCallback Callback; + std::string Body; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + CurlWriteCallbackData WriteData; + CurlHeaderCallbackData HeaderData; + curl_slist* HeaderList = nullptr; + + // For PUT/POST with payload: keep the data alive until transfer completes + IoBuffer PayloadBuffer; + CurlReadCallbackData ReadData; + + TransferContext(AsyncHttpCallback&& InCallback) : Callback(std::move(InCallback)) + { + WriteData.Body = &Body; + HeaderData.Headers = &ResponseHeaders; + } + + ~TransferContext() + { + if (HeaderList) + { + curl_slist_free_all(HeaderList); + } + } + + TransferContext(const TransferContext&) = delete; + TransferContext& operator=(const TransferContext&) = delete; +}; + +////////////////////////////////////////////////////////////////////////// +// +// AsyncHttpClient::Impl + +struct AsyncHttpClient::Impl +{ + Impl(std::string_view BaseUri, const HttpClientSettings& Settings) + : m_BaseUri(BaseUri) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_OwnedIoContext(std::make_unique<asio::io_context>()) + , m_IoContext(*m_OwnedIoContext) + , m_Strand(asio::make_strand(m_IoContext)) + , m_Timer(m_Strand) + { + Init(); + m_WorkGuard.emplace(m_IoContext.get_executor()); + m_IoThread = std::thread([this]() { + SetCurrentThreadName("async_http"); + try + { + m_IoContext.run(); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("AsyncHttpClient: unhandled exception in io thread: {}", Ex.what()); + } + }); + } + + Impl(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings) + : m_BaseUri(BaseUri) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_IoContext(IoContext) + , m_Strand(asio::make_strand(m_IoContext)) + , m_Timer(m_Strand) + { + Init(); + } + + ~Impl() + { + // Clean up curl state on the strand where all curl_multi operations + // are serialized. Use a promise to block until the cleanup handler + // has actually executed - essential for the external io_context case + // where we don't own the run loop. + std::promise<void> Done; + std::future<void> DoneFuture = Done.get_future(); + + asio::post(m_Strand, [this, &Done]() { + m_ShuttingDown = true; + m_Timer.cancel(); + + // Release all tracked sockets (don't close - curl owns the fds). + for (auto& [Fd, Info] : m_Sockets) + { + if (Info->Socket.is_open()) + { + Info->Socket.cancel(); + Info->Socket.release(); + } + } + m_Sockets.clear(); + + for (auto& [Handle, Ctx] : m_Transfers) + { + curl_multi_remove_handle(m_Multi, Handle); + curl_easy_cleanup(Handle); + } + m_Transfers.clear(); + + for (CURL* Handle : m_HandlePool) + { + curl_easy_cleanup(Handle); + } + m_HandlePool.clear(); + + Done.set_value(); + }); + + // For owned io_context: release work guard so run() can return after + // processing the cleanup handler above. + m_WorkGuard.reset(); + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + else + { + // External io_context: wait for the cleanup handler to complete. + DoneFuture.wait(); + } + + if (m_Multi) + { + curl_multi_cleanup(m_Multi); + } + } + + LoggerRef Log() { return m_Log; } + + void Init() + { + m_Multi = curl_multi_init(); + if (!m_Multi) + { + throw std::runtime_error("curl_multi_init failed"); + } + + SetupMultiCallbacks(); + + if (m_Settings.SessionId == Oid::Zero) + { + m_SessionId = std::string(GetSessionIdString()); + } + else + { + m_SessionId = m_Settings.SessionId.ToString(); + } + } + + // -- Handle pool ----------------------------------------------------- + + CURL* AllocHandle() + { + if (!m_HandlePool.empty()) + { + CURL* Handle = m_HandlePool.back(); + m_HandlePool.pop_back(); + curl_easy_reset(Handle); + return Handle; + } + CURL* Handle = curl_easy_init(); + if (!Handle) + { + throw std::runtime_error("curl_easy_init failed"); + } + return Handle; + } + + void ReleaseHandle(CURL* Handle) { m_HandlePool.push_back(Handle); } + + // -- Configure a handle with common settings ------------------------- + // Called only from DoAsync* lambdas running on the strand. + + void ConfigureHandle(CURL* Handle, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters) + { + // Build URL + ExtendableStringBuilder<256> Url; + BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters); + curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); + + // Unix domain socket + if (!m_Settings.UnixSocketPath.empty()) + { + m_UnixSocketPathUtf8 = PathToUtf8(m_Settings.UnixSocketPath); + curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, m_UnixSocketPathUtf8.c_str()); + } + + // Timeouts + if (m_Settings.ConnectTimeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(m_Settings.ConnectTimeout.count())); + } + if (m_Settings.Timeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_Settings.Timeout.count())); + } + + // HTTP/2 + if (m_Settings.AssumeHttp2) + { + curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); + } + + // SSL + if (m_Settings.InsecureSsl) + { + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L); + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L); + } + if (!m_Settings.CaBundlePath.empty()) + { + curl_easy_setopt(Handle, CURLOPT_CAINFO, m_Settings.CaBundlePath.c_str()); + } + + // Verbose/debug + if (m_Settings.Verbose) + { + curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); + } + + // Thread safety + curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); + + if (m_Settings.ForbidReuseConnection) + { + curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); + } + } + + // -- Access token ---------------------------------------------------- + + std::optional<std::string> GetAccessToken() + { + if (!m_Settings.AccessTokenProvider.has_value()) + { + return {}; + } + { + RwLock::SharedLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + } + RwLock::ExclusiveLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + HttpClientAccessToken NewToken = m_Settings.AccessTokenProvider.value()(); + if (!NewToken.IsValid()) + { + ZEN_WARN("AsyncHttpClient: failed to refresh access token, retrying once"); + NewToken = m_Settings.AccessTokenProvider.value()(); + } + if (NewToken.IsValid()) + { + m_CachedAccessToken = NewToken; + return m_CachedAccessToken.GetValue(); + } + ZEN_WARN("AsyncHttpClient: access token provider returned invalid token"); + return {}; + } + + // -- Submit a transfer ----------------------------------------------- + + void SubmitTransfer(CURL* Handle, std::unique_ptr<TransferContext> Ctx) + { + ZEN_TRACE_CPU("AsyncHttpClient::SubmitTransfer"); + // Setup write/header callbacks + curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &Ctx->WriteData); + curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &Ctx->HeaderData); + + m_Transfers[Handle] = std::move(Ctx); + + CURLMcode Mc = curl_multi_add_handle(m_Multi, Handle); + if (Mc != CURLM_OK) + { + auto Stolen = std::move(m_Transfers[Handle]); + m_Transfers.erase(Handle); + ReleaseHandle(Handle); + + HttpClient::Response ErrorResponse; + ErrorResponse.Error = + HttpClient::ErrorContext{.ErrorCode = HttpClientErrorCode::kInternalError, + .ErrorMessage = fmt::format("curl_multi_add_handle failed: {}", curl_multi_strerror(Mc))}; + asio::post(m_IoContext, + [Cb = std::move(Stolen->Callback), Response = std::move(ErrorResponse)]() mutable { Cb(std::move(Response)); }); + return; + } + } + + // -- Socket-action integration --------------------------------------- + // + // curl_multi drives I/O via two callbacks: + // - SocketCallback: curl tells us which sockets to watch for read/write + // - TimerCallback: curl tells us when to fire a timeout + // + // On each socket event or timeout we call curl_multi_socket_action(), + // then drain completed transfers via curl_multi_info_read(). + + // Per-socket state: wraps the native fd in an ASIO socket for async_wait. + struct SocketInfo + { + asio::ip::tcp::socket Socket; + int WatchFlags = 0; // CURL_POLL_IN, CURL_POLL_OUT, CURL_POLL_INOUT + + explicit SocketInfo(asio::io_context& IoContext) : Socket(IoContext) {} + }; + + // Static thunks registered with curl_multi ---------------------------- + + static int CurlSocketCallback(CURL* Easy, curl_socket_t Fd, int Action, void* UserPtr, void* SocketPtr) + { + ZEN_UNUSED(Easy); + auto* Self = static_cast<Impl*>(UserPtr); + Self->OnCurlSocket(Fd, Action, static_cast<SocketInfo*>(SocketPtr)); + return 0; + } + + static int CurlTimerCallback(CURLM* Multi, long TimeoutMs, void* UserPtr) + { + ZEN_UNUSED(Multi); + auto* Self = static_cast<Impl*>(UserPtr); + Self->OnCurlTimer(TimeoutMs); + return 0; + } + + void SetupMultiCallbacks() + { + curl_multi_setopt(m_Multi, CURLMOPT_SOCKETFUNCTION, CurlSocketCallback); + curl_multi_setopt(m_Multi, CURLMOPT_SOCKETDATA, this); + curl_multi_setopt(m_Multi, CURLMOPT_TIMERFUNCTION, CurlTimerCallback); + curl_multi_setopt(m_Multi, CURLMOPT_TIMERDATA, this); + } + + // Called by curl when socket watch state changes --------------------- + + void OnCurlSocket(curl_socket_t Fd, int Action, SocketInfo* Info) + { + if (Action == CURL_POLL_REMOVE) + { + if (Info) + { + // Cancel pending async_wait ops before releasing the fd. + // curl owns the fd, so we must release() rather than close(). + Info->Socket.cancel(); + if (Info->Socket.is_open()) + { + Info->Socket.release(); + } + m_Sockets.erase(Fd); + } + return; + } + + if (!Info) + { + // New socket - wrap the native fd in an ASIO socket. + auto [It, Inserted] = m_Sockets.emplace(Fd, std::make_unique<SocketInfo>(m_IoContext)); + Info = It->second.get(); + + asio::error_code Ec; + // Determine protocol from the fd (v4 vs v6). Default to v4. + Info->Socket.assign(asio::ip::tcp::v4(), Fd, Ec); + if (Ec) + { + // Try v6 as fallback + Info->Socket.assign(asio::ip::tcp::v6(), Fd, Ec); + } + if (Ec) + { + ZEN_WARN("AsyncHttpClient: failed to assign socket fd {}: {}", static_cast<int>(Fd), Ec.message()); + m_Sockets.erase(Fd); + return; + } + + curl_multi_assign(m_Multi, Fd, Info); + } + + Info->WatchFlags = Action; + SetSocketWatch(Fd, Info); + } + + void SetSocketWatch(curl_socket_t Fd, SocketInfo* Info) + { + // Cancel any pending wait before issuing a new one. + Info->Socket.cancel(); + + if (Info->WatchFlags & CURL_POLL_IN) + { + Info->Socket.async_wait(asio::socket_base::wait_read, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + OnSocketReady(Fd, CURL_CSELECT_IN); + })); + } + + if (Info->WatchFlags & CURL_POLL_OUT) + { + Info->Socket.async_wait(asio::socket_base::wait_write, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + OnSocketReady(Fd, CURL_CSELECT_OUT); + })); + } + } + + void OnSocketReady(curl_socket_t Fd, int CurlAction) + { + ZEN_TRACE_CPU("AsyncHttpClient::OnSocketReady"); + int StillRunning = 0; + curl_multi_socket_action(m_Multi, Fd, CurlAction, &StillRunning); + CheckCompleted(); + + // Re-arm the watch if the socket is still tracked. + auto It = m_Sockets.find(Fd); + if (It != m_Sockets.end()) + { + SetSocketWatch(Fd, It->second.get()); + } + } + + // Called by curl when it wants a timeout ------------------------------ + + void OnCurlTimer(long TimeoutMs) + { + m_Timer.cancel(); + + if (TimeoutMs < 0) + { + // curl says "no timeout needed" + return; + } + + if (TimeoutMs == 0) + { + // curl wants immediate action - run it directly on the strand. + asio::post(m_Strand, [this]() { + if (m_ShuttingDown) + { + return; + } + int StillRunning = 0; + curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning); + CheckCompleted(); + }); + return; + } + + m_Timer.expires_after(std::chrono::milliseconds(TimeoutMs)); + m_Timer.async_wait(asio::bind_executor(m_Strand, [this](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + ZEN_TRACE_CPU("AsyncHttpClient::OnTimeout"); + int StillRunning = 0; + curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning); + CheckCompleted(); + })); + } + + // Drain completed transfers from curl_multi -------------------------- + + void CheckCompleted() + { + int MsgsLeft = 0; + CURLMsg* Msg = nullptr; + while ((Msg = curl_multi_info_read(m_Multi, &MsgsLeft)) != nullptr) + { + if (Msg->msg != CURLMSG_DONE) + { + continue; + } + + CURL* Handle = Msg->easy_handle; + CURLcode Result = Msg->data.result; + + curl_multi_remove_handle(m_Multi, Handle); + + auto It = m_Transfers.find(Handle); + if (It == m_Transfers.end()) + { + ReleaseHandle(Handle); + continue; + } + + std::unique_ptr<TransferContext> Ctx = std::move(It->second); + m_Transfers.erase(It); + + CompleteTransfer(Handle, Result, std::move(Ctx)); + } + } + + void CompleteTransfer(CURL* Handle, CURLcode CurlResult, std::unique_ptr<TransferContext> Ctx) + { + ZEN_TRACE_CPU("AsyncHttpClient::CompleteTransfer"); + // Extract result info + long StatusCode = 0; + curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &StatusCode); + + double Elapsed = 0; + curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed); + + curl_off_t UpBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes); + + curl_off_t DownBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); + + ReleaseHandle(Handle); + + // Build response + HttpClient::Response Response; + Response.StatusCode = HttpResponseCode(StatusCode); + Response.UploadedBytes = static_cast<int64_t>(UpBytes); + Response.DownloadedBytes = static_cast<int64_t>(DownBytes); + Response.ElapsedSeconds = Elapsed; + Response.Header = BuildHeaderMap(Ctx->ResponseHeaders); + + if (CurlResult != CURLE_OK) + { + const char* ErrorMsg = curl_easy_strerror(CurlResult); + + if (CurlResult != CURLE_OPERATION_TIMEDOUT && CurlResult != CURLE_COULDNT_CONNECT && CurlResult != CURLE_ABORTED_BY_CALLBACK) + { + ZEN_WARN("AsyncHttpClient failure: ({}) '{}'", static_cast<int>(CurlResult), ErrorMsg); + } + + if (!Ctx->Body.empty()) + { + Response.ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size()); + } + + Response.Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(CurlResult), .ErrorMessage = std::string(ErrorMsg)}; + } + else if (StatusCode == static_cast<long>(HttpResponseCode::NoContent) || Ctx->Body.empty()) + { + // No payload + } + else + { + IoBuffer PayloadBuffer = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size()); + ApplyContentTypeFromHeaders(PayloadBuffer, Ctx->ResponseHeaders); + + const HttpResponseCode Code = HttpResponseCode(StatusCode); + if (!IsHttpSuccessCode(Code) && Code != HttpResponseCode::NotFound) + { + ZEN_WARN("AsyncHttpClient request failed: status={}, base={}", static_cast<int>(Code), m_BaseUri); + } + + Response.ResponsePayload = std::move(PayloadBuffer); + } + + // Dispatch the user callback off the strand so a slow callback + // cannot starve the curl_multi poll loop. + asio::post(m_IoContext, [LogRef = m_Log, Cb = std::move(Ctx->Callback), Response = std::move(Response)]() mutable { + try + { + Cb(std::move(Response)); + } + catch (const std::exception& Ex) + { + ZEN_SCOPED_LOG(LogRef); + ZEN_ERROR("AsyncHttpClient: unhandled exception in completion callback: {}", Ex.what()); + } + }); + } + + // -- Async verb implementations -------------------------------------- + + void DoAsyncGet(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader), + Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Get"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_HTTPGET, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncHead(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) + { + asio::post(m_Strand, + [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Head"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, {}); + curl_easy_setopt(Handle, CURLOPT_NOBODY, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncDelete(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) + { + asio::post(m_Strand, + [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Delete"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, {}); + curl_easy_setopt(Handle, CURLOPT_CUSTOMREQUEST, "DELETE"); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPost(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader), + Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Post"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_POST, 1L); + curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE, 0L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPostWithPayload(std::string Url, + IoBuffer Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Payload = std::move(Payload), + ContentType, + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::PostWithPayload"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, {}); + curl_easy_setopt(Handle, CURLOPT_POST, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->PayloadBuffer = std::move(Payload); + Ctx->HeaderList = + BuildHeaderList(AdditionalHeader, + m_SessionId, + GetAccessToken(), + {std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)))}); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + // Set up read callback for payload data + Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); + Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); + Ctx->ReadData.Offset = 0; + + curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPutWithPayload(std::string Url, + IoBuffer Payload, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Payload = std::move(Payload), + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader), + Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Put"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->PayloadBuffer = std::move(Payload); + Ctx->HeaderList = BuildHeaderList( + AdditionalHeader, + m_SessionId, + GetAccessToken(), + {std::make_pair("Content-Type", std::string(MapContentTypeToString(Ctx->PayloadBuffer.GetContentType())))}); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); + Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); + Ctx->ReadData.Offset = 0; + + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPutNoPayload(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, [this, Url = std::move(Url), Callback = std::move(Callback), Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Put"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, 0LL); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + + HttpClient::KeyValueMap ContentLengthHeader{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; + Ctx->HeaderList = BuildHeaderList(ContentLengthHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + // -- Members --------------------------------------------------------- + + std::string m_BaseUri; + HttpClientSettings m_Settings; + LoggerRef m_Log; + std::string m_SessionId; + std::string m_UnixSocketPathUtf8; + + // io_context and strand - all curl_multi operations are serialized on the + // strand, making this safe even when the io_context has multiple threads. + std::unique_ptr<asio::io_context> m_OwnedIoContext; + asio::io_context& m_IoContext; + asio::strand<asio::io_context::executor_type> m_Strand; + std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard; + std::thread m_IoThread; + + // curl_multi and socket-action state + CURLM* m_Multi = nullptr; + std::unordered_map<CURL*, std::unique_ptr<TransferContext>> m_Transfers; + std::vector<CURL*> m_HandlePool; + std::unordered_map<curl_socket_t, std::unique_ptr<SocketInfo>> m_Sockets; + asio::steady_timer m_Timer; + bool m_ShuttingDown = false; + + // Access token cache + RwLock m_AccessTokenLock; + HttpClientAccessToken m_CachedAccessToken; +}; + +////////////////////////////////////////////////////////////////////////// +// +// AsyncHttpClient public API + +AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(BaseUri, Settings)) +{ +} + +AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(BaseUri, IoContext, Settings)) +{ +} + +AsyncHttpClient::~AsyncHttpClient() = default; + +// -- Callback-based API -------------------------------------------------- + +void +AsyncHttpClient::AsyncGet(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncGet(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); +} + +void +AsyncHttpClient::AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncHead(std::string(Url), std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncDelete(std::string(Url), std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncPost(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncPost(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); +} + +void +AsyncHttpClient::AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, Payload.GetContentType(), std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncPost(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, ContentType, std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncPut(std::string_view Url, + const IoBuffer& Payload, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncPutWithPayload(std::string(Url), Payload, std::move(Callback), AdditionalHeader, Parameters); +} + +void +AsyncHttpClient::AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncPutNoPayload(std::string(Url), std::move(Callback), Parameters); +} + +// -- Future-based API ---------------------------------------------------- + +std::future<HttpClient::Response> +AsyncHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncGet( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader, + Parameters); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncHead( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncDelete( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPost( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader, + Parameters); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPost( + Url, + Payload, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPost( + Url, + Payload, + ContentType, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPut( + Url, + Payload, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader, + Parameters); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPut( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + Parameters); + return Future; +} + +} // namespace zen diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp index d150b44c6..3be7337c1 100644 --- a/src/zenhttp/clients/httpclientcurl.cpp +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -1,6 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "httpclientcurl.h" +#include "httpclientcurlhelpers.h" #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> @@ -29,153 +30,7 @@ static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0}; ////////////////////////////////////////////////////////////////////////// -static HttpClientErrorCode -MapCurlError(CURLcode Code) -{ - switch (Code) - { - case CURLE_OK: - return HttpClientErrorCode::kOK; - case CURLE_COULDNT_CONNECT: - return HttpClientErrorCode::kConnectionFailure; - case CURLE_COULDNT_RESOLVE_HOST: - return HttpClientErrorCode::kHostResolutionFailure; - case CURLE_COULDNT_RESOLVE_PROXY: - return HttpClientErrorCode::kProxyResolutionFailure; - case CURLE_RECV_ERROR: - return HttpClientErrorCode::kNetworkReceiveError; - case CURLE_SEND_ERROR: - return HttpClientErrorCode::kNetworkSendFailure; - case CURLE_OPERATION_TIMEDOUT: - return HttpClientErrorCode::kOperationTimedOut; - case CURLE_SSL_CONNECT_ERROR: - return HttpClientErrorCode::kSSLConnectError; - case CURLE_SSL_CERTPROBLEM: - return HttpClientErrorCode::kSSLCertificateError; - case CURLE_PEER_FAILED_VERIFICATION: - return HttpClientErrorCode::kSSLCACertError; - case CURLE_SSL_CIPHER: - case CURLE_SSL_ENGINE_NOTFOUND: - case CURLE_SSL_ENGINE_SETFAILED: - return HttpClientErrorCode::kGenericSSLError; - case CURLE_ABORTED_BY_CALLBACK: - return HttpClientErrorCode::kRequestCancelled; - default: - return HttpClientErrorCode::kOtherError; - } -} - -////////////////////////////////////////////////////////////////////////// -// -// Curl callback helpers - -struct WriteCallbackData -{ - std::string* Body = nullptr; - std::function<bool()>* CheckIfAbortFunction = nullptr; -}; - -static size_t -CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) -{ - auto* Data = static_cast<WriteCallbackData*>(UserData); - size_t TotalBytes = Size * Nmemb; - - if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) - { - return 0; // Signal abort to curl - } - - Data->Body->append(Ptr, TotalBytes); - return TotalBytes; -} - -struct HeaderCallbackData -{ - std::vector<std::pair<std::string, std::string>>* Headers = nullptr; -}; - -// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value. -// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines). -static std::optional<std::pair<std::string_view, std::string_view>> -ParseHeaderLine(std::string_view Line) -{ - while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) - { - Line.remove_suffix(1); - } - - if (Line.empty()) - { - return std::nullopt; - } - - size_t ColonPos = Line.find(':'); - if (ColonPos == std::string_view::npos) - { - return std::nullopt; - } - - std::string_view Key = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); - - while (!Key.empty() && Key.back() == ' ') - { - Key.remove_suffix(1); - } - while (!Value.empty() && Value.front() == ' ') - { - Value.remove_prefix(1); - } - - return std::pair{Key, Value}; -} - -static size_t -CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) -{ - auto* Data = static_cast<HeaderCallbackData*>(UserData); - size_t TotalBytes = Size * Nmemb; - - if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) - { - auto& [Key, Value] = *Header; - Data->Headers->emplace_back(std::string(Key), std::string(Value)); - } - - return TotalBytes; -} - -struct ReadCallbackData -{ - const uint8_t* DataPtr = nullptr; - size_t DataSize = 0; - size_t Offset = 0; - std::function<bool()>* CheckIfAbortFunction = nullptr; -}; - -static size_t -CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) -{ - auto* Data = static_cast<ReadCallbackData*>(UserData); - size_t MaxRead = Size * Nmemb; - - if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) - { - return CURL_READFUNC_ABORT; - } - - size_t Remaining = Data->DataSize - Data->Offset; - size_t ToRead = std::min(MaxRead, Remaining); - - if (ToRead > 0) - { - memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead); - Data->Offset += ToRead; - } - - return ToRead; -} +// Curl callback helpers and shared utilities are in httpclientcurlhelpers.h struct StreamReadCallbackData { @@ -233,7 +88,7 @@ CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, voi { ZEN_UNUSED(Handle); LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr); - auto Log = [&]() -> LoggerRef { return LogRef; }; + ZEN_SCOPED_LOG(LogRef); std::string_view DataView(Data, Size); @@ -281,120 +136,6 @@ CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, voi ////////////////////////////////////////////////////////////////////////// -static std::pair<std::string, std::string> -HeaderContentType(ZenContentType ContentType) -{ - return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); -} - -static curl_slist* -BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, - std::string_view SessionId, - const std::optional<std::string>& AccessToken, - const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) -{ - curl_slist* Headers = nullptr; - - for (const auto& [Key, Value] : *AdditionalHeader) - { - ExtendableStringBuilder<64> HeaderLine; - HeaderLine << Key << ": " << Value; - Headers = curl_slist_append(Headers, HeaderLine.c_str()); - } - - if (!SessionId.empty()) - { - ExtendableStringBuilder<64> SessionHeader; - SessionHeader << "UE-Session: " << SessionId; - Headers = curl_slist_append(Headers, SessionHeader.c_str()); - } - - if (AccessToken.has_value()) - { - ExtendableStringBuilder<128> AuthHeader; - AuthHeader << "Authorization: " << AccessToken.value(); - Headers = curl_slist_append(Headers, AuthHeader.c_str()); - } - - for (const auto& [Key, Value] : ExtraHeaders) - { - ExtendableStringBuilder<128> HeaderLine; - HeaderLine << Key << ": " << Value; - Headers = curl_slist_append(Headers, HeaderLine.c_str()); - } - - return Headers; -} - -static HttpClient::KeyValueMap -BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers) -{ - HttpClient::KeyValueMap HeaderMap; - for (const auto& [Key, Value] : Headers) - { - HeaderMap->insert_or_assign(Key, Value); - } - return HeaderMap; -} - -// Scans response headers for Content-Type and applies it to the buffer. -static void -ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers) -{ - for (const auto& [Key, Value] : Headers) - { - if (StrCaseCompare(Key, "Content-Type") == 0) - { - Buffer.SetContentType(ParseContentType(Value)); - break; - } - } -} - -static void -AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input) -{ - static constexpr char HexDigits[] = "0123456789ABCDEF"; - static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"); - - for (char C : Input) - { - if (Unreserved.Contains(C)) - { - Out.Append(C); - } - else - { - uint8_t Byte = static_cast<uint8_t>(C); - char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]}; - Out.Append(std::string_view(Encoded, 3)); - } - } -} - -static void -BuildUrlWithParameters(StringBuilderBase& Url, - std::string_view BaseUrl, - std::string_view ResourcePath, - const HttpClient::KeyValueMap& Parameters) -{ - Url.Append(BaseUrl); - Url.Append(ResourcePath); - - if (!Parameters->empty()) - { - char Separator = '?'; - for (const auto& [Key, Value] : *Parameters) - { - Url.Append(Separator); - AppendUrlEncoded(Url, Key); - Url.Append('='); - AppendUrlEncoded(Url, Value); - Separator = '&'; - } - } -} - ////////////////////////////////////////////////////////////////////////// CurlHttpClient::CurlHttpClient(std::string_view BaseUri, @@ -440,9 +181,9 @@ CurlHttpClient::CurlResult CurlHttpClient::Session::PerformWithResponseCallbacks() { std::string Body; - WriteCallbackData WriteData{.Body = &Body, + CurlWriteCallbackData WriteData{.Body = &Body, .CheckIfAbortFunction = Outer->m_CheckIfAbortFunction ? &Outer->m_CheckIfAbortFunction : nullptr}; - HeaderCallbackData HdrData{}; + CurlHeaderCallbackData HdrData{}; std::vector<std::pair<std::string, std::string>> ResponseHeaders; HdrData.Headers = &ResponseHeaders; @@ -487,6 +228,13 @@ CurlHttpClient::Session::Perform() curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); Result.DownloadedBytes = static_cast<int64_t>(DownBytes); + char* EffectiveUrl = nullptr; + curl_easy_getinfo(Handle, CURLINFO_EFFECTIVE_URL, &EffectiveUrl); + if (EffectiveUrl) + { + Result.Url = EffectiveUrl; + } + return Result; } @@ -553,8 +301,9 @@ CurlHttpClient::CommonResponse(std::string_view SessionId, if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT && Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK) { - ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'", + ZEN_WARN("HttpClient client failure (session: {}, url: {}): ({}) '{}'", SessionId, + Result.Url, static_cast<int>(Result.ErrorCode), Result.ErrorMessage); } @@ -702,6 +451,7 @@ CurlHttpClient::ShouldRetry(const CurlResult& Result) case CURLE_RECV_ERROR: case CURLE_SEND_ERROR: case CURLE_OPERATION_TIMEDOUT: + case CURLE_PARTIAL_FILE: return true; default: return false; @@ -748,10 +498,11 @@ CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult { if (Result.ErrorCode != CURLE_OK) { - ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}", + ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' (Curl error: {}) Attempt {}/{}", SessionId, static_cast<int>(MapCurlError(Result.ErrorCode)), Result.ErrorMessage, + static_cast<int>(Result.ErrorCode), Attempt, m_ConnectionSettings.RetryCount + 1); } @@ -998,9 +749,9 @@ CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValu curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); - ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), - .DataSize = Payload.GetSize(), - .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + CurlReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); @@ -1213,7 +964,7 @@ CurlHttpClient::Post(std::string_view Url, std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); if (Ec) { - auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_SCOPED_LOG(Data->Log); ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}", Data->TempFolderPath->string(), Ec.message()); @@ -1266,7 +1017,7 @@ CurlHttpClient::Post(std::string_view Url, std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); if (Ec) { - auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_SCOPED_LOG(Data->Log); ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}", Data->TempFolderPath->string(), Ec.message()); @@ -1367,9 +1118,9 @@ CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyV return Sess.PerformWithResponseCallbacks(); } - ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), - .DataSize = Payload.GetSize(), - .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + CurlReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); @@ -1532,7 +1283,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); if (Ec) { - auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_SCOPED_LOG(Data->Log); ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", Data->TempFolderPath->string(), Ec.message()); @@ -1618,7 +1369,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); if (Ec) { - auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_SCOPED_LOG(Data->Log); ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", Data->TempFolderPath->string(), Ec.message()); diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h index bdeb46633..ea9193e65 100644 --- a/src/zenhttp/clients/httpclientcurl.h +++ b/src/zenhttp/clients/httpclientcurl.h @@ -73,6 +73,7 @@ private: int64_t DownloadedBytes = 0; CURLcode ErrorCode = CURLE_OK; std::string ErrorMessage; + std::string Url; }; struct Session diff --git a/src/zenhttp/clients/httpclientcurlhelpers.h b/src/zenhttp/clients/httpclientcurlhelpers.h new file mode 100644 index 000000000..cb5f5d9a9 --- /dev/null +++ b/src/zenhttp/clients/httpclientcurlhelpers.h @@ -0,0 +1,298 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +// Shared helpers for curl-based HTTP client implementations (sync and async). +// This is an internal header, not part of the public API. + +#include <zencore/string.h> + +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <curl/curl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <optional> +#include <string> +#include <utility> +#include <vector> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Error mapping + +inline HttpClientErrorCode +MapCurlError(CURLcode Code) +{ + switch (Code) + { + case CURLE_OK: + return HttpClientErrorCode::kOK; + case CURLE_COULDNT_CONNECT: + return HttpClientErrorCode::kConnectionFailure; + case CURLE_COULDNT_RESOLVE_HOST: + return HttpClientErrorCode::kHostResolutionFailure; + case CURLE_COULDNT_RESOLVE_PROXY: + return HttpClientErrorCode::kProxyResolutionFailure; + case CURLE_RECV_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case CURLE_SEND_ERROR: + return HttpClientErrorCode::kNetworkSendFailure; + case CURLE_OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case CURLE_SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case CURLE_SSL_CERTPROBLEM: + return HttpClientErrorCode::kSSLCertificateError; + case CURLE_PEER_FAILED_VERIFICATION: + return HttpClientErrorCode::kSSLCACertError; + case CURLE_SSL_CIPHER: + case CURLE_SSL_ENGINE_NOTFOUND: + case CURLE_SSL_ENGINE_SETFAILED: + return HttpClientErrorCode::kGenericSSLError; + case CURLE_ABORTED_BY_CALLBACK: + return HttpClientErrorCode::kRequestCancelled; + default: + return HttpClientErrorCode::kOtherError; + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Curl callback data structures and callbacks + +struct CurlWriteCallbackData +{ + std::string* Body = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +inline size_t +CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<CurlWriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; // Signal abort to curl + } + + Data->Body->append(Ptr, TotalBytes); + return TotalBytes; +} + +struct CurlHeaderCallbackData +{ + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; +}; + +// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value. +// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines). +inline std::optional<std::pair<std::string_view, std::string_view>> +ParseHeaderLine(std::string_view Line) +{ + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return std::nullopt; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos == std::string_view::npos) + { + return std::nullopt; + } + + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + return std::pair{Key, Value}; +} + +inline size_t +CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<CurlHeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) + { + auto& [Key, Value] = *Header; + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; +} + +struct CurlReadCallbackData +{ + const uint8_t* DataPtr = nullptr; + size_t DataSize = 0; + size_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +inline size_t +CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<CurlReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->DataSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +////////////////////////////////////////////////////////////////////////// +// +// URL and header construction + +inline void +AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input) +{ + static constexpr char HexDigits[] = "0123456789ABCDEF"; + static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"); + + for (char C : Input) + { + if (Unreserved.Contains(C)) + { + Out.Append(C); + } + else + { + uint8_t Byte = static_cast<uint8_t>(C); + char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]}; + Out.Append(std::string_view(Encoded, 3)); + } + } +} + +inline void +BuildUrlWithParameters(StringBuilderBase& Url, + std::string_view BaseUrl, + std::string_view ResourcePath, + const HttpClient::KeyValueMap& Parameters) +{ + Url.Append(BaseUrl); + Url.Append(ResourcePath); + + if (!Parameters->empty()) + { + char Separator = '?'; + for (const auto& [Key, Value] : *Parameters) + { + Url.Append(Separator); + AppendUrlEncoded(Url, Key); + Url.Append('='); + AppendUrlEncoded(Url, Value); + Separator = '&'; + } + } +} + +inline std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +inline curl_slist* +BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, + std::string_view SessionId, + const std::optional<std::string>& AccessToken, + const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) +{ + curl_slist* Headers = nullptr; + + for (const auto& [Key, Value] : *AdditionalHeader) + { + ExtendableStringBuilder<64> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + if (!SessionId.empty()) + { + ExtendableStringBuilder<64> SessionHeader; + SessionHeader << "UE-Session: " << SessionId; + Headers = curl_slist_append(Headers, SessionHeader.c_str()); + } + + if (AccessToken.has_value()) + { + ExtendableStringBuilder<128> AuthHeader; + AuthHeader << "Authorization: " << AccessToken.value(); + Headers = curl_slist_append(Headers, AuthHeader.c_str()); + } + + bool HasContentTypeOverride = AdditionalHeader->contains("Content-Type"); + for (const auto& [Key, Value] : ExtraHeaders) + { + if (HasContentTypeOverride && Key == "Content-Type") + { + continue; + } + ExtendableStringBuilder<128> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + return Headers; +} + +inline HttpClient::KeyValueMap +BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers) +{ + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + return HeaderMap; +} + +// Scans response headers for Content-Type and applies it to the buffer. +inline void +ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers) +{ + for (const auto& [Key, Value] : Headers) + { + if (StrCaseCompare(Key, "Content-Type") == 0) + { + Buffer.SetContentType(ParseContentType(Value)); + break; + } + } +} + +} // namespace zen diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index ace7a3c7f..3da8a9220 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -520,7 +520,7 @@ MeasureLatency(HttpClient& Client, std::string_view Url) ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url)); // Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable. - // Bail out immediately — retrying will just burn the connect timeout each time. + // Bail out immediately - retrying will just burn the connect timeout each time. if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError()) { break; diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index af653cbb2..deaeca2a8 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -194,7 +194,7 @@ public: "slow", [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) { - Sleep(2000); + Sleep(100); Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response"); }); }, @@ -414,6 +414,17 @@ TEST_CASE("httpclient.post") CHECK_EQ(Resp.AsText(), "{\"key\":\"value\"}"); } + SUBCASE("POST with content type override via additional header") + { + const char* Payload = "test payload"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf, ZenContentType::kJSON, {{"Content-Type", "text/plain"}}); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "test payload"); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kText); + } + SUBCASE("POST with CbObject payload round-trip") { CbObjectWriter Writer; @@ -750,7 +761,9 @@ TEST_CASE("httpclient.error-handling") { SUBCASE("Connection refused") { - HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + HttpClientSettings Settings; + Settings.ConnectTimeout = std::chrono::milliseconds(200); + HttpClient Client("127.0.0.1:19999", Settings, /*CheckIfAbortFunction*/ {}); HttpClient::Response Resp = Client.Get("/api/test/hello"); CHECK(!Resp.IsSuccess()); CHECK(Resp.Error.has_value()); @@ -760,7 +773,7 @@ TEST_CASE("httpclient.error-handling") { TestServerFixture Fixture; HttpClientSettings Settings; - Settings.Timeout = std::chrono::milliseconds(500); + Settings.Timeout = std::chrono::milliseconds(50); HttpClient Client = Fixture.MakeClient(Settings); HttpClient::Response Resp = Client.Get("/api/test/slow"); @@ -970,7 +983,9 @@ TEST_CASE("httpclient.measurelatency") SUBCASE("Failed measurement against unreachable port") { - HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + HttpClientSettings Settings; + Settings.ConnectTimeout = std::chrono::milliseconds(200); + HttpClient Client("127.0.0.1:19999", Settings, /*CheckIfAbortFunction*/ {}); LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello"); CHECK(!Result.Success); CHECK(!Result.FailureReason.empty()); @@ -1144,7 +1159,7 @@ struct FaultTcpServer ~FaultTcpServer() { // io_context::stop() is thread-safe; do NOT call m_Acceptor.close() from this - // thread — ASIO I/O objects are not safe for concurrent access and the io_context + // thread - ASIO I/O objects are not safe for concurrent access and the io_context // thread may be touching the acceptor in StartAccept(). m_IoContext.stop(); if (m_Thread.joinable()) @@ -1498,7 +1513,7 @@ TEST_CASE("httpclient.transport-faults-post" * doctest::skip()) std::atomic<bool> StallActive{true}; FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) { DrainHttpRequest(Socket); - // Stop reading body — TCP window will fill and client send will stall + // Stop reading body - TCP window will fill and client send will stall while (StallActive.load()) { std::this_thread::sleep_for(std::chrono::milliseconds(50)); @@ -1735,21 +1750,21 @@ TEST_CASE("httpclient.uri_decoding") TestServerFixture Fixture; HttpClient Client = Fixture.MakeClient(); - // URI without encoding — should pass through unchanged + // URI without encoding - should pass through unchanged { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello/world.txt"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "uri=echo/uri/hello/world.txt\ncapture=hello/world.txt"); } - // Percent-encoded space — server should see decoded path + // Percent-encoded space - server should see decoded path { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello%20world.txt"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "uri=echo/uri/hello world.txt\ncapture=hello world.txt"); } - // Percent-encoded slash (%2F) — should be decoded to / + // Percent-encoded slash (%2F) - should be decoded to / { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/a%2Fb.txt"); REQUIRE(Resp.IsSuccess()); @@ -1763,21 +1778,21 @@ TEST_CASE("httpclient.uri_decoding") CHECK(Resp.AsText() == "uri=echo/uri/file & name.txt\ncapture=file & name.txt"); } - // No capture — echo/uri route returns just RelativeUri + // No capture - echo/uri route returns just RelativeUri { HttpClient::Response Resp = Client.Get("/api/test/echo/uri"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "echo/uri"); } - // Literal percent that is not an escape (%ZZ) — should be kept as-is + // Literal percent that is not an escape (%ZZ) - should be kept as-is { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/100%25done.txt"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "uri=echo/uri/100%done.txt\ncapture=100%done.txt"); } - // Query params — raw values are returned as-is from GetQueryParams + // Query params - raw values are returned as-is from GetQueryParams { HttpClient::Response Resp = Client.Get("/api/test/echo/uri?key=value&name=test"); REQUIRE(Resp.IsSuccess()); @@ -1788,7 +1803,7 @@ TEST_CASE("httpclient.uri_decoding") { HttpClient::Response Resp = Client.Get("/api/test/echo/uri?prefix=listing%2F&mode=s3"); REQUIRE(Resp.IsSuccess()); - // GetQueryParams returns raw (still-encoded) values — callers must Decode() explicitly + // GetQueryParams returns raw (still-encoded) values - callers must Decode() explicitly CHECK(Resp.AsText() == "echo/uri\nprefix=listing%2F\nmode=s3"); } diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp index c42841922..26a7298b3 100644 --- a/src/zenhttp/httpclientauth.cpp +++ b/src/zenhttp/httpclientauth.cpp @@ -50,8 +50,6 @@ namespace zen { namespace httpclientauth { IoBuffer Payload{IoBuffer::Wrap, Body.data(), Body.size()}; - // TODO: ensure this gets the right Content-Type passed along - HttpClient::Response Response = Http.Post("", Payload, {{"Content-Type", "application/x-www-form-urlencoded"}}); if (!Response || Response.StatusCode != HttpResponseCode::OK) @@ -94,7 +92,8 @@ namespace zen { namespace httpclientauth { std::string_view CloudHost, bool Unattended, bool Quiet, - bool Hidden) + bool Hidden, + bool IsHordeUrl) { Stopwatch Timer; @@ -117,8 +116,9 @@ namespace zen { namespace httpclientauth { } }); - const std::string ProcArgs = fmt::format("{} --AuthConfigUrl {} --OutFile {} --Unattended={}", + const std::string ProcArgs = fmt::format("{} {} {} --OutFile {} --Unattended={}", OidcExecutablePath, + IsHordeUrl ? "--HordeUrl" : "--AuthConfigUrl", CloudHost, AuthTokenPath, Unattended ? "true"sv : "false"sv); @@ -193,7 +193,7 @@ namespace zen { namespace httpclientauth { } else { - ZEN_WARN("Failed running {} to get auth token, error code {}", OidcExecutablePath, ExitCode); + ZEN_WARN("Failed running '{}' to get auth token, error code {}", ProcArgs, ExitCode); } return HttpClientAccessToken{}; } @@ -202,9 +202,10 @@ namespace zen { namespace httpclientauth { std::string_view CloudHost, bool Quiet, bool Unattended, - bool Hidden) + bool Hidden, + bool IsHordeUrl) { - HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); + HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl); if (InitialToken.IsValid()) { return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath), @@ -212,12 +213,13 @@ namespace zen { namespace httpclientauth { Token = InitialToken, Quiet, Unattended, - Hidden]() mutable { + Hidden, + IsHordeUrl]() mutable { if (!Token.NeedsRefresh()) { return std::move(Token); } - return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); + return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl); }; } return {}; diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index e05c9815f..03117ee6c 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -266,10 +266,10 @@ TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges) return false; } - const auto Start = ParseInt<uint32_t>(Token.substr(0, Delim)); - const auto End = ParseInt<uint32_t>(Token.substr(Delim + 1)); + const auto Start = ParseInt<uint64_t>(Token.substr(0, Delim)); + const auto End = ParseInt<uint64_t>(Token.substr(Delim + 1)); - if (Start.has_value() && End.has_value() && End.value() > Start.value()) + if (Start.has_value() && End.has_value() && End.value() >= Start.value()) { Ranges.push_back({.Start = Start.value(), .End = End.value()}); } @@ -286,6 +286,45 @@ TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges) return Count != Ranges.size(); } +MultipartByteRangesResult +BuildMultipartByteRanges(const IoBuffer& Data, const HttpRanges& Ranges) +{ + Oid::String_t BoundaryStr; + Oid::NewOid().ToString(BoundaryStr); + std::string_view Boundary(BoundaryStr, Oid::StringLength); + + const uint64_t TotalSize = Data.GetSize(); + + std::vector<IoBuffer> Parts; + Parts.reserve(Ranges.size() * 2 + 1); + + for (const HttpRange& Range : Ranges) + { + uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1; + if (RangeEnd >= TotalSize || Range.Start > RangeEnd) + { + return {}; + } + + uint64_t RangeSize = 1 + (RangeEnd - Range.Start); + + std::string PartHeader = fmt::format("\r\n--{}\r\nContent-Type: application/octet-stream\r\nContent-Range: bytes {}-{}/{}\r\n\r\n", + Boundary, + Range.Start, + RangeEnd, + TotalSize); + Parts.push_back(IoBufferBuilder::MakeCloneFromMemory(PartHeader.data(), PartHeader.size())); + + IoBuffer RangeData(Data, Range.Start, RangeSize); + Parts.push_back(RangeData); + } + + std::string ClosingBoundary = fmt::format("\r\n--{}--", Boundary); + Parts.push_back(IoBufferBuilder::MakeCloneFromMemory(ClosingBoundary.data(), ClosingBoundary.size())); + + return {.Parts = std::move(Parts), .ContentType = fmt::format("multipart/byteranges; boundary={}", Boundary)}; +} + ////////////////////////////////////////////////////////////////////////// const std::string_view @@ -479,6 +518,18 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) return Ref<IHttpPackageHandler>(); } +bool +HttpService::AcceptsLocalFileReferences() const +{ + return false; +} + +const ILocalRefPolicy* +HttpService::GetLocalRefPolicy() const +{ + return nullptr; +} + ////////////////////////////////////////////////////////////////////////// HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service) @@ -552,6 +603,56 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType } void +HttpServerRequest::WriteResponse(HttpContentType ContentType, const IoBuffer& Data, const HttpRanges& Ranges) +{ + if (Ranges.empty()) + { + WriteResponse(HttpResponseCode::OK, ContentType, IoBuffer(Data)); + return; + } + + if (Ranges.size() == 1) + { + const HttpRange& Range = Ranges[0]; + const uint64_t TotalSize = Data.GetSize(); + // ~uint64_t(0) is the sentinel meaning "end of file" (suffix range). + const uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1; + + if (RangeEnd >= TotalSize || Range.Start > RangeEnd) + { + m_ContentRangeHeader = fmt::format("bytes */{}", TotalSize); + WriteResponse(HttpResponseCode::RangeNotSatisfiable); + return; + } + + const uint64_t RangeSize = 1 + (RangeEnd - Range.Start); + IoBuffer RangeBuf(Data, Range.Start, RangeSize); + + m_ContentRangeHeader = fmt::format("bytes {}-{}/{}", Range.Start, RangeEnd, TotalSize); + WriteResponse(HttpResponseCode::PartialContent, ContentType, std::move(RangeBuf)); + return; + } + + // Multi-range + MultipartByteRangesResult MultipartResult = BuildMultipartByteRanges(Data, Ranges); + if (MultipartResult.Parts.empty()) + { + m_ContentRangeHeader = fmt::format("bytes */{}", Data.GetSize()); + WriteResponse(HttpResponseCode::RangeNotSatisfiable); + return; + } + WriteResponse(HttpResponseCode::PartialContent, std::move(MultipartResult.ContentType), std::span<IoBuffer>(MultipartResult.Parts)); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, const std::string& CustomContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(ParseContentType(CustomContentType) == HttpContentType::kUnknownContentType); + m_ContentTypeOverride = CustomContentType; + WriteResponse(ResponseCode, HttpContentType::kBinary, Blobs); +} + +void HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) { std::span<const SharedBuffer> Segments = Payload.GetSegments(); @@ -705,7 +806,10 @@ HttpServerRequest::ReadPayloadPackage() { if (IoBuffer Payload = ReadPayload()) { - return ParsePackageMessage(std::move(Payload)); + ParseFlags Flags = + (IsLocalMachineRequest() && m_Service.AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences : ParseFlags::kDefault; + const ILocalRefPolicy* Policy = EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences) ? m_Service.GetLocalRefPolicy() : nullptr; + return ParsePackageMessage(std::move(Payload), {}, Flags, Policy); } return {}; @@ -816,7 +920,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) // Strip the separator slash left over after the service prefix is removed. // When a service has BaseUri "/foo", the prefix length is set to len("/foo") = 4. - // Stripping 4 chars from "/foo/bar" yields "/bar" — the path separator becomes + // Stripping 4 chars from "/foo/bar" yields "/bar" - the path separator becomes // the first character of the relative URI. Remove it so patterns like "bar" or // "{id}" match without needing to account for the leading slash. if (!Uri.empty() && Uri.front() == '/') @@ -1273,7 +1377,12 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP return PackageHandlerRef->CreateTarget(Cid, Size); }; - CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer); + ParseFlags PkgFlags = (Request.IsLocalMachineRequest() && Service.AcceptsLocalFileReferences()) + ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + const ILocalRefPolicy* PkgPolicy = + EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? Service.GetLocalRefPolicy() : nullptr; + CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer, PkgFlags, PkgPolicy); PackageHandlerRef->OnRequestComplete(); } @@ -1512,7 +1621,7 @@ TEST_CASE("http.common") }, HttpVerb::kGet); - // Single-segment literal with leading slash — simulates real server RelativeUri + // Single-segment literal with leading slash - simulates real server RelativeUri { Reset(); TestHttpServerRequest req{Service, "/activity_counters"sv}; @@ -1532,7 +1641,7 @@ TEST_CASE("http.common") CHECK_EQ(Captures[0], "hello"sv); } - // Two-segment route with leading slash — first literal segment + // Two-segment route with leading slash - first literal segment { Reset(); TestHttpServerRequest req{Service, "/prefix/world"sv}; diff --git a/src/zenhttp/include/zenhttp/asynchttpclient.h b/src/zenhttp/include/zenhttp/asynchttpclient.h new file mode 100644 index 000000000..cb41626b9 --- /dev/null +++ b/src/zenhttp/include/zenhttp/asynchttpclient.h @@ -0,0 +1,123 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zenhttp/httpclient.h> + +#include <functional> +#include <future> +#include <memory> + +namespace asio { +class io_context; +} + +namespace zen { + +/// Completion callback for async HTTP operations. +using AsyncHttpCallback = std::function<void(HttpClient::Response)>; + +/** Asynchronous HTTP client backed by curl_multi and ASIO. + * + * Uses curl_multi_socket_action() driven by ASIO socket async_wait to process + * transfers without blocking the caller. All curl_multi operations are + * serialized on an internal strand; callers may issue requests from any + * thread, and the io_context may have multiple threads. + * + * Two construction modes: + * - Owned io_context: creates an internal thread (self-contained). + * - External io_context: caller runs the event loop. + * + * Completion callbacks are dispatched on the io_context (not the internal + * strand), so a slow callback will not block the curl poll loop. Future- + * based wrappers (Get, Post, ...) return a std::future<Response> for + * callers that prefer blocking on a result. + */ +class AsyncHttpClient +{ +public: + using Response = HttpClient::Response; + using KeyValueMap = HttpClient::KeyValueMap; + + /// Construct with an internally-owned io_context and thread. + explicit AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings = {}); + + /// Construct with an externally-managed io_context. The io_context must + /// outlive this client and must be running (via run()) on at least one thread. + AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings = {}); + + ~AsyncHttpClient(); + + AsyncHttpClient(const AsyncHttpClient&) = delete; + AsyncHttpClient& operator=(const AsyncHttpClient&) = delete; + + // -- Callback-based API ---------------------------------------------- + + void AsyncGet(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + void AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + void AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + void AsyncPost(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + void AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + void AsyncPost(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}); + + void AsyncPut(std::string_view Url, + const IoBuffer& Payload, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + void AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters = {}); + + // -- Future-based API ------------------------------------------------ + + [[nodiscard]] std::future<Response> Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + [[nodiscard]] std::future<Response> Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + [[nodiscard]] std::future<Response> Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Put(std::string_view Url, + const IoBuffer& Payload, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + [[nodiscard]] std::future<Response> Put(std::string_view Url, const KeyValueMap& Parameters = {}); + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void asynchttpclient_test_forcelink(); // internal + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h index ce646ebd7..9220a50b6 100644 --- a/src/zenhttp/include/zenhttp/httpclientauth.h +++ b/src/zenhttp/include/zenhttp/httpclientauth.h @@ -33,7 +33,8 @@ namespace httpclientauth { std::string_view CloudHost, bool Quiet, bool Unattended, - bool Hidden); + bool Hidden, + bool IsHordeUrl = false); } // namespace httpclientauth } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h index f9a99f3cc..1d921600d 100644 --- a/src/zenhttp/include/zenhttp/httpcommon.h +++ b/src/zenhttp/include/zenhttp/httpcommon.h @@ -19,8 +19,8 @@ class StringBuilderBase; struct HttpRange { - uint32_t Start = ~uint32_t(0); - uint32_t End = ~uint32_t(0); + uint64_t Start = ~uint64_t(0); + uint64_t End = ~uint64_t(0); }; using HttpRanges = std::vector<HttpRange>; @@ -30,6 +30,16 @@ extern HttpContentType (*ParseContentType)(const std::string_view& ContentTypeSt std::string_view ReasonStringForHttpResultCode(int HttpCode); bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges); +struct MultipartByteRangesResult +{ + std::vector<IoBuffer> Parts; + std::string ContentType; +}; + +// Build a multipart/byteranges response body from the given data and ranges. +// Generates a unique boundary per call. Returns empty Parts if any range is out of bounds. +MultipartByteRangesResult BuildMultipartByteRanges(const IoBuffer& Data, const HttpRanges& Ranges); + enum class HttpVerb : uint8_t { kGet = 1 << 0, diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 5eaed6004..955b8ed15 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -12,6 +12,7 @@ #include <zencore/string.h> #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <zenhttp/localrefpolicy.h> #include <zentelemetry/hyperloglog.h> #include <zentelemetry/stats.h> @@ -121,11 +122,13 @@ public: virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload); + void WriteResponse(HttpResponseCode ResponseCode, const std::string& CustomContentType, std::span<IoBuffer> Blobs); void WriteResponse(HttpResponseCode ResponseCode, CbObject Data); void WriteResponse(HttpResponseCode ResponseCode, CbArray Array); void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package); void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString); void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob); + void WriteResponse(HttpContentType ContentType, const IoBuffer& Data, const HttpRanges& Ranges); virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) = 0; @@ -151,6 +154,8 @@ protected: std::string_view m_QueryString; mutable uint32_t m_RequestId = ~uint32_t(0); mutable Oid m_SessionId = Oid::Zero; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; inline void SetIsHandled() { m_Flags |= kIsHandled; } @@ -193,9 +198,16 @@ public: HttpService() = default; virtual ~HttpService() = default; - virtual const char* BaseUri() const = 0; - virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; - virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); + [[nodiscard]] virtual const char* BaseUri() const = 0; + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; + virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); + + /// Whether this service accepts local file references in inbound packages from local clients. + [[nodiscard]] virtual bool AcceptsLocalFileReferences() const; + + /// Returns the local ref policy for validating file paths in inbound local references. + /// Returns nullptr by default, which causes file-path local refs to be rejected (fail-closed). + [[nodiscard]] virtual const ILocalRefPolicy* GetLocalRefPolicy() const; // Internals @@ -290,12 +302,12 @@ public: std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; } - /** Track active WebSocket connections — called by server implementations on upgrade/close. */ + /** Track active WebSocket connections - called by server implementations on upgrade/close. */ void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); } void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); } uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); } - /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */ + /** Track WebSocket frame and byte counters - called by WS connection implementations per frame. */ void OnWebSocketFrameReceived(uint64_t Bytes) { m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed); @@ -317,7 +329,7 @@ private: int m_EffectiveHttpsPort = 0; std::string m_ExternalHost; metrics::Meter m_RequestMeter; - metrics::HyperLogLog<12> m_ClientAddresses; // ~4 KiB, ~1.6% error — sufficient for client counting + metrics::HyperLogLog<12> m_ClientAddresses; // ~4 KiB, ~1.6% error - sufficient for client counting metrics::HyperLogLog<12> m_ClientSessions; std::string m_DefaultRedirect; std::atomic<uint64_t> m_ActiveWebSocketConnections{0}; @@ -510,7 +522,8 @@ private: bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef); -void http_forcelink(); // internal -void websocket_forcelink(); // internal +void http_forcelink(); // internal +void httpparser_forcelink(); // internal +void websocket_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h index bce771c75..51ab2e06e 100644 --- a/src/zenhttp/include/zenhttp/httpstats.h +++ b/src/zenhttp/include/zenhttp/httpstats.h @@ -23,11 +23,11 @@ namespace zen { class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler { public: - /// Construct without an io_context — optionally uses a dedicated push thread + /// Construct without an io_context - optionally uses a dedicated push thread /// for WebSocket stats broadcasting. explicit HttpStatsService(bool EnableWebSockets = false); - /// Construct with an external io_context — uses an asio timer instead + /// Construct with an external io_context - uses an asio timer instead /// of a dedicated thread for WebSocket stats broadcasting. /// The caller must ensure the io_context outlives this service and that /// its run loop is active. @@ -43,7 +43,7 @@ public: virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h index 9c3b909a2..fd2f79171 100644 --- a/src/zenhttp/include/zenhttp/httpwsclient.h +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -26,7 +26,7 @@ namespace zen { * Callback interface for WebSocket client events * * Separate from the server-side IWebSocketHandler because the caller - * already owns the HttpWsClient — no Ref<WebSocketConnection> needed. + * already owns the HttpWsClient - no Ref<WebSocketConnection> needed. */ class IWsClientHandler { @@ -85,9 +85,9 @@ private: /// it is treated as a plain host:port and gets the ws:// prefix. /// /// Examples: -/// HttpToWsUrl("http://host:8080", "/orch/ws") → "ws://host:8080/orch/ws" -/// HttpToWsUrl("https://host", "/foo") → "wss://host/foo" -/// HttpToWsUrl("host:8080", "/bar") → "ws://host:8080/bar" +/// HttpToWsUrl("http://host:8080", "/orch/ws") -> "ws://host:8080/orch/ws" +/// HttpToWsUrl("https://host", "/foo") -> "wss://host/foo" +/// HttpToWsUrl("host:8080", "/bar") -> "ws://host:8080/bar" std::string HttpToWsUrl(std::string_view Endpoint, std::string_view Path); } // namespace zen diff --git a/src/zenhttp/include/zenhttp/localrefpolicy.h b/src/zenhttp/include/zenhttp/localrefpolicy.h new file mode 100644 index 000000000..0b37f9dc7 --- /dev/null +++ b/src/zenhttp/include/zenhttp/localrefpolicy.h @@ -0,0 +1,21 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <filesystem> + +namespace zen { + +/// Policy interface for validating local file reference paths in inbound CbPackage messages. +/// Implementations should throw std::invalid_argument if the path is not allowed. +class ILocalRefPolicy +{ +public: + virtual ~ILocalRefPolicy() = default; + + /// Validate that a local file reference path is allowed. + /// Throws std::invalid_argument if the path escapes the allowed root. + virtual void ValidatePath(const std::filesystem::path& Path) const = 0; +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h index 1a5068580..66e3f6e55 100644 --- a/src/zenhttp/include/zenhttp/packageformat.h +++ b/src/zenhttp/include/zenhttp/packageformat.h @@ -5,6 +5,7 @@ #include <zencore/compactbinarypackage.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> +#include <zenhttp/localrefpolicy.h> #include <functional> #include <gsl/gsl-lite.hpp> @@ -97,11 +98,22 @@ gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions); std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr); CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr); -CbPackage ParsePackageMessage( - IoBuffer Payload, - std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { + +enum class ParseFlags +{ + kDefault = 0, + kAllowLocalReferences = (1u << 0), // Allow packages containing local file references (local clients only) +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(ParseFlags); + +CbPackage ParsePackageMessage( + IoBuffer Payload, + std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; - }); + }, + ParseFlags Flags = ParseFlags::kDefault, + const ILocalRefPolicy* Policy = nullptr); bool IsPackageMessage(IoBuffer Payload); bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage); @@ -122,10 +134,11 @@ CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, void* Targe class CbPackageReader { public: - CbPackageReader(); + CbPackageReader(ParseFlags Flags = ParseFlags::kDefault); ~CbPackageReader(); void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer); + void SetLocalRefPolicy(const ILocalRefPolicy* Policy); /** Process compact binary package data stream @@ -149,6 +162,8 @@ private: kReadingBuffers } m_CurrentState = State::kInitialState; + ParseFlags m_Flags; + const ILocalRefPolicy* m_LocalRefPolicy = nullptr; std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer; std::vector<IoBuffer> m_PayloadBuffers; std::vector<CbAttachmentEntry> m_AttachmentEntries; diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h index 710579faa..2d25515d3 100644 --- a/src/zenhttp/include/zenhttp/websocket.h +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -59,7 +59,7 @@ class IWebSocketHandler public: virtual ~IWebSocketHandler() = default; - virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0; + virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) = 0; virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0; virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0; }; diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp index 7e6207e56..5ad5ebcc7 100644 --- a/src/zenhttp/monitoring/httpstats.cpp +++ b/src/zenhttp/monitoring/httpstats.cpp @@ -196,8 +196,9 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request) // void -HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen"); ZEN_INFO("Stats WebSocket client connected"); diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp index 9c62c1f2d..267ce386c 100644 --- a/src/zenhttp/packageformat.cpp +++ b/src/zenhttp/packageformat.cpp @@ -36,6 +36,71 @@ const std::string_view HandlePrefix(":?#:"); typedef eastl::fixed_vector<IoBuffer, 16> IoBufferVec_t; +/// Enforce local-ref path policy. Handle-based refs bypass the policy since they use OS handle security. +/// If no policy is set, file-path local refs are rejected (fail-closed). +static void +ApplyLocalRefPolicy(const ILocalRefPolicy* Policy, const std::filesystem::path& Path) +{ + if (Policy) + { + Policy->ValidatePath(Path); + } + else + { + throw std::invalid_argument("local file reference rejected: no validation policy"); + } +} + +// Validates the CbPackageHeader magic and attachment count. Returns the total +// chunk count (AttachmentCount + 1, including the implicit root object). +static uint32_t +ValidatePackageHeader(const CbPackageHeader& Hdr) +{ + if (Hdr.HeaderMagic != kCbPkgMagic) + { + throw std::invalid_argument( + fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr.HeaderMagic)); + } + // ChunkCount is AttachmentCount + 1 (the root object is implicit). Guard against + // UINT32_MAX wrapping to 0, which would bypass subsequent size checks. + if (Hdr.AttachmentCount == UINT32_MAX) + { + throw std::invalid_argument("invalid CbPackage, attachment count overflow"); + } + return Hdr.AttachmentCount + 1; +} + +struct ValidatedLocalRef +{ + bool Valid = false; + const CbAttachmentReferenceHeader* Header = nullptr; + std::string_view Path; + std::string Error; +}; + +// Validates that the attachment buffer contains a well-formed local reference +// header and path. On failure, Valid is false and Error contains the reason. +static ValidatedLocalRef +ValidateLocalRef(const IoBuffer& AttachmentBuffer) +{ + if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader)) + { + return {.Error = fmt::format("local ref attachment too small for header (size {})", AttachmentBuffer.Size())}; + } + + const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); + + if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength) + { + return {.Error = fmt::format("local ref attachment too small for path (need {}, have {})", + sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength, + AttachmentBuffer.Size())}; + } + + const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); + return {.Valid = true, .Header = AttachRefHdr, .Path = std::string_view(PathPointer, AttachRefHdr->AbsolutePathLength)}; +} + IoBufferVec_t FormatPackageMessageInternal(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle); std::vector<IoBuffer> @@ -361,7 +426,10 @@ IsPackageMessage(IoBuffer Payload) } CbPackage -ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer) +ParsePackageMessage(IoBuffer Payload, + std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer, + ParseFlags Flags, + const ILocalRefPolicy* Policy) { ZEN_TRACE_CPU("ParsePackageMessage"); @@ -372,17 +440,13 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint BinaryReader Reader(Payload); - const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData()); - if (Hdr->HeaderMagic != kCbPkgMagic) - { - throw std::invalid_argument( - fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr->HeaderMagic)); - } + const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData()); + const uint32_t ChunkCount = ValidatePackageHeader(*Hdr); Reader.Skip(sizeof(CbPackageHeader)); - const uint32_t ChunkCount = Hdr->AttachmentCount + 1; - - if (Reader.Remaining() < sizeof(CbAttachmentEntry) * ChunkCount) + // Widen to uint64_t so the multiplication cannot wrap on 32-bit. + const uint64_t AttachmentTableSize = uint64_t(sizeof(CbAttachmentEntry)) * ChunkCount; + if (Reader.Remaining() < AttachmentTableSize) { throw std::invalid_argument(fmt::format("invalid CbPackage, missing attachment entry data (need {} bytes, have {} bytes)", sizeof(CbAttachmentEntry) * ChunkCount, @@ -417,15 +481,22 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) { - // Marshal local reference - a "pointer" to the chunk backing file - - ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); + if (!EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences)) + { + throw std::invalid_argument( + fmt::format("package contains local reference (attachment #{}) but local references are not allowed", i)); + } - const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); - const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); + // Marshal local reference - a "pointer" to the chunk backing file - ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); - std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength); + ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer); + if (!LocalRef.Valid) + { + MalformedAttachments.push_back(std::make_pair(i, fmt::format("{} for {}", LocalRef.Error, Entry.AttachmentHash))); + continue; + } + const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header; + std::string_view PathView = LocalRef.Path; IoBuffer FullFileBuffer; @@ -461,13 +532,29 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint } else { + ApplyLocalRefPolicy(Policy, Path); FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second; } } if (FullFileBuffer) { - IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize() + // Guard against offset+size overflow or exceeding the file bounds. + const uint64_t FileSize = FullFileBuffer.GetSize(); + if (AttachRefHdr->PayloadByteOffset > FileSize || + AttachRefHdr->PayloadByteSize > FileSize - AttachRefHdr->PayloadByteOffset) + { + MalformedAttachments.push_back( + std::make_pair(i, + fmt::format("Local ref offset/size out of bounds (offset {}, size {}, file size {}) for {}", + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize, + FileSize, + Entry.AttachmentHash))); + continue; + } + + IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FileSize ? FullFileBuffer : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); @@ -630,7 +717,9 @@ ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPa return OutPackage.TryLoad(Response); } -CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) +CbPackageReader::CbPackageReader(ParseFlags Flags) +: m_Flags(Flags) +, m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) { } @@ -644,6 +733,12 @@ CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Ci m_CreateBuffer = CreateBuffer; } +void +CbPackageReader::SetLocalRefPolicy(const ILocalRefPolicy* Policy) +{ + m_LocalRefPolicy = Policy; +} + uint64_t CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) { @@ -657,12 +752,14 @@ CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) return sizeof m_PackageHeader; case State::kReadingHeader: - ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); - memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); - ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic); - m_CurrentState = State::kReadingAttachmentEntries; - m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1); - return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry); + { + ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); + memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); + const uint32_t ChunkCount = ValidatePackageHeader(m_PackageHeader); + m_CurrentState = State::kReadingAttachmentEntries; + m_AttachmentEntries.resize(ChunkCount); + return uint64_t(ChunkCount) * sizeof(CbAttachmentEntry); + } case State::kReadingAttachmentEntries: ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry))); @@ -691,16 +788,19 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) { // Marshal local reference - a "pointer" to the chunk backing file - ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); - - const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); - const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1); - - ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); + ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer); + if (!LocalRef.Valid) + { + throw std::invalid_argument(LocalRef.Error); + } - std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength}; + const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header; + std::filesystem::path Path(Utf8ToWide(LocalRef.Path)); - std::filesystem::path Path{PathView}; + if (!LocalRef.Path.starts_with(HandlePrefix)) + { + ApplyLocalRefPolicy(m_LocalRefPolicy, Path); + } IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); @@ -714,6 +814,17 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) AttachRefHdr->PayloadByteSize)); } + // MakeFromFile silently clamps offset+size to the file size. Detect this + // to avoid returning a short buffer that could cause subtle downstream issues. + if (ChunkReference.GetSize() != AttachRefHdr->PayloadByteSize) + { + throw std::invalid_argument(fmt::format("local ref offset/size out of bounds for '{}' (requested offset {}, size {}, got size {})", + PathToUtf8(Path), + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize, + ChunkReference.GetSize())); + } + return ChunkReference; }; @@ -732,6 +843,13 @@ CbPackageReader::Finalize() { IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex]; + if ((Entry.Flags & CbAttachmentEntry::kIsLocalRef) && !EnumHasAllFlags(m_Flags, ParseFlags::kAllowLocalReferences)) + { + throw std::invalid_argument( + fmt::format("package contains local reference (attachment #{}) but local references are not allowed", + CurrentAttachmentIndex)); + } + if (CurrentAttachmentIndex == 0) { // Root object @@ -815,6 +933,13 @@ CbPackageReader::Finalize() TEST_SUITE_BEGIN("http.packageformat"); +/// Permissive policy that allows any path, for use in tests that exercise local ref +/// functionality but are not testing path validation. +struct PermissiveLocalRefPolicy : public ILocalRefPolicy +{ + void ValidatePath(const std::filesystem::path&) const override {} +}; + TEST_CASE("CbPackage.Serialization") { // Make a test package @@ -922,6 +1047,169 @@ TEST_CASE("CbPackage.LocalRef") RemainingBytes -= ByteCount; }; + PermissiveLocalRefPolicy AllowAllPolicy; + CbPackageReader Reader(ParseFlags::kAllowLocalReferences); + Reader.SetLocalRefPolicy(&AllowAllPolicy); + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); + NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); + auto Buffers = Reader.GetPayloadBuffers(); + + for (auto& PayloadBuffer : Buffers) + { + CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); + } + + Reader.Finalize(); +} + +TEST_CASE("CbPackage.Validation.TruncatedHeader") +{ + // Payload too small for a CbPackageHeader + uint8_t Bytes[] = {0xcc, 0xaa, 0x77, 0xaa}; + IoBuffer Payload(IoBuffer::Wrap, Bytes, sizeof(Bytes)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.BadMagic") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = 0xDEADBEEF; + Hdr.AttachmentCount = 0; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.AttachmentCountOverflow") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = UINT32_MAX; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.TruncatedAttachmentTable") +{ + // Valid header but not enough data for the attachment entries + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = 10; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.TruncatedAttachmentData") +{ + // Valid header + one attachment entry claiming more data than available + std::vector<uint8_t> Data(sizeof(CbPackageHeader) + sizeof(CbAttachmentEntry)); + + CbPackageHeader* Hdr = reinterpret_cast<CbPackageHeader*>(Data.data()); + Hdr->HeaderMagic = kCbPkgMagic; + Hdr->AttachmentCount = 0; // ChunkCount = 1 (root object) + + CbAttachmentEntry* Entry = reinterpret_cast<CbAttachmentEntry*>(Data.data() + sizeof(CbPackageHeader)); + Entry->PayloadSize = 9999; // way more than available + Entry->Flags = CbAttachmentEntry::kIsObject; + Entry->AttachmentHash = IoHash(); + + IoBuffer Payload(IoBuffer::Wrap, Data.data(), Data.size()); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.LocalRefRejectedByDefault") +{ + // Build a valid package with local refs backed by compressed-format files, + // then verify it's rejected with default ParseFlags and accepted when allowed. + ScopedTemporaryDirectory TempDir; + auto Path1 = TempDir.Path() / "abcd"; + auto Path2 = TempDir.Path() / "efgh"; + + // Compress data and write to disk, then create file-backed compressed attachments. + // The files must contain compressed-format data because ParsePackageMessage expects it + // when resolving local refs. + CompressedBuffer Comp1 = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("abcd")), OodleCompressor::NotSet, OodleCompressionLevel::None); + CompressedBuffer Comp2 = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("efgh")), OodleCompressor::NotSet, OodleCompressionLevel::None); + + IoHash Hash1 = Comp1.DecodeRawHash(); + IoHash Hash2 = Comp2.DecodeRawHash(); + + { + IoBuffer Buf1 = Comp1.GetCompressed().Flatten().AsIoBuffer(); + IoBuffer Buf2 = Comp2.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(Path1, Buf1); + WriteFile(Path2, Buf2); + } + + // Create attachments from file-backed buffers so FormatPackageMessage uses local refs + CbAttachment Attach1{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path1)), Hash1}; + CbAttachment Attach2{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path2)), Hash2}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("abcd", Attach1); + Cbo.AddAttachment("efgh", Attach2); + + CbPackage Pkg; + Pkg.AddAttachment(Attach1); + Pkg.AddAttachment(Attach2); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // Default flags should reject local refs + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); + + // With kAllowLocalReferences + a permissive policy, the local-ref gate is passed (the full round-trip + // for local refs through ParsePackageMessage is covered by CbPackage.LocalRef via CbPackageReader) + PermissiveLocalRefPolicy AllowAllPolicy; + CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &AllowAllPolicy); + CHECK(Result.GetObject()); + CHECK(Result.GetAttachments().size() == 2); +} + +TEST_CASE("CbPackage.Validation.LocalRefRejectedByReader") +{ + // Same test but via CbPackageReader + ScopedTemporaryDirectory TempDir; + auto FilePath = TempDir.Path() / "testdata"; + + { + IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("testdata")); + WriteFile(FilePath, Buf); + } + + IoBuffer FileBuffer = IoBufferBuilder::MakeFromFile(FilePath); + CbAttachment Attach{SharedBuffer(FileBuffer)}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten(); + const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); + uint64_t RemainingBytes = Buffer.GetSize(); + + auto ConsumeBytes = [&](uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + void* ReturnPtr = (void*)CursorPtr; + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + return ReturnPtr; + }; + + auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + memcpy(TargetBuffer, CursorPtr, ByteCount); + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + }; + + // Default flags should reject CbPackageReader Reader; uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); @@ -933,7 +1221,199 @@ TEST_CASE("CbPackage.LocalRef") CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); } - Reader.Finalize(); + CHECK_THROWS_AS(Reader.Finalize(), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.BadMagicViaReader") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = 0xBADCAFE; + Hdr.AttachmentCount = 0; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.AttachmentCountOverflowViaReader") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = UINT32_MAX; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathOutsideRoot") +{ + // A file outside the allowed root should be rejected by the policy + ScopedTemporaryDirectory AllowedRoot; + ScopedTemporaryDirectory OutsideDir; + + auto OutsidePath = OutsideDir.Path() / "secret.dat"; + { + IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("secret")); + WriteFile(OutsidePath, Buf); + } + + // Create file-backed compressed attachment from outside root + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("secret")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(OutsidePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // Policy rooted at AllowedRoot should reject the file in OutsideDir + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + Policy.Root = std::filesystem::weakly_canonical(AllowedRoot.Path()).string(); + + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathInsideRoot") +{ + // A file inside the allowed root should be accepted by the policy + ScopedTemporaryDirectory TempRoot; + + auto FilePath = TempRoot.Path() / "data.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("hello")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(FilePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string(); + + CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy); + CHECK(Result.GetObject()); + CHECK(Result.GetAttachments().size() == 1); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathTraversal") +{ + // A file path containing ".." that resolves outside root should be rejected + ScopedTemporaryDirectory TempRoot; + ScopedTemporaryDirectory OutsideDir; + + auto OutsidePath = OutsideDir.Path() / "evil.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("evil")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(OutsidePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + // Root is TempRoot, but the file lives in OutsideDir + Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string(); + + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.NoPolicyFailClosed") +{ + // When local refs are allowed but no policy is provided, file-path refs should be rejected + ScopedTemporaryDirectory TempDir; + + auto FilePath = TempDir.Path() / "data.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("data")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(FilePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // kAllowLocalReferences but nullptr policy => fail-closed + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, nullptr), std::invalid_argument); } TEST_SUITE_END(); diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 7972777b8..a1a775ba3 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -625,6 +625,8 @@ public: void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; } void SetKeepAlive(bool KeepAlive) { m_IsKeepAlive = KeepAlive; } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } /** * Initialize the response for sending a payload made up of multiple blobs @@ -768,10 +770,18 @@ public: { ZEN_MEMSCOPE(GetHttpasioTag()); + std::string_view ContentTypeStr = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" - << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Type: " << ContentTypeStr << "\r\n" << "Content-Length: " << ContentLength() << "\r\n"sv; + if (!m_ContentRangeHeader.empty()) + { + m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv; + } + if (!m_IsKeepAlive) { m_Headers << "Connection: close\r\n"sv; @@ -898,7 +908,9 @@ private: bool m_AllowZeroCopyFileSend = true; State m_State = State::kUninitialized; HttpContentType m_ContentType = HttpContentType::kBinary; - uint64_t m_ContentLength = 0; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; + uint64_t m_ContentLength = 0; eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive ExtendableStringBuilder<160> m_Headers; @@ -1275,7 +1287,9 @@ HttpServerConnectionT<SocketType>::HandleRequest() asio::buffer(ResponseStr->data(), ResponseStr->size()), asio::bind_executor( m_Strand, - [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr, PrefixLen = Service->UriPrefixLength()]( + const asio::error_code& Ec, + std::size_t) { if (Ec) { ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); @@ -1287,7 +1301,9 @@ HttpServerConnectionT<SocketType>::HandleRequest() Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + std::string_view FullUrl = Conn->m_RequestData.Url(); + std::string_view RelativeUri = FullUrl.substr(std::min(PrefixLen, static_cast<int>(FullUrl.size()))); + WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri); WsConn->Start(); })); @@ -1295,7 +1311,7 @@ HttpServerConnectionT<SocketType>::HandleRequest() return; } } - // Service doesn't support WebSocket or missing key — fall through to normal handling + // Service doesn't support WebSocket or missing key - fall through to normal handling } if (!m_RequestData.IsKeepAlive()) @@ -2127,6 +2143,10 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); m_Response->SetKeepAlive(m_Request.IsKeepAlive()); + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -2142,6 +2162,14 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); m_Response->SetKeepAlive(m_Request.IsKeepAlive()); + if (!m_ContentTypeOverride.empty()) + { + m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index 918b55dc6..8b07c7905 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -8,6 +8,13 @@ #include <limits> +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <cstring> +# include <string> +# include <string_view> +#endif + namespace zen { using namespace std::literals; @@ -29,25 +36,25 @@ static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-W // HttpRequestParser // -http_parser_settings HttpRequestParser::s_ParserSettings{ - .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, - .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, - .on_status = - [](http_parser* p, const char* Data, size_t ByteCount) { - ZEN_UNUSED(p, Data, ByteCount); - return 0; - }, - .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, - .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, - .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, - .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, - .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, - .on_chunk_header{}, - .on_chunk_complete{}}; +// clang-format off +llhttp_settings_t HttpRequestParser::s_ParserSettings = []() { + llhttp_settings_t S; + llhttp_settings_init(&S); + S.on_message_begin = [](llhttp_t* p) { return GetThis(p)->OnMessageBegin(); }; + S.on_url = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }; + S.on_status = [](llhttp_t*, const char*, size_t) { return 0; }; + S.on_header_field = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }; + S.on_header_value = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }; + S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); }; + S.on_body = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }; + S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); }; + return S; +}(); +// clang-format on HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection) { - http_parser_init(&m_Parser, HTTP_REQUEST); + llhttp_init(&m_Parser, HTTP_REQUEST, &s_ParserSettings); m_Parser.data = this; ResetState(); @@ -60,16 +67,17 @@ HttpRequestParser::~HttpRequestParser() size_t HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize) { - const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); - - http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser)); - - if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE) + llhttp_errno_t Err = llhttp_execute(&m_Parser, InputData, DataSize); + if (Err == HPE_OK) { - ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno)); - return ~0ull; + return DataSize; } - return ConsumedBytes; + if (Err == HPE_PAUSED_UPGRADE) + { + return DataSize; + } + ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", llhttp_errno_name(Err), llhttp_get_error_reason(&m_Parser)); + return ~0ull; } int @@ -79,7 +87,7 @@ HttpRequestParser::OnUrl(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } if (m_UrlRange.Length == 0) @@ -101,7 +109,7 @@ HttpRequestParser::OnHeader(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } if (m_HeaderEntries.empty()) @@ -212,7 +220,7 @@ HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } ZEN_ASSERT_SLOW(!m_HeaderEntries.empty()); @@ -269,9 +277,9 @@ HttpRequestParser::OnHeadersComplete() } } - m_KeepAlive = !!http_should_keep_alive(&m_Parser); + m_KeepAlive = !!llhttp_should_keep_alive(&m_Parser); - switch (m_Parser.method) + switch (llhttp_get_method(&m_Parser)) { case HTTP_GET: m_RequestVerb = HttpVerb::kGet; @@ -302,7 +310,7 @@ HttpRequestParser::OnHeadersComplete() break; default: - ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method)); + ZEN_WARN("invalid HTTP method: '{}'", llhttp_method_name(static_cast<llhttp_method_t>(llhttp_get_method(&m_Parser)))); break; } @@ -349,20 +357,11 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes) { ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes", (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); - return 1; + return -1; } memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); m_BodyPosition += Bytes; - if (http_body_is_final(&m_Parser)) - { - if (m_BodyPosition != m_BodyBuffer.Size()) - { - ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); - return 1; - } - } - return 0; } @@ -409,7 +408,7 @@ HttpRequestParser::OnMessageComplete() catch (const AssertException& AssertEx) { ZEN_WARN("Assert caught when processing http request: {}", AssertEx.FullDescription()); - return 1; + return -1; } catch (const std::system_error& SystemError) { @@ -426,19 +425,19 @@ HttpRequestParser::OnMessageComplete() ZEN_ERROR("failed processing http request: '{}' ({})", SystemError.what(), SystemError.code().value()); } ResetState(); - return 1; + return -1; } catch (const std::bad_alloc& BadAlloc) { ZEN_WARN("out of memory when processing http request: '{}'", BadAlloc.what()); ResetState(); - return 1; + return -1; } catch (const std::exception& Ex) { ZEN_ERROR("failed processing http request: '{}'", Ex.what()); ResetState(); - return 1; + return -1; } } @@ -459,4 +458,331 @@ HttpRequestParser::IsWebSocketUpgrade() const return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0; } +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +namespace { + + struct MockCallbacks : HttpRequestParserCallbacks + { + int HandleRequestCount = 0; + int TerminateCount = 0; + + HttpRequestParser* Parser = nullptr; + + HttpVerb LastVerb{}; + std::string LastUrl; + std::string LastQueryString; + std::string LastBody; + bool LastKeepAlive = false; + bool LastIsWebSocketUpgrade = false; + std::string LastSecWebSocketKey; + std::string LastUpgradeHeader; + HttpContentType LastContentType{}; + + void HandleRequest() override + { + ++HandleRequestCount; + if (Parser) + { + LastVerb = Parser->RequestVerb(); + LastUrl = std::string(Parser->Url()); + LastQueryString = std::string(Parser->QueryString()); + LastKeepAlive = Parser->IsKeepAlive(); + LastIsWebSocketUpgrade = Parser->IsWebSocketUpgrade(); + LastSecWebSocketKey = std::string(Parser->SecWebSocketKey()); + LastUpgradeHeader = std::string(Parser->UpgradeHeader()); + LastContentType = Parser->ContentType(); + + IoBuffer Body = Parser->Body(); + if (Body.Size() > 0) + { + LastBody.assign(reinterpret_cast<const char*>(Body.Data()), Body.Size()); + } + else + { + LastBody.clear(); + } + } + } + + void TerminateConnection() override { ++TerminateCount; } + }; + +} // anonymous namespace + +TEST_SUITE_BEGIN("http.httpparser"); + +TEST_CASE("httpparser.basic_get") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, HttpVerb::kGet); + CHECK_EQ(Mock.LastUrl, "/path"); + CHECK(Mock.LastKeepAlive); +} + +TEST_CASE("httpparser.post_with_body") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 13\r\n" + "Content-Type: application/json\r\n" + "\r\n" + "{\"key\":\"val\"}"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, HttpVerb::kPost); + CHECK_EQ(Mock.LastBody, "{\"key\":\"val\"}"); + CHECK_EQ(Mock.LastContentType, HttpContentType::kJSON); +} + +TEST_CASE("httpparser.pipelined_requests") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "GET /first HTTP/1.1\r\nHost: localhost\r\n\r\n" + "GET /second HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 2); + CHECK_EQ(Mock.LastUrl, "/second"); +} + +TEST_CASE("httpparser.partial_header") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Chunk1 = "GET /path HTTP/1.1\r\nHost: loc"; + std::string Chunk2 = "alhost\r\n\r\n"; + + size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size()); + CHECK_NE(Consumed1, ~0ull); + CHECK_EQ(Consumed1, Chunk1.size()); + CHECK_EQ(Mock.HandleRequestCount, 0); + + size_t Consumed2 = Parser.ConsumeData(Chunk2.data(), Chunk2.size()); + CHECK_NE(Consumed2, ~0ull); + CHECK_EQ(Consumed2, Chunk2.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastUrl, "/path"); +} + +TEST_CASE("httpparser.partial_body") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Headers = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 10\r\n" + "\r\n"; + std::string BodyPart1 = "hello"; + std::string BodyPart2 = "world"; + + std::string Chunk1 = Headers + BodyPart1; + + size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size()); + CHECK_NE(Consumed1, ~0ull); + CHECK_EQ(Consumed1, Chunk1.size()); + CHECK_EQ(Mock.HandleRequestCount, 0); + + size_t Consumed2 = Parser.ConsumeData(BodyPart2.data(), BodyPart2.size()); + CHECK_NE(Consumed2, ~0ull); + CHECK_EQ(Consumed2, BodyPart2.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastBody, "helloworld"); +} + +TEST_CASE("httpparser.invalid_request") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Garbage = "NOT_HTTP garbage data\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Garbage.data(), Garbage.size()); + CHECK_EQ(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 0); +} + +TEST_CASE("httpparser.body_overflow") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + // llhttp enforces Content-Length strictly: it delivers exactly 2 body bytes, + // fires on_message_complete, then tries to parse the remaining "O_LONG_BODY" + // as a new HTTP request which fails. + std::string Request = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 2\r\n" + "\r\n" + "TOO_LONG_BODY"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastBody, "TO"); +} + +TEST_CASE("httpparser.websocket_upgrade") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "GET /ws HTTP/1.1\r\n" + "Host: localhost\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK(Mock.LastIsWebSocketUpgrade); + CHECK_EQ(Mock.LastSecWebSocketKey, "dGhlIHNhbXBsZSBub25jZQ=="); + CHECK_EQ(Mock.LastUpgradeHeader, "websocket"); +} + +TEST_CASE("httpparser.websocket_upgrade_with_trailing_bytes") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string HttpPart = + "GET /ws HTTP/1.1\r\n" + "Host: localhost\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + + // Append fake WebSocket frame bytes after the HTTP message + std::string Request = HttpPart; + Request.push_back('\x81'); + Request.push_back('\x05'); + Request.append("hello"); + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_NE(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK(Mock.LastIsWebSocketUpgrade); +} + +TEST_CASE("httpparser.keep_alive_detection") +{ + SUBCASE("HTTP/1.1 default keep-alive") + { + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + Parser.ConsumeData(Request.data(), Request.size()); + CHECK(Mock.LastKeepAlive); + } + + SUBCASE("Connection: close disables keep-alive") + { + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; + Parser.ConsumeData(Request.data(), Request.size()); + CHECK_FALSE(Mock.LastKeepAlive); + } +} + +TEST_CASE("httpparser.all_verbs") +{ + struct VerbTest + { + const char* Method; + HttpVerb Expected; + }; + + VerbTest Tests[] = { + {"GET", HttpVerb::kGet}, + {"POST", HttpVerb::kPost}, + {"PUT", HttpVerb::kPut}, + {"DELETE", HttpVerb::kDelete}, + {"HEAD", HttpVerb::kHead}, + {"COPY", HttpVerb::kCopy}, + {"OPTIONS", HttpVerb::kOptions}, + }; + + for (const VerbTest& Test : Tests) + { + CAPTURE(Test.Method); + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = std::string(Test.Method) + " /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, Test.Expected); + } +} + +TEST_CASE("httpparser.query_string") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path?key=val&other=123 HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastUrl, "/path"); + CHECK_EQ(Mock.LastQueryString, "key=val&other=123"); +} + +TEST_SUITE_END(); + +void +httpparser_forcelink() +{ +} + +#endif // ZEN_WITH_TESTS + } // namespace zen diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index 23ad9d8fb..4ff216248 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -8,7 +8,7 @@ #include <EASTL/fixed_vector.h> ZEN_THIRD_PARTY_INCLUDES_START -#include <http_parser.h> +#include <llhttp.h> ZEN_THIRD_PARTY_INCLUDES_END #include <atomic> @@ -100,7 +100,7 @@ private: Oid m_SessionId{}; IoBuffer m_BodyBuffer; uint64_t m_BodyPosition = 0; - http_parser m_Parser; + llhttp_t m_Parser; eastl::fixed_vector<char, 512> m_HeaderData; std::string m_NormalizedUrl; @@ -114,8 +114,8 @@ private: int OnBody(const char* Data, size_t Bytes); int OnMessageComplete(); - static HttpRequestParser* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); } - static http_parser_settings s_ParserSettings; + static HttpRequestParser* GetThis(llhttp_t* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); } + static llhttp_settings_t s_ParserSettings; }; } // namespace zen diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 31b0315d4..b0fb020e0 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -185,13 +185,17 @@ public: const std::vector<IoBuffer>& ResponseBuffers() const { return m_ResponseBuffers; } void SuppressPayload() { m_ResponseBuffers.resize(1); } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } std::string_view GetHeaders(); private: - uint16_t m_ResponseCode = 0; - bool m_IsKeepAlive = true; - HttpContentType m_ContentType = HttpContentType::kBinary; + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; uint64_t m_ContentLength = 0; std::vector<IoBuffer> m_ResponseBuffers; ExtendableStringBuilder<160> m_Headers; @@ -246,10 +250,18 @@ HttpPluginResponse::GetHeaders() if (m_Headers.Size() == 0) { + std::string_view ContentTypeStr = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" - << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Type: " << ContentTypeStr << "\r\n" << "Content-Length: " << ContentLength() << "\r\n"sv; + if (!m_ContentRangeHeader.empty()) + { + m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv; + } + if (!m_IsKeepAlive) { m_Headers << "Connection: close\r\n"sv; @@ -669,6 +681,10 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode) ZEN_MEMSCOPE(GetHttppluginTag()); m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary)); + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -681,6 +697,14 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpConten ZEN_MEMSCOPE(GetHttppluginTag()); m_Response.reset(new HttpPluginResponse(ContentType)); + if (!m_ContentTypeOverride.empty()) + { + m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 2cad97725..67b1230a0 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -464,6 +464,8 @@ public: inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } private: eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks; @@ -473,6 +475,8 @@ private: uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; eastl::fixed_vector<IoBuffer, 16> m_DataBuffers; std::string m_LocationHeader; @@ -725,7 +729,8 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; - std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + std::string_view ContentTypeString = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); @@ -739,6 +744,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size(); } + // Content-Range header (for 206 Partial Content single-range responses) + + if (!m_ContentRangeHeader.empty()) + { + PHTTP_KNOWN_HEADER ContentRangeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentRange]; + ContentRangeHeader->pRawValue = m_ContentRangeHeader.data(); + ContentRangeHeader->RawValueLength = (USHORT)m_ContentRangeHeader.size(); + } + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.StatusCode = m_ResponseCode; @@ -1258,7 +1272,7 @@ HttpSysServer::RegisterHttpUrls(int BasePort) else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED) { // Port may be owned by another process's wildcard registration (access denied) - // or actively in use (sharing violation) — retry on a different port + // or actively in use (sharing violation) - retry on a different port ShouldRetryNextPort = true; } else @@ -2279,6 +2293,11 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + if (!m_ContentRangeHeader.empty()) + { + Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } + if (SuppressBody()) { Response->SuppressResponseBody(); @@ -2307,6 +2326,15 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + if (!m_ContentTypeOverride.empty()) + { + Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } + if (SuppressBody()) { Response->SuppressResponseBody(); @@ -2595,7 +2623,14 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT &Transaction().Server())); Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + ExtendableStringBuilder<128> UrlUtf8; + WideToUtf8({(wchar_t*)HttpReq->CookedUrl.pAbsPath, + gsl::narrow<size_t>(HttpReq->CookedUrl.AbsPathLength / sizeof(wchar_t))}, + UrlUtf8); + int PrefixLen = Service->UriPrefixLength(); + std::string_view RelativeUri{UrlUtf8.ToView()}; + RelativeUri.remove_prefix(std::min(PrefixLen, static_cast<int>(RelativeUri.size()))); + WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri); WsConn->Start(); return nullptr; @@ -2603,11 +2638,11 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult); - // WebSocket upgrade failed — return nullptr since ServerRequest() + // WebSocket upgrade failed - return nullptr since ServerRequest() // was never populated (no InvokeRequestHandler call) return nullptr; } - // Service doesn't support WebSocket or missing key — fall through to normal handling + // Service doesn't support WebSocket or missing key - fall through to normal handling } } diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp index 5ae48f5b3..078c21ea1 100644 --- a/src/zenhttp/servers/wsasio.cpp +++ b/src/zenhttp/servers/wsasio.cpp @@ -141,7 +141,7 @@ WsAsioConnectionT<SocketType>::ProcessReceivedData() } case WebSocketOpcode::kPong: - // Unsolicited pong — ignore per RFC 6455 + // Unsolicited pong - ignore per RFC 6455 break; case WebSocketOpcode::kClose: diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp index af320172d..8520e9f60 100644 --- a/src/zenhttp/servers/wshttpsys.cpp +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -70,7 +70,7 @@ WsHttpSysConnection::Shutdown() return; } - // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED + // Cancel pending I/O - completions will fire with ERROR_OPERATION_ABORTED HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); } @@ -211,7 +211,7 @@ WsHttpSysConnection::ProcessReceivedData() } case WebSocketOpcode::kPong: - // Unsolicited pong — ignore per RFC 6455 + // Unsolicited pong - ignore per RFC 6455 break; case WebSocketOpcode::kClose: @@ -446,7 +446,7 @@ WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason) m_Handler.OnWebSocketClose(*this, Code, Reason); - // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED + // Cancel pending read I/O - completions drain via ERROR_OPERATION_ABORTED HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); } diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp index 59c46a418..a58037fec 100644 --- a/src/zenhttp/servers/wstest.cpp +++ b/src/zenhttp/servers/wstest.cpp @@ -5,6 +5,7 @@ # include <zencore/scopeguard.h> # include <zencore/testing.h> # include <zencore/testutils.h> +# include <zencore/timer.h> # include <zenhttp/httpserver.h> # include <zenhttp/httpwsclient.h> @@ -59,7 +60,7 @@ TEST_CASE("websocket.framecodec") std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); - // Server frames are unmasked — TryParseFrame should handle them + // Server frames are unmasked - TryParseFrame should handle them WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); @@ -129,7 +130,7 @@ TEST_CASE("websocket.framecodec") { std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{}); - // Pass only 1 byte — not enough for a frame header + // Pass only 1 byte - not enough for a frame header WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1); CHECK_FALSE(Result.IsValid); CHECK_EQ(Result.BytesConsumed, 0u); @@ -335,8 +336,9 @@ namespace { } // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override { + ZEN_UNUSED(RelativeUri); m_OpenCount.fetch_add(1); m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); }); @@ -463,7 +465,7 @@ namespace { if (!Done.load()) { - // Timeout — cancel the read + // Timeout - cancel the read asio::error_code Ec; Sock.cancel(Ec); } @@ -476,6 +478,23 @@ namespace { return Result; } + static void WaitForServerListening(int Port) + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + asio::io_context IoCtx; + asio::ip::tcp::socket Probe(IoCtx); + asio::error_code Ec; + Probe.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)), Ec); + if (!Ec) + { + return; + } + Sleep(10); + } + } + } // anonymous namespace TEST_CASE("websocket.integration") @@ -501,8 +520,8 @@ TEST_CASE("websocket.integration") Server->Close(); }); - // Give server a moment to start accepting - Sleep(100); + // Wait for server to start accepting + WaitForServerListening(Port); SUBCASE("handshake succeeds with 101") { @@ -692,7 +711,7 @@ TEST_CASE("websocket.integration") std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); - // Should NOT get 101 — should fall through to normal request handling + // Should NOT get 101 - should fall through to normal request handling CHECK(Response.find("101") == std::string::npos); Sock.close(); @@ -813,7 +832,7 @@ TEST_CASE("websocket.client") Server->Close(); }); - Sleep(100); + WaitForServerListening(Port); SUBCASE("connect, echo, close") { @@ -937,7 +956,7 @@ TEST_CASE("websocket.client.unixsocket") Server->Close(); }); - Sleep(100); + WaitForServerListening(Port); SUBCASE("connect, echo, close over unix socket") { diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 7b050ae35..67a01403d 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -9,7 +9,7 @@ target('zenhttp') add_files("servers/wshttpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) add_deps("zencore", "zentelemetry", "transport-sdk", "asio") - add_packages("http_parser", "json11", "libcurl") + add_packages("llhttp", "json11", "libcurl") add_options("httpsys") if is_plat("linux", "macosx") then diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index 3ac8eea8d..e15aa4d30 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -4,6 +4,7 @@ #if ZEN_WITH_TESTS +# include <zenhttp/asynchttpclient.h> # include <zenhttp/httpclient.h> # include <zenhttp/httpserver.h> # include <zenhttp/packageformat.h> @@ -16,7 +17,9 @@ zenhttp_forcelinktests() { http_forcelink(); httpclient_forcelink(); + httpparser_forcelink(); httpclient_test_forcelink(); + asynchttpclient_test_forcelink(); forcelink_packageformat(); passwordsecurity_forcelink(); websocket_forcelink(); diff --git a/src/zennomad/include/zennomad/nomadclient.h b/src/zennomad/include/zennomad/nomadclient.h index 0a3411ace..cebf217e1 100644 --- a/src/zennomad/include/zennomad/nomadclient.h +++ b/src/zennomad/include/zennomad/nomadclient.h @@ -52,7 +52,11 @@ public: /** Build the Nomad job registration JSON for the given job ID and orchestrator endpoint. * The JSON structure varies based on the configured driver and distribution mode. */ - std::string BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const; + std::string BuildJobJson(const std::string& JobId, + const std::string& OrchestratorEndpoint, + const std::string& CoordinatorSession = {}, + bool CleanStart = false, + const std::string& TraceHost = {}) const; /** Submit a job via PUT /v1/jobs. On success, populates OutJob with the job info. */ bool SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob); diff --git a/src/zennomad/include/zennomad/nomadprovisioner.h b/src/zennomad/include/zennomad/nomadprovisioner.h index 750693b3f..a8368e3dc 100644 --- a/src/zennomad/include/zennomad/nomadprovisioner.h +++ b/src/zennomad/include/zennomad/nomadprovisioner.h @@ -47,7 +47,11 @@ public: /** Construct a provisioner. * @param Config Nomad connection and job configuration. * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */ - NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint); + NomadProvisioner(const NomadConfig& Config, + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession = {}, + bool CleanStart = false, + std::string_view TraceHost = {}); /** Signals the management thread to exit and stops all tracked jobs. */ ~NomadProvisioner(); @@ -83,6 +87,9 @@ private: NomadConfig m_Config; std::string m_OrchestratorEndpoint; + std::string m_CoordinatorSession; + bool m_CleanStart = false; + std::string m_TraceHost; std::unique_ptr<NomadClient> m_Client; diff --git a/src/zennomad/nomadclient.cpp b/src/zennomad/nomadclient.cpp index 9edcde125..4bb09a930 100644 --- a/src/zennomad/nomadclient.cpp +++ b/src/zennomad/nomadclient.cpp @@ -58,7 +58,11 @@ NomadClient::Initialize() } std::string -NomadClient::BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const +NomadClient::BuildJobJson(const std::string& JobId, + const std::string& OrchestratorEndpoint, + const std::string& CoordinatorSession, + bool CleanStart, + const std::string& TraceHost) const { ZEN_TRACE_CPU("NomadClient::BuildJobJson"); @@ -94,6 +98,22 @@ NomadClient::BuildJobJson(const std::string& JobId, const std::string& Orchestra IdArg << "--instance-id=nomad-" << JobId; Args.push_back(std::string(IdArg.ToView())); } + if (!CoordinatorSession.empty()) + { + ExtendableStringBuilder<128> SessionArg; + SessionArg << "--coordinator-session=" << CoordinatorSession; + Args.push_back(std::string(SessionArg.ToView())); + } + if (CleanStart) + { + Args.push_back("--clean"); + } + if (!TraceHost.empty()) + { + ExtendableStringBuilder<128> TraceArg; + TraceArg << "--tracehost=" << TraceHost; + Args.push_back(std::string(TraceArg.ToView())); + } TaskConfig["args"] = Args; } else @@ -115,6 +135,22 @@ NomadClient::BuildJobJson(const std::string& JobId, const std::string& Orchestra IdArg << "--instance-id=nomad-" << JobId; Args.push_back(std::string(IdArg.ToView())); } + if (!CoordinatorSession.empty()) + { + ExtendableStringBuilder<128> SessionArg; + SessionArg << "--coordinator-session=" << CoordinatorSession; + Args.push_back(std::string(SessionArg.ToView())); + } + if (CleanStart) + { + Args.push_back("--clean"); + } + if (!TraceHost.empty()) + { + ExtendableStringBuilder<128> TraceArg; + TraceArg << "--tracehost=" << TraceHost; + Args.push_back(std::string(TraceArg.ToView())); + } TaskConfig["args"] = Args; } diff --git a/src/zennomad/nomadprocess.cpp b/src/zennomad/nomadprocess.cpp index 1ae968fb7..deecdef05 100644 --- a/src/zennomad/nomadprocess.cpp +++ b/src/zennomad/nomadprocess.cpp @@ -37,7 +37,7 @@ struct NomadProcess::Impl } CreateProcOptions Options; - Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; CreateProcResult Result = CreateProc("nomad" ZEN_EXE_SUFFIX_LITERAL, "nomad" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options); diff --git a/src/zennomad/nomadprovisioner.cpp b/src/zennomad/nomadprovisioner.cpp index 3fe9c0ac3..e07ce155e 100644 --- a/src/zennomad/nomadprovisioner.cpp +++ b/src/zennomad/nomadprovisioner.cpp @@ -14,9 +14,16 @@ namespace zen::nomad { -NomadProvisioner::NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint) +NomadProvisioner::NomadProvisioner(const NomadConfig& Config, + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession, + bool CleanStart, + std::string_view TraceHost) : m_Config(Config) , m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_CoordinatorSession(CoordinatorSession) +, m_CleanStart(CleanStart) +, m_TraceHost(TraceHost) , m_ProcessId(static_cast<uint32_t>(zen::GetCurrentProcessId())) , m_Log(zen::logging::Get("nomad.provisioner")) { @@ -154,7 +161,7 @@ NomadProvisioner::SubmitNewJobs() ZEN_DEBUG("submitting job '{}' (estimated: {}, target: {})", JobId, m_EstimatedCoreCount.load(), m_TargetCoreCount.load()); - const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint); + const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint, m_CoordinatorSession, m_CleanStart, m_TraceHost); NomadJobInfo JobInfo; JobInfo.Id = JobId; diff --git a/src/zenremotestore/builds/buildstoragecache.cpp b/src/zenremotestore/builds/buildstoragecache.cpp index 0e0b14dca..40ea757eb 100644 --- a/src/zenremotestore/builds/buildstoragecache.cpp +++ b/src/zenremotestore/builds/buildstoragecache.cpp @@ -96,7 +96,8 @@ public: ZEN_ASSERT(!IsFlushed); ZEN_ASSERT(ContentType == ZenContentType::kCompressedBinary); - // Move all segments in Payload to be file handle based so if Payload is materialized it does not affect buffers in queue + // Move all segments in Payload to be file handle based unless they are very small so if Payload is materialized it does not affect + // buffers in queue std::vector<SharedBuffer> FileBasedSegments; std::span<const SharedBuffer> Segments = Payload.GetSegments(); FileBasedSegments.reserve(Segments.size()); @@ -104,42 +105,56 @@ public: tsl::robin_map<void*, std::filesystem::path> HandleToPath; for (const SharedBuffer& Segment : Segments) { - std::filesystem::path FilePath; - IoBufferFileReference Ref; - if (Segment.AsIoBuffer().GetFileReference(Ref)) + const uint64_t SegmentSize = Segment.GetSize(); + if (SegmentSize < 16u * 1024u) { - if (auto It = HandleToPath.find(Ref.FileHandle); It != HandleToPath.end()) - { - FilePath = It->second; - } - else + FileBasedSegments.push_back(Segment); + } + else + { + std::filesystem::path FilePath; + IoBufferFileReference Ref; + if (Segment.AsIoBuffer().GetFileReference(Ref)) { - std::error_code Ec; - std::filesystem::path Path = PathFromHandle(Ref.FileHandle, Ec); - if (!Ec && !Path.empty()) + if (auto It = HandleToPath.find(Ref.FileHandle); It != HandleToPath.end()) { - HandleToPath.insert_or_assign(Ref.FileHandle, Path); - FilePath = std::move(Path); + FilePath = It->second; + } + else + { + std::error_code Ec; + std::filesystem::path Path = PathFromHandle(Ref.FileHandle, Ec); + if (!Ec && !Path.empty()) + { + HandleToPath.insert_or_assign(Ref.FileHandle, Path); + FilePath = std::move(Path); + } + else + { + ZEN_WARN("Failed getting path for chunk to upload to cache. Skipping upload."); + return; + } } } - } - if (!FilePath.empty()) - { - IoBuffer BufferFromFile = IoBufferBuilder::MakeFromFile(FilePath, Ref.FileChunkOffset, Ref.FileChunkSize); - if (BufferFromFile) + if (!FilePath.empty()) { - FileBasedSegments.push_back(SharedBuffer(std::move(BufferFromFile))); + IoBuffer BufferFromFile = IoBufferBuilder::MakeFromFile(FilePath, Ref.FileChunkOffset, Ref.FileChunkSize); + if (BufferFromFile) + { + FileBasedSegments.push_back(SharedBuffer(std::move(BufferFromFile))); + } + else + { + ZEN_WARN("Failed opening file '{}' to upload to cache. Skipping upload.", FilePath); + return; + } } else { FileBasedSegments.push_back(Segment); } } - else - { - FileBasedSegments.push_back(Segment); - } } } diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index a04063c4c..6d93d8de6 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -11,8 +11,7 @@ #include <zenremotestore/chunking/chunkblock.h> #include <zenremotestore/chunking/chunkingcache.h> #include <zenremotestore/chunking/chunkingcontroller.h> -#include <zenremotestore/filesystemutils.h> -#include <zenremotestore/operationlogoutput.h> +#include <zenutil/progress.h> #include <zencore/basicfile.h> #include <zencore/compactbinary.h> @@ -26,6 +25,7 @@ #include <zencore/string.h> #include <zencore/timer.h> #include <zencore/trace.h> +#include <zenutil/filesystemutils.h> #include <zenutil/wildcard.h> #include <numeric> @@ -79,7 +79,8 @@ namespace { return CacheFolderPath / RawHash.ToHexString(); } - bool CleanDirectory(OperationLogOutput& OperationLogOutput, + bool CleanDirectory(LoggerRef InLog, + ProgressBase& Progress, WorkerThreadPool& IOWorkerPool, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, @@ -88,10 +89,10 @@ namespace { std::span<const std::string> ExcludeDirectories) { ZEN_TRACE_CPU("CleanDirectory"); + ZEN_SCOPED_LOG(InLog); Stopwatch Timer; - std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(OperationLogOutput.CreateProgressBar("Clean Folder")); - OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr); + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = Progress.CreateProgressBar("Clean Folder"); CleanDirectoryResult Result = CleanDirectory( IOWorkerPool, @@ -100,16 +101,16 @@ namespace { Path, ExcludeDirectories, [&](const std::string_view Details, uint64_t TotalCount, uint64_t RemainingCount, bool IsPaused, bool IsAborted) { - Progress.UpdateState({.Task = "Cleaning folder ", - .Details = std::string(Details), - .TotalCount = TotalCount, - .RemainingCount = RemainingCount, - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, - false); + ProgressBar->UpdateState({.Task = "Cleaning folder ", + .Details = std::string(Details), + .TotalCount = TotalCount, + .RemainingCount = RemainingCount, + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + false); }, - OperationLogOutput.GetProgressUpdateDelayMS()); + Progress.GetProgressUpdateDelayMS()); - Progress.Finish(); + ProgressBar->Finish(); if (AbortFlag) { @@ -128,17 +129,16 @@ namespace { Result.FailedRemovePaths[FailedPathIndex].second.value(), Result.FailedRemovePaths[FailedPathIndex].second.message()); } - ZEN_OPERATION_LOG_WARN(OperationLogOutput, "Clean failed to remove files from '{}': {}", Path, SB.ToView()); + ZEN_WARN("Clean failed to remove files from '{}': {}", Path, SB.ToView()); } if (ElapsedTimeMs >= 200 && !IsQuiet) { - ZEN_OPERATION_LOG_INFO(OperationLogOutput, - "Wiped folder '{}' {} ({}) in {}", - Path, - Result.FoundCount, - NiceBytes(Result.DeletedByteCount), - NiceTimeSpanMs(ElapsedTimeMs)); + ZEN_INFO("Wiped folder '{}' {} ({}) in {}", + Path, + Result.FoundCount, + NiceBytes(Result.DeletedByteCount), + NiceTimeSpanMs(ElapsedTimeMs)); } return Result.FailedRemovePaths.empty(); @@ -150,11 +150,9 @@ namespace { } bool IsChunkCompressable(const tsl::robin_set<uint32_t>& NonCompressableExtensionHashes, - const ChunkedFolderContent& Content, const ChunkedContentLookup& Lookup, uint32_t ChunkIndex) { - ZEN_UNUSED(Content); const uint32_t ChunkLocationCount = Lookup.ChunkSequenceLocationCounts[ChunkIndex]; if (ChunkLocationCount == 0) { @@ -180,6 +178,54 @@ namespace { return SB.ToString(); } + uint32_t SetNativeFileAttributes(const std::filesystem::path FilePath, SourcePlatform SourcePlatform, uint32_t Attributes) + { +#if ZEN_PLATFORM_WINDOWS + if (SourcePlatform == SourcePlatform::Windows) + { + SetFileAttributesToPath(FilePath, Attributes); + return Attributes; + } + else + { + uint32_t CurrentAttributes = GetFileAttributesFromPath(FilePath); + uint32_t NewAttributes = zen::MakeFileAttributeReadOnly(CurrentAttributes, zen::IsFileModeReadOnly(Attributes)); + if (CurrentAttributes != NewAttributes) + { + SetFileAttributesToPath(FilePath, NewAttributes); + } + return NewAttributes; + } +#endif // ZEN_PLATFORM_WINDOWS +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + if (SourcePlatform != SourcePlatform::Windows) + { + zen::SetFileMode(FilePath, Attributes); + return Attributes; + } + else + { + uint32_t CurrentMode = zen::GetFileMode(FilePath); + uint32_t NewMode = zen::MakeFileModeReadOnly(CurrentMode, zen::IsFileAttributeReadOnly(Attributes)); + if (CurrentMode != NewMode) + { + zen::SetFileMode(FilePath, NewMode); + } + return NewMode; + } +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + }; + + uint32_t GetNativeFileAttributes(const std::filesystem::path FilePath) + { +#if ZEN_PLATFORM_WINDOWS + return GetFileAttributesFromPath(FilePath); +#endif // ZEN_PLATFORM_WINDOWS +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + return GetFileMode(FilePath); +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + } + void DownloadLargeBlob(BuildStorageBase& Storage, const std::filesystem::path& DownloadFolder, const Oid& BuildId, @@ -219,7 +265,7 @@ namespace { Workload->TempFile.Write(Chunk.GetView(), Offset); } }, - [&Work, Workload, &DownloadedChunkByteCount, OnDownloadComplete = std::move(OnDownloadComplete)]() { + [&Work, Workload, OnDownloadComplete = std::move(OnDownloadComplete)]() { if (!Work.IsAborted()) { ZEN_TRACE_CPU("Async_DownloadLargeBlob_OnComplete"); @@ -334,8 +380,120 @@ namespace { return CompositeBuffer{}; } + std::filesystem::path TryMoveDownloadedChunk(IoBuffer& BlockBuffer, const std::filesystem::path& Path, bool ForceDiskBased) + { + uint64_t BlockSize = BlockBuffer.GetSize(); + IoBufferFileReference FileRef; + if (BlockBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) && (FileRef.FileChunkSize == BlockSize)) + { + ZEN_TRACE_CPU("MoveTempFullBlock"); + std::error_code Ec; + std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec); + if (!Ec) + { + BlockBuffer.SetDeleteOnClose(false); + BlockBuffer = {}; + RenameFile(TempBlobPath, Path, Ec); + if (Ec) + { + // Re-open the temp file again + BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete); + BlockBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockSize, true); + BlockBuffer.SetDeleteOnClose(true); + } + else + { + return Path; + } + } + } + + if (ForceDiskBased) + { + // Could not be moved and rather large, lets store it on disk + ZEN_TRACE_CPU("WriteTempFullBlock"); + TemporaryFile::SafeWriteFile(Path, BlockBuffer); + BlockBuffer = {}; + return Path; + } + + return {}; + } + } // namespace +class ReadFileCache +{ +public: + // A buffered file reader that provides CompositeBuffer where the buffers are owned and the memory never overwritten + ReadFileCache(std::atomic<uint64_t>& OpenReadCount, + std::atomic<uint64_t>& CurrentOpenFileCount, + std::atomic<uint64_t>& ReadCount, + std::atomic<uint64_t>& ReadByteCount, + const std::filesystem::path& Path, + const ChunkedFolderContent& LocalContent, + const ChunkedContentLookup& LocalLookup, + size_t MaxOpenFileCount) + : m_Path(Path) + , m_LocalContent(LocalContent) + , m_LocalLookup(LocalLookup) + , m_OpenReadCount(OpenReadCount) + , m_CurrentOpenFileCount(CurrentOpenFileCount) + , m_ReadCount(ReadCount) + , m_ReadByteCount(ReadByteCount) + { + m_OpenFiles.reserve(MaxOpenFileCount); + } + ~ReadFileCache() { m_OpenFiles.clear(); } + + CompositeBuffer GetRange(uint32_t SequenceIndex, uint64_t Offset, uint64_t Size) + { + ZEN_TRACE_CPU("ReadFileCache::GetRange"); + + auto CacheIt = + std::find_if(m_OpenFiles.begin(), m_OpenFiles.end(), [SequenceIndex](const auto& Lhs) { return Lhs.first == SequenceIndex; }); + if (CacheIt != m_OpenFiles.end()) + { + if (CacheIt != m_OpenFiles.begin()) + { + auto CachedFile(std::move(CacheIt->second)); + m_OpenFiles.erase(CacheIt); + m_OpenFiles.insert(m_OpenFiles.begin(), std::make_pair(SequenceIndex, std::move(CachedFile))); + } + CompositeBuffer Result = m_OpenFiles.front().second->GetRange(Offset, Size); + return Result; + } + const uint32_t LocalPathIndex = m_LocalLookup.SequenceIndexFirstPathIndex[SequenceIndex]; + const std::filesystem::path LocalFilePath = (m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred(); + if (Size == m_LocalContent.RawSizes[LocalPathIndex]) + { + IoBuffer Result = IoBufferBuilder::MakeFromFile(LocalFilePath); + return CompositeBuffer(SharedBuffer(Result)); + } + if (m_OpenFiles.size() == m_OpenFiles.capacity()) + { + m_OpenFiles.pop_back(); + } + m_OpenFiles.insert( + m_OpenFiles.begin(), + std::make_pair( + SequenceIndex, + std::make_unique<BufferedOpenFile>(LocalFilePath, m_OpenReadCount, m_CurrentOpenFileCount, m_ReadCount, m_ReadByteCount))); + CompositeBuffer Result = m_OpenFiles.front().second->GetRange(Offset, Size); + return Result; + } + +private: + const std::filesystem::path m_Path; + const ChunkedFolderContent& m_LocalContent; + const ChunkedContentLookup& m_LocalLookup; + std::vector<std::pair<uint32_t, std::unique_ptr<BufferedOpenFile>>> m_OpenFiles; + std::atomic<uint64_t>& m_OpenReadCount; + std::atomic<uint64_t>& m_CurrentOpenFileCount; + std::atomic<uint64_t>& m_ReadCount; + std::atomic<uint64_t>& m_ReadByteCount; +}; + bool IsSingleFileChunk(const ChunkedFolderContent& RemoteContent, const std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> Locations) @@ -498,7 +656,8 @@ ZenTempFolderPath(const std::filesystem::path& ZenFolderPath) ////////////////////// BuildsOperationUpdateFolder -BuildsOperationUpdateFolder::BuildsOperationUpdateFolder(OperationLogOutput& OperationLogOutput, +BuildsOperationUpdateFolder::BuildsOperationUpdateFolder(LoggerRef Log, + ProgressBase& Progress, StorageInstance& Storage, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, @@ -513,7 +672,8 @@ BuildsOperationUpdateFolder::BuildsOperationUpdateFolder(OperationLogOutput& const std::vector<ChunkBlockDescription>& BlockDescriptions, const std::vector<IoHash>& LooseChunkHashes, const Options& Options) -: m_LogOutput(OperationLogOutput) +: m_Log(Log) +, m_Progress(Progress) , m_Storage(Storage) , m_AbortFlag(AbortFlag) , m_PauseFlag(PauseFlag) @@ -551,65 +711,30 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) }; auto EndProgress = - MakeGuard([&]() { m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); }); - - ZEN_ASSERT((!m_Options.PrimeCacheOnly) || - (m_Options.PrimeCacheOnly && (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off))); + MakeGuard([&]() { m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); }); - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::ScanExistingData, (uint32_t)TaskSteps::StepCount); + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::ScanExistingData, (uint32_t)TaskSteps::StepCount); CreateDirectories(m_CacheFolderPath); CreateDirectories(m_TempDownloadFolderPath); CreateDirectories(m_TempBlockFolderPath); - Stopwatch CacheMappingTimer; - std::vector<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters(m_RemoteContent.ChunkedContent.SequenceRawHashes.size()); std::vector<bool> RemoteChunkIndexNeedsCopyFromLocalFileFlags(m_RemoteContent.ChunkedContent.ChunkHashes.size()); std::vector<std::atomic<bool>> RemoteChunkIndexNeedsCopyFromSourceFlags(m_RemoteContent.ChunkedContent.ChunkHashes.size()); tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> CachedChunkHashesFound; tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> CachedSequenceHashesFound; - if (!m_Options.PrimeCacheOnly) - { - ScanCacheFolder(CachedChunkHashesFound, CachedSequenceHashesFound); - } + ScanCacheFolder(CachedChunkHashesFound, CachedSequenceHashesFound); tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> CachedBlocksFound; - if (!m_Options.PrimeCacheOnly) - { - ScanTempBlocksFolder(CachedBlocksFound); - } + ScanTempBlocksFolder(CachedBlocksFound); tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> SequenceIndexesLeftToFindToRemoteIndex; - - if (!m_Options.PrimeCacheOnly && m_Options.EnableTargetFolderScavenging) - { - // Pick up all whole files we can use from current local state - ZEN_TRACE_CPU("GetLocalSequences"); - - Stopwatch LocalTimer; - - std::vector<uint32_t> MissingSequenceIndexes = ScanTargetFolder(CachedChunkHashesFound, CachedSequenceHashesFound); - - for (uint32_t RemoteSequenceIndex : MissingSequenceIndexes) - { - // We must write the sequence - const uint32_t ChunkCount = m_RemoteContent.ChunkedContent.ChunkCounts[RemoteSequenceIndex]; - const IoHash& RemoteSequenceRawHash = m_RemoteContent.ChunkedContent.SequenceRawHashes[RemoteSequenceIndex]; - SequenceIndexChunksLeftToWriteCounters[RemoteSequenceIndex] = ChunkCount; - SequenceIndexesLeftToFindToRemoteIndex.insert({RemoteSequenceRawHash, RemoteSequenceIndex}); - } - } - else - { - for (uint32_t RemoteSequenceIndex = 0; RemoteSequenceIndex < m_RemoteContent.ChunkedContent.SequenceRawHashes.size(); - RemoteSequenceIndex++) - { - const uint32_t ChunkCount = m_RemoteContent.ChunkedContent.ChunkCounts[RemoteSequenceIndex]; - SequenceIndexChunksLeftToWriteCounters[RemoteSequenceIndex] = ChunkCount; - } - } + InitializeSequenceCounters(SequenceIndexChunksLeftToWriteCounters, + SequenceIndexesLeftToFindToRemoteIndex, + CachedChunkHashesFound, + CachedSequenceHashesFound); std::vector<ChunkedFolderContent> ScavengedContents; std::vector<ChunkedContentLookup> ScavengedLookups; @@ -618,7 +743,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) std::vector<ScavengedSequenceCopyOperation> ScavengedSequenceCopyOperations; uint64_t ScavengedPathsCount = 0; - if (!m_Options.PrimeCacheOnly && m_Options.EnableOtherDownloadsScavenging) + if (m_Options.EnableOtherDownloadsScavenging) { ZEN_TRACE_CPU("GetScavengedSequences"); @@ -627,123 +752,19 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) if (!SequenceIndexesLeftToFindToRemoteIndex.empty()) { std::vector<ScavengeSource> ScavengeSources = FindScavengeSources(); - - const size_t ScavengePathCount = ScavengeSources.size(); - - ScavengedContents.resize(ScavengePathCount); - ScavengedLookups.resize(ScavengePathCount); - ScavengedPaths.resize(ScavengePathCount); - - std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Scavenging")); - OperationLogOutput::ProgressBar& ScavengeProgressBar(*ProgressBarPtr); - - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); - - std::atomic<uint64_t> PathsFound(0); - std::atomic<uint64_t> ChunksFound(0); - std::atomic<uint64_t> PathsScavenged(0); - - for (size_t ScavengeIndex = 0; ScavengeIndex < ScavengePathCount; ScavengeIndex++) - { - Work.ScheduleWork(m_IOWorkerPool, - [this, - &ScavengeSources, - &ScavengedContents, - &ScavengedPaths, - &ScavengedLookups, - &PathsFound, - &ChunksFound, - &PathsScavenged, - ScavengeIndex](std::atomic<bool>&) { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_FindScavengeContent"); - - const ScavengeSource& Source = ScavengeSources[ScavengeIndex]; - ChunkedFolderContent& ScavengedLocalContent = ScavengedContents[ScavengeIndex]; - ChunkedContentLookup& ScavengedLookup = ScavengedLookups[ScavengeIndex]; - - if (FindScavengeContent(Source, ScavengedLocalContent, ScavengedLookup)) - { - ScavengedPaths[ScavengeIndex] = Source.Path; - PathsFound += ScavengedLocalContent.Paths.size(); - ChunksFound += ScavengedLocalContent.ChunkedContent.ChunkHashes.size(); - } - else - { - ScavengedPaths[ScavengeIndex].clear(); - } - PathsScavenged++; - } - }); - } - { - ZEN_TRACE_CPU("ScavengeScan_Wait"); - - Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { - ZEN_UNUSED(PendingWork); - std::string Details = fmt::format("{}/{} scanned. {} paths and {} chunks found for scavenging", - PathsScavenged.load(), - ScavengePathCount, - PathsFound.load(), - ChunksFound.load()); - ScavengeProgressBar.UpdateState( - {.Task = "Scavenging ", - .Details = Details, - .TotalCount = ScavengePathCount, - .RemainingCount = ScavengePathCount - PathsScavenged.load(), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, - false); - }); - } - - ScavengeProgressBar.Finish(); + ScanScavengeSources(ScavengeSources, ScavengedContents, ScavengedLookups, ScavengedPaths); if (m_AbortFlag) { return; } - for (uint32_t ScavengedContentIndex = 0; - ScavengedContentIndex < ScavengedContents.size() && (!SequenceIndexesLeftToFindToRemoteIndex.empty()); - ScavengedContentIndex++) - { - const std::filesystem::path& ScavengePath = ScavengedPaths[ScavengedContentIndex]; - if (!ScavengePath.empty()) - { - const ChunkedFolderContent& ScavengedLocalContent = ScavengedContents[ScavengedContentIndex]; - const ChunkedContentLookup& ScavengedLookup = ScavengedLookups[ScavengedContentIndex]; - - for (uint32_t ScavengedSequenceIndex = 0; - ScavengedSequenceIndex < ScavengedLocalContent.ChunkedContent.SequenceRawHashes.size(); - ScavengedSequenceIndex++) - { - const IoHash& SequenceRawHash = ScavengedLocalContent.ChunkedContent.SequenceRawHashes[ScavengedSequenceIndex]; - if (auto It = SequenceIndexesLeftToFindToRemoteIndex.find(SequenceRawHash); - It != SequenceIndexesLeftToFindToRemoteIndex.end()) - { - const uint32_t RemoteSequenceIndex = It->second; - const uint64_t RawSize = - m_RemoteContent.RawSizes[m_RemoteLookup.SequenceIndexFirstPathIndex[RemoteSequenceIndex]]; - ZEN_ASSERT(RawSize > 0); - - const uint32_t ScavengedPathIndex = ScavengedLookup.SequenceIndexFirstPathIndex[ScavengedSequenceIndex]; - ZEN_ASSERT_SLOW(IsFile((ScavengePath / ScavengedLocalContent.Paths[ScavengedPathIndex]).make_preferred())); - - ScavengedSequenceCopyOperations.push_back({.ScavengedContentIndex = ScavengedContentIndex, - .ScavengedPathIndex = ScavengedPathIndex, - .RemoteSequenceIndex = RemoteSequenceIndex, - .RawSize = RawSize}); - - SequenceIndexesLeftToFindToRemoteIndex.erase(SequenceRawHash); - SequenceIndexChunksLeftToWriteCounters[RemoteSequenceIndex] = 0; - - m_CacheMappingStats.ScavengedPathsMatchingSequencesCount++; - m_CacheMappingStats.ScavengedPathsMatchingSequencesByteCount += RawSize; - } - } - ScavengedPathsCount++; - } - } + MatchScavengedSequencesToRemote(ScavengedContents, + ScavengedLookups, + ScavengedPaths, + SequenceIndexesLeftToFindToRemoteIndex, + SequenceIndexChunksLeftToWriteCounters, + ScavengedSequenceCopyOperations, + ScavengedPathsCount); } m_CacheMappingStats.ScavengeElapsedWallTimeUs += ScavengeTimer.GetElapsedTimeUs(); } @@ -762,7 +783,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) tsl::robin_map<IoHash, size_t, IoHash::Hasher> RawHashToCopyChunkDataIndex; std::vector<CopyChunkData> CopyChunkDatas; - if (!m_Options.PrimeCacheOnly && m_Options.EnableTargetFolderScavenging) + if (m_Options.EnableTargetFolderScavenging) { ZEN_TRACE_CPU("GetLocalChunks"); @@ -782,7 +803,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) m_CacheMappingStats.LocalScanElapsedWallTimeUs += LocalTimer.GetElapsedTimeUs(); } - if (!m_Options.PrimeCacheOnly && m_Options.EnableOtherDownloadsScavenging) + if (m_Options.EnableOtherDownloadsScavenging) { ZEN_TRACE_CPU("GetScavengeChunks"); @@ -813,54 +834,40 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) if (m_CacheMappingStats.CacheSequenceHashesCount > 0 || m_CacheMappingStats.CacheChunkCount > 0 || m_CacheMappingStats.CacheBlockCount > 0) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Download cache: Found {} ({}) chunk sequences, {} ({}) chunks, {} ({}) blocks in {}", - m_CacheMappingStats.CacheSequenceHashesCount, - NiceBytes(m_CacheMappingStats.CacheSequenceHashesByteCount), - m_CacheMappingStats.CacheChunkCount, - NiceBytes(m_CacheMappingStats.CacheChunkByteCount), - m_CacheMappingStats.CacheBlockCount, - NiceBytes(m_CacheMappingStats.CacheBlocksByteCount), - NiceTimeSpanMs(m_CacheMappingStats.CacheScanElapsedWallTimeUs / 1000)); + ZEN_INFO("Download cache: Found {} ({}) chunk sequences, {} ({}) chunks, {} ({}) blocks in {}", + m_CacheMappingStats.CacheSequenceHashesCount, + NiceBytes(m_CacheMappingStats.CacheSequenceHashesByteCount), + m_CacheMappingStats.CacheChunkCount, + NiceBytes(m_CacheMappingStats.CacheChunkByteCount), + m_CacheMappingStats.CacheBlockCount, + NiceBytes(m_CacheMappingStats.CacheBlocksByteCount), + NiceTimeSpanMs(m_CacheMappingStats.CacheScanElapsedWallTimeUs / 1000)); } if (m_CacheMappingStats.LocalPathsMatchingSequencesCount > 0 || m_CacheMappingStats.LocalChunkMatchingRemoteCount > 0) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Local state : Found {} ({}) chunk sequences, {} ({}) chunks in {}", - m_CacheMappingStats.LocalPathsMatchingSequencesCount, - NiceBytes(m_CacheMappingStats.LocalPathsMatchingSequencesByteCount), - m_CacheMappingStats.LocalChunkMatchingRemoteCount, - NiceBytes(m_CacheMappingStats.LocalChunkMatchingRemoteByteCount), - NiceTimeSpanMs(m_CacheMappingStats.LocalScanElapsedWallTimeUs / 1000)); + ZEN_INFO("Local state : Found {} ({}) chunk sequences, {} ({}) chunks in {}", + m_CacheMappingStats.LocalPathsMatchingSequencesCount, + NiceBytes(m_CacheMappingStats.LocalPathsMatchingSequencesByteCount), + m_CacheMappingStats.LocalChunkMatchingRemoteCount, + NiceBytes(m_CacheMappingStats.LocalChunkMatchingRemoteByteCount), + NiceTimeSpanMs(m_CacheMappingStats.LocalScanElapsedWallTimeUs / 1000)); } if (m_CacheMappingStats.ScavengedPathsMatchingSequencesCount > 0 || m_CacheMappingStats.ScavengedChunkMatchingRemoteCount > 0) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Scavenge of {} paths, found {} ({}) chunk sequences, {} ({}) chunks in {}", - ScavengedPathsCount, - m_CacheMappingStats.ScavengedPathsMatchingSequencesCount, - NiceBytes(m_CacheMappingStats.ScavengedPathsMatchingSequencesByteCount), - m_CacheMappingStats.ScavengedChunkMatchingRemoteCount, - NiceBytes(m_CacheMappingStats.ScavengedChunkMatchingRemoteByteCount), - NiceTimeSpanMs(m_CacheMappingStats.ScavengeElapsedWallTimeUs / 1000)); + ZEN_INFO("Scavenge of {} paths, found {} ({}) chunk sequences, {} ({}) chunks in {}", + ScavengedPathsCount, + m_CacheMappingStats.ScavengedPathsMatchingSequencesCount, + NiceBytes(m_CacheMappingStats.ScavengedPathsMatchingSequencesByteCount), + m_CacheMappingStats.ScavengedChunkMatchingRemoteCount, + NiceBytes(m_CacheMappingStats.ScavengedChunkMatchingRemoteByteCount), + NiceTimeSpanMs(m_CacheMappingStats.ScavengeElapsedWallTimeUs / 1000)); } } - uint64_t BytesToWrite = 0; - - for (uint32_t RemoteChunkIndex = 0; RemoteChunkIndex < m_RemoteContent.ChunkedContent.ChunkHashes.size(); RemoteChunkIndex++) - { - uint64_t ChunkWriteCount = GetChunkWriteCount(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex); - if (ChunkWriteCount > 0) - { - BytesToWrite += m_RemoteContent.ChunkedContent.ChunkRawSizes[RemoteChunkIndex] * ChunkWriteCount; - if (!RemoteChunkIndexNeedsCopyFromLocalFileFlags[RemoteChunkIndex]) - { - RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex] = true; - } - } - } + uint64_t BytesToWrite = CalculateBytesToWriteAndFlagNeededChunks(SequenceIndexChunksLeftToWriteCounters, + RemoteChunkIndexNeedsCopyFromLocalFileFlags, + RemoteChunkIndexNeedsCopyFromSourceFlags); for (const ScavengedSequenceCopyOperation& ScavengeCopyOp : ScavengedSequenceCopyOperations) { @@ -885,7 +892,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) BlobsExistsResult ExistsResult; { ChunkBlockAnalyser BlockAnalyser( - m_LogOutput, + Log(), m_BlockDescriptions, ChunkBlockAnalyser::Options{.IsQuiet = m_Options.IsQuiet, .IsVerbose = m_Options.IsVerbose, @@ -900,183 +907,21 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) std::vector<uint32_t> FetchBlockIndexes; std::vector<uint32_t> CachedChunkBlockIndexes; + ClassifyCachedAndFetchBlocks(NeededBlocks, CachedBlocksFound, TotalPartWriteCount, CachedChunkBlockIndexes, FetchBlockIndexes); - { - ZEN_TRACE_CPU("BlockCacheFileExists"); - for (const ChunkBlockAnalyser::NeededBlock& NeededBlock : NeededBlocks) - { - if (m_Options.PrimeCacheOnly) - { - FetchBlockIndexes.push_back(NeededBlock.BlockIndex); - } - else - { - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[NeededBlock.BlockIndex]; - bool UsingCachedBlock = false; - if (auto It = CachedBlocksFound.find(BlockDescription.BlockHash); It != CachedBlocksFound.end()) - { - TotalPartWriteCount++; - - std::filesystem::path BlockPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(); - if (IsFile(BlockPath)) - { - CachedChunkBlockIndexes.push_back(NeededBlock.BlockIndex); - UsingCachedBlock = true; - } - } - if (!UsingCachedBlock) - { - FetchBlockIndexes.push_back(NeededBlock.BlockIndex); - } - } - } - } - - std::vector<uint32_t> NeededLooseChunkIndexes; - - { - NeededLooseChunkIndexes.reserve(m_LooseChunkHashes.size()); - for (uint32_t LooseChunkIndex = 0; LooseChunkIndex < m_LooseChunkHashes.size(); LooseChunkIndex++) - { - const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex]; - auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); - ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end()); - const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second; - - if (RemoteChunkIndexNeedsCopyFromLocalFileFlags[RemoteChunkIndex]) - { - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Skipping chunk {} due to cache reuse", - m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]); - } - continue; - } - - bool NeedsCopy = true; - if (RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex].compare_exchange_strong(NeedsCopy, false)) - { - uint64_t WriteCount = GetChunkWriteCount(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex); - if (WriteCount == 0) - { - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Skipping chunk {} due to cache reuse", - m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]); - } - } - else - { - NeededLooseChunkIndexes.push_back(LooseChunkIndex); - } - } - } - } - - if (m_Storage.CacheStorage) - { - ZEN_TRACE_CPU("BlobCacheExistCheck"); - Stopwatch Timer; - - std::vector<IoHash> BlobHashes; - BlobHashes.reserve(NeededLooseChunkIndexes.size() + FetchBlockIndexes.size()); + std::vector<uint32_t> NeededLooseChunkIndexes = DetermineNeededLooseChunkIndexes(SequenceIndexChunksLeftToWriteCounters, + RemoteChunkIndexNeedsCopyFromLocalFileFlags, + RemoteChunkIndexNeedsCopyFromSourceFlags); - for (const uint32_t LooseChunkIndex : NeededLooseChunkIndexes) - { - BlobHashes.push_back(m_LooseChunkHashes[LooseChunkIndex]); - } - - for (uint32_t BlockIndex : FetchBlockIndexes) - { - BlobHashes.push_back(m_BlockDescriptions[BlockIndex].BlockHash); - } - - const std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult = - m_Storage.CacheStorage->BlobsExists(m_BuildId, BlobHashes); - - if (CacheExistsResult.size() == BlobHashes.size()) - { - ExistsResult.ExistingBlobs.reserve(CacheExistsResult.size()); - for (size_t BlobIndex = 0; BlobIndex < BlobHashes.size(); BlobIndex++) - { - if (CacheExistsResult[BlobIndex].HasBody) - { - ExistsResult.ExistingBlobs.insert(BlobHashes[BlobIndex]); - } - } - } - ExistsResult.ElapsedTimeMs = Timer.GetElapsedTimeMs(); - if (!ExistsResult.ExistingBlobs.empty() && !m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Remote cache : Found {} out of {} needed blobs in {}", - ExistsResult.ExistingBlobs.size(), - BlobHashes.size(), - NiceTimeSpanMs(ExistsResult.ElapsedTimeMs)); - } - } - - std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> BlockPartialDownloadModes; - - if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off) - { - BlockPartialDownloadModes.resize(m_BlockDescriptions.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); - } - else - { - ChunkBlockAnalyser::EPartialBlockDownloadMode CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; - ChunkBlockAnalyser::EPartialBlockDownloadMode CachePartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; - - switch (m_Options.PartialBlockRequestMode) - { - case EPartialBlockRequestMode::Off: - break; - case EPartialBlockRequestMode::ZenCacheOnly: - CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1 - ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed - : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; - CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; - break; - case EPartialBlockRequestMode::Mixed: - CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1 - ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed - : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; - CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; - break; - case EPartialBlockRequestMode::All: - CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1 - ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed - : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; - CloudPartialDownloadMode = m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest > 1 - ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange - : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; - break; - default: - ZEN_ASSERT(false); - break; - } - - BlockPartialDownloadModes.reserve(m_BlockDescriptions.size()); - for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++) - { - const bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash); - BlockPartialDownloadModes.push_back(BlockExistInCache ? CachePartialDownloadMode : CloudPartialDownloadMode); - } - } + ExistsResult = QueryBlobCacheExists(NeededLooseChunkIndexes, FetchBlockIndexes); + std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> BlockPartialDownloadModes = + DeterminePartialDownloadModes(ExistsResult); ZEN_ASSERT(BlockPartialDownloadModes.size() == m_BlockDescriptions.size()); ChunkBlockAnalyser::BlockResult PartialBlocks = BlockAnalyser.CalculatePartialBlockDownloads(NeededBlocks, BlockPartialDownloadModes); - struct LooseChunkHashWorkData - { - std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs; - uint32_t RemoteChunkIndex = (uint32_t)-1; - }; - TotalRequestCount += NeededLooseChunkIndexes.size(); TotalPartWriteCount += NeededLooseChunkIndexes.size(); TotalRequestCount += PartialBlocks.BlockRanges.size(); @@ -1084,626 +929,52 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) TotalRequestCount += PartialBlocks.FullBlockIndexes.size(); TotalPartWriteCount += PartialBlocks.FullBlockIndexes.size(); - std::vector<LooseChunkHashWorkData> LooseChunkHashWorks; - for (uint32_t LooseChunkIndex : NeededLooseChunkIndexes) - { - const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex]; - auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); - ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end()); - const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second; - - std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs = - GetRemainingChunkTargets(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex); - - ZEN_ASSERT(!ChunkTargetPtrs.empty()); - LooseChunkHashWorks.push_back( - LooseChunkHashWorkData{.ChunkTargetPtrs = ChunkTargetPtrs, .RemoteChunkIndex = RemoteChunkIndex}); - } + std::vector<LooseChunkHashWorkData> LooseChunkHashWorks = + BuildLooseChunkHashWorks(NeededLooseChunkIndexes, SequenceIndexChunksLeftToWriteCounters); ZEN_TRACE_CPU("WriteChunks"); - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::WriteChunks, (uint32_t)TaskSteps::StepCount); + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::WriteChunks, (uint32_t)TaskSteps::StepCount); Stopwatch WriteTimer; FilteredRate FilteredDownloadedBytesPerSecond; FilteredRate FilteredWrittenBytesPerSecond; - std::unique_ptr<OperationLogOutput::ProgressBar> WriteProgressBarPtr( - m_LogOutput.CreateProgressBar(m_Options.PrimeCacheOnly ? "Downloading" : "Writing")); - OperationLogOutput::ProgressBar& WriteProgressBar(*WriteProgressBarPtr); - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Writing"); + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); TotalPartWriteCount += CopyChunkDatas.size(); TotalPartWriteCount += ScavengedSequenceCopyOperations.size(); BufferedWriteFileCache WriteCache; - for (uint32_t ScavengeOpIndex = 0; ScavengeOpIndex < ScavengedSequenceCopyOperations.size(); ScavengeOpIndex++) - { - if (m_AbortFlag) - { - break; - } - if (!m_Options.PrimeCacheOnly) - { - Work.ScheduleWork( - m_IOWorkerPool, - [this, - &ScavengedPaths, - &ScavengedSequenceCopyOperations, - &ScavengedContents, - &FilteredWrittenBytesPerSecond, - ScavengeOpIndex, - &WritePartsComplete, - TotalPartWriteCount](std::atomic<bool>&) mutable { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_WriteScavenged"); - - FilteredWrittenBytesPerSecond.Start(); - - const ScavengedSequenceCopyOperation& ScavengeOp = ScavengedSequenceCopyOperations[ScavengeOpIndex]; - const ChunkedFolderContent& ScavengedContent = ScavengedContents[ScavengeOp.ScavengedContentIndex]; - const std::filesystem::path& ScavengeRootPath = ScavengedPaths[ScavengeOp.ScavengedContentIndex]; - - WriteScavengedSequenceToCache(ScavengeRootPath, ScavengedContent, ScavengeOp); - - WritePartsComplete++; - if (WritePartsComplete == TotalPartWriteCount) - { - FilteredWrittenBytesPerSecond.Stop(); - } - } - }); - } - } - - for (uint32_t LooseChunkHashWorkIndex = 0; LooseChunkHashWorkIndex < LooseChunkHashWorks.size(); LooseChunkHashWorkIndex++) - { - if (m_AbortFlag) - { - break; - } - - if (m_Options.PrimeCacheOnly) - { - const uint32_t RemoteChunkIndex = LooseChunkHashWorks[LooseChunkHashWorkIndex].RemoteChunkIndex; - if (ExistsResult.ExistingBlobs.contains(m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex])) - { - m_DownloadStats.RequestsCompleteCount++; - continue; - } - } - - Work.ScheduleWork( - m_IOWorkerPool, - [this, - &SequenceIndexChunksLeftToWriteCounters, - &Work, - &ExistsResult, - &WritePartsComplete, - &LooseChunkHashWorks, - LooseChunkHashWorkIndex, - TotalRequestCount, - TotalPartWriteCount, - &WriteCache, - &FilteredDownloadedBytesPerSecond, - &FilteredWrittenBytesPerSecond](std::atomic<bool>&) mutable { - ZEN_TRACE_CPU("Async_ReadPreDownloadedChunk"); - if (!m_AbortFlag) - { - LooseChunkHashWorkData& LooseChunkHashWork = LooseChunkHashWorks[LooseChunkHashWorkIndex]; - const uint32_t RemoteChunkIndex = LooseChunkHashWorks[LooseChunkHashWorkIndex].RemoteChunkIndex; - WriteLooseChunk(RemoteChunkIndex, - ExistsResult, - SequenceIndexChunksLeftToWriteCounters, - WritePartsComplete, - std::move(LooseChunkHashWork.ChunkTargetPtrs), - WriteCache, - Work, - TotalRequestCount, - TotalPartWriteCount, - FilteredDownloadedBytesPerSecond, - FilteredWrittenBytesPerSecond); - } - }, - WorkerThreadPool::EMode::EnableBacklog); - } - - std::unique_ptr<CloneQueryInterface> CloneQuery; - if (m_Options.AllowFileClone) - { - CloneQuery = GetCloneQueryInterface(m_CacheFolderPath); - } - - for (size_t CopyDataIndex = 0; CopyDataIndex < CopyChunkDatas.size(); CopyDataIndex++) - { - ZEN_ASSERT(!m_Options.PrimeCacheOnly); - if (m_AbortFlag) - { - break; - } - - Work.ScheduleWork(m_IOWorkerPool, - [this, - &CloneQuery, - &SequenceIndexChunksLeftToWriteCounters, - &WriteCache, - &Work, - &FilteredWrittenBytesPerSecond, - &CopyChunkDatas, - &ScavengedContents, - &ScavengedLookups, - &ScavengedPaths, - &WritePartsComplete, - TotalPartWriteCount, - CopyDataIndex](std::atomic<bool>&) { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_CopyLocal"); - - FilteredWrittenBytesPerSecond.Start(); - const CopyChunkData& CopyData = CopyChunkDatas[CopyDataIndex]; - - std::vector<uint32_t> WrittenSequenceIndexes = WriteLocalChunkToCache(CloneQuery.get(), - CopyData, - ScavengedContents, - ScavengedLookups, - ScavengedPaths, - WriteCache); - WritePartsComplete++; - if (!m_AbortFlag) - { - if (WritePartsComplete == TotalPartWriteCount) - { - FilteredWrittenBytesPerSecond.Stop(); - } - - // Write tracking, updating this must be done without any files open - std::vector<uint32_t> CompletedChunkSequences; - for (uint32_t RemoteSequenceIndex : WrittenSequenceIndexes) - { - if (CompleteSequenceChunk(RemoteSequenceIndex, SequenceIndexChunksLeftToWriteCounters)) - { - CompletedChunkSequences.push_back(RemoteSequenceIndex); - } - } - WriteCache.Close(CompletedChunkSequences); - VerifyAndCompleteChunkSequencesAsync(CompletedChunkSequences, Work); - } - } - }); - } - - for (uint32_t BlockIndex : CachedChunkBlockIndexes) - { - ZEN_ASSERT(!m_Options.PrimeCacheOnly); - if (m_AbortFlag) - { - break; - } - - Work.ScheduleWork( - m_IOWorkerPool, - [this, - &RemoteChunkIndexNeedsCopyFromSourceFlags, - &SequenceIndexChunksLeftToWriteCounters, - &WriteCache, - &Work, - &FilteredWrittenBytesPerSecond, - &WritePartsComplete, - TotalPartWriteCount, - BlockIndex](std::atomic<bool>&) mutable { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_WriteCachedBlock"); - - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - FilteredWrittenBytesPerSecond.Start(); - - std::filesystem::path BlockChunkPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(); - IoBuffer BlockBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath); - if (!BlockBuffer) - { - throw std::runtime_error( - fmt::format("Can not read block {} at {}", BlockDescription.BlockHash, BlockChunkPath)); - } - - if (!m_AbortFlag) - { - if (!WriteChunksBlockToCache(BlockDescription, - SequenceIndexChunksLeftToWriteCounters, - Work, - CompositeBuffer(std::move(BlockBuffer)), - RemoteChunkIndexNeedsCopyFromSourceFlags, - WriteCache)) - { - std::error_code DummyEc; - RemoveFile(BlockChunkPath, DummyEc); - throw std::runtime_error(fmt::format("Block {} is malformed", BlockDescription.BlockHash)); - } + WriteChunksContext Context{.Work = Work, + .WriteCache = WriteCache, + .SequenceIndexChunksLeftToWriteCounters = SequenceIndexChunksLeftToWriteCounters, + .RemoteChunkIndexNeedsCopyFromSourceFlags = RemoteChunkIndexNeedsCopyFromSourceFlags, + .WritePartsComplete = WritePartsComplete, + .TotalPartWriteCount = TotalPartWriteCount, + .TotalRequestCount = TotalRequestCount, + .ExistsResult = ExistsResult, + .FilteredDownloadedBytesPerSecond = FilteredDownloadedBytesPerSecond, + .FilteredWrittenBytesPerSecond = FilteredWrittenBytesPerSecond}; - std::error_code Ec = TryRemoveFile(BlockChunkPath); - if (Ec) - { - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Failed removing file '{}', reason: ({}) {}", - BlockChunkPath, - Ec.value(), - Ec.message()); - } - - WritePartsComplete++; - - if (WritePartsComplete == TotalPartWriteCount) - { - FilteredWrittenBytesPerSecond.Stop(); - } - } - } - }); - } - - for (size_t BlockRangeIndex = 0; BlockRangeIndex < PartialBlocks.BlockRanges.size();) - { - ZEN_ASSERT(!m_Options.PrimeCacheOnly); - if (m_AbortFlag) - { - break; - } - - size_t RangeCount = 1; - size_t RangesLeft = PartialBlocks.BlockRanges.size() - BlockRangeIndex; - const ChunkBlockAnalyser::BlockRangeDescriptor& CurrentBlockRange = PartialBlocks.BlockRanges[BlockRangeIndex]; - while (RangeCount < RangesLeft && - CurrentBlockRange.BlockIndex == PartialBlocks.BlockRanges[BlockRangeIndex + RangeCount].BlockIndex) - { - RangeCount++; - } - - Work.ScheduleWork( - m_NetworkPool, - [this, - &RemoteChunkIndexNeedsCopyFromSourceFlags, - &SequenceIndexChunksLeftToWriteCounters, - &ExistsResult, - &WriteCache, - &FilteredDownloadedBytesPerSecond, - TotalRequestCount, - &WritePartsComplete, - TotalPartWriteCount, - &FilteredWrittenBytesPerSecond, - &Work, - &PartialBlocks, - BlockRangeStartIndex = BlockRangeIndex, - RangeCount = RangeCount](std::atomic<bool>&) { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_GetPartialBlockRanges"); - - FilteredDownloadedBytesPerSecond.Start(); - - DownloadPartialBlock( - PartialBlocks.BlockRanges, - BlockRangeStartIndex, - RangeCount, - ExistsResult, - [this, - &RemoteChunkIndexNeedsCopyFromSourceFlags, - &SequenceIndexChunksLeftToWriteCounters, - &WritePartsComplete, - &WriteCache, - &Work, - TotalRequestCount, - TotalPartWriteCount, - &FilteredDownloadedBytesPerSecond, - &FilteredWrittenBytesPerSecond, - &PartialBlocks](IoBuffer&& InMemoryBuffer, - const std::filesystem::path& OnDiskPath, - size_t BlockRangeStartIndex, - std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths) { - if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } - - if (!m_AbortFlag) - { - Work.ScheduleWork( - m_IOWorkerPool, - [this, - &RemoteChunkIndexNeedsCopyFromSourceFlags, - &SequenceIndexChunksLeftToWriteCounters, - &WritePartsComplete, - &WriteCache, - &Work, - TotalPartWriteCount, - &FilteredWrittenBytesPerSecond, - &PartialBlocks, - BlockRangeStartIndex, - BlockChunkPath = std::filesystem::path(OnDiskPath), - BlockPartialBuffer = std::move(InMemoryBuffer), - OffsetAndLengths = std::vector<std::pair<uint64_t, uint64_t>>(OffsetAndLengths.begin(), - OffsetAndLengths.end())]( - std::atomic<bool>&) mutable { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_WritePartialBlock"); - - const uint32_t BlockIndex = PartialBlocks.BlockRanges[BlockRangeStartIndex].BlockIndex; - - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - - if (BlockChunkPath.empty()) - { - ZEN_ASSERT(BlockPartialBuffer); - } - else - { - ZEN_ASSERT(!BlockPartialBuffer); - BlockPartialBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath); - if (!BlockPartialBuffer) - { - throw std::runtime_error( - fmt::format("Could not open downloaded block {} from {}", - BlockDescription.BlockHash, - BlockChunkPath)); - } - } - - FilteredWrittenBytesPerSecond.Start(); - - size_t RangeCount = OffsetAndLengths.size(); - - for (size_t PartialRangeIndex = 0; PartialRangeIndex < RangeCount; PartialRangeIndex++) - { - const std::pair<uint64_t, uint64_t>& OffsetAndLength = - OffsetAndLengths[PartialRangeIndex]; - IoBuffer BlockRangeBuffer(BlockPartialBuffer, - OffsetAndLength.first, - OffsetAndLength.second); - - const ChunkBlockAnalyser::BlockRangeDescriptor& RangeDescriptor = - PartialBlocks.BlockRanges[BlockRangeStartIndex + PartialRangeIndex]; - - if (!WritePartialBlockChunksToCache(BlockDescription, - SequenceIndexChunksLeftToWriteCounters, - Work, - CompositeBuffer(std::move(BlockRangeBuffer)), - RangeDescriptor.ChunkBlockIndexStart, - RangeDescriptor.ChunkBlockIndexStart + - RangeDescriptor.ChunkBlockIndexCount - 1, - RemoteChunkIndexNeedsCopyFromSourceFlags, - WriteCache)) - { - std::error_code DummyEc; - RemoveFile(BlockChunkPath, DummyEc); - throw std::runtime_error( - fmt::format("Partial block {} is malformed", BlockDescription.BlockHash)); - } - - WritePartsComplete++; - if (WritePartsComplete == TotalPartWriteCount) - { - FilteredWrittenBytesPerSecond.Stop(); - } - } - std::error_code Ec = TryRemoveFile(BlockChunkPath); - if (Ec) - { - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Failed removing file '{}', reason: ({}) {}", - BlockChunkPath, - Ec.value(), - Ec.message()); - } - } - }, - OnDiskPath.empty() ? WorkerThreadPool::EMode::DisableBacklog - : WorkerThreadPool::EMode::EnableBacklog); - } - }); - } - }); - BlockRangeIndex += RangeCount; - } - - for (uint32_t BlockIndex : PartialBlocks.FullBlockIndexes) - { - if (m_AbortFlag) - { - break; - } - - if (m_Options.PrimeCacheOnly && ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash)) - { - m_DownloadStats.RequestsCompleteCount++; - continue; - } - - Work.ScheduleWork( - m_NetworkPool, - [this, - &WritePartsComplete, - TotalPartWriteCount, - &FilteredWrittenBytesPerSecond, - &ExistsResult, - &Work, - &WriteCache, - &RemoteChunkIndexNeedsCopyFromSourceFlags, - &SequenceIndexChunksLeftToWriteCounters, - &FilteredDownloadedBytesPerSecond, - TotalRequestCount, - BlockIndex](std::atomic<bool>&) { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_GetFullBlock"); - - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - - FilteredDownloadedBytesPerSecond.Start(); - - IoBuffer BlockBuffer; - const bool ExistsInCache = - m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash); - if (ExistsInCache) - { - BlockBuffer = m_Storage.CacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); - } - if (!BlockBuffer) - { - BlockBuffer = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); - if (BlockBuffer && m_Storage.CacheStorage && m_Options.PopulateCache) - { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, - BlockDescription.BlockHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(BlockBuffer))); - } - } - if (!BlockBuffer) - { - throw std::runtime_error(fmt::format("Block {} is missing", BlockDescription.BlockHash)); - } - if (!m_AbortFlag) - { - uint64_t BlockSize = BlockBuffer.GetSize(); - m_DownloadStats.DownloadedBlockCount++; - m_DownloadStats.DownloadedBlockByteCount += BlockSize; - m_DownloadStats.RequestsCompleteCount++; - if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } + ScheduleScavengedSequenceWrites(Context, ScavengedSequenceCopyOperations, ScavengedContents, ScavengedPaths); + ScheduleLooseChunkWrites(Context, LooseChunkHashWorks); - if (!m_Options.PrimeCacheOnly) - { - std::filesystem::path BlockChunkPath; + std::unique_ptr<CloneQueryInterface> CloneQuery = + m_Options.AllowFileClone ? GetCloneQueryInterface(m_CacheFolderPath) : nullptr; - // Check if the dowloaded block is file based and we can move it directly without rewriting it - { - IoBufferFileReference FileRef; - if (BlockBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) && - (FileRef.FileChunkSize == BlockSize)) - { - ZEN_TRACE_CPU("MoveTempFullBlock"); - std::error_code Ec; - std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec); - if (!Ec) - { - BlockBuffer.SetDeleteOnClose(false); - BlockBuffer = {}; - BlockChunkPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(); - RenameFile(TempBlobPath, BlockChunkPath, Ec); - if (Ec) - { - BlockChunkPath = std::filesystem::path{}; - - // Re-open the temp file again - BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete); - BlockBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockSize, true); - BlockBuffer.SetDeleteOnClose(true); - } - } - } - } - - if (BlockChunkPath.empty() && (BlockSize > m_Options.MaximumInMemoryPayloadSize)) - { - ZEN_TRACE_CPU("WriteTempFullBlock"); - // Could not be moved and rather large, lets store it on disk - BlockChunkPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(); - TemporaryFile::SafeWriteFile(BlockChunkPath, BlockBuffer); - BlockBuffer = {}; - } - - if (!m_AbortFlag) - { - Work.ScheduleWork( - m_IOWorkerPool, - [this, - &Work, - &RemoteChunkIndexNeedsCopyFromSourceFlags, - &SequenceIndexChunksLeftToWriteCounters, - BlockIndex, - &WriteCache, - &WritePartsComplete, - TotalPartWriteCount, - &FilteredWrittenBytesPerSecond, - BlockChunkPath, - BlockBuffer = std::move(BlockBuffer)](std::atomic<bool>&) mutable { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_WriteFullBlock"); - - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - - if (BlockChunkPath.empty()) - { - ZEN_ASSERT(BlockBuffer); - } - else - { - ZEN_ASSERT(!BlockBuffer); - BlockBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath); - if (!BlockBuffer) - { - throw std::runtime_error( - fmt::format("Could not open dowloaded block {} from {}", - BlockDescription.BlockHash, - BlockChunkPath)); - } - } - - FilteredWrittenBytesPerSecond.Start(); - if (!WriteChunksBlockToCache(BlockDescription, - SequenceIndexChunksLeftToWriteCounters, - Work, - CompositeBuffer(std::move(BlockBuffer)), - RemoteChunkIndexNeedsCopyFromSourceFlags, - WriteCache)) - { - std::error_code DummyEc; - RemoveFile(BlockChunkPath, DummyEc); - throw std::runtime_error( - fmt::format("Block {} is malformed", BlockDescription.BlockHash)); - } - - if (!BlockChunkPath.empty()) - { - std::error_code Ec = TryRemoveFile(BlockChunkPath); - if (Ec) - { - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Failed removing file '{}', reason: ({}) {}", - BlockChunkPath, - Ec.value(), - Ec.message()); - } - } - - WritePartsComplete++; - - if (WritePartsComplete == TotalPartWriteCount) - { - FilteredWrittenBytesPerSecond.Stop(); - } - } - }, - BlockChunkPath.empty() ? WorkerThreadPool::EMode::DisableBacklog - : WorkerThreadPool::EMode::EnableBacklog); - } - } - } - } - }); - } + ScheduleLocalChunkCopies(Context, CopyChunkDatas, CloneQuery.get(), ScavengedContents, ScavengedLookups, ScavengedPaths); + ScheduleCachedBlockWrites(Context, CachedChunkBlockIndexes); + SchedulePartialBlockDownloads(Context, PartialBlocks); + ScheduleFullBlockDownloads(Context, PartialBlocks.FullBlockIndexes); { ZEN_TRACE_CPU("WriteChunks_Wait"); - Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { + Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { ZEN_UNUSED(PendingWork); uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() + m_DownloadStats.DownloadedBlockByteCount.load() + @@ -1719,12 +990,11 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) { CloneDetails = fmt::format(" ({} cloned)", NiceBytes(m_DiskStats.CloneByteCount.load())); } - std::string WriteDetails = m_Options.PrimeCacheOnly ? "" - : fmt::format(" {}/{} ({}B/s) written{}", - NiceBytes(m_WrittenChunkByteCount.load()), - NiceBytes(BytesToWrite), - NiceNum(FilteredWrittenBytesPerSecond.GetCurrent()), - CloneDetails); + std::string WriteDetails = fmt::format(" {}/{} ({}B/s) written{}", + NiceBytes(m_WrittenChunkByteCount.load()), + NiceBytes(BytesToWrite), + NiceNum(FilteredWrittenBytesPerSecond.GetCurrent()), + CloneDetails); std::string Details = fmt::format("{}/{} ({}{}) downloaded.{}", m_DownloadStats.RequestsCompleteCount.load(), @@ -1734,11 +1004,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) WriteDetails); std::string Task; - if (m_Options.PrimeCacheOnly) - { - Task = "Downloading "; - } - else if ((m_WrittenChunkByteCount < BytesToWrite) || (BytesToValidate == 0)) + if ((m_WrittenChunkByteCount < BytesToWrite) || (BytesToValidate == 0)) { Task = "Writing chunks "; } @@ -1747,15 +1013,13 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) Task = "Verifying chunks "; } - WriteProgressBar.UpdateState( - {.Task = Task, - .Details = Details, - .TotalCount = m_Options.PrimeCacheOnly ? TotalRequestCount : (BytesToWrite + BytesToValidate), - .RemainingCount = m_Options.PrimeCacheOnly ? (TotalRequestCount - m_DownloadStats.RequestsCompleteCount.load()) - : ((BytesToWrite + BytesToValidate) - - (m_WrittenChunkByteCount.load() + m_ValidatedChunkByteCount.load())), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, - false); + ProgressBar->UpdateState({.Task = Task, + .Details = Details, + .TotalCount = (BytesToWrite + BytesToValidate), + .RemainingCount = ((BytesToWrite + BytesToValidate) - + (m_WrittenChunkByteCount.load() + m_ValidatedChunkByteCount.load())), + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + false); }); } @@ -1764,40 +1028,13 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) FilteredWrittenBytesPerSecond.Stop(); FilteredDownloadedBytesPerSecond.Stop(); - WriteProgressBar.Finish(); + ProgressBar->Finish(); if (m_AbortFlag) { return; } - if (!m_Options.PrimeCacheOnly) - { - uint32_t RawSequencesMissingWriteCount = 0; - for (uint32_t SequenceIndex = 0; SequenceIndex < SequenceIndexChunksLeftToWriteCounters.size(); SequenceIndex++) - { - const auto& SequenceIndexChunksLeftToWriteCounter = SequenceIndexChunksLeftToWriteCounters[SequenceIndex]; - if (SequenceIndexChunksLeftToWriteCounter.load() != 0) - { - RawSequencesMissingWriteCount++; - const uint32_t PathIndex = m_RemoteLookup.SequenceIndexFirstPathIndex[SequenceIndex]; - const std::filesystem::path& IncompletePath = m_RemoteContent.Paths[PathIndex]; - ZEN_ASSERT(!IncompletePath.empty()); - const uint32_t ExpectedSequenceCount = m_RemoteContent.ChunkedContent.ChunkCounts[SequenceIndex]; - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "{}: Max count {}, Current count {}", - IncompletePath, - ExpectedSequenceCount, - SequenceIndexChunksLeftToWriteCounter.load()); - } - ZEN_ASSERT(SequenceIndexChunksLeftToWriteCounter.load() <= ExpectedSequenceCount); - } - } - ZEN_ASSERT(RawSequencesMissingWriteCount == 0); - ZEN_ASSERT(m_WrittenChunkByteCount == BytesToWrite); - ZEN_ASSERT(m_ValidatedChunkByteCount == BytesToValidate); - } + VerifyWriteChunksComplete(SequenceIndexChunksLeftToWriteCounters, BytesToWrite, BytesToValidate); const uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() + m_DownloadStats.DownloadedBlockByteCount.load() + @@ -1809,17 +1046,15 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) { CloneDetails = fmt::format(" ({} cloned)", NiceBytes(m_DiskStats.CloneByteCount.load())); } - ZEN_OPERATION_LOG_INFO( - m_LogOutput, - "Downloaded {} ({}bits/s) in {}. Wrote {} ({}B/s){} in {}. Completed in {}", - NiceBytes(DownloadedBytes), - NiceNum(GetBytesPerSecond(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS(), DownloadedBytes * 8)), - NiceTimeSpanMs(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS() / 1000), - NiceBytes(m_WrittenChunkByteCount.load()), - NiceNum(GetBytesPerSecond(FilteredWrittenBytesPerSecond.GetElapsedTimeUS(), m_DiskStats.WriteByteCount.load())), - CloneDetails, - NiceTimeSpanMs(FilteredWrittenBytesPerSecond.GetElapsedTimeUS() / 1000), - NiceTimeSpanMs(WriteTimer.GetElapsedTimeMs())); + ZEN_INFO("Downloaded {} ({}bits/s) in {}. Wrote {} ({}B/s){} in {}. Completed in {}", + NiceBytes(DownloadedBytes), + NiceNum(GetBytesPerSecond(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS(), DownloadedBytes * 8)), + NiceTimeSpanMs(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS() / 1000), + NiceBytes(m_WrittenChunkByteCount.load()), + NiceNum(GetBytesPerSecond(FilteredWrittenBytesPerSecond.GetElapsedTimeUS(), m_DiskStats.WriteByteCount.load())), + CloneDetails, + NiceTimeSpanMs(FilteredWrittenBytesPerSecond.GetElapsedTimeUS() / 1000), + NiceTimeSpanMs(WriteTimer.GetElapsedTimeMs())); } m_WriteChunkStats.WriteChunksElapsedWallTimeUs = WriteTimer.GetElapsedTimeUs(); @@ -1827,199 +1062,40 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) m_WriteChunkStats.WriteTimeUs = FilteredWrittenBytesPerSecond.GetElapsedTimeUS(); } - if (m_Options.PrimeCacheOnly) + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::PrepareTarget, (uint32_t)TaskSteps::StepCount); + + if (m_AbortFlag) { return; } - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::PrepareTarget, (uint32_t)TaskSteps::StepCount); - - tsl::robin_map<uint32_t, uint32_t> RemotePathIndexToLocalPathIndex; - RemotePathIndexToLocalPathIndex.reserve(m_RemoteContent.Paths.size()); - - tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> SequenceHashToLocalPathIndex; - std::vector<uint32_t> RemoveLocalPathIndexes; + LocalPathCategorization Categorization = CategorizeLocalPaths(RemotePathToRemoteIndex); if (m_AbortFlag) { return; } + std::atomic<uint64_t> CachedCount = 0; + std::atomic<uint64_t> CachedByteCount = 0; + ScheduleLocalFileCaching(Categorization.FilesToCache, CachedCount, CachedByteCount); + if (m_AbortFlag) { - ZEN_TRACE_CPU("PrepareTarget"); - - tsl::robin_set<IoHash, IoHash::Hasher> CachedRemoteSequences; - - std::vector<uint32_t> FilesToCache; - - uint64_t MatchCount = 0; - uint64_t PathMismatchCount = 0; - uint64_t HashMismatchCount = 0; - std::atomic<uint64_t> CachedCount = 0; - std::atomic<uint64_t> CachedByteCount = 0; - uint64_t SkippedCount = 0; - uint64_t DeleteCount = 0; - for (uint32_t LocalPathIndex = 0; LocalPathIndex < m_LocalContent.Paths.size(); LocalPathIndex++) - { - if (m_AbortFlag) - { - break; - } - const IoHash& RawHash = m_LocalContent.RawHashes[LocalPathIndex]; - const std::filesystem::path& LocalPath = m_LocalContent.Paths[LocalPathIndex]; - - ZEN_ASSERT_SLOW(IsFile((m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred())); - - if (m_Options.EnableTargetFolderScavenging) - { - if (!m_Options.WipeTargetFolder) - { - // Check if it is already in the correct place - if (auto RemotePathIt = RemotePathToRemoteIndex.find(LocalPath.generic_string()); - RemotePathIt != RemotePathToRemoteIndex.end()) - { - const uint32_t RemotePathIndex = RemotePathIt->second; - if (m_RemoteContent.RawHashes[RemotePathIndex] == RawHash) - { - // It is already in it's correct place - RemotePathIndexToLocalPathIndex[RemotePathIndex] = LocalPathIndex; - SequenceHashToLocalPathIndex.insert({RawHash, LocalPathIndex}); - MatchCount++; - continue; - } - else - { - HashMismatchCount++; - } - } - else - { - PathMismatchCount++; - } - } - - // Do we need it? - if (m_RemoteLookup.RawHashToSequenceIndex.contains(RawHash)) - { - if (!CachedRemoteSequences.contains(RawHash)) - { - // We need it, make sure we move it to the cache - FilesToCache.push_back(LocalPathIndex); - CachedRemoteSequences.insert(RawHash); - continue; - } - else - { - SkippedCount++; - } - } - } - - if (!m_Options.WipeTargetFolder) - { - // Explicitly delete the unneeded local file - RemoveLocalPathIndexes.push_back(LocalPathIndex); - DeleteCount++; - } - } - - if (m_AbortFlag) - { - return; - } - - { - ZEN_TRACE_CPU("CopyToCache"); - - Stopwatch Timer; - - std::unique_ptr<OperationLogOutput::ProgressBar> CacheLocalProgressBarPtr( - m_LogOutput.CreateProgressBar("Cache Local Data")); - OperationLogOutput::ProgressBar& CacheLocalProgressBar(*CacheLocalProgressBarPtr); - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); - - for (uint32_t LocalPathIndex : FilesToCache) - { - if (m_AbortFlag) - { - break; - } - Work.ScheduleWork(m_IOWorkerPool, [this, &CachedCount, &CachedByteCount, LocalPathIndex](std::atomic<bool>&) { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_CopyToCache"); - - const IoHash& RawHash = m_LocalContent.RawHashes[LocalPathIndex]; - const std::filesystem::path& LocalPath = m_LocalContent.Paths[LocalPathIndex]; - const std::filesystem::path CacheFilePath = GetFinalChunkedSequenceFileName(m_CacheFolderPath, RawHash); - ZEN_ASSERT_SLOW(!IsFileWithRetry(CacheFilePath)); - const std::filesystem::path LocalFilePath = (m_Path / LocalPath).make_preferred(); - - std::error_code Ec = RenameFileWithRetry(LocalFilePath, CacheFilePath); - if (Ec) - { - ZEN_OPERATION_LOG_WARN(m_LogOutput, - "Failed to move file from '{}' to '{}', reason: ({}) {}, retrying...", - LocalFilePath, - CacheFilePath, - Ec.value(), - Ec.message()); - Ec = RenameFileWithRetry(LocalFilePath, CacheFilePath); - if (Ec) - { - throw std::system_error(std::error_code(Ec.value(), std::system_category()), - fmt::format("Failed to file from '{}' to '{}', reason: ({}) {}", - LocalFilePath, - CacheFilePath, - Ec.value(), - Ec.message())); - } - } - - CachedCount++; - CachedByteCount += m_LocalContent.RawSizes[LocalPathIndex]; - } - }); - } - - { - ZEN_TRACE_CPU("CopyToCache_Wait"); - - Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { - ZEN_UNUSED(PendingWork); - const uint64_t WorkTotal = FilesToCache.size(); - const uint64_t WorkComplete = CachedCount.load(); - std::string Details = fmt::format("{}/{} ({}) files", WorkComplete, WorkTotal, NiceBytes(CachedByteCount)); - CacheLocalProgressBar.UpdateState( - {.Task = "Caching local ", - .Details = Details, - .TotalCount = gsl::narrow<uint64_t>(WorkTotal), - .RemainingCount = gsl::narrow<uint64_t>(WorkTotal - WorkComplete), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, - false); - }); - } - - CacheLocalProgressBar.Finish(); - if (m_AbortFlag) - { - return; - } - - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Local state prep: Match: {}, PathMismatch: {}, HashMismatch: {}, Cached: {} ({}), Skipped: {}, " - "Delete: {}", - MatchCount, - PathMismatchCount, - HashMismatchCount, - CachedCount.load(), - NiceBytes(CachedByteCount.load()), - SkippedCount, - DeleteCount); - } + return; } - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::FinalizeTarget, (uint32_t)TaskSteps::StepCount); + ZEN_DEBUG( + "Local state prep: Match: {}, PathMismatch: {}, HashMismatch: {}, Cached: {} ({}), Skipped: {}, " + "Delete: {}", + Categorization.MatchCount, + Categorization.PathMismatchCount, + Categorization.HashMismatchCount, + CachedCount.load(), + NiceBytes(CachedByteCount.load()), + Categorization.SkippedCount, + Categorization.DeleteCount); + + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::FinalizeTarget, (uint32_t)TaskSteps::StepCount); if (m_Options.WipeTargetFolder) { @@ -2027,9 +1103,16 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) Stopwatch Timer; // Clean target folder - if (!CleanDirectory(m_LogOutput, m_IOWorkerPool, m_AbortFlag, m_PauseFlag, m_Options.IsQuiet, m_Path, m_Options.ExcludeFolders)) + if (!CleanDirectory(Log(), + m_Progress, + m_IOWorkerPool, + m_AbortFlag, + m_PauseFlag, + m_Options.IsQuiet, + m_Path, + m_Options.ExcludeFolders)) { - ZEN_OPERATION_LOG_WARN(m_LogOutput, "Some files in {} could not be removed", m_Path); + ZEN_WARN("Some files in {} could not be removed", m_Path); } m_RebuildFolderStateStats.CleanFolderElapsedWallTimeUs = Timer.GetElapsedTimeUs(); } @@ -2044,315 +1127,49 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) Stopwatch Timer; - std::unique_ptr<OperationLogOutput::ProgressBar> RebuildProgressBarPtr(m_LogOutput.CreateProgressBar("Rebuild State")); - OperationLogOutput::ProgressBar& RebuildProgressBar(*RebuildProgressBarPtr); - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Rebuild State"); + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); OutLocalFolderState.Paths.resize(m_RemoteContent.Paths.size()); OutLocalFolderState.RawSizes.resize(m_RemoteContent.Paths.size()); OutLocalFolderState.Attributes.resize(m_RemoteContent.Paths.size()); OutLocalFolderState.ModificationTicks.resize(m_RemoteContent.Paths.size()); - std::atomic<uint64_t> DeletedCount = 0; - - for (uint32_t LocalPathIndex : RemoveLocalPathIndexes) - { - if (m_AbortFlag) - { - break; - } - Work.ScheduleWork(m_IOWorkerPool, [this, &DeletedCount, LocalPathIndex](std::atomic<bool>&) { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_RemoveFile"); - - const std::filesystem::path LocalFilePath = (m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred(); - SetFileReadOnlyWithRetry(LocalFilePath, false); - RemoveFileWithRetry(LocalFilePath); - DeletedCount++; - } - }); - } - + std::atomic<uint64_t> DeletedCount = 0; std::atomic<uint64_t> TargetsComplete = 0; - struct FinalizeTarget - { - IoHash RawHash; - uint32_t RemotePathIndex; - }; - - std::vector<FinalizeTarget> Targets; - Targets.reserve(m_RemoteContent.Paths.size()); - for (uint32_t RemotePathIndex = 0; RemotePathIndex < m_RemoteContent.Paths.size(); RemotePathIndex++) - { - Targets.push_back( - FinalizeTarget{.RawHash = m_RemoteContent.RawHashes[RemotePathIndex], .RemotePathIndex = RemotePathIndex}); - } - std::sort(Targets.begin(), Targets.end(), [](const FinalizeTarget& Lhs, const FinalizeTarget& Rhs) { - if (Lhs.RawHash < Rhs.RawHash) - { - return true; - } - else if (Lhs.RawHash > Rhs.RawHash) - { - return false; - } - return Lhs.RemotePathIndex < Rhs.RemotePathIndex; - }); - - size_t TargetOffset = 0; - while (TargetOffset < Targets.size()) - { - if (m_AbortFlag) - { - break; - } - - size_t TargetCount = 1; - while ((TargetOffset + TargetCount) < Targets.size() && - (Targets[TargetOffset + TargetCount].RawHash == Targets[TargetOffset].RawHash)) - { - TargetCount++; - } - - Work.ScheduleWork( - m_IOWorkerPool, - [this, - &SequenceHashToLocalPathIndex, - &Targets, - &RemotePathIndexToLocalPathIndex, - &OutLocalFolderState, - BaseTargetOffset = TargetOffset, - TargetCount, - &TargetsComplete](std::atomic<bool>&) { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_FinalizeChunkSequence"); + ScheduleLocalFileRemovals(Work, Categorization.RemoveLocalPathIndexes, DeletedCount); - size_t TargetOffset = BaseTargetOffset; - const IoHash& RawHash = Targets[TargetOffset].RawHash; + std::vector<FinalizeTarget> Targets = BuildSortedFinalizeTargets(); - if (RawHash == IoHash::Zero) - { - ZEN_TRACE_CPU("CreateEmptyFiles"); - while (TargetOffset < (BaseTargetOffset + TargetCount)) - { - const uint32_t RemotePathIndex = Targets[TargetOffset].RemotePathIndex; - ZEN_ASSERT(Targets[TargetOffset].RawHash == RawHash); - const std::filesystem::path& TargetPath = m_RemoteContent.Paths[RemotePathIndex]; - std::filesystem::path TargetFilePath = (m_Path / TargetPath).make_preferred(); - if (!RemotePathIndexToLocalPathIndex[RemotePathIndex]) - { - if (IsFileWithRetry(TargetFilePath)) - { - SetFileReadOnlyWithRetry(TargetFilePath, false); - } - else - { - CreateDirectories(TargetFilePath.parent_path()); - } - BasicFile OutputFile; - OutputFile.Open(TargetFilePath, BasicFile::Mode::kTruncate); - } - OutLocalFolderState.Paths[RemotePathIndex] = TargetPath; - OutLocalFolderState.RawSizes[RemotePathIndex] = m_RemoteContent.RawSizes[RemotePathIndex]; - - OutLocalFolderState.Attributes[RemotePathIndex] = - m_RemoteContent.Attributes.empty() - ? GetNativeFileAttributes(TargetFilePath) - : SetNativeFileAttributes(TargetFilePath, - m_RemoteContent.Platform, - m_RemoteContent.Attributes[RemotePathIndex]); - OutLocalFolderState.ModificationTicks[RemotePathIndex] = GetModificationTickFromPath(TargetFilePath); - - TargetOffset++; - TargetsComplete++; - } - } - else - { - ZEN_TRACE_CPU("FinalizeFile"); - ZEN_ASSERT(m_RemoteLookup.RawHashToSequenceIndex.contains(RawHash)); - const uint32_t FirstRemotePathIndex = Targets[TargetOffset].RemotePathIndex; - const std::filesystem::path& FirstTargetPath = m_RemoteContent.Paths[FirstRemotePathIndex]; - std::filesystem::path FirstTargetFilePath = (m_Path / FirstTargetPath).make_preferred(); - - if (auto InPlaceIt = RemotePathIndexToLocalPathIndex.find(FirstRemotePathIndex); - InPlaceIt != RemotePathIndexToLocalPathIndex.end()) - { - ZEN_ASSERT_SLOW(IsFileWithRetry(FirstTargetFilePath)); - } - else - { - if (IsFileWithRetry(FirstTargetFilePath)) - { - SetFileReadOnlyWithRetry(FirstTargetFilePath, false); - } - else - { - CreateDirectories(FirstTargetFilePath.parent_path()); - } - - if (auto InplaceIt = SequenceHashToLocalPathIndex.find(RawHash); - InplaceIt != SequenceHashToLocalPathIndex.end()) - { - ZEN_TRACE_CPU("Copy"); - const uint32_t LocalPathIndex = InplaceIt->second; - const std::filesystem::path& SourcePath = m_LocalContent.Paths[LocalPathIndex]; - std::filesystem::path SourceFilePath = (m_Path / SourcePath).make_preferred(); - ZEN_ASSERT_SLOW(IsFileWithRetry(SourceFilePath)); - - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Copying from '{}' -> '{}'", - SourceFilePath, - FirstTargetFilePath); - const uint64_t RawSize = m_LocalContent.RawSizes[LocalPathIndex]; - FastCopyFile(m_Options.AllowFileClone, - m_Options.UseSparseFiles, - SourceFilePath, - FirstTargetFilePath, - RawSize, - m_DiskStats.WriteCount, - m_DiskStats.WriteByteCount, - m_DiskStats.CloneCount, - m_DiskStats.CloneByteCount); - - m_RebuildFolderStateStats.FinalizeTreeFilesCopiedCount++; - } - else - { - ZEN_TRACE_CPU("Rename"); - const std::filesystem::path CacheFilePath = - GetFinalChunkedSequenceFileName(m_CacheFolderPath, RawHash); - ZEN_ASSERT_SLOW(IsFileWithRetry(CacheFilePath)); - - std::error_code Ec = RenameFileWithRetry(CacheFilePath, FirstTargetFilePath); - if (Ec) - { - ZEN_OPERATION_LOG_WARN(m_LogOutput, - "Failed to move file from '{}' to '{}', reason: ({}) {}, retrying...", - CacheFilePath, - FirstTargetFilePath, - Ec.value(), - Ec.message()); - Ec = RenameFileWithRetry(CacheFilePath, FirstTargetFilePath); - if (Ec) - { - throw std::system_error( - std::error_code(Ec.value(), std::system_category()), - fmt::format("Failed to move file from '{}' to '{}', reason: ({}) {}", - CacheFilePath, - FirstTargetFilePath, - Ec.value(), - Ec.message())); - } - } - - m_RebuildFolderStateStats.FinalizeTreeFilesMovedCount++; - } - } - - OutLocalFolderState.Paths[FirstRemotePathIndex] = FirstTargetPath; - OutLocalFolderState.RawSizes[FirstRemotePathIndex] = m_RemoteContent.RawSizes[FirstRemotePathIndex]; - - OutLocalFolderState.Attributes[FirstRemotePathIndex] = - m_RemoteContent.Attributes.empty() - ? GetNativeFileAttributes(FirstTargetFilePath) - : SetNativeFileAttributes(FirstTargetFilePath, - m_RemoteContent.Platform, - m_RemoteContent.Attributes[FirstRemotePathIndex]); - OutLocalFolderState.ModificationTicks[FirstRemotePathIndex] = - GetModificationTickFromPath(FirstTargetFilePath); - - TargetOffset++; - TargetsComplete++; - - while (TargetOffset < (BaseTargetOffset + TargetCount)) - { - const uint32_t RemotePathIndex = Targets[TargetOffset].RemotePathIndex; - ZEN_ASSERT(Targets[TargetOffset].RawHash == RawHash); - const std::filesystem::path& TargetPath = m_RemoteContent.Paths[RemotePathIndex]; - std::filesystem::path TargetFilePath = (m_Path / TargetPath).make_preferred(); - - if (auto InPlaceIt = RemotePathIndexToLocalPathIndex.find(RemotePathIndex); - InPlaceIt != RemotePathIndexToLocalPathIndex.end()) - { - ZEN_ASSERT_SLOW(IsFileWithRetry(TargetFilePath)); - } - else - { - ZEN_TRACE_CPU("Copy"); - if (IsFileWithRetry(TargetFilePath)) - { - SetFileReadOnlyWithRetry(TargetFilePath, false); - } - else - { - CreateDirectories(TargetFilePath.parent_path()); - } - - ZEN_ASSERT_SLOW(IsFileWithRetry(FirstTargetFilePath)); - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Copying from '{}' -> '{}'", - FirstTargetFilePath, - TargetFilePath); - const uint64_t RawSize = m_RemoteContent.RawSizes[RemotePathIndex]; - FastCopyFile(m_Options.AllowFileClone, - m_Options.UseSparseFiles, - FirstTargetFilePath, - TargetFilePath, - RawSize, - m_DiskStats.WriteCount, - m_DiskStats.WriteByteCount, - m_DiskStats.CloneCount, - m_DiskStats.CloneByteCount); - - m_RebuildFolderStateStats.FinalizeTreeFilesCopiedCount++; - } - - OutLocalFolderState.Paths[RemotePathIndex] = TargetPath; - OutLocalFolderState.RawSizes[RemotePathIndex] = m_RemoteContent.RawSizes[RemotePathIndex]; - - OutLocalFolderState.Attributes[RemotePathIndex] = - m_RemoteContent.Attributes.empty() - ? GetNativeFileAttributes(TargetFilePath) - : SetNativeFileAttributes(TargetFilePath, - m_RemoteContent.Platform, - m_RemoteContent.Attributes[RemotePathIndex]); - OutLocalFolderState.ModificationTicks[RemotePathIndex] = GetModificationTickFromPath(TargetFilePath); - - TargetOffset++; - TargetsComplete++; - } - } - } - }); - - TargetOffset += TargetCount; - } + ScheduleTargetFinalization(Work, + Targets, + Categorization.SequenceHashToLocalPathIndex, + Categorization.RemotePathIndexToLocalPathIndex, + OutLocalFolderState, + TargetsComplete); { ZEN_TRACE_CPU("FinalizeTree_Wait"); - Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { + Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { ZEN_UNUSED(PendingWork); - const uint64_t WorkTotal = Targets.size() + RemoveLocalPathIndexes.size(); + const uint64_t WorkTotal = Targets.size() + Categorization.RemoveLocalPathIndexes.size(); const uint64_t WorkComplete = TargetsComplete.load() + DeletedCount.load(); std::string Details = fmt::format("{}/{} files", WorkComplete, WorkTotal); - RebuildProgressBar.UpdateState({.Task = "Rebuilding state ", - .Details = Details, - .TotalCount = gsl::narrow<uint64_t>(WorkTotal), - .RemainingCount = gsl::narrow<uint64_t>(WorkTotal - WorkComplete), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, - false); + ProgressBar->UpdateState({.Task = "Rebuilding state ", + .Details = Details, + .TotalCount = gsl::narrow<uint64_t>(WorkTotal), + .RemainingCount = gsl::narrow<uint64_t>(WorkTotal - WorkComplete), + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + false); }); } m_RebuildFolderStateStats.FinalizeTreeElapsedWallTimeUs = Timer.GetElapsedTimeUs(); - RebuildProgressBar.Finish(); + ProgressBar->Finish(); } - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount); + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount); } catch (const std::exception&) { @@ -2416,11 +1233,7 @@ BuildsOperationUpdateFolder::ScanCacheFolder(tsl::robin_map<IoHash, uint32_t, Io std::error_code Ec = TryRemoveFile(CacheDirContent.Files[Index]); if (Ec) { - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Failed removing file '{}', reason: ({}) {}", - CacheDirContent.Files[Index], - Ec.value(), - Ec.message()); + ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", CacheDirContent.Files[Index], Ec.value(), Ec.message()); } } m_CacheMappingStats.CacheScanElapsedWallTimeUs += CacheTimer.GetElapsedTimeUs(); @@ -2476,17 +1289,1302 @@ BuildsOperationUpdateFolder::ScanTempBlocksFolder(tsl::robin_map<IoHash, uint32_ std::error_code Ec = TryRemoveFile(BlockDirContent.Files[Index]); if (Ec) { - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Failed removing file '{}', reason: ({}) {}", - BlockDirContent.Files[Index], - Ec.value(), - Ec.message()); + ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", BlockDirContent.Files[Index], Ec.value(), Ec.message()); } } m_CacheMappingStats.CacheScanElapsedWallTimeUs += CacheTimer.GetElapsedTimeUs(); } +void +BuildsOperationUpdateFolder::InitializeSequenceCounters(std::vector<std::atomic<uint32_t>>& OutSequenceCounters, + tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutSequencesLeftToFind, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedChunkHashesFound, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedSequenceHashesFound) +{ + if (m_Options.EnableTargetFolderScavenging) + { + // Pick up all whole files we can use from current local state + ZEN_TRACE_CPU("GetLocalSequences"); + + std::vector<uint32_t> MissingSequenceIndexes = ScanTargetFolder(CachedChunkHashesFound, CachedSequenceHashesFound); + + for (uint32_t RemoteSequenceIndex : MissingSequenceIndexes) + { + // We must write the sequence + const uint32_t ChunkCount = m_RemoteContent.ChunkedContent.ChunkCounts[RemoteSequenceIndex]; + const IoHash& RemoteSequenceRawHash = m_RemoteContent.ChunkedContent.SequenceRawHashes[RemoteSequenceIndex]; + OutSequenceCounters[RemoteSequenceIndex] = ChunkCount; + OutSequencesLeftToFind.insert({RemoteSequenceRawHash, RemoteSequenceIndex}); + } + } + else + { + for (uint32_t RemoteSequenceIndex = 0; RemoteSequenceIndex < m_RemoteContent.ChunkedContent.SequenceRawHashes.size(); + RemoteSequenceIndex++) + { + OutSequenceCounters[RemoteSequenceIndex] = m_RemoteContent.ChunkedContent.ChunkCounts[RemoteSequenceIndex]; + } + } +} + +void +BuildsOperationUpdateFolder::MatchScavengedSequencesToRemote(std::span<const ChunkedFolderContent> Contents, + std::span<const ChunkedContentLookup> Lookups, + std::span<const std::filesystem::path> Paths, + tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& InOutSequencesLeftToFind, + std::vector<std::atomic<uint32_t>>& InOutSequenceCounters, + std::vector<ScavengedSequenceCopyOperation>& OutCopyOperations, + uint64_t& OutScavengedPathsCount) +{ + for (uint32_t ScavengedContentIndex = 0; ScavengedContentIndex < Contents.size() && !InOutSequencesLeftToFind.empty(); + ScavengedContentIndex++) + { + const std::filesystem::path& ScavengePath = Paths[ScavengedContentIndex]; + if (ScavengePath.empty()) + { + continue; + } + const ChunkedFolderContent& ScavengedLocalContent = Contents[ScavengedContentIndex]; + const ChunkedContentLookup& ScavengedLookup = Lookups[ScavengedContentIndex]; + + for (uint32_t ScavengedSequenceIndex = 0; ScavengedSequenceIndex < ScavengedLocalContent.ChunkedContent.SequenceRawHashes.size(); + ScavengedSequenceIndex++) + { + const IoHash& SequenceRawHash = ScavengedLocalContent.ChunkedContent.SequenceRawHashes[ScavengedSequenceIndex]; + auto It = InOutSequencesLeftToFind.find(SequenceRawHash); + if (It == InOutSequencesLeftToFind.end()) + { + continue; + } + const uint32_t RemoteSequenceIndex = It->second; + const uint64_t RawSize = m_RemoteContent.RawSizes[m_RemoteLookup.SequenceIndexFirstPathIndex[RemoteSequenceIndex]]; + ZEN_ASSERT(RawSize > 0); + + const uint32_t ScavengedPathIndex = ScavengedLookup.SequenceIndexFirstPathIndex[ScavengedSequenceIndex]; + ZEN_ASSERT_SLOW(IsFile((ScavengePath / ScavengedLocalContent.Paths[ScavengedPathIndex]).make_preferred())); + + OutCopyOperations.push_back({.ScavengedContentIndex = ScavengedContentIndex, + .ScavengedPathIndex = ScavengedPathIndex, + .RemoteSequenceIndex = RemoteSequenceIndex, + .RawSize = RawSize}); + + InOutSequencesLeftToFind.erase(SequenceRawHash); + InOutSequenceCounters[RemoteSequenceIndex] = 0; + + m_CacheMappingStats.ScavengedPathsMatchingSequencesCount++; + m_CacheMappingStats.ScavengedPathsMatchingSequencesByteCount += RawSize; + } + OutScavengedPathsCount++; + } +} + +uint64_t +BuildsOperationUpdateFolder::CalculateBytesToWriteAndFlagNeededChunks(std::span<const std::atomic<uint32_t>> SequenceCounters, + const std::vector<bool>& NeedsCopyFromLocalFileFlags, + std::span<std::atomic<bool>> OutNeedsCopyFromSourceFlags) +{ + uint64_t BytesToWrite = 0; + for (uint32_t RemoteChunkIndex = 0; RemoteChunkIndex < m_RemoteContent.ChunkedContent.ChunkHashes.size(); RemoteChunkIndex++) + { + const uint64_t ChunkWriteCount = GetChunkWriteCount(SequenceCounters, RemoteChunkIndex); + if (ChunkWriteCount > 0) + { + BytesToWrite += m_RemoteContent.ChunkedContent.ChunkRawSizes[RemoteChunkIndex] * ChunkWriteCount; + if (!NeedsCopyFromLocalFileFlags[RemoteChunkIndex]) + { + OutNeedsCopyFromSourceFlags[RemoteChunkIndex] = true; + } + } + } + return BytesToWrite; +} + +void +BuildsOperationUpdateFolder::ClassifyCachedAndFetchBlocks(std::span<const ChunkBlockAnalyser::NeededBlock> NeededBlocks, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedBlocksFound, + uint64_t& TotalPartWriteCount, + std::vector<uint32_t>& OutCachedChunkBlockIndexes, + std::vector<uint32_t>& OutFetchBlockIndexes) +{ + ZEN_TRACE_CPU("BlockCacheFileExists"); + for (const ChunkBlockAnalyser::NeededBlock& NeededBlock : NeededBlocks) + { + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[NeededBlock.BlockIndex]; + bool UsingCachedBlock = false; + if (auto It = CachedBlocksFound.find(BlockDescription.BlockHash); It != CachedBlocksFound.end()) + { + TotalPartWriteCount++; + + std::filesystem::path BlockPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(); + if (IsFile(BlockPath)) + { + OutCachedChunkBlockIndexes.push_back(NeededBlock.BlockIndex); + UsingCachedBlock = true; + } + } + if (!UsingCachedBlock) + { + OutFetchBlockIndexes.push_back(NeededBlock.BlockIndex); + } + } +} + +std::vector<uint32_t> +BuildsOperationUpdateFolder::DetermineNeededLooseChunkIndexes(std::span<const std::atomic<uint32_t>> SequenceCounters, + const std::vector<bool>& NeedsCopyFromLocalFileFlags, + std::span<std::atomic<bool>> NeedsCopyFromSourceFlags) +{ + std::vector<uint32_t> NeededLooseChunkIndexes; + NeededLooseChunkIndexes.reserve(m_LooseChunkHashes.size()); + for (uint32_t LooseChunkIndex = 0; LooseChunkIndex < m_LooseChunkHashes.size(); LooseChunkIndex++) + { + const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex]; + auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); + ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end()); + const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second; + + if (NeedsCopyFromLocalFileFlags[RemoteChunkIndex]) + { + if (m_Options.IsVerbose) + { + ZEN_INFO("Skipping chunk {} due to cache reuse", m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]); + } + continue; + } + + bool NeedsCopy = true; + if (NeedsCopyFromSourceFlags[RemoteChunkIndex].compare_exchange_strong(NeedsCopy, false)) + { + const uint64_t WriteCount = GetChunkWriteCount(SequenceCounters, RemoteChunkIndex); + if (WriteCount == 0) + { + if (m_Options.IsVerbose) + { + ZEN_INFO("Skipping chunk {} due to cache reuse", m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]); + } + } + else + { + NeededLooseChunkIndexes.push_back(LooseChunkIndex); + } + } + } + return NeededLooseChunkIndexes; +} + +BuildsOperationUpdateFolder::BlobsExistsResult +BuildsOperationUpdateFolder::QueryBlobCacheExists(std::span<const uint32_t> NeededLooseChunkIndexes, + std::span<const uint32_t> FetchBlockIndexes) +{ + BlobsExistsResult Result; + if (!m_Storage.CacheStorage) + { + return Result; + } + + ZEN_TRACE_CPU("BlobCacheExistCheck"); + Stopwatch Timer; + + std::vector<IoHash> BlobHashes; + BlobHashes.reserve(NeededLooseChunkIndexes.size() + FetchBlockIndexes.size()); + + for (const uint32_t LooseChunkIndex : NeededLooseChunkIndexes) + { + BlobHashes.push_back(m_LooseChunkHashes[LooseChunkIndex]); + } + + for (uint32_t BlockIndex : FetchBlockIndexes) + { + BlobHashes.push_back(m_BlockDescriptions[BlockIndex].BlockHash); + } + + const std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult = m_Storage.CacheStorage->BlobsExists(m_BuildId, BlobHashes); + + if (CacheExistsResult.size() == BlobHashes.size()) + { + Result.ExistingBlobs.reserve(CacheExistsResult.size()); + for (size_t BlobIndex = 0; BlobIndex < BlobHashes.size(); BlobIndex++) + { + if (CacheExistsResult[BlobIndex].HasBody) + { + Result.ExistingBlobs.insert(BlobHashes[BlobIndex]); + } + } + } + Result.ElapsedTimeMs = Timer.GetElapsedTimeMs(); + if (!Result.ExistingBlobs.empty() && !m_Options.IsQuiet) + { + ZEN_INFO("Remote cache : Found {} out of {} needed blobs in {}", + Result.ExistingBlobs.size(), + BlobHashes.size(), + NiceTimeSpanMs(Result.ElapsedTimeMs)); + } + return Result; +} + +std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> +BuildsOperationUpdateFolder::DeterminePartialDownloadModes(const BlobsExistsResult& ExistsResult) +{ + std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> Modes; + + if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off) + { + Modes.resize(m_BlockDescriptions.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); + return Modes; + } + + const bool MultiRangeCache = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1; + const bool MultiRangeBuild = m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest > 1; + ChunkBlockAnalyser::EPartialBlockDownloadMode CachePartialDownloadMode = + MultiRangeCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + ChunkBlockAnalyser::EPartialBlockDownloadMode CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + + switch (m_Options.PartialBlockRequestMode) + { + case EPartialBlockRequestMode::Off: + break; + case EPartialBlockRequestMode::ZenCacheOnly: + CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + break; + case EPartialBlockRequestMode::Mixed: + CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; + break; + case EPartialBlockRequestMode::All: + CloudPartialDownloadMode = MultiRangeBuild ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange + : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; + break; + default: + ZEN_ASSERT(false); + break; + } + + Modes.reserve(m_BlockDescriptions.size()); + for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++) + { + const bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash); + Modes.push_back(BlockExistInCache ? CachePartialDownloadMode : CloudPartialDownloadMode); + } + return Modes; +} + +std::vector<BuildsOperationUpdateFolder::LooseChunkHashWorkData> +BuildsOperationUpdateFolder::BuildLooseChunkHashWorks(std::span<const uint32_t> NeededLooseChunkIndexes, + std::span<const std::atomic<uint32_t>> SequenceCounters) +{ + std::vector<LooseChunkHashWorkData> LooseChunkHashWorks; + LooseChunkHashWorks.reserve(NeededLooseChunkIndexes.size()); + for (uint32_t LooseChunkIndex : NeededLooseChunkIndexes) + { + const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex]; + auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); + ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end()); + const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second; + + std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs = + GetRemainingChunkTargets(SequenceCounters, RemoteChunkIndex); + + ZEN_ASSERT(!ChunkTargetPtrs.empty()); + LooseChunkHashWorks.push_back(LooseChunkHashWorkData{.ChunkTargetPtrs = ChunkTargetPtrs, .RemoteChunkIndex = RemoteChunkIndex}); + } + return LooseChunkHashWorks; +} + +void +BuildsOperationUpdateFolder::VerifyWriteChunksComplete(std::span<const std::atomic<uint32_t>> SequenceCounters, + uint64_t BytesToWrite, + uint64_t BytesToValidate) +{ + uint32_t RawSequencesMissingWriteCount = 0; + for (uint32_t SequenceIndex = 0; SequenceIndex < SequenceCounters.size(); SequenceIndex++) + { + const auto& Counter = SequenceCounters[SequenceIndex]; + if (Counter.load() != 0) + { + RawSequencesMissingWriteCount++; + const uint32_t PathIndex = m_RemoteLookup.SequenceIndexFirstPathIndex[SequenceIndex]; + const std::filesystem::path& IncompletePath = m_RemoteContent.Paths[PathIndex]; + ZEN_ASSERT(!IncompletePath.empty()); + const uint32_t ExpectedSequenceCount = m_RemoteContent.ChunkedContent.ChunkCounts[SequenceIndex]; + if (!m_Options.IsQuiet) + { + ZEN_INFO("{}: Max count {}, Current count {}", IncompletePath, ExpectedSequenceCount, Counter.load()); + } + ZEN_ASSERT(Counter.load() <= ExpectedSequenceCount); + } + } + ZEN_ASSERT(RawSequencesMissingWriteCount == 0); + ZEN_ASSERT(m_WrittenChunkByteCount == BytesToWrite); + ZEN_ASSERT(m_ValidatedChunkByteCount == BytesToValidate); +} + +std::vector<BuildsOperationUpdateFolder::FinalizeTarget> +BuildsOperationUpdateFolder::BuildSortedFinalizeTargets() +{ + std::vector<FinalizeTarget> Targets; + Targets.reserve(m_RemoteContent.Paths.size()); + for (uint32_t RemotePathIndex = 0; RemotePathIndex < m_RemoteContent.Paths.size(); RemotePathIndex++) + { + Targets.push_back(FinalizeTarget{.RawHash = m_RemoteContent.RawHashes[RemotePathIndex], .RemotePathIndex = RemotePathIndex}); + } + std::sort(Targets.begin(), Targets.end(), [](const FinalizeTarget& Lhs, const FinalizeTarget& Rhs) { + return std::tie(Lhs.RawHash, Lhs.RemotePathIndex) < std::tie(Rhs.RawHash, Rhs.RemotePathIndex); + }); + return Targets; +} + +void +BuildsOperationUpdateFolder::ScanScavengeSources(std::span<const ScavengeSource> Sources, + std::vector<ChunkedFolderContent>& OutContents, + std::vector<ChunkedContentLookup>& OutLookups, + std::vector<std::filesystem::path>& OutPaths) +{ + ZEN_TRACE_CPU("ScanScavengeSources"); + + const size_t ScavengePathCount = Sources.size(); + OutContents.resize(ScavengePathCount); + OutLookups.resize(ScavengePathCount); + OutPaths.resize(ScavengePathCount); + + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Scavenging"); + + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + std::atomic<uint64_t> PathsFound(0); + std::atomic<uint64_t> ChunksFound(0); + std::atomic<uint64_t> PathsScavenged(0); + + for (size_t ScavengeIndex = 0; ScavengeIndex < ScavengePathCount; ScavengeIndex++) + { + Work.ScheduleWork(m_IOWorkerPool, + [this, &Sources, &OutContents, &OutPaths, &OutLookups, &PathsFound, &ChunksFound, &PathsScavenged, ScavengeIndex]( + std::atomic<bool>&) { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("Async_FindScavengeContent"); + + const ScavengeSource& Source = Sources[ScavengeIndex]; + ChunkedFolderContent& ScavengedLocalContent = OutContents[ScavengeIndex]; + ChunkedContentLookup& ScavengedLookup = OutLookups[ScavengeIndex]; + + if (FindScavengeContent(Source, ScavengedLocalContent, ScavengedLookup)) + { + OutPaths[ScavengeIndex] = Source.Path; + PathsFound += ScavengedLocalContent.Paths.size(); + ChunksFound += ScavengedLocalContent.ChunkedContent.ChunkHashes.size(); + } + else + { + OutPaths[ScavengeIndex].clear(); + } + PathsScavenged++; + } + }); + } + { + ZEN_TRACE_CPU("ScavengeScan_Wait"); + + Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { + ZEN_UNUSED(PendingWork); + std::string Details = fmt::format("{}/{} scanned. {} paths and {} chunks found for scavenging", + PathsScavenged.load(), + ScavengePathCount, + PathsFound.load(), + ChunksFound.load()); + ProgressBar->UpdateState({.Task = "Scavenging ", + .Details = Details, + .TotalCount = ScavengePathCount, + .RemainingCount = ScavengePathCount - PathsScavenged.load(), + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + false); + }); + } + + ProgressBar->Finish(); +} + +BuildsOperationUpdateFolder::LocalPathCategorization +BuildsOperationUpdateFolder::CategorizeLocalPaths(const tsl::robin_map<std::string, uint32_t>& RemotePathToRemoteIndex) +{ + ZEN_TRACE_CPU("PrepareTarget"); + + LocalPathCategorization Result; + tsl::robin_set<IoHash, IoHash::Hasher> CachedRemoteSequences; + + Result.RemotePathIndexToLocalPathIndex.reserve(m_RemoteContent.Paths.size()); + + for (uint32_t LocalPathIndex = 0; LocalPathIndex < m_LocalContent.Paths.size(); LocalPathIndex++) + { + if (m_AbortFlag) + { + break; + } + const IoHash& RawHash = m_LocalContent.RawHashes[LocalPathIndex]; + const std::filesystem::path& LocalPath = m_LocalContent.Paths[LocalPathIndex]; + + ZEN_ASSERT_SLOW(IsFile((m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred())); + + if (m_Options.EnableTargetFolderScavenging) + { + if (!m_Options.WipeTargetFolder) + { + // Check if it is already in the correct place + if (auto RemotePathIt = RemotePathToRemoteIndex.find(LocalPath.generic_string()); + RemotePathIt != RemotePathToRemoteIndex.end()) + { + const uint32_t RemotePathIndex = RemotePathIt->second; + if (m_RemoteContent.RawHashes[RemotePathIndex] == RawHash) + { + // It is already in it's correct place + Result.RemotePathIndexToLocalPathIndex[RemotePathIndex] = LocalPathIndex; + Result.SequenceHashToLocalPathIndex.insert({RawHash, LocalPathIndex}); + Result.MatchCount++; + continue; + } + else + { + Result.HashMismatchCount++; + } + } + else + { + Result.PathMismatchCount++; + } + } + + // Do we need it? + if (m_RemoteLookup.RawHashToSequenceIndex.contains(RawHash)) + { + if (!CachedRemoteSequences.contains(RawHash)) + { + // We need it, make sure we move it to the cache + Result.FilesToCache.push_back(LocalPathIndex); + CachedRemoteSequences.insert(RawHash); + continue; + } + else + { + Result.SkippedCount++; + } + } + } + + if (!m_Options.WipeTargetFolder) + { + // Explicitly delete the unneeded local file + Result.RemoveLocalPathIndexes.push_back(LocalPathIndex); + Result.DeleteCount++; + } + } + + return Result; +} + +void +BuildsOperationUpdateFolder::ScheduleLocalFileCaching(std::span<const uint32_t> FilesToCache, + std::atomic<uint64_t>& OutCachedCount, + std::atomic<uint64_t>& OutCachedByteCount) +{ + ZEN_TRACE_CPU("CopyToCache"); + + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Cache Local Data"); + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + for (uint32_t LocalPathIndex : FilesToCache) + { + if (m_AbortFlag) + { + break; + } + Work.ScheduleWork(m_IOWorkerPool, [this, &OutCachedCount, &OutCachedByteCount, LocalPathIndex](std::atomic<bool>&) { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("Async_CopyToCache"); + + const IoHash& RawHash = m_LocalContent.RawHashes[LocalPathIndex]; + const std::filesystem::path& LocalPath = m_LocalContent.Paths[LocalPathIndex]; + const std::filesystem::path CacheFilePath = GetFinalChunkedSequenceFileName(m_CacheFolderPath, RawHash); + ZEN_ASSERT_SLOW(!IsFileWithRetry(CacheFilePath)); + const std::filesystem::path LocalFilePath = (m_Path / LocalPath).make_preferred(); + + std::error_code Ec = RenameFileWithRetry(LocalFilePath, CacheFilePath); + if (Ec) + { + ZEN_WARN("Failed to move file from '{}' to '{}', reason: ({}) {}, retrying...", + LocalFilePath, + CacheFilePath, + Ec.value(), + Ec.message()); + Ec = RenameFileWithRetry(LocalFilePath, CacheFilePath); + if (Ec) + { + throw std::system_error(std::error_code(Ec.value(), std::system_category()), + fmt::format("Failed to file from '{}' to '{}', reason: ({}) {}", + LocalFilePath, + CacheFilePath, + Ec.value(), + Ec.message())); + } + } + + OutCachedCount++; + OutCachedByteCount += m_LocalContent.RawSizes[LocalPathIndex]; + } + }); + } + + { + ZEN_TRACE_CPU("CopyToCache_Wait"); + + Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { + ZEN_UNUSED(PendingWork); + const uint64_t WorkTotal = FilesToCache.size(); + const uint64_t WorkComplete = OutCachedCount.load(); + std::string Details = fmt::format("{}/{} ({}) files", WorkComplete, WorkTotal, NiceBytes(OutCachedByteCount)); + ProgressBar->UpdateState({.Task = "Caching local ", + .Details = Details, + .TotalCount = gsl::narrow<uint64_t>(WorkTotal), + .RemainingCount = gsl::narrow<uint64_t>(WorkTotal - WorkComplete), + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + false); + }); + } + + ProgressBar->Finish(); +} + +void +BuildsOperationUpdateFolder::ScheduleScavengedSequenceWrites(WriteChunksContext& Context, + std::span<const ScavengedSequenceCopyOperation> CopyOperations, + const std::vector<ChunkedFolderContent>& ScavengedContents, + const std::vector<std::filesystem::path>& ScavengedPaths) +{ + for (uint32_t ScavengeOpIndex = 0; ScavengeOpIndex < CopyOperations.size(); ScavengeOpIndex++) + { + if (m_AbortFlag) + { + break; + } + Context.Work.ScheduleWork( + m_IOWorkerPool, + [this, &Context, &CopyOperations, &ScavengedContents, &ScavengedPaths, ScavengeOpIndex](std::atomic<bool>&) { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("Async_WriteScavenged"); + + Context.FilteredWrittenBytesPerSecond.Start(); + + const ScavengedSequenceCopyOperation& ScavengeOp = CopyOperations[ScavengeOpIndex]; + const ChunkedFolderContent& ScavengedContent = ScavengedContents[ScavengeOp.ScavengedContentIndex]; + const std::filesystem::path& ScavengeRootPath = ScavengedPaths[ScavengeOp.ScavengedContentIndex]; + + WriteScavengedSequenceToCache(ScavengeRootPath, ScavengedContent, ScavengeOp); + + if (Context.WritePartsComplete.fetch_add(1) + 1 == Context.TotalPartWriteCount) + { + Context.FilteredWrittenBytesPerSecond.Stop(); + } + } + }); + } +} + +void +BuildsOperationUpdateFolder::ScheduleLooseChunkWrites(WriteChunksContext& Context, std::vector<LooseChunkHashWorkData>& LooseChunkHashWorks) +{ + for (uint32_t LooseChunkHashWorkIndex = 0; LooseChunkHashWorkIndex < LooseChunkHashWorks.size(); LooseChunkHashWorkIndex++) + { + if (m_AbortFlag) + { + break; + } + + Context.Work.ScheduleWork( + m_IOWorkerPool, + [this, &Context, &LooseChunkHashWorks, LooseChunkHashWorkIndex](std::atomic<bool>&) { + ZEN_TRACE_CPU("Async_ReadPreDownloadedChunk"); + if (!m_AbortFlag) + { + LooseChunkHashWorkData& LooseChunkHashWork = LooseChunkHashWorks[LooseChunkHashWorkIndex]; + const uint32_t RemoteChunkIndex = LooseChunkHashWork.RemoteChunkIndex; + WriteLooseChunk(RemoteChunkIndex, + Context.ExistsResult, + Context.SequenceIndexChunksLeftToWriteCounters, + Context.WritePartsComplete, + std::move(LooseChunkHashWork.ChunkTargetPtrs), + Context.WriteCache, + Context.Work, + Context.TotalRequestCount, + Context.TotalPartWriteCount, + Context.FilteredDownloadedBytesPerSecond, + Context.FilteredWrittenBytesPerSecond); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } +} + +void +BuildsOperationUpdateFolder::ScheduleLocalChunkCopies(WriteChunksContext& Context, + std::span<const CopyChunkData> CopyChunkDatas, + CloneQueryInterface* CloneQuery, + const std::vector<ChunkedFolderContent>& ScavengedContents, + const std::vector<ChunkedContentLookup>& ScavengedLookups, + const std::vector<std::filesystem::path>& ScavengedPaths) +{ + for (size_t CopyDataIndex = 0; CopyDataIndex < CopyChunkDatas.size(); CopyDataIndex++) + { + if (m_AbortFlag) + { + break; + } + + Context.Work.ScheduleWork( + m_IOWorkerPool, + [this, &Context, CloneQuery, &CopyChunkDatas, &ScavengedContents, &ScavengedLookups, &ScavengedPaths, CopyDataIndex]( + std::atomic<bool>&) { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("Async_CopyLocal"); + + Context.FilteredWrittenBytesPerSecond.Start(); + const CopyChunkData& CopyData = CopyChunkDatas[CopyDataIndex]; + + std::vector<uint32_t> WrittenSequenceIndexes = WriteLocalChunkToCache(CloneQuery, + CopyData, + ScavengedContents, + ScavengedLookups, + ScavengedPaths, + Context.WriteCache); + bool WritePartsDone = Context.WritePartsComplete.fetch_add(1) + 1 == Context.TotalPartWriteCount; + if (!m_AbortFlag) + { + if (WritePartsDone) + { + Context.FilteredWrittenBytesPerSecond.Stop(); + } + + // Write tracking, updating this must be done without any files open + std::vector<uint32_t> CompletedChunkSequences; + for (uint32_t RemoteSequenceIndex : WrittenSequenceIndexes) + { + if (CompleteSequenceChunk(RemoteSequenceIndex, Context.SequenceIndexChunksLeftToWriteCounters)) + { + CompletedChunkSequences.push_back(RemoteSequenceIndex); + } + } + Context.WriteCache.Close(CompletedChunkSequences); + VerifyAndCompleteChunkSequencesAsync(CompletedChunkSequences, Context.Work); + } + } + }); + } +} + +void +BuildsOperationUpdateFolder::ScheduleCachedBlockWrites(WriteChunksContext& Context, std::span<const uint32_t> CachedBlockIndexes) +{ + for (uint32_t BlockIndex : CachedBlockIndexes) + { + if (m_AbortFlag) + { + break; + } + + Context.Work.ScheduleWork(m_IOWorkerPool, [this, &Context, BlockIndex](std::atomic<bool>&) { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("Async_WriteCachedBlock"); + + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + Context.FilteredWrittenBytesPerSecond.Start(); + + std::filesystem::path BlockChunkPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(); + IoBuffer BlockBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath); + if (!BlockBuffer) + { + throw std::runtime_error(fmt::format("Can not read block {} at {}", BlockDescription.BlockHash, BlockChunkPath)); + } + + if (!m_AbortFlag) + { + if (!WriteChunksBlockToCache(BlockDescription, + Context.SequenceIndexChunksLeftToWriteCounters, + Context.Work, + CompositeBuffer(std::move(BlockBuffer)), + Context.RemoteChunkIndexNeedsCopyFromSourceFlags, + Context.WriteCache)) + { + std::error_code DummyEc; + RemoveFile(BlockChunkPath, DummyEc); + throw std::runtime_error(fmt::format("Block {} is malformed", BlockDescription.BlockHash)); + } + + std::error_code Ec = TryRemoveFile(BlockChunkPath); + if (Ec) + { + ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", BlockChunkPath, Ec.value(), Ec.message()); + } + + if (Context.WritePartsComplete.fetch_add(1) + 1 == Context.TotalPartWriteCount) + { + Context.FilteredWrittenBytesPerSecond.Stop(); + } + } + } + }); + } +} + +void +BuildsOperationUpdateFolder::SchedulePartialBlockDownloads(WriteChunksContext& Context, + const ChunkBlockAnalyser::BlockResult& PartialBlocks) +{ + for (size_t BlockRangeIndex = 0; BlockRangeIndex < PartialBlocks.BlockRanges.size();) + { + if (m_AbortFlag) + { + break; + } + + size_t RangeCount = 1; + size_t RangesLeft = PartialBlocks.BlockRanges.size() - BlockRangeIndex; + const ChunkBlockAnalyser::BlockRangeDescriptor& CurrentBlockRange = PartialBlocks.BlockRanges[BlockRangeIndex]; + while (RangeCount < RangesLeft && + CurrentBlockRange.BlockIndex == PartialBlocks.BlockRanges[BlockRangeIndex + RangeCount].BlockIndex) + { + RangeCount++; + } + + Context.Work.ScheduleWork( + m_NetworkPool, + [this, &Context, &PartialBlocks, BlockRangeStartIndex = BlockRangeIndex, RangeCount = RangeCount](std::atomic<bool>&) { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("Async_GetPartialBlockRanges"); + + Context.FilteredDownloadedBytesPerSecond.Start(); + + DownloadPartialBlock( + PartialBlocks.BlockRanges, + BlockRangeStartIndex, + RangeCount, + Context.ExistsResult, + Context.TotalRequestCount, + Context.FilteredDownloadedBytesPerSecond, + [this, &Context, &PartialBlocks](IoBuffer&& InMemoryBuffer, + const std::filesystem::path& OnDiskPath, + size_t BlockRangeStartIndex, + std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths) { + if (!m_AbortFlag) + { + Context.Work.ScheduleWork( + m_IOWorkerPool, + [this, + &Context, + &PartialBlocks, + BlockRangeStartIndex, + BlockChunkPath = std::filesystem::path(OnDiskPath), + BlockPartialBuffer = std::move(InMemoryBuffer), + OffsetAndLengths = + std::vector<std::pair<uint64_t, uint64_t>>(OffsetAndLengths.begin(), OffsetAndLengths.end())]( + std::atomic<bool>&) mutable { + if (!m_AbortFlag) + { + WritePartialBlockToCache(Context, + BlockRangeStartIndex, + std::move(BlockPartialBuffer), + BlockChunkPath, + OffsetAndLengths, + PartialBlocks); + } + }, + OnDiskPath.empty() ? WorkerThreadPool::EMode::DisableBacklog : WorkerThreadPool::EMode::EnableBacklog); + } + }); + } + }); + BlockRangeIndex += RangeCount; + } +} + +void +BuildsOperationUpdateFolder::WritePartialBlockToCache(WriteChunksContext& Context, + size_t BlockRangeStartIndex, + IoBuffer BlockPartialBuffer, + const std::filesystem::path& BlockChunkPath, + std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths, + const ChunkBlockAnalyser::BlockResult& PartialBlocks) +{ + ZEN_TRACE_CPU("Async_WritePartialBlock"); + + const uint32_t BlockIndex = PartialBlocks.BlockRanges[BlockRangeStartIndex].BlockIndex; + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + + if (BlockChunkPath.empty()) + { + ZEN_ASSERT(BlockPartialBuffer); + } + else + { + ZEN_ASSERT(!BlockPartialBuffer); + BlockPartialBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath); + if (!BlockPartialBuffer) + { + throw std::runtime_error(fmt::format("Could not open downloaded block {} from {}", BlockDescription.BlockHash, BlockChunkPath)); + } + } + + Context.FilteredWrittenBytesPerSecond.Start(); + + const size_t RangeCount = OffsetAndLengths.size(); + + for (size_t PartialRangeIndex = 0; PartialRangeIndex < RangeCount; PartialRangeIndex++) + { + const std::pair<uint64_t, uint64_t>& OffsetAndLength = OffsetAndLengths[PartialRangeIndex]; + IoBuffer BlockRangeBuffer(BlockPartialBuffer, OffsetAndLength.first, OffsetAndLength.second); + + const ChunkBlockAnalyser::BlockRangeDescriptor& RangeDescriptor = + PartialBlocks.BlockRanges[BlockRangeStartIndex + PartialRangeIndex]; + + if (!WritePartialBlockChunksToCache(BlockDescription, + Context.SequenceIndexChunksLeftToWriteCounters, + Context.Work, + CompositeBuffer(std::move(BlockRangeBuffer)), + RangeDescriptor.ChunkBlockIndexStart, + RangeDescriptor.ChunkBlockIndexStart + RangeDescriptor.ChunkBlockIndexCount - 1, + Context.RemoteChunkIndexNeedsCopyFromSourceFlags, + Context.WriteCache)) + { + std::error_code DummyEc; + RemoveFile(BlockChunkPath, DummyEc); + throw std::runtime_error(fmt::format("Partial block {} is malformed", BlockDescription.BlockHash)); + } + + if (Context.WritePartsComplete.fetch_add(1) + 1 == Context.TotalPartWriteCount) + { + Context.FilteredWrittenBytesPerSecond.Stop(); + } + } + std::error_code Ec = TryRemoveFile(BlockChunkPath); + if (Ec) + { + ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", BlockChunkPath, Ec.value(), Ec.message()); + } +} + +void +BuildsOperationUpdateFolder::ScheduleFullBlockDownloads(WriteChunksContext& Context, std::span<const uint32_t> FullBlockIndexes) +{ + for (uint32_t BlockIndex : FullBlockIndexes) + { + if (m_AbortFlag) + { + break; + } + + Context.Work.ScheduleWork(m_NetworkPool, [this, &Context, BlockIndex](std::atomic<bool>&) { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("Async_GetFullBlock"); + + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + + Context.FilteredDownloadedBytesPerSecond.Start(); + + IoBuffer BlockBuffer; + const bool ExistsInCache = + m_Storage.CacheStorage && Context.ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash); + if (ExistsInCache) + { + BlockBuffer = m_Storage.CacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); + } + if (!BlockBuffer) + { + try + { + BlockBuffer = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); + } + catch (const std::exception&) + { + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } + } + } + if (!m_AbortFlag) + { + if (!BlockBuffer) + { + throw std::runtime_error(fmt::format("Block {} is missing", BlockDescription.BlockHash)); + } + + uint64_t BlockSize = BlockBuffer.GetSize(); + m_DownloadStats.DownloadedBlockCount++; + m_DownloadStats.DownloadedBlockByteCount += BlockSize; + if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == Context.TotalRequestCount) + { + Context.FilteredDownloadedBytesPerSecond.Stop(); + } + + const bool PutInCache = !ExistsInCache && m_Storage.CacheStorage && m_Options.PopulateCache; + + std::filesystem::path BlockChunkPath = + TryMoveDownloadedChunk(BlockBuffer, + m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(), + /* ForceDiskBased */ PutInCache || (BlockSize > m_Options.MaximumInMemoryPayloadSize)); + + if (PutInCache) + { + ZEN_ASSERT(!BlockChunkPath.empty()); + IoBuffer CacheBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath); + if (CacheBuffer) + { + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlockDescription.BlockHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(CacheBuffer))); + } + } + + if (!m_AbortFlag) + { + Context.Work.ScheduleWork( + m_IOWorkerPool, + [this, &Context, BlockIndex, BlockChunkPath, BlockBuffer = std::move(BlockBuffer)](std::atomic<bool>&) mutable { + if (!m_AbortFlag) + { + WriteFullBlockToCache(Context, BlockIndex, std::move(BlockBuffer), BlockChunkPath); + } + }, + BlockChunkPath.empty() ? WorkerThreadPool::EMode::DisableBacklog : WorkerThreadPool::EMode::EnableBacklog); + } + } + } + }); + } +} + +void +BuildsOperationUpdateFolder::WriteFullBlockToCache(WriteChunksContext& Context, + uint32_t BlockIndex, + IoBuffer BlockBuffer, + const std::filesystem::path& BlockChunkPath) +{ + ZEN_TRACE_CPU("Async_WriteFullBlock"); + + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + + if (BlockChunkPath.empty()) + { + ZEN_ASSERT(BlockBuffer); + } + else + { + ZEN_ASSERT(!BlockBuffer); + BlockBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath); + if (!BlockBuffer) + { + throw std::runtime_error(fmt::format("Could not open dowloaded block {} from {}", BlockDescription.BlockHash, BlockChunkPath)); + } + } + + Context.FilteredWrittenBytesPerSecond.Start(); + if (!WriteChunksBlockToCache(BlockDescription, + Context.SequenceIndexChunksLeftToWriteCounters, + Context.Work, + CompositeBuffer(std::move(BlockBuffer)), + Context.RemoteChunkIndexNeedsCopyFromSourceFlags, + Context.WriteCache)) + { + std::error_code DummyEc; + RemoveFile(BlockChunkPath, DummyEc); + throw std::runtime_error(fmt::format("Block {} is malformed", BlockDescription.BlockHash)); + } + + if (!BlockChunkPath.empty()) + { + std::error_code Ec = TryRemoveFile(BlockChunkPath); + if (Ec) + { + ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", BlockChunkPath, Ec.value(), Ec.message()); + } + } + + if (Context.WritePartsComplete.fetch_add(1) + 1 == Context.TotalPartWriteCount) + { + Context.FilteredWrittenBytesPerSecond.Stop(); + } +} + +void +BuildsOperationUpdateFolder::ScheduleLocalFileRemovals(ParallelWork& Work, + std::span<const uint32_t> RemoveLocalPathIndexes, + std::atomic<uint64_t>& DeletedCount) +{ + for (uint32_t LocalPathIndex : RemoveLocalPathIndexes) + { + if (m_AbortFlag) + { + break; + } + Work.ScheduleWork(m_IOWorkerPool, [this, &DeletedCount, LocalPathIndex](std::atomic<bool>&) { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("Async_RemoveFile"); + + const std::filesystem::path LocalFilePath = (m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred(); + SetFileReadOnlyWithRetry(LocalFilePath, false); + RemoveFileWithRetry(LocalFilePath); + DeletedCount++; + } + }); + } +} + +void +BuildsOperationUpdateFolder::ScheduleTargetFinalization( + ParallelWork& Work, + std::span<const FinalizeTarget> Targets, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& SequenceHashToLocalPathIndex, + const tsl::robin_map<uint32_t, uint32_t>& RemotePathIndexToLocalPathIndex, + FolderContent& OutLocalFolderState, + std::atomic<uint64_t>& TargetsComplete) +{ + size_t TargetOffset = 0; + while (TargetOffset < Targets.size()) + { + if (m_AbortFlag) + { + break; + } + + size_t TargetCount = 1; + while ((TargetOffset + TargetCount) < Targets.size() && + (Targets[TargetOffset + TargetCount].RawHash == Targets[TargetOffset].RawHash)) + { + TargetCount++; + } + + Work.ScheduleWork(m_IOWorkerPool, + [this, + &SequenceHashToLocalPathIndex, + Targets, + &RemotePathIndexToLocalPathIndex, + &OutLocalFolderState, + BaseTargetOffset = TargetOffset, + TargetCount, + &TargetsComplete](std::atomic<bool>&) { + if (!m_AbortFlag) + { + FinalizeTargetGroup(BaseTargetOffset, + TargetCount, + Targets, + SequenceHashToLocalPathIndex, + RemotePathIndexToLocalPathIndex, + OutLocalFolderState, + TargetsComplete); + } + }); + + TargetOffset += TargetCount; + } +} + +void +BuildsOperationUpdateFolder::FinalizeTargetGroup(size_t BaseOffset, + size_t Count, + std::span<const FinalizeTarget> Targets, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& SequenceHashToLocalPathIndex, + const tsl::robin_map<uint32_t, uint32_t>& RemotePathIndexToLocalPathIndex, + FolderContent& OutLocalFolderState, + std::atomic<uint64_t>& TargetsComplete) +{ + ZEN_TRACE_CPU("Async_FinalizeChunkSequence"); + + size_t TargetOffset = BaseOffset; + const IoHash& RawHash = Targets[TargetOffset].RawHash; + + if (RawHash == IoHash::Zero) + { + ZEN_TRACE_CPU("CreateEmptyFiles"); + while (TargetOffset < (BaseOffset + Count)) + { + const uint32_t RemotePathIndex = Targets[TargetOffset].RemotePathIndex; + ZEN_ASSERT(Targets[TargetOffset].RawHash == RawHash); + const std::filesystem::path& TargetPath = m_RemoteContent.Paths[RemotePathIndex]; + std::filesystem::path TargetFilePath = (m_Path / TargetPath).make_preferred(); + auto InPlaceIt = RemotePathIndexToLocalPathIndex.find(RemotePathIndex); + if (InPlaceIt == RemotePathIndexToLocalPathIndex.end() || InPlaceIt->second == 0) + { + if (IsFileWithRetry(TargetFilePath)) + { + SetFileReadOnlyWithRetry(TargetFilePath, false); + } + else + { + CreateDirectories(TargetFilePath.parent_path()); + } + BasicFile OutputFile; + OutputFile.Open(TargetFilePath, BasicFile::Mode::kTruncate); + } + OutLocalFolderState.Paths[RemotePathIndex] = TargetPath; + OutLocalFolderState.RawSizes[RemotePathIndex] = m_RemoteContent.RawSizes[RemotePathIndex]; + + OutLocalFolderState.Attributes[RemotePathIndex] = + m_RemoteContent.Attributes.empty() + ? GetNativeFileAttributes(TargetFilePath) + : SetNativeFileAttributes(TargetFilePath, m_RemoteContent.Platform, m_RemoteContent.Attributes[RemotePathIndex]); + OutLocalFolderState.ModificationTicks[RemotePathIndex] = GetModificationTickFromPath(TargetFilePath); + + TargetOffset++; + TargetsComplete++; + } + } + else + { + ZEN_TRACE_CPU("FinalizeFile"); + ZEN_ASSERT(m_RemoteLookup.RawHashToSequenceIndex.contains(RawHash)); + const uint32_t FirstRemotePathIndex = Targets[TargetOffset].RemotePathIndex; + const std::filesystem::path& FirstTargetPath = m_RemoteContent.Paths[FirstRemotePathIndex]; + std::filesystem::path FirstTargetFilePath = (m_Path / FirstTargetPath).make_preferred(); + + if (auto InPlaceIt = RemotePathIndexToLocalPathIndex.find(FirstRemotePathIndex); InPlaceIt != RemotePathIndexToLocalPathIndex.end()) + { + ZEN_ASSERT_SLOW(IsFileWithRetry(FirstTargetFilePath)); + } + else + { + if (IsFileWithRetry(FirstTargetFilePath)) + { + SetFileReadOnlyWithRetry(FirstTargetFilePath, false); + } + else + { + CreateDirectories(FirstTargetFilePath.parent_path()); + } + + if (auto InplaceIt = SequenceHashToLocalPathIndex.find(RawHash); InplaceIt != SequenceHashToLocalPathIndex.end()) + { + ZEN_TRACE_CPU("Copy"); + const uint32_t LocalPathIndex = InplaceIt->second; + const std::filesystem::path& SourcePath = m_LocalContent.Paths[LocalPathIndex]; + std::filesystem::path SourceFilePath = (m_Path / SourcePath).make_preferred(); + ZEN_ASSERT_SLOW(IsFileWithRetry(SourceFilePath)); + + ZEN_DEBUG("Copying from '{}' -> '{}'", SourceFilePath, FirstTargetFilePath); + const uint64_t RawSize = m_LocalContent.RawSizes[LocalPathIndex]; + FastCopyFile(m_Options.AllowFileClone, + m_Options.UseSparseFiles, + SourceFilePath, + FirstTargetFilePath, + RawSize, + m_DiskStats.WriteCount, + m_DiskStats.WriteByteCount, + m_DiskStats.CloneCount, + m_DiskStats.CloneByteCount); + + m_RebuildFolderStateStats.FinalizeTreeFilesCopiedCount++; + } + else + { + ZEN_TRACE_CPU("Rename"); + const std::filesystem::path CacheFilePath = GetFinalChunkedSequenceFileName(m_CacheFolderPath, RawHash); + ZEN_ASSERT_SLOW(IsFileWithRetry(CacheFilePath)); + + std::error_code Ec = RenameFileWithRetry(CacheFilePath, FirstTargetFilePath); + if (Ec) + { + ZEN_WARN("Failed to move file from '{}' to '{}', reason: ({}) {}, retrying...", + CacheFilePath, + FirstTargetFilePath, + Ec.value(), + Ec.message()); + Ec = RenameFileWithRetry(CacheFilePath, FirstTargetFilePath); + if (Ec) + { + throw std::system_error(std::error_code(Ec.value(), std::system_category()), + fmt::format("Failed to move file from '{}' to '{}', reason: ({}) {}", + CacheFilePath, + FirstTargetFilePath, + Ec.value(), + Ec.message())); + } + } + + m_RebuildFolderStateStats.FinalizeTreeFilesMovedCount++; + } + } + + OutLocalFolderState.Paths[FirstRemotePathIndex] = FirstTargetPath; + OutLocalFolderState.RawSizes[FirstRemotePathIndex] = m_RemoteContent.RawSizes[FirstRemotePathIndex]; + + OutLocalFolderState.Attributes[FirstRemotePathIndex] = + m_RemoteContent.Attributes.empty() + ? GetNativeFileAttributes(FirstTargetFilePath) + : SetNativeFileAttributes(FirstTargetFilePath, m_RemoteContent.Platform, m_RemoteContent.Attributes[FirstRemotePathIndex]); + OutLocalFolderState.ModificationTicks[FirstRemotePathIndex] = GetModificationTickFromPath(FirstTargetFilePath); + + TargetOffset++; + TargetsComplete++; + + while (TargetOffset < (BaseOffset + Count)) + { + const uint32_t RemotePathIndex = Targets[TargetOffset].RemotePathIndex; + ZEN_ASSERT(Targets[TargetOffset].RawHash == RawHash); + const std::filesystem::path& TargetPath = m_RemoteContent.Paths[RemotePathIndex]; + std::filesystem::path TargetFilePath = (m_Path / TargetPath).make_preferred(); + + if (auto InPlaceIt = RemotePathIndexToLocalPathIndex.find(RemotePathIndex); InPlaceIt != RemotePathIndexToLocalPathIndex.end()) + { + ZEN_ASSERT_SLOW(IsFileWithRetry(TargetFilePath)); + } + else + { + ZEN_TRACE_CPU("Copy"); + if (IsFileWithRetry(TargetFilePath)) + { + SetFileReadOnlyWithRetry(TargetFilePath, false); + } + else + { + CreateDirectories(TargetFilePath.parent_path()); + } + + ZEN_ASSERT_SLOW(IsFileWithRetry(FirstTargetFilePath)); + ZEN_DEBUG("Copying from '{}' -> '{}'", FirstTargetFilePath, TargetFilePath); + const uint64_t RawSize = m_RemoteContent.RawSizes[RemotePathIndex]; + FastCopyFile(m_Options.AllowFileClone, + m_Options.UseSparseFiles, + FirstTargetFilePath, + TargetFilePath, + RawSize, + m_DiskStats.WriteCount, + m_DiskStats.WriteByteCount, + m_DiskStats.CloneCount, + m_DiskStats.CloneByteCount); + + m_RebuildFolderStateStats.FinalizeTreeFilesCopiedCount++; + } + + OutLocalFolderState.Paths[RemotePathIndex] = TargetPath; + OutLocalFolderState.RawSizes[RemotePathIndex] = m_RemoteContent.RawSizes[RemotePathIndex]; + + OutLocalFolderState.Attributes[RemotePathIndex] = + m_RemoteContent.Attributes.empty() + ? GetNativeFileAttributes(TargetFilePath) + : SetNativeFileAttributes(TargetFilePath, m_RemoteContent.Platform, m_RemoteContent.Attributes[RemotePathIndex]); + OutLocalFolderState.ModificationTicks[RemotePathIndex] = GetModificationTickFromPath(TargetFilePath); + + TargetOffset++; + TargetsComplete++; + } + } +} + std::vector<BuildsOperationUpdateFolder::ScavengeSource> BuildsOperationUpdateFolder::FindScavengeSources() { @@ -2526,7 +2624,7 @@ BuildsOperationUpdateFolder::FindScavengeSources() } catch (const std::exception& Ex) { - ZEN_OPERATION_LOG_WARN(m_LogOutput, "{}", Ex.what()); + ZEN_WARN("{}", Ex.what()); DeleteEntry = true; } @@ -2563,11 +2661,7 @@ BuildsOperationUpdateFolder::ScanTargetFolder(const tsl::robin_map<IoHash, uint3 ZEN_ASSERT_SLOW(IsFile(CacheFilePath)); if (m_Options.IsVerbose) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Found sequence {} at {} ({})", - RemoteSequenceRawHash, - CacheFilePath, - NiceBytes(RemoteRawSize)); + ZEN_INFO("Found sequence {} at {} ({})", RemoteSequenceRawHash, CacheFilePath, NiceBytes(RemoteRawSize)); } } else if (auto CacheChunkIt = CachedChunkHashesFound.find(RemoteSequenceRawHash); CacheChunkIt != CachedChunkHashesFound.end()) @@ -2576,11 +2670,7 @@ BuildsOperationUpdateFolder::ScanTargetFolder(const tsl::robin_map<IoHash, uint3 ZEN_ASSERT_SLOW(IsFile(CacheFilePath)); if (m_Options.IsVerbose) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Found chunk {} at {} ({})", - RemoteSequenceRawHash, - CacheFilePath, - NiceBytes(RemoteRawSize)); + ZEN_INFO("Found chunk {} at {} ({})", RemoteSequenceRawHash, CacheFilePath, NiceBytes(RemoteRawSize)); } } else if (auto It = m_LocalLookup.RawHashToSequenceIndex.find(RemoteSequenceRawHash); @@ -2594,11 +2684,7 @@ BuildsOperationUpdateFolder::ScanTargetFolder(const tsl::robin_map<IoHash, uint3 m_CacheMappingStats.LocalPathsMatchingSequencesByteCount += RemoteRawSize; if (m_Options.IsVerbose) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Found sequence {} at {} ({})", - RemoteSequenceRawHash, - LocalFilePath, - NiceBytes(RemoteRawSize)); + ZEN_INFO("Found sequence {} at {} ({})", RemoteSequenceRawHash, LocalFilePath, NiceBytes(RemoteRawSize)); } } else @@ -2624,10 +2710,9 @@ BuildsOperationUpdateFolder::FindScavengeContent(const ScavengeSource& Source, BuildSaveState SavedState = ReadBuildSaveStateFile(Source.StateFilePath); if (SavedState.Version == BuildSaveState::NoVersion) { - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Skipping old build state at '{}', state files before version {} can not be trusted during scavenge", - Source.StateFilePath, - BuildSaveState::kVersion1); + ZEN_DEBUG("Skipping old build state at '{}', state files before version {} can not be trusted during scavenge", + Source.StateFilePath, + BuildSaveState::kVersion1); return false; } OutScavengedLocalContent = std::move(SavedState.State.ChunkedContent); @@ -2635,7 +2720,7 @@ BuildsOperationUpdateFolder::FindScavengeContent(const ScavengeSource& Source, } catch (const std::exception& Ex) { - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Skipping invalid build state at '{}', reason: {}", Source.StateFilePath, Ex.what()); + ZEN_DEBUG("Skipping invalid build state at '{}', reason: {}", Source.StateFilePath, Ex.what()); return false; } @@ -2688,11 +2773,10 @@ BuildsOperationUpdateFolder::FindScavengeContent(const ScavengeSource& Source, } else { - ZEN_OPERATION_LOG_WARN(m_LogOutput, - "Scavenged state file at '{}' for '{}' is invalid, skipping scavenging for sequence {}", - Source.StateFilePath, - Source.Path, - SequenceHash); + ZEN_WARN("Scavenged state file at '{}' for '{}' is invalid, skipping scavenging for sequence {}", + Source.StateFilePath, + Source.Path, + SequenceHash); } } } @@ -2928,7 +3012,10 @@ BuildsOperationUpdateFolder::CheckRequiredDiskSpace(const tsl::robin_map<std::st if (Space.Free < (RequiredSpace + 16u * 1024u * 1024u)) { throw std::runtime_error( - fmt::format("Not enough free space for target path '{}', {} of free space is needed", m_Path, RequiredSpace)); + fmt::format("Not enough free space for target path '{}', {} of free space is needed but only {} is available", + m_Path, + NiceBytes(RequiredSpace), + NiceBytes(Space.Free))); } } @@ -2980,18 +3067,13 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd FilteredRate& FilteredDownloadedBytesPerSecond, FilteredRate& FilteredWrittenBytesPerSecond) { - std::filesystem::path ExistingCompressedChunkPath; - if (!m_Options.PrimeCacheOnly) + const IoHash& ChunkHash = m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]; + std::filesystem::path ExistingCompressedChunkPath = FindDownloadedChunk(ChunkHash); + if (!ExistingCompressedChunkPath.empty()) { - const IoHash& ChunkHash = m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]; - ExistingCompressedChunkPath = FindDownloadedChunk(ChunkHash); - if (!ExistingCompressedChunkPath.empty()) + if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount) { - m_DownloadStats.RequestsCompleteCount++; - if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } + FilteredDownloadedBytesPerSecond.Stop(); } } if (!m_AbortFlag) @@ -3009,7 +3091,7 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd &FilteredWrittenBytesPerSecond, RemoteChunkIndex, ChunkTargetPtrs = std::move(ChunkTargetPtrs), - CompressedChunkPath = std::move(ExistingCompressedChunkPath)](std::atomic<bool>& AbortFlag) mutable { + CompressedChunkPath = std::move(ExistingCompressedChunkPath)](std::atomic<bool>& AbortFlag) { if (!AbortFlag) { ZEN_TRACE_CPU("Async_WritePreDownloadedChunk"); @@ -3027,11 +3109,11 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd bool NeedHashVerify = WriteCompressedChunkToCache(ChunkHash, ChunkTargetPtrs, WriteCache, std::move(CompressedPart)); - WritePartsComplete++; + bool WritePartsDone = WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount; if (!AbortFlag) { - if (WritePartsComplete == TotalPartWriteCount) + if (WritePartsDone) { FilteredWrittenBytesPerSecond.Stop(); } @@ -3039,11 +3121,7 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd std::error_code Ec = TryRemoveFile(CompressedChunkPath); if (Ec) { - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Failed removing file '{}', reason: ({}) {}", - CompressedChunkPath, - Ec.value(), - Ec.message()); + ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", CompressedChunkPath, Ec.value(), Ec.message()); } std::vector<uint32_t> CompletedSequences = @@ -3085,6 +3163,8 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd DownloadBuildBlob(RemoteChunkIndex, ExistsResult, Work, + TotalRequestCount, + FilteredDownloadedBytesPerSecond, [this, &ExistsResult, SequenceIndexChunksLeftToWriteCounters, @@ -3092,19 +3172,11 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd &Work, &WritePartsComplete, TotalPartWriteCount, - TotalRequestCount, RemoteChunkIndex, - &FilteredDownloadedBytesPerSecond, &FilteredWrittenBytesPerSecond, ChunkTargetPtrs = std::move(ChunkTargetPtrs)](IoBuffer&& Payload) mutable { - if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } - IoBufferFileReference FileRef; - bool EnableBacklog = Payload.GetFileReference(FileRef); - AsyncWriteDownloadedChunk(m_Options.ZenFolderPath, - RemoteChunkIndex, + AsyncWriteDownloadedChunk(RemoteChunkIndex, + ExistsResult, std::move(ChunkTargetPtrs), WriteCache, Work, @@ -3112,8 +3184,7 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd SequenceIndexChunksLeftToWriteCounters, WritePartsComplete, TotalPartWriteCount, - FilteredWrittenBytesPerSecond, - EnableBacklog); + FilteredWrittenBytesPerSecond); }); } }); @@ -3125,6 +3196,8 @@ void BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkIndex, const BlobsExistsResult& ExistsResult, ParallelWork& Work, + uint64_t TotalRequestCount, + FilteredRate& FilteredDownloadedBytesPerSecond, std::function<void(IoBuffer&& Payload)>&& OnDownloaded) { const IoHash& ChunkHash = m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]; @@ -3140,7 +3213,10 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde uint64_t BlobSize = BuildBlob.GetSize(); m_DownloadStats.DownloadedChunkCount++; m_DownloadStats.DownloadedChunkByteCount += BlobSize; - m_DownloadStats.RequestsCompleteCount++; + if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } OnDownloaded(std::move(BuildBlob)); } else @@ -3157,16 +3233,11 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde m_NetworkPool, m_DownloadStats.DownloadedChunkByteCount, m_DownloadStats.MultipartAttachmentCount, - [this, &Work, ChunkHash, RemoteChunkIndex, OnDownloaded = std::move(OnDownloaded)](IoBuffer&& Payload) mutable { + [this, &FilteredDownloadedBytesPerSecond, TotalRequestCount, OnDownloaded = std::move(OnDownloaded)](IoBuffer&& Payload) { m_DownloadStats.DownloadedChunkCount++; - m_DownloadStats.RequestsCompleteCount++; - - if (Payload && m_Storage.CacheStorage && m_Options.PopulateCache) + if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount) { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, - ChunkHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(Payload))); + FilteredDownloadedBytesPerSecond.Stop(); } OnDownloaded(std::move(Payload)); @@ -3174,26 +3245,34 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde } else { - BuildBlob = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, ChunkHash); - if (BuildBlob && m_Storage.CacheStorage && m_Options.PopulateCache) + try { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, - ChunkHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(BuildBlob))); + BuildBlob = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, ChunkHash); } - if (!BuildBlob) + catch (const std::exception&) { - throw std::runtime_error(fmt::format("Chunk {} is missing", ChunkHash)); + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } } - if (!m_Options.PrimeCacheOnly) + if (!m_AbortFlag) { + if (!BuildBlob) + { + throw std::runtime_error(fmt::format("Chunk {} is missing", ChunkHash)); + } + if (!m_AbortFlag) { uint64_t BlobSize = BuildBlob.GetSize(); m_DownloadStats.DownloadedChunkCount++; m_DownloadStats.DownloadedChunkByteCount += BlobSize; - m_DownloadStats.RequestsCompleteCount++; + if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } OnDownloaded(std::move(BuildBlob)); } @@ -3208,6 +3287,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( size_t BlockRangeStartIndex, size_t BlockRangeCount, const BlobsExistsResult& ExistsResult, + uint64_t TotalRequestCount, + FilteredRate& FilteredDownloadedBytesPerSecond, std::function<void(IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath, size_t BlockRangeStartIndex, @@ -3222,6 +3303,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( IoBuffer&& BlockRangeBuffer, size_t BlockRangeStartIndex, std::span<const std::pair<uint64_t, uint64_t>> BlockOffsetAndLengths, + uint64_t TotalRequestCount, + FilteredRate& FilteredDownloadedBytesPerSecond, const std::function<void(IoBuffer && InMemoryBuffer, const std::filesystem::path& OnDiskPath, size_t BlockRangeStartIndex, @@ -3229,63 +3312,23 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( uint64_t BlockRangeBufferSize = BlockRangeBuffer.GetSize(); m_DownloadStats.DownloadedBlockCount++; m_DownloadStats.DownloadedBlockByteCount += BlockRangeBufferSize; - m_DownloadStats.RequestsCompleteCount += BlockOffsetAndLengths.size(); - - std::filesystem::path BlockChunkPath; - - // Check if the dowloaded block is file based and we can move it directly without rewriting it + if (m_DownloadStats.RequestsCompleteCount.fetch_add(BlockOffsetAndLengths.size()) + BlockOffsetAndLengths.size() == + TotalRequestCount) { - IoBufferFileReference FileRef; - if (BlockRangeBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) && - (FileRef.FileChunkSize == BlockRangeBufferSize)) - { - ZEN_TRACE_CPU("MoveTempPartialBlock"); - - std::error_code Ec; - std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec); - if (!Ec) - { - BlockRangeBuffer.SetDeleteOnClose(false); - BlockRangeBuffer = {}; - - IoHashStream RangeId; - for (const std::pair<uint64_t, uint64_t>& Range : BlockOffsetAndLengths) - { - RangeId.Append(&Range.first, sizeof(uint64_t)); - RangeId.Append(&Range.second, sizeof(uint64_t)); - } - - BlockChunkPath = m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash()); - RenameFile(TempBlobPath, BlockChunkPath, Ec); - if (Ec) - { - BlockChunkPath = std::filesystem::path{}; - - // Re-open the temp file again - BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete); - BlockRangeBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockRangeBufferSize, true); - BlockRangeBuffer.SetDeleteOnClose(true); - } - } - } + FilteredDownloadedBytesPerSecond.Stop(); } - if (BlockChunkPath.empty() && (BlockRangeBufferSize > m_Options.MaximumInMemoryPayloadSize)) + IoHashStream RangeId; + for (const std::pair<uint64_t, uint64_t>& Range : BlockOffsetAndLengths) { - ZEN_TRACE_CPU("WriteTempPartialBlock"); - - IoHashStream RangeId; - for (const std::pair<uint64_t, uint64_t>& Range : BlockOffsetAndLengths) - { - RangeId.Append(&Range.first, sizeof(uint64_t)); - RangeId.Append(&Range.second, sizeof(uint64_t)); - } - - // Could not be moved and rather large, lets store it on disk - BlockChunkPath = m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash()); - TemporaryFile::SafeWriteFile(BlockChunkPath, BlockRangeBuffer); - BlockRangeBuffer = {}; + RangeId.Append(&Range.first, sizeof(uint64_t)); + RangeId.Append(&Range.second, sizeof(uint64_t)); } + std::filesystem::path BlockChunkPath = + TryMoveDownloadedChunk(BlockRangeBuffer, + m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash()), + /* ForceDiskBased */ BlockRangeBufferSize > m_Options.MaximumInMemoryPayloadSize); + if (!m_AbortFlag) { OnDownloaded(std::move(BlockRangeBuffer), std::move(BlockChunkPath), BlockRangeStartIndex, BlockOffsetAndLengths); @@ -3337,6 +3380,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( std::move(PayloadBuffer), SubRangeStartIndex, std::vector<std::pair<uint64_t, uint64_t>>{std::make_pair(0u, SubRange.second)}, + TotalRequestCount, + FilteredDownloadedBytesPerSecond, OnDownloaded); SubRangeCountComplete += SubRangeCount; continue; @@ -3361,6 +3406,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( std::move(RangeBuffers.PayloadBuffer), SubRangeStartIndex, RangesSpan.subspan(SubRangeCountComplete, SubRangeCount), + TotalRequestCount, + FilteredDownloadedBytesPerSecond, OnDownloaded); SubRangeCountComplete += SubRangeCount; continue; @@ -3371,6 +3418,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( std::move(RangeBuffers.PayloadBuffer), SubRangeStartIndex, RangeBuffers.Ranges, + TotalRequestCount, + FilteredDownloadedBytesPerSecond, OnDownloaded); SubRangeCountComplete += SubRangeCount; continue; @@ -3383,60 +3432,97 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); - BuildStorageBase::BuildBlobRanges RangeBuffers = - m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges); - if (m_AbortFlag) + BuildStorageBase::BuildBlobRanges RangeBuffers; + + try { - break; + RangeBuffers = m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges); } - if (RangeBuffers.PayloadBuffer) + catch (const std::exception&) { - if (RangeBuffers.Ranges.empty()) + // Silence http errors due to abort + if (!m_AbortFlag) { - // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3 - // Upload to cache (if enabled) and use the whole payload for the remaining ranges + throw; + } + } - if (m_Storage.CacheStorage && m_Options.PopulateCache) + if (!m_AbortFlag) + { + if (RangeBuffers.PayloadBuffer) + { + if (RangeBuffers.Ranges.empty()) { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, - BlockDescription.BlockHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(std::vector<IoBuffer>{RangeBuffers.PayloadBuffer})); + // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3 + // Upload to cache (if enabled) and use the whole payload for the remaining ranges + + const uint64_t Size = RangeBuffers.PayloadBuffer.GetSize(); + + const bool PopulateCache = !ExistsInCache && m_Storage.CacheStorage && m_Options.PopulateCache; + + std::filesystem::path BlockPath = + TryMoveDownloadedChunk(RangeBuffers.PayloadBuffer, + m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(), + /* ForceDiskBased */ PopulateCache || Size > m_Options.MaximumInMemoryPayloadSize); + if (!BlockPath.empty()) + { + RangeBuffers.PayloadBuffer = IoBufferBuilder::MakeFromFile(BlockPath); + if (!RangeBuffers.PayloadBuffer) + { + throw std::runtime_error( + fmt::format("Failed to read block {} from temporary path '{}'", BlockDescription.BlockHash, BlockPath)); + } + RangeBuffers.PayloadBuffer.SetDeleteOnClose(true); + } + + if (PopulateCache) + { + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlockDescription.BlockHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(RangeBuffers.PayloadBuffer))); + } + if (m_AbortFlag) { break; } - } - SubRangeCount = Ranges.size() - SubRangeCountComplete; - ProcessDownload(BlockDescription, - std::move(RangeBuffers.PayloadBuffer), - SubRangeStartIndex, - RangesSpan.subspan(SubRangeCountComplete, SubRangeCount), - OnDownloaded); + SubRangeCount = Ranges.size() - SubRangeCountComplete; + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangesSpan.subspan(SubRangeCountComplete, SubRangeCount), + TotalRequestCount, + FilteredDownloadedBytesPerSecond, + OnDownloaded); + } + else + { + if (RangeBuffers.Ranges.size() != SubRanges.size()) + { + throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", + SubRanges.size(), + BlockDescription.BlockHash, + RangeBuffers.Ranges.size())); + } + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangeBuffers.Ranges, + TotalRequestCount, + FilteredDownloadedBytesPerSecond, + OnDownloaded); + } } else { - if (RangeBuffers.Ranges.size() != SubRanges.size()) - { - throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", - SubRanges.size(), - BlockDescription.BlockHash, - RangeBuffers.Ranges.size())); - } - ProcessDownload(BlockDescription, - std::move(RangeBuffers.PayloadBuffer), - SubRangeStartIndex, - RangeBuffers.Ranges, - OnDownloaded); + throw std::runtime_error( + fmt::format("Block {} is missing when fetching {} ranges", BlockDescription.BlockHash, SubRangeCount)); } - } - else - { - throw std::runtime_error(fmt::format("Block {} is missing when fetching {} ranges", BlockDescription.BlockHash, SubRangeCount)); - } - SubRangeCountComplete += SubRangeCount; + SubRangeCountComplete += SubRangeCount; + } } } @@ -3654,7 +3740,7 @@ BuildsOperationUpdateFolder::WriteLocalChunkToCache(CloneQueryInterface* C if (m_Options.IsVerbose) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Copied {} from {}", NiceBytes(CacheLocalFileBytesRead), SourceFilePath); + ZEN_INFO("Copied {} from {}", NiceBytes(CacheLocalFileBytesRead), SourceFilePath); } std::vector<uint32_t> Result; @@ -4149,8 +4235,8 @@ BuildsOperationUpdateFolder::WritePartialBlockChunksToCache(const ChunkBlockDesc } void -BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::path& ZenFolderPath, - uint32_t RemoteChunkIndex, +BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(uint32_t RemoteChunkIndex, + const BlobsExistsResult& ExistsResult, std::vector<const ChunkedContentLookup::ChunkSequenceLocation*>&& ChunkTargetPtrs, BufferedWriteFileCache& WriteCache, ParallelWork& Work, @@ -4158,8 +4244,7 @@ BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::pa std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters, std::atomic<uint64_t>& WritePartsComplete, const uint64_t TotalPartWriteCount, - FilteredRate& FilteredWrittenBytesPerSecond, - bool EnableBacklog) + FilteredRate& FilteredWrittenBytesPerSecond) { ZEN_TRACE_CPU("AsyncWriteDownloadedChunk"); @@ -4167,48 +4252,32 @@ BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::pa const uint64_t Size = Payload.GetSize(); - std::filesystem::path CompressedChunkPath; + const bool ExistsInCache = m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(ChunkHash); + + const bool PopulateCache = !ExistsInCache && m_Storage.CacheStorage && m_Options.PopulateCache; - // Check if the dowloaded chunk is file based and we can move it directly without rewriting it + std::filesystem::path CompressedChunkPath = + TryMoveDownloadedChunk(Payload, + m_TempDownloadFolderPath / ChunkHash.ToHexString(), + /* ForceDiskBased */ PopulateCache || Size > m_Options.MaximumInMemoryPayloadSize); + if (PopulateCache) { - IoBufferFileReference FileRef; - if (Payload.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) && (FileRef.FileChunkSize == Size)) + IoBuffer CacheBlob = IoBufferBuilder::MakeFromFile(CompressedChunkPath); + if (CacheBlob) { - ZEN_TRACE_CPU("MoveTempChunk"); - std::error_code Ec; - std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec); - if (!Ec) - { - Payload.SetDeleteOnClose(false); - Payload = {}; - CompressedChunkPath = m_TempDownloadFolderPath / ChunkHash.ToHexString(); - RenameFile(TempBlobPath, CompressedChunkPath, Ec); - if (Ec) - { - CompressedChunkPath = std::filesystem::path{}; - - // Re-open the temp file again - BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete); - Payload = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, Size, true); - Payload.SetDeleteOnClose(true); - } - } + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + ChunkHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(CacheBlob))); } } - if (CompressedChunkPath.empty() && (Size > m_Options.MaximumInMemoryPayloadSize)) - { - ZEN_TRACE_CPU("WriteTempChunk"); - // Could not be moved and rather large, lets store it on disk - CompressedChunkPath = m_TempDownloadFolderPath / ChunkHash.ToHexString(); - TemporaryFile::SafeWriteFile(CompressedChunkPath, Payload); - Payload = {}; - } + IoBufferFileReference FileRef; + bool EnableBacklog = !CompressedChunkPath.empty() || Payload.GetFileReference(FileRef); Work.ScheduleWork( m_IOWorkerPool, - [&ZenFolderPath, - this, + [this, SequenceIndexChunksLeftToWriteCounters, &Work, CompressedChunkPath, @@ -4244,8 +4313,7 @@ BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::pa bool NeedHashVerify = WriteCompressedChunkToCache(ChunkHash, ChunkTargetPtrs, WriteCache, std::move(CompressedPart)); if (!m_AbortFlag) { - WritePartsComplete++; - if (WritePartsComplete == TotalPartWriteCount) + if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount) { FilteredWrittenBytesPerSecond.Stop(); } @@ -4255,11 +4323,7 @@ BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::pa std::error_code Ec = TryRemoveFile(CompressedChunkPath); if (Ec) { - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Failed removing file '{}', reason: ({}) {}", - CompressedChunkPath, - Ec.value(), - Ec.message()); + ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", CompressedChunkPath, Ec.value(), Ec.message()); } } @@ -4412,7 +4476,8 @@ BuildsOperationUpdateFolder::VerifySequence(uint32_t RemoteSequenceIndex) ////////////////////// BuildsOperationUploadFolder -BuildsOperationUploadFolder::BuildsOperationUploadFolder(OperationLogOutput& OperationLogOutput, +BuildsOperationUploadFolder::BuildsOperationUploadFolder(LoggerRef Log, + ProgressBase& Progress, StorageInstance& Storage, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, @@ -4423,7 +4488,8 @@ BuildsOperationUploadFolder::BuildsOperationUploadFolder(OperationLogOutput& bool CreateBuild, const CbObject& MetaData, const Options& Options) -: m_LogOutput(OperationLogOutput) +: m_Log(Log) +, m_Progress(Progress) , m_Storage(Storage) , m_AbortFlag(AbortFlag) , m_PauseFlag(PauseFlag) @@ -4476,9 +4542,7 @@ BuildsOperationUploadFolder::PrepareBuild() } else if (m_Options.AllowMultiparts) { - ZEN_OPERATION_LOG_WARN(m_LogOutput, - "PreferredMultipartChunkSize is unknown. Defaulting to '{}'", - NiceBytes(Result.PreferredMultipartChunkSize)); + ZEN_WARN("PreferredMultipartChunkSize is unknown. Defaulting to '{}'", NiceBytes(Result.PreferredMultipartChunkSize)); } } @@ -4538,10 +4602,8 @@ BuildsOperationUploadFolder::ReadFolder() return true; }, m_IOWorkerPool, - m_LogOutput.GetProgressUpdateDelayMS(), - [&](bool, std::ptrdiff_t) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Found {} files in '{}'...", LocalFolderScanStats.AcceptedFileCount.load(), m_Path); - }, + m_Progress.GetProgressUpdateDelayMS(), + [&](bool, std::ptrdiff_t) { ZEN_INFO("Found {} files in '{}'...", LocalFolderScanStats.AcceptedFileCount.load(), m_Path); }, m_AbortFlag); Part.TotalRawSize = std::accumulate(Part.Content.RawSizes.begin(), Part.Content.RawSizes.end(), std::uint64_t(0)); @@ -4655,10 +4717,7 @@ BuildsOperationUploadFolder::Execute(const Oid& BuildPartId, if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Reading {} parts took {}", - UploadParts.size(), - NiceTimeSpanMs(ReadPartsTimer.GetElapsedTimeMs())); + ZEN_INFO("Reading {} parts took {}", UploadParts.size(), NiceTimeSpanMs(ReadPartsTimer.GetElapsedTimeMs())); } const uint32_t PartsUploadStepCount = gsl::narrow<uint32_t>(uint32_t(PartTaskSteps::StepCount) * UploadParts.size()); @@ -4669,7 +4728,7 @@ BuildsOperationUploadFolder::Execute(const Oid& BuildPartId, const uint32_t CleanupStep = FinalizeBuildStep + 1; const uint32_t StepCount = CleanupStep + 1; - auto EndProgress = MakeGuard([&]() { m_LogOutput.SetLogOperationProgress(StepCount, StepCount); }); + auto EndProgress = MakeGuard([&]() { m_Progress.SetLogOperationProgress(StepCount, StepCount); }); Stopwatch ProcessTimer; @@ -4677,7 +4736,7 @@ BuildsOperationUploadFolder::Execute(const Oid& BuildPartId, CreateDirectories(m_Options.TempDir); auto _ = MakeGuard([&]() { CleanAndRemoveDirectory(m_IOWorkerPool, m_AbortFlag, m_PauseFlag, m_Options.TempDir); }); - m_LogOutput.SetLogOperationProgress(PrepareBuildStep, StepCount); + m_Progress.SetLogOperationProgress(PrepareBuildStep, StepCount); m_PrepBuildResultFuture = m_NetworkPool.EnqueueTask(std::packaged_task<PrepareBuildResult()>{[this] { return PrepareBuild(); }}, WorkerThreadPool::EMode::EnableBacklog); @@ -4694,7 +4753,7 @@ BuildsOperationUploadFolder::Execute(const Oid& BuildPartId, } } - m_LogOutput.SetLogOperationProgress(FinalizeBuildStep, StepCount); + m_Progress.SetLogOperationProgress(FinalizeBuildStep, StepCount); if (m_CreateBuild && !m_AbortFlag) { @@ -4702,11 +4761,11 @@ BuildsOperationUploadFolder::Execute(const Oid& BuildPartId, m_Storage.BuildStorage->FinalizeBuild(m_BuildId); if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "FinalizeBuild took {}", NiceTimeSpanMs(FinalizeBuildTimer.GetElapsedTimeMs())); + ZEN_INFO("FinalizeBuild took {}", NiceTimeSpanMs(FinalizeBuildTimer.GetElapsedTimeMs())); } } - m_LogOutput.SetLogOperationProgress(CleanupStep, StepCount); + m_Progress.SetLogOperationProgress(CleanupStep, StepCount); std::vector<std::pair<Oid, std::string>> Result; Result.reserve(UploadParts.size()); @@ -4864,235 +4923,256 @@ BuildsOperationUploadFolder::GenerateBuildBlocks(const ChunkedFolderContent& { ZEN_TRACE_CPU("GenerateBuildBlocks"); const std::size_t NewBlockCount = NewBlockChunks.size(); - if (NewBlockCount > 0) + if (NewBlockCount == 0) { - std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Generate Blocks")); - OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr); + return; + } - OutBlocks.BlockDescriptions.resize(NewBlockCount); - OutBlocks.BlockSizes.resize(NewBlockCount); - OutBlocks.BlockMetaDatas.resize(NewBlockCount); - OutBlocks.BlockHeaders.resize(NewBlockCount); - OutBlocks.MetaDataHasBeenUploaded.resize(NewBlockCount, 0); - OutBlocks.BlockHashToBlockIndex.reserve(NewBlockCount); + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Generate Blocks"); + + OutBlocks.BlockDescriptions.resize(NewBlockCount); + OutBlocks.BlockSizes.resize(NewBlockCount); + OutBlocks.BlockMetaDatas.resize(NewBlockCount); + OutBlocks.BlockHeaders.resize(NewBlockCount); + OutBlocks.MetaDataHasBeenUploaded.resize(NewBlockCount, 0); + OutBlocks.BlockHashToBlockIndex.reserve(NewBlockCount); + + RwLock Lock; + FilteredRate FilteredGeneratedBytesPerSecond; + FilteredRate FilteredUploadedBytesPerSecond; + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + std::atomic<uint64_t> QueuedPendingBlocksForUpload = 0; + + GenerateBuildBlocksContext Context{.Work = Work, + .GenerateBlobsPool = m_IOWorkerPool, + .UploadBlocksPool = m_NetworkPool, + .FilteredGeneratedBytesPerSecond = FilteredGeneratedBytesPerSecond, + .FilteredUploadedBytesPerSecond = FilteredUploadedBytesPerSecond, + .QueuedPendingBlocksForUpload = QueuedPendingBlocksForUpload, + .Lock = Lock, + .OutBlocks = OutBlocks, + .GenerateBlocksStats = GenerateBlocksStats, + .UploadStats = UploadStats, + .NewBlockCount = NewBlockCount}; + + ScheduleBlockGeneration(Context, Content, Lookup, NewBlockChunks); + + Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { + ZEN_UNUSED(PendingWork); + + FilteredGeneratedBytesPerSecond.Update(GenerateBlocksStats.GeneratedBlockByteCount.load()); + FilteredUploadedBytesPerSecond.Update(UploadStats.BlocksBytes.load()); + + std::string Details = fmt::format("Generated {}/{} ({}, {}B/s). Uploaded {}/{} ({}, {}bits/s)", + GenerateBlocksStats.GeneratedBlockCount.load(), + NewBlockCount, + NiceBytes(GenerateBlocksStats.GeneratedBlockByteCount.load()), + NiceNum(FilteredGeneratedBytesPerSecond.GetCurrent()), + UploadStats.BlockCount.load(), + NewBlockCount, + NiceBytes(UploadStats.BlocksBytes.load()), + NiceNum(FilteredUploadedBytesPerSecond.GetCurrent() * 8)); + + ProgressBar->UpdateState({.Task = "Generating blocks", + .Details = Details, + .TotalCount = gsl::narrow<uint64_t>(NewBlockCount), + .RemainingCount = gsl::narrow<uint64_t>(NewBlockCount - GenerateBlocksStats.GeneratedBlockCount.load()), + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + false); + }); - RwLock Lock; + ZEN_ASSERT(m_AbortFlag || QueuedPendingBlocksForUpload.load() == 0); - WorkerThreadPool& GenerateBlobsPool = m_IOWorkerPool; - WorkerThreadPool& UploadBlocksPool = m_NetworkPool; + ProgressBar->Finish(); - FilteredRate FilteredGeneratedBytesPerSecond; - FilteredRate FilteredUploadedBytesPerSecond; + GenerateBlocksStats.GenerateBlocksElapsedWallTimeUS = FilteredGeneratedBytesPerSecond.GetElapsedTimeUS(); + UploadStats.ElapsedWallTimeUS = FilteredUploadedBytesPerSecond.GetElapsedTimeUS(); +} - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); +void +BuildsOperationUploadFolder::ScheduleBlockGeneration(GenerateBuildBlocksContext& Context, + const ChunkedFolderContent& Content, + const ChunkedContentLookup& Lookup, + const std::vector<std::vector<uint32_t>>& NewBlockChunks) +{ + for (size_t BlockIndex = 0; BlockIndex < Context.NewBlockCount; BlockIndex++) + { + if (Context.Work.IsAborted()) + { + break; + } + const std::vector<uint32_t>& ChunksInBlock = NewBlockChunks[BlockIndex]; + Context.Work.ScheduleWork( + Context.GenerateBlobsPool, + [this, &Context, &Content, &Lookup, ChunksInBlock, BlockIndex](std::atomic<bool>&) { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("GenerateBuildBlocks_Generate"); - std::atomic<uint64_t> QueuedPendingBlocksForUpload = 0; + Context.FilteredGeneratedBytesPerSecond.Start(); - for (size_t BlockIndex = 0; BlockIndex < NewBlockCount; BlockIndex++) - { - if (Work.IsAborted()) - { - break; - } - const std::vector<uint32_t>& ChunksInBlock = NewBlockChunks[BlockIndex]; - Work.ScheduleWork( - GenerateBlobsPool, - [this, - &Content, - &Lookup, - &Work, - &UploadBlocksPool, - NewBlockCount, - ChunksInBlock, - &Lock, - &OutBlocks, - &GenerateBlocksStats, - &UploadStats, - &FilteredGeneratedBytesPerSecond, - &QueuedPendingBlocksForUpload, - &FilteredUploadedBytesPerSecond, - BlockIndex](std::atomic<bool>&) { - if (!m_AbortFlag) + Stopwatch GenerateTimer; + CompressedBuffer CompressedBlock = + GenerateBlock(Content, Lookup, ChunksInBlock, Context.OutBlocks.BlockDescriptions[BlockIndex]); + if (m_Options.IsVerbose) { - ZEN_TRACE_CPU("GenerateBuildBlocks_Generate"); + ZEN_INFO("Generated block {} ({}) containing {} chunks in {}", + Context.OutBlocks.BlockDescriptions[BlockIndex].BlockHash, + NiceBytes(CompressedBlock.GetCompressedSize()), + Context.OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size(), + NiceTimeSpanMs(GenerateTimer.GetElapsedTimeMs())); + } - FilteredGeneratedBytesPerSecond.Start(); + Context.OutBlocks.BlockSizes[BlockIndex] = CompressedBlock.GetCompressedSize(); + { + CbObjectWriter Writer; + Writer.AddString("createdBy", "zen"); + Context.OutBlocks.BlockMetaDatas[BlockIndex] = Writer.Save(); + } + Context.GenerateBlocksStats.GeneratedBlockByteCount += Context.OutBlocks.BlockSizes[BlockIndex]; + Context.GenerateBlocksStats.GeneratedBlockCount++; - Stopwatch GenerateTimer; - CompressedBuffer CompressedBlock = - GenerateBlock(Content, Lookup, ChunksInBlock, OutBlocks.BlockDescriptions[BlockIndex]); - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Generated block {} ({}) containing {} chunks in {}", - OutBlocks.BlockDescriptions[BlockIndex].BlockHash, - NiceBytes(CompressedBlock.GetCompressedSize()), - OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size(), - NiceTimeSpanMs(GenerateTimer.GetElapsedTimeMs())); - } + Context.Lock.WithExclusiveLock([&]() { + Context.OutBlocks.BlockHashToBlockIndex.insert_or_assign(Context.OutBlocks.BlockDescriptions[BlockIndex].BlockHash, + BlockIndex); + }); - OutBlocks.BlockSizes[BlockIndex] = CompressedBlock.GetCompressedSize(); - { - CbObjectWriter Writer; - Writer.AddString("createdBy", "zen"); - OutBlocks.BlockMetaDatas[BlockIndex] = Writer.Save(); - } - GenerateBlocksStats.GeneratedBlockByteCount += OutBlocks.BlockSizes[BlockIndex]; - GenerateBlocksStats.GeneratedBlockCount++; + { + std::span<const SharedBuffer> Segments = CompressedBlock.GetCompressed().GetSegments(); + ZEN_ASSERT(Segments.size() >= 2); + Context.OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]); + } - Lock.WithExclusiveLock([&]() { - OutBlocks.BlockHashToBlockIndex.insert_or_assign(OutBlocks.BlockDescriptions[BlockIndex].BlockHash, BlockIndex); - }); + if (Context.GenerateBlocksStats.GeneratedBlockCount == Context.NewBlockCount) + { + Context.FilteredGeneratedBytesPerSecond.Stop(); + } + if (Context.QueuedPendingBlocksForUpload.load() > 16) + { + std::span<const SharedBuffer> Segments = CompressedBlock.GetCompressed().GetSegments(); + ZEN_ASSERT(Segments.size() >= 2); + Context.OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]); + } + else + { + if (!m_AbortFlag) { - std::span<const SharedBuffer> Segments = CompressedBlock.GetCompressed().GetSegments(); - ZEN_ASSERT(Segments.size() >= 2); - OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]); + Context.QueuedPendingBlocksForUpload++; + Context.Work.ScheduleWork( + Context.UploadBlocksPool, + [this, &Context, BlockIndex, Payload = std::move(CompressedBlock)](std::atomic<bool>&) mutable { + UploadGeneratedBlock(Context, BlockIndex, std::move(Payload)); + }); } + } + } + }); + } +} - if (GenerateBlocksStats.GeneratedBlockCount == NewBlockCount) - { - FilteredGeneratedBytesPerSecond.Stop(); - } +void +BuildsOperationUploadFolder::UploadGeneratedBlock(GenerateBuildBlocksContext& Context, size_t BlockIndex, CompressedBuffer Payload) +{ + auto _ = MakeGuard([&Context] { Context.QueuedPendingBlocksForUpload--; }); + if (m_AbortFlag) + { + return; + } - if (QueuedPendingBlocksForUpload.load() > 16) - { - std::span<const SharedBuffer> Segments = CompressedBlock.GetCompressed().GetSegments(); - ZEN_ASSERT(Segments.size() >= 2); - OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]); - } - else - { - if (!m_AbortFlag) - { - QueuedPendingBlocksForUpload++; + if (Context.GenerateBlocksStats.GeneratedBlockCount == Context.NewBlockCount) + { + ZEN_TRACE_CPU("GenerateBuildBlocks_Save"); - Work.ScheduleWork( - UploadBlocksPool, - [this, - NewBlockCount, - &GenerateBlocksStats, - &UploadStats, - &FilteredUploadedBytesPerSecond, - &QueuedPendingBlocksForUpload, - &OutBlocks, - BlockIndex, - Payload = std::move(CompressedBlock)](std::atomic<bool>&) mutable { - auto _ = MakeGuard([&QueuedPendingBlocksForUpload] { QueuedPendingBlocksForUpload--; }); - if (!m_AbortFlag) - { - if (GenerateBlocksStats.GeneratedBlockCount == NewBlockCount) - { - ZEN_TRACE_CPU("GenerateBuildBlocks_Save"); - - FilteredUploadedBytesPerSecond.Stop(); - std::span<const SharedBuffer> Segments = Payload.GetCompressed().GetSegments(); - ZEN_ASSERT(Segments.size() >= 2); - OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]); - } - else - { - ZEN_TRACE_CPU("GenerateBuildBlocks_Upload"); - - FilteredUploadedBytesPerSecond.Start(); - - const CbObject BlockMetaData = - BuildChunkBlockDescription(OutBlocks.BlockDescriptions[BlockIndex], - OutBlocks.BlockMetaDatas[BlockIndex]); - - const IoHash& BlockHash = OutBlocks.BlockDescriptions[BlockIndex].BlockHash; - const uint64_t CompressedBlockSize = Payload.GetCompressedSize(); - - if (m_Storage.CacheStorage && m_Options.PopulateCache) - { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, - BlockHash, - ZenContentType::kCompressedBinary, - Payload.GetCompressed()); - } - - m_Storage.BuildStorage->PutBuildBlob(m_BuildId, - BlockHash, - ZenContentType::kCompressedBinary, - std::move(Payload).GetCompressed()); - UploadStats.BlocksBytes += CompressedBlockSize; - - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Uploaded block {} ({}) containing {} chunks", - BlockHash, - NiceBytes(CompressedBlockSize), - OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size()); - } - - if (m_Storage.CacheStorage && m_Options.PopulateCache) - { - m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, - std::vector<IoHash>({BlockHash}), - std::vector<CbObject>({BlockMetaData})); - } - - bool MetadataSucceeded = - m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); - if (MetadataSucceeded) - { - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Uploaded block {} metadata ({})", - BlockHash, - NiceBytes(BlockMetaData.GetSize())); - } - - OutBlocks.MetaDataHasBeenUploaded[BlockIndex] = true; - UploadStats.BlocksBytes += BlockMetaData.GetSize(); - } - - UploadStats.BlockCount++; - if (UploadStats.BlockCount == NewBlockCount) - { - FilteredUploadedBytesPerSecond.Stop(); - } - } - } - }); - } - } - } - }); + Context.FilteredUploadedBytesPerSecond.Stop(); + std::span<const SharedBuffer> Segments = Payload.GetCompressed().GetSegments(); + ZEN_ASSERT(Segments.size() >= 2); + Context.OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]); + return; + } + + ZEN_TRACE_CPU("GenerateBuildBlocks_Upload"); + + Context.FilteredUploadedBytesPerSecond.Start(); + + const CbObject BlockMetaData = + BuildChunkBlockDescription(Context.OutBlocks.BlockDescriptions[BlockIndex], Context.OutBlocks.BlockMetaDatas[BlockIndex]); + + const IoHash& BlockHash = Context.OutBlocks.BlockDescriptions[BlockIndex].BlockHash; + const uint64_t CompressedBlockSize = Payload.GetCompressedSize(); + + if (m_Storage.CacheStorage && m_Options.PopulateCache) + { + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload.GetCompressed()); + } + + try + { + m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, std::move(Payload).GetCompressed()); + } + catch (const std::exception&) + { + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; } + } - Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { - ZEN_UNUSED(PendingWork); + if (m_AbortFlag) + { + return; + } - FilteredGeneratedBytesPerSecond.Update(GenerateBlocksStats.GeneratedBlockByteCount.load()); - FilteredUploadedBytesPerSecond.Update(UploadStats.BlocksBytes.load()); + Context.UploadStats.BlocksBytes += CompressedBlockSize; - std::string Details = fmt::format("Generated {}/{} ({}, {}B/s). Uploaded {}/{} ({}, {}bits/s)", - GenerateBlocksStats.GeneratedBlockCount.load(), - NewBlockCount, - NiceBytes(GenerateBlocksStats.GeneratedBlockByteCount.load()), - NiceNum(FilteredGeneratedBytesPerSecond.GetCurrent()), - UploadStats.BlockCount.load(), - NewBlockCount, - NiceBytes(UploadStats.BlocksBytes.load()), - NiceNum(FilteredUploadedBytesPerSecond.GetCurrent() * 8)); + if (m_Options.IsVerbose) + { + ZEN_INFO("Uploaded block {} ({}) containing {} chunks", + BlockHash, + NiceBytes(CompressedBlockSize), + Context.OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size()); + } - Progress.UpdateState({.Task = "Generating blocks", - .Details = Details, - .TotalCount = gsl::narrow<uint64_t>(NewBlockCount), - .RemainingCount = gsl::narrow<uint64_t>(NewBlockCount - GenerateBlocksStats.GeneratedBlockCount.load()), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, - false); - }); + if (m_Storage.CacheStorage && m_Options.PopulateCache) + { + m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, std::vector<IoHash>({BlockHash}), std::vector<CbObject>({BlockMetaData})); + } - ZEN_ASSERT(m_AbortFlag || QueuedPendingBlocksForUpload.load() == 0); + bool MetadataSucceeded = false; + try + { + MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); + } + catch (const std::exception&) + { + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } + } - Progress.Finish(); + if (m_AbortFlag) + { + return; + } - GenerateBlocksStats.GenerateBlocksElapsedWallTimeUS = FilteredGeneratedBytesPerSecond.GetElapsedTimeUS(); - UploadStats.ElapsedWallTimeUS = FilteredUploadedBytesPerSecond.GetElapsedTimeUS(); + if (MetadataSucceeded) + { + if (m_Options.IsVerbose) + { + ZEN_INFO("Uploaded block {} metadata ({})", BlockHash, NiceBytes(BlockMetaData.GetSize())); + } + + Context.OutBlocks.MetaDataHasBeenUploaded[BlockIndex] = true; + Context.UploadStats.BlocksBytes += BlockMetaData.GetSize(); + } + + Context.UploadStats.BlockCount++; + if (Context.UploadStats.BlockCount == Context.NewBlockCount) + { + Context.FilteredUploadedBytesPerSecond.Stop(); } } @@ -5220,7 +5300,7 @@ BuildsOperationUploadFolder::GenerateBlock(const ChunkedFolderContent& Content, uint64_t RawSize = Chunk.GetSize(); const bool ShouldCompressChunk = RawSize >= m_Options.MinimumSizeForCompressInBlock && - IsChunkCompressable(m_NonCompressableExtensionHashes, Content, Lookup, ChunkIndex); + IsChunkCompressable(m_NonCompressableExtensionHashes, Lookup, ChunkIndex); const OodleCompressionLevel CompressionLevel = ShouldCompressChunk ? OodleCompressionLevel::VeryFast : OodleCompressionLevel::None; @@ -5259,9 +5339,9 @@ BuildsOperationUploadFolder::RebuildBlock(const ChunkedFolderContent& Content, Content.ChunkedContent.ChunkRawSizes[ChunkIndex]); ZEN_ASSERT_SLOW(IoHash::HashBuffer(Chunk) == Content.ChunkedContent.ChunkHashes[ChunkIndex]); - const uint64_t RawSize = Chunk.GetSize(); - const bool ShouldCompressChunk = RawSize >= m_Options.MinimumSizeForCompressInBlock && - IsChunkCompressable(m_NonCompressableExtensionHashes, Content, Lookup, ChunkIndex); + const uint64_t RawSize = Chunk.GetSize(); + const bool ShouldCompressChunk = + RawSize >= m_Options.MinimumSizeForCompressInBlock && IsChunkCompressable(m_NonCompressableExtensionHashes, Lookup, ChunkIndex); const OodleCompressionLevel CompressionLevel = ShouldCompressChunk ? OodleCompressionLevel::VeryFast : OodleCompressionLevel::None; @@ -5287,158 +5367,40 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController ReuseBlocksStatistics ReuseBlocksStats; UploadStatistics UploadStats; GenerateBlocksStatistics GenerateBlocksStats; + LooseChunksStatistics LooseChunksStats; - LooseChunksStatistics LooseChunksStats; - ChunkedFolderContent LocalContent; - - m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::ChunkPartContent, StepCount); + m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::ChunkPartContent, StepCount); - Stopwatch ScanTimer; - { - std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Scan Folder")); - OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr); - - FilteredRate FilteredBytesHashed; - FilteredBytesHashed.Start(); - LocalContent = ChunkFolderContent( - ChunkingStats, - m_IOWorkerPool, - m_Path, - Part.Content, - ChunkController, - ChunkCache, - m_LogOutput.GetProgressUpdateDelayMS(), - [&](bool IsAborted, bool IsPaused, std::ptrdiff_t) { - FilteredBytesHashed.Update(ChunkingStats.BytesHashed.load()); - std::string Details = fmt::format("{}/{} ({}/{}, {}B/s) scanned, {} ({}) chunks found", - ChunkingStats.FilesProcessed.load(), - Part.Content.Paths.size(), - NiceBytes(ChunkingStats.BytesHashed.load()), - NiceBytes(Part.TotalRawSize), - NiceNum(FilteredBytesHashed.GetCurrent()), - ChunkingStats.UniqueChunksFound.load(), - NiceBytes(ChunkingStats.UniqueBytesFound.load())); - Progress.UpdateState({.Task = "Scanning files ", - .Details = Details, - .TotalCount = Part.TotalRawSize, - .RemainingCount = Part.TotalRawSize - ChunkingStats.BytesHashed.load(), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, - false); - }, - m_AbortFlag, - m_PauseFlag); - FilteredBytesHashed.Stop(); - Progress.Finish(); - if (m_AbortFlag) - { - return; - } - } - - if (!m_Options.IsQuiet) + ChunkedFolderContent LocalContent = ScanPartContent(Part, ChunkController, ChunkCache, ChunkingStats); + if (m_AbortFlag) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Found {} ({}) files divided into {} ({}) unique chunks in '{}' in {}. Average hash rate {}B/sec", - Part.Content.Paths.size(), - NiceBytes(Part.TotalRawSize), - ChunkingStats.UniqueChunksFound.load(), - NiceBytes(ChunkingStats.UniqueBytesFound.load()), - m_Path, - NiceTimeSpanMs(ScanTimer.GetElapsedTimeMs()), - NiceNum(GetBytesPerSecond(ChunkingStats.ElapsedWallTimeUS, ChunkingStats.BytesHashed))); + return; } const ChunkedContentLookup LocalLookup = BuildChunkedContentLookup(LocalContent); - std::vector<size_t> ReuseBlockIndexes; - std::vector<uint32_t> NewBlockChunkIndexes; - if (PartIndex == 0) { - const PrepareBuildResult PrepBuildResult = m_PrepBuildResultFuture.get(); - - m_FindBlocksStats.FindBlockTimeMS = PrepBuildResult.ElapsedTimeMs; - m_FindBlocksStats.FoundBlockCount = PrepBuildResult.KnownBlocks.size(); - - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Build prepare took {}. {} took {}, payload size {}{}", - NiceTimeSpanMs(PrepBuildResult.ElapsedTimeMs), - m_CreateBuild ? "PutBuild" : "GetBuild", - NiceTimeSpanMs(PrepBuildResult.PrepareBuildTimeMs), - NiceBytes(PrepBuildResult.PayloadSize), - m_Options.IgnoreExistingBlocks ? "" - : fmt::format(". Found {} blocks in {}", - PrepBuildResult.KnownBlocks.size(), - NiceTimeSpanMs(PrepBuildResult.FindBlocksTimeMs))); - } - - m_PreferredMultipartChunkSize = PrepBuildResult.PreferredMultipartChunkSize; - - m_LargeAttachmentSize = m_Options.AllowMultiparts ? m_PreferredMultipartChunkSize * 4u : (std::uint64_t)-1; - - m_KnownBlocks = std::move(PrepBuildResult.KnownBlocks); + ConsumePrepareBuildResult(); } ZEN_ASSERT(m_PreferredMultipartChunkSize != 0); ZEN_ASSERT(m_LargeAttachmentSize != 0); - m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::CalculateDelta, StepCount); + m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::CalculateDelta, StepCount); Stopwatch BlockArrangeTimer; - std::vector<std::uint32_t> LooseChunkIndexes; - { - bool EnableBlocks = true; - std::vector<std::uint32_t> BlockChunkIndexes; - for (uint32_t ChunkIndex = 0; ChunkIndex < LocalContent.ChunkedContent.ChunkHashes.size(); ChunkIndex++) - { - const uint64_t ChunkRawSize = LocalContent.ChunkedContent.ChunkRawSizes[ChunkIndex]; - if (!EnableBlocks || ChunkRawSize == 0 || ChunkRawSize > m_Options.BlockParameters.MaxChunkEmbedSize) - { - LooseChunkIndexes.push_back(ChunkIndex); - LooseChunksStats.ChunkByteCount += ChunkRawSize; - } - else - { - BlockChunkIndexes.push_back(ChunkIndex); - FindBlocksStats.PotentialChunkByteCount += ChunkRawSize; - } - } - FindBlocksStats.PotentialChunkCount += BlockChunkIndexes.size(); - LooseChunksStats.ChunkCount = LooseChunkIndexes.size(); - - if (m_Options.IgnoreExistingBlocks) - { - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Ignoring any existing blocks in store"); - } - NewBlockChunkIndexes = std::move(BlockChunkIndexes); - } - else - { - ReuseBlockIndexes = FindReuseBlocks(m_LogOutput, - m_Options.BlockReuseMinPercentLimit, - m_Options.IsVerbose, - ReuseBlocksStats, - m_KnownBlocks, - LocalContent.ChunkedContent.ChunkHashes, - BlockChunkIndexes, - NewBlockChunkIndexes); - FindBlocksStats.AcceptedBlockCount += ReuseBlockIndexes.size(); - - for (const ChunkBlockDescription& Description : m_KnownBlocks) - { - for (uint32_t ChunkRawLength : Description.ChunkRawLengths) - { - FindBlocksStats.FoundBlockByteCount += ChunkRawLength; - } - FindBlocksStats.FoundBlockChunkCount += Description.ChunkRawHashes.size(); - } - } - } + std::vector<uint32_t> LooseChunkIndexes; + std::vector<uint32_t> NewBlockChunkIndexes; + std::vector<size_t> ReuseBlockIndexes; + ClassifyChunksByBlockEligibility(LocalContent, + LooseChunkIndexes, + NewBlockChunkIndexes, + ReuseBlockIndexes, + LooseChunksStats, + FindBlocksStats, + ReuseBlocksStats); std::vector<std::vector<uint32_t>> NewBlockChunks; ArrangeChunksIntoBlocks(LocalContent, LocalLookup, NewBlockChunkIndexes, NewBlockChunks); @@ -5460,43 +5422,43 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController : 0.0; if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Found {} chunks in {} ({}) blocks eligible for reuse in {}\n" - " Reusing {} ({}) matching chunks in {} blocks ({:.1f}%)\n" - " Accepting {} ({}) redundant chunks ({:.1f}%)\n" - " Rejected {} ({}) chunks in {} blocks\n" - " Arranged {} ({}) chunks in {} new blocks\n" - " Keeping {} ({}) chunks as loose chunks\n" - " Discovery completed in {}", - FindBlocksStats.FoundBlockChunkCount, - FindBlocksStats.FoundBlockCount, - NiceBytes(FindBlocksStats.FoundBlockByteCount), - NiceTimeSpanMs(FindBlocksStats.FindBlockTimeMS), + ZEN_INFO( + "Found {} chunks in {} ({}) blocks eligible for reuse in {}\n" + " Reusing {} ({}) matching chunks in {} blocks ({:.1f}%)\n" + " Accepting {} ({}) redundant chunks ({:.1f}%)\n" + " Rejected {} ({}) chunks in {} blocks\n" + " Arranged {} ({}) chunks in {} new blocks\n" + " Keeping {} ({}) chunks as loose chunks\n" + " Discovery completed in {}", + FindBlocksStats.FoundBlockChunkCount, + FindBlocksStats.FoundBlockCount, + NiceBytes(FindBlocksStats.FoundBlockByteCount), + NiceTimeSpanMs(FindBlocksStats.FindBlockTimeMS), - ReuseBlocksStats.AcceptedChunkCount, - NiceBytes(ReuseBlocksStats.AcceptedRawByteCount), - FindBlocksStats.AcceptedBlockCount, - AcceptedByteCountPercent, + ReuseBlocksStats.AcceptedChunkCount, + NiceBytes(ReuseBlocksStats.AcceptedRawByteCount), + FindBlocksStats.AcceptedBlockCount, + AcceptedByteCountPercent, - ReuseBlocksStats.AcceptedReduntantChunkCount, - NiceBytes(ReuseBlocksStats.AcceptedReduntantByteCount), - AcceptedReduntantByteCountPercent, + ReuseBlocksStats.AcceptedReduntantChunkCount, + NiceBytes(ReuseBlocksStats.AcceptedReduntantByteCount), + AcceptedReduntantByteCountPercent, - ReuseBlocksStats.RejectedChunkCount, - NiceBytes(ReuseBlocksStats.RejectedByteCount), - ReuseBlocksStats.RejectedBlockCount, + ReuseBlocksStats.RejectedChunkCount, + NiceBytes(ReuseBlocksStats.RejectedByteCount), + ReuseBlocksStats.RejectedBlockCount, - FindBlocksStats.NewBlocksChunkCount, - NiceBytes(FindBlocksStats.NewBlocksChunkByteCount), - FindBlocksStats.NewBlocksCount, + FindBlocksStats.NewBlocksChunkCount, + NiceBytes(FindBlocksStats.NewBlocksChunkByteCount), + FindBlocksStats.NewBlocksCount, - LooseChunksStats.ChunkCount, - NiceBytes(LooseChunksStats.ChunkByteCount), + LooseChunksStats.ChunkCount, + NiceBytes(LooseChunksStats.ChunkByteCount), - NiceTimeSpanMs(BlockArrangeTimer.GetElapsedTimeMs())); + NiceTimeSpanMs(BlockArrangeTimer.GetElapsedTimeMs())); } - m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::GenerateBlocks, StepCount); + m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::GenerateBlocks, StepCount); GeneratedBlocks NewBlocks; if (!NewBlockChunks.empty()) @@ -5506,295 +5468,523 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController uint64_t BlockGenerateTimeUs = GenerateBuildBlocksTimer.GetElapsedTimeUs(); if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO( - m_LogOutput, - "Generated {} ({}) and uploaded {} ({}) blocks in {}. Generate speed: {}B/sec. Transfer speed {}bits/sec.", - GenerateBlocksStats.GeneratedBlockCount.load(), - NiceBytes(GenerateBlocksStats.GeneratedBlockByteCount), - UploadStats.BlockCount.load(), - NiceBytes(UploadStats.BlocksBytes.load()), - NiceTimeSpanMs(BlockGenerateTimeUs / 1000), - NiceNum(GetBytesPerSecond(GenerateBlocksStats.GenerateBlocksElapsedWallTimeUS, - GenerateBlocksStats.GeneratedBlockByteCount)), - NiceNum(GetBytesPerSecond(UploadStats.ElapsedWallTimeUS, UploadStats.BlocksBytes * 8))); + ZEN_INFO("Generated {} ({}) and uploaded {} ({}) blocks in {}. Generate speed: {}B/sec. Transfer speed {}bits/sec.", + GenerateBlocksStats.GeneratedBlockCount.load(), + NiceBytes(GenerateBlocksStats.GeneratedBlockByteCount), + UploadStats.BlockCount.load(), + NiceBytes(UploadStats.BlocksBytes.load()), + NiceTimeSpanMs(BlockGenerateTimeUs / 1000), + NiceNum(GetBytesPerSecond(GenerateBlocksStats.GenerateBlocksElapsedWallTimeUS, + GenerateBlocksStats.GeneratedBlockByteCount)), + NiceNum(GetBytesPerSecond(UploadStats.ElapsedWallTimeUS, UploadStats.BlocksBytes * 8))); } }); GenerateBuildBlocks(LocalContent, LocalLookup, NewBlockChunks, NewBlocks, GenerateBlocksStats, UploadStats); } - m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::BuildPartManifest, StepCount); + m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::BuildPartManifest, StepCount); + + BuiltPartManifest Manifest = + BuildPartManifestObject(LocalContent, LocalLookup, ChunkController, ReuseBlockIndexes, NewBlocks, LooseChunkIndexes); + + m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::UploadBuildPart, StepCount); - CbObject PartManifest; + Stopwatch PutBuildPartResultTimer; + std::pair<IoHash, std::vector<IoHash>> PutBuildPartResult = + m_Storage.BuildStorage->PutBuildPart(m_BuildId, Part.PartId, Part.PartName, Manifest.PartManifest); + if (!m_Options.IsQuiet) { - CbObjectWriter PartManifestWriter; - Stopwatch ManifestGenerationTimer; - auto __ = MakeGuard([&]() { - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Generated build part manifest in {} ({})", - NiceTimeSpanMs(ManifestGenerationTimer.GetElapsedTimeMs()), - NiceBytes(PartManifestWriter.GetSaveSize())); - } - }); + ZEN_INFO("PutBuildPart took {}, payload size {}. {} attachments are needed.", + NiceTimeSpanMs(PutBuildPartResultTimer.GetElapsedTimeMs()), + NiceBytes(Manifest.PartManifest.GetSize()), + PutBuildPartResult.second.size()); + } + IoHash PartHash = PutBuildPartResult.first; - PartManifestWriter.BeginObject("chunker"sv); - { - PartManifestWriter.AddString("name"sv, ChunkController.GetName()); - PartManifestWriter.AddObject("parameters"sv, ChunkController.GetParameters()); - } - PartManifestWriter.EndObject(); // chunker + m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::UploadAttachments, StepCount); - std::vector<IoHash> AllChunkBlockHashes; - std::vector<ChunkBlockDescription> AllChunkBlockDescriptions; - AllChunkBlockHashes.reserve(ReuseBlockIndexes.size() + NewBlocks.BlockDescriptions.size()); - AllChunkBlockDescriptions.reserve(ReuseBlockIndexes.size() + NewBlocks.BlockDescriptions.size()); - for (size_t ReuseBlockIndex : ReuseBlockIndexes) - { - AllChunkBlockDescriptions.push_back(m_KnownBlocks[ReuseBlockIndex]); - AllChunkBlockHashes.push_back(m_KnownBlocks[ReuseBlockIndex].BlockHash); - } - AllChunkBlockDescriptions.insert(AllChunkBlockDescriptions.end(), - NewBlocks.BlockDescriptions.begin(), - NewBlocks.BlockDescriptions.end()); - for (const ChunkBlockDescription& BlockDescription : NewBlocks.BlockDescriptions) + std::vector<IoHash> UnknownChunks; + if (m_Options.IgnoreExistingBlocks) + { + if (m_Options.IsVerbose) { - AllChunkBlockHashes.push_back(BlockDescription.BlockHash); + ZEN_INFO("PutBuildPart uploading all attachments, needs are: {}", FormatArray<IoHash>(PutBuildPartResult.second, "\n "sv)); } - std::vector<IoHash> AbsoluteChunkHashes; - if (m_Options.DoExtraContentValidation) + std::vector<IoHash> ForceUploadChunkHashes; + ForceUploadChunkHashes.reserve(LooseChunkIndexes.size()); + + for (uint32_t ChunkIndex : LooseChunkIndexes) { - tsl::robin_map<IoHash, size_t, IoHash::Hasher> ChunkHashToAbsoluteChunkIndex; - AbsoluteChunkHashes.reserve(LocalContent.ChunkedContent.ChunkHashes.size()); - for (uint32_t ChunkIndex : LooseChunkIndexes) - { - ChunkHashToAbsoluteChunkIndex.insert({LocalContent.ChunkedContent.ChunkHashes[ChunkIndex], AbsoluteChunkHashes.size()}); - AbsoluteChunkHashes.push_back(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]); - } - for (const ChunkBlockDescription& Block : AllChunkBlockDescriptions) - { - for (const IoHash& ChunkHash : Block.ChunkRawHashes) - { - ChunkHashToAbsoluteChunkIndex.insert({ChunkHash, AbsoluteChunkHashes.size()}); - AbsoluteChunkHashes.push_back(ChunkHash); - } - } - for (const IoHash& ChunkHash : LocalContent.ChunkedContent.ChunkHashes) - { - ZEN_ASSERT(AbsoluteChunkHashes[ChunkHashToAbsoluteChunkIndex.at(ChunkHash)] == ChunkHash); - ZEN_ASSERT(LocalContent.ChunkedContent.ChunkHashes[LocalLookup.ChunkHashToChunkIndex.at(ChunkHash)] == ChunkHash); - } - for (const uint32_t ChunkIndex : LocalContent.ChunkedContent.ChunkOrders) - { - ZEN_ASSERT(AbsoluteChunkHashes[ChunkHashToAbsoluteChunkIndex.at(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex])] == - LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]); - ZEN_ASSERT(LocalLookup.ChunkHashToChunkIndex.at(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]) == ChunkIndex); - } + ForceUploadChunkHashes.push_back(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]); } - std::vector<uint32_t> AbsoluteChunkOrders = CalculateAbsoluteChunkOrders(LocalContent.ChunkedContent.ChunkHashes, - LocalContent.ChunkedContent.ChunkOrders, - LocalLookup.ChunkHashToChunkIndex, - LooseChunkIndexes, - AllChunkBlockDescriptions); - if (m_Options.DoExtraContentValidation) + for (size_t BlockIndex = 0; BlockIndex < NewBlocks.BlockHeaders.size(); BlockIndex++) { - for (uint32_t ChunkOrderIndex = 0; ChunkOrderIndex < LocalContent.ChunkedContent.ChunkOrders.size(); ChunkOrderIndex++) + if (NewBlocks.BlockHeaders[BlockIndex]) { - uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[ChunkOrderIndex]; - uint32_t AbsoluteChunkIndex = AbsoluteChunkOrders[ChunkOrderIndex]; - const IoHash& LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex]; - const IoHash& AbsoluteChunkHash = AbsoluteChunkHashes[AbsoluteChunkIndex]; - ZEN_ASSERT(LocalChunkHash == AbsoluteChunkHash); + // Block was not uploaded during generation + ForceUploadChunkHashes.push_back(NewBlocks.BlockDescriptions[BlockIndex].BlockHash); } } - - WriteBuildContentToCompactBinary(PartManifestWriter, - LocalContent.Platform, - LocalContent.Paths, - LocalContent.RawHashes, - LocalContent.RawSizes, - LocalContent.Attributes, - LocalContent.ChunkedContent.SequenceRawHashes, - LocalContent.ChunkedContent.ChunkCounts, - LocalContent.ChunkedContent.ChunkHashes, - LocalContent.ChunkedContent.ChunkRawSizes, - AbsoluteChunkOrders, - LooseChunkIndexes, - AllChunkBlockHashes); - - if (m_Options.DoExtraContentValidation) + UploadAttachmentBatch(ForceUploadChunkHashes, + UnknownChunks, + LocalContent, + LocalLookup, + NewBlockChunks, + NewBlocks, + LooseChunkIndexes, + UploadStats, + LooseChunksStats); + } + else if (!PutBuildPartResult.second.empty()) + { + if (m_Options.IsVerbose) { - ChunkedFolderContent VerifyFolderContent; - - std::vector<uint32_t> OutAbsoluteChunkOrders; - std::vector<IoHash> OutLooseChunkHashes; - std::vector<uint64_t> OutLooseChunkRawSizes; - std::vector<IoHash> OutBlockRawHashes; - ReadBuildContentFromCompactBinary(PartManifestWriter.Save(), - VerifyFolderContent.Platform, - VerifyFolderContent.Paths, - VerifyFolderContent.RawHashes, - VerifyFolderContent.RawSizes, - VerifyFolderContent.Attributes, - VerifyFolderContent.ChunkedContent.SequenceRawHashes, - VerifyFolderContent.ChunkedContent.ChunkCounts, - OutAbsoluteChunkOrders, - OutLooseChunkHashes, - OutLooseChunkRawSizes, - OutBlockRawHashes); - ZEN_ASSERT(OutBlockRawHashes == AllChunkBlockHashes); + ZEN_INFO("PutBuildPart needs attachments: {}", FormatArray<IoHash>(PutBuildPartResult.second, "\n "sv)); + } + UploadAttachmentBatch(PutBuildPartResult.second, + UnknownChunks, + LocalContent, + LocalLookup, + NewBlockChunks, + NewBlocks, + LooseChunkIndexes, + UploadStats, + LooseChunksStats); + } + + FinalizeBuildPartWithRetries(Part, + PartHash, + UnknownChunks, + LocalContent, + LocalLookup, + NewBlockChunks, + NewBlocks, + LooseChunkIndexes, + UploadStats, + LooseChunksStats); - for (uint32_t OrderIndex = 0; OrderIndex < OutAbsoluteChunkOrders.size(); OrderIndex++) - { - uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[OrderIndex]; - const IoHash LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex]; + if (!NewBlocks.BlockDescriptions.empty() && !m_AbortFlag) + { + UploadMissingBlockMetadata(NewBlocks, UploadStats); + // The newly generated blocks are now known blocks so the next part upload can use those blocks as well + m_KnownBlocks.insert(m_KnownBlocks.end(), NewBlocks.BlockDescriptions.begin(), NewBlocks.BlockDescriptions.end()); + } - uint32_t VerifyChunkIndex = OutAbsoluteChunkOrders[OrderIndex]; - const IoHash VerifyChunkHash = AbsoluteChunkHashes[VerifyChunkIndex]; + m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::PutBuildPartStats, StepCount); - ZEN_ASSERT(LocalChunkHash == VerifyChunkHash); - } + m_Storage.BuildStorage->PutBuildPartStats( + m_BuildId, + Part.PartId, + {{"totalSize", double(Part.LocalFolderScanStats.FoundFileByteCount.load())}, + {"reusedRatio", AcceptedByteCountPercent / 100.0}, + {"reusedBlockCount", double(FindBlocksStats.AcceptedBlockCount)}, + {"reusedBlockByteCount", double(ReuseBlocksStats.AcceptedRawByteCount)}, + {"newBlockCount", double(FindBlocksStats.NewBlocksCount)}, + {"newBlockByteCount", double(FindBlocksStats.NewBlocksChunkByteCount)}, + {"uploadedCount", double(UploadStats.BlockCount.load() + UploadStats.ChunkCount.load())}, + {"uploadedByteCount", double(UploadStats.BlocksBytes.load() + UploadStats.ChunksBytes.load())}, + {"uploadedBytesPerSec", + double(GetBytesPerSecond(UploadStats.ElapsedWallTimeUS, UploadStats.ChunksBytes + UploadStats.BlocksBytes))}, + {"elapsedTimeSec", double(UploadTimer.GetElapsedTimeMs() / 1000.0)}}); - CalculateLocalChunkOrders(OutAbsoluteChunkOrders, - OutLooseChunkHashes, - OutLooseChunkRawSizes, - AllChunkBlockDescriptions, - VerifyFolderContent.ChunkedContent.ChunkHashes, - VerifyFolderContent.ChunkedContent.ChunkRawSizes, - VerifyFolderContent.ChunkedContent.ChunkOrders, - m_Options.DoExtraContentValidation); + m_LocalFolderScanStats += Part.LocalFolderScanStats; + m_ChunkingStats += ChunkingStats; + m_FindBlocksStats += FindBlocksStats; + m_ReuseBlocksStats += ReuseBlocksStats; + m_UploadStats += UploadStats; + m_GenerateBlocksStats += GenerateBlocksStats; + m_LooseChunksStats += LooseChunksStats; +} - ZEN_ASSERT(LocalContent.Paths == VerifyFolderContent.Paths); - ZEN_ASSERT(LocalContent.RawHashes == VerifyFolderContent.RawHashes); - ZEN_ASSERT(LocalContent.RawSizes == VerifyFolderContent.RawSizes); - ZEN_ASSERT(LocalContent.Attributes == VerifyFolderContent.Attributes); - ZEN_ASSERT(LocalContent.ChunkedContent.SequenceRawHashes == VerifyFolderContent.ChunkedContent.SequenceRawHashes); - ZEN_ASSERT(LocalContent.ChunkedContent.ChunkCounts == VerifyFolderContent.ChunkedContent.ChunkCounts); +ChunkedFolderContent +BuildsOperationUploadFolder::ScanPartContent(const UploadPart& Part, + ChunkingController& ChunkController, + ChunkingCache& ChunkCache, + ChunkingStatistics& ChunkingStats) +{ + Stopwatch ScanTimer; - for (uint32_t OrderIndex = 0; OrderIndex < LocalContent.ChunkedContent.ChunkOrders.size(); OrderIndex++) - { - uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[OrderIndex]; - const IoHash LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex]; - uint64_t LocalChunkRawSize = LocalContent.ChunkedContent.ChunkRawSizes[LocalChunkIndex]; + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Scan Folder"); - uint32_t VerifyChunkIndex = VerifyFolderContent.ChunkedContent.ChunkOrders[OrderIndex]; - const IoHash VerifyChunkHash = VerifyFolderContent.ChunkedContent.ChunkHashes[VerifyChunkIndex]; - uint64_t VerifyChunkRawSize = VerifyFolderContent.ChunkedContent.ChunkRawSizes[VerifyChunkIndex]; + FilteredRate FilteredBytesHashed; + FilteredBytesHashed.Start(); + ChunkedFolderContent LocalContent = ChunkFolderContent( + ChunkingStats, + m_IOWorkerPool, + m_Path, + Part.Content, + ChunkController, + ChunkCache, + m_Progress.GetProgressUpdateDelayMS(), + [&](bool IsAborted, bool IsPaused, std::ptrdiff_t) { + FilteredBytesHashed.Update(ChunkingStats.BytesHashed.load()); + std::string Details = fmt::format("{}/{} ({}/{}, {}B/s) scanned, {} ({}) chunks found", + ChunkingStats.FilesProcessed.load(), + Part.Content.Paths.size(), + NiceBytes(ChunkingStats.BytesHashed.load()), + NiceBytes(Part.TotalRawSize), + NiceNum(FilteredBytesHashed.GetCurrent()), + ChunkingStats.UniqueChunksFound.load(), + NiceBytes(ChunkingStats.UniqueBytesFound.load())); + ProgressBar->UpdateState({.Task = "Scanning files ", + .Details = Details, + .TotalCount = Part.TotalRawSize, + .RemainingCount = Part.TotalRawSize - ChunkingStats.BytesHashed.load(), + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + false); + }, + m_AbortFlag, + m_PauseFlag); + FilteredBytesHashed.Stop(); + ProgressBar->Finish(); + if (m_AbortFlag) + { + return LocalContent; + } - ZEN_ASSERT(LocalChunkHash == VerifyChunkHash); - ZEN_ASSERT(LocalChunkRawSize == VerifyChunkRawSize); - } - } - PartManifest = PartManifestWriter.Save(); + if (!m_Options.IsQuiet) + { + ZEN_INFO("Found {} ({}) files divided into {} ({}) unique chunks in '{}' in {}. Average hash rate {}B/sec", + Part.Content.Paths.size(), + NiceBytes(Part.TotalRawSize), + ChunkingStats.UniqueChunksFound.load(), + NiceBytes(ChunkingStats.UniqueBytesFound.load()), + m_Path, + NiceTimeSpanMs(ScanTimer.GetElapsedTimeMs()), + NiceNum(GetBytesPerSecond(ChunkingStats.ElapsedWallTimeUS, ChunkingStats.BytesHashed))); } - m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::UploadBuildPart, StepCount); + return LocalContent; +} + +void +BuildsOperationUploadFolder::ConsumePrepareBuildResult() +{ + const PrepareBuildResult PrepBuildResult = m_PrepBuildResultFuture.get(); + + m_FindBlocksStats.FindBlockTimeMS = PrepBuildResult.ElapsedTimeMs; + m_FindBlocksStats.FoundBlockCount = PrepBuildResult.KnownBlocks.size(); - Stopwatch PutBuildPartResultTimer; - std::pair<IoHash, std::vector<IoHash>> PutBuildPartResult = - m_Storage.BuildStorage->PutBuildPart(m_BuildId, Part.PartId, Part.PartName, PartManifest); if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "PutBuildPart took {}, payload size {}. {} attachments are needed.", - NiceTimeSpanMs(PutBuildPartResultTimer.GetElapsedTimeMs()), - NiceBytes(PartManifest.GetSize()), - PutBuildPartResult.second.size()); + ZEN_INFO("Build prepare took {}. {} took {}, payload size {}{}", + NiceTimeSpanMs(PrepBuildResult.ElapsedTimeMs), + m_CreateBuild ? "PutBuild" : "GetBuild", + NiceTimeSpanMs(PrepBuildResult.PrepareBuildTimeMs), + NiceBytes(PrepBuildResult.PayloadSize), + m_Options.IgnoreExistingBlocks ? "" + : fmt::format(". Found {} blocks in {}", + PrepBuildResult.KnownBlocks.size(), + NiceTimeSpanMs(PrepBuildResult.FindBlocksTimeMs))); } - IoHash PartHash = PutBuildPartResult.first; - auto UploadAttachments = - [this, &LooseChunksStats, &UploadStats, &LocalContent, &LocalLookup, &NewBlockChunks, &NewBlocks, &LooseChunkIndexes]( - std::span<IoHash> RawHashes, - std::vector<IoHash>& OutUnknownChunks) { - if (!m_AbortFlag) - { - UploadStatistics TempUploadStats; - LooseChunksStatistics TempLooseChunksStats; - - Stopwatch TempUploadTimer; - auto __ = MakeGuard([&]() { - if (!m_Options.IsQuiet) - { - uint64_t TempChunkUploadTimeUs = TempUploadTimer.GetElapsedTimeUs(); - ZEN_OPERATION_LOG_INFO( - m_LogOutput, - "Uploaded {} ({}) blocks. " - "Compressed {} ({} {}B/s) and uploaded {} ({}) chunks. " - "Transferred {} ({}bits/s) in {}", - TempUploadStats.BlockCount.load(), - NiceBytes(TempUploadStats.BlocksBytes), - - TempLooseChunksStats.CompressedChunkCount.load(), - NiceBytes(TempLooseChunksStats.CompressedChunkBytes.load()), - NiceNum(GetBytesPerSecond(TempLooseChunksStats.CompressChunksElapsedWallTimeUS, - TempLooseChunksStats.ChunkByteCount)), - TempUploadStats.ChunkCount.load(), - NiceBytes(TempUploadStats.ChunksBytes), - - NiceBytes(TempUploadStats.BlocksBytes + TempUploadStats.ChunksBytes), - NiceNum(GetBytesPerSecond(TempUploadStats.ElapsedWallTimeUS, TempUploadStats.ChunksBytes * 8)), - NiceTimeSpanMs(TempChunkUploadTimeUs / 1000)); - } - }); - UploadPartBlobs(LocalContent, - LocalLookup, - RawHashes, - NewBlockChunks, - NewBlocks, - LooseChunkIndexes, - m_LargeAttachmentSize, - TempUploadStats, - TempLooseChunksStats, - OutUnknownChunks); - UploadStats += TempUploadStats; - LooseChunksStats += TempLooseChunksStats; - } - }; + m_PreferredMultipartChunkSize = PrepBuildResult.PreferredMultipartChunkSize; + m_LargeAttachmentSize = m_Options.AllowMultiparts ? m_PreferredMultipartChunkSize * 4u : (std::uint64_t)-1; + m_KnownBlocks = std::move(PrepBuildResult.KnownBlocks); +} - m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::UploadAttachments, StepCount); +void +BuildsOperationUploadFolder::ClassifyChunksByBlockEligibility(const ChunkedFolderContent& LocalContent, + std::vector<uint32_t>& OutLooseChunkIndexes, + std::vector<uint32_t>& OutNewBlockChunkIndexes, + std::vector<size_t>& OutReuseBlockIndexes, + LooseChunksStatistics& LooseChunksStats, + FindBlocksStatistics& FindBlocksStats, + ReuseBlocksStatistics& ReuseBlocksStats) +{ + const bool EnableBlocks = true; + std::vector<std::uint32_t> BlockChunkIndexes; + for (uint32_t ChunkIndex = 0; ChunkIndex < LocalContent.ChunkedContent.ChunkHashes.size(); ChunkIndex++) + { + const uint64_t ChunkRawSize = LocalContent.ChunkedContent.ChunkRawSizes[ChunkIndex]; + if (!EnableBlocks || ChunkRawSize == 0 || ChunkRawSize > m_Options.BlockParameters.MaxChunkEmbedSize) + { + OutLooseChunkIndexes.push_back(ChunkIndex); + LooseChunksStats.ChunkByteCount += ChunkRawSize; + } + else + { + BlockChunkIndexes.push_back(ChunkIndex); + FindBlocksStats.PotentialChunkByteCount += ChunkRawSize; + } + } + FindBlocksStats.PotentialChunkCount += BlockChunkIndexes.size(); + LooseChunksStats.ChunkCount = OutLooseChunkIndexes.size(); - std::vector<IoHash> UnknownChunks; if (m_Options.IgnoreExistingBlocks) { - if (m_Options.IsVerbose) + if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "PutBuildPart uploading all attachments, needs are: {}", - FormatArray<IoHash>(PutBuildPartResult.second, "\n "sv)); + ZEN_INFO("Ignoring any existing blocks in store"); } + OutNewBlockChunkIndexes = std::move(BlockChunkIndexes); + return; + } - std::vector<IoHash> ForceUploadChunkHashes; - ForceUploadChunkHashes.reserve(LooseChunkIndexes.size()); + OutReuseBlockIndexes = FindReuseBlocks(Log(), + m_Options.BlockReuseMinPercentLimit, + m_Options.IsVerbose, + ReuseBlocksStats, + m_KnownBlocks, + LocalContent.ChunkedContent.ChunkHashes, + BlockChunkIndexes, + OutNewBlockChunkIndexes); + FindBlocksStats.AcceptedBlockCount += OutReuseBlockIndexes.size(); - for (uint32_t ChunkIndex : LooseChunkIndexes) + for (const ChunkBlockDescription& Description : m_KnownBlocks) + { + for (uint32_t ChunkRawLength : Description.ChunkRawLengths) { - ForceUploadChunkHashes.push_back(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]); + FindBlocksStats.FoundBlockByteCount += ChunkRawLength; } + FindBlocksStats.FoundBlockChunkCount += Description.ChunkRawHashes.size(); + } +} - for (size_t BlockIndex = 0; BlockIndex < NewBlocks.BlockHeaders.size(); BlockIndex++) +BuildsOperationUploadFolder::BuiltPartManifest +BuildsOperationUploadFolder::BuildPartManifestObject(const ChunkedFolderContent& LocalContent, + const ChunkedContentLookup& LocalLookup, + ChunkingController& ChunkController, + std::span<const size_t> ReuseBlockIndexes, + const GeneratedBlocks& NewBlocks, + std::span<const uint32_t> LooseChunkIndexes) +{ + BuiltPartManifest Result; + + CbObjectWriter PartManifestWriter; + Stopwatch ManifestGenerationTimer; + auto __ = MakeGuard([&]() { + if (!m_Options.IsQuiet) + { + ZEN_INFO("Generated build part manifest in {} ({})", + NiceTimeSpanMs(ManifestGenerationTimer.GetElapsedTimeMs()), + NiceBytes(PartManifestWriter.GetSaveSize())); + } + }); + + PartManifestWriter.BeginObject("chunker"sv); + { + PartManifestWriter.AddString("name"sv, ChunkController.GetName()); + PartManifestWriter.AddObject("parameters"sv, ChunkController.GetParameters()); + } + PartManifestWriter.EndObject(); // chunker + + Result.AllChunkBlockHashes.reserve(ReuseBlockIndexes.size() + NewBlocks.BlockDescriptions.size()); + Result.AllChunkBlockDescriptions.reserve(ReuseBlockIndexes.size() + NewBlocks.BlockDescriptions.size()); + for (size_t ReuseBlockIndex : ReuseBlockIndexes) + { + Result.AllChunkBlockDescriptions.push_back(m_KnownBlocks[ReuseBlockIndex]); + Result.AllChunkBlockHashes.push_back(m_KnownBlocks[ReuseBlockIndex].BlockHash); + } + Result.AllChunkBlockDescriptions.insert(Result.AllChunkBlockDescriptions.end(), + NewBlocks.BlockDescriptions.begin(), + NewBlocks.BlockDescriptions.end()); + for (const ChunkBlockDescription& BlockDescription : NewBlocks.BlockDescriptions) + { + Result.AllChunkBlockHashes.push_back(BlockDescription.BlockHash); + } + + std::vector<IoHash> AbsoluteChunkHashes; + if (m_Options.DoExtraContentValidation) + { + tsl::robin_map<IoHash, size_t, IoHash::Hasher> ChunkHashToAbsoluteChunkIndex; + AbsoluteChunkHashes.reserve(LocalContent.ChunkedContent.ChunkHashes.size()); + for (uint32_t ChunkIndex : LooseChunkIndexes) { - if (NewBlocks.BlockHeaders[BlockIndex]) + ChunkHashToAbsoluteChunkIndex.insert({LocalContent.ChunkedContent.ChunkHashes[ChunkIndex], AbsoluteChunkHashes.size()}); + AbsoluteChunkHashes.push_back(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]); + } + for (const ChunkBlockDescription& Block : Result.AllChunkBlockDescriptions) + { + for (const IoHash& ChunkHash : Block.ChunkRawHashes) { - // Block was not uploaded during generation - ForceUploadChunkHashes.push_back(NewBlocks.BlockDescriptions[BlockIndex].BlockHash); + ChunkHashToAbsoluteChunkIndex.insert({ChunkHash, AbsoluteChunkHashes.size()}); + AbsoluteChunkHashes.push_back(ChunkHash); } } - UploadAttachments(ForceUploadChunkHashes, UnknownChunks); + for (const IoHash& ChunkHash : LocalContent.ChunkedContent.ChunkHashes) + { + ZEN_ASSERT(AbsoluteChunkHashes[ChunkHashToAbsoluteChunkIndex.at(ChunkHash)] == ChunkHash); + ZEN_ASSERT(LocalContent.ChunkedContent.ChunkHashes[LocalLookup.ChunkHashToChunkIndex.at(ChunkHash)] == ChunkHash); + } + for (const uint32_t ChunkIndex : LocalContent.ChunkedContent.ChunkOrders) + { + ZEN_ASSERT(AbsoluteChunkHashes[ChunkHashToAbsoluteChunkIndex.at(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex])] == + LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]); + ZEN_ASSERT(LocalLookup.ChunkHashToChunkIndex.at(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]) == ChunkIndex); + } } - else if (!PutBuildPartResult.second.empty()) + + std::vector<uint32_t> AbsoluteChunkOrders = CalculateAbsoluteChunkOrders(LocalContent.ChunkedContent.ChunkHashes, + LocalContent.ChunkedContent.ChunkOrders, + LocalLookup.ChunkHashToChunkIndex, + LooseChunkIndexes, + Result.AllChunkBlockDescriptions); + + if (m_Options.DoExtraContentValidation) { - if (m_Options.IsVerbose) + for (uint32_t ChunkOrderIndex = 0; ChunkOrderIndex < LocalContent.ChunkedContent.ChunkOrders.size(); ChunkOrderIndex++) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "PutBuildPart needs attachments: {}", - FormatArray<IoHash>(PutBuildPartResult.second, "\n "sv)); + uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[ChunkOrderIndex]; + uint32_t AbsoluteChunkIndex = AbsoluteChunkOrders[ChunkOrderIndex]; + const IoHash& LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex]; + const IoHash& AbsoluteChunkHash = AbsoluteChunkHashes[AbsoluteChunkIndex]; + ZEN_ASSERT(LocalChunkHash == AbsoluteChunkHash); } - UploadAttachments(PutBuildPartResult.second, UnknownChunks); } + WriteBuildContentToCompactBinary(PartManifestWriter, + LocalContent.Platform, + LocalContent.Paths, + LocalContent.RawHashes, + LocalContent.RawSizes, + LocalContent.Attributes, + LocalContent.ChunkedContent.SequenceRawHashes, + LocalContent.ChunkedContent.ChunkCounts, + LocalContent.ChunkedContent.ChunkHashes, + LocalContent.ChunkedContent.ChunkRawSizes, + AbsoluteChunkOrders, + LooseChunkIndexes, + Result.AllChunkBlockHashes); + + if (m_Options.DoExtraContentValidation) + { + ChunkedFolderContent VerifyFolderContent; + + std::vector<uint32_t> OutAbsoluteChunkOrders; + std::vector<IoHash> OutLooseChunkHashes; + std::vector<uint64_t> OutLooseChunkRawSizes; + std::vector<IoHash> OutBlockRawHashes; + ReadBuildContentFromCompactBinary(PartManifestWriter.Save(), + VerifyFolderContent.Platform, + VerifyFolderContent.Paths, + VerifyFolderContent.RawHashes, + VerifyFolderContent.RawSizes, + VerifyFolderContent.Attributes, + VerifyFolderContent.ChunkedContent.SequenceRawHashes, + VerifyFolderContent.ChunkedContent.ChunkCounts, + OutAbsoluteChunkOrders, + OutLooseChunkHashes, + OutLooseChunkRawSizes, + OutBlockRawHashes); + ZEN_ASSERT(OutBlockRawHashes == Result.AllChunkBlockHashes); + + for (uint32_t OrderIndex = 0; OrderIndex < OutAbsoluteChunkOrders.size(); OrderIndex++) + { + uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[OrderIndex]; + const IoHash LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex]; + + uint32_t VerifyChunkIndex = OutAbsoluteChunkOrders[OrderIndex]; + const IoHash VerifyChunkHash = AbsoluteChunkHashes[VerifyChunkIndex]; + + ZEN_ASSERT(LocalChunkHash == VerifyChunkHash); + } + + CalculateLocalChunkOrders(OutAbsoluteChunkOrders, + OutLooseChunkHashes, + OutLooseChunkRawSizes, + Result.AllChunkBlockDescriptions, + VerifyFolderContent.ChunkedContent.ChunkHashes, + VerifyFolderContent.ChunkedContent.ChunkRawSizes, + VerifyFolderContent.ChunkedContent.ChunkOrders, + m_Options.DoExtraContentValidation); + + ZEN_ASSERT(LocalContent.Paths == VerifyFolderContent.Paths); + ZEN_ASSERT(LocalContent.RawHashes == VerifyFolderContent.RawHashes); + ZEN_ASSERT(LocalContent.RawSizes == VerifyFolderContent.RawSizes); + ZEN_ASSERT(LocalContent.Attributes == VerifyFolderContent.Attributes); + ZEN_ASSERT(LocalContent.ChunkedContent.SequenceRawHashes == VerifyFolderContent.ChunkedContent.SequenceRawHashes); + ZEN_ASSERT(LocalContent.ChunkedContent.ChunkCounts == VerifyFolderContent.ChunkedContent.ChunkCounts); + + for (uint32_t OrderIndex = 0; OrderIndex < LocalContent.ChunkedContent.ChunkOrders.size(); OrderIndex++) + { + uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[OrderIndex]; + const IoHash LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex]; + uint64_t LocalChunkRawSize = LocalContent.ChunkedContent.ChunkRawSizes[LocalChunkIndex]; + + uint32_t VerifyChunkIndex = VerifyFolderContent.ChunkedContent.ChunkOrders[OrderIndex]; + const IoHash VerifyChunkHash = VerifyFolderContent.ChunkedContent.ChunkHashes[VerifyChunkIndex]; + uint64_t VerifyChunkRawSize = VerifyFolderContent.ChunkedContent.ChunkRawSizes[VerifyChunkIndex]; + + ZEN_ASSERT(LocalChunkHash == VerifyChunkHash); + ZEN_ASSERT(LocalChunkRawSize == VerifyChunkRawSize); + } + } + + Result.PartManifest = PartManifestWriter.Save(); + return Result; +} + +void +BuildsOperationUploadFolder::UploadAttachmentBatch(std::span<IoHash> RawHashes, + std::vector<IoHash>& OutUnknownChunks, + const ChunkedFolderContent& LocalContent, + const ChunkedContentLookup& LocalLookup, + const std::vector<std::vector<uint32_t>>& NewBlockChunks, + GeneratedBlocks& NewBlocks, + std::span<const uint32_t> LooseChunkIndexes, + UploadStatistics& UploadStats, + LooseChunksStatistics& LooseChunksStats) +{ + if (m_AbortFlag) + { + return; + } + + UploadStatistics TempUploadStats; + LooseChunksStatistics TempLooseChunksStats; + + Stopwatch TempUploadTimer; + auto __ = MakeGuard([&]() { + if (!m_Options.IsQuiet) + { + uint64_t TempChunkUploadTimeUs = TempUploadTimer.GetElapsedTimeUs(); + ZEN_INFO( + "Uploaded {} ({}) blocks. " + "Compressed {} ({} {}B/s) and uploaded {} ({}) chunks. " + "Transferred {} ({}bits/s) in {}", + TempUploadStats.BlockCount.load(), + NiceBytes(TempUploadStats.BlocksBytes), + + TempLooseChunksStats.CompressedChunkCount.load(), + NiceBytes(TempLooseChunksStats.CompressedChunkBytes.load()), + NiceNum(GetBytesPerSecond(TempLooseChunksStats.CompressChunksElapsedWallTimeUS, TempLooseChunksStats.ChunkByteCount)), + TempUploadStats.ChunkCount.load(), + NiceBytes(TempUploadStats.ChunksBytes), + + NiceBytes(TempUploadStats.BlocksBytes + TempUploadStats.ChunksBytes), + NiceNum(GetBytesPerSecond(TempUploadStats.ElapsedWallTimeUS, TempUploadStats.ChunksBytes * 8)), + NiceTimeSpanMs(TempChunkUploadTimeUs / 1000)); + } + }); + UploadPartBlobs(LocalContent, + LocalLookup, + RawHashes, + NewBlockChunks, + NewBlocks, + LooseChunkIndexes, + m_LargeAttachmentSize, + TempUploadStats, + TempLooseChunksStats, + OutUnknownChunks); + UploadStats += TempUploadStats; + LooseChunksStats += TempLooseChunksStats; +} + +void +BuildsOperationUploadFolder::FinalizeBuildPartWithRetries(const UploadPart& Part, + const IoHash& PartHash, + std::vector<IoHash>& InOutUnknownChunks, + const ChunkedFolderContent& LocalContent, + const ChunkedContentLookup& LocalLookup, + const std::vector<std::vector<uint32_t>>& NewBlockChunks, + GeneratedBlocks& NewBlocks, + std::span<const uint32_t> LooseChunkIndexes, + UploadStatistics& UploadStats, + LooseChunksStatistics& LooseChunksStats) +{ auto BuildUnkownChunksResponse = [](const std::vector<IoHash>& UnknownChunks, bool WillRetry) { return fmt::format( "The following build blobs was reported as needed for upload but was reported as existing at the start of the " @@ -5803,9 +5993,9 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController FormatArray<IoHash>(UnknownChunks, "\n "sv)); }; - if (!UnknownChunks.empty()) + if (!InOutUnknownChunks.empty()) { - ZEN_OPERATION_LOG_WARN(m_LogOutput, "{}", BuildUnkownChunksResponse(UnknownChunks, /*WillRetry*/ true)); + ZEN_WARN("{}", BuildUnkownChunksResponse(InOutUnknownChunks, /*WillRetry*/ true)); } uint32_t FinalizeBuildPartRetryCount = 5; @@ -5815,10 +6005,9 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController std::vector<IoHash> Needs = m_Storage.BuildStorage->FinalizeBuildPart(m_BuildId, Part.PartId, PartHash); if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "FinalizeBuildPart took {}. {} attachments are missing.", - NiceTimeSpanMs(FinalizeBuildPartTimer.GetElapsedTimeMs()), - Needs.size()); + ZEN_INFO("FinalizeBuildPart took {}. {} attachments are missing.", + NiceTimeSpanMs(FinalizeBuildPartTimer.GetElapsedTimeMs()), + Needs.size()); } if (Needs.empty()) { @@ -5826,12 +6015,20 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController } if (m_Options.IsVerbose) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "FinalizeBuildPart needs attachments: {}", FormatArray<IoHash>(Needs, "\n "sv)); + ZEN_INFO("FinalizeBuildPart needs attachments: {}", FormatArray<IoHash>(Needs, "\n "sv)); } std::vector<IoHash> RetryUnknownChunks; - UploadAttachments(Needs, RetryUnknownChunks); - if (RetryUnknownChunks == UnknownChunks) + UploadAttachmentBatch(Needs, + RetryUnknownChunks, + LocalContent, + LocalLookup, + NewBlockChunks, + NewBlocks, + LooseChunkIndexes, + UploadStats, + LooseChunksStats); + if (RetryUnknownChunks == InOutUnknownChunks) { if (FinalizeBuildPartRetryCount > 0) { @@ -5841,100 +6038,68 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController } else { - UnknownChunks = RetryUnknownChunks; - ZEN_OPERATION_LOG_WARN(m_LogOutput, - "{}", - BuildUnkownChunksResponse(UnknownChunks, /*WillRetry*/ FinalizeBuildPartRetryCount != 0)); + InOutUnknownChunks = RetryUnknownChunks; + ZEN_WARN("{}", BuildUnkownChunksResponse(InOutUnknownChunks, /*WillRetry*/ FinalizeBuildPartRetryCount != 0)); } } - if (!UnknownChunks.empty()) + if (!InOutUnknownChunks.empty()) { - throw std::runtime_error(BuildUnkownChunksResponse(UnknownChunks, /*WillRetry*/ false)); + throw std::runtime_error(BuildUnkownChunksResponse(InOutUnknownChunks, /*WillRetry*/ false)); } +} - if (!NewBlocks.BlockDescriptions.empty() && !m_AbortFlag) - { - uint64_t UploadBlockMetadataCount = 0; - Stopwatch UploadBlockMetadataTimer; +void +BuildsOperationUploadFolder::UploadMissingBlockMetadata(GeneratedBlocks& NewBlocks, UploadStatistics& UploadStats) +{ + uint64_t UploadBlockMetadataCount = 0; + Stopwatch UploadBlockMetadataTimer; - uint32_t FailedMetadataUploadCount = 1; - int32_t MetadataUploadRetryCount = 3; - while ((MetadataUploadRetryCount-- > 0) && (FailedMetadataUploadCount > 0)) + uint32_t FailedMetadataUploadCount = 1; + int32_t MetadataUploadRetryCount = 3; + while ((MetadataUploadRetryCount-- > 0) && (FailedMetadataUploadCount > 0)) + { + FailedMetadataUploadCount = 0; + for (size_t BlockIndex = 0; BlockIndex < NewBlocks.BlockDescriptions.size(); BlockIndex++) { - FailedMetadataUploadCount = 0; - for (size_t BlockIndex = 0; BlockIndex < NewBlocks.BlockDescriptions.size(); BlockIndex++) + if (m_AbortFlag) { - if (m_AbortFlag) + break; + } + const IoHash& BlockHash = NewBlocks.BlockDescriptions[BlockIndex].BlockHash; + if (!NewBlocks.MetaDataHasBeenUploaded[BlockIndex]) + { + const CbObject BlockMetaData = + BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]); + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - break; + m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, + std::vector<IoHash>({BlockHash}), + std::vector<CbObject>({BlockMetaData})); } - const IoHash& BlockHash = NewBlocks.BlockDescriptions[BlockIndex].BlockHash; - if (!NewBlocks.MetaDataHasBeenUploaded[BlockIndex]) + bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); + if (MetadataSucceeded) { - const CbObject BlockMetaData = - BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]); - if (m_Storage.CacheStorage && m_Options.PopulateCache) - { - m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, - std::vector<IoHash>({BlockHash}), - std::vector<CbObject>({BlockMetaData})); - } - bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); - if (MetadataSucceeded) - { - UploadStats.BlocksBytes += BlockMetaData.GetSize(); - NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true; - UploadBlockMetadataCount++; - } - else - { - FailedMetadataUploadCount++; - } + UploadStats.BlocksBytes += BlockMetaData.GetSize(); + NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true; + UploadBlockMetadataCount++; + } + else + { + FailedMetadataUploadCount++; } } } - if (UploadBlockMetadataCount > 0) + } + if (UploadBlockMetadataCount > 0) + { + uint64_t ElapsedUS = UploadBlockMetadataTimer.GetElapsedTimeUs(); + UploadStats.ElapsedWallTimeUS += ElapsedUS; + if (!m_Options.IsQuiet) { - uint64_t ElapsedUS = UploadBlockMetadataTimer.GetElapsedTimeUs(); - UploadStats.ElapsedWallTimeUS += ElapsedUS; - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Uploaded metadata for {} blocks in {}", - UploadBlockMetadataCount, - NiceTimeSpanMs(ElapsedUS / 1000)); - } + ZEN_INFO("Uploaded metadata for {} blocks in {}", UploadBlockMetadataCount, NiceTimeSpanMs(ElapsedUS / 1000)); } - - // The newly generated blocks are now known blocks so the next part upload can use those blocks as well - m_KnownBlocks.insert(m_KnownBlocks.end(), NewBlocks.BlockDescriptions.begin(), NewBlocks.BlockDescriptions.end()); } - - m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::PutBuildPartStats, StepCount); - - m_Storage.BuildStorage->PutBuildPartStats( - m_BuildId, - Part.PartId, - {{"totalSize", double(Part.LocalFolderScanStats.FoundFileByteCount.load())}, - {"reusedRatio", AcceptedByteCountPercent / 100.0}, - {"reusedBlockCount", double(FindBlocksStats.AcceptedBlockCount)}, - {"reusedBlockByteCount", double(ReuseBlocksStats.AcceptedRawByteCount)}, - {"newBlockCount", double(FindBlocksStats.NewBlocksCount)}, - {"newBlockByteCount", double(FindBlocksStats.NewBlocksChunkByteCount)}, - {"uploadedCount", double(UploadStats.BlockCount.load() + UploadStats.ChunkCount.load())}, - {"uploadedByteCount", double(UploadStats.BlocksBytes.load() + UploadStats.ChunksBytes.load())}, - {"uploadedBytesPerSec", - double(GetBytesPerSecond(UploadStats.ElapsedWallTimeUS, UploadStats.ChunksBytes + UploadStats.BlocksBytes))}, - {"elapsedTimeSec", double(UploadTimer.GetElapsedTimeMs() / 1000.0)}}); - - m_LocalFolderScanStats += Part.LocalFolderScanStats; - m_ChunkingStats += ChunkingStats; - m_FindBlocksStats += FindBlocksStats; - m_ReuseBlocksStats += ReuseBlocksStats; - m_UploadStats += UploadStats; - m_GenerateBlocksStats += GenerateBlocksStats; - m_LooseChunksStats += LooseChunksStats; } void @@ -5950,463 +6115,490 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co std::vector<IoHash>& OutUnknownChunks) { ZEN_TRACE_CPU("UploadPartBlobs"); + + UploadPartClassification Classification = + ClassifyUploadRawHashes(RawHashes, Content, Lookup, NewBlocks, LooseChunkIndexes, OutUnknownChunks); + + if (Classification.BlockIndexes.empty() && Classification.LooseChunkOrderIndexes.empty()) { - std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Upload Blobs")); - OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr); + return; + } - WorkerThreadPool& ReadChunkPool = m_IOWorkerPool; - WorkerThreadPool& UploadChunkPool = m_NetworkPool; + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Upload Blobs"); + + FilteredRate FilteredGenerateBlockBytesPerSecond; + FilteredRate FilteredCompressedBytesPerSecond; + FilteredRate FilteredUploadedBytesPerSecond; + + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + std::atomic<size_t> UploadedBlockSize = 0; + std::atomic<size_t> UploadedBlockCount = 0; + std::atomic<size_t> UploadedRawChunkSize = 0; + std::atomic<size_t> UploadedCompressedChunkSize = 0; + std::atomic<uint32_t> UploadedChunkCount = 0; + std::atomic<uint64_t> GeneratedBlockCount = 0; + std::atomic<uint64_t> GeneratedBlockByteCount = 0; + std::atomic<uint64_t> QueuedPendingInMemoryBlocksForUpload = 0; + + const size_t UploadBlockCount = Classification.BlockIndexes.size(); + const uint32_t UploadChunkCount = gsl::narrow<uint32_t>(Classification.LooseChunkOrderIndexes.size()); + const uint64_t TotalRawSize = Classification.TotalLooseChunksSize + Classification.TotalBlocksSize; + + UploadPartBlobsContext Context{.Work = Work, + .ReadChunkPool = m_IOWorkerPool, + .UploadChunkPool = m_NetworkPool, + .FilteredGenerateBlockBytesPerSecond = FilteredGenerateBlockBytesPerSecond, + .FilteredCompressedBytesPerSecond = FilteredCompressedBytesPerSecond, + .FilteredUploadedBytesPerSecond = FilteredUploadedBytesPerSecond, + .UploadedBlockSize = UploadedBlockSize, + .UploadedBlockCount = UploadedBlockCount, + .UploadedRawChunkSize = UploadedRawChunkSize, + .UploadedCompressedChunkSize = UploadedCompressedChunkSize, + .UploadedChunkCount = UploadedChunkCount, + .GeneratedBlockCount = GeneratedBlockCount, + .GeneratedBlockByteCount = GeneratedBlockByteCount, + .QueuedPendingInMemoryBlocksForUpload = QueuedPendingInMemoryBlocksForUpload, + .UploadBlockCount = UploadBlockCount, + .UploadChunkCount = UploadChunkCount, + .LargeAttachmentSize = LargeAttachmentSize, + .NewBlocks = NewBlocks, + .Content = Content, + .Lookup = Lookup, + .NewBlockChunks = NewBlockChunks, + .LooseChunkIndexes = LooseChunkIndexes, + .TempUploadStats = TempUploadStats, + .TempLooseChunksStats = TempLooseChunksStats}; + + ScheduleBlockGenerationAndUpload(Context, Classification.BlockIndexes); + ScheduleLooseChunkCompressionAndUpload(Context, Classification.LooseChunkOrderIndexes); + + Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { + ZEN_UNUSED(PendingWork); + FilteredCompressedBytesPerSecond.Update(TempLooseChunksStats.CompressedChunkRawBytes.load()); + FilteredGenerateBlockBytesPerSecond.Update(GeneratedBlockByteCount.load()); + FilteredUploadedBytesPerSecond.Update(UploadedCompressedChunkSize.load() + UploadedBlockSize.load()); + uint64_t UploadedRawSize = UploadedRawChunkSize.load() + UploadedBlockSize.load(); + uint64_t UploadedCompressedSize = UploadedCompressedChunkSize.load() + UploadedBlockSize.load(); + + std::string Details = fmt::format( + "Compressed {}/{} ({}/{}{}) chunks. " + "Uploaded {}/{} ({}/{}) blobs " + "({}{})", + TempLooseChunksStats.CompressedChunkCount.load(), + Classification.LooseChunkOrderIndexes.size(), + NiceBytes(TempLooseChunksStats.CompressedChunkRawBytes), + NiceBytes(Classification.TotalLooseChunksSize), + (TempLooseChunksStats.CompressedChunkCount == Classification.LooseChunkOrderIndexes.size()) + ? "" + : fmt::format(" {}B/s", NiceNum(FilteredCompressedBytesPerSecond.GetCurrent())), + + UploadedBlockCount.load() + UploadedChunkCount.load(), + UploadBlockCount + UploadChunkCount, + NiceBytes(UploadedRawSize), + NiceBytes(TotalRawSize), + + NiceBytes(UploadedCompressedSize), + (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount) + ? "" + : fmt::format(" {}bits/s", NiceNum(FilteredUploadedBytesPerSecond.GetCurrent()))); + + ProgressBar->UpdateState({.Task = "Uploading blobs ", + .Details = Details, + .TotalCount = gsl::narrow<uint64_t>(TotalRawSize), + .RemainingCount = gsl::narrow<uint64_t>(TotalRawSize - UploadedRawSize), + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + false); + }); - FilteredRate FilteredGenerateBlockBytesPerSecond; - FilteredRate FilteredCompressedBytesPerSecond; - FilteredRate FilteredUploadedBytesPerSecond; + ZEN_ASSERT(m_AbortFlag || QueuedPendingInMemoryBlocksForUpload.load() == 0); - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + ProgressBar->Finish(); - std::atomic<size_t> UploadedBlockSize = 0; - std::atomic<size_t> UploadedBlockCount = 0; - std::atomic<size_t> UploadedRawChunkSize = 0; - std::atomic<size_t> UploadedCompressedChunkSize = 0; - std::atomic<uint32_t> UploadedChunkCount = 0; + TempUploadStats.ElapsedWallTimeUS += FilteredUploadedBytesPerSecond.GetElapsedTimeUS(); + TempLooseChunksStats.CompressChunksElapsedWallTimeUS += FilteredCompressedBytesPerSecond.GetElapsedTimeUS(); +} - tsl::robin_map<uint32_t, uint32_t> ChunkIndexToLooseChunkOrderIndex; - ChunkIndexToLooseChunkOrderIndex.reserve(LooseChunkIndexes.size()); - for (uint32_t OrderIndex = 0; OrderIndex < LooseChunkIndexes.size(); OrderIndex++) - { - ChunkIndexToLooseChunkOrderIndex.insert_or_assign(LooseChunkIndexes[OrderIndex], OrderIndex); - } +BuildsOperationUploadFolder::UploadPartClassification +BuildsOperationUploadFolder::ClassifyUploadRawHashes(std::span<IoHash> RawHashes, + const ChunkedFolderContent& Content, + const ChunkedContentLookup& Lookup, + const GeneratedBlocks& NewBlocks, + std::span<const uint32_t> LooseChunkIndexes, + std::vector<IoHash>& OutUnknownChunks) +{ + UploadPartClassification Result; - std::vector<size_t> BlockIndexes; - std::vector<uint32_t> LooseChunkOrderIndexes; + tsl::robin_map<uint32_t, uint32_t> ChunkIndexToLooseChunkOrderIndex; + ChunkIndexToLooseChunkOrderIndex.reserve(LooseChunkIndexes.size()); + for (uint32_t OrderIndex = 0; OrderIndex < LooseChunkIndexes.size(); OrderIndex++) + { + ChunkIndexToLooseChunkOrderIndex.insert_or_assign(LooseChunkIndexes[OrderIndex], OrderIndex); + } - uint64_t TotalLooseChunksSize = 0; - uint64_t TotalBlocksSize = 0; - for (const IoHash& RawHash : RawHashes) + for (const IoHash& RawHash : RawHashes) + { + if (auto It = NewBlocks.BlockHashToBlockIndex.find(RawHash); It != NewBlocks.BlockHashToBlockIndex.end()) { - if (auto It = NewBlocks.BlockHashToBlockIndex.find(RawHash); It != NewBlocks.BlockHashToBlockIndex.end()) - { - BlockIndexes.push_back(It->second); - TotalBlocksSize += NewBlocks.BlockSizes[It->second]; - } - else if (auto ChunkIndexIt = Lookup.ChunkHashToChunkIndex.find(RawHash); ChunkIndexIt != Lookup.ChunkHashToChunkIndex.end()) - { - const uint32_t ChunkIndex = ChunkIndexIt->second; - if (auto LooseOrderIndexIt = ChunkIndexToLooseChunkOrderIndex.find(ChunkIndex); - LooseOrderIndexIt != ChunkIndexToLooseChunkOrderIndex.end()) - { - LooseChunkOrderIndexes.push_back(LooseOrderIndexIt->second); - TotalLooseChunksSize += Content.ChunkedContent.ChunkRawSizes[ChunkIndex]; - } - } - else + Result.BlockIndexes.push_back(It->second); + Result.TotalBlocksSize += NewBlocks.BlockSizes[It->second]; + } + else if (auto ChunkIndexIt = Lookup.ChunkHashToChunkIndex.find(RawHash); ChunkIndexIt != Lookup.ChunkHashToChunkIndex.end()) + { + const uint32_t ChunkIndex = ChunkIndexIt->second; + if (auto LooseOrderIndexIt = ChunkIndexToLooseChunkOrderIndex.find(ChunkIndex); + LooseOrderIndexIt != ChunkIndexToLooseChunkOrderIndex.end()) { - OutUnknownChunks.push_back(RawHash); + Result.LooseChunkOrderIndexes.push_back(LooseOrderIndexIt->second); + Result.TotalLooseChunksSize += Content.ChunkedContent.ChunkRawSizes[ChunkIndex]; } } - if (BlockIndexes.empty() && LooseChunkOrderIndexes.empty()) + else { - return; + OutUnknownChunks.push_back(RawHash); } + } + return Result; +} - uint64_t TotalRawSize = TotalLooseChunksSize + TotalBlocksSize; - - const size_t UploadBlockCount = BlockIndexes.size(); - const uint32_t UploadChunkCount = gsl::narrow<uint32_t>(LooseChunkOrderIndexes.size()); - - auto AsyncUploadBlock = [this, - &Work, - &NewBlocks, - UploadBlockCount, - &UploadedBlockCount, - UploadChunkCount, - &UploadedChunkCount, - &UploadedBlockSize, - &TempUploadStats, - &FilteredUploadedBytesPerSecond, - &UploadChunkPool](const size_t BlockIndex, - const IoHash BlockHash, - CompositeBuffer&& Payload, - std::atomic<uint64_t>& QueuedPendingInMemoryBlocksForUpload) { - bool IsInMemoryBlock = true; - if (QueuedPendingInMemoryBlocksForUpload.load() > 16) - { - ZEN_TRACE_CPU("AsyncUploadBlock_WriteTempBlock"); - std::filesystem::path TempFilePath = m_Options.TempDir / (BlockHash.ToHexString()); - Payload = CompositeBuffer(WriteToTempFile(std::move(Payload), TempFilePath)); - IsInMemoryBlock = false; - } - else - { - QueuedPendingInMemoryBlocksForUpload++; - } +void +BuildsOperationUploadFolder::ScheduleBlockGenerationAndUpload(UploadPartBlobsContext& Context, std::span<const size_t> BlockIndexes) +{ + for (const size_t BlockIndex : BlockIndexes) + { + const IoHash& BlockHash = Context.NewBlocks.BlockDescriptions[BlockIndex].BlockHash; + if (m_AbortFlag) + { + break; + } + Context.Work.ScheduleWork( + Context.ReadChunkPool, + [this, &Context, BlockHash = IoHash(BlockHash), BlockIndex, GenerateBlockCount = BlockIndexes.size()](std::atomic<bool>&) { + if (m_AbortFlag) + { + return; + } + ZEN_TRACE_CPU("UploadPartBlobs_GenerateBlock"); - Work.ScheduleWork( - UploadChunkPool, - [this, - &QueuedPendingInMemoryBlocksForUpload, - &NewBlocks, - UploadBlockCount, - &UploadedBlockCount, - UploadChunkCount, - &UploadedChunkCount, - &UploadedBlockSize, - &TempUploadStats, - &FilteredUploadedBytesPerSecond, - IsInMemoryBlock, - BlockIndex, - BlockHash, - Payload = CompositeBuffer(std::move(Payload))](std::atomic<bool>&) mutable { - auto _ = MakeGuard([IsInMemoryBlock, &QueuedPendingInMemoryBlocksForUpload] { - if (IsInMemoryBlock) - { - QueuedPendingInMemoryBlocksForUpload--; - } - }); - if (!m_AbortFlag) + Context.FilteredGenerateBlockBytesPerSecond.Start(); + + Stopwatch GenerateTimer; + CompositeBuffer Payload; + if (Context.NewBlocks.BlockHeaders[BlockIndex]) + { + Payload = RebuildBlock(Context.Content, + Context.Lookup, + std::move(Context.NewBlocks.BlockHeaders[BlockIndex]), + Context.NewBlockChunks[BlockIndex]) + .GetCompressed(); + } + else + { + ChunkBlockDescription BlockDescription; + CompressedBuffer CompressedBlock = + GenerateBlock(Context.Content, Context.Lookup, Context.NewBlockChunks[BlockIndex], BlockDescription); + if (!CompressedBlock) { - ZEN_TRACE_CPU("AsyncUploadBlock"); + throw std::runtime_error(fmt::format("Failed generating block {}", BlockHash)); + } + ZEN_ASSERT(BlockDescription.BlockHash == BlockHash); + Payload = std::move(CompressedBlock).GetCompressed(); + } - const uint64_t PayloadSize = Payload.GetSize(); + Context.GeneratedBlockByteCount += Context.NewBlocks.BlockSizes[BlockIndex]; + if (Context.GeneratedBlockCount.fetch_add(1) + 1 == GenerateBlockCount) + { + Context.FilteredGenerateBlockBytesPerSecond.Stop(); + } + if (m_Options.IsVerbose) + { + ZEN_INFO("{} block {} ({}) containing {} chunks in {}", + Context.NewBlocks.BlockHeaders[BlockIndex] ? "Regenerated" : "Generated", + Context.NewBlocks.BlockDescriptions[BlockIndex].BlockHash, + NiceBytes(Context.NewBlocks.BlockSizes[BlockIndex]), + Context.NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size(), + NiceTimeSpanMs(GenerateTimer.GetElapsedTimeMs())); + } + if (!m_AbortFlag) + { + UploadBlockPayload(Context, BlockIndex, BlockHash, std::move(Payload)); + } + }); + } +} - FilteredUploadedBytesPerSecond.Start(); - const CbObject BlockMetaData = - BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]); +void +BuildsOperationUploadFolder::UploadBlockPayload(UploadPartBlobsContext& Context, + size_t BlockIndex, + const IoHash& BlockHash, + CompositeBuffer Payload) +{ + bool IsInMemoryBlock = true; + if (Context.QueuedPendingInMemoryBlocksForUpload.load() > 16) + { + ZEN_TRACE_CPU("AsyncUploadBlock_WriteTempBlock"); + std::filesystem::path TempFilePath = m_Options.TempDir / (BlockHash.ToHexString()); + Payload = CompositeBuffer(WriteToTempFile(std::move(Payload), TempFilePath)); + IsInMemoryBlock = false; + } + else + { + Context.QueuedPendingInMemoryBlocksForUpload++; + } - if (m_Storage.CacheStorage && m_Options.PopulateCache) - { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); - } - m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Uploaded block {} ({}) containing {} chunks", - BlockHash, - NiceBytes(PayloadSize), - NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size()); - } - UploadedBlockSize += PayloadSize; - TempUploadStats.BlocksBytes += PayloadSize; + Context.Work.ScheduleWork( + Context.UploadChunkPool, + [this, &Context, IsInMemoryBlock, BlockIndex, BlockHash = IoHash(BlockHash), Payload = CompositeBuffer(std::move(Payload))]( + std::atomic<bool>&) { + auto _ = MakeGuard([IsInMemoryBlock, &Context] { + if (IsInMemoryBlock) + { + Context.QueuedPendingInMemoryBlocksForUpload--; + } + }); + if (m_AbortFlag) + { + return; + } + ZEN_TRACE_CPU("AsyncUploadBlock"); - if (m_Storage.CacheStorage && m_Options.PopulateCache) - { - m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, - std::vector<IoHash>({BlockHash}), - std::vector<CbObject>({BlockMetaData})); - } - bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); - if (MetadataSucceeded) - { - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Uploaded block {} metadata ({})", - BlockHash, - NiceBytes(BlockMetaData.GetSize())); - } + const uint64_t PayloadSize = Payload.GetSize(); - NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true; - TempUploadStats.BlocksBytes += BlockMetaData.GetSize(); - } + Context.FilteredUploadedBytesPerSecond.Start(); + const CbObject BlockMetaData = + BuildChunkBlockDescription(Context.NewBlocks.BlockDescriptions[BlockIndex], Context.NewBlocks.BlockMetaDatas[BlockIndex]); - TempUploadStats.BlockCount++; + if (m_Storage.CacheStorage && m_Options.PopulateCache) + { + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); + } - UploadedBlockCount++; - if (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount) - { - FilteredUploadedBytesPerSecond.Stop(); - } - } - }); - }; + try + { + m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); + } + catch (const std::exception&) + { + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } + } - auto AsyncUploadLooseChunk = [this, - LargeAttachmentSize, - &Work, - &UploadChunkPool, - &FilteredUploadedBytesPerSecond, - &UploadedBlockCount, - &UploadedChunkCount, - UploadBlockCount, - UploadChunkCount, - &UploadedCompressedChunkSize, - &UploadedRawChunkSize, - &TempUploadStats](const IoHash& RawHash, uint64_t RawSize, CompositeBuffer&& Payload) { - Work.ScheduleWork( - UploadChunkPool, - [this, - &Work, - LargeAttachmentSize, - &FilteredUploadedBytesPerSecond, - &UploadChunkPool, - &UploadedBlockCount, - &UploadedChunkCount, - UploadBlockCount, - UploadChunkCount, - &UploadedCompressedChunkSize, - &UploadedRawChunkSize, - &TempUploadStats, - RawHash, - RawSize, - Payload = CompositeBuffer(std::move(Payload))](std::atomic<bool>&) mutable { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("AsyncUploadLooseChunk"); + if (m_AbortFlag) + { + return; + } + if (m_Options.IsVerbose) + { + ZEN_INFO("Uploaded block {} ({}) containing {} chunks", + BlockHash, + NiceBytes(PayloadSize), + Context.NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size()); + } + Context.UploadedBlockSize += PayloadSize; + Context.TempUploadStats.BlocksBytes += PayloadSize; - const uint64_t PayloadSize = Payload.GetSize(); + if (m_Storage.CacheStorage && m_Options.PopulateCache) + { + m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, + std::vector<IoHash>({BlockHash}), + std::vector<CbObject>({BlockMetaData})); + } - if (m_Storage.CacheStorage && m_Options.PopulateCache) - { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload); - } + bool MetadataSucceeded = false; + try + { + MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); + } + catch (const std::exception&) + { + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } + } + if (m_AbortFlag) + { + return; + } + if (MetadataSucceeded) + { + if (m_Options.IsVerbose) + { + ZEN_INFO("Uploaded block {} metadata ({})", BlockHash, NiceBytes(BlockMetaData.GetSize())); + } + Context.NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true; + Context.TempUploadStats.BlocksBytes += BlockMetaData.GetSize(); + } - if (PayloadSize >= LargeAttachmentSize) - { - ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart"); - TempUploadStats.MultipartAttachmentCount++; - std::vector<std::function<void()>> MultipartWork = m_Storage.BuildStorage->PutLargeBuildBlob( - m_BuildId, - RawHash, - ZenContentType::kCompressedBinary, - PayloadSize, - [Payload = std::move(Payload), &FilteredUploadedBytesPerSecond](uint64_t Offset, - uint64_t Size) mutable -> IoBuffer { - FilteredUploadedBytesPerSecond.Start(); - - IoBuffer PartPayload = Payload.Mid(Offset, Size).Flatten().AsIoBuffer(); - PartPayload.SetContentType(ZenContentType::kBinary); - return PartPayload; - }, - [RawSize, - &TempUploadStats, - &UploadedCompressedChunkSize, - &UploadChunkPool, - &UploadedBlockCount, - UploadBlockCount, - &UploadedChunkCount, - UploadChunkCount, - &FilteredUploadedBytesPerSecond, - &UploadedRawChunkSize](uint64_t SentBytes, bool IsComplete) { - TempUploadStats.ChunksBytes += SentBytes; - UploadedCompressedChunkSize += SentBytes; - if (IsComplete) - { - TempUploadStats.ChunkCount++; - UploadedChunkCount++; - if (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount) - { - FilteredUploadedBytesPerSecond.Stop(); - } - UploadedRawChunkSize += RawSize; - } - }); - for (auto& WorkPart : MultipartWork) - { - Work.ScheduleWork(UploadChunkPool, [Work = std::move(WorkPart)](std::atomic<bool>& AbortFlag) { - ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart_Work"); - if (!AbortFlag) - { - Work(); - } - }); - } - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Uploaded multipart chunk {} ({})", RawHash, NiceBytes(PayloadSize)); - } - } - else - { - ZEN_TRACE_CPU("AsyncUploadLooseChunk_Singlepart"); - m_Storage.BuildStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload); - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Uploaded chunk {} ({})", RawHash, NiceBytes(PayloadSize)); - } - TempUploadStats.ChunksBytes += Payload.GetSize(); - TempUploadStats.ChunkCount++; - UploadedCompressedChunkSize += Payload.GetSize(); - UploadedRawChunkSize += RawSize; - UploadedChunkCount++; - if (UploadedChunkCount == UploadChunkCount) - { - FilteredUploadedBytesPerSecond.Stop(); - } - } - } - }); - }; + Context.TempUploadStats.BlockCount++; - std::vector<size_t> GenerateBlockIndexes; + if (Context.UploadedBlockCount.fetch_add(1) + 1 == Context.UploadBlockCount && + Context.UploadedChunkCount == Context.UploadChunkCount) + { + Context.FilteredUploadedBytesPerSecond.Stop(); + } + }); +} - std::atomic<uint64_t> GeneratedBlockCount = 0; - std::atomic<uint64_t> GeneratedBlockByteCount = 0; +void +BuildsOperationUploadFolder::ScheduleLooseChunkCompressionAndUpload(UploadPartBlobsContext& Context, + std::span<const uint32_t> LooseChunkOrderIndexes) +{ + for (const uint32_t LooseChunkOrderIndex : LooseChunkOrderIndexes) + { + const uint32_t ChunkIndex = Context.LooseChunkIndexes[LooseChunkOrderIndex]; + Context.Work.ScheduleWork(Context.ReadChunkPool, + [this, &Context, LooseChunkOrderCount = LooseChunkOrderIndexes.size(), ChunkIndex](std::atomic<bool>&) { + if (m_AbortFlag) + { + return; + } + ZEN_TRACE_CPU("UploadPartBlobs_CompressChunk"); - std::atomic<uint64_t> QueuedPendingInMemoryBlocksForUpload = 0; + Context.FilteredCompressedBytesPerSecond.Start(); + Stopwatch CompressTimer; + CompositeBuffer Payload = + CompressChunk(Context.Content, Context.Lookup, ChunkIndex, Context.TempLooseChunksStats); + if (m_Options.IsVerbose) + { + ZEN_INFO("Compressed chunk {} ({} -> {}) in {}", + Context.Content.ChunkedContent.ChunkHashes[ChunkIndex], + NiceBytes(Context.Content.ChunkedContent.ChunkRawSizes[ChunkIndex]), + NiceBytes(Payload.GetSize()), + NiceTimeSpanMs(CompressTimer.GetElapsedTimeMs())); + } + const uint64_t ChunkRawSize = Context.Content.ChunkedContent.ChunkRawSizes[ChunkIndex]; + Context.TempUploadStats.ReadFromDiskBytes += ChunkRawSize; + if (Context.TempLooseChunksStats.CompressedChunkCount == LooseChunkOrderCount) + { + Context.FilteredCompressedBytesPerSecond.Stop(); + } + if (!m_AbortFlag) + { + UploadLooseChunkPayload(Context, + Context.Content.ChunkedContent.ChunkHashes[ChunkIndex], + ChunkRawSize, + std::move(Payload)); + } + }); + } +} - // Start generation of any non-prebuilt blocks and schedule upload - for (const size_t BlockIndex : BlockIndexes) - { - const IoHash& BlockHash = NewBlocks.BlockDescriptions[BlockIndex].BlockHash; - if (!m_AbortFlag) +void +BuildsOperationUploadFolder::UploadLooseChunkPayload(UploadPartBlobsContext& Context, + const IoHash& RawHash, + uint64_t RawSize, + CompositeBuffer Payload) +{ + Context.Work.ScheduleWork( + Context.UploadChunkPool, + [this, &Context, RawHash = IoHash(RawHash), RawSize, Payload = CompositeBuffer(std::move(Payload))](std::atomic<bool>&) mutable { + if (m_AbortFlag) { - Work.ScheduleWork( - ReadChunkPool, - [this, - BlockHash = IoHash(BlockHash), - BlockIndex, - &FilteredGenerateBlockBytesPerSecond, - &Content, - &Lookup, - &NewBlocks, - &NewBlockChunks, - &GenerateBlockIndexes, - &GeneratedBlockCount, - &GeneratedBlockByteCount, - &AsyncUploadBlock, - &QueuedPendingInMemoryBlocksForUpload](std::atomic<bool>&) { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("UploadPartBlobs_GenerateBlock"); + return; + } + ZEN_TRACE_CPU("AsyncUploadLooseChunk"); - FilteredGenerateBlockBytesPerSecond.Start(); + const uint64_t PayloadSize = Payload.GetSize(); - Stopwatch GenerateTimer; - CompositeBuffer Payload; - if (NewBlocks.BlockHeaders[BlockIndex]) - { - Payload = - RebuildBlock(Content, Lookup, std::move(NewBlocks.BlockHeaders[BlockIndex]), NewBlockChunks[BlockIndex]) - .GetCompressed(); - } - else + if (m_Storage.CacheStorage && m_Options.PopulateCache) + { + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload); + } + + if (PayloadSize >= Context.LargeAttachmentSize) + { + ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart"); + Context.TempUploadStats.MultipartAttachmentCount++; + try + { + std::vector<std::function<void()>> MultipartWork = m_Storage.BuildStorage->PutLargeBuildBlob( + m_BuildId, + RawHash, + ZenContentType::kCompressedBinary, + PayloadSize, + [Payload = std::move(Payload), &Context](uint64_t Offset, uint64_t Size) -> IoBuffer { + Context.FilteredUploadedBytesPerSecond.Start(); + + IoBuffer PartPayload = Payload.Mid(Offset, Size).Flatten().AsIoBuffer(); + PartPayload.SetContentType(ZenContentType::kBinary); + return PartPayload; + }, + [&Context, RawSize](uint64_t SentBytes, bool IsComplete) { + Context.TempUploadStats.ChunksBytes += SentBytes; + Context.UploadedCompressedChunkSize += SentBytes; + if (IsComplete) { - ChunkBlockDescription BlockDescription; - CompressedBuffer CompressedBlock = - GenerateBlock(Content, Lookup, NewBlockChunks[BlockIndex], BlockDescription); - if (!CompressedBlock) + Context.TempUploadStats.ChunkCount++; + if (Context.UploadedChunkCount.fetch_add(1) + 1 == Context.UploadChunkCount && + Context.UploadedBlockCount == Context.UploadBlockCount) { - throw std::runtime_error(fmt::format("Failed generating block {}", BlockHash)); + Context.FilteredUploadedBytesPerSecond.Stop(); } - ZEN_ASSERT(BlockDescription.BlockHash == BlockHash); - Payload = std::move(CompressedBlock).GetCompressed(); + Context.UploadedRawChunkSize += RawSize; } - - GeneratedBlockByteCount += NewBlocks.BlockSizes[BlockIndex]; - GeneratedBlockCount++; - if (GeneratedBlockCount == GenerateBlockIndexes.size()) - { - FilteredGenerateBlockBytesPerSecond.Stop(); - } - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "{} block {} ({}) containing {} chunks in {}", - NewBlocks.BlockHeaders[BlockIndex] ? "Regenerated" : "Generated", - NewBlocks.BlockDescriptions[BlockIndex].BlockHash, - NiceBytes(NewBlocks.BlockSizes[BlockIndex]), - NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size(), - NiceTimeSpanMs(GenerateTimer.GetElapsedTimeMs())); - } - if (!m_AbortFlag) + }); + for (auto& WorkPart : MultipartWork) + { + Context.Work.ScheduleWork(Context.UploadChunkPool, [Work = std::move(WorkPart)](std::atomic<bool>& AbortFlag) { + ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart_Work"); + if (!AbortFlag) { - AsyncUploadBlock(BlockIndex, BlockHash, std::move(Payload), QueuedPendingInMemoryBlocksForUpload); + Work(); } - } - }); - } - } - - // Start compression of any non-precompressed loose chunks and schedule upload - for (const uint32_t LooseChunkOrderIndex : LooseChunkOrderIndexes) - { - const uint32_t ChunkIndex = LooseChunkIndexes[LooseChunkOrderIndex]; - Work.ScheduleWork( - ReadChunkPool, - [this, - &Content, - &Lookup, - &TempLooseChunksStats, - &LooseChunkOrderIndexes, - &FilteredCompressedBytesPerSecond, - &TempUploadStats, - &AsyncUploadLooseChunk, - ChunkIndex](std::atomic<bool>&) { + }); + } + if (m_Options.IsVerbose) + { + ZEN_INFO("Uploaded multipart chunk {} ({})", RawHash, NiceBytes(PayloadSize)); + } + } + catch (const std::exception&) + { + // Silence http errors due to abort if (!m_AbortFlag) { - ZEN_TRACE_CPU("UploadPartBlobs_CompressChunk"); - - FilteredCompressedBytesPerSecond.Start(); - Stopwatch CompressTimer; - CompositeBuffer Payload = CompressChunk(Content, Lookup, ChunkIndex, TempLooseChunksStats); - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Compressed chunk {} ({} -> {}) in {}", - Content.ChunkedContent.ChunkHashes[ChunkIndex], - NiceBytes(Content.ChunkedContent.ChunkRawSizes[ChunkIndex]), - NiceBytes(Payload.GetSize()), - NiceTimeSpanMs(CompressTimer.GetElapsedTimeMs())); - } - const uint64_t ChunkRawSize = Content.ChunkedContent.ChunkRawSizes[ChunkIndex]; - TempUploadStats.ReadFromDiskBytes += ChunkRawSize; - if (TempLooseChunksStats.CompressedChunkCount == LooseChunkOrderIndexes.size()) - { - FilteredCompressedBytesPerSecond.Stop(); - } - if (!m_AbortFlag) - { - AsyncUploadLooseChunk(Content.ChunkedContent.ChunkHashes[ChunkIndex], ChunkRawSize, std::move(Payload)); - } + throw; } - }); - } + } + return; + } - Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { - ZEN_UNUSED(PendingWork); - FilteredCompressedBytesPerSecond.Update(TempLooseChunksStats.CompressedChunkRawBytes.load()); - FilteredGenerateBlockBytesPerSecond.Update(GeneratedBlockByteCount.load()); - FilteredUploadedBytesPerSecond.Update(UploadedCompressedChunkSize.load() + UploadedBlockSize.load()); - uint64_t UploadedRawSize = UploadedRawChunkSize.load() + UploadedBlockSize.load(); - uint64_t UploadedCompressedSize = UploadedCompressedChunkSize.load() + UploadedBlockSize.load(); - - std::string Details = fmt::format( - "Compressed {}/{} ({}/{}{}) chunks. " - "Uploaded {}/{} ({}/{}) blobs " - "({}{})", - TempLooseChunksStats.CompressedChunkCount.load(), - LooseChunkOrderIndexes.size(), - NiceBytes(TempLooseChunksStats.CompressedChunkRawBytes), - NiceBytes(TotalLooseChunksSize), - (TempLooseChunksStats.CompressedChunkCount == LooseChunkOrderIndexes.size()) - ? "" - : fmt::format(" {}B/s", NiceNum(FilteredCompressedBytesPerSecond.GetCurrent())), - - UploadedBlockCount.load() + UploadedChunkCount.load(), - UploadBlockCount + UploadChunkCount, - NiceBytes(UploadedRawSize), - NiceBytes(TotalRawSize), - - NiceBytes(UploadedCompressedSize), - (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount) - ? "" - : fmt::format(" {}bits/s", NiceNum(FilteredUploadedBytesPerSecond.GetCurrent()))); - - Progress.UpdateState({.Task = "Uploading blobs ", - .Details = Details, - .TotalCount = gsl::narrow<uint64_t>(TotalRawSize), - .RemainingCount = gsl::narrow<uint64_t>(TotalRawSize - UploadedRawSize), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, - false); + ZEN_TRACE_CPU("AsyncUploadLooseChunk_Singlepart"); + try + { + m_Storage.BuildStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload); + } + catch (const std::exception&) + { + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } + } + if (m_AbortFlag) + { + return; + } + if (m_Options.IsVerbose) + { + ZEN_INFO("Uploaded chunk {} ({})", RawHash, NiceBytes(PayloadSize)); + } + Context.TempUploadStats.ChunksBytes += Payload.GetSize(); + Context.TempUploadStats.ChunkCount++; + Context.UploadedCompressedChunkSize += Payload.GetSize(); + Context.UploadedRawChunkSize += RawSize; + if (Context.UploadedChunkCount.fetch_add(1) + 1 == Context.UploadChunkCount && + Context.UploadedBlockCount == Context.UploadBlockCount) + { + Context.FilteredUploadedBytesPerSecond.Stop(); + } }); - - ZEN_ASSERT(m_AbortFlag || QueuedPendingInMemoryBlocksForUpload.load() == 0); - - Progress.Finish(); - - TempUploadStats.ElapsedWallTimeUS += FilteredUploadedBytesPerSecond.GetElapsedTimeUS(); - TempLooseChunksStats.CompressChunksElapsedWallTimeUS += FilteredCompressedBytesPerSecond.GetElapsedTimeUS(); - } } CompositeBuffer @@ -6432,7 +6624,7 @@ BuildsOperationUploadFolder::CompressChunk(const ChunkedFolderContent& Content, throw std::runtime_error(fmt::format("Fetched chunk {} has invalid size", ChunkHash)); } - const bool ShouldCompressChunk = IsChunkCompressable(m_NonCompressableExtensionHashes, Content, Lookup, ChunkIndex); + const bool ShouldCompressChunk = IsChunkCompressable(m_NonCompressableExtensionHashes, Lookup, ChunkIndex); const OodleCompressionLevel CompressionLevel = ShouldCompressChunk ? OodleCompressionLevel::VeryFast : OodleCompressionLevel::None; if (ShouldCompressChunk) @@ -6521,7 +6713,8 @@ BuildsOperationUploadFolder::CompressChunk(const ChunkedFolderContent& Content, return std::move(CompressedBlob).GetCompressed(); } -BuildsOperationValidateBuildPart::BuildsOperationValidateBuildPart(OperationLogOutput& OperationLogOutput, +BuildsOperationValidateBuildPart::BuildsOperationValidateBuildPart(LoggerRef Log, + ProgressBase& Progress, BuildStorageBase& Storage, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, @@ -6532,7 +6725,8 @@ BuildsOperationValidateBuildPart::BuildsOperationValidateBuildPart(OperationLogO const std::string_view BuildPartName, const Options& Options) -: m_LogOutput(OperationLogOutput) +: m_Log(Log) +, m_Progress(Progress) , m_Storage(Storage) , m_AbortFlag(AbortFlag) , m_PauseFlag(PauseFlag) @@ -6551,89 +6745,24 @@ BuildsOperationValidateBuildPart::Execute() ZEN_TRACE_CPU("ValidateBuildPart"); try { - enum class TaskSteps : uint32_t - { - FetchBuild, - FetchBuildPart, - ValidateBlobs, - Cleanup, - StepCount - }; - auto EndProgress = - MakeGuard([&]() { m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); }); + MakeGuard([&]() { m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); }); Stopwatch Timer; auto _ = MakeGuard([&]() { if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Validated build part {}/{} ('{}') in {}", - m_BuildId, - m_BuildPartId, - m_BuildPartName, - NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + ZEN_INFO("Validated build part {}/{} ('{}') in {}", + m_BuildId, + m_BuildPartId, + m_BuildPartName, + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } }); - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::FetchBuild, (uint32_t)TaskSteps::StepCount); - - CbObject Build = m_Storage.GetBuild(m_BuildId); - if (!m_BuildPartName.empty()) - { - m_BuildPartId = Build["parts"sv].AsObjectView()[m_BuildPartName].AsObjectId(); - if (m_BuildPartId == Oid::Zero) - { - throw std::runtime_error(fmt::format("Build {} does not have a part named '{}'", m_BuildId, m_BuildPartName)); - } - } - m_ValidateStats.BuildBlobSize = Build.GetSize(); - uint64_t PreferredMultipartChunkSize = 32u * 1024u * 1024u; - if (auto ChunkSize = Build["chunkSize"sv].AsUInt64(); ChunkSize != 0) - { - PreferredMultipartChunkSize = ChunkSize; - } - - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::FetchBuildPart, (uint32_t)TaskSteps::StepCount); + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::FetchBuild, (uint32_t)TaskSteps::StepCount); - CbObject BuildPart = m_Storage.GetBuildPart(m_BuildId, m_BuildPartId); - m_ValidateStats.BuildPartSize = BuildPart.GetSize(); - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Validating build part {}/{} ({})", - m_BuildId, - m_BuildPartId, - NiceBytes(BuildPart.GetSize())); - } - std::vector<IoHash> ChunkAttachments; - if (const CbObjectView ChunkAttachmentsView = BuildPart["chunkAttachments"sv].AsObjectView()) - { - for (CbFieldView LooseFileView : ChunkAttachmentsView["rawHashes"sv]) - { - ChunkAttachments.push_back(LooseFileView.AsBinaryAttachment()); - } - } - m_ValidateStats.ChunkAttachmentCount = ChunkAttachments.size(); - std::vector<IoHash> BlockAttachments; - if (const CbObjectView BlockAttachmentsView = BuildPart["blockAttachments"sv].AsObjectView()) - { - { - for (CbFieldView BlocksView : BlockAttachmentsView["rawHashes"sv]) - { - BlockAttachments.push_back(BlocksView.AsBinaryAttachment()); - } - } - } - m_ValidateStats.BlockAttachmentCount = BlockAttachments.size(); - - std::vector<ChunkBlockDescription> VerifyBlockDescriptions = - ParseChunkBlockDescriptionList(m_Storage.GetBlockMetadatas(m_BuildId, BlockAttachments)); - if (VerifyBlockDescriptions.size() != BlockAttachments.size()) - { - throw std::runtime_error(fmt::format("Uploaded blocks metadata could not all be found, {} blocks metadata is missing", - BlockAttachments.size() - VerifyBlockDescriptions.size())); - } + ResolvedBuildPart Resolved = ResolveBuildPart(); ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); @@ -6643,150 +6772,23 @@ BuildsOperationValidateBuildPart::Execute() CreateDirectories(TempFolder); auto __ = MakeGuard([this, TempFolder]() { CleanAndRemoveDirectory(m_IOWorkerPool, m_AbortFlag, m_PauseFlag, TempFolder); }); - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::ValidateBlobs, (uint32_t)TaskSteps::StepCount); + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::ValidateBlobs, (uint32_t)TaskSteps::StepCount); - std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Validate Blobs")); - OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr); + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Validate Blobs"); - uint64_t AttachmentsToVerifyCount = ChunkAttachments.size() + BlockAttachments.size(); - FilteredRate FilteredDownloadedBytesPerSecond; - FilteredRate FilteredVerifiedBytesPerSecond; + const uint64_t AttachmentsToVerifyCount = Resolved.ChunkAttachments.size() + Resolved.BlockAttachments.size(); + FilteredRate FilteredDownloadedBytesPerSecond; + FilteredRate FilteredVerifiedBytesPerSecond; - std::atomic<uint64_t> MultipartAttachmentCount = 0; + ValidateBlobsContext Context{.Work = Work, + .AttachmentsToVerifyCount = AttachmentsToVerifyCount, + .FilteredDownloadedBytesPerSecond = FilteredDownloadedBytesPerSecond, + .FilteredVerifiedBytesPerSecond = FilteredVerifiedBytesPerSecond}; - for (const IoHash& ChunkAttachment : ChunkAttachments) - { - Work.ScheduleWork( - m_NetworkPool, - [this, - &Work, - AttachmentsToVerifyCount, - &TempFolder, - PreferredMultipartChunkSize, - &FilteredDownloadedBytesPerSecond, - &FilteredVerifiedBytesPerSecond, - &ChunkAttachments, - ChunkAttachment = IoHash(ChunkAttachment)](std::atomic<bool>&) { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("ValidateBuildPart_GetChunk"); - - FilteredDownloadedBytesPerSecond.Start(); - DownloadLargeBlob( - m_Storage, - TempFolder, - m_BuildId, - ChunkAttachment, - PreferredMultipartChunkSize, - Work, - m_NetworkPool, - m_DownloadStats.DownloadedChunkByteCount, - m_DownloadStats.MultipartAttachmentCount, - [this, - &Work, - AttachmentsToVerifyCount, - &FilteredDownloadedBytesPerSecond, - &FilteredVerifiedBytesPerSecond, - ChunkHash = IoHash(ChunkAttachment)](IoBuffer&& Payload) { - m_DownloadStats.DownloadedChunkCount++; - Payload.SetContentType(ZenContentType::kCompressedBinary); - if (!m_AbortFlag) - { - Work.ScheduleWork( - m_IOWorkerPool, - [this, - AttachmentsToVerifyCount, - &FilteredDownloadedBytesPerSecond, - &FilteredVerifiedBytesPerSecond, - Payload = IoBuffer(std::move(Payload)), - ChunkHash](std::atomic<bool>&) mutable { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("ValidateBuildPart_Validate"); - - if (m_DownloadStats.DownloadedChunkCount + m_DownloadStats.DownloadedBlockCount == - AttachmentsToVerifyCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } - - FilteredVerifiedBytesPerSecond.Start(); - - uint64_t CompressedSize; - uint64_t DecompressedSize; - ValidateBlob(m_AbortFlag, std::move(Payload), ChunkHash, CompressedSize, DecompressedSize); - m_ValidateStats.VerifiedAttachmentCount++; - m_ValidateStats.VerifiedByteCount += DecompressedSize; - if (m_ValidateStats.VerifiedAttachmentCount.load() == AttachmentsToVerifyCount) - { - FilteredVerifiedBytesPerSecond.Stop(); - } - } - }); - } - }); - } - }); - } - - for (const IoHash& BlockAttachment : BlockAttachments) - { - Work.ScheduleWork( - m_NetworkPool, - [this, - &Work, - AttachmentsToVerifyCount, - &FilteredDownloadedBytesPerSecond, - &FilteredVerifiedBytesPerSecond, - BlockAttachment = IoHash(BlockAttachment)](std::atomic<bool>&) { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("ValidateBuildPart_GetBlock"); - - FilteredDownloadedBytesPerSecond.Start(); - IoBuffer Payload = m_Storage.GetBuildBlob(m_BuildId, BlockAttachment); - m_DownloadStats.DownloadedBlockCount++; - m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize(); - if (m_DownloadStats.DownloadedChunkCount + m_DownloadStats.DownloadedBlockCount == AttachmentsToVerifyCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } - if (!Payload) - { - throw std::runtime_error(fmt::format("Block attachment {} could not be found", BlockAttachment)); - } - if (!m_AbortFlag) - { - Work.ScheduleWork( - m_IOWorkerPool, - [this, - &FilteredVerifiedBytesPerSecond, - AttachmentsToVerifyCount, - Payload = std::move(Payload), - BlockAttachment](std::atomic<bool>&) mutable { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("ValidateBuildPart_ValidateBlock"); - - FilteredVerifiedBytesPerSecond.Start(); - - uint64_t CompressedSize; - uint64_t DecompressedSize; - ValidateChunkBlock(std::move(Payload), BlockAttachment, CompressedSize, DecompressedSize); - m_ValidateStats.VerifiedAttachmentCount++; - m_ValidateStats.VerifiedByteCount += DecompressedSize; - if (m_ValidateStats.VerifiedAttachmentCount.load() == AttachmentsToVerifyCount) - { - FilteredVerifiedBytesPerSecond.Stop(); - } - } - }); - } - } - }); - } + ScheduleChunkAttachmentValidation(Context, Resolved.ChunkAttachments, TempFolder, Resolved.PreferredMultipartChunkSize); + ScheduleBlockAttachmentValidation(Context, Resolved.BlockAttachments); - Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { + Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { ZEN_UNUSED(PendingWork); const uint64_t DownloadedAttachmentCount = m_DownloadStats.DownloadedChunkCount + m_DownloadStats.DownloadedBlockCount; @@ -6805,20 +6807,20 @@ BuildsOperationValidateBuildPart::Execute() NiceBytes(m_ValidateStats.VerifiedByteCount.load()), NiceNum(FilteredVerifiedBytesPerSecond.GetCurrent())); - Progress.UpdateState( + ProgressBar->UpdateState( {.Task = "Validating blobs ", .Details = Details, .TotalCount = gsl::narrow<uint64_t>(AttachmentsToVerifyCount * 2), .RemainingCount = gsl::narrow<uint64_t>(AttachmentsToVerifyCount * 2 - (DownloadedAttachmentCount + m_ValidateStats.VerifiedAttachmentCount.load())), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, false); }); - Progress.Finish(); + ProgressBar->Finish(); m_ValidateStats.ElapsedWallTimeUS = Timer.GetElapsedTimeUs(); - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount); + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount); } catch (const std::exception&) { @@ -6827,7 +6829,189 @@ BuildsOperationValidateBuildPart::Execute() } } -BuildsOperationPrimeCache::BuildsOperationPrimeCache(OperationLogOutput& OperationLogOutput, +BuildsOperationValidateBuildPart::ResolvedBuildPart +BuildsOperationValidateBuildPart::ResolveBuildPart() +{ + ResolvedBuildPart Result; + Result.PreferredMultipartChunkSize = 32u * 1024u * 1024u; + + CbObject Build = m_Storage.GetBuild(m_BuildId); + if (!m_BuildPartName.empty()) + { + m_BuildPartId = Build["parts"sv].AsObjectView()[m_BuildPartName].AsObjectId(); + if (m_BuildPartId == Oid::Zero) + { + throw std::runtime_error(fmt::format("Build {} does not have a part named '{}'", m_BuildId, m_BuildPartName)); + } + } + m_ValidateStats.BuildBlobSize = Build.GetSize(); + if (auto ChunkSize = Build["chunkSize"sv].AsUInt64(); ChunkSize != 0) + { + Result.PreferredMultipartChunkSize = ChunkSize; + } + + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::FetchBuildPart, (uint32_t)TaskSteps::StepCount); + + CbObject BuildPart = m_Storage.GetBuildPart(m_BuildId, m_BuildPartId); + m_ValidateStats.BuildPartSize = BuildPart.GetSize(); + if (!m_Options.IsQuiet) + { + ZEN_INFO("Validating build part {}/{} ({})", m_BuildId, m_BuildPartId, NiceBytes(BuildPart.GetSize())); + } + if (const CbObjectView ChunkAttachmentsView = BuildPart["chunkAttachments"sv].AsObjectView()) + { + for (CbFieldView LooseFileView : ChunkAttachmentsView["rawHashes"sv]) + { + Result.ChunkAttachments.push_back(LooseFileView.AsBinaryAttachment()); + } + } + m_ValidateStats.ChunkAttachmentCount = Result.ChunkAttachments.size(); + if (const CbObjectView BlockAttachmentsView = BuildPart["blockAttachments"sv].AsObjectView()) + { + for (CbFieldView BlocksView : BlockAttachmentsView["rawHashes"sv]) + { + Result.BlockAttachments.push_back(BlocksView.AsBinaryAttachment()); + } + } + m_ValidateStats.BlockAttachmentCount = Result.BlockAttachments.size(); + + std::vector<ChunkBlockDescription> VerifyBlockDescriptions = + ParseChunkBlockDescriptionList(m_Storage.GetBlockMetadatas(m_BuildId, Result.BlockAttachments)); + if (VerifyBlockDescriptions.size() != Result.BlockAttachments.size()) + { + throw std::runtime_error(fmt::format("Uploaded blocks metadata could not all be found, {} blocks metadata is missing", + Result.BlockAttachments.size() - VerifyBlockDescriptions.size())); + } + + return Result; +} + +void +BuildsOperationValidateBuildPart::ScheduleChunkAttachmentValidation(ValidateBlobsContext& Context, + std::span<const IoHash> ChunkAttachments, + const std::filesystem::path& TempFolder, + uint64_t PreferredMultipartChunkSize) +{ + for (const IoHash& ChunkAttachment : ChunkAttachments) + { + Context.Work.ScheduleWork( + m_NetworkPool, + [this, &Context, &TempFolder, PreferredMultipartChunkSize, ChunkAttachment = IoHash(ChunkAttachment)](std::atomic<bool>&) { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("ValidateBuildPart_GetChunk"); + + Context.FilteredDownloadedBytesPerSecond.Start(); + DownloadLargeBlob( + m_Storage, + TempFolder, + m_BuildId, + ChunkAttachment, + PreferredMultipartChunkSize, + Context.Work, + m_NetworkPool, + m_DownloadStats.DownloadedChunkByteCount, + m_DownloadStats.MultipartAttachmentCount, + [this, &Context, ChunkHash = IoHash(ChunkAttachment)](IoBuffer&& Payload) { + m_DownloadStats.DownloadedChunkCount++; + Payload.SetContentType(ZenContentType::kCompressedBinary); + if (!m_AbortFlag) + { + Context.Work.ScheduleWork( + m_IOWorkerPool, + [this, &Context, Payload = IoBuffer(std::move(Payload)), ChunkHash](std::atomic<bool>&) mutable { + if (!m_AbortFlag) + { + ValidateDownloadedChunk(Context, ChunkHash, std::move(Payload)); + } + }); + } + }); + } + }); + } +} + +void +BuildsOperationValidateBuildPart::ScheduleBlockAttachmentValidation(ValidateBlobsContext& Context, std::span<const IoHash> BlockAttachments) +{ + for (const IoHash& BlockAttachment : BlockAttachments) + { + Context.Work.ScheduleWork(m_NetworkPool, [this, &Context, BlockAttachment = IoHash(BlockAttachment)](std::atomic<bool>&) { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("ValidateBuildPart_GetBlock"); + + Context.FilteredDownloadedBytesPerSecond.Start(); + IoBuffer Payload = m_Storage.GetBuildBlob(m_BuildId, BlockAttachment); + m_DownloadStats.DownloadedBlockCount++; + m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize(); + if (m_DownloadStats.DownloadedChunkCount + m_DownloadStats.DownloadedBlockCount == Context.AttachmentsToVerifyCount) + { + Context.FilteredDownloadedBytesPerSecond.Stop(); + } + if (!Payload) + { + throw std::runtime_error(fmt::format("Block attachment {} could not be found", BlockAttachment)); + } + if (!m_AbortFlag) + { + Context.Work.ScheduleWork(m_IOWorkerPool, + [this, &Context, Payload = std::move(Payload), BlockAttachment](std::atomic<bool>&) mutable { + if (!m_AbortFlag) + { + ValidateDownloadedBlock(Context, BlockAttachment, std::move(Payload)); + } + }); + } + } + }); + } +} + +void +BuildsOperationValidateBuildPart::ValidateDownloadedChunk(ValidateBlobsContext& Context, const IoHash& ChunkHash, IoBuffer Payload) +{ + ZEN_TRACE_CPU("ValidateBuildPart_Validate"); + + if (m_DownloadStats.DownloadedChunkCount + m_DownloadStats.DownloadedBlockCount == Context.AttachmentsToVerifyCount) + { + Context.FilteredDownloadedBytesPerSecond.Stop(); + } + + Context.FilteredVerifiedBytesPerSecond.Start(); + + uint64_t CompressedSize; + uint64_t DecompressedSize; + ValidateBlob(m_AbortFlag, std::move(Payload), ChunkHash, CompressedSize, DecompressedSize); + m_ValidateStats.VerifiedAttachmentCount++; + m_ValidateStats.VerifiedByteCount += DecompressedSize; + if (m_ValidateStats.VerifiedAttachmentCount.load() == Context.AttachmentsToVerifyCount) + { + Context.FilteredVerifiedBytesPerSecond.Stop(); + } +} + +void +BuildsOperationValidateBuildPart::ValidateDownloadedBlock(ValidateBlobsContext& Context, const IoHash& BlockAttachment, IoBuffer Payload) +{ + ZEN_TRACE_CPU("ValidateBuildPart_ValidateBlock"); + + Context.FilteredVerifiedBytesPerSecond.Start(); + + uint64_t CompressedSize; + uint64_t DecompressedSize; + ValidateChunkBlock(std::move(Payload), BlockAttachment, CompressedSize, DecompressedSize); + m_ValidateStats.VerifiedAttachmentCount++; + m_ValidateStats.VerifiedByteCount += DecompressedSize; + if (m_ValidateStats.VerifiedAttachmentCount.load() == Context.AttachmentsToVerifyCount) + { + Context.FilteredVerifiedBytesPerSecond.Stop(); + } +} + +BuildsOperationPrimeCache::BuildsOperationPrimeCache(LoggerRef Log, + ProgressBase& Progress, StorageInstance& Storage, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, @@ -6836,7 +7020,8 @@ BuildsOperationPrimeCache::BuildsOperationPrimeCache(OperationLogOutput& Opera std::span<const Oid> BuildPartIds, const Options& Options, BuildStorageCache::Statistics& StorageCacheStats) -: m_LogOutput(OperationLogOutput) +: m_Log(Log) +, m_Progress(Progress) , m_Storage(Storage) , m_AbortFlag(AbortFlag) , m_PauseFlag(PauseFlag) @@ -6858,9 +7043,69 @@ BuildsOperationPrimeCache::Execute() Stopwatch PrimeTimer; tsl::robin_map<IoHash, uint64_t, IoHash::Hasher> LooseChunkRawSizes; + tsl::robin_set<IoHash, IoHash::Hasher> BuildBlobs; + CollectReferencedBlobs(BuildBlobs, LooseChunkRawSizes); + + if (!m_Options.IsQuiet) + { + ZEN_INFO("Found {} referenced blobs", BuildBlobs.size()); + } - tsl::robin_set<IoHash, IoHash::Hasher> BuildBlobs; + if (BuildBlobs.empty()) + { + return; + } + std::vector<IoHash> BlobsToDownload = FilterAlreadyCachedBlobs(BuildBlobs); + + if (BlobsToDownload.empty()) + { + return; + } + + std::atomic<uint64_t> MultipartAttachmentCount; + std::atomic<size_t> CompletedDownloadCount; + FilteredRate FilteredDownloadedBytesPerSecond; + + ScheduleBlobDownloads(BlobsToDownload, + LooseChunkRawSizes, + MultipartAttachmentCount, + CompletedDownloadCount, + FilteredDownloadedBytesPerSecond); + + if (m_AbortFlag) + { + return; + } + + if (m_Storage.CacheStorage) + { + m_Storage.CacheStorage->Flush(m_Progress.GetProgressUpdateDelayMS(), [this](intptr_t Remaining) -> bool { + ZEN_UNUSED(Remaining); + if (!m_Options.IsQuiet) + { + ZEN_INFO("Waiting for {} blobs to finish upload to '{}'", Remaining, m_Storage.CacheHost.Name); + } + return !m_AbortFlag; + }); + } + + if (!m_Options.IsQuiet) + { + uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() + m_DownloadStats.DownloadedBlockByteCount.load(); + ZEN_INFO("Downloaded {} ({}bits/s) in {}. {} as multipart. Completed in {}", + NiceBytes(DownloadedBytes), + NiceNum(GetBytesPerSecond(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS(), DownloadedBytes * 8)), + NiceTimeSpanMs(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS() / 1000), + MultipartAttachmentCount.load(), + NiceTimeSpanMs(PrimeTimer.GetElapsedTimeMs())); + } +} + +void +BuildsOperationPrimeCache::CollectReferencedBlobs(tsl::robin_set<IoHash, IoHash::Hasher>& OutBuildBlobs, + tsl::robin_map<IoHash, uint64_t, IoHash::Hasher>& OutLooseChunkRawSizes) +{ for (const Oid& BuildPartId : m_BuildPartIds) { CbObject BuildPart = m_Storage.BuildStorage->GetBuildPart(m_BuildId, BuildPartId); @@ -6878,26 +7123,20 @@ BuildsOperationPrimeCache::Execute() ChunkRawSizes.size())); } - BuildBlobs.reserve(ChunkAttachments.size() + BlockAttachments.size()); - BuildBlobs.insert(BlockAttachments.begin(), BlockAttachments.end()); - BuildBlobs.insert(ChunkAttachments.begin(), ChunkAttachments.end()); + OutBuildBlobs.reserve(ChunkAttachments.size() + BlockAttachments.size()); + OutBuildBlobs.insert(BlockAttachments.begin(), BlockAttachments.end()); + OutBuildBlobs.insert(ChunkAttachments.begin(), ChunkAttachments.end()); for (size_t ChunkAttachmentIndex = 0; ChunkAttachmentIndex < ChunkAttachments.size(); ChunkAttachmentIndex++) { - LooseChunkRawSizes.insert_or_assign(ChunkAttachments[ChunkAttachmentIndex], ChunkRawSizes[ChunkAttachmentIndex]); + OutLooseChunkRawSizes.insert_or_assign(ChunkAttachments[ChunkAttachmentIndex], ChunkRawSizes[ChunkAttachmentIndex]); } } +} - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Found {} referenced blobs", BuildBlobs.size()); - } - - if (BuildBlobs.empty()) - { - return; - } - +std::vector<IoHash> +BuildsOperationPrimeCache::FilterAlreadyCachedBlobs(const tsl::robin_set<IoHash, IoHash::Hasher>& BuildBlobs) +{ std::vector<IoHash> BlobsToDownload; BlobsToDownload.reserve(BuildBlobs.size()); @@ -6923,11 +7162,10 @@ BuildsOperationPrimeCache::Execute() if (FoundCount > 0 && !m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Remote cache : Found {} out of {} needed blobs in {}", - FoundCount, - BuildBlobs.size(), - NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + ZEN_INFO("Remote cache : Found {} out of {} needed blobs in {}", + FoundCount, + BuildBlobs.size(), + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } } } @@ -6935,169 +7173,170 @@ BuildsOperationPrimeCache::Execute() { BlobsToDownload.insert(BlobsToDownload.end(), BuildBlobs.begin(), BuildBlobs.end()); } + return BlobsToDownload; +} - if (BlobsToDownload.empty()) - { - return; - } - - std::atomic<uint64_t> MultipartAttachmentCount; - std::atomic<size_t> CompletedDownloadCount; - FilteredRate FilteredDownloadedBytesPerSecond; - - { - std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Downloading")); - OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr); - - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); - - const size_t BlobCount = BlobsToDownload.size(); - - for (size_t BlobIndex = 0; BlobIndex < BlobCount; BlobIndex++) - { - Work.ScheduleWork(m_NetworkPool, - [this, - &Work, - &BlobsToDownload, - BlobCount, - &LooseChunkRawSizes, - &CompletedDownloadCount, - &FilteredDownloadedBytesPerSecond, - &MultipartAttachmentCount, - BlobIndex](std::atomic<bool>&) { - if (!m_AbortFlag) - { - const IoHash& BlobHash = BlobsToDownload[BlobIndex]; - - bool IsLargeBlob = false; +void +BuildsOperationPrimeCache::ScheduleBlobDownloads(std::span<const IoHash> BlobsToDownload, + const tsl::robin_map<IoHash, uint64_t, IoHash::Hasher>& LooseChunkRawSizes, + std::atomic<uint64_t>& MultipartAttachmentCount, + std::atomic<size_t>& CompletedDownloadCount, + FilteredRate& FilteredDownloadedBytesPerSecond) +{ + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Downloading"); - if (auto It = LooseChunkRawSizes.find(BlobHash); It != LooseChunkRawSizes.end()) - { - IsLargeBlob = It->second >= m_Options.LargeAttachmentSize; - } + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); - FilteredDownloadedBytesPerSecond.Start(); + const size_t BlobCount = BlobsToDownload.size(); - if (IsLargeBlob) - { - DownloadLargeBlob( - *m_Storage.BuildStorage, - m_TempPath, - m_BuildId, - BlobHash, - m_Options.PreferredMultipartChunkSize, - Work, - m_NetworkPool, - m_DownloadStats.DownloadedChunkByteCount, - MultipartAttachmentCount, - [this, BlobCount, BlobHash, &FilteredDownloadedBytesPerSecond, &CompletedDownloadCount]( - IoBuffer&& Payload) { - m_DownloadStats.DownloadedChunkCount++; - m_DownloadStats.RequestsCompleteCount++; + for (size_t BlobIndex = 0; BlobIndex < BlobCount; BlobIndex++) + { + Work.ScheduleWork( + m_NetworkPool, + [this, + &Work, + BlobsToDownload, + BlobCount, + &LooseChunkRawSizes, + &CompletedDownloadCount, + &FilteredDownloadedBytesPerSecond, + &MultipartAttachmentCount, + BlobIndex](std::atomic<bool>&) { + if (!m_AbortFlag) + { + const IoHash& BlobHash = BlobsToDownload[BlobIndex]; + bool IsLargeBlob = false; + if (auto It = LooseChunkRawSizes.find(BlobHash); It != LooseChunkRawSizes.end()) + { + IsLargeBlob = It->second >= m_Options.LargeAttachmentSize; + } - if (!m_AbortFlag) - { - if (Payload && m_Storage.CacheStorage) - { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, - BlobHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(Payload))); - } - } - CompletedDownloadCount++; - if (CompletedDownloadCount == BlobCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } - }); - } - else - { - IoBuffer Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash); - m_DownloadStats.DownloadedBlockCount++; - m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize(); - m_DownloadStats.RequestsCompleteCount++; + FilteredDownloadedBytesPerSecond.Start(); - if (!m_AbortFlag) - { - if (Payload && m_Storage.CacheStorage) - { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, - BlobHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(std::move(Payload)))); - } - } - CompletedDownloadCount++; - if (CompletedDownloadCount == BlobCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } - } - } - }); - } + if (IsLargeBlob) + { + DownloadLargeBlobForCache(Work, + BlobHash, + BlobCount, + CompletedDownloadCount, + MultipartAttachmentCount, + FilteredDownloadedBytesPerSecond); + } + else + { + DownloadSingleBlobForCache(BlobHash, BlobCount, CompletedDownloadCount, FilteredDownloadedBytesPerSecond); + } + } + }); + } - Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { - ZEN_UNUSED(PendingWork); + Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { + ZEN_UNUSED(PendingWork); - uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() + m_DownloadStats.DownloadedBlockByteCount.load(); - FilteredDownloadedBytesPerSecond.Update(DownloadedBytes); - - std::string DownloadRateString = (CompletedDownloadCount == BlobCount) - ? "" - : fmt::format(" {}bits/s", NiceNum(FilteredDownloadedBytesPerSecond.GetCurrent() * 8)); - std::string UploadDetails = m_Storage.CacheStorage ? fmt::format(" {} ({}) uploaded.", - m_StorageCacheStats.PutBlobCount.load(), - NiceBytes(m_StorageCacheStats.PutBlobByteCount.load())) - : ""; - - std::string Details = fmt::format("{}/{} ({}{}) downloaded.{}", - CompletedDownloadCount.load(), - BlobCount, - NiceBytes(DownloadedBytes), - DownloadRateString, - UploadDetails); - Progress.UpdateState({.Task = "Downloading", + uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() + m_DownloadStats.DownloadedBlockByteCount.load(); + FilteredDownloadedBytesPerSecond.Update(DownloadedBytes); + + std::string DownloadRateString = (CompletedDownloadCount == BlobCount) + ? "" + : fmt::format(" {}bits/s", NiceNum(FilteredDownloadedBytesPerSecond.GetCurrent() * 8)); + std::string UploadDetails = m_Storage.CacheStorage ? fmt::format(" {} ({}) uploaded.", + m_StorageCacheStats.PutBlobCount.load(), + NiceBytes(m_StorageCacheStats.PutBlobByteCount.load())) + : ""; + + std::string Details = fmt::format("{}/{} ({}{}) downloaded.{}", + CompletedDownloadCount.load(), + BlobCount, + NiceBytes(DownloadedBytes), + DownloadRateString, + UploadDetails); + ProgressBar->UpdateState({.Task = "Downloading", .Details = Details, .TotalCount = BlobCount, .RemainingCount = BlobCount - CompletedDownloadCount.load(), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, false); - }); + }); + + FilteredDownloadedBytesPerSecond.Stop(); + ProgressBar->Finish(); +} - FilteredDownloadedBytesPerSecond.Stop(); +void +BuildsOperationPrimeCache::DownloadLargeBlobForCache(ParallelWork& Work, + const IoHash& BlobHash, + size_t BlobCount, + std::atomic<size_t>& CompletedDownloadCount, + std::atomic<uint64_t>& MultipartAttachmentCount, + FilteredRate& FilteredDownloadedBytesPerSecond) +{ + DownloadLargeBlob(*m_Storage.BuildStorage, + m_TempPath, + m_BuildId, + BlobHash, + m_Options.PreferredMultipartChunkSize, + Work, + m_NetworkPool, + m_DownloadStats.DownloadedChunkByteCount, + MultipartAttachmentCount, + [this, BlobCount, BlobHash, &FilteredDownloadedBytesPerSecond, &CompletedDownloadCount](IoBuffer&& Payload) { + m_DownloadStats.DownloadedChunkCount++; + m_DownloadStats.RequestsCompleteCount++; + + if (!m_AbortFlag) + { + if (Payload && m_Storage.CacheStorage) + { + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlobHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(Payload))); + } + } + if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } + }); +} - Progress.Finish(); - } - if (m_AbortFlag) +void +BuildsOperationPrimeCache::DownloadSingleBlobForCache(const IoHash& BlobHash, + size_t BlobCount, + std::atomic<size_t>& CompletedDownloadCount, + FilteredRate& FilteredDownloadedBytesPerSecond) +{ + IoBuffer Payload; + try { - return; - } + Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash); - if (m_Storage.CacheStorage) + m_DownloadStats.DownloadedBlockCount++; + m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize(); + m_DownloadStats.RequestsCompleteCount++; + } + catch (const std::exception&) { - m_Storage.CacheStorage->Flush(m_LogOutput.GetProgressUpdateDelayMS(), [this](intptr_t Remaining) -> bool { - ZEN_UNUSED(Remaining); - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Waiting for {} blobs to finish upload to '{}'", Remaining, m_Storage.CacheHost.Name); - } - return !m_AbortFlag; - }); + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } } - if (!m_Options.IsQuiet) + if (!m_AbortFlag) { - uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() + m_DownloadStats.DownloadedBlockByteCount.load(); - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Downloaded {} ({}bits/s) in {}. {} as multipart. Completed in {}", - NiceBytes(DownloadedBytes), - NiceNum(GetBytesPerSecond(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS(), DownloadedBytes * 8)), - NiceTimeSpanMs(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS() / 1000), - MultipartAttachmentCount.load(), - NiceTimeSpanMs(PrimeTimer.GetElapsedTimeMs())); + if (Payload && m_Storage.CacheStorage) + { + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlobHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(std::move(Payload)))); + } + if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } } } @@ -7212,7 +7451,7 @@ ResolveBuildPartNames(CbObjectView BuildObject, } ChunkedFolderContent -GetRemoteContent(OperationLogOutput& Output, +GetRemoteContent(LoggerRef InLog, StorageInstance& Storage, const Oid& BuildId, const std::vector<std::pair<Oid, std::string>>& BuildParts, @@ -7228,6 +7467,7 @@ GetRemoteContent(OperationLogOutput& Output, bool DoExtraContentVerify) { ZEN_TRACE_CPU("GetRemoteContent"); + ZEN_SCOPED_LOG(InLog); Stopwatch GetBuildPartTimer; const Oid BuildPartId = BuildParts[0].first; @@ -7235,13 +7475,12 @@ GetRemoteContent(OperationLogOutput& Output, CbObject BuildPartManifest = Storage.BuildStorage->GetBuildPart(BuildId, BuildPartId); if (!IsQuiet) { - ZEN_OPERATION_LOG_INFO(Output, - "GetBuildPart {} ('{}') took {}. Payload size: {}", - BuildPartId, - BuildPartName, - NiceTimeSpanMs(GetBuildPartTimer.GetElapsedTimeMs()), - NiceBytes(BuildPartManifest.GetSize())); - ZEN_OPERATION_LOG_INFO(Output, "{}", GetCbObjectAsNiceString(BuildPartManifest, " "sv, "\n"sv)); + ZEN_INFO("GetBuildPart {} ('{}') took {}. Payload size: {}", + BuildPartId, + BuildPartName, + NiceTimeSpanMs(GetBuildPartTimer.GetElapsedTimeMs()), + NiceBytes(BuildPartManifest.GetSize())); + ZEN_INFO("{}", GetCbObjectAsNiceString(BuildPartManifest, " "sv, "\n"sv)); } { @@ -7251,17 +7490,16 @@ GetRemoteContent(OperationLogOutput& Output, OutChunkController = CreateChunkingController(ChunkerName, Parameters); } - auto ParseBuildPartManifest = [&Output, IsQuiet, IsVerbose, DoExtraContentVerify]( - StorageInstance& Storage, - const Oid& BuildId, - const Oid& BuildPartId, - CbObject BuildPartManifest, - std::span<const std::string> IncludeWildcards, - std::span<const std::string> ExcludeWildcards, - const BuildManifest::Part* OptionalManifest, - ChunkedFolderContent& OutRemoteContent, - std::vector<ChunkBlockDescription>& OutBlockDescriptions, - std::vector<IoHash>& OutLooseChunkHashes) { + auto ParseBuildPartManifest = [&Log, IsQuiet, IsVerbose, DoExtraContentVerify](StorageInstance& Storage, + const Oid& BuildId, + const Oid& BuildPartId, + CbObject BuildPartManifest, + std::span<const std::string> IncludeWildcards, + std::span<const std::string> ExcludeWildcards, + const BuildManifest::Part* OptionalManifest, + ChunkedFolderContent& OutRemoteContent, + std::vector<ChunkBlockDescription>& OutBlockDescriptions, + std::vector<IoHash>& OutLooseChunkHashes) { std::vector<uint32_t> AbsoluteChunkOrders; std::vector<uint64_t> LooseChunkRawSizes; std::vector<IoHash> BlockRawHashes; @@ -7284,13 +7522,13 @@ GetRemoteContent(OperationLogOutput& Output, { if (!IsQuiet) { - ZEN_OPERATION_LOG_INFO(Output, "Fetching metadata for {} blocks", BlockRawHashes.size()); + ZEN_INFO("Fetching metadata for {} blocks", BlockRawHashes.size()); } Stopwatch GetBlockMetadataTimer; bool AttemptFallback = false; - OutBlockDescriptions = GetBlockDescriptions(Output, + OutBlockDescriptions = GetBlockDescriptions(Log(), *Storage.BuildStorage, Storage.CacheStorage.get(), BuildId, @@ -7301,11 +7539,10 @@ GetRemoteContent(OperationLogOutput& Output, if (!IsQuiet) { - ZEN_OPERATION_LOG_INFO(Output, - "GetBlockMetadata for {} took {}. Found {} blocks", - BuildPartId, - NiceTimeSpanMs(GetBlockMetadataTimer.GetElapsedTimeMs()), - OutBlockDescriptions.size()); + ZEN_INFO("GetBlockMetadata for {} took {}. Found {} blocks", + BuildPartId, + NiceTimeSpanMs(GetBlockMetadataTimer.GetElapsedTimeMs()), + OutBlockDescriptions.size()); } } @@ -7414,12 +7651,11 @@ GetRemoteContent(OperationLogOutput& Output, CbObject OverlayBuildPartManifest = Storage.BuildStorage->GetBuildPart(BuildId, OverlayBuildPartId); if (!IsQuiet) { - ZEN_OPERATION_LOG_INFO(Output, - "GetBuildPart {} ('{}') took {}. Payload size: {}", - OverlayBuildPartId, - OverlayBuildPartName, - NiceTimeSpanMs(GetOverlayBuildPartTimer.GetElapsedTimeMs()), - NiceBytes(OverlayBuildPartManifest.GetSize())); + ZEN_INFO("GetBuildPart {} ('{}') took {}. Payload size: {}", + OverlayBuildPartId, + OverlayBuildPartName, + NiceTimeSpanMs(GetOverlayBuildPartTimer.GetElapsedTimeMs()), + NiceBytes(OverlayBuildPartManifest.GetSize())); } ChunkedFolderContent OverlayPartContent; @@ -7589,7 +7825,7 @@ namespace buildstorageoperations_testutils { { TestState(const std::filesystem::path& InRootPath) : RootPath(InRootPath) - , LogOutput(CreateStandardLogOutput(Log)) + , LogOutput(CreateStandardProgress(Log)) , ChunkController(CreateStandardChunkingController(StandardChunkingControllerSettings{})) , ChunkCache(CreateMemoryChunkingCache()) , WorkerPool(2) @@ -7631,7 +7867,8 @@ namespace buildstorageoperations_testutils { { const std::filesystem::path SourcePath = RootPath / Source; CbObject MetaData; - BuildsOperationUploadFolder Upload(*LogOutput, + BuildsOperationUploadFolder Upload(Log, + *LogOutput, Storage, AbortFlag, PauseFlag, @@ -7649,7 +7886,8 @@ namespace buildstorageoperations_testutils { { for (auto Part : Parts) { - BuildsOperationValidateBuildPart Validate(*LogOutput, + BuildsOperationValidateBuildPart Validate(Log, + *LogOutput, *Storage.BuildStorage, AbortFlag, PauseFlag, @@ -7693,7 +7931,7 @@ namespace buildstorageoperations_testutils { std::vector<ChunkBlockDescription> BlockDescriptions; std::vector<IoHash> LooseChunkHashes; - ChunkedFolderContent RemoteContent = GetRemoteContent(*LogOutput, + ChunkedFolderContent RemoteContent = GetRemoteContent(Log, Storage, BuildId, AllBuildParts, @@ -7774,7 +8012,8 @@ namespace buildstorageoperations_testutils { const ChunkedContentLookup LocalLookup = BuildChunkedContentLookup(LocalContent); const ChunkedContentLookup RemoteLookup = BuildChunkedContentLookup(RemoteContent); - BuildsOperationUpdateFolder Download(*LogOutput, + BuildsOperationUpdateFolder Download(Log, + *LogOutput, Storage, AbortFlag, PauseFlag, @@ -7837,8 +8076,8 @@ namespace buildstorageoperations_testutils { std::filesystem::path SystemRootDir; std::filesystem::path ZenFolderPath; - LoggerRef Log = ConsoleLog(); - std::unique_ptr<OperationLogOutput> LogOutput; + LoggerRef Log = ConsoleLog(); + std::unique_ptr<ProgressBase> LogOutput; std::unique_ptr<ChunkingController> ChunkController; std::unique_ptr<ChunkingCache> ChunkCache; @@ -7990,7 +8229,8 @@ TEST_CASE("buildstorageoperations.memorychunkingcache") { const std::filesystem::path SourcePath = SourceFolder.Path() / "source"; CbObject MetaData; - BuildsOperationUploadFolder Upload(*State.LogOutput, + BuildsOperationUploadFolder Upload(State.Log, + *State.LogOutput, State.Storage, State.AbortFlag, State.PauseFlag, @@ -8020,7 +8260,8 @@ TEST_CASE("buildstorageoperations.memorychunkingcache") { const std::filesystem::path SourcePath = SourceFolder.Path() / "source"; CbObject MetaData; - BuildsOperationUploadFolder Upload(*State.LogOutput, + BuildsOperationUploadFolder Upload(State.Log, + *State.LogOutput, State.Storage, State.AbortFlag, State.PauseFlag, diff --git a/src/zenremotestore/builds/buildstorageutil.cpp b/src/zenremotestore/builds/buildstorageutil.cpp index 2ae726e29..144964e37 100644 --- a/src/zenremotestore/builds/buildstorageutil.cpp +++ b/src/zenremotestore/builds/buildstorageutil.cpp @@ -9,7 +9,6 @@ #include <zenremotestore/builds/jupiterbuildstorage.h> #include <zenremotestore/chunking/chunkblock.h> #include <zenremotestore/jupiter/jupiterhost.h> -#include <zenremotestore/operationlogoutput.h> #include <zenutil/zenserverprocess.h> namespace zen { @@ -32,7 +31,7 @@ namespace { } // namespace BuildStorageResolveResult -ResolveBuildStorage(OperationLogOutput& Output, +ResolveBuildStorage(LoggerRef InLog, const HttpClientSettings& ClientSettings, std::string_view Host, std::string_view OverrideHost, @@ -40,6 +39,8 @@ ResolveBuildStorage(OperationLogOutput& Output, ZenCacheResolveMode ZenResolveMode, bool Verbose) { + ZEN_SCOPED_LOG(InLog); + bool AllowZenCacheDiscovery = ZenResolveMode == ZenCacheResolveMode::Discovery || ZenResolveMode == ZenCacheResolveMode::All; bool AllowLocalZenCache = ZenResolveMode == ZenCacheResolveMode::LocalHost || ZenResolveMode == ZenCacheResolveMode::All; @@ -80,10 +81,9 @@ ResolveBuildStorage(OperationLogOutput& Output, { if (Verbose) { - ZEN_OPERATION_LOG_INFO(Output, - "Querying servers at '{}/api/v1/status/servers'\n Connection settings:{}", - DiscoveryHost, - ConnectionSettingsToString(ClientSettings)); + ZEN_INFO("Querying servers at '{}/api/v1/status/servers'\n Connection settings:{}", + DiscoveryHost, + ConnectionSettingsToString(ClientSettings)); } DiscoveryResponse = DiscoverJupiterEndpoints(DiscoveryHost, ClientSettings); @@ -93,14 +93,14 @@ ResolveBuildStorage(OperationLogOutput& Output, { if (Verbose) { - ZEN_OPERATION_LOG_INFO(Output, "Testing server endpoint at '{}/health/live'. Assume http2: {}", OverrideHost, HostAssumeHttp2); + ZEN_INFO("Testing server endpoint at '{}/health/live'. Assume http2: {}", OverrideHost, HostAssumeHttp2); } if (JupiterEndpointTestResult TestResult = TestJupiterEndpoint(OverrideHost, HostAssumeHttp2, ClientSettings.Verbose); TestResult.Success) { if (Verbose) { - ZEN_OPERATION_LOG_INFO(Output, "Server endpoint at '{}/api/v1/status/servers' succeeded", OverrideHost); + ZEN_INFO("Server endpoint at '{}/api/v1/status/servers' succeeded", OverrideHost); } HostUrl = OverrideHost; HostName = GetHostNameFromUrl(OverrideHost); @@ -125,10 +125,9 @@ ResolveBuildStorage(OperationLogOutput& Output, { if (Verbose) { - ZEN_OPERATION_LOG_INFO(Output, - "Testing server endpoint at '{}/health/live'. Assume http2: {}", - ServerEndpoint.BaseUrl, - ServerEndpoint.AssumeHttp2); + ZEN_INFO("Testing server endpoint at '{}/health/live'. Assume http2: {}", + ServerEndpoint.BaseUrl, + ServerEndpoint.AssumeHttp2); } if (JupiterEndpointTestResult TestResult = @@ -137,7 +136,7 @@ ResolveBuildStorage(OperationLogOutput& Output, { if (Verbose) { - ZEN_OPERATION_LOG_INFO(Output, "Server endpoint at '{}/api/v1/status/servers' succeeded", ServerEndpoint.BaseUrl); + ZEN_INFO("Server endpoint at '{}/api/v1/status/servers' succeeded", ServerEndpoint.BaseUrl); } HostUrl = ServerEndpoint.BaseUrl; @@ -149,10 +148,7 @@ ResolveBuildStorage(OperationLogOutput& Output, } else { - ZEN_OPERATION_LOG_DEBUG(Output, - "Unable to reach host {}. Reason: {}", - ServerEndpoint.BaseUrl, - TestResult.FailureReason); + ZEN_DEBUG("Unable to reach host {}. Reason: {}", ServerEndpoint.BaseUrl, TestResult.FailureReason); } } } @@ -173,10 +169,9 @@ ResolveBuildStorage(OperationLogOutput& Output, { if (Verbose) { - ZEN_OPERATION_LOG_INFO(Output, - "Testing cache endpoint at '{}/status/builds'. Assume http2: {}", - CacheEndpoint.BaseUrl, - CacheEndpoint.AssumeHttp2); + ZEN_INFO("Testing cache endpoint at '{}/status/builds'. Assume http2: {}", + CacheEndpoint.BaseUrl, + CacheEndpoint.AssumeHttp2); } if (ZenCacheEndpointTestResult TestResult = @@ -185,7 +180,7 @@ ResolveBuildStorage(OperationLogOutput& Output, { if (Verbose) { - ZEN_OPERATION_LOG_INFO(Output, "Cache endpoint at '{}/status/builds' succeeded", CacheEndpoint.BaseUrl); + ZEN_INFO("Cache endpoint at '{}/status/builds' succeeded", CacheEndpoint.BaseUrl); } CacheUrl = CacheEndpoint.BaseUrl; @@ -225,7 +220,7 @@ ResolveBuildStorage(OperationLogOutput& Output, { if (Verbose) { - ZEN_OPERATION_LOG_INFO(Output, "Testing cache endpoint at '{}/status/builds'. Assume http2: {}", ZenCacheHost, false); + ZEN_INFO("Testing cache endpoint at '{}/status/builds'. Assume http2: {}", ZenCacheHost, false); } if (ZenCacheEndpointTestResult TestResult = TestZenCacheEndpoint(ZenCacheHost, /*AssumeHttp2*/ false, ClientSettings.Verbose); TestResult.Success) @@ -272,7 +267,7 @@ ParseBlockMetadatas(std::span<const CbObject> BlockMetadatas) } std::vector<ChunkBlockDescription> -GetBlockDescriptions(OperationLogOutput& Output, +GetBlockDescriptions(LoggerRef InLog, BuildStorageBase& Storage, BuildStorageCache* OptionalCacheStorage, const Oid& BuildId, @@ -282,6 +277,7 @@ GetBlockDescriptions(OperationLogOutput& Output, bool IsVerbose) { using namespace std::literals; + ZEN_SCOPED_LOG(InLog); std::vector<ChunkBlockDescription> UnorderedList; tsl::robin_map<IoHash, size_t, IoHash::Hasher> BlockDescriptionLookup; @@ -322,7 +318,7 @@ GetBlockDescriptions(OperationLogOutput& Output, if (Description.BlockHash == IoHash::Zero) { - ZEN_OPERATION_LOG_WARN(Output, "Unexpected/invalid block metadata received from remote store, skipping block"); + ZEN_WARN("Unexpected/invalid block metadata received from remote store, skipping block"); } else { @@ -383,7 +379,7 @@ GetBlockDescriptions(OperationLogOutput& Output, } if (AttemptFallback) { - ZEN_OPERATION_LOG_WARN(Output, "{} Attemping fallback options.", ErrorDescription); + ZEN_WARN("{} Attemping fallback options.", ErrorDescription); std::vector<ChunkBlockDescription> AugmentedBlockDescriptions; AugmentedBlockDescriptions.reserve(BlockRawHashes.size()); std::vector<ChunkBlockDescription> FoundBlocks = ParseChunkBlockDescriptionList(Storage.FindBlocks(BuildId, (uint64_t)-1)); @@ -406,7 +402,7 @@ GetBlockDescriptions(OperationLogOutput& Output, { if (!IsQuiet) { - ZEN_OPERATION_LOG_INFO(Output, "Found block {} via context find successfully", BlockHash); + ZEN_INFO("Found block {} via context find successfully", BlockHash); } AugmentedBlockDescriptions.emplace_back(std::move(*ListBlocksIt)); } diff --git a/src/zenremotestore/builds/jupiterbuildstorage.cpp b/src/zenremotestore/builds/jupiterbuildstorage.cpp index ad4c4bc89..d837ce07f 100644 --- a/src/zenremotestore/builds/jupiterbuildstorage.cpp +++ b/src/zenremotestore/builds/jupiterbuildstorage.cpp @@ -263,7 +263,7 @@ public: std::vector<std::function<void()>> WorkList; for (auto& WorkItem : WorkItems) { - WorkList.emplace_back([this, WorkItem = std::move(WorkItem), OnSentBytes]() { + WorkList.emplace_back([this, WorkItem = std::move(WorkItem), OnSentBytes = std::move(OnSentBytes)]() { Stopwatch ExecutionTimer; auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); }); bool IsComplete = false; @@ -444,11 +444,13 @@ public: virtual bool GetExtendedStatistics(ExtendedStatistics& OutStats) override { - OutStats.ReceivedBytesPerSource.reserve(m_ReceivedBytesPerSource.size()); - for (auto& It : m_ReceivedBytesPerSource) - { - OutStats.ReceivedBytesPerSource.insert_or_assign(It.first, m_SourceBytes[It.second]); - } + m_SourceLock.WithSharedLock([this, &OutStats]() { + OutStats.ReceivedBytesPerSource.reserve(m_ReceivedBytesPerSource.size()); + for (auto& It : m_ReceivedBytesPerSource) + { + OutStats.ReceivedBytesPerSource.insert_or_assign(It.first, m_SourceBytes[It.second].load(std::memory_order_relaxed)); + } + }); return true; } @@ -521,15 +523,29 @@ private: } if (!Result.Source.empty()) { - if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source); - It != m_ReceivedBytesPerSource.end()) - { - m_SourceBytes[It->second] += Result.ReceivedBytes; - } - else + if (!m_SourceLock.WithSharedLock([&]() { + if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source); + It != m_ReceivedBytesPerSource.end()) + { + m_SourceBytes[It->second] += Result.ReceivedBytes; + return true; + } + return false; + })) { - m_ReceivedBytesPerSource.insert_or_assign(Result.Source, m_SourceBytes.size()); - m_SourceBytes.push_back(Result.ReceivedBytes); + m_SourceLock.WithExclusiveLock([&]() { + if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source); + It != m_ReceivedBytesPerSource.end()) + { + m_SourceBytes[It->second] += Result.ReceivedBytes; + } + else if (m_SourceCount < MaxSourceCount) + { + size_t Index = m_SourceCount++; + m_ReceivedBytesPerSource.insert_or_assign(Result.Source, Index); + m_SourceBytes[Index] += Result.ReceivedBytes; + } + }); } } } @@ -540,8 +556,11 @@ private: const std::string m_Bucket; const std::filesystem::path m_TempFolderPath; - tsl::robin_map<std::string, uint32_t> m_ReceivedBytesPerSource; - std::vector<uint64_t> m_SourceBytes; + RwLock m_SourceLock; + tsl::robin_map<std::string, uint32_t> m_ReceivedBytesPerSource; + static constexpr size_t MaxSourceCount = 8u; + std::array<std::atomic<uint64_t>, MaxSourceCount> m_SourceBytes; + size_t m_SourceCount = 0; }; std::unique_ptr<BuildStorageBase> diff --git a/src/zenremotestore/chunking/chunkblock.cpp b/src/zenremotestore/chunking/chunkblock.cpp index 0fe3c09ce..f29112f53 100644 --- a/src/zenremotestore/chunking/chunkblock.cpp +++ b/src/zenremotestore/chunking/chunkblock.cpp @@ -7,7 +7,6 @@ #include <zencore/logging.h> #include <zencore/timer.h> #include <zencore/trace.h> -#include <zenremotestore/operationlogoutput.h> #include <numeric> @@ -445,7 +444,7 @@ IterateChunkBlock(const SharedBuffer& BlockPayload, }; std::vector<size_t> -FindReuseBlocks(OperationLogOutput& Output, +FindReuseBlocks(LoggerRef InLog, const uint8_t BlockReuseMinPercentLimit, const bool IsVerbose, ReuseBlocksStatistics& Stats, @@ -455,6 +454,7 @@ FindReuseBlocks(OperationLogOutput& Output, std::vector<uint32_t>& OutUnusedChunkIndexes) { ZEN_TRACE_CPU("FindReuseBlocks"); + ZEN_SCOPED_LOG(InLog); // Find all blocks with a usage level higher than MinPercentLimit // Pick out the blocks with usage higher or equal to MinPercentLimit @@ -521,11 +521,10 @@ FindReuseBlocks(OperationLogOutput& Output, { if (IsVerbose) { - ZEN_OPERATION_LOG_INFO(Output, - "Reusing block {}. {} attachments found, usage level: {}%", - KnownBlock.BlockHash, - FoundAttachmentCount, - ReusePercent); + ZEN_INFO("Reusing block {}. {} attachments found, usage level: {}%", + KnownBlock.BlockHash, + FoundAttachmentCount, + ReusePercent); } ReuseBlockIndexes.push_back(KnownBlockIndex); @@ -534,12 +533,13 @@ FindReuseBlocks(OperationLogOutput& Output, } else if (FoundAttachmentCount > 0) { - // if (IsVerbose) - //{ - // ZEN_OPERATION_LOG_INFO(Output, "Skipping block {}. {} attachments found, usage level: {}%", - // KnownBlock.BlockHash, - // FoundAttachmentCount, ReusePercent); - //} + if (IsVerbose) + { + ZEN_INFO("Skipping block {}. {} attachments found, usage level: {}%", + KnownBlock.BlockHash, + FoundAttachmentCount, + ReusePercent); + } Stats.RejectedBlockCount++; Stats.RejectedChunkCount += FoundAttachmentCount; Stats.RejectedByteCount += ReuseSize; @@ -583,11 +583,10 @@ FindReuseBlocks(OperationLogOutput& Output, { if (IsVerbose) { - ZEN_OPERATION_LOG_INFO(Output, - "Reusing block {}. {} attachments found, usage level: {}%", - KnownBlock.BlockHash, - FoundChunkIndexes.size(), - ReusePercent); + ZEN_INFO("Reusing block {}. {} attachments found, usage level: {}%", + KnownBlock.BlockHash, + FoundChunkIndexes.size(), + ReusePercent); } FilteredReuseBlockIndexes.push_back(KnownBlockIndex); @@ -604,11 +603,10 @@ FindReuseBlocks(OperationLogOutput& Output, } else { - // if (IsVerbose) - //{ - // ZEN_OPERATION_LOG_INFO(Output, "Skipping block {}. filtered usage level: {}%", KnownBlock.BlockHash, - // ReusePercent); - //} + if (IsVerbose) + { + ZEN_INFO("Skipping block {}. filtered usage level: {}%", KnownBlock.BlockHash, ReusePercent); + } Stats.RejectedBlockCount++; Stats.RejectedChunkCount += FoundChunkIndexes.size(); Stats.RejectedByteCount += AdjustedReuseSize; @@ -629,10 +627,8 @@ FindReuseBlocks(OperationLogOutput& Output, return FilteredReuseBlockIndexes; } -ChunkBlockAnalyser::ChunkBlockAnalyser(OperationLogOutput& LogOutput, - std::span<const ChunkBlockDescription> BlockDescriptions, - const Options& Options) -: m_LogOutput(LogOutput) +ChunkBlockAnalyser::ChunkBlockAnalyser(LoggerRef Log, std::span<const ChunkBlockDescription> BlockDescriptions, const Options& Options) +: m_Log(Log) , m_BlockDescriptions(BlockDescriptions) , m_Options(Options) { @@ -899,20 +895,20 @@ ChunkBlockAnalyser::CalculatePartialBlockDownloads(std::span<const NeededBlock> if (Ideal < FullDownloadTotalSize && !m_Options.IsQuiet) { const double AchievedPercent = Ideal == 0 ? 100.0 : (100.0 * Actual) / Ideal; - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Block Partial Analysis: Blocks: {}, Full: {}, Ideal: {}, Actual: {}. Skipping {} ({:.1f}%) out of " - "possible {} using {} extra ranges " - "via {} extra requests. Completed in {}", - NeededBlocks.size(), - NiceBytes(FullDownloadTotalSize), - NiceBytes(IdealDownloadTotalSize), - NiceBytes(ActualDownloadTotalSize), - NiceBytes(FullDownloadTotalSize - ActualDownloadTotalSize), - AchievedPercent, - NiceBytes(Ideal), - RangeCount - MinRequestCount, - RequestCount - MinRequestCount, - NiceTimeSpanMs(PartialAnalisysTimer.GetElapsedTimeMs())); + ZEN_INFO( + "Block Partial Analysis: Blocks: {}, Full: {}, Ideal: {}, Actual: {}. Skipping {} ({:.1f}%) out of " + "possible {} using {} extra ranges " + "via {} extra requests. Completed in {}", + NeededBlocks.size(), + NiceBytes(FullDownloadTotalSize), + NiceBytes(IdealDownloadTotalSize), + NiceBytes(ActualDownloadTotalSize), + NiceBytes(FullDownloadTotalSize - ActualDownloadTotalSize), + AchievedPercent, + NiceBytes(Ideal), + RangeCount - MinRequestCount, + RequestCount - MinRequestCount, + NiceTimeSpanMs(PartialAnalisysTimer.GetElapsedTimeMs())); } } @@ -1001,8 +997,7 @@ TEST_CASE("chunkblock.reuseblocks") BlockDescriptions.emplace_back(std::move(Block)); } - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); { // We use just about all the chunks - should result in use of both blocks @@ -1019,14 +1014,8 @@ TEST_CASE("chunkblock.reuseblocks") std::iota(ManyChunkIndexes.begin(), ManyChunkIndexes.end(), 0); std::vector<uint32_t> UnusedChunkIndexes; - std::vector<size_t> ReusedBlocks = FindReuseBlocks(*LogOutput, - 80, - false, - ReuseBlocksStats, - BlockDescriptions, - ManyChunkHashes, - ManyChunkIndexes, - UnusedChunkIndexes); + std::vector<size_t> ReusedBlocks = + FindReuseBlocks(LogRef, 80, false, ReuseBlocksStats, BlockDescriptions, ManyChunkHashes, ManyChunkIndexes, UnusedChunkIndexes); CHECK_EQ(2u, ReusedBlocks.size()); CHECK_EQ(0u, UnusedChunkIndexes.size()); @@ -1047,7 +1036,7 @@ TEST_CASE("chunkblock.reuseblocks") std::iota(ManyChunkIndexes.begin(), ManyChunkIndexes.end(), 0); std::vector<uint32_t> UnusedChunkIndexes; - std::vector<size_t> ReusedBlocks = FindReuseBlocks(*LogOutput, + std::vector<size_t> ReusedBlocks = FindReuseBlocks(LogRef, 80, false, ReuseBlocksStats, @@ -1076,7 +1065,7 @@ TEST_CASE("chunkblock.reuseblocks") // We use half the chunks - should result in no use of blocks due to 80% limit std::vector<uint32_t> UnusedChunkIndexes80Percent; ReuseBlocksStatistics ReuseBlocksStats; - std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(*LogOutput, + std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(LogRef, 80, false, ReuseBlocksStats, @@ -1092,7 +1081,7 @@ TEST_CASE("chunkblock.reuseblocks") // We use half the chunks - should result in use of both blocks due to 40% limit std::vector<uint32_t> UnusedChunkIndexes40Percent; ReuseBlocksStatistics ReuseBlocksStats; - std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(*LogOutput, + std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(LogRef, 40, false, ReuseBlocksStats, @@ -1122,7 +1111,7 @@ TEST_CASE("chunkblock.reuseblocks") // We use half the chunks for first block - should result in use of one blocks due to 80% limit ReuseBlocksStatistics ReuseBlocksStats; std::vector<uint32_t> UnusedChunkIndexes80Percent; - std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(*LogOutput, + std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(LogRef, 80, false, ReuseBlocksStats, @@ -1139,7 +1128,7 @@ TEST_CASE("chunkblock.reuseblocks") // We use half the chunks - should result in use of both blocks due to 40% limit ReuseBlocksStatistics ReuseBlocksStats; std::vector<uint32_t> UnusedChunkIndexes40Percent; - std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(*LogOutput, + std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(LogRef, 40, false, ReuseBlocksStats, @@ -1178,7 +1167,7 @@ TEST_CASE("chunkblock.reuseblocks") // We use half the chunks for first block - should result in use of one blocks due to 80% limit ReuseBlocksStatistics ReuseBlocksStats; std::vector<uint32_t> UnusedChunkIndexes80Percent; - std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(*LogOutput, + std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(LogRef, 80, false, ReuseBlocksStats, @@ -1195,7 +1184,7 @@ TEST_CASE("chunkblock.reuseblocks") // We use half the chunks - should result in use of both blocks due to 40% limit ReuseBlocksStatistics ReuseBlocksStats; std::vector<uint32_t> UnusedChunkIndexes40Percent; - std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(*LogOutput, + std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(LogRef, 40, false, ReuseBlocksStats, @@ -1214,7 +1203,7 @@ namespace chunkblock_analyser_testutils { // Build a ChunkBlockDescription without any real payload. // Hashes are derived deterministically from (BlockSeed XOR ChunkIndex) so that the same - // seed produces the same hashes — useful for deduplication tests. + // seed produces the same hashes - useful for deduplication tests. static ChunkBlockDescription MakeBlockDesc(uint64_t HeaderSize, std::initializer_list<uint32_t> CompressedLengths, uint32_t BlockSeed = 0) @@ -1257,7 +1246,7 @@ namespace chunkblock_analyser_testutils { TEST_CASE("chunkblock.mergecheapestrange.picks_smallest_gap") { using RD = chunkblock_impl::RangeDescriptor; - // Gap between ranges 0-1 is 50, gap between 1-2 is 150 → pair 0-1 gets merged + // Gap between ranges 0-1 is 50, gap between 1-2 is 150 -> pair 0-1 gets merged std::vector<RD> Ranges = { {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1}, {.RangeStart = 150, .RangeLength = 100, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1}, @@ -1279,7 +1268,7 @@ TEST_CASE("chunkblock.mergecheapestrange.picks_smallest_gap") TEST_CASE("chunkblock.mergecheapestrange.tiebreak_smaller_merged") { using RD = chunkblock_impl::RangeDescriptor; - // Gap 0-1 == gap 1-2 == 100; merged size 0-1 (250) < merged size 1-2 (350) → pair 0-1 wins + // Gap 0-1 == gap 1-2 == 100; merged size 0-1 (250) < merged size 1-2 (350) -> pair 0-1 wins std::vector<RD> Ranges = { {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1}, {.RangeStart = 200, .RangeLength = 50, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1}, @@ -1304,7 +1293,7 @@ TEST_CASE("chunkblock.optimizeranges.preserves_ranges_low_latency") { using RD = chunkblock_impl::RangeDescriptor; // With MaxRangeCountPerRequest unlimited, RequestCount=1 - // RequestTimeAsBytes = 100000 * 1 * 0.001 = 100 << slack=7000 → all ranges preserved + // RequestTimeAsBytes = 100000 * 1 * 0.001 = 100 << slack=7000 -> all ranges preserved std::vector<RD> ExactRanges = { {.RangeStart = 0, .RangeLength = 1000, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1}, {.RangeStart = 2000, .RangeLength = 1000, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1}, @@ -1325,7 +1314,7 @@ TEST_CASE("chunkblock.optimizeranges.preserves_ranges_low_latency") TEST_CASE("chunkblock.optimizeranges.falls_back_to_full_block") { using RD = chunkblock_impl::RangeDescriptor; - // 1 range already; slack=100 < SpeedBytesPerSec*LatencySec=200 → full block (empty result) + // 1 range already; slack=100 < SpeedBytesPerSec*LatencySec=200 -> full block (empty result) std::vector<RD> ExactRanges = { {.RangeStart = 100, .RangeLength = 900, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 3}, }; @@ -1344,7 +1333,7 @@ TEST_CASE("chunkblock.optimizeranges.falls_back_to_full_block") TEST_CASE("chunkblock.optimizeranges.maxrangesperblock_clamp") { using RD = chunkblock_impl::RangeDescriptor; - // 5 input ranges; MaxRangesPerBlock=2 clamps to ≤2 before the cost model runs + // 5 input ranges; MaxRangesPerBlock=2 clamps to <=2 before the cost model runs std::vector<RD> ExactRanges = { {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1}, {.RangeStart = 300, .RangeLength = 100, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1}, @@ -1378,8 +1367,8 @@ TEST_CASE("chunkblock.optimizeranges.low_maxrangecountperrequest_drives_merge") uint64_t TotalBlockSize = 1000; double LatencySec = 1.0; uint64_t SpeedBytesPerSec = 500; - // With MaxRangeCountPerRequest=-1: RequestCount=1, RequestTimeAsBytes=500 < slack=700 → preserved - // With MaxRangeCountPerRequest=1: RequestCount=3, RequestTimeAsBytes=1500 > slack=700 → merged + // With MaxRangeCountPerRequest=-1: RequestCount=1, RequestTimeAsBytes=500 < slack=700 -> preserved + // With MaxRangeCountPerRequest=1: RequestCount=3, RequestTimeAsBytes=1500 > slack=700 -> merged uint64_t MaxRangesPerBlock = 1024; auto Unlimited = @@ -1394,7 +1383,7 @@ TEST_CASE("chunkblock.optimizeranges.low_maxrangecountperrequest_drives_merge") TEST_CASE("chunkblock.optimizeranges.unlimited_rangecountperrequest_no_extra_cost") { using RD = chunkblock_impl::RangeDescriptor; - // MaxRangeCountPerRequest=-1 → RequestCount always 1, even with many ranges and high latency + // MaxRangeCountPerRequest=-1 -> RequestCount always 1, even with many ranges and high latency std::vector<RD> ExactRanges = { {.RangeStart = 0, .RangeLength = 50, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1}, {.RangeStart = 200, .RangeLength = 50, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1}, @@ -1418,7 +1407,7 @@ TEST_CASE("chunkblock.optimizeranges.two_range_direct_merge_path") { using RD = chunkblock_impl::RangeDescriptor; // Exactly 2 ranges; cost model demands merge; exercises the RangeCount==2 direct-merge branch - // After direct merge → 1 range with small slack → full block (empty) + // After direct merge -> 1 range with small slack -> full block (empty) std::vector<RD> ExactRanges = { {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 2}, {.RangeStart = 400, .RangeLength = 100, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 2}, @@ -1429,8 +1418,8 @@ TEST_CASE("chunkblock.optimizeranges.two_range_direct_merge_path") uint64_t MaxRangeCountPerReq = (uint64_t)-1; uint64_t MaxRangesPerBlock = 1024; - // Iteration 1: RangeCount=2, RequestCount=1, RequestTimeAsBytes=500 > slack=400 → direct merge - // After merge: 1 range [{0,500,0,4}], slack=100 < Speed*Lat=500 → full block + // Iteration 1: RangeCount=2, RequestCount=1, RequestTimeAsBytes=500 > slack=400 -> direct merge + // After merge: 1 range [{0,500,0,4}], slack=100 < Speed*Lat=500 -> full block auto Result = chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock); @@ -1441,12 +1430,11 @@ TEST_CASE("chunkblock.getneeded.all_chunks") { using namespace chunkblock_analyser_testutils; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); auto Block = MakeBlockDesc(50, {100, 100, 100, 100}); ChunkBlockAnalyser::Options Options; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); auto HashMap = MakeHashMap({Block}); auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t) { return true; }); @@ -1464,12 +1452,11 @@ TEST_CASE("chunkblock.getneeded.no_chunks") { using namespace chunkblock_analyser_testutils; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); auto Block = MakeBlockDesc(50, {100, 100, 100, 100}); ChunkBlockAnalyser::Options Options; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); auto HashMap = MakeHashMap({Block}); auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t) { return false; }); @@ -1481,12 +1468,11 @@ TEST_CASE("chunkblock.getneeded.subset_within_block") { using namespace chunkblock_analyser_testutils; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); auto Block = MakeBlockDesc(50, {100, 100, 100, 100}); ChunkBlockAnalyser::Options Options; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); auto HashMap = MakeHashMap({Block}); // Indices 0 and 2 are needed; 1 and 3 are not @@ -1503,12 +1489,11 @@ TEST_CASE("chunkblock.getneeded.dedup_low_slack_wins") { using namespace chunkblock_analyser_testutils; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); - // Block 0: {H0, H1, SharedH, H3} — 3 of 4 needed (H3 not needed); slack = 100 - // Block 1: {H4, H5, SharedH, H6} — only SharedH needed; slack = 300 - // Block 0 has less slack → processed first → SharedH assigned to block 0 + // Block 0: {H0, H1, SharedH, H3} - 3 of 4 needed (H3 not needed); slack = 100 + // Block 1: {H4, H5, SharedH, H6} - only SharedH needed; slack = 300 + // Block 0 has less slack -> processed first -> SharedH assigned to block 0 IoHash SharedH = IoHash::HashBuffer(MemoryView("shared_chunk_dedup", 18)); IoHash H0 = IoHash::HashBuffer(MemoryView("block0_chunk0", 13)); IoHash H1 = IoHash::HashBuffer(MemoryView("block0_chunk1", 13)); @@ -1531,9 +1516,9 @@ TEST_CASE("chunkblock.getneeded.dedup_low_slack_wins") std::vector<ChunkBlockDescription> Blocks = {Block0, Block1}; ChunkBlockAnalyser::Options Options; - ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options); + ChunkBlockAnalyser Analyser(LogRef, Blocks, Options); - // Map: H0→0, H1→1, SharedH→2, H3→3, H4→4, H5→5, H6→6 + // Map: H0->0, H1->1, SharedH->2, H3->3, H4->4, H5->5, H6->6 auto HashMap = MakeHashMap(Blocks); // Need H0(0), H1(1), SharedH(2) from block 0; SharedH from block 1 (already index 2) // H3(3) not needed; H4,H5,H6 not needed @@ -1541,7 +1526,7 @@ TEST_CASE("chunkblock.getneeded.dedup_low_slack_wins") // Block 0 slack=100 (H3 unused), block 1 slack=300 (H4,H5,H6 unused) // Block 0 processed first; picks up H0, H1, SharedH - // Block 1 tries SharedH but it's already picked up → empty → not added + // Block 1 tries SharedH but it's already picked up -> empty -> not added REQUIRE_EQ(1u, NeededBlocks.size()); CHECK_EQ(0u, NeededBlocks[0].BlockIndex); REQUIRE_EQ(3u, NeededBlocks[0].ChunkIndexes.size()); @@ -1554,8 +1539,7 @@ TEST_CASE("chunkblock.getneeded.dedup_no_double_pickup") { using namespace chunkblock_analyser_testutils; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); // SharedH appears in both blocks; should appear in the result exactly once IoHash SharedH = IoHash::HashBuffer(MemoryView("shared_chunk_nodup", 18)); @@ -1578,16 +1562,16 @@ TEST_CASE("chunkblock.getneeded.dedup_no_double_pickup") std::vector<ChunkBlockDescription> Blocks = {Block0, Block1}; ChunkBlockAnalyser::Options Options; - ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options); + ChunkBlockAnalyser Analyser(LogRef, Blocks, Options); - // Map: SharedH→0, H0→1, H1→2, H2→3, H3→4 + // Map: SharedH->0, H0->1, H1->2, H2->3, H3->4 // Only SharedH (index 0) needed; no other chunks auto HashMap = MakeHashMap(Blocks); auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex == 0; }); - // Block 0: SharedH needed, H0 not needed → slack=100 - // Block 1: SharedH needed, H1/H2/H3 not needed → slack=300 - // Block 0 processed first → picks up SharedH; Block 1 skips it + // Block 0: SharedH needed, H0 not needed -> slack=100 + // Block 1: SharedH needed, H1/H2/H3 not needed -> slack=300 + // Block 0 processed first -> picks up SharedH; Block 1 skips it // Count total occurrences of SharedH across all NeededBlocks uint32_t SharedOccurrences = 0; @@ -1609,13 +1593,12 @@ TEST_CASE("chunkblock.getneeded.skips_unrequested_chunks") { using namespace chunkblock_analyser_testutils; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); - // Block has 4 chunks but only 2 appear in the hash map → ChunkIndexes has exactly those 2 + // Block has 4 chunks but only 2 appear in the hash map -> ChunkIndexes has exactly those 2 auto Block = MakeBlockDesc(50, {100, 100, 100, 100}); ChunkBlockAnalyser::Options Options; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); // Only put chunks at positions 0 and 2 in the map tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> HashMap; @@ -1635,19 +1618,18 @@ TEST_CASE("chunkblock.getneeded.two_blocks_both_contribute") { using namespace chunkblock_analyser_testutils; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); // Block 0: all 4 needed (slack=0); block 1: 3 of 4 needed (slack=100) - // Both blocks contribute chunks → 2 NeededBlocks in result + // Both blocks contribute chunks -> 2 NeededBlocks in result auto Block0 = MakeBlockDesc(50, {100, 100, 100, 100}, /*BlockSeed=*/0); auto Block1 = MakeBlockDesc(50, {100, 100, 100, 100}, /*BlockSeed=*/200); std::vector<ChunkBlockDescription> Blocks = {Block0, Block1}; ChunkBlockAnalyser::Options Options; - ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options); + ChunkBlockAnalyser Analyser(LogRef, Blocks, Options); - // HashMap: Block0 hashes → indices 0-3, Block1 hashes → indices 4-7 + // HashMap: Block0 hashes -> indices 0-3, Block1 hashes -> indices 4-7 auto HashMap = MakeHashMap(Blocks); // Need all Block0 chunks (0-3) and Block1 chunks 0-2 (indices 4-6); not chunk index 7 (Block1 chunk 3) auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex <= 6; }); @@ -1666,15 +1648,14 @@ TEST_CASE("chunkblock.calc.off_mode") using namespace chunkblock_analyser_testutils; using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); - // HeaderSize > 0, chunks size matches → CanDoPartialBlockDownload = true + // HeaderSize > 0, chunks size matches -> CanDoPartialBlockDownload = true // But mode Off forces full block regardless auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); ChunkBlockAnalyser::Options Options; Options.IsQuiet = true; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}}; std::vector<Mode> Modes = {Mode::Off}; @@ -1691,17 +1672,16 @@ TEST_CASE("chunkblock.calc.exact_mode") using namespace chunkblock_analyser_testutils; using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); ChunkBlockAnalyser::Options Options; Options.IsQuiet = true; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; - // Need chunks 0 and 2 → 2 non-contiguous ranges; Exact mode passes them straight through + // Need chunks 0 and 2 -> 2 non-contiguous ranges; Exact mode passes them straight through std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}}; std::vector<Mode> Modes = {Mode::Exact}; @@ -1728,18 +1708,17 @@ TEST_CASE("chunkblock.calc.singlerange_mode") using namespace chunkblock_analyser_testutils; using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); - // Default HostLatencySec=-1 → OptimizeRanges not called after SingleRange collapse + // Default HostLatencySec=-1 -> OptimizeRanges not called after SingleRange collapse ChunkBlockAnalyser::Options Options; Options.IsQuiet = true; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; - // Need chunks 0 and 2 → 2 ranges that get collapsed to 1 + // Need chunks 0 and 2 -> 2 ranges that get collapsed to 1 std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}}; std::vector<Mode> Modes = {Mode::SingleRange}; @@ -1761,16 +1740,15 @@ TEST_CASE("chunkblock.calc.multirange_mode") using namespace chunkblock_analyser_testutils; using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); - // Low latency: RequestTimeAsBytes=100 << slack → OptimizeRanges preserves ranges + // Low latency: RequestTimeAsBytes=100 << slack -> OptimizeRanges preserves ranges ChunkBlockAnalyser::Options Options; Options.IsQuiet = true; Options.HostLatencySec = 0.001; Options.HostSpeedBytesPerSec = 100000; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; @@ -1792,17 +1770,16 @@ TEST_CASE("chunkblock.calc.multirangehighspeed_mode") using namespace chunkblock_analyser_testutils; using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); - // Block slack ≈ 714 bytes (TotalBlockSize≈1114, RangeTotalSize=400 for chunks 0+2) - // RequestTimeAsBytes = 400000 * 1 * 0.001 = 400 < 714 → ranges preserved + // Block slack ~= 714 bytes (TotalBlockSize~=1114, RangeTotalSize=400 for chunks 0+2) + // RequestTimeAsBytes = 400000 * 1 * 0.001 = 400 < 714 -> ranges preserved ChunkBlockAnalyser::Options Options; Options.IsQuiet = true; Options.HostHighSpeedLatencySec = 0.001; Options.HostHighSpeedBytesPerSec = 400000; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; @@ -1824,17 +1801,16 @@ TEST_CASE("chunkblock.calc.all_chunks_needed_full_block") using namespace chunkblock_analyser_testutils; using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); ChunkBlockAnalyser::Options Options; Options.IsQuiet = true; Options.HostLatencySec = 0.001; Options.HostSpeedBytesPerSec = 100000; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); - // All 4 chunks needed → short-circuit to full block regardless of mode + // All 4 chunks needed -> short-circuit to full block regardless of mode std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 1, 2, 3}}}; std::vector<Mode> Modes = {Mode::Exact}; @@ -1850,14 +1826,13 @@ TEST_CASE("chunkblock.calc.headersize_zero_forces_full_block") using namespace chunkblock_analyser_testutils; using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); - // HeaderSize=0 → CanDoPartialBlockDownload=false → full block even in Exact mode + // HeaderSize=0 -> CanDoPartialBlockDownload=false -> full block even in Exact mode auto Block = MakeBlockDesc(0, {100, 200, 300, 400}); ChunkBlockAnalyser::Options Options; Options.IsQuiet = true; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}}; std::vector<Mode> Modes = {Mode::Exact}; @@ -1874,25 +1849,24 @@ TEST_CASE("chunkblock.calc.low_maxrangecountperrequest") using namespace chunkblock_analyser_testutils; using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); - // 5 chunks of 100 bytes each; need chunks 0, 2, 4 → 3 non-contiguous ranges - // With MaxRangeCountPerRequest=1 and high latency, cost model merges aggressively → full block + // 5 chunks of 100 bytes each; need chunks 0, 2, 4 -> 3 non-contiguous ranges + // With MaxRangeCountPerRequest=1 and high latency, cost model merges aggressively -> full block auto Block = MakeBlockDesc(10, {100, 100, 100, 100, 100}); ChunkBlockAnalyser::Options Options; Options.IsQuiet = true; Options.HostLatencySec = 0.1; Options.HostSpeedBytesPerSec = 1000; Options.HostMaxRangeCountPerRequest = 1; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2, 4}}}; std::vector<Mode> Modes = {Mode::MultiRange}; auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); - // Cost model drives merging: 3 requests × 1000 × 0.1 = 300 > slack ≈ 210+headersize + // Cost model drives merging: 3 requests x 1000 x 0.1 = 300 > slack ~= 210+headersize // After merges converges to full block REQUIRE_EQ(1u, Result.FullBlockIndexes.size()); CHECK_EQ(0u, Result.FullBlockIndexes[0]); @@ -1904,14 +1878,13 @@ TEST_CASE("chunkblock.calc.no_latency_skips_optimize") using namespace chunkblock_analyser_testutils; using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); - // Default HostLatencySec=-1 → OptimizeRanges not called; raw GetBlockRanges result used + // Default HostLatencySec=-1 -> OptimizeRanges not called; raw GetBlockRanges result used ChunkBlockAnalyser::Options Options; Options.IsQuiet = true; - ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options); uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; @@ -1920,7 +1893,7 @@ TEST_CASE("chunkblock.calc.no_latency_skips_optimize") auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); - // No optimize pass → exact ranges from GetBlockRanges + // No optimize pass -> exact ranges from GetBlockRanges CHECK(Result.FullBlockIndexes.empty()); REQUIRE_EQ(2u, Result.BlockRanges.size()); CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart); @@ -1934,8 +1907,7 @@ TEST_CASE("chunkblock.calc.multiple_blocks_different_modes") using namespace chunkblock_analyser_testutils; using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; - LoggerRef LogRef = Log(); - std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + LoggerRef LogRef = Log(); // 3 blocks with different modes: Off, Exact, MultiRange auto Block0 = MakeBlockDesc(50, {100, 200, 300, 400}, /*BlockSeed=*/0); @@ -1948,7 +1920,7 @@ TEST_CASE("chunkblock.calc.multiple_blocks_different_modes") Options.HostSpeedBytesPerSec = 100000; std::vector<ChunkBlockDescription> Blocks = {Block0, Block1, Block2}; - ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options); + ChunkBlockAnalyser Analyser(LogRef, Blocks, Options); uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + 50; @@ -1961,11 +1933,11 @@ TEST_CASE("chunkblock.calc.multiple_blocks_different_modes") auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); - // Block 0: Off → FullBlockIndexes + // Block 0: Off -> FullBlockIndexes REQUIRE_EQ(1u, Result.FullBlockIndexes.size()); CHECK_EQ(0u, Result.FullBlockIndexes[0]); - // Block 1: Exact → 2 ranges; Block 2: MultiRange (low latency) → 2 ranges + // Block 1: Exact -> 2 ranges; Block 2: MultiRange (low latency) -> 2 ranges // Total: 4 ranges REQUIRE_EQ(4u, Result.BlockRanges.size()); @@ -2058,7 +2030,7 @@ TEST_CASE("chunkblock.getblockranges.non_contiguous") { using namespace chunkblock_analyser_testutils; - // Chunks 0 and 2 needed, chunk 1 skipped → two separate ranges + // Chunks 0 and 2 needed, chunk 1 skipped -> two separate ranges auto Block = MakeBlockDesc(50, {100, 200, 300}); uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; @@ -2082,7 +2054,7 @@ TEST_CASE("chunkblock.getblockranges.contiguous_run") { using namespace chunkblock_analyser_testutils; - // Chunks 1, 2, 3 needed (consecutive) → one merged range + // Chunks 1, 2, 3 needed (consecutive) -> one merged range auto Block = MakeBlockDesc(50, {50, 100, 150, 200, 250}); uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h index 0d2eded58..2d29e5efd 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h @@ -24,7 +24,7 @@ namespace zen { class CloneQueryInterface; -class OperationLogOutput; +class ProgressBase; class BuildStorageBase; class HttpClient; class ParallelWork; @@ -128,7 +128,6 @@ public: std::uint64_t PreferredMultipartChunkSize = 32u * 1024u * 1024u; EPartialBlockRequestMode PartialBlockRequestMode = EPartialBlockRequestMode::Mixed; bool WipeTargetFolder = false; - bool PrimeCacheOnly = false; bool EnableOtherDownloadsScavenging = true; bool EnableTargetFolderScavenging = true; bool ValidateCompletedSequences = true; @@ -137,7 +136,8 @@ public: bool PopulateCache = true; }; - BuildsOperationUpdateFolder(OperationLogOutput& OperationLogOutput, + BuildsOperationUpdateFolder(LoggerRef Log, + ProgressBase& Progress, StorageInstance& Storage, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, @@ -209,6 +209,45 @@ private: uint64_t ElapsedTimeMs = 0; }; + struct LooseChunkHashWorkData + { + std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs; + uint32_t RemoteChunkIndex = (uint32_t)-1; + }; + + struct FinalizeTarget + { + IoHash RawHash; + uint32_t RemotePathIndex; + }; + + struct LocalPathCategorization + { + std::vector<uint32_t> FilesToCache; + std::vector<uint32_t> RemoveLocalPathIndexes; + tsl::robin_map<uint32_t, uint32_t> RemotePathIndexToLocalPathIndex; + tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> SequenceHashToLocalPathIndex; + uint64_t MatchCount = 0; + uint64_t PathMismatchCount = 0; + uint64_t HashMismatchCount = 0; + uint64_t SkippedCount = 0; + uint64_t DeleteCount = 0; + }; + + struct WriteChunksContext + { + ParallelWork& Work; + BufferedWriteFileCache& WriteCache; + std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters; + std::span<std::atomic<bool>> RemoteChunkIndexNeedsCopyFromSourceFlags; + std::atomic<uint64_t>& WritePartsComplete; + uint64_t TotalPartWriteCount; + uint64_t TotalRequestCount; + const BlobsExistsResult& ExistsResult; + FilteredRate& FilteredDownloadedBytesPerSecond; + FilteredRate& FilteredWrittenBytesPerSecond; + }; + void ScanCacheFolder(tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedChunkHashesFound, tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedSequenceHashesFound); void ScanTempBlocksFolder(tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedBlocksFound); @@ -261,12 +300,16 @@ private: void DownloadBuildBlob(uint32_t RemoteChunkIndex, const BlobsExistsResult& ExistsResult, ParallelWork& Work, + uint64_t TotalRequestCount, + FilteredRate& FilteredDownloadedBytesPerSecond, std::function<void(IoBuffer&& Payload)>&& OnDownloaded); - void DownloadPartialBlock(std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges, - size_t BlockRangeIndex, - size_t BlockRangeCount, - const BlobsExistsResult& ExistsResult, + void DownloadPartialBlock(std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges, + size_t BlockRangeIndex, + size_t BlockRangeCount, + const BlobsExistsResult& ExistsResult, + uint64_t TotalRequestCount, + FilteredRate& FilteredDownloadedBytesPerSecond, std::function<void(IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath, size_t BlockRangeStartIndex, @@ -323,8 +366,8 @@ private: std::span<std::atomic<bool>> RemoteChunkIndexNeedsCopyFromSourceFlags, BufferedWriteFileCache& WriteCache); - void AsyncWriteDownloadedChunk(const std::filesystem::path& ZenFolderPath, - uint32_t RemoteChunkIndex, + void AsyncWriteDownloadedChunk(uint32_t RemoteChunkIndex, + const BlobsExistsResult& ExistsResult, std::vector<const ChunkedContentLookup::ChunkSequenceLocation*>&& ChunkTargetPtrs, BufferedWriteFileCache& WriteCache, ParallelWork& Work, @@ -332,8 +375,7 @@ private: std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters, std::atomic<uint64_t>& WritePartsComplete, const uint64_t TotalPartWriteCount, - FilteredRate& FilteredWrittenBytesPerSecond, - bool EnableBacklog); + FilteredRate& FilteredWrittenBytesPerSecond); void VerifyAndCompleteChunkSequencesAsync(std::span<const uint32_t> RemoteSequenceIndexes, ParallelWork& Work); bool CompleteSequenceChunk(uint32_t RemoteSequenceIndex, std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters); @@ -343,7 +385,112 @@ private: void FinalizeChunkSequences(std::span<const uint32_t> RemoteSequenceIndexes); void VerifySequence(uint32_t RemoteSequenceIndex); - OperationLogOutput& m_LogOutput; + void InitializeSequenceCounters(std::vector<std::atomic<uint32_t>>& OutSequenceCounters, + tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutSequencesLeftToFind, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedChunkHashesFound, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedSequenceHashesFound); + + void MatchScavengedSequencesToRemote(std::span<const ChunkedFolderContent> Contents, + std::span<const ChunkedContentLookup> Lookups, + std::span<const std::filesystem::path> Paths, + tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& InOutSequencesLeftToFind, + std::vector<std::atomic<uint32_t>>& InOutSequenceCounters, + std::vector<ScavengedSequenceCopyOperation>& OutCopyOperations, + uint64_t& OutScavengedPathsCount); + + uint64_t CalculateBytesToWriteAndFlagNeededChunks(std::span<const std::atomic<uint32_t>> SequenceCounters, + const std::vector<bool>& NeedsCopyFromLocalFileFlags, + std::span<std::atomic<bool>> OutNeedsCopyFromSourceFlags); + + void ClassifyCachedAndFetchBlocks(std::span<const ChunkBlockAnalyser::NeededBlock> NeededBlocks, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedBlocksFound, + uint64_t& TotalPartWriteCount, + std::vector<uint32_t>& OutCachedChunkBlockIndexes, + std::vector<uint32_t>& OutFetchBlockIndexes); + + std::vector<uint32_t> DetermineNeededLooseChunkIndexes(std::span<const std::atomic<uint32_t>> SequenceCounters, + const std::vector<bool>& NeedsCopyFromLocalFileFlags, + std::span<std::atomic<bool>> NeedsCopyFromSourceFlags); + + BlobsExistsResult QueryBlobCacheExists(std::span<const uint32_t> NeededLooseChunkIndexes, std::span<const uint32_t> FetchBlockIndexes); + + std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> DeterminePartialDownloadModes(const BlobsExistsResult& ExistsResult); + + std::vector<LooseChunkHashWorkData> BuildLooseChunkHashWorks(std::span<const uint32_t> NeededLooseChunkIndexes, + std::span<const std::atomic<uint32_t>> SequenceCounters); + + void VerifyWriteChunksComplete(std::span<const std::atomic<uint32_t>> SequenceCounters, + uint64_t BytesToWrite, + uint64_t BytesToValidate); + + std::vector<FinalizeTarget> BuildSortedFinalizeTargets(); + + void ScanScavengeSources(std::span<const ScavengeSource> Sources, + std::vector<ChunkedFolderContent>& OutContents, + std::vector<ChunkedContentLookup>& OutLookups, + std::vector<std::filesystem::path>& OutPaths); + + LocalPathCategorization CategorizeLocalPaths(const tsl::robin_map<std::string, uint32_t>& RemotePathToRemoteIndex); + + void ScheduleLocalFileCaching(std::span<const uint32_t> FilesToCache, + std::atomic<uint64_t>& OutCachedCount, + std::atomic<uint64_t>& OutCachedByteCount); + + void ScheduleScavengedSequenceWrites(WriteChunksContext& Context, + std::span<const ScavengedSequenceCopyOperation> CopyOperations, + const std::vector<ChunkedFolderContent>& ScavengedContents, + const std::vector<std::filesystem::path>& ScavengedPaths); + + void ScheduleLooseChunkWrites(WriteChunksContext& Context, std::vector<LooseChunkHashWorkData>& LooseChunkHashWorks); + + void ScheduleLocalChunkCopies(WriteChunksContext& Context, + std::span<const CopyChunkData> CopyChunkDatas, + CloneQueryInterface* CloneQuery, + const std::vector<ChunkedFolderContent>& ScavengedContents, + const std::vector<ChunkedContentLookup>& ScavengedLookups, + const std::vector<std::filesystem::path>& ScavengedPaths); + + void ScheduleCachedBlockWrites(WriteChunksContext& Context, std::span<const uint32_t> CachedBlockIndexes); + + void SchedulePartialBlockDownloads(WriteChunksContext& Context, const ChunkBlockAnalyser::BlockResult& PartialBlocks); + + void WritePartialBlockToCache(WriteChunksContext& Context, + size_t BlockRangeStartIndex, + IoBuffer BlockPartialBuffer, + const std::filesystem::path& BlockChunkPath, + std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths, + const ChunkBlockAnalyser::BlockResult& PartialBlocks); + + void ScheduleFullBlockDownloads(WriteChunksContext& Context, std::span<const uint32_t> FullBlockIndexes); + + void WriteFullBlockToCache(WriteChunksContext& Context, + uint32_t BlockIndex, + IoBuffer BlockBuffer, + const std::filesystem::path& BlockChunkPath); + + void ScheduleLocalFileRemovals(ParallelWork& Work, + std::span<const uint32_t> RemoveLocalPathIndexes, + std::atomic<uint64_t>& DeletedCount); + + void ScheduleTargetFinalization(ParallelWork& Work, + std::span<const FinalizeTarget> Targets, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& SequenceHashToLocalPathIndex, + const tsl::robin_map<uint32_t, uint32_t>& RemotePathIndexToLocalPathIndex, + FolderContent& OutLocalFolderState, + std::atomic<uint64_t>& TargetsComplete); + + void FinalizeTargetGroup(size_t BaseOffset, + size_t Count, + std::span<const FinalizeTarget> Targets, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& SequenceHashToLocalPathIndex, + const tsl::robin_map<uint32_t, uint32_t>& RemotePathIndexToLocalPathIndex, + FolderContent& OutLocalFolderState, + std::atomic<uint64_t>& TargetsComplete); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + ProgressBase& m_Progress; StorageInstance& m_Storage; std::atomic<bool>& m_AbortFlag; std::atomic<bool>& m_PauseFlag; @@ -492,7 +639,8 @@ public: bool PopulateCache = true; }; - BuildsOperationUploadFolder(OperationLogOutput& OperationLogOutput, + BuildsOperationUploadFolder(LoggerRef Log, + ProgressBase& Progress, StorageInstance& Storage, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, @@ -569,6 +717,28 @@ private: GenerateBlocksStatistics& GenerateBlocksStats, UploadStatistics& UploadStats); + struct GenerateBuildBlocksContext + { + ParallelWork& Work; + WorkerThreadPool& GenerateBlobsPool; + WorkerThreadPool& UploadBlocksPool; + FilteredRate& FilteredGeneratedBytesPerSecond; + FilteredRate& FilteredUploadedBytesPerSecond; + std::atomic<uint64_t>& QueuedPendingBlocksForUpload; + RwLock& Lock; + GeneratedBlocks& OutBlocks; + GenerateBlocksStatistics& GenerateBlocksStats; + UploadStatistics& UploadStats; + size_t NewBlockCount; + }; + + void ScheduleBlockGeneration(GenerateBuildBlocksContext& Context, + const ChunkedFolderContent& Content, + const ChunkedContentLookup& Lookup, + const std::vector<std::vector<uint32_t>>& NewBlockChunks); + + void UploadGeneratedBlock(GenerateBuildBlocksContext& Context, size_t BlockIndex, CompressedBuffer Payload); + std::vector<uint32_t> CalculateAbsoluteChunkOrders(const std::span<const IoHash> LocalChunkHashes, const std::span<const uint32_t> LocalChunkOrder, const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& ChunkHashToLocalChunkIndex, @@ -609,6 +779,58 @@ private: uint32_t PartStepOffset, uint32_t StepCount); + ChunkedFolderContent ScanPartContent(const UploadPart& Part, + ChunkingController& ChunkController, + ChunkingCache& ChunkCache, + ChunkingStatistics& ChunkingStats); + + void ConsumePrepareBuildResult(); + + void ClassifyChunksByBlockEligibility(const ChunkedFolderContent& LocalContent, + std::vector<uint32_t>& OutLooseChunkIndexes, + std::vector<uint32_t>& OutNewBlockChunkIndexes, + std::vector<size_t>& OutReuseBlockIndexes, + LooseChunksStatistics& LooseChunksStats, + FindBlocksStatistics& FindBlocksStats, + ReuseBlocksStatistics& ReuseBlocksStats); + + struct BuiltPartManifest + { + CbObject PartManifest; + std::vector<ChunkBlockDescription> AllChunkBlockDescriptions; + std::vector<IoHash> AllChunkBlockHashes; + }; + + BuiltPartManifest BuildPartManifestObject(const ChunkedFolderContent& LocalContent, + const ChunkedContentLookup& LocalLookup, + ChunkingController& ChunkController, + std::span<const size_t> ReuseBlockIndexes, + const GeneratedBlocks& NewBlocks, + std::span<const uint32_t> LooseChunkIndexes); + + void UploadAttachmentBatch(std::span<IoHash> RawHashes, + std::vector<IoHash>& OutUnknownChunks, + const ChunkedFolderContent& LocalContent, + const ChunkedContentLookup& LocalLookup, + const std::vector<std::vector<uint32_t>>& NewBlockChunks, + GeneratedBlocks& NewBlocks, + std::span<const uint32_t> LooseChunkIndexes, + UploadStatistics& UploadStats, + LooseChunksStatistics& LooseChunksStats); + + void FinalizeBuildPartWithRetries(const UploadPart& Part, + const IoHash& PartHash, + std::vector<IoHash>& InOutUnknownChunks, + const ChunkedFolderContent& LocalContent, + const ChunkedContentLookup& LocalLookup, + const std::vector<std::vector<uint32_t>>& NewBlockChunks, + GeneratedBlocks& NewBlocks, + std::span<const uint32_t> LooseChunkIndexes, + UploadStatistics& UploadStats, + LooseChunksStatistics& LooseChunksStats); + + void UploadMissingBlockMetadata(GeneratedBlocks& NewBlocks, UploadStatistics& UploadStats); + void UploadPartBlobs(const ChunkedFolderContent& Content, const ChunkedContentLookup& Lookup, std::span<IoHash> RawHashes, @@ -620,18 +842,72 @@ private: LooseChunksStatistics& TempLooseChunksStats, std::vector<IoHash>& OutUnknownChunks); + struct UploadPartClassification + { + std::vector<size_t> BlockIndexes; + std::vector<uint32_t> LooseChunkOrderIndexes; + uint64_t TotalBlocksSize = 0; + uint64_t TotalLooseChunksSize = 0; + }; + + UploadPartClassification ClassifyUploadRawHashes(std::span<IoHash> RawHashes, + const ChunkedFolderContent& Content, + const ChunkedContentLookup& Lookup, + const GeneratedBlocks& NewBlocks, + std::span<const uint32_t> LooseChunkIndexes, + std::vector<IoHash>& OutUnknownChunks); + + struct UploadPartBlobsContext + { + ParallelWork& Work; + WorkerThreadPool& ReadChunkPool; + WorkerThreadPool& UploadChunkPool; + FilteredRate& FilteredGenerateBlockBytesPerSecond; + FilteredRate& FilteredCompressedBytesPerSecond; + FilteredRate& FilteredUploadedBytesPerSecond; + std::atomic<size_t>& UploadedBlockSize; + std::atomic<size_t>& UploadedBlockCount; + std::atomic<size_t>& UploadedRawChunkSize; + std::atomic<size_t>& UploadedCompressedChunkSize; + std::atomic<uint32_t>& UploadedChunkCount; + std::atomic<uint64_t>& GeneratedBlockCount; + std::atomic<uint64_t>& GeneratedBlockByteCount; + std::atomic<uint64_t>& QueuedPendingInMemoryBlocksForUpload; + size_t UploadBlockCount; + uint32_t UploadChunkCount; + uint64_t LargeAttachmentSize; + GeneratedBlocks& NewBlocks; + const ChunkedFolderContent& Content; + const ChunkedContentLookup& Lookup; + const std::vector<std::vector<uint32_t>>& NewBlockChunks; + std::span<const uint32_t> LooseChunkIndexes; + UploadStatistics& TempUploadStats; + LooseChunksStatistics& TempLooseChunksStats; + }; + + void ScheduleBlockGenerationAndUpload(UploadPartBlobsContext& Context, std::span<const size_t> BlockIndexes); + + void ScheduleLooseChunkCompressionAndUpload(UploadPartBlobsContext& Context, std::span<const uint32_t> LooseChunkOrderIndexes); + + void UploadBlockPayload(UploadPartBlobsContext& Context, size_t BlockIndex, const IoHash& BlockHash, CompositeBuffer Payload); + + void UploadLooseChunkPayload(UploadPartBlobsContext& Context, const IoHash& RawHash, uint64_t RawSize, CompositeBuffer Payload); + CompositeBuffer CompressChunk(const ChunkedFolderContent& Content, const ChunkedContentLookup& Lookup, uint32_t ChunkIndex, LooseChunksStatistics& TempLooseChunksStats); - OperationLogOutput& m_LogOutput; - StorageInstance& m_Storage; - std::atomic<bool>& m_AbortFlag; - std::atomic<bool>& m_PauseFlag; - WorkerThreadPool& m_IOWorkerPool; - WorkerThreadPool& m_NetworkPool; - const Oid m_BuildId; + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + ProgressBase& m_Progress; + StorageInstance& m_Storage; + std::atomic<bool>& m_AbortFlag; + std::atomic<bool>& m_PauseFlag; + WorkerThreadPool& m_IOWorkerPool; + WorkerThreadPool& m_NetworkPool; + const Oid m_BuildId; const std::filesystem::path m_Path; const bool m_CreateBuild; // ?? Member? @@ -665,7 +941,8 @@ public: bool IsQuiet = false; bool IsVerbose = false; }; - BuildsOperationValidateBuildPart(OperationLogOutput& OperationLogOutput, + BuildsOperationValidateBuildPart(LoggerRef Log, + ProgressBase& Progress, BuildStorageBase& Storage, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, @@ -682,21 +959,61 @@ public: DownloadStatistics m_DownloadStats; private: + enum class TaskSteps : uint32_t + { + FetchBuild, + FetchBuildPart, + ValidateBlobs, + Cleanup, + StepCount + }; + ChunkBlockDescription ValidateChunkBlock(IoBuffer&& Payload, const IoHash& BlobHash, uint64_t& OutCompressedSize, uint64_t& OutDecompressedSize); - OperationLogOutput& m_LogOutput; - BuildStorageBase& m_Storage; - std::atomic<bool>& m_AbortFlag; - std::atomic<bool>& m_PauseFlag; - WorkerThreadPool& m_IOWorkerPool; - WorkerThreadPool& m_NetworkPool; - const Oid m_BuildId; - Oid m_BuildPartId; - const std::string m_BuildPartName; - const Options m_Options; + struct ValidateBlobsContext + { + ParallelWork& Work; + uint64_t AttachmentsToVerifyCount; + FilteredRate& FilteredDownloadedBytesPerSecond; + FilteredRate& FilteredVerifiedBytesPerSecond; + }; + + struct ResolvedBuildPart + { + std::vector<IoHash> ChunkAttachments; + std::vector<IoHash> BlockAttachments; + uint64_t PreferredMultipartChunkSize = 0; + }; + + ResolvedBuildPart ResolveBuildPart(); + + void ScheduleChunkAttachmentValidation(ValidateBlobsContext& Context, + std::span<const IoHash> ChunkAttachments, + const std::filesystem::path& TempFolder, + uint64_t PreferredMultipartChunkSize); + + void ScheduleBlockAttachmentValidation(ValidateBlobsContext& Context, std::span<const IoHash> BlockAttachments); + + void ValidateDownloadedChunk(ValidateBlobsContext& Context, const IoHash& ChunkHash, IoBuffer Payload); + + void ValidateDownloadedBlock(ValidateBlobsContext& Context, const IoHash& BlockAttachment, IoBuffer Payload); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + ProgressBase& m_Progress; + BuildStorageBase& m_Storage; + std::atomic<bool>& m_AbortFlag; + std::atomic<bool>& m_PauseFlag; + WorkerThreadPool& m_IOWorkerPool; + WorkerThreadPool& m_NetworkPool; + const Oid m_BuildId; + Oid m_BuildPartId; + const std::string m_BuildPartName; + const Options m_Options; }; class BuildsOperationPrimeCache @@ -712,7 +1029,8 @@ public: bool ForceUpload = false; }; - BuildsOperationPrimeCache(OperationLogOutput& OperationLogOutput, + BuildsOperationPrimeCache(LoggerRef Log, + ProgressBase& Progress, StorageInstance& Storage, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, @@ -727,7 +1045,33 @@ public: DownloadStatistics m_DownloadStats; private: - OperationLogOutput& m_LogOutput; + LoggerRef Log() { return m_Log; } + + void CollectReferencedBlobs(tsl::robin_set<IoHash, IoHash::Hasher>& OutBuildBlobs, + tsl::robin_map<IoHash, uint64_t, IoHash::Hasher>& OutLooseChunkRawSizes); + + std::vector<IoHash> FilterAlreadyCachedBlobs(const tsl::robin_set<IoHash, IoHash::Hasher>& BuildBlobs); + + void ScheduleBlobDownloads(std::span<const IoHash> BlobsToDownload, + const tsl::robin_map<IoHash, uint64_t, IoHash::Hasher>& LooseChunkRawSizes, + std::atomic<uint64_t>& MultipartAttachmentCount, + std::atomic<size_t>& CompletedDownloadCount, + FilteredRate& FilteredDownloadedBytesPerSecond); + + void DownloadLargeBlobForCache(ParallelWork& Work, + const IoHash& BlobHash, + size_t BlobCount, + std::atomic<size_t>& CompletedDownloadCount, + std::atomic<uint64_t>& MultipartAttachmentCount, + FilteredRate& FilteredDownloadedBytesPerSecond); + + void DownloadSingleBlobForCache(const IoHash& BlobHash, + size_t BlobCount, + std::atomic<size_t>& CompletedDownloadCount, + FilteredRate& FilteredDownloadedBytesPerSecond); + + LoggerRef m_Log; + ProgressBase& m_Progress; StorageInstance& m_Storage; std::atomic<bool>& m_AbortFlag; std::atomic<bool>& m_PauseFlag; @@ -755,7 +1099,7 @@ std::vector<std::pair<Oid, std::string>> ResolveBuildPartNames(CbObjectView struct BuildManifest; -ChunkedFolderContent GetRemoteContent(OperationLogOutput& Output, +ChunkedFolderContent GetRemoteContent(LoggerRef InLog, StorageInstance& Storage, const Oid& BuildId, const std::vector<std::pair<Oid, std::string>>& BuildParts, diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h index 7306188ca..c55c930bc 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h @@ -8,7 +8,6 @@ namespace zen { -class OperationLogOutput; class BuildStorageBase; class BuildStorageCache; @@ -38,7 +37,7 @@ enum class ZenCacheResolveMode All }; -BuildStorageResolveResult ResolveBuildStorage(OperationLogOutput& Output, +BuildStorageResolveResult ResolveBuildStorage(LoggerRef InLog, const HttpClientSettings& ClientSettings, std::string_view Host, std::string_view OverrideHost, @@ -46,7 +45,7 @@ BuildStorageResolveResult ResolveBuildStorage(OperationLogOutput& Output, ZenCacheResolveMode ZenResolveMode, bool Verbose); -std::vector<ChunkBlockDescription> GetBlockDescriptions(OperationLogOutput& Output, +std::vector<ChunkBlockDescription> GetBlockDescriptions(LoggerRef InLog, BuildStorageBase& Storage, BuildStorageCache* OptionalCacheStorage, const Oid& BuildId, diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h index e3a5f6539..73d037542 100644 --- a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h +++ b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h @@ -3,6 +3,7 @@ #pragma once #include <zencore/iohash.h> +#include <zencore/logbase.h> #include <zencore/compactbinary.h> #include <zencore/compress.h> @@ -64,9 +65,7 @@ struct ReuseBlocksStatistics } }; -class OperationLogOutput; - -std::vector<size_t> FindReuseBlocks(OperationLogOutput& Output, +std::vector<size_t> FindReuseBlocks(LoggerRef InLog, const uint8_t BlockReuseMinPercentLimit, const bool IsVerbose, ReuseBlocksStatistics& Stats, @@ -91,7 +90,7 @@ public: uint64_t MaxRangesPerBlock = 1024u; }; - ChunkBlockAnalyser(OperationLogOutput& LogOutput, std::span<const ChunkBlockDescription> BlockDescriptions, const Options& Options); + ChunkBlockAnalyser(LoggerRef Log, std::span<const ChunkBlockDescription> BlockDescriptions, const Options& Options); struct BlockRangeDescriptor { @@ -130,7 +129,9 @@ public: std::span<const EPartialBlockDownloadMode> BlockPartialDownloadModes); private: - OperationLogOutput& m_LogOutput; + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; const std::span<const ChunkBlockDescription> m_BlockDescriptions; const Options m_Options; }; diff --git a/src/zenremotestore/include/zenremotestore/operationlogoutput.h b/src/zenremotestore/include/zenremotestore/operationlogoutput.h deleted file mode 100644 index 32b95f50f..000000000 --- a/src/zenremotestore/include/zenremotestore/operationlogoutput.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/fmtutils.h> -#include <zencore/logbase.h> - -namespace zen { - -class OperationLogOutput -{ -public: - virtual ~OperationLogOutput() {} - virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) = 0; - - virtual void SetLogOperationName(std::string_view Name) = 0; - virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) = 0; - virtual uint32_t GetProgressUpdateDelayMS() = 0; - - class ProgressBar - { - public: - struct State - { - bool operator==(const State&) const = default; - std::string Task; - std::string Details; - uint64_t TotalCount = 0; - uint64_t RemainingCount = 0; - enum class EStatus - { - Running, - Aborted, - Paused - }; - EStatus Status = EStatus::Running; - - static EStatus CalculateStatus(bool IsAborted, bool IsPaused) - { - if (IsAborted) - { - return EStatus::Aborted; - } - if (IsPaused) - { - return EStatus::Paused; - } - return EStatus::Running; - } - }; - - virtual ~ProgressBar() {} - - virtual void UpdateState(const State& NewState, bool DoLinebreak) = 0; - virtual void Finish() = 0; - }; - - virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) = 0; -}; - -OperationLogOutput* CreateStandardLogOutput(LoggerRef Log); - -#define ZEN_OPERATION_LOG(OutputTarget, InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - static constinit zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - (OutputTarget).EmitLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ - } while (false) - -#define ZEN_OPERATION_LOG_INFO(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Info, fmtstr, ##__VA_ARGS__) -#define ZEN_OPERATION_LOG_DEBUG(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Debug, fmtstr, ##__VA_ARGS__) -#define ZEN_OPERATION_LOG_WARN(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Warn, fmtstr, ##__VA_ARGS__) - -} // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/projectstore/projectstoreoperations.h b/src/zenremotestore/include/zenremotestore/projectstore/projectstoreoperations.h index a07ede6f6..db5b27d3f 100644 --- a/src/zenremotestore/include/zenremotestore/projectstore/projectstoreoperations.h +++ b/src/zenremotestore/include/zenremotestore/projectstore/projectstoreoperations.h @@ -20,7 +20,7 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { class BuildStorageBase; -class OperationLogOutput; +class ProgressBase; struct StorageInstance; class ProjectStoreOperationOplogState @@ -34,10 +34,7 @@ public: std::filesystem::path TempFolderPath; }; - ProjectStoreOperationOplogState(OperationLogOutput& OperationLogOutput, - StorageInstance& Storage, - const Oid& BuildId, - const Options& Options); + ProjectStoreOperationOplogState(LoggerRef Log, StorageInstance& Storage, const Oid& BuildId, const Options& Options); CbObjectView LoadBuildObject(); CbObjectView LoadBuildPartsObject(); @@ -51,10 +48,12 @@ public: const Oid& GetBuildPartId(); private: - OperationLogOutput& m_LogOutput; - StorageInstance& m_Storage; - const Oid m_BuildId; - const Options m_Options; + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + StorageInstance& m_Storage; + const Oid m_BuildId; + const Options m_Options; Oid m_BuildPartId = Oid::Zero; CbObject m_BuildObject; @@ -79,7 +78,8 @@ public: bool PopulateCache = true; }; - ProjectStoreOperationDownloadAttachments(OperationLogOutput& OperationLogOutput, + ProjectStoreOperationDownloadAttachments(LoggerRef Log, + ProgressBase& Progress, StorageInstance& Storage, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, @@ -92,12 +92,15 @@ public: void Execute(); private: - OperationLogOutput& m_LogOutput; - StorageInstance& m_Storage; - std::atomic<bool>& m_AbortFlag; - std::atomic<bool>& m_PauseFlag; - WorkerThreadPool& m_IOWorkerPool; - WorkerThreadPool& m_NetworkPool; + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + ProgressBase& m_Progress; + StorageInstance& m_Storage; + std::atomic<bool>& m_AbortFlag; + std::atomic<bool>& m_PauseFlag; + WorkerThreadPool& m_IOWorkerPool; + WorkerThreadPool& m_NetworkPool; ProjectStoreOperationOplogState& m_State; const tsl::robin_set<IoHash, IoHash::Hasher> m_AttachmentHashes; diff --git a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h index 8df892053..b81708341 100644 --- a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h +++ b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h @@ -152,7 +152,8 @@ struct RemoteStoreOptions typedef std::function<CompositeBuffer(const IoHash& AttachmentHash)> TGetAttachmentBufferFunc; -CbObject BuildContainer(CidStore& ChunkStore, +CbObject BuildContainer(LoggerRef InLog, + CidStore& ChunkStore, ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, WorkerThreadPool& WorkerPool, @@ -205,7 +206,8 @@ RemoteProjectStore::Result SaveOplogContainer( const std::function<void(const ChunkedInfo& Chunked)>& OnChunkedAttachment, JobContext* OptionalContext); -void SaveOplog(CidStore& ChunkStore, +void SaveOplog(LoggerRef InLog, + CidStore& ChunkStore, RemoteProjectStore& RemoteStore, ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, @@ -222,6 +224,7 @@ void SaveOplog(CidStore& ChunkStore, struct LoadOplogContext { + LoggerRef Log; CidStore& ChunkStore; RemoteProjectStore& RemoteStore; BuildStorageCache* OptionalCache = nullptr; diff --git a/src/zenremotestore/jupiter/jupitersession.cpp b/src/zenremotestore/jupiter/jupitersession.cpp index a9788cb4e..d610d1fc8 100644 --- a/src/zenremotestore/jupiter/jupitersession.cpp +++ b/src/zenremotestore/jupiter/jupitersession.cpp @@ -673,7 +673,7 @@ JupiterSession::PutMultipartBuildBlob(std::string_view Namespace, size_t RetryPartIndex = PartNameToIndex.at(RetryPartId); const MultipartUploadResponse::Part& RetryPart = Workload->PartDescription.Parts[RetryPartIndex]; IoBuffer RetryPartPayload = - Workload->Transmitter(RetryPart.FirstByte, RetryPart.LastByte - RetryPart.FirstByte - 1); + Workload->Transmitter(RetryPart.FirstByte, RetryPart.LastByte - RetryPart.FirstByte); std::string RetryMultipartUploadResponseRequestString = fmt::format("/api/v2/builds/{}/{}/{}/blobs/{}/uploadMultipart{}&supportsRedirect={}", Namespace, diff --git a/src/zenremotestore/operationlogoutput.cpp b/src/zenremotestore/operationlogoutput.cpp deleted file mode 100644 index 5ed844c9d..000000000 --- a/src/zenremotestore/operationlogoutput.cpp +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zenremotestore/operationlogoutput.h> - -#include <zencore/logging.h> -#include <zencore/logging/logger.h> - -ZEN_THIRD_PARTY_INCLUDES_START -#include <gsl/gsl-lite.hpp> -ZEN_THIRD_PARTY_INCLUDES_END - -namespace zen { - -class StandardLogOutput; - -class StandardLogOutputProgressBar : public OperationLogOutput::ProgressBar -{ -public: - StandardLogOutputProgressBar(StandardLogOutput& Output, std::string_view InSubTask) : m_Output(Output), m_SubTask(InSubTask) {} - - virtual void UpdateState(const State& NewState, bool DoLinebreak) override; - virtual void Finish() override; - -private: - StandardLogOutput& m_Output; - std::string m_SubTask; - State m_State; -}; - -class StandardLogOutput : public OperationLogOutput -{ -public: - StandardLogOutput(LoggerRef& Log) : m_Log(Log) {} - virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override - { - if (m_Log.ShouldLog(Point.Level)) - { - m_Log->Log(Point, Args); - } - } - - virtual void SetLogOperationName(std::string_view Name) override - { - m_LogOperationName = Name; - ZEN_OPERATION_LOG_INFO(*this, "{}", m_LogOperationName); - } - virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override - { - [[maybe_unused]] const size_t PercentDone = StepCount > 0u ? gsl::narrow<uint8_t>((100 * StepIndex) / StepCount) : 0u; - ZEN_OPERATION_LOG_INFO(*this, "{}: {}%", m_LogOperationName, PercentDone); - } - virtual uint32_t GetProgressUpdateDelayMS() override { return 2000; } - virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override - { - return new StandardLogOutputProgressBar(*this, InSubTask); - } - -private: - LoggerRef m_Log; - std::string m_LogOperationName; - LoggerRef Log() { return m_Log; } -}; - -void -StandardLogOutputProgressBar::UpdateState(const State& NewState, bool DoLinebreak) -{ - ZEN_UNUSED(DoLinebreak); - [[maybe_unused]] const size_t PercentDone = - NewState.TotalCount > 0u ? gsl::narrow<uint8_t>((100 * (NewState.TotalCount - NewState.RemainingCount)) / NewState.TotalCount) : 0u; - std::string Task = NewState.Task; - switch (NewState.Status) - { - case State::EStatus::Aborted: - Task = "Aborting"; - break; - case State::EStatus::Paused: - Task = "Paused"; - break; - default: - break; - } - ZEN_OPERATION_LOG_INFO(m_Output, "{}: {}%{}", Task, PercentDone, NewState.Details.empty() ? "" : fmt::format(" {}", NewState.Details)); - m_State = NewState; -} -void -StandardLogOutputProgressBar::Finish() -{ - if (m_State.RemainingCount > 0) - { - State NewState = m_State; - NewState.RemainingCount = 0; - NewState.Details = ""; - UpdateState(NewState, /*DoLinebreak*/ true); - } -} - -OperationLogOutput* -CreateStandardLogOutput(LoggerRef Log) -{ - return new StandardLogOutput(Log); -} - -} // namespace zen diff --git a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp index e95d9118c..d7596263b 100644 --- a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp @@ -9,7 +9,6 @@ #include <zenremotestore/builds/buildstorageutil.h> #include <zenremotestore/builds/jupiterbuildstorage.h> -#include <zenremotestore/operationlogoutput.h> #include <numeric> @@ -436,8 +435,6 @@ public: BuildStorageCache* OptionalCache, const Oid& CacheBuildId) override { - std::unique_ptr<OperationLogOutput> Output(CreateStandardLogOutput(Log())); - ZEN_ASSERT(m_OplogBuildPartId != Oid::Zero); ZEN_ASSERT(OptionalCache == nullptr || CacheBuildId == m_BuildId); @@ -447,7 +444,7 @@ public: try { - Result.Blocks = zen::GetBlockDescriptions(*Output, + Result.Blocks = zen::GetBlockDescriptions(Log(), *m_BuildStorage, OptionalCache, m_BuildId, diff --git a/src/zenremotestore/projectstore/projectstoreoperations.cpp b/src/zenremotestore/projectstore/projectstoreoperations.cpp index 36dc4d868..ba4b74825 100644 --- a/src/zenremotestore/projectstore/projectstoreoperations.cpp +++ b/src/zenremotestore/projectstore/projectstoreoperations.cpp @@ -3,13 +3,14 @@ #include <zenremotestore/projectstore/projectstoreoperations.h> #include <zencore/compactbinaryutil.h> +#include <zencore/fmtutils.h> #include <zencore/parallelwork.h> #include <zencore/scopeguard.h> #include <zencore/timer.h> #include <zenremotestore/builds/buildstorageutil.h> #include <zenremotestore/chunking/chunkedfile.h> -#include <zenremotestore/operationlogoutput.h> #include <zenremotestore/projectstore/remoteprojectstore.h> +#include <zenutil/progress.h> namespace zen { @@ -17,11 +18,11 @@ using namespace std::literals; //////////////////////////// ProjectStoreOperationOplogState -ProjectStoreOperationOplogState::ProjectStoreOperationOplogState(OperationLogOutput& OperationLogOutput, - StorageInstance& Storage, - const Oid& BuildId, - const Options& Options) -: m_LogOutput(OperationLogOutput) +ProjectStoreOperationOplogState::ProjectStoreOperationOplogState(LoggerRef Log, + StorageInstance& Storage, + const Oid& BuildId, + const Options& Options) +: m_Log(Log) , m_Storage(Storage) , m_BuildId(BuildId) , m_Options(Options) @@ -48,10 +49,7 @@ ProjectStoreOperationOplogState::LoadBuildObject() { if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Read build {} from locally cached file in {}", - m_BuildId, - NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + ZEN_INFO("Read build {} from locally cached file in {}", m_BuildId, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } return m_BuildObject; } @@ -61,11 +59,10 @@ ProjectStoreOperationOplogState::LoadBuildObject() m_BuildObject = m_Storage.BuildStorage->GetBuild(m_BuildId); if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Fetched build {} from {} in {}", - m_BuildId, - m_Storage.BuildStorageHttp->GetBaseUri(), - NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + ZEN_INFO("Fetched build {} from {} in {}", + m_BuildId, + m_Storage.BuildStorageHttp->GetBaseUri(), + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } CreateDirectories(CachedBuildObjectPath.parent_path()); TemporaryFile::SafeWriteFile(CachedBuildObjectPath, m_BuildObject.GetBuffer().GetView()); @@ -122,11 +119,10 @@ ProjectStoreOperationOplogState::LoadBuildPartsObject() { if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Read build part {}/{} from locally cached file in {}", - m_BuildId, - BuildPartId, - NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + ZEN_INFO("Read build part {}/{} from locally cached file in {}", + m_BuildId, + BuildPartId, + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } return m_BuildPartsObject; } @@ -136,12 +132,11 @@ ProjectStoreOperationOplogState::LoadBuildPartsObject() m_BuildPartsObject = m_Storage.BuildStorage->GetBuildPart(m_BuildId, BuildPartId); if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Fetched build part {}/{} from {} in {}", - m_BuildId, - BuildPartId, - m_Storage.BuildStorageHttp->GetBaseUri(), - NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + ZEN_INFO("Fetched build part {}/{} from {} in {}", + m_BuildId, + BuildPartId, + m_Storage.BuildStorageHttp->GetBaseUri(), + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } CreateDirectories(CachedBuildPartObjectPath.parent_path()); TemporaryFile::SafeWriteFile(CachedBuildPartObjectPath, m_BuildPartsObject.GetBuffer().GetView()); @@ -168,11 +163,7 @@ ProjectStoreOperationOplogState::LoadOpsSectionObject() } else if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Read {}/{}/ops from locally cached file in {}", - BuildPartId, - m_BuildId, - NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + ZEN_INFO("Read {}/{}/ops from locally cached file in {}", BuildPartId, m_BuildId, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); return m_OpsSectionObject; } } @@ -193,11 +184,10 @@ ProjectStoreOperationOplogState::LoadOpsSectionObject() } if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Decompressed and validated oplog payload {} -> {} in {}", - NiceBytes(OpsSection.GetSize()), - NiceBytes(m_OpsSectionObject.GetSize()), - NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + ZEN_INFO("Decompressed and validated oplog payload {} -> {} in {}", + NiceBytes(OpsSection.GetSize()), + NiceBytes(m_OpsSectionObject.GetSize()), + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } if (m_OpsSectionObject) { @@ -226,12 +216,11 @@ ProjectStoreOperationOplogState::LoadArrayFromBuildPart(std::string_view ArrayNa { if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Read {}/{}/{} from locally cached file in {}", - BuildPartId, - m_BuildId, - ArrayName, - NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + ZEN_INFO("Read {}/{}/{} from locally cached file in {}", + BuildPartId, + m_BuildId, + ArrayName, + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } CbArray Result = CbArray(SharedBuffer(std::move(Payload))); return Result; @@ -290,7 +279,8 @@ ProjectStoreOperationOplogState::LoadChunksArray() //////////////////////////// ProjectStoreOperationDownloadAttachments -ProjectStoreOperationDownloadAttachments::ProjectStoreOperationDownloadAttachments(OperationLogOutput& OperationLogOutput, +ProjectStoreOperationDownloadAttachments::ProjectStoreOperationDownloadAttachments(LoggerRef Log, + ProgressBase& Progress, StorageInstance& Storage, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, @@ -299,7 +289,8 @@ ProjectStoreOperationDownloadAttachments::ProjectStoreOperationDownloadAttachmen ProjectStoreOperationOplogState& State, std::span<const IoHash> AttachmentHashes, const Options& Options) -: m_LogOutput(OperationLogOutput) +: m_Log(Log) +, m_Progress(Progress) , m_Storage(Storage) , m_AbortFlag(AbortFlag) , m_PauseFlag(PauseFlag) @@ -325,9 +316,9 @@ ProjectStoreOperationDownloadAttachments::Execute() }; auto EndProgress = - MakeGuard([&]() { m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); }); + MakeGuard([&]() { m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); }); - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::ReadAttachmentData, (uint32_t)TaskSteps::StepCount); + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::ReadAttachmentData, (uint32_t)TaskSteps::StepCount); Stopwatch Timer; tsl::robin_map<IoHash, uint64_t, IoHash::Hasher> ChunkSizes; @@ -415,13 +406,12 @@ ProjectStoreOperationDownloadAttachments::Execute() FilesToDechunk.size() > 0 ? fmt::format("\n{} file{} needs to be dechunked", FilesToDechunk.size(), FilesToDechunk.size() == 1 ? "" : "s") : ""; - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Need to download {} block{} and {} chunk{}{}", - BlocksToDownload.size(), - BlocksToDownload.size() == 1 ? "" : "s", - LooseChunksToDownload.size(), - LooseChunksToDownload.size() == 1 ? "" : "s", - DechunkInfo); + ZEN_INFO("Need to download {} block{} and {} chunk{}{}", + BlocksToDownload.size(), + BlocksToDownload.size() == 1 ? "" : "s", + LooseChunksToDownload.size(), + LooseChunksToDownload.size() == 1 ? "" : "s", + DechunkInfo); } auto GetBuildBlob = [this](const IoHash& RawHash, const std::filesystem::path& OutputPath) { @@ -470,18 +460,15 @@ ProjectStoreOperationDownloadAttachments::Execute() std::filesystem::path TempAttachmentPath = MakeSafeAbsolutePath(m_Options.AttachmentOutputPath) / ".tmp"; CreateDirectories(TempAttachmentPath); auto _0 = MakeGuard([this, &TempAttachmentPath]() { - if (true) + if (!m_Options.IsQuiet) { - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Cleaning up temporary directory"); - } - CleanDirectory(TempAttachmentPath, true); - RemoveDir(TempAttachmentPath); + ZEN_INFO("Cleaning up temporary directory"); } + CleanDirectory(TempAttachmentPath, true); + RemoveDir(TempAttachmentPath); }); - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::Download, (uint32_t)TaskSteps::StepCount); + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::Download, (uint32_t)TaskSteps::StepCount); std::filesystem::path BlocksPath = TempAttachmentPath / "blocks"; CreateDirectories(BlocksPath); @@ -492,11 +479,9 @@ ProjectStoreOperationDownloadAttachments::Execute() std::filesystem::path LooseChunksPath = TempAttachmentPath / "loosechunks"; CreateDirectories(LooseChunksPath); - std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Downloading")); - OperationLogOutput::ProgressBar& DownloadProgressBar(*ProgressBarPtr); + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Downloading"); - std::atomic<bool> PauseFlag; - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); std::atomic<size_t> LooseChunksCompleted; std::atomic<size_t> BlocksCompleted; @@ -511,7 +496,7 @@ ProjectStoreOperationDownloadAttachments::Execute() if (m_Options.ForceDownload || !IsFile(LooseChunkOutputPath)) { GetBuildBlob(RawHash, LooseChunkOutputPath); - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Downloaded chunk {}", RawHash); + ZEN_DEBUG("Downloaded chunk {}", RawHash); } Work.ScheduleWork(m_IOWorkerPool, [&, LooseChunkIndex, LooseChunkOutputPath](std::atomic<bool>&) { @@ -547,7 +532,7 @@ ProjectStoreOperationDownloadAttachments::Execute() { ChunkOutput.Close(); RemoveFile(ChunkOutputPath); - throw std::runtime_error(fmt::format("Failed to decompress chunk {} to ", RawHash, ChunkOutputPath)); + throw std::runtime_error(fmt::format("Failed to decompress chunk {} to '{}'", RawHash, ChunkOutputPath)); } } else @@ -555,7 +540,7 @@ ProjectStoreOperationDownloadAttachments::Execute() TemporaryFile::SafeWriteFile(ChunkOutputPath, CompressedChunk.GetCompressed()); } - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Wrote loose chunk {} to '{}'", RawHash, ChunkOutputPath); + ZEN_DEBUG("Wrote loose chunk {} to '{}'", RawHash, ChunkOutputPath); LooseChunksCompleted++; }); }); @@ -572,7 +557,7 @@ ProjectStoreOperationDownloadAttachments::Execute() if (m_Options.ForceDownload || !IsFile(BlockOutputPath)) { GetBuildBlob(RawHash, BlockOutputPath); - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Downloaded block {}", RawHash); + ZEN_DEBUG("Downloaded block {}", RawHash); } Work.ScheduleWork(m_IOWorkerPool, [&, BlockIndex, BlockOutputPath](std::atomic<bool>&) { @@ -607,7 +592,7 @@ ProjectStoreOperationDownloadAttachments::Execute() ChunkOutput.Close(); RemoveFile(ChunkOutputPath); throw std::runtime_error( - fmt::format("Failed to decompress chunk {} to ", ChunkHash, ChunkOutputPath)); + fmt::format("Failed to decompress chunk {} to '{}'", ChunkHash, ChunkOutputPath)); } } else @@ -615,7 +600,7 @@ ProjectStoreOperationDownloadAttachments::Execute() TemporaryFile::SafeWriteFile(ChunkOutputPath, CompressedChunk.GetCompressed()); } - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Wrote block chunk {} to '{}'", ChunkHash, ChunkOutputPath); + ZEN_DEBUG("Wrote block chunk {} to '{}'", ChunkHash, ChunkOutputPath); } if (ChunkedFileRawHashes.contains(ChunkHash)) { @@ -635,7 +620,7 @@ ProjectStoreOperationDownloadAttachments::Execute() }); } - Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { + Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { ZEN_UNUSED(IsAborted, IsPaused, PendingWork); std::string Details = fmt::format("{}/{} blocks, {}/{} chunks downloaded", @@ -643,39 +628,37 @@ ProjectStoreOperationDownloadAttachments::Execute() BlocksToDownload.size(), LooseChunksCompleted.load(), LooseChunksToDownload.size()); - DownloadProgressBar.UpdateState({.Task = "Downloading", - .Details = Details, - .TotalCount = BlocksToDownload.size() + LooseChunksToDownload.size(), - .RemainingCount = BlocksToDownload.size() + LooseChunksToDownload.size() - - (BlocksCompleted.load() + LooseChunksCompleted.load()), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, - false); + ProgressBar->UpdateState({.Task = "Downloading", + .Details = Details, + .TotalCount = BlocksToDownload.size() + LooseChunksToDownload.size(), + .RemainingCount = BlocksToDownload.size() + LooseChunksToDownload.size() - + (BlocksCompleted.load() + LooseChunksCompleted.load()), + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + false); }); - DownloadProgressBar.Finish(); + ProgressBar->Finish(); if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "{} block{} downloaded, {} loose chunk{} downloaded in {}", - BlocksToDownload.size(), - BlocksToDownload.size() == 1 ? "" : "s", - LooseChunksToDownload.size(), - LooseChunksToDownload.size() == 1 ? "" : "s", - NiceTimeSpanMs(DownloadTimer.GetElapsedTimeMs())); + ZEN_INFO("{} block{} downloaded, {} loose chunk{} downloaded in {}", + BlocksToDownload.size(), + BlocksToDownload.size() == 1 ? "" : "s", + LooseChunksToDownload.size(), + LooseChunksToDownload.size() == 1 ? "" : "s", + NiceTimeSpanMs(DownloadTimer.GetElapsedTimeMs())); } } if (!ChunkedFileInfos.empty()) { - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::AnalyzeDechunk, (uint32_t)TaskSteps::StepCount); + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::AnalyzeDechunk, (uint32_t)TaskSteps::StepCount); std::filesystem::path ChunkedFilesPath = TempAttachmentPath / "chunkedfiles"; CreateDirectories(ChunkedFilesPath); try { - std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Dechunking")); - OperationLogOutput::ProgressBar& DechunkingProgressBar(*ProgressBarPtr); + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Dechunking"); std::atomic<uint64_t> ChunksWritten; @@ -729,7 +712,7 @@ ProjectStoreOperationDownloadAttachments::Execute() PrepareFileForScatteredWrite(OpenChunkedFiles.back()->Handle(), ChunkedFileInfo.RawSize); } - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::Dechunk, (uint32_t)TaskSteps::StepCount); + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::Dechunk, (uint32_t)TaskSteps::StepCount); std::vector<std::atomic<uint8_t>> ChunkWrittenFlags(ChunkOpenFileTargets.size()); @@ -755,7 +738,7 @@ ProjectStoreOperationDownloadAttachments::Execute() })) { std::error_code DummyEc; - throw std::runtime_error(fmt::format("Failed to decompress chunk {} at offset {} to {}", + throw std::runtime_error(fmt::format("Failed to decompress chunk {} at offset {} to '{}'", CompressedChunkBuffer.DecodeRawHash(), ChunkTarget.Offset, PathFromHandle(OutputFile.Handle(), DummyEc))); @@ -768,8 +751,7 @@ ProjectStoreOperationDownloadAttachments::Execute() { Stopwatch DechunkTimer; - std::atomic<bool> PauseFlag; - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); std::vector<IoHash> LooseChunks(LooseChunksToDownload.begin(), LooseChunksToDownload.end()); @@ -819,26 +801,24 @@ ProjectStoreOperationDownloadAttachments::Execute() } }); } - Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { + Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { ZEN_UNUSED(IsAborted, IsPaused, PendingWork); std::string Details = fmt::format("{}/{} chunks written", ChunksWritten.load(), ChunkOpenFileTargets.size()); - DechunkingProgressBar.UpdateState( - {.Task = "Dechunking ", - .Details = Details, - .TotalCount = ChunkOpenFileTargets.size(), - .RemainingCount = ChunkOpenFileTargets.size() - ChunksWritten.load(), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, - false); + ProgressBar->UpdateState({.Task = "Dechunking ", + .Details = Details, + .TotalCount = ChunkOpenFileTargets.size(), + .RemainingCount = ChunkOpenFileTargets.size() - ChunksWritten.load(), + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + false); }); - DechunkingProgressBar.Finish(); + ProgressBar->Finish(); if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "{} file{} dechunked in {}", - ChunkedFileInfos.size(), - ChunkedFileInfos.size() == 1 ? "" : "s", - NiceTimeSpanMs(DechunkTimer.GetElapsedTimeMs())); + ZEN_INFO("{} file{} dechunked in {}", + ChunkedFileInfos.size(), + ChunkedFileInfos.size() == 1 ? "" : "s", + NiceTimeSpanMs(DechunkTimer.GetElapsedTimeMs())); } } } @@ -853,12 +833,10 @@ ProjectStoreOperationDownloadAttachments::Execute() throw; } { - Stopwatch VerifyTimer; - std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Verifying")); - OperationLogOutput::ProgressBar& VerifyProgressBar(*ProgressBarPtr); + Stopwatch VerifyTimer; + std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Verifying"); - std::atomic<bool> PauseFlag; - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); std::atomic<size_t> DechunkedFilesMoved; @@ -875,43 +853,41 @@ ProjectStoreOperationDownloadAttachments::Execute() } std::filesystem::path ChunkOutputPath = m_Options.AttachmentOutputPath / fmt::format("{}", ChunkedFileInfo.RawHash); RenameFile(ChunkedFilePath, ChunkOutputPath); - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Moved dechunked file {} to '{}'", ChunkedFileInfo.RawHash, ChunkOutputPath); + ZEN_DEBUG("Moved dechunked file {} to '{}'", ChunkedFileInfo.RawHash, ChunkOutputPath); DechunkedFilesMoved++; }); } - Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { + Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { ZEN_UNUSED(IsAborted, IsPaused, PendingWork); std::string Details = fmt::format("{}/{} files verified", DechunkedFilesMoved.load(), ChunkedFileInfos.size()); - VerifyProgressBar.UpdateState({.Task = "Verifying ", - .Details = Details, - .TotalCount = ChunkedFileInfos.size(), - .RemainingCount = ChunkedFileInfos.size() - DechunkedFilesMoved.load(), - .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, - false); + ProgressBar->UpdateState({.Task = "Verifying ", + .Details = Details, + .TotalCount = ChunkedFileInfos.size(), + .RemainingCount = ChunkedFileInfos.size() - DechunkedFilesMoved.load(), + .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)}, + false); }); - VerifyProgressBar.Finish(); + ProgressBar->Finish(); if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Verified {} chunked file{} in {}", - ChunkedFileInfos.size(), - ChunkedFileInfos.size() == 1 ? "" : "s", - NiceTimeSpanMs(VerifyTimer.GetElapsedTimeMs())); + ZEN_INFO("Verified {} chunked file{} in {}", + ChunkedFileInfos.size(), + ChunkedFileInfos.size() == 1 ? "" : "s", + NiceTimeSpanMs(VerifyTimer.GetElapsedTimeMs())); } } } if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Downloaded {} attachment{} to '{}' in {}", - m_AttachmentHashes.size(), - m_AttachmentHashes.size() == 1 ? "" : "s", - m_Options.AttachmentOutputPath, - NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + ZEN_INFO("Downloaded {} attachment{} to '{}' in {}", + m_AttachmentHashes.size(), + m_AttachmentHashes.size() == 1 ? "" : "s", + m_Options.AttachmentOutputPath, + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount); + m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount); } } // namespace zen diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp index 1a9dc10ef..f43f0813a 100644 --- a/src/zenremotestore/projectstore/remoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp @@ -8,6 +8,8 @@ #include <zencore/filesystem.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> +#include <zencore/logging/broadcastsink.h> +#include <zencore/logging/logger.h> #include <zencore/parallelwork.h> #include <zencore/scopeguard.h> #include <zencore/stream.h> @@ -18,8 +20,9 @@ #include <zenremotestore/builds/buildstoragecache.h> #include <zenremotestore/chunking/chunkedcontent.h> #include <zenremotestore/chunking/chunkedfile.h> -#include <zenremotestore/operationlogoutput.h> #include <zenstore/cidstore.h> +#include <zenutil/logging.h> +#include <zenutil/progress.h> #include <numeric> #include <unordered_map> @@ -392,7 +395,10 @@ namespace remotestore_impl { OodleCompressor Compressor, OodleCompressionLevel CompressionLevel) { - ZEN_ASSERT(!IsFile(AttachmentPath)); + if (IsFile(AttachmentPath)) + { + ZEN_WARN("Temp attachment file already exists at '{}', truncating", AttachmentPath); + } BasicFile CompressedFile; std::error_code Ec; CompressedFile.Open(AttachmentPath, BasicFile::Mode::kTruncateDelete, Ec); @@ -448,6 +454,7 @@ namespace remotestore_impl { }; CbObject RewriteOplog( + LoggerRef InLog, ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, bool IgnoreMissingAttachments, @@ -456,6 +463,7 @@ namespace remotestore_impl { std::unordered_map<IoHash, FoundAttachment, IoHash::Hasher>& UploadAttachments, // TODO: Rename to OutUploadAttachments JobContext* OptionalContext) { + ZEN_SCOPED_LOG(InLog); size_t OpCount = 0; CreateDirectories(AttachmentTempPath); @@ -929,7 +937,6 @@ namespace remotestore_impl { { return; } - ZEN_ASSERT(UploadAttachment->Size != 0); if (!UploadAttachment->RawPath.empty()) { if (UploadAttachment->Size > (MaxChunkEmbedSize * 2)) @@ -1140,31 +1147,51 @@ namespace remotestore_impl { std::atomic<uint64_t> ChunksCompleteCount = 0; }; - class JobContextLogOutput : public OperationLogOutput + class JobContextSink : public logging::Sink { public: - JobContextLogOutput(JobContext* OptionalContext) : m_OptionalContext(OptionalContext) {} - virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override + explicit JobContextSink(JobContext* Context) : m_Context(Context) {} + + void Log(const logging::LogMessage& Msg) override { - if (m_OptionalContext) + if (m_Context) { - fmt::basic_memory_buffer<char, 250> MessageBuffer; - fmt::vformat_to(fmt::appender(MessageBuffer), Point.FormatString, Args); - remotestore_impl::ReportMessage(m_OptionalContext, std::string_view(MessageBuffer.data(), MessageBuffer.size())); + m_Context->ReportMessage(Msg.GetPayload()); } } - virtual void SetLogOperationName(std::string_view Name) override { ZEN_UNUSED(Name); } - virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { ZEN_UNUSED(StepIndex, StepCount); } - virtual uint32_t GetProgressUpdateDelayMS() override { return 0; } - virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override + void Flush() override {} + void SetFormatter(std::unique_ptr<logging::Formatter>) override {} + + private: + JobContext* m_Context; + }; + + class JobContextLogger + { + public: + explicit JobContextLogger(JobContext* OptionalContext) { - ZEN_UNUSED(InSubTask); - return nullptr; + if (!OptionalContext) + { + return; + } + logging::SinkPtr ContextSink(new JobContextSink(OptionalContext)); + Ref<logging::BroadcastSink> DefaultSink = GetDefaultBroadcastSink(); + std::vector<logging::SinkPtr> Sinks; + if (DefaultSink) + { + Sinks.push_back(DefaultSink); + } + Sinks.push_back(std::move(ContextSink)); + Ref<logging::BroadcastSink> Broadcast(new logging::BroadcastSink(std::move(Sinks))); + m_Log = Ref<logging::Logger>(new logging::Logger("jobcontext", Broadcast)); } + LoggerRef Log() const { return m_Log ? LoggerRef(*m_Log) : zen::Log(); } + private: - JobContext* m_OptionalContext; + Ref<logging::Logger> m_Log; }; void DownloadAndSaveBlockChunks(LoadOplogContext& Context, @@ -1185,6 +1212,7 @@ namespace remotestore_impl { &LoadAttachmentsTimer, &DownloadStartMS](std::atomic<bool>& AbortFlag) { ZEN_TRACE_CPU("DownloadBlockChunks"); + ZEN_SCOPED_LOG(Context.Log); if (AbortFlag) { @@ -1300,6 +1328,7 @@ namespace remotestore_impl { &AllNeededPartialChunkHashesLookup, ChunkDownloadedFlags](std::atomic<bool>& AbortFlag) { ZEN_TRACE_CPU("DownloadBlock"); + ZEN_SCOPED_LOG(Context.Log); if (AbortFlag) { @@ -1366,6 +1395,7 @@ namespace remotestore_impl { &AllNeededPartialChunkHashesLookup, ChunkDownloadedFlags, Bytes = std::move(BlobBuffer)](std::atomic<bool>& AbortFlag) { + ZEN_SCOPED_LOG(Context.Log); if (AbortFlag) { return; @@ -1715,6 +1745,7 @@ namespace remotestore_impl { ChunkDownloadedFlags, RetriesLeft](std::atomic<bool>& AbortFlag) { ZEN_TRACE_CPU("DownloadBlockRanges"); + ZEN_SCOPED_LOG(Context.Log); try { uint64_t Unset = (std::uint64_t)-1; @@ -1760,6 +1791,7 @@ namespace remotestore_impl { OffsetAndLengths = std::vector<std::pair<uint64_t, uint64_t>>(OffsetAndLengths.begin(), OffsetAndLengths.end())]( std::atomic<bool>& AbortFlag) { + ZEN_SCOPED_LOG(Context.Log); try { ZEN_ASSERT(BlockPayload.Size() > 0); @@ -1972,6 +2004,7 @@ namespace remotestore_impl { Context.NetworkWorkerPool, [&Context, &AttachmentWork, RawHash, &LoadAttachmentsTimer, &DownloadStartMS, &Info](std::atomic<bool>& AbortFlag) { ZEN_TRACE_CPU("DownloadAttachment"); + ZEN_SCOPED_LOG(Context.Log); if (AbortFlag) { @@ -2061,7 +2094,8 @@ namespace remotestore_impl { WorkerThreadPool::EMode::EnableBacklog); }; - void AsyncCreateBlock(ParallelWork& Work, + void AsyncCreateBlock(LoggerRef InLog, + ParallelWork& Work, WorkerThreadPool& WorkerPool, std::vector<std::pair<IoHash, FetchChunkFunc>>&& ChunksInBlock, RwLock& SectionsLock, @@ -2071,9 +2105,10 @@ namespace remotestore_impl { JobContext* OptionalContext) { Work.ScheduleWork(WorkerPool, - [&Blocks, &SectionsLock, BlockIndex, Chunks = std::move(ChunksInBlock), &AsyncOnBlock, OptionalContext]( + [InLog, &Blocks, &SectionsLock, BlockIndex, Chunks = std::move(ChunksInBlock), &AsyncOnBlock, OptionalContext]( std::atomic<bool>& AbortFlag) mutable { ZEN_TRACE_CPU("CreateBlock"); + ZEN_SCOPED_LOG(InLog); if (remotestore_impl::IsCancelled(OptionalContext)) { @@ -2452,7 +2487,8 @@ GetBlocksFromOplog(CbObjectView ContainerObject, std::span<const IoHash> Include } CbObject -BuildContainer(CidStore& ChunkStore, +BuildContainer(LoggerRef InLog, + CidStore& ChunkStore, ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, size_t MaxBlockSize, @@ -2472,7 +2508,8 @@ BuildContainer(CidStore& ChunkStore, { using namespace std::literals; - std::unique_ptr<OperationLogOutput> LogOutput(std::make_unique<remotestore_impl::JobContextLogOutput>(OptionalContext)); + ZEN_SCOPED_LOG(InLog); + remotestore_impl::JobContextLogger JobContextOutput(OptionalContext); Stopwatch Timer; @@ -2485,7 +2522,8 @@ BuildContainer(CidStore& ChunkStore, size_t TotalOpCount = Oplog.GetOplogEntryCount(); Stopwatch RewriteOplogTimer; - CbObject SectionOps = remotestore_impl::RewriteOplog(Project, + CbObject SectionOps = remotestore_impl::RewriteOplog(InLog, + Project, Oplog, IgnoreMissingAttachments, EmbedLooseFiles, @@ -2605,7 +2643,7 @@ BuildContainer(CidStore& ChunkStore, std::vector<uint32_t> UnusedChunkIndexes; ReuseBlocksStatistics ReuseBlocksStats; - ReusedBlockIndexes = FindReuseBlocks(*LogOutput, + ReusedBlockIndexes = FindReuseBlocks(JobContextOutput.Log(), /*BlockReuseMinPercentLimit*/ 80, /*IsVerbose*/ false, ReuseBlocksStats, @@ -2749,7 +2787,8 @@ BuildContainer(CidStore& ChunkStore, .MaxChunkEmbedSize = MaxChunkEmbedSize, .IsCancelledFunc = [OptionalContext]() { return remotestore_impl::IsCancelled(OptionalContext); }}); - auto OnNewBlock = [&Work, + auto OnNewBlock = [&Log, + &Work, &WorkerPool, BuildBlocks, &BlockCreateProgressTimer, @@ -2774,7 +2813,8 @@ BuildContainer(CidStore& ChunkStore, size_t BlockIndex = remotestore_impl::AddBlock(BlocksLock, Blocks); if (BuildBlocks) { - remotestore_impl::AsyncCreateBlock(Work, + remotestore_impl::AsyncCreateBlock(Log(), + Work, WorkerPool, std::move(ChunksInBlock), BlocksLock, @@ -3007,6 +3047,17 @@ BuildContainer(CidStore& ChunkStore, return {}; } + // Reused blocks were not composed (their chunks were erased from UploadAttachments) but must + // still appear in the container so that a fresh receiver knows to download them. + if (BuildBlocks) + { + for (size_t KnownBlockIndex : ReusedBlockIndexes) + { + const ChunkBlockDescription& Reused = KnownBlocks[KnownBlockIndex]; + Blocks.push_back(Reused); + } + } + CbObjectWriter OplogContainerWriter; RwLock::SharedLockScope _(BlocksLock); OplogContainerWriter.AddBinary("ops"sv, CompressedOpsSection.GetCompressed().Flatten().AsIoBuffer()); @@ -3096,7 +3147,8 @@ BuildContainer(CidStore& ChunkStore, } CbObject -BuildContainer(CidStore& ChunkStore, +BuildContainer(LoggerRef InLog, + CidStore& ChunkStore, ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, WorkerThreadPool& WorkerPool, @@ -3112,7 +3164,8 @@ BuildContainer(CidStore& ChunkStore, const std::function<void(std::vector<std::pair<IoHash, FetchChunkFunc>>&&)>& OnBlockChunks, bool EmbedLooseFiles) { - return BuildContainer(ChunkStore, + return BuildContainer(InLog, + ChunkStore, Project, Oplog, MaxBlockSize, @@ -3132,7 +3185,8 @@ BuildContainer(CidStore& ChunkStore, } void -SaveOplog(CidStore& ChunkStore, +SaveOplog(LoggerRef InLog, + CidStore& ChunkStore, RemoteProjectStore& RemoteStore, ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, @@ -3149,6 +3203,7 @@ SaveOplog(CidStore& ChunkStore, { using namespace std::literals; + ZEN_SCOPED_LOG(InLog); Stopwatch Timer; remotestore_impl::UploadInfo Info; @@ -3168,8 +3223,8 @@ SaveOplog(CidStore& ChunkStore, std::unordered_map<IoHash, remotestore_impl::CreatedBlock, IoHash::Hasher> CreatedBlocks; tsl::robin_map<IoHash, TGetAttachmentBufferFunc, IoHash::Hasher> LooseLargeFiles; - auto MakeTempBlock = [AttachmentTempPath, &AttachmentsLock, &CreatedBlocks](CompressedBuffer&& CompressedBlock, - ChunkBlockDescription&& Block) { + auto MakeTempBlock = [&Log, AttachmentTempPath, &AttachmentsLock, &CreatedBlocks](CompressedBuffer&& CompressedBlock, + ChunkBlockDescription&& Block) { std::filesystem::path BlockPath = AttachmentTempPath; BlockPath.append(Block.BlockHash.ToHexString()); IoBuffer BlockBuffer = WriteToTempFile(std::move(CompressedBlock).GetCompressed(), BlockPath); @@ -3180,8 +3235,8 @@ SaveOplog(CidStore& ChunkStore, ZEN_DEBUG("Saved temp block to '{}', {}", AttachmentTempPath, NiceBytes(BlockSize)); }; - auto UploadBlock = [&RemoteStore, &RemoteStoreInfo, &Info, OptionalContext](CompressedBuffer&& CompressedBlock, - ChunkBlockDescription&& Block) { + auto UploadBlock = [&Log, &RemoteStore, &RemoteStoreInfo, &Info, OptionalContext](CompressedBuffer&& CompressedBlock, + ChunkBlockDescription&& Block) { IoHash BlockHash = Block.BlockHash; uint64_t CompressedSize = CompressedBlock.GetCompressedSize(); RemoteProjectStore::SaveAttachmentResult Result = @@ -3201,13 +3256,13 @@ SaveOplog(CidStore& ChunkStore, }; std::vector<std::vector<std::pair<IoHash, FetchChunkFunc>>> BlockChunks; - auto OnBlockChunks = [&BlockChunks](std::vector<std::pair<IoHash, FetchChunkFunc>>&& Chunks) { + auto OnBlockChunks = [&Log, &BlockChunks](std::vector<std::pair<IoHash, FetchChunkFunc>>&& Chunks) { BlockChunks.push_back({std::make_move_iterator(Chunks.begin()), std::make_move_iterator(Chunks.end())}); ZEN_DEBUG("Found {} block chunks", Chunks.size()); }; - auto OnLargeAttachment = [&AttachmentsLock, &LargeAttachments, &LooseLargeFiles](const IoHash& AttachmentHash, - TGetAttachmentBufferFunc&& GetBufferFunc) { + auto OnLargeAttachment = [&Log, &AttachmentsLock, &LargeAttachments, &LooseLargeFiles](const IoHash& AttachmentHash, + TGetAttachmentBufferFunc&& GetBufferFunc) { { RwLock::ExclusiveLockScope _(AttachmentsLock); LargeAttachments.insert(AttachmentHash); @@ -3286,7 +3341,8 @@ SaveOplog(CidStore& ChunkStore, } } - CbObject OplogContainerObject = BuildContainer(ChunkStore, + CbObject OplogContainerObject = BuildContainer(InLog, + ChunkStore, Project, Oplog, MaxBlockSize, @@ -3694,7 +3750,8 @@ LoadOplog(LoadOplogContext&& Context) { using namespace std::literals; - std::unique_ptr<OperationLogOutput> LogOutput(std::make_unique<remotestore_impl::JobContextLogOutput>(Context.OptionalJobContext)); + ZEN_SCOPED_LOG(Context.Log); + remotestore_impl::JobContextLogger JobContextOutput(Context.OptionalJobContext); remotestore_impl::DownloadInfo Info; @@ -3985,7 +4042,7 @@ LoadOplog(LoadOplogContext&& Context) ZEN_ASSERT(PartialBlockDownloadModes.size() == BlocksWithDescription.size()); ChunkBlockAnalyser PartialAnalyser( - *LogOutput, + JobContextOutput.Log(), BlockDescriptions.Blocks, ChunkBlockAnalyser::Options{.IsQuiet = false, .IsVerbose = false, @@ -4108,10 +4165,10 @@ LoadOplog(LoadOplogContext&& Context) std::filesystem::path TempFileName = TempFilePath / Chunked.RawHash.ToHexString(); DechunkWork.ScheduleWork( Context.WorkerPool, - [&Context, TempFileName, &FilesToDechunk, ChunkedIndex, &Info](std::atomic<bool>& AbortFlag) { + [&Log, &Context, TempFileName, &FilesToDechunk, ChunkedIndex, &Info](std::atomic<bool>& AbortFlag) { ZEN_TRACE_CPU("DechunkAttachment"); - auto _ = MakeGuard([&TempFileName] { + auto _ = MakeGuard([&Log, &TempFileName] { std::error_code Ec; if (IsFile(TempFileName, Ec)) { @@ -4712,7 +4769,8 @@ TEST_CASE_TEMPLATE("project.store.export", WorkerThreadPool& NetworkPool = Pools.NetworkPool; WorkerThreadPool& WorkerPool = Pools.WorkerPool; - SaveOplog(CidStore, + SaveOplog(Log(), + CidStore, *RemoteStore, *Project.Get(), *Oplog, @@ -4732,7 +4790,8 @@ TEST_CASE_TEMPLATE("project.store.export", CapturingJobContext Ctx; auto DoLoad = [&](bool Force, bool Clean) { - LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .OptionalCache = nullptr, .CacheBuildId = Oid::Zero, @@ -4793,7 +4852,8 @@ SetupExportStore(CidStore& CidStore, /*.ForceEnableTempBlocks =*/false}; std::shared_ptr<RemoteProjectStore> RemoteStore = CreateFileRemoteStore(Log(), Options); - SaveOplog(CidStore, + SaveOplog(Log(), + CidStore, *RemoteStore, Project, *Oplog, @@ -4856,7 +4916,8 @@ SetupPartialBlockExportStore(WorkerThreadPool& NetworkPool, WorkerThreadPool& Wo /*.ForceDisableBlocks =*/false, /*.ForceEnableTempBlocks =*/false}; std::shared_ptr<RemoteProjectStore> RemoteStore = CreateFileRemoteStore(Log(), Options); - SaveOplog(LocalCidStore, + SaveOplog(Log(), + LocalCidStore, *RemoteStore, *LocalProject, *Oplog, @@ -5024,7 +5085,8 @@ TEST_CASE("project.store.import.context_settings") bool PopulateCache, bool ForceDownload) -> void { Ref<ProjectStore::Oplog> ImportOplog = ImportProject->NewOplog(fmt::format("import_{}", OpJobIndex++), {}); - LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = ImportCidStore, .RemoteStore = *RemoteStore, .OptionalCache = OptCache, .CacheBuildId = CacheBuildId, @@ -5131,7 +5193,8 @@ TEST_CASE("project.store.import.context_settings") // StoreMaxRangeCountPerRequest=128 -> all three ranges sent in one LoadAttachmentRanges call. Ref<ProjectStore::Oplog> PartialOplog = ImportProject->NewOplog(fmt::format("partial_cloud_multi_{}", OpJobIndex++), {}); - LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = ImportCidStore, .RemoteStore = *PartialRemoteStore, .OptionalCache = nullptr, .CacheBuildId = CacheBuildId, @@ -5163,7 +5226,8 @@ TEST_CASE("project.store.import.context_settings") SeedCidStoreWithAlternateChunks(ImportCidStore, *PartialRemoteStore, BlockHash); Ref<ProjectStore::Oplog> PartialOplog = ImportProject->NewOplog(fmt::format("partial_cloud_single_{}", OpJobIndex++), {}); - LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = ImportCidStore, .RemoteStore = *PartialRemoteStore, .OptionalCache = nullptr, .CacheBuildId = CacheBuildId, @@ -5194,7 +5258,8 @@ TEST_CASE("project.store.import.context_settings") // Phase 1: full block download from remote populates the cache. { Ref<ProjectStore::Oplog> Phase1Oplog = ImportProject->NewOplog(fmt::format("partial_cache_multi_p1_{}", OpJobIndex++), {}); - LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = ImportCidStore, .RemoteStore = *PartialRemoteStore, .OptionalCache = Cache.get(), .CacheBuildId = CacheBuildId, @@ -5226,7 +5291,8 @@ TEST_CASE("project.store.import.context_settings") SeedCidStoreWithAlternateChunks(Phase2CidStore, *PartialRemoteStore, BlockHash); Ref<ProjectStore::Oplog> Phase2Oplog = ImportProject->NewOplog(fmt::format("partial_cache_multi_p2_{}", OpJobIndex++), {}); - LoadOplog(LoadOplogContext{.ChunkStore = Phase2CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = Phase2CidStore, .RemoteStore = *PartialRemoteStore, .OptionalCache = Cache.get(), .CacheBuildId = CacheBuildId, @@ -5259,7 +5325,8 @@ TEST_CASE("project.store.import.context_settings") // Phase 1: full block download from remote into cache. { Ref<ProjectStore::Oplog> Phase1Oplog = ImportProject->NewOplog(fmt::format("partial_cache_single_p1_{}", OpJobIndex++), {}); - LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = ImportCidStore, .RemoteStore = *PartialRemoteStore, .OptionalCache = Cache.get(), .CacheBuildId = CacheBuildId, @@ -5291,7 +5358,8 @@ TEST_CASE("project.store.import.context_settings") SeedCidStoreWithAlternateChunks(Phase2CidStore, *PartialRemoteStore, BlockHash); Ref<ProjectStore::Oplog> Phase2Oplog = ImportProject->NewOplog(fmt::format("partial_cache_single_p2_{}", OpJobIndex++), {}); - LoadOplog(LoadOplogContext{.ChunkStore = Phase2CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = Phase2CidStore, .RemoteStore = *PartialRemoteStore, .OptionalCache = Cache.get(), .CacheBuildId = CacheBuildId, @@ -5373,7 +5441,8 @@ RunSaveOplog(CidStore& CidStore, { *OutRemoteStore = RemoteStore; } - SaveOplog(CidStore, + SaveOplog(Log(), + CidStore, *RemoteStore, Project, Oplog, @@ -5476,7 +5545,8 @@ TEST_CASE("project.store.embed_loose_files_true") /*ForceDisableBlocks=*/false, &RemoteStore); Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_embed_true_import", {}); - LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -5530,7 +5600,8 @@ TEST_CASE("project.store.embed_loose_files_false" * doctest::skip()) // superse &RemoteStore); Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_embed_false_import", {}); - LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -5693,7 +5764,8 @@ TEST_CASE("project.store.export.large_file_attachment_direct") &RemoteStore); Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_large_direct_import", {}); - LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -5750,7 +5822,8 @@ TEST_CASE("project.store.export.large_file_attachment_via_temp") &RemoteStore); Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_large_via_temp_import", {}); - LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -5804,7 +5877,8 @@ TEST_CASE("project.store.export.large_chunk_from_cidstore") &RemoteStore); Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_large_cid_import", {}); - LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -5867,7 +5941,8 @@ TEST_CASE("project.store.export.block_reuse") BlockHashesAfterFirst.push_back(B.BlockHash); } - SaveOplog(CidStore, + SaveOplog(Log(), + CidStore, *RemoteStore, *Project, *Oplog, @@ -5944,7 +6019,8 @@ TEST_CASE("project.store.export.max_chunks_per_block") CHECK(KnownBlocks.Blocks.size() >= 2); Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_max_chunks_import", {}); - LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -6027,7 +6103,8 @@ TEST_CASE("project.store.export.max_data_per_block") CHECK(KnownBlocks.Blocks.size() >= 2); Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_max_data_per_block_import", {}); - LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -6155,7 +6232,8 @@ TEST_CASE("project.store.embed_loose_files_zero_data_hash") &RemoteStore); Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_zero_data_hash_import", {}); - LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -6209,7 +6287,8 @@ TEST_CASE("project.store.embed_loose_files_already_resolved") &RemoteStore1); Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_already_resolved_import", {}); - LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore1, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -6296,7 +6375,8 @@ TEST_CASE("project.store.import.missing_attachment") Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_missing_att_throw", {}); REQUIRE(ImportOplog); CapturingJobContext Ctx; - CHECK_THROWS_AS(LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + CHECK_THROWS_AS(LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -6313,7 +6393,8 @@ TEST_CASE("project.store.import.missing_attachment") Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_missing_att_ignore", {}); REQUIRE(ImportOplog); CapturingJobContext Ctx; - CHECK_NOTHROW(LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + CHECK_NOTHROW(LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -6358,7 +6439,8 @@ TEST_CASE("project.store.import.error.load_container_failure") REQUIRE(ImportOplog); CapturingJobContext Ctx; - CHECK_THROWS_AS(LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + CHECK_THROWS_AS(LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -6785,6 +6867,7 @@ TEST_CASE("buildcontainer.public_overload_smoke") std::atomic<int> BlockCallCount{0}; CbObject Container = BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -6828,6 +6911,7 @@ TEST_CASE("buildcontainer.build_blocks_false_on_block_chunks") std::atomic<int> BlockChunksCallCount{0}; CbObject Container = BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -6893,6 +6977,7 @@ TEST_CASE("buildcontainer.ignore_missing_binary_attachment_warn") { CapturingJobContext Ctx; BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -6916,6 +7001,7 @@ TEST_CASE("buildcontainer.ignore_missing_binary_attachment_warn") SUBCASE("throw") { CHECK_THROWS(BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -6967,6 +7053,7 @@ TEST_CASE("buildcontainer.ignore_missing_file_attachment_warn") { CapturingJobContext Ctx; BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -6990,6 +7077,7 @@ TEST_CASE("buildcontainer.ignore_missing_file_attachment_warn") SUBCASE("throw") { CHECK_THROWS(BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -7008,6 +7096,61 @@ TEST_CASE("buildcontainer.ignore_missing_file_attachment_warn") } } +TEST_CASE("buildcontainer.zero_byte_file_attachment") +{ + // A zero-byte file on disk is a valid attachment. BuildContainer must process + // it without hitting ZEN_ASSERT(UploadAttachment->Size != 0) in + // ResolveAttachments. The empty file flows through the compress-inline path + // and becomes a LooseUploadAttachment with raw size 0. + using namespace projectstore_testutils; + using namespace std::literals; + + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CidStore CidStore(Gc); + std::unique_ptr<ProjectStore> ProjectStoreDummy; + Ref<ProjectStore::Project> Project = MakeTestProject(CidStore, Gc, TempDir.Path(), ProjectStoreDummy); + + std::filesystem::path RootDir = TempDir.Path() / "root"; + auto FileAtts = CreateFileAttachments(RootDir, std::initializer_list<size_t>{512}); + + Ref<ProjectStore::Oplog> Oplog = Project->NewOplog("bc_zero_byte_file", {}); + REQUIRE(Oplog); + Oplog->AppendNewOplogEntry(CreateFilesOplogPackage(Oid::NewOid(), RootDir, FileAtts)); + + // Truncate the file to zero bytes after the oplog entry is created. + // The file still exists on disk so RewriteOplog's IsFile() check passes, + // but MakeFromFile returns a zero-size buffer. + std::filesystem::resize_file(FileAtts[0].second, 0); + + WorkerThreadPool WorkerPool(GetWorkerCount()); + + CbObject Container = BuildContainer( + Log(), + CidStore, + *Project, + *Oplog, + WorkerPool, + 64u * 1024u, + 1000, + 32u * 1024u, + 64u * 1024u * 1024u, + /*BuildBlocks=*/true, + /*IgnoreMissingAttachments=*/false, + /*AllowChunking=*/true, + [](CompressedBuffer&&, ChunkBlockDescription&&) {}, + [](const IoHash&, TGetAttachmentBufferFunc&&) {}, + [](std::vector<std::pair<IoHash, FetchChunkFunc>>&&) {}, + /*EmbedLooseFiles=*/true); + + CHECK(Container.GetSize() > 0); + + // The zero-byte attachment is packed into a block via the compress-inline path. + CbArrayView Blocks = Container["blocks"sv].AsArrayView(); + CHECK(Blocks.Num() > 0); +} + TEST_CASE("buildcontainer.embed_loose_files_false_no_rewrite") { // EmbedLooseFiles=false: RewriteOp is skipped for file-op entries; they pass through @@ -7030,6 +7173,7 @@ TEST_CASE("buildcontainer.embed_loose_files_false_no_rewrite") WorkerThreadPool WorkerPool(GetWorkerCount()); CbObject Container = BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -7080,6 +7224,7 @@ TEST_CASE("buildcontainer.allow_chunking_false") { std::atomic<int> LargeAttachmentCallCount{0}; BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -7103,6 +7248,7 @@ TEST_CASE("buildcontainer.allow_chunking_false") // Chunking branch in FindChunkSizes is taken, but the ~4 KB chunk still exceeds MaxChunkEmbedSize -> OnLargeAttachment. std::atomic<int> LargeAttachmentCallCount{0}; BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -7144,6 +7290,7 @@ TEST_CASE("buildcontainer.async_on_block_exception_propagates") WorkerThreadPool WorkerPool(GetWorkerCount()); CHECK_THROWS_AS(BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -7184,6 +7331,7 @@ TEST_CASE("buildcontainer.on_large_attachment_exception_propagates") WorkerThreadPool WorkerPool(GetWorkerCount()); CHECK_THROWS_AS(BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -7226,6 +7374,7 @@ TEST_CASE("buildcontainer.context_cancellation_aborts") Ctx.m_Cancel = true; CHECK_NOTHROW(BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -7265,6 +7414,7 @@ TEST_CASE("buildcontainer.context_progress_reporting") CapturingJobContext Ctx; BuildContainer( + Log(), CidStore, *Project, *Oplog, @@ -7428,7 +7578,8 @@ TEST_CASE("loadoplog.missing_block_attachment_ignored") CapturingJobContext Ctx; Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_missing_block_import", {}); - CHECK_NOTHROW(LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + CHECK_NOTHROW(LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = CidStore, .RemoteStore = *RemoteStore, .Oplog = *ImportOplog, .NetworkWorkerPool = NetworkPool, @@ -7501,7 +7652,8 @@ TEST_CASE("loadoplog.clean_oplog_with_populated_cache") { Ref<ProjectStore::Oplog> Phase1Oplog = ImportProject->NewOplog("oplog_clean_cache_p1", {}); - LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = ImportCidStore, .RemoteStore = *RemoteStore, .OptionalCache = Cache.get(), .CacheBuildId = CacheBuildId, @@ -7517,7 +7669,8 @@ TEST_CASE("loadoplog.clean_oplog_with_populated_cache") { Ref<ProjectStore::Oplog> Phase2Oplog = ImportProject->NewOplog("oplog_clean_cache_p2", {}); - CHECK_NOTHROW(LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + CHECK_NOTHROW(LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = ImportCidStore, .RemoteStore = *RemoteStore, .OptionalCache = Cache.get(), .CacheBuildId = CacheBuildId, @@ -7532,6 +7685,158 @@ TEST_CASE("loadoplog.clean_oplog_with_populated_cache") } } +TEST_CASE("project.store.export.block_reuse_fresh_receiver") +{ + // Regression test: after a second export that reuses existing blocks, a fresh import must still + // receive all chunks. The bug: FindReuseBlocks erases reused-block chunks from UploadAttachments, + // but never adds the reused blocks to the container's "blocks" section. A fresh receiver then + // silently misses those chunks because ParseOplogContainer never sees them. + using namespace projectstore_testutils; + using namespace std::literals; + + ScopedTemporaryDirectory TempDir; + ScopedTemporaryDirectory ExportDir; + + // -- Export side ---------------------------------------------------------- + GcManager ExportGc; + CidStore ExportCidStore(ExportGc); + CidStoreConfiguration ExportCidConfig = {.RootDirectory = TempDir.Path() / "export_cas", + .TinyValueThreshold = 1024, + .HugeValueThreshold = 4096}; + ExportCidStore.Initialize(ExportCidConfig); + + std::filesystem::path ExportBasePath = TempDir.Path() / "export_projectstore"; + ProjectStore ExportProjectStore(ExportCidStore, ExportBasePath, ExportGc, ProjectStore::Configuration{}); + std::filesystem::path RootDir = TempDir.Path() / "root"; + std::filesystem::path EngineRootDir = TempDir.Path() / "engine"; + std::filesystem::path ProjectRootDir = TempDir.Path() / "game"; + std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject"; + Ref<ProjectStore::Project> ExportProject(ExportProjectStore.NewProject(ExportBasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + ProjectRootDir.string(), + ProjectFilePath.string())); + + // 20 KB with None encoding: compressed ~ 20 KB < MaxChunkEmbedSize (32 KB) -> packed into blocks. + Ref<ProjectStore::Oplog> Oplog = ExportProject->NewOplog("oplog_reuse_rt", {}); + REQUIRE(Oplog); + Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage( + Oid::NewOid(), + CreateAttachments(std::initializer_list<size_t>{20u * 1024u, 20u * 1024u}, OodleCompressionLevel::None))); + + TestWorkerPools Pools; + WorkerThreadPool& NetworkPool = Pools.NetworkPool; + WorkerThreadPool& WorkerPool = Pools.WorkerPool; + + constexpr size_t MaxBlockSize = 64u * 1024u; + constexpr size_t MaxChunksPerBlock = 1000; + constexpr size_t MaxChunkEmbedSize = 32u * 1024u; + constexpr size_t ChunkFileSizeLimit = 64u * 1024u * 1024u; + + // First export: creates blocks on disk. + FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, + .MaxChunksPerBlock = MaxChunksPerBlock, + .MaxChunkEmbedSize = MaxChunkEmbedSize, + .ChunkFileSizeLimit = ChunkFileSizeLimit}, + /*.FolderPath =*/ExportDir.Path(), + /*.Name =*/std::string("oplog_reuse_rt"), + /*.OptionalBaseName =*/std::string(), + /*.ForceDisableBlocks =*/false, + /*.ForceEnableTempBlocks =*/false}; + + std::shared_ptr<RemoteProjectStore> RemoteStore = CreateFileRemoteStore(Log(), Options); + SaveOplog(Log(), + ExportCidStore, + *RemoteStore, + *ExportProject, + *Oplog, + NetworkPool, + WorkerPool, + MaxBlockSize, + MaxChunksPerBlock, + MaxChunkEmbedSize, + ChunkFileSizeLimit, + /*EmbedLooseFiles*/ true, + /*ForceUpload*/ false, + /*IgnoreMissingAttachments*/ false, + /*OptionalContext*/ nullptr); + + // Verify first export produced blocks. + RemoteProjectStore::GetKnownBlocksResult KnownAfterFirst = RemoteStore->GetKnownBlocks(); + REQUIRE(!KnownAfterFirst.Blocks.empty()); + + // Second export to the SAME store: triggers block reuse via GetKnownBlocks. + SaveOplog(Log(), + ExportCidStore, + *RemoteStore, + *ExportProject, + *Oplog, + NetworkPool, + WorkerPool, + MaxBlockSize, + MaxChunksPerBlock, + MaxChunkEmbedSize, + ChunkFileSizeLimit, + /*EmbedLooseFiles*/ true, + /*ForceUpload*/ false, + /*IgnoreMissingAttachments*/ false, + /*OptionalContext*/ nullptr); + + // Verify the container has no duplicate block entries. + { + RemoteProjectStore::LoadContainerResult ContainerResult = RemoteStore->LoadContainer(); + REQUIRE(ContainerResult.ErrorCode == 0); + std::vector<IoHash> BlockHashes = GetBlockHashesFromOplog(ContainerResult.ContainerObject); + REQUIRE(!BlockHashes.empty()); + std::unordered_set<IoHash, IoHash::Hasher> UniqueBlockHashes(BlockHashes.begin(), BlockHashes.end()); + CHECK(UniqueBlockHashes.size() == BlockHashes.size()); + } + + // Collect all attachment hashes referenced by the oplog ops. + std::unordered_set<IoHash, IoHash::Hasher> ExpectedHashes; + Oplog->IterateOplogWithKey([&](int, const Oid&, CbObjectView Op) { + Op.IterateAttachments([&](CbFieldView FieldView) { ExpectedHashes.insert(FieldView.AsAttachment()); }); + }); + REQUIRE(!ExpectedHashes.empty()); + + // -- Import side (fresh, empty CAS) -------------------------------------- + GcManager ImportGc; + CidStore ImportCidStore(ImportGc); + CidStoreConfiguration ImportCidConfig = {.RootDirectory = TempDir.Path() / "import_cas", + .TinyValueThreshold = 1024, + .HugeValueThreshold = 4096}; + ImportCidStore.Initialize(ImportCidConfig); + + std::filesystem::path ImportBasePath = TempDir.Path() / "import_projectstore"; + ProjectStore ImportProjectStore(ImportCidStore, ImportBasePath, ImportGc, ProjectStore::Configuration{}); + Ref<ProjectStore::Project> ImportProject(ImportProjectStore.NewProject(ImportBasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + ProjectRootDir.string(), + ProjectFilePath.string())); + + Ref<ProjectStore::Oplog> ImportOplog = ImportProject->NewOplog("oplog_reuse_rt_import", {}); + REQUIRE(ImportOplog); + + LoadOplog(LoadOplogContext{.Log = Log(), + .ChunkStore = ImportCidStore, + .RemoteStore = *RemoteStore, + .Oplog = *ImportOplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = true, + .IgnoreMissingAttachments = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::All}); + + // Every attachment hash from the original oplog must be present in the import CAS. + for (const IoHash& Hash : ExpectedHashes) + { + CHECK_MESSAGE(ImportCidStore.ContainsChunk(Hash), "Missing chunk after import: ", Hash); + } +} + TEST_SUITE_END(); #endif // ZEN_WITH_TESTS diff --git a/src/zenremotestore/zenremotestore.cpp b/src/zenremotestore/zenremotestore.cpp index 0b205b296..9642f8470 100644 --- a/src/zenremotestore/zenremotestore.cpp +++ b/src/zenremotestore/zenremotestore.cpp @@ -9,7 +9,6 @@ #include <zenremotestore/chunking/chunkedcontent.h> #include <zenremotestore/chunking/chunkedfile.h> #include <zenremotestore/chunking/chunkingcache.h> -#include <zenremotestore/filesystemutils.h> #include <zenremotestore/projectstore/remoteprojectstore.h> #if ZEN_WITH_TESTS @@ -27,7 +26,6 @@ zenremotestore_forcelinktests() chunkedcontent_forcelink(); chunkedfile_forcelink(); chunkingcache_forcelink(); - filesystemutils_forcelink(); remoteprojectstore_forcelink(); } diff --git a/src/zens3-testbed/main.cpp b/src/zens3-testbed/main.cpp deleted file mode 100644 index 4cd6b411f..000000000 --- a/src/zens3-testbed/main.cpp +++ /dev/null @@ -1,526 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -// Simple test bed for exercising the zens3 module against a real S3 bucket. -// -// Usage: -// zens3-testbed --bucket <name> --region <region> [command] [args...] -// -// Credentials are read from environment variables: -// AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY -// -// Commands: -// put <key> <file> Upload a local file -// get <key> [file] Download an object (prints to stdout if no file given) -// head <key> Check if object exists, show metadata -// delete <key> Delete an object -// list [prefix] List objects with optional prefix -// multipart-put <key> <file> [part-size-mb] Upload via multipart -// roundtrip <key> Upload test data, download, verify, delete - -#include <zenutil/cloud/imdscredentials.h> -#include <zenutil/cloud/s3client.h> - -#include <zencore/except_fmt.h> -#include <zencore/filesystem.h> -#include <zencore/iobuffer.h> -#include <zencore/logging.h> -#include <zencore/string.h> - -#include <zencore/memory/newdelete.h> - -ZEN_THIRD_PARTY_INCLUDES_START -#include <fmt/format.h> -#include <cxxopts.hpp> -ZEN_THIRD_PARTY_INCLUDES_END - -#include <cstdlib> -#include <fstream> -#include <iostream> - -namespace { - -using namespace zen; - -std::string -GetEnvVar(const char* Name) -{ - const char* Value = std::getenv(Name); - return Value ? std::string(Value) : std::string(); -} - -IoBuffer -ReadFileToBuffer(const std::filesystem::path& Path) -{ - return zen::ReadFile(Path).Flatten(); -} - -void -WriteBufferToFile(const IoBuffer& Buffer, const std::filesystem::path& Path) -{ - std::ofstream File(Path, std::ios::binary); - if (!File) - { - throw zen::runtime_error("failed to open '{}' for writing", Path.string()); - } - File.write(reinterpret_cast<const char*>(Buffer.GetData()), static_cast<std::streamsize>(Buffer.GetSize())); -} - -S3Client -CreateClient(const cxxopts::ParseResult& Args) -{ - S3ClientOptions Options; - Options.BucketName = Args["bucket"].as<std::string>(); - Options.Region = Args["region"].as<std::string>(); - - if (Args.count("imds")) - { - // Use IMDS credential provider for EC2 instances - ImdsCredentialProviderOptions ImdsOpts; - if (Args.count("imds-endpoint")) - { - ImdsOpts.Endpoint = Args["imds-endpoint"].as<std::string>(); - } - Options.CredentialProvider = Ref<ImdsCredentialProvider>(new ImdsCredentialProvider(ImdsOpts)); - } - else - { - std::string AccessKey = GetEnvVar("AWS_ACCESS_KEY_ID"); - std::string SecretKey = GetEnvVar("AWS_SECRET_ACCESS_KEY"); - std::string SessionToken = GetEnvVar("AWS_SESSION_TOKEN"); - - if (AccessKey.empty() || SecretKey.empty()) - { - throw zen::runtime_error("AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables must be set"); - } - - Options.Credentials.AccessKeyId = std::move(AccessKey); - Options.Credentials.SecretAccessKey = std::move(SecretKey); - Options.Credentials.SessionToken = std::move(SessionToken); - } - - if (Args.count("endpoint")) - { - Options.Endpoint = Args["endpoint"].as<std::string>(); - } - - if (Args.count("path-style")) - { - Options.PathStyle = true; - } - - if (Args.count("timeout")) - { - Options.Timeout = std::chrono::milliseconds(Args["timeout"].as<int>() * 1000); - } - - return S3Client(Options); -} - -int -CmdPut(S3Client& Client, const std::vector<std::string>& Positional) -{ - if (Positional.size() < 3) - { - fmt::print(stderr, "Usage: zens3-testbed ... put <key> <file>\n"); - return 1; - } - - const auto& Key = Positional[1]; - const auto& FilePath = Positional[2]; - - IoBuffer Content = ReadFileToBuffer(FilePath); - fmt::print("Uploading '{}' ({} bytes) to s3://{}/{}\n", FilePath, Content.GetSize(), Client.BucketName(), Key); - - S3Result Result = Client.PutObject(Key, Content); - if (!Result) - { - fmt::print(stderr, "PUT failed: {}\n", Result.Error); - return 1; - } - - fmt::print("OK\n"); - return 0; -} - -int -CmdGet(S3Client& Client, const std::vector<std::string>& Positional) -{ - if (Positional.size() < 2) - { - fmt::print(stderr, "Usage: zens3-testbed ... get <key> [file]\n"); - return 1; - } - - const auto& Key = Positional[1]; - - S3GetObjectResult Result = Client.GetObject(Key); - if (!Result) - { - fmt::print(stderr, "GET failed: {}\n", Result.Error); - return 1; - } - - if (Positional.size() >= 3) - { - const auto& FilePath = Positional[2]; - WriteBufferToFile(Result.Content, FilePath); - fmt::print("Downloaded {} bytes to '{}'\n", Result.Content.GetSize(), FilePath); - } - else - { - // Print to stdout - std::string_view Text = Result.AsText(); - std::cout.write(Text.data(), static_cast<std::streamsize>(Text.size())); - std::cout << std::endl; - } - - return 0; -} - -int -CmdHead(S3Client& Client, const std::vector<std::string>& Positional) -{ - if (Positional.size() < 2) - { - fmt::print(stderr, "Usage: zens3-testbed ... head <key>\n"); - return 1; - } - - const auto& Key = Positional[1]; - - S3HeadObjectResult Result = Client.HeadObject(Key); - - if (!Result) - { - fmt::print(stderr, "HEAD failed: {}\n", Result.Error); - return 1; - } - - if (Result.Status == HeadObjectResult::NotFound) - { - fmt::print("Object '{}' does not exist\n", Key); - return 1; - } - - fmt::print("Key: {}\n", Result.Info.Key); - fmt::print("Size: {} bytes\n", Result.Info.Size); - fmt::print("ETag: {}\n", Result.Info.ETag); - fmt::print("Last-Modified: {}\n", Result.Info.LastModified); - return 0; -} - -int -CmdDelete(S3Client& Client, const std::vector<std::string>& Positional) -{ - if (Positional.size() < 2) - { - fmt::print(stderr, "Usage: zens3-testbed ... delete <key>\n"); - return 1; - } - - const auto& Key = Positional[1]; - - S3Result Result = Client.DeleteObject(Key); - if (!Result) - { - fmt::print(stderr, "DELETE failed: {}\n", Result.Error); - return 1; - } - - fmt::print("Deleted '{}'\n", Key); - return 0; -} - -int -CmdList(S3Client& Client, const std::vector<std::string>& Positional) -{ - std::string Prefix; - if (Positional.size() >= 2) - { - Prefix = Positional[1]; - } - - S3ListObjectsResult Result = Client.ListObjects(Prefix); - if (!Result) - { - fmt::print(stderr, "LIST failed: {}\n", Result.Error); - return 1; - } - - fmt::print("{} objects found:\n", Result.Objects.size()); - for (const auto& Obj : Result.Objects) - { - fmt::print(" {:>12} {} {}\n", Obj.Size, Obj.LastModified, Obj.Key); - } - - return 0; -} - -int -CmdMultipartPut(S3Client& Client, const std::vector<std::string>& Positional) -{ - if (Positional.size() < 3) - { - fmt::print(stderr, "Usage: zens3-testbed ... multipart-put <key> <file> [part-size-mb]\n"); - return 1; - } - - const auto& Key = Positional[1]; - const auto& FilePath = Positional[2]; - - uint64_t PartSize = 8 * 1024 * 1024; // 8 MB default - if (Positional.size() >= 4) - { - PartSize = std::stoull(Positional[3]) * 1024 * 1024; - } - - IoBuffer Content = ReadFileToBuffer(FilePath); - fmt::print("Multipart uploading '{}' ({} bytes, part size {} MB) to s3://{}/{}\n", - FilePath, - Content.GetSize(), - PartSize / (1024 * 1024), - Client.BucketName(), - Key); - - S3Result Result = Client.PutObjectMultipart(Key, Content, PartSize); - if (!Result) - { - fmt::print(stderr, "Multipart PUT failed: {}\n", Result.Error); - return 1; - } - - fmt::print("OK\n"); - return 0; -} - -int -CmdRoundtrip(S3Client& Client, const std::vector<std::string>& Positional) -{ - if (Positional.size() < 2) - { - fmt::print(stderr, "Usage: zens3-testbed ... roundtrip <key>\n"); - return 1; - } - - const auto& Key = Positional[1]; - - // Generate test data - const size_t TestSize = 1024 * 64; // 64 KB - std::vector<uint8_t> TestData(TestSize); - for (size_t i = 0; i < TestSize; ++i) - { - TestData[i] = static_cast<uint8_t>(i & 0xFF); - } - - IoBuffer UploadContent(IoBuffer::Clone, TestData.data(), TestData.size()); - - fmt::print("=== Roundtrip test for key '{}' ===\n\n", Key); - - // PUT - fmt::print("[1/4] PUT {} bytes...\n", TestSize); - S3Result Result = Client.PutObject(Key, UploadContent); - if (!Result) - { - fmt::print(stderr, " FAILED: {}\n", Result.Error); - return 1; - } - fmt::print(" OK\n"); - - // HEAD - fmt::print("[2/4] HEAD...\n"); - S3HeadObjectResult HeadResult = Client.HeadObject(Key); - if (HeadResult.Status != HeadObjectResult::Found) - { - fmt::print(stderr, " FAILED: {}\n", !HeadResult ? HeadResult.Error : "not found"); - return 1; - } - fmt::print(" OK (size={}, etag={})\n", HeadResult.Info.Size, HeadResult.Info.ETag); - - if (HeadResult.Info.Size != TestSize) - { - fmt::print(stderr, " SIZE MISMATCH: expected {}, got {}\n", TestSize, HeadResult.Info.Size); - return 1; - } - - // GET - fmt::print("[3/4] GET and verify...\n"); - S3GetObjectResult GetResult = Client.GetObject(Key); - if (!GetResult) - { - fmt::print(stderr, " FAILED: {}\n", GetResult.Error); - return 1; - } - - if (GetResult.Content.GetSize() != TestSize) - { - fmt::print(stderr, " SIZE MISMATCH: expected {}, got {}\n", TestSize, GetResult.Content.GetSize()); - return 1; - } - - if (memcmp(GetResult.Content.GetData(), TestData.data(), TestSize) != 0) - { - fmt::print(stderr, " DATA MISMATCH\n"); - return 1; - } - fmt::print(" OK (verified {} bytes)\n", TestSize); - - // DELETE - fmt::print("[4/4] DELETE...\n"); - Result = Client.DeleteObject(Key); - if (!Result) - { - fmt::print(stderr, " FAILED: {}\n", Result.Error); - return 1; - } - fmt::print(" OK\n"); - - fmt::print("\n=== Roundtrip test PASSED ===\n"); - return 0; -} - -int -CmdPresign(S3Client& Client, const std::vector<std::string>& Positional) -{ - if (Positional.size() < 2) - { - fmt::print(stderr, "Usage: zens3-testbed ... presign <key> [method] [expires-seconds]\n"); - return 1; - } - - const auto& Key = Positional[1]; - - std::string Method = "GET"; - if (Positional.size() >= 3) - { - Method = Positional[2]; - } - - std::chrono::seconds ExpiresIn(3600); - if (Positional.size() >= 4) - { - ExpiresIn = std::chrono::seconds(std::stoul(Positional[3])); - } - - std::string Url; - if (Method == "PUT") - { - Url = Client.GeneratePresignedPutUrl(Key, ExpiresIn); - } - else - { - Url = Client.GeneratePresignedGetUrl(Key, ExpiresIn); - } - - fmt::print("{}\n", Url); - return 0; -} - -} // namespace - -int -main(int argc, char* argv[]) -{ - using namespace zen; - - logging::InitializeLogging(); - - cxxopts::Options Options("zens3-testbed", "Test bed for exercising S3 operations via the zens3 module"); - - // clang-format off - Options.add_options() - ("b,bucket", "S3 bucket name", cxxopts::value<std::string>()) - ("r,region", "AWS region", cxxopts::value<std::string>()->default_value("us-east-1")) - ("e,endpoint", "Custom S3 endpoint URL", cxxopts::value<std::string>()) - ("path-style", "Use path-style addressing (for MinIO, etc.)") - ("imds", "Use EC2 IMDS for credentials instead of env vars") - ("imds-endpoint", "Custom IMDS endpoint URL (for testing)", cxxopts::value<std::string>()) - ("timeout", "Request timeout in seconds", cxxopts::value<int>()->default_value("30")) - ("v,verbose", "Enable verbose logging") - ("h,help", "Show help") - ("positional", "Command and arguments", cxxopts::value<std::vector<std::string>>()); - // clang-format on - - Options.parse_positional({"positional"}); - Options.positional_help("<command> [args...]"); - - try - { - auto Result = Options.parse(argc, argv); - - if (Result.count("help") || !Result.count("positional")) - { - fmt::print("{}\n", Options.help()); - fmt::print("Commands:\n"); - fmt::print(" put <key> <file> Upload a local file\n"); - fmt::print(" get <key> [file] Download (to file or stdout)\n"); - fmt::print(" head <key> Show object metadata\n"); - fmt::print(" delete <key> Delete an object\n"); - fmt::print(" list [prefix] List objects\n"); - fmt::print(" multipart-put <key> <file> [part-mb] Multipart upload\n"); - fmt::print(" roundtrip <key> Upload/download/verify/delete\n"); - fmt::print(" presign <key> [method] [expires-sec] Generate pre-signed URL\n"); - fmt::print("\nCredentials via AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY env vars,\n"); - fmt::print("or use --imds to fetch from EC2 Instance Metadata Service.\n"); - return 0; - } - - if (!Result.count("bucket")) - { - fmt::print(stderr, "Error: --bucket is required\n"); - return 1; - } - - if (Result.count("verbose")) - { - logging::SetLogLevel(logging::Debug); - } - - auto Client = CreateClient(Result); - - const auto& Positional = Result["positional"].as<std::vector<std::string>>(); - const auto& Command = Positional[0]; - - if (Command == "put") - { - return CmdPut(Client, Positional); - } - else if (Command == "get") - { - return CmdGet(Client, Positional); - } - else if (Command == "head") - { - return CmdHead(Client, Positional); - } - else if (Command == "delete") - { - return CmdDelete(Client, Positional); - } - else if (Command == "list") - { - return CmdList(Client, Positional); - } - else if (Command == "multipart-put") - { - return CmdMultipartPut(Client, Positional); - } - else if (Command == "roundtrip") - { - return CmdRoundtrip(Client, Positional); - } - else if (Command == "presign") - { - return CmdPresign(Client, Positional); - } - else - { - fmt::print(stderr, "Unknown command: '{}'\n", Command); - return 1; - } - } - catch (const std::exception& Ex) - { - fmt::print(stderr, "Error: {}\n", Ex.what()); - return 1; - } -} diff --git a/src/zens3-testbed/xmake.lua b/src/zens3-testbed/xmake.lua deleted file mode 100644 index 168ab9de9..000000000 --- a/src/zens3-testbed/xmake.lua +++ /dev/null @@ -1,8 +0,0 @@ --- Copyright Epic Games, Inc. All Rights Reserved. - -target("zens3-testbed") - set_kind("binary") - set_group("tools") - add_files("*.cpp") - add_deps("zenutil", "zencore") - add_deps("cxxopts", "fmt") diff --git a/src/zenserver-test/buildstore-tests.cpp b/src/zenserver-test/buildstore-tests.cpp index cf9b10896..1f2157993 100644 --- a/src/zenserver-test/buildstore-tests.cpp +++ b/src/zenserver-test/buildstore-tests.cpp @@ -148,8 +148,6 @@ TEST_CASE("buildstore.blobs") } { - // Single-range Get - ZenServerInstance Instance(TestEnv); const uint16_t PortNumber = @@ -158,6 +156,7 @@ TEST_CASE("buildstore.blobs") HttpClient Client(Instance.GetBaseUri() + "/builds/"); + // Single-range Get { const IoHash& RawHash = CompressedBlobsHashes.front(); uint64_t BlobSize = CompressedBlobsSizes.front(); @@ -183,20 +182,63 @@ TEST_CASE("buildstore.blobs") MemoryView RangeView = Payload.GetView(); CHECK(ActualRange.EqualBytes(RangeView)); } - } - { - // Single-range Post + { + // GET blob not found + IoHash FakeHash = IoHash::HashBuffer("nonexistent", 11); + HttpClient::Response Result = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, FakeHash)); + CHECK_EQ(Result.StatusCode, HttpResponseCode::NotFound); + } - ZenServerInstance Instance(TestEnv); + { + // GET with out-of-bounds range + const IoHash& RawHash = CompressedBlobsHashes.front(); + uint64_t BlobSize = CompressedBlobsSizes.front(); - const uint16_t PortNumber = - Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath)); - CHECK(PortNumber != 0); + HttpClient::KeyValueMap Headers; + Headers.Entries.insert({"Range", fmt::format("bytes={}-{}", BlobSize + 100, BlobSize + 200)}); - HttpClient Client(Instance.GetBaseUri() + "/builds/"); + HttpClient::Response Result = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), Headers); + CHECK_EQ(Result.StatusCode, HttpResponseCode::RangeNotSatisfiable); + } { + // GET with multi-range header (uses Download for multipart boundary parsing) + const IoHash& RawHash = CompressedBlobsHashes.front(); + uint64_t BlobSize = CompressedBlobsSizes.front(); + + uint64_t Range1Start = 0; + uint64_t Range1End = BlobSize / 4 - 1; + uint64_t Range2Start = BlobSize / 2; + uint64_t Range2End = BlobSize / 2 + BlobSize / 4 - 1; + + HttpClient::KeyValueMap Headers; + Headers.Entries.insert({"Range", fmt::format("bytes={}-{},{}-{}", Range1Start, Range1End, Range2Start, Range2End)}); + + HttpClient::Response Result = + Client.Download(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), SystemRootPath, Headers); + CHECK_EQ(Result.StatusCode, HttpResponseCode::PartialContent); + REQUIRE_EQ(Result.Ranges.size(), 2); + + HttpClient::Response FullBlobResult = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + HttpClient::Accept(ZenContentType::kCompressedBinary)); + REQUIRE(FullBlobResult); + + uint64_t Range1Len = Range1End - Range1Start + 1; + uint64_t Range2Len = Range2End - Range2Start + 1; + + MemoryView ExpectedRange1 = FullBlobResult.ResponsePayload.GetView().Mid(Range1Start, Range1Len); + MemoryView ExpectedRange2 = FullBlobResult.ResponsePayload.GetView().Mid(Range2Start, Range2Len); + + MemoryView ActualRange1 = Result.ResponsePayload.GetView().Mid(Result.Ranges[0].OffsetInPayload, Range1Len); + MemoryView ActualRange2 = Result.ResponsePayload.GetView().Mid(Result.Ranges[1].OffsetInPayload, Range2Len); + + CHECK(ExpectedRange1.EqualBytes(ActualRange1)); + CHECK(ExpectedRange2.EqualBytes(ActualRange2)); + } + + // Single-range Post + { uint64_t RangeSizeSum = 0; const IoHash& RawHash = CompressedBlobsHashes.front(); @@ -259,19 +301,96 @@ TEST_CASE("buildstore.blobs") Offset += Range.second; } } - } - { - // Multi-range + { + // POST with wrong accept type + const IoHash& RawHash = CompressedBlobsHashes.front(); - ZenServerInstance Instance(TestEnv); + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + Writer.BeginObject(); + Writer.AddInteger("offset"sv, uint64_t(0)); + Writer.AddInteger("length"sv, uint64_t(10)); + Writer.EndObject(); + Writer.EndArray(); - const uint16_t PortNumber = - Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath)); - CHECK(PortNumber != 0); + HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + Writer.Save(), + HttpClient::Accept(ZenContentType::kBinary)); + CHECK_EQ(Result.StatusCode, HttpResponseCode::BadRequest); + } - HttpClient Client(Instance.GetBaseUri() + "/builds/"); + { + // POST with missing payload + const IoHash& RawHash = CompressedBlobsHashes.front(); + + HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + HttpClient::Accept(ZenContentType::kCbPackage)); + CHECK_EQ(Result.StatusCode, HttpResponseCode::BadRequest); + } + + { + // POST with empty ranges array + const IoHash& RawHash = CompressedBlobsHashes.front(); + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + Writer.EndArray(); + + HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + Writer.Save(), + HttpClient::Accept(ZenContentType::kCbPackage)); + CHECK_EQ(Result.StatusCode, HttpResponseCode::BadRequest); + } + + { + // POST with range count exceeding maximum + const IoHash& RawHash = CompressedBlobsHashes.front(); + + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + for (uint32_t I = 0; I < 257; I++) + { + Writer.BeginObject(); + Writer.AddInteger("offset"sv, uint64_t(0)); + Writer.AddInteger("length"sv, uint64_t(1)); + Writer.EndObject(); + } + Writer.EndArray(); + + HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + Writer.Save(), + HttpClient::Accept(ZenContentType::kCbPackage)); + CHECK_EQ(Result.StatusCode, HttpResponseCode::BadRequest); + } + + { + // POST with out-of-bounds range returns length=0 + const IoHash& RawHash = CompressedBlobsHashes.front(); + uint64_t BlobSize = CompressedBlobsSizes.front(); + + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + Writer.BeginObject(); + Writer.AddInteger("offset"sv, BlobSize + 100); + Writer.AddInteger("length"sv, uint64_t(50)); + Writer.EndObject(); + Writer.EndArray(); + + HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + Writer.Save(), + HttpClient::Accept(ZenContentType::kCbPackage)); + REQUIRE(Result); + CbPackage ResponsePackage = ParsePackageMessage(Result.ResponsePayload); + CbObjectView ResponseObject = ResponsePackage.GetObject(); + CbArrayView RangeArray = ResponseObject["ranges"sv].AsArrayView(); + REQUIRE_EQ(RangeArray.Num(), uint64_t(1)); + CbObjectView Range = (*begin(RangeArray)).AsObjectView(); + CHECK_EQ(Range["offset"sv].AsUInt64(), BlobSize + 100); + CHECK_EQ(Range["length"sv].AsUInt64(), uint64_t(0)); + } + + // Multi-range { uint64_t RangeSizeSum = 0; diff --git a/src/zenserver-test/cache-tests.cpp b/src/zenserver-test/cache-tests.cpp index 14748e214..e54e7060d 100644 --- a/src/zenserver-test/cache-tests.cpp +++ b/src/zenserver-test/cache-tests.cpp @@ -9,6 +9,7 @@ # include <zencore/compactbinarypackage.h> # include <zencore/compress.h> # include <zencore/fmtutils.h> +# include <zenhttp/localrefpolicy.h> # include <zenhttp/packageformat.h> # include <zenstore/cache/cachepolicy.h> # include <zencore/filesystem.h> @@ -25,6 +26,13 @@ namespace zen::tests { TEST_SUITE_BEGIN("server.cache"); +/// Permissive policy that allows any path, for use in tests that exercise local ref +/// functionality but are not testing path validation. +struct PermissiveLocalRefPolicy : public ILocalRefPolicy +{ + void ValidatePath(const std::filesystem::path&) const override {} +}; + TEST_CASE("zcache.basic") { using namespace std::literals; @@ -164,143 +172,85 @@ TEST_CASE("zcache.cbpackage") return true; }; - SUBCASE("PUT/GET returns correct package") - { - std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + std::filesystem::path LocalDataDir = TestEnv.CreateNewTestDir(); + std::filesystem::path RemoteDataDir = TestEnv.CreateNewTestDir(); - ZenServerInstance Instance1(TestEnv); - Instance1.SetDataDir(TestDir); - const uint16_t PortNumber = Instance1.SpawnServerAndWaitUntilReady(); - const std::string BaseUri = fmt::format("http://localhost:{}/z$", PortNumber); + ZenServerInstance RemoteInstance(TestEnv); + RemoteInstance.SetDataDir(RemoteDataDir); + const uint16_t RemotePortNumber = RemoteInstance.SpawnServerAndWaitUntilReady(); - HttpClient Http{BaseUri}; + ZenServerInstance LocalInstance(TestEnv); + LocalInstance.SetDataDir(LocalDataDir); + LocalInstance.SpawnServer(TestEnv.GetNewPortNumber(), + fmt::format("--upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", RemotePortNumber)); + const uint16_t LocalPortNumber = LocalInstance.WaitUntilReady(); + CHECK_MESSAGE(LocalPortNumber != 0, LocalInstance.GetLogOutput()); - const std::string_view Bucket = "mosdef"sv; - zen::IoHash Key; - zen::CbPackage ExpectedPackage = CreateTestPackage(Key); + const auto LocalBaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber); + const auto RemoteBaseUri = fmt::format("http://localhost:{}/z$", RemotePortNumber); - // PUT - { - zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage); - HttpClient::Response Result = Http.Put(fmt::format("/{}/{}", Bucket, Key), Body); - CHECK(Result.StatusCode == HttpResponseCode::Created); - } + HttpClient LocalHttp{LocalBaseUri}; + HttpClient RemoteHttp{RemoteBaseUri}; - // GET - { - HttpClient::Response Result = Http.Get(fmt::format("/{}/{}", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}}); - CHECK(Result.StatusCode == HttpResponseCode::OK); - - zen::CbPackage Package; - const bool Ok = Package.TryLoad(Result.ResponsePayload); - CHECK(Ok); - CHECK(IsEqual(Package, ExpectedPackage)); - } - } + const std::string_view Bucket = "mosdef"sv; - SUBCASE("PUT propagates upstream") + // Phase 1: PUT/GET returns correct package (via local) { - // Setup local and remote server - std::filesystem::path LocalDataDir = TestEnv.CreateNewTestDir(); - std::filesystem::path RemoteDataDir = TestEnv.CreateNewTestDir(); - - ZenServerInstance RemoteInstance(TestEnv); - RemoteInstance.SetDataDir(RemoteDataDir); - const uint16_t RemotePortNumber = RemoteInstance.SpawnServerAndWaitUntilReady(); - - ZenServerInstance LocalInstance(TestEnv); - LocalInstance.SetDataDir(LocalDataDir); - LocalInstance.SpawnServer(TestEnv.GetNewPortNumber(), - fmt::format("--upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", RemotePortNumber)); - const uint16_t LocalPortNumber = LocalInstance.WaitUntilReady(); - CHECK_MESSAGE(LocalPortNumber != 0, LocalInstance.GetLogOutput()); - - const auto LocalBaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber); - const auto RemoteBaseUri = fmt::format("http://localhost:{}/z$", RemotePortNumber); - - const std::string_view Bucket = "mosdef"sv; - zen::IoHash Key; - zen::CbPackage ExpectedPackage = CreateTestPackage(Key); - - HttpClient LocalHttp{LocalBaseUri}; - HttpClient RemoteHttp{RemoteBaseUri}; - - // Store the cache record package in the local instance - { - zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage); - HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}", Bucket, Key), Body); - - CHECK(Result.StatusCode == HttpResponseCode::Created); - } + zen::IoHash Key1; + zen::CbPackage ExpectedPackage = CreateTestPackage(Key1); - // The cache record can be retrieved as a package from the local instance - { - HttpClient::Response Result = LocalHttp.Get(fmt::format("/{}/{}", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}}); - CHECK(Result.StatusCode == HttpResponseCode::OK); + zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage); + HttpClient::Response PutResult = LocalHttp.Put(fmt::format("/{}/{}", Bucket, Key1), Body); + CHECK(PutResult.StatusCode == HttpResponseCode::Created); - zen::CbPackage Package; - const bool Ok = Package.TryLoad(Result.ResponsePayload); - CHECK(Ok); - CHECK(IsEqual(Package, ExpectedPackage)); - } + HttpClient::Response GetResult = LocalHttp.Get(fmt::format("/{}/{}", Bucket, Key1), {{"Accept", "application/x-ue-cbpkg"}}); + CHECK(GetResult.StatusCode == HttpResponseCode::OK); - // The cache record can be retrieved as a package from the remote instance - { - HttpClient::Response Result = RemoteHttp.Get(fmt::format("/{}/{}", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}}); - CHECK(Result.StatusCode == HttpResponseCode::OK); - - zen::CbPackage Package; - const bool Ok = Package.TryLoad(Result.ResponsePayload); - CHECK(Ok); - CHECK(IsEqual(Package, ExpectedPackage)); - } + zen::CbPackage Package; + const bool Ok = Package.TryLoad(GetResult.ResponsePayload); + CHECK(Ok); + CHECK(IsEqual(Package, ExpectedPackage)); } - SUBCASE("GET finds upstream when missing in local") + // Phase 2: PUT propagates upstream { - // Setup local and remote server - std::filesystem::path LocalDataDir = TestEnv.CreateNewTestDir(); - std::filesystem::path RemoteDataDir = TestEnv.CreateNewTestDir(); + zen::IoHash Key2; + zen::CbPackage ExpectedPackage = CreateTestPackage(Key2); - ZenServerInstance RemoteInstance(TestEnv); - RemoteInstance.SetDataDir(RemoteDataDir); - const uint16_t RemotePortNumber = RemoteInstance.SpawnServerAndWaitUntilReady(); + zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage); + HttpClient::Response PutResult = LocalHttp.Put(fmt::format("/{}/{}", Bucket, Key2), Body); + CHECK(PutResult.StatusCode == HttpResponseCode::Created); - ZenServerInstance LocalInstance(TestEnv); - LocalInstance.SetDataDir(LocalDataDir); - LocalInstance.SpawnServer(TestEnv.GetNewPortNumber(), - fmt::format("--upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", RemotePortNumber)); - const uint16_t LocalPortNumber = LocalInstance.WaitUntilReady(); - CHECK_MESSAGE(LocalPortNumber != 0, LocalInstance.GetLogOutput()); + HttpClient::Response LocalGetResult = LocalHttp.Get(fmt::format("/{}/{}", Bucket, Key2), {{"Accept", "application/x-ue-cbpkg"}}); + CHECK(LocalGetResult.StatusCode == HttpResponseCode::OK); - const auto LocalBaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber); - const auto RemoteBaseUri = fmt::format("http://localhost:{}/z$", RemotePortNumber); + zen::CbPackage LocalPackage; + CHECK(LocalPackage.TryLoad(LocalGetResult.ResponsePayload)); + CHECK(IsEqual(LocalPackage, ExpectedPackage)); - HttpClient LocalHttp{LocalBaseUri}; - HttpClient RemoteHttp{RemoteBaseUri}; + HttpClient::Response RemoteGetResult = RemoteHttp.Get(fmt::format("/{}/{}", Bucket, Key2), {{"Accept", "application/x-ue-cbpkg"}}); + CHECK(RemoteGetResult.StatusCode == HttpResponseCode::OK); - const std::string_view Bucket = "mosdef"sv; - zen::IoHash Key; - zen::CbPackage ExpectedPackage = CreateTestPackage(Key); + zen::CbPackage RemotePackage; + CHECK(RemotePackage.TryLoad(RemoteGetResult.ResponsePayload)); + CHECK(IsEqual(RemotePackage, ExpectedPackage)); + } - // Store the cache record package in upstream cache - { - zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage); - HttpClient::Response Result = RemoteHttp.Put(fmt::format("/{}/{}", Bucket, Key), Body); + // Phase 3: GET finds upstream when missing in local + { + zen::IoHash Key3; + zen::CbPackage ExpectedPackage = CreateTestPackage(Key3); - CHECK(Result.StatusCode == HttpResponseCode::Created); - } + zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage); + HttpClient::Response PutResult = RemoteHttp.Put(fmt::format("/{}/{}", Bucket, Key3), Body); + CHECK(PutResult.StatusCode == HttpResponseCode::Created); - // The cache record can be retrieved as a package from the local cache - { - HttpClient::Response Result = LocalHttp.Get(fmt::format("/{}/{}", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}}); - CHECK(Result.StatusCode == HttpResponseCode::OK); + HttpClient::Response GetResult = LocalHttp.Get(fmt::format("/{}/{}", Bucket, Key3), {{"Accept", "application/x-ue-cbpkg"}}); + CHECK(GetResult.StatusCode == HttpResponseCode::OK); - zen::CbPackage Package; - const bool Ok = Package.TryLoad(Result.ResponsePayload); - CHECK(Ok); - CHECK(IsEqual(Package, ExpectedPackage)); - } + zen::CbPackage Package; + CHECK(Package.TryLoad(GetResult.ResponsePayload)); + CHECK(IsEqual(Package, ExpectedPackage)); } } @@ -340,25 +290,25 @@ TEST_CASE("zcache.policy") return Package; }; - SUBCASE("query - 'local' does not query upstream (binary)") - { - ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber()); - ZenServerInstance UpstreamInst(TestEnv); - UpstreamCfg.Spawn(UpstreamInst); - const uint16_t UpstreamPort = UpstreamCfg.Port; + ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber()); + ZenServerInstance UpstreamInst(TestEnv); + UpstreamCfg.Spawn(UpstreamInst); - ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamPort); - ZenServerInstance LocalInst(TestEnv); - LocalCfg.Spawn(LocalInst); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamCfg.Port); + ZenServerInstance LocalInst(TestEnv); + LocalCfg.Spawn(LocalInst); + + HttpClient LocalHttp{LocalCfg.BaseUri}; + HttpClient RemoteHttp{UpstreamCfg.BaseUri}; - const std::string_view Bucket = "legacy"sv; + // query - 'local' does not query upstream (binary) + // Uses size 1024 for unique key + { + const auto Bucket = "legacy"sv; zen::IoHash Key; IoBuffer BinaryValue = GenerateData(1024, Key); - HttpClient LocalHttp{LocalCfg.BaseUri}; - HttpClient RemoteHttp{UpstreamCfg.BaseUri}; - { HttpClient::Response Result = RemoteHttp.Put(fmt::format("/{}/{}", Bucket, Key), BinaryValue); CHECK(Result.StatusCode == HttpResponseCode::Created); @@ -377,26 +327,14 @@ TEST_CASE("zcache.policy") } } - SUBCASE("store - 'local' does not store upstream (binary)") + // store - 'local' does not store upstream (binary) + // Uses size 2048 for unique key { - ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber()); - ZenServerInstance UpstreamInst(TestEnv); - UpstreamCfg.Spawn(UpstreamInst); - const uint16_t UpstreamPort = UpstreamCfg.Port; - - ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamPort); - ZenServerInstance LocalInst(TestEnv); - LocalCfg.Spawn(LocalInst); - const auto Bucket = "legacy"sv; zen::IoHash Key; - IoBuffer BinaryValue = GenerateData(1024, Key); + IoBuffer BinaryValue = GenerateData(2048, Key); - HttpClient LocalHttp{LocalCfg.BaseUri}; - HttpClient RemoteHttp{UpstreamCfg.BaseUri}; - - // Store binary cache value locally { HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}?Policy=Query,StoreLocal", Bucket, Key), BinaryValue, @@ -415,25 +353,14 @@ TEST_CASE("zcache.policy") } } - SUBCASE("store - 'local/remote' stores local and upstream (binary)") + // store - 'local/remote' stores local and upstream (binary) + // Uses size 4096 for unique key { - ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber()); - ZenServerInstance UpstreamInst(TestEnv); - UpstreamCfg.Spawn(UpstreamInst); - - ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamCfg.Port); - ZenServerInstance LocalInst(TestEnv); - LocalCfg.Spawn(LocalInst); - const auto Bucket = "legacy"sv; zen::IoHash Key; - IoBuffer BinaryValue = GenerateData(1024, Key); - - HttpClient LocalHttp{LocalCfg.BaseUri}; - HttpClient RemoteHttp{UpstreamCfg.BaseUri}; + IoBuffer BinaryValue = GenerateData(4096, Key); - // Store binary cache value locally and upstream { HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}?Policy=Query,Store", Bucket, Key), BinaryValue, @@ -452,27 +379,16 @@ TEST_CASE("zcache.policy") } } - SUBCASE("query - 'local' does not query upstream (cbpackage)") + // query - 'local' does not query upstream (cbpackage) + // Uses bucket "policy4" to isolate from other cbpackage scenarios (deterministic key) { - ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber()); - ZenServerInstance UpstreamInst(TestEnv); - UpstreamCfg.Spawn(UpstreamInst); - - ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamCfg.Port); - ZenServerInstance LocalInst(TestEnv); - LocalCfg.Spawn(LocalInst); - - const auto Bucket = "legacy"sv; + const auto Bucket = "policy4"sv; zen::IoHash Key; zen::IoHash PayloadId; zen::CbPackage Package = GeneratePackage(Key, PayloadId); IoBuffer Buf = SerializeToBuffer(Package); - HttpClient LocalHttp{LocalCfg.BaseUri}; - HttpClient RemoteHttp{UpstreamCfg.BaseUri}; - - // Store package upstream { HttpClient::Response Result = RemoteHttp.Put(fmt::format("/{}/{}", Bucket, Key), Buf); CHECK(Result.StatusCode == HttpResponseCode::Created); @@ -491,27 +407,16 @@ TEST_CASE("zcache.policy") } } - SUBCASE("store - 'local' does not store upstream (cbpackage)") + // store - 'local' does not store upstream (cbpackage) + // Uses bucket "policy5" to isolate from other cbpackage scenarios (deterministic key) { - ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber()); - ZenServerInstance UpstreamInst(TestEnv); - UpstreamCfg.Spawn(UpstreamInst); - - ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamCfg.Port); - ZenServerInstance LocalInst(TestEnv); - LocalCfg.Spawn(LocalInst); - - const auto Bucket = "legacy"sv; + const auto Bucket = "policy5"sv; zen::IoHash Key; zen::IoHash PayloadId; zen::CbPackage Package = GeneratePackage(Key, PayloadId); IoBuffer Buf = SerializeToBuffer(Package); - HttpClient LocalHttp{LocalCfg.BaseUri}; - HttpClient RemoteHttp{UpstreamCfg.BaseUri}; - - // Store package locally { HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}?Policy=Query,StoreLocal", Bucket, Key), Buf); CHECK(Result.StatusCode == HttpResponseCode::Created); @@ -528,27 +433,16 @@ TEST_CASE("zcache.policy") } } - SUBCASE("store - 'local/remote' stores local and upstream (cbpackage)") + // store - 'local/remote' stores local and upstream (cbpackage) + // Uses bucket "policy6" to isolate from other cbpackage scenarios (deterministic key) { - ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber()); - ZenServerInstance UpstreamInst(TestEnv); - UpstreamCfg.Spawn(UpstreamInst); - - ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamCfg.Port); - ZenServerInstance LocalInst(TestEnv); - LocalCfg.Spawn(LocalInst); - - const auto Bucket = "legacy"sv; + const auto Bucket = "policy6"sv; zen::IoHash Key; zen::IoHash PayloadId; zen::CbPackage Package = GeneratePackage(Key, PayloadId); IoBuffer Buf = SerializeToBuffer(Package); - HttpClient LocalHttp{LocalCfg.BaseUri}; - HttpClient RemoteHttp{UpstreamCfg.BaseUri}; - - // Store package locally and upstream { HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}?Policy=Query,Store", Bucket, Key), Buf); CHECK(Result.StatusCode == HttpResponseCode::Created); @@ -565,78 +459,62 @@ TEST_CASE("zcache.policy") } } - SUBCASE("skip - 'data' returns cache record without attachments/empty payload") + // skip - 'data' returns cache record without attachments/empty payload + // Uses bucket "skiptest7" to isolate from other cbpackage scenarios { - ZenConfig Cfg = ZenConfig::New(TestEnv.GetNewPortNumber()); - ZenServerInstance Instance(TestEnv); - Cfg.Spawn(Instance); - - const auto Bucket = "test"sv; + const auto Bucket = "skiptest7"sv; zen::IoHash Key; zen::IoHash PayloadId; zen::CbPackage Package = GeneratePackage(Key, PayloadId); IoBuffer Buf = SerializeToBuffer(Package); - HttpClient Http{Cfg.BaseUri}; - - // Store package { - HttpClient::Response Result = Http.Put(fmt::format("/{}/{}", Bucket, Key), Buf); + HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}", Bucket, Key), Buf); CHECK(Result.StatusCode == HttpResponseCode::Created); } - // Get package { HttpClient::Response Result = - Http.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}}); + LocalHttp.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}}); CHECK(Result); CbPackage ResponsePackage; CHECK(ResponsePackage.TryLoad(Result.ResponsePayload)); CHECK(ResponsePackage.GetAttachments().size() == 0); } - // Get record { HttpClient::Response Result = - Http.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/x-ue-cb"}}); + LocalHttp.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/x-ue-cb"}}); CHECK(Result); CbObject ResponseObject = zen::LoadCompactBinaryObject(Result.ResponsePayload); CHECK(ResponseObject); } - // Get payload { - HttpClient::Response Result = - Http.Get(fmt::format("/{}/{}/{}?Policy=Default,SkipData", Bucket, Key, PayloadId), {{"Accept", "application/x-ue-comp"}}); + HttpClient::Response Result = LocalHttp.Get(fmt::format("/{}/{}/{}?Policy=Default,SkipData", Bucket, Key, PayloadId), + {{"Accept", "application/x-ue-comp"}}); CHECK(Result); CHECK(Result.ResponsePayload.GetSize() == 0); } } - SUBCASE("skip - 'data' returns empty binary value") + // skip - 'data' returns empty binary value + // Uses size 8192 for unique key (avoids collision with size 1024/2048/4096 above) { - ZenConfig Cfg = ZenConfig::New(TestEnv.GetNewPortNumber()); - ZenServerInstance Instance(TestEnv); - Cfg.Spawn(Instance); - - const auto Bucket = "test"sv; + const auto Bucket = "skiptest8"sv; zen::IoHash Key; - IoBuffer BinaryValue = GenerateData(1024, Key); - - HttpClient Http{Cfg.BaseUri}; + IoBuffer BinaryValue = GenerateData(8192, Key); - // Store binary cache value { - HttpClient::Response Result = Http.Put(fmt::format("/{}/{}", Bucket, Key), BinaryValue); + HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}", Bucket, Key), BinaryValue); CHECK(Result.StatusCode == HttpResponseCode::Created); } - // Get package { HttpClient::Response Result = - Http.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/octet-stream"}}); + LocalHttp.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/octet-stream"}}); CHECK(Result); CHECK(Result.ResponsePayload.GetSize() == 0); } @@ -743,7 +621,11 @@ TEST_CASE("zcache.rpc") if (Result.StatusCode == HttpResponseCode::OK) { - CbPackage Response = ParsePackageMessage(Result.ResponsePayload); + ParseFlags PFlags = EnumHasAllFlags(AcceptOptions, RpcAcceptOptions::kAllowLocalReferences) ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + PermissiveLocalRefPolicy AllowAllPolicy; + const ILocalRefPolicy* PPolicy = EnumHasAllFlags(PFlags, ParseFlags::kAllowLocalReferences) ? &AllowAllPolicy : nullptr; + CbPackage Response = ParsePackageMessage(Result.ResponsePayload, {}, PFlags, PPolicy); CHECK(!Response.IsNull()); OutResult.Response = std::move(Response); CHECK(OutResult.Result.Parse(OutResult.Response)); @@ -1745,8 +1627,13 @@ TEST_CASE("zcache.rpc.partialchunks") CHECK(Result.StatusCode == HttpResponseCode::OK); - CbPackage Response = ParsePackageMessage(Result.ResponsePayload); - bool Loaded = !Response.IsNull(); + ParseFlags PFlags = EnumHasAllFlags(Options.AcceptOptions, RpcAcceptOptions::kAllowLocalReferences) + ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + PermissiveLocalRefPolicy AllowAllPolicy; + const ILocalRefPolicy* PPolicy = EnumHasAllFlags(PFlags, ParseFlags::kAllowLocalReferences) ? &AllowAllPolicy : nullptr; + CbPackage Response = ParsePackageMessage(Result.ResponsePayload, {}, PFlags, PPolicy); + bool Loaded = !Response.IsNull(); CHECK_MESSAGE(Loaded, "GetCacheChunks response failed to load."); cacherequests::GetCacheChunksResult GetCacheChunksResult; CHECK(GetCacheChunksResult.Parse(Response)); diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp index 021052a3b..a4755adec 100644 --- a/src/zenserver-test/compute-tests.cpp +++ b/src/zenserver-test/compute-tests.cpp @@ -21,6 +21,7 @@ # include <zenhttp/httpserver.h> # include <zenhttp/websocket.h> # include <zencompute/computeservice.h> +# include <zencore/fmtutils.h> # include <zenstore/zenstore.h> # include <zenutil/zenserverprocess.h> @@ -36,6 +37,8 @@ using namespace std::literals; static constexpr std::string_view kBuildSystemVersion = "17fe280d-ccd8-4be8-a9d1-89c944a70969"; static constexpr std::string_view kRot13Version = "13131313-1313-1313-1313-131313131313"; static constexpr std::string_view kSleepVersion = "88888888-8888-8888-8888-888888888888"; +static constexpr std::string_view kFailVersion = "fa11fa11-fa11-fa11-fa11-fa11fa11fa11"; +static constexpr std::string_view kCrashVersion = "c4a50000-c4a5-c4a5-c4a5-c4a5c4a5c4a5"; // In-memory implementation of ChunkResolver for test use. // Stores compressed data keyed by decompressed content hash. @@ -104,6 +107,16 @@ RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env) << "Sleep"sv; WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Fail"sv; + WorkerWriter << "version"sv << Guid::FromString(kFailVersion); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Crash"sv; + WorkerWriter << "version"sv << Guid::FromString(kCrashVersion); + WorkerWriter.EndObject(); WorkerWriter.EndArray(); CbPackage WorkerPackage; @@ -115,7 +128,7 @@ RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env) const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); HttpClient::Response RegisterResp = Client.Post(WorkerUrl, std::move(WorkerPackage)); REQUIRE_MESSAGE(RegisterResp, - fmt::format("Worker registration failed: status={}, body={}", int(RegisterResp.StatusCode), RegisterResp.ToText())); + fmt::format("Worker registration failed: status={}, body={}", RegisterResp.StatusCode, RegisterResp.ToText())); return WorkerId; } @@ -220,6 +233,83 @@ BuildSleepActionForSession(std::string_view Input, uint64_t SleepTimeMs, InMemor return ActionWriter.Save(); } +// Build a Fail action CbPackage. The worker exits with the given exit code. +static CbPackage +BuildFailActionPackage(int ExitCode) +{ + // The Fail function throws before reading inputs, but the action structure + // still requires a valid input attachment for the runner to manifest. + std::string_view Dummy = "x"sv; + + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Dummy.data(), Dummy.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Dummy.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Fail"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kFailVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "ExitCode"sv << static_cast<uint64_t>(ExitCode); + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + +// Build a Crash action CbPackage. The worker process crashes hard. +// Mode: "abort" (default) or "nullptr" (null pointer dereference). +static CbPackage +BuildCrashActionPackage(std::string_view Mode = "abort"sv) +{ + std::string_view Dummy = "x"sv; + + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Dummy.data(), Dummy.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Dummy.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Crash"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kCrashVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "Mode"sv << Mode; + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + static HttpClient::Response PollForResult(HttpClient& Client, const std::string& ResultUrl, uint64_t TimeoutMs = 30'000) { @@ -267,6 +357,41 @@ PollForLsnInCompleted(HttpClient& Client, const std::string& CompletedUrl, int L return false; } +static void +WaitForActionRunning(zen::compute::ComputeServiceSession& Session, uint64_t TimeoutMs = 10'000) +{ + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < TimeoutMs) + { + if (Session.GetActionCounts().Running > 0) + { + return; + } + Sleep(50); + } + FAIL("Timed out waiting for action to reach Running state"); +} + +static void +WaitForAnyActionRunningHttp(HttpClient& Client, uint64_t TimeoutMs = 10'000) +{ + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < TimeoutMs) + { + HttpClient::Response Resp = Client.Get("/jobs/running"sv); + if (Resp) + { + CbObject Obj = Resp.AsObject(); + if (Obj["running"sv].AsArrayView().Num() > 0) + { + return; + } + } + Sleep(50); + } + FAIL("Timed out waiting for any action to reach Running state"); +} + static std::string GetRot13Output(const CbPackage& ResultPackage) { @@ -340,8 +465,9 @@ public: } // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override { + ZEN_UNUSED(RelativeUri); m_WsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); } @@ -469,6 +595,16 @@ BuildWorkerPackage(ZenServerEnvironment& Env, InMemoryChunkResolver& Resolver) << "Sleep"sv; WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Fail"sv; + WorkerWriter << "version"sv << Guid::FromString(kFailVersion); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Crash"sv; + WorkerWriter << "version"sv << Guid::FromString(kCrashVersion); + WorkerWriter.EndObject(); WorkerWriter.EndArray(); CbPackage WorkerPackage; @@ -526,7 +662,7 @@ TEST_CASE("function.rot13") // Submit action via legacy /jobs/{worker} endpoint const std::string JobUrl = fmt::format("/jobs/{}", WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission"); @@ -536,7 +672,7 @@ TEST_CASE("function.rot13") HttpClient::Response ResultResp = PollForResult(Client, ResultUrl); REQUIRE_MESSAGE( ResultResp.StatusCode == HttpResponseCode::OK, - fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", ResultResp.StatusCode, Instance.GetLogOutput())); // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" CbPackage ResultPackage = ResultResp.AsPackage(); @@ -562,7 +698,7 @@ TEST_CASE("function.workers") const IoHash WorkerId = RegisterWorker(Client, TestEnv); - // GET /workers — the registered worker should appear in the listing + // GET /workers - the registered worker should appear in the listing HttpClient::Response ListResp = Client.Get("/workers"sv); REQUIRE_MESSAGE(ListResp, "Failed to list workers after registration"); @@ -578,10 +714,10 @@ TEST_CASE("function.workers") REQUIRE_MESSAGE(WorkerFound, fmt::format("Worker {} not found in worker listing", WorkerId.ToHexString())); - // GET /workers/{worker} — descriptor should match what was registered + // GET /workers/{worker} - descriptor should match what was registered const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); HttpClient::Response DescResp = Client.Get(WorkerUrl); - REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", int(DescResp.StatusCode))); + REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", DescResp.StatusCode)); CbObject Desc = DescResp.AsObject(); CHECK_EQ(Desc["buildsystem_version"sv].AsUuid(), Guid::FromString(kBuildSystemVersion)); @@ -607,7 +743,7 @@ TEST_CASE("function.workers") CHECK_MESSAGE(Rot13Found, "Rot13 function not found in worker descriptor"); CHECK_MESSAGE(SleepFound, "Sleep function not found in worker descriptor"); - // GET /workers/{unknown} — should return 404 + // GET /workers/{unknown} - should return 404 const std::string UnknownUrl = fmt::format("/workers/{}", IoHash::Zero.ToHexString()); HttpClient::Response NotFoundResp = Client.Get(UnknownUrl); CHECK_EQ(NotFoundResp.StatusCode, HttpResponseCode::NotFound); @@ -627,7 +763,7 @@ TEST_CASE("function.queues.lifecycle") // Create a queue HttpClient::Response CreateResp = Client.Post("/queues"sv); - REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText())); const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); @@ -651,8 +787,7 @@ TEST_CASE("function.queues.lifecycle") // Submit action via queue-scoped endpoint const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); - REQUIRE_MESSAGE(SubmitResp, - fmt::format("Queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Queue job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from queue job submission"); @@ -668,9 +803,8 @@ TEST_CASE("function.queues.lifecycle") // Retrieve result via queue-scoped /jobs/{lsn} endpoint const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); HttpClient::Response ResultResp = Client.Get(ResultUrl); - REQUIRE_MESSAGE( - ResultResp.StatusCode == HttpResponseCode::OK, - fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", ResultResp.StatusCode, Instance.GetLogOutput())); // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" CbPackage ResultPackage = ResultResp.AsPackage(); @@ -712,13 +846,13 @@ TEST_CASE("function.queues.cancel") // Submit a job const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); // Cancel the queue const std::string QueueUrl = fmt::format("/queues/{}", QueueId); HttpClient::Response CancelResp = Client.Delete(QueueUrl); REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, - fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + fmt::format("Queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText())); // Verify queue status shows cancelled HttpClient::Response StatusResp = Client.Get(QueueUrl); @@ -740,10 +874,10 @@ TEST_CASE("function.queues.remote") const IoHash WorkerId = RegisterWorker(Client, TestEnv); - // Create a remote queue — response includes both an integer queue_id and an OID queue_token + // Create a remote queue - response includes both an integer queue_id and an OID queue_token HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); REQUIRE_MESSAGE(CreateResp, - fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + fmt::format("Remote queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText())); CbObject CreateObj = CreateResp.AsObject(); const std::string QueueToken = std::string(CreateObj["queue_token"sv].AsString()); @@ -753,7 +887,7 @@ TEST_CASE("function.queues.remote") const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); REQUIRE_MESSAGE(SubmitResp, - fmt::format("Remote queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + fmt::format("Remote queue job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from remote queue job submission"); @@ -769,7 +903,7 @@ TEST_CASE("function.queues.remote") HttpClient::Response ResultResp = Client.Get(ResultUrl); REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK, fmt::format("Failed to retrieve result from remote queue: status={}\nServer log:\n{}", - int(ResultResp.StatusCode), + ResultResp.StatusCode, Instance.GetLogOutput())); // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" @@ -801,20 +935,19 @@ TEST_CASE("function.queues.cancel_running") // Submit a Sleep job long enough that it will still be running when we cancel const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(SubmitResp, - fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); // Wait for the worker process to start executing before cancelling - Sleep(1'000); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + WaitForAnyActionRunningHttp(Client); // Cancel the queue, which should interrupt the running Sleep job - const std::string QueueUrl = fmt::format("/queues/{}", QueueId); HttpClient::Response CancelResp = Client.Delete(QueueUrl); REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, - fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + fmt::format("Queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText())); // The cancelled job should appear in the /completed endpoint once the process exits const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); @@ -849,7 +982,7 @@ TEST_CASE("function.queues.remote_cancel") // Create a remote queue to obtain an OID token for token-addressed cancellation HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); REQUIRE_MESSAGE(CreateResp, - fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + fmt::format("Remote queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText())); const std::string QueueToken = std::string(CreateResp.AsObject()["queue_token"sv].AsString()); REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation"); @@ -857,20 +990,19 @@ TEST_CASE("function.queues.remote_cancel") // Submit a long-running Sleep job via the token-addressed endpoint const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(SubmitResp, - fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); // Wait for the worker process to start executing before cancelling - Sleep(1'000); + const std::string QueueUrl = fmt::format("/queues/{}", QueueToken); + WaitForAnyActionRunningHttp(Client); // Cancel the queue via its OID token - const std::string QueueUrl = fmt::format("/queues/{}", QueueToken); HttpClient::Response CancelResp = Client.Delete(QueueUrl); REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, - fmt::format("Remote queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + fmt::format("Remote queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText())); // The cancelled job should appear in the token-addressed /completed endpoint const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken); @@ -910,13 +1042,13 @@ TEST_CASE("function.queues.drain") // Submit a long-running job so we can verify it completes even after drain const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response Submit1 = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 2'000)); - REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", int(Submit1.StatusCode))); + REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", Submit1.StatusCode)); const int Lsn1 = Submit1.AsObject()["lsn"sv].AsInt32(); // Drain the queue const std::string DrainUrl = fmt::format("/queues/{}/drain", QueueId); HttpClient::Response DrainResp = Client.Post(DrainUrl); - REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", int(DrainResp.StatusCode), DrainResp.ToText())); + REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", DrainResp.StatusCode, DrainResp.ToText())); CHECK_EQ(std::string(DrainResp.AsObject()["state"sv].AsString()), "draining"); // Second submission should be rejected with 424 @@ -965,7 +1097,7 @@ TEST_CASE("function.priority") // jobs by priority when the slot becomes free. const std::string BlockerJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); HttpClient::Response BlockerResp = Client.Post(BlockerJobUrl, BuildSleepActionPackage("data"sv, 1'000)); - REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", int(BlockerResp.StatusCode))); + REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", BlockerResp.StatusCode)); // Submit 3 low-priority Rot13 jobs const std::string LowJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); @@ -982,7 +1114,7 @@ TEST_CASE("function.priority") REQUIRE_MESSAGE(LowResp3, "Low-priority job 3 submission failed"); const int LsnLow3 = LowResp3.AsObject()["lsn"sv].AsInt32(); - // Submit 1 high-priority Rot13 job — should execute before the low-priority ones + // Submit 1 high-priority Rot13 job - should execute before the low-priority ones const std::string HighJobUrl = fmt::format("/queues/{}/jobs/{}?priority=10", QueueId, WorkerId.ToHexString()); HttpClient::Response HighResp = Client.Post(HighJobUrl, BuildRot13ActionPackage("high"sv)); REQUIRE_MESSAGE(HighResp, "High-priority job submission failed"); @@ -1104,6 +1236,305 @@ TEST_CASE("function.priority") } ////////////////////////////////////////////////////////////////////////// +// Process exit code tests +// +// These tests exercise how the compute service handles worker processes +// that exit with non-zero exit codes, including retry behaviour and +// final failure reporting. + +TEST_CASE("function.exit_code") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + auto CreateQueue = [&](int MaxRetries) -> std::pair<int, std::string> { + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << MaxRetries; + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + return {QueueId, fmt::format("/queues/{}", QueueId)}; + }; + + // Scenario 1: failed_action - immediate failure with max_retries=0 + { + auto [QueueId, QueueUrl] = CreateQueue(0); + + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(42)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Fail job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Fail job submission"); + + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + bool FoundInHistory = false; + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + FoundInHistory = true; + break; + } + } + CHECK_MESSAGE(FoundInHistory, fmt::format("LSN {} not found in action history", Lsn)); + + const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); + HttpClient::Response ResultResp = Client.Get(ResultUrl); + CHECK_EQ(ResultResp.StatusCode, HttpResponseCode::OK); + } + + // Scenario 2: auto_retry - retried twice before permanent failure + { + auto [QueueId, QueueUrl] = CreateQueue(2); + + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(1)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}", SubmitResp.StatusCode)); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000), + fmt::format("LSN {} did not appear in queue {} completed list after retries\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + CHECK_EQ(Item.AsObjectView()["retry_count"sv].AsInt32(), 2); + break; + } + } + + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + } + + // Scenario 3: reschedule_failed - manual reschedule rejected after retry limit + { + auto [QueueId, QueueUrl] = CreateQueue(1); + + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(7)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}", SubmitResp.StatusCode)); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000), + fmt::format("LSN {} did not appear in queue completed list\nServer log:\n{}", Lsn, Instance.GetLogOutput())); + + const std::string RescheduleUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); + HttpClient::Response RescheduleResp = Client.Post(RescheduleUrl); + CHECK_EQ(RescheduleResp.StatusCode, HttpResponseCode::Conflict); + } + + // Scenario 4: mixed_success_and_failure - one success and one failure in the same queue + { + auto [QueueId, QueueUrl] = CreateQueue(0); + + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + + HttpClient::Response SuccessResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello"sv)); + REQUIRE_MESSAGE(SuccessResp, "Rot13 job submission failed"); + const int LsnSuccess = SuccessResp.AsObject()["lsn"sv].AsInt32(); + + HttpClient::Response FailResp = Client.Post(JobUrl, BuildFailActionPackage(1)); + REQUIRE_MESSAGE(FailResp, "Fail job submission failed"); + const int LsnFail = FailResp.AsObject()["lsn"sv].AsInt32(); + + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, LsnSuccess), + fmt::format("Success LSN {} did not complete\nServer log:\n{}", LsnSuccess, Instance.GetLogOutput())); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, LsnFail), + fmt::format("Fail LSN {} did not complete\nServer log:\n{}", LsnFail, Instance.GetLogOutput())); + + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0); + } +} + +TEST_CASE("function.crash") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + auto CreateQueue = [&](int MaxRetries) -> std::pair<int, std::string> { + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << MaxRetries; + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + return {QueueId, fmt::format("/queues/{}", QueueId)}; + }; + + // Scenario 1: abort - worker process calls std::abort(), no retries + { + auto [QueueId, QueueUrl] = CreateQueue(0); + + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("abort"sv)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Crash job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Crash job submission"); + + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + bool FoundInHistory = false; + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + FoundInHistory = true; + break; + } + } + CHECK_MESSAGE(FoundInHistory, fmt::format("LSN {} not found in action history", Lsn)); + } + + // Scenario 2: nullptr - worker process dereferences null, no retries + { + auto [QueueId, QueueUrl] = CreateQueue(0); + + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("nullptr"sv)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Crash job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Crash job submission"); + + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + } + + // Scenario 3: auto_retry - crash retried once before permanent failure + { + auto [QueueId, QueueUrl] = CreateQueue(1); + + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("abort"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Crash job submission failed: status={}", SubmitResp.StatusCode)); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000), + fmt::format("LSN {} did not appear in queue {} completed list after retries\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + CHECK_EQ(Item.AsObjectView()["retry_count"sv].AsInt32(), 1); + break; + } + } + + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + } +} + +////////////////////////////////////////////////////////////////////////// // Remote worker synchronization tests // // These tests exercise the orchestrator discovery path where new compute @@ -1139,7 +1570,6 @@ TEST_CASE("function.remote.worker_sync_on_discovery") // Trigger immediate orchestrator re-query and wait for runner setup Session.NotifyOrchestratorChanged(); - Sleep(2'000); // Submit Rot13 action via session CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); @@ -1162,9 +1592,8 @@ TEST_CASE("function.remote.worker_sync_on_discovery") Sleep(200); } - REQUIRE_MESSAGE( - ResultCode == HttpResponseCode::OK, - fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", ResultCode, Instance.GetLogOutput())); REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); @@ -1199,7 +1628,6 @@ TEST_CASE("function.remote.late_runner_discovery") // Wait for W1 discovery Session.NotifyOrchestratorChanged(); - Sleep(2'000); // Baseline: submit Rot13 action and verify it completes on W1 { @@ -1241,27 +1669,37 @@ TEST_CASE("function.remote.late_runner_discovery") // Wait for W2 discovery Session.NotifyOrchestratorChanged(); - Sleep(2'000); - // Verify W2 received the worker by querying its /compute/workers endpoint directly + // Poll W2 until the worker has been synced (SyncWorkersToRunner is async) { - const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port2); - HttpClient Client(ComputeBaseUri); - HttpClient::Response ListResp = Client.Get("/workers"sv); - REQUIRE_MESSAGE(ListResp, "Failed to list workers on W2"); + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port2); + HttpClient Client(ComputeBaseUri); + bool WorkerFound = false; + Stopwatch Timer; - bool WorkerFound = false; - for (auto& Item : ListResp.AsObject()["workers"sv]) + while (Timer.GetElapsedTimeMs() < 10'000) { - if (Item.AsHash() == WorkerPackage.GetObjectHash()) + HttpClient::Response ListResp = Client.Get("/workers"sv); + if (ListResp) + { + for (auto& Item : ListResp.AsObject()["workers"sv]) + { + if (Item.AsHash() == WorkerPackage.GetObjectHash()) + { + WorkerFound = true; + break; + } + } + } + if (WorkerFound) { - WorkerFound = true; break; } + Sleep(50); } REQUIRE_MESSAGE(WorkerFound, - fmt::format("Worker not found on W2 after discovery — SyncWorkersToRunner may have failed\nW2 log:\n{}", + fmt::format("Worker not found on W2 after discovery - SyncWorkersToRunner may have failed\nW2 log:\n{}", Instance2.GetLogOutput())); } @@ -1322,7 +1760,6 @@ TEST_CASE("function.remote.queue_association") // Wait for scheduler to discover the runner Session.NotifyOrchestratorChanged(); - Sleep(2'000); // Create a local queue and submit action to it auto QueueResult = Session.CreateQueue(); @@ -1349,9 +1786,8 @@ TEST_CASE("function.remote.queue_association") Sleep(200); } - REQUIRE_MESSAGE( - ResultCode == HttpResponseCode::OK, - fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", ResultCode, Instance.GetLogOutput())); REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); @@ -1401,7 +1837,6 @@ TEST_CASE("function.remote.queue_cancel_propagation") // Wait for scheduler to discover the runner Session.NotifyOrchestratorChanged(); - Sleep(2'000); // Create a local queue and submit a long-running Sleep action auto QueueResult = Session.CreateQueue(); @@ -1414,9 +1849,9 @@ TEST_CASE("function.remote.queue_cancel_propagation") REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); // Wait for the action to start running on the remote - Sleep(2'000); + WaitForActionRunning(Session); - // Cancel the local queue — this should propagate to the remote + // Cancel the local queue - this should propagate to the remote Session.CancelQueue(QueueId); // Poll for the action to complete (as cancelled) @@ -1481,13 +1916,13 @@ TEST_CASE("function.abandon_running_http") const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", int(SubmitResp.StatusCode))); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", SubmitResp.StatusCode)); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN"); // Wait for the process to start running - Sleep(1'000); + WaitForAnyActionRunningHttp(Client); // Verify the ready endpoint returns OK before abandon { @@ -1498,7 +1933,7 @@ TEST_CASE("function.abandon_running_http") // Trigger abandon via the HTTP endpoint HttpClient::Response AbandonResp = Client.Post("/abandon"sv); REQUIRE_MESSAGE(AbandonResp.StatusCode == HttpResponseCode::OK, - fmt::format("Abandon request failed: status={}, body={}", int(AbandonResp.StatusCode), AbandonResp.ToText())); + fmt::format("Abandon request failed: status={}, body={}", AbandonResp.StatusCode, AbandonResp.ToText())); // Ready endpoint should now return 503 { @@ -1529,7 +1964,7 @@ TEST_CASE("function.abandon_running_http") CHECK_MESSAGE(RejectedResp.StatusCode != HttpResponseCode::OK, "Expected action submission to be rejected in Abandoned state"); } -TEST_CASE("function.session.abandon_pending") +TEST_CASE("function.session.abandon_pending" * doctest::skip()) { // Create a session with no runners so actions stay pending InMemoryChunkResolver Resolver; @@ -1540,7 +1975,7 @@ TEST_CASE("function.session.abandon_pending") CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); - // Enqueue several actions — they will stay pending because there are no runners + // Enqueue several actions - they will stay pending because there are no runners auto QueueResult = Session.CreateQueue(); REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue"); @@ -1553,7 +1988,7 @@ TEST_CASE("function.session.abandon_pending") REQUIRE_MESSAGE(Enqueue2, "Failed to enqueue action 2"); REQUIRE_MESSAGE(Enqueue3, "Failed to enqueue action 3"); - // Transition to Abandoned — should mark all pending actions as Abandoned + // Transition to Abandoned - should mark all pending actions as Abandoned bool Transitioned = Session.Abandon(); CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); CHECK(Session.GetSessionState() == zen::compute::ComputeServiceSession::SessionState::Abandoned); @@ -1577,7 +2012,7 @@ TEST_CASE("function.session.abandon_pending") } Sleep(100); } - CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code))); + CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, Code)); } // Queue should show 0 active, 3 abandoned @@ -1589,7 +2024,7 @@ TEST_CASE("function.session.abandon_pending") auto Rejected = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); CHECK_MESSAGE(!Rejected, "Expected action submission to be rejected in Abandoned state"); - // Abandoned → Sunset should be valid + // Abandoned -> Sunset should be valid CHECK(Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Sunset)); Session.Shutdown(); @@ -1618,7 +2053,6 @@ TEST_CASE("function.session.abandon_running") // Wait for scheduler to discover the runner Session.NotifyOrchestratorChanged(); - Sleep(2'000); // Create a queue and submit a long-running Sleep action auto QueueResult = Session.CreateQueue(); @@ -1631,9 +2065,9 @@ TEST_CASE("function.session.abandon_running") REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); // Wait for the action to start running on the remote - Sleep(2'000); + WaitForActionRunning(Session); - // Transition to Abandoned — should abandon the running action + // Transition to Abandoned - should abandon the running action bool Transitioned = Session.Abandon(); CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); CHECK(!Session.IsHealthy()); @@ -1689,7 +2123,6 @@ TEST_CASE("function.remote.abandon_propagation") // Wait for scheduler to discover the runner Session.NotifyOrchestratorChanged(); - Sleep(2'000); // Create a local queue and submit a long-running Sleep action auto QueueResult = Session.CreateQueue(); @@ -1702,9 +2135,9 @@ TEST_CASE("function.remote.abandon_propagation") REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); // Wait for the action to start running on the remote - Sleep(2'000); + WaitForActionRunning(Session); - // Transition to Abandoned — should abandon the running action and propagate + // Transition to Abandoned - should abandon the running action and propagate bool Transitioned = Session.Abandon(); CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); @@ -1762,7 +2195,6 @@ TEST_CASE("function.remote.shutdown_cancels_queues") Session.RegisterWorker(WorkerPackage); Session.NotifyOrchestratorChanged(); - Sleep(2'000); // Create a queue and submit a long-running action so the remote queue is established auto QueueResult = Session.CreateQueue(); @@ -1775,7 +2207,7 @@ TEST_CASE("function.remote.shutdown_cancels_queues") REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); // Wait for the action to start running on the remote - Sleep(2'000); + WaitForActionRunning(Session); // Verify the remote has a non-implicit queue before shutdown HttpClient RemoteClient(Instance.GetBaseUri() + "/compute"); @@ -1795,7 +2227,7 @@ TEST_CASE("function.remote.shutdown_cancels_queues") REQUIRE_MESSAGE(RemoteQueueFound, "Expected remote queue to exist before shutdown"); } - // Shut down the session — this should cancel all remote queues + // Shut down the session - this should cancel all remote queues Session.Shutdown(); // Verify the remote queue is now cancelled @@ -1837,7 +2269,6 @@ TEST_CASE("function.remote.shutdown_rejects_new_work") // Wait for runner discovery Session.NotifyOrchestratorChanged(); - Sleep(2'000); // Baseline: submit an action and verify it completes { @@ -1865,7 +2296,7 @@ TEST_CASE("function.remote.shutdown_rejects_new_work") CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); } - // Shut down — the remote runner should now reject new work + // Shut down - the remote runner should now reject new work Session.Shutdown(); // Attempting to enqueue after shutdown should fail (session is in Sunset state) @@ -1894,7 +2325,7 @@ TEST_CASE("function.session.retract_pending") REQUIRE_MESSAGE(Enqueued, "Failed to enqueue action"); // Let the scheduler process the pending action - Sleep(500); + Sleep(50); // Retract the pending action auto Result = Session.RetractAction(Enqueued.Lsn); @@ -1903,7 +2334,7 @@ TEST_CASE("function.session.retract_pending") // The action should be re-enqueued as pending (still no runners, so stays pending). // Let the scheduler process the retracted action back to pending. - Sleep(500); + Sleep(50); // Queue should still show 1 active (the action was rescheduled, not completed) auto Status = Session.GetQueueStatus(QueueResult.QueueId); @@ -1955,7 +2386,7 @@ TEST_CASE("function.session.retract_not_terminal") REQUIRE_MESSAGE(Code == HttpResponseCode::OK, "Action did not complete within timeout"); - // Retract should fail — action already completed (no longer in pending/running maps) + // Retract should fail - action already completed (no longer in pending/running maps) auto RetractResult = Session.RetractAction(Enqueued.Lsn); CHECK(!RetractResult.Success); @@ -1979,24 +2410,24 @@ TEST_CASE("function.retract_http") // Submit a long-running Sleep action to occupy the single execution slot const std::string BlockerUrl = fmt::format("/jobs/{}", WorkerId.ToHexString()); HttpClient::Response BlockerResp = Client.Post(BlockerUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", int(BlockerResp.StatusCode))); + REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", BlockerResp.StatusCode)); - // Submit a second action — it will stay pending because the slot is occupied + // Submit a second action - it will stay pending because the slot is occupied HttpClient::Response SubmitResp = Client.Post(BlockerUrl, BuildRot13ActionPackage("Retract HTTP Test"sv)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", int(SubmitResp.StatusCode))); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", SubmitResp.StatusCode)); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission"); - // Wait for the scheduler to process the pending action into m_PendingActions - Sleep(1'000); + // Wait for the blocker action to start running (occupying the single slot) + WaitForAnyActionRunningHttp(Client); // Retract the pending action via POST /jobs/{lsn}/retract const std::string RetractUrl = fmt::format("/jobs/{}/retract", Lsn); HttpClient::Response RetractResp = Client.Post(RetractUrl); CHECK_MESSAGE(RetractResp.StatusCode == HttpResponseCode::OK, fmt::format("Retract failed: status={}, body={}\nServer log:\n{}", - int(RetractResp.StatusCode), + RetractResp.StatusCode, RetractResp.ToText(), Instance.GetLogOutput())); @@ -2008,10 +2439,45 @@ TEST_CASE("function.retract_http") } // A second retract should also succeed (action is back to pending) - Sleep(500); + Sleep(50); HttpClient::Response RetractResp2 = Client.Post(RetractUrl); CHECK_MESSAGE(RetractResp2.StatusCode == HttpResponseCode::OK, - fmt::format("Second retract failed: status={}, body={}", int(RetractResp2.StatusCode), RetractResp2.ToText())); + fmt::format("Second retract failed: status={}, body={}", RetractResp2.StatusCode, RetractResp2.ToText())); +} + +TEST_CASE("function.session.immediate_query_after_enqueue") +{ + // Verify that actions are immediately visible to GetActionResult and + // FindActionResult right after enqueue, without waiting for the + // scheduler thread to process the update. + + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.Ready(); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + CbObject ActionObj = BuildRot13ActionForSession("immediate-query"sv, Resolver); + + auto EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Failed to enqueue action"); + + // Query by LSN immediately - must not return NotFound + CbPackage Result; + HttpResponseCode Code = Session.GetActionResult(EnqueueRes.Lsn, Result); + CHECK_MESSAGE(Code == HttpResponseCode::Accepted, + fmt::format("GetActionResult returned {} immediately after enqueue, expected Accepted", Code)); + + // Query by ActionId immediately - must not return NotFound + const IoHash ActionId = ActionObj.GetHash(); + CbPackage FindResult; + HttpResponseCode FindCode = Session.FindActionResult(ActionId, FindResult); + CHECK_MESSAGE(FindCode == HttpResponseCode::Accepted, + fmt::format("FindActionResult returned {} immediately after enqueue, expected Accepted", FindCode)); + + Session.Shutdown(); } TEST_SUITE_END(); diff --git a/src/zenserver-test/hub-tests.cpp b/src/zenserver-test/hub-tests.cpp index b2da552fc..35a840e5d 100644 --- a/src/zenserver-test/hub-tests.cpp +++ b/src/zenserver-test/hub-tests.cpp @@ -329,17 +329,36 @@ TEST_CASE("hub.lifecycle.children") CHECK_EQ(Result.AsText(), "GhijklmNop"sv); } - Result = Client.Post("modules/abc/deprovision"); + // Deprovision all modules at once + Result = Client.Post("deprovision"); REQUIRE(Result); + CHECK_EQ(Result.StatusCode, HttpResponseCode::Accepted); + { + CbObject Body = Result.AsObject(); + CbArrayView AcceptedArr = Body["Accepted"].AsArrayView(); + CHECK_EQ(AcceptedArr.Num(), 2u); + bool FoundAbc = false; + bool FoundDef = false; + for (CbFieldView F : AcceptedArr) + { + if (F.AsString() == "abc"sv) + { + FoundAbc = true; + } + else if (F.AsString() == "def"sv) + { + FoundDef = true; + } + } + CHECK(FoundAbc); + CHECK(FoundDef); + } REQUIRE(WaitForModuleGone(Client, "abc")); + REQUIRE(WaitForModuleGone(Client, "def")); { HttpClient ModClient(fmt::format("http://localhost:{}", AbcPort), kFastTimeout); CHECK(WaitForPortUnreachable(ModClient)); } - - Result = Client.Post("modules/def/deprovision"); - REQUIRE(Result); - REQUIRE(WaitForModuleGone(Client, "def")); { HttpClient ModClient(fmt::format("http://localhost:{}", DefPort), kFastTimeout); CHECK(WaitForPortUnreachable(ModClient)); @@ -349,6 +368,10 @@ TEST_CASE("hub.lifecycle.children") Result = Client.Get("status"); REQUIRE(Result); CHECK_EQ(Result.AsObject()["modules"].AsArrayView().Num(), 0u); + + // Deprovision-all with no modules + Result = Client.Post("deprovision"); + CHECK(Result); } static bool @@ -377,7 +400,7 @@ TEST_CASE("hub.consul.kv") consul::ConsulProcess ConsulProc; ConsulProc.SpawnConsulAgent(); - consul::ConsulClient Client("http://localhost:8500/"); + consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"}); Client.SetKeyValue("zen/hub/testkey", "testvalue"); std::string RetrievedValue = Client.GetKeyValue("zen/hub/testkey"); @@ -399,7 +422,7 @@ TEST_CASE("hub.consul.hub.registration") "--consul-health-interval-seconds=5 --consul-deregister-after-seconds=60"); REQUIRE(PortNumber != 0); - consul::ConsulClient Client("http://localhost:8500/"); + consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"}); REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000)); // Verify custom intervals flowed through to the registered check @@ -480,7 +503,7 @@ TEST_CASE("hub.consul.hub.registration.token") // Use a plain client -- dev-mode Consul doesn't enforce ACLs, but the // server has exercised the ConsulTokenEnv -> GetEnvVariable -> ConsulClient path. - consul::ConsulClient Client("http://localhost:8500/"); + consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"}); REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000)); @@ -501,7 +524,7 @@ TEST_CASE("hub.consul.provision.registration") Instance.SpawnServerAndWaitUntilReady("--consul-endpoint=http://localhost:8500/ --instance-id=test-instance"); REQUIRE(PortNumber != 0); - consul::ConsulClient Client("http://localhost:8500/"); + consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"}); REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000)); @@ -762,9 +785,10 @@ TEST_CASE("hub.hibernate.errors") CHECK(!Result); CHECK_EQ(Result.StatusCode, HttpResponseCode::NotFound); + // Obliterate of an unknown module succeeds (cleans up backend data for dehydrated modules) Result = Client.Delete("modules/unknown"); - CHECK(!Result); - CHECK_EQ(Result.StatusCode, HttpResponseCode::NotFound); + CHECK(Result); + CHECK_EQ(Result.StatusCode, HttpResponseCode::Accepted); // Double-provision: second call while first is in-flight returns 202 Accepted with the same port. Result = Client.Post("modules/errmod/provision"); diff --git a/src/zenserver-test/logging-tests.cpp b/src/zenserver-test/logging-tests.cpp index 2e530ff92..cb0926ddc 100644 --- a/src/zenserver-test/logging-tests.cpp +++ b/src/zenserver-test/logging-tests.cpp @@ -69,8 +69,8 @@ TEST_CASE("logging.file.default") // --quiet sets the console sink level to WARN. The formatted "[info] ..." // entry written by the default logger's console sink must therefore not appear -// in captured stdout. (The "console" named logger — used by ZEN_CONSOLE_* -// macros — may still emit plain-text messages without a level marker, so we +// in captured stdout. (The "console" named logger - used by ZEN_CONSOLE_* +// macros - may still emit plain-text messages without a level marker, so we // check for the absence of the FullFormatter "[info]" prefix rather than the // message text itself.) TEST_CASE("logging.console.quiet") @@ -146,28 +146,6 @@ TEST_CASE("logging.file.json") CHECK_MESSAGE(LogContains(FileLog, "\"message\""), FileLog); CHECK_MESSAGE(LogContains(FileLog, "\"source\": \"zenserver\""), FileLog); CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); -} - -// --log-id <id> is automatically set to the server instance name in test mode. -// The JSON formatter emits this value as the "id" field, so every entry in a -// .json log file must carry a non-empty "id". -TEST_CASE("logging.log_id") -{ - const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); - const std::filesystem::path LogFile = TestDir / "test.json"; - - ZenServerInstance Instance(TestEnv); - Instance.SetDataDir(TestDir); - - const std::string LogArg = fmt::format("--abslog {}", LogFile.string()); - const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); - CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); - - Instance.Shutdown(); - - CHECK_MESSAGE(std::filesystem::exists(LogFile), "JSON log file was not created"); - const std::string FileLog = ReadFileToString(LogFile); - // The JSON formatter writes the log-id as: "id": "<value>", CHECK_MESSAGE(LogContains(FileLog, "\"id\": \""), FileLog); } diff --git a/src/zenserver-test/objectstore-tests.cpp b/src/zenserver-test/objectstore-tests.cpp index 1f6a7675c..99c92e15f 100644 --- a/src/zenserver-test/objectstore-tests.cpp +++ b/src/zenserver-test/objectstore-tests.cpp @@ -19,18 +19,22 @@ using namespace std::literals; TEST_SUITE_BEGIN("server.objectstore"); -TEST_CASE("objectstore.blobs") +TEST_CASE("objectstore") { - std::string_view Bucket = "bkt"sv; + ZenServerInstance Instance(TestEnv); + + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--objectstore-enabled"); + CHECK(Port != 0); - std::vector<IoHash> CompressedBlobsHashes; - std::vector<uint64_t> BlobsSizes; - std::vector<uint64_t> CompressedBlobsSizes; + // --- objectstore.blobs --- { - ZenServerInstance Instance(TestEnv); + INFO("objectstore.blobs"); + + std::string_view Bucket = "bkt"sv; - const uint16_t PortNumber = Instance.SpawnServerAndWaitUntilReady(fmt::format("--objectstore-enabled")); - CHECK(PortNumber != 0); + std::vector<IoHash> CompressedBlobsHashes; + std::vector<uint64_t> BlobsSizes; + std::vector<uint64_t> CompressedBlobsSizes; HttpClient Client(Instance.GetBaseUri() + "/obj/"); @@ -68,94 +72,238 @@ TEST_CASE("objectstore.blobs") CHECK_EQ(RawSize, BlobsSizes[I]); } } + + // --- objectstore.s3client --- + { + INFO("objectstore.s3client"); + + S3ClientOptions Opts; + Opts.BucketName = "s3test"; + Opts.Region = "us-east-1"; + Opts.Endpoint = fmt::format("http://localhost:{}/obj/bucket", Port); + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = "testkey"; + Opts.Credentials.SecretAccessKey = "testsecret"; + + S3Client Client(Opts); + + // -- PUT + GET roundtrip -- + std::string_view TestData = "hello from s3client via objectstore"sv; + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); + S3Result PutRes = Client.PutObject("test/hello.txt", std::move(Content)); + REQUIRE_MESSAGE(PutRes.IsSuccess(), PutRes.Error); + + S3GetObjectResult GetRes = Client.GetObject("test/hello.txt"); + REQUIRE_MESSAGE(GetRes.IsSuccess(), GetRes.Error); + CHECK(GetRes.AsText() == TestData); + + // -- PUT overwrites -- + IoBuffer Original = IoBufferBuilder::MakeFromMemory(MakeMemoryView("original"sv)); + IoBuffer Overwrite = IoBufferBuilder::MakeFromMemory(MakeMemoryView("overwritten"sv)); + REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Original)).IsSuccess()); + REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Overwrite)).IsSuccess()); + + S3GetObjectResult OverwriteGet = Client.GetObject("overwrite/file.txt"); + REQUIRE(OverwriteGet.IsSuccess()); + CHECK(OverwriteGet.AsText() == "overwritten"sv); + + // -- GET not found -- + S3GetObjectResult NotFoundGet = Client.GetObject("nonexistent/file.dat"); + CHECK_FALSE(NotFoundGet.IsSuccess()); + + // -- HEAD found -- + std::string_view HeadData = "head test data"sv; + IoBuffer HeadContent = IoBufferBuilder::MakeFromMemory(MakeMemoryView(HeadData)); + REQUIRE(Client.PutObject("head/meta.txt", std::move(HeadContent)).IsSuccess()); + + S3HeadObjectResult HeadRes = Client.HeadObject("head/meta.txt"); + REQUIRE_MESSAGE(HeadRes.IsSuccess(), HeadRes.Error); + CHECK(HeadRes.Status == HeadObjectResult::Found); + CHECK(HeadRes.Info.Size == HeadData.size()); + + // -- HEAD not found -- + S3HeadObjectResult HeadNotFound = Client.HeadObject("nonexistent/file.dat"); + CHECK(HeadNotFound.IsSuccess()); + CHECK(HeadNotFound.Status == HeadObjectResult::NotFound); + + // -- LIST objects -- + for (int i = 0; i < 3; ++i) + { + std::string Key = fmt::format("listing/item-{}.txt", i); + std::string Payload = fmt::format("content-{}", i); + IoBuffer Buf = IoBufferBuilder::MakeFromMemory(MakeMemoryView(Payload)); + REQUIRE(Client.PutObject(Key, std::move(Buf)).IsSuccess()); + } + + S3ListObjectsResult ListRes = Client.ListObjects("listing/"); + REQUIRE_MESSAGE(ListRes.IsSuccess(), ListRes.Error); + REQUIRE(ListRes.Objects.size() == 3); + + std::vector<std::string> Keys; + for (const S3ObjectInfo& Obj : ListRes.Objects) + { + Keys.push_back(Obj.Key); + CHECK(Obj.Size > 0); + } + std::sort(Keys.begin(), Keys.end()); + CHECK(Keys[0] == "listing/item-0.txt"); + CHECK(Keys[1] == "listing/item-1.txt"); + CHECK(Keys[2] == "listing/item-2.txt"); + + // -- LIST empty prefix -- + S3ListObjectsResult EmptyList = Client.ListObjects("no-such-prefix/"); + REQUIRE(EmptyList.IsSuccess()); + CHECK(EmptyList.Objects.empty()); + } + + // --- objectstore.range-requests --- + { + INFO("objectstore.range-requests"); + + HttpClient Client(Instance.GetBaseUri() + "/obj/"); + + IoBuffer Blob = CreateRandomBlob(1024); + MemoryView BlobView = Blob.GetView(); + std::string ObjectPath = "bucket/bkt/range-test/data.bin"; + + HttpClient::Response PutResult = Client.Put(ObjectPath, IoBuffer(Blob)); + REQUIRE(PutResult); + + // Full GET without Range header + { + HttpClient::Response Result = Client.Get(ObjectPath); + CHECK(Result.StatusCode == HttpResponseCode::OK); + CHECK_EQ(Result.ResponsePayload.GetSize(), 1024u); + CHECK(Result.ResponsePayload.GetView().EqualBytes(BlobView)); + } + + // Single range: bytes 100-199 + { + HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=100-199"}}); + CHECK(Result.StatusCode == HttpResponseCode::PartialContent); + CHECK_EQ(Result.ResponsePayload.GetSize(), 100u); + CHECK(Result.ResponsePayload.GetView().EqualBytes(BlobView.Mid(100, 100))); + } + + // Range starting at zero: bytes 0-49 + { + HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=0-49"}}); + CHECK(Result.StatusCode == HttpResponseCode::PartialContent); + CHECK_EQ(Result.ResponsePayload.GetSize(), 50u); + CHECK(Result.ResponsePayload.GetView().EqualBytes(BlobView.Mid(0, 50))); + } + + // Range at end of file: bytes 1000-1023 + { + HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=1000-1023"}}); + CHECK(Result.StatusCode == HttpResponseCode::PartialContent); + CHECK_EQ(Result.ResponsePayload.GetSize(), 24u); + CHECK(Result.ResponsePayload.GetView().EqualBytes(BlobView.Mid(1000, 24))); + } + + // Multiple ranges: bytes 0-49 and 100-149 + { + HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=0-49,100-149"}}); + CHECK(Result.StatusCode == HttpResponseCode::PartialContent); + + std::string_view Body(reinterpret_cast<const char*>(Result.ResponsePayload.GetData()), Result.ResponsePayload.GetSize()); + + // Verify multipart structure contains both range payloads + CHECK(Body.find("Content-Range: bytes 0-49/1024") != std::string_view::npos); + CHECK(Body.find("Content-Range: bytes 100-149/1024") != std::string_view::npos); + + // Extract and verify actual data for first range + auto FindPartData = [&](std::string_view ContentRange) -> std::string_view { + size_t Pos = Body.find(ContentRange); + if (Pos == std::string_view::npos) + { + return {}; + } + // Skip past the Content-Range line and the blank line separator + Pos = Body.find("\r\n\r\n", Pos); + if (Pos == std::string_view::npos) + { + return {}; + } + Pos += 4; + size_t End = Body.find("\r\n--", Pos); + if (End == std::string_view::npos) + { + return {}; + } + return Body.substr(Pos, End - Pos); + }; + + std::string_view Part1 = FindPartData("Content-Range: bytes 0-49/1024"); + CHECK_EQ(Part1.size(), 50u); + CHECK(MemoryView(Part1.data(), Part1.size()).EqualBytes(BlobView.Mid(0, 50))); + + std::string_view Part2 = FindPartData("Content-Range: bytes 100-149/1024"); + CHECK_EQ(Part2.size(), 50u); + CHECK(MemoryView(Part2.data(), Part2.size()).EqualBytes(BlobView.Mid(100, 50))); + } + + // Out-of-bounds single range + { + HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=2000-2099"}}); + CHECK(Result.StatusCode == HttpResponseCode::RangeNotSatisfiable); + } + + // Out-of-bounds multi-range + { + HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=0-49,2000-2099"}}); + CHECK(Result.StatusCode == HttpResponseCode::RangeNotSatisfiable); + } + } } -TEST_CASE("objectstore.s3client") +TEST_CASE("objectstore.range-requests-download") { ZenServerInstance Instance(TestEnv); const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--objectstore-enabled"); - CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); - - // S3Client in path-style builds paths as /{bucket}/{key}. - // The objectstore routes objects at bucket/{bucket}/{key} relative to its base. - // Point the S3Client endpoint at {server}/obj/bucket so the paths line up. - S3ClientOptions Opts; - Opts.BucketName = "s3test"; - Opts.Region = "us-east-1"; - Opts.Endpoint = fmt::format("http://localhost:{}/obj/bucket", Port); - Opts.PathStyle = true; - Opts.Credentials.AccessKeyId = "testkey"; - Opts.Credentials.SecretAccessKey = "testsecret"; - - S3Client Client(Opts); - - // -- PUT + GET roundtrip -- - std::string_view TestData = "hello from s3client via objectstore"sv; - IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); - S3Result PutRes = Client.PutObject("test/hello.txt", std::move(Content)); - REQUIRE_MESSAGE(PutRes.IsSuccess(), PutRes.Error); - - S3GetObjectResult GetRes = Client.GetObject("test/hello.txt"); - REQUIRE_MESSAGE(GetRes.IsSuccess(), GetRes.Error); - CHECK(GetRes.AsText() == TestData); - - // -- PUT overwrites -- - IoBuffer Original = IoBufferBuilder::MakeFromMemory(MakeMemoryView("original"sv)); - IoBuffer Overwrite = IoBufferBuilder::MakeFromMemory(MakeMemoryView("overwritten"sv)); - REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Original)).IsSuccess()); - REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Overwrite)).IsSuccess()); - - S3GetObjectResult OverwriteGet = Client.GetObject("overwrite/file.txt"); - REQUIRE(OverwriteGet.IsSuccess()); - CHECK(OverwriteGet.AsText() == "overwritten"sv); - - // -- GET not found -- - S3GetObjectResult NotFoundGet = Client.GetObject("nonexistent/file.dat"); - CHECK_FALSE(NotFoundGet.IsSuccess()); - - // -- HEAD found -- - std::string_view HeadData = "head test data"sv; - IoBuffer HeadContent = IoBufferBuilder::MakeFromMemory(MakeMemoryView(HeadData)); - REQUIRE(Client.PutObject("head/meta.txt", std::move(HeadContent)).IsSuccess()); - - S3HeadObjectResult HeadRes = Client.HeadObject("head/meta.txt"); - REQUIRE_MESSAGE(HeadRes.IsSuccess(), HeadRes.Error); - CHECK(HeadRes.Status == HeadObjectResult::Found); - CHECK(HeadRes.Info.Size == HeadData.size()); - - // -- HEAD not found -- - S3HeadObjectResult HeadNotFound = Client.HeadObject("nonexistent/file.dat"); - CHECK(HeadNotFound.IsSuccess()); - CHECK(HeadNotFound.Status == HeadObjectResult::NotFound); - - // -- LIST objects -- - for (int i = 0; i < 3; ++i) + REQUIRE(Port != 0); + + HttpClient Client(Instance.GetBaseUri() + "/obj/"); + + IoBuffer Blob = CreateRandomBlob(1024); + MemoryView BlobView = Blob.GetView(); + std::string ObjectPath = "bucket/bkt/range-download-test/data.bin"; + + HttpClient::Response PutResult = Client.Put(ObjectPath, IoBuffer(Blob)); + REQUIRE(PutResult); + + ScopedTemporaryDirectory DownloadDir; + + // Single range via Download: verify Ranges is populated and GetRanges maps correctly { - std::string Key = fmt::format("listing/item-{}.txt", i); - std::string Payload = fmt::format("content-{}", i); - IoBuffer Buf = IoBufferBuilder::MakeFromMemory(MakeMemoryView(Payload)); - REQUIRE(Client.PutObject(Key, std::move(Buf)).IsSuccess()); - } + HttpClient::Response Result = Client.Download(ObjectPath, DownloadDir.Path(), {{"Range", "bytes=100-199"}}); + CHECK(Result.StatusCode == HttpResponseCode::PartialContent); + REQUIRE_EQ(Result.Ranges.size(), 1u); + CHECK_EQ(Result.Ranges[0].RangeOffset, 100u); + CHECK_EQ(Result.Ranges[0].RangeLength, 100u); - S3ListObjectsResult ListRes = Client.ListObjects("listing/"); - REQUIRE_MESSAGE(ListRes.IsSuccess(), ListRes.Error); - REQUIRE(ListRes.Objects.size() == 3); + std::vector<std::pair<uint64_t, uint64_t>> RequestedRanges = {{100, 100}}; + std::vector<std::pair<uint64_t, uint64_t>> PayloadRanges = Result.GetRanges(RequestedRanges); + REQUIRE_EQ(PayloadRanges.size(), 1u); + CHECK(Result.ResponsePayload.GetView().Mid(PayloadRanges[0].first, PayloadRanges[0].second).EqualBytes(BlobView.Mid(100, 100))); + } - std::vector<std::string> Keys; - for (const S3ObjectInfo& Obj : ListRes.Objects) + // Multi-range via Download: verify Ranges is populated for both parts and GetRanges maps correctly { - Keys.push_back(Obj.Key); - CHECK(Obj.Size > 0); + HttpClient::Response Result = Client.Download(ObjectPath, DownloadDir.Path(), {{"Range", "bytes=0-49,100-149"}}); + CHECK(Result.StatusCode == HttpResponseCode::PartialContent); + REQUIRE_EQ(Result.Ranges.size(), 2u); + CHECK_EQ(Result.Ranges[0].RangeOffset, 0u); + CHECK_EQ(Result.Ranges[0].RangeLength, 50u); + CHECK_EQ(Result.Ranges[1].RangeOffset, 100u); + CHECK_EQ(Result.Ranges[1].RangeLength, 50u); + + std::vector<std::pair<uint64_t, uint64_t>> RequestedRanges = {{0, 50}, {100, 50}}; + std::vector<std::pair<uint64_t, uint64_t>> PayloadRanges = Result.GetRanges(RequestedRanges); + REQUIRE_EQ(PayloadRanges.size(), 2u); + CHECK(Result.ResponsePayload.GetView().Mid(PayloadRanges[0].first, PayloadRanges[0].second).EqualBytes(BlobView.Mid(0, 50))); + CHECK(Result.ResponsePayload.GetView().Mid(PayloadRanges[1].first, PayloadRanges[1].second).EqualBytes(BlobView.Mid(100, 50))); } - std::sort(Keys.begin(), Keys.end()); - CHECK(Keys[0] == "listing/item-0.txt"); - CHECK(Keys[1] == "listing/item-1.txt"); - CHECK(Keys[2] == "listing/item-2.txt"); - - // -- LIST empty prefix -- - S3ListObjectsResult EmptyList = Client.ListObjects("no-such-prefix/"); - REQUIRE(EmptyList.IsSuccess()); - CHECK(EmptyList.Objects.empty()); } TEST_SUITE_END(); diff --git a/src/zenserver-test/process-tests.cpp b/src/zenserver-test/process-tests.cpp index ae11bb294..3f6476810 100644 --- a/src/zenserver-test/process-tests.cpp +++ b/src/zenserver-test/process-tests.cpp @@ -115,7 +115,7 @@ TEST_CASE("pipe.raii_cleanup") { StdoutPipeHandles Pipe; REQUIRE(CreateStdoutPipe(Pipe)); - // Pipe goes out of scope here — destructor should close both ends + // Pipe goes out of scope here - destructor should close both ends } } @@ -155,7 +155,7 @@ TEST_CASE("pipe.move_semantics") CHECK(Moved.WriteFd == -1); # endif - // Assigned goes out of scope — destructor closes handles + // Assigned goes out of scope - destructor closes handles } TEST_CASE("pipe.close_is_idempotent") diff --git a/src/zenserver-test/projectstore-tests.cpp b/src/zenserver-test/projectstore-tests.cpp index a37ecb6be..49d985abb 100644 --- a/src/zenserver-test/projectstore-tests.cpp +++ b/src/zenserver-test/projectstore-tests.cpp @@ -22,6 +22,7 @@ ZEN_THIRD_PARTY_INCLUDES_START ZEN_THIRD_PARTY_INCLUDES_END # include <random> +# include <thread> namespace zen::tests { @@ -40,326 +41,429 @@ TEST_CASE("project.basic") const uint16_t PortNumber = Instance1.SpawnServerAndWaitUntilReady(); - std::mt19937_64 mt; - - zen::StringBuilder<64> BaseUri; - BaseUri << fmt::format("http://localhost:{}", PortNumber); + std::string ServerUri = fmt::format("http://localhost:{}", PortNumber); std::filesystem::path BinPath = zen::GetRunningExecutablePath(); std::filesystem::path RootPath = BinPath.parent_path().parent_path(); BinPath = BinPath.lexically_relative(RootPath); - SUBCASE("build store init") + auto CreateProjectAndOplog = [&](std::string_view ProjectName, std::string_view OplogName) -> std::string { + HttpClient Http{ServerUri}; + + zen::CbObjectWriter Body; + Body << "id" << ProjectName; + Body << "root" << RootPath.c_str(); + Body << "project" + << "/zooom"; + Body << "engine" + << "/zooom"; + IoBuffer BodyBuf = Body.Save().GetBuffer().AsIoBuffer(); + auto Response = Http.Post(fmt::format("/prj/{}", ProjectName), BodyBuf); + REQUIRE(Response.StatusCode == HttpResponseCode::Created); + + std::string OplogUri = fmt::format("{}/prj/{}/oplog/{}", ServerUri, ProjectName, OplogName); + HttpClient OplogHttp{OplogUri}; + auto OplogResponse = OplogHttp.Post(""sv, IoBuffer{}, ZenContentType::kCbObject); + REQUIRE(OplogResponse.StatusCode == HttpResponseCode::Created); + + return OplogUri; + }; + + // Create a file at a path exceeding Windows MAX_PATH (260 chars) for long filename testing + std::filesystem::path LongPathDir = RootPath / "longpathtest"; + for (int I = 0; I < 5; ++I) { - { - HttpClient Http{BaseUri}; + LongPathDir /= std::string(50, char('a' + I)); + } + std::filesystem::path LongFilePath = LongPathDir / "testfile.bin"; + std::filesystem::path LongRelPath = LongFilePath.lexically_relative(RootPath); - { - zen::CbObjectWriter Body; - Body << "id" - << "test"; - Body << "root" << RootPath.c_str(); - Body << "project" - << "/zooom"; - Body << "engine" - << "/zooom"; - - zen::BinaryWriter MemOut; - IoBuffer BodyBuf = Body.Save().GetBuffer().AsIoBuffer(); - - auto Response = Http.Post("/prj/test"sv, BodyBuf); - CHECK(Response.StatusCode == HttpResponseCode::Created); - } + const uint8_t LongPathFileData[] = {0xDE, 0xAD, 0xBE, 0xEF}; + CreateDirectories(MakeSafeAbsolutePath(LongPathDir)); + WriteFile(MakeSafeAbsolutePath(LongFilePath), IoBufferBuilder::MakeCloneFromMemory(LongPathFileData, sizeof(LongPathFileData))); + CHECK(LongRelPath.string().length() > 260); - { - auto Response = Http.Get("/prj/test"sv); - REQUIRE(Response.StatusCode == HttpResponseCode::OK); + std::string LongClientPath = "/{engine}/client"; + for (int I = 0; I < 5; ++I) + { + LongClientPath += '/'; + LongClientPath.append(50, char('a' + I)); + } + LongClientPath += "/longfile.bin"; + CHECK(LongClientPath.length() > 260); - CbObject ResponseObject = Response.AsObject(); + const std::string_view LongPathChunkId{ + "00000000" + "00000000" + "00020000"}; + auto LongPathFileOid = zen::Oid::FromHexString(LongPathChunkId); - CHECK(ResponseObject["id"].AsString() == "test"sv); - CHECK(ResponseObject["root"].AsString() == PathToUtf8(RootPath.c_str())); - } + // --- build store persistence --- + // First section also verifies project and oplog creation responses. + { + HttpClient ServerHttp{ServerUri}; + + { + zen::CbObjectWriter Body; + Body << "id" + << "test_persist"; + Body << "root" << RootPath.c_str(); + Body << "project" + << "/zooom"; + Body << "engine" + << "/zooom"; + IoBuffer BodyBuf = Body.Save().GetBuffer().AsIoBuffer(); + + auto Response = ServerHttp.Post("/prj/test_persist"sv, BodyBuf); + CHECK(Response.StatusCode == HttpResponseCode::Created); + } + + { + auto Response = ServerHttp.Get("/prj/test_persist"sv); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); + + CbObject ResponseObject = Response.AsObject(); + + CHECK(ResponseObject["id"].AsString() == "test_persist"sv); + CHECK(ResponseObject["root"].AsString() == PathToUtf8(RootPath.c_str())); } - BaseUri << "/prj/test/oplog/foobar"; + std::string OplogUri = fmt::format("{}/prj/test_persist/oplog/oplog_persist", ServerUri); { - HttpClient Http{BaseUri}; + HttpClient OplogHttp{OplogUri}; { - auto Response = Http.Post(""sv, IoBuffer{}, ZenContentType::kCbObject); + auto Response = OplogHttp.Post(""sv, IoBuffer{}, ZenContentType::kCbObject); CHECK(Response.StatusCode == HttpResponseCode::Created); } { - auto Response = Http.Get(""sv); + auto Response = OplogHttp.Get(""sv); REQUIRE(Response.StatusCode == HttpResponseCode::OK); CbObject ResponseObject = Response.AsObject(); - CHECK(ResponseObject["id"].AsString() == "foobar"sv); - CHECK(ResponseObject["project"].AsString() == "test"sv); + CHECK(ResponseObject["id"].AsString() == "oplog_persist"sv); + CHECK(ResponseObject["project"].AsString() == "test_persist"sv); } } - // Create a file at a path exceeding Windows MAX_PATH (260 chars) for long filename testing - std::filesystem::path LongPathDir = RootPath / "longpathtest"; - for (int I = 0; I < 5; ++I) + uint8_t AttachData[] = {1, 2, 3}; + + zen::CompressedBuffer Attachment = zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone(zen::MemoryView{AttachData, 3})); + zen::CbAttachment Attach{Attachment, Attachment.DecodeRawHash()}; + + zen::CbObjectWriter OpWriter; + OpWriter << "key" + << "foo" + << "attachment" << Attach; + + const std::string_view ChunkId{ + "00000000" + "00000000" + "00010000"}; + auto FileOid = zen::Oid::FromHexString(ChunkId); + + OpWriter.BeginArray("files"); + OpWriter.BeginObject(); + OpWriter << "id" << FileOid; + OpWriter << "clientpath" + << "/{engine}/client/side/path"; + OpWriter << "serverpath" << BinPath.c_str(); + OpWriter.EndObject(); + OpWriter.BeginObject(); + OpWriter << "id" << LongPathFileOid; + OpWriter << "clientpath" << LongClientPath; + OpWriter << "serverpath" << LongRelPath.c_str(); + OpWriter.EndObject(); + OpWriter.EndArray(); + + zen::CbObject Op = OpWriter.Save(); + + zen::CbPackage OpPackage(Op); + OpPackage.AddAttachment(Attach); + + zen::BinaryWriter MemOut; + legacy::SaveCbPackage(OpPackage, MemOut); + + HttpClient Http{OplogUri}; + { - LongPathDir /= std::string(50, char('a' + I)); - } - std::filesystem::path LongFilePath = LongPathDir / "testfile.bin"; - std::filesystem::path LongRelPath = LongFilePath.lexically_relative(RootPath); + auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView())); - const uint8_t LongPathFileData[] = {0xDE, 0xAD, 0xBE, 0xEF}; - CreateDirectories(MakeSafeAbsolutePath(LongPathDir)); - WriteFile(MakeSafeAbsolutePath(LongFilePath), IoBufferBuilder::MakeCloneFromMemory(LongPathFileData, sizeof(LongPathFileData))); - CHECK(LongRelPath.string().length() > 260); + REQUIRE(Response); + CHECK(Response.StatusCode == HttpResponseCode::Created); + } - std::string LongClientPath = "/{engine}/client"; - for (int I = 0; I < 5; ++I) { - LongClientPath += '/'; - LongClientPath.append(50, char('a' + I)); + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << "/" << ChunkId; + auto Response = Http.Get(ChunkGetUri); + + REQUIRE(Response); + CHECK(Response.StatusCode == HttpResponseCode::OK); } - LongClientPath += "/longfile.bin"; - CHECK(LongClientPath.length() > 260); - const std::string_view LongPathChunkId{ - "00000000" - "00000000" - "00020000"}; - auto LongPathFileOid = zen::Oid::FromHexString(LongPathChunkId); + { + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << "/" << ChunkId << "?offset=1&size=10"; + auto Response = Http.Get(ChunkGetUri); + + REQUIRE(Response); + CHECK(Response.StatusCode == HttpResponseCode::OK); + CHECK(Response.ResponsePayload.GetSize() == 10); + } - SUBCASE("build store persistence") { - uint8_t AttachData[] = {1, 2, 3}; - - zen::CompressedBuffer Attachment = zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone(zen::MemoryView{AttachData, 3})); - zen::CbAttachment Attach{Attachment, Attachment.DecodeRawHash()}; - - zen::CbObjectWriter OpWriter; - OpWriter << "key" - << "foo" - << "attachment" << Attach; - - const std::string_view ChunkId{ - "00000000" - "00000000" - "00010000"}; - auto FileOid = zen::Oid::FromHexString(ChunkId); - - OpWriter.BeginArray("files"); - OpWriter.BeginObject(); - OpWriter << "id" << FileOid; - OpWriter << "clientpath" - << "/{engine}/client/side/path"; - OpWriter << "serverpath" << BinPath.c_str(); - OpWriter.EndObject(); - OpWriter.BeginObject(); - OpWriter << "id" << LongPathFileOid; - OpWriter << "clientpath" << LongClientPath; - OpWriter << "serverpath" << LongRelPath.c_str(); - OpWriter.EndObject(); - OpWriter.EndArray(); - - zen::CbObject Op = OpWriter.Save(); - - zen::CbPackage OpPackage(Op); - OpPackage.AddAttachment(Attach); - - zen::BinaryWriter MemOut; - legacy::SaveCbPackage(OpPackage, MemOut); - - HttpClient Http{BaseUri}; + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << "/" << LongPathChunkId; + auto Response = Http.Get(ChunkGetUri); - { - auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView())); + REQUIRE(Response); + CHECK(Response.StatusCode == HttpResponseCode::OK); + CHECK(Response.ResponsePayload.GetSize() == sizeof(LongPathFileData)); + } - REQUIRE(Response); - CHECK(Response.StatusCode == HttpResponseCode::Created); - } + ZEN_INFO("+++++++"); + } - // Read file data + // --- snapshot --- + { + std::string OplogUri = CreateProjectAndOplog("test_snap", "oplog_snap"); - { - zen::StringBuilder<128> ChunkGetUri; - ChunkGetUri << "/" << ChunkId; - auto Response = Http.Get(ChunkGetUri); + zen::CbObjectWriter OpWriter; + OpWriter << "key" + << "foo"; - REQUIRE(Response); - CHECK(Response.StatusCode == HttpResponseCode::OK); - } + const std::string_view ChunkId{ + "00000000" + "00000000" + "00010000"}; + auto FileOid = zen::Oid::FromHexString(ChunkId); + + OpWriter.BeginArray("files"); + OpWriter.BeginObject(); + OpWriter << "id" << FileOid; + OpWriter << "clientpath" + << "/{engine}/client/side/path"; + OpWriter << "serverpath" << BinPath.c_str(); + OpWriter.EndObject(); + OpWriter.BeginObject(); + OpWriter << "id" << LongPathFileOid; + OpWriter << "clientpath" << LongClientPath; + OpWriter << "serverpath" << LongRelPath.c_str(); + OpWriter.EndObject(); + OpWriter.EndArray(); + + zen::CbObject Op = OpWriter.Save(); + + zen::CbPackage OpPackage(Op); - { - zen::StringBuilder<128> ChunkGetUri; - ChunkGetUri << "/" << ChunkId << "?offset=1&size=10"; - auto Response = Http.Get(ChunkGetUri); + zen::BinaryWriter MemOut; + legacy::SaveCbPackage(OpPackage, MemOut); - REQUIRE(Response); - CHECK(Response.StatusCode == HttpResponseCode::OK); - CHECK(Response.ResponsePayload.GetSize() == 10); - } + HttpClient Http{OplogUri}; - // Read long-path file data - { - zen::StringBuilder<128> ChunkGetUri; - ChunkGetUri << "/" << LongPathChunkId; - auto Response = Http.Get(ChunkGetUri); + { + auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView())); - REQUIRE(Response); - CHECK(Response.StatusCode == HttpResponseCode::OK); - CHECK(Response.ResponsePayload.GetSize() == sizeof(LongPathFileData)); - } + REQUIRE(Response); + CHECK(Response.StatusCode == HttpResponseCode::Created); + } + + // Read file data, it is raw and uncompressed + { + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << "/" << ChunkId; + auto Response = Http.Get(ChunkGetUri); - ZEN_INFO("+++++++"); + REQUIRE(Response); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); + + IoBuffer Data = Response.ResponsePayload; + IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath); + CHECK(ReferenceData.GetSize() == Data.GetSize()); + CHECK(ReferenceData.GetView().EqualBytes(Data.GetView())); } - SUBCASE("snapshot") + // Read long-path file data, it is raw and uncompressed { - zen::CbObjectWriter OpWriter; - OpWriter << "key" - << "foo"; - - const std::string_view ChunkId{ - "00000000" - "00000000" - "00010000"}; - auto FileOid = zen::Oid::FromHexString(ChunkId); - - OpWriter.BeginArray("files"); - OpWriter.BeginObject(); - OpWriter << "id" << FileOid; - OpWriter << "clientpath" - << "/{engine}/client/side/path"; - OpWriter << "serverpath" << BinPath.c_str(); - OpWriter.EndObject(); - OpWriter.BeginObject(); - OpWriter << "id" << LongPathFileOid; - OpWriter << "clientpath" << LongClientPath; - OpWriter << "serverpath" << LongRelPath.c_str(); - OpWriter.EndObject(); - OpWriter.EndArray(); - - zen::CbObject Op = OpWriter.Save(); - - zen::CbPackage OpPackage(Op); - - zen::BinaryWriter MemOut; - legacy::SaveCbPackage(OpPackage, MemOut); - - HttpClient Http{BaseUri}; + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << "/" << LongPathChunkId; + auto Response = Http.Get(ChunkGetUri); - { - auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView())); + REQUIRE(Response); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); - REQUIRE(Response); - CHECK(Response.StatusCode == HttpResponseCode::Created); - } + IoBuffer Data = Response.ResponsePayload; + MemoryView ExpectedView{LongPathFileData, sizeof(LongPathFileData)}; + CHECK(Data.GetSize() == sizeof(LongPathFileData)); + CHECK(Data.GetView().EqualBytes(ExpectedView)); + } - // Read file data, it is raw and uncompressed - { - zen::StringBuilder<128> ChunkGetUri; - ChunkGetUri << "/" << ChunkId; - auto Response = Http.Get(ChunkGetUri); + { + IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer.AddString("method"sv, "snapshot"sv); }); + auto Response = Http.Post("/rpc"sv, Payload); + REQUIRE(Response); + CHECK(Response.StatusCode == HttpResponseCode::OK); + } - REQUIRE(Response); - REQUIRE(Response.StatusCode == HttpResponseCode::OK); + // Read chunk data, it is now compressed + { + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << "/" << ChunkId; + auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}}); - IoBuffer Data = Response.ResponsePayload; - IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath); - CHECK(ReferenceData.GetSize() == Data.GetSize()); - CHECK(ReferenceData.GetView().EqualBytes(Data.GetView())); - } + REQUIRE(Response); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); + + IoBuffer Data = Response.ResponsePayload; + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize); + REQUIRE(Compressed); + IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer(); + IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath); + CHECK(RawSize == ReferenceData.GetSize()); + CHECK(ReferenceData.GetSize() == DataDecompressed.GetSize()); + CHECK(ReferenceData.GetView().EqualBytes(DataDecompressed.GetView())); + } - // Read long-path file data, it is raw and uncompressed - { - zen::StringBuilder<128> ChunkGetUri; - ChunkGetUri << "/" << LongPathChunkId; - auto Response = Http.Get(ChunkGetUri); + // Read compressed long-path file data after snapshot + { + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << "/" << LongPathChunkId; + auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}}); - REQUIRE(Response); - REQUIRE(Response.StatusCode == HttpResponseCode::OK); + REQUIRE(Response); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); + + IoBuffer Data = Response.ResponsePayload; + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize); + REQUIRE(Compressed); + IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer(); + MemoryView ExpectedView{LongPathFileData, sizeof(LongPathFileData)}; + CHECK(RawSize == sizeof(LongPathFileData)); + CHECK(DataDecompressed.GetSize() == sizeof(LongPathFileData)); + CHECK(DataDecompressed.GetView().EqualBytes(ExpectedView)); + } - IoBuffer Data = Response.ResponsePayload; - MemoryView ExpectedView{LongPathFileData, sizeof(LongPathFileData)}; - CHECK(Data.GetSize() == sizeof(LongPathFileData)); - CHECK(Data.GetView().EqualBytes(ExpectedView)); - } + ZEN_INFO("+++++++"); + } - { - IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer.AddString("method"sv, "snapshot"sv); }); - auto Response = Http.Post("/rpc"sv, Payload); - REQUIRE(Response); - CHECK(Response.StatusCode == HttpResponseCode::OK); - } + // --- snapshot zero byte file --- + { + std::string OplogUri = CreateProjectAndOplog("test_zero", "oplog_zero"); - // Read chunk data, it is now compressed - { - zen::StringBuilder<128> ChunkGetUri; - ChunkGetUri << "/" << ChunkId; - auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}}); + std::filesystem::path EmptyFileRelPath = std::filesystem::path("zerobyte_snapshot_test") / "empty.bin"; + std::filesystem::path EmptyFileAbsPath = RootPath / EmptyFileRelPath; + CreateDirectories(MakeSafeAbsolutePath(EmptyFileAbsPath.parent_path())); + WriteFile(MakeSafeAbsolutePath(EmptyFileAbsPath), IoBuffer{}); + REQUIRE(IsFile(MakeSafeAbsolutePath(EmptyFileAbsPath))); - REQUIRE(Response); - REQUIRE(Response.StatusCode == HttpResponseCode::OK); + const std::string_view EmptyChunkId{ + "00000000" + "00000000" + "00030000"}; + auto EmptyFileOid = zen::Oid::FromHexString(EmptyChunkId); + + zen::CbObjectWriter OpWriter; + OpWriter << "key" + << "zero_byte_test"; + OpWriter.BeginArray("files"); + OpWriter.BeginObject(); + OpWriter << "id" << EmptyFileOid; + OpWriter << "clientpath" + << "/{engine}/empty_file"; + OpWriter << "serverpath" << EmptyFileRelPath.c_str(); + OpWriter.EndObject(); + OpWriter.EndArray(); + + zen::CbObject Op = OpWriter.Save(); + zen::CbPackage OpPackage(Op); - IoBuffer Data = Response.ResponsePayload; - IoHash RawHash; - uint64_t RawSize; - CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize); - REQUIRE(Compressed); - IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer(); - IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath); - CHECK(RawSize == ReferenceData.GetSize()); - CHECK(ReferenceData.GetSize() == DataDecompressed.GetSize()); - CHECK(ReferenceData.GetView().EqualBytes(DataDecompressed.GetView())); - } + zen::BinaryWriter MemOut; + legacy::SaveCbPackage(OpPackage, MemOut); - // Read compressed long-path file data after snapshot - { - zen::StringBuilder<128> ChunkGetUri; - ChunkGetUri << "/" << LongPathChunkId; - auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}}); + HttpClient Http{OplogUri}; - REQUIRE(Response); - REQUIRE(Response.StatusCode == HttpResponseCode::OK); + { + auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView())); + REQUIRE(Response); + CHECK(Response.StatusCode == HttpResponseCode::Created); + } - IoBuffer Data = Response.ResponsePayload; - IoHash RawHash; - uint64_t RawSize; - CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize); - REQUIRE(Compressed); - IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer(); - MemoryView ExpectedView{LongPathFileData, sizeof(LongPathFileData)}; - CHECK(RawSize == sizeof(LongPathFileData)); - CHECK(DataDecompressed.GetSize() == sizeof(LongPathFileData)); - CHECK(DataDecompressed.GetView().EqualBytes(ExpectedView)); - } + // Read file data before snapshot - raw and uncompressed, 0 bytes. + // http.sys converts a 200 OK with empty body to 204 No Content, so + // accept either status code. + { + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << "/" << EmptyChunkId; + auto Response = Http.Get(ChunkGetUri); - ZEN_INFO("+++++++"); + REQUIRE(Response); + CHECK((Response.StatusCode == HttpResponseCode::OK || Response.StatusCode == HttpResponseCode::NoContent)); + CHECK(Response.ResponsePayload.GetSize() == 0); } - SUBCASE("test chunk not found error") + // Trigger snapshot. { - HttpClient Http{BaseUri}; + IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer.AddString("method"sv, "snapshot"sv); }); + auto Response = Http.Post("/rpc"sv, Payload); + REQUIRE(Response); + CHECK(Response.StatusCode == HttpResponseCode::OK); + } - for (size_t I = 0; I < 65; I++) - { - zen::StringBuilder<128> PostUri; - PostUri << "/f77c781846caead318084604/info"; - auto Response = Http.Get(PostUri); + // Read chunk after snapshot - compressed, decompresses to 0 bytes. + { + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << "/" << EmptyChunkId; + auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}}); - REQUIRE(!Response.Error); - CHECK(Response.StatusCode == HttpResponseCode::NotFound); - } + REQUIRE(Response); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); + + IoBuffer Data = Response.ResponsePayload; + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize); + REQUIRE(Compressed); + CHECK(RawSize == 0); + IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer(); + CHECK(DataDecompressed.GetSize() == 0); } - // Cleanup long-path test directory { std::error_code Ec; - DeleteDirectories(MakeSafeAbsolutePath(RootPath / "longpathtest"), Ec); + DeleteDirectories(MakeSafeAbsolutePath(RootPath / "zerobyte_snapshot_test"), Ec); } + + ZEN_INFO("+++++++"); + } + + // --- test chunk not found error --- + { + std::string OplogUri = CreateProjectAndOplog("test_notfound", "oplog_notfound"); + HttpClient Http{OplogUri}; + + for (size_t I = 0; I < 65; I++) + { + zen::StringBuilder<128> PostUri; + PostUri << "/f77c781846caead318084604/info"; + auto Response = Http.Get(PostUri); + + REQUIRE(!Response.Error); + CHECK(Response.StatusCode == HttpResponseCode::NotFound); + } + } + + // Cleanup long-path test directory + { + std::error_code Ec; + DeleteDirectories(MakeSafeAbsolutePath(RootPath / "longpathtest"), Ec); } } @@ -656,86 +760,102 @@ TEST_CASE("project.remote") } }; - SUBCASE("File") + // --- Zen --- + // NOTE: Zen export must run before file-based exports from the same source + // oplog. A prior file export leaves server-side state that causes a + // subsequent zen-protocol export from the same oplog to abort. { + INFO("Zen"); ScopedTemporaryDirectory TempDir; { - IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) { + std::string ExportSourceUri = Servers.GetInstance(0).GetBaseUri(); + std::string ExportTargetUri = Servers.GetInstance(1).GetBaseUri(); + MakeProject(ExportTargetUri, "proj0_zen"); + MakeOplog(ExportTargetUri, "proj0_zen", "oplog0_zen"); + + IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer << "method"sv << "export"sv; Writer << "params" << BeginObject; { Writer << "maxblocksize"sv << 3072u; Writer << "maxchunkembedsize"sv << 1296u; - Writer << "chunkfilesizelimit"sv << 5u * 1024u; Writer << "maxchunksperblock"sv << 16u; + Writer << "chunkfilesizelimit"sv << 5u * 1024u; Writer << "force"sv << false; - Writer << "file"sv << BeginObject; + Writer << "zen"sv << BeginObject; { - Writer << "path"sv << path; - Writer << "name"sv - << "proj0_oplog0"sv; + Writer << "url"sv << ExportTargetUri.substr(7); + Writer << "project" + << "proj0_zen"; + Writer << "oplog" + << "oplog0_zen"; } - Writer << EndObject; // "file" + Writer << EndObject; // "zen" } Writer << EndObject; // "params" }); - HttpClient Http{Servers.GetInstance(0).GetBaseUri()}; - + HttpClient Http{Servers.GetInstance(0).GetBaseUri()}; HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0", "oplog0"), Payload); HttpWaitForCompletion(Servers.GetInstance(0), Response); } + ValidateAttachments(1, "proj0_zen", "oplog0_zen"); + ValidateOplog(1, "proj0_zen", "oplog0_zen"); + { - MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_copy"); - MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy"); + std::string ImportSourceUri = Servers.GetInstance(1).GetBaseUri(); + std::string ImportTargetUri = Servers.GetInstance(2).GetBaseUri(); + MakeProject(ImportTargetUri, "proj1"); + MakeOplog(ImportTargetUri, "proj1", "oplog1"); - IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) { + IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer << "method"sv << "import"sv; Writer << "params" << BeginObject; { Writer << "force"sv << false; - Writer << "file"sv << BeginObject; + Writer << "zen"sv << BeginObject; { - Writer << "path"sv << path; - Writer << "name"sv - << "proj0_oplog0"sv; + Writer << "url"sv << ImportSourceUri.substr(7); + Writer << "project" + << "proj0_zen"; + Writer << "oplog" + << "oplog0_zen"; } - Writer << EndObject; // "file" + Writer << EndObject; // "zen" } Writer << EndObject; // "params" }); - HttpClient Http{Servers.GetInstance(1).GetBaseUri()}; - - HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_copy", "oplog0_copy"), Payload); - HttpWaitForCompletion(Servers.GetInstance(1), Response); + HttpClient Http{Servers.GetInstance(2).GetBaseUri()}; + HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj1", "oplog1"), Payload); + HttpWaitForCompletion(Servers.GetInstance(2), Response); } - ValidateAttachments(1, "proj0_copy", "oplog0_copy"); - ValidateOplog(1, "proj0_copy", "oplog0_copy"); + ValidateAttachments(2, "proj1", "oplog1"); + ValidateOplog(2, "proj1", "oplog1"); } - SUBCASE("File disable blocks") + // --- File --- { + INFO("File"); ScopedTemporaryDirectory TempDir; { - IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { + IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) { Writer << "method"sv << "export"sv; Writer << "params" << BeginObject; { Writer << "maxblocksize"sv << 3072u; Writer << "maxchunkembedsize"sv << 1296u; - Writer << "maxchunksperblock"sv << 16u; Writer << "chunkfilesizelimit"sv << 5u * 1024u; - Writer << "force"sv << false; + Writer << "maxchunksperblock"sv << 16u; + Writer << "force"sv << true; Writer << "file"sv << BeginObject; { - Writer << "path"sv << TempDir.Path().string(); + Writer << "path"sv << path; Writer << "name"sv << "proj0_oplog0"sv; - Writer << "disableblocks"sv << true; } Writer << EndObject; // "file" } @@ -748,9 +868,10 @@ TEST_CASE("project.remote") HttpWaitForCompletion(Servers.GetInstance(0), Response); } { - MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_copy"); - MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy"); - IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { + MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_file"); + MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_file", "oplog0_file"); + + IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) { Writer << "method"sv << "import"sv; Writer << "params" << BeginObject; @@ -758,7 +879,7 @@ TEST_CASE("project.remote") Writer << "force"sv << false; Writer << "file"sv << BeginObject; { - Writer << "path"sv << TempDir.Path().string(); + Writer << "path"sv << path; Writer << "name"sv << "proj0_oplog0"sv; } @@ -769,15 +890,16 @@ TEST_CASE("project.remote") HttpClient Http{Servers.GetInstance(1).GetBaseUri()}; - HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_copy", "oplog0_copy"), Payload); + HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_file", "oplog0_file"), Payload); HttpWaitForCompletion(Servers.GetInstance(1), Response); } - ValidateAttachments(1, "proj0_copy", "oplog0_copy"); - ValidateOplog(1, "proj0_copy", "oplog0_copy"); + ValidateAttachments(1, "proj0_file", "oplog0_file"); + ValidateOplog(1, "proj0_file", "oplog0_file"); } - SUBCASE("File force temp blocks") + // --- File disable blocks --- { + INFO("File disable blocks"); ScopedTemporaryDirectory TempDir; { IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { @@ -789,26 +911,27 @@ TEST_CASE("project.remote") Writer << "maxchunkembedsize"sv << 1296u; Writer << "maxchunksperblock"sv << 16u; Writer << "chunkfilesizelimit"sv << 5u * 1024u; - Writer << "force"sv << false; + Writer << "force"sv << true; Writer << "file"sv << BeginObject; { Writer << "path"sv << TempDir.Path().string(); Writer << "name"sv << "proj0_oplog0"sv; - Writer << "enabletempblocks"sv << true; + Writer << "disableblocks"sv << true; } Writer << EndObject; // "file" } Writer << EndObject; // "params" }); - HttpClient Http{Servers.GetInstance(0).GetBaseUri()}; + HttpClient Http{Servers.GetInstance(0).GetBaseUri()}; + HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0", "oplog0"), Payload); HttpWaitForCompletion(Servers.GetInstance(0), Response); } { - MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_copy"); - MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy"); + MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_noblock"); + MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_noblock", "oplog0_noblock"); IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer << "method"sv << "import"sv; @@ -826,23 +949,20 @@ TEST_CASE("project.remote") Writer << EndObject; // "params" }); - HttpClient Http{Servers.GetInstance(1).GetBaseUri()}; - HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_copy", "oplog0_copy"), Payload); + HttpClient Http{Servers.GetInstance(1).GetBaseUri()}; + + HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_noblock", "oplog0_noblock"), Payload); HttpWaitForCompletion(Servers.GetInstance(1), Response); } - ValidateAttachments(1, "proj0_copy", "oplog0_copy"); - ValidateOplog(1, "proj0_copy", "oplog0_copy"); + ValidateAttachments(1, "proj0_noblock", "oplog0_noblock"); + ValidateOplog(1, "proj0_noblock", "oplog0_noblock"); } - SUBCASE("Zen") + // --- File force temp blocks --- { + INFO("File force temp blocks"); ScopedTemporaryDirectory TempDir; { - std::string ExportSourceUri = Servers.GetInstance(0).GetBaseUri(); - std::string ExportTargetUri = Servers.GetInstance(1).GetBaseUri(); - MakeProject(ExportTargetUri, "proj0_copy"); - MakeOplog(ExportTargetUri, "proj0_copy", "oplog0_copy"); - IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer << "method"sv << "export"sv; @@ -852,14 +972,13 @@ TEST_CASE("project.remote") Writer << "maxchunkembedsize"sv << 1296u; Writer << "maxchunksperblock"sv << 16u; Writer << "chunkfilesizelimit"sv << 5u * 1024u; - Writer << "force"sv << false; - Writer << "zen"sv << BeginObject; + Writer << "force"sv << true; + Writer << "file"sv << BeginObject; { - Writer << "url"sv << ExportTargetUri.substr(7); - Writer << "project" - << "proj0_copy"; - Writer << "oplog" - << "oplog0_copy"; + Writer << "path"sv << TempDir.Path().string(); + Writer << "name"sv + << "proj0_oplog0"sv; + Writer << "enabletempblocks"sv << true; } Writer << EndObject; // "file" } @@ -870,40 +989,32 @@ TEST_CASE("project.remote") HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0", "oplog0"), Payload); HttpWaitForCompletion(Servers.GetInstance(0), Response); } - ValidateAttachments(1, "proj0_copy", "oplog0_copy"); - ValidateOplog(1, "proj0_copy", "oplog0_copy"); - { - std::string ImportSourceUri = Servers.GetInstance(1).GetBaseUri(); - std::string ImportTargetUri = Servers.GetInstance(2).GetBaseUri(); - MakeProject(ImportTargetUri, "proj1"); - MakeOplog(ImportTargetUri, "proj1", "oplog1"); - + MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_tmpblock"); + MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_tmpblock", "oplog0_tmpblock"); IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer << "method"sv << "import"sv; Writer << "params" << BeginObject; { Writer << "force"sv << false; - Writer << "zen"sv << BeginObject; + Writer << "file"sv << BeginObject; { - Writer << "url"sv << ImportSourceUri.substr(7); - Writer << "project" - << "proj0_copy"; - Writer << "oplog" - << "oplog0_copy"; + Writer << "path"sv << TempDir.Path().string(); + Writer << "name"sv + << "proj0_oplog0"sv; } Writer << EndObject; // "file" } Writer << EndObject; // "params" }); - HttpClient Http{Servers.GetInstance(2).GetBaseUri()}; - HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj1", "oplog1"), Payload); - HttpWaitForCompletion(Servers.GetInstance(2), Response); + HttpClient Http{Servers.GetInstance(1).GetBaseUri()}; + HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_tmpblock", "oplog0_tmpblock"), Payload); + HttpWaitForCompletion(Servers.GetInstance(1), Response); } - ValidateAttachments(2, "proj1", "oplog1"); - ValidateOplog(2, "proj1", "oplog1"); + ValidateAttachments(1, "proj0_tmpblock", "oplog0_tmpblock"); + ValidateOplog(1, "proj0_tmpblock", "oplog0_tmpblock"); } } @@ -1154,6 +1265,368 @@ TEST_CASE("project.rpcappendop") } } +TEST_CASE("project.file.data.transitions") +{ + using namespace utils; + + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + const uint16_t PortNumber = Instance.SpawnServerAndWaitUntilReady(); + + zen::StringBuilder<64> ServerBaseUri; + ServerBaseUri << fmt::format("http://localhost:{}", PortNumber); + + // Set up a root directory with a test file on disk for path-referenced serving + std::filesystem::path RootDir = TestDir / "root"; + std::filesystem::path TestFilePath = RootDir / "content" / "testfile.bin"; + std::filesystem::path RelServerPath = std::filesystem::path("content") / "testfile.bin"; + CreateDirectories(TestFilePath.parent_path()); + IoBuffer FileBlob = CreateRandomBlob(4096); + WriteFile(TestFilePath, FileBlob); + + // Create a compressed blob to use as a CAS-referenced attachment (content differs from FileBlob) + CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(CreateRandomBlob(2048))); + + // Fixed chunk IDs for the file entry across sub-tests + const std::string_view FileChunkIdStr{ + "aa000000" + "bb000000" + "cc000001"}; + Oid FileOid = Oid::FromHexString(FileChunkIdStr); + + HttpClient Http{ServerBaseUri}; + + auto MakeProject = [&](std::string_view ProjectName) { + CbObjectWriter Project; + Project.AddString("id"sv, ProjectName); + Project.AddString("root"sv, PathToUtf8(RootDir.c_str())); + Project.AddString("engine"sv, ""sv); + Project.AddString("project"sv, ""sv); + Project.AddString("projectfile"sv, ""sv); + HttpClient::Response Response = Http.Post(fmt::format("/prj/{}", ProjectName), Project.Save()); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("MakeProject")); + }; + + auto MakeOplog = [&](std::string_view ProjectName, std::string_view OplogName) { + HttpClient::Response Response = + Http.Post(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName), IoBuffer{}, ZenContentType::kCbObject); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("MakeOplog")); + }; + + auto PostOplogEntry = [&](std::string_view ProjectName, std::string_view OplogName, const CbPackage& OpPackage) { + zen::BinaryWriter MemOut; + legacy::SaveCbPackage(OpPackage, MemOut); + IoBuffer Body{IoBuffer::Wrap, MemOut.GetData(), MemOut.GetSize()}; + Body.SetContentType(HttpContentType::kCbPackage); + HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/new", ProjectName, OplogName), Body); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("PostOplogEntry")); + }; + + auto GetChunk = [&](std::string_view ProjectName) -> HttpClient::Response { + return Http.Get(fmt::format("/prj/{}/oplog/oplog/{}", ProjectName, FileChunkIdStr)); + }; + + // Extract the raw decompressed bytes from a chunk response, handling both compressed and uncompressed payloads + auto GetDecompressedPayload = [](const HttpClient::Response& Response) -> IoBuffer { + if (Response.ResponsePayload.GetContentType() == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Response.ResponsePayload), RawHash, RawSize); + REQUIRE(Compressed); + return Compressed.Decompress().AsIoBuffer(); + } + return Response.ResponsePayload; + }; + + auto TriggerGcAndWait = [&]() { + HttpClient::Response TriggerResponse = Http.Post("/admin/gc?smallobjects=true"sv, IoBuffer{}); + REQUIRE_MESSAGE(TriggerResponse.IsSuccess(), TriggerResponse.ErrorMessage("TriggerGc")); + + for (int Attempt = 0; Attempt < 100; ++Attempt) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + HttpClient::Response StatusResponse = Http.Get("/admin/gc"sv); + REQUIRE_MESSAGE(StatusResponse.IsSuccess(), StatusResponse.ErrorMessage("GcStatus")); + CbObject StatusObj = StatusResponse.AsObject(); + if (StatusObj["Status"sv].AsString() == "Idle"sv) + { + return; + } + } + FAIL("GC did not complete within timeout"); + }; + + auto BuildPathReferencedFileOp = [&](const Oid& KeyId) -> CbPackage { + CbPackage Package; + CbObjectWriter Object; + Object << "key"sv << OidAsString(KeyId); + Object.BeginArray("files"sv); + Object.BeginObject(); + Object << "id"sv << FileOid; + Object << "serverpath"sv << RelServerPath.string(); + Object << "clientpath"sv + << "/{engine}/testfile.bin"sv; + Object.EndObject(); + Object.EndArray(); + Package.SetObject(Object.Save()); + return Package; + }; + + auto BuildHashReferencedFileOp = [&](const Oid& KeyId, const CompressedBuffer& Blob) -> CbPackage { + CbPackage Package; + CbObjectWriter Object; + Object << "key"sv << OidAsString(KeyId); + CbAttachment Attach(Blob, Blob.DecodeRawHash()); + Object.BeginArray("files"sv); + Object.BeginObject(); + Object << "id"sv << FileOid; + Object << "data"sv << Attach; + Object << "clientpath"sv + << "/{engine}/testfile.bin"sv; + Object.EndObject(); + Object.EndArray(); + Package.AddAttachment(Attach); + Package.SetObject(Object.Save()); + return Package; + }; + + // --- path-referenced file is retrievable --- + { + MakeProject("proj_path"sv); + MakeOplog("proj_path"sv, "oplog"sv); + + CbPackage Op = BuildPathReferencedFileOp(Oid::NewOid()); + PostOplogEntry("proj_path"sv, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk("proj_path"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK_EQ(Payload.GetSize(), FileBlob.GetSize()); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + + // --- hash-referenced file is retrievable --- + { + MakeProject("proj_hash"sv); + MakeOplog("proj_hash"sv, "oplog"sv); + + CbPackage Op = BuildHashReferencedFileOp(Oid::NewOid(), CompressedBlob); + PostOplogEntry("proj_hash"sv, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk("proj_hash"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer(); + CHECK_EQ(Payload.GetSize(), ExpectedDecompressed.GetSize()); + CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView())); + } + } + + struct TransitionVariant + { + std::string_view Suffix; + bool SameOpKey; + bool RunGc; + }; + + static constexpr TransitionVariant Variants[] = { + {"_nk", false, false}, + {"_sk", true, false}, + {"_nk_gc", false, true}, + {"_sk_gc", true, true}, + }; + + // --- hash-referenced to path-referenced transition with different content --- + for (const TransitionVariant& V : Variants) + { + std::string ProjName = fmt::format("proj_h2pd{}", V.Suffix); + MakeProject(ProjName); + MakeOplog(ProjName, "oplog"sv); + + Oid FirstOpKey = Oid::NewOid(); + Oid SecondOpKey = V.SameOpKey ? FirstOpKey : Oid::NewOid(); + + { + CbPackage Op = BuildHashReferencedFileOp(FirstOpKey, CompressedBlob); + PostOplogEntry(ProjName, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk(ProjName); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer(); + CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView())); + } + } + + { + CbPackage Op = BuildPathReferencedFileOp(SecondOpKey); + PostOplogEntry(ProjName, "oplog"sv, Op); + } + + if (V.RunGc) + { + TriggerGcAndWait(); + } + + HttpClient::Response Response = GetChunk(ProjName); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK_EQ(Payload.GetSize(), FileBlob.GetSize()); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + + // --- hash-referenced to path-referenced transition with identical content --- + { + CompressedBuffer MatchingBlob = CompressedBuffer::Compress(SharedBuffer::Clone(FileBlob.GetView())); + + for (const TransitionVariant& V : Variants) + { + std::string ProjName = fmt::format("proj_h2ps{}", V.Suffix); + MakeProject(ProjName); + MakeOplog(ProjName, "oplog"sv); + + Oid FirstOpKey = Oid::NewOid(); + Oid SecondOpKey = V.SameOpKey ? FirstOpKey : Oid::NewOid(); + + { + CbPackage Op = BuildHashReferencedFileOp(FirstOpKey, MatchingBlob); + PostOplogEntry(ProjName, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk(ProjName); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + + { + CbPackage Op = BuildPathReferencedFileOp(SecondOpKey); + PostOplogEntry(ProjName, "oplog"sv, Op); + } + + if (V.RunGc) + { + TriggerGcAndWait(); + } + + HttpClient::Response Response = GetChunk(ProjName); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK_EQ(Payload.GetSize(), FileBlob.GetSize()); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + } + + // --- path-referenced to hash-referenced transition with different content --- + for (const TransitionVariant& V : Variants) + { + std::string ProjName = fmt::format("proj_p2hd{}", V.Suffix); + MakeProject(ProjName); + MakeOplog(ProjName, "oplog"sv); + + Oid FirstOpKey = Oid::NewOid(); + Oid SecondOpKey = V.SameOpKey ? FirstOpKey : Oid::NewOid(); + + { + CbPackage Op = BuildPathReferencedFileOp(FirstOpKey); + PostOplogEntry(ProjName, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk(ProjName); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + + { + CbPackage Op = BuildHashReferencedFileOp(SecondOpKey, CompressedBlob); + PostOplogEntry(ProjName, "oplog"sv, Op); + } + + if (V.RunGc) + { + TriggerGcAndWait(); + } + + HttpClient::Response Response = GetChunk(ProjName); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer(); + CHECK_EQ(Payload.GetSize(), ExpectedDecompressed.GetSize()); + CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView())); + } + } + + // --- path-referenced to hash-referenced transition with identical content --- + { + CompressedBuffer MatchingBlob = CompressedBuffer::Compress(SharedBuffer::Clone(FileBlob.GetView())); + + for (const TransitionVariant& V : Variants) + { + std::string ProjName = fmt::format("proj_p2hs{}", V.Suffix); + MakeProject(ProjName); + MakeOplog(ProjName, "oplog"sv); + + Oid FirstOpKey = Oid::NewOid(); + Oid SecondOpKey = V.SameOpKey ? FirstOpKey : Oid::NewOid(); + + { + CbPackage Op = BuildPathReferencedFileOp(FirstOpKey); + PostOplogEntry(ProjName, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk(ProjName); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + + { + CbPackage Op = BuildHashReferencedFileOp(SecondOpKey, MatchingBlob); + PostOplogEntry(ProjName, "oplog"sv, Op); + } + + if (V.RunGc) + { + TriggerGcAndWait(); + } + + HttpClient::Response Response = GetChunk(ProjName); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK_EQ(Payload.GetSize(), FileBlob.GetSize()); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + } +} + TEST_SUITE_END(); } // namespace zen::tests diff --git a/src/zenserver-test/xmake.lua b/src/zenserver-test/xmake.lua index 7b208bbc7..c240712ea 100644 --- a/src/zenserver-test/xmake.lua +++ b/src/zenserver-test/xmake.lua @@ -9,7 +9,7 @@ target("zenserver-test") add_deps("zencore", "zenremotestore", "zenhttp", "zencompute", "zenstore") add_deps("zenserver", {inherit=false}) add_deps("zentest-appstub", {inherit=false}) - add_packages("http_parser") + add_packages("llhttp") if has_config("zennomad") then add_deps("zennomad") diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp index cf7ffe4e4..d713f693f 100644 --- a/src/zenserver-test/zenserver-test.cpp +++ b/src/zenserver-test/zenserver-test.cpp @@ -199,7 +199,7 @@ TEST_CASE("default.single") HttpClient Http{fmt::format("http://localhost:{}", PortNumber)}; - for (int i = 0; i < 100; ++i) + for (int i = 0; i < 20; ++i) { auto res = Http.Get("/test/hello"sv); ++RequestCount; @@ -238,7 +238,6 @@ TEST_CASE("default.loopback") ZEN_INFO("Running loopback server test..."); - SUBCASE("ipv4 endpoint connectivity") { HttpClient Http{fmt::format("http://127.0.0.1:{}", PortNumber)}; @@ -247,7 +246,6 @@ TEST_CASE("default.loopback") CHECK(res); } - SUBCASE("ipv6 endpoint connectivity") { HttpClient Http{fmt::format("http://[::1]:{}", PortNumber)}; @@ -287,7 +285,7 @@ TEST_CASE("multi.basic") HttpClient Http{fmt::format("http://localhost:{}", PortNumber)}; - for (int i = 0; i < 100; ++i) + for (int i = 0; i < 20; ++i) { auto res = Http.Get("/test/hello"sv); ++RequestCount; @@ -401,13 +399,11 @@ TEST_CASE("http.unixsocket") Settings.UnixSocketPath = SocketPath; HttpClient Http{fmt::format("http://localhost:{}", PortNumber), Settings, {}}; - SUBCASE("GET over unix socket") { HttpClient::Response Res = Http.Get("/testing/hello"); CHECK(Res.IsSuccess()); } - SUBCASE("POST echo over unix socket") { IoBuffer Body{IoBuffer::Wrap, "unix-test", 9}; HttpClient::Response Res = Http.Post("/testing/echo", Body); @@ -431,13 +427,11 @@ TEST_CASE("http.nonetwork") Settings.UnixSocketPath = SocketPath; HttpClient Http{fmt::format("http://localhost:{}", PortNumber), Settings, {}}; - SUBCASE("GET over unix socket succeeds") { HttpClient::Response Res = Http.Get("/testing/hello"); CHECK(Res.IsSuccess()); } - SUBCASE("TCP connection is refused") { asio::io_context IoContext; asio::ip::tcp::socket Socket(IoContext); diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index 1673cea6c..b110f7538 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -22,6 +22,8 @@ # if ZEN_WITH_HORDE # include <zenhorde/hordeconfig.h> # include <zenhorde/hordeprovisioner.h> +# include <zenhttp/httpclientauth.h> +# include <zenutil/authutils.h> # endif # if ZEN_WITH_NOMAD # include <zennomad/nomadconfig.h> @@ -67,6 +69,20 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("compute", "", + "coordinator-session", + "Session ID of the orchestrator (for stale-instance rejection)", + cxxopts::value<std::string>(m_ServerOptions.CoordinatorSession)->default_value(""), + ""); + + Options.add_option("compute", + "", + "announce-url", + "Override URL announced to the coordinator (e.g. relay-visible endpoint)", + cxxopts::value<std::string>(m_ServerOptions.AnnounceUrl)->default_value(""), + ""); + + Options.add_option("compute", + "", "idms", "Enable IDMS cloud detection; optionally specify a custom probe endpoint", cxxopts::value<std::string>(m_ServerOptions.IdmsEndpoint)->default_value("")->implicit_value("auto"), @@ -79,6 +95,20 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value<bool>(m_ServerOptions.EnableWorkerWebSocket)->default_value("false"), ""); + Options.add_option("compute", + "", + "provision-clean", + "Pass --clean to provisioned worker instances so they wipe state on startup", + cxxopts::value<bool>(m_ServerOptions.ProvisionClean)->default_value("false"), + ""); + + Options.add_option("compute", + "", + "provision-tracehost", + "Pass --tracehost to provisioned worker instances for remote trace collection", + cxxopts::value<std::string>(m_ServerOptions.ProvisionTraceHost)->default_value(""), + ""); + # if ZEN_WITH_HORDE // Horde provisioning options Options.add_option("horde", @@ -139,6 +169,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("horde", "", + "horde-drain-grace-period", + "Grace period in seconds for draining agents before force-kill", + cxxopts::value<int>(m_ServerOptions.HordeConfig.DrainGracePeriodSeconds)->default_value("300"), + ""); + + Options.add_option("horde", + "", "horde-host", "Host address for Horde agents to connect back to", cxxopts::value<std::string>(m_ServerOptions.HordeConfig.HostAddress)->default_value(""), @@ -164,6 +201,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) "Port number for Zen service communication", cxxopts::value<uint16_t>(m_ServerOptions.HordeConfig.ZenServicePort)->default_value("8558"), ""); + + Options.add_option("horde", + "", + "horde-oidctoken-exe-path", + "Path to OidcToken executable for automatic Horde authentication", + cxxopts::value<std::string>(m_HordeOidcTokenExePath)->default_value(""), + ""); # endif # if ZEN_WITH_NOMAD @@ -313,6 +357,30 @@ ZenComputeServerConfigurator::ValidateOptions() # if ZEN_WITH_HORDE horde::FromString(m_ServerOptions.HordeConfig.Mode, m_HordeModeStr); horde::FromString(m_ServerOptions.HordeConfig.EncryptionMode, m_HordeEncryptionStr); + + // Set up OidcToken-based authentication if no static token was provided + if (m_ServerOptions.HordeConfig.AuthToken.empty() && !m_ServerOptions.HordeConfig.ServerUrl.empty()) + { + std::filesystem::path OidcExePath = FindOidcTokenExePath(m_HordeOidcTokenExePath); + if (!OidcExePath.empty()) + { + ZEN_INFO("using OidcToken executable for Horde authentication: {}", OidcExePath); + auto Provider = httpclientauth::CreateFromOidcTokenExecutable(OidcExePath, + m_ServerOptions.HordeConfig.ServerUrl, + /*Quiet=*/true, + /*Unattended=*/false, + /*Hidden=*/true, + /*IsHordeUrl=*/true); + if (Provider) + { + m_ServerOptions.HordeConfig.AccessTokenProvider = std::move(*Provider); + } + else + { + ZEN_WARN("OidcToken authentication failed; Horde requests will be unauthenticated"); + } + } + } # endif # if ZEN_WITH_NOMAD @@ -347,6 +415,8 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ } m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint; + m_CoordinatorSession = ServerConfig.CoordinatorSession; + m_AnnounceUrl = ServerConfig.AnnounceUrl; m_InstanceId = ServerConfig.InstanceId; m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket; @@ -379,13 +449,20 @@ ZenComputeServer::Cleanup() m_AnnounceTimer.cancel(); # if ZEN_WITH_HORDE - // Shut down Horde provisioner first — this signals all agent threads + // Disconnect the provisioner state provider before destroying the + // provisioner so the orchestrator HTTP layer cannot call into it. + if (m_OrchestratorService) + { + m_OrchestratorService->SetProvisionerStateProvider(nullptr); + } + + // Shut down Horde provisioner - this signals all agent threads // to exit and joins them before we tear down HTTP services. m_HordeProvisioner.reset(); # endif # if ZEN_WITH_NOMAD - // Shut down Nomad provisioner — stops the management thread and + // Shut down Nomad provisioner - stops the management thread and // sends stop requests for all tracked jobs. m_NomadProvisioner.reset(); # endif @@ -419,12 +496,12 @@ ZenComputeServer::Cleanup() m_IoRunner.join(); } - ShutdownServices(); - if (m_Http) { m_Http->Close(); } + + ShutdownServices(); } catch (const std::exception& Ex) { @@ -444,11 +521,12 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) ZEN_TRACE_CPU("ZenComputeServer::InitializeServices"); ZEN_INFO("initializing compute services"); - CidStoreConfiguration Config; - Config.RootDirectory = m_DataRoot / "cas"; + m_ActionStore = std::make_unique<MemoryCidStore>(); - m_CidStore = std::make_unique<CidStore>(m_GcManager); - m_CidStore->Initialize(Config); + CidStoreConfiguration WorkerStoreConfig; + WorkerStoreConfig.RootDirectory = m_DataRoot / "cas"; + m_WorkerStore = std::make_unique<CidStore>(m_GcManager); + m_WorkerStore->Initialize(WorkerStoreConfig); if (!ServerConfig.IdmsEndpoint.empty()) { @@ -476,10 +554,12 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) std::make_unique<zen::compute::HttpOrchestratorService>(ServerConfig.DataDir / "orch", ServerConfig.EnableWorkerWebSocket); ZEN_INFO("instantiating function service"); - m_ComputeService = std::make_unique<zen::compute::HttpComputeService>(*m_CidStore, + m_ComputeService = std::make_unique<zen::compute::HttpComputeService>(*m_ActionStore, + *m_WorkerStore, m_StatsService, ServerConfig.DataDir / "functions", ServerConfig.MaxConcurrentActions); + m_ComputeService->SetShutdownCallback([this] { RequestExit(0); }); m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatsService, m_StatusService); @@ -504,7 +584,11 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) OrchestratorEndpoint << '/'; } - m_NomadProvisioner = std::make_unique<nomad::NomadProvisioner>(NomadCfg, OrchestratorEndpoint); + m_NomadProvisioner = std::make_unique<nomad::NomadProvisioner>(NomadCfg, + OrchestratorEndpoint, + m_OrchestratorService->GetSessionId().ToString(), + ServerConfig.ProvisionClean, + ServerConfig.ProvisionTraceHost); } } # endif @@ -535,7 +619,14 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) : std::filesystem::path(HordeConfig.BinariesPath); std::filesystem::path WorkingDir = ServerConfig.DataDir / "horde"; - m_HordeProvisioner = std::make_unique<horde::HordeProvisioner>(HordeConfig, BinariesPath, WorkingDir, OrchestratorEndpoint); + m_HordeProvisioner = std::make_unique<horde::HordeProvisioner>(HordeConfig, + BinariesPath, + WorkingDir, + OrchestratorEndpoint, + m_OrchestratorService->GetSessionId().ToString(), + ServerConfig.ProvisionClean, + ServerConfig.ProvisionTraceHost); + m_OrchestratorService->SetProvisionerStateProvider(m_HordeProvisioner.get()); } } # endif @@ -563,6 +654,10 @@ ZenComputeServer::GetInstanceId() const std::string ZenComputeServer::GetAnnounceUrl() const { + if (!m_AnnounceUrl.empty()) + { + return m_AnnounceUrl; + } return m_Http->GetServiceUri(nullptr); } @@ -633,6 +728,11 @@ ZenComputeServer::BuildAnnounceBody() << "nomad"; } + if (!m_CoordinatorSession.empty()) + { + AnnounceBody << "coordinator_session" << m_CoordinatorSession; + } + ResolveCloudMetadata(); if (m_CloudMetadata) { @@ -779,8 +879,10 @@ ZenComputeServer::ProvisionerMaintenanceTick() # if ZEN_WITH_HORDE if (m_HordeProvisioner) { - m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX); + // Re-apply current target to spawn agent threads for any that have + // exited since the last tick, without overwriting a user-set target. auto Stats = m_HordeProvisioner->GetStats(); + m_HordeProvisioner->SetTargetCoreCount(Stats.TargetCoreCount); ZEN_DEBUG("Horde maintenance: target={}, estimated={}, active={}", Stats.TargetCoreCount, Stats.EstimatedCoreCount, @@ -882,12 +984,14 @@ ZenComputeServer::Run() OnReady(); + StartSelfSession("zencompute"); + PostAnnounce(); EnqueueAnnounceTimer(); InitializeOrchestratorWebSocket(); # if ZEN_WITH_HORDE - // Start Horde provisioning if configured — request maximum allowed cores. + // Start Horde provisioning if configured - request maximum allowed cores. // SetTargetCoreCount clamps to HordeConfig::MaxCores internally. if (m_HordeProvisioner) { @@ -899,7 +1003,7 @@ ZenComputeServer::Run() # endif # if ZEN_WITH_NOMAD - // Start Nomad provisioning if configured — request maximum allowed cores. + // Start Nomad provisioning if configured - request maximum allowed cores. // SetTargetCoreCount clamps to NomadConfig::MaxCores internally. if (m_NomadProvisioner) { diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h index 8f4edc0f0..aa9c1a5b3 100644 --- a/src/zenserver/compute/computeserver.h +++ b/src/zenserver/compute/computeserver.h @@ -10,6 +10,7 @@ # include <zencore/system.h> # include <zenhttp/httpwsclient.h> # include <zenstore/gc.h> +# include <zenstore/memorycidstore.h> # include "frontend/frontend.h" namespace cxxopts { @@ -41,7 +42,6 @@ class NomadProvisioner; namespace zen { -class CidStore; class HttpApiService; struct ZenComputeServerConfig : public ZenServerConfig @@ -49,9 +49,13 @@ struct ZenComputeServerConfig : public ZenServerConfig std::string UpstreamNotificationEndpoint; std::string InstanceId; // For use in notifications std::string CoordinatorEndpoint; + std::string CoordinatorSession; ///< Session ID for stale-instance rejection + std::string AnnounceUrl; ///< Override for self-announced URL (e.g. relay-visible endpoint) std::string IdmsEndpoint; int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2) - bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link + bool EnableWorkerWebSocket = false; // Use WebSocket for worker<->orchestrator link + bool ProvisionClean = false; // Pass --clean to provisioned workers + std::string ProvisionTraceHost; // Pass --tracehost to provisioned workers # if ZEN_WITH_HORDE horde::HordeConfig HordeConfig; @@ -84,6 +88,7 @@ private: # if ZEN_WITH_HORDE std::string m_HordeModeStr = "direct"; std::string m_HordeEncryptionStr = "none"; + std::string m_HordeOidcTokenExePath; # endif # if ZEN_WITH_NOMAD @@ -131,7 +136,8 @@ public: private: GcManager m_GcManager; GcScheduler m_GcScheduler{m_GcManager}; - std::unique_ptr<CidStore> m_CidStore; + std::unique_ptr<MemoryCidStore> m_ActionStore; + std::unique_ptr<CidStore> m_WorkerStore; std::unique_ptr<HttpApiService> m_ApiService; std::unique_ptr<zen::compute::HttpComputeService> m_ComputeService; std::unique_ptr<zen::compute::HttpOrchestratorService> m_OrchestratorService; @@ -146,6 +152,8 @@ private: # endif SystemMetricsTracker m_MetricsTracker; std::string m_CoordinatorEndpoint; + std::string m_CoordinatorSession; + std::string m_AnnounceUrl; std::string m_InstanceId; asio::steady_timer m_AnnounceTimer{m_IoContext}; @@ -163,7 +171,7 @@ private: std::string GetInstanceId() const; CbObject BuildAnnounceBody(); - // Worker→orchestrator WebSocket client + // Worker->orchestrator WebSocket client struct OrchestratorWsHandler : public IWsClientHandler { ZenComputeServer& Server; diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp index daad154bc..6449159fd 100644 --- a/src/zenserver/config/config.cpp +++ b/src/zenserver/config/config.cpp @@ -12,6 +12,7 @@ #include <zencore/compactbinaryutil.h> #include <zencore/compactbinaryvalidation.h> #include <zencore/except.h> +#include <zencore/filesystem.h> #include <zencore/fmtutils.h> #include <zencore/iobuffer.h> #include <zencore/logging.h> @@ -478,15 +479,27 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig throw std::runtime_error(fmt::format("'--snapshot-dir' ('{}') must be a directory", ServerOptions.BaseSnapshotDir)); } - ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir); - ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir); - ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir); - ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile); - ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); + SystemRootDir = ExpandEnvironmentVariables(SystemRootDir); + ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir); + + DataDir = ExpandEnvironmentVariables(DataDir); + ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir); + + ContentDir = ExpandEnvironmentVariables(ContentDir); + ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir); + + ConfigFile = ExpandEnvironmentVariables(ConfigFile); + ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile); + + BaseSnapshotDir = ExpandEnvironmentVariables(BaseSnapshotDir); + ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); + + ExpandEnvironmentVariables(SecurityConfigPath); ServerOptions.SecurityConfigPath = MakeSafeAbsolutePath(SecurityConfigPath); if (!UnixSocketPath.empty()) { + UnixSocketPath = ExpandEnvironmentVariables(UnixSocketPath); ServerOptions.HttpConfig.UnixSocketPath = MakeSafeAbsolutePath(UnixSocketPath); } diff --git a/src/zenserver/diag/logging.cpp b/src/zenserver/diag/logging.cpp index f3d8dbfe3..e1a8fed7d 100644 --- a/src/zenserver/diag/logging.cpp +++ b/src/zenserver/diag/logging.cpp @@ -112,7 +112,7 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService) const zen::Oid ServerSessionId = zen::GetSessionId(); logging::Registry::Instance().ApplyAll([&](auto Logger) { - static constinit logging::LogPoint SessionIdPoint{{}, logging::Info, "server session id: {}"}; + static constinit logging::LogPoint SessionIdPoint{0, 0, logging::Info, "server session id: {}"}; ZEN_MEMSCOPE(ELLMTag::Logging); Logger->Log(SessionIdPoint, fmt::make_format_args(ServerSessionId)); }); diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp index 52ec5b8b3..812536074 100644 --- a/src/zenserver/frontend/frontend.cpp +++ b/src/zenserver/frontend/frontend.cpp @@ -160,7 +160,7 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) ContentType = ParseContentType(DotExt); - // Extensions used only for static file serving — not in the global + // Extensions used only for static file serving - not in the global // ParseContentType table because that table also drives URI extension // stripping for content negotiation, and we don't want /api/foo.txt to // have its extension removed. diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html deleted file mode 100644 index c07bbb692..000000000 --- a/src/zenserver/frontend/html/compute/compute.html +++ /dev/null @@ -1,925 +0,0 @@ -<!DOCTYPE html> -<html lang="en"> -<head> - <meta charset="UTF-8"> - <meta name="viewport" content="width=device-width, initial-scale=1.0"> - <title>Zen Compute Dashboard</title> - <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js"></script> - <link rel="stylesheet" type="text/css" href="../zen.css" /> - <script src="../util/sanitize.js"></script> - <script src="../theme.js"></script> - <script src="../banner.js" defer></script> - <script src="../nav.js" defer></script> - <style> - .grid { - grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); - } - - .chart-container { - position: relative; - height: 300px; - margin-top: 20px; - } - - .stats-row { - display: flex; - justify-content: space-between; - margin-bottom: 12px; - padding: 8px 0; - border-bottom: 1px solid var(--theme_border_subtle); - } - - .stats-row:last-child { - border-bottom: none; - margin-bottom: 0; - } - - .stats-label { - color: var(--theme_g1); - font-size: 13px; - } - - .stats-value { - color: var(--theme_bright); - font-weight: 600; - font-size: 13px; - } - - .rate-stats { - display: grid; - grid-template-columns: repeat(3, 1fr); - gap: 16px; - margin-top: 16px; - } - - .rate-item { - text-align: center; - } - - .rate-value { - font-size: 20px; - font-weight: 600; - color: var(--theme_p0); - } - - .rate-label { - font-size: 11px; - color: var(--theme_g1); - margin-top: 4px; - text-transform: uppercase; - } - - .worker-row { - cursor: pointer; - transition: background 0.15s; - } - - .worker-row:hover { - background: var(--theme_p4); - } - - .worker-row.selected { - background: var(--theme_p3); - } - - .worker-detail { - margin-top: 20px; - border-top: 1px solid var(--theme_g2); - padding-top: 16px; - } - - .worker-detail-title { - font-size: 15px; - font-weight: 600; - color: var(--theme_bright); - margin-bottom: 12px; - } - - .detail-section { - margin-bottom: 16px; - } - - .detail-section-label { - font-size: 11px; - font-weight: 600; - color: var(--theme_g1); - text-transform: uppercase; - letter-spacing: 0.5px; - margin-bottom: 6px; - } - - .detail-table { - width: 100%; - border-collapse: collapse; - font-size: 12px; - } - - .detail-table td { - padding: 4px 8px; - color: var(--theme_g0); - border-bottom: 1px solid var(--theme_border_subtle); - vertical-align: top; - } - - .detail-table td:first-child { - color: var(--theme_g1); - width: 40%; - font-family: monospace; - } - - .detail-table tr:last-child td { - border-bottom: none; - } - - .detail-mono { - font-family: monospace; - font-size: 11px; - color: var(--theme_g1); - } - - .detail-tag { - display: inline-block; - padding: 2px 8px; - border-radius: 4px; - background: var(--theme_border_subtle); - color: var(--theme_g0); - font-size: 11px; - margin: 2px 4px 2px 0; - } - </style> -</head> -<body> - <div class="container" style="max-width: 1400px; margin: 0 auto;"> - <zen-banner cluster-status="nominal" load="0" tagline="Node Overview" logo-src="../favicon.ico"></zen-banner> - <zen-nav> - <a href="/dashboard/">Home</a> - <a href="compute.html">Node</a> - <a href="orchestrator.html">Orchestrator</a> - </zen-nav> - <div class="timestamp">Last updated: <span id="last-update">Never</span></div> - - <div id="error-container"></div> - - <!-- Action Queue Stats --> - <div class="section-title">Action Queue</div> - <div class="grid"> - <div class="card"> - <div class="card-title">Pending Actions</div> - <div class="metric-value" id="actions-pending">-</div> - <div class="metric-label">Waiting to be scheduled</div> - </div> - <div class="card"> - <div class="card-title">Running Actions</div> - <div class="metric-value" id="actions-running">-</div> - <div class="metric-label">Currently executing</div> - </div> - <div class="card"> - <div class="card-title">Completed Actions</div> - <div class="metric-value" id="actions-complete">-</div> - <div class="metric-label">Results available</div> - </div> - </div> - - <!-- Action Queue Chart --> - <div class="card" style="margin-bottom: 30px;"> - <div class="card-title">Action Queue History</div> - <div class="chart-container"> - <canvas id="queue-chart"></canvas> - </div> - </div> - - <!-- Performance Metrics --> - <div class="section-title">Performance Metrics</div> - <div class="card" style="margin-bottom: 30px;"> - <div class="card-title">Completion Rate</div> - <div class="rate-stats"> - <div class="rate-item"> - <div class="rate-value" id="rate-1">-</div> - <div class="rate-label">1 min rate</div> - </div> - <div class="rate-item"> - <div class="rate-value" id="rate-5">-</div> - <div class="rate-label">5 min rate</div> - </div> - <div class="rate-item"> - <div class="rate-value" id="rate-15">-</div> - <div class="rate-label">15 min rate</div> - </div> - </div> - <div style="margin-top: 20px;"> - <div class="stats-row"> - <span class="stats-label">Total Retired</span> - <span class="stats-value" id="retired-count">-</span> - </div> - <div class="stats-row"> - <span class="stats-label">Mean Rate</span> - <span class="stats-value" id="rate-mean">-</span> - </div> - </div> - </div> - - <!-- Workers --> - <div class="section-title">Workers</div> - <div class="card" style="margin-bottom: 30px;"> - <div class="card-title">Worker Status</div> - <div class="stats-row"> - <span class="stats-label">Registered Workers</span> - <span class="stats-value" id="worker-count">-</span> - </div> - <div id="worker-table-container" style="margin-top: 16px; display: none;"> - <table id="worker-table"> - <thead> - <tr> - <th>Name</th> - <th>Platform</th> - <th style="text-align: right;">Cores</th> - <th style="text-align: right;">Timeout</th> - <th style="text-align: right;">Functions</th> - <th>Worker ID</th> - </tr> - </thead> - <tbody id="worker-table-body"></tbody> - </table> - <div id="worker-detail" class="worker-detail" style="display: none;"></div> - </div> - </div> - - <!-- Queues --> - <div class="section-title">Queues</div> - <div class="card" style="margin-bottom: 30px;"> - <div class="card-title">Queue Status</div> - <div id="queue-list-empty" class="empty-state" style="text-align: left;">No queues.</div> - <div id="queue-list-container" style="display: none;"> - <table id="queue-list-table"> - <thead> - <tr> - <th style="text-align: right; width: 60px;">ID</th> - <th style="text-align: center; width: 80px;">Status</th> - <th style="text-align: right;">Active</th> - <th style="text-align: right;">Completed</th> - <th style="text-align: right;">Failed</th> - <th style="text-align: right;">Abandoned</th> - <th style="text-align: right;">Cancelled</th> - <th>Token</th> - </tr> - </thead> - <tbody id="queue-list-body"></tbody> - </table> - </div> - </div> - - <!-- Action History --> - <div class="section-title">Recent Actions</div> - <div class="card" style="margin-bottom: 30px;"> - <div class="card-title">Action History</div> - <div id="action-history-empty" class="empty-state" style="text-align: left;">No actions recorded yet.</div> - <div id="action-history-container" style="display: none;"> - <table id="action-history-table"> - <thead> - <tr> - <th style="text-align: right; width: 60px;">LSN</th> - <th style="text-align: right; width: 60px;">Queue</th> - <th style="text-align: center; width: 70px;">Status</th> - <th>Function</th> - <th style="text-align: right; width: 80px;">Started</th> - <th style="text-align: right; width: 80px;">Finished</th> - <th style="text-align: right; width: 80px;">Duration</th> - <th>Worker ID</th> - <th>Action ID</th> - </tr> - </thead> - <tbody id="action-history-body"></tbody> - </table> - </div> - </div> - - <!-- System Resources --> - <div class="section-title">System Resources</div> - <div class="grid"> - <div class="card"> - <div class="card-title">CPU Usage</div> - <div class="metric-value" id="cpu-usage">-</div> - <div class="metric-label">Percent</div> - <div class="progress-bar"> - <div class="progress-fill" id="cpu-progress" style="width: 0%"></div> - </div> - <div style="position: relative; height: 60px; margin-top: 12px;"> - <canvas id="cpu-chart"></canvas> - </div> - <div style="margin-top: 12px;"> - <div class="stats-row"> - <span class="stats-label">Packages</span> - <span class="stats-value" id="cpu-packages">-</span> - </div> - <div class="stats-row"> - <span class="stats-label">Physical Cores</span> - <span class="stats-value" id="cpu-cores">-</span> - </div> - <div class="stats-row"> - <span class="stats-label">Logical Processors</span> - <span class="stats-value" id="cpu-lp">-</span> - </div> - </div> - </div> - <div class="card"> - <div class="card-title">Memory</div> - <div class="stats-row"> - <span class="stats-label">Used</span> - <span class="stats-value" id="memory-used">-</span> - </div> - <div class="stats-row"> - <span class="stats-label">Total</span> - <span class="stats-value" id="memory-total">-</span> - </div> - <div class="progress-bar"> - <div class="progress-fill" id="memory-progress" style="width: 0%"></div> - </div> - </div> - <div class="card"> - <div class="card-title">Disk</div> - <div class="stats-row"> - <span class="stats-label">Used</span> - <span class="stats-value" id="disk-used">-</span> - </div> - <div class="stats-row"> - <span class="stats-label">Total</span> - <span class="stats-value" id="disk-total">-</span> - </div> - <div class="progress-bar"> - <div class="progress-fill" id="disk-progress" style="width: 0%"></div> - </div> - </div> - </div> - </div> - - <script> - // Configuration - const BASE_URL = window.location.origin; - const REFRESH_INTERVAL = 2000; // 2 seconds - const MAX_HISTORY_POINTS = 60; // Show last 2 minutes - - // Data storage - const history = { - timestamps: [], - pending: [], - running: [], - completed: [], - cpu: [] - }; - - // CPU sparkline chart - const cpuCtx = document.getElementById('cpu-chart').getContext('2d'); - const cpuChart = new Chart(cpuCtx, { - type: 'line', - data: { - labels: [], - datasets: [{ - data: [], - borderColor: '#58a6ff', - backgroundColor: 'rgba(88, 166, 255, 0.15)', - borderWidth: 1.5, - tension: 0.4, - fill: true, - pointRadius: 0 - }] - }, - options: { - responsive: true, - maintainAspectRatio: false, - animation: false, - plugins: { legend: { display: false }, tooltip: { enabled: false } }, - scales: { - x: { display: false }, - y: { display: false, min: 0, max: 100 } - } - } - }); - - // Queue chart setup - const ctx = document.getElementById('queue-chart').getContext('2d'); - const chart = new Chart(ctx, { - type: 'line', - data: { - labels: [], - datasets: [ - { - label: 'Pending', - data: [], - borderColor: '#f0883e', - backgroundColor: 'rgba(240, 136, 62, 0.1)', - tension: 0.4, - fill: true - }, - { - label: 'Running', - data: [], - borderColor: '#58a6ff', - backgroundColor: 'rgba(88, 166, 255, 0.1)', - tension: 0.4, - fill: true - }, - { - label: 'Completed', - data: [], - borderColor: '#3fb950', - backgroundColor: 'rgba(63, 185, 80, 0.1)', - tension: 0.4, - fill: true - } - ] - }, - options: { - responsive: true, - maintainAspectRatio: false, - plugins: { - legend: { - display: true, - labels: { - color: '#8b949e' - } - } - }, - scales: { - x: { - display: false - }, - y: { - beginAtZero: true, - ticks: { - color: '#8b949e' - }, - grid: { - color: '#21262d' - } - } - } - } - }); - - // Helper functions - - function formatBytes(bytes) { - if (bytes === 0) return '0 B'; - const k = 1024; - const sizes = ['B', 'KB', 'MB', 'GB', 'TB']; - const i = Math.floor(Math.log(bytes) / Math.log(k)); - return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; - } - - function formatRate(rate) { - return rate.toFixed(2) + '/s'; - } - - function showError(message) { - const container = document.getElementById('error-container'); - container.innerHTML = `<div class="error">Error: ${escapeHtml(message)}</div>`; - } - - function clearError() { - document.getElementById('error-container').innerHTML = ''; - } - - function updateTimestamp() { - const now = new Date(); - document.getElementById('last-update').textContent = now.toLocaleTimeString(); - } - - // Fetch functions - async function fetchJSON(endpoint) { - const response = await fetch(`${BASE_URL}${endpoint}`, { - headers: { - 'Accept': 'application/json' - } - }); - if (!response.ok) { - throw new Error(`HTTP ${response.status}: ${response.statusText}`); - } - return await response.json(); - } - - async function fetchHealth() { - try { - const response = await fetch(`${BASE_URL}/compute/ready`); - const isHealthy = response.status === 200; - - const banner = document.querySelector('zen-banner'); - - if (isHealthy) { - banner.setAttribute('cluster-status', 'nominal'); - banner.setAttribute('load', '0'); - } else { - banner.setAttribute('cluster-status', 'degraded'); - banner.setAttribute('load', '0'); - } - - return isHealthy; - } catch (error) { - const banner = document.querySelector('zen-banner'); - banner.setAttribute('cluster-status', 'degraded'); - banner.setAttribute('load', '0'); - throw error; - } - } - - async function fetchStats() { - const data = await fetchJSON('/stats/compute'); - - // Update action counts - document.getElementById('actions-pending').textContent = data.actions_pending || 0; - document.getElementById('actions-running').textContent = data.actions_submitted || 0; - document.getElementById('actions-complete').textContent = data.actions_complete || 0; - - // Update completion rates - if (data.actions_retired) { - document.getElementById('rate-1').textContent = formatRate(data.actions_retired.rate_1 || 0); - document.getElementById('rate-5').textContent = formatRate(data.actions_retired.rate_5 || 0); - document.getElementById('rate-15').textContent = formatRate(data.actions_retired.rate_15 || 0); - document.getElementById('retired-count').textContent = data.actions_retired.count || 0; - document.getElementById('rate-mean').textContent = formatRate(data.actions_retired.rate_mean || 0); - } - - // Update chart - const now = new Date().toLocaleTimeString(); - history.timestamps.push(now); - history.pending.push(data.actions_pending || 0); - history.running.push(data.actions_submitted || 0); - history.completed.push(data.actions_complete || 0); - - // Keep only last N points - if (history.timestamps.length > MAX_HISTORY_POINTS) { - history.timestamps.shift(); - history.pending.shift(); - history.running.shift(); - history.completed.shift(); - } - - chart.data.labels = history.timestamps; - chart.data.datasets[0].data = history.pending; - chart.data.datasets[1].data = history.running; - chart.data.datasets[2].data = history.completed; - chart.update('none'); - } - - async function fetchSysInfo() { - const data = await fetchJSON('/compute/sysinfo'); - - // Update CPU - const cpuUsage = data.cpu_usage || 0; - document.getElementById('cpu-usage').textContent = cpuUsage.toFixed(1) + '%'; - document.getElementById('cpu-progress').style.width = cpuUsage + '%'; - - const banner = document.querySelector('zen-banner'); - banner.setAttribute('load', cpuUsage.toFixed(1)); - - history.cpu.push(cpuUsage); - if (history.cpu.length > MAX_HISTORY_POINTS) history.cpu.shift(); - cpuChart.data.labels = history.cpu.map(() => ''); - cpuChart.data.datasets[0].data = history.cpu; - cpuChart.update('none'); - - document.getElementById('cpu-packages').textContent = data.cpu_count ?? '-'; - document.getElementById('cpu-cores').textContent = data.core_count ?? '-'; - document.getElementById('cpu-lp').textContent = data.lp_count ?? '-'; - - // Update Memory - const memUsed = data.memory_used || 0; - const memTotal = data.memory_total || 1; - const memPercent = (memUsed / memTotal) * 100; - document.getElementById('memory-used').textContent = formatBytes(memUsed); - document.getElementById('memory-total').textContent = formatBytes(memTotal); - document.getElementById('memory-progress').style.width = memPercent + '%'; - - // Update Disk - const diskUsed = data.disk_used || 0; - const diskTotal = data.disk_total || 1; - const diskPercent = (diskUsed / diskTotal) * 100; - document.getElementById('disk-used').textContent = formatBytes(diskUsed); - document.getElementById('disk-total').textContent = formatBytes(diskTotal); - document.getElementById('disk-progress').style.width = diskPercent + '%'; - } - - // Persists the selected worker ID across refreshes - let selectedWorkerId = null; - - function renderWorkerDetail(id, desc) { - const panel = document.getElementById('worker-detail'); - - if (!desc) { - panel.style.display = 'none'; - return; - } - - function field(label, value) { - return `<tr><td>${label}</td><td>${value ?? '-'}</td></tr>`; - } - - function monoField(label, value) { - return `<tr><td>${label}</td><td class="detail-mono">${value ?? '-'}</td></tr>`; - } - - // Functions - const functions = desc.functions || []; - const functionsHtml = functions.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : - `<table class="detail-table">${functions.map(f => - `<tr><td>${escapeHtml(f.name || '-')}</td><td class="detail-mono">${escapeHtml(f.version || '-')}</td></tr>` - ).join('')}</table>`; - - // Executables - const executables = desc.executables || []; - const totalExecSize = executables.reduce((sum, e) => sum + (e.size || 0), 0); - const execHtml = executables.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : - `<table class="detail-table"> - <tr style="font-size:11px;"> - <td style="color:var(--theme_faint);padding-bottom:4px;">Path</td> - <td style="color:var(--theme_faint);padding-bottom:4px;">Hash</td> - <td style="color:var(--theme_faint);padding-bottom:4px;text-align:right;">Size</td> - </tr> - ${executables.map(e => - `<tr> - <td>${escapeHtml(e.name || '-')}</td> - <td class="detail-mono">${escapeHtml(e.hash || '-')}</td> - <td style="text-align:right;white-space:nowrap;">${e.size != null ? formatBytes(e.size) : '-'}</td> - </tr>` - ).join('')} - <tr style="border-top:1px solid var(--theme_g2);"> - <td style="color:var(--theme_g1);padding-top:6px;">Total</td> - <td></td> - <td style="text-align:right;white-space:nowrap;padding-top:6px;color:var(--theme_bright);font-weight:600;">${formatBytes(totalExecSize)}</td> - </tr> - </table>`; - - // Files - const files = desc.files || []; - const filesHtml = files.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : - `<table class="detail-table">${files.map(f => - `<tr><td>${escapeHtml(f.name || f)}</td><td class="detail-mono">${escapeHtml(f.hash || '')}</td></tr>` - ).join('')}</table>`; - - // Dirs - const dirs = desc.dirs || []; - const dirsHtml = dirs.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : - dirs.map(d => `<span class="detail-tag">${escapeHtml(d)}</span>`).join(''); - - // Environment - const env = desc.environment || []; - const envHtml = env.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : - env.map(e => `<span class="detail-tag">${escapeHtml(e)}</span>`).join(''); - - panel.innerHTML = ` - <div class="worker-detail-title">${escapeHtml(desc.name || id)}</div> - <div class="detail-section"> - <table class="detail-table"> - ${field('Worker ID', `<span class="detail-mono">${escapeHtml(id)}</span>`)} - ${field('Path', escapeHtml(desc.path || '-'))} - ${field('Platform', escapeHtml(desc.host || '-'))} - ${monoField('Build System', desc.buildsystem_version)} - ${field('Cores', desc.cores)} - ${field('Timeout', desc.timeout != null ? desc.timeout + 's' : null)} - </table> - </div> - <div class="detail-section"> - <div class="detail-section-label">Functions</div> - ${functionsHtml} - </div> - <div class="detail-section"> - <div class="detail-section-label">Executables</div> - ${execHtml} - </div> - <div class="detail-section"> - <div class="detail-section-label">Files</div> - ${filesHtml} - </div> - <div class="detail-section"> - <div class="detail-section-label">Directories</div> - ${dirsHtml} - </div> - <div class="detail-section"> - <div class="detail-section-label">Environment</div> - ${envHtml} - </div> - `; - panel.style.display = 'block'; - } - - async function fetchWorkers() { - const data = await fetchJSON('/compute/workers'); - const workerIds = data.workers || []; - - document.getElementById('worker-count').textContent = workerIds.length; - - const container = document.getElementById('worker-table-container'); - const tbody = document.getElementById('worker-table-body'); - - if (workerIds.length === 0) { - container.style.display = 'none'; - selectedWorkerId = null; - return; - } - - const descriptors = await Promise.all( - workerIds.map(id => fetchJSON(`/compute/workers/${id}`).catch(() => null)) - ); - - // Build a map for quick lookup by ID - const descriptorMap = {}; - workerIds.forEach((id, i) => { descriptorMap[id] = descriptors[i]; }); - - tbody.innerHTML = ''; - descriptors.forEach((desc, i) => { - const id = workerIds[i]; - const name = desc ? (desc.name || '-') : '-'; - const host = desc ? (desc.host || '-') : '-'; - const cores = desc ? (desc.cores != null ? desc.cores : '-') : '-'; - const timeout = desc ? (desc.timeout != null ? desc.timeout + 's' : '-') : '-'; - const functions = desc ? (desc.functions ? desc.functions.length : 0) : '-'; - - const tr = document.createElement('tr'); - tr.className = 'worker-row' + (id === selectedWorkerId ? ' selected' : ''); - tr.dataset.workerId = id; - tr.innerHTML = ` - <td style="color: var(--theme_bright);">${escapeHtml(name)}</td> - <td>${escapeHtml(host)}</td> - <td style="text-align: right;">${escapeHtml(String(cores))}</td> - <td style="text-align: right;">${escapeHtml(String(timeout))}</td> - <td style="text-align: right;">${escapeHtml(String(functions))}</td> - <td style="color: var(--theme_g1); font-family: monospace; font-size: 11px;">${escapeHtml(id)}</td> - `; - tr.addEventListener('click', () => { - document.querySelectorAll('.worker-row').forEach(r => r.classList.remove('selected')); - if (selectedWorkerId === id) { - // Toggle off - selectedWorkerId = null; - document.getElementById('worker-detail').style.display = 'none'; - } else { - selectedWorkerId = id; - tr.classList.add('selected'); - renderWorkerDetail(id, descriptorMap[id]); - } - }); - tbody.appendChild(tr); - }); - - // Re-render detail if selected worker is still present - if (selectedWorkerId && descriptorMap[selectedWorkerId]) { - renderWorkerDetail(selectedWorkerId, descriptorMap[selectedWorkerId]); - } else if (selectedWorkerId && !descriptorMap[selectedWorkerId]) { - selectedWorkerId = null; - document.getElementById('worker-detail').style.display = 'none'; - } - - container.style.display = 'block'; - } - - // Windows FILETIME: 100ns ticks since 1601-01-01. Convert to JS Date. - const FILETIME_EPOCH_OFFSET_MS = 11644473600000n; - function filetimeToDate(ticks) { - if (!ticks) return null; - const ms = BigInt(ticks) / 10000n - FILETIME_EPOCH_OFFSET_MS; - return new Date(Number(ms)); - } - - function formatTime(date) { - if (!date) return '-'; - return date.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' }); - } - - function formatDuration(startDate, endDate) { - if (!startDate || !endDate) return '-'; - const ms = endDate - startDate; - if (ms < 0) return '-'; - if (ms < 1000) return ms + ' ms'; - if (ms < 60000) return (ms / 1000).toFixed(2) + ' s'; - const m = Math.floor(ms / 60000); - const s = ((ms % 60000) / 1000).toFixed(0).padStart(2, '0'); - return `${m}m ${s}s`; - } - - async function fetchQueues() { - const data = await fetchJSON('/compute/queues'); - const queues = data.queues || []; - - const empty = document.getElementById('queue-list-empty'); - const container = document.getElementById('queue-list-container'); - const tbody = document.getElementById('queue-list-body'); - - if (queues.length === 0) { - empty.style.display = ''; - container.style.display = 'none'; - return; - } - - empty.style.display = 'none'; - tbody.innerHTML = ''; - - for (const q of queues) { - const id = q.queue_id ?? '-'; - const badge = q.state === 'cancelled' - ? '<span class="status-badge failure">cancelled</span>' - : q.state === 'draining' - ? '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_warn) 15%, transparent);color:var(--theme_warn);">draining</span>' - : q.is_complete - ? '<span class="status-badge success">complete</span>' - : '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_p0) 15%, transparent);color:var(--theme_p0);">active</span>'; - const token = q.queue_token - ? `<span class="detail-mono">${escapeHtml(q.queue_token)}</span>` - : '<span style="color:var(--theme_faint);">-</span>'; - - const tr = document.createElement('tr'); - tr.innerHTML = ` - <td style="text-align: right; font-family: monospace; color: var(--theme_bright);">${escapeHtml(String(id))}</td> - <td style="text-align: center;">${badge}</td> - <td style="text-align: right;">${q.active_count ?? 0}</td> - <td style="text-align: right; color: var(--theme_ok);">${q.completed_count ?? 0}</td> - <td style="text-align: right; color: var(--theme_fail);">${q.failed_count ?? 0}</td> - <td style="text-align: right; color: var(--theme_warn);">${q.abandoned_count ?? 0}</td> - <td style="text-align: right; color: var(--theme_warn);">${q.cancelled_count ?? 0}</td> - <td>${token}</td> - `; - tbody.appendChild(tr); - } - - container.style.display = 'block'; - } - - async function fetchActionHistory() { - const data = await fetchJSON('/compute/jobs/history?limit=50'); - const entries = data.history || []; - - const empty = document.getElementById('action-history-empty'); - const container = document.getElementById('action-history-container'); - const tbody = document.getElementById('action-history-body'); - - if (entries.length === 0) { - empty.style.display = ''; - container.style.display = 'none'; - return; - } - - empty.style.display = 'none'; - tbody.innerHTML = ''; - - // Entries arrive oldest-first; reverse to show newest at top - for (const entry of [...entries].reverse()) { - const lsn = entry.lsn ?? '-'; - const succeeded = entry.succeeded; - const badge = succeeded == null - ? '<span class="status-badge" style="background:var(--theme_border_subtle);color:var(--theme_g1);">unknown</span>' - : succeeded - ? '<span class="status-badge success">ok</span>' - : '<span class="status-badge failure">failed</span>'; - const desc = entry.actionDescriptor || {}; - const fn = desc.Function || '-'; - const workerId = entry.workerId || '-'; - const actionId = entry.actionId || '-'; - - const startDate = filetimeToDate(entry.time_Running); - const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed); - - const queueId = entry.queueId || 0; - const queueCell = queueId - ? `<a href="/compute/queues/${queueId}" style="color: var(--theme_ln); text-decoration: none; font-family: monospace;">${escapeHtml(String(queueId))}</a>` - : '<span style="color: var(--theme_faint);">-</span>'; - - const tr = document.createElement('tr'); - tr.innerHTML = ` - <td style="text-align: right; font-family: monospace; color: var(--theme_g1);">${escapeHtml(String(lsn))}</td> - <td style="text-align: right;">${queueCell}</td> - <td style="text-align: center;">${badge}</td> - <td style="color: var(--theme_bright);">${escapeHtml(fn)}</td> - <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(startDate)}</td> - <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(endDate)}</td> - <td style="text-align: right; font-size: 12px; white-space: nowrap;">${formatDuration(startDate, endDate)}</td> - <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(workerId)}</td> - <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(actionId)}</td> - `; - tbody.appendChild(tr); - } - - container.style.display = 'block'; - } - - async function updateDashboard() { - try { - await Promise.all([ - fetchHealth(), - fetchStats(), - fetchSysInfo(), - fetchWorkers(), - fetchQueues(), - fetchActionHistory() - ]); - - clearError(); - updateTimestamp(); - } catch (error) { - console.error('Error updating dashboard:', error); - showError(error.message); - } - } - - // Start updating - updateDashboard(); - setInterval(updateDashboard, REFRESH_INTERVAL); - </script> -</body> -</html> diff --git a/src/zenserver/frontend/html/compute/hub.html b/src/zenserver/frontend/html/compute/hub.html index b15b34577..41c80d3a3 100644 --- a/src/zenserver/frontend/html/compute/hub.html +++ b/src/zenserver/frontend/html/compute/hub.html @@ -83,7 +83,7 @@ } async function fetchStats() { - var data = await fetchJSON('/hub/stats'); + var data = await fetchJSON('/stats/hub'); var current = data.currentInstanceCount || 0; var max = data.maxInstanceCount || 0; diff --git a/src/zenserver/frontend/html/compute/index.html b/src/zenserver/frontend/html/compute/index.html index 9597fd7f3..aaa09aec0 100644 --- a/src/zenserver/frontend/html/compute/index.html +++ b/src/zenserver/frontend/html/compute/index.html @@ -1 +1 @@ -<meta http-equiv="refresh" content="0; url=compute.html" />
\ No newline at end of file +<meta http-equiv="refresh" content="0; url=/dashboard/?page=compute" />
\ No newline at end of file diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html deleted file mode 100644 index d1a2bb015..000000000 --- a/src/zenserver/frontend/html/compute/orchestrator.html +++ /dev/null @@ -1,669 +0,0 @@ -<!DOCTYPE html> -<html lang="en"> -<head> - <meta charset="UTF-8"> - <meta name="viewport" content="width=device-width, initial-scale=1.0"> - <link rel="stylesheet" type="text/css" href="../zen.css" /> - <script src="../util/sanitize.js"></script> - <script src="../theme.js"></script> - <script src="../banner.js" defer></script> - <script src="../nav.js" defer></script> - <title>Zen Orchestrator Dashboard</title> - <style> - .agent-count { - display: flex; - align-items: center; - gap: 8px; - font-size: 14px; - padding: 8px 16px; - border-radius: 6px; - background: var(--theme_g3); - border: 1px solid var(--theme_g2); - } - - .agent-count .count { - font-size: 20px; - font-weight: 600; - color: var(--theme_bright); - } - </style> -</head> -<body> - <div class="container" style="max-width: 1400px; margin: 0 auto;"> - <zen-banner cluster-status="nominal" load="0" logo-src="../favicon.ico"></zen-banner> - <zen-nav> - <a href="/dashboard/">Home</a> - <a href="compute.html">Node</a> - <a href="orchestrator.html">Orchestrator</a> - </zen-nav> - <div class="header"> - <div> - <div class="timestamp">Last updated: <span id="last-update">Never</span></div> - </div> - <div class="agent-count"> - <span>Agents:</span> - <span class="count" id="agent-count">-</span> - </div> - </div> - - <div id="error-container"></div> - - <div class="card"> - <div class="card-title">Compute Agents</div> - <div id="empty-state" class="empty-state">No agents registered.</div> - <table id="agent-table" style="display: none;"> - <thead> - <tr> - <th style="width: 40px; text-align: center;">Health</th> - <th>Hostname</th> - <th style="text-align: right;">CPUs</th> - <th style="text-align: right;">CPU Usage</th> - <th style="text-align: right;">Memory</th> - <th style="text-align: right;">Queues</th> - <th style="text-align: right;">Pending</th> - <th style="text-align: right;">Running</th> - <th style="text-align: right;">Completed</th> - <th style="text-align: right;">Traffic</th> - <th style="text-align: right;">Last Seen</th> - </tr> - </thead> - <tbody id="agent-table-body"></tbody> - </table> - </div> - <div class="card" style="margin-top: 20px;"> - <div class="card-title">Connected Clients</div> - <div id="clients-empty" class="empty-state">No clients connected.</div> - <table id="clients-table" style="display: none;"> - <thead> - <tr> - <th style="width: 40px; text-align: center;">Health</th> - <th>Client ID</th> - <th>Hostname</th> - <th>Address</th> - <th style="text-align: right;">Last Seen</th> - </tr> - </thead> - <tbody id="clients-table-body"></tbody> - </table> - </div> - <div class="card" style="margin-top: 20px;"> - <div style="display: flex; align-items: center; gap: 12px; margin-bottom: 12px;"> - <div class="card-title" style="margin-bottom: 0;">Event History</div> - <div class="history-tabs"> - <button class="history-tab active" data-tab="workers" onclick="switchHistoryTab('workers')">Workers</button> - <button class="history-tab" data-tab="clients" onclick="switchHistoryTab('clients')">Clients</button> - </div> - </div> - <div id="history-panel-workers"> - <div id="history-empty" class="empty-state">No provisioning events recorded.</div> - <table id="history-table" style="display: none;"> - <thead> - <tr> - <th>Time</th> - <th>Event</th> - <th>Worker</th> - <th>Hostname</th> - </tr> - </thead> - <tbody id="history-table-body"></tbody> - </table> - </div> - <div id="history-panel-clients" style="display: none;"> - <div id="client-history-empty" class="empty-state">No client events recorded.</div> - <table id="client-history-table" style="display: none;"> - <thead> - <tr> - <th>Time</th> - <th>Event</th> - <th>Client</th> - <th>Hostname</th> - </tr> - </thead> - <tbody id="client-history-table-body"></tbody> - </table> - </div> - </div> - </div> - - <script> - const BASE_URL = window.location.origin; - const REFRESH_INTERVAL = 2000; - - function showError(message) { - document.getElementById('error-container').innerHTML = - '<div class="error">Error: ' + escapeHtml(message) + '</div>'; - } - - function clearError() { - document.getElementById('error-container').innerHTML = ''; - } - - function formatLastSeen(dtMs) { - if (dtMs == null) return '-'; - var seconds = Math.floor(dtMs / 1000); - if (seconds < 60) return seconds + 's ago'; - var minutes = Math.floor(seconds / 60); - if (minutes < 60) return minutes + 'm ' + (seconds % 60) + 's ago'; - var hours = Math.floor(minutes / 60); - return hours + 'h ' + (minutes % 60) + 'm ago'; - } - - function healthClass(dtMs, reachable) { - if (reachable === false) return 'health-red'; - if (dtMs == null) return 'health-red'; - var seconds = dtMs / 1000; - if (seconds < 30 && reachable === true) return 'health-green'; - if (seconds < 120) return 'health-yellow'; - return 'health-red'; - } - - function healthTitle(dtMs, reachable) { - var seenStr = dtMs != null ? 'Last seen ' + formatLastSeen(dtMs) : 'Never seen'; - if (reachable === true) return seenStr + ' · Reachable'; - if (reachable === false) return seenStr + ' · Unreachable'; - return seenStr + ' · Reachability unknown'; - } - - function formatCpuUsage(percent) { - if (percent == null || percent === 0) return '-'; - return percent.toFixed(1) + '%'; - } - - function formatMemory(usedBytes, totalBytes) { - if (!totalBytes) return '-'; - var usedGiB = usedBytes / (1024 * 1024 * 1024); - var totalGiB = totalBytes / (1024 * 1024 * 1024); - return usedGiB.toFixed(1) + ' / ' + totalGiB.toFixed(1) + ' GiB'; - } - - function formatBytes(bytes) { - if (!bytes) return '-'; - if (bytes < 1024) return bytes + ' B'; - if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KiB'; - if (bytes < 1024 * 1024 * 1024) return (bytes / (1024 * 1024)).toFixed(1) + ' MiB'; - if (bytes < 1024 * 1024 * 1024 * 1024) return (bytes / (1024 * 1024 * 1024)).toFixed(1) + ' GiB'; - return (bytes / (1024 * 1024 * 1024 * 1024)).toFixed(1) + ' TiB'; - } - - function formatTraffic(recv, sent) { - if (!recv && !sent) return '-'; - return formatBytes(recv) + ' / ' + formatBytes(sent); - } - - function parseIpFromUri(uri) { - try { - var url = new URL(uri); - var host = url.hostname; - // Strip IPv6 brackets - if (host.startsWith('[') && host.endsWith(']')) host = host.slice(1, -1); - // Only handle IPv4 - var parts = host.split('.'); - if (parts.length !== 4) return null; - var octets = parts.map(Number); - if (octets.some(function(o) { return isNaN(o) || o < 0 || o > 255; })) return null; - return octets; - } catch (e) { - return null; - } - } - - function computeCidr(ips) { - if (ips.length === 0) return null; - if (ips.length === 1) return ips[0].join('.') + '/32'; - - // Convert each IP to a 32-bit integer - var ints = ips.map(function(o) { - return ((o[0] << 24) | (o[1] << 16) | (o[2] << 8) | o[3]) >>> 0; - }); - - // Find common prefix length by ANDing all identical high bits - var common = ~0 >>> 0; - for (var i = 1; i < ints.length; i++) { - // XOR to find differing bits, then mask away everything from the first difference down - var diff = (ints[0] ^ ints[i]) >>> 0; - if (diff !== 0) { - var bit = 31 - Math.floor(Math.log2(diff)); - var mask = bit > 0 ? ((~0 << (32 - bit)) >>> 0) : 0; - common = (common & mask) >>> 0; - } - } - - // Count leading ones in the common mask - var prefix = 0; - for (var b = 31; b >= 0; b--) { - if ((common >>> b) & 1) prefix++; - else break; - } - - // Network address - var net = (ints[0] & common) >>> 0; - var a = (net >>> 24) & 0xff; - var bv = (net >>> 16) & 0xff; - var c = (net >>> 8) & 0xff; - var d = net & 0xff; - return a + '.' + bv + '.' + c + '.' + d + '/' + prefix; - } - - function renderDashboard(data) { - var banner = document.querySelector('zen-banner'); - if (data.hostname) { - banner.setAttribute('tagline', 'Orchestrator \u2014 ' + data.hostname); - } - var workers = data.workers || []; - - document.getElementById('agent-count').textContent = workers.length; - - if (workers.length === 0) { - banner.setAttribute('cluster-status', 'degraded'); - banner.setAttribute('load', '0'); - } else { - banner.setAttribute('cluster-status', 'nominal'); - } - - var emptyState = document.getElementById('empty-state'); - var table = document.getElementById('agent-table'); - var tbody = document.getElementById('agent-table-body'); - - if (workers.length === 0) { - emptyState.style.display = ''; - table.style.display = 'none'; - } else { - emptyState.style.display = 'none'; - table.style.display = ''; - - tbody.innerHTML = ''; - var totalCpus = 0; - var totalWeightedCpuUsage = 0; - var totalMemUsed = 0; - var totalMemTotal = 0; - var totalQueues = 0; - var totalPending = 0; - var totalRunning = 0; - var totalCompleted = 0; - var totalBytesRecv = 0; - var totalBytesSent = 0; - var allIps = []; - for (var i = 0; i < workers.length; i++) { - var w = workers[i]; - var uri = w.uri || ''; - var dt = w.dt; - var dashboardUrl = uri + '/dashboard/compute/'; - - var id = w.id || ''; - - var hostname = w.hostname || ''; - var cpus = w.cpus || 0; - totalCpus += cpus; - if (cpus > 0 && typeof w.cpu_usage === 'number') { - totalWeightedCpuUsage += w.cpu_usage * cpus; - } - - var memTotal = w.memory_total || 0; - var memUsed = w.memory_used || 0; - totalMemTotal += memTotal; - totalMemUsed += memUsed; - - var activeQueues = w.active_queues || 0; - totalQueues += activeQueues; - - var actionsPending = w.actions_pending || 0; - var actionsRunning = w.actions_running || 0; - var actionsCompleted = w.actions_completed || 0; - totalPending += actionsPending; - totalRunning += actionsRunning; - totalCompleted += actionsCompleted; - - var bytesRecv = w.bytes_received || 0; - var bytesSent = w.bytes_sent || 0; - totalBytesRecv += bytesRecv; - totalBytesSent += bytesSent; - - var ip = parseIpFromUri(uri); - if (ip) allIps.push(ip); - - var reachable = w.reachable; - var hClass = healthClass(dt, reachable); - var hTitle = healthTitle(dt, reachable); - - var platform = w.platform || ''; - var badges = ''; - if (platform) { - var platColors = { windows: '#0078d4', wine: '#722f37', linux: '#e95420', macos: '#a2aaad' }; - var platColor = platColors[platform] || '#8b949e'; - badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + platColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(platform) + '</span>'; - } - var provisioner = w.provisioner || ''; - if (provisioner) { - var provColors = { horde: '#8957e5', nomad: '#3fb950' }; - var provColor = provColors[provisioner] || '#8b949e'; - badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + provColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(provisioner) + '</span>'; - } - - var tr = document.createElement('tr'); - tr.title = id; - tr.innerHTML = - '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' + - '<td><a href="' + escapeHtml(dashboardUrl) + '" target="_blank">' + escapeHtml(hostname) + '</a>' + badges + '</td>' + - '<td style="text-align: right;">' + (cpus > 0 ? cpus : '-') + '</td>' + - '<td style="text-align: right;">' + formatCpuUsage(w.cpu_usage) + '</td>' + - '<td style="text-align: right;">' + formatMemory(memUsed, memTotal) + '</td>' + - '<td style="text-align: right;">' + (activeQueues > 0 ? activeQueues : '-') + '</td>' + - '<td style="text-align: right;">' + actionsPending + '</td>' + - '<td style="text-align: right;">' + actionsRunning + '</td>' + - '<td style="text-align: right;">' + actionsCompleted + '</td>' + - '<td style="text-align: right; font-size: 11px; color: var(--theme_g1);">' + formatTraffic(bytesRecv, bytesSent) + '</td>' + - '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>'; - tbody.appendChild(tr); - } - - var clusterLoad = totalCpus > 0 ? (totalWeightedCpuUsage / totalCpus) : 0; - banner.setAttribute('load', clusterLoad.toFixed(1)); - - // Total row - var cidr = computeCidr(allIps); - var totalTr = document.createElement('tr'); - totalTr.className = 'total-row'; - totalTr.innerHTML = - '<td></td>' + - '<td style="text-align: right; color: var(--theme_g1); text-transform: uppercase; font-size: 11px;">Total' + (cidr ? ' <span style="font-family: monospace; font-weight: normal;">' + escapeHtml(cidr) + '</span>' : '') + '</td>' + - '<td style="text-align: right;">' + totalCpus + '</td>' + - '<td></td>' + - '<td style="text-align: right;">' + formatMemory(totalMemUsed, totalMemTotal) + '</td>' + - '<td style="text-align: right;">' + totalQueues + '</td>' + - '<td style="text-align: right;">' + totalPending + '</td>' + - '<td style="text-align: right;">' + totalRunning + '</td>' + - '<td style="text-align: right;">' + totalCompleted + '</td>' + - '<td style="text-align: right; font-size: 11px;">' + formatTraffic(totalBytesRecv, totalBytesSent) + '</td>' + - '<td></td>'; - tbody.appendChild(totalTr); - } - - clearError(); - document.getElementById('last-update').textContent = new Date().toLocaleTimeString(); - - // Render provisioning history if present in WebSocket payload - if (data.events) { - renderProvisioningHistory(data.events); - } - - // Render connected clients if present - if (data.clients) { - renderClients(data.clients); - } - - // Render client history if present - if (data.client_events) { - renderClientHistory(data.client_events); - } - } - - function eventBadge(type) { - var colors = { joined: 'var(--theme_ok)', left: 'var(--theme_fail)', returned: 'var(--theme_warn)' }; - var labels = { joined: 'Joined', left: 'Left', returned: 'Returned' }; - var color = colors[type] || 'var(--theme_g1)'; - var label = labels[type] || type; - return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>'; - } - - function formatTimestamp(ts) { - if (!ts) return '-'; - // CbObject DateTime serialized as ticks (100ns since 0001-01-01) or ISO string - var date; - if (typeof ts === 'number') { - // .NET-style ticks: convert to Unix ms - var unixMs = (ts - 621355968000000000) / 10000; - date = new Date(unixMs); - } else { - date = new Date(ts); - } - if (isNaN(date.getTime())) return '-'; - return date.toLocaleTimeString(); - } - - var activeHistoryTab = 'workers'; - - function switchHistoryTab(tab) { - activeHistoryTab = tab; - var tabs = document.querySelectorAll('.history-tab'); - for (var i = 0; i < tabs.length; i++) { - tabs[i].classList.toggle('active', tabs[i].getAttribute('data-tab') === tab); - } - document.getElementById('history-panel-workers').style.display = tab === 'workers' ? '' : 'none'; - document.getElementById('history-panel-clients').style.display = tab === 'clients' ? '' : 'none'; - } - - function renderProvisioningHistory(events) { - var emptyState = document.getElementById('history-empty'); - var table = document.getElementById('history-table'); - var tbody = document.getElementById('history-table-body'); - - if (!events || events.length === 0) { - emptyState.style.display = ''; - table.style.display = 'none'; - return; - } - - emptyState.style.display = 'none'; - table.style.display = ''; - tbody.innerHTML = ''; - - for (var i = 0; i < events.length; i++) { - var evt = events[i]; - var tr = document.createElement('tr'); - tr.innerHTML = - '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' + - '<td>' + eventBadge(evt.type) + '</td>' + - '<td>' + escapeHtml(evt.worker_id || '') + '</td>' + - '<td>' + escapeHtml(evt.hostname || '') + '</td>'; - tbody.appendChild(tr); - } - } - - function clientHealthClass(dtMs) { - if (dtMs == null) return 'health-red'; - var seconds = dtMs / 1000; - if (seconds < 30) return 'health-green'; - if (seconds < 120) return 'health-yellow'; - return 'health-red'; - } - - function renderClients(clients) { - var emptyState = document.getElementById('clients-empty'); - var table = document.getElementById('clients-table'); - var tbody = document.getElementById('clients-table-body'); - - if (!clients || clients.length === 0) { - emptyState.style.display = ''; - table.style.display = 'none'; - return; - } - - emptyState.style.display = 'none'; - table.style.display = ''; - tbody.innerHTML = ''; - - for (var i = 0; i < clients.length; i++) { - var c = clients[i]; - var dt = c.dt; - var hClass = clientHealthClass(dt); - var hTitle = dt != null ? 'Last seen ' + formatLastSeen(dt) : 'Never seen'; - - var sessionBadge = ''; - if (c.session_id) { - sessionBadge = ' <span style="font-family:monospace;font-size:10px;color:var(--theme_faint);" title="Session ' + escapeHtml(c.session_id) + '">' + escapeHtml(c.session_id.substring(0, 8)) + '</span>'; - } - - var tr = document.createElement('tr'); - tr.innerHTML = - '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' + - '<td>' + escapeHtml(c.id || '') + sessionBadge + '</td>' + - '<td>' + escapeHtml(c.hostname || '') + '</td>' + - '<td style="font-family: monospace; font-size: 12px; color: var(--theme_g1);">' + escapeHtml(c.address || '') + '</td>' + - '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>'; - tbody.appendChild(tr); - } - } - - function clientEventBadge(type) { - var colors = { connected: 'var(--theme_ok)', disconnected: 'var(--theme_fail)', updated: 'var(--theme_warn)' }; - var labels = { connected: 'Connected', disconnected: 'Disconnected', updated: 'Updated' }; - var color = colors[type] || 'var(--theme_g1)'; - var label = labels[type] || type; - return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>'; - } - - function renderClientHistory(events) { - var emptyState = document.getElementById('client-history-empty'); - var table = document.getElementById('client-history-table'); - var tbody = document.getElementById('client-history-table-body'); - - if (!events || events.length === 0) { - emptyState.style.display = ''; - table.style.display = 'none'; - return; - } - - emptyState.style.display = 'none'; - table.style.display = ''; - tbody.innerHTML = ''; - - for (var i = 0; i < events.length; i++) { - var evt = events[i]; - var tr = document.createElement('tr'); - tr.innerHTML = - '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' + - '<td>' + clientEventBadge(evt.type) + '</td>' + - '<td>' + escapeHtml(evt.client_id || '') + '</td>' + - '<td>' + escapeHtml(evt.hostname || '') + '</td>'; - tbody.appendChild(tr); - } - } - - // Fetch-based polling fallback - var pollTimer = null; - - async function fetchProvisioningHistory() { - try { - var response = await fetch(BASE_URL + '/orch/history?limit=50', { - headers: { 'Accept': 'application/json' } - }); - if (response.ok) { - var data = await response.json(); - renderProvisioningHistory(data.events || []); - } - } catch (e) { - console.error('Error fetching provisioning history:', e); - } - } - - async function fetchClients() { - try { - var response = await fetch(BASE_URL + '/orch/clients', { - headers: { 'Accept': 'application/json' } - }); - if (response.ok) { - var data = await response.json(); - renderClients(data.clients || []); - } - } catch (e) { - console.error('Error fetching clients:', e); - } - } - - async function fetchClientHistory() { - try { - var response = await fetch(BASE_URL + '/orch/clients/history?limit=50', { - headers: { 'Accept': 'application/json' } - }); - if (response.ok) { - var data = await response.json(); - renderClientHistory(data.client_events || []); - } - } catch (e) { - console.error('Error fetching client history:', e); - } - } - - async function fetchDashboard() { - var banner = document.querySelector('zen-banner'); - try { - var response = await fetch(BASE_URL + '/orch/agents', { - headers: { 'Accept': 'application/json' } - }); - - if (!response.ok) { - banner.setAttribute('cluster-status', 'degraded'); - throw new Error('HTTP ' + response.status + ': ' + response.statusText); - } - - renderDashboard(await response.json()); - fetchProvisioningHistory(); - fetchClients(); - fetchClientHistory(); - } catch (error) { - console.error('Error updating dashboard:', error); - showError(error.message); - banner.setAttribute('cluster-status', 'offline'); - } - } - - function startPolling() { - if (pollTimer) return; - fetchDashboard(); - pollTimer = setInterval(fetchDashboard, REFRESH_INTERVAL); - } - - function stopPolling() { - if (pollTimer) { - clearInterval(pollTimer); - pollTimer = null; - } - } - - // WebSocket connection with automatic reconnect and polling fallback - var ws = null; - - function connectWebSocket() { - var proto = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - ws = new WebSocket(proto + '//' + window.location.host + '/orch/ws'); - - ws.onopen = function() { - stopPolling(); - clearError(); - }; - - ws.onmessage = function(event) { - try { - renderDashboard(JSON.parse(event.data)); - } catch (e) { - console.error('WebSocket message parse error:', e); - } - }; - - ws.onclose = function() { - ws = null; - startPolling(); - setTimeout(connectWebSocket, 3000); - }; - - ws.onerror = function() { - // onclose will fire after onerror - }; - } - - // Fetch orchestrator hostname for the banner - fetch(BASE_URL + '/orch/status', { headers: { 'Accept': 'application/json' } }) - .then(function(r) { return r.ok ? r.json() : null; }) - .then(function(d) { - if (d && d.hostname) { - document.querySelector('zen-banner').setAttribute('tagline', 'Orchestrator \u2014 ' + d.hostname); - } - }) - .catch(function() {}); - - // Initial load via fetch, then try WebSocket - fetchDashboard(); - connectWebSocket(); - </script> -</body> -</html> diff --git a/src/zenserver/frontend/html/pages/builds.js b/src/zenserver/frontend/html/pages/builds.js index 095f0bf29..c63d13b91 100644 --- a/src/zenserver/frontend/html/pages/builds.js +++ b/src/zenserver/frontend/html/pages/builds.js @@ -16,7 +16,7 @@ export class Page extends ZenPage this.set_title("build store"); // Build Store Stats - const stats_section = this.add_section("Build Store Stats"); + const stats_section = this._collapsible_section("Build Store Service Stats"); stats_section.tag().classify("dropall").text("raw yaml \u2192").on_click(() => { window.open("/stats/builds.yaml", "_blank"); }); @@ -39,6 +39,7 @@ export class Page extends ZenPage _render_stats(stats) { + stats = this._merge_last_stats(stats); const grid = this._stats_grid; const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj); @@ -49,39 +50,30 @@ export class Page extends ZenPage // Build Store tile { - const blobs = safe(stats, "store.blobs"); - const metadata = safe(stats, "store.metadata"); - if (blobs || metadata) - { - const tile = grid.tag().classify("card").classify("stats-tile"); - tile.tag().classify("card-title").text("Build Store"); - const columns = tile.tag().classify("tile-columns"); + const blobs = safe(stats, "store.blobs") || {}; + const metadata = safe(stats, "store.metadata") || {}; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Build Store"); + const columns = tile.tag().classify("tile-columns"); - const left = columns.tag().classify("tile-metrics"); - this._metric(left, Friendly.bytes(safe(stats, "store.size.disk") || 0), "disk", true); - if (blobs) - { - this._metric(left, Friendly.sep(blobs.count || 0), "blobs"); - this._metric(left, Friendly.sep(blobs.readcount || 0), "blob reads"); - this._metric(left, Friendly.sep(blobs.writecount || 0), "blob writes"); - const blobHitRatio = (blobs.readcount || 0) > 0 - ? (((blobs.hitcount || 0) / blobs.readcount) * 100).toFixed(1) + "%" - : "-"; - this._metric(left, blobHitRatio, "blob hit ratio"); - } + const left = columns.tag().classify("tile-metrics"); + this._metric(left, Friendly.bytes(safe(stats, "store.size.disk") || 0), "disk", true); + this._metric(left, Friendly.sep(blobs.count || 0), "blobs"); + this._metric(left, Friendly.sep(blobs.readcount || 0), "blob reads"); + this._metric(left, Friendly.sep(blobs.writecount || 0), "blob writes"); + const blobHitRatio = (blobs.readcount || 0) > 0 + ? (((blobs.hitcount || 0) / blobs.readcount) * 100).toFixed(1) + "%" + : "-"; + this._metric(left, blobHitRatio, "blob hit ratio"); - const right = columns.tag().classify("tile-metrics"); - if (metadata) - { - this._metric(right, Friendly.sep(metadata.count || 0), "metadata entries", true); - this._metric(right, Friendly.sep(metadata.readcount || 0), "meta reads"); - this._metric(right, Friendly.sep(metadata.writecount || 0), "meta writes"); - const metaHitRatio = (metadata.readcount || 0) > 0 - ? (((metadata.hitcount || 0) / metadata.readcount) * 100).toFixed(1) + "%" - : "-"; - this._metric(right, metaHitRatio, "meta hit ratio"); - } - } + const right = columns.tag().classify("tile-metrics"); + this._metric(right, Friendly.sep(metadata.count || 0), "metadata entries", true); + this._metric(right, Friendly.sep(metadata.readcount || 0), "meta reads"); + this._metric(right, Friendly.sep(metadata.writecount || 0), "meta writes"); + const metaHitRatio = (metadata.readcount || 0) > 0 + ? (((metadata.hitcount || 0) / metadata.readcount) * 100).toFixed(1) + "%" + : "-"; + this._metric(right, metaHitRatio, "meta hit ratio"); } } diff --git a/src/zenserver/frontend/html/pages/cache.js b/src/zenserver/frontend/html/pages/cache.js index 1fc8227c8..683f7df4f 100644 --- a/src/zenserver/frontend/html/pages/cache.js +++ b/src/zenserver/frontend/html/pages/cache.js @@ -6,7 +6,7 @@ import { ZenPage } from "./page.js" import { Fetcher } from "../util/fetcher.js" import { Friendly } from "../util/friendly.js" import { Modal } from "../util/modal.js" -import { Table, Toolbar } from "../util/widgets.js" +import { Table, Toolbar, Pager, add_copy_button } from "../util/widgets.js" //////////////////////////////////////////////////////////////////////////////// export class Page extends ZenPage @@ -44,8 +44,6 @@ export class Page extends ZenPage // Cache Namespaces var section = this._collapsible_section("Cache Namespaces"); - section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all()); - var columns = [ "namespace", "dir", @@ -56,31 +54,30 @@ export class Page extends ZenPage "actions", ]; - var zcache_info = await new Fetcher().resource("/z$/").json(); this._cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_AlignNumeric); - for (const namespace of zcache_info["Namespaces"] || []) - { - new Fetcher().resource(`/z$/${namespace}/`).json().then((data) => { - const row = this._cache_table.add_row( - "", - data["Configuration"]["RootDir"], - data["Buckets"].length, - data["EntryCount"], - Friendly.bytes(data["StorageSize"].DiskSize), - Friendly.bytes(data["StorageSize"].MemorySize) - ); - var cell = row.get_cell(0); - cell.tag().text(namespace).on_click(() => this.view_namespace(namespace)); - - cell = row.get_cell(-1); - const action_tb = new Toolbar(cell, true); - action_tb.left().add("view").on_click(() => this.view_namespace(namespace)); - action_tb.left().add("drop").on_click(() => this.drop_namespace(namespace)); - - row.attr("zs_name", namespace); - }); - } + this._cache_pager = new Pager(section, 25, () => this._render_cache_page(), + Pager.make_search_fn(() => this._cache_data, item => item.namespace)); + const cache_drop_link = document.createElement("span"); + cache_drop_link.className = "dropall zen_action"; + cache_drop_link.style.position = "static"; + cache_drop_link.textContent = "drop-all"; + cache_drop_link.addEventListener("click", () => this.drop_all()); + this._cache_pager.prepend(cache_drop_link); + + const loading = Pager.loading(section); + const zcache_info = await new Fetcher().resource("/z$/").json(); + const namespaces = zcache_info["Namespaces"] || []; + const results = await Promise.allSettled( + namespaces.map(ns => new Fetcher().resource(`/z$/${ns}/`).json().then(data => ({ namespace: ns, data }))) + ); + this._cache_data = results + .filter(r => r.status === "fulfilled") + .map(r => r.value) + .sort((a, b) => a.namespace.localeCompare(b.namespace)); + this._cache_pager.set_total(this._cache_data.length); + this._render_cache_page(); + loading.remove(); // Namespace detail area (inside namespaces section so it collapses together) this._namespace_host = section; @@ -95,84 +92,79 @@ export class Page extends ZenPage } } - _collapsible_section(name) + _render_cache_page() { - const section = this.add_section(name); - const container = section._parent.inner(); - const heading = container.firstElementChild; + const { start, end } = this._cache_pager.page_range(); + this._cache_table.clear(start); + for (let i = start; i < end; i++) + { + const item = this._cache_data[i]; + const data = item.data; + const row = this._cache_table.add_row( + "", + data["Configuration"]["RootDir"], + data["Buckets"].length, + data["EntryCount"], + Friendly.bytes(data["StorageSize"].DiskSize), + Friendly.bytes(data["StorageSize"].MemorySize) + ); - heading.style.cursor = "pointer"; - heading.style.userSelect = "none"; + const cell = row.get_cell(0); + cell.tag().text(item.namespace).on_click(() => this.view_namespace(item.namespace)); + add_copy_button(cell.inner(), item.namespace); + add_copy_button(row.get_cell(1).inner(), data["Configuration"]["RootDir"]); - const indicator = document.createElement("span"); - indicator.textContent = " \u25BC"; - indicator.style.fontSize = "0.7em"; - heading.appendChild(indicator); + const action_cell = row.get_cell(-1); + const action_tb = new Toolbar(action_cell, true); + action_tb.left().add("view").on_click(() => this.view_namespace(item.namespace)); + action_tb.left().add("drop").on_click(() => this.drop_namespace(item.namespace)); - let collapsed = false; - heading.addEventListener("click", (e) => { - if (e.target !== heading && e.target !== indicator) - { - return; - } - collapsed = !collapsed; - indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; - let sibling = heading.nextElementSibling; - while (sibling) - { - sibling.style.display = collapsed ? "none" : ""; - sibling = sibling.nextElementSibling; - } - }); - - return section; + row.attr("zs_name", item.namespace); + } } _render_stats(stats) { + stats = this._merge_last_stats(stats); const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj); const grid = this._stats_grid; - this._last_stats = stats; grid.inner().innerHTML = ""; // Store I/O tile { - const store = safe(stats, "cache.store"); - if (store) - { - const tile = grid.tag().classify("card").classify("stats-tile").classify("stats-tile-detailed"); - if (this._selected_category === "store") tile.classify("stats-tile-selected"); - tile.on_click(() => this._select_category("store")); - tile.tag().classify("card-title").text("Store I/O"); - const columns = tile.tag().classify("tile-columns"); - - const left = columns.tag().classify("tile-metrics"); - const storeHits = store.hits || 0; - const storeMisses = store.misses || 0; - const storeTotal = storeHits + storeMisses; - const storeRatio = storeTotal > 0 ? ((storeHits / storeTotal) * 100).toFixed(1) + "%" : "-"; - this._metric(left, storeRatio, "store hit ratio", true); - this._metric(left, Friendly.sep(storeHits), "hits"); - this._metric(left, Friendly.sep(storeMisses), "misses"); - this._metric(left, Friendly.sep(store.writes || 0), "writes"); - this._metric(left, Friendly.sep(store.rejected_reads || 0), "rejected reads"); - this._metric(left, Friendly.sep(store.rejected_writes || 0), "rejected writes"); - - const right = columns.tag().classify("tile-metrics"); - const readRateMean = safe(store, "read.bytes.rate_mean") || 0; - const readRate1 = safe(store, "read.bytes.rate_1") || 0; - const readRate5 = safe(store, "read.bytes.rate_5") || 0; - const writeRateMean = safe(store, "write.bytes.rate_mean") || 0; - const writeRate1 = safe(store, "write.bytes.rate_1") || 0; - const writeRate5 = safe(store, "write.bytes.rate_5") || 0; - this._metric(right, Friendly.bytes(readRateMean) + "/s", "read rate (mean)", true); - this._metric(right, Friendly.bytes(readRate1) + "/s", "read rate (1m)"); - this._metric(right, Friendly.bytes(readRate5) + "/s", "read rate (5m)"); - this._metric(right, Friendly.bytes(writeRateMean) + "/s", "write rate (mean)"); - this._metric(right, Friendly.bytes(writeRate1) + "/s", "write rate (1m)"); - this._metric(right, Friendly.bytes(writeRate5) + "/s", "write rate (5m)"); - } + const store = safe(stats, "cache.store") || {}; + const tile = grid.tag().classify("card").classify("stats-tile").classify("stats-tile-detailed"); + if (this._selected_category === "store") tile.classify("stats-tile-selected"); + tile.on_click(() => this._select_category("store")); + tile.tag().classify("card-title").text("Store I/O"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + const storeHits = store.hits || 0; + const storeMisses = store.misses || 0; + const storeTotal = storeHits + storeMisses; + const storeRatio = storeTotal > 0 ? ((storeHits / storeTotal) * 100).toFixed(1) + "%" : "-"; + this._metric(left, storeRatio, "store hit ratio", true); + this._metric(left, Friendly.sep(storeHits), "hits"); + this._metric(left, Friendly.sep(storeMisses), "misses"); + this._metric(left, Friendly.sep(store.writes || 0), "writes"); + this._metric(left, Friendly.sep(store.rejected_reads || 0), "rejected reads"); + this._metric(left, Friendly.sep(store.rejected_writes || 0), "rejected writes"); + + const right = columns.tag().classify("tile-metrics"); + const readRateMean = safe(store, "read.bytes.rate_mean") || 0; + const readRate1 = safe(store, "read.bytes.rate_1") || 0; + const readRate5 = safe(store, "read.bytes.rate_5") || 0; + const writeRateMean = safe(store, "write.bytes.rate_mean") || 0; + const writeRate1 = safe(store, "write.bytes.rate_1") || 0; + const writeRate5 = safe(store, "write.bytes.rate_5") || 0; + this._metric(right, Friendly.bytes(readRateMean) + "/s", "read rate (mean)", true); + this._metric(right, Friendly.bytes(readRate1) + "/s", "read rate (1m)"); + this._metric(right, Friendly.bytes(readRate5) + "/s", "read rate (5m)"); + this._metric(right, Friendly.bytes(writeRateMean) + "/s", "write rate (mean)"); + this._metric(right, Friendly.bytes(writeRate1) + "/s", "write rate (1m)"); + this._metric(right, Friendly.bytes(writeRate5) + "/s", "write rate (5m)"); } // Hit/Miss tile @@ -208,89 +200,83 @@ export class Page extends ZenPage // HTTP Requests tile { - const req = safe(stats, "requests"); - if (req) - { - const tile = grid.tag().classify("card").classify("stats-tile"); - tile.tag().classify("card-title").text("HTTP Requests"); - const columns = tile.tag().classify("tile-columns"); + const req = safe(stats, "requests") || {}; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("HTTP Requests"); + const columns = tile.tag().classify("tile-columns"); - const left = columns.tag().classify("tile-metrics"); - const reqData = req.requests || req; - this._metric(left, Friendly.sep(reqData.count || 0), "total requests", true); - if (reqData.rate_mean > 0) - { - this._metric(left, Friendly.sep(reqData.rate_mean, 1) + "/s", "req/sec (mean)"); - } - if (reqData.rate_1 > 0) - { - this._metric(left, Friendly.sep(reqData.rate_1, 1) + "/s", "req/sec (1m)"); - } - if (reqData.rate_5 > 0) - { - this._metric(left, Friendly.sep(reqData.rate_5, 1) + "/s", "req/sec (5m)"); - } - if (reqData.rate_15 > 0) - { - this._metric(left, Friendly.sep(reqData.rate_15, 1) + "/s", "req/sec (15m)"); - } - const badRequests = safe(stats, "cache.badrequestcount") || 0; - this._metric(left, Friendly.sep(badRequests), "bad requests"); + const left = columns.tag().classify("tile-metrics"); + const reqData = req.requests || req; + this._metric(left, Friendly.sep(reqData.count || 0), "total requests", true); + if (reqData.rate_mean > 0) + { + this._metric(left, Friendly.sep(reqData.rate_mean, 1) + "/s", "req/sec (mean)"); + } + if (reqData.rate_1 > 0) + { + this._metric(left, Friendly.sep(reqData.rate_1, 1) + "/s", "req/sec (1m)"); + } + if (reqData.rate_5 > 0) + { + this._metric(left, Friendly.sep(reqData.rate_5, 1) + "/s", "req/sec (5m)"); + } + if (reqData.rate_15 > 0) + { + this._metric(left, Friendly.sep(reqData.rate_15, 1) + "/s", "req/sec (15m)"); + } + const badRequests = safe(stats, "cache.badrequestcount") || 0; + this._metric(left, Friendly.sep(badRequests), "bad requests"); - const right = columns.tag().classify("tile-metrics"); - this._metric(right, Friendly.duration(reqData.t_avg || 0), "avg latency", true); - if (reqData.t_p75) - { - this._metric(right, Friendly.duration(reqData.t_p75), "p75"); - } - if (reqData.t_p95) - { - this._metric(right, Friendly.duration(reqData.t_p95), "p95"); - } - if (reqData.t_p99) - { - this._metric(right, Friendly.duration(reqData.t_p99), "p99"); - } - if (reqData.t_p999) - { - this._metric(right, Friendly.duration(reqData.t_p999), "p999"); - } - if (reqData.t_max) - { - this._metric(right, Friendly.duration(reqData.t_max), "max"); - } + const right = columns.tag().classify("tile-metrics"); + this._metric(right, Friendly.duration(reqData.t_avg || 0), "avg latency", true); + if (reqData.t_p75) + { + this._metric(right, Friendly.duration(reqData.t_p75), "p75"); + } + if (reqData.t_p95) + { + this._metric(right, Friendly.duration(reqData.t_p95), "p95"); + } + if (reqData.t_p99) + { + this._metric(right, Friendly.duration(reqData.t_p99), "p99"); + } + if (reqData.t_p999) + { + this._metric(right, Friendly.duration(reqData.t_p999), "p999"); + } + if (reqData.t_max) + { + this._metric(right, Friendly.duration(reqData.t_max), "max"); } } // RPC tile { - const rpc = safe(stats, "cache.rpc"); - if (rpc) - { - const tile = grid.tag().classify("card").classify("stats-tile"); - tile.tag().classify("card-title").text("RPC"); - const columns = tile.tag().classify("tile-columns"); + const rpc = safe(stats, "cache.rpc") || {}; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("RPC"); + const columns = tile.tag().classify("tile-columns"); - const left = columns.tag().classify("tile-metrics"); - this._metric(left, Friendly.sep(rpc.count || 0), "rpc calls", true); - this._metric(left, Friendly.sep(rpc.ops || 0), "batch ops"); + const left = columns.tag().classify("tile-metrics"); + this._metric(left, Friendly.sep(rpc.count || 0), "rpc calls", true); + this._metric(left, Friendly.sep(rpc.ops || 0), "batch ops"); - const right = columns.tag().classify("tile-metrics"); - if (rpc.records) - { - this._metric(right, Friendly.sep(rpc.records.count || 0), "record calls"); - this._metric(right, Friendly.sep(rpc.records.ops || 0), "record ops"); - } - if (rpc.values) - { - this._metric(right, Friendly.sep(rpc.values.count || 0), "value calls"); - this._metric(right, Friendly.sep(rpc.values.ops || 0), "value ops"); - } - if (rpc.chunks) - { - this._metric(right, Friendly.sep(rpc.chunks.count || 0), "chunk calls"); - this._metric(right, Friendly.sep(rpc.chunks.ops || 0), "chunk ops"); - } + const right = columns.tag().classify("tile-metrics"); + if (rpc.records) + { + this._metric(right, Friendly.sep(rpc.records.count || 0), "record calls"); + this._metric(right, Friendly.sep(rpc.records.ops || 0), "record ops"); + } + if (rpc.values) + { + this._metric(right, Friendly.sep(rpc.values.count || 0), "value calls"); + this._metric(right, Friendly.sep(rpc.values.ops || 0), "value ops"); + } + if (rpc.chunks) + { + this._metric(right, Friendly.sep(rpc.chunks.count || 0), "chunk calls"); + this._metric(right, Friendly.sep(rpc.chunks.ops || 0), "chunk ops"); } } @@ -313,7 +299,7 @@ export class Page extends ZenPage this._metric(right, safe(stats, "cid.size.large") != null ? Friendly.bytes(safe(stats, "cid.size.large")) : "-", "cid large"); } - // Upstream tile (only if upstream is active) + // Upstream tile (only shown when upstream is active) { const upstream = safe(stats, "upstream"); if (upstream) @@ -644,10 +630,9 @@ export class Page extends ZenPage async drop_all() { const drop = async () => { - for (const row of this._cache_table) + for (const item of this._cache_data || []) { - const namespace = row.attr("zs_name"); - await new Fetcher().resource("z$", namespace).delete(); + await new Fetcher().resource("z$", item.namespace).delete(); } this.reload(); }; diff --git a/src/zenserver/frontend/html/pages/compute.js b/src/zenserver/frontend/html/pages/compute.js index ab3d49c27..c2257029e 100644 --- a/src/zenserver/frontend/html/pages/compute.js +++ b/src/zenserver/frontend/html/pages/compute.js @@ -5,7 +5,7 @@ import { ZenPage } from "./page.js" import { Fetcher } from "../util/fetcher.js" import { Friendly } from "../util/friendly.js" -import { Table } from "../util/widgets.js" +import { Table, add_copy_button } from "../util/widgets.js" const MAX_HISTORY_POINTS = 60; @@ -24,6 +24,12 @@ function formatTime(date) return date.toLocaleTimeString([], { hour: "2-digit", minute: "2-digit", second: "2-digit" }); } +function truncateHash(hash) +{ + if (!hash || hash.length <= 15) return hash; + return hash.slice(0, 6) + "\u2026" + hash.slice(-6); +} + function formatDuration(startDate, endDate) { if (!startDate || !endDate) return "-"; @@ -100,39 +106,6 @@ export class Page extends ZenPage }, 2000); } - _collapsible_section(name) - { - const section = this.add_section(name); - const container = section._parent.inner(); - const heading = container.firstElementChild; - - heading.style.cursor = "pointer"; - heading.style.userSelect = "none"; - - const indicator = document.createElement("span"); - indicator.textContent = " \u25BC"; - indicator.style.fontSize = "0.7em"; - heading.appendChild(indicator); - - let collapsed = false; - heading.addEventListener("click", (e) => { - if (e.target !== heading && e.target !== indicator) - { - return; - } - collapsed = !collapsed; - indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; - let sibling = heading.nextElementSibling; - while (sibling) - { - sibling.style.display = collapsed ? "none" : ""; - sibling = sibling.nextElementSibling; - } - }); - - return section; - } - async _load_chartjs() { if (window.Chart) @@ -338,11 +311,7 @@ export class Page extends ZenPage { const workerIds = data.workers || []; - if (this._workers_table) - { - this._workers_table.clear(); - } - else + if (!this._workers_table) { this._workers_table = this._workers_host.add_widget( Table, @@ -353,6 +322,7 @@ export class Page extends ZenPage if (workerIds.length === 0) { + this._workers_table.clear(); return; } @@ -382,6 +352,10 @@ export class Page extends ZenPage id, ); + // Worker ID column: monospace for hex readability, copy button + row.get_cell(5).style("fontFamily", "'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace"); + add_copy_button(row.get_cell(5).inner(), id); + // Make name clickable to expand detail const cell = row.get_cell(0); cell.tag().text(name).on_click(() => this._toggle_worker_detail(id, desc)); @@ -551,7 +525,7 @@ export class Page extends ZenPage : q.state === "draining" ? "draining" : q.is_complete ? "complete" : "active"; - this._queues_table.add_row( + const qrow = this._queues_table.add_row( id, status, String(q.active_count ?? 0), @@ -561,6 +535,10 @@ export class Page extends ZenPage String(q.cancelled_count ?? 0), q.queue_token || "-", ); + if (q.queue_token) + { + add_copy_button(qrow.get_cell(7).inner(), q.queue_token); + } } } @@ -579,6 +557,11 @@ export class Page extends ZenPage ["LSN", "queue", "status", "function", "started", "finished", "duration", "worker ID", "action ID"], Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1 ); + + // Right-align hash column headers to match data cells + const hdr = this._history_table.inner().firstElementChild; + hdr.children[7].style.textAlign = "right"; + hdr.children[8].style.textAlign = "right"; } // Entries arrive oldest-first; reverse to show newest at top @@ -593,7 +576,10 @@ export class Page extends ZenPage const startDate = filetimeToDate(entry.time_Running); const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed); - this._history_table.add_row( + const workerId = entry.workerId || "-"; + const actionId = entry.actionId || "-"; + + const row = this._history_table.add_row( lsn, queueId, status, @@ -601,9 +587,17 @@ export class Page extends ZenPage formatTime(startDate), formatTime(endDate), formatDuration(startDate, endDate), - entry.workerId || "-", - entry.actionId || "-", + truncateHash(workerId), + truncateHash(actionId), ); + + // Hash columns: force right-align (AlignNumeric misses hex strings starting with a-f), + // use monospace for readability, and show full value on hover + const mono = "'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace"; + row.get_cell(7).style("textAlign", "right").style("fontFamily", mono).attr("title", workerId); + if (workerId !== "-") { add_copy_button(row.get_cell(7).inner(), workerId); } + row.get_cell(8).style("textAlign", "right").style("fontFamily", mono).attr("title", actionId); + if (actionId !== "-") { add_copy_button(row.get_cell(8).inner(), actionId); } } } diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js index 1e4c82e3f..e381f4a71 100644 --- a/src/zenserver/frontend/html/pages/entry.js +++ b/src/zenserver/frontend/html/pages/entry.js @@ -168,7 +168,7 @@ export class Page extends ZenPage if (key === "cook.artifacts") { action_tb.left().add("view-raw").on_click(() => { - window.location = "/" + ["prj", project, "oplog", oplog, value+".json"].join("/"); + window.open("/" + ["prj", project, "oplog", oplog, value+".json"].join("/"), "_self"); }); } diff --git a/src/zenserver/frontend/html/pages/hub.js b/src/zenserver/frontend/html/pages/hub.js index 78e3a090c..b2bca9324 100644 --- a/src/zenserver/frontend/html/pages/hub.js +++ b/src/zenserver/frontend/html/pages/hub.js @@ -6,6 +6,7 @@ import { ZenPage } from "./page.js" import { Fetcher } from "../util/fetcher.js" import { Friendly } from "../util/friendly.js" import { Modal } from "../util/modal.js" +import { flash_highlight, copy_button } from "../util/widgets.js" //////////////////////////////////////////////////////////////////////////////// const STABLE_STATES = new Set(["provisioned", "hibernated", "crashed"]); @@ -20,6 +21,7 @@ function _btn_enabled(state, action) if (action === "hibernate") { return state === "provisioned"; } if (action === "wake") { return state === "hibernated"; } if (action === "deprovision") { return _is_actionable(state); } + if (action === "obliterate") { return _is_actionable(state); } return false; } @@ -82,7 +84,7 @@ export class Page extends ZenPage this.set_title("hub"); // Capacity - const stats_section = this.add_section("Capacity"); + const stats_section = this._collapsible_section("Hub Service Stats"); this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles"); // Modules @@ -96,20 +98,24 @@ export class Page extends ZenPage this._bulk_label.className = "module-bulk-label"; this._btn_bulk_hibernate = _make_bulk_btn("\u23F8", "Hibernate", () => this._exec_action("hibernate", [...this._selected])); this._btn_bulk_wake = _make_bulk_btn("\u25B6", "Wake", () => this._exec_action("wake", [...this._selected])); - this._btn_bulk_deprov = _make_bulk_btn("\u2715", "Deprovision",() => this._confirm_deprovision([...this._selected])); + this._btn_bulk_deprov = _make_bulk_btn("\u23F9", "Deprovision",() => this._confirm_deprovision([...this._selected])); + this._btn_bulk_oblit = _make_bulk_btn("\uD83D\uDD25", "Obliterate", () => this._confirm_obliterate([...this._selected])); const bulk_sep = document.createElement("div"); bulk_sep.className = "module-bulk-sep"; this._btn_hibernate_all = _make_bulk_btn("\u23F8", "Hibernate All", () => this._confirm_all("hibernate", "Hibernate All")); this._btn_wake_all = _make_bulk_btn("\u25B6", "Wake All", () => this._confirm_all("wake", "Wake All")); - this._btn_deprov_all = _make_bulk_btn("\u2715", "Deprovision All",() => this._confirm_all("deprovision", "Deprovision All")); + this._btn_deprov_all = _make_bulk_btn("\u23F9", "Deprovision All",() => this._confirm_all("deprovision", "Deprovision All")); + this._btn_oblit_all = _make_bulk_btn("\uD83D\uDD25", "Obliterate All", () => this._confirm_obliterate(this._modules_data.map(m => m.moduleId))); this._bulk_bar.appendChild(this._bulk_label); this._bulk_bar.appendChild(this._btn_bulk_hibernate); this._bulk_bar.appendChild(this._btn_bulk_wake); this._bulk_bar.appendChild(this._btn_bulk_deprov); + this._bulk_bar.appendChild(this._btn_bulk_oblit); this._bulk_bar.appendChild(bulk_sep); this._bulk_bar.appendChild(this._btn_hibernate_all); this._bulk_bar.appendChild(this._btn_wake_all); this._bulk_bar.appendChild(this._btn_deprov_all); + this._bulk_bar.appendChild(this._btn_oblit_all); mod_host.appendChild(this._bulk_bar); // Module table @@ -152,6 +158,38 @@ export class Page extends ZenPage this._btn_next.className = "module-pager-btn"; this._btn_next.textContent = "Next \u2192"; this._btn_next.addEventListener("click", () => this._go_page(this._page + 1)); + this._btn_provision = _make_bulk_btn("+", "Provision", () => this._show_provision_modal()); + this._btn_obliterate = _make_bulk_btn("\uD83D\uDD25", "Obliterate", () => this._show_obliterate_modal()); + this._search_input = document.createElement("input"); + this._search_input.type = "text"; + this._search_input.className = "module-pager-search"; + this._search_input.placeholder = "Search module\u2026"; + this._search_input.addEventListener("keydown", (e) => + { + if (e.key === "Enter") + { + const term = this._search_input.value.trim().toLowerCase(); + if (!term) { return; } + const idx = this._modules_data.findIndex(m => + (m.moduleId || "").toLowerCase().includes(term) + ); + if (idx >= 0) + { + const id = this._modules_data[idx].moduleId; + this._navigate_to_module(id); + this._flash_module(id); + } + else + { + this._search_input.style.outline = "2px solid var(--theme_fail)"; + setTimeout(() => { this._search_input.style.outline = ""; }, 1000); + } + } + }); + + pager.appendChild(this._btn_provision); + pager.appendChild(this._btn_obliterate); + pager.appendChild(this._search_input); pager.appendChild(this._btn_prev); pager.appendChild(this._pager_label); pager.appendChild(this._btn_next); @@ -164,8 +202,11 @@ export class Page extends ZenPage this._row_cache = new Map(); // moduleId → row refs, for in-place DOM updates this._updating = false; this._page = 0; - this._page_size = 50; + this._page_size = 25; this._expanded = new Set(); // moduleIds with open metrics panel + this._pending_highlight = null; // moduleId to navigate+flash after next poll + this._pending_highlight_timer = null; + this._loading = mod_section.tag().classify("pager-loading").text("Loading\u2026").inner(); await this._update(); this._poll_timer = setInterval(() => this._update(), 2000); @@ -184,6 +225,15 @@ export class Page extends ZenPage this._render_capacity(stats); this._render_modules(status); + if (this._loading) { this._loading.remove(); this._loading = null; } + if (this._pending_highlight && this._module_map.has(this._pending_highlight)) + { + const id = this._pending_highlight; + this._pending_highlight = null; + clearTimeout(this._pending_highlight_timer); + this._navigate_to_module(id); + this._flash_module(id); + } } catch (e) { /* service unavailable */ } finally { this._updating = false; } @@ -203,27 +253,48 @@ export class Page extends ZenPage { const tile = grid.tag().classify("card").classify("stats-tile"); - tile.tag().classify("card-title").text("Active Modules"); + tile.tag().classify("card-title").text("Instances"); const body = tile.tag().classify("tile-metrics"); this._metric(body, Friendly.sep(current), "currently provisioned", true); + this._metric(body, Friendly.sep(max), "high watermark"); + this._metric(body, Friendly.sep(limit), "maximum allowed"); + if (limit > 0) + { + const pct = ((current / limit) * 100).toFixed(0) + "%"; + this._metric(body, pct, "utilization"); + } } + const machine = data.machine || {}; + const limits = data.resource_limits || {}; + if (machine.disk_total_bytes > 0 || machine.memory_total_mib > 0) { - const tile = grid.tag().classify("card").classify("stats-tile"); - tile.tag().classify("card-title").text("Peak Modules"); - const body = tile.tag().classify("tile-metrics"); - this._metric(body, Friendly.sep(max), "high watermark", true); - } + const disk_used = Math.max(0, (machine.disk_total_bytes || 0) - (machine.disk_free_bytes || 0)); + const mem_used = Math.max(0, (machine.memory_total_mib || 0) - (machine.memory_avail_mib || 0)) * 1024 * 1024; + const vmem_used = Math.max(0, (machine.virtual_memory_total_mib || 0) - (machine.virtual_memory_avail_mib || 0)) * 1024 * 1024; + const disk_limit = limits.disk_bytes || 0; + const mem_limit = limits.memory_bytes || 0; + const disk_over = disk_limit > 0 && disk_used > disk_limit; + const mem_over = mem_limit > 0 && mem_used > mem_limit; - { const tile = grid.tag().classify("card").classify("stats-tile"); - tile.tag().classify("card-title").text("Instance Limit"); - const body = tile.tag().classify("tile-metrics"); - this._metric(body, Friendly.sep(limit), "maximum allowed", true); - if (limit > 0) + if (disk_over || mem_over) { tile.inner().setAttribute("data-over", "true"); } + tile.tag().classify("card-title").text("Resources"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + this._metric(left, Friendly.bytes(disk_used), "disk used", true); + this._metric(left, Friendly.bytes(machine.disk_total_bytes), "disk total"); + if (disk_limit > 0) { this._metric(left, Friendly.bytes(disk_limit), "disk limit"); } + + const right = columns.tag().classify("tile-metrics"); + this._metric(right, Friendly.bytes(mem_used), "memory used", true); + this._metric(right, Friendly.bytes(machine.memory_total_mib * 1024 * 1024), "memory total"); + if (mem_limit > 0) { this._metric(right, Friendly.bytes(mem_limit), "memory limit"); } + if (machine.virtual_memory_total_mib > 0) { - const pct = ((current / limit) * 100).toFixed(0) + "%"; - this._metric(body, pct, "utilization"); + this._metric(right, Friendly.bytes(vmem_used), "vmem used", true); + this._metric(right, Friendly.bytes(machine.virtual_memory_total_mib * 1024 * 1024), "vmem total"); } } } @@ -274,7 +345,7 @@ export class Page extends ZenPage row.idx.textContent = i + 1; row.cb.checked = this._selected.has(id); row.dot.setAttribute("data-state", state); - if (state === "deprovisioning") + if (state === "deprovisioning" || state === "obliterating") { row.dot.setAttribute("data-prev-state", prev); } @@ -284,10 +355,20 @@ export class Page extends ZenPage } row.state_text.nodeValue = state; row.port_text.nodeValue = m.port ? String(m.port) : ""; + row.copy_port_btn.style.display = m.port ? "" : "none"; + if (m.state_change_time) + { + const state_label = state.charAt(0).toUpperCase() + state.slice(1); + row.state_since_label.textContent = state_label + " since"; + row.state_age_label.textContent = state_label + " for"; + row.state_since_node.nodeValue = m.state_change_time; + row.state_age_node.nodeValue = Friendly.timespan(Date.now() - new Date(m.state_change_time).getTime()); + } row.btn_open.disabled = state !== "provisioned"; row.btn_hibernate.disabled = !_btn_enabled(state, "hibernate"); row.btn_wake.disabled = !_btn_enabled(state, "wake"); row.btn_deprov.disabled = !_btn_enabled(state, "deprovision"); + row.btn_oblit.disabled = !_btn_enabled(state, "obliterate"); if (m.process_metrics) { @@ -347,6 +428,8 @@ export class Page extends ZenPage id_wrap.style.cssText = "display:inline-flex;align-items:center;font-family:monospace;font-size:14px;"; id_wrap.appendChild(btn_expand); id_wrap.appendChild(document.createTextNode("\u00A0" + id)); + const copy_id_btn = copy_button(id); + id_wrap.appendChild(copy_id_btn); td_id.appendChild(id_wrap); tr.appendChild(td_id); @@ -354,7 +437,7 @@ export class Page extends ZenPage const dot = document.createElement("span"); dot.className = "module-state-dot"; dot.setAttribute("data-state", state); - if (state === "deprovisioning") + if (state === "deprovisioning" || state === "obliterating") { dot.setAttribute("data-prev-state", prev); } @@ -368,27 +451,33 @@ export class Page extends ZenPage td_port.style.cssText = "font-variant-numeric:tabular-nums;"; const port_node = document.createTextNode(port ? String(port) : ""); td_port.appendChild(port_node); + const copy_port_btn = copy_button(() => port_node.nodeValue); + copy_port_btn.style.display = port ? "" : "none"; + td_port.appendChild(copy_port_btn); tr.appendChild(td_port); const td_action = document.createElement("td"); td_action.className = "module-action-cell"; const [wrap_o, btn_o] = _make_action_btn("\u2197", "Open dashboard", () => { - window.open(`${window.location.protocol}//${window.location.hostname}:${port}`, "_blank"); + window.open(`/hub/proxy/${port}/dashboard/`, "_blank"); }); btn_o.disabled = state !== "provisioned"; const [wrap_h, btn_h] = _make_action_btn("\u23F8", "Hibernate", () => this._post_module_action(id, "hibernate").then(() => this._update())); const [wrap_w, btn_w] = _make_action_btn("\u25B6", "Wake", () => this._post_module_action(id, "wake").then(() => this._update())); - const [wrap_d, btn_d] = _make_action_btn("\u2715", "Deprovision", () => this._confirm_deprovision([id])); + const [wrap_d, btn_d] = _make_action_btn("\u23F9", "Deprovision", () => this._confirm_deprovision([id])); + const [wrap_x, btn_x] = _make_action_btn("\uD83D\uDD25", "Obliterate", () => this._confirm_obliterate([id])); btn_h.disabled = !_btn_enabled(state, "hibernate"); btn_w.disabled = !_btn_enabled(state, "wake"); btn_d.disabled = !_btn_enabled(state, "deprovision"); + btn_x.disabled = !_btn_enabled(state, "obliterate"); td_action.appendChild(wrap_h); td_action.appendChild(wrap_w); td_action.appendChild(wrap_d); + td_action.appendChild(wrap_x); td_action.appendChild(wrap_o); tr.appendChild(td_action); - // Build metrics grid from process_metrics keys. + // Build metrics grid: fixed state-time rows followed by process_metrics keys. // Keys are split into two halves and interleaved so the grid fills // top-to-bottom in the left column before continuing in the right column. const metric_nodes = new Map(); @@ -396,6 +485,28 @@ export class Page extends ZenPage metrics_td.colSpan = 6; const metrics_grid = document.createElement("div"); metrics_grid.className = "module-metrics-grid"; + + const _add_fixed_pair = (label, value_str) => { + const label_el = document.createElement("span"); + label_el.className = "module-metrics-label"; + label_el.textContent = label; + const value_node = document.createTextNode(value_str); + const value_el = document.createElement("span"); + value_el.className = "module-metrics-value"; + value_el.appendChild(value_node); + metrics_grid.appendChild(label_el); + metrics_grid.appendChild(value_el); + return { label_el, value_node }; + }; + + const state_label = m.state ? m.state.charAt(0).toUpperCase() + m.state.slice(1) : "State"; + const state_since_str = m.state_change_time || ""; + const state_age_str = m.state_change_time + ? Friendly.timespan(Date.now() - new Date(m.state_change_time).getTime()) + : ""; + const { label_el: state_since_label, value_node: state_since_node } = _add_fixed_pair(state_label + " since", state_since_str); + const { label_el: state_age_label, value_node: state_age_node } = _add_fixed_pair(state_label + " for", state_age_str); + const keys = Object.keys(m.process_metrics || {}); const half = Math.ceil(keys.length / 2); const add_metric_pair = (key) => { @@ -423,7 +534,7 @@ export class Page extends ZenPage metrics_td.appendChild(metrics_grid); metrics_tr.appendChild(metrics_td); - row = { tr, metrics_tr, idx: td_idx, cb, dot, state_text: state_node, port_text: port_node, btn_expand, btn_open: btn_o, btn_hibernate: btn_h, btn_wake: btn_w, btn_deprov: btn_d, metric_nodes }; + row = { tr, metrics_tr, idx: td_idx, cb, dot, state_text: state_node, port_text: port_node, copy_port_btn, btn_expand, btn_open: btn_o, btn_hibernate: btn_h, btn_wake: btn_w, btn_deprov: btn_d, btn_oblit: btn_x, metric_nodes, state_since_node, state_age_node, state_since_label, state_age_label }; this._row_cache.set(id, row); } @@ -533,6 +644,7 @@ export class Page extends ZenPage this._btn_bulk_hibernate.disabled = !this._all_selected_in_state("provisioned"); this._btn_bulk_wake.disabled = !this._all_selected_in_state("hibernated"); this._btn_bulk_deprov.disabled = selected === 0; + this._btn_bulk_oblit.disabled = selected === 0; this._select_all_cb.disabled = total === 0; this._select_all_cb.checked = selected === total && total > 0; @@ -545,6 +657,7 @@ export class Page extends ZenPage this._btn_hibernate_all.disabled = empty; this._btn_wake_all.disabled = empty; this._btn_deprov_all.disabled = empty; + this._btn_oblit_all.disabled = empty; } _on_select_all() @@ -590,6 +703,35 @@ export class Page extends ZenPage .option("Deprovision", () => this._exec_action("deprovision", ids)); } + _confirm_obliterate(ids) + { + const warn = "\uD83D\uDD25 WARNING: This action is irreversible! \uD83D\uDD25"; + const detail = "All local and backend data will be permanently destroyed.\nThis cannot be undone."; + let message; + if (ids.length === 1) + { + const id = ids[0]; + const state = this._module_state(id) || "unknown"; + message = `${warn}\n\n${detail}\n\nModule ID: ${id}\nCurrent state: ${state}`; + } + else + { + message = `${warn}\n\nObliterate ${ids.length} modules.\n\n${detail}`; + } + + new Modal() + .title("\uD83D\uDD25 Obliterate") + .message(message) + .option("Cancel", null) + .option("\uD83D\uDD25 Obliterate", () => this._exec_obliterate(ids)); + } + + async _exec_obliterate(ids) + { + await Promise.allSettled(ids.map(id => fetch(`/hub/modules/${encodeURIComponent(id)}`, { method: "DELETE" }))); + await this._update(); + } + _confirm_all(action, label) { // Capture IDs at modal-open time so action targets the displayed list @@ -614,4 +756,191 @@ export class Page extends ZenPage await fetch(`/hub/modules/${moduleId}/${action}`, { method: "POST" }); } + _show_module_input_modal({ title, submit_label, warning, on_submit }) + { + const MODULE_ID_RE = /^[A-Za-z0-9][A-Za-z0-9-]*$/; + + const overlay = document.createElement("div"); + overlay.className = "zen_modal"; + + const bg = document.createElement("div"); + bg.className = "zen_modal_bg"; + bg.addEventListener("click", () => overlay.remove()); + overlay.appendChild(bg); + + const dialog = document.createElement("div"); + overlay.appendChild(dialog); + + const title_el = document.createElement("div"); + title_el.className = "zen_modal_title"; + title_el.textContent = title; + dialog.appendChild(title_el); + + const content = document.createElement("div"); + content.className = "zen_modal_message"; + content.style.textAlign = "center"; + + if (warning) + { + const warn = document.createElement("div"); + warn.style.cssText = "color:var(--theme_fail);font-weight:bold;margin-bottom:12px;"; + warn.textContent = warning; + content.appendChild(warn); + } + + const input = document.createElement("input"); + input.type = "text"; + input.placeholder = "module-name"; + input.style.cssText = "width:100%;font-size:14px;padding:8px 12px;"; + content.appendChild(input); + + const error_div = document.createElement("div"); + error_div.style.cssText = "color:var(--theme_fail);font-size:12px;margin-top:8px;min-height:1.2em;"; + content.appendChild(error_div); + + dialog.appendChild(content); + + const buttons = document.createElement("div"); + buttons.className = "zen_modal_buttons"; + + const btn_cancel = document.createElement("div"); + btn_cancel.textContent = "Cancel"; + btn_cancel.addEventListener("click", () => overlay.remove()); + + const btn_submit = document.createElement("div"); + btn_submit.textContent = submit_label; + + buttons.appendChild(btn_cancel); + buttons.appendChild(btn_submit); + dialog.appendChild(buttons); + + let submitting = false; + + const set_submit_enabled = (enabled) => { + btn_submit.style.opacity = enabled ? "" : "0.4"; + btn_submit.style.pointerEvents = enabled ? "" : "none"; + }; + + set_submit_enabled(false); + + const validate = () => { + if (submitting) { return false; } + const val = input.value.trim(); + if (val.length === 0) + { + error_div.textContent = ""; + set_submit_enabled(false); + return false; + } + if (!MODULE_ID_RE.test(val)) + { + error_div.textContent = "Only letters, numbers, and hyphens allowed (must start with a letter or number)"; + set_submit_enabled(false); + return false; + } + error_div.textContent = ""; + set_submit_enabled(true); + return true; + }; + + input.addEventListener("input", validate); + + const submit = async () => { + if (submitting) { return; } + const moduleId = input.value.trim(); + if (!MODULE_ID_RE.test(moduleId)) { return; } + + submitting = true; + set_submit_enabled(false); + error_div.textContent = ""; + + try + { + const ok = await on_submit(moduleId); + if (ok) + { + overlay.remove(); + await this._update(); + return; + } + } + catch (e) + { + error_div.textContent = e.message || "Request failed"; + } + submitting = false; + set_submit_enabled(true); + }; + + btn_submit.addEventListener("click", submit); + input.addEventListener("keydown", (e) => { + if (e.key === "Enter" && validate()) { submit(); } + if (e.key === "Escape") { overlay.remove(); } + }); + + document.body.appendChild(overlay); + input.focus(); + + return { error_div }; + } + + _show_provision_modal() + { + const { error_div } = this._show_module_input_modal({ + title: "Provision Module", + submit_label: "Provision", + on_submit: async (moduleId) => { + const resp = await fetch(`/hub/modules/${encodeURIComponent(moduleId)}/provision`, { method: "POST" }); + if (!resp.ok) + { + const msg = await resp.text(); + error_div.textContent = msg || ("HTTP " + resp.status); + return false; + } + // Endpoint returns compact binary (CbObjectWriter), not text + if (resp.status === 200 || resp.status === 202) + { + this._pending_highlight = moduleId; + this._pending_highlight_timer = setTimeout(() => { this._pending_highlight = null; }, 5000); + } + return true; + } + }); + } + + _show_obliterate_modal() + { + const { error_div } = this._show_module_input_modal({ + title: "\uD83D\uDD25 Obliterate Module", + submit_label: "\uD83D\uDD25 Obliterate", + warning: "\uD83D\uDD25 WARNING: This action is irreversible! \uD83D\uDD25\nAll local and backend data will be permanently destroyed.", + on_submit: async (moduleId) => { + const resp = await fetch(`/hub/modules/${encodeURIComponent(moduleId)}`, { method: "DELETE" }); + if (resp.ok) + { + return true; + } + const msg = await resp.text(); + error_div.textContent = msg || ("HTTP " + resp.status); + return false; + } + }); + } + + _navigate_to_module(moduleId) + { + const idx = this._modules_data.findIndex(m => m.moduleId === moduleId); + if (idx >= 0) + { + this._page = Math.floor(idx / this._page_size); + this._render_page(); + } + } + + _flash_module(id) + { + const cached = this._row_cache.get(id); + if (cached) { flash_highlight(cached.tr); } + } + } diff --git a/src/zenserver/frontend/html/pages/orchestrator.js b/src/zenserver/frontend/html/pages/orchestrator.js index 4a9290a3c..d11306998 100644 --- a/src/zenserver/frontend/html/pages/orchestrator.js +++ b/src/zenserver/frontend/html/pages/orchestrator.js @@ -5,7 +5,7 @@ import { ZenPage } from "./page.js" import { Fetcher } from "../util/fetcher.js" import { Friendly } from "../util/friendly.js" -import { Table } from "../util/widgets.js" +import { Table, add_copy_button } from "../util/widgets.js" //////////////////////////////////////////////////////////////////////////////// export class Page extends ZenPage @@ -14,6 +14,14 @@ export class Page extends ZenPage { this.set_title("orchestrator"); + // Provisioner section (hidden until data arrives) + this._prov_section = this._collapsible_section("Provisioner"); + this._prov_section._parent.inner().style.display = "none"; + this._prov_grid = null; + this._prov_target_dirty = false; + this._prov_commit_timer = null; + this._prov_last_target = null; + // Agents section const agents_section = this._collapsible_section("Compute Agents"); this._agents_host = agents_section; @@ -46,48 +54,16 @@ export class Page extends ZenPage this._connect_ws(); } - _collapsible_section(name) - { - const section = this.add_section(name); - const container = section._parent.inner(); - const heading = container.firstElementChild; - - heading.style.cursor = "pointer"; - heading.style.userSelect = "none"; - - const indicator = document.createElement("span"); - indicator.textContent = " \u25BC"; - indicator.style.fontSize = "0.7em"; - heading.appendChild(indicator); - - let collapsed = false; - heading.addEventListener("click", (e) => { - if (e.target !== heading && e.target !== indicator) - { - return; - } - collapsed = !collapsed; - indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; - let sibling = heading.nextElementSibling; - while (sibling) - { - sibling.style.display = collapsed ? "none" : ""; - sibling = sibling.nextElementSibling; - } - }); - - return section; - } - async _fetch_all() { try { - const [agents, history, clients, client_history] = await Promise.all([ + const [agents, history, clients, client_history, prov] = await Promise.all([ new Fetcher().resource("/orch/agents").json(), new Fetcher().resource("/orch/history").param("limit", "50").json().catch(() => null), new Fetcher().resource("/orch/clients").json().catch(() => null), new Fetcher().resource("/orch/clients/history").param("limit", "50").json().catch(() => null), + new Fetcher().resource("/orch/provisioner/status").json().catch(() => null), ]); this._render_agents(agents); @@ -103,6 +79,7 @@ export class Page extends ZenPage { this._render_client_history(client_history.client_events || []); } + this._render_provisioner(prov); } catch (e) { /* service unavailable */ } } @@ -142,6 +119,7 @@ export class Page extends ZenPage { this._render_client_history(data.client_events); } + this._render_provisioner(data.provisioner); } catch (e) { /* ignore parse errors */ } }; @@ -189,7 +167,7 @@ export class Page extends ZenPage return; } - let totalCpus = 0, totalWeightedCpu = 0; + let totalCpus = 0, activeCpus = 0, totalWeightedCpu = 0; let totalMemUsed = 0, totalMemTotal = 0; let totalQueues = 0, totalPending = 0, totalRunning = 0, totalCompleted = 0; let totalRecv = 0, totalSent = 0; @@ -206,8 +184,14 @@ export class Page extends ZenPage const completed = w.actions_completed || 0; const recv = w.bytes_received || 0; const sent = w.bytes_sent || 0; + const provisioner = w.provisioner || ""; + const isProvisioned = provisioner !== ""; totalCpus += cpus; + if (w.provisioner_status === "active") + { + activeCpus += cpus; + } if (cpus > 0 && typeof cpuUsage === "number") { totalWeightedCpu += cpuUsage * cpus; @@ -242,12 +226,49 @@ export class Page extends ZenPage cell.inner().textContent = ""; cell.tag("a").text(hostname).attr("href", w.uri + "/dashboard/compute/").attr("target", "_blank"); } + + // Visual treatment based on provisioner status + const provStatus = w.provisioner_status || ""; + if (!isProvisioned) + { + row.inner().style.opacity = "0.45"; + } + else + { + const hostCell = row.get_cell(0); + const el = hostCell.inner(); + const badge = document.createElement("span"); + const badgeBase = "display:inline-block;margin-left:6px;padding:1px 5px;border-radius:8px;" + + "font-size:9px;font-weight:600;color:#fff;vertical-align:middle;"; + + if (provStatus === "draining") + { + badge.textContent = "draining"; + badge.style.cssText = badgeBase + "background:var(--theme_warn);"; + row.inner().style.opacity = "0.6"; + } + else if (provStatus === "active") + { + badge.textContent = provisioner; + badge.style.cssText = badgeBase + "background:#8957e5;"; + } + else + { + badge.textContent = "deallocated"; + badge.style.cssText = badgeBase + "background:var(--theme_fail);"; + row.inner().style.opacity = "0.45"; + } + el.appendChild(badge); + } } - // Total row + // Total row — show active / total in CPUs column + const cpuLabel = activeCpus < totalCpus + ? Friendly.sep(activeCpus) + " / " + Friendly.sep(totalCpus) + : Friendly.sep(totalCpus); const total = this._agents_table.add_row( "TOTAL", - Friendly.sep(totalCpus), + cpuLabel, "", totalMemTotal > 0 ? Friendly.bytes(totalMemUsed) + " / " + Friendly.bytes(totalMemTotal) : "-", Friendly.sep(totalQueues), @@ -277,12 +298,13 @@ export class Page extends ZenPage for (const c of clients) { - this._clients_table.add_row( + const crow = this._clients_table.add_row( c.id || "", c.hostname || "", c.address || "", this._format_last_seen(c.dt), ); + if (c.id) { add_copy_button(crow.get_cell(0).inner(), c.id); } } } @@ -338,6 +360,154 @@ export class Page extends ZenPage } } + _render_provisioner(prov) + { + const container = this._prov_section._parent.inner(); + + if (!prov || !prov.name) + { + container.style.display = "none"; + return; + } + container.style.display = ""; + + if (!this._prov_grid) + { + this._prov_grid = this._prov_section.tag().classify("grid").classify("stats-tiles"); + this._prov_tiles = {}; + + // Target cores tile with editable input + const target_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + target_tile.tag().classify("card-title").text("Target Cores"); + const target_body = target_tile.tag().classify("tile-metrics"); + const target_m = target_body.tag().classify("tile-metric").classify("tile-metric-hero"); + const input = document.createElement("input"); + input.type = "number"; + input.min = "0"; + input.style.cssText = "width:100px;padding:4px 8px;border:1px solid var(--theme_g2);border-radius:4px;" + + "background:var(--theme_g4);color:var(--theme_bright);font-size:20px;font-weight:600;text-align:right;"; + target_m.inner().appendChild(input); + target_m.tag().classify("metric-label").text("target"); + this._prov_tiles.target_input = input; + + input.addEventListener("focus", () => { this._prov_target_dirty = true; }); + input.addEventListener("input", () => { + this._prov_target_dirty = true; + if (this._prov_commit_timer) + { + clearTimeout(this._prov_commit_timer); + } + this._prov_commit_timer = setTimeout(() => this._commit_provisioner_target(), 800); + }); + input.addEventListener("keydown", (e) => { + if (e.key === "Enter") + { + if (this._prov_commit_timer) + { + clearTimeout(this._prov_commit_timer); + } + this._commit_provisioner_target(); + input.blur(); + } + }); + input.addEventListener("blur", () => { + if (this._prov_commit_timer) + { + clearTimeout(this._prov_commit_timer); + } + this._commit_provisioner_target(); + }); + + // Active cores + const active_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + active_tile.tag().classify("card-title").text("Active Cores"); + const active_body = active_tile.tag().classify("tile-metrics"); + this._prov_tiles.active = active_body; + + // Estimated cores + const est_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + est_tile.tag().classify("card-title").text("Estimated Cores"); + const est_body = est_tile.tag().classify("tile-metrics"); + this._prov_tiles.estimated = est_body; + + // Agents + const agents_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + agents_tile.tag().classify("card-title").text("Agents"); + const agents_body = agents_tile.tag().classify("tile-metrics"); + this._prov_tiles.agents = agents_body; + + // Draining + const drain_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + drain_tile.tag().classify("card-title").text("Draining"); + const drain_body = drain_tile.tag().classify("tile-metrics"); + this._prov_tiles.draining = drain_body; + } + + // Update values + const input = this._prov_tiles.target_input; + if (!this._prov_target_dirty && document.activeElement !== input) + { + input.value = prov.target_cores; + } + this._prov_last_target = prov.target_cores; + + // Re-render metric tiles (clear and recreate content) + for (const key of ["active", "estimated", "agents", "draining"]) + { + this._prov_tiles[key].inner().innerHTML = ""; + } + this._metric(this._prov_tiles.active, Friendly.sep(prov.active_cores), "cores", true); + this._metric(this._prov_tiles.estimated, Friendly.sep(prov.estimated_cores), "cores", true); + this._metric(this._prov_tiles.agents, Friendly.sep(prov.agents), "active", true); + this._metric(this._prov_tiles.draining, Friendly.sep(prov.agents_draining || 0), "agents", true); + } + + async _commit_provisioner_target() + { + const input = this._prov_tiles?.target_input; + if (!input || this._prov_committing) + { + return; + } + const value = parseInt(input.value, 10); + if (isNaN(value) || value < 0) + { + return; + } + if (value === this._prov_last_target) + { + this._prov_target_dirty = false; + return; + } + this._prov_committing = true; + try + { + const resp = await fetch("/orch/provisioner/target", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ target_cores: value }), + }); + if (resp.ok) + { + this._prov_target_dirty = false; + console.log("Target cores set to", value); + } + else + { + const text = await resp.text(); + console.error("Failed to set target cores: HTTP", resp.status, text); + } + } + catch (e) + { + console.error("Failed to set target cores:", e); + } + finally + { + this._prov_committing = false; + } + } + _metric(parent, value, label, hero = false) { const m = parent.tag().classify("tile-metric"); diff --git a/src/zenserver/frontend/html/pages/page.js b/src/zenserver/frontend/html/pages/page.js index cf8d3e3dd..3653abb0e 100644 --- a/src/zenserver/frontend/html/pages/page.js +++ b/src/zenserver/frontend/html/pages/page.js @@ -6,6 +6,26 @@ import { WidgetHost } from "../util/widgets.js" import { Fetcher } from "../util/fetcher.js" import { Friendly } from "../util/friendly.js" +function _deep_merge_stats(base, update) +{ + const result = Object.assign({}, base); + for (const key of Object.keys(update)) + { + const bv = result[key]; + const uv = update[key]; + if (uv && typeof uv === "object" && !Array.isArray(uv) + && bv && typeof bv === "object" && !Array.isArray(bv)) + { + result[key] = _deep_merge_stats(bv, uv); + } + else + { + result[key] = uv; + } + } + return result; +} + //////////////////////////////////////////////////////////////////////////////// export class PageBase extends WidgetHost { @@ -282,10 +302,7 @@ export class ZenPage extends PageBase _render_http_requests_tile(grid, req, bad_requests = undefined) { - if (!req) - { - return; - } + req = req || {}; const tile = grid.tag().classify("card").classify("stats-tile"); tile.tag().classify("card-title").text("HTTP Requests"); const columns = tile.tag().classify("tile-columns"); @@ -337,4 +354,47 @@ export class ZenPage extends PageBase this._metric(right, Friendly.duration(reqData.t_max), "max"); } } + + _merge_last_stats(stats) + { + if (this._last_stats) + { + stats = _deep_merge_stats(this._last_stats, stats); + } + this._last_stats = stats; + return stats; + } + + _collapsible_section(name) + { + const section = this.add_section(name); + const container = section._parent.inner(); + const heading = container.firstElementChild; + + heading.style.cursor = "pointer"; + heading.style.userSelect = "none"; + + const indicator = document.createElement("span"); + indicator.textContent = " \u25BC"; + indicator.style.fontSize = "0.7em"; + heading.appendChild(indicator); + + let collapsed = false; + heading.addEventListener("click", (e) => { + if (e.target !== heading && e.target !== indicator) + { + return; + } + collapsed = !collapsed; + indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; + let sibling = heading.nextElementSibling; + while (sibling) + { + sibling.style.display = collapsed ? "none" : ""; + sibling = sibling.nextElementSibling; + } + }); + + return section; + } } diff --git a/src/zenserver/frontend/html/pages/projects.js b/src/zenserver/frontend/html/pages/projects.js index 2469bf70b..2e76a80f1 100644 --- a/src/zenserver/frontend/html/pages/projects.js +++ b/src/zenserver/frontend/html/pages/projects.js @@ -6,7 +6,7 @@ import { ZenPage } from "./page.js" import { Fetcher } from "../util/fetcher.js" import { Friendly } from "../util/friendly.js" import { Modal } from "../util/modal.js" -import { Table, Toolbar } from "../util/widgets.js" +import { Table, Toolbar, Pager, add_copy_button } from "../util/widgets.js" //////////////////////////////////////////////////////////////////////////////// export class Page extends ZenPage @@ -39,8 +39,6 @@ export class Page extends ZenPage // Projects list var section = this._collapsible_section("Projects"); - section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all()); - var columns = [ "name", "project dir", @@ -51,51 +49,21 @@ export class Page extends ZenPage this._project_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric); - var projects = await new Fetcher().resource("/prj/list").json(); - projects.sort((a, b) => (b.LastAccessTime || 0) - (a.LastAccessTime || 0)); - - for (const project of projects) - { - var row = this._project_table.add_row( - "", - "", - "", - "", - ); - - var cell = row.get_cell(0); - cell.tag().text(project.Id).on_click(() => this.view_project(project.Id)); - - if (project.ProjectRootDir) - { - row.get_cell(1).tag("a").text(project.ProjectRootDir) - .attr("href", "vscode://" + project.ProjectRootDir.replace(/\\/g, "/")); - } - if (project.EngineRootDir) - { - row.get_cell(2).tag("a").text(project.EngineRootDir) - .attr("href", "vscode://" + project.EngineRootDir.replace(/\\/g, "/")); - } - - cell = row.get_cell(-1); - const action_tb = new Toolbar(cell, true).left(); - action_tb.add("view").on_click(() => this.view_project(project.Id)); - action_tb.add("drop").on_click(() => this.drop_project(project.Id)); - - row.attr("zs_name", project.Id); - - // Fetch project details to get oplog count - new Fetcher().resource("prj", project.Id).json().then((info) => { - const oplogs = info["oplogs"] || []; - row.get_cell(3).text(Friendly.sep(oplogs.length)).style("textAlign", "right"); - // Right-align the corresponding header cell - const header = this._project_table._element.firstElementChild; - if (header && header.children[4]) - { - header.children[4].style.textAlign = "right"; - } - }).catch(() => {}); - } + this._project_pager = new Pager(section, 25, () => this._render_projects_page(), + Pager.make_search_fn(() => this._projects_data, p => p.Id)); + const drop_link = document.createElement("span"); + drop_link.className = "dropall zen_action"; + drop_link.style.position = "static"; + drop_link.textContent = "drop-all"; + drop_link.addEventListener("click", () => this.drop_all()); + this._project_pager.prepend(drop_link); + + const loading = Pager.loading(section); + this._projects_data = await new Fetcher().resource("/prj/list").json(); + this._projects_data.sort((a, b) => a.Id.localeCompare(b.Id)); + this._project_pager.set_total(this._projects_data.length); + this._render_projects_page(); + loading.remove(); // Project detail area (inside projects section so it collapses together) this._project_host = section; @@ -110,39 +78,6 @@ export class Page extends ZenPage } } - _collapsible_section(name) - { - const section = this.add_section(name); - const container = section._parent.inner(); - const heading = container.firstElementChild; - - heading.style.cursor = "pointer"; - heading.style.userSelect = "none"; - - const indicator = document.createElement("span"); - indicator.textContent = " \u25BC"; - indicator.style.fontSize = "0.7em"; - heading.appendChild(indicator); - - let collapsed = false; - heading.addEventListener("click", (e) => { - if (e.target !== heading && e.target !== indicator) - { - return; - } - collapsed = !collapsed; - indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; - let sibling = heading.nextElementSibling; - while (sibling) - { - sibling.style.display = collapsed ? "none" : ""; - sibling = sibling.nextElementSibling; - } - }); - - return section; - } - _clear_param(name) { this._params.delete(name); @@ -153,6 +88,7 @@ export class Page extends ZenPage _render_stats(stats) { + stats = this._merge_last_stats(stats); const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj); const grid = this._stats_grid; @@ -163,54 +99,48 @@ export class Page extends ZenPage // Store Operations tile { - const store = safe(stats, "store"); - if (store) - { - const tile = grid.tag().classify("card").classify("stats-tile"); - tile.tag().classify("card-title").text("Store Operations"); - const columns = tile.tag().classify("tile-columns"); - - const left = columns.tag().classify("tile-metrics"); - const proj = store.project || {}; - this._metric(left, Friendly.sep(proj.readcount || 0), "project reads", true); - this._metric(left, Friendly.sep(proj.writecount || 0), "project writes"); - this._metric(left, Friendly.sep(proj.deletecount || 0), "project deletes"); - - const right = columns.tag().classify("tile-metrics"); - const oplog = store.oplog || {}; - this._metric(right, Friendly.sep(oplog.readcount || 0), "oplog reads", true); - this._metric(right, Friendly.sep(oplog.writecount || 0), "oplog writes"); - this._metric(right, Friendly.sep(oplog.deletecount || 0), "oplog deletes"); - } + const store = safe(stats, "store") || {}; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Store Operations"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + const proj = store.project || {}; + this._metric(left, Friendly.sep(proj.readcount || 0), "project reads", true); + this._metric(left, Friendly.sep(proj.writecount || 0), "project writes"); + this._metric(left, Friendly.sep(proj.deletecount || 0), "project deletes"); + + const right = columns.tag().classify("tile-metrics"); + const oplog = store.oplog || {}; + this._metric(right, Friendly.sep(oplog.readcount || 0), "oplog reads", true); + this._metric(right, Friendly.sep(oplog.writecount || 0), "oplog writes"); + this._metric(right, Friendly.sep(oplog.deletecount || 0), "oplog deletes"); } // Op & Chunk tile { - const store = safe(stats, "store"); - if (store) - { - const tile = grid.tag().classify("card").classify("stats-tile"); - tile.tag().classify("card-title").text("Ops & Chunks"); - const columns = tile.tag().classify("tile-columns"); - - const left = columns.tag().classify("tile-metrics"); - const op = store.op || {}; - const opTotal = (op.hitcount || 0) + (op.misscount || 0); - const opRatio = opTotal > 0 ? (((op.hitcount || 0) / opTotal) * 100).toFixed(1) + "%" : "-"; - this._metric(left, opRatio, "op hit ratio", true); - this._metric(left, Friendly.sep(op.hitcount || 0), "op hits"); - this._metric(left, Friendly.sep(op.misscount || 0), "op misses"); - this._metric(left, Friendly.sep(op.writecount || 0), "op writes"); - - const right = columns.tag().classify("tile-metrics"); - const chunk = store.chunk || {}; - const chunkTotal = (chunk.hitcount || 0) + (chunk.misscount || 0); - const chunkRatio = chunkTotal > 0 ? (((chunk.hitcount || 0) / chunkTotal) * 100).toFixed(1) + "%" : "-"; - this._metric(right, chunkRatio, "chunk hit ratio", true); - this._metric(right, Friendly.sep(chunk.hitcount || 0), "chunk hits"); - this._metric(right, Friendly.sep(chunk.misscount || 0), "chunk misses"); - this._metric(right, Friendly.sep(chunk.writecount || 0), "chunk writes"); - } + const store = safe(stats, "store") || {}; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Ops & Chunks"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + const op = store.op || {}; + const opTotal = (op.hitcount || 0) + (op.misscount || 0); + const opRatio = opTotal > 0 ? (((op.hitcount || 0) / opTotal) * 100).toFixed(1) + "%" : "-"; + this._metric(left, opRatio, "op hit ratio", true); + this._metric(left, Friendly.sep(op.hitcount || 0), "op hits"); + this._metric(left, Friendly.sep(op.misscount || 0), "op misses"); + this._metric(left, Friendly.sep(op.writecount || 0), "op writes"); + + const right = columns.tag().classify("tile-metrics"); + const chunk = store.chunk || {}; + const chunkTotal = (chunk.hitcount || 0) + (chunk.misscount || 0); + const chunkRatio = chunkTotal > 0 ? (((chunk.hitcount || 0) / chunkTotal) * 100).toFixed(1) + "%" : "-"; + this._metric(right, chunkRatio, "chunk hit ratio", true); + this._metric(right, Friendly.sep(chunk.hitcount || 0), "chunk hits"); + this._metric(right, Friendly.sep(chunk.misscount || 0), "chunk misses"); + this._metric(right, Friendly.sep(chunk.writecount || 0), "chunk writes"); } // Storage tile @@ -231,6 +161,57 @@ export class Page extends ZenPage } } + _render_projects_page() + { + const { start, end } = this._project_pager.page_range(); + this._project_table.clear(start); + for (let i = start; i < end; i++) + { + const project = this._projects_data[i]; + const row = this._project_table.add_row( + "", + "", + "", + "", + ); + + const cell = row.get_cell(0); + cell.tag().text(project.Id).on_click(() => this.view_project(project.Id)); + add_copy_button(cell.inner(), project.Id); + + if (project.ProjectRootDir) + { + row.get_cell(1).tag("a").text(project.ProjectRootDir) + .attr("href", "vscode://" + project.ProjectRootDir.replace(/\\/g, "/")); + add_copy_button(row.get_cell(1).inner(), project.ProjectRootDir); + } + if (project.EngineRootDir) + { + row.get_cell(2).tag("a").text(project.EngineRootDir) + .attr("href", "vscode://" + project.EngineRootDir.replace(/\\/g, "/")); + add_copy_button(row.get_cell(2).inner(), project.EngineRootDir); + } + + const action_cell = row.get_cell(-1); + const action_tb = new Toolbar(action_cell, true).left(); + action_tb.add("view").on_click(() => this.view_project(project.Id)); + action_tb.add("drop").on_click(() => this.drop_project(project.Id)); + + row.attr("zs_name", project.Id); + + new Fetcher().resource("prj", project.Id).json().then((info) => { + const oplogs = info["oplogs"] || []; + row.get_cell(3).text(Friendly.sep(oplogs.length)).style("textAlign", "right"); + }).catch(() => {}); + } + + const header = this._project_table._element.firstElementChild; + if (header && header.children[4]) + { + header.children[4].style.textAlign = "right"; + } + } + async view_project(project_id) { // Toggle off if already selected @@ -351,10 +332,9 @@ export class Page extends ZenPage async drop_all() { const drop = async () => { - for (const row of this._project_table) + for (const project of this._projects_data || []) { - const project_id = row.attr("zs_name"); - await new Fetcher().resource("prj", project_id).delete(); + await new Fetcher().resource("prj", project.Id).delete(); } this.reload(); }; diff --git a/src/zenserver/frontend/html/pages/start.js b/src/zenserver/frontend/html/pages/start.js index e5b4d14f1..d06040b2f 100644 --- a/src/zenserver/frontend/html/pages/start.js +++ b/src/zenserver/frontend/html/pages/start.js @@ -6,7 +6,7 @@ import { ZenPage } from "./page.js" import { Fetcher } from "../util/fetcher.js" import { Friendly } from "../util/friendly.js" import { Modal } from "../util/modal.js" -import { Table, Toolbar } from "../util/widgets.js" +import { Table, Toolbar, Pager } from "../util/widgets.js" //////////////////////////////////////////////////////////////////////////////// export class Page extends ZenPage @@ -50,54 +50,40 @@ export class Page extends ZenPage this._render_stats(all_stats); // project list - var project_table = null; if (available.has("/prj/")) { var section = this.add_section("Cooked Projects"); - section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("projects")); - var columns = [ "name", "project_dir", "engine_dir", "actions", ]; - project_table = section.add_widget(Table, columns); - - var projects = await new Fetcher().resource("/prj/list").json(); - projects.sort((a, b) => (b.LastAccessTime || 0) - (a.LastAccessTime || 0)); - projects = projects.slice(0, 25); - projects.sort((a, b) => a.Id.localeCompare(b.Id)); - - for (const project of projects) - { - var row = project_table.add_row( - "", - project.ProjectRootDir, - project.EngineRootDir, - ); - - var cell = row.get_cell(0); - cell.tag().text(project.Id).on_click((x) => this.view_project(x), project.Id); - - var cell = row.get_cell(-1); - var action_tb = new Toolbar(cell, true); - action_tb.left().add("view").on_click((x) => this.view_project(x), project.Id); - action_tb.left().add("drop").on_click((x) => this.drop_project(x), project.Id); - - row.attr("zs_name", project.Id); - } + this._project_table = section.add_widget(Table, columns); + + this._project_pager = new Pager(section, 25, () => this._render_projects_page(), + Pager.make_search_fn(() => this._projects_data, p => p.Id)); + const drop_link = document.createElement("span"); + drop_link.className = "dropall zen_action"; + drop_link.style.position = "static"; + drop_link.textContent = "drop-all"; + drop_link.addEventListener("click", () => this.drop_all("projects")); + this._project_pager.prepend(drop_link); + + const prj_loading = Pager.loading(section); + this._projects_data = await new Fetcher().resource("/prj/list").json(); + this._projects_data.sort((a, b) => a.Id.localeCompare(b.Id)); + this._project_pager.set_total(this._projects_data.length); + this._render_projects_page(); + prj_loading.remove(); } // cache - var cache_table = null; if (available.has("/z$/")) { var section = this.add_section("Cache"); - section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("z$")); - var columns = [ "namespace", "dir", @@ -107,30 +93,30 @@ export class Page extends ZenPage "size mem", "actions", ]; - var zcache_info = await new Fetcher().resource("/z$/").json(); - cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight); - for (const namespace of zcache_info["Namespaces"] || []) - { - new Fetcher().resource(`/z$/${namespace}/`).json().then((data) => { - const row = cache_table.add_row( - "", - data["Configuration"]["RootDir"], - data["Buckets"].length, - data["EntryCount"], - Friendly.bytes(data["StorageSize"].DiskSize), - Friendly.bytes(data["StorageSize"].MemorySize) - ); - var cell = row.get_cell(0); - cell.tag().text(namespace).on_click(() => this.view_zcache(namespace)); - - cell = row.get_cell(-1); - const action_tb = new Toolbar(cell, true); - action_tb.left().add("view").on_click(() => this.view_zcache(namespace)); - action_tb.left().add("drop").on_click(() => this.drop_zcache(namespace)); - - row.attr("zs_name", namespace); - }); - } + this._cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight); + + this._cache_pager = new Pager(section, 25, () => this._render_cache_page(), + Pager.make_search_fn(() => this._cache_data, item => item.namespace)); + const cache_drop_link = document.createElement("span"); + cache_drop_link.className = "dropall zen_action"; + cache_drop_link.style.position = "static"; + cache_drop_link.textContent = "drop-all"; + cache_drop_link.addEventListener("click", () => this.drop_all("z$")); + this._cache_pager.prepend(cache_drop_link); + + const cache_loading = Pager.loading(section); + const zcache_info = await new Fetcher().resource("/z$/").json(); + const namespaces = zcache_info["Namespaces"] || []; + const results = await Promise.allSettled( + namespaces.map(ns => new Fetcher().resource(`/z$/${ns}/`).json().then(data => ({ namespace: ns, data }))) + ); + this._cache_data = results + .filter(r => r.status === "fulfilled") + .map(r => r.value) + .sort((a, b) => a.namespace.localeCompare(b.namespace)); + this._cache_pager.set_total(this._cache_data.length); + this._render_cache_page(); + cache_loading.remove(); } // version @@ -139,15 +125,13 @@ export class Page extends ZenPage version.param("detailed", "true"); version.text().then((data) => ver_tag.text(data)); - this._project_table = project_table; - this._cache_table = cache_table; - // WebSocket for live stats updates this.connect_stats_ws((all_stats) => this._render_stats(all_stats)); } _render_stats(all_stats) { + all_stats = this._merge_last_stats(all_stats); const grid = this._stats_grid; const safe_lookup = this._safe_lookup; @@ -316,6 +300,60 @@ export class Page extends ZenPage m.tag().classify("metric-label").text(label); } + _render_projects_page() + { + const { start, end } = this._project_pager.page_range(); + this._project_table.clear(start); + for (let i = start; i < end; i++) + { + const project = this._projects_data[i]; + const row = this._project_table.add_row( + "", + project.ProjectRootDir, + project.EngineRootDir, + ); + + const cell = row.get_cell(0); + cell.tag().text(project.Id).on_click((x) => this.view_project(x), project.Id); + + const action_cell = row.get_cell(-1); + const action_tb = new Toolbar(action_cell, true); + action_tb.left().add("view").on_click((x) => this.view_project(x), project.Id); + action_tb.left().add("drop").on_click((x) => this.drop_project(x), project.Id); + + row.attr("zs_name", project.Id); + } + } + + _render_cache_page() + { + const { start, end } = this._cache_pager.page_range(); + this._cache_table.clear(start); + for (let i = start; i < end; i++) + { + const item = this._cache_data[i]; + const data = item.data; + const row = this._cache_table.add_row( + "", + data["Configuration"]["RootDir"], + data["Buckets"].length, + data["EntryCount"], + Friendly.bytes(data["StorageSize"].DiskSize), + Friendly.bytes(data["StorageSize"].MemorySize) + ); + + const cell = row.get_cell(0); + cell.tag().text(item.namespace).on_click(() => this.view_zcache(item.namespace)); + + const action_cell = row.get_cell(-1); + const action_tb = new Toolbar(action_cell, true); + action_tb.left().add("view").on_click(() => this.view_zcache(item.namespace)); + action_tb.left().add("drop").on_click(() => this.drop_zcache(item.namespace)); + + row.attr("zs_name", item.namespace); + } + } + view_stat(provider) { window.location = "?page=stat&provider=" + provider; @@ -361,20 +399,18 @@ export class Page extends ZenPage async drop_all_projects() { - for (const row of this._project_table) + for (const project of this._projects_data || []) { - const project_id = row.attr("zs_name"); - await new Fetcher().resource("prj", project_id).delete(); + await new Fetcher().resource("prj", project.Id).delete(); } this.reload(); } async drop_all_zcache() { - for (const row of this._cache_table) + for (const item of this._cache_data || []) { - const namespace = row.attr("zs_name"); - await new Fetcher().resource("z$", namespace).delete(); + await new Fetcher().resource("z$", item.namespace).delete(); } this.reload(); } diff --git a/src/zenserver/frontend/html/pages/workspaces.js b/src/zenserver/frontend/html/pages/workspaces.js index d31fd7373..db02e8be1 100644 --- a/src/zenserver/frontend/html/pages/workspaces.js +++ b/src/zenserver/frontend/html/pages/workspaces.js @@ -4,6 +4,7 @@ import { ZenPage } from "./page.js" import { Fetcher } from "../util/fetcher.js" +import { copy_button } from "../util/widgets.js" //////////////////////////////////////////////////////////////////////////////// export class Page extends ZenPage @@ -13,7 +14,7 @@ export class Page extends ZenPage this.set_title("workspaces"); // Workspace Service Stats - const stats_section = this.add_section("Workspace Service Stats"); + const stats_section = this._collapsible_section("Workspace Service Stats"); this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles"); const stats = await new Fetcher().resource("stats", "ws").json().catch(() => null); @@ -157,6 +158,7 @@ export class Page extends ZenPage id_wrap.className = "ws-id-wrap"; id_wrap.appendChild(btn_expand); id_wrap.appendChild(document.createTextNode("\u00A0" + id)); + id_wrap.appendChild(copy_button(id)); const td_id = document.createElement("td"); td_id.appendChild(id_wrap); tr.appendChild(td_id); @@ -200,6 +202,7 @@ export class Page extends ZenPage _render_stats(stats) { + stats = this._merge_last_stats(stats); const grid = this._stats_grid; grid.inner().innerHTML = ""; diff --git a/src/zenserver/frontend/html/util/widgets.js b/src/zenserver/frontend/html/util/widgets.js index 17bd2fde7..651686a11 100644 --- a/src/zenserver/frontend/html/util/widgets.js +++ b/src/zenserver/frontend/html/util/widgets.js @@ -6,6 +6,58 @@ import { Component } from "./component.js" import { Friendly } from "../util/friendly.js" //////////////////////////////////////////////////////////////////////////////// +export function flash_highlight(element) +{ + if (!element) { return; } + element.classList.add("pager-search-highlight"); + setTimeout(() => { element.classList.remove("pager-search-highlight"); }, 1500); +} + +//////////////////////////////////////////////////////////////////////////////// +export function copy_button(value_or_fn) +{ + if (!navigator.clipboard) + { + const stub = document.createElement("span"); + stub.style.display = "none"; + return stub; + } + + let reset_timer = 0; + const btn = document.createElement("button"); + btn.className = "zen-copy-btn"; + btn.title = "Copy to clipboard"; + btn.textContent = "\u29C9"; + btn.addEventListener("click", async (e) => { + e.stopPropagation(); + const v = typeof value_or_fn === "function" ? value_or_fn() : value_or_fn; + if (!v) { return; } + try + { + await navigator.clipboard.writeText(v); + clearTimeout(reset_timer); + btn.classList.add("zen-copy-ok"); + btn.textContent = "\u2713"; + reset_timer = setTimeout(() => { btn.classList.remove("zen-copy-ok"); btn.textContent = "\u29C9"; }, 800); + } + catch (_e) { /* clipboard not available */ } + }); + return btn; +} + +// Wraps the existing children of `element` plus a copy button into an +// inline-flex nowrap container so the button never wraps to a new line. +export function add_copy_button(element, value_or_fn) +{ + if (!navigator.clipboard) { return; } + const wrap = document.createElement("span"); + wrap.className = "zen-copy-wrap"; + while (element.firstChild) { wrap.appendChild(element.firstChild); } + wrap.appendChild(copy_button(value_or_fn)); + element.appendChild(wrap); +} + +//////////////////////////////////////////////////////////////////////////////// class Widget extends Component { } @@ -402,6 +454,135 @@ export class ProgressBar extends Widget //////////////////////////////////////////////////////////////////////////////// +export class Pager +{ + constructor(section, page_size, on_change, search_fn) + { + this._page = 0; + this._page_size = page_size; + this._total = 0; + this._on_change = on_change; + this._search_fn = search_fn || null; + this._search_input = null; + + const pager = section.tag().classify("module-pager").inner(); + this._btn_prev = document.createElement("button"); + this._btn_prev.className = "module-pager-btn"; + this._btn_prev.textContent = "\u2190 Prev"; + this._btn_prev.addEventListener("click", () => this._go_page(this._page - 1)); + this._label = document.createElement("span"); + this._label.className = "module-pager-label"; + this._btn_next = document.createElement("button"); + this._btn_next.className = "module-pager-btn"; + this._btn_next.textContent = "Next \u2192"; + this._btn_next.addEventListener("click", () => this._go_page(this._page + 1)); + + if (this._search_fn) + { + this._search_input = document.createElement("input"); + this._search_input.type = "text"; + this._search_input.className = "module-pager-search"; + this._search_input.placeholder = "Search\u2026"; + this._search_input.addEventListener("keydown", (e) => + { + if (e.key === "Enter") + { + this._do_search(this._search_input.value.trim()); + } + }); + pager.appendChild(this._search_input); + } + + pager.appendChild(this._btn_prev); + pager.appendChild(this._label); + pager.appendChild(this._btn_next); + this._pager = pager; + + this._update_ui(); + } + + prepend(element) + { + const ref = this._search_input || this._btn_prev; + this._pager.insertBefore(element, ref); + } + + set_total(n) + { + this._total = n; + const max_page = Math.max(0, Math.ceil(n / this._page_size) - 1); + if (this._page > max_page) + { + this._page = max_page; + } + this._update_ui(); + } + + page_range() + { + const start = this._page * this._page_size; + const end = Math.min(start + this._page_size, this._total); + return { start, end }; + } + + _go_page(n) + { + const max = Math.max(0, Math.ceil(this._total / this._page_size) - 1); + this._page = Math.max(0, Math.min(n, max)); + this._update_ui(); + this._on_change(); + } + + _do_search(term) + { + if (!term || !this._search_fn) + { + return; + } + const result = this._search_fn(term); + if (!result) + { + this._search_input.style.outline = "2px solid var(--theme_fail)"; + setTimeout(() => { this._search_input.style.outline = ""; }, 1000); + return; + } + this._go_page(Math.floor(result.index / this._page_size)); + flash_highlight(this._pager.parentNode.querySelector(`[zs_name="${CSS.escape(result.name)}"]`)); + } + + _update_ui() + { + const total = this._total; + const page_count = Math.max(1, Math.ceil(total / this._page_size)); + const start = this._page * this._page_size + 1; + const end = Math.min(start + this._page_size - 1, total); + + this._btn_prev.disabled = this._page === 0; + this._btn_next.disabled = this._page >= page_count - 1; + this._label.textContent = total === 0 + ? "No items" + : `${start}\u2013${end} of ${total}`; + } + + static make_search_fn(get_data, get_key) + { + return (term) => { + const t = term.toLowerCase(); + const data = get_data(); + const i = data.findIndex(item => get_key(item).toLowerCase().includes(t)); + return i < 0 ? null : { index: i, name: get_key(data[i]) }; + }; + } + + static loading(section) + { + return section.tag().classify("pager-loading").text("Loading\u2026").inner(); + } +} + + + +//////////////////////////////////////////////////////////////////////////////// export class WidgetHost { constructor(parent, depth=1) diff --git a/src/zenserver/frontend/html/zen.css b/src/zenserver/frontend/html/zen.css index d9f7491ea..d3c6c9036 100644 --- a/src/zenserver/frontend/html/zen.css +++ b/src/zenserver/frontend/html/zen.css @@ -816,6 +816,10 @@ zen-banner + zen-nav::part(nav-bar) { border-color: var(--theme_p0); } +.stats-tile[data-over="true"] { + border-color: var(--theme_fail); +} + .stats-tile-detailed { position: relative; } @@ -1607,6 +1611,25 @@ tr:last-child td { animation: module-dot-deprovisioning-from-provisioned 1s steps(1, end) infinite; } +@keyframes module-dot-obliterating-from-provisioned { + 0%, 59.9% { background: var(--theme_fail); } + 60%, 100% { background: var(--theme_ok); } +} +@keyframes module-dot-obliterating-from-hibernated { + 0%, 59.9% { background: var(--theme_fail); } + 60%, 100% { background: var(--theme_warn); } +} + +.module-state-dot[data-state="obliterating"][data-prev-state="provisioned"] { + animation: module-dot-obliterating-from-provisioned 0.5s steps(1, end) infinite; +} +.module-state-dot[data-state="obliterating"][data-prev-state="hibernated"] { + animation: module-dot-obliterating-from-hibernated 0.5s steps(1, end) infinite; +} +.module-state-dot[data-state="obliterating"] { + animation: module-dot-obliterating-from-provisioned 0.5s steps(1, end) infinite; +} + .module-action-cell { white-space: nowrap; display: flex; @@ -1726,6 +1749,53 @@ tr:last-child td { text-align: center; } +.module-pager-search { + font-size: 12px; + padding: 4px 8px; + width: 14em; + border: 1px solid var(--theme_g2); + border-radius: 4px; + background: var(--theme_g4); + color: var(--theme_g0); + outline: none; + transition: border-color 0.15s, outline 0.3s; +} + +.module-pager-search:focus { + border-color: var(--theme_p0); +} + +.module-pager-search::placeholder { + color: var(--theme_g1); +} + +@keyframes pager-search-flash { + from { box-shadow: inset 0 0 0 100px var(--theme_p2); } + to { box-shadow: inset 0 0 0 100px transparent; } +} + +.zen_table > .pager-search-highlight > div { + animation: pager-search-flash 1s linear forwards; +} + +.module-table .pager-search-highlight td { + animation: pager-search-flash 1s linear forwards; +} + +@keyframes pager-loading-pulse { + 0%, 100% { opacity: 0.6; } + 50% { opacity: 0.2; } +} + +.pager-loading { + color: var(--theme_g1); + font-style: italic; + font-size: 14px; + font-weight: 600; + padding: 12px 0; + animation: pager-loading-pulse 1.5s ease-in-out infinite; +} + .module-table td, .module-table th { padding-top: 4px; padding-bottom: 4px; @@ -1746,6 +1816,35 @@ tr:last-child td { color: var(--theme_bright); } +.zen-copy-btn { + background: transparent; + border: 1px solid var(--theme_g2); + border-radius: 4px; + color: var(--theme_g1); + cursor: pointer; + font-size: 12px; + line-height: 1; + padding: 2px 5px; + margin-left: 6px; + vertical-align: middle; + flex-shrink: 0; + transition: background 0.1s, color 0.1s; +} +.zen-copy-btn:hover { + background: var(--theme_g2); + color: var(--theme_bright); +} +.zen-copy-btn.zen-copy-ok { + color: var(--theme_ok); + border-color: var(--theme_ok); +} + +.zen-copy-wrap { + display: inline-flex; + align-items: center; + white-space: nowrap; +} + .module-metrics-row td { padding: 6px 10px 10px 42px; background: var(--theme_g3); diff --git a/src/zenserver/frontend/zipfs.cpp b/src/zenserver/frontend/zipfs.cpp index c7c8687ca..27b92f33a 100644 --- a/src/zenserver/frontend/zipfs.cpp +++ b/src/zenserver/frontend/zipfs.cpp @@ -189,12 +189,12 @@ ZipFs::GetFile(const std::string_view& FileName) const if (Item.CompressionMethod == 0) { - // Stored — point directly into the buffer + // Stored - point directly into the buffer Item.View = MemoryView(FileData, Item.UncompressedSize); } else { - // Deflate — decompress using zlib + // Deflate - decompress using zlib Item.DecompressedData = IoBuffer(Item.UncompressedSize); z_stream Stream = {}; diff --git a/src/zenserver/hub/README.md b/src/zenserver/hub/README.md index 322be3649..c75349fa5 100644 --- a/src/zenserver/hub/README.md +++ b/src/zenserver/hub/README.md @@ -3,23 +3,32 @@ The Zen Server can act in a "hub" mode. In this mode, the only services offered are the basic health and diagnostic services alongside an API to provision and deprovision Storage server instances. +A module ID is an alphanumeric identifier (hyphens allowed) that identifies a dataset, typically +associated with a content plug-in module. + ## Generic Server API GET `/health` - returns an `OK!` payload when all enabled services are up and responding ## Hub API -GET `{moduleid}` - alphanumeric identifier to identify a dataset (typically associated with a content plug-in module) - -GET `/hub/status` - obtain a summary of the currently live instances +GET `/hub/status` - obtain a summary of all currently live instances GET `/hub/modules/{moduleid}` - retrieve information about a module +DELETE `/hub/modules/{moduleid}` - obliterate a module (permanently destroys all data) + POST `/hub/modules/{moduleid}/provision` - provision service for module POST `/hub/modules/{moduleid}/deprovision` - deprovision service for module -GET `/hub/stats` - retrieve stats for service +POST `/hub/modules/{moduleid}/hibernate` - hibernate a provisioned module + +POST `/hub/modules/{moduleid}/wake` - wake a hibernated module + +GET `/stats/hub` - retrieve stats for the hub service + +`/hub/proxy/{port}/{path}` - reverse proxy to a child instance dashboard (all HTTP verbs) ## Hub Configuration diff --git a/src/zenserver/hub/httphubservice.cpp b/src/zenserver/hub/httphubservice.cpp index ebefcf2e3..e4b0c28d0 100644 --- a/src/zenserver/hub/httphubservice.cpp +++ b/src/zenserver/hub/httphubservice.cpp @@ -2,6 +2,7 @@ #include "httphubservice.h" +#include "httpproxyhandler.h" #include "hub.h" #include "storageserverinstance.h" @@ -43,10 +44,11 @@ namespace { } } // namespace -HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpStatusService& StatusService) +HttpHubService::HttpHubService(Hub& Hub, HttpProxyHandler& Proxy, HttpStatsService& StatsService, HttpStatusService& StatusService) : m_Hub(Hub) , m_StatsService(StatsService) , m_StatusService(StatusService) +, m_Proxy(Proxy) { using namespace std::literals; @@ -67,6 +69,23 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta return true; }); + m_Router.AddMatcher("port", [](std::string_view Str) -> bool { + if (Str.empty()) + { + return false; + } + for (const auto C : Str) + { + if (!std::isdigit(C)) + { + return false; + } + } + return true; + }); + + m_Router.AddMatcher("proxypath", [](std::string_view Str) -> bool { return !Str.empty(); }); + m_Router.RegisterRoute( "status", [this](HttpRouterRequest& Req) { @@ -78,6 +97,10 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta Obj << "moduleId" << ModuleId; Obj << "state" << ToString(Info.State); Obj << "port" << Info.Port; + if (Info.StateChangeTime != std::chrono::system_clock::time_point::min()) + { + Obj << "state_change_time" << ToDateTime(Info.StateChangeTime); + } Obj.BeginObject("process_metrics"); { Obj << "MemoryBytes" << Info.Metrics.MemoryBytes; @@ -98,6 +121,11 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta HttpVerb::kGet); m_Router.RegisterRoute( + "deprovision", + [this](HttpRouterRequest& Req) { HandleDeprovisionAll(Req.ServerRequest()); }, + HttpVerb::kPost); + + m_Router.RegisterRoute( "modules/{moduleid}", [this](HttpRouterRequest& Req) { std::string_view ModuleId = Req.GetCapture(1); @@ -229,15 +257,23 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta HttpVerb::kPost); m_Router.RegisterRoute( - "stats", + "proxy/{port}/{proxypath}", [this](HttpRouterRequest& Req) { - CbObjectWriter Obj; - Obj << "currentInstanceCount" << m_Hub.GetInstanceCount(); - Obj << "maxInstanceCount" << m_Hub.GetMaxInstanceCount(); - Obj << "instanceLimit" << m_Hub.GetConfig().InstanceLimit; - Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + std::string_view PortStr = Req.GetCapture(1); + + // Use RelativeUriWithExtension to preserve the file extension that the + // router's URI parser strips (e.g. ".css", ".js") - the upstream server + // needs the full path including the extension. + std::string_view FullUri = Req.ServerRequest().RelativeUriWithExtension(); + std::string_view Prefix = "proxy/"; + + // FullUri is "proxy/{port}/{path...}" - skip past "proxy/{port}/" + size_t PathStart = Prefix.size() + PortStr.size() + 1; + std::string_view PathTail = (PathStart < FullUri.size()) ? FullUri.substr(PathStart) : std::string_view{}; + + m_Proxy.HandleProxyRequest(Req.ServerRequest(), PortStr, PathTail); }, - HttpVerb::kGet); + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); m_StatsService.RegisterHandler("hub", *this); m_StatusService.RegisterHandler("hub", *this); @@ -286,7 +322,37 @@ HttpHubService::HandleStatusRequest(HttpServerRequest& Request) void HttpHubService::HandleStatsRequest(HttpServerRequest& Request) { - Request.WriteResponse(HttpResponseCode::OK, CollectStats()); + CbObjectWriter Cbo; + + EmitSnapshot("requests", m_HttpRequests, Cbo); + + Cbo << "currentInstanceCount" << m_Hub.GetInstanceCount(); + Cbo << "maxInstanceCount" << m_Hub.GetMaxInstanceCount(); + Cbo << "instanceLimit" << m_Hub.GetConfig().InstanceLimit; + + SystemMetrics SysMetrics; + DiskSpace Disk; + m_Hub.GetMachineMetrics(SysMetrics, Disk); + Cbo.BeginObject("machine"); + { + Cbo << "disk_free_bytes" << Disk.Free; + Cbo << "disk_total_bytes" << Disk.Total; + Cbo << "memory_avail_mib" << SysMetrics.AvailSystemMemoryMiB; + Cbo << "memory_total_mib" << SysMetrics.SystemMemoryMiB; + Cbo << "virtual_memory_avail_mib" << SysMetrics.AvailVirtualMemoryMiB; + Cbo << "virtual_memory_total_mib" << SysMetrics.VirtualMemoryMiB; + } + Cbo.EndObject(); + + const ResourceMetrics& Limits = m_Hub.GetConfig().ResourceLimits; + Cbo.BeginObject("resource_limits"); + { + Cbo << "disk_bytes" << Limits.DiskUsageBytes; + Cbo << "memory_bytes" << Limits.MemoryUsageBytes; + } + Cbo.EndObject(); + + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); } CbObject @@ -310,6 +376,81 @@ HttpHubService::GetActivityCounter() } void +HttpHubService::HandleDeprovisionAll(HttpServerRequest& Request) +{ + std::vector<std::string> ModulesToDeprovision; + m_Hub.EnumerateModules([&ModulesToDeprovision](std::string_view ModuleId, const Hub::InstanceInfo& InstanceInfo) { + if (InstanceInfo.State == HubInstanceState::Provisioned || InstanceInfo.State == HubInstanceState::Hibernated) + { + ModulesToDeprovision.push_back(std::string(ModuleId)); + } + }); + + if (ModulesToDeprovision.empty()) + { + return Request.WriteResponse(HttpResponseCode::OK); + } + std::vector<std::string> Rejected; + std::vector<std::string> Accepted; + std::vector<std::string> Completed; + for (const std::string& ModuleId : ModulesToDeprovision) + { + Hub::Response Response = m_Hub.Deprovision(ModuleId); + switch (Response.ResponseCode) + { + case Hub::EResponseCode::NotFound: + // Ignore + break; + case Hub::EResponseCode::Rejected: + Rejected.push_back(ModuleId); + break; + case Hub::EResponseCode::Accepted: + Accepted.push_back(ModuleId); + break; + case Hub::EResponseCode::Completed: + Completed.push_back(ModuleId); + break; + } + } + if (Rejected.empty() && Accepted.empty() && Completed.empty()) + { + return Request.WriteResponse(HttpResponseCode::OK); + } + HttpResponseCode Response = HttpResponseCode::OK; + CbObjectWriter Writer; + if (!Completed.empty()) + { + Writer.BeginArray("Completed"); + for (const std::string& ModuleId : Completed) + { + Writer.AddString(ModuleId); + } + Writer.EndArray(); // Completed + } + if (!Accepted.empty()) + { + Writer.BeginArray("Accepted"); + for (const std::string& ModuleId : Accepted) + { + Writer.AddString(ModuleId); + } + Writer.EndArray(); // Accepted + Response = HttpResponseCode::Accepted; + } + if (!Rejected.empty()) + { + Writer.BeginArray("Rejected"); + for (const std::string& ModuleId : Rejected) + { + Writer.AddString(ModuleId); + } + Writer.EndArray(); // Rejected + Response = HttpResponseCode::Conflict; + } + Request.WriteResponse(Response, Writer.Save()); +} + +void HttpHubService::HandleModuleGet(HttpServerRequest& Request, std::string_view ModuleId) { Hub::InstanceInfo InstanceInfo; @@ -328,45 +469,36 @@ HttpHubService::HandleModuleGet(HttpServerRequest& Request, std::string_view Mod void HttpHubService::HandleModuleDelete(HttpServerRequest& Request, std::string_view ModuleId) { - Hub::InstanceInfo InstanceInfo; - if (!m_Hub.Find(ModuleId, &InstanceInfo)) + Hub::Response Resp = m_Hub.Obliterate(std::string(ModuleId)); + + if (HandleFailureResults(Request, Resp)) { - Request.WriteResponse(HttpResponseCode::NotFound); return; } - if (InstanceInfo.State == HubInstanceState::Provisioned || InstanceInfo.State == HubInstanceState::Hibernated || - InstanceInfo.State == HubInstanceState::Crashed) - { - try - { - Hub::Response Resp = m_Hub.Deprovision(std::string(ModuleId)); - - if (HandleFailureResults(Request, Resp)) - { - return; - } - - // TODO: nuke all related storage + const HttpResponseCode HttpCode = + (Resp.ResponseCode == Hub::EResponseCode::Accepted) ? HttpResponseCode::Accepted : HttpResponseCode::OK; + CbObjectWriter Obj; + Obj << "moduleId" << ModuleId; + Request.WriteResponse(HttpCode, Obj.Save()); +} - const HttpResponseCode HttpCode = - (Resp.ResponseCode == Hub::EResponseCode::Accepted) ? HttpResponseCode::Accepted : HttpResponseCode::OK; - CbObjectWriter Obj; - Obj << "moduleId" << ModuleId; - return Request.WriteResponse(HttpCode, Obj.Save()); - } - catch (const std::exception& Ex) - { - ZEN_ERROR("Exception while deprovisioning module '{}': {}", ModuleId, Ex.what()); - throw; - } - } +void +HttpHubService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) +{ + m_Proxy.OnWebSocketOpen(std::move(Connection), RelativeUri); +} - // TODO: nuke all related storage +void +HttpHubService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + m_Proxy.OnWebSocketMessage(Conn, Msg); +} - CbObjectWriter Obj; - Obj << "moduleId" << ModuleId; - Request.WriteResponse(HttpResponseCode::OK, Obj.Save()); +void +HttpHubService::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) +{ + m_Proxy.OnWebSocketClose(Conn, Code, Reason); } } // namespace zen diff --git a/src/zenserver/hub/httphubservice.h b/src/zenserver/hub/httphubservice.h index 1bb1c303e..f4d1b0b89 100644 --- a/src/zenserver/hub/httphubservice.h +++ b/src/zenserver/hub/httphubservice.h @@ -2,11 +2,16 @@ #pragma once +#include <zencore/thread.h> #include <zenhttp/httpserver.h> #include <zenhttp/httpstatus.h> +#include <zenhttp/websocket.h> + +#include <memory> namespace zen { +class HttpProxyHandler; class HttpStatsService; class Hub; @@ -16,10 +21,10 @@ class Hub; * use in UEFN content worker style scenarios. * */ -class HttpHubService : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider +class HttpHubService : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider, public IWebSocketHandler { public: - HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpStatusService& StatusService); + HttpHubService(Hub& Hub, HttpProxyHandler& Proxy, HttpStatsService& StatsService, HttpStatusService& StatusService); ~HttpHubService(); HttpHubService(const HttpHubService&) = delete; @@ -32,6 +37,11 @@ public: virtual CbObject CollectStats() override; virtual uint64_t GetActivityCounter() override; + // IWebSocketHandler + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; + void SetNotificationEndpoint(std::string_view UpstreamNotificationEndpoint, std::string_view InstanceId); private: @@ -43,8 +53,11 @@ private: HttpStatsService& m_StatsService; HttpStatusService& m_StatusService; + void HandleDeprovisionAll(HttpServerRequest& Request); void HandleModuleGet(HttpServerRequest& Request, std::string_view ModuleId); void HandleModuleDelete(HttpServerRequest& Request, std::string_view ModuleId); + + HttpProxyHandler& m_Proxy; }; } // namespace zen diff --git a/src/zenserver/hub/httpproxyhandler.cpp b/src/zenserver/hub/httpproxyhandler.cpp new file mode 100644 index 000000000..235d7388f --- /dev/null +++ b/src/zenserver/hub/httpproxyhandler.cpp @@ -0,0 +1,528 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpproxyhandler.h" + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include <zenhttp/httpclient.h> +#include <zenhttp/httpwsclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <charconv> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +#endif // ZEN_WITH_TESTS + +namespace zen { + +namespace { + + std::string InjectProxyScript(std::string_view Html, uint16_t Port) + { + ExtendableStringBuilder<2048> Script; + Script.Append("<script>\n(function(){\n var P = \"/hub/proxy/"); + Script.Append(fmt::format("{}", Port)); + Script.Append( + "\";\n" + " var OF = window.fetch;\n" + " window.fetch = function(u, o) {\n" + " if (typeof u === \"string\") {\n" + " try {\n" + " var p = new URL(u, location.origin);\n" + " if (p.origin === location.origin && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {\n" + " if (u.startsWith(\"/\") && !u.startsWith(P)) u = P + u;\n" + " }\n" + " }\n" + " return OF.call(this, u, o);\n" + " };\n" + " var OW = window.WebSocket;\n" + " window.WebSocket = function(u, pr) {\n" + " try {\n" + " var p = new URL(u);\n" + " if (p.hostname === location.hostname\n" + " && String(p.port || (p.protocol === \"wss:\" ? \"443\" : \"80\"))\n" + " === String(location.port || (location.protocol === \"https:\" ? \"443\" : \"80\"))\n" + " && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {}\n" + " return pr !== undefined ? new OW(u, pr) : new OW(u);\n" + " };\n" + " window.WebSocket.prototype = OW.prototype;\n" + " window.WebSocket.CONNECTING = OW.CONNECTING;\n" + " window.WebSocket.OPEN = OW.OPEN;\n" + " window.WebSocket.CLOSING = OW.CLOSING;\n" + " window.WebSocket.CLOSED = OW.CLOSED;\n" + " var OO = window.open;\n" + " window.open = function(u, t, f) {\n" + " if (typeof u === \"string\") {\n" + " try {\n" + " var p = new URL(u, location.origin);\n" + " if (p.origin === location.origin && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {}\n" + " }\n" + " return OO.call(this, u, t, f);\n" + " };\n" + " document.addEventListener(\"click\", function(e) {\n" + " var t = e.composedPath ? e.composedPath()[0] : e.target;\n" + " while (t && t.tagName !== \"A\") t = t.parentNode || t.host;\n" + " if (!t || !t.href) return;\n" + " try {\n" + " var h = new URL(t.href);\n" + " if (h.origin === location.origin && !h.pathname.startsWith(P))\n" + " { h.pathname = P + h.pathname; e.preventDefault(); window.location.href = h.toString(); }\n" + " } catch(x) {}\n" + " }, true);\n" + "})();\n</script>"); + + std::string ScriptStr = Script.ToString(); + + size_t HeadClose = Html.find("</head>"); + if (HeadClose != std::string_view::npos) + { + std::string Result; + Result.reserve(Html.size() + ScriptStr.size()); + Result.append(Html.substr(0, HeadClose)); + Result.append(ScriptStr); + Result.append(Html.substr(HeadClose)); + return Result; + } + + std::string Result; + Result.reserve(Html.size() + ScriptStr.size()); + Result.append(ScriptStr); + Result.append(Html); + return Result; + } + +} // namespace + +struct HttpProxyHandler::WsBridge : public RefCounted, public IWsClientHandler +{ + Ref<WebSocketConnection> ClientConn; + std::unique_ptr<HttpWsClient> UpstreamClient; + uint16_t Port = 0; + + void OnWsOpen() override {} + + void OnWsMessage(const WebSocketMessage& Msg) override + { + if (!ClientConn->IsOpen()) + { + return; + } + switch (Msg.Opcode) + { + case WebSocketOpcode::kText: + ClientConn->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kBinary: + ClientConn->SendBinary(std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + default: + break; + } + } + + void OnWsClose(uint16_t Code, std::string_view Reason) override + { + if (ClientConn->IsOpen()) + { + ClientConn->Close(Code, Reason); + } + } +}; + +HttpProxyHandler::HttpProxyHandler() +{ +} + +HttpProxyHandler::HttpProxyHandler(PortValidator ValidatePort) : m_ValidatePort(std::move(ValidatePort)) +{ +} + +void +HttpProxyHandler::SetPortValidator(PortValidator ValidatePort) +{ + m_ValidatePort = std::move(ValidatePort); +} + +HttpProxyHandler::~HttpProxyHandler() +{ + try + { + Shutdown(); + } + catch (...) + { + } +} + +HttpClient& +HttpProxyHandler::GetOrCreateProxyClient(uint16_t Port) +{ + HttpClient* Result = nullptr; + m_ProxyClientsLock.WithExclusiveLock([&] { + auto It = m_ProxyClients.find(Port); + if (It == m_ProxyClients.end()) + { + HttpClientSettings Settings; + Settings.LogCategory = "hub-proxy"; + Settings.ConnectTimeout = std::chrono::milliseconds(5000); + Settings.Timeout = std::chrono::milliseconds(30000); + auto Client = std::make_unique<HttpClient>(fmt::format("http://127.0.0.1:{}", Port), Settings); + Result = Client.get(); + m_ProxyClients.emplace(Port, std::move(Client)); + } + else + { + Result = It->second.get(); + } + }); + return *Result; +} + +void +HttpProxyHandler::HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail) +{ + uint16_t Port = 0; + auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port); + if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size()) + { + Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "invalid proxy URL"); + return; + } + + if (!m_ValidatePort(Port)) + { + Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "target instance not available"); + return; + } + + HttpClient& Client = GetOrCreateProxyClient(Port); + + std::string RequestPath; + RequestPath.reserve(1 + PathTail.size()); + RequestPath.push_back('/'); + RequestPath.append(PathTail); + + std::string_view QueryString = Request.QueryString(); + if (!QueryString.empty()) + { + RequestPath.push_back('?'); + RequestPath.append(QueryString); + } + + HttpClient::KeyValueMap ForwardHeaders; + HttpContentType AcceptType = Request.AcceptContentType(); + if (AcceptType != HttpContentType::kUnknownContentType) + { + ForwardHeaders->emplace("Accept", std::string(MapContentTypeToString(AcceptType))); + } + + std::string_view Auth = Request.GetAuthorizationHeader(); + if (!Auth.empty()) + { + ForwardHeaders->emplace("Authorization", std::string(Auth)); + } + + HttpContentType ReqContentType = Request.RequestContentType(); + if (ReqContentType != HttpContentType::kUnknownContentType) + { + ForwardHeaders->emplace("Content-Type", std::string(MapContentTypeToString(ReqContentType))); + } + + HttpClient::Response Response; + + switch (Request.RequestVerb()) + { + case HttpVerb::kGet: + Response = Client.Get(RequestPath, ForwardHeaders); + break; + case HttpVerb::kPost: + { + IoBuffer Payload = Request.ReadPayload(); + Response = Client.Post(RequestPath, Payload, ForwardHeaders); + break; + } + case HttpVerb::kPut: + { + IoBuffer Payload = Request.ReadPayload(); + Response = Client.Put(RequestPath, Payload, ForwardHeaders); + break; + } + case HttpVerb::kDelete: + Response = Client.Delete(RequestPath, ForwardHeaders); + break; + case HttpVerb::kHead: + Response = Client.Head(RequestPath, ForwardHeaders); + break; + default: + Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "method not supported"); + return; + } + + if (Response.Error) + { + if (!m_ValidatePort(Port)) + { + Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "target instance not available"); + return; + } + + ZEN_WARN("proxy request to port {} failed: {}", Port, Response.Error->ErrorMessage); + switch (Response.Error->ErrorCode) + { + case HttpClientErrorCode::kConnectionFailure: + case HttpClientErrorCode::kHostResolutionFailure: + return Request.WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("instance not reachable: {}", Response.Error->ErrorMessage)); + case HttpClientErrorCode::kOperationTimedOut: + return Request.WriteResponse(HttpResponseCode::GatewayTimeout, + HttpContentType::kText, + fmt::format("upstream request timed out: {}", Response.Error->ErrorMessage)); + case HttpClientErrorCode::kRequestCancelled: + return Request.WriteResponse(HttpResponseCode::ServiceUnavailable, + HttpContentType::kText, + fmt::format("upstream request cancelled: {}", Response.Error->ErrorMessage)); + default: + return Request.WriteResponse(HttpResponseCode::BadGateway, + HttpContentType::kText, + fmt::format("upstream request failed: {}", Response.Error->ErrorMessage)); + } + } + + HttpContentType ContentType = Response.ResponsePayload.GetContentType(); + + if (ContentType == HttpContentType::kHTML) + { + std::string_view Html(static_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize()); + std::string Injected = InjectProxyScript(Html, Port); + Request.WriteResponse(Response.StatusCode, HttpContentType::kHTML, std::string_view(Injected)); + } + else + { + Request.WriteResponse(Response.StatusCode, ContentType, std::move(Response.ResponsePayload)); + } +} + +void +HttpProxyHandler::PrunePort(uint16_t Port) +{ + m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.erase(Port); }); + + std::vector<Ref<WsBridge>> Stale; + m_WsBridgesLock.WithExclusiveLock([&] { + for (auto It = m_WsBridges.begin(); It != m_WsBridges.end();) + { + if (It->second->Port == Port) + { + Stale.push_back(std::move(It->second)); + It = m_WsBridges.erase(It); + } + else + { + ++It; + } + } + }); + + for (auto& Bridge : Stale) + { + if (Bridge->UpstreamClient) + { + Bridge->UpstreamClient->Close(1001, "instance shutting down"); + } + if (Bridge->ClientConn->IsOpen()) + { + Bridge->ClientConn->Close(1001, "instance shutting down"); + } + } +} + +void +HttpProxyHandler::Shutdown() +{ + m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.clear(); }); + m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.clear(); }); +} + +////////////////////////////////////////////////////////////////////////// +// +// WebSocket proxy +// + +void +HttpProxyHandler::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) +{ + const std::string_view ProxyPrefix = "proxy/"; + if (!RelativeUri.starts_with(ProxyPrefix)) + { + Connection->Close(1008, "unsupported WebSocket endpoint"); + return; + } + + std::string_view ProxyTail = RelativeUri.substr(ProxyPrefix.size()); + + size_t SlashPos = ProxyTail.find('/'); + std::string_view PortStr = (SlashPos != std::string_view::npos) ? ProxyTail.substr(0, SlashPos) : ProxyTail; + std::string_view Path = (SlashPos != std::string_view::npos) ? ProxyTail.substr(SlashPos) : "/"; + + uint16_t Port = 0; + auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port); + if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size()) + { + Connection->Close(1008, "invalid proxy URL"); + return; + } + + if (!m_ValidatePort(Port)) + { + Connection->Close(1008, "target instance not available"); + return; + } + + std::string WsUrl = HttpToWsUrl(fmt::format("http://127.0.0.1:{}", Port), Path); + + Ref<WsBridge> Bridge(new WsBridge()); + Bridge->ClientConn = Connection; + Bridge->Port = Port; + + Bridge->UpstreamClient = std::make_unique<HttpWsClient>(WsUrl, *Bridge); + + try + { + Bridge->UpstreamClient->Connect(); + } + catch (const std::exception& Ex) + { + ZEN_WARN("proxy WebSocket connect to {} failed: {}", WsUrl, Ex.what()); + Connection->Close(1011, "upstream connect failed"); + return; + } + + WebSocketConnection* Key = Connection.Get(); + m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.emplace(Key, std::move(Bridge)); }); +} + +void +HttpProxyHandler::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + Ref<WsBridge> Bridge; + m_WsBridgesLock.WithSharedLock([&] { + auto It = m_WsBridges.find(&Conn); + if (It != m_WsBridges.end()) + { + Bridge = It->second; + } + }); + + if (!Bridge || !Bridge->UpstreamClient) + { + return; + } + + switch (Msg.Opcode) + { + case WebSocketOpcode::kText: + Bridge->UpstreamClient->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kBinary: + Bridge->UpstreamClient->SendBinary( + std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kClose: + Bridge->UpstreamClient->Close(Msg.CloseCode, {}); + break; + default: + break; + } +} + +void +HttpProxyHandler::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) +{ + Ref<WsBridge> Bridge = m_WsBridgesLock.WithExclusiveLock([this, &Conn]() -> Ref<WsBridge> { + auto It = m_WsBridges.find(&Conn); + if (It != m_WsBridges.end()) + { + Ref<WsBridge> Bridge = std::move(It->second); + m_WsBridges.erase(It); + return Bridge; + } + return {}; + }); + + if (Bridge && Bridge->UpstreamClient) + { + Bridge->UpstreamClient->Close(Code, Reason); + } +} + +#if ZEN_WITH_TESTS + +TEST_SUITE_BEGIN("server.httpproxyhandler"); + +TEST_CASE("server.httpproxyhandler.html_injection") +{ + SUBCASE("injects before </head>") + { + std::string Result = InjectProxyScript("<html><head></head><body></body></html>", 21005); + CHECK(Result.find("<script>") != std::string::npos); + CHECK(Result.find("/hub/proxy/21005") != std::string::npos); + size_t ScriptEnd = Result.find("</script>"); + size_t HeadClose = Result.find("</head>"); + REQUIRE(ScriptEnd != std::string::npos); + REQUIRE(HeadClose != std::string::npos); + CHECK(ScriptEnd < HeadClose); + } + + SUBCASE("prepends when no </head>") + { + std::string Result = InjectProxyScript("<body>content</body>", 21005); + CHECK(Result.find("<script>") == 0); + CHECK(Result.find("<body>content</body>") != std::string::npos); + } + + SUBCASE("empty html") + { + std::string Result = InjectProxyScript("", 21005); + CHECK(Result.find("<script>") != std::string::npos); + CHECK(Result.find("/hub/proxy/21005") != std::string::npos); + } + + SUBCASE("preserves original content") + { + std::string_view Html = "<html><head><title>Test</title></head><body><h1>Dashboard</h1></body></html>"; + std::string Result = InjectProxyScript(Html, 21005); + CHECK(Result.find("<title>Test</title>") != std::string::npos); + CHECK(Result.find("<h1>Dashboard</h1>") != std::string::npos); + } +} + +TEST_CASE("server.httpproxyhandler.port_embedding") +{ + std::string Result = InjectProxyScript("<head></head>", 80); + CHECK(Result.find("/hub/proxy/80") != std::string::npos); + + Result = InjectProxyScript("<head></head>", 65535); + CHECK(Result.find("/hub/proxy/65535") != std::string::npos); +} + +TEST_SUITE_END(); + +void +httpproxyhandler_forcelink() +{ +} +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zenserver/hub/httpproxyhandler.h b/src/zenserver/hub/httpproxyhandler.h new file mode 100644 index 000000000..8667c0ca1 --- /dev/null +++ b/src/zenserver/hub/httpproxyhandler.h @@ -0,0 +1,52 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> + +#include <functional> +#include <memory> +#include <unordered_map> + +namespace zen { + +class HttpClient; + +class HttpProxyHandler +{ +public: + using PortValidator = std::function<bool(uint16_t)>; + + HttpProxyHandler(); + explicit HttpProxyHandler(PortValidator ValidatePort); + ~HttpProxyHandler(); + + void SetPortValidator(PortValidator ValidatePort); + + HttpProxyHandler(const HttpProxyHandler&) = delete; + HttpProxyHandler& operator=(const HttpProxyHandler&) = delete; + + void HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail); + void PrunePort(uint16_t Port); + void Shutdown(); + + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri); + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg); + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason); + +private: + PortValidator m_ValidatePort; + + HttpClient& GetOrCreateProxyClient(uint16_t Port); + + RwLock m_ProxyClientsLock; + std::unordered_map<uint16_t, std::unique_ptr<HttpClient>> m_ProxyClients; + + struct WsBridge; + RwLock m_WsBridgesLock; + std::unordered_map<WebSocketConnection*, Ref<WsBridge>> m_WsBridges; +}; + +} // namespace zen diff --git a/src/zenserver/hub/hub.cpp b/src/zenserver/hub/hub.cpp index 6c44e2333..128d3ed35 100644 --- a/src/zenserver/hub/hub.cpp +++ b/src/zenserver/hub/hub.cpp @@ -19,7 +19,6 @@ ZEN_THIRD_PARTY_INCLUDES_START ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_WITH_TESTS -# include <zencore/filesystem.h> # include <zencore/testing.h> # include <zencore/testutils.h> #endif @@ -122,23 +121,73 @@ private: ////////////////////////////////////////////////////////////////////////// -Hub::Hub(const Configuration& Config, - ZenServerEnvironment&& RunEnvironment, - WorkerThreadPool* OptionalWorkerPool, - AsyncModuleStateChangeCallbackFunc&& ModuleStateChangeCallback) +ProcessMetrics +Hub::AtomicProcessMetrics::Load() const +{ + return { + .MemoryBytes = MemoryBytes.load(), + .KernelTimeMs = KernelTimeMs.load(), + .UserTimeMs = UserTimeMs.load(), + .WorkingSetSize = WorkingSetSize.load(), + .PeakWorkingSetSize = PeakWorkingSetSize.load(), + .PagefileUsage = PagefileUsage.load(), + .PeakPagefileUsage = PeakPagefileUsage.load(), + }; +} + +void +Hub::AtomicProcessMetrics::Store(const ProcessMetrics& Metrics) +{ + MemoryBytes.store(Metrics.MemoryBytes); + KernelTimeMs.store(Metrics.KernelTimeMs); + UserTimeMs.store(Metrics.UserTimeMs); + WorkingSetSize.store(Metrics.WorkingSetSize); + PeakWorkingSetSize.store(Metrics.PeakWorkingSetSize); + PagefileUsage.store(Metrics.PagefileUsage); + PeakPagefileUsage.store(Metrics.PeakPagefileUsage); +} + +void +Hub::AtomicProcessMetrics::Reset() +{ + MemoryBytes.store(0); + KernelTimeMs.store(0); + UserTimeMs.store(0); + WorkingSetSize.store(0); + PeakWorkingSetSize.store(0); + PagefileUsage.store(0); + PeakPagefileUsage.store(0); +} + +void +Hub::GetMachineMetrics(SystemMetrics& OutSystemMetrict, DiskSpace& OutDiskSpace) const +{ + m_Lock.WithSharedLock([&]() { + OutSystemMetrict = m_SystemMetrics; + OutDiskSpace = m_DiskSpace; + }); +} + +////////////////////////////////////////////////////////////////////////// + +Hub::Hub(const Configuration& Config, ZenServerEnvironment&& RunEnvironment, AsyncModuleStateChangeCallbackFunc&& ModuleStateChangeCallback) : m_Config(Config) , m_RunEnvironment(std::move(RunEnvironment)) -, m_WorkerPool(OptionalWorkerPool) +, m_WorkerPool(Config.OptionalProvisionWorkerPool) , m_BackgroundWorkLatch(1) , m_ModuleStateChangeCallback(std::move(ModuleStateChangeCallback)) , m_ActiveInstances(Config.InstanceLimit) , m_FreeActiveInstanceIndexes(Config.InstanceLimit) { - m_HostMetrics = GetSystemMetrics(); - m_ResourceLimits.DiskUsageBytes = 1000ull * 1024 * 1024 * 1024; - m_ResourceLimits.MemoryUsageBytes = 16ull * 1024 * 1024 * 1024; + ZEN_ASSERT_FORMAT( + Config.OptionalProvisionWorkerPool != Config.OptionalHydrationWorkerPool || Config.OptionalProvisionWorkerPool == nullptr, + "Provision and hydration worker pools must be distinct to avoid deadlocks"); - if (m_Config.HydrationTargetSpecification.empty()) + if (!m_Config.HydrationTargetSpecification.empty()) + { + m_HydrationTargetSpecification = m_Config.HydrationTargetSpecification; + } + else if (!m_Config.HydrationOptions) { std::filesystem::path FileHydrationPath = m_RunEnvironment.CreateChildDir("hydration_storage"); ZEN_INFO("using file hydration path: '{}'", FileHydrationPath); @@ -146,7 +195,7 @@ Hub::Hub(const Configuration& Config, } else { - m_HydrationTargetSpecification = m_Config.HydrationTargetSpecification; + m_HydrationOptions = m_Config.HydrationOptions; } m_HydrationTempPath = m_RunEnvironment.CreateChildDir("hydration_temp"); @@ -171,6 +220,9 @@ Hub::Hub(const Configuration& Config, } } #endif + + UpdateMachineMetrics(); + m_WatchDog = std::thread([this]() { WatchDog(); }); } @@ -195,6 +247,9 @@ Hub::Shutdown() { ZEN_INFO("Hub service shutting down, deprovisioning any current instances"); + bool Expected = false; + bool WaitForBackgroundWork = m_ShutdownFlag.compare_exchange_strong(Expected, true); + m_WatchDogEvent.Set(); if (m_WatchDog.joinable()) { @@ -203,8 +258,6 @@ Hub::Shutdown() m_WatchDog = {}; - bool Expected = false; - bool WaitForBackgroundWork = m_ShutdownFlag.compare_exchange_strong(Expected, true); if (WaitForBackgroundWork && m_WorkerPool) { m_BackgroundWorkLatch.CountDown(); @@ -254,7 +307,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) if (auto It = m_InstanceLookup.find(std::string(ModuleId)); It == m_InstanceLookup.end()) { std::string Reason; - if (!CanProvisionInstance(ModuleId, /* out */ Reason)) + if (!CanProvisionInstanceLocked(ModuleId, /* out */ Reason)) { ZEN_WARN("Cannot provision new storage server instance for module '{}': {}", ModuleId, Reason); @@ -272,11 +325,18 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) auto NewInstance = std::make_unique<StorageServerInstance>( m_RunEnvironment, StorageServerInstance::Configuration{.BasePort = GetInstanceIndexAssignedPort(ActiveInstanceIndex), - .HydrationTempPath = m_HydrationTempPath, + .StateDir = m_RunEnvironment.CreateChildDir(ModuleId), + .TempDir = m_HydrationTempPath / ModuleId, .HydrationTargetSpecification = m_HydrationTargetSpecification, + .HydrationOptions = m_HydrationOptions, .HttpThreadCount = m_Config.InstanceHttpThreadCount, .CoreLimit = m_Config.InstanceCoreLimit, - .ConfigPath = m_Config.InstanceConfigPath}, + .ConfigPath = m_Config.InstanceConfigPath, + .Malloc = m_Config.InstanceMalloc, + .Trace = m_Config.InstanceTrace, + .TraceHost = m_Config.InstanceTraceHost, + .TraceFile = m_Config.InstanceTraceFile, + .OptionalWorkerPool = m_Config.OptionalHydrationWorkerPool}, ModuleId); #if ZEN_PLATFORM_WINDOWS @@ -289,6 +349,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) Instance = NewInstance->LockExclusive(/*Wait*/ true); m_ActiveInstances[ActiveInstanceIndex].Instance = std::move(NewInstance); + m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Reset(); m_InstanceLookup.insert_or_assign(std::string(ModuleId), ActiveInstanceIndex); // Set Provisioning while both hub lock and instance lock are held so that any // concurrent Deprovision sees the in-flight state, not Unprovisioned. @@ -334,6 +395,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) case HubInstanceState::Unprovisioned: break; case HubInstanceState::Provisioned: + m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(std::chrono::system_clock::now()); return Response{EResponseCode::Completed}; case HubInstanceState::Hibernated: _.ReleaseNow(); @@ -354,6 +416,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) Instance = {}; if (ActualState == HubInstanceState::Provisioned) { + m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(std::chrono::system_clock::now()); return Response{EResponseCode::Completed}; } if (ActualState == HubInstanceState::Provisioning) @@ -540,6 +603,7 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI switch (CurrentState) { case HubInstanceState::Deprovisioning: + case HubInstanceState::Obliterating: return Response{EResponseCode::Accepted}; case HubInstanceState::Crashed: case HubInstanceState::Hibernated: @@ -585,11 +649,11 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI try { m_WorkerPool->ScheduleWork( - [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Instance = std::move(SharedInstancePtr)]() mutable { + [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Instance = std::move(SharedInstancePtr), OldState]() mutable { auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); try { - CompleteDeprovision(*Instance, ActiveInstanceIndex); + CompleteDeprovision(*Instance, ActiveInstanceIndex, OldState); } catch (const std::exception& Ex) { @@ -617,20 +681,222 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI } else { - CompleteDeprovision(Instance, ActiveInstanceIndex); + CompleteDeprovision(Instance, ActiveInstanceIndex, OldState); + } + + return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; +} + +Hub::Response +Hub::Obliterate(const std::string& ModuleId) +{ + ZEN_ASSERT(!m_ShutdownFlag.load()); + + StorageServerInstance::ExclusiveLockedPtr Instance; + size_t ActiveInstanceIndex = (size_t)-1; + { + RwLock::ExclusiveLockScope Lock(m_Lock); + + if (auto It = m_InstanceLookup.find(ModuleId); It != m_InstanceLookup.end()) + { + ActiveInstanceIndex = It->second; + ZEN_ASSERT(ActiveInstanceIndex < m_ActiveInstances.size()); + + HubInstanceState CurrentState = m_ActiveInstances[ActiveInstanceIndex].State.load(); + + switch (CurrentState) + { + case HubInstanceState::Obliterating: + return Response{EResponseCode::Accepted}; + case HubInstanceState::Provisioned: + case HubInstanceState::Hibernated: + case HubInstanceState::Crashed: + break; + case HubInstanceState::Deprovisioning: + return Response{EResponseCode::Rejected, + fmt::format("Module '{}' is being deprovisioned, retry after completion", ModuleId)}; + case HubInstanceState::Recovering: + return Response{EResponseCode::Rejected, fmt::format("Module '{}' is currently recovering from a crash", ModuleId)}; + case HubInstanceState::Unprovisioned: + return Response{EResponseCode::Completed}; + default: + return Response{EResponseCode::Rejected, + fmt::format("Module '{}' is currently in state '{}'", ModuleId, ToString(CurrentState))}; + } + + std::unique_ptr<StorageServerInstance>& RawInstance = m_ActiveInstances[ActiveInstanceIndex].Instance; + ZEN_ASSERT(RawInstance != nullptr); + + Instance = RawInstance->LockExclusive(/*Wait*/ true); + } + else + { + // Module not tracked by hub - obliterate backend data directly. + // Covers the deprovisioned case where data was preserved via dehydration. + if (m_ObliteratingInstances.contains(ModuleId)) + { + return Response{EResponseCode::Accepted}; + } + + m_ObliteratingInstances.insert(ModuleId); + Lock.ReleaseNow(); + + if (m_WorkerPool) + { + m_BackgroundWorkLatch.AddCount(1); + try + { + m_WorkerPool->ScheduleWork( + [this, ModuleId = std::string(ModuleId)]() { + auto Guard = MakeGuard([this, ModuleId]() { + m_Lock.WithExclusiveLock([this, ModuleId]() { m_ObliteratingInstances.erase(ModuleId); }); + m_BackgroundWorkLatch.CountDown(); + }); + try + { + ObliterateBackendData(ModuleId); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed async obliterate of untracked module '{}': {}", ModuleId, Ex.what()); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& DispatchEx) + { + ZEN_ERROR("Failed to dispatch async obliterate of untracked module '{}': {}", ModuleId, DispatchEx.what()); + m_BackgroundWorkLatch.CountDown(); + { + RwLock::ExclusiveLockScope _(m_Lock); + m_ObliteratingInstances.erase(ModuleId); + } + throw; + } + + return Response{EResponseCode::Accepted}; + } + + auto _ = MakeGuard([this, &ModuleId]() { + RwLock::ExclusiveLockScope _(m_Lock); + m_ObliteratingInstances.erase(ModuleId); + }); + + ObliterateBackendData(ModuleId); + + return Response{EResponseCode::Completed}; + } + } + + HubInstanceState OldState = UpdateInstanceState(Instance, ActiveInstanceIndex, HubInstanceState::Obliterating); + const uint16_t Port = Instance.GetBasePort(); + NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Obliterating, Port, {}); + + if (m_WorkerPool) + { + std::shared_ptr<StorageServerInstance::ExclusiveLockedPtr> SharedInstancePtr = + std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance)); + + m_BackgroundWorkLatch.AddCount(1); + try + { + m_WorkerPool->ScheduleWork( + [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Instance = std::move(SharedInstancePtr)]() mutable { + auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); + try + { + CompleteObliterate(*Instance, ActiveInstanceIndex); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed async obliterate of module '{}': {}", ModuleId, Ex.what()); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& DispatchEx) + { + ZEN_ERROR("Failed async dispatch obliterate of module '{}': {}", ModuleId, DispatchEx.what()); + m_BackgroundWorkLatch.CountDown(); + + NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, OldState, Port, {}); + { + RwLock::ExclusiveLockScope HubLock(m_Lock); + ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId)) != m_InstanceLookup.end()); + ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId))->second == ActiveInstanceIndex); + UpdateInstanceState(HubLock, ActiveInstanceIndex, OldState); + } + + throw; + } + } + else + { + CompleteObliterate(Instance, ActiveInstanceIndex); } return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; } void -Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex) +Hub::CompleteObliterate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex) { const std::string ModuleId(Instance.GetModuleId()); const uint16_t Port = Instance.GetBasePort(); try { + Instance.Obliterate(); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed to obliterate storage server instance for module '{}': {}", ModuleId, Ex.what()); + Instance = {}; + { + RwLock::ExclusiveLockScope HubLock(m_Lock); + UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Crashed); + } + NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Crashed, Port, {}); + throw; + } + + NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Unprovisioned, Port, {}); + RemoveInstance(Instance, ActiveInstanceIndex, ModuleId); +} + +void +Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState) +{ + const std::string ModuleId(Instance.GetModuleId()); + const uint16_t Port = Instance.GetBasePort(); + + try + { + if (OldState == HubInstanceState::Provisioned) + { + ZEN_INFO("Triggering GC for module {}", ModuleId); + + HttpClient GcClient(fmt::format("http://localhost:{}", Port)); + + HttpClient::KeyValueMap Params; + Params.Entries.insert({"smallobjects", "true"}); + Params.Entries.insert({"skipcid", "false"}); + HttpClient::Response Response = GcClient.Post("/admin/gc", HttpClient::Accept(HttpContentType::kCbObject), Params); + Stopwatch Timer; + while (Response && Timer.GetElapsedTimeMs() < 5000) + { + Response = GcClient.Get("/admin/gc", HttpClient::Accept(HttpContentType::kCbObject)); + if (Response) + { + bool Complete = Response.AsObject()["Status"].AsString() != "Running"; + if (Complete) + { + break; + } + Sleep(50); + } + } + } Instance.Deprovision(); } catch (const std::exception& Ex) @@ -649,20 +915,7 @@ Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, si } NotifyStateUpdate(ModuleId, HubInstanceState::Deprovisioning, HubInstanceState::Unprovisioned, Port, {}); - Instance = {}; - - std::unique_ptr<StorageServerInstance> DeleteInstance; - { - RwLock::ExclusiveLockScope HubLock(m_Lock); - auto It = m_InstanceLookup.find(std::string(ModuleId)); - ZEN_ASSERT_SLOW(It != m_InstanceLookup.end()); - ZEN_ASSERT_SLOW(It->second == ActiveInstanceIndex); - DeleteInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance); - m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex); - m_InstanceLookup.erase(It); - UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Unprovisioned); - } - DeleteInstance.reset(); + RemoveInstance(Instance, ActiveInstanceIndex, ModuleId); } Hub::Response @@ -935,6 +1188,50 @@ Hub::CompleteWake(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t Ac } } +void +Hub::RemoveInstance(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, std::string_view ModuleId) +{ + Instance = {}; + + std::unique_ptr<StorageServerInstance> DeleteInstance; + { + RwLock::ExclusiveLockScope HubLock(m_Lock); + auto It = m_InstanceLookup.find(std::string(ModuleId)); + ZEN_ASSERT_SLOW(It != m_InstanceLookup.end()); + ZEN_ASSERT_SLOW(It->second == ActiveInstanceIndex); + DeleteInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance); + m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex); + m_InstanceLookup.erase(It); + UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Unprovisioned); + } + DeleteInstance.reset(); +} + +void +Hub::ObliterateBackendData(std::string_view ModuleId) +{ + std::filesystem::path ServerStateDir = m_RunEnvironment.GetChildBaseDir() / ModuleId; + std::filesystem::path TempDir = m_HydrationTempPath / ModuleId; + + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + + HydrationConfig Config{.ServerStateDir = ServerStateDir, + .TempDir = TempDir, + .ModuleId = std::string(ModuleId), + .TargetSpecification = m_HydrationTargetSpecification, + .Options = m_HydrationOptions}; + if (m_Config.OptionalHydrationWorkerPool) + { + Config.Threading.emplace(HydrationConfig::ThreadingOptions{.WorkerPool = m_Config.OptionalHydrationWorkerPool, + .AbortFlag = &AbortFlag, + .PauseFlag = &PauseFlag}); + } + + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Obliterate(); +} + bool Hub::Find(std::string_view ModuleId, InstanceInfo* OutInstanceInfo) { @@ -947,12 +1244,10 @@ Hub::Find(std::string_view ModuleId, InstanceInfo* OutInstanceInfo) ZEN_ASSERT(ActiveInstanceIndex < m_ActiveInstances.size()); const std::unique_ptr<StorageServerInstance>& Instance = m_ActiveInstances[ActiveInstanceIndex].Instance; ZEN_ASSERT(Instance); - InstanceInfo Info{ - m_ActiveInstances[ActiveInstanceIndex].State.load(), - std::chrono::system_clock::now() // TODO - }; - Instance->GetProcessMetrics(Info.Metrics); - Info.Port = Instance->GetBasePort(); + InstanceInfo Info{m_ActiveInstances[ActiveInstanceIndex].State.load(), + m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.load()}; + Info.Metrics = m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Load(); + Info.Port = Instance->GetBasePort(); *OutInstanceInfo = Info; } @@ -971,12 +1266,10 @@ Hub::EnumerateModules(std::function<void(std::string_view ModuleId, const Instan { const std::unique_ptr<StorageServerInstance>& Instance = m_ActiveInstances[ActiveInstanceIndex].Instance; ZEN_ASSERT(Instance); - InstanceInfo Info{ - m_ActiveInstances[ActiveInstanceIndex].State.load(), - std::chrono::system_clock::now() // TODO - }; - Instance->GetProcessMetrics(Info.Metrics); - Info.Port = Instance->GetBasePort(); + InstanceInfo Info{m_ActiveInstances[ActiveInstanceIndex].State.load(), + m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.load()}; + Info.Metrics = m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Load(); + Info.Port = Instance->GetBasePort(); Infos.push_back(std::make_pair(std::string(Instance->GetModuleId()), Info)); } @@ -994,30 +1287,15 @@ Hub::GetInstanceCount() return m_Lock.WithSharedLock([this]() { return gsl::narrow_cast<int>(m_InstanceLookup.size()); }); } -void -Hub::UpdateCapacityMetrics() -{ - m_HostMetrics = GetSystemMetrics(); - - // TODO: Should probably go into WatchDog and use atomic for update so it can be read without locks... - // Per-instance stats are already refreshed by WatchDog and are readable via the Find and EnumerateModules -} - -void -Hub::UpdateStats() -{ - int CurrentInstanceCount = m_Lock.WithSharedLock([this] { return gsl::narrow_cast<int>(m_InstanceLookup.size()); }); - int CurrentMaxCount = m_MaxInstanceCount.load(); - - int NewMax = Max(CurrentMaxCount, CurrentInstanceCount); - - m_MaxInstanceCount.compare_exchange_weak(CurrentMaxCount, NewMax); -} - bool -Hub::CanProvisionInstance(std::string_view ModuleId, std::string& OutReason) +Hub::CanProvisionInstanceLocked(std::string_view ModuleId, std::string& OutReason) { - ZEN_UNUSED(ModuleId); + if (m_ObliteratingInstances.contains(std::string(ModuleId))) + { + OutReason = fmt::format("module '{}' is being obliterated", ModuleId); + return false; + } + if (m_FreeActiveInstanceIndexes.empty()) { OutReason = fmt::format("instance limit ({}) exceeded", m_Config.InstanceLimit); @@ -1025,7 +1303,24 @@ Hub::CanProvisionInstance(std::string_view ModuleId, std::string& OutReason) return false; } - // TODO: handle additional resource metrics + const uint64_t DiskUsedBytes = m_DiskSpace.Free <= m_DiskSpace.Total ? m_DiskSpace.Total - m_DiskSpace.Free : 0; + if (m_Config.ResourceLimits.DiskUsageBytes > 0 && DiskUsedBytes > m_Config.ResourceLimits.DiskUsageBytes) + { + OutReason = + fmt::format("disk usage ({}) exceeds ({})", NiceBytes(DiskUsedBytes), NiceBytes(m_Config.ResourceLimits.DiskUsageBytes)); + return false; + } + + const uint64_t RamUsedMiB = m_SystemMetrics.AvailSystemMemoryMiB <= m_SystemMetrics.SystemMemoryMiB + ? m_SystemMetrics.SystemMemoryMiB - m_SystemMetrics.AvailSystemMemoryMiB + : 0; + const uint64_t RamUsedBytes = RamUsedMiB * 1024 * 1024; + if (m_Config.ResourceLimits.MemoryUsageBytes > 0 && RamUsedBytes > m_Config.ResourceLimits.MemoryUsageBytes) + { + OutReason = + fmt::format("ram usage ({}) exceeds ({})", NiceBytes(RamUsedBytes), NiceBytes(m_Config.ResourceLimits.MemoryUsageBytes)); + return false; + } return true; } @@ -1036,6 +1331,21 @@ Hub::GetInstanceIndexAssignedPort(size_t ActiveInstanceIndex) const return gsl::narrow<uint16_t>(m_Config.BasePortNumber + ActiveInstanceIndex); } +bool +Hub::IsInstancePort(uint16_t Port) const +{ + if (Port < m_Config.BasePortNumber) + { + return false; + } + size_t Index = Port - m_Config.BasePortNumber; + if (Index >= m_ActiveInstances.size()) + { + return false; + } + return m_ActiveInstances[Index].State.load(std::memory_order_relaxed) != HubInstanceState::Unprovisioned; +} + HubInstanceState Hub::UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewState) { @@ -1046,11 +1356,13 @@ Hub::UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewS case HubInstanceState::Unprovisioned: return To == HubInstanceState::Provisioning; case HubInstanceState::Provisioned: - return To == HubInstanceState::Hibernating || To == HubInstanceState::Deprovisioning || To == HubInstanceState::Crashed; + return To == HubInstanceState::Hibernating || To == HubInstanceState::Deprovisioning || To == HubInstanceState::Crashed || + To == HubInstanceState::Obliterating; case HubInstanceState::Hibernated: - return To == HubInstanceState::Waking || To == HubInstanceState::Deprovisioning; + return To == HubInstanceState::Waking || To == HubInstanceState::Deprovisioning || To == HubInstanceState::Obliterating; case HubInstanceState::Crashed: - return To == HubInstanceState::Provisioning || To == HubInstanceState::Deprovisioning || To == HubInstanceState::Recovering; + return To == HubInstanceState::Provisioning || To == HubInstanceState::Deprovisioning || + To == HubInstanceState::Recovering || To == HubInstanceState::Obliterating; case HubInstanceState::Provisioning: return To == HubInstanceState::Provisioned || To == HubInstanceState::Unprovisioned || To == HubInstanceState::Crashed; case HubInstanceState::Hibernating: @@ -1062,11 +1374,15 @@ Hub::UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewS To == HubInstanceState::Crashed; case HubInstanceState::Recovering: return To == HubInstanceState::Provisioned || To == HubInstanceState::Unprovisioned; + case HubInstanceState::Obliterating: + return To == HubInstanceState::Unprovisioned || To == HubInstanceState::Crashed; } return false; }(m_ActiveInstances[ActiveInstanceIndex].State.load(), NewState)); + const std::chrono::system_clock::time_point Now = std::chrono::system_clock::now(); m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.store(0); - m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(std::chrono::system_clock::now()); + m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(Now); + m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.store(Now); return m_ActiveInstances[ActiveInstanceIndex].State.exchange(NewState); } @@ -1075,10 +1391,14 @@ Hub::AttemptRecoverInstance(std::string_view ModuleId) { StorageServerInstance::ExclusiveLockedPtr Instance; size_t ActiveInstanceIndex = (size_t)-1; - { RwLock::ExclusiveLockScope _(m_Lock); + if (m_ShutdownFlag.load()) + { + return; + } + auto It = m_InstanceLookup.find(std::string(ModuleId)); if (It == m_InstanceLookup.end()) { @@ -1173,14 +1493,14 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient, StorageServerInstance::SharedLockedPtr&& LockedInstance, size_t ActiveInstanceIndex) { + const std::string ModuleId(LockedInstance.GetModuleId()); + HubInstanceState InstanceState = m_ActiveInstances[ActiveInstanceIndex].State.load(); if (LockedInstance.IsRunning()) { - LockedInstance.UpdateMetrics(); + m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Store(LockedInstance.GetProcessMetrics()); if (InstanceState == HubInstanceState::Provisioned) { - const std::string ModuleId(LockedInstance.GetModuleId()); - const uint16_t Port = LockedInstance.GetBasePort(); const uint64_t PreviousActivitySum = m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.load(); const std::chrono::system_clock::time_point LastActivityTime = m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.load(); @@ -1260,8 +1580,7 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient, else if (InstanceState == HubInstanceState::Provisioned) { // Process is not running but state says it should be - instance died unexpectedly. - const std::string ModuleId(LockedInstance.GetModuleId()); - const uint16_t Port = LockedInstance.GetBasePort(); + const uint16_t Port = LockedInstance.GetBasePort(); UpdateInstanceState(LockedInstance, ActiveInstanceIndex, HubInstanceState::Crashed); NotifyStateUpdate(ModuleId, HubInstanceState::Provisioned, HubInstanceState::Crashed, Port, {}); LockedInstance = {}; @@ -1272,7 +1591,6 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient, { // Process is not running - no HTTP activity check is possible. // Use a pure time-based check; the margin window does not apply here. - const std::string ModuleId = std::string(LockedInstance.GetModuleId()); const std::chrono::system_clock::time_point LastActivityTime = m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.load(); const uint64_t PreviousActivitySum = m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.load(); const std::chrono::system_clock::time_point Now = std::chrono::system_clock::now(); @@ -1304,7 +1622,7 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient, } else { - // transitional state (Provisioning, Deprovisioning, Hibernating, Waking, Recovering) - expected, skip. + // transitional state (Provisioning, Deprovisioning, Hibernating, Waking, Recovering, Obliterating) - expected, skip. // Crashed is handled above via AttemptRecoverInstance; it appears here only when the instance // lock was busy on a previous cycle and recovery is already pending. return true; @@ -1312,6 +1630,43 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient, } void +Hub::UpdateMachineMetrics() +{ + try + { + bool DiskSpaceOk = false; + DiskSpace Disk; + + std::filesystem::path ChildDir = m_RunEnvironment.GetChildBaseDir(); + if (!ChildDir.empty()) + { + if (DiskSpaceInfo(ChildDir, Disk)) + { + DiskSpaceOk = true; + } + else + { + ZEN_WARN("Failed to query disk space for '{}'; disk-based provisioning limits will not be enforced", ChildDir); + } + } + + SystemMetrics Metrics = GetSystemMetrics(); + + m_Lock.WithExclusiveLock([&]() { + if (DiskSpaceOk) + { + m_DiskSpace = Disk; + } + m_SystemMetrics = Metrics; + }); + } + catch (const std::exception& Ex) + { + ZEN_WARN("Failed to update machine metrics. Reason: {}", Ex.what()); + } +} + +void Hub::WatchDog() { const uint64_t CycleIntervalMs = std::chrono::duration_cast<std::chrono::milliseconds>(m_Config.WatchDog.CycleInterval).count(); @@ -1326,16 +1681,18 @@ Hub::WatchDog() [&]() -> bool { return m_WatchDogEvent.Wait(0); }); size_t CheckInstanceIndex = SIZE_MAX; // first increment wraps to 0 - while (!m_WatchDogEvent.Wait(gsl::narrow<int>(CycleIntervalMs))) + while (!m_ShutdownFlag.load() && !m_WatchDogEvent.Wait(gsl::narrow<int>(CycleIntervalMs))) { try { + UpdateMachineMetrics(); + // Snapshot slot count. We iterate all slots (including freed nulls) so // round-robin coverage is not skewed by deprovisioned entries. size_t SlotsRemaining = m_Lock.WithSharedLock([this]() { return m_ActiveInstances.size(); }); Stopwatch Timer; - bool ShuttingDown = false; + bool ShuttingDown = m_ShutdownFlag.load(); while (SlotsRemaining > 0 && Timer.GetElapsedTimeMs() < CycleProcessingBudgetMs && !ShuttingDown) { StorageServerInstance::SharedLockedPtr LockedInstance; @@ -1366,16 +1723,24 @@ Hub::WatchDog() std::string ModuleId(LockedInstance.GetModuleId()); - bool InstanceIsOk = CheckInstanceStatus(ActivityCheckClient, std::move(LockedInstance), CheckInstanceIndex); - if (InstanceIsOk) + try { - ShuttingDown = m_WatchDogEvent.Wait(gsl::narrow<int>(InstanceCheckThrottleMs)); + bool InstanceIsOk = CheckInstanceStatus(ActivityCheckClient, std::move(LockedInstance), CheckInstanceIndex); + if (InstanceIsOk) + { + ShuttingDown = m_WatchDogEvent.Wait(gsl::narrow<int>(InstanceCheckThrottleMs)); + } + else + { + ZEN_WARN("Instance for module '{}' is not running, attempting recovery", ModuleId); + AttemptRecoverInstance(ModuleId); + } } - else + catch (const std::exception& Ex) { - ZEN_WARN("Instance for module '{}' is not running, attempting recovery", ModuleId); - AttemptRecoverInstance(ModuleId); + ZEN_WARN("Failed to check status of module {}. Reason: {}", ModuleId, Ex.what()); } + ShuttingDown |= m_ShutdownFlag.load(); } } catch (const std::exception& Ex) @@ -1417,6 +1782,14 @@ static const HttpClientSettings kFastTimeout{.ConnectTimeout = std::chrono::mill namespace hub_testutils { + struct TestHubPools + { + WorkerThreadPool ProvisionPool; + WorkerThreadPool HydrationPool; + + explicit TestHubPools(int ThreadCount) : ProvisionPool(ThreadCount, "hub_test_prov"), HydrationPool(ThreadCount, "hub_test_hydr") {} + }; + ZenServerEnvironment MakeHubEnvironment(const std::filesystem::path& BaseDir) { return ZenServerEnvironment(ZenServerEnvironment::Hub, GetRunningExecutablePath().parent_path(), BaseDir); @@ -1425,9 +1798,14 @@ namespace hub_testutils { std::unique_ptr<Hub> MakeHub(const std::filesystem::path& BaseDir, Hub::Configuration Config = {}, Hub::AsyncModuleStateChangeCallbackFunc StateChangeCallback = {}, - WorkerThreadPool* WorkerPool = nullptr) + TestHubPools* Pools = nullptr) { - return std::make_unique<Hub>(Config, MakeHubEnvironment(BaseDir), WorkerPool, std::move(StateChangeCallback)); + if (Pools) + { + Config.OptionalProvisionWorkerPool = &Pools->ProvisionPool; + Config.OptionalHydrationWorkerPool = &Pools->HydrationPool; + } + return std::make_unique<Hub>(Config, MakeHubEnvironment(BaseDir), std::move(StateChangeCallback)); } struct CallbackRecord @@ -1499,14 +1877,32 @@ namespace hub_testutils { } // namespace hub_testutils -TEST_CASE("hub.provision_basic") +TEST_CASE("hub.provision") { ScopedTemporaryDirectory TempDir; - std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path()); + + struct TransitionRecord + { + HubInstanceState OldState; + HubInstanceState NewState; + }; + RwLock CaptureMutex; + std::vector<TransitionRecord> Transitions; + + hub_testutils::StateChangeCapture CaptureInstance; + + auto CaptureFunc = + [&](std::string_view ModuleId, const HubProvisionedInstanceInfo& Info, HubInstanceState OldState, HubInstanceState NewState) { + CaptureMutex.WithExclusiveLock([&]() { Transitions.push_back({OldState, NewState}); }); + CaptureInstance.CaptureFunc()(ModuleId, Info, OldState, NewState); + }; + + std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), {}, std::move(CaptureFunc)); CHECK_EQ(HubInstance->GetInstanceCount(), 0); CHECK_FALSE(HubInstance->Find("module_a")); + // Provision HubProvisionedInstanceInfo Info; const Hub::Response ProvisionResult = HubInstance->Provision("module_a", Info); REQUIRE_MESSAGE(ProvisionResult.ResponseCode == Hub::EResponseCode::Completed, ProvisionResult.Message); @@ -1515,12 +1911,23 @@ TEST_CASE("hub.provision_basic") Hub::InstanceInfo InstanceInfo; REQUIRE(HubInstance->Find("module_a", &InstanceInfo)); CHECK_EQ(InstanceInfo.State, HubInstanceState::Provisioned); + CHECK_NE(InstanceInfo.StateChangeTime, std::chrono::system_clock::time_point::min()); + CHECK_LE(InstanceInfo.StateChangeTime, std::chrono::system_clock::now()); { HttpClient ModClient(fmt::format("http://localhost:{}", Info.Port), kFastTimeout); CHECK(ModClient.Get("/health/")); } + // Verify provision callback + { + RwLock::SharedLockScope _(CaptureInstance.CallbackMutex); + REQUIRE_EQ(CaptureInstance.ProvisionCallbacks.size(), 1u); + CHECK_EQ(CaptureInstance.ProvisionCallbacks[0].ModuleId, "module_a"); + CHECK_EQ(CaptureInstance.ProvisionCallbacks[0].Port, Info.Port); + } + + // Deprovision const Hub::Response DeprovisionResult = HubInstance->Deprovision("module_a"); CHECK(DeprovisionResult.ResponseCode == Hub::EResponseCode::Completed); CHECK_EQ(HubInstance->GetInstanceCount(), 0); @@ -1530,6 +1937,28 @@ TEST_CASE("hub.provision_basic") HttpClient ModClient(fmt::format("http://localhost:{}", Info.Port), kFastTimeout); CHECK(!ModClient.Get("/health/")); } + + // Verify deprovision callback + { + RwLock::SharedLockScope _(CaptureInstance.CallbackMutex); + REQUIRE_EQ(CaptureInstance.DeprovisionCallbacks.size(), 1u); + CHECK_EQ(CaptureInstance.DeprovisionCallbacks[0].ModuleId, "module_a"); + CHECK_EQ(CaptureInstance.DeprovisionCallbacks[0].Port, Info.Port); + } + + // Verify full transition sequence + { + RwLock::SharedLockScope _(CaptureMutex); + REQUIRE_EQ(Transitions.size(), 4u); + CHECK_EQ(Transitions[0].OldState, HubInstanceState::Unprovisioned); + CHECK_EQ(Transitions[0].NewState, HubInstanceState::Provisioning); + CHECK_EQ(Transitions[1].OldState, HubInstanceState::Provisioning); + CHECK_EQ(Transitions[1].NewState, HubInstanceState::Provisioned); + CHECK_EQ(Transitions[2].OldState, HubInstanceState::Provisioned); + CHECK_EQ(Transitions[2].NewState, HubInstanceState::Deprovisioning); + CHECK_EQ(Transitions[3].OldState, HubInstanceState::Deprovisioning); + CHECK_EQ(Transitions[3].NewState, HubInstanceState::Unprovisioned); + } } TEST_CASE("hub.provision_config") @@ -1582,92 +2011,6 @@ TEST_CASE("hub.provision_config") } } -TEST_CASE("hub.provision_callbacks") -{ - ScopedTemporaryDirectory TempDir; - - hub_testutils::StateChangeCapture CaptureInstance; - - std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), {}, CaptureInstance.CaptureFunc()); - - HubProvisionedInstanceInfo Info; - - const Hub::Response ProvisionResult = HubInstance->Provision("cb_module", Info); - REQUIRE_MESSAGE(ProvisionResult.ResponseCode == Hub::EResponseCode::Completed, ProvisionResult.Message); - - { - RwLock::SharedLockScope _(CaptureInstance.CallbackMutex); - REQUIRE_EQ(CaptureInstance.ProvisionCallbacks.size(), 1u); - CHECK_EQ(CaptureInstance.ProvisionCallbacks[0].ModuleId, "cb_module"); - CHECK_EQ(CaptureInstance.ProvisionCallbacks[0].Port, Info.Port); - CHECK_NE(CaptureInstance.ProvisionCallbacks[0].Port, 0); - } - - { - HttpClient ModClient(fmt::format("http://localhost:{}", Info.Port), kFastTimeout); - CHECK(ModClient.Get("/health/")); - } - - const Hub::Response DeprovisionResult = HubInstance->Deprovision("cb_module"); - CHECK(DeprovisionResult.ResponseCode == Hub::EResponseCode::Completed); - - { - HttpClient ModClient(fmt::format("http://localhost:{}", Info.Port), kFastTimeout); - CHECK(!ModClient.Get("/health/")); - } - - { - RwLock::SharedLockScope _(CaptureInstance.CallbackMutex); - REQUIRE_EQ(CaptureInstance.DeprovisionCallbacks.size(), 1u); - CHECK_EQ(CaptureInstance.DeprovisionCallbacks[0].ModuleId, "cb_module"); - CHECK_EQ(CaptureInstance.DeprovisionCallbacks[0].Port, Info.Port); - CHECK_EQ(CaptureInstance.DeprovisionCallbacks.size(), 1u); - } -} - -TEST_CASE("hub.provision_callback_sequence") -{ - ScopedTemporaryDirectory TempDir; - - struct TransitionRecord - { - HubInstanceState OldState; - HubInstanceState NewState; - }; - RwLock CaptureMutex; - std::vector<TransitionRecord> Transitions; - - auto CaptureFunc = - [&](std::string_view ModuleId, const HubProvisionedInstanceInfo& Info, HubInstanceState OldState, HubInstanceState NewState) { - ZEN_UNUSED(ModuleId); - ZEN_UNUSED(Info); - CaptureMutex.WithExclusiveLock([&]() { Transitions.push_back({OldState, NewState}); }); - }; - - std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), {}, std::move(CaptureFunc)); - - HubProvisionedInstanceInfo Info; - { - const Hub::Response R = HubInstance->Provision("seq_module", Info); - REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); - } - { - const Hub::Response R = HubInstance->Deprovision("seq_module"); - REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); - } - - RwLock::SharedLockScope _(CaptureMutex); - REQUIRE_EQ(Transitions.size(), 4u); - CHECK_EQ(Transitions[0].OldState, HubInstanceState::Unprovisioned); - CHECK_EQ(Transitions[0].NewState, HubInstanceState::Provisioning); - CHECK_EQ(Transitions[1].OldState, HubInstanceState::Provisioning); - CHECK_EQ(Transitions[1].NewState, HubInstanceState::Provisioned); - CHECK_EQ(Transitions[2].OldState, HubInstanceState::Provisioned); - CHECK_EQ(Transitions[2].NewState, HubInstanceState::Deprovisioning); - CHECK_EQ(Transitions[3].OldState, HubInstanceState::Deprovisioning); - CHECK_EQ(Transitions[3].NewState, HubInstanceState::Unprovisioned); -} - TEST_CASE("hub.instance_limit") { ScopedTemporaryDirectory TempDir; @@ -1699,54 +2042,7 @@ TEST_CASE("hub.instance_limit") CHECK_EQ(HubInstance->GetInstanceCount(), 2); } -TEST_CASE("hub.enumerate_modules") -{ - ScopedTemporaryDirectory TempDir; - std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path()); - - HubProvisionedInstanceInfo Info; - - { - const Hub::Response R = HubInstance->Provision("enum_a", Info); - REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); - } - { - const Hub::Response R = HubInstance->Provision("enum_b", Info); - REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); - } - - std::vector<std::string> Ids; - int ProvisionedCount = 0; - HubInstance->EnumerateModules([&](std::string_view ModuleId, const Hub::InstanceInfo& InstanceInfo) { - Ids.push_back(std::string(ModuleId)); - if (InstanceInfo.State == HubInstanceState::Provisioned) - { - ProvisionedCount++; - } - }); - CHECK_EQ(Ids.size(), 2u); - CHECK_EQ(ProvisionedCount, 2); - const bool FoundA = std::find(Ids.begin(), Ids.end(), "enum_a") != Ids.end(); - const bool FoundB = std::find(Ids.begin(), Ids.end(), "enum_b") != Ids.end(); - CHECK(FoundA); - CHECK(FoundB); - - HubInstance->Deprovision("enum_a"); - Ids.clear(); - ProvisionedCount = 0; - HubInstance->EnumerateModules([&](std::string_view ModuleId, const Hub::InstanceInfo& InstanceInfo) { - Ids.push_back(std::string(ModuleId)); - if (InstanceInfo.State == HubInstanceState::Provisioned) - { - ProvisionedCount++; - } - }); - REQUIRE_EQ(Ids.size(), 1u); - CHECK_EQ(Ids[0], "enum_b"); - CHECK_EQ(ProvisionedCount, 1); -} - -TEST_CASE("hub.max_instance_count") +TEST_CASE("hub.enumerate_and_instance_tracking") { ScopedTemporaryDirectory TempDir; std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path()); @@ -1756,22 +2052,56 @@ TEST_CASE("hub.max_instance_count") HubProvisionedInstanceInfo Info; { - const Hub::Response R = HubInstance->Provision("max_a", Info); + const Hub::Response R = HubInstance->Provision("track_a", Info); REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); } CHECK_GE(HubInstance->GetMaxInstanceCount(), 1); { - const Hub::Response R = HubInstance->Provision("max_b", Info); + const Hub::Response R = HubInstance->Provision("track_b", Info); REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); } CHECK_GE(HubInstance->GetMaxInstanceCount(), 2); + // Enumerate both modules + { + std::vector<std::string> Ids; + int ProvisionedCount = 0; + HubInstance->EnumerateModules([&](std::string_view ModuleId, const Hub::InstanceInfo& InstanceInfo) { + Ids.push_back(std::string(ModuleId)); + if (InstanceInfo.State == HubInstanceState::Provisioned) + { + ProvisionedCount++; + } + }); + CHECK_EQ(Ids.size(), 2u); + CHECK_EQ(ProvisionedCount, 2); + CHECK(std::find(Ids.begin(), Ids.end(), "track_a") != Ids.end()); + CHECK(std::find(Ids.begin(), Ids.end(), "track_b") != Ids.end()); + } + const int MaxAfterTwo = HubInstance->GetMaxInstanceCount(); - HubInstance->Deprovision("max_a"); + // Deprovision one - max instance count must not decrease + HubInstance->Deprovision("track_a"); CHECK_EQ(HubInstance->GetInstanceCount(), 1); CHECK_EQ(HubInstance->GetMaxInstanceCount(), MaxAfterTwo); + + // Enumerate after deprovision + { + std::vector<std::string> Ids; + int ProvisionedCount = 0; + HubInstance->EnumerateModules([&](std::string_view ModuleId, const Hub::InstanceInfo& InstanceInfo) { + Ids.push_back(std::string(ModuleId)); + if (InstanceInfo.State == HubInstanceState::Provisioned) + { + ProvisionedCount++; + } + }); + REQUIRE_EQ(Ids.size(), 1u); + CHECK_EQ(Ids[0], "track_b"); + CHECK_EQ(ProvisionedCount, 1); + } } TEST_CASE("hub.concurrent_callbacks") @@ -1917,7 +2247,7 @@ TEST_CASE("hub.job_object") } # endif // ZEN_PLATFORM_WINDOWS -TEST_CASE("hub.hibernate_wake") +TEST_CASE("hub.hibernate_wake_obliterate") { ScopedTemporaryDirectory TempDir; Hub::Configuration Config; @@ -1927,6 +2257,11 @@ TEST_CASE("hub.hibernate_wake") HubProvisionedInstanceInfo ProvInfo; Hub::InstanceInfo Info; + // Error cases on non-existent modules (no provision needed) + CHECK(HubInstance->Hibernate("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound); + CHECK(HubInstance->Wake("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound); + CHECK(HubInstance->Deprovision("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound); + // Provision { const Hub::Response R = HubInstance->Provision("hib_a", ProvInfo); @@ -1934,82 +2269,104 @@ TEST_CASE("hub.hibernate_wake") } REQUIRE(HubInstance->Find("hib_a", &Info)); CHECK_EQ(Info.State, HubInstanceState::Provisioned); + const std::chrono::system_clock::time_point ProvisionedTime = Info.StateChangeTime; + CHECK_NE(ProvisionedTime, std::chrono::system_clock::time_point::min()); + CHECK_LE(ProvisionedTime, std::chrono::system_clock::now()); { HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout); CHECK(ModClient.Get("/health/")); } + // Double-wake on provisioned module is idempotent + CHECK(HubInstance->Wake("hib_a").ResponseCode == Hub::EResponseCode::Completed); + // Hibernate - const Hub::Response HibernateResult = HubInstance->Hibernate("hib_a"); - REQUIRE_MESSAGE(HibernateResult.ResponseCode == Hub::EResponseCode::Completed, HibernateResult.Message); + { + const Hub::Response R = HubInstance->Hibernate("hib_a"); + REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); + } REQUIRE(HubInstance->Find("hib_a", &Info)); CHECK_EQ(Info.State, HubInstanceState::Hibernated); + const std::chrono::system_clock::time_point HibernatedTime = Info.StateChangeTime; + CHECK_GE(HibernatedTime, ProvisionedTime); { HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout); CHECK(!ModClient.Get("/health/")); } + // Double-hibernate on already-hibernated module is idempotent + CHECK(HubInstance->Hibernate("hib_a").ResponseCode == Hub::EResponseCode::Completed); + // Wake - const Hub::Response WakeResult = HubInstance->Wake("hib_a"); - REQUIRE_MESSAGE(WakeResult.ResponseCode == Hub::EResponseCode::Completed, WakeResult.Message); + { + const Hub::Response R = HubInstance->Wake("hib_a"); + REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); + } REQUIRE(HubInstance->Find("hib_a", &Info)); CHECK_EQ(Info.State, HubInstanceState::Provisioned); + CHECK_GE(Info.StateChangeTime, HibernatedTime); { HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout); CHECK(ModClient.Get("/health/")); } - // Deprovision - const Hub::Response DeprovisionResult = HubInstance->Deprovision("hib_a"); - CHECK(DeprovisionResult.ResponseCode == Hub::EResponseCode::Completed); - CHECK_FALSE(HubInstance->Find("hib_a")); + // Hibernate again for obliterate-from-hibernated test { - HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout); - CHECK(!ModClient.Get("/health/")); + const Hub::Response R = HubInstance->Hibernate("hib_a"); + REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); } -} - -TEST_CASE("hub.hibernate_wake_errors") -{ - ScopedTemporaryDirectory TempDir; - Hub::Configuration Config; - Config.BasePortNumber = 22700; - std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config); - - HubProvisionedInstanceInfo ProvInfo; + REQUIRE(HubInstance->Find("hib_a", &Info)); + CHECK_EQ(Info.State, HubInstanceState::Hibernated); - // Hibernate/wake on a non-existent module - returns NotFound (-> 404) - CHECK(HubInstance->Hibernate("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound); - CHECK(HubInstance->Wake("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound); + // Obliterate from hibernated + { + const Hub::Response R = HubInstance->Obliterate("hib_a"); + CHECK(R.ResponseCode == Hub::EResponseCode::Completed); + } + CHECK_EQ(HubInstance->GetInstanceCount(), 0); + CHECK_FALSE(HubInstance->Find("hib_a")); - // Double-hibernate: second hibernate on already-hibernated module returns Completed (idempotent) + // Re-provision for obliterate-from-provisioned test { - const Hub::Response R = HubInstance->Provision("err_b", ProvInfo); + const Hub::Response R = HubInstance->Provision("hib_a", ProvInfo); REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); } + REQUIRE(HubInstance->Find("hib_a", &Info)); + CHECK_EQ(Info.State, HubInstanceState::Provisioned); { - const Hub::Response R = HubInstance->Hibernate("err_b"); - REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); + HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout); + CHECK(ModClient.Get("/health/")); } + // Obliterate from provisioned + { + const Hub::Response R = HubInstance->Obliterate("hib_a"); + CHECK(R.ResponseCode == Hub::EResponseCode::Completed); + } + CHECK_EQ(HubInstance->GetInstanceCount(), 0); + CHECK_FALSE(HubInstance->Find("hib_a")); { - const Hub::Response HibResp = HubInstance->Hibernate("err_b"); - CHECK(HibResp.ResponseCode == Hub::EResponseCode::Completed); + HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout); + CHECK(!ModClient.Get("/health/")); } - // Wake on provisioned: succeeds (-> Provisioned), then wake again returns Completed (idempotent) + // Obliterate deprovisioned module (not tracked by hub, backend data may exist) { - const Hub::Response R = HubInstance->Wake("err_b"); + const Hub::Response R = HubInstance->Provision("hib_a", ProvInfo); REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); } - { - const Hub::Response WakeResp = HubInstance->Wake("err_b"); - CHECK(WakeResp.ResponseCode == Hub::EResponseCode::Completed); + const Hub::Response R = HubInstance->Deprovision("hib_a"); + REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); + } + CHECK_FALSE(HubInstance->Find("hib_a")); + { + const Hub::Response R = HubInstance->Obliterate("hib_a"); + CHECK(R.ResponseCode == Hub::EResponseCode::Completed); } - // Deprovision not-found - returns NotFound (-> 404) - CHECK(HubInstance->Deprovision("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound); + // Obliterate of a never-provisioned module also succeeds (no-op backend cleanup) + CHECK(HubInstance->Obliterate("never_existed").ResponseCode == Hub::EResponseCode::Completed); } TEST_CASE("hub.async_hibernate_wake") @@ -2019,8 +2376,8 @@ TEST_CASE("hub.async_hibernate_wake") Hub::Configuration Config; Config.BasePortNumber = 23000; - WorkerThreadPool WorkerPool(2, "hub_async_hib_wake"); - std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &WorkerPool); + hub_testutils::TestHubPools Pools(2); + std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &Pools); HubProvisionedInstanceInfo ProvInfo; Hub::InstanceInfo Info; @@ -2150,25 +2507,21 @@ TEST_CASE("hub.recover_process_crash") if (HubInstance->Find("module_a", &InstanceInfo) && InstanceInfo.State == HubInstanceState::Provisioned && ModClient.Get("/health/")) { - // Recovery must reuse the same port - the instance was never removed from the hub's - // port table during recovery, so AttemptRecoverInstance reuses m_Config.BasePort. CHECK_EQ(InstanceInfo.Port, Info.Port); Recovered = true; break; } } - CHECK_MESSAGE(Recovered, "Instance did not recover within timeout"); + REQUIRE_MESSAGE(Recovered, "Instance did not recover within timeout"); // Verify the full crash/recovery callback sequence { RwLock::SharedLockScope _(CaptureMutex); REQUIRE_GE(Transitions.size(), 3u); - // Find the Provisioned->Crashed transition const auto CrashedIt = std::find_if(Transitions.begin(), Transitions.end(), [](const TransitionRecord& R) { return R.OldState == HubInstanceState::Provisioned && R.NewState == HubInstanceState::Crashed; }); REQUIRE_NE(CrashedIt, Transitions.end()); - // Recovery sequence follows: Crashed->Recovering, Recovering->Provisioned const auto RecoveringIt = CrashedIt + 1; REQUIRE_NE(RecoveringIt, Transitions.end()); CHECK_EQ(RecoveringIt->OldState, HubInstanceState::Crashed); @@ -2178,44 +2531,6 @@ TEST_CASE("hub.recover_process_crash") CHECK_EQ(RecoveredIt->OldState, HubInstanceState::Recovering); CHECK_EQ(RecoveredIt->NewState, HubInstanceState::Provisioned); } -} - -TEST_CASE("hub.recover_process_crash_then_deprovision") -{ - ScopedTemporaryDirectory TempDir; - - // Fast watchdog cycle so crash detection is near-instant instead of waiting up to the 3s default. - Hub::Configuration Config; - Config.WatchDog.CycleInterval = std::chrono::milliseconds(10); - Config.WatchDog.InstanceCheckThrottle = std::chrono::milliseconds(1); - - std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config); - - HubProvisionedInstanceInfo Info; - { - const Hub::Response R = HubInstance->Provision("module_a", Info); - REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message); - } - - // Kill the child process, wait for the watchdog to detect and recover the instance. - HubInstance->TerminateModuleForTesting("module_a"); - - constexpr auto kPollIntervalMs = std::chrono::milliseconds(50); - constexpr auto kTimeoutMs = std::chrono::seconds(15); - const auto Deadline = std::chrono::steady_clock::now() + kTimeoutMs; - - bool Recovered = false; - while (std::chrono::steady_clock::now() < Deadline) - { - std::this_thread::sleep_for(kPollIntervalMs); - Hub::InstanceInfo InstanceInfo; - if (HubInstance->Find("module_a", &InstanceInfo) && InstanceInfo.State == HubInstanceState::Provisioned) - { - Recovered = true; - break; - } - } - REQUIRE_MESSAGE(Recovered, "Instance did not recover within timeout"); // After recovery, deprovision should succeed and a re-provision should work. { @@ -2244,8 +2559,8 @@ TEST_CASE("hub.async_provision_concurrent") Config.BasePortNumber = 22800; Config.InstanceLimit = kModuleCount; - WorkerThreadPool WorkerPool(4, "hub_async_concurrent"); - std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &WorkerPool); + hub_testutils::TestHubPools Pools(4); + std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &Pools); std::vector<HubProvisionedInstanceInfo> Infos(kModuleCount); std::vector<std::string> Reasons(kModuleCount); @@ -2326,8 +2641,8 @@ TEST_CASE("hub.async_provision_shutdown_waits") Config.InstanceLimit = kModuleCount; Config.BasePortNumber = 22900; - WorkerThreadPool WorkerPool(2, "hub_async_shutdown"); - std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &WorkerPool); + hub_testutils::TestHubPools Pools(2); + std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &Pools); std::vector<HubProvisionedInstanceInfo> Infos(kModuleCount); @@ -2352,15 +2667,15 @@ TEST_CASE("hub.async_provision_shutdown_waits") TEST_CASE("hub.async_provision_rejected") { - // Rejection from CanProvisionInstance fires synchronously even when a WorkerPool is present. + // Rejection from CanProvisionInstanceLocked fires synchronously even when a WorkerPool is present. ScopedTemporaryDirectory TempDir; Hub::Configuration Config; Config.InstanceLimit = 1; Config.BasePortNumber = 23100; - WorkerThreadPool WorkerPool(2, "hub_async_rejected"); - std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &WorkerPool); + hub_testutils::TestHubPools Pools(2); + std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &Pools); HubProvisionedInstanceInfo Info; @@ -2369,7 +2684,7 @@ TEST_CASE("hub.async_provision_rejected") REQUIRE_MESSAGE(FirstResult.ResponseCode == Hub::EResponseCode::Accepted, FirstResult.Message); REQUIRE_NE(Info.Port, 0); - // Second provision: CanProvisionInstance rejects synchronously (limit reached), returns Rejected + // Second provision: CanProvisionInstanceLocked rejects synchronously (limit reached), returns Rejected HubProvisionedInstanceInfo Info2; const Hub::Response SecondResult = HubInstance->Provision("async_r2", Info2); CHECK(SecondResult.ResponseCode == Hub::EResponseCode::Rejected); @@ -2448,12 +2763,12 @@ TEST_CASE("hub.instance.inactivity.deprovision") // Phase 1: immediately after setup all three instances must still be alive. // No timeout has elapsed yet (only 100ms have passed). - CHECK_MESSAGE(HubInstance->Find("idle"), "idle was deprovisioned within 100ms - its 2s provisioned timeout has not elapsed"); + CHECK_MESSAGE(HubInstance->Find("idle"), "idle was deprovisioned within 100ms - its 4s provisioned timeout has not elapsed"); CHECK_MESSAGE(HubInstance->Find("idle_hib"), "idle_hib was deprovisioned within 100ms - its 1s hibernated timeout has not elapsed"); CHECK_MESSAGE(HubInstance->Find("persistent"), - "persistent was deprovisioned within 100ms - its 2s provisioned timeout has not elapsed"); + "persistent was deprovisioned within 100ms - its 4s provisioned timeout has not elapsed"); // Phase 2: idle_hib must be deprovisioned by the watchdog within its 1s hibernated timeout. // idle must remain alive - its 2s provisioned timeout has not elapsed yet. @@ -2477,7 +2792,7 @@ TEST_CASE("hub.instance.inactivity.deprovision") CHECK_MESSAGE(!HubInstance->Find("idle_hib"), "idle_hib should still be gone - it was deprovisioned in phase 2"); - CHECK_MESSAGE(!HubInstance->Find("idle"), "idle should be gone after its 3s provisioned timeout elapsed"); + CHECK_MESSAGE(!HubInstance->Find("idle"), "idle should be gone after its 4s provisioned timeout elapsed"); CHECK_MESSAGE(HubInstance->Find("persistent"), "persistent was incorrectly deprovisioned - its activity timer was reset by PokeInstance"); @@ -2485,6 +2800,55 @@ TEST_CASE("hub.instance.inactivity.deprovision") HubInstance->Shutdown(); } +TEST_CASE("hub.machine_metrics") +{ + ScopedTemporaryDirectory TempDir; + + std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), {}); + + // UpdateMachineMetrics() is called synchronously in the Hub constructor, so metrics + // are available immediately without waiting for a watchdog cycle. + SystemMetrics SysMetrics; + DiskSpace Disk; + HubInstance->GetMachineMetrics(SysMetrics, Disk); + + CHECK_GT(Disk.Total, 0u); + CHECK_LE(Disk.Free, Disk.Total); + + CHECK_GT(SysMetrics.SystemMemoryMiB, 0u); + CHECK_LE(SysMetrics.AvailSystemMemoryMiB, SysMetrics.SystemMemoryMiB); + + CHECK_GT(SysMetrics.VirtualMemoryMiB, 0u); + CHECK_LE(SysMetrics.AvailVirtualMemoryMiB, SysMetrics.VirtualMemoryMiB); +} + +TEST_CASE("hub.provision_rejected_resource_limits") +{ + // The Hub constructor calls UpdateMachineMetrics() synchronously, so CanProvisionInstanceLocked + // can enforce limits immediately without waiting for a watchdog cycle. + ScopedTemporaryDirectory TempDir; + + { + Hub::Configuration Config; + Config.ResourceLimits.DiskUsageBytes = 1; + std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config); + HubProvisionedInstanceInfo Info; + const Hub::Response Result = HubInstance->Provision("disk_limit", Info); + CHECK(Result.ResponseCode == Hub::EResponseCode::Rejected); + CHECK_NE(Result.Message.find("disk usage"), std::string::npos); + } + + { + Hub::Configuration Config; + Config.ResourceLimits.MemoryUsageBytes = 1; + std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config); + HubProvisionedInstanceInfo Info; + const Hub::Response Result = HubInstance->Provision("mem_limit", Info); + CHECK(Result.ResponseCode == Hub::EResponseCode::Rejected); + CHECK_NE(Result.Message.find("ram usage"), std::string::npos); + } +} + TEST_SUITE_END(); void diff --git a/src/zenserver/hub/hub.h b/src/zenserver/hub/hub.h index c343b19e2..040f34af5 100644 --- a/src/zenserver/hub/hub.h +++ b/src/zenserver/hub/hub.h @@ -6,6 +6,8 @@ #include "resourcemetrics.h" #include "storageserverinstance.h" +#include <zencore/compactbinary.h> +#include <zencore/filesystem.h> #include <zencore/system.h> #include <zenutil/zenserverprocess.h> @@ -16,6 +18,7 @@ #include <memory> #include <thread> #include <unordered_map> +#include <unordered_set> namespace zen { @@ -64,10 +67,20 @@ public: uint32_t InstanceHttpThreadCount = 0; // Automatic int InstanceCoreLimit = 0; // Automatic + std::string InstanceMalloc; + std::string InstanceTrace; + std::string InstanceTraceHost; + std::string InstanceTraceFile; std::filesystem::path InstanceConfigPath; std::string HydrationTargetSpecification; + CbObject HydrationOptions; WatchDogConfiguration WatchDog; + + ResourceMetrics ResourceLimits; + + WorkerThreadPool* OptionalProvisionWorkerPool = nullptr; + WorkerThreadPool* OptionalHydrationWorkerPool = nullptr; }; typedef std::function< @@ -76,7 +89,6 @@ public: Hub(const Configuration& Config, ZenServerEnvironment&& RunEnvironment, - WorkerThreadPool* OptionalWorkerPool = nullptr, AsyncModuleStateChangeCallbackFunc&& ModuleStateChangeCallback = {}); ~Hub(); @@ -86,7 +98,7 @@ public: struct InstanceInfo { HubInstanceState State = HubInstanceState::Unprovisioned; - std::chrono::system_clock::time_point ProvisionTime; + std::chrono::system_clock::time_point StateChangeTime; ProcessMetrics Metrics; uint16_t Port = 0; }; @@ -126,6 +138,14 @@ public: Response Deprovision(const std::string& ModuleId); /** + * Obliterate a storage server instance and all associated data. + * Shuts down the process, deletes backend hydration data, and cleans local state. + * + * @param ModuleId The ID of the module to obliterate. + */ + Response Obliterate(const std::string& ModuleId); + + /** * Hibernate a storage server instance for the given module ID. * The instance is shut down but its data is preserved; it can be woken later. * @@ -160,6 +180,10 @@ public: int GetMaxInstanceCount() const { return m_MaxInstanceCount.load(); } + void GetMachineMetrics(SystemMetrics& OutSystemMetrict, DiskSpace& OutDiskSpace) const; + + bool IsInstancePort(uint16_t Port) const; + const Configuration& GetConfig() const { return m_Config; } #if ZEN_WITH_TESTS @@ -176,14 +200,31 @@ private: AsyncModuleStateChangeCallbackFunc m_ModuleStateChangeCallback; std::string m_HydrationTargetSpecification; + CbObject m_HydrationOptions; std::filesystem::path m_HydrationTempPath; #if ZEN_PLATFORM_WINDOWS JobObject m_JobObject; #endif - RwLock m_Lock; + mutable RwLock m_Lock; std::unordered_map<std::string, size_t> m_InstanceLookup; + // Mirrors ProcessMetrics with atomic fields, enabling lock-free reads alongside watchdog writes. + struct AtomicProcessMetrics + { + std::atomic<uint64_t> MemoryBytes = 0; + std::atomic<uint64_t> KernelTimeMs = 0; + std::atomic<uint64_t> UserTimeMs = 0; + std::atomic<uint64_t> WorkingSetSize = 0; + std::atomic<uint64_t> PeakWorkingSetSize = 0; + std::atomic<uint64_t> PagefileUsage = 0; + std::atomic<uint64_t> PeakPagefileUsage = 0; + + ProcessMetrics Load() const; + void Store(const ProcessMetrics& Metrics); + void Reset(); + }; + struct ActiveInstance { // Invariant: Instance == nullptr if and only if State == Unprovisioned. @@ -192,11 +233,16 @@ private: // without holding the hub lock. std::unique_ptr<StorageServerInstance> Instance; std::atomic<HubInstanceState> State = HubInstanceState::Unprovisioned; - // TODO: We should move current metrics here (from StorageServerInstance) - // Read and updated by WatchDog, updates to State triggers a reset of both + // Process metrics - written by WatchDog (inside instance shared lock), read lock-free. + AtomicProcessMetrics ProcessMetrics; + + // Activity tracking - written by WatchDog, reset on every state transition. std::atomic<uint64_t> LastKnownActivitySum = 0; std::atomic<std::chrono::system_clock::time_point> LastActivityTime = std::chrono::system_clock::time_point::min(); + + // Set in UpdateInstanceStateLocked on every state transition; read lock-free by Find/EnumerateModules. + std::atomic<std::chrono::system_clock::time_point> StateChangeTime = std::chrono::system_clock::time_point::min(); }; // UpdateInstanceState is overloaded to accept a locked instance pointer (exclusive or shared) or the hub exclusive @@ -224,23 +270,23 @@ private: } HubInstanceState UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewState); - std::vector<ActiveInstance> m_ActiveInstances; - std::deque<size_t> m_FreeActiveInstanceIndexes; - ResourceMetrics m_ResourceLimits; - SystemMetrics m_HostMetrics; - std::atomic<int> m_MaxInstanceCount = 0; - std::thread m_WatchDog; + std::vector<ActiveInstance> m_ActiveInstances; + std::deque<size_t> m_FreeActiveInstanceIndexes; + SystemMetrics m_SystemMetrics; + DiskSpace m_DiskSpace; + std::atomic<int> m_MaxInstanceCount = 0; + std::thread m_WatchDog; + std::unordered_set<std::string> m_ObliteratingInstances; Event m_WatchDogEvent; void WatchDog(); + void UpdateMachineMetrics(); bool CheckInstanceStatus(HttpClient& ActivityHttpClient, StorageServerInstance::SharedLockedPtr&& LockedInstance, size_t ActiveInstanceIndex); void AttemptRecoverInstance(std::string_view ModuleId); - void UpdateStats(); - void UpdateCapacityMetrics(); - bool CanProvisionInstance(std::string_view ModuleId, std::string& OutReason); + bool CanProvisionInstanceLocked(std::string_view ModuleId, std::string& OutReason); uint16_t GetInstanceIndexAssignedPort(size_t ActiveInstanceIndex) const; Response InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveInstance& Instance)>&& DeprovisionGate); @@ -248,9 +294,12 @@ private: size_t ActiveInstanceIndex, HubInstanceState OldState, bool IsNewInstance); - void CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex); - void CompleteHibernate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState); - void CompleteWake(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState); + void CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState); + void CompleteObliterate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex); + void CompleteHibernate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState); + void CompleteWake(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState); + void RemoveInstance(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, std::string_view ModuleId); + void ObliterateBackendData(std::string_view ModuleId); // Notifications may fire slightly out of sync with the Hub's internal State flag. // The guarantee is that notifications are sent in the correct order, but the State diff --git a/src/zenserver/hub/hubinstancestate.cpp b/src/zenserver/hub/hubinstancestate.cpp index c47fdd294..310305e5d 100644 --- a/src/zenserver/hub/hubinstancestate.cpp +++ b/src/zenserver/hub/hubinstancestate.cpp @@ -29,6 +29,8 @@ ToString(HubInstanceState State) return "crashed"; case HubInstanceState::Recovering: return "recovering"; + case HubInstanceState::Obliterating: + return "obliterating"; } ZEN_ASSERT(false); return "unknown"; diff --git a/src/zenserver/hub/hubinstancestate.h b/src/zenserver/hub/hubinstancestate.h index c895f75d1..c7188aa5c 100644 --- a/src/zenserver/hub/hubinstancestate.h +++ b/src/zenserver/hub/hubinstancestate.h @@ -20,7 +20,8 @@ enum class HubInstanceState : uint32_t Hibernating, // Provisioned -> Hibernated (Shutting down process, preserving data on disk) Waking, // Hibernated -> Provisioned (Starting process from preserved data) Deprovisioning, // Provisioned/Hibernated/Crashed -> Unprovisioned (Shutting down process and cleaning up data) - Recovering, // Crashed -> Provisioned/Deprovisioned (Attempting in-place restart after a crash) + Recovering, // Crashed -> Provisioned/Unprovisioned (Attempting in-place restart after a crash) + Obliterating, // Provisioned/Hibernated/Crashed -> Unprovisioned (Destroying all local and backend data) }; std::string_view ToString(HubInstanceState State); diff --git a/src/zenserver/hub/hydration.cpp b/src/zenserver/hub/hydration.cpp index 541127590..b356064f9 100644 --- a/src/zenserver/hub/hydration.cpp +++ b/src/zenserver/hub/hydration.cpp @@ -5,20 +5,25 @@ #include <zencore/basicfile.h> #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/compress.h> #include <zencore/except_fmt.h> #include <zencore/filesystem.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> +#include <zencore/parallelwork.h> +#include <zencore/stream.h> #include <zencore/system.h> +#include <zencore/timer.h> #include <zenutil/cloud/imdscredentials.h> #include <zenutil/cloud/s3client.h> +#include <zenutil/filesystemutils.h> -ZEN_THIRD_PARTY_INCLUDES_START -#include <json11.hpp> -ZEN_THIRD_PARTY_INCLUDES_END +#include <numeric> +#include <unordered_map> +#include <unordered_set> #if ZEN_WITH_TESTS -# include <zencore/parallelwork.h> # include <zencore/testing.h> # include <zencore/testutils.h> # include <zencore/thread.h> @@ -29,7 +34,7 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -namespace { +namespace hydration_impl { /// UTC time decomposed to calendar fields with sub-second milliseconds. struct UtcTime @@ -55,483 +60,1046 @@ namespace { } }; -} // namespace + std::filesystem::path FastRelativePath(const std::filesystem::path& Root, const std::filesystem::path& Abs) + { + auto [_, ItAbs] = std::mismatch(Root.begin(), Root.end(), Abs.begin(), Abs.end()); + std::filesystem::path RelativePath; + for (auto I = ItAbs; I != Abs.end(); I++) + { + RelativePath = RelativePath / *I; + } + return RelativePath; + } -/////////////////////////////////////////////////////////////////////////// + void CleanDirectory(WorkerThreadPool& WorkerPool, + std::atomic<bool>& AbortFlag, + std::atomic<bool>& PauseFlag, + const std::filesystem::path& Path) + { + CleanDirectory(WorkerPool, AbortFlag, PauseFlag, Path, std::vector<std::string>{}, {}, 0); + } -constexpr std::string_view FileHydratorPrefix = "file://"; + class StorageBase + { + public: + virtual ~StorageBase() {} + + virtual void Configure(std::string_view ModuleId, + const std::filesystem::path& TempDir, + std::string_view TargetSpecification, + const CbObject& Options) = 0; + virtual void SaveMetadata(const CbObject& Data) = 0; + virtual CbObject LoadMetadata() = 0; + virtual CbObject GetSettings() = 0; + virtual void ParseSettings(const CbObjectView& Settings) = 0; + virtual std::vector<IoHash> List() = 0; + virtual void Put(ParallelWork& Work, + WorkerThreadPool& WorkerPool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath) = 0; + virtual void Get(ParallelWork& Work, + WorkerThreadPool& WorkerPool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath) = 0; + virtual void Delete(ParallelWork& Work, WorkerThreadPool& WorkerPool) = 0; + }; -struct FileHydrator : public HydrationStrategyBase -{ - virtual void Configure(const HydrationConfig& Config) override; - virtual void Hydrate() override; - virtual void Dehydrate() override; + constexpr std::string_view FileHydratorPrefix = "file://"; + constexpr std::string_view FileHydratorType = "file"; -private: - HydrationConfig m_Config; - std::filesystem::path m_StorageModuleRootDir; -}; + constexpr std::string_view S3HydratorPrefix = "s3://"; + constexpr std::string_view S3HydratorType = "s3"; -void -FileHydrator::Configure(const HydrationConfig& Config) -{ - m_Config = Config; + class FileStorage : public StorageBase + { + public: + FileStorage() {} + virtual void Configure(std::string_view ModuleId, + const std::filesystem::path& TempDir, + std::string_view TargetSpecification, + const CbObject& Options) + { + ZEN_UNUSED(TempDir); + if (!TargetSpecification.empty()) + { + m_StoragePath = Utf8ToWide(TargetSpecification.substr(FileHydratorPrefix.length())); + if (m_StoragePath.empty()) + { + throw zen::runtime_error("Hydration config 'file' type requires a directory path"); + } + } + else + { + CbObjectView Settings = Options["settings"].AsObjectView(); + std::string_view Path = Settings["path"].AsString(); + if (Path.empty()) + { + throw zen::runtime_error("Hydration config 'file' type requires 'settings.path'"); + } + m_StoragePath = Utf8ToWide(std::string(Path)); + } + m_StoragePath = m_StoragePath / ModuleId; + MakeSafeAbsolutePathInPlace(m_StoragePath); - std::filesystem::path ConfigPath(Utf8ToWide(m_Config.TargetSpecification.substr(FileHydratorPrefix.length()))); - MakeSafeAbsolutePathInPlace(ConfigPath); + m_StatePathName = m_StoragePath / "current-state.cbo"; + m_CASPath = m_StoragePath / "cas"; + CreateDirectories(m_CASPath); + } + virtual void SaveMetadata(const CbObject& Data) + { + BinaryWriter Output; + SaveCompactBinary(Output, Data); + WriteFile(m_StatePathName, IoBuffer(IoBuffer::Wrap, Output.GetData(), Output.GetSize())); + } + virtual CbObject LoadMetadata() + { + if (!IsFile(m_StatePathName)) + { + return {}; + } + FileContents Content = ReadFile(m_StatePathName); + if (Content.ErrorCode) + { + ThrowSystemError(Content.ErrorCode.value(), "Failed to read state file"); + } + IoBuffer Payload = Content.Flatten(); + CbValidateError Error; + CbObject Result = ValidateAndReadCompactBinaryObject(std::move(Payload), Error); + if (Error != CbValidateError::None) + { + throw std::runtime_error(fmt::format("Failed to read {} state file. Reason: {}", m_StatePathName, ToString(Error))); + } + return Result; + } - if (!std::filesystem::exists(ConfigPath)) - { - throw std::invalid_argument(fmt::format("Target does not exist: '{}'", ConfigPath.string())); - } + virtual CbObject GetSettings() override { return {}; } + virtual void ParseSettings(const CbObjectView& Settings) { ZEN_UNUSED(Settings); } - m_StorageModuleRootDir = ConfigPath / m_Config.ModuleId; + virtual std::vector<IoHash> List() + { + DirectoryContent DirContent; + GetDirectoryContent(m_CASPath, DirectoryContentFlags::IncludeFiles, DirContent); + std::vector<IoHash> Result; + Result.reserve(DirContent.Files.size()); + for (const std::filesystem::path& Path : DirContent.Files) + { + IoHash Hash; + if (IoHash::TryParse(Path.filename().string(), Hash)) + { + Result.push_back(Hash); + } + } + return Result; + } - CreateDirectories(m_StorageModuleRootDir); -} + virtual void Put(ParallelWork& Work, + WorkerThreadPool& WorkerPool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath) + { + ZEN_UNUSED(Size); + Work.ScheduleWork(WorkerPool, + [this, Hash = IoHash(Hash), SourcePath = std::filesystem::path(SourcePath)](std::atomic<bool>& AbortFlag) { + if (!AbortFlag.load()) + { + CopyFile(SourcePath, m_CASPath / fmt::format("{}", Hash), CopyFileOptions{.EnableClone = true}); + } + }); + } -void -FileHydrator::Hydrate() -{ - ZEN_INFO("Hydrating state from '{}' to '{}'", m_StorageModuleRootDir, m_Config.ServerStateDir); + virtual void Get(ParallelWork& Work, + WorkerThreadPool& WorkerPool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath) + { + ZEN_UNUSED(Size); + Work.ScheduleWork( + WorkerPool, + [this, Hash = IoHash(Hash), DestinationPath = std::filesystem::path(DestinationPath)](std::atomic<bool>& AbortFlag) { + if (!AbortFlag.load()) + { + CopyFile(m_CASPath / fmt::format("{}", Hash), DestinationPath, CopyFileOptions{.EnableClone = true}); + } + }); + } - // Ensure target is clean - ZEN_DEBUG("Wiping server state at '{}'", m_Config.ServerStateDir); - const bool ForceRemoveReadOnlyFiles = true; - CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles); + virtual void Delete(ParallelWork& Work, WorkerThreadPool& WorkerPool) override + { + ZEN_UNUSED(Work); + ZEN_UNUSED(WorkerPool); + DeleteDirectories(m_StoragePath); + } - bool WipeServerState = false; + private: + std::filesystem::path m_StoragePath; + std::filesystem::path m_StatePathName; + std::filesystem::path m_CASPath; + }; - try + class S3Storage : public StorageBase { - ZEN_DEBUG("Copying '{}' to '{}'", m_StorageModuleRootDir, m_Config.ServerStateDir); - CopyTree(m_StorageModuleRootDir, m_Config.ServerStateDir, {.EnableClone = true}); - } - catch (std::exception& Ex) - { - ZEN_WARN("Copy failed: {}. Will wipe any partially copied state from '{}'", Ex.what(), m_Config.ServerStateDir); + public: + S3Storage() {} - // We don't do the clean right here to avoid potentially running into double-throws - WipeServerState = true; - } + virtual void Configure(std::string_view ModuleId, + const std::filesystem::path& TempDir, + std::string_view TargetSpecification, + const CbObject& Options) + { + m_Options = Options; - if (WipeServerState) - { - ZEN_DEBUG("Cleaning server state '{}'", m_Config.ServerStateDir); - CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles); - } -} + CbObjectView Settings = m_Options["settings"].AsObjectView(); + std::string_view Spec; + if (!TargetSpecification.empty()) + { + Spec = TargetSpecification; + Spec.remove_prefix(S3HydratorPrefix.size()); + } + else + { + std::string_view Uri = Settings["uri"].AsString(); + if (Uri.empty()) + { + throw zen::runtime_error("Incremental S3 hydration config requires 'settings.uri'"); + } + Spec = Uri; + Spec.remove_prefix(S3HydratorPrefix.size()); + } -void -FileHydrator::Dehydrate() -{ - ZEN_INFO("Dehydrating state from '{}' to '{}'", m_Config.ServerStateDir, m_StorageModuleRootDir); + size_t SlashPos = Spec.find('/'); + std::string UserPrefix = SlashPos != std::string_view::npos ? std::string(Spec.substr(SlashPos + 1)) : std::string{}; + m_Bucket = std::string(SlashPos != std::string_view::npos ? Spec.substr(0, SlashPos) : Spec); + m_KeyPrefix = UserPrefix.empty() ? std::string(ModuleId) : UserPrefix + "/" + std::string(ModuleId); - const std::filesystem::path TargetDir = m_StorageModuleRootDir; + ZEN_ASSERT(!m_Bucket.empty()); - // Ensure target is clean. This could be replaced with an atomic copy at a later date - // (i.e copy into a temporary directory name and rename it once complete) + std::string Region = std::string(Settings["region"].AsString()); + if (Region.empty()) + { + Region = GetEnvVariable("AWS_DEFAULT_REGION"); + } + if (Region.empty()) + { + Region = GetEnvVariable("AWS_REGION"); + } + if (Region.empty()) + { + Region = "us-east-1"; + } + m_Region = std::move(Region); + + std::string AccessKeyId = GetEnvVariable("AWS_ACCESS_KEY_ID"); + if (AccessKeyId.empty()) + { + m_CredentialProvider = Ref<ImdsCredentialProvider>(new ImdsCredentialProvider({})); + } + else + { + m_Credentials.AccessKeyId = std::move(AccessKeyId); + m_Credentials.SecretAccessKey = GetEnvVariable("AWS_SECRET_ACCESS_KEY"); + m_Credentials.SessionToken = GetEnvVariable("AWS_SESSION_TOKEN"); + } + m_TempDir = TempDir; + m_Client = CreateS3Client(); + } - ZEN_DEBUG("Cleaning storage root '{}'", TargetDir); - const bool ForceRemoveReadOnlyFiles = true; - CleanDirectory(TargetDir, ForceRemoveReadOnlyFiles); + virtual void SaveMetadata(const CbObject& Data) + { + S3Client& Client = *m_Client; + BinaryWriter Output; + SaveCompactBinary(Output, Data); + IoBuffer Payload(IoBuffer::Clone, Output.GetData(), Output.GetSize()); + + std::string Key = m_KeyPrefix + "/incremental-state.cbo"; + S3Result Result = Client.PutObject(Key, std::move(Payload)); + if (!Result.IsSuccess()) + { + throw zen::runtime_error("Failed to save incremental metadata to '{}': {}", Key, Result.Error); + } + } - bool CopySuccess = true; + virtual CbObject LoadMetadata() + { + S3Client& Client = *m_Client; + std::string Key = m_KeyPrefix + "/incremental-state.cbo"; + S3GetObjectResult Result = Client.GetObject(Key); + if (!Result.IsSuccess()) + { + if (Result.Error == S3GetObjectResult::NotFoundErrorText) + { + return {}; + } + throw zen::runtime_error("Failed to load incremental metadata from '{}': {}", Key, Result.Error); + } - try - { - ZEN_DEBUG("Copying '{}' to '{}'", m_Config.ServerStateDir, TargetDir); - CopyTree(m_Config.ServerStateDir, TargetDir, {.EnableClone = true}); - } - catch (std::exception& Ex) - { - ZEN_WARN("Copy failed: {}. Will wipe any partially copied state from '{}'", Ex.what(), m_StorageModuleRootDir); + CbValidateError Error; + CbObject Meta = ValidateAndReadCompactBinaryObject(std::move(Result.Content), Error); + if (Error != CbValidateError::None) + { + throw zen::runtime_error("Failed to parse incremental metadata from '{}': {}", Key, ToString(Error)); + } + return Meta; + } - // We don't do the clean right here to avoid potentially running into double-throws - CopySuccess = false; - } + virtual CbObject GetSettings() override + { + CbObjectWriter Writer; + Writer << "MultipartChunkSize" << m_MultipartChunkSize; + return Writer.Save(); + } - if (!CopySuccess) - { - ZEN_DEBUG("Removing partially copied state from '{}'", TargetDir); - CleanDirectory(TargetDir, ForceRemoveReadOnlyFiles); - } + virtual void ParseSettings(const CbObjectView& Settings) + { + m_MultipartChunkSize = Settings["MultipartChunkSize"].AsUInt64(DefaultMultipartChunkSize); + } - ZEN_DEBUG("Wiping server state '{}'", m_Config.ServerStateDir); - CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles); -} + virtual std::vector<IoHash> List() + { + S3Client& Client = *m_Client; + std::string Prefix = m_KeyPrefix + "/cas/"; + S3ListObjectsResult Result = Client.ListObjects(Prefix); + if (!Result.IsSuccess()) + { + throw zen::runtime_error("Failed to list S3 objects under '{}': {}", Prefix, Result.Error); + } -/////////////////////////////////////////////////////////////////////////// + std::vector<IoHash> Hashes; + Hashes.reserve(Result.Objects.size()); + for (const S3ObjectInfo& Obj : Result.Objects) + { + size_t LastSlash = Obj.Key.rfind('/'); + if (LastSlash == std::string::npos) + { + continue; + } + IoHash Hash; + if (IoHash::TryParse(Obj.Key.substr(LastSlash + 1), Hash)) + { + Hashes.push_back(Hash); + } + } + return Hashes; + } -constexpr std::string_view S3HydratorPrefix = "s3://"; + virtual void Put(ParallelWork& Work, + WorkerThreadPool& WorkerPool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath) + { + Work.ScheduleWork( + WorkerPool, + [this, Hash = IoHash(Hash), Size, SourcePath = std::filesystem::path(SourcePath)](std::atomic<bool>& AbortFlag) { + if (AbortFlag.load()) + { + return; + } + S3Client& Client = *m_Client; + std::string Key = m_KeyPrefix + "/cas/" + fmt::format("{}", Hash); -struct S3Hydrator : public HydrationStrategyBase -{ - void Configure(const HydrationConfig& Config) override; - void Dehydrate() override; - void Hydrate() override; + if (Size >= (m_MultipartChunkSize + (m_MultipartChunkSize / 4))) + { + BasicFile File(SourcePath, BasicFile::Mode::kRead); + S3Result Result = Client.PutObjectMultipart( + Key, + Size, + [&File](uint64_t Offset, uint64_t ChunkSize) { return File.ReadRange(Offset, ChunkSize); }, + m_MultipartChunkSize); + if (!Result.IsSuccess()) + { + throw zen::runtime_error("Failed to upload '{}' to S3: {}", Key, Result.Error); + } + } + else + { + BasicFile File(SourcePath, BasicFile::Mode::kRead); + S3Result Result = Client.PutObject(Key, File.ReadAll()); + if (!Result.IsSuccess()) + { + throw zen::runtime_error("Failed to upload '{}' to S3: {}", Key, Result.Error); + } + } + }); + } -private: - S3Client CreateS3Client() const; - std::string BuildTimestampFolderName() const; - std::string MakeObjectKey(std::string_view FolderName, const std::filesystem::path& RelPath) const; - - HydrationConfig m_Config; - std::string m_Bucket; - std::string m_KeyPrefix; // "<user-prefix>/<ModuleId>" or just "<ModuleId>" - no trailing slash - std::string m_Region; - SigV4Credentials m_Credentials; - Ref<ImdsCredentialProvider> m_CredentialProvider; -}; + virtual void Get(ParallelWork& Work, + WorkerThreadPool& WorkerPool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath) + { + std::string Key = m_KeyPrefix + "/cas/" + fmt::format("{}", Hash); -void -S3Hydrator::Configure(const HydrationConfig& Config) -{ - m_Config = Config; + if (Size >= (m_MultipartChunkSize + (m_MultipartChunkSize / 4))) + { + class WorkData + { + public: + WorkData(const std::filesystem::path& DestPath, uint64_t Size) : m_DestFile(DestPath, BasicFile::Mode::kTruncate) + { + PrepareFileForScatteredWrite(m_DestFile.Handle(), Size); + } + ~WorkData() { m_DestFile.Flush(); } + void Write(const void* Data, uint64_t Size, uint64_t Offset) { m_DestFile.Write(Data, Size, Offset); } - std::string_view Spec = m_Config.TargetSpecification; - Spec.remove_prefix(S3HydratorPrefix.size()); + private: + BasicFile m_DestFile; + }; - size_t SlashPos = Spec.find('/'); - std::string UserPrefix = SlashPos != std::string_view::npos ? std::string(Spec.substr(SlashPos + 1)) : std::string{}; - m_Bucket = std::string(SlashPos != std::string_view::npos ? Spec.substr(0, SlashPos) : Spec); - m_KeyPrefix = UserPrefix.empty() ? m_Config.ModuleId : UserPrefix + "/" + m_Config.ModuleId; + std::shared_ptr<WorkData> Data = std::make_shared<WorkData>(DestinationPath, Size); - ZEN_ASSERT(!m_Bucket.empty()); + uint64_t Offset = 0; + while (Offset < Size) + { + uint64_t ChunkSize = std::min<uint64_t>(m_MultipartChunkSize, Size - Offset); + + Work.ScheduleWork(WorkerPool, [this, Key = Key, Offset, ChunkSize, Data](std::atomic<bool>& AbortFlag) { + if (AbortFlag) + { + return; + } + S3GetObjectResult Chunk = m_Client->GetObjectRange(Key, Offset, ChunkSize); + if (!Chunk.IsSuccess()) + { + throw zen::runtime_error("Failed to download '{}' bytes [{}-{}] from S3: {}", + Key, + Offset, + Offset + ChunkSize - 1, + Chunk.Error); + } + + Data->Write(Chunk.Content.GetData(), Chunk.Content.GetSize(), Offset); + }); + Offset += ChunkSize; + } + } + else + { + Work.ScheduleWork( + WorkerPool, + [this, Key = Key, DestinationPath = std::filesystem::path(DestinationPath)](std::atomic<bool>& AbortFlag) { + if (AbortFlag) + { + return; + } + S3GetObjectResult Chunk = m_Client->GetObject(Key, m_TempDir); + if (!Chunk.IsSuccess()) + { + throw zen::runtime_error("Failed to download '{}' from S3: {}", Key, Chunk.Error); + } + + if (IoBufferFileReference FileRef; Chunk.Content.GetFileReference(FileRef)) + { + std::error_code Ec; + std::filesystem::path ChunkPath = PathFromHandle(FileRef.FileHandle, Ec); + if (Ec) + { + WriteFile(DestinationPath, Chunk.Content); + } + else + { + Chunk.Content.SetDeleteOnClose(false); + Chunk.Content = {}; + RenameFile(ChunkPath, DestinationPath, Ec); + if (Ec) + { + Chunk.Content = IoBufferBuilder::MakeFromFile(ChunkPath); + Chunk.Content.SetDeleteOnClose(true); + WriteFile(DestinationPath, Chunk.Content); + } + } + } + else + { + WriteFile(DestinationPath, Chunk.Content); + } + }); + } + } - std::string Region = GetEnvVariable("AWS_DEFAULT_REGION"); - if (Region.empty()) - { - Region = GetEnvVariable("AWS_REGION"); - } - if (Region.empty()) - { - Region = "us-east-1"; - } - m_Region = std::move(Region); + virtual void Delete(ParallelWork& Work, WorkerThreadPool& WorkerPool) override + { + std::string Prefix = m_KeyPrefix + "/"; + S3ListObjectsResult ListResult = m_Client->ListObjects(Prefix); + if (!ListResult.IsSuccess()) + { + throw zen::runtime_error("Failed to list S3 objects for deletion under '{}': {}", Prefix, ListResult.Error); + } + for (const S3ObjectInfo& Obj : ListResult.Objects) + { + Work.ScheduleWork(WorkerPool, [this, Key = Obj.Key](std::atomic<bool>& AbortFlag) { + if (AbortFlag.load()) + { + return; + } + S3Result DelResult = m_Client->DeleteObject(Key); + if (!DelResult.IsSuccess()) + { + throw zen::runtime_error("Failed to delete S3 object '{}': {}", Key, DelResult.Error); + } + }); + } + } - std::string AccessKeyId = GetEnvVariable("AWS_ACCESS_KEY_ID"); - if (AccessKeyId.empty()) - { - m_CredentialProvider = Ref<ImdsCredentialProvider>(new ImdsCredentialProvider({})); - } - else - { - m_Credentials.AccessKeyId = std::move(AccessKeyId); - m_Credentials.SecretAccessKey = GetEnvVariable("AWS_SECRET_ACCESS_KEY"); - m_Credentials.SessionToken = GetEnvVariable("AWS_SESSION_TOKEN"); - } -} + private: + std::unique_ptr<S3Client> CreateS3Client() const + { + S3ClientOptions Options; + Options.BucketName = m_Bucket; + Options.Region = m_Region; -S3Client -S3Hydrator::CreateS3Client() const + CbObjectView Settings = m_Options["settings"].AsObjectView(); + std::string_view Endpoint = Settings["endpoint"].AsString(); + if (!Endpoint.empty()) + { + Options.Endpoint = std::string(Endpoint); + Options.PathStyle = Settings["path-style"].AsBool(); + } + + if (m_CredentialProvider) + { + Options.CredentialProvider = m_CredentialProvider; + } + else + { + Options.Credentials = m_Credentials; + } + + Options.HttpSettings.MaximumInMemoryDownloadSize = 16u * 1024u; + + return std::make_unique<S3Client>(Options); + } + + static constexpr uint64_t DefaultMultipartChunkSize = 32u * 1024u * 1024u; + + std::string m_KeyPrefix; + CbObject m_Options; + std::string m_Bucket; + std::string m_Region; + SigV4Credentials m_Credentials; + Ref<ImdsCredentialProvider> m_CredentialProvider; + std::unique_ptr<S3Client> m_Client; + std::filesystem::path m_TempDir; + uint64_t m_MultipartChunkSize = DefaultMultipartChunkSize; + }; + +} // namespace hydration_impl + +using namespace hydration_impl; + +/////////////////////////////////////////////////////////////////////////// + +class IncrementalHydrator : public HydrationStrategyBase { - S3ClientOptions Options; - Options.BucketName = m_Bucket; - Options.Region = m_Region; +public: + IncrementalHydrator(std::unique_ptr<StorageBase>&& Storage); + virtual ~IncrementalHydrator() override; + virtual void Configure(const HydrationConfig& Config) override; + virtual void Dehydrate(const CbObject& CachedState) override; + virtual CbObject Hydrate() override; + virtual void Obliterate() override; - if (!m_Config.S3Endpoint.empty()) +private: + struct Entry { - Options.Endpoint = m_Config.S3Endpoint; - Options.PathStyle = m_Config.S3PathStyle; - } + std::filesystem::path RelativePath; + uint64_t Size; + uint64_t ModTick; + IoHash Hash; + }; - if (m_CredentialProvider) - { - Options.CredentialProvider = m_CredentialProvider; - } - else - { - Options.Credentials = m_Credentials; - } + std::unique_ptr<StorageBase> m_Storage; + HydrationConfig m_Config; + WorkerThreadPool m_FallbackWorkPool; + std::atomic<bool> m_FallbackAbortFlag{false}; + std::atomic<bool> m_FallbackPauseFlag{false}; + HydrationConfig::ThreadingOptions m_Threading{.WorkerPool = &m_FallbackWorkPool, + .AbortFlag = &m_FallbackAbortFlag, + .PauseFlag = &m_FallbackPauseFlag}; +}; - return S3Client(Options); +IncrementalHydrator::IncrementalHydrator(std::unique_ptr<StorageBase>&& Storage) : m_Storage(std::move(Storage)), m_FallbackWorkPool(0) +{ } -std::string -S3Hydrator::BuildTimestampFolderName() const +IncrementalHydrator::~IncrementalHydrator() { - UtcTime Now = UtcTime::Now(); - return fmt::format("{:04d}{:02d}{:02d}-{:02d}{:02d}{:02d}-{:03d}", - Now.Tm.tm_year + 1900, - Now.Tm.tm_mon + 1, - Now.Tm.tm_mday, - Now.Tm.tm_hour, - Now.Tm.tm_min, - Now.Tm.tm_sec, - Now.Ms); + m_Storage.reset(); } -std::string -S3Hydrator::MakeObjectKey(std::string_view FolderName, const std::filesystem::path& RelPath) const +void +IncrementalHydrator::Configure(const HydrationConfig& Config) { - return m_KeyPrefix + "/" + std::string(FolderName) + "/" + RelPath.generic_string(); + m_Config = Config; + m_Storage->Configure(Config.ModuleId, Config.TempDir, Config.TargetSpecification, Config.Options); + if (Config.Threading) + { + m_Threading = *Config.Threading; + } } void -S3Hydrator::Dehydrate() +IncrementalHydrator::Dehydrate(const CbObject& CachedState) { - ZEN_INFO("Dehydrating state from '{}' to s3://{}/{}", m_Config.ServerStateDir, m_Bucket, m_KeyPrefix); + Stopwatch Timer; + const std::filesystem::path ServerStateDir = MakeSafeAbsolutePath(m_Config.ServerStateDir); try { - S3Client Client = CreateS3Client(); - std::string FolderName = BuildTimestampFolderName(); - uint64_t TotalBytes = 0; - uint32_t FileCount = 0; - std::chrono::steady_clock::time_point UploadStart = std::chrono::steady_clock::now(); + std::unordered_map<std::string, size_t> StateEntryLookup; + std::vector<Entry> StateEntries; + for (CbFieldView FieldView : CachedState["Files"].AsArrayView()) + { + CbObjectView EntryView = FieldView.AsObjectView(); + std::filesystem::path RelativePath(EntryView["Path"].AsString()); + uint64_t Size = EntryView["Size"].AsUInt64(); + uint64_t ModTick = EntryView["ModTick"].AsUInt64(); + IoHash Hash = EntryView["Hash"].AsHash(); + + StateEntryLookup.insert_or_assign(RelativePath.generic_string(), StateEntries.size()); + StateEntries.push_back(Entry{.RelativePath = RelativePath, .Size = Size, .ModTick = ModTick, .Hash = Hash}); + } DirectoryContent DirContent; - GetDirectoryContent(m_Config.ServerStateDir, DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive, DirContent); + GetDirectoryContent(*m_Threading.WorkerPool, + ServerStateDir, + DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive | + DirectoryContentFlags::IncludeFileSizes | DirectoryContentFlags::IncludeModificationTick, + DirContent); + + ZEN_INFO("Dehydrating module '{}' from folder '{}'. {} ({}) files", + m_Config.ModuleId, + m_Config.ServerStateDir, + DirContent.Files.size(), + NiceBytes(std::accumulate(DirContent.FileSizes.begin(), DirContent.FileSizes.end(), uint64_t(0)))); + + std::vector<Entry> Entries; + Entries.resize(DirContent.Files.size()); + + uint64_t TotalBytes = 0; + uint64_t TotalFiles = 0; + uint64_t HashedFiles = 0; + uint64_t HashedBytes = 0; + + std::unordered_set<IoHash> ExistsLookup; - for (const std::filesystem::path& AbsPath : DirContent.Files) { - std::filesystem::path RelPath = AbsPath.lexically_relative(m_Config.ServerStateDir); - if (RelPath.empty() || *RelPath.begin() == "..") + ParallelWork Work(*m_Threading.AbortFlag, *m_Threading.PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + for (size_t FileIndex = 0; FileIndex < DirContent.Files.size(); FileIndex++) { - throw zen::runtime_error( - "lexically_relative produced a '..'-escape path for '{}' relative to '{}' - " - "path form mismatch (e.g. \\\\?\\ prefix on one but not the other)", - AbsPath.string(), - m_Config.ServerStateDir.string()); - } - std::string Key = MakeObjectKey(FolderName, RelPath); + const std::filesystem::path AbsPath = MakeSafeAbsolutePath(DirContent.Files[FileIndex]); + if (AbsPath.filename() == "reserve.gc") + { + continue; + } + const std::filesystem::path RelativePath = FastRelativePath(ServerStateDir, DirContent.Files[FileIndex]); + if (*RelativePath.begin() == ".sentry-native") + { + continue; + } + if (RelativePath == ".lock") + { + continue; + } - BasicFile File(AbsPath, BasicFile::Mode::kRead); - uint64_t FileSize = File.FileSize(); + Entry& CurrentEntry = Entries[TotalFiles]; + CurrentEntry.RelativePath = RelativePath; + CurrentEntry.Size = DirContent.FileSizes[FileIndex]; + CurrentEntry.ModTick = DirContent.FileModificationTicks[FileIndex]; - S3Result UploadResult = - Client.PutObjectMultipart(Key, FileSize, [&File](uint64_t Offset, uint64_t Size) { return File.ReadRange(Offset, Size); }); - if (!UploadResult.IsSuccess()) - { - throw zen::runtime_error("Failed to upload '{}' to S3: {}", Key, UploadResult.Error); + bool FoundHash = false; + if (auto KnownIt = StateEntryLookup.find(CurrentEntry.RelativePath.generic_string()); KnownIt != StateEntryLookup.end()) + { + const Entry& StateEntry = StateEntries[KnownIt->second]; + if (StateEntry.Size == CurrentEntry.Size && StateEntry.ModTick == CurrentEntry.ModTick) + { + CurrentEntry.Hash = StateEntry.Hash; + FoundHash = true; + } + } + + if (!FoundHash) + { + Work.ScheduleWork(*m_Threading.WorkerPool, [AbsPath, EntryIndex = TotalFiles, &Entries](std::atomic<bool>& AbortFlag) { + if (AbortFlag) + { + return; + } + + Entry& CurrentEntry = Entries[EntryIndex]; + + bool FoundHash = false; + if (AbsPath.extension().empty()) + { + auto It = CurrentEntry.RelativePath.begin(); + if (It != CurrentEntry.RelativePath.end() && It->filename().string().ends_with("cas")) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(IoBufferBuilder::MakeFromFile(AbsPath)), + RawHash, + RawSize); + if (Compressed) + { + // We compose a meta-hash since taking the RawHash might collide with an existing + // non-compressed file with the same content The collision is unlikely except if the + // compressed data is zero bytes causing RawHash to be the same as an empty file. + IoHashStream Hasher; + Hasher.Append(RawHash.Hash, sizeof(RawHash.Hash)); + Hasher.Append(&CurrentEntry.Size, sizeof(CurrentEntry.Size)); + CurrentEntry.Hash = Hasher.GetHash(); + FoundHash = true; + } + } + } + + if (!FoundHash) + { + CurrentEntry.Hash = IoHash::HashBuffer(IoBufferBuilder::MakeFromFile(AbsPath)); + } + }); + HashedFiles++; + HashedBytes += CurrentEntry.Size; + } + TotalFiles++; + TotalBytes += CurrentEntry.Size; } - TotalBytes += FileSize; - ++FileCount; + std::vector<IoHash> ExistingEntries = m_Storage->List(); + ExistsLookup.insert(ExistingEntries.begin(), ExistingEntries.end()); + + Work.Wait(); + + Entries.resize(TotalFiles); } - // Write current-state.json - int64_t UploadDurationMs = - std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - UploadStart).count(); - - UtcTime Now = UtcTime::Now(); - std::string UploadTimeUtc = fmt::format("{:04d}-{:02d}-{:02d}T{:02d}:{:02d}:{:02d}.{:03d}Z", - Now.Tm.tm_year + 1900, - Now.Tm.tm_mon + 1, - Now.Tm.tm_mday, - Now.Tm.tm_hour, - Now.Tm.tm_min, - Now.Tm.tm_sec, - Now.Ms); - - CbObjectWriter Meta; - Meta << "FolderName" << FolderName; - Meta << "ModuleId" << m_Config.ModuleId; - Meta << "HostName" << GetMachineName(); - Meta << "UploadTimeUtc" << UploadTimeUtc; - Meta << "UploadDurationMs" << UploadDurationMs; - Meta << "TotalSizeBytes" << TotalBytes; - Meta << "FileCount" << FileCount; - - ExtendableStringBuilder<1024> JsonBuilder; - Meta.Save().ToJson(JsonBuilder); - - std::string MetaKey = m_KeyPrefix + "/current-state.json"; - std::string_view JsonText = JsonBuilder.ToView(); - IoBuffer MetaBuf(IoBuffer::Clone, JsonText.data(), JsonText.size()); - S3Result MetaUploadResult = Client.PutObject(MetaKey, std::move(MetaBuf)); - if (!MetaUploadResult.IsSuccess()) + uint64_t UploadedFiles = 0; + uint64_t UploadedBytes = 0; { - throw zen::runtime_error("Failed to write current-state.json to '{}': {}", MetaKey, MetaUploadResult.Error); + ParallelWork Work(*m_Threading.AbortFlag, *m_Threading.PauseFlag, WorkerThreadPool::EMode::DisableBacklog); + + for (const Entry& CurrentEntry : Entries) + { + if (!ExistsLookup.contains(CurrentEntry.Hash)) + { + m_Storage->Put(Work, + *m_Threading.WorkerPool, + CurrentEntry.Hash, + CurrentEntry.Size, + MakeSafeAbsolutePath(ServerStateDir / CurrentEntry.RelativePath)); + UploadedFiles++; + UploadedBytes += CurrentEntry.Size; + } + } + + Work.Wait(); + uint64_t UploadTimeMs = Timer.GetElapsedTimeMs(); + + UtcTime Now = UtcTime::Now(); + std::string UploadTimeUtc = fmt::format("{:04d}-{:02d}-{:02d}T{:02d}:{:02d}:{:02d}.{:03d}Z", + Now.Tm.tm_year + 1900, + Now.Tm.tm_mon + 1, + Now.Tm.tm_mday, + Now.Tm.tm_hour, + Now.Tm.tm_min, + Now.Tm.tm_sec, + Now.Ms); + + CbObjectWriter Meta; + Meta << "SourceFolder" << ServerStateDir.generic_string(); + Meta << "ModuleId" << m_Config.ModuleId; + Meta << "HostName" << GetMachineName(); + Meta << "UploadTimeUtc" << UploadTimeUtc; + Meta << "UploadDurationMs" << UploadTimeMs; + Meta << "TotalSizeBytes" << TotalBytes; + Meta << "StorageSettings" << m_Storage->GetSettings(); + + Meta.BeginArray("Files"); + for (const Entry& CurrentEntry : Entries) + { + Meta.BeginObject(); + { + Meta << "Path" << CurrentEntry.RelativePath.generic_string(); + Meta << "Size" << CurrentEntry.Size; + Meta << "ModTick" << CurrentEntry.ModTick; + Meta << "Hash" << CurrentEntry.Hash; + } + Meta.EndObject(); + } + Meta.EndArray(); + + m_Storage->SaveMetadata(Meta.Save()); } - ZEN_INFO("Dehydration complete: {} files, {} bytes, {} ms", FileCount, TotalBytes, UploadDurationMs); + ZEN_DEBUG("Cleaning server state '{}'", m_Config.ServerStateDir); + CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, ServerStateDir); + + ZEN_INFO("Dehydration of module '{}' completed from folder '{}'. Hashed {} ({}). Uploaded {} ({}). Total {} ({}) in {}", + m_Config.ModuleId, + m_Config.ServerStateDir, + HashedFiles, + NiceBytes(HashedBytes), + UploadedFiles, + NiceBytes(UploadedBytes), + TotalFiles, + NiceBytes(TotalBytes), + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } - catch (std::exception& Ex) + catch (const std::exception& Ex) { - // Any in-progress multipart upload has already been aborted by PutObjectMultipart. - // current-state.json is only written on success, so the previous S3 state remains valid. - ZEN_WARN("S3 dehydration failed: {}. S3 state not updated.", Ex.what()); + ZEN_WARN("Dehydration of module '{}' failed: {}. Leaving server state '{}'", m_Config.ModuleId, Ex.what(), m_Config.ServerStateDir); } } -void -S3Hydrator::Hydrate() +CbObject +IncrementalHydrator::Hydrate() { - ZEN_INFO("Hydrating state from s3://{}/{} to '{}'", m_Bucket, m_KeyPrefix, m_Config.ServerStateDir); - - const bool ForceRemoveReadOnlyFiles = true; - - // Clean temp dir before starting in case of leftover state from a previous failed hydration - ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir); - CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles); - - bool WipeServerState = false; + Stopwatch Timer; + const std::filesystem::path ServerStateDir = MakeSafeAbsolutePath(m_Config.ServerStateDir); + const std::filesystem::path TempDir = MakeSafeAbsolutePath(m_Config.TempDir); try { - S3Client Client = CreateS3Client(); - std::string MetaKey = m_KeyPrefix + "/current-state.json"; - - S3HeadObjectResult HeadResult = Client.HeadObject(MetaKey); - if (HeadResult.Status == HeadObjectResult::NotFound) - { - throw zen::runtime_error("No state found in S3 at '{}'", MetaKey); - } - if (!HeadResult.IsSuccess()) + CbObject Meta = m_Storage->LoadMetadata(); + if (!Meta) { - throw zen::runtime_error("Failed to check for state in S3 at '{}': {}", MetaKey, HeadResult.Error); + ZEN_INFO("No dehydrated state for module {} found, cleaning server state: '{}'", m_Config.ModuleId, m_Config.ServerStateDir); + CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, ServerStateDir); + return CbObject(); } - S3GetObjectResult MetaResult = Client.GetObject(MetaKey); - if (!MetaResult.IsSuccess()) - { - throw zen::runtime_error("Failed to read current-state.json from '{}': {}", MetaKey, MetaResult.Error); - } + std::unordered_map<std::string, size_t> EntryLookup; + std::vector<Entry> Entries; + uint64_t TotalSize = 0; - std::string ParseError; - json11::Json MetaJson = json11::Json::parse(std::string(MetaResult.AsText()), ParseError); - if (!ParseError.empty()) + for (CbFieldView FieldView : Meta["Files"]) { - throw zen::runtime_error("Failed to parse current-state.json from '{}': {}", MetaKey, ParseError); + CbObjectView EntryView = FieldView.AsObjectView(); + if (EntryView) + { + Entry NewEntry = {.RelativePath = std::filesystem::path(EntryView["Path"].AsString()), + .Size = EntryView["Size"].AsUInt64(), + .ModTick = EntryView["ModTick"].AsUInt64(), + .Hash = EntryView["Hash"].AsHash()}; + TotalSize += NewEntry.Size; + EntryLookup.insert_or_assign(NewEntry.RelativePath.generic_string(), Entries.size()); + Entries.emplace_back(std::move(NewEntry)); + } } - std::string FolderName = MetaJson["FolderName"].string_value(); - if (FolderName.empty()) - { - throw zen::runtime_error("current-state.json from '{}' has missing or empty FolderName", MetaKey); - } + ZEN_INFO("Hydrating module '{}' to folder '{}'. {} ({}) files", + m_Config.ModuleId, + m_Config.ServerStateDir, + Entries.size(), + NiceBytes(TotalSize)); - std::string FolderPrefix = m_KeyPrefix + "/" + FolderName + "/"; - S3ListObjectsResult ListResult = Client.ListObjects(FolderPrefix); - if (!ListResult.IsSuccess()) - { - throw zen::runtime_error("Failed to list S3 objects under '{}': {}", FolderPrefix, ListResult.Error); - } + m_Storage->ParseSettings(Meta["StorageSettings"].AsObjectView()); + + uint64_t DownloadedBytes = 0; + uint64_t DownloadedFiles = 0; - for (const S3ObjectInfo& Obj : ListResult.Objects) { - if (!Obj.Key.starts_with(FolderPrefix)) - { - ZEN_WARN("Skipping unexpected S3 key '{}' (expected prefix '{}')", Obj.Key, FolderPrefix); - continue; - } + ParallelWork Work(*m_Threading.AbortFlag, *m_Threading.PauseFlag, WorkerThreadPool::EMode::EnableBacklog); - std::string RelKey = Obj.Key.substr(FolderPrefix.size()); - if (RelKey.empty()) + for (const Entry& CurrentEntry : Entries) { - continue; + std::filesystem::path Path = MakeSafeAbsolutePath(TempDir / CurrentEntry.RelativePath); + CreateDirectories(Path.parent_path()); + m_Storage->Get(Work, *m_Threading.WorkerPool, CurrentEntry.Hash, CurrentEntry.Size, Path); + DownloadedBytes += CurrentEntry.Size; + DownloadedFiles++; } - std::filesystem::path DestPath = MakeSafeAbsolutePath(m_Config.TempDir / std::filesystem::path(RelKey)); - CreateDirectories(DestPath.parent_path()); - BasicFile DestFile(DestPath, BasicFile::Mode::kTruncate); - DestFile.SetFileSize(Obj.Size); - - if (Obj.Size > 0) - { - BasicFileWriter Writer(DestFile, 64 * 1024); - - uint64_t Offset = 0; - while (Offset < Obj.Size) - { - uint64_t ChunkSize = std::min<uint64_t>(8 * 1024 * 1024, Obj.Size - Offset); - S3GetObjectResult Chunk = Client.GetObjectRange(Obj.Key, Offset, ChunkSize); - if (!Chunk.IsSuccess()) - { - throw zen::runtime_error("Failed to download '{}' bytes [{}-{}] from S3: {}", - Obj.Key, - Offset, - Offset + ChunkSize - 1, - Chunk.Error); - } - - Writer.Write(Chunk.Content.GetData(), Chunk.Content.GetSize(), Offset); - Offset += ChunkSize; - } - - Writer.Flush(); - } + Work.Wait(); } // Downloaded successfully - swap into ServerStateDir - ZEN_DEBUG("Wiping server state '{}'", m_Config.ServerStateDir); - CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles); + ZEN_DEBUG("Cleaning server state '{}'", m_Config.ServerStateDir); + CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, ServerStateDir); // If the two paths share at least one common component they are on the same drive/volume // and atomic renames will succeed. Otherwise fall back to a full copy. - auto [ItTmp, ItState] = - std::mismatch(m_Config.TempDir.begin(), m_Config.TempDir.end(), m_Config.ServerStateDir.begin(), m_Config.ServerStateDir.end()); - if (ItTmp != m_Config.TempDir.begin()) + auto [ItTmp, ItState] = std::mismatch(TempDir.begin(), TempDir.end(), ServerStateDir.begin(), ServerStateDir.end()); + if (ItTmp != TempDir.begin()) { - // Fast path: atomic renames - no data copying needed - for (const std::filesystem::directory_entry& Entry : std::filesystem::directory_iterator(m_Config.TempDir)) + DirectoryContent DirContent; + GetDirectoryContent(*m_Threading.WorkerPool, + TempDir, + DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::IncludeDirs, + DirContent); + + for (const std::filesystem::path& AbsPath : DirContent.Directories) { - std::filesystem::path Dest = MakeSafeAbsolutePath(m_Config.ServerStateDir / Entry.path().filename()); - if (Entry.is_directory()) + std::filesystem::path Dest = MakeSafeAbsolutePath(ServerStateDir / AbsPath.filename()); + std::error_code Ec = RenameDirectoryWithRetry(AbsPath, Dest); + if (Ec) { - RenameDirectory(Entry.path(), Dest); + throw std::system_error(Ec, fmt::format("Failed to rename directory from '{}' to '{}'", AbsPath, Dest)); } - else + } + for (const std::filesystem::path& AbsPath : DirContent.Files) + { + std::filesystem::path Dest = MakeSafeAbsolutePath(ServerStateDir / AbsPath.filename()); + std::error_code Ec = RenameFileWithRetry(AbsPath, Dest); + if (Ec) { - RenameFile(Entry.path(), Dest); + throw std::system_error(Ec, fmt::format("Failed to rename file from '{}' to '{}'", AbsPath, Dest)); } } + ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir); - CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles); + CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, TempDir); } else { // Slow path: TempDir and ServerStateDir are on different filesystems, so rename // would fail. Copy the tree instead and clean up the temp files afterwards. ZEN_DEBUG("TempDir and ServerStateDir are on different filesystems - using CopyTree"); - CopyTree(m_Config.TempDir, m_Config.ServerStateDir, {.EnableClone = true}); + CopyTree(TempDir, ServerStateDir, {.EnableClone = true}); ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir); - CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles); + CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, TempDir); + } + + // TODO: This could perhaps be done more efficently, but ok for now + DirectoryContent DirContent; + GetDirectoryContent(*m_Threading.WorkerPool, + ServerStateDir, + DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive | + DirectoryContentFlags::IncludeFileSizes | DirectoryContentFlags::IncludeModificationTick, + DirContent); + + CbObjectWriter HydrateState; + HydrateState.BeginArray("Files"); + for (size_t FileIndex = 0; FileIndex < DirContent.Files.size(); FileIndex++) + { + std::filesystem::path RelativePath = FastRelativePath(ServerStateDir, DirContent.Files[FileIndex]); + + if (auto It = EntryLookup.find(RelativePath.generic_string()); It != EntryLookup.end()) + { + HydrateState.BeginObject(); + { + HydrateState << "Path" << RelativePath.generic_string(); + HydrateState << "Size" << DirContent.FileSizes[FileIndex]; + HydrateState << "ModTick" << DirContent.FileModificationTicks[FileIndex]; + HydrateState << "Hash" << Entries[It->second].Hash; + } + HydrateState.EndObject(); + } + else + { + ZEN_ASSERT(false); + } } + HydrateState.EndArray(); + + CbObject StateObject = HydrateState.Save(); - ZEN_INFO("Hydration complete from folder '{}'", FolderName); + ZEN_INFO("Hydration of module '{}' complete to folder '{}'. {} ({}) files in {}", + m_Config.ModuleId, + m_Config.ServerStateDir, + DownloadedFiles, + NiceBytes(DownloadedBytes), + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + + return StateObject; } - catch (std::exception& Ex) + catch (const std::exception& Ex) { - ZEN_WARN("S3 hydration failed: {}. Will wipe any partially installed state.", Ex.what()); - - // We don't do the clean right here to avoid potentially running into double-throws - WipeServerState = true; + ZEN_WARN("Hydration of module '{}' failed: {}. Cleaning server state '{}'", m_Config.ModuleId, Ex.what(), m_Config.ServerStateDir); + CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, ServerStateDir); + ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir); + CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, TempDir); + return {}; } +} - if (WipeServerState) +void +IncrementalHydrator::Obliterate() +{ + const std::filesystem::path ServerStateDir = MakeSafeAbsolutePath(m_Config.ServerStateDir); + const std::filesystem::path TempDir = MakeSafeAbsolutePath(m_Config.TempDir); + + try { - ZEN_DEBUG("Cleaning server state '{}'", m_Config.ServerStateDir); - CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles); - ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir); - CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles); + ParallelWork Work(*m_Threading.AbortFlag, *m_Threading.PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + m_Storage->Delete(Work, *m_Threading.WorkerPool); + Work.Wait(); } + catch (const std::exception& Ex) + { + ZEN_WARN("Failed to delete backend storage for module '{}': {}. Proceeding with local cleanup.", m_Config.ModuleId, Ex.what()); + } + + CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, ServerStateDir); + CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, TempDir); } std::unique_ptr<HydrationStrategyBase> CreateHydrator(const HydrationConfig& Config) { - if (StrCaseCompare(Config.TargetSpecification.substr(0, FileHydratorPrefix.length()), FileHydratorPrefix) == 0) + std::unique_ptr<StorageBase> Storage; + + if (!Config.TargetSpecification.empty()) { - std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<FileHydrator>(); - Hydrator->Configure(Config); - return Hydrator; + if (StrCaseCompare(Config.TargetSpecification.substr(0, FileHydratorPrefix.length()), FileHydratorPrefix) == 0) + { + Storage = std::make_unique<FileStorage>(); + } + else if (StrCaseCompare(Config.TargetSpecification.substr(0, S3HydratorPrefix.length()), S3HydratorPrefix) == 0) + { + Storage = std::make_unique<S3Storage>(); + } + else + { + throw std::runtime_error(fmt::format("Unknown hydration strategy: {}", Config.TargetSpecification)); + } } - if (StrCaseCompare(Config.TargetSpecification.substr(0, S3HydratorPrefix.length()), S3HydratorPrefix) == 0) + else { - std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<S3Hydrator>(); - Hydrator->Configure(Config); - return Hydrator; + std::string_view Type = Config.Options["type"].AsString(); + if (Type == FileHydratorType) + { + Storage = std::make_unique<FileStorage>(); + } + else if (Type == S3HydratorType) + { + Storage = std::make_unique<S3Storage>(); + } + else if (!Type.empty()) + { + throw zen::runtime_error("Unknown hydration target type '{}'", Type); + } + else + { + throw zen::runtime_error("No hydration target configured"); + } } - throw std::runtime_error(fmt::format("Unknown hydration strategy: {}", Config.TargetSpecification)); + + auto Hydrator = std::make_unique<IncrementalHydrator>(std::move(Storage)); + Hydrator->Configure(Config); + return Hydrator; } #if ZEN_WITH_TESTS namespace { + struct TestThreading + { + WorkerThreadPool WorkerPool; + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + HydrationConfig::ThreadingOptions Options{.WorkerPool = &WorkerPool, .AbortFlag = &AbortFlag, .PauseFlag = &PauseFlag}; + + explicit TestThreading(int ThreadCount) : WorkerPool(ThreadCount) {} + }; + /// Scoped RAII helper to set/restore a single environment variable within a test. /// Used to configure AWS credentials for each S3 test's MinIO instance /// without polluting the global environment. @@ -593,10 +1161,10 @@ namespace { /// subdir/file_b.bin /// subdir/nested/file_c.bin /// Returns a vector of (relative path, content) pairs for later verification. - std::vector<std::pair<std::filesystem::path, IoBuffer>> CreateTestTree(const std::filesystem::path& BaseDir) - { - std::vector<std::pair<std::filesystem::path, IoBuffer>> Files; + typedef std::vector<std::pair<std::filesystem::path, IoBuffer>> TestFileList; + TestFileList AddTestFiles(const std::filesystem::path& BaseDir, TestFileList& Files) + { auto AddFile = [&](std::filesystem::path RelPath, IoBuffer Content) { std::filesystem::path FullPath = BaseDir / RelPath; CreateDirectories(FullPath.parent_path()); @@ -607,6 +1175,36 @@ namespace { AddFile("file_a.bin", CreateSemiRandomBlob(1024)); AddFile("subdir/file_b.bin", CreateSemiRandomBlob(2048)); AddFile("subdir/nested/file_c.bin", CreateSemiRandomBlob(512)); + AddFile("subdir/nested/file_d.bin", CreateSemiRandomBlob(512)); + AddFile("subdir/nested/file_e.bin", CreateSemiRandomBlob(512)); + AddFile("subdir/nested/file_f.bin", CreateSemiRandomBlob(512)); + + return Files; + } + + TestFileList CreateSmallTestTree(const std::filesystem::path& BaseDir) + { + TestFileList Files; + AddTestFiles(BaseDir, Files); + return Files; + } + + TestFileList CreateTestTree(const std::filesystem::path& BaseDir) + { + TestFileList Files; + AddTestFiles(BaseDir, Files); + + auto AddFile = [&](std::filesystem::path RelPath, IoBuffer Content) { + std::filesystem::path FullPath = BaseDir / RelPath; + CreateDirectories(FullPath.parent_path()); + WriteFile(FullPath, Content); + Files.emplace_back(std::move(RelPath), std::move(Content)); + }; + + AddFile("subdir/nested/medium.bulk", CreateSemiRandomBlob(256u * 1024u)); + AddFile("subdir/nested/big.bulk", CreateSemiRandomBlob(512u * 1024u)); + AddFile("subdir/nested/huge.bulk", CreateSemiRandomBlob(9u * 1024u * 1024u)); + AddFile("subdir/nested/biggest.bulk", CreateSemiRandomBlob(63u * 1024u * 1024u)); return Files; } @@ -644,7 +1242,7 @@ TEST_CASE("hydration.file.dehydrate_hydrate") CreateDirectories(HydrationTemp); const std::string ModuleId = "testmodule"; - auto TestFiles = CreateTestTree(ServerStateDir); + auto TestFiles = CreateSmallTestTree(ServerStateDir); HydrationConfig Config; Config.ServerStateDir = ServerStateDir; @@ -655,7 +1253,7 @@ TEST_CASE("hydration.file.dehydrate_hydrate") // Dehydrate: copy server state to file store { std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - Hydrator->Dehydrate(); + Hydrator->Dehydrate(CbObject()); } // Verify the module folder exists in the store and ServerStateDir was wiped @@ -672,7 +1270,7 @@ TEST_CASE("hydration.file.dehydrate_hydrate") VerifyTree(ServerStateDir, TestFiles); } -TEST_CASE("hydration.file.dehydrate_cleans_server_state") +TEST_CASE("hydration.file.hydrate_overwrites_existing_state") { ScopedTemporaryDirectory TempDir; @@ -683,7 +1281,7 @@ TEST_CASE("hydration.file.dehydrate_cleans_server_state") CreateDirectories(HydrationStore); CreateDirectories(HydrationTemp); - CreateTestTree(ServerStateDir); + auto TestFiles = CreateSmallTestTree(ServerStateDir); HydrationConfig Config; Config.ServerStateDir = ServerStateDir; @@ -691,14 +1289,26 @@ TEST_CASE("hydration.file.dehydrate_cleans_server_state") Config.ModuleId = "testmodule"; Config.TargetSpecification = "file://" + HydrationStore.string(); - std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - Hydrator->Dehydrate(); + // Dehydrate the original state + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Dehydrate(CbObject()); + } - // FileHydrator::Dehydrate() must wipe ServerStateDir when done - CHECK(std::filesystem::is_empty(ServerStateDir)); + // Put a stale file in ServerStateDir to simulate leftover state + WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256)); + + // Hydrate - must wipe stale file and restore original + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Hydrate(); + } + + CHECK_FALSE(std::filesystem::exists(ServerStateDir / "stale.bin")); + VerifyTree(ServerStateDir, TestFiles); } -TEST_CASE("hydration.file.hydrate_overwrites_existing_state") +TEST_CASE("hydration.file.excluded_files_not_dehydrated") { ScopedTemporaryDirectory TempDir; @@ -709,31 +1319,86 @@ TEST_CASE("hydration.file.hydrate_overwrites_existing_state") CreateDirectories(HydrationStore); CreateDirectories(HydrationTemp); - auto TestFiles = CreateTestTree(ServerStateDir); + auto TestFiles = CreateSmallTestTree(ServerStateDir); + + // Add files that the dehydrator should skip + WriteFile(ServerStateDir / "reserve.gc", CreateSemiRandomBlob(64)); + CreateDirectories(ServerStateDir / ".sentry-native"); + WriteFile(ServerStateDir / ".sentry-native" / "db.lock", CreateSemiRandomBlob(32)); + WriteFile(ServerStateDir / ".sentry-native" / "breadcrumb.json", CreateSemiRandomBlob(128)); HydrationConfig Config; Config.ServerStateDir = ServerStateDir; Config.TempDir = HydrationTemp; - Config.ModuleId = "testmodule"; + Config.ModuleId = "testmodule_excl"; Config.TargetSpecification = "file://" + HydrationStore.string(); - // Dehydrate the original state { std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - Hydrator->Dehydrate(); + Hydrator->Dehydrate(CbObject()); } - // Put a stale file in ServerStateDir to simulate leftover state - WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256)); - - // Hydrate - must wipe stale file and restore original + // Hydrate into a clean directory + CleanDirectory(ServerStateDir, true); { std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); Hydrator->Hydrate(); } - CHECK_FALSE(std::filesystem::exists(ServerStateDir / "stale.bin")); + // Normal files must be restored VerifyTree(ServerStateDir, TestFiles); + // Excluded files must NOT be restored + CHECK_FALSE(std::filesystem::exists(ServerStateDir / "reserve.gc")); + CHECK_FALSE(std::filesystem::exists(ServerStateDir / ".sentry-native")); +} + +// --------------------------------------------------------------------------- +// FileHydrator obliterate test +// --------------------------------------------------------------------------- + +TEST_CASE("hydration.file.obliterate") +{ + ScopedTemporaryDirectory TempDir; + + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationStore = TempDir.Path() / "hydration_store"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationStore); + CreateDirectories(HydrationTemp); + + const std::string ModuleId = "obliterate_test"; + CreateSmallTestTree(ServerStateDir); + + HydrationConfig Config; + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = ModuleId; + Config.TargetSpecification = "file://" + HydrationStore.string(); + + // Dehydrate so the backend store has data + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Dehydrate(CbObject()); + } + CHECK(std::filesystem::exists(HydrationStore / ModuleId)); + + // Put some files back in ServerStateDir and TempDir to verify cleanup + CreateSmallTestTree(ServerStateDir); + WriteFile(HydrationTemp / "leftover.tmp", CreateSemiRandomBlob(64)); + + // Obliterate + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Obliterate(); + } + + // Backend store directory deleted + CHECK_FALSE(std::filesystem::exists(HydrationStore / ModuleId)); + // ServerStateDir cleaned + CHECK(std::filesystem::is_empty(ServerStateDir)); + // TempDir cleaned + CHECK(std::filesystem::is_empty(HydrationTemp)); } // --------------------------------------------------------------------------- @@ -750,6 +1415,8 @@ TEST_CASE("hydration.file.concurrent") std::filesystem::path HydrationStore = TempDir.Path() / "hydration_store"; CreateDirectories(HydrationStore); + TestThreading Threading(8); + struct ModuleData { HydrationConfig Config; @@ -769,7 +1436,8 @@ TEST_CASE("hydration.file.concurrent") Modules[I].Config.TempDir = TempPath; Modules[I].Config.ModuleId = ModuleId; Modules[I].Config.TargetSpecification = "file://" + HydrationStore.string(); - Modules[I].Files = CreateTestTree(StateDir); + Modules[I].Config.Threading = Threading.Options; + Modules[I].Files = CreateSmallTestTree(StateDir); } // Concurrent dehydrate @@ -783,7 +1451,7 @@ TEST_CASE("hydration.file.concurrent") { Work.ScheduleWork(Pool, [&Config = Modules[I].Config](std::atomic<bool>&) { std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - Hydrator->Dehydrate(); + Hydrator->Dehydrate(CbObject()); }); } Work.Wait(); @@ -818,14 +1486,14 @@ TEST_CASE("hydration.file.concurrent") // --------------------------------------------------------------------------- // S3Hydrator tests // -// Each test case spawns its own local MinIO instance (self-contained, no external setup needed). +// Each test case spawns a local MinIO instance (self-contained, no external setup needed). // The MinIO binary must be present in the same directory as the test executable (copied by xmake). // --------------------------------------------------------------------------- TEST_CASE("hydration.s3.dehydrate_hydrate") { MinioProcessOptions MinioOpts; - MinioOpts.Port = 19010; + MinioOpts.Port = 19011; MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -840,168 +1508,57 @@ TEST_CASE("hydration.s3.dehydrate_hydrate") CreateDirectories(ServerStateDir); CreateDirectories(HydrationTemp); - const std::string ModuleId = "s3test_roundtrip"; - auto TestFiles = CreateTestTree(ServerStateDir); - HydrationConfig Config; - Config.ServerStateDir = ServerStateDir; - Config.TempDir = HydrationTemp; - Config.ModuleId = ModuleId; - Config.TargetSpecification = "s3://zen-hydration-test"; - Config.S3Endpoint = Minio.Endpoint(); - Config.S3PathStyle = true; - - // Dehydrate: upload server state to MinIO + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = "s3test_roundtrip"; { - std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - Hydrator->Dehydrate(); + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); } - // Wipe server state - CleanDirectory(ServerStateDir, true); - CHECK(std::filesystem::is_empty(ServerStateDir)); - - // Hydrate: download from MinIO back to server state + // Hydrate with no prior S3 state (first-boot path). Pre-populate ServerStateDir + // with a stale file to confirm the cleanup branch wipes it. + WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256)); { std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); Hydrator->Hydrate(); } - - // Verify restored contents match the original - VerifyTree(ServerStateDir, TestFiles); -} - -TEST_CASE("hydration.s3.current_state_json_selects_latest_folder") -{ - // Each Dehydrate() uploads files to a new timestamp-named folder and then overwrites - // current-state.json to point at that folder. Old folders are NOT deleted. - // Hydrate() must read current-state.json to determine which folder to restore from. - // - // This test verifies that: - // 1. After two dehydrations, Hydrate() restores from the second snapshot, not the first, - // confirming that current-state.json was updated between dehydrations. - // 2. current-state.json is updated to point at the second (latest) folder. - // 3. Hydrate() restores the v2 snapshot (identified by v2marker.bin), NOT the v1 snapshot. - - MinioProcessOptions MinioOpts; - MinioOpts.Port = 19011; - MinioProcess Minio(MinioOpts); - Minio.SpawnMinioServer(); - Minio.CreateBucket("zen-hydration-test"); - - ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); - ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); - - ScopedTemporaryDirectory TempDir; - - std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; - std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; - CreateDirectories(ServerStateDir); - CreateDirectories(HydrationTemp); - - const std::string ModuleId = "s3test_folder_select"; - - HydrationConfig Config; - Config.ServerStateDir = ServerStateDir; - Config.TempDir = HydrationTemp; - Config.ModuleId = ModuleId; - Config.TargetSpecification = "s3://zen-hydration-test"; - Config.S3Endpoint = Minio.Endpoint(); - Config.S3PathStyle = true; + CHECK(std::filesystem::is_empty(ServerStateDir)); // v1: dehydrate without a marker file - CreateTestTree(ServerStateDir); + CreateSmallTestTree(ServerStateDir); { std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - Hydrator->Dehydrate(); + Hydrator->Dehydrate(CbObject()); } - // ServerStateDir is now empty. Wait so the v2 timestamp folder name is strictly later - // (timestamp resolution is 1 ms, but macOS scheduler granularity requires a larger margin). - Sleep(100); - // v2: dehydrate WITH a marker file that only v2 has - CreateTestTree(ServerStateDir); + CreateSmallTestTree(ServerStateDir); WriteFile(ServerStateDir / "v2marker.bin", CreateSemiRandomBlob(64)); { std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - Hydrator->Dehydrate(); + Hydrator->Dehydrate(CbObject()); } - // Hydrate must restore v2 (current-state.json points to the v2 folder) + // Hydrate must restore v2 (the latest dehydrated state) CleanDirectory(ServerStateDir, true); { std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); Hydrator->Hydrate(); } - // v2 marker must be present - confirms current-state.json pointed to the v2 folder + // v2 marker must be present - confirms the second dehydration overwrote the first CHECK(std::filesystem::exists(ServerStateDir / "v2marker.bin")); - // Subdirectory hierarchy must also be intact CHECK(std::filesystem::exists(ServerStateDir / "subdir" / "file_b.bin")); CHECK(std::filesystem::exists(ServerStateDir / "subdir" / "nested" / "file_c.bin")); } -TEST_CASE("hydration.s3.module_isolation") -{ - // Two independent modules dehydrate/hydrate without interfering with each other. - // Uses VerifyTree with per-module byte content to detect cross-module data mixing. - MinioProcessOptions MinioOpts; - MinioOpts.Port = 19012; - MinioProcess Minio(MinioOpts); - Minio.SpawnMinioServer(); - Minio.CreateBucket("zen-hydration-test"); - - ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); - ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); - - ScopedTemporaryDirectory TempDir; - - struct ModuleData - { - HydrationConfig Config; - std::vector<std::pair<std::filesystem::path, IoBuffer>> Files; - }; - - std::vector<ModuleData> Modules; - for (const char* ModuleId : {"s3test_iso_a", "s3test_iso_b"}) - { - std::filesystem::path StateDir = TempDir.Path() / ModuleId / "state"; - std::filesystem::path TempPath = TempDir.Path() / ModuleId / "temp"; - CreateDirectories(StateDir); - CreateDirectories(TempPath); - - ModuleData Data; - Data.Config.ServerStateDir = StateDir; - Data.Config.TempDir = TempPath; - Data.Config.ModuleId = ModuleId; - Data.Config.TargetSpecification = "s3://zen-hydration-test"; - Data.Config.S3Endpoint = Minio.Endpoint(); - Data.Config.S3PathStyle = true; - Data.Files = CreateTestTree(StateDir); - - std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Data.Config); - Hydrator->Dehydrate(); - - Modules.push_back(std::move(Data)); - } - - for (ModuleData& Module : Modules) - { - CleanDirectory(Module.Config.ServerStateDir, true); - std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Module.Config); - Hydrator->Hydrate(); - - // Each module's files must be independently restorable with correct byte content. - // If S3 key prefixes were mixed up, CreateSemiRandomBlob content would differ. - VerifyTree(Module.Config.ServerStateDir, Module.Files); - } -} - -// --------------------------------------------------------------------------- -// S3Hydrator concurrent test -// --------------------------------------------------------------------------- - TEST_CASE("hydration.s3.concurrent") { // N modules dehydrate and hydrate concurrently against MinIO. @@ -1015,7 +1572,10 @@ TEST_CASE("hydration.s3.concurrent") ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); - constexpr int kModuleCount = 4; + constexpr int kModuleCount = 6; + constexpr int kThreadCount = 4; + + TestThreading Threading(kThreadCount); ScopedTemporaryDirectory TempDir; @@ -1034,18 +1594,25 @@ TEST_CASE("hydration.s3.concurrent") CreateDirectories(StateDir); CreateDirectories(TempPath); - Modules[I].Config.ServerStateDir = StateDir; - Modules[I].Config.TempDir = TempPath; - Modules[I].Config.ModuleId = ModuleId; - Modules[I].Config.TargetSpecification = "s3://zen-hydration-test"; - Modules[I].Config.S3Endpoint = Minio.Endpoint(); - Modules[I].Config.S3PathStyle = true; - Modules[I].Files = CreateTestTree(StateDir); + Modules[I].Config.ServerStateDir = StateDir; + Modules[I].Config.TempDir = TempPath; + Modules[I].Config.ModuleId = ModuleId; + Modules[I].Config.Threading = Threading.Options; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Modules[I].Config.Options = std::move(Root).AsObject(); + } + Modules[I].Files = CreateTestTree(StateDir); } // Concurrent dehydrate { - WorkerThreadPool Pool(kModuleCount, "hydration_s3_dehy"); + WorkerThreadPool Pool(kThreadCount, "hydration_s3_dehy"); std::atomic<bool> AbortFlag{false}; std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); @@ -1054,7 +1621,7 @@ TEST_CASE("hydration.s3.concurrent") { Work.ScheduleWork(Pool, [&Config = Modules[I].Config](std::atomic<bool>&) { std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - Hydrator->Dehydrate(); + Hydrator->Dehydrate(CbObject()); }); } Work.Wait(); @@ -1063,7 +1630,7 @@ TEST_CASE("hydration.s3.concurrent") // Concurrent hydrate { - WorkerThreadPool Pool(kModuleCount, "hydration_s3_hy"); + WorkerThreadPool Pool(kThreadCount, "hydration_s3_hy"); std::atomic<bool> AbortFlag{false}; std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); @@ -1087,17 +1654,82 @@ TEST_CASE("hydration.s3.concurrent") } } -// --------------------------------------------------------------------------- -// S3Hydrator: no prior state (first-boot path) -// --------------------------------------------------------------------------- +TEST_CASE("hydration.s3.obliterate") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = 19019; + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + ScopedTemporaryDirectory TempDir; + + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationTemp); + + const std::string ModuleId = "s3test_obliterate"; + + HydrationConfig Config; + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = ModuleId; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); + } + + // Dehydrate to populate backend + CreateSmallTestTree(ServerStateDir); + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Dehydrate(CbObject()); + } + + auto ListModuleObjects = [&]() { + S3ClientOptions Opts; + Opts.BucketName = "zen-hydration-test"; + Opts.Endpoint = Minio.Endpoint(); + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = Minio.RootUser(); + Opts.Credentials.SecretAccessKey = Minio.RootPassword(); + S3Client Client(Opts); + return Client.ListObjects(ModuleId + "/"); + }; + + // Verify objects exist in S3 + CHECK(!ListModuleObjects().Objects.empty()); + + // Re-populate ServerStateDir and TempDir for cleanup verification + CreateSmallTestTree(ServerStateDir); + WriteFile(HydrationTemp / "leftover.tmp", CreateSemiRandomBlob(64)); + + // Obliterate + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Obliterate(); + } + + // Verify S3 objects deleted + CHECK(ListModuleObjects().Objects.empty()); + // Local directories cleaned + CHECK(std::filesystem::is_empty(ServerStateDir)); + CHECK(std::filesystem::is_empty(HydrationTemp)); +} -TEST_CASE("hydration.s3.no_prior_state") +TEST_CASE("hydration.s3.config_overrides") { - // Hydrate() against an empty bucket (first-boot scenario) must leave ServerStateDir empty. - // The "No state found in S3" path goes through the error-cleanup branch, which wipes - // ServerStateDir to ensure no partial or stale content is left for the server to start on. MinioProcessOptions MinioOpts; - MinioOpts.Port = 19014; + MinioOpts.Port = 19015; MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -1112,36 +1744,244 @@ TEST_CASE("hydration.s3.no_prior_state") CreateDirectories(ServerStateDir); CreateDirectories(HydrationTemp); - // Pre-populate ServerStateDir to confirm the wipe actually runs. - WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256)); + // Path prefix: "s3://bucket/some/prefix" stores objects under + // "some/prefix/<ModuleId>/..." rather than directly under "<ModuleId>/...". + { + auto TestFiles = CreateSmallTestTree(ServerStateDir); + + HydrationConfig Config; + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = "s3test_prefix"; + { + std::string ConfigJson = fmt::format( + R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test/team/project","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); + } + + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Dehydrate(CbObject()); + } + + CleanDirectory(ServerStateDir, true); + + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Hydrate(); + } + + VerifyTree(ServerStateDir, TestFiles); + } + + // Region override: 'region' in Options["settings"] takes precedence over AWS_DEFAULT_REGION. + // AWS_DEFAULT_REGION is set to a bogus value; hydration must succeed using the region from Options. + { + CleanDirectory(ServerStateDir, true); + auto TestFiles = CreateSmallTestTree(ServerStateDir); + + ScopedEnvVar EnvRegion("AWS_DEFAULT_REGION", "wrong-region"); + + HydrationConfig Config; + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = "s3test_region_override"; + { + std::string ConfigJson = fmt::format( + R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true,"region":"us-east-1"}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); + } + + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Dehydrate(CbObject()); + } + + CleanDirectory(ServerStateDir, true); + + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Hydrate(); + } + + VerifyTree(ServerStateDir, TestFiles); + } +} + +TEST_CASE("hydration.s3.dehydrate_hydrate.performance" * doctest::skip()) +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = 19010; + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + ScopedTemporaryDirectory TempDir; + + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationTemp); + + const std::string ModuleId = "s3test_performance"; + CopyTree("E:\\Dev\\hub\\brainrot\\20260402-225355-508", ServerStateDir, {.EnableClone = true}); + // auto TestFiles = CreateTestTree(ServerStateDir); + + TestThreading Threading(4); + + HydrationConfig Config; + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = ModuleId; + Config.Threading = Threading.Options; + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); + + // Dehydrate: upload server state to MinIO + { + ZEN_INFO("============== DEHYDRATE =============="); + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Dehydrate(CbObject()); + } + + for (size_t I = 0; I < 1; I++) + { + // Wipe server state + CleanDirectory(ServerStateDir, true); + CHECK(std::filesystem::is_empty(ServerStateDir)); + + // Hydrate: download from MinIO back to server state + { + ZEN_INFO("=============== HYDRATE ==============="); + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Hydrate(); + } + } +} + +//#define REAL_DATA_PATH "E:\\Dev\\hub\\zenddc\\Zen" +//#define REAL_DATA_PATH "E:\\Dev\\hub\\brainrot\\20260402-225355-508" + +TEST_CASE("hydration.file.incremental") +{ + std::filesystem::path TmpPath; +# ifdef REAL_DATA_PATH + TmpPath = std::filesystem::path(REAL_DATA_PATH).parent_path() / "hub"; +# endif + ScopedTemporaryDirectory TempDir(TmpPath); + + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationStore = TempDir.Path() / "hydration_store"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationStore); + CreateDirectories(HydrationTemp); + + const std::string ModuleId = "testmodule"; + // auto TestFiles = CreateTestTree(ServerStateDir); + + TestThreading Threading(4); HydrationConfig Config; Config.ServerStateDir = ServerStateDir; Config.TempDir = HydrationTemp; - Config.ModuleId = "s3test_no_prior"; - Config.TargetSpecification = "s3://zen-hydration-test"; - Config.S3Endpoint = Minio.Endpoint(); - Config.S3PathStyle = true; + Config.ModuleId = ModuleId; + Config.TargetSpecification = "file://" + HydrationStore.string(); + Config.Threading = Threading.Options; - std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - Hydrator->Hydrate(); + std::unique_ptr<StorageBase> Storage = std::make_unique<FileStorage>(); + std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<IncrementalHydrator>(std::move(Storage)); - // ServerStateDir must be empty: the error path wipes it to prevent a server start - // against stale or partially-installed content. + // Hydrate with no prior state + CbObject HydrationState; + { + Hydrator->Configure(Config); + HydrationState = Hydrator->Hydrate(); + CHECK_FALSE(HydrationState); + } + +# ifdef REAL_DATA_PATH + ZEN_INFO("Writing state data..."); + CopyTree(REAL_DATA_PATH, ServerStateDir, {.EnableClone = true}); + ZEN_INFO("Writing state data complete"); +# else + // Create test files and dehydrate + auto TestFiles = CreateTestTree(ServerStateDir); +# endif + { + Hydrator->Configure(Config); + Hydrator->Dehydrate(HydrationState); + } + CHECK(std::filesystem::is_empty(ServerStateDir)); + + // Hydrate: restore from S3 + { + Hydrator->Configure(Config); + HydrationState = Hydrator->Hydrate(); + } +# ifndef REAL_DATA_PATH + VerifyTree(ServerStateDir, TestFiles); +# endif + // Dehydrate again with cached state (should skip re-uploading unchanged files) + { + Hydrator->Configure(Config); + Hydrator->Dehydrate(HydrationState); + } CHECK(std::filesystem::is_empty(ServerStateDir)); + + // Hydrate one more time to confirm second dehydrate produced valid state + { + Hydrator->Configure(Config); + HydrationState = Hydrator->Hydrate(); + } + + // Replace files and dehydrate + TestFiles = CreateTestTree(ServerStateDir); + { + Hydrator->Configure(Config); + Hydrator->Dehydrate(HydrationState); + } + + // Hydrate one more time to confirm second dehydrate produced valid state + { + Hydrator->Configure(Config); + HydrationState = Hydrator->Hydrate(); + } +# ifndef REAL_DATA_PATH + VerifyTree(ServerStateDir, TestFiles); +# endif // 0 + + // Dehydrate, nothing touched - no hashing, no upload + { + Hydrator->Configure(Config); + Hydrator->Dehydrate(HydrationState); + } } // --------------------------------------------------------------------------- -// S3Hydrator: bucket path prefix in TargetSpecification +// S3Storage test // --------------------------------------------------------------------------- -TEST_CASE("hydration.s3.path_prefix") +TEST_CASE("hydration.s3.incremental") { - // TargetSpecification of the form "s3://bucket/some/prefix" stores objects under - // "some/prefix/<ModuleId>/..." rather than directly under "<ModuleId>/...". - // Tests the second branch of the m_KeyPrefix calculation in S3Hydrator::Configure(). MinioProcessOptions MinioOpts; - MinioOpts.Port = 19015; + MinioOpts.Port = 19017; MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -1149,36 +1989,132 @@ TEST_CASE("hydration.s3.path_prefix") ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); - ScopedTemporaryDirectory TempDir; + std::filesystem::path TmpPath; +# ifdef REAL_DATA_PATH + TmpPath = std::filesystem::path(REAL_DATA_PATH).parent_path() / "hub"; +# endif + ScopedTemporaryDirectory TempDir(TmpPath); std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; CreateDirectories(ServerStateDir); CreateDirectories(HydrationTemp); - std::vector<std::pair<std::filesystem::path, IoBuffer>> TestFiles = CreateTestTree(ServerStateDir); + const std::string ModuleId = "s3test_incremental"; + + TestThreading Threading(8); HydrationConfig Config; - Config.ServerStateDir = ServerStateDir; - Config.TempDir = HydrationTemp; - Config.ModuleId = "s3test_prefix"; - Config.TargetSpecification = "s3://zen-hydration-test/team/project"; - Config.S3Endpoint = Minio.Endpoint(); - Config.S3PathStyle = true; + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = ModuleId; + Config.Threading = Threading.Options; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); + } + + std::unique_ptr<StorageBase> Storage = std::make_unique<S3Storage>(); + std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<IncrementalHydrator>(std::move(Storage)); + // Hydrate with no prior state + CbObject HydrationState; { - std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - Hydrator->Dehydrate(); + Hydrator->Configure(Config); + HydrationState = Hydrator->Hydrate(); + CHECK_FALSE(HydrationState); } - CleanDirectory(ServerStateDir, true); +# ifdef REAL_DATA_PATH + ZEN_INFO("Writing state data..."); + CopyTree(REAL_DATA_PATH, ServerStateDir, {.EnableClone = true}); + ZEN_INFO("Writing state data complete"); +# else + // Create test files and dehydrate + auto TestFiles = CreateTestTree(ServerStateDir); +# endif + { + Hydrator->Configure(Config); + Hydrator->Dehydrate(HydrationState); + } + CHECK(std::filesystem::is_empty(ServerStateDir)); + // Hydrate: restore from S3 { - std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - Hydrator->Hydrate(); + Hydrator->Configure(Config); + HydrationState = Hydrator->Hydrate(); + } +# ifndef REAL_DATA_PATH + VerifyTree(ServerStateDir, TestFiles); +# endif + // Dehydrate again with cached state (should skip re-uploading unchanged files) + { + Hydrator->Configure(Config); + Hydrator->Dehydrate(HydrationState); + } + CHECK(std::filesystem::is_empty(ServerStateDir)); + + // Hydrate one more time to confirm second dehydrate produced valid state + { + Hydrator->Configure(Config); + HydrationState = Hydrator->Hydrate(); + } + + // Replace files and dehydrate + TestFiles = CreateTestTree(ServerStateDir); + { + Hydrator->Configure(Config); + Hydrator->Dehydrate(HydrationState); } + // Hydrate one more time to confirm second dehydrate produced valid state + { + Hydrator->Configure(Config); + HydrationState = Hydrator->Hydrate(); + } + +# ifndef REAL_DATA_PATH VerifyTree(ServerStateDir, TestFiles); +# endif // 0 + + // Dehydrate, nothing touched - no hashing, no upload + { + Hydrator->Configure(Config); + Hydrator->Dehydrate(HydrationState); + } +} + +TEST_CASE("hydration.create_hydrator_rejects_invalid_config") +{ + ScopedTemporaryDirectory TempDir; + + HydrationConfig Config; + Config.ServerStateDir = TempDir.Path() / "state"; + Config.TempDir = TempDir.Path() / "temp"; + Config.ModuleId = "invalid_test"; + + // Unknown TargetSpecification prefix + Config.TargetSpecification = "ftp://somewhere"; + CHECK_THROWS(CreateHydrator(Config)); + + // Unknown Options type + Config.TargetSpecification.clear(); + { + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(R"({"type":"dynamodb"})", ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); + } + CHECK_THROWS(CreateHydrator(Config)); + + // Empty Options (no type field) + Config.Options = CbObject(); + CHECK_THROWS(CreateHydrator(Config)); } TEST_SUITE_END(); diff --git a/src/zenserver/hub/hydration.h b/src/zenserver/hub/hydration.h index d29ffe5c0..fc2f309b2 100644 --- a/src/zenserver/hub/hydration.h +++ b/src/zenserver/hub/hydration.h @@ -2,10 +2,15 @@ #pragma once +#include <zencore/compactbinary.h> + #include <filesystem> +#include <optional> namespace zen { +class WorkerThreadPool; + struct HydrationConfig { // Location of server state to hydrate/dehydrate @@ -16,12 +21,18 @@ struct HydrationConfig std::string ModuleId; // Back-end specific target specification (e.g. S3 bucket, file path, etc) std::string TargetSpecification; + // Full config object when using --hub-hydration-target-config (mutually exclusive with TargetSpecification) + CbObject Options; + + struct ThreadingOptions + { + WorkerThreadPool* WorkerPool; + std::atomic<bool>* AbortFlag; + std::atomic<bool>* PauseFlag; + }; - // Optional S3 endpoint override (e.g. "http://localhost:9000" for MinIO). - std::string S3Endpoint; - // Use path-style S3 URLs (endpoint/bucket/key) instead of virtual-hosted-style - // (bucket.endpoint/key). Required for MinIO and other non-AWS endpoints. - bool S3PathStyle = false; + // External threading for parallel I/O and hashing. If not set, work runs inline on the caller's thread. + std::optional<ThreadingOptions> Threading; }; /** @@ -36,11 +47,22 @@ struct HydrationStrategyBase { virtual ~HydrationStrategyBase() = default; - virtual void Dehydrate() = 0; - virtual void Hydrate() = 0; + // Set up the hydration target from Config. Must be called before Hydrate/Dehydrate. virtual void Configure(const HydrationConfig& Config) = 0; + + // Upload server state to the configured target. ServerStateDir is wiped on success. + // On failure, ServerStateDir is left intact. + virtual void Dehydrate(const CbObject& CachedState) = 0; + + // Download state from the configured target into ServerStateDir. Returns cached state for the next Dehydrate. + // On failure, ServerStateDir is wiped and an empty CbObject is returned. + virtual CbObject Hydrate() = 0; + + // Delete all stored data for this module from the configured backend, then clean ServerStateDir and TempDir. + virtual void Obliterate() = 0; }; +// Create a configured hydrator based on Config. Ready to call Hydrate/Dehydrate immediately. std::unique_ptr<HydrationStrategyBase> CreateHydrator(const HydrationConfig& Config); #if ZEN_WITH_TESTS diff --git a/src/zenserver/hub/storageserverinstance.cpp b/src/zenserver/hub/storageserverinstance.cpp index 6b139dbf1..af2c19113 100644 --- a/src/zenserver/hub/storageserverinstance.cpp +++ b/src/zenserver/hub/storageserverinstance.cpp @@ -16,8 +16,6 @@ StorageServerInstance::StorageServerInstance(ZenServerEnvironment& RunEnvironmen , m_ModuleId(ModuleId) , m_ServerInstance(RunEnvironment, ZenServerInstance::ServerMode::kStorageServer) { - m_BaseDir = RunEnvironment.CreateChildDir(ModuleId); - m_TempDir = Config.HydrationTempPath / ModuleId; } StorageServerInstance::~StorageServerInstance() @@ -31,7 +29,7 @@ StorageServerInstance::SpawnServerProcess() m_ServerInstance.ResetDeadProcess(); m_ServerInstance.SetServerExecutablePath(GetRunningExecutablePath()); - m_ServerInstance.SetDataDir(m_BaseDir); + m_ServerInstance.SetDataDir(m_Config.StateDir); #if ZEN_PLATFORM_WINDOWS m_ServerInstance.SetJobObject(m_JobObject); #endif @@ -50,6 +48,36 @@ StorageServerInstance::SpawnServerProcess() { AdditionalOptions << " --config=\"" << MakeSafeAbsolutePath(m_Config.ConfigPath).string() << "\""; } + if (!m_Config.Malloc.empty()) + { + AdditionalOptions << " --malloc=" << m_Config.Malloc; + } + if (!m_Config.Trace.empty()) + { + AdditionalOptions << " --trace=" << m_Config.Trace; + } + if (!m_Config.TraceHost.empty()) + { + AdditionalOptions << " --tracehost=" << m_Config.TraceHost; + } + if (!m_Config.TraceFile.empty()) + { + constexpr std::string_view ModuleIdPattern = "{moduleid}"; + constexpr std::string_view PortPattern = "{port}"; + + std::string ResolvedTraceFile = m_Config.TraceFile; + for (size_t Pos = ResolvedTraceFile.find(ModuleIdPattern); Pos != std::string::npos; + Pos = ResolvedTraceFile.find(ModuleIdPattern, Pos)) + { + ResolvedTraceFile.replace(Pos, ModuleIdPattern.length(), m_ModuleId); + } + std::string PortStr = fmt::format("{}", m_Config.BasePort); + for (size_t Pos = ResolvedTraceFile.find(PortPattern); Pos != std::string::npos; Pos = ResolvedTraceFile.find(PortPattern, Pos)) + { + ResolvedTraceFile.replace(Pos, PortPattern.length(), PortStr); + } + AdditionalOptions << " --tracefile=\"" << ResolvedTraceFile << "\""; + } m_ServerInstance.SpawnServerAndWaitUntilReady(m_Config.BasePort, AdditionalOptions.ToView()); ZEN_DEBUG("Storage server instance for module '{}' started, listening on port {}", m_ModuleId, m_Config.BasePort); @@ -57,16 +85,15 @@ StorageServerInstance::SpawnServerProcess() m_ServerInstance.EnableShutdownOnDestroy(); } -void -StorageServerInstance::GetProcessMetrics(ProcessMetrics& OutMetrics) const +ProcessMetrics +StorageServerInstance::GetProcessMetrics() const { - OutMetrics.MemoryBytes = m_MemoryBytes.load(); - OutMetrics.KernelTimeMs = m_KernelTimeMs.load(); - OutMetrics.UserTimeMs = m_UserTimeMs.load(); - OutMetrics.WorkingSetSize = m_WorkingSetSize.load(); - OutMetrics.PeakWorkingSetSize = m_PeakWorkingSetSize.load(); - OutMetrics.PagefileUsage = m_PagefileUsage.load(); - OutMetrics.PeakPagefileUsage = m_PeakPagefileUsage.load(); + ProcessMetrics Metrics; + if (m_ServerInstance.IsRunning()) + { + zen::GetProcessMetrics(m_ServerInstance.GetProcessHandle(), Metrics); + } + return Metrics; } void @@ -78,7 +105,7 @@ StorageServerInstance::ProvisionLocked() return; } - ZEN_INFO("Provisioning storage server instance for module '{}', at '{}'", m_ModuleId, m_BaseDir); + ZEN_INFO("Provisioning storage server instance for module '{}', at '{}'", m_ModuleId, m_Config.StateDir); try { Hydrate(); @@ -88,7 +115,7 @@ StorageServerInstance::ProvisionLocked() { ZEN_WARN("Failed spawning server instance for module '{}', at '{}' during provisioning. Reason: {}", m_ModuleId, - m_BaseDir, + m_Config.StateDir, Ex.what()); throw; } @@ -118,6 +145,22 @@ StorageServerInstance::DeprovisionLocked() } void +StorageServerInstance::ObliterateLocked() +{ + if (m_ServerInstance.IsRunning()) + { + // m_ServerInstance.Shutdown() never throws. + m_ServerInstance.Shutdown(); + } + + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag); + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Obliterate(); +} + +void StorageServerInstance::HibernateLocked() { // Signal server to shut down, but keep data around for later wake @@ -147,7 +190,10 @@ StorageServerInstance::WakeLocked() } catch (const std::exception& Ex) { - ZEN_WARN("Failed spawning server instance for module '{}', at '{}' during waking. Reason: {}", m_ModuleId, m_BaseDir, Ex.what()); + ZEN_WARN("Failed spawning server instance for module '{}', at '{}' during waking. Reason: {}", + m_ModuleId, + m_Config.StateDir, + Ex.what()); throw; } } @@ -155,27 +201,38 @@ StorageServerInstance::WakeLocked() void StorageServerInstance::Hydrate() { - HydrationConfig Config{.ServerStateDir = m_BaseDir, - .TempDir = m_TempDir, - .ModuleId = m_ModuleId, - .TargetSpecification = m_Config.HydrationTargetSpecification}; - + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag); std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); - - Hydrator->Hydrate(); + m_HydrationState = Hydrator->Hydrate(); } void StorageServerInstance::Dehydrate() { - HydrationConfig Config{.ServerStateDir = m_BaseDir, - .TempDir = m_TempDir, - .ModuleId = m_ModuleId, - .TargetSpecification = m_Config.HydrationTargetSpecification}; - + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag); std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Dehydrate(m_HydrationState); +} + +HydrationConfig +StorageServerInstance::MakeHydrationConfig(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag) +{ + HydrationConfig Config{.ServerStateDir = m_Config.StateDir, + .TempDir = m_Config.TempDir, + .ModuleId = m_ModuleId, + .TargetSpecification = m_Config.HydrationTargetSpecification, + .Options = m_Config.HydrationOptions}; + if (m_Config.OptionalWorkerPool) + { + Config.Threading.emplace( + HydrationConfig::ThreadingOptions{.WorkerPool = m_Config.OptionalWorkerPool, .AbortFlag = &AbortFlag, .PauseFlag = &PauseFlag}); + } - Hydrator->Dehydrate(); + return Config; } StorageServerInstance::SharedLockedPtr::SharedLockedPtr() : m_Lock(nullptr), m_Instance(nullptr) @@ -249,25 +306,6 @@ StorageServerInstance::SharedLockedPtr::IsRunning() const return m_Instance->m_ServerInstance.IsRunning(); } -void -StorageServerInstance::UpdateMetricsLocked() -{ - if (m_ServerInstance.IsRunning()) - { - ProcessMetrics Metrics; - zen::GetProcessMetrics(m_ServerInstance.GetProcessHandle(), Metrics); - - m_MemoryBytes.store(Metrics.MemoryBytes); - m_KernelTimeMs.store(Metrics.KernelTimeMs); - m_UserTimeMs.store(Metrics.UserTimeMs); - m_WorkingSetSize.store(Metrics.WorkingSetSize); - m_PeakWorkingSetSize.store(Metrics.PeakWorkingSetSize); - m_PagefileUsage.store(Metrics.PagefileUsage); - m_PeakPagefileUsage.store(Metrics.PeakPagefileUsage); - } - // TODO: Resource metrics... -} - #if ZEN_WITH_TESTS void StorageServerInstance::SharedLockedPtr::TerminateForTesting() const @@ -363,6 +401,13 @@ StorageServerInstance::ExclusiveLockedPtr::Deprovision() } void +StorageServerInstance::ExclusiveLockedPtr::Obliterate() +{ + ZEN_ASSERT(m_Instance != nullptr); + m_Instance->ObliterateLocked(); +} + +void StorageServerInstance::ExclusiveLockedPtr::Hibernate() { ZEN_ASSERT(m_Instance != nullptr); diff --git a/src/zenserver/hub/storageserverinstance.h b/src/zenserver/hub/storageserverinstance.h index 94c47630c..80f8a5016 100644 --- a/src/zenserver/hub/storageserverinstance.h +++ b/src/zenserver/hub/storageserverinstance.h @@ -2,8 +2,9 @@ #pragma once -#include "resourcemetrics.h" +#include "hydration.h" +#include <zencore/compactbinary.h> #include <zenutil/zenserverprocess.h> #include <atomic> @@ -11,6 +12,8 @@ namespace zen { +class WorkerThreadPool; + /** * Storage Server Instance * @@ -24,21 +27,27 @@ public: struct Configuration { uint16_t BasePort; - std::filesystem::path HydrationTempPath; + std::filesystem::path StateDir; + std::filesystem::path TempDir; std::string HydrationTargetSpecification; + CbObject HydrationOptions; uint32_t HttpThreadCount = 0; // Automatic int CoreLimit = 0; // Automatic std::filesystem::path ConfigPath; + std::string Malloc; + std::string Trace; + std::string TraceHost; + std::string TraceFile; + + WorkerThreadPool* OptionalWorkerPool = nullptr; }; StorageServerInstance(ZenServerEnvironment& RunEnvironment, const Configuration& Config, std::string_view ModuleId); ~StorageServerInstance(); - const ResourceMetrics& GetResourceMetrics() const { return m_ResourceMetrics; } - inline std::string_view GetModuleId() const { return m_ModuleId; } inline uint16_t GetBasePort() const { return m_Config.BasePort; } - void GetProcessMetrics(ProcessMetrics& OutMetrics) const; + ProcessMetrics GetProcessMetrics() const; #if ZEN_PLATFORM_WINDOWS void SetJobObject(JobObject* InJobObject) { m_JobObject = InJobObject; } @@ -68,15 +77,10 @@ public: } bool IsRunning() const; - const ResourceMetrics& GetResourceMetrics() const + ProcessMetrics GetProcessMetrics() const { ZEN_ASSERT(m_Instance); - return m_Instance->m_ResourceMetrics; - } - void UpdateMetrics() - { - ZEN_ASSERT(m_Instance); - return m_Instance->UpdateMetricsLocked(); + return m_Instance->GetProcessMetrics(); } #if ZEN_WITH_TESTS @@ -114,14 +118,9 @@ public: } bool IsRunning() const; - const ResourceMetrics& GetResourceMetrics() const - { - ZEN_ASSERT(m_Instance); - return m_Instance->m_ResourceMetrics; - } - void Provision(); void Deprovision(); + void Obliterate(); void Hibernate(); void Wake(); @@ -135,29 +134,17 @@ public: private: void ProvisionLocked(); void DeprovisionLocked(); + void ObliterateLocked(); void HibernateLocked(); void WakeLocked(); - void UpdateMetricsLocked(); - mutable RwLock m_Lock; const Configuration m_Config; std::string m_ModuleId; ZenServerInstance m_ServerInstance; - std::filesystem::path m_BaseDir; - - std::filesystem::path m_TempDir; - ResourceMetrics m_ResourceMetrics; - - std::atomic<uint64_t> m_MemoryBytes = 0; - std::atomic<uint64_t> m_KernelTimeMs = 0; - std::atomic<uint64_t> m_UserTimeMs = 0; - std::atomic<uint64_t> m_WorkingSetSize = 0; - std::atomic<uint64_t> m_PeakWorkingSetSize = 0; - std::atomic<uint64_t> m_PagefileUsage = 0; - std::atomic<uint64_t> m_PeakPagefileUsage = 0; + CbObject m_HydrationState; #if ZEN_PLATFORM_WINDOWS JobObject* m_JobObject = nullptr; @@ -165,8 +152,9 @@ private: void SpawnServerProcess(); - void Hydrate(); - void Dehydrate(); + void Hydrate(); + void Dehydrate(); + HydrationConfig MakeHydrationConfig(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag); friend class SharedLockedPtr; friend class ExclusiveLockedPtr; diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp index 314031246..1390d112e 100644 --- a/src/zenserver/hub/zenhubserver.cpp +++ b/src/zenserver/hub/zenhubserver.cpp @@ -2,21 +2,29 @@ #include "zenhubserver.h" +#include "config/luaconfig.h" #include "frontend/frontend.h" #include "httphubservice.h" +#include "httpproxyhandler.h" #include "hub.h" +#include <zencore/compactbinary.h> #include <zencore/config.h> +#include <zencore/except.h> +#include <zencore/except_fmt.h> +#include <zencore/filesystem.h> #include <zencore/fmtutils.h> +#include <zencore/intmath.h> #include <zencore/memory/llm.h> #include <zencore/memory/memorytrace.h> #include <zencore/memory/tagtrace.h> #include <zencore/scopeguard.h> #include <zencore/sentryintegration.h> +#include <zencore/system.h> +#include <zencore/thread.h> #include <zencore/windows.h> #include <zenhttp/httpapiservice.h> #include <zenutil/service.h> -#include <zenutil/workerpools.h> ZEN_THIRD_PARTY_INCLUDES_START #include <cxxopts.hpp> @@ -53,12 +61,19 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("hub", "", "instance-id", - "Instance ID for use in notifications", + "Instance ID for use in notifications (deprecated, use --upstream-notification-instance-id)", cxxopts::value<std::string>(m_ServerOptions.InstanceId)->default_value(""), ""); Options.add_option("hub", "", + "upstream-notification-instance-id", + "Instance ID for use in notifications", + cxxopts::value<std::string>(m_ServerOptions.InstanceId), + ""); + + Options.add_option("hub", + "", "consul-endpoint", "Consul endpoint URL for service registration (empty = disabled)", cxxopts::value<std::string>(m_ServerOptions.ConsulEndpoint)->default_value(""), @@ -88,13 +103,27 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("hub", "", + "consul-register-hub", + "Register the hub parent service with Consul (instance registration is unaffected)", + cxxopts::value<bool>(m_ServerOptions.ConsulRegisterHub)->default_value("true"), + ""); + + Options.add_option("hub", + "", "hub-base-port-number", - "Base port number for provisioned instances", + "Base port number for provisioned instances (deprecated, use --hub-instance-base-port-number)", cxxopts::value<uint16_t>(m_ServerOptions.HubBasePortNumber)->default_value("21000"), ""); Options.add_option("hub", "", + "hub-instance-base-port-number", + "Base port number for provisioned instances", + cxxopts::value<uint16_t>(m_ServerOptions.HubBasePortNumber), + ""); + + Options.add_option("hub", + "", "hub-instance-limit", "Maximum number of provisioned instances for this hub", cxxopts::value<int>(m_ServerOptions.HubInstanceLimit)->default_value("1000"), @@ -113,6 +142,34 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("hub", "", + "hub-instance-malloc", + "Select memory allocator for provisioned instances (ansi|stomp|rpmalloc|mimalloc)", + cxxopts::value<std::string>(m_ServerOptions.HubInstanceMalloc)->default_value(""), + "<allocator>"); + + Options.add_option("hub", + "", + "hub-instance-trace", + "Trace channel specification for provisioned instances (e.g. default, cpu,log, memory)", + cxxopts::value<std::string>(m_ServerOptions.HubInstanceTrace)->default_value(""), + "<channels>"); + + Options.add_option("hub", + "", + "hub-instance-tracehost", + "Trace host for provisioned instances", + cxxopts::value<std::string>(m_ServerOptions.HubInstanceTraceHost)->default_value(""), + "<host>"); + + Options.add_option("hub", + "", + "hub-instance-tracefile", + "Trace file path for provisioned instances", + cxxopts::value<std::string>(m_ServerOptions.HubInstanceTraceFile)->default_value(""), + "<path>"); + + Options.add_option("hub", + "", "hub-instance-http-threads", "Number of http server connection threads for provisioned instances", cxxopts::value<unsigned int>(m_ServerOptions.HubInstanceHttpThreadCount), @@ -131,6 +188,16 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.HubInstanceConfigPath), "<instance config>"); + const uint32_t DefaultHubInstanceProvisionThreadCount = Max(GetHardwareConcurrency() / 4u, 2u); + + Options.add_option("hub", + "", + "hub-instance-provision-threads", + fmt::format("Number of threads for instance provisioning (default {})", DefaultHubInstanceProvisionThreadCount), + cxxopts::value<uint32_t>(m_ServerOptions.HubInstanceProvisionThreadCount) + ->default_value(fmt::format("{}", DefaultHubInstanceProvisionThreadCount)), + "<threads>"); + Options.add_option("hub", "", "hub-hydration-target-spec", @@ -139,6 +206,24 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.HydrationTargetSpecification), "<hydration-target-spec>"); + Options.add_option("hub", + "", + "hub-hydration-target-config", + "Path to JSON file specifying the hydration target (mutually exclusive with " + "--hub-hydration-target-spec). Supported types: 'file', 's3'.", + cxxopts::value(m_ServerOptions.HydrationTargetConfigPath), + "<path>"); + + const uint32_t DefaultHubHydrationThreadCount = Max(GetHardwareConcurrency() / 4u, 2u); + + Options.add_option( + "hub", + "", + "hub-hydration-threads", + fmt::format("Number of threads for hydration/dehydration (default {})", DefaultHubHydrationThreadCount), + cxxopts::value<uint32_t>(m_ServerOptions.HubHydrationThreadCount)->default_value(fmt::format("{}", DefaultHubHydrationThreadCount)), + "<threads>"); + #if ZEN_PLATFORM_WINDOWS Options.add_option("hub", "", @@ -203,12 +288,112 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) "Request timeout in milliseconds for instance activity check requests", cxxopts::value<uint32_t>(m_ServerOptions.WatchdogConfig.ActivityCheckRequestTimeoutMs)->default_value("200"), "<ms>"); + + Options.add_option("hub", + "", + "hub-provision-disk-limit-bytes", + "Reject provisioning when used disk bytes exceed this value (0 = no limit).", + cxxopts::value<uint64_t>(m_ServerOptions.HubProvisionDiskLimitBytes), + "<bytes>"); + + Options.add_option("hub", + "", + "hub-provision-disk-limit-percent", + "Reject provisioning when used disk exceeds this percentage of total disk (0 = no limit).", + cxxopts::value<uint32_t>(m_ServerOptions.HubProvisionDiskLimitPercent), + "<percent>"); + + Options.add_option("hub", + "", + "hub-provision-memory-limit-bytes", + "Reject provisioning when used memory bytes exceed this value (0 = no limit).", + cxxopts::value<uint64_t>(m_ServerOptions.HubProvisionMemoryLimitBytes), + "<bytes>"); + + Options.add_option("hub", + "", + "hub-provision-memory-limit-percent", + "Reject provisioning when used memory exceeds this percentage of total RAM (0 = no limit).", + cxxopts::value<uint32_t>(m_ServerOptions.HubProvisionMemoryLimitPercent), + "<percent>"); } void ZenHubServerConfigurator::AddConfigOptions(LuaConfig::Options& Options) { - ZEN_UNUSED(Options); + using namespace std::literals; + + Options.AddOption("hub.upstreamnotification.endpoint"sv, + m_ServerOptions.UpstreamNotificationEndpoint, + "upstream-notification-endpoint"sv); + Options.AddOption("hub.upstreamnotification.instanceid"sv, m_ServerOptions.InstanceId, "upstream-notification-instance-id"sv); + + Options.AddOption("hub.consul.endpoint"sv, m_ServerOptions.ConsulEndpoint, "consul-endpoint"sv); + Options.AddOption("hub.consul.tokenenv"sv, m_ServerOptions.ConsulTokenEnv, "consul-token-env"sv); + Options.AddOption("hub.consul.healthintervalseconds"sv, + m_ServerOptions.ConsulHealthIntervalSeconds, + "consul-health-interval-seconds"sv); + Options.AddOption("hub.consul.deregisterafterseconds"sv, + m_ServerOptions.ConsulDeregisterAfterSeconds, + "consul-deregister-after-seconds"sv); + Options.AddOption("hub.consul.registerhub"sv, m_ServerOptions.ConsulRegisterHub, "consul-register-hub"sv); + + Options.AddOption("hub.instance.baseportnumber"sv, m_ServerOptions.HubBasePortNumber, "hub-instance-base-port-number"sv); + Options.AddOption("hub.instance.http"sv, m_ServerOptions.HubInstanceHttpClass, "hub-instance-http"sv); + Options.AddOption("hub.instance.malloc"sv, m_ServerOptions.HubInstanceMalloc, "hub-instance-malloc"sv); + Options.AddOption("hub.instance.trace"sv, m_ServerOptions.HubInstanceTrace, "hub-instance-trace"sv); + Options.AddOption("hub.instance.tracehost"sv, m_ServerOptions.HubInstanceTraceHost, "hub-instance-tracehost"sv); + Options.AddOption("hub.instance.tracefile"sv, m_ServerOptions.HubInstanceTraceFile, "hub-instance-tracefile"sv); + Options.AddOption("hub.instance.httpthreads"sv, m_ServerOptions.HubInstanceHttpThreadCount, "hub-instance-http-threads"sv); + Options.AddOption("hub.instance.corelimit"sv, m_ServerOptions.HubInstanceCoreLimit, "hub-instance-corelimit"sv); + Options.AddOption("hub.instance.config"sv, m_ServerOptions.HubInstanceConfigPath, "hub-instance-config"sv); + Options.AddOption("hub.instance.limits.count"sv, m_ServerOptions.HubInstanceLimit, "hub-instance-limit"sv); + Options.AddOption("hub.instance.limits.disklimitbytes"sv, + m_ServerOptions.HubProvisionDiskLimitBytes, + "hub-provision-disk-limit-bytes"sv); + Options.AddOption("hub.instance.limits.disklimitpercent"sv, + m_ServerOptions.HubProvisionDiskLimitPercent, + "hub-provision-disk-limit-percent"sv); + Options.AddOption("hub.instance.limits.memorylimitbytes"sv, + m_ServerOptions.HubProvisionMemoryLimitBytes, + "hub-provision-memory-limit-bytes"sv); + Options.AddOption("hub.instance.limits.memorylimitpercent"sv, + m_ServerOptions.HubProvisionMemoryLimitPercent, + "hub-provision-memory-limit-percent"sv); + Options.AddOption("hub.instance.provisionthreads"sv, + m_ServerOptions.HubInstanceProvisionThreadCount, + "hub-instance-provision-threads"sv); + + Options.AddOption("hub.hydration.targetspec"sv, m_ServerOptions.HydrationTargetSpecification, "hub-hydration-target-spec"sv); + Options.AddOption("hub.hydration.targetconfig"sv, m_ServerOptions.HydrationTargetConfigPath, "hub-hydration-target-config"sv); + Options.AddOption("hub.hydration.threads"sv, m_ServerOptions.HubHydrationThreadCount, "hub-hydration-threads"sv); + + Options.AddOption("hub.watchdog.cycleintervalms"sv, m_ServerOptions.WatchdogConfig.CycleIntervalMs, "hub-watchdog-cycle-interval-ms"sv); + Options.AddOption("hub.watchdog.cycleprocessingbudgetms"sv, + m_ServerOptions.WatchdogConfig.CycleProcessingBudgetMs, + "hub-watchdog-cycle-processing-budget-ms"sv); + Options.AddOption("hub.watchdog.instancecheckthrottlems"sv, + m_ServerOptions.WatchdogConfig.InstanceCheckThrottleMs, + "hub-watchdog-instance-check-throttle-ms"sv); + Options.AddOption("hub.watchdog.provisionedinactivitytimeoutseconds"sv, + m_ServerOptions.WatchdogConfig.ProvisionedInactivityTimeoutSeconds, + "hub-watchdog-provisioned-inactivity-timeout-seconds"sv); + Options.AddOption("hub.watchdog.hibernatedinactivitytimeoutseconds"sv, + m_ServerOptions.WatchdogConfig.HibernatedInactivityTimeoutSeconds, + "hub-watchdog-hibernated-inactivity-timeout-seconds"sv); + Options.AddOption("hub.watchdog.inactivitycheckmarginseconds"sv, + m_ServerOptions.WatchdogConfig.InactivityCheckMarginSeconds, + "hub-watchdog-inactivity-check-margin-seconds"sv); + Options.AddOption("hub.watchdog.activitycheckconnecttimeoutms"sv, + m_ServerOptions.WatchdogConfig.ActivityCheckConnectTimeoutMs, + "hub-watchdog-activity-check-connect-timeout-ms"sv); + Options.AddOption("hub.watchdog.activitycheckrequesttimeoutms"sv, + m_ServerOptions.WatchdogConfig.ActivityCheckRequestTimeoutMs, + "hub-watchdog-activity-check-request-timeout-ms"sv); + +#if ZEN_PLATFORM_WINDOWS + Options.AddOption("hub.usejobobject"sv, m_ServerOptions.HubUseJobObject, "hub-use-job-object"sv); +#endif } void @@ -226,6 +411,28 @@ ZenHubServerConfigurator::OnConfigFileParsed(LuaConfig::Options& LuaOptions) void ZenHubServerConfigurator::ValidateOptions() { + if (m_ServerOptions.HubProvisionDiskLimitPercent > 100) + { + throw OptionParseException( + fmt::format("'--hub-provision-disk-limit-percent' ({}) must be in range 0..100", m_ServerOptions.HubProvisionDiskLimitPercent), + {}); + } + if (m_ServerOptions.HubProvisionMemoryLimitPercent > 100) + { + throw OptionParseException(fmt::format("'--hub-provision-memory-limit-percent' ({}) must be in range 0..100", + m_ServerOptions.HubProvisionMemoryLimitPercent), + {}); + } + if (!m_ServerOptions.HydrationTargetSpecification.empty() && !m_ServerOptions.HydrationTargetConfigPath.empty()) + { + throw OptionParseException("'--hub-hydration-target-spec' and '--hub-hydration-target-config' are mutually exclusive", {}); + } + if (!m_ServerOptions.HydrationTargetConfigPath.empty() && !std::filesystem::exists(m_ServerOptions.HydrationTargetConfigPath)) + { + throw OptionParseException( + fmt::format("'--hub-hydration-target-config': file not found: '{}'", m_ServerOptions.HydrationTargetConfigPath.string()), + {}); + } } /////////////////////////////////////////////////////////////////////////// @@ -247,6 +454,15 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId, HubInstanceState NewState) { ZEN_UNUSED(PreviousState); + + if (NewState == HubInstanceState::Deprovisioning || NewState == HubInstanceState::Hibernating) + { + if (Info.Port != 0) + { + m_Proxy->PrunePort(Info.Port); + } + } + if (!m_ConsulClient) { return; @@ -262,12 +478,9 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId, .Tags = std::vector<std::pair<std::string, std::string>>{std::make_pair("module", std::string(ModuleId)), std::make_pair("zen-hub", std::string(HubInstanceId)), std::make_pair("version", std::string(ZEN_CFG_VERSION))}, - .HealthIntervalSeconds = NewState == HubInstanceState::Provisioning - ? 0u - : m_ConsulHealthIntervalSeconds, // Disable health checks while not finished provisioning - .DeregisterAfterSeconds = NewState == HubInstanceState::Provisioning - ? 0u - : m_ConsulDeregisterAfterSeconds}; // Disable health checks while not finished provisioning + .HealthIntervalSeconds = NewState == HubInstanceState::Provisioning ? 0u : m_ConsulHealthIntervalSeconds, + .DeregisterAfterSeconds = NewState == HubInstanceState::Provisioning ? 0u : m_ConsulDeregisterAfterSeconds, + .InitialStatus = NewState == HubInstanceState::Provisioned ? "passing" : ""}; if (!m_ConsulClient->RegisterService(ServiceInfo)) { @@ -294,8 +507,8 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId, ZEN_INFO("Deregistered storage server instance for module '{}' at port {} from Consul", ModuleId, Info.Port); } } - // Transitional states (Deprovisioning, Hibernating, Waking, Recovering, Crashed) - // and Hibernated are intentionally ignored. + // Transitional states (Waking, Recovering, Crashed) and stable states + // not handled above (Hibernated) are intentionally ignored by Consul. } int @@ -317,6 +530,10 @@ ZenHubServer::Initialize(const ZenHubServerConfig& ServerConfig, ZenServerState: // the main test range. ZenServerEnvironment::SetBaseChildId(1000); + m_ProvisionWorkerPool = + std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubInstanceProvisionThreadCount), "hub_provision"); + m_HydrationWorkerPool = std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubHydrationThreadCount), "hub_hydration"); + m_DebugOptionForcedCrash = ServerConfig.ShouldCrash; InitializeState(ServerConfig); @@ -342,12 +559,18 @@ ZenHubServer::Cleanup() m_IoRunner.join(); } - ShutdownServices(); if (m_Http) { m_Http->Close(); } + ShutdownServices(); + + if (m_Proxy) + { + m_Proxy->Shutdown(); + } + if (m_Hub) { m_Hub->Shutdown(); @@ -357,6 +580,7 @@ ZenHubServer::Cleanup() m_HubService.reset(); m_ApiService.reset(); m_Hub.reset(); + m_Proxy.reset(); m_ConsulRegistration.reset(); m_ConsulClient.reset(); @@ -373,49 +597,121 @@ ZenHubServer::InitializeState(const ZenHubServerConfig& ServerConfig) ZEN_UNUSED(ServerConfig); } +ResourceMetrics +ZenHubServer::ResolveLimits(const ZenHubServerConfig& ServerConfig) +{ + uint64_t DiskTotal = 0; + uint64_t MemoryTotal = 0; + + if (ServerConfig.HubProvisionDiskLimitPercent > 0) + { + DiskSpace Disk; + if (DiskSpaceInfo(ServerConfig.DataDir, Disk)) + { + DiskTotal = Disk.Total; + } + else + { + ZEN_WARN("Failed to query disk space for '{}'; disk percent limit will not be applied", ServerConfig.DataDir); + } + } + if (ServerConfig.HubProvisionMemoryLimitPercent > 0) + { + MemoryTotal = GetSystemMetrics().SystemMemoryMiB * 1024 * 1024; + } + + auto Resolve = [](uint64_t Bytes, uint32_t Pct, uint64_t Total) -> uint64_t { + const uint64_t PctBytes = Pct > 0 ? (Total * Pct) / 100 : 0; + if (Bytes > 0 && PctBytes > 0) + { + return Min(Bytes, PctBytes); + } + return Bytes > 0 ? Bytes : PctBytes; + }; + + return { + .DiskUsageBytes = Resolve(ServerConfig.HubProvisionDiskLimitBytes, ServerConfig.HubProvisionDiskLimitPercent, DiskTotal), + .MemoryUsageBytes = Resolve(ServerConfig.HubProvisionMemoryLimitBytes, ServerConfig.HubProvisionMemoryLimitPercent, MemoryTotal), + }; +} + void ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) { ZEN_INFO("instantiating Hub"); + Hub::Configuration HubConfig{ + .UseJobObject = ServerConfig.HubUseJobObject, + .BasePortNumber = ServerConfig.HubBasePortNumber, + .InstanceLimit = ServerConfig.HubInstanceLimit, + .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount, + .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit, + .InstanceMalloc = ServerConfig.HubInstanceMalloc, + .InstanceTrace = ServerConfig.HubInstanceTrace, + .InstanceTraceHost = ServerConfig.HubInstanceTraceHost, + .InstanceTraceFile = ServerConfig.HubInstanceTraceFile, + .InstanceConfigPath = ServerConfig.HubInstanceConfigPath, + .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification, + .WatchDog = + { + .CycleInterval = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleIntervalMs), + .CycleProcessingBudget = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleProcessingBudgetMs), + .InstanceCheckThrottle = std::chrono::milliseconds(ServerConfig.WatchdogConfig.InstanceCheckThrottleMs), + .ProvisionedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.ProvisionedInactivityTimeoutSeconds), + .HibernatedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.HibernatedInactivityTimeoutSeconds), + .InactivityCheckMargin = std::chrono::seconds(ServerConfig.WatchdogConfig.InactivityCheckMarginSeconds), + .ActivityCheckConnectTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckConnectTimeoutMs), + .ActivityCheckRequestTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckRequestTimeoutMs), + }, + .ResourceLimits = ResolveLimits(ServerConfig), + .OptionalProvisionWorkerPool = m_ProvisionWorkerPool.get(), + .OptionalHydrationWorkerPool = m_HydrationWorkerPool.get()}; + + if (!ServerConfig.HydrationTargetConfigPath.empty()) + { + FileContents Contents = ReadFile(ServerConfig.HydrationTargetConfigPath); + if (!Contents) + { + throw zen::runtime_error("Failed to read hydration config '{}': {}", + ServerConfig.HydrationTargetConfigPath.string(), + Contents.ErrorCode.message()); + } + IoBuffer Buffer(Contents.Flatten()); + std::string_view JsonText(static_cast<const char*>(Buffer.GetData()), Buffer.GetSize()); + + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(JsonText, ParseError); + if (!ParseError.empty() || !Root.IsObject()) + { + throw zen::runtime_error("Failed to parse hydration config '{}': {}", + ServerConfig.HydrationTargetConfigPath.string(), + ParseError.empty() ? "root must be a JSON object" : ParseError); + } + HubConfig.HydrationOptions = std::move(Root).AsObject(); + } + + m_Proxy = std::make_unique<HttpProxyHandler>(); + m_Hub = std::make_unique<Hub>( - Hub::Configuration{ - .UseJobObject = ServerConfig.HubUseJobObject, - .BasePortNumber = ServerConfig.HubBasePortNumber, - .InstanceLimit = ServerConfig.HubInstanceLimit, - .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount, - .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit, - .InstanceConfigPath = ServerConfig.HubInstanceConfigPath, - .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification, - .WatchDog = - { - .CycleInterval = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleIntervalMs), - .CycleProcessingBudget = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleProcessingBudgetMs), - .InstanceCheckThrottle = std::chrono::milliseconds(ServerConfig.WatchdogConfig.InstanceCheckThrottleMs), - .ProvisionedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.ProvisionedInactivityTimeoutSeconds), - .HibernatedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.HibernatedInactivityTimeoutSeconds), - .InactivityCheckMargin = std::chrono::seconds(ServerConfig.WatchdogConfig.InactivityCheckMarginSeconds), - .ActivityCheckConnectTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckConnectTimeoutMs), - .ActivityCheckRequestTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckRequestTimeoutMs), - }}, + std::move(HubConfig), ZenServerEnvironment(ZenServerEnvironment::Hub, ServerConfig.DataDir / "hub", ServerConfig.DataDir / "servers", ServerConfig.HubInstanceHttpClass), - &GetMediumWorkerPool(EWorkloadType::Background), - m_ConsulClient ? Hub::AsyncModuleStateChangeCallbackFunc{[this, HubInstanceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId)]( - std::string_view ModuleId, - const HubProvisionedInstanceInfo& Info, - HubInstanceState PreviousState, - HubInstanceState NewState) { - OnModuleStateChanged(HubInstanceId, ModuleId, Info, PreviousState, NewState); - }} - : Hub::AsyncModuleStateChangeCallbackFunc{}); + Hub::AsyncModuleStateChangeCallbackFunc{ + [this, HubInstanceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId)](std::string_view ModuleId, + const HubProvisionedInstanceInfo& Info, + HubInstanceState PreviousState, + HubInstanceState NewState) { + OnModuleStateChanged(HubInstanceId, ModuleId, Info, PreviousState, NewState); + }}); + + m_Proxy->SetPortValidator([Hub = m_Hub.get()](uint16_t Port) { return Hub->IsInstancePort(Port); }); ZEN_INFO("instantiating API service"); m_ApiService = std::make_unique<zen::HttpApiService>(*m_Http); ZEN_INFO("instantiating hub service"); - m_HubService = std::make_unique<HttpHubService>(*m_Hub, m_StatsService, m_StatusService); + m_HubService = std::make_unique<HttpHubService>(*m_Hub, *m_Proxy, m_StatsService, m_StatusService); m_HubService->SetNotificationEndpoint(ServerConfig.UpstreamNotificationEndpoint, ServerConfig.InstanceId); m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatsService, m_StatusService); @@ -465,21 +761,32 @@ ZenHubServer::InitializeConsulRegistration(const ZenHubServerConfig& ServerConfi } else { - ZEN_INFO("Consul token read from environment variable '{}'", ConsulAccessTokenEnvName); + ZEN_INFO("Consul token will be read from environment variable '{}'", ConsulAccessTokenEnvName); } try { - m_ConsulClient = std::make_unique<consul::ConsulClient>(ServerConfig.ConsulEndpoint, ConsulAccessToken); + m_ConsulClient = std::make_unique<consul::ConsulClient>(consul::ConsulClient::Configuration{ + .BaseUri = ServerConfig.ConsulEndpoint, + .TokenEnvName = ConsulAccessTokenEnvName, + }); m_ConsulHealthIntervalSeconds = ServerConfig.ConsulHealthIntervalSeconds; m_ConsulDeregisterAfterSeconds = ServerConfig.ConsulDeregisterAfterSeconds; + if (!ServerConfig.ConsulRegisterHub) + { + ZEN_INFO( + "Hub parent Consul registration skipped (consul-register-hub is false); " + "instance registration remains enabled"); + return; + } + consul::ServiceRegistrationInfo Info; Info.ServiceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId); Info.ServiceName = "zen-hub"; // Info.Address = "localhost"; // Let the consul agent figure out out external address // TODO: Info.BaseUri? Info.Port = static_cast<uint16_t>(EffectivePort); - Info.HealthEndpoint = "hub/health"; + Info.HealthEndpoint = "health"; Info.Tags = std::vector<std::pair<std::string, std::string>>{ std::make_pair("zen-hub", Info.ServiceId), std::make_pair("version", std::string(ZEN_CFG_VERSION)), @@ -569,6 +876,8 @@ ZenHubServer::Run() OnReady(); + StartSelfSession("zenhub"); + m_Http->Run(IsInteractiveMode); SetNewState(kShuttingDown); diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h index 77df3eaa3..5e465bb14 100644 --- a/src/zenserver/hub/zenhubserver.h +++ b/src/zenserver/hub/zenhubserver.h @@ -3,8 +3,10 @@ #pragma once #include "hubinstancestate.h" +#include "resourcemetrics.h" #include "zenserver.h" +#include <zencore/workthreadpool.h> #include <zenutil/consul.h> namespace cxxopts { @@ -19,6 +21,7 @@ namespace zen { class HttpApiService; class HttpFrontendService; class HttpHubService; +class HttpProxyHandler; struct ZenHubWatchdogConfig { @@ -34,21 +37,33 @@ struct ZenHubWatchdogConfig struct ZenHubServerConfig : public ZenServerConfig { - std::string UpstreamNotificationEndpoint; - std::string InstanceId; // For use in notifications - std::string ConsulEndpoint; // If set, enables Consul service registration - std::string ConsulTokenEnv; // Environment variable name to read a Consul token from; defaults to CONSUL_HTTP_TOKEN if empty - uint32_t ConsulHealthIntervalSeconds = 10; // Interval in seconds between Consul health checks - uint32_t ConsulDeregisterAfterSeconds = 30; // Seconds before Consul deregisters an unhealthy service - uint16_t HubBasePortNumber = 21000; - int HubInstanceLimit = 1000; - bool HubUseJobObject = true; - std::string HubInstanceHttpClass = "asio"; - uint32_t HubInstanceHttpThreadCount = 0; // Automatic - int HubInstanceCoreLimit = 0; // Automatic - std::filesystem::path HubInstanceConfigPath; // Path to Lua config file - std::string HydrationTargetSpecification; // hydration/dehydration target specification + std::string UpstreamNotificationEndpoint; + std::string InstanceId; // For use in notifications + std::string ConsulEndpoint; // If set, enables Consul service registration + std::string ConsulTokenEnv; // Environment variable name to read a Consul token from; defaults to CONSUL_HTTP_TOKEN if empty + uint32_t ConsulHealthIntervalSeconds = 10; // Interval in seconds between Consul health checks + uint32_t ConsulDeregisterAfterSeconds = 30; // Seconds before Consul deregisters an unhealthy service + bool ConsulRegisterHub = true; // Whether to register the hub parent service with Consul (instance registration unaffected) + uint16_t HubBasePortNumber = 21000; + int HubInstanceLimit = 1000; + bool HubUseJobObject = true; + std::string HubInstanceHttpClass = "asio"; + std::string HubInstanceMalloc; + std::string HubInstanceTrace; + std::string HubInstanceTraceHost; + std::string HubInstanceTraceFile; + uint32_t HubInstanceHttpThreadCount = 0; // Automatic + uint32_t HubInstanceProvisionThreadCount = 0; // Synchronous provisioning + uint32_t HubHydrationThreadCount = 0; // Synchronous hydration/dehydration + int HubInstanceCoreLimit = 0; // Automatic + std::filesystem::path HubInstanceConfigPath; // Path to Lua config file + std::string HydrationTargetSpecification; // hydration/dehydration target specification + std::filesystem::path HydrationTargetConfigPath; // path to JSON config file (mutually exclusive with HydrationTargetSpecification) ZenHubWatchdogConfig WatchdogConfig; + uint64_t HubProvisionDiskLimitBytes = 0; + uint32_t HubProvisionDiskLimitPercent = 0; + uint64_t HubProvisionMemoryLimitBytes = 0; + uint32_t HubProvisionMemoryLimitPercent = 0; }; class Hub; @@ -115,7 +130,10 @@ private: std::filesystem::path m_ContentRoot; bool m_DebugOptionForcedCrash = false; - std::unique_ptr<Hub> m_Hub; + std::unique_ptr<HttpProxyHandler> m_Proxy; + std::unique_ptr<WorkerThreadPool> m_ProvisionWorkerPool; + std::unique_ptr<WorkerThreadPool> m_HydrationWorkerPool; + std::unique_ptr<Hub> m_Hub; std::unique_ptr<HttpHubService> m_HubService; std::unique_ptr<HttpApiService> m_ApiService; @@ -126,6 +144,8 @@ private: uint32_t m_ConsulHealthIntervalSeconds = 10; uint32_t m_ConsulDeregisterAfterSeconds = 30; + static ResourceMetrics ResolveLimits(const ZenHubServerConfig& ServerConfig); + void InitializeState(const ZenHubServerConfig& ServerConfig); void InitializeServices(const ZenHubServerConfig& ServerConfig); void RegisterServices(const ZenHubServerConfig& ServerConfig); diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp index 00b7a67d7..108685eb9 100644 --- a/src/zenserver/main.cpp +++ b/src/zenserver/main.cpp @@ -14,7 +14,6 @@ #include <zencore/memory/memorytrace.h> #include <zencore/memory/newdelete.h> #include <zencore/scopeguard.h> -#include <zencore/sentryintegration.h> #include <zencore/session.h> #include <zencore/string.h> #include <zencore/thread.h> @@ -169,7 +168,12 @@ AppMain(int argc, char* argv[]) if (IsDir(ServerOptions.DataDir)) { ZEN_CONSOLE_INFO("Deleting files from '{}' ({})", ServerOptions.DataDir, DeleteReason); - DeleteDirectories(ServerOptions.DataDir); + std::error_code Ec; + DeleteDirectories(ServerOptions.DataDir, Ec); + if (Ec) + { + ZEN_WARN("could not fully clean '{}': {} (continuing anyway)", ServerOptions.DataDir, Ec.message()); + } } } @@ -250,7 +254,7 @@ test_main(int argc, char** argv) zen::MaximizeOpenFileCount(); zen::testing::TestRunner Runner; - Runner.ApplyCommandLine(argc, argv); + Runner.ApplyCommandLine(argc, argv, "server.*"); return Runner.Run(); } #endif diff --git a/src/zenserver/proxy/httptrafficinspector.cpp b/src/zenserver/proxy/httptrafficinspector.cpp index 74ecbfd48..913bd2c28 100644 --- a/src/zenserver/proxy/httptrafficinspector.cpp +++ b/src/zenserver/proxy/httptrafficinspector.cpp @@ -10,29 +10,33 @@ namespace zen { // clang-format off -http_parser_settings HttpTrafficInspector::s_RequestSettings{ - .on_message_begin = [](http_parser*) { return 0; }, - .on_url = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnUrl(Data, Len); }, - .on_status = [](http_parser*, const char*, size_t) { return 0; }, - .on_header_field = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderField(Data, Len); }, - .on_header_value = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderValue(Data, Len); }, - .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, - .on_body = [](http_parser*, const char*, size_t) { return 0; }, - .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, - .on_chunk_header{}, - .on_chunk_complete{}}; - -http_parser_settings HttpTrafficInspector::s_ResponseSettings{ - .on_message_begin = [](http_parser*) { return 0; }, - .on_url = [](http_parser*, const char*, size_t) { return 0; }, - .on_status = [](http_parser*, const char*, size_t) { return 0; }, - .on_header_field = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderField(Data, Len); }, - .on_header_value = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderValue(Data, Len); }, - .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, - .on_body = [](http_parser*, const char*, size_t) { return 0; }, - .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, - .on_chunk_header{}, - .on_chunk_complete{}}; +llhttp_settings_t HttpTrafficInspector::s_RequestSettings = []() { + llhttp_settings_t S; + llhttp_settings_init(&S); + S.on_message_begin = [](llhttp_t*) { return 0; }; + S.on_url = [](llhttp_t* p, const char* Data, size_t Len) { return GetThis(p)->OnUrl(Data, Len); }; + S.on_status = [](llhttp_t*, const char*, size_t) { return 0; }; + S.on_header_field = [](llhttp_t* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderField(Data, Len); }; + S.on_header_value = [](llhttp_t* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderValue(Data, Len); }; + S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); }; + S.on_body = [](llhttp_t*, const char*, size_t) { return 0; }; + S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); }; + return S; +}(); + +llhttp_settings_t HttpTrafficInspector::s_ResponseSettings = []() { + llhttp_settings_t S; + llhttp_settings_init(&S); + S.on_message_begin = [](llhttp_t*) { return 0; }; + S.on_url = [](llhttp_t*, const char*, size_t) { return 0; }; + S.on_status = [](llhttp_t*, const char*, size_t) { return 0; }; + S.on_header_field = [](llhttp_t* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderField(Data, Len); }; + S.on_header_value = [](llhttp_t* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderValue(Data, Len); }; + S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); }; + S.on_body = [](llhttp_t*, const char*, size_t) { return 0; }; + S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); }; + return S; +}(); // clang-format on HttpTrafficInspector::HttpTrafficInspector(Direction Dir, std::string_view SessionLabel) @@ -40,7 +44,8 @@ HttpTrafficInspector::HttpTrafficInspector(Direction Dir, std::string_view Sessi , m_Direction(Dir) , m_SessionLabel(SessionLabel) { - http_parser_init(&m_Parser, Dir == Direction::Request ? HTTP_REQUEST : HTTP_RESPONSE); + llhttp_settings_t* Settings = (Dir == Direction::Request) ? &s_RequestSettings : &s_ResponseSettings; + llhttp_init(&m_Parser, Dir == Direction::Request ? HTTP_REQUEST : HTTP_RESPONSE, Settings); m_Parser.data = this; } @@ -52,11 +57,9 @@ HttpTrafficInspector::Inspect(const char* Data, size_t Length) return; } - http_parser_settings* Settings = (m_Direction == Direction::Request) ? &s_RequestSettings : &s_ResponseSettings; + llhttp_errno_t Err = llhttp_execute(&m_Parser, Data, Length); - size_t Parsed = http_parser_execute(&m_Parser, Settings, Data, Length); - - if (m_Parser.upgrade) + if (Err == HPE_PAUSED_UPGRADE) { if (m_Direction == Direction::Request) { @@ -72,15 +75,9 @@ HttpTrafficInspector::Inspect(const char* Data, size_t Length) return; } - http_errno Error = HTTP_PARSER_ERRNO(&m_Parser); - if (Error != HPE_OK) - { - ZEN_DEBUG("[{}] non-HTTP traffic detected ({}), disabling inspection", m_SessionLabel, http_errno_name(Error)); - m_Disabled = true; - } - else if (Parsed != Length) + if (Err != HPE_OK) { - ZEN_DEBUG("[{}] parser consumed {}/{} bytes, disabling inspection", m_SessionLabel, Parsed, Length); + ZEN_DEBUG("[{}] non-HTTP traffic detected ({}), disabling inspection", m_SessionLabel, llhttp_errno_name(Err)); m_Disabled = true; } } @@ -127,11 +124,11 @@ HttpTrafficInspector::OnHeadersComplete() { if (m_Direction == Direction::Request) { - m_Method = http_method_str(static_cast<http_method>(m_Parser.method)); + m_Method = llhttp_method_name(static_cast<llhttp_method_t>(llhttp_get_method(&m_Parser))); } else { - m_StatusCode = m_Parser.status_code; + m_StatusCode = static_cast<uint16_t>(llhttp_get_status_code(&m_Parser)); } return 0; } diff --git a/src/zenserver/proxy/httptrafficinspector.h b/src/zenserver/proxy/httptrafficinspector.h index f4af0e77e..8192632ba 100644 --- a/src/zenserver/proxy/httptrafficinspector.h +++ b/src/zenserver/proxy/httptrafficinspector.h @@ -6,7 +6,7 @@ #include <zencore/uid.h> ZEN_THIRD_PARTY_INCLUDES_START -#include <http_parser.h> +#include <llhttp.h> ZEN_THIRD_PARTY_INCLUDES_END #include <atomic> @@ -45,15 +45,15 @@ private: void ResetMessageState(); - static HttpTrafficInspector* GetThis(http_parser* Parser) { return static_cast<HttpTrafficInspector*>(Parser->data); } + static HttpTrafficInspector* GetThis(llhttp_t* Parser) { return static_cast<HttpTrafficInspector*>(Parser->data); } - static http_parser_settings s_RequestSettings; - static http_parser_settings s_ResponseSettings; + static llhttp_settings_t s_RequestSettings; + static llhttp_settings_t s_ResponseSettings; LoggerRef Log() { return m_Log; } LoggerRef m_Log; - http_parser m_Parser; + llhttp_t m_Parser; Direction m_Direction; std::string m_SessionLabel; bool m_Disabled = false; diff --git a/src/zenserver/proxy/zenproxyserver.cpp b/src/zenserver/proxy/zenproxyserver.cpp index 7e59a7b7e..ffa9a4295 100644 --- a/src/zenserver/proxy/zenproxyserver.cpp +++ b/src/zenserver/proxy/zenproxyserver.cpp @@ -257,7 +257,7 @@ ZenProxyServerConfigurator::ValidateOptions() for (const std::string& Raw : m_RawProxyMappings) { // The mode keyword "proxy" from argv[1] gets captured as a positional - // argument — skip it. + // argument - skip it. if (Raw == "proxy") { continue; @@ -304,7 +304,7 @@ ZenProxyServer::Initialize(const ZenProxyServerConfig& ServerConfig, ZenServerSt // worker threads don't exit prematurely between async operations. m_ProxyIoWorkGuard.emplace(m_ProxyIoContext.get_executor()); - // Start proxy I/O worker threads. Use a modest thread count — proxy work is + // Start proxy I/O worker threads. Use a modest thread count - proxy work is // I/O-bound so we don't need a thread per core, but having more than one // avoids head-of-line blocking when many connections are active. unsigned int ThreadCount = std::max(GetHardwareConcurrency() / 4, 4u); @@ -385,6 +385,8 @@ ZenProxyServer::Run() OnReady(); + StartSelfSession("zenproxy"); + m_Http->Run(IsInteractiveMode); SetNewState(kShuttingDown); @@ -422,15 +424,16 @@ ZenProxyServer::Cleanup() m_IoRunner.join(); } - m_ProxyStatsService.reset(); - m_FrontendService.reset(); - m_ApiService.reset(); - - ShutdownServices(); if (m_Http) { m_Http->Close(); } + + ShutdownServices(); + + m_ProxyStatsService.reset(); + m_FrontendService.reset(); + m_ApiService.reset(); } catch (const std::exception& Ex) { diff --git a/src/zenserver/sessions/httpsessions.cpp b/src/zenserver/sessions/httpsessions.cpp index fdf2e1f21..56a22fb04 100644 --- a/src/zenserver/sessions/httpsessions.cpp +++ b/src/zenserver/sessions/httpsessions.cpp @@ -377,7 +377,7 @@ HttpSessionsService::SessionLogRequest(HttpRouterRequest& Req) if (ServerRequest.RequestContentType() == HttpContentType::kText) { - // Raw text — split by newlines, one entry per line + // Raw text - split by newlines, one entry per line IoBuffer Payload = ServerRequest.ReadPayload(); std::string_view Text(reinterpret_cast<const char*>(Payload.GetData()), Payload.GetSize()); const DateTime Now = DateTime::Now(); @@ -512,8 +512,9 @@ HttpSessionsService::SessionLogRequest(HttpRouterRequest& Req) // void -HttpSessionsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpSessionsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); ZEN_INFO("Sessions WebSocket client connected"); m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); } diff --git a/src/zenserver/sessions/httpsessions.h b/src/zenserver/sessions/httpsessions.h index 86a23f835..6ebe61c8d 100644 --- a/src/zenserver/sessions/httpsessions.h +++ b/src/zenserver/sessions/httpsessions.h @@ -37,7 +37,7 @@ public: void SetSelfSessionId(const Oid& Id) { m_SelfSessionId = Id; } // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; diff --git a/src/zenserver/sessions/sessions.cpp b/src/zenserver/sessions/sessions.cpp index 1212ba5d8..9d4e3120c 100644 --- a/src/zenserver/sessions/sessions.cpp +++ b/src/zenserver/sessions/sessions.cpp @@ -129,7 +129,7 @@ SessionsService::~SessionsService() = default; bool SessionsService::RegisterSession(const Oid& SessionId, std::string AppName, std::string Mode, const Oid& JobId, CbObjectView Metadata) { - // Log outside the lock scope — InProcSessionLogSink calls back into + // Log outside the lock scope - InProcSessionLogSink calls back into // GetSession() which acquires m_Lock shared, so logging while holding // m_Lock exclusively would deadlock. { diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp index bbbb0c37b..f935e2c6b 100644 --- a/src/zenserver/storage/buildstore/httpbuildstore.cpp +++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp @@ -162,96 +162,81 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) fmt::format("Invalid blob hash '{}'", Hash)); } - std::vector<std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs; + m_BuildStoreStats.BlobReadCount++; + IoBuffer Blob = m_BuildStore.GetBlob(BlobHash); + if (!Blob) + { + return ServerRequest.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, fmt::format("Blob {} not found", Hash)); + } + m_BuildStoreStats.BlobHitCount++; + if (ServerRequest.RequestVerb() == HttpVerb::kPost) { + if (ServerRequest.AcceptContentType() != HttpContentType::kCbPackage) + { + m_BuildStoreStats.BadRequestCount++; + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Accept type '{}' is not supported for blob {}, expected '{}'", + ToString(ServerRequest.AcceptContentType()), + Hash, + ToString(HttpContentType::kCbPackage))); + } + CbObject RangePayload = ServerRequest.ReadPayloadObject(); - if (RangePayload) + if (!RangePayload) { - CbArrayView RangesArray = RangePayload["ranges"sv].AsArrayView(); - OffsetAndLengthPairs.reserve(RangesArray.Num()); - for (CbFieldView FieldView : RangesArray) - { - CbObjectView RangeView = FieldView.AsObjectView(); - uint64_t RangeOffset = RangeView["offset"sv].AsUInt64(); - uint64_t RangeLength = RangeView["length"sv].AsUInt64(); - OffsetAndLengthPairs.push_back(std::make_pair(RangeOffset, RangeLength)); - } - if (OffsetAndLengthPairs.size() > MaxRangeCountPerRequestSupported) - { - return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, - HttpContentType::kText, - fmt::format("Number of ranges ({}) for blob request exceeds maximum range count {}", - OffsetAndLengthPairs.size(), - MaxRangeCountPerRequestSupported)); - } + m_BuildStoreStats.BadRequestCount++; + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Missing payload for range request on blob {}", BlobHash)); } - if (OffsetAndLengthPairs.empty()) + + CbArrayView RangesArray = RangePayload["ranges"sv].AsArrayView(); + const uint64_t RangeCount = RangesArray.Num(); + if (RangeCount == 0) { m_BuildStoreStats.BadRequestCount++; return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, - "Fetching blob without ranges must be done with the GET verb"); + "POST request must include a non-empty 'ranges' array"); } - } - else - { - HttpRanges Ranges; - bool HasRange = ServerRequest.TryGetRanges(Ranges); - if (HasRange) + if (RangeCount > MaxRangeCountPerRequestSupported) { - if (Ranges.size() > 1) - { - // Only a single http range is supported, we have limited support for http multirange responses - m_BuildStoreStats.BadRequestCount++; - return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, - HttpContentType::kText, - fmt::format("Multiple ranges in blob request is only supported for {} accept type", - ToString(HttpContentType::kCbPackage))); - } - const HttpRange& FirstRange = Ranges.front(); - OffsetAndLengthPairs.push_back(std::make_pair<uint64_t, uint64_t>(FirstRange.Start, FirstRange.End - FirstRange.Start + 1)); + m_BuildStoreStats.BadRequestCount++; + return ServerRequest.WriteResponse( + HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Range count {} exceeds maximum of {}", RangeCount, MaxRangeCountPerRequestSupported)); } - } - - m_BuildStoreStats.BlobReadCount++; - IoBuffer Blob = m_BuildStore.GetBlob(BlobHash); - if (!Blob) - { - return ServerRequest.WriteResponse(HttpResponseCode::NotFound, - HttpContentType::kText, - fmt::format("Blob with hash '{}' could not be found", Hash)); - } - m_BuildStoreStats.BlobHitCount++; - if (OffsetAndLengthPairs.empty()) - { - return ServerRequest.WriteResponse(HttpResponseCode::OK, Blob.GetContentType(), Blob); - } + const uint64_t BlobSize = Blob.GetSize(); + std::vector<IoBuffer> RangeBuffers; + RangeBuffers.reserve(RangeCount); - if (ServerRequest.AcceptContentType() == HttpContentType::kCbPackage) - { - const uint64_t BlobSize = Blob.GetSize(); + CbPackage ResponsePackage; + CbObjectWriter Writer; - CbPackage ResponsePackage; - std::vector<IoBuffer> RangeBuffers; - CbObjectWriter Writer; Writer.BeginArray("ranges"sv); - for (const std::pair<uint64_t, uint64_t>& Range : OffsetAndLengthPairs) + for (CbFieldView FieldView : RangesArray) { - const uint64_t MaxBlobSize = Range.first < BlobSize ? BlobSize - Range.first : 0; - const uint64_t RangeSize = Min(Range.second, MaxBlobSize); + CbObjectView RangeView = FieldView.AsObjectView(); + uint64_t RangeOffset = RangeView["offset"sv].AsUInt64(); + uint64_t RangeLength = RangeView["length"sv].AsUInt64(); + + const uint64_t MaxBlobSize = RangeOffset < BlobSize ? BlobSize - RangeOffset : 0; + const uint64_t RangeSize = Min(RangeLength, MaxBlobSize); Writer.BeginObject(); { - if (Range.first + RangeSize <= BlobSize) + if (RangeOffset + RangeSize <= BlobSize) { - RangeBuffers.push_back(IoBuffer(Blob, Range.first, RangeSize)); - Writer.AddInteger("offset"sv, Range.first); + RangeBuffers.push_back(IoBuffer(Blob, RangeOffset, RangeSize)); + Writer.AddInteger("offset"sv, RangeOffset); Writer.AddInteger("length"sv, RangeSize); } else { - Writer.AddInteger("offset"sv, Range.first); + Writer.AddInteger("offset"sv, RangeOffset); Writer.AddInteger("length"sv, 0); } } @@ -259,7 +244,7 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) } Writer.EndArray(); - CompositeBuffer Ranges(RangeBuffers); + CompositeBuffer Ranges(std::move(RangeBuffers)); CbAttachment PayloadAttachment(std::move(Ranges), BlobHash); Writer.AddAttachment("payload", PayloadAttachment); @@ -269,32 +254,21 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) ResponsePackage.SetObject(HeaderObject); CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(ResponsePackage); - uint64_t ResponseSize = RpcResponseBuffer.GetSize(); - ZEN_UNUSED(ResponseSize); return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer); } else { - if (OffsetAndLengthPairs.size() != 1) + HttpRanges RequestedRangeHeader; + bool HasRange = ServerRequest.TryGetRanges(RequestedRangeHeader); + if (HasRange) { - // Only a single http range is supported, we have limited support for http multirange responses - m_BuildStoreStats.BadRequestCount++; - return ServerRequest.WriteResponse( - HttpResponseCode::BadRequest, - HttpContentType::kText, - fmt::format("Multiple ranges in blob request is only supported for {} accept type", ToString(HttpContentType::kCbPackage))); + // Standard HTTP GET with Range header: framework handles 206, Content-Range, and 416 on OOB. + return ServerRequest.WriteResponse(HttpContentType::kBinary, Blob, RequestedRangeHeader); } - - const std::pair<uint64_t, uint64_t>& OffsetAndLength = OffsetAndLengthPairs.front(); - const uint64_t BlobSize = Blob.GetSize(); - const uint64_t MaxBlobSize = OffsetAndLength.first < BlobSize ? BlobSize - OffsetAndLength.first : 0; - const uint64_t RangeSize = Min(OffsetAndLength.second, MaxBlobSize); - if (OffsetAndLength.first + RangeSize > BlobSize) + else { - return ServerRequest.WriteResponse(HttpResponseCode::NoContent); + return ServerRequest.WriteResponse(HttpResponseCode::OK, Blob.GetContentType(), Blob); } - Blob = IoBuffer(Blob, OffsetAndLength.first, RangeSize); - return ServerRequest.WriteResponse(HttpResponseCode::OK, ZenContentType::kBinary, Blob); } } diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp index c1727270c..8ad48225b 100644 --- a/src/zenserver/storage/cache/httpstructuredcache.cpp +++ b/src/zenserver/storage/cache/httpstructuredcache.cpp @@ -80,7 +80,8 @@ HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCach HttpStatusService& StatusService, UpstreamCache& UpstreamCache, const DiskWriteBlocker* InDiskWriteBlocker, - OpenProcessCache& InOpenProcessCache) + OpenProcessCache& InOpenProcessCache, + const ILocalRefPolicy* InLocalRefPolicy) : m_Log(logging::Get("cache")) , m_CacheStore(InCacheStore) , m_StatsService(StatsService) @@ -90,6 +91,7 @@ HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCach , m_DiskWriteBlocker(InDiskWriteBlocker) , m_OpenProcessCache(InOpenProcessCache) , m_RpcHandler(m_Log, m_CacheStats, UpstreamCache, InCacheStore, InCidStore, InDiskWriteBlocker) +, m_LocalRefPolicy(InLocalRefPolicy) { m_StatsService.RegisterHandler("z$", *this); m_StatusService.RegisterHandler("z$", *this); @@ -114,6 +116,18 @@ HttpStructuredCacheService::BaseUri() const return "/z$/"; } +bool +HttpStructuredCacheService::AcceptsLocalFileReferences() const +{ + return true; +} + +const ILocalRefPolicy* +HttpStructuredCacheService::GetLocalRefPolicy() const +{ + return m_LocalRefPolicy; +} + void HttpStructuredCacheService::Flush() { diff --git a/src/zenserver/storage/cache/httpstructuredcache.h b/src/zenserver/storage/cache/httpstructuredcache.h index fc80b449e..f606126d6 100644 --- a/src/zenserver/storage/cache/httpstructuredcache.h +++ b/src/zenserver/storage/cache/httpstructuredcache.h @@ -76,11 +76,14 @@ public: HttpStatusService& StatusService, UpstreamCache& UpstreamCache, const DiskWriteBlocker* InDiskWriteBlocker, - OpenProcessCache& InOpenProcessCache); + OpenProcessCache& InOpenProcessCache, + const ILocalRefPolicy* InLocalRefPolicy = nullptr); ~HttpStructuredCacheService(); - virtual const char* BaseUri() const override; - virtual void HandleRequest(HttpServerRequest& Request) override; + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + virtual bool AcceptsLocalFileReferences() const override; + virtual const ILocalRefPolicy* GetLocalRefPolicy() const override; void Flush(); @@ -125,6 +128,7 @@ private: const DiskWriteBlocker* m_DiskWriteBlocker = nullptr; OpenProcessCache& m_OpenProcessCache; CacheRpcHandler m_RpcHandler; + const ILocalRefPolicy* m_LocalRefPolicy = nullptr; void ReplayRequestRecorder(const CacheRequestContext& Context, cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount); diff --git a/src/zenserver/storage/localrefpolicy.cpp b/src/zenserver/storage/localrefpolicy.cpp new file mode 100644 index 000000000..47ef13b28 --- /dev/null +++ b/src/zenserver/storage/localrefpolicy.cpp @@ -0,0 +1,29 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "localrefpolicy.h" + +#include <zencore/except_fmt.h> +#include <zencore/fmtutils.h> + +#include <filesystem> + +namespace zen { + +DataRootLocalRefPolicy::DataRootLocalRefPolicy(const std::filesystem::path& DataRoot) +: m_CanonicalRoot(std::filesystem::weakly_canonical(DataRoot).string()) +{ +} + +void +DataRootLocalRefPolicy::ValidatePath(const std::filesystem::path& Path) const +{ + std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(Path); + std::string FileStr = CanonicalFile.string(); + + if (FileStr.size() < m_CanonicalRoot.size() || FileStr.compare(0, m_CanonicalRoot.size(), m_CanonicalRoot) != 0) + { + throw zen::invalid_argument("local file reference '{}' is outside allowed data root", CanonicalFile); + } +} + +} // namespace zen diff --git a/src/zenserver/storage/localrefpolicy.h b/src/zenserver/storage/localrefpolicy.h new file mode 100644 index 000000000..3686d1880 --- /dev/null +++ b/src/zenserver/storage/localrefpolicy.h @@ -0,0 +1,25 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/localrefpolicy.h> + +#include <filesystem> +#include <string> + +namespace zen { + +/// Local ref policy that restricts file paths to a canonical data root directory. +/// Uses weakly_canonical + string prefix comparison to detect path traversal. +class DataRootLocalRefPolicy : public ILocalRefPolicy +{ +public: + explicit DataRootLocalRefPolicy(const std::filesystem::path& DataRoot); + + void ValidatePath(const std::filesystem::path& Path) const override; + +private: + std::string m_CanonicalRoot; +}; + +} // namespace zen diff --git a/src/zenserver/storage/objectstore/objectstore.cpp b/src/zenserver/storage/objectstore/objectstore.cpp index d6516fa1a..1115c1cd6 100644 --- a/src/zenserver/storage/objectstore/objectstore.cpp +++ b/src/zenserver/storage/objectstore/objectstore.cpp @@ -637,11 +637,7 @@ HttpObjectStoreService::GetObject(HttpRouterRequest& Request, const std::string_ } HttpRanges Ranges; - if (Request.ServerRequest().TryGetRanges(Ranges); Ranges.size() > 1) - { - // Only a single range is supported - return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); - } + Request.ServerRequest().TryGetRanges(Ranges); FileContents File; { @@ -665,42 +661,49 @@ HttpObjectStoreService::GetObject(HttpRouterRequest& Request, const std::string_ if (Ranges.empty()) { - const uint64_t TotalServed = m_TotalBytesServed.fetch_add(FileBuf.Size()) + FileBuf.Size(); - + const uint64_t TotalServed = m_TotalBytesServed.fetch_add(FileBuf.GetSize()) + FileBuf.GetSize(); ZEN_LOG_DEBUG(LogObj, "GET - '{}/{}' ({}) [OK] (Served: {})", BucketName, RelativeBucketPath, - NiceBytes(FileBuf.Size()), + NiceBytes(FileBuf.GetSize()), NiceBytes(TotalServed)); - - Request.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, FileBuf); } else { - const auto Range = Ranges[0]; - const uint64_t RangeSize = 1 + (Range.End - Range.Start); - const uint64_t TotalServed = m_TotalBytesServed.fetch_add(RangeSize) + RangeSize; - - ZEN_LOG_DEBUG(LogObj, - "GET - '{}/{}' (Range: {}-{}) ({}/{}) [OK] (Served: {})", - BucketName, - RelativeBucketPath, - Range.Start, - Range.End, - NiceBytes(RangeSize), - NiceBytes(FileBuf.Size()), - NiceBytes(TotalServed)); - - MemoryView RangeView = FileBuf.GetView().Mid(Range.Start, RangeSize); - if (RangeView.GetSize() != RangeSize) + const uint64_t TotalSize = FileBuf.GetSize(); + uint64_t ServedBytes = 0; + for (const HttpRange& Range : Ranges) { - return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + const uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1; + if (RangeEnd < TotalSize && Range.Start <= RangeEnd) + { + ServedBytes += 1 + (RangeEnd - Range.Start); + } + } + if (ServedBytes > 0) + { + const uint64_t TotalServed = m_TotalBytesServed.fetch_add(ServedBytes) + ServedBytes; + ZEN_LOG_DEBUG(LogObj, + "GET - '{}/{}' (Ranges: {}) ({}/{}) [OK] (Served: {})", + BucketName, + RelativeBucketPath, + Ranges.size(), + NiceBytes(ServedBytes), + NiceBytes(TotalSize), + NiceBytes(TotalServed)); + } + else + { + ZEN_LOG_DEBUG(LogObj, + "GET - '{}/{}' (Ranges: {}) [416] ({})", + BucketName, + RelativeBucketPath, + Ranges.size(), + NiceBytes(TotalSize)); } - - IoBuffer RangeBuf = IoBuffer(IoBuffer::Wrap, RangeView.GetData(), RangeView.GetSize()); - Request.ServerRequest().WriteResponse(HttpResponseCode::PartialContent, HttpContentType::kBinary, RangeBuf); } + Request.ServerRequest().WriteResponse(HttpContentType::kBinary, FileBuf, Ranges); } void diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp index a7c8c66b6..2a6c62195 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.cpp +++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp @@ -18,7 +18,6 @@ #include <zenremotestore/builds/buildstoragecache.h> #include <zenremotestore/builds/buildstorageutil.h> #include <zenremotestore/jupiter/jupiterhost.h> -#include <zenremotestore/operationlogoutput.h> #include <zenremotestore/projectstore/buildsremoteprojectstore.h> #include <zenremotestore/projectstore/fileremoteprojectstore.h> #include <zenremotestore/projectstore/jupiterremoteprojectstore.h> @@ -279,7 +278,7 @@ namespace { { ZEN_MEMSCOPE(GetProjectHttpTag()); - auto Log = [InLog]() { return InLog; }; + ZEN_SCOPED_LOG(InLog); using namespace std::literals; @@ -566,11 +565,9 @@ namespace { .AllowResume = true, .RetryCount = 2}; - std::unique_ptr<OperationLogOutput> Output(CreateStandardLogOutput(Log())); - try { - ResolveResult = ResolveBuildStorage(*Output, + ResolveResult = ResolveBuildStorage(Log(), ClientSettings, Host, OverrideHost, @@ -656,7 +653,8 @@ HttpProjectService::HttpProjectService(CidStore& Store, JobQueue& InJobQueue, bool InRestrictContentTypes, const std::filesystem::path& InOidcTokenExePath, - bool InAllowExternalOidcTokenExe) + bool InAllowExternalOidcTokenExe, + const ILocalRefPolicy* InLocalRefPolicy) : m_Log(logging::Get("project")) , m_CidStore(Store) , m_ProjectStore(Projects) @@ -668,6 +666,7 @@ HttpProjectService::HttpProjectService(CidStore& Store, , m_RestrictContentTypes(InRestrictContentTypes) , m_OidcTokenExePath(InOidcTokenExePath) , m_AllowExternalOidcTokenExe(InAllowExternalOidcTokenExe) +, m_LocalRefPolicy(InLocalRefPolicy) { ZEN_MEMSCOPE(GetProjectHttpTag()); @@ -785,22 +784,22 @@ HttpProjectService::HttpProjectService(CidStore& Store, HttpVerb::kPost); m_Router.RegisterRoute( - "details\\$", + "details$", [this](HttpRouterRequest& Req) { HandleDetailsRequest(Req); }, HttpVerb::kGet); m_Router.RegisterRoute( - "details\\$/{project}", + "details$/{project}", [this](HttpRouterRequest& Req) { HandleProjectDetailsRequest(Req); }, HttpVerb::kGet); m_Router.RegisterRoute( - "details\\$/{project}/{log}", + "details$/{project}/{log}", [this](HttpRouterRequest& Req) { HandleOplogDetailsRequest(Req); }, HttpVerb::kGet); m_Router.RegisterRoute( - "details\\$/{project}/{log}/{chunk}", + "details$/{project}/{log}/{chunk}", [this](HttpRouterRequest& Req) { HandleOplogOpDetailsRequest(Req); }, HttpVerb::kGet); @@ -820,6 +819,18 @@ HttpProjectService::BaseUri() const return "/prj/"; } +bool +HttpProjectService::AcceptsLocalFileReferences() const +{ + return true; +} + +const ILocalRefPolicy* +HttpProjectService::GetLocalRefPolicy() const +{ + return m_LocalRefPolicy; +} + void HttpProjectService::HandleRequest(HttpServerRequest& Request) { @@ -1250,7 +1261,7 @@ HttpProjectService::HandleChunkInfoRequest(HttpRouterRequest& Req) const Oid Obj = Oid::FromHexString(ChunkId); - CbObject ResponsePayload = ProjectStore::GetChunkInfo(Log(), *Project, *FoundLog, Obj); + CbObject ResponsePayload = ProjectStore::GetChunkInfo(*Project, *FoundLog, Obj); if (ResponsePayload) { m_ProjectStats.ChunkHitCount++; @@ -1339,7 +1350,7 @@ HttpProjectService::HandleChunkByIdRequest(HttpRouterRequest& Req) HttpContentType AcceptType = HttpReq.AcceptContentType(); ProjectStore::GetChunkRangeResult Result = - ProjectStore::GetChunkRange(Log(), *Project, *FoundLog, Obj, Offset, Size, AcceptType, /*OptionalInOutModificationTag*/ nullptr); + ProjectStore::GetChunkRange(*Project, *FoundLog, Obj, Offset, Size, AcceptType, /*OptionalInOutModificationTag*/ nullptr); switch (Result.Error) { @@ -1668,7 +1679,8 @@ HttpProjectService::HandleOplogOpNewRequest(HttpRouterRequest& Req) CbPackage Package; - if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver)) + const bool ValidateHashes = false; + if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver, ValidateHashes)) { CbValidateError ValidateResult; if (CbObject Core = ValidateAndReadCompactBinaryObject(IoBuffer(Payload), ValidateResult); @@ -2676,6 +2688,7 @@ HttpProjectService::HandleOplogLoadRequest(HttpRouterRequest& Req) try { CbObject ContainerObject = BuildContainer( + Log(), m_CidStore, *Project, *Oplog, @@ -2763,7 +2776,11 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) case HttpContentType::kCbPackage: try { - Package = ParsePackageMessage(Payload); + ParseFlags PkgFlags = (HttpReq.IsLocalMachineRequest() && AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + const ILocalRefPolicy* PkgPolicy = + EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? GetLocalRefPolicy() : nullptr; + Package = ParsePackageMessage(Payload, {}, PkgFlags, PkgPolicy); Cb = Package.GetObject(); } catch (const std::invalid_argument& ex) @@ -2872,6 +2889,7 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) try { LoadOplog(LoadOplogContext{ + .Log = Log(), .ChunkStore = m_CidStore, .RemoteStore = *RemoteStoreResult->Store, .OptionalCache = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->Cache.get() : nullptr, @@ -2997,7 +3015,8 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) try { - SaveOplog(m_CidStore, + SaveOplog(Log(), + m_CidStore, *ActualRemoteStore, *Project, *Oplog, diff --git a/src/zenserver/storage/projectstore/httpprojectstore.h b/src/zenserver/storage/projectstore/httpprojectstore.h index e3ed02f26..8aa345fa7 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.h +++ b/src/zenserver/storage/projectstore/httpprojectstore.h @@ -47,11 +47,14 @@ public: JobQueue& InJobQueue, bool InRestrictContentTypes, const std::filesystem::path& InOidcTokenExePath, - bool AllowExternalOidcTokenExe); + bool AllowExternalOidcTokenExe, + const ILocalRefPolicy* InLocalRefPolicy = nullptr); ~HttpProjectService(); - virtual const char* BaseUri() const override; - virtual void HandleRequest(HttpServerRequest& Request) override; + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + virtual bool AcceptsLocalFileReferences() const override; + virtual const ILocalRefPolicy* GetLocalRefPolicy() const override; virtual void HandleStatusRequest(HttpServerRequest& Request) override; virtual void HandleStatsRequest(HttpServerRequest& Request) override; @@ -117,6 +120,7 @@ private: bool m_RestrictContentTypes; std::filesystem::path m_OidcTokenExePath; bool m_AllowExternalOidcTokenExe; + const ILocalRefPolicy* m_LocalRefPolicy; Ref<TransferThreadWorkers> GetThreadWorkers(bool BoostWorkers, bool SingleThreaded); }; diff --git a/src/zenserver/storage/storageconfig.cpp b/src/zenserver/storage/storageconfig.cpp index 0dbb45164..bb4f053e4 100644 --- a/src/zenserver/storage/storageconfig.cpp +++ b/src/zenserver/storage/storageconfig.cpp @@ -57,6 +57,12 @@ ZenStorageServerConfigurator::ValidateOptions() ZEN_WARN("'--gc-v2=false' is deprecated, reverting to '--gc-v2=true'"); ServerOptions.GcConfig.UseGCV2 = true; } + if (ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent > 100) + { + throw OptionParseException(fmt::format("'--buildstore-disksizelimit-percent' ('{}') is invalid, must be between 1 and 100.", + ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent), + {}); + } } class ZenStructuredCacheBucketsConfigOption : public LuaConfig::OptionValue @@ -382,6 +388,9 @@ ZenStorageServerConfigurator::AddConfigOptions(LuaConfig::Options& LuaOptions) ////// buildsstore LuaOptions.AddOption("server.buildstore.enabled"sv, ServerOptions.BuildStoreConfig.Enabled, "buildstore-enabled"sv); LuaOptions.AddOption("server.buildstore.disksizelimit"sv, ServerOptions.BuildStoreConfig.MaxDiskSpaceLimit, "buildstore-disksizelimit"); + LuaOptions.AddOption("server.buildstore.disksizelimitpercent"sv, + ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent, + "buildstore-disksizelimit-percent"); ////// cache LuaOptions.AddOption("cache.enable"sv, ServerOptions.StructuredCacheConfig.Enabled); @@ -477,7 +486,7 @@ ZenStorageServerConfigurator::AddConfigOptions(LuaConfig::Options& LuaOptions) ServerOptions.GcConfig.CompactBlockUsageThresholdPercent, "gc-compactblock-threshold"sv); LuaOptions.AddOption("gc.verbose"sv, ServerOptions.GcConfig.Verbose, "gc-verbose"sv); - LuaOptions.AddOption("gc.single-threaded"sv, ServerOptions.GcConfig.SingleThreaded, "gc-single-threaded"sv); + LuaOptions.AddOption("gc.singlethreaded"sv, ServerOptions.GcConfig.SingleThreaded, "gc-single-threaded"sv); LuaOptions.AddOption("gc.cache.attachment.store"sv, ServerOptions.GcConfig.StoreCacheAttachmentMetaData, "gc-cache-attachment-store"); LuaOptions.AddOption("gc.projectstore.attachment.store"sv, ServerOptions.GcConfig.StoreProjectAttachmentMetaData, @@ -1035,6 +1044,13 @@ ZenStorageServerCmdLineOptions::AddBuildStoreOptions(cxxopts::Options& options, "Max number of bytes before build store entries get evicted. Default set to 1099511627776 (1TB week)", cxxopts::value<uint64_t>(ServerOptions.BuildStoreConfig.MaxDiskSpaceLimit)->default_value("1099511627776"), ""); + options.add_option("buildstore", + "", + "buildstore-disksizelimit-percent", + "Max percentage (1-100) of total drive capacity (of --data-dir drive) before build store entries get evicted. " + "0 (default) disables this limit. When combined with --buildstore-disksizelimit, the lower value wins.", + cxxopts::value<uint32_t>(ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent)->default_value("0"), + ""); } void diff --git a/src/zenserver/storage/storageconfig.h b/src/zenserver/storage/storageconfig.h index 18af4f096..fec8fd70b 100644 --- a/src/zenserver/storage/storageconfig.h +++ b/src/zenserver/storage/storageconfig.h @@ -135,8 +135,9 @@ struct ZenProjectStoreConfig struct ZenBuildStoreConfig { - bool Enabled = false; - uint64_t MaxDiskSpaceLimit = 1u * 1024u * 1024u * 1024u * 1024u; // 1TB + bool Enabled = false; + uint64_t MaxDiskSpaceLimit = 1u * 1024u * 1024u * 1024u * 1024u; // 1TB + uint32_t MaxDiskSpaceLimitPercent = 0; }; struct ZenWorkspacesConfig diff --git a/src/zenserver/storage/upstream/upstreamcache.cpp b/src/zenserver/storage/upstream/upstreamcache.cpp index b26c57414..a516c452c 100644 --- a/src/zenserver/storage/upstream/upstreamcache.cpp +++ b/src/zenserver/storage/upstream/upstreamcache.cpp @@ -772,7 +772,7 @@ namespace detail { UpstreamEndpointInfo m_Info; UpstreamStatus m_Status; UpstreamEndpointStats m_Stats; - RefPtr<JupiterClient> m_Client; + Ref<JupiterClient> m_Client; const bool m_AllowRedirect = false; }; @@ -1446,7 +1446,7 @@ namespace detail { // Make sure we safely bump the refcount inside a scope lock RwLock::SharedLockScope _(m_ClientLock); ZEN_ASSERT(m_Client); - Ref<ZenStructuredCacheClient> ClientRef(m_Client); + Ref<ZenStructuredCacheClient> ClientRef(m_Client.Get()); _.ReleaseNow(); return ClientRef; } @@ -1485,15 +1485,15 @@ namespace detail { LoggerRef Log() { return m_Log; } - LoggerRef m_Log; - UpstreamEndpointInfo m_Info; - UpstreamStatus m_Status; - UpstreamEndpointStats m_Stats; - std::vector<ZenEndpoint> m_Endpoints; - std::chrono::milliseconds m_ConnectTimeout; - std::chrono::milliseconds m_Timeout; - RwLock m_ClientLock; - RefPtr<ZenStructuredCacheClient> m_Client; + LoggerRef m_Log; + UpstreamEndpointInfo m_Info; + UpstreamStatus m_Status; + UpstreamEndpointStats m_Stats; + std::vector<ZenEndpoint> m_Endpoints; + std::chrono::milliseconds m_ConnectTimeout; + std::chrono::milliseconds m_Timeout; + RwLock m_ClientLock; + Ref<ZenStructuredCacheClient> m_Client; }; } // namespace detail diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp index bc0a8f4ac..7b52f2832 100644 --- a/src/zenserver/storage/zenstorageserver.cpp +++ b/src/zenserver/storage/zenstorageserver.cpp @@ -37,8 +37,6 @@ #include <zenutil/sessionsclient.h> #include <zenutil/workerpools.h> #include <zenutil/zenserverprocess.h> -#include "sessions/inprocsessionlogsink.h" -#include "sessions/sessions.h" #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> @@ -165,11 +163,6 @@ ZenStorageServer::RegisterServices() m_Http->RegisterService(*m_HttpWorkspacesService); } - if (m_HttpSessionsService) - { - m_Http->RegisterService(*m_HttpSessionsService); - } - m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatsService, m_StatusService); if (m_FrontendService) @@ -223,12 +216,13 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions ZEN_INFO("instantiating project service"); + m_LocalRefPolicy = std::make_unique<DataRootLocalRefPolicy>(m_DataRoot); m_JobQueue = MakeJobQueue(8, "bgjobs"); m_OpenProcessCache = std::make_unique<OpenProcessCache>(); m_ProjectStore = new ProjectStore(*m_CidStore, m_DataRoot / "projects", m_GcManager, ProjectStore::Configuration{}); m_HttpProjectService.reset(new HttpProjectService{*m_CidStore, - m_ProjectStore, + m_ProjectStore.Get(), m_StatusService, m_StatsService, *m_AuthMgr, @@ -236,7 +230,8 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions *m_JobQueue, ServerOptions.RestrictContentTypes, ServerOptions.OidcTokenExecutable, - ServerOptions.AllowExternalOidcTokenExe}); + ServerOptions.AllowExternalOidcTokenExe, + m_LocalRefPolicy.get()}); if (ServerOptions.WorksSpacesConfig.Enabled) { @@ -251,16 +246,6 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions *m_Workspaces)); } - { - m_SessionsService = std::make_unique<SessionsService>(); - m_HttpSessionsService = std::make_unique<HttpSessionsService>(m_StatusService, m_StatsService, *m_SessionsService, m_IoContext); - m_HttpSessionsService->SetSelfSessionId(GetSessionId()); - - m_InProcSessionLogSink = logging::SinkPtr(new InProcSessionLogSink(*m_SessionsService)); - m_InProcSessionLogSink->SetLevel(logging::Info); - GetDefaultBroadcastSink()->AddSink(m_InProcSessionLogSink); - } - if (!ServerOptions.SessionsTargetUrl.empty()) { m_SessionsClient = std::make_unique<SessionsServiceClient>(SessionsServiceClient::Options{ @@ -281,7 +266,31 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions BuildStoreConfig BuildsCfg; BuildsCfg.RootDirectory = m_DataRoot / "builds"; BuildsCfg.MaxDiskSpaceLimit = ServerOptions.BuildStoreConfig.MaxDiskSpaceLimit; - m_BuildStore = std::make_unique<BuildStore>(std::move(BuildsCfg), m_GcManager, *m_BuildCidStore); + + if (ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent > 0) + { + DiskSpace Space; + if (DiskSpaceInfo(m_DataRoot, Space) && Space.Total > 0) + { + uint64_t PercentLimit = Space.Total * ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent / 100; + BuildsCfg.MaxDiskSpaceLimit = ServerOptions.BuildStoreConfig.MaxDiskSpaceLimit > 0 + ? std::min(ServerOptions.BuildStoreConfig.MaxDiskSpaceLimit, PercentLimit) + : PercentLimit; + ZEN_INFO("buildstore disk limit: {}% of {} = {} (effective limit: {})", + ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent, + NiceBytes(Space.Total), + NiceBytes(PercentLimit), + NiceBytes(BuildsCfg.MaxDiskSpaceLimit)); + } + else + { + ZEN_WARN("buildstore-disksizelimit-percent: failed to query disk space for {}, using absolute limit {}", + m_DataRoot.string(), + NiceBytes(BuildsCfg.MaxDiskSpaceLimit)); + } + } + + m_BuildStore = std::make_unique<BuildStore>(std::move(BuildsCfg), m_GcManager, *m_BuildCidStore); } if (ServerOptions.StructuredCacheConfig.Enabled) @@ -323,13 +332,13 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions ZEN_OTEL_SPAN("InitializeComputeService"); m_HttpComputeService = - std::make_unique<compute::HttpComputeService>(*m_CidStore, m_StatsService, ServerOptions.DataDir / "functions"); + std::make_unique<compute::HttpComputeService>(*m_CidStore, *m_CidStore, m_StatsService, ServerOptions.DataDir / "functions"); } #endif #if ZEN_WITH_VFS m_VfsServiceImpl = std::make_unique<VfsServiceImpl>(); - m_VfsServiceImpl->AddService(Ref<ProjectStore>(m_ProjectStore)); + m_VfsServiceImpl->AddService(Ref<ProjectStore>(m_ProjectStore.Get())); m_VfsServiceImpl->AddService(Ref<ZenCacheStore>(m_CacheStore)); m_VfsService = std::make_unique<VfsService>(m_StatusService, m_VfsServiceImpl.get()); @@ -713,7 +722,8 @@ ZenStorageServer::InitializeStructuredCache(const ZenStorageServerConfig& Server m_StatusService, *m_UpstreamCache, m_GcManager.GetDiskWriteBlocker(), - *m_OpenProcessCache); + *m_OpenProcessCache, + m_LocalRefPolicy.get()); m_StatsReporter.AddProvider(m_CacheStore.Get()); m_StatsReporter.AddProvider(m_CidStore.get()); @@ -838,7 +848,7 @@ ZenStorageServer::Run() OnReady(); - m_SessionsService->RegisterSession(GetSessionId(), "zenserver", GetServerMode(), Oid::Zero, {}); + StartSelfSession("zenserver"); if (m_SessionsClient) { @@ -888,11 +898,6 @@ ZenStorageServer::Cleanup() m_Http->Close(); } - if (m_InProcSessionLogSink) - { - GetDefaultBroadcastSink()->RemoveSink(m_InProcSessionLogSink); - m_InProcSessionLogSink = {}; - } if (m_SessionLogSink) { GetDefaultBroadcastSink()->RemoveSink(m_SessionLogSink); @@ -904,11 +909,6 @@ ZenStorageServer::Cleanup() m_SessionsClient.reset(); } - if (m_SessionsService) - { - m_SessionsService->RemoveSession(GetSessionId()); - } - ShutdownServices(); if (m_JobQueue) @@ -940,8 +940,6 @@ ZenStorageServer::Cleanup() m_UpstreamCache.reset(); m_CacheStore = {}; - m_HttpSessionsService.reset(); - m_SessionsService.reset(); m_HttpWorkspacesService.reset(); m_Workspaces.reset(); m_HttpProjectService.reset(); diff --git a/src/zenserver/storage/zenstorageserver.h b/src/zenserver/storage/zenstorageserver.h index fad22ad54..9fa46ba9b 100644 --- a/src/zenserver/storage/zenstorageserver.h +++ b/src/zenserver/storage/zenstorageserver.h @@ -11,6 +11,7 @@ #include <zenstore/cache/structuredcachestore.h> #include <zenstore/gc.h> #include <zenstore/projectstore.h> +#include "localrefpolicy.h" #include "admin/admin.h" #include "buildstore/httpbuildstore.h" @@ -19,7 +20,6 @@ #include "frontend/frontend.h" #include "objectstore/objectstore.h" #include "projectstore/httpprojectstore.h" -#include "sessions/httpsessions.h" #include "stats/statsreporter.h" #include "upstream/upstream.h" #include "vfs/vfsservice.h" @@ -65,27 +65,26 @@ private: void InitializeServices(const ZenStorageServerConfig& ServerOptions); void RegisterServices(); - std::unique_ptr<JobQueue> m_JobQueue; - GcManager m_GcManager; - GcScheduler m_GcScheduler{m_GcManager}; - std::unique_ptr<CidStore> m_CidStore; - Ref<ZenCacheStore> m_CacheStore; - std::unique_ptr<OpenProcessCache> m_OpenProcessCache; - HttpTestService m_TestService; - std::unique_ptr<CidStore> m_BuildCidStore; - std::unique_ptr<BuildStore> m_BuildStore; + std::unique_ptr<DataRootLocalRefPolicy> m_LocalRefPolicy; + std::unique_ptr<JobQueue> m_JobQueue; + GcManager m_GcManager; + GcScheduler m_GcScheduler{m_GcManager}; + std::unique_ptr<CidStore> m_CidStore; + Ref<ZenCacheStore> m_CacheStore; + std::unique_ptr<OpenProcessCache> m_OpenProcessCache; + HttpTestService m_TestService; + std::unique_ptr<CidStore> m_BuildCidStore; + std::unique_ptr<BuildStore> m_BuildStore; #if ZEN_WITH_TESTS HttpTestingService m_TestingService; #endif - RefPtr<ProjectStore> m_ProjectStore; + Ref<ProjectStore> m_ProjectStore; std::unique_ptr<VfsServiceImpl> m_VfsServiceImpl; std::unique_ptr<HttpProjectService> m_HttpProjectService; std::unique_ptr<Workspaces> m_Workspaces; std::unique_ptr<HttpWorkspacesService> m_HttpWorkspacesService; - std::unique_ptr<SessionsService> m_SessionsService; - std::unique_ptr<HttpSessionsService> m_HttpSessionsService; std::unique_ptr<UpstreamCache> m_UpstreamCache; std::unique_ptr<HttpUpstreamService> m_UpstreamService; std::unique_ptr<HttpStructuredCacheService> m_StructuredCacheService; @@ -98,7 +97,6 @@ private: std::unique_ptr<SessionsServiceClient> m_SessionsClient; logging::SinkPtr m_SessionLogSink; - logging::SinkPtr m_InProcSessionLogSink; asio::steady_timer m_SessionAnnounceTimer{m_IoContext}; void EnqueueSessionAnnounceTimer(); diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua index c2c81e7aa..b609d1050 100644 --- a/src/zenserver/xmake.lua +++ b/src/zenserver/xmake.lua @@ -32,7 +32,7 @@ target("zenserver") add_deps("protozero", "asio", "cxxopts") add_deps("sol2") - add_packages("http_parser") + add_packages("llhttp") add_packages("json11") add_packages("zlib") add_packages("lua") diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index 6aa02eb87..e68e46bd6 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -13,6 +13,7 @@ #include <zencore/iobuffer.h> #include <zencore/jobqueue.h> #include <zencore/logging.h> +#include <zencore/logging/broadcastsink.h> #include <zencore/memory/fmalloc.h> #include <zencore/scopeguard.h> #include <zencore/sentryintegration.h> @@ -28,6 +29,7 @@ #include <zenhttp/security/passwordsecurityfilter.h> #include <zentelemetry/otlptrace.h> #include <zenutil/authutils.h> +#include <zenutil/logging.h> #include <zenutil/service.h> #include <zenutil/workerpools.h> #include <zenutil/zenserverprocess.h> @@ -64,6 +66,9 @@ ZEN_THIRD_PARTY_INCLUDES_END #include "config/config.h" #include "diag/logging.h" +#include "sessions/httpsessions.h" +#include "sessions/inprocsessionlogsink.h" +#include "sessions/sessions.h" #include <zencore/memory/llm.h> @@ -225,6 +230,8 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState:: LogSettingsSummary(ServerOptions); + InitializeSessions(); + return EffectiveBasePort; } @@ -233,6 +240,11 @@ ZenServerBase::Finalize() { m_StatsService.RegisterHandler("http", *m_Http); + if (m_HttpSessionsService) + { + m_Http->RegisterService(*m_HttpSessionsService); + } + m_Http->SetDefaultRedirect("/dashboard/"); // Register health service last so if we return "OK" for health it means all services have been properly initialized @@ -243,11 +255,49 @@ ZenServerBase::Finalize() void ZenServerBase::ShutdownServices() { - m_StatsService.UnregisterHandler("http", *m_Http); + if (m_InProcSessionLogSink) + { + GetDefaultBroadcastSink()->RemoveSink(m_InProcSessionLogSink); + m_InProcSessionLogSink = {}; + } + + if (m_SessionsService) + { + m_SessionsService->RemoveSession(GetSessionId()); + } + + m_HttpSessionsService.reset(); + m_SessionsService.reset(); + + if (m_Http) + { + m_StatsService.UnregisterHandler("http", *m_Http); + } m_StatsService.Shutdown(); } void +ZenServerBase::InitializeSessions() +{ + m_SessionsService = std::make_unique<SessionsService>(); + m_HttpSessionsService = std::make_unique<HttpSessionsService>(m_StatusService, m_StatsService, *m_SessionsService, m_IoContext); + m_HttpSessionsService->SetSelfSessionId(GetSessionId()); + + m_InProcSessionLogSink = logging::SinkPtr(new InProcSessionLogSink(*m_SessionsService)); + m_InProcSessionLogSink->SetLevel(logging::Info); + GetDefaultBroadcastSink()->AddSink(m_InProcSessionLogSink); +} + +void +ZenServerBase::StartSelfSession(std::string_view AppName) +{ + if (m_SessionsService) + { + m_SessionsService->RegisterSession(GetSessionId(), std::string(AppName), GetServerMode(), Oid::Zero, {}); + } +} + +void ZenServerBase::GetBuildOptions(StringBuilderBase& OutOptions, char Separator) const { ZEN_MEMSCOPE(GetZenserverTag()); @@ -272,6 +322,8 @@ ZenServerBase::GetBuildOptions(StringBuilderBase& OutOptions, char Separator) co OutOptions << Separator; OutOptions << "ZEN_WITH_MEMTRACK=" << (ZEN_WITH_MEMTRACK ? "1" : "0"); OutOptions << Separator; + OutOptions << "ZEN_WITH_COMPUTE_SERVICES=" << (ZEN_WITH_COMPUTE_SERVICES ? "1" : "0"); + OutOptions << Separator; OutOptions << "ZEN_WITH_TRACE=" << (ZEN_WITH_TRACE ? "1" : "0"); } @@ -697,7 +749,7 @@ ZenServerMain::Run() // The entry's process failed to pick up our sponsor request after // multiple attempts. Before reclaiming the entry, verify that the // PID does not still belong to a zenserver process. If it does, the - // server is alive but unresponsive – fall back to the original error + // server is alive but unresponsive - fall back to the original error // path. If the PID is gone or belongs to a different executable the // entry is genuinely stale and safe to reclaim. const int StalePid = Entry->Pid.load(); @@ -715,7 +767,7 @@ ZenServerMain::Run() } ZEN_CONSOLE_WARN( "Failed to add sponsor to process on port {} (pid {}); " - "pid belongs to '{}' – assuming stale entry and reclaiming", + "pid belongs to '{}' - assuming stale entry and reclaiming", m_ServerOptions.BasePort, StalePid, ExeEc ? "<unknown>" : PidExePath.filename().string()); diff --git a/src/zenserver/zenserver.h b/src/zenserver/zenserver.h index f5286e9ee..995ff054f 100644 --- a/src/zenserver/zenserver.h +++ b/src/zenserver/zenserver.h @@ -3,6 +3,7 @@ #pragma once #include <zencore/basicfile.h> +#include <zencore/logging/sink.h> #include <zencore/system.h> #include <zenhttp/httpserver.h> #include <zenhttp/httpstats.h> @@ -27,6 +28,8 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { +class HttpSessionsService; +class SessionsService; struct FLLMTag; extern const FLLMTag& GetZenserverTag(); @@ -57,6 +60,7 @@ protected: int Initialize(const ZenServerConfig& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry); void Finalize(); void ShutdownServices(); + void StartSelfSession(std::string_view AppName); void GetBuildOptions(StringBuilderBase& OutOptions, char Separator = ',') const; static std::vector<std::pair<std::string_view, std::string>> BuildSettingsList(const ZenServerConfig& ServerConfig); void LogSettingsSummary(const ZenServerConfig& ServerConfig); @@ -104,6 +108,11 @@ protected: HttpStatusService m_StatusService; SystemMetricsTracker m_MetricsTracker; + // Sessions (shared by all derived servers) + std::unique_ptr<SessionsService> m_SessionsService; + std::unique_ptr<HttpSessionsService> m_HttpSessionsService; + logging::SinkPtr m_InProcSessionLogSink; + // Stats reporting StatsReporter m_StatsReporter; @@ -137,6 +146,7 @@ protected: virtual void HandleStatusRequest(HttpServerRequest& Request) override; private: + void InitializeSessions(); void InitializeSecuritySettings(const ZenServerConfig& ServerOptions); }; class ZenServerMain diff --git a/src/zenstore/buildstore/buildstore.cpp b/src/zenstore/buildstore/buildstore.cpp index dff1c3c61..a08741d31 100644 --- a/src/zenstore/buildstore/buildstore.cpp +++ b/src/zenstore/buildstore/buildstore.cpp @@ -1052,7 +1052,7 @@ public: { ZEN_TRACE_CPU("Builds::PreCache"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -1107,7 +1107,7 @@ public: ZEN_TRACE_CPU("Builds::GetUnusedReferences"); ZEN_MEMSCOPE(GetBuildstoreTag()); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); size_t InitialCount = IoCids.size(); size_t UsedCount = InitialCount; @@ -1152,7 +1152,7 @@ BuildStore::RemoveExpiredData(GcCtx& Ctx, GcStats& Stats) ZEN_TRACE_CPU("Builds::RemoveExpiredData"); ZEN_MEMSCOPE(GetBuildstoreTag()); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { diff --git a/src/zenstore/cache/cachedisklayer.cpp b/src/zenstore/cache/cachedisklayer.cpp index 4640309d9..45a4b6456 100644 --- a/src/zenstore/cache/cachedisklayer.cpp +++ b/src/zenstore/cache/cachedisklayer.cpp @@ -3083,7 +3083,7 @@ public: { ZEN_TRACE_CPU("Z$::Bucket::CompactStore"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -3338,7 +3338,7 @@ ZenCacheDiskLayer::CacheBucket::RemoveExpiredData(GcCtx& Ctx, GcStats& Stats) { ZEN_TRACE_CPU("Z$::Bucket::RemoveExpiredData"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); size_t TotalEntries = 0; @@ -3502,7 +3502,7 @@ ZenCacheDiskLayer::CacheBucket::GetReferences(const LoggerRef& Logger, { ZEN_TRACE_CPU("Z$::Bucket::GetReferencesLocked"); - auto Log = [&Logger]() { return Logger; }; + ZEN_SCOPED_LOG(Logger); auto GetAttachments = [&](const IoHash& RawHash, MemoryView Data) -> bool { if (CbValidateError Error = ValidateCompactBinary(Data, CbValidateMode::Default); Error == CbValidateError::None) @@ -3718,7 +3718,7 @@ public: { ZEN_TRACE_CPU("Z$::Bucket::PreCache"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -3753,7 +3753,7 @@ public: { ZEN_TRACE_CPU("Z$::Bucket::UpdateLockedState"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -3784,7 +3784,7 @@ public: { ZEN_TRACE_CPU("Z$::Bucket::GetUnusedReferences"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); const size_t InitialCount = IoCids.size(); size_t UsedCount = InitialCount; @@ -3818,7 +3818,7 @@ ZenCacheDiskLayer::CacheBucket::CreateReferenceCheckers(GcCtx& Ctx) { ZEN_TRACE_CPU("Z$::Bucket::CreateReferenceCheckers"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { diff --git a/src/zenstore/cache/structuredcachestore.cpp b/src/zenstore/cache/structuredcachestore.cpp index cff0e9a35..97b793083 100644 --- a/src/zenstore/cache/structuredcachestore.cpp +++ b/src/zenstore/cache/structuredcachestore.cpp @@ -468,7 +468,7 @@ ZenCacheStore::LogWorker() LoggerRef ZCacheLog(logging::Get("z$")); - auto Log = [&ZCacheLog]() -> LoggerRef { return ZCacheLog; }; + ZEN_SCOPED_LOG(ZCacheLog); std::vector<AccessLogItem> Items; while (true) @@ -1086,11 +1086,9 @@ ZenCacheStore::GetBucketInfo(std::string_view NamespaceName, std::string_view Bu std::vector<RwLock::SharedLockScope> ZenCacheStore::LockState(GcCtx& Ctx) { + ZEN_UNUSED(Ctx); ZEN_TRACE_CPU("CacheStore::LockState"); - auto Log = [&Ctx]() { return Ctx.Logger; }; - ZEN_UNUSED(Log); - std::vector<RwLock::SharedLockScope> Locks; Locks.emplace_back(RwLock::SharedLockScope(m_NamespacesLock)); for (auto& NamespaceIt : m_Namespaces) @@ -1211,7 +1209,7 @@ public: { ZEN_TRACE_CPU("Z$::UpdateLockedState"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; @@ -1276,7 +1274,7 @@ public: { ZEN_TRACE_CPU("Z$::GetUnusedReferences"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); const size_t InitialCount = IoCids.size(); size_t UsedCount = InitialCount; @@ -1309,7 +1307,7 @@ ZenCacheStore::CreateReferenceCheckers(GcCtx& Ctx) { ZEN_TRACE_CPU("CacheStore::CreateReferenceCheckers"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { diff --git a/src/zenstore/cidstore.cpp b/src/zenstore/cidstore.cpp index b20d8f565..ac8a75a58 100644 --- a/src/zenstore/cidstore.cpp +++ b/src/zenstore/cidstore.cpp @@ -188,12 +188,24 @@ CidStore::Initialize(const CidStoreConfiguration& Config) } CidStore::InsertResult +CidStore::AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash) +{ + return m_Impl->AddChunk(ChunkData, RawHash, InsertMode::kMayBeMovedInPlace); +} + +CidStore::InsertResult CidStore::AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, InsertMode Mode) { return m_Impl->AddChunk(ChunkData, RawHash, Mode); } std::vector<CidStore::InsertResult> +CidStore::AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes) +{ + return m_Impl->AddChunks(ChunkDatas, RawHashes, InsertMode::kMayBeMovedInPlace); +} + +std::vector<CidStore::InsertResult> CidStore::AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes, InsertMode Mode) { return m_Impl->AddChunks(ChunkDatas, RawHashes, Mode); diff --git a/src/zenstore/compactcas.cpp b/src/zenstore/compactcas.cpp index 43dc389e2..6f1e1d701 100644 --- a/src/zenstore/compactcas.cpp +++ b/src/zenstore/compactcas.cpp @@ -698,7 +698,7 @@ public: ZEN_TRACE_CPU("CasContainer::CompactStore"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -875,7 +875,7 @@ public: ZEN_MEMSCOPE(GetCasContainerTag()); ZEN_TRACE_CPU("CasContainer::RemoveUnreferencedData"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -958,7 +958,7 @@ CasContainerStrategy::CreateReferencePruner(GcCtx& Ctx, GcReferenceStoreStats&) ZEN_MEMSCOPE(GetCasContainerTag()); ZEN_TRACE_CPU("CasContainer::CreateReferencePruner"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -1391,7 +1391,7 @@ TEST_CASE("compactcas.compact.gc") { ScopedTemporaryDirectory TempDir; - const int kIterationCount = 1000; + const int kIterationCount = 200; std::vector<IoHash> Keys(kIterationCount); @@ -1504,7 +1504,7 @@ TEST_CASE("compactcas.threadedinsert") ScopedTemporaryDirectory TempDir; const uint64_t kChunkSize = 1048; - const int32_t kChunkCount = 2048; + const int32_t kChunkCount = 512; uint64_t ExpectedSize = 0; tsl::robin_map<IoHash, IoBuffer, IoHash::Hasher> Chunks; @@ -1803,7 +1803,7 @@ TEST_CASE("compactcas.restart") } const uint64_t kChunkSize = 1048 + 395; - const size_t kChunkCount = 7167; + const size_t kChunkCount = 2000; std::vector<IoHash> Hashes; Hashes.reserve(kChunkCount); @@ -1984,9 +1984,8 @@ TEST_CASE("compactcas.iteratechunks") WorkerThreadPool ThreadPool(Max(GetHardwareConcurrency() - 1u, 2u), "put"); const uint64_t kChunkSize = 1048 + 395; - const size_t kChunkCount = 63840; + const size_t kChunkCount = 10000; - for (uint32_t N = 0; N < 2; N++) { GcManager Gc; CasContainerStrategy Cas(Gc); @@ -2017,7 +2016,7 @@ TEST_CASE("compactcas.iteratechunks") size_t BatchCount = Min<size_t>(kChunkCount - Offset, 512u); WorkLatch.AddCount(1); ThreadPool.ScheduleWork( - [N, &WorkLatch, &InsertLock, &ChunkHashesLookup, &ExpectedSize, &Hashes, &Cas, Offset, BatchCount]() { + [&WorkLatch, &InsertLock, &ChunkHashesLookup, &ExpectedSize, &Hashes, &Cas, Offset, BatchCount]() { auto _ = MakeGuard([&WorkLatch]() { WorkLatch.CountDown(); }); std::vector<IoBuffer> BatchBlobs; @@ -2028,7 +2027,7 @@ TEST_CASE("compactcas.iteratechunks") while (BatchBlobs.size() < BatchCount) { IoBuffer Chunk = CreateRandomBlob( - N + kChunkSize + ((BatchHashes.size() % 100) + (BatchHashes.size() % 7) * 315u + Offset % 377)); + kChunkSize + ((BatchHashes.size() % 100) + (BatchHashes.size() % 7) * 315u + Offset % 377)); IoHash Hash = IoHash::HashBuffer(Chunk); { RwLock::ExclusiveLockScope __(InsertLock); diff --git a/src/zenstore/filecas.cpp b/src/zenstore/filecas.cpp index 0088afe6e..3a7a72ee3 100644 --- a/src/zenstore/filecas.cpp +++ b/src/zenstore/filecas.cpp @@ -1231,7 +1231,7 @@ public: ZEN_MEMSCOPE(GetFileCasTag()); ZEN_TRACE_CPU("FileCas::CompactStore"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -1375,7 +1375,7 @@ public: ZEN_MEMSCOPE(GetFileCasTag()); ZEN_TRACE_CPU("FileCas::RemoveUnreferencedData"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -1458,7 +1458,7 @@ FileCasStrategy::CreateReferencePruner(GcCtx& Ctx, GcReferenceStoreStats&) ZEN_TRACE_CPU("FileCas::CreateReferencePruner"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { diff --git a/src/zenstore/gc.cpp b/src/zenstore/gc.cpp index f3edf804d..928fc3f08 100644 --- a/src/zenstore/gc.cpp +++ b/src/zenstore/gc.cpp @@ -546,7 +546,7 @@ FilterReferences(GcCtx& Ctx, std::string_view Context, std::vector<IoHash>& InOu return false; } - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); const bool Filter = Ctx.Settings.AttachmentRangeMax != IoHash::Max || Ctx.Settings.AttachmentRangeMin != IoHash::Zero; @@ -2063,6 +2063,14 @@ GcScheduler::GetState() const { { std::unique_lock Lock(m_GcMutex); + + if (m_TriggerGcParams || m_TriggerScrubParams) + { + // If a trigger is pending, treat it as running + Result.Status = GcSchedulerStatus::kRunning; + return Result; + } + Result.LastFullGcTime = m_LastGcTime; Result.LastFullGCDiff = m_LastFullGCDiff; Result.LastFullGcDuration = m_LastFullGcDuration; diff --git a/src/zenstore/include/zenstore/cache/cachepolicy.h b/src/zenstore/include/zenstore/cache/cachepolicy.h index 7773cd3d1..4a062a0c2 100644 --- a/src/zenstore/include/zenstore/cache/cachepolicy.h +++ b/src/zenstore/include/zenstore/cache/cachepolicy.h @@ -163,9 +163,9 @@ private: friend class CacheRecordPolicyBuilder; friend class OptionalCacheRecordPolicy; - CachePolicy RecordPolicy = CachePolicy::Default; - CachePolicy DefaultValuePolicy = CachePolicy::Default; - RefPtr<const Private::ICacheRecordPolicyShared> Shared; + CachePolicy RecordPolicy = CachePolicy::Default; + CachePolicy DefaultValuePolicy = CachePolicy::Default; + Ref<const Private::ICacheRecordPolicyShared> Shared; }; /** A cache record policy builder is used to construct a cache record policy. */ @@ -186,8 +186,8 @@ public: CacheRecordPolicy Build(); private: - CachePolicy BasePolicy = CachePolicy::Default; - RefPtr<Private::ICacheRecordPolicyShared> Shared; + CachePolicy BasePolicy = CachePolicy::Default; + Ref<Private::ICacheRecordPolicyShared> Shared; }; /** diff --git a/src/zenstore/include/zenstore/cidstore.h b/src/zenstore/include/zenstore/cidstore.h index d54062476..c00e0449f 100644 --- a/src/zenstore/include/zenstore/cidstore.h +++ b/src/zenstore/include/zenstore/cidstore.h @@ -58,16 +58,14 @@ struct CidStoreConfiguration * */ -class CidStore final : public ChunkResolver, public StatsProvider +class CidStore final : public ChunkStore, public StatsProvider { public: CidStore(GcManager& Gc); ~CidStore(); - struct InsertResult - { - bool New = false; - }; + using InsertResult = ChunkStore::InsertResult; + enum class InsertMode { kCopyOnly, @@ -75,17 +73,17 @@ public: }; void Initialize(const CidStoreConfiguration& Config); - InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, InsertMode Mode = InsertMode::kMayBeMovedInPlace); - std::vector<InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas, - std::span<IoHash> RawHashes, - InsertMode Mode = InsertMode::kMayBeMovedInPlace); - virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash) override; + InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, InsertMode Mode); + std::vector<InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes) override; + std::vector<InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes, InsertMode Mode); + IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; bool IterateChunks(std::span<IoHash> DecompressedIds, const std::function<bool(size_t Index, const IoBuffer& Payload)>& AsyncCallback, WorkerThreadPool* OptionalWorkerPool, uint64_t LargeSizeLimit); - bool ContainsChunk(const IoHash& DecompressedId); - void FilterChunks(HashKeySet& InOutChunks); + bool ContainsChunk(const IoHash& DecompressedId) override; + void FilterChunks(HashKeySet& InOutChunks) override; void Flush(); CidStoreSize TotalSize() const; CidStoreStats Stats() const; diff --git a/src/zenstore/include/zenstore/memorycidstore.h b/src/zenstore/include/zenstore/memorycidstore.h new file mode 100644 index 000000000..0311274d5 --- /dev/null +++ b/src/zenstore/include/zenstore/memorycidstore.h @@ -0,0 +1,68 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "cidstore.h" +#include "zenstore.h" + +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/thread.h> + +#include <deque> +#include <span> +#include <thread> +#include <unordered_map> + +namespace zen { + +class HashKeySet; + +/** Memory-backed chunk store. + * + * Stores chunks in an in-memory hash map, optionally layered over a + * standard CidStore for write-through and read fallback. When a backing + * store is provided: + * + * - AddChunk writes to memory and asynchronously to the backing store. + * - FindChunkByCid checks memory first, then falls back to the backing store. + * - ContainsChunk and FilterChunks check memory first, then the backing store. + * + * The memory store does NOT cache read-through results from the backing store. + * Only chunks explicitly added via AddChunk/AddChunks are held in memory. + */ + +class MemoryCidStore : public ChunkStore +{ +public: + explicit MemoryCidStore(CidStore* BackingStore = nullptr); + ~MemoryCidStore(); + + InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash) override; + std::vector<InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes) override; + IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + bool ContainsChunk(const IoHash& DecompressedId) override; + void FilterChunks(HashKeySet& InOutChunks) override; + +private: + RwLock m_Lock; + std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> m_Chunks; + CidStore* m_BackingStore = nullptr; + + // Async write-through to backing store + struct PendingWrite + { + IoBuffer Data; + IoHash Hash; + }; + + std::mutex m_FlushLock; + std::vector<PendingWrite> m_FlushQueue; + Event m_FlushEvent; + std::thread m_FlushThread; + std::atomic<bool> m_FlushThreadEnabled{false}; + + void FlushThreadFunction(); +}; + +} // namespace zen diff --git a/src/zenstore/include/zenstore/projectstore.h b/src/zenstore/include/zenstore/projectstore.h index 100a82907..d05261967 100644 --- a/src/zenstore/include/zenstore/projectstore.h +++ b/src/zenstore/include/zenstore/projectstore.h @@ -305,11 +305,11 @@ public: std::unordered_set<IoHash, IoHash::Hasher> m_PendingPrepOpAttachments; GcClock::TimePoint m_PendingPrepOpAttachmentsRetainEnd; - RefPtr<OplogStorage> m_Storage; - uint64_t m_LogFlushPosition = 0; - bool m_IsLegacySnapshot = false; + Ref<OplogStorage> m_Storage; + uint64_t m_LogFlushPosition = 0; + bool m_IsLegacySnapshot = false; - RefPtr<OplogStorage> GetStorage(); + Ref<OplogStorage> GetStorage(); /** Scan oplog and register each entry, thus updating the in-memory tracking tables */ @@ -484,7 +484,7 @@ public: Project& Project, Oplog& Oplog, const std::unordered_set<std::string>& WantedFieldNames); - static CbObject GetChunkInfo(LoggerRef InLog, Project& Project, Oplog& Oplog, const Oid& ChunkId); + static CbObject GetChunkInfo(Project& Project, Oplog& Oplog, const Oid& ChunkId); struct GetChunkRangeResult { enum class EError : uint8_t @@ -502,8 +502,7 @@ public: uint64_t RawSize = 0; ZenContentType ContentType = ZenContentType::kUnknownContentType; }; - static GetChunkRangeResult GetChunkRange(LoggerRef InLog, - Project& Project, + static GetChunkRangeResult GetChunkRange(Project& Project, Oplog& Oplog, const Oid& ChunkId, uint64_t Offset, diff --git a/src/zenstore/include/zenstore/zenstore.h b/src/zenstore/include/zenstore/zenstore.h index bed219b4b..95ae33a4a 100644 --- a/src/zenstore/include/zenstore/zenstore.h +++ b/src/zenstore/include/zenstore/zenstore.h @@ -4,19 +4,56 @@ #include <zencore/zencore.h> +#include <span> +#include <vector> + #define ZENSTORE_API namespace zen { +class HashKeySet; class IoBuffer; struct IoHash; class ChunkResolver { public: + virtual ~ChunkResolver() = default; virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) = 0; }; +/** Abstract chunk store interface. + * + * Extends ChunkResolver with write and query operations. Both CidStore + * (disk-backed) and MemoryCidStore (in-memory) implement this interface, + * allowing callers to be agnostic about the storage backend. + */ +class ChunkStore : public ChunkResolver +{ +public: + struct InsertResult + { + bool New = false; + }; + + virtual InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash) = 0; + virtual std::vector<InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes) = 0; + virtual bool ContainsChunk(const IoHash& DecompressedId) = 0; + virtual void FilterChunks(HashKeySet& InOutChunks) = 0; +}; + +/** Composite resolver that tries a primary store first, then a fallback. */ +class FallbackChunkResolver : public ChunkResolver +{ +public: + FallbackChunkResolver(ChunkResolver& Primary, ChunkResolver& Fallback); + IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + +private: + ChunkResolver& m_Primary; + ChunkResolver& m_Fallback; +}; + ZENSTORE_API void zenstore_forcelinktests(); } // namespace zen diff --git a/src/zenstore/memorycidstore.cpp b/src/zenstore/memorycidstore.cpp new file mode 100644 index 000000000..b4832029b --- /dev/null +++ b/src/zenstore/memorycidstore.cpp @@ -0,0 +1,143 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenstore/hashkeyset.h> +#include <zenstore/memorycidstore.h> + +namespace zen { + +MemoryCidStore::MemoryCidStore(CidStore* BackingStore) : m_BackingStore(BackingStore) +{ + if (m_BackingStore) + { + m_FlushThreadEnabled = true; + m_FlushThread = std::thread(&MemoryCidStore::FlushThreadFunction, this); + } +} + +MemoryCidStore::~MemoryCidStore() +{ + m_FlushThreadEnabled = false; + m_FlushEvent.Set(); + if (m_FlushThread.joinable()) + { + m_FlushThread.join(); + } +} + +MemoryCidStore::InsertResult +MemoryCidStore::AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash) +{ + bool IsNew = false; + + m_Lock.WithExclusiveLock([&] { + auto [It, Inserted] = m_Chunks.try_emplace(RawHash, ChunkData); + IsNew = Inserted; + }); + + if (m_BackingStore) + { + std::lock_guard<std::mutex> Lock(m_FlushLock); + m_FlushQueue.push_back({.Data = ChunkData, .Hash = RawHash}); + m_FlushEvent.Set(); + } + + return {.New = IsNew}; +} + +std::vector<MemoryCidStore::InsertResult> +MemoryCidStore::AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes) +{ + std::vector<MemoryCidStore::InsertResult> Results; + Results.reserve(ChunkDatas.size()); + + for (size_t i = 0; i < ChunkDatas.size(); ++i) + { + Results.push_back(AddChunk(ChunkDatas[i], RawHashes[i])); + } + + return Results; +} + +IoBuffer +MemoryCidStore::FindChunkByCid(const IoHash& DecompressedId) +{ + IoBuffer Result; + + m_Lock.WithSharedLock([&] { + auto It = m_Chunks.find(DecompressedId); + if (It != m_Chunks.end()) + { + Result = It->second; + } + }); + + if (!Result && m_BackingStore) + { + Result = m_BackingStore->FindChunkByCid(DecompressedId); + } + + return Result; +} + +bool +MemoryCidStore::ContainsChunk(const IoHash& DecompressedId) +{ + bool Found = false; + + m_Lock.WithSharedLock([&] { Found = m_Chunks.find(DecompressedId) != m_Chunks.end(); }); + + if (!Found && m_BackingStore) + { + Found = m_BackingStore->ContainsChunk(DecompressedId); + } + + return Found; +} + +void +MemoryCidStore::FilterChunks(HashKeySet& InOutChunks) +{ + // Remove hashes that are present in our memory store + m_Lock.WithSharedLock([&] { InOutChunks.RemoveHashesIf([&](const IoHash& Hash) { return m_Chunks.find(Hash) != m_Chunks.end(); }); }); + + // Delegate remainder to backing store + if (m_BackingStore && !InOutChunks.IsEmpty()) + { + m_BackingStore->FilterChunks(InOutChunks); + } +} + +void +MemoryCidStore::FlushThreadFunction() +{ + SetCurrentThreadName("MemCidFlush"); + + while (m_FlushThreadEnabled) + { + m_FlushEvent.Wait(); + + std::vector<PendingWrite> Batch; + { + std::lock_guard<std::mutex> Lock(m_FlushLock); + Batch.swap(m_FlushQueue); + } + + for (PendingWrite& Write : Batch) + { + m_BackingStore->AddChunk(Write.Data, Write.Hash); + } + } + + // Drain remaining writes on shutdown + std::vector<PendingWrite> Remaining; + { + std::lock_guard<std::mutex> Lock(m_FlushLock); + Remaining.swap(m_FlushQueue); + } + for (PendingWrite& Write : Remaining) + { + m_BackingStore->AddChunk(Write.Data, Write.Hash); + } +} + +} // namespace zen diff --git a/src/zenstore/projectstore.cpp b/src/zenstore/projectstore.cpp index 13674da4d..8159f9f83 100644 --- a/src/zenstore/projectstore.cpp +++ b/src/zenstore/projectstore.cpp @@ -3180,6 +3180,7 @@ ProjectStore::Oplog::AddFileMapping(const RwLock::ExclusiveLockScope&, } else { + m_ChunkMap.erase(FileId); Entry.ServerPath = ServerPath; } @@ -3402,11 +3403,11 @@ ProjectStore::Oplog::AppendNewOplogEntry(CbPackage OpPackage) return EntryId; } -RefPtr<ProjectStore::OplogStorage> +Ref<ProjectStore::OplogStorage> ProjectStore::Oplog::GetStorage() { ZEN_MEMSCOPE(GetProjectstoreTag()); - RefPtr<OplogStorage> Storage; + Ref<OplogStorage> Storage; { RwLock::SharedLockScope _(m_OplogLock); Storage = m_Storage; @@ -3424,7 +3425,7 @@ ProjectStore::Oplog::AppendNewOplogEntry(CbObjectView Core) using namespace std::literals; - RefPtr<OplogStorage> Storage = GetStorage(); + Ref<OplogStorage> Storage = GetStorage(); if (!Storage) { return {}; @@ -3456,7 +3457,7 @@ ProjectStore::Oplog::AppendNewOplogEntries(std::span<CbObjectView> Cores) using namespace std::literals; - RefPtr<OplogStorage> Storage = GetStorage(); + Ref<OplogStorage> Storage = GetStorage(); if (!Storage) { return std::vector<ProjectStore::LogSequenceNumber>(Cores.size(), LogSequenceNumber{}); @@ -3515,7 +3516,18 @@ ProjectStore::Project::~Project() // Only write access times if we have not been explicitly deleted if (!m_OplogStoragePath.empty()) { - WriteAccessTimes(); + try + { + WriteAccessTimes(); + } + catch (const std::exception& Ex) + { + // RefCounted::Release() is noexcept, so a destructor that propagates an exception + // terminates the program. WriteAccessTimes() already catches I/O failures, but lock + // acquisition and allocations ahead of the internal try/catch can still throw, so + // we defend here as well. + ZEN_ERROR("project '{}': ~Project threw exception: '{}'", Identifier, Ex.what()); + } } } @@ -4738,7 +4750,7 @@ ProjectStore::GetProjectsList() CbObject ProjectStore::GetProjectFiles(LoggerRef InLog, Project& Project, Oplog& Oplog, const std::unordered_set<std::string>& WantedFieldNames) { - auto Log = [&InLog]() { return InLog; }; + ZEN_SCOPED_LOG(InLog); using namespace std::literals; @@ -4893,7 +4905,7 @@ ProjectStore::GetProjectChunkInfos(LoggerRef InLog, Project& Project, Oplog& Opl ZEN_MEMSCOPE(GetProjectstoreTag()); ZEN_TRACE_CPU("ProjectStore::GetProjectChunkInfos"); - auto Log = [&InLog]() { return InLog; }; + ZEN_SCOPED_LOG(InLog); using namespace std::literals; @@ -5050,16 +5062,13 @@ ProjectStore::GetProjectChunkInfos(LoggerRef InLog, Project& Project, Oplog& Opl } CbObject -ProjectStore::GetChunkInfo(LoggerRef InLog, Project& Project, Oplog& Oplog, const Oid& ChunkId) +ProjectStore::GetChunkInfo(Project& Project, Oplog& Oplog, const Oid& ChunkId) { ZEN_MEMSCOPE(GetProjectstoreTag()); ZEN_TRACE_CPU("ProjectStore::GetChunkInfo"); using namespace std::literals; - auto Log = [&InLog]() { return InLog; }; - ZEN_UNUSED(Log); - IoBuffer Chunk = Oplog.FindChunk(Project.RootDir, ChunkId, nullptr); if (!Chunk) { @@ -5168,7 +5177,10 @@ ExtractRange(IoBuffer&& Chunk, uint64_t Offset, uint64_t Size, ZenContentType Ac const bool IsFullRange = (Offset == 0) && ((Size == ~(0ull)) || (Size == ChunkSize)); if (IsFullRange) { - Result.Chunk = CompositeBuffer(SharedBuffer(std::move(Chunk))); + if (ChunkSize > 0) + { + Result.Chunk = CompositeBuffer(SharedBuffer(std::move(Chunk))); + } Result.RawSize = 0; } else @@ -5205,8 +5217,7 @@ ExtractRange(IoBuffer&& Chunk, uint64_t Offset, uint64_t Size, ZenContentType Ac } ProjectStore::GetChunkRangeResult -ProjectStore::GetChunkRange(LoggerRef InLog, - Project& Project, +ProjectStore::GetChunkRange(Project& Project, Oplog& Oplog, const Oid& ChunkId, uint64_t Offset, @@ -5218,9 +5229,6 @@ ProjectStore::GetChunkRange(LoggerRef InLog, ZEN_TRACE_CPU("ProjectStore::GetChunkRange"); - auto Log = [&InLog]() { return InLog; }; - ZEN_UNUSED(Log); - uint64_t OldTag = OptionalInOutModificationTag == nullptr ? 0 : *OptionalInOutModificationTag; IoBuffer Chunk = Oplog.FindChunk(Project.RootDir, ChunkId, OptionalInOutModificationTag); if (!Chunk) @@ -5727,7 +5735,7 @@ public: ZEN_TRACE_CPU("Store::CompactStore"); ZEN_MEMSCOPE(GetProjectstoreTag()); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -5863,7 +5871,7 @@ ProjectStore::RemoveExpiredData(GcCtx& Ctx, GcStats& Stats) ZEN_TRACE_CPU("Store::RemoveExpiredData"); ZEN_MEMSCOPE(GetProjectstoreTag()); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -6016,7 +6024,7 @@ public: { ZEN_TRACE_CPU("Store::UpdateLockedState"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; @@ -6093,7 +6101,7 @@ public: { ZEN_TRACE_CPU("Store::GetUnusedReferences"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); size_t InitialCount = IoCids.size(); size_t UsedCount = InitialCount; @@ -6117,6 +6125,7 @@ public: } private: + LoggerRef Log() { return m_ProjectStore.Log(); } ProjectStore& m_ProjectStore; std::vector<IoHash> m_References; }; @@ -6161,7 +6170,7 @@ public: { ZEN_TRACE_CPU("Store::Oplog::PreCache"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -6279,7 +6288,7 @@ public: { ZEN_TRACE_CPU("Store::Oplog::UpdateLockedState"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); Stopwatch Timer; const auto _ = MakeGuard([&] { @@ -6387,7 +6396,7 @@ public: { ZEN_TRACE_CPU("Store::Oplog::GetUnusedReferences"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); const size_t InitialCount = IoCids.size(); size_t UsedCount = InitialCount; @@ -6413,6 +6422,7 @@ public: return UnusedReferences; } + LoggerRef Log() { return m_Project->Log(); } ProjectStore& m_ProjectStore; Ref<ProjectStore::Project> m_Project; std::string m_OplogId; @@ -6428,7 +6438,7 @@ ProjectStore::CreateReferenceCheckers(GcCtx& Ctx) { ZEN_TRACE_CPU("Store::CreateReferenceCheckers"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); size_t ProjectCount = 0; size_t OplogCount = 0; @@ -6490,11 +6500,9 @@ ProjectStore::CreateReferenceCheckers(GcCtx& Ctx) std::vector<RwLock::SharedLockScope> ProjectStore::LockState(GcCtx& Ctx) { + ZEN_UNUSED(Ctx); ZEN_TRACE_CPU("Store::LockState"); - auto Log = [&Ctx]() { return Ctx.Logger; }; - ZEN_UNUSED(Log); - std::vector<RwLock::SharedLockScope> Locks; Locks.emplace_back(RwLock::SharedLockScope(m_ProjectsLock)); for (auto& ProjectIt : m_Projects) @@ -6526,7 +6534,7 @@ public: { ZEN_TRACE_CPU("Store::Validate"); - auto Log = [&Ctx]() { return Ctx.Logger; }; + ZEN_SCOPED_LOG(Ctx.Logger); ProjectStore::Oplog::ValidationResult Result; @@ -6625,9 +6633,6 @@ ProjectStore::CreateReferenceValidators(GcCtx& Ctx) return {}; } - auto Log = [&Ctx]() { return Ctx.Logger; }; - ZEN_UNUSED(Log); - DiscoverProjects(); std::vector<std::pair<std::string, std::string>> Oplogs; @@ -8242,8 +8247,7 @@ TEST_CASE("project.store.partial.read") { uint64_t ModificationTag = 0; - auto Result = ProjectStore.GetChunkRange(Log(), - *Project1, + auto Result = ProjectStore.GetChunkRange(*Project1, *Oplog1, Attachments[OpIds[1]][0].first, 0, @@ -8258,8 +8262,7 @@ TEST_CASE("project.store.partial.read") CompressedBuffer Attachment = CompressedBuffer::FromCompressed(Result.Chunk, RawHash, RawSize); CHECK(RawSize == Attachments[OpIds[1]][0].second.DecodeRawSize()); - auto Result2 = ProjectStore.GetChunkRange(Log(), - *Project1, + auto Result2 = ProjectStore.GetChunkRange(*Project1, *Oplog1, Attachments[OpIds[1]][0].first, 0, @@ -8272,8 +8275,7 @@ TEST_CASE("project.store.partial.read") { uint64_t FullChunkModificationTag = 0; { - auto Result = ProjectStore.GetChunkRange(Log(), - *Project1, + auto Result = ProjectStore.GetChunkRange(*Project1, *Oplog1, Attachments[OpIds[2]][1].first, 0, @@ -8286,8 +8288,7 @@ TEST_CASE("project.store.partial.read") Attachments[OpIds[2]][1].second.DecodeRawSize()); } { - auto Result = ProjectStore.GetChunkRange(Log(), - *Project1, + auto Result = ProjectStore.GetChunkRange(*Project1, *Oplog1, Attachments[OpIds[2]][1].first, 0, @@ -8300,8 +8301,7 @@ TEST_CASE("project.store.partial.read") { uint64_t PartialChunkModificationTag = 0; { - auto Result = ProjectStore.GetChunkRange(Log(), - *Project1, + auto Result = ProjectStore.GetChunkRange(*Project1, *Oplog1, Attachments[OpIds[2]][1].first, 5, @@ -8324,8 +8324,7 @@ TEST_CASE("project.store.partial.read") } { - auto Result = ProjectStore.GetChunkRange(Log(), - *Project1, + auto Result = ProjectStore.GetChunkRange(*Project1, *Oplog1, Attachments[OpIds[2]][1].first, 0, diff --git a/src/zenstore/workspaces.cpp b/src/zenstore/workspaces.cpp index ad21bbc68..cfdcd294c 100644 --- a/src/zenstore/workspaces.cpp +++ b/src/zenstore/workspaces.cpp @@ -331,9 +331,6 @@ ScanFolder(LoggerRef InLog, const std::filesystem::path& Path, WorkerThreadPool& { ZEN_TRACE_CPU("workspaces::ScanFolderImpl"); - auto Log = [&InLog]() { return InLog; }; - ZEN_UNUSED(Log); - FolderScanner Data(InLog, WorkerPool, Path); Data.Traverse(); return std::make_unique<FolderStructure>(std::move(Data.FoundFiles), std::move(Data.FoundFileIds)); @@ -811,7 +808,7 @@ Workspaces::GetShareAlias(std::string_view Alias) const std::vector<Workspaces::WorkspaceConfiguration> Workspaces::ReadConfig(const LoggerRef& InLog, const std::filesystem::path& WorkspaceStatePath, std::string& OutError) { - auto Log = [&InLog]() { return InLog; }; + ZEN_SCOPED_LOG(InLog); using namespace std::literals; @@ -835,7 +832,7 @@ Workspaces::WriteConfig(const LoggerRef& InLog, const std::filesystem::path& WorkspaceStatePath, const std::vector<WorkspaceConfiguration>& WorkspaceConfigurations) { - auto Log = [&InLog]() { return InLog; }; + ZEN_SCOPED_LOG(InLog); using namespace std::literals; @@ -850,7 +847,7 @@ Workspaces::WriteConfig(const LoggerRef& InLog, std::vector<Workspaces::WorkspaceShareConfiguration> Workspaces::ReadWorkspaceConfig(const LoggerRef& InLog, const std::filesystem::path& WorkspaceRoot, std::string& OutError) { - auto Log = [&InLog]() { return InLog; }; + ZEN_SCOPED_LOG(InLog); using namespace std::literals; @@ -874,7 +871,7 @@ Workspaces::WriteWorkspaceConfig(const LoggerRef& InLog, const std::filesystem::path& WorkspaceRoot, const std::vector<WorkspaceShareConfiguration>& WorkspaceShareConfigurations) { - auto Log = [&InLog]() { return InLog; }; + ZEN_SCOPED_LOG(InLog); using namespace std::literals; @@ -1049,9 +1046,6 @@ Workspaces::RemoveWorkspaceShare(const LoggerRef& Log, const std::filesystem::pa Workspaces::WorkspaceConfiguration Workspaces::FindWorkspace(const LoggerRef& InLog, const std::filesystem::path& WorkspaceStatePath, const Oid& WorkspaceId) { - auto Log = [&InLog]() { return InLog; }; - ZEN_UNUSED(Log); - std::string Error; std::vector<WorkspaceConfiguration> Workspaces = ReadConfig(InLog, WorkspaceStatePath, Error); if (!Error.empty()) @@ -1075,9 +1069,6 @@ Workspaces::FindWorkspace(const LoggerRef& InLog, const std::filesystem::path& WorkspaceStatePath, const std::filesystem::path& WorkspaceRoot) { - auto Log = [&InLog]() { return InLog; }; - ZEN_UNUSED(Log); - std::string Error; std::vector<WorkspaceConfiguration> Workspaces = ReadConfig(InLog, WorkspaceStatePath, Error); if (!Error.empty()) @@ -1102,7 +1093,7 @@ Workspaces::FindWorkspaceShare(const LoggerRef& InLog, std::string_view ShareAlias, WorkspaceConfiguration& OutWorkspace) { - auto Log = [&InLog]() { return InLog; }; + ZEN_SCOPED_LOG(InLog); std::string Error; std::vector<WorkspaceConfiguration> Workspaces = ReadConfig(InLog, WorkspaceStatePath, Error); @@ -1151,7 +1142,7 @@ Workspaces::FindWorkspaceShare(const LoggerRef& InLog, Workspaces::WorkspaceShareConfiguration Workspaces::FindWorkspaceShare(const LoggerRef& InLog, const std::filesystem::path& WorkspaceRoot, const Oid& WorkspaceShareId) { - auto Log = [&InLog]() { return InLog; }; + ZEN_SCOPED_LOG(InLog); std::string Error; std::vector<WorkspaceShareConfiguration> Shares = ReadWorkspaceConfig(InLog, WorkspaceRoot, Error); if (!Error.empty()) @@ -1174,7 +1165,7 @@ Workspaces::FindWorkspaceShare(const LoggerRef& InLog, const std::filesystem::pa Workspaces::WorkspaceShareConfiguration Workspaces::FindWorkspaceShare(const LoggerRef& InLog, const std::filesystem::path& WorkspaceRoot, const std::filesystem::path& SharePath) { - auto Log = [&InLog]() { return InLog; }; + ZEN_SCOPED_LOG(InLog); std::string Error; std::vector<WorkspaceShareConfiguration> Shares = ReadWorkspaceConfig(InLog, WorkspaceRoot, Error); if (!Error.empty()) diff --git a/src/zenstore/zenstore.cpp b/src/zenstore/zenstore.cpp index c563cc202..bf0c71211 100644 --- a/src/zenstore/zenstore.cpp +++ b/src/zenstore/zenstore.cpp @@ -2,6 +2,26 @@ #include "zenstore/zenstore.h" +#include <zencore/iobuffer.h> + +namespace zen { + +FallbackChunkResolver::FallbackChunkResolver(ChunkResolver& Primary, ChunkResolver& Fallback) : m_Primary(Primary), m_Fallback(Fallback) +{ +} + +IoBuffer +FallbackChunkResolver::FindChunkByCid(const IoHash& DecompressedId) +{ + if (IoBuffer Result = m_Primary.FindChunkByCid(DecompressedId)) + { + return Result; + } + return m_Fallback.FindChunkByCid(DecompressedId); +} + +} // namespace zen + #if ZEN_WITH_TESTS # include <zenstore/blockstore.h> diff --git a/src/zentelemetry/include/zentelemetry/stats.h b/src/zentelemetry/include/zentelemetry/stats.h index 260b0fcfb..ddec8e883 100644 --- a/src/zentelemetry/include/zentelemetry/stats.h +++ b/src/zentelemetry/include/zentelemetry/stats.h @@ -40,7 +40,7 @@ private: * Suitable for tracking quantities that go up and down over time, such as * requests in flight or active jobs. All operations are lock-free via atomics. * - * Unlike a Meter, a Counter does not track rates — it only records a running total. + * Unlike a Meter, a Counter does not track rates - it only records a running total. */ class Counter { @@ -63,7 +63,7 @@ private: * rate = rate + alpha * (instantRate - rate) * * where instantRate = Count / Interval. The alpha value controls how quickly - * the average responds to changes — higher alpha means more weight on recent + * the average responds to changes - higher alpha means more weight on recent * samples. Typical alphas are derived from a decay half-life (e.g. 1, 5, 15 * minutes) and a fixed tick interval. * @@ -205,7 +205,7 @@ private: * * Records exact min, max, count and mean across all values ever seen, plus a * reservoir sample (via UniformSample) used to compute percentiles. Percentiles - * are therefore probabilistic — they reflect the distribution of a representative + * are therefore probabilistic - they reflect the distribution of a representative * sample rather than the full history. * * All operations are thread-safe via lock-free atomics. @@ -343,7 +343,7 @@ struct RequestStatsSnapshot * * Maintains two independent histogram+meter pairs: one for request duration * (in hi-freq timer ticks) and one for transferred bytes. Both dimensions - * share the same request count — a single Update() call advances both. + * share the same request count - a single Update() call advances both. * * Duration accessors return values in hi-freq timer ticks. Multiply by * GetHifreqTimerToSeconds() to convert to seconds. @@ -383,7 +383,7 @@ public: * or destruction. * * The byte count can be supplied at construction or updated at any point via - * SetBytes() before the scope ends — useful when the response size is not + * SetBytes() before the scope ends - useful when the response size is not * known until the operation completes. * * Call Cancel() to discard the measurement entirely. diff --git a/src/zentelemetry/otelmetricsprotozero.h b/src/zentelemetry/otelmetricsprotozero.h index 12bae0d75..f02eca7ab 100644 --- a/src/zentelemetry/otelmetricsprotozero.h +++ b/src/zentelemetry/otelmetricsprotozero.h @@ -49,22 +49,22 @@ option go_package = "go.opentelemetry.io/proto/otlp/metrics/v1"; // data but do not implement the OTLP protocol. // // MetricsData -// └─── ResourceMetrics -// ├── Resource -// ├── SchemaURL -// └── ScopeMetrics -// ├── Scope -// ├── SchemaURL -// └── Metric -// ├── Name -// ├── Description -// ├── Unit -// └── data -// ├── Gauge -// ├── Sum -// ├── Histogram -// ├── ExponentialHistogram -// └── Summary +// +--- ResourceMetrics +// |-- Resource +// |-- SchemaURL +// +-- ScopeMetrics +// |-- Scope +// |-- SchemaURL +// +-- Metric +// |-- Name +// |-- Description +// |-- Unit +// +-- data +// |-- Gauge +// |-- Sum +// |-- Histogram +// |-- ExponentialHistogram +// +-- Summary // // The main difference between this message and collector protocol is that // in this message there will not be any "control" or "metadata" specific to diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp index 54e54edde..73cb7ff2d 100644 --- a/src/zentest-appstub/zentest-appstub.cpp +++ b/src/zentest-appstub/zentest-appstub.cpp @@ -33,6 +33,13 @@ using namespace zen; // Some basic functions to implement some test "compute" functions +struct ForcedExitException : std::exception +{ + int Code; + explicit ForcedExitException(int InCode) : Code(InCode) {} + const char* what() const noexcept override { return "forced exit"; } +}; + std::string Rot13Function(std::string_view InputString) { @@ -111,6 +118,16 @@ DescribeFunctions() << "Sleep"sv; Versions << "Version"sv << Guid::FromString("88888888-8888-8888-8888-888888888888"sv); Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Fail"sv; + Versions << "Version"sv << Guid::FromString("fa11fa11-fa11-fa11-fa11-fa11fa11fa11"sv); + Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Crash"sv; + Versions << "Version"sv << Guid::FromString("c4a50000-c4a5-c4a5-c4a5-c4a5c4a5c4a5"sv); + Versions.EndObject(); Versions.EndArray(); return Versions.Save(); @@ -201,6 +218,38 @@ ExecuteFunction(CbObject Action, ContentResolver ChunkResolver) zen::Sleep(static_cast<int>(SleepTimeMs)); return Apply(IdentityFunction); } + else if (Function == "Fail"sv) + { + int FailExitCode = static_cast<int>(Action["Constants"sv].AsObjectView()["ExitCode"sv].AsUInt64()); + if (FailExitCode == 0) + { + FailExitCode = 1; + } + throw ForcedExitException(FailExitCode); + } + else if (Function == "Crash"sv) + { + // Crash modes: + // "abort" - calls std::abort() (SIGABRT / process termination) + // "nullptr" - dereferences a null pointer (SIGSEGV / access violation) + std::string_view Mode = Action["Constants"sv].AsObjectView()["Mode"sv].AsString(); + + printf("[zentest] crashing with mode: %.*s\n", int(Mode.size()), Mode.data()); + fflush(stdout); + + if (Mode == "nullptr"sv) + { + volatile int* Ptr = nullptr; + *Ptr = 42; + } + + // Default crash mode (also reached after nullptr write on platforms + // that don't immediately fault on null dereference) +#if defined(_MSC_VER) + _set_abort_behavior(0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); +#endif + std::abort(); + } else { return {}; @@ -421,6 +470,12 @@ main(int argc, char* argv[]) } } } + catch (ForcedExitException& Ex) + { + printf("[zentest] forced exit with code: %d\n", Ex.Code); + + ExitCode = Ex.Code; + } catch (std::exception& Ex) { printf("[zentest] exception caught in main: '%s'\n", Ex.what()); diff --git a/src/zenutil/cloud/imdscredentials.cpp b/src/zenutil/cloud/imdscredentials.cpp index dde1dc019..b025eb6da 100644 --- a/src/zenutil/cloud/imdscredentials.cpp +++ b/src/zenutil/cloud/imdscredentials.cpp @@ -115,7 +115,7 @@ ImdsCredentialProvider::FetchToken() HttpClient::KeyValueMap Headers; Headers->emplace("X-aws-ec2-metadata-token-ttl-seconds", "21600"); - HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", Headers); + HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", IoBuffer{}, Headers); if (!Response.IsSuccess()) { ZEN_WARN("IMDS token request failed: {}", Response.ErrorMessage("PUT /latest/api/token")); @@ -213,7 +213,7 @@ ImdsCredentialProvider::FetchCredentials() } else { - // Expiration is in the past or unparseable — force refresh next time + // Expiration is in the past or unparseable - force refresh next time NewExpiresAt = std::chrono::steady_clock::now(); } @@ -369,7 +369,7 @@ TEST_CASE("imdscredentials.fetch_from_mock") TEST_CASE("imdscredentials.unreachable_endpoint") { - // Point at a non-existent server — should return empty credentials, not crash + // Point at a non-existent server - should return empty credentials, not crash ImdsCredentialProviderOptions Opts; Opts.Endpoint = "http://127.0.0.1:1"; // unlikely to have anything listening Opts.ConnectTimeout = std::chrono::milliseconds(100); diff --git a/src/zenutil/cloud/minioprocess.cpp b/src/zenutil/cloud/minioprocess.cpp index 457453bd8..2db0010dc 100644 --- a/src/zenutil/cloud/minioprocess.cpp +++ b/src/zenutil/cloud/minioprocess.cpp @@ -45,7 +45,7 @@ struct MinioProcess::Impl } CreateProcOptions Options; - Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; Options.Environment.emplace_back("MINIO_ROOT_USER", m_Options.RootUser); Options.Environment.emplace_back("MINIO_ROOT_PASSWORD", m_Options.RootPassword); @@ -102,7 +102,7 @@ struct MinioProcess::Impl { if (m_DataDir.empty()) { - ZEN_WARN("MinIO: Cannot create bucket before data directory is initialized — call SpawnMinioServer() first"); + ZEN_WARN("MinIO: Cannot create bucket before data directory is initialized - call SpawnMinioServer() first"); return; } diff --git a/src/zenutil/cloud/mockimds.cpp b/src/zenutil/cloud/mockimds.cpp index 6919fab4d..88b348ed6 100644 --- a/src/zenutil/cloud/mockimds.cpp +++ b/src/zenutil/cloud/mockimds.cpp @@ -93,7 +93,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request) return; } - // Autoscaling lifecycle state — 404 when not in an ASG + // Autoscaling lifecycle state - 404 when not in an ASG if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") { if (Aws.AutoscalingState.empty()) @@ -105,7 +105,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request) return; } - // Spot interruption notice — 404 when no interruption pending + // Spot interruption notice - 404 when no interruption pending if (Uri == "latest/meta-data/spot/instance-action") { if (Aws.SpotAction.empty()) @@ -117,7 +117,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request) return; } - // IAM role discovery — returns the role name + // IAM role discovery - returns the role name if (Uri == "latest/meta-data/iam/security-credentials/") { if (Aws.IamRoleName.empty()) diff --git a/src/zenutil/cloud/s3client.cpp b/src/zenutil/cloud/s3client.cpp index 26d1023f4..83238f5cc 100644 --- a/src/zenutil/cloud/s3client.cpp +++ b/src/zenutil/cloud/s3client.cpp @@ -137,6 +137,8 @@ namespace { } // namespace +std::string_view S3GetObjectResult::NotFoundErrorText = "Not found"; + S3Client::S3Client(const S3ClientOptions& Options) : m_Log(logging::Get("s3")) , m_BucketName(Options.BucketName) @@ -145,13 +147,8 @@ S3Client::S3Client(const S3ClientOptions& Options) , m_PathStyle(Options.PathStyle) , m_Credentials(Options.Credentials) , m_CredentialProvider(Options.CredentialProvider) -, m_HttpClient(BuildEndpoint(), - HttpClientSettings{ - .LogCategory = "s3", - .ConnectTimeout = Options.ConnectTimeout, - .Timeout = Options.Timeout, - .RetryCount = Options.RetryCount, - }) +, m_HttpClient(BuildEndpoint(), Options.HttpSettings) +, m_Verbose(Options.HttpSettings.Verbose) { m_Host = BuildHostHeader(); ZEN_INFO("S3 client configured for bucket '{}' in region '{}' (endpoint: {}, {})", @@ -342,26 +339,37 @@ S3Client::PutObject(std::string_view Key, IoBuffer Content) return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 PUT '{}' succeeded ({} bytes)", Key, Content.GetSize()); + if (m_Verbose) + { + ZEN_INFO("S3 PUT '{}' succeeded ({} bytes)", Key, Content.GetSize()); + } return {}; } S3GetObjectResult -S3Client::GetObject(std::string_view Key) +S3Client::GetObject(std::string_view Key, const std::filesystem::path& TempFilePath) { std::string Path = KeyToPath(Key); HttpClient::KeyValueMap Headers = SignRequest("GET", Path, "", EmptyPayloadHash); - HttpClient::Response Response = m_HttpClient.Get(Path, Headers); + HttpClient::Response Response = m_HttpClient.Download(Path, TempFilePath, Headers); if (!Response.IsSuccess()) { + if (Response.StatusCode == HttpResponseCode::NotFound) + { + return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}}; + } + std::string Err = Response.ErrorMessage("S3 GET failed"); ZEN_WARN("S3 GET '{}' failed: {}", Key, Err); return S3GetObjectResult{S3Result{std::move(Err)}, {}}; } - ZEN_DEBUG("S3 GET '{}' succeeded ({} bytes)", Key, Response.ResponsePayload.GetSize()); + if (m_Verbose) + { + ZEN_INFO("S3 GET '{}' succeeded ({} bytes)", Key, Response.ResponsePayload.GetSize()); + } return S3GetObjectResult{{}, std::move(Response.ResponsePayload)}; } @@ -377,6 +385,11 @@ S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t Ran HttpClient::Response Response = m_HttpClient.Get(Path, Headers); if (!Response.IsSuccess()) { + if (Response.StatusCode == HttpResponseCode::NotFound) + { + return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}}; + } + std::string Err = Response.ErrorMessage("S3 GET range failed"); ZEN_WARN("S3 GET range '{}' [{}-{}] failed: {}", Key, RangeStart, RangeStart + RangeSize - 1, Err); return S3GetObjectResult{S3Result{std::move(Err)}, {}}; @@ -397,11 +410,14 @@ S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t Ran return S3GetObjectResult{S3Result{std::move(Err)}, {}}; } - ZEN_DEBUG("S3 GET range '{}' [{}-{}] succeeded ({} bytes)", - Key, - RangeStart, - RangeStart + RangeSize - 1, - Response.ResponsePayload.GetSize()); + if (m_Verbose) + { + ZEN_INFO("S3 GET range '{}' [{}-{}] succeeded ({} bytes)", + Key, + RangeStart, + RangeStart + RangeSize - 1, + Response.ResponsePayload.GetSize()); + } return S3GetObjectResult{{}, std::move(Response.ResponsePayload)}; } @@ -420,7 +436,10 @@ S3Client::DeleteObject(std::string_view Key) return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 DELETE '{}' succeeded", Key); + if (m_Verbose) + { + ZEN_INFO("S3 DELETE '{}' succeeded", Key); + } return {}; } @@ -462,7 +481,10 @@ S3Client::HeadObject(std::string_view Key) Info.LastModified = *V; } - ZEN_DEBUG("S3 HEAD '{}' succeeded (size={})", Key, Info.Size); + if (m_Verbose) + { + ZEN_INFO("S3 HEAD '{}' succeeded (size={})", Key, Info.Size); + } return S3HeadObjectResult{{}, std::move(Info), HeadObjectResult::Found}; } @@ -559,10 +581,16 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) } ContinuationToken = std::string(NextToken); - ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' fetching next page ({} objects so far)", Prefix, Result.Objects.size()); + if (m_Verbose) + { + ZEN_INFO("S3 ListObjectsV2 prefix='{}' fetching next page ({} objects so far)", Prefix, Result.Objects.size()); + } } - ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' returned {} objects", Prefix, Result.Objects.size()); + if (m_Verbose) + { + ZEN_INFO("S3 ListObjectsV2 prefix='{}' returned {} objects", Prefix, Result.Objects.size()); + } return Result; } @@ -601,7 +629,10 @@ S3Client::CreateMultipartUpload(std::string_view Key) return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; } - ZEN_DEBUG("S3 CreateMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + if (m_Verbose) + { + ZEN_INFO("S3 CreateMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + } return S3CreateMultipartUploadResult{{}, std::string(UploadId)}; } @@ -636,7 +667,10 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P return S3UploadPartResult{S3Result{std::move(Err)}, {}}; } - ZEN_DEBUG("S3 UploadPart '{}' part {} succeeded ({} bytes, etag={})", Key, PartNumber, Content.GetSize(), *ETag); + if (m_Verbose) + { + ZEN_INFO("S3 UploadPart '{}' part {} succeeded ({} bytes, etag={})", Key, PartNumber, Content.GetSize(), *ETag); + } return S3UploadPartResult{{}, *ETag}; } @@ -685,7 +719,10 @@ S3Client::CompleteMultipartUpload(std::string_view Key, return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 CompleteMultipartUpload '{}' succeeded ({} parts)", Key, PartETags.size()); + if (m_Verbose) + { + ZEN_INFO("S3 CompleteMultipartUpload '{}' succeeded ({} parts)", Key, PartETags.size()); + } return {}; } @@ -706,7 +743,10 @@ S3Client::AbortMultipartUpload(std::string_view Key, std::string_view UploadId) return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 AbortMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + if (m_Verbose) + { + ZEN_INFO("S3 AbortMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + } return {}; } @@ -749,7 +789,10 @@ S3Client::PutObjectMultipart(std::string_view Key, return PutObject(Key, TotalSize > 0 ? FetchRange(0, TotalSize) : IoBuffer{}); } - ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, TotalSize, (TotalSize + PartSize - 1) / PartSize); + if (m_Verbose) + { + ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, TotalSize, (TotalSize + PartSize - 1) / PartSize); + } S3CreateMultipartUploadResult InitResult = CreateMultipartUpload(Key); if (!InitResult) @@ -797,7 +840,10 @@ S3Client::PutObjectMultipart(std::string_view Key, throw; } - ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), TotalSize); + if (m_Verbose) + { + ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), TotalSize); + } return {}; } @@ -885,7 +931,10 @@ TEST_CASE("s3client.minio_integration") { using namespace std::literals; - // Spawn a local MinIO server + // Spawn a single MinIO server for the entire test case. Previously each SUBCASE re-entered + // the TEST_CASE from the top, spawning and killing MinIO per subcase - slow and flaky on + // macOS CI. Sequential sections avoid the re-entry while still sharing one MinIO instance + // that is torn down via RAII at scope exit. MinioProcessOptions MinioOpts; MinioOpts.Port = 19000; MinioOpts.RootUser = "testuser"; @@ -893,11 +942,8 @@ TEST_CASE("s3client.minio_integration") MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); - - // Pre-create the test bucket (creates a subdirectory in MinIO's data dir) Minio.CreateBucket("integration-test"); - // Configure S3Client for the test bucket S3ClientOptions Opts; Opts.BucketName = "integration-test"; Opts.Region = "us-east-1"; @@ -908,7 +954,7 @@ TEST_CASE("s3client.minio_integration") S3Client Client(Opts); - SUBCASE("put_get_delete") + // -- put_get_delete ------------------------------------------------------- { // PUT std::string_view TestData = "hello, minio integration test!"sv; @@ -937,14 +983,14 @@ TEST_CASE("s3client.minio_integration") CHECK(HeadRes2.Status == HeadObjectResult::NotFound); } - SUBCASE("head_not_found") + // -- head_not_found ------------------------------------------------------- { S3HeadObjectResult Res = Client.HeadObject("nonexistent/key.dat"); CHECK(Res.IsSuccess()); CHECK(Res.Status == HeadObjectResult::NotFound); } - SUBCASE("list_objects") + // -- list_objects --------------------------------------------------------- { // Upload several objects with a common prefix for (int i = 0; i < 3; ++i) @@ -979,7 +1025,7 @@ TEST_CASE("s3client.minio_integration") } } - SUBCASE("multipart_upload") + // -- multipart_upload ----------------------------------------------------- { // Create a payload large enough to exercise multipart (use minimum part size) constexpr uint64_t PartSize = 5 * 1024 * 1024; // 5 MB minimum @@ -1006,7 +1052,7 @@ TEST_CASE("s3client.minio_integration") Client.DeleteObject("multipart/large.bin"); } - SUBCASE("presigned_urls") + // -- presigned_urls ------------------------------------------------------- { // Upload an object std::string_view TestData = "presigned-url-test-data"sv; @@ -1032,8 +1078,6 @@ TEST_CASE("s3client.minio_integration") // Cleanup Client.DeleteObject("presigned/test.txt"); } - - Minio.StopMinioServer(); } TEST_SUITE_END(); diff --git a/src/zenutil/consoletui.cpp b/src/zenutil/consoletui.cpp index 124132aed..10e8abb31 100644 --- a/src/zenutil/consoletui.cpp +++ b/src/zenutil/consoletui.cpp @@ -311,7 +311,7 @@ TuiPickOne(std::string_view Title, std::span<const std::string> Items) printf("\033[1;7m"); // bold + reverse video } - // \xe2\x96\xb6 = U+25B6 BLACK RIGHT-POINTING TRIANGLE (▶) + // \xe2\x96\xb6 = U+25B6 BLACK RIGHT-POINTING TRIANGLE (>) const char* Indicator = IsSelected ? " \xe2\x96\xb6 " : " "; printf("%s%s", Indicator, Items[i].c_str()); @@ -328,7 +328,7 @@ TuiPickOne(std::string_view Title, std::span<const std::string> Items) printf("\r\033[K\n"); // Hint footer - // \xe2\x86\x91 = U+2191 ↑ \xe2\x86\x93 = U+2193 ↓ + // \xe2\x86\x91 = U+2191 ^ \xe2\x86\x93 = U+2193 v printf( "\r\033[K \033[2m\xe2\x86\x91/\xe2\x86\x93\033[0m navigate " "\033[2mEnter\033[0m confirm " diff --git a/src/zenutil/consul/consul.cpp b/src/zenutil/consul/consul.cpp index c9144e589..c372b131d 100644 --- a/src/zenutil/consul/consul.cpp +++ b/src/zenutil/consul/consul.cpp @@ -9,9 +9,13 @@ #include <zencore/logging.h> #include <zencore/process.h> #include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> #include <zencore/thread.h> #include <zencore/timer.h> +#include <zenhttp/httpserver.h> + #include <fmt/format.h> namespace zen::consul { @@ -31,7 +35,7 @@ struct ConsulProcess::Impl } CreateProcOptions Options; - Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; const std::filesystem::path ConsulExe = GetRunningExecutablePath().parent_path() / ("consul" ZEN_EXE_SUFFIX_LITERAL); CreateProcResult Result = CreateProc(ConsulExe, "consul" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options); @@ -107,7 +111,7 @@ ConsulProcess::StopConsulAgent() ////////////////////////////////////////////////////////////////////////// -ConsulClient::ConsulClient(std::string_view BaseUri, std::string_view Token) : m_Token(Token), m_HttpClient(BaseUri) +ConsulClient::ConsulClient(const Configuration& Config) : m_Config(Config), m_HttpClient(m_Config.BaseUri) { } @@ -193,12 +197,18 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) // when no interval is configured (e.g. during Provisioning). Writer.BeginObject("Check"sv); { - Writer.AddString("HTTP"sv, fmt::format("http://{}:{}/{}", Info.Address, Info.Port, Info.HealthEndpoint)); + Writer.AddString( + "HTTP"sv, + fmt::format("http://{}:{}/{}", Info.Address.empty() ? "localhost" : Info.Address, Info.Port, Info.HealthEndpoint)); Writer.AddString("Interval"sv, fmt::format("{}s", Info.HealthIntervalSeconds)); if (Info.DeregisterAfterSeconds != 0) { Writer.AddString("DeregisterCriticalServiceAfter"sv, fmt::format("{}s", Info.DeregisterAfterSeconds)); } + if (!Info.InitialStatus.empty()) + { + Writer.AddString("Status"sv, Info.InitialStatus); + } } Writer.EndObject(); // Check } @@ -223,27 +233,112 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) bool ConsulClient::DeregisterService(std::string_view ServiceId) { + using namespace std::literals; + HttpClient::KeyValueMap AdditionalHeaders; ApplyCommonHeaders(AdditionalHeaders); AdditionalHeaders.Entries.emplace(HttpClient::Accept(HttpContentType::kJSON)); - HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), AdditionalHeaders); + HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), IoBuffer{}, AdditionalHeaders); + if (Result) + { + return true; + } + + // Agent deregister failed - fall back to catalog deregister. + // This handles cases where the service was registered via a different Consul agent + // (e.g. load-balanced endpoint routing to different agents). + std::string NodeName = GetNodeName(); + if (!NodeName.empty()) + { + CbObjectWriter Writer; + Writer.AddString("Node"sv, NodeName); + Writer.AddString("ServiceID"sv, ServiceId); + ExtendableStringBuilder<256> SB; + CompactBinaryToJson(Writer.Save(), SB); + + IoBuffer PayloadBuffer(IoBuffer::Wrap, SB.Data(), SB.Size()); + PayloadBuffer.SetContentType(HttpContentType::kJSON); + + HttpClient::Response CatalogResult = m_HttpClient.Put("v1/catalog/deregister", PayloadBuffer, AdditionalHeaders); + if (CatalogResult) + { + ZEN_INFO("ConsulClient::DeregisterService() deregistered service '{}' via catalog fallback (agent error: {})", + ServiceId, + Result.ErrorMessage("")); + return true; + } + + ZEN_WARN("ConsulClient::DeregisterService() failed to deregister service '{}' (agent: {}, catalog: {})", + ServiceId, + Result.ErrorMessage(""), + CatalogResult.ErrorMessage("")); + } + else + { + ZEN_WARN( + "ConsulClient::DeregisterService() failed to deregister service '{}' (agent: {}, could not determine node name for catalog " + "fallback)", + ServiceId, + Result.ErrorMessage("")); + } + + return false; +} + +std::string +ConsulClient::GetNodeName() +{ + using namespace std::literals; + + HttpClient::KeyValueMap AdditionalHeaders; + ApplyCommonHeaders(AdditionalHeaders); + + HttpClient::Response Result = m_HttpClient.Get("v1/agent/self", AdditionalHeaders); if (!Result) { - ZEN_WARN("ConsulClient::DeregisterService() failed to deregister service '{}' ({})", ServiceId, Result.ErrorMessage("")); - return false; + return {}; } - return true; + std::string JsonError; + CbFieldIterator Root = LoadCompactBinaryFromJson(Result.AsText(), JsonError); + if (!Root || !JsonError.empty()) + { + return {}; + } + + for (CbFieldView Field : Root) + { + if (Field.GetName() == "Config"sv) + { + CbObjectView Config = Field.AsObjectView(); + if (Config) + { + return std::string(Config["NodeName"sv].AsString()); + } + } + } + + return {}; } void ConsulClient::ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap) { - if (!m_Token.empty()) + std::string Token; + if (!m_Config.StaticToken.empty()) { - InOutHeaderMap.Entries.emplace("X-Consul-Token", m_Token); + Token = m_Config.StaticToken; + } + else if (!m_Config.TokenEnvName.empty()) + { + Token = GetEnvVariable(m_Config.TokenEnvName); + } + + if (!Token.empty()) + { + InOutHeaderMap.Entries.emplace("X-Consul-Token", Token); } } @@ -446,4 +541,196 @@ ServiceRegistration::RegistrationLoop() } } +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +consul_forcelink() +{ +} + +struct MockHealthService : public HttpService +{ + std::atomic<bool> FailHealth{false}; + std::atomic<int> HealthCheckCount{0}; + + const char* BaseUri() const override { return "/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + std::string_view Uri = Request.RelativeUri(); + if (Uri == "health/" || Uri == "health") + { + HealthCheckCount.fetch_add(1); + if (FailHealth.load()) + { + Request.WriteResponse(HttpResponseCode::ServiceUnavailable); + } + else + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + } +}; + +struct TestHealthServer +{ + MockHealthService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(0, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + int Port() const { return m_Port; } + + ~TestHealthServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional<ScopedTemporaryDirectory> m_TmpDir; + Ref<HttpServer> m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +static bool +WaitForCondition(std::function<bool()> Predicate, int TimeoutMs, int PollIntervalMs = 200) +{ + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < static_cast<uint64_t>(TimeoutMs)) + { + if (Predicate()) + { + return true; + } + Sleep(PollIntervalMs); + } + return Predicate(); +} + +static std::string +GetCheckStatus(ConsulClient& Client, std::string_view ServiceId) +{ + using namespace std::literals; + + std::string JsonError; + CbFieldIterator ChecksRoot = LoadCompactBinaryFromJson(Client.GetAgentChecksJson(), JsonError); + if (!ChecksRoot || !JsonError.empty()) + { + return {}; + } + + for (CbFieldView F : ChecksRoot) + { + if (!F.IsObject()) + { + continue; + } + for (CbFieldView C : F.AsObjectView()) + { + CbObjectView Check = C.AsObjectView(); + if (Check["ServiceID"sv].AsString() == ServiceId) + { + return std::string(Check["Status"sv].AsString()); + } + } + } + return {}; +} + +TEST_SUITE_BEGIN("util.consul"); + +TEST_CASE("util.consul.service_lifecycle") +{ + ConsulProcess ConsulProc; + ConsulProc.SpawnConsulAgent(); + + TestHealthServer HealthServer; + HealthServer.Start(); + + ConsulClient Client({.BaseUri = "http://localhost:8500/"}); + + const std::string ServiceId = "test-health-svc"; + + ServiceRegistrationInfo Info; + Info.ServiceId = ServiceId; + Info.ServiceName = "zen-test-health"; + Info.Address = "127.0.0.1"; + Info.Port = static_cast<uint16_t>(HealthServer.Port()); + Info.HealthEndpoint = "health/"; + Info.HealthIntervalSeconds = 1; + Info.DeregisterAfterSeconds = 60; + + // Phase 1: Register and verify Consul sends health checks to our service + REQUIRE(Client.RegisterService(Info)); + REQUIRE(Client.HasService(ServiceId)); + + REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000, 50)); + CHECK(HealthServer.Mock.HealthCheckCount.load() >= 1); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing"); + + // Phase 2: Explicit deregister + REQUIRE(Client.DeregisterService(ServiceId)); + CHECK_FALSE(Client.HasService(ServiceId)); + + // Phase 3: Register with InitialStatus, verify immediately passing before any health check fires, + // then fail health and verify check goes critical + HealthServer.Mock.HealthCheckCount.store(0); + HealthServer.Mock.FailHealth.store(false); + + Info.InitialStatus = "passing"; + REQUIRE(Client.RegisterService(Info)); + REQUIRE(Client.HasService(ServiceId)); + + CHECK_EQ(HealthServer.Mock.HealthCheckCount.load(), 0); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing"); + + REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000, 50)); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing"); + + HealthServer.Mock.FailHealth.store(true); + + // Wait for Consul to observe the failing check + REQUIRE(WaitForCondition([&]() { return GetCheckStatus(Client, ServiceId) == "critical"; }, 10000, 50)); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "critical"); + + // Phase 4: Explicit deregister while critical + REQUIRE(Client.DeregisterService(ServiceId)); + CHECK_FALSE(Client.HasService(ServiceId)); + + // Phase 5: Deregister an already-deregistered service - should not crash + Client.DeregisterService(ServiceId); + CHECK_FALSE(Client.HasService(ServiceId)); + + ConsulProc.StopConsulAgent(); +} + +TEST_SUITE_END(); + +#endif + } // namespace zen::consul diff --git a/src/zenremotestore/filesystemutils.cpp b/src/zenutil/filesystemutils.cpp index fdb2143d8..ccc42a838 100644 --- a/src/zenremotestore/filesystemutils.cpp +++ b/src/zenutil/filesystemutils.cpp @@ -1,8 +1,6 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zenremotestore/filesystemutils.h> - -#include <zenremotestore/chunking/chunkedcontent.h> +#include <zenutil/filesystemutils.h> #include <zencore/filesystem.h> #include <zencore/fmtutils.h> @@ -83,117 +81,6 @@ BufferedOpenFile::GetRange(uint64_t Offset, uint64_t Size) return Result; } -ReadFileCache::ReadFileCache(std::atomic<uint64_t>& OpenReadCount, - std::atomic<uint64_t>& CurrentOpenFileCount, - std::atomic<uint64_t>& ReadCount, - std::atomic<uint64_t>& ReadByteCount, - const std::filesystem::path& Path, - const ChunkedFolderContent& LocalContent, - const ChunkedContentLookup& LocalLookup, - size_t MaxOpenFileCount) -: m_Path(Path) -, m_LocalContent(LocalContent) -, m_LocalLookup(LocalLookup) -, m_OpenReadCount(OpenReadCount) -, m_CurrentOpenFileCount(CurrentOpenFileCount) -, m_ReadCount(ReadCount) -, m_ReadByteCount(ReadByteCount) -{ - m_OpenFiles.reserve(MaxOpenFileCount); -} -ReadFileCache::~ReadFileCache() -{ - m_OpenFiles.clear(); -} - -CompositeBuffer -ReadFileCache::GetRange(uint32_t SequenceIndex, uint64_t Offset, uint64_t Size) -{ - ZEN_TRACE_CPU("ReadFileCache::GetRange"); - - auto CacheIt = - std::find_if(m_OpenFiles.begin(), m_OpenFiles.end(), [SequenceIndex](const auto& Lhs) { return Lhs.first == SequenceIndex; }); - if (CacheIt != m_OpenFiles.end()) - { - if (CacheIt != m_OpenFiles.begin()) - { - auto CachedFile(std::move(CacheIt->second)); - m_OpenFiles.erase(CacheIt); - m_OpenFiles.insert(m_OpenFiles.begin(), std::make_pair(SequenceIndex, std::move(CachedFile))); - } - CompositeBuffer Result = m_OpenFiles.front().second->GetRange(Offset, Size); - return Result; - } - const uint32_t LocalPathIndex = m_LocalLookup.SequenceIndexFirstPathIndex[SequenceIndex]; - const std::filesystem::path LocalFilePath = (m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred(); - if (Size == m_LocalContent.RawSizes[LocalPathIndex]) - { - IoBuffer Result = IoBufferBuilder::MakeFromFile(LocalFilePath); - return CompositeBuffer(SharedBuffer(Result)); - } - if (m_OpenFiles.size() == m_OpenFiles.capacity()) - { - m_OpenFiles.pop_back(); - } - m_OpenFiles.insert( - m_OpenFiles.begin(), - std::make_pair( - SequenceIndex, - std::make_unique<BufferedOpenFile>(LocalFilePath, m_OpenReadCount, m_CurrentOpenFileCount, m_ReadCount, m_ReadByteCount))); - CompositeBuffer Result = m_OpenFiles.front().second->GetRange(Offset, Size); - return Result; -} - -uint32_t -SetNativeFileAttributes(const std::filesystem::path FilePath, SourcePlatform SourcePlatform, uint32_t Attributes) -{ -#if ZEN_PLATFORM_WINDOWS - if (SourcePlatform == SourcePlatform::Windows) - { - SetFileAttributesToPath(FilePath, Attributes); - return Attributes; - } - else - { - uint32_t CurrentAttributes = GetFileAttributesFromPath(FilePath); - uint32_t NewAttributes = zen::MakeFileAttributeReadOnly(CurrentAttributes, zen::IsFileModeReadOnly(Attributes)); - if (CurrentAttributes != NewAttributes) - { - SetFileAttributesToPath(FilePath, NewAttributes); - } - return NewAttributes; - } -#endif // ZEN_PLATFORM_WINDOWS -#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC - if (SourcePlatform != SourcePlatform::Windows) - { - zen::SetFileMode(FilePath, Attributes); - return Attributes; - } - else - { - uint32_t CurrentMode = zen::GetFileMode(FilePath); - uint32_t NewMode = zen::MakeFileModeReadOnly(CurrentMode, zen::IsFileAttributeReadOnly(Attributes)); - if (CurrentMode != NewMode) - { - zen::SetFileMode(FilePath, NewMode); - } - return NewMode; - } -#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC -}; - -uint32_t -GetNativeFileAttributes(const std::filesystem::path FilePath) -{ -#if ZEN_PLATFORM_WINDOWS - return GetFileAttributesFromPath(FilePath); -#endif // ZEN_PLATFORM_WINDOWS -#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC - return GetFileMode(FilePath); -#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC -} - bool IsFileWithRetry(const std::filesystem::path& Path) { @@ -256,6 +143,21 @@ RenameFileWithRetry(const std::filesystem::path& SourcePath, const std::filesyst } std::error_code +RenameDirectoryWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath) +{ + std::error_code Ec; + RenameDirectory(SourcePath, TargetPath, Ec); + for (size_t Retries = 0; Ec && Retries < 5; Retries++) + { + ZEN_ASSERT_SLOW(IsDir(SourcePath)); + Sleep(50 + int(Retries * 150)); + Ec.clear(); + RenameDirectory(SourcePath, TargetPath, Ec); + } + return Ec; +} + +std::error_code TryRemoveFile(const std::filesystem::path& Path) { std::error_code Ec; @@ -336,6 +238,124 @@ FastCopyFile(bool AllowFileClone, } } +void +GetDirectoryContent(WorkerThreadPool& WorkerPool, + const std::filesystem::path& Path, + DirectoryContentFlags Flags, + DirectoryContent& OutContent) +{ + struct Visitor : public GetDirectoryContentVisitor + { + Visitor(zen::DirectoryContent& OutContent, const std::filesystem::path& InRootPath) : Content(OutContent), RootPath(InRootPath) {} + virtual bool AsyncAllowDirectory(const std::filesystem::path& Parent, const std::filesystem::path& DirectoryName) const + { + ZEN_UNUSED(Parent, DirectoryName); + return true; + } + virtual void AsyncVisitDirectory(const std::filesystem::path& RelativeRoot, DirectoryContent&& InContent) + { + std::vector<std::filesystem::path> Files; + std::vector<std::filesystem::path> Directories; + + if (!InContent.FileNames.empty()) + { + Files.reserve(InContent.FileNames.size()); + for (const std::filesystem::path& FileName : InContent.FileNames) + { + if (RelativeRoot.empty()) + { + Files.push_back(RootPath / FileName); + } + else + { + Files.push_back(RootPath / RelativeRoot / FileName); + } + } + } + + if (!InContent.DirectoryNames.empty()) + { + Directories.reserve(InContent.DirectoryNames.size()); + for (const std::filesystem::path& DirName : InContent.DirectoryNames) + { + if (RelativeRoot.empty()) + { + Directories.push_back(RootPath / DirName); + } + else + { + Directories.push_back(RootPath / RelativeRoot / DirName); + } + } + } + + Lock.WithExclusiveLock([&]() { + if (!InContent.FileNames.empty()) + { + for (const std::filesystem::path& FileName : InContent.FileNames) + { + if (RelativeRoot.empty()) + { + Content.Files.push_back(RootPath / FileName); + } + else + { + Content.Files.push_back(RootPath / RelativeRoot / FileName); + } + } + } + if (!InContent.FileSizes.empty()) + { + Content.FileSizes.insert(Content.FileSizes.end(), InContent.FileSizes.begin(), InContent.FileSizes.end()); + } + if (!InContent.FileAttributes.empty()) + { + Content.FileAttributes.insert(Content.FileAttributes.end(), + InContent.FileAttributes.begin(), + InContent.FileAttributes.end()); + } + if (!InContent.FileModificationTicks.empty()) + { + Content.FileModificationTicks.insert(Content.FileModificationTicks.end(), + InContent.FileModificationTicks.begin(), + InContent.FileModificationTicks.end()); + } + + if (!InContent.DirectoryNames.empty()) + { + for (const std::filesystem::path& DirName : InContent.DirectoryNames) + { + if (RelativeRoot.empty()) + { + Content.Directories.push_back(RootPath / DirName); + } + else + { + Content.Directories.push_back(RootPath / RelativeRoot / DirName); + } + } + } + if (!InContent.DirectoryAttributes.empty()) + { + Content.DirectoryAttributes.insert(Content.DirectoryAttributes.end(), + InContent.DirectoryAttributes.begin(), + InContent.DirectoryAttributes.end()); + } + }); + } + RwLock Lock; + zen::DirectoryContent& Content; + const std::filesystem::path& RootPath; + }; + + Visitor RootVisitor(OutContent, Path); + + Latch PendingWork(1); + GetDirectoryContent(Path, Flags, RootVisitor, WorkerPool, PendingWork); + PendingWork.CountDown(); + PendingWork.Wait(); +} + CleanDirectoryResult CleanDirectory( WorkerThreadPool& IOWorkerPool, @@ -468,7 +488,7 @@ CleanDirectory( if (!FailedRemovePaths.empty()) { RwLock::ExclusiveLockScope _(ResultLock); - FailedRemovePaths.insert(FailedRemovePaths.end(), FailedRemovePaths.begin(), FailedRemovePaths.end()); + Result.FailedRemovePaths.insert(Result.FailedRemovePaths.end(), FailedRemovePaths.begin(), FailedRemovePaths.end()); } else if (!RelativeRoot.empty()) { @@ -637,7 +657,7 @@ namespace { void GenerateFile(const std::filesystem::path& Path) { BasicFile _(Path, BasicFile::Mode::kTruncate); } } // namespace -TEST_SUITE_BEGIN("remotestore.filesystemutils"); +TEST_SUITE_BEGIN("util.filesystemutils"); TEST_CASE("filesystemutils.CleanDirectory") { diff --git a/src/zenutil/include/zenutil/cloud/mockimds.h b/src/zenutil/include/zenutil/cloud/mockimds.h index d0c0155b0..28e1e8ba6 100644 --- a/src/zenutil/include/zenutil/cloud/mockimds.h +++ b/src/zenutil/include/zenutil/cloud/mockimds.h @@ -23,7 +23,7 @@ namespace zen::compute { * * When a request arrives for a provider that is not the ActiveProvider, the * mock returns 404, causing CloudMetadata to write a sentinel file and move - * on to the next provider — exactly like a failed probe on bare metal. + * on to the next provider - exactly like a failed probe on bare metal. * * All config fields are public and can be mutated between poll cycles to * simulate state changes (e.g. a spot interruption appearing mid-run). @@ -45,13 +45,13 @@ public: std::string AvailabilityZone = "us-east-1a"; std::string LifeCycle = "on-demand"; // "spot" or "on-demand" - // Empty string → endpoint returns 404 (instance not in an ASG). - // Non-empty → returned as the response body. "InService" means healthy; + // Empty string -> endpoint returns 404 (instance not in an ASG). + // Non-empty -> returned as the response body. "InService" means healthy; // anything else (e.g. "Terminated:Wait") triggers termination detection. std::string AutoscalingState; - // Empty string → endpoint returns 404 (no spot interruption). - // Non-empty → returned as the response body, signalling a spot reclaim. + // Empty string -> endpoint returns 404 (no spot interruption). + // Non-empty -> returned as the response body, signalling a spot reclaim. std::string SpotAction; // IAM credential fields for ImdsCredentialProvider testing @@ -69,10 +69,10 @@ public: std::string Location = "eastus"; std::string Priority = "Regular"; // "Spot" or "Regular" - // Empty → instance is not in a VM Scale Set (no autoscaling). + // Empty -> instance is not in a VM Scale Set (no autoscaling). std::string VmScaleSetName; - // Empty → no scheduled events. Set to "Preempt", "Terminate", or + // Empty -> no scheduled events. Set to "Preempt", "Terminate", or // "Reboot" to simulate a termination-class event. std::string ScheduledEventType; std::string ScheduledEventStatus = "Scheduled"; diff --git a/src/zenutil/include/zenutil/cloud/s3client.h b/src/zenutil/include/zenutil/cloud/s3client.h index bd30aa8a2..b0402d231 100644 --- a/src/zenutil/include/zenutil/cloud/s3client.h +++ b/src/zenutil/include/zenutil/cloud/s3client.h @@ -35,9 +35,7 @@ struct S3ClientOptions /// Overrides the static Credentials field. Ref<ImdsCredentialProvider> CredentialProvider; - std::chrono::milliseconds ConnectTimeout{5000}; - std::chrono::milliseconds Timeout{}; - uint8_t RetryCount = 3; + HttpClientSettings HttpSettings = {.LogCategory = "s3", .ConnectTimeout = std::chrono::milliseconds(5000), .RetryCount = 3}; }; struct S3ObjectInfo @@ -70,6 +68,8 @@ struct S3GetObjectResult : S3Result IoBuffer Content; std::string_view AsText() const { return std::string_view(reinterpret_cast<const char*>(Content.GetData()), Content.GetSize()); } + + static std::string_view NotFoundErrorText; }; /// Result of HeadObject - carries object metadata and existence status. @@ -119,7 +119,7 @@ public: S3Result PutObject(std::string_view Key, IoBuffer Content); /// Download an object from S3 - S3GetObjectResult GetObject(std::string_view Key); + S3GetObjectResult GetObject(std::string_view Key, const std::filesystem::path& TempFilePath = {}); /// Download a byte range of an object from S3 /// @param RangeStart First byte offset (inclusive) @@ -219,6 +219,7 @@ private: SigV4Credentials m_Credentials; Ref<ImdsCredentialProvider> m_CredentialProvider; HttpClient m_HttpClient; + bool m_Verbose = false; // Cached signing key (only changes once per day, protected by RwLock for thread safety) mutable RwLock m_SigningKeyLock; diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h index 22737589b..49bb0cc92 100644 --- a/src/zenutil/include/zenutil/consoletui.h +++ b/src/zenutil/include/zenutil/consoletui.h @@ -30,7 +30,7 @@ bool IsTuiAvailable(); // - Title: a short description printed once above the list // - Items: pre-formatted display labels, one per selectable entry // -// Arrow keys (↑/↓) navigate the selection, Enter confirms, Esc cancels. +// Arrow keys (^/v) navigate the selection, Enter confirms, Esc cancels. // Returns the index of the selected item, or -1 if the user cancelled. // // Precondition: IsTuiAvailable() must be true. diff --git a/src/zenutil/include/zenutil/consul.h b/src/zenutil/include/zenutil/consul.h index 4002d5d23..c3d0e5f1d 100644 --- a/src/zenutil/include/zenutil/consul.h +++ b/src/zenutil/include/zenutil/consul.h @@ -23,12 +23,20 @@ struct ServiceRegistrationInfo std::vector<std::pair<std::string, std::string>> Tags; uint32_t HealthIntervalSeconds = 10; uint32_t DeregisterAfterSeconds = 30; + std::string InitialStatus; }; class ConsulClient { public: - ConsulClient(std::string_view BaseUri, std::string_view Token = ""); + struct Configuration + { + std::string BaseUri; + std::string StaticToken; + std::string TokenEnvName; + }; + + ConsulClient(const Configuration& Config); ~ConsulClient(); ConsulClient(const ConsulClient&) = delete; @@ -55,9 +63,10 @@ public: private: static bool FindServiceInJson(std::string_view Json, std::string_view ServiceId); void ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap); + std::string GetNodeName(); - std::string m_Token; - HttpClient m_HttpClient; + Configuration m_Config; + HttpClient m_HttpClient; }; class ConsulProcess @@ -109,4 +118,6 @@ private: void RegistrationLoop(); }; +void consul_forcelink(); + } // namespace zen::consul diff --git a/src/zenremotestore/include/zenremotestore/filesystemutils.h b/src/zenutil/include/zenutil/filesystemutils.h index cb2d718f7..05defd1a8 100644 --- a/src/zenremotestore/include/zenremotestore/filesystemutils.h +++ b/src/zenutil/include/zenutil/filesystemutils.h @@ -3,7 +3,7 @@ #pragma once #include <zencore/basicfile.h> -#include <zenremotestore/chunking/chunkedcontent.h> +#include <zencore/filesystem.h> namespace zen { @@ -42,42 +42,12 @@ private: IoBuffer m_Cache; }; -class ReadFileCache -{ -public: - // A buffered file reader that provides CompositeBuffer where the buffers are owned and the memory never overwritten - ReadFileCache(std::atomic<uint64_t>& OpenReadCount, - std::atomic<uint64_t>& CurrentOpenFileCount, - std::atomic<uint64_t>& ReadCount, - std::atomic<uint64_t>& ReadByteCount, - const std::filesystem::path& Path, - const ChunkedFolderContent& LocalContent, - const ChunkedContentLookup& LocalLookup, - size_t MaxOpenFileCount); - ~ReadFileCache(); - - CompositeBuffer GetRange(uint32_t SequenceIndex, uint64_t Offset, uint64_t Size); - -private: - const std::filesystem::path m_Path; - const ChunkedFolderContent& m_LocalContent; - const ChunkedContentLookup& m_LocalLookup; - std::vector<std::pair<uint32_t, std::unique_ptr<BufferedOpenFile>>> m_OpenFiles; - std::atomic<uint64_t>& m_OpenReadCount; - std::atomic<uint64_t>& m_CurrentOpenFileCount; - std::atomic<uint64_t>& m_ReadCount; - std::atomic<uint64_t>& m_ReadByteCount; -}; - -uint32_t SetNativeFileAttributes(const std::filesystem::path FilePath, SourcePlatform SourcePlatform, uint32_t Attributes); - -uint32_t GetNativeFileAttributes(const std::filesystem::path FilePath); - bool IsFileWithRetry(const std::filesystem::path& Path); bool SetFileReadOnlyWithRetry(const std::filesystem::path& Path, bool ReadOnly); std::error_code RenameFileWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); +std::error_code RenameDirectoryWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); std::error_code TryRemoveFile(const std::filesystem::path& Path); @@ -101,6 +71,13 @@ struct CleanDirectoryResult std::vector<std::pair<std::filesystem::path, std::error_code>> FailedRemovePaths; }; +class WorkerThreadPool; + +void GetDirectoryContent(WorkerThreadPool& WorkerPool, + const std::filesystem::path& Path, + DirectoryContentFlags Flags, + DirectoryContent& OutContent); + CleanDirectoryResult CleanDirectory( WorkerThreadPool& IOWorkerPool, std::atomic<bool>& AbortFlag, diff --git a/src/zenutil/include/zenutil/process/subprocessmanager.h b/src/zenutil/include/zenutil/process/subprocessmanager.h index 4a25170df..95d7fa43d 100644 --- a/src/zenutil/include/zenutil/process/subprocessmanager.h +++ b/src/zenutil/include/zenutil/process/subprocessmanager.h @@ -95,19 +95,24 @@ public: /// Spawn a new child process and begin monitoring it. /// /// If Options.StdoutPipe is set, the pipe is consumed and async reading - /// begins automatically. Similarly for Options.StderrPipe. + /// begins automatically. Similarly for Options.StderrPipe. When providing + /// pipes, pass the corresponding data callback here so it is installed + /// before the first async read completes - setting it later via + /// SetStdoutCallback risks losing early output. /// /// Returns a non-owning pointer valid until Remove() or manager destruction. /// The exit callback fires on an io_context thread when the process terminates. ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout = {}, + ProcessDataCallback OnStderr = {}); /// Adopt an already-running process by handle. Takes ownership of handle internals. ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); - /// Stop monitoring a process by pid. Does NOT kill the process — call + /// Stop monitoring a process by pid. Does NOT kill the process - call /// process->Kill() first if needed. The exit callback will not fire after /// this returns. void Remove(int Pid); @@ -182,12 +187,6 @@ public: /// yet computed. [[nodiscard]] float GetCpuUsagePercent() const; - /// Set per-process stdout callback (overrides manager default). - void SetStdoutCallback(ProcessDataCallback Callback); - - /// Set per-process stderr callback (overrides manager default). - void SetStderrCallback(ProcessDataCallback Callback); - /// Return all stdout captured so far. When a callback is set, output is /// delivered there instead of being accumulated. [[nodiscard]] std::string GetCapturedStdout() const; @@ -220,7 +219,7 @@ private: /// A group of managed processes with OS-level backing. /// /// On Windows: backed by a JobObject. All processes assigned on spawn. -/// Kill-on-close guarantee — if the group is destroyed, the OS terminates +/// Kill-on-close guarantee - if the group is destroyed, the OS terminates /// all member processes. /// On Linux/macOS: uses setpgid() so children share a process group. /// Enables bulk signal delivery via kill(-pgid, sig). @@ -237,11 +236,14 @@ public: /// Group name (as passed to CreateGroup). [[nodiscard]] std::string_view GetName() const; - /// Spawn a process into this group. + /// Spawn a process into this group. See SubprocessManager::Spawn for + /// details on the stdout/stderr callback parameters. ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout = {}, + ProcessDataCallback OnStderr = {}); /// Adopt an already-running process into this group. /// On Windows the process is assigned to the group's JobObject. diff --git a/src/zenutil/include/zenutil/progress.h b/src/zenutil/include/zenutil/progress.h new file mode 100644 index 000000000..6a137ae9c --- /dev/null +++ b/src/zenutil/include/zenutil/progress.h @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logbase.h> + +#include <memory> +#include <string> + +namespace zen { + +class ProgressBase +{ +public: + virtual ~ProgressBase() = default; + + virtual void SetLogOperationName(std::string_view Name) = 0; + virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) = 0; + virtual uint32_t GetProgressUpdateDelayMS() const = 0; + + class ProgressBar + { + public: + struct State + { + bool operator==(const State&) const = default; + std::string Task; + std::string Details; + uint64_t TotalCount = 0; + uint64_t RemainingCount = 0; + uint64_t OptionalElapsedTime = (uint64_t)-1; + enum class EStatus + { + Running, + Aborted, + Paused + }; + EStatus Status = EStatus::Running; + + static constexpr EStatus CalculateStatus(bool IsAborted, bool IsPaused) + { + if (IsAborted) + { + return EStatus::Aborted; + } + if (IsPaused) + { + return EStatus::Paused; + } + return EStatus::Running; + } + }; + + virtual ~ProgressBar() = default; + + virtual void UpdateState(const State& NewState, bool DoLinebreak) = 0; + virtual void Finish() = 0; + }; + + virtual std::unique_ptr<ProgressBar> CreateProgressBar(std::string_view InSubTask) = 0; +}; + +ProgressBase* CreateStandardProgress(LoggerRef Log); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/sessionsclient.h b/src/zenutil/include/zenutil/sessionsclient.h index aca45e61d..12ff5e593 100644 --- a/src/zenutil/include/zenutil/sessionsclient.h +++ b/src/zenutil/include/zenutil/sessionsclient.h @@ -35,13 +35,13 @@ public: SessionsServiceClient(const SessionsServiceClient&) = delete; SessionsServiceClient& operator=(const SessionsServiceClient&) = delete; - /// POST /sessions/{id} — register or re-announce the session with optional metadata. + /// POST /sessions/{id} - register or re-announce the session with optional metadata. [[nodiscard]] bool Announce(CbObjectView Metadata = {}); - /// PUT /sessions/{id} — update metadata on an existing session. + /// PUT /sessions/{id} - update metadata on an existing session. [[nodiscard]] bool UpdateMetadata(CbObjectView Metadata = {}); - /// DELETE /sessions/{id} — remove the session. + /// DELETE /sessions/{id} - remove the session. [[nodiscard]] bool Remove(); /// Create a logging sink that forwards log messages to the session's log endpoint. diff --git a/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h b/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h index f4ac5ff22..4387e616a 100644 --- a/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h +++ b/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h @@ -85,12 +85,12 @@ public: void Flush() override { - // Nothing to flush — writes happen asynchronously + // Nothing to flush - writes happen asynchronously } void SetFormatter(std::unique_ptr<logging::Formatter> /*InFormatter*/) override { - // Not used — we output the raw payload directly + // Not used - we output the raw payload directly } private: @@ -124,7 +124,7 @@ private: { break; // don't retry during shutdown } - continue; // drop batch — will retry on next batch + continue; // drop batch - will retry on next batch } // Build a gathered buffer sequence so the entire batch is written @@ -176,7 +176,7 @@ private: std::string m_Source; uint32_t m_MaxQueueSize; - // Sequence counter — incremented atomically by Log() callers. + // Sequence counter - incremented atomically by Log() callers. // Gaps in the sequence seen by the receiver indicate dropped messages. std::atomic<uint64_t> m_NextSequence{0}; diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h index 03d507400..d6f66fbea 100644 --- a/src/zenutil/include/zenutil/zenserverprocess.h +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -66,6 +66,7 @@ public: std::filesystem::path CreateNewTestDir(); std::filesystem::path CreateChildDir(std::string_view ChildName); std::filesystem::path ProgramBaseDir() const { return m_ProgramBaseDir; } + std::filesystem::path GetChildBaseDir() const { return m_ChildProcessBaseDir; } std::filesystem::path GetTestRootDir(std::string_view Path); inline bool IsInitialized() const { return m_IsInitialized; } inline bool IsTestEnvironment() const { return m_IsTestInstance; } diff --git a/src/zenutil/logging/fullformatter.cpp b/src/zenutil/logging/fullformatter.cpp index 2a4840241..283a8bc37 100644 --- a/src/zenutil/logging/fullformatter.cpp +++ b/src/zenutil/logging/fullformatter.cpp @@ -12,6 +12,7 @@ #include <atomic> #include <chrono> #include <string> +#include "zencore/logging.h" namespace zen::logging { @@ -25,7 +26,7 @@ struct FullFormatter::Impl { } - explicit Impl(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {} + explicit Impl(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' ') {} std::chrono::time_point<std::chrono::system_clock> m_Epoch; std::tm m_CachedLocalTm{}; @@ -155,15 +156,7 @@ FullFormatter::Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) OutBuffer.push_back(' '); } - // append logger name if exists - if (Msg.GetLoggerName().size() > 0) - { - OutBuffer.push_back('['); - helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - } - + // level OutBuffer.push_back('['); if (IsColorEnabled()) { @@ -177,6 +170,23 @@ FullFormatter::Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) OutBuffer.push_back(']'); OutBuffer.push_back(' '); + // logger name + if (Msg.GetLoggerName().size() > 0) + { + if (IsColorEnabled()) + { + OutBuffer.append("\033[1m"sv); + } + OutBuffer.push_back('['); + helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); + OutBuffer.push_back(']'); + if (IsColorEnabled()) + { + OutBuffer.append("\033[0m"sv); + } + OutBuffer.push_back(' '); + } + // add source location if present if (Msg.GetSource()) { diff --git a/src/zenutil/logging/jsonformatter.cpp b/src/zenutil/logging/jsonformatter.cpp index 673a03c94..c63ad891e 100644 --- a/src/zenutil/logging/jsonformatter.cpp +++ b/src/zenutil/logging/jsonformatter.cpp @@ -19,8 +19,6 @@ static void WriteEscapedString(MemoryBuffer& Dest, std::string_view Text) { // Strip ANSI SGR sequences before escaping so they don't appear in JSON output - static const auto IsEscapeStart = [](char C) { return C == '\033'; }; - const char* RangeStart = Text.data(); const char* End = Text.data() + Text.size(); diff --git a/src/zenutil/logging/logging.cpp b/src/zenutil/logging/logging.cpp index aa34fc50c..c1636da61 100644 --- a/src/zenutil/logging/logging.cpp +++ b/src/zenutil/logging/logging.cpp @@ -124,7 +124,7 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) LoggerRef DefaultLogger = zen::logging::Default(); - // Build the broadcast sink — a shared indirection point that all + // Build the broadcast sink - a shared indirection point that all // loggers cloned from the default will share. Adding or removing // a child sink later is immediately visible to every logger. std::vector<logging::SinkPtr> BroadcastChildren; @@ -179,7 +179,7 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) { return; } - static constinit logging::LogPoint ErrorPoint{{}, logging::Err, "{}"}; + static constinit logging::LogPoint ErrorPoint{0, 0, logging::Err, "{}"}; if (auto ErrLogger = zen::logging::ErrorLog()) { try @@ -249,7 +249,7 @@ FinishInitializeLogging(const LoggingOptions& LogOptions) const std::string StartLogTime = zen::DateTime::Now().ToIso8601(); logging::Registry::Instance().ApplyAll([&](auto Logger) { - static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"}; + static constinit logging::LogPoint LogStartPoint{0, 0, logging::Info, "log starting at {}"}; Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime)); }); } diff --git a/src/zenutil/logging/rotatingfilesink.cpp b/src/zenutil/logging/rotatingfilesink.cpp index 23cf60d16..df59af5fe 100644 --- a/src/zenutil/logging/rotatingfilesink.cpp +++ b/src/zenutil/logging/rotatingfilesink.cpp @@ -85,7 +85,7 @@ struct RotatingFileSink::Impl m_CurrentSize = m_CurrentFile.FileSize(OutEc); if (OutEc) { - // FileSize failed but we have an open file — reset to 0 + // FileSize failed but we have an open file - reset to 0 // so we can at least attempt writes from the start m_CurrentSize = 0; OutEc.clear(); diff --git a/src/zenutil/process/asyncpipereader.cpp b/src/zenutil/process/asyncpipereader.cpp index 2fdcda30d..8eac350c6 100644 --- a/src/zenutil/process/asyncpipereader.cpp +++ b/src/zenutil/process/asyncpipereader.cpp @@ -50,7 +50,7 @@ struct AsyncPipeReader::Impl int Fd = Pipe.ReadFd; - // Close the write end — child already has it + // Close the write end - child already has it Pipe.CloseWriteEnd(); // Set non-blocking @@ -156,7 +156,7 @@ CreateOverlappedStdoutPipe(StdoutPipeHandles& OutPipe) // The read end should not be inherited by the child SetHandleInformation(ReadHandle, HANDLE_FLAG_INHERIT, 0); - // Open the client (write) end — inheritable, for the child process + // Open the client (write) end - inheritable, for the child process SECURITY_ATTRIBUTES Sa; Sa.nLength = sizeof(Sa); Sa.lpSecurityDescriptor = nullptr; @@ -202,7 +202,7 @@ struct AsyncPipeReader::Impl HANDLE ReadHandle = static_cast<HANDLE>(Pipe.ReadHandle); - // Close the write end — child already has it + // Close the write end - child already has it Pipe.CloseWriteEnd(); // Take ownership of the read handle diff --git a/src/zenutil/process/subprocessmanager.cpp b/src/zenutil/process/subprocessmanager.cpp index 3a91b0a61..d0b912a0d 100644 --- a/src/zenutil/process/subprocessmanager.cpp +++ b/src/zenutil/process/subprocessmanager.cpp @@ -196,18 +196,6 @@ ManagedProcess::GetCpuUsagePercent() const return m_Impl->m_CpuUsagePercent.load(); } -void -ManagedProcess::SetStdoutCallback(ProcessDataCallback Callback) -{ - m_Impl->m_StdoutCallback = std::move(Callback); -} - -void -ManagedProcess::SetStderrCallback(ProcessDataCallback Callback) -{ - m_Impl->m_StderrCallback = std::move(Callback); -} - std::string ManagedProcess::GetCapturedStdout() const { @@ -288,7 +276,9 @@ struct SubprocessManager::Impl ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr); ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); void Remove(int Pid); void RemoveAll(); @@ -462,7 +452,9 @@ ManagedProcess* SubprocessManager::Impl::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { bool HasStdout = Options.StdoutPipe != nullptr; bool HasStderr = Options.StderrPipe != nullptr; @@ -476,6 +468,16 @@ SubprocessManager::Impl::Spawn(const std::filesystem::path& Executable, ImplPtr->m_Handle.Initialize(static_cast<int>(Result)); #endif + // Install callbacks before starting async readers so no data is missed. + if (OnStdout) + { + ImplPtr->m_StdoutCallback = std::move(OnStdout); + } + if (OnStderr) + { + ImplPtr->m_StderrCallback = std::move(OnStderr); + } + auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr))); ManagedProcess* Ptr = AddProcess(std::move(Proc)); @@ -719,10 +721,12 @@ ManagedProcess* SubprocessManager::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { ZEN_TRACE_CPU("SubprocessManager::Spawn"); - return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit)); + return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr)); } ManagedProcess* @@ -835,7 +839,9 @@ struct ProcessGroup::Impl ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr); ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); void Remove(int Pid); void KillAll(); @@ -884,7 +890,9 @@ ManagedProcess* ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { bool HasStdout = Options.StdoutPipe != nullptr; bool HasStderr = Options.StderrPipe != nullptr; @@ -895,7 +903,11 @@ ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable, Options.AssignToJob = &m_JobObject; } #else - if (m_Pgid > 0) + if (m_Pgid == 0) + { + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; + } + else { Options.ProcessGroupId = m_Pgid; } @@ -917,6 +929,16 @@ ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable, } #endif + // Install callbacks before starting async readers so no data is missed. + if (OnStdout) + { + ImplPtr->m_StdoutCallback = std::move(OnStdout); + } + if (OnStderr) + { + ImplPtr->m_StderrCallback = std::move(OnStderr); + } + auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr))); ManagedProcess* Ptr = AddProcess(std::move(Proc)); @@ -1077,10 +1099,12 @@ ManagedProcess* ProcessGroup::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { ZEN_TRACE_CPU("ProcessGroup::Spawn"); - return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit)); + return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr)); } ManagedProcess* @@ -1185,7 +1209,17 @@ TEST_CASE("SubprocessManager.SpawnAndDetectExit") CallbackFired = true; }); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (CallbackFired) + { + break; + } + } + } CHECK(CallbackFired); CHECK(ReceivedExitCode == 42); @@ -1210,7 +1244,17 @@ TEST_CASE("SubprocessManager.SpawnAndDetectCleanExit") CallbackFired = true; }); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (CallbackFired) + { + break; + } + } + } CHECK(CallbackFired); CHECK(ReceivedExitCode == 0); @@ -1235,7 +1279,17 @@ TEST_CASE("SubprocessManager.StdoutCapture") ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } CHECK(Exited); std::string Captured = Proc->GetCapturedStdout(); @@ -1264,7 +1318,17 @@ TEST_CASE("SubprocessManager.StderrCapture") ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } CHECK(Exited); std::string CapturedErr = Proc->GetCapturedStderr(); @@ -1289,11 +1353,24 @@ TEST_CASE("SubprocessManager.StdoutCallback") std::string ReceivedData; bool Exited = false; - ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); + ManagedProcess* Proc = Manager.Spawn( + AppStub, + CmdLine, + Options, + [&](ManagedProcess&, int) { Exited = true; }, + [&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); }); - Proc->SetStdoutCallback([&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); }); - - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } CHECK(Exited); CHECK(ReceivedData.find("callback_test") != std::string::npos); @@ -1316,8 +1393,18 @@ TEST_CASE("SubprocessManager.MetricsSampling") ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); - // Run for enough time to get metrics samples - IoContext.run_for(1s); + // Poll until metrics are available + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Proc->GetLatestMetrics().WorkingSetSize > 0) + { + break; + } + } + } ProcessMetrics Metrics = Proc->GetLatestMetrics(); CHECK(Metrics.WorkingSetSize > 0); @@ -1326,7 +1413,17 @@ TEST_CASE("SubprocessManager.MetricsSampling") CHECK(Snapshot.size() == 1); // Let it finish - IoContext.run_for(3s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 10'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } CHECK(Exited); } @@ -1350,7 +1447,7 @@ TEST_CASE("SubprocessManager.RemoveWhileRunning") // Let it start IoContext.run_for(100ms); - // Remove without killing — callback should NOT fire after this + // Remove without killing - callback should NOT fire after this Manager.Remove(Pid); IoContext.run_for(500ms); @@ -1375,12 +1472,31 @@ TEST_CASE("SubprocessManager.KillAndWaitForExit") ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { CallbackFired = true; }); // Let it start - IoContext.run_for(200ms); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Proc->IsRunning()) + { + break; + } + } + } Proc->Kill(); - IoContext.run_for(2s); - + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (CallbackFired) + { + break; + } + } + } CHECK(CallbackFired); } @@ -1401,7 +1517,17 @@ TEST_CASE("SubprocessManager.AdoptProcess") Manager.Adopt(ProcessHandle(Result), [&](ManagedProcess&, int ExitCode) { ReceivedExitCode = ExitCode; }); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (ReceivedExitCode != -1) + { + break; + } + } + } CHECK(ReceivedExitCode == 7); } @@ -1424,7 +1550,17 @@ TEST_CASE("SubprocessManager.UserTag") Proc->SetTag("my-worker-1"); CHECK(Proc->GetTag() == "my-worker-1"); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (!ReceivedTag.empty()) + { + break; + } + } + } CHECK(ReceivedTag == "my-worker-1"); } @@ -1454,7 +1590,17 @@ TEST_CASE("ProcessGroup.SpawnAndMembership") CHECK(Group->GetProcessCount() == 2); CHECK(Manager.GetProcessCount() == 2); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (ExitCount == 2) + { + break; + } + } + } CHECK(ExitCount == 2); } @@ -1504,7 +1650,17 @@ TEST_CASE("ProcessGroup.AggregateMetrics") Group->Spawn(AppStub, CmdLine, Options, [](ManagedProcess&, int) {}); // Wait for metrics sampling - IoContext.run_for(1s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Group->GetAggregateMetrics().TotalWorkingSetSize > 0) + { + break; + } + } + } AggregateProcessMetrics GroupAgg = Group->GetAggregateMetrics(); CHECK(GroupAgg.ProcessCount == 2); @@ -1570,7 +1726,17 @@ TEST_CASE("ProcessGroup.MixedGroupedAndUngrouped") CHECK(Group->GetProcessCount() == 2); CHECK(Manager.GetProcessCount() == 3); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (GroupExitCount == 2 && UngroupedExitCode != -1) + { + break; + } + } + } CHECK(GroupExitCount == 2); CHECK(UngroupedExitCode == 0); @@ -1590,7 +1756,7 @@ TEST_CASE("ProcessGroup.FindGroup") TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) { - // Seed for reproducibility — change to explore different orderings + // Seed for reproducibility - change to explore different orderings // // Note that while this is a stress test, it is still single-threaded @@ -1619,7 +1785,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // Phase 1: Spawn multiple groups with varied workloads // ======================================================================== - ZEN_INFO("StressTest: Phase 1 — spawning initial groups"); + ZEN_INFO("StressTest: Phase 1 - spawning initial groups"); constexpr int NumInitialGroups = 8; std::vector<std::string> GroupNames; @@ -1673,7 +1839,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // Phase 2: Randomly kill some groups, create replacements, add ungrouped // ======================================================================== - ZEN_INFO("StressTest: Phase 2 — random group kills and replacements"); + ZEN_INFO("StressTest: Phase 2 - random group kills and replacements"); constexpr int NumGroupsToKill = 3; @@ -1738,7 +1904,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // Phase 3: Rapid spawn/exit churn // ======================================================================== - ZEN_INFO("StressTest: Phase 3 — rapid spawn/exit churn"); + ZEN_INFO("StressTest: Phase 3 - rapid spawn/exit churn"); std::atomic<int> ChurnExitCount{0}; int TotalChurnSpawned = 0; @@ -1762,7 +1928,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // Brief pump to allow some exits to be processed IoContext.run_for(200ms); - // Destroy the group — any still-running processes get killed + // Destroy the group - any still-running processes get killed Manager.DestroyGroup(Name); } @@ -1772,7 +1938,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // Phase 4: Drain and verify // ======================================================================== - ZEN_INFO("StressTest: Phase 4 — draining remaining processes"); + ZEN_INFO("StressTest: Phase 4 - draining remaining processes"); // Check metrics were collected before we wind down AggregateProcessMetrics Agg = Manager.GetAggregateMetrics(); @@ -1803,7 +1969,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // (exact count is hard to predict due to killed groups, but should be > 0) CHECK(TotalExitCallbacks.load() > 0); - ZEN_INFO("StressTest: PASSED — seed={}", Seed); + ZEN_INFO("StressTest: PASSED - seed={}", Seed); } TEST_SUITE_END(); diff --git a/src/zenutil/progress.cpp b/src/zenutil/progress.cpp new file mode 100644 index 000000000..a12076dcc --- /dev/null +++ b/src/zenutil/progress.cpp @@ -0,0 +1,99 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/progress.h> + +#include <zencore/logging.h> + +namespace zen { + +class StandardProgressBase; + +class StandardProgressBar : public ProgressBase::ProgressBar +{ +public: + StandardProgressBar(StandardProgressBase& Owner, std::string_view InSubTask) : m_Owner(Owner), m_SubTask(InSubTask) {} + + virtual void UpdateState(const State& NewState, bool DoLinebreak) override; + virtual void Finish() override; + +private: + LoggerRef Log(); + StandardProgressBase& m_Owner; + std::string m_SubTask; + State m_State; +}; + +class StandardProgressBase : public ProgressBase +{ +public: + StandardProgressBase(LoggerRef Log) : m_Log(Log) {} + + virtual void SetLogOperationName(std::string_view Name) override + { + m_LogOperationName = Name; + ZEN_INFO("{}", m_LogOperationName); + } + virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override + { + const size_t PercentDone = StepCount > 0u ? (100u * StepIndex) / StepCount : 0u; + ZEN_INFO("{}: {}%", m_LogOperationName, PercentDone); + } + virtual uint32_t GetProgressUpdateDelayMS() const override { return 2000; } + virtual std::unique_ptr<ProgressBar> CreateProgressBar(std::string_view InSubTask) override + { + return std::make_unique<StandardProgressBar>(*this, InSubTask); + } + +private: + friend class StandardProgressBar; + LoggerRef m_Log; + std::string m_LogOperationName; + LoggerRef Log() { return m_Log; } +}; + +LoggerRef +StandardProgressBar::Log() +{ + return m_Owner.Log(); +} + +void +StandardProgressBar::UpdateState(const State& NewState, bool DoLinebreak) +{ + ZEN_UNUSED(DoLinebreak); + const size_t PercentDone = + NewState.TotalCount > 0u ? (100u * (NewState.TotalCount - NewState.RemainingCount)) / NewState.TotalCount : 0u; + std::string Task = NewState.Task; + switch (NewState.Status) + { + case State::EStatus::Aborted: + Task = "Aborting"; + break; + case State::EStatus::Paused: + Task = "Paused"; + break; + default: + break; + } + ZEN_INFO("{}: {}%{}", Task, PercentDone, NewState.Details.empty() ? "" : fmt::format(" {}", NewState.Details)); + m_State = NewState; +} +void +StandardProgressBar::Finish() +{ + if (m_State.RemainingCount > 0) + { + State NewState = m_State; + NewState.RemainingCount = 0; + NewState.Details = ""; + UpdateState(NewState, /*DoLinebreak*/ true); + } +} + +ProgressBase* +CreateStandardProgress(LoggerRef Log) +{ + return new StandardProgressBase(Log); +} + +} // namespace zen diff --git a/src/zenutil/sessionsclient.cpp b/src/zenutil/sessionsclient.cpp index c62cc4099..ec9c6177a 100644 --- a/src/zenutil/sessionsclient.cpp +++ b/src/zenutil/sessionsclient.cpp @@ -21,7 +21,7 @@ namespace zen { ////////////////////////////////////////////////////////////////////////// // -// SessionLogSink — batching log sink that forwards to /sessions/{id}/log +// SessionLogSink - batching log sink that forwards to /sessions/{id}/log // static const char* @@ -108,7 +108,7 @@ public: void SetFormatter(std::unique_ptr<logging::Formatter> /*InFormatter*/) override { - // No formatting needed — we send raw message text + // No formatting needed - we send raw message text } private: @@ -124,6 +124,9 @@ private: { if (Msg.Type == BufferedLogEntry::Type::Shutdown) { + // Mark complete so WaitAndDequeue returns false on empty queue + m_Queue.CompleteAdding(); + // Drain remaining log entries BufferedLogEntry Remaining; while (m_Queue.WaitAndDequeue(Remaining)) @@ -172,7 +175,7 @@ private: { SendBatch(Batch); } - // Drain remaining + m_Queue.CompleteAdding(); while (m_Queue.WaitAndDequeue(Extra)) { if (Extra.Type == BufferedLogEntry::Type::Log) @@ -226,7 +229,7 @@ private: } catch (const std::exception&) { - // Best-effort — silently discard on failure + // Best-effort - silently discard on failure } } diff --git a/src/zenutil/splitconsole/logstreamlistener.cpp b/src/zenutil/splitconsole/logstreamlistener.cpp index 04718b543..df985a196 100644 --- a/src/zenutil/splitconsole/logstreamlistener.cpp +++ b/src/zenutil/splitconsole/logstreamlistener.cpp @@ -17,7 +17,7 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { ////////////////////////////////////////////////////////////////////////// -// LogStreamSession — reads CbObject-framed messages from a single TCP connection +// LogStreamSession - reads CbObject-framed messages from a single TCP connection class LogStreamSession : public RefCounted { @@ -34,7 +34,7 @@ private: [Self](const asio::error_code& Ec, std::size_t BytesRead) { if (Ec) { - return; // connection closed or error — session ends + return; // connection closed or error - session ends } Self->m_BufferUsed += BytesRead; Self->ProcessBuffer(); @@ -119,7 +119,7 @@ private: m_BufferUsed -= Consumed; } - // If buffer is full and we can't parse a message, the message is too large — drop connection + // If buffer is full and we can't parse a message, the message is too large - drop connection if (m_BufferUsed == m_ReadBuf.size()) { ZEN_WARN("LogStreamSession: buffer full with no complete message, dropping connection"); @@ -141,7 +141,7 @@ private: struct LogStreamListener::Impl { - // Owned io_context mode — creates and runs its own thread + // Owned io_context mode - creates and runs its own thread Impl(LogStreamTarget& Target, uint16_t Port) : m_Target(Target) , m_OwnedIoContext(std::make_unique<asio::io_context>()) @@ -154,7 +154,7 @@ struct LogStreamListener::Impl }); } - // External io_context mode — caller drives the io_context + // External io_context mode - caller drives the io_context Impl(LogStreamTarget& Target, asio::io_context& IoContext, uint16_t Port) : m_Target(Target), m_Acceptor(IoContext) { SetupAcceptor(Port); @@ -312,7 +312,7 @@ namespace { logging::LogMessage MakeLogMessage(std::string_view Text, logging::LogLevel Level = logging::Info) { - static logging::LogPoint Point{{}, Level, {}}; + static logging::LogPoint Point{0, 0, Level, {}}; Point.Level = Level; return logging::LogMessage(Point, "test", Text); } @@ -367,7 +367,7 @@ TEST_CASE("DroppedMessageDetection") asio::ip::tcp::socket Socket(IoContext); Socket.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), Listener.GetPort())); - // Send seq=0, then seq=5 — the listener should detect a gap of 4 + // Send seq=0, then seq=5 - the listener should detect a gap of 4 for (uint64_t Seq : {uint64_t(0), uint64_t(5)}) { CbObjectWriter Writer; diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index 2b27b2d8b..e1ffeeb3e 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -181,7 +181,7 @@ ZenServerState::Initialize() ThrowLastError("Could not map view of Zen server state"); } #else - int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, geteuid() == 0 ? 0766 : 0666); + int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT, geteuid() == 0 ? 0766 : 0666); if (Fd < 0) { // Work around a potential issue if the service user is changed in certain configurations. @@ -191,7 +191,7 @@ ZenServerState::Initialize() // shared memory object and retry, we'll be able to get past shm_open() so long as we have // the appropriate permissions to create the shared memory object. shm_unlink("/UnrealEngineZen"); - Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, geteuid() == 0 ? 0766 : 0666); + Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT, geteuid() == 0 ? 0766 : 0666); if (Fd < 0) { ThrowLastError("Could not open a shared memory object"); @@ -244,7 +244,7 @@ ZenServerState::InitializeReadOnly() ThrowLastError("Could not map view of Zen server state"); } #else - int Fd = shm_open("/UnrealEngineZen", O_RDONLY | O_CLOEXEC, 0666); + int Fd = shm_open("/UnrealEngineZen", O_RDONLY, 0666); if (Fd < 0) { return false; @@ -267,6 +267,8 @@ ZenServerState::InitializeReadOnly() ZenServerState::ZenServerEntry* ZenServerState::Lookup(int DesiredListenPort) const { + const uint32_t OurPid = GetCurrentProcessId(); + for (int i = 0; i < m_MaxEntryCount; ++i) { uint16_t EntryPort = m_Data[i].DesiredListenPort; @@ -274,6 +276,14 @@ ZenServerState::Lookup(int DesiredListenPort) const { if (DesiredListenPort == 0 || (EntryPort == DesiredListenPort)) { + // If the entry's PID matches our own but we haven't registered yet, + // this is a stale entry from a previous process incarnation (e.g. PID 1 + // reuse after unclean shutdown in k8s). Skip it. + if (m_Data[i].Pid == OurPid && m_OurEntry == nullptr) + { + continue; + } + std::error_code _; if (IsProcessRunning(m_Data[i].Pid, _)) { @@ -289,6 +299,8 @@ ZenServerState::Lookup(int DesiredListenPort) const ZenServerState::ZenServerEntry* ZenServerState::LookupByEffectivePort(int Port) const { + const uint32_t OurPid = GetCurrentProcessId(); + for (int i = 0; i < m_MaxEntryCount; ++i) { uint16_t EntryPort = m_Data[i].EffectiveListenPort; @@ -296,6 +308,11 @@ ZenServerState::LookupByEffectivePort(int Port) const { if (EntryPort == Port) { + if (m_Data[i].Pid == OurPid && m_OurEntry == nullptr) + { + continue; + } + std::error_code _; if (IsProcessRunning(m_Data[i].Pid, _)) { @@ -358,12 +375,26 @@ ZenServerState::Sweep() ZEN_ASSERT(m_IsReadOnly == false); + const uint32_t OurPid = GetCurrentProcessId(); + for (int i = 0; i < m_MaxEntryCount; ++i) { ZenServerEntry& Entry = m_Data[i]; if (Entry.DesiredListenPort) { + // If the entry's PID matches our own but we haven't registered yet, + // this is a stale entry from a previous process incarnation (e.g. PID 1 + // reuse after unclean shutdown in k8s). Reclaim it. + if (Entry.Pid == OurPid && m_OurEntry == nullptr) + { + ZEN_CONSOLE_DEBUG("Sweep - pid {} matches current process but no registration yet, reclaiming stale entry (port {})", + Entry.Pid.load(), + Entry.DesiredListenPort.load()); + Entry.Reset(); + continue; + } + std::error_code ErrorCode; if (Entry.Pid != 0 && IsProcessRunning(Entry.Pid, ErrorCode) == false) { @@ -620,7 +651,7 @@ ZenServerInstanceInfo::Create(const Oid& SessionId, const InstanceInfoData& Data ThrowLastError("Could not map instance info shared memory"); } #else - int Fd = shm_open(Name.c_str(), O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC, 0666); + int Fd = shm_open(Name.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0666); if (Fd < 0) { ThrowLastError("Could not create instance info shared memory"); @@ -687,7 +718,7 @@ ZenServerInstanceInfo::OpenReadOnly(const Oid& SessionId) return false; } #else - int Fd = shm_open(Name.c_str(), O_RDONLY | O_CLOEXEC, 0666); + int Fd = shm_open(Name.c_str(), O_RDONLY, 0666); if (Fd < 0) { return false; @@ -1583,7 +1614,7 @@ ValidateLockFileInfo(const LockFileInfo& Info, std::string& OutReason) std::optional<int> StartupZenServer(LoggerRef LogRef, const StartupZenServerOptions& Options) { - auto Log = [&LogRef]() { return LogRef; }; + ZEN_SCOPED_LOG(LogRef); // Check if a matching server is already running { @@ -1653,7 +1684,7 @@ ShutdownZenServer(LoggerRef LogRef, ZenServerState::ZenServerEntry* Entry, const std::filesystem::path& ProgramBaseDir) { - auto Log = [&LogRef]() { return LogRef; }; + ZEN_SCOPED_LOG(LogRef); int EntryPort = (int)Entry->DesiredListenPort.load(); const uint32_t ServerProcessPid = Entry->Pid.load(); try diff --git a/src/zenutil/zenutil.cpp b/src/zenutil/zenutil.cpp index 2ca380c75..b282adc03 100644 --- a/src/zenutil/zenutil.cpp +++ b/src/zenutil/zenutil.cpp @@ -5,9 +5,11 @@ #if ZEN_WITH_TESTS # include <zenutil/cloud/imdscredentials.h> +# include <zenutil/consul.h> # include <zenutil/cloud/s3client.h> # include <zenutil/cloud/sigv4.h> # include <zenutil/config/commandlineoptions.h> +# include <zenutil/filesystemutils.h> # include <zenutil/rpcrecording.h> # include <zenutil/splitconsole/logstreamlistener.h> # include <zenutil/process/subprocessmanager.h> @@ -20,6 +22,8 @@ zenutil_forcelinktests() { cache::rpcrecord_forcelink(); commandlineoptions_forcelink(); + consul::consul_forcelink(); + filesystemutils_forcelink(); imdscredentials_forcelink(); logstreamlistener_forcelink(); subprocessmanager_forcelink(); |