diff options
| author | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
|---|---|---|
| committer | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
| commit | d1abc50ee9d4fb72efc646e17decafea741caa34 (patch) | |
| tree | e4288e00f2f7ca0391b83d986efcb69d3ba66a83 /src | |
| parent | Allow requests with invalid content-types unless specified in command line or... (diff) | |
| parent | updated chunk–block analyser (#818) (diff) | |
| download | zen-d1abc50ee9d4fb72efc646e17decafea741caa34.tar.xz zen-d1abc50ee9d4fb72efc646e17decafea741caa34.zip | |
Merge branch 'main' into lm/restrict-content-type
Diffstat (limited to 'src')
429 files changed, 54397 insertions, 5283 deletions
diff --git a/src/zen.ico b/src/zen.ico Binary files differnew file mode 100644 index 000000000..f7fb251b5 --- /dev/null +++ b/src/zen.ico diff --git a/src/zen/authutils.cpp b/src/zen/authutils.cpp index 31db82efd..534f7952b 100644 --- a/src/zen/authutils.cpp +++ b/src/zen/authutils.cpp @@ -154,21 +154,34 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, ZEN_ASSERT(!SystemRootDir.empty()); if (!Auth) { - if (m_EncryptionKey.empty()) + static const std::string_view DefaultEncryptionKey("abcdefghijklmnopqrstuvxyz0123456"); + static const std::string_view DefaultEncryptionIV("0123456789abcdef"); + if (m_EncryptionKey.empty() && m_EncryptionIV.empty()) { - m_EncryptionKey = "abcdefghijklmnopqrstuvxyz0123456"; + m_EncryptionKey = DefaultEncryptionKey; + m_EncryptionIV = DefaultEncryptionIV; if (!Quiet) { - ZEN_CONSOLE_WARN("Using default encryption key"); + ZEN_CONSOLE_WARN("Auth: Using default encryption key and initialization vector for auth storage"); } } - - if (m_EncryptionIV.empty()) + else { - m_EncryptionIV = "0123456789abcdef"; - if (!Quiet) + if (m_EncryptionKey.empty()) + { + m_EncryptionKey = DefaultEncryptionKey; + if (!Quiet) + { + ZEN_CONSOLE_WARN("Auth: Using default encryption key for auth storage"); + } + } + if (m_EncryptionIV.empty()) { - ZEN_CONSOLE_WARN("Using default encryption initialization vector"); + m_EncryptionIV = DefaultEncryptionIV; + if (!Quiet) + { + ZEN_CONSOLE_WARN("Auth: Using default encryption initialization vector for auth storage"); + } } } @@ -187,9 +200,9 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, { ExtendableStringBuilder<128> SB; SB << "\n RootDirectory: " << AuthMgrConfig.RootDirectory.string(); - SB << "\n EncryptionKey: " << m_EncryptionKey; - SB << "\n EncryptionIV: " << m_EncryptionIV; - ZEN_CONSOLE("Creating auth manager with:{}", SB.ToString()); + SB << "\n EncryptionKey: " << HideSensitiveString(m_EncryptionKey); + SB << "\n EncryptionIV: " << HideSensitiveString(m_EncryptionIV); + ZEN_CONSOLE("Auth: Creating auth manager with:{}", SB.ToString()); } Auth = AuthMgr::Create(AuthMgrConfig); } @@ -204,13 +217,18 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, ExtendableStringBuilder<128> SB; SB << "\n Name: " << ProviderName; SB << "\n Url: " << m_OpenIdProviderUrl; - SB << "\n ClientId: " << m_OpenIdClientId; - ZEN_CONSOLE("Adding openid auth provider:{}", SB.ToString()); + SB << "\n ClientId: " << HideSensitiveString(m_OpenIdClientId); + ZEN_CONSOLE("Auth: Adding Open ID auth provider:{}", SB.ToString()); } Auth->AddOpenIdProvider({.Name = ProviderName, .Url = m_OpenIdProviderUrl, .ClientId = m_OpenIdClientId}); if (!m_OpenIdRefreshToken.empty()) { - ZEN_CONSOLE("Adding open id refresh token {} to provider {}", m_OpenIdRefreshToken, ProviderName); + if (!Quiet) + { + ZEN_CONSOLE("Auth: Adding open id refresh token {} to provider {}", + HideSensitiveString(m_OpenIdRefreshToken), + ProviderName); + } Auth->AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = m_OpenIdRefreshToken}); } } @@ -225,21 +243,21 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, if (!m_AccessToken.empty()) { - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Adding static auth token: {}", m_AccessToken); + ZEN_CONSOLE("Auth: Using static auth token: {}", HideSensitiveString(m_AccessToken)); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromStaticToken(m_AccessToken); } else if (!m_AccessTokenPath.empty()) { - MakeSafeAbsolutePathÍnPlace(m_AccessTokenPath); + MakeSafeAbsolutePathInPlace(m_AccessTokenPath); std::string ResolvedAccessToken = ReadAccessTokenFromJsonFile(m_AccessTokenPath); if (!ResolvedAccessToken.empty()) { - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Adding static auth token from {}: {}", m_AccessTokenPath, ResolvedAccessToken); + ZEN_CONSOLE("Auth: Adding static auth token from {}: {}", m_AccessTokenPath, HideSensitiveString(ResolvedAccessToken)); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromStaticToken(ResolvedAccessToken); } @@ -250,9 +268,9 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, { ExtendableStringBuilder<128> SB; SB << "\n Url: " << m_OAuthUrl; - SB << "\n ClientId: " << m_OAuthClientId; - SB << "\n ClientSecret: " << m_OAuthClientSecret; - ZEN_CONSOLE("Adding oauth provider:{}", SB.ToString()); + SB << "\n ClientId: " << HideSensitiveString(m_OAuthClientId); + SB << "\n ClientSecret: " << HideSensitiveString(m_OAuthClientSecret); + ZEN_CONSOLE("Auth: Adding oauth provider:{}", SB.ToString()); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromOAuthClientCredentials( {.Url = m_OAuthUrl, .ClientId = m_OAuthClientId, .ClientSecret = m_OAuthClientSecret}); @@ -260,25 +278,27 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, else if (!m_OpenIdProviderName.empty()) { CreateAuthMgr(); - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Using openid provider: {}", m_OpenIdProviderName); + ZEN_CONSOLE("Auth: Using OpenId provider: {}", m_OpenIdProviderName); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromOpenIdProvider(*Auth, m_OpenIdProviderName); } else if (std::string ResolvedAccessToken = GetEnvAccessToken(m_AccessTokenEnv); !ResolvedAccessToken.empty()) { - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Using environment variable '{}' as access token '{}'", m_AccessTokenEnv, ResolvedAccessToken); + ZEN_CONSOLE("Auth: Resolved environment variable '{}' to access token '{}'", + m_AccessTokenEnv, + HideSensitiveString(ResolvedAccessToken)); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromStaticToken(ResolvedAccessToken); } else if (std::filesystem::path OidcTokenExePath = FindOidcTokenExePath(m_OidcTokenAuthExecutablePath); !OidcTokenExePath.empty()) { - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Running oidctoken exe from path '{}'", m_OidcTokenAuthExecutablePath); + ZEN_CONSOLE("Auth: Using oidctoken exe from path '{}'", OidcTokenExePath); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromOidcTokenExecutable(OidcTokenExePath, HostUrl, Quiet, m_OidcTokenUnattended, Hidden); @@ -291,9 +311,9 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, if (!ClientSettings.AccessTokenProvider) { CreateAuthMgr(); - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Using default openid provider"); + ZEN_CONSOLE("Auth: Using default Open ID provider"); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromDefaultOpenIdProvider(*Auth); } diff --git a/src/zen/cmds/admin_cmd.h b/src/zen/cmds/admin_cmd.h index 87ef8091b..83bcf8893 100644 --- a/src/zen/cmds/admin_cmd.h +++ b/src/zen/cmds/admin_cmd.h @@ -13,6 +13,9 @@ namespace zen { class ScrubCommand : public StorageCommand { public: + static constexpr char Name[] = "scrub"; + static constexpr char Description[] = "Scrub zen storage (verify data integrity)"; + ScrubCommand(); ~ScrubCommand(); @@ -20,7 +23,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"scrub", "Scrub zen storage"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_DryRun = false; bool m_NoGc = false; @@ -33,6 +36,9 @@ private: class GcCommand : public StorageCommand { public: + static constexpr char Name[] = "gc"; + static constexpr char Description[] = "Garbage collect zen storage"; + GcCommand(); ~GcCommand(); @@ -40,7 +46,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"gc", "Garbage collect zen storage"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_SmallObjects{false}; bool m_SkipCid{false}; @@ -62,6 +68,9 @@ private: class GcStatusCommand : public StorageCommand { public: + static constexpr char Name[] = "gc-status"; + static constexpr char Description[] = "Garbage collect zen storage status check"; + GcStatusCommand(); ~GcStatusCommand(); @@ -69,7 +78,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"gc-status", "Garbage collect zen storage status check"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_Details = false; }; @@ -77,6 +86,9 @@ private: class GcStopCommand : public StorageCommand { public: + static constexpr char Name[] = "gc-stop"; + static constexpr char Description[] = "Request cancel of running garbage collection in zen storage"; + GcStopCommand(); ~GcStopCommand(); @@ -84,7 +96,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"gc-stop", "Request cancel of running garbage collection in zen storage"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; @@ -93,6 +105,9 @@ private: class JobCommand : public ZenCmdBase { public: + static constexpr char Name[] = "jobs"; + static constexpr char Description[] = "Show/cancel zen background jobs"; + JobCommand(); ~JobCommand(); @@ -100,7 +115,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"jobs", "Show/cancel zen background jobs"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::uint64_t m_JobId = 0; bool m_Cancel = 0; @@ -111,6 +126,9 @@ private: class LoggingCommand : public ZenCmdBase { public: + static constexpr char Name[] = "logs"; + static constexpr char Description[] = "Show/control zen logging"; + LoggingCommand(); ~LoggingCommand(); @@ -118,7 +136,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"logs", "Show/control zen logging"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_CacheWriteLog; std::string m_CacheAccessLog; @@ -133,6 +151,9 @@ private: class FlushCommand : public StorageCommand { public: + static constexpr char Name[] = "flush"; + static constexpr char Description[] = "Flush storage"; + FlushCommand(); ~FlushCommand(); @@ -140,7 +161,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"flush", "Flush zen storage"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; @@ -149,6 +170,9 @@ private: class CopyStateCommand : public StorageCommand { public: + static constexpr char Name[] = "copy-state"; + static constexpr char Description[] = "Copy zen server disk state"; + CopyStateCommand(); ~CopyStateCommand(); @@ -156,7 +180,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"copy-state", "Copy zen server disk state"}; + cxxopts::Options m_Options{Name, Description}; std::filesystem::path m_DataPath; std::filesystem::path m_TargetPath; bool m_SkipLogs = false; diff --git a/src/zen/cmds/bench_cmd.h b/src/zen/cmds/bench_cmd.h index ed123be75..7fbf85340 100644 --- a/src/zen/cmds/bench_cmd.h +++ b/src/zen/cmds/bench_cmd.h @@ -9,6 +9,9 @@ namespace zen { class BenchCommand : public ZenCmdBase { public: + static constexpr char Name[] = "bench"; + static constexpr char Description[] = "Utility command for benchmarking"; + BenchCommand(); ~BenchCommand(); @@ -17,7 +20,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"bench", "Benchmarking utility command"}; + cxxopts::Options m_Options{Name, Description}; bool m_PurgeStandbyLists = false; bool m_SingleProcess = false; }; diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index f4edb65ab..b4b4df7c9 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -67,13 +67,11 @@ ZEN_THIRD_PARTY_INCLUDES_END static const bool DoExtraContentVerify = false; -#define ZEN_CLOUD_STORAGE "Cloud Storage" - namespace zen { using namespace std::literals; -namespace { +namespace builds_impl { static std::atomic<bool> AbortFlag = false; static std::atomic<bool> PauseFlag = false; @@ -270,10 +268,11 @@ namespace { static bool IsQuiet = false; static ProgressBar::Mode ProgressMode = ProgressBar::Mode::Pretty; -#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \ - if (IsVerbose) \ - { \ - ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__); \ +#undef ZEN_CONSOLE_VERBOSE +#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \ + if (IsVerbose) \ + { \ + ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__); \ } const std::string DefaultAccessTokenEnvVariableName( @@ -1467,9 +1466,16 @@ namespace { ZEN_CONSOLE("Downloading build {}, parts:{} to '{}' ({})", BuildId, BuildPartString.ToView(), Path, NiceBytes(RawSize)); } + Stopwatch IndexTimer; + const ChunkedContentLookup LocalLookup = BuildChunkedContentLookup(LocalState.State.ChunkedContent); const ChunkedContentLookup RemoteLookup = BuildChunkedContentLookup(RemoteContent); + if (!IsQuiet) + { + ZEN_OPERATION_LOG_INFO(Output, "Indexed local and remote content in {}", NiceTimeSpanMs(IndexTimer.GetElapsedTimeMs())); + } + ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Download, TaskSteps::StepCount); BuildsOperationUpdateFolder Updater( @@ -1588,7 +1594,7 @@ namespace { } } } - if (Storage.BuildCacheStorage) + if (Storage.CacheStorage) { if (SB.Size() > 0) { @@ -1643,9 +1649,9 @@ namespace { } if (Options.PrimeCacheOnly) { - if (Storage.BuildCacheStorage) + if (Storage.CacheStorage) { - Storage.BuildCacheStorage->Flush(5000, [](intptr_t Remaining) { + Storage.CacheStorage->Flush(5000, [](intptr_t Remaining) { if (!IsQuiet) { if (Remaining == 0) @@ -2002,12 +2008,13 @@ namespace { ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Cleanup, TaskSteps::StepCount); } -} // namespace +} // namespace builds_impl ////////////////////////////////////////////////////////////////////////////////////////////////////// BuildsCommand::BuildsCommand() { + using namespace builds_impl; m_Options.add_options()("h,help", "Print help"); auto AddSystemOptions = [this](cxxopts::Options& Ops) { @@ -2648,6 +2655,7 @@ BuildsCommand::~BuildsCommand() = default; void BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace builds_impl; ZEN_UNUSED(GlobalOptions); signal(SIGINT, SignalCallbackHandler); @@ -2680,7 +2688,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_SystemRootDir = PickDefaultSystemRootDirectory(); } - MakeSafeAbsolutePathÍnPlace(m_SystemRootDir); + MakeSafeAbsolutePathInPlace(m_SystemRootDir); }; ParseSystemOptions(); @@ -2729,7 +2737,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { throw OptionParseException("'--host', '--url', '--override-host' or '--storage-path' is required", SubOption->help()); } - MakeSafeAbsolutePathÍnPlace(m_StoragePath); + MakeSafeAbsolutePathInPlace(m_StoragePath); }; auto ParseOutputOptions = [&]() { @@ -2800,8 +2808,6 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) .Verbose = m_VerboseHttp, .MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)}; - std::unique_ptr<AuthMgr> Auth; - std::string StorageDescription; std::string CacheDescription; @@ -2820,44 +2826,47 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) BuildStorageResolveResult ResolveRes = ResolveBuildStorage(*Output, ClientSettings, m_Host, m_OverrideHost, m_ZenCacheHost, ZenCacheResolveMode::All, m_Verbose); - if (!ResolveRes.HostUrl.empty()) + if (!ResolveRes.Cloud.Address.empty()) { - ClientSettings.AssumeHttp2 = ResolveRes.HostAssumeHttp2; + ClientSettings.AssumeHttp2 = ResolveRes.Cloud.AssumeHttp2; Result.BuildStorageHttp = - std::make_unique<HttpClient>(ResolveRes.HostUrl, ClientSettings, []() { return AbortFlag.load(); }); + std::make_unique<HttpClient>(ResolveRes.Cloud.Address, ClientSettings, []() { return AbortFlag.load(); }); - Result.BuildStorage = CreateJupiterBuildStorage(Log(), + Result.BuildStorage = CreateJupiterBuildStorage(Log(), *Result.BuildStorageHttp, StorageStats, m_Namespace, m_Bucket, m_AllowRedirect, TempPath / "storage"); - Result.StorageName = ResolveRes.HostName; + Result.BuildStorageHost = ResolveRes.Cloud; + + uint64_t HostLatencyNs = ResolveRes.Cloud.LatencySec >= 0 ? uint64_t(ResolveRes.Cloud.LatencySec * 1000000000.0) : 0; - StorageDescription = fmt::format("Cloud {}{}. SessionId: '{}'. Namespace '{}', Bucket '{}'", - ResolveRes.HostName, - (ResolveRes.HostUrl == ResolveRes.HostName) ? "" : fmt::format(" {}", ResolveRes.HostUrl), - Result.BuildStorageHttp->GetSessionId(), - m_Namespace, - m_Bucket); - ; + StorageDescription = + fmt::format("Cloud {}{}. SessionId: '{}'. Namespace '{}', Bucket '{}'. Latency: {}", + ResolveRes.Cloud.Name, + (ResolveRes.Cloud.Address == ResolveRes.Cloud.Name) ? "" : fmt::format(" {}", ResolveRes.Cloud.Address), + Result.BuildStorageHttp->GetSessionId(), + m_Namespace, + m_Bucket, + NiceLatencyNs(HostLatencyNs)); - if (!ResolveRes.CacheUrl.empty()) + if (!ResolveRes.Cache.Address.empty()) { Result.CacheHttp = std::make_unique<HttpClient>( - ResolveRes.CacheUrl, + ResolveRes.Cache.Address, HttpClientSettings{ .LogCategory = "httpcacheclient", .ConnectTimeout = std::chrono::milliseconds{3000}, .Timeout = std::chrono::milliseconds{30000}, - .AssumeHttp2 = ResolveRes.CacheAssumeHttp2, + .AssumeHttp2 = ResolveRes.Cache.AssumeHttp2, .AllowResume = true, .RetryCount = 0, .Verbose = m_VerboseHttp, .MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)}, []() { return AbortFlag.load(); }); - Result.BuildCacheStorage = + Result.CacheStorage = CreateZenBuildStorageCache(*Result.CacheHttp, StorageCacheStats, m_Namespace, @@ -2865,14 +2874,17 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) TempPath / "zencache", BoostCacheBackgroundWorkerPool ? GetSmallWorkerPool(EWorkloadType::Background) : GetTinyWorkerPool(EWorkloadType::Background)); - Result.CacheName = ResolveRes.CacheName; + Result.CacheHost = ResolveRes.Cache; + + uint64_t CacheLatencyNs = ResolveRes.Cache.LatencySec >= 0 ? uint64_t(ResolveRes.Cache.LatencySec * 1000000000.0) : 0; CacheDescription = - fmt::format("Zen {}{}. SessionId: '{}'", - ResolveRes.CacheName, - (ResolveRes.CacheUrl == ResolveRes.CacheName) ? "" : fmt::format(" {}", ResolveRes.CacheUrl), - Result.CacheHttp->GetSessionId()); - ; + fmt::format("Zen {}{}. SessionId: '{}'. Latency: {}", + ResolveRes.Cache.Name, + (ResolveRes.Cache.Address == ResolveRes.Cache.Name) ? "" : fmt::format(" {}", ResolveRes.Cache.Address), + Result.CacheHttp->GetSessionId(), + NiceLatencyNs(CacheLatencyNs)); + if (!m_Namespace.empty()) { CacheDescription += fmt::format(". Namespace '{}'", m_Namespace); @@ -2888,41 +2900,56 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { StorageDescription = fmt::format("folder {}", m_StoragePath); Result.BuildStorage = CreateFileBuildStorage(m_StoragePath, StorageStats, false, DefaultLatency, DefaultDelayPerKBSec); - Result.StorageName = fmt::format("Disk {}", m_StoragePath.stem()); + + Result.BuildStorageHost = BuildStorageResolveResult::Host{.Address = m_StoragePath.generic_string(), + .Name = "Disk", + .LatencySec = 1.0 / 100000, // 1 us + .Caps = {.MaxRangeCountPerRequest = 2048u}}; if (!m_ZenCacheHost.empty()) { - Result.CacheHttp = std::make_unique<HttpClient>( - m_ZenCacheHost, - HttpClientSettings{ - .LogCategory = "httpcacheclient", - .ConnectTimeout = std::chrono::milliseconds{3000}, - .Timeout = std::chrono::milliseconds{30000}, - .AssumeHttp2 = m_AssumeHttp2, - .AllowResume = true, - .RetryCount = 0, - .Verbose = m_VerboseHttp, - .MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)}, - []() { return AbortFlag.load(); }); - Result.BuildCacheStorage = - CreateZenBuildStorageCache(*Result.CacheHttp, - StorageCacheStats, - m_Namespace, - m_Bucket, - TempPath / "zencache", - BoostCacheBackgroundWorkerPool ? GetSmallWorkerPool(EWorkloadType::Background) - : GetTinyWorkerPool(EWorkloadType::Background)); - Result.CacheName = m_ZenCacheHost; - - CacheDescription = fmt::format("Zen {}{}. SessionId: '{}'", Result.CacheName, "", Result.CacheHttp->GetSessionId()); - ; - if (!m_Namespace.empty()) - { - CacheDescription += fmt::format(". Namespace '{}'", m_Namespace); - } - if (!m_Bucket.empty()) + ZenCacheEndpointTestResult TestResult = TestZenCacheEndpoint(m_ZenCacheHost, m_AssumeHttp2, m_VerboseHttp); + + if (TestResult.Success) { - CacheDescription += fmt::format(" Bucket '{}'", m_Bucket); + Result.CacheHttp = std::make_unique<HttpClient>( + m_ZenCacheHost, + HttpClientSettings{ + .LogCategory = "httpcacheclient", + .ConnectTimeout = std::chrono::milliseconds{3000}, + .Timeout = std::chrono::milliseconds{30000}, + .AssumeHttp2 = m_AssumeHttp2, + .AllowResume = true, + .RetryCount = 0, + .Verbose = m_VerboseHttp, + .MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)}, + []() { return AbortFlag.load(); }); + + Result.CacheStorage = + CreateZenBuildStorageCache(*Result.CacheHttp, + StorageCacheStats, + m_Namespace, + m_Bucket, + TempPath / "zencache", + BoostCacheBackgroundWorkerPool ? GetSmallWorkerPool(EWorkloadType::Background) + : GetTinyWorkerPool(EWorkloadType::Background)); + Result.CacheHost = + BuildStorageResolveResult::Host{.Address = m_ZenCacheHost, + .Name = m_ZenCacheHost, + .AssumeHttp2 = m_AssumeHttp2, + .LatencySec = TestResult.LatencySeconds, + .Caps = {.MaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest}}; + + CacheDescription = fmt::format("Zen {}. SessionId: '{}'", Result.CacheHost.Name, Result.CacheHttp->GetSessionId()); + + if (!m_Namespace.empty()) + { + CacheDescription += fmt::format(". Namespace '{}'", m_Namespace); + } + if (!m_Bucket.empty()) + { + CacheDescription += fmt::format(" Bucket '{}'", m_Bucket); + } } } } @@ -2934,7 +2961,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (!IsQuiet) { ZEN_CONSOLE("Remote: {}", StorageDescription); - if (!Result.CacheName.empty()) + if (!Result.CacheHost.Name.empty()) { ZEN_CONSOLE("Cache : {}", CacheDescription); } @@ -2947,7 +2974,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { throw OptionParseException("'--local-path' is required", SubOption->help()); } - MakeSafeAbsolutePathÍnPlace(m_Path); + MakeSafeAbsolutePathInPlace(m_Path); }; auto ParseFileFilters = [&](std::vector<std::string>& OutIncludeWildcards, std::vector<std::string>& OutExcludeWildcards) { @@ -3004,7 +3031,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { throw OptionParseException("'--compare-path' is required", SubOption->help()); } - MakeSafeAbsolutePathÍnPlace(m_DiffPath); + MakeSafeAbsolutePathInPlace(m_DiffPath); }; auto ParseBlobHash = [&]() -> IoHash { @@ -3016,7 +3043,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (m_BlobHash.length() != IoHash::StringLength) { throw OptionParseException( - fmt::format("'--blob-hash' ('{}') is malfomed, it must be {} characters long", m_BlobHash, IoHash::StringLength), + fmt::format("'--blob-hash' ('{}') is malformed, it must be {} characters long", m_BlobHash, IoHash::StringLength), SubOption->help()); } @@ -3033,7 +3060,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (m_BuildId.length() != Oid::StringLength) { throw OptionParseException( - fmt::format("'--build-id' ('{}') is malfomed, it must be {} characters long", m_BuildId, Oid::StringLength), + fmt::format("'--build-id' ('{}') is malformed, it must be {} characters long", m_BuildId, Oid::StringLength), SubOption->help()); } else if (Oid BuildId = Oid::FromHexString(m_BuildId); BuildId == Oid::Zero) @@ -3105,7 +3132,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (!m_BuildMetadataPath.empty()) { - MakeSafeAbsolutePathÍnPlace(m_BuildMetadataPath); + MakeSafeAbsolutePathInPlace(m_BuildMetadataPath); IoBuffer MetaDataJson = ReadFile(m_BuildMetadataPath).Flatten(); std::string_view Json(reinterpret_cast<const char*>(MetaDataJson.GetData()), MetaDataJson.GetSize()); std::string JsonError; @@ -3202,8 +3229,8 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (SubOption == &m_ListOptions) { - MakeSafeAbsolutePathÍnPlace(m_ListQueryPath); - MakeSafeAbsolutePathÍnPlace(m_ListResultPath); + MakeSafeAbsolutePathInPlace(m_ListQueryPath); + MakeSafeAbsolutePathInPlace(m_ListResultPath); if (!m_ListResultPath.empty()) { @@ -3255,7 +3282,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this]() { CleanAndRemoveDirectory(GetSmallWorkerPool(EWorkloadType::Burst), m_ZenFolderPath); }); @@ -3294,7 +3321,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (SubOption == &m_ListBlocksOptions) { - MakeSafeAbsolutePathÍnPlace(m_ListResultPath); + MakeSafeAbsolutePathInPlace(m_ListResultPath); if (!m_ListResultPath.empty()) { @@ -3316,7 +3343,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this]() { CleanAndRemoveDirectory(GetSmallWorkerPool(EWorkloadType::Burst), m_ZenFolderPath); }); @@ -3387,8 +3414,8 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); - MakeSafeAbsolutePathÍnPlace(m_ChunkingCachePath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ChunkingCachePath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); }); @@ -3475,7 +3502,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) "Requests: {}\n" "Avg Request Time: {}\n" "Avg I/O Time: {}", - Storage.StorageName, + Storage.BuildStorageHost.Name, NiceBytes(StorageStats.TotalBytesRead.load()), NiceBytes(StorageStats.TotalBytesWritten.load()), StorageStats.TotalRequestCount.load(), @@ -3532,7 +3559,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = m_Path / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); BuildStorageBase::Statistics StorageStats; BuildStorageCache::Statistics StorageCacheStats; @@ -3632,7 +3659,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = m_Path / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); BuildStorageBase::Statistics StorageStats; BuildStorageCache::Statistics StorageCacheStats; @@ -3652,7 +3679,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) std::unique_ptr<CbObjectWriter> StructuredOutput; if (!m_LsResultPath.empty()) { - MakeSafeAbsolutePathÍnPlace(m_LsResultPath); + MakeSafeAbsolutePathInPlace(m_LsResultPath); StructuredOutput = std::make_unique<CbObjectWriter>(); } @@ -3696,7 +3723,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) ParsePath(); ParseDiffPath(); - MakeSafeAbsolutePathÍnPlace(m_ChunkingCachePath); + MakeSafeAbsolutePathInPlace(m_ChunkingCachePath); std::vector<std::string> ExcludeFolders = DefaultExcludeFolders; std::vector<std::string> ExcludeExtensions = DefaultExcludeExtensions; @@ -3745,7 +3772,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); }); @@ -3796,12 +3823,12 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (!IsQuiet) { - if (Storage.BuildCacheStorage) + if (Storage.CacheStorage) { - ZEN_CONSOLE("Uploaded {} ({}) blobs", + ZEN_CONSOLE("Uploaded {} ({}) blobs to {}", StorageCacheStats.PutBlobCount.load(), NiceBytes(StorageCacheStats.PutBlobByteCount), - Storage.CacheName); + Storage.CacheHost.Name); } } @@ -3828,7 +3855,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); }); @@ -3883,7 +3910,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); }); @@ -3933,7 +3960,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = m_Path / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); EPartialBlockRequestMode PartialBlockRequestMode = ParseAllowPartialBlockRequests(); @@ -4083,8 +4110,8 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = m_Path / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); - MakeSafeAbsolutePathÍnPlace(m_ChunkingCachePath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ChunkingCachePath); StorageInstance Storage = CreateBuildStorage(StorageStats, StorageCacheStats, diff --git a/src/zen/cmds/builds_cmd.h b/src/zen/cmds/builds_cmd.h index f5c44ab55..5c80beed5 100644 --- a/src/zen/cmds/builds_cmd.h +++ b/src/zen/cmds/builds_cmd.h @@ -71,7 +71,7 @@ private: bool m_AppendNewContent = false; uint8_t m_BlockReuseMinPercentLimit = 85; bool m_AllowMultiparts = true; - std::string m_AllowPartialBlockRequests = "mixed"; + std::string m_AllowPartialBlockRequests = "true"; AuthCommandLineOptions m_AuthOptions; diff --git a/src/zen/cmds/cache_cmd.h b/src/zen/cmds/cache_cmd.h index 4dc05bbdc..4f5b90f4d 100644 --- a/src/zen/cmds/cache_cmd.h +++ b/src/zen/cmds/cache_cmd.h @@ -9,6 +9,9 @@ namespace zen { class DropCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "drop"; + static constexpr char Description[] = "Drop cache namespace or bucket"; + DropCommand(); ~DropCommand(); @@ -16,7 +19,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"drop", "Drop cache namespace or bucket"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_NamespaceName; std::string m_BucketName; @@ -25,13 +28,16 @@ private: class CacheInfoCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "cache-info"; + static constexpr char Description[] = "Info on cache, namespace or bucket"; + CacheInfoCommand(); ~CacheInfoCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"cache-info", "Info on cache, namespace or bucket"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_NamespaceName; std::string m_SizeInfoBucketNames; @@ -42,26 +48,32 @@ private: class CacheStatsCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "cache-stats"; + static constexpr char Description[] = "Stats on cache"; + CacheStatsCommand(); ~CacheStatsCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"cache-stats", "Stats info on cache"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; class CacheDetailsCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "cache-details"; + static constexpr char Description[] = "Details on cache"; + CacheDetailsCommand(); ~CacheDetailsCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"cache-details", "Detailed info on cache"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_CSV = false; bool m_Details = false; diff --git a/src/zen/cmds/copy_cmd.h b/src/zen/cmds/copy_cmd.h index e1a5dcb82..757a8e691 100644 --- a/src/zen/cmds/copy_cmd.h +++ b/src/zen/cmds/copy_cmd.h @@ -11,6 +11,9 @@ namespace zen { class CopyCommand : public ZenCmdBase { public: + static constexpr char Name[] = "copy"; + static constexpr char Description[] = "Copy file(s)"; + CopyCommand(); ~CopyCommand(); @@ -19,7 +22,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"copy", "Copy files efficiently"}; + cxxopts::Options m_Options{Name, Description}; std::filesystem::path m_CopySource; std::filesystem::path m_CopyTarget; bool m_NoClone = false; diff --git a/src/zen/cmds/dedup_cmd.h b/src/zen/cmds/dedup_cmd.h index 5b8387dd2..835b35e92 100644 --- a/src/zen/cmds/dedup_cmd.h +++ b/src/zen/cmds/dedup_cmd.h @@ -11,6 +11,9 @@ namespace zen { class DedupCommand : public ZenCmdBase { public: + static constexpr char Name[] = "dedup"; + static constexpr char Description[] = "Dedup files"; + DedupCommand(); ~DedupCommand(); @@ -19,7 +22,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"dedup", "Deduplicate files"}; + cxxopts::Options m_Options{Name, Description}; std::vector<std::string> m_Positional; std::filesystem::path m_DedupSource; std::filesystem::path m_DedupTarget; diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp new file mode 100644 index 000000000..42c7119e7 --- /dev/null +++ b/src/zen/cmds/exec_cmd.cpp @@ -0,0 +1,1374 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "exec_cmd.h" + +#include <zencompute/computeservice.h> +#include <zencompute/recordingreader.h> +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryfile.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalue.h> +#include <zencore/compress.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/system.h> +#include <zencore/timer.h> +#include <zenhttp/httpclient.h> +#include <zenhttp/packageformat.h> + +#include <EASTL/hash_map.h> +#include <EASTL/hash_set.h> +#include <EASTL/map.h> + +using namespace std::literals; + +namespace eastl { + +template<> +struct hash<zen::IoHash> : public zen::IoHash::Hasher +{ +}; + +} // namespace eastl + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen { + +ExecCommand::ExecCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName), "<hosturl>"); + m_Options.add_option("", "", "log", "Action log directory", cxxopts::value(m_RecordingLogPath), "<path>"); + m_Options.add_option("", "p", "path", "Recording path (directory or .actionlog file)", cxxopts::value(m_RecordingPath), "<path>"); + m_Options.add_option("", "", "offset", "Recording replay start offset", cxxopts::value(m_Offset), "<offset>"); + m_Options.add_option("", "", "stride", "Recording replay stride", cxxopts::value(m_Stride), "<stride>"); + m_Options.add_option("", "", "limit", "Recording replay limit", cxxopts::value(m_Limit), "<limit>"); + m_Options.add_option("", "", "beacon", "Beacon path", cxxopts::value(m_BeaconPath), "<path>"); + m_Options.add_option("", "", "orch", "Orchestrator URL for worker discovery", cxxopts::value(m_OrchestratorUrl), "<url>"); + m_Options.add_option("", + "", + "mode", + "Select execution mode (http,inproc,dump,direct,beacon,buildlog)", + cxxopts::value(m_Mode)->default_value("http"), + "<string>"); + m_Options + .add_option("", "", "dump-actions", "Dump each action to console as it is dispatched", cxxopts::value(m_DumpActions), "<bool>"); + m_Options.add_option("", "o", "output", "Save action results to directory", cxxopts::value(m_OutputPath), "<path>"); + m_Options.add_option("", "", "binary", "Write output as binary packages instead of YAML", cxxopts::value(m_Binary), "<bool>"); + m_Options.add_option("", "", "quiet", "Quiet mode (less logging)", cxxopts::value(m_Quiet), "<bool>"); + m_Options.parse_positional("mode"); +} + +ExecCommand::~ExecCommand() +{ +} + +void +ExecCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + // Configure + + if (!ParseOptions(argc, argv)) + { + return; + } + + m_HostName = ResolveTargetHostSpec(m_HostName); + + if (m_RecordingPath.empty()) + { + throw OptionParseException("replay path is required!", m_Options.help()); + } + + m_VerboseLogging = GlobalOptions.IsVerbose; + m_QuietLogging = m_Quiet && !m_VerboseLogging; + + enum ExecMode + { + kHttp, + kDirect, + kInproc, + kDump, + kBeacon, + kBuildLog + } Mode; + + if (m_Mode == "http"sv) + { + Mode = kHttp; + } + else if (m_Mode == "direct"sv) + { + Mode = kDirect; + } + else if (m_Mode == "inproc"sv) + { + Mode = kInproc; + } + else if (m_Mode == "dump"sv) + { + Mode = kDump; + } + else if (m_Mode == "beacon"sv) + { + Mode = kBeacon; + } + else if (m_Mode == "buildlog"sv) + { + Mode = kBuildLog; + } + else + { + throw OptionParseException("invalid mode specified!", m_Options.help()); + } + + // Gather information from recording path + + std::unique_ptr<zen::compute::RecordingReader> Reader; + std::unique_ptr<zen::compute::UeRecordingReader> UeReader; + + std::filesystem::path RecordingPath{m_RecordingPath}; + + if (!std::filesystem::is_directory(RecordingPath)) + { + throw OptionParseException("replay path should be a directory path!", m_Options.help()); + } + else + { + if (std::filesystem::is_directory(RecordingPath / "cid")) + { + Reader = std::make_unique<zen::compute::RecordingReader>(RecordingPath); + m_WorkerMap = Reader->ReadWorkers(); + m_ChunkResolver = Reader.get(); + m_RecordingReader = Reader.get(); + } + else + { + UeReader = std::make_unique<zen::compute::UeRecordingReader>(RecordingPath); + m_WorkerMap = UeReader->ReadWorkers(); + m_ChunkResolver = UeReader.get(); + m_RecordingReader = UeReader.get(); + } + } + + ZEN_CONSOLE("found {} workers, {} action items", m_WorkerMap.size(), m_RecordingReader->GetActionCount()); + + for (auto& Kv : m_WorkerMap) + { + CbObject WorkerDesc = Kv.second.GetObject(); + const IoHash& WorkerId = Kv.first; + + RegisterWorkerFunctionsFromDescription(WorkerDesc, WorkerId); + + if (m_VerboseLogging) + { + zen::ExtendableStringBuilder<1024> ObjStr; +# if 0 + zen::CompactBinaryToJson(WorkerDesc, ObjStr); + ZEN_CONSOLE("worker {}: {}", WorkerId, ObjStr); +# else + zen::CompactBinaryToYaml(WorkerDesc, ObjStr); + ZEN_CONSOLE("worker {}:\n{}", WorkerId, ObjStr); +# endif + } + } + + if (m_VerboseLogging) + { + EmitFunctionList(m_FunctionList); + } + + // Iterate over work items and dispatch or log them + + int ReturnValue = 0; + + Stopwatch ExecTimer; + + switch (Mode) + { + case kHttp: + // Forward requests to HTTP function service + ReturnValue = HttpExecute(); + break; + + case kDirect: + // Not currently supported + ReturnValue = LocalMessagingExecute(); + break; + + case kInproc: + // Handle execution in-core (by spawning child processes) + ReturnValue = InProcessExecute(); + break; + + case kDump: + // Dump high level information about actions to console + ReturnValue = DumpWorkItems(); + break; + + case kBeacon: + ReturnValue = BeaconExecute(); + break; + + case kBuildLog: + ReturnValue = BuildActionsLog(); + break; + + default: + ZEN_ERROR("Unknown operating mode! No work submitted"); + + ReturnValue = 1; + } + + ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); + + if (!ReturnValue) + { + ZEN_CONSOLE("all work items completed successfully"); + } + else + { + ZEN_CONSOLE("some work items failed (code {})", ReturnValue); + } +} + +int +ExecCommand::InProcessExecute() +{ + ZEN_ASSERT(m_ChunkResolver); + ChunkResolver& Resolver = *m_ChunkResolver; + + zen::compute::ComputeServiceSession ComputeSession(Resolver); + + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + ComputeSession.AddLocalRunner(Resolver, TempPath); + + return ExecUsingSession(ComputeSession); +} + +int +ExecCommand::ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession) +{ + struct JobTracker + { + public: + inline void Insert(int LsnField) + { + RwLock::ExclusiveLockScope _(Lock); + PendingJobs.insert(LsnField); + } + + inline bool IsEmpty() const + { + RwLock::SharedLockScope _(Lock); + return PendingJobs.empty(); + } + + inline void Remove(int CompleteLsn) + { + RwLock::ExclusiveLockScope _(Lock); + PendingJobs.erase(CompleteLsn); + } + + inline size_t GetSize() const + { + RwLock::SharedLockScope _(Lock); + return PendingJobs.size(); + } + + private: + mutable RwLock Lock; + std::unordered_set<int> PendingJobs; + }; + + JobTracker PendingJobs; + + struct ActionSummaryEntry + { + int32_t Lsn = 0; + int RecordingIndex = 0; + IoHash ActionId; + std::string FunctionName; + int InputAttachments = 0; + uint64_t InputBytes = 0; + int OutputAttachments = 0; + uint64_t OutputBytes = 0; + float WallSeconds = 0.0f; + float CpuSeconds = 0.0f; + uint64_t SubmittedTicks = 0; + uint64_t StartedTicks = 0; + std::string ExecutionLocation; + }; + + std::mutex SummaryLock; + std::unordered_map<int32_t, ActionSummaryEntry> SummaryEntries; + + ComputeSession.WaitUntilReady(); + + // Register as a client with the orchestrator (best-effort) + + std::string OrchestratorClientId; + + if (!m_OrchestratorUrl.empty()) + { + try + { + HttpClient OrchestratorClient(m_OrchestratorUrl); + + CbObjectWriter Ann; + Ann << "session_id"sv << GetSessionId(); + Ann << "hostname"sv << std::string_view(GetMachineName()); + + CbObjectWriter Meta; + Meta << "source"sv + << "zen-exec"sv; + Ann << "metadata"sv << Meta.Save(); + + auto Resp = OrchestratorClient.Post("/orch/clients", Ann.Save()); + if (Resp.IsSuccess()) + { + OrchestratorClientId = std::string(Resp.AsObject()["id"].AsString()); + ZEN_CONSOLE_INFO("registered with orchestrator as {}", OrchestratorClientId); + } + else + { + ZEN_WARN("failed to register with orchestrator (status {})", static_cast<int>(Resp.StatusCode)); + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to register with orchestrator: {}", Ex.what()); + } + } + + Stopwatch OrchestratorHeartbeatTimer; + + auto SendOrchestratorHeartbeat = [&] { + if (OrchestratorClientId.empty() || OrchestratorHeartbeatTimer.GetElapsedTimeMs() < 30'000) + { + return; + } + OrchestratorHeartbeatTimer.Reset(); + try + { + HttpClient OrchestratorClient(m_OrchestratorUrl); + std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/update", OrchestratorClientId)); + } + catch (...) + { + } + }; + + auto ClientCleanup = MakeGuard([&] { + if (!OrchestratorClientId.empty()) + { + try + { + HttpClient OrchestratorClient(m_OrchestratorUrl); + std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/complete", OrchestratorClientId)); + } + catch (...) + { + } + } + }); + + // Create a queue to group all actions from this exec session + + CbObjectWriter Metadata; + Metadata << "source"sv + << "zen-exec"sv; + + auto QueueResult = ComputeSession.CreateQueue("zen-exec", Metadata.Save()); + const int QueueId = QueueResult.QueueId; + if (!QueueId) + { + ZEN_ERROR("failed to create compute queue"); + return 1; + } + + auto QueueCleanup = MakeGuard([&] { ComputeSession.DeleteQueue(QueueId); }); + + if (!m_OutputPath.empty()) + { + zen::CreateDirectories(m_OutputPath); + } + + std::atomic<int> IsDraining{0}; + + auto DrainCompletedJobs = [&] { + if (IsDraining.exchange(1)) + { + return; + } + + auto _ = MakeGuard([&] { IsDraining.store(0, std::memory_order_release); }); + + CbObjectWriter Cbo; + ComputeSession.GetQueueCompleted(QueueId, Cbo); + + if (CbObject Completed = Cbo.Save()) + { + for (auto& It : Completed["completed"sv]) + { + int32_t CompleteLsn = It.AsInt32(); + + CbPackage ResultPackage; + HttpResponseCode Response = ComputeSession.GetActionResult(CompleteLsn, /* out */ ResultPackage); + + if (Response == HttpResponseCode::OK) + { + if (!m_OutputPath.empty() && ResultPackage) + { + int OutputAttachments = 0; + uint64_t OutputBytes = 0; + + if (!m_Binary) + { + // Write the root object as YAML + ExtendableStringBuilder<4096> YamlStr; + CompactBinaryToYaml(ResultPackage.GetObject(), YamlStr); + + std::string_view Yaml = YamlStr; + zen::WriteFile(m_OutputPath / fmt::format("{}.result.yaml", CompleteLsn), + IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); + + // Write decompressed attachments + auto Attachments = ResultPackage.GetAttachments(); + + if (!Attachments.empty()) + { + std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.result.attachments", CompleteLsn); + zen::CreateDirectories(AttDir); + + for (const CbAttachment& Att : Attachments) + { + ++OutputAttachments; + + IoHash AttHash = Att.GetHash(); + + if (Att.IsCompressedBinary()) + { + SharedBuffer Decompressed = Att.AsCompressedBinary().Decompress(); + OutputBytes += Decompressed.GetSize(); + zen::WriteFile(AttDir / AttHash.ToHexString(), + IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); + } + else + { + SharedBuffer Binary = Att.AsBinary(); + OutputBytes += Binary.GetSize(); + zen::WriteFile(AttDir / AttHash.ToHexString(), + IoBuffer(IoBuffer::Clone, Binary.GetData(), Binary.GetSize())); + } + } + } + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved result: {}/{}.result.yaml ({} attachments)", + m_OutputPath.string(), + CompleteLsn, + OutputAttachments); + } + } + else + { + CompositeBuffer Serialized = FormatPackageMessageBuffer(ResultPackage); + zen::WriteFile(m_OutputPath / fmt::format("{}.result.pkg", CompleteLsn), std::move(Serialized)); + + for (const CbAttachment& Att : ResultPackage.GetAttachments()) + { + ++OutputAttachments; + OutputBytes += Att.AsBinary().GetSize(); + } + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved result: {}/{}.result.pkg", m_OutputPath.string(), CompleteLsn); + } + } + + std::lock_guard Lock(SummaryLock); + if (auto It2 = SummaryEntries.find(CompleteLsn); It2 != SummaryEntries.end()) + { + It2->second.OutputAttachments = OutputAttachments; + It2->second.OutputBytes = OutputBytes; + } + } + + PendingJobs.Remove(CompleteLsn); + + ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, PendingJobs.GetSize()); + } + } + } + }; + + // Describe workers + + ZEN_CONSOLE("describing {} workers", m_WorkerMap.size()); + + for (auto Kv : m_WorkerMap) + { + CbPackage WorkerDesc = Kv.second; + + ComputeSession.RegisterWorker(WorkerDesc); + } + + // Then submit work items + + int FailedWorkCounter = 0; + size_t RemainingWorkItems = m_RecordingReader->GetActionCount(); + int SubmittedWorkItems = 0; + + ZEN_CONSOLE("submitting {} work items", RemainingWorkItems); + + int OffsetCounter = m_Offset; + int StrideCounter = m_Stride; + + auto ShouldSchedule = [&]() -> bool { + if (m_Limit && SubmittedWorkItems >= m_Limit) + { + // Limit reached, ignore + + return false; + } + + if (OffsetCounter && OffsetCounter--) + { + // Still in offset, ignore + + return false; + } + + if (--StrideCounter == 0) + { + StrideCounter = m_Stride; + + return true; + } + + return false; + }; + + int TargetParallelism = 8; + + if (OffsetCounter || StrideCounter || m_Limit) + { + TargetParallelism = 1; + } + + std::atomic<int> RecordingIndex{0}; + + m_RecordingReader->IterateActions( + [&](CbObject ActionObject, const IoHash& ActionId) { + // Enqueue job + + const int CurrentRecordingIndex = RecordingIndex++; + + Stopwatch SubmitTimer; + + const int Priority = 0; + + if (ShouldSchedule()) + { + if (m_VerboseLogging) + { + int AttachmentCount = 0; + uint64_t AttachmentBytes = 0; + eastl::hash_set<IoHash> ReferencedChunks; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsAttachment(); + + ReferencedChunks.insert(AttachData); + ++AttachmentCount; + + if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData)) + { + AttachmentBytes += ChunkData.GetSize(); + } + }); + + zen::ExtendableStringBuilder<1024> ObjStr; + zen::CompactBinaryToJson(ActionObject, ObjStr); + ZEN_CONSOLE("work item {} ({} attachments, {} bytes): {}", + ActionId, + AttachmentCount, + NiceBytes(AttachmentBytes), + ObjStr); + } + + if (m_DumpActions) + { + int AttachmentCount = 0; + uint64_t AttachmentBytes = 0; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsAttachment(); + + ++AttachmentCount; + + if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData)) + { + AttachmentBytes += ChunkData.GetSize(); + } + }); + + zen::ExtendableStringBuilder<1024> ObjStr; + zen::CompactBinaryToYaml(ActionObject, ObjStr); + ZEN_CONSOLE("action {} ({} attachments, {}):\n{}", ActionId, AttachmentCount, NiceBytes(AttachmentBytes), ObjStr); + } + + if (zen::compute::ComputeServiceSession::EnqueueResult EnqueueResult = + ComputeSession.EnqueueActionToQueue(QueueId, ActionObject, Priority)) + { + const int32_t LsnField = EnqueueResult.Lsn; + + --RemainingWorkItems; + ++SubmittedWorkItems; + + if (!m_QuietLogging) + { + ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining", + SubmittedWorkItems, + LsnField, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + RemainingWorkItems); + } + + if (!m_OutputPath.empty()) + { + ActionSummaryEntry Entry; + Entry.Lsn = LsnField; + Entry.RecordingIndex = CurrentRecordingIndex; + Entry.ActionId = ActionId; + Entry.FunctionName = std::string(ActionObject["Function"sv].AsString()); + + if (!m_Binary) + { + // Write action object as YAML + ExtendableStringBuilder<4096> YamlStr; + CompactBinaryToYaml(ActionObject, YamlStr); + + std::string_view Yaml = YamlStr; + zen::WriteFile(m_OutputPath / fmt::format("{}.action.yaml", LsnField), + IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); + + // Write decompressed input attachments + std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.action.attachments", LsnField); + bool AttDirCreated = false; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachCid = Field.AsAttachment(); + ++Entry.InputAttachments; + + if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid)) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); + SharedBuffer Decompressed = Compressed.Decompress(); + + Entry.InputBytes += Decompressed.GetSize(); + + if (!AttDirCreated) + { + zen::CreateDirectories(AttDir); + AttDirCreated = true; + } + + zen::WriteFile(AttDir / AttachCid.ToHexString(), + IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); + } + }); + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved action: {}/{}.action.yaml ({} attachments)", + m_OutputPath.string(), + LsnField, + Entry.InputAttachments); + } + } + else + { + // Build a CbPackage from the action and write as .pkg + CbPackage ActionPackage; + ActionPackage.SetObject(ActionObject); + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachCid = Field.AsAttachment(); + ++Entry.InputAttachments; + + if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid)) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); + + Entry.InputBytes += ChunkData.GetSize(); + ActionPackage.AddAttachment(CbAttachment(std::move(Compressed), RawHash)); + } + }); + + CompositeBuffer Serialized = FormatPackageMessageBuffer(ActionPackage); + zen::WriteFile(m_OutputPath / fmt::format("{}.action.pkg", LsnField), std::move(Serialized)); + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved action: {}/{}.action.pkg", m_OutputPath.string(), LsnField); + } + } + + std::lock_guard Lock(SummaryLock); + SummaryEntries.emplace(LsnField, std::move(Entry)); + } + + PendingJobs.Insert(LsnField); + } + else + { + if (!m_QuietLogging) + { + std::string_view FunctionName = ActionObject["Function"sv].AsString(); + const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); + const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + + ZEN_ERROR( + "failed to resolve function for work with (Function:{},FunctionVersion:{},BuildSystemVersion:{}). Work " + "descriptor " + "at: 'file://{}'", + std::string(FunctionName), + FunctionVersion, + BuildSystemVersion, + "<null>"); + + EmitFunctionListOnce(m_FunctionList); + } + + ++FailedWorkCounter; + } + } + + // Check for completed work + + DrainCompletedJobs(); + SendOrchestratorHeartbeat(); + }, + TargetParallelism); + + // Wait until all pending work is complete + + while (!PendingJobs.IsEmpty()) + { + // TODO: improve this logic + zen::Sleep(500); + + DrainCompletedJobs(); + SendOrchestratorHeartbeat(); + } + + // Merge timing data from queue history into summary entries + + if (!SummaryEntries.empty()) + { + // RunnerAction::State indices (can't include functionrunner.h from here) + constexpr int kStateNew = 0; + constexpr int kStatePending = 1; + constexpr int kStateRunning = 3; + constexpr int kStateCompleted = 4; // first terminal state + constexpr int kStateCount = 8; + + for (const auto& HistEntry : ComputeSession.GetQueueHistory(QueueId, 0)) + { + std::lock_guard Lock(SummaryLock); + if (auto It = SummaryEntries.find(HistEntry.Lsn); It != SummaryEntries.end()) + { + // Find terminal state timestamp (Completed, Failed, Abandoned, or Cancelled) + uint64_t EndTick = 0; + for (int S = kStateCompleted; S < kStateCount; ++S) + { + if (HistEntry.Timestamps[S] != 0) + { + EndTick = HistEntry.Timestamps[S]; + break; + } + } + uint64_t StartTick = HistEntry.Timestamps[kStateNew]; + if (EndTick > StartTick) + { + It->second.WallSeconds = float(double(EndTick - StartTick) / double(TimeSpan::TicksPerSecond)); + } + It->second.CpuSeconds = HistEntry.CpuSeconds; + It->second.SubmittedTicks = HistEntry.Timestamps[kStatePending]; + It->second.StartedTicks = HistEntry.Timestamps[kStateRunning]; + It->second.ExecutionLocation = HistEntry.ExecutionLocation; + } + } + } + + // Write summary file if output path is set + + if (!m_OutputPath.empty() && !SummaryEntries.empty()) + { + std::vector<ActionSummaryEntry> Sorted; + Sorted.reserve(SummaryEntries.size()); + for (auto& [_, Entry] : SummaryEntries) + { + Sorted.push_back(std::move(Entry)); + } + + std::sort(Sorted.begin(), Sorted.end(), [](const ActionSummaryEntry& A, const ActionSummaryEntry& B) { + return A.RecordingIndex < B.RecordingIndex; + }); + + auto FormatTimestamp = [](uint64_t Ticks) -> std::string { + if (Ticks == 0) + { + return "-"; + } + return DateTime(Ticks).ToString("%H:%M:%S.%s"); + }; + + ExtendableStringBuilder<4096> Summary; + Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8} {:>8} {:>12} {:>12} {:<24}\n", + "LSN", + "Index", + "ActionId", + "Function", + "InAtt", + "InBytes", + "OutAtt", + "OutBytes", + "Wall(s)", + "CPU(s)", + "Submitted", + "Started", + "Location")); + Summary.Append(fmt::format("{:-<8} {:-<8} {:-<40} {:-<40} {:-<8} {:-<12} {:-<8} {:-<12} {:-<8} {:-<8} {:-<12} {:-<12} {:-<24}\n", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "")); + + for (const ActionSummaryEntry& Entry : Sorted) + { + Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8.2f} {:>8.2f} {:>12} {:>12} {:<24}\n", + Entry.Lsn, + Entry.RecordingIndex, + Entry.ActionId, + Entry.FunctionName, + Entry.InputAttachments, + NiceBytes(Entry.InputBytes), + Entry.OutputAttachments, + NiceBytes(Entry.OutputBytes), + Entry.WallSeconds, + Entry.CpuSeconds, + FormatTimestamp(Entry.SubmittedTicks), + FormatTimestamp(Entry.StartedTicks), + Entry.ExecutionLocation)); + } + + std::filesystem::path SummaryPath = m_OutputPath / "summary.txt"; + std::string_view SummaryStr = Summary; + zen::WriteFile(SummaryPath, IoBuffer(IoBuffer::Clone, SummaryStr.data(), SummaryStr.size())); + + ZEN_CONSOLE("wrote summary to {}", SummaryPath.string()); + + if (!m_Binary) + { + auto EscapeHtml = [](std::string_view Input) -> std::string { + std::string Out; + Out.reserve(Input.size()); + for (char C : Input) + { + switch (C) + { + case '&': + Out += "&"; + break; + case '<': + Out += "<"; + break; + case '>': + Out += ">"; + break; + case '"': + Out += """; + break; + case '\'': + Out += "'"; + break; + default: + Out += C; + } + } + return Out; + }; + + auto EscapeJson = [](std::string_view Input) -> std::string { + std::string Out; + Out.reserve(Input.size()); + for (char C : Input) + { + switch (C) + { + case '"': + Out += "\\\""; + break; + case '\\': + Out += "\\\\"; + break; + case '\n': + Out += "\\n"; + break; + case '\r': + Out += "\\r"; + break; + case '\t': + Out += "\\t"; + break; + default: + if (static_cast<unsigned char>(C) < 0x20) + { + Out += fmt::format("\\u{:04x}", static_cast<unsigned>(static_cast<unsigned char>(C))); + } + else + { + Out += C; + } + } + } + return Out; + }; + + ExtendableStringBuilder<8192> Html; + + Html.Append(std::string_view(R"(<!DOCTYPE html> +<html><head><meta charset="utf-8"><title>Exec Summary</title> +<style> +body{font-family:system-ui,sans-serif;margin:20px;background:#fafafa} +#container{overflow-y:auto;height:calc(100vh - 120px)} +table{border-collapse:collapse;width:100%} +th,td{border:1px solid #ddd;padding:6px 10px;text-align:left;white-space:nowrap} +th{background:#f0f0f0;cursor:pointer;user-select:none;position:sticky;top:0;z-index:1} +th:hover{background:#e0e0e0} +th .arrow{font-size:0.7em;margin-left:4px} +tr:hover{background:#e8f0fe} +input{padding:6px 10px;margin-bottom:12px;width:300px;border:1px solid #ccc;border-radius:4px} +button{padding:6px 14px;margin-left:8px;margin-bottom:12px;border:1px solid #ccc;border-radius:4px;background:#f0f0f0;cursor:pointer} +button:hover{background:#e0e0e0} +a{color:#1a73e8;text-decoration:none} +a:hover{text-decoration:underline} +.num{text-align:right} +</style></head><body> +<h2>Exec Summary</h2> +<input type="text" id="filter" placeholder="Filter by function name..."><button id="csvBtn">Export CSV</button> +<div id="container"> +<table><thead><tr> +<th data-col="0">LSN <span class="arrow"></span></th> +<th data-col="1">Index <span class="arrow"></span></th> +<th data-col="2">Action ID <span class="arrow"></span></th> +<th data-col="3">Function <span class="arrow"></span></th> +<th data-col="4">In Attachments <span class="arrow"></span></th> +<th data-col="5">In Bytes <span class="arrow"></span></th> +<th data-col="6">Out Attachments <span class="arrow"></span></th> +<th data-col="7">Out Bytes <span class="arrow"></span></th> +<th data-col="8">Wall(s) <span class="arrow"></span></th> +<th data-col="9">CPU(s) <span class="arrow"></span></th> +<th data-col="10">Submitted <span class="arrow"></span></th> +<th data-col="11">Started <span class="arrow"></span></th> +<th data-col="12">Location <span class="arrow"></span></th> +</tr></thead><tbody> +<tr id="spacerTop"><td colspan="13"></td></tr> +<tr id="spacerBot"><td colspan="13"></td></tr> +</tbody></table></div> +<script> +const DATA=[ +)")); + + std::string_view ResultExt = ".result.yaml"; + std::string_view ActionExt = ".action.yaml"; + + for (const ActionSummaryEntry& Entry : Sorted) + { + std::string SafeName = EscapeJson(EscapeHtml(Entry.FunctionName)); + std::string ActionIdStr = Entry.ActionId.ToHexString(); + std::string ActionLink; + if (!ActionExt.empty()) + { + ActionLink = EscapeJson(fmt::format(" <a href=\"{}{}\">[action]</a>", Entry.Lsn, ActionExt)); + } + + // Indices: 0=lsn, 1=idx, 2=actionId, 3=fn, 4=inAtt, 5=inBytes, 6=outAtt, 7=outBytes, + // 8=wall, 9=cpu, 10=niceBytesIn, 11=niceBytesOut, 12=actionLink, + // 13=submittedTicks, 14=startedTicks, 15=submittedDisplay, 16=startedDisplay, + // 17=location + Html.Append( + fmt::format("[{},{},\"{}\",\"{}\",{},{},{},{},{:.6f},{:.6f},\"{}\",\"{}\",\"{}\",{},{},\"{}\",\"{}\",\"{}\"],\n", + Entry.Lsn, + Entry.RecordingIndex, + ActionIdStr, + SafeName, + Entry.InputAttachments, + Entry.InputBytes, + Entry.OutputAttachments, + Entry.OutputBytes, + Entry.WallSeconds, + Entry.CpuSeconds, + EscapeJson(NiceBytes(Entry.InputBytes)), + EscapeJson(NiceBytes(Entry.OutputBytes)), + ActionLink, + Entry.SubmittedTicks, + Entry.StartedTicks, + FormatTimestamp(Entry.SubmittedTicks), + FormatTimestamp(Entry.StartedTicks), + EscapeJson(EscapeHtml(Entry.ExecutionLocation)))); + } + + Html.Append(fmt::format(R"(]; +const RESULT_EXT="{}"; +)", + ResultExt)); + + Html.Append(std::string_view(R"JS((function(){ +const ROW_H=33,BUF=20; +const container=document.getElementById("container"); +const tbody=container.querySelector("tbody"); +const headers=container.querySelectorAll("th"); +const filterInput=document.getElementById("filter"); +const spacerTop=document.getElementById("spacerTop"); +const spacerBot=document.getElementById("spacerBot"); +let view=[...DATA.keys()]; +let sortCol=-1,sortAsc=true; +const COLS=[ + {f:0,t:"n"},{f:1,t:"n"},{f:2,t:"s"},{f:3,t:"s"}, + {f:4,t:"n"},{f:5,t:"n"},{f:6,t:"n"},{f:7,t:"n"}, + {f:8,t:"n"},{f:9,t:"n"},{f:13,t:"n"},{f:14,t:"n"},{f:17,t:"s"} +]; +function rowHtml(i){ + const d=DATA[view[i]]; + const bg=i%2?' style="background:#f9f9f9"':''; + return '<tr'+bg+'>'+ + '<td class="num"><a href="'+d[0]+RESULT_EXT+'">'+d[0]+'</a></td>'+ + '<td class="num">'+d[1]+'</td>'+ + '<td><code>'+d[2]+'</code></td>'+ + '<td>'+d[3]+d[12]+'</td>'+ + '<td class="num">'+d[4]+'</td>'+ + '<td class="num">'+d[10]+'</td>'+ + '<td class="num">'+d[6]+'</td>'+ + '<td class="num">'+d[11]+'</td>'+ + '<td class="num">'+d[8].toFixed(2)+'</td>'+ + '<td class="num">'+d[9].toFixed(2)+'</td>'+ + '<td class="num">'+d[15]+'</td>'+ + '<td class="num">'+d[16]+'</td>'+ + '<td>'+d[17]+'</td></tr>'; +} +let lastFirst=-1,lastLast=-1; +function render(){ + const scrollTop=container.scrollTop; + const viewH=container.clientHeight; + let first=Math.floor(scrollTop/ROW_H)-BUF; + let last=Math.ceil((scrollTop+viewH)/ROW_H)+BUF; + if(first<0) first=0; + if(last>=view.length) last=view.length-1; + if(first===lastFirst&&last===lastLast) return; + lastFirst=first;lastLast=last; + const rows=[]; + for(let i=first;i<=last;i++) rows.push(rowHtml(i)); + spacerTop.style.height=(first*ROW_H)+'px'; + spacerBot.style.height=((view.length-1-last)*ROW_H)+'px'; + const mid=rows.join(''); + const topTr='<tr id="spacerTop"><td colspan="13" style="border:0;padding:0;height:'+spacerTop.style.height+'"></td></tr>'; + const botTr='<tr id="spacerBot"><td colspan="13" style="border:0;padding:0;height:'+spacerBot.style.height+'"></td></tr>'; + tbody.innerHTML=topTr+mid+botTr; +} +function applySort(){ + if(sortCol<0) return; + const c=COLS[sortCol]; + view.sort((a,b)=>{ + const va=DATA[a][c.f],vb=DATA[b][c.f]; + if(c.t==="n") return sortAsc?va-vb:vb-va; + return sortAsc?(va<vb?-1:va>vb?1:0):(va>vb?-1:va<vb?1:0); + }); +} +function rebuild(){ + const q=filterInput.value.toLowerCase(); + view=[]; + for(let i=0;i<DATA.length;i++){ + if(!q||DATA[i][3].toLowerCase().includes(q)) view.push(i); + } + applySort(); + lastFirst=lastLast=-1; + render(); +} +headers.forEach(th=>{ + th.addEventListener("click",()=>{ + const col=parseInt(th.dataset.col); + if(sortCol===col){sortAsc=!sortAsc}else{sortCol=col;sortAsc=true} + headers.forEach(h=>h.querySelector(".arrow").textContent=""); + th.querySelector(".arrow").textContent=sortAsc?"\u25B2":"\u25BC"; + applySort(); + lastFirst=lastLast=-1; + render(); + }); +}); +filterInput.addEventListener("input",()=>rebuild()); +let ticking=false; +container.addEventListener("scroll",()=>{ + if(!ticking){ticking=true;requestAnimationFrame(()=>{render();ticking=false})} +}); +rebuild(); +document.getElementById("csvBtn").addEventListener("click",()=>{ + const H=["LSN","Index","Action ID","Function","In Attachments","In Bytes","Out Attachments","Out Bytes","Wall(s)","CPU(s)","Submitted","Started","Location"]; + const esc=v=>{const s=String(v);return s.includes(',')||s.includes('"')||s.includes('\n')?'"'+s.replace(/"/g,'""')+'"':s}; + const rows=[H.join(",")]; + for(let i=0;i<view.length;i++){ + const d=DATA[view[i]]; + rows.push([d[0],d[1],d[2],d[3],d[4],d[5],d[6],d[7],d[8],d[9],d[15],d[16],d[17]].map(esc).join(",")); + } + const blob=new Blob([rows.join("\n")],{type:"text/csv"}); + const a=document.createElement("a"); + a.href=URL.createObjectURL(blob); + a.download="summary.csv"; + a.click(); + URL.revokeObjectURL(a.href); +}); +})(); +</script></body></html> +)JS")); + + std::filesystem::path HtmlPath = m_OutputPath / "summary.html"; + std::string_view HtmlStr = Html; + zen::WriteFile(HtmlPath, IoBuffer(IoBuffer::Clone, HtmlStr.data(), HtmlStr.size())); + + ZEN_CONSOLE("wrote HTML summary to {}", HtmlPath.string()); + } + } + + if (FailedWorkCounter) + { + return 1; + } + + return 0; +} + +int +ExecCommand::LocalMessagingExecute() +{ + // Non-HTTP work submission path + + // To be reimplemented using final transport + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +int +ExecCommand::HttpExecute() +{ + ZEN_ASSERT(m_ChunkResolver); + ChunkResolver& Resolver = *m_ChunkResolver; + + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + + zen::compute::ComputeServiceSession ComputeSession(Resolver); + ComputeSession.AddRemoteRunner(Resolver, TempPath, m_HostName); + + return ExecUsingSession(ComputeSession); +} + +int +ExecCommand::BeaconExecute() +{ + ZEN_ASSERT(m_ChunkResolver); + ChunkResolver& Resolver = *m_ChunkResolver; + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + + zen::compute::ComputeServiceSession ComputeSession(Resolver); + + if (!m_OrchestratorUrl.empty()) + { + ZEN_CONSOLE_INFO("using orchestrator at {}", m_OrchestratorUrl); + ComputeSession.SetOrchestratorEndpoint(m_OrchestratorUrl); + ComputeSession.SetOrchestratorBasePath(TempPath); + } + else + { + ZEN_CONSOLE_INFO("note: using hard-coded local worker path"); + ComputeSession.AddRemoteRunner(Resolver, TempPath, "http://localhost:8558"); + } + + return ExecUsingSession(ComputeSession); +} + +////////////////////////////////////////////////////////////////////////// + +void +ExecCommand::RegisterWorkerFunctionsFromDescription(const CbObject& WorkerDesc, const IoHash& WorkerId) +{ + const Guid WorkerBuildSystemVersion = WorkerDesc["buildsystem_version"sv].AsUuid(); + + for (auto& Item : WorkerDesc["functions"sv]) + { + CbObjectView Function = Item.AsObjectView(); + + std::string_view FunctionName = Function["name"sv].AsString(); + const Guid FunctionVersion = Function["version"sv].AsUuid(); + + m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, + .FunctionVersion = FunctionVersion, + .BuildSystemVersion = WorkerBuildSystemVersion, + .WorkerId = WorkerId}); + } +} + +void +ExecCommand::EmitFunctionListOnce(const std::vector<FunctionDefinition>& FunctionList) +{ + if (m_FunctionListEmittedOnce == false) + { + EmitFunctionList(FunctionList); + + m_FunctionListEmittedOnce = true; + } +} + +int +ExecCommand::DumpWorkItems() +{ + std::atomic<int> EmittedCount{0}; + + eastl::hash_map<IoHash, uint64_t> SeenAttachments; // Attachment CID -> count of references + + m_RecordingReader->IterateActions( + [&](CbObject ActionObject, const IoHash& ActionId) { + eastl::hash_map<IoHash, CompressedBuffer> Attachments; + + uint64_t AttachmentBytes = 0; + uint64_t UncompressedAttachmentBytes = 0; + + ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = m_ChunkResolver->FindChunkByCid(AttachmentCid); + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + Attachments[AttachmentCid] = CompressedData; + + AttachmentBytes += CompressedData.GetCompressedSize(); + UncompressedAttachmentBytes += CompressedData.DecodeRawSize(); + + if (auto [Iter, Inserted] = SeenAttachments.insert({AttachmentCid, 1}); !Inserted) + { + ++Iter->second; + } + }); + + zen::ExtendableStringBuilder<1024> ObjStr; + +# if 0 + zen::CompactBinaryToJson(ActionObject, ObjStr); + ZEN_CONSOLE("work item {} ({} attachments): {}", ActionId, Attachments.size(), ObjStr); +# else + zen::CompactBinaryToYaml(ActionObject, ObjStr); + ZEN_CONSOLE("work item {} ({} attachments, {}->{} bytes):\n{}", + ActionId, + Attachments.size(), + AttachmentBytes, + UncompressedAttachmentBytes, + ObjStr); +# endif + + ++EmittedCount; + }, + 1); + + ZEN_CONSOLE("emitted: {} actions", EmittedCount.load()); + + eastl::map<uint64_t, std::vector<IoHash>> ReferenceHistogram; + + for (const auto& [K, V] : SeenAttachments) + { + if (V > 1) + { + ReferenceHistogram[V].push_back(K); + } + } + + for (const auto& [RefCount, Cids] : ReferenceHistogram) + { + ZEN_CONSOLE("{} attachments with {} references", Cids.size(), RefCount); + } + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +int +ExecCommand::BuildActionsLog() +{ + ZEN_ASSERT(m_ChunkResolver); + ChunkResolver& Resolver = *m_ChunkResolver; + + if (m_RecordingPath.empty()) + { + throw OptionParseException("need to specify recording path", m_Options.help()); + } + + if (std::filesystem::exists(m_RecordingLogPath)) + { + throw OptionParseException(fmt::format("recording log directory '{}' already exists!", m_RecordingLogPath), m_Options.help()); + } + + ZEN_NOT_IMPLEMENTED("build log generation not implemented yet!"); + + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + + zen::compute::ComputeServiceSession ComputeSession(Resolver); + ComputeSession.StartRecording(Resolver, m_RecordingLogPath); + + return ExecUsingSession(ComputeSession); +} + +void +ExecCommand::EmitFunctionList(const std::vector<FunctionDefinition>& FunctionList) +{ + ZEN_CONSOLE("=== Known functions:\n==========================="); + + ZEN_CONSOLE("{:30} {:36} {:36} {}", "function", "version", "build system", "worker id"); + + for (const FunctionDefinition& Func : FunctionList) + { + ZEN_CONSOLE("{:30} {:36} {:36} {}", Func.FunctionName, Func.FunctionVersion, Func.BuildSystemVersion, Func.WorkerId); + } + + ZEN_CONSOLE("==========================="); +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h new file mode 100644 index 000000000..6311354c0 --- /dev/null +++ b/src/zen/cmds/exec_cmd.h @@ -0,0 +1,101 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +#include <zencompute/recordingreader.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/guid.h> +#include <zencore/iohash.h> + +#include <filesystem> +#include <functional> +#include <unordered_map> + +namespace zen { +class CbPackage; +class CbObject; +struct IoHash; +class ChunkResolver; +} // namespace zen + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { +class ComputeServiceSession; +} + +namespace zen { + +/** + * Zen CLI command for executing functions from a recording + * + * Mostly for testing and debugging purposes + */ + +class ExecCommand : public ZenCmdBase +{ +public: + ExecCommand(); + ~ExecCommand(); + + static constexpr char Name[] = "exec"; + static constexpr char Description[] = "Execute functions from a recording"; + + virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{Name, Description}; + std::string m_HostName; + std::string m_OrchestratorUrl; + std::filesystem::path m_BeaconPath; + std::filesystem::path m_RecordingPath; + std::filesystem::path m_RecordingLogPath; + int m_Offset = 0; + int m_Stride = 1; + int m_Limit = 0; + bool m_Quiet = false; + std::string m_Mode{"http"}; + std::filesystem::path m_OutputPath; + bool m_Binary = false; + + struct FunctionDefinition + { + std::string FunctionName; + zen::Guid FunctionVersion; + zen::Guid BuildSystemVersion; + zen::IoHash WorkerId; + }; + + bool m_FunctionListEmittedOnce = false; + void EmitFunctionListOnce(const std::vector<FunctionDefinition>& FunctionList); + void EmitFunctionList(const std::vector<FunctionDefinition>& FunctionList); + + std::unordered_map<zen::IoHash, zen::CbPackage> m_WorkerMap; + std::vector<FunctionDefinition> m_FunctionList; + bool m_VerboseLogging = false; + bool m_QuietLogging = false; + bool m_DumpActions = false; + + zen::ChunkResolver* m_ChunkResolver = nullptr; + zen::compute::RecordingReaderBase* m_RecordingReader = nullptr; + + void RegisterWorkerFunctionsFromDescription(const zen::CbObject& WorkerDesc, const zen::IoHash& WorkerId); + + int ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession); + + // Execution modes + + int DumpWorkItems(); + int HttpExecute(); + int InProcessExecute(); + int LocalMessagingExecute(); + int BeaconExecute(); + int BuildActionsLog(); +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zen/cmds/info_cmd.h b/src/zen/cmds/info_cmd.h index 231565bfd..dc108b8a2 100644 --- a/src/zen/cmds/info_cmd.h +++ b/src/zen/cmds/info_cmd.h @@ -9,6 +9,9 @@ namespace zen { class InfoCommand : public ZenCmdBase { public: + static constexpr char Name[] = "info"; + static constexpr char Description[] = "Show high level Zen server information"; + InfoCommand(); ~InfoCommand(); @@ -17,7 +20,7 @@ public: // virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"info", "Show high level zen store information"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; diff --git a/src/zen/cmds/print_cmd.cpp b/src/zen/cmds/print_cmd.cpp index 030cc8b66..c6b250fdf 100644 --- a/src/zen/cmds/print_cmd.cpp +++ b/src/zen/cmds/print_cmd.cpp @@ -84,7 +84,7 @@ PrintCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) } else { - MakeSafeAbsolutePathÍnPlace(m_Filename); + MakeSafeAbsolutePathInPlace(m_Filename); Fc = ReadFile(m_Filename); } @@ -244,7 +244,7 @@ PrintPackageCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** ar if (m_Filename.empty()) throw OptionParseException("'--source' is required", m_Options.help()); - MakeSafeAbsolutePathÍnPlace(m_Filename); + MakeSafeAbsolutePathInPlace(m_Filename); FileContents Fc = ReadFile(m_Filename); IoBuffer Data = Fc.Flatten(); CbPackage Package; diff --git a/src/zen/cmds/print_cmd.h b/src/zen/cmds/print_cmd.h index 6c1529b7c..f4a97e218 100644 --- a/src/zen/cmds/print_cmd.h +++ b/src/zen/cmds/print_cmd.h @@ -11,6 +11,9 @@ namespace zen { class PrintCommand : public ZenCmdBase { public: + static constexpr char Name[] = "print"; + static constexpr char Description[] = "Print compact binary object"; + PrintCommand(); ~PrintCommand(); @@ -19,7 +22,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"print", "Print compact binary object"}; + cxxopts::Options m_Options{Name, Description}; std::filesystem::path m_Filename; bool m_ShowCbObjectTypeInfo = false; }; @@ -29,6 +32,9 @@ private: class PrintPackageCommand : public ZenCmdBase { public: + static constexpr char Name[] = "printpackage"; + static constexpr char Description[] = "Print compact binary package"; + PrintPackageCommand(); ~PrintPackageCommand(); @@ -37,7 +43,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"printpkg", "Print compact binary package"}; + cxxopts::Options m_Options{Name, Description}; std::filesystem::path m_Filename; bool m_ShowCbObjectTypeInfo = false; }; diff --git a/src/zen/cmds/projectstore_cmd.cpp b/src/zen/cmds/projectstore_cmd.cpp index 519b68126..db931e49a 100644 --- a/src/zen/cmds/projectstore_cmd.cpp +++ b/src/zen/cmds/projectstore_cmd.cpp @@ -41,12 +41,10 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -namespace { +namespace projectstore_impl { using namespace std::literals; -#define ZEN_CLOUD_STORAGE "Cloud Storage" - void WriteAuthOptions(CbObjectWriter& Writer, std::string_view JupiterOpenIdProvider, std::string_view JupiterAccessToken, @@ -500,7 +498,7 @@ namespace { return {}; } -} // namespace +} // namespace projectstore_impl /////////////////////////////////////// @@ -522,6 +520,7 @@ DropProjectCommand::~DropProjectCommand() void DropProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -611,6 +610,7 @@ ProjectInfoCommand::~ProjectInfoCommand() void ProjectInfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -697,6 +697,7 @@ CreateProjectCommand::~CreateProjectCommand() = default; void CreateProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); using namespace std::literals; @@ -766,6 +767,7 @@ CreateOplogCommand::~CreateOplogCommand() = default; void CreateOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); using namespace std::literals; @@ -990,6 +992,7 @@ ExportOplogCommand::~ExportOplogCommand() void ExportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; using namespace std::literals; ZEN_UNUSED(GlobalOptions); @@ -1470,6 +1473,20 @@ ImportOplogCommand::ImportOplogCommand() "Enables both 'boost-worker-count' and 'boost-worker-memory' - may cause computer to be less responsive", cxxopts::value(m_BoostWorkers), "<boostworkermemory>"); + m_Options.add_option( + "", + "", + "allow-partial-block-requests", + "Allow request for partial chunk blocks.\n" + " false = only full block requests allowed\n" + " mixed = multiple partial block ranges requests per block allowed to zen cache, single partial block range " + "request per block to host\n" + " zencacheonly = multiple partial block ranges requests per block allowed to zen cache, only full block requests " + "allowed to host\n" + " true = multiple partial block ranges requests per block allowed to zen cache and host\n" + "Defaults to 'mixed'.", + cxxopts::value(m_AllowPartialBlockRequests), + "<allowpartialblockrequests>"); m_Options.parse_positional({"project", "oplog", "gcpath"}); m_Options.positional_help("[<projectid> <oplogid> [<gcpath>]]"); @@ -1482,6 +1499,7 @@ ImportOplogCommand::~ImportOplogCommand() void ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; using namespace std::literals; ZEN_UNUSED(GlobalOptions); @@ -1514,6 +1532,13 @@ ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg throw OptionParseException("'--oplog' is required", m_Options.help()); } + EPartialBlockRequestMode Mode = PartialBlockRequestModeFromString(m_AllowPartialBlockRequests); + if (Mode == EPartialBlockRequestMode::Invalid) + { + throw OptionParseException(fmt::format("'--allow-partial-block-requests' ('{}') is invalid", m_AllowPartialBlockRequests), + m_Options.help()); + } + HttpClient Http(m_HostName); m_ProjectName = ResolveProject(Http, m_ProjectName); if (m_ProjectName.empty()) @@ -1651,6 +1676,9 @@ ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg { Writer.AddBool("boostworkermemory"sv, true); } + + Writer.AddString("partialblockrequestmode", m_AllowPartialBlockRequests); + if (!m_FileDirectoryPath.empty()) { Writer.BeginObject("file"sv); @@ -1766,6 +1794,7 @@ SnapshotOplogCommand::~SnapshotOplogCommand() void SnapshotOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; using namespace std::literals; ZEN_UNUSED(GlobalOptions); @@ -1830,6 +1859,7 @@ ProjectStatsCommand::~ProjectStatsCommand() void ProjectStatsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -1882,6 +1912,7 @@ ProjectOpDetailsCommand::~ProjectOpDetailsCommand() void ProjectOpDetailsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -1997,6 +2028,7 @@ OplogMirrorCommand::~OplogMirrorCommand() void OplogMirrorCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -2264,6 +2296,7 @@ OplogValidateCommand::~OplogValidateCommand() void OplogValidateCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -2415,6 +2448,7 @@ OplogDownloadCommand::~OplogDownloadCommand() void OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -2432,7 +2466,7 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a { m_SystemRootDir = PickDefaultSystemRootDirectory(); } - MakeSafeAbsolutePathÍnPlace(m_SystemRootDir); + MakeSafeAbsolutePathInPlace(m_SystemRootDir); }; ParseSystemOptions(); @@ -2570,36 +2604,37 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a StorageInstance Storage; - ClientSettings.AssumeHttp2 = ResolveRes.HostAssumeHttp2; + ClientSettings.AssumeHttp2 = ResolveRes.Cloud.AssumeHttp2; ClientSettings.MaximumInMemoryDownloadSize = m_BoostWorkerMemory ? RemoteStoreOptions::DefaultMaxBlockSize : 1024u * 1024u; - Storage.BuildStorageHttp = std::make_unique<HttpClient>(ResolveRes.HostUrl, ClientSettings); + Storage.BuildStorageHttp = std::make_unique<HttpClient>(ResolveRes.Cloud.Address, ClientSettings); + Storage.BuildStorageHost = ResolveRes.Cloud; BuildStorageCache::Statistics StorageCacheStats; std::atomic<bool> AbortFlag(false); - if (!ResolveRes.CacheUrl.empty()) + if (!ResolveRes.Cache.Address.empty()) { Storage.CacheHttp = std::make_unique<HttpClient>( - ResolveRes.CacheUrl, + ResolveRes.Cache.Address, HttpClientSettings{ .LogCategory = "httpcacheclient", .ConnectTimeout = std::chrono::milliseconds{3000}, .Timeout = std::chrono::milliseconds{30000}, - .AssumeHttp2 = ResolveRes.CacheAssumeHttp2, + .AssumeHttp2 = ResolveRes.Cache.AssumeHttp2, .AllowResume = true, .RetryCount = 0, .MaximumInMemoryDownloadSize = m_BoostWorkerMemory ? RemoteStoreOptions::DefaultMaxBlockSize : 1024u * 1024u}, [&AbortFlag]() { return AbortFlag.load(); }); - Storage.CacheName = ResolveRes.CacheName; + Storage.CacheHost = ResolveRes.Cache; } if (!m_Quiet) { std::string StorageDescription = fmt::format("Cloud {}{}. SessionId {}. Namespace '{}', Bucket '{}'", - ResolveRes.HostName, - (ResolveRes.HostUrl == ResolveRes.HostName) ? "" : fmt::format(" {}", ResolveRes.HostUrl), + ResolveRes.Cloud.Name, + (ResolveRes.Cloud.Address == ResolveRes.Cloud.Name) ? "" : fmt::format(" {}", ResolveRes.Cloud.Address), Storage.BuildStorageHttp->GetSessionId(), m_Namespace, m_Bucket); @@ -2610,8 +2645,8 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a { std::string CacheDescription = fmt::format("Zen {}{}. SessionId {}. Namespace '{}', Bucket '{}'", - ResolveRes.CacheName, - (ResolveRes.CacheUrl == ResolveRes.CacheName) ? "" : fmt::format(" {}", ResolveRes.CacheUrl), + ResolveRes.Cache.Name, + (ResolveRes.Cache.Address == ResolveRes.Cache.Name) ? "" : fmt::format(" {}", ResolveRes.Cache.Address), Storage.CacheHttp->GetSessionId(), m_Namespace, m_Bucket); @@ -2627,11 +2662,10 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a Storage.BuildStorage = CreateJupiterBuildStorage(Log(), *Storage.BuildStorageHttp, StorageStats, m_Namespace, m_Bucket, m_AllowRedirect, StorageTempPath); - Storage.StorageName = ResolveRes.HostName; if (Storage.CacheHttp) { - Storage.BuildCacheStorage = CreateZenBuildStorageCache( + Storage.CacheStorage = CreateZenBuildStorageCache( *Storage.CacheHttp, StorageCacheStats, m_Namespace, diff --git a/src/zen/cmds/projectstore_cmd.h b/src/zen/cmds/projectstore_cmd.h index 56ef858f5..1ba98b39e 100644 --- a/src/zen/cmds/projectstore_cmd.h +++ b/src/zen/cmds/projectstore_cmd.h @@ -16,6 +16,9 @@ class ProjectStoreCommand : public ZenCmdBase class DropProjectCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "project-drop"; + static constexpr char Description[] = "Drop project or project oplog"; + DropProjectCommand(); ~DropProjectCommand(); @@ -23,7 +26,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"project-drop", "Drop project or project oplog"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -33,13 +36,16 @@ private: class ProjectInfoCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "project-info"; + static constexpr char Description[] = "Info on project or project oplog"; + ProjectInfoCommand(); ~ProjectInfoCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"project-info", "Info on project or project oplog"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -48,6 +54,9 @@ private: class CreateProjectCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "project-create"; + static constexpr char Description[] = "Create a project"; + CreateProjectCommand(); ~CreateProjectCommand(); @@ -55,7 +64,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"project-create", "Create project, the project must not already exist."}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectId; std::string m_RootDir; @@ -68,6 +77,9 @@ private: class CreateOplogCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-create"; + static constexpr char Description[] = "Create a project oplog"; + CreateOplogCommand(); ~CreateOplogCommand(); @@ -75,7 +87,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-create", "Create oplog in an existing project, the oplog must not already exist."}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectId; std::string m_OplogId; @@ -86,6 +98,9 @@ private: class ExportOplogCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-export"; + static constexpr char Description[] = "Export project store oplog"; + ExportOplogCommand(); ~ExportOplogCommand(); @@ -93,8 +108,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-export", - "Export project store oplog to cloud (--cloud), file system (--file) or other Zen instance (--zen)"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -145,6 +159,9 @@ private: class ImportOplogCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-import"; + static constexpr char Description[] = "Import project store oplog"; + ImportOplogCommand(); ~ImportOplogCommand(); @@ -152,8 +169,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-import", - "Import project store oplog from cloud (--cloud), file system (--file) or other Zen instance (--zen)"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -193,19 +209,23 @@ private: bool m_BoostWorkerCount = false; bool m_BoostWorkerMemory = false; bool m_BoostWorkers = false; + + std::string m_AllowPartialBlockRequests = "true"; }; class SnapshotOplogCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-snapshot"; + static constexpr char Description[] = "Snapshot project store oplog"; + SnapshotOplogCommand(); ~SnapshotOplogCommand(); - virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-snapshot", "Snapshot external file references in project store oplog into zen"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -214,26 +234,32 @@ private: class ProjectStatsCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "project-stats"; + static constexpr char Description[] = "Stats on project store"; + ProjectStatsCommand(); ~ProjectStatsCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"project-stats", "Stats info on project store"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; class ProjectOpDetailsCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "project-op-details"; + static constexpr char Description[] = "Detail info on ops inside a project store oplog"; + ProjectOpDetailsCommand(); ~ProjectOpDetailsCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"project-op-details", "Detail info on ops inside a project store oplog"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_Details = false; bool m_OpDetails = false; @@ -247,13 +273,16 @@ private: class OplogMirrorCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-mirror"; + static constexpr char Description[] = "Mirror project store oplog to file system"; + OplogMirrorCommand(); ~OplogMirrorCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-mirror", "Mirror oplog to file system"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -268,13 +297,16 @@ private: class OplogValidateCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-validate"; + static constexpr char Description[] = "Validate oplog for missing references"; + OplogValidateCommand(); ~OplogValidateCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-validate", "Validate oplog for missing references"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; diff --git a/src/zen/cmds/rpcreplay_cmd.h b/src/zen/cmds/rpcreplay_cmd.h index a6363b614..332a3126c 100644 --- a/src/zen/cmds/rpcreplay_cmd.h +++ b/src/zen/cmds/rpcreplay_cmd.h @@ -9,6 +9,9 @@ namespace zen { class RpcStartRecordingCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "rpc-record-start"; + static constexpr char Description[] = "Starts recording of cache rpc requests on a host"; + RpcStartRecordingCommand(); ~RpcStartRecordingCommand(); @@ -16,7 +19,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"rpc-record-start", "Starts recording of cache rpc requests on a host"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_RecordingPath; }; @@ -24,6 +27,9 @@ private: class RpcStopRecordingCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "rpc-record-stop"; + static constexpr char Description[] = "Stops recording of cache rpc requests on a host"; + RpcStopRecordingCommand(); ~RpcStopRecordingCommand(); @@ -31,13 +37,16 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"rpc-record-stop", "Stops recording of cache rpc requests on a host"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; class RpcReplayCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "rpc-record-replay"; + static constexpr char Description[] = "Replays a previously recorded session of rpc requests"; + RpcReplayCommand(); ~RpcReplayCommand(); @@ -45,7 +54,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"rpc-record-replay", "Replays a previously recorded session of cache rpc requests to a target host"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_RecordingPath; bool m_OnHost = false; diff --git a/src/zen/cmds/run_cmd.h b/src/zen/cmds/run_cmd.h index 570a2e63a..300c08c5b 100644 --- a/src/zen/cmds/run_cmd.h +++ b/src/zen/cmds/run_cmd.h @@ -9,6 +9,9 @@ namespace zen { class RunCommand : public ZenCmdBase { public: + static constexpr char Name[] = "run"; + static constexpr char Description[] = "Run command with special options"; + RunCommand(); ~RunCommand(); @@ -17,7 +20,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"run", "Run executable"}; + cxxopts::Options m_Options{Name, Description}; int m_RunCount = 0; int m_RunTime = -1; std::string m_BaseDirectory; diff --git a/src/zen/cmds/serve_cmd.h b/src/zen/cmds/serve_cmd.h index ac74981f2..22f430948 100644 --- a/src/zen/cmds/serve_cmd.h +++ b/src/zen/cmds/serve_cmd.h @@ -11,6 +11,9 @@ namespace zen { class ServeCommand : public ZenCmdBase { public: + static constexpr char Name[] = "serve"; + static constexpr char Description[] = "Serve files from a directory"; + ServeCommand(); ~ServeCommand(); @@ -18,7 +21,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"serve", "Serve files from a tree"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; diff --git a/src/zen/cmds/status_cmd.h b/src/zen/cmds/status_cmd.h index dc103a196..df5df3066 100644 --- a/src/zen/cmds/status_cmd.h +++ b/src/zen/cmds/status_cmd.h @@ -11,6 +11,9 @@ namespace zen { class StatusCommand : public ZenCmdBase { public: + static constexpr char Name[] = "status"; + static constexpr char Description[] = "Show zen status"; + StatusCommand(); ~StatusCommand(); @@ -20,7 +23,7 @@ public: private: int GetLockFileEffectivePort() const; - cxxopts::Options m_Options{"status", "Show zen status"}; + cxxopts::Options m_Options{Name, Description}; uint16_t m_Port = 0; std::filesystem::path m_DataDir; }; diff --git a/src/zen/cmds/top_cmd.h b/src/zen/cmds/top_cmd.h index 74167ecfd..aeb196558 100644 --- a/src/zen/cmds/top_cmd.h +++ b/src/zen/cmds/top_cmd.h @@ -9,6 +9,9 @@ namespace zen { class TopCommand : public ZenCmdBase { public: + static constexpr char Name[] = "top"; + static constexpr char Description[] = "Monitor zen server activity"; + TopCommand(); ~TopCommand(); @@ -16,12 +19,15 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"top", "Show dev UI"}; + cxxopts::Options m_Options{Name, Description}; }; class PsCommand : public ZenCmdBase { public: + static constexpr char Name[] = "ps"; + static constexpr char Description[] = "Enumerate running zen server instances"; + PsCommand(); ~PsCommand(); @@ -29,7 +35,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"ps", "Enumerate running Zen server instances"}; + cxxopts::Options m_Options{Name, Description}; }; } // namespace zen diff --git a/src/zen/cmds/trace_cmd.h b/src/zen/cmds/trace_cmd.h index a6c9742b7..6eb0ba22b 100644 --- a/src/zen/cmds/trace_cmd.h +++ b/src/zen/cmds/trace_cmd.h @@ -6,11 +6,12 @@ namespace zen { -/** Scrub storage - */ class TraceCommand : public ZenCmdBase { public: + static constexpr char Name[] = "trace"; + static constexpr char Description[] = "Control zen realtime tracing"; + TraceCommand(); ~TraceCommand(); @@ -18,7 +19,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"trace", "Control zen realtime tracing"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_Stop = false; std::string m_TraceHost; diff --git a/src/zen/cmds/ui_cmd.cpp b/src/zen/cmds/ui_cmd.cpp new file mode 100644 index 000000000..da06ce305 --- /dev/null +++ b/src/zen/cmds/ui_cmd.cpp @@ -0,0 +1,236 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "ui_cmd.h" + +#include <zencore/except_fmt.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/process.h> +#include <zenutil/consoletui.h> +#include <zenutil/zenserverprocess.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +# include <shellapi.h> +#endif + +namespace zen { + +namespace { + + struct RunningServerInfo + { + uint16_t Port; + uint32_t Pid; + std::string SessionId; + std::string CmdLine; + }; + + static std::vector<RunningServerInfo> CollectRunningServers() + { + std::vector<RunningServerInfo> Servers; + ZenServerState State; + if (!State.InitializeReadOnly()) + return Servers; + + State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) { + StringBuilder<25> SessionSB; + Entry.GetSessionId().ToString(SessionSB); + std::error_code CmdLineEc; + std::string CmdLine = GetProcessCommandLine(static_cast<int>(Entry.Pid.load()), CmdLineEc); + Servers.push_back({Entry.EffectiveListenPort.load(), Entry.Pid.load(), std::string(SessionSB.c_str()), std::move(CmdLine)}); + }); + + return Servers; + } + +} // namespace + +UiCommand::UiCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_options()("a,all", "Open dashboard for all running instances", cxxopts::value(m_All)->default_value("false")); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", + "p", + "path", + "Dashboard path (default: /dashboard/)", + cxxopts::value(m_DashboardPath)->default_value("/dashboard/"), + "<path>"); + m_Options.parse_positional("path"); +} + +UiCommand::~UiCommand() +{ +} + +void +UiCommand::OpenBrowser(std::string_view HostName) +{ + // Allow shortcuts for specifying dashboard path, and ensure it is in a format we expect + // (leading slash, trailing slash if no file extension) + + if (!m_DashboardPath.empty()) + { + if (m_DashboardPath[0] != '/') + { + m_DashboardPath = "/dashboard/" + m_DashboardPath; + } + + if (m_DashboardPath.find_last_of('.') == std::string::npos && m_DashboardPath.back() != '/') + { + m_DashboardPath += '/'; + } + } + + bool Success = false; + + ExtendableStringBuilder<256> FullUrl; + FullUrl << HostName << m_DashboardPath; + +#if ZEN_PLATFORM_WINDOWS + HINSTANCE Result = ShellExecuteA(nullptr, "open", FullUrl.c_str(), nullptr, nullptr, SW_SHOWNORMAL); + Success = reinterpret_cast<intptr_t>(Result) > 32; +#else + // Validate URL doesn't contain shell metacharacters that could lead to command injection + std::string_view FullUrlView = FullUrl; + constexpr std::string_view DangerousChars = ";|&$`\\\"'<>(){}[]!#*?~\n\r"; + if (FullUrlView.find_first_of(DangerousChars) != std::string_view::npos) + { + throw OptionParseException(fmt::format("URL contains invalid characters: '{}'", FullUrl), m_Options.help()); + } + +# if ZEN_PLATFORM_MAC + std::string Command = fmt::format("open \"{}\"", FullUrl); +# elif ZEN_PLATFORM_LINUX + std::string Command = fmt::format("xdg-open \"{}\"", FullUrl); +# else + ZEN_NOT_IMPLEMENTED("Browser launching not implemented on this platform"); +# endif + + Success = system(Command.c_str()) == 0; +#endif + + if (!Success) + { + throw zen::runtime_error("Failed to launch browser for '{}'", FullUrl); + } + + ZEN_CONSOLE("Web browser launched for '{}' successfully", FullUrl); +} + +void +UiCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + using namespace std::literals; + + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return; + } + + // Resolve target server + uint16_t ServerPort = 0; + + if (m_HostName.empty()) + { + // Auto-discover running instances. + std::vector<RunningServerInfo> Servers = CollectRunningServers(); + + if (m_All) + { + if (Servers.empty()) + { + throw OptionParseException("No running Zen server instances found", m_Options.help()); + } + + for (const auto& Server : Servers) + { + OpenBrowser(fmt::format("http://localhost:{}", Server.Port)); + } + return; + } + + // If multiple are found and we have an interactive terminal, present a picker + // instead of silently using the first one. + if (Servers.size() > 1 && IsTuiAvailable()) + { + std::vector<std::string> Labels; + Labels.reserve(Servers.size() + 1); + Labels.push_back(fmt::format("(all {} instances)", Servers.size())); + + const int32_t Cols = static_cast<int32_t>(TuiConsoleColumns()); + constexpr int32_t kIndicator = 3; // " ▶ " or " " prefix + constexpr int32_t kSeparator = 2; // " " before cmdline + constexpr int32_t kEllipsis = 3; // "..." + + for (const auto& Server : Servers) + { + std::string Label = fmt::format("port {:<5} pid {:<7} session {}", Server.Port, Server.Pid, Server.SessionId); + + if (!Server.CmdLine.empty()) + { + int32_t Available = Cols - kIndicator - kSeparator - static_cast<int32_t>(Label.size()); + if (Available > kEllipsis) + { + Label += " "; + if (static_cast<int32_t>(Server.CmdLine.size()) <= Available) + { + Label += Server.CmdLine; + } + else + { + Label.append(Server.CmdLine, 0, static_cast<size_t>(Available - kEllipsis)); + Label += "..."; + } + } + } + + Labels.push_back(std::move(Label)); + } + + int SelectedIdx = TuiPickOne("Multiple Zen server instances found. Select one to open:", Labels); + if (SelectedIdx < 0) + return; // User cancelled + + if (SelectedIdx == 0) + { + // "All" selected + for (const auto& Server : Servers) + { + OpenBrowser(fmt::format("http://localhost:{}", Server.Port)); + } + return; + } + + ServerPort = Servers[SelectedIdx - 1].Port; + m_HostName = fmt::format("http://localhost:{}", ServerPort); + } + + if (m_HostName.empty()) + { + // Single or zero instances, or not an interactive terminal: + // fall back to default resolution (picks first instance or returns empty) + m_HostName = ResolveTargetHostSpec("", ServerPort); + } + } + else + { + if (m_All) + { + throw OptionParseException("--all cannot be used together with --hosturl", m_Options.help()); + } + m_HostName = ResolveTargetHostSpec(m_HostName, ServerPort); + } + + if (m_HostName.empty()) + { + throw OptionParseException("Unable to resolve server specification", m_Options.help()); + } + + OpenBrowser(m_HostName); +} + +} // namespace zen diff --git a/src/zen/cmds/ui_cmd.h b/src/zen/cmds/ui_cmd.h new file mode 100644 index 000000000..c74cdbbd0 --- /dev/null +++ b/src/zen/cmds/ui_cmd.h @@ -0,0 +1,32 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +#include <filesystem> + +namespace zen { + +class UiCommand : public ZenCmdBase +{ +public: + UiCommand(); + ~UiCommand(); + + static constexpr char Name[] = "ui"; + static constexpr char Description[] = "Launch web browser with zen server UI"; + + virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + void OpenBrowser(std::string_view HostName); + + cxxopts::Options m_Options{Name, Description}; + std::string m_HostName; + std::string m_DashboardPath = "/dashboard/"; + bool m_All = false; +}; + +} // namespace zen diff --git a/src/zen/cmds/up_cmd.h b/src/zen/cmds/up_cmd.h index 2e822d5fc..270db7f88 100644 --- a/src/zen/cmds/up_cmd.h +++ b/src/zen/cmds/up_cmd.h @@ -11,6 +11,9 @@ namespace zen { class UpCommand : public ZenCmdBase { public: + static constexpr char Name[] = "up"; + static constexpr char Description[] = "Bring zen server up"; + UpCommand(); ~UpCommand(); @@ -18,7 +21,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"up", "Bring up zen service"}; + cxxopts::Options m_Options{Name, Description}; uint16_t m_Port = 0; bool m_ShowConsole = false; bool m_ShowLog = false; @@ -28,6 +31,9 @@ private: class AttachCommand : public ZenCmdBase { public: + static constexpr char Name[] = "attach"; + static constexpr char Description[] = "Add a sponsor process to a running zen service"; + AttachCommand(); ~AttachCommand(); @@ -35,7 +41,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"attach", "Add a sponsor process to a running zen service"}; + cxxopts::Options m_Options{Name, Description}; uint16_t m_Port = 0; int m_OwnerPid = 0; std::filesystem::path m_DataDir; @@ -44,6 +50,9 @@ private: class DownCommand : public ZenCmdBase { public: + static constexpr char Name[] = "down"; + static constexpr char Description[] = "Bring zen server down"; + DownCommand(); ~DownCommand(); @@ -51,7 +60,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"down", "Bring down zen service"}; + cxxopts::Options m_Options{Name, Description}; uint16_t m_Port = 0; bool m_ForceTerminate = false; std::filesystem::path m_ProgramBaseDir; diff --git a/src/zen/cmds/vfs_cmd.h b/src/zen/cmds/vfs_cmd.h index 5deaa02fa..9009c774b 100644 --- a/src/zen/cmds/vfs_cmd.h +++ b/src/zen/cmds/vfs_cmd.h @@ -9,6 +9,9 @@ namespace zen { class VfsCommand : public StorageCommand { public: + static constexpr char Name[] = "vfs"; + static constexpr char Description[] = "Manage virtual file system"; + VfsCommand(); ~VfsCommand(); @@ -16,7 +19,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"vfs", "Manage virtual file system"}; + cxxopts::Options m_Options{Name, Description}; std::string m_Verb; std::string m_HostName; diff --git a/src/zen/cmds/wipe_cmd.cpp b/src/zen/cmds/wipe_cmd.cpp index adf0e61f0..10f5ad8e1 100644 --- a/src/zen/cmds/wipe_cmd.cpp +++ b/src/zen/cmds/wipe_cmd.cpp @@ -33,7 +33,7 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -namespace { +namespace wipe_impl { static std::atomic<bool> AbortFlag = false; static std::atomic<bool> PauseFlag = false; static bool IsVerbose = false; @@ -49,10 +49,11 @@ namespace { : GetMediumWorkerPool(EWorkloadType::Burst); } -#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \ - if (IsVerbose) \ - { \ - ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__); \ +#undef ZEN_CONSOLE_VERBOSE +#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \ + if (IsVerbose) \ + { \ + ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__); \ } static void SignalCallbackHandler(int SigNum) @@ -505,7 +506,7 @@ namespace { } return CleanWipe; } -} // namespace +} // namespace wipe_impl WipeCommand::WipeCommand() { @@ -532,6 +533,7 @@ WipeCommand::~WipeCommand() = default; void WipeCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace wipe_impl; ZEN_UNUSED(GlobalOptions); signal(SIGINT, SignalCallbackHandler); @@ -549,7 +551,7 @@ WipeCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) ProgressMode = (IsVerbose || m_PlainProgress) ? ProgressBar::Mode::Plain : ProgressBar::Mode::Pretty; BoostWorkerThreads = m_BoostWorkerThreads; - MakeSafeAbsolutePathÍnPlace(m_Directory); + MakeSafeAbsolutePathInPlace(m_Directory); if (!IsDir(m_Directory)) { diff --git a/src/zen/cmds/workspaces_cmd.cpp b/src/zen/cmds/workspaces_cmd.cpp index 6e6f5d863..af265d898 100644 --- a/src/zen/cmds/workspaces_cmd.cpp +++ b/src/zen/cmds/workspaces_cmd.cpp @@ -398,7 +398,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** } else { - MakeSafeAbsolutePathÍnPlace(m_SystemRootDir); + MakeSafeAbsolutePathInPlace(m_SystemRootDir); } std::filesystem::path StatePath = m_SystemRootDir / "workspaces"; @@ -815,7 +815,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** if (Results.size() != m_ChunkIds.size()) { throw std::runtime_error( - fmt::format("failed to get workspace share batch - invalid result count recevied (expected: {}, received: {}", + fmt::format("failed to get workspace share batch - invalid result count received (expected: {}, received: {}", m_ChunkIds.size(), Results.size())); } diff --git a/src/zen/progressbar.cpp b/src/zen/progressbar.cpp index 83606df67..b758c061b 100644 --- a/src/zen/progressbar.cpp +++ b/src/zen/progressbar.cpp @@ -8,16 +8,12 @@ #include <zencore/logging.h> #include <zencore/windows.h> #include <zenremotestore/operationlogoutput.h> +#include <zenutil/consoletui.h> ZEN_THIRD_PARTY_INCLUDES_START #include <gsl/gsl-lite.hpp> ZEN_THIRD_PARTY_INCLUDES_END -#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC -# include <sys/ioctl.h> -# include <unistd.h> -#endif - ////////////////////////////////////////////////////////////////////////// namespace zen { @@ -31,35 +27,12 @@ GetConsoleHandle() } #endif -static bool -CheckStdoutTty() -{ -#if ZEN_PLATFORM_WINDOWS - HANDLE hStdOut = GetConsoleHandle(); - DWORD dwMode = 0; - static bool IsConsole = ::GetConsoleMode(hStdOut, &dwMode); - return IsConsole; -#else - return isatty(fileno(stdout)); -#endif -} - -static bool -IsStdoutTty() -{ - static bool StdoutIsTty = CheckStdoutTty(); - return StdoutIsTty; -} - static void OutputToConsoleRaw(const char* String, size_t Length) { #if ZEN_PLATFORM_WINDOWS HANDLE hStdOut = GetConsoleHandle(); -#endif - -#if ZEN_PLATFORM_WINDOWS - if (IsStdoutTty()) + if (TuiIsStdoutTty()) { WriteConsoleA(hStdOut, String, (DWORD)Length, 0, 0); } @@ -85,26 +58,6 @@ OutputToConsoleRaw(const StringBuilderBase& SB) } uint32_t -GetConsoleColumns(uint32_t Default) -{ -#if ZEN_PLATFORM_WINDOWS - HANDLE hStdOut = GetConsoleHandle(); - CONSOLE_SCREEN_BUFFER_INFO csbi; - if (GetConsoleScreenBufferInfo(hStdOut, &csbi) == TRUE) - { - return (uint32_t)(csbi.srWindow.Right - csbi.srWindow.Left + 1); - } -#else - struct winsize w; - if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &w) == 0) - { - return (uint32_t)w.ws_col; - } -#endif - return Default; -} - -uint32_t GetUpdateDelayMS(ProgressBar::Mode InMode) { switch (InMode) @@ -165,7 +118,7 @@ ProgressBar::PopLogOperation(Mode InMode) } ProgressBar::ProgressBar(Mode InMode, std::string_view InSubTask) -: m_Mode((!IsStdoutTty() && InMode == Mode::Pretty) ? Mode::Plain : InMode) +: m_Mode((!TuiIsStdoutTty() && InMode == Mode::Pretty) ? Mode::Plain : InMode) , m_LastUpdateMS((uint64_t)-1) , m_PausedMS(0) , m_SubTask(InSubTask) @@ -245,6 +198,7 @@ ProgressBar::UpdateState(const State& NewState, bool DoLinebreak) const std::string Details = (!NewState.Details.empty()) ? fmt::format(": {}", NewState.Details) : ""; const std::string Output = fmt::format("{} {}% ({}){}\n", Task, PercentDone, NiceTimeSpanMs(ElapsedTimeMS), Details); OutputToConsoleRaw(Output); + m_State = NewState; } else if (m_Mode == Mode::Pretty) { @@ -253,10 +207,11 @@ ProgressBar::UpdateState(const State& NewState, bool DoLinebreak) size_t ProgressBarCount = (ProgressBarSize * PercentDone) / 100; uint64_t Completed = NewState.TotalCount - NewState.RemainingCount; uint64_t ETAElapsedMS = ElapsedTimeMS -= m_PausedMS; - uint64_t ETAMS = - (NewState.Status == State::EStatus::Running) && (PercentDone > 5) ? (ETAElapsedMS * NewState.RemainingCount) / Completed : 0; + uint64_t ETAMS = ((m_State.TotalCount == NewState.TotalCount) && (NewState.Status == State::EStatus::Running)) && (PercentDone > 5) + ? (ETAElapsedMS * NewState.RemainingCount) / Completed + : 0; - uint32_t ConsoleColumns = GetConsoleColumns(1024); + uint32_t ConsoleColumns = TuiConsoleColumns(1024); const std::string PercentString = fmt::format("{:#3}%", PercentDone); @@ -435,19 +390,19 @@ class ConsoleOpLogOutput : public OperationLogOutput { public: ConsoleOpLogOutput(zen::ProgressBar::Mode InMode) : m_Mode(InMode) {} - virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) + virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override { - logging::EmitConsoleLogMessage(LogLevel, Format, Args); + logging::EmitConsoleLogMessage(Point, Args); } - virtual void SetLogOperationName(std::string_view Name) { zen::ProgressBar::SetLogOperationName(m_Mode, Name); } - virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) + virtual void SetLogOperationName(std::string_view Name) override { zen::ProgressBar::SetLogOperationName(m_Mode, Name); } + virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { zen::ProgressBar::SetLogOperationProgress(m_Mode, StepIndex, StepCount); } - virtual uint32_t GetProgressUpdateDelayMS() { return GetUpdateDelayMS(m_Mode); } + virtual uint32_t GetProgressUpdateDelayMS() override { return GetUpdateDelayMS(m_Mode); } - virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) { return new ConsoleOpLogProgressBar(m_Mode, InSubTask); } + virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override { return new ConsoleOpLogProgressBar(m_Mode, InSubTask); } private: zen::ProgressBar::Mode m_Mode; diff --git a/src/zen/progressbar.h b/src/zen/progressbar.h index bbdb008d4..cb1c7023b 100644 --- a/src/zen/progressbar.h +++ b/src/zen/progressbar.h @@ -76,7 +76,6 @@ private: }; uint32_t GetUpdateDelayMS(ProgressBar::Mode InMode); -uint32_t GetConsoleColumns(uint32_t Default); OperationLogOutput* CreateConsoleLogOutput(ProgressBar::Mode InMode); diff --git a/src/zen/xmake.lua b/src/zen/xmake.lua index ab094fef3..f889c3296 100644 --- a/src/zen/xmake.lua +++ b/src/zen/xmake.lua @@ -6,15 +6,12 @@ target("zen") add_files("**.cpp") add_files("zen.cpp", {unity_ignored = true }) add_deps("zencore", "zenhttp", "zenremotestore", "zenstore", "zenutil") + add_deps("zencompute", "zennet") add_deps("cxxopts", "fmt") add_packages("json11") add_includedirs(".") set_symbols("debug") - if is_mode("release") then - set_optimize("fastest") - end - if is_plat("windows") then add_files("zen.rc") add_ldflags("/subsystem:console,5.02") diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 09a2e4f91..9a466da2e 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -11,6 +11,7 @@ #include "cmds/cache_cmd.h" #include "cmds/copy_cmd.h" #include "cmds/dedup_cmd.h" +#include "cmds/exec_cmd.h" #include "cmds/info_cmd.h" #include "cmds/print_cmd.h" #include "cmds/projectstore_cmd.h" @@ -21,6 +22,7 @@ #include "cmds/status_cmd.h" #include "cmds/top_cmd.h" #include "cmds/trace_cmd.h" +#include "cmds/ui_cmd.h" #include "cmds/up_cmd.h" #include "cmds/version_cmd.h" #include "cmds/vfs_cmd.h" @@ -39,7 +41,8 @@ #include <zencore/trace.h> #include <zencore/windows.h> #include <zenhttp/httpcommon.h> -#include <zenutil/environmentoptions.h> +#include <zenutil/config/environmentoptions.h> +#include <zenutil/consoletui.h> #include <zenutil/logging.h> #include <zenutil/workerpools.h> #include <zenutil/zenserverprocess.h> @@ -53,7 +56,6 @@ #include "progressbar.h" #if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 # include <zencore/testing.h> #endif @@ -122,7 +124,7 @@ ZenCmdBase::ParseOptions(int argc, char** argv) bool ZenCmdBase::ParseOptions(cxxopts::Options& CmdOptions, int argc, char** argv) { - CmdOptions.set_width(GetConsoleColumns(80)); + CmdOptions.set_width(TuiConsoleColumns(80)); cxxopts::ParseResult Result; @@ -192,6 +194,84 @@ ZenCmdBase::GetSubCommand(cxxopts::Options&, return argc; } +ZenSubCmdBase::ZenSubCmdBase(std::string_view Name, std::string_view Description) +: m_SubOptions(std::string(Name), std::string(Description)) +{ + m_SubOptions.add_options()("h,help", "Print help"); +} + +void +ZenCmdWithSubCommands::AddSubCommand(ZenSubCmdBase& SubCmd) +{ + m_SubCommands.push_back(&SubCmd); +} + +bool +ZenCmdWithSubCommands::OnParentOptionsParsed(const ZenCliOptions& /*GlobalOptions*/) +{ + return true; +} + +void +ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + std::vector<cxxopts::Options*> SubOptionPtrs; + SubOptionPtrs.reserve(m_SubCommands.size()); + for (ZenSubCmdBase* SubCmd : m_SubCommands) + { + SubOptionPtrs.push_back(&SubCmd->SubOptions()); + } + + cxxopts::Options* MatchedSubOption = nullptr; + std::vector<char*> SubCommandArguments; + int ParentArgc = GetSubCommand(Options(), argc, argv, SubOptionPtrs, MatchedSubOption, SubCommandArguments); + + if (!ParseOptions(Options(), ParentArgc, argv)) + { + return; + } + + if (MatchedSubOption == nullptr) + { + ExtendableStringBuilder<128> VerbList; + for (bool First = true; ZenSubCmdBase * SubCmd : m_SubCommands) + { + if (!First) + { + VerbList.Append(", "); + } + VerbList.Append(SubCmd->SubOptions().program()); + First = false; + } + throw OptionParseException(fmt::format("No subcommand specified. Available subcommands: {}", VerbList.ToView()), Options().help()); + } + + ZenSubCmdBase* MatchedSubCmd = nullptr; + for (ZenSubCmdBase* SubCmd : m_SubCommands) + { + if (&SubCmd->SubOptions() == MatchedSubOption) + { + MatchedSubCmd = SubCmd; + break; + } + } + ZEN_ASSERT(MatchedSubCmd != nullptr); + + // Parse subcommand args before OnParentOptionsParsed so --help on the subcommand + // works without requiring parent options like --hosturl to be populated. + if (!ParseOptions(*MatchedSubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data())) + { + return; + } + + if (!OnParentOptionsParsed(GlobalOptions)) + { + return; + } + + MatchedSubCmd->Run(GlobalOptions); +} + static ReturnCode GetReturnCodeFromHttpResult(const HttpClientError& Ex) { @@ -316,22 +396,25 @@ main(int argc, char** argv) } #endif // ZEN_WITH_TRACE - AttachCommand AttachCmd; - BenchCommand BenchCmd; - BuildsCommand BuildsCmd; - CacheDetailsCommand CacheDetailsCmd; - CacheGetCommand CacheGetCmd; - CacheGenerateCommand CacheGenerateCmd; - CacheInfoCommand CacheInfoCmd; - CacheStatsCommand CacheStatsCmd; - CopyCommand CopyCmd; - CopyStateCommand CopyStateCmd; - CreateOplogCommand CreateOplogCmd; - CreateProjectCommand CreateProjectCmd; - DedupCommand DedupCmd; - DownCommand DownCmd; - DropCommand DropCmd; - DropProjectCommand ProjectDropCmd; + AttachCommand AttachCmd; + BenchCommand BenchCmd; + BuildsCommand BuildsCmd; + CacheDetailsCommand CacheDetailsCmd; + CacheGetCommand CacheGetCmd; + CacheGenerateCommand CacheGenerateCmd; + CacheInfoCommand CacheInfoCmd; + CacheStatsCommand CacheStatsCmd; + CopyCommand CopyCmd; + CopyStateCommand CopyStateCmd; + CreateOplogCommand CreateOplogCmd; + CreateProjectCommand CreateProjectCmd; + DedupCommand DedupCmd; + DownCommand DownCmd; + DropCommand DropCmd; + DropProjectCommand ProjectDropCmd; +#if ZEN_WITH_COMPUTE_SERVICES + ExecCommand ExecCmd; +#endif // ZEN_WITH_COMPUTE_SERVICES ExportOplogCommand ExportOplogCmd; FlushCommand FlushCmd; GcCommand GcCmd; @@ -360,6 +443,7 @@ main(int argc, char** argv) LoggingCommand LoggingCmd; TopCommand TopCmd; TraceCommand TraceCmd; + UiCommand UiCmd; UpCommand UpCmd; VersionCommand VersionCmd; VfsCommand VfsCmd; @@ -375,53 +459,57 @@ main(int argc, char** argv) const char* CmdSummary; } Commands[] = { // clang-format off - {"attach", &AttachCmd, "Add a sponsor process to a running zen service"}, - {"bench", &BenchCmd, "Utility command for benchmarking"}, - {BuildsCommand::Name, &BuildsCmd, BuildsCommand::Description}, - {"cache-details", &CacheDetailsCmd, "Details on cache"}, - {"cache-info", &CacheInfoCmd, "Info on cache, namespace or bucket"}, + {AttachCommand::Name, &AttachCmd, AttachCommand::Description}, + {BenchCommand::Name, &BenchCmd, BenchCommand::Description}, + {BuildsCommand::Name, &BuildsCmd, BuildsCommand::Description}, + {CacheDetailsCommand::Name, &CacheDetailsCmd, CacheDetailsCommand::Description}, + {CacheInfoCommand::Name, &CacheInfoCmd, CacheInfoCommand::Description}, {CacheGetCommand::Name, &CacheGetCmd, CacheGetCommand::Description}, {CacheGenerateCommand::Name, &CacheGenerateCmd, CacheGenerateCommand::Description}, - {"cache-stats", &CacheStatsCmd, "Stats on cache"}, - {"copy", &CopyCmd, "Copy file(s)"}, - {"copy-state", &CopyStateCmd, "Copy zen server disk state"}, - {"dedup", &DedupCmd, "Dedup files"}, - {"down", &DownCmd, "Bring zen server down"}, - {"drop", &DropCmd, "Drop cache namespace or bucket"}, - {"gc-status", &GcStatusCmd, "Garbage collect zen storage status check"}, - {"gc-stop", &GcStopCmd, "Request cancel of running garbage collection in zen storage"}, - {"gc", &GcCmd, "Garbage collect zen storage"}, - {"info", &InfoCmd, "Show high level Zen server information"}, - {"jobs", &JobCmd, "Show/cancel zen background jobs"}, - {"logs", &LoggingCmd, "Show/control zen logging"}, - {"oplog-create", &CreateOplogCmd, "Create a project oplog"}, - {"oplog-export", &ExportOplogCmd, "Export project store oplog"}, - {"oplog-import", &ImportOplogCmd, "Import project store oplog"}, - {"oplog-mirror", &OplogMirrorCmd, "Mirror project store oplog to file system"}, - {"oplog-snapshot", &SnapshotOplogCmd, "Snapshot project store oplog"}, + {CacheStatsCommand::Name, &CacheStatsCmd, CacheStatsCommand::Description}, + {CopyCommand::Name, &CopyCmd, CopyCommand::Description}, + {CopyStateCommand::Name, &CopyStateCmd, CopyStateCommand::Description}, + {DedupCommand::Name, &DedupCmd, DedupCommand::Description}, + {DownCommand::Name, &DownCmd, DownCommand::Description}, + {DropCommand::Name, &DropCmd, DropCommand::Description}, +#if ZEN_WITH_COMPUTE_SERVICES + {ExecCommand::Name, &ExecCmd, ExecCommand::Description}, +#endif + {GcStatusCommand::Name, &GcStatusCmd, GcStatusCommand::Description}, + {GcStopCommand::Name, &GcStopCmd, GcStopCommand::Description}, + {GcCommand::Name, &GcCmd, GcCommand::Description}, + {InfoCommand::Name, &InfoCmd, InfoCommand::Description}, + {JobCommand::Name, &JobCmd, JobCommand::Description}, + {LoggingCommand::Name, &LoggingCmd, LoggingCommand::Description}, + {CreateOplogCommand::Name, &CreateOplogCmd, CreateOplogCommand::Description}, + {ExportOplogCommand::Name, &ExportOplogCmd, ExportOplogCommand::Description}, + {ImportOplogCommand::Name, &ImportOplogCmd, ImportOplogCommand::Description}, + {OplogMirrorCommand::Name, &OplogMirrorCmd, OplogMirrorCommand::Description}, + {SnapshotOplogCommand::Name, &SnapshotOplogCmd, SnapshotOplogCommand::Description}, {OplogDownloadCommand::Name, &OplogDownload, OplogDownloadCommand::Description}, - {"oplog-validate", &OplogValidateCmd, "Validate oplog for missing references"}, - {"print", &PrintCmd, "Print compact binary object"}, - {"printpackage", &PrintPkgCmd, "Print compact binary package"}, - {"project-create", &CreateProjectCmd, "Create a project"}, - {"project-op-details", &ProjectOpDetailsCmd, "Detail info on ops inside a project store oplog"}, - {"project-drop", &ProjectDropCmd, "Drop project or project oplog"}, - {"project-info", &ProjectInfoCmd, "Info on project or project oplog"}, - {"project-stats", &ProjectStatsCmd, "Stats on project store"}, - {"ps", &PsCmd, "Enumerate running zen server instances"}, - {"rpc-record-replay", &RpcReplayCmd, "Replays a previously recorded session of rpc requests"}, - {"rpc-record-start", &RpcStartRecordingCmd, "Starts recording of cache rpc requests on a host"}, - {"rpc-record-stop", &RpcStopRecordingCmd, "Stops recording of cache rpc requests on a host"}, - {"run", &RunCmd, "Run command with special options"}, - {"scrub", &ScrubCmd, "Scrub zen storage (verify data integrity)"}, - {"serve", &ServeCmd, "Serve files from a directory"}, - {"status", &StatusCmd, "Show zen status"}, - {"top", &TopCmd, "Monitor zen server activity"}, - {"trace", &TraceCmd, "Control zen realtime tracing"}, - {"up", &UpCmd, "Bring zen server up"}, + {OplogValidateCommand::Name, &OplogValidateCmd, OplogValidateCommand::Description}, + {PrintCommand::Name, &PrintCmd, PrintCommand::Description}, + {PrintPackageCommand::Name, &PrintPkgCmd, PrintPackageCommand::Description}, + {CreateProjectCommand::Name, &CreateProjectCmd, CreateProjectCommand::Description}, + {ProjectOpDetailsCommand::Name, &ProjectOpDetailsCmd, ProjectOpDetailsCommand::Description}, + {DropProjectCommand::Name, &ProjectDropCmd, DropProjectCommand::Description}, + {ProjectInfoCommand::Name, &ProjectInfoCmd, ProjectInfoCommand::Description}, + {ProjectStatsCommand::Name, &ProjectStatsCmd, ProjectStatsCommand::Description}, + {PsCommand::Name, &PsCmd, PsCommand::Description}, + {RpcReplayCommand::Name, &RpcReplayCmd, RpcReplayCommand::Description}, + {RpcStartRecordingCommand::Name, &RpcStartRecordingCmd, RpcStartRecordingCommand::Description}, + {RpcStopRecordingCommand::Name, &RpcStopRecordingCmd, RpcStopRecordingCommand::Description}, + {RunCommand::Name, &RunCmd, RunCommand::Description}, + {ScrubCommand::Name, &ScrubCmd, ScrubCommand::Description}, + {ServeCommand::Name, &ServeCmd, ServeCommand::Description}, + {StatusCommand::Name, &StatusCmd, StatusCommand::Description}, + {TopCommand::Name, &TopCmd, TopCommand::Description}, + {TraceCommand::Name, &TraceCmd, TraceCommand::Description}, + {UiCommand::Name, &UiCmd, UiCommand::Description}, + {UpCommand::Name, &UpCmd, UpCommand::Description}, {VersionCommand::Name, &VersionCmd, VersionCommand::Description}, - {"vfs", &VfsCmd, "Manage virtual file system"}, - {"flush", &FlushCmd, "Flush storage"}, + {VfsCommand::Name, &VfsCmd, VfsCommand::Description}, + {FlushCommand::Name, &FlushCmd, FlushCommand::Description}, {WipeCommand::Name, &WipeCmd, WipeCommand::Description}, {WorkspaceCommand::Name, &WorkspaceCmd, WorkspaceCommand::Description}, {WorkspaceShareCommand::Name, &WorkspaceShareCmd, WorkspaceShareCommand::Description}, @@ -538,6 +626,9 @@ main(int argc, char** argv) Options.add_options()("corelimit", "Limit concurrency", cxxopts::value(CoreLimit)); + ZenLoggingCmdLineOptions LoggingCmdLineOptions; + LoggingCmdLineOptions.AddCliOptions(Options, GlobalOptions.LoggingConfig); + #if ZEN_WITH_TRACE // We only have this in options for command line help purposes - we parse these argument separately earlier using // GetTraceOptionsFromCommandline() @@ -624,8 +715,8 @@ main(int argc, char** argv) } LimitHardwareConcurrency(CoreLimit); -#if ZEN_USE_SENTRY +#if ZEN_USE_SENTRY { EnvironmentOptions EnvOptions; @@ -671,12 +762,19 @@ main(int argc, char** argv) } #endif - zen::LoggingOptions LogOptions; - LogOptions.IsDebug = GlobalOptions.IsDebug; - LogOptions.IsVerbose = GlobalOptions.IsVerbose; - LogOptions.AllowAsync = false; + LoggingCmdLineOptions.ApplyOptions(GlobalOptions.LoggingConfig); + + const LoggingOptions LogOptions = {.IsDebug = GlobalOptions.IsDebug, + .IsVerbose = GlobalOptions.IsVerbose, + .IsTest = false, + .NoConsoleOutput = GlobalOptions.LoggingConfig.NoConsoleOutput, + .QuietConsole = GlobalOptions.LoggingConfig.QuietConsole, + .AbsLogFile = GlobalOptions.LoggingConfig.AbsLogFile, + .LogId = GlobalOptions.LoggingConfig.LogId}; zen::InitializeLogging(LogOptions); + ApplyLoggingOptions(Options, GlobalOptions.LoggingConfig); + std::set_terminate([]() { void* Frames[8]; uint32_t FrameCount = GetCallstack(2, 8, Frames); diff --git a/src/zen/zen.h b/src/zen/zen.h index 05d1e4ec8..06e5356a6 100644 --- a/src/zen/zen.h +++ b/src/zen/zen.h @@ -5,7 +5,8 @@ #include <zencore/except.h> #include <zencore/timer.h> #include <zencore/zencore.h> -#include <zenutil/commandlineoptions.h> +#include <zenutil/config/commandlineoptions.h> +#include <zenutil/config/loggingconfig.h> namespace zen { @@ -14,6 +15,8 @@ struct ZenCliOptions bool IsDebug = false; bool IsVerbose = false; + ZenLoggingConfig LoggingConfig; + // Arguments after " -- " on command line are passed through and not parsed std::string PassthroughCommandLine; std::string PassthroughArgs; @@ -76,4 +79,41 @@ class CacheStoreCommand : public ZenCmdBase virtual ZenCmdCategory& CommandCategory() const override { return g_CacheStoreCategory; } }; +// Base for individual subcommands +class ZenSubCmdBase +{ +public: + ZenSubCmdBase(std::string_view Name, std::string_view Description); + virtual ~ZenSubCmdBase() = default; + cxxopts::Options& SubOptions() { return m_SubOptions; } + virtual void Run(const ZenCliOptions& GlobalOptions) = 0; + +protected: + cxxopts::Options m_SubOptions; +}; + +// Base for commands that host subcommands - handles all dispatch boilerplate +class ZenCmdWithSubCommands : public ZenCmdBase +{ +public: + void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) final; + +protected: + void AddSubCommand(ZenSubCmdBase& SubCmd); + virtual bool OnParentOptionsParsed(const ZenCliOptions& GlobalOptions); + +private: + std::vector<ZenSubCmdBase*> m_SubCommands; +}; + +class CacheStoreCmdWithSubCommands : public ZenCmdWithSubCommands +{ + ZenCmdCategory& CommandCategory() const override { return g_CacheStoreCategory; } +}; + +class StorageCmdWithSubCommands : public ZenCmdWithSubCommands +{ + ZenCmdCategory& CommandCategory() const override { return g_StorageCategory; } +}; + } // namespace zen diff --git a/src/zen/zen.rc b/src/zen/zen.rc index 661d75011..0617681a7 100644 --- a/src/zen/zen.rc +++ b/src/zen/zen.rc @@ -7,7 +7,7 @@ LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US #pragma code_page(1252) -101 ICON "..\\UnrealEngine.ico" +101 ICON "..\\zen.ico" VS_VERSION_INFO VERSIONINFO FILEVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0 diff --git a/src/zenbase/include/zenbase/refcount.h b/src/zenbase/include/zenbase/refcount.h index 40ad7bca5..08bc6ae54 100644 --- a/src/zenbase/include/zenbase/refcount.h +++ b/src/zenbase/include/zenbase/refcount.h @@ -51,6 +51,9 @@ private: * NOTE: Unlike RefCounted, this class deletes the derived type when the reference count reaches zero. * It has no virtual destructor, so it's important that you either don't derive from it further, * or ensure that the derived class has a virtual destructor. + * + * This class is useful when you want to avoid adding a vtable to a class just to implement + * reference counting. */ template<typename T> diff --git a/src/zencompute-test/xmake.lua b/src/zencompute-test/xmake.lua new file mode 100644 index 000000000..1207bdefd --- /dev/null +++ b/src/zencompute-test/xmake.lua @@ -0,0 +1,8 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("zencompute-test") + set_kind("binary") + set_group("tests") + add_headerfiles("**.h") + add_files("*.cpp") + add_deps("zencompute", "zencore") diff --git a/src/zencompute-test/zencompute-test.cpp b/src/zencompute-test/zencompute-test.cpp new file mode 100644 index 000000000..60aaeab1d --- /dev/null +++ b/src/zencompute-test/zencompute-test.cpp @@ -0,0 +1,16 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/zencompute.h> +#include <zencore/testing.h> + +#include <zencore/memory/newdelete.h> + +int +main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) +{ +#if ZEN_WITH_TESTS + return zen::testing::RunTestMain(argc, argv, "zencompute-test", zen::zencompute_forcelinktests); +#else + return 0; +#endif +} diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md new file mode 100644 index 000000000..f5188123f --- /dev/null +++ b/src/zencompute/CLAUDE.md @@ -0,0 +1,232 @@ +# zencompute Module + +Lambda-style compute function service. Accepts execution requests from HTTP clients, schedules them across local and remote runners, and tracks results. + +## Directory Structure + +``` +src/zencompute/ +├── include/zencompute/ # Public headers +│ ├── computeservice.h # ComputeServiceSession public API +│ ├── httpcomputeservice.h # HTTP service wrapper +│ ├── orchestratorservice.h # Worker registry and orchestration +│ ├── httporchestrator.h # HTTP orchestrator with WebSocket push +│ ├── recordingreader.h # Recording/replay reader API +│ ├── cloudmetadata.h # Cloud provider detection (AWS/Azure/GCP) +│ └── mockimds.h # Test helper for cloud metadata +├── runners/ # Execution backends +│ ├── functionrunner.h/.cpp # Abstract base + BaseRunnerGroup/RunnerGroup +│ ├── localrunner.h/.cpp # LocalProcessRunner (sandbox, monitoring, CPU sampling) +│ ├── windowsrunner.h/.cpp # Windows AppContainer sandboxing + CreateProcessW +│ ├── linuxrunner.h/.cpp # Linux user/mount/network namespace isolation +│ ├── macrunner.h/.cpp # macOS Seatbelt sandboxing +│ ├── winerunner.h/.cpp # Wine runner for Windows executables on Linux +│ ├── remotehttprunner.h/.cpp # Remote HTTP submission to other zenserver instances +│ └── deferreddeleter.h/.cpp # Background deletion of sandbox directories +├── recording/ # Recording/replay subsystem +│ ├── actionrecorder.h/.cpp # Write actions+attachments to disk +│ └── recordingreader.cpp # Read recordings back +├── timeline/ +│ └── workertimeline.h/.cpp # Per-worker action lifecycle event tracking +├── testing/ +│ └── mockimds.cpp # Mock IMDS for cloud metadata tests +├── computeservice.cpp # ComputeServiceSession::Impl (~1700 lines) +├── httpcomputeservice.cpp # HTTP route registration and handlers (~900 lines) +├── httporchestrator.cpp # Orchestrator HTTP API + WebSocket push +├── orchestratorservice.cpp # Worker registry, health probing +└── cloudmetadata.cpp # IMDS probing, termination monitoring +``` + +## Key Classes + +### `ComputeServiceSession` (computeservice.h) +Public API entry point. Uses PIMPL (`struct Impl` in computeservice.cpp). Owns: +- Two `RunnerGroup`s: `m_LocalRunnerGroup`, `m_RemoteRunnerGroup` +- Scheduler thread that drains `m_UpdatedActions` and drives state transitions +- Action maps: `m_PendingActions`, `m_RunningMap`, `m_ResultsMap` +- Queue map: `m_Queues` (QueueEntry objects) +- Action history ring: `m_ActionHistory` (bounded deque, default 1000) + +**Session states:** Created → Ready → Draining → Paused → Abandoned → Sunset. Both Abandoned and Sunset can be jumped to from any earlier state. Abandoned is used for spot instance termination grace periods — on entry, all pending and running actions are immediately marked as `RunnerAction::State::Abandoned` and running processes are best-effort cancelled. Auto-retry is suppressed while the session is Abandoned. `IsHealthy()` returns false for Abandoned and Sunset. + +### `RunnerAction` (runners/functionrunner.h) +Shared ref-counted struct representing one action through its lifecycle. + +**Key fields:** +- `ActionLsn` — global unique sequence number +- `QueueId` — 0 for implicit/unqueued actions +- `Worker` — descriptor + content hash +- `ActionObj` — CbObject with the action spec +- `CpuUsagePercent` / `CpuSeconds` — atomics updated by monitor thread +- `RetryCount` — atomic int tracking how many times the action has been rescheduled +- `Timestamps[State::_Count]` — timestamp of each state transition + +**State machine (forward-only under normal flow, atomic):** +``` +New → Pending → Submitting → Running → Completed + → Failed + → Abandoned + → Cancelled +``` +`SetActionState()` rejects non-forward transitions. The one exception is `ResetActionStateToPending()`, which uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling. It clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline. + +### `LocalProcessRunner` (runners/localrunner.h) +Base for all local execution. Platform runners subclass this and override: +- `SubmitAction()` — fork/exec the worker process +- `SweepRunningActions()` — poll for process exit (waitpid / WaitForSingleObject) +- `CancelRunningActions()` — signal all processes during shutdown +- `SampleProcessCpu(RunningAction&)` — read platform CPU usage (no-op default) + +**Infrastructure owned by LocalProcessRunner:** +- Monitor thread — calls `SweepRunningActions()` then `SampleRunningProcessCpu()` in a loop +- `m_RunningMap` — `RwLock`-guarded map of `Lsn → RunningAction` +- `DeferredDirectoryDeleter` — sandbox directories are queued for async deletion +- `PrepareActionSubmission()` — shared preamble (capacity check, sandbox creation, worker manifesting, input decompression) +- `ProcessCompletedActions()` — shared post-processing (gather outputs, set state, enqueue deletion) + +**CPU sampling:** `SampleRunningProcessCpu()` iterates `m_RunningMap` under shared lock, calls `SampleProcessCpu()` per entry, throttled to every 5 seconds per action. Platform implementations: +- Linux: `/proc/{pid}/stat` utime+stime jiffies ÷ `_SC_CLK_TCK` +- Windows: `GetProcessTimes()` in 100ns intervals ÷ 10,000,000 +- macOS: `proc_pidinfo(PROC_PIDTASKINFO)` pti_total_user+system nanoseconds ÷ 1,000,000,000 + +### `FunctionRunner` / `RunnerGroup` (runners/functionrunner.h) +Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. + +### `HttpComputeService` (include/zencompute/httpcomputeservice.h) +Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. + +## Action Lifecycle (End to End) + +1. **HTTP POST** → `HttpComputeService` ingests attachments, calls `EnqueueAction()` +2. **Enqueue** → creates `RunnerAction` (New → Pending), calls `PostUpdate()` +3. **PostUpdate** → appends to `m_UpdatedActions`, signals scheduler thread event +4. **Scheduler thread** → drains `m_UpdatedActions`, drives pending actions to runners +5. **Runner `SubmitAction()`** → Pending → Submitting (on runner's worker pool thread) +6. **Process launch** → Submitting → Running, added to `m_RunningMap` +7. **Monitor thread `SweepRunningActions()`** → detects exit, gathers outputs +8. **`ProcessCompletedActions()`** → Running → Completed/Failed/Abandoned, `PostUpdate()` +9. **Scheduler thread `HandleActionUpdates()`** — for Failed/Abandoned actions, checks retry limit; if retries remain, calls `ResetActionStateToPending()` which loops back to step 3. Otherwise moves to `m_ResultsMap`, records history, notifies queue. +10. **Client `GET /jobs/{lsn}`** → returns result from `m_ResultsMap`, schedules retirement + +### Action Rescheduling + +Actions that fail or are abandoned can be automatically retried or manually rescheduled via the API. + +**Automatic retry (scheduler path):** In `HandleActionUpdates()`, when a Failed or Abandoned state is detected, the scheduler checks `RetryCount < GetMaxRetriesForQueue(QueueId)`. If retries remain, the action is removed from active maps and `ResetActionStateToPending()` is called, which re-enters it into the scheduler pipeline. The action keeps its original LSN so clients can continue polling with the same identifier. + +**Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure. + +**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. + +**Cancelled actions are never retried** — cancellation is an intentional user action, not a transient failure. + +## Queue System + +Queues group actions from a single client session. A `QueueEntry` (internal) tracks: +- `State` — `std::atomic<QueueState>` lifecycle state (Active → Draining → Cancelled) +- `ActiveCount` — pending + running actions (atomic) +- `CompletedCount / FailedCount / AbandonedCount / CancelledCount` (atomics) +- `ActiveLsns` — for cancellation lookup (under `m_Lock`) +- `FinishedLsns` — moved here when actions complete +- `IdleSince` — used for 15-minute automatic expiry +- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit + +**Queue state machine (`QueueState` enum):** +``` +Active → Draining → Cancelled + \ ↑ + ─────────────────────/ +``` +- **Active** — accepts new work, schedules pending work, finishes running work (initial state) +- **Draining** — rejects new work, finishes existing work (one-way via CAS from Active; cannot override Cancelled) +- **Cancelled** — rejects new work, actively cancels in-flight work (reachable from Active or Draining) + +Key operations: +- `CreateQueue(Tag)` → returns `QueueId` +- `EnqueueActionToQueue(QueueId, ...)` → action's `QueueId` field is set at creation +- `CancelQueue(QueueId)` → marks all active LSNs for cancellation +- `DrainQueue(QueueId)` → stops accepting new submissions; existing work finishes naturally (irreversible) +- `GetQueueCompleted(QueueId)` → CbWriter output of finished results +- Queue references in HTTP routes accept either a decimal ID or an Oid token (24-hex), resolved by `ResolveQueueRef()` + +## HTTP API + +All routes registered in `HttpComputeService` constructor. Prefix is configured externally (typically `/compute`). + +### Global endpoints +| Method | Path | Description | +|--------|------|-------------| +| POST | `abandon` | Transition session to Abandoned state (409 if invalid) | +| GET | `jobs/history` | Action history (last N, with timestamps per state) | +| GET | `jobs/running` | In-flight actions with CPU metrics | +| GET | `jobs/completed` | Actions with results available | +| GET/POST/DELETE | `jobs/{lsn}` | GET: result; POST: reschedule failed action; DELETE: retire | +| POST | `jobs/{worker}` | Submit action for specific worker | +| POST | `jobs` | Submit action (worker resolved from descriptor) | +| GET | `workers` | List worker IDs | +| GET | `workers/all` | All workers with full descriptors | +| GET/POST | `workers/{worker}` | Get/register worker | + +### Queue-scoped endpoints +Queue ref is capture(1) in all `queues/{queueref}/...` routes. + +| Method | Path | Description | +|--------|------|-------------| +| GET | `queues` | List queue IDs | +| POST | `queues` | Create queue | +| GET/DELETE | `queues/{queueref}` | Status / delete | +| POST | `queues/{queueref}/drain` | Drain queue (irreversible; rejects new submissions) | +| GET | `queues/{queueref}/completed` | Queue's completed results | +| GET | `queues/{queueref}/history` | Queue's action history | +| GET | `queues/{queueref}/running` | Queue's running actions | +| POST | `queues/{queueref}/jobs` | Submit to queue | +| GET/POST | `queues/{queueref}/jobs/{lsn}` | GET: result; POST: reschedule | +| GET/POST | `queues/{queueref}/workers/...` | Worker endpoints (same as global) | + +Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `HandleWorkersAllGet`, `HandleWorkerRequest`) shared by top-level and queue-scoped routes. + +## Concurrency Model + +**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks: +1. `m_ResultsLock` +2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock") +3. `m_PendingLock` +4. `m_QueueLock` + +**Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`. + +**Update decoupling:** Runners call `PostUpdate(RunnerAction*)` rather than directly mutating service state. The scheduler thread batches and deduplicates updates. + +**Thread ownership:** +- Scheduler thread — drives state transitions, owns `m_PendingActions` +- Monitor thread (per runner) — polls process completion, owns `m_RunningMap` via shared lock +- Worker pool threads — async submission, brief `SubmitAction()` calls +- HTTP threads — read-only access to results, queue status + +## Sandbox Layout + +Each action gets a unique numbered directory under `m_SandboxPath`: +``` +scratch/{counter}/ + worker/ ← worker binaries (or bind-mounted on Linux) + inputs/ ← decompressed action inputs + outputs/ ← written by worker process +``` + +On Linux with sandboxing enabled, the process runs in a pivot-rooted namespace with `/usr`, `/lib`, `/etc`, `/worker` bind-mounted read-only and a tmpfs `/dev`. + +## Adding a New HTTP Endpoint + +1. Register the route in the `HttpComputeService` constructor in `httpcomputeservice.cpp` +2. If the handler is shared between top-level and a `queues/{queueref}/...` variant, extract it as a private helper method declared in `httpcomputeservice.h` +3. Queue-scoped routes validate the queue ref with `ResolveQueueRef(HttpReq, Req.GetCapture(1))` which writes an error response and returns 0 on failure +4. Use `CbObjectWriter` for response bodies; emit via `HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save())` +5. Conditional fields (e.g., optional CPU metrics): emit inside `if (value > 0.0f)` / `if (value >= 0.0f)` guards to omit absent values rather than emitting sentinel values + +## Adding a New Runner Platform + +1. Subclass `LocalProcessRunner`, add `h`/`cpp` files in `runners/` +2. Override `SubmitAction()`, `SweepRunningActions()`, `CancelRunningActions()`, and optionally `CancelAction(int)` and `SampleProcessCpu(RunningAction&)` +3. `SampleProcessCpu()` must update both `Running.Action->CpuSeconds` (unconditionally from the absolute OS value) and `Running.Action->CpuUsagePercent` (delta-based, only after second sample) +4. `ProcessHandle` convention: store pid as `reinterpret_cast<void*>(static_cast<intptr_t>(pid))` for consistency with the base class +5. Register in `ComputeServiceSession::AddLocalRunner()` in `computeservice.cpp` diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp new file mode 100644 index 000000000..65bac895f --- /dev/null +++ b/src/zencompute/cloudmetadata.cpp @@ -0,0 +1,1014 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/cloudmetadata.h> + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> +#include <zencore/string.h> +#include <zencore/trace.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +// All major cloud providers expose instance metadata at this link-local address. +// It is only routable from within a cloud VM; on bare-metal the TCP connect will +// fail, which is how we distinguish cloud from non-cloud environments. +static constexpr std::string_view kImdsEndpoint = "http://169.254.169.254"; + +// Short connect timeout so that detection on non-cloud machines is fast. The IMDS +// is a local service on the hypervisor so 200ms is generous for actual cloud VMs. +static constexpr auto kImdsTimeout = std::chrono::milliseconds{200}; + +std::string_view +ToString(CloudProvider Provider) +{ + switch (Provider) + { + case CloudProvider::AWS: + return "AWS"; + case CloudProvider::Azure: + return "Azure"; + case CloudProvider::GCP: + return "GCP"; + default: + return "None"; + } +} + +CloudMetadata::CloudMetadata(std::filesystem::path DataDir) : CloudMetadata(std::move(DataDir), std::string(kImdsEndpoint)) +{ +} + +CloudMetadata::CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint) +: m_Log(logging::Get("cloud")) +, m_DataDir(std::move(DataDir)) +, m_ImdsEndpoint(std::move(ImdsEndpoint)) +{ + ZEN_TRACE_CPU("CloudMetadata::CloudMetadata"); + + std::error_code Ec; + std::filesystem::create_directories(m_DataDir, Ec); + + DetectProvider(); + + if (m_Info.Provider != CloudProvider::None) + { + StartTerminationMonitor(); + } +} + +CloudMetadata::~CloudMetadata() +{ + ZEN_TRACE_CPU("CloudMetadata::~CloudMetadata"); + m_MonitorEnabled = false; + m_MonitorEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } +} + +CloudProvider +CloudMetadata::GetProvider() const +{ + return m_InfoLock.WithSharedLock([&] { return m_Info.Provider; }); +} + +CloudInstanceInfo +CloudMetadata::GetInstanceInfo() const +{ + return m_InfoLock.WithSharedLock([&] { return m_Info; }); +} + +bool +CloudMetadata::IsTerminationPending() const +{ + return m_TerminationPending.load(std::memory_order_relaxed); +} + +std::string +CloudMetadata::GetTerminationReason() const +{ + return m_ReasonLock.WithSharedLock([&] { return m_TerminationReason; }); +} + +void +CloudMetadata::Describe(CbWriter& Writer) const +{ + ZEN_TRACE_CPU("CloudMetadata::Describe"); + CloudInstanceInfo Info = GetInstanceInfo(); + + if (Info.Provider == CloudProvider::None) + { + return; + } + + Writer.BeginObject("cloud"); + Writer << "provider" << ToString(Info.Provider); + Writer << "instance_id" << Info.InstanceId; + Writer << "availability_zone" << Info.AvailabilityZone; + Writer << "is_spot" << Info.IsSpot; + Writer << "is_autoscaling" << Info.IsAutoscaling; + Writer << "termination_pending" << IsTerminationPending(); + + if (IsTerminationPending()) + { + Writer << "termination_reason" << GetTerminationReason(); + } + + Writer.EndObject(); +} + +void +CloudMetadata::DetectProvider() +{ + ZEN_TRACE_CPU("CloudMetadata::DetectProvider"); + + if (TryDetectAWS()) + { + return; + } + + if (TryDetectAzure()) + { + return; + } + + if (TryDetectGCP()) + { + return; + } + + ZEN_DEBUG("no cloud provider detected"); +} + +// AWS detection uses IMDSv2 which requires a session token obtained via PUT before +// any GET requests are allowed. This is more secure than IMDSv1 (which allowed +// unauthenticated GETs) and is the default on modern EC2 instances. The token has +// a 300-second TTL and is reused for termination polling. +bool +CloudMetadata::TryDetectAWS() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectAWS"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotAWS"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping AWS detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing AWS IMDS"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + // IMDSv2: acquire session token. The TTL header is mandatory; we request + // 300s which is sufficient for the detection phase. The token is also + // stored in m_AwsToken for reuse by the termination polling thread. + HttpClient::KeyValueMap TokenHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token-ttl-seconds", "300"}); + HttpClient::Response TokenResponse = ImdsClient.Put("/latest/api/token", IoBuffer{}, TokenHeaders); + + if (!TokenResponse.IsSuccess()) + { + ZEN_DEBUG("AWS IMDS token request failed ({}), not on AWS", static_cast<int>(TokenResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + m_AwsToken = std::string(TokenResponse.AsText()); + + HttpClient::KeyValueMap AuthHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token", m_AwsToken}); + + HttpClient::Response IdResponse = ImdsClient.Get("/latest/meta-data/instance-id", AuthHeaders); + if (IdResponse.IsSuccess()) + { + m_Info.InstanceId = std::string(IdResponse.AsText()); + } + + HttpClient::Response AzResponse = ImdsClient.Get("/latest/meta-data/placement/availability-zone", AuthHeaders); + if (AzResponse.IsSuccess()) + { + m_Info.AvailabilityZone = std::string(AzResponse.AsText()); + } + + // "spot" vs "on-demand" — determines whether the instance can be + // reclaimed by AWS with a 2-minute warning + HttpClient::Response LifecycleResponse = ImdsClient.Get("/latest/meta-data/instance-life-cycle", AuthHeaders); + if (LifecycleResponse.IsSuccess()) + { + m_Info.IsSpot = (LifecycleResponse.AsText() == "spot"); + } + + // This endpoint only exists on instances managed by an Auto Scaling + // Group. A successful response (regardless of value) means autoscaling. + HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders); + if (AutoscaleResponse.IsSuccess()) + { + m_Info.IsAutoscaling = true; + } + + m_Info.Provider = CloudProvider::AWS; + + ZEN_INFO("detected AWS instance: id={}, az={}, spot={}, autoscaling={}", + m_Info.InstanceId, + m_Info.AvailabilityZone, + m_Info.IsSpot, + m_Info.IsAutoscaling); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("AWS IMDS probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// Azure IMDS returns a single JSON document for the entire instance metadata, +// unlike AWS and GCP which use separate plain-text endpoints per field. The +// "Metadata: true" header is required; requests without it are rejected. +// The api-version parameter is mandatory and pins the response schema. +bool +CloudMetadata::TryDetectAzure() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectAzure"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotAzure"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping Azure detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing Azure IMDS"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + HttpClient::KeyValueMap MetadataHeaders({ + std::pair<std::string_view, std::string_view>{"Metadata", "true"}, + }); + + HttpClient::Response InstanceResponse = ImdsClient.Get("/metadata/instance?api-version=2021-02-01", MetadataHeaders); + + if (!InstanceResponse.IsSuccess()) + { + ZEN_DEBUG("Azure IMDS request failed ({}), not on Azure", static_cast<int>(InstanceResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(InstanceResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + ZEN_DEBUG("Azure IMDS returned invalid JSON: {}", JsonError); + WriteSentinelFile(SentinelPath); + return false; + } + + const json11::Json& Compute = Json["compute"]; + + m_Info.InstanceId = Compute["vmId"].string_value(); + m_Info.AvailabilityZone = Compute["location"].string_value(); + + // Azure spot VMs have priority "Spot"; regular VMs have "Regular" + std::string Priority = Compute["priority"].string_value(); + m_Info.IsSpot = (Priority == "Spot"); + + // Check if part of a VMSS (Virtual Machine Scale Set) — indicates autoscaling + std::string VmssName = Compute["vmScaleSetName"].string_value(); + m_Info.IsAutoscaling = !VmssName.empty(); + + m_Info.Provider = CloudProvider::Azure; + + ZEN_INFO("detected Azure instance: id={}, location={}, spot={}, vmss={}", + m_Info.InstanceId, + m_Info.AvailabilityZone, + m_Info.IsSpot, + m_Info.IsAutoscaling); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("Azure IMDS probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// GCP requires the "Metadata-Flavor: Google" header on all IMDS requests. +// Unlike AWS, there is no session token; the header itself is the auth mechanism +// (it prevents SSRF attacks since browsers won't send custom headers to the +// metadata endpoint). Each metadata field is fetched from a separate URL. +bool +CloudMetadata::TryDetectGCP() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectGCP"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotGCP"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping GCP detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing GCP metadata service"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + HttpClient::KeyValueMap MetadataHeaders(std::pair<std::string_view, std::string_view>{"Metadata-Flavor", "Google"}); + + // Fetch instance ID + HttpClient::Response IdResponse = ImdsClient.Get("/computeMetadata/v1/instance/id", MetadataHeaders); + + if (!IdResponse.IsSuccess()) + { + ZEN_DEBUG("GCP metadata request failed ({}), not on GCP", static_cast<int>(IdResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + m_Info.InstanceId = std::string(IdResponse.AsText()); + + // GCP returns the fully-qualified zone path "projects/<num>/zones/<zone>". + // Strip the prefix to get just the zone name (e.g. "us-central1-a"). + HttpClient::Response ZoneResponse = ImdsClient.Get("/computeMetadata/v1/instance/zone", MetadataHeaders); + if (ZoneResponse.IsSuccess()) + { + std::string_view Zone = ZoneResponse.AsText(); + if (auto Pos = Zone.rfind('/'); Pos != std::string_view::npos) + { + Zone = Zone.substr(Pos + 1); + } + m_Info.AvailabilityZone = std::string(Zone); + } + + // Check for preemptible/spot (scheduling/preemptible returns "TRUE" or "FALSE") + HttpClient::Response PreemptibleResponse = ImdsClient.Get("/computeMetadata/v1/instance/scheduling/preemptible", MetadataHeaders); + if (PreemptibleResponse.IsSuccess()) + { + m_Info.IsSpot = (PreemptibleResponse.AsText() == "TRUE"); + } + + // Check for maintenance event + HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders); + if (MaintenanceResponse.IsSuccess()) + { + std::string_view Event = MaintenanceResponse.AsText(); + if (!Event.empty() && Event != "NONE") + { + m_TerminationPending = true; + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); }); + } + } + + m_Info.Provider = CloudProvider::GCP; + + ZEN_INFO("detected GCP instance: id={}, az={}, spot={}", m_Info.InstanceId, m_Info.AvailabilityZone, m_Info.IsSpot); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("GCP metadata probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// Sentinel files are empty marker files whose mere existence signals that a +// previous detection attempt for a given provider failed. This avoids paying +// the connect-timeout cost on every startup for providers that are known to +// be absent. The files persist across process restarts; delete them manually +// (or remove the DataDir) to force re-detection. +void +CloudMetadata::WriteSentinelFile(const std::filesystem::path& Path) +{ + try + { + BasicFile File; + File.Open(Path, BasicFile::Mode::kTruncate); + } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to write sentinel file '{}': {}", Path.string(), Ex.what()); + } +} + +bool +CloudMetadata::HasSentinelFile(const std::filesystem::path& Path) const +{ + return zen::IsFile(Path); +} + +void +CloudMetadata::ClearSentinelFiles() +{ + std::error_code Ec; + std::filesystem::remove(m_DataDir / ".isNotAWS", Ec); + std::filesystem::remove(m_DataDir / ".isNotAzure", Ec); + std::filesystem::remove(m_DataDir / ".isNotGCP", Ec); +} + +void +CloudMetadata::StartTerminationMonitor() +{ + ZEN_INFO("starting cloud termination monitor for {} instance {}", ToString(m_Info.Provider), m_Info.InstanceId); + + m_MonitorThread = std::thread{&CloudMetadata::TerminationMonitorThread, this}; +} + +void +CloudMetadata::TerminationMonitorThread() +{ + SetCurrentThreadName("cloud_term_mon"); + + // Poll every 5 seconds. The Event is used as an interruptible sleep so + // that the destructor can wake us up immediately for a clean shutdown. + while (m_MonitorEnabled) + { + m_MonitorEvent.Wait(5000); + m_MonitorEvent.Reset(); + + if (!m_MonitorEnabled) + { + return; + } + + PollTermination(); + } +} + +void +CloudMetadata::PollTermination() +{ + try + { + CloudProvider Provider = m_InfoLock.WithSharedLock([&] { return m_Info.Provider; }); + + if (Provider == CloudProvider::AWS) + { + PollAWSTermination(); + } + else if (Provider == CloudProvider::Azure) + { + PollAzureTermination(); + } + else if (Provider == CloudProvider::GCP) + { + PollGCPTermination(); + } + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("termination poll error: {}", Ex.what()); + } +} + +// AWS termination signals: +// - /spot/instance-action: returns 200 with a JSON body ~2 minutes before +// a spot instance is reclaimed. Returns 404 when no action is pending. +// - /autoscaling/target-lifecycle-state: returns the ASG lifecycle state. +// "InService" is normal; anything else (e.g. "Terminated:Wait") means +// the instance is being cycled out. +void +CloudMetadata::PollAWSTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollAWSTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap AuthHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token", m_AwsToken}); + + HttpClient::Response SpotResponse = ImdsClient.Get("/latest/meta-data/spot/instance-action", AuthHeaders); + if (SpotResponse.IsSuccess()) + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS spot interruption: {}", SpotResponse.AsText()); }); + ZEN_WARN("AWS spot interruption detected: {}", SpotResponse.AsText()); + } + return; + } + + HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders); + if (AutoscaleResponse.IsSuccess()) + { + std::string_view State = AutoscaleResponse.AsText(); + if (State.find("InService") == std::string_view::npos) + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS autoscaling lifecycle: {}", State); }); + ZEN_WARN("AWS autoscaling termination detected: {}", State); + } + } + } +} + +// Azure Scheduled Events API returns a JSON array of upcoming platform events. +// We care about "Preempt" (spot eviction), "Terminate", and "Reboot" events. +// Other event types like "Freeze" (live migration) are non-destructive and +// ignored. The Events array is empty when nothing is pending. +void +CloudMetadata::PollAzureTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollAzureTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap MetadataHeaders({ + std::pair<std::string_view, std::string_view>{"Metadata", "true"}, + }); + + HttpClient::Response EventsResponse = ImdsClient.Get("/metadata/scheduledevents?api-version=2020-07-01", MetadataHeaders); + + if (!EventsResponse.IsSuccess()) + { + return; + } + + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(EventsResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + return; + } + + const json11::Json::array& Events = Json["Events"].array_items(); + for (const auto& Evt : Events) + { + std::string EventType = Evt["EventType"].string_value(); + if (EventType == "Preempt" || EventType == "Terminate" || EventType == "Reboot") + { + if (!m_TerminationPending.exchange(true)) + { + std::string EventStatus = Evt["EventStatus"].string_value(); + m_ReasonLock.WithExclusiveLock( + [&] { m_TerminationReason = fmt::format("Azure scheduled event: {} ({})", EventType, EventStatus); }); + ZEN_WARN("Azure termination event detected: {} ({})", EventType, EventStatus); + } + return; + } + } +} + +// GCP maintenance-event returns "NONE" when nothing is pending, and a +// descriptive string like "TERMINATE_ON_HOST_MAINTENANCE" when the VM is +// about to be live-migrated or terminated. Preemptible/spot VMs get a +// 30-second warning before termination. +void +CloudMetadata::PollGCPTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollGCPTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap MetadataHeaders(std::pair<std::string_view, std::string_view>{"Metadata-Flavor", "Google"}); + + HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders); + if (MaintenanceResponse.IsSuccess()) + { + std::string_view Event = MaintenanceResponse.AsText(); + if (!Event.empty() && Event != "NONE") + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); }); + ZEN_WARN("GCP maintenance event detected: {}", Event); + } + } + } +} + +} // namespace zen::compute + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +# include <zencompute/mockimds.h> + +# include <zencore/filesystem.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zenhttp/httpserver.h> + +# include <memory> +# include <thread> + +namespace zen::compute { + +TEST_SUITE_BEGIN("compute.cloudmetadata"); + +// --------------------------------------------------------------------------- +// Test helper — spins up a local ASIO HTTP server hosting a MockImdsService +// --------------------------------------------------------------------------- + +struct TestImdsServer +{ + MockImdsService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(7575, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + std::string Endpoint() const { return fmt::format("http://127.0.0.1:{}", m_Port); } + + std::filesystem::path DataDir() const { return m_TmpDir->Path() / "cloud"; } + + std::unique_ptr<CloudMetadata> CreateCloud() { return std::make_unique<CloudMetadata>(DataDir(), Endpoint()); } + + ~TestImdsServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional<ScopedTemporaryDirectory> m_TmpDir; + Ref<HttpServer> m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.aws") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::AWS; + + SUBCASE("detection basics") + { + Imds.Mock.Aws.InstanceId = "i-abc123"; + Imds.Mock.Aws.AvailabilityZone = "us-west-2b"; + Imds.Mock.Aws.LifeCycle = "on-demand"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::AWS); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "i-abc123"); + CHECK(Info.AvailabilityZone == "us-west-2b"); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("spot instance") + { + Imds.Mock.Aws.LifeCycle = "spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("autoscaling instance") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsAutoscaling == true); + } + + SUBCASE("spot termination") + { + Imds.Mock.Aws.LifeCycle = "spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + // Simulate a spot interruption notice appearing + Imds.Mock.Aws.SpotAction = R"({"action":"terminate","time":"2025-01-01T00:00:00Z"})"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("spot interruption") != std::string::npos); + } + + SUBCASE("autoscaling termination") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + // Simulate ASG cycling the instance out + Imds.Mock.Aws.AutoscalingState = "Terminated:Wait"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("autoscaling") != std::string::npos); + } + + SUBCASE("no termination when InService") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.azure") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::Azure; + + SUBCASE("detection basics") + { + Imds.Mock.Azure.VmId = "vm-test-1234"; + Imds.Mock.Azure.Location = "westeurope"; + Imds.Mock.Azure.Priority = "Regular"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::Azure); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "vm-test-1234"); + CHECK(Info.AvailabilityZone == "westeurope"); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("spot instance") + { + Imds.Mock.Azure.Priority = "Spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("vmss instance") + { + Imds.Mock.Azure.VmScaleSetName = "my-vmss"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsAutoscaling == true); + } + + SUBCASE("preempt termination") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Azure.ScheduledEventType = "Preempt"; + Imds.Mock.Azure.ScheduledEventStatus = "Scheduled"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("Preempt") != std::string::npos); + } + + SUBCASE("terminate event") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Azure.ScheduledEventType = "Terminate"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("Terminate") != std::string::npos); + } + + SUBCASE("no termination when events empty") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.gcp") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::GCP; + + SUBCASE("detection basics") + { + Imds.Mock.Gcp.InstanceId = "9876543210"; + Imds.Mock.Gcp.Zone = "projects/123/zones/europe-west1-b"; + Imds.Mock.Gcp.Preemptible = "FALSE"; + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::GCP); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "9876543210"); + CHECK(Info.AvailabilityZone == "europe-west1-b"); // zone prefix stripped + CHECK(Info.IsSpot == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("preemptible instance") + { + Imds.Mock.Gcp.Preemptible = "TRUE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("maintenance event during detection") + { + Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + // GCP sets termination pending immediately during detection if a + // maintenance event is active + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos); + } + + SUBCASE("maintenance event during polling") + { + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos); + } + + SUBCASE("no termination when NONE") + { + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// No provider +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.no_provider") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::None; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::None); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId.empty()); + CHECK(Info.AvailabilityZone.empty()); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); +} + +// --------------------------------------------------------------------------- +// Sentinel file management +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.sentinel_files") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::None; + Imds.Start(); + + auto DataDir = Imds.DataDir(); + + SUBCASE("sentinels are written on failed detection") + { + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::None); + CHECK(zen::IsFile(DataDir / ".isNotAWS")); + CHECK(zen::IsFile(DataDir / ".isNotAzure")); + CHECK(zen::IsFile(DataDir / ".isNotGCP")); + } + + SUBCASE("ClearSentinelFiles removes sentinels") + { + auto Cloud = Imds.CreateCloud(); + + CHECK(zen::IsFile(DataDir / ".isNotAWS")); + CHECK(zen::IsFile(DataDir / ".isNotAzure")); + CHECK(zen::IsFile(DataDir / ".isNotGCP")); + + Cloud->ClearSentinelFiles(); + + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP")); + } + + SUBCASE("only failed providers get sentinels") + { + // Switch to AWS — Azure and GCP never probed, so no sentinels for them + Imds.Mock.ActiveProvider = CloudProvider::AWS; + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::AWS); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP")); + } +} + +TEST_SUITE_END(); + +void +cloudmetadata_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp new file mode 100644 index 000000000..838d741b6 --- /dev/null +++ b/src/zencompute/computeservice.cpp @@ -0,0 +1,2236 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "runners/functionrunner.h" +# include "recording/actionrecorder.h" +# include "runners/localrunner.h" +# include "runners/remotehttprunner.h" +# if ZEN_PLATFORM_LINUX +# include "runners/linuxrunner.h" +# elif ZEN_PLATFORM_WINDOWS +# include "runners/windowsrunner.h" +# elif ZEN_PLATFORM_MAC +# include "runners/macrunner.h" +# endif + +# include <zencompute/recordingreader.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/trace.h> +# include <zencore/workthreadpool.h> +# include <zenutil/workerpools.h> +# include <zentelemetry/stats.h> +# include <zenhttp/httpclient.h> + +# include <set> +# include <deque> +# include <map> +# include <thread> +# include <unordered_map> +# include <unordered_set> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <EASTL/hash_set.h> +ZEN_THIRD_PARTY_INCLUDES_END + +using namespace std::literals; + +namespace zen { + +const char* +ToString(compute::ComputeServiceSession::SessionState State) +{ + using enum compute::ComputeServiceSession::SessionState; + switch (State) + { + case Created: + return "Created"; + case Ready: + return "Ready"; + case Draining: + return "Draining"; + case Paused: + return "Paused"; + case Abandoned: + return "Abandoned"; + case Sunset: + return "Sunset"; + } + return "Unknown"; +} + +const char* +ToString(compute::ComputeServiceSession::QueueState State) +{ + using enum compute::ComputeServiceSession::QueueState; + switch (State) + { + case Active: + return "active"; + case Draining: + return "draining"; + case Cancelled: + return "cancelled"; + } + return "unknown"; +} + +} // namespace zen + +namespace zen::compute { + +using SessionState = ComputeServiceSession::SessionState; + +static_assert(ZEN_ARRAY_COUNT(ComputeServiceSession::ActionHistoryEntry::Timestamps) == static_cast<size_t>(RunnerAction::State::_Count)); + +////////////////////////////////////////////////////////////////////////// + +struct ComputeServiceSession::Impl +{ + ComputeServiceSession* m_ComputeServiceSession; + ChunkResolver& m_ChunkResolver; + LoggerRef m_Log{logging::Get("compute")}; + + Impl(ComputeServiceSession* InComputeServiceSession, ChunkResolver& InChunkResolver) + : m_ComputeServiceSession(InComputeServiceSession) + , m_ChunkResolver(InChunkResolver) + , m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) + , m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) + { + // Create a non-expiring, non-deletable implicit queue for legacy endpoints + auto Result = CreateQueue("implicit"sv, {}, {}); + m_ImplicitQueueId = Result.QueueId; + m_QueueLock.WithSharedLock([&] { m_Queues[m_ImplicitQueueId]->Implicit = true; }); + + m_SchedulingThread = std::thread{&Impl::SchedulerThreadFunction, this}; + } + + void WaitUntilReady(); + void Shutdown(); + bool IsHealthy(); + + bool RequestStateTransition(SessionState NewState); + void AbandonAllActions(); + + LoggerRef Log() { return m_Log; } + + // Orchestration + + void SetOrchestratorEndpoint(std::string_view Endpoint); + void SetOrchestratorBasePath(std::filesystem::path BasePath); + + std::string m_OrchestratorEndpoint; + std::filesystem::path m_OrchestratorBasePath; + Stopwatch m_OrchestratorQueryTimer; + std::unordered_set<std::string> m_KnownWorkerUris; + + void UpdateCoordinatorState(); + + // Worker registration and discovery + + struct FunctionDefinition + { + std::string FunctionName; + Guid FunctionVersion; + Guid BuildSystemVersion; + IoHash WorkerId; + }; + + void RegisterWorker(CbPackage Worker); + WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + + // Action scheduling and tracking + + std::atomic<SessionState> m_SessionState{SessionState::Created}; + std::atomic<int32_t> m_ActionsCounter = 0; // sequence number + metrics::Meter m_ArrivalRate; + + RwLock m_PendingLock; + std::map<int, Ref<RunnerAction>> m_PendingActions; + + RwLock m_RunningLock; + std::unordered_map<int, Ref<RunnerAction>> m_RunningMap; + + RwLock m_ResultsLock; + std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap; + metrics::Meter m_ResultRate; + std::atomic<uint64_t> m_RetiredCount{0}; + + EnqueueResult EnqueueAction(int QueueId, CbObject ActionObject, int Priority); + EnqueueResult EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority); + + void GetCompleted(CbWriter& Cbo); + + HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + void RetireActionResult(int ActionLsn); + + std::thread m_SchedulingThread; + std::atomic<bool> m_SchedulingThreadEnabled{true}; + Event m_SchedulingThreadEvent; + + void SchedulerThreadFunction(); + void SchedulePendingActions(); + + // Workers + + RwLock m_WorkerLock; + std::unordered_map<IoHash, CbPackage> m_WorkerMap; + std::vector<FunctionDefinition> m_FunctionList; + std::vector<IoHash> GetKnownWorkerIds(); + void SyncWorkersToRunner(FunctionRunner& Runner); + + // Runners + + DeferredDirectoryDeleter m_DeferredDeleter; + WorkerThreadPool& m_LocalSubmitPool; + WorkerThreadPool& m_RemoteSubmitPool; + RunnerGroup<LocalProcessRunner> m_LocalRunnerGroup; + RunnerGroup<RemoteHttpRunner> m_RemoteRunnerGroup; + + void ShutdownRunners(); + + // Recording + + void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); + void StopRecording(); + + std::unique_ptr<ActionRecorder> m_Recorder; + + // History tracking + + RwLock m_ActionHistoryLock; + std::deque<ComputeServiceSession::ActionHistoryEntry> m_ActionHistory; + size_t m_HistoryLimit = 1000; + + // Queue tracking + + using QueueState = ComputeServiceSession::QueueState; + + struct QueueEntry : RefCounted + { + int QueueId; + bool Implicit{false}; + std::atomic<QueueState> State{QueueState::Active}; + std::atomic<int> ActiveCount{0}; // pending + running + std::atomic<int> CompletedCount{0}; // successfully completed + std::atomic<int> FailedCount{0}; // failed + std::atomic<int> AbandonedCount{0}; // abandoned + std::atomic<int> CancelledCount{0}; // cancelled + std::atomic<uint64_t> IdleSince{0}; // hifreq tick when queue became idle; 0 = has active work + + RwLock m_Lock; + std::unordered_set<int> ActiveLsns; // for cancellation lookup + std::unordered_set<int> FinishedLsns; // completed/failed/cancelled LSNs + + std::string Tag; + CbObject Metadata; + CbObject Config; + }; + + int m_ImplicitQueueId{0}; + std::atomic<int> m_QueueCounter{0}; + RwLock m_QueueLock; + std::unordered_map<int, Ref<QueueEntry>> m_Queues; + + Ref<QueueEntry> FindQueue(int QueueId) + { + Ref<QueueEntry> Queue; + m_QueueLock.WithSharedLock([&] { + if (auto It = m_Queues.find(QueueId); It != m_Queues.end()) + { + Queue = It->second; + } + }); + return Queue; + } + + ComputeServiceSession::CreateQueueResult CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config); + std::vector<int> GetQueueIds(); + ComputeServiceSession::QueueStatus GetQueueStatus(int QueueId); + CbObject GetQueueMetadata(int QueueId); + CbObject GetQueueConfig(int QueueId); + void CancelQueue(int QueueId); + void DeleteQueue(int QueueId); + void DrainQueue(int QueueId); + ComputeServiceSession::EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); + ComputeServiceSession::EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + void GetQueueCompleted(int QueueId, CbWriter& Cbo); + void NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState); + void ExpireCompletedQueues(); + + Stopwatch m_QueueExpiryTimer; + + std::vector<ComputeServiceSession::RunningActionInfo> GetRunningActions(); + std::vector<ComputeServiceSession::ActionHistoryEntry> GetActionHistory(int Limit); + std::vector<ComputeServiceSession::ActionHistoryEntry> GetQueueHistory(int QueueId, int Limit); + + // Action submission + + [[nodiscard]] size_t QueryCapacity(); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action); + [[nodiscard]] std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + [[nodiscard]] size_t GetSubmittedActionCount(); + + // Updates + + RwLock m_UpdatedActionsLock; + std::vector<Ref<RunnerAction>> m_UpdatedActions; + + void HandleActionUpdates(); + void PostUpdate(RunnerAction* Action); + + static constexpr int kDefaultMaxRetries = 3; + int GetMaxRetriesForQueue(int QueueId); + + ComputeServiceSession::RescheduleResult RescheduleAction(int ActionLsn); + + ActionCounts GetActionCounts() + { + ActionCounts Counts; + Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load(); + Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] { + size_t Count = 0; + for (const auto& [Id, Queue] : m_Queues) + { + if (!Queue->Implicit) + { + ++Count; + } + } + return Count; + }); + return Counts; + } + + void EmitStats(CbObjectWriter& Cbo) + { + Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed)); + m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); + m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); + m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); + Cbo << "actions_submitted"sv << GetSubmittedActionCount(); + EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo); + EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); + } +}; + +bool +ComputeServiceSession::Impl::IsHealthy() +{ + return m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned; +} + +bool +ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState) +{ + SessionState Current = m_SessionState.load(std::memory_order_relaxed); + + for (;;) + { + if (Current == NewState) + { + return true; + } + + // Validate the transition + bool Valid = false; + + switch (Current) + { + case SessionState::Created: + Valid = (NewState == SessionState::Ready); + break; + case SessionState::Ready: + Valid = (NewState == SessionState::Draining); + break; + case SessionState::Draining: + Valid = (NewState == SessionState::Ready || NewState == SessionState::Paused); + break; + case SessionState::Paused: + Valid = (NewState == SessionState::Ready || NewState == SessionState::Sunset); + break; + case SessionState::Abandoned: + Valid = (NewState == SessionState::Sunset); + break; + case SessionState::Sunset: + Valid = false; + break; + } + + // Allow jumping directly to Abandoned or Sunset from any non-terminal state + if (NewState == SessionState::Abandoned && Current < SessionState::Abandoned) + { + Valid = true; + } + if (NewState == SessionState::Sunset && Current != SessionState::Sunset) + { + Valid = true; + } + + if (!Valid) + { + ZEN_WARN("invalid session state transition {} -> {}", ToString(Current), ToString(NewState)); + return false; + } + + if (m_SessionState.compare_exchange_strong(Current, NewState, std::memory_order_acq_rel)) + { + ZEN_INFO("session state: {} -> {}", ToString(Current), ToString(NewState)); + + if (NewState == SessionState::Abandoned) + { + AbandonAllActions(); + } + + return true; + } + + // CAS failed, Current was updated — retry with the new value + } +} + +void +ComputeServiceSession::Impl::AbandonAllActions() +{ + // Collect all pending actions and mark them as Abandoned + std::vector<Ref<RunnerAction>> PendingToAbandon; + + m_PendingLock.WithSharedLock([&] { + PendingToAbandon.reserve(m_PendingActions.size()); + for (auto& [Lsn, Action] : m_PendingActions) + { + PendingToAbandon.push_back(Action); + } + }); + + for (auto& Action : PendingToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + } + + // Collect all running actions and mark them as Abandoned, then + // best-effort cancel via the local runner group + std::vector<Ref<RunnerAction>> RunningToAbandon; + + m_RunningLock.WithSharedLock([&] { + RunningToAbandon.reserve(m_RunningMap.size()); + for (auto& [Lsn, Action] : m_RunningMap) + { + RunningToAbandon.push_back(Action); + } + }); + + for (auto& Action : RunningToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + m_LocalRunnerGroup.CancelAction(Action->ActionLsn); + } + + ZEN_INFO("abandoned all actions: {} pending, {} running", PendingToAbandon.size(), RunningToAbandon.size()); +} + +void +ComputeServiceSession::Impl::SetOrchestratorEndpoint(std::string_view Endpoint) +{ + m_OrchestratorEndpoint = Endpoint; +} + +void +ComputeServiceSession::Impl::SetOrchestratorBasePath(std::filesystem::path BasePath) +{ + m_OrchestratorBasePath = std::move(BasePath); +} + +void +ComputeServiceSession::Impl::UpdateCoordinatorState() +{ + ZEN_TRACE_CPU("ComputeServiceSession::UpdateCoordinatorState"); + if (m_OrchestratorEndpoint.empty()) + { + return; + } + + // Poll faster when we have no discovered workers yet so remote runners come online quickly + const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000; + if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs) + { + return; + } + + m_OrchestratorQueryTimer.Reset(); + + try + { + HttpClient Client(m_OrchestratorEndpoint); + + HttpClient::Response Response = Client.Get("/orch/agents"); + + if (!Response.IsSuccess()) + { + ZEN_WARN("orchestrator query failed with status {}", static_cast<int>(Response.StatusCode)); + return; + } + + CbObject WorkerList = Response.AsObject(); + + std::unordered_set<std::string> ValidWorkerUris; + + for (auto& Item : WorkerList["workers"sv]) + { + CbObjectView Worker = Item.AsObjectView(); + + uint64_t Dt = Worker["dt"sv].AsUInt64(); + bool Reachable = Worker["reachable"sv].AsBool(); + std::string_view Uri = Worker["uri"sv].AsString(); + + // Skip stale workers (not seen in over 30 seconds) + if (Dt > 30000) + { + continue; + } + + // Skip workers that are not confirmed reachable + if (!Reachable) + { + continue; + } + + std::string UriStr{Uri}; + ValidWorkerUris.insert(UriStr); + + // Skip workers we already know about + if (m_KnownWorkerUris.contains(UriStr)) + { + continue; + } + + ZEN_INFO("discovered new worker at {}", UriStr); + + m_KnownWorkerUris.insert(UriStr); + + auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool); + SyncWorkersToRunner(*NewRunner); + m_RemoteRunnerGroup.AddRunner(NewRunner); + } + + // Remove workers that are no longer valid (stale or unreachable) + for (auto It = m_KnownWorkerUris.begin(); It != m_KnownWorkerUris.end();) + { + if (!ValidWorkerUris.contains(*It)) + { + const std::string& ExpiredUri = *It; + ZEN_INFO("removing expired worker at {}", ExpiredUri); + + m_RemoteRunnerGroup.RemoveRunnerIf([&](const RemoteHttpRunner& Runner) { return Runner.GetHostName() == ExpiredUri; }); + + It = m_KnownWorkerUris.erase(It); + } + else + { + ++It; + } + } + } + catch (const HttpClientError& Ex) + { + ZEN_WARN("orchestrator query error: {}", Ex.what()); + } + catch (const std::exception& Ex) + { + ZEN_WARN("orchestrator query unexpected error: {}", Ex.what()); + } +} + +void +ComputeServiceSession::Impl::WaitUntilReady() +{ + if (m_RemoteRunnerGroup.GetRunnerCount() || !m_OrchestratorEndpoint.empty()) + { + ZEN_INFO("waiting for remote runners..."); + + constexpr int MaxWaitSeconds = 120; + + for (int Elapsed = 0; Elapsed < MaxWaitSeconds; Elapsed++) + { + if (!m_SchedulingThreadEnabled.load(std::memory_order_relaxed)) + { + ZEN_WARN("shutdown requested while waiting for remote runners"); + return; + } + + const size_t Capacity = m_RemoteRunnerGroup.QueryCapacity(); + + if (Capacity > 0) + { + ZEN_INFO("found {} remote runners (capacity: {})", m_RemoteRunnerGroup.GetRunnerCount(), Capacity); + break; + } + + zen::Sleep(1000); + } + } + else + { + ZEN_ASSERT(m_LocalRunnerGroup.GetRunnerCount(), "no runners available"); + } + + RequestStateTransition(SessionState::Ready); +} + +void +ComputeServiceSession::Impl::Shutdown() +{ + RequestStateTransition(SessionState::Sunset); + + m_SchedulingThreadEnabled = false; + m_SchedulingThreadEvent.Set(); + if (m_SchedulingThread.joinable()) + { + m_SchedulingThread.join(); + } + + ShutdownRunners(); + + m_DeferredDeleter.Shutdown(); +} + +void +ComputeServiceSession::Impl::ShutdownRunners() +{ + m_LocalRunnerGroup.Shutdown(); + m_RemoteRunnerGroup.Shutdown(); +} + +void +ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) +{ + ZEN_INFO("starting recording to '{}'", RecordingPath); + + m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath); + + ZEN_INFO("started recording to '{}'", RecordingPath); +} + +void +ComputeServiceSession::Impl::StopRecording() +{ + ZEN_INFO("stopping recording"); + + m_Recorder = nullptr; + + ZEN_INFO("stopped recording"); +} + +std::vector<ComputeServiceSession::RunningActionInfo> +ComputeServiceSession::Impl::GetRunningActions() +{ + std::vector<ComputeServiceSession::RunningActionInfo> Result; + m_RunningLock.WithSharedLock([&] { + Result.reserve(m_RunningMap.size()); + for (const auto& [Lsn, Action] : m_RunningMap) + { + Result.push_back({.Lsn = Lsn, + .QueueId = Action->QueueId, + .ActionId = Action->ActionId, + .CpuUsagePercent = Action->CpuUsagePercent.load(std::memory_order_relaxed), + .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed)}); + } + }); + return Result; +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::Impl::GetActionHistory(int Limit) +{ + RwLock::SharedLockScope _(m_ActionHistoryLock); + + if (Limit > 0 && static_cast<size_t>(Limit) < m_ActionHistory.size()) + { + return std::vector<ActionHistoryEntry>(m_ActionHistory.end() - Limit, m_ActionHistory.end()); + } + + return std::vector<ActionHistoryEntry>(m_ActionHistory.begin(), m_ActionHistory.end()); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::Impl::GetQueueHistory(int QueueId, int Limit) +{ + // Resolve the queue and snapshot its finished LSN set + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + std::unordered_set<int> FinishedLsns; + + Queue->m_Lock.WithSharedLock([&] { FinishedLsns = Queue->FinishedLsns; }); + + // Filter the global history to entries belonging to this queue. + // m_ActionHistory is ordered oldest-first, so the filtered result keeps the same ordering. + std::vector<ActionHistoryEntry> Result; + + m_ActionHistoryLock.WithSharedLock([&] { + for (const auto& Entry : m_ActionHistory) + { + if (FinishedLsns.contains(Entry.Lsn)) + { + Result.push_back(Entry); + } + } + }); + + if (Limit > 0 && static_cast<size_t>(Limit) < Result.size()) + { + Result.erase(Result.begin(), Result.end() - Limit); + } + + return Result; +} + +void +ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) +{ + ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker"); + RwLock::ExclusiveLockScope _(m_WorkerLock); + + const IoHash& WorkerId = Worker.GetObject().GetHash(); + + if (m_WorkerMap.insert_or_assign(WorkerId, Worker).second) + { + // Note that since the convention currently is that WorkerId is equal to the hash + // of the worker descriptor there is no chance that we get a second write with a + // different descriptor. Thus we only need to call this the first time, when the + // worker is added + + m_LocalRunnerGroup.RegisterWorker(Worker); + m_RemoteRunnerGroup.RegisterWorker(Worker); + + if (m_Recorder) + { + m_Recorder->RegisterWorker(Worker); + } + + CbObject WorkerObj = Worker.GetObject(); + + // Populate worker database + + const Guid WorkerBuildSystemVersion = WorkerObj["buildsystem_version"sv].AsUuid(); + + for (auto& Item : WorkerObj["functions"sv]) + { + CbObjectView Function = Item.AsObjectView(); + + std::string_view FunctionName = Function["name"sv].AsString(); + const Guid FunctionVersion = Function["version"sv].AsUuid(); + + m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, + .FunctionVersion = FunctionVersion, + .BuildSystemVersion = WorkerBuildSystemVersion, + .WorkerId = WorkerId}); + } + } +} + +void +ComputeServiceSession::Impl::SyncWorkersToRunner(FunctionRunner& Runner) +{ + ZEN_TRACE_CPU("SyncWorkersToRunner"); + + std::vector<CbPackage> Workers; + + { + RwLock::SharedLockScope _(m_WorkerLock); + Workers.reserve(m_WorkerMap.size()); + for (const auto& [Id, Pkg] : m_WorkerMap) + { + Workers.push_back(Pkg); + } + } + + for (const CbPackage& Worker : Workers) + { + Runner.RegisterWorker(Worker); + } +} + +WorkerDesc +ComputeServiceSession::Impl::GetWorkerDescriptor(const IoHash& WorkerId) +{ + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + const CbPackage& Desc = It->second; + return {Desc, WorkerId}; + } + + return {}; +} + +std::vector<IoHash> +ComputeServiceSession::Impl::GetKnownWorkerIds() +{ + std::vector<IoHash> WorkerIds; + + m_WorkerLock.WithSharedLock([&] { + WorkerIds.reserve(m_WorkerMap.size()); + for (const auto& [WorkerId, _] : m_WorkerMap) + { + WorkerIds.push_back(WorkerId); + } + }); + + return WorkerIds; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueAction(int QueueId, CbObject ActionObject, int Priority) +{ + ZEN_TRACE_CPU("ComputeServiceSession::EnqueueAction"); + + // Resolve function to worker + + IoHash WorkerId{IoHash::Zero}; + CbPackage WorkerPackage; + + std::string_view FunctionName = ActionObject["Function"sv].AsString(); + const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); + const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + + m_WorkerLock.WithSharedLock([&] { + for (const FunctionDefinition& FuncDef : m_FunctionList) + { + if (FuncDef.FunctionName == FunctionName && FuncDef.FunctionVersion == FunctionVersion && + FuncDef.BuildSystemVersion == BuildSystemVersion) + { + WorkerId = FuncDef.WorkerId; + + break; + } + } + + if (WorkerId != IoHash::Zero) + { + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + WorkerPackage = It->second; + } + } + }); + + if (WorkerId == IoHash::Zero) + { + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker matches the action specification"; + + return {0, Writer.Save()}; + } + + if (WorkerPackage) + { + return EnqueueResolvedAction(QueueId, WorkerDesc{WorkerPackage, WorkerId}, ActionObject, Priority); + } + + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker found despite match"; + + return {0, Writer.Save()}; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + ZEN_TRACE_CPU("ComputeServiceSession::EnqueueResolvedAction"); + + if (m_SessionState.load(std::memory_order_relaxed) != SessionState::Ready) + { + CbObjectWriter Writer; + Writer << "error"sv << fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load())); + return {0, Writer.Save()}; + } + + const int ActionLsn = ++m_ActionsCounter; + + m_ArrivalRate.Mark(); + + Ref<RunnerAction> Pending{new RunnerAction(m_ComputeServiceSession)}; + + Pending->ActionLsn = ActionLsn; + Pending->QueueId = QueueId; + Pending->Worker = Worker; + Pending->ActionId = ActionObj.GetHash(); + Pending->ActionObj = ActionObj; + Pending->Priority = RequestPriority; + + // For now simply put action into pending state, so we can do batch scheduling + + ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); + + Pending->SetActionState(RunnerAction::State::Pending); + + if (m_Recorder) + { + m_Recorder->RecordAction(Pending); + } + + CbObjectWriter Writer; + Writer << "lsn" << Pending->ActionLsn; + Writer << "worker" << Pending->Worker.WorkerId; + Writer << "action" << Pending->ActionId; + + return {Pending->ActionLsn, Writer.Save()}; +} + +SubmitResult +ComputeServiceSession::Impl::SubmitAction(Ref<RunnerAction> Action) +{ + // Loosely round-robin scheduling of actions across runners. + // + // It's not entirely clear what this means given that submits + // can come in across multiple threads, but it's probably better + // than always starting with the first runner. + // + // Longer term we should track the state of the individual + // runners and make decisions accordingly. + + SubmitResult Result = m_LocalRunnerGroup.SubmitAction(Action); + if (Result.IsAccepted) + { + return Result; + } + + return m_RemoteRunnerGroup.SubmitAction(Action); +} + +size_t +ComputeServiceSession::Impl::GetSubmittedActionCount() +{ + return m_LocalRunnerGroup.GetSubmittedActionCount() + m_RemoteRunnerGroup.GetSubmittedActionCount(); +} + +HttpResponseCode +ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) + { + return HttpResponseCode::Accepted; + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) + { + return HttpResponseCode::Accepted; + } + } + + return HttpResponseCode::NotFound; +} + +HttpResponseCode +ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) + { + if (It->second->ActionId == ActionId) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + for (const auto& [K, Pending] : m_PendingActions) + { + if (Pending->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + for (const auto& [K, v] : m_RunningMap) + { + if (v->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + return HttpResponseCode::NotFound; +} + +void +ComputeServiceSession::Impl::RetireActionResult(int ActionLsn) +{ + m_DeferredDeleter.MarkReady(ActionLsn); +} + +void +ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo) +{ + Cbo.BeginArray("completed"); + + m_ResultsLock.WithSharedLock([&] { + for (auto& [Lsn, Action] : m_ResultsMap) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Lsn; + Cbo << "state"sv << RunnerAction::ToString(Action->ActionState()); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); +} + +////////////////////////////////////////////////////////////////////////// +// Queue management + +ComputeServiceSession::CreateQueueResult +ComputeServiceSession::Impl::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config) +{ + const int QueueId = ++m_QueueCounter; + + Ref<QueueEntry> Queue{new QueueEntry()}; + Queue->QueueId = QueueId; + Queue->Tag = Tag; + Queue->Metadata = std::move(Metadata); + Queue->Config = std::move(Config); + Queue->IdleSince = GetHifreqTimerValue(); + + m_QueueLock.WithExclusiveLock([&] { m_Queues[QueueId] = Queue; }); + + ZEN_DEBUG("created queue {}", QueueId); + + return {.QueueId = QueueId}; +} + +std::vector<int> +ComputeServiceSession::Impl::GetQueueIds() +{ + std::vector<int> Ids; + + m_QueueLock.WithSharedLock([&] { + Ids.reserve(m_Queues.size()); + for (const auto& [Id, Queue] : m_Queues) + { + if (!Queue->Implicit) + { + Ids.push_back(Id); + } + } + }); + + return Ids; +} + +ComputeServiceSession::QueueStatus +ComputeServiceSession::Impl::GetQueueStatus(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + const int Active = Queue->ActiveCount.load(std::memory_order_relaxed); + const int Completed = Queue->CompletedCount.load(std::memory_order_relaxed); + const int Failed = Queue->FailedCount.load(std::memory_order_relaxed); + const int AbandonedN = Queue->AbandonedCount.load(std::memory_order_relaxed); + const int CancelledN = Queue->CancelledCount.load(std::memory_order_relaxed); + const QueueState QState = Queue->State.load(); + + return { + .IsValid = true, + .QueueId = QueueId, + .ActiveCount = Active, + .CompletedCount = Completed, + .FailedCount = Failed, + .AbandonedCount = AbandonedN, + .CancelledCount = CancelledN, + .State = QState, + .IsComplete = (Active == 0), + }; +} + +CbObject +ComputeServiceSession::Impl::GetQueueMetadata(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + return Queue->Metadata; +} + +CbObject +ComputeServiceSession::Impl::GetQueueConfig(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + return Queue->Config; +} + +void +ComputeServiceSession::Impl::CancelQueue(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue || Queue->Implicit) + { + return; + } + + Queue->State.store(QueueState::Cancelled); + + // Collect active LSNs snapshot for cancellation + std::vector<int> LsnsToCancel; + + Queue->m_Lock.WithSharedLock([&] { LsnsToCancel.assign(Queue->ActiveLsns.begin(), Queue->ActiveLsns.end()); }); + + // Identify which LSNs are still pending (not yet dispatched to a runner) + std::vector<Ref<RunnerAction>> PendingActionsToCancel; + std::vector<int> RunningLsnsToCancel; + + m_PendingLock.WithSharedLock([&] { + for (int Lsn : LsnsToCancel) + { + if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end()) + { + PendingActionsToCancel.push_back(It->second); + } + } + }); + + m_RunningLock.WithSharedLock([&] { + for (int Lsn : LsnsToCancel) + { + if (m_RunningMap.find(Lsn) != m_RunningMap.end()) + { + RunningLsnsToCancel.push_back(Lsn); + } + } + }); + + // Cancel pending actions by marking them as Cancelled; they will flow through + // HandleActionUpdates and eventually be removed from the pending map. + for (auto& Action : PendingActionsToCancel) + { + Action->SetActionState(RunnerAction::State::Cancelled); + } + + // Best-effort cancellation of running actions via the local runner group. + // Also set the action state to Cancelled directly so a subsequent Failed + // transition from the runner is blocked (Cancelled > Failed in the enum). + for (int Lsn : RunningLsnsToCancel) + { + m_RunningLock.WithSharedLock([&] { + if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end()) + { + It->second->SetActionState(RunnerAction::State::Cancelled); + } + }); + m_LocalRunnerGroup.CancelAction(Lsn); + } + + m_RemoteRunnerGroup.CancelRemoteQueue(QueueId); + + ZEN_INFO("cancelled queue {}: {} pending, {} running actions cancelled", + QueueId, + PendingActionsToCancel.size(), + RunningLsnsToCancel.size()); + + // Wake up the scheduler to process the cancelled actions + m_SchedulingThreadEvent.Set(); +} + +void +ComputeServiceSession::Impl::DeleteQueue(int QueueId) +{ + // Never delete the implicit queue + { + Ref<QueueEntry> Queue = FindQueue(QueueId); + if (Queue && Queue->Implicit) + { + return; + } + } + + // Cancel any active work first + CancelQueue(QueueId); + + m_QueueLock.WithExclusiveLock([&] { + if (auto It = m_Queues.find(QueueId); It != m_Queues.end()) + { + m_Queues.erase(It); + } + }); +} + +void +ComputeServiceSession::Impl::DrainQueue(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue || Queue->Implicit) + { + return; + } + + QueueState Expected = QueueState::Active; + Queue->State.compare_exchange_strong(Expected, QueueState::Draining); + ZEN_INFO("draining queue {}", QueueId); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue not found"sv; + return {0, Writer.Save()}; + } + + QueueState QState = Queue->State.load(); + if (QState == QueueState::Cancelled) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is cancelled"sv; + return {0, Writer.Save()}; + } + + if (QState == QueueState::Draining) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is draining"sv; + return {0, Writer.Save()}; + } + + EnqueueResult Result = EnqueueAction(QueueId, ActionObject, Priority); + + if (Result.Lsn != 0) + { + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + } + + return Result; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue not found"sv; + return {0, Writer.Save()}; + } + + QueueState QState = Queue->State.load(); + if (QState == QueueState::Cancelled) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is cancelled"sv; + return {0, Writer.Save()}; + } + + if (QState == QueueState::Draining) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is draining"sv; + return {0, Writer.Save()}; + } + + EnqueueResult Result = EnqueueResolvedAction(QueueId, Worker, ActionObj, Priority); + + if (Result.Lsn != 0) + { + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + } + + return Result; +} + +void +ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + Cbo.BeginArray("completed"); + + if (Queue) + { + Queue->m_Lock.WithSharedLock([&] { + m_ResultsLock.WithSharedLock([&] { + for (int Lsn : Queue->FinishedLsns) + { + if (m_ResultsMap.contains(Lsn)) + { + Cbo << Lsn; + } + } + }); + }); + } + + Cbo.EndArray(); +} + +void +ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState) +{ + if (QueueId == 0) + { + return; + } + + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return; + } + + Queue->m_Lock.WithExclusiveLock([&] { + Queue->ActiveLsns.erase(Lsn); + Queue->FinishedLsns.insert(Lsn); + }); + + const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed); + if (PreviousActive == 1) + { + Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed); + } + + switch (ActionState) + { + case RunnerAction::State::Completed: + Queue->CompletedCount.fetch_add(1, std::memory_order_relaxed); + break; + case RunnerAction::State::Abandoned: + Queue->AbandonedCount.fetch_add(1, std::memory_order_relaxed); + break; + case RunnerAction::State::Cancelled: + Queue->CancelledCount.fetch_add(1, std::memory_order_relaxed); + break; + default: + Queue->FailedCount.fetch_add(1, std::memory_order_relaxed); + break; + } +} + +void +ComputeServiceSession::Impl::ExpireCompletedQueues() +{ + static constexpr uint64_t ExpiryTimeMs = 15 * 60 * 1000; + + std::vector<int> ExpiredQueueIds; + + m_QueueLock.WithSharedLock([&] { + for (const auto& [Id, Queue] : m_Queues) + { + if (Queue->Implicit) + { + continue; + } + const uint64_t Idle = Queue->IdleSince.load(std::memory_order_relaxed); + if (Idle != 0 && Queue->ActiveCount.load(std::memory_order_relaxed) == 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(GetHifreqTimerValue() - Idle); + if (ElapsedMs >= ExpiryTimeMs) + { + ExpiredQueueIds.push_back(Id); + } + } + } + }); + + for (int QueueId : ExpiredQueueIds) + { + ZEN_INFO("expiring idle queue {}", QueueId); + DeleteQueue(QueueId); + } +} + +void +ComputeServiceSession::Impl::SchedulePendingActions() +{ + ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions"); + int ScheduledCount = 0; + size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); + + static Stopwatch DumpRunningTimer; + + auto _ = MakeGuard([&] { + ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", + ScheduledCount, + RunningCount, + m_RetiredCount.load(), + PendingCount, + ResultCount); + + if (DumpRunningTimer.GetElapsedTimeMs() > 30000) + { + DumpRunningTimer.Reset(); + + std::set<int> RunningList; + m_RunningLock.WithSharedLock([&] { + for (auto& [K, V] : m_RunningMap) + { + RunningList.insert(K); + } + }); + + ExtendableStringBuilder<1024> RunningString; + for (int i : RunningList) + { + if (RunningString.Size()) + { + RunningString << ", "; + } + + RunningString.Append(IntNum(i)); + } + + ZEN_INFO("running: {}", RunningString); + } + }); + + size_t Capacity = QueryCapacity(); + + if (!Capacity) + { + _.Dismiss(); + + return; + } + + std::vector<Ref<RunnerAction>> ActionsToSchedule; + + // Pull actions to schedule from the pending queue, we will + // try to submit these to the runner outside of the lock. Note + // that because of how the state transitions work it's not + // actually the case that all of these actions will still be + // pending by the time we try to submit them, but that's fine. + // + // Also note that the m_PendingActions list is not maintained + // here, that's done periodically in SchedulePendingActions() + + m_PendingLock.WithExclusiveLock([&] { + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) + { + return; + } + + if (m_PendingActions.empty()) + { + return; + } + + for (auto& [Lsn, Pending] : m_PendingActions) + { + switch (Pending->ActionState()) + { + case RunnerAction::State::Pending: + ActionsToSchedule.push_back(Pending); + break; + + case RunnerAction::State::Submitting: + break; // already claimed by async submission + + case RunnerAction::State::Running: + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + case RunnerAction::State::Abandoned: + case RunnerAction::State::Cancelled: + break; + + default: + case RunnerAction::State::New: + ZEN_WARN("unexpected state {} for pending action {}", static_cast<int>(Pending->ActionState()), Pending->ActionLsn); + break; + } + } + + // Sort by priority descending, then by LSN ascending (FIFO within same priority) + std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) { + if (A->Priority != B->Priority) + { + return A->Priority > B->Priority; + } + return A->ActionLsn < B->ActionLsn; + }); + + if (ActionsToSchedule.size() > Capacity) + { + ActionsToSchedule.resize(Capacity); + } + + PendingCount = m_PendingActions.size(); + }); + + if (ActionsToSchedule.empty()) + { + _.Dismiss(); + return; + } + + ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); + + Stopwatch SubmitTimer; + std::vector<SubmitResult> SubmitResults = SubmitActions(ActionsToSchedule); + + int NotAcceptedCount = 0; + int ScheduledActionCount = 0; + + for (const SubmitResult& SubResult : SubmitResults) + { + if (SubResult.IsAccepted) + { + ++ScheduledActionCount; + } + else + { + ++NotAcceptedCount; + } + } + + ZEN_INFO("scheduled {} pending actions in {} ({} rejected)", + ScheduledActionCount, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + NotAcceptedCount); + + ScheduledCount += ScheduledActionCount; + PendingCount -= ScheduledActionCount; +} + +void +ComputeServiceSession::Impl::SchedulerThreadFunction() +{ + SetCurrentThreadName("Function_Scheduler"); + + auto _ = MakeGuard([&] { ZEN_INFO("scheduler thread exiting"); }); + + do + { + int TimeoutMs = 500; + + auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + + if (PendingCount) + { + TimeoutMs = 100; + } + + const bool WasSignaled = m_SchedulingThreadEvent.Wait(TimeoutMs); + + if (m_SchedulingThreadEnabled == false) + { + return; + } + + if (WasSignaled) + { + m_SchedulingThreadEvent.Reset(); + } + + ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}", + m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }), + PendingCount, + m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }), + m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }), + TimeoutMs); + + HandleActionUpdates(); + + // Auto-transition Draining → Paused when all work is done + if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining) + { + size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + + if (Pending == 0 && Running == 0) + { + SessionState Expected = SessionState::Draining; + if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel)) + { + ZEN_INFO("session state: Draining -> Paused (all work completed)"); + } + } + } + + UpdateCoordinatorState(); + SchedulePendingActions(); + + static constexpr uint64_t QueueExpirySweepIntervalMs = 30000; + if (m_QueueExpiryTimer.GetElapsedTimeMs() >= QueueExpirySweepIntervalMs) + { + m_QueueExpiryTimer.Reset(); + ExpireCompletedQueues(); + } + } while (m_SchedulingThreadEnabled); +} + +void +ComputeServiceSession::Impl::PostUpdate(RunnerAction* Action) +{ + m_UpdatedActionsLock.WithExclusiveLock([&] { m_UpdatedActions.emplace_back(Action); }); + m_SchedulingThreadEvent.Set(); +} + +int +ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId) +{ + if (QueueId == 0) + { + return kDefaultMaxRetries; + } + + CbObject Config = GetQueueConfig(QueueId); + + if (Config) + { + int Value = Config["max_retries"].AsInt32(0); + + if (Value > 0) + { + return Value; + } + } + + return kDefaultMaxRetries; +} + +ComputeServiceSession::RescheduleResult +ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) +{ + Ref<RunnerAction> Action; + RunnerAction::State State; + RescheduleResult ValidationError; + bool Removed = false; + + // Find, validate, and remove atomically under a single lock scope to prevent + // concurrent RescheduleAction calls from double-removing the same action. + m_ResultsLock.WithExclusiveLock([&] { + auto It = m_ResultsMap.find(ActionLsn); + if (It == m_ResultsMap.end()) + { + ValidationError = {.Success = false, .Error = "Action not found in results"}; + return; + } + + Action = It->second; + State = Action->ActionState(); + + if (State != RunnerAction::State::Failed && State != RunnerAction::State::Abandoned) + { + ValidationError = {.Success = false, .Error = "Action is not in a failed or abandoned state"}; + return; + } + + int MaxRetries = GetMaxRetriesForQueue(Action->QueueId); + if (Action->RetryCount.load(std::memory_order_relaxed) >= MaxRetries) + { + ValidationError = {.Success = false, .Error = "Retry limit reached"}; + return; + } + + m_ResultsMap.erase(It); + Removed = true; + }); + + if (!Removed) + { + return ValidationError; + } + + if (Action->QueueId != 0) + { + Ref<QueueEntry> Queue = FindQueue(Action->QueueId); + + if (Queue) + { + Queue->m_Lock.WithExclusiveLock([&] { + Queue->FinishedLsns.erase(ActionLsn); + Queue->ActiveLsns.insert(ActionLsn); + }); + + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + + if (State == RunnerAction::State::Failed) + { + Queue->FailedCount.fetch_sub(1, std::memory_order_relaxed); + } + else + { + Queue->AbandonedCount.fetch_sub(1, std::memory_order_relaxed); + } + } + } + + // Reset action state — this calls PostUpdate() internally + Action->ResetActionStateToPending(); + + int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); + ZEN_INFO("action {} ({}) manually rescheduled (retry {})", Action->ActionId, ActionLsn, NewRetryCount); + + return {.Success = true, .RetryCount = NewRetryCount}; +} + +void +ComputeServiceSession::Impl::HandleActionUpdates() +{ + ZEN_TRACE_CPU("ComputeServiceSession::HandleActionUpdates"); + + // Drain the update queue atomically + std::vector<Ref<RunnerAction>> UpdatedActions; + m_UpdatedActionsLock.WithExclusiveLock([&] { std::swap(UpdatedActions, m_UpdatedActions); }); + + std::unordered_set<int> SeenLsn; + + // Process each action's latest state, deduplicating by LSN. + // + // This is safe because state transitions are monotonically increasing by enum + // rank (Pending < Submitting < Running < Completed/Failed/Cancelled), so + // SetActionState rejects any transition to a lower-ranked state. By the time + // we read ActionState() here, it reflects the highest state reached — making + // the first occurrence per LSN authoritative and duplicates redundant. + for (Ref<RunnerAction>& Action : UpdatedActions) + { + const int ActionLsn = Action->ActionLsn; + + if (auto [It, Inserted] = SeenLsn.insert(ActionLsn); Inserted) + { + switch (Action->ActionState()) + { + // Newly enqueued — add to pending map for scheduling + case RunnerAction::State::Pending: + m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + break; + + // Async submission in progress — remains in pending map + case RunnerAction::State::Submitting: + break; + + // Dispatched to a runner — move from pending to running + case RunnerAction::State::Running: + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + m_RunningMap.insert({ActionLsn, Action}); + m_PendingActions.erase(ActionLsn); + }); + }); + ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); + break; + + // Terminal states — move to results, record history, notify queue + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + case RunnerAction::State::Abandoned: + case RunnerAction::State::Cancelled: + { + auto TerminalState = Action->ActionState(); + + // Automatic retry for Failed/Abandoned actions with retries remaining. + // Skip retries when the session itself is abandoned — those actions + // were intentionally abandoned and should not be rescheduled. + if ((TerminalState == RunnerAction::State::Failed || TerminalState == RunnerAction::State::Abandoned) && + m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned) + { + int MaxRetries = GetMaxRetriesForQueue(Action->QueueId); + + if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) + { + // Remove from whichever active map the action is in before resetting + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); + + // Reset triggers PostUpdate() which re-enters the action as Pending + Action->ResetActionStateToPending(); + int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); + + ZEN_INFO("action {} ({}) auto-rescheduled (retry {}/{})", + Action->ActionId, + ActionLsn, + NewRetryCount, + MaxRetries); + break; + } + } + + // Remove from whichever active map the action is in + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); + + m_ResultsLock.WithExclusiveLock([&] { + m_ResultsMap[ActionLsn] = Action; + + // Append to bounded action history ring + m_ActionHistoryLock.WithExclusiveLock([&] { + ActionHistoryEntry Entry{.Lsn = ActionLsn, + .QueueId = Action->QueueId, + .ActionId = Action->ActionId, + .WorkerId = Action->Worker.WorkerId, + .ActionDescriptor = Action->ActionObj, + .ExecutionLocation = std::move(Action->ExecutionLocation), + .Succeeded = TerminalState == RunnerAction::State::Completed, + .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed), + .RetryCount = Action->RetryCount.load(std::memory_order_relaxed)}; + + std::copy(std::begin(Action->Timestamps), std::end(Action->Timestamps), std::begin(Entry.Timestamps)); + + m_ActionHistory.push_back(std::move(Entry)); + + if (m_ActionHistory.size() > m_HistoryLimit) + { + m_ActionHistory.pop_front(); + } + }); + }); + m_RetiredCount.fetch_add(1); + m_ResultRate.Mark(1); + ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", + Action->ActionId, + ActionLsn, + TerminalState == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); + break; + } + } + } + } +} + +size_t +ComputeServiceSession::Impl::QueryCapacity() +{ + return m_LocalRunnerGroup.QueryCapacity() + m_RemoteRunnerGroup.QueryCapacity(); +} + +std::vector<SubmitResult> +ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions"); + std::vector<SubmitResult> Results(Actions.size()); + + // First try submitting the batch to local runners in parallel + + std::vector<SubmitResult> LocalResults = m_LocalRunnerGroup.SubmitActions(Actions); + std::vector<size_t> RemoteIndices; + std::vector<Ref<RunnerAction>> RemoteActions; + + for (size_t i = 0; i < Actions.size(); ++i) + { + if (LocalResults[i].IsAccepted) + { + Results[i] = std::move(LocalResults[i]); + } + else + { + RemoteIndices.push_back(i); + RemoteActions.push_back(Actions[i]); + } + } + + // Submit remaining actions to remote runners in parallel + if (!RemoteActions.empty()) + { + std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); + + for (size_t j = 0; j < RemoteIndices.size(); ++j) + { + Results[RemoteIndices[j]] = std::move(RemoteResults[j]); + } + } + + return Results; +} + +////////////////////////////////////////////////////////////////////////// + +ComputeServiceSession::ComputeServiceSession(ChunkResolver& InChunkResolver) +{ + m_Impl = std::make_unique<Impl>(this, InChunkResolver); +} + +ComputeServiceSession::~ComputeServiceSession() +{ + Shutdown(); +} + +bool +ComputeServiceSession::IsHealthy() +{ + return m_Impl->IsHealthy(); +} + +void +ComputeServiceSession::WaitUntilReady() +{ + m_Impl->WaitUntilReady(); +} + +void +ComputeServiceSession::Shutdown() +{ + m_Impl->Shutdown(); +} + +ComputeServiceSession::SessionState +ComputeServiceSession::GetSessionState() const +{ + return m_Impl->m_SessionState.load(std::memory_order_relaxed); +} + +bool +ComputeServiceSession::RequestStateTransition(SessionState NewState) +{ + return m_Impl->RequestStateTransition(NewState); +} + +void +ComputeServiceSession::SetOrchestratorEndpoint(std::string_view Endpoint) +{ + m_Impl->SetOrchestratorEndpoint(Endpoint); +} + +void +ComputeServiceSession::SetOrchestratorBasePath(std::filesystem::path BasePath) +{ + m_Impl->SetOrchestratorBasePath(std::move(BasePath)); +} + +void +ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) +{ + m_Impl->StartRecording(InResolver, RecordingPath); +} + +void +ComputeServiceSession::StopRecording() +{ + m_Impl->StopRecording(); +} + +ComputeServiceSession::ActionCounts +ComputeServiceSession::GetActionCounts() +{ + return m_Impl->GetActionCounts(); +} + +void +ComputeServiceSession::EmitStats(CbObjectWriter& Cbo) +{ + m_Impl->EmitStats(Cbo); +} + +std::vector<IoHash> +ComputeServiceSession::GetKnownWorkerIds() +{ + return m_Impl->GetKnownWorkerIds(); +} + +WorkerDesc +ComputeServiceSession::GetWorkerDescriptor(const IoHash& WorkerId) +{ + return m_Impl->GetWorkerDescriptor(WorkerId); +} + +void +ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddLocalRunner"); + +# if ZEN_PLATFORM_LINUX + auto* NewRunner = new LinuxProcessRunner(InChunkResolver, + BasePath, + m_Impl->m_DeferredDeleter, + m_Impl->m_LocalSubmitPool, + false, + MaxConcurrentActions); +# elif ZEN_PLATFORM_WINDOWS + auto* NewRunner = new WindowsProcessRunner(InChunkResolver, + BasePath, + m_Impl->m_DeferredDeleter, + m_Impl->m_LocalSubmitPool, + false, + MaxConcurrentActions); +# elif ZEN_PLATFORM_MAC + auto* NewRunner = + new MacProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, false, MaxConcurrentActions); +# endif + + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner); +} + +void +ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner"); + + auto* NewRunner = new RemoteHttpRunner(InChunkResolver, BasePath, HostName, m_Impl->m_RemoteSubmitPool); + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_RemoteRunnerGroup.AddRunner(NewRunner); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueAction(CbObject ActionObject, int Priority) +{ + return m_Impl->EnqueueActionToQueue(m_Impl->m_ImplicitQueueId, ActionObject, Priority); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + return m_Impl->EnqueueResolvedActionToQueue(m_Impl->m_ImplicitQueueId, Worker, ActionObj, RequestPriority); +} +ComputeServiceSession::CreateQueueResult +ComputeServiceSession::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config) +{ + return m_Impl->CreateQueue(Tag, std::move(Metadata), std::move(Config)); +} + +CbObject +ComputeServiceSession::GetQueueMetadata(int QueueId) +{ + return m_Impl->GetQueueMetadata(QueueId); +} + +CbObject +ComputeServiceSession::GetQueueConfig(int QueueId) +{ + return m_Impl->GetQueueConfig(QueueId); +} + +std::vector<int> +ComputeServiceSession::GetQueueIds() +{ + return m_Impl->GetQueueIds(); +} + +ComputeServiceSession::QueueStatus +ComputeServiceSession::GetQueueStatus(int QueueId) +{ + return m_Impl->GetQueueStatus(QueueId); +} + +void +ComputeServiceSession::CancelQueue(int QueueId) +{ + m_Impl->CancelQueue(QueueId); +} + +void +ComputeServiceSession::DrainQueue(int QueueId) +{ + m_Impl->DrainQueue(QueueId); +} + +void +ComputeServiceSession::DeleteQueue(int QueueId) +{ + m_Impl->DeleteQueue(QueueId); +} + +void +ComputeServiceSession::GetQueueCompleted(int QueueId, CbWriter& Cbo) +{ + m_Impl->GetQueueCompleted(QueueId, Cbo); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + return m_Impl->EnqueueActionToQueue(QueueId, ActionObject, Priority); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + return m_Impl->EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority); +} + +void +ComputeServiceSession::RegisterWorker(CbPackage Worker) +{ + m_Impl->RegisterWorker(Worker); +} + +HttpResponseCode +ComputeServiceSession::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + return m_Impl->GetActionResult(ActionLsn, OutResultPackage); +} + +HttpResponseCode +ComputeServiceSession::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + return m_Impl->FindActionResult(ActionId, OutResultPackage); +} + +void +ComputeServiceSession::RetireActionResult(int ActionLsn) +{ + m_Impl->RetireActionResult(ActionLsn); +} + +ComputeServiceSession::RescheduleResult +ComputeServiceSession::RescheduleAction(int ActionLsn) +{ + return m_Impl->RescheduleAction(ActionLsn); +} + +std::vector<ComputeServiceSession::RunningActionInfo> +ComputeServiceSession::GetRunningActions() +{ + return m_Impl->GetRunningActions(); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::GetActionHistory(int Limit) +{ + return m_Impl->GetActionHistory(Limit); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::GetQueueHistory(int QueueId, int Limit) +{ + return m_Impl->GetQueueHistory(QueueId, Limit); +} + +void +ComputeServiceSession::GetCompleted(CbWriter& Cbo) +{ + m_Impl->GetCompleted(Cbo); +} + +void +ComputeServiceSession::PostUpdate(RunnerAction* Action) +{ + m_Impl->PostUpdate(Action); +} + +////////////////////////////////////////////////////////////////////////// + +void +computeservice_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp new file mode 100644 index 000000000..e82a40781 --- /dev/null +++ b/src/zencompute/httpcomputeservice.cpp @@ -0,0 +1,1643 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/httpcomputeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "runners/functionrunner.h" + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/logging.h> +# include <zencore/system.h> +# include <zencore/thread.h> +# include <zencore/trace.h> +# include <zencore/uid.h> +# include <zenstore/cidstore.h> +# include <zentelemetry/stats.h> + +# include <span> +# include <unordered_map> + +using namespace std::literals; + +namespace zen::compute { + +constinit AsciiSet g_DecimalSet("0123456789"); +constinit AsciiSet g_HexSet("0123456789abcdefABCDEF"); + +auto DecimalMatcher = [](std::string_view Str) { return AsciiSet::HasOnly(Str, g_DecimalSet); }; +auto IoHashMatcher = [](std::string_view Str) { return Str.size() == 40 && AsciiSet::HasOnly(Str, g_HexSet); }; +auto OidMatcher = [](std::string_view Str) { return Str.size() == 24 && AsciiSet::HasOnly(Str, g_HexSet); }; + +////////////////////////////////////////////////////////////////////////// + +struct HttpComputeService::Impl +{ + HttpComputeService* m_Self; + CidStore& m_CidStore; + IHttpStatsService& m_StatsService; + LoggerRef m_Log; + std::filesystem::path m_BaseDir; + HttpRequestRouter m_Router; + ComputeServiceSession m_ComputeService; + SystemMetricsTracker m_MetricsTracker; + + // Metrics + + metrics::OperationTiming m_HttpRequests; + + // Per-remote-queue metadata, shared across all lookup maps below. + + struct RemoteQueueInfo : RefCounted + { + int QueueId = 0; + Oid Token; + std::string IdempotencyKey; // empty if no idempotency key was provided + std::string ClientHostname; // empty if no hostname was provided + }; + + // Remote queue registry — all three maps share the same RemoteQueueInfo objects. + // All maps are guarded by m_RemoteQueueLock. + + RwLock m_RemoteQueueLock; + std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token → info + std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId → info + std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key → info + + LoggerRef Log() { return m_Log; } + + int ResolveQueueToken(const Oid& Token); + int ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture); + + struct IngestStats + { + int Count = 0; + int NewCount = 0; + uint64_t Bytes = 0; + uint64_t NewBytes = 0; + }; + + IngestStats IngestPackageAttachments(const CbPackage& Package); + bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); + void HandleWorkersGet(HttpServerRequest& HttpReq); + void HandleWorkersAllGet(HttpServerRequest& HttpReq); + void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); + void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + + void RegisterRoutes(); + + Impl(HttpComputeService* Self, + CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions) + : m_Self(Self) + , m_CidStore(InCidStore) + , m_StatsService(StatsService) + , m_Log(logging::Get("compute")) + , m_BaseDir(BaseDir) + , m_ComputeService(InCidStore) + { + m_ComputeService.AddLocalRunner(InCidStore, m_BaseDir / "local", MaxConcurrentActions); + m_ComputeService.WaitUntilReady(); + m_StatsService.RegisterHandler("compute", *m_Self); + RegisterRoutes(); + } +}; + +////////////////////////////////////////////////////////////////////////// + +void +HttpComputeService::Impl::RegisterRoutes() +{ + m_Router.AddMatcher("lsn", DecimalMatcher); + m_Router.AddMatcher("worker", IoHashMatcher); + m_Router.AddMatcher("action", IoHashMatcher); + m_Router.AddMatcher("queue", DecimalMatcher); + m_Router.AddMatcher("oidtoken", OidMatcher); + m_Router.AddMatcher("queueref", [](std::string_view Str) { return DecimalMatcher(Str) || OidMatcher(Str); }); + + m_Router.RegisterRoute( + "ready", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.IsHealthy()) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + + return HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "abandon", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + bool Success = m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Abandoned); + + if (Success) + { + CbObjectWriter Cbo; + Cbo << "state"sv + << "Abandoned"sv; + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Abandoned from current state"sv; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "workers", + [this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "workers/{worker}", + [this](HttpRouterRequest& Req) { HandleWorkerRequest(Req.ServerRequest(), IoHash::FromHexString(Req.GetCapture(1))); }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/completed", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + m_ComputeService.GetCompleted(Cbo); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + Cbo.BeginObject("metrics"); + Describe(Sm, Cbo); + Cbo.EndObject(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto QueryParams = HttpReq.GetQueryParams(); + + int QueryLimit = 50; + + if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) + { + QueryLimit = ParseInt<int>(LimitParam).value_or(50); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("history"); + for (const auto& Entry : m_ComputeService.GetActionHistory(QueryLimit)) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Entry.Lsn; + Cbo << "queueId"sv << Entry.QueueId; + Cbo << "actionId"sv << Entry.ActionId; + Cbo << "workerId"sv << Entry.WorkerId; + Cbo << "succeeded"sv << Entry.Succeeded; + Cbo << "actionDescriptor"sv << Entry.ActionDescriptor; + if (Entry.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds); + } + if (Entry.RetryCount > 0) + { + Cbo << "retry_count"sv << Entry.RetryCount; + } + + for (const auto& Timestamp : Entry.Timestamps) + { + Cbo.AddInteger( + fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))), + Timestamp); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/running", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Running = m_ComputeService.GetRunningActions(); + CbObjectWriter Cbo; + Cbo.BeginArray("running"); + for (const auto& Info : Running) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Info.Lsn; + Cbo << "queueId"sv << Info.QueueId; + Cbo << "actionId"sv << Info.ActionId; + if (Info.CpuUsagePercent >= 0.0f) + { + Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent); + } + if (Info.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{lsn}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int ActionLsn = ParseInt<int>(Req.GetCapture(1)).value_or(0); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbPackage Output; + HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output); + + if (ResponseCode == HttpResponseCode::OK) + { + HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + else + { + HttpReq.WriteResponse(ResponseCode); + } + + // Once we've initiated the response we can mark the result + // as retired, allowing the service to free any associated + // resources. Note that there still needs to be a delay + // to allow the transmission to complete, it would be better + // if we could issue this once the response is fully sent... + m_ComputeService.RetireActionResult(ActionLsn); + } + break; + + case HttpVerb::kPost: + { + auto Result = m_ComputeService.RescheduleAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "lsn"sv << ActionLsn; + Cbo << "retry_count"sv << Result.RetryCount; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the + // one which uses the scheduled action lsn for lookups + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2)); + + CbPackage Output; + if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output); + ResponseCode != HttpResponseCode::OK) + { + ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) + + if (ResponseCode == HttpResponseCode::NotFound) + { + return HttpReq.WriteResponse(ResponseCode); + } + + return HttpReq.WriteResponse(ResponseCode); + } + + ZEN_DEBUG("jobs/{}/{}: OK", Req.GetCapture(1), Req.GetCapture(2)) + + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId); + + if (!Worker) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + // TODO: return status of all pending or executing jobs + break; + + case HttpVerb::kPost: + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span<const CbAttachment> Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + ActionObj.GetHash(), + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + break; + + default: + break; + } + break; + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto QueryParams = HttpReq.GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + // Resolve worker + + // + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span<const CbAttachment> Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + return; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "workers/all", + [this](HttpRouterRequest& Req) { HandleWorkersAllGet(Req.ServerRequest()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkersGet(HttpReq); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers/all", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkersAllGet(HttpReq); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkerRequest(HttpReq, IoHash::FromHexString(Req.GetCapture(2))); + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "sysinfo", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + + CbObjectWriter Cbo; + Describe(Sm, Cbo); + + Cbo << "cpu_usage" << Sm.CpuUsagePercent; + Cbo << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; + Cbo << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; + Cbo << "disk_used" << 100 * 1024; + Cbo << "disk_total" << 100 * 1024 * 1024; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "record/start", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + m_ComputeService.StartRecording(m_CidStore, m_BaseDir / "recording"); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "record/stop", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + m_ComputeService.StopRecording(); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); + + // Local-only queue listing and creation + + m_Router.RegisterRoute( + "queues", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbObjectWriter Cbo; + Cbo.BeginArray("queues"sv); + + for (const int QueueId : m_ComputeService.GetQueueIds()) + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + continue; + } + + Cbo.BeginObject(); + WriteQueueDescription(Cbo, QueueId, Status); + Cbo.EndObject(); + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + case HttpVerb::kPost: + { + CbObject Metadata; + CbObject Config; + if (const CbObject Body = HttpReq.ReadPayloadObject()) + { + Metadata = Body.Find("metadata"sv).AsObject(); + Config = Body.Find("config"sv).AsObject(); + } + + ComputeServiceSession::CreateQueueResult Result = + m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config)); + + CbObjectWriter Cbo; + Cbo << "queue_id"sv << Result.QueueId; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + // Queue creation routes — these remain separate since local creates a plain queue + // while remote additionally generates an OID token for external access. + + m_Router.RegisterRoute( + "queues/remote", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + // Extract optional fields from the request body. + // idempotency_key: when present, we return the existing remote queue token for this + // key rather than creating a new queue, making the endpoint safe to call concurrently. + // hostname: human-readable origin context stored alongside the queue for diagnostics. + // metadata: arbitrary CbObject metadata propagated from the originating queue. + // config: arbitrary CbObject config propagated from the originating queue. + std::string IdempotencyKey; + std::string ClientHostname; + CbObject Metadata; + CbObject Config; + if (const CbObject Body = HttpReq.ReadPayloadObject()) + { + IdempotencyKey = std::string(Body["idempotency_key"sv].AsString()); + ClientHostname = std::string(Body["hostname"sv].AsString()); + Metadata = Body.Find("metadata"sv).AsObject(); + Config = Body.Find("config"sv).AsObject(); + } + + // Stamp the forwarding node's hostname into the metadata so that the + // remote side knows which node originated the queue. + if (!ClientHostname.empty()) + { + CbObjectWriter MetaWriter; + for (auto Field : Metadata) + { + MetaWriter.AddField(Field.GetName(), Field); + } + MetaWriter << "via"sv << ClientHostname; + Metadata = MetaWriter.Save(); + } + + RwLock::ExclusiveLockScope _(m_RemoteQueueLock); + + if (!IdempotencyKey.empty()) + { + if (auto It = m_RemoteQueuesByTag.find(IdempotencyKey); It != m_RemoteQueuesByTag.end()) + { + Ref<RemoteQueueInfo> Existing = It->second; + if (m_ComputeService.GetQueueStatus(Existing->QueueId).IsValid) + { + CbObjectWriter Cbo; + Cbo << "queue_token"sv << Existing->Token.ToString(); + Cbo << "queue_id"sv << Existing->QueueId; + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + // Queue has since expired — clean up stale entries and fall through to create a new one + m_RemoteQueuesByToken.erase(Existing->Token); + m_RemoteQueuesByQueueId.erase(Existing->QueueId); + m_RemoteQueuesByTag.erase(It); + } + } + + ComputeServiceSession::CreateQueueResult Result = m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config)); + Ref<RemoteQueueInfo> InfoRef(new RemoteQueueInfo()); + InfoRef->QueueId = Result.QueueId; + InfoRef->Token = Oid::NewOid(); + InfoRef->IdempotencyKey = std::move(IdempotencyKey); + InfoRef->ClientHostname = std::move(ClientHostname); + + m_RemoteQueuesByToken[InfoRef->Token] = InfoRef; + m_RemoteQueuesByQueueId[InfoRef->QueueId] = InfoRef; + if (!InfoRef->IdempotencyKey.empty()) + { + m_RemoteQueuesByTag[InfoRef->IdempotencyKey] = InfoRef; + } + + CbObjectWriter Cbo; + Cbo << "queue_token"sv << InfoRef->Token.ToString(); + Cbo << "queue_id"sv << InfoRef->QueueId; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + + // Unified queue routes — {queueref} accepts both local integer IDs and remote OID tokens. + // ResolveQueueRef() handles access control (local-only for integer IDs) and token resolution. + + m_Router.RegisterRoute( + "queues/{queueref}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Cbo; + WriteQueueDescription(Cbo, QueueId, Status); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + case HttpVerb::kDelete: + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + m_ComputeService.CancelQueue(QueueId); + + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "queues/{queueref}/drain", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + m_ComputeService.DrainQueue(QueueId); + + // Return updated queue status + Status = m_ComputeService.GetQueueStatus(QueueId); + + CbObjectWriter Cbo; + WriteQueueDescription(Cbo, QueueId, Status); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/completed", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Cbo; + m_ComputeService.GetQueueCompleted(QueueId, Cbo); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = HttpReq.GetQueryParams(); + + int QueryLimit = 50; + + if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) + { + QueryLimit = ParseInt<int>(LimitParam).value_or(50); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("history"); + for (const auto& Entry : m_ComputeService.GetQueueHistory(QueueId, QueryLimit)) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Entry.Lsn; + Cbo << "queueId"sv << Entry.QueueId; + Cbo << "actionId"sv << Entry.ActionId; + Cbo << "workerId"sv << Entry.WorkerId; + Cbo << "succeeded"sv << Entry.Succeeded; + if (Entry.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds); + } + if (Entry.RetryCount > 0) + { + Cbo << "retry_count"sv << Entry.RetryCount; + } + + for (const auto& Timestamp : Entry.Timestamps) + { + Cbo.AddInteger( + fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))), + Timestamp); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/running", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + if (QueueId == 0) + { + return; + } + // Filter global running list to this queue + auto AllRunning = m_ComputeService.GetRunningActions(); + std::vector<ComputeServiceSession::RunningActionInfo> Running; + for (auto& Info : AllRunning) + if (Info.QueueId == QueueId) + Running.push_back(Info); + CbObjectWriter Cbo; + Cbo.BeginArray("running"); + for (const auto& Info : Running) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Info.Lsn; + Cbo << "queueId"sv << Info.QueueId; + Cbo << "actionId"sv << Info.ActionId; + if (Info.CpuUsagePercent >= 0.0f) + { + Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent); + } + if (Info.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(2)); + WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId); + + if (!Worker) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + if (!CheckAttachments(ActionObj, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, ActionObj.GetHash(), Result.Lsn); + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + IngestStats Stats = IngestPackageAttachments(Action); + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + ActionObj.GetHash(), + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + if (!CheckAttachments(ActionObj, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: action accepted (lsn {})", QueueId, Result.Lsn); + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + IngestStats Stats = IngestPackageAttachments(Action); + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{lsn}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + const int ActionLsn = ParseInt<int>(Req.GetCapture(2)).value_or(0); + + if (QueueId == 0) + { + return; + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + ZEN_UNUSED(QueueId); + + CbPackage Output; + HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output); + + if (ResponseCode == HttpResponseCode::OK) + { + HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + else + { + HttpReq.WriteResponse(ResponseCode); + } + + m_ComputeService.RetireActionResult(ActionLsn); + } + break; + + case HttpVerb::kPost: + { + ZEN_UNUSED(QueueId); + + auto Result = m_ComputeService.RescheduleAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "lsn"sv << ActionLsn; + Cbo << "retry_count"sv << Result.RetryCount; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); +} + +////////////////////////////////////////////////////////////////////////// + +HttpComputeService::HttpComputeService(CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions) +: m_Impl(std::make_unique<Impl>(this, InCidStore, StatsService, BaseDir, MaxConcurrentActions)) +{ +} + +HttpComputeService::~HttpComputeService() +{ + m_Impl->m_StatsService.UnregisterHandler("compute", *this); +} + +void +HttpComputeService::Shutdown() +{ + m_Impl->m_ComputeService.Shutdown(); +} + +ComputeServiceSession::ActionCounts +HttpComputeService::GetActionCounts() +{ + return m_Impl->m_ComputeService.GetActionCounts(); +} + +const char* +HttpComputeService::BaseUri() const +{ + return "/compute/"; +} + +void +HttpComputeService::HandleRequest(HttpServerRequest& Request) +{ + ZEN_TRACE_CPU("HttpComputeService::HandleRequest"); + metrics::OperationTiming::Scope $(m_Impl->m_HttpRequests); + + if (m_Impl->m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +void +HttpComputeService::HandleStatsRequest(HttpServerRequest& Request) +{ + CbObjectWriter Cbo; + m_Impl->m_ComputeService.EmitStats(Cbo); + + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpComputeService::Impl::WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status) +{ + Cbo << "queue_id"sv << Status.QueueId; + Cbo << "active_count"sv << Status.ActiveCount; + Cbo << "completed_count"sv << Status.CompletedCount; + Cbo << "failed_count"sv << Status.FailedCount; + Cbo << "abandoned_count"sv << Status.AbandonedCount; + Cbo << "cancelled_count"sv << Status.CancelledCount; + Cbo << "state"sv << ToString(Status.State); + Cbo << "cancelled"sv << (Status.State == ComputeServiceSession::QueueState::Cancelled); + Cbo << "draining"sv << (Status.State == ComputeServiceSession::QueueState::Draining); + Cbo << "is_complete"sv << Status.IsComplete; + + if (CbObject Meta = m_ComputeService.GetQueueMetadata(QueueId)) + { + Cbo << "metadata"sv << Meta; + } + + if (CbObject Cfg = m_ComputeService.GetQueueConfig(QueueId)) + { + Cbo << "config"sv << Cfg; + } + + { + RwLock::SharedLockScope $(m_RemoteQueueLock); + if (auto It = m_RemoteQueuesByQueueId.find(QueueId); It != m_RemoteQueuesByQueueId.end()) + { + Cbo << "queue_token"sv << It->second->Token.ToString(); + if (!It->second->ClientHostname.empty()) + { + Cbo << "hostname"sv << It->second->ClientHostname; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// + +int +HttpComputeService::Impl::ResolveQueueToken(const Oid& Token) +{ + RwLock::SharedLockScope $(m_RemoteQueueLock); + + auto It = m_RemoteQueuesByToken.find(Token); + + if (It != m_RemoteQueuesByToken.end()) + { + return It->second->QueueId; + } + + return 0; +} + +int +HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture) +{ + if (OidMatcher(Capture)) + { + // Remote OID token — accessible from any client + const Oid Token = Oid::FromHexString(Capture); + const int QueueId = ResolveQueueToken(Token); + + if (QueueId == 0) + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + return QueueId; + } + + // Local integer queue ID — restricted to local machine requests + if (!HttpReq.IsLocalMachineRequest()) + { + HttpReq.WriteResponse(HttpResponseCode::Forbidden); + return 0; + } + + return ParseInt<int>(Capture).value_or(0); +} + +HttpComputeService::Impl::IngestStats +HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package) +{ + IngestStats Stats; + + for (const CbAttachment& Attachment : Package.GetAttachments()) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + Stats.Bytes += CompressedSize; + ++Stats.Count; + + const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + Stats.NewBytes += CompressedSize; + ++Stats.NewCount; + } + } + + return Stats; +} + +bool +HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList) +{ + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + return NeedList.empty(); +} + +void +HttpComputeService::Impl::HandleWorkersGet(HttpServerRequest& HttpReq) +{ + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + for (const IoHash& WorkerId : m_ComputeService.GetKnownWorkerIds()) + { + Cbo << WorkerId; + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpComputeService::Impl::HandleWorkersAllGet(HttpServerRequest& HttpReq) +{ + std::vector<IoHash> WorkerIds = m_ComputeService.GetKnownWorkerIds(); + + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + for (const IoHash& WorkerId : WorkerIds) + { + Cbo.BeginObject(); + Cbo << "id" << WorkerId; + Cbo << "descriptor" << m_ComputeService.GetWorkerDescriptor(WorkerId).Descriptor.GetObject(); + Cbo.EndObject(); + } + + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId) +{ + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + if (WorkerDesc Desc = m_ComputeService.GetWorkerDescriptor(WorkerId)) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor.GetObject()); + } + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + + case HttpVerb::kPost: + { + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + CbObject WorkerSpec = HttpReq.ReadPayloadObject(); + + HashKeySet ChunkSet; + WorkerSpec.IterateAttachments([&](CbFieldView Field) { + const IoHash Hash = Field.AsHash(); + ChunkSet.AddHashToSet(Hash); + }); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerSpec); + + m_CidStore.FilterChunks(ChunkSet); + + if (ChunkSet.IsEmpty()) + { + ZEN_DEBUG("worker {}: all attachments already available", WorkerId); + m_ComputeService.RegisterWorker(WorkerPackage); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); + ChunkSet.IterateHashes([&](const IoHash& Hash) { + ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash); + ResponseWriter.AddHash(Hash); + }); + ResponseWriter.EndArray(); + + ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize()); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage WorkerSpecPackage = HttpReq.ReadPayloadPackage(); + CbObject WorkerSpec = WorkerSpecPackage.GetObject(); + + std::span<const CbAttachment> Attachments = WorkerSpecPackage.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + TotalAttachmentBytes += Buffer.GetCompressedSize(); + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += Buffer.GetCompressedSize(); + ++NewAttachmentCount; + } + } + + ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments", + WorkerId, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + m_ComputeService.RegisterWorker(WorkerSpecPackage); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + break; + + default: + break; + } + } + break; + + default: + break; + } +} + +////////////////////////////////////////////////////////////////////////// + +void +httpcomputeservice_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp new file mode 100644 index 000000000..6cbe01e04 --- /dev/null +++ b/src/zencompute/httporchestrator.cpp @@ -0,0 +1,650 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/httporchestrator.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencompute/orchestratorservice.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/logging.h> +# include <zencore/string.h> +# include <zencore/system.h> + +namespace zen::compute { + +// Worker IDs must be 3-64 characters and can only contain letters, numbers, underscores, and dashes +static bool +IsValidWorkerId(std::string_view Id) +{ + if (Id.size() < 3 || Id.size() > 64) + { + return false; + } + for (char c : Id) + { + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-') + { + continue; + } + return false; + } + return true; +} + +// Shared announce payload parser used by both the HTTP POST route and the +// WebSocket message handler. Returns the worker ID on success (empty on +// validation failure). The returned WorkerAnnouncement has string_view +// fields that reference the supplied CbObjectView, so the CbObject must +// outlive the returned announcement. +static std::string_view +ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnnouncement& Ann) +{ + Ann.Id = Data["id"].AsString(""); + Ann.Uri = Data["uri"].AsString(""); + + if (!IsValidWorkerId(Ann.Id)) + { + return {}; + } + + if (!Ann.Uri.starts_with("http://") && !Ann.Uri.starts_with("https://")) + { + return {}; + } + + Ann.Hostname = Data["hostname"].AsString(""); + Ann.Platform = Data["platform"].AsString(""); + Ann.CpuUsagePercent = Data["cpu_usage"].AsFloat(0.0f); + Ann.MemoryTotalBytes = Data["memory_total"].AsUInt64(0); + Ann.MemoryUsedBytes = Data["memory_used"].AsUInt64(0); + Ann.BytesReceived = Data["bytes_received"].AsUInt64(0); + Ann.BytesSent = Data["bytes_sent"].AsUInt64(0); + Ann.ActionsPending = Data["actions_pending"].AsInt32(0); + Ann.ActionsRunning = Data["actions_running"].AsInt32(0); + Ann.ActionsCompleted = Data["actions_completed"].AsInt32(0); + Ann.ActiveQueues = Data["active_queues"].AsInt32(0); + Ann.Provisioner = Data["provisioner"].AsString(""); + + if (auto Metrics = Data["metrics"].AsObjectView()) + { + Ann.Cpus = Metrics["lp_count"].AsInt32(0); + if (Ann.Cpus <= 0) + { + Ann.Cpus = 1; + } + } + + return Ann.Id; +} + +HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) +: m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket)) +, m_Hostname(GetMachineName()) +{ + m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); + m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); + + // dummy endpoint for websocket clients + m_Router.RegisterRoute( + "ws", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "status", + [this](HttpRouterRequest& Req) { + CbObjectWriter Cbo; + Cbo << "hostname" << std::string_view(m_Hostname); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "provision", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "announce", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObject Data = HttpReq.ReadPayloadObject(); + + OrchestratorService::WorkerAnnouncement Ann; + std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); + + if (WorkerId.empty()) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Invalid worker announcement: id must be 3-64 alphanumeric/underscore/dash " + "characters and uri must start with http:// or https://"); + } + + m_Service->AnnounceWorker(Ann); + + HttpReq.WriteResponse(HttpResponseCode::OK); + +# if ZEN_WITH_WEBSOCKETS + // Notify push thread that state may have changed + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "agents", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + int Limit = 100; + auto LimitStr = Params.GetValue("limit"); + if (!LimitStr.empty()) + { + Limit = std::atoi(std::string(LimitStr).c_str()); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetProvisioningHistory(Limit)); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "timeline/{workerid}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + std::string_view WorkerId = Req.GetCapture(1); + auto Params = HttpReq.GetQueryParams(); + + auto FromStr = Params.GetValue("from"); + auto ToStr = Params.GetValue("to"); + auto LimitStr = Params.GetValue("limit"); + + std::optional<DateTime> From; + std::optional<DateTime> To; + + if (!FromStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(FromStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + From = DateTime(*Val); + } + + if (!ToStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(ToStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + To = DateTime(*Val); + } + + int Limit = !LimitStr.empty() ? zen::ParseInt<int>(LimitStr).value_or(0) : 0; + + CbObject Result = m_Service->GetWorkerTimeline(WorkerId, From, To, Limit); + + if (!Result) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "timeline", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + auto FromStr = Params.GetValue("from"); + auto ToStr = Params.GetValue("to"); + + DateTime From = DateTime(0); + DateTime To = DateTime::Now(); + + if (!FromStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(FromStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + From = DateTime(*Val); + } + + if (!ToStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(ToStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + To = DateTime(*Val); + } + + CbObject Result = m_Service->GetAllTimelines(From, To); + + HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); + }, + HttpVerb::kGet); + + // Client tracking endpoints + + m_Router.RegisterRoute( + "clients", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObject Data = HttpReq.ReadPayloadObject(); + + OrchestratorService::ClientAnnouncement Ann; + Ann.SessionId = Data["session_id"].AsObjectId(Oid::Zero); + Ann.Hostname = Data["hostname"].AsString(""); + Ann.Address = HttpReq.GetRemoteAddress(); + + auto MetadataView = Data["metadata"].AsObjectView(); + if (MetadataView) + { + Ann.Metadata = CbObject::Clone(MetadataView); + } + + std::string ClientId = m_Service->AnnounceClient(Ann); + + CbObjectWriter ResponseObj; + ResponseObj << "id" << std::string_view(ClientId); + HttpReq.WriteResponse(HttpResponseCode::OK, ResponseObj.Save()); + +# if ZEN_WITH_WEBSOCKETS + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients/{clientid}/update", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view ClientId = Req.GetCapture(1); + + CbObject MetadataObj; + CbObject Data = HttpReq.ReadPayloadObject(); + if (Data) + { + auto MetadataView = Data["metadata"].AsObjectView(); + if (MetadataView) + { + MetadataObj = CbObject::Clone(MetadataView); + } + } + + if (m_Service->UpdateClient(ClientId, std::move(MetadataObj))) + { + HttpReq.WriteResponse(HttpResponseCode::OK); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients/{clientid}/complete", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view ClientId = Req.GetCapture(1); + + if (m_Service->CompleteClient(ClientId)) + { + HttpReq.WriteResponse(HttpResponseCode::OK); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + +# if ZEN_WITH_WEBSOCKETS + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetClientList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "clients/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + int Limit = 100; + auto LimitStr = Params.GetValue("limit"); + if (!LimitStr.empty()) + { + Limit = std::atoi(std::string(LimitStr).c_str()); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetClientHistory(Limit)); + }, + HttpVerb::kGet); + +# if ZEN_WITH_WEBSOCKETS + + // Start the WebSocket push thread + m_PushEnabled.store(true); + m_PushThread = std::thread([this] { PushThreadFunction(); }); +# endif +} + +HttpOrchestratorService::~HttpOrchestratorService() +{ + Shutdown(); +} + +void +HttpOrchestratorService::Shutdown() +{ +# if ZEN_WITH_WEBSOCKETS + if (!m_PushEnabled.exchange(false)) + { + return; + } + + // Stop the push thread first, before touching connections. This ensures + // the push thread is no longer reading m_WsConnections or calling into + // m_Service when we start tearing things down. + m_PushEvent.Set(); + if (m_PushThread.joinable()) + { + m_PushThread.join(); + } + + // Clean up worker WebSocket connections — collect IDs under lock, then + // notify the service outside the lock to avoid lock-order inversions. + std::vector<std::string> WorkerIds; + m_WorkerWsLock.WithExclusiveLock([&] { + WorkerIds.reserve(m_WorkerWsMap.size()); + for (const auto& [Conn, Id] : m_WorkerWsMap) + { + WorkerIds.push_back(Id); + } + m_WorkerWsMap.clear(); + }); + for (const auto& Id : WorkerIds) + { + m_Service->SetWorkerWebSocketConnected(Id, false); + } + + // Now that the push thread is gone, release all dashboard connections. + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); }); +# endif +} + +const char* +HttpOrchestratorService::BaseUri() const +{ + return "/orch/"; +} + +void +HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) +{ + if (m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +# if ZEN_WITH_WEBSOCKETS +void +HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + if (!m_PushEnabled.load()) + { + return; + } + + ZEN_INFO("WebSocket client connected"); + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); + + // Wake push thread to send initial state immediately + m_PushEvent.Set(); +} + +void +HttpOrchestratorService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + // Only handle binary messages from workers when the feature is enabled. + if (!m_Service->IsWorkerWebSocketEnabled() || Msg.Opcode != WebSocketOpcode::kBinary) + { + return; + } + + std::string WorkerId = HandleWorkerWebSocketMessage(Msg); + if (WorkerId.empty()) + { + return; + } + + // Check if this is a new worker WebSocket connection + bool IsNewWorkerWs = false; + m_WorkerWsLock.WithExclusiveLock([&] { + auto It = m_WorkerWsMap.find(&Conn); + if (It == m_WorkerWsMap.end()) + { + m_WorkerWsMap[&Conn] = WorkerId; + IsNewWorkerWs = true; + } + }); + + if (IsNewWorkerWs) + { + m_Service->SetWorkerWebSocketConnected(WorkerId, true); + } + + m_PushEvent.Set(); +} + +std::string +HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg) +{ + // Workers send CbObject in native binary format over the WebSocket to + // avoid the lossy CbObject↔JSON round-trip. + CbObject Data = CbObject::MakeView(Msg.Payload.GetData()); + if (!Data) + { + ZEN_WARN("worker WebSocket message is not a valid CbObject"); + return {}; + } + + OrchestratorService::WorkerAnnouncement Ann; + std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); + if (WorkerId.empty()) + { + ZEN_WARN("invalid worker announcement via WebSocket"); + return {}; + } + + m_Service->AnnounceWorker(Ann); + return std::string(WorkerId); +} + +void +HttpOrchestratorService::OnWebSocketClose(WebSocketConnection& Conn, + [[maybe_unused]] uint16_t Code, + [[maybe_unused]] std::string_view Reason) +{ + ZEN_INFO("WebSocket client disconnected (code {})", Code); + + // Check if this was a worker WebSocket connection; collect the ID under + // the worker lock, then notify the service outside the lock. + std::string DisconnectedWorkerId; + m_WorkerWsLock.WithExclusiveLock([&] { + auto It = m_WorkerWsMap.find(&Conn); + if (It != m_WorkerWsMap.end()) + { + DisconnectedWorkerId = std::move(It->second); + m_WorkerWsMap.erase(It); + } + }); + + if (!DisconnectedWorkerId.empty()) + { + m_Service->SetWorkerWebSocketConnected(DisconnectedWorkerId, false); + m_PushEvent.Set(); + } + + if (!m_PushEnabled.load()) + { + return; + } + + // Remove from dashboard connections + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} +# endif + +////////////////////////////////////////////////////////////////////////// +// +// Push thread +// + +# if ZEN_WITH_WEBSOCKETS +void +HttpOrchestratorService::PushThreadFunction() +{ + SetCurrentThreadName("orch_ws_push"); + + while (m_PushEnabled.load()) + { + m_PushEvent.Wait(2000); + m_PushEvent.Reset(); + + if (!m_PushEnabled.load()) + { + break; + } + + // Snapshot current connections + std::vector<Ref<WebSocketConnection>> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; }); + + if (Connections.empty()) + { + continue; + } + + // Build combined JSON with worker list, provisioning history, clients, and client history + CbObject WorkerList = m_Service->GetWorkerList(); + CbObject History = m_Service->GetProvisioningHistory(50); + CbObject ClientList = m_Service->GetClientList(); + CbObject ClientHistory = m_Service->GetClientHistory(50); + + ExtendableStringBuilder<4096> JsonBuilder; + JsonBuilder.Append("{"); + JsonBuilder.Append(fmt::format("\"hostname\":\"{}\",", m_Hostname)); + + // Emit workers array from worker list + ExtendableStringBuilder<2048> WorkerJson; + WorkerList.ToJson(WorkerJson); + std::string_view WorkerJsonView = WorkerJson.ToView(); + // Strip outer braces: {"workers":[...]} -> "workers":[...] + if (WorkerJsonView.size() >= 2) + { + JsonBuilder.Append(WorkerJsonView.substr(1, WorkerJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit events array from history + ExtendableStringBuilder<2048> HistoryJson; + History.ToJson(HistoryJson); + std::string_view HistoryJsonView = HistoryJson.ToView(); + if (HistoryJsonView.size() >= 2) + { + JsonBuilder.Append(HistoryJsonView.substr(1, HistoryJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit clients array from client list + ExtendableStringBuilder<2048> ClientJson; + ClientList.ToJson(ClientJson); + std::string_view ClientJsonView = ClientJson.ToView(); + if (ClientJsonView.size() >= 2) + { + JsonBuilder.Append(ClientJsonView.substr(1, ClientJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit client_events array from client history + ExtendableStringBuilder<2048> ClientHistoryJson; + ClientHistory.ToJson(ClientHistoryJson); + std::string_view ClientHistoryJsonView = ClientHistoryJson.ToView(); + if (ClientHistoryJsonView.size() >= 2) + { + JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2)); + } + + JsonBuilder.Append("}"); + std::string_view Json = JsonBuilder.ToView(); + + // Broadcast to all connected clients, prune closed ones + bool HadClosedConnections = false; + + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Json); + } + else + { + HadClosedConnections = true; + } + } + + if (HadClosedConnections) + { + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref<WebSocketConnection>& C) { + return !C->IsOpen(); + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); + } + } +} +# endif + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h new file mode 100644 index 000000000..a5bc5a34d --- /dev/null +++ b/src/zencompute/include/zencompute/cloudmetadata.h @@ -0,0 +1,151 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinarybuilder.h> +#include <zencore/logging.h> +#include <zencore/thread.h> + +#include <atomic> +#include <filesystem> +#include <string> +#include <thread> + +namespace zen::compute { + +enum class CloudProvider +{ + None, + AWS, + Azure, + GCP +}; + +std::string_view ToString(CloudProvider Provider); + +/** Snapshot of detected cloud instance properties. */ +struct CloudInstanceInfo +{ + CloudProvider Provider = CloudProvider::None; + std::string InstanceId; + std::string AvailabilityZone; + bool IsSpot = false; + bool IsAutoscaling = false; +}; + +/** + * Detects whether the process is running on a cloud VM (AWS, Azure, or GCP) + * and monitors for impending termination signals. + * + * Detection works by querying the Instance Metadata Service (IMDS) at the + * well-known link-local address 169.254.169.254, which is only routable from + * within a cloud VM. Each provider is probed in sequence (AWS -> Azure -> GCP); + * the first successful response wins. + * + * To avoid a ~200ms connect timeout penalty on every startup when running on + * bare-metal or non-cloud machines, failed probes write sentinel files + * (e.g. ".isNotAWS") to DataDir. Subsequent startups skip providers that have + * a sentinel present. Delete the sentinel files to force re-detection. + * + * When a provider is detected, a background thread polls for termination + * signals every 5 seconds (spot interruption, autoscaling lifecycle changes, + * scheduled maintenance). The termination state is exposed as an atomic bool + * so the compute server can include it in coordinator announcements and react + * to imminent shutdown. + * + * Thread safety: GetInstanceInfo() and GetTerminationReason() acquire a + * shared RwLock; the background monitor thread acquires the exclusive lock + * only when writing the termination reason (a one-time transition). The + * termination-pending flag itself is a lock-free atomic. + * + * Usage: + * auto Cloud = std::make_unique<CloudMetadata>(DataDir / "cloud"); + * if (Cloud->IsTerminationPending()) { ... } + * Cloud->Describe(AnnounceBody); // writes "cloud" sub-object into CB + */ +class CloudMetadata +{ +public: + /** Synchronously probes cloud providers and starts the termination monitor + * if a provider is detected. Creates DataDir if it does not exist. + */ + explicit CloudMetadata(std::filesystem::path DataDir); + + /** Synchronously probes cloud providers at the given IMDS endpoint. + * Intended for testing — allows redirecting all IMDS queries to a local + * mock HTTP server instead of the real 169.254.169.254 endpoint. + */ + CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint); + + /** Stops the termination monitor thread and joins it. */ + ~CloudMetadata(); + + CloudMetadata(const CloudMetadata&) = delete; + CloudMetadata& operator=(const CloudMetadata&) = delete; + + CloudProvider GetProvider() const; + CloudInstanceInfo GetInstanceInfo() const; + bool IsTerminationPending() const; + std::string GetTerminationReason() const; + + /** Writes a "cloud" sub-object into the compact binary writer if a provider + * was detected. No-op when running on bare metal. + */ + void Describe(CbWriter& Writer) const; + + /** Executes a single termination-poll cycle for the detected provider. + * Public so tests can drive poll cycles synchronously without relying on + * the background thread's 5-second timer. + */ + void PollTermination(); + + /** Removes the negative-cache sentinel files (.isNotAWS, .isNotAzure, + * .isNotGCP) from DataDir so subsequent detection probes are not skipped. + * Primarily intended for tests that need to reset state between sub-cases. + */ + void ClearSentinelFiles(); + +private: + /** Tries each provider in order, stops on first successful detection. */ + void DetectProvider(); + bool TryDetectAWS(); + bool TryDetectAzure(); + bool TryDetectGCP(); + + void WriteSentinelFile(const std::filesystem::path& Path); + bool HasSentinelFile(const std::filesystem::path& Path) const; + + void StartTerminationMonitor(); + void TerminationMonitorThread(); + void PollAWSTermination(); + void PollAzureTermination(); + void PollGCPTermination(); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + std::filesystem::path m_DataDir; + std::string m_ImdsEndpoint; + + mutable RwLock m_InfoLock; + CloudInstanceInfo m_Info; + + std::atomic<bool> m_TerminationPending{false}; + + mutable RwLock m_ReasonLock; + std::string m_TerminationReason; + + // IMDSv2 session token, acquired during AWS detection and reused for + // subsequent termination polling. Has a 300s TTL on the AWS side; if it + // expires mid-run the poll requests will get 401s which we treat as + // non-terminal (the monitor simply retries next cycle). + std::string m_AwsToken; + + std::thread m_MonitorThread; + std::atomic<bool> m_MonitorEnabled{true}; + Event m_MonitorEvent; +}; + +void cloudmetadata_forcelink(); // internal + +} // namespace zen::compute diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h new file mode 100644 index 000000000..65ec5f9ee --- /dev/null +++ b/src/zencompute/include/zencompute/computeservice.h @@ -0,0 +1,262 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/iohash.h> +# include <zenstore/zenstore.h> +# include <zenhttp/httpcommon.h> + +# include <filesystem> + +namespace zen { +class ChunkResolver; +class CbObjectWriter; +} // namespace zen + +namespace zen::compute { + +class ActionRecorder; +class ComputeServiceSession; +class IActionResultHandler; +class LocalProcessRunner; +class RemoteHttpRunner; +struct RunnerAction; +struct SubmitResult; + +struct WorkerDesc +{ + CbPackage Descriptor; + IoHash WorkerId{IoHash::Zero}; + + inline operator bool() const { return WorkerId != IoHash::Zero; } +}; + +/** + * Lambda style compute function service + * + * The responsibility of this class is to accept function execution requests, and + * schedule them using one or more FunctionRunner instances. It will basically always + * accept requests, queueing them if necessary, and then hand them off to runners + * as they become available. + * + * This is typically fronted by an API service that handles communication with clients. + */ +class ComputeServiceSession final +{ +public: + /** + * Session lifecycle state machine. + * + * Forward transitions: Created -> Ready -> Draining -> Paused -> Sunset + * Backward transitions: Draining -> Ready, Paused -> Ready + * Automatic transition: Draining -> Paused (when pending + running reaches 0) + * Jump transitions: any non-terminal -> Abandoned, any non-terminal -> Sunset + * Terminal states: Abandoned (only Sunset out), Sunset (no transitions out) + * + * | State | Accept new actions | Schedule pending | Finish running | + * |-----------|-------------------|-----------------|----------------| + * | Created | No | No | N/A | + * | Ready | Yes | Yes | Yes | + * | Draining | No | Yes | Yes | + * | Paused | No | No | No | + * | Abandoned | No | No | No (all abandoned) | + * | Sunset | No | No | No | + */ + enum class SessionState + { + Created, // Initial state before WaitUntilReady completes + Ready, // Normal operating state; accepts and schedules work + Draining, // Stops accepting new work; finishes existing; auto-transitions to Paused when empty + Paused, // Idle; no work accepted or scheduled; can resume to Ready + Abandoned, // Spot termination grace period; all actions abandoned; only Sunset out + Sunset // Terminal; triggers full shutdown + }; + + ComputeServiceSession(ChunkResolver& InChunkResolver); + ~ComputeServiceSession(); + + void WaitUntilReady(); + void Shutdown(); + bool IsHealthy(); + + SessionState GetSessionState() const; + + // Request a state transition. Returns false if the transition is invalid. + // Sunset can be reached from any non-Sunset state. + bool RequestStateTransition(SessionState NewState); + + // Orchestration + + void SetOrchestratorEndpoint(std::string_view Endpoint); + void SetOrchestratorBasePath(std::filesystem::path BasePath); + + // Worker registration and discovery + + void RegisterWorker(CbPackage Worker); + [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + [[nodiscard]] std::vector<IoHash> GetKnownWorkerIds(); + + // Action runners + + void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); + void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); + + // Action submission + + struct EnqueueResult + { + int Lsn; + CbObject ResponseMessage; + + inline operator bool() const { return Lsn != 0; } + }; + + [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority); + [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); + + // Queue management + // + // Queues group actions submitted by a single client session. They allow + // cancelling or polling completion of all actions in the group. + + struct CreateQueueResult + { + int QueueId = 0; // 0 if creation failed + }; + + enum class QueueState + { + Active, + Draining, + Cancelled, + }; + + struct QueueStatus + { + bool IsValid = false; + int QueueId = 0; + int ActiveCount = 0; // pending + running (not yet completed) + int CompletedCount = 0; // successfully completed + int FailedCount = 0; // failed + int AbandonedCount = 0; // abandoned + int CancelledCount = 0; // cancelled + QueueState State = QueueState::Active; + bool IsComplete = false; // ActiveCount == 0 + }; + + [[nodiscard]] CreateQueueResult CreateQueue(std::string_view Tag = {}, CbObject Metadata = {}, CbObject Config = {}); + [[nodiscard]] std::vector<int> GetQueueIds(); + [[nodiscard]] QueueStatus GetQueueStatus(int QueueId); + [[nodiscard]] CbObject GetQueueMetadata(int QueueId); + [[nodiscard]] CbObject GetQueueConfig(int QueueId); + void CancelQueue(int QueueId); + void DrainQueue(int QueueId); + void DeleteQueue(int QueueId); + void GetQueueCompleted(int QueueId, CbWriter& Cbo); + + // Queue-scoped action submission. Actions submitted via these methods are + // tracked under the given queue in addition to the global LSN-based tracking. + + [[nodiscard]] EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); + [[nodiscard]] EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + + // Completed action tracking + + [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + void RetireActionResult(int ActionLsn); + + // Action rescheduling + + struct RescheduleResult + { + bool Success = false; + std::string Error; + int RetryCount = 0; + }; + + [[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn); + + void GetCompleted(CbWriter&); + + // Running action tracking + + struct RunningActionInfo + { + int Lsn; + int QueueId; + IoHash ActionId; + float CpuUsagePercent; // -1.0 if not yet sampled + float CpuSeconds; // 0.0 if not yet sampled + }; + + [[nodiscard]] std::vector<RunningActionInfo> GetRunningActions(); + + // Action history tracking (note that this is separate from completed action tracking, and + // will include actions which have been retired and no longer have their results available) + + struct ActionHistoryEntry + { + int Lsn; + int QueueId = 0; + IoHash ActionId; + IoHash WorkerId; + CbObject ActionDescriptor; + std::string ExecutionLocation; + bool Succeeded; + float CpuSeconds = 0.0f; // total CPU time at completion; 0.0 if not sampled + int RetryCount = 0; // number of times this action was rescheduled + // sized to match RunnerAction::State::_Count but we can't use the enum here + // for dependency reasons, so just use a fixed size array and static assert in + // the implementation file + uint64_t Timestamps[8] = {}; + }; + + [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); + [[nodiscard]] std::vector<ActionHistoryEntry> GetQueueHistory(int QueueId, int Limit = 100); + + // Stats reporting + + struct ActionCounts + { + int Pending = 0; + int Running = 0; + int Completed = 0; + int ActiveQueues = 0; + }; + + [[nodiscard]] ActionCounts GetActionCounts(); + + void EmitStats(CbObjectWriter& Cbo); + + // Recording + + void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); + void StopRecording(); + +private: + void PostUpdate(RunnerAction* Action); + + friend class FunctionRunner; + friend struct RunnerAction; + + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void computeservice_forcelink(); + +} // namespace zen::compute + +namespace zen { +const char* ToString(compute::ComputeServiceSession::SessionState State); +const char* ToString(compute::ComputeServiceSession::QueueState State); +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h new file mode 100644 index 000000000..ee1cd2614 --- /dev/null +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "zencompute/computeservice.h" + +# include <zenhttp/httpserver.h> + +# include <filesystem> +# include <memory> + +namespace zen { +class CidStore; +} + +namespace zen::compute { + +/** + * HTTP interface for compute service + */ +class HttpComputeService : public HttpService, public IHttpStatsProvider +{ +public: + HttpComputeService(CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions = 0); + ~HttpComputeService(); + + void Shutdown(); + + [[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts(); + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + + // IHttpStatsProvider + + void HandleStatsRequest(HttpServerRequest& Request) override; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void httpcomputeservice_forcelink(); + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h new file mode 100644 index 000000000..da5c5dfc3 --- /dev/null +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -0,0 +1,101 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#include <zencore/logging.h> +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> + +#include <atomic> +#include <filesystem> +#include <memory> +#include <string> +#include <thread> +#include <unordered_map> +#include <vector> + +#define ZEN_WITH_WEBSOCKETS 1 + +namespace zen::compute { + +class OrchestratorService; + +// Experimental helper, to see if we can get rid of some error-prone +// boilerplate when declaring loggers as class members. + +class LoggerHelper +{ +public: + LoggerHelper(std::string_view Logger) : m_Log(logging::Get(Logger)) {} + + LoggerRef operator()() { return m_Log; } + +private: + LoggerRef m_Log; +}; + +/** + * Orchestrator HTTP service with WebSocket push support + * + * Normal HTTP requests are routed through the HttpRequestRouter as before. + * WebSocket clients connecting to /orch/ws receive periodic state broadcasts + * from a dedicated push thread, eliminating the need for polling. + */ + +class HttpOrchestratorService : public HttpService +#if ZEN_WITH_WEBSOCKETS +, + public IWebSocketHandler +#endif +{ +public: + explicit HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); + ~HttpOrchestratorService(); + + HttpOrchestratorService(const HttpOrchestratorService&) = delete; + HttpOrchestratorService& operator=(const HttpOrchestratorService&) = delete; + + /** + * Gracefully shut down the WebSocket push thread and release connections. + * Must be called while the ASIO io_context is still alive. The destructor + * also calls this, so it is safe (but not ideal) to omit the explicit call. + */ + void Shutdown(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + + // IWebSocketHandler +#if ZEN_WITH_WEBSOCKETS + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; +#endif + +private: + HttpRequestRouter m_Router; + LoggerHelper Log{"orch"}; + std::unique_ptr<OrchestratorService> m_Service; + std::string m_Hostname; + + // WebSocket push + +#if ZEN_WITH_WEBSOCKETS + RwLock m_WsConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; + std::thread m_PushThread; + std::atomic<bool> m_PushEnabled{false}; + Event m_PushEvent; + void PushThreadFunction(); + + // Worker WebSocket connections (worker→orchestrator persistent links) + RwLock m_WorkerWsLock; + std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr → worker ID + std::string HandleWorkerWebSocketMessage(const WebSocketMessage& Msg); +#endif +}; + +} // namespace zen::compute diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h new file mode 100644 index 000000000..521722e63 --- /dev/null +++ b/src/zencompute/include/zencompute/mockimds.h @@ -0,0 +1,102 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/cloudmetadata.h> +#include <zenhttp/httpserver.h> + +#include <string> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +/** + * Mock IMDS (Instance Metadata Service) for testing CloudMetadata. + * + * Implements an HttpService that responds to the same URL paths as the real + * cloud provider metadata endpoints (AWS IMDSv2, Azure IMDS, GCP metadata). + * Tests configure which provider is "active" and set the desired response + * values, then pass the mock server's address as the ImdsEndpoint to the + * CloudMetadata constructor. + * + * When a request arrives for a provider that is not the ActiveProvider, the + * mock returns 404, causing CloudMetadata to write a sentinel file and move + * on to the next provider — exactly like a failed probe on bare metal. + * + * All config fields are public and can be mutated between poll cycles to + * simulate state changes (e.g. a spot interruption appearing mid-run). + * + * Usage: + * MockImdsService Mock; + * Mock.ActiveProvider = CloudProvider::AWS; + * Mock.Aws.InstanceId = "i-test"; + * // ... stand up ASIO server, register Mock, create CloudMetadata with endpoint + */ +class MockImdsService : public HttpService +{ +public: + /** AWS IMDSv2 response configuration. */ + struct AwsConfig + { + std::string Token = "mock-aws-token-v2"; + std::string InstanceId = "i-0123456789abcdef0"; + std::string AvailabilityZone = "us-east-1a"; + std::string LifeCycle = "on-demand"; // "spot" or "on-demand" + + // Empty string → endpoint returns 404 (instance not in an ASG). + // Non-empty → returned as the response body. "InService" means healthy; + // anything else (e.g. "Terminated:Wait") triggers termination detection. + std::string AutoscalingState; + + // Empty string → endpoint returns 404 (no spot interruption). + // Non-empty → returned as the response body, signalling a spot reclaim. + std::string SpotAction; + }; + + /** Azure IMDS response configuration. */ + struct AzureConfig + { + std::string VmId = "vm-12345678-1234-1234-1234-123456789abc"; + std::string Location = "eastus"; + std::string Priority = "Regular"; // "Spot" or "Regular" + + // Empty → instance is not in a VM Scale Set (no autoscaling). + std::string VmScaleSetName; + + // Empty → no scheduled events. Set to "Preempt", "Terminate", or + // "Reboot" to simulate a termination-class event. + std::string ScheduledEventType; + std::string ScheduledEventStatus = "Scheduled"; + }; + + /** GCP metadata response configuration. */ + struct GcpConfig + { + std::string InstanceId = "1234567890123456789"; + std::string Zone = "projects/123456/zones/us-central1-a"; + std::string Preemptible = "FALSE"; // "TRUE" or "FALSE" + std::string MaintenanceEvent = "NONE"; // "NONE" or event description + }; + + /** Which provider's endpoints respond successfully. + * Requests targeting other providers receive 404. + */ + CloudProvider ActiveProvider = CloudProvider::None; + + AwsConfig Aws; + AzureConfig Azure; + GcpConfig Gcp; + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + +private: + void HandleAwsRequest(HttpServerRequest& Request); + void HandleAzureRequest(HttpServerRequest& Request); + void HandleGcpRequest(HttpServerRequest& Request); +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h new file mode 100644 index 000000000..071e902b3 --- /dev/null +++ b/src/zencompute/include/zencompute/orchestratorservice.h @@ -0,0 +1,177 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/thread.h> +# include <zencore/timer.h> +# include <zencore/uid.h> + +# include <deque> +# include <optional> +# include <filesystem> +# include <memory> +# include <string> +# include <string_view> +# include <thread> +# include <unordered_map> + +namespace zen::compute { + +class WorkerTimelineStore; + +class OrchestratorService +{ +public: + explicit OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); + ~OrchestratorService(); + + OrchestratorService(const OrchestratorService&) = delete; + OrchestratorService& operator=(const OrchestratorService&) = delete; + + struct WorkerAnnouncement + { + std::string_view Id; + std::string_view Uri; + std::string_view Hostname; + std::string_view Platform; // e.g. "windows", "wine", "linux", "macos" + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string_view Provisioner; // e.g. "horde", "nomad", or empty + }; + + struct ProvisioningEvent + { + enum class Type + { + Joined, + Left, + Returned + }; + Type EventType; + DateTime Timestamp; + std::string WorkerId; + std::string Hostname; + }; + + struct ClientAnnouncement + { + Oid SessionId; + std::string_view Hostname; + std::string_view Address; + CbObject Metadata; + }; + + struct ClientEvent + { + enum class Type + { + Connected, + Disconnected, + Updated + }; + Type EventType; + DateTime Timestamp; + std::string ClientId; + std::string Hostname; + }; + + CbObject GetWorkerList(); + void AnnounceWorker(const WorkerAnnouncement& Announcement); + + bool IsWorkerWebSocketEnabled() const; + void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected); + + CbObject GetProvisioningHistory(int Limit = 100); + + CbObject GetWorkerTimeline(std::string_view WorkerId, std::optional<DateTime> From, std::optional<DateTime> To, int Limit); + + CbObject GetAllTimelines(DateTime From, DateTime To); + + std::string AnnounceClient(const ClientAnnouncement& Announcement); + bool UpdateClient(std::string_view ClientId, CbObject Metadata = {}); + bool CompleteClient(std::string_view ClientId); + CbObject GetClientList(); + CbObject GetClientHistory(int Limit = 100); + +private: + enum class ReachableState + { + Unknown, + Reachable, + Unreachable, + }; + + struct KnownWorker + { + std::string BaseUri; + Stopwatch LastSeen; + std::string Hostname; + std::string Platform; + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string Provisioner; + ReachableState Reachable = ReachableState::Unknown; + bool WsConnected = false; + Stopwatch LastProbed; + }; + + RwLock m_KnownWorkersLock; + std::unordered_map<std::string, KnownWorker> m_KnownWorkers; + std::unique_ptr<WorkerTimelineStore> m_TimelineStore; + + RwLock m_ProvisioningLogLock; + std::deque<ProvisioningEvent> m_ProvisioningLog; + static constexpr size_t kMaxProvisioningEvents = 1000; + + void RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname); + + struct KnownClient + { + Oid SessionId; + std::string Hostname; + std::string Address; + Stopwatch LastSeen; + CbObject Metadata; + }; + + RwLock m_KnownClientsLock; + std::unordered_map<std::string, KnownClient> m_KnownClients; + + RwLock m_ClientLogLock; + std::deque<ClientEvent> m_ClientLog; + static constexpr size_t kMaxClientEvents = 1000; + + void RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname); + + bool m_EnableWorkerWebSocket = false; + + std::thread m_ProbeThread; + std::atomic<bool> m_ProbeThreadEnabled{true}; + Event m_ProbeThreadEvent; + void ProbeThreadFunction(); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/include/zencompute/recordingreader.h b/src/zencompute/include/zencompute/recordingreader.h new file mode 100644 index 000000000..3f233fae0 --- /dev/null +++ b/src/zencompute/include/zencompute/recordingreader.h @@ -0,0 +1,129 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#include <zencompute/computeservice.h> +#include <zencompute/zencompute.h> +#include <zencore/basicfile.h> +#include <zencore/compactbinarybuilder.h> +#include <zenstore/cidstore.h> +#include <zenstore/gc.h> +#include <zenstore/zenstore.h> + +#include <filesystem> +#include <functional> +#include <unordered_map> + +namespace zen { +class CbObject; +class CbPackage; +struct IoHash; +} // namespace zen + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +////////////////////////////////////////////////////////////////////////// + +class RecordingReaderBase +{ + RecordingReaderBase(const RecordingReaderBase&) = delete; + RecordingReaderBase& operator=(const RecordingReaderBase&) = delete; + +public: + RecordingReaderBase() = default; + virtual ~RecordingReaderBase() = 0; + virtual std::unordered_map<IoHash, CbPackage> ReadWorkers() = 0; + virtual void IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, int TargetParallelism) = 0; + virtual size_t GetActionCount() const = 0; +}; + +////////////////////////////////////////////////////////////////////////// + +/** + * Reader for recordings done via the zencompute recording system, which + * have a shared chunk store and a log of actions with pointers into the + * chunk store for their data. + */ +class RecordingReader : public RecordingReaderBase, public ChunkResolver +{ +public: + explicit RecordingReader(const std::filesystem::path& RecordingPath); + ~RecordingReader(); + + virtual std::unordered_map<zen::IoHash, zen::CbPackage> ReadWorkers() override; + + virtual void IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, + int TargetParallelism) override; + virtual size_t GetActionCount() const override; + +private: + std::filesystem::path m_RecordingLogDir; + BasicFile m_WorkerDataFile; + BasicFile m_ActionDataFile; + GcManager m_Gc; + CidStore m_CidStore{m_Gc}; + + // ChunkResolver interface + virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + + struct ActionEntry + { + IoHash ActionId; + uint64_t Offset; + uint64_t Size; + }; + + std::vector<ActionEntry> m_Actions; + + void ScanActions(); +}; + +////////////////////////////////////////////////////////////////////////// + +struct LocalResolver : public ChunkResolver +{ + LocalResolver(const LocalResolver&) = delete; + LocalResolver& operator=(const LocalResolver&) = delete; + + LocalResolver() = default; + ~LocalResolver() = default; + + virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + void Add(const IoHash& Cid, IoBuffer Data); + +private: + RwLock MapLock; + std::unordered_map<IoHash, IoBuffer> Attachments; +}; + +/** + * This is a reader for UE/DDB recordings, which have a different layout on + * disk (no shared chunk store) + */ +class UeRecordingReader : public RecordingReaderBase, public ChunkResolver +{ +public: + explicit UeRecordingReader(const std::filesystem::path& RecordingPath); + ~UeRecordingReader(); + + virtual std::unordered_map<zen::IoHash, zen::CbPackage> ReadWorkers() override; + virtual void IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, + int TargetParallelism) override; + virtual size_t GetActionCount() const override; + virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + +private: + std::filesystem::path m_RecordingDir; + LocalResolver m_LocalResolver; + std::vector<std::filesystem::path> m_WorkDirs; + + CbPackage ReadAction(std::filesystem::path WorkDir); +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/zencompute.h b/src/zencompute/include/zencompute/zencompute.h new file mode 100644 index 000000000..00be4d4a0 --- /dev/null +++ b/src/zencompute/include/zencompute/zencompute.h @@ -0,0 +1,15 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if !defined(ZEN_WITH_COMPUTE_SERVICES) +# define ZEN_WITH_COMPUTE_SERVICES 1 +#endif + +namespace zen { + +void zencompute_forcelinktests(); + +} diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp new file mode 100644 index 000000000..9ea695305 --- /dev/null +++ b/src/zencompute/orchestratorservice.cpp @@ -0,0 +1,710 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/orchestratorservice.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinarybuilder.h> +# include <zencore/logging.h> +# include <zencore/trace.h> +# include <zenhttp/httpclient.h> + +# include "timeline/workertimeline.h" + +namespace zen::compute { + +OrchestratorService::OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) +: m_TimelineStore(std::make_unique<WorkerTimelineStore>(DataDir / "timelines")) +, m_EnableWorkerWebSocket(EnableWorkerWebSocket) +{ + m_ProbeThread = std::thread{&OrchestratorService::ProbeThreadFunction, this}; +} + +OrchestratorService::~OrchestratorService() +{ + m_ProbeThreadEnabled = false; + m_ProbeThreadEvent.Set(); + if (m_ProbeThread.joinable()) + { + m_ProbeThread.join(); + } +} + +CbObject +OrchestratorService::GetWorkerList() +{ + ZEN_TRACE_CPU("OrchestratorService::GetWorkerList"); + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + m_KnownWorkersLock.WithSharedLock([&] { + for (const auto& [WorkerId, Worker] : m_KnownWorkers) + { + Cbo.BeginObject(); + Cbo << "id" << WorkerId; + Cbo << "uri" << Worker.BaseUri; + Cbo << "hostname" << Worker.Hostname; + if (!Worker.Platform.empty()) + { + Cbo << "platform" << std::string_view(Worker.Platform); + } + Cbo << "cpus" << Worker.Cpus; + Cbo << "cpu_usage" << Worker.CpuUsagePercent; + Cbo << "memory_total" << Worker.MemoryTotalBytes; + Cbo << "memory_used" << Worker.MemoryUsedBytes; + Cbo << "bytes_received" << Worker.BytesReceived; + Cbo << "bytes_sent" << Worker.BytesSent; + Cbo << "actions_pending" << Worker.ActionsPending; + Cbo << "actions_running" << Worker.ActionsRunning; + Cbo << "actions_completed" << Worker.ActionsCompleted; + Cbo << "active_queues" << Worker.ActiveQueues; + if (!Worker.Provisioner.empty()) + { + Cbo << "provisioner" << std::string_view(Worker.Provisioner); + } + if (Worker.Reachable != ReachableState::Unknown) + { + Cbo << "reachable" << (Worker.Reachable == ReachableState::Reachable); + } + if (Worker.WsConnected) + { + Cbo << "ws_connected" << true; + } + Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +void +OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann) +{ + ZEN_TRACE_CPU("OrchestratorService::AnnounceWorker"); + + bool IsNew = false; + std::string EvictedId; + std::string EvictedHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + IsNew = (m_KnownWorkers.find(std::string(Ann.Id)) == m_KnownWorkers.end()); + + // If a different worker ID already maps to the same URI, the old entry + // is stale (e.g. a previous Horde lease on the same machine). Remove it + // so the dashboard doesn't show duplicates. + if (IsNew) + { + for (auto It = m_KnownWorkers.begin(); It != m_KnownWorkers.end(); ++It) + { + if (It->second.BaseUri == Ann.Uri && It->first != Ann.Id) + { + EvictedId = It->first; + EvictedHostname = It->second.Hostname; + m_KnownWorkers.erase(It); + break; + } + } + } + + auto& Worker = m_KnownWorkers[std::string(Ann.Id)]; + Worker.BaseUri = Ann.Uri; + Worker.Hostname = Ann.Hostname; + if (!Ann.Platform.empty()) + { + Worker.Platform = Ann.Platform; + } + Worker.Cpus = Ann.Cpus; + Worker.CpuUsagePercent = Ann.CpuUsagePercent; + Worker.MemoryTotalBytes = Ann.MemoryTotalBytes; + Worker.MemoryUsedBytes = Ann.MemoryUsedBytes; + Worker.BytesReceived = Ann.BytesReceived; + Worker.BytesSent = Ann.BytesSent; + Worker.ActionsPending = Ann.ActionsPending; + Worker.ActionsRunning = Ann.ActionsRunning; + Worker.ActionsCompleted = Ann.ActionsCompleted; + Worker.ActiveQueues = Ann.ActiveQueues; + if (!Ann.Provisioner.empty()) + { + Worker.Provisioner = Ann.Provisioner; + } + Worker.LastSeen.Reset(); + }); + + if (!EvictedId.empty()) + { + ZEN_INFO("worker {} superseded by {} (same endpoint)", EvictedId, Ann.Id); + RecordProvisioningEvent(ProvisioningEvent::Type::Left, EvictedId, EvictedHostname); + } + + if (IsNew) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Joined, Ann.Id, Ann.Hostname); + } +} + +bool +OrchestratorService::IsWorkerWebSocketEnabled() const +{ + return m_EnableWorkerWebSocket; +} + +void +OrchestratorService::SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected) +{ + ReachableState PrevState = ReachableState::Unknown; + std::string WorkerHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + auto It = m_KnownWorkers.find(std::string(WorkerId)); + if (It == m_KnownWorkers.end()) + { + return; + } + + PrevState = It->second.Reachable; + WorkerHostname = It->second.Hostname; + It->second.WsConnected = Connected; + It->second.Reachable = Connected ? ReachableState::Reachable : ReachableState::Unreachable; + + if (Connected) + { + ZEN_INFO("worker {} WebSocket connected — marking reachable", WorkerId); + } + else + { + ZEN_WARN("worker {} WebSocket disconnected — marking unreachable", WorkerId); + } + }); + + // Record provisioning events for state transitions outside the lock + if (Connected && PrevState == ReachableState::Unreachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Returned, WorkerId, WorkerHostname); + } + else if (!Connected && PrevState == ReachableState::Reachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Left, WorkerId, WorkerHostname); + } +} + +CbObject +OrchestratorService::GetWorkerTimeline(std::string_view WorkerId, std::optional<DateTime> From, std::optional<DateTime> To, int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetWorkerTimeline"); + + Ref<WorkerTimeline> Timeline = m_TimelineStore->Find(WorkerId); + if (!Timeline) + { + return {}; + } + + std::vector<WorkerTimeline::Event> Events; + + if (From || To) + { + DateTime StartTime = From.value_or(DateTime(0)); + DateTime EndTime = To.value_or(DateTime::Now()); + Events = Timeline->QueryTimeline(StartTime, EndTime); + } + else if (Limit > 0) + { + Events = Timeline->QueryRecent(Limit); + } + else + { + Events = Timeline->QueryRecent(); + } + + WorkerTimeline::TimeRange Range = Timeline->GetTimeRange(); + + CbObjectWriter Cbo; + Cbo << "worker_id" << WorkerId; + Cbo << "event_count" << static_cast<int32_t>(Timeline->GetEventCount()); + + if (Range) + { + Cbo.AddDateTime("time_first", Range.First); + Cbo.AddDateTime("time_last", Range.Last); + } + + Cbo.BeginArray("events"); + for (const auto& Evt : Events) + { + Cbo.BeginObject(); + Cbo << "type" << WorkerTimeline::ToString(Evt.Type); + Cbo.AddDateTime("ts", Evt.Timestamp); + + if (Evt.ActionLsn != 0) + { + Cbo << "lsn" << Evt.ActionLsn; + Cbo << "action_id" << Evt.ActionId; + } + + if (Evt.Type == WorkerTimeline::EventType::ActionStateChanged) + { + Cbo << "prev_state" << RunnerAction::ToString(Evt.PreviousState); + Cbo << "state" << RunnerAction::ToString(Evt.ActionState); + } + + if (!Evt.Reason.empty()) + { + Cbo << "reason" << std::string_view(Evt.Reason); + } + + Cbo.EndObject(); + } + Cbo.EndArray(); + + return Cbo.Save(); +} + +CbObject +OrchestratorService::GetAllTimelines(DateTime From, DateTime To) +{ + ZEN_TRACE_CPU("OrchestratorService::GetAllTimelines"); + + DateTime StartTime = From; + DateTime EndTime = To; + + auto AllInfo = m_TimelineStore->GetAllWorkerInfo(); + + CbObjectWriter Cbo; + Cbo.AddDateTime("from", StartTime); + Cbo.AddDateTime("to", EndTime); + + Cbo.BeginArray("workers"); + for (const auto& Info : AllInfo) + { + if (!Info.Range || Info.Range.Last < StartTime || Info.Range.First > EndTime) + { + continue; + } + + Cbo.BeginObject(); + Cbo << "worker_id" << Info.WorkerId; + Cbo.AddDateTime("time_first", Info.Range.First); + Cbo.AddDateTime("time_last", Info.Range.Last); + Cbo.EndObject(); + } + Cbo.EndArray(); + + return Cbo.Save(); +} + +void +OrchestratorService::RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname) +{ + ProvisioningEvent Evt{ + .EventType = Type, + .Timestamp = DateTime::Now(), + .WorkerId = std::string(WorkerId), + .Hostname = std::string(Hostname), + }; + + m_ProvisioningLogLock.WithExclusiveLock([&] { + m_ProvisioningLog.push_back(std::move(Evt)); + while (m_ProvisioningLog.size() > kMaxProvisioningEvents) + { + m_ProvisioningLog.pop_front(); + } + }); +} + +CbObject +OrchestratorService::GetProvisioningHistory(int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetProvisioningHistory"); + + if (Limit <= 0) + { + Limit = 100; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("events"); + + m_ProvisioningLogLock.WithSharedLock([&] { + // Return last N events, newest first + int Count = 0; + for (auto It = m_ProvisioningLog.rbegin(); It != m_ProvisioningLog.rend() && Count < Limit; ++It, ++Count) + { + const auto& Evt = *It; + Cbo.BeginObject(); + + switch (Evt.EventType) + { + case ProvisioningEvent::Type::Joined: + Cbo << "type" + << "joined"; + break; + case ProvisioningEvent::Type::Left: + Cbo << "type" + << "left"; + break; + case ProvisioningEvent::Type::Returned: + Cbo << "type" + << "returned"; + break; + } + + Cbo.AddDateTime("ts", Evt.Timestamp); + Cbo << "worker_id" << std::string_view(Evt.WorkerId); + Cbo << "hostname" << std::string_view(Evt.Hostname); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +std::string +OrchestratorService::AnnounceClient(const ClientAnnouncement& Ann) +{ + ZEN_TRACE_CPU("OrchestratorService::AnnounceClient"); + + std::string ClientId = fmt::format("client-{}", Oid::NewOid().ToString()); + + bool IsNew = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(ClientId); + IsNew = (It == m_KnownClients.end()); + + auto& Client = m_KnownClients[ClientId]; + Client.SessionId = Ann.SessionId; + Client.Hostname = Ann.Hostname; + if (!Ann.Address.empty()) + { + Client.Address = Ann.Address; + } + if (Ann.Metadata) + { + Client.Metadata = Ann.Metadata; + } + Client.LastSeen.Reset(); + }); + + if (IsNew) + { + RecordClientEvent(ClientEvent::Type::Connected, ClientId, Ann.Hostname); + } + else + { + RecordClientEvent(ClientEvent::Type::Updated, ClientId, Ann.Hostname); + } + + return ClientId; +} + +bool +OrchestratorService::UpdateClient(std::string_view ClientId, CbObject Metadata) +{ + ZEN_TRACE_CPU("OrchestratorService::UpdateClient"); + + bool Found = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(std::string(ClientId)); + if (It != m_KnownClients.end()) + { + Found = true; + if (Metadata) + { + It->second.Metadata = std::move(Metadata); + } + It->second.LastSeen.Reset(); + } + }); + + return Found; +} + +bool +OrchestratorService::CompleteClient(std::string_view ClientId) +{ + ZEN_TRACE_CPU("OrchestratorService::CompleteClient"); + + std::string Hostname; + bool Found = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(std::string(ClientId)); + if (It != m_KnownClients.end()) + { + Found = true; + Hostname = It->second.Hostname; + m_KnownClients.erase(It); + } + }); + + if (Found) + { + RecordClientEvent(ClientEvent::Type::Disconnected, ClientId, Hostname); + } + + return Found; +} + +CbObject +OrchestratorService::GetClientList() +{ + ZEN_TRACE_CPU("OrchestratorService::GetClientList"); + CbObjectWriter Cbo; + Cbo.BeginArray("clients"); + + m_KnownClientsLock.WithSharedLock([&] { + for (const auto& [ClientId, Client] : m_KnownClients) + { + Cbo.BeginObject(); + Cbo << "id" << ClientId; + if (Client.SessionId) + { + Cbo << "session_id" << Client.SessionId; + } + Cbo << "hostname" << std::string_view(Client.Hostname); + if (!Client.Address.empty()) + { + Cbo << "address" << std::string_view(Client.Address); + } + Cbo << "dt" << Client.LastSeen.GetElapsedTimeMs(); + if (Client.Metadata) + { + Cbo << "metadata" << Client.Metadata; + } + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +CbObject +OrchestratorService::GetClientHistory(int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetClientHistory"); + + if (Limit <= 0) + { + Limit = 100; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("client_events"); + + m_ClientLogLock.WithSharedLock([&] { + int Count = 0; + for (auto It = m_ClientLog.rbegin(); It != m_ClientLog.rend() && Count < Limit; ++It, ++Count) + { + const auto& Evt = *It; + Cbo.BeginObject(); + + switch (Evt.EventType) + { + case ClientEvent::Type::Connected: + Cbo << "type" + << "connected"; + break; + case ClientEvent::Type::Disconnected: + Cbo << "type" + << "disconnected"; + break; + case ClientEvent::Type::Updated: + Cbo << "type" + << "updated"; + break; + } + + Cbo.AddDateTime("ts", Evt.Timestamp); + Cbo << "client_id" << std::string_view(Evt.ClientId); + Cbo << "hostname" << std::string_view(Evt.Hostname); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +void +OrchestratorService::RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname) +{ + ClientEvent Evt{ + .EventType = Type, + .Timestamp = DateTime::Now(), + .ClientId = std::string(ClientId), + .Hostname = std::string(Hostname), + }; + + m_ClientLogLock.WithExclusiveLock([&] { + m_ClientLog.push_back(std::move(Evt)); + while (m_ClientLog.size() > kMaxClientEvents) + { + m_ClientLog.pop_front(); + } + }); +} + +void +OrchestratorService::ProbeThreadFunction() +{ + ZEN_TRACE_CPU("OrchestratorService::ProbeThreadFunction"); + SetCurrentThreadName("orch_probe"); + + bool IsFirstProbe = true; + + do + { + if (!IsFirstProbe) + { + m_ProbeThreadEvent.Wait(5'000); + m_ProbeThreadEvent.Reset(); + } + else + { + IsFirstProbe = false; + } + + if (m_ProbeThreadEnabled == false) + { + return; + } + + m_ProbeThreadEvent.Reset(); + + // Snapshot worker IDs and URIs under shared lock + struct WorkerSnapshot + { + std::string Id; + std::string Uri; + bool WsConnected = false; + }; + std::vector<WorkerSnapshot> Snapshots; + + m_KnownWorkersLock.WithSharedLock([&] { + Snapshots.reserve(m_KnownWorkers.size()); + for (const auto& [WorkerId, Worker] : m_KnownWorkers) + { + Snapshots.push_back({WorkerId, Worker.BaseUri, Worker.WsConnected}); + } + }); + + // Probe each worker outside the lock + for (const auto& Snap : Snapshots) + { + if (m_ProbeThreadEnabled == false) + { + return; + } + + // Workers with an active WebSocket connection are known-reachable; + // skip the HTTP health probe for them. + if (Snap.WsConnected) + { + continue; + } + + ReachableState NewState = ReachableState::Unreachable; + + try + { + HttpClient Client(Snap.Uri, + {.ConnectTimeout = std::chrono::milliseconds{3000}, .Timeout = std::chrono::milliseconds{5000}}); + HttpClient::Response Response = Client.Get("/health/"); + if (Response.IsSuccess()) + { + NewState = ReachableState::Reachable; + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + } + + ReachableState PrevState = ReachableState::Unknown; + std::string WorkerHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + auto It = m_KnownWorkers.find(Snap.Id); + if (It != m_KnownWorkers.end()) + { + PrevState = It->second.Reachable; + WorkerHostname = It->second.Hostname; + It->second.Reachable = NewState; + It->second.LastProbed.Reset(); + + if (PrevState != NewState) + { + if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable) + { + ZEN_INFO("worker {} ({}) is reachable again", Snap.Id, Snap.Uri); + } + else if (NewState == ReachableState::Reachable) + { + ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri); + } + else if (PrevState == ReachableState::Reachable) + { + ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri); + } + else + { + ZEN_WARN("worker {} ({}) is not reachable", Snap.Id, Snap.Uri); + } + } + } + }); + + // Record provisioning events for state transitions outside the lock + if (PrevState != NewState) + { + if (NewState == ReachableState::Unreachable && PrevState == ReachableState::Reachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Left, Snap.Id, WorkerHostname); + } + else if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Returned, Snap.Id, WorkerHostname); + } + } + } + + // Sweep expired clients (5-minute timeout) + static constexpr int64_t kClientTimeoutMs = 5 * 60 * 1000; + + struct ExpiredClient + { + std::string Id; + std::string Hostname; + }; + std::vector<ExpiredClient> ExpiredClients; + + m_KnownClientsLock.WithExclusiveLock([&] { + for (auto It = m_KnownClients.begin(); It != m_KnownClients.end();) + { + if (It->second.LastSeen.GetElapsedTimeMs() > kClientTimeoutMs) + { + ExpiredClients.push_back({It->first, It->second.Hostname}); + It = m_KnownClients.erase(It); + } + else + { + ++It; + } + } + }); + + for (const auto& Expired : ExpiredClients) + { + ZEN_INFO("client {} timed out (no announcement for >5 minutes)", Expired.Id); + RecordClientEvent(ClientEvent::Type::Disconnected, Expired.Id, Expired.Hostname); + } + } while (m_ProbeThreadEnabled); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/recording/actionrecorder.cpp b/src/zencompute/recording/actionrecorder.cpp new file mode 100644 index 000000000..90141ca55 --- /dev/null +++ b/src/zencompute/recording/actionrecorder.cpp @@ -0,0 +1,258 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "actionrecorder.h" + +#include "../runners/functionrunner.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinaryfile.h> +#include <zencore/compactbinaryvalue.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> + +#if ZEN_PLATFORM_WINDOWS +# include <ppl.h> +# define ZEN_CONCRT_AVAILABLE 1 +#else +# define ZEN_CONCRT_AVAILABLE 0 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +RecordingFileWriter::RecordingFileWriter() +{ +} + +RecordingFileWriter::~RecordingFileWriter() +{ + Close(); +} + +void +RecordingFileWriter::Open(std::filesystem::path FilePath) +{ + using namespace std::literals; + + m_File.Open(FilePath, BasicFile::Mode::kTruncate); + m_File.Write("----DDC2----DATA", 16, 0); + m_FileOffset = 16; + + std::filesystem::path TocPath = FilePath.replace_extension(".ztoc"); + m_TocFile.Open(TocPath, BasicFile::Mode::kTruncate); + + m_TocWriter << "version"sv << 1; + m_TocWriter.BeginArray("toc"sv); +} + +void +RecordingFileWriter::Close() +{ + m_TocWriter.EndArray(); + CbObject Toc = m_TocWriter.Save(); + + std::error_code Ec; + m_TocFile.WriteAll(Toc.GetBuffer().AsIoBuffer(), Ec); +} + +void +RecordingFileWriter::AppendObject(const CbObject& Object, const IoHash& ObjectHash) +{ + RwLock::ExclusiveLockScope _(m_FileLock); + + MemoryView ObjectView = Object.GetBuffer().GetView(); + + std::error_code Ec; + m_File.Write(ObjectView, m_FileOffset, Ec); + + if (Ec) + { + throw std::system_error(Ec, "failed writing to archive"); + } + + m_TocWriter.BeginArray(); + m_TocWriter.AddHash(ObjectHash); + m_TocWriter.AddInteger(m_FileOffset); + m_TocWriter.AddInteger(gsl::narrow<int>(ObjectView.GetSize())); + m_TocWriter.EndArray(); + + m_FileOffset += ObjectView.GetSize(); +} + +////////////////////////////////////////////////////////////////////////// + +ActionRecorder::ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath) +: m_ChunkResolver(InChunkResolver) +, m_RecordingLogDir(RecordingLogPath) +{ + std::error_code Ec; + CreateDirectories(m_RecordingLogDir, Ec); + + if (Ec) + { + ZEN_WARN("Could not create directory '{}': {}", m_RecordingLogDir, Ec.message()); + } + + CleanDirectory(m_RecordingLogDir, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Could not clean directory '{}': {}", m_RecordingLogDir, Ec.message()); + } + + m_WorkersFile.Open(m_RecordingLogDir / "workers.zdat"); + m_ActionsFile.Open(m_RecordingLogDir / "actions.zdat"); + + CidStoreConfiguration CidConfig; + CidConfig.RootDirectory = m_RecordingLogDir / "cid"; + CidConfig.HugeValueThreshold = 128 * 1024 * 1024; + + m_CidStore.Initialize(CidConfig); +} + +ActionRecorder::~ActionRecorder() +{ + Shutdown(); +} + +void +ActionRecorder::Shutdown() +{ + m_CidStore.Flush(); +} + +void +ActionRecorder::RegisterWorker(const CbPackage& WorkerPackage) +{ + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + + m_WorkersFile.AppendObject(WorkerPackage.GetObject(), WorkerId); + + std::unordered_set<IoHash> AddedChunks; + uint64_t AddedBytes = 0; + + // First add all attachments from the worker package itself + + for (const CbAttachment& Attachment : WorkerPackage.GetAttachments()) + { + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + IoBuffer Data = Buffer.GetCompressed().Flatten().AsIoBuffer(); + + const IoHash ChunkHash = Buffer.DecodeRawHash(); + + CidStore::InsertResult Result = m_CidStore.AddChunk(Data, ChunkHash, CidStore::InsertMode::kCopyOnly); + + AddedChunks.insert(ChunkHash); + + if (Result.New) + { + AddedBytes += Data.GetSize(); + } + } + + // Not all attachments will be present in the worker package, so we need to add + // all referenced chunks to ensure that the recording is self-contained and not + // referencing data in the main CID store + + CbObject WorkerDescriptor = WorkerPackage.GetObject(); + + WorkerDescriptor.IterateAttachments([&](const CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + + if (!AddedChunks.contains(AttachmentCid)) + { + IoBuffer AttachmentData = m_ChunkResolver.FindChunkByCid(AttachmentCid); + + if (AttachmentData) + { + CidStore::InsertResult Result = m_CidStore.AddChunk(AttachmentData, AttachmentCid, CidStore::InsertMode::kCopyOnly); + + if (Result.New) + { + AddedBytes += AttachmentData.GetSize(); + } + } + else + { + ZEN_WARN("RegisterWorker: could not resolve attachment chunk {} for worker {}", AttachmentCid, WorkerId); + } + + AddedChunks.insert(AttachmentCid); + } + }); + + ZEN_INFO("recorded worker {} with {} attachments ({} bytes)", WorkerId, AddedChunks.size(), AddedBytes); +} + +bool +ActionRecorder::RecordAction(Ref<RunnerAction> Action) +{ + bool AllGood = true; + + Action->ActionObj.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsHash(); + IoBuffer ChunkData = m_ChunkResolver.FindChunkByCid(AttachData); + + if (ChunkData) + { + if (ChunkData.GetContentType() == ZenContentType::kCompressedBinary) + { + IoHash DecompressedHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), /* out */ DecompressedHash, /* out*/ RawSize); + + OodleCompressor Compressor; + OodleCompressionLevel CompressionLevel; + uint64_t BlockSize = 0; + if (Compressed.TryGetCompressParameters(/* out */ Compressor, /* out */ CompressionLevel, /* out */ BlockSize)) + { + if (Compressor == OodleCompressor::NotSet) + { + CompositeBuffer Decompressed = Compressed.DecompressToComposite(); + CompressedBuffer NewCompressed = CompressedBuffer::Compress(std::move(Decompressed), + OodleCompressor::Mermaid, + OodleCompressionLevel::Fast, + BlockSize); + + ChunkData = NewCompressed.GetCompressed().Flatten().AsIoBuffer(); + } + } + } + + const uint64_t ChunkSize = ChunkData.GetSize(); + + m_CidStore.AddChunk(ChunkData, AttachData, CidStore::InsertMode::kCopyOnly); + ++m_ChunkCounter; + m_ChunkBytesCounter.fetch_add(ChunkSize); + } + else + { + AllGood = false; + + ZEN_WARN("could not resolve chunk {}", AttachData); + } + }); + + if (AllGood) + { + m_ActionsFile.AppendObject(Action->ActionObj, Action->ActionId); + ++m_ActionsCounter; + + return true; + } + else + { + return false; + } +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/recording/actionrecorder.h b/src/zencompute/recording/actionrecorder.h new file mode 100644 index 000000000..2827b6ac7 --- /dev/null +++ b/src/zencompute/recording/actionrecorder.h @@ -0,0 +1,91 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/computeservice.h> +#include <zencompute/zencompute.h> +#include <zencore/basicfile.h> +#include <zencore/compactbinarybuilder.h> +#include <zenstore/cidstore.h> +#include <zenstore/gc.h> +#include <zenstore/zenstore.h> + +#include <filesystem> +#include <functional> +#include <map> +#include <unordered_map> + +namespace zen { +class CbObject; +class CbPackage; +struct IoHash; +} // namespace zen + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +////////////////////////////////////////////////////////////////////////// + +struct RecordingFileWriter +{ + RecordingFileWriter(RecordingFileWriter&&) = delete; + RecordingFileWriter& operator=(RecordingFileWriter&&) = delete; + + RwLock m_FileLock; + BasicFile m_File; + uint64_t m_FileOffset = 0; + CbObjectWriter m_TocWriter; + BasicFile m_TocFile; + + RecordingFileWriter(); + ~RecordingFileWriter(); + + void Open(std::filesystem::path FilePath); + void Close(); + void AppendObject(const CbObject& Object, const IoHash& ObjectHash); +}; + +////////////////////////////////////////////////////////////////////////// + +/** + * Recording "runner" implementation + * + * This class writes out all actions and their attachments to a recording directory + * in a format that can be read back by the RecordingReader. + * + * The contents of the recording directory will be self-contained, with all referenced + * attachments stored in the recording directory itself, so that the recording can be + * moved or shared without needing to maintain references to the main CID store. + * + */ + +class ActionRecorder +{ +public: + ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath); + ~ActionRecorder(); + + ActionRecorder(const ActionRecorder&) = delete; + ActionRecorder& operator=(const ActionRecorder&) = delete; + + void Shutdown(); + void RegisterWorker(const CbPackage& WorkerPackage); + bool RecordAction(Ref<RunnerAction> Action); + +private: + ChunkResolver& m_ChunkResolver; + std::filesystem::path m_RecordingLogDir; + + RecordingFileWriter m_WorkersFile; + RecordingFileWriter m_ActionsFile; + GcManager m_Gc; + CidStore m_CidStore{m_Gc}; + std::atomic<int> m_ChunkCounter{0}; + std::atomic<uint64_t> m_ChunkBytesCounter{0}; + std::atomic<int> m_ActionsCounter{0}; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/recording/recordingreader.cpp b/src/zencompute/recording/recordingreader.cpp new file mode 100644 index 000000000..1c1a119cf --- /dev/null +++ b/src/zencompute/recording/recordingreader.cpp @@ -0,0 +1,335 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/recordingreader.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinaryfile.h> +#include <zencore/compactbinaryvalue.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> + +#if ZEN_PLATFORM_WINDOWS +# include <ppl.h> +# define ZEN_CONCRT_AVAILABLE 1 +#else +# define ZEN_CONCRT_AVAILABLE 0 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +# if ZEN_PLATFORM_WINDOWS +# define ZEN_BUILD_ACTION L"Build.action" +# define ZEN_WORKER_UCB L"worker.ucb" +# else +# define ZEN_BUILD_ACTION "Build.action" +# define ZEN_WORKER_UCB "worker.ucb" +# endif + +////////////////////////////////////////////////////////////////////////// + +struct RecordingTreeVisitor : public FileSystemTraversal::TreeVisitor +{ + virtual void VisitFile(const std::filesystem::path& Parent, + const path_view& File, + uint64_t FileSize, + uint32_t NativeModeOrAttributes, + uint64_t NativeModificationTick) + { + ZEN_UNUSED(Parent, File, FileSize, NativeModeOrAttributes, NativeModificationTick); + + if (File.compare(path_view(ZEN_BUILD_ACTION)) == 0) + { + WorkDirs.push_back(Parent); + } + else if (File.compare(path_view(ZEN_WORKER_UCB)) == 0) + { + WorkerDirs.push_back(Parent); + } + } + + virtual bool VisitDirectory(const std::filesystem::path& Parent, const path_view& DirectoryName, uint32_t NativeModeOrAttributes) + { + ZEN_UNUSED(Parent, DirectoryName, NativeModeOrAttributes); + + return true; + } + + std::vector<std::filesystem::path> WorkerDirs; + std::vector<std::filesystem::path> WorkDirs; +}; + +////////////////////////////////////////////////////////////////////////// + +void +IterateOverArray(auto Array, auto Func, int TargetParallelism) +{ +# if ZEN_CONCRT_AVAILABLE + if (TargetParallelism > 1) + { + concurrency::simple_partitioner Chunker(Array.size() / TargetParallelism); + concurrency::parallel_for_each(begin(Array), end(Array), [&](const auto& Item) { Func(Item); }); + + return; + } +# else + ZEN_UNUSED(TargetParallelism); +# endif + + for (const auto& Item : Array) + { + Func(Item); + } +} + +////////////////////////////////////////////////////////////////////////// + +RecordingReaderBase::~RecordingReaderBase() = default; + +////////////////////////////////////////////////////////////////////////// + +RecordingReader::RecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingLogDir(RecordingPath) +{ + CidStoreConfiguration CidConfig; + CidConfig.RootDirectory = m_RecordingLogDir / "cid"; + CidConfig.HugeValueThreshold = 128 * 1024 * 1024; + + m_CidStore.Initialize(CidConfig); +} + +RecordingReader::~RecordingReader() +{ + m_CidStore.Flush(); +} + +size_t +RecordingReader::GetActionCount() const +{ + return m_Actions.size(); +} + +IoBuffer +RecordingReader::FindChunkByCid(const IoHash& DecompressedId) +{ + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(DecompressedId)) + { + return Chunk; + } + + ZEN_ERROR("failed lookup of chunk with CID '{}'", DecompressedId); + + return {}; +} + +std::unordered_map<zen::IoHash, zen::CbPackage> +RecordingReader::ReadWorkers() +{ + std::unordered_map<zen::IoHash, zen::CbPackage> WorkerMap; + + { + CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "workers.ztoc"); + CbObject Toc = TocFile.Object; + + m_WorkerDataFile.Open(m_RecordingLogDir / "workers.zdat", BasicFile::Mode::kRead); + + ZEN_ASSERT(Toc["version"sv].AsInt32() == 1); + + for (auto& It : Toc["toc"]) + { + CbArrayView Entry = It.AsArrayView(); + CbFieldViewIterator Vit = Entry.CreateViewIterator(); + + const IoHash WorkerId = Vit++->AsHash(); + const uint64_t Offset = Vit++->AsInt64(0); + const uint64_t Size = Vit++->AsInt64(0); + + IoBuffer WorkerRange = m_WorkerDataFile.ReadRange(Offset, Size); + CbObject WorkerDesc = LoadCompactBinaryObject(WorkerRange); + CbPackage& WorkerPkg = WorkerMap[WorkerId]; + WorkerPkg.SetObject(WorkerDesc); + + WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = m_CidStore.FindChunkByCid(AttachmentCid); + + if (AttachmentData) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash)); + } + }); + } + } + + // Scan actions as well (this should be called separately, ideally) + + ScanActions(); + + return WorkerMap; +} + +void +RecordingReader::ScanActions() +{ + CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "actions.ztoc"); + CbObject Toc = TocFile.Object; + + m_ActionDataFile.Open(m_RecordingLogDir / "actions.zdat", BasicFile::Mode::kRead); + + ZEN_ASSERT(Toc["version"sv].AsInt32() == 1); + + for (auto& It : Toc["toc"]) + { + CbArrayView ArrayEntry = It.AsArrayView(); + CbFieldViewIterator Vit = ArrayEntry.CreateViewIterator(); + + ActionEntry Entry; + Entry.ActionId = Vit++->AsHash(); + Entry.Offset = Vit++->AsInt64(0); + Entry.Size = Vit++->AsInt64(0); + + m_Actions.push_back(Entry); + } +} + +void +RecordingReader::IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, int TargetParallelism) +{ + IterateOverArray( + m_Actions, + [&](const ActionEntry& Entry) { + CbObject ActionDesc = LoadCompactBinaryObject(m_ActionDataFile.ReadRange(Entry.Offset, Entry.Size)); + + Callback(ActionDesc, Entry.ActionId); + }, + TargetParallelism); +} + +////////////////////////////////////////////////////////////////////////// + +IoBuffer +LocalResolver::FindChunkByCid(const IoHash& DecompressedId) +{ + RwLock::SharedLockScope _(MapLock); + if (auto It = Attachments.find(DecompressedId); It != Attachments.end()) + { + return It->second; + } + + return {}; +} + +void +LocalResolver::Add(const IoHash& Cid, IoBuffer Data) +{ + RwLock::ExclusiveLockScope _(MapLock); + Data.SetContentType(ZenContentType::kCompressedBinary); + Attachments[Cid] = Data; +} + +/// + +UeRecordingReader::UeRecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingDir(RecordingPath) +{ +} + +UeRecordingReader::~UeRecordingReader() +{ +} + +size_t +UeRecordingReader::GetActionCount() const +{ + return m_WorkDirs.size(); +} + +IoBuffer +UeRecordingReader::FindChunkByCid(const IoHash& DecompressedId) +{ + return m_LocalResolver.FindChunkByCid(DecompressedId); +} + +std::unordered_map<zen::IoHash, zen::CbPackage> +UeRecordingReader::ReadWorkers() +{ + std::unordered_map<zen::IoHash, zen::CbPackage> WorkerMap; + + FileSystemTraversal Traversal; + RecordingTreeVisitor Visitor; + Traversal.TraverseFileSystem(m_RecordingDir, Visitor); + + m_WorkDirs = std::move(Visitor.WorkDirs); + + for (const std::filesystem::path& WorkerDir : Visitor.WorkerDirs) + { + CbObjectFromFile WorkerFile = LoadCompactBinaryObject(WorkerDir / "worker.ucb"); + CbObject WorkerDesc = WorkerFile.Object; + const IoHash& WorkerId = WorkerFile.Hash; + CbPackage& WorkerPkg = WorkerMap[WorkerId]; + WorkerPkg.SetObject(WorkerDesc); + + WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = ReadFile(WorkerDir / "chunks" / AttachmentCid.ToHexString()).Flatten(); + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash)); + }); + } + + return WorkerMap; +} + +void +UeRecordingReader::IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, int ParallelismTarget) +{ + IterateOverArray( + m_WorkDirs, + [&](const std::filesystem::path& WorkDir) { + CbPackage WorkPackage = ReadAction(WorkDir); + CbObject ActionObject = WorkPackage.GetObject(); + const IoHash& ActionId = WorkPackage.GetObjectHash(); + + Callback(ActionObject, ActionId); + }, + ParallelismTarget); +} + +CbPackage +UeRecordingReader::ReadAction(std::filesystem::path WorkDir) +{ + CbPackage WorkPackage; + std::filesystem::path WorkDescPath = WorkDir / "Build.action"; + CbObjectFromFile ActionFile = LoadCompactBinaryObject(WorkDescPath); + CbObject& ActionObject = ActionFile.Object; + + WorkPackage.SetObject(ActionObject); + + ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = ReadFile(WorkDir / "inputs" / AttachmentCid.ToHexString()).Flatten(); + + m_LocalResolver.Add(AttachmentCid, AttachmentData); + + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + ZEN_ASSERT(AttachmentCid == RawHash); + WorkPackage.AddAttachment(CbAttachment(CompressedData, RawHash)); + }); + + return WorkPackage; +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/runners/deferreddeleter.cpp b/src/zencompute/runners/deferreddeleter.cpp new file mode 100644 index 000000000..4fad2cf70 --- /dev/null +++ b/src/zencompute/runners/deferreddeleter.cpp @@ -0,0 +1,340 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "deferreddeleter.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/logging.h> +# include <zencore/thread.h> + +# include <algorithm> +# include <chrono> + +namespace zen::compute { + +using namespace std::chrono_literals; + +using Clock = std::chrono::steady_clock; + +// Default deferral: how long to wait before attempting deletion. +// This gives memory-mapped file handles time to close naturally. +static constexpr auto DeferralPeriod = 60s; + +// Shortened deferral after MarkReady(): the client has collected results +// so handles should be released soon, but we still wait briefly. +static constexpr auto ReadyGracePeriod = 5s; + +// Interval between retry attempts for directories that failed deletion. +static constexpr auto RetryInterval = 5s; + +static constexpr int MaxRetries = 10; + +DeferredDirectoryDeleter::DeferredDirectoryDeleter() : m_Thread(&DeferredDirectoryDeleter::ThreadFunction, this) +{ +} + +DeferredDirectoryDeleter::~DeferredDirectoryDeleter() +{ + Shutdown(); +} + +void +DeferredDirectoryDeleter::Enqueue(int ActionLsn, std::filesystem::path Path) +{ + { + std::lock_guard Lock(m_Mutex); + m_Queue.push_back({ActionLsn, std::move(Path)}); + } + m_Cv.notify_one(); +} + +void +DeferredDirectoryDeleter::MarkReady(int ActionLsn) +{ + { + std::lock_guard Lock(m_Mutex); + m_ReadyLsns.push_back(ActionLsn); + } + m_Cv.notify_one(); +} + +void +DeferredDirectoryDeleter::Shutdown() +{ + { + std::lock_guard Lock(m_Mutex); + m_Done = true; + } + m_Cv.notify_one(); + + if (m_Thread.joinable()) + { + m_Thread.join(); + } +} + +void +DeferredDirectoryDeleter::ThreadFunction() +{ + SetCurrentThreadName("ZenDirCleanup"); + + struct PendingEntry + { + int ActionLsn; + std::filesystem::path Path; + Clock::time_point ReadyTime; + int Attempts = 0; + }; + + std::vector<PendingEntry> PendingList; + + auto TryDelete = [](PendingEntry& Entry) -> bool { + std::error_code Ec; + std::filesystem::remove_all(Entry.Path, Ec); + return !Ec; + }; + + for (;;) + { + bool Shutting = false; + + // Drain the incoming queue and process MarkReady signals + + { + std::unique_lock Lock(m_Mutex); + + if (m_Queue.empty() && m_ReadyLsns.empty() && !m_Done) + { + if (PendingList.empty()) + { + m_Cv.wait(Lock, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; }); + } + else + { + auto NextReady = PendingList.front().ReadyTime; + for (const auto& Entry : PendingList) + { + if (Entry.ReadyTime < NextReady) + { + NextReady = Entry.ReadyTime; + } + } + + m_Cv.wait_until(Lock, NextReady, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; }); + } + } + + // Move new items into PendingList with the full deferral deadline + auto Now = Clock::now(); + for (auto& Entry : m_Queue) + { + PendingList.push_back({Entry.ActionLsn, std::move(Entry.Path), Now + DeferralPeriod, 0}); + } + m_Queue.clear(); + + // Apply MarkReady: shorten ReadyTime for matching entries + for (int Lsn : m_ReadyLsns) + { + for (auto& Entry : PendingList) + { + if (Entry.ActionLsn == Lsn) + { + auto NewReady = Now + ReadyGracePeriod; + if (NewReady < Entry.ReadyTime) + { + Entry.ReadyTime = NewReady; + } + } + } + } + m_ReadyLsns.clear(); + + Shutting = m_Done; + } + + // Process items whose deferral period has elapsed (or all items on shutdown) + + auto Now = Clock::now(); + + for (size_t i = 0; i < PendingList.size();) + { + auto& Entry = PendingList[i]; + + if (!Shutting && Now < Entry.ReadyTime) + { + ++i; + continue; + } + + if (TryDelete(Entry)) + { + if (Entry.Attempts > 0) + { + ZEN_INFO("Retry succeeded for directory '{}'", Entry.Path); + } + + PendingList[i] = std::move(PendingList.back()); + PendingList.pop_back(); + } + else + { + ++Entry.Attempts; + + if (Entry.Attempts >= MaxRetries) + { + ZEN_WARN("Giving up on deleting '{}' after {} attempts", Entry.Path, Entry.Attempts); + PendingList[i] = std::move(PendingList.back()); + PendingList.pop_back(); + } + else + { + ZEN_WARN("Unable to delete directory '{}' (attempt {}), will retry", Entry.Path, Entry.Attempts); + Entry.ReadyTime = Now + RetryInterval; + ++i; + } + } + } + + // Exit once shutdown is requested and nothing remains + + if (Shutting && PendingList.empty()) + { + return; + } + } +} + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS + +# include <zencore/testing.h> + +namespace zen::compute { + +void +deferreddeleter_forcelink() +{ +} + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/testutils.h> + +namespace zen::compute { + +TEST_SUITE_BEGIN("compute.deferreddeleter"); + +TEST_CASE("DeferredDirectoryDeleter.DeletesSingleDirectory") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path DirToDelete = TempDir.Path() / "subdir"; + CreateDirectories(DirToDelete / "nested"); + + CHECK(std::filesystem::exists(DirToDelete)); + + { + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(1, DirToDelete); + } + + CHECK(!std::filesystem::exists(DirToDelete)); +} + +TEST_CASE("DeferredDirectoryDeleter.DeletesMultipleDirectories") +{ + ScopedTemporaryDirectory TempDir; + + constexpr int NumDirs = 10; + std::vector<std::filesystem::path> Dirs; + + for (int i = 0; i < NumDirs; ++i) + { + auto Dir = TempDir.Path() / std::to_string(i); + CreateDirectories(Dir / "child"); + Dirs.push_back(std::move(Dir)); + } + + { + DeferredDirectoryDeleter Deleter; + for (int i = 0; i < NumDirs; ++i) + { + CHECK(std::filesystem::exists(Dirs[i])); + Deleter.Enqueue(100 + i, Dirs[i]); + } + } + + for (const auto& Dir : Dirs) + { + CHECK(!std::filesystem::exists(Dir)); + } +} + +TEST_CASE("DeferredDirectoryDeleter.ShutdownIsIdempotent") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "idempotent"; + CreateDirectories(Dir); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(42, Dir); + Deleter.Shutdown(); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_CASE("DeferredDirectoryDeleter.HandlesNonExistentPath") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path NoSuchDir = TempDir.Path() / "does_not_exist"; + + { + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(99, NoSuchDir); + } +} + +TEST_CASE("DeferredDirectoryDeleter.ExplicitShutdownBeforeDestruction") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "explicit"; + CreateDirectories(Dir / "inner"); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(7, Dir); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_CASE("DeferredDirectoryDeleter.MarkReadyShortensDeferral") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "markready"; + CreateDirectories(Dir / "child"); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(50, Dir); + + // Without MarkReady the full deferral (60s) would apply. + // MarkReady shortens it to 5s, and shutdown bypasses even that. + Deleter.MarkReady(50); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_SUITE_END(); + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/runners/deferreddeleter.h b/src/zencompute/runners/deferreddeleter.h new file mode 100644 index 000000000..9b010aa0f --- /dev/null +++ b/src/zencompute/runners/deferreddeleter.h @@ -0,0 +1,68 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <condition_variable> +# include <deque> +# include <filesystem> +# include <mutex> +# include <thread> +# include <vector> + +namespace zen::compute { + +/// Deletes directories on a background thread to avoid blocking callers. +/// Useful when DeleteDirectories may stall (e.g. Wine's deferred-unlink semantics). +/// +/// Enqueued directories wait for a deferral period before deletion, giving +/// file handles time to close. Call MarkReady() with the ActionLsn to shorten +/// the wait to a brief grace period (e.g. once a client has collected results). +/// On shutdown, all pending directories are deleted immediately. +class DeferredDirectoryDeleter +{ + DeferredDirectoryDeleter(const DeferredDirectoryDeleter&) = delete; + DeferredDirectoryDeleter& operator=(const DeferredDirectoryDeleter&) = delete; + +public: + DeferredDirectoryDeleter(); + ~DeferredDirectoryDeleter(); + + /// Enqueue a directory for deferred deletion, associated with an action LSN. + void Enqueue(int ActionLsn, std::filesystem::path Path); + + /// Signal that the action result has been consumed and the directory + /// can be deleted after a short grace period instead of the full deferral. + void MarkReady(int ActionLsn); + + /// Drain the queue and join the background thread. Idempotent. + void Shutdown(); + +private: + struct QueueEntry + { + int ActionLsn; + std::filesystem::path Path; + }; + + std::mutex m_Mutex; + std::condition_variable m_Cv; + std::deque<QueueEntry> m_Queue; + std::vector<int> m_ReadyLsns; + bool m_Done = false; + std::thread m_Thread; + void ThreadFunction(); +}; + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS +namespace zen::compute { +void deferreddeleter_forcelink(); // internal +} // namespace zen::compute +#endif diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp new file mode 100644 index 000000000..768cdf1e1 --- /dev/null +++ b/src/zencompute/runners/functionrunner.cpp @@ -0,0 +1,365 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "functionrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/filesystem.h> +# include <zencore/trace.h> + +# include <fmt/format.h> +# include <vector> + +namespace zen::compute { + +FunctionRunner::FunctionRunner(std::filesystem::path BasePath) : m_ActionsPath(BasePath / "actions") +{ +} + +FunctionRunner::~FunctionRunner() = default; + +size_t +FunctionRunner::QueryCapacity() +{ + return 1; +} + +std::vector<SubmitResult> +FunctionRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + std::vector<SubmitResult> Results; + Results.reserve(Actions.size()); + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; +} + +void +FunctionRunner::MaybeDumpAction(int ActionLsn, const CbObject& ActionObject) +{ + if (m_DumpActions) + { + std::string UniqueId = fmt::format("{}.ddb", ActionLsn); + std::filesystem::path Path = m_ActionsPath / UniqueId; + + zen::WriteFile(Path, IoBuffer(ActionObject.GetBuffer().AsIoBuffer())); + } +} + +////////////////////////////////////////////////////////////////////////// + +void +BaseRunnerGroup::AddRunnerInternal(FunctionRunner* Runner) +{ + m_RunnersLock.WithExclusiveLock([&] { m_Runners.emplace_back(Runner); }); +} + +size_t +BaseRunnerGroup::QueryCapacity() +{ + size_t TotalCapacity = 0; + m_RunnersLock.WithSharedLock([&] { + for (const auto& Runner : m_Runners) + { + TotalCapacity += Runner->QueryCapacity(); + } + }); + return TotalCapacity; +} + +SubmitResult +BaseRunnerGroup::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("BaseRunnerGroup::SubmitAction"); + RwLock::SharedLockScope _(m_RunnersLock); + + const int InitialIndex = m_NextSubmitIndex.load(std::memory_order_acquire); + int Index = InitialIndex; + const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + + if (RunnerCount == 0) + { + return {.IsAccepted = false, .Reason = "No runners available"}; + } + + do + { + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + + auto& Runner = m_Runners[Index++]; + + SubmitResult Result = Runner->SubmitAction(Action); + + if (Result.IsAccepted == true) + { + m_NextSubmitIndex = Index % RunnerCount; + + return Result; + } + + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + } while (Index != InitialIndex); + + return {.IsAccepted = false}; +} + +std::vector<SubmitResult> +BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions"); + RwLock::SharedLockScope _(m_RunnersLock); + + const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + + if (RunnerCount == 0) + { + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); + } + + // Query capacity per runner and compute total + std::vector<size_t> Capacities(RunnerCount); + size_t TotalCapacity = 0; + + for (int i = 0; i < RunnerCount; ++i) + { + Capacities[i] = m_Runners[i]->QueryCapacity(); + TotalCapacity += Capacities[i]; + } + + if (TotalCapacity == 0) + { + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"}); + } + + // Distribute actions across runners proportionally to their available capacity + std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions(RunnerCount); + std::vector<size_t> ActionRunnerIndex(Actions.size()); + size_t ActionIdx = 0; + + for (int i = 0; i < RunnerCount; ++i) + { + if (Capacities[i] == 0) + { + continue; + } + + size_t Share = (Actions.size() * Capacities[i] + TotalCapacity - 1) / TotalCapacity; + Share = std::min(Share, Capacities[i]); + + for (size_t j = 0; j < Share && ActionIdx < Actions.size(); ++j, ++ActionIdx) + { + PerRunnerActions[i].push_back(Actions[ActionIdx]); + ActionRunnerIndex[ActionIdx] = i; + } + } + + // Assign any remaining actions to runners with capacity (round-robin) + for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount) + { + if (Capacities[i] > PerRunnerActions[i].size()) + { + PerRunnerActions[i].push_back(Actions[ActionIdx]); + ActionRunnerIndex[ActionIdx] = i; + ++ActionIdx; + } + } + + // Submit batches per runner + std::vector<std::vector<SubmitResult>> PerRunnerResults(RunnerCount); + + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]); + } + } + + // Reassemble results in original action order + std::vector<SubmitResult> Results(Actions.size()); + std::vector<size_t> PerRunnerIdx(RunnerCount, 0); + + for (size_t i = 0; i < Actions.size(); ++i) + { + size_t RunnerIdx = ActionRunnerIndex[i]; + size_t Idx = PerRunnerIdx[RunnerIdx]++; + Results[i] = std::move(PerRunnerResults[RunnerIdx][Idx]); + } + + return Results; +} + +size_t +BaseRunnerGroup::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunnersLock); + + size_t TotalCount = 0; + + for (const auto& Runner : m_Runners) + { + TotalCount += Runner->GetSubmittedActionCount(); + } + + return TotalCount; +} + +void +BaseRunnerGroup::RegisterWorker(CbPackage Worker) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->RegisterWorker(Worker); + } +} + +void +BaseRunnerGroup::Shutdown() +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->Shutdown(); + } +} + +bool +BaseRunnerGroup::CancelAction(int ActionLsn) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + if (Runner->CancelAction(ActionLsn)) + { + return true; + } + } + + return false; +} + +void +BaseRunnerGroup::CancelRemoteQueue(int QueueId) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->CancelRemoteQueue(QueueId); + } +} + +////////////////////////////////////////////////////////////////////////// + +RunnerAction::RunnerAction(ComputeServiceSession* OwnerSession) : m_OwnerSession(OwnerSession) +{ + this->Timestamps[static_cast<int>(State::New)] = DateTime::Now().GetTicks(); +} + +RunnerAction::~RunnerAction() +{ +} + +bool +RunnerAction::ResetActionStateToPending() +{ + // Only allow reset from Failed or Abandoned states + State CurrentState = m_ActionState.load(); + + if (CurrentState != State::Failed && CurrentState != State::Abandoned) + { + return false; + } + + if (!m_ActionState.compare_exchange_strong(CurrentState, State::Pending)) + { + return false; + } + + // Clear timestamps from Submitting through _Count + for (int i = static_cast<int>(State::Submitting); i < static_cast<int>(State::_Count); ++i) + { + this->Timestamps[i] = 0; + } + + // Record new Pending timestamp + this->Timestamps[static_cast<int>(State::Pending)] = DateTime::Now().GetTicks(); + + // Clear execution fields + ExecutionLocation.clear(); + CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); + CpuSeconds.store(0.0f, std::memory_order_relaxed); + + // Increment retry count + RetryCount.fetch_add(1, std::memory_order_relaxed); + + // Re-enter the scheduler pipeline + m_OwnerSession->PostUpdate(this); + + return true; +} + +void +RunnerAction::SetActionState(State NewState) +{ + ZEN_ASSERT(NewState < State::_Count); + this->Timestamps[static_cast<int>(NewState)] = DateTime::Now().GetTicks(); + + do + { + if (State CurrentState = m_ActionState.load(); CurrentState == NewState) + { + // No state change + return; + } + else + { + if (NewState <= CurrentState) + { + // Cannot transition to an earlier or same state + return; + } + + if (m_ActionState.compare_exchange_strong(CurrentState, NewState)) + { + // Successful state change + + m_OwnerSession->PostUpdate(this); + + return; + } + } + } while (true); +} + +void +RunnerAction::SetResult(CbPackage&& Result) +{ + m_Result = std::move(Result); +} + +CbPackage& +RunnerAction::GetResult() +{ + ZEN_ASSERT(IsCompleted()); + return m_Result; +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h new file mode 100644 index 000000000..f67414dbb --- /dev/null +++ b/src/zencompute/runners/functionrunner.h @@ -0,0 +1,214 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/computeservice.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <atomic> +# include <filesystem> +# include <vector> + +namespace zen::compute { + +struct SubmitResult +{ + bool IsAccepted = false; + std::string Reason; +}; + +/** Base interface for classes implementing a remote execution "runner" + */ +class FunctionRunner : public RefCounted +{ + FunctionRunner(FunctionRunner&&) = delete; + FunctionRunner& operator=(FunctionRunner&&) = delete; + +public: + FunctionRunner(std::filesystem::path BasePath); + virtual ~FunctionRunner() = 0; + + virtual void Shutdown() = 0; + virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; + + [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0; + [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; + [[nodiscard]] virtual bool IsHealthy() = 0; + [[nodiscard]] virtual size_t QueryCapacity(); + [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + + // Best-effort cancellation of a specific in-flight action. Returns true if the + // cancellation signal was successfully sent. The action will transition to Cancelled + // asynchronously once the platform-level termination completes. + virtual bool CancelAction(int /*ActionLsn*/) { return false; } + + // Cancel the remote queue corresponding to the given local QueueId. + // Only meaningful for remote runners; local runners ignore this. + virtual void CancelRemoteQueue(int /*QueueId*/) {} + +protected: + std::filesystem::path m_ActionsPath; + bool m_DumpActions = false; + void MaybeDumpAction(int ActionLsn, const CbObject& ActionObject); +}; + +/** Base class for RunnerGroup that operates on generic FunctionRunner references. + * All scheduling, capacity, and lifecycle logic lives here. + */ +class BaseRunnerGroup +{ +public: + size_t QueryCapacity(); + SubmitResult SubmitAction(Ref<RunnerAction> Action); + std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + size_t GetSubmittedActionCount(); + void RegisterWorker(CbPackage Worker); + void Shutdown(); + bool CancelAction(int ActionLsn); + void CancelRemoteQueue(int QueueId); + + size_t GetRunnerCount() + { + return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); }); + } + +protected: + void AddRunnerInternal(FunctionRunner* Runner); + + RwLock m_RunnersLock; + std::vector<Ref<FunctionRunner>> m_Runners; + std::atomic<int> m_NextSubmitIndex{0}; +}; + +/** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal. + */ +template<typename RunnerType> +struct RunnerGroup : public BaseRunnerGroup +{ + void AddRunner(RunnerType* Runner) { AddRunnerInternal(Runner); } + + template<typename Predicate> + size_t RemoveRunnerIf(Predicate&& Pred) + { + size_t RemovedCount = 0; + m_RunnersLock.WithExclusiveLock([&] { + auto It = m_Runners.begin(); + while (It != m_Runners.end()) + { + if (Pred(static_cast<RunnerType&>(**It))) + { + (*It)->Shutdown(); + It = m_Runners.erase(It); + ++RemovedCount; + } + else + { + ++It; + } + } + }); + return RemovedCount; + } +}; + +/** + * This represents an action going through different stages of scheduling and execution. + */ +struct RunnerAction : public RefCounted +{ + explicit RunnerAction(ComputeServiceSession* OwnerSession); + ~RunnerAction(); + + int ActionLsn = 0; + int QueueId = 0; + WorkerDesc Worker; + IoHash ActionId; + CbObject ActionObj; + int Priority = 0; + std::string ExecutionLocation; // "local" or remote hostname + + // CPU usage and total CPU time of the running process, sampled periodically by the local runner. + // CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage. + // CpuSeconds: total CPU time (user+system) consumed since process start, in seconds. 0.0 if not yet sampled. + std::atomic<float> CpuUsagePercent{-1.0f}; + std::atomic<float> CpuSeconds{0.0f}; + std::atomic<int> RetryCount{0}; + + enum class State + { + New, + Pending, + Submitting, + Running, + Completed, + Failed, + Abandoned, + Cancelled, + _Count + }; + + static const char* ToString(State _) + { + switch (_) + { + case State::New: + return "New"; + case State::Pending: + return "Pending"; + case State::Submitting: + return "Submitting"; + case State::Running: + return "Running"; + case State::Completed: + return "Completed"; + case State::Failed: + return "Failed"; + case State::Abandoned: + return "Abandoned"; + case State::Cancelled: + return "Cancelled"; + default: + return "Unknown"; + } + } + + static State FromString(std::string_view Name, State Default = State::Failed) + { + for (int i = 0; i < static_cast<int>(State::_Count); ++i) + { + if (Name == ToString(static_cast<State>(i))) + { + return static_cast<State>(i); + } + } + return Default; + } + + uint64_t Timestamps[static_cast<int>(State::_Count)] = {}; + + State ActionState() const { return m_ActionState; } + void SetActionState(State NewState); + + bool IsSuccess() const { return ActionState() == State::Completed; } + bool ResetActionStateToPending(); + bool IsCompleted() const + { + return ActionState() == State::Completed || ActionState() == State::Failed || ActionState() == State::Abandoned || + ActionState() == State::Cancelled; + } + + void SetResult(CbPackage&& Result); + CbPackage& GetResult(); + + ComputeServiceSession* GetOwnerSession() const { return m_OwnerSession; } + +private: + std::atomic<State> m_ActionState = State::New; + ComputeServiceSession* m_OwnerSession = nullptr; + CbPackage m_Result; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp new file mode 100644 index 000000000..e79a6c90f --- /dev/null +++ b/src/zencompute/runners/linuxrunner.cpp @@ -0,0 +1,734 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "linuxrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <fcntl.h> +# include <sched.h> +# include <signal.h> +# include <sys/mount.h> +# include <sys/stat.h> +# include <sys/syscall.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +namespace { + + // All helper functions in this namespace are async-signal-safe (safe to call + // between fork() and execve()). They use only raw syscalls and avoid any + // heap allocation, stdio, or other non-AS-safe operations. + + void WriteToFd(int Fd, const char* Buf, size_t Len) + { + while (Len > 0) + { + ssize_t Written = write(Fd, Buf, Len); + if (Written <= 0) + { + break; + } + Buf += Written; + Len -= static_cast<size_t>(Written); + } + } + + [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno) + { + // Write the message prefix + size_t MsgLen = 0; + for (const char* P = Msg; *P; ++P) + { + ++MsgLen; + } + WriteToFd(ErrorPipeFd, Msg, MsgLen); + + // Append ": " and the errno string if non-zero + if (Errno != 0) + { + WriteToFd(ErrorPipeFd, ": ", 2); + const char* ErrStr = strerror(Errno); + size_t ErrLen = 0; + for (const char* P = ErrStr; *P; ++P) + { + ++ErrLen; + } + WriteToFd(ErrorPipeFd, ErrStr, ErrLen); + } + + _exit(127); + } + + int MkdirIfNeeded(const char* Path, mode_t Mode) + { + if (mkdir(Path, Mode) != 0 && errno != EEXIST) + { + return -1; + } + return 0; + } + + int BindMountReadOnly(const char* Src, const char* Dst) + { + if (mount(Src, Dst, nullptr, MS_BIND | MS_REC, nullptr) != 0) + { + return -1; + } + + // Remount read-only + if (mount(nullptr, Dst, nullptr, MS_REMOUNT | MS_BIND | MS_RDONLY | MS_REC, nullptr) != 0) + { + return -1; + } + + return 0; + } + + // Set up namespace-based sandbox isolation in the child process. + // This is called after fork(), before execve(). All operations must be + // async-signal-safe. + // + // The sandbox layout after pivot_root: + // / -> the sandbox directory (tmpfs-like, was SandboxPath) + // /usr -> bind-mount of host /usr (read-only) + // /lib -> bind-mount of host /lib (read-only) + // /lib64 -> bind-mount of host /lib64 (read-only, optional) + // /etc -> bind-mount of host /etc (read-only) + // /worker -> bind-mount of worker directory (read-only) + // /proc -> proc filesystem + // /dev -> tmpfs with null, zero, urandom + void SetupNamespaceSandbox(const char* SandboxPath, uid_t Uid, gid_t Gid, const char* WorkerPath, int ErrorPipeFd) + { + // 1. Unshare user, mount, and network namespaces + if (unshare(CLONE_NEWUSER | CLONE_NEWNS | CLONE_NEWNET) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "unshare() failed", errno); + } + + // 2. Write UID/GID mappings + // Must deny setgroups first (required by kernel for unprivileged user namespaces) + { + int Fd = open("/proc/self/setgroups", O_WRONLY); + if (Fd >= 0) + { + WriteToFd(Fd, "deny", 4); + close(Fd); + } + // setgroups file may not exist on older kernels; not fatal + } + + { + // uid_map: map our UID to 0 inside the namespace + char Buf[64]; + int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast<unsigned>(Uid)); + + int Fd = open("/proc/self/uid_map", O_WRONLY); + if (Fd < 0) + { + WriteErrorAndExit(ErrorPipeFd, "open uid_map failed", errno); + } + WriteToFd(Fd, Buf, static_cast<size_t>(Len)); + close(Fd); + } + + { + // gid_map: map our GID to 0 inside the namespace + char Buf[64]; + int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast<unsigned>(Gid)); + + int Fd = open("/proc/self/gid_map", O_WRONLY); + if (Fd < 0) + { + WriteErrorAndExit(ErrorPipeFd, "open gid_map failed", errno); + } + WriteToFd(Fd, Buf, static_cast<size_t>(Len)); + close(Fd); + } + + // 3. Privatize the entire mount tree so our mounts don't propagate + if (mount(nullptr, "/", nullptr, MS_REC | MS_PRIVATE, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount MS_PRIVATE failed", errno); + } + + // 4. Create mount points inside the sandbox and bind-mount system directories + + // Helper macro-like pattern for building paths inside sandbox + // We use stack buffers since we can't allocate heap memory safely + char MountPoint[4096]; + + auto BuildPath = [&](const char* Suffix) -> const char* { + snprintf(MountPoint, sizeof(MountPoint), "%s/%s", SandboxPath, Suffix); + return MountPoint; + }; + + // /usr (required) + if (MkdirIfNeeded(BuildPath("usr"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/usr failed", errno); + } + if (BindMountReadOnly("/usr", BuildPath("usr")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /usr failed", errno); + } + + // /lib (required) + if (MkdirIfNeeded(BuildPath("lib"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/lib failed", errno); + } + if (BindMountReadOnly("/lib", BuildPath("lib")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /lib failed", errno); + } + + // /lib64 (optional — not all distros have it) + { + struct stat St; + if (stat("/lib64", &St) == 0 && S_ISDIR(St.st_mode)) + { + if (MkdirIfNeeded(BuildPath("lib64"), 0755) == 0) + { + BindMountReadOnly("/lib64", BuildPath("lib64")); + // Failure is non-fatal for lib64 + } + } + } + + // /etc (required — for resolv.conf, ld.so.cache, etc.) + if (MkdirIfNeeded(BuildPath("etc"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/etc failed", errno); + } + if (BindMountReadOnly("/etc", BuildPath("etc")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /etc failed", errno); + } + + // /worker — bind-mount worker directory (contains the executable) + if (MkdirIfNeeded(BuildPath("worker"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/worker failed", errno); + } + if (BindMountReadOnly(WorkerPath, BuildPath("worker")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount worker dir failed", errno); + } + + // 5. Mount /proc inside sandbox + if (MkdirIfNeeded(BuildPath("proc"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/proc failed", errno); + } + if (mount("proc", BuildPath("proc"), "proc", MS_NOSUID | MS_NOEXEC | MS_NODEV, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount /proc failed", errno); + } + + // 6. Mount tmpfs /dev and bind-mount essential device nodes + if (MkdirIfNeeded(BuildPath("dev"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/dev failed", errno); + } + if (mount("tmpfs", BuildPath("dev"), "tmpfs", MS_NOSUID | MS_NOEXEC, "size=64k,mode=0755") != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount tmpfs /dev failed", errno); + } + + // Bind-mount /dev/null, /dev/zero, /dev/urandom + { + char DevSrc[64]; + char DevDst[4096]; + + auto BindDev = [&](const char* Name) { + snprintf(DevSrc, sizeof(DevSrc), "/dev/%s", Name); + snprintf(DevDst, sizeof(DevDst), "%s/dev/%s", SandboxPath, Name); + + // Create the file to mount over + int Fd = open(DevDst, O_WRONLY | O_CREAT, 0666); + if (Fd >= 0) + { + close(Fd); + } + mount(DevSrc, DevDst, nullptr, MS_BIND, nullptr); + // Non-fatal if individual devices fail + }; + + BindDev("null"); + BindDev("zero"); + BindDev("urandom"); + } + + // 7. pivot_root to sandbox + // pivot_root requires the new root and put_old to be mount points. + // Bind-mount sandbox onto itself to make it a mount point. + if (mount(SandboxPath, SandboxPath, nullptr, MS_BIND | MS_REC, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount sandbox onto itself failed", errno); + } + + // Create .pivot_old inside sandbox + char PivotOld[4096]; + snprintf(PivotOld, sizeof(PivotOld), "%s/.pivot_old", SandboxPath); + if (MkdirIfNeeded(PivotOld, 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir .pivot_old failed", errno); + } + + if (syscall(SYS_pivot_root, SandboxPath, PivotOld) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "pivot_root failed", errno); + } + + // 8. Now inside new root. Clean up old root. + if (chdir("/") != 0) + { + WriteErrorAndExit(ErrorPipeFd, "chdir / failed", errno); + } + + if (umount2("/.pivot_old", MNT_DETACH) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "umount2 .pivot_old failed", errno); + } + + rmdir("/.pivot_old"); + } + +} // anonymous namespace + +LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); + + if (m_Sandboxed) + { + ZEN_INFO("namespace sandboxing enabled for child processes"); + } +} + +SubmitResult +LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("LinuxProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: <worker_exe_path> -Build=build.action + // Pre-compute all path strings before fork() for async-signal-safety. + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::string ExePathStr; + std::string SandboxedExePathStr; + + if (m_Sandboxed) + { + // After pivot_root, the worker dir is at /worker inside the new root + std::filesystem::path SandboxedExePath = std::filesystem::path("/worker") / std::filesystem::path(ExecPath); + SandboxedExePathStr = SandboxedExePath.string(); + // We still need the real path for logging + ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string(); + } + else + { + ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string(); + } + + std::string BuildArg = "-Build=build.action"; + + // argv[0] should be the path the child will see + const std::string& ChildExePath = m_Sandboxed ? SandboxedExePathStr : ExePathStr; + + std::vector<char*> ArgV; + ArgV.push_back(const_cast<char*>(ChildExePath.data())); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + std::string WorkerPathStr = Prepared->WorkerPath.string(); + + // Pre-fork: get uid/gid for namespace mapping, create error pipe + uid_t CurrentUid = 0; + gid_t CurrentGid = 0; + int ErrorPipe[2] = {-1, -1}; + + if (m_Sandboxed) + { + CurrentUid = getuid(); + CurrentGid = getgid(); + + if (pipe2(ErrorPipe, O_CLOEXEC) != 0) + { + throw zen::runtime_error("pipe2() for sandbox error pipe failed: {}", strerror(errno)); + } + } + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + int SavedErrno = errno; + if (m_Sandboxed) + { + close(ErrorPipe[0]); + close(ErrorPipe[1]); + } + throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno)); + } + + if (ChildPid == 0) + { + // Child process + + if (m_Sandboxed) + { + // Close read end of error pipe — child only writes + close(ErrorPipe[0]); + + SetupNamespaceSandbox(SandboxPathStr.c_str(), CurrentUid, CurrentGid, WorkerPathStr.c_str(), ErrorPipe[1]); + + // After pivot_root, CWD is "/" which is the sandbox root. + // execve with the sandboxed path. + execve(SandboxedExePathStr.c_str(), ArgV.data(), Envp.data()); + + WriteErrorAndExit(ErrorPipe[1], "execve failed", errno); + } + else + { + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + _exit(127); + } + } + + // Parent process + + if (m_Sandboxed) + { + // Close write end of error pipe — parent only reads + close(ErrorPipe[1]); + + // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC + // and read returns 0. If setup failed, child wrote an error message. + char ErrBuf[512]; + ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1); + close(ErrorPipe[0]); + + if (BytesRead > 0) + { + // Sandbox setup or execve failed + ErrBuf[BytesRead] = '\0'; + + // Reap the child (it called _exit(127)) + waitpid(ChildPid, nullptr, 0); + + // Clean up the sandbox in the background + m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); + + ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + + Action->SetActionState(RunnerAction::State::Failed); + return SubmitResult{.IsAccepted = false}; + } + } + + // Store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +LinuxProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("LinuxProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +LinuxProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("LinuxProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +LinuxProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("LinuxProcessRunner::CancelAction"); + + // Hold the shared lock while sending the signal to prevent the sweep thread + // from reaping the PID (via waitpid) between our lookup and kill(). Without + // the lock held, the PID could be recycled by the kernel and we'd signal an + // unrelated process. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (!Target->ProcessHandle) + { + return; + } + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Target->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno)); + return; + } + + ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +static uint64_t +ReadProcStatCpuTicks(pid_t Pid) +{ + char Path[64]; + snprintf(Path, sizeof(Path), "/proc/%d/stat", static_cast<int>(Pid)); + + char Buf[256]; + int Fd = open(Path, O_RDONLY); + if (Fd < 0) + { + return 0; + } + + ssize_t Len = read(Fd, Buf, sizeof(Buf) - 1); + close(Fd); + + if (Len <= 0) + { + return 0; + } + + Buf[Len] = '\0'; + + // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens + const char* P = strrchr(Buf, ')'); + if (!P) + { + return 0; + } + + P += 2; // skip ') ' + + // Remaining fields (space-separated, 0-indexed from here): + // 0:state 1:ppid 2:pgrp 3:session 4:tty_nr 5:tty_pgrp 6:flags + // 7:minflt 8:cminflt 9:majflt 10:cmajflt 11:utime 12:stime + unsigned long UTime = 0; + unsigned long STime = 0; + sscanf(P, "%*c %*d %*d %*d %*d %*d %*u %*u %*u %*u %*u %lu %lu", &UTime, &STime); + return UTime + STime; +} + +void +LinuxProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + static const long ClkTck = sysconf(_SC_CLK_TCK); + + const pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running.ProcessHandle)); + + const uint64_t NowTicks = GetHifreqTimerValue(); + const uint64_t CurrentOsTicks = ReadProcStatCpuTicks(Pid); + + if (CurrentOsTicks == 0) + { + // Process gone or /proc entry unreadable — record timestamp without updating usage + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = 0; + return; + } + + // Cumulative CPU seconds (absolute, available from first sample) + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / ClkTck), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) * 1000.0 / ClkTck / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/linuxrunner.h b/src/zencompute/runners/linuxrunner.h new file mode 100644 index 000000000..266de366b --- /dev/null +++ b/src/zencompute/runners/linuxrunner.h @@ -0,0 +1,44 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +namespace zen::compute { + +/** Native Linux process runner for executing Linux worker executables directly. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using Linux namespaces: + user, mount, and network namespaces are unshared so the child has no network + access and can only see the sandbox directory (with system libraries bind-mounted + read-only). This requires no special privileges thanks to user namespaces. + */ +class LinuxProcessRunner : public LocalProcessRunner +{ +public: + LinuxProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + bool m_Sandboxed = false; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp new file mode 100644 index 000000000..7aaefb06e --- /dev/null +++ b/src/zencompute/runners/localrunner.cpp @@ -0,0 +1,674 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/system.h> +# include <zencore/scopeguard.h> +# include <zencore/timer.h> +# include <zencore/trace.h> +# include <zenstore/cidstore.h> + +# include <span> + +namespace zen::compute { + +using namespace std::literals; + +LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions) +: FunctionRunner(BaseDir) +, m_Log(logging::Get("local_exec")) +, m_ChunkResolver(Resolver) +, m_WorkerPath(std::filesystem::weakly_canonical(BaseDir / "workers")) +, m_SandboxPath(std::filesystem::weakly_canonical(BaseDir / "scratch")) +, m_DeferredDeleter(Deleter) +, m_WorkerPool(WorkerPool) +{ + SystemMetrics Sm = GetSystemMetricsForReporting(); + + m_MaxRunningActions = Sm.LogicalProcessorCount * 2; + + if (MaxConcurrentActions > 0) + { + m_MaxRunningActions = MaxConcurrentActions; + } + + ZEN_INFO("Max concurrent action count: {}", m_MaxRunningActions); + + bool DidCleanup = false; + + if (std::filesystem::is_directory(m_ActionsPath)) + { + ZEN_INFO("Cleaning '{}'", m_ActionsPath); + + std::error_code Ec; + CleanDirectory(m_ActionsPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_ActionsPath, Ec.message()); + } + + DidCleanup = true; + } + + if (std::filesystem::is_directory(m_SandboxPath)) + { + ZEN_INFO("Cleaning '{}'", m_SandboxPath); + std::error_code Ec; + CleanDirectory(m_SandboxPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_SandboxPath, Ec.message()); + } + + DidCleanup = true; + } + + // We clean out all workers on startup since we can't know they are good. They could be bad + // due to tampering, malware (which I also mean to include AV and antimalware software) or + // other processes we have no control over + if (std::filesystem::is_directory(m_WorkerPath)) + { + ZEN_INFO("Cleaning '{}'", m_WorkerPath); + std::error_code Ec; + CleanDirectory(m_WorkerPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_WorkerPath, Ec.message()); + } + + DidCleanup = true; + } + + if (DidCleanup) + { + ZEN_INFO("Cleanup complete"); + } + + m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; + +# if ZEN_PLATFORM_WINDOWS + // Suppress any error dialogs caused by missing dependencies + UINT OldMode = ::SetErrorMode(0); + ::SetErrorMode(OldMode | SEM_FAILCRITICALERRORS); +# endif + + m_AcceptNewActions = true; +} + +LocalProcessRunner::~LocalProcessRunner() +{ + try + { + Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception during local process runner shutdown: {}", Ex.what()); + } +} + +void +LocalProcessRunner::Shutdown() +{ + ZEN_TRACE_CPU("LocalProcessRunner::Shutdown"); + m_AcceptNewActions = false; + + m_MonitorThreadEnabled = false; + m_MonitorThreadEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } + + CancelRunningActions(); +} + +std::filesystem::path +LocalProcessRunner::CreateNewSandbox() +{ + ZEN_TRACE_CPU("LocalProcessRunner::CreateNewSandbox"); + std::string UniqueId = std::to_string(++m_SandboxCounter); + std::filesystem::path Path = m_SandboxPath / UniqueId; + zen::CreateDirectories(Path); + + return Path; +} + +void +LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) +{ + ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker"); + if (m_DumpActions) + { + CbObject WorkerDescriptor = WorkerPackage.GetObject(); + const IoHash& WorkerId = WorkerPackage.GetObjectHash(); + + std::string UniqueId = fmt::format("worker_{}"sv, WorkerId); + std::filesystem::path Path = m_ActionsPath / UniqueId; + + zen::WriteFile(Path / "worker.ucb", WorkerDescriptor.GetBuffer().AsIoBuffer()); + + ManifestWorker(WorkerPackage, Path / "tree", [&](const IoHash& Cid, CompressedBuffer& ChunkBuffer) { + std::filesystem::path ChunkPath = Path / "chunks" / Cid.ToHexString(); + zen::WriteFile(ChunkPath, ChunkBuffer.GetCompressed()); + }); + + ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path); + } +} + +size_t +LocalProcessRunner::QueryCapacity() +{ + // Estimate how much more work we're ready to accept + + RwLock::SharedLockScope _{m_RunningLock}; + + if (!m_AcceptNewActions) + { + return 0; + } + + const size_t InFlightCount = m_RunningMap.size() + m_SubmittingCount.load(std::memory_order_relaxed); + + if (const size_t MaxRunningActions = m_MaxRunningActions; InFlightCount >= MaxRunningActions) + { + return 0; + } + else + { + return MaxRunningActions - InFlightCount; + } +} + +std::vector<SubmitResult> +LocalProcessRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + if (Actions.size() <= 1) + { + std::vector<SubmitResult> Results; + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; + } + + // For nontrivial batches, check capacity upfront and accept what fits. + // Accepted actions are transitioned to Submitting and dispatched to the + // worker pool as fire-and-forget, so SubmitActions returns immediately + // and the scheduler thread is free to handle completions and updates. + + size_t Available = QueryCapacity(); + + std::vector<SubmitResult> Results(Actions.size()); + + size_t AcceptCount = std::min(Available, Actions.size()); + + for (size_t i = 0; i < AcceptCount; ++i) + { + const Ref<RunnerAction>& Action = Actions[i]; + + Action->SetActionState(RunnerAction::State::Submitting); + m_SubmittingCount.fetch_add(1, std::memory_order_relaxed); + + Results[i] = SubmitResult{.IsAccepted = true}; + + m_WorkerPool.ScheduleWork( + [this, Action]() { + auto CountGuard = MakeGuard([this] { m_SubmittingCount.fetch_sub(1, std::memory_order_relaxed); }); + + SubmitResult Result = SubmitAction(Action); + + if (!Result.IsAccepted) + { + // This might require another state? We should + // distinguish between outright rejections (e.g. invalid action) + // and transient failures (e.g. failed to launch process) which might + // be retried by the scheduler, but for now just fail the action + Action->SetActionState(RunnerAction::State::Failed); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + + for (size_t i = AcceptCount; i < Actions.size(); ++i) + { + Results[i] = SubmitResult{.IsAccepted = false}; + } + + return Results; +} + +std::optional<LocalProcessRunner::PreparedAction> +LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("LocalProcessRunner::PrepareActionSubmission"); + + // Verify whether we can accept more work + + { + RwLock::SharedLockScope _{m_RunningLock}; + + if (!m_AcceptNewActions) + { + return std::nullopt; + } + + if (m_RunningMap.size() >= size_t(m_MaxRunningActions)) + { + return std::nullopt; + } + } + + // Each enqueued action is assigned an integer index (logical sequence number), + // which we use as a key for tracking data structures and as an opaque id which + // may be used by clients to reference the scheduled action + + const int32_t ActionLsn = Action->ActionLsn; + const CbObject& ActionObj = Action->ActionObj; + + MaybeDumpAction(ActionLsn, ActionObj); + + std::filesystem::path SandboxPath = CreateNewSandbox(); + + // Ensure the sandbox directory is cleaned up if any subsequent step throws + auto SandboxGuard = MakeGuard([&] { m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(SandboxPath)); }); + + CbPackage WorkerPackage = Action->Worker.Descriptor; + + std::filesystem::path WorkerPath = ManifestWorker(Action->Worker); + + // Write out action + + zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer()); + + // Manifest inputs in sandbox + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash Cid = Field.AsHash(); + std::filesystem::path FilePath{SandboxPath / "Inputs"sv / Cid.ToHexString()}; + IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(Cid); + + if (!DataBuffer) + { + throw std::runtime_error(fmt::format("input CID chunk '{}' missing", Cid)); + } + + zen::WriteFile(FilePath, DataBuffer); + }); + + Action->ExecutionLocation = "local"; + + SandboxGuard.Dismiss(); + + return PreparedAction{ + .ActionLsn = ActionLsn, + .SandboxPath = std::move(SandboxPath), + .WorkerPath = std::move(WorkerPath), + .WorkerPackage = std::move(WorkerPackage), + }; +} + +SubmitResult +LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + // Base class is not directly usable — platform subclasses override this + ZEN_UNUSED(Action); + return SubmitResult{.IsAccepted = false}; +} + +size_t +LocalProcessRunner::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunningLock); + return m_RunningMap.size(); +} + +std::filesystem::path +LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) +{ + ZEN_TRACE_CPU("LocalProcessRunner::ManifestWorker"); + RwLock::SharedLockScope _(m_WorkerLock); + + std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); + + if (!std::filesystem::exists(WorkerDir)) + { + _.ReleaseNow(); + + RwLock::ExclusiveLockScope $(m_WorkerLock); + + if (!std::filesystem::exists(WorkerDir)) + { + ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {}); + } + } + + return WorkerDir; +} + +void +LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromPackage, + CbObjectView FileEntry, + const std::filesystem::path& SandboxRootPath, + std::function<void(const IoHash&, CompressedBuffer&)>& ChunkReferenceCallback) +{ + std::string_view Name = FileEntry["name"sv].AsString(); + const IoHash ChunkHash = FileEntry["hash"sv].AsHash(); + const uint64_t Size = FileEntry["size"sv].AsUInt64(); + + CompressedBuffer Compressed; + + if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash)) + { + Compressed = Attachment->AsCompressedBinary(); + } + else + { + IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(ChunkHash); + + if (!DataBuffer) + { + throw std::runtime_error(fmt::format("worker chunk '{}' missing", ChunkHash)); + } + + uint64_t DataRawSize = 0; + IoHash DataRawHash; + Compressed = CompressedBuffer::FromCompressed(SharedBuffer{DataBuffer}, DataRawHash, DataRawSize); + + if (DataRawSize != Size) + { + throw std::runtime_error( + fmt::format("worker chunk '{}' size: {}, action spec expected {}", ChunkHash, DataBuffer.Size(), Size)); + } + } + + ChunkReferenceCallback(ChunkHash, Compressed); + + std::filesystem::path FilePath{SandboxRootPath / std::filesystem::path(Name).make_preferred()}; + + // Validate the resolved path stays within the sandbox to prevent directory traversal + // via malicious names like "../../etc/evil" + // + // This might be worth revisiting to frontload the validation and eliminate some memory + // allocations in the future. + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxRootPath); + std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(FilePath); + std::string RootStr = CanonicalRoot.string(); + std::string FileStr = CanonicalFile.string(); + + if (FileStr.size() < RootStr.size() || FileStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: '{}' escapes sandbox root '{}'", Name, SandboxRootPath); + } + } + + SharedBuffer Decompressed = Compressed.Decompress(); + zen::WriteFile(FilePath, Decompressed.AsIoBuffer()); +} + +void +LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, + const std::filesystem::path& SandboxPath, + std::function<void(const IoHash&, CompressedBuffer&)>&& ChunkReferenceCallback) +{ + CbObject WorkerDescription = WorkerPackage.GetObject(); + + // Manifest worker in Sandbox + + for (auto& It : WorkerDescription["executables"sv]) + { + DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); +# if !ZEN_PLATFORM_WINDOWS + std::string_view ExeName = It.AsObjectView()["name"sv].AsString(); + std::filesystem::path ExePath{SandboxPath / std::filesystem::path(ExeName).make_preferred()}; + std::filesystem::permissions( + ExePath, + std::filesystem::perms::owner_exec | std::filesystem::perms::group_exec | std::filesystem::perms::others_exec, + std::filesystem::perm_options::add); +# endif + } + + for (auto& It : WorkerDescription["dirs"sv]) + { + std::string_view Name = It.AsString(); + std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; + + // Validate dir path stays within sandbox + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxPath); + std::filesystem::path CanonicalDir = std::filesystem::weakly_canonical(DirPath); + std::string RootStr = CanonicalRoot.string(); + std::string DirStr = CanonicalDir.string(); + + if (DirStr.size() < RootStr.size() || DirStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: dir '{}' escapes sandbox root '{}'", Name, SandboxPath); + } + } + + zen::CreateDirectories(DirPath); + } + + for (auto& It : WorkerDescription["files"sv]) + { + DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); + } + + WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer()); +} + +CbPackage +LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) +{ + ZEN_TRACE_CPU("LocalProcessRunner::GatherActionOutputs"); + std::filesystem::path OutputFile = SandboxPath / "build.output"; + FileContents OutputData = zen::ReadFile(OutputFile); + + if (OutputData.ErrorCode) + { + throw std::system_error(OutputData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputFile)); + } + + CbPackage OutputPackage; + CbObject Output = zen::LoadCompactBinaryObject(OutputData.Flatten()); + + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalRawAttachmentBytes = 0; + + Output.IterateAttachments([&](CbFieldView Field) { + IoHash Hash = Field.AsHash(); + std::filesystem::path OutputPath{SandboxPath / "Outputs" / Hash.ToHexString()}; + FileContents ChunkData = zen::ReadFile(OutputPath); + + if (ChunkData.ErrorCode) + { + throw std::system_error(ChunkData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputPath)); + } + + uint64_t ChunkDataRawSize = 0; + IoHash ChunkDataHash; + CompressedBuffer AttachmentBuffer = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData.Flatten()), ChunkDataHash, ChunkDataRawSize); + + if (!AttachmentBuffer) + { + throw std::runtime_error("Invalid output encountered (not valid CompressedBuffer format)"); + } + + TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize(); + TotalRawAttachmentBytes += ChunkDataRawSize; + + CbAttachment Attachment(std::move(AttachmentBuffer), ChunkDataHash); + OutputPackage.AddAttachment(Attachment); + }); + + OutputPackage.SetObject(Output); + + ZEN_DEBUG("Action completed with {} attachments ({} compressed, {} uncompressed)", + OutputPackage.GetAttachments().size(), + NiceBytes(TotalAttachmentBytes), + NiceBytes(TotalRawAttachmentBytes)); + + return OutputPackage; +} + +void +LocalProcessRunner::MonitorThreadFunction() +{ + SetCurrentThreadName("LocalProcessRunner_Monitor"); + + auto _ = MakeGuard([&] { ZEN_INFO("monitor thread exiting"); }); + + do + { + // On Windows it's possible to wait on process handles, so we wait for either a process to exit + // or for the monitor event to be signaled (which indicates we should check for cancellation + // or shutdown). This could be further improved by using a completion port and registering process + // handles with it, but this is a reasonable first implementation given that we shouldn't be dealing + // with an enormous number of concurrent processes. + // + // On other platforms we just wait on the monitor event and poll for process exits at intervals. +# if ZEN_PLATFORM_WINDOWS + auto WaitOnce = [&] { + HANDLE WaitHandles[MAXIMUM_WAIT_OBJECTS]; + + uint32_t NumHandles = 0; + + WaitHandles[NumHandles++] = m_MonitorThreadEvent.GetWindowsHandle(); + + m_RunningLock.WithSharedLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd && NumHandles < MAXIMUM_WAIT_OBJECTS; ++It) + { + Ref<RunningAction> Action = It->second; + + WaitHandles[NumHandles++] = Action->ProcessHandle; + } + }); + + DWORD WaitResult = WaitForMultipleObjects(NumHandles, WaitHandles, FALSE, 1000); + + // return true if a handle was signaled + return (WaitResult <= NumHandles); + }; +# else + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(1000); }; +# endif + + while (!WaitOnce()) + { + if (m_MonitorThreadEnabled == false) + { + return; + } + + SweepRunningActions(); + SampleRunningProcessCpu(); + } + + // Signal received + + SweepRunningActions(); + SampleRunningProcessCpu(); + } while (m_MonitorThreadEnabled); +} + +void +LocalProcessRunner::CancelRunningActions() +{ + // Base class is not directly usable — platform subclasses override this +} + +void +LocalProcessRunner::SampleRunningProcessCpu() +{ + static constexpr uint64_t kSampleIntervalMs = 5'000; + + m_RunningLock.WithSharedLock([&] { + const uint64_t Now = GetHifreqTimerValue(); + for (auto& [Lsn, Running] : m_RunningMap) + { + const bool NeverSampled = Running->LastCpuSampleTicks == 0; + const bool IntervalElapsed = Stopwatch::GetElapsedTimeMs(Now - Running->LastCpuSampleTicks) >= kSampleIntervalMs; + if (NeverSampled || IntervalElapsed) + { + SampleProcessCpu(*Running); + } + } + }); +} + +void +LocalProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("LocalProcessRunner::SweepRunningActions"); +} + +void +LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& CompletedActions) +{ + ZEN_TRACE_CPU("LocalProcessRunner::ProcessCompletedActions"); + // Shared post-processing: gather outputs, set state, clean sandbox. + // Note that this must be called without holding any local locks + // otherwise we may end up with deadlocks. + + for (Ref<RunningAction> Running : CompletedActions) + { + const int ActionLsn = Running->Action->ActionLsn; + + if (Running->ExitCode == 0) + { + try + { + // Gather outputs + + CbPackage OutputPackage = GatherActionOutputs(Running->SandboxPath); + + Running->Action->SetResult(std::move(OutputPackage)); + Running->Action->SetActionState(RunnerAction::State::Completed); + + // Enqueue sandbox for deferred background deletion, giving + // file handles time to close before we attempt removal. + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); + + // Success -- continue with next iteration of the loop + continue; + } + catch (std::exception& Ex) + { + ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what()); + } + } + + // Failed - clean up the sandbox in the background. + + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h new file mode 100644 index 000000000..7493e980b --- /dev/null +++ b/src/zencompute/runners/localrunner.h @@ -0,0 +1,138 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "functionrunner.h" + +# include <zencore/thread.h> +# include <zencore/zencore.h> +# include <zenstore/cidstore.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/logging.h> + +# include "deferreddeleter.h" + +# include <zencore/workthreadpool.h> + +# include <atomic> +# include <filesystem> +# include <optional> +# include <thread> + +namespace zen { +class CbPackage; +} + +namespace zen::compute { + +/** Direct process spawner + + This runner simply sets up a directory structure for each job and + creates a process to perform the computation in it. It is not very + efficient and is intended mostly for testing. + + */ + +class LocalProcessRunner : public FunctionRunner +{ + LocalProcessRunner(LocalProcessRunner&&) = delete; + LocalProcessRunner& operator=(LocalProcessRunner&&) = delete; + +public: + LocalProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions = 0); + ~LocalProcessRunner(); + + virtual void Shutdown() override; + virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + [[nodiscard]] virtual bool IsHealthy() override { return true; } + [[nodiscard]] virtual size_t GetSubmittedActionCount() override; + [[nodiscard]] virtual size_t QueryCapacity() override; + [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override; + +protected: + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + + struct RunningAction : public RefCounted + { + Ref<RunnerAction> Action; + void* ProcessHandle = nullptr; + int ExitCode = 0; + std::filesystem::path SandboxPath; + + // State for periodic CPU usage sampling + uint64_t LastCpuSampleTicks = 0; // hifreq timer value at last sample + uint64_t LastCpuOsTicks = 0; // OS CPU ticks (platform-specific units) at last sample + }; + + std::atomic_bool m_AcceptNewActions; + ChunkResolver& m_ChunkResolver; + RwLock m_WorkerLock; + std::filesystem::path m_WorkerPath; + std::atomic<int32_t> m_SandboxCounter = 0; + std::filesystem::path m_SandboxPath; + int32_t m_MaxRunningActions = 64; // arbitrary limit for testing + + // if used in conjuction with m_ResultsLock, this lock must be taken *after* + // m_ResultsLock to avoid deadlocks + RwLock m_RunningLock; + std::unordered_map<int, Ref<RunningAction>> m_RunningMap; + + std::atomic<int32_t> m_SubmittingCount = 0; + DeferredDirectoryDeleter& m_DeferredDeleter; + WorkerThreadPool& m_WorkerPool; + + std::thread m_MonitorThread; + std::atomic<bool> m_MonitorThreadEnabled{true}; + Event m_MonitorThreadEvent; + void MonitorThreadFunction(); + virtual void SweepRunningActions(); + virtual void CancelRunningActions(); + + // Sample CPU usage for all currently running processes (throttled per-action). + void SampleRunningProcessCpu(); + + // Override in platform runners to sample one process. Called under a shared RunningLock. + virtual void SampleProcessCpu(RunningAction& /*Running*/) {} + + // Shared preamble for SubmitAction: capacity check, sandbox creation, + // worker manifesting, action writing, input manifesting. + struct PreparedAction + { + int32_t ActionLsn; + std::filesystem::path SandboxPath; + std::filesystem::path WorkerPath; + CbPackage WorkerPackage; + }; + std::optional<PreparedAction> PrepareActionSubmission(Ref<RunnerAction> Action); + + // Shared post-processing for SweepRunningActions: gather outputs, + // set state, clean sandbox. + void ProcessCompletedActions(std::vector<Ref<RunningAction>>& CompletedActions); + + std::filesystem::path CreateNewSandbox(); + void ManifestWorker(const CbPackage& WorkerPackage, + const std::filesystem::path& SandboxPath, + std::function<void(const IoHash&, CompressedBuffer&)>&& ChunkReferenceCallback); + std::filesystem::path ManifestWorker(const WorkerDesc& Worker); + CbPackage GatherActionOutputs(std::filesystem::path SandboxPath); + + void DecompressAttachmentToFile(const CbPackage& FromPackage, + CbObjectView FileEntry, + const std::filesystem::path& SandboxRootPath, + std::function<void(const IoHash&, CompressedBuffer&)>& ChunkReferenceCallback); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp new file mode 100644 index 000000000..5cec90699 --- /dev/null +++ b/src/zencompute/runners/macrunner.cpp @@ -0,0 +1,491 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "macrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <fcntl.h> +# include <libproc.h> +# include <sandbox.h> +# include <signal.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +namespace { + + // All helper functions in this namespace are async-signal-safe (safe to call + // between fork() and execve()). They use only raw syscalls and avoid any + // heap allocation, stdio, or other non-AS-safe operations. + + void WriteToFd(int Fd, const char* Buf, size_t Len) + { + while (Len > 0) + { + ssize_t Written = write(Fd, Buf, Len); + if (Written <= 0) + { + break; + } + Buf += Written; + Len -= static_cast<size_t>(Written); + } + } + + [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno) + { + // Write the message prefix + size_t MsgLen = 0; + for (const char* P = Msg; *P; ++P) + { + ++MsgLen; + } + WriteToFd(ErrorPipeFd, Msg, MsgLen); + + // Append ": " and the errno string if non-zero + if (Errno != 0) + { + WriteToFd(ErrorPipeFd, ": ", 2); + const char* ErrStr = strerror(Errno); + size_t ErrLen = 0; + for (const char* P = ErrStr; *P; ++P) + { + ++ErrLen; + } + WriteToFd(ErrorPipeFd, ErrStr, ErrLen); + } + + _exit(127); + } + + // Build a Seatbelt profile string that denies everything by default and + // allows only the minimum needed for the worker to execute: process ops, + // system library reads, worker directory (read-only), and sandbox directory + // (read-write). Network access is denied implicitly by the deny-default policy. + std::string BuildSandboxProfile(const std::string& SandboxPath, const std::string& WorkerPath) + { + std::string Profile; + Profile.reserve(1024); + + Profile += "(version 1)\n"; + Profile += "(deny default)\n"; + Profile += "(allow process*)\n"; + Profile += "(allow sysctl-read)\n"; + Profile += "(allow file-read-metadata)\n"; + + // System library paths needed for dynamic linker and runtime + Profile += "(allow file-read* (subpath \"/usr\"))\n"; + Profile += "(allow file-read* (subpath \"/System\"))\n"; + Profile += "(allow file-read* (subpath \"/Library\"))\n"; + Profile += "(allow file-read* (subpath \"/dev\"))\n"; + Profile += "(allow file-read* (subpath \"/private/var/db/dyld\"))\n"; + Profile += "(allow file-read* (subpath \"/etc\"))\n"; + + // Worker directory: read-only + Profile += "(allow file-read* (subpath \""; + Profile += WorkerPath; + Profile += "\"))\n"; + + // Sandbox directory: read+write + Profile += "(allow file-read* file-write* (subpath \""; + Profile += SandboxPath; + Profile += "\"))\n"; + + return Profile; + } + +} // anonymous namespace + +MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); + + if (m_Sandboxed) + { + ZEN_INFO("Seatbelt sandboxing enabled for child processes"); + } +} + +SubmitResult +MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("MacProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: <worker_exe_path> -Build=build.action + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath); + std::string ExePathStr = ExePath.string(); + std::string BuildArg = "-Build=build.action"; + + std::vector<char*> ArgV; + ArgV.push_back(ExePathStr.data()); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + std::string WorkerPathStr = Prepared->WorkerPath.string(); + + // Pre-fork: build sandbox profile and create error pipe + std::string SandboxProfile; + int ErrorPipe[2] = {-1, -1}; + + if (m_Sandboxed) + { + SandboxProfile = BuildSandboxProfile(SandboxPathStr, WorkerPathStr); + + if (pipe(ErrorPipe) != 0) + { + throw zen::runtime_error("pipe() for sandbox error pipe failed: {}", strerror(errno)); + } + fcntl(ErrorPipe[0], F_SETFD, FD_CLOEXEC); + fcntl(ErrorPipe[1], F_SETFD, FD_CLOEXEC); + } + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + int SavedErrno = errno; + if (m_Sandboxed) + { + close(ErrorPipe[0]); + close(ErrorPipe[1]); + } + throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno)); + } + + if (ChildPid == 0) + { + // Child process + + if (m_Sandboxed) + { + // Close read end of error pipe — child only writes + close(ErrorPipe[0]); + + // Apply Seatbelt sandbox profile + char* ErrorBuf = nullptr; + if (sandbox_init(SandboxProfile.c_str(), 0, &ErrorBuf) != 0) + { + // sandbox_init failed — write error to pipe and exit + if (ErrorBuf) + { + WriteErrorAndExit(ErrorPipe[1], ErrorBuf, 0); + // WriteErrorAndExit does not return, but sandbox_free_error + // is not needed since we _exit + } + WriteErrorAndExit(ErrorPipe[1], "sandbox_init failed", errno); + } + if (ErrorBuf) + { + sandbox_free_error(ErrorBuf); + } + + if (chdir(SandboxPathStr.c_str()) != 0) + { + WriteErrorAndExit(ErrorPipe[1], "chdir to sandbox failed", errno); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + + WriteErrorAndExit(ErrorPipe[1], "execve failed", errno); + } + else + { + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + _exit(127); + } + } + + // Parent process + + if (m_Sandboxed) + { + // Close write end of error pipe — parent only reads + close(ErrorPipe[1]); + + // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC + // and read returns 0. If setup failed, child wrote an error message. + char ErrBuf[512]; + ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1); + close(ErrorPipe[0]); + + if (BytesRead > 0) + { + // Sandbox setup or execve failed + ErrBuf[BytesRead] = '\0'; + + // Reap the child (it called _exit(127)) + waitpid(ChildPid, nullptr, 0); + + // Clean up the sandbox in the background + m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); + + ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + + Action->SetActionState(RunnerAction::State::Failed); + return SubmitResult{.IsAccepted = false}; + } + } + + // Store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +MacProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("MacProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +MacProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("MacProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +MacProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("MacProcessRunner::CancelAction"); + + // Hold the shared lock while sending the signal to prevent the sweep thread + // from reaping the PID (via waitpid) between our lookup and kill(). Without + // the lock held, the PID could be recycled by the kernel and we'd signal an + // unrelated process. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (!Target->ProcessHandle) + { + return; + } + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Target->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno)); + return; + } + + ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +void +MacProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + const pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running.ProcessHandle)); + + struct proc_taskinfo Info; + if (proc_pidinfo(Pid, PROC_PIDTASKINFO, 0, &Info, sizeof(Info)) <= 0) + { + return; + } + + // pti_total_user and pti_total_system are in nanoseconds + const uint64_t CurrentOsTicks = Info.pti_total_user + Info.pti_total_system; + const uint64_t NowTicks = GetHifreqTimerValue(); + + // Cumulative CPU seconds (absolute, available from first sample): ns → seconds + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 1'000'000'000.0), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + // ns → ms: divide by 1,000,000; then as percent of elapsed ms + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 1'000'000.0 / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/macrunner.h b/src/zencompute/runners/macrunner.h new file mode 100644 index 000000000..d653b923a --- /dev/null +++ b/src/zencompute/runners/macrunner.h @@ -0,0 +1,43 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC + +namespace zen::compute { + +/** Native macOS process runner for executing Mac worker executables directly. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using macOS Seatbelt + (sandbox_init): no network access and no filesystem access outside the + explicitly allowed sandbox and worker directories. This requires no elevation. + */ +class MacProcessRunner : public LocalProcessRunner +{ +public: + MacProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + bool m_Sandboxed = false; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp new file mode 100644 index 000000000..672636d06 --- /dev/null +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -0,0 +1,618 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "remotehttprunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/scopeguard.h> +# include <zencore/system.h> +# include <zencore/trace.h> +# include <zenhttp/httpcommon.h> +# include <zenstore/cidstore.h> + +# include <span> + +////////////////////////////////////////////////////////////////////////// + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, + const std::filesystem::path& BaseDir, + std::string_view HostName, + WorkerThreadPool& InWorkerPool) +: FunctionRunner(BaseDir) +, m_Log(logging::Get("http_exec")) +, m_ChunkResolver{InChunkResolver} +, m_WorkerPool{InWorkerPool} +, m_HostName{HostName} +, m_BaseUrl{fmt::format("{}/compute", HostName)} +, m_Http(m_BaseUrl) +, m_InstanceId(Oid::NewOid()) +{ + m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; +} + +RemoteHttpRunner::~RemoteHttpRunner() +{ + Shutdown(); +} + +void +RemoteHttpRunner::Shutdown() +{ + // TODO: should cleanly drain/cancel pending work + + m_MonitorThreadEnabled = false; + m_MonitorThreadEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } +} + +void +RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::RegisterWorker"); + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + CbPackage WorkerDesc = WorkerPackage; + + std::string WorkerUrl = fmt::format("/workers/{}", WorkerId); + + HttpClient::Response WorkerResponse = m_Http.Get(WorkerUrl); + + if (WorkerResponse.StatusCode == HttpResponseCode::NotFound) + { + HttpClient::Response DescResponse = m_Http.Post(WorkerUrl, WorkerDesc.GetObject()); + + if (DescResponse.StatusCode == HttpResponseCode::NotFound) + { + CbPackage Pkg = WorkerDesc; + + // Build response package by sending only the attachments + // the other end needs. We start with the full package and + // remove the attachments which are not needed. + + { + std::unordered_set<IoHash> Needed; + + CbObject Response = DescResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + Needed.insert(NeedHash); + } + + std::unordered_set<IoHash> ToRemove; + + for (const CbAttachment& Attachment : Pkg.GetAttachments()) + { + const IoHash& Hash = Attachment.GetHash(); + + if (Needed.find(Hash) == Needed.end()) + { + ToRemove.insert(Hash); + } + } + + for (const IoHash& Hash : ToRemove) + { + int RemovedCount = Pkg.RemoveAttachment(Hash); + + ZEN_ASSERT(RemovedCount == 1); + } + } + + // Post resulting package + + HttpClient::Response PayloadResponse = m_Http.Post(WorkerUrl, Pkg); + + if (!IsHttpSuccessCode(PayloadResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); + + // TODO: propagate error + } + } + else if (!IsHttpSuccessCode(DescResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); + + // TODO: propagate error + } + else + { + ZEN_ASSERT(DescResponse.StatusCode == HttpResponseCode::NoContent); + } + } + else if (WorkerResponse.StatusCode == HttpResponseCode::OK) + { + // Already known from a previous run + } + else if (!IsHttpSuccessCode(WorkerResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to look up worker {} at {}{} (error: {} {})", + WorkerId, + m_Http.GetBaseUri(), + WorkerUrl, + (int)WorkerResponse.StatusCode, + ToString(WorkerResponse.StatusCode)); + + // TODO: propagate error + } +} + +size_t +RemoteHttpRunner::QueryCapacity() +{ + // Estimate how much more work we're ready to accept + + RwLock::SharedLockScope _{m_RunningLock}; + + size_t RunningCount = m_RemoteRunningMap.size(); + + if (RunningCount >= size_t(m_MaxRunningActions)) + { + return 0; + } + + return m_MaxRunningActions - RunningCount; +} + +std::vector<SubmitResult> +RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions"); + + if (Actions.size() <= 1) + { + std::vector<SubmitResult> Results; + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; + } + + // For larger batches, submit HTTP requests in parallel via the shared worker pool + + std::vector<std::future<SubmitResult>> Futures; + Futures.reserve(Actions.size()); + + for (const Ref<RunnerAction>& Action : Actions) + { + std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); }); + + Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog)); + } + + std::vector<SubmitResult> Results; + Results.reserve(Futures.size()); + + for (auto& Future : Futures) + { + Results.push_back(Future.get()); + } + + return Results; +} + +SubmitResult +RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitAction"); + + // Verify whether we can accept more work + + { + RwLock::SharedLockScope _{m_RunningLock}; + if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) + { + return SubmitResult{.IsAccepted = false}; + } + } + + using namespace std::literals; + + // Each enqueued action is assigned an integer index (logical sequence number), + // which we use as a key for tracking data structures and as an opaque id which + // may be used by clients to reference the scheduled action + + Action->ExecutionLocation = m_HostName; + + const int32_t ActionLsn = Action->ActionLsn; + const CbObject& ActionObj = Action->ActionObj; + const IoHash ActionId = ActionObj.GetHash(); + + MaybeDumpAction(ActionLsn, ActionObj); + + // Determine the submission URL. If the action belongs to a queue, ensure a + // corresponding remote queue exists on the target node and submit via it. + + std::string SubmitUrl = "/jobs"; + if (const int QueueId = Action->QueueId; QueueId != 0) + { + CbObject QueueMeta = Action->GetOwnerSession()->GetQueueMetadata(QueueId); + CbObject QueueConfig = Action->GetOwnerSession()->GetQueueConfig(QueueId); + if (Oid Token = EnsureRemoteQueue(QueueId, QueueMeta, QueueConfig); Token != Oid::Zero) + { + SubmitUrl = fmt::format("/queues/{}/jobs", Token); + } + } + + // Enqueue job. If the remote returns FailedDependency (424), it means it + // cannot resolve the worker/function — re-register the worker and retry once. + + CbObject Result; + HttpClient::Response WorkResponse; + HttpResponseCode WorkResponseCode{}; + + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, ActionObj); + WorkResponseCode = WorkResponse.StatusCode; + + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); + + RegisterWorker(Action->Worker.Descriptor); + } + else + { + break; + } + } + + if (WorkResponseCode == HttpResponseCode::OK) + { + Result = WorkResponse.AsObject(); + } + else if (WorkResponseCode == HttpResponseCode::NotFound) + { + // Not all attachments are present + + // Build response package including all required attachments + + CbPackage Pkg; + Pkg.SetObject(ActionObj); + + CbObject Response = WorkResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + } + else + { + // No such attachment + + return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + } + } + + // Post resulting package + + HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); + + if (!PayloadResponse) + { + ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); + + // TODO: include more information about the failure in the response + + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + else if (PayloadResponse.StatusCode == HttpResponseCode::OK) + { + Result = PayloadResponse.AsObject(); + } + else + { + // Unexpected response + + const int ResponseStatusCode = (int)PayloadResponse.StatusCode; + + ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})", + ActionId, + m_Http.GetBaseUri(), + SubmitUrl, + ResponseStatusCode, + ToString(ResponseStatusCode)); + + return {.IsAccepted = false, + .Reason = fmt::format("unexpected response code {} {} from {}{}", + ResponseStatusCode, + ToString(ResponseStatusCode), + m_Http.GetBaseUri(), + SubmitUrl)}; + } + } + + if (Result) + { + if (const int32_t LsnField = Result["lsn"].AsInt32(0)) + { + HttpRunningAction NewAction; + NewAction.Action = Action; + NewAction.RemoteActionLsn = LsnField; + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + + m_RemoteRunningMap[LsnField] = std::move(NewAction); + } + + ZEN_DEBUG("scheduled action {} with remote LSN {} (local LSN {})", ActionId, LsnField, ActionLsn); + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; + } + } + + return {}; +} + +Oid +RemoteHttpRunner::EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config) +{ + { + RwLock::SharedLockScope _(m_QueueTokenLock); + if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end()) + { + return It->second; + } + } + + // Build a stable idempotency key that uniquely identifies this (runner instance, local queue) + // pair. The server uses this to return the same remote queue token for concurrent or redundant + // requests, preventing orphaned remote queues when multiple threads race through here. + // Also send hostname so the server can associate the queue with its origin for diagnostics. + CbObjectWriter Body; + Body << "idempotency_key"sv << fmt::format("{}/{}", m_InstanceId, QueueId); + Body << "hostname"sv << GetMachineName(); + if (Metadata) + { + Body << "metadata"sv << Metadata; + } + if (Config) + { + Body << "config"sv << Config; + } + + HttpClient::Response Resp = m_Http.Post("/queues/remote", Body.Save()); + if (!Resp) + { + ZEN_WARN("failed to create remote queue for local queue {} on {}", QueueId, m_HostName); + return Oid::Zero; + } + + Oid Token = Oid::TryFromHexString(Resp.AsObject()["queue_token"sv].AsString()); + if (Token == Oid::Zero) + { + return Oid::Zero; + } + + ZEN_DEBUG("created remote queue '{}' for local queue {} on {}", Token, QueueId, m_HostName); + + RwLock::ExclusiveLockScope _(m_QueueTokenLock); + auto [It, Inserted] = m_RemoteQueueTokens.try_emplace(QueueId, Token); + return It->second; +} + +void +RemoteHttpRunner::CancelRemoteQueue(int QueueId) +{ + Oid Token; + { + RwLock::SharedLockScope _(m_QueueTokenLock); + if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end()) + { + Token = It->second; + } + } + + if (Token == Oid::Zero) + { + return; + } + + HttpClient::Response Resp = m_Http.Delete(fmt::format("/queues/{}", Token)); + + if (Resp.StatusCode == HttpResponseCode::NoContent) + { + ZEN_DEBUG("cancelled remote queue '{}' (local queue {}) on {}", Token, QueueId, m_HostName); + } + else + { + ZEN_WARN("failed to cancel remote queue '{}' on {}: {}", Token, m_HostName, int(Resp.StatusCode)); + } +} + +bool +RemoteHttpRunner::IsHealthy() +{ + if (HttpClient::Response Ready = m_Http.Get("/ready")) + { + return true; + } + else + { + // TODO: use response to propagate context + return false; + } +} + +size_t +RemoteHttpRunner::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunningLock); + return m_RemoteRunningMap.size(); +} + +void +RemoteHttpRunner::MonitorThreadFunction() +{ + SetCurrentThreadName("RemoteHttpRunner_Monitor"); + + do + { + const int NormalWaitingTime = 200; + int WaitTimeMs = NormalWaitingTime; + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; + auto SweepOnce = [&] { + const size_t RetiredCount = SweepRunningActions(); + + m_RunningLock.WithSharedLock([&] { + if (m_RemoteRunningMap.size() > 16) + { + WaitTimeMs = NormalWaitingTime / 4; + } + else + { + if (RetiredCount) + { + WaitTimeMs = NormalWaitingTime / 2; + } + else + { + WaitTimeMs = NormalWaitingTime; + } + } + }); + }; + + while (!WaitOnce()) + { + SweepOnce(); + } + + // Signal received - this may mean we should quit + + SweepOnce(); + } while (m_MonitorThreadEnabled); +} + +size_t +RemoteHttpRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SweepRunningActions"); + std::vector<HttpRunningAction> CompletedActions; + + // Poll remote for list of completed actions + + HttpClient::Response ResponseCompleted = m_Http.Get("/jobs/completed"sv); + + if (CbObject Completed = ResponseCompleted.AsObject()) + { + for (auto& FieldIt : Completed["completed"sv]) + { + CbObjectView EntryObj = FieldIt.AsObjectView(); + const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); + std::string_view StateName = EntryObj["state"sv].AsString(); + + RunnerAction::State RemoteState = RunnerAction::FromString(StateName); + + // Always fetch to drain the result from the remote's results map, + // but only keep the result package for successfully completed actions. + HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn)); + + m_RunningLock.WithExclusiveLock([&] { + if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end()) + { + HttpRunningAction CompletedAction = std::move(CompleteIt->second); + CompletedAction.RemoteState = RemoteState; + + if (RemoteState == RunnerAction::State::Completed && ResponseJob) + { + CompletedAction.ActionResults = ResponseJob.AsPackage(); + } + + CompletedActions.push_back(std::move(CompletedAction)); + m_RemoteRunningMap.erase(CompleteIt); + } + else + { + // we received a completion notice for an action we don't know about, + // this can happen if the runner is used by multiple upstream schedulers, + // or if this compute node was recently restarted and lost track of + // previously scheduled actions + } + }); + } + + if (CbObjectView Metrics = Completed["metrics"sv].AsObjectView()) + { + // if (const size_t CpuCount = Metrics["core_count"].AsInt32(0)) + if (const int32_t CpuCount = Metrics["lp_count"].AsInt32(0)) + { + const int32_t NewCap = zen::Max(4, CpuCount); + + if (m_MaxRunningActions > NewCap) + { + ZEN_DEBUG("capping {} to {} actions (was {})", m_BaseUrl, NewCap, m_MaxRunningActions); + + m_MaxRunningActions = NewCap; + } + } + } + } + + // Notify outer. Note that this has to be done without holding any local locks + // otherwise we may end up with deadlocks. + + for (HttpRunningAction& HttpAction : CompletedActions) + { + const int ActionLsn = HttpAction.Action->ActionLsn; + + ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + RunnerAction::ToString(HttpAction.RemoteState)); + + if (HttpAction.RemoteState == RunnerAction::State::Completed) + { + HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); + } + + HttpAction.Action->SetActionState(HttpAction.RemoteState); + } + + return CompletedActions.size(); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h new file mode 100644 index 000000000..9119992a9 --- /dev/null +++ b/src/zencompute/runners/remotehttprunner.h @@ -0,0 +1,100 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "functionrunner.h" + +# include <zencore/compactbinarypackage.h> +# include <zencore/logging.h> +# include <zencore/uid.h> +# include <zencore/workthreadpool.h> +# include <zencore/zencore.h> +# include <zenhttp/httpclient.h> + +# include <atomic> +# include <filesystem> +# include <thread> +# include <unordered_map> + +namespace zen { +class CidStore; +} + +namespace zen::compute { + +/** HTTP-based runner + + This implements a DDC remote compute execution strategy via REST API + + */ + +class RemoteHttpRunner : public FunctionRunner +{ + RemoteHttpRunner(RemoteHttpRunner&&) = delete; + RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete; + +public: + RemoteHttpRunner(ChunkResolver& InChunkResolver, + const std::filesystem::path& BaseDir, + std::string_view HostName, + WorkerThreadPool& InWorkerPool); + ~RemoteHttpRunner(); + + virtual void Shutdown() override; + virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + [[nodiscard]] virtual bool IsHealthy() override; + [[nodiscard]] virtual size_t GetSubmittedActionCount() override; + [[nodiscard]] virtual size_t QueryCapacity() override; + [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override; + virtual void CancelRemoteQueue(int QueueId) override; + + std::string_view GetHostName() const { return m_HostName; } + +protected: + LoggerRef Log() { return m_Log; } + +private: + LoggerRef m_Log; + ChunkResolver& m_ChunkResolver; + WorkerThreadPool& m_WorkerPool; + std::string m_HostName; + std::string m_BaseUrl; + HttpClient m_Http; + + int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + + struct HttpRunningAction + { + Ref<RunnerAction> Action; + int RemoteActionLsn = 0; // Remote LSN + RunnerAction::State RemoteState = RunnerAction::State::Failed; + CbPackage ActionResults; + }; + + RwLock m_RunningLock; + std::unordered_map<int, HttpRunningAction> m_RemoteRunningMap; // Note that this is keyed on the *REMOTE* lsn + + std::thread m_MonitorThread; + std::atomic<bool> m_MonitorThreadEnabled{true}; + Event m_MonitorThreadEvent; + void MonitorThreadFunction(); + size_t SweepRunningActions(); + + RwLock m_QueueTokenLock; + std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId → remote queue token + + // Stable identity for this runner instance, used as part of the idempotency key when + // creating remote queues. Generated once at construction and never changes. + Oid m_InstanceId; + + Oid EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp new file mode 100644 index 000000000..e9a1ae8b6 --- /dev/null +++ b/src/zencompute/runners/windowsrunner.cpp @@ -0,0 +1,460 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "windowsrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/scopeguard.h> +# include <zencore/trace.h> +# include <zencore/system.h> +# include <zencore/timer.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <userenv.h> +# include <aclapi.h> +# include <sddl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +using namespace std::literals; + +WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + if (!m_Sandboxed) + { + return; + } + + // Build a unique profile name per process to avoid collisions + m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); + + // Clean up any stale profile from a previous crash + DeleteAppContainerProfile(m_AppContainerName.c_str()); + + PSID Sid = nullptr; + + HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), + m_AppContainerName.c_str(), // display name + m_AppContainerName.c_str(), // description + nullptr, // no capabilities + 0, // capability count + &Sid); + + if (FAILED(Hr)) + { + throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); + } + + m_AppContainerSid = Sid; + + ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); +} + +WindowsProcessRunner::~WindowsProcessRunner() +{ + if (m_AppContainerSid) + { + FreeSid(m_AppContainerSid); + m_AppContainerSid = nullptr; + } + + if (!m_AppContainerName.empty()) + { + DeleteAppContainerProfile(m_AppContainerName.c_str()); + } +} + +void +WindowsProcessRunner::GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask) +{ + PACL ExistingDacl = nullptr; + PSECURITY_DESCRIPTOR SecurityDescriptor = nullptr; + + DWORD Result = GetNamedSecurityInfoW(Path.c_str(), + SE_FILE_OBJECT, + DACL_SECURITY_INFORMATION, + nullptr, + nullptr, + &ExistingDacl, + nullptr, + &SecurityDescriptor); + + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("GetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } + + auto $0 = MakeGuard([&] { LocalFree(SecurityDescriptor); }); + + EXPLICIT_ACCESSW Access{}; + Access.grfAccessPermissions = AccessMask; + Access.grfAccessMode = SET_ACCESS; + Access.grfInheritance = OBJECT_INHERIT_ACE | CONTAINER_INHERIT_ACE; + Access.Trustee.TrusteeForm = TRUSTEE_IS_SID; + Access.Trustee.TrusteeType = TRUSTEE_IS_WELL_KNOWN_GROUP; + Access.Trustee.ptstrName = static_cast<LPWSTR>(m_AppContainerSid); + + PACL NewDacl = nullptr; + + Result = SetEntriesInAclW(1, &Access, ExistingDacl, &NewDacl); + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("SetEntriesInAclW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } + + auto $1 = MakeGuard([&] { LocalFree(NewDacl); }); + + Result = SetNamedSecurityInfoW(const_cast<LPWSTR>(Path.c_str()), + SE_FILE_OBJECT, + DACL_SECURITY_INFORMATION, + nullptr, + nullptr, + NewDacl, + nullptr); + + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("SetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } +} + +SubmitResult +WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("WindowsProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Set up environment variables + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + StringBuilder<1024> EnvironmentBlock; + + for (auto& It : WorkerDescription["environment"sv]) + { + EnvironmentBlock.Append(It.AsString()); + EnvironmentBlock.Append('\0'); + } + EnvironmentBlock.Append('\0'); + EnvironmentBlock.Append('\0'); + + // Execute process - this spawns the child process immediately without waiting + // for completion + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred(); + + ExtendableWideStringBuilder<512> CommandLine; + CommandLine.Append(L'"'); + CommandLine.Append(ExePath.c_str()); + CommandLine.Append(L'"'); + CommandLine.Append(L" -Build=build.action"); + + LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; + LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; + BOOL bInheritHandles = FALSE; + DWORD dwCreationFlags = 0; + + ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed); + + CommandLine.EnsureNulTerminated(); + + PROCESS_INFORMATION ProcessInformation{}; + + if (m_Sandboxed) + { + // Grant AppContainer access to sandbox and worker directories + GrantAppContainerAccess(Prepared->SandboxPath, FILE_ALL_ACCESS); + GrantAppContainerAccess(Prepared->WorkerPath, FILE_GENERIC_READ | FILE_GENERIC_EXECUTE); + + // Set up extended startup info with AppContainer security capabilities + SECURITY_CAPABILITIES SecurityCapabilities{}; + SecurityCapabilities.AppContainerSid = m_AppContainerSid; + SecurityCapabilities.Capabilities = nullptr; + SecurityCapabilities.CapabilityCount = 0; + + SIZE_T AttrListSize = 0; + InitializeProcThreadAttributeList(nullptr, 1, 0, &AttrListSize); + + auto AttrList = static_cast<PPROC_THREAD_ATTRIBUTE_LIST>(malloc(AttrListSize)); + auto $0 = MakeGuard([&] { free(AttrList); }); + + if (!InitializeProcThreadAttributeList(AttrList, 1, 0, &AttrListSize)) + { + zen::ThrowLastError("InitializeProcThreadAttributeList failed"); + } + + auto $1 = MakeGuard([&] { DeleteProcThreadAttributeList(AttrList); }); + + if (!UpdateProcThreadAttribute(AttrList, + 0, + PROC_THREAD_ATTRIBUTE_SECURITY_CAPABILITIES, + &SecurityCapabilities, + sizeof(SecurityCapabilities), + nullptr, + nullptr)) + { + zen::ThrowLastError("UpdateProcThreadAttribute (SECURITY_CAPABILITIES) failed"); + } + + STARTUPINFOEXW StartupInfoEx{}; + StartupInfoEx.StartupInfo.cb = sizeof(STARTUPINFOEXW); + StartupInfoEx.lpAttributeList = AttrList; + + dwCreationFlags |= EXTENDED_STARTUPINFO_PRESENT; + + BOOL Success = CreateProcessW(nullptr, + CommandLine.Data(), + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags, + (LPVOID)EnvironmentBlock.Data(), + Prepared->SandboxPath.c_str(), + &StartupInfoEx.StartupInfo, + /* out */ &ProcessInformation); + + if (!Success) + { + zen::ThrowLastError("Unable to launch sandboxed process"); + } + } + else + { + STARTUPINFO StartupInfo{}; + StartupInfo.cb = sizeof StartupInfo; + + BOOL Success = CreateProcessW(nullptr, + CommandLine.Data(), + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags, + (LPVOID)EnvironmentBlock.Data(), + Prepared->SandboxPath.c_str(), + &StartupInfo, + /* out */ &ProcessInformation); + + if (!Success) + { + zen::ThrowLastError("Unable to launch process"); + } + } + + CloseHandle(ProcessInformation.hThread); + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = ProcessInformation.hProcess; + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +WindowsProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("WindowsProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + DWORD ExitCode = 0; + BOOL IsSuccess = GetExitCodeProcess(Running->ProcessHandle, &ExitCode); + + if (IsSuccess && ExitCode != STILL_ACTIVE) + { + CloseHandle(Running->ProcessHandle); + Running->ProcessHandle = INVALID_HANDLE_VALUE; + Running->ExitCode = ExitCode; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +WindowsProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("WindowsProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // For expedience we initiate the process termination for all known + // processes before attempting to wait for them to exit. + + // Initiate termination for all known processes before waiting for them to exit. + + for (const auto& Kv : RunningMap) + { + Ref<RunningAction> Running = Kv.second; + + BOOL TermSuccess = TerminateProcess(Running->ProcessHandle, 222); + + if (!TermSuccess) + { + DWORD LastError = GetLastError(); + + if (LastError != ERROR_ACCESS_DENIED) + { + ZEN_WARN("TerminateProcess for LSN {} not successful: {}", Running->Action->ActionLsn, GetSystemErrorAsString(LastError)); + } + } + } + + // Wait for all processes and clean up, regardless of whether TerminateProcess succeeded. + + for (auto& [Lsn, Running] : RunningMap) + { + if (Running->ProcessHandle != INVALID_HANDLE_VALUE) + { + DWORD WaitResult = WaitForSingleObject(Running->ProcessHandle, 2000); + + if (WaitResult != WAIT_OBJECT_0) + { + ZEN_WARN("wait for LSN {}: process exit did not succeed, result = {}", Running->Action->ActionLsn, WaitResult); + } + else + { + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + } + + CloseHandle(Running->ProcessHandle); + Running->ProcessHandle = INVALID_HANDLE_VALUE; + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +WindowsProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("WindowsProcessRunner::CancelAction"); + + // Hold the shared lock while terminating to prevent the sweep thread from + // closing the handle between our lookup and TerminateProcess call. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (Target->ProcessHandle == INVALID_HANDLE_VALUE) + { + return; + } + + BOOL TermSuccess = TerminateProcess(Target->ProcessHandle, 222); + + if (!TermSuccess) + { + DWORD LastError = GetLastError(); + + if (LastError != ERROR_ACCESS_DENIED) + { + ZEN_WARN("CancelAction: TerminateProcess for LSN {} not successful: {}", ActionLsn, GetSystemErrorAsString(LastError)); + } + + return; + } + + ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +void +WindowsProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + FILETIME CreationTime, ExitTime, KernelTime, UserTime; + if (!GetProcessTimes(Running.ProcessHandle, &CreationTime, &ExitTime, &KernelTime, &UserTime)) + { + return; + } + + auto FtToU64 = [](FILETIME Ft) -> uint64_t { return (static_cast<uint64_t>(Ft.dwHighDateTime) << 32) | Ft.dwLowDateTime; }; + + // FILETIME values are in 100-nanosecond intervals + const uint64_t CurrentOsTicks = FtToU64(KernelTime) + FtToU64(UserTime); + const uint64_t NowTicks = GetHifreqTimerValue(); + + // Cumulative CPU seconds (absolute, available from first sample): 100ns → seconds + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 10'000'000.0), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + // 100ns → ms: divide by 10000; then as percent of elapsed ms + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 10000.0 / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h new file mode 100644 index 000000000..9f2385cc4 --- /dev/null +++ b/src/zencompute/runners/windowsrunner.h @@ -0,0 +1,53 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS + +# include <zencore/windows.h> + +# include <string> + +namespace zen::compute { + +/** Windows process runner using CreateProcessW for executing worker executables. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using a Windows AppContainer: + no network access (AppContainer blocks network by default when no capabilities are + granted) and no filesystem access outside explicitly granted sandbox and worker + directories. This requires no elevation. + */ +class WindowsProcessRunner : public LocalProcessRunner +{ +public: + WindowsProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + ~WindowsProcessRunner(); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + void GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask); + + bool m_Sandboxed = false; + PSID m_AppContainerSid = nullptr; + std::wstring m_AppContainerName; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp new file mode 100644 index 000000000..506bec73b --- /dev/null +++ b/src/zencompute/runners/winerunner.cpp @@ -0,0 +1,237 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "winerunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <signal.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); +} + +SubmitResult +WineProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("WineProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: wine <worker_exe_path> -Build=build.action + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath); + std::string ExePathStr = ExePath.string(); + std::string WinePathStr = m_WinePath; + std::string BuildArg = "-Build=build.action"; + + std::vector<char*> ArgV; + ArgV.push_back(WinePathStr.data()); + ArgV.push_back(ExePathStr.data()); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing via Wine: {} {} {}", WinePathStr, ExePathStr, BuildArg); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + throw std::runtime_error(fmt::format("fork() failed: {}", strerror(errno))); + } + + if (ChildPid == 0) + { + // Child process + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(WinePathStr.c_str(), ArgV.data(), Envp.data()); + + // execve only returns on failure + _exit(127); + } + + // Parent: store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +WineProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("WineProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +WineProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("WineProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/winerunner.h b/src/zencompute/runners/winerunner.h new file mode 100644 index 000000000..7df62e7c0 --- /dev/null +++ b/src/zencompute/runners/winerunner.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <string> + +namespace zen::compute { + +/** Wine-based process runner for executing Windows worker executables on Linux. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + */ +class WineProcessRunner : public LocalProcessRunner +{ +public: + WineProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + +private: + std::string m_WinePath = "wine"; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/testing/mockimds.cpp b/src/zencompute/testing/mockimds.cpp new file mode 100644 index 000000000..dd09312df --- /dev/null +++ b/src/zencompute/testing/mockimds.cpp @@ -0,0 +1,205 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/mockimds.h> + +#include <zencore/fmtutils.h> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +const char* +MockImdsService::BaseUri() const +{ + return "/"; +} + +void +MockImdsService::HandleRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // AWS endpoints live under /latest/ + if (Uri.starts_with("latest/")) + { + if (ActiveProvider == CloudProvider::AWS) + { + HandleAwsRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // Azure endpoints live under /metadata/ + if (Uri.starts_with("metadata/")) + { + if (ActiveProvider == CloudProvider::Azure) + { + HandleAzureRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // GCP endpoints live under /computeMetadata/ + if (Uri.starts_with("computeMetadata/")) + { + if (ActiveProvider == CloudProvider::GCP) + { + HandleGcpRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAwsRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // IMDSv2 token acquisition (PUT only) + if (Uri == "latest/api/token" && Request.RequestVerb() == HttpVerb::kPut) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.Token); + return; + } + + // Instance identity + if (Uri == "latest/meta-data/instance-id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.InstanceId); + return; + } + + if (Uri == "latest/meta-data/placement/availability-zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AvailabilityZone); + return; + } + + if (Uri == "latest/meta-data/instance-life-cycle") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.LifeCycle); + return; + } + + // Autoscaling lifecycle state — 404 when not in an ASG + if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") + { + if (Aws.AutoscalingState.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AutoscalingState); + return; + } + + // Spot interruption notice — 404 when no interruption pending + if (Uri == "latest/meta-data/spot/instance-action") + { + if (Aws.SpotAction.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.SpotAction); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAzureRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // Instance metadata (single JSON document) + if (Uri == "metadata/instance") + { + std::string Json = fmt::format(R"({{"compute":{{"vmId":"{}","location":"{}","priority":"{}","vmScaleSetName":"{}"}}}})", + Azure.VmId, + Azure.Location, + Azure.Priority, + Azure.VmScaleSetName); + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + // Scheduled events for termination monitoring + if (Uri == "metadata/scheduledevents") + { + std::string Json; + if (Azure.ScheduledEventType.empty()) + { + Json = R"({"Events":[]})"; + } + else + { + Json = fmt::format(R"({{"Events":[{{"EventType":"{}","EventStatus":"{}"}}]}})", + Azure.ScheduledEventType, + Azure.ScheduledEventStatus); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleGcpRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + if (Uri == "computeMetadata/v1/instance/id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.InstanceId); + return; + } + + if (Uri == "computeMetadata/v1/instance/zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Zone); + return; + } + + if (Uri == "computeMetadata/v1/instance/scheduling/preemptible") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Preemptible); + return; + } + + if (Uri == "computeMetadata/v1/instance/maintenance-event") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.MaintenanceEvent); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/timeline/workertimeline.cpp b/src/zencompute/timeline/workertimeline.cpp new file mode 100644 index 000000000..88ef5b62d --- /dev/null +++ b/src/zencompute/timeline/workertimeline.cpp @@ -0,0 +1,430 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "workertimeline.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/basicfile.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinaryfile.h> + +# include <algorithm> + +namespace zen::compute { + +WorkerTimeline::WorkerTimeline(std::string_view WorkerId) : m_WorkerId(WorkerId) +{ +} + +WorkerTimeline::~WorkerTimeline() +{ +} + +void +WorkerTimeline::RecordProvisioned() +{ + AppendEvent({ + .Type = EventType::WorkerProvisioned, + .Timestamp = DateTime::Now(), + }); +} + +void +WorkerTimeline::RecordDeprovisioned() +{ + AppendEvent({ + .Type = EventType::WorkerDeprovisioned, + .Timestamp = DateTime::Now(), + }); +} + +void +WorkerTimeline::RecordActionAccepted(int ActionLsn, const IoHash& ActionId) +{ + AppendEvent({ + .Type = EventType::ActionAccepted, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + }); +} + +void +WorkerTimeline::RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason) +{ + AppendEvent({ + .Type = EventType::ActionRejected, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + .Reason = std::string(Reason), + }); +} + +void +WorkerTimeline::RecordActionStateChanged(int ActionLsn, + const IoHash& ActionId, + RunnerAction::State PreviousState, + RunnerAction::State NewState) +{ + AppendEvent({ + .Type = EventType::ActionStateChanged, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + .ActionState = NewState, + .PreviousState = PreviousState, + }); +} + +std::vector<WorkerTimeline::Event> +WorkerTimeline::QueryTimeline(DateTime StartTime, DateTime EndTime) const +{ + std::vector<Event> Result; + + m_EventsLock.WithSharedLock([&] { + for (const auto& Evt : m_Events) + { + if (Evt.Timestamp >= StartTime && Evt.Timestamp <= EndTime) + { + Result.push_back(Evt); + } + } + }); + + return Result; +} + +std::vector<WorkerTimeline::Event> +WorkerTimeline::QueryRecent(int Limit) const +{ + std::vector<Event> Result; + + m_EventsLock.WithSharedLock([&] { + const int Count = std::min(Limit, gsl::narrow<int>(m_Events.size())); + auto It = m_Events.end() - Count; + Result.assign(It, m_Events.end()); + }); + + return Result; +} + +size_t +WorkerTimeline::GetEventCount() const +{ + size_t Count = 0; + m_EventsLock.WithSharedLock([&] { Count = m_Events.size(); }); + return Count; +} + +WorkerTimeline::TimeRange +WorkerTimeline::GetTimeRange() const +{ + TimeRange Range; + m_EventsLock.WithSharedLock([&] { + if (!m_Events.empty()) + { + Range.First = m_Events.front().Timestamp; + Range.Last = m_Events.back().Timestamp; + } + }); + return Range; +} + +void +WorkerTimeline::AppendEvent(Event&& Evt) +{ + m_EventsLock.WithExclusiveLock([&] { + while (m_Events.size() >= m_MaxEvents) + { + m_Events.pop_front(); + } + + m_Events.push_back(std::move(Evt)); + }); +} + +const char* +WorkerTimeline::ToString(EventType Type) +{ + switch (Type) + { + case EventType::WorkerProvisioned: + return "provisioned"; + case EventType::WorkerDeprovisioned: + return "deprovisioned"; + case EventType::ActionAccepted: + return "accepted"; + case EventType::ActionRejected: + return "rejected"; + case EventType::ActionStateChanged: + return "state_changed"; + default: + return "unknown"; + } +} + +static WorkerTimeline::EventType +EventTypeFromString(std::string_view Str) +{ + if (Str == "provisioned") + return WorkerTimeline::EventType::WorkerProvisioned; + if (Str == "deprovisioned") + return WorkerTimeline::EventType::WorkerDeprovisioned; + if (Str == "accepted") + return WorkerTimeline::EventType::ActionAccepted; + if (Str == "rejected") + return WorkerTimeline::EventType::ActionRejected; + if (Str == "state_changed") + return WorkerTimeline::EventType::ActionStateChanged; + return WorkerTimeline::EventType::WorkerProvisioned; +} + +void +WorkerTimeline::WriteTo(const std::filesystem::path& Path) const +{ + CbObjectWriter Cbo; + Cbo << "worker_id" << m_WorkerId; + + m_EventsLock.WithSharedLock([&] { + if (!m_Events.empty()) + { + Cbo.AddDateTime("time_first", m_Events.front().Timestamp); + Cbo.AddDateTime("time_last", m_Events.back().Timestamp); + } + + Cbo.BeginArray("events"); + for (const auto& Evt : m_Events) + { + Cbo.BeginObject(); + Cbo << "type" << ToString(Evt.Type); + Cbo.AddDateTime("ts", Evt.Timestamp); + + if (Evt.ActionLsn != 0) + { + Cbo << "lsn" << Evt.ActionLsn; + Cbo << "action_id" << Evt.ActionId; + } + + if (Evt.Type == EventType::ActionStateChanged) + { + Cbo << "prev_state" << static_cast<int32_t>(Evt.PreviousState); + Cbo << "state" << static_cast<int32_t>(Evt.ActionState); + } + + if (!Evt.Reason.empty()) + { + Cbo << "reason" << std::string_view(Evt.Reason); + } + + Cbo.EndObject(); + } + Cbo.EndArray(); + }); + + CbObject Obj = Cbo.Save(); + + BasicFile File(Path, BasicFile::Mode::kTruncate); + File.Write(Obj.GetBuffer().GetView(), 0); +} + +void +WorkerTimeline::ReadFrom(const std::filesystem::path& Path) +{ + CbObjectFromFile Loaded = LoadCompactBinaryObject(Path); + CbObject Root = std::move(Loaded.Object); + + if (!Root) + { + return; + } + + std::deque<Event> LoadedEvents; + + for (CbFieldView Field : Root["events"].AsArrayView()) + { + CbObjectView EventObj = Field.AsObjectView(); + + Event Evt; + Evt.Type = EventTypeFromString(EventObj["type"].AsString()); + Evt.Timestamp = EventObj["ts"].AsDateTime(); + + Evt.ActionLsn = EventObj["lsn"].AsInt32(); + Evt.ActionId = EventObj["action_id"].AsHash(); + + if (Evt.Type == EventType::ActionStateChanged) + { + Evt.PreviousState = static_cast<RunnerAction::State>(EventObj["prev_state"].AsInt32()); + Evt.ActionState = static_cast<RunnerAction::State>(EventObj["state"].AsInt32()); + } + + std::string_view Reason = EventObj["reason"].AsString(); + if (!Reason.empty()) + { + Evt.Reason = std::string(Reason); + } + + LoadedEvents.push_back(std::move(Evt)); + } + + m_EventsLock.WithExclusiveLock([&] { m_Events = std::move(LoadedEvents); }); +} + +WorkerTimeline::TimeRange +WorkerTimeline::ReadTimeRange(const std::filesystem::path& Path) +{ + CbObjectFromFile Loaded = LoadCompactBinaryObject(Path); + + if (!Loaded.Object) + { + return {}; + } + + return { + .First = Loaded.Object["time_first"].AsDateTime(), + .Last = Loaded.Object["time_last"].AsDateTime(), + }; +} + +// WorkerTimelineStore + +static constexpr std::string_view kTimelineExtension = ".ztimeline"; + +WorkerTimelineStore::WorkerTimelineStore(std::filesystem::path PersistenceDir) : m_PersistenceDir(std::move(PersistenceDir)) +{ + std::error_code Ec; + std::filesystem::create_directories(m_PersistenceDir, Ec); +} + +Ref<WorkerTimeline> +WorkerTimelineStore::GetOrCreate(std::string_view WorkerId) +{ + // Fast path: check if it already exists in memory + { + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + return It->second; + } + } + + // Slow path: create under exclusive lock, loading from disk if available + RwLock::ExclusiveLockScope _(m_Lock); + + auto& Entry = m_Timelines[std::string(WorkerId)]; + if (!Entry) + { + Entry = Ref<WorkerTimeline>(new WorkerTimeline(WorkerId)); + + std::filesystem::path Path = TimelinePath(WorkerId); + std::error_code Ec; + if (std::filesystem::is_regular_file(Path, Ec)) + { + Entry->ReadFrom(Path); + } + } + return Entry; +} + +Ref<WorkerTimeline> +WorkerTimelineStore::Find(std::string_view WorkerId) +{ + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + return It->second; + } + return {}; +} + +std::vector<std::string> +WorkerTimelineStore::GetActiveWorkerIds() const +{ + std::vector<std::string> Result; + + RwLock::SharedLockScope $(m_Lock); + Result.reserve(m_Timelines.size()); + for (const auto& [Id, _] : m_Timelines) + { + Result.push_back(Id); + } + + return Result; +} + +std::vector<WorkerTimelineStore::WorkerTimelineInfo> +WorkerTimelineStore::GetAllWorkerInfo() const +{ + std::unordered_map<std::string, WorkerTimeline::TimeRange> InfoMap; + + { + RwLock::SharedLockScope _(m_Lock); + for (const auto& [Id, Timeline] : m_Timelines) + { + InfoMap[Id] = Timeline->GetTimeRange(); + } + } + + std::error_code Ec; + for (const auto& Entry : std::filesystem::directory_iterator(m_PersistenceDir, Ec)) + { + if (!Entry.is_regular_file()) + { + continue; + } + + const auto& Path = Entry.path(); + if (Path.extension().string() != kTimelineExtension) + { + continue; + } + + std::string Id = Path.stem().string(); + if (InfoMap.find(Id) == InfoMap.end()) + { + InfoMap[Id] = WorkerTimeline::ReadTimeRange(Path); + } + } + + std::vector<WorkerTimelineInfo> Result; + Result.reserve(InfoMap.size()); + for (auto& [Id, Range] : InfoMap) + { + Result.push_back({.WorkerId = std::move(Id), .Range = Range}); + } + return Result; +} + +void +WorkerTimelineStore::Save(std::string_view WorkerId) +{ + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + It->second->WriteTo(TimelinePath(WorkerId)); + } +} + +void +WorkerTimelineStore::SaveAll() +{ + RwLock::SharedLockScope _(m_Lock); + for (const auto& [Id, Timeline] : m_Timelines) + { + Timeline->WriteTo(TimelinePath(Id)); + } +} + +std::filesystem::path +WorkerTimelineStore::TimelinePath(std::string_view WorkerId) const +{ + return m_PersistenceDir / (std::string(WorkerId) + std::string(kTimelineExtension)); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/timeline/workertimeline.h b/src/zencompute/timeline/workertimeline.h new file mode 100644 index 000000000..87e19bc28 --- /dev/null +++ b/src/zencompute/timeline/workertimeline.h @@ -0,0 +1,169 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../runners/functionrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zenbase/refcount.h> +# include <zencore/compactbinary.h> +# include <zencore/iohash.h> +# include <zencore/thread.h> +# include <zencore/timer.h> + +# include <deque> +# include <filesystem> +# include <string> +# include <string_view> +# include <unordered_map> +# include <vector> + +namespace zen::compute { + +struct RunnerAction; + +/** Worker activity timeline for tracking and visualizing worker activity over time. + * + * Records worker lifecycle events (provisioning/deprovisioning) and action lifecycle + * events (accept, reject, state changes) with timestamps, enabling time-range queries + * for dashboard visualization. + */ +class WorkerTimeline : public RefCounted +{ +public: + explicit WorkerTimeline(std::string_view WorkerId); + ~WorkerTimeline() override; + + struct TimeRange + { + DateTime First = DateTime(0); + DateTime Last = DateTime(0); + + explicit operator bool() const { return First.GetTicks() != 0; } + }; + + enum class EventType + { + WorkerProvisioned, + WorkerDeprovisioned, + ActionAccepted, + ActionRejected, + ActionStateChanged + }; + + static const char* ToString(EventType Type); + + struct Event + { + EventType Type; + DateTime Timestamp = DateTime(0); + + // Action context (only set for action events) + int ActionLsn = 0; + IoHash ActionId; + RunnerAction::State ActionState = RunnerAction::State::New; + RunnerAction::State PreviousState = RunnerAction::State::New; + + // Optional reason (e.g. rejection reason) + std::string Reason; + }; + + /** Record that this worker has been provisioned and is available for work. */ + void RecordProvisioned(); + + /** Record that this worker has been deprovisioned and is no longer available. */ + void RecordDeprovisioned(); + + /** Record that an action was accepted by this worker. */ + void RecordActionAccepted(int ActionLsn, const IoHash& ActionId); + + /** Record that an action was rejected by this worker. */ + void RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason); + + /** Record an action state transition on this worker. */ + void RecordActionStateChanged(int ActionLsn, const IoHash& ActionId, RunnerAction::State PreviousState, RunnerAction::State NewState); + + /** Query events within a time range (inclusive). Returns events ordered by timestamp. */ + [[nodiscard]] std::vector<Event> QueryTimeline(DateTime StartTime, DateTime EndTime) const; + + /** Query the most recent N events. */ + [[nodiscard]] std::vector<Event> QueryRecent(int Limit = 100) const; + + /** Return the total number of recorded events. */ + [[nodiscard]] size_t GetEventCount() const; + + /** Return the time range covered by the events in this timeline. */ + [[nodiscard]] TimeRange GetTimeRange() const; + + [[nodiscard]] const std::string& GetWorkerId() const { return m_WorkerId; } + + /** Write the timeline to a file at the given path. */ + void WriteTo(const std::filesystem::path& Path) const; + + /** Read the timeline from a file at the given path. Replaces current in-memory events. */ + void ReadFrom(const std::filesystem::path& Path); + + /** Read only the time range from a persisted timeline file, without loading events. */ + [[nodiscard]] static TimeRange ReadTimeRange(const std::filesystem::path& Path); + +private: + void AppendEvent(Event&& Evt); + + std::string m_WorkerId; + mutable RwLock m_EventsLock; + std::deque<Event> m_Events; + size_t m_MaxEvents = 10'000; +}; + +/** Manages a set of WorkerTimeline instances, keyed by worker ID. + * + * Provides thread-safe lookup and on-demand creation of timelines, backed by + * a persistence directory. Each timeline is stored as a separate file named + * {WorkerId}.ztimeline within the directory. + */ +class WorkerTimelineStore +{ +public: + explicit WorkerTimelineStore(std::filesystem::path PersistenceDir); + ~WorkerTimelineStore() = default; + + WorkerTimelineStore(const WorkerTimelineStore&) = delete; + WorkerTimelineStore& operator=(const WorkerTimelineStore&) = delete; + + /** Get the timeline for a worker, creating one if it does not exist. + * If a persisted file exists on disk it will be loaded on first access. */ + Ref<WorkerTimeline> GetOrCreate(std::string_view WorkerId); + + /** Get the timeline for a worker, or null ref if it does not exist in memory. */ + [[nodiscard]] Ref<WorkerTimeline> Find(std::string_view WorkerId); + + /** Return the worker IDs of currently loaded (in-memory) timelines. */ + [[nodiscard]] std::vector<std::string> GetActiveWorkerIds() const; + + struct WorkerTimelineInfo + { + std::string WorkerId; + WorkerTimeline::TimeRange Range; + }; + + /** Return info for all known timelines (in-memory and on-disk), including time range. */ + [[nodiscard]] std::vector<WorkerTimelineInfo> GetAllWorkerInfo() const; + + /** Persist a single worker's timeline to disk. */ + void Save(std::string_view WorkerId); + + /** Persist all in-memory timelines to disk. */ + void SaveAll(); + +private: + [[nodiscard]] std::filesystem::path TimelinePath(std::string_view WorkerId) const; + + std::filesystem::path m_PersistenceDir; + mutable RwLock m_Lock; + std::unordered_map<std::string, Ref<WorkerTimeline>> m_Timelines; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/xmake.lua b/src/zencompute/xmake.lua new file mode 100644 index 000000000..ed0af66a5 --- /dev/null +++ b/src/zencompute/xmake.lua @@ -0,0 +1,19 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zencompute') + set_kind("static") + set_group("libs") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_includedirs(".", {private=true}) + add_deps("zencore", "zenstore", "zenutil", "zennet", "zenhttp") + add_packages("json11") + + if is_os("macosx") then + add_cxxflags("-Wno-deprecated-declarations") + end + + if is_plat("windows") then + add_syslinks("Userenv") + end diff --git a/src/zencompute/zencompute.cpp b/src/zencompute/zencompute.cpp new file mode 100644 index 000000000..1f3f6d3f9 --- /dev/null +++ b/src/zencompute/zencompute.cpp @@ -0,0 +1,21 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/zencompute.h" + +#if ZEN_WITH_TESTS +# include "runners/deferreddeleter.h" +# include <zencompute/cloudmetadata.h> +#endif + +namespace zen { + +void +zencompute_forcelinktests() +{ +#if ZEN_WITH_TESTS + compute::cloudmetadata_forcelink(); + compute::deferreddeleter_forcelink(); +#endif +} + +} // namespace zen diff --git a/src/zencore-test/zencore-test.cpp b/src/zencore-test/zencore-test.cpp index 68fc940ee..3d9a79283 100644 --- a/src/zencore-test/zencore-test.cpp +++ b/src/zencore-test/zencore-test.cpp @@ -1,47 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -// zencore-test.cpp : Defines the entry point for the console application. -// - -#include <zencore/filesystem.h> -#include <zencore/logging.h> -#include <zencore/trace.h> +#include <zencore/testing.h> #include <zencore/zencore.h> #include <zencore/memory/newdelete.h> -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include <zencore/testing.h> -# include <zencore/process.h> -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zencore_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zencore-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zencore-test", zen::zencore_forcelinktests); #else return 0; #endif diff --git a/src/zencore/base64.cpp b/src/zencore/base64.cpp index 1f56ee6c3..96e121799 100644 --- a/src/zencore/base64.cpp +++ b/src/zencore/base64.cpp @@ -1,6 +1,10 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include <zencore/base64.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +#include <string> namespace zen { @@ -11,7 +15,6 @@ static const uint8_t EncodingAlphabet[64] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; /** The table used to convert an ascii character into a 6 bit value */ -#if 0 static const uint8_t DecodingAlphabet[256] = { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x00-0x0f 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x10-0x1f @@ -30,7 +33,6 @@ static const uint8_t DecodingAlphabet[256] = { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xe0-0xef 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF // 0xf0-0xff }; -#endif // 0 template<typename CharType> uint32_t @@ -104,4 +106,194 @@ Base64::Encode(const uint8_t* Source, uint32_t Length, CharType* Dest) template uint32_t Base64::Encode<char>(const uint8_t* Source, uint32_t Length, char* Dest); template uint32_t Base64::Encode<wchar_t>(const uint8_t* Source, uint32_t Length, wchar_t* Dest); +template<typename CharType> +bool +Base64::Decode(const CharType* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength) +{ + // Length must be a multiple of 4 + if (Length % 4 != 0) + { + OutLength = 0; + return false; + } + + uint8_t* DecodedBytes = Dest; + + // Process 4 encoded characters at a time, producing 3 decoded bytes + while (Length > 0) + { + // Count padding characters at the end + uint32_t PadCount = 0; + if (Source[3] == '=') + { + PadCount++; + if (Source[2] == '=') + { + PadCount++; + } + } + + // Look up each character in the decoding table + uint8_t A = DecodingAlphabet[static_cast<uint8_t>(Source[0])]; + uint8_t B = DecodingAlphabet[static_cast<uint8_t>(Source[1])]; + uint8_t C = (PadCount >= 2) ? 0 : DecodingAlphabet[static_cast<uint8_t>(Source[2])]; + uint8_t D = (PadCount >= 1) ? 0 : DecodingAlphabet[static_cast<uint8_t>(Source[3])]; + + // Check for invalid characters (0xFF means not in the base64 alphabet) + if (A == 0xFF || B == 0xFF || C == 0xFF || D == 0xFF) + { + OutLength = 0; + return false; + } + + // Reconstruct the 24-bit value from 4 6-bit chunks + uint32_t ByteTriplet = (A << 18) | (B << 12) | (C << 6) | D; + + // Extract the 3 bytes + *DecodedBytes++ = static_cast<uint8_t>(ByteTriplet >> 16); + if (PadCount < 2) + { + *DecodedBytes++ = static_cast<uint8_t>((ByteTriplet >> 8) & 0xFF); + } + if (PadCount < 1) + { + *DecodedBytes++ = static_cast<uint8_t>(ByteTriplet & 0xFF); + } + + Source += 4; + Length -= 4; + } + + OutLength = uint32_t(DecodedBytes - Dest); + return true; +} + +template bool Base64::Decode<char>(const char* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength); +template bool Base64::Decode<wchar_t>(const wchar_t* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength); + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +using namespace std::string_literals; + +TEST_SUITE_BEGIN("core.base64"); + +TEST_CASE("Base64") +{ + auto EncodeString = [](std::string_view Input) -> std::string { + std::string Result; + Result.resize(Base64::GetEncodedDataSize(uint32_t(Input.size()))); + Base64::Encode(reinterpret_cast<const uint8_t*>(Input.data()), uint32_t(Input.size()), Result.data()); + return Result; + }; + + auto DecodeString = [](std::string_view Input) -> std::string { + std::string Result; + Result.resize(Base64::GetMaxDecodedDataSize(uint32_t(Input.size()))); + uint32_t DecodedLength = 0; + bool Success = Base64::Decode(Input.data(), uint32_t(Input.size()), reinterpret_cast<uint8_t*>(Result.data()), DecodedLength); + CHECK(Success); + Result.resize(DecodedLength); + return Result; + }; + + SUBCASE("Encode") + { + CHECK(EncodeString("") == ""s); + CHECK(EncodeString("f") == "Zg=="s); + CHECK(EncodeString("fo") == "Zm8="s); + CHECK(EncodeString("foo") == "Zm9v"s); + CHECK(EncodeString("foob") == "Zm9vYg=="s); + CHECK(EncodeString("fooba") == "Zm9vYmE="s); + CHECK(EncodeString("foobar") == "Zm9vYmFy"s); + } + + SUBCASE("Decode") + { + CHECK(DecodeString("") == ""s); + CHECK(DecodeString("Zg==") == "f"s); + CHECK(DecodeString("Zm8=") == "fo"s); + CHECK(DecodeString("Zm9v") == "foo"s); + CHECK(DecodeString("Zm9vYg==") == "foob"s); + CHECK(DecodeString("Zm9vYmE=") == "fooba"s); + CHECK(DecodeString("Zm9vYmFy") == "foobar"s); + } + + SUBCASE("RoundTrip") + { + auto RoundTrip = [&](const std::string& Input) { + std::string Encoded = EncodeString(Input); + std::string Decoded = DecodeString(Encoded); + CHECK(Decoded == Input); + }; + + RoundTrip("Hello, World!"); + RoundTrip("Base64 encoding test with various lengths"); + RoundTrip("A"); + RoundTrip("AB"); + RoundTrip("ABC"); + RoundTrip("ABCD"); + RoundTrip("\x00\x01\x02\xff\xfe\xfd"s); + } + + SUBCASE("BinaryRoundTrip") + { + // Test with all byte values 0-255 + uint8_t AllBytes[256]; + for (int i = 0; i < 256; ++i) + { + AllBytes[i] = static_cast<uint8_t>(i); + } + + char Encoded[Base64::GetEncodedDataSize(256) + 1]; + Base64::Encode(AllBytes, 256, Encoded); + + uint8_t Decoded[256]; + uint32_t DecodedLength = 0; + bool Success = Base64::Decode(Encoded, uint32_t(strlen(Encoded)), Decoded, DecodedLength); + CHECK(Success); + CHECK(DecodedLength == 256); + CHECK(memcmp(AllBytes, Decoded, 256) == 0); + } + + SUBCASE("DecodeInvalidInput") + { + uint8_t Dest[64]; + uint32_t OutLength = 0; + + // Length not a multiple of 4 + CHECK_FALSE(Base64::Decode("abc", 3u, Dest, OutLength)); + + // Invalid character + CHECK_FALSE(Base64::Decode("ab!d", 4u, Dest, OutLength)); + } + + SUBCASE("EncodedDataSize") + { + CHECK(Base64::GetEncodedDataSize(0) == 0); + CHECK(Base64::GetEncodedDataSize(1) == 4); + CHECK(Base64::GetEncodedDataSize(2) == 4); + CHECK(Base64::GetEncodedDataSize(3) == 4); + CHECK(Base64::GetEncodedDataSize(4) == 8); + CHECK(Base64::GetEncodedDataSize(5) == 8); + CHECK(Base64::GetEncodedDataSize(6) == 8); + } + + SUBCASE("MaxDecodedDataSize") + { + CHECK(Base64::GetMaxDecodedDataSize(0) == 0); + CHECK(Base64::GetMaxDecodedDataSize(4) == 3); + CHECK(Base64::GetMaxDecodedDataSize(8) == 6); + CHECK(Base64::GetMaxDecodedDataSize(12) == 9); + } +} + +TEST_SUITE_END(); + +#endif + } // namespace zen diff --git a/src/zencore/basicfile.cpp b/src/zencore/basicfile.cpp index bd4d119fb..9dcf7663a 100644 --- a/src/zencore/basicfile.cpp +++ b/src/zencore/basicfile.cpp @@ -888,6 +888,8 @@ WriteToTempFile(CompositeBuffer&& Buffer, const std::filesystem::path& Path) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.basicfile"); + TEST_CASE("BasicFile") { ScopedCurrentDirectoryChange _; @@ -1081,6 +1083,8 @@ TEST_CASE("BasicFileBuffer") } } +TEST_SUITE_END(); + void basicfile_forcelink() { diff --git a/src/zencore/blake3.cpp b/src/zencore/blake3.cpp index 054f0d3a0..55f9b74af 100644 --- a/src/zencore/blake3.cpp +++ b/src/zencore/blake3.cpp @@ -123,7 +123,7 @@ BLAKE3::ToHexString(StringBuilderBase& outBuilder) const char str[65]; ToHexString(str); - outBuilder.AppendRange(str, &str[65]); + outBuilder.AppendRange(str, &str[StringLength]); return outBuilder; } @@ -200,6 +200,8 @@ BLAKE3Stream::GetHash() // return text; // } +TEST_SUITE_BEGIN("core.blake3"); + TEST_CASE("BLAKE3") { SUBCASE("Basics") @@ -237,6 +239,8 @@ TEST_CASE("BLAKE3") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/callstack.cpp b/src/zencore/callstack.cpp index 8aa1111bf..ee0b0625a 100644 --- a/src/zencore/callstack.cpp +++ b/src/zencore/callstack.cpp @@ -260,6 +260,8 @@ GetCallstackRaw(void* CaptureBuffer, int FramesToSkip, int FramesToCapture) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.callstack"); + TEST_CASE("Callstack.Basic") { void* Addresses[4]; @@ -272,6 +274,8 @@ TEST_CASE("Callstack.Basic") } } +TEST_SUITE_END(); + void callstack_forcelink() { diff --git a/src/zencore/commandline.cpp b/src/zencore/commandline.cpp index 426cf23d6..718ef9678 100644 --- a/src/zencore/commandline.cpp +++ b/src/zencore/commandline.cpp @@ -14,6 +14,7 @@ ZEN_THIRD_PARTY_INCLUDES_END # include <crt_externs.h> #endif +#include <locale.h> #include <functional> namespace zen { diff --git a/src/zencore/compactbinary.cpp b/src/zencore/compactbinary.cpp index b43cc18f1..9c81305d0 100644 --- a/src/zencore/compactbinary.cpp +++ b/src/zencore/compactbinary.cpp @@ -1512,6 +1512,8 @@ uson_forcelink() { } +TEST_SUITE_BEGIN("core.compactbinary"); + TEST_CASE("guid") { using namespace std::literals; @@ -1704,8 +1706,6 @@ TEST_CASE("uson.datetime") ////////////////////////////////////////////////////////////////////////// -TEST_SUITE_BEGIN("core.datetime"); - TEST_CASE("core.datetime.compare") { DateTime T1(2000, 12, 13); @@ -1732,10 +1732,6 @@ TEST_CASE("core.datetime.add") CHECK(dT + T1 - T2 == dT1); } -TEST_SUITE_END(); - -TEST_SUITE_BEGIN("core.timespan"); - TEST_CASE("core.timespan.compare") { TimeSpan T1(1000); diff --git a/src/zencore/compactbinarybuilder.cpp b/src/zencore/compactbinarybuilder.cpp index 63c0b9c5c..a9ba30750 100644 --- a/src/zencore/compactbinarybuilder.cpp +++ b/src/zencore/compactbinarybuilder.cpp @@ -710,6 +710,8 @@ usonbuilder_forcelink() // return ""; // } +TEST_SUITE_BEGIN("core.compactbinarybuilder"); + TEST_CASE("usonbuilder.object") { using namespace std::literals; @@ -1530,6 +1532,8 @@ TEST_CASE("usonbuilder.stream") CHECK(ValidateCompactBinary(Object.GetBuffer(), CbValidateMode::All) == CbValidateError::None); } } + +TEST_SUITE_END(); #endif } // namespace zen diff --git a/src/zencore/compactbinaryjson.cpp b/src/zencore/compactbinaryjson.cpp index abbec360a..da560a449 100644 --- a/src/zencore/compactbinaryjson.cpp +++ b/src/zencore/compactbinaryjson.cpp @@ -654,6 +654,8 @@ cbjson_forcelink() { } +TEST_SUITE_BEGIN("core.compactbinaryjson"); + TEST_CASE("uson.json") { using namespace std::literals; @@ -872,6 +874,8 @@ TEST_CASE("json.uson") } } +TEST_SUITE_END(); + #endif // ZEN_WITH_TESTS } // namespace zen diff --git a/src/zencore/compactbinarypackage.cpp b/src/zencore/compactbinarypackage.cpp index ffe64f2e9..56a292ca6 100644 --- a/src/zencore/compactbinarypackage.cpp +++ b/src/zencore/compactbinarypackage.cpp @@ -805,6 +805,8 @@ usonpackage_forcelink() { } +TEST_SUITE_BEGIN("core.compactbinarypackage"); + TEST_CASE("usonpackage") { using namespace std::literals; @@ -1343,6 +1345,8 @@ TEST_CASE("usonpackage.invalidpackage") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/compactbinaryvalidation.cpp b/src/zencore/compactbinaryvalidation.cpp index d7292f405..3e78f8ef1 100644 --- a/src/zencore/compactbinaryvalidation.cpp +++ b/src/zencore/compactbinaryvalidation.cpp @@ -753,10 +753,14 @@ usonvalidation_forcelink() { } +TEST_SUITE_BEGIN("core.compactbinaryvalidation"); + TEST_CASE("usonvalidation") { SUBCASE("Basic") {} } + +TEST_SUITE_END(); #endif } // namespace zen diff --git a/src/zencore/compactbinaryyaml.cpp b/src/zencore/compactbinaryyaml.cpp index 5122e952a..b7f2c55df 100644 --- a/src/zencore/compactbinaryyaml.cpp +++ b/src/zencore/compactbinaryyaml.cpp @@ -14,11 +14,6 @@ #include <string_view> #include <vector> -ZEN_THIRD_PARTY_INCLUDES_START -#include <ryml.hpp> -#include <ryml_std.hpp> -ZEN_THIRD_PARTY_INCLUDES_END - namespace zen { ////////////////////////////////////////////////////////////////////////// @@ -26,193 +21,349 @@ namespace zen { class CbYamlWriter { public: - explicit CbYamlWriter(StringBuilderBase& InBuilder) : m_StrBuilder(InBuilder) { m_NodeStack.push_back(m_Tree.rootref()); } + explicit CbYamlWriter(StringBuilderBase& InBuilder) : m_Builder(InBuilder) {} void WriteField(CbFieldView Field) { - ryml::NodeRef Node; + CbValue Accessor = Field.GetValue(); + CbFieldType Type = Accessor.GetType(); - if (m_IsFirst) + switch (Type) { - Node = Top(); + case CbFieldType::Object: + case CbFieldType::UniformObject: + WriteMapEntries(Field, 0); + break; + case CbFieldType::Array: + case CbFieldType::UniformArray: + WriteSeqEntries(Field, 0); + break; + default: + WriteScalarValue(Field); + m_Builder << '\n'; + break; + } + } + + void WriteMapEntry(CbFieldView Field, int32_t Indent) + { + WriteIndent(Indent); + WriteMapEntryContent(Field, Indent); + } + + void WriteSeqEntry(CbFieldView Field, int32_t Indent) + { + CbValue Accessor = Field.GetValue(); + CbFieldType Type = Accessor.GetType(); - m_IsFirst = false; + if (Type == CbFieldType::Object || Type == CbFieldType::UniformObject) + { + bool First = true; + for (CbFieldView MapChild : Field) + { + if (First) + { + WriteIndent(Indent); + m_Builder << "- "; + First = false; + } + else + { + WriteIndent(Indent + 1); + } + WriteMapEntryContent(MapChild, Indent + 1); + } + } + else if (Type == CbFieldType::Array || Type == CbFieldType::UniformArray) + { + WriteIndent(Indent); + m_Builder << "-\n"; + WriteSeqEntries(Field, Indent + 1); } else { - Node = Top().append_child(); + WriteIndent(Indent); + m_Builder << "- "; + WriteScalarValue(Field); + m_Builder << '\n'; } + } - if (std::u8string_view Name = Field.GetU8Name(); !Name.empty()) +private: + void WriteMapEntries(CbFieldView MapField, int32_t Indent) + { + for (CbFieldView Child : MapField) { - Node.set_key_serialized(ryml::csubstr((const char*)Name.data(), Name.size())); + WriteIndent(Indent); + WriteMapEntryContent(Child, Indent); } + } + + void WriteMapEntryContent(CbFieldView Field, int32_t Indent) + { + std::u8string_view Name = Field.GetU8Name(); + m_Builder << std::string_view(reinterpret_cast<const char*>(Name.data()), Name.size()); - switch (CbValue Accessor = Field.GetValue(); Accessor.GetType()) + CbValue Accessor = Field.GetValue(); + CbFieldType Type = Accessor.GetType(); + + if (IsContainer(Type)) { - case CbFieldType::Null: - Node.set_val("null"); - break; - case CbFieldType::Object: - case CbFieldType::UniformObject: - Node |= ryml::MAP; - m_NodeStack.push_back(Node); - for (CbFieldView It : Field) + m_Builder << ":\n"; + WriteFieldValue(Field, Indent + 1); + } + else + { + m_Builder << ": "; + WriteScalarValue(Field); + m_Builder << '\n'; + } + } + + void WriteSeqEntries(CbFieldView SeqField, int32_t Indent) + { + for (CbFieldView Child : SeqField) + { + CbValue Accessor = Child.GetValue(); + CbFieldType Type = Accessor.GetType(); + + if (Type == CbFieldType::Object || Type == CbFieldType::UniformObject) + { + bool First = true; + for (CbFieldView MapChild : Child) { - WriteField(It); + if (First) + { + WriteIndent(Indent); + m_Builder << "- "; + First = false; + } + else + { + WriteIndent(Indent + 1); + } + WriteMapEntryContent(MapChild, Indent + 1); } - m_NodeStack.pop_back(); + } + else if (Type == CbFieldType::Array || Type == CbFieldType::UniformArray) + { + WriteIndent(Indent); + m_Builder << "-\n"; + WriteSeqEntries(Child, Indent + 1); + } + else + { + WriteIndent(Indent); + m_Builder << "- "; + WriteScalarValue(Child); + m_Builder << '\n'; + } + } + } + + void WriteFieldValue(CbFieldView Field, int32_t Indent) + { + CbValue Accessor = Field.GetValue(); + CbFieldType Type = Accessor.GetType(); + + switch (Type) + { + case CbFieldType::Object: + case CbFieldType::UniformObject: + WriteMapEntries(Field, Indent); break; case CbFieldType::Array: case CbFieldType::UniformArray: - Node |= ryml::SEQ; - m_NodeStack.push_back(Node); - for (CbFieldView It : Field) - { - WriteField(It); - } - m_NodeStack.pop_back(); + WriteSeqEntries(Field, Indent); break; - case CbFieldType::Binary: - { - ExtendableStringBuilder<256> Builder; - const MemoryView Value = Accessor.AsBinary(); - ZEN_ASSERT(Value.GetSize() <= 512 * 1024 * 1024); - const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); - const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); - Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); - - Node.set_key_serialized(Builder.c_str()); - } + case CbFieldType::CustomById: + WriteCustomById(Field.GetValue().AsCustomById(), Indent); break; - case CbFieldType::String: - { - const std::u8string_view U8String = Accessor.AsU8String(); - Node.set_val(ryml::csubstr((const char*)U8String.data(), U8String.size())); - } + case CbFieldType::CustomByName: + WriteCustomByName(Field.GetValue().AsCustomByName(), Indent); + break; + default: + WriteScalarValue(Field); + m_Builder << '\n'; + break; + } + } + + void WriteScalarValue(CbFieldView Field) + { + CbValue Accessor = Field.GetValue(); + switch (Accessor.GetType()) + { + case CbFieldType::Null: + m_Builder << "null"; + break; + case CbFieldType::BoolFalse: + m_Builder << "false"; + break; + case CbFieldType::BoolTrue: + m_Builder << "true"; break; case CbFieldType::IntegerPositive: - Node << Accessor.AsIntegerPositive(); + m_Builder << Accessor.AsIntegerPositive(); break; case CbFieldType::IntegerNegative: - Node << Accessor.AsIntegerNegative(); + m_Builder << Accessor.AsIntegerNegative(); break; case CbFieldType::Float32: if (const float Value = Accessor.AsFloat32(); std::isfinite(Value)) - { - Node << Value; - } + m_Builder.Append(fmt::format("{}", Value)); else - { - Node << "null"; - } + m_Builder << "null"; break; case CbFieldType::Float64: if (const double Value = Accessor.AsFloat64(); std::isfinite(Value)) - { - Node << Value; - } + m_Builder.Append(fmt::format("{}", Value)); else + m_Builder << "null"; + break; + case CbFieldType::String: { - Node << "null"; + const std::u8string_view U8String = Accessor.AsU8String(); + WriteString(std::string_view(reinterpret_cast<const char*>(U8String.data()), U8String.size())); } break; - case CbFieldType::BoolFalse: - Node << "false"; - break; - case CbFieldType::BoolTrue: - Node << "true"; + case CbFieldType::Hash: + WriteString(Accessor.AsHash().ToHexString()); break; case CbFieldType::ObjectAttachment: case CbFieldType::BinaryAttachment: - Node << Accessor.AsAttachment().ToHexString(); - break; - case CbFieldType::Hash: - Node << Accessor.AsHash().ToHexString(); + WriteString(Accessor.AsAttachment().ToHexString()); break; case CbFieldType::Uuid: - Node << fmt::format("{}", Accessor.AsUuid()); + WriteString(fmt::format("{}", Accessor.AsUuid())); break; case CbFieldType::DateTime: - Node << DateTime(Accessor.AsDateTimeTicks()).ToIso8601(); + WriteString(DateTime(Accessor.AsDateTimeTicks()).ToIso8601()); break; case CbFieldType::TimeSpan: if (const TimeSpan Span(Accessor.AsTimeSpanTicks()); Span.GetDays() == 0) - { - Node << Span.ToString("%h:%m:%s.%n"); - } + WriteString(Span.ToString("%h:%m:%s.%n")); else - { - Node << Span.ToString("%d.%h:%m:%s.%n"); - } + WriteString(Span.ToString("%d.%h:%m:%s.%n")); break; case CbFieldType::ObjectId: - Node << fmt::format("{}", Accessor.AsObjectId()); + WriteString(fmt::format("{}", Accessor.AsObjectId())); break; - case CbFieldType::CustomById: - { - CbCustomById Custom = Accessor.AsCustomById(); + case CbFieldType::Binary: + WriteBase64(Accessor.AsBinary()); + break; + default: + ZEN_ASSERT_FORMAT(false, "invalid field type: {}", uint8_t(Accessor.GetType())); + break; + } + } - Node |= ryml::MAP; + void WriteCustomById(CbCustomById Custom, int32_t Indent) + { + WriteIndent(Indent); + m_Builder << "Id: "; + m_Builder.Append(fmt::format("{}", Custom.Id)); + m_Builder << '\n'; + + WriteIndent(Indent); + m_Builder << "Data: "; + WriteBase64(Custom.Data); + m_Builder << '\n'; + } - ryml::NodeRef IdNode = Node.append_child(); - IdNode.set_key("Id"); - IdNode.set_val_serialized(fmt::format("{}", Custom.Id)); + void WriteCustomByName(CbCustomByName Custom, int32_t Indent) + { + WriteIndent(Indent); + m_Builder << "Name: "; + WriteString(std::string_view(reinterpret_cast<const char*>(Custom.Name.data()), Custom.Name.size())); + m_Builder << '\n'; + + WriteIndent(Indent); + m_Builder << "Data: "; + WriteBase64(Custom.Data); + m_Builder << '\n'; + } - ryml::NodeRef DataNode = Node.append_child(); - DataNode.set_key("Data"); + void WriteBase64(MemoryView Value) + { + ZEN_ASSERT(Value.GetSize() <= 512 * 1024 * 1024); + ExtendableStringBuilder<256> Buf; + const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); + const size_t EncodedIndex = Buf.AddUninitialized(size_t(EncodedSize)); + Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Buf.Data() + EncodedIndex); + WriteString(Buf.ToView()); + } - ExtendableStringBuilder<256> Builder; - const MemoryView& Value = Custom.Data; - const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); - const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); - Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); + void WriteString(std::string_view Str) + { + if (NeedsQuoting(Str)) + { + m_Builder << '\''; + for (char C : Str) + { + if (C == '\'') + m_Builder << "''"; + else + m_Builder << C; + } + m_Builder << '\''; + } + else + { + m_Builder << Str; + } + } - DataNode.set_val_serialized(Builder.c_str()); - } - break; - case CbFieldType::CustomByName: - { - CbCustomByName Custom = Accessor.AsCustomByName(); + void WriteIndent(int32_t Indent) + { + for (int32_t I = 0; I < Indent; ++I) + m_Builder << " "; + } - Node |= ryml::MAP; + static bool NeedsQuoting(std::string_view Str) + { + if (Str.empty()) + return false; - ryml::NodeRef NameNode = Node.append_child(); - NameNode.set_key("Name"); - std::string_view Name = std::string_view((const char*)Custom.Name.data(), Custom.Name.size()); - NameNode.set_val_serialized(std::string(Name)); + char First = Str[0]; + if (First == ' ' || First == '\n' || First == '\t' || First == '\r' || First == '*' || First == '&' || First == '%' || + First == '@' || First == '`') + return true; - ryml::NodeRef DataNode = Node.append_child(); - DataNode.set_key("Data"); + if (Str.size() >= 2 && Str[0] == '<' && Str[1] == '<') + return true; - ExtendableStringBuilder<256> Builder; - const MemoryView& Value = Custom.Data; - const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); - const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); - Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); + char Last = Str.back(); + if (Last == ' ' || Last == '\n' || Last == '\t' || Last == '\r') + return true; - DataNode.set_val_serialized(Builder.c_str()); - } - break; - default: - ZEN_ASSERT_FORMAT(false, "invalid field type: {}", uint8_t(Accessor.GetType())); - break; + for (char C : Str) + { + if (C == '#' || C == ':' || C == '-' || C == '?' || C == ',' || C == '\n' || C == '{' || C == '}' || C == '[' || C == ']' || + C == '\'' || C == '"') + return true; } - if (m_NodeStack.size() == 1) + return false; + } + + static bool IsContainer(CbFieldType Type) + { + switch (Type) { - std::string Yaml = ryml::emitrs_yaml<std::string>(m_Tree); - m_StrBuilder << Yaml; + case CbFieldType::Object: + case CbFieldType::UniformObject: + case CbFieldType::Array: + case CbFieldType::UniformArray: + case CbFieldType::CustomById: + case CbFieldType::CustomByName: + return true; + default: + return false; } } -private: - StringBuilderBase& m_StrBuilder; - bool m_IsFirst = true; - - ryml::Tree m_Tree; - std::vector<ryml::NodeRef> m_NodeStack; - ryml::NodeRef& Top() { return m_NodeStack.back(); } + StringBuilderBase& m_Builder; }; void @@ -229,12 +380,40 @@ CompactBinaryToYaml(const CbArrayView& Array, StringBuilderBase& Builder) Writer.WriteField(Array.AsFieldView()); } +void +CompactBinaryToYaml(MemoryView Data, StringBuilderBase& InBuilder) +{ + std::vector<CbFieldView> Fields = ReadCompactBinaryStream(Data); + if (Fields.empty()) + return; + + CbYamlWriter Writer(InBuilder); + if (Fields.size() == 1) + { + Writer.WriteField(Fields[0]); + return; + } + + if (Fields[0].HasName()) + { + for (const CbFieldView& Field : Fields) + Writer.WriteMapEntry(Field, 0); + } + else + { + for (const CbFieldView& Field : Fields) + Writer.WriteSeqEntry(Field, 0); + } +} + #if ZEN_WITH_TESTS void cbyaml_forcelink() { } +TEST_SUITE_BEGIN("core.compactbinaryyaml"); + TEST_CASE("uson.yaml") { using namespace std::literals; @@ -347,6 +526,8 @@ mixed_seq: )"sv); } } + +TEST_SUITE_END(); #endif } // namespace zen diff --git a/src/zencore/compositebuffer.cpp b/src/zencore/compositebuffer.cpp index 252ac9045..ed2b16384 100644 --- a/src/zencore/compositebuffer.cpp +++ b/src/zencore/compositebuffer.cpp @@ -297,6 +297,9 @@ CompositeBuffer::IterateRange(uint64_t Offset, } #if ZEN_WITH_TESTS + +TEST_SUITE_BEGIN("core.compositebuffer"); + TEST_CASE("CompositeBuffer Null") { CompositeBuffer Buffer; @@ -462,6 +465,8 @@ TEST_CASE("CompositeBuffer Composite") TestIterateRange(8, 0, MakeMemoryView(FlatArray).Mid(8, 0), FlatView2); } +TEST_SUITE_END(); + void compositebuffer_forcelink() { diff --git a/src/zencore/compress.cpp b/src/zencore/compress.cpp index 25ed0fc46..6aa0adce0 100644 --- a/src/zencore/compress.cpp +++ b/src/zencore/compress.cpp @@ -2420,6 +2420,8 @@ private: #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.compress"); + TEST_CASE("CompressedBuffer") { uint8_t Zeroes[1024]{}; @@ -2967,6 +2969,8 @@ TEST_CASE("CompressedBufferReader") } } +TEST_SUITE_END(); + void compress_forcelink() { diff --git a/src/zencore/crypto.cpp b/src/zencore/crypto.cpp index 09eebb6ae..049854b42 100644 --- a/src/zencore/crypto.cpp +++ b/src/zencore/crypto.cpp @@ -449,6 +449,8 @@ crypto_forcelink() { } +TEST_SUITE_BEGIN("core.crypto"); + TEST_CASE("crypto.bits") { using CryptoBits256Bit = CryptoBits<256>; @@ -500,6 +502,8 @@ TEST_CASE("crypto.aes") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 92a065707..8ed63565c 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -194,7 +194,7 @@ WipeDirectory(const wchar_t* DirPath, bool KeepDotFiles) FindClose(hFind); } - return true; + return Success; } bool @@ -1022,7 +1022,7 @@ TryCloneFile(const std::filesystem::path& FromPath, const std::filesystem::path& return false; } fchmod(ToFd, 0666); - ScopedFd $To = { FromFd }; + ScopedFd $To = { ToFd }; ioctl(ToFd, FICLONE, FromFd); @@ -1112,7 +1112,8 @@ CopyFile(const std::filesystem::path& FromPath, const std::filesystem::path& ToP size_t FileSizeBytes = Stat.st_size; - fchown(ToFd, Stat.st_uid, Stat.st_gid); + int $Ignore = fchown(ToFd, Stat.st_uid, Stat.st_gid); + ZEN_UNUSED($Ignore); // What's the appropriate error handling here? // Copy impl const size_t BufferSize = Min(FileSizeBytes, 64u << 10); @@ -1326,11 +1327,6 @@ ReadFile(void* NativeHandle, void* Data, uint64_t Size, uint64_t FileOffset, uin { BytesRead = size_t(dwNumberOfBytesRead); } - else if ((BytesRead != NumberOfBytesToRead)) - { - Ec = MakeErrorCode(ERROR_HANDLE_EOF); - return; - } else { Ec = MakeErrorCodeFromLastError(); @@ -1344,20 +1340,15 @@ ReadFile(void* NativeHandle, void* Data, uint64_t Size, uint64_t FileOffset, uin { BytesRead = size_t(ReadResult); } - else if ((BytesRead != NumberOfBytesToRead)) - { - Ec = MakeErrorCode(EIO); - return; - } else { Ec = MakeErrorCodeFromLastError(); return; } #endif - Size -= NumberOfBytesToRead; - FileOffset += NumberOfBytesToRead; - Data = reinterpret_cast<uint8_t*>(Data) + NumberOfBytesToRead; + Size -= BytesRead; + FileOffset += BytesRead; + Data = reinterpret_cast<uint8_t*>(Data) + BytesRead; } } @@ -1408,7 +1399,7 @@ WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t Buffer const uint64_t ChunkSize = Min<uint64_t>(WriteSize, uint64_t(2) * 1024 * 1024 * 1024); #if ZEN_PLATFORM_WINDOWS - hRes = Outfile.Write(DataPtr, gsl::narrow_cast<uint32_t>(WriteSize)); + hRes = Outfile.Write(DataPtr, gsl::narrow_cast<uint32_t>(ChunkSize)); if (FAILED(hRes)) { Outfile.Close(); @@ -1417,7 +1408,7 @@ WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t Buffer ThrowSystemException(hRes, fmt::format("File write failed for '{}'", Path).c_str()); } #else - if (write(Fd, DataPtr, WriteSize) != int64_t(WriteSize)) + if (write(Fd, DataPtr, ChunkSize) != int64_t(ChunkSize)) { close(Fd); std::error_code DummyEc; @@ -3069,7 +3060,7 @@ SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly) } void -MakeSafeAbsolutePathÍnPlace(std::filesystem::path& Path) +MakeSafeAbsolutePathInPlace(std::filesystem::path& Path) { if (!Path.empty()) { @@ -3091,7 +3082,7 @@ std::filesystem::path MakeSafeAbsolutePath(const std::filesystem::path& Path) { std::filesystem::path Tmp(Path); - MakeSafeAbsolutePathÍnPlace(Tmp); + MakeSafeAbsolutePathInPlace(Tmp); return Tmp; } @@ -3319,6 +3310,8 @@ filesystem_forcelink() { } +TEST_SUITE_BEGIN("core.filesystem"); + TEST_CASE("filesystem") { using namespace std::filesystem; @@ -3543,7 +3536,6 @@ TEST_CASE("PathBuilder") Path.Reset(); Path.Append(fspath(L"/\u0119oo/")); Path /= L"bar"; - printf("%ls\n", Path.ToPath().c_str()); CHECK(Path.ToView() == L"/\u0119oo/bar"); CHECK(Path.ToPath() == L"\\\u0119oo\\bar"); # endif @@ -3614,6 +3606,8 @@ TEST_CASE("SharedMemory") CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, false)); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/include/zencore/base64.h b/src/zencore/include/zencore/base64.h index 4d78b085f..08d9f3043 100644 --- a/src/zencore/include/zencore/base64.h +++ b/src/zencore/include/zencore/base64.h @@ -11,7 +11,11 @@ struct Base64 template<typename CharType> static uint32_t Encode(const uint8_t* Source, uint32_t Length, CharType* Dest); + template<typename CharType> + static bool Decode(const CharType* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength); + static inline constexpr int32_t GetEncodedDataSize(uint32_t Size) { return ((Size + 2) / 3) * 4; } + static inline constexpr int32_t GetMaxDecodedDataSize(uint32_t Length) { return (Length / 4) * 3; } }; } // namespace zen diff --git a/src/zencore/include/zencore/blockingqueue.h b/src/zencore/include/zencore/blockingqueue.h index e91fdc659..b6c93e937 100644 --- a/src/zencore/include/zencore/blockingqueue.h +++ b/src/zencore/include/zencore/blockingqueue.h @@ -2,6 +2,8 @@ #pragma once +#include <zencore/zencore.h> // For ZEN_ASSERT + #include <atomic> #include <condition_variable> #include <deque> diff --git a/src/zencore/include/zencore/compactbinaryfile.h b/src/zencore/include/zencore/compactbinaryfile.h index 00c37e941..33f3e7bea 100644 --- a/src/zencore/include/zencore/compactbinaryfile.h +++ b/src/zencore/include/zencore/compactbinaryfile.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include <zencore/compactbinary.h> #include <zencore/iohash.h> diff --git a/src/zencore/include/zencore/compactbinaryvalue.h b/src/zencore/include/zencore/compactbinaryvalue.h index aa2d2821d..4ce8009b8 100644 --- a/src/zencore/include/zencore/compactbinaryvalue.h +++ b/src/zencore/include/zencore/compactbinaryvalue.h @@ -128,17 +128,21 @@ CbValue::AsString(CbFieldError* OutError, std::string_view Default) const uint32_t ValueSizeByteCount; const uint64_t ValueSize = ReadVarUInt(Chars, ValueSizeByteCount); - if (OutError) + if (ValueSize >= (uint64_t(1) << 31)) { - if (ValueSize >= (uint64_t(1) << 31)) + if (OutError) { *OutError = CbFieldError::RangeError; - return Default; } + return Default; + } + + if (OutError) + { *OutError = CbFieldError::None; } - return std::string_view(Chars + ValueSizeByteCount, int32_t(ValueSize)); + return std::string_view(Chars + ValueSizeByteCount, size_t(ValueSize)); } inline std::u8string_view @@ -148,17 +152,21 @@ CbValue::AsU8String(CbFieldError* OutError, std::u8string_view Default) const uint32_t ValueSizeByteCount; const uint64_t ValueSize = ReadVarUInt(Chars, ValueSizeByteCount); - if (OutError) + if (ValueSize >= (uint64_t(1) << 31)) { - if (ValueSize >= (uint64_t(1) << 31)) + if (OutError) { *OutError = CbFieldError::RangeError; - return Default; } + return Default; + } + + if (OutError) + { *OutError = CbFieldError::None; } - return std::u8string_view(Chars + ValueSizeByteCount, int32_t(ValueSize)); + return std::u8string_view(Chars + ValueSizeByteCount, size_t(ValueSize)); } inline uint64_t diff --git a/src/zencore/include/zencore/filesystem.h b/src/zencore/include/zencore/filesystem.h index f28863679..16e2b59f8 100644 --- a/src/zencore/include/zencore/filesystem.h +++ b/src/zencore/include/zencore/filesystem.h @@ -64,80 +64,80 @@ std::filesystem::path PathFromHandle(void* NativeHandle, std::error_code& Ec); */ std::filesystem::path CanonicalPath(std::filesystem::path InPath, std::error_code& Ec); -/** Query file size +/** Check if a path exists and is a regular file (throws) */ bool IsFile(const std::filesystem::path& Path); -/** Query file size +/** Check if a path exists and is a regular file (does not throw) */ bool IsFile(const std::filesystem::path& Path, std::error_code& Ec); -/** Query file size +/** Check if a path exists and is a directory (throws) */ bool IsDir(const std::filesystem::path& Path); -/** Query file size +/** Check if a path exists and is a directory (does not throw) */ bool IsDir(const std::filesystem::path& Path, std::error_code& Ec); -/** Query file size +/** Delete file at path, if it exists (throws) */ bool RemoveFile(const std::filesystem::path& Path); -/** Query file size +/** Delete file at path, if it exists (does not throw) */ bool RemoveFile(const std::filesystem::path& Path, std::error_code& Ec); -/** Query file size +/** Delete directory at path, if it exists (throws) */ bool RemoveDir(const std::filesystem::path& Path); -/** Query file size +/** Delete directory at path, if it exists (does not throw) */ bool RemoveDir(const std::filesystem::path& Path, std::error_code& Ec); -/** Query file size +/** Query file size (throws) */ uint64_t FileSizeFromPath(const std::filesystem::path& Path); -/** Query file size +/** Query file size (does not throw) */ uint64_t FileSizeFromPath(const std::filesystem::path& Path, std::error_code& Ec); -/** Query file size from native file handle +/** Query file size from native file handle (throws) */ uint64_t FileSizeFromHandle(void* NativeHandle); -/** Query file size from native file handle +/** Query file size from native file handle (does not throw) */ uint64_t FileSizeFromHandle(void* NativeHandle, std::error_code& Ec); /** Get a native time tick of last modification time */ -uint64_t GetModificationTickFromHandle(void* NativeHandle, std::error_code& Ec); +uint64_t GetModificationTickFromPath(const std::filesystem::path& Filename); /** Get a native time tick of last modification time */ -uint64_t GetModificationTickFromPath(const std::filesystem::path& Filename); +uint64_t GetModificationTickFromHandle(void* NativeHandle, std::error_code& Ec); bool TryGetFileProperties(const std::filesystem::path& Path, uint64_t& OutSize, uint64_t& OutModificationTick, uint32_t& OutNativeModeOrAttributes); -/** Move a file, if the files are not on the same drive the function will fail +/** Move/rename a file, if the files are not on the same drive the function will fail (throws) */ void RenameFile(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); -/** Move a file, if the files are not on the same drive the function will fail +/** Move/rename a file, if the files are not on the same drive the function will fail */ void RenameFile(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath, std::error_code& Ec); -/** Move a directory, if the files are not on the same drive the function will fail +/** Move/rename a directory, if the files are not on the same drive the function will fail (throws) */ void RenameDirectory(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); -/** Move a directory, if the files are not on the same drive the function will fail +/** Move/rename a directory, if the files are not on the same drive the function will fail */ void RenameDirectory(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath, std::error_code& Ec); @@ -421,7 +421,7 @@ uint32_t MakeFileModeReadOnly(uint32_t FileMode, bool ReadOnly); bool SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly, std::error_code& Ec); bool SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly); -void MakeSafeAbsolutePathÍnPlace(std::filesystem::path& Path); +void MakeSafeAbsolutePathInPlace(std::filesystem::path& Path); [[nodiscard]] std::filesystem::path MakeSafeAbsolutePath(const std::filesystem::path& Path); class SharedMemory diff --git a/src/zencore/include/zencore/hashutils.h b/src/zencore/include/zencore/hashutils.h index 4e877e219..8abfd4b6e 100644 --- a/src/zencore/include/zencore/hashutils.h +++ b/src/zencore/include/zencore/hashutils.h @@ -2,6 +2,10 @@ #pragma once +#include <cstddef> +#include <functional> +#include <type_traits> + namespace zen { template<typename T> diff --git a/src/zencore/include/zencore/iobuffer.h b/src/zencore/include/zencore/iobuffer.h index 182768ff6..82c201edd 100644 --- a/src/zencore/include/zencore/iobuffer.h +++ b/src/zencore/include/zencore/iobuffer.h @@ -426,22 +426,39 @@ private: class IoBufferBuilder { public: - static IoBuffer MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset = 0, uint64_t Size = ~0ull); - static IoBuffer MakeFromTemporaryFile(const std::filesystem::path& FileName); - static IoBuffer MakeFromFileHandle(void* FileHandle, uint64_t Offset = 0, uint64_t Size = ~0ull); - /** Make sure buffer data is memory resident, but avoid memory mapping data from files - */ - static IoBuffer ReadFromFileMaybe(const IoBuffer& InBuffer); - inline static IoBuffer MakeFromMemory(MemoryView Memory) { return IoBuffer(IoBuffer::Wrap, Memory.GetData(), Memory.GetSize()); } - inline static IoBuffer MakeCloneFromMemory(const void* Ptr, size_t Sz) + static IoBuffer MakeFromFile(const std::filesystem::path& FileName, + uint64_t Offset = 0, + uint64_t Size = ~0ull, + ZenContentType ContentType = ZenContentType::kBinary); + static IoBuffer MakeFromTemporaryFile(const std::filesystem::path& FileName, ZenContentType ContentType = ZenContentType::kBinary); + static IoBuffer MakeFromFileHandle(void* FileHandle, + uint64_t Offset = 0, + uint64_t Size = ~0ull, + ZenContentType ContentType = ZenContentType::kBinary); + inline static IoBuffer MakeFromMemory(MemoryView Memory, ZenContentType ContentType = ZenContentType::kBinary) + { + IoBuffer NewBuffer(IoBuffer::Wrap, Memory.GetData(), Memory.GetSize()); + NewBuffer.SetContentType(ContentType); + return NewBuffer; + } + inline static IoBuffer MakeCloneFromMemory(const void* Ptr, size_t Sz, ZenContentType ContentType = ZenContentType::kBinary) { if (Sz) { - return IoBuffer(IoBuffer::Clone, Ptr, Sz); + IoBuffer NewBuffer(IoBuffer::Clone, Ptr, Sz); + NewBuffer.SetContentType(ContentType); + return NewBuffer; } return {}; } - inline static IoBuffer MakeCloneFromMemory(MemoryView Memory) { return MakeCloneFromMemory(Memory.GetData(), Memory.GetSize()); } + inline static IoBuffer MakeCloneFromMemory(MemoryView Memory, ZenContentType ContentType = ZenContentType::kBinary) + { + return MakeCloneFromMemory(Memory.GetData(), Memory.GetSize(), ContentType); + } + + /** Make sure buffer data is memory resident, but avoid memory mapping data from files + */ + static IoBuffer ReadFromFileMaybe(const IoBuffer& InBuffer); }; void iobuffer_forcelink(); diff --git a/src/zencore/include/zencore/logbase.h b/src/zencore/include/zencore/logbase.h index 00af68b0a..ece17a85e 100644 --- a/src/zencore/include/zencore/logbase.h +++ b/src/zencore/include/zencore/logbase.h @@ -4,96 +4,85 @@ #include <string_view> -#define ZEN_LOG_LEVEL_TRACE 0 -#define ZEN_LOG_LEVEL_DEBUG 1 -#define ZEN_LOG_LEVEL_INFO 2 -#define ZEN_LOG_LEVEL_WARN 3 -#define ZEN_LOG_LEVEL_ERROR 4 -#define ZEN_LOG_LEVEL_CRITICAL 5 -#define ZEN_LOG_LEVEL_OFF 6 - -#define ZEN_LEVEL_NAME_TRACE std::string_view("trace", 5) -#define ZEN_LEVEL_NAME_DEBUG std::string_view("debug", 5) -#define ZEN_LEVEL_NAME_INFO std::string_view("info", 4) -#define ZEN_LEVEL_NAME_WARNING std::string_view("warning", 7) -#define ZEN_LEVEL_NAME_ERROR std::string_view("error", 5) -#define ZEN_LEVEL_NAME_CRITICAL std::string_view("critical", 8) -#define ZEN_LEVEL_NAME_OFF std::string_view("off", 3) - -namespace zen::logging::level { +namespace zen::logging { enum LogLevel : int { - Trace = ZEN_LOG_LEVEL_TRACE, - Debug = ZEN_LOG_LEVEL_DEBUG, - Info = ZEN_LOG_LEVEL_INFO, - Warn = ZEN_LOG_LEVEL_WARN, - Err = ZEN_LOG_LEVEL_ERROR, - Critical = ZEN_LOG_LEVEL_CRITICAL, - Off = ZEN_LOG_LEVEL_OFF, + Trace, + Debug, + Info, + Warn, + Err, + Critical, + Off, LogLevelCount }; LogLevel ParseLogLevelString(std::string_view String); std::string_view ToStringView(LogLevel Level); -} // namespace zen::logging::level - -namespace zen::logging { - -void SetLogLevel(level::LogLevel NewLogLevel); -level::LogLevel GetLogLevel(); +void SetLogLevel(LogLevel NewLogLevel); +LogLevel GetLogLevel(); -} // namespace zen::logging +struct SourceLocation +{ + constexpr SourceLocation() = default; + constexpr SourceLocation(const char* InFilename, int InLine) : Filename(InFilename), Line(InLine) {} -namespace spdlog { -class logger; -} + constexpr operator bool() const noexcept { return Line != 0; } -namespace zen::logging { + const char* Filename{nullptr}; + int Line{0}; +}; -struct SourceLocation +/** This encodes the constant parts of a log message which can be emitted once + * and then referred to by log events. + * + * It's *critical* that instances of this struct are permanent and never + * destroyed, as log messages will refer to them by pointer. The easiest way + * to ensure this is to create them as function-local statics. + * + * The logging macros already do this for you so this should not be something + * you normally would need to worry about. + */ +struct LogPoint { - constexpr SourceLocation() = default; - constexpr SourceLocation(const char* filename_in, int line_in, const char* funcname_in) - : filename(filename_in) - , line(line_in) - , funcname(funcname_in) - { - } - - constexpr bool empty() const noexcept { return line == 0; } - - // IMPORTANT NOTE: the layout of this class must match the spdlog::source_loc class - // since we currently pass a pointer to it into spdlog after casting it to - // spdlog::source_loc* - // - // This is intended to be an intermediate state, before we (probably) transition off - // spdlog entirely - - const char* filename{nullptr}; - int line{0}; - const char* funcname{nullptr}; + SourceLocation Location; + LogLevel Level; + std::string_view FormatString; }; +class Logger; + } // namespace zen::logging namespace zen { +// Lightweight non-owning handle to a Logger. Loggers are owned by the Registry +// via Ref<Logger>; LoggerRef exists as a cheap (raw pointer) handle that can be +// stored in members and passed through logging macros without requiring the +// complete Logger type or incurring refcount overhead on every log call. struct LoggerRef { LoggerRef() = default; - LoggerRef(spdlog::logger& InLogger) : SpdLogger(&InLogger) {} + LoggerRef(logging::Logger& InLogger) : m_Logger(&InLogger) {} + // This exists so that logging macros can pass LoggerRef or LogCategory + // to ZEN_LOG without needing to know which one it is LoggerRef Logger() { return *this; } - bool ShouldLog(int Level) const; - inline operator bool() const { return SpdLogger != nullptr; } + bool ShouldLog(logging::LogLevel Level) const; + inline operator bool() const { return m_Logger != nullptr; } + + inline logging::Logger* operator->() const { return m_Logger; } + inline logging::Logger& operator*() const { return *m_Logger; } - void SetLogLevel(logging::level::LogLevel NewLogLevel); - logging::level::LogLevel GetLogLevel(); + void SetLogLevel(logging::LogLevel NewLogLevel); + logging::LogLevel GetLogLevel(); + void Flush(); - spdlog::logger* SpdLogger = nullptr; +private: + logging::Logger* m_Logger = nullptr; }; } // namespace zen diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h index afbbbd3ee..4b593c19e 100644 --- a/src/zencore/include/zencore/logging.h +++ b/src/zencore/include/zencore/logging.h @@ -9,16 +9,9 @@ #if ZEN_PLATFORM_WINDOWS # define ZEN_LOG_SECTION(Id) ZEN_DATA_SECTION(Id) -# pragma section(".zlog$f", read) # pragma section(".zlog$l", read) -# pragma section(".zlog$m", read) -# pragma section(".zlog$s", read) -# define ZEN_DECLARE_FUNCTION static constinit ZEN_LOG_SECTION(".zlog$f") char FuncName[] = __FUNCTION__; -# define ZEN_LOG_FUNCNAME FuncName #else # define ZEN_LOG_SECTION(Id) -# define ZEN_DECLARE_FUNCTION -# define ZEN_LOG_FUNCNAME static_cast<const char*>(__func__) #endif namespace zen::logging { @@ -31,39 +24,35 @@ void FlushLogging(); LoggerRef Default(); void SetDefault(std::string_view NewDefaultLoggerId); LoggerRef ConsoleLog(); +void ResetConsoleLog(); void SuppressConsoleLog(); LoggerRef ErrorLog(); void SetErrorLog(std::string_view LoggerId); LoggerRef Get(std::string_view Name); -void ConfigureLogLevels(level::LogLevel Level, std::string_view Loggers); +void ConfigureLogLevels(LogLevel Level, std::string_view Loggers); void RefreshLogLevels(); -void RefreshLogLevels(level::LogLevel DefaultLevel); - +void RefreshLogLevels(LogLevel DefaultLevel); + +/** LogCategory allows for the creation of log categories that can be used with + * the logging macros just like a logger reference. The main purpose of this is + * to allow for static log categories in global scope where we can't actually + * go ahead and instantiate a logger immediately because the logging system may + * not be initialized yet. + */ struct LogCategory { - inline LogCategory(std::string_view InCategory) : CategoryName(InCategory) {} - - inline zen::LoggerRef Logger() - { - if (LoggerRef) - { - return LoggerRef; - } + inline LogCategory(std::string_view InCategory) : m_CategoryName(InCategory) {} - LoggerRef = zen::logging::Get(CategoryName); - return LoggerRef; - } + LoggerRef Logger(); - std::string CategoryName; - zen::LoggerRef LoggerRef; +private: + std::string m_CategoryName; + LoggerRef m_LoggerRef; }; -void EmitConsoleLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args); -void EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Message); -void EmitLogMessage(LoggerRef& Logger, const SourceLocation& Location, int LogLevel, std::string_view Message); -void EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Format, fmt::format_args Args); -void EmitLogMessage(LoggerRef& Logger, const SourceLocation& Location, int LogLevel, std::string_view Format, fmt::format_args Args); +void EmitConsoleLogMessage(const LogPoint& Lp, fmt::format_args Args); +void EmitLogMessage(LoggerRef& Logger, const LogPoint& Lp, fmt::format_args Args); template<typename... T> auto @@ -78,15 +67,14 @@ namespace zen { extern LoggerRef TheDefaultLogger; -inline LoggerRef -Log() -{ - if (TheDefaultLogger) - { - return TheDefaultLogger; - } - return zen::logging::ConsoleLog(); -} +/** + * This is the default logger, which any ZEN_INFO et al will get if there's + * no Log() function declared in the current scope. + * + * Typically, classes which want to log to its own channel will declare a Log() + * member function which returns a LoggerRef created at construction time. + */ +LoggerRef Log(); using logging::ConsoleLog; using logging::ErrorLog; @@ -97,12 +85,6 @@ using zen::ConsoleLog; using zen::ErrorLog; using zen::Log; -inline consteval bool -LogIsErrorLevel(int LogLevel) -{ - return (LogLevel == zen::logging::level::Err || LogLevel == zen::logging::level::Critical); -}; - #if ZEN_BUILD_DEBUG # define ZEN_CHECK_FORMAT_STRING(fmtstr, ...) \ while (false) \ @@ -116,75 +98,66 @@ LogIsErrorLevel(int LogLevel) } #endif -#define ZEN_LOG_WITH_LOCATION(InLogger, InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - ZEN_DECLARE_FUNCTION \ - static constinit ZEN_LOG_SECTION(".zlog$s") char FileName[] = __FILE__; \ - static constinit ZEN_LOG_SECTION(".zlog$m") char FormatString[] = fmtstr; \ - static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::SourceLocation Location{FileName, __LINE__, ZEN_LOG_FUNCNAME}; \ - zen::LoggerRef Logger = InLogger; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - if (Logger.ShouldLog(InLevel)) \ - { \ - zen::logging::EmitLogMessage(Logger, \ - Location, \ - InLevel, \ - std::string_view(FormatString, sizeof FormatString - 1), \ - zen::logging::LogCaptureArguments(__VA_ARGS__)); \ - } \ +#define ZEN_LOG_WITH_LOCATION(InLogger, InLevel, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + static constinit ZEN_LOG_SECTION(".zlog$l") \ + zen::logging::LogPoint LogPoint{zen::logging::SourceLocation{__FILE__, __LINE__}, InLevel, std::string_view(fmtstr)}; \ + zen::LoggerRef Logger = InLogger; \ + ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ + if (Logger.ShouldLog(InLevel)) \ + { \ + zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ + } \ } while (false); -#define ZEN_LOG(InLogger, InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - static constinit ZEN_LOG_SECTION(".zlog$m") char FormatString[] = fmtstr; \ - zen::LoggerRef Logger = InLogger; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - if (Logger.ShouldLog(InLevel)) \ - { \ - zen::logging::EmitLogMessage(Logger, \ - InLevel, \ - std::string_view(FormatString, sizeof FormatString - 1), \ - zen::logging::LogCaptureArguments(__VA_ARGS__)); \ - } \ +#define ZEN_LOG(InLogger, InLevel, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \ + zen::LoggerRef Logger = InLogger; \ + ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ + if (Logger.ShouldLog(InLevel)) \ + { \ + zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ + } \ } while (false); #define ZEN_DEFINE_LOG_CATEGORY_STATIC(Category, Name) \ static zen::logging::LogCategory Category { Name } -#define ZEN_LOG_TRACE(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Trace, fmtstr, ##__VA_ARGS__) -#define ZEN_LOG_DEBUG(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Debug, fmtstr, ##__VA_ARGS__) -#define ZEN_LOG_INFO(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Info, fmtstr, ##__VA_ARGS__) -#define ZEN_LOG_WARN(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Warn, fmtstr, ##__VA_ARGS__) -#define ZEN_LOG_ERROR(Category, fmtstr, ...) ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::level::Err, fmtstr, ##__VA_ARGS__) -#define ZEN_LOG_CRITICAL(Category, fmtstr, ...) \ - ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::level::Critical, fmtstr, ##__VA_ARGS__) - -#define ZEN_TRACE(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Trace, fmtstr, ##__VA_ARGS__) -#define ZEN_DEBUG(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Debug, fmtstr, ##__VA_ARGS__) -#define ZEN_INFO(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Info, fmtstr, ##__VA_ARGS__) -#define ZEN_WARN(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Warn, fmtstr, ##__VA_ARGS__) -#define ZEN_ERROR(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Err, fmtstr, ##__VA_ARGS__) -#define ZEN_CRITICAL(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Critical, fmtstr, ##__VA_ARGS__) - -#define ZEN_CONSOLE_LOG(InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - zen::logging::EmitConsoleLogMessage(InLevel, fmtstr, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ +#define ZEN_LOG_TRACE(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Trace, fmtstr, ##__VA_ARGS__) +#define ZEN_LOG_DEBUG(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Debug, fmtstr, ##__VA_ARGS__) +#define ZEN_LOG_INFO(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Info, fmtstr, ##__VA_ARGS__) +#define ZEN_LOG_WARN(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Warn, fmtstr, ##__VA_ARGS__) +#define ZEN_LOG_ERROR(Category, fmtstr, ...) ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::Err, fmtstr, ##__VA_ARGS__) +#define ZEN_LOG_CRITICAL(Category, fmtstr, ...) ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::Critical, fmtstr, ##__VA_ARGS__) + +#define ZEN_TRACE(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Trace, fmtstr, ##__VA_ARGS__) +#define ZEN_DEBUG(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Debug, fmtstr, ##__VA_ARGS__) +#define ZEN_INFO(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Info, fmtstr, ##__VA_ARGS__) +#define ZEN_WARN(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Warn, fmtstr, ##__VA_ARGS__) +#define ZEN_ERROR(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Err, fmtstr, ##__VA_ARGS__) +#define ZEN_CRITICAL(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Critical, fmtstr, ##__VA_ARGS__) + +#define ZEN_CONSOLE_LOG(InLevel, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \ + ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ + zen::logging::EmitConsoleLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ } while (false) -#define ZEN_CONSOLE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_TRACE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Trace, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_DEBUG(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Debug, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_INFO(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_WARN(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Warn, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_ERROR(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Err, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_CRITICAL(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Critical, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_TRACE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Trace, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_DEBUG(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Debug, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_INFO(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_WARN(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Warn, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_ERROR(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Err, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_CRITICAL(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Critical, fmtstr, ##__VA_ARGS__) ////////////////////////////////////////////////////////////////////////// @@ -239,28 +212,28 @@ std::string_view EmitActivitiesForLogging(StringBuilderBase& OutString); #define ZEN_LOG_SCOPE(...) ScopedLazyActivity $Activity##__LINE__([&](StringBuilderBase& Out) { Out << fmt::format(__VA_ARGS__); }) -#define ZEN_SCOPED_WARN(fmtstr, ...) \ - do \ - { \ - ExtendableStringBuilder<256> ScopeString; \ - const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ - ZEN_LOG(Log(), zen::logging::level::Warn, fmtstr "{}", ##__VA_ARGS__, Scopes); \ +#define ZEN_SCOPED_WARN(fmtstr, ...) \ + do \ + { \ + ExtendableStringBuilder<256> ScopeString; \ + const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ + ZEN_LOG(Log(), zen::logging::Warn, fmtstr "{}", ##__VA_ARGS__, Scopes); \ } while (false) -#define ZEN_SCOPED_ERROR(fmtstr, ...) \ - do \ - { \ - ExtendableStringBuilder<256> ScopeString; \ - const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ - ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Err, fmtstr "{}", ##__VA_ARGS__, Scopes); \ +#define ZEN_SCOPED_ERROR(fmtstr, ...) \ + do \ + { \ + ExtendableStringBuilder<256> ScopeString; \ + const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ + ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Err, fmtstr "{}", ##__VA_ARGS__, Scopes); \ } while (false) -#define ZEN_SCOPED_CRITICAL(fmtstr, ...) \ - do \ - { \ - ExtendableStringBuilder<256> ScopeString; \ - const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ - ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Critical, fmtstr "{}", ##__VA_ARGS__, Scopes); \ +#define ZEN_SCOPED_CRITICAL(fmtstr, ...) \ + do \ + { \ + ExtendableStringBuilder<256> ScopeString; \ + const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ + ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Critical, fmtstr "{}", ##__VA_ARGS__, Scopes); \ } while (false) ScopedActivityBase* GetThreadActivity(); diff --git a/src/zencore/include/zencore/logging/ansicolorsink.h b/src/zencore/include/zencore/logging/ansicolorsink.h new file mode 100644 index 000000000..5060a8393 --- /dev/null +++ b/src/zencore/include/zencore/logging/ansicolorsink.h @@ -0,0 +1,33 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging/sink.h> + +#include <memory> + +namespace zen::logging { + +enum class ColorMode +{ + On, + Off, + Auto +}; + +class AnsiColorStdoutSink : public Sink +{ +public: + explicit AnsiColorStdoutSink(ColorMode Mode = ColorMode::Auto); + ~AnsiColorStdoutSink() override; + + void Log(const LogMessage& Msg) override; + void Flush() override; + void SetFormatter(std::unique_ptr<Formatter> InFormatter) override; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/asyncsink.h b/src/zencore/include/zencore/logging/asyncsink.h new file mode 100644 index 000000000..c49a1ccce --- /dev/null +++ b/src/zencore/include/zencore/logging/asyncsink.h @@ -0,0 +1,30 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging/sink.h> + +#include <memory> +#include <vector> + +namespace zen::logging { + +class AsyncSink : public Sink +{ +public: + explicit AsyncSink(std::vector<SinkPtr> InSinks); + ~AsyncSink() override; + + AsyncSink(const AsyncSink&) = delete; + AsyncSink& operator=(const AsyncSink&) = delete; + + void Log(const LogMessage& Msg) override; + void Flush() override; + void SetFormatter(std::unique_ptr<Formatter> InFormatter) override; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/formatter.h b/src/zencore/include/zencore/logging/formatter.h new file mode 100644 index 000000000..11904d71d --- /dev/null +++ b/src/zencore/include/zencore/logging/formatter.h @@ -0,0 +1,20 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging/logmsg.h> +#include <zencore/logging/memorybuffer.h> + +#include <memory> + +namespace zen::logging { + +class Formatter +{ +public: + virtual ~Formatter() = default; + virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) = 0; + virtual std::unique_ptr<Formatter> Clone() const = 0; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/helpers.h b/src/zencore/include/zencore/logging/helpers.h new file mode 100644 index 000000000..ce021e1a5 --- /dev/null +++ b/src/zencore/include/zencore/logging/helpers.h @@ -0,0 +1,122 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logbase.h> +#include <zencore/logging/memorybuffer.h> + +#include <chrono> +#include <ctime> +#include <string_view> + +namespace zen::logging::helpers { + +inline void +AppendStringView(std::string_view Sv, MemoryBuffer& Dest) +{ + Dest.append(Sv.data(), Sv.data() + Sv.size()); +} + +inline void +AppendInt(int N, MemoryBuffer& Dest) +{ + fmt::format_int Formatted(N); + Dest.append(Formatted.data(), Formatted.data() + Formatted.size()); +} + +inline void +Pad2(int N, MemoryBuffer& Dest) +{ + if (N >= 0 && N < 100) + { + Dest.push_back(static_cast<char>('0' + N / 10)); + Dest.push_back(static_cast<char>('0' + N % 10)); + } + else + { + fmt::format_int Formatted(N); + Dest.append(Formatted.data(), Formatted.data() + Formatted.size()); + } +} + +inline void +Pad3(uint32_t N, MemoryBuffer& Dest) +{ + if (N < 1000) + { + Dest.push_back(static_cast<char>('0' + N / 100)); + Dest.push_back(static_cast<char>('0' + (N / 10) % 10)); + Dest.push_back(static_cast<char>('0' + N % 10)); + } + else + { + AppendInt(static_cast<int>(N), Dest); + } +} + +inline void +PadUint(size_t N, unsigned int Width, MemoryBuffer& Dest) +{ + fmt::format_int Formatted(N); + auto StrLen = static_cast<unsigned int>(Formatted.size()); + if (Width > StrLen) + { + for (unsigned int Pad = 0; Pad < Width - StrLen; ++Pad) + { + Dest.push_back('0'); + } + } + Dest.append(Formatted.data(), Formatted.data() + Formatted.size()); +} + +template<typename ToDuration> +inline ToDuration +TimeFraction(std::chrono::system_clock::time_point Tp) +{ + using std::chrono::duration_cast; + using std::chrono::seconds; + auto Duration = Tp.time_since_epoch(); + auto Secs = duration_cast<seconds>(Duration); + return duration_cast<ToDuration>(Duration) - duration_cast<ToDuration>(Secs); +} + +inline std::tm +SafeLocaltime(std::time_t Time) +{ + std::tm Result{}; +#if defined(_WIN32) + localtime_s(&Result, &Time); +#else + localtime_r(&Time, &Result); +#endif + return Result; +} + +inline const char* +ShortFilename(const char* Path) +{ + if (Path == nullptr) + { + return Path; + } + + const char* It = Path; + const char* LastSep = Path; + while (*It) + { + if (*It == '/' || *It == '\\') + { + LastSep = It + 1; + } + ++It; + } + return LastSep; +} + +inline std::string_view +LevelToShortString(LogLevel Level) +{ + return ToStringView(Level); +} + +} // namespace zen::logging::helpers diff --git a/src/zencore/include/zencore/logging/logger.h b/src/zencore/include/zencore/logging/logger.h new file mode 100644 index 000000000..39d1139a5 --- /dev/null +++ b/src/zencore/include/zencore/logging/logger.h @@ -0,0 +1,63 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging/sink.h> + +#include <atomic> +#include <memory> +#include <string_view> + +namespace zen::logging { + +class ErrorHandler +{ +public: + virtual ~ErrorHandler() = default; + virtual void HandleError(const std::string_view& Msg) = 0; +}; + +class Logger : public RefCounted +{ +public: + Logger(std::string_view InName, SinkPtr InSink); + Logger(std::string_view InName, std::span<const SinkPtr> InSinks); + ~Logger(); + + Logger(const Logger&) = delete; + Logger& operator=(const Logger&) = delete; + + void Log(const LogPoint& Point, fmt::format_args Args); + + bool ShouldLog(LogLevel InLevel) const { return InLevel >= m_Level.load(std::memory_order_relaxed); } + + void SetLevel(LogLevel InLevel) { m_Level.store(InLevel, std::memory_order_relaxed); } + LogLevel GetLevel() const { return m_Level.load(std::memory_order_relaxed); } + + void SetFlushLevel(LogLevel InLevel) { m_FlushLevel.store(InLevel, std::memory_order_relaxed); } + LogLevel GetFlushLevel() const { return m_FlushLevel.load(std::memory_order_relaxed); } + + std::string_view Name() const; + + void SetSinks(std::vector<SinkPtr> InSinks); + void AddSink(SinkPtr InSink); + + void SetFormatter(std::unique_ptr<Formatter> InFormatter); + + void SetErrorHandler(ErrorHandler* Handler); + + void Flush(); + + Ref<Logger> Clone(std::string_view NewName) const; + +private: + void SinkIt(const LogMessage& Msg); + void FlushIfNeeded(LogLevel InLevel); + + struct Impl; + std::unique_ptr<Impl> m_Impl; + std::atomic<LogLevel> m_Level{Info}; + std::atomic<LogLevel> m_FlushLevel{Off}; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/logmsg.h b/src/zencore/include/zencore/logging/logmsg.h new file mode 100644 index 000000000..1d8b6b1b7 --- /dev/null +++ b/src/zencore/include/zencore/logging/logmsg.h @@ -0,0 +1,66 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logbase.h> + +#include <chrono> +#include <string_view> + +namespace zen::logging { + +using LogClock = std::chrono::system_clock; + +struct LogMessage +{ + LogMessage() = default; + + LogMessage(const LogPoint& InPoint, std::string_view InLoggerName, std::string_view InPayload) + : m_LoggerName(InLoggerName) + , m_Level(InPoint.Level) + , m_Time(LogClock::now()) + , m_Source(InPoint.Location) + , m_Payload(InPayload) + , m_Point(&InPoint) + { + } + + std::string_view GetPayload() const { return m_Payload; } + int GetThreadId() const { return m_ThreadId; } + LogClock::time_point GetTime() const { return m_Time; } + LogLevel GetLevel() const { return m_Level; } + std::string_view GetLoggerName() const { return m_LoggerName; } + const SourceLocation& GetSource() const { return m_Source; } + const LogPoint& GetLogPoint() const { return *m_Point; } + + void SetThreadId(int InThreadId) { m_ThreadId = InThreadId; } + void SetPayload(std::string_view InPayload) { m_Payload = InPayload; } + void SetLoggerName(std::string_view InName) { m_LoggerName = InName; } + void SetLevel(LogLevel InLevel) { m_Level = InLevel; } + void SetTime(LogClock::time_point InTime) { m_Time = InTime; } + void SetSource(const SourceLocation& InSource) { m_Source = InSource; } + + mutable size_t ColorRangeStart = 0; + mutable size_t ColorRangeEnd = 0; + +private: + static constexpr LogPoint s_DefaultPoints[LogLevelCount] = { + {{}, Trace, {}}, + {{}, Debug, {}}, + {{}, Info, {}}, + {{}, Warn, {}}, + {{}, Err, {}}, + {{}, Critical, {}}, + {{}, Off, {}}, + }; + + std::string_view m_LoggerName; + LogLevel m_Level = Off; + std::chrono::system_clock::time_point m_Time; + SourceLocation m_Source; + std::string_view m_Payload; + const LogPoint* m_Point = &s_DefaultPoints[Off]; + int m_ThreadId = 0; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/memorybuffer.h b/src/zencore/include/zencore/logging/memorybuffer.h new file mode 100644 index 000000000..cd0ff324f --- /dev/null +++ b/src/zencore/include/zencore/logging/memorybuffer.h @@ -0,0 +1,11 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <fmt/format.h> + +namespace zen::logging { + +using MemoryBuffer = fmt::basic_memory_buffer<char, 250>; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/messageonlyformatter.h b/src/zencore/include/zencore/logging/messageonlyformatter.h new file mode 100644 index 000000000..ce25fe9a6 --- /dev/null +++ b/src/zencore/include/zencore/logging/messageonlyformatter.h @@ -0,0 +1,22 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging/formatter.h> +#include <zencore/logging/helpers.h> + +namespace zen::logging { + +class MessageOnlyFormatter : public Formatter +{ +public: + void Format(const LogMessage& Msg, MemoryBuffer& Dest) override + { + helpers::AppendStringView(Msg.GetPayload(), Dest); + Dest.push_back('\n'); + } + + std::unique_ptr<Formatter> Clone() const override { return std::make_unique<MessageOnlyFormatter>(); } +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/msvcsink.h b/src/zencore/include/zencore/logging/msvcsink.h new file mode 100644 index 000000000..48ea1b915 --- /dev/null +++ b/src/zencore/include/zencore/logging/msvcsink.h @@ -0,0 +1,30 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging/sink.h> + +#if ZEN_PLATFORM_WINDOWS + +# include <mutex> + +namespace zen::logging { + +class MsvcSink : public Sink +{ +public: + MsvcSink(); + ~MsvcSink() override = default; + + void Log(const LogMessage& Msg) override; + void Flush() override; + void SetFormatter(std::unique_ptr<Formatter> InFormatter) override; + +private: + std::mutex m_Mutex; + std::unique_ptr<Formatter> m_Formatter; +}; + +} // namespace zen::logging + +#endif // ZEN_PLATFORM_WINDOWS diff --git a/src/zencore/include/zencore/logging/nullsink.h b/src/zencore/include/zencore/logging/nullsink.h new file mode 100644 index 000000000..7ac5677c6 --- /dev/null +++ b/src/zencore/include/zencore/logging/nullsink.h @@ -0,0 +1,17 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging/sink.h> + +namespace zen::logging { + +class NullSink : public Sink +{ +public: + void Log(const LogMessage& /*Msg*/) override {} + void Flush() override {} + void SetFormatter(std::unique_ptr<Formatter> /*InFormatter*/) override {} +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/registry.h b/src/zencore/include/zencore/logging/registry.h new file mode 100644 index 000000000..a4d3692d2 --- /dev/null +++ b/src/zencore/include/zencore/logging/registry.h @@ -0,0 +1,70 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging/logger.h> + +#include <chrono> +#include <memory> +#include <span> +#include <string> +#include <type_traits> +#include <utility> + +namespace zen::logging { + +class Registry +{ +public: + using LogLevels = std::span<const std::pair<std::string, LogLevel>>; + + static Registry& Instance(); + void Shutdown(); + + void Register(Ref<Logger> InLogger); + void Drop(const std::string& Name); + Ref<Logger> Get(const std::string& Name); + + void SetDefaultLogger(Ref<Logger> InLogger); + Logger* DefaultLoggerRaw(); + Ref<Logger> DefaultLogger(); + + void SetGlobalLevel(LogLevel Level); + LogLevel GetGlobalLevel() const; + void SetLevels(LogLevels Levels, LogLevel* DefaultLevel); + + void FlushAll(); + void FlushOn(LogLevel Level); + void FlushEvery(std::chrono::seconds Interval); + + // Change formatter on all registered loggers + void SetFormatter(std::unique_ptr<Formatter> InFormatter); + + // Apply function to all registered loggers. Note that the function will + // be called while the registry mutex is held, so it should be fast and + // not attempt to call back into the registry. + template<typename Func> + void ApplyAll(Func&& F) + { + ApplyAllImpl([](void* Ctx, Ref<Logger> L) { (*static_cast<std::decay_t<Func>*>(Ctx))(std::move(L)); }, &F); + } + + // Set error handler for all loggers in the registry. The handler is called + // if any logger encounters an error during logging or flushing. + // The caller must ensure the handler outlives the registry. + void SetErrorHandler(ErrorHandler* Handler); + +private: + void ApplyAllImpl(void (*Func)(void*, Ref<Logger>), void* Context); + + Registry(); + ~Registry(); + + Registry(const Registry&) = delete; + Registry& operator=(const Registry&) = delete; + + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/sink.h b/src/zencore/include/zencore/logging/sink.h new file mode 100644 index 000000000..172176a4e --- /dev/null +++ b/src/zencore/include/zencore/logging/sink.h @@ -0,0 +1,34 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/refcount.h> +#include <zencore/logging/formatter.h> +#include <zencore/logging/logmsg.h> + +#include <atomic> +#include <memory> + +namespace zen::logging { + +class Sink : public RefCounted +{ +public: + virtual ~Sink() = default; + + virtual void Log(const LogMessage& Msg) = 0; + virtual void Flush() = 0; + + virtual void SetFormatter(std::unique_ptr<Formatter> InFormatter) = 0; + + bool ShouldLog(LogLevel InLevel) const { return InLevel >= m_Level.load(std::memory_order_relaxed); } + void SetLevel(LogLevel InLevel) { m_Level.store(InLevel, std::memory_order_relaxed); } + LogLevel GetLevel() const { return m_Level.load(std::memory_order_relaxed); } + +protected: + std::atomic<LogLevel> m_Level{Trace}; +}; + +using SinkPtr = Ref<Sink>; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/tracesink.h b/src/zencore/include/zencore/logging/tracesink.h new file mode 100644 index 000000000..785c51e10 --- /dev/null +++ b/src/zencore/include/zencore/logging/tracesink.h @@ -0,0 +1,27 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging/sink.h> + +namespace zen::logging { + +#if ZEN_WITH_TRACE + +/** + * A logging sink that forwards log messages to the trace system. + * + * Work-in-progress, not fully implemented. + */ + +class TraceSink : public Sink +{ +public: + void Log(const LogMessage& Msg) override; + void Flush() override; + void SetFormatter(std::unique_ptr<Formatter> InFormatter) override; +}; + +#endif + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/md5.h b/src/zencore/include/zencore/md5.h index d934dd86b..3b0b7cae6 100644 --- a/src/zencore/include/zencore/md5.h +++ b/src/zencore/include/zencore/md5.h @@ -43,6 +43,8 @@ public: MD5 GetHash(); private: + // Opaque storage for MD5_CTX (104 bytes, aligned to uint32_t) + alignas(4) uint8_t m_Context[104]; }; void md5_forcelink(); // internal diff --git a/src/zencore/include/zencore/meta.h b/src/zencore/include/zencore/meta.h index 82eb5cc30..20ec4ac6f 100644 --- a/src/zencore/include/zencore/meta.h +++ b/src/zencore/include/zencore/meta.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once /* This file contains utility functions for meta programming * diff --git a/src/zencore/include/zencore/mpscqueue.h b/src/zencore/include/zencore/mpscqueue.h index 19e410d85..d97c433fd 100644 --- a/src/zencore/include/zencore/mpscqueue.h +++ b/src/zencore/include/zencore/mpscqueue.h @@ -22,10 +22,10 @@ namespace zen { template<typename ElementType> struct TypeCompatibleStorage { - ElementType* Data() { return (ElementType*)this; } - const ElementType* Data() const { return (const ElementType*)this; } + ElementType* Data() { return reinterpret_cast<ElementType*>(&Storage); } + const ElementType* Data() const { return reinterpret_cast<const ElementType*>(&Storage); } - alignas(ElementType) char DataMember; + alignas(ElementType) char Storage[sizeof(ElementType)]; }; /** Fast multi-producer/single-consumer unbounded concurrent queue. @@ -58,7 +58,7 @@ public: Tail = Next; Next = Tail->Next.load(std::memory_order_relaxed); - std::destroy_at((ElementType*)&Tail->Value); + std::destroy_at(Tail->Value.Data()); delete Tail; } } @@ -67,7 +67,7 @@ public: void Enqueue(ArgTypes&&... Args) { Node* New = new Node; - new (&New->Value) ElementType(std::forward<ArgTypes>(Args)...); + new (New->Value.Data()) ElementType(std::forward<ArgTypes>(Args)...); Node* Prev = Head.exchange(New, std::memory_order_acq_rel); Prev->Next.store(New, std::memory_order_release); @@ -82,7 +82,7 @@ public: return {}; } - ElementType* ValuePtr = (ElementType*)&Next->Value; + ElementType* ValuePtr = Next->Value.Data(); std::optional<ElementType> Res{std::move(*ValuePtr)}; std::destroy_at(ValuePtr); @@ -100,9 +100,11 @@ private: }; private: - std::atomic<Node*> Head; // accessed only by producers - alignas(hardware_constructive_interference_size) - Node* Tail; // accessed only by consumer, hence should be on a different cache line than `Head` + // Use a fixed constant to avoid GCC's -Winterference-size warning with std::hardware_destructive_interference_size + static constexpr std::size_t CacheLineSize = 64; + + alignas(CacheLineSize) std::atomic<Node*> Head; // accessed only by producers + alignas(CacheLineSize) Node* Tail; // accessed only by consumer, separate cache line from Head }; void mpscqueue_forcelink(); diff --git a/src/zencore/include/zencore/process.h b/src/zencore/include/zencore/process.h index e3b7a70d7..809312c7b 100644 --- a/src/zencore/include/zencore/process.h +++ b/src/zencore/include/zencore/process.h @@ -9,6 +9,10 @@ namespace zen { +#if ZEN_PLATFORM_WINDOWS +class JobObject; +#endif + /** Basic process abstraction */ class ProcessHandle @@ -46,6 +50,7 @@ private: /** Basic process creation */ + struct CreateProcOptions { enum @@ -63,6 +68,9 @@ struct CreateProcOptions const std::filesystem::path* WorkingDirectory = nullptr; uint32_t Flags = 0; std::filesystem::path StdoutFile; +#if ZEN_PLATFORM_WINDOWS + JobObject* AssignToJob = nullptr; // When set, the process is created suspended, assigned to the job, then resumed +#endif }; #if ZEN_PLATFORM_WINDOWS @@ -99,12 +107,38 @@ private: std::vector<HandleType> m_ProcessHandles; }; +#if ZEN_PLATFORM_WINDOWS +/** Windows Job Object wrapper + * + * When configured with JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, the OS will + * terminate all assigned child processes when the job handle is closed + * (including abnormal termination of the owning process). This provides + * an OS-level guarantee against orphaned child processes. + */ +class JobObject +{ +public: + JobObject(); + ~JobObject(); + JobObject(const JobObject&) = delete; + JobObject& operator=(const JobObject&) = delete; + + void Initialize(); + bool AssignProcess(void* ProcessHandle); + [[nodiscard]] bool IsValid() const; + +private: + void* m_JobHandle = nullptr; +}; +#endif // ZEN_PLATFORM_WINDOWS + bool IsProcessRunning(int pid); bool IsProcessRunning(int pid, std::error_code& OutEc); int GetCurrentProcessId(); int GetProcessId(CreateProcResult ProcId); std::filesystem::path GetProcessExecutablePath(int Pid, std::error_code& OutEc); +std::string GetProcessCommandLine(int Pid, std::error_code& OutEc); std::error_code FindProcess(const std::filesystem::path& ExecutableImage, ProcessHandle& OutHandle, bool IncludeSelf = true); /** Wait for all threads in the current process to exit (except the calling thread) diff --git a/src/zencore/include/zencore/sentryintegration.h b/src/zencore/include/zencore/sentryintegration.h index faf1238b7..27e5a8a82 100644 --- a/src/zencore/include/zencore/sentryintegration.h +++ b/src/zencore/include/zencore/sentryintegration.h @@ -11,11 +11,9 @@ #if ZEN_USE_SENTRY -# include <memory> +# include <zencore/logging/logger.h> -ZEN_THIRD_PARTY_INCLUDES_START -# include <spdlog/logger.h> -ZEN_THIRD_PARTY_INCLUDES_END +# include <memory> namespace sentry { @@ -42,6 +40,7 @@ public: }; void Initialize(const Config& Conf, const std::string& CommandLine); + void Close(); void LogStartupInformation(); static void ClearCaches(); @@ -53,7 +52,7 @@ private: std::string m_SentryUserName; std::string m_SentryHostName; std::string m_SentryId; - std::shared_ptr<spdlog::logger> m_SentryLogger; + Ref<logging::Logger> m_SentryLogger; }; } // namespace zen diff --git a/src/zencore/include/zencore/sharedbuffer.h b/src/zencore/include/zencore/sharedbuffer.h index c57e9f568..3d4c19282 100644 --- a/src/zencore/include/zencore/sharedbuffer.h +++ b/src/zencore/include/zencore/sharedbuffer.h @@ -116,14 +116,15 @@ public: inline void Reset() { m_Buffer = nullptr; } inline bool GetFileReference(IoBufferFileReference& OutRef) const { - if (const IoBufferExtendedCore* Core = m_Buffer->ExtendedCore()) + if (!IsNull()) { - return Core->GetFileReference(OutRef); - } - else - { - return false; + if (const IoBufferExtendedCore* Core = m_Buffer->ExtendedCore()) + { + return Core->GetFileReference(OutRef); + } } + + return false; } [[nodiscard]] MemoryView GetView() const diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h index cbff6454f..4deca63ed 100644 --- a/src/zencore/include/zencore/string.h +++ b/src/zencore/include/zencore/string.h @@ -8,7 +8,6 @@ #include <stdint.h> #include <string.h> #include <charconv> -#include <codecvt> #include <compare> #include <concepts> #include <optional> @@ -51,7 +50,7 @@ StringLength(const wchar_t* str) return wcslen(str); } -inline bool +inline int StringCompare(const char16_t* s1, const char16_t* s2) { char16_t c1, c2; @@ -66,7 +65,7 @@ StringCompare(const char16_t* s1, const char16_t* s2) ++s1; ++s2; } - return uint16_t(c1) - uint16_t(c2); + return int(uint16_t(c1)) - int(uint16_t(c2)); } inline bool @@ -122,10 +121,10 @@ public: StringBuilderImpl() = default; ~StringBuilderImpl(); - StringBuilderImpl(const StringBuilderImpl&) = delete; - StringBuilderImpl(const StringBuilderImpl&&) = delete; + StringBuilderImpl(const StringBuilderImpl&) = delete; + StringBuilderImpl(StringBuilderImpl&&) = delete; const StringBuilderImpl& operator=(const StringBuilderImpl&) = delete; - const StringBuilderImpl& operator=(const StringBuilderImpl&&) = delete; + StringBuilderImpl& operator=(StringBuilderImpl&&) = delete; inline size_t AddUninitialized(size_t Count) { @@ -374,9 +373,9 @@ protected: [[noreturn]] void Fail(const char* FailReason); // note: throws exception - C* m_Base; - C* m_CurPos; - C* m_End; + C* m_Base = nullptr; + C* m_CurPos = nullptr; + C* m_End = nullptr; bool m_IsDynamic = false; bool m_IsExtendable = false; }; @@ -773,8 +772,9 @@ std::optional<T> ParseInt(const std::string_view& Input) { T Out = 0; - const std::from_chars_result Result = std::from_chars(Input.data(), Input.data() + Input.size(), Out); - if (Result.ec == std::errc::invalid_argument || Result.ec == std::errc::result_out_of_range) + const char* End = Input.data() + Input.size(); + const std::from_chars_result Result = std::from_chars(Input.data(), End, Out); + if (Result.ec == std::errc::invalid_argument || Result.ec == std::errc::result_out_of_range || Result.ptr != End) { return std::nullopt; } @@ -797,6 +797,22 @@ HashStringDjb2(const std::string_view& InString) } constexpr uint32_t +HashStringDjb2(const std::span<const std::string_view> InStrings) +{ + uint32_t HashValue = 5381; + + for (const std::string_view& String : InStrings) + { + for (int CurChar : String) + { + HashValue = HashValue * 33 + CurChar; + } + } + + return HashValue; +} + +constexpr uint32_t HashStringAsLowerDjb2(const std::string_view& InString) { uint32_t HashValue = 5381; @@ -1249,6 +1265,8 @@ private: uint64_t LoMask, HiMask; }; +std::string HideSensitiveString(std::string_view String); + ////////////////////////////////////////////////////////////////////////// void string_forcelink(); // internal diff --git a/src/zencore/include/zencore/system.h b/src/zencore/include/zencore/system.h index aec2e0ce4..a67999e52 100644 --- a/src/zencore/include/zencore/system.h +++ b/src/zencore/include/zencore/system.h @@ -4,6 +4,8 @@ #include <zencore/zencore.h> +#include <chrono> +#include <memory> #include <string> namespace zen { @@ -12,6 +14,8 @@ class CbWriter; std::string GetMachineName(); std::string_view GetOperatingSystemName(); +std::string GetOperatingSystemVersion(); +std::string_view GetRuntimePlatformName(); // "windows", "wine", "linux", or "macos" std::string_view GetCpuName(); struct SystemMetrics @@ -25,6 +29,14 @@ struct SystemMetrics uint64_t AvailVirtualMemoryMiB = 0; uint64_t PageFileMiB = 0; uint64_t AvailPageFileMiB = 0; + uint64_t UptimeSeconds = 0; +}; + +/// Extended metrics that include CPU usage percentage, which requires +/// stateful delta tracking via SystemMetricsTracker. +struct ExtendedSystemMetrics : SystemMetrics +{ + float CpuUsagePercent = 0.0f; }; SystemMetrics GetSystemMetrics(); @@ -32,6 +44,31 @@ SystemMetrics GetSystemMetrics(); void SetCpuCountForReporting(int FakeCpuCount); SystemMetrics GetSystemMetricsForReporting(); +ExtendedSystemMetrics ApplyReportingOverrides(ExtendedSystemMetrics Metrics); + void Describe(const SystemMetrics& Metrics, CbWriter& Writer); +void Describe(const ExtendedSystemMetrics& Metrics, CbWriter& Writer); + +/// Stateful tracker that computes CPU usage as a delta between consecutive +/// Query() calls. The first call returns CpuUsagePercent = 0 (no previous +/// sample). Thread-safe: concurrent calls are serialised internally. +/// CPU sampling is rate-limited to MinInterval (default 1 s); calls that +/// arrive sooner return the previously cached value. +class SystemMetricsTracker +{ +public: + explicit SystemMetricsTracker(std::chrono::milliseconds MinInterval = std::chrono::seconds(1)); + ~SystemMetricsTracker(); + + SystemMetricsTracker(const SystemMetricsTracker&) = delete; + SystemMetricsTracker& operator=(const SystemMetricsTracker&) = delete; + + /// Collect current metrics. CPU usage is computed as delta since last Query(). + ExtendedSystemMetrics Query(); + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; } // namespace zen diff --git a/src/zencore/include/zencore/testing.h b/src/zencore/include/zencore/testing.h index a00ee3166..8410216c4 100644 --- a/src/zencore/include/zencore/testing.h +++ b/src/zencore/include/zencore/testing.h @@ -43,8 +43,9 @@ public: TestRunner(); ~TestRunner(); - int ApplyCommandLine(int argc, char const* const* argv); - int Run(); + void SetDefaultSuiteFilter(const char* Pattern); + int ApplyCommandLine(int Argc, char const* const* Argv); + int Run(); private: struct Impl; @@ -59,6 +60,8 @@ private: return Runner.Run(); \ }() +int RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink)()); + } // namespace zen::testing #endif diff --git a/src/zencore/include/zencore/testutils.h b/src/zencore/include/zencore/testutils.h index e2a4f8346..2a789d18f 100644 --- a/src/zencore/include/zencore/testutils.h +++ b/src/zencore/include/zencore/testutils.h @@ -59,6 +59,33 @@ struct TrueType static const bool Enabled = true; }; +namespace utf8test { + + // 2-byte UTF-8 (Latin extended) + static constexpr const char kLatin[] = u8"café_résumé"; + static constexpr const wchar_t kLatinW[] = L"café_résumé"; + + // 2-byte UTF-8 (Cyrillic) + static constexpr const char kCyrillic[] = u8"данные"; + static constexpr const wchar_t kCyrillicW[] = L"данные"; + + // 3-byte UTF-8 (CJK) + static constexpr const char kCJK[] = u8"日本語"; + static constexpr const wchar_t kCJKW[] = L"日本語"; + + // Mixed scripts + static constexpr const char kMixed[] = u8"zen_éд日"; + static constexpr const wchar_t kMixedW[] = L"zen_éд日"; + + // 4-byte UTF-8 (supplementary plane) — string tests only, NOT filesystem + static constexpr const char kEmoji[] = u8"📦"; + static constexpr const wchar_t kEmojiW[] = L"📦"; + + // BMP-only test strings suitable for filesystem use + static constexpr const char* kFilenameSafe[] = {kLatin, kCyrillic, kCJK, kMixed}; + +} // namespace utf8test + } // namespace zen #endif // ZEN_WITH_TESTS diff --git a/src/zencore/include/zencore/thread.h b/src/zencore/include/zencore/thread.h index de8f9399c..d0d710ee8 100644 --- a/src/zencore/include/zencore/thread.h +++ b/src/zencore/include/zencore/thread.h @@ -58,17 +58,27 @@ public: } private: - RwLock* m_Lock; + RwLock* m_Lock = nullptr; }; - inline void WithSharedLock(auto&& Fun) + inline auto WithSharedLock(auto&& Fun) { SharedLockScope $(*this); - Fun(); + return Fun(); } struct ExclusiveLockScope { + ExclusiveLockScope(const ExclusiveLockScope& Rhs) = delete; + ExclusiveLockScope(ExclusiveLockScope&& Rhs) : m_Lock(Rhs.m_Lock) { Rhs.m_Lock = nullptr; } + ExclusiveLockScope& operator=(ExclusiveLockScope&& Rhs) + { + ReleaseNow(); + m_Lock = Rhs.m_Lock; + Rhs.m_Lock = nullptr; + return *this; + } + ExclusiveLockScope& operator=(const ExclusiveLockScope& Rhs) = delete; ExclusiveLockScope(RwLock& Lock) : m_Lock(&Lock) { Lock.AcquireExclusive(); } ~ExclusiveLockScope() { ReleaseNow(); } @@ -82,13 +92,13 @@ public: } private: - RwLock* m_Lock; + RwLock* m_Lock = nullptr; }; - inline void WithExclusiveLock(auto&& Fun) + inline auto WithExclusiveLock(auto&& Fun) { ExclusiveLockScope $(*this); - Fun(); + return Fun(); } private: @@ -195,7 +205,7 @@ public: // false positive completion results. void AddCount(std::ptrdiff_t Count) { - std::atomic_ptrdiff_t Old = Counter.fetch_add(Count); + std::ptrdiff_t Old = Counter.fetch_add(Count); ZEN_ASSERT(Old > 0); } diff --git a/src/zencore/include/zencore/trace.h b/src/zencore/include/zencore/trace.h index 99a565151..d17e018ea 100644 --- a/src/zencore/include/zencore/trace.h +++ b/src/zencore/include/zencore/trace.h @@ -13,6 +13,7 @@ ZEN_THIRD_PARTY_INCLUDES_START # define TRACE_IMPLEMENT 0 #endif #include <trace.h> +#include <lane_trace.h> #undef TRACE_IMPLEMENT ZEN_THIRD_PARTY_INCLUDES_END diff --git a/src/zencore/include/zencore/varint.h b/src/zencore/include/zencore/varint.h index 9fe905f25..43ca14d38 100644 --- a/src/zencore/include/zencore/varint.h +++ b/src/zencore/include/zencore/varint.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include "intmath.h" diff --git a/src/zencore/include/zencore/xxhash.h b/src/zencore/include/zencore/xxhash.h index fc55b513b..f79d39b61 100644 --- a/src/zencore/include/zencore/xxhash.h +++ b/src/zencore/include/zencore/xxhash.h @@ -87,7 +87,7 @@ struct XXH3_128Stream } private: - XXH3_state_s m_State; + XXH3_state_s m_State{}; }; struct XXH3_128Stream_deprecated diff --git a/src/zencore/include/zencore/zencore.h b/src/zencore/include/zencore/zencore.h index 177a19fff..a31950b0b 100644 --- a/src/zencore/include/zencore/zencore.h +++ b/src/zencore/include/zencore/zencore.h @@ -70,26 +70,36 @@ protected: } // namespace zen -#define ZEN_ASSERT(x, ...) \ - do \ - { \ - if (x) [[unlikely]] \ - break; \ - zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, #x); \ +#define ZEN_ASSERT(x, ...) \ + do \ + { \ + if (x) [[unlikely]] \ + break; \ + zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, ZEN_ASSERT_MSG_(#x, ##__VA_ARGS__)); \ } while (false) #ifndef NDEBUG -# define ZEN_ASSERT_SLOW(x, ...) \ - do \ - { \ - if (x) [[unlikely]] \ - break; \ - zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, #x); \ +# define ZEN_ASSERT_SLOW(x, ...) \ + do \ + { \ + if (x) [[unlikely]] \ + break; \ + zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, ZEN_ASSERT_MSG_(#x, ##__VA_ARGS__)); \ } while (false) #else # define ZEN_ASSERT_SLOW(x, ...) #endif +// Internal: select between "expr" and "expr: message" forms. +// With no extra args: ZEN_ASSERT_MSG_("expr") -> "expr" +// With a message arg: ZEN_ASSERT_MSG_("expr", "msg") -> "expr" ": " "msg" +// With fmt-style args: ZEN_ASSERT_MSG_("expr", "msg", args...) -> "expr" ": " "msg" +// The extra fmt args are silently discarded here — use ZEN_ASSERT_FORMAT for those. +#define ZEN_ASSERT_MSG_SELECT_(_1, _2, N, ...) N +#define ZEN_ASSERT_MSG_1_(expr) expr +#define ZEN_ASSERT_MSG_2_(expr, msg, ...) expr ": " msg +#define ZEN_ASSERT_MSG_(expr, ...) ZEN_ASSERT_MSG_SELECT_(unused, ##__VA_ARGS__, ZEN_ASSERT_MSG_2_, ZEN_ASSERT_MSG_1_)(expr, ##__VA_ARGS__) + ////////////////////////////////////////////////////////////////////////// #define ZEN_NOT_IMPLEMENTED(...) ZEN_ASSERT(false, __VA_ARGS__) diff --git a/src/zencore/intmath.cpp b/src/zencore/intmath.cpp index 5a686dc8e..fedf76edc 100644 --- a/src/zencore/intmath.cpp +++ b/src/zencore/intmath.cpp @@ -19,6 +19,8 @@ intmath_forcelink() { } +TEST_SUITE_BEGIN("core.intmath"); + TEST_CASE("intmath") { CHECK(FloorLog2(0x00) == 0); @@ -43,6 +45,12 @@ TEST_CASE("intmath") CHECK(FloorLog2_64(0x0000'0001'0000'0000ull) == 32); CHECK(FloorLog2_64(0x8000'0000'0000'0000ull) == 63); + CHECK(CountLeadingZeros(0x8000'0000u) == 0); + CHECK(CountLeadingZeros(0x0000'0000u) == 32); + CHECK(CountLeadingZeros(0x0000'0001u) == 31); + CHECK(CountLeadingZeros(0x0000'8000u) == 16); + CHECK(CountLeadingZeros(0x0001'0000u) == 15); + CHECK(CountLeadingZeros64(0x8000'0000'0000'0000ull) == 0); CHECK(CountLeadingZeros64(0x0000'0000'0000'0000ull) == 64); CHECK(CountLeadingZeros64(0x0000'0000'0000'0001ull) == 63); @@ -60,6 +68,8 @@ TEST_CASE("intmath") CHECK(ByteSwap(uint64_t(0x214d'6172'7469'6e21ull)) == 0x216e'6974'7261'4d21ull); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/iobuffer.cpp b/src/zencore/iobuffer.cpp index be9b39e7a..c47c54981 100644 --- a/src/zencore/iobuffer.cpp +++ b/src/zencore/iobuffer.cpp @@ -592,15 +592,17 @@ IoBufferBuilder::ReadFromFileMaybe(const IoBuffer& InBuffer) } IoBuffer -IoBufferBuilder::MakeFromFileHandle(void* FileHandle, uint64_t Offset, uint64_t Size) +IoBufferBuilder::MakeFromFileHandle(void* FileHandle, uint64_t Offset, uint64_t Size, ZenContentType ContentType) { ZEN_TRACE_CPU("IoBufferBuilder::MakeFromFileHandle"); - return IoBuffer(IoBuffer::BorrowedFile, FileHandle, Offset, Size); + IoBuffer Buffer(IoBuffer::BorrowedFile, FileHandle, Offset, Size); + Buffer.SetContentType(ContentType); + return Buffer; } IoBuffer -IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset, uint64_t Size) +IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset, uint64_t Size, ZenContentType ContentType) { ZEN_TRACE_CPU("IoBufferBuilder::MakeFromFile"); @@ -632,8 +634,6 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of FileSize = Stat.st_size; #endif // ZEN_PLATFORM_WINDOWS - // TODO: should validate that offset is in range - if (Size == ~0ull) { Size = FileSize - Offset; @@ -652,7 +652,9 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of #if ZEN_PLATFORM_WINDOWS void* Fd = DataFile.Detach(); #endif - return IoBuffer(IoBuffer::File, (void*)uintptr_t(Fd), Offset, Size, Offset == 0 && Size == FileSize); + IoBuffer NewBuffer(IoBuffer::File, (void*)uintptr_t(Fd), Offset, Size, Offset == 0 && Size == FileSize); + NewBuffer.SetContentType(ContentType); + return NewBuffer; } #if !ZEN_PLATFORM_WINDOWS @@ -664,7 +666,7 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of } IoBuffer -IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName) +IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName, ZenContentType ContentType) { ZEN_TRACE_CPU("IoBufferBuilder::MakeFromTemporaryFile"); @@ -703,7 +705,9 @@ IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName) Handle = (void*)uintptr_t(Fd); #endif // ZEN_PLATFORM_WINDOWS - return IoBuffer(IoBuffer::File, Handle, 0, FileSize, /*IsWholeFile*/ true); + IoBuffer NewBuffer(IoBuffer::File, Handle, 0, FileSize, /*IsWholeFile*/ true); + NewBuffer.SetContentType(ContentType); + return NewBuffer; } ////////////////////////////////////////////////////////////////////////// @@ -715,6 +719,8 @@ iobuffer_forcelink() { } +TEST_SUITE_BEGIN("core.iobuffer"); + TEST_CASE("IoBuffer") { zen::IoBuffer buffer1; @@ -752,6 +758,8 @@ TEST_CASE("IoBuffer.mmap") # endif } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/jobqueue.cpp b/src/zencore/jobqueue.cpp index 75c1be42b..d6a8a6479 100644 --- a/src/zencore/jobqueue.cpp +++ b/src/zencore/jobqueue.cpp @@ -90,7 +90,7 @@ public: uint64_t NewJobId = IdGenerator.fetch_add(1); if (NewJobId == 0) { - IdGenerator.fetch_add(1); + NewJobId = IdGenerator.fetch_add(1); } RefPtr<Job> NewJob(new Job()); NewJob->Queue = this; @@ -129,7 +129,7 @@ public: QueuedJobs.erase(It); } }); - ZEN_ERROR("Failed to schedule job {}:'{}' to job queue. Reason: ''", NewJob->Id.Id, NewJob->Name, Ex.what()); + ZEN_ERROR("Failed to schedule job {}:'{}' to job queue. Reason: '{}'", NewJob->Id.Id, NewJob->Name, Ex.what()); throw; } } @@ -221,11 +221,11 @@ public: std::vector<JobInfo> Jobs; QueueLock.WithSharedLock([&]() { - for (auto It : RunningJobs) + for (const auto& It : RunningJobs) { Jobs.push_back({.Id = JobId{It.first}, .Status = JobStatus::Running}); } - for (auto It : CompletedJobs) + for (const auto& It : CompletedJobs) { if (IsStale(It.second->EndTick)) { @@ -234,7 +234,7 @@ public: } Jobs.push_back({.Id = JobId{It.first}, .Status = JobStatus::Completed}); } - for (auto It : AbortedJobs) + for (const auto& It : AbortedJobs) { if (IsStale(It.second->EndTick)) { @@ -243,7 +243,7 @@ public: } Jobs.push_back({.Id = JobId{It.first}, .Status = JobStatus::Aborted}); } - for (auto It : QueuedJobs) + for (const auto& It : QueuedJobs) { Jobs.push_back({.Id = It->Id, .Status = JobStatus::Queued}); } @@ -337,7 +337,7 @@ public: std::atomic_bool InitializedFlag = false; RwLock QueueLock; std::deque<RefPtr<Job>> QueuedJobs; - std::unordered_map<uint64_t, Job*> RunningJobs; + std::unordered_map<uint64_t, RefPtr<Job>> RunningJobs; std::unordered_map<uint64_t, RefPtr<Job>> CompletedJobs; std::unordered_map<uint64_t, RefPtr<Job>> AbortedJobs; @@ -429,20 +429,16 @@ JobQueue::ToString(JobStatus Status) { case JobQueue::JobStatus::Queued: return "Queued"sv; - break; case JobQueue::JobStatus::Running: return "Running"sv; - break; case JobQueue::JobStatus::Aborted: return "Aborted"sv; - break; case JobQueue::JobStatus::Completed: return "Completed"sv; - break; default: ZEN_ASSERT(false); + return ""sv; } - return ""sv; } std::unique_ptr<JobQueue> @@ -460,6 +456,8 @@ jobqueue_forcelink() { } +TEST_SUITE_BEGIN("core.jobqueue"); + TEST_CASE("JobQueue") { std::unique_ptr<JobQueue> Queue(MakeJobQueue(2, "queue")); @@ -580,6 +578,8 @@ TEST_CASE("JobQueue") } JobsLatch.Wait(); } + +TEST_SUITE_END(); #endif } // namespace zen diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp index a6697c443..099518637 100644 --- a/src/zencore/logging.cpp +++ b/src/zencore/logging.cpp @@ -2,208 +2,128 @@ #include "zencore/logging.h" +#include <zencore/logging/ansicolorsink.h> +#include <zencore/logging/logger.h> +#include <zencore/logging/messageonlyformatter.h> +#include <zencore/logging/nullsink.h> +#include <zencore/logging/registry.h> #include <zencore/string.h> #include <zencore/testing.h> #include <zencore/thread.h> #include <zencore/memory/llm.h> -ZEN_THIRD_PARTY_INCLUDES_START -#include <spdlog/details/registry.h> -#include <spdlog/sinks/null_sink.h> -#include <spdlog/sinks/stdout_color_sinks.h> -#include <spdlog/spdlog.h> -ZEN_THIRD_PARTY_INCLUDES_END +#include <mutex> #if ZEN_PLATFORM_WINDOWS # pragma section(".zlog$a", read) -# pragma section(".zlog$f", read) -# pragma section(".zlog$m", read) -# pragma section(".zlog$s", read) +# pragma section(".zlog$l", read) # pragma section(".zlog$z", read) #endif namespace zen { -// We shadow the underlying spdlog default logger, in order to avoid a bunch of overhead LoggerRef TheDefaultLogger; +LoggerRef +Log() +{ + if (TheDefaultLogger) + { + return TheDefaultLogger; + } + return zen::logging::ConsoleLog(); +} + } // namespace zen namespace zen::logging { -using MemoryBuffer_t = fmt::basic_memory_buffer<char, 250>; - -struct LoggingContext -{ - inline LoggingContext(); - inline ~LoggingContext(); - - zen::logging::MemoryBuffer_t MessageBuffer; - - inline std::string_view Message() const { return std::string_view(MessageBuffer.data(), MessageBuffer.size()); } -}; +////////////////////////////////////////////////////////////////////////// -LoggingContext::LoggingContext() +LoggerRef +LogCategory::Logger() { -} + // This should be thread safe since zen::logging::Get() will return + // the same logger instance for the same category name. Also the + // LoggerRef is simply a pointer. + if (!m_LoggerRef) + { + m_LoggerRef = zen::logging::Get(m_CategoryName); + } -LoggingContext::~LoggingContext() -{ + return m_LoggerRef; } -////////////////////////////////////////////////////////////////////////// - static inline bool -IsErrorLevel(int LogLevel) +IsErrorLevel(LogLevel InLevel) { - return (LogLevel == zen::logging::level::Err || LogLevel == zen::logging::level::Critical); + return (InLevel == Err || InLevel == Critical); }; -static_assert(sizeof(spdlog::source_loc) == sizeof(SourceLocation)); -static_assert(offsetof(spdlog::source_loc, filename) == offsetof(SourceLocation, filename)); -static_assert(offsetof(spdlog::source_loc, line) == offsetof(SourceLocation, line)); -static_assert(offsetof(spdlog::source_loc, funcname) == offsetof(SourceLocation, funcname)); - void -EmitLogMessage(LoggerRef& Logger, int LogLevel, const std::string_view Message) +EmitLogMessage(LoggerRef& Logger, const LogPoint& Lp, fmt::format_args Args) { ZEN_MEMSCOPE(ELLMTag::Logging); - const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; - Logger.SpdLogger->log(InLevel, Message); - if (IsErrorLevel(LogLevel)) - { - if (LoggerRef ErrLogger = zen::logging::ErrorLog()) - { - ErrLogger.SpdLogger->log(InLevel, Message); - } - } -} -void -EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Format, fmt::format_args Args) -{ - ZEN_MEMSCOPE(ELLMTag::Logging); - zen::logging::LoggingContext LogCtx; - fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args); - zen::logging::EmitLogMessage(Logger, LogLevel, LogCtx.Message()); -} + Logger->Log(Lp, Args); -void -EmitLogMessage(LoggerRef& Logger, const SourceLocation& InLocation, int LogLevel, const std::string_view Message) -{ - ZEN_MEMSCOPE(ELLMTag::Logging); - const spdlog::source_loc& Location = *reinterpret_cast<const spdlog::source_loc*>(&InLocation); - const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; - Logger.SpdLogger->log(Location, InLevel, Message); - if (IsErrorLevel(LogLevel)) + if (IsErrorLevel(Lp.Level)) { if (LoggerRef ErrLogger = zen::logging::ErrorLog()) { - ErrLogger.SpdLogger->log(Location, InLevel, Message); + ErrLogger->Log(Lp, Args); } } } void -EmitLogMessage(LoggerRef& Logger, const SourceLocation& InLocation, int LogLevel, std::string_view Format, fmt::format_args Args) +EmitConsoleLogMessage(const LogPoint& Lp, fmt::format_args Args) { ZEN_MEMSCOPE(ELLMTag::Logging); - zen::logging::LoggingContext LogCtx; - fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args); - zen::logging::EmitLogMessage(Logger, InLocation, LogLevel, LogCtx.Message()); -} - -void -EmitConsoleLogMessage(int LogLevel, const std::string_view Message) -{ - ZEN_MEMSCOPE(ELLMTag::Logging); - const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; - ConsoleLog().SpdLogger->log(InLevel, Message); -} - -#define ZEN_COLOR_YELLOW "\033[0;33m" -#define ZEN_COLOR_RED "\033[0;31m" -#define ZEN_BRIGHT_COLOR_RED "\033[1;31m" -#define ZEN_COLOR_RESET "\033[0m" - -void -EmitConsoleLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) -{ - ZEN_MEMSCOPE(ELLMTag::Logging); - zen::logging::LoggingContext LogCtx; - - // We are not using a format option for console which include log level since it would interfere with normal console output - - const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; - switch (InLevel) - { - case spdlog::level::level_enum::warn: - fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_COLOR_YELLOW "Warning: " ZEN_COLOR_RESET); - break; - case spdlog::level::level_enum::err: - fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_BRIGHT_COLOR_RED "Error: " ZEN_COLOR_RESET); - break; - case spdlog::level::level_enum::critical: - fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_COLOR_RED "Critical: " ZEN_COLOR_RESET); - break; - default: - break; - } - fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args); - zen::logging::EmitConsoleLogMessage(LogLevel, LogCtx.Message()); + ConsoleLog()->Log(Lp, Args); } } // namespace zen::logging -namespace zen::logging::level { +namespace zen::logging { -spdlog::level::level_enum -to_spdlog_level(LogLevel NewLogLevel) -{ - return static_cast<spdlog::level::level_enum>((int)NewLogLevel); -} +constinit std::string_view LevelNames[] = {std::string_view("trace", 5), + std::string_view("debug", 5), + std::string_view("info", 4), + std::string_view("warning", 7), + std::string_view("error", 5), + std::string_view("critical", 8), + std::string_view("off", 3)}; LogLevel -to_logging_level(spdlog::level::level_enum NewLogLevel) -{ - return static_cast<LogLevel>((int)NewLogLevel); -} - -constinit std::string_view LevelNames[] = {ZEN_LEVEL_NAME_TRACE, - ZEN_LEVEL_NAME_DEBUG, - ZEN_LEVEL_NAME_INFO, - ZEN_LEVEL_NAME_WARNING, - ZEN_LEVEL_NAME_ERROR, - ZEN_LEVEL_NAME_CRITICAL, - ZEN_LEVEL_NAME_OFF}; - -level::LogLevel ParseLogLevelString(std::string_view Name) { - for (int Level = 0; Level < level::LogLevelCount; ++Level) + for (int Level = 0; Level < LogLevelCount; ++Level) { if (LevelNames[Level] == Name) - return static_cast<level::LogLevel>(Level); + { + return static_cast<LogLevel>(Level); + } } if (Name == "warn") { - return level::Warn; + return Warn; } if (Name == "err") { - return level::Err; + return Err; } - return level::Off; + return Off; } std::string_view -ToStringView(level::LogLevel Level) +ToStringView(LogLevel Level) { - if (int(Level) < level::LogLevelCount) + if (int(Level) < LogLevelCount) { return LevelNames[int(Level)]; } @@ -211,17 +131,17 @@ ToStringView(level::LogLevel Level) return "None"; } -} // namespace zen::logging::level +} // namespace zen::logging ////////////////////////////////////////////////////////////////////////// namespace zen::logging { RwLock LogLevelsLock; -std::string LogLevels[level::LogLevelCount]; +std::string LogLevels[LogLevelCount]; void -ConfigureLogLevels(level::LogLevel Level, std::string_view Loggers) +ConfigureLogLevels(LogLevel Level, std::string_view Loggers) { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -230,18 +150,18 @@ ConfigureLogLevels(level::LogLevel Level, std::string_view Loggers) } void -RefreshLogLevels(level::LogLevel* DefaultLevel) +RefreshLogLevels(LogLevel* DefaultLevel) { ZEN_MEMSCOPE(ELLMTag::Logging); - spdlog::details::registry::log_levels Levels; + std::vector<std::pair<std::string, LogLevel>> Levels; { RwLock::SharedLockScope _(LogLevelsLock); - for (int i = 0; i < level::LogLevelCount; ++i) + for (int i = 0; i < LogLevelCount; ++i) { - level::LogLevel CurrentLevel{i}; + LogLevel CurrentLevel{i}; std::string_view Spec = LogLevels[i]; @@ -251,7 +171,7 @@ RefreshLogLevels(level::LogLevel* DefaultLevel) if (auto CommaPos = Spec.find_first_of(','); CommaPos != std::string_view::npos) { - LoggerName = Spec.substr(CommaPos + 1); + LoggerName = Spec.substr(0, CommaPos); Spec.remove_prefix(CommaPos + 1); } else @@ -260,24 +180,16 @@ RefreshLogLevels(level::LogLevel* DefaultLevel) Spec = {}; } - Levels[LoggerName] = to_spdlog_level(CurrentLevel); + Levels.emplace_back(std::move(LoggerName), CurrentLevel); } } } - if (DefaultLevel) - { - spdlog::level::level_enum SpdDefaultLevel = to_spdlog_level(*DefaultLevel); - spdlog::details::registry::instance().set_levels(Levels, &SpdDefaultLevel); - } - else - { - spdlog::details::registry::instance().set_levels(Levels, nullptr); - } + Registry::Instance().SetLevels(Levels, DefaultLevel); } void -RefreshLogLevels(level::LogLevel DefaultLevel) +RefreshLogLevels(LogLevel DefaultLevel) { RefreshLogLevels(&DefaultLevel); } @@ -289,21 +201,21 @@ RefreshLogLevels() } void -SetLogLevel(level::LogLevel NewLogLevel) +SetLogLevel(LogLevel NewLogLevel) { - spdlog::set_level(to_spdlog_level(NewLogLevel)); + Registry::Instance().SetGlobalLevel(NewLogLevel); } -level::LogLevel +LogLevel GetLogLevel() { - return level::to_logging_level(spdlog::get_level()); + return Registry::Instance().GetGlobalLevel(); } LoggerRef Default() { - ZEN_ASSERT(TheDefaultLogger); + ZEN_ASSERT(TheDefaultLogger, "logging::InitializeLogging() must be called before using the logger"); return TheDefaultLogger; } @@ -312,10 +224,10 @@ SetDefault(std::string_view NewDefaultLoggerId) { ZEN_MEMSCOPE(ELLMTag::Logging); - auto NewDefaultLogger = spdlog::get(std::string(NewDefaultLoggerId)); + Ref<Logger> NewDefaultLogger = Registry::Instance().Get(std::string(NewDefaultLoggerId)); ZEN_ASSERT(NewDefaultLogger); - spdlog::set_default_logger(NewDefaultLogger); + Registry::Instance().SetDefaultLogger(NewDefaultLogger); TheDefaultLogger = LoggerRef(*NewDefaultLogger); } @@ -338,11 +250,11 @@ SetErrorLog(std::string_view NewErrorLoggerId) } else { - auto NewErrorLogger = spdlog::get(std::string(NewErrorLoggerId)); + Ref<Logger> NewErrorLogger = Registry::Instance().Get(std::string(NewErrorLoggerId)); ZEN_ASSERT(NewErrorLogger); - TheErrorLogger = LoggerRef(*NewErrorLogger.get()); + TheErrorLogger = LoggerRef(*NewErrorLogger.Get()); } } @@ -353,39 +265,75 @@ Get(std::string_view Name) { ZEN_MEMSCOPE(ELLMTag::Logging); - std::shared_ptr<spdlog::logger> Logger = spdlog::get(std::string(Name)); + Ref<Logger> FoundLogger = Registry::Instance().Get(std::string(Name)); - if (!Logger) + if (!FoundLogger) { g_LoggerMutex.WithExclusiveLock([&] { - Logger = spdlog::get(std::string(Name)); + FoundLogger = Registry::Instance().Get(std::string(Name)); - if (!Logger) + if (!FoundLogger) { - Logger = Default().SpdLogger->clone(std::string(Name)); - spdlog::apply_logger_env_levels(Logger); - spdlog::register_logger(Logger); + FoundLogger = Default()->Clone(std::string(Name)); + Registry::Instance().Register(FoundLogger); } }); } - return *Logger; + return *FoundLogger; } -std::once_flag ConsoleInitFlag; -std::shared_ptr<spdlog::logger> ConLogger; +std::once_flag ConsoleInitFlag; +Ref<Logger> ConLogger; void SuppressConsoleLog() { + ZEN_MEMSCOPE(ELLMTag::Logging); + if (ConLogger) { - spdlog::drop("console"); + Registry::Instance().Drop("console"); ConLogger = {}; } - ConLogger = spdlog::null_logger_mt("console"); + + SinkPtr NullSinkPtr(new NullSink()); + ConLogger = Ref<Logger>(new Logger("console", std::vector<SinkPtr>{NullSinkPtr})); + Registry::Instance().Register(ConLogger); } +#define ZEN_COLOR_YELLOW "\033[0;33m" +#define ZEN_COLOR_RED "\033[0;31m" +#define ZEN_BRIGHT_COLOR_RED "\033[1;31m" +#define ZEN_COLOR_RESET "\033[0m" + +class ConsoleFormatter : public Formatter +{ +public: + void Format(const LogMessage& Msg, MemoryBuffer& Dest) override + { + switch (Msg.GetLevel()) + { + case Warn: + fmt::format_to(fmt::appender(Dest), ZEN_COLOR_YELLOW "Warning: " ZEN_COLOR_RESET); + break; + case Err: + fmt::format_to(fmt::appender(Dest), ZEN_BRIGHT_COLOR_RED "Error: " ZEN_COLOR_RESET); + break; + case Critical: + fmt::format_to(fmt::appender(Dest), ZEN_COLOR_RED "Critical: " ZEN_COLOR_RESET); + break; + default: + break; + } + + helpers::AppendStringView(Msg.GetPayload(), Dest); + Dest.push_back('\n'); + } + + std::unique_ptr<Formatter> Clone() const override { return std::make_unique<ConsoleFormatter>(); } +}; + LoggerRef ConsoleLog() { @@ -394,10 +342,10 @@ ConsoleLog() std::call_once(ConsoleInitFlag, [&] { if (!ConLogger) { - ConLogger = spdlog::stdout_color_mt("console"); - spdlog::apply_logger_env_levels(ConLogger); - - ConLogger->set_pattern("%v"); + SinkPtr ConsoleSink(new AnsiColorStdoutSink()); + ConsoleSink->SetFormatter(std::make_unique<ConsoleFormatter>()); + ConLogger = Ref<Logger>(new Logger("console", std::vector<SinkPtr>{ConsoleSink})); + Registry::Instance().Register(ConLogger); } }); @@ -405,17 +353,29 @@ ConsoleLog() } void +ResetConsoleLog() +{ + ZEN_MEMSCOPE(ELLMTag::Logging); + + LoggerRef ConLog = ConsoleLog(); + + ConLog->SetFormatter(std::make_unique<ConsoleFormatter>()); +} + +void InitializeLogging() { ZEN_MEMSCOPE(ELLMTag::Logging); - TheDefaultLogger = *spdlog::default_logger_raw(); + TheDefaultLogger = *Registry::Instance().DefaultLoggerRaw(); } void ShutdownLogging() { - spdlog::shutdown(); + ZEN_MEMSCOPE(ELLMTag::Logging); + + Registry::Instance().Shutdown(); TheDefaultLogger = {}; } @@ -449,7 +409,7 @@ EnableVTMode() void FlushLogging() { - spdlog::details::registry::instance().flush_all(); + Registry::Instance().FlushAll(); } } // namespace zen::logging @@ -457,21 +417,27 @@ FlushLogging() namespace zen { bool -LoggerRef::ShouldLog(int Level) const +LoggerRef::ShouldLog(logging::LogLevel Level) const { - return SpdLogger->should_log(static_cast<spdlog::level::level_enum>(Level)); + return m_Logger->ShouldLog(Level); } void -LoggerRef::SetLogLevel(logging::level::LogLevel NewLogLevel) +LoggerRef::SetLogLevel(logging::LogLevel NewLogLevel) { - SpdLogger->set_level(to_spdlog_level(NewLogLevel)); + m_Logger->SetLevel(NewLogLevel); } -logging::level::LogLevel +logging::LogLevel LoggerRef::GetLogLevel() { - return logging::level::to_logging_level(SpdLogger->level()); + return m_Logger->GetLevel(); +} + +void +LoggerRef::Flush() +{ + m_Logger->Flush(); } thread_local ScopedActivityBase* t_ScopeStack = nullptr; @@ -532,6 +498,8 @@ logging_forcelink() using namespace std::literals; +TEST_SUITE_BEGIN("core.logging"); + TEST_CASE("simple.bread") { ExtendableStringBuilder<256> Crumbs; @@ -580,6 +548,8 @@ TEST_CASE("simple.bread") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/logging/ansicolorsink.cpp b/src/zencore/logging/ansicolorsink.cpp new file mode 100644 index 000000000..540d22359 --- /dev/null +++ b/src/zencore/logging/ansicolorsink.cpp @@ -0,0 +1,273 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/logging/ansicolorsink.h> +#include <zencore/logging/helpers.h> +#include <zencore/logging/messageonlyformatter.h> + +#include <cstdio> +#include <cstdlib> +#include <mutex> + +#if defined(_WIN32) +# include <io.h> +# define ZEN_ISATTY _isatty +# define ZEN_FILENO _fileno +#else +# include <unistd.h> +# define ZEN_ISATTY isatty +# define ZEN_FILENO fileno +#endif + +namespace zen::logging { + +// Default formatter replicating spdlog's %+ pattern: +// [YYYY-MM-DD HH:MM:SS.mmm] [logger_name] [level] message\n +class DefaultConsoleFormatter : public Formatter +{ +public: + void Format(const LogMessage& Msg, MemoryBuffer& Dest) override + { + // timestamp + auto Secs = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch()); + if (Secs != m_LastLogSecs) + { + m_LastLogSecs = Secs; + m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); + } + + Dest.push_back('['); + helpers::AppendInt(m_CachedLocalTm.tm_year + 1900, Dest); + Dest.push_back('-'); + helpers::Pad2(m_CachedLocalTm.tm_mon + 1, Dest); + Dest.push_back('-'); + helpers::Pad2(m_CachedLocalTm.tm_mday, Dest); + Dest.push_back(' '); + helpers::Pad2(m_CachedLocalTm.tm_hour, Dest); + Dest.push_back(':'); + helpers::Pad2(m_CachedLocalTm.tm_min, Dest); + Dest.push_back(':'); + helpers::Pad2(m_CachedLocalTm.tm_sec, Dest); + Dest.push_back('.'); + auto Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime()); + helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest); + Dest.push_back(']'); + Dest.push_back(' '); + + // logger name + if (Msg.GetLoggerName().size() > 0) + { + Dest.push_back('['); + helpers::AppendStringView(Msg.GetLoggerName(), Dest); + Dest.push_back(']'); + Dest.push_back(' '); + } + + // level (colored range) + Dest.push_back('['); + Msg.ColorRangeStart = Dest.size(); + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); + Msg.ColorRangeEnd = Dest.size(); + Dest.push_back(']'); + Dest.push_back(' '); + + // message + helpers::AppendStringView(Msg.GetPayload(), Dest); + Dest.push_back('\n'); + } + + std::unique_ptr<Formatter> Clone() const override { return std::make_unique<DefaultConsoleFormatter>(); } + +private: + std::chrono::seconds m_LastLogSecs{0}; + std::tm m_CachedLocalTm{}; +}; + +static constexpr std::string_view s_Reset = "\033[m"; + +static std::string_view +GetColorForLevel(LogLevel InLevel) +{ + using namespace std::string_view_literals; + switch (InLevel) + { + case Trace: + return "\033[37m"sv; // white + case Debug: + return "\033[36m"sv; // cyan + case Info: + return "\033[32m"sv; // green + case Warn: + return "\033[33m\033[1m"sv; // bold yellow + case Err: + return "\033[31m\033[1m"sv; // bold red + case Critical: + return "\033[1m\033[41m"sv; // bold on red background + default: + return s_Reset; + } +} + +struct AnsiColorStdoutSink::Impl +{ + explicit Impl(ColorMode Mode) : m_Formatter(std::make_unique<DefaultConsoleFormatter>()), m_UseColor(ResolveColorMode(Mode)) {} + + static bool IsColorTerminal() + { + // If stdout is not a TTY, no color + if (ZEN_ISATTY(ZEN_FILENO(stdout)) == 0) + { + return false; + } + + // NO_COLOR convention (https://no-color.org/) + if (std::getenv("NO_COLOR") != nullptr) + { + return false; + } + + // COLORTERM is set by terminals that support color (e.g. "truecolor", "24bit") + if (std::getenv("COLORTERM") != nullptr) + { + return true; + } + + // Check TERM for known color-capable values + const char* Term = std::getenv("TERM"); + if (Term != nullptr) + { + std::string_view TermView(Term); + // "dumb" terminals do not support color + if (TermView == "dumb") + { + return false; + } + // Match against known color-capable terminal types. + // TERM often includes suffixes like "-256color", so we use substring matching. + constexpr std::string_view ColorTerms[] = { + "alacritty", + "ansi", + "color", + "console", + "cygwin", + "gnome", + "konsole", + "kterm", + "linux", + "msys", + "putty", + "rxvt", + "screen", + "tmux", + "vt100", + "vt102", + "xterm", + }; + for (std::string_view Candidate : ColorTerms) + { + if (TermView.find(Candidate) != std::string_view::npos) + { + return true; + } + } + } + +#if defined(_WIN32) + // Windows console supports ANSI color by default in modern versions + return true; +#else + // Unknown terminal — be conservative + return false; +#endif + } + + static bool ResolveColorMode(ColorMode Mode) + { + switch (Mode) + { + case ColorMode::On: + return true; + case ColorMode::Off: + return false; + case ColorMode::Auto: + default: + return IsColorTerminal(); + } + } + + void Log(const LogMessage& Msg) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + + MemoryBuffer Formatted; + m_Formatter->Format(Msg, Formatted); + + if (m_UseColor && Msg.ColorRangeEnd > Msg.ColorRangeStart) + { + // Print pre-color range + fwrite(Formatted.data(), 1, Msg.ColorRangeStart, m_File); + + // Print color + std::string_view Color = GetColorForLevel(Msg.GetLevel()); + fwrite(Color.data(), 1, Color.size(), m_File); + + // Print colored range + fwrite(Formatted.data() + Msg.ColorRangeStart, 1, Msg.ColorRangeEnd - Msg.ColorRangeStart, m_File); + + // Reset color + fwrite(s_Reset.data(), 1, s_Reset.size(), m_File); + + // Print remainder + fwrite(Formatted.data() + Msg.ColorRangeEnd, 1, Formatted.size() - Msg.ColorRangeEnd, m_File); + } + else + { + fwrite(Formatted.data(), 1, Formatted.size(), m_File); + } + + fflush(m_File); + } + + void Flush() + { + std::lock_guard<std::mutex> Lock(m_Mutex); + fflush(m_File); + } + + void SetFormatter(std::unique_ptr<Formatter> InFormatter) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + m_Formatter = std::move(InFormatter); + } + +private: + std::mutex m_Mutex; + std::unique_ptr<Formatter> m_Formatter; + FILE* m_File = stdout; + bool m_UseColor = true; +}; + +AnsiColorStdoutSink::AnsiColorStdoutSink(ColorMode Mode) : m_Impl(std::make_unique<Impl>(Mode)) +{ +} + +AnsiColorStdoutSink::~AnsiColorStdoutSink() = default; + +void +AnsiColorStdoutSink::Log(const LogMessage& Msg) +{ + m_Impl->Log(Msg); +} + +void +AnsiColorStdoutSink::Flush() +{ + m_Impl->Flush(); +} + +void +AnsiColorStdoutSink::SetFormatter(std::unique_ptr<Formatter> InFormatter) +{ + m_Impl->SetFormatter(std::move(InFormatter)); +} + +} // namespace zen::logging diff --git a/src/zencore/logging/asyncsink.cpp b/src/zencore/logging/asyncsink.cpp new file mode 100644 index 000000000..02bf9f3ba --- /dev/null +++ b/src/zencore/logging/asyncsink.cpp @@ -0,0 +1,212 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/logging/asyncsink.h> + +#include <zencore/blockingqueue.h> +#include <zencore/logging/logmsg.h> +#include <zencore/thread.h> + +#include <future> +#include <string> +#include <thread> + +namespace zen::logging { + +struct AsyncLogMessage +{ + enum class Type : uint8_t + { + Log, + Flush, + Shutdown + }; + + Type MsgType = Type::Log; + + // Points to the LogPoint from upstream logging code. LogMessage guarantees + // this is always valid (either a static LogPoint from ZEN_LOG macros or one + // of the per-level default LogPoints). + const LogPoint* Point = nullptr; + + int ThreadId = 0; + std::string OwnedPayload; + std::string OwnedLoggerName; + std::chrono::system_clock::time_point Time; + + std::shared_ptr<std::promise<void>> FlushPromise; +}; + +struct AsyncSink::Impl +{ + explicit Impl(std::vector<SinkPtr> InSinks) : m_Sinks(std::move(InSinks)) + { + m_WorkerThread = std::thread([this]() { + zen::SetCurrentThreadName("AsyncLog"); + WorkerLoop(); + }); + } + + ~Impl() + { + AsyncLogMessage ShutdownMsg; + ShutdownMsg.MsgType = AsyncLogMessage::Type::Shutdown; + m_Queue.Enqueue(std::move(ShutdownMsg)); + + if (m_WorkerThread.joinable()) + { + m_WorkerThread.join(); + } + } + + void Log(const LogMessage& Msg) + { + AsyncLogMessage AsyncMsg; + AsyncMsg.OwnedPayload = std::string(Msg.GetPayload()); + AsyncMsg.OwnedLoggerName = std::string(Msg.GetLoggerName()); + AsyncMsg.ThreadId = Msg.GetThreadId(); + AsyncMsg.Time = Msg.GetTime(); + AsyncMsg.Point = &Msg.GetLogPoint(); + AsyncMsg.MsgType = AsyncLogMessage::Type::Log; + + m_Queue.Enqueue(std::move(AsyncMsg)); + } + + void Flush() + { + auto Promise = std::make_shared<std::promise<void>>(); + auto Future = Promise->get_future(); + + AsyncLogMessage FlushMsg; + FlushMsg.MsgType = AsyncLogMessage::Type::Flush; + FlushMsg.FlushPromise = std::move(Promise); + + m_Queue.Enqueue(std::move(FlushMsg)); + + Future.get(); + } + + void SetFormatter(std::unique_ptr<Formatter> InFormatter) + { + for (auto& CurrentSink : m_Sinks) + { + CurrentSink->SetFormatter(InFormatter->Clone()); + } + } + +private: + void ForwardLogToSinks(const AsyncLogMessage& AsyncMsg) + { + LogMessage Reconstructed(*AsyncMsg.Point, AsyncMsg.OwnedLoggerName, AsyncMsg.OwnedPayload); + Reconstructed.SetTime(AsyncMsg.Time); + Reconstructed.SetThreadId(AsyncMsg.ThreadId); + + for (auto& CurrentSink : m_Sinks) + { + if (CurrentSink->ShouldLog(Reconstructed.GetLevel())) + { + try + { + CurrentSink->Log(Reconstructed); + } + catch (const std::exception&) + { + } + } + } + } + + void FlushSinks() + { + for (auto& CurrentSink : m_Sinks) + { + try + { + CurrentSink->Flush(); + } + catch (const std::exception&) + { + } + } + } + + void WorkerLoop() + { + AsyncLogMessage Msg; + while (m_Queue.WaitAndDequeue(Msg)) + { + switch (Msg.MsgType) + { + case AsyncLogMessage::Type::Log: + { + ForwardLogToSinks(Msg); + break; + } + + case AsyncLogMessage::Type::Flush: + { + FlushSinks(); + if (Msg.FlushPromise) + { + Msg.FlushPromise->set_value(); + } + break; + } + + case AsyncLogMessage::Type::Shutdown: + { + m_Queue.CompleteAdding(); + + AsyncLogMessage Remaining; + while (m_Queue.WaitAndDequeue(Remaining)) + { + if (Remaining.MsgType == AsyncLogMessage::Type::Log) + { + ForwardLogToSinks(Remaining); + } + else if (Remaining.MsgType == AsyncLogMessage::Type::Flush) + { + FlushSinks(); + if (Remaining.FlushPromise) + { + Remaining.FlushPromise->set_value(); + } + } + } + + FlushSinks(); + return; + } + } + } + } + + std::vector<SinkPtr> m_Sinks; + BlockingQueue<AsyncLogMessage> m_Queue; + std::thread m_WorkerThread; +}; + +AsyncSink::AsyncSink(std::vector<SinkPtr> InSinks) : m_Impl(std::make_unique<Impl>(std::move(InSinks))) +{ +} + +AsyncSink::~AsyncSink() = default; + +void +AsyncSink::Log(const LogMessage& Msg) +{ + m_Impl->Log(Msg); +} + +void +AsyncSink::Flush() +{ + m_Impl->Flush(); +} + +void +AsyncSink::SetFormatter(std::unique_ptr<Formatter> InFormatter) +{ + m_Impl->SetFormatter(std::move(InFormatter)); +} + +} // namespace zen::logging diff --git a/src/zencore/logging/logger.cpp b/src/zencore/logging/logger.cpp new file mode 100644 index 000000000..dd1675bb1 --- /dev/null +++ b/src/zencore/logging/logger.cpp @@ -0,0 +1,142 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/logging/logger.h> +#include <zencore/thread.h> + +#include <string> +#include <vector> + +namespace zen::logging { + +struct Logger::Impl +{ + std::string m_Name; + std::vector<SinkPtr> m_Sinks; + ErrorHandler* m_ErrorHandler = nullptr; +}; + +Logger::Logger(std::string_view InName, SinkPtr InSink) : m_Impl(std::make_unique<Impl>()) +{ + m_Impl->m_Name = InName; + m_Impl->m_Sinks.push_back(std::move(InSink)); +} + +Logger::Logger(std::string_view InName, std::span<const SinkPtr> InSinks) : m_Impl(std::make_unique<Impl>()) +{ + m_Impl->m_Name = InName; + m_Impl->m_Sinks.assign(InSinks.begin(), InSinks.end()); +} + +Logger::~Logger() = default; + +void +Logger::Log(const LogPoint& Point, fmt::format_args Args) +{ + if (!ShouldLog(Point.Level)) + { + return; + } + + fmt::basic_memory_buffer<char, 250> Buffer; + fmt::vformat_to(fmt::appender(Buffer), Point.FormatString, Args); + + LogMessage LogMsg(Point, m_Impl->m_Name, std::string_view(Buffer.data(), Buffer.size())); + LogMsg.SetThreadId(GetCurrentThreadId()); + SinkIt(LogMsg); + FlushIfNeeded(Point.Level); +} + +void +Logger::SinkIt(const LogMessage& Msg) +{ + for (auto& CurrentSink : m_Impl->m_Sinks) + { + if (CurrentSink->ShouldLog(Msg.GetLevel())) + { + try + { + CurrentSink->Log(Msg); + } + catch (const std::exception& Ex) + { + if (m_Impl->m_ErrorHandler) + { + m_Impl->m_ErrorHandler->HandleError(Ex.what()); + } + } + } + } +} + +void +Logger::FlushIfNeeded(LogLevel InLevel) +{ + if (InLevel >= m_FlushLevel.load(std::memory_order_relaxed)) + { + Flush(); + } +} + +void +Logger::Flush() +{ + for (auto& CurrentSink : m_Impl->m_Sinks) + { + try + { + CurrentSink->Flush(); + } + catch (const std::exception& Ex) + { + if (m_Impl->m_ErrorHandler) + { + m_Impl->m_ErrorHandler->HandleError(Ex.what()); + } + } + } +} + +void +Logger::SetSinks(std::vector<SinkPtr> InSinks) +{ + m_Impl->m_Sinks = std::move(InSinks); +} + +void +Logger::AddSink(SinkPtr InSink) +{ + m_Impl->m_Sinks.push_back(std::move(InSink)); +} + +void +Logger::SetErrorHandler(ErrorHandler* Handler) +{ + m_Impl->m_ErrorHandler = Handler; +} + +void +Logger::SetFormatter(std::unique_ptr<Formatter> InFormatter) +{ + for (auto& CurrentSink : m_Impl->m_Sinks) + { + CurrentSink->SetFormatter(InFormatter->Clone()); + } +} + +std::string_view +Logger::Name() const +{ + return m_Impl->m_Name; +} + +Ref<Logger> +Logger::Clone(std::string_view NewName) const +{ + Ref<Logger> Cloned(new Logger(NewName, m_Impl->m_Sinks)); + Cloned->SetLevel(m_Level.load(std::memory_order_relaxed)); + Cloned->SetFlushLevel(m_FlushLevel.load(std::memory_order_relaxed)); + Cloned->SetErrorHandler(m_Impl->m_ErrorHandler); + return Cloned; +} + +} // namespace zen::logging diff --git a/src/zencore/logging/msvcsink.cpp b/src/zencore/logging/msvcsink.cpp new file mode 100644 index 000000000..457a4d6e1 --- /dev/null +++ b/src/zencore/logging/msvcsink.cpp @@ -0,0 +1,80 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/zencore.h> + +#if ZEN_PLATFORM_WINDOWS + +# include <zencore/logging/helpers.h> +# include <zencore/logging/messageonlyformatter.h> +# include <zencore/logging/msvcsink.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <Windows.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::logging { + +// Default formatter for MSVC debug output: [level] message\n +// For error/critical messages with source info, prepends file(line): so that +// the message is clickable in the Visual Studio Output window. +class DefaultMsvcFormatter : public Formatter +{ +public: + void Format(const LogMessage& Msg, MemoryBuffer& Dest) override + { + const auto& Source = Msg.GetSource(); + if (Msg.GetLevel() >= LogLevel::Err && Source) + { + helpers::AppendStringView(Source.Filename, Dest); + Dest.push_back('('); + helpers::AppendInt(Source.Line, Dest); + Dest.push_back(')'); + Dest.push_back(':'); + Dest.push_back(' '); + } + + Dest.push_back('['); + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); + Dest.push_back(']'); + Dest.push_back(' '); + helpers::AppendStringView(Msg.GetPayload(), Dest); + Dest.push_back('\n'); + } + + std::unique_ptr<Formatter> Clone() const override { return std::make_unique<DefaultMsvcFormatter>(); } +}; + +MsvcSink::MsvcSink() : m_Formatter(std::make_unique<DefaultMsvcFormatter>()) +{ +} + +void +MsvcSink::Log(const LogMessage& Msg) +{ + std::lock_guard<std::mutex> Lock(m_Mutex); + + MemoryBuffer Formatted; + m_Formatter->Format(Msg, Formatted); + + // Null-terminate for OutputDebugStringA + Formatted.push_back('\0'); + + OutputDebugStringA(Formatted.data()); +} + +void +MsvcSink::Flush() +{ + // Nothing to flush for OutputDebugString +} + +void +MsvcSink::SetFormatter(std::unique_ptr<Formatter> InFormatter) +{ + std::lock_guard<std::mutex> Lock(m_Mutex); + m_Formatter = std::move(InFormatter); +} + +} // namespace zen::logging + +#endif // ZEN_PLATFORM_WINDOWS diff --git a/src/zencore/logging/registry.cpp b/src/zencore/logging/registry.cpp new file mode 100644 index 000000000..3ed1fb0df --- /dev/null +++ b/src/zencore/logging/registry.cpp @@ -0,0 +1,330 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/logging/registry.h> + +#include <zencore/logging/ansicolorsink.h> +#include <zencore/logging/messageonlyformatter.h> + +#include <atomic> +#include <condition_variable> +#include <mutex> +#include <thread> +#include <unordered_map> + +namespace zen::logging { + +struct Registry::Impl +{ + Impl() + { + // Create default logger with a stdout color sink + SinkPtr DefaultSink(new AnsiColorStdoutSink()); + m_DefaultLogger = Ref<Logger>(new Logger("", DefaultSink)); + m_Loggers[""] = m_DefaultLogger; + } + + ~Impl() { StopPeriodicFlush(); } + + void Register(Ref<Logger> InLogger) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + if (m_ErrorHandler) + { + InLogger->SetErrorHandler(m_ErrorHandler); + } + m_Loggers[std::string(InLogger->Name())] = std::move(InLogger); + } + + void Drop(const std::string& Name) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + m_Loggers.erase(Name); + } + + Ref<Logger> Get(const std::string& Name) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + auto It = m_Loggers.find(Name); + if (It != m_Loggers.end()) + { + return It->second; + } + return {}; + } + + void SetDefaultLogger(Ref<Logger> InLogger) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + if (InLogger) + { + m_Loggers[std::string(InLogger->Name())] = InLogger; + } + m_DefaultLogger = std::move(InLogger); + } + + Logger* DefaultLoggerRaw() { return m_DefaultLogger.Get(); } + + Ref<Logger> DefaultLogger() + { + std::lock_guard<std::mutex> Lock(m_Mutex); + return m_DefaultLogger; + } + + void SetGlobalLevel(LogLevel Level) + { + m_GlobalLevel.store(Level, std::memory_order_relaxed); + std::lock_guard<std::mutex> Lock(m_Mutex); + for (auto& [Name, CurLogger] : m_Loggers) + { + CurLogger->SetLevel(Level); + } + } + + LogLevel GetGlobalLevel() const { return m_GlobalLevel.load(std::memory_order_relaxed); } + + void SetLevels(Registry::LogLevels Levels, LogLevel* DefaultLevel) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + + if (DefaultLevel) + { + m_GlobalLevel.store(*DefaultLevel, std::memory_order_relaxed); + for (auto& [Name, CurLogger] : m_Loggers) + { + CurLogger->SetLevel(*DefaultLevel); + } + } + + for (auto& [LoggerName, Level] : Levels) + { + auto It = m_Loggers.find(LoggerName); + if (It != m_Loggers.end()) + { + It->second->SetLevel(Level); + } + } + } + + void FlushAll() + { + std::lock_guard<std::mutex> Lock(m_Mutex); + for (auto& [Name, CurLogger] : m_Loggers) + { + try + { + CurLogger->Flush(); + } + catch (const std::exception&) + { + } + } + } + + void FlushOn(LogLevel Level) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + m_FlushLevel = Level; + for (auto& [Name, CurLogger] : m_Loggers) + { + CurLogger->SetFlushLevel(Level); + } + } + + void FlushEvery(std::chrono::seconds Interval) + { + StopPeriodicFlush(); + + m_PeriodicFlushRunning.store(true, std::memory_order_relaxed); + + m_FlushThread = std::thread([this, Interval] { + while (m_PeriodicFlushRunning.load(std::memory_order_relaxed)) + { + { + std::unique_lock<std::mutex> Lock(m_PeriodicFlushMutex); + m_PeriodicFlushCv.wait_for(Lock, Interval, [this] { return !m_PeriodicFlushRunning.load(std::memory_order_relaxed); }); + } + + if (m_PeriodicFlushRunning.load(std::memory_order_relaxed)) + { + FlushAll(); + } + } + }); + } + + void SetFormatter(std::unique_ptr<Formatter> InFormatter) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + for (auto& [Name, CurLogger] : m_Loggers) + { + CurLogger->SetFormatter(InFormatter->Clone()); + } + } + + void ApplyAll(void (*Func)(void*, Ref<Logger>), void* Context) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + for (auto& [Name, CurLogger] : m_Loggers) + { + Func(Context, CurLogger); + } + } + + void SetErrorHandler(ErrorHandler* Handler) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + m_ErrorHandler = Handler; + for (auto& [Name, CurLogger] : m_Loggers) + { + CurLogger->SetErrorHandler(Handler); + } + } + + void Shutdown() + { + StopPeriodicFlush(); + FlushAll(); + + std::lock_guard<std::mutex> Lock(m_Mutex); + m_Loggers.clear(); + m_DefaultLogger = nullptr; + } + +private: + void StopPeriodicFlush() + { + if (m_FlushThread.joinable()) + { + m_PeriodicFlushRunning.store(false, std::memory_order_relaxed); + { + std::lock_guard<std::mutex> Lock(m_PeriodicFlushMutex); + m_PeriodicFlushCv.notify_one(); + } + m_FlushThread.join(); + } + } + + std::mutex m_Mutex; + std::unordered_map<std::string, Ref<Logger>> m_Loggers; + Ref<Logger> m_DefaultLogger; + std::atomic<LogLevel> m_GlobalLevel{Trace}; + LogLevel m_FlushLevel{Off}; + ErrorHandler* m_ErrorHandler = nullptr; + + // Periodic flush + std::atomic<bool> m_PeriodicFlushRunning{false}; + std::mutex m_PeriodicFlushMutex; + std::condition_variable m_PeriodicFlushCv; + std::thread m_FlushThread; +}; + +Registry& +Registry::Instance() +{ + static Registry s_Instance; + return s_Instance; +} + +Registry::Registry() : m_Impl(std::make_unique<Impl>()) +{ +} + +Registry::~Registry() = default; + +void +Registry::Register(Ref<Logger> InLogger) +{ + m_Impl->Register(std::move(InLogger)); +} + +void +Registry::Drop(const std::string& Name) +{ + m_Impl->Drop(Name); +} + +Ref<Logger> +Registry::Get(const std::string& Name) +{ + return m_Impl->Get(Name); +} + +void +Registry::SetDefaultLogger(Ref<Logger> InLogger) +{ + m_Impl->SetDefaultLogger(std::move(InLogger)); +} + +Logger* +Registry::DefaultLoggerRaw() +{ + return m_Impl->DefaultLoggerRaw(); +} + +Ref<Logger> +Registry::DefaultLogger() +{ + return m_Impl->DefaultLogger(); +} + +void +Registry::SetGlobalLevel(LogLevel Level) +{ + m_Impl->SetGlobalLevel(Level); +} + +LogLevel +Registry::GetGlobalLevel() const +{ + return m_Impl->GetGlobalLevel(); +} + +void +Registry::SetLevels(LogLevels Levels, LogLevel* DefaultLevel) +{ + m_Impl->SetLevels(Levels, DefaultLevel); +} + +void +Registry::FlushAll() +{ + m_Impl->FlushAll(); +} + +void +Registry::FlushOn(LogLevel Level) +{ + m_Impl->FlushOn(Level); +} + +void +Registry::FlushEvery(std::chrono::seconds Interval) +{ + m_Impl->FlushEvery(Interval); +} + +void +Registry::SetFormatter(std::unique_ptr<Formatter> InFormatter) +{ + m_Impl->SetFormatter(std::move(InFormatter)); +} + +void +Registry::ApplyAllImpl(void (*Func)(void*, Ref<Logger>), void* Context) +{ + m_Impl->ApplyAll(Func, Context); +} + +void +Registry::SetErrorHandler(ErrorHandler* Handler) +{ + m_Impl->SetErrorHandler(Handler); +} + +void +Registry::Shutdown() +{ + m_Impl->Shutdown(); +} + +} // namespace zen::logging diff --git a/src/zencore/logging/tracesink.cpp b/src/zencore/logging/tracesink.cpp new file mode 100644 index 000000000..8a6f4e40c --- /dev/null +++ b/src/zencore/logging/tracesink.cpp @@ -0,0 +1,92 @@ + +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/logbase.h> +#include <zencore/logging/tracesink.h> +#include <zencore/string.h> +#include <zencore/timer.h> +#include <zencore/trace.h> + +#if ZEN_WITH_TRACE + +namespace zen::logging { + +UE_TRACE_CHANNEL_DEFINE(LogChannel) + +UE_TRACE_EVENT_BEGIN(Logging, LogCategory, NoSync | Important) + UE_TRACE_EVENT_FIELD(const void*, CategoryPointer) + UE_TRACE_EVENT_FIELD(uint8_t, DefaultVerbosity) + UE_TRACE_EVENT_FIELD(UE::Trace::AnsiString, Name) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Logging, LogMessageSpec, NoSync | Important) + UE_TRACE_EVENT_FIELD(const void*, LogPoint) + UE_TRACE_EVENT_FIELD(const void*, CategoryPointer) + UE_TRACE_EVENT_FIELD(int32_t, Line) + UE_TRACE_EVENT_FIELD(uint8_t, Verbosity) + UE_TRACE_EVENT_FIELD(UE::Trace::AnsiString, FileName) + UE_TRACE_EVENT_FIELD(UE::Trace::AnsiString, FormatString) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Logging, LogMessage, NoSync) + UE_TRACE_EVENT_FIELD(const void*, LogPoint) + UE_TRACE_EVENT_FIELD(uint64_t, Cycle) + UE_TRACE_EVENT_FIELD(uint8_t[], FormatArgs) +UE_TRACE_EVENT_END() + +void +TraceLogCategory(const logging::Logger* Category, const char* Name, logging::LogLevel DefaultVerbosity) +{ + uint16_t NameLen = uint16_t(strlen(Name)); + UE_TRACE_LOG(Logging, LogCategory, LogChannel, NameLen * sizeof(ANSICHAR)) + << LogCategory.CategoryPointer(Category) << LogCategory.DefaultVerbosity(uint8_t(DefaultVerbosity)) + << LogCategory.Name(Name, NameLen); +} + +void +TraceLogMessageSpec(const void* LogPoint, + const logging::Logger* Category, + logging::LogLevel Verbosity, + const std::string_view File, + int32_t Line, + const std::string_view Format) +{ + uint16_t FileNameLen = uint16_t(File.size()); + uint16_t FormatStringLen = uint16_t(Format.size()); + uint32_t DataSize = (FileNameLen * sizeof(ANSICHAR)) + (FormatStringLen * sizeof(ANSICHAR)); + UE_TRACE_LOG(Logging, LogMessageSpec, LogChannel, DataSize) + << LogMessageSpec.LogPoint(LogPoint) << LogMessageSpec.CategoryPointer(Category) << LogMessageSpec.Line(Line) + << LogMessageSpec.Verbosity(uint8_t(Verbosity)) << LogMessageSpec.FileName(File.data(), FileNameLen) + << LogMessageSpec.FormatString(Format.data(), FormatStringLen); +} + +void +TraceLogMessageInternal(const void* LogPoint, int32_t EncodedFormatArgsSize, const uint8_t* EncodedFormatArgs) +{ + UE_TRACE_LOG(Logging, LogMessage, LogChannel) << LogMessage.LogPoint(LogPoint) << LogMessage.Cycle(GetHifreqTimerValue()) + << LogMessage.FormatArgs(EncodedFormatArgs, EncodedFormatArgsSize); +} + +////////////////////////////////////////////////////////////////////////// + +void +TraceSink::Log(const LogMessage& Msg) +{ + ZEN_UNUSED(Msg); +} + +void +TraceSink::Flush() +{ +} + +void +TraceSink::SetFormatter(std::unique_ptr<Formatter> /*InFormatter*/) +{ + // This sink doesn't use a formatter since it just forwards the raw format + // args to the trace system +} + +} // namespace zen::logging + +#endif diff --git a/src/zencore/md5.cpp b/src/zencore/md5.cpp index 4ec145697..f8cfee3ac 100644 --- a/src/zencore/md5.cpp +++ b/src/zencore/md5.cpp @@ -56,9 +56,9 @@ struct MD5_CTX unsigned char digest[16]; /* actual digest after MD5Final call */ }; -void MD5Init(); -void MD5Update(); -void MD5Final(); +void MD5Init(MD5_CTX* mdContext); +void MD5Update(MD5_CTX* mdContext, unsigned char* inBuf, unsigned int inLen); +void MD5Final(MD5_CTX* mdContext); /* ********************************************************************** @@ -342,6 +342,23 @@ Transform(uint32_t* buf, uint32_t* in) #undef G #undef H #undef I +#undef ROTATE_LEFT +#undef S11 +#undef S12 +#undef S13 +#undef S14 +#undef S21 +#undef S22 +#undef S23 +#undef S24 +#undef S31 +#undef S32 +#undef S33 +#undef S34 +#undef S41 +#undef S42 +#undef S43 +#undef S44 namespace zen { @@ -353,28 +370,32 @@ MD5 MD5::Zero; // Initialized to all zeroes MD5Stream::MD5Stream() { + static_assert(sizeof(MD5_CTX) <= sizeof(m_Context)); Reset(); } void MD5Stream::Reset() { + MD5Init(reinterpret_cast<MD5_CTX*>(m_Context)); } MD5Stream& MD5Stream::Append(const void* Data, size_t ByteCount) { - ZEN_UNUSED(Data); - ZEN_UNUSED(ByteCount); - + MD5Update(reinterpret_cast<MD5_CTX*>(m_Context), (unsigned char*)Data, (unsigned int)ByteCount); return *this; } MD5 MD5Stream::GetHash() { - MD5 md5{}; + MD5_CTX FinalCtx; + memcpy(&FinalCtx, m_Context, sizeof(MD5_CTX)); + MD5Final(&FinalCtx); + MD5 md5{}; + memcpy(md5.Hash, FinalCtx.digest, 16); return md5; } @@ -391,7 +412,7 @@ MD5::FromHexString(const char* string) { MD5 md5; - ParseHexBytes(string, 40, md5.Hash); + ParseHexBytes(string, 2 * sizeof md5.Hash, md5.Hash); return md5; } @@ -411,7 +432,7 @@ MD5::ToHexString(StringBuilderBase& outBuilder) const char str[41]; ToHexString(str); - outBuilder.AppendRange(str, &str[40]); + outBuilder.AppendRange(str, &str[StringLength]); return outBuilder; } @@ -437,6 +458,8 @@ md5_forcelink() // return md5text; // } +TEST_SUITE_BEGIN("core.md5"); + TEST_CASE("MD5") { using namespace std::literals; @@ -451,13 +474,15 @@ TEST_CASE("MD5") MD5::String_t Buffer; Result.ToHexString(Buffer); - CHECK(Output.compare(Buffer)); + CHECK(Output.compare(Buffer) == 0); MD5 Reresult = MD5::FromHexString(Buffer); Reresult.ToHexString(Buffer); - CHECK(Output.compare(Buffer)); + CHECK(Output.compare(Buffer) == 0); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/memoryview.cpp b/src/zencore/memoryview.cpp index 1f6a6996c..1654b1766 100644 --- a/src/zencore/memoryview.cpp +++ b/src/zencore/memoryview.cpp @@ -18,6 +18,8 @@ namespace zen { #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.memoryview"); + TEST_CASE("MemoryView") { { @@ -35,6 +37,8 @@ TEST_CASE("MemoryView") CHECK(MakeMemoryView<float>({1.0f, 1.2f}).GetSize() == 8); } +TEST_SUITE_END(); + void memory_forcelink() { diff --git a/src/zencore/memtrack/callstacktrace.cpp b/src/zencore/memtrack/callstacktrace.cpp index a5b7fede6..4a7068568 100644 --- a/src/zencore/memtrack/callstacktrace.cpp +++ b/src/zencore/memtrack/callstacktrace.cpp @@ -169,13 +169,13 @@ private: std::atomic_uint64_t Key; std::atomic_uint32_t Value; - inline uint64 GetKey() const { return Key.load(std::memory_order_relaxed); } + inline uint64 GetKey() const { return Key.load(std::memory_order_acquire); } inline uint32_t GetValue() const { return Value.load(std::memory_order_relaxed); } - inline bool IsEmpty() const { return Key.load(std::memory_order_relaxed) == 0; } + inline bool IsEmpty() const { return Key.load(std::memory_order_acquire) == 0; } inline void SetKeyValue(uint64_t InKey, uint32_t InValue) { - Value.store(InValue, std::memory_order_release); - Key.store(InKey, std::memory_order_relaxed); + Value.store(InValue, std::memory_order_relaxed); + Key.store(InKey, std::memory_order_release); } static inline uint32_t KeyHash(uint64_t Key) { return static_cast<uint32_t>(Key); } static inline void ClearEntries(FEncounteredCallstackSetEntry* Entries, int32_t EntryCount) diff --git a/src/zencore/memtrack/tagtrace.cpp b/src/zencore/memtrack/tagtrace.cpp index 70a74365d..fca4a2ec3 100644 --- a/src/zencore/memtrack/tagtrace.cpp +++ b/src/zencore/memtrack/tagtrace.cpp @@ -186,7 +186,7 @@ FTagTrace::AnnounceSpecialTags() const { auto EmitTag = [](const char16_t* DisplayString, int32_t Tag, int32_t ParentTag) { const uint32_t DisplayLen = (uint32_t)StringLength(DisplayString); - UE_TRACE_LOG(Memory, TagSpec, MemAllocChannel, DisplayLen * sizeof(ANSICHAR)) + UE_TRACE_LOG(Memory, TagSpec, MemAllocChannel, DisplayLen * sizeof(char16_t)) << TagSpec.Tag(Tag) << TagSpec.Parent(ParentTag) << TagSpec.Display(DisplayString, DisplayLen); }; diff --git a/src/zencore/mpscqueue.cpp b/src/zencore/mpscqueue.cpp index 29c76c3ca..bdd22e20c 100644 --- a/src/zencore/mpscqueue.cpp +++ b/src/zencore/mpscqueue.cpp @@ -7,7 +7,8 @@ namespace zen { -#if ZEN_WITH_TESTS && 0 +#if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.mpscqueue"); TEST_CASE("mpsc") { MpscQueue<std::string> Queue; @@ -15,6 +16,7 @@ TEST_CASE("mpsc") std::optional<std::string> Value = Queue.Dequeue(); CHECK_EQ(Value, "hello"); } +TEST_SUITE_END(); #endif void @@ -22,4 +24,4 @@ mpscqueue_forcelink() { } -} // namespace zen
\ No newline at end of file +} // namespace zen diff --git a/src/zencore/parallelwork.cpp b/src/zencore/parallelwork.cpp index d86d5815f..94696f479 100644 --- a/src/zencore/parallelwork.cpp +++ b/src/zencore/parallelwork.cpp @@ -157,6 +157,8 @@ ParallelWork::RethrowErrors() #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.parallelwork"); + TEST_CASE("parallellwork.nowork") { std::atomic<bool> AbortFlag; @@ -255,6 +257,8 @@ TEST_CASE("parallellwork.limitqueue") Work.Wait(); } +TEST_SUITE_END(); + void parallellwork_forcelink() { diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp index 56849a10d..f657869dc 100644 --- a/src/zencore/process.cpp +++ b/src/zencore/process.cpp @@ -9,6 +9,7 @@ #include <zencore/string.h> #include <zencore/testing.h> #include <zencore/timer.h> +#include <zencore/trace.h> #include <thread> @@ -490,6 +491,8 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma LPSECURITY_ATTRIBUTES ProcessAttributes = nullptr; LPSECURITY_ATTRIBUTES ThreadAttributes = nullptr; + const bool AssignToJob = Options.AssignToJob && Options.AssignToJob->IsValid(); + DWORD CreationFlags = 0; if (Options.Flags & CreateProcOptions::Flag_NewConsole) { @@ -503,6 +506,10 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma { CreationFlags |= CREATE_NEW_PROCESS_GROUP; } + if (AssignToJob) + { + CreationFlags |= CREATE_SUSPENDED; + } const wchar_t* WorkingDir = nullptr; if (Options.WorkingDirectory != nullptr) @@ -571,6 +578,15 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma return nullptr; } + if (AssignToJob) + { + if (!Options.AssignToJob->AssignProcess(ProcessInfo.hProcess)) + { + ZEN_WARN("Failed to assign newly created process to job object"); + } + ResumeThread(ProcessInfo.hThread); + } + CloseHandle(ProcessInfo.hThread); return ProcessInfo.hProcess; } @@ -644,6 +660,8 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C }; PROCESS_INFORMATION ProcessInfo = {}; + const bool AssignToJob = Options.AssignToJob && Options.AssignToJob->IsValid(); + if (Options.Flags & CreateProcOptions::Flag_NewConsole) { CreateProcFlags |= CREATE_NEW_CONSOLE; @@ -652,6 +670,10 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C { CreateProcFlags |= CREATE_NO_WINDOW; } + if (AssignToJob) + { + CreateProcFlags |= CREATE_SUSPENDED; + } ExtendableWideStringBuilder<256> CommandLineZ; CommandLineZ << CommandLine; @@ -679,6 +701,15 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C return nullptr; } + if (AssignToJob) + { + if (!Options.AssignToJob->AssignProcess(ProcessInfo.hProcess)) + { + ZEN_WARN("Failed to assign newly created process to job object"); + } + ResumeThread(ProcessInfo.hThread); + } + CloseHandle(ProcessInfo.hThread); return ProcessInfo.hProcess; } @@ -715,6 +746,8 @@ CreateProcElevated(const std::filesystem::path& Executable, std::string_view Com CreateProcResult CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine, const CreateProcOptions& Options) { + ZEN_TRACE_CPU("CreateProc"); + #if ZEN_PLATFORM_WINDOWS if (Options.Flags & CreateProcOptions::Flag_Unelevated) { @@ -746,6 +779,17 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine ZEN_UNUSED(Result); } + if (!Options.StdoutFile.empty()) + { + int Fd = open(Options.StdoutFile.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); + if (Fd >= 0) + { + dup2(Fd, STDOUT_FILENO); + dup2(Fd, STDERR_FILENO); + close(Fd); + } + } + if (execv(Executable.c_str(), ArgV.data()) < 0) { ThrowLastError("Failed to exec() a new process image"); @@ -845,6 +889,65 @@ ProcessMonitor::IsActive() const ////////////////////////////////////////////////////////////////////////// +#if ZEN_PLATFORM_WINDOWS +JobObject::JobObject() = default; + +JobObject::~JobObject() +{ + if (m_JobHandle) + { + CloseHandle(m_JobHandle); + m_JobHandle = nullptr; + } +} + +void +JobObject::Initialize() +{ + ZEN_ASSERT(m_JobHandle == nullptr, "JobObject already initialized"); + + m_JobHandle = CreateJobObjectW(nullptr, nullptr); + if (!m_JobHandle) + { + ZEN_WARN("Failed to create job object: {}", zen::GetLastError()); + return; + } + + JOBOBJECT_EXTENDED_LIMIT_INFORMATION LimitInfo = {}; + LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + + if (!SetInformationJobObject(m_JobHandle, JobObjectExtendedLimitInformation, &LimitInfo, sizeof(LimitInfo))) + { + ZEN_WARN("Failed to set job object limits: {}", zen::GetLastError()); + CloseHandle(m_JobHandle); + m_JobHandle = nullptr; + } +} + +bool +JobObject::AssignProcess(void* ProcessHandle) +{ + ZEN_ASSERT(m_JobHandle != nullptr, "JobObject not initialized"); + ZEN_ASSERT(ProcessHandle != nullptr, "ProcessHandle is null"); + + if (!AssignProcessToJobObject(m_JobHandle, ProcessHandle)) + { + ZEN_WARN("Failed to assign process to job object: {}", zen::GetLastError()); + return false; + } + + return true; +} + +bool +JobObject::IsValid() const +{ + return m_JobHandle != nullptr; +} +#endif // ZEN_PLATFORM_WINDOWS + +////////////////////////////////////////////////////////////////////////// + bool IsProcessRunning(int pid, std::error_code& OutEc) { @@ -1001,6 +1104,232 @@ GetProcessExecutablePath(int Pid, std::error_code& OutEc) #endif // ZEN_PLATFORM_LINUX } +std::string +GetProcessCommandLine(int Pid, std::error_code& OutEc) +{ +#if ZEN_PLATFORM_WINDOWS + HANDLE hProcess = OpenProcess(PROCESS_QUERY_INFORMATION, FALSE, static_cast<DWORD>(Pid)); + if (!hProcess) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + auto _ = MakeGuard([hProcess] { CloseHandle(hProcess); }); + + // NtQueryInformationProcess is an undocumented NT API; load it dynamically. + // Info class 60 = ProcessCommandLine, available since Windows 8.1. + using PFN_NtQIP = LONG(WINAPI*)(HANDLE, UINT, PVOID, ULONG, PULONG); + static const PFN_NtQIP s_NtQIP = + reinterpret_cast<PFN_NtQIP>(GetProcAddress(GetModuleHandleW(L"ntdll.dll"), "NtQueryInformationProcess")); + if (!s_NtQIP) + { + return {}; + } + + constexpr UINT ProcessCommandLineClass = 60; + constexpr LONG StatusInfoLengthMismatch = static_cast<LONG>(0xC0000004L); + + ULONG ReturnLength = 0; + LONG Status = s_NtQIP(hProcess, ProcessCommandLineClass, nullptr, 0, &ReturnLength); + if (Status != StatusInfoLengthMismatch || ReturnLength == 0) + { + return {}; + } + + std::vector<char> Buf(ReturnLength); + Status = s_NtQIP(hProcess, ProcessCommandLineClass, Buf.data(), ReturnLength, &ReturnLength); + if (Status < 0) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + + // Output: UNICODE_STRING header immediately followed by the UTF-16 string data. + // The UNICODE_STRING.Buffer field points into our Buf. + struct LocalUnicodeString + { + USHORT Length; + USHORT MaximumLength; + WCHAR* Buffer; + }; + if (ReturnLength < sizeof(LocalUnicodeString)) + { + return {}; + } + const auto* Us = reinterpret_cast<const LocalUnicodeString*>(Buf.data()); + if (Us->Length == 0 || Us->Buffer == nullptr) + { + return {}; + } + + // Skip argv[0]: may be a quoted path ("C:\...\exe.exe") or a bare path + const WCHAR* p = Us->Buffer; + const WCHAR* End = Us->Buffer + Us->Length / sizeof(WCHAR); + if (p < End && *p == L'"') + { + ++p; + while (p < End && *p != L'"') + { + ++p; + } + if (p < End) + { + ++p; // skip closing quote + } + } + else + { + while (p < End && *p != L' ') + { + ++p; + } + } + while (p < End && *p == L' ') + { + ++p; + } + if (p >= End) + { + return {}; + } + + int Utf8Size = WideCharToMultiByte(CP_UTF8, 0, p, static_cast<int>(End - p), nullptr, 0, nullptr, nullptr); + if (Utf8Size <= 0) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + std::string Result(Utf8Size, '\0'); + WideCharToMultiByte(CP_UTF8, 0, p, static_cast<int>(End - p), Result.data(), Utf8Size, nullptr, nullptr); + return Result; + +#elif ZEN_PLATFORM_LINUX + std::string CmdlinePath = fmt::format("/proc/{}/cmdline", Pid); + FILE* F = fopen(CmdlinePath.c_str(), "rb"); + if (!F) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + auto FGuard = MakeGuard([F] { fclose(F); }); + + // /proc/{pid}/cmdline contains null-separated argv entries; read it all + std::string Raw; + char Chunk[4096]; + size_t BytesRead; + while ((BytesRead = fread(Chunk, 1, sizeof(Chunk), F)) > 0) + { + Raw.append(Chunk, BytesRead); + } + if (Raw.empty()) + { + return {}; + } + + // Skip argv[0] (first null-terminated entry) + const char* p = Raw.data(); + const char* End = Raw.data() + Raw.size(); + while (p < End && *p != '\0') + { + ++p; + } + if (p < End) + { + ++p; // skip null terminator of argv[0] + } + + // Build result: remaining entries joined by spaces (inter-arg nulls → spaces) + std::string Result; + Result.reserve(static_cast<size_t>(End - p)); + for (const char* q = p; q < End; ++q) + { + Result += (*q == '\0') ? ' ' : *q; + } + while (!Result.empty() && Result.back() == ' ') + { + Result.pop_back(); + } + return Result; + +#elif ZEN_PLATFORM_MAC + int Mib[3] = {CTL_KERN, KERN_PROCARGS2, Pid}; + size_t BufSize = 0; + if (sysctl(Mib, 3, nullptr, &BufSize, nullptr, 0) != 0 || BufSize == 0) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + + std::vector<char> Buf(BufSize); + if (sysctl(Mib, 3, Buf.data(), &BufSize, nullptr, 0) != 0) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + + // Layout: [int argc][exec_path\0][null padding][argv[0]\0][argv[1]\0]...[envp\0]... + if (BufSize < sizeof(int)) + { + return {}; + } + int Argc = 0; + memcpy(&Argc, Buf.data(), sizeof(int)); + if (Argc <= 1) + { + return {}; + } + + const char* p = Buf.data() + sizeof(int); + const char* End = Buf.data() + BufSize; + + // Skip exec_path and any trailing null padding that follows it + while (p < End && *p != '\0') + { + ++p; + } + while (p < End && *p == '\0') + { + ++p; + } + + // Skip argv[0] + while (p < End && *p != '\0') + { + ++p; + } + if (p < End) + { + ++p; + } + + // Collect argv[1..Argc-1] + std::string Result; + for (int i = 1; i < Argc && p < End; ++i) + { + if (i > 1) + { + Result += ' '; + } + const char* ArgStart = p; + while (p < End && *p != '\0') + { + ++p; + } + Result.append(ArgStart, p); + if (p < End) + { + ++p; + } + } + return Result; + +#else + ZEN_UNUSED(Pid); + ZEN_UNUSED(OutEc); + return {}; +#endif +} + std::error_code FindProcess(const std::filesystem::path& ExecutableImage, ProcessHandle& OutHandle, bool IncludeSelf) { diff --git a/src/zencore/refcount.cpp b/src/zencore/refcount.cpp index a6a86ee12..f19afe715 100644 --- a/src/zencore/refcount.cpp +++ b/src/zencore/refcount.cpp @@ -33,6 +33,8 @@ refcount_forcelink() { } +TEST_SUITE_BEGIN("core.refcount"); + TEST_CASE("RefPtr") { RefPtr<TestRefClass> Ref; @@ -60,6 +62,8 @@ TEST_CASE("RefPtr") CHECK(IsDestroyed == true); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp index 00e67dc85..8d087e8c6 100644 --- a/src/zencore/sentryintegration.cpp +++ b/src/zencore/sentryintegration.cpp @@ -4,29 +4,23 @@ #include <zencore/config.h> #include <zencore/logging.h> +#include <zencore/logging/registry.h> +#include <zencore/logging/sink.h> #include <zencore/session.h> #include <zencore/uid.h> #include <stdarg.h> #include <stdio.h> -#if ZEN_PLATFORM_LINUX +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC # include <pwd.h> +# include <unistd.h> #endif -#if ZEN_PLATFORM_MAC -# include <pwd.h> -#endif - -ZEN_THIRD_PARTY_INCLUDES_START -#include <spdlog/spdlog.h> -ZEN_THIRD_PARTY_INCLUDES_END - #if ZEN_USE_SENTRY # define SENTRY_BUILD_STATIC 1 ZEN_THIRD_PARTY_INCLUDES_START # include <sentry.h> -# include <spdlog/sinks/base_sink.h> ZEN_THIRD_PARTY_INCLUDES_END namespace sentry { @@ -44,71 +38,58 @@ struct SentryAssertImpl : zen::AssertImpl const zen::CallstackFrames* Callstack) override; }; -class sentry_sink final : public spdlog::sinks::base_sink<spdlog::details::null_mutex> +static constexpr sentry_level_t MapToSentryLevel[zen::logging::LogLevelCount] = {SENTRY_LEVEL_DEBUG, + SENTRY_LEVEL_DEBUG, + SENTRY_LEVEL_INFO, + SENTRY_LEVEL_WARNING, + SENTRY_LEVEL_ERROR, + SENTRY_LEVEL_FATAL, + SENTRY_LEVEL_DEBUG}; + +class SentrySink final : public zen::logging::Sink { public: - sentry_sink(); - ~sentry_sink(); + SentrySink() = default; + ~SentrySink() = default; -protected: - void sink_it_(const spdlog::details::log_msg& msg) override; - void flush_() override; + void Log(const zen::logging::LogMessage& Msg) override + { + if (Msg.GetLevel() != zen::logging::Err && Msg.GetLevel() != zen::logging::Critical) + { + return; + } + try + { + std::string Message = fmt::format("{}\n{}({})", Msg.GetPayload(), Msg.GetSource().Filename, Msg.GetSource().Line); + sentry_value_t Event = sentry_value_new_message_event( + /* level */ MapToSentryLevel[Msg.GetLevel()], + /* logger */ nullptr, + /* message */ Message.c_str()); + sentry_event_value_add_stacktrace(Event, NULL, 0); + sentry_capture_event(Event); + } + catch (const std::exception&) + { + // If our logging with Message formatting fails we do a non-allocating version and just post the payload raw + char TmpBuffer[256]; + size_t MaxCopy = zen::Min<size_t>(Msg.GetPayload().size(), size_t(255)); + memcpy(TmpBuffer, Msg.GetPayload().data(), MaxCopy); + TmpBuffer[MaxCopy] = '\0'; + sentry_value_t Event = sentry_value_new_message_event( + /* level */ SENTRY_LEVEL_ERROR, + /* logger */ nullptr, + /* message */ TmpBuffer); + sentry_event_value_add_stacktrace(Event, NULL, 0); + sentry_capture_event(Event); + } + } + + void Flush() override {} + void SetFormatter(std::unique_ptr<zen::logging::Formatter>) override {} }; ////////////////////////////////////////////////////////////////////////// -static constexpr sentry_level_t MapToSentryLevel[spdlog::level::level_enum::n_levels] = {SENTRY_LEVEL_DEBUG, - SENTRY_LEVEL_DEBUG, - SENTRY_LEVEL_INFO, - SENTRY_LEVEL_WARNING, - SENTRY_LEVEL_ERROR, - SENTRY_LEVEL_FATAL, - SENTRY_LEVEL_DEBUG}; - -sentry_sink::sentry_sink() -{ -} -sentry_sink::~sentry_sink() -{ -} - -void -sentry_sink::sink_it_(const spdlog::details::log_msg& msg) -{ - if (msg.level != spdlog::level::err && msg.level != spdlog::level::critical) - { - return; - } - try - { - std::string Message = fmt::format("{}\n{}({}) [{}]", msg.payload, msg.source.filename, msg.source.line, msg.source.funcname); - sentry_value_t event = sentry_value_new_message_event( - /* level */ MapToSentryLevel[msg.level], - /* logger */ nullptr, - /* message */ Message.c_str()); - sentry_event_value_add_stacktrace(event, NULL, 0); - sentry_capture_event(event); - } - catch (const std::exception&) - { - // If our logging with Message formatting fails we do a non-allocating version and just post the msg.payload raw - char TmpBuffer[256]; - size_t MaxCopy = zen::Min<size_t>(msg.payload.size(), size_t(255)); - memcpy(TmpBuffer, msg.payload.data(), MaxCopy); - TmpBuffer[MaxCopy] = '\0'; - sentry_value_t event = sentry_value_new_message_event( - /* level */ SENTRY_LEVEL_ERROR, - /* logger */ nullptr, - /* message */ TmpBuffer); - sentry_event_value_add_stacktrace(event, NULL, 0); - sentry_capture_event(event); - } -} -void -sentry_sink::flush_() -{ -} - void SentryAssertImpl::OnAssert(const char* Filename, int LineNumber, @@ -145,6 +126,10 @@ SentryAssertImpl::OnAssert(const char* Filename, namespace zen { # if ZEN_USE_SENTRY +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogSentry, "sentry-sdk"); + +static std::atomic<bool> s_SentryLogEnabled{true}; + static void SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[maybe_unused]] void* Userdata) { @@ -163,26 +148,62 @@ SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[may MessagePtr = LogMessage.c_str(); } + // SentryLogFunction can be called before the logging system is initialized + // (during sentry_init which runs before InitializeLogging), or after it has + // been shut down (during sentry_close on a background worker thread). Fall + // back to console logging when the category logger is not available. + // + // Since we want to default to WARN level but this runs before logging has + // been configured, we ignore the callbacks for DEBUG/INFO explicitly here + // which means users don't see every possible log message if they're trying + // to configure the levels using --log-debug=sentry-sdk + if (!TheDefaultLogger || !s_SentryLogEnabled.load(std::memory_order_acquire)) + { + switch (Level) + { + case SENTRY_LEVEL_DEBUG: + // ZEN_CONSOLE_DEBUG("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_INFO: + // ZEN_CONSOLE_INFO("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_WARNING: + ZEN_CONSOLE_WARN("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_ERROR: + ZEN_CONSOLE_ERROR("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_FATAL: + ZEN_CONSOLE_CRITICAL("sentry: {}", MessagePtr); + break; + } + return; + } + switch (Level) { case SENTRY_LEVEL_DEBUG: - ZEN_CONSOLE_DEBUG("sentry: {}", MessagePtr); + ZEN_LOG_DEBUG(LogSentry, "sentry: {}", MessagePtr); break; case SENTRY_LEVEL_INFO: - ZEN_CONSOLE_INFO("sentry: {}", MessagePtr); + ZEN_LOG_INFO(LogSentry, "sentry: {}", MessagePtr); break; case SENTRY_LEVEL_WARNING: - ZEN_CONSOLE_WARN("sentry: {}", MessagePtr); + ZEN_LOG_WARN(LogSentry, "sentry: {}", MessagePtr); break; case SENTRY_LEVEL_ERROR: - ZEN_CONSOLE_ERROR("sentry: {}", MessagePtr); + ZEN_LOG_ERROR(LogSentry, "sentry: {}", MessagePtr); break; case SENTRY_LEVEL_FATAL: - ZEN_CONSOLE_CRITICAL("sentry: {}", MessagePtr); + ZEN_LOG_CRITICAL(LogSentry, "sentry: {}", MessagePtr); break; } } @@ -194,11 +215,21 @@ SentryIntegration::SentryIntegration() SentryIntegration::~SentryIntegration() { + Close(); +} + +void +SentryIntegration::Close() +{ if (m_IsInitialized && m_SentryErrorCode == 0) { logging::SetErrorLog(""); m_SentryAssert.reset(); + // Disable spdlog forwarding before sentry_close() since its background + // worker thread may still log during shutdown via SentryLogFunction + s_SentryLogEnabled.store(false, std::memory_order_release); sentry_close(); + m_IsInitialized = false; } } @@ -298,7 +329,9 @@ SentryIntegration::Initialize(const Config& Conf, const std::string& CommandLine sentry_set_user(SentryUserObject); - m_SentryLogger = spdlog::create<sentry::sentry_sink>("sentry"); + logging::SinkPtr SentrySink(new sentry::SentrySink()); + m_SentryLogger = Ref<logging::Logger>(new logging::Logger("sentry", std::vector<logging::SinkPtr>{SentrySink})); + logging::Registry::Instance().Register(m_SentryLogger); logging::SetErrorLog("sentry"); m_SentryAssert = std::make_unique<sentry::SentryAssertImpl>(); @@ -310,22 +343,31 @@ SentryIntegration::Initialize(const Config& Conf, const std::string& CommandLine void SentryIntegration::LogStartupInformation() { + // Initialize the sentry-sdk log category at Warn level to reduce startup noise. + // The level can be overridden via --log-debug=sentry-sdk or --log-info=sentry-sdk + LogSentry.Logger().SetLogLevel(logging::Warn); + if (m_IsInitialized) { if (m_SentryErrorCode == 0) { if (m_AllowPII) { - ZEN_INFO("sentry initialized, username: '{}', hostname: '{}', id: '{}'", m_SentryUserName, m_SentryHostName, m_SentryId); + ZEN_LOG_INFO(LogSentry, + "sentry initialized, username: '{}', hostname: '{}', id: '{}'", + m_SentryUserName, + m_SentryHostName, + m_SentryId); } else { - ZEN_INFO("sentry initialized with anonymous reports"); + ZEN_LOG_INFO(LogSentry, "sentry initialized with anonymous reports"); } } else { - ZEN_WARN( + ZEN_LOG_WARN( + LogSentry, "sentry_init returned failure! (error code: {}) note that sentry expects crashpad_handler to exist alongside the running " "executable", m_SentryErrorCode); diff --git a/src/zencore/sha1.cpp b/src/zencore/sha1.cpp index 3ee74d7d8..807ae4c30 100644 --- a/src/zencore/sha1.cpp +++ b/src/zencore/sha1.cpp @@ -373,6 +373,8 @@ sha1_forcelink() // return sha1text; // } +TEST_SUITE_BEGIN("core.sha1"); + TEST_CASE("SHA1") { uint8_t sha1_empty[20] = {0xda, 0x39, 0xa3, 0xee, 0x5e, 0x6b, 0x4b, 0x0d, 0x32, 0x55, @@ -438,6 +440,8 @@ TEST_CASE("SHA1") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/sharedbuffer.cpp b/src/zencore/sharedbuffer.cpp index 78efb9d42..8dc6d49d8 100644 --- a/src/zencore/sharedbuffer.cpp +++ b/src/zencore/sharedbuffer.cpp @@ -152,10 +152,14 @@ sharedbuffer_forcelink() { } +TEST_SUITE_BEGIN("core.sharedbuffer"); + TEST_CASE("SharedBuffer") { } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/stream.cpp b/src/zencore/stream.cpp index a800ce121..de67303a4 100644 --- a/src/zencore/stream.cpp +++ b/src/zencore/stream.cpp @@ -79,6 +79,8 @@ BufferReader::Serialize(void* V, int64_t Length) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.stream"); + TEST_CASE("binary.writer.span") { BinaryWriter Writer; @@ -91,6 +93,8 @@ TEST_CASE("binary.writer.span") CHECK(memcmp(Result.GetData(), "apa banan", 9) == 0); } +TEST_SUITE_END(); + void stream_forcelink() { diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp index 0ee863b74..ed0ba6f46 100644 --- a/src/zencore/string.cpp +++ b/src/zencore/string.cpp @@ -4,6 +4,7 @@ #include <zencore/memoryview.h> #include <zencore/string.h> #include <zencore/testing.h> +#include <zencore/testutils.h> #include <inttypes.h> #include <math.h> @@ -24,6 +25,10 @@ utf16to8_impl(u16bit_iterator StartIt, u16bit_iterator EndIt, ::zen::StringBuild // Take care of surrogate pairs first if (utf8::internal::is_lead_surrogate(cp)) { + if (StartIt == EndIt) + { + break; + } uint32_t trail_surrogate = utf8::internal::mask16(*StartIt++); cp = (cp << 10) + trail_surrogate + utf8::internal::SURROGATE_OFFSET; } @@ -180,7 +185,21 @@ Utf8ToWide(const std::u8string_view& Str8, WideStringBuilderBase& OutString) if (!ByteCount) { +#if ZEN_SIZEOF_WCHAR_T == 2 + if (CurrentOutChar > 0xFFFF) + { + // Supplementary plane: emit a UTF-16 surrogate pair + uint32_t Adjusted = uint32_t(CurrentOutChar - 0x10000); + OutString.Append(wchar_t(0xD800 + (Adjusted >> 10))); + OutString.Append(wchar_t(0xDC00 + (Adjusted & 0x3FF))); + } + else + { + OutString.Append(wchar_t(CurrentOutChar)); + } +#else OutString.Append(wchar_t(CurrentOutChar)); +#endif CurrentOutChar = 0; } } @@ -249,6 +268,17 @@ namespace { /* kNicenumTime */ 1000}; } // namespace +uint64_t +IntPow(uint64_t Base, int Exp) +{ + uint64_t Result = 1; + for (int I = 0; I < Exp; ++I) + { + Result *= Base; + } + return Result; +} + /* * Convert a number to an appropriately human-readable output. */ @@ -296,7 +326,7 @@ NiceNumGeneral(uint64_t Num, std::span<char> Buffer, NicenumFormat Format) const char* u = UnitStrings[Format][Index]; - if ((Index == 0) || ((Num % (uint64_t)powl((int)KiloUnit[Format], Index)) == 0)) + if ((Index == 0) || ((Num % IntPow(KiloUnit[Format], Index)) == 0)) { /* * If this is an even multiple of the base, always display @@ -320,7 +350,7 @@ NiceNumGeneral(uint64_t Num, std::span<char> Buffer, NicenumFormat Format) for (int i = 2; i >= 0; i--) { - double Value = (double)Num / (uint64_t)powl((int)KiloUnit[Format], Index); + double Value = (double)Num / IntPow(KiloUnit[Format], Index); /* * Don't print floating point values for time. Note, @@ -520,13 +550,38 @@ UrlDecode(std::string_view InUrl) return std::string(Url.ToView()); } -////////////////////////////////////////////////////////////////////////// -// -// Unit tests -// +std::string +HideSensitiveString(std::string_view String) +{ + const size_t Length = String.length(); + const size_t SourceLength = Length > 16 ? 4 : 0; + const size_t PadLength = Min(Length - SourceLength, 4u); + const bool AddEllipsis = (SourceLength + PadLength) < Length; + StringBuilder<16> SB; + if (SourceLength > 0) + { + SB << String.substr(0, SourceLength); + } + if (PadLength > 0) + { + SB << std::string(PadLength, 'X'); + } + if (AddEllipsis) + { + SB << "..."; + } + return SB.ToString(); +}; + + ////////////////////////////////////////////////////////////////////////// + // + // Unit tests + // #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.string"); + TEST_CASE("url") { using namespace std::literals; @@ -793,11 +848,6 @@ TEST_CASE("niceNum") } } -void -string_forcelink() -{ -} - TEST_CASE("StringBuilder") { StringBuilder<64> sb; @@ -963,33 +1013,131 @@ TEST_CASE("ExtendableWideStringBuilder") TEST_CASE("utf8") { + using namespace utf8test; + SUBCASE("utf8towide") { - // TODO: add more extensive testing here - this covers a very small space - WideStringBuilder<32> wout; Utf8ToWide(u8"abcdefghi", wout); CHECK(StringEquals(L"abcdefghi", wout.c_str())); wout.Reset(); + Utf8ToWide(u8"abc\xC3\xA4\xC3\xB6\xC3\xBC", wout); + CHECK(StringEquals(L"abc\u00E4\u00F6\u00FC", wout.c_str())); + + wout.Reset(); + Utf8ToWide(std::string_view(kLatin), wout); + CHECK(StringEquals(kLatinW, wout.c_str())); + + wout.Reset(); + Utf8ToWide(std::string_view(kCyrillic), wout); + CHECK(StringEquals(kCyrillicW, wout.c_str())); + + wout.Reset(); + Utf8ToWide(std::string_view(kCJK), wout); + CHECK(StringEquals(kCJKW, wout.c_str())); + + wout.Reset(); + Utf8ToWide(std::string_view(kMixed), wout); + CHECK(StringEquals(kMixedW, wout.c_str())); - Utf8ToWide(u8"abc���", wout); - CHECK(StringEquals(L"abc���", wout.c_str())); + wout.Reset(); + Utf8ToWide(std::string_view(kEmoji), wout); + CHECK(StringEquals(kEmojiW, wout.c_str())); } SUBCASE("widetoutf8") { - // TODO: add more extensive testing here - this covers a very small space - - StringBuilder<32> out; + StringBuilder<64> out; WideToUtf8(L"abcdefghi", out); CHECK(StringEquals("abcdefghi", out.c_str())); out.Reset(); + WideToUtf8(kLatinW, out); + CHECK(StringEquals(kLatin, out.c_str())); + + out.Reset(); + WideToUtf8(kCyrillicW, out); + CHECK(StringEquals(kCyrillic, out.c_str())); + + out.Reset(); + WideToUtf8(kCJKW, out); + CHECK(StringEquals(kCJK, out.c_str())); + + out.Reset(); + WideToUtf8(kMixedW, out); + CHECK(StringEquals(kMixed, out.c_str())); - WideToUtf8(L"abc���", out); - CHECK(StringEquals(u8"abc���", out.c_str())); + out.Reset(); + WideToUtf8(kEmojiW, out); + CHECK(StringEquals(kEmoji, out.c_str())); + } + + SUBCASE("roundtrip") + { + // UTF-8 -> Wide -> UTF-8 identity + const char* Utf8Strings[] = {kLatin, kCyrillic, kCJK, kMixed, kEmoji}; + for (const char* Utf8Str : Utf8Strings) + { + ExtendableWideStringBuilder<64> Wide; + Utf8ToWide(std::string_view(Utf8Str), Wide); + + ExtendableStringBuilder<64> Back; + WideToUtf8(std::wstring_view(Wide.c_str()), Back); + CHECK(StringEquals(Utf8Str, Back.c_str())); + } + + // Wide -> UTF-8 -> Wide identity + const wchar_t* WideStrings[] = {kLatinW, kCyrillicW, kCJKW, kMixedW, kEmojiW}; + for (const wchar_t* WideStr : WideStrings) + { + ExtendableStringBuilder<64> Utf8; + WideToUtf8(std::wstring_view(WideStr), Utf8); + + ExtendableWideStringBuilder<64> Back; + Utf8ToWide(std::string_view(Utf8.c_str()), Back); + CHECK(StringEquals(WideStr, Back.c_str())); + } + + // Empty string round-trip + { + ExtendableWideStringBuilder<8> Wide; + Utf8ToWide(std::string_view(""), Wide); + CHECK(Wide.Size() == 0); + + ExtendableStringBuilder<8> Narrow; + WideToUtf8(std::wstring_view(L""), Narrow); + CHECK(Narrow.Size() == 0); + } + } + + SUBCASE("IsValidUtf8") + { + // Valid inputs + CHECK(IsValidUtf8("")); + CHECK(IsValidUtf8("hello world")); + CHECK(IsValidUtf8(kLatin)); + CHECK(IsValidUtf8(kCyrillic)); + CHECK(IsValidUtf8(kCJK)); + CHECK(IsValidUtf8(kMixed)); + CHECK(IsValidUtf8(kEmoji)); + + // Invalid: truncated 2-byte sequence + CHECK(!IsValidUtf8(std::string_view("\xC3", 1))); + + // Invalid: truncated 3-byte sequence + CHECK(!IsValidUtf8(std::string_view("\xE6\x97", 2))); + + // Invalid: truncated 4-byte sequence + CHECK(!IsValidUtf8(std::string_view("\xF0\x9F\x93", 3))); + + // Invalid: bad start byte + CHECK(!IsValidUtf8(std::string_view("\xFF", 1))); + CHECK(!IsValidUtf8(std::string_view("\xFE", 1))); + + // Invalid: overlong encoding of '/' (U+002F) + CHECK(!IsValidUtf8(std::string_view("\xC0\xAF", 2))); } } @@ -1105,6 +1253,28 @@ TEST_CASE("string") } } +TEST_CASE("hidesensitivestring") +{ + using namespace std::literals; + + CHECK_EQ(HideSensitiveString(""sv), ""sv); + CHECK_EQ(HideSensitiveString("A"sv), "X"sv); + CHECK_EQ(HideSensitiveString("ABCD"sv), "XXXX"sv); + CHECK_EQ(HideSensitiveString("ABCDE"sv), "XXXX..."sv); + CHECK_EQ(HideSensitiveString("ABCDEFGH"sv), "XXXX..."sv); + CHECK_EQ(HideSensitiveString("ABCDEFGHIJKLMNOP"sv), "XXXX..."sv); + CHECK_EQ(HideSensitiveString("ABCDEFGHIJKLMNOPQ"sv), "ABCDXXXX..."sv); + CHECK_EQ(HideSensitiveString("ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"sv), "ABCDXXXX..."sv); + CHECK_EQ(HideSensitiveString("1234567890123456789"sv), "1234XXXX..."sv); +} + +TEST_SUITE_END(); + +void +string_forcelink() +{ +} + #endif } // namespace zen diff --git a/src/zencore/system.cpp b/src/zencore/system.cpp index b9ac3bdee..141450b84 100644 --- a/src/zencore/system.cpp +++ b/src/zencore/system.cpp @@ -4,15 +4,20 @@ #include <zencore/compactbinarybuilder.h> #include <zencore/except.h> +#include <zencore/fmtutils.h> #include <zencore/memory/memory.h> #include <zencore/string.h> +#include <mutex> + #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> ZEN_THIRD_PARTY_INCLUDES_START # include <iphlpapi.h> # include <winsock2.h> +# include <pdh.h> +# pragma comment(lib, "pdh.lib") ZEN_THIRD_PARTY_INCLUDES_END #elif ZEN_PLATFORM_LINUX # include <sys/utsname.h> @@ -65,55 +70,73 @@ GetSystemMetrics() // Determine physical core count - DWORD BufferSize = 0; - BOOL Result = GetLogicalProcessorInformation(nullptr, &BufferSize); - if (int32_t Error = GetLastError(); Error != ERROR_INSUFFICIENT_BUFFER) { - ThrowSystemError(Error, "Failed to get buffer size for logical processor information"); - } + DWORD BufferSize = 0; + BOOL Result = GetLogicalProcessorInformationEx(RelationAll, nullptr, &BufferSize); + if (int32_t Error = GetLastError(); Error != ERROR_INSUFFICIENT_BUFFER) + { + ThrowSystemError(Error, "Failed to get buffer size for logical processor information"); + } - PSYSTEM_LOGICAL_PROCESSOR_INFORMATION Buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION)Memory::Alloc(BufferSize); + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX Buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)Memory::Alloc(BufferSize); - Result = GetLogicalProcessorInformation(Buffer, &BufferSize); - if (!Result) - { - Memory::Free(Buffer); - throw std::runtime_error("Failed to get logical processor information"); - } - - DWORD ProcessorPkgCount = 0; - DWORD ProcessorCoreCount = 0; - DWORD ByteOffset = 0; - while (ByteOffset + sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION) <= BufferSize) - { - const SYSTEM_LOGICAL_PROCESSOR_INFORMATION& Slpi = Buffer[ByteOffset / sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION)]; - if (Slpi.Relationship == RelationProcessorCore) + Result = GetLogicalProcessorInformationEx(RelationAll, Buffer, &BufferSize); + if (!Result) { - ProcessorCoreCount++; + Memory::Free(Buffer); + throw std::runtime_error("Failed to get logical processor information"); } - else if (Slpi.Relationship == RelationProcessorPackage) + + DWORD ProcessorPkgCount = 0; + DWORD ProcessorCoreCount = 0; + DWORD LogicalProcessorCount = 0; + + BYTE* Ptr = reinterpret_cast<BYTE*>(Buffer); + BYTE* const End = Ptr + BufferSize; + while (Ptr < End) { - ProcessorPkgCount++; + const SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX& Slpi = *reinterpret_cast<const SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*>(Ptr); + if (Slpi.Relationship == RelationProcessorCore) + { + ++ProcessorCoreCount; + + // Count logical processors (threads) across all processor groups for this core. + // Each core entry lists one GROUP_AFFINITY per group it spans; each set bit + // in the Mask represents one logical processor (HyperThreading sibling). + for (WORD g = 0; g < Slpi.Processor.GroupCount; ++g) + { + LogicalProcessorCount += static_cast<DWORD>(__popcnt64(Slpi.Processor.GroupMask[g].Mask)); + } + } + else if (Slpi.Relationship == RelationProcessorPackage) + { + ++ProcessorPkgCount; + } + Ptr += Slpi.Size; } - ByteOffset += sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION); - } - Metrics.CoreCount = ProcessorCoreCount; - Metrics.CpuCount = ProcessorPkgCount; + Metrics.CoreCount = ProcessorCoreCount; + Metrics.CpuCount = ProcessorPkgCount; + Metrics.LogicalProcessorCount = LogicalProcessorCount; - Memory::Free(Buffer); + Memory::Free(Buffer); + } // Query memory status - MEMORYSTATUSEX MemStatus{.dwLength = sizeof(MEMORYSTATUSEX)}; - GlobalMemoryStatusEx(&MemStatus); + { + MEMORYSTATUSEX MemStatus{.dwLength = sizeof(MEMORYSTATUSEX)}; + GlobalMemoryStatusEx(&MemStatus); + + Metrics.SystemMemoryMiB = MemStatus.ullTotalPhys / 1024 / 1024; + Metrics.AvailSystemMemoryMiB = MemStatus.ullAvailPhys / 1024 / 1024; + Metrics.VirtualMemoryMiB = MemStatus.ullTotalVirtual / 1024 / 1024; + Metrics.AvailVirtualMemoryMiB = MemStatus.ullAvailVirtual / 1024 / 1024; + Metrics.PageFileMiB = MemStatus.ullTotalPageFile / 1024 / 1024; + Metrics.AvailPageFileMiB = MemStatus.ullAvailPageFile / 1024 / 1024; + } - Metrics.SystemMemoryMiB = MemStatus.ullTotalPhys / 1024 / 1024; - Metrics.AvailSystemMemoryMiB = MemStatus.ullAvailPhys / 1024 / 1024; - Metrics.VirtualMemoryMiB = MemStatus.ullTotalVirtual / 1024 / 1024; - Metrics.AvailVirtualMemoryMiB = MemStatus.ullAvailVirtual / 1024 / 1024; - Metrics.PageFileMiB = MemStatus.ullTotalPageFile / 1024 / 1024; - Metrics.AvailPageFileMiB = MemStatus.ullAvailPageFile / 1024 / 1024; + Metrics.UptimeSeconds = GetTickCount64() / 1000; return Metrics; } @@ -206,6 +229,17 @@ GetSystemMetrics() Metrics.VirtualMemoryMiB = Metrics.SystemMemoryMiB; Metrics.AvailVirtualMemoryMiB = Metrics.AvailSystemMemoryMiB; + // System uptime + if (FILE* UptimeFile = fopen("/proc/uptime", "r")) + { + double UptimeSec = 0; + if (fscanf(UptimeFile, "%lf", &UptimeSec) == 1) + { + Metrics.UptimeSeconds = static_cast<uint64_t>(UptimeSec); + } + fclose(UptimeFile); + } + // Parse /proc/meminfo for swap/page file information Metrics.PageFileMiB = 0; Metrics.AvailPageFileMiB = 0; @@ -298,12 +332,35 @@ GetSystemMetrics() Metrics.PageFileMiB = SwapUsage.xsu_total / 1024 / 1024; Metrics.AvailPageFileMiB = (SwapUsage.xsu_total - SwapUsage.xsu_used) / 1024 / 1024; + // System uptime via boot time + { + struct timeval BootTime + { + }; + Size = sizeof(BootTime); + if (sysctlbyname("kern.boottime", &BootTime, &Size, nullptr, 0) == 0) + { + Metrics.UptimeSeconds = static_cast<uint64_t>(time(nullptr) - BootTime.tv_sec); + } + } + return Metrics; } #else # error "Unknown platform" #endif +ExtendedSystemMetrics +ApplyReportingOverrides(ExtendedSystemMetrics Metrics) +{ + if (g_FakeCpuCount) + { + Metrics.CoreCount = g_FakeCpuCount; + Metrics.LogicalProcessorCount = g_FakeCpuCount; + } + return Metrics; +} + SystemMetrics GetSystemMetricsForReporting() { @@ -318,12 +375,281 @@ GetSystemMetricsForReporting() return Sm; } +/////////////////////////////////////////////////////////////////////////// +// SystemMetricsTracker +/////////////////////////////////////////////////////////////////////////// + +// Per-platform CPU sampling helper. Called with m_Mutex held. + +#if ZEN_PLATFORM_WINDOWS || ZEN_PLATFORM_LINUX + +// Samples CPU usage by reading /proc/stat. Used natively on Linux and as a +// Wine fallback on Windows (where /proc/stat is accessible via the Z: drive). +struct ProcStatCpuSampler +{ + const char* Path = "/proc/stat"; + unsigned long PrevUser = 0; + unsigned long PrevNice = 0; + unsigned long PrevSystem = 0; + unsigned long PrevIdle = 0; + unsigned long PrevIoWait = 0; + unsigned long PrevIrq = 0; + unsigned long PrevSoftIrq = 0; + + explicit ProcStatCpuSampler(const char* InPath = "/proc/stat") : Path(InPath) {} + + float Sample() + { + float CpuUsage = 0.0f; + + if (FILE* Stat = fopen(Path, "r")) + { + char Line[256]; + unsigned long User, Nice, System, Idle, IoWait, Irq, SoftIrq; + + if (fgets(Line, sizeof(Line), Stat)) + { + if (sscanf(Line, "cpu %lu %lu %lu %lu %lu %lu %lu", &User, &Nice, &System, &Idle, &IoWait, &Irq, &SoftIrq) == 7) + { + unsigned long TotalDelta = (User + Nice + System + Idle + IoWait + Irq + SoftIrq) - + (PrevUser + PrevNice + PrevSystem + PrevIdle + PrevIoWait + PrevIrq + PrevSoftIrq); + unsigned long IdleDelta = Idle - PrevIdle; + + if (TotalDelta > 0) + { + CpuUsage = 100.0f * (TotalDelta - IdleDelta) / TotalDelta; + } + + PrevUser = User; + PrevNice = Nice; + PrevSystem = System; + PrevIdle = Idle; + PrevIoWait = IoWait; + PrevIrq = Irq; + PrevSoftIrq = SoftIrq; + } + } + fclose(Stat); + } + + return CpuUsage; + } +}; + +#endif + +#if ZEN_PLATFORM_WINDOWS + +struct CpuSampler +{ + PDH_HQUERY QueryHandle = nullptr; + PDH_HCOUNTER CounterHandle = nullptr; + bool HasPreviousSample = false; + bool IsWine = false; + ProcStatCpuSampler ProcStat{"Z:\\proc\\stat"}; + + CpuSampler() + { + IsWine = zen::windows::IsRunningOnWine(); + + if (!IsWine) + { + if (PdhOpenQueryW(nullptr, 0, &QueryHandle) == ERROR_SUCCESS) + { + if (PdhAddEnglishCounterW(QueryHandle, L"\\Processor(_Total)\\% Processor Time", 0, &CounterHandle) != ERROR_SUCCESS) + { + CounterHandle = nullptr; + } + } + } + } + + ~CpuSampler() + { + if (QueryHandle) + { + PdhCloseQuery(QueryHandle); + } + } + + float Sample() + { + if (IsWine) + { + return ProcStat.Sample(); + } + + if (!QueryHandle || !CounterHandle) + { + return 0.0f; + } + + PdhCollectQueryData(QueryHandle); + + if (!HasPreviousSample) + { + HasPreviousSample = true; + return 0.0f; + } + + PDH_FMT_COUNTERVALUE CounterValue; + if (PdhGetFormattedCounterValue(CounterHandle, PDH_FMT_DOUBLE, nullptr, &CounterValue) == ERROR_SUCCESS) + { + return static_cast<float>(CounterValue.doubleValue); + } + + return 0.0f; + } +}; + +#elif ZEN_PLATFORM_LINUX + +struct CpuSampler +{ + ProcStatCpuSampler ProcStat; + + float Sample() { return ProcStat.Sample(); } +}; + +#elif ZEN_PLATFORM_MAC + +struct CpuSampler +{ + unsigned long PrevTotalTicks = 0; + unsigned long PrevIdleTicks = 0; + + float Sample() + { + float CpuUsage = 0.0f; + + host_cpu_load_info_data_t CpuLoad; + mach_msg_type_number_t Count = sizeof(CpuLoad) / sizeof(natural_t); + if (host_statistics(mach_host_self(), HOST_CPU_LOAD_INFO, (host_info_t)&CpuLoad, &Count) == KERN_SUCCESS) + { + unsigned long TotalTicks = 0; + for (int i = 0; i < CPU_STATE_MAX; ++i) + { + TotalTicks += CpuLoad.cpu_ticks[i]; + } + unsigned long IdleTicks = CpuLoad.cpu_ticks[CPU_STATE_IDLE]; + + unsigned long TotalDelta = TotalTicks - PrevTotalTicks; + unsigned long IdleDelta = IdleTicks - PrevIdleTicks; + + if (TotalDelta > 0 && PrevTotalTicks > 0) + { + CpuUsage = 100.0f * (TotalDelta - IdleDelta) / TotalDelta; + } + + PrevTotalTicks = TotalTicks; + PrevIdleTicks = IdleTicks; + } + + return CpuUsage; + } +}; + +#endif + +struct SystemMetricsTracker::Impl +{ + using Clock = std::chrono::steady_clock; + + std::mutex Mutex; + CpuSampler Sampler; + float CachedCpuPercent = 0.0f; + Clock::time_point NextSampleTime = Clock::now(); + std::chrono::milliseconds MinInterval; + + explicit Impl(std::chrono::milliseconds InMinInterval) : MinInterval(InMinInterval) {} + + float SampleCpu() + { + const auto Now = Clock::now(); + if (Now >= NextSampleTime) + { + CachedCpuPercent = Sampler.Sample(); + NextSampleTime = Now + MinInterval; + } + return CachedCpuPercent; + } +}; + +SystemMetricsTracker::SystemMetricsTracker(std::chrono::milliseconds MinInterval) : m_Impl(std::make_unique<Impl>(MinInterval)) +{ +} + +SystemMetricsTracker::~SystemMetricsTracker() = default; + +ExtendedSystemMetrics +SystemMetricsTracker::Query() +{ + ExtendedSystemMetrics Metrics; + static_cast<SystemMetrics&>(Metrics) = GetSystemMetrics(); + + std::lock_guard Lock(m_Impl->Mutex); + Metrics.CpuUsagePercent = m_Impl->SampleCpu(); + return Metrics; +} + +/////////////////////////////////////////////////////////////////////////// + std::string_view GetOperatingSystemName() { return ZEN_PLATFORM_NAME; } +std::string +GetOperatingSystemVersion() +{ +#if ZEN_PLATFORM_WINDOWS + // Use RtlGetVersion to avoid the compatibility shim that GetVersionEx applies + using RtlGetVersionFn = LONG(WINAPI*)(PRTL_OSVERSIONINFOW); + RTL_OSVERSIONINFOW OsVer{.dwOSVersionInfoSize = sizeof(OsVer)}; + if (auto Fn = (RtlGetVersionFn)GetProcAddress(GetModuleHandleW(L"ntdll.dll"), "RtlGetVersion")) + { + Fn(&OsVer); + } + return fmt::format("Windows {}.{} Build {}", OsVer.dwMajorVersion, OsVer.dwMinorVersion, OsVer.dwBuildNumber); +#elif ZEN_PLATFORM_LINUX + struct utsname Info + { + }; + if (uname(&Info) == 0) + { + return fmt::format("{} {}", Info.sysname, Info.release); + } + return "Linux"; +#elif ZEN_PLATFORM_MAC + char OsVersion[64] = ""; + size_t Size = sizeof(OsVersion); + if (sysctlbyname("kern.osproductversion", OsVersion, &Size, nullptr, 0) == 0) + { + return fmt::format("macOS {}", OsVersion); + } + return "macOS"; +#endif +} + +std::string_view +GetRuntimePlatformName() +{ +#if ZEN_PLATFORM_WINDOWS + if (zen::windows::IsRunningOnWine()) + { + return "wine"sv; + } + return "windows"sv; +#elif ZEN_PLATFORM_LINUX + return "linux"sv; +#elif ZEN_PLATFORM_MAC + return "macos"sv; +#else + return "unknown"sv; +#endif +} + std::string_view GetCpuName() { @@ -340,7 +666,14 @@ Describe(const SystemMetrics& Metrics, CbWriter& Writer) Writer << "cpu_count" << Metrics.CpuCount << "core_count" << Metrics.CoreCount << "lp_count" << Metrics.LogicalProcessorCount << "total_memory_mb" << Metrics.SystemMemoryMiB << "avail_memory_mb" << Metrics.AvailSystemMemoryMiB << "total_virtual_mb" << Metrics.VirtualMemoryMiB << "avail_virtual_mb" << Metrics.AvailVirtualMemoryMiB << "total_pagefile_mb" << Metrics.PageFileMiB - << "avail_pagefile_mb" << Metrics.AvailPageFileMiB; + << "avail_pagefile_mb" << Metrics.AvailPageFileMiB << "uptime_seconds" << Metrics.UptimeSeconds; +} + +void +Describe(const ExtendedSystemMetrics& Metrics, CbWriter& Writer) +{ + Describe(static_cast<const SystemMetrics&>(Metrics), Writer); + Writer << "cpu_usage_percent" << Metrics.CpuUsagePercent; } } // namespace zen diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp index 936424e0f..089e376bb 100644 --- a/src/zencore/testing.cpp +++ b/src/zencore/testing.cpp @@ -1,11 +1,22 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#define ZEN_TEST_WITH_RUNNER 1 + #include "zencore/testing.h" + +#include "zencore/filesystem.h" #include "zencore/logging.h" +#include "zencore/process.h" +#include "zencore/trace.h" #if ZEN_WITH_TESTS -# include <doctest/doctest.h> +# include <chrono> +# include <clocale> +# include <cstdlib> +# include <cstdio> +# include <string> +# include <vector> namespace zen::testing { @@ -21,9 +32,35 @@ struct TestListener : public doctest::IReporter void report_query(const doctest::QueryData& /*in*/) override {} - void test_run_start() override {} + void test_run_start() override { RunStart = std::chrono::steady_clock::now(); } - void test_run_end(const doctest::TestRunStats& /*in*/) override {} + void test_run_end(const doctest::TestRunStats& in) override + { + auto elapsed = std::chrono::steady_clock::now() - RunStart; + double elapsedSeconds = std::chrono::duration_cast<std::chrono::milliseconds>(elapsed).count() / 1000.0; + + // Write machine-readable summary to file if requested (used by xmake test summary table) + const char* summaryFile = std::getenv("ZEN_TEST_SUMMARY_FILE"); + if (summaryFile && summaryFile[0] != '\0') + { + if (FILE* f = std::fopen(summaryFile, "w")) + { + std::fprintf(f, + "cases_total=%u\ncases_passed=%u\nassertions_total=%d\nassertions_passed=%d\n" + "elapsed_seconds=%.3f\n", + in.numTestCasesPassingFilters, + in.numTestCasesPassingFilters - in.numTestCasesFailed, + in.numAsserts, + in.numAsserts - in.numAssertsFailed, + elapsedSeconds); + for (const auto& failure : FailedTests) + { + std::fprintf(f, "failed=%s|%s|%u\n", failure.Name.c_str(), failure.File.c_str(), failure.Line); + } + std::fclose(f); + } + } + } void test_case_start(const doctest::TestCaseData& in) override { @@ -37,7 +74,14 @@ struct TestListener : public doctest::IReporter ZEN_CONSOLE("{}-------------------------------------------------------------------------------{}", ColorYellow, ColorNone); } - void test_case_end(const doctest::CurrentTestCaseStats& /*in*/) override { Current = nullptr; } + void test_case_end(const doctest::CurrentTestCaseStats& in) override + { + if (!in.testCaseSuccess && Current) + { + FailedTests.push_back({Current->m_name, Current->m_file.c_str(), Current->m_line}); + } + Current = nullptr; + } void test_case_exception(const doctest::TestCaseException& /*in*/) override {} @@ -57,7 +101,16 @@ struct TestListener : public doctest::IReporter void test_case_skipped(const doctest::TestCaseData& /*in*/) override {} - const doctest::TestCaseData* Current = nullptr; + const doctest::TestCaseData* Current = nullptr; + std::chrono::steady_clock::time_point RunStart = {}; + + struct FailedTestInfo + { + std::string Name; + std::string File; + unsigned Line; + }; + std::vector<FailedTestInfo> FailedTests; }; struct TestRunner::Impl @@ -75,20 +128,26 @@ TestRunner::~TestRunner() { } +void +TestRunner::SetDefaultSuiteFilter(const char* Pattern) +{ + m_Impl->Session.setOption("test-suite", Pattern); +} + int -TestRunner::ApplyCommandLine(int argc, char const* const* argv) +TestRunner::ApplyCommandLine(int Argc, char const* const* Argv) { - m_Impl->Session.applyCommandLine(argc, argv); + m_Impl->Session.applyCommandLine(Argc, Argv); - for (int i = 1; i < argc; ++i) + for (int i = 1; i < Argc; ++i) { - if (argv[i] == "--debug"sv) + if (Argv[i] == "--debug"sv) { - zen::logging::SetLogLevel(zen::logging::level::Debug); + zen::logging::SetLogLevel(zen::logging::Debug); } - else if (argv[i] == "--verbose"sv) + else if (Argv[i] == "--verbose"sv) { - zen::logging::SetLogLevel(zen::logging::level::Trace); + zen::logging::SetLogLevel(zen::logging::Trace); } } @@ -101,6 +160,57 @@ TestRunner::Run() return m_Impl->Session.run(); } +int +RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink)()) +{ +# if ZEN_PLATFORM_WINDOWS + setlocale(LC_ALL, "en_us.UTF8"); +# endif + + ForceLink(); + +# if ZEN_PLATFORM_LINUX + zen::IgnoreChildSignals(); +# endif + +# if ZEN_WITH_TRACE + zen::TraceInit(ExecutableName); + zen::TraceOptions TraceCommandlineOptions; + if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) + { + TraceConfigure(TraceCommandlineOptions); + } +# endif + + zen::logging::InitializeLogging(); + zen::MaximizeOpenFileCount(); + + TestRunner Runner; + + // Derive default suite filter from ExecutableName: "zencore-test" -> "core.*" + if (ExecutableName) + { + std::string_view Name = ExecutableName; + if (Name.starts_with("zen")) + { + Name.remove_prefix(3); + } + if (Name.ends_with("-test")) + { + Name.remove_suffix(5); + } + if (!Name.empty()) + { + std::string Filter(Name); + Filter += ".*"; + Runner.SetDefaultSuiteFilter(Filter.c_str()); + } + } + + Runner.ApplyCommandLine(Argc, Argv); + return Runner.Run(); +} + } // namespace zen::testing #endif // ZEN_WITH_TESTS diff --git a/src/zencore/testutils.cpp b/src/zencore/testutils.cpp index 5bc2841ae..0cd3f8121 100644 --- a/src/zencore/testutils.cpp +++ b/src/zencore/testutils.cpp @@ -46,7 +46,7 @@ ScopedTemporaryDirectory::~ScopedTemporaryDirectory() IoBuffer CreateRandomBlob(uint64_t Size) { - static FastRandom Rand{.Seed = 0x7CEBF54E45B9F5D1}; + thread_local FastRandom Rand{.Seed = 0x7CEBF54E45B9F5D1}; return CreateRandomBlob(Rand, Size); }; diff --git a/src/zencore/thread.cpp b/src/zencore/thread.cpp index 9e3486e49..54459cbaa 100644 --- a/src/zencore/thread.cpp +++ b/src/zencore/thread.cpp @@ -133,7 +133,10 @@ SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName) #elif ZEN_PLATFORM_MAC pthread_setname_np(ThreadNameZ.c_str()); #else - pthread_setname_np(pthread_self(), ThreadNameZ.c_str()); + // Linux pthread_setname_np has a 16-byte limit (15 chars + NUL) + StringBuilder<16> LinuxThreadName; + LinuxThreadName << LimitedThreadName.substr(0, 15); + pthread_setname_np(pthread_self(), LinuxThreadName.c_str()); #endif } // namespace zen @@ -233,12 +236,15 @@ Event::Close() #else std::atomic_thread_fence(std::memory_order_acquire); auto* Inner = (EventInner*)m_EventHandle.load(); + if (Inner) { - std::unique_lock Lock(Inner->Mutex); - Inner->bSet.store(true); - m_EventHandle = nullptr; + { + std::unique_lock Lock(Inner->Mutex); + Inner->bSet.store(true); + m_EventHandle = nullptr; + } + delete Inner; } - delete Inner; #endif } @@ -351,7 +357,7 @@ NamedEvent::NamedEvent(std::string_view EventName) intptr_t Packed; Packed = intptr_t(Sem) << 32; Packed |= intptr_t(Fd) & 0xffff'ffff; - m_EventHandle = (void*)Packed; + m_EventHandle = (void*)Packed; #endif ZEN_ASSERT(m_EventHandle != nullptr); } @@ -372,7 +378,9 @@ NamedEvent::Close() #if ZEN_PLATFORM_WINDOWS CloseHandle(m_EventHandle); #elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC - int Fd = int(intptr_t(m_EventHandle.load()) & 0xffff'ffff); + const intptr_t Handle = intptr_t(m_EventHandle.load()); + const int Fd = int(Handle & 0xffff'ffff); + const int Sem = int(Handle >> 32); if (flock(Fd, LOCK_EX | LOCK_NB) == 0) { @@ -388,11 +396,10 @@ NamedEvent::Close() } flock(Fd, LOCK_UN | LOCK_NB); - close(Fd); - - int Sem = int(intptr_t(m_EventHandle.load()) >> 32); semctl(Sem, 0, IPC_RMID); } + + close(Fd); #endif m_EventHandle = nullptr; @@ -481,9 +488,12 @@ NamedMutex::~NamedMutex() CloseHandle(m_MutexHandle); } #elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC - int Inner = int(intptr_t(m_MutexHandle)); - flock(Inner, LOCK_UN); - close(Inner); + if (m_MutexHandle) + { + int Inner = int(intptr_t(m_MutexHandle)); + flock(Inner, LOCK_UN); + close(Inner); + } #endif } @@ -516,7 +526,6 @@ NamedMutex::Create(std::string_view MutexName) if (flock(Inner, LOCK_EX) != 0) { close(Inner); - Inner = 0; return false; } @@ -583,6 +592,11 @@ GetCurrentThreadId() void Sleep(int ms) { + if (ms <= 0) + { + return; + } + #if ZEN_PLATFORM_WINDOWS ::Sleep(ms); #else diff --git a/src/zencore/trace.cpp b/src/zencore/trace.cpp index 87035554f..7c195e69f 100644 --- a/src/zencore/trace.cpp +++ b/src/zencore/trace.cpp @@ -10,7 +10,16 @@ # define TRACE_IMPLEMENT 1 # undef _WINSOCK_DEPRECATED_NO_WARNINGS +// GCC false positives in thirdparty trace.h (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100137) +# if ZEN_COMPILER_GCC +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wstringop-overread" +# pragma GCC diagnostic ignored "-Wdangling-pointer" +# endif # include <zencore/trace.h> +# if ZEN_COMPILER_GCC +# pragma GCC diagnostic pop +# endif # include <zencore/memory/fmalloc.h> # include <zencore/memory/memorytrace.h> @@ -165,10 +174,17 @@ GetTraceOptionsFromCommandline(TraceOptions& OutOptions) auto MatchesArg = [](std::string_view Option, std::string_view Arg) -> std::optional<std::string_view> { if (Arg.starts_with(Option)) { - std::string_view::value_type DelimChar = Arg[Option.length()]; - if (DelimChar == ' ' || DelimChar == '=') + if (Arg.length() > Option.length()) + { + std::string_view::value_type DelimChar = Arg[Option.length()]; + if (DelimChar == ' ' || DelimChar == '=') + { + return Arg.substr(Option.size() + 1); + } + } + else { - return Arg.substr(Option.size() + 1); + return ""sv; } } return {}; diff --git a/src/zencore/uid.cpp b/src/zencore/uid.cpp index d7636f2ad..971683721 100644 --- a/src/zencore/uid.cpp +++ b/src/zencore/uid.cpp @@ -156,6 +156,8 @@ Oid::FromMemory(const void* Ptr) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.uid"); + TEST_CASE("Oid") { SUBCASE("Basic") @@ -185,6 +187,8 @@ TEST_CASE("Oid") } } +TEST_SUITE_END(); + void uid_forcelink() { diff --git a/src/zencore/windows.cpp b/src/zencore/windows.cpp index d02fcd35e..87f854b90 100644 --- a/src/zencore/windows.cpp +++ b/src/zencore/windows.cpp @@ -12,14 +12,12 @@ namespace zen::windows { bool IsRunningOnWine() { - HMODULE NtDll = GetModuleHandleA("ntdll.dll"); + static bool s_Result = [] { + HMODULE NtDll = GetModuleHandleA("ntdll.dll"); + return NtDll && !!GetProcAddress(NtDll, "wine_get_version"); + }(); - if (NtDll) - { - return !!GetProcAddress(NtDll, "wine_get_version"); - } - - return false; + return s_Result; } FileMapping::FileMapping(_In_ FileMapping& orig) diff --git a/src/zencore/workthreadpool.cpp b/src/zencore/workthreadpool.cpp index cb84bbe06..1cb338c66 100644 --- a/src/zencore/workthreadpool.cpp +++ b/src/zencore/workthreadpool.cpp @@ -354,6 +354,8 @@ workthreadpool_forcelink() using namespace std::literals; +TEST_SUITE_BEGIN("core.workthreadpool"); + TEST_CASE("threadpool.basic") { WorkerThreadPool Threadpool{1}; @@ -368,6 +370,8 @@ TEST_CASE("threadpool.basic") CHECK_THROWS(FutureThrow.get()); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/xmake.lua b/src/zencore/xmake.lua index a3fd4dacb..171f4c533 100644 --- a/src/zencore/xmake.lua +++ b/src/zencore/xmake.lua @@ -15,6 +15,7 @@ target('zencore') set_configdir("include/zencore") add_files("**.cpp") add_files("trace.cpp", {unity_ignored = true }) + add_files("testing.cpp", {unity_ignored = true }) if has_config("zenrpmalloc") then add_deps("rpmalloc") @@ -25,7 +26,6 @@ target('zencore') end add_deps("zenbase") - add_deps("spdlog") add_deps("utfcpp") add_deps("oodle") add_deps("blake3") @@ -33,8 +33,6 @@ target('zencore') add_deps("timesinceprocessstart") add_deps("doctest") add_deps("fmt") - add_deps("ryml") - add_packages("json11") if is_plat("linux", "macosx") then diff --git a/src/zencore/xxhash.cpp b/src/zencore/xxhash.cpp index 6d1050531..88a48dd68 100644 --- a/src/zencore/xxhash.cpp +++ b/src/zencore/xxhash.cpp @@ -59,6 +59,8 @@ xxhash_forcelink() { } +TEST_SUITE_BEGIN("core.xxhash"); + TEST_CASE("XXH3_128") { using namespace std::literals; @@ -96,6 +98,8 @@ TEST_CASE("XXH3_128") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp index 4ff79edc7..8c29a8962 100644 --- a/src/zencore/zencore.cpp +++ b/src/zencore/zencore.cpp @@ -147,7 +147,7 @@ AssertImpl::OnAssert(const char* Filename, int LineNumber, const char* FunctionN Message.push_back('\0'); // We use direct ZEN_LOG here instead of ZEN_ERROR as we don't care about *this* code location in the log - ZEN_LOG(Log(), zen::logging::level::Err, "{}", Message.data()); + ZEN_LOG(Log(), zen::logging::Err, "{}", Message.data()); zen::logging::FlushLogging(); } @@ -285,7 +285,7 @@ zencore_forcelinktests() namespace zen { -TEST_SUITE_BEGIN("core.assert"); +TEST_SUITE_BEGIN("core.zencore"); TEST_CASE("Assert.Default") { diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp new file mode 100644 index 000000000..819b2d0cb --- /dev/null +++ b/src/zenhorde/hordeagent.cpp @@ -0,0 +1,297 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordeagent.h" +#include "hordetransportaes.h" + +#include <zencore/basicfile.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/trace.h> + +#include <cstring> +#include <unordered_map> + +namespace zen::horde { + +HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info) +{ + ZEN_TRACE_CPU("HordeAgent::Connect"); + + auto Transport = std::make_unique<TcpComputeTransport>(Info); + if (!Transport->IsValid()) + { + ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort()); + return; + } + + // The 64-byte nonce is always sent unencrypted as the first thing on the wire. + // The Horde agent uses this to identify which lease this connection belongs to. + Transport->Send(Info.Nonce, sizeof(Info.Nonce)); + + std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport); + if (Info.EncryptionMode == Encryption::AES) + { + FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport)); + if (!FinalTransport->IsValid()) + { + ZEN_WARN("failed to create AES transport"); + return; + } + } + + // Create multiplexed socket and channels + m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport)); + + // Channel 0 is the agent control channel (handles Attach/Fork handshake). + // Channel 100 is the child I/O channel (handles file upload and remote execution). + Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0); + Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100); + + if (!AgentComputeChannel || !ChildComputeChannel) + { + ZEN_WARN("failed to create compute channels"); + return; + } + + m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel)); + m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel)); + + m_IsValid = true; +} + +HordeAgent::~HordeAgent() +{ + CloseConnection(); +} + +bool +HordeAgent::BeginCommunication() +{ + ZEN_TRACE_CPU("HordeAgent::BeginCommunication"); + + if (!m_IsValid) + { + return false; + } + + // Start the send/recv pump threads + m_Socket->StartCommunication(); + + // Wait for Attach on agent channel + AgentMessageType Type = m_AgentChannel->ReadResponse(5000); + if (Type == AgentMessageType::None) + { + ZEN_WARN("timed out waiting for Attach on agent channel"); + return false; + } + if (Type != AgentMessageType::Attach) + { + ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type)); + return false; + } + + // Fork tells the remote agent to create child channel 100 with a 4MB buffer. + // After this, the agent will send an Attach on the child channel. + m_AgentChannel->Fork(100, 4 * 1024 * 1024); + + // Wait for Attach on child channel + Type = m_ChildChannel->ReadResponse(5000); + if (Type == AgentMessageType::None) + { + ZEN_WARN("timed out waiting for Attach on child channel"); + return false; + } + if (Type != AgentMessageType::Attach) + { + ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type)); + return false; + } + + return true; +} + +bool +HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator) +{ + ZEN_TRACE_CPU("HordeAgent::UploadBinaries"); + + m_ChildChannel->UploadFiles("", BundleLocator.c_str()); + + std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles; + + auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* { + std::string Key(Locator); + + if (auto It = BlobFiles.find(Key); It != BlobFiles.end()) + { + return It->second.get(); + } + + const std::filesystem::path Path = BundleDir / (Key + ".blob"); + std::error_code Ec; + auto File = std::make_unique<BasicFile>(); + File->Open(Path, BasicFile::Mode::kRead, Ec); + + if (Ec) + { + ZEN_ERROR("cannot read blob file: '{}'", Path); + return nullptr; + } + + BasicFile* Ptr = File.get(); + BlobFiles.emplace(std::move(Key), std::move(File)); + return Ptr; + }; + + // The upload protocol is request-driven: we send WriteFiles, then the remote agent + // sends ReadBlob requests for each blob it needs. We respond with Blob data until + // the agent sends WriteFilesResponse indicating the upload is complete. + constexpr int32_t ReadResponseTimeoutMs = 1000; + + for (;;) + { + bool TimedOut = false; + + if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob) + { + if (TimedOut) + { + continue; + } + // End of stream - check if it was a successful upload + if (Type == AgentMessageType::WriteFilesResponse) + { + return true; + } + else if (Type == AgentMessageType::Exception) + { + ExceptionInfo Ex; + m_ChildChannel->ReadException(Ex); + ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); + } + else + { + ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type)); + } + return false; + } + + BlobRequest Req; + m_ChildChannel->ReadBlobRequest(Req); + + BasicFile* File = FindOrOpenBlob(Req.Locator); + if (!File) + { + return false; + } + + // Read from offset to end of file + const uint64_t TotalSize = File->FileSize(); + const uint64_t Offset = static_cast<uint64_t>(Req.Offset); + if (Offset >= TotalSize) + { + ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); + m_ChildChannel->Blob(nullptr, 0); + continue; + } + + const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); + m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize()); + } +} + +void +HordeAgent::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + bool UseWine) +{ + ZEN_TRACE_CPU("HordeAgent::Execute"); + m_ChildChannel + ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); +} + +bool +HordeAgent::Poll(bool LogOutput) +{ + constexpr int32_t ReadResponseTimeoutMs = 100; + AgentMessageType Type; + + while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None) + { + switch (Type) + { + case AgentMessageType::ExecuteOutput: + { + if (LogOutput && m_ChildChannel->GetResponseSize() > 0) + { + const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData()); + size_t ResponseSize = m_ChildChannel->GetResponseSize(); + + // Trim trailing newlines + while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r')) + { + --ResponseSize; + } + + if (ResponseSize > 0) + { + const std::string_view Output(ResponseData, ResponseSize); + ZEN_INFO("[remote] {}", Output); + } + } + break; + } + + case AgentMessageType::ExecuteResult: + { + if (m_ChildChannel->GetResponseSize() == sizeof(int32_t)) + { + int32_t ExitCode; + memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t)); + ZEN_INFO("remote process exited with code {}", ExitCode); + } + m_IsValid = false; + return false; + } + + case AgentMessageType::Exception: + { + ExceptionInfo Ex; + m_ChildChannel->ReadException(Ex); + ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); + m_HasErrors = true; + break; + } + + default: + break; + } + } + + return m_IsValid && !m_HasErrors; +} + +void +HordeAgent::CloseConnection() +{ + if (m_ChildChannel) + { + m_ChildChannel->Close(); + } + if (m_AgentChannel) + { + m_AgentChannel->Close(); + } +} + +bool +HordeAgent::IsValid() const +{ + return m_IsValid && !m_HasErrors; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h new file mode 100644 index 000000000..e0ae89ead --- /dev/null +++ b/src/zenhorde/hordeagent.h @@ -0,0 +1,77 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordeagentmessage.h" +#include "hordecomputesocket.h" + +#include <zenhorde/hordeclient.h> + +#include <zencore/logbase.h> + +#include <filesystem> +#include <memory> +#include <string> + +namespace zen::horde { + +/** Manages the lifecycle of a single Horde compute agent. + * + * Handles the full connection sequence for one provisioned machine: + * 1. Connect via TCP transport (with optional AES encryption wrapping) + * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100) + * 3. Perform the Attach/Fork handshake to establish the child channel + * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol + * 5. Execute zenserver remotely via ExecuteV2 + * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code) + */ +class HordeAgent +{ +public: + explicit HordeAgent(const MachineInfo& Info); + ~HordeAgent(); + + HordeAgent(const HordeAgent&) = delete; + HordeAgent& operator=(const HordeAgent&) = delete; + + /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel). + * Returns false if the handshake times out or receives an unexpected message. */ + bool BeginCommunication(); + + /** Upload binary files to the remote agent. + * @param BundleDir Directory containing .blob files. + * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */ + bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator); + + /** Execute a command on the remote machine. */ + void Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir = nullptr, + const char* const* EnvVars = nullptr, + size_t NumEnvVars = 0, + bool UseWine = false); + + /** Poll for output and results. Returns true if the agent is still running. + * When LogOutput is true, remote stdout is logged via ZEN_INFO. */ + bool Poll(bool LogOutput = true); + + void CloseConnection(); + bool IsValid() const; + + const MachineInfo& GetMachineInfo() const { return m_MachineInfo; } + +private: + LoggerRef Log() { return m_Log; } + + std::unique_ptr<ComputeSocket> m_Socket; + std::unique_ptr<AgentMessageChannel> m_AgentChannel; ///< Channel 0: agent control + std::unique_ptr<AgentMessageChannel> m_ChildChannel; ///< Channel 100: child I/O + + LoggerRef m_Log; + bool m_IsValid = false; + bool m_HasErrors = false; + MachineInfo m_MachineInfo; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp new file mode 100644 index 000000000..998134a96 --- /dev/null +++ b/src/zenhorde/hordeagentmessage.cpp @@ -0,0 +1,340 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordeagentmessage.h" + +#include <zencore/intmath.h> + +#include <cassert> +#include <cstring> + +namespace zen::horde { + +AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel)) +{ +} + +AgentMessageChannel::~AgentMessageChannel() = default; + +void +AgentMessageChannel::Close() +{ + CreateMessage(AgentMessageType::None, 0); + FlushMessage(); +} + +void +AgentMessageChannel::Ping() +{ + CreateMessage(AgentMessageType::Ping, 0); + FlushMessage(); +} + +void +AgentMessageChannel::Fork(int ChannelId, int BufferSize) +{ + CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); + WriteInt32(ChannelId); + WriteInt32(BufferSize); + FlushMessage(); +} + +void +AgentMessageChannel::Attach() +{ + CreateMessage(AgentMessageType::Attach, 0); + FlushMessage(); +} + +void +AgentMessageChannel::UploadFiles(const char* Path, const char* Locator) +{ + CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); + WriteString(Path); + WriteString(Locator); + FlushMessage(); +} + +void +AgentMessageChannel::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags) +{ + size_t RequiredSize = 50 + strlen(Exe); + for (size_t i = 0; i < NumArgs; ++i) + { + RequiredSize += strlen(Args[i]) + 10; + } + if (WorkingDir) + { + RequiredSize += strlen(WorkingDir) + 10; + } + for (size_t i = 0; i < NumEnvVars; ++i) + { + RequiredSize += strlen(EnvVars[i]) + 20; + } + + CreateMessage(AgentMessageType::ExecuteV2, RequiredSize); + WriteString(Exe); + + WriteUnsignedVarInt(NumArgs); + for (size_t i = 0; i < NumArgs; ++i) + { + WriteString(Args[i]); + } + + WriteOptionalString(WorkingDir); + + // ExecuteV2 protocol requires env vars as separate key/value pairs. + // Callers pass "KEY=VALUE" strings; we split on the first '=' here. + WriteUnsignedVarInt(NumEnvVars); + for (size_t i = 0; i < NumEnvVars; ++i) + { + const char* Eq = strchr(EnvVars[i], '='); + assert(Eq != nullptr); + + WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i])); + if (*(Eq + 1) == '\0') + { + WriteOptionalString(nullptr); + } + else + { + WriteOptionalString(Eq + 1); + } + } + + WriteInt32(static_cast<int>(Flags)); + FlushMessage(); +} + +void +AgentMessageChannel::Blob(const uint8_t* Data, size_t Length) +{ + // Blob responses are chunked to fit within the compute buffer's chunk size. + // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields). + const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength; + for (size_t ChunkOffset = 0; ChunkOffset < Length;) + { + const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize); + + CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); + WriteInt32(static_cast<int>(ChunkOffset)); + WriteInt32(static_cast<int>(Length)); + WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength); + FlushMessage(); + + ChunkOffset += ChunkLength; + } +} + +AgentMessageType +AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut) +{ + // Deferred advance: the previous response's buffer is only released when the next + // ReadResponse is called. This allows callers to read response data between calls + // without copying, since the pointer comes directly from the ring buffer. + if (m_ResponseData) + { + m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength); + m_ResponseData = nullptr; + m_ResponseLength = 0; + } + + const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut); + if (!Header) + { + return AgentMessageType::None; + } + + uint32_t Length; + memcpy(&Length, Header + 1, sizeof(uint32_t)); + + Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut); + if (!Header) + { + return AgentMessageType::None; + } + + m_ResponseType = static_cast<AgentMessageType>(Header[0]); + m_ResponseData = Header + MessageHeaderLength; + m_ResponseLength = Length; + + return m_ResponseType; +} + +void +AgentMessageChannel::ReadException(ExceptionInfo& Ex) +{ + assert(m_ResponseType == AgentMessageType::Exception); + const uint8_t* Pos = m_ResponseData; + Ex.Message = ReadString(&Pos); + Ex.Description = ReadString(&Pos); +} + +int +AgentMessageChannel::ReadExecuteResult() +{ + assert(m_ResponseType == AgentMessageType::ExecuteResult); + const uint8_t* Pos = m_ResponseData; + return ReadInt32(&Pos); +} + +void +AgentMessageChannel::ReadBlobRequest(BlobRequest& Req) +{ + assert(m_ResponseType == AgentMessageType::ReadBlob); + const uint8_t* Pos = m_ResponseData; + Req.Locator = ReadString(&Pos); + Req.Offset = ReadUnsignedVarInt(&Pos); + Req.Length = ReadUnsignedVarInt(&Pos); +} + +void +AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength) +{ + m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength); + m_RequestData[0] = static_cast<uint8_t>(Type); + m_MaxRequestSize = MaxLength; + m_RequestSize = 0; +} + +void +AgentMessageChannel::FlushMessage() +{ + const uint32_t Size = static_cast<uint32_t>(m_RequestSize); + memcpy(&m_RequestData[1], &Size, sizeof(uint32_t)); + m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize); + m_RequestSize = 0; + m_MaxRequestSize = 0; + m_RequestData = nullptr; +} + +void +AgentMessageChannel::WriteInt32(int Value) +{ + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int)); +} + +int +AgentMessageChannel::ReadInt32(const uint8_t** Pos) +{ + int Value; + memcpy(&Value, *Pos, sizeof(int)); + *Pos += sizeof(int); + return Value; +} + +void +AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length) +{ + assert(m_RequestSize + Length <= m_MaxRequestSize); + memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length); + m_RequestSize += Length; +} + +const uint8_t* +AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +{ + const uint8_t* Data = *Pos; + *Pos += Length; + return Data; +} + +size_t +AgentMessageChannel::MeasureUnsignedVarInt(size_t Value) +{ + if (Value == 0) + { + return 1; + } + return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; +} + +void +AgentMessageChannel::WriteUnsignedVarInt(size_t Value) +{ + const size_t ByteCount = MeasureUnsignedVarInt(Value); + assert(m_RequestSize + ByteCount <= m_MaxRequestSize); + + uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize; + for (size_t i = 1; i < ByteCount; ++i) + { + Output[ByteCount - i] = static_cast<uint8_t>(Value); + Value >>= 8; + } + Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); + + m_RequestSize += ByteCount; +} + +size_t +AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +{ + const uint8_t* Data = *Pos; + const uint8_t FirstByte = Data[0]; + const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24; + + size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes)); + for (size_t i = 1; i < NumBytes; ++i) + { + Value <<= 8; + Value |= Data[i]; + } + + *Pos += NumBytes; + return Value; +} + +size_t +AgentMessageChannel::MeasureString(const char* Text) const +{ + const size_t Length = strlen(Text); + return MeasureUnsignedVarInt(Length) + Length; +} + +void +AgentMessageChannel::WriteString(const char* Text) +{ + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Length); + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); +} + +void +AgentMessageChannel::WriteString(std::string_view Text) +{ + WriteUnsignedVarInt(Text.size()); + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); +} + +std::string_view +AgentMessageChannel::ReadString(const uint8_t** Pos) +{ + const size_t Length = ReadUnsignedVarInt(Pos); + const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length)); + return std::string_view(Start, Length); +} + +void +AgentMessageChannel::WriteOptionalString(const char* Text) +{ + // Optional strings use length+1 encoding: 0 means null/absent, + // N>0 means a string of length N-1 follows. This matches the UE + // FAgentMessageChannel serialization convention. + if (!Text) + { + WriteUnsignedVarInt(0); + } + else + { + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Length + 1); + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h new file mode 100644 index 000000000..38c4375fd --- /dev/null +++ b/src/zenhorde/hordeagentmessage.h @@ -0,0 +1,161 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> + +#include "hordecomputechannel.h" + +#include <cstddef> +#include <cstdint> +#include <string> +#include <string_view> +#include <vector> + +namespace zen::horde { + +/** Agent message types matching the UE EAgentMessageType byte values. + * These are the message opcodes exchanged over the agent/child channels. */ +enum class AgentMessageType : uint8_t +{ + None = 0x00, + Ping = 0x01, + Exception = 0x02, + Fork = 0x03, + Attach = 0x04, + WriteFiles = 0x10, + WriteFilesResponse = 0x11, + DeleteFiles = 0x12, + ExecuteV2 = 0x22, + ExecuteOutput = 0x17, + ExecuteResult = 0x18, + ReadBlob = 0x20, + ReadBlobResponse = 0x21, +}; + +/** Flags for the ExecuteV2 message. */ +enum class ExecuteProcessFlags : uint8_t +{ + None = 0, + UseWine = 1, ///< Run the executable under Wine on Linux agents +}; + +/** Parsed exception information from an Exception message. */ +struct ExceptionInfo +{ + std::string_view Message; + std::string_view Description; +}; + +/** Parsed blob read request from a ReadBlob message. */ +struct BlobRequest +{ + std::string_view Locator; + size_t Offset = 0; + size_t Length = 0; +}; + +/** Channel for sending and receiving agent messages over a ComputeChannel. + * + * Implements the Horde agent message protocol, matching the UE + * FAgentMessageChannel serialization format exactly. Messages are framed as + * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8; + * integers use variable-length encoding. + * + * The protocol has two directions: + * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob + * - Responses (remote -> initiator): ReadResponse returns the type, then call the + * appropriate Read* method to parse the payload. + */ +class AgentMessageChannel +{ +public: + explicit AgentMessageChannel(Ref<ComputeChannel> Channel); + ~AgentMessageChannel(); + + AgentMessageChannel(const AgentMessageChannel&) = delete; + AgentMessageChannel& operator=(const AgentMessageChannel&) = delete; + + // --- Requests (Initiator -> Remote) --- + + /** Close the channel. */ + void Close(); + + /** Send a keepalive ping. */ + void Ping(); + + /** Fork communication to a new channel with the given ID and buffer size. */ + void Fork(int ChannelId, int BufferSize); + + /** Send an attach request (used during channel setup handshake). */ + void Attach(); + + /** Request the remote agent to write files from the given bundle locator. */ + void UploadFiles(const char* Path, const char* Locator); + + /** Execute a process on the remote machine. */ + void Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags = ExecuteProcessFlags::None); + + /** Send blob data in response to a ReadBlob request. */ + void Blob(const uint8_t* Data, size_t Length); + + // --- Responses (Remote -> Initiator) --- + + /** Read the next response message. Returns the message type, or None on timeout. + * After this returns, use GetResponseData()/GetResponseSize() or the typed + * Read* methods to access the payload. */ + AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr); + + const void* GetResponseData() const { return m_ResponseData; } + size_t GetResponseSize() const { return m_ResponseLength; } + + /** Parse an Exception response payload. */ + void ReadException(ExceptionInfo& Ex); + + /** Parse an ExecuteResult response payload. Returns the exit code. */ + int ReadExecuteResult(); + + /** Parse a ReadBlob response payload into a BlobRequest. */ + void ReadBlobRequest(BlobRequest& Req); + +private: + static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)] + + Ref<ComputeChannel> m_Channel; + + uint8_t* m_RequestData = nullptr; + size_t m_RequestSize = 0; + size_t m_MaxRequestSize = 0; + + AgentMessageType m_ResponseType = AgentMessageType::None; + const uint8_t* m_ResponseData = nullptr; + size_t m_ResponseLength = 0; + + void CreateMessage(AgentMessageType Type, size_t MaxLength); + void FlushMessage(); + + void WriteInt32(int Value); + static int ReadInt32(const uint8_t** Pos); + + void WriteFixedLengthBytes(const uint8_t* Data, size_t Length); + static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length); + + static size_t MeasureUnsignedVarInt(size_t Value); + void WriteUnsignedVarInt(size_t Value); + static size_t ReadUnsignedVarInt(const uint8_t** Pos); + + size_t MeasureString(const char* Text) const; + void WriteString(const char* Text); + void WriteString(std::string_view Text); + static std::string_view ReadString(const uint8_t** Pos); + + void WriteOptionalString(const char* Text); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp new file mode 100644 index 000000000..d3974bc28 --- /dev/null +++ b/src/zenhorde/hordebundle.cpp @@ -0,0 +1,619 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordebundle.h" + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/intmath.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/process.h> +#include <zencore/trace.h> + +#include <algorithm> +#include <chrono> +#include <cstring> + +namespace zen::horde { + +static LoggerRef +Log() +{ + static auto s_Logger = zen::logging::Get("horde.bundle"); + return s_Logger; +} + +static constexpr uint8_t PacketSignature[3] = {'U', 'B', 'N'}; +static constexpr uint8_t PacketVersion = 5; +static constexpr int32_t CurrentPacketBaseIdx = -2; +static constexpr int ImportBias = 3; +static constexpr uint32_t ChunkSize = 64 * 1024; // 64KB fixed chunks +static constexpr uint32_t LargeFileThreshold = 128 * 1024; // 128KB + +// BlobType: 20 bytes each = FGuid (16 bytes, 4x uint32 LE) + Version (int32 LE) +// Values from UE SDK: GUIDs stored as 4 uint32 LE values. + +// ChunkLeaf v1: {0xB27AFB68, 0x4A4B9E20, 0x8A78D8A4, 0x39D49840} +static constexpr uint8_t BlobType_ChunkLeafV1[20] = {0x68, 0xFB, 0x7A, 0xB2, 0x20, 0x9E, 0x4B, 0x4A, 0xA4, 0xD8, + 0x78, 0x8A, 0x40, 0x98, 0xD4, 0x39, 0x01, 0x00, 0x00, 0x00}; // version 1 + +// ChunkInterior v2: {0xF4DEDDBC, 0x4C7A70CB, 0x11F04783, 0xB9CDCCAF} +static constexpr uint8_t BlobType_ChunkInteriorV2[20] = {0xBC, 0xDD, 0xDE, 0xF4, 0xCB, 0x70, 0x7A, 0x4C, 0x83, 0x47, + 0xF0, 0x11, 0xAF, 0xCC, 0xCD, 0xB9, 0x02, 0x00, 0x00, 0x00}; // version 2 + +// Directory v1: {0x0714EC11, 0x4D07291A, 0x8AE77F86, 0x799980D6} +static constexpr uint8_t BlobType_DirectoryV1[20] = {0x11, 0xEC, 0x14, 0x07, 0x1A, 0x29, 0x07, 0x4D, 0x86, 0x7F, + 0xE7, 0x8A, 0xD6, 0x80, 0x99, 0x79, 0x01, 0x00, 0x00, 0x00}; // version 1 + +static constexpr size_t BlobTypeSize = 20; + +// ─── VarInt helpers (UE format) ───────────────────────────────────────────── + +static size_t +MeasureVarInt(size_t Value) +{ + if (Value == 0) + { + return 1; + } + return (FloorLog2(static_cast<unsigned int>(Value)) / 7) + 1; +} + +static void +WriteVarInt(std::vector<uint8_t>& Buffer, size_t Value) +{ + const size_t ByteCount = MeasureVarInt(Value); + const size_t Offset = Buffer.size(); + Buffer.resize(Offset + ByteCount); + + uint8_t* Output = Buffer.data() + Offset; + for (size_t i = 1; i < ByteCount; ++i) + { + Output[ByteCount - i] = static_cast<uint8_t>(Value); + Value >>= 8; + } + Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); +} + +// ─── Binary helpers ───────────────────────────────────────────────────────── + +static void +WriteLE32(std::vector<uint8_t>& Buffer, int32_t Value) +{ + uint8_t Bytes[4]; + memcpy(Bytes, &Value, 4); + Buffer.insert(Buffer.end(), Bytes, Bytes + 4); +} + +static void +WriteByte(std::vector<uint8_t>& Buffer, uint8_t Value) +{ + Buffer.push_back(Value); +} + +static void +WriteBytes(std::vector<uint8_t>& Buffer, const void* Data, size_t Size) +{ + auto* Ptr = static_cast<const uint8_t*>(Data); + Buffer.insert(Buffer.end(), Ptr, Ptr + Size); +} + +static void +WriteString(std::vector<uint8_t>& Buffer, std::string_view Str) +{ + WriteVarInt(Buffer, Str.size()); + WriteBytes(Buffer, Str.data(), Str.size()); +} + +static void +AlignTo4(std::vector<uint8_t>& Buffer) +{ + while (Buffer.size() % 4 != 0) + { + Buffer.push_back(0); + } +} + +static void +PatchLE32(std::vector<uint8_t>& Buffer, size_t Offset, int32_t Value) +{ + memcpy(Buffer.data() + Offset, &Value, 4); +} + +// ─── Packet builder ───────────────────────────────────────────────────────── + +// Builds a single uncompressed Horde V2 packet. Layout: +// [Signature(3) + Version(1) + PacketLength(4)] 8 bytes (header) +// [TypeTableOffset(4) + ImportTableOffset(4) + ExportTableOffset(4)] 12 bytes +// [Export data...] +// [Type table: count(4) + count * 20 bytes] +// [Import table: count(4) + (count+1) offset entries(4 each) + import data] +// [Export table: count(4) + (count+1) offset entries(4 each)] +// +// ALL offsets are absolute from byte 0 of the full packet (including the 8-byte header). +// PacketLength in the header = total packet size including the 8-byte header. + +struct PacketBuilder +{ + std::vector<uint8_t> Data; + std::vector<int32_t> ExportOffsets; // Absolute byte offset of each export from byte 0 + + // Type table: unique 20-byte BlobType entries + std::vector<const uint8_t*> Types; + + // Import table entries: (baseIdx, fragment) + struct ImportEntry + { + int32_t BaseIdx; + std::string Fragment; + }; + std::vector<ImportEntry> Imports; + + // Current export's start offset (absolute from byte 0) + size_t CurrentExportStart = 0; + + PacketBuilder() + { + // Reserve packet header (8 bytes) + table offsets (12 bytes) = 20 bytes + Data.resize(20, 0); + + // Write signature + Data[0] = PacketSignature[0]; + Data[1] = PacketSignature[1]; + Data[2] = PacketSignature[2]; + Data[3] = PacketVersion; + // PacketLength, TypeTableOffset, ImportTableOffset, ExportTableOffset + // will be patched in Finish() + } + + int AddType(const uint8_t* BlobType) + { + for (size_t i = 0; i < Types.size(); ++i) + { + if (memcmp(Types[i], BlobType, BlobTypeSize) == 0) + { + return static_cast<int>(i); + } + } + Types.push_back(BlobType); + return static_cast<int>(Types.size() - 1); + } + + int AddImport(int32_t BaseIdx, std::string Fragment) + { + Imports.push_back({BaseIdx, std::move(Fragment)}); + return static_cast<int>(Imports.size() - 1); + } + + void BeginExport() + { + AlignTo4(Data); + CurrentExportStart = Data.size(); + // Reserve space for payload length + WriteLE32(Data, 0); + } + + // Write raw payload data into the current export + void WritePayload(const void* Payload, size_t Size) { WriteBytes(Data, Payload, Size); } + + // Complete the current export: patches payload length, writes type+imports metadata + int CompleteExport(const uint8_t* BlobType, const std::vector<int>& ImportIndices) + { + const int ExportIndex = static_cast<int>(ExportOffsets.size()); + + // Patch payload length (does not include the 4-byte length field itself) + const size_t PayloadStart = CurrentExportStart + 4; + const int32_t PayloadLen = static_cast<int32_t>(Data.size() - PayloadStart); + PatchLE32(Data, CurrentExportStart, PayloadLen); + + // Write type index (varint) + const int TypeIdx = AddType(BlobType); + WriteVarInt(Data, static_cast<size_t>(TypeIdx)); + + // Write import count + indices + WriteVarInt(Data, ImportIndices.size()); + for (int Idx : ImportIndices) + { + WriteVarInt(Data, static_cast<size_t>(Idx)); + } + + // Record export offset (absolute from byte 0) + ExportOffsets.push_back(static_cast<int32_t>(CurrentExportStart)); + + return ExportIndex; + } + + // Finalize the packet: write type/import/export tables, patch header. + std::vector<uint8_t> Finish() + { + AlignTo4(Data); + + // ── Type table: count(int32) + count * BlobTypeSize bytes ── + const int32_t TypeTableOffset = static_cast<int32_t>(Data.size()); + WriteLE32(Data, static_cast<int32_t>(Types.size())); + for (const uint8_t* TypeEntry : Types) + { + WriteBytes(Data, TypeEntry, BlobTypeSize); + } + + // ── Import table: count(int32) + (count+1) offsets(int32 each) + import data ── + const int32_t ImportTableOffset = static_cast<int32_t>(Data.size()); + const int32_t ImportCount = static_cast<int32_t>(Imports.size()); + WriteLE32(Data, ImportCount); + + // Reserve space for (count+1) offset entries — will be patched below + const size_t ImportOffsetsStart = Data.size(); + for (int32_t i = 0; i <= ImportCount; ++i) + { + WriteLE32(Data, 0); // placeholder + } + + // Write import data and record offsets + for (int32_t i = 0; i < ImportCount; ++i) + { + // Record absolute offset of this import's data + PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(i) * 4, static_cast<int32_t>(Data.size())); + + ImportEntry& Imp = Imports[static_cast<size_t>(i)]; + // BaseIdx encoded as unsigned VarInt with bias: VarInt(BaseIdx + ImportBias) + const size_t EncodedBaseIdx = static_cast<size_t>(static_cast<int64_t>(Imp.BaseIdx) + ImportBias); + WriteVarInt(Data, EncodedBaseIdx); + // Fragment: raw UTF-8 bytes, NO length prefix (length determined by offset table) + WriteBytes(Data, Imp.Fragment.data(), Imp.Fragment.size()); + } + + // Sentinel offset (points past the last import's data) + PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(ImportCount) * 4, static_cast<int32_t>(Data.size())); + + // ── Export table: count(int32) + (count+1) offsets(int32 each) ── + const int32_t ExportTableOffset = static_cast<int32_t>(Data.size()); + const int32_t ExportCount = static_cast<int32_t>(ExportOffsets.size()); + WriteLE32(Data, ExportCount); + + for (int32_t Off : ExportOffsets) + { + WriteLE32(Data, Off); + } + // Sentinel: points to the start of the type table (end of export data region) + WriteLE32(Data, TypeTableOffset); + + // ── Patch header ── + // PacketLength = total packet size including the 8-byte header + const int32_t PacketLength = static_cast<int32_t>(Data.size()); + PatchLE32(Data, 4, PacketLength); + PatchLE32(Data, 8, TypeTableOffset); + PatchLE32(Data, 12, ImportTableOffset); + PatchLE32(Data, 16, ExportTableOffset); + + return std::move(Data); + } +}; + +// ─── Encoded packet wrapper ───────────────────────────────────────────────── + +// Wraps an uncompressed packet with the encoded header: +// [Signature(3) + Version(1) + HeaderLength(4)] 8 bytes +// [DecompressedLength(4)] 4 bytes +// [CompressionFormat(1): 0=None] 1 byte +// [PacketData...] +// +// HeaderLength = total encoded packet size INCLUDING the 8-byte outer header. + +static std::vector<uint8_t> +EncodePacket(std::vector<uint8_t> UncompressedPacket) +{ + const int32_t DecompressedLen = static_cast<int32_t>(UncompressedPacket.size()); + // HeaderLength includes the 8-byte outer signature header itself + const int32_t HeaderLength = 8 + 4 + 1 + DecompressedLen; + + std::vector<uint8_t> Encoded; + Encoded.reserve(static_cast<size_t>(HeaderLength)); + + // Outer signature: 'U','B','N', version=5, HeaderLength (LE int32) + WriteByte(Encoded, PacketSignature[0]); // 'U' + WriteByte(Encoded, PacketSignature[1]); // 'B' + WriteByte(Encoded, PacketSignature[2]); // 'N' + WriteByte(Encoded, PacketVersion); // 5 + WriteLE32(Encoded, HeaderLength); + + // Decompressed length + compression format + WriteLE32(Encoded, DecompressedLen); + WriteByte(Encoded, 0); // CompressionFormat::None + + // Packet data + WriteBytes(Encoded, UncompressedPacket.data(), UncompressedPacket.size()); + + return Encoded; +} + +// ─── Bundle blob name generation ──────────────────────────────────────────── + +static std::string +GenerateBlobName() +{ + static std::atomic<uint32_t> s_Counter{0}; + + const int Pid = GetCurrentProcessId(); + + auto Now = std::chrono::steady_clock::now().time_since_epoch(); + auto Ms = std::chrono::duration_cast<std::chrono::milliseconds>(Now).count(); + + ExtendableStringBuilder<64> Name; + Name << Pid << "_" << Ms << "_" << s_Counter.fetch_add(1); + return std::string(Name.ToView()); +} + +// ─── File info for bundling ───────────────────────────────────────────────── + +struct FileInfo +{ + std::filesystem::path Path; + std::string Name; // Filename only (for directory entry) + uint64_t FileSize; + IoHash ContentHash; // IoHash of file content + BLAKE3 StreamHash; // Full BLAKE3 for stream hash + int DirectoryExportImportIndex; // Import index referencing this file's root export + IoHash RootExportHash; // IoHash of the root export for this file +}; + +// ─── CreateBundle implementation ──────────────────────────────────────────── + +bool +BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult) +{ + ZEN_TRACE_CPU("BundleCreator::CreateBundle"); + + std::error_code Ec; + + // Collect files that exist + std::vector<FileInfo> ValidFiles; + for (const BundleFile& F : Files) + { + if (!std::filesystem::exists(F.Path, Ec)) + { + if (F.Optional) + { + continue; + } + ZEN_ERROR("required bundle file does not exist: {}", F.Path.string()); + return false; + } + FileInfo Info; + Info.Path = F.Path; + Info.Name = F.Path.filename().string(); + Info.FileSize = std::filesystem::file_size(F.Path, Ec); + if (Ec) + { + ZEN_ERROR("failed to get file size: {}", F.Path.string()); + return false; + } + ValidFiles.push_back(std::move(Info)); + } + + if (ValidFiles.empty()) + { + ZEN_ERROR("no valid files to bundle"); + return false; + } + + std::filesystem::create_directories(OutputDir, Ec); + if (Ec) + { + ZEN_ERROR("failed to create output directory: {}", OutputDir.string()); + return false; + } + + const std::string BlobName = GenerateBlobName(); + PacketBuilder Packet; + + // Process each file: create chunk exports + for (FileInfo& Info : ValidFiles) + { + BasicFile File; + File.Open(Info.Path, BasicFile::Mode::kRead, Ec); + if (Ec) + { + ZEN_ERROR("failed to open file: {}", Info.Path.string()); + return false; + } + + // Compute stream hash (full BLAKE3) and content hash (IoHash) while reading + BLAKE3Stream StreamHasher; + IoHashStream ContentHasher; + + if (Info.FileSize <= LargeFileThreshold) + { + // Small file: single chunk leaf export + IoBuffer Content = File.ReadAll(); + const auto* Data = static_cast<const uint8_t*>(Content.GetData()); + const size_t Size = Content.GetSize(); + + StreamHasher.Append(Data, Size); + ContentHasher.Append(Data, Size); + + Packet.BeginExport(); + Packet.WritePayload(Data, Size); + + const IoHash ChunkHash = IoHash::HashBuffer(Data, Size); + const int ExportIndex = Packet.CompleteExport(BlobType_ChunkLeafV1, {}); + Info.RootExportHash = ChunkHash; + Info.ContentHash = ContentHasher.GetHash(); + Info.StreamHash = StreamHasher.GetHash(); + + // Add import for this file's root export (references export within same packet) + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << ExportIndex; + Info.DirectoryExportImportIndex = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + } + else + { + // Large file: split into fixed 64KB chunks, then create interior node + std::vector<int> ChunkExportIndices; + std::vector<IoHash> ChunkHashes; + + uint64_t Remaining = Info.FileSize; + uint64_t Offset = 0; + + while (Remaining > 0) + { + const uint64_t ReadSize = std::min(static_cast<uint64_t>(ChunkSize), Remaining); + IoBuffer Chunk = File.ReadRange(Offset, ReadSize); + const auto* Data = static_cast<const uint8_t*>(Chunk.GetData()); + const size_t Size = Chunk.GetSize(); + + StreamHasher.Append(Data, Size); + ContentHasher.Append(Data, Size); + + Packet.BeginExport(); + Packet.WritePayload(Data, Size); + + const IoHash ChunkHash = IoHash::HashBuffer(Data, Size); + const int ExpIdx = Packet.CompleteExport(BlobType_ChunkLeafV1, {}); + + ChunkExportIndices.push_back(ExpIdx); + ChunkHashes.push_back(ChunkHash); + + Offset += ReadSize; + Remaining -= ReadSize; + } + + Info.ContentHash = ContentHasher.GetHash(); + Info.StreamHash = StreamHasher.GetHash(); + + // Create interior node referencing all chunk leaves + // Interior payload: for each child: [IoHash(20)][node_type=1(1)] + imports + std::vector<int> InteriorImports; + for (size_t i = 0; i < ChunkExportIndices.size(); ++i) + { + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << ChunkExportIndices[i]; + const int ImportIdx = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + InteriorImports.push_back(ImportIdx); + } + + Packet.BeginExport(); + + // Write interior payload: [hash(20)][type(1)] per child + for (size_t i = 0; i < ChunkHashes.size(); ++i) + { + Packet.WritePayload(ChunkHashes[i].Hash, sizeof(IoHash)); + const uint8_t NodeType = 1; // ChunkNode type + Packet.WritePayload(&NodeType, 1); + } + + // Hash the interior payload to get the interior node hash + const IoHash InteriorHash = IoHash::HashBuffer(Packet.Data.data() + (Packet.CurrentExportStart + 4), + Packet.Data.size() - (Packet.CurrentExportStart + 4)); + + const int InteriorExportIndex = Packet.CompleteExport(BlobType_ChunkInteriorV2, InteriorImports); + + Info.RootExportHash = InteriorHash; + + // Add import for directory to reference this interior node + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << InteriorExportIndex; + Info.DirectoryExportImportIndex = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + } + } + + // Create directory node export + // Payload: [flags(varint=0)] [file_count(varint)] [file_entries...] [dir_count(varint=0)] + // FileEntry: [import(varint)] [IoHash(20)] [name(string)] [flags(varint)] [length(varint)] [IoHash_stream(20)] + + Packet.BeginExport(); + + // Build directory payload into a temporary buffer, then write it + std::vector<uint8_t> DirPayload; + WriteVarInt(DirPayload, 0); // flags + WriteVarInt(DirPayload, ValidFiles.size()); // file_count + + std::vector<int> DirImports; + for (size_t i = 0; i < ValidFiles.size(); ++i) + { + FileInfo& Info = ValidFiles[i]; + DirImports.push_back(Info.DirectoryExportImportIndex); + + // IoHash of target (20 bytes) — import is consumed sequentially from the + // export's import list by ReadBlobRef, not encoded in the payload + WriteBytes(DirPayload, Info.RootExportHash.Hash, sizeof(IoHash)); + // name (string) + WriteString(DirPayload, Info.Name); + // flags (varint): 1 = Executable + WriteVarInt(DirPayload, 1); + // length (varint) + WriteVarInt(DirPayload, static_cast<size_t>(Info.FileSize)); + // stream hash: IoHash from full BLAKE3, truncated to 20 bytes + const IoHash StreamIoHash = IoHash::FromBLAKE3(Info.StreamHash); + WriteBytes(DirPayload, StreamIoHash.Hash, sizeof(IoHash)); + } + + WriteVarInt(DirPayload, 0); // dir_count + + Packet.WritePayload(DirPayload.data(), DirPayload.size()); + const int DirExportIndex = Packet.CompleteExport(BlobType_DirectoryV1, DirImports); + + // Finalize packet and encode + std::vector<uint8_t> UncompressedPacket = Packet.Finish(); + std::vector<uint8_t> EncodedPacket = EncodePacket(std::move(UncompressedPacket)); + + // Write .blob file + const std::filesystem::path BlobFilePath = OutputDir / (BlobName + ".blob"); + { + BasicFile BlobFile(BlobFilePath, BasicFile::Mode::kTruncate, Ec); + if (Ec) + { + ZEN_ERROR("failed to create blob file: {}", BlobFilePath.string()); + return false; + } + BlobFile.Write(EncodedPacket.data(), EncodedPacket.size(), 0); + } + + // Build locator: <blob_name>#pkt=0,<encoded_len>&exp=<dir_export_index> + ExtendableStringBuilder<256> Locator; + Locator << BlobName << "#pkt=0," << uint64_t(EncodedPacket.size()) << "&exp=" << DirExportIndex; + const std::string LocatorStr(Locator.ToView()); + + // Write .ref file (use first file's name as the ref base) + const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + ".Bundle.ref"); + { + BasicFile RefFile(RefFilePath, BasicFile::Mode::kTruncate, Ec); + if (Ec) + { + ZEN_ERROR("failed to create ref file: {}", RefFilePath.string()); + return false; + } + RefFile.Write(LocatorStr.data(), LocatorStr.size(), 0); + } + + OutResult.Locator = LocatorStr; + OutResult.BundleDir = OutputDir; + + ZEN_INFO("created V2 bundle: blob={}.blob locator={} files={}", BlobName, LocatorStr, ValidFiles.size()); + return true; +} + +bool +BundleCreator::ReadLocator(const std::filesystem::path& RefFile, std::string& OutLocator) +{ + BasicFile File; + std::error_code Ec; + File.Open(RefFile, BasicFile::Mode::kRead, Ec); + if (Ec) + { + return false; + } + + IoBuffer Content = File.ReadAll(); + OutLocator.assign(static_cast<const char*>(Content.GetData()), Content.GetSize()); + + // Strip trailing whitespace/newlines + while (!OutLocator.empty() && (OutLocator.back() == '\n' || OutLocator.back() == '\r' || OutLocator.back() == '\0')) + { + OutLocator.pop_back(); + } + + return !OutLocator.empty(); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordebundle.h b/src/zenhorde/hordebundle.h new file mode 100644 index 000000000..052f60435 --- /dev/null +++ b/src/zenhorde/hordebundle.h @@ -0,0 +1,49 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <filesystem> +#include <string> +#include <vector> + +namespace zen::horde { + +/** Describes a file to include in a Horde bundle. */ +struct BundleFile +{ + std::filesystem::path Path; ///< Local file path + bool Optional; ///< If true, skip without error if missing +}; + +/** Result of a successful bundle creation. */ +struct BundleResult +{ + std::string Locator; ///< Root directory locator for WriteFiles + std::filesystem::path BundleDir; ///< Directory containing .blob files +}; + +/** Creates Horde V2 bundles from local files for upload to remote agents. + * + * Produces a proper Horde storage V2 bundle containing: + * - Chunk leaf exports for file data (split into 64KB chunks for large files) + * - Optional interior chunk nodes referencing leaf chunks + * - A directory node listing all bundled files with metadata + * + * The bundle is written as a single .blob file with a corresponding .ref file + * containing the locator string. The locator format is: + * <blob_name>#pkt=0,<encoded_len>&exp=<directory_export_index> + */ +struct BundleCreator +{ + /** Create a V2 bundle from one or more input files. + * @param Files Files to include in the bundle. + * @param OutputDir Directory where .blob and .ref files will be written. + * @param OutResult Receives the locator and output directory on success. + * @return True on success. */ + static bool CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult); + + /** Read a locator string from a .ref file. Strips trailing whitespace/newlines. */ + static bool ReadLocator(const std::filesystem::path& RefFile, std::string& OutLocator); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp new file mode 100644 index 000000000..fb981f0ba --- /dev/null +++ b/src/zenhorde/hordeclient.cpp @@ -0,0 +1,382 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memoryview.h> +#include <zencore/trace.h> +#include <zenhorde/hordeclient.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::horde { + +HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client")) +{ +} + +HordeClient::~HordeClient() = default; + +bool +HordeClient::Initialize() +{ + ZEN_TRACE_CPU("HordeClient::Initialize"); + + HttpClientSettings Settings; + Settings.LogCategory = "horde.http"; + Settings.ConnectTimeout = std::chrono::milliseconds{10000}; + Settings.Timeout = std::chrono::milliseconds{60000}; + Settings.RetryCount = 1; + Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests}; + + if (!m_Config.AuthToken.empty()) + { + Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken { + HttpClientAccessToken Token; + Token.Value = token; + Token.ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours{24}; + return Token; + }; + } + + m_Http = std::make_unique<zen::HttpClient>(m_Config.ServerUrl, Settings); + + if (!m_Config.AuthToken.empty()) + { + if (!m_Http->Authenticate()) + { + ZEN_WARN("failed to authenticate with Horde server"); + return false; + } + } + + return true; +} + +std::string +HordeClient::BuildRequestBody() const +{ + json11::Json::object Requirements; + + if (m_Config.Mode == ConnectionMode::Direct && !m_Config.Pool.empty()) + { + Requirements["pool"] = m_Config.Pool; + } + + std::string Condition; +#if ZEN_PLATFORM_WINDOWS + ExtendableStringBuilder<256> CondBuf; + CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')"; + Condition = std::string(CondBuf); +#elif ZEN_PLATFORM_MAC + Condition = "OSFamily == 'MacOS'"; +#else + Condition = "OSFamily == 'Linux'"; +#endif + + if (!m_Config.Condition.empty()) + { + Condition += " "; + Condition += m_Config.Condition; + } + + Requirements["condition"] = Condition; + Requirements["exclusive"] = true; + + json11::Json::object Connection; + Connection["modePreference"] = ToString(m_Config.Mode); + + if (m_Config.EncryptionMode != Encryption::None) + { + Connection["encryption"] = ToString(m_Config.EncryptionMode); + } + + // Request configured zen service port to be forwarded. The Horde agent will map this + // to a local port on the provisioned machine and report it back in the response. + json11::Json::object PortsObj; + PortsObj["ZenPort"] = json11::Json(m_Config.ZenServicePort); + Connection["ports"] = PortsObj; + + json11::Json::object Root; + Root["requirements"] = Requirements; + Root["connection"] = Connection; + + return json11::Json(Root).dump(); +} + +bool +HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster) +{ + ZEN_TRACE_CPU("HordeClient::ResolveCluster"); + + const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{RequestBody.data(), RequestBody.size()}, ZenContentType::kJSON); + + const HttpClient::Response Response = m_Http->Post("api/v2/compute/_cluster", Payload); + + if (Response.Error) + { + ZEN_WARN("cluster resolution failed: {}", Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast<int>(Response.StatusCode); + + if (StatusCode == 503 || StatusCode == 429) + { + ZEN_DEBUG("cluster resolution returned HTTP/{}: no resources", StatusCode); + return false; + } + + if (StatusCode == 401) + { + ZEN_WARN("cluster resolution returned HTTP/401: token expired"); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("cluster resolution failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON response for cluster resolution: {}", Err); + return false; + } + + const json11::Json ClusterIdVal = Json["clusterId"]; + if (!ClusterIdVal.is_string() || ClusterIdVal.string_value().empty()) + { + ZEN_WARN("missing 'clusterId' in cluster resolution response"); + return false; + } + + OutCluster.ClusterId = ClusterIdVal.string_value(); + return true; +} + +bool +HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize) +{ + if (Hex.size() != OutSize * 2) + { + return false; + } + + for (size_t i = 0; i < OutSize; ++i) + { + auto HexToByte = [](char c) -> int { + if (c >= '0' && c <= '9') + return c - '0'; + if (c >= 'a' && c <= 'f') + return c - 'a' + 10; + if (c >= 'A' && c <= 'F') + return c - 'A' + 10; + return -1; + }; + + const int Hi = HexToByte(Hex[i * 2]); + const int Lo = HexToByte(Hex[i * 2 + 1]); + if (Hi < 0 || Lo < 0) + { + return false; + } + Out[i] = static_cast<uint8_t>((Hi << 4) | Lo); + } + + return true; +} + +bool +HordeClient::RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine) +{ + ZEN_TRACE_CPU("HordeClient::RequestMachine"); + + ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str()); + + ExtendableStringBuilder<128> ResourcePath; + ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str()); + + const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{RequestBody.data(), RequestBody.size()}, ZenContentType::kJSON); + const HttpClient::Response Response = m_Http->Post(ResourcePath.ToView(), Payload); + + // Reset output to invalid state + OutMachine = {}; + OutMachine.Port = 0xFFFF; + + if (Response.Error) + { + ZEN_WARN("machine request failed: {}", Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast<int>(Response.StatusCode); + + if (StatusCode == 404 || StatusCode == 503 || StatusCode == 429) + { + ZEN_DEBUG("machine request returned HTTP/{}: no resources", StatusCode); + return false; + } + + if (StatusCode == 401) + { + ZEN_WARN("machine request returned HTTP/401: token expired"); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("machine request failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON response for machine request: {}", Err); + return false; + } + + // Required fields + const json11::Json NonceVal = Json["nonce"]; + const json11::Json IpVal = Json["ip"]; + const json11::Json PortVal = Json["port"]; + + if (!NonceVal.is_string() || !IpVal.is_string() || !PortVal.is_number()) + { + ZEN_WARN("missing 'nonce', 'ip', or 'port' in machine response"); + return false; + } + + OutMachine.Ip = IpVal.string_value(); + OutMachine.Port = static_cast<uint16_t>(PortVal.int_value()); + + if (!ParseHexBytes(NonceVal.string_value(), OutMachine.Nonce, NonceSize)) + { + ZEN_WARN("invalid nonce hex string in machine response"); + return false; + } + + if (const json11::Json PortsVal = Json["ports"]; PortsVal.is_object()) + { + for (const auto& [Key, Val] : PortsVal.object_items()) + { + PortInfo Info; + if (Val["port"].is_number()) + { + Info.Port = static_cast<uint16_t>(Val["port"].int_value()); + } + if (Val["agentPort"].is_number()) + { + Info.AgentPort = static_cast<uint16_t>(Val["agentPort"].int_value()); + } + OutMachine.Ports[Key] = Info; + } + } + + if (const json11::Json ConnectionModeVal = Json["connectionMode"]; ConnectionModeVal.is_string()) + { + if (FromString(OutMachine.Mode, ConnectionModeVal.string_value())) + { + if (const json11::Json ConnectionAddressVal = Json["connectionAddress"]; ConnectionAddressVal.is_string()) + { + OutMachine.ConnectionAddress = ConnectionAddressVal.string_value(); + } + } + } + + // Properties are a flat string array of "Key=Value" pairs describing the machine. + // We extract OS family and core counts for sizing decisions. If neither core count + // is available, we fall back to 16 as a conservative default. + uint16_t LogicalCores = 0; + uint16_t PhysicalCores = 0; + + if (const json11::Json PropertiesVal = Json["properties"]; PropertiesVal.is_array()) + { + for (const json11::Json& PropVal : PropertiesVal.array_items()) + { + if (!PropVal.is_string()) + { + continue; + } + + const std::string Prop = PropVal.string_value(); + if (Prop.starts_with("OSFamily=")) + { + if (Prop.substr(9) == "Windows") + { + OutMachine.IsWindows = true; + } + } + else if (Prop.starts_with("LogicalCores=")) + { + LogicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 13)); + } + else if (Prop.starts_with("PhysicalCores=")) + { + PhysicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 14)); + } + } + } + + if (LogicalCores > 0) + { + OutMachine.LogicalCores = LogicalCores; + } + else if (PhysicalCores > 0) + { + OutMachine.LogicalCores = PhysicalCores * 2; + } + else + { + OutMachine.LogicalCores = 16; + } + + if (const json11::Json EncryptionVal = Json["encryption"]; EncryptionVal.is_string()) + { + if (FromString(OutMachine.EncryptionMode, EncryptionVal.string_value())) + { + if (OutMachine.EncryptionMode == Encryption::AES) + { + const json11::Json KeyVal = Json["key"]; + if (KeyVal.is_string() && !KeyVal.string_value().empty()) + { + if (!ParseHexBytes(KeyVal.string_value(), OutMachine.Key, KeySize)) + { + ZEN_WARN("invalid AES key in machine response"); + } + } + else + { + ZEN_WARN("AES encryption requested but no key provided"); + } + } + } + } + + if (const json11::Json LeaseIdVal = Json["leaseId"]; LeaseIdVal.is_string()) + { + OutMachine.LeaseId = LeaseIdVal.string_value(); + } + + ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}", + OutMachine.GetConnectionAddress(), + OutMachine.GetConnectionPort(), + OutMachine.LogicalCores, + OutMachine.LeaseId); + + return true; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp new file mode 100644 index 000000000..0d032b5d5 --- /dev/null +++ b/src/zenhorde/hordecomputebuffer.cpp @@ -0,0 +1,454 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputebuffer.h" + +#include <algorithm> +#include <cassert> +#include <chrono> +#include <condition_variable> +#include <cstring> + +namespace zen::horde { + +// Simplified ring buffer implementation for in-process use only. +// Uses a single contiguous buffer with write/read cursors and +// mutex+condvar for synchronization. This is simpler than the UE version +// which uses lock-free atomics and shared memory, but sufficient for our +// use case where we're the initiator side of the compute protocol. + +struct ComputeBuffer::Detail : TRefCounted<Detail> +{ + std::vector<uint8_t> Data; + size_t NumChunks = 0; + size_t ChunkLength = 0; + + // Current write state + size_t WriteChunkIdx = 0; + size_t WriteOffset = 0; + bool WriteComplete = false; + + // Current read state + size_t ReadChunkIdx = 0; + size_t ReadOffset = 0; + bool Detached = false; + + // Per-chunk written length + std::vector<size_t> ChunkWrittenLength; + std::vector<bool> ChunkFinished; // Writer moved to next chunk + + std::mutex Mutex; + std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes + std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space + + bool HasWriter = false; + bool HasReader = false; + + uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; } + const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; } +}; + +// ComputeBuffer + +ComputeBuffer::ComputeBuffer() +{ +} +ComputeBuffer::~ComputeBuffer() +{ +} + +bool +ComputeBuffer::CreateNew(const Params& InParams) +{ + auto* NewDetail = new Detail(); + NewDetail->NumChunks = InParams.NumChunks; + NewDetail->ChunkLength = InParams.ChunkLength; + NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0); + NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0); + NewDetail->ChunkFinished.resize(InParams.NumChunks, false); + + m_Detail = NewDetail; + return true; +} + +void +ComputeBuffer::Close() +{ + m_Detail = nullptr; +} + +bool +ComputeBuffer::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +ComputeBufferReader +ComputeBuffer::CreateReader() +{ + assert(m_Detail); + m_Detail->HasReader = true; + return ComputeBufferReader(m_Detail); +} + +ComputeBufferWriter +ComputeBuffer::CreateWriter() +{ + assert(m_Detail); + m_Detail->HasWriter = true; + return ComputeBufferWriter(m_Detail); +} + +// ComputeBufferReader + +ComputeBufferReader::ComputeBufferReader() +{ +} +ComputeBufferReader::~ComputeBufferReader() +{ +} + +ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default; +ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default; +ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default; +ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default; + +ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) +{ +} + +void +ComputeBufferReader::Close() +{ + m_Detail = nullptr; +} + +void +ComputeBufferReader::Detach() +{ + if (m_Detail) + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + m_Detail->Detached = true; + m_Detail->ReadCV.notify_all(); + } +} + +bool +ComputeBufferReader::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +bool +ComputeBufferReader::IsComplete() const +{ + if (!m_Detail) + { + return true; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + if (m_Detail->Detached) + { + return true; + } + return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx && + m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx]; +} + +void +ComputeBufferReader::AdvanceReadPosition(size_t Size) +{ + if (!m_Detail) + { + return; + } + + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + + m_Detail->ReadOffset += Size; + + // Check if we need to move to next chunk + const size_t ReadChunk = m_Detail->ReadChunkIdx; + if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) + { + const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; + m_Detail->ReadChunkIdx = NextChunk; + m_Detail->ReadOffset = 0; + m_Detail->WriteCV.notify_all(); + } + + m_Detail->ReadCV.notify_all(); +} + +size_t +ComputeBufferReader::GetMaxReadSize() const +{ + if (!m_Detail) + { + return 0; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t ReadChunk = m_Detail->ReadChunkIdx; + return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; +} + +const uint8_t* +ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut) +{ + if (!m_Detail) + { + return nullptr; + } + + std::unique_lock<std::mutex> Lock(m_Detail->Mutex); + + auto Predicate = [&]() -> bool { + if (m_Detail->Detached) + { + return true; + } + + const size_t ReadChunk = m_Detail->ReadChunkIdx; + const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; + + if (Available >= MinSize) + { + return true; + } + + // If chunk is finished and we've read everything, try to move to next + if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) + { + if (m_Detail->WriteComplete) + { + return true; // End of stream + } + // Move to next chunk + const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; + m_Detail->ReadChunkIdx = NextChunk; + m_Detail->ReadOffset = 0; + m_Detail->WriteCV.notify_all(); + return false; // Re-check with new chunk + } + + if (m_Detail->WriteComplete) + { + return true; // End of stream + } + + return false; + }; + + if (TimeoutMs < 0) + { + m_Detail->ReadCV.wait(Lock, Predicate); + } + else + { + if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) + { + if (OutTimedOut) + { + *OutTimedOut = true; + } + return nullptr; + } + } + + if (m_Detail->Detached) + { + return nullptr; + } + + const size_t ReadChunk = m_Detail->ReadChunkIdx; + const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; + + if (Available < MinSize) + { + return nullptr; // End of stream + } + + return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset; +} + +size_t +ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut) +{ + const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut); + if (!Data) + { + return 0; + } + + const size_t Available = GetMaxReadSize(); + const size_t ToCopy = std::min(Available, MaxSize); + memcpy(Buffer, Data, ToCopy); + AdvanceReadPosition(ToCopy); + return ToCopy; +} + +// ComputeBufferWriter + +ComputeBufferWriter::ComputeBufferWriter() = default; +ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default; +ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default; +ComputeBufferWriter::~ComputeBufferWriter() = default; +ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default; +ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default; + +ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) +{ +} + +void +ComputeBufferWriter::Close() +{ + if (m_Detail) + { + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + if (!m_Detail->WriteComplete) + { + m_Detail->WriteComplete = true; + m_Detail->ReadCV.notify_all(); + } + } + m_Detail = nullptr; + } +} + +bool +ComputeBufferWriter::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +void +ComputeBufferWriter::MarkComplete() +{ + if (m_Detail) + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + m_Detail->WriteComplete = true; + m_Detail->ReadCV.notify_all(); + } +} + +void +ComputeBufferWriter::AdvanceWritePosition(size_t Size) +{ + if (!m_Detail || Size == 0) + { + return; + } + + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t WriteChunk = m_Detail->WriteChunkIdx; + m_Detail->ChunkWrittenLength[WriteChunk] += Size; + m_Detail->WriteOffset += Size; + m_Detail->ReadCV.notify_all(); +} + +size_t +ComputeBufferWriter::GetMaxWriteSize() const +{ + if (!m_Detail) + { + return 0; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t WriteChunk = m_Detail->WriteChunkIdx; + return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; +} + +size_t +ComputeBufferWriter::GetChunkMaxLength() const +{ + if (!m_Detail) + { + return 0; + } + return m_Detail->ChunkLength; +} + +size_t +ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs) +{ + uint8_t* Dest = WaitToWrite(1, TimeoutMs); + if (!Dest) + { + return 0; + } + + const size_t Available = GetMaxWriteSize(); + const size_t ToCopy = std::min(Available, MaxSize); + memcpy(Dest, Buffer, ToCopy); + AdvanceWritePosition(ToCopy); + return ToCopy; +} + +uint8_t* +ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs) +{ + if (!m_Detail) + { + return nullptr; + } + + std::unique_lock<std::mutex> Lock(m_Detail->Mutex); + + if (m_Detail->WriteComplete) + { + return nullptr; + } + + const size_t WriteChunk = m_Detail->WriteChunkIdx; + const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; + + // If current chunk has enough space, return pointer + if (Available >= MinSize) + { + return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk]; + } + + // Current chunk is full - mark it as finished and move to next. + // The writer cannot advance until the reader has fully consumed the next chunk, + // preventing the writer from overwriting data the reader hasn't processed yet. + m_Detail->ChunkFinished[WriteChunk] = true; + m_Detail->ReadCV.notify_all(); + + const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks; + + // Wait until reader has consumed the next chunk + auto Predicate = [&]() -> bool { + // Check if read has moved past this chunk + return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached; + }; + + if (TimeoutMs < 0) + { + m_Detail->WriteCV.wait(Lock, Predicate); + } + else + { + if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) + { + return nullptr; + } + } + + if (m_Detail->Detached) + { + return nullptr; + } + + // Reset next chunk + m_Detail->ChunkWrittenLength[NextChunk] = 0; + m_Detail->ChunkFinished[NextChunk] = false; + m_Detail->WriteChunkIdx = NextChunk; + m_Detail->WriteOffset = 0; + + return m_Detail->ChunkPtr(NextChunk); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h new file mode 100644 index 000000000..64ef91b7a --- /dev/null +++ b/src/zenhorde/hordecomputebuffer.h @@ -0,0 +1,136 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/refcount.h> + +#include <cstddef> +#include <cstdint> +#include <mutex> +#include <vector> + +namespace zen::horde { + +class ComputeBufferReader; +class ComputeBufferWriter; + +/** Simplified in-process ring buffer for the Horde compute protocol. + * + * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files, + * this implementation uses plain heap-allocated memory since we only need in-process + * communication between channel and transport threads. The buffer is divided into + * fixed-size chunks; readers and writers block when no space is available. + */ +class ComputeBuffer +{ +public: + struct Params + { + size_t NumChunks = 2; + size_t ChunkLength = 512 * 1024; + }; + + ComputeBuffer(); + ~ComputeBuffer(); + + ComputeBuffer(const ComputeBuffer&) = delete; + ComputeBuffer& operator=(const ComputeBuffer&) = delete; + + bool CreateNew(const Params& InParams); + void Close(); + + bool IsValid() const; + + ComputeBufferReader CreateReader(); + ComputeBufferWriter CreateWriter(); + +private: + struct Detail; + Ref<Detail> m_Detail; + + friend class ComputeBufferReader; + friend class ComputeBufferWriter; +}; + +/** Read endpoint for a ComputeBuffer. + * + * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer + * directly into the buffer memory (zero-copy); the caller must call + * AdvanceReadPosition() after consuming the data. + */ +class ComputeBufferReader +{ +public: + ComputeBufferReader(); + ComputeBufferReader(const ComputeBufferReader&); + ComputeBufferReader(ComputeBufferReader&&) noexcept; + ~ComputeBufferReader(); + + ComputeBufferReader& operator=(const ComputeBufferReader&); + ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept; + + void Close(); + void Detach(); + bool IsValid() const; + bool IsComplete() const; + + void AdvanceReadPosition(size_t Size); + size_t GetMaxReadSize() const; + + /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */ + size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); + + /** Wait until at least MinSize bytes are available and return a direct pointer. + * Returns nullptr on timeout or if the writer has completed. */ + const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); + +private: + friend class ComputeBuffer; + explicit ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail); + + Ref<ComputeBuffer::Detail> m_Detail; +}; + +/** Write endpoint for a ComputeBuffer. + * + * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer + * directly into the buffer memory (zero-copy); the caller must call + * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal + * that no more data will be written. + */ +class ComputeBufferWriter +{ +public: + ComputeBufferWriter(); + ComputeBufferWriter(const ComputeBufferWriter&); + ComputeBufferWriter(ComputeBufferWriter&&) noexcept; + ~ComputeBufferWriter(); + + ComputeBufferWriter& operator=(const ComputeBufferWriter&); + ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept; + + void Close(); + bool IsValid() const; + + /** Signal that no more data will be written. Unblocks any waiting readers. */ + void MarkComplete(); + + void AdvanceWritePosition(size_t Size); + size_t GetMaxWriteSize() const; + size_t GetChunkMaxLength() const; + + /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */ + size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1); + + /** Wait until at least MinSize bytes of write space are available and return a direct pointer. + * Returns nullptr on timeout. */ + uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1); + +private: + friend class ComputeBuffer; + explicit ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail); + + Ref<ComputeBuffer::Detail> m_Detail; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp new file mode 100644 index 000000000..ee2a6f327 --- /dev/null +++ b/src/zenhorde/hordecomputechannel.cpp @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputechannel.h" + +namespace zen::horde { + +ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter) +: Reader(std::move(InReader)) +, Writer(std::move(InWriter)) +{ +} + +bool +ComputeChannel::IsValid() const +{ + return Reader.IsValid() && Writer.IsValid(); +} + +size_t +ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs) +{ + return Writer.Write(Data, Size, TimeoutMs); +} + +size_t +ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs) +{ + return Reader.Read(Data, Size, TimeoutMs); +} + +void +ComputeChannel::MarkComplete() +{ + Writer.MarkComplete(); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h new file mode 100644 index 000000000..c1dff20e4 --- /dev/null +++ b/src/zenhorde/hordecomputechannel.h @@ -0,0 +1,32 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordecomputebuffer.h" + +namespace zen::horde { + +/** Bidirectional communication channel using a pair of compute buffers. + * + * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter + * (for sending data). Used by ComputeSocket to represent one logical channel + * within a multiplexed connection. + */ +class ComputeChannel : public TRefCounted<ComputeChannel> +{ +public: + ComputeBufferReader Reader; + ComputeBufferWriter Writer; + + ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter); + + bool IsValid() const; + + size_t Send(const void* Data, size_t Size, int TimeoutMs = -1); + size_t Recv(void* Data, size_t Size, int TimeoutMs = -1); + + /** Signal that no more data will be sent on this channel. */ + void MarkComplete(); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp new file mode 100644 index 000000000..6ef67760c --- /dev/null +++ b/src/zenhorde/hordecomputesocket.cpp @@ -0,0 +1,204 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputesocket.h" + +#include <zencore/logging.h> + +namespace zen::horde { + +ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport) +: m_Log(zen::logging::Get("horde.socket")) +, m_Transport(std::move(Transport)) +{ +} + +ComputeSocket::~ComputeSocket() +{ + // Shutdown order matters: first stop the ping thread, then unblock send threads + // by detaching readers, then join send threads, and finally close the transport + // to unblock the recv thread (which is blocked on RecvMessage). + { + std::lock_guard<std::mutex> Lock(m_PingMutex); + m_PingShouldStop = true; + m_PingCV.notify_all(); + } + + for (auto& Reader : m_Readers) + { + Reader.Detach(); + } + + for (auto& [Id, Thread] : m_SendThreads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Transport->Close(); + + if (m_RecvThread.joinable()) + { + m_RecvThread.join(); + } + if (m_PingThread.joinable()) + { + m_PingThread.join(); + } +} + +Ref<ComputeChannel> +ComputeSocket::CreateChannel(int ChannelId) +{ + ComputeBuffer::Params Params; + + ComputeBuffer RecvBuffer; + if (!RecvBuffer.CreateNew(Params)) + { + return {}; + } + + ComputeBuffer SendBuffer; + if (!SendBuffer.CreateNew(Params)) + { + return {}; + } + + Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter())); + + // Attach recv buffer writer (transport recv thread writes into this) + { + std::lock_guard<std::mutex> Lock(m_WritersMutex); + m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter()); + } + + // Attach send buffer reader (send thread reads from this) + { + ComputeBufferReader Reader = SendBuffer.CreateReader(); + m_Readers.push_back(Reader); + m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader))); + } + + return Channel; +} + +void +ComputeSocket::StartCommunication() +{ + m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this); + m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this); +} + +void +ComputeSocket::PingThreadProc() +{ + while (true) + { + { + std::unique_lock<std::mutex> Lock(m_PingMutex); + if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; })) + { + break; + } + } + + std::lock_guard<std::mutex> Lock(m_SendMutex); + FrameHeader Header; + Header.Channel = 0; + Header.Size = ControlPing; + m_Transport->SendMessage(&Header, sizeof(Header)); + } +} + +void +ComputeSocket::RecvThreadProc() +{ + // Writers are cached locally to avoid taking m_WritersMutex on every frame. + // The shared m_Writers map is only accessed when a channel is seen for the first time. + std::unordered_map<int, ComputeBufferWriter> CachedWriters; + + FrameHeader Header; + while (m_Transport->RecvMessage(&Header, sizeof(Header))) + { + if (Header.Size >= 0) + { + // Data frame + auto It = CachedWriters.find(Header.Channel); + if (It == CachedWriters.end()) + { + std::lock_guard<std::mutex> Lock(m_WritersMutex); + auto WIt = m_Writers.find(Header.Channel); + if (WIt == m_Writers.end()) + { + ZEN_WARN("recv frame for unknown channel {}", Header.Channel); + // Skip the data + std::vector<uint8_t> Discard(Header.Size); + m_Transport->RecvMessage(Discard.data(), Header.Size); + continue; + } + It = CachedWriters.emplace(Header.Channel, WIt->second).first; + } + + ComputeBufferWriter& Writer = It->second; + uint8_t* Dest = Writer.WaitToWrite(Header.Size); + if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size)) + { + ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size); + return; + } + Writer.AdvanceWritePosition(Header.Size); + } + else if (Header.Size == ControlDetach) + { + // Detach the recv buffer for this channel + CachedWriters.erase(Header.Channel); + + std::lock_guard<std::mutex> Lock(m_WritersMutex); + auto It = m_Writers.find(Header.Channel); + if (It != m_Writers.end()) + { + It->second.MarkComplete(); + m_Writers.erase(It); + } + } + else if (Header.Size == ControlPing) + { + // Ping response - ignore + } + else + { + ZEN_WARN("invalid frame header size: {}", Header.Size); + return; + } + } +} + +void +ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader) +{ + // Each channel has its own send thread. All send threads share m_SendMutex + // to serialize writes to the transport, since TCP requires atomic frame writes. + FrameHeader Header; + Header.Channel = Channel; + + const uint8_t* Data; + while ((Data = Reader.WaitToRead(1)) != nullptr) + { + std::lock_guard<std::mutex> Lock(m_SendMutex); + + Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize()); + m_Transport->SendMessage(&Header, sizeof(Header)); + m_Transport->SendMessage(Data, Header.Size); + Reader.AdvanceReadPosition(Header.Size); + } + + if (Reader.IsComplete()) + { + std::lock_guard<std::mutex> Lock(m_SendMutex); + Header.Size = ControlDetach; + m_Transport->SendMessage(&Header, sizeof(Header)); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h new file mode 100644 index 000000000..0c3cb4195 --- /dev/null +++ b/src/zenhorde/hordecomputesocket.h @@ -0,0 +1,79 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordecomputebuffer.h" +#include "hordecomputechannel.h" +#include "hordetransport.h" + +#include <zencore/logbase.h> + +#include <condition_variable> +#include <memory> +#include <mutex> +#include <thread> +#include <unordered_map> +#include <vector> + +namespace zen::horde { + +/** Multiplexed socket that routes data between multiple channels over a single transport. + * + * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers. + * A recv thread demultiplexes incoming frames to channel-specific buffers, while + * per-channel send threads multiplex outgoing data onto the shared transport. + * + * Wire format per frame: [channelId (4B)][size (4B)][data] + * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping. + */ +class ComputeSocket +{ +public: + explicit ComputeSocket(std::unique_ptr<ComputeTransport> Transport); + ~ComputeSocket(); + + ComputeSocket(const ComputeSocket&) = delete; + ComputeSocket& operator=(const ComputeSocket&) = delete; + + /** Create a channel with the given ID. + * Allocates anonymous in-process buffers and spawns a send thread for the channel. */ + Ref<ComputeChannel> CreateChannel(int ChannelId); + + /** Start the recv pump and ping threads. Must be called after all channels are created. */ + void StartCommunication(); + +private: + struct FrameHeader + { + int32_t Channel = 0; + int32_t Size = 0; + }; + + static constexpr int32_t ControlDetach = -2; + static constexpr int32_t ControlPing = -3; + + LoggerRef Log() { return m_Log; } + + void RecvThreadProc(); + void SendThreadProc(int Channel, ComputeBufferReader Reader); + void PingThreadProc(); + + LoggerRef m_Log; + std::unique_ptr<ComputeTransport> m_Transport; + std::mutex m_SendMutex; ///< Serializes writes to the transport + + std::mutex m_WritersMutex; + std::unordered_map<int, ComputeBufferWriter> m_Writers; ///< Recv-side: writers keyed by channel ID + + std::vector<ComputeBufferReader> m_Readers; ///< Send-side: readers for join on destruction + std::unordered_map<int, std::thread> m_SendThreads; ///< One send thread per channel + + std::thread m_RecvThread; + std::thread m_PingThread; + + bool m_PingShouldStop = false; + std::mutex m_PingMutex; + std::condition_variable m_PingCV; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp new file mode 100644 index 000000000..2dca228d9 --- /dev/null +++ b/src/zenhorde/hordeconfig.cpp @@ -0,0 +1,89 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhorde/hordeconfig.h> + +namespace zen::horde { + +bool +HordeConfig::Validate() const +{ + if (ServerUrl.empty()) + { + return false; + } + + // Relay mode implies AES encryption + if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES) + { + return false; + } + + return true; +} + +const char* +ToString(ConnectionMode Mode) +{ + switch (Mode) + { + case ConnectionMode::Direct: + return "direct"; + case ConnectionMode::Tunnel: + return "tunnel"; + case ConnectionMode::Relay: + return "relay"; + } + return "direct"; +} + +const char* +ToString(Encryption Enc) +{ + switch (Enc) + { + case Encryption::None: + return "none"; + case Encryption::AES: + return "aes"; + } + return "none"; +} + +bool +FromString(ConnectionMode& OutMode, std::string_view Str) +{ + if (Str == "direct") + { + OutMode = ConnectionMode::Direct; + return true; + } + if (Str == "tunnel") + { + OutMode = ConnectionMode::Tunnel; + return true; + } + if (Str == "relay") + { + OutMode = ConnectionMode::Relay; + return true; + } + return false; +} + +bool +FromString(Encryption& OutEnc, std::string_view Str) +{ + if (Str == "none") + { + OutEnc = Encryption::None; + return true; + } + if (Str == "aes") + { + OutEnc = Encryption::AES; + return true; + } + return false; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp new file mode 100644 index 000000000..f88c95da2 --- /dev/null +++ b/src/zenhorde/hordeprovisioner.cpp @@ -0,0 +1,367 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhorde/hordeclient.h> +#include <zenhorde/hordeprovisioner.h> + +#include "hordeagent.h" +#include "hordebundle.h" + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/thread.h> +#include <zencore/trace.h> + +#include <chrono> +#include <thread> + +namespace zen::horde { + +struct HordeProvisioner::AgentWrapper +{ + std::thread Thread; + std::atomic<bool> ShouldExit{false}; +}; + +HordeProvisioner::HordeProvisioner(const HordeConfig& Config, + const std::filesystem::path& BinariesPath, + const std::filesystem::path& WorkingDir, + std::string_view OrchestratorEndpoint) +: m_Config(Config) +, m_BinariesPath(BinariesPath) +, m_WorkingDir(WorkingDir) +, m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_Log(zen::logging::Get("horde.provisioner")) +{ +} + +HordeProvisioner::~HordeProvisioner() +{ + std::lock_guard<std::mutex> Lock(m_AgentsLock); + for (auto& Agent : m_Agents) + { + Agent->ShouldExit.store(true); + } + for (auto& Agent : m_Agents) + { + if (Agent->Thread.joinable()) + { + Agent->Thread.join(); + } + } +} + +void +HordeProvisioner::SetTargetCoreCount(uint32_t Count) +{ + ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount"); + + m_TargetCoreCount.store(std::min(Count, static_cast<uint32_t>(m_Config.MaxCores))); + + while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load()) + { + if (!m_AskForAgents.load()) + { + return; + } + RequestAgent(); + } + + // Clean up finished agent threads + std::lock_guard<std::mutex> Lock(m_AgentsLock); + for (auto It = m_Agents.begin(); It != m_Agents.end();) + { + if ((*It)->ShouldExit.load()) + { + if ((*It)->Thread.joinable()) + { + (*It)->Thread.join(); + } + It = m_Agents.erase(It); + } + else + { + ++It; + } + } +} + +ProvisioningStats +HordeProvisioner::GetStats() const +{ + ProvisioningStats Stats; + Stats.TargetCoreCount = m_TargetCoreCount.load(); + Stats.EstimatedCoreCount = m_EstimatedCoreCount.load(); + Stats.ActiveCoreCount = m_ActiveCoreCount.load(); + Stats.AgentsActive = m_AgentsActive.load(); + Stats.AgentsRequesting = m_AgentsRequesting.load(); + return Stats; +} + +uint32_t +HordeProvisioner::GetAgentCount() const +{ + std::lock_guard<std::mutex> Lock(m_AgentsLock); + return static_cast<uint32_t>(m_Agents.size()); +} + +void +HordeProvisioner::RequestAgent() +{ + m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); + + std::lock_guard<std::mutex> Lock(m_AgentsLock); + + auto Wrapper = std::make_unique<AgentWrapper>(); + AgentWrapper& Ref = *Wrapper; + Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); }); + + m_Agents.push_back(std::move(Wrapper)); +} + +void +HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) +{ + ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent"); + + static std::atomic<uint32_t> ThreadIndex{0}; + const uint32_t CurrentIndex = ThreadIndex.fetch_add(1); + + zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex)); + + std::unique_ptr<HordeAgent> Agent; + uint32_t MachineCoreCount = 0; + + auto _ = MakeGuard([&] { + if (Agent) + { + Agent->CloseConnection(); + } + Wrapper.ShouldExit.store(true); + }); + + { + // EstimatedCoreCount is incremented speculatively when the agent is requested + // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. + auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); + + { + ZEN_TRACE_CPU("HordeProvisioner::CreateBundles"); + + std::lock_guard<std::mutex> BundleLock(m_BundleLock); + + if (!m_BundlesCreated) + { + const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; + + std::vector<BundleFile> Files; + +#if ZEN_PLATFORM_WINDOWS + Files.emplace_back(m_BinariesPath / "zenserver.exe", false); +#elif ZEN_PLATFORM_LINUX + Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver.debug", true); +#elif ZEN_PLATFORM_MAC + Files.emplace_back(m_BinariesPath / "zenserver", false); +#endif + + BundleResult Result; + if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) + { + ZEN_WARN("failed to create bundle, cannot provision any agents!"); + m_AskForAgents.store(false); + return; + } + + m_Bundles.emplace_back(Result.Locator, Result.BundleDir); + m_BundlesCreated = true; + } + + if (!m_HordeClient) + { + m_HordeClient = std::make_unique<HordeClient>(m_Config); + if (!m_HordeClient->Initialize()) + { + ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); + m_AskForAgents.store(false); + return; + } + } + } + + if (!m_AskForAgents.load()) + { + return; + } + + m_AgentsRequesting.fetch_add(1); + auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); + + // Simple backoff: if the last machine request failed, wait up to 5 seconds + // before trying again. + // + // Note however that it's possible that multiple threads enter this code at + // the same time if multiple agents are requested at once, and they will all + // see the same last failure time and back off accordingly. We might want to + // use a semaphore or similar to limit the number of concurrent requests. + + if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) + { + auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()); + const uint64_t ElapsedNs = Now - LastFail; + const uint64_t ElapsedMs = ElapsedNs / 1'000'000; + if (ElapsedMs < 5000) + { + const uint64_t WaitMs = 5000 - ElapsedMs; + for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + if (Wrapper.ShouldExit.load()) + { + return; + } + } + } + + if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) + { + return; + } + + std::string RequestBody = m_HordeClient->BuildRequestBody(); + + // Resolve cluster if needed + std::string ClusterId = m_Config.Cluster; + if (ClusterId == HordeConfig::ClusterAuto) + { + ClusterInfo Cluster; + if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) + { + ZEN_WARN("failed to resolve cluster"); + m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } + ClusterId = Cluster.ClusterId; + } + + MachineInfo Machine; + if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) + { + m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } + + m_LastRequestFailTime.store(0); + + if (Wrapper.ShouldExit.load()) + { + return; + } + + // Connect to agent and perform handshake + Agent = std::make_unique<HordeAgent>(Machine); + if (!Agent->IsValid()) + { + ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort()); + return; + } + + if (!Agent->BeginCommunication()) + { + ZEN_WARN("BeginCommunication failed"); + return; + } + + for (auto& [Locator, BundleDir] : m_Bundles) + { + if (Wrapper.ShouldExit.load()) + { + return; + } + + if (!Agent->UploadBinaries(BundleDir, Locator)) + { + ZEN_WARN("UploadBinaries failed"); + return; + } + } + + if (Wrapper.ShouldExit.load()) + { + return; + } + + // Build command line for remote zenserver + std::vector<std::string> ArgStrings; + ArgStrings.push_back("compute"); + ArgStrings.push_back("--http=asio"); + + // TEMP HACK - these should be made fully dynamic + // these are currently here to allow spawning the compute agent locally + // for debugging purposes (i.e with a local Horde Server+Agent setup) + ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); + ArgStrings.push_back("--data-dir=c:\\temp\\123"); + + if (!m_OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; + ArgStrings.emplace_back(CoordArg.ToView()); + } + + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=horde-" << Machine.LeaseId; + ArgStrings.emplace_back(IdArg.ToView()); + } + + std::vector<const char*> Args; + Args.reserve(ArgStrings.size()); + for (const std::string& Arg : ArgStrings) + { + Args.push_back(Arg.c_str()); + } + +#if ZEN_PLATFORM_WINDOWS + const bool UseWine = !Machine.IsWindows; + const char* AppName = "zenserver.exe"; +#else + const bool UseWine = false; + const char* AppName = "zenserver"; +#endif + + Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine); + + ZEN_INFO("remote execution started on [{}:{}] lease={}", + Machine.GetConnectionAddress(), + Machine.GetConnectionPort(), + Machine.LeaseId); + + MachineCoreCount = Machine.LogicalCores; + m_EstimatedCoreCount.fetch_add(MachineCoreCount); + m_ActiveCoreCount.fetch_add(MachineCoreCount); + m_AgentsActive.fetch_add(1); + } + + // Agent poll loop + + auto ActiveGuard = MakeGuard([&]() { + m_EstimatedCoreCount.fetch_sub(MachineCoreCount); + m_ActiveCoreCount.fetch_sub(MachineCoreCount); + m_AgentsActive.fetch_sub(1); + }); + + while (Agent->IsValid() && !Wrapper.ShouldExit.load()) + { + const bool LogOutput = false; + if (!Agent->Poll(LogOutput)) + { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp new file mode 100644 index 000000000..69766e73e --- /dev/null +++ b/src/zenhorde/hordetransport.cpp @@ -0,0 +1,169 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordetransport.h" + +#include <zencore/logging.h> +#include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +namespace zen::horde { + +// ComputeTransport base + +bool +ComputeTransport::SendMessage(const void* Data, size_t Size) +{ + const uint8_t* Ptr = static_cast<const uint8_t*>(Data); + size_t Remaining = Size; + + while (Remaining > 0) + { + const size_t Sent = Send(Ptr, Remaining); + if (Sent == 0) + { + return false; + } + Ptr += Sent; + Remaining -= Sent; + } + + return true; +} + +bool +ComputeTransport::RecvMessage(void* Data, size_t Size) +{ + uint8_t* Ptr = static_cast<uint8_t*>(Data); + size_t Remaining = Size; + + while (Remaining > 0) + { + const size_t Received = Recv(Ptr, Remaining); + if (Received == 0) + { + return false; + } + Ptr += Received; + Remaining -= Received; + } + + return true; +} + +// TcpComputeTransport - ASIO pimpl + +struct TcpComputeTransport::Impl +{ + asio::io_context IoContext; + asio::ip::tcp::socket Socket; + + Impl() : Socket(IoContext) {} +}; + +// Uses ASIO in synchronous mode only — no async operations or io_context::run(). +// The io_context is only needed because ASIO sockets require one to be constructed. +TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) +: m_Impl(std::make_unique<Impl>()) +, m_Log(zen::logging::Get("horde.transport")) +{ + ZEN_TRACE_CPU("TcpComputeTransport::Connect"); + + asio::error_code Ec; + + const asio::ip::address Address = asio::ip::make_address(Info.GetConnectionAddress(), Ec); + if (Ec) + { + ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message()); + m_HasErrors = true; + return; + } + + const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort()); + + m_Impl->Socket.connect(Endpoint, Ec); + if (Ec) + { + ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message()); + m_HasErrors = true; + return; + } + + // Disable Nagle's algorithm for lower latency + m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec); +} + +TcpComputeTransport::~TcpComputeTransport() +{ + Close(); +} + +bool +TcpComputeTransport::IsValid() const +{ + return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed; +} + +size_t +TcpComputeTransport::Send(const void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + asio::error_code Ec; + const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec); + + if (Ec) + { + m_HasErrors = true; + return 0; + } + + return Sent; +} + +size_t +TcpComputeTransport::Recv(void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + asio::error_code Ec; + const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec); + + if (Ec) + { + return 0; + } + + return Received; +} + +void +TcpComputeTransport::MarkComplete() +{ +} + +void +TcpComputeTransport::Close() +{ + if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open()) + { + asio::error_code Ec; + m_Impl->Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + m_Impl->Socket.close(Ec); + } + m_IsClosed = true; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h new file mode 100644 index 000000000..1b178dc0f --- /dev/null +++ b/src/zenhorde/hordetransport.h @@ -0,0 +1,71 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/hordeclient.h> + +#include <zencore/logbase.h> + +#include <cstddef> +#include <cstdint> +#include <memory> + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +namespace zen::horde { + +/** Abstract base interface for compute transports. + * + * Matches the UE FComputeTransport pattern. Concrete implementations handle + * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides + * blocking message helpers on top. + */ +class ComputeTransport +{ +public: + virtual ~ComputeTransport() = default; + + virtual bool IsValid() const = 0; + virtual size_t Send(const void* Data, size_t Size) = 0; + virtual size_t Recv(void* Data, size_t Size) = 0; + virtual void MarkComplete() = 0; + virtual void Close() = 0; + + /** Blocking send that loops until all bytes are transferred. Returns false on error. */ + bool SendMessage(const void* Data, size_t Size); + + /** Blocking receive that loops until all bytes are transferred. Returns false on error. */ + bool RecvMessage(void* Data, size_t Size); +}; + +/** TCP socket transport using ASIO. + * + * Connects to the Horde compute endpoint specified by MachineInfo and provides + * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the + * header clean. + */ +class TcpComputeTransport final : public ComputeTransport +{ +public: + explicit TcpComputeTransport(const MachineInfo& Info); + ~TcpComputeTransport() override; + + bool IsValid() const override; + size_t Send(const void* Data, size_t Size) override; + size_t Recv(void* Data, size_t Size) override; + void MarkComplete() override; + void Close() override; + +private: + LoggerRef Log() { return m_Log; } + + struct Impl; + std::unique_ptr<Impl> m_Impl; + LoggerRef m_Log; + bool m_IsClosed = false; + bool m_HasErrors = false; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp new file mode 100644 index 000000000..986dd3705 --- /dev/null +++ b/src/zenhorde/hordetransportaes.cpp @@ -0,0 +1,425 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordetransportaes.h" + +#include <zencore/logging.h> +#include <zencore/trace.h> + +#include <algorithm> +#include <cstring> +#include <random> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +# include <bcrypt.h> +# pragma comment(lib, "Bcrypt.lib") +#else +ZEN_THIRD_PARTY_INCLUDES_START +# include <openssl/evp.h> +# include <openssl/err.h> +ZEN_THIRD_PARTY_INCLUDES_END +#endif + +namespace zen::horde { + +struct AesComputeTransport::CryptoContext +{ + uint8_t Key[KeySize] = {}; + uint8_t EncryptNonce[NonceBytes] = {}; + uint8_t DecryptNonce[NonceBytes] = {}; + bool HasErrors = false; + +#if !ZEN_PLATFORM_WINDOWS + EVP_CIPHER_CTX* EncCtx = nullptr; + EVP_CIPHER_CTX* DecCtx = nullptr; +#endif + + CryptoContext(const uint8_t (&InKey)[KeySize]) + { + memcpy(Key, InKey, KeySize); + + // The encrypt nonce is randomly initialized and then deterministically mutated + // per message via UpdateNonce(). The decrypt nonce is not used — it comes from + // the wire (each received message carries its own nonce in the header). + std::random_device Rd; + std::mt19937 Gen(Rd()); + std::uniform_int_distribution<int> Dist(0, 255); + for (auto& Byte : EncryptNonce) + { + Byte = static_cast<uint8_t>(Dist(Gen)); + } + +#if !ZEN_PLATFORM_WINDOWS + // Drain any stale OpenSSL errors + while (ERR_get_error() != 0) + { + } + + EncCtx = EVP_CIPHER_CTX_new(); + EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); + + DecCtx = EVP_CIPHER_CTX_new(); + EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); +#endif + } + + ~CryptoContext() + { +#if ZEN_PLATFORM_WINDOWS + SecureZeroMemory(Key, sizeof(Key)); + SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); + SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); +#else + OPENSSL_cleanse(Key, sizeof(Key)); + OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); + OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); + + if (EncCtx) + { + EVP_CIPHER_CTX_free(EncCtx); + } + if (DecCtx) + { + EVP_CIPHER_CTX_free(DecCtx); + } +#endif + } + + void UpdateNonce() + { + uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce); + N32[0]++; + N32[1]--; + N32[2] = N32[0] ^ N32[1]; + } + + // Returns total encrypted message size, or 0 on failure + // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] + int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) + { + UpdateNonce(); + + // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than + // caching but has some overhead. For our use case (relatively large, infrequent messages) + // this is acceptable. +#if ZEN_PLATFORM_WINDOWS + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = EncryptNonce; + AuthInfo.cbNonce = NonceBytes; + uint8_t Tag[TagBytes] = {}; + AuthInfo.pbTag = Tag; + AuthInfo.cbTag = TagBytes; + + ULONG CipherLen = 0; + NTSTATUS Status = + BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + return 0; + } + + // Write header: length + nonce + memcpy(Out, &InLength, 4); + memcpy(Out + 4, EncryptNonce, NonceBytes); + // Write tag after ciphertext + memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + + return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; +#else + if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) + { + HasErrors = true; + return 0; + } + + int32_t Offset = 0; + // Write length + memcpy(Out + Offset, &InLength, 4); + Offset += 4; + // Write nonce + memcpy(Out + Offset, EncryptNonce, NonceBytes); + Offset += NonceBytes; + + // Encrypt + int OutLen = 0; + if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + { + HasErrors = true; + return 0; + } + Offset += OutLen; + + // Finalize + int FinalLen = 0; + if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + Offset += FinalLen; + + // Get tag + if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + { + HasErrors = true; + return 0; + } + Offset += TagBytes; + + return Offset; +#endif + } + + // Decrypt a message. Returns decrypted data length, or 0 on failure. + // Input must be [ciphertext][tag], with nonce provided separately. + int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) + { +#if ZEN_PLATFORM_WINDOWS + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); + AuthInfo.cbNonce = NonceBytes; + AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); + AuthInfo.cbTag = TagBytes; + + ULONG PlainLen = 0; + NTSTATUS Status = BCryptDecrypt(hKey, + (PUCHAR)CipherAndTag, + (ULONG)DataLength, + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out, + (ULONG)DataLength, + &PlainLen, + 0); + + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + return 0; + } + + return static_cast<int32_t>(PlainLen); +#else + if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) + { + HasErrors = true; + return 0; + } + + int OutLen = 0; + if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) + { + HasErrors = true; + return 0; + } + + // Set the tag for verification + if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + { + HasErrors = true; + return 0; + } + + int FinalLen = 0; + if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + + return OutLen + FinalLen; +#endif + } +}; + +AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport) +: m_Crypto(std::make_unique<CryptoContext>(Key)) +, m_Inner(std::move(InnerTransport)) +{ +} + +AesComputeTransport::~AesComputeTransport() +{ + Close(); +} + +bool +AesComputeTransport::IsValid() const +{ + return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; +} + +size_t +AesComputeTransport::Send(const void* Data, size_t Size) +{ + ZEN_TRACE_CPU("AesComputeTransport::Send"); + + if (!IsValid()) + { + return 0; + } + + std::lock_guard<std::mutex> Lock(m_Lock); + + const int32_t DataLength = static_cast<int32_t>(Size); + const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; + + if (m_EncryptBuffer.size() < MessageLength) + { + m_EncryptBuffer.resize(MessageLength); + } + + const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); + if (EncryptedLen == 0) + { + return 0; + } + + if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen))) + { + return 0; + } + + return Size; +} + +size_t +AesComputeTransport::Recv(void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes + // than the decrypted message contains. Excess bytes are buffered in m_RemainingData + // and returned on subsequent Recv calls without another decryption round-trip. + ZEN_TRACE_CPU("AesComputeTransport::Recv"); + + std::lock_guard<std::mutex> Lock(m_Lock); + + if (!m_RemainingData.empty()) + { + const size_t Available = m_RemainingData.size() - m_RemainingOffset; + const size_t ToCopy = std::min(Available, Size); + + memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); + m_RemainingOffset += ToCopy; + + if (m_RemainingOffset >= m_RemainingData.size()) + { + m_RemainingData.clear(); + m_RemainingOffset = 0; + } + + return ToCopy; + } + + // Receive packet header: [length(4B)][nonce(12B)] + struct PacketHeader + { + int32_t DataLength = 0; + uint8_t Nonce[NonceBytes] = {}; + } Header; + + if (!m_Inner->RecvMessage(&Header, sizeof(Header))) + { + return 0; + } + + // Validate DataLength to prevent OOM from malicious/corrupt peers + static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB + + if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) + { + ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); + return 0; + } + + // Receive ciphertext + tag + const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes; + + if (m_EncryptBuffer.size() < MessageLength) + { + m_EncryptBuffer.resize(MessageLength); + } + + if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) + { + return 0; + } + + // Decrypt + const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size); + + // We need a temporary buffer for decryption if we can't decrypt directly into output + std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength)); + + const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); + if (Decrypted == 0) + { + return 0; + } + + memcpy(Data, DecryptedBuf.data(), BytesToReturn); + + // Store remaining data if we couldn't return everything + if (static_cast<size_t>(Header.DataLength) > BytesToReturn) + { + m_RemainingOffset = 0; + m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); + } + + return BytesToReturn; +} + +void +AesComputeTransport::MarkComplete() +{ + if (IsValid()) + { + m_Inner->MarkComplete(); + } +} + +void +AesComputeTransport::Close() +{ + if (!m_IsClosed) + { + if (m_Inner && m_Inner->IsValid()) + { + m_Inner->Close(); + } + m_IsClosed = true; + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h new file mode 100644 index 000000000..efcad9835 --- /dev/null +++ b/src/zenhorde/hordetransportaes.h @@ -0,0 +1,52 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordetransport.h" + +#include <cstdint> +#include <memory> +#include <mutex> +#include <vector> + +namespace zen::horde { + +/** AES-256-GCM encrypted transport wrapper. + * + * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting + * all incoming data using AES-256-GCM. The nonce is mutated per message using + * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1]. + * + * Wire format per encrypted message: + * [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)] + * + * Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time). + */ +class AesComputeTransport final : public ComputeTransport +{ +public: + AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport); + ~AesComputeTransport() override; + + bool IsValid() const override; + size_t Send(const void* Data, size_t Size) override; + size_t Recv(void* Data, size_t Size) override; + void MarkComplete() override; + void Close() override; + +private: + static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size + static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size + + struct CryptoContext; + + std::unique_ptr<CryptoContext> m_Crypto; + std::unique_ptr<ComputeTransport> m_Inner; + std::vector<uint8_t> m_EncryptBuffer; + std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv + size_t m_RemainingOffset = 0; + std::mutex m_Lock; + bool m_IsClosed = false; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h new file mode 100644 index 000000000..201d68b83 --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeclient.h @@ -0,0 +1,116 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/hordeconfig.h> + +#include <zencore/logbase.h> + +#include <cstdint> +#include <map> +#include <memory> +#include <string> +#include <vector> + +namespace zen { +class HttpClient; +} + +namespace zen::horde { + +static constexpr size_t NonceSize = 64; +static constexpr size_t KeySize = 32; + +/** Port mapping information returned by Horde for a provisioned machine. */ +struct PortInfo +{ + uint16_t Port = 0; + uint16_t AgentPort = 0; +}; + +/** Describes a provisioned compute machine returned by the Horde API. + * + * Contains the network address, encryption credentials, and capabilities + * needed to establish a compute transport connection to the machine. + */ +struct MachineInfo +{ + std::string Ip; + ConnectionMode Mode = ConnectionMode::Direct; + std::string ConnectionAddress; ///< Relay/tunnel address (used when Mode != Direct) + uint16_t Port = 0; + uint16_t LogicalCores = 0; + Encryption EncryptionMode = Encryption::None; + uint8_t Nonce[NonceSize] = {}; ///< 64-byte nonce sent during TCP handshake + uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES) + bool IsWindows = false; + std::string LeaseId; + + std::map<std::string, PortInfo> Ports; + + /** Return the address to connect to, accounting for connection mode. */ + const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } + + /** Return the port to connect to, accounting for connection mode and port mapping. */ + uint16_t GetConnectionPort() const + { + if (Mode == ConnectionMode::Relay) + { + auto It = Ports.find("_horde_compute"); + if (It != Ports.end()) + { + return It->second.Port; + } + } + return Port; + } + + bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } +}; + +/** Result of cluster auto-resolution via the Horde API. */ +struct ClusterInfo +{ + std::string ClusterId = "default"; +}; + +/** HTTP client for the Horde compute REST API. + * + * Handles cluster resolution and machine provisioning requests. Each call + * is synchronous and returns success/failure. Thread safety: individual + * methods are not thread-safe; callers must synchronize access. + */ +class HordeClient +{ +public: + explicit HordeClient(const HordeConfig& Config); + ~HordeClient(); + + HordeClient(const HordeClient&) = delete; + HordeClient& operator=(const HordeClient&) = delete; + + /** Initialize the underlying HTTP client. Must be called before other methods. */ + bool Initialize(); + + /** Build the JSON request body for cluster resolution and machine requests. + * Encodes pool, condition, connection mode, encryption, and port requirements. */ + std::string BuildRequestBody() const; + + /** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */ + bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); + + /** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}. + * On success, populates OutMachine with connection details and credentials. */ + bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); + + LoggerRef Log() { return m_Log; } + +private: + bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize); + + HordeConfig m_Config; + std::unique_ptr<zen::HttpClient> m_Http; + LoggerRef m_Log; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h new file mode 100644 index 000000000..dd70f9832 --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeconfig.h @@ -0,0 +1,62 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/zenhorde.h> + +#include <string> + +namespace zen::horde { + +/** Transport connection mode for Horde compute agents. */ +enum class ConnectionMode +{ + Direct, ///< Connect directly to the agent IP + Tunnel, ///< Connect through a Horde tunnel relay + Relay, ///< Connect through a Horde relay with port mapping +}; + +/** Transport encryption mode for Horde compute channels. */ +enum class Encryption +{ + None, ///< No encryption + AES, ///< AES-256-GCM encryption (required for Relay mode) +}; + +/** Configuration for connecting to an Epic Horde compute cluster. + * + * Specifies the Horde server URL, authentication token, pool selection, + * connection mode, and resource limits. Used by HordeClient and HordeProvisioner. + */ +struct HordeConfig +{ + static constexpr const char* ClusterDefault = "default"; + static constexpr const char* ClusterAuto = "_auto"; + + bool Enabled = false; ///< Whether Horde provisioning is active + std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") + std::string AuthToken; ///< Authentication token for the Horde API + std::string Pool; ///< Pool name to request machines from + std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve + std::string Condition; ///< Agent filter expression for machine selection + std::string HostAddress; ///< Address that provisioned agents use to connect back to us + std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload + uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication + + int MaxCores = 2048; + bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents + ConnectionMode Mode = ConnectionMode::Direct; + Encryption EncryptionMode = Encryption::None; + + /** Validate the configuration. Returns false if the configuration is invalid + * (e.g. Relay mode without AES encryption). */ + bool Validate() const; +}; + +const char* ToString(ConnectionMode Mode); +const char* ToString(Encryption Enc); + +bool FromString(ConnectionMode& OutMode, std::string_view Str); +bool FromString(Encryption& OutEnc, std::string_view Str); + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h new file mode 100644 index 000000000..4e2e63bbd --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h @@ -0,0 +1,110 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/hordeconfig.h> + +#include <zencore/logbase.h> + +#include <atomic> +#include <cstdint> +#include <filesystem> +#include <memory> +#include <mutex> +#include <string> +#include <vector> + +namespace zen::horde { + +class HordeClient; + +/** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */ +struct ProvisioningStats +{ + uint32_t TargetCoreCount = 0; ///< Requested number of cores (clamped to MaxCores) + uint32_t EstimatedCoreCount = 0; ///< Cores expected once pending requests complete + uint32_t ActiveCoreCount = 0; ///< Cores on machines that are currently running zenserver + uint32_t AgentsActive = 0; ///< Number of agents with a running remote process + uint32_t AgentsRequesting = 0; ///< Number of agents currently requesting a machine from Horde +}; + +/** Multi-agent lifecycle manager for Horde worker provisioning. + * + * Provisions remote compute workers by requesting machines from the Horde API, + * connecting via the Horde compute transport protocol, uploading the zenserver + * binary, and executing it remotely. Each provisioned machine runs zenserver + * in compute mode, which announces itself back to the orchestrator. + * + * Spawns one thread per agent. Each thread handles the full lifecycle: + * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption -> + * channel setup -> binary upload -> remote execution -> poll until exit. + * + * Thread safety: SetTargetCoreCount and GetStats may be called from any thread. + */ +class HordeProvisioner +{ +public: + /** Construct a provisioner. + * @param Config Horde connection and pool configuration. + * @param BinariesPath Directory containing the zenserver binary to upload. + * @param WorkingDir Local directory for bundle staging and working files. + * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */ + HordeProvisioner(const HordeConfig& Config, + const std::filesystem::path& BinariesPath, + const std::filesystem::path& WorkingDir, + std::string_view OrchestratorEndpoint); + + /** Signals all agent threads to exit and joins them. */ + ~HordeProvisioner(); + + HordeProvisioner(const HordeProvisioner&) = delete; + HordeProvisioner& operator=(const HordeProvisioner&) = delete; + + /** Set the target number of cores to provision. + * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the + * estimated core count is below the target. Also joins any finished + * agent threads. */ + void SetTargetCoreCount(uint32_t Count); + + /** Return a snapshot of the current provisioning counters. */ + ProvisioningStats GetStats() const; + + uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); } + uint32_t GetAgentCount() const; + +private: + LoggerRef Log() { return m_Log; } + + struct AgentWrapper; + + void RequestAgent(); + void ThreadAgent(AgentWrapper& Wrapper); + + HordeConfig m_Config; + std::filesystem::path m_BinariesPath; + std::filesystem::path m_WorkingDir; + std::string m_OrchestratorEndpoint; + + std::unique_ptr<HordeClient> m_HordeClient; + + std::mutex m_BundleLock; + std::vector<std::pair<std::string, std::filesystem::path>> m_Bundles; ///< (locator, bundleDir) pairs + bool m_BundlesCreated = false; + + mutable std::mutex m_AgentsLock; + std::vector<std::unique_ptr<AgentWrapper>> m_Agents; + + std::atomic<uint64_t> m_LastRequestFailTime{0}; + std::atomic<uint32_t> m_TargetCoreCount{0}; + std::atomic<uint32_t> m_EstimatedCoreCount{0}; + std::atomic<uint32_t> m_ActiveCoreCount{0}; + std::atomic<uint32_t> m_AgentsActive{0}; + std::atomic<uint32_t> m_AgentsRequesting{0}; + std::atomic<bool> m_AskForAgents{true}; + + LoggerRef m_Log; + + static constexpr uint32_t EstimatedCoresPerAgent = 32; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/zenhorde.h b/src/zenhorde/include/zenhorde/zenhorde.h new file mode 100644 index 000000000..35147ff75 --- /dev/null +++ b/src/zenhorde/include/zenhorde/zenhorde.h @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if !defined(ZEN_WITH_HORDE) +# define ZEN_WITH_HORDE 1 +#endif diff --git a/src/zenhorde/xmake.lua b/src/zenhorde/xmake.lua new file mode 100644 index 000000000..48d028e86 --- /dev/null +++ b/src/zenhorde/xmake.lua @@ -0,0 +1,22 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zenhorde') + set_kind("static") + set_group("libs") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore", "zenhttp", "zencompute", "zenutil") + add_packages("asio", "json11") + + if is_plat("windows") then + add_syslinks("Ws2_32", "Bcrypt") + end + + if is_plat("linux") or is_plat("macosx") then + add_packages("openssl") + end + + if is_os("macosx") then + add_cxxflags("-Wno-deprecated-declarations") + end diff --git a/src/zenhttp-test/zenhttp-test.cpp b/src/zenhttp-test/zenhttp-test.cpp index c18759beb..b4b406ac8 100644 --- a/src/zenhttp-test/zenhttp-test.cpp +++ b/src/zenhttp-test/zenhttp-test.cpp @@ -1,44 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zencore/filesystem.h> -#include <zencore/logging.h> -#include <zencore/memory/newdelete.h> -#include <zencore/trace.h> +#include <zencore/testing.h> #include <zenhttp/zenhttp.h> -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include <zencore/testing.h> -# include <zencore/process.h> -#endif +#include <zencore/memory/newdelete.h> int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zenhttp_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zenhttp-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zenhttp-test", zen::zenhttp_forcelinktests); #else return 0; #endif diff --git a/src/zenhttp/auth/oidc.cpp b/src/zenhttp/auth/oidc.cpp index 38e7586ad..23bbc17e8 100644 --- a/src/zenhttp/auth/oidc.cpp +++ b/src/zenhttp/auth/oidc.cpp @@ -32,6 +32,25 @@ namespace details { using namespace std::literals; +static std::string +FormUrlEncode(std::string_view Input) +{ + std::string Result; + Result.reserve(Input.size()); + for (char C : Input) + { + if ((C >= 'A' && C <= 'Z') || (C >= 'a' && C <= 'z') || (C >= '0' && C <= '9') || C == '-' || C == '_' || C == '.' || C == '~') + { + Result.push_back(C); + } + else + { + Result.append(fmt::format("%{:02X}", static_cast<uint8_t>(C))); + } + } + return Result; +} + OidcClient::OidcClient(const OidcClient::Options& Options) { m_BaseUrl = std::string(Options.BaseUrl); @@ -67,6 +86,8 @@ OidcClient::Initialize() .TokenEndpoint = Json["token_endpoint"].string_value(), .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(), .RegistrationEndpoint = Json["registration_endpoint"].string_value(), + .EndSessionEndpoint = Json["end_session_endpoint"].string_value(), + .DeviceAuthorizationEndpoint = Json["device_authorization_endpoint"].string_value(), .JwksUri = Json["jwks_uri"].string_value(), .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]), .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]), @@ -81,7 +102,8 @@ OidcClient::Initialize() OidcClient::RefreshTokenResult OidcClient::RefreshToken(std::string_view RefreshToken) { - const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId); + const std::string Body = + fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", FormUrlEncode(RefreshToken), FormUrlEncode(m_ClientId)); HttpClient Http{m_Config.TokenEndpoint}; diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp index 47425e014..6f4c67dd0 100644 --- a/src/zenhttp/clients/httpclientcommon.cpp +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -142,7 +142,10 @@ namespace detail { DataSize -= CopySize; if (m_CacheBufferOffset == CacheBufferSize) { - AppendData(m_CacheBuffer, CacheBufferSize); + if (std::error_code Ec = AppendData(m_CacheBuffer, CacheBufferSize)) + { + return Ec; + } if (DataSize > 0) { ZEN_ASSERT(DataSize < CacheBufferSize); @@ -382,6 +385,177 @@ namespace detail { return Result; } + MultipartBoundaryParser::MultipartBoundaryParser() : BoundaryEndMatcher("--"), HeaderEndMatcher("\r\n\r\n") {} + + bool MultipartBoundaryParser::Init(const std::string_view ContentTypeHeaderValue) + { + std::string LowerCaseValue = ToLower(ContentTypeHeaderValue); + if (LowerCaseValue.starts_with("multipart/byteranges")) + { + size_t BoundaryPos = LowerCaseValue.find("boundary="); + if (BoundaryPos != std::string::npos) + { + // Yes, we do a substring of the non-lowercase value string as we want the exact boundary string + std::string_view BoundaryName = std::string_view(ContentTypeHeaderValue).substr(BoundaryPos + 9); + size_t BoundaryEnd = std::string::npos; + while (!BoundaryName.empty() && BoundaryName[0] == ' ') + { + BoundaryName = BoundaryName.substr(1); + } + if (!BoundaryName.empty()) + { + if (BoundaryName.size() > 2 && BoundaryName.front() == '"' && BoundaryName.back() == '"') + { + BoundaryEnd = BoundaryName.find('"', 1); + if (BoundaryEnd != std::string::npos) + { + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(1, BoundaryEnd - 1))); + return true; + } + } + else + { + BoundaryEnd = BoundaryName.find_first_of(" \r\n"); + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(0, BoundaryEnd))); + return true; + } + } + } + } + return false; + } + + void MultipartBoundaryParser::ParseInput(std::string_view data) + { + const char* InputPtr = data.data(); + size_t InputLength = data.length(); + size_t ScanPos = 0; + while (ScanPos < InputLength) + { + const char ScanChar = InputPtr[ScanPos]; + if (BoundaryBeginMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + if (PayloadOffset + ScanPos < (BoundaryBeginMatcher.GetMatchEndOffset() + BoundaryEndMatcher.GetMatchString().length())) + { + BoundaryEndMatcher.Match(PayloadOffset + ScanPos, ScanChar); + if (BoundaryEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + BoundaryBeginMatcher.Reset(); + HeaderEndMatcher.Reset(); + BoundaryEndMatcher.Reset(); + BoundaryHeader.Reset(); + break; + } + } + + BoundaryHeader.Append(ScanChar); + + HeaderEndMatcher.Match(PayloadOffset + ScanPos, ScanChar); + + if (HeaderEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + const uint64_t HeaderStartOffset = BoundaryBeginMatcher.GetMatchEndOffset(); + const uint64_t HeaderEndOffset = HeaderEndMatcher.GetMatchStartOffset(); + const uint64_t HeaderLength = HeaderEndOffset - HeaderStartOffset; + std::string_view HeaderText(BoundaryHeader.ToView().substr(0, HeaderLength)); + + uint64_t OffsetInPayload = PayloadOffset + ScanPos + 1; + + uint64_t RangeOffset = 0; + uint64_t RangeLength = 0; + HttpContentType ContentType = HttpContentType::kBinary; + + ForEachStrTok(HeaderText, "\r\n", [&](std::string_view Line) { + const std::pair<std::string_view, std::string_view> KeyAndValue = GetHeaderKeyAndValue(Line); + const std::string_view Key = KeyAndValue.first; + const std::string_view Value = KeyAndValue.second; + if (Key == "Content-Range") + { + std::pair<uint64_t, uint64_t> ContentRange = ParseContentRange(Value); + if (ContentRange.second != 0) + { + RangeOffset = ContentRange.first; + RangeLength = ContentRange.second; + } + } + else if (Key == "Content-Type") + { + ContentType = ParseContentType(Value); + } + + return true; + }); + + if (RangeLength > 0) + { + Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = OffsetInPayload, + .RangeOffset = RangeOffset, + .RangeLength = RangeLength, + .ContentType = ContentType}); + } + + BoundaryBeginMatcher.Reset(); + HeaderEndMatcher.Reset(); + BoundaryEndMatcher.Reset(); + BoundaryHeader.Reset(); + } + } + else + { + BoundaryBeginMatcher.Match(PayloadOffset + ScanPos, ScanChar); + } + ScanPos++; + } + PayloadOffset += InputLength; + } + + std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString) + { + size_t DelimiterPos = HeaderString.find(':'); + if (DelimiterPos != std::string::npos) + { + std::string_view Key = HeaderString.substr(0, DelimiterPos); + constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); + Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); + Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); + + std::string_view Value = HeaderString.substr(DelimiterPos + 1); + Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); + Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); + return std::make_pair(Key, Value); + } + return std::make_pair(HeaderString, std::string_view{}); + } + + std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value) + { + if (Value.starts_with("bytes ")) + { + size_t RangeSplitPos = Value.find('-', 6); + if (RangeSplitPos != std::string::npos) + { + size_t RangeEndLength = Value.find('/', RangeSplitPos + 1); + if (RangeEndLength == std::string::npos) + { + RangeEndLength = Value.length() - (RangeSplitPos + 1); + } + else + { + RangeEndLength = RangeEndLength - (RangeSplitPos + 1); + } + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(Value.substr(6, RangeSplitPos - 6)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(Value.substr(RangeSplitPos + 1, RangeEndLength)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + uint64_t RangeOffset = RequestedRangeStart.value(); + uint64_t RangeLength = RequestedRangeEnd.value() - RangeOffset + 1; + return std::make_pair(RangeOffset, RangeLength); + } + } + } + return {0, 0}; + } + } // namespace detail } // namespace zen @@ -423,6 +597,8 @@ namespace testutil { } // namespace testutil +TEST_SUITE_BEGIN("http.httpclientcommon"); + TEST_CASE("BufferedReadFileStream") { ScopedTemporaryDirectory TmpDir; @@ -470,5 +646,150 @@ TEST_CASE("CompositeBufferReadStream") CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data)); } +TEST_CASE("MultipartBoundaryParser") +{ + uint64_t Range1Offset = 2638; + uint64_t Range1Length = (5111437 - Range1Offset) + 1; + + uint64_t Range2Offset = 5118199; + uint64_t Range2Length = (9147741 - Range2Offset) + 1; + + std::string_view ContentTypeHeaderValue1 = "multipart/byteranges; boundary=00000000000000019229"; + std::string_view ContentTypeHeaderValue2 = "multipart/byteranges; boundary=\"00000000000000019229\""; + + { + std::string_view Example1 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/44369878\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample1; + ParserExample1.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 7; + for (size_t Offset = 0; Offset < Example1.length(); Offset += InputWindow) + { + ParserExample1.ParseInput(Example1.substr(Offset, Min(Example1.length() - Offset, InputWindow))); + } + + CHECK(ParserExample1.Boundaries.size() == 2); + + CHECK(ParserExample1.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample1.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample1.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample1.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example2 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample2; + ParserExample2.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 3; + for (size_t Offset = 0; Offset < Example2.length(); Offset += InputWindow) + { + std::string_view Window = Example2.substr(Offset, Min(Example2.length() - Offset, InputWindow)); + ParserExample2.ParseInput(Window); + } + + CHECK(ParserExample2.Boundaries.size() == 2); + + CHECK(ParserExample2.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample2.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample2.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample2.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example3 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita"; + + detail::MultipartBoundaryParser ParserExample3; + ParserExample3.Init(ContentTypeHeaderValue2); + + const size_t InputWindow = 31; + for (size_t Offset = 0; Offset < Example3.length(); Offset += InputWindow) + { + ParserExample3.ParseInput(Example3.substr(Offset, Min(Example3.length() - Offset, InputWindow))); + } + + CHECK(ParserExample3.Boundaries.size() == 2); + + CHECK(ParserExample3.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample3.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample3.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample3.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example4 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "Not: really\r\n" + "\r\n" + "datadatadatadata" + "\r\n--000000000bait0019229\r\n" + "\r\n--00\r\n--000000000bait001922\r\n" + "\r\n\r\n\r\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "Content-Type: application/x-ue-comp\r\n" + "ditaditadita" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n---\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample4; + ParserExample4.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 3; + for (size_t Offset = 0; Offset < Example4.length(); Offset += InputWindow) + { + std::string_view Window = Example4.substr(Offset, Min(Example4.length() - Offset, InputWindow)); + ParserExample4.ParseInput(Window); + } + + CHECK(ParserExample4.Boundaries.size() == 2); + + CHECK(ParserExample4.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample4.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample4.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample4.Boundaries[1].RangeLength == Range2Length); + } +} + +TEST_SUITE_END(); + } // namespace zen #endif diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h index 1d0b7f9ea..5ed946541 100644 --- a/src/zenhttp/clients/httpclientcommon.h +++ b/src/zenhttp/clients/httpclientcommon.h @@ -3,6 +3,7 @@ #pragma once #include <zencore/compositebuffer.h> +#include <zencore/string.h> #include <zencore/trace.h> #include <zenhttp/httpclient.h> @@ -87,7 +88,7 @@ namespace detail { std::error_code Write(std::string_view DataString); IoBuffer DetachToIoBuffer(); IoBuffer BorrowIoBuffer(); - inline uint64_t GetSize() const { return m_WriteOffset; } + inline uint64_t GetSize() const { return m_WriteOffset + m_CacheBufferOffset; } void ResetWritePos(uint64_t WriteOffset); private: @@ -143,6 +144,118 @@ namespace detail { uint64_t m_BytesLeftInSegment; }; + class IncrementalStringMatcher + { + public: + enum class EMatchState + { + None, + Partial, + Complete + }; + + EMatchState MatchState = EMatchState::None; + + IncrementalStringMatcher() {} + + IncrementalStringMatcher(std::string&& InMatchString) : MatchString(std::move(InMatchString)) + { + RawMatchString = MatchString.data(); + } + + void Init(std::string&& InMatchString) + { + MatchString = std::move(InMatchString); + RawMatchString = MatchString.data(); + } + + inline void Reset() + { + MatchLength = 0; + MatchStartOffset = 0; + MatchState = EMatchState::None; + } + + inline uint64_t GetMatchEndOffset() const + { + if (MatchState == EMatchState::Complete) + { + return MatchStartOffset + MatchString.length(); + } + return 0; + } + + inline uint64_t GetMatchStartOffset() const + { + ZEN_ASSERT(MatchState == EMatchState::Complete); + return MatchStartOffset; + } + + void Match(uint64_t Offset, char C) + { + ZEN_ASSERT_SLOW(RawMatchString != nullptr); + + if (MatchState == EMatchState::Complete) + { + Reset(); + } + if (C == RawMatchString[MatchLength]) + { + if (MatchLength == 0) + { + MatchStartOffset = Offset; + } + MatchLength++; + if (MatchLength == MatchString.length()) + { + MatchState = EMatchState::Complete; + } + else + { + MatchState = EMatchState::Partial; + } + } + else if (MatchLength != 0) + { + Reset(); + Match(Offset, C); + } + else + { + Reset(); + } + } + inline const std::string& GetMatchString() const { return MatchString; } + + private: + std::string MatchString; + const char* RawMatchString = nullptr; + uint64_t MatchLength = 0; + + uint64_t MatchStartOffset = 0; + }; + + class MultipartBoundaryParser + { + public: + std::vector<HttpClient::Response::MultipartBoundary> Boundaries; + + MultipartBoundaryParser(); + bool Init(const std::string_view ContentTypeHeaderValue); + void ParseInput(std::string_view data); + + private: + IncrementalStringMatcher BoundaryBeginMatcher; + IncrementalStringMatcher BoundaryEndMatcher; + IncrementalStringMatcher HeaderEndMatcher; + + ExtendableStringBuilder<64> BoundaryHeader; + uint64_t PayloadOffset = 0; + }; + + std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString); + std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value); + } // namespace detail } // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index 5d92b3b6b..14e40b02a 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -12,6 +12,7 @@ #include <zencore/session.h> #include <zencore/stream.h> #include <zenhttp/packageformat.h> +#include <algorithm> namespace zen { @@ -23,6 +24,21 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; +bool +HttpClient::ErrorContext::IsConnectionError() const +{ + switch (static_cast<cpr::ErrorCode>(ErrorCode)) + { + case cpr::ErrorCode::CONNECTION_FAILURE: + case cpr::ErrorCode::OPERATION_TIMEDOUT: + case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: + case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: + return true; + default: + return false; + } +} + // If we want to support different HTTP client implementations then we'll need to make this more abstract HttpClientError::ResponseClass @@ -149,6 +165,18 @@ CprHttpClient::CprHttpClient(std::string_view BaseUri, { } +bool +CprHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const +{ + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + // Quiet + return false; + } + const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes; + return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end(); +} + CprHttpClient::~CprHttpClient() { ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient"); @@ -162,10 +190,11 @@ CprHttpClient::~CprHttpClient() } HttpClient::Response -CprHttpClient::ResponseWithPayload(std::string_view SessionId, - cpr::Response&& HttpResponse, - const HttpResponseCode WorkResponseCode, - IoBuffer&& Payload) +CprHttpClient::ResponseWithPayload(std::string_view SessionId, + cpr::Response&& HttpResponse, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) { // This ends up doing a memcpy, would be good to get rid of it by streaming results // into buffer directly @@ -174,30 +203,37 @@ CprHttpClient::ResponseWithPayload(std::string_view SessionId, if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) { const HttpContentType ContentType = ParseContentType(It->second); - ResponseBuffer.SetContentType(ContentType); } - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - - if (!Quiet) + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) { - if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + if (ShouldLogErrorCode(WorkResponseCode)) { ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse); } } + std::sort(BoundaryPositions.begin(), + BoundaryPositions.end(), + [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) { + return Lhs.RangeOffset < Rhs.RangeOffset; + }); + return HttpClient::Response{.StatusCode = WorkResponseCode, .ResponsePayload = std::move(ResponseBuffer), .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), - .ElapsedSeconds = HttpResponse.elapsed}; + .ElapsedSeconds = HttpResponse.elapsed, + .Ranges = std::move(BoundaryPositions)}; } HttpClient::Response -CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload) +CprHttpClient::CommonResponse(std::string_view SessionId, + cpr::Response&& HttpResponse, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) { const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); if (HttpResponse.error) @@ -235,7 +271,7 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe } else { - return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload)); + return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions)); } } @@ -346,8 +382,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, } Sleep(100 * (Attempt + 1)); Attempt++; - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - if (!Quiet) + if (ShouldLogErrorCode(HttpResponseCode(Result.status_code))) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), @@ -385,8 +420,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, } Sleep(100 * (Attempt + 1)); Attempt++; - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - if (!Quiet) + if (ShouldLogErrorCode(HttpResponseCode(Result.status_code))) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), @@ -621,7 +655,7 @@ CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const Ke ResponseBuffer.SetContentType(ContentType); } - return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; + return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = std::move(ResponseBuffer)}; } ////////////////////////////////////////////////////////////////////////// @@ -896,236 +930,287 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF std::string PayloadString; std::unique_ptr<detail::TempPayloadFile> PayloadFile; - cpr::Response Response = DoWithRetry( - m_SessionId, - [&]() { - auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> { - size_t DelimiterPos = header.find(':'); - if (DelimiterPos != std::string::npos) - { - std::string Key = header.substr(0, DelimiterPos); - constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); - Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); - Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); - - std::string Value = header.substr(DelimiterPos + 1); - Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); - Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); - - return std::make_pair(Key, Value); - } - return std::make_pair(header, ""); - }; - - auto DownloadCallback = [&](std::string data, intptr_t) { - if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) - { - return false; - } - if (PayloadFile) - { - ZEN_ASSERT(PayloadString.empty()); - std::error_code Ec = PayloadFile->Write(data); - if (Ec) - { - ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - return false; - } - } - else - { - PayloadString.append(data); - } - return true; - }; - - uint64_t RequestedContentLength = (uint64_t)-1; - if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) - { - if (RangeIt->second.starts_with("bytes")) - { - size_t RangeStartPos = RangeIt->second.find('=', 5); - if (RangeStartPos != std::string::npos) - { - RangeStartPos++; - size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos); - if (RangeSplitPos != std::string::npos) - { - std::optional<size_t> RequestedRangeStart = - ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos)); - std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1)); - if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) - { - RequestedContentLength = RequestedRangeEnd.value() - 1; - } - } - } - } - } - - cpr::Response Response; - { - std::vector<std::pair<std::string, std::string>> ReceivedHeaders; - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair<std::string, std::string> Header = GetHeader(header); - if (Header.first == "Content-Length"sv) - { - std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); - if (ContentLength.has_value()) - { - if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) - { - PayloadFile = std::make_unique<detail::TempPayloadFile>(); - std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); - if (Ec) - { - ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - PayloadFile.reset(); - } - } - else - { - PayloadString.reserve(ContentLength.value()); - } - } - } - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - return 1; - }; - - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair<std::string, std::string>& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - } - if (m_ConnectionSettings.AllowResume) - { - auto SupportsRanges = [](const cpr::Response& Response) -> bool { - if (Response.header.find("Content-Range") != Response.header.end()) - { - return true; - } - if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) - { - return It->second == "bytes"sv; - } - return false; - }; - - auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool { - if (ShouldRetry(Response)) - { - return SupportsRanges(Response); - } - return false; - }; - - if (ShouldResume(Response)) - { - auto It = Response.header.find("Content-Length"); - if (It != Response.header.end()) - { - uint64_t ContentLength = RequestedContentLength; - if (ContentLength == uint64_t(-1)) - { - if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) - { - ContentLength = ParsedContentLength.value(); - } - } - - std::vector<std::pair<std::string, std::string>> ReceivedHeaders; - - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair<std::string, std::string> Header = GetHeader(header); - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - - if (Header.first == "Content-Range"sv) - { - if (Header.second.starts_with("bytes "sv)) - { - size_t RangeStartEnd = Header.second.find('-', 6); - if (RangeStartEnd != std::string::npos) - { - const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); - if (Start) - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - if (Start.value() == DownloadedSize) - { - return 1; - } - else if (Start.value() > DownloadedSize) - { - return 0; - } - if (PayloadFile) - { - PayloadFile->ResetWritePos(Start.value()); - } - else - { - PayloadString = PayloadString.substr(0, Start.value()); - } - return 1; - } - } - } - return 0; - } - return 1; - }; - - KeyValueMap HeadersWithRange(AdditionalHeader); - do - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - - std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); - if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) - { - if (RangeIt->second == Range) - { - // If we didn't make any progress, abort - break; - } - } - HeadersWithRange.Entries.insert_or_assign("Range", Range); - - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair<std::string, std::string>& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - ReceivedHeaders.clear(); - } while (ShouldResume(Response)); - } - } - } - - if (!PayloadString.empty()) - { - Response.text = std::move(PayloadString); - } - return Response; - }, - PayloadFile); - - return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); + + HttpContentType ContentType = HttpContentType::kUnknownContentType; + detail::MultipartBoundaryParser BoundaryParser; + bool IsMultiRangeResponse = false; + + cpr::Response Response = DoWithRetry( + m_SessionId, + [&]() { + // Reset state from any previous attempt + PayloadString.clear(); + PayloadFile.reset(); + BoundaryParser.Boundaries.clear(); + ContentType = HttpContentType::kUnknownContentType; + IsMultiRangeResponse = false; + + auto DownloadCallback = [&](std::string data, intptr_t) { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + + if (IsMultiRangeResponse) + { + BoundaryParser.ParseInput(data); + } + + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + std::string_view RangeValue(RangeIt->second); + size_t RangeStartPos = RangeValue.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') + { + RangeStartPos++; + } + RequestedContentLength = 0; + + while (RangeStartPos < RangeValue.length()) + { + size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); + if (RangeEnd == std::string::npos) + { + RangeEnd = RangeValue.length(); + } + + std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); + size_t RangeSplitPos = RangeString.find('-'); + if (RangeSplitPos != std::string::npos) + { + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1; + } + } + RangeStartPos = RangeEnd; + while (RangeStartPos != RangeValue.length() && + (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' ')) + { + RangeStartPos++; + } + } + } + } + } + + cpr::Response Response; + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); + if (Header.first == "Content-Length"sv) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); + if (ContentLength.has_value()) + { + if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) + { + PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + else if (Header.first == "Content-Type") + { + IsMultiRangeResponse = BoundaryParser.Init(Header.second); + if (!IsMultiRangeResponse) + { + ContentType = ParseContentType(Header.second); + } + } + else if (Header.first == "Content-Range") + { + if (!IsMultiRangeResponse) + { + std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Header.second); + if (Range.second != 0) + { + BoundaryParser.Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0, + .RangeOffset = Range.first, + .RangeLength = Range.second, + .ContentType = ContentType}); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + } + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const cpr::Response& Response) -> bool { + if (Response.header.find("Content-Range") != Response.header.end()) + { + return true; + } + if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) + { + return It->second == "bytes"sv; + } + return false; + }; + + auto ShouldResume = [&SupportsRanges, &IsMultiRangeResponse](const cpr::Response& Response) -> bool { + if (IsMultiRangeResponse) + { + return false; + } + if (ShouldRetry(Response)) + { + return SupportsRanges(Response); + } + return false; + }; + + if (ShouldResume(Response)) + { + auto It = Response.header.find("Content-Length"); + if (It != Response.header.end()) + { + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + + auto HeaderCallback = [&](std::string header, intptr_t) { + const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + + if (Header.first == "Content-Range"sv) + { + if (Header.second.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Header.second.find('-', 6); + if (RangeStartEnd != std::string::npos) + { + const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + if (Start.value() == DownloadedSize) + { + return 1; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (PayloadFile) + { + PayloadFile->ResetWritePos(Start.value()); + } + else + { + PayloadString = PayloadString.substr(0, Start.value()); + } + return 1; + } + } + } + return 0; + } + return 1; + }; + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + // If we didn't make any progress, abort + break; + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + ReceivedHeaders.clear(); + } while (ShouldResume(Response)); + } + } + } + + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile); + + return CommonResponse(m_SessionId, + std::move(Response), + PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, + std::move(BoundaryParser.Boundaries)); } } // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h index 40af53b5d..752d91add 100644 --- a/src/zenhttp/clients/httpclientcpr.h +++ b/src/zenhttp/clients/httpclientcpr.h @@ -155,14 +155,19 @@ private: std::function<cpr::Response()>&& Func, std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; }); + bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const; bool ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile); - HttpClient::Response CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload); + HttpClient::Response CommonResponse(std::string_view SessionId, + cpr::Response&& HttpResponse, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {}); - HttpClient::Response ResponseWithPayload(std::string_view SessionId, - cpr::Response&& HttpResponse, - const HttpResponseCode WorkResponseCode, - IoBuffer&& Payload); + HttpClient::Response ResponseWithPayload(std::string_view SessionId, + cpr::Response&& HttpResponse, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions); }; } // namespace zen diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp new file mode 100644 index 000000000..9497dadb8 --- /dev/null +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -0,0 +1,566 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpwsclient.h> + +#include "../servers/wsframecodec.h" + +#include <zencore/base64.h> +#include <zencore/logging.h> +#include <zencore/string.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <deque> +#include <random> +#include <thread> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +struct HttpWsClient::Impl +{ + Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_OwnedIoContext(std::make_unique<asio::io_context>()) + , m_IoContext(*m_OwnedIoContext) + { + ParseUrl(Url); + } + + Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_IoContext(IoContext) + { + ParseUrl(Url); + } + + ~Impl() + { + // Release work guard so io_context::run() can return + m_WorkGuard.reset(); + + // Close the socket to cancel pending async ops + if (m_Socket) + { + asio::error_code Ec; + m_Socket->close(Ec); + } + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + } + + void ParseUrl(std::string_view Url) + { + // Expected format: ws://host:port/path + if (Url.substr(0, 5) == "ws://") + { + Url.remove_prefix(5); + } + + auto SlashPos = Url.find('/'); + std::string_view HostPort; + if (SlashPos != std::string_view::npos) + { + HostPort = Url.substr(0, SlashPos); + m_Path = std::string(Url.substr(SlashPos)); + } + else + { + HostPort = Url; + m_Path = "/"; + } + + auto ColonPos = HostPort.find(':'); + if (ColonPos != std::string_view::npos) + { + m_Host = std::string(HostPort.substr(0, ColonPos)); + m_Port = std::string(HostPort.substr(ColonPos + 1)); + } + else + { + m_Host = std::string(HostPort); + m_Port = "80"; + } + } + + void Connect() + { + if (m_OwnedIoContext) + { + m_WorkGuard = std::make_unique<asio::io_context::work>(m_IoContext); + m_IoThread = std::thread([this] { m_IoContext.run(); }); + } + + asio::post(m_IoContext, [this] { DoResolve(); }); + } + + void DoResolve() + { + m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext); + + m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) { + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "resolve failed"); + return; + } + + DoConnect(Results); + }); + } + + void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints) + { + m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoContext); + + // Start connect timeout timer + m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout); + m_Timer->async_wait([this](const asio::error_code& Ec) { + if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port); + if (m_Socket) + { + asio::error_code CloseEc; + m_Socket->close(CloseEc); + } + } + }); + + asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "connect failed"); + return; + } + + DoHandshake(); + }); + } + + void DoHandshake() + { + // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded) + uint8_t KeyBytes[16]; + { + static thread_local std::mt19937 s_Rng(std::random_device{}()); + for (int i = 0; i < 4; ++i) + { + uint32_t Val = s_Rng(); + std::memcpy(KeyBytes + i * 4, &Val, 4); + } + } + + char KeyBase64[Base64::GetEncodedDataSize(16) + 1]; + uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64); + KeyBase64[KeyLen] = '\0'; + m_WebSocketKey = std::string(KeyBase64, KeyLen); + + // Build the HTTP upgrade request + ExtendableStringBuilder<512> Request; + Request << "GET " << m_Path << " HTTP/1.1\r\n" + << "Host: " << m_Host << ":" << m_Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n" + << "Sec-WebSocket-Version: 13\r\n"; + + // Add Authorization header if access token provider is set + if (m_Settings.AccessTokenProvider) + { + HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); + if (Token.IsValid()) + { + Request << "Authorization: Bearer " << Token.Value << "\r\n"; + } + } + + Request << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + + m_HandshakeBuffer = std::make_shared<std::string>(ReqStr); + + asio::async_write(*m_Socket, + asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()), + [this](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake write failed"); + return; + } + + DoReadHandshakeResponse(); + }); + } + + void DoReadHandshakeResponse() + { + asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) { + m_Timer->cancel(); + + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake read failed"); + return; + } + + // Parse the response + const auto& Data = m_ReadBuffer.data(); + std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); + + // Consume the headers from the read buffer (any extra data stays for frame parsing) + auto HeaderEnd = Response.find("\r\n\r\n"); + if (HeaderEnd != std::string::npos) + { + m_ReadBuffer.consume(HeaderEnd + 4); + } + + // Validate 101 response + if (Response.find("101") == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); + m_Handler.OnWsClose(1006, "handshake rejected"); + return; + } + + // Validate Sec-WebSocket-Accept + std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); + if (Response.find(ExpectedAccept) == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); + m_Handler.OnWsClose(1006, "invalid accept key"); + return; + } + + m_IsOpen.store(true); + m_Handler.OnWsOpen(); + EnqueueRead(); + }); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Read loop + // + + void EnqueueRead() + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) { + OnDataReceived(Ec); + }); + } + + void OnDataReceived(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message()); + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "connection lost"); + } + return; + } + + ProcessReceivedData(); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + EnqueueRead(); + } + } + + void ProcessReceivedData() + { + while (m_ReadBuffer.size() > 0) + { + const auto& InputBuffer = m_ReadBuffer.data(); + const auto* RawData = static_cast<const uint8_t*>(InputBuffer.data()); + const auto Size = InputBuffer.size(); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size); + if (!Frame.IsValid) + { + break; + } + + m_ReadBuffer.consume(Frame.BytesConsumed); + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWsMessage(Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with masked pong + std::vector<uint8_t> PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = + std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo masked close frame if we haven't sent one yet + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + + m_IsOpen.store(false); + m_Handler.OnWsClose(Code, Reason); + return; + } + + default: + ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode)); + break; + } + } + } + + ////////////////////////////////////////////////////////////////////////// + // + // Write queue + // + + void EnqueueWrite(std::vector<uint8_t> Frame) + { + bool ShouldFlush = false; + + m_WriteLock.WithExclusiveLock([&] { + m_WriteQueue.push_back(std::move(Frame)); + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + }); + + if (ShouldFlush) + { + FlushWriteQueue(); + } + } + + void FlushWriteQueue() + { + std::vector<uint8_t> Frame; + + m_WriteLock.WithExclusiveLock([&] { + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + Frame = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + }); + + if (Frame.empty()) + { + return; + } + + auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame)); + + asio::async_write(*m_Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); }); + } + + void OnWriteComplete(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message()); + } + + m_WriteLock.WithExclusiveLock([&] { + m_IsWriting = false; + m_WriteQueue.clear(); + }); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "write error"); + } + return; + } + + FlushWriteQueue(); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Public operations + // + + void SendText(std::string_view Text) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); + } + + void SendBinary(std::span<const uint8_t> Data) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); + } + + void DoClose(uint16_t Code, std::string_view Reason) + { + if (!m_IsOpen.exchange(false)) + { + return; + } + + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + } + + IWsClientHandler& m_Handler; + HttpWsClientSettings m_Settings; + LoggerRef m_Log; + + std::string m_Host; + std::string m_Port; + std::string m_Path; + + // io_context: owned (standalone) or external (shared) + std::unique_ptr<asio::io_context> m_OwnedIoContext; + asio::io_context& m_IoContext; + std::unique_ptr<asio::io_context::work> m_WorkGuard; + std::thread m_IoThread; + + // Connection state + std::unique_ptr<asio::ip::tcp::resolver> m_Resolver; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<asio::steady_timer> m_Timer; + asio::streambuf m_ReadBuffer; + std::string m_WebSocketKey; + std::shared_ptr<std::string> m_HandshakeBuffer; + + // Write queue + RwLock m_WriteLock; + std::deque<std::vector<uint8_t>> m_WriteQueue; + bool m_IsWriting = false; + + std::atomic<bool> m_IsOpen{false}; + std::atomic<bool> m_CloseSent{false}; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(Url, Handler, Settings)) +{ +} + +HttpWsClient::HttpWsClient(std::string_view Url, + IWsClientHandler& Handler, + asio::io_context& IoContext, + const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(Url, Handler, IoContext, Settings)) +{ +} + +HttpWsClient::~HttpWsClient() = default; + +void +HttpWsClient::Connect() +{ + m_Impl->Connect(); +} + +void +HttpWsClient::SendText(std::string_view Text) +{ + m_Impl->SendText(Text); +} + +void +HttpWsClient::SendBinary(std::span<const uint8_t> Data) +{ + m_Impl->SendBinary(Data); +} + +void +HttpWsClient::Close(uint16_t Code, std::string_view Reason) +{ + m_Impl->DoClose(Code, Reason); +} + +bool +HttpWsClient::IsOpen() const +{ + return m_Impl->m_IsOpen.load(std::memory_order_relaxed); +} + +} // namespace zen diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 43e9fb468..281d512cf 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -21,9 +21,17 @@ #include "clients/httpclientcommon.h" +#include <numeric> + #if ZEN_WITH_TESTS +# include <zencore/scopeguard.h> # include <zencore/testing.h> # include <zencore/testutils.h> +# include <zenhttp/security/passwordsecurityfilter.h> +# include "servers/httpasio.h" +# include "servers/httpsys.h" + +# include <thread> #endif // ZEN_WITH_TESTS namespace zen { @@ -96,6 +104,44 @@ HttpClientBase::GetAccessToken() ////////////////////////////////////////////////////////////////////////// +std::vector<std::pair<uint64_t, uint64_t>> +HttpClient::Response::GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const +{ + if (Ranges.empty()) + { + return {}; + } + + std::vector<std::pair<uint64_t, uint64_t>> Result; + Result.reserve(OffsetAndLengthPairs.size()); + + auto BoundaryIt = Ranges.begin(); + auto OffsetAndLengthPairIt = OffsetAndLengthPairs.begin(); + while (OffsetAndLengthPairIt != OffsetAndLengthPairs.end()) + { + uint64_t Offset = OffsetAndLengthPairIt->first; + uint64_t Length = OffsetAndLengthPairIt->second; + while (Offset >= BoundaryIt->RangeOffset + BoundaryIt->RangeLength) + { + BoundaryIt++; + if (BoundaryIt == Ranges.end()) + { + throw std::runtime_error("HttpClient::Response can not fulfill requested range"); + } + } + if (Offset + Length > BoundaryIt->RangeOffset + BoundaryIt->RangeLength || Offset < BoundaryIt->RangeOffset) + { + throw std::runtime_error("HttpClient::Response can not fulfill requested range"); + } + uint64_t OffsetIntoRange = Offset - BoundaryIt->RangeOffset; + uint64_t RangePayloadOffset = BoundaryIt->OffsetInPayload + OffsetIntoRange; + Result.emplace_back(std::make_pair(RangePayloadOffset, Length)); + + OffsetAndLengthPairIt++; + } + return Result; +} + CbObject HttpClient::Response::AsObject() const { @@ -334,10 +380,55 @@ HttpClient::Authenticate() return m_Inner->Authenticate(); } +LatencyTestResult +MeasureLatency(HttpClient& Client, std::string_view Url) +{ + std::vector<double> MeasurementTimes; + std::string ErrorMessage; + + for (uint32_t AttemptCount = 0; AttemptCount < 20 && MeasurementTimes.size() < 5; AttemptCount++) + { + HttpClient::Response MeasureResponse = Client.Get(Url); + if (MeasureResponse.IsSuccess()) + { + MeasurementTimes.push_back(MeasureResponse.ElapsedSeconds); + Sleep(5); + } + else + { + ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url)); + + // Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable. + // Bail out immediately — retrying will just burn the connect timeout each time. + if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError()) + { + break; + } + } + } + + if (MeasurementTimes.empty()) + { + return {.Success = false, .FailureReason = ErrorMessage}; + } + + if (MeasurementTimes.size() > 2) + { + std::sort(MeasurementTimes.begin(), MeasurementTimes.end()); + MeasurementTimes.pop_back(); // Remove the worst time + } + + double AverageLatency = std::accumulate(MeasurementTimes.begin(), MeasurementTimes.end(), 0.0) / MeasurementTimes.size(); + + return {.Success = true, .LatencySeconds = AverageLatency}; +} + ////////////////////////////////////////////////////////////////////////// #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.httpclient"); + TEST_CASE("responseformat") { using namespace std::literals; @@ -388,8 +479,366 @@ TEST_CASE("httpclient") { using namespace std::literals; - SUBCASE("client") {} + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + if (HttpServiceRequest.IsLocalMachineRequest()) + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + else + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey stranger"); + } + } + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK); + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + SUBCASE("asio") + { + Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{}); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread ServerThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + AsioServer->Close(); + }); + + { + HttpClient Client(fmt::format("127.0.0.1:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + + if (IsIPv6Capable()) + { + HttpClient Client(fmt::format("[::1]:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + + { + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } +# if 0 + { + HttpClient Client(fmt::format("10.24.101.77:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + Sleep(20000); +# endif // 0 + AsioServer->RequestExit(); + } + } + +# if ZEN_PLATFORM_WINDOWS + SUBCASE("httpsys") + { + Ref<HttpServer> HttpSysServer = CreateHttpSysServer(HttpSysConfig{.ForceLoopback = false}); + + int Port = HttpSysServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + HttpSysServer->RegisterService(TestService); + + std::thread ServerThread([&]() { HttpSysServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + HttpSysServer->Close(); + }); + + if (true) + { + HttpClient Client(fmt::format("127.0.0.1:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + + if (IsIPv6Capable()) + { + HttpClient Client(fmt::format("[::1]:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + + { + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } +# if 0 + { + HttpClient Client(fmt::format("10.24.101.77:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + Sleep(20000); +# endif // 0 + HttpSysServer->RequestExit(); + } + } +# endif // ZEN_PLATFORM_WINDOWS +} + +TEST_CASE("httpclient.requestfilter") +{ + using namespace std::literals; + + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_filter"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_forbid"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + class MyFilterImpl : public IHttpRequestFilter + { + public: + virtual Result FilterRequest(HttpServerRequest& Request) + { + if (Request.RelativeUri() == "should_filter") + { + Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "no thank you"); + return Result::ResponseSent; + } + else if (Request.RelativeUri() == "should_forbid") + { + return Result::Forbidden; + } + return Result::Accepted; + } + }; + + MyFilterImpl MyFilter; + + Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{}); + + AsioServer->SetHttpRequestFilter(&MyFilter); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread ServerThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + AsioServer->Close(); + }); + + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response YoResponse = Client.Get("/test/yo"); + CHECK(YoResponse.IsSuccess()); + CHECK_EQ(YoResponse.AsText(), "hey family"); + + HttpClient::Response ShouldFilterResponse = Client.Get("/test/should_filter"); + CHECK_EQ(ShouldFilterResponse.StatusCode, HttpResponseCode::MethodNotAllowed); + CHECK_EQ(ShouldFilterResponse.AsText(), "no thank you"); + + HttpClient::Response ShouldForbitResponse = Client.Get("/test/should_forbid"); + CHECK_EQ(ShouldForbitResponse.StatusCode, HttpResponseCode::Forbidden); + + AsioServer->RequestExit(); + } +} + +TEST_CASE("httpclient.password") +{ + using namespace std::literals; + + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_filter"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_forbid"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{}); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread ServerThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + AsioServer->Close(); + }); + + SUBCASE("usernamepassword") + { + CbObjectWriter Writer; + { + Writer.BeginObject("basic"); + { + Writer << "username"sv + << "me"; + Writer << "password"sv + << "456123789"; + } + Writer.EndObject(); + Writer << "protect-machine-local-requests" << true; + } + + PasswordHttpFilter::Configuration PasswordFilterOptions = PasswordHttpFilter::ReadConfiguration(Writer.Save()); + + PasswordHttpFilter MyFilter(PasswordFilterOptions); + + AsioServer->SetHttpRequestFilter(&MyFilter); + + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response ForbiddenResponse = Client.Get("/test/yo"); + CHECK(!ForbiddenResponse.IsSuccess()); + CHECK_EQ(ForbiddenResponse.StatusCode, HttpResponseCode::Forbidden); + + HttpClient::Response WithBasicResponse = + Client.Get("/test/yo", + std::pair<std::string, std::string>("Authorization", + fmt::format("Basic {}", PasswordFilterOptions.PasswordConfig.Password))); + CHECK(WithBasicResponse.IsSuccess()); + AsioServer->SetHttpRequestFilter(nullptr); + } + AsioServer->RequestExit(); + } } +TEST_SUITE_END(); void httpclient_forcelink() diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp new file mode 100644 index 000000000..52bf149a7 --- /dev/null +++ b/src/zenhttp/httpclient_test.cpp @@ -0,0 +1,1366 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpclient.h> +#include <zenhttp/httpserver.h> + +#if ZEN_WITH_TESTS + +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinaryutil.h> +# include <zencore/compositebuffer.h> +# include <zencore/iobuffer.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/session.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> + +# include "servers/httpasio.h" + +# include <atomic> +# include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// Test service + +class HttpClientTestService : public HttpService +{ +public: + HttpClientTestService() + { + m_Router.AddMatcher("statuscode", [](std::string_view Str) -> bool { + for (char C : Str) + { + if (C < '0' || C > '9') + { + return false; + } + } + return !Str.empty(); + }); + + m_Router.RegisterRoute( + "hello", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "echo", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + IoBuffer Body = HttpReq.ReadPayload(); + HttpContentType CT = HttpReq.RequestContentType(); + HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body); + }, + HttpVerb::kPost | HttpVerb::kPut); + + m_Router.RegisterRoute( + "echo/headers", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Auth = HttpReq.GetAuthorizationHeader(); + CbObjectWriter Writer; + if (!Auth.empty()) + { + Writer.AddString("Authorization", Auth); + } + HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save()); + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "echo/method", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Method = ToString(HttpReq.RequestVerb()); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + + m_Router.RegisterRoute( + "json", + [](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddBool("ok", true); + Obj.AddString("message", "test"); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "nocontent", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "created", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::Created, HttpContentType::kText, "resource created"); + }, + HttpVerb::kPost | HttpVerb::kPut); + + m_Router.RegisterRoute( + "content-type/text", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "plain text"); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/json", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"key\":\"value\"}"); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/binary", + [](HttpRouterRequest& Req) { + uint8_t Data[] = {0xDE, 0xAD, 0xBE, 0xEF}; + IoBuffer Buf(IoBuffer::Clone, Data, sizeof(Data)); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/cbobject", + [](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddString("type", "cbobject"); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "auth/bearer", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Auth = HttpReq.GetAuthorizationHeader(); + if (Auth.starts_with("Bearer ") && Auth.size() > 7) + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "authenticated"); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::Unauthorized, HttpContentType::kText, "unauthorized"); + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "slow", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) { + Sleep(2000); + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response"); + }); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "large", + [](HttpRouterRequest& Req) { + constexpr size_t Size = 64 * 1024; + IoBuffer Buf(Size); + uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData()); + for (size_t i = 0; i < Size; ++i) + { + Ptr[i] = static_cast<uint8_t>(i & 0xFF); + } + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "status/{statuscode}", + [](HttpRouterRequest& Req) { + std::string_view CodeStr = Req.GetCapture(1); + int Code = std::stoi(std::string{CodeStr}); + const HttpResponseCode ResponseCode = static_cast<HttpResponseCode>(Code); + Req.ServerRequest().WriteResponse(ResponseCode); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + + m_Router.RegisterRoute( + "attempt-counter", + [this](HttpRouterRequest& Req) { + uint32_t Count = m_AttemptCounter.fetch_add(1); + if (Count < m_FailCount) + { + Req.ServerRequest().WriteResponse(HttpResponseCode::ServiceUnavailable); + } + else + { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "success after retries"); + } + }, + HttpVerb::kGet); + } + + virtual const char* BaseUri() const override { return "/api/test/"; } + virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); } + + void ResetAttemptCounter(uint32_t FailCount) + { + m_AttemptCounter.store(0); + m_FailCount = FailCount; + } + +private: + HttpRequestRouter m_Router; + std::atomic<uint32_t> m_AttemptCounter{0}; + uint32_t m_FailCount = 2; +}; + +////////////////////////////////////////////////////////////////////////// +// Test server fixture + +struct TestServerFixture +{ + HttpClientTestService TestService; + ScopedTemporaryDirectory TmpDir; + Ref<HttpServer> Server; + std::thread ServerThread; + int Port = -1; + + TestServerFixture() + { + Server = CreateHttpAsioServer(AsioConfig{}); + Port = Server->Initialize(7600, TmpDir.Path()); + ZEN_ASSERT(Port != -1); + Server->RegisterService(TestService); + ServerThread = std::thread([this]() { Server->Run(false); }); + } + + ~TestServerFixture() + { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + } + + HttpClient MakeClient(HttpClientSettings Settings = {}) + { + return HttpClient(fmt::format("127.0.0.1:{}", Port), Settings, /*CheckIfAbortFunction*/ {}); + } +}; + +////////////////////////////////////////////////////////////////////////// +// Tests + +TEST_SUITE_BEGIN("http.httpclient"); + +TEST_CASE("httpclient.verbs") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("GET returns 200 with expected body") + { + HttpClient::Response Resp = Client.Get("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "GET"); + } + + SUBCASE("POST dispatches correctly") + { + HttpClient::Response Resp = Client.Post("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "POST"); + } + + SUBCASE("PUT dispatches correctly") + { + HttpClient::Response Resp = Client.Put("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "PUT"); + } + + SUBCASE("DELETE dispatches correctly") + { + HttpClient::Response Resp = Client.Delete("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "DELETE"); + } + + SUBCASE("HEAD returns 200 with empty body") + { + HttpClient::Response Resp = Client.Head("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), ""sv); + } +} + +TEST_CASE("httpclient.get") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("simple GET with text response") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("GET with auth header via echo") + { + HttpClient::Response Resp = + Client.Get("/api/test/echo/headers", std::pair<std::string, std::string>("Authorization", "Bearer test-token-123")); + CHECK(Resp.IsSuccess()); + CbObject Obj = Resp.AsObject(); + CHECK_EQ(Obj["Authorization"].AsString(), "Bearer test-token-123"); + } + + SUBCASE("GET returning CbObject") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + CHECK(Resp.IsSuccess()); + CbObject Obj = Resp.AsObject(); + CHECK(Obj["ok"].AsBool() == true); + CHECK_EQ(Obj["message"].AsString(), "test"); + } + + SUBCASE("GET large payload") + { + HttpClient::Response Resp = Client.Get("/api/test/large"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + + const uint8_t* Data = static_cast<const uint8_t*>(Resp.ResponsePayload.GetData()); + bool Valid = true; + for (size_t i = 0; i < 64 * 1024; ++i) + { + if (Data[i] != static_cast<uint8_t>(i & 0xFF)) + { + Valid = false; + break; + } + } + CHECK(Valid); + } +} + +TEST_CASE("httpclient.post") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("POST with IoBuffer payload echo round-trip") + { + const char* Payload = "test payload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "test payload data"); + } + + SUBCASE("POST with IoBuffer and explicit content type") + { + const char* Payload = "{\"key\":\"value\"}"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf, ZenContentType::kJSON); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "{\"key\":\"value\"}"); + } + + SUBCASE("POST with CbObject payload round-trip") + { + CbObjectWriter Writer; + Writer.AddBool("enabled", true); + Writer.AddString("name", "testobj"); + CbObject Obj = Writer.Save(); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Obj); + CHECK(Resp.IsSuccess()); + CbObject RoundTripped = Resp.AsObject(); + CHECK(RoundTripped["enabled"].AsBool() == true); + CHECK_EQ(RoundTripped["name"].AsString(), "testobj"); + } + + SUBCASE("POST with CompositeBuffer payload") + { + const char* Part1 = "hello "; + const char* Part2 = "composite"; + IoBuffer Buf1(IoBuffer::Clone, Part1, strlen(Part1)); + IoBuffer Buf2(IoBuffer::Clone, Part2, strlen(Part2)); + + SharedBuffer Seg1{Buf1}; + SharedBuffer Seg2{Buf2}; + CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)}; + + HttpClient::Response Resp = Client.Post("/api/test/echo", Composite, ZenContentType::kText); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello composite"); + } + + SUBCASE("POST with custom headers") + { + HttpClient::Response Resp = Client.Post("/api/test/echo/headers", HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{}); + CHECK(Resp.IsSuccess()); + } + + SUBCASE("POST with empty body to nocontent endpoint") + { + HttpClient::Response Resp = Client.Post("/api/test/nocontent"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); + } +} + +TEST_CASE("httpclient.put") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("PUT with IoBuffer payload echo round-trip") + { + const char* Payload = "put payload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Put("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "put payload data"); + } + + SUBCASE("PUT with parameters only") + { + HttpClient::Response Resp = Client.Put("/api/test/nocontent"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); + } + + SUBCASE("PUT to created endpoint") + { + const char* Payload = "new resource"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Put("/api/test/created", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::Created); + CHECK_EQ(Resp.AsText(), "resource created"); + } +} + +TEST_CASE("httpclient.upload") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("Upload IoBuffer") + { + constexpr size_t Size = 128 * 1024; + IoBuffer Blob = CreateSemiRandomBlob(Size); + + HttpClient::Response Resp = Client.Upload("/api/test/echo", Blob); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), Size); + } + + SUBCASE("Upload CompositeBuffer") + { + IoBuffer Buf1 = CreateSemiRandomBlob(32 * 1024); + IoBuffer Buf2 = CreateSemiRandomBlob(32 * 1024); + + SharedBuffer Seg1{Buf1}; + SharedBuffer Seg2{Buf2}; + CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)}; + + HttpClient::Response Resp = Client.Upload("/api/test/echo", Composite, ZenContentType::kBinary); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + } +} + +TEST_CASE("httpclient.download") +{ + TestServerFixture Fixture; + ScopedTemporaryDirectory DownloadDir; + + SUBCASE("Download small payload stays in memory") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Download("/api/test/hello", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("Download with reduced MaximumInMemoryDownloadSize forces file spill") + { + HttpClientSettings Settings; + Settings.MaximumInMemoryDownloadSize = 4; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Download("/api/test/large", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + } +} + +TEST_CASE("httpclient.status-codes") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("2xx are success") + { + CHECK(Client.Get("/api/test/status/200").IsSuccess()); + CHECK(Client.Get("/api/test/status/201").IsSuccess()); + CHECK(Client.Get("/api/test/status/204").IsSuccess()); + } + + SUBCASE("4xx are not success") + { + CHECK(!Client.Get("/api/test/status/400").IsSuccess()); + CHECK(!Client.Get("/api/test/status/401").IsSuccess()); + CHECK(!Client.Get("/api/test/status/403").IsSuccess()); + CHECK(!Client.Get("/api/test/status/404").IsSuccess()); + CHECK(!Client.Get("/api/test/status/409").IsSuccess()); + } + + SUBCASE("5xx are not success") + { + CHECK(!Client.Get("/api/test/status/500").IsSuccess()); + CHECK(!Client.Get("/api/test/status/502").IsSuccess()); + CHECK(!Client.Get("/api/test/status/503").IsSuccess()); + } + + SUBCASE("status code values match") + { + CHECK_EQ(Client.Get("/api/test/status/200").StatusCode, HttpResponseCode::OK); + CHECK_EQ(Client.Get("/api/test/status/201").StatusCode, HttpResponseCode::Created); + CHECK_EQ(Client.Get("/api/test/status/204").StatusCode, HttpResponseCode::NoContent); + CHECK_EQ(Client.Get("/api/test/status/400").StatusCode, HttpResponseCode::BadRequest); + CHECK_EQ(Client.Get("/api/test/status/401").StatusCode, HttpResponseCode::Unauthorized); + CHECK_EQ(Client.Get("/api/test/status/403").StatusCode, HttpResponseCode::Forbidden); + CHECK_EQ(Client.Get("/api/test/status/404").StatusCode, HttpResponseCode::NotFound); + CHECK_EQ(Client.Get("/api/test/status/409").StatusCode, HttpResponseCode::Conflict); + CHECK_EQ(Client.Get("/api/test/status/500").StatusCode, HttpResponseCode::InternalServerError); + CHECK_EQ(Client.Get("/api/test/status/502").StatusCode, HttpResponseCode::BadGateway); + CHECK_EQ(Client.Get("/api/test/status/503").StatusCode, HttpResponseCode::ServiceUnavailable); + } +} + +TEST_CASE("httpclient.response") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("IsSuccess and operator bool for success") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(static_cast<bool>(Resp)); + } + + SUBCASE("IsSuccess and operator bool for failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/404"); + CHECK(!Resp.IsSuccess()); + CHECK(!static_cast<bool>(Resp)); + } + + SUBCASE("AsText returns body") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("AsText returns empty for no-content") + { + HttpClient::Response Resp = Client.Get("/api/test/nocontent"); + CHECK(Resp.AsText().empty()); + } + + SUBCASE("AsObject parses CbObject") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + CbObject Obj = Resp.AsObject(); + CHECK(Obj["ok"].AsBool() == true); + CHECK_EQ(Obj["message"].AsString(), "test"); + } + + SUBCASE("AsObject returns empty for non-CB content") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CbObject Obj = Resp.AsObject(); + CHECK(!Obj); + } + + SUBCASE("ToText for text content") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/text"); + CHECK_EQ(Resp.ToText(), "plain text"); + } + + SUBCASE("ToText for CbObject content") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + std::string Text = Resp.ToText(); + CHECK(!Text.empty()); + // ToText for CbObject converts to JSON string representation + CHECK(Text.find("ok") != std::string::npos); + CHECK(Text.find("test") != std::string::npos); + } + + SUBCASE("ErrorMessage includes status code on failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/404"); + std::string Msg = Resp.ErrorMessage("test-prefix"); + CHECK(Msg.find("test-prefix") != std::string::npos); + CHECK(Msg.find("404") != std::string::npos); + } + + SUBCASE("ThrowError throws on failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/500"); + CHECK_THROWS_AS(Resp.ThrowError("test"), HttpClientError); + } + + SUBCASE("ThrowError does not throw on success") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK_NOTHROW(Resp.ThrowError("test")); + } + + SUBCASE("HttpClientError carries response code") + { + HttpClient::Response Resp = Client.Get("/api/test/status/403"); + try + { + Resp.ThrowError("test"); + CHECK(false); // should not reach + } + catch (const HttpClientError& Err) + { + CHECK_EQ(Err.GetHttpResponseCode(), HttpResponseCode::Forbidden); + } + } +} + +TEST_CASE("httpclient.error-handling") +{ + SUBCASE("Connection refused") + { + HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("Request timeout") + { + TestServerFixture Fixture; + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(500); + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/slow"); + CHECK(!Resp.IsSuccess()); + } + + SUBCASE("Nonexistent endpoint returns failure") + { + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Get("/api/test/does-not-exist"); + CHECK(!Resp.IsSuccess()); + } +} + +TEST_CASE("httpclient.session") +{ + TestServerFixture Fixture; + + SUBCASE("Default session ID is non-empty") + { + HttpClient Client = Fixture.MakeClient(); + CHECK(!Client.GetSessionId().empty()); + } + + SUBCASE("SetSessionId changes ID") + { + HttpClient Client = Fixture.MakeClient(); + Oid NewId = Oid::NewOid(); + std::string OldId = std::string(Client.GetSessionId()); + Client.SetSessionId(NewId); + CHECK_EQ(Client.GetSessionId(), NewId.ToString()); + CHECK_NE(Client.GetSessionId(), OldId); + } + + SUBCASE("SetSessionId with Zero resets") + { + HttpClient Client = Fixture.MakeClient(); + Oid NewId = Oid::NewOid(); + Client.SetSessionId(NewId); + CHECK_EQ(Client.GetSessionId(), NewId.ToString()); + Client.SetSessionId(Oid::Zero); + // After resetting, should get a session string (not empty, not the custom one) + CHECK(!Client.GetSessionId().empty()); + CHECK_NE(Client.GetSessionId(), NewId.ToString()); + } +} + +TEST_CASE("httpclient.authentication") +{ + TestServerFixture Fixture; + + SUBCASE("Authenticate returns false without provider") + { + HttpClient Client = Fixture.MakeClient(); + CHECK(!Client.Authenticate()); + } + + SUBCASE("Authenticate returns true with valid token") + { + HttpClientSettings Settings; + Settings.AccessTokenProvider = []() -> HttpClientAccessToken { + return HttpClientAccessToken{ + .Value = "valid-token", + .ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours(1), + }; + }; + HttpClient Client = Fixture.MakeClient(Settings); + CHECK(Client.Authenticate()); + } + + SUBCASE("Authenticate returns false with expired token") + { + HttpClientSettings Settings; + Settings.AccessTokenProvider = []() -> HttpClientAccessToken { + return HttpClientAccessToken{ + .Value = "expired-token", + .ExpireTime = HttpClientAccessToken::Clock::now() - std::chrono::hours(1), + }; + }; + HttpClient Client = Fixture.MakeClient(Settings); + CHECK(!Client.Authenticate()); + } + + SUBCASE("Bearer token verified by auth endpoint") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response AuthResp = + Client.Get("/api/test/auth/bearer", std::pair<std::string, std::string>("Authorization", "Bearer my-secret-token")); + CHECK(AuthResp.IsSuccess()); + CHECK_EQ(AuthResp.AsText(), "authenticated"); + } + + SUBCASE("Request without token to auth endpoint gets 401") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Get("/api/test/auth/bearer"); + CHECK(!Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::Unauthorized); + } +} + +TEST_CASE("httpclient.content-types") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("text content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/text"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kText); + } + + SUBCASE("JSON content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/json"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kJSON); + } + + SUBCASE("binary content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/binary"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kBinary); + } + + SUBCASE("CbObject content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/cbobject"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kCbObject); + } +} + +TEST_CASE("httpclient.metadata") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("ElapsedSeconds is positive") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(Resp.ElapsedSeconds > 0.0); + } + + SUBCASE("DownloadedBytes populated for GET") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(Resp.DownloadedBytes > 0); + } + + SUBCASE("UploadedBytes populated for POST with payload") + { + const char* Payload = "some upload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK(Resp.UploadedBytes > 0); + } +} + +TEST_CASE("httpclient.retry") +{ + TestServerFixture Fixture; + + SUBCASE("Retry succeeds after transient failures") + { + Fixture.TestService.ResetAttemptCounter(2); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/attempt-counter"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "success after retries"); + } + + SUBCASE("No retry returns 503 immediately") + { + Fixture.TestService.ResetAttemptCounter(2); + + HttpClientSettings Settings; + Settings.RetryCount = 0; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/attempt-counter"); + CHECK(!Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::ServiceUnavailable); + } +} + +TEST_CASE("httpclient.measurelatency") +{ + SUBCASE("Successful measurement against live server") + { + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello"); + CHECK(Result.Success); + CHECK(Result.LatencySeconds > 0.0); + } + + SUBCASE("Failed measurement against unreachable port") + { + HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello"); + CHECK(!Result.Success); + CHECK(!Result.FailureReason.empty()); + } +} + +TEST_CASE("httpclient.keyvaluemap") +{ + SUBCASE("Default construction is empty") + { + HttpClient::KeyValueMap Map; + CHECK(Map->empty()); + } + + SUBCASE("Construction from pair") + { + HttpClient::KeyValueMap Map(std::pair<std::string, std::string>("key", "value")); + CHECK_EQ(Map->size(), 1u); + CHECK_EQ(Map->at("key"), "value"); + } + + SUBCASE("Construction from string_view pair") + { + HttpClient::KeyValueMap Map(std::pair<std::string_view, std::string_view>("key"sv, "value"sv)); + CHECK_EQ(Map->size(), 1u); + CHECK_EQ(Map->at("key"), "value"); + } + + SUBCASE("Construction from initializer list") + { + HttpClient::KeyValueMap Map({{"a"sv, "1"sv}, {"b"sv, "2"sv}}); + CHECK_EQ(Map->size(), 2u); + CHECK_EQ(Map->at("a"), "1"); + CHECK_EQ(Map->at("b"), "2"); + } +} + +////////////////////////////////////////////////////////////////////////// +// Transport fault testing + +static std::string +MakeRawHttpResponse(int StatusCode, std::string_view Body) +{ + return fmt::format( + "HTTP/1.1 {} OK\r\n" + "Content-Type: text/plain\r\n" + "Content-Length: {}\r\n" + "\r\n" + "{}", + StatusCode, + Body.size(), + Body); +} + +static std::string +MakeRawHttpHeaders(int StatusCode, size_t ContentLength) +{ + return fmt::format( + "HTTP/1.1 {} OK\r\n" + "Content-Type: application/octet-stream\r\n" + "Content-Length: {}\r\n" + "\r\n", + StatusCode, + ContentLength); +} + +static void +DrainHttpRequest(asio::ip::tcp::socket& Socket) +{ + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); +} + +static void +DrainFullHttpRequest(asio::ip::tcp::socket& Socket) +{ + // Read until end of headers + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); + if (Ec) + { + return; + } + + // Extract headers to find Content-Length + std::string Headers(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data())); + + size_t ContentLength = 0; + auto Pos = Headers.find("Content-Length: "); + if (Pos == std::string::npos) + { + Pos = Headers.find("content-length: "); + } + if (Pos != std::string::npos) + { + size_t ValStart = Pos + 16; // length of "Content-Length: " + size_t ValEnd = Headers.find("\r\n", ValStart); + if (ValEnd != std::string::npos) + { + ContentLength = std::stoull(Headers.substr(ValStart, ValEnd - ValStart)); + } + } + + // Calculate how many body bytes were already read past the header boundary. + // asio::read_until may read past the delimiter, so Buf.data() contains everything read. + size_t HeaderEnd = Headers.find("\r\n\r\n") + 4; + size_t BodyBytesInBuf = Headers.size() > HeaderEnd ? Headers.size() - HeaderEnd : 0; + size_t Remaining = ContentLength > BodyBytesInBuf ? ContentLength - BodyBytesInBuf : 0; + + if (Remaining > 0) + { + std::vector<char> BodyBuf(Remaining); + asio::read(Socket, asio::buffer(BodyBuf), Ec); + } +} + +static void +DrainPartialBody(asio::ip::tcp::socket& Socket, size_t BytesToRead) +{ + // Read headers first + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); + if (Ec) + { + return; + } + + // Determine how many body bytes were already buffered past headers + std::string All(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data())); + size_t HeaderEnd = All.find("\r\n\r\n") + 4; + size_t BodyBytesInBuf = All.size() > HeaderEnd ? All.size() - HeaderEnd : 0; + + if (BodyBytesInBuf < BytesToRead) + { + size_t Remaining = BytesToRead - BodyBytesInBuf; + std::vector<char> BodyBuf(Remaining); + asio::read(Socket, asio::buffer(BodyBuf), Ec); + } +} + +struct FaultTcpServer +{ + using FaultHandler = std::function<void(asio::ip::tcp::socket&)>; + + asio::io_context m_IoContext; + asio::ip::tcp::acceptor m_Acceptor; + FaultHandler m_Handler; + std::thread m_Thread; + int m_Port; + + explicit FaultTcpServer(FaultHandler Handler) + : m_Acceptor(m_IoContext, asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), 0)) + , m_Handler(std::move(Handler)) + { + m_Port = m_Acceptor.local_endpoint().port(); + StartAccept(); + m_Thread = std::thread([this]() { m_IoContext.run(); }); + } + + ~FaultTcpServer() + { + std::error_code Ec; + m_Acceptor.close(Ec); + m_IoContext.stop(); + if (m_Thread.joinable()) + { + m_Thread.join(); + } + } + + FaultTcpServer(const FaultTcpServer&) = delete; + FaultTcpServer& operator=(const FaultTcpServer&) = delete; + + void StartAccept() + { + m_Acceptor.async_accept([this](std::error_code Ec, asio::ip::tcp::socket Socket) { + if (!Ec) + { + m_Handler(Socket); + } + if (m_Acceptor.is_open()) + { + StartAccept(); + } + }); + } + + HttpClient MakeClient(HttpClientSettings Settings = {}) + { + return HttpClient(fmt::format("127.0.0.1:{}", m_Port), Settings, /*CheckIfAbortFunction*/ {}); + } +}; + +TEST_CASE("httpclient.transport-faults" * doctest::skip()) +{ + SUBCASE("connection reset before response") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("connection closed before response") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("partial headers then close") + { + // libcurl parses the status line (200 OK) and accepts the response even though + // headers are truncated mid-field. It reports success with an empty body instead + // of an error. Ideally this should be detected as a transport failure. + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Partial = "HTTP/1.1 200 OK\r\nContent-"; + std::error_code Ec; + asio::write(Socket, asio::buffer(Partial), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + WARN(!Resp.IsSuccess()); + WARN(Resp.Error.has_value()); + } + + SUBCASE("truncated body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 1000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + std::string PartialBody(100, 'x'); + asio::write(Socket, asio::buffer(PartialBody), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("connection reset mid-body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 10000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + std::string PartialBody(1000, 'x'); + asio::write(Socket, asio::buffer(PartialBody), Ec); + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("stalled response triggers timeout") + { + std::atomic<bool> StallActive{true}; + FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 1000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + while (StallActive.load()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + }); + + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(500); + HttpClient Client = Server.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + StallActive.store(false); + } + + SUBCASE("retry succeeds after transient failures") + { + std::atomic<int> ConnCount{0}; + FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) { + int N = ConnCount.fetch_add(1); + DrainHttpRequest(Socket); + if (N < 2) + { + // Connection reset produces NETWORK_SEND_FAILURE which is retryable + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + } + else + { + std::string Response = MakeRawHttpResponse(200, "recovered"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + } + }); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Server.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/test"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "recovered"); + } +} + +TEST_CASE("httpclient.transport-faults-post" * doctest::skip()) +{ + constexpr size_t kPostBodySize = 256 * 1024; + + auto MakePostBody = []() -> IoBuffer { + IoBuffer Buf(kPostBodySize); + uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData()); + for (size_t i = 0; i < kPostBodySize; ++i) + { + Ptr[i] = static_cast<uint8_t>(i & 0xFF); + } + Buf.SetContentType(ZenContentType::kBinary); + return Buf; + }; + + SUBCASE("POST: server resets before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: server closes before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: server resets mid-body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainPartialBody(Socket, 8 * 1024); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: early error response before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Response = MakeRawHttpResponse(503, "service busy"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + // With a large upload body, the server may RST the connection before the client + // reads the 503 response. Either outcome is valid: the client sees the HTTP 503 + // status, or it sees a transport-level error from the RST. + CHECK((Resp.StatusCode == HttpResponseCode::ServiceUnavailable || Resp.Error.has_value())); + } + + SUBCASE("POST: stalled upload triggers timeout") + { + std::atomic<bool> StallActive{true}; + FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + // Stop reading body — TCP window will fill and client send will stall + while (StallActive.load()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + }); + + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(2000); + HttpClient Client = Server.MakeClient(Settings); + + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + StallActive.store(false); + } + + SUBCASE("POST: retry with large body after transient failure") + { + std::atomic<int> ConnCount{0}; + FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) { + int N = ConnCount.fetch_add(1); + if (N < 2) + { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + } + else + { + DrainFullHttpRequest(Socket); + std::string Response = MakeRawHttpResponse(200, "upload-ok"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + } + }); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Server.MakeClient(Settings); + + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "upload-ok"); + } +} + +TEST_SUITE_END(); + +void +httpclient_test_forcelink() +{ +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp index 72df12d02..02e1b57e2 100644 --- a/src/zenhttp/httpclientauth.cpp +++ b/src/zenhttp/httpclientauth.cpp @@ -170,7 +170,7 @@ namespace zen { namespace httpclientauth { time_t UTCTime = timegm(&Time); HttpClientAccessToken::TimePoint ExpireTime = std::chrono::system_clock::from_time_t(UTCTime); - ExpireTime += std::chrono::microseconds(Millisecond); + ExpireTime += std::chrono::milliseconds(Millisecond); return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime}; } diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index c4e67d4ed..9bae95690 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -23,10 +23,12 @@ #include <zencore/logging.h> #include <zencore/stream.h> #include <zencore/string.h> +#include <zencore/system.h> #include <zencore/testing.h> #include <zencore/thread.h> #include <zenhttp/packageformat.h> #include <zentelemetry/otlptrace.h> +#include <zentelemetry/stats.h> #include <charconv> #include <mutex> @@ -463,7 +465,7 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) ////////////////////////////////////////////////////////////////////////// -HttpServerRequest::HttpServerRequest(HttpService& Service) : m_BaseUri(Service.BaseUri()) +HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service) { } @@ -745,6 +747,10 @@ HttpRequestRouter::RegisterRoute(const char* UriPattern, HttpRequestRouter::Hand { if (UriPattern[i] == '}') { + if (i == PatternStart) + { + throw std::runtime_error(fmt::format("matcher pattern is empty in URI pattern '{}'", UriPattern)); + } std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart); if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end()) { @@ -910,8 +916,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) CapturedSegments.emplace_back(Uri); - for (int MatcherIndex : Matchers) + for (size_t MatcherOffset = 0; MatcherOffset < Matchers.size(); MatcherOffset++) { + int MatcherIndex = Matchers[MatcherOffset]; if (UriPos >= UriLen) { IsMatch = false; @@ -921,9 +928,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (MatcherIndex < 0) { // Literal match - int LitIndex = -MatcherIndex - 1; - const std::string& LitStr = m_Literals[LitIndex]; - size_t LitLen = LitStr.length(); + int LitIndex = -MatcherIndex - 1; + std::string_view LitStr = m_Literals[LitIndex]; + size_t LitLen = LitStr.length(); if (Uri.substr(UriPos, LitLen) == LitStr) { @@ -939,9 +946,18 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) { // Matcher function size_t SegmentStart = UriPos; - while (UriPos < UriLen && Uri[UriPos] != '/') + + if (MatcherOffset == (Matchers.size() - 1)) + { + // Last matcher, use the remaining part of the uri + UriPos = UriLen; + } + else { - ++UriPos; + while (UriPos < UriLen && Uri[UriPos] != '/') + { + ++UriPos; + } } std::string_view Segment = Uri.substr(SegmentStart, UriPos - SegmentStart); @@ -970,7 +986,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ExtendableStringBuilder<128> RoutePath; - RoutePath.Append(Request.BaseUri()); + RoutePath.Append(Request.Service().BaseUri()); RoutePath.Append(Handler.Pattern); ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); } @@ -994,7 +1010,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ExtendableStringBuilder<128> RoutePath; - RoutePath.Append(Request.BaseUri()); + RoutePath.Append(Request.Service().BaseUri()); RoutePath.Append(Handler.Pattern); ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); } @@ -1014,7 +1030,28 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) int HttpServer::Initialize(int BasePort, std::filesystem::path DataDir) { - return OnInitialize(BasePort, std::move(DataDir)); + m_EffectivePort = OnInitialize(BasePort, std::move(DataDir)); + m_ExternalHost = OnGetExternalHost(); + return m_EffectivePort; +} + +std::string +HttpServer::OnGetExternalHost() const +{ + return GetMachineName(); +} + +std::string +HttpServer::GetServiceUri(const HttpService* Service) const +{ + if (Service) + { + return fmt::format("http://{}:{}{}", m_ExternalHost, m_EffectivePort, Service->BaseUri()); + } + else + { + return fmt::format("http://{}:{}", m_ExternalHost, m_EffectivePort); + } } void @@ -1052,6 +1089,45 @@ HttpServer::EnumerateServices(std::function<void(HttpService& Service)>&& Callba } } +void +HttpServer::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + OnSetHttpRequestFilter(RequestFilter); +} + +CbObject +HttpServer::CollectStats() +{ + CbObjectWriter Cbo; + + metrics::EmitSnapshot("requests", m_RequestMeter, Cbo); + + Cbo.BeginObject("bytes"); + { + Cbo << "received" << GetTotalBytesReceived(); + Cbo << "sent" << GetTotalBytesSent(); + } + Cbo.EndObject(); + + Cbo.BeginObject("websockets"); + { + Cbo << "active_connections" << GetActiveWebSocketConnectionCount(); + Cbo << "frames_received" << m_WsFramesReceived.load(std::memory_order_relaxed); + Cbo << "frames_sent" << m_WsFramesSent.load(std::memory_order_relaxed); + Cbo << "bytes_received" << m_WsBytesReceived.load(std::memory_order_relaxed); + Cbo << "bytes_sent" << m_WsBytesSent.load(std::memory_order_relaxed); + } + Cbo.EndObject(); + + return Cbo.Save(); +} + +void +HttpServer::HandleStatsRequest(HttpServerRequest& Request) +{ + Request.WriteResponse(HttpResponseCode::OK, CollectStats()); +} + ////////////////////////////////////////////////////////////////////////// HttpRpcHandler::HttpRpcHandler() @@ -1294,6 +1370,8 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.httpserver"); + TEST_CASE("http.common") { using namespace std::literals; @@ -1310,7 +1388,11 @@ TEST_CASE("http.common") { TestHttpServerRequest(HttpService& Service, std::string_view Uri) : HttpServerRequest(Service) { m_Uri = Uri; } virtual IoBuffer ReadPayload() override { return IoBuffer(); } - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override + + virtual bool IsLocalMachineRequest() const override { return false; } + virtual std::string_view GetAuthorizationHeader() const override { return {}; } + + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override { ZEN_UNUSED(ResponseCode, ContentType, Blobs); } @@ -1395,20 +1477,33 @@ TEST_CASE("http.common") SUBCASE("router-matcher") { - bool HandledA = false; - bool HandledAA = false; - bool HandledAB = false; - bool HandledAandB = false; + bool HandledA = false; + bool HandledAA = false; + bool HandledAB = false; + bool HandledAandB = false; + bool HandledAandPath = false; std::vector<std::string> Captures; auto Reset = [&] { - HandledA = HandledAA = HandledAB = HandledAandB = false; + HandledA = HandledAA = HandledAB = HandledAandB = HandledAandPath = false; Captures.clear(); }; TestHttpService Service; HttpRequestRouter r; - r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0; }); - r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0; }); + + r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0 && In.find('/') == std::string_view::npos; }); + r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0 && In.find('/') == std::string_view::npos; }); + static constexpr AsciiSet ValidPathCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789/_.,;$~{}+-[]%()]ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + r.AddMatcher("path", [](std::string_view Str) -> bool { return !Str.empty() && AsciiSet::HasOnly(Str, ValidPathCharactersSet); }); + + r.RegisterRoute( + "path/{a}/{path}", + [&](auto& Req) { + HandledAandPath = true; + Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; + }, + HttpVerb::kGet); + r.RegisterRoute( "{a}", [&](auto& Req) { @@ -1437,7 +1532,6 @@ TEST_CASE("http.common") Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; }, HttpVerb::kGet); - { Reset(); TestHttpServerRequest req{Service, "ab"sv}; @@ -1445,6 +1539,7 @@ TEST_CASE("http.common") CHECK(HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 1); CHECK_EQ(Captures[0], "ab"sv); @@ -1457,6 +1552,7 @@ TEST_CASE("http.common") CHECK(!HandledA); CHECK(!HandledAA); CHECK(HandledAB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 2); CHECK_EQ(Captures[0], "ab"sv); CHECK_EQ(Captures[1], "def"sv); @@ -1470,6 +1566,7 @@ TEST_CASE("http.common") CHECK(!HandledAA); CHECK(!HandledAB); CHECK(HandledAandB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 2); CHECK_EQ(Captures[0], "ab"sv); CHECK_EQ(Captures[1], "def"sv); @@ -1482,6 +1579,7 @@ TEST_CASE("http.common") CHECK(!HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); } { @@ -1491,6 +1589,35 @@ TEST_CASE("http.common") CHECK(HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); + REQUIRE_EQ(Captures.size(), 1); + CHECK_EQ(Captures[0], "a123"sv); + } + + { + Reset(); + TestHttpServerRequest req{Service, "path/ab/simple_path.txt"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + CHECK(HandledAandPath); + REQUIRE_EQ(Captures.size(), 2); + CHECK_EQ(Captures[0], "ab"sv); + CHECK_EQ(Captures[1], "simple_path.txt"sv); + } + + { + Reset(); + TestHttpServerRequest req{Service, "path/ab/directory/and/path.txt"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + CHECK(HandledAandPath); + REQUIRE_EQ(Captures.size(), 2); + CHECK_EQ(Captures[0], "ab"sv); + CHECK_EQ(Captures[1], "directory/and/path.txt"sv); } } @@ -1508,6 +1635,8 @@ TEST_CASE("http.common") } } +TEST_SUITE_END(); + void http_forcelink() { diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h index a988346e0..c252a5d99 100644 --- a/src/zenhttp/include/zenhttp/cprutils.h +++ b/src/zenhttp/include/zenhttp/cprutils.h @@ -66,10 +66,10 @@ struct fmt::formatter<cpr::Response> Response.url.str(), Response.status_code, zen::ToString(zen::HttpResponseCode(Response.status_code)), + Response.reason, Response.uploaded_bytes, Response.downloaded_bytes, NiceResponseTime.c_str(), - Response.reason, Json); } else @@ -82,10 +82,10 @@ struct fmt::formatter<cpr::Response> Response.url.str(), Response.status_code, zen::ToString(zen::HttpResponseCode(Response.status_code)), + Response.reason, Response.uploaded_bytes, Response.downloaded_bytes, NiceResponseTime.c_str(), - Response.reason, Body.GetText()); } } diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h index addb00cb8..57ab01158 100644 --- a/src/zenhttp/include/zenhttp/formatters.h +++ b/src/zenhttp/include/zenhttp/formatters.h @@ -73,7 +73,7 @@ struct fmt::formatter<zen::HttpClient::Response> if (Response.IsSuccess()) { return fmt::format_to(Ctx.out(), - "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s", + "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}", ToString(Response.StatusCode), Response.UploadedBytes, Response.DownloadedBytes, diff --git a/src/zenhttp/include/zenhttp/httpapiservice.h b/src/zenhttp/include/zenhttp/httpapiservice.h index 0270973bf..2d384d1d8 100644 --- a/src/zenhttp/include/zenhttp/httpapiservice.h +++ b/src/zenhttp/include/zenhttp/httpapiservice.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include <zenhttp/httpserver.h> diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 9a9b74d72..1bb36a298 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -13,6 +13,7 @@ #include <functional> #include <optional> #include <unordered_map> +#include <vector> namespace zen { @@ -58,6 +59,10 @@ struct HttpClientSettings Oid SessionId = Oid::Zero; bool Verbose = false; uint64_t MaximumInMemoryDownloadSize = 1024u * 1024u; + + /// HTTP status codes that are expected and should not be logged as warnings. + /// 404 is always treated as expected regardless of this list. + std::vector<HttpResponseCode> ExpectedErrorCodes; }; class HttpClientError : public std::runtime_error @@ -113,6 +118,15 @@ private: class HttpClientBase; +/** HTTP Client + * + * This is safe for use on multiple threads simultaneously, as each + * instance maintains an internal connection pool and will synchronize + * access to it as needed. + * + * Uses libcurl under the hood. We currently only use HTTP 1.1 features. + * + */ class HttpClient { public: @@ -123,8 +137,11 @@ public: struct ErrorContext { - int ErrorCode; + int ErrorCode = 0; std::string ErrorMessage; + + /** True when the error is a transport-level connection failure (connect timeout, refused, DNS) */ + bool IsConnectionError() const; }; struct KeyValueMap @@ -171,13 +188,29 @@ public: KeyValueMap Header; // The number of bytes sent as part of the request - int64_t UploadedBytes; + int64_t UploadedBytes = 0; // The number of bytes received as part of the response - int64_t DownloadedBytes; + int64_t DownloadedBytes = 0; // The elapsed time in seconds for the request to execute - double ElapsedSeconds; + double ElapsedSeconds = 0.0; + + struct MultipartBoundary + { + uint64_t OffsetInPayload = 0; + uint64_t RangeOffset = 0; + uint64_t RangeLength = 0; + HttpContentType ContentType; + }; + + // Ranges will map out all received ranges, both single and multi-range responses + // If no range was requested Ranges will be empty + std::vector<MultipartBoundary> Ranges; + + // Map the absolute OffsetAndLengthPairs into ResponsePayload from the ranges received (Ranges). + // If the response was not a partial response, an empty vector will be returned + std::vector<std::pair<uint64_t, uint64_t>> GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const; // This contains any errors from the HTTP stack. It won't contain information on // why the server responded with a non-success HTTP status, that may be gleaned @@ -260,6 +293,16 @@ private: const HttpClientSettings m_ConnectionSettings; }; -void httpclient_forcelink(); // internal +struct LatencyTestResult +{ + bool Success = false; + std::string FailureReason; + double LatencySeconds = -1.0; +}; + +LatencyTestResult MeasureLatency(HttpClient& Client, std::string_view Url); + +void httpclient_forcelink(); // internal +void httpclient_test_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h index bc18549c9..8fca35ac5 100644 --- a/src/zenhttp/include/zenhttp/httpcommon.h +++ b/src/zenhttp/include/zenhttp/httpcommon.h @@ -184,6 +184,13 @@ IsHttpSuccessCode(HttpResponseCode HttpCode) noexcept return IsHttpSuccessCode(int(HttpCode)); } +[[nodiscard]] inline bool +IsHttpOk(HttpResponseCode HttpCode) noexcept +{ + return HttpCode == HttpResponseCode::OK || HttpCode == HttpResponseCode::Created || HttpCode == HttpResponseCode::Accepted || + HttpCode == HttpResponseCode::NoContent; +} + std::string_view ToString(HttpResponseCode HttpCode); } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 3438a1471..0e1714669 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -13,6 +13,8 @@ #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <zentelemetry/stats.h> + #include <functional> #include <gsl/gsl-lite.hpp> #include <list> @@ -30,16 +32,18 @@ class HttpService; */ class HttpServerRequest { -public: +protected: explicit HttpServerRequest(HttpService& Service); + +public: ~HttpServerRequest(); // Synchronous operations [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix - [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } + [[nodiscard]] inline std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; } - [[nodiscard]] inline std::string_view BaseUri() const { return m_BaseUri; } // Service prefix + [[nodiscard]] inline HttpService& Service() const { return m_Service; } struct QueryParams { @@ -79,6 +83,18 @@ public: inline bool IsHandled() const { return !!(m_Flags & kIsHandled); } inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); } inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; } + inline void SetLogRequest(bool ShouldLog) + { + if (ShouldLog) + { + m_Flags |= kLogRequest; + } + else + { + m_Flags &= ~kLogRequest; + } + } + inline bool ShouldLogRequest() const { return !!(m_Flags & kLogRequest); } /** Read POST/PUT payload for request body, which is always available without delay */ @@ -87,6 +103,10 @@ public: CbObject ReadPayloadObject(); CbPackage ReadPayloadPackage(); + virtual bool IsLocalMachineRequest() const = 0; + virtual std::string_view GetAuthorizationHeader() const = 0; + virtual std::string_view GetRemoteAddress() const { return {}; } + /** Respond with payload No data will have been sent when any of these functions return. Instead, the response will be transmitted @@ -115,15 +135,17 @@ protected: kSuppressBody = 1 << 1, kHaveRequestId = 1 << 2, kHaveSessionId = 1 << 3, + kLogRequest = 1 << 4, }; - mutable uint32_t m_Flags = 0; + mutable uint32_t m_Flags = 0; + + HttpService& m_Service; // Service handling this request HttpVerb m_Verb = HttpVerb::kGet; HttpContentType m_ContentType = HttpContentType::kBinary; HttpContentType m_AcceptType = HttpContentType::kUnknownContentType; uint64_t m_ContentLength = ~0ull; - std::string_view m_BaseUri; // Base URI path of the service handling this request - std::string_view m_Uri; // URI without service prefix + std::string_view m_Uri; // URI without service prefix std::string_view m_UriWithExtension; std::string_view m_QueryString; mutable uint32_t m_RequestId = ~uint32_t(0); @@ -144,6 +166,19 @@ public: virtual void OnRequestComplete() = 0; }; +class IHttpRequestFilter +{ +public: + virtual ~IHttpRequestFilter() {} + enum class Result + { + Forbidden, + ResponseSent, + Accepted + }; + virtual Result FilterRequest(HttpServerRequest& Request) = 0; +}; + /** * Base class for implementing an HTTP "service" * @@ -170,30 +205,110 @@ private: int m_UriPrefixLength = 0; }; +struct IHttpStatsProvider +{ + /** Handle an HTTP stats request, writing the response directly. + * Implementations may inspect query parameters on the request + * to include optional detailed breakdowns. + */ + virtual void HandleStatsRequest(HttpServerRequest& Request) = 0; + + /** Return the provider's current stats as a CbObject snapshot. + * Used by the WebSocket push thread to broadcast live updates + * without requiring an HttpServerRequest. Providers that do + * not override this will be skipped in WebSocket broadcasts. + */ + virtual CbObject CollectStats() { return {}; } +}; + +struct IHttpStatsService +{ + virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; + virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; +}; + /** HTTP server * * Implements the main event loop to service HTTP requests, and handles routing * requests to the appropriate handler as registered via RegisterService */ -class HttpServer : public RefCounted +class HttpServer : public RefCounted, public IHttpStatsProvider { public: void RegisterService(HttpService& Service); void EnumerateServices(std::function<void(HttpService&)>&& Callback); + void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter); int Initialize(int BasePort, std::filesystem::path DataDir); void Run(bool IsInteractiveSession); void RequestExit(); void Close(); + /** Returns a canonical http:// URI for the given service, using the external + * IP and the port the server is actually listening on. Only valid + * after Initialize() has returned successfully. + */ + std::string GetServiceUri(const HttpService* Service) const; + + /** Returns the external host string (IP or hostname) determined during Initialize(). + * Only valid after Initialize() has returned successfully. + */ + std::string_view GetExternalHost() const { return m_ExternalHost; } + + /** Returns total bytes received and sent across all connections since server start. */ + virtual uint64_t GetTotalBytesReceived() const { return 0; } + virtual uint64_t GetTotalBytesSent() const { return 0; } + + /** Mark that a request has been handled. Called by server implementations. */ + void MarkRequest() { m_RequestMeter.Mark(); } + + /** Set a default redirect path for root requests */ + void SetDefaultRedirect(std::string_view Path) { m_DefaultRedirect = Path; } + + std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; } + + /** Track active WebSocket connections — called by server implementations on upgrade/close. */ + void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); } + void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); } + uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); } + + /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */ + void OnWebSocketFrameReceived(uint64_t Bytes) + { + m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed); + m_WsBytesReceived.fetch_add(Bytes, std::memory_order_relaxed); + } + void OnWebSocketFrameSent(uint64_t Bytes) + { + m_WsFramesSent.fetch_add(1, std::memory_order_relaxed); + m_WsBytesSent.fetch_add(Bytes, std::memory_order_relaxed); + } + + // IHttpStatsProvider + virtual CbObject CollectStats() override; + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + private: std::vector<HttpService*> m_KnownServices; + int m_EffectivePort = 0; + std::string m_ExternalHost; + metrics::Meter m_RequestMeter; + std::string m_DefaultRedirect; + std::atomic<uint64_t> m_ActiveWebSocketConnections{0}; + std::atomic<uint64_t> m_WsFramesReceived{0}; + std::atomic<uint64_t> m_WsFramesSent{0}; + std::atomic<uint64_t> m_WsBytesReceived{0}; + std::atomic<uint64_t> m_WsBytesSent{0}; virtual void OnRegisterService(HttpService& Service) = 0; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) = 0; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) = 0; virtual void OnRun(bool IsInteractiveSession) = 0; virtual void OnRequestExit() = 0; virtual void OnClose() = 0; + +protected: + virtual std::string OnGetExternalHost() const; }; struct HttpServerPluginConfig @@ -236,7 +351,7 @@ public: inline HttpServerRequest& ServerRequest() { return m_HttpRequest; } private: - HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} + explicit HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} ~HttpRouterRequest() = default; HttpRouterRequest(const HttpRouterRequest&) = delete; @@ -385,7 +500,7 @@ public: ~HttpRpcHandler(); HttpRpcHandler(const HttpRpcHandler&) = delete; - HttpRpcHandler operator=(const HttpRpcHandler&) = delete; + HttpRpcHandler& operator=(const HttpRpcHandler&) = delete; void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction); @@ -401,17 +516,7 @@ private: bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef); -struct IHttpStatsProvider -{ - virtual void HandleStatsRequest(HttpServerRequest& Request) = 0; -}; - -struct IHttpStatsService -{ - virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; - virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; -}; - -void http_forcelink(); // internal +void http_forcelink(); // internal +void websocket_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h index e6fea6765..460315faf 100644 --- a/src/zenhttp/include/zenhttp/httpstats.h +++ b/src/zenhttp/include/zenhttp/httpstats.h @@ -3,23 +3,50 @@ #pragma once #include <zencore/logging.h> +#include <zencore/thread.h> #include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> +#include <atomic> #include <map> +#include <memory> +#include <thread> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio/io_context.hpp> +#include <asio/steady_timer.hpp> +ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -class HttpStatsService : public HttpService, public IHttpStatsService +class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler { public: - HttpStatsService(); + /// Construct without an io_context — optionally uses a dedicated push thread + /// for WebSocket stats broadcasting. + explicit HttpStatsService(bool EnableWebSockets = false); + + /// Construct with an external io_context — uses an asio timer instead + /// of a dedicated thread for WebSocket stats broadcasting. + /// The caller must ensure the io_context outlives this service and that + /// its run loop is active. + HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets = true); + ~HttpStatsService(); + void Shutdown(); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; + // IWebSocketHandler + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; + private: LoggerRef m_Log; HttpRequestRouter m_Router; @@ -28,6 +55,22 @@ private: RwLock m_Lock; std::map<std::string, IHttpStatsProvider*> m_Providers; + + // WebSocket push + RwLock m_WsConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; + std::atomic<bool> m_PushEnabled{false}; + + void BroadcastStats(); + + // Thread-based push (when no io_context is provided) + std::thread m_PushThread; + Event m_PushEvent; + void PushThreadFunction(); + + // Timer-based push (when an io_context is provided) + std::unique_ptr<asio::steady_timer> m_PushTimer; + void EnqueuePushTimer(); }; } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h new file mode 100644 index 000000000..926ec1e3d --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -0,0 +1,79 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zenhttp/httpclient.h> +#include <zenhttp/websocket.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio/io_context.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <chrono> +#include <cstdint> +#include <functional> +#include <memory> +#include <optional> +#include <span> +#include <string> +#include <string_view> + +namespace zen { + +/** + * Callback interface for WebSocket client events + * + * Separate from the server-side IWebSocketHandler because the caller + * already owns the HttpWsClient — no Ref<WebSocketConnection> needed. + */ +class IWsClientHandler +{ +public: + virtual ~IWsClientHandler() = default; + + virtual void OnWsOpen() = 0; + virtual void OnWsMessage(const WebSocketMessage& Msg) = 0; + virtual void OnWsClose(uint16_t Code, std::string_view Reason) = 0; +}; + +struct HttpWsClientSettings +{ + std::string LogCategory = "wsclient"; + std::chrono::milliseconds ConnectTimeout{5000}; + std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider; +}; + +/** + * WebSocket client over TCP (ws:// scheme) + * + * Uses ASIO for async I/O. Two construction modes: + * - Internal io_context + background thread (standalone use) + * - External io_context (shared event loop, no internal thread) + * + * Thread-safe for SendText/SendBinary/Close. + */ +class HttpWsClient +{ +public: + HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings = {}); + HttpWsClient(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings = {}); + + ~HttpWsClient(); + + HttpWsClient(const HttpWsClient&) = delete; + HttpWsClient& operator=(const HttpWsClient&) = delete; + + void Connect(); + void SendText(std::string_view Text); + void SendBinary(std::span<const uint8_t> Data); + void Close(uint16_t Code = 1000, std::string_view Reason = {}); + bool IsOpen() const; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h index c90b840da..1a5068580 100644 --- a/src/zenhttp/include/zenhttp/packageformat.h +++ b/src/zenhttp/include/zenhttp/packageformat.h @@ -68,7 +68,7 @@ struct CbAttachmentEntry struct CbAttachmentReferenceHeader { uint64_t PayloadByteOffset = 0; - uint64_t PayloadByteSize = ~0u; + uint64_t PayloadByteSize = ~uint64_t(0); uint16_t AbsolutePathLength = 0; // This header will be followed by UTF8 encoded absolute path to backing file diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurity.h b/src/zenhttp/include/zenhttp/security/passwordsecurity.h new file mode 100644 index 000000000..6b2b548a6 --- /dev/null +++ b/src/zenhttp/include/zenhttp/security/passwordsecurity.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class PasswordSecurity +{ +public: + struct Configuration + { + std::string Password; + bool ProtectMachineLocalRequests = false; + std::vector<std::string> UnprotectedUris; + }; + + explicit PasswordSecurity(const Configuration& Config); + + [[nodiscard]] inline std::string_view Password() const { return m_Config.Password; } + [[nodiscard]] inline bool ProtectMachineLocalRequests() const { return m_Config.ProtectMachineLocalRequests; } + [[nodiscard]] bool IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const; + + bool IsAllowed(std::string_view Password, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest); + +private: + const Configuration m_Config; + tsl::robin_map<uint32_t, uint32_t> m_UnprotectedUriHashes; +}; + +void passwordsecurity_forcelink(); // internal + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h new file mode 100644 index 000000000..c098f05ad --- /dev/null +++ b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h @@ -0,0 +1,51 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> +#include <zenhttp/security/passwordsecurity.h> + +namespace zen { + +class PasswordHttpFilter : public IHttpRequestFilter +{ +public: + static constexpr std::string_view TypeName = "password"; + + struct Configuration + { + PasswordSecurity::Configuration PasswordConfig; + std::string AuthenticationTypeString; + }; + + /** + * Expected format (Json) + * { + * "password": { # "Authorization: Basic <username:password base64 encoded>" style + * "username": "<username>", + * "password": "<password>" + * }, + * "protect-machine-local-requests": false, + * "unprotected-uris": [ + * "/health/", + * "/health/info", + * "/health/version" + * ] + * } + */ + static Configuration ReadConfiguration(CbObjectView Config); + + explicit PasswordHttpFilter(const PasswordHttpFilter::Configuration& Config) + : m_PasswordSecurity(Config.PasswordConfig) + , m_AuthenticationTypeString(Config.AuthenticationTypeString) + { + } + + virtual Result FilterRequest(HttpServerRequest& Request) override; + +private: + PasswordSecurity m_PasswordSecurity; + const std::string m_AuthenticationTypeString; +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h new file mode 100644 index 000000000..bc3293282 --- /dev/null +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/refcount.h> +#include <zencore/iobuffer.h> + +#include <cstdint> +#include <span> +#include <string_view> + +namespace zen { + +enum class WebSocketOpcode : uint8_t +{ + kText = 0x1, + kBinary = 0x2, + kClose = 0x8, + kPing = 0x9, + kPong = 0xA +}; + +struct WebSocketMessage +{ + WebSocketOpcode Opcode = WebSocketOpcode::kText; + IoBuffer Payload; + uint16_t CloseCode = 0; +}; + +/** + * Represents an active WebSocket connection + * + * Derived classes implement the actual transport (e.g. ASIO sockets). + * Instances are reference-counted so that both the service layer and + * the async read/write loop can share ownership. + */ +class WebSocketConnection : public RefCounted +{ +public: + virtual ~WebSocketConnection() = default; + + virtual void SendText(std::string_view Text) = 0; + virtual void SendBinary(std::span<const uint8_t> Data) = 0; + virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0; + virtual bool IsOpen() const = 0; +}; + +/** + * Interface for services that accept WebSocket upgrades + * + * An HttpService may additionally implement this interface to indicate + * it supports WebSocket connections. The HTTP server checks for this + * via dynamic_cast when it sees an Upgrade: websocket request. + */ +class IWebSocketHandler +{ +public: + virtual ~IWebSocketHandler() = default; + + virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0; + virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0; + virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0; +}; + +} // namespace zen diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp index b097a0d3f..2370def0c 100644 --- a/src/zenhttp/monitoring/httpstats.cpp +++ b/src/zenhttp/monitoring/httpstats.cpp @@ -3,15 +3,57 @@ #include "zenhttp/httpstats.h" #include <zencore/compactbinarybuilder.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/trace.h> namespace zen { -HttpStatsService::HttpStatsService() : m_Log(logging::Get("stats")) +HttpStatsService::HttpStatsService(bool EnableWebSockets) : m_Log(logging::Get("stats")) { + if (EnableWebSockets) + { + m_PushEnabled.store(true); + m_PushThread = std::thread([this] { PushThreadFunction(); }); + } +} + +HttpStatsService::HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets) : m_Log(logging::Get("stats")) +{ + if (EnableWebSockets) + { + m_PushEnabled.store(true); + m_PushTimer = std::make_unique<asio::steady_timer>(IoContext); + EnqueuePushTimer(); + } } HttpStatsService::~HttpStatsService() { + Shutdown(); +} + +void +HttpStatsService::Shutdown() +{ + if (!m_PushEnabled.exchange(false)) + { + return; + } + + if (m_PushTimer) + { + m_PushTimer->cancel(); + m_PushTimer.reset(); + } + + if (m_PushThread.joinable()) + { + m_PushEvent.Set(); + m_PushThread.join(); + } + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); }); } const char* @@ -39,6 +81,7 @@ HttpStatsService::UnregisterHandler(std::string_view Id, IHttpStatsProvider& Pro void HttpStatsService::HandleRequest(HttpServerRequest& Request) { + ZEN_TRACE_CPU("HttpStatsService::HandleRequest"); using namespace std::literals; std::string_view Key = Request.RelativeUri(); @@ -89,4 +132,154 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request) } } +////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +void +HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen"); + ZEN_INFO("Stats WebSocket client connected"); + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); + + // Send initial state immediately + if (m_PushThread.joinable()) + { + m_PushEvent.Set(); + } +} + +void +HttpStatsService::OnWebSocketMessage(WebSocketConnection& /*Conn*/, const WebSocketMessage& /*Msg*/) +{ + // No client-to-server messages expected +} + +void +HttpStatsService::OnWebSocketClose(WebSocketConnection& Conn, [[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + ZEN_TRACE_CPU("HttpStatsService::OnWebSocketClose"); + ZEN_INFO("Stats WebSocket client disconnected (code {})", Code); + + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} + +////////////////////////////////////////////////////////////////////////// +// +// Stats broadcast +// + +void +HttpStatsService::BroadcastStats() +{ + ZEN_TRACE_CPU("HttpStatsService::BroadcastStats"); + std::vector<Ref<WebSocketConnection>> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; }); + + if (Connections.empty()) + { + return; + } + + // Collect stats from all providers + ExtendableStringBuilder<4096> JsonBuilder; + JsonBuilder.Append("{"); + + bool First = true; + { + RwLock::SharedLockScope _(m_Lock); + for (auto& [Id, Provider] : m_Providers) + { + CbObject Stats = Provider->CollectStats(); + if (!Stats) + { + continue; + } + + if (!First) + { + JsonBuilder.Append(","); + } + First = false; + + // Emit as "provider_id": { ... } + JsonBuilder.Append("\""); + JsonBuilder.Append(Id); + JsonBuilder.Append("\":"); + + ExtendableStringBuilder<2048> StatsJson; + Stats.ToJson(StatsJson); + JsonBuilder.Append(StatsJson.ToView()); + } + } + + JsonBuilder.Append("}"); + + std::string_view Json = JsonBuilder.ToView(); + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Json); + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Thread-based push (fallback when no io_context) +// + +void +HttpStatsService::PushThreadFunction() +{ + SetCurrentThreadName("stats_ws_push"); + + while (m_PushEnabled.load()) + { + m_PushEvent.Wait(5000); + m_PushEvent.Reset(); + + if (!m_PushEnabled.load()) + { + break; + } + + BroadcastStats(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Timer-based push (when io_context is provided) +// + +void +HttpStatsService::EnqueuePushTimer() +{ + if (!m_PushTimer) + { + return; + } + + m_PushTimer->expires_after(std::chrono::seconds(5)); + m_PushTimer->async_wait([this](const asio::error_code& Ec) { + if (Ec) + { + return; + } + + BroadcastStats(); + EnqueuePushTimer(); + }); +} + } // namespace zen diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp index 708238224..cbfe4d889 100644 --- a/src/zenhttp/packageformat.cpp +++ b/src/zenhttp/packageformat.cpp @@ -581,7 +581,7 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize); AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView()); - Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy}); + Attachments.emplace_back(CbAttachment(SharedBuffer{AttachmentBufferCopy}, Entry.AttachmentHash)); } else { @@ -805,6 +805,8 @@ CbPackageReader::Finalize() #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.packageformat"); + TEST_CASE("CbPackage.Serialization") { // Make a test package @@ -926,6 +928,8 @@ TEST_CASE("CbPackage.LocalRef") Reader.Finalize(); } +TEST_SUITE_END(); + void forcelink_packageformat() { diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp new file mode 100644 index 000000000..0e3a743c3 --- /dev/null +++ b/src/zenhttp/security/passwordsecurity.cpp @@ -0,0 +1,176 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenhttp/security/passwordsecurity.h" +#include <zencore/compactbinaryutil.h> +#include <zencore/fmtutils.h> +#include <zencore/string.h> + +#if ZEN_WITH_TESTS +# include <zencore/compactbinarybuilder.h> +# include <zencore/testing.h> +#endif // ZEN_WITH_TESTS + +namespace zen { +using namespace std::literals; + +PasswordSecurity::PasswordSecurity(const Configuration& Config) : m_Config(Config) +{ + m_UnprotectedUriHashes.reserve(m_Config.UnprotectedUris.size()); + for (uint32_t Index = 0; Index < m_Config.UnprotectedUris.size(); Index++) + { + const std::string& UnprotectedUri = m_Config.UnprotectedUris[Index]; + if (auto Result = m_UnprotectedUriHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second) + { + throw std::runtime_error(fmt::format( + "password security unprotected uris does not generate unique hashes. Uri #{} ('{}') collides with uri #{} ('{}')", + Index + 1, + UnprotectedUri, + Result.first->second + 1, + m_Config.UnprotectedUris[Result.first->second])); + } + } +} + +bool +PasswordSecurity::IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const +{ + if (!m_Config.UnprotectedUris.empty()) + { + uint32_t UriHash = HashStringDjb2(std::array<const std::string_view, 2>{BaseUri, RelativeUri}); + if (auto It = m_UnprotectedUriHashes.find(UriHash); It != m_UnprotectedUriHashes.end()) + { + const std::string_view& UnprotectedUri = m_Config.UnprotectedUris[It->second]; + if (UnprotectedUri.length() == BaseUri.length() + RelativeUri.length()) + { + if (UnprotectedUri.substr(0, BaseUri.length()) == BaseUri && UnprotectedUri.substr(BaseUri.length()) == RelativeUri) + { + return true; + } + } + } + } + return false; +} + +bool +PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest) +{ + if (IsUnprotectedUri(BaseUri, RelativeUri)) + { + return true; + } + if (!ProtectMachineLocalRequests() && IsMachineLocalRequest) + { + return true; + } + if (Password().empty()) + { + return true; + } + if (Password() == InPassword) + { + return true; + } + return false; +} + +#if ZEN_WITH_TESTS + +TEST_SUITE_BEGIN("http.passwordsecurity"); + +TEST_CASE("passwordsecurity.allowanything") +{ + PasswordSecurity Anything({}); + CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); +} + +TEST_CASE("passwordsecurity.allowalllocal") +{ + PasswordSecurity AllLocal({.Password = "123456"}); + CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); +} + +TEST_CASE("passwordsecurity.allowonlypassword") +{ + PasswordSecurity AllLocal({.Password = "123456", .ProtectMachineLocalRequests = true}); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); +} + +TEST_CASE("passwordsecurity.allowsomeexternaluris") +{ + PasswordSecurity AllLocal( + {.Password = "123456", .ProtectMachineLocalRequests = false, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})}); + CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); +} + +TEST_CASE("passwordsecurity.allowsomelocaluris") +{ + PasswordSecurity AllLocal( + {.Password = "123456", .ProtectMachineLocalRequests = true, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})}); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); +} + +TEST_CASE("passwordsecurity.conflictingunprotecteduris") +{ + try + { + PasswordSecurity AllLocal({.Password = "123456", + .ProtectMachineLocalRequests = true, + .UnprotectedUris = std::vector<std::string>({"/free/access", "/free/access"})}); + CHECK(false); + } + catch (const std::runtime_error& Ex) + { + CHECK_EQ(Ex.what(), + std::string("password security unprotected uris does not generate unique hashes. Uri #2 ('/free/access') collides with " + "uri #1 ('/free/access')")); + } +} + +TEST_SUITE_END(); + +void +passwordsecurity_forcelink() +{ +} +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zenhttp/security/passwordsecurityfilter.cpp b/src/zenhttp/security/passwordsecurityfilter.cpp new file mode 100644 index 000000000..87d8cc275 --- /dev/null +++ b/src/zenhttp/security/passwordsecurityfilter.cpp @@ -0,0 +1,56 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenhttp/security/passwordsecurityfilter.h" + +#include <zencore/base64.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/fmtutils.h> + +namespace zen { + +using namespace std::literals; + +PasswordHttpFilter::Configuration +PasswordHttpFilter::ReadConfiguration(CbObjectView Config) +{ + Configuration Result; + if (CbObjectView PasswordType = Config["basic"sv].AsObjectView(); PasswordType) + { + Result.AuthenticationTypeString = "Basic "; + std::string_view Username = PasswordType["username"sv].AsString(); + std::string_view Password = PasswordType["password"sv].AsString(); + std::string UsernamePassword = fmt::format("{}:{}", Username, Password); + Result.PasswordConfig.Password.resize(Base64::GetEncodedDataSize(uint32_t(UsernamePassword.length()))); + Base64::Encode(reinterpret_cast<const uint8_t*>(UsernamePassword.data()), + uint32_t(UsernamePassword.size()), + const_cast<char*>(Result.PasswordConfig.Password.data())); + } + Result.PasswordConfig.ProtectMachineLocalRequests = Config["protect-machine-local-requests"sv].AsBool(); + Result.PasswordConfig.UnprotectedUris = compactbinary_helpers::ReadArray<std::string>("unprotected-uris"sv, Config); + return Result; +} + +IHttpRequestFilter::Result +PasswordHttpFilter::FilterRequest(HttpServerRequest& Request) +{ + std::string_view Password; + std::string_view AuthorizationHeader = Request.GetAuthorizationHeader(); + size_t AuthorizationHeaderLength = AuthorizationHeader.length(); + if (AuthorizationHeaderLength > m_AuthenticationTypeString.length()) + { + if (StrCaseCompare(AuthorizationHeader.data(), m_AuthenticationTypeString.c_str(), m_AuthenticationTypeString.length()) == 0) + { + Password = AuthorizationHeader.substr(m_AuthenticationTypeString.length()); + } + } + + bool IsAllowed = + m_PasswordSecurity.IsAllowed(Password, Request.Service().BaseUri(), Request.RelativeUri(), Request.IsLocalMachineRequest()); + if (IsAllowed) + { + return Result::Accepted; + } + return Result::Forbidden; +} + +} // namespace zen diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 18a0f6a40..f5178ebe8 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -7,12 +7,15 @@ #include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/memory/llm.h> +#include <zencore/system.h> #include <zencore/thread.h> #include <zencore/trace.h> #include <zencore/windows.h> #include <zenhttp/httpserver.h> #include "httpparser.h" +#include "wsasio.h" +#include "wsframecodec.h" #include <EASTL/fixed_vector.h> @@ -89,15 +92,19 @@ IsIPv6AvailableSysctl(void) char buf[16]; if (fgets(buf, sizeof(buf), f)) { - fclose(f); // 0 means IPv6 enabled, 1 means disabled val = atoi(buf); } + fclose(f); } return val == 0; } +#endif // ZEN_PLATFORM_LINUX +namespace zen { + +#if ZEN_PLATFORM_LINUX bool IsIPv6Capable() { @@ -121,8 +128,6 @@ IsIPv6Capable() } #endif -namespace zen { - const FLLMTag& GetHttpasioTag() { @@ -145,7 +150,7 @@ inline LoggerRef InitLogger() { LoggerRef Logger = logging::Get("asio"); - // Logger.set_level(spdlog::level::trace); + // Logger.SetLogLevel(logging::Trace); return Logger; } @@ -496,16 +501,21 @@ public: HttpAsioServerImpl(); ~HttpAsioServerImpl(); - void Initialize(std::filesystem::path DataDir); - int Start(uint16_t Port, const AsioConfig& Config); - void Stop(); - void RegisterService(const char* UrlPath, HttpService& Service); - HttpService* RouteRequest(std::string_view Url); + void Initialize(std::filesystem::path DataDir); + int Start(uint16_t Port, const AsioConfig& Config); + void Stop(); + void RegisterService(const char* UrlPath, HttpService& Service); + void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter); + HttpService* RouteRequest(std::string_view Url); + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); + + bool IsLoopbackOnly() const; asio::io_service m_IoService; asio::io_service::work m_Work{m_IoService}; std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; std::vector<std::thread> m_ThreadPool; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; LoggerRef m_RequestLog; HttpServerTracer m_RequestTracer; @@ -518,6 +528,11 @@ public: RwLock m_Lock; std::vector<ServiceEntry> m_UriHandlers; + + std::atomic<uint64_t> m_TotalBytesReceived{0}; + std::atomic<uint64_t> m_TotalBytesSent{0}; + + HttpServer* m_HttpServer = nullptr; }; /** @@ -527,12 +542,21 @@ public: class HttpAsioServerRequest : public HttpServerRequest { public: - HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber); + HttpAsioServerRequest(HttpRequestParser& Request, + HttpService& Service, + IoBuffer PayloadBuffer, + uint32_t RequestNumber, + bool IsLocalMachineRequest, + std::string RemoteAddress); ~HttpAsioServerRequest(); virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; + virtual bool IsLocalMachineRequest() const override; + virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; + virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; @@ -548,6 +572,8 @@ public: HttpRequestParser& m_Request; uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers IoBuffer m_PayloadBuffer; + bool m_IsLocalMachineRequest; + std::string m_RemoteAddress; std::unique_ptr<HttpResponse> m_Response; }; @@ -925,6 +951,7 @@ private: void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, uint32_t RequestNumber, HttpResponse* ResponseToPop); void CloseConnection(); + void SendInlineResponse(uint32_t RequestNumber, std::string_view StatusLine, std::string_view Headers = {}, std::string_view Body = {}); HttpAsioServerImpl& m_Server; asio::streambuf m_RequestBuffer; @@ -1025,6 +1052,8 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused] } } + m_Server.m_TotalBytesReceived.fetch_add(ByteCount, std::memory_order_relaxed); + ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed), @@ -1078,6 +1107,8 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, return; } + m_Server.m_TotalBytesSent.fetch_add(ByteCount, std::memory_order_relaxed); + ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}", m_ConnectionId, RequestNumber, @@ -1139,10 +1170,91 @@ HttpServerConnection::CloseConnection() } void +HttpServerConnection::SendInlineResponse(uint32_t RequestNumber, + std::string_view StatusLine, + std::string_view Headers, + std::string_view Body) +{ + ExtendableStringBuilder<256> ResponseBuilder; + ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n"; + if (!Headers.empty()) + { + ResponseBuilder << Headers; + } + if (!m_RequestData.IsKeepAlive()) + { + ResponseBuilder << "Connection: close\r\n"; + } + ResponseBuilder << "\r\n"; + if (!Body.empty()) + { + ResponseBuilder << Body; + } + auto ResponseView = ResponseBuilder.ToView(); + IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size()); + auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize()); + asio::async_write( + *m_Socket.get(), + Buffer, + [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); + }); +} + +void HttpServerConnection::HandleRequest() { ZEN_MEMSCOPE(GetHttpasioTag()); + // WebSocket upgrade detection must happen before the keep-alive check below, + // because Upgrade requests have "Connection: Upgrade" which the HTTP parser + // treats as non-keep-alive, causing a premature shutdown of the receive side. + if (m_RequestData.IsWebSocketUpgrade()) + { + if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) + { + IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service); + if (WsHandler && !m_RequestData.SecWebSocketKey().empty()) + { + std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(m_RequestData.SecWebSocketKey()); + + auto ResponseStr = std::make_shared<std::string>(); + ResponseStr->reserve(256); + ResponseStr->append( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: "); + ResponseStr->append(AcceptKey); + ResponseStr->append("\r\n\r\n"); + + // Send the 101 response on the current socket, then hand the socket off + // to a WsAsioConnection for the WebSocket protocol. + asio::async_write(*m_Socket, + asio::buffer(ResponseStr->data(), ResponseStr->size()), + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); + return; + } + + Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); + Ref<WsAsioConnection> WsConn( + new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + }); + + m_RequestState = RequestState::kDone; + return; + } + } + // Service doesn't support WebSocket or missing key — fall through to normal handling + } + if (!m_RequestData.IsKeepAlive()) { m_RequestState = RequestState::kWritingFinal; @@ -1166,14 +1278,24 @@ HttpServerConnection::HandleRequest() { ZEN_TRACE_CPU("asio::HandleRequest"); - HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber); + m_Server.m_HttpServer->MarkRequest(); + + auto RemoteEndpoint = m_Socket->remote_endpoint(); + bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address(); + + HttpAsioServerRequest Request(m_RequestData, + *Service, + m_RequestData.Body(), + RequestNumber, + IsLocalConnection, + RemoteEndpoint.address().to_string()); ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber); const HttpVerb RequestVerb = Request.RequestVerb(); const std::string_view Uri = Request.RelativeUri(); - if (m_Server.m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server.m_RequestLog.ShouldLog(logging::Trace)) { ZEN_LOG_TRACE(m_Server.m_RequestLog, "connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})", @@ -1188,56 +1310,73 @@ HttpServerConnection::HandleRequest() std::vector<IoBuffer>{Request.ReadPayload()}); } - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_Server.FilterRequest(Request); + if (FilterResult == IHttpRequestFilter::Result::Accepted) { - try - { - Service->HandleRequest(Request); - } - catch (const AssertException& AssertEx) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); - } - catch (const std::system_error& SystemError) + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + try { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + Service->HandleRequest(Request); } - else + catch (const AssertException& AssertEx) { - ZEN_WARN("Caught system error exception while handling request: {}. ({})", - SystemError.what(), - SystemError.code().value()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); } - } - catch (const std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); + catch (const std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (const std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (const std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); - ZEN_WARN("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (const std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_WARN("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } } } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + Request.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); + } if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response)) { + if (Request.ShouldLogRequest()) + { + ZEN_INFO("{} {} {} -> {}", ToString(RequestVerb), Uri, Response->ResponseCode(), NiceBytes(Response->ContentLength())); + } + // Transmit the response if (m_RequestData.RequestVerb() == HttpVerb::kHead) @@ -1278,51 +1417,24 @@ HttpServerConnection::HandleRequest() } } - if (m_RequestData.RequestVerb() == HttpVerb::kHead) + // If a default redirect is configured and the request is for the root path, send a 302 + std::string_view DefaultRedirect = m_Server.m_HttpServer->GetDefaultRedirect(); + if (!DefaultRedirect.empty() && (m_RequestData.Url() == "/" || m_RequestData.Url().empty())) { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "\r\n"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Connection: close\r\n" - "\r\n"sv; - } - - asio::async_write(*m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + ExtendableStringBuilder<128> Headers; + Headers << "Location: " << DefaultRedirect << "\r\nContent-Length: 0\r\n"; + SendInlineResponse(RequestNumber, "302 Found"sv, Headers.ToView()); + } + else if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + SendInlineResponse(RequestNumber, "404 NOT FOUND"sv); } else { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "No suitable route found"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "Connection: close\r\n" - "\r\n" - "No suitable route found"sv; - } - - asio::async_write(*m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + SendInlineResponse(RequestNumber, + "404 NOT FOUND"sv, + "Content-Length: 23\r\nContent-Type: text/plain\r\n"sv, + "No suitable route found"sv); } } @@ -1348,8 +1460,11 @@ struct HttpAcceptor m_Acceptor.set_option(exclusive_address(true)); m_AlternateProtocolAcceptor.set_option(exclusive_address(true)); #else // ZEN_PLATFORM_WINDOWS - m_Acceptor.set_option(asio::socket_base::reuse_address(false)); - m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(false)); + // Allow binding to a port in TIME_WAIT so the server can restart immediately + // after a previous instance exits. On Linux this does not allow two processes + // to actively listen on the same port simultaneously. + m_Acceptor.set_option(asio::socket_base::reuse_address(true)); + m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(true)); #endif // ZEN_PLATFORM_WINDOWS m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); @@ -1512,7 +1627,7 @@ struct HttpAcceptor { ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message()); - return 0; + return {}; } if (EffectivePort != BasePort) @@ -1569,7 +1684,8 @@ struct HttpAcceptor void StopAccepting() { m_IsStopped = true; } - int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } + int GetAcceptPort() const { return m_Acceptor.local_endpoint().port(); } + bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); } bool IsValid() const { return m_IsValid; } @@ -1632,11 +1748,15 @@ private: HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, - uint32_t RequestNumber) + uint32_t RequestNumber, + bool IsLocalMachineRequest, + std::string RemoteAddress) : HttpServerRequest(Service) , m_Request(Request) , m_RequestNumber(RequestNumber) , m_PayloadBuffer(std::move(PayloadBuffer)) +, m_IsLocalMachineRequest(IsLocalMachineRequest) +, m_RemoteAddress(std::move(RemoteAddress)) { const int PrefixLength = Service.UriPrefixLength(); @@ -1708,6 +1828,24 @@ HttpAsioServerRequest::ParseRequestId() const return m_Request.RequestId(); } +bool +HttpAsioServerRequest::IsLocalMachineRequest() const +{ + return m_IsLocalMachineRequest; +} + +std::string_view +HttpAsioServerRequest::GetRemoteAddress() const +{ + return m_RemoteAddress; +} + +std::string_view +HttpAsioServerRequest::GetAuthorizationHeader() const +{ + return m_Request.AuthorizationHeader(); +} + IoBuffer HttpAsioServerRequest::ReadPayload() { @@ -1904,6 +2042,37 @@ HttpAsioServerImpl::RouteRequest(std::string_view Url) return CandidateService; } +void +HttpAsioServerImpl::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + ZEN_MEMSCOPE(GetHttpasioTag()); + RwLock::ExclusiveLockScope _(m_Lock); + m_HttpRequestFilter.store(RequestFilter); +} + +IHttpRequestFilter::Result +HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_Lock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + + return RequestFilter->FilterRequest(Request); +} + +bool +HttpAsioServerImpl::IsLoopbackOnly() const +{ + return m_Acceptor && m_Acceptor->IsLoopbackOnly(); +} + } // namespace zen::asio_http ////////////////////////////////////////////////////////////////////////// @@ -1916,11 +2085,15 @@ public: HttpAsioServer(const AsioConfig& Config); ~HttpAsioServer(); - virtual void OnRegisterService(HttpService& Service) override; - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool IsInteractiveSession) override; - virtual void OnRequestExit() override; - virtual void OnClose() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual void OnRun(bool IsInteractiveSession) override; + virtual void OnRequestExit() override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; + virtual uint64_t GetTotalBytesReceived() const override; + virtual uint64_t GetTotalBytesSent() const override; private: Event m_ShutdownEvent; @@ -1934,6 +2107,7 @@ HttpAsioServer::HttpAsioServer(const AsioConfig& Config) : m_InitialConfig(Config) , m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>()) { + m_Impl->m_HttpServer = this; ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser)); } @@ -1965,6 +2139,12 @@ HttpAsioServer::OnRegisterService(HttpService& Service) m_Impl->RegisterService(Service.BaseUri(), Service); } +void +HttpAsioServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + m_Impl->SetHttpRequestFilter(RequestFilter); +} + int HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) { @@ -1989,10 +2169,46 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) return m_BasePort; } +std::string +HttpAsioServer::OnGetExternalHost() const +{ + if (m_Impl->IsLoopbackOnly()) + { + return "127.0.0.1"; + } + + // Use the UDP connect trick: connecting a UDP socket to an external address + // causes the OS to select the appropriate local interface without sending any data. + try + { + asio::io_service IoService; + asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4()); + Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80)); + return Sock.local_endpoint().address().to_string(); + } + catch (const std::exception&) + { + return GetMachineName(); + } +} + +uint64_t +HttpAsioServer::GetTotalBytesReceived() const +{ + return m_Impl->m_TotalBytesReceived.load(std::memory_order_relaxed); +} + +uint64_t +HttpAsioServer::GetTotalBytesSent() const +{ + return m_Impl->m_TotalBytesSent.load(std::memory_order_relaxed); +} + void HttpAsioServer::OnRun(bool IsInteractive) { - const int WaitTimeout = 1000; + const int WaitTimeout = 1000; + bool ShutdownRequested = false; #if ZEN_PLATFORM_WINDOWS if (IsInteractive) @@ -2008,12 +2224,13 @@ HttpAsioServer::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #else if (IsInteractive) { @@ -2022,8 +2239,8 @@ HttpAsioServer::OnRun(bool IsInteractive) do { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #endif } diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h index c483dfc28..3ec1141a7 100644 --- a/src/zenhttp/servers/httpasio.h +++ b/src/zenhttp/servers/httpasio.h @@ -15,4 +15,6 @@ struct AsioConfig Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config); +bool IsIPv6Capable(); + } // namespace zen diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 31cb04be5..584e06cbf 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -54,9 +54,19 @@ HttpMultiServer::OnInitialize(int BasePort, std::filesystem::path DataDir) } void +HttpMultiServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + for (auto& Server : m_Servers) + { + Server->SetHttpRequestFilter(RequestFilter); + } +} + +void HttpMultiServer::OnRun(bool IsInteractiveSession) { - const int WaitTimeout = 1000; + const int WaitTimeout = 1000; + bool ShutdownRequested = false; #if ZEN_PLATFORM_WINDOWS if (IsInteractiveSession) @@ -72,12 +82,13 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #else if (IsInteractiveSession) { @@ -86,8 +97,8 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) do { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #endif } @@ -106,6 +117,16 @@ HttpMultiServer::OnClose() } } +std::string +HttpMultiServer::OnGetExternalHost() const +{ + if (!m_Servers.empty()) + { + return std::string(m_Servers.front()->GetExternalHost()); + } + return HttpServer::OnGetExternalHost(); +} + void HttpMultiServer::AddServer(Ref<HttpServer> Server) { diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h index ae0ed74cf..97699828a 100644 --- a/src/zenhttp/servers/httpmulti.h +++ b/src/zenhttp/servers/httpmulti.h @@ -15,11 +15,13 @@ public: HttpMultiServer(); ~HttpMultiServer(); - virtual void OnRegisterService(HttpService& Service) override; - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool IsInteractiveSession) override; - virtual void OnRequestExit() override; - virtual void OnClose() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnRun(bool IsInteractiveSession) override; + virtual void OnRequestExit() override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; void AddServer(Ref<HttpServer> Server); diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp index 0ec1cb3c4..9bb7ef3bc 100644 --- a/src/zenhttp/servers/httpnull.cpp +++ b/src/zenhttp/servers/httpnull.cpp @@ -24,6 +24,12 @@ HttpNullServer::OnRegisterService(HttpService& Service) ZEN_UNUSED(Service); } +void +HttpNullServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + ZEN_UNUSED(RequestFilter); +} + int HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir) { @@ -34,7 +40,8 @@ HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir) void HttpNullServer::OnRun(bool IsInteractiveSession) { - const int WaitTimeout = 1000; + const int WaitTimeout = 1000; + bool ShutdownRequested = false; #if ZEN_PLATFORM_WINDOWS if (IsInteractiveSession) @@ -50,12 +57,13 @@ HttpNullServer::OnRun(bool IsInteractiveSession) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #else if (IsInteractiveSession) { @@ -64,8 +72,8 @@ HttpNullServer::OnRun(bool IsInteractiveSession) do { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #endif } diff --git a/src/zenhttp/servers/httpnull.h b/src/zenhttp/servers/httpnull.h index ce7230938..52838f012 100644 --- a/src/zenhttp/servers/httpnull.h +++ b/src/zenhttp/servers/httpnull.h @@ -18,6 +18,7 @@ public: ~HttpNullServer(); virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index 93094e21b..918b55dc6 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -12,13 +12,17 @@ namespace zen { using namespace std::literals; -static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); -static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); -static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); -static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); -static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); -static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); -static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); +static constexpr uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constexpr uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constexpr uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constexpr uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constexpr uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constexpr uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constexpr uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); +static constexpr uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv); +static constexpr uint32_t HashUpgrade = HashStringAsLowerDjb2("Upgrade"sv); +static constexpr uint32_t HashSecWebSocketKey = HashStringAsLowerDjb2("Sec-WebSocket-Key"sv); +static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-WebSocket-Version"sv); ////////////////////////////////////////////////////////////////////////// // @@ -142,41 +146,62 @@ HttpRequestParser::ParseCurrentHeader() const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1); - if (HeaderHash == HashContentLength) + switch (HeaderHash) { - m_ContentLengthHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashAccept) - { - m_AcceptHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashContentType) - { - m_ContentTypeHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashSession) - { - m_SessionId = Oid::TryFromHexString(HeaderValue); - } - else if (HeaderHash == HashRequest) - { - std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); - } - else if (HeaderHash == HashExpect) - { - if (HeaderValue == "100-continue"sv) - { - // We don't currently do anything with this - m_Expect100Continue = true; - } - else - { - ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); - } - } - else if (HeaderHash == HashRange) - { - m_RangeHeaderIndex = CurrentHeaderIndex; + case HashContentLength: + m_ContentLengthHeaderIndex = CurrentHeaderIndex; + break; + + case HashAccept: + m_AcceptHeaderIndex = CurrentHeaderIndex; + break; + + case HashContentType: + m_ContentTypeHeaderIndex = CurrentHeaderIndex; + break; + + case HashAuthorization: + m_AuthorizationHeaderIndex = CurrentHeaderIndex; + break; + + case HashSession: + m_SessionId = Oid::TryFromHexString(HeaderValue); + break; + + case HashRequest: + std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); + break; + + case HashExpect: + if (HeaderValue == "100-continue"sv) + { + // We don't currently do anything with this + m_Expect100Continue = true; + } + else + { + ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); + } + break; + + case HashRange: + m_RangeHeaderIndex = CurrentHeaderIndex; + break; + + case HashUpgrade: + m_UpgradeHeaderIndex = CurrentHeaderIndex; + break; + + case HashSecWebSocketKey: + m_SecWebSocketKeyHeaderIndex = CurrentHeaderIndex; + break; + + case HashSecWebSocketVersion: + m_SecWebSocketVersionHeaderIndex = CurrentHeaderIndex; + break; + + default: + break; } } @@ -220,11 +245,6 @@ NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl) NormalizedUrl.reserve(UrlLength); NormalizedUrl.append(Url, UrlIndex); } - - if (!LastCharWasSeparator) - { - NormalizedUrl.push_back('/'); - } } else if (!NormalizedUrl.empty()) { @@ -305,6 +325,7 @@ HttpRequestParser::OnHeadersComplete() if (ContentLength) { + // TODO: should sanity-check content length here m_BodyBuffer = IoBuffer(ContentLength); } @@ -324,9 +345,9 @@ HttpRequestParser::OnHeadersComplete() int HttpRequestParser::OnBody(const char* Data, size_t Bytes) { - if (m_BodyPosition + Bytes > m_BodyBuffer.Size()) + if ((m_BodyPosition + Bytes) > m_BodyBuffer.Size()) { - ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes", + ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes", (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); return 1; } @@ -337,7 +358,7 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes) { if (m_BodyPosition != m_BodyBuffer.Size()) { - ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); + ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); return 1; } } @@ -353,13 +374,18 @@ HttpRequestParser::ResetState() m_HeaderEntries.clear(); - m_ContentLengthHeaderIndex = -1; - m_AcceptHeaderIndex = -1; - m_ContentTypeHeaderIndex = -1; - m_RangeHeaderIndex = -1; - m_Expect100Continue = false; - m_BodyBuffer = {}; - m_BodyPosition = 0; + m_ContentLengthHeaderIndex = -1; + m_AcceptHeaderIndex = -1; + m_ContentTypeHeaderIndex = -1; + m_RangeHeaderIndex = -1; + m_AuthorizationHeaderIndex = -1; + m_UpgradeHeaderIndex = -1; + m_SecWebSocketKeyHeaderIndex = -1; + m_SecWebSocketVersionHeaderIndex = -1; + m_RequestVerb = HttpVerb::kGet; + m_Expect100Continue = false; + m_BodyBuffer = {}; + m_BodyPosition = 0; m_HeaderData.clear(); m_NormalizedUrl.clear(); @@ -416,4 +442,21 @@ HttpRequestParser::OnMessageComplete() } } +bool +HttpRequestParser::IsWebSocketUpgrade() const +{ + std::string_view Upgrade = GetHeaderValue(m_UpgradeHeaderIndex); + if (Upgrade.empty()) + { + return false; + } + + // Case-insensitive check for "websocket" + if (Upgrade.size() != 9) + { + return false; + } + return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0; +} + } // namespace zen diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index 0d2664ec5..23ad9d8fb 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -46,6 +46,12 @@ struct HttpRequestParser std::string_view RangeHeader() const { return GetHeaderValue(m_RangeHeaderIndex); } + std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); } + + std::string_view UpgradeHeader() const { return GetHeaderValue(m_UpgradeHeaderIndex); } + std::string_view SecWebSocketKey() const { return GetHeaderValue(m_SecWebSocketKeyHeaderIndex); } + bool IsWebSocketUpgrade() const; + private: struct HeaderRange { @@ -83,7 +89,11 @@ private: int8_t m_AcceptHeaderIndex; int8_t m_ContentTypeHeaderIndex; int8_t m_RangeHeaderIndex; - HttpVerb m_RequestVerb; + int8_t m_AuthorizationHeaderIndex; + int8_t m_UpgradeHeaderIndex; + int8_t m_SecWebSocketKeyHeaderIndex; + int8_t m_SecWebSocketVersionHeaderIndex; + HttpVerb m_RequestVerb = HttpVerb::kGet; std::atomic_bool m_KeepAlive{false}; bool m_Expect100Continue = false; int m_RequestId = -1; diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index b9217ed87..4bf8c61bb 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -96,6 +96,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer // HttpPluginServer virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; @@ -104,7 +105,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer virtual void AddPlugin(Ref<TransportPlugin> Plugin) override; virtual void RemovePlugin(Ref<TransportPlugin> Plugin) override; - HttpService* RouteRequest(std::string_view Url); + HttpService* RouteRequest(std::string_view Url); + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); struct ServiceEntry { @@ -112,7 +114,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer HttpService* Service; }; - bool m_IsInitialized = false; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; + bool m_IsInitialized = false; RwLock m_Lock; std::vector<ServiceEntry> m_UriHandlers; std::vector<Ref<TransportPlugin>> m_Plugins; @@ -120,7 +123,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer bool m_IsRequestLoggingEnabled = false; LoggerRef m_RequestLog; std::atomic_uint32_t m_ConnectionIdCounter{0}; - int m_BasePort; + int m_BasePort = 0; HttpServerTracer m_RequestTracer; @@ -143,8 +146,11 @@ public: HttpPluginServerRequest(const HttpPluginServerRequest&) = delete; HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; - virtual Oid ParseSessionId() const override; - virtual uint32_t ParseRequestId() const override; + // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection + virtual bool IsLocalMachineRequest() const /* override*/ { return false; } + virtual std::string_view GetAuthorizationHeader() const override; + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -288,7 +294,7 @@ HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPlug ConnectionName = "anonymous"; } - ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('')", m_ConnectionId, ConnectionName); + ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('{}')", m_ConnectionId, ConnectionName); } uint32_t @@ -372,12 +378,14 @@ HttpPluginConnectionHandler::HandleRequest() { ZEN_TRACE_CPU("http_plugin::HandleRequest"); + m_Server->MarkRequest(); + HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body()); const HttpVerb RequestVerb = Request.RequestVerb(); const std::string_view Uri = Request.RelativeUri(); - if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server->m_RequestLog.ShouldLog(logging::Trace)) { ZEN_LOG_TRACE(m_Server->m_RequestLog, "connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})", @@ -392,53 +400,65 @@ HttpPluginConnectionHandler::HandleRequest() std::vector<IoBuffer>{Request.ReadPayload()}); } - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_Server->FilterRequest(Request); + if (FilterResult == IHttpRequestFilter::Result::Accepted) { - try - { - Service->HandleRequest(Request); - } - catch (const AssertException& AssertEx) + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); - } - catch (const std::system_error& SystemError) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + try { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + Service->HandleRequest(Request); } - else + catch (const AssertException& AssertEx) { - ZEN_WARN("Caught system error exception while handling request: {}. ({})", - SystemError.what(), - SystemError.code().value()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); } - } - catch (const std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); + catch (const std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (const std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (const std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (const std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); - ZEN_WARN("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + ZEN_WARN("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } } } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + Request.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); + } if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response)) { @@ -462,7 +482,7 @@ HttpPluginConnectionHandler::HandleRequest() const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers(); - if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server->m_RequestLog.ShouldLog(logging::Trace)) { m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber), ResponseBuffers); @@ -618,6 +638,12 @@ HttpPluginServerRequest::~HttpPluginServerRequest() { } +std::string_view +HttpPluginServerRequest::GetAuthorizationHeader() const +{ + return m_Request.AuthorizationHeader(); +} + Oid HttpPluginServerRequest::ParseSessionId() const { @@ -750,6 +776,13 @@ HttpPluginServerImpl::OnInitialize(int InBasePort, std::filesystem::path DataDir } void +HttpPluginServerImpl::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_HttpRequestFilter.store(RequestFilter); +} + +void HttpPluginServerImpl::OnClose() { if (!m_IsInitialized) @@ -806,6 +839,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } @@ -894,6 +928,22 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url) return CandidateService; } +IHttpRequestFilter::Result +HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_Lock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + return RequestFilter->FilterRequest(Request); +} + ////////////////////////////////////////////////////////////////////////// struct HttpPluginServerImpl; diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 54cc0c22d..dfe6bb6aa 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -12,6 +12,7 @@ #include <zencore/memory/llm.h> #include <zencore/scopeguard.h> #include <zencore/string.h> +#include <zencore/system.h> #include <zencore/timer.h> #include <zencore/trace.h> #include <zenhttp/packageformat.h> @@ -25,7 +26,9 @@ # include <zencore/workthreadpool.h> # include "iothreadpool.h" +# include <atomic> # include <http.h> +# include <asio.hpp> // for resolving addresses for GetExternalHost namespace zen { @@ -72,6 +75,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In OutString.Append("unknown"); } +class HttpSysServerRequest; + /** * @brief Windows implementation of HTTP server based on http.sys * @@ -83,6 +88,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In class HttpSysServer : public HttpServer { friend class HttpSysTransaction; + friend class HttpMessageResponseRequest; + friend struct InitialRequestHandler; public: explicit HttpSysServer(const HttpSysConfig& Config); @@ -90,17 +97,23 @@ public: // HttpServer interface implementation - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool TestMode) override; - virtual void OnRequestExit() override; - virtual void OnRegisterService(HttpService& Service) override; - virtual void OnClose() override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnRun(bool TestMode) override; + virtual void OnRequestExit() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; + virtual uint64_t GetTotalBytesReceived() const override; + virtual uint64_t GetTotalBytesSent() const override; WorkerThreadPool& WorkPool(); inline bool IsOk() const { return m_IsOk; } inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + IHttpRequestFilter::Result FilterRequest(HttpSysServerRequest& Request); + private: int InitializeServer(int BasePort); void Cleanup(); @@ -124,8 +137,8 @@ private: std::unique_ptr<WinIoThreadPool> m_IoThreadPool; - RwLock m_AsyncWorkPoolInitLock; - WorkerThreadPool* m_AsyncWorkPool = nullptr; + RwLock m_AsyncWorkPoolInitLock; + std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr; std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; @@ -137,6 +150,12 @@ private: int32_t m_MaxPendingRequests = 128; Event m_ShutdownEvent; HttpSysConfig m_InitialConfig; + + RwLock m_RequestFilterLock; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; + + std::atomic<uint64_t> m_TotalBytesReceived{0}; + std::atomic<uint64_t> m_TotalBytesSent{0}; }; } // namespace zen @@ -144,6 +163,10 @@ private: #if ZEN_WITH_HTTPSYS +# include "httpsys_iocontext.h" +# include "wshttpsys.h" +# include "wsframecodec.h" + # include <conio.h> # include <mstcpip.h> # pragma comment(lib, "httpapi.lib") @@ -313,6 +336,10 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; + virtual bool IsLocalMachineRequest() const override; + virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; + virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; @@ -320,16 +347,19 @@ public: virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; virtual bool TryGetRanges(HttpRanges& Ranges) override; + void LogRequest(HttpMessageResponseRequest* Response); + using HttpServerRequest::WriteResponse; HttpSysServerRequest(const HttpSysServerRequest&) = delete; HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; - HttpSysTransaction& m_HttpTx; - HttpSysRequestHandler* m_NextCompletionHandler = nullptr; - IoBuffer m_PayloadBuffer; - ExtendableStringBuilder<128> m_UriUtf8; - ExtendableStringBuilder<128> m_QueryStringUtf8; + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; + mutable ExtendableStringBuilder<64> m_RemoteAddress; }; /** HTTP transaction @@ -363,7 +393,7 @@ public: PTP_IO Iocp(); HANDLE RequestQueueHandle(); - inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; } inline HttpSysServer& Server() { return m_HttpServer; } inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } @@ -380,8 +410,8 @@ public: }; private: - OVERLAPPED m_HttpOverlapped{}; - HttpSysServer& m_HttpServer; + HttpSysIoContext m_IoContext{}; + HttpSysServer& m_HttpServer; // Tracks which handler is due to handle the next I/O completion event HttpSysRequestHandler* m_CompletionHandler = nullptr; @@ -418,7 +448,10 @@ public: virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; void SuppressResponseBody(); // typically used for HEAD requests - inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } + inline uint16_t GetResponseCode() const { return m_ResponseCode; } + inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } + + void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; } private: eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks; @@ -429,6 +462,7 @@ private: bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; eastl::fixed_vector<IoBuffer, 16> m_DataBuffers; + std::string m_LocationHeader; void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); }; @@ -569,7 +603,7 @@ HttpMessageResponseRequest::SuppressResponseBody() HttpSysRequestHandler* HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { - ZEN_UNUSED(NumberOfBytesTransferred); + Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); if (IoResult != NO_ERROR) { @@ -684,6 +718,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + // Location header (for redirects) + + if (!m_LocationHeader.empty()) + { + PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation]; + LocationHeader->pRawValue = m_LocationHeader.data(); + LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size(); + } + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.StatusCode = m_ResponseCode; @@ -694,21 +737,22 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) HTTP_CACHE_POLICY CachePolicy; - CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; + CachePolicy.Policy = HttpCachePolicyNocache; CachePolicy.SecondsToLive = 0; // Initial response API call - SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, - &HttpResponse, - &CachePolicy, - NULL, - NULL, - 0, - Tx.Overlapped(), - NULL); + SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), // RequestQueueHandle + HttpReq->RequestId, // RequestId + SendFlags, // Flags + &HttpResponse, // HttpResponse + &CachePolicy, // CachePolicy + NULL, // BytesSent + NULL, // Reserved1 + 0, // Reserved2 + Tx.Overlapped(), // Overlapped + NULL // LogData + ); m_IsInitialResponse = false; } @@ -716,9 +760,9 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) { // Subsequent response API calls - SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, + SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), // RequestQueueHandle + HttpReq->RequestId, // RequestId + SendFlags, // Flags (USHORT)ThisRequestChunkCount, // EntityChunkCount &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks NULL, // BytesSent @@ -884,7 +928,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr ZEN_UNUSED(IoResult, NumberOfBytesTransferred); - ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); + ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}", + GetSystemErrorAsString(IoResult), + IoResult, + NumberOfBytesTransferred); return this; } @@ -1017,8 +1064,10 @@ HttpSysServer::~HttpSysServer() ZEN_ERROR("~HttpSysServer() called without calling Close() first"); } - delete m_AsyncWorkPool; + auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed); m_AsyncWorkPool = nullptr; + + delete WorkPool; } void @@ -1049,7 +1098,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})", + WideToUtf8(WildcardUrlPath), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1058,7 +1110,7 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result); return 0; } @@ -1082,7 +1134,9 @@ HttpSysServer::InitializeServer(int BasePort) if ((Result == ERROR_SHARING_VIOLATION)) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); @@ -1104,7 +1158,9 @@ HttpSysServer::InitializeServer(int BasePort) { for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); } @@ -1128,25 +1184,29 @@ HttpSysServer::InitializeServer(int BasePort) // port for the current user. eg: // netsh http add urlacl url=http://*:8558/ user=<some_user> - ZEN_WARN( - "Unable to register handler using '{}' - falling back to local-only. " - "Please ensure the appropriate netsh URL reservation configuration " - "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)", - WideToUtf8(WildcardUrlPath)); + if (!m_InitialConfig.ForceLoopback) + { + ZEN_WARN( + "Unable to register handler using '{}' - falling back to local-only. " + "Please ensure the appropriate netsh URL reservation configuration " + "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)", + WideToUtf8(WildcardUrlPath)); + } const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; - ULONG InternalResult = ERROR_SHARING_VIOLATION; - for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + bool ShouldRetryNextPort = true; + for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset) { - EffectivePort = BasePort + (PortOffset * 100); + EffectivePort = BasePort + (PortOffset * 100); + ShouldRetryNextPort = false; for (const std::u8string_view Host : Hosts) { WideStringBuilder<64> LocalUrlPath; LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; - InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); if (InternalResult == NO_ERROR) { @@ -1154,11 +1214,25 @@ HttpSysServer::InitializeServer(int BasePort) m_BaseUris.push_back(LocalUrlPath.c_str()); } + else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED) + { + // Port may be owned by another process's wildcard registration (access denied) + // or actively in use (sharing violation) — retry on a different port + ShouldRetryNextPort = true; + } else { - break; + ZEN_WARN("Failed to register local handler '{}': {} ({:#x})", + WideToUtf8(LocalUrlPath), + GetSystemErrorAsString(InternalResult), + InternalResult); } } + + if (!m_BaseUris.empty()) + { + break; + } } } else @@ -1174,7 +1248,10 @@ HttpSysServer::InitializeServer(int BasePort) if (m_BaseUris.empty()) { - ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})", + WideToUtf8(WildcardUrlPath), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1192,7 +1269,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1204,7 +1284,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1236,7 +1319,7 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result); + ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result); } } @@ -1258,21 +1341,6 @@ HttpSysServer::InitializeServer(int BasePort) ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); } - // This is not available in all Windows SDK versions so for now we can't use recently - // released functionality. We should investigate how to get more recent SDK releases - // into the build - -# if 0 - if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4)) - { - ZEN_DEBUG("HTTP3 is available"); - } - else - { - ZEN_DEBUG("HTTP3 is NOT available"); - } -# endif - return EffectivePort; } @@ -1305,17 +1373,17 @@ HttpSysServer::WorkPool() { ZEN_MEMSCOPE(GetHttpsysTag()); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_acquire)) { RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_relaxed)) { - m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"); + m_AsyncWorkPool.store(new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"), std::memory_order_release); } } - return *m_AsyncWorkPool; + return *m_AsyncWorkPool.load(std::memory_order_relaxed); } void @@ -1337,9 +1405,9 @@ HttpSysServer::OnRun(bool IsInteractive) ZEN_CONSOLE("Zen Server running (http.sys). Press ESC or Q to quit"); } + bool ShutdownRequested = false; do { - // int WaitTimeout = -1; int WaitTimeout = 100; if (IsInteractive) @@ -1352,14 +1420,15 @@ HttpSysServer::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } } - m_ShutdownEvent.Wait(WaitTimeout); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); UpdateLofreqTimerValue(); - } while (!IsApplicationExitRequested()); + } while (!ShutdownRequested); } void @@ -1530,7 +1599,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, // than one thread at any given moment. This means we need to be careful about what // happens in here - HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped); + + switch (IoContext->ContextType) + { + case HttpSysIoContext::Type::kWebSocketRead: + static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kWebSocketWrite: + static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kTransaction: + break; + } + + HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext); if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) { @@ -1641,6 +1726,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) { HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + m_HttpServer.MarkRequest(); + // Default request handling # if ZEN_WITH_OTEL @@ -1666,9 +1753,21 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator); # endif - if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest); + if (FilterResult == IHttpRequestFilter::Result::Accepted) + { + if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + { + Service.HandleRequest(ThisRequest); + } + } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + ThisRequest.WriteResponse(HttpResponseCode::Forbidden); + } + else { - Service.HandleRequest(ThisRequest); + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); } return ThisRequest; @@ -1810,6 +1909,52 @@ HttpSysServerRequest::ParseRequestId() const return 0; } +bool +HttpSysServerRequest::IsLocalMachineRequest() const +{ + const PSOCKADDR LocalAddress = m_HttpTx.HttpRequest()->Address.pLocalAddress; + const PSOCKADDR RemoteAddress = m_HttpTx.HttpRequest()->Address.pRemoteAddress; + if (LocalAddress->sa_family != RemoteAddress->sa_family) + { + return false; + } + if (LocalAddress->sa_family == AF_INET) + { + const SOCKADDR_IN& LocalAddressv4 = (const SOCKADDR_IN&)(*LocalAddress); + const SOCKADDR_IN& RemoteAddressv4 = (const SOCKADDR_IN&)(*RemoteAddress); + return LocalAddressv4.sin_addr.S_un.S_addr == RemoteAddressv4.sin_addr.S_un.S_addr; + } + else if (LocalAddress->sa_family == AF_INET6) + { + const SOCKADDR_IN6& LocalAddressv6 = (const SOCKADDR_IN6&)(*LocalAddress); + const SOCKADDR_IN6& RemoteAddressv6 = (const SOCKADDR_IN6&)(*RemoteAddress); + return memcmp(&LocalAddressv6.sin6_addr, &RemoteAddressv6.sin6_addr, sizeof(in6_addr)) == 0; + } + else + { + return false; + } +} + +std::string_view +HttpSysServerRequest::GetRemoteAddress() const +{ + if (m_RemoteAddress.Size() == 0) + { + const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress; + GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false); + } + return m_RemoteAddress.ToView(); +} + +std::string_view +HttpSysServerRequest::GetAuthorizationHeader() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + const HTTP_KNOWN_HEADER& AuthorizationHeader = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAuthorization]; + return std::string_view(AuthorizationHeader.pRawValue, AuthorizationHeader.RawValueLength); +} + IoBuffer HttpSysServerRequest::ReadPayload() { @@ -1823,7 +1968,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) ZEN_ASSERT(IsHandled() == false); - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); if (SuppressBody()) { @@ -1841,6 +1986,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) # endif SetIsHandled(); + LogRequest(Response); } void @@ -1850,7 +1996,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy ZEN_ASSERT(IsHandled() == false); - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); if (SuppressBody()) { @@ -1868,6 +2014,20 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy # endif SetIsHandled(); + LogRequest(Response); +} + +void +HttpSysServerRequest::LogRequest(HttpMessageResponseRequest* Response) +{ + if (ShouldLogRequest()) + { + ZEN_INFO("{} {} {} -> {}", + ToString(RequestVerb()), + m_UriUtf8.c_str(), + Response->GetResponseCode(), + NiceBytes(Response->GetResponseBodySize())); + } } void @@ -1896,6 +2056,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy # endif SetIsHandled(); + LogRequest(Response); } void @@ -2015,6 +2176,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT break; } + Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); + ZEN_TRACE_CPU("httpsys::HandleCompletion"); // Route request @@ -2023,64 +2186,122 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT { HTTP_REQUEST* HttpReq = HttpRequest(); -# if 0 - for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) { - auto& ReqInfo = HttpReq->pRequestInfo[i]; - - switch (ReqInfo.InfoType) + // WebSocket upgrade detection + if (m_IsInitialRequest) { - case HttpRequestInfoTypeRequestTiming: + const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade]; + if (UpgradeHeader.RawValueLength > 0 && + StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0) + { + if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service)) { - const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo); + // Extract Sec-WebSocket-Key from the unknown headers + // (http.sys has no known-header slot for it) + std::string_view SecWebSocketKey; + for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i) + { + const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i]; + if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0) + { + SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength); + break; + } + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeAuth: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeChannelBind: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslProtocol: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBindingDraft: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBinding: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV0: - { - const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo); + if (SecWebSocketKey.empty()) + { + ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header"); + return nullptr; + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeRequestSizing: - { - const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo); - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeQuicStats: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV1: - { - const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo); + const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey); + + HANDLE RequestQueueHandle = Transaction().RequestQueueHandle(); + HTTP_REQUEST_ID RequestId = HttpReq->RequestId; + + // Build the 101 Switching Protocols response + HTTP_RESPONSE Response = {}; + Response.StatusCode = 101; + Response.pReason = "Switching Protocols"; + Response.ReasonLength = (USHORT)strlen(Response.pReason); + + Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket"; + Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9; + + eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders; - ZEN_INFO(""); + // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders + // despite there being an entry for it there (HttpHeaderConnection). If you try to do + // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below + + UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"}); + + UnknownHeaders.push_back({.NameLength = 20, + .RawValueLength = (USHORT)AcceptKey.size(), + .pName = "Sec-WebSocket-Accept", + .pRawValue = AcceptKey.c_str()}); + + Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size(); + Response.Headers.pUnknownHeaders = UnknownHeaders.data(); + + const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + + // Use an OVERLAPPED with an event so we can wait synchronously. + // The request queue is IOCP-associated, so passing NULL for pOverlapped + // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent + // prevents IOCP delivery and lets us wait on the event directly. + OVERLAPPED SendOverlapped = {}; + HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr); + SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1); + + ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle, + RequestId, + Flags, + &Response, + nullptr, // CachePolicy + nullptr, // BytesSent + nullptr, // Reserved1 + 0, // Reserved2 + &SendOverlapped, + nullptr // LogData + ); + + if (SendResult == ERROR_IO_PENDING) + { + WaitForSingleObject(SendEvent, INFINITE); + SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE; + } + + CloseHandle(SendEvent); + + if (SendResult == NO_ERROR) + { + Transaction().Server().OnWebSocketConnectionOpened(); + Ref<WsHttpSysConnection> WsConn(new WsHttpSysConnection(RequestQueueHandle, + RequestId, + *WsHandler, + Transaction().Iocp(), + &Transaction().Server())); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + + return nullptr; + } + + ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult); + + // WebSocket upgrade failed — return nullptr since ServerRequest() + // was never populated (no InvokeRequestHandler call) + return nullptr; } - break; + // Service doesn't support WebSocket or missing key — fall through to normal handling + } } - } -# endif - if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) - { if (m_IsInitialRequest) { m_ContentLength = GetContentLength(HttpReq); @@ -2146,6 +2367,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); } } + else + { + // If a default redirect is configured and the request is for the root path, send a 302 + std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect(); + std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength); + if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty())) + { + auto* Response = new HttpMessageResponseRequest(Transaction(), 302); + Response->SetLocationHeader(DefaultRedirect); + return Response; + } + } // Unable to route return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); @@ -2205,12 +2438,81 @@ HttpSysServer::OnRequestExit() m_ShutdownEvent.Set(); } +std::string +HttpSysServer::OnGetExternalHost() const +{ + // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback + bool IsPublic = false; + for (const auto& Uri : m_BaseUris) + { + if (Uri.find(L'*') != std::wstring::npos) + { + IsPublic = true; + break; + } + } + + if (!IsPublic) + { + return "127.0.0.1"; + } + + // Use the UDP connect trick: connecting a UDP socket to an external address + // causes the OS to select the appropriate local interface without sending any data. + try + { + asio::io_service IoService; + asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4()); + Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80)); + return Sock.local_endpoint().address().to_string(); + } + catch (const std::exception&) + { + return GetMachineName(); + } +} + +uint64_t +HttpSysServer::GetTotalBytesReceived() const +{ + return m_TotalBytesReceived.load(std::memory_order_relaxed); +} + +uint64_t +HttpSysServer::GetTotalBytesSent() const +{ + return m_TotalBytesSent.load(std::memory_order_relaxed); +} + void HttpSysServer::OnRegisterService(HttpService& Service) { RegisterService(Service.BaseUri(), Service); } +void +HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + RwLock::ExclusiveLockScope _(m_RequestFilterLock); + m_HttpRequestFilter.store(RequestFilter); +} + +IHttpRequestFilter::Result +HttpSysServer::FilterRequest(HttpSysServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_RequestFilterLock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + return RequestFilter->FilterRequest(Request); +} + Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config) { diff --git a/src/zenhttp/servers/httpsys_iocontext.h b/src/zenhttp/servers/httpsys_iocontext.h new file mode 100644 index 000000000..4f8a97012 --- /dev/null +++ b/src/zenhttp/servers/httpsys_iocontext.h @@ -0,0 +1,40 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include <zencore/windows.h> + +# include <cstdint> + +namespace zen { + +/** + * Tagged OVERLAPPED wrapper for http.sys IOCP dispatch + * + * Both HttpSysTransaction (for normal HTTP request I/O) and WsHttpSysConnection + * (for WebSocket read/write) embed this struct. The single IoCompletionCallback + * bound to the request queue uses the ContextType tag to dispatch to the correct + * handler. + * + * The Overlapped member must be first so that CONTAINING_RECORD works to recover + * the HttpSysIoContext from the OVERLAPPED pointer provided by the threadpool. + */ +struct HttpSysIoContext +{ + OVERLAPPED Overlapped{}; + + enum class Type : uint8_t + { + kTransaction, + kWebSocketRead, + kWebSocketWrite, + } ContextType = Type::kTransaction; + + void* Owner = nullptr; +}; + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/httptracer.h b/src/zenhttp/servers/httptracer.h index da72c79c9..a9a45f162 100644 --- a/src/zenhttp/servers/httptracer.h +++ b/src/zenhttp/servers/httptracer.h @@ -1,9 +1,9 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zenhttp/httpserver.h> - #pragma once +#include <zenhttp/httpserver.h> + namespace zen { /** Helper class for HTTP server implementations diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp new file mode 100644 index 000000000..b2543277a --- /dev/null +++ b/src/zenhttp/servers/wsasio.cpp @@ -0,0 +1,311 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wsasio.h" +#include "wsframecodec.h" + +#include <zencore/logging.h> +#include <zenhttp/httpserver.h> + +namespace zen::asio_http { + +static LoggerRef +WsLog() +{ + static LoggerRef g_Logger = logging::Get("ws"); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +WsAsioConnection::WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server) +: m_Socket(std::move(Socket)) +, m_Handler(Handler) +, m_HttpServer(Server) +{ +} + +WsAsioConnection::~WsAsioConnection() +{ + m_IsOpen.store(false); + if (m_HttpServer) + { + m_HttpServer->OnWebSocketConnectionClosed(); + } +} + +void +WsAsioConnection::Start() +{ + EnqueueRead(); +} + +bool +WsAsioConnection::IsOpen() const +{ + return m_IsOpen.load(std::memory_order_relaxed); +} + +////////////////////////////////////////////////////////////////////////// +// +// Read loop +// + +void +WsAsioConnection::EnqueueRead() +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + Ref<WsAsioConnection> Self(this); + + asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) { + Self->OnDataReceived(Ec, ByteCount); + }); +} + +void +WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(WsLog(), "WebSocket read error: {}", Ec.message()); + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); + } + return; + } + + ProcessReceivedData(); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + EnqueueRead(); + } +} + +void +WsAsioConnection::ProcessReceivedData() +{ + while (m_ReadBuffer.size() > 0) + { + const auto& InputBuffer = m_ReadBuffer.data(); + const auto* Data = static_cast<const uint8_t*>(InputBuffer.data()); + const auto Size = InputBuffer.size(); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size); + if (!Frame.IsValid) + { + break; // not enough data yet + } + + m_ReadBuffer.consume(Frame.BytesConsumed); + + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed); + } + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with pong carrying the same payload + std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + // Unsolicited pong — ignore per RFC 6455 + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo close frame back if we haven't sent one yet + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + + m_IsOpen.store(false); + m_Handler.OnWebSocketClose(*this, Code, Reason); + + // Shut down the socket + std::error_code ShutdownEc; + m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc); + m_Socket->close(ShutdownEc); + return; + } + + default: + ZEN_LOG_WARN(WsLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode)); + break; + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Write queue +// + +void +WsAsioConnection::SendText(std::string_view Text) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); +} + +void +WsAsioConnection::SendBinary(std::span<const uint8_t> Data) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); +} + +void +WsAsioConnection::Close(uint16_t Code, std::string_view Reason) +{ + DoClose(Code, Reason); +} + +void +WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason) +{ + if (!m_IsOpen.exchange(false)) + { + return; + } + + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + + m_Handler.OnWebSocketClose(*this, Code, Reason); +} + +void +WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame) +{ + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameSent(Frame.size()); + } + + bool ShouldFlush = false; + + m_WriteLock.WithExclusiveLock([&] { + m_WriteQueue.push_back(std::move(Frame)); + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + }); + + if (ShouldFlush) + { + FlushWriteQueue(); + } +} + +void +WsAsioConnection::FlushWriteQueue() +{ + std::vector<uint8_t> Frame; + + m_WriteLock.WithExclusiveLock([&] { + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + Frame = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + }); + + if (Frame.empty()) + { + return; + } + + Ref<WsAsioConnection> Self(this); + + // Move Frame into a shared_ptr so we can create the buffer and capture ownership + // in the same async_write call without evaluation order issues. + auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame)); + + asio::async_write(*m_Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); }); +} + +void +WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(WsLog(), "WebSocket write error: {}", Ec.message()); + } + + m_WriteLock.WithExclusiveLock([&] { + m_IsWriting = false; + m_WriteQueue.clear(); + }); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + return; + } + + FlushWriteQueue(); +} + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h new file mode 100644 index 000000000..e8bb3b1d2 --- /dev/null +++ b/src/zenhttp/servers/wsasio.h @@ -0,0 +1,77 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/websocket.h> + +#include <zencore/thread.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <deque> +#include <memory> +#include <vector> + +namespace zen { +class HttpServer; +} // namespace zen + +namespace zen::asio_http { + +/** + * WebSocket connection over an ASIO TCP socket + * + * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake) + * and runs an async read/write loop to exchange WebSocket frames. + * + * Lifetime is managed solely through intrusive reference counting (RefCounted). + * The async read/write callbacks capture Ref<WsAsioConnection> to keep the + * connection alive for the duration of the async operation. The service layer + * also holds a Ref<WebSocketConnection>. + */ + +class WsAsioConnection : public WebSocketConnection +{ +public: + WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server); + ~WsAsioConnection() override; + + /** + * Start the async read loop. Must be called once after construction + * and the 101 response has been sent. + */ + void Start(); + + // WebSocketConnection interface + void SendText(std::string_view Text) override; + void SendBinary(std::span<const uint8_t> Data) override; + void Close(uint16_t Code, std::string_view Reason) override; + bool IsOpen() const override; + +private: + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void ProcessReceivedData(); + + void EnqueueWrite(std::vector<uint8_t> Frame); + void FlushWriteQueue(); + void OnWriteComplete(const asio::error_code& Ec, std::size_t ByteCount); + + void DoClose(uint16_t Code, std::string_view Reason); + + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + IWebSocketHandler& m_Handler; + zen::HttpServer* m_HttpServer; + asio::streambuf m_ReadBuffer; + + RwLock m_WriteLock; + std::deque<std::vector<uint8_t>> m_WriteQueue; + bool m_IsWriting = false; + + std::atomic<bool> m_IsOpen{true}; + std::atomic<bool> m_CloseSent{false}; +}; + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp new file mode 100644 index 000000000..e452141fe --- /dev/null +++ b/src/zenhttp/servers/wsframecodec.cpp @@ -0,0 +1,236 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wsframecodec.h" + +#include <zencore/base64.h> +#include <zencore/sha1.h> + +#include <cstring> +#include <random> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Frame parsing +// + +WsFrameParseResult +WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) +{ + // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames) + if (Size < 2) + { + return {}; + } + + const bool Fin = (Data[0] & 0x80) != 0; + const uint8_t OpcodeRaw = Data[0] & 0x0F; + const bool Masked = (Data[1] & 0x80) != 0; + uint64_t PayloadLen = Data[1] & 0x7F; + + size_t HeaderSize = 2; + + if (PayloadLen == 126) + { + if (Size < 4) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]); + HeaderSize = 4; + } + else if (PayloadLen == 127) + { + if (Size < 10) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) | + (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]); + HeaderSize = 10; + } + + // Reject frames with unreasonable payload sizes to prevent OOM + static constexpr uint64_t kMaxPayloadSize = 256 * 1024 * 1024; // 256 MB + if (PayloadLen > kMaxPayloadSize) + { + return {}; + } + + const size_t MaskSize = Masked ? 4 : 0; + const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen; + + if (Size < TotalFrame) + { + return {}; + } + + const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr; + const uint8_t* PayloadData = Data + HeaderSize + MaskSize; + + WsFrameParseResult Result; + Result.IsValid = true; + Result.BytesConsumed = TotalFrame; + Result.Opcode = static_cast<WebSocketOpcode>(OpcodeRaw); + Result.Fin = Fin; + + Result.Payload.resize(static_cast<size_t>(PayloadLen)); + if (PayloadLen > 0) + { + std::memcpy(Result.Payload.data(), PayloadData, static_cast<size_t>(PayloadLen)); + + if (Masked) + { + for (size_t i = 0; i < Result.Payload.size(); ++i) + { + Result.Payload[i] ^= MaskKey[i & 3]; + } + } + } + + return Result; +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (server-to-client, no masking) +// + +std::vector<uint8_t> +WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload) +{ + std::vector<uint8_t> Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast<uint8_t>(Opcode)); + + // Payload length (no mask bit for server frames) + if (PayloadLen < 126) + { + Frame.push_back(static_cast<uint8_t>(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(126); + Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + Frame.insert(Frame.end(), Payload.begin(), Payload.end()); + + return Frame; +} + +std::vector<uint8_t> +WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector<uint8_t> Payload; + Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF)); + Payload.push_back(static_cast<uint8_t>(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (client-to-server, with masking) +// + +std::vector<uint8_t> +WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload) +{ + std::vector<uint8_t> Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast<uint8_t>(Opcode)); + + // Payload length with mask bit set + if (PayloadLen < 126) + { + Frame.push_back(0x80 | static_cast<uint8_t>(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + // Generate random 4-byte mask key + static thread_local std::mt19937 s_Rng(std::random_device{}()); + uint32_t MaskValue = s_Rng(); + uint8_t MaskKey[4]; + std::memcpy(MaskKey, &MaskValue, 4); + + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + // Masked payload + for (size_t i = 0; i < PayloadLen; ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; +} + +std::vector<uint8_t> +WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector<uint8_t> Payload; + Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF)); + Payload.push_back(static_cast<uint8_t>(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2) +// + +static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +std::string +WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey) +{ + // Concatenate client key with the magic GUID + std::string Combined; + Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size()); + Combined.append(ClientKey); + Combined.append(kWebSocketMagicGuid); + + // SHA1 hash + SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size()); + + // Base64 encode the 20-byte hash + char Base64Buf[Base64::GetEncodedDataSize(20) + 1]; + uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf); + Base64Buf[EncodedLen] = '\0'; + + return std::string(Base64Buf, EncodedLen); +} + +} // namespace zen diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h new file mode 100644 index 000000000..2d90b6fa1 --- /dev/null +++ b/src/zenhttp/servers/wsframecodec.h @@ -0,0 +1,74 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/websocket.h> + +#include <cstddef> +#include <cstdint> +#include <optional> +#include <span> +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +/** + * Result of attempting to parse a single WebSocket frame from a byte buffer + */ +struct WsFrameParseResult +{ + bool IsValid = false; // true if a complete frame was successfully parsed + size_t BytesConsumed = 0; // number of bytes consumed from the input buffer + WebSocketOpcode Opcode = WebSocketOpcode::kText; + bool Fin = false; + std::vector<uint8_t> Payload; +}; + +/** + * RFC 6455 WebSocket frame codec + * + * Provides static helpers for parsing client-to-server frames (which are + * always masked) and building server-to-client frames (which are never masked). + */ +struct WsFrameCodec +{ + /** + * Try to parse one complete frame from the front of the buffer. + * + * Returns a result with IsValid == false and BytesConsumed == 0 when + * there is not enough data yet. The caller should accumulate more data + * and retry. + */ + static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size); + + /** + * Build a server-to-client frame (no masking) + */ + static std::vector<uint8_t> BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload); + + /** + * Build a close frame with a status code and optional reason string + */ + static std::vector<uint8_t> BuildCloseFrame(uint16_t Code, std::string_view Reason = {}); + + /** + * Build a client-to-server frame (with masking per RFC 6455) + */ + static std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload); + + /** + * Build a masked close frame with status code and optional reason + */ + static std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason = {}); + + /** + * Compute the Sec-WebSocket-Accept value per RFC 6455 section 4.2.2 + * + * accept = Base64(SHA1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + */ + static std::string ComputeAcceptKey(std::string_view ClientKey); +}; + +} // namespace zen diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp new file mode 100644 index 000000000..af320172d --- /dev/null +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -0,0 +1,485 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wshttpsys.h" + +#if ZEN_WITH_HTTPSYS + +# include "wsframecodec.h" + +# include <zencore/logging.h> +# include <zenhttp/httpserver.h> + +namespace zen { + +static LoggerRef +WsHttpSysLog() +{ + static LoggerRef g_Logger = logging::Get("ws_httpsys"); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle, + HTTP_REQUEST_ID RequestId, + IWebSocketHandler& Handler, + PTP_IO Iocp, + HttpServer* Server) +: m_RequestQueueHandle(RequestQueueHandle) +, m_RequestId(RequestId) +, m_Handler(Handler) +, m_Iocp(Iocp) +, m_HttpServer(Server) +, m_ReadBuffer(8192) +{ + m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead; + m_ReadIoContext.Owner = this; + m_WriteIoContext.ContextType = HttpSysIoContext::Type::kWebSocketWrite; + m_WriteIoContext.Owner = this; +} + +WsHttpSysConnection::~WsHttpSysConnection() +{ + ZEN_ASSERT(m_OutstandingOps.load() == 0); + + if (m_IsOpen.exchange(false)) + { + Disconnect(); + } + + if (m_HttpServer) + { + m_HttpServer->OnWebSocketConnectionClosed(); + } +} + +void +WsHttpSysConnection::Start() +{ + m_SelfRef = Ref<WsHttpSysConnection>(this); + IssueAsyncRead(); +} + +void +WsHttpSysConnection::Shutdown() +{ + m_ShutdownRequested.store(true, std::memory_order_relaxed); + + if (!m_IsOpen.exchange(false)) + { + return; + } + + // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED + HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); +} + +bool +WsHttpSysConnection::IsOpen() const +{ + return m_IsOpen.load(std::memory_order_relaxed); +} + +////////////////////////////////////////////////////////////////////////// +// +// Async read path +// + +void +WsHttpSysConnection::IssueAsyncRead() +{ + if (!m_IsOpen.load(std::memory_order_relaxed) || m_ShutdownRequested.load(std::memory_order_relaxed)) + { + MaybeReleaseSelfRef(); + return; + } + + m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); + + ZeroMemory(&m_ReadIoContext.Overlapped, sizeof(OVERLAPPED)); + + StartThreadpoolIo(m_Iocp); + + ULONG Result = HttpReceiveRequestEntityBody(m_RequestQueueHandle, + m_RequestId, + 0, // Flags + m_ReadBuffer.data(), + (ULONG)m_ReadBuffer.size(), + nullptr, // BytesRead (ignored for async) + &m_ReadIoContext.Overlapped); + + if (Result != NO_ERROR && Result != ERROR_IO_PENDING) + { + CancelThreadpoolIo(m_Iocp); + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "read issue failed"); + } + + MaybeReleaseSelfRef(); + } +} + +void +WsHttpSysConnection::OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // Hold a transient ref to prevent mid-callback destruction after MaybeReleaseSelfRef + Ref<WsHttpSysConnection> Guard(this); + + if (IoResult != NO_ERROR) + { + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.exchange(false)) + { + if (IoResult == ERROR_HANDLE_EOF) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection closed"); + } + else if (IoResult != ERROR_OPERATION_ABORTED) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); + } + } + + MaybeReleaseSelfRef(); + return; + } + + if (NumberOfBytesTransferred > 0) + { + m_Accumulated.insert(m_Accumulated.end(), m_ReadBuffer.begin(), m_ReadBuffer.begin() + NumberOfBytesTransferred); + ProcessReceivedData(); + } + + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + IssueAsyncRead(); + } + else + { + MaybeReleaseSelfRef(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame parsing +// + +void +WsHttpSysConnection::ProcessReceivedData() +{ + while (!m_Accumulated.empty()) + { + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size()); + if (!Frame.IsValid) + { + break; // not enough data yet + } + + // Remove consumed bytes + m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed); + + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed); + } + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with pong carrying the same payload + std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + // Unsolicited pong — ignore per RFC 6455 + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo close frame back if we haven't sent one yet + { + bool ShouldSendClose = false; + { + RwLock::ExclusiveLockScope _(m_WriteLock); + if (!m_CloseSent.exchange(true)) + { + ShouldSendClose = true; + } + } + if (ShouldSendClose) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + } + + m_IsOpen.store(false); + m_Handler.OnWebSocketClose(*this, Code, Reason); + Disconnect(); + return; + } + + default: + ZEN_LOG_WARN(WsHttpSysLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode)); + break; + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Async write path +// + +void +WsHttpSysConnection::EnqueueWrite(std::vector<uint8_t> Frame) +{ + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameSent(Frame.size()); + } + + bool ShouldFlush = false; + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.push_back(std::move(Frame)); + + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + } + + if (ShouldFlush) + { + FlushWriteQueue(); + } +} + +void +WsHttpSysConnection::FlushWriteQueue() +{ + { + RwLock::ExclusiveLockScope _(m_WriteLock); + + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + + m_CurrentWriteBuffer = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + } + + m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); + + ZeroMemory(&m_WriteChunk, sizeof(m_WriteChunk)); + m_WriteChunk.DataChunkType = HttpDataChunkFromMemory; + m_WriteChunk.FromMemory.pBuffer = m_CurrentWriteBuffer.data(); + m_WriteChunk.FromMemory.BufferLength = (ULONG)m_CurrentWriteBuffer.size(); + + ZeroMemory(&m_WriteIoContext.Overlapped, sizeof(OVERLAPPED)); + + StartThreadpoolIo(m_Iocp); + + ULONG Result = HttpSendResponseEntityBody(m_RequestQueueHandle, + m_RequestId, + HTTP_SEND_RESPONSE_FLAG_MORE_DATA, + 1, + &m_WriteChunk, + nullptr, + nullptr, + 0, + &m_WriteIoContext.Overlapped, + nullptr); + + if (Result != NO_ERROR && Result != ERROR_IO_PENDING) + { + CancelThreadpoolIo(m_Iocp); + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket async write failed: {}", Result); + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.clear(); + m_IsWriting = false; + } + m_CurrentWriteBuffer.clear(); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + + MaybeReleaseSelfRef(); + } +} + +void +WsHttpSysConnection::OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(NumberOfBytesTransferred); + + // Hold a transient ref to prevent mid-callback destruction + Ref<WsHttpSysConnection> Guard(this); + + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + m_CurrentWriteBuffer.clear(); + + if (IoResult != NO_ERROR) + { + ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket write completion error: {}", IoResult); + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.clear(); + m_IsWriting = false; + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + + MaybeReleaseSelfRef(); + return; + } + + FlushWriteQueue(); +} + +////////////////////////////////////////////////////////////////////////// +// +// Send interface +// + +void +WsHttpSysConnection::SendText(std::string_view Text) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); +} + +void +WsHttpSysConnection::SendBinary(std::span<const uint8_t> Data) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); +} + +void +WsHttpSysConnection::Close(uint16_t Code, std::string_view Reason) +{ + DoClose(Code, Reason); +} + +void +WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason) +{ + if (!m_IsOpen.exchange(false)) + { + return; + } + + { + bool ShouldSendClose = false; + { + RwLock::ExclusiveLockScope _(m_WriteLock); + if (!m_CloseSent.exchange(true)) + { + ShouldSendClose = true; + } + } + if (ShouldSendClose) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + } + + m_Handler.OnWebSocketClose(*this, Code, Reason); + + // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED + HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); +} + +////////////////////////////////////////////////////////////////////////// +// +// Lifetime management +// + +void +WsHttpSysConnection::MaybeReleaseSelfRef() +{ + if (m_OutstandingOps.load(std::memory_order_relaxed) == 0 && !m_IsOpen.load(std::memory_order_relaxed)) + { + m_SelfRef = nullptr; + } +} + +void +WsHttpSysConnection::Disconnect() +{ + // Send final empty body with DISCONNECT to tell http.sys the connection is done + HttpSendResponseEntityBody(m_RequestQueueHandle, + m_RequestId, + HTTP_SEND_RESPONSE_FLAG_DISCONNECT, + 0, + nullptr, + nullptr, + nullptr, + 0, + nullptr, + nullptr); +} + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h new file mode 100644 index 000000000..6015e3873 --- /dev/null +++ b/src/zenhttp/servers/wshttpsys.h @@ -0,0 +1,107 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/websocket.h> + +#include "httpsys_iocontext.h" + +#include <zencore/thread.h> + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include <zencore/windows.h> +# include <http.h> + +# include <atomic> +# include <deque> +# include <vector> + +namespace zen { + +class HttpServer; + +/** + * WebSocket connection over an http.sys opaque-mode connection + * + * After the 101 Switching Protocols response is sent with + * HTTP_SEND_RESPONSE_FLAG_OPAQUE, http.sys stops parsing HTTP on the + * connection. Raw bytes are exchanged via HttpReceiveRequestEntityBody / + * HttpSendResponseEntityBody using the original RequestId. + * + * All I/O is performed asynchronously via the same IOCP threadpool used + * for normal http.sys traffic, eliminating per-connection threads. + * + * Lifetime is managed through intrusive reference counting (RefCounted). + * A self-reference (m_SelfRef) is held from Start() until all outstanding + * async operations have drained, preventing premature destruction. + */ +class WsHttpSysConnection : public WebSocketConnection +{ +public: + WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp, HttpServer* Server); + ~WsHttpSysConnection() override; + + /** + * Start the async read loop. Must be called once after construction + * and after the 101 response has been sent. + */ + void Start(); + + /** + * Shut down the connection. Cancels pending I/O; IOCP completions + * will fire with ERROR_OPERATION_ABORTED and drain naturally. + */ + void Shutdown(); + + // WebSocketConnection interface + void SendText(std::string_view Text) override; + void SendBinary(std::span<const uint8_t> Data) override; + void Close(uint16_t Code, std::string_view Reason) override; + bool IsOpen() const override; + + // Called from IoCompletionCallback via tagged dispatch + void OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + void OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + +private: + void IssueAsyncRead(); + void ProcessReceivedData(); + void EnqueueWrite(std::vector<uint8_t> Frame); + void FlushWriteQueue(); + void DoClose(uint16_t Code, std::string_view Reason); + void Disconnect(); + void MaybeReleaseSelfRef(); + + HANDLE m_RequestQueueHandle; + HTTP_REQUEST_ID m_RequestId; + IWebSocketHandler& m_Handler; + PTP_IO m_Iocp; + HttpServer* m_HttpServer; + + // Tagged OVERLAPPED contexts for concurrent read and write + HttpSysIoContext m_ReadIoContext{}; + HttpSysIoContext m_WriteIoContext{}; + + // Read state + std::vector<uint8_t> m_ReadBuffer; + std::vector<uint8_t> m_Accumulated; + + // Write state + RwLock m_WriteLock; + std::deque<std::vector<uint8_t>> m_WriteQueue; + std::vector<uint8_t> m_CurrentWriteBuffer; + HTTP_DATA_CHUNK m_WriteChunk{}; + bool m_IsWriting = false; + + // Lifetime management + std::atomic<int32_t> m_OutstandingOps{0}; + Ref<WsHttpSysConnection> m_SelfRef; + std::atomic<bool> m_ShutdownRequested{false}; + std::atomic<bool> m_IsOpen{true}; + std::atomic<bool> m_CloseSent{false}; +}; + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp new file mode 100644 index 000000000..2134e4ff1 --- /dev/null +++ b/src/zenhttp/servers/wstest.cpp @@ -0,0 +1,925 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#if ZEN_WITH_TESTS + +# include <zencore/scopeguard.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> + +# include <zenhttp/httpserver.h> +# include <zenhttp/httpwsclient.h> +# include <zenhttp/websocket.h> + +# include "httpasio.h" +# include "wsframecodec.h" + +ZEN_THIRD_PARTY_INCLUDES_START +# if ZEN_PLATFORM_WINDOWS +# include <winsock2.h> +# else +# include <poll.h> +# include <sys/socket.h> +# endif +# include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +# include <atomic> +# include <chrono> +# include <cstring> +# include <random> +# include <string> +# include <string_view> +# include <thread> +# include <vector> + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// +// Unit tests: WsFrameCodec +// + +TEST_SUITE_BEGIN("http.wstest"); + +TEST_CASE("websocket.framecodec") +{ + SUBCASE("ComputeAcceptKey RFC 6455 test vector") + { + // RFC 6455 section 4.2.2 example + std::string AcceptKey = WsFrameCodec::ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); + CHECK_EQ(AcceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + } + + SUBCASE("BuildFrame and TryParseFrame roundtrip - text") + { + std::string_view Text = "Hello, WebSocket!"; + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + + // Server frames are unmasked — TryParseFrame should handle them + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, Frame.size()); + CHECK(Result.Fin); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), Text.size()); + CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text); + } + + SUBCASE("BuildFrame and TryParseFrame roundtrip - binary") + { + std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, BinaryData); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); + CHECK_EQ(Result.Payload, BinaryData); + } + + SUBCASE("BuildFrame - medium payload (126-65535 bytes)") + { + std::vector<uint8_t> Payload(300, 0x42); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 300u); + CHECK_EQ(Result.Payload, Payload); + } + + SUBCASE("BuildFrame - large payload (>65535 bytes)") + { + std::vector<uint8_t> Payload(70000, 0xAB); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 70000u); + } + + SUBCASE("BuildCloseFrame roundtrip") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildCloseFrame(1000, "normal closure"); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); + REQUIRE(Result.Payload.size() >= 2); + + uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); + CHECK_EQ(Code, 1000); + + std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2); + CHECK_EQ(Reason, "normal closure"); + } + + SUBCASE("TryParseFrame - partial data returns invalid") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{}); + + // Pass only 1 byte — not enough for a frame header + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1); + CHECK_FALSE(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, 0u); + } + + SUBCASE("TryParseFrame - empty payload") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{}); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK(Result.Payload.empty()); + } + + SUBCASE("TryParseFrame - masked client frame") + { + // Build a masked frame manually as a client would send + // Frame: FIN=1, opcode=text, MASK=1, payload_len=5, mask_key=0x37FA213D, payload="Hello" + uint8_t MaskKey[4] = {0x37, 0xFA, 0x21, 0x3D}; + uint8_t MaskedPayload[5] = {}; + const char* Original = "Hello"; + for (int i = 0; i < 5; ++i) + { + MaskedPayload[i] = static_cast<uint8_t>(Original[i]) ^ MaskKey[i % 4]; + } + + std::vector<uint8_t> Frame; + Frame.push_back(0x81); // FIN + text + Frame.push_back(0x85); // MASK + len=5 + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + Frame.insert(Frame.end(), MaskedPayload, MaskedPayload + 5); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), 5u); + CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), 5), "Hello"sv); + } + + SUBCASE("BuildMaskedFrame roundtrip - text") + { + std::string_view Text = "Hello, masked WebSocket!"; + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + + // Verify mask bit is set + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, Frame.size()); + CHECK(Result.Fin); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), Text.size()); + CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text); + } + + SUBCASE("BuildMaskedFrame roundtrip - binary") + { + std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, BinaryData); + + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); + CHECK_EQ(Result.Payload, BinaryData); + } + + SUBCASE("BuildMaskedFrame - medium payload (126-65535 bytes)") + { + std::vector<uint8_t> Payload(300, 0x42); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); + + CHECK((Frame[1] & 0x80) != 0); + CHECK_EQ((Frame[1] & 0x7F), 126); // 16-bit extended length + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 300u); + CHECK_EQ(Result.Payload, Payload); + } + + SUBCASE("BuildMaskedFrame - large payload (>65535 bytes)") + { + std::vector<uint8_t> Payload(70000, 0xAB); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); + + CHECK((Frame[1] & 0x80) != 0); + CHECK_EQ((Frame[1] & 0x7F), 127); // 64-bit extended length + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 70000u); + } + + SUBCASE("BuildMaskedCloseFrame roundtrip") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure"); + + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); + REQUIRE(Result.Payload.size() >= 2); + + uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); + CHECK_EQ(Code, 1000); + + std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2); + CHECK_EQ(Reason, "normal closure"); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Integration tests: WebSocket over ASIO +// + +namespace { + + /** + * Helper: Build a masked client-to-server frame per RFC 6455 + */ + std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload) + { + std::vector<uint8_t> Frame; + + // FIN + opcode + Frame.push_back(0x80 | static_cast<uint8_t>(Opcode)); + + // Payload length with mask bit set + if (Payload.size() < 126) + { + Frame.push_back(0x80 | static_cast<uint8_t>(Payload.size())); + } + else if (Payload.size() <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast<uint8_t>((Payload.size() >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(Payload.size() & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((Payload.size() >> (i * 8)) & 0xFF)); + } + } + + // Mask key (use a fixed key for deterministic tests) + uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78}; + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + // Masked payload + for (size_t i = 0; i < Payload.size(); ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; + } + + std::vector<uint8_t> BuildMaskedTextFrame(std::string_view Text) + { + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + return BuildMaskedFrame(WebSocketOpcode::kText, Payload); + } + + std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code) + { + std::vector<uint8_t> Payload; + Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF)); + Payload.push_back(static_cast<uint8_t>(Code & 0xFF)); + return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); + } + + /** + * Test service that implements IWebSocketHandler + */ + struct WsTestService : public HttpService, public IWebSocketHandler + { + const char* BaseUri() const override { return "/wstest/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello from wstest"); + } + + // IWebSocketHandler + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + { + m_OpenCount.fetch_add(1); + + m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); }); + } + + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override + { + m_MessageCount.fetch_add(1); + + if (Msg.Opcode == WebSocketOpcode::kText) + { + std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size()); + m_LastMessage = std::string(Text); + + // Echo the message back + Conn.SendText(Text); + } + } + + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) override + { + m_CloseCount.fetch_add(1); + m_LastCloseCode = Code; + + m_ConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_Connections.begin(), m_Connections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_Connections.erase(It, m_Connections.end()); + }); + } + + void SendToAll(std::string_view Text) + { + RwLock::SharedLockScope _(m_ConnectionsLock); + for (auto& Conn : m_Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Text); + } + } + } + + std::atomic<int> m_OpenCount{0}; + std::atomic<int> m_MessageCount{0}; + std::atomic<int> m_CloseCount{0}; + std::atomic<uint16_t> m_LastCloseCode{0}; + std::string m_LastMessage; + + RwLock m_ConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_Connections; + }; + + /** + * Helper: Perform the WebSocket upgrade handshake on a raw TCP socket + * + * Returns true on success (101 response), false otherwise. + */ + bool DoWebSocketHandshake(asio::ip::tcp::socket& Sock, std::string_view Path, int Port) + { + // Send HTTP upgrade request + ExtendableStringBuilder<512> Request; + Request << "GET " << Path << " HTTP/1.1\r\n" + << "Host: 127.0.0.1:" << Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + << "Sec-WebSocket-Version: 13\r\n" + << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + + asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); + + // Read the response (look for "101") + asio::streambuf ResponseBuf; + asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + + return Response.find("101") != std::string::npos; + } + + /** + * Helper: Read a single server-to-client frame from a socket + * + * Uses a background thread with a synchronous ASIO read and a timeout. + */ + WsFrameParseResult ReadOneFrame(asio::ip::tcp::socket& Sock, int TimeoutMs = 5000) + { + std::vector<uint8_t> Buffer; + WsFrameParseResult Result; + std::atomic<bool> Done{false}; + + std::thread Reader([&] { + while (!Done.load()) + { + uint8_t Tmp[4096]; + asio::error_code Ec; + size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec); + if (Ec || BytesRead == 0) + { + break; + } + + Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size()); + if (Frame.IsValid) + { + Result = std::move(Frame); + Done.store(true); + return; + } + } + }); + + auto Deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(TimeoutMs); + while (!Done.load() && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + if (!Done.load()) + { + // Timeout — cancel the read + asio::error_code Ec; + Sock.cancel(Ec); + } + + if (Reader.joinable()) + { + Reader.join(); + } + + return Result; + } + +} // anonymous namespace + +TEST_CASE("websocket.integration") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{}); + + int Port = Server->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + // Give server a moment to start accepting + Sleep(100); + + SUBCASE("handshake succeeds with 101") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + CHECK(Ok); + + Sleep(50); + CHECK_EQ(TestService.m_OpenCount.load(), 1); + + Sock.close(); + } + + SUBCASE("normal HTTP still works alongside WebSocket service") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + // Send a normal HTTP GET (not upgrade) + std::string HttpReq = fmt::format( + "GET /wstest/hello HTTP/1.1\r\n" + "Host: 127.0.0.1:{}\r\n" + "Connection: close\r\n" + "\r\n", + Port); + + asio::write(Sock, asio::buffer(HttpReq)); + + asio::streambuf ResponseBuf; + asio::error_code Ec; + asio::read(Sock, ResponseBuf, asio::transfer_at_least(1), Ec); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + CHECK(Response.find("200") != std::string::npos); + } + + SUBCASE("echo message roundtrip") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send a text message (masked, as client) + std::vector<uint8_t> Frame = BuildMaskedTextFrame("ping test"); + asio::write(Sock, asio::buffer(Frame)); + + // Read the echo reply + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); + std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(ReplyText, "ping test"sv); + CHECK_EQ(TestService.m_MessageCount.load(), 1); + CHECK_EQ(TestService.m_LastMessage, "ping test"); + + Sock.close(); + } + + SUBCASE("server push to client") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Server pushes a message + TestService.SendToAll("server says hello"); + + WsFrameParseResult Frame = ReadOneFrame(Sock); + REQUIRE(Frame.IsValid); + CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); + std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size()); + CHECK_EQ(Text, "server says hello"sv); + + Sock.close(); + } + + SUBCASE("client close handshake") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send close frame + std::vector<uint8_t> CloseFrame = BuildMaskedCloseFrame(1000); + asio::write(Sock, asio::buffer(CloseFrame)); + + // Server should echo close back + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose); + + Sleep(50); + CHECK_EQ(TestService.m_CloseCount.load(), 1); + CHECK_EQ(TestService.m_LastCloseCode.load(), 1000); + + Sock.close(); + } + + SUBCASE("multiple concurrent connections") + { + constexpr int NumClients = 5; + + asio::io_context IoCtx; + std::vector<asio::ip::tcp::socket> Sockets; + + for (int i = 0; i < NumClients; ++i) + { + Sockets.emplace_back(IoCtx); + Sockets.back().connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sockets.back(), "/wstest/ws", Port); + REQUIRE(Ok); + } + + Sleep(100); + CHECK_EQ(TestService.m_OpenCount.load(), NumClients); + + // Broadcast from server + TestService.SendToAll("broadcast"); + + // Each client should receive the message + for (int i = 0; i < NumClients; ++i) + { + WsFrameParseResult Frame = ReadOneFrame(Sockets[i]); + REQUIRE(Frame.IsValid); + CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); + std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size()); + CHECK_EQ(Text, "broadcast"sv); + } + + // Close all + for (auto& S : Sockets) + { + S.close(); + } + } + + SUBCASE("service without IWebSocketHandler rejects upgrade") + { + // Register a plain HTTP service (no WebSocket) + struct PlainService : public HttpService + { + const char* BaseUri() const override { return "/plain/"; } + void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK); } + }; + + PlainService Plain; + Server->RegisterService(Plain); + + Sleep(50); + + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + // Attempt WebSocket upgrade on the plain service + ExtendableStringBuilder<512> Request; + Request << "GET /plain/ws HTTP/1.1\r\n" + << "Host: 127.0.0.1:" << Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + << "Sec-WebSocket-Version: 13\r\n" + << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); + + asio::streambuf ResponseBuf; + asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + + // Should NOT get 101 — should fall through to normal request handling + CHECK(Response.find("101") == std::string::npos); + + Sock.close(); + } + + SUBCASE("ping/pong auto-response") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send a ping frame with payload "test" + std::string_view PingPayload = "test"; + std::span<const uint8_t> PingData(reinterpret_cast<const uint8_t*>(PingPayload.data()), PingPayload.size()); + std::vector<uint8_t> PingFrame = BuildMaskedFrame(WebSocketOpcode::kPing, PingData); + asio::write(Sock, asio::buffer(PingFrame)); + + // Should receive a pong with the same payload + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kPong); + CHECK_EQ(Reply.Payload.size(), 4u); + std::string_view PongText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(PongText, "test"sv); + + Sock.close(); + } + + SUBCASE("multiple messages in sequence") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + for (int i = 0; i < 10; ++i) + { + std::string Msg = fmt::format("message {}", i); + std::vector<uint8_t> Frame = BuildMaskedTextFrame(Msg); + asio::write(Sock, asio::buffer(Frame)); + + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); + std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(ReplyText, Msg); + } + + CHECK_EQ(TestService.m_MessageCount.load(), 10); + + Sock.close(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Integration tests: HttpWsClient +// + +namespace { + + struct TestWsClientHandler : public IWsClientHandler + { + void OnWsOpen() override { m_OpenCount.fetch_add(1); } + + void OnWsMessage(const WebSocketMessage& Msg) override + { + if (Msg.Opcode == WebSocketOpcode::kText) + { + std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size()); + m_LastMessage = std::string(Text); + } + m_MessageCount.fetch_add(1); + } + + void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override + { + m_CloseCount.fetch_add(1); + m_LastCloseCode = Code; + } + + std::atomic<int> m_OpenCount{0}; + std::atomic<int> m_MessageCount{0}; + std::atomic<int> m_CloseCount{0}; + std::atomic<uint16_t> m_LastCloseCode{0}; + std::string m_LastMessage; + }; + +} // anonymous namespace + +TEST_CASE("websocket.client") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{}); + + int Port = Server->Initialize(7576, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + Sleep(100); + + SUBCASE("connect, echo, close") + { + TestWsClientHandler Handler; + std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); + + HttpWsClient Client(Url, Handler); + Client.Connect(); + + // Wait for OnWsOpen + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + CHECK(Client.IsOpen()); + + // Send text, expect echo + Client.SendText("hello from client"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + CHECK_EQ(Handler.m_MessageCount.load(), 1); + CHECK_EQ(Handler.m_LastMessage, "hello from client"); + + // Close + Client.Close(1000, "done"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + // The server echoes the close frame, which triggers OnWsClose on the client side + // with the server's close code. Allow the connection to settle. + Sleep(50); + CHECK_FALSE(Client.IsOpen()); + } + + SUBCASE("connect to bad port") + { + TestWsClientHandler Handler; + std::string Url = "ws://127.0.0.1:1/wstest/ws"; + + HttpWsClient Client(Url, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)}); + Client.Connect(); + + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + CHECK_EQ(Handler.m_CloseCount.load(), 1); + CHECK_EQ(Handler.m_LastCloseCode.load(), 1006); + CHECK_EQ(Handler.m_OpenCount.load(), 0); + } + + SUBCASE("server-initiated close") + { + TestWsClientHandler Handler; + std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); + + HttpWsClient Client(Url, Handler); + Client.Connect(); + + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + + // Copy connections then close them outside the lock to avoid deadlocking + // with OnWebSocketClose which acquires an exclusive lock + std::vector<Ref<WebSocketConnection>> Conns; + TestService.m_ConnectionsLock.WithSharedLock([&] { Conns = TestService.m_Connections; }); + for (auto& Conn : Conns) + { + Conn->Close(1001, "going away"); + } + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + CHECK_EQ(Handler.m_CloseCount.load(), 1); + CHECK_EQ(Handler.m_LastCloseCode.load(), 1001); + CHECK_FALSE(Client.IsOpen()); + } +} + +TEST_SUITE_END(); + +void +websocket_forcelink() +{ +} + +} // namespace zen + +#endif // ZEN_WITH_TESTS diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp index 9135d5425..489324aba 100644 --- a/src/zenhttp/transports/dlltransport.cpp +++ b/src/zenhttp/transports/dlltransport.cpp @@ -72,20 +72,36 @@ DllTransportLogger::DllTransportLogger(std::string_view PluginName) : m_PluginNa void DllTransportLogger::LogMessage(LogLevel PluginLogLevel, const char* Message) { - logging::level::LogLevel Level; - // clang-format off switch (PluginLogLevel) { - case LogLevel::Trace: Level = logging::level::Trace; break; - case LogLevel::Debug: Level = logging::level::Debug; break; - case LogLevel::Info: Level = logging::level::Info; break; - case LogLevel::Warn: Level = logging::level::Warn; break; - case LogLevel::Err: Level = logging::level::Err; break; - case LogLevel::Critical: Level = logging::level::Critical; break; - default: Level = logging::level::Off; break; + case LogLevel::Trace: + ZEN_TRACE("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Debug: + ZEN_DEBUG("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Info: + ZEN_INFO("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Warn: + ZEN_WARN("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Err: + ZEN_ERROR("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Critical: + ZEN_CRITICAL("[{}] {}", m_PluginName, Message); + return; + + default: + ZEN_UNUSED(Message); + break; } - // clang-format on - ZEN_LOG(Log(), Level, "[{}] {}", m_PluginName, Message) } uint32_t diff --git a/src/zenhttp/transports/winsocktransport.cpp b/src/zenhttp/transports/winsocktransport.cpp index c06a50c95..0217ed44e 100644 --- a/src/zenhttp/transports/winsocktransport.cpp +++ b/src/zenhttp/transports/winsocktransport.cpp @@ -322,7 +322,7 @@ SocketTransportPluginImpl::Initialize(TransportServer* ServerInterface) else { } - } while (!IsApplicationExitRequested() && m_KeepRunning.test()); + } while (m_KeepRunning.test()); ZEN_INFO("HTTP plugin server accept thread exit"); }); diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 78876d21b..e8f87b668 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -6,6 +6,7 @@ target('zenhttp') add_headerfiles("**.h") add_files("**.cpp") add_files("servers/httpsys.cpp", {unity_ignored=true}) + add_files("servers/wshttpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) add_deps("zencore", "zentelemetry", "transport-sdk", "asio", "cpr") add_packages("http_parser", "json11") diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index a2679f92e..3ac8eea8d 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -7,6 +7,7 @@ # include <zenhttp/httpclient.h> # include <zenhttp/httpserver.h> # include <zenhttp/packageformat.h> +# include <zenhttp/security/passwordsecurity.h> namespace zen { @@ -15,7 +16,10 @@ zenhttp_forcelinktests() { http_forcelink(); httpclient_forcelink(); + httpclient_test_forcelink(); forcelink_packageformat(); + passwordsecurity_forcelink(); + websocket_forcelink(); } } // namespace zen diff --git a/src/zennet-test/zennet-test.cpp b/src/zennet-test/zennet-test.cpp index bc3b8e8e9..1283eb820 100644 --- a/src/zennet-test/zennet-test.cpp +++ b/src/zennet-test/zennet-test.cpp @@ -1,45 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zencore/filesystem.h> -#include <zencore/logging.h> -#include <zencore/trace.h> +#include <zencore/testing.h> #include <zennet/zennet.h> #include <zencore/memory/newdelete.h> -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include <zencore/testing.h> -# include <zencore/process.h> -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zennet_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zennet-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zennet-test", zen::zennet_forcelinktests); #else return 0; #endif diff --git a/src/zennet/beacon.cpp b/src/zennet/beacon.cpp new file mode 100644 index 000000000..394a4afbb --- /dev/null +++ b/src/zennet/beacon.cpp @@ -0,0 +1,170 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zennet/beacon.h> + +#include <zencore/basicfile.h> +#include <zencore/compactbinary.h> +#include <zencore/compactbinaryfile.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/session.h> +#include <zencore/uid.h> + +#include <fmt/format.h> +#include <asio.hpp> +#include <map> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +struct FsBeacon::Impl +{ + Impl(std::filesystem::path ShareRoot); + ~Impl(); + + void EnsureValid(); + + void AddGroup(std::string_view GroupId, CbObject Metadata); + void ScanGroup(std::string_view GroupId, std::vector<Oid>& OutSessions); + void ReadMetadata(std::string_view GroupId, const std::vector<Oid>& InSessions, std::vector<CbObject>& OutMetadata); + +private: + std::filesystem::path m_ShareRoot; + zen::Oid m_SessionId; + + struct GroupData + { + CbObject Metadata; + BasicFile LockFile; + }; + + std::map<std::string, GroupData> m_Registration; + + std::filesystem::path GetSessionMarkerPath(std::string_view GroupId, const Oid& SessionId) + { + Oid::String_t SessionIdString; + SessionId.ToString(SessionIdString); + + return m_ShareRoot / GroupId / SessionIdString; + } +}; + +FsBeacon::Impl::Impl(std::filesystem::path ShareRoot) : m_ShareRoot(ShareRoot), m_SessionId(GetSessionId()) +{ +} + +FsBeacon::Impl::~Impl() +{ +} + +void +FsBeacon::Impl::EnsureValid() +{ +} + +void +FsBeacon::Impl::AddGroup(std::string_view GroupId, CbObject Metadata) +{ + zen::CreateDirectories(m_ShareRoot / GroupId); + std::filesystem::path MarkerFile = GetSessionMarkerPath(GroupId, m_SessionId); + + GroupData& Group = m_Registration[std::string(GroupId)]; + + Group.Metadata = Metadata; + + std::error_code Ec; + Group.LockFile.Open(MarkerFile, + BasicFile::Mode::kTruncate | BasicFile::Mode::kPreventDelete | + BasicFile::Mode::kPreventWrite /* | BasicFile::Mode::kDeleteOnClose */, + Ec); + + if (Ec) + { + throw std::system_error(Ec, fmt::format("failed to open beacon marker file '{}' for write", MarkerFile)); + } + + Group.LockFile.WriteAll(Metadata.GetBuffer().AsIoBuffer(), Ec); + + if (Ec) + { + throw std::system_error(Ec, fmt::format("failed to write to beacon marker file '{}'", MarkerFile)); + } + + Group.LockFile.Flush(); +} + +void +FsBeacon::Impl::ScanGroup(std::string_view GroupId, std::vector<Oid>& OutSessions) +{ + DirectoryContent Dc; + zen::GetDirectoryContent(m_ShareRoot / GroupId, zen::DirectoryContentFlags::IncludeFiles, /* out */ Dc); + + for (const std::filesystem::path& FilePath : Dc.Files) + { + std::filesystem::path File = FilePath.filename(); + + std::error_code Ec; + if (std::filesystem::remove(FilePath, Ec) == false) + { + auto FileString = File.generic_string(); + + if (FileString.length() != Oid::StringLength) + continue; + + if (const Oid SessionId = Oid::FromHexString(FileString)) + { + if (std::filesystem::file_size(File, Ec) > 0) + { + OutSessions.push_back(SessionId); + } + } + } + } +} + +void +FsBeacon::Impl::ReadMetadata(std::string_view GroupId, const std::vector<Oid>& InSessions, std::vector<CbObject>& OutMetadata) +{ + for (const Oid& SessionId : InSessions) + { + const std::filesystem::path MarkerFile = GetSessionMarkerPath(GroupId, SessionId); + + if (CbObject Metadata = LoadCompactBinaryObject(MarkerFile).Object) + { + OutMetadata.push_back(std::move(Metadata)); + } + } +} + +////////////////////////////////////////////////////////////////////////// + +FsBeacon::FsBeacon(std::filesystem::path ShareRoot) : m_Impl(std::make_unique<Impl>(ShareRoot)) +{ +} + +FsBeacon::~FsBeacon() +{ +} + +void +FsBeacon::AddGroup(std::string_view GroupId, CbObject Metadata) +{ + m_Impl->AddGroup(GroupId, Metadata); +} + +void +FsBeacon::ScanGroup(std::string_view GroupId, std::vector<Oid>& OutSessions) +{ + m_Impl->ScanGroup(GroupId, OutSessions); +} + +void +FsBeacon::ReadMetadata(std::string_view GroupId, const std::vector<Oid>& InSessions, std::vector<CbObject>& OutMetadata) +{ + m_Impl->ReadMetadata(GroupId, InSessions, OutMetadata); +} + +////////////////////////////////////////////////////////////////////////// + +} // namespace zen diff --git a/src/zennet/include/zennet/beacon.h b/src/zennet/include/zennet/beacon.h new file mode 100644 index 000000000..a8d4805cb --- /dev/null +++ b/src/zennet/include/zennet/beacon.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zennet/zennet.h> + +#include <zencore/uid.h> + +#include <filesystem> +#include <memory> +#include <string> +#include <vector> + +namespace zen { + +class CbObject; + +/** File-system based peer discovery + + Intended to be used with an SMB file share as the root. + */ + +class FsBeacon +{ +public: + FsBeacon(std::filesystem::path ShareRoot); + ~FsBeacon(); + + void AddGroup(std::string_view GroupId, CbObject Metadata); + void ScanGroup(std::string_view GroupId, std::vector<Oid>& OutSessions); + void ReadMetadata(std::string_view GroupId, const std::vector<Oid>& InSessions, std::vector<CbObject>& OutMetadata); + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen diff --git a/src/zennet/include/zennet/statsdclient.h b/src/zennet/include/zennet/statsdclient.h index c378e49ce..7688c132c 100644 --- a/src/zennet/include/zennet/statsdclient.h +++ b/src/zennet/include/zennet/statsdclient.h @@ -8,6 +8,8 @@ #include <memory> #include <string_view> +#undef SendMessage + namespace zen { class StatsTransportBase diff --git a/src/zennet/statsdclient.cpp b/src/zennet/statsdclient.cpp index fe5ca4dda..8afa2e835 100644 --- a/src/zennet/statsdclient.cpp +++ b/src/zennet/statsdclient.cpp @@ -12,6 +12,7 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <zencore/windows.h> #include <asio.hpp> +#undef SendMessage ZEN_THIRD_PARTY_INCLUDES_END namespace zen { @@ -379,6 +380,8 @@ statsd_forcelink() { } +TEST_SUITE_BEGIN("net.statsdclient"); + TEST_CASE("zennet.statsd.emit") { // auto Client = CreateStatsDaemonClient("localhost", 8125); @@ -458,6 +461,8 @@ TEST_CASE("zennet.statsd.batch") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zennomad/include/zennomad/nomadclient.h b/src/zennomad/include/zennomad/nomadclient.h new file mode 100644 index 000000000..0a3411ace --- /dev/null +++ b/src/zennomad/include/zennomad/nomadclient.h @@ -0,0 +1,77 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zennomad/nomadconfig.h> + +#include <zencore/logbase.h> + +#include <memory> +#include <string> +#include <vector> + +namespace zen { +class HttpClient; +} + +namespace zen::nomad { + +/** Summary of a Nomad job returned by the API. */ +struct NomadJobInfo +{ + std::string Id; + std::string Status; ///< "pending", "running", "dead" + std::string StatusDescription; +}; + +/** Summary of a Nomad allocation returned by the API. */ +struct NomadAllocInfo +{ + std::string Id; + std::string ClientStatus; ///< "pending", "running", "complete", "failed" + std::string TaskState; ///< State of the task within the allocation +}; + +/** HTTP client for the Nomad REST API (v1). + * + * Handles job submission, status polling, and job termination. + * All calls are synchronous. Thread safety: individual methods are + * not thread-safe; callers must synchronize access. + */ +class NomadClient +{ +public: + explicit NomadClient(const NomadConfig& Config); + ~NomadClient(); + + NomadClient(const NomadClient&) = delete; + NomadClient& operator=(const NomadClient&) = delete; + + /** Initialize the underlying HTTP client. Must be called before other methods. */ + bool Initialize(); + + /** Build the Nomad job registration JSON for the given job ID and orchestrator endpoint. + * The JSON structure varies based on the configured driver and distribution mode. */ + std::string BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const; + + /** Submit a job via PUT /v1/jobs. On success, populates OutJob with the job info. */ + bool SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob); + + /** Get the status of a job via GET /v1/job/{jobId}. */ + bool GetJobStatus(const std::string& JobId, NomadJobInfo& OutJob); + + /** Get allocations for a job via GET /v1/job/{jobId}/allocations. */ + bool GetAllocations(const std::string& JobId, std::vector<NomadAllocInfo>& OutAllocs); + + /** Stop a job via DELETE /v1/job/{jobId}. */ + bool StopJob(const std::string& JobId); + + LoggerRef Log() { return m_Log; } + +private: + NomadConfig m_Config; + std::unique_ptr<zen::HttpClient> m_Http; + LoggerRef m_Log; +}; + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/nomadconfig.h b/src/zennomad/include/zennomad/nomadconfig.h new file mode 100644 index 000000000..92d2bbaca --- /dev/null +++ b/src/zennomad/include/zennomad/nomadconfig.h @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zennomad/zennomad.h> + +#include <string> + +namespace zen::nomad { + +/** Nomad task driver type. */ +enum class Driver +{ + RawExec, ///< Use Nomad raw_exec driver (direct process execution) + Docker, ///< Use Nomad Docker driver +}; + +/** How the zenserver binary is made available on Nomad clients. */ +enum class BinaryDistribution +{ + PreDeployed, ///< Binary is already present on Nomad client nodes + Artifact, ///< Download binary via Nomad artifact stanza +}; + +/** Configuration for Nomad worker provisioning. + * + * Specifies the Nomad server URL, authentication, resource limits, and + * job configuration. Used by NomadClient and NomadProvisioner. + */ +struct NomadConfig +{ + bool Enabled = false; ///< Whether Nomad provisioning is active + std::string ServerUrl; ///< Nomad HTTP API URL (e.g. "http://localhost:4646") + std::string AclToken; ///< Nomad ACL token (sent as X-Nomad-Token header) + std::string Datacenter = "dc1"; ///< Target datacenter + std::string Namespace = "default"; ///< Nomad namespace + std::string Region; ///< Nomad region (empty = server default) + + Driver TaskDriver = Driver::RawExec; ///< Task driver for job execution + BinaryDistribution BinDistribution = BinaryDistribution::PreDeployed; ///< How to distribute the zenserver binary + + std::string BinaryPath; ///< Path to zenserver on Nomad clients (PreDeployed mode) + std::string ArtifactSource; ///< URL to download zenserver binary (Artifact mode) + std::string DockerImage; ///< Docker image name (Docker driver mode) + + int MaxJobs = 64; ///< Maximum concurrent Nomad jobs + int CpuMhz = 1000; ///< CPU MHz allocated per task + int MemoryMb = 2048; ///< Memory MB allocated per task + int CoresPerJob = 32; ///< Estimated cores per job (for scaling calculations) + int MaxCores = 2048; ///< Maximum total cores to provision + + std::string JobPrefix = "zenserver-worker"; ///< Prefix for generated Nomad job IDs + + /** Validate the configuration. Returns false if required fields are missing + * or incompatible options are set. */ + bool Validate() const; +}; + +const char* ToString(Driver D); +const char* ToString(BinaryDistribution Dist); + +bool FromString(Driver& OutDriver, std::string_view Str); +bool FromString(BinaryDistribution& OutDist, std::string_view Str); + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/nomadprocess.h b/src/zennomad/include/zennomad/nomadprocess.h new file mode 100644 index 000000000..a66c2ce41 --- /dev/null +++ b/src/zennomad/include/zennomad/nomadprocess.h @@ -0,0 +1,78 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpclient.h> + +#include <memory> +#include <string> +#include <string_view> +#include <vector> + +namespace zen::nomad { + +struct NomadJobInfo; +struct NomadAllocInfo; + +/** Manages a Nomad agent process running in dev mode for testing. + * + * Spawns `nomad agent -dev` and polls the HTTP API until the agent + * is ready. On destruction or via StopNomadAgent(), the agent + * process is killed. + */ +class NomadProcess +{ +public: + NomadProcess(); + ~NomadProcess(); + + NomadProcess(const NomadProcess&) = delete; + NomadProcess& operator=(const NomadProcess&) = delete; + + /** Spawn a Nomad dev agent and block until the leader endpoint responds (10 s timeout). */ + void SpawnNomadAgent(); + + /** Kill the Nomad agent process. */ + void StopNomadAgent(); + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +/** Lightweight HTTP wrapper around the Nomad v1 REST API for use in tests. + * + * Unlike the production NomadClient (which requires a NomadConfig and + * supports all driver/distribution modes), this client exposes a simpler + * interface geared towards test scenarios. + */ +class NomadTestClient +{ +public: + explicit NomadTestClient(std::string_view BaseUri); + ~NomadTestClient(); + + NomadTestClient(const NomadTestClient&) = delete; + NomadTestClient& operator=(const NomadTestClient&) = delete; + + /** Submit a raw_exec batch job. + * Returns the parsed job info on success; Id will be empty on failure. */ + NomadJobInfo SubmitJob(std::string_view JobId, std::string_view Command, const std::vector<std::string>& Args); + + /** Query the status of an existing job. */ + NomadJobInfo GetJobStatus(std::string_view JobId); + + /** Stop (deregister) a running job. */ + void StopJob(std::string_view JobId); + + /** Get allocations for a job. */ + std::vector<NomadAllocInfo> GetAllocations(std::string_view JobId); + + /** List all jobs, optionally filtered by prefix. */ + std::vector<NomadJobInfo> ListJobs(std::string_view Prefix = ""); + +private: + HttpClient m_HttpClient; +}; + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/nomadprovisioner.h b/src/zennomad/include/zennomad/nomadprovisioner.h new file mode 100644 index 000000000..750693b3f --- /dev/null +++ b/src/zennomad/include/zennomad/nomadprovisioner.h @@ -0,0 +1,107 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zennomad/nomadconfig.h> + +#include <zencore/logbase.h> + +#include <atomic> +#include <condition_variable> +#include <cstdint> +#include <memory> +#include <mutex> +#include <string> +#include <thread> +#include <vector> + +namespace zen::nomad { + +class NomadClient; + +/** Snapshot of the current Nomad provisioning state, returned by NomadProvisioner::GetStats(). */ +struct NomadProvisioningStats +{ + uint32_t TargetCoreCount = 0; ///< Requested number of cores (clamped to MaxCores) + uint32_t EstimatedCoreCount = 0; ///< Cores expected from submitted jobs + uint32_t ActiveJobCount = 0; ///< Number of currently tracked Nomad jobs + uint32_t RunningJobCount = 0; ///< Number of jobs in "running" status +}; + +/** Job lifecycle manager for Nomad worker provisioning. + * + * Provisions remote compute workers by submitting batch jobs to a Nomad + * cluster via the REST API. Each job runs zenserver in compute mode, which + * announces itself back to the orchestrator. + * + * Uses a single management thread that periodically: + * 1. Submits new jobs when estimated cores < target cores + * 2. Polls existing jobs for status changes + * 3. Cleans up dead/failed jobs and adjusts counters + * + * Thread safety: SetTargetCoreCount and GetStats may be called from any thread. + */ +class NomadProvisioner +{ +public: + /** Construct a provisioner. + * @param Config Nomad connection and job configuration. + * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */ + NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint); + + /** Signals the management thread to exit and stops all tracked jobs. */ + ~NomadProvisioner(); + + NomadProvisioner(const NomadProvisioner&) = delete; + NomadProvisioner& operator=(const NomadProvisioner&) = delete; + + /** Set the target number of cores to provision. + * Clamped to NomadConfig::MaxCores. The management thread will + * submit new jobs to approach this target. */ + void SetTargetCoreCount(uint32_t Count); + + /** Return a snapshot of the current provisioning counters. */ + NomadProvisioningStats GetStats() const; + +private: + LoggerRef Log() { return m_Log; } + + struct TrackedJob + { + std::string JobId; + std::string Status; ///< "pending", "running", "dead" + int Cores = 0; + }; + + void ManagementThread(); + void SubmitNewJobs(); + void PollExistingJobs(); + void CleanupDeadJobs(); + void StopAllJobs(); + + std::string GenerateJobId(); + + NomadConfig m_Config; + std::string m_OrchestratorEndpoint; + + std::unique_ptr<NomadClient> m_Client; + + mutable std::mutex m_JobsLock; + std::vector<TrackedJob> m_Jobs; + std::atomic<uint32_t> m_JobIndex{0}; + + std::atomic<uint32_t> m_TargetCoreCount{0}; + std::atomic<uint32_t> m_EstimatedCoreCount{0}; + std::atomic<uint32_t> m_RunningJobCount{0}; + + std::thread m_Thread; + std::mutex m_WakeMutex; + std::condition_variable m_WakeCV; + std::atomic<bool> m_ShouldExit{false}; + + uint32_t m_ProcessId = 0; + + LoggerRef m_Log; +}; + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/zennomad.h b/src/zennomad/include/zennomad/zennomad.h new file mode 100644 index 000000000..09fb98dfe --- /dev/null +++ b/src/zennomad/include/zennomad/zennomad.h @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if !defined(ZEN_WITH_NOMAD) +# define ZEN_WITH_NOMAD 1 +#endif diff --git a/src/zennomad/nomadclient.cpp b/src/zennomad/nomadclient.cpp new file mode 100644 index 000000000..9edcde125 --- /dev/null +++ b/src/zennomad/nomadclient.cpp @@ -0,0 +1,366 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memoryview.h> +#include <zencore/trace.h> +#include <zenhttp/httpclient.h> +#include <zennomad/nomadclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::nomad { + +namespace { + + HttpClient::KeyValueMap MakeNomadHeaders(const NomadConfig& Config) + { + HttpClient::KeyValueMap Headers; + if (!Config.AclToken.empty()) + { + Headers->emplace("X-Nomad-Token", Config.AclToken); + } + return Headers; + } + +} // namespace + +NomadClient::NomadClient(const NomadConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("nomad.client")) +{ +} + +NomadClient::~NomadClient() = default; + +bool +NomadClient::Initialize() +{ + ZEN_TRACE_CPU("NomadClient::Initialize"); + + HttpClientSettings Settings; + Settings.LogCategory = "nomad.http"; + Settings.ConnectTimeout = std::chrono::milliseconds{10000}; + Settings.Timeout = std::chrono::milliseconds{60000}; + Settings.RetryCount = 1; + + // Ensure the base URL ends with a slash so path concatenation works correctly + std::string BaseUrl = m_Config.ServerUrl; + if (!BaseUrl.empty() && BaseUrl.back() != '/') + { + BaseUrl += '/'; + } + + m_Http = std::make_unique<zen::HttpClient>(BaseUrl, Settings); + + return true; +} + +std::string +NomadClient::BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const +{ + ZEN_TRACE_CPU("NomadClient::BuildJobJson"); + + // Build the task config based on driver and distribution mode + json11::Json::object TaskConfig; + + if (m_Config.TaskDriver == Driver::RawExec) + { + std::string Command; + if (m_Config.BinDistribution == BinaryDistribution::PreDeployed) + { + Command = m_Config.BinaryPath; + } + else + { + // Artifact mode: binary is downloaded to local/zenserver + Command = "local/zenserver"; + } + + TaskConfig["command"] = Command; + + json11::Json::array Args; + Args.push_back("compute"); + Args.push_back("--http=asio"); + if (!OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << OrchestratorEndpoint; + Args.push_back(std::string(CoordArg.ToView())); + } + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=nomad-" << JobId; + Args.push_back(std::string(IdArg.ToView())); + } + TaskConfig["args"] = Args; + } + else + { + // Docker driver + TaskConfig["image"] = m_Config.DockerImage; + + json11::Json::array Args; + Args.push_back("compute"); + Args.push_back("--http=asio"); + if (!OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << OrchestratorEndpoint; + Args.push_back(std::string(CoordArg.ToView())); + } + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=nomad-" << JobId; + Args.push_back(std::string(IdArg.ToView())); + } + TaskConfig["args"] = Args; + } + + // Build resource stanza + json11::Json::object Resources; + Resources["CPU"] = m_Config.CpuMhz; + Resources["MemoryMB"] = m_Config.MemoryMb; + + // Build the task + json11::Json::object Task; + Task["Name"] = "zenserver"; + Task["Driver"] = (m_Config.TaskDriver == Driver::RawExec) ? "raw_exec" : "docker"; + Task["Config"] = TaskConfig; + Task["Resources"] = Resources; + + // Add artifact stanza if using artifact distribution + if (m_Config.BinDistribution == BinaryDistribution::Artifact && !m_Config.ArtifactSource.empty()) + { + json11::Json::object Artifact; + Artifact["GetterSource"] = m_Config.ArtifactSource; + + json11::Json::array Artifacts; + Artifacts.push_back(Artifact); + Task["Artifacts"] = Artifacts; + } + + json11::Json::array Tasks; + Tasks.push_back(Task); + + // Build the task group + json11::Json::object Group; + Group["Name"] = "zenserver-group"; + Group["Count"] = 1; + Group["Tasks"] = Tasks; + + json11::Json::array Groups; + Groups.push_back(Group); + + // Build datacenters array + json11::Json::array Datacenters; + Datacenters.push_back(m_Config.Datacenter); + + // Build the job + json11::Json::object Job; + Job["ID"] = JobId; + Job["Name"] = JobId; + Job["Type"] = "batch"; + Job["Datacenters"] = Datacenters; + Job["TaskGroups"] = Groups; + + if (!m_Config.Namespace.empty() && m_Config.Namespace != "default") + { + Job["Namespace"] = m_Config.Namespace; + } + + if (!m_Config.Region.empty()) + { + Job["Region"] = m_Config.Region; + } + + // Wrap in the registration envelope + json11::Json::object Root; + Root["Job"] = Job; + + return json11::Json(Root).dump(); +} + +bool +NomadClient::SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob) +{ + ZEN_TRACE_CPU("NomadClient::SubmitJob"); + + const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{JobJson.data(), JobJson.size()}, ZenContentType::kJSON); + + const HttpClient::Response Response = m_Http->Put("v1/jobs", Payload, MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad job submit failed: {}", Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast<int>(Response.StatusCode); + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad job submit failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON response from Nomad job submit: {}", Err); + return false; + } + + // The response contains EvalID; the job ID is what we submitted + OutJob.Id = Json["JobModifyIndex"].is_number() ? OutJob.Id : ""; + OutJob.Status = "pending"; + + ZEN_INFO("Nomad job submitted: eval_id={}", Json["EvalID"].string_value()); + + return true; +} + +bool +NomadClient::GetJobStatus(const std::string& JobId, NomadJobInfo& OutJob) +{ + ZEN_TRACE_CPU("NomadClient::GetJobStatus"); + + ExtendableStringBuilder<128> Path; + Path << "v1/job/" << JobId; + + const HttpClient::Response Response = m_Http->Get(Path.ToView(), MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad job status query failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast<int>(Response.StatusCode); + + if (StatusCode == 404) + { + ZEN_INFO("Nomad job '{}' not found", JobId); + OutJob.Status = "dead"; + return true; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad job status query failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON in Nomad job status response: {}", Err); + return false; + } + + OutJob.Id = Json["ID"].string_value(); + OutJob.Status = Json["Status"].string_value(); + if (const json11::Json Desc = Json["StatusDescription"]; Desc.is_string()) + { + OutJob.StatusDescription = Desc.string_value(); + } + + return true; +} + +bool +NomadClient::GetAllocations(const std::string& JobId, std::vector<NomadAllocInfo>& OutAllocs) +{ + ZEN_TRACE_CPU("NomadClient::GetAllocations"); + + ExtendableStringBuilder<128> Path; + Path << "v1/job/" << JobId << "/allocations"; + + const HttpClient::Response Response = m_Http->Get(Path.ToView(), MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad allocation query failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad allocation query failed with HTTP/{}", static_cast<int>(Response.StatusCode)); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON in Nomad allocation response: {}", Err); + return false; + } + + OutAllocs.clear(); + if (!Json.is_array()) + { + return true; + } + + for (const json11::Json& AllocVal : Json.array_items()) + { + NomadAllocInfo Alloc; + Alloc.Id = AllocVal["ID"].string_value(); + Alloc.ClientStatus = AllocVal["ClientStatus"].string_value(); + + // Extract task state if available + if (const json11::Json TaskStates = AllocVal["TaskStates"]; TaskStates.is_object()) + { + for (const auto& [TaskName, TaskState] : TaskStates.object_items()) + { + if (TaskState["State"].is_string()) + { + Alloc.TaskState = TaskState["State"].string_value(); + } + } + } + + OutAllocs.push_back(std::move(Alloc)); + } + + return true; +} + +bool +NomadClient::StopJob(const std::string& JobId) +{ + ZEN_TRACE_CPU("NomadClient::StopJob"); + + ExtendableStringBuilder<128> Path; + Path << "v1/job/" << JobId; + + const HttpClient::Response Response = m_Http->Delete(Path.ToView(), MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad job stop failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad job stop failed with HTTP/{}", static_cast<int>(Response.StatusCode)); + return false; + } + + ZEN_INFO("Nomad job '{}' stopped", JobId); + return true; +} + +} // namespace zen::nomad diff --git a/src/zennomad/nomadconfig.cpp b/src/zennomad/nomadconfig.cpp new file mode 100644 index 000000000..d55b3da9a --- /dev/null +++ b/src/zennomad/nomadconfig.cpp @@ -0,0 +1,91 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zennomad/nomadconfig.h> + +namespace zen::nomad { + +bool +NomadConfig::Validate() const +{ + if (ServerUrl.empty()) + { + return false; + } + + if (BinDistribution == BinaryDistribution::PreDeployed && BinaryPath.empty()) + { + return false; + } + + if (BinDistribution == BinaryDistribution::Artifact && ArtifactSource.empty()) + { + return false; + } + + if (TaskDriver == Driver::Docker && DockerImage.empty()) + { + return false; + } + + return true; +} + +const char* +ToString(Driver D) +{ + switch (D) + { + case Driver::RawExec: + return "raw_exec"; + case Driver::Docker: + return "docker"; + } + return "raw_exec"; +} + +const char* +ToString(BinaryDistribution Dist) +{ + switch (Dist) + { + case BinaryDistribution::PreDeployed: + return "predeployed"; + case BinaryDistribution::Artifact: + return "artifact"; + } + return "predeployed"; +} + +bool +FromString(Driver& OutDriver, std::string_view Str) +{ + if (Str == "raw_exec") + { + OutDriver = Driver::RawExec; + return true; + } + if (Str == "docker") + { + OutDriver = Driver::Docker; + return true; + } + return false; +} + +bool +FromString(BinaryDistribution& OutDist, std::string_view Str) +{ + if (Str == "predeployed") + { + OutDist = BinaryDistribution::PreDeployed; + return true; + } + if (Str == "artifact") + { + OutDist = BinaryDistribution::Artifact; + return true; + } + return false; +} + +} // namespace zen::nomad diff --git a/src/zennomad/nomadprocess.cpp b/src/zennomad/nomadprocess.cpp new file mode 100644 index 000000000..1ae968fb7 --- /dev/null +++ b/src/zennomad/nomadprocess.cpp @@ -0,0 +1,354 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zennomad/nomadclient.h> +#include <zennomad/nomadprocess.h> + +#include <zenbase/zenbase.h> +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memoryview.h> +#include <zencore/process.h> +#include <zencore/timer.h> +#include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <fmt/format.h> + +namespace zen::nomad { + +////////////////////////////////////////////////////////////////////////// + +struct NomadProcess::Impl +{ + Impl(std::string_view BaseUri) : m_HttpClient(BaseUri) {} + ~Impl() = default; + + void SpawnNomadAgent() + { + ZEN_TRACE_CPU("SpawnNomadAgent"); + + if (m_ProcessHandle.IsValid()) + { + return; + } + + CreateProcOptions Options; + Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + + CreateProcResult Result = CreateProc("nomad" ZEN_EXE_SUFFIX_LITERAL, "nomad" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options); + + if (Result) + { + m_ProcessHandle.Initialize(Result); + + Stopwatch Timer; + + // Poll to check when the agent is ready + + do + { + Sleep(100); + HttpClient::Response Resp = m_HttpClient.Get("v1/status/leader"); + if (Resp) + { + ZEN_INFO("Nomad agent started successfully (waited {})", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + + return; + } + } while (Timer.GetElapsedTimeMs() < 30000); + } + + // Report failure! + + ZEN_WARN("Nomad agent failed to start within timeout period"); + } + + void StopNomadAgent() + { + if (!m_ProcessHandle.IsValid()) + { + return; + } + + // This waits for the process to exit and also resets the handle + m_ProcessHandle.Kill(); + } + +private: + ProcessHandle m_ProcessHandle; + HttpClient m_HttpClient; +}; + +NomadProcess::NomadProcess() : m_Impl(std::make_unique<Impl>("http://localhost:4646/")) +{ +} + +NomadProcess::~NomadProcess() +{ +} + +void +NomadProcess::SpawnNomadAgent() +{ + m_Impl->SpawnNomadAgent(); +} + +void +NomadProcess::StopNomadAgent() +{ + m_Impl->StopNomadAgent(); +} + +////////////////////////////////////////////////////////////////////////// + +NomadTestClient::NomadTestClient(std::string_view BaseUri) : m_HttpClient(BaseUri) +{ +} + +NomadTestClient::~NomadTestClient() +{ +} + +NomadJobInfo +NomadTestClient::SubmitJob(std::string_view JobId, std::string_view Command, const std::vector<std::string>& Args) +{ + ZEN_TRACE_CPU("SubmitNomadJob"); + + NomadJobInfo Result; + + // Build the job JSON for a raw_exec batch job + json11::Json::object TaskConfig; + TaskConfig["command"] = std::string(Command); + + json11::Json::array JsonArgs; + for (const auto& Arg : Args) + { + JsonArgs.push_back(Arg); + } + TaskConfig["args"] = JsonArgs; + + json11::Json::object Resources; + Resources["CPU"] = 100; + Resources["MemoryMB"] = 64; + + json11::Json::object Task; + Task["Name"] = "test-task"; + Task["Driver"] = "raw_exec"; + Task["Config"] = TaskConfig; + Task["Resources"] = Resources; + + json11::Json::array Tasks; + Tasks.push_back(Task); + + json11::Json::object Group; + Group["Name"] = "test-group"; + Group["Count"] = 1; + Group["Tasks"] = Tasks; + + json11::Json::array Groups; + Groups.push_back(Group); + + json11::Json::array Datacenters; + Datacenters.push_back("dc1"); + + json11::Json::object Job; + Job["ID"] = std::string(JobId); + Job["Name"] = std::string(JobId); + Job["Type"] = "batch"; + Job["Datacenters"] = Datacenters; + Job["TaskGroups"] = Groups; + + json11::Json::object Root; + Root["Job"] = Job; + + std::string Body = json11::Json(Root).dump(); + + IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{Body.data(), Body.size()}, ZenContentType::kJSON); + + HttpClient::Response Response = + m_HttpClient.Put("v1/jobs", Payload, {{"Content-Type", "application/json"}, {"Accept", "application/json"}}); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: SubmitJob failed for '{}'", JobId); + return Result; + } + + std::string ResponseBody(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(ResponseBody, Err); + + if (!Err.empty()) + { + ZEN_WARN("NomadTestClient: invalid JSON in SubmitJob response: {}", Err); + return Result; + } + + Result.Id = std::string(JobId); + Result.Status = "pending"; + + ZEN_INFO("NomadTestClient: job '{}' submitted (eval_id={})", JobId, Json["EvalID"].string_value()); + + return Result; +} + +NomadJobInfo +NomadTestClient::GetJobStatus(std::string_view JobId) +{ + ZEN_TRACE_CPU("GetNomadJobStatus"); + + NomadJobInfo Result; + + HttpClient::Response Response = m_HttpClient.Get(fmt::format("v1/job/{}", JobId)); + + if (Response.Error) + { + ZEN_WARN("NomadTestClient: GetJobStatus failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return Result; + } + + if (static_cast<int>(Response.StatusCode) == 404) + { + Result.Status = "dead"; + return Result; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: GetJobStatus failed with HTTP/{}", static_cast<int>(Response.StatusCode)); + return Result; + } + + std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("NomadTestClient: invalid JSON in GetJobStatus response: {}", Err); + return Result; + } + + Result.Id = Json["ID"].string_value(); + Result.Status = Json["Status"].string_value(); + if (const json11::Json Desc = Json["StatusDescription"]; Desc.is_string()) + { + Result.StatusDescription = Desc.string_value(); + } + + return Result; +} + +void +NomadTestClient::StopJob(std::string_view JobId) +{ + ZEN_TRACE_CPU("StopNomadJob"); + + HttpClient::Response Response = m_HttpClient.Delete(fmt::format("v1/job/{}", JobId)); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: StopJob failed for '{}'", JobId); + return; + } + + ZEN_INFO("NomadTestClient: job '{}' stopped", JobId); +} + +std::vector<NomadAllocInfo> +NomadTestClient::GetAllocations(std::string_view JobId) +{ + ZEN_TRACE_CPU("GetNomadAllocations"); + + std::vector<NomadAllocInfo> Allocs; + + HttpClient::Response Response = m_HttpClient.Get(fmt::format("v1/job/{}/allocations", JobId)); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: GetAllocations failed for '{}'", JobId); + return Allocs; + } + + std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty() || !Json.is_array()) + { + return Allocs; + } + + for (const json11::Json& AllocVal : Json.array_items()) + { + NomadAllocInfo Alloc; + Alloc.Id = AllocVal["ID"].string_value(); + Alloc.ClientStatus = AllocVal["ClientStatus"].string_value(); + + if (const json11::Json TaskStates = AllocVal["TaskStates"]; TaskStates.is_object()) + { + for (const auto& [TaskName, TaskState] : TaskStates.object_items()) + { + if (TaskState["State"].is_string()) + { + Alloc.TaskState = TaskState["State"].string_value(); + } + } + } + + Allocs.push_back(std::move(Alloc)); + } + + return Allocs; +} + +std::vector<NomadJobInfo> +NomadTestClient::ListJobs(std::string_view Prefix) +{ + ZEN_TRACE_CPU("ListNomadJobs"); + + std::vector<NomadJobInfo> Jobs; + + std::string Url = "v1/jobs"; + if (!Prefix.empty()) + { + Url = fmt::format("v1/jobs?prefix={}", Prefix); + } + + HttpClient::Response Response = m_HttpClient.Get(Url); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: ListJobs failed"); + return Jobs; + } + + std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty() || !Json.is_array()) + { + return Jobs; + } + + for (const json11::Json& JobVal : Json.array_items()) + { + NomadJobInfo Job; + Job.Id = JobVal["ID"].string_value(); + Job.Status = JobVal["Status"].string_value(); + if (const json11::Json Desc = JobVal["StatusDescription"]; Desc.is_string()) + { + Job.StatusDescription = Desc.string_value(); + } + Jobs.push_back(std::move(Job)); + } + + return Jobs; +} + +} // namespace zen::nomad diff --git a/src/zennomad/nomadprovisioner.cpp b/src/zennomad/nomadprovisioner.cpp new file mode 100644 index 000000000..3fe9c0ac3 --- /dev/null +++ b/src/zennomad/nomadprovisioner.cpp @@ -0,0 +1,264 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zennomad/nomadclient.h> +#include <zennomad/nomadprovisioner.h> + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/process.h> +#include <zencore/scopeguard.h> +#include <zencore/thread.h> +#include <zencore/trace.h> + +#include <chrono> + +namespace zen::nomad { + +NomadProvisioner::NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint) +: m_Config(Config) +, m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_ProcessId(static_cast<uint32_t>(zen::GetCurrentProcessId())) +, m_Log(zen::logging::Get("nomad.provisioner")) +{ + ZEN_DEBUG("initializing provisioner (server: {}, driver: {}, max_cores: {}, cores_per_job: {}, max_jobs: {})", + m_Config.ServerUrl, + ToString(m_Config.TaskDriver), + m_Config.MaxCores, + m_Config.CoresPerJob, + m_Config.MaxJobs); + + m_Client = std::make_unique<NomadClient>(m_Config); + if (!m_Client->Initialize()) + { + ZEN_ERROR("failed to initialize Nomad HTTP client"); + return; + } + + ZEN_DEBUG("Nomad HTTP client initialized, starting management thread"); + + m_Thread = std::thread([this] { ManagementThread(); }); +} + +NomadProvisioner::~NomadProvisioner() +{ + ZEN_DEBUG("provisioner shutting down"); + + m_ShouldExit.store(true); + m_WakeCV.notify_all(); + + if (m_Thread.joinable()) + { + m_Thread.join(); + } + + StopAllJobs(); + + ZEN_DEBUG("provisioner shutdown complete"); +} + +void +NomadProvisioner::SetTargetCoreCount(uint32_t Count) +{ + const uint32_t Clamped = std::min(Count, static_cast<uint32_t>(m_Config.MaxCores)); + const uint32_t Previous = m_TargetCoreCount.exchange(Clamped); + + if (Clamped != Previous) + { + ZEN_DEBUG("target core count changed: {} -> {}", Previous, Clamped); + } + + m_WakeCV.notify_all(); +} + +NomadProvisioningStats +NomadProvisioner::GetStats() const +{ + NomadProvisioningStats Stats; + Stats.TargetCoreCount = m_TargetCoreCount.load(); + Stats.EstimatedCoreCount = m_EstimatedCoreCount.load(); + Stats.RunningJobCount = m_RunningJobCount.load(); + + { + std::lock_guard<std::mutex> Lock(m_JobsLock); + Stats.ActiveJobCount = static_cast<uint32_t>(m_Jobs.size()); + } + + return Stats; +} + +std::string +NomadProvisioner::GenerateJobId() +{ + const uint32_t Index = m_JobIndex.fetch_add(1); + + ExtendableStringBuilder<128> Builder; + Builder << m_Config.JobPrefix << "-" << m_ProcessId << "-" << Index; + return std::string(Builder.ToView()); +} + +void +NomadProvisioner::ManagementThread() +{ + ZEN_TRACE_CPU("Nomad_Mgmt"); + zen::SetCurrentThreadName("nomad_mgmt"); + + ZEN_INFO("Nomad management thread started"); + + while (!m_ShouldExit.load()) + { + ZEN_DEBUG("management cycle: target={} estimated={} running={} active={}", + m_TargetCoreCount.load(), + m_EstimatedCoreCount.load(), + m_RunningJobCount.load(), + [this] { + std::lock_guard<std::mutex> Lock(m_JobsLock); + return m_Jobs.size(); + }()); + + SubmitNewJobs(); + PollExistingJobs(); + CleanupDeadJobs(); + + // Wait up to 5 seconds or until woken + std::unique_lock<std::mutex> Lock(m_WakeMutex); + m_WakeCV.wait_for(Lock, std::chrono::seconds(5), [this] { return m_ShouldExit.load(); }); + } + + ZEN_INFO("Nomad management thread exiting"); +} + +void +NomadProvisioner::SubmitNewJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::SubmitNewJobs"); + + const uint32_t CoresPerJob = static_cast<uint32_t>(m_Config.CoresPerJob); + + while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load()) + { + { + std::lock_guard<std::mutex> Lock(m_JobsLock); + if (static_cast<int>(m_Jobs.size()) >= m_Config.MaxJobs) + { + ZEN_INFO("Nomad max jobs limit reached ({})", m_Config.MaxJobs); + break; + } + } + + if (m_ShouldExit.load()) + { + break; + } + + const std::string JobId = GenerateJobId(); + + ZEN_DEBUG("submitting job '{}' (estimated: {}, target: {})", JobId, m_EstimatedCoreCount.load(), m_TargetCoreCount.load()); + + const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint); + + NomadJobInfo JobInfo; + JobInfo.Id = JobId; + + if (!m_Client->SubmitJob(JobJson, JobInfo)) + { + ZEN_WARN("failed to submit Nomad job '{}'", JobId); + break; + } + + TrackedJob Tracked; + Tracked.JobId = JobId; + Tracked.Status = "pending"; + Tracked.Cores = static_cast<int>(CoresPerJob); + + { + std::lock_guard<std::mutex> Lock(m_JobsLock); + m_Jobs.push_back(std::move(Tracked)); + } + + m_EstimatedCoreCount.fetch_add(CoresPerJob); + + ZEN_INFO("Nomad job '{}' submitted (estimated cores: {})", JobId, m_EstimatedCoreCount.load()); + } +} + +void +NomadProvisioner::PollExistingJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::PollExistingJobs"); + + std::lock_guard<std::mutex> Lock(m_JobsLock); + + for (auto& Job : m_Jobs) + { + if (m_ShouldExit.load()) + { + break; + } + + NomadJobInfo Info; + if (!m_Client->GetJobStatus(Job.JobId, Info)) + { + ZEN_DEBUG("failed to poll status for job '{}'", Job.JobId); + continue; + } + + const std::string PrevStatus = Job.Status; + Job.Status = Info.Status; + + if (PrevStatus != Job.Status) + { + ZEN_INFO("Nomad job '{}' status changed: {} -> {}", Job.JobId, PrevStatus, Job.Status); + + if (Job.Status == "running" && PrevStatus != "running") + { + m_RunningJobCount.fetch_add(1); + } + else if (Job.Status != "running" && PrevStatus == "running") + { + m_RunningJobCount.fetch_sub(1); + } + } + } +} + +void +NomadProvisioner::CleanupDeadJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::CleanupDeadJobs"); + + std::lock_guard<std::mutex> Lock(m_JobsLock); + + for (auto It = m_Jobs.begin(); It != m_Jobs.end();) + { + if (It->Status == "dead") + { + ZEN_INFO("Nomad job '{}' is dead, removing from tracked jobs", It->JobId); + m_EstimatedCoreCount.fetch_sub(static_cast<uint32_t>(It->Cores)); + It = m_Jobs.erase(It); + } + else + { + ++It; + } + } +} + +void +NomadProvisioner::StopAllJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::StopAllJobs"); + + std::lock_guard<std::mutex> Lock(m_JobsLock); + + for (const auto& Job : m_Jobs) + { + ZEN_INFO("stopping Nomad job '{}' during shutdown", Job.JobId); + m_Client->StopJob(Job.JobId); + } + + m_Jobs.clear(); + m_EstimatedCoreCount.store(0); + m_RunningJobCount.store(0); +} + +} // namespace zen::nomad diff --git a/src/zennomad/xmake.lua b/src/zennomad/xmake.lua new file mode 100644 index 000000000..ef1a8b201 --- /dev/null +++ b/src/zennomad/xmake.lua @@ -0,0 +1,10 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zennomad') + set_kind("static") + set_group("libs") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore", "zenhttp", "zenutil") + add_packages("json11") diff --git a/src/zenremotestore-test/zenremotestore-test.cpp b/src/zenremotestore-test/zenremotestore-test.cpp index 5db185041..dc47c5aed 100644 --- a/src/zenremotestore-test/zenremotestore-test.cpp +++ b/src/zenremotestore-test/zenremotestore-test.cpp @@ -1,46 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zencore/filesystem.h> -#include <zencore/logging.h> -#include <zencore/trace.h> -#include <zenremotestore/projectstore/remoteprojectstore.h> +#include <zencore/testing.h> #include <zenremotestore/zenremotestore.h> #include <zencore/memory/newdelete.h> -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include <zencore/testing.h> -# include <zencore/process.h> -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zenremotestore_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zenstore-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zenremotestore-test", zen::zenremotestore_forcelinktests); #else return 0; #endif diff --git a/src/zenremotestore/builds/buildmanifest.cpp b/src/zenremotestore/builds/buildmanifest.cpp index 051436e96..738e4b33b 100644 --- a/src/zenremotestore/builds/buildmanifest.cpp +++ b/src/zenremotestore/builds/buildmanifest.cpp @@ -97,6 +97,8 @@ ParseBuildManifest(const std::filesystem::path& ManifestPath) } #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("remotestore.buildmanifest"); + TEST_CASE("buildmanifest.unstructured") { ScopedTemporaryDirectory Root; @@ -163,6 +165,8 @@ TEST_CASE("buildmanifest.structured") CHECK_EQ(Manifest.Parts[1].Files[0].generic_string(), "baz.pdb"); } +TEST_SUITE_END(); + void buildmanifest_forcelink() { diff --git a/src/zenremotestore/builds/buildsavedstate.cpp b/src/zenremotestore/builds/buildsavedstate.cpp index 1d1f4605f..0685bf679 100644 --- a/src/zenremotestore/builds/buildsavedstate.cpp +++ b/src/zenremotestore/builds/buildsavedstate.cpp @@ -588,6 +588,8 @@ namespace buildsavestate_test { } } // namespace buildsavestate_test +TEST_SUITE_BEGIN("remotestore.buildsavedstate"); + TEST_CASE("buildsavestate.BuildsSelection") { using namespace buildsavestate_test; @@ -696,6 +698,8 @@ TEST_CASE("buildsavestate.DownloadedPaths") } } +TEST_SUITE_END(); + #endif // ZEN_WITH_TESTS } // namespace zen diff --git a/src/zenremotestore/builds/buildstoragecache.cpp b/src/zenremotestore/builds/buildstoragecache.cpp index 07fcd62ba..00765903d 100644 --- a/src/zenremotestore/builds/buildstoragecache.cpp +++ b/src/zenremotestore/builds/buildstoragecache.cpp @@ -151,7 +151,7 @@ public: auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); }); HttpClient::Response CacheResponse = - m_HttpClient.Upload(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash.ToHexString()), + m_HttpClient.Upload(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash), Payload, ContentType); @@ -180,7 +180,7 @@ public: } CreateDirectories(m_TempFolderPath); HttpClient::Response CacheResponse = - m_HttpClient.Download(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash.ToHexString()), + m_HttpClient.Download(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash), m_TempFolderPath, Headers); AddStatistic(CacheResponse); @@ -191,6 +191,74 @@ public: return {}; } + virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId, + const IoHash& RawHash, + std::span<const std::pair<uint64_t, uint64_t>> Ranges) override + { + ZEN_TRACE_CPU("ZenBuildStorageCache::GetBuildBlobRanges"); + + Stopwatch ExecutionTimer; + auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); }); + + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + { + for (const std::pair<uint64_t, uint64_t>& Range : Ranges) + { + Writer.BeginObject(); + { + Writer.AddInteger("offset"sv, Range.first); + Writer.AddInteger("length"sv, Range.second); + } + Writer.EndObject(); + } + } + Writer.EndArray(); // ranges + + CreateDirectories(m_TempFolderPath); + HttpClient::Response CacheResponse = + m_HttpClient.Post(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash), + Writer.Save(), + HttpClient::Accept(ZenContentType::kCbPackage)); + AddStatistic(CacheResponse); + if (CacheResponse.IsSuccess()) + { + CbPackage ResponsePackage = ParsePackageMessage(CacheResponse.ResponsePayload); + CbObjectView ResponseObject = ResponsePackage.GetObject(); + + CbArrayView RangeArray = ResponseObject["ranges"sv].AsArrayView(); + + std::vector<std::pair<uint64_t, uint64_t>> ReceivedRanges; + ReceivedRanges.reserve(RangeArray.Num()); + + uint64_t OffsetInPayloadRanges = 0; + + for (CbFieldView View : RangeArray) + { + CbObjectView RangeView = View.AsObjectView(); + uint64_t Offset = RangeView["offset"sv].AsUInt64(); + uint64_t Length = RangeView["length"sv].AsUInt64(); + + const std::pair<uint64_t, uint64_t>& Range = Ranges[ReceivedRanges.size()]; + + if (Offset != Range.first || Length != Range.second) + { + return {}; + } + ReceivedRanges.push_back(std::make_pair(OffsetInPayloadRanges, Length)); + OffsetInPayloadRanges += Length; + } + + const CbAttachment* DataAttachment = ResponsePackage.FindAttachment(RawHash); + if (DataAttachment) + { + SharedBuffer PayloadRanges = DataAttachment->AsBinary(); + return BuildBlobRanges{.PayloadBuffer = PayloadRanges.AsIoBuffer(), .Ranges = std::move(ReceivedRanges)}; + } + } + return {}; + } + virtual void PutBlobMetadatas(const Oid& BuildId, std::span<const IoHash> BlobHashes, std::span<const CbObject> MetaDatas) override { ZEN_ASSERT(!IsFlushed); @@ -460,6 +528,192 @@ CreateZenBuildStorageCache(HttpClient& HttpClient, return std::make_unique<ZenBuildStorageCache>(HttpClient, Stats, Namespace, Bucket, TempFolderPath, BackgroundWorkerPool); } +#if ZEN_WITH_TESTS + +class InMemoryBuildStorageCache : public BuildStorageCache +{ +public: + // MaxRangeSupported == 0 : no range requests are accepted, always return full blob + // MaxRangeSupported == 1 : single range is supported, multi range returns full blob + // MaxRangeSupported > 1 : multirange is supported up to MaxRangeSupported, more ranges returns empty blob (bad request) + explicit InMemoryBuildStorageCache(uint64_t MaxRangeSupported, + BuildStorageCache::Statistics& Stats, + double LatencySec = 0.0, + double DelayPerKBSec = 0.0) + : m_MaxRangeSupported(MaxRangeSupported) + , m_Stats(Stats) + , m_LatencySec(LatencySec) + , m_DelayPerKBSec(DelayPerKBSec) + { + } + void PutBuildBlob(const Oid&, const IoHash& RawHash, ZenContentType, const CompositeBuffer& Payload) override + { + IoBuffer Buf = Payload.Flatten().AsIoBuffer(); + Buf.MakeOwned(); + const uint64_t SentBytes = Buf.Size(); + uint64_t ReceivedBytes = 0; + SimulateLatency(SentBytes, 0); + auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); }); + Stopwatch ExecutionTimer; + auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer.GetElapsedTimeUs(), ReceivedBytes, SentBytes); }); + { + std::lock_guard Lock(m_Mutex); + m_Entries[RawHash] = std::move(Buf); + } + m_Stats.PutBlobCount.fetch_add(1); + m_Stats.PutBlobByteCount.fetch_add(SentBytes); + } + + IoBuffer GetBuildBlob(const Oid&, const IoHash& RawHash, uint64_t RangeOffset = 0, uint64_t RangeBytes = (uint64_t)-1) override + { + uint64_t SentBytes = 0; + uint64_t ReceivedBytes = 0; + SimulateLatency(SentBytes, 0); + auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); }); + Stopwatch ExecutionTimer; + auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer.GetElapsedTimeUs(), ReceivedBytes, SentBytes); }); + IoBuffer FullPayload; + { + std::lock_guard Lock(m_Mutex); + auto It = m_Entries.find(RawHash); + if (It == m_Entries.end()) + { + return {}; + } + FullPayload = It->second; + } + + if (RangeOffset != 0 || RangeBytes != (uint64_t)-1) + { + if (m_MaxRangeSupported == 0) + { + ReceivedBytes = FullPayload.Size(); + return FullPayload; + } + else + { + ReceivedBytes = (RangeBytes == (uint64_t)-1) ? FullPayload.Size() - RangeOffset : RangeBytes; + return IoBuffer(FullPayload, RangeOffset, RangeBytes); + } + } + else + { + ReceivedBytes = FullPayload.Size(); + return FullPayload; + } + } + + BuildBlobRanges GetBuildBlobRanges(const Oid&, const IoHash& RawHash, std::span<const std::pair<uint64_t, uint64_t>> Ranges) override + { + ZEN_ASSERT(!Ranges.empty()); + uint64_t SentBytes = 0; + uint64_t ReceivedBytes = 0; + SimulateLatency(SentBytes, 0); + auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); }); + Stopwatch ExecutionTimer; + auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer.GetElapsedTimeUs(), ReceivedBytes, SentBytes); }); + if (m_MaxRangeSupported > 1 && Ranges.size() > m_MaxRangeSupported) + { + return {}; + } + IoBuffer FullPayload; + { + std::lock_guard Lock(m_Mutex); + auto It = m_Entries.find(RawHash); + if (It == m_Entries.end()) + { + return {}; + } + FullPayload = It->second; + } + + if (Ranges.size() > m_MaxRangeSupported) + { + // An empty Ranges signals to the caller: "full buffer given, use it for all requested ranges". + ReceivedBytes = FullPayload.Size(); + return {.PayloadBuffer = FullPayload}; + } + else + { + uint64_t PayloadStart = Ranges.front().first; + uint64_t PayloadSize = Ranges.back().first + Ranges.back().second - PayloadStart; + IoBuffer RangeBuffer = IoBuffer(FullPayload, PayloadStart, PayloadSize); + std::vector<std::pair<uint64_t, uint64_t>> PayloadRanges; + PayloadRanges.reserve(Ranges.size()); + for (const std::pair<uint64_t, uint64_t>& Range : Ranges) + { + PayloadRanges.push_back(std::make_pair(Range.first - PayloadStart, Range.second)); + } + ReceivedBytes = PayloadSize; + return {.PayloadBuffer = RangeBuffer, .Ranges = std::move(PayloadRanges)}; + } + } + + void PutBlobMetadatas(const Oid&, std::span<const IoHash>, std::span<const CbObject>) override {} + + std::vector<CbObject> GetBlobMetadatas(const Oid&, std::span<const IoHash> Hashes) override + { + return std::vector<CbObject>(Hashes.size()); + } + + std::vector<BlobExistsResult> BlobsExists(const Oid&, std::span<const IoHash> Hashes) override + { + std::lock_guard Lock(m_Mutex); + std::vector<BlobExistsResult> Result; + Result.reserve(Hashes.size()); + for (const IoHash& Hash : Hashes) + { + auto It = m_Entries.find(Hash); + Result.push_back({.HasBody = (It != m_Entries.end() && It->second)}); + } + return Result; + } + + void Flush(int32_t, std::function<bool(intptr_t)>&&) override {} + +private: + void AddStatistic(uint64_t ElapsedTimeUs, uint64_t ReceivedBytes, uint64_t SentBytes) + { + m_Stats.TotalBytesWritten += SentBytes; + m_Stats.TotalBytesRead += ReceivedBytes; + m_Stats.TotalExecutionTimeUs += ElapsedTimeUs; + m_Stats.TotalRequestCount++; + SetAtomicMax(m_Stats.PeakSentBytes, SentBytes); + SetAtomicMax(m_Stats.PeakReceivedBytes, ReceivedBytes); + if (ElapsedTimeUs > 0) + { + SetAtomicMax(m_Stats.PeakBytesPerSec, (ReceivedBytes + SentBytes) * 1000000 / ElapsedTimeUs); + } + } + + void SimulateLatency(uint64_t SendBytes, uint64_t ReceiveBytes) + { + double SleepSec = m_LatencySec; + if (m_DelayPerKBSec > 0.0) + { + SleepSec += m_DelayPerKBSec * (double(SendBytes + ReceiveBytes) / 1024u); + } + if (SleepSec > 0) + { + Sleep(int(SleepSec * 1000)); + } + } + + uint64_t m_MaxRangeSupported = 0; + BuildStorageCache::Statistics& m_Stats; + const double m_LatencySec = 0.0; + const double m_DelayPerKBSec = 0.0; + std::mutex m_Mutex; + std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> m_Entries; +}; + +std::unique_ptr<BuildStorageCache> +CreateInMemoryBuildStorageCache(uint64_t MaxRangeSupported, BuildStorageCache::Statistics& Stats, double LatencySec, double DelayPerKBSec) +{ + return std::make_unique<InMemoryBuildStorageCache>(MaxRangeSupported, Stats, LatencySec, DelayPerKBSec); +} +#endif // ZEN_WITH_TESTS + ZenCacheEndpointTestResult TestZenCacheEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpVerbose) { @@ -474,9 +728,28 @@ TestZenCacheEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const boo HttpClient::Response TestResponse = TestHttpClient.Get("/status/builds"); if (TestResponse.IsSuccess()) { - return {.Success = true}; + uint64_t MaxRangeCountPerRequest = 1; + CbObject StatusResponse = TestResponse.AsObject(); + if (StatusResponse["ok"].AsBool()) + { + MaxRangeCountPerRequest = StatusResponse["capabilities"].AsObjectView()["maxrangecountperrequest"].AsUInt64(1); + + LatencyTestResult LatencyResult = MeasureLatency(TestHttpClient, "/health"); + + if (!LatencyResult.Success) + { + return {.Success = false, .FailureReason = LatencyResult.FailureReason}; + } + + return {.Success = true, .LatencySeconds = LatencyResult.LatencySeconds, .MaxRangeCountPerRequest = MaxRangeCountPerRequest}; + } + else + { + return {.Success = false, + .FailureReason = fmt::format("ZenCache endpoint {}/status/builds did not respond with \"ok\"", BaseUrl)}; + } } return {.Success = false, .FailureReason = TestResponse.ErrorMessage("")}; -}; +} } // namespace zen diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index 2319ad66d..f4b167b73 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -38,6 +38,7 @@ ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_WITH_TESTS # include <zencore/testing.h> # include <zencore/testutils.h> +# include <zenhttp/httpclientauth.h> # include <zenremotestore/builds/filebuildstorage.h> #endif // ZEN_WITH_TESTS @@ -484,24 +485,6 @@ private: uint64_t FilteredPerSecond = 0; }; -EPartialBlockRequestMode -PartialBlockRequestModeFromString(const std::string_view ModeString) -{ - switch (HashStringAsLowerDjb2(ModeString)) - { - case HashStringDjb2("false"): - return EPartialBlockRequestMode::Off; - case HashStringDjb2("zencacheonly"): - return EPartialBlockRequestMode::ZenCacheOnly; - case HashStringDjb2("mixed"): - return EPartialBlockRequestMode::Mixed; - case HashStringDjb2("true"): - return EPartialBlockRequestMode::All; - default: - return EPartialBlockRequestMode::Invalid; - } -} - std::filesystem::path ZenStateFilePath(const std::filesystem::path& ZenFolderPath) { @@ -579,13 +562,6 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) CreateDirectories(m_TempDownloadFolderPath); CreateDirectories(m_TempBlockFolderPath); - Stopwatch IndexTimer; - - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Indexed local and remote content in {}", NiceTimeSpanMs(IndexTimer.GetElapsedTimeMs())); - } - Stopwatch CacheMappingTimer; std::vector<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters(m_RemoteContent.ChunkedContent.SequenceRawHashes.size()); @@ -906,343 +882,240 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) CheckRequiredDiskSpace(RemotePathToRemoteIndex); + BlobsExistsResult ExistsResult; { - ZEN_TRACE_CPU("WriteChunks"); - - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::WriteChunks, (uint32_t)TaskSteps::StepCount); - - Stopwatch WriteTimer; - - FilteredRate FilteredDownloadedBytesPerSecond; - FilteredRate FilteredWrittenBytesPerSecond; - - std::unique_ptr<OperationLogOutput::ProgressBar> WriteProgressBarPtr( - m_LogOutput.CreateProgressBar(m_Options.PrimeCacheOnly ? "Downloading" : "Writing")); - OperationLogOutput::ProgressBar& WriteProgressBar(*WriteProgressBarPtr); - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + ChunkBlockAnalyser BlockAnalyser( + m_LogOutput, + m_BlockDescriptions, + ChunkBlockAnalyser::Options{.IsQuiet = m_Options.IsQuiet, + .IsVerbose = m_Options.IsVerbose, + .HostLatencySec = m_Storage.BuildStorageHost.LatencySec, + .HostHighSpeedLatencySec = m_Storage.CacheHost.LatencySec, + .HostMaxRangeCountPerRequest = m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest, + .HostHighSpeedMaxRangeCountPerRequest = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest}); - struct LooseChunkHashWorkData - { - std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs; - uint32_t RemoteChunkIndex = (uint32_t)-1; - }; + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = BlockAnalyser.GetNeeded( + m_RemoteLookup.ChunkHashToChunkIndex, + [&](uint32_t RemoteChunkIndex) -> bool { return RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex]; }); - std::vector<LooseChunkHashWorkData> LooseChunkHashWorks; - TotalPartWriteCount += CopyChunkDatas.size(); - TotalPartWriteCount += ScavengedSequenceCopyOperations.size(); + std::vector<uint32_t> FetchBlockIndexes; + std::vector<uint32_t> CachedChunkBlockIndexes; - for (const IoHash ChunkHash : m_LooseChunkHashes) { - auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); - ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end()); - const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second; - if (RemoteChunkIndexNeedsCopyFromLocalFileFlags[RemoteChunkIndex]) + ZEN_TRACE_CPU("BlockCacheFileExists"); + for (const ChunkBlockAnalyser::NeededBlock& NeededBlock : NeededBlocks) { - if (m_Options.IsVerbose) + if (m_Options.PrimeCacheOnly) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Skipping chunk {} due to cache reuse", ChunkHash); - } - continue; - } - bool NeedsCopy = true; - if (RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex].compare_exchange_strong(NeedsCopy, false)) - { - std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs = - GetRemainingChunkTargets(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex); - - if (ChunkTargetPtrs.empty()) - { - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Skipping chunk {} due to cache reuse", ChunkHash); - } + FetchBlockIndexes.push_back(NeededBlock.BlockIndex); } else { - TotalRequestCount++; - TotalPartWriteCount++; - LooseChunkHashWorks.push_back( - LooseChunkHashWorkData{.ChunkTargetPtrs = ChunkTargetPtrs, .RemoteChunkIndex = RemoteChunkIndex}); - } - } - } - - uint32_t BlockCount = gsl::narrow<uint32_t>(m_BlockDescriptions.size()); - - std::vector<bool> ChunkIsPickedUpByBlock(m_RemoteContent.ChunkedContent.ChunkHashes.size(), false); - auto GetNeededChunkBlockIndexes = [this, &RemoteChunkIndexNeedsCopyFromSourceFlags, &ChunkIsPickedUpByBlock]( - const ChunkBlockDescription& BlockDescription) { - ZEN_TRACE_CPU("GetNeededChunkBlockIndexes"); - std::vector<uint32_t> NeededBlockChunkIndexes; - for (uint32_t ChunkBlockIndex = 0; ChunkBlockIndex < BlockDescription.ChunkRawHashes.size(); ChunkBlockIndex++) - { - const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex]; - if (auto It = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); It != m_RemoteLookup.ChunkHashToChunkIndex.end()) - { - const uint32_t RemoteChunkIndex = It->second; - if (!ChunkIsPickedUpByBlock[RemoteChunkIndex]) + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[NeededBlock.BlockIndex]; + bool UsingCachedBlock = false; + if (auto It = CachedBlocksFound.find(BlockDescription.BlockHash); It != CachedBlocksFound.end()) { - if (RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex]) + TotalPartWriteCount++; + + std::filesystem::path BlockPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(); + if (IsFile(BlockPath)) { - ChunkIsPickedUpByBlock[RemoteChunkIndex] = true; - NeededBlockChunkIndexes.push_back(ChunkBlockIndex); + CachedChunkBlockIndexes.push_back(NeededBlock.BlockIndex); + UsingCachedBlock = true; } } - } - else - { - ZEN_DEBUG("Chunk {} not found in block {}", ChunkHash, BlockDescription.BlockHash); + if (!UsingCachedBlock) + { + FetchBlockIndexes.push_back(NeededBlock.BlockIndex); + } } } - return NeededBlockChunkIndexes; - }; + } - std::vector<uint32_t> CachedChunkBlockIndexes; - std::vector<uint32_t> FetchBlockIndexes; - std::vector<std::vector<uint32_t>> AllBlockChunkIndexNeeded; + std::vector<uint32_t> NeededLooseChunkIndexes; - for (uint32_t BlockIndex = 0; BlockIndex < BlockCount; BlockIndex++) { - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - - std::vector<uint32_t> BlockChunkIndexNeeded = GetNeededChunkBlockIndexes(BlockDescription); - if (!BlockChunkIndexNeeded.empty()) + NeededLooseChunkIndexes.reserve(m_LooseChunkHashes.size()); + for (uint32_t LooseChunkIndex = 0; LooseChunkIndex < m_LooseChunkHashes.size(); LooseChunkIndex++) { - if (m_Options.PrimeCacheOnly) + const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex]; + auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); + ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end()); + const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second; + + if (RemoteChunkIndexNeedsCopyFromLocalFileFlags[RemoteChunkIndex]) { - FetchBlockIndexes.push_back(BlockIndex); + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Skipping chunk {} due to cache reuse", + m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]); + } + continue; } - else + + bool NeedsCopy = true; + if (RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex].compare_exchange_strong(NeedsCopy, false)) { - bool UsingCachedBlock = false; - if (auto It = CachedBlocksFound.find(BlockDescription.BlockHash); It != CachedBlocksFound.end()) + uint64_t WriteCount = GetChunkWriteCount(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex); + if (WriteCount == 0) { - TotalPartWriteCount++; - - std::filesystem::path BlockPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(); - if (IsFile(BlockPath)) + if (m_Options.IsVerbose) { - CachedChunkBlockIndexes.push_back(BlockIndex); - UsingCachedBlock = true; + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Skipping chunk {} due to cache reuse", + m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]); } } - if (!UsingCachedBlock) + else { - FetchBlockIndexes.push_back(BlockIndex); + NeededLooseChunkIndexes.push_back(LooseChunkIndex); } } } - AllBlockChunkIndexNeeded.emplace_back(std::move(BlockChunkIndexNeeded)); } - BlobsExistsResult ExistsResult; - - if (m_Storage.BuildCacheStorage) + if (m_Storage.CacheStorage) { ZEN_TRACE_CPU("BlobCacheExistCheck"); Stopwatch Timer; - tsl::robin_set<IoHash> BlobHashesSet; + std::vector<IoHash> BlobHashes; + BlobHashes.reserve(NeededLooseChunkIndexes.size() + FetchBlockIndexes.size()); - BlobHashesSet.reserve(LooseChunkHashWorks.size() + FetchBlockIndexes.size()); - for (LooseChunkHashWorkData& LooseChunkHashWork : LooseChunkHashWorks) + for (const uint32_t LooseChunkIndex : NeededLooseChunkIndexes) { - BlobHashesSet.insert(m_RemoteContent.ChunkedContent.ChunkHashes[LooseChunkHashWork.RemoteChunkIndex]); + BlobHashes.push_back(m_LooseChunkHashes[LooseChunkIndex]); } + for (uint32_t BlockIndex : FetchBlockIndexes) { - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - BlobHashesSet.insert(BlockDescription.BlockHash); + BlobHashes.push_back(m_BlockDescriptions[BlockIndex].BlockHash); } - if (!BlobHashesSet.empty()) - { - const std::vector<IoHash> BlobHashes(BlobHashesSet.begin(), BlobHashesSet.end()); - const std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult = - m_Storage.BuildCacheStorage->BlobsExists(m_BuildId, BlobHashes); + const std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult = + m_Storage.CacheStorage->BlobsExists(m_BuildId, BlobHashes); - if (CacheExistsResult.size() == BlobHashes.size()) + if (CacheExistsResult.size() == BlobHashes.size()) + { + ExistsResult.ExistingBlobs.reserve(CacheExistsResult.size()); + for (size_t BlobIndex = 0; BlobIndex < BlobHashes.size(); BlobIndex++) { - ExistsResult.ExistingBlobs.reserve(CacheExistsResult.size()); - for (size_t BlobIndex = 0; BlobIndex < BlobHashes.size(); BlobIndex++) + if (CacheExistsResult[BlobIndex].HasBody) { - if (CacheExistsResult[BlobIndex].HasBody) - { - ExistsResult.ExistingBlobs.insert(BlobHashes[BlobIndex]); - } + ExistsResult.ExistingBlobs.insert(BlobHashes[BlobIndex]); } } - ExistsResult.ElapsedTimeMs = Timer.GetElapsedTimeMs(); - if (!ExistsResult.ExistingBlobs.empty() && !m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Remote cache : Found {} out of {} needed blobs in {}", - ExistsResult.ExistingBlobs.size(), - BlobHashes.size(), - NiceTimeSpanMs(ExistsResult.ElapsedTimeMs)); - } + } + ExistsResult.ElapsedTimeMs = Timer.GetElapsedTimeMs(); + if (!ExistsResult.ExistingBlobs.empty() && !m_Options.IsQuiet) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Remote cache : Found {} out of {} needed blobs in {}", + ExistsResult.ExistingBlobs.size(), + BlobHashes.size(), + NiceTimeSpanMs(ExistsResult.ElapsedTimeMs)); } } - std::vector<BlockRangeDescriptor> BlockRangeWorks; - std::vector<uint32_t> FullBlockWorks; + std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> BlockPartialDownloadModes; + + if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off) { - Stopwatch Timer; + BlockPartialDownloadModes.resize(m_BlockDescriptions.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); + } + else + { + ChunkBlockAnalyser::EPartialBlockDownloadMode CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + ChunkBlockAnalyser::EPartialBlockDownloadMode CachePartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; - std::vector<uint32_t> PartialBlockIndexes; + switch (m_Options.PartialBlockRequestMode) + { + case EPartialBlockRequestMode::Off: + break; + case EPartialBlockRequestMode::ZenCacheOnly: + CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + break; + case EPartialBlockRequestMode::Mixed: + CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; + break; + case EPartialBlockRequestMode::All: + CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange + : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; + break; + default: + ZEN_ASSERT(false); + break; + } - for (uint32_t BlockIndex : FetchBlockIndexes) + BlockPartialDownloadModes.reserve(m_BlockDescriptions.size()); + for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++) { - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + const bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash); + BlockPartialDownloadModes.push_back(BlockExistInCache ? CachePartialDownloadMode : CloudPartialDownloadMode); + } + } - const std::vector<uint32_t> BlockChunkIndexNeeded = std::move(AllBlockChunkIndexNeeded[BlockIndex]); - if (!BlockChunkIndexNeeded.empty()) - { - bool WantsToDoPartialBlockDownload = BlockChunkIndexNeeded.size() < BlockDescription.ChunkRawHashes.size(); - bool CanDoPartialBlockDownload = - (BlockDescription.HeaderSize > 0) && - (BlockDescription.ChunkCompressedLengths.size() == BlockDescription.ChunkRawHashes.size()); - - bool AllowedToDoPartialRequest = false; - bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash); - switch (m_Options.PartialBlockRequestMode) - { - case EPartialBlockRequestMode::Off: - break; - case EPartialBlockRequestMode::ZenCacheOnly: - AllowedToDoPartialRequest = BlockExistInCache; - break; - case EPartialBlockRequestMode::Mixed: - case EPartialBlockRequestMode::All: - AllowedToDoPartialRequest = true; - break; - default: - ZEN_ASSERT(false); - break; - } + ZEN_ASSERT(BlockPartialDownloadModes.size() == m_BlockDescriptions.size()); - const uint32_t ChunkStartOffsetInBlock = - gsl::narrow<uint32_t>(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize); + ChunkBlockAnalyser::BlockResult PartialBlocks = + BlockAnalyser.CalculatePartialBlockDownloads(NeededBlocks, BlockPartialDownloadModes); - const uint64_t TotalBlockSize = std::accumulate(BlockDescription.ChunkCompressedLengths.begin(), - BlockDescription.ChunkCompressedLengths.end(), - std::uint64_t(ChunkStartOffsetInBlock)); + struct LooseChunkHashWorkData + { + std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs; + uint32_t RemoteChunkIndex = (uint32_t)-1; + }; - if (AllowedToDoPartialRequest && WantsToDoPartialBlockDownload && CanDoPartialBlockDownload) - { - ZEN_TRACE_CPU("PartialBlockAnalysis"); - - bool LimitToSingleRange = - BlockExistInCache ? false : m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Mixed; - uint64_t TotalWantedChunksSize = 0; - std::optional<std::vector<BlockRangeDescriptor>> MaybeBlockRanges = - CalculateBlockRanges(BlockIndex, - BlockDescription, - BlockChunkIndexNeeded, - LimitToSingleRange, - ChunkStartOffsetInBlock, - TotalBlockSize, - TotalWantedChunksSize); - ZEN_ASSERT(TotalWantedChunksSize <= TotalBlockSize); - - if (MaybeBlockRanges.has_value()) - { - const std::vector<BlockRangeDescriptor>& BlockRanges = MaybeBlockRanges.value(); - ZEN_ASSERT(!BlockRanges.empty()); - BlockRangeWorks.insert(BlockRangeWorks.end(), BlockRanges.begin(), BlockRanges.end()); - TotalRequestCount += BlockRanges.size(); - TotalPartWriteCount += BlockRanges.size(); - - uint64_t RequestedSize = std::accumulate( - BlockRanges.begin(), - BlockRanges.end(), - uint64_t(0), - [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; }); - PartialBlockIndexes.push_back(BlockIndex); - - if (RequestedSize > TotalWantedChunksSize) - { - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO( - m_LogOutput, - "Requesting {} chunks ({}) from block {} ({}) using {} requests (extra bytes {})", - BlockChunkIndexNeeded.size(), - NiceBytes(RequestedSize), - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize), - BlockRanges.size(), - NiceBytes(RequestedSize - TotalWantedChunksSize)); - } - } - } - else - { - FullBlockWorks.push_back(BlockIndex); - TotalRequestCount++; - TotalPartWriteCount++; - } - } - else - { - FullBlockWorks.push_back(BlockIndex); - TotalRequestCount++; - TotalPartWriteCount++; - } - } - } + TotalRequestCount += NeededLooseChunkIndexes.size(); + TotalPartWriteCount += NeededLooseChunkIndexes.size(); + TotalRequestCount += PartialBlocks.BlockRanges.size(); + TotalPartWriteCount += PartialBlocks.BlockRanges.size(); + TotalRequestCount += PartialBlocks.FullBlockIndexes.size(); + TotalPartWriteCount += PartialBlocks.FullBlockIndexes.size(); - if (!PartialBlockIndexes.empty()) - { - uint64_t TotalFullBlockRequestBytes = 0; - for (uint32_t BlockIndex : FullBlockWorks) - { - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - uint32_t CurrentOffset = - gsl::narrow<uint32_t>(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize); + std::vector<LooseChunkHashWorkData> LooseChunkHashWorks; + for (uint32_t LooseChunkIndex : NeededLooseChunkIndexes) + { + const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex]; + auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); + ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end()); + const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second; - TotalFullBlockRequestBytes += std::accumulate(BlockDescription.ChunkCompressedLengths.begin(), - BlockDescription.ChunkCompressedLengths.end(), - std::uint64_t(CurrentOffset)); - } + std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs = + GetRemainingChunkTargets(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex); - uint64_t TotalPartialBlockBytes = 0; - for (uint32_t BlockIndex : PartialBlockIndexes) - { - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - uint32_t CurrentOffset = - gsl::narrow<uint32_t>(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize); + ZEN_ASSERT(!ChunkTargetPtrs.empty()); + LooseChunkHashWorks.push_back( + LooseChunkHashWorkData{.ChunkTargetPtrs = ChunkTargetPtrs, .RemoteChunkIndex = RemoteChunkIndex}); + } - TotalPartialBlockBytes += std::accumulate(BlockDescription.ChunkCompressedLengths.begin(), - BlockDescription.ChunkCompressedLengths.end(), - std::uint64_t(CurrentOffset)); - } + ZEN_TRACE_CPU("WriteChunks"); - uint64_t NonPartialTotalBlockBytes = TotalFullBlockRequestBytes + TotalPartialBlockBytes; + m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::WriteChunks, (uint32_t)TaskSteps::StepCount); - const uint64_t TotalPartialBlockRequestBytes = - std::accumulate(BlockRangeWorks.begin(), - BlockRangeWorks.end(), - uint64_t(0), - [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; }); - uint64_t TotalExtraPartialBlocksRequests = BlockRangeWorks.size() - PartialBlockIndexes.size(); + Stopwatch WriteTimer; - uint64_t TotalSavedBlocksSize = TotalPartialBlockBytes - TotalPartialBlockRequestBytes; - double SavedSizePercent = (TotalSavedBlocksSize * 100.0) / NonPartialTotalBlockBytes; + FilteredRate FilteredDownloadedBytesPerSecond; + FilteredRate FilteredWrittenBytesPerSecond; - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Analysis of partial block requests saves download of {} out of {} ({:.1f}%) using {} extra " - "requests. Completed in {}", - NiceBytes(TotalSavedBlocksSize), - NiceBytes(NonPartialTotalBlockBytes), - SavedSizePercent, - TotalExtraPartialBlocksRequests, - NiceTimeSpanMs(ExistsResult.ElapsedTimeMs)); - } - } - } + std::unique_ptr<OperationLogOutput::ProgressBar> WriteProgressBarPtr( + m_LogOutput.CreateProgressBar(m_Options.PrimeCacheOnly ? "Downloading" : "Writing")); + OperationLogOutput::ProgressBar& WriteProgressBar(*WriteProgressBarPtr); + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + TotalPartWriteCount += CopyChunkDatas.size(); + TotalPartWriteCount += ScavengedSequenceCopyOperations.size(); BufferedWriteFileCache WriteCache; @@ -1472,13 +1345,23 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) }); } - for (size_t BlockRangeIndex = 0; BlockRangeIndex < BlockRangeWorks.size(); BlockRangeIndex++) + for (size_t BlockRangeIndex = 0; BlockRangeIndex < PartialBlocks.BlockRanges.size();) { ZEN_ASSERT(!m_Options.PrimeCacheOnly); if (m_AbortFlag) { break; } + + size_t RangeCount = 1; + size_t RangesLeft = PartialBlocks.BlockRanges.size() - BlockRangeIndex; + const ChunkBlockAnalyser::BlockRangeDescriptor& CurrentBlockRange = PartialBlocks.BlockRanges[BlockRangeIndex]; + while (RangeCount < RangesLeft && + CurrentBlockRange.BlockIndex == PartialBlocks.BlockRanges[BlockRangeIndex + RangeCount].BlockIndex) + { + RangeCount++; + } + Work.ScheduleWork( m_NetworkPool, [this, @@ -1492,18 +1375,19 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) TotalPartWriteCount, &FilteredWrittenBytesPerSecond, &Work, - &BlockRangeWorks, - BlockRangeIndex](std::atomic<bool>&) { + &PartialBlocks, + BlockRangeStartIndex = BlockRangeIndex, + RangeCount = RangeCount](std::atomic<bool>&) { if (!m_AbortFlag) { - ZEN_TRACE_CPU("Async_GetPartialBlock"); - - const BlockRangeDescriptor& BlockRange = BlockRangeWorks[BlockRangeIndex]; + ZEN_TRACE_CPU("Async_GetPartialBlockRanges"); FilteredDownloadedBytesPerSecond.Start(); DownloadPartialBlock( - BlockRange, + PartialBlocks.BlockRanges, + BlockRangeStartIndex, + RangeCount, ExistsResult, [this, &RemoteChunkIndexNeedsCopyFromSourceFlags, @@ -1515,7 +1399,10 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) TotalPartWriteCount, &FilteredDownloadedBytesPerSecond, &FilteredWrittenBytesPerSecond, - &BlockRange](IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath) { + &PartialBlocks](IoBuffer&& InMemoryBuffer, + const std::filesystem::path& OnDiskPath, + size_t BlockRangeStartIndex, + std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths) { if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) { FilteredDownloadedBytesPerSecond.Stop(); @@ -1533,14 +1420,18 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) &Work, TotalPartWriteCount, &FilteredWrittenBytesPerSecond, - &BlockRange, + &PartialBlocks, + BlockRangeStartIndex, BlockChunkPath = std::filesystem::path(OnDiskPath), - BlockPartialBuffer = std::move(InMemoryBuffer)](std::atomic<bool>&) mutable { + BlockPartialBuffer = std::move(InMemoryBuffer), + OffsetAndLengths = std::vector<std::pair<uint64_t, uint64_t>>(OffsetAndLengths.begin(), + OffsetAndLengths.end())]( + std::atomic<bool>&) mutable { if (!m_AbortFlag) { ZEN_TRACE_CPU("Async_WritePartialBlock"); - const uint32_t BlockIndex = BlockRange.BlockIndex; + const uint32_t BlockIndex = PartialBlocks.BlockRanges[BlockRangeStartIndex].BlockIndex; const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; @@ -1563,22 +1454,41 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) FilteredWrittenBytesPerSecond.Start(); - if (!WritePartialBlockChunksToCache( - BlockDescription, - SequenceIndexChunksLeftToWriteCounters, - Work, - CompositeBuffer(std::move(BlockPartialBuffer)), - BlockRange.ChunkBlockIndexStart, - BlockRange.ChunkBlockIndexStart + BlockRange.ChunkBlockIndexCount - 1, - RemoteChunkIndexNeedsCopyFromSourceFlags, - WriteCache)) + size_t RangeCount = OffsetAndLengths.size(); + + for (size_t PartialRangeIndex = 0; PartialRangeIndex < RangeCount; PartialRangeIndex++) { - std::error_code DummyEc; - RemoveFile(BlockChunkPath, DummyEc); - throw std::runtime_error( - fmt::format("Partial block {} is malformed", BlockDescription.BlockHash)); - } + const std::pair<uint64_t, uint64_t>& OffsetAndLength = + OffsetAndLengths[PartialRangeIndex]; + IoBuffer BlockRangeBuffer(BlockPartialBuffer, + OffsetAndLength.first, + OffsetAndLength.second); + + const ChunkBlockAnalyser::BlockRangeDescriptor& RangeDescriptor = + PartialBlocks.BlockRanges[BlockRangeStartIndex + PartialRangeIndex]; + + if (!WritePartialBlockChunksToCache(BlockDescription, + SequenceIndexChunksLeftToWriteCounters, + Work, + CompositeBuffer(std::move(BlockRangeBuffer)), + RangeDescriptor.ChunkBlockIndexStart, + RangeDescriptor.ChunkBlockIndexStart + + RangeDescriptor.ChunkBlockIndexCount - 1, + RemoteChunkIndexNeedsCopyFromSourceFlags, + WriteCache)) + { + std::error_code DummyEc; + RemoveFile(BlockChunkPath, DummyEc); + throw std::runtime_error( + fmt::format("Partial block {} is malformed", BlockDescription.BlockHash)); + } + WritePartsComplete++; + if (WritePartsComplete == TotalPartWriteCount) + { + FilteredWrittenBytesPerSecond.Stop(); + } + } std::error_code Ec = TryRemoveFile(BlockChunkPath); if (Ec) { @@ -1588,12 +1498,6 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) Ec.value(), Ec.message()); } - - WritePartsComplete++; - if (WritePartsComplete == TotalPartWriteCount) - { - FilteredWrittenBytesPerSecond.Stop(); - } } }, OnDiskPath.empty() ? WorkerThreadPool::EMode::DisableBacklog @@ -1602,9 +1506,10 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) }); } }); + BlockRangeIndex += RangeCount; } - for (uint32_t BlockIndex : FullBlockWorks) + for (uint32_t BlockIndex : PartialBlocks.FullBlockIndexes) { if (m_AbortFlag) { @@ -1641,20 +1546,20 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) IoBuffer BlockBuffer; const bool ExistsInCache = - m_Storage.BuildCacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash); + m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash); if (ExistsInCache) { - BlockBuffer = m_Storage.BuildCacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); + BlockBuffer = m_Storage.CacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); } if (!BlockBuffer) { BlockBuffer = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); - if (BlockBuffer && m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (BlockBuffer && m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - BlockDescription.BlockHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(BlockBuffer))); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlockDescription.BlockHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(BlockBuffer))); } } if (!BlockBuffer) @@ -3217,10 +3122,10 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde const IoHash& ChunkHash = m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]; // FilteredDownloadedBytesPerSecond.Start(); IoBuffer BuildBlob; - const bool ExistsInCache = m_Storage.BuildCacheStorage && ExistsResult.ExistingBlobs.contains(ChunkHash); + const bool ExistsInCache = m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(ChunkHash); if (ExistsInCache) { - BuildBlob = m_Storage.BuildCacheStorage->GetBuildBlob(m_BuildId, ChunkHash); + BuildBlob = m_Storage.CacheStorage->GetBuildBlob(m_BuildId, ChunkHash); } if (BuildBlob) { @@ -3248,12 +3153,12 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde m_DownloadStats.DownloadedChunkCount++; m_DownloadStats.RequestsCompleteCount++; - if (Payload && m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (Payload && m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - ChunkHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(Payload))); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + ChunkHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(Payload))); } OnDownloaded(std::move(Payload)); @@ -3262,12 +3167,12 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde else { BuildBlob = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, ChunkHash); - if (BuildBlob && m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (BuildBlob && m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - ChunkHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(BuildBlob))); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + ChunkHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(BuildBlob))); } if (!BuildBlob) { @@ -3289,347 +3194,241 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde } } -BuildsOperationUpdateFolder::BlockRangeDescriptor -BuildsOperationUpdateFolder::MergeBlockRanges(std::span<const BlockRangeDescriptor> Ranges) +void +BuildsOperationUpdateFolder::DownloadPartialBlock( + std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges, + size_t BlockRangeStartIndex, + size_t BlockRangeCount, + const BlobsExistsResult& ExistsResult, + std::function<void(IoBuffer&& InMemoryBuffer, + const std::filesystem::path& OnDiskPath, + size_t BlockRangeStartIndex, + std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths)>&& OnDownloaded) { - ZEN_ASSERT(Ranges.size() > 1); - const BlockRangeDescriptor& First = Ranges.front(); - const BlockRangeDescriptor& Last = Ranges.back(); - - return BlockRangeDescriptor{.BlockIndex = First.BlockIndex, - .RangeStart = First.RangeStart, - .RangeLength = Last.RangeStart + Last.RangeLength - First.RangeStart, - .ChunkBlockIndexStart = First.ChunkBlockIndexStart, - .ChunkBlockIndexCount = Last.ChunkBlockIndexStart + Last.ChunkBlockIndexCount - First.ChunkBlockIndexStart}; -} + const uint32_t BlockIndex = BlockRanges[BlockRangeStartIndex].BlockIndex; -std::optional<std::vector<BuildsOperationUpdateFolder::BlockRangeDescriptor>> -BuildsOperationUpdateFolder::MakeOptionalBlockRangeVector(uint64_t TotalBlockSize, const BlockRangeDescriptor& Range) -{ - if (Range.RangeLength == TotalBlockSize) - { - return {}; - } - else - { - return std::vector<BlockRangeDescriptor>{Range}; - } -}; + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; -const BuildsOperationUpdateFolder::BlockRangeLimit* -BuildsOperationUpdateFolder::GetBlockRangeLimitForRange(std::span<const BlockRangeLimit> Limits, - uint64_t TotalBlockSize, - std::span<const BlockRangeDescriptor> Ranges) -{ - if (Ranges.size() > 1) - { - const std::uint64_t WantedSize = - std::accumulate(Ranges.begin(), Ranges.end(), uint64_t(0), [](uint64_t Current, const BlockRangeDescriptor& Range) { - return Current + Range.RangeLength; - }); + auto ProcessDownload = [this]( + const ChunkBlockDescription& BlockDescription, + IoBuffer&& BlockRangeBuffer, + size_t BlockRangeStartIndex, + std::span<const std::pair<uint64_t, uint64_t>> BlockOffsetAndLengths, + const std::function<void(IoBuffer && InMemoryBuffer, + const std::filesystem::path& OnDiskPath, + size_t BlockRangeStartIndex, + std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths)>& OnDownloaded) { + uint64_t BlockRangeBufferSize = BlockRangeBuffer.GetSize(); + m_DownloadStats.DownloadedBlockCount++; + m_DownloadStats.DownloadedBlockByteCount += BlockRangeBufferSize; + m_DownloadStats.RequestsCompleteCount += BlockOffsetAndLengths.size(); - const double RangeRequestedPercent = (WantedSize * 100.0) / TotalBlockSize; + std::filesystem::path BlockChunkPath; - for (const BlockRangeLimit& Limit : Limits) + // Check if the dowloaded block is file based and we can move it directly without rewriting it { - if (RangeRequestedPercent >= Limit.SizePercent && Ranges.size() > Limit.MaxRangeCount) + IoBufferFileReference FileRef; + if (BlockRangeBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) && + (FileRef.FileChunkSize == BlockRangeBufferSize)) { - return &Limit; - } - } - } - return nullptr; -}; + ZEN_TRACE_CPU("MoveTempPartialBlock"); -std::vector<BuildsOperationUpdateFolder::BlockRangeDescriptor> -BuildsOperationUpdateFolder::CollapseBlockRanges(const uint64_t AlwaysAcceptableGap, std::span<const BlockRangeDescriptor> BlockRanges) -{ - ZEN_ASSERT(BlockRanges.size() > 1); - std::vector<BlockRangeDescriptor> CollapsedBlockRanges; + std::error_code Ec; + std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec); + if (!Ec) + { + BlockRangeBuffer.SetDeleteOnClose(false); + BlockRangeBuffer = {}; - auto BlockRangesIt = BlockRanges.begin(); - CollapsedBlockRanges.push_back(*BlockRangesIt++); - for (; BlockRangesIt != BlockRanges.end(); BlockRangesIt++) - { - BlockRangeDescriptor& LastRange = CollapsedBlockRanges.back(); + IoHashStream RangeId; + for (const std::pair<uint64_t, uint64_t>& Range : BlockOffsetAndLengths) + { + RangeId.Append(&Range.first, sizeof(uint64_t)); + RangeId.Append(&Range.second, sizeof(uint64_t)); + } + + BlockChunkPath = m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash()); + RenameFile(TempBlobPath, BlockChunkPath, Ec); + if (Ec) + { + BlockChunkPath = std::filesystem::path{}; - const uint64_t BothRangeSize = BlockRangesIt->RangeLength + LastRange.RangeLength; + // Re-open the temp file again + BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete); + BlockRangeBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockRangeBufferSize, true); + BlockRangeBuffer.SetDeleteOnClose(true); + } + } + } + } - const uint64_t Gap = BlockRangesIt->RangeStart - (LastRange.RangeStart + LastRange.RangeLength); - if (Gap <= Max(BothRangeSize / 16, AlwaysAcceptableGap)) + if (BlockChunkPath.empty() && (BlockRangeBufferSize > m_Options.MaximumInMemoryPayloadSize)) { - LastRange.ChunkBlockIndexCount = - (BlockRangesIt->ChunkBlockIndexStart + BlockRangesIt->ChunkBlockIndexCount) - LastRange.ChunkBlockIndexStart; - LastRange.RangeLength = (BlockRangesIt->RangeStart + BlockRangesIt->RangeLength) - LastRange.RangeStart; + ZEN_TRACE_CPU("WriteTempPartialBlock"); + + IoHashStream RangeId; + for (const std::pair<uint64_t, uint64_t>& Range : BlockOffsetAndLengths) + { + RangeId.Append(&Range.first, sizeof(uint64_t)); + RangeId.Append(&Range.second, sizeof(uint64_t)); + } + + // Could not be moved and rather large, lets store it on disk + BlockChunkPath = m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash()); + TemporaryFile::SafeWriteFile(BlockChunkPath, BlockRangeBuffer); + BlockRangeBuffer = {}; } - else + if (!m_AbortFlag) { - CollapsedBlockRanges.push_back(*BlockRangesIt); + OnDownloaded(std::move(BlockRangeBuffer), std::move(BlockChunkPath), BlockRangeStartIndex, BlockOffsetAndLengths); } - } - - return CollapsedBlockRanges; -}; + }; -uint64_t -BuildsOperationUpdateFolder::CalculateNextGap(std::span<const BlockRangeDescriptor> BlockRanges) -{ - ZEN_ASSERT(BlockRanges.size() > 1); - uint64_t AcceptableGap = (uint64_t)-1; - for (size_t RangeIndex = 0; RangeIndex < BlockRanges.size() - 1; RangeIndex++) + std::vector<std::pair<uint64_t, uint64_t>> Ranges; + Ranges.reserve(BlockRangeCount); + for (size_t BlockRangeIndex = BlockRangeStartIndex; BlockRangeIndex < BlockRangeStartIndex + BlockRangeCount; BlockRangeIndex++) { - const BlockRangeDescriptor& Range = BlockRanges[RangeIndex]; - const BlockRangeDescriptor& NextRange = BlockRanges[RangeIndex + 1]; - - const uint64_t Gap = NextRange.RangeStart - (Range.RangeStart + Range.RangeLength); - AcceptableGap = Min(Gap, AcceptableGap); + const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = BlockRanges[BlockRangeIndex]; + Ranges.push_back(std::make_pair(BlockRange.RangeStart, BlockRange.RangeLength)); } - AcceptableGap = RoundUp(AcceptableGap, 16u * 1024u); - return AcceptableGap; -}; -std::optional<std::vector<BuildsOperationUpdateFolder::BlockRangeDescriptor>> -BuildsOperationUpdateFolder::CalculateBlockRanges(uint32_t BlockIndex, - const ChunkBlockDescription& BlockDescription, - std::span<const uint32_t> BlockChunkIndexNeeded, - bool LimitToSingleRange, - const uint64_t ChunkStartOffsetInBlock, - const uint64_t TotalBlockSize, - uint64_t& OutTotalWantedChunksSize) -{ - ZEN_TRACE_CPU("CalculateBlockRanges"); + const bool ExistsInCache = m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash); - std::vector<BlockRangeDescriptor> BlockRanges; + size_t SubBlockRangeCount = BlockRangeCount; + size_t SubRangeCountComplete = 0; + std::span<const std::pair<uint64_t, uint64_t>> RangesSpan(Ranges); + while (SubRangeCountComplete < SubBlockRangeCount) { - uint64_t CurrentOffset = ChunkStartOffsetInBlock; - uint32_t ChunkBlockIndex = 0; - uint32_t NeedBlockChunkIndexOffset = 0; - BlockRangeDescriptor NextRange{.BlockIndex = BlockIndex}; - while (NeedBlockChunkIndexOffset < BlockChunkIndexNeeded.size() && ChunkBlockIndex < BlockDescription.ChunkRawHashes.size()) + if (m_AbortFlag) + { + break; + } + + // First try to get subrange from cache. + // If not successful, try to get the ranges from the build store and adapt SubRangeCount... + + size_t SubRangeStartIndex = BlockRangeStartIndex + SubRangeCountComplete; + if (ExistsInCache) { - const uint32_t ChunkCompressedLength = BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; - if (ChunkBlockIndex < BlockChunkIndexNeeded[NeedBlockChunkIndexOffset]) + size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, m_Storage.CacheHost.Caps.MaxRangeCountPerRequest); + + if (SubRangeCount == 1) { - if (NextRange.RangeLength > 0) + // Legacy single-range path, prefer that for max compatibility + + const std::pair<uint64_t, uint64_t> SubRange = RangesSpan[SubRangeCountComplete]; + IoBuffer PayloadBuffer = + m_Storage.CacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash, SubRange.first, SubRange.second); + if (m_AbortFlag) { - BlockRanges.push_back(NextRange); - NextRange = {.BlockIndex = BlockIndex}; + break; } - ChunkBlockIndex++; - CurrentOffset += ChunkCompressedLength; - } - else if (ChunkBlockIndex == BlockChunkIndexNeeded[NeedBlockChunkIndexOffset]) - { - if (NextRange.RangeLength == 0) + if (PayloadBuffer) { - NextRange.RangeStart = CurrentOffset; - NextRange.ChunkBlockIndexStart = ChunkBlockIndex; + ProcessDownload(BlockDescription, + std::move(PayloadBuffer), + SubRangeStartIndex, + std::vector<std::pair<uint64_t, uint64_t>>{std::make_pair(0u, SubRange.second)}, + OnDownloaded); + SubRangeCountComplete += SubRangeCount; + continue; } - NextRange.RangeLength += ChunkCompressedLength; - NextRange.ChunkBlockIndexCount++; - ChunkBlockIndex++; - CurrentOffset += ChunkCompressedLength; - NeedBlockChunkIndexOffset++; } else { - ZEN_ASSERT(false); - } - } - if (NextRange.RangeLength > 0) - { - BlockRanges.push_back(NextRange); - } - } - ZEN_ASSERT(!BlockRanges.empty()); - - OutTotalWantedChunksSize = - std::accumulate(BlockRanges.begin(), BlockRanges.end(), uint64_t(0), [](uint64_t Current, const BlockRangeDescriptor& Range) { - return Current + Range.RangeLength; - }); + auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); - double RangeWantedPercent = (OutTotalWantedChunksSize * 100.0) / TotalBlockSize; - - if (BlockRanges.size() == 1) - { - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Range request of {} ({:.2f}%) using single range from block {} ({}) as is", - NiceBytes(OutTotalWantedChunksSize), - RangeWantedPercent, - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize)); + BuildStorageCache::BuildBlobRanges RangeBuffers = + m_Storage.CacheStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges); + if (m_AbortFlag) + { + break; + } + if (RangeBuffers.PayloadBuffer) + { + if (RangeBuffers.Ranges.empty()) + { + SubRangeCount = Ranges.size() - SubRangeCountComplete; + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangesSpan.subspan(SubRangeCountComplete, SubRangeCount), + OnDownloaded); + SubRangeCountComplete += SubRangeCount; + continue; + } + else if (RangeBuffers.Ranges.size() == SubRangeCount) + { + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangeBuffers.Ranges, + OnDownloaded); + SubRangeCountComplete += SubRangeCount; + continue; + } + } + } } - return BlockRanges; - } - if (LimitToSingleRange) - { - const BlockRangeDescriptor MergedRange = MergeBlockRanges(BlockRanges); - if (m_Options.IsVerbose) - { - const double RangeRequestedPercent = (MergedRange.RangeLength * 100.0) / TotalBlockSize; - const double WastedPercent = ((MergedRange.RangeLength - OutTotalWantedChunksSize) * 100.0) / MergedRange.RangeLength; + size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest); - ZEN_OPERATION_LOG_INFO( - m_LogOutput, - "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) limited to single block range {} ({:.2f}%) wasting " - "{:.2f}% ({})", - NiceBytes(OutTotalWantedChunksSize), - RangeWantedPercent, - BlockRanges.size(), - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize), - NiceBytes(MergedRange.RangeLength), - RangeRequestedPercent, - WastedPercent, - NiceBytes(MergedRange.RangeLength - OutTotalWantedChunksSize)); - } - return MakeOptionalBlockRangeVector(TotalBlockSize, MergedRange); - } + auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); - if (RangeWantedPercent > FullBlockRangePercentLimit) - { - const BlockRangeDescriptor MergedRange = MergeBlockRanges(BlockRanges); - if (m_Options.IsVerbose) + BuildStorageBase::BuildBlobRanges RangeBuffers = + m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges); + if (m_AbortFlag) { - const double RangeRequestedPercent = (MergedRange.RangeLength * 100.0) / TotalBlockSize; - const double WastedPercent = ((MergedRange.RangeLength - OutTotalWantedChunksSize) * 100.0) / MergedRange.RangeLength; - - ZEN_OPERATION_LOG_INFO( - m_LogOutput, - "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) exceeds {}%. Merged to single block range {} " - "({:.2f}%) wasting {:.2f}% ({})", - NiceBytes(OutTotalWantedChunksSize), - RangeWantedPercent, - BlockRanges.size(), - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize), - FullBlockRangePercentLimit, - NiceBytes(MergedRange.RangeLength), - RangeRequestedPercent, - WastedPercent, - NiceBytes(MergedRange.RangeLength - OutTotalWantedChunksSize)); + break; } - return MakeOptionalBlockRangeVector(TotalBlockSize, MergedRange); - } - - std::vector<BlockRangeDescriptor> CollapsedBlockRanges = CollapseBlockRanges(16u * 1024u, BlockRanges); - while (GetBlockRangeLimitForRange(ForceMergeLimits, TotalBlockSize, CollapsedBlockRanges)) - { - CollapsedBlockRanges = CollapseBlockRanges(CalculateNextGap(CollapsedBlockRanges), CollapsedBlockRanges); - } - - const std::uint64_t WantedCollapsedSize = - std::accumulate(CollapsedBlockRanges.begin(), - CollapsedBlockRanges.end(), - uint64_t(0), - [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; }); - - const double CollapsedRangeRequestedPercent = (WantedCollapsedSize * 100.0) / TotalBlockSize; - - if (m_Options.IsVerbose) - { - const double WastedPercent = ((WantedCollapsedSize - OutTotalWantedChunksSize) * 100.0) / WantedCollapsedSize; - - ZEN_OPERATION_LOG_INFO( - m_LogOutput, - "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) collapsed to {} {:.2f}% using {} ranges wasting {:.2f}% " - "({})", - NiceBytes(OutTotalWantedChunksSize), - RangeWantedPercent, - BlockRanges.size(), - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize), - NiceBytes(WantedCollapsedSize), - CollapsedRangeRequestedPercent, - CollapsedBlockRanges.size(), - WastedPercent, - NiceBytes(WantedCollapsedSize - OutTotalWantedChunksSize)); - } - return CollapsedBlockRanges; -} - -void -BuildsOperationUpdateFolder::DownloadPartialBlock( - const BlockRangeDescriptor BlockRange, - const BlobsExistsResult& ExistsResult, - std::function<void(IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath)>&& OnDownloaded) -{ - const uint32_t BlockIndex = BlockRange.BlockIndex; - - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - - IoBuffer BlockBuffer; - if (m_Storage.BuildCacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash)) - { - BlockBuffer = - m_Storage.BuildCacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength); - } - if (!BlockBuffer) - { - BlockBuffer = - m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength); - } - if (!BlockBuffer) - { - throw std::runtime_error(fmt::format("Block {} is missing when fetching range {} -> {}", - BlockDescription.BlockHash, - BlockRange.RangeStart, - BlockRange.RangeStart + BlockRange.RangeLength)); - } - if (!m_AbortFlag) - { - uint64_t BlockSize = BlockBuffer.GetSize(); - m_DownloadStats.DownloadedBlockCount++; - m_DownloadStats.DownloadedBlockByteCount += BlockSize; - m_DownloadStats.RequestsCompleteCount++; - - std::filesystem::path BlockChunkPath; - - // Check if the dowloaded block is file based and we can move it directly without rewriting it + if (RangeBuffers.PayloadBuffer) { - IoBufferFileReference FileRef; - if (BlockBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) && (FileRef.FileChunkSize == BlockSize)) + if (RangeBuffers.Ranges.empty()) { - ZEN_TRACE_CPU("MoveTempPartialBlock"); + // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3 + // Upload to cache (if enabled) and use the whole payload for the remaining ranges - std::error_code Ec; - std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec); - if (!Ec) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - BlockBuffer.SetDeleteOnClose(false); - BlockBuffer = {}; - BlockChunkPath = m_TempBlockFolderPath / - fmt::format("{}_{:x}_{:x}", BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength); - RenameFile(TempBlobPath, BlockChunkPath, Ec); - if (Ec) + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlockDescription.BlockHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(std::vector<IoBuffer>{RangeBuffers.PayloadBuffer})); + if (m_AbortFlag) { - BlockChunkPath = std::filesystem::path{}; - - // Re-open the temp file again - BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete); - BlockBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockSize, true); - BlockBuffer.SetDeleteOnClose(true); + break; } } - } - } - if (BlockChunkPath.empty() && (BlockSize > m_Options.MaximumInMemoryPayloadSize)) - { - ZEN_TRACE_CPU("WriteTempPartialBlock"); - // Could not be moved and rather large, lets store it on disk - BlockChunkPath = m_TempBlockFolderPath / - fmt::format("{}_{:x}_{:x}", BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength); - TemporaryFile::SafeWriteFile(BlockChunkPath, BlockBuffer); - BlockBuffer = {}; + SubRangeCount = Ranges.size() - SubRangeCountComplete; + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangesSpan.subspan(SubRangeCountComplete, SubRangeCount), + OnDownloaded); + } + else + { + if (RangeBuffers.Ranges.size() != SubRanges.size()) + { + throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", + SubRanges.size(), + BlockDescription.BlockHash, + RangeBuffers.Ranges.size())); + } + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangeBuffers.Ranges, + OnDownloaded); + } } - if (!m_AbortFlag) + else { - OnDownloaded(std::move(BlockBuffer), std::move(BlockChunkPath)); + throw std::runtime_error(fmt::format("Block {} is missing when fetching {} ranges", BlockDescription.BlockHash, SubRangeCount)); } + + SubRangeCountComplete += SubRangeCount; } } @@ -4083,7 +3882,8 @@ BuildsOperationUpdateFolder::WriteSequenceChunkToCache(BufferedWriteFileCache::L } bool -BuildsOperationUpdateFolder::GetBlockWriteOps(std::span<const IoHash> ChunkRawHashes, +BuildsOperationUpdateFolder::GetBlockWriteOps(const IoHash& BlockRawHash, + std::span<const IoHash> ChunkRawHashes, std::span<const uint32_t> ChunkCompressedLengths, std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters, std::span<std::atomic<bool>> RemoteChunkIndexNeedsCopyFromSourceFlags, @@ -4115,9 +3915,34 @@ BuildsOperationUpdateFolder::GetBlockWriteOps(std::span<const IoHash> ChunkR uint64_t VerifyChunkSize; CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed(SharedBuffer::MakeView(ChunkMemoryView), VerifyChunkHash, VerifyChunkSize); - ZEN_ASSERT(CompressedChunk); - ZEN_ASSERT(VerifyChunkHash == ChunkHash); - ZEN_ASSERT(VerifyChunkSize == m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]); + if (!CompressedChunk) + { + throw std::runtime_error(fmt::format("Chunk {} at {}, size {} in block {} is not a valid compressed buffer", + ChunkHash, + OffsetInBlock, + ChunkCompressedSize, + BlockRawHash)); + } + if (VerifyChunkHash != ChunkHash) + { + throw std::runtime_error(fmt::format("Chunk {} at {}, size {} in block {} has a mismatching content hash {}", + ChunkHash, + OffsetInBlock, + ChunkCompressedSize, + BlockRawHash, + VerifyChunkHash)); + } + if (VerifyChunkSize != m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]) + { + throw std::runtime_error( + fmt::format("Chunk {} at {}, size {} in block {} has a mismatching raw size {}, expected {}", + ChunkHash, + OffsetInBlock, + ChunkCompressedSize, + BlockRawHash, + VerifyChunkSize, + m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex])); + } OodleCompressor ChunkCompressor; OodleCompressionLevel ChunkCompressionLevel; @@ -4138,7 +3963,18 @@ BuildsOperationUpdateFolder::GetBlockWriteOps(std::span<const IoHash> ChunkR { Decompressed = CompressedChunk.Decompress().AsIoBuffer(); } - ZEN_ASSERT(Decompressed.GetSize() == m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]); + + if (Decompressed.GetSize() != m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]) + { + throw std::runtime_error(fmt::format("Chunk {} at {}, size {} in block {} decompressed to size {}, expected {}", + ChunkHash, + OffsetInBlock, + ChunkCompressedSize, + BlockRawHash, + Decompressed.GetSize(), + m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex])); + } + ZEN_ASSERT_SLOW(ChunkHash == IoHash::HashBuffer(Decompressed)); for (const ChunkedContentLookup::ChunkSequenceLocation* Target : ChunkTargetPtrs) { @@ -4237,7 +4073,8 @@ BuildsOperationUpdateFolder::WriteChunksBlockToCache(const ChunkBlockDescription const std::vector<uint32_t> ChunkCompressedLengths = ReadChunkBlockHeader(BlockView.Mid(CompressedBuffer::GetHeaderSizeForNoneEncoder()), HeaderSize); - if (GetBlockWriteOps(BlockDescription.ChunkRawHashes, + if (GetBlockWriteOps(BlockDescription.BlockHash, + BlockDescription.ChunkRawHashes, ChunkCompressedLengths, SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndexNeedsCopyFromSourceFlags, @@ -4252,7 +4089,8 @@ BuildsOperationUpdateFolder::WriteChunksBlockToCache(const ChunkBlockDescription return false; } - if (GetBlockWriteOps(BlockDescription.ChunkRawHashes, + if (GetBlockWriteOps(BlockDescription.BlockHash, + BlockDescription.ChunkRawHashes, BlockDescription.ChunkCompressedLengths, SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndexNeedsCopyFromSourceFlags, @@ -4283,7 +4121,8 @@ BuildsOperationUpdateFolder::WritePartialBlockChunksToCache(const ChunkBlockDesc const MemoryView BlockView = BlockMemoryBuffer.GetView(); BlockWriteOps Ops; - if (GetBlockWriteOps(BlockDescription.ChunkRawHashes, + if (GetBlockWriteOps(BlockDescription.BlockHash, + BlockDescription.ChunkRawHashes, BlockDescription.ChunkCompressedLengths, SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndexNeedsCopyFromSourceFlags, @@ -5156,12 +4995,12 @@ BuildsOperationUploadFolder::GenerateBuildBlocks(const ChunkedFolderContent& const IoHash& BlockHash = OutBlocks.BlockDescriptions[BlockIndex].BlockHash; const uint64_t CompressedBlockSize = Payload.GetCompressedSize(); - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - BlockHash, - ZenContentType::kCompressedBinary, - Payload.GetCompressed()); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlockHash, + ZenContentType::kCompressedBinary, + Payload.GetCompressed()); } m_Storage.BuildStorage->PutBuildBlob(m_BuildId, @@ -5179,11 +5018,11 @@ BuildsOperationUploadFolder::GenerateBuildBlocks(const ChunkedFolderContent& OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size()); } - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBlobMetadatas(m_BuildId, - std::vector<IoHash>({BlockHash}), - std::vector<CbObject>({BlockMetaData})); + m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, + std::vector<IoHash>({BlockHash}), + std::vector<CbObject>({BlockMetaData})); } bool MetadataSucceeded = @@ -5334,6 +5173,13 @@ BuildsOperationUploadFolder::FetchChunk(const ChunkedFolderContent& Content, ZEN_ASSERT(!ChunkLocations.empty()); CompositeBuffer Chunk = OpenFileCache.GetRange(ChunkLocations[0].SequenceIndex, ChunkLocations[0].Offset, Content.ChunkedContent.ChunkRawSizes[ChunkIndex]); + if (!Chunk) + { + throw std::runtime_error(fmt::format("Unable to read chunk at {}, size {} from '{}'", + ChunkLocations[0].Offset, + Content.ChunkedContent.ChunkRawSizes[ChunkIndex], + Content.Paths[Lookup.SequenceIndexFirstPathIndex[ChunkLocations[0].SequenceIndex]])); + } ZEN_ASSERT_SLOW(IoHash::HashBuffer(Chunk) == ChunkHash); return Chunk; }; @@ -5362,10 +5208,7 @@ BuildsOperationUploadFolder::GenerateBlock(const ChunkedFolderContent& Content, Content.ChunkedContent.ChunkHashes[ChunkIndex], [this, &Content, &Lookup, &OpenFileCache, ChunkIndex](const IoHash& ChunkHash) -> std::pair<uint64_t, CompressedBuffer> { CompositeBuffer Chunk = FetchChunk(Content, Lookup, ChunkHash, OpenFileCache); - if (!Chunk) - { - ZEN_ASSERT(false); - } + ZEN_ASSERT(Chunk); uint64_t RawSize = Chunk.GetSize(); const bool ShouldCompressChunk = RawSize >= m_Options.MinimumSizeForCompressInBlock && @@ -6023,11 +5866,11 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController { const CbObject BlockMetaData = BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]); - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBlobMetadatas(m_BuildId, - std::vector<IoHash>({BlockHash}), - std::vector<CbObject>({BlockMetaData})); + m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, + std::vector<IoHash>({BlockHash}), + std::vector<CbObject>({BlockMetaData})); } bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); if (MetadataSucceeded) @@ -6221,9 +6064,9 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co const CbObject BlockMetaData = BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]); - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); } m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); if (m_Options.IsVerbose) @@ -6237,11 +6080,11 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co UploadedBlockSize += PayloadSize; TempUploadStats.BlocksBytes += PayloadSize; - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBlobMetadatas(m_BuildId, - std::vector<IoHash>({BlockHash}), - std::vector<CbObject>({BlockMetaData})); + m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, + std::vector<IoHash>({BlockHash}), + std::vector<CbObject>({BlockMetaData})); } bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); if (MetadataSucceeded) @@ -6304,9 +6147,9 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co const uint64_t PayloadSize = Payload.GetSize(); - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload); } if (PayloadSize >= LargeAttachmentSize) @@ -7050,14 +6893,14 @@ BuildsOperationPrimeCache::Execute() std::vector<IoHash> BlobsToDownload; BlobsToDownload.reserve(BuildBlobs.size()); - if (m_Storage.BuildCacheStorage && !BuildBlobs.empty() && !m_Options.ForceUpload) + if (m_Storage.CacheStorage && !BuildBlobs.empty() && !m_Options.ForceUpload) { ZEN_TRACE_CPU("BlobCacheExistCheck"); Stopwatch Timer; const std::vector<IoHash> BlobHashes(BuildBlobs.begin(), BuildBlobs.end()); const std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult = - m_Storage.BuildCacheStorage->BlobsExists(m_BuildId, BlobHashes); + m_Storage.CacheStorage->BlobsExists(m_BuildId, BlobHashes); if (CacheExistsResult.size() == BlobHashes.size()) { @@ -7104,33 +6947,33 @@ BuildsOperationPrimeCache::Execute() for (size_t BlobIndex = 0; BlobIndex < BlobCount; BlobIndex++) { - Work.ScheduleWork( - m_NetworkPool, - [this, - &Work, - &BlobsToDownload, - BlobCount, - &LooseChunkRawSizes, - &CompletedDownloadCount, - &FilteredDownloadedBytesPerSecond, - &MultipartAttachmentCount, - BlobIndex](std::atomic<bool>&) { - if (!m_AbortFlag) - { - const IoHash& BlobHash = BlobsToDownload[BlobIndex]; + Work.ScheduleWork(m_NetworkPool, + [this, + &Work, + &BlobsToDownload, + BlobCount, + &LooseChunkRawSizes, + &CompletedDownloadCount, + &FilteredDownloadedBytesPerSecond, + &MultipartAttachmentCount, + BlobIndex](std::atomic<bool>&) { + if (!m_AbortFlag) + { + const IoHash& BlobHash = BlobsToDownload[BlobIndex]; - bool IsLargeBlob = false; + bool IsLargeBlob = false; - if (auto It = LooseChunkRawSizes.find(BlobHash); It != LooseChunkRawSizes.end()) - { - IsLargeBlob = It->second >= m_Options.LargeAttachmentSize; - } + if (auto It = LooseChunkRawSizes.find(BlobHash); It != LooseChunkRawSizes.end()) + { + IsLargeBlob = It->second >= m_Options.LargeAttachmentSize; + } - FilteredDownloadedBytesPerSecond.Start(); + FilteredDownloadedBytesPerSecond.Start(); - if (IsLargeBlob) - { - DownloadLargeBlob(*m_Storage.BuildStorage, + if (IsLargeBlob) + { + DownloadLargeBlob( + *m_Storage.BuildStorage, m_TempPath, m_BuildId, BlobHash, @@ -7146,12 +6989,12 @@ BuildsOperationPrimeCache::Execute() if (!m_AbortFlag) { - if (Payload && m_Storage.BuildCacheStorage) + if (Payload && m_Storage.CacheStorage) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - BlobHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(Payload))); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlobHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(Payload))); } } CompletedDownloadCount++; @@ -7160,32 +7003,32 @@ BuildsOperationPrimeCache::Execute() FilteredDownloadedBytesPerSecond.Stop(); } }); - } - else - { - IoBuffer Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash); - m_DownloadStats.DownloadedBlockCount++; - m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize(); - m_DownloadStats.RequestsCompleteCount++; + } + else + { + IoBuffer Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash); + m_DownloadStats.DownloadedBlockCount++; + m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize(); + m_DownloadStats.RequestsCompleteCount++; - if (!m_AbortFlag) - { - if (Payload && m_Storage.BuildCacheStorage) - { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - BlobHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(std::move(Payload)))); - } - } - CompletedDownloadCount++; - if (CompletedDownloadCount == BlobCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } - } - } - }); + if (!m_AbortFlag) + { + if (Payload && m_Storage.CacheStorage) + { + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlobHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(std::move(Payload)))); + } + } + CompletedDownloadCount++; + if (CompletedDownloadCount == BlobCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } + } + } + }); } Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { @@ -7197,10 +7040,10 @@ BuildsOperationPrimeCache::Execute() std::string DownloadRateString = (CompletedDownloadCount == BlobCount) ? "" : fmt::format(" {}bits/s", NiceNum(FilteredDownloadedBytesPerSecond.GetCurrent() * 8)); - std::string UploadDetails = m_Storage.BuildCacheStorage ? fmt::format(" {} ({}) uploaded.", - m_StorageCacheStats.PutBlobCount.load(), - NiceBytes(m_StorageCacheStats.PutBlobByteCount.load())) - : ""; + std::string UploadDetails = m_Storage.CacheStorage ? fmt::format(" {} ({}) uploaded.", + m_StorageCacheStats.PutBlobCount.load(), + NiceBytes(m_StorageCacheStats.PutBlobByteCount.load())) + : ""; std::string Details = fmt::format("{}/{} ({}{}) downloaded.{}", CompletedDownloadCount.load(), @@ -7225,13 +7068,13 @@ BuildsOperationPrimeCache::Execute() return; } - if (m_Storage.BuildCacheStorage) + if (m_Storage.CacheStorage) { - m_Storage.BuildCacheStorage->Flush(m_LogOutput.GetProgressUpdateDelayMS(), [this](intptr_t Remaining) -> bool { + m_Storage.CacheStorage->Flush(m_LogOutput.GetProgressUpdateDelayMS(), [this](intptr_t Remaining) -> bool { ZEN_UNUSED(Remaining); if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Waiting for {} blobs to finish upload to '{}'", Remaining, m_Storage.CacheName); + ZEN_OPERATION_LOG_INFO(m_LogOutput, "Waiting for {} blobs to finish upload to '{}'", Remaining, m_Storage.CacheHost.Name); } return !m_AbortFlag; }); @@ -7431,16 +7274,31 @@ GetRemoteContent(OperationLogOutput& Output, // TODO: GetBlockDescriptions for all BlockRawHashes in one go - check for local block descriptions when we cache them { + if (!IsQuiet) + { + ZEN_OPERATION_LOG_INFO(Output, "Fetching metadata for {} blocks", BlockRawHashes.size()); + } + + Stopwatch GetBlockMetadataTimer; + bool AttemptFallback = false; OutBlockDescriptions = GetBlockDescriptions(Output, *Storage.BuildStorage, - Storage.BuildCacheStorage.get(), + Storage.CacheStorage.get(), BuildId, - BuildPartId, BlockRawHashes, AttemptFallback, IsQuiet, IsVerbose); + + if (!IsQuiet) + { + ZEN_OPERATION_LOG_INFO(Output, + "GetBlockMetadata for {} took {}. Found {} blocks", + BuildPartId, + NiceTimeSpanMs(GetBlockMetadataTimer.GetElapsedTimeMs()), + OutBlockDescriptions.size()); + } } CalculateLocalChunkOrders(AbsoluteChunkOrders, @@ -7989,6 +7847,8 @@ namespace buildstorageoperations_testutils { } // namespace buildstorageoperations_testutils +TEST_SUITE_BEGIN("remotestore.buildstorageoperations"); + TEST_CASE("buildstorageoperations.upload.folder") { using namespace buildstorageoperations_testutils; @@ -8176,106 +8036,270 @@ TEST_CASE("buildstorageoperations.memorychunkingcache") TEST_CASE("buildstorageoperations.upload.multipart") { - using namespace buildstorageoperations_testutils; + // Disabled since it relies on authentication and specific block being present in cloud storage + if (false) + { + using namespace buildstorageoperations_testutils; - FastRandom BaseRandom; + FastRandom BaseRandom; - const size_t FileCount = 11; + const size_t FileCount = 11; - const std::string Paths[FileCount] = {{"file_1"}, - {"file_2.exe"}, - {"file_3.txt"}, - {"dir_1/dir1_file_1.exe"}, - {"dir_1/dir1_file_2.pdb"}, - {"dir_1/dir1_file_3.txt"}, - {"dir_2/dir2_dir1/dir2_dir1_file_1.exe"}, - {"dir_2/dir2_dir1/dir2_dir1_file_2.pdb"}, - {"dir_2/dir2_dir1/dir2_dir1_file_3.dll"}, - {"dir_2/dir2_dir2/dir2_dir2_file_1.txt"}, - {"dir_2/dir2_dir2/dir2_dir2_file_2.json"}}; - const uint64_t Sizes[FileCount] = - {6u * 1024u, 0, 798, 19u * 1024u, 7u * 1024u, 93, 31u * 1024u, 17u * 1024u, 13u * 1024u, 2u * 1024u, 3u * 1024u}; + const std::string Paths[FileCount] = {{"file_1"}, + {"file_2.exe"}, + {"file_3.txt"}, + {"dir_1/dir1_file_1.exe"}, + {"dir_1/dir1_file_2.pdb"}, + {"dir_1/dir1_file_3.txt"}, + {"dir_2/dir2_dir1/dir2_dir1_file_1.exe"}, + {"dir_2/dir2_dir1/dir2_dir1_file_2.pdb"}, + {"dir_2/dir2_dir1/dir2_dir1_file_3.dll"}, + {"dir_2/dir2_dir2/dir2_dir2_file_1.txt"}, + {"dir_2/dir2_dir2/dir2_dir2_file_2.json"}}; + const uint64_t Sizes[FileCount] = + {6u * 1024u, 0, 798, 19u * 1024u, 7u * 1024u, 93, 31u * 1024u, 17u * 1024u, 13u * 1024u, 2u * 1024u, 3u * 1024u}; - ScopedTemporaryDirectory SourceFolder; - TestState State(SourceFolder.Path()); - State.Initialize(); - State.CreateSourceData("source", Paths, Sizes); + ScopedTemporaryDirectory SourceFolder; + TestState State(SourceFolder.Path()); + State.Initialize(); + State.CreateSourceData("source", Paths, Sizes); - std::span<const std::string> ManifestFiles1(Paths); - ManifestFiles1 = ManifestFiles1.subspan(0, FileCount / 2); + std::span<const std::string> ManifestFiles1(Paths); + ManifestFiles1 = ManifestFiles1.subspan(0, FileCount / 2); - std::span<const uint64_t> ManifestSizes1(Sizes); - ManifestSizes1 = ManifestSizes1.subspan(0, FileCount / 2); + std::span<const uint64_t> ManifestSizes1(Sizes); + ManifestSizes1 = ManifestSizes1.subspan(0, FileCount / 2); - std::span<const std::string> ManifestFiles2(Paths); - ManifestFiles2 = ManifestFiles2.subspan(FileCount / 2 - 1); + std::span<const std::string> ManifestFiles2(Paths); + ManifestFiles2 = ManifestFiles2.subspan(FileCount / 2 - 1); - std::span<const uint64_t> ManifestSizes2(Sizes); - ManifestSizes2 = ManifestSizes2.subspan(FileCount / 2 - 1); + std::span<const uint64_t> ManifestSizes2(Sizes); + ManifestSizes2 = ManifestSizes2.subspan(FileCount / 2 - 1); - const Oid BuildPart1Id = Oid::NewOid(); - const std::string BuildPart1Name = "part1"; - const Oid BuildPart2Id = Oid::NewOid(); - const std::string BuildPart2Name = "part2"; - { - CbObjectWriter Writer; - Writer.BeginObject("parts"sv); + const Oid BuildPart1Id = Oid::NewOid(); + const std::string BuildPart1Name = "part1"; + const Oid BuildPart2Id = Oid::NewOid(); + const std::string BuildPart2Name = "part2"; { - Writer.BeginObject(BuildPart1Name); + CbObjectWriter Writer; + Writer.BeginObject("parts"sv); { - Writer.AddObjectId("partId"sv, BuildPart1Id); - Writer.BeginArray("files"sv); - for (const std::string& ManifestFile : ManifestFiles1) + Writer.BeginObject(BuildPart1Name); { - Writer.AddString(ManifestFile); + Writer.AddObjectId("partId"sv, BuildPart1Id); + Writer.BeginArray("files"sv); + for (const std::string& ManifestFile : ManifestFiles1) + { + Writer.AddString(ManifestFile); + } + Writer.EndArray(); // files + } + Writer.EndObject(); // part1 + + Writer.BeginObject(BuildPart2Name); + { + Writer.AddObjectId("partId"sv, BuildPart2Id); + Writer.BeginArray("files"sv); + for (const std::string& ManifestFile : ManifestFiles2) + { + Writer.AddString(ManifestFile); + } + Writer.EndArray(); // files } - Writer.EndArray(); // files + Writer.EndObject(); // part2 + } + Writer.EndObject(); // parts + + ExtendableStringBuilder<1024> Manifest; + CompactBinaryToJson(Writer.Save(), Manifest); + WriteFile(State.RootPath / "manifest.json", IoBuffer(IoBuffer::Wrap, Manifest.Data(), Manifest.Size())); + } + + const Oid BuildId = Oid::NewOid(); + + auto Result = State.Upload(BuildId, {}, {}, "source", State.RootPath / "manifest.json"); + + CHECK_EQ(Result.size(), 2u); + CHECK_EQ(Result[0].first, BuildPart1Id); + CHECK_EQ(Result[0].second, BuildPart1Name); + CHECK_EQ(Result[1].first, BuildPart2Id); + CHECK_EQ(Result[1].second, BuildPart2Name); + State.ValidateUpload(BuildId, Result); + + FolderContent DownloadContent = State.Download(BuildId, Oid::Zero, {}, "download", /* Append */ false); + State.ValidateDownload(Paths, Sizes, "source", "download", DownloadContent); + + FolderContent Part1DownloadContent = State.Download(BuildId, BuildPart1Id, {}, "download_part1", /* Append */ false); + State.ValidateDownload(ManifestFiles1, ManifestSizes1, "source", "download_part1", Part1DownloadContent); + + FolderContent Part2DownloadContent = State.Download(BuildId, Oid::Zero, BuildPart2Name, "download_part2", /* Append */ false); + State.ValidateDownload(ManifestFiles2, ManifestSizes2, "source", "download_part2", Part2DownloadContent); + + (void)State.Download(BuildId, BuildPart1Id, BuildPart1Name, "download_part1+2", /* Append */ false); + FolderContent Part1And2DownloadContent = State.Download(BuildId, BuildPart2Id, {}, "download_part1+2", /* Append */ true); + State.ValidateDownload(Paths, Sizes, "source", "download_part1+2", Part1And2DownloadContent); + } +} + +TEST_CASE("buildstorageoperations.partial.block.download" * doctest::skip(true)) +{ + const std::string OidcExecutableName = "OidcToken" ZEN_EXE_SUFFIX_LITERAL; + std::filesystem::path OidcTokenExePath = (GetRunningExecutablePath().parent_path() / OidcExecutableName).make_preferred(); + + HttpClientSettings ClientSettings{ + .LogCategory = "httpbuildsclient", + .AccessTokenProvider = + httpclientauth::CreateFromOidcTokenExecutable(OidcTokenExePath, "https://jupiter.devtools.epicgames.com", true, false, false), + .AssumeHttp2 = false, + .AllowResume = true, + .RetryCount = 0, + .Verbose = false}; + + HttpClient HttpClient("https://euc.jupiter.devtools.epicgames.com", ClientSettings); + + const std::string_view Namespace = "fortnite.oplog"; + const std::string_view Bucket = "fortnitegame.staged-build.fortnite-main.ps4-client"; + const Oid BuildId = Oid::FromHexString("09a76ea92ad301d4724fafad"); + + { + HttpClient::Response Response = HttpClient.Get(fmt::format("/api/v2/builds/{}/{}/{}", Namespace, Bucket, BuildId), + HttpClient::Accept(ZenContentType::kCbObject)); + CbValidateError ValidateResult = CbValidateError::None; + CbObject Object = ValidateAndReadCompactBinaryObject(IoBuffer(Response.ResponsePayload), ValidateResult); + REQUIRE(ValidateResult == CbValidateError::None); + } + + std::vector<ChunkBlockDescription> BlockDescriptions; + { + CbObjectWriter Request; + + Request.BeginArray("blocks"sv); + { + Request.AddHash(IoHash::FromHexString("7c353ed782675a5e8f968e61e51fc797ecdc2882")); + } + Request.EndArray(); + + IoBuffer Payload = Request.Save().GetBuffer().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCbObject); + + HttpClient::Response BlockDescriptionsResponse = + HttpClient.Post(fmt::format("/api/v2/builds/{}/{}/{}/blocks/getBlockMetadata", Namespace, Bucket, BuildId), + Payload, + HttpClient::Accept(ZenContentType::kCbObject)); + REQUIRE(BlockDescriptionsResponse.IsSuccess()); + + CbValidateError ValidateResult = CbValidateError::None; + CbObject Object = ValidateAndReadCompactBinaryObject(IoBuffer(BlockDescriptionsResponse.ResponsePayload), ValidateResult); + REQUIRE(ValidateResult == CbValidateError::None); + + { + CbArrayView BlocksArray = Object["blocks"sv].AsArrayView(); + for (CbFieldView Block : BlocksArray) + { + ChunkBlockDescription Description = ParseChunkBlockDescription(Block.AsObjectView()); + BlockDescriptions.emplace_back(std::move(Description)); } - Writer.EndObject(); // part1 + } + } + + REQUIRE(!BlockDescriptions.empty()); - Writer.BeginObject(BuildPart2Name); + const IoHash BlockHash = BlockDescriptions.back().BlockHash; + + const ChunkBlockDescription& BlockDescription = BlockDescriptions.front(); + REQUIRE(!BlockDescription.ChunkRawHashes.empty()); + REQUIRE(!BlockDescription.ChunkCompressedLengths.empty()); + + std::vector<std::pair<uint64_t, uint64_t>> ChunkOffsetAndSizes; + uint64_t Offset = gsl::narrow<uint32_t>(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize); + + for (uint32_t ChunkCompressedSize : BlockDescription.ChunkCompressedLengths) + { + ChunkOffsetAndSizes.push_back(std::make_pair(Offset, ChunkCompressedSize)); + Offset += ChunkCompressedSize; + } + + ScopedTemporaryDirectory SourceFolder; + + auto Validate = [&](std::span<const uint32_t> ChunkIndexesToFetch) { + std::vector<std::pair<uint64_t, uint64_t>> Ranges; + for (uint32_t ChunkIndex : ChunkIndexesToFetch) + { + Ranges.push_back(ChunkOffsetAndSizes[ChunkIndex]); + } + + HttpClient::KeyValueMap Headers; + if (!Ranges.empty()) + { + ExtendableStringBuilder<512> SB; + for (const std::pair<uint64_t, uint64_t>& R : Ranges) { - Writer.AddObjectId("partId"sv, BuildPart2Id); - Writer.BeginArray("files"sv); - for (const std::string& ManifestFile : ManifestFiles2) + if (SB.Size() > 0) { - Writer.AddString(ManifestFile); + SB << ", "; } - Writer.EndArray(); // files + SB << R.first << "-" << R.first + R.second - 1; } - Writer.EndObject(); // part2 + Headers.Entries.insert({"Range", fmt::format("bytes={}", SB.ToView())}); } - Writer.EndObject(); // parts - ExtendableStringBuilder<1024> Manifest; - CompactBinaryToJson(Writer.Save(), Manifest); - WriteFile(State.RootPath / "manifest.json", IoBuffer(IoBuffer::Wrap, Manifest.Data(), Manifest.Size())); - } + HttpClient::Response GetBlobRangesResponse = HttpClient.Download( + fmt::format("/api/v2/builds/{}/{}/{}/blobs/{}?supportsRedirect=false", Namespace, Bucket, BuildId, BlockHash), + SourceFolder.Path(), + Headers); - const Oid BuildId = Oid::NewOid(); + REQUIRE(GetBlobRangesResponse.IsSuccess()); + [[maybe_unused]] MemoryView RangesMemoryView = GetBlobRangesResponse.ResponsePayload.GetView(); - auto Result = State.Upload(BuildId, {}, {}, "source", State.RootPath / "manifest.json"); + std::vector<std::pair<uint64_t, uint64_t>> PayloadRanges = GetBlobRangesResponse.GetRanges(Ranges); + if (PayloadRanges.empty()) + { + // We got the whole blob, use the ranges as is + PayloadRanges = Ranges; + } - CHECK_EQ(Result.size(), 2u); - CHECK_EQ(Result[0].first, BuildPart1Id); - CHECK_EQ(Result[0].second, BuildPart1Name); - CHECK_EQ(Result[1].first, BuildPart2Id); - CHECK_EQ(Result[1].second, BuildPart2Name); - State.ValidateUpload(BuildId, Result); + REQUIRE(PayloadRanges.size() == Ranges.size()); - FolderContent DownloadContent = State.Download(BuildId, Oid::Zero, {}, "download", /* Append */ false); - State.ValidateDownload(Paths, Sizes, "source", "download", DownloadContent); + for (uint32_t RangeIndex = 0; RangeIndex < PayloadRanges.size(); RangeIndex++) + { + const std::pair<uint64_t, uint64_t>& PayloadRange = PayloadRanges[RangeIndex]; + + CHECK_EQ(PayloadRange.second, Ranges[RangeIndex].second); - FolderContent Part1DownloadContent = State.Download(BuildId, BuildPart1Id, {}, "download_part1", /* Append */ false); - State.ValidateDownload(ManifestFiles1, ManifestSizes1, "source", "download_part1", Part1DownloadContent); + IoBuffer ChunkPayload(GetBlobRangesResponse.ResponsePayload, PayloadRange.first, PayloadRange.second); + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed(SharedBuffer(ChunkPayload), RawHash, RawSize); + CHECK(CompressedChunk); + CHECK_EQ(RawHash, BlockDescription.ChunkRawHashes[ChunkIndexesToFetch[RangeIndex]]); + CHECK_EQ(RawSize, BlockDescription.ChunkRawLengths[ChunkIndexesToFetch[RangeIndex]]); + } + }; - FolderContent Part2DownloadContent = State.Download(BuildId, Oid::Zero, BuildPart2Name, "download_part2", /* Append */ false); - State.ValidateDownload(ManifestFiles2, ManifestSizes2, "source", "download_part2", Part2DownloadContent); + { + // Single + std::vector<uint32_t> ChunkIndexesToFetch{uint32_t(BlockDescription.ChunkCompressedLengths.size() / 2)}; + Validate(ChunkIndexesToFetch); + } + { + // Many + std::vector<uint32_t> ChunkIndexesToFetch; + for (uint32_t Index = 0; Index < BlockDescription.ChunkCompressedLengths.size() / 16; Index++) + { + ChunkIndexesToFetch.push_back(uint32_t(BlockDescription.ChunkCompressedLengths.size() / 6 + Index * 7)); + ChunkIndexesToFetch.push_back(uint32_t(BlockDescription.ChunkCompressedLengths.size() / 6 + Index * 7 + 1)); + ChunkIndexesToFetch.push_back(uint32_t(BlockDescription.ChunkCompressedLengths.size() / 6 + Index * 7 + 3)); + } + Validate(ChunkIndexesToFetch); + } - (void)State.Download(BuildId, BuildPart1Id, BuildPart1Name, "download_part1+2", /* Append */ false); - FolderContent Part1And2DownloadContent = State.Download(BuildId, BuildPart2Id, {}, "download_part1+2", /* Append */ true); - State.ValidateDownload(Paths, Sizes, "source", "download_part1+2", Part1And2DownloadContent); + { + // First and last + std::vector<uint32_t> ChunkIndexesToFetch{0, uint32_t(BlockDescription.ChunkCompressedLengths.size() - 1)}; + Validate(ChunkIndexesToFetch); + } } +TEST_SUITE_END(); void buildstorageoperations_forcelink() diff --git a/src/zenremotestore/builds/buildstorageutil.cpp b/src/zenremotestore/builds/buildstorageutil.cpp index 36b45e800..2ae726e29 100644 --- a/src/zenremotestore/builds/buildstorageutil.cpp +++ b/src/zenremotestore/builds/buildstorageutil.cpp @@ -63,11 +63,15 @@ ResolveBuildStorage(OperationLogOutput& Output, std::string HostUrl; std::string HostName; + double HostLatencySec = -1.0; + uint64_t HostMaxRangeCountPerRequest = 1; std::string CacheUrl; std::string CacheName; - bool HostAssumeHttp2 = ClientSettings.AssumeHttp2; - bool CacheAssumeHttp2 = ClientSettings.AssumeHttp2; + bool HostAssumeHttp2 = ClientSettings.AssumeHttp2; + bool CacheAssumeHttp2 = ClientSettings.AssumeHttp2; + double CacheLatencySec = -1.0; + uint64_t CacheMaxRangeCountPerRequest = 1; JupiterServerDiscovery DiscoveryResponse; const std::string_view DiscoveryHost = Host.empty() ? OverrideHost : Host; @@ -98,8 +102,10 @@ ResolveBuildStorage(OperationLogOutput& Output, { ZEN_OPERATION_LOG_INFO(Output, "Server endpoint at '{}/api/v1/status/servers' succeeded", OverrideHost); } - HostUrl = OverrideHost; - HostName = GetHostNameFromUrl(OverrideHost); + HostUrl = OverrideHost; + HostName = GetHostNameFromUrl(OverrideHost); + HostLatencySec = TestResult.LatencySeconds; + HostMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest; } else { @@ -134,9 +140,11 @@ ResolveBuildStorage(OperationLogOutput& Output, ZEN_OPERATION_LOG_INFO(Output, "Server endpoint at '{}/api/v1/status/servers' succeeded", ServerEndpoint.BaseUrl); } - HostUrl = ServerEndpoint.BaseUrl; - HostAssumeHttp2 = ServerEndpoint.AssumeHttp2; - HostName = ServerEndpoint.Name; + HostUrl = ServerEndpoint.BaseUrl; + HostAssumeHttp2 = ServerEndpoint.AssumeHttp2; + HostName = ServerEndpoint.Name; + HostLatencySec = TestResult.LatencySeconds; + HostMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest; break; } else @@ -180,9 +188,11 @@ ResolveBuildStorage(OperationLogOutput& Output, ZEN_OPERATION_LOG_INFO(Output, "Cache endpoint at '{}/status/builds' succeeded", CacheEndpoint.BaseUrl); } - CacheUrl = CacheEndpoint.BaseUrl; - CacheAssumeHttp2 = CacheEndpoint.AssumeHttp2; - CacheName = CacheEndpoint.Name; + CacheUrl = CacheEndpoint.BaseUrl; + CacheAssumeHttp2 = CacheEndpoint.AssumeHttp2; + CacheName = CacheEndpoint.Name; + CacheLatencySec = TestResult.LatencySeconds; + CacheMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest; break; } } @@ -204,6 +214,7 @@ ResolveBuildStorage(OperationLogOutput& Output, CacheUrl = ZenServerLocalHostUrl; CacheAssumeHttp2 = false; CacheName = "localhost"; + CacheLatencySec = TestResult.LatencySeconds; } } }); @@ -219,8 +230,10 @@ ResolveBuildStorage(OperationLogOutput& Output, if (ZenCacheEndpointTestResult TestResult = TestZenCacheEndpoint(ZenCacheHost, /*AssumeHttp2*/ false, ClientSettings.Verbose); TestResult.Success) { - CacheUrl = ZenCacheHost; - CacheName = GetHostNameFromUrl(ZenCacheHost); + CacheUrl = ZenCacheHost; + CacheName = GetHostNameFromUrl(ZenCacheHost); + CacheLatencySec = TestResult.LatencySeconds; + CacheMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest; } else { @@ -228,13 +241,34 @@ ResolveBuildStorage(OperationLogOutput& Output, } } - return BuildStorageResolveResult{.HostUrl = HostUrl, - .HostName = HostName, - .HostAssumeHttp2 = HostAssumeHttp2, + return BuildStorageResolveResult{ + .Cloud = {.Address = HostUrl, + .Name = HostName, + .AssumeHttp2 = HostAssumeHttp2, + .LatencySec = HostLatencySec, + .Caps = BuildStorageResolveResult::Capabilities{.MaxRangeCountPerRequest = HostMaxRangeCountPerRequest}}, + .Cache = {.Address = CacheUrl, + .Name = CacheName, + .AssumeHttp2 = CacheAssumeHttp2, + .LatencySec = CacheLatencySec, + .Caps = BuildStorageResolveResult::Capabilities{.MaxRangeCountPerRequest = CacheMaxRangeCountPerRequest}}}; +} - .CacheUrl = CacheUrl, - .CacheName = CacheName, - .CacheAssumeHttp2 = CacheAssumeHttp2}; +std::vector<ChunkBlockDescription> +ParseBlockMetadatas(std::span<const CbObject> BlockMetadatas) +{ + std::vector<ChunkBlockDescription> UnorderedList; + UnorderedList.reserve(BlockMetadatas.size()); + for (size_t CacheBlockMetadataIndex = 0; CacheBlockMetadataIndex < BlockMetadatas.size(); CacheBlockMetadataIndex++) + { + const CbObject& CacheBlockMetadata = BlockMetadatas[CacheBlockMetadataIndex]; + ChunkBlockDescription Description = ParseChunkBlockDescription(CacheBlockMetadata); + if (Description.BlockHash != IoHash::Zero) + { + UnorderedList.emplace_back(std::move(Description)); + } + } + return UnorderedList; } std::vector<ChunkBlockDescription> @@ -242,7 +276,6 @@ GetBlockDescriptions(OperationLogOutput& Output, BuildStorageBase& Storage, BuildStorageCache* OptionalCacheStorage, const Oid& BuildId, - const Oid& BuildPartId, std::span<const IoHash> BlockRawHashes, bool AttemptFallback, bool IsQuiet, @@ -250,37 +283,20 @@ GetBlockDescriptions(OperationLogOutput& Output, { using namespace std::literals; - if (!IsQuiet) - { - ZEN_OPERATION_LOG_INFO(Output, "Fetching metadata for {} blocks", BlockRawHashes.size()); - } - - Stopwatch GetBlockMetadataTimer; - std::vector<ChunkBlockDescription> UnorderedList; tsl::robin_map<IoHash, size_t, IoHash::Hasher> BlockDescriptionLookup; if (OptionalCacheStorage && !BlockRawHashes.empty()) { std::vector<CbObject> CacheBlockMetadatas = OptionalCacheStorage->GetBlobMetadatas(BuildId, BlockRawHashes); - UnorderedList.reserve(CacheBlockMetadatas.size()); - for (size_t CacheBlockMetadataIndex = 0; CacheBlockMetadataIndex < CacheBlockMetadatas.size(); CacheBlockMetadataIndex++) + if (!CacheBlockMetadatas.empty()) { - const CbObject& CacheBlockMetadata = CacheBlockMetadatas[CacheBlockMetadataIndex]; - ChunkBlockDescription Description = ParseChunkBlockDescription(CacheBlockMetadata); - if (Description.BlockHash == IoHash::Zero) + UnorderedList = ParseBlockMetadatas(CacheBlockMetadatas); + for (size_t DescriptionIndex = 0; DescriptionIndex < UnorderedList.size(); DescriptionIndex++) { - ZEN_OPERATION_LOG_WARN(Output, "Unexpected/invalid block metadata received from remote cache, skipping block"); - } - else - { - UnorderedList.emplace_back(std::move(Description)); + const ChunkBlockDescription& Description = UnorderedList[DescriptionIndex]; + BlockDescriptionLookup.insert_or_assign(Description.BlockHash, DescriptionIndex); } } - for (size_t DescriptionIndex = 0; DescriptionIndex < UnorderedList.size(); DescriptionIndex++) - { - const ChunkBlockDescription& Description = UnorderedList[DescriptionIndex]; - BlockDescriptionLookup.insert_or_assign(Description.BlockHash, DescriptionIndex); - } } if (UnorderedList.size() < BlockRawHashes.size()) @@ -346,15 +362,6 @@ GetBlockDescriptions(OperationLogOutput& Output, } } - if (!IsQuiet) - { - ZEN_OPERATION_LOG_INFO(Output, - "GetBlockMetadata for {} took {}. Found {} blocks", - BuildPartId, - NiceTimeSpanMs(GetBlockMetadataTimer.GetElapsedTimeMs()), - Result.size()); - } - if (Result.size() != BlockRawHashes.size()) { std::string ErrorDescription = diff --git a/src/zenremotestore/builds/filebuildstorage.cpp b/src/zenremotestore/builds/filebuildstorage.cpp index 55e69de61..2f4904449 100644 --- a/src/zenremotestore/builds/filebuildstorage.cpp +++ b/src/zenremotestore/builds/filebuildstorage.cpp @@ -432,6 +432,45 @@ public: return IoBuffer{}; } + virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId, + const IoHash& RawHash, + std::span<const std::pair<uint64_t, uint64_t>> Ranges) override + { + ZEN_TRACE_CPU("FileBuildStorage::GetBuildBlobRanges"); + ZEN_UNUSED(BuildId); + ZEN_ASSERT(!Ranges.empty()); + + uint64_t ReceivedBytes = 0; + uint64_t SentBytes = Ranges.size() * 2 * 8; + + SimulateLatency(SentBytes, 0); + auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); }); + + Stopwatch ExecutionTimer; + auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer, SentBytes, ReceivedBytes); }); + + BuildBlobRanges Result; + + const std::filesystem::path BlockPath = GetBlobPayloadPath(RawHash); + if (IsFile(BlockPath)) + { + BasicFile File(BlockPath, BasicFile::Mode::kRead); + + uint64_t RangeOffset = Ranges.front().first; + uint64_t RangeBytes = Ranges.back().first + Ranges.back().second - RangeOffset; + Result.PayloadBuffer = IoBufferBuilder::MakeFromFileHandle(File.Detach(), RangeOffset, RangeBytes); + + Result.Ranges.reserve(Ranges.size()); + + for (const std::pair<uint64_t, uint64_t>& Range : Ranges) + { + Result.Ranges.push_back(std::make_pair(Range.first - RangeOffset, Range.second)); + } + ReceivedBytes = Result.PayloadBuffer.GetSize(); + } + return Result; + } + virtual std::vector<std::function<void()>> GetLargeBuildBlob(const Oid& BuildId, const IoHash& RawHash, uint64_t ChunkSize, diff --git a/src/zenremotestore/builds/jupiterbuildstorage.cpp b/src/zenremotestore/builds/jupiterbuildstorage.cpp index 23d0ddd4c..8e16da1a9 100644 --- a/src/zenremotestore/builds/jupiterbuildstorage.cpp +++ b/src/zenremotestore/builds/jupiterbuildstorage.cpp @@ -21,7 +21,7 @@ namespace zen { using namespace std::literals; namespace { - void ThrowFromJupiterResult(const JupiterResult& Result, std::string_view Prefix) + [[noreturn]] void ThrowFromJupiterResult(const JupiterResult& Result, std::string_view Prefix) { int Error = Result.ErrorCode < (int)HttpResponseCode::Continue ? Result.ErrorCode : 0; HttpResponseCode Status = @@ -295,6 +295,26 @@ public: return std::move(GetBuildBlobResult.Response); } + virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId, + const IoHash& RawHash, + std::span<const std::pair<uint64_t, uint64_t>> Ranges) override + { + ZEN_TRACE_CPU("Jupiter::GetBuildBlob"); + + Stopwatch ExecutionTimer; + auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); }); + CreateDirectories(m_TempFolderPath); + + BuildBlobRangesResult GetBuildBlobResult = + m_Session.GetBuildBlob(m_Namespace, m_Bucket, BuildId, RawHash, m_TempFolderPath, Ranges); + AddStatistic(GetBuildBlobResult); + if (!GetBuildBlobResult.Success) + { + ThrowFromJupiterResult(GetBuildBlobResult, "Failed fetching build blob ranges"sv); + } + return BuildBlobRanges{.PayloadBuffer = std::move(GetBuildBlobResult.Response), .Ranges = std::move(GetBuildBlobResult.Ranges)}; + } + virtual std::vector<std::function<void()>> GetLargeBuildBlob(const Oid& BuildId, const IoHash& RawHash, uint64_t ChunkSize, diff --git a/src/zenremotestore/chunking/chunkblock.cpp b/src/zenremotestore/chunking/chunkblock.cpp index c4d8653f4..cca32c17d 100644 --- a/src/zenremotestore/chunking/chunkblock.cpp +++ b/src/zenremotestore/chunking/chunkblock.cpp @@ -7,27 +7,201 @@ #include <zencore/logging.h> #include <zencore/timer.h> #include <zencore/trace.h> - #include <zenremotestore/operationlogoutput.h> -#include <vector> +#include <numeric> ZEN_THIRD_PARTY_INCLUDES_START -#include <tsl/robin_map.h> +#include <tsl/robin_set.h> ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_WITH_TESTS # include <zencore/testing.h> # include <zencore/testutils.h> - -# include <unordered_map> -# include <numeric> #endif // ZEN_WITH_TESTS namespace zen { using namespace std::literals; +namespace chunkblock_impl { + + struct RangeDescriptor + { + uint64_t RangeStart = 0; + uint64_t RangeLength = 0; + uint32_t ChunkBlockIndexStart = 0; + uint32_t ChunkBlockIndexCount = 0; + }; + + void MergeCheapestRange(std::vector<RangeDescriptor>& InOutRanges) + { + ZEN_ASSERT(InOutRanges.size() > 1); + + size_t BestRangeIndexToCollapse = SIZE_MAX; + uint64_t BestGap = (uint64_t)-1; + + for (size_t RangeIndex = 0; RangeIndex < InOutRanges.size() - 1; RangeIndex++) + { + const RangeDescriptor& Range = InOutRanges[RangeIndex]; + const RangeDescriptor& NextRange = InOutRanges[RangeIndex + 1]; + uint64_t Gap = NextRange.RangeStart - (Range.RangeStart + Range.RangeLength); + if (Gap < BestGap) + { + BestRangeIndexToCollapse = RangeIndex; + BestGap = Gap; + } + else if (Gap == BestGap) + { + const RangeDescriptor& BestRange = InOutRanges[BestRangeIndexToCollapse]; + const RangeDescriptor& BestNextRange = InOutRanges[BestRangeIndexToCollapse + 1]; + uint64_t BestMergedSize = (BestNextRange.RangeStart + BestNextRange.RangeLength) - BestRange.RangeStart; + uint64_t MergedSize = (NextRange.RangeStart + NextRange.RangeLength) - Range.RangeStart; + if (MergedSize < BestMergedSize) + { + BestRangeIndexToCollapse = RangeIndex; + } + } + } + + ZEN_ASSERT(BestRangeIndexToCollapse != SIZE_MAX); + ZEN_ASSERT(BestRangeIndexToCollapse < InOutRanges.size() - 1); + ZEN_ASSERT(BestGap != (uint64_t)-1); + + RangeDescriptor& BestRange = InOutRanges[BestRangeIndexToCollapse]; + const RangeDescriptor& BestNextRange = InOutRanges[BestRangeIndexToCollapse + 1]; + BestRange.RangeLength = BestNextRange.RangeStart - BestRange.RangeStart + BestNextRange.RangeLength; + BestRange.ChunkBlockIndexCount = + BestNextRange.ChunkBlockIndexStart - BestRange.ChunkBlockIndexStart + BestNextRange.ChunkBlockIndexCount; + InOutRanges.erase(InOutRanges.begin() + BestRangeIndexToCollapse + 1); + } + + std::vector<RangeDescriptor> GetBlockRanges(const ChunkBlockDescription& BlockDescription, + const uint64_t ChunkStartOffsetInBlock, + std::span<const uint32_t> BlockChunkIndexNeeded) + { + ZEN_TRACE_CPU("GetBlockRanges"); + std::vector<RangeDescriptor> BlockRanges; + { + uint64_t CurrentOffset = ChunkStartOffsetInBlock; + uint32_t ChunkBlockIndex = 0; + uint32_t NeedBlockChunkIndexOffset = 0; + RangeDescriptor NextRange; + while (NeedBlockChunkIndexOffset < BlockChunkIndexNeeded.size() && ChunkBlockIndex < BlockDescription.ChunkRawHashes.size()) + { + const uint32_t ChunkCompressedLength = BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; + if (ChunkBlockIndex < BlockChunkIndexNeeded[NeedBlockChunkIndexOffset]) + { + if (NextRange.RangeLength > 0) + { + BlockRanges.push_back(NextRange); + NextRange = {}; + } + ChunkBlockIndex++; + CurrentOffset += ChunkCompressedLength; + } + else if (ChunkBlockIndex == BlockChunkIndexNeeded[NeedBlockChunkIndexOffset]) + { + if (NextRange.RangeLength == 0) + { + NextRange.RangeStart = CurrentOffset; + NextRange.ChunkBlockIndexStart = ChunkBlockIndex; + } + NextRange.RangeLength += ChunkCompressedLength; + NextRange.ChunkBlockIndexCount++; + ChunkBlockIndex++; + CurrentOffset += ChunkCompressedLength; + NeedBlockChunkIndexOffset++; + } + else + { + ZEN_ASSERT(false); + } + } + if (NextRange.RangeLength > 0) + { + BlockRanges.push_back(NextRange); + } + } + ZEN_ASSERT(!BlockRanges.empty()); + return BlockRanges; + } + + std::vector<RangeDescriptor> OptimizeRanges(uint64_t TotalBlockSize, + std::span<const RangeDescriptor> ExactRanges, + double LatencySec, + uint64_t SpeedBytesPerSec, + uint64_t MaxRangeCountPerRequest, + uint64_t MaxRangesPerBlock) + { + ZEN_TRACE_CPU("OptimizeRanges"); + ZEN_ASSERT(MaxRangesPerBlock > 0); + std::vector<RangeDescriptor> Ranges(ExactRanges.begin(), ExactRanges.end()); + + while (Ranges.size() > MaxRangesPerBlock) + { + MergeCheapestRange(Ranges); + } + + while (true) + { + const std::uint64_t RangeTotalSize = + std::accumulate(Ranges.begin(), Ranges.end(), uint64_t(0u), [](uint64_t Current, const RangeDescriptor& Value) { + return Current + Value.RangeLength; + }); + + const size_t RangeCount = Ranges.size(); + const uint64_t RequestCount = + MaxRangeCountPerRequest == (uint64_t)-1 ? 1 : (RangeCount + MaxRangeCountPerRequest - 1) / MaxRangeCountPerRequest; + uint64_t RequestTimeAsBytes = uint64_t(SpeedBytesPerSec * RequestCount * LatencySec); + + if (RangeCount == 1) + { + // Does fetching the full block add less time than the time it takes to complete a single request? + if (TotalBlockSize - RangeTotalSize < SpeedBytesPerSec * LatencySec) + { + const std::uint64_t InitialRangeTotalSize = + std::accumulate(ExactRanges.begin(), + ExactRanges.end(), + uint64_t(0u), + [](uint64_t Current, const RangeDescriptor& Value) { return Current + Value.RangeLength; }); + + ZEN_DEBUG( + "Latency round trip takes as long as receiving the extra redundant bytes - go full block, dropping {} of slack, " + "adding {} of bytes to fetch, for block of size {}", + NiceBytes(TotalBlockSize - RangeTotalSize), + NiceBytes(TotalBlockSize - InitialRangeTotalSize), + NiceBytes(TotalBlockSize)); + return {}; + } + else + { + return Ranges; + } + } + + if (RequestTimeAsBytes < (TotalBlockSize - RangeTotalSize)) + { + return Ranges; + } + + if (RangeCount == 2) + { + // Merge to single range + Ranges.front().RangeLength = Ranges.back().RangeStart - Ranges.front().RangeStart + Ranges.back().RangeLength; + Ranges.front().ChunkBlockIndexCount = + Ranges.back().ChunkBlockIndexStart - Ranges.front().ChunkBlockIndexStart + Ranges.back().ChunkBlockIndexCount; + Ranges.pop_back(); + } + else + { + MergeCheapestRange(Ranges); + } + } + } + +} // namespace chunkblock_impl + ChunkBlockDescription ParseChunkBlockDescription(const CbObjectView& BlockObject) { @@ -455,9 +629,299 @@ FindReuseBlocks(OperationLogOutput& Output, return FilteredReuseBlockIndexes; } +ChunkBlockAnalyser::ChunkBlockAnalyser(OperationLogOutput& LogOutput, + std::span<const ChunkBlockDescription> BlockDescriptions, + const Options& Options) +: m_LogOutput(LogOutput) +, m_BlockDescriptions(BlockDescriptions) +, m_Options(Options) +{ +} + +std::vector<ChunkBlockAnalyser::NeededBlock> +ChunkBlockAnalyser::GetNeeded(const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& ChunkHashToChunkIndex, + std::function<bool(uint32_t ChunkIndex)>&& NeedsBlockChunk) +{ + ZEN_TRACE_CPU("ChunkBlockAnalyser::GetNeeded"); + + std::vector<NeededBlock> Result; + + std::vector<bool> ChunkIsNeeded(ChunkHashToChunkIndex.size()); + for (uint32_t ChunkIndex = 0; ChunkIndex < ChunkHashToChunkIndex.size(); ChunkIndex++) + { + ChunkIsNeeded[ChunkIndex] = NeedsBlockChunk(ChunkIndex); + } + + std::vector<uint64_t> BlockSlack(m_BlockDescriptions.size(), 0u); + for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++) + { + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + + uint64_t BlockUsedSize = 0; + uint64_t BlockSize = 0; + + for (uint32_t ChunkBlockIndex = 0; ChunkBlockIndex < BlockDescription.ChunkRawHashes.size(); ChunkBlockIndex++) + { + const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex]; + if (auto It = ChunkHashToChunkIndex.find(ChunkHash); It != ChunkHashToChunkIndex.end()) + { + const uint32_t RemoteChunkIndex = It->second; + if (ChunkIsNeeded[RemoteChunkIndex]) + { + BlockUsedSize += BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; + } + } + BlockSize += BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; + } + BlockSlack[BlockIndex] = BlockSize - BlockUsedSize; + } + + std::vector<uint32_t> BlockOrder(m_BlockDescriptions.size()); + std::iota(BlockOrder.begin(), BlockOrder.end(), 0); + + std::sort(BlockOrder.begin(), BlockOrder.end(), [&BlockSlack](uint32_t Lhs, uint32_t Rhs) { + return BlockSlack[Lhs] < BlockSlack[Rhs]; + }); + + std::vector<bool> ChunkIsPickedUp(ChunkHashToChunkIndex.size(), false); + + for (uint32_t BlockIndex : BlockOrder) + { + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + + std::vector<uint32_t> BlockChunkIndexNeeded; + + for (uint32_t ChunkBlockIndex = 0; ChunkBlockIndex < BlockDescription.ChunkRawHashes.size(); ChunkBlockIndex++) + { + const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex]; + if (auto It = ChunkHashToChunkIndex.find(ChunkHash); It != ChunkHashToChunkIndex.end()) + { + const uint32_t RemoteChunkIndex = It->second; + if (ChunkIsNeeded[RemoteChunkIndex]) + { + if (!ChunkIsPickedUp[RemoteChunkIndex]) + { + ChunkIsPickedUp[RemoteChunkIndex] = true; + BlockChunkIndexNeeded.push_back(ChunkBlockIndex); + } + } + } + else + { + ZEN_DEBUG("Chunk {} not found in block {}", ChunkHash, BlockDescription.BlockHash); + } + } + + if (!BlockChunkIndexNeeded.empty()) + { + Result.push_back(NeededBlock{.BlockIndex = BlockIndex, .ChunkIndexes = std::move(BlockChunkIndexNeeded)}); + } + } + return Result; +} + +ChunkBlockAnalyser::BlockResult +ChunkBlockAnalyser::CalculatePartialBlockDownloads(std::span<const NeededBlock> NeededBlocks, + std::span<const EPartialBlockDownloadMode> BlockPartialDownloadModes) +{ + ZEN_TRACE_CPU("ChunkBlockAnalyser::CalculatePartialBlockDownloads"); + + Stopwatch PartialAnalisysTimer; + + ChunkBlockAnalyser::BlockResult Result; + + { + uint64_t MinRequestCount = 0; + uint64_t RequestCount = 0; + uint64_t RangeCount = 0; + uint64_t IdealDownloadTotalSize = 0; + uint64_t ActualDownloadTotalSize = 0; + uint64_t FullDownloadTotalSize = 0; + for (const NeededBlock& NeededBlock : NeededBlocks) + { + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[NeededBlock.BlockIndex]; + std::span<const uint32_t> BlockChunkIndexNeeded(NeededBlock.ChunkIndexes); + const uint32_t ChunkStartOffsetInBlock = + gsl::narrow<uint32_t>(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize); + uint64_t TotalBlockSize = std::accumulate(BlockDescription.ChunkCompressedLengths.begin(), + BlockDescription.ChunkCompressedLengths.end(), + uint64_t(ChunkStartOffsetInBlock)); + uint64_t ExactRangesSize = 0; + uint64_t DownloadRangesSize = 0; + uint64_t FullDownloadSize = 0; + + bool CanDoPartialBlockDownload = (BlockDescription.HeaderSize > 0) && + (BlockDescription.ChunkCompressedLengths.size() == BlockDescription.ChunkRawHashes.size()); + + if (NeededBlock.ChunkIndexes.size() == BlockDescription.ChunkRawHashes.size() || !CanDoPartialBlockDownload) + { + // Full block + ExactRangesSize = TotalBlockSize; + DownloadRangesSize = TotalBlockSize; + FullDownloadSize = TotalBlockSize; + MinRequestCount++; + RequestCount++; + RangeCount++; + Result.FullBlockIndexes.push_back(NeededBlock.BlockIndex); + } + else if (NeededBlock.ChunkIndexes.empty()) + { + // Not needed + } + else + { + FullDownloadSize = TotalBlockSize; + std::vector<chunkblock_impl::RangeDescriptor> Ranges = + chunkblock_impl::GetBlockRanges(BlockDescription, ChunkStartOffsetInBlock, BlockChunkIndexNeeded); + ExactRangesSize = std::accumulate( + Ranges.begin(), + Ranges.end(), + uint64_t(0), + [](uint64_t Current, const chunkblock_impl::RangeDescriptor& Range) { return Current + Range.RangeLength; }); + + EPartialBlockDownloadMode PartialBlockDownloadMode = BlockPartialDownloadModes[NeededBlock.BlockIndex]; + if (PartialBlockDownloadMode == EPartialBlockDownloadMode::Off) + { + // Use full block + MinRequestCount++; + RangeCount++; + RequestCount++; + Result.FullBlockIndexes.push_back(NeededBlock.BlockIndex); + DownloadRangesSize = TotalBlockSize; + } + else + { + const bool IsHighSpeed = (PartialBlockDownloadMode == EPartialBlockDownloadMode::MultiRangeHighSpeed); + uint64_t MaxRangeCountPerRequest = + IsHighSpeed ? m_Options.HostHighSpeedMaxRangeCountPerRequest : m_Options.HostMaxRangeCountPerRequest; + ZEN_ASSERT(MaxRangeCountPerRequest != 0); + + if (PartialBlockDownloadMode == EPartialBlockDownloadMode::Exact) + { + // Use exact ranges + for (const chunkblock_impl::RangeDescriptor& Range : Ranges) + { + Result.BlockRanges.push_back(BlockRangeDescriptor{.BlockIndex = NeededBlock.BlockIndex, + .RangeStart = Range.RangeStart, + .RangeLength = Range.RangeLength, + .ChunkBlockIndexStart = Range.ChunkBlockIndexStart, + .ChunkBlockIndexCount = Range.ChunkBlockIndexCount}); + } + + MinRequestCount++; + RangeCount += Ranges.size(); + RequestCount += MaxRangeCountPerRequest == (uint64_t)-1 + ? 1 + : (Ranges.size() + MaxRangeCountPerRequest - 1) / MaxRangeCountPerRequest; + DownloadRangesSize = ExactRangesSize; + } + else + { + if (PartialBlockDownloadMode == EPartialBlockDownloadMode::SingleRange) + { + // Use single range + if (Ranges.size() > 1) + { + Ranges = {chunkblock_impl::RangeDescriptor{ + .RangeStart = Ranges.front().RangeStart, + .RangeLength = Ranges.back().RangeStart + Ranges.back().RangeLength - Ranges.front().RangeStart, + .ChunkBlockIndexStart = Ranges.front().ChunkBlockIndexStart, + .ChunkBlockIndexCount = Ranges.back().ChunkBlockIndexStart + Ranges.back().ChunkBlockIndexCount - + Ranges.front().ChunkBlockIndexStart}}; + } + + // We still do the optimize pass to see if it is more effective to use a full block + } + + double LatencySec = IsHighSpeed ? m_Options.HostHighSpeedLatencySec : m_Options.HostLatencySec; + uint64_t SpeedBytesPerSec = IsHighSpeed ? m_Options.HostHighSpeedBytesPerSec : m_Options.HostSpeedBytesPerSec; + if (LatencySec > 0.0 && SpeedBytesPerSec > 0u) + { + Ranges = chunkblock_impl::OptimizeRanges(TotalBlockSize, + Ranges, + LatencySec, + SpeedBytesPerSec, + MaxRangeCountPerRequest, + m_Options.MaxRangesPerBlock); + } + + MinRequestCount++; + if (Ranges.empty()) + { + Result.FullBlockIndexes.push_back(NeededBlock.BlockIndex); + RequestCount++; + RangeCount++; + DownloadRangesSize = TotalBlockSize; + } + else + { + for (const chunkblock_impl::RangeDescriptor& Range : Ranges) + { + Result.BlockRanges.push_back(BlockRangeDescriptor{.BlockIndex = NeededBlock.BlockIndex, + .RangeStart = Range.RangeStart, + .RangeLength = Range.RangeLength, + .ChunkBlockIndexStart = Range.ChunkBlockIndexStart, + .ChunkBlockIndexCount = Range.ChunkBlockIndexCount}); + } + RangeCount += Ranges.size(); + RequestCount += MaxRangeCountPerRequest == (uint64_t)-1 + ? 1 + : (Ranges.size() + MaxRangeCountPerRequest - 1) / MaxRangeCountPerRequest; + } + + DownloadRangesSize = Ranges.empty() + ? TotalBlockSize + : std::accumulate(Ranges.begin(), + Ranges.end(), + uint64_t(0), + [](uint64_t Current, const chunkblock_impl::RangeDescriptor& Range) { + return Current + Range.RangeLength; + }); + } + } + } + IdealDownloadTotalSize += ExactRangesSize; + ActualDownloadTotalSize += DownloadRangesSize; + FullDownloadTotalSize += FullDownloadSize; + + if (ExactRangesSize < FullDownloadSize) + { + ZEN_DEBUG("Block {}: Full: {}, Ideal: {}, Actual: {}, Saves: {}", + NeededBlock.BlockIndex, + NiceBytes(FullDownloadSize), + NiceBytes(ExactRangesSize), + NiceBytes(DownloadRangesSize), + NiceBytes(FullDownloadSize - DownloadRangesSize)); + } + } + uint64_t Actual = FullDownloadTotalSize - ActualDownloadTotalSize; + uint64_t Ideal = FullDownloadTotalSize - IdealDownloadTotalSize; + if (Ideal < FullDownloadTotalSize && !m_Options.IsQuiet) + { + const double AchievedPercent = Ideal == 0 ? 100.0 : (100.0 * Actual) / Ideal; + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Block Partial Analysis: Blocks: {}, Full: {}, Ideal: {}, Actual: {}. Skipping {} ({:.1f}%) out of " + "possible {} using {} extra ranges " + "via {} extra requests. Completed in {}", + NeededBlocks.size(), + NiceBytes(FullDownloadTotalSize), + NiceBytes(IdealDownloadTotalSize), + NiceBytes(ActualDownloadTotalSize), + NiceBytes(FullDownloadTotalSize - ActualDownloadTotalSize), + AchievedPercent, + NiceBytes(Ideal), + RangeCount - MinRequestCount, + RequestCount - MinRequestCount, + NiceTimeSpanMs(PartialAnalisysTimer.GetElapsedTimeMs())); + } + } + + return Result; +} + #if ZEN_WITH_TESTS -namespace testutils { +namespace chunkblock_testutils { static std::vector<std::pair<Oid, CompressedBuffer>> CreateAttachments( const std::span<const size_t>& Sizes, OodleCompressionLevel CompressionLevel = OodleCompressionLevel::VeryFast, @@ -474,12 +938,14 @@ namespace testutils { return Result; } -} // namespace testutils +} // namespace chunkblock_testutils + +TEST_SUITE_BEGIN("remotestore.chunkblock"); -TEST_CASE("project.store.block") +TEST_CASE("chunkblock.block") { using namespace std::literals; - using namespace testutils; + using namespace chunkblock_testutils; std::vector<std::size_t> AttachmentSizes({7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269, 2257, 3685, 3489, 7194, 6151, 5482, 6217, 3511, 6738, 5061, 7537, 2759, 1916, 8210, 2235, 4024, 1582, 5251, @@ -504,10 +970,10 @@ TEST_CASE("project.store.block") HeaderSize)); } -TEST_CASE("project.store.reuseblocks") +TEST_CASE("chunkblock.reuseblocks") { using namespace std::literals; - using namespace testutils; + using namespace chunkblock_testutils; std::vector<std::vector<std::size_t>> BlockAttachmentSizes( {std::vector<std::size_t>{7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269, 2257, 3685, 3489, @@ -744,6 +1210,894 @@ TEST_CASE("project.store.reuseblocks") } } +namespace chunkblock_analyser_testutils { + + // Build a ChunkBlockDescription without any real payload. + // Hashes are derived deterministically from (BlockSeed XOR ChunkIndex) so that the same + // seed produces the same hashes — useful for deduplication tests. + static ChunkBlockDescription MakeBlockDesc(uint64_t HeaderSize, + std::initializer_list<uint32_t> CompressedLengths, + uint32_t BlockSeed = 0) + { + ChunkBlockDescription Desc; + Desc.HeaderSize = HeaderSize; + uint32_t ChunkIndex = 0; + for (uint32_t Length : CompressedLengths) + { + uint64_t HashInput = uint64_t(BlockSeed ^ ChunkIndex); + Desc.ChunkRawHashes.push_back(IoHash::HashBuffer(MemoryView(&HashInput, sizeof(HashInput)))); + Desc.ChunkRawLengths.push_back(Length); + Desc.ChunkCompressedLengths.push_back(Length); + ChunkIndex++; + } + return Desc; + } + + // Build the robin_map<IoHash, uint32_t> needed by GetNeeded from a flat list of blocks. + // First occurrence of each hash wins; index is assigned sequentially across all blocks. + [[maybe_unused]] static tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> MakeHashMap(const std::vector<ChunkBlockDescription>& Blocks) + { + tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> Result; + uint32_t Index = 0; + for (const ChunkBlockDescription& Block : Blocks) + { + for (const IoHash& Hash : Block.ChunkRawHashes) + { + if (!Result.contains(Hash)) + { + Result.emplace(Hash, Index++); + } + } + } + return Result; + } + +} // namespace chunkblock_analyser_testutils + +TEST_CASE("chunkblock.mergecheapestrange.picks_smallest_gap") +{ + using RD = chunkblock_impl::RangeDescriptor; + // Gap between ranges 0-1 is 50, gap between 1-2 is 150 → pair 0-1 gets merged + std::vector<RD> Ranges = { + {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1}, + {.RangeStart = 150, .RangeLength = 100, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1}, + {.RangeStart = 400, .RangeLength = 100, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1}, + }; + chunkblock_impl::MergeCheapestRange(Ranges); + + REQUIRE_EQ(2u, Ranges.size()); + CHECK_EQ(0u, Ranges[0].RangeStart); + CHECK_EQ(250u, Ranges[0].RangeLength); // 150+100 + CHECK_EQ(0u, Ranges[0].ChunkBlockIndexStart); + CHECK_EQ(2u, Ranges[0].ChunkBlockIndexCount); + CHECK_EQ(400u, Ranges[1].RangeStart); + CHECK_EQ(100u, Ranges[1].RangeLength); + CHECK_EQ(2u, Ranges[1].ChunkBlockIndexStart); + CHECK_EQ(1u, Ranges[1].ChunkBlockIndexCount); +} + +TEST_CASE("chunkblock.mergecheapestrange.tiebreak_smaller_merged") +{ + using RD = chunkblock_impl::RangeDescriptor; + // Gap 0-1 == gap 1-2 == 100; merged size 0-1 (250) < merged size 1-2 (350) → pair 0-1 wins + std::vector<RD> Ranges = { + {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1}, + {.RangeStart = 200, .RangeLength = 50, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1}, + {.RangeStart = 350, .RangeLength = 200, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1}, + }; + chunkblock_impl::MergeCheapestRange(Ranges); + + REQUIRE_EQ(2u, Ranges.size()); + // Pair 0-1 merged: start=0, length = (200+50)-0 = 250 + CHECK_EQ(0u, Ranges[0].RangeStart); + CHECK_EQ(250u, Ranges[0].RangeLength); + CHECK_EQ(0u, Ranges[0].ChunkBlockIndexStart); + CHECK_EQ(2u, Ranges[0].ChunkBlockIndexCount); + // Pair 1 unchanged (was index 2) + CHECK_EQ(350u, Ranges[1].RangeStart); + CHECK_EQ(200u, Ranges[1].RangeLength); + CHECK_EQ(2u, Ranges[1].ChunkBlockIndexStart); + CHECK_EQ(1u, Ranges[1].ChunkBlockIndexCount); +} + +TEST_CASE("chunkblock.optimizeranges.preserves_ranges_low_latency") +{ + using RD = chunkblock_impl::RangeDescriptor; + // With MaxRangeCountPerRequest unlimited, RequestCount=1 + // RequestTimeAsBytes = 100000 * 1 * 0.001 = 100 << slack=7000 → all ranges preserved + std::vector<RD> ExactRanges = { + {.RangeStart = 0, .RangeLength = 1000, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1}, + {.RangeStart = 2000, .RangeLength = 1000, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1}, + {.RangeStart = 4000, .RangeLength = 1000, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1}, + }; + uint64_t TotalBlockSize = 10000; + double LatencySec = 0.001; + uint64_t SpeedBytesPerSec = 100000; + uint64_t MaxRangeCountPerReq = (uint64_t)-1; + uint64_t MaxRangesPerBlock = 1024; + + auto Result = + chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock); + + REQUIRE_EQ(3u, Result.size()); +} + +TEST_CASE("chunkblock.optimizeranges.falls_back_to_full_block") +{ + using RD = chunkblock_impl::RangeDescriptor; + // 1 range already; slack=100 < SpeedBytesPerSec*LatencySec=200 → full block (empty result) + std::vector<RD> ExactRanges = { + {.RangeStart = 100, .RangeLength = 900, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 3}, + }; + uint64_t TotalBlockSize = 1000; + double LatencySec = 0.01; + uint64_t SpeedBytesPerSec = 20000; + uint64_t MaxRangeCountPerReq = (uint64_t)-1; + uint64_t MaxRangesPerBlock = 1024; + + auto Result = + chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock); + + CHECK(Result.empty()); +} + +TEST_CASE("chunkblock.optimizeranges.maxrangesperblock_clamp") +{ + using RD = chunkblock_impl::RangeDescriptor; + // 5 input ranges; MaxRangesPerBlock=2 clamps to ≤2 before the cost model runs + std::vector<RD> ExactRanges = { + {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1}, + {.RangeStart = 300, .RangeLength = 100, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1}, + {.RangeStart = 600, .RangeLength = 100, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1}, + {.RangeStart = 900, .RangeLength = 100, .ChunkBlockIndexStart = 3, .ChunkBlockIndexCount = 1}, + {.RangeStart = 1200, .RangeLength = 100, .ChunkBlockIndexStart = 4, .ChunkBlockIndexCount = 1}, + }; + uint64_t TotalBlockSize = 5000; + double LatencySec = 0.001; + uint64_t SpeedBytesPerSec = 100000; + uint64_t MaxRangeCountPerReq = (uint64_t)-1; + uint64_t MaxRangesPerBlock = 2; + + auto Result = + chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock); + + CHECK(Result.size() <= 2u); + CHECK(!Result.empty()); +} + +TEST_CASE("chunkblock.optimizeranges.low_maxrangecountperrequest_drives_merge") +{ + using RD = chunkblock_impl::RangeDescriptor; + // MaxRangeCountPerRequest=1 means RequestCount==RangeCount; high latency drives merging + // With MaxRangeCountPerRequest=-1 the same 3 ranges would be preserved (verified by comment below) + std::vector<RD> ExactRanges = { + {.RangeStart = 100, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1}, + {.RangeStart = 250, .RangeLength = 100, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1}, + {.RangeStart = 400, .RangeLength = 100, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1}, + }; + uint64_t TotalBlockSize = 1000; + double LatencySec = 1.0; + uint64_t SpeedBytesPerSec = 500; + // With MaxRangeCountPerRequest=-1: RequestCount=1, RequestTimeAsBytes=500 < slack=700 → preserved + // With MaxRangeCountPerRequest=1: RequestCount=3, RequestTimeAsBytes=1500 > slack=700 → merged + uint64_t MaxRangesPerBlock = 1024; + + auto Unlimited = + chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, (uint64_t)-1, MaxRangesPerBlock); + CHECK_EQ(3u, Unlimited.size()); + + auto Limited = + chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, uint64_t(1), MaxRangesPerBlock); + CHECK(Limited.size() < 3u); +} + +TEST_CASE("chunkblock.optimizeranges.unlimited_rangecountperrequest_no_extra_cost") +{ + using RD = chunkblock_impl::RangeDescriptor; + // MaxRangeCountPerRequest=-1 → RequestCount always 1, even with many ranges and high latency + std::vector<RD> ExactRanges = { + {.RangeStart = 0, .RangeLength = 50, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1}, + {.RangeStart = 200, .RangeLength = 50, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1}, + {.RangeStart = 400, .RangeLength = 50, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1}, + {.RangeStart = 600, .RangeLength = 50, .ChunkBlockIndexStart = 3, .ChunkBlockIndexCount = 1}, + {.RangeStart = 800, .RangeLength = 50, .ChunkBlockIndexStart = 4, .ChunkBlockIndexCount = 1}, + }; + uint64_t TotalBlockSize = 5000; + double LatencySec = 0.1; + uint64_t SpeedBytesPerSec = 10000; // RequestTimeAsBytes=1000 << slack=4750 + uint64_t MaxRangeCountPerReq = (uint64_t)-1; + uint64_t MaxRangesPerBlock = 1024; + + auto Result = + chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock); + + CHECK_EQ(5u, Result.size()); +} + +TEST_CASE("chunkblock.optimizeranges.two_range_direct_merge_path") +{ + using RD = chunkblock_impl::RangeDescriptor; + // Exactly 2 ranges; cost model demands merge; exercises the RangeCount==2 direct-merge branch + // After direct merge → 1 range with small slack → full block (empty) + std::vector<RD> ExactRanges = { + {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 2}, + {.RangeStart = 400, .RangeLength = 100, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 2}, + }; + uint64_t TotalBlockSize = 600; + double LatencySec = 0.1; + uint64_t SpeedBytesPerSec = 5000; // RequestTimeAsBytes=500 > slack=400 on first iter + uint64_t MaxRangeCountPerReq = (uint64_t)-1; + uint64_t MaxRangesPerBlock = 1024; + + // Iteration 1: RangeCount=2, RequestCount=1, RequestTimeAsBytes=500 > slack=400 → direct merge + // After merge: 1 range [{0,500,0,4}], slack=100 < Speed*Lat=500 → full block + auto Result = + chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock); + + CHECK(Result.empty()); +} + +TEST_CASE("chunkblock.getneeded.all_chunks") +{ + using namespace chunkblock_analyser_testutils; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + auto Block = MakeBlockDesc(50, {100, 100, 100, 100}); + ChunkBlockAnalyser::Options Options; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + auto HashMap = MakeHashMap({Block}); + auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t) { return true; }); + + REQUIRE_EQ(1u, NeededBlocks.size()); + CHECK_EQ(0u, NeededBlocks[0].BlockIndex); + REQUIRE_EQ(4u, NeededBlocks[0].ChunkIndexes.size()); + CHECK_EQ(0u, NeededBlocks[0].ChunkIndexes[0]); + CHECK_EQ(1u, NeededBlocks[0].ChunkIndexes[1]); + CHECK_EQ(2u, NeededBlocks[0].ChunkIndexes[2]); + CHECK_EQ(3u, NeededBlocks[0].ChunkIndexes[3]); +} + +TEST_CASE("chunkblock.getneeded.no_chunks") +{ + using namespace chunkblock_analyser_testutils; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + auto Block = MakeBlockDesc(50, {100, 100, 100, 100}); + ChunkBlockAnalyser::Options Options; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + auto HashMap = MakeHashMap({Block}); + auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t) { return false; }); + + CHECK(NeededBlocks.empty()); +} + +TEST_CASE("chunkblock.getneeded.subset_within_block") +{ + using namespace chunkblock_analyser_testutils; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + auto Block = MakeBlockDesc(50, {100, 100, 100, 100}); + ChunkBlockAnalyser::Options Options; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + auto HashMap = MakeHashMap({Block}); + // Indices 0 and 2 are needed; 1 and 3 are not + auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex == 0 || ChunkIndex == 2; }); + + REQUIRE_EQ(1u, NeededBlocks.size()); + CHECK_EQ(0u, NeededBlocks[0].BlockIndex); + REQUIRE_EQ(2u, NeededBlocks[0].ChunkIndexes.size()); + CHECK_EQ(0u, NeededBlocks[0].ChunkIndexes[0]); + CHECK_EQ(2u, NeededBlocks[0].ChunkIndexes[1]); +} + +TEST_CASE("chunkblock.getneeded.dedup_low_slack_wins") +{ + using namespace chunkblock_analyser_testutils; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + // Block 0: {H0, H1, SharedH, H3} — 3 of 4 needed (H3 not needed); slack = 100 + // Block 1: {H4, H5, SharedH, H6} — only SharedH needed; slack = 300 + // Block 0 has less slack → processed first → SharedH assigned to block 0 + IoHash SharedH = IoHash::HashBuffer(MemoryView("shared_chunk_dedup", 18)); + IoHash H0 = IoHash::HashBuffer(MemoryView("block0_chunk0", 13)); + IoHash H1 = IoHash::HashBuffer(MemoryView("block0_chunk1", 13)); + IoHash H3 = IoHash::HashBuffer(MemoryView("block0_chunk3", 13)); + IoHash H4 = IoHash::HashBuffer(MemoryView("block1_chunk0", 13)); + IoHash H5 = IoHash::HashBuffer(MemoryView("block1_chunk1", 13)); + IoHash H6 = IoHash::HashBuffer(MemoryView("block1_chunk3", 13)); + + ChunkBlockDescription Block0; + Block0.HeaderSize = 50; + Block0.ChunkRawHashes = {H0, H1, SharedH, H3}; + Block0.ChunkRawLengths = {100, 100, 100, 100}; + Block0.ChunkCompressedLengths = {100, 100, 100, 100}; + + ChunkBlockDescription Block1; + Block1.HeaderSize = 50; + Block1.ChunkRawHashes = {H4, H5, SharedH, H6}; + Block1.ChunkRawLengths = {100, 100, 100, 100}; + Block1.ChunkCompressedLengths = {100, 100, 100, 100}; + + std::vector<ChunkBlockDescription> Blocks = {Block0, Block1}; + ChunkBlockAnalyser::Options Options; + ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options); + + // Map: H0→0, H1→1, SharedH→2, H3→3, H4→4, H5→5, H6→6 + auto HashMap = MakeHashMap(Blocks); + // Need H0(0), H1(1), SharedH(2) from block 0; SharedH from block 1 (already index 2) + // H3(3) not needed; H4,H5,H6 not needed + auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex <= 2; }); + + // Block 0 slack=100 (H3 unused), block 1 slack=300 (H4,H5,H6 unused) + // Block 0 processed first; picks up H0, H1, SharedH + // Block 1 tries SharedH but it's already picked up → empty → not added + REQUIRE_EQ(1u, NeededBlocks.size()); + CHECK_EQ(0u, NeededBlocks[0].BlockIndex); + REQUIRE_EQ(3u, NeededBlocks[0].ChunkIndexes.size()); + CHECK_EQ(0u, NeededBlocks[0].ChunkIndexes[0]); + CHECK_EQ(1u, NeededBlocks[0].ChunkIndexes[1]); + CHECK_EQ(2u, NeededBlocks[0].ChunkIndexes[2]); +} + +TEST_CASE("chunkblock.getneeded.dedup_no_double_pickup") +{ + using namespace chunkblock_analyser_testutils; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + // SharedH appears in both blocks; should appear in the result exactly once + IoHash SharedH = IoHash::HashBuffer(MemoryView("shared_chunk_nodup", 18)); + IoHash H0 = IoHash::HashBuffer(MemoryView("unique_chunk_b0", 15)); + IoHash H1 = IoHash::HashBuffer(MemoryView("unique_chunk_b1a", 16)); + IoHash H2 = IoHash::HashBuffer(MemoryView("unique_chunk_b1b", 16)); + IoHash H3 = IoHash::HashBuffer(MemoryView("unique_chunk_b1c", 16)); + + ChunkBlockDescription Block0; + Block0.HeaderSize = 50; + Block0.ChunkRawHashes = {SharedH, H0}; + Block0.ChunkRawLengths = {100, 100}; + Block0.ChunkCompressedLengths = {100, 100}; + + ChunkBlockDescription Block1; + Block1.HeaderSize = 50; + Block1.ChunkRawHashes = {H1, H2, H3, SharedH}; + Block1.ChunkRawLengths = {100, 100, 100, 100}; + Block1.ChunkCompressedLengths = {100, 100, 100, 100}; + + std::vector<ChunkBlockDescription> Blocks = {Block0, Block1}; + ChunkBlockAnalyser::Options Options; + ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options); + + // Map: SharedH→0, H0→1, H1→2, H2→3, H3→4 + // Only SharedH (index 0) needed; no other chunks + auto HashMap = MakeHashMap(Blocks); + auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex == 0; }); + + // Block 0: SharedH needed, H0 not needed → slack=100 + // Block 1: SharedH needed, H1/H2/H3 not needed → slack=300 + // Block 0 processed first → picks up SharedH; Block 1 skips it + + // Count total occurrences of SharedH across all NeededBlocks + uint32_t SharedOccurrences = 0; + for (const auto& NB : NeededBlocks) + { + for (uint32_t Idx : NB.ChunkIndexes) + { + // SharedH is at block-local index 0 in Block0 and index 3 in Block1 + (void)Idx; + SharedOccurrences++; + } + } + CHECK_EQ(1u, SharedOccurrences); + REQUIRE_EQ(1u, NeededBlocks.size()); + CHECK_EQ(0u, NeededBlocks[0].BlockIndex); +} + +TEST_CASE("chunkblock.getneeded.skips_unrequested_chunks") +{ + using namespace chunkblock_analyser_testutils; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + // Block has 4 chunks but only 2 appear in the hash map → ChunkIndexes has exactly those 2 + auto Block = MakeBlockDesc(50, {100, 100, 100, 100}); + ChunkBlockAnalyser::Options Options; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + // Only put chunks at positions 0 and 2 in the map + tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> HashMap; + HashMap.emplace(Block.ChunkRawHashes[0], 0u); + HashMap.emplace(Block.ChunkRawHashes[2], 1u); + + auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t) { return true; }); + + REQUIRE_EQ(1u, NeededBlocks.size()); + CHECK_EQ(0u, NeededBlocks[0].BlockIndex); + REQUIRE_EQ(2u, NeededBlocks[0].ChunkIndexes.size()); + CHECK_EQ(0u, NeededBlocks[0].ChunkIndexes[0]); + CHECK_EQ(2u, NeededBlocks[0].ChunkIndexes[1]); +} + +TEST_CASE("chunkblock.getneeded.two_blocks_both_contribute") +{ + using namespace chunkblock_analyser_testutils; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + // Block 0: all 4 needed (slack=0); block 1: 3 of 4 needed (slack=100) + // Both blocks contribute chunks → 2 NeededBlocks in result + auto Block0 = MakeBlockDesc(50, {100, 100, 100, 100}, /*BlockSeed=*/0); + auto Block1 = MakeBlockDesc(50, {100, 100, 100, 100}, /*BlockSeed=*/200); + + std::vector<ChunkBlockDescription> Blocks = {Block0, Block1}; + ChunkBlockAnalyser::Options Options; + ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options); + + // HashMap: Block0 hashes → indices 0-3, Block1 hashes → indices 4-7 + auto HashMap = MakeHashMap(Blocks); + // Need all Block0 chunks (0-3) and Block1 chunks 0-2 (indices 4-6); not chunk index 7 (Block1 chunk 3) + auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex <= 6; }); + + CHECK_EQ(2u, NeededBlocks.size()); + // Block 0 has slack=0 (all 4 needed), Block 1 has slack=100 (1 not needed) + // Block 0 comes first in result + CHECK_EQ(0u, NeededBlocks[0].BlockIndex); + CHECK_EQ(4u, NeededBlocks[0].ChunkIndexes.size()); + CHECK_EQ(1u, NeededBlocks[1].BlockIndex); + CHECK_EQ(3u, NeededBlocks[1].ChunkIndexes.size()); +} + +TEST_CASE("chunkblock.calc.off_mode") +{ + using namespace chunkblock_analyser_testutils; + using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + // HeaderSize > 0, chunks size matches → CanDoPartialBlockDownload = true + // But mode Off forces full block regardless + auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); + ChunkBlockAnalyser::Options Options; + Options.IsQuiet = true; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}}; + std::vector<Mode> Modes = {Mode::Off}; + + auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); + + REQUIRE_EQ(1u, Result.FullBlockIndexes.size()); + CHECK_EQ(0u, Result.FullBlockIndexes[0]); + CHECK(Result.BlockRanges.empty()); +} + +TEST_CASE("chunkblock.calc.exact_mode") +{ + using namespace chunkblock_analyser_testutils; + using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); + ChunkBlockAnalyser::Options Options; + Options.IsQuiet = true; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; + + // Need chunks 0 and 2 → 2 non-contiguous ranges; Exact mode passes them straight through + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}}; + std::vector<Mode> Modes = {Mode::Exact}; + + auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); + + CHECK(Result.FullBlockIndexes.empty()); + REQUIRE_EQ(2u, Result.BlockRanges.size()); + + CHECK_EQ(0u, Result.BlockRanges[0].BlockIndex); + CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart); + CHECK_EQ(100u, Result.BlockRanges[0].RangeLength); + CHECK_EQ(0u, Result.BlockRanges[0].ChunkBlockIndexStart); + CHECK_EQ(1u, Result.BlockRanges[0].ChunkBlockIndexCount); + + CHECK_EQ(0u, Result.BlockRanges[1].BlockIndex); + CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[1].RangeStart); // 100+200 before chunk 2 + CHECK_EQ(300u, Result.BlockRanges[1].RangeLength); + CHECK_EQ(2u, Result.BlockRanges[1].ChunkBlockIndexStart); + CHECK_EQ(1u, Result.BlockRanges[1].ChunkBlockIndexCount); +} + +TEST_CASE("chunkblock.calc.singlerange_mode") +{ + using namespace chunkblock_analyser_testutils; + using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); + // Default HostLatencySec=-1 → OptimizeRanges not called after SingleRange collapse + ChunkBlockAnalyser::Options Options; + Options.IsQuiet = true; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; + + // Need chunks 0 and 2 → 2 ranges that get collapsed to 1 + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}}; + std::vector<Mode> Modes = {Mode::SingleRange}; + + auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); + + CHECK(Result.FullBlockIndexes.empty()); + REQUIRE_EQ(1u, Result.BlockRanges.size()); + CHECK_EQ(0u, Result.BlockRanges[0].BlockIndex); + CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart); + // Spans from chunk 0 start to chunk 2 end: 100+200+300=600 + CHECK_EQ(600u, Result.BlockRanges[0].RangeLength); + CHECK_EQ(0u, Result.BlockRanges[0].ChunkBlockIndexStart); + // ChunkBlockIndexCount = (2+1) - 0 = 3 + CHECK_EQ(3u, Result.BlockRanges[0].ChunkBlockIndexCount); +} + +TEST_CASE("chunkblock.calc.multirange_mode") +{ + using namespace chunkblock_analyser_testutils; + using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); + // Low latency: RequestTimeAsBytes=100 << slack → OptimizeRanges preserves ranges + ChunkBlockAnalyser::Options Options; + Options.IsQuiet = true; + Options.HostLatencySec = 0.001; + Options.HostSpeedBytesPerSec = 100000; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; + + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}}; + std::vector<Mode> Modes = {Mode::MultiRange}; + + auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); + + CHECK(Result.FullBlockIndexes.empty()); + REQUIRE_EQ(2u, Result.BlockRanges.size()); + CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart); + CHECK_EQ(100u, Result.BlockRanges[0].RangeLength); + CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[1].RangeStart); + CHECK_EQ(300u, Result.BlockRanges[1].RangeLength); +} + +TEST_CASE("chunkblock.calc.multirangehighspeed_mode") +{ + using namespace chunkblock_analyser_testutils; + using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); + // Block slack ≈ 714 bytes (TotalBlockSize≈1114, RangeTotalSize=400 for chunks 0+2) + // RequestTimeAsBytes = 400000 * 1 * 0.001 = 400 < 714 → ranges preserved + ChunkBlockAnalyser::Options Options; + Options.IsQuiet = true; + Options.HostHighSpeedLatencySec = 0.001; + Options.HostHighSpeedBytesPerSec = 400000; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; + + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}}; + std::vector<Mode> Modes = {Mode::MultiRangeHighSpeed}; + + auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); + + CHECK(Result.FullBlockIndexes.empty()); + REQUIRE_EQ(2u, Result.BlockRanges.size()); + CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart); + CHECK_EQ(100u, Result.BlockRanges[0].RangeLength); + CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[1].RangeStart); + CHECK_EQ(300u, Result.BlockRanges[1].RangeLength); +} + +TEST_CASE("chunkblock.calc.all_chunks_needed_full_block") +{ + using namespace chunkblock_analyser_testutils; + using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); + ChunkBlockAnalyser::Options Options; + Options.IsQuiet = true; + Options.HostLatencySec = 0.001; + Options.HostSpeedBytesPerSec = 100000; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + // All 4 chunks needed → short-circuit to full block regardless of mode + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 1, 2, 3}}}; + std::vector<Mode> Modes = {Mode::Exact}; + + auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); + + REQUIRE_EQ(1u, Result.FullBlockIndexes.size()); + CHECK_EQ(0u, Result.FullBlockIndexes[0]); + CHECK(Result.BlockRanges.empty()); +} + +TEST_CASE("chunkblock.calc.headersize_zero_forces_full_block") +{ + using namespace chunkblock_analyser_testutils; + using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + // HeaderSize=0 → CanDoPartialBlockDownload=false → full block even in Exact mode + auto Block = MakeBlockDesc(0, {100, 200, 300, 400}); + ChunkBlockAnalyser::Options Options; + Options.IsQuiet = true; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}}; + std::vector<Mode> Modes = {Mode::Exact}; + + auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); + + REQUIRE_EQ(1u, Result.FullBlockIndexes.size()); + CHECK_EQ(0u, Result.FullBlockIndexes[0]); + CHECK(Result.BlockRanges.empty()); +} + +TEST_CASE("chunkblock.calc.low_maxrangecountperrequest") +{ + using namespace chunkblock_analyser_testutils; + using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + // 5 chunks of 100 bytes each; need chunks 0, 2, 4 → 3 non-contiguous ranges + // With MaxRangeCountPerRequest=1 and high latency, cost model merges aggressively → full block + auto Block = MakeBlockDesc(10, {100, 100, 100, 100, 100}); + ChunkBlockAnalyser::Options Options; + Options.IsQuiet = true; + Options.HostLatencySec = 0.1; + Options.HostSpeedBytesPerSec = 1000; + Options.HostMaxRangeCountPerRequest = 1; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2, 4}}}; + std::vector<Mode> Modes = {Mode::MultiRange}; + + auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); + + // Cost model drives merging: 3 requests × 1000 × 0.1 = 300 > slack ≈ 210+headersize + // After merges converges to full block + REQUIRE_EQ(1u, Result.FullBlockIndexes.size()); + CHECK_EQ(0u, Result.FullBlockIndexes[0]); + CHECK(Result.BlockRanges.empty()); +} + +TEST_CASE("chunkblock.calc.no_latency_skips_optimize") +{ + using namespace chunkblock_analyser_testutils; + using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); + // Default HostLatencySec=-1 → OptimizeRanges not called; raw GetBlockRanges result used + ChunkBlockAnalyser::Options Options; + Options.IsQuiet = true; + ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options); + + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; + + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}}; + std::vector<Mode> Modes = {Mode::MultiRange}; + + auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); + + // No optimize pass → exact ranges from GetBlockRanges + CHECK(Result.FullBlockIndexes.empty()); + REQUIRE_EQ(2u, Result.BlockRanges.size()); + CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart); + CHECK_EQ(100u, Result.BlockRanges[0].RangeLength); + CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[1].RangeStart); + CHECK_EQ(300u, Result.BlockRanges[1].RangeLength); +} + +TEST_CASE("chunkblock.calc.multiple_blocks_different_modes") +{ + using namespace chunkblock_analyser_testutils; + using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode; + + LoggerRef LogRef = Log(); + std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef)); + + // 3 blocks with different modes: Off, Exact, MultiRange + auto Block0 = MakeBlockDesc(50, {100, 200, 300, 400}, /*BlockSeed=*/0); + auto Block1 = MakeBlockDesc(50, {100, 200, 300, 400}, /*BlockSeed=*/10); + auto Block2 = MakeBlockDesc(50, {100, 200, 300, 400}, /*BlockSeed=*/20); + + ChunkBlockAnalyser::Options Options; + Options.IsQuiet = true; + Options.HostLatencySec = 0.001; + Options.HostSpeedBytesPerSec = 100000; + + std::vector<ChunkBlockDescription> Blocks = {Block0, Block1, Block2}; + ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options); + + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + 50; + + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = { + {.BlockIndex = 0, .ChunkIndexes = {0, 2}}, + {.BlockIndex = 1, .ChunkIndexes = {0, 2}}, + {.BlockIndex = 2, .ChunkIndexes = {0, 2}}, + }; + std::vector<Mode> Modes = {Mode::Off, Mode::Exact, Mode::MultiRange}; + + auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes); + + // Block 0: Off → FullBlockIndexes + REQUIRE_EQ(1u, Result.FullBlockIndexes.size()); + CHECK_EQ(0u, Result.FullBlockIndexes[0]); + + // Block 1: Exact → 2 ranges; Block 2: MultiRange (low latency) → 2 ranges + // Total: 4 ranges + REQUIRE_EQ(4u, Result.BlockRanges.size()); + + // First 2 ranges belong to Block 1 (Exact) + CHECK_EQ(1u, Result.BlockRanges[0].BlockIndex); + CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart); + CHECK_EQ(100u, Result.BlockRanges[0].RangeLength); + CHECK_EQ(1u, Result.BlockRanges[1].BlockIndex); + CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[1].RangeStart); + CHECK_EQ(300u, Result.BlockRanges[1].RangeLength); + + // Last 2 ranges belong to Block 2 (MultiRange preserved) + CHECK_EQ(2u, Result.BlockRanges[2].BlockIndex); + CHECK_EQ(ChunkStartOffset, Result.BlockRanges[2].RangeStart); + CHECK_EQ(100u, Result.BlockRanges[2].RangeLength); + CHECK_EQ(2u, Result.BlockRanges[3].BlockIndex); + CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[3].RangeStart); + CHECK_EQ(300u, Result.BlockRanges[3].RangeLength); +} + +TEST_CASE("chunkblock.getblockranges.first_chunk_only") +{ + using namespace chunkblock_analyser_testutils; + + auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; + + std::vector<uint32_t> Needed = {0}; + auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed); + + REQUIRE_EQ(1u, Ranges.size()); + CHECK_EQ(ChunkStartOffset, Ranges[0].RangeStart); + CHECK_EQ(100u, Ranges[0].RangeLength); + CHECK_EQ(0u, Ranges[0].ChunkBlockIndexStart); + CHECK_EQ(1u, Ranges[0].ChunkBlockIndexCount); +} + +TEST_CASE("chunkblock.getblockranges.last_chunk_only") +{ + using namespace chunkblock_analyser_testutils; + + auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; + + std::vector<uint32_t> Needed = {3}; + auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed); + + REQUIRE_EQ(1u, Ranges.size()); + CHECK_EQ(ChunkStartOffset + 600u, Ranges[0].RangeStart); // 100+200+300 before chunk 3 + CHECK_EQ(400u, Ranges[0].RangeLength); + CHECK_EQ(3u, Ranges[0].ChunkBlockIndexStart); + CHECK_EQ(1u, Ranges[0].ChunkBlockIndexCount); +} + +TEST_CASE("chunkblock.getblockranges.middle_chunk_only") +{ + using namespace chunkblock_analyser_testutils; + + auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; + + std::vector<uint32_t> Needed = {1}; + auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed); + + REQUIRE_EQ(1u, Ranges.size()); + CHECK_EQ(ChunkStartOffset + 100u, Ranges[0].RangeStart); // 100 before chunk 1 + CHECK_EQ(200u, Ranges[0].RangeLength); + CHECK_EQ(1u, Ranges[0].ChunkBlockIndexStart); + CHECK_EQ(1u, Ranges[0].ChunkBlockIndexCount); +} + +TEST_CASE("chunkblock.getblockranges.all_chunks") +{ + using namespace chunkblock_analyser_testutils; + + auto Block = MakeBlockDesc(50, {100, 200, 300, 400}); + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; + + std::vector<uint32_t> Needed = {0, 1, 2, 3}; + auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed); + + REQUIRE_EQ(1u, Ranges.size()); + CHECK_EQ(ChunkStartOffset, Ranges[0].RangeStart); + CHECK_EQ(1000u, Ranges[0].RangeLength); // 100+200+300+400 + CHECK_EQ(0u, Ranges[0].ChunkBlockIndexStart); + CHECK_EQ(4u, Ranges[0].ChunkBlockIndexCount); +} + +TEST_CASE("chunkblock.getblockranges.non_contiguous") +{ + using namespace chunkblock_analyser_testutils; + + // Chunks 0 and 2 needed, chunk 1 skipped → two separate ranges + auto Block = MakeBlockDesc(50, {100, 200, 300}); + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; + + std::vector<uint32_t> Needed = {0, 2}; + auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed); + + REQUIRE_EQ(2u, Ranges.size()); + + CHECK_EQ(ChunkStartOffset, Ranges[0].RangeStart); + CHECK_EQ(100u, Ranges[0].RangeLength); + CHECK_EQ(0u, Ranges[0].ChunkBlockIndexStart); + CHECK_EQ(1u, Ranges[0].ChunkBlockIndexCount); + + CHECK_EQ(ChunkStartOffset + 300u, Ranges[1].RangeStart); // 100+200 before chunk 2 + CHECK_EQ(300u, Ranges[1].RangeLength); + CHECK_EQ(2u, Ranges[1].ChunkBlockIndexStart); + CHECK_EQ(1u, Ranges[1].ChunkBlockIndexCount); +} + +TEST_CASE("chunkblock.getblockranges.contiguous_run") +{ + using namespace chunkblock_analyser_testutils; + + // Chunks 1, 2, 3 needed (consecutive) → one merged range + auto Block = MakeBlockDesc(50, {50, 100, 150, 200, 250}); + uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize; + + std::vector<uint32_t> Needed = {1, 2, 3}; + auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed); + + REQUIRE_EQ(1u, Ranges.size()); + CHECK_EQ(ChunkStartOffset + 50u, Ranges[0].RangeStart); // 50 before chunk 1 + CHECK_EQ(450u, Ranges[0].RangeLength); // 100+150+200 + CHECK_EQ(1u, Ranges[0].ChunkBlockIndexStart); + CHECK_EQ(3u, Ranges[0].ChunkBlockIndexCount); +} + +TEST_SUITE_END(); + void chunkblock_forcelink() { diff --git a/src/zenremotestore/chunking/chunkedcontent.cpp b/src/zenremotestore/chunking/chunkedcontent.cpp index 26d179f14..c09ab9d3a 100644 --- a/src/zenremotestore/chunking/chunkedcontent.cpp +++ b/src/zenremotestore/chunking/chunkedcontent.cpp @@ -166,7 +166,6 @@ namespace { if (Chunked.Info.ChunkSequence.empty()) { AddChunkSequence(Stats, OutChunkedContent.ChunkedContent, ChunkHashToChunkIndex, Chunked.Info.RawHash, RawSize); - Stats.UniqueSequencesFound++; } else { @@ -186,7 +185,6 @@ namespace { Chunked.Info.ChunkHashes, ChunkSizes); } - Stats.UniqueSequencesFound++; } }); Stats.FilesChunked++; @@ -253,7 +251,7 @@ FolderContent::operator==(const FolderContent& Rhs) const if ((Platform == Rhs.Platform) && (RawSizes == Rhs.RawSizes) && (Attributes == Rhs.Attributes) && (ModificationTicks == Rhs.ModificationTicks) && (Paths.size() == Rhs.Paths.size())) { - size_t PathCount = 0; + size_t PathCount = Paths.size(); for (size_t PathIndex = 0; PathIndex < PathCount; PathIndex++) { if (Paths[PathIndex].generic_string() != Rhs.Paths[PathIndex].generic_string()) @@ -1706,6 +1704,8 @@ namespace chunkedcontent_testutils { } // namespace chunkedcontent_testutils +TEST_SUITE_BEGIN("remotestore.chunkedcontent"); + TEST_CASE("chunkedcontent.DeletePathsFromContent") { FastRandom BaseRandom; @@ -1924,6 +1924,8 @@ TEST_CASE("chunkedcontent.ApplyChunkedContentOverlay") } } +TEST_SUITE_END(); + #endif // ZEN_WITH_TESTS } // namespace zen diff --git a/src/zenremotestore/chunking/chunkedfile.cpp b/src/zenremotestore/chunking/chunkedfile.cpp index 652110605..633ddfd0d 100644 --- a/src/zenremotestore/chunking/chunkedfile.cpp +++ b/src/zenremotestore/chunking/chunkedfile.cpp @@ -211,6 +211,8 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { # if 0 +TEST_SUITE_BEGIN("remotestore.chunkedfile"); + TEST_CASE("chunkedfile.findparams") { # if 1 @@ -513,6 +515,8 @@ TEST_CASE("chunkedfile.findparams") // WorkLatch.CountDown(); // WorkLatch.Wait(); } + +TEST_SUITE_END(); # endif // 0 void diff --git a/src/zenremotestore/chunking/chunkingcache.cpp b/src/zenremotestore/chunking/chunkingcache.cpp index 7f0a26330..e9b783a00 100644 --- a/src/zenremotestore/chunking/chunkingcache.cpp +++ b/src/zenremotestore/chunking/chunkingcache.cpp @@ -75,13 +75,13 @@ public: { Lock.ReleaseNow(); RwLock::ExclusiveLockScope EditLock(m_Lock); - if (auto RemoveIt = m_PathHashToEntry.find(PathHash); It != m_PathHashToEntry.end()) + if (auto RemoveIt = m_PathHashToEntry.find(PathHash); RemoveIt != m_PathHashToEntry.end()) { - CachedEntry& DeleteEntry = m_Entries[It->second]; + CachedEntry& DeleteEntry = m_Entries[RemoveIt->second]; DeleteEntry.Chunked = {}; DeleteEntry.ModificationTick = 0; - m_FreeEntryIndexes.push_back(It->second); - m_PathHashToEntry.erase(It); + m_FreeEntryIndexes.push_back(RemoveIt->second); + m_PathHashToEntry.erase(RemoveIt); } } } @@ -461,6 +461,8 @@ namespace chunkingcache_testutils { } } // namespace chunkingcache_testutils +TEST_SUITE_BEGIN("remotestore.chunkingcache"); + TEST_CASE("chunkingcache.nullchunkingcache") { using namespace chunkingcache_testutils; @@ -617,6 +619,8 @@ TEST_CASE("chunkingcache.diskchunkingcache") } } +TEST_SUITE_END(); + void chunkingcache_forcelink() { diff --git a/src/zenremotestore/filesystemutils.cpp b/src/zenremotestore/filesystemutils.cpp index fa1ce6f78..fdb2143d8 100644 --- a/src/zenremotestore/filesystemutils.cpp +++ b/src/zenremotestore/filesystemutils.cpp @@ -637,6 +637,8 @@ namespace { void GenerateFile(const std::filesystem::path& Path) { BasicFile _(Path, BasicFile::Mode::kTruncate); } } // namespace +TEST_SUITE_BEGIN("remotestore.filesystemutils"); + TEST_CASE("filesystemutils.CleanDirectory") { ScopedTemporaryDirectory TmpDir; @@ -692,6 +694,8 @@ TEST_CASE("filesystemutils.CleanDirectory") CHECK(!IsFile(TmpDir.Path() / "CantDeleteMe2" / "deleteme")); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorage.h b/src/zenremotestore/include/zenremotestore/builds/buildstorage.h index 85dabc59f..da8437a58 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorage.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorage.h @@ -53,15 +53,24 @@ public: std::function<IoBuffer(uint64_t Offset, uint64_t Size)>&& Transmitter, std::function<void(uint64_t, bool)>&& OnSentBytes) = 0; - virtual IoBuffer GetBuildBlob(const Oid& BuildId, - const IoHash& RawHash, - uint64_t RangeOffset = 0, - uint64_t RangeBytes = (uint64_t)-1) = 0; + virtual IoBuffer GetBuildBlob(const Oid& BuildId, + const IoHash& RawHash, + uint64_t RangeOffset = 0, + uint64_t RangeBytes = (uint64_t)-1) = 0; + + struct BuildBlobRanges + { + IoBuffer PayloadBuffer; + std::vector<std::pair<uint64_t, uint64_t>> Ranges; + }; + virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId, + const IoHash& RawHash, + std::span<const std::pair<uint64_t, uint64_t>> Ranges) = 0; virtual std::vector<std::function<void()>> GetLargeBuildBlob(const Oid& BuildId, const IoHash& RawHash, uint64_t ChunkSize, std::function<void(uint64_t Offset, const IoBuffer& Chunk)>&& OnReceive, - std::function<void()>&& OnComplete) = 0; + std::function<void()>&& OnComplete) = 0; [[nodiscard]] virtual bool PutBlockMetadata(const Oid& BuildId, const IoHash& BlockRawHash, const CbObject& MetaData) = 0; virtual CbObject FindBlocks(const Oid& BuildId, uint64_t MaxBlockCount) = 0; diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h b/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h index bb5b1c5f4..24702df0f 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h @@ -37,6 +37,14 @@ public: const IoHash& RawHash, uint64_t RangeOffset = 0, uint64_t RangeBytes = (uint64_t)-1) = 0; + struct BuildBlobRanges + { + IoBuffer PayloadBuffer; + std::vector<std::pair<uint64_t, uint64_t>> Ranges; + }; + virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId, + const IoHash& RawHash, + std::span<const std::pair<uint64_t, uint64_t>> Ranges) = 0; virtual void PutBlobMetadatas(const Oid& BuildId, std::span<const IoHash> BlobHashes, std::span<const CbObject> MetaDatas) = 0; virtual std::vector<CbObject> GetBlobMetadatas(const Oid& BuildId, std::span<const IoHash> BlobHashes) = 0; @@ -61,10 +69,19 @@ std::unique_ptr<BuildStorageCache> CreateZenBuildStorageCache(HttpClient& H const std::filesystem::path& TempFolderPath, WorkerThreadPool& BackgroundWorkerPool); +#if ZEN_WITH_TESTS +std::unique_ptr<BuildStorageCache> CreateInMemoryBuildStorageCache(uint64_t MaxRangeSupported, + BuildStorageCache::Statistics& Stats, + double LatencySec = 0.0, + double DelayPerKBSec = 0.0); +#endif // ZEN_WITH_TESTS + struct ZenCacheEndpointTestResult { bool Success = false; std::string FailureReason; + double LatencySeconds = -1.0; + uint64_t MaxRangeCountPerRequest = 1; }; ZenCacheEndpointTestResult TestZenCacheEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpVerbose); diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h index 6304159ae..0d2eded58 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h @@ -7,7 +7,9 @@ #include <zencore/uid.h> #include <zencore/zencore.h> #include <zenremotestore/builds/buildstoragecache.h> +#include <zenremotestore/chunking/chunkblock.h> #include <zenremotestore/chunking/chunkedcontent.h> +#include <zenremotestore/partialblockrequestmode.h> #include <zenutil/bufferedwritefilecache.h> #include <atomic> @@ -108,17 +110,6 @@ struct RebuildFolderStateStatistics uint64_t FinalizeTreeElapsedWallTimeUs = 0; }; -enum EPartialBlockRequestMode -{ - Off, - ZenCacheOnly, - Mixed, - All, - Invalid -}; - -EPartialBlockRequestMode PartialBlockRequestModeFromString(const std::string_view ModeString); - std::filesystem::path ZenStateFilePath(const std::filesystem::path& ZenFolderPath); std::filesystem::path ZenTempFolderPath(const std::filesystem::path& ZenFolderPath); @@ -170,7 +161,7 @@ public: DownloadStatistics m_DownloadStats; WriteChunkStatistics m_WriteChunkStats; RebuildFolderStateStatistics m_RebuildFolderStateStats; - std::atomic<uint64_t> m_WrittenChunkByteCount; + std::atomic<uint64_t> m_WrittenChunkByteCount = 0; private: struct BlockWriteOps @@ -195,7 +186,7 @@ private: uint32_t ScavengedContentIndex = (uint32_t)-1; uint32_t ScavengedPathIndex = (uint32_t)-1; uint32_t RemoteSequenceIndex = (uint32_t)-1; - uint64_t RawSize = (uint32_t)-1; + uint64_t RawSize = (uint64_t)-1; }; struct CopyChunkData @@ -218,33 +209,6 @@ private: uint64_t ElapsedTimeMs = 0; }; - struct BlockRangeDescriptor - { - uint32_t BlockIndex = (uint32_t)-1; - uint64_t RangeStart = 0; - uint64_t RangeLength = 0; - uint32_t ChunkBlockIndexStart = 0; - uint32_t ChunkBlockIndexCount = 0; - }; - - struct BlockRangeLimit - { - uint16_t SizePercent; - uint16_t MaxRangeCount; - }; - - static constexpr uint16_t FullBlockRangePercentLimit = 95; - - static constexpr BuildsOperationUpdateFolder::BlockRangeLimit ForceMergeLimits[] = { - {.SizePercent = FullBlockRangePercentLimit, .MaxRangeCount = 1}, - {.SizePercent = 90, .MaxRangeCount = 2}, - {.SizePercent = 85, .MaxRangeCount = 8}, - {.SizePercent = 80, .MaxRangeCount = 16}, - {.SizePercent = 70, .MaxRangeCount = 32}, - {.SizePercent = 60, .MaxRangeCount = 48}, - {.SizePercent = 2, .MaxRangeCount = 56}, - {.SizePercent = 0, .MaxRangeCount = 64}}; - void ScanCacheFolder(tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedChunkHashesFound, tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedSequenceHashesFound); void ScanTempBlocksFolder(tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedBlocksFound); @@ -299,25 +263,14 @@ private: ParallelWork& Work, std::function<void(IoBuffer&& Payload)>&& OnDownloaded); - BlockRangeDescriptor MergeBlockRanges(std::span<const BlockRangeDescriptor> Ranges); - std::optional<std::vector<BlockRangeDescriptor>> MakeOptionalBlockRangeVector(uint64_t TotalBlockSize, - const BlockRangeDescriptor& Range); - const BlockRangeLimit* GetBlockRangeLimitForRange(std::span<const BlockRangeLimit> Limits, - uint64_t TotalBlockSize, - std::span<const BlockRangeDescriptor> Ranges); - std::vector<BlockRangeDescriptor> CollapseBlockRanges(const uint64_t AlwaysAcceptableGap, - std::span<const BlockRangeDescriptor> BlockRanges); - uint64_t CalculateNextGap(std::span<const BlockRangeDescriptor> BlockRanges); - std::optional<std::vector<BlockRangeDescriptor>> CalculateBlockRanges(uint32_t BlockIndex, - const ChunkBlockDescription& BlockDescription, - std::span<const uint32_t> BlockChunkIndexNeeded, - bool LimitToSingleRange, - const uint64_t ChunkStartOffsetInBlock, - const uint64_t TotalBlockSize, - uint64_t& OutTotalWantedChunksSize); - void DownloadPartialBlock(const BlockRangeDescriptor BlockRange, - const BlobsExistsResult& ExistsResult, - std::function<void(IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath)>&& OnDownloaded); + void DownloadPartialBlock(std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges, + size_t BlockRangeIndex, + size_t BlockRangeCount, + const BlobsExistsResult& ExistsResult, + std::function<void(IoBuffer&& InMemoryBuffer, + const std::filesystem::path& OnDiskPath, + size_t BlockRangeStartIndex, + std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths)>&& OnDownloaded); std::vector<uint32_t> WriteLocalChunkToCache(CloneQueryInterface* CloneQuery, const CopyChunkData& CopyData, @@ -339,7 +292,8 @@ private: const uint64_t FileOffset, const uint32_t PathIndex); - bool GetBlockWriteOps(std::span<const IoHash> ChunkRawHashes, + bool GetBlockWriteOps(const IoHash& BlockRawHash, + std::span<const IoHash> ChunkRawHashes, std::span<const uint32_t> ChunkCompressedLengths, std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters, std::span<std::atomic<bool>> RemoteChunkIndexNeedsCopyFromSourceFlags, @@ -408,7 +362,7 @@ private: const std::filesystem::path m_TempDownloadFolderPath; const std::filesystem::path m_TempBlockFolderPath; - std::atomic<uint64_t> m_ValidatedChunkByteCount; + std::atomic<uint64_t> m_ValidatedChunkByteCount = 0; }; struct FindBlocksStatistics diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h index ab3037c89..7306188ca 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h @@ -14,13 +14,20 @@ class BuildStorageCache; struct BuildStorageResolveResult { - std::string HostUrl; - std::string HostName; - bool HostAssumeHttp2 = false; - - std::string CacheUrl; - std::string CacheName; - bool CacheAssumeHttp2 = false; + struct Capabilities + { + uint64_t MaxRangeCountPerRequest = 1; + }; + struct Host + { + std::string Address; + std::string Name; + bool AssumeHttp2 = false; + double LatencySec = -1.0; + Capabilities Caps; + }; + Host Cloud; + Host Cache; }; enum class ZenCacheResolveMode @@ -43,7 +50,6 @@ std::vector<ChunkBlockDescription> GetBlockDescriptions(OperationLogOutput& Out BuildStorageBase& Storage, BuildStorageCache* OptionalCacheStorage, const Oid& BuildId, - const Oid& BuildPartId, std::span<const IoHash> BlockRawHashes, bool AttemptFallback, bool IsQuiet, @@ -51,12 +57,13 @@ std::vector<ChunkBlockDescription> GetBlockDescriptions(OperationLogOutput& Out struct StorageInstance { - std::unique_ptr<HttpClient> BuildStorageHttp; - std::unique_ptr<BuildStorageBase> BuildStorage; - std::string StorageName; + BuildStorageResolveResult::Host BuildStorageHost; + std::unique_ptr<HttpClient> BuildStorageHttp; + std::unique_ptr<BuildStorageBase> BuildStorage; + + BuildStorageResolveResult::Host CacheHost; std::unique_ptr<HttpClient> CacheHttp; - std::unique_ptr<BuildStorageCache> BuildCacheStorage; - std::string CacheName; + std::unique_ptr<BuildStorageCache> CacheStorage; }; } // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h index d339b0f94..931bb2097 100644 --- a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h +++ b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h @@ -7,8 +7,9 @@ #include <zencore/compactbinary.h> #include <zencore/compress.h> -#include <optional> -#include <vector> +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END namespace zen { @@ -20,13 +21,14 @@ struct ThinChunkBlockDescription struct ChunkBlockDescription : public ThinChunkBlockDescription { - uint64_t HeaderSize; + uint64_t HeaderSize = 0; std::vector<uint32_t> ChunkRawLengths; std::vector<uint32_t> ChunkCompressedLengths; }; std::vector<ChunkBlockDescription> ParseChunkBlockDescriptionList(const CbObjectView& BlocksObject); ChunkBlockDescription ParseChunkBlockDescription(const CbObjectView& BlockObject); +std::vector<ChunkBlockDescription> ParseBlockMetadatas(std::span<const CbObject> BlockMetadatas); CbObject BuildChunkBlockDescription(const ChunkBlockDescription& Block, CbObjectView MetaData); ChunkBlockDescription GetChunkBlockDescription(const SharedBuffer& BlockPayload, const IoHash& RawHash); typedef std::function<std::pair<uint64_t, CompressedBuffer>(const IoHash& RawHash)> FetchChunkFunc; @@ -73,6 +75,70 @@ std::vector<size_t> FindReuseBlocks(OperationLogOutput& Output, std::span<const uint32_t> ChunkIndexes, std::vector<uint32_t>& OutUnusedChunkIndexes); +class ChunkBlockAnalyser +{ +public: + struct Options + { + bool IsQuiet = false; + bool IsVerbose = false; + double HostLatencySec = -1.0; + double HostHighSpeedLatencySec = -1.0; + uint64_t HostSpeedBytesPerSec = (1u * 1024u * 1024u * 1024u) / 8u; // 1GBit + uint64_t HostHighSpeedBytesPerSec = (2u * 1024u * 1024u * 1024u) / 8u; // 2GBit + uint64_t HostMaxRangeCountPerRequest = (uint64_t)-1; + uint64_t HostHighSpeedMaxRangeCountPerRequest = (uint64_t)-1; // No limit + uint64_t MaxRangesPerBlock = 1024u; + }; + + ChunkBlockAnalyser(OperationLogOutput& LogOutput, std::span<const ChunkBlockDescription> BlockDescriptions, const Options& Options); + + struct BlockRangeDescriptor + { + uint32_t BlockIndex = (uint32_t)-1; + uint64_t RangeStart = 0; + uint64_t RangeLength = 0; + uint32_t ChunkBlockIndexStart = 0; + uint32_t ChunkBlockIndexCount = 0; + }; + + struct NeededBlock + { + uint32_t BlockIndex; + std::vector<uint32_t> ChunkIndexes; + }; + + std::vector<NeededBlock> GetNeeded(const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& ChunkHashToChunkIndex, + std::function<bool(uint32_t ChunkIndex)>&& NeedsBlockChunk); + + enum class EPartialBlockDownloadMode + { + Off, + SingleRange, + MultiRange, + MultiRangeHighSpeed, + Exact + }; + + struct BlockResult + { + std::vector<BlockRangeDescriptor> BlockRanges; + std::vector<uint32_t> FullBlockIndexes; + }; + + BlockResult CalculatePartialBlockDownloads(std::span<const NeededBlock> NeededBlocks, + std::span<const EPartialBlockDownloadMode> BlockPartialDownloadModes); + +private: + OperationLogOutput& m_LogOutput; + const std::span<const ChunkBlockDescription> m_BlockDescriptions; + const Options m_Options; +}; + +#if ZEN_WITH_TESTS + void chunkblock_forcelink(); +#endif // ZEN_WITH_TESTS + } // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h b/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h index d402bd3f0..f44381e42 100644 --- a/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h +++ b/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h @@ -231,7 +231,7 @@ GetSequenceIndexForRawHash(const ChunkedContentLookup& Lookup, const IoHash& Raw inline uint32_t GetChunkIndexForRawHash(const ChunkedContentLookup& Lookup, const IoHash& RawHash) { - return Lookup.RawHashToSequenceIndex.at(RawHash); + return Lookup.ChunkHashToChunkIndex.at(RawHash); } inline uint32_t diff --git a/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h b/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h index 432496bc1..caf7ecd28 100644 --- a/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h +++ b/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h @@ -2,6 +2,7 @@ #pragma once +#include <cstdint> #include <string> #include <string_view> #include <vector> @@ -28,6 +29,8 @@ struct JupiterEndpointTestResult { bool Success = false; std::string FailureReason; + double LatencySeconds = -1.0; + uint64_t MaxRangeCountPerRequest = 1; }; JupiterEndpointTestResult TestJupiterEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpVerbose); diff --git a/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h b/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h index eaf6962fd..8721bc37f 100644 --- a/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h +++ b/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h @@ -56,6 +56,11 @@ struct FinalizeBuildPartResult : JupiterResult std::vector<IoHash> Needs; }; +struct BuildBlobRangesResult : JupiterResult +{ + std::vector<std::pair<uint64_t, uint64_t>> Ranges; +}; + /** * Context for performing Jupiter operations * @@ -135,6 +140,13 @@ public: uint64_t Offset = 0, uint64_t Size = (uint64_t)-1); + BuildBlobRangesResult GetBuildBlob(std::string_view Namespace, + std::string_view BucketId, + const Oid& BuildId, + const IoHash& Hash, + std::filesystem::path TempFolderPath, + std::span<const std::pair<uint64_t, uint64_t>> Ranges); + JupiterResult PutMultipartBuildBlob(std::string_view Namespace, std::string_view BucketId, const Oid& BuildId, diff --git a/src/zenremotestore/include/zenremotestore/operationlogoutput.h b/src/zenremotestore/include/zenremotestore/operationlogoutput.h index 9693e69cf..32b95f50f 100644 --- a/src/zenremotestore/include/zenremotestore/operationlogoutput.h +++ b/src/zenremotestore/include/zenremotestore/operationlogoutput.h @@ -3,6 +3,7 @@ #pragma once #include <zencore/fmtutils.h> +#include <zencore/logbase.h> namespace zen { @@ -10,7 +11,7 @@ class OperationLogOutput { public: virtual ~OperationLogOutput() {} - virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) = 0; + virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) = 0; virtual void SetLogOperationName(std::string_view Name) = 0; virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) = 0; @@ -57,23 +58,19 @@ public: virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) = 0; }; -struct LoggerRef; +OperationLogOutput* CreateStandardLogOutput(LoggerRef Log); -OperationLogOutput* CreateStandardLogOutput(LoggerRef& Log); - -#define ZEN_OPERATION_LOG(OutputTarget, InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - OutputTarget.EmitLogMessage(InLevel, fmtstr, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ +#define ZEN_OPERATION_LOG(OutputTarget, InLevel, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + static constinit zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \ + ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ + (OutputTarget).EmitLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ } while (false) -#define ZEN_OPERATION_LOG_INFO(OutputTarget, fmtstr, ...) \ - ZEN_OPERATION_LOG((OutputTarget), zen::logging::level::Info, fmtstr, ##__VA_ARGS__) -#define ZEN_OPERATION_LOG_DEBUG(OutputTarget, fmtstr, ...) \ - ZEN_OPERATION_LOG((OutputTarget), zen::logging::level::Debug, fmtstr, ##__VA_ARGS__) -#define ZEN_OPERATION_LOG_WARN(OutputTarget, fmtstr, ...) \ - ZEN_OPERATION_LOG((OutputTarget), zen::logging::level::Warn, fmtstr, ##__VA_ARGS__) +#define ZEN_OPERATION_LOG_INFO(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Info, fmtstr, ##__VA_ARGS__) +#define ZEN_OPERATION_LOG_DEBUG(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Debug, fmtstr, ##__VA_ARGS__) +#define ZEN_OPERATION_LOG_WARN(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Warn, fmtstr, ##__VA_ARGS__) } // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/partialblockrequestmode.h b/src/zenremotestore/include/zenremotestore/partialblockrequestmode.h new file mode 100644 index 000000000..54adea2b2 --- /dev/null +++ b/src/zenremotestore/include/zenremotestore/partialblockrequestmode.h @@ -0,0 +1,20 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <string_view> + +namespace zen { + +enum EPartialBlockRequestMode +{ + Off, + ZenCacheOnly, + Mixed, + All, + Invalid +}; + +EPartialBlockRequestMode PartialBlockRequestModeFromString(const std::string_view ModeString); + +} // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h index e8b7c15c0..c058e1c1f 100644 --- a/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h +++ b/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h @@ -2,6 +2,7 @@ #pragma once +#include <zenhttp/httpclient.h> #include <zenremotestore/projectstore/remoteprojectstore.h> namespace zen { @@ -10,9 +11,6 @@ class AuthMgr; struct BuildsRemoteStoreOptions : RemoteStoreOptions { - std::string Host; - std::string OverrideHost; - std::string ZenHost; std::string Namespace; std::string Bucket; Oid BuildId; @@ -22,18 +20,16 @@ struct BuildsRemoteStoreOptions : RemoteStoreOptions std::filesystem::path OidcExePath; bool ForceDisableBlocks = false; bool ForceDisableTempBlocks = false; - bool AssumeHttp2 = false; - bool PopulateCache = true; IoBuffer MetaData; size_t MaximumInMemoryDownloadSize = 1024u * 1024u; }; -std::shared_ptr<RemoteProjectStore> CreateJupiterBuildsRemoteStore(LoggerRef InLog, - const BuildsRemoteStoreOptions& Options, - const std::filesystem::path& TempFilePath, - bool Quiet, - bool Unattended, - bool Hidden, - WorkerThreadPool& CacheBackgroundWorkerPool); +struct BuildStorageResolveResult; + +std::shared_ptr<RemoteProjectStore> CreateJupiterBuildsRemoteStore(LoggerRef InLog, + const BuildStorageResolveResult& ResolveResult, + std::function<HttpClientAccessToken()>&& TokenProvider, + const BuildsRemoteStoreOptions& Options, + const std::filesystem::path& TempFilePath); } // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h index 008f94351..084d975a2 100644 --- a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h +++ b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h @@ -5,7 +5,9 @@ #include <zencore/jobqueue.h> #include <zenstore/projectstore.h> +#include <zenremotestore/builds/buildstoragecache.h> #include <zenremotestore/chunking/chunkblock.h> +#include <zenremotestore/partialblockrequestmode.h> #include <unordered_set> @@ -73,24 +75,35 @@ public: std::vector<ChunkBlockDescription> Blocks; }; + struct GetBlockDescriptionsResult : public Result + { + std::vector<ChunkBlockDescription> Blocks; + }; + + struct LoadAttachmentRangesResult : public Result + { + IoBuffer Bytes; + std::vector<std::pair<uint64_t, uint64_t>> Ranges; + }; + struct RemoteStoreInfo { - bool CreateBlocks; - bool UseTempBlockFiles; - bool AllowChunking; + bool CreateBlocks = false; + bool UseTempBlockFiles = false; + bool AllowChunking = false; std::string ContainerName; std::string Description; }; struct Stats { - std::uint64_t m_SentBytes; - std::uint64_t m_ReceivedBytes; - std::uint64_t m_RequestTimeNS; - std::uint64_t m_RequestCount; - std::uint64_t m_PeakSentBytes; - std::uint64_t m_PeakReceivedBytes; - std::uint64_t m_PeakBytesPerSec; + std::uint64_t m_SentBytes = 0; + std::uint64_t m_ReceivedBytes = 0; + std::uint64_t m_RequestTimeNS = 0; + std::uint64_t m_RequestCount = 0; + std::uint64_t m_PeakSentBytes = 0; + std::uint64_t m_PeakReceivedBytes = 0; + std::uint64_t m_PeakBytesPerSec = 0; }; struct ExtendedStats @@ -111,12 +124,17 @@ public: virtual FinalizeResult FinalizeContainer(const IoHash& RawHash) = 0; virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Payloads) = 0; - virtual LoadContainerResult LoadContainer() = 0; - virtual GetKnownBlocksResult GetKnownBlocks() = 0; - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) = 0; - virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) = 0; + virtual LoadContainerResult LoadContainer() = 0; + virtual GetKnownBlocksResult GetKnownBlocks() = 0; + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span<const IoHash> BlockHashes, + BuildStorageCache* OptionalCache, + const Oid& CacheBuildId) = 0; + + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) = 0; - virtual void Flush() = 0; + virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, + std::span<const std::pair<uint64_t, uint64_t>> Ranges) = 0; + virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) = 0; }; struct RemoteStoreOptions @@ -153,14 +171,15 @@ RemoteProjectStore::LoadContainerResult BuildContainer( class JobContext; -RemoteProjectStore::Result SaveOplogContainer(ProjectStore::Oplog& Oplog, - const CbObject& ContainerObject, - const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments, - const std::function<bool(const IoHash& RawHash)>& HasAttachment, - const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock, - const std::function<void(const IoHash& RawHash)>& OnNeedAttachment, - const std::function<void(const ChunkedInfo& Chunked)>& OnChunkedAttachment, - JobContext* OptionalContext); +RemoteProjectStore::Result SaveOplogContainer( + ProjectStore::Oplog& Oplog, + const CbObject& ContainerObject, + const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments, + const std::function<bool(const IoHash& RawHash)>& HasAttachment, + const std::function<void(ThinChunkBlockDescription&& ThinBlockDescription, std::vector<uint32_t>&& NeededChunkIndexes)>& OnNeedBlock, + const std::function<void(const IoHash& RawHash)>& OnNeedAttachment, + const std::function<void(const ChunkedInfo& Chunked)>& OnChunkedAttachment, + JobContext* OptionalContext); RemoteProjectStore::Result SaveOplog(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, @@ -177,15 +196,29 @@ RemoteProjectStore::Result SaveOplog(CidStore& ChunkStore, bool IgnoreMissingAttachments, JobContext* OptionalContext); -RemoteProjectStore::Result LoadOplog(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - ProjectStore::Oplog& Oplog, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, - bool ForceDownload, - bool IgnoreMissingAttachments, - bool CleanOplog, - JobContext* OptionalContext); +struct LoadOplogContext +{ + CidStore& ChunkStore; + RemoteProjectStore& RemoteStore; + BuildStorageCache* OptionalCache = nullptr; + Oid CacheBuildId = Oid::Zero; + BuildStorageCache::Statistics* OptionalCacheStats = nullptr; + ProjectStore::Oplog& Oplog; + WorkerThreadPool& NetworkWorkerPool; + WorkerThreadPool& WorkerPool; + bool ForceDownload = false; + bool IgnoreMissingAttachments = false; + bool CleanOplog = false; + EPartialBlockRequestMode PartialBlockRequestMode = EPartialBlockRequestMode::All; + bool PopulateCache = false; + double StoreLatencySec = -1.0; + uint64_t StoreMaxRangeCountPerRequest = 1; + double CacheLatencySec = -1.0; + uint64_t CacheMaxRangeCountPerRequest = 1; + JobContext* OptionalJobContext = nullptr; +}; + +RemoteProjectStore::Result LoadOplog(LoadOplogContext&& Context); std::vector<IoHash> GetBlockHashesFromOplog(CbObjectView ContainerObject); std::vector<ThinChunkBlockDescription> GetBlocksFromOplog(CbObjectView ContainerObject, std::span<const IoHash> IncludeBlockHashes); diff --git a/src/zenremotestore/jupiter/jupiterhost.cpp b/src/zenremotestore/jupiter/jupiterhost.cpp index 7706f00c2..314aafc78 100644 --- a/src/zenremotestore/jupiter/jupiterhost.cpp +++ b/src/zenremotestore/jupiter/jupiterhost.cpp @@ -59,7 +59,22 @@ TestJupiterEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpClient::Response TestResponse = TestHttpClient.Get("/health/live"); if (TestResponse.IsSuccess()) { - return {.Success = true}; + // TODO: dan.engelbrecht 20260305 - replace this naive nginx detection with proper capabilites end point once it exists in Jupiter + uint64_t MaxRangeCountPerRequest = 1; + if (auto It = TestResponse.Header.Entries.find("Server"); It != TestResponse.Header.Entries.end()) + { + if (StrCaseCompare(It->second.c_str(), "nginx", 5) == 0) + { + MaxRangeCountPerRequest = 128u; // This leaves more than 2k header space for auth token etc + } + } + LatencyTestResult LatencyResult = MeasureLatency(TestHttpClient, "/health/ready"); + + if (!LatencyResult.Success) + { + return {.Success = false, .FailureReason = LatencyResult.FailureReason}; + } + return {.Success = true, .LatencySeconds = LatencyResult.LatencySeconds, .MaxRangeCountPerRequest = MaxRangeCountPerRequest}; } return {.Success = false, .FailureReason = TestResponse.ErrorMessage("")}; } diff --git a/src/zenremotestore/jupiter/jupitersession.cpp b/src/zenremotestore/jupiter/jupitersession.cpp index 1bc6564ce..52f9eb678 100644 --- a/src/zenremotestore/jupiter/jupitersession.cpp +++ b/src/zenremotestore/jupiter/jupitersession.cpp @@ -852,6 +852,71 @@ JupiterSession::GetBuildBlob(std::string_view Namespace, return detail::ConvertResponse(Response, "JupiterSession::GetBuildBlob"sv); } +BuildBlobRangesResult +JupiterSession::GetBuildBlob(std::string_view Namespace, + std::string_view BucketId, + const Oid& BuildId, + const IoHash& Hash, + std::filesystem::path TempFolderPath, + std::span<const std::pair<uint64_t, uint64_t>> Ranges) +{ + HttpClient::KeyValueMap Headers; + if (!Ranges.empty()) + { + ExtendableStringBuilder<512> SB; + for (const std::pair<uint64_t, uint64_t>& R : Ranges) + { + if (SB.Size() > 0) + { + SB << ", "; + } + SB << R.first << "-" << R.first + R.second - 1; + } + Headers.Entries.insert({"Range", fmt::format("bytes={}", SB.ToView())}); + } + std::string Url = fmt::format("/api/v2/builds/{}/{}/{}/blobs/{}?supportsRedirect={}", + Namespace, + BucketId, + BuildId, + Hash.ToHexString(), + m_AllowRedirect ? "true"sv : "false"sv); + + HttpClient::Response Response = m_HttpClient.Download(Url, TempFolderPath, Headers); + if (Response.StatusCode == HttpResponseCode::RangeNotSatisfiable && Ranges.size() > 1) + { + // Requests to Jupiter that is not served via nginx (content not stored locally in the file system) can not serve multi-range + // requests (asp.net limitation) This rejection is not implemented as of 2026-03-02, it is in the backlog (@joakim.lindqvist) + // If we encounter this error we fall back to a single range which covers all the requested ranges + uint64_t RangeStart = Ranges.front().first; + uint64_t RangeEnd = Ranges.back().first + Ranges.back().second - 1; + Headers.Entries.insert_or_assign("Range", fmt::format("bytes={}-{}", RangeStart, RangeEnd)); + Response = m_HttpClient.Download(Url, TempFolderPath, Headers); + } + if (Response.IsSuccess()) + { + // If we get a redirect to S3 or a non-Jupiter endpoint the content type will not be correct, validate it and set it + if (m_AllowRedirect && (Response.ResponsePayload.GetContentType() == HttpContentType::kBinary)) + { + IoHash ValidateRawHash; + uint64_t ValidateRawSize = 0; + if (!Headers.Entries.contains("Range")) + { + ZEN_ASSERT_SLOW(CompressedBuffer::ValidateCompressedHeader(Response.ResponsePayload, + ValidateRawHash, + ValidateRawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)); + ZEN_ASSERT_SLOW(ValidateRawHash == Hash); + ZEN_ASSERT_SLOW(ValidateRawSize > 0); + ZEN_UNUSED(ValidateRawHash, ValidateRawSize); + Response.ResponsePayload.SetContentType(ZenContentType::kCompressedBinary); + } + } + } + BuildBlobRangesResult Result = {detail::ConvertResponse(Response, "JupiterSession::GetBuildBlob"sv)}; + Result.Ranges = Response.GetRanges(Ranges); + return Result; +} + JupiterResult JupiterSession::PutBlockMetadata(std::string_view Namespace, std::string_view BucketId, diff --git a/src/zenremotestore/operationlogoutput.cpp b/src/zenremotestore/operationlogoutput.cpp index 0837ed716..5ed844c9d 100644 --- a/src/zenremotestore/operationlogoutput.cpp +++ b/src/zenremotestore/operationlogoutput.cpp @@ -3,6 +3,7 @@ #include <zenremotestore/operationlogoutput.h> #include <zencore/logging.h> +#include <zencore/logging/logger.h> ZEN_THIRD_PARTY_INCLUDES_START #include <gsl/gsl-lite.hpp> @@ -30,13 +31,11 @@ class StandardLogOutput : public OperationLogOutput { public: StandardLogOutput(LoggerRef& Log) : m_Log(Log) {} - virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) override + virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override { - if (m_Log.ShouldLog(LogLevel)) + if (m_Log.ShouldLog(Point.Level)) { - fmt::basic_memory_buffer<char, 250> MessageBuffer; - fmt::vformat_to(fmt::appender(MessageBuffer), Format, Args); - ZEN_LOG(m_Log, LogLevel, "{}", std::string_view(MessageBuffer.data(), MessageBuffer.size())); + m_Log->Log(Point, Args); } } @@ -47,7 +46,7 @@ public: } virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { - const size_t PercentDone = StepCount > 0u ? gsl::narrow<uint8_t>((100 * StepIndex) / StepCount) : 0u; + [[maybe_unused]] const size_t PercentDone = StepCount > 0u ? gsl::narrow<uint8_t>((100 * StepIndex) / StepCount) : 0u; ZEN_OPERATION_LOG_INFO(*this, "{}: {}%", m_LogOperationName, PercentDone); } virtual uint32_t GetProgressUpdateDelayMS() override { return 2000; } @@ -59,13 +58,14 @@ public: private: LoggerRef m_Log; std::string m_LogOperationName; + LoggerRef Log() { return m_Log; } }; void StandardLogOutputProgressBar::UpdateState(const State& NewState, bool DoLinebreak) { ZEN_UNUSED(DoLinebreak); - const size_t PercentDone = + [[maybe_unused]] const size_t PercentDone = NewState.TotalCount > 0u ? gsl::narrow<uint8_t>((100 * (NewState.TotalCount - NewState.RemainingCount)) / NewState.TotalCount) : 0u; std::string Task = NewState.Task; switch (NewState.Status) @@ -95,7 +95,7 @@ StandardLogOutputProgressBar::Finish() } OperationLogOutput* -CreateStandardLogOutput(LoggerRef& Log) +CreateStandardLogOutput(LoggerRef Log) { return new StandardLogOutput(Log); } diff --git a/src/zenremotestore/partialblockrequestmode.cpp b/src/zenremotestore/partialblockrequestmode.cpp new file mode 100644 index 000000000..b3edf515b --- /dev/null +++ b/src/zenremotestore/partialblockrequestmode.cpp @@ -0,0 +1,27 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenremotestore/partialblockrequestmode.h> + +#include <zencore/string.h> + +namespace zen { + +EPartialBlockRequestMode +PartialBlockRequestModeFromString(const std::string_view ModeString) +{ + switch (HashStringAsLowerDjb2(ModeString)) + { + case HashStringDjb2("false"): + return EPartialBlockRequestMode::Off; + case HashStringDjb2("zencacheonly"): + return EPartialBlockRequestMode::ZenCacheOnly; + case HashStringDjb2("mixed"): + return EPartialBlockRequestMode::Mixed; + case HashStringDjb2("true"): + return EPartialBlockRequestMode::All; + default: + return EPartialBlockRequestMode::Invalid; + } +} + +} // namespace zen diff --git a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp index a8e883dde..2282a31dd 100644 --- a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp @@ -7,8 +7,6 @@ #include <zencore/fmtutils.h> #include <zencore/scopeguard.h> -#include <zenhttp/httpclientauth.h> -#include <zenremotestore/builds/buildstoragecache.h> #include <zenremotestore/builds/buildstorageutil.h> #include <zenremotestore/builds/jupiterbuildstorage.h> #include <zenremotestore/operationlogoutput.h> @@ -26,18 +24,14 @@ class BuildsRemoteStore : public RemoteProjectStore public: BuildsRemoteStore(LoggerRef InLog, const HttpClientSettings& ClientSettings, - HttpClientSettings* OptionalCacheClientSettings, std::string_view HostUrl, - std::string_view CacheUrl, const std::filesystem::path& TempFilePath, - WorkerThreadPool& CacheBackgroundWorkerPool, std::string_view Namespace, std::string_view Bucket, const Oid& BuildId, const IoBuffer& MetaData, bool ForceDisableBlocks, - bool ForceDisableTempBlocks, - bool PopulateCache) + bool ForceDisableTempBlocks) : m_Log(InLog) , m_BuildStorageHttp(HostUrl, ClientSettings) , m_BuildStorage(CreateJupiterBuildStorage(Log(), @@ -53,20 +47,8 @@ public: , m_MetaData(MetaData) , m_EnableBlocks(!ForceDisableBlocks) , m_UseTempBlocks(!ForceDisableTempBlocks) - , m_PopulateCache(PopulateCache) { m_MetaData.MakeOwned(); - if (OptionalCacheClientSettings) - { - ZEN_ASSERT(!CacheUrl.empty()); - m_BuildCacheStorageHttp = std::make_unique<HttpClient>(CacheUrl, *OptionalCacheClientSettings); - m_BuildCacheStorage = CreateZenBuildStorageCache(*m_BuildCacheStorageHttp, - m_StorageCacheStats, - Namespace, - Bucket, - TempFilePath, - CacheBackgroundWorkerPool); - } } virtual RemoteStoreInfo GetInfo() const override @@ -75,9 +57,8 @@ public: .UseTempBlockFiles = m_UseTempBlocks, .AllowChunking = true, .ContainerName = fmt::format("{}/{}/{}", m_Namespace, m_Bucket, m_BuildId), - .Description = fmt::format("[cloud] {}{}. SessionId: {}. {}/{}/{}"sv, + .Description = fmt::format("[cloud] {}. SessionId: {}. {}/{}/{}"sv, m_BuildStorageHttp.GetBaseUri(), - m_BuildCacheStorage ? fmt::format(" (Cache: {})", m_BuildCacheStorageHttp->GetBaseUri()) : ""sv, m_BuildStorageHttp.GetSessionId(), m_Namespace, m_Bucket, @@ -86,15 +67,13 @@ public: virtual Stats GetStats() const override { - return { - .m_SentBytes = m_BuildStorageStats.TotalBytesWritten.load() + m_StorageCacheStats.TotalBytesWritten.load(), - .m_ReceivedBytes = m_BuildStorageStats.TotalBytesRead.load() + m_StorageCacheStats.TotalBytesRead.load(), - .m_RequestTimeNS = m_BuildStorageStats.TotalRequestTimeUs.load() * 1000 + m_StorageCacheStats.TotalRequestTimeUs.load() * 1000, - .m_RequestCount = m_BuildStorageStats.TotalRequestCount.load() + m_StorageCacheStats.TotalRequestCount.load(), - .m_PeakSentBytes = Max(m_BuildStorageStats.PeakSentBytes.load(), m_StorageCacheStats.PeakSentBytes.load()), - .m_PeakReceivedBytes = Max(m_BuildStorageStats.PeakReceivedBytes.load(), m_StorageCacheStats.PeakReceivedBytes.load()), - .m_PeakBytesPerSec = Max(m_BuildStorageStats.PeakBytesPerSec.load(), m_StorageCacheStats.PeakBytesPerSec.load()), - }; + return {.m_SentBytes = m_BuildStorageStats.TotalBytesWritten.load(), + .m_ReceivedBytes = m_BuildStorageStats.TotalBytesRead.load(), + .m_RequestTimeNS = m_BuildStorageStats.TotalRequestTimeUs.load() * 1000, + .m_RequestCount = m_BuildStorageStats.TotalRequestCount.load(), + .m_PeakSentBytes = m_BuildStorageStats.PeakSentBytes.load(), + .m_PeakReceivedBytes = m_BuildStorageStats.PeakReceivedBytes.load(), + .m_PeakBytesPerSec = m_BuildStorageStats.PeakBytesPerSec.load()}; } virtual bool GetExtendedStats(ExtendedStats& OutStats) const override @@ -109,11 +88,6 @@ public: } Result = true; } - if (m_BuildCacheStorage) - { - OutStats.m_ReceivedBytesPerSource.insert_or_assign("Cache", m_StorageCacheStats.TotalBytesRead); - Result = true; - } return Result; } @@ -441,7 +415,7 @@ public: catch (const HttpClientError& Ex) { Result.ErrorCode = MakeErrorCode(Ex); - Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'", + Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", m_BuildStorageHttp.GetBaseUri(), m_Namespace, m_Bucket, @@ -451,7 +425,7 @@ public: catch (const std::exception& Ex) { Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); - Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'", + Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", m_BuildStorageHttp.GetBaseUri(), m_Namespace, m_Bucket, @@ -462,6 +436,53 @@ public: return Result; } + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span<const IoHash> BlockHashes, + BuildStorageCache* OptionalCache, + const Oid& CacheBuildId) override + { + std::unique_ptr<OperationLogOutput> Output(CreateStandardLogOutput(Log())); + + ZEN_ASSERT(m_OplogBuildPartId != Oid::Zero); + ZEN_ASSERT(OptionalCache == nullptr || CacheBuildId == m_BuildId); + + GetBlockDescriptionsResult Result; + Stopwatch Timer; + auto _ = MakeGuard([&Timer, &Result]() { Result.ElapsedSeconds = Timer.GetElapsedTimeUs() / 1000000.0; }); + + try + { + Result.Blocks = zen::GetBlockDescriptions(*Output, + *m_BuildStorage, + OptionalCache, + m_BuildId, + BlockHashes, + /*AttemptFallback*/ false, + /*IsQuiet*/ false, + /*IsVerbose)*/ false); + } + catch (const HttpClientError& Ex) + { + Result.ErrorCode = MakeErrorCode(Ex); + Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", + m_BuildStorageHttp.GetBaseUri(), + m_Namespace, + m_Bucket, + m_BuildId, + Ex.what()); + } + catch (const std::exception& Ex) + { + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", + m_BuildStorageHttp.GetBaseUri(), + m_Namespace, + m_Bucket, + m_BuildId, + Ex.what()); + } + return Result; + } + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override { ZEN_ASSERT(m_OplogBuildPartId != Oid::Zero); @@ -472,44 +493,73 @@ public: try { - if (m_BuildCacheStorage) - { - IoBuffer CachedBlob = m_BuildCacheStorage->GetBuildBlob(m_BuildId, RawHash); - if (CachedBlob) - { - Result.Bytes = std::move(CachedBlob); - } - } - if (!Result.Bytes) + Result.Bytes = m_BuildStorage->GetBuildBlob(m_BuildId, RawHash); + } + catch (const HttpClientError& Ex) + { + Result.ErrorCode = MakeErrorCode(Ex); + Result.Reason = fmt::format("Failed getting blob {}/{}/{}/{}/{}. Reason: '{}'", + m_BuildStorageHttp.GetBaseUri(), + m_Namespace, + m_Bucket, + m_BuildId, + RawHash, + Ex.what()); + } + catch (const std::exception& Ex) + { + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.Reason = fmt::format("Failed getting blob {}/{}/{}/{}/{}. Reason: '{}'", + m_BuildStorageHttp.GetBaseUri(), + m_Namespace, + m_Bucket, + m_BuildId, + RawHash, + Ex.what()); + } + + return Result; + } + + virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, + std::span<const std::pair<uint64_t, uint64_t>> Ranges) override + { + ZEN_ASSERT(!Ranges.empty()); + LoadAttachmentRangesResult Result; + Stopwatch Timer; + auto _ = MakeGuard([&Timer, &Result]() { Result.ElapsedSeconds = Timer.GetElapsedTimeUs() / 1000000.0; }); + + try + { + BuildStorageBase::BuildBlobRanges BlobRanges = m_BuildStorage->GetBuildBlobRanges(m_BuildId, RawHash, Ranges); + if (BlobRanges.PayloadBuffer) { - Result.Bytes = m_BuildStorage->GetBuildBlob(m_BuildId, RawHash); - if (m_BuildCacheStorage && Result.Bytes && m_PopulateCache) - { - m_BuildCacheStorage->PutBuildBlob(m_BuildId, - RawHash, - Result.Bytes.GetContentType(), - CompositeBuffer(SharedBuffer(Result.Bytes))); - } + Result.Bytes = std::move(BlobRanges.PayloadBuffer); + Result.Ranges = std::move(BlobRanges.Ranges); } } catch (const HttpClientError& Ex) { Result.ErrorCode = MakeErrorCode(Ex); - Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'", + Result.Reason = fmt::format("Failed getting {} ranges for blob {}/{}/{}/{}/{}. Reason: '{}'", + Ranges.size(), m_BuildStorageHttp.GetBaseUri(), m_Namespace, m_Bucket, m_BuildId, + RawHash, Ex.what()); } catch (const std::exception& Ex) { Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); - Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'", + Result.Reason = fmt::format("Failed getting {} ranges for blob {}/{}/{}/{}/{}. Reason: '{}'", + Ranges.size(), m_BuildStorageHttp.GetBaseUri(), m_Namespace, m_Bucket, m_BuildId, + RawHash, Ex.what()); } @@ -524,38 +574,6 @@ public: std::vector<IoHash> AttachmentsLeftToFind = RawHashes; - if (m_BuildCacheStorage) - { - std::vector<BuildStorageCache::BlobExistsResult> ExistCheck = m_BuildCacheStorage->BlobsExists(m_BuildId, RawHashes); - if (ExistCheck.size() == RawHashes.size()) - { - AttachmentsLeftToFind.clear(); - for (size_t BlobIndex = 0; BlobIndex < RawHashes.size(); BlobIndex++) - { - const IoHash& Hash = RawHashes[BlobIndex]; - const BuildStorageCache::BlobExistsResult& BlobExists = ExistCheck[BlobIndex]; - if (BlobExists.HasBody) - { - IoBuffer CachedPayload = m_BuildCacheStorage->GetBuildBlob(m_BuildId, Hash); - if (CachedPayload) - { - Result.Chunks.emplace_back( - std::pair<IoHash, CompressedBuffer>{Hash, - CompressedBuffer::FromCompressedNoValidate(std::move(CachedPayload))}); - } - else - { - AttachmentsLeftToFind.push_back(Hash); - } - } - else - { - AttachmentsLeftToFind.push_back(Hash); - } - } - } - } - for (const IoHash& Hash : AttachmentsLeftToFind) { LoadAttachmentResult ChunkResult = LoadAttachment(Hash); @@ -564,27 +582,12 @@ public: return LoadAttachmentsResult{ChunkResult}; } ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(ChunkResult.ElapsedSeconds * 1000))); - if (m_BuildCacheStorage && ChunkResult.Bytes && m_PopulateCache) - { - m_BuildCacheStorage->PutBuildBlob(m_BuildId, - Hash, - ChunkResult.Bytes.GetContentType(), - CompositeBuffer(SharedBuffer(ChunkResult.Bytes))); - } Result.Chunks.emplace_back( std::pair<IoHash, CompressedBuffer>{Hash, CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult.Bytes))}); } return Result; } - virtual void Flush() override - { - if (m_BuildCacheStorage) - { - m_BuildCacheStorage->Flush(100, [](intptr_t) { return false; }); - } - } - private: static int MakeErrorCode(const HttpClientError& Ex) { @@ -601,10 +604,6 @@ private: HttpClient m_BuildStorageHttp; std::unique_ptr<BuildStorageBase> m_BuildStorage; - BuildStorageCache::Statistics m_StorageCacheStats; - std::unique_ptr<HttpClient> m_BuildCacheStorageHttp; - std::unique_ptr<BuildStorageCache> m_BuildCacheStorage; - const std::string m_Namespace; const std::string m_Bucket; const Oid m_BuildId; @@ -613,120 +612,35 @@ private: const bool m_EnableBlocks = true; const bool m_UseTempBlocks = true; const bool m_AllowRedirect = false; - const bool m_PopulateCache = true; }; std::shared_ptr<RemoteProjectStore> -CreateJupiterBuildsRemoteStore(LoggerRef InLog, - const BuildsRemoteStoreOptions& Options, - const std::filesystem::path& TempFilePath, - bool Quiet, - bool Unattended, - bool Hidden, - WorkerThreadPool& CacheBackgroundWorkerPool) +CreateJupiterBuildsRemoteStore(LoggerRef InLog, + const BuildStorageResolveResult& ResolveResult, + std::function<HttpClientAccessToken()>&& TokenProvider, + const BuildsRemoteStoreOptions& Options, + const std::filesystem::path& TempFilePath) { - std::string Host = Options.Host; - if (!Host.empty() && Host.find("://"sv) == std::string::npos) - { - // Assume https URL - Host = fmt::format("https://{}"sv, Host); - } - std::string OverrideUrl = Options.OverrideHost; - if (!OverrideUrl.empty() && OverrideUrl.find("://"sv) == std::string::npos) - { - // Assume https URL - OverrideUrl = fmt::format("https://{}"sv, OverrideUrl); - } - std::string ZenHost = Options.ZenHost; - if (!ZenHost.empty() && ZenHost.find("://"sv) == std::string::npos) - { - // Assume https URL - ZenHost = fmt::format("https://{}"sv, ZenHost); - } - - // 1) openid-provider if given (assumes oidctoken.exe -Zen true has been run with matching Options.OpenIdProvider - // 2) Access token as parameter in request - // 3) Environment variable (different win vs linux/mac) - // 4) Default openid-provider (assumes oidctoken.exe -Zen true has been run with matching Options.OpenIdProvider - - std::function<HttpClientAccessToken()> TokenProvider; - if (!Options.OpenIdProvider.empty()) - { - TokenProvider = httpclientauth::CreateFromOpenIdProvider(Options.AuthManager, Options.OpenIdProvider); - } - else if (!Options.AccessToken.empty()) - { - TokenProvider = httpclientauth::CreateFromStaticToken(Options.AccessToken); - } - else if (!Options.OidcExePath.empty()) - { - if (auto TokenProviderMaybe = httpclientauth::CreateFromOidcTokenExecutable(Options.OidcExePath, - Host.empty() ? OverrideUrl : Host, - Quiet, - Unattended, - Hidden); - TokenProviderMaybe) - { - TokenProvider = TokenProviderMaybe.value(); - } - } - - if (!TokenProvider) - { - TokenProvider = httpclientauth::CreateFromDefaultOpenIdProvider(Options.AuthManager); - } - - BuildStorageResolveResult ResolveRes; - { - HttpClientSettings ClientSettings{.LogCategory = "httpbuildsclient", - .AccessTokenProvider = TokenProvider, - .AssumeHttp2 = Options.AssumeHttp2, - .AllowResume = true, - .RetryCount = 2}; - - std::unique_ptr<OperationLogOutput> Output(CreateStandardLogOutput(InLog)); - - ResolveRes = - ResolveBuildStorage(*Output, ClientSettings, Host, OverrideUrl, ZenHost, ZenCacheResolveMode::Discovery, /*Verbose*/ false); - } - HttpClientSettings ClientSettings{.LogCategory = "httpbuildsclient", .ConnectTimeout = std::chrono::milliseconds(3000), .Timeout = std::chrono::milliseconds(1800000), .AccessTokenProvider = std::move(TokenProvider), - .AssumeHttp2 = ResolveRes.HostAssumeHttp2, + .AssumeHttp2 = ResolveResult.Cloud.AssumeHttp2, .AllowResume = true, .RetryCount = 4, .MaximumInMemoryDownloadSize = Options.MaximumInMemoryDownloadSize}; - std::unique_ptr<HttpClientSettings> CacheClientSettings; - - if (!ResolveRes.CacheUrl.empty()) - { - CacheClientSettings = - std::make_unique<HttpClientSettings>(HttpClientSettings{.LogCategory = "httpcacheclient", - .ConnectTimeout = std::chrono::milliseconds{3000}, - .Timeout = std::chrono::milliseconds{30000}, - .AssumeHttp2 = ResolveRes.CacheAssumeHttp2, - .AllowResume = true, - .RetryCount = 0, - .MaximumInMemoryDownloadSize = Options.MaximumInMemoryDownloadSize}); - } - std::shared_ptr<RemoteProjectStore> RemoteStore = std::make_shared<BuildsRemoteStore>(InLog, ClientSettings, - CacheClientSettings.get(), - ResolveRes.HostUrl, - ResolveRes.CacheUrl, + ResolveResult.Cloud.Address, TempFilePath, - CacheBackgroundWorkerPool, Options.Namespace, Options.Bucket, Options.BuildId, Options.MetaData, Options.ForceDisableBlocks, - Options.ForceDisableTempBlocks, - Options.PopulateCache); + Options.ForceDisableTempBlocks); + return RemoteStore; } diff --git a/src/zenremotestore/projectstore/fileremoteprojectstore.cpp b/src/zenremotestore/projectstore/fileremoteprojectstore.cpp index 3a67d3842..bb21de12c 100644 --- a/src/zenremotestore/projectstore/fileremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/fileremoteprojectstore.cpp @@ -7,8 +7,12 @@ #include <zencore/filesystem.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> +#include <zencore/scopeguard.h> #include <zencore/timer.h> #include <zenhttp/httpcommon.h> +#include <zenremotestore/builds/buildstoragecache.h> + +#include <numeric> namespace zen { @@ -74,9 +78,11 @@ public: virtual SaveResult SaveContainer(const IoBuffer& Payload) override { - Stopwatch Timer; SaveResult Result; + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); + { CbObject ContainerObject = LoadCompactBinaryObject(Payload); @@ -87,6 +93,10 @@ public: { Result.Needs.insert(AttachmentHash); } + else if (std::filesystem::path AttachmentMetaPath = GetAttachmentMetaPath(AttachmentHash); IsFile(AttachmentMetaPath)) + { + BasicFile TouchIt(AttachmentMetaPath, BasicFile::Mode::kWrite); + } }); } @@ -112,14 +122,18 @@ public: Result.Reason = fmt::format("Failed saving oplog container to '{}'. Reason: {}", ContainerPath, Ex.what()); } AddStats(Payload.GetSize(), 0, Timer.GetElapsedTimeUs() * 1000); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } - virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash, ChunkBlockDescription&&) override + virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, + const IoHash& RawHash, + ChunkBlockDescription&& BlockDescription) override { - Stopwatch Timer; - SaveAttachmentResult Result; + SaveAttachmentResult Result; + + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); if (!IsFile(ChunkPath)) { @@ -142,14 +156,33 @@ public: Result.Reason = fmt::format("Failed saving oplog attachment to '{}'. Reason: {}", ChunkPath, Ex.what()); } } + if (!Result.ErrorCode && BlockDescription.BlockHash != IoHash::Zero) + { + try + { + std::filesystem::path MetaPath = GetAttachmentMetaPath(RawHash); + CbObject MetaData = BuildChunkBlockDescription(BlockDescription, {}); + SharedBuffer MetaBuffer = MetaData.GetBuffer(); + BasicFile MetaFile; + MetaFile.Open(MetaPath, BasicFile::Mode::kTruncate); + MetaFile.Write(MetaBuffer.GetView(), 0); + } + catch (const std::exception& Ex) + { + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.Reason = fmt::format("Failed saving block description to '{}'. Reason: {}", RawHash, Ex.what()); + } + } AddStats(Payload.GetSize(), 0, Timer.GetElapsedTimeUs() * 1000); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override { + SaveAttachmentsResult Result; + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); for (const SharedBuffer& Chunk : Chunks) { @@ -157,12 +190,10 @@ public: SaveAttachmentResult ChunkResult = SaveAttachment(Compressed.GetCompressed(), Compressed.DecodeRawHash(), {}); if (ChunkResult.ErrorCode) { - ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; - return SaveAttachmentsResult{ChunkResult}; + Result = SaveAttachmentsResult{ChunkResult}; + break; } } - SaveAttachmentsResult Result; - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } @@ -172,21 +203,60 @@ public: virtual GetKnownBlocksResult GetKnownBlocks() override { + Stopwatch Timer; if (m_OptionalBaseName.empty()) { - return GetKnownBlocksResult{{.ErrorCode = static_cast<int>(HttpResponseCode::NoContent)}}; + size_t MaxBlockCount = 10000; + + GetKnownBlocksResult Result; + + DirectoryContent Content; + GetDirectoryContent( + m_OutputPath, + DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive | DirectoryContentFlags::IncludeModificationTick, + Content); + std::vector<size_t> RecentOrder(Content.Files.size()); + std::iota(RecentOrder.begin(), RecentOrder.end(), 0u); + std::sort(RecentOrder.begin(), RecentOrder.end(), [&Content](size_t Lhs, size_t Rhs) { + return Content.FileModificationTicks[Lhs] > Content.FileModificationTicks[Rhs]; + }); + + for (size_t FileIndex : RecentOrder) + { + std::filesystem::path MetaPath = Content.Files[FileIndex]; + if (MetaPath.extension() == MetaExtension) + { + IoBuffer MetaFile = ReadFile(MetaPath).Flatten(); + CbValidateError Err; + CbObject ValidatedObject = ValidateAndReadCompactBinaryObject(std::move(MetaFile), Err); + if (Err == CbValidateError::None) + { + ChunkBlockDescription Description = ParseChunkBlockDescription(ValidatedObject); + if (Description.BlockHash != IoHash::Zero) + { + Result.Blocks.emplace_back(std::move(Description)); + if (Result.Blocks.size() == MaxBlockCount) + { + break; + } + } + } + } + } + + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; + return Result; } LoadContainerResult LoadResult = LoadContainer(m_OptionalBaseName); if (LoadResult.ErrorCode) { return GetKnownBlocksResult{LoadResult}; } - Stopwatch Timer; std::vector<IoHash> BlockHashes = GetBlockHashesFromOplog(LoadResult.ContainerObject); if (BlockHashes.empty()) { return GetKnownBlocksResult{{.ErrorCode = static_cast<int>(HttpResponseCode::NoContent), - .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeUs() * 1000}}; + .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeMs() / 1000.0}}; } std::vector<IoHash> ExistingBlockHashes; for (const IoHash& RawHash : BlockHashes) @@ -200,15 +270,15 @@ public: if (ExistingBlockHashes.empty()) { return GetKnownBlocksResult{{.ErrorCode = static_cast<int>(HttpResponseCode::NoContent), - .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeUs() * 1000}}; + .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeMs() / 1000.0}}; } std::vector<ThinChunkBlockDescription> ThinKnownBlocks = GetBlocksFromOplog(LoadResult.ContainerObject, ExistingBlockHashes); - const size_t KnowBlockCount = ThinKnownBlocks.size(); + const size_t KnownBlockCount = ThinKnownBlocks.size(); - GetKnownBlocksResult Result{{.ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeUs() * 1000}}; - Result.Blocks.resize(KnowBlockCount); - for (size_t BlockIndex = 0; BlockIndex < KnowBlockCount; BlockIndex++) + GetKnownBlocksResult Result{{.ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeMs() / 1000.0}}; + Result.Blocks.resize(KnownBlockCount); + for (size_t BlockIndex = 0; BlockIndex < KnownBlockCount; BlockIndex++) { Result.Blocks[BlockIndex].BlockHash = ThinKnownBlocks[BlockIndex].BlockHash; Result.Blocks[BlockIndex].ChunkRawHashes = std::move(ThinKnownBlocks[BlockIndex].ChunkRawHashes); @@ -217,16 +287,88 @@ public: return Result; } + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span<const IoHash> BlockHashes, + BuildStorageCache* OptionalCache, + const Oid& CacheBuildId) override + { + GetBlockDescriptionsResult Result; + + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); + + Result.Blocks.reserve(BlockHashes.size()); + + uint64_t ByteCount = 0; + + std::vector<ChunkBlockDescription> UnorderedList; + { + if (OptionalCache) + { + std::vector<CbObject> CacheBlockMetadatas = OptionalCache->GetBlobMetadatas(CacheBuildId, BlockHashes); + for (const CbObject& BlockObject : CacheBlockMetadatas) + { + ByteCount += BlockObject.GetSize(); + } + UnorderedList = ParseBlockMetadatas(CacheBlockMetadatas); + } + + tsl::robin_map<IoHash, size_t, IoHash::Hasher> BlockDescriptionLookup; + BlockDescriptionLookup.reserve(BlockHashes.size()); + for (size_t DescriptionIndex = 0; DescriptionIndex < UnorderedList.size(); DescriptionIndex++) + { + const ChunkBlockDescription& Description = UnorderedList[DescriptionIndex]; + BlockDescriptionLookup.insert_or_assign(Description.BlockHash, DescriptionIndex); + } + + if (UnorderedList.size() < BlockHashes.size()) + { + for (const IoHash& RawHash : BlockHashes) + { + if (!BlockDescriptionLookup.contains(RawHash)) + { + std::filesystem::path MetaPath = GetAttachmentMetaPath(RawHash); + IoBuffer MetaFile = ReadFile(MetaPath).Flatten(); + ByteCount += MetaFile.GetSize(); + CbValidateError Err; + CbObject ValidatedObject = ValidateAndReadCompactBinaryObject(std::move(MetaFile), Err); + if (Err == CbValidateError::None) + { + ChunkBlockDescription Description = ParseChunkBlockDescription(ValidatedObject); + if (Description.BlockHash != IoHash::Zero) + { + BlockDescriptionLookup.insert_or_assign(Description.BlockHash, UnorderedList.size()); + UnorderedList.emplace_back(std::move(Description)); + } + } + } + } + } + + Result.Blocks.reserve(UnorderedList.size()); + for (const IoHash& RawHash : BlockHashes) + { + if (auto It = BlockDescriptionLookup.find(RawHash); It != BlockDescriptionLookup.end()) + { + Result.Blocks.emplace_back(std::move(UnorderedList[It->second])); + } + } + } + AddStats(0, ByteCount, Timer.GetElapsedTimeUs() * 1000); + return Result; + } + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override { - Stopwatch Timer; - LoadAttachmentResult Result; + LoadAttachmentResult Result; + + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); if (!IsFile(ChunkPath)) { Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound); Result.Reason = fmt::format("Failed loading oplog attachment from '{}'. Reason: 'The file does not exist'", ChunkPath.string()); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } { @@ -235,7 +377,41 @@ public: Result.Bytes = ChunkFile.ReadAll(); } AddStats(0, Result.Bytes.GetSize(), Timer.GetElapsedTimeUs() * 1000); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; + return Result; + } + + virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, + std::span<const std::pair<uint64_t, uint64_t>> Ranges) override + { + ZEN_ASSERT(!Ranges.empty()); + LoadAttachmentRangesResult Result; + + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); + + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); + if (!IsFile(ChunkPath)) + { + Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound); + Result.Reason = fmt::format("Failed loading oplog attachment from '{}'. Reason: 'The file does not exist'", ChunkPath.string()); + return Result; + } + { + uint64_t Start = Ranges.front().first; + uint64_t Length = Ranges.back().first + Ranges.back().second - Ranges.front().first; + Result.Bytes = IoBufferBuilder::MakeFromFile(ChunkPath, Start, Length); + Result.Ranges.reserve(Ranges.size()); + for (const std::pair<uint64_t, uint64_t>& Range : Ranges) + { + Result.Ranges.push_back(std::make_pair(Range.first - Start, Range.second)); + } + } + AddStats(0, + std::accumulate(Result.Ranges.begin(), + Result.Ranges.end(), + uint64_t(0), + [](uint64_t Current, const std::pair<uint64_t, uint64_t>& Value) { return Current + Value.second; }), + Timer.GetElapsedTimeUs() * 1000); return Result; } @@ -258,20 +434,20 @@ public: return Result; } - virtual void Flush() override {} - private: LoadContainerResult LoadContainer(const std::string& Name) { - Stopwatch Timer; - LoadContainerResult Result; + LoadContainerResult Result; + + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); + std::filesystem::path SourcePath = m_OutputPath; SourcePath.append(Name); if (!IsFile(SourcePath)) { Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound); Result.Reason = fmt::format("Failed loading oplog container from '{}'. Reason: 'The file does not exist'", SourcePath.string()); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } IoBuffer ContainerPayload; @@ -285,18 +461,16 @@ private: if (Result.ContainerObject = ValidateAndReadCompactBinaryObject(std::move(ContainerPayload), ValidateResult); ValidateResult != CbValidateError::None || !Result.ContainerObject) { - Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); - Result.Reason = fmt::format("The file {} is not formatted as a compact binary object ('{}')", - SourcePath.string(), - ToString(ValidateResult)); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.Reason = fmt::format("The file {} is not formatted as a compact binary object ('{}')", + SourcePath.string(), + ToString(ValidateResult)); return Result; } - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } - std::filesystem::path GetAttachmentPath(const IoHash& RawHash) const + std::filesystem::path GetAttachmentBasePath(const IoHash& RawHash) const { ExtendablePathBuilder<128> ShardedPath; ShardedPath.Append(m_OutputPath.c_str()); @@ -315,6 +489,19 @@ private: return ShardedPath.ToPath(); } + static constexpr std::string_view BlobExtension = ".blob"; + static constexpr std::string_view MetaExtension = ".meta"; + + std::filesystem::path GetAttachmentPath(const IoHash& RawHash) + { + return GetAttachmentBasePath(RawHash).replace_extension(BlobExtension); + } + + std::filesystem::path GetAttachmentMetaPath(const IoHash& RawHash) + { + return GetAttachmentBasePath(RawHash).replace_extension(MetaExtension); + } + void AddStats(uint64_t UploadedBytes, uint64_t DownloadedBytes, uint64_t ElapsedNS) { m_SentBytes.fetch_add(UploadedBytes); diff --git a/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp b/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp index 462de2988..5b456cb4c 100644 --- a/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp @@ -212,13 +212,43 @@ public: return Result; } + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span<const IoHash> BlockHashes, + BuildStorageCache* OptionalCache, + const Oid& CacheBuildId) override + { + ZEN_UNUSED(BlockHashes, OptionalCache, CacheBuildId); + return GetBlockDescriptionsResult{Result{.ErrorCode = int(HttpResponseCode::NotFound)}}; + } + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override { - JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect); - JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath); + LoadAttachmentResult Result; + JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect); + JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath); + AddStats(GetResult); + + Result = {ConvertResult(GetResult), std::move(GetResult.Response)}; + if (GetResult.ErrorCode) + { + Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}. Reason: '{}'", + m_JupiterClient->ServiceUrl(), + m_Namespace, + RawHash, + Result.Reason); + } + return Result; + } + + virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, + std::span<const std::pair<uint64_t, uint64_t>> Ranges) override + { + ZEN_ASSERT(!Ranges.empty()); + LoadAttachmentRangesResult Result; + JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect); + JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath); AddStats(GetResult); - LoadAttachmentResult Result{ConvertResult(GetResult), std::move(GetResult.Response)}; + Result = LoadAttachmentRangesResult{ConvertResult(GetResult), std::move(GetResult.Response)}; if (GetResult.ErrorCode) { Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}. Reason: '{}'", @@ -227,6 +257,10 @@ public: RawHash, Result.Reason); } + else + { + Result.Ranges = std::vector<std::pair<uint64_t, uint64_t>>(Ranges.begin(), Ranges.end()); + } return Result; } @@ -247,8 +281,6 @@ public: return Result; } - virtual void Flush() override {} - private: LoadContainerResult LoadContainer(const IoHash& Key) { diff --git a/src/zenremotestore/projectstore/projectstoreoperations.cpp b/src/zenremotestore/projectstore/projectstoreoperations.cpp index becac3d4c..36dc4d868 100644 --- a/src/zenremotestore/projectstore/projectstoreoperations.cpp +++ b/src/zenremotestore/projectstore/projectstoreoperations.cpp @@ -426,19 +426,19 @@ ProjectStoreOperationDownloadAttachments::Execute() auto GetBuildBlob = [this](const IoHash& RawHash, const std::filesystem::path& OutputPath) { IoBuffer Payload; - if (m_Storage.BuildCacheStorage) + if (m_Storage.CacheStorage) { - Payload = m_Storage.BuildCacheStorage->GetBuildBlob(m_State.GetBuildId(), RawHash); + Payload = m_Storage.CacheStorage->GetBuildBlob(m_State.GetBuildId(), RawHash); } if (!Payload) { Payload = m_Storage.BuildStorage->GetBuildBlob(m_State.GetBuildId(), RawHash); - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_State.GetBuildId(), - RawHash, - Payload.GetContentType(), - CompositeBuffer(SharedBuffer(Payload))); + m_Storage.CacheStorage->PutBuildBlob(m_State.GetBuildId(), + RawHash, + Payload.GetContentType(), + CompositeBuffer(SharedBuffer(Payload))); } } uint64_t PayloadSize = Payload.GetSize(); diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp index 8be8eb0df..247bd6cb9 100644 --- a/src/zenremotestore/projectstore/remoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp @@ -14,6 +14,8 @@ #include <zencore/trace.h> #include <zencore/workthreadpool.h> #include <zenhttp/httpcommon.h> +#include <zenremotestore/builds/buildstoragecache.h> +#include <zenremotestore/chunking/chunkedcontent.h> #include <zenremotestore/chunking/chunkedfile.h> #include <zenremotestore/operationlogoutput.h> #include <zenstore/cidstore.h> @@ -123,14 +125,17 @@ namespace remotestore_impl { return OptionalContext->IsCancelled(); } - std::string GetStats(const RemoteProjectStore::Stats& Stats, uint64_t ElapsedWallTimeMS) + std::string GetStats(const RemoteProjectStore::Stats& Stats, + const BuildStorageCache::Statistics* OptionalCacheStats, + uint64_t ElapsedWallTimeMS) { - return fmt::format( - "Sent: {} ({}bits/s) Recv: {} ({}bits/s)", - NiceBytes(Stats.m_SentBytes), - NiceNum(ElapsedWallTimeMS > 0u ? static_cast<uint64_t>((Stats.m_SentBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u), - NiceBytes(Stats.m_ReceivedBytes), - NiceNum(ElapsedWallTimeMS > 0u ? static_cast<uint64_t>((Stats.m_ReceivedBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u)); + uint64_t SentBytes = Stats.m_SentBytes + (OptionalCacheStats ? OptionalCacheStats->TotalBytesWritten.load() : 0); + uint64_t ReceivedBytes = Stats.m_ReceivedBytes + (OptionalCacheStats ? OptionalCacheStats->TotalBytesRead.load() : 0); + return fmt::format("Sent: {} ({}bits/s) Recv: {} ({}bits/s)", + NiceBytes(SentBytes), + NiceNum(ElapsedWallTimeMS > 0u ? static_cast<uint64_t>((SentBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u), + NiceBytes(ReceivedBytes), + NiceNum(ElapsedWallTimeMS > 0u ? static_cast<uint64_t>((ReceivedBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u)); } void LogRemoteStoreStatsDetails(const RemoteProjectStore::Stats& Stats) @@ -229,44 +234,66 @@ namespace remotestore_impl { struct DownloadInfo { - uint64_t OplogSizeBytes = 0; - std::atomic<uint64_t> AttachmentsDownloaded = 0; - std::atomic<uint64_t> AttachmentBlocksDownloaded = 0; - std::atomic<uint64_t> AttachmentBytesDownloaded = 0; - std::atomic<uint64_t> AttachmentBlockBytesDownloaded = 0; - std::atomic<uint64_t> AttachmentsStored = 0; - std::atomic<uint64_t> AttachmentBytesStored = 0; - std::atomic_size_t MissingAttachmentCount = 0; + uint64_t OplogSizeBytes = 0; + std::atomic<uint64_t> AttachmentsDownloaded = 0; + std::atomic<uint64_t> AttachmentBlocksDownloaded = 0; + std::atomic<uint64_t> AttachmentBlocksRangesDownloaded = 0; + std::atomic<uint64_t> AttachmentBytesDownloaded = 0; + std::atomic<uint64_t> AttachmentBlockBytesDownloaded = 0; + std::atomic<uint64_t> AttachmentBlockRangeBytesDownloaded = 0; + std::atomic<uint64_t> AttachmentsStored = 0; + std::atomic<uint64_t> AttachmentBytesStored = 0; + std::atomic_size_t MissingAttachmentCount = 0; }; - void DownloadAndSaveBlockChunks(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - bool IgnoreMissingAttachments, - JobContext* OptionalContext, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, - Latch& AttachmentsDownloadLatch, - Latch& AttachmentsWriteLatch, - AsyncRemoteResult& RemoteResult, - DownloadInfo& Info, - Stopwatch& LoadAttachmentsTimer, - std::atomic_uint64_t& DownloadStartMS, - const std::vector<IoHash>& Chunks) + class JobContextLogOutput : public OperationLogOutput + { + public: + JobContextLogOutput(JobContext* OptionalContext) : m_OptionalContext(OptionalContext) {} + virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override + { + if (m_OptionalContext) + { + fmt::basic_memory_buffer<char, 250> MessageBuffer; + fmt::vformat_to(fmt::appender(MessageBuffer), Point.FormatString, Args); + remotestore_impl::ReportMessage(m_OptionalContext, std::string_view(MessageBuffer.data(), MessageBuffer.size())); + } + } + + virtual void SetLogOperationName(std::string_view Name) override { ZEN_UNUSED(Name); } + virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { ZEN_UNUSED(StepIndex, StepCount); } + virtual uint32_t GetProgressUpdateDelayMS() override { return 0; } + virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override + { + ZEN_UNUSED(InSubTask); + return nullptr; + } + + private: + JobContext* m_OptionalContext; + }; + + void DownloadAndSaveBlockChunks(LoadOplogContext& Context, + Latch& AttachmentsDownloadLatch, + Latch& AttachmentsWriteLatch, + AsyncRemoteResult& RemoteResult, + DownloadInfo& Info, + Stopwatch& LoadAttachmentsTimer, + std::atomic_uint64_t& DownloadStartMS, + ThinChunkBlockDescription&& ThinBlockDescription, + std::vector<uint32_t>&& NeededChunkIndexes) { AttachmentsDownloadLatch.AddCount(1); - NetworkWorkerPool.ScheduleWork( - [&RemoteStore, - &ChunkStore, - &WorkerPool, + Context.NetworkWorkerPool.ScheduleWork( + [&Context, &AttachmentsDownloadLatch, &AttachmentsWriteLatch, &RemoteResult, - Chunks = Chunks, + ThinBlockDescription = std::move(ThinBlockDescription), + NeededChunkIndexes = std::move(NeededChunkIndexes), &Info, &LoadAttachmentsTimer, - &DownloadStartMS, - IgnoreMissingAttachments, - OptionalContext]() { + &DownloadStartMS]() { ZEN_TRACE_CPU("DownloadBlockChunks"); auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); }); @@ -276,34 +303,47 @@ namespace remotestore_impl { } try { + std::vector<IoHash> Chunks; + Chunks.reserve(NeededChunkIndexes.size()); + for (uint32_t ChunkIndex : NeededChunkIndexes) + { + Chunks.push_back(ThinBlockDescription.ChunkRawHashes[ChunkIndex]); + } + uint64_t Unset = (std::uint64_t)-1; DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); - RemoteProjectStore::LoadAttachmentsResult Result = RemoteStore.LoadAttachments(Chunks); + RemoteProjectStore::LoadAttachmentsResult Result = Context.RemoteStore.LoadAttachments(Chunks); if (Result.ErrorCode) { - ReportMessage(OptionalContext, + ReportMessage(Context.OptionalJobContext, fmt::format("Failed to load attachments with {} chunks ({}): {}", Chunks.size(), RemoteResult.GetError(), RemoteResult.GetErrorReason())); Info.MissingAttachmentCount.fetch_add(1); - if (IgnoreMissingAttachments) + if (Context.IgnoreMissingAttachments) { RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); } return; } - Info.AttachmentsDownloaded.fetch_add(Chunks.size()); - ZEN_INFO("Loaded {} bulk attachments in {}", - Chunks.size(), - NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))); + Info.AttachmentsDownloaded.fetch_add(Result.Chunks.size()); + for (const auto& It : Result.Chunks) + { + uint64_t ChunkSize = It.second.GetCompressedSize(); + Info.AttachmentBytesDownloaded.fetch_add(ChunkSize); + } + remotestore_impl::ReportMessage(Context.OptionalJobContext, + fmt::format("Loaded {} bulk attachments in {}", + Chunks.size(), + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)))); if (RemoteResult.IsError()) { return; } AttachmentsWriteLatch.AddCount(1); - WorkerPool.ScheduleWork( - [&AttachmentsWriteLatch, &RemoteResult, &Info, &ChunkStore, Chunks = std::move(Result.Chunks)]() { + Context.WorkerPool.ScheduleWork( + [&AttachmentsWriteLatch, &RemoteResult, &Info, &Context, Chunks = std::move(Result.Chunks)]() { auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); }); if (RemoteResult.IsError()) { @@ -320,13 +360,13 @@ namespace remotestore_impl { for (const auto& It : Chunks) { - uint64_t ChunkSize = It.second.GetCompressedSize(); - Info.AttachmentBytesDownloaded.fetch_add(ChunkSize); WriteAttachmentBuffers.push_back(It.second.GetCompressed().Flatten().AsIoBuffer()); WriteRawHashes.push_back(It.first); } std::vector<CidStore::InsertResult> InsertResults = - ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes, CidStore::InsertMode::kCopyOnly); + Context.ChunkStore.AddChunks(WriteAttachmentBuffers, + WriteRawHashes, + CidStore::InsertMode::kCopyOnly); for (size_t Index = 0; Index < InsertResults.size(); Index++) { @@ -350,46 +390,38 @@ namespace remotestore_impl { catch (const std::exception& Ex) { RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::InternalServerError), - fmt::format("Failed to bulk load {} attachments", Chunks.size()), + fmt::format("Failed to bulk load {} attachments", NeededChunkIndexes.size()), Ex.what()); } }, WorkerThreadPool::EMode::EnableBacklog); }; - void DownloadAndSaveBlock(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - bool IgnoreMissingAttachments, - JobContext* OptionalContext, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, - Latch& AttachmentsDownloadLatch, - Latch& AttachmentsWriteLatch, - AsyncRemoteResult& RemoteResult, - DownloadInfo& Info, - Stopwatch& LoadAttachmentsTimer, - std::atomic_uint64_t& DownloadStartMS, - const IoHash& BlockHash, - const std::vector<IoHash>& Chunks, - uint32_t RetriesLeft) + void DownloadAndSaveBlock(LoadOplogContext& Context, + Latch& AttachmentsDownloadLatch, + Latch& AttachmentsWriteLatch, + AsyncRemoteResult& RemoteResult, + DownloadInfo& Info, + Stopwatch& LoadAttachmentsTimer, + std::atomic_uint64_t& DownloadStartMS, + const IoHash& BlockHash, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& AllNeededPartialChunkHashesLookup, + std::span<std::atomic<bool>> ChunkDownloadedFlags, + uint32_t RetriesLeft) { AttachmentsDownloadLatch.AddCount(1); - NetworkWorkerPool.ScheduleWork( + Context.NetworkWorkerPool.ScheduleWork( [&AttachmentsDownloadLatch, &AttachmentsWriteLatch, - &ChunkStore, - &RemoteStore, - &NetworkWorkerPool, - &WorkerPool, - BlockHash, + &Context, &RemoteResult, &Info, &LoadAttachmentsTimer, &DownloadStartMS, - IgnoreMissingAttachments, - OptionalContext, RetriesLeft, - Chunks = std::vector<IoHash>(Chunks)]() { + BlockHash = IoHash(BlockHash), + &AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags]() { ZEN_TRACE_CPU("DownloadBlock"); auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); }); @@ -401,51 +433,65 @@ namespace remotestore_impl { { uint64_t Unset = (std::uint64_t)-1; DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); - RemoteProjectStore::LoadAttachmentResult BlockResult = RemoteStore.LoadAttachment(BlockHash); - if (BlockResult.ErrorCode) + + IoBuffer BlobBuffer; + if (Context.OptionalCache) { - ReportMessage(OptionalContext, - fmt::format("Failed to download block attachment {} ({}): {}", - BlockHash, - RemoteResult.GetError(), - RemoteResult.GetErrorReason())); - Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) - { - RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text); - } - return; + BlobBuffer = Context.OptionalCache->GetBuildBlob(Context.CacheBuildId, BlockHash); } - if (RemoteResult.IsError()) + + if (!BlobBuffer) { - return; + RemoteProjectStore::LoadAttachmentResult BlockResult = Context.RemoteStore.LoadAttachment(BlockHash); + if (BlockResult.ErrorCode) + { + ReportMessage(Context.OptionalJobContext, + fmt::format("Failed to download block attachment {} ({}): {}", + BlockHash, + BlockResult.Reason, + BlockResult.Text)); + Info.MissingAttachmentCount.fetch_add(1); + if (!Context.IgnoreMissingAttachments) + { + RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text); + } + return; + } + if (RemoteResult.IsError()) + { + return; + } + BlobBuffer = std::move(BlockResult.Bytes); + ZEN_DEBUG("Loaded block attachment '{}' in {} ({})", + BlockHash, + NiceTimeSpanMs(static_cast<uint64_t>(BlockResult.ElapsedSeconds * 1000)), + NiceBytes(BlobBuffer.Size())); + if (Context.OptionalCache && Context.PopulateCache) + { + Context.OptionalCache->PutBuildBlob(Context.CacheBuildId, + BlockHash, + BlobBuffer.GetContentType(), + CompositeBuffer(SharedBuffer(BlobBuffer))); + } } - uint64_t BlockSize = BlockResult.Bytes.GetSize(); + uint64_t BlockSize = BlobBuffer.GetSize(); Info.AttachmentBlocksDownloaded.fetch_add(1); - ZEN_INFO("Loaded block attachment '{}' in {} ({})", - BlockHash, - NiceTimeSpanMs(static_cast<uint64_t>(BlockResult.ElapsedSeconds * 1000)), - NiceBytes(BlockSize)); Info.AttachmentBlockBytesDownloaded.fetch_add(BlockSize); AttachmentsWriteLatch.AddCount(1); - WorkerPool.ScheduleWork( + Context.WorkerPool.ScheduleWork( [&AttachmentsDownloadLatch, &AttachmentsWriteLatch, - &ChunkStore, - &RemoteStore, - &NetworkWorkerPool, - &WorkerPool, - BlockHash, + &Context, &RemoteResult, &Info, &LoadAttachmentsTimer, &DownloadStartMS, - IgnoreMissingAttachments, - OptionalContext, RetriesLeft, - Chunks = std::move(Chunks), - Bytes = std::move(BlockResult.Bytes)]() { + BlockHash = IoHash(BlockHash), + &AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + Bytes = std::move(BlobBuffer)]() { auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); }); if (RemoteResult.IsError()) { @@ -454,64 +500,107 @@ namespace remotestore_impl { try { ZEN_ASSERT(Bytes.Size() > 0); - std::unordered_set<IoHash, IoHash::Hasher> WantedChunks; - WantedChunks.reserve(Chunks.size()); - WantedChunks.insert(Chunks.begin(), Chunks.end()); std::vector<IoBuffer> WriteAttachmentBuffers; std::vector<IoHash> WriteRawHashes; IoHash RawHash; uint64_t RawSize; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Bytes), RawHash, RawSize); + + std::string ErrorString; + if (!Compressed) { - if (RetriesLeft > 0) + ErrorString = + fmt::format("Block attachment {} is malformed, can't parse as compressed binary", BlockHash); + } + else if (RawHash != BlockHash) + { + ErrorString = fmt::format("Block attachment {} has mismatching raw hash ({})", BlockHash, RawHash); + } + else if (CompositeBuffer BlockPayload = Compressed.DecompressToComposite(); !BlockPayload) + { + ErrorString = fmt::format("Block attachment {} is malformed, can't decompress payload", BlockHash); + } + else + { + uint64_t PotentialSize = 0; + uint64_t UsedSize = 0; + uint64_t BlockSize = BlockPayload.GetSize(); + + uint64_t BlockHeaderSize = 0; + + bool StoreChunksOK = IterateChunkBlock( + BlockPayload.Flatten(), + [&AllNeededPartialChunkHashesLookup, + &ChunkDownloadedFlags, + &WriteAttachmentBuffers, + &WriteRawHashes, + &Info, + &PotentialSize](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) { + auto ChunkIndexIt = AllNeededPartialChunkHashesLookup.find(AttachmentRawHash); + if (ChunkIndexIt != AllNeededPartialChunkHashesLookup.end()) + { + bool Expected = false; + if (ChunkDownloadedFlags[ChunkIndexIt->second].compare_exchange_strong(Expected, true)) + { + WriteAttachmentBuffers.emplace_back(Chunk.GetCompressed().Flatten().AsIoBuffer()); + IoHash RawHash; + uint64_t RawSize; + ZEN_ASSERT(CompressedBuffer::ValidateCompressedHeader( + WriteAttachmentBuffers.back(), + RawHash, + RawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)); + ZEN_ASSERT(RawHash == AttachmentRawHash); + WriteRawHashes.emplace_back(AttachmentRawHash); + PotentialSize += WriteAttachmentBuffers.back().GetSize(); + } + } + }, + BlockHeaderSize); + + if (!StoreChunksOK) { - ReportMessage( - OptionalContext, - fmt::format( - "Block attachment {} is malformed, can't parse as compressed binary, retrying download", - BlockHash)); - return DownloadAndSaveBlock(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, - AttachmentsDownloadLatch, - AttachmentsWriteLatch, - RemoteResult, - Info, - LoadAttachmentsTimer, - DownloadStartMS, - BlockHash, - std::move(Chunks), - RetriesLeft - 1); + ErrorString = fmt::format("Invalid format for block {}", BlockHash); + } + else + { + if (!WriteAttachmentBuffers.empty()) + { + std::vector<CidStore::InsertResult> Results = + Context.ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); + for (size_t Index = 0; Index < Results.size(); Index++) + { + const CidStore::InsertResult& Result = Results[Index]; + if (Result.New) + { + Info.AttachmentBytesStored.fetch_add(WriteAttachmentBuffers[Index].GetSize()); + Info.AttachmentsStored.fetch_add(1); + UsedSize += WriteAttachmentBuffers[Index].GetSize(); + } + } + if (UsedSize < BlockSize) + { + ZEN_DEBUG("Used {} (skipping {}) out of {} for block {} ({} %) (use of matching {}%)", + NiceBytes(UsedSize), + NiceBytes(BlockSize - UsedSize), + NiceBytes(BlockSize), + BlockHash, + (100 * UsedSize) / BlockSize, + PotentialSize > 0 ? (UsedSize * 100) / PotentialSize : 0); + } + } } - ReportMessage( - OptionalContext, - fmt::format("Block attachment {} is malformed, can't parse as compressed binary", BlockHash)); - RemoteResult.SetError( - gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), - fmt::format("Block attachment {} is malformed, can't parse as compressed binary", BlockHash), - {}); - return; } - CompositeBuffer BlockPayload = Compressed.DecompressToComposite(); - if (!BlockPayload) + + if (!ErrorString.empty()) { if (RetriesLeft > 0) { - ReportMessage( - OptionalContext, - fmt::format("Block attachment {} is malformed, can't decompress payload, retrying download", - BlockHash)); - return DownloadAndSaveBlock(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, + ReportMessage(Context.OptionalJobContext, fmt::format("{}, retrying download", ErrorString)); + + return DownloadAndSaveBlock(Context, AttachmentsDownloadLatch, AttachmentsWriteLatch, RemoteResult, @@ -519,91 +608,16 @@ namespace remotestore_impl { LoadAttachmentsTimer, DownloadStartMS, BlockHash, - std::move(Chunks), + AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, RetriesLeft - 1); } - ReportMessage(OptionalContext, - fmt::format("Block attachment {} is malformed, can't decompress payload", BlockHash)); - RemoteResult.SetError( - gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), - fmt::format("Block attachment {} is malformed, can't decompress payload", BlockHash), - {}); - return; - } - if (RawHash != BlockHash) - { - ReportMessage(OptionalContext, - fmt::format("Block attachment {} has mismatching raw hash ({})", BlockHash, RawHash)); - RemoteResult.SetError( - gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), - fmt::format("Block attachment {} has mismatching raw hash ({})", BlockHash, RawHash), - {}); - return; - } - - uint64_t PotentialSize = 0; - uint64_t UsedSize = 0; - uint64_t BlockSize = BlockPayload.GetSize(); - - uint64_t BlockHeaderSize = 0; - bool StoreChunksOK = IterateChunkBlock( - BlockPayload.Flatten(), - [&WantedChunks, &WriteAttachmentBuffers, &WriteRawHashes, &Info, &PotentialSize]( - CompressedBuffer&& Chunk, - const IoHash& AttachmentRawHash) { - if (WantedChunks.contains(AttachmentRawHash)) - { - WriteAttachmentBuffers.emplace_back(Chunk.GetCompressed().Flatten().AsIoBuffer()); - IoHash RawHash; - uint64_t RawSize; - ZEN_ASSERT( - CompressedBuffer::ValidateCompressedHeader(WriteAttachmentBuffers.back(), - RawHash, - RawSize, - /*OutOptionalTotalCompressedSize*/ nullptr)); - ZEN_ASSERT(RawHash == AttachmentRawHash); - WriteRawHashes.emplace_back(AttachmentRawHash); - WantedChunks.erase(AttachmentRawHash); - PotentialSize += WriteAttachmentBuffers.back().GetSize(); - } - }, - BlockHeaderSize); - - if (!StoreChunksOK) - { - ReportMessage(OptionalContext, - fmt::format("Block attachment {} has invalid format ({}): {}", - BlockHash, - RemoteResult.GetError(), - RemoteResult.GetErrorReason())); - RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), - fmt::format("Invalid format for block {}", BlockHash), - {}); - return; - } - - ZEN_ASSERT(WantedChunks.empty()); - - if (!WriteAttachmentBuffers.empty()) - { - auto Results = ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); - for (size_t Index = 0; Index < Results.size(); Index++) + else { - const auto& Result = Results[Index]; - if (Result.New) - { - Info.AttachmentBytesStored.fetch_add(WriteAttachmentBuffers[Index].GetSize()); - Info.AttachmentsStored.fetch_add(1); - UsedSize += WriteAttachmentBuffers[Index].GetSize(); - } + ReportMessage(Context.OptionalJobContext, ErrorString); + RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), ErrorString, {}); + return; } - ZEN_DEBUG("Used {} (matching {}) out of {} for block {} ({} %) (use of matching {}%)", - NiceBytes(UsedSize), - NiceBytes(PotentialSize), - NiceBytes(BlockSize), - BlockHash, - (100 * UsedSize) / BlockSize, - PotentialSize > 0 ? (UsedSize * 100) / PotentialSize : 0); } } catch (const std::exception& Ex) @@ -618,19 +632,458 @@ namespace remotestore_impl { catch (const std::exception& Ex) { RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::InternalServerError), - fmt::format("Failed to block attachment {}", BlockHash), + fmt::format("Failed to download block attachment {}", BlockHash), + Ex.what()); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + }; + + void DownloadPartialBlock(LoadOplogContext& Context, + AsyncRemoteResult& RemoteResult, + DownloadInfo& Info, + double& DownloadTimeSeconds, + const ChunkBlockDescription& BlockDescription, + bool BlockExistsInCache, + std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRangeDescriptors, + size_t BlockRangeIndexStart, + size_t BlockRangeCount, + std::function<void(IoBuffer&& Buffer, + size_t BlockRangeStartIndex, + std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths)>&& OnDownloaded) + { + ZEN_ASSERT(Context.StoreMaxRangeCountPerRequest != 0); + ZEN_ASSERT(BlockExistsInCache == false || Context.CacheMaxRangeCountPerRequest != 0); + + std::vector<std::pair<uint64_t, uint64_t>> Ranges; + Ranges.reserve(BlockRangeDescriptors.size()); + for (size_t BlockRangeIndex = BlockRangeIndexStart; BlockRangeIndex < BlockRangeIndexStart + BlockRangeCount; BlockRangeIndex++) + { + const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = BlockRangeDescriptors[BlockRangeIndex]; + Ranges.push_back(std::make_pair(BlockRange.RangeStart, BlockRange.RangeLength)); + } + + size_t SubBlockRangeCount = BlockRangeCount; + size_t SubRangeCountComplete = 0; + std::span<const std::pair<uint64_t, uint64_t>> RangesSpan(Ranges); + + while (SubRangeCountComplete < SubBlockRangeCount) + { + if (RemoteResult.IsError()) + { + break; + } + + size_t SubRangeStartIndex = BlockRangeIndexStart + SubRangeCountComplete; + if (BlockExistsInCache) + { + ZEN_ASSERT(Context.OptionalCache); + size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, Context.CacheMaxRangeCountPerRequest); + + if (SubRangeCount == 1) + { + // Legacy single-range path, prefer that for max compatibility + + const std::pair<uint64_t, uint64_t> SubRange = RangesSpan[SubRangeCountComplete]; + Stopwatch CacheTimer; + IoBuffer PayloadBuffer = Context.OptionalCache->GetBuildBlob(Context.CacheBuildId, + BlockDescription.BlockHash, + SubRange.first, + SubRange.second); + DownloadTimeSeconds += CacheTimer.GetElapsedTimeMs() / 1000.0; + if (RemoteResult.IsError()) + { + break; + } + if (PayloadBuffer) + { + OnDownloaded(std::move(PayloadBuffer), + SubRangeStartIndex, + std::vector<std::pair<uint64_t, uint64_t>>{std::make_pair(0u, SubRange.second)}); + SubRangeCountComplete += SubRangeCount; + continue; + } + } + else + { + auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); + + Stopwatch CacheTimer; + BuildStorageCache::BuildBlobRanges RangeBuffers = + Context.OptionalCache->GetBuildBlobRanges(Context.CacheBuildId, BlockDescription.BlockHash, SubRanges); + DownloadTimeSeconds += CacheTimer.GetElapsedTimeMs() / 1000.0; + if (RemoteResult.IsError()) + { + break; + } + if (RangeBuffers.PayloadBuffer) + { + if (RangeBuffers.Ranges.empty()) + { + SubRangeCount = Ranges.size() - SubRangeCountComplete; + OnDownloaded(std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangesSpan.subspan(SubRangeCountComplete, SubRangeCount)); + SubRangeCountComplete += SubRangeCount; + continue; + } + else if (RangeBuffers.Ranges.size() == SubRangeCount) + { + OnDownloaded(std::move(RangeBuffers.PayloadBuffer), SubRangeStartIndex, RangeBuffers.Ranges); + SubRangeCountComplete += SubRangeCount; + continue; + } + } + } + } + + size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, Context.StoreMaxRangeCountPerRequest); + + auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); + + RemoteProjectStore::LoadAttachmentRangesResult BlockResult = + Context.RemoteStore.LoadAttachmentRanges(BlockDescription.BlockHash, SubRanges); + DownloadTimeSeconds += BlockResult.ElapsedSeconds; + if (RemoteResult.IsError()) + { + break; + } + if (BlockResult.ErrorCode || !BlockResult.Bytes) + { + ReportMessage(Context.OptionalJobContext, + fmt::format("Failed to download {} ranges from block attachment '{}' ({}): {}", + SubRanges.size(), + BlockDescription.BlockHash, + BlockResult.ErrorCode, + BlockResult.Reason)); + Info.MissingAttachmentCount.fetch_add(1); + if (!Context.IgnoreMissingAttachments) + { + RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text); + break; + } + } + else + { + if (BlockResult.Ranges.empty()) + { + // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3 + // Use the whole payload for the remaining ranges + + if (Context.OptionalCache && Context.PopulateCache) + { + Context.OptionalCache->PutBuildBlob(Context.CacheBuildId, + BlockDescription.BlockHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(std::vector<IoBuffer>{BlockResult.Bytes})); + if (RemoteResult.IsError()) + { + break; + } + } + SubRangeCount = Ranges.size() - SubRangeCountComplete; + OnDownloaded(std::move(BlockResult.Bytes), + SubRangeStartIndex, + RangesSpan.subspan(SubRangeCountComplete, SubRangeCount)); + } + else + { + if (BlockResult.Ranges.size() != SubRanges.size()) + { + RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), + fmt::format("Range response for block {} contains {} ranges, expected {} ranges", + BlockDescription.BlockHash, + BlockResult.Ranges.size(), + SubRanges.size()), + ""); + break; + } + OnDownloaded(std::move(BlockResult.Bytes), SubRangeStartIndex, BlockResult.Ranges); + } + } + + SubRangeCountComplete += SubRangeCount; + } + } + + void DownloadAndSavePartialBlock(LoadOplogContext& Context, + Latch& AttachmentsDownloadLatch, + Latch& AttachmentsWriteLatch, + AsyncRemoteResult& RemoteResult, + DownloadInfo& Info, + Stopwatch& LoadAttachmentsTimer, + std::atomic_uint64_t& DownloadStartMS, + const ChunkBlockDescription& BlockDescription, + bool BlockExistsInCache, + std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRangeDescriptors, + size_t BlockRangeIndexStart, + size_t BlockRangeCount, + const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& AllNeededPartialChunkHashesLookup, + std::span<std::atomic<bool>> ChunkDownloadedFlags, + uint32_t RetriesLeft) + { + AttachmentsDownloadLatch.AddCount(1); + Context.NetworkWorkerPool.ScheduleWork( + [&AttachmentsDownloadLatch, + &AttachmentsWriteLatch, + &Context, + &RemoteResult, + &Info, + &LoadAttachmentsTimer, + &DownloadStartMS, + BlockDescription, + BlockExistsInCache, + BlockRangeDescriptors, + BlockRangeIndexStart, + BlockRangeCount, + &AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + RetriesLeft]() { + ZEN_TRACE_CPU("DownloadBlockRanges"); + + auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); }); + try + { + uint64_t Unset = (std::uint64_t)-1; + DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); + + double DownloadElapsedSeconds = 0; + uint64_t DownloadedBytes = 0; + + DownloadPartialBlock( + Context, + RemoteResult, + Info, + DownloadElapsedSeconds, + BlockDescription, + BlockExistsInCache, + BlockRangeDescriptors, + BlockRangeIndexStart, + BlockRangeCount, + [&](IoBuffer&& Buffer, + size_t BlockRangeStartIndex, + std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths) { + uint64_t BlockPartSize = Buffer.GetSize(); + DownloadedBytes += BlockPartSize; + + Info.AttachmentBlockRangeBytesDownloaded.fetch_add(BlockPartSize); + Info.AttachmentBlocksRangesDownloaded++; + + AttachmentsWriteLatch.AddCount(1); + Context.WorkerPool.ScheduleWork( + [&AttachmentsWriteLatch, + &Context, + &AttachmentsDownloadLatch, + &RemoteResult, + &Info, + &LoadAttachmentsTimer, + &DownloadStartMS, + BlockDescription, + BlockExistsInCache, + BlockRangeDescriptors, + BlockRangeStartIndex, + &AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + RetriesLeft, + BlockPayload = std::move(Buffer), + OffsetAndLengths = + std::vector<std::pair<uint64_t, uint64_t>>(OffsetAndLengths.begin(), OffsetAndLengths.end())]() { + auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); }); + try + { + ZEN_ASSERT(BlockPayload.Size() > 0); + + size_t RangeCount = OffsetAndLengths.size(); + for (size_t RangeOffset = 0; RangeOffset < RangeCount; RangeOffset++) + { + if (RemoteResult.IsError()) + { + return; + } + + const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = + BlockRangeDescriptors[BlockRangeStartIndex + RangeOffset]; + const std::pair<uint64_t, uint64_t>& OffsetAndLength = OffsetAndLengths[RangeOffset]; + IoBuffer BlockRangeBuffer(BlockPayload, OffsetAndLength.first, OffsetAndLength.second); + + std::vector<IoBuffer> WriteAttachmentBuffers; + std::vector<IoHash> WriteRawHashes; + + uint64_t PotentialSize = 0; + uint64_t UsedSize = 0; + uint64_t BlockPartSize = BlockRangeBuffer.GetSize(); + + uint32_t OffsetInBlock = 0; + for (uint32_t ChunkBlockIndex = BlockRange.ChunkBlockIndexStart; + ChunkBlockIndex < BlockRange.ChunkBlockIndexStart + BlockRange.ChunkBlockIndexCount; + ChunkBlockIndex++) + { + if (RemoteResult.IsError()) + { + break; + } + + const uint32_t ChunkCompressedSize = + BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; + const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex]; + + if (auto ChunkIndexIt = AllNeededPartialChunkHashesLookup.find(ChunkHash); + ChunkIndexIt != AllNeededPartialChunkHashesLookup.end()) + { + if (!ChunkDownloadedFlags[ChunkIndexIt->second]) + { + IoHash VerifyChunkHash; + uint64_t VerifyChunkSize; + CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed( + SharedBuffer(IoBuffer(BlockRangeBuffer, OffsetInBlock, ChunkCompressedSize)), + VerifyChunkHash, + VerifyChunkSize); + + std::string ErrorString; + + if (!CompressedChunk) + { + ErrorString = fmt::format( + "Chunk at {},{} in block attachment '{}' is not a valid compressed buffer", + OffsetInBlock, + ChunkCompressedSize, + BlockDescription.BlockHash); + } + else if (VerifyChunkHash != ChunkHash) + { + ErrorString = fmt::format( + "Chunk at {},{} in block attachment '{}' has mismatching hash, expected " + "{}, got {}", + OffsetInBlock, + ChunkCompressedSize, + BlockDescription.BlockHash, + ChunkHash, + VerifyChunkHash); + } + else if (VerifyChunkSize != BlockDescription.ChunkRawLengths[ChunkBlockIndex]) + { + ErrorString = fmt::format( + "Chunk at {},{} in block attachment '{}' has mismatching raw size, " + "expected {}, " + "got {}", + OffsetInBlock, + ChunkCompressedSize, + BlockDescription.BlockHash, + BlockDescription.ChunkRawLengths[ChunkBlockIndex], + VerifyChunkSize); + } + + if (!ErrorString.empty()) + { + if (RetriesLeft > 0) + { + ReportMessage(Context.OptionalJobContext, + fmt::format("{}, retrying download", ErrorString)); + return DownloadAndSavePartialBlock(Context, + AttachmentsDownloadLatch, + AttachmentsWriteLatch, + RemoteResult, + Info, + LoadAttachmentsTimer, + DownloadStartMS, + BlockDescription, + BlockExistsInCache, + BlockRangeDescriptors, + BlockRangeStartIndex, + RangeCount, + AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + RetriesLeft - 1); + } + + ReportMessage(Context.OptionalJobContext, ErrorString); + Info.MissingAttachmentCount.fetch_add(1); + if (!Context.IgnoreMissingAttachments) + { + RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::NotFound), + "Malformed chunk block", + ErrorString); + } + } + else + { + bool Expected = false; + if (ChunkDownloadedFlags[ChunkIndexIt->second].compare_exchange_strong(Expected, + true)) + { + WriteAttachmentBuffers.emplace_back( + CompressedChunk.GetCompressed().Flatten().AsIoBuffer()); + WriteRawHashes.emplace_back(ChunkHash); + PotentialSize += WriteAttachmentBuffers.back().GetSize(); + } + } + } + } + OffsetInBlock += ChunkCompressedSize; + } + + if (!WriteAttachmentBuffers.empty()) + { + std::vector<CidStore::InsertResult> Results = + Context.ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); + for (size_t Index = 0; Index < Results.size(); Index++) + { + const CidStore::InsertResult& Result = Results[Index]; + if (Result.New) + { + Info.AttachmentBytesStored.fetch_add(WriteAttachmentBuffers[Index].GetSize()); + Info.AttachmentsStored.fetch_add(1); + UsedSize += WriteAttachmentBuffers[Index].GetSize(); + } + } + if (UsedSize < BlockPartSize) + { + ZEN_DEBUG( + "Used {} (skipping {}) out of {} for block {} range {}, {} ({} %) (use of matching " + "{}%)", + NiceBytes(UsedSize), + NiceBytes(BlockPartSize - UsedSize), + NiceBytes(BlockPartSize), + BlockDescription.BlockHash, + BlockRange.RangeStart, + BlockRange.RangeLength, + (100 * UsedSize) / BlockPartSize, + PotentialSize > 0 ? (UsedSize * 100) / PotentialSize : 0); + } + } + } + } + catch (const std::exception& Ex) + { + RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::InternalServerError), + fmt::format("Failed saving {} ranges from block attachment {}", + OffsetAndLengths.size(), + BlockDescription.BlockHash), + Ex.what()); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + }); + if (!RemoteResult.IsError()) + { + ZEN_DEBUG("Loaded {} ranges from block attachment '{}' in {} ({})", + BlockRangeCount, + BlockDescription.BlockHash, + NiceTimeSpanMs(static_cast<uint64_t>(DownloadElapsedSeconds * 1000)), + NiceBytes(DownloadedBytes)); + } + } + catch (const std::exception& Ex) + { + RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::InternalServerError), + fmt::format("Failed to download block attachment {} ranges", BlockDescription.BlockHash), Ex.what()); } }, WorkerThreadPool::EMode::EnableBacklog); }; - void DownloadAndSaveAttachment(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - bool IgnoreMissingAttachments, - JobContext* OptionalContext, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, + void DownloadAndSaveAttachment(LoadOplogContext& Context, Latch& AttachmentsDownloadLatch, Latch& AttachmentsWriteLatch, AsyncRemoteResult& RemoteResult, @@ -640,19 +1093,15 @@ namespace remotestore_impl { const IoHash& RawHash) { AttachmentsDownloadLatch.AddCount(1); - NetworkWorkerPool.ScheduleWork( - [&RemoteStore, - &ChunkStore, - &WorkerPool, + Context.NetworkWorkerPool.ScheduleWork( + [&Context, &RemoteResult, &AttachmentsDownloadLatch, &AttachmentsWriteLatch, RawHash, &LoadAttachmentsTimer, &DownloadStartMS, - &Info, - IgnoreMissingAttachments, - OptionalContext]() { + &Info]() { ZEN_TRACE_CPU("DownloadAttachment"); auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); }); @@ -664,43 +1113,52 @@ namespace remotestore_impl { { uint64_t Unset = (std::uint64_t)-1; DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); - RemoteProjectStore::LoadAttachmentResult AttachmentResult = RemoteStore.LoadAttachment(RawHash); - if (AttachmentResult.ErrorCode) + IoBuffer BlobBuffer; + if (Context.OptionalCache) { - ReportMessage(OptionalContext, - fmt::format("Failed to download large attachment {}: '{}', error code : {}", - RawHash, - AttachmentResult.Reason, - AttachmentResult.ErrorCode)); - Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + BlobBuffer = Context.OptionalCache->GetBuildBlob(Context.CacheBuildId, RawHash); + } + if (!BlobBuffer) + { + RemoteProjectStore::LoadAttachmentResult AttachmentResult = Context.RemoteStore.LoadAttachment(RawHash); + if (AttachmentResult.ErrorCode) { - RemoteResult.SetError(AttachmentResult.ErrorCode, AttachmentResult.Reason, AttachmentResult.Text); + ReportMessage(Context.OptionalJobContext, + fmt::format("Failed to download large attachment {}: '{}', error code : {}", + RawHash, + AttachmentResult.Reason, + AttachmentResult.ErrorCode)); + Info.MissingAttachmentCount.fetch_add(1); + if (!Context.IgnoreMissingAttachments) + { + RemoteResult.SetError(AttachmentResult.ErrorCode, AttachmentResult.Reason, AttachmentResult.Text); + } + return; + } + BlobBuffer = std::move(AttachmentResult.Bytes); + ZEN_DEBUG("Loaded large attachment '{}' in {} ({})", + RawHash, + NiceTimeSpanMs(static_cast<uint64_t>(AttachmentResult.ElapsedSeconds * 1000)), + NiceBytes(BlobBuffer.GetSize())); + if (Context.OptionalCache && Context.PopulateCache) + { + Context.OptionalCache->PutBuildBlob(Context.CacheBuildId, + RawHash, + BlobBuffer.GetContentType(), + CompositeBuffer(SharedBuffer(BlobBuffer))); } - return; } - uint64_t AttachmentSize = AttachmentResult.Bytes.GetSize(); - ZEN_INFO("Loaded large attachment '{}' in {} ({})", - RawHash, - NiceTimeSpanMs(static_cast<uint64_t>(AttachmentResult.ElapsedSeconds * 1000)), - NiceBytes(AttachmentSize)); - Info.AttachmentsDownloaded.fetch_add(1); if (RemoteResult.IsError()) { return; } + uint64_t AttachmentSize = BlobBuffer.GetSize(); + Info.AttachmentsDownloaded.fetch_add(1); Info.AttachmentBytesDownloaded.fetch_add(AttachmentSize); AttachmentsWriteLatch.AddCount(1); - WorkerPool.ScheduleWork( - [&AttachmentsWriteLatch, - &RemoteResult, - &Info, - &ChunkStore, - RawHash, - AttachmentSize, - Bytes = std::move(AttachmentResult.Bytes), - OptionalContext]() { + Context.WorkerPool.ScheduleWork( + [&Context, &AttachmentsWriteLatch, &RemoteResult, &Info, RawHash, AttachmentSize, Bytes = std::move(BlobBuffer)]() { ZEN_TRACE_CPU("WriteAttachment"); auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); }); @@ -710,7 +1168,7 @@ namespace remotestore_impl { } try { - CidStore::InsertResult InsertResult = ChunkStore.AddChunk(Bytes, RawHash); + CidStore::InsertResult InsertResult = Context.ChunkStore.AddChunk(Bytes, RawHash); if (InsertResult.New) { Info.AttachmentBytesStored.fetch_add(AttachmentSize); @@ -1126,7 +1584,9 @@ namespace remotestore_impl { uint64_t PartialTransferWallTimeMS = Timer.GetElapsedTimeMs(); ReportProgress(OptionalContext, "Saving attachments"sv, - fmt::format("{} remaining... {}", Remaining, GetStats(RemoteStore.GetStats(), PartialTransferWallTimeMS)), + fmt::format("{} remaining... {}", + Remaining, + GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, PartialTransferWallTimeMS)), AttachmentsToSave, Remaining); } @@ -1135,7 +1595,7 @@ namespace remotestore_impl { { ReportProgress(OptionalContext, "Saving attachments"sv, - fmt::format("{}", GetStats(RemoteStore.GetStats(), ElapsedTimeMS)), + fmt::format("{}", GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, ElapsedTimeMS)), AttachmentsToSave, 0); } @@ -1146,7 +1606,7 @@ namespace remotestore_impl { LargeAttachmentCountToUpload, BulkAttachmentCountToUpload, NiceTimeSpanMs(ElapsedTimeMS), - GetStats(RemoteStore.GetStats(), ElapsedTimeMS))); + GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, ElapsedTimeMS))); } } // namespace remotestore_impl @@ -1224,35 +1684,7 @@ BuildContainer(CidStore& ChunkStore, { using namespace std::literals; - class JobContextLogOutput : public OperationLogOutput - { - public: - JobContextLogOutput(JobContext* OptionalContext) : m_OptionalContext(OptionalContext) {} - virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) override - { - ZEN_UNUSED(LogLevel); - if (m_OptionalContext) - { - fmt::basic_memory_buffer<char, 250> MessageBuffer; - fmt::vformat_to(fmt::appender(MessageBuffer), Format, Args); - remotestore_impl::ReportMessage(m_OptionalContext, std::string_view(MessageBuffer.data(), MessageBuffer.size())); - } - } - - virtual void SetLogOperationName(std::string_view Name) override { ZEN_UNUSED(Name); } - virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { ZEN_UNUSED(StepIndex, StepCount); } - virtual uint32_t GetProgressUpdateDelayMS() override { return 0; } - virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override - { - ZEN_UNUSED(InSubTask); - return nullptr; - } - - private: - JobContext* m_OptionalContext; - }; - - std::unique_ptr<OperationLogOutput> LogOutput(std::make_unique<JobContextLogOutput>(OptionalContext)); + std::unique_ptr<OperationLogOutput> LogOutput(std::make_unique<remotestore_impl::JobContextLogOutput>(OptionalContext)); size_t OpCount = 0; @@ -1783,31 +2215,36 @@ BuildContainer(CidStore& ChunkStore, } ResolveAttachmentsLatch.CountDown(); - while (!ResolveAttachmentsLatch.Wait(1000)) { - ptrdiff_t Remaining = ResolveAttachmentsLatch.Remaining(); - if (remotestore_impl::IsCancelled(OptionalContext)) + ptrdiff_t AttachmentCountToUseForProgress = ResolveAttachmentsLatch.Remaining(); + while (!ResolveAttachmentsLatch.Wait(1000)) { - RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", ""); - remotestore_impl::ReportMessage(OptionalContext, - fmt::format("Aborting ({}): {}", RemoteResult.GetError(), RemoteResult.GetErrorReason())); - while (!ResolveAttachmentsLatch.Wait(1000)) + ptrdiff_t Remaining = ResolveAttachmentsLatch.Remaining(); + if (remotestore_impl::IsCancelled(OptionalContext)) { - Remaining = ResolveAttachmentsLatch.Remaining(); - remotestore_impl::ReportProgress(OptionalContext, - "Resolving attachments"sv, - fmt::format("Aborting, {} attachments remaining...", Remaining), - UploadAttachments.size(), - Remaining); + RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", ""); + remotestore_impl::ReportMessage( + OptionalContext, + fmt::format("Aborting ({}): {}", RemoteResult.GetError(), RemoteResult.GetErrorReason())); + while (!ResolveAttachmentsLatch.Wait(1000)) + { + Remaining = ResolveAttachmentsLatch.Remaining(); + remotestore_impl::ReportProgress(OptionalContext, + "Resolving attachments"sv, + fmt::format("Aborting, {} attachments remaining...", Remaining), + UploadAttachments.size(), + Remaining); + } + remotestore_impl::ReportProgress(OptionalContext, "Resolving attachments"sv, "Aborted"sv, UploadAttachments.size(), 0); + return {}; } - remotestore_impl::ReportProgress(OptionalContext, "Resolving attachments"sv, "Aborted"sv, UploadAttachments.size(), 0); - return {}; + AttachmentCountToUseForProgress = Max(Remaining, AttachmentCountToUseForProgress); + remotestore_impl::ReportProgress(OptionalContext, + "Resolving attachments"sv, + fmt::format("{} remaining...", Remaining), + AttachmentCountToUseForProgress, + Remaining); } - remotestore_impl::ReportProgress(OptionalContext, - "Resolving attachments"sv, - fmt::format("{} remaining...", Remaining), - UploadAttachments.size(), - Remaining); } if (UploadAttachments.size() > 0) { @@ -2010,14 +2447,13 @@ BuildContainer(CidStore& ChunkStore, AsyncOnBlock, RemoteResult); ComposedBlocks++; + // Worker will set Blocks[BlockIndex] = Block (including ChunkRawHashes) under shared lock } else { ZEN_INFO("Bulk group {} attachments", ChunkCount); OnBlockChunks(std::move(ChunksInBlock)); - } - { - // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index + // We can share the lock as we are not resizing the vector and only touch our own index RwLock::SharedLockScope _(BlocksLock); Blocks[BlockIndex].ChunkRawHashes = std::move(ChunkRawHashes); } @@ -2195,12 +2631,14 @@ BuildContainer(CidStore& ChunkStore, 0); } - remotestore_impl::ReportMessage(OptionalContext, - fmt::format("Built oplog and collected {} attachments from {} ops into {} blocks and in {}", - ChunkAssembleCount, - TotalOpCount, - GeneratedBlockCount, - NiceTimeSpanMs(static_cast<uint64_t>(Timer.GetElapsedTimeMs())))); + remotestore_impl::ReportMessage( + OptionalContext, + fmt::format("Built oplog and collected {} attachments from {} ops into {} blocks and {} loose attachments in {}", + ChunkAssembleCount, + TotalOpCount, + GeneratedBlockCount, + LargeChunkHashes.size(), + NiceTimeSpanMs(static_cast<uint64_t>(Timer.GetElapsedTimeMs())))); if (remotestore_impl::IsCancelled(OptionalContext)) { @@ -2752,30 +3190,32 @@ SaveOplog(CidStore& ChunkStore, remotestore_impl::LogRemoteStoreStatsDetails(RemoteStore.GetStats()); - remotestore_impl::ReportMessage(OptionalContext, - fmt::format("Saved oplog '{}' {} in {} ({}), Blocks: {} ({}), Attachments: {} ({}) {}", - RemoteStoreInfo.ContainerName, - RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE", - NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0)), - NiceBytes(Info.OplogSizeBytes), - Info.AttachmentBlocksUploaded.load(), - NiceBytes(Info.AttachmentBlockBytesUploaded.load()), - Info.AttachmentsUploaded.load(), - NiceBytes(Info.AttachmentBytesUploaded.load()), - remotestore_impl::GetStats(RemoteStore.GetStats(), TransferWallTimeMS))); + remotestore_impl::ReportMessage( + OptionalContext, + fmt::format("Saved oplog '{}' {} in {} ({}), Blocks: {} ({}), Attachments: {} ({}) {}", + RemoteStoreInfo.ContainerName, + RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE", + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0)), + NiceBytes(Info.OplogSizeBytes), + Info.AttachmentBlocksUploaded.load(), + NiceBytes(Info.AttachmentBlockBytesUploaded.load()), + Info.AttachmentsUploaded.load(), + NiceBytes(Info.AttachmentBytesUploaded.load()), + remotestore_impl::GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, TransferWallTimeMS))); return Result; }; RemoteProjectStore::Result -ParseOplogContainer(const CbObject& ContainerObject, - const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments, - const std::function<bool(const IoHash& RawHash)>& HasAttachment, - const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock, - const std::function<void(const IoHash& RawHash)>& OnNeedAttachment, - const std::function<void(const ChunkedInfo&)>& OnChunkedAttachment, - CbObject& OutOplogSection, - JobContext* OptionalContext) +ParseOplogContainer( + const CbObject& ContainerObject, + const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments, + const std::function<bool(const IoHash& RawHash)>& HasAttachment, + const std::function<void(ThinChunkBlockDescription&& ThinBlockDescription, std::vector<uint32_t>&& NeededChunkIndexes)>& OnNeedBlock, + const std::function<void(const IoHash& RawHash)>& OnNeedAttachment, + const std::function<void(const ChunkedInfo&)>& OnChunkedAttachment, + CbObject& OutOplogSection, + JobContext* OptionalContext) { using namespace std::literals; @@ -2801,22 +3241,43 @@ ParseOplogContainer(const CbObject& ContainerObject, "Section has unexpected data type", "Failed to save oplog container"}; } - std::unordered_set<IoHash, IoHash::Hasher> OpsAttachments; + std::unordered_set<IoHash, IoHash::Hasher> NeededAttachments; { CbArrayView OpsArray = OutOplogSection["ops"sv].AsArrayView(); + + size_t OpCount = OpsArray.Num(); + size_t OpsCompleteCount = 0; + + remotestore_impl::ReportMessage(OptionalContext, fmt::format("Scanning {} ops for attachments", OpCount)); + for (CbFieldView OpEntry : OpsArray) { - OpEntry.IterateAttachments([&](CbFieldView FieldView) { OpsAttachments.insert(FieldView.AsAttachment()); }); + OpEntry.IterateAttachments([&](CbFieldView FieldView) { NeededAttachments.insert(FieldView.AsAttachment()); }); if (remotestore_impl::IsCancelled(OptionalContext)) { return RemoteProjectStore::Result{.ErrorCode = gsl::narrow<int>(HttpResponseCode::OK), .ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0, .Reason = "Operation cancelled"}; } + OpsCompleteCount++; + if ((OpsCompleteCount & 4095) == 0) + { + remotestore_impl::ReportProgress( + OptionalContext, + "Scanning oplog"sv, + fmt::format("{} attachments found, {} ops remaining...", NeededAttachments.size(), OpCount - OpsCompleteCount), + OpCount, + OpCount - OpsCompleteCount); + } } + remotestore_impl::ReportProgress(OptionalContext, + "Scanning oplog"sv, + fmt::format("{} attachments found", NeededAttachments.size()), + OpCount, + OpCount - OpsCompleteCount); } { - std::vector<IoHash> ReferencedAttachments(OpsAttachments.begin(), OpsAttachments.end()); + std::vector<IoHash> ReferencedAttachments(NeededAttachments.begin(), NeededAttachments.end()); OnReferencedAttachments(ReferencedAttachments); } @@ -2827,24 +3288,41 @@ ParseOplogContainer(const CbObject& ContainerObject, .Reason = "Operation cancelled"}; } - remotestore_impl::ReportMessage(OptionalContext, fmt::format("Oplog references {} attachments", OpsAttachments.size())); + remotestore_impl::ReportMessage(OptionalContext, fmt::format("Oplog references {} attachments", NeededAttachments.size())); CbArrayView ChunkedFilesArray = ContainerObject["chunkedfiles"sv].AsArrayView(); for (CbFieldView ChunkedFileField : ChunkedFilesArray) { CbObjectView ChunkedFileView = ChunkedFileField.AsObjectView(); IoHash RawHash = ChunkedFileView["rawhash"sv].AsHash(); - if (OpsAttachments.contains(RawHash) && (!HasAttachment(RawHash))) + if (NeededAttachments.erase(RawHash) == 1) { - ChunkedInfo Chunked = ReadChunkedInfo(ChunkedFileView); + if (!HasAttachment(RawHash)) + { + ChunkedInfo Chunked = ReadChunkedInfo(ChunkedFileView); + + size_t NeededChunkAttachmentCount = 0; - OnReferencedAttachments(Chunked.ChunkHashes); - OpsAttachments.insert(Chunked.ChunkHashes.begin(), Chunked.ChunkHashes.end()); - OnChunkedAttachment(Chunked); - ZEN_INFO("Requesting chunked attachment '{}' ({}) built from {} chunks", - Chunked.RawHash, - NiceBytes(Chunked.RawSize), - Chunked.ChunkHashes.size()); + OnReferencedAttachments(Chunked.ChunkHashes); + for (const IoHash& ChunkHash : Chunked.ChunkHashes) + { + if (!HasAttachment(ChunkHash)) + { + if (NeededAttachments.insert(ChunkHash).second) + { + NeededChunkAttachmentCount++; + } + } + } + OnChunkedAttachment(Chunked); + + remotestore_impl::ReportMessage(OptionalContext, + fmt::format("Requesting chunked attachment '{}' ({}) built from {} chunks, need {} chunks", + Chunked.RawHash, + NiceBytes(Chunked.RawSize), + Chunked.ChunkHashes.size(), + NeededChunkAttachmentCount)); + } } if (remotestore_impl::IsCancelled(OptionalContext)) { @@ -2854,6 +3332,8 @@ ParseOplogContainer(const CbObject& ContainerObject, } } + std::vector<ThinChunkBlockDescription> ThinBlocksDescriptions; + size_t NeedBlockCount = 0; CbArrayView BlocksArray = ContainerObject["blocks"sv].AsArrayView(); for (CbFieldView BlockField : BlocksArray) @@ -2863,45 +3343,38 @@ ParseOplogContainer(const CbObject& ContainerObject, CbArrayView ChunksArray = BlockView["chunks"sv].AsArrayView(); - std::vector<IoHash> NeededChunks; - NeededChunks.reserve(ChunksArray.Num()); - if (BlockHash == IoHash::Zero) + std::vector<IoHash> ChunkHashes; + ChunkHashes.reserve(ChunksArray.Num()); + for (CbFieldView ChunkField : ChunksArray) { - for (CbFieldView ChunkField : ChunksArray) - { - IoHash ChunkHash = ChunkField.AsBinaryAttachment(); - if (OpsAttachments.erase(ChunkHash) == 1) - { - if (!HasAttachment(ChunkHash)) - { - NeededChunks.emplace_back(ChunkHash); - } - } - } + ChunkHashes.push_back(ChunkField.AsHash()); } - else + ThinBlocksDescriptions.push_back(ThinChunkBlockDescription{.BlockHash = BlockHash, .ChunkRawHashes = std::move(ChunkHashes)}); + } + + for (ThinChunkBlockDescription& ThinBlockDescription : ThinBlocksDescriptions) + { + std::vector<uint32_t> NeededBlockChunkIndexes; + for (uint32_t ChunkIndex = 0; ChunkIndex < ThinBlockDescription.ChunkRawHashes.size(); ChunkIndex++) { - for (CbFieldView ChunkField : ChunksArray) + const IoHash& ChunkHash = ThinBlockDescription.ChunkRawHashes[ChunkIndex]; + if (NeededAttachments.erase(ChunkHash) == 1) { - const IoHash ChunkHash = ChunkField.AsHash(); - if (OpsAttachments.erase(ChunkHash) == 1) + if (!HasAttachment(ChunkHash)) { - if (!HasAttachment(ChunkHash)) - { - NeededChunks.emplace_back(ChunkHash); - } + NeededBlockChunkIndexes.push_back(ChunkIndex); } } } - - if (!NeededChunks.empty()) + if (!NeededBlockChunkIndexes.empty()) { - OnNeedBlock(BlockHash, std::move(NeededChunks)); - if (BlockHash != IoHash::Zero) + if (ThinBlockDescription.BlockHash != IoHash::Zero) { NeedBlockCount++; } + OnNeedBlock(std::move(ThinBlockDescription), std::move(NeededBlockChunkIndexes)); } + if (remotestore_impl::IsCancelled(OptionalContext)) { return RemoteProjectStore::Result{.ErrorCode = gsl::narrow<int>(HttpResponseCode::OK), @@ -2909,6 +3382,7 @@ ParseOplogContainer(const CbObject& ContainerObject, .Reason = "Operation cancelled"}; } } + remotestore_impl::ReportMessage(OptionalContext, fmt::format("Requesting {} of {} attachment blocks", NeedBlockCount, BlocksArray.Num())); @@ -2918,7 +3392,7 @@ ParseOplogContainer(const CbObject& ContainerObject, { IoHash AttachmentHash = LargeChunksField.AsBinaryAttachment(); - if (OpsAttachments.erase(AttachmentHash) == 1) + if (NeededAttachments.erase(AttachmentHash) == 1) { if (!HasAttachment(AttachmentHash)) { @@ -2941,14 +3415,15 @@ ParseOplogContainer(const CbObject& ContainerObject, } RemoteProjectStore::Result -SaveOplogContainer(ProjectStore::Oplog& Oplog, - const CbObject& ContainerObject, - const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments, - const std::function<bool(const IoHash& RawHash)>& HasAttachment, - const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock, - const std::function<void(const IoHash& RawHash)>& OnNeedAttachment, - const std::function<void(const ChunkedInfo&)>& OnChunkedAttachment, - JobContext* OptionalContext) +SaveOplogContainer( + ProjectStore::Oplog& Oplog, + const CbObject& ContainerObject, + const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments, + const std::function<bool(const IoHash& RawHash)>& HasAttachment, + const std::function<void(ThinChunkBlockDescription&& ThinBlockDescription, std::vector<uint32_t>&& NeededChunkIndexes)>& OnNeedBlock, + const std::function<void(const IoHash& RawHash)>& OnNeedAttachment, + const std::function<void(const ChunkedInfo&)>& OnChunkedAttachment, + JobContext* OptionalContext) { using namespace std::literals; @@ -2972,18 +3447,12 @@ SaveOplogContainer(ProjectStore::Oplog& Oplog, } RemoteProjectStore::Result -LoadOplog(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - ProjectStore::Oplog& Oplog, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, - bool ForceDownload, - bool IgnoreMissingAttachments, - bool CleanOplog, - JobContext* OptionalContext) +LoadOplog(LoadOplogContext&& Context) { using namespace std::literals; + std::unique_ptr<OperationLogOutput> LogOutput(std::make_unique<remotestore_impl::JobContextLogOutput>(Context.OptionalJobContext)); + remotestore_impl::DownloadInfo Info; Stopwatch Timer; @@ -2991,25 +3460,25 @@ LoadOplog(CidStore& ChunkStore, std::unordered_set<IoHash, IoHash::Hasher> Attachments; uint64_t BlockCountToDownload = 0; - RemoteProjectStore::RemoteStoreInfo RemoteStoreInfo = RemoteStore.GetInfo(); - remotestore_impl::ReportMessage(OptionalContext, fmt::format("Loading oplog container '{}'", RemoteStoreInfo.ContainerName)); + RemoteProjectStore::RemoteStoreInfo RemoteStoreInfo = Context.RemoteStore.GetInfo(); + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Loading oplog container '{}'", RemoteStoreInfo.ContainerName)); uint64_t TransferWallTimeMS = 0; Stopwatch LoadContainerTimer; - RemoteProjectStore::LoadContainerResult LoadContainerResult = RemoteStore.LoadContainer(); + RemoteProjectStore::LoadContainerResult LoadContainerResult = Context.RemoteStore.LoadContainer(); TransferWallTimeMS += LoadContainerTimer.GetElapsedTimeMs(); if (LoadContainerResult.ErrorCode) { remotestore_impl::ReportMessage( - OptionalContext, + Context.OptionalJobContext, fmt::format("Failed to load oplog container: '{}', error code: {}", LoadContainerResult.Reason, LoadContainerResult.ErrorCode)); return RemoteProjectStore::Result{.ErrorCode = LoadContainerResult.ErrorCode, .ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0, .Reason = LoadContainerResult.Reason, .Text = LoadContainerResult.Text}; } - remotestore_impl::ReportMessage(OptionalContext, + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Loaded container in {} ({})", NiceTimeSpanMs(static_cast<uint64_t>(LoadContainerResult.ElapsedSeconds * 1000)), NiceBytes(LoadContainerResult.ContainerObject.GetSize()))); @@ -3023,22 +3492,27 @@ LoadOplog(CidStore& ChunkStore, Stopwatch LoadAttachmentsTimer; std::atomic_uint64_t DownloadStartMS = (std::uint64_t)-1; - auto HasAttachment = [&Oplog, &ChunkStore, ForceDownload](const IoHash& RawHash) { - if (ForceDownload) + auto HasAttachment = [&Context](const IoHash& RawHash) { + if (Context.ForceDownload) { return false; } - if (ChunkStore.ContainsChunk(RawHash)) + if (Context.ChunkStore.ContainsChunk(RawHash)) { return true; } return false; }; - auto OnNeedBlock = [&RemoteStore, - &ChunkStore, - &NetworkWorkerPool, - &WorkerPool, + struct NeededBlockDownload + { + ThinChunkBlockDescription ThinBlockDescription; + std::vector<uint32_t> NeededChunkIndexes; + }; + + std::vector<NeededBlockDownload> NeededBlockDownloads; + + auto OnNeedBlock = [&Context, &AttachmentsDownloadLatch, &AttachmentsWriteLatch, &AttachmentCount, @@ -3047,8 +3521,8 @@ LoadOplog(CidStore& ChunkStore, &Info, &LoadAttachmentsTimer, &DownloadStartMS, - IgnoreMissingAttachments, - OptionalContext](const IoHash& BlockHash, std::vector<IoHash>&& Chunks) { + &NeededBlockDownloads](ThinChunkBlockDescription&& ThinBlockDescription, + std::vector<uint32_t>&& NeededChunkIndexes) { if (RemoteResult.IsError()) { return; @@ -3056,47 +3530,26 @@ LoadOplog(CidStore& ChunkStore, BlockCountToDownload++; AttachmentCount.fetch_add(1); - if (BlockHash == IoHash::Zero) - { - DownloadAndSaveBlockChunks(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, + if (ThinBlockDescription.BlockHash == IoHash::Zero) + { + DownloadAndSaveBlockChunks(Context, AttachmentsDownloadLatch, AttachmentsWriteLatch, RemoteResult, Info, LoadAttachmentsTimer, DownloadStartMS, - Chunks); + std::move(ThinBlockDescription), + std::move(NeededChunkIndexes)); } else { - DownloadAndSaveBlock(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, - AttachmentsDownloadLatch, - AttachmentsWriteLatch, - RemoteResult, - Info, - LoadAttachmentsTimer, - DownloadStartMS, - BlockHash, - Chunks, - 3); + NeededBlockDownloads.push_back(NeededBlockDownload{.ThinBlockDescription = std::move(ThinBlockDescription), + .NeededChunkIndexes = std::move(NeededChunkIndexes)}); } }; - auto OnNeedAttachment = [&RemoteStore, - &Oplog, - &ChunkStore, - &NetworkWorkerPool, - &WorkerPool, + auto OnNeedAttachment = [&Context, &AttachmentsDownloadLatch, &AttachmentsWriteLatch, &RemoteResult, @@ -3104,9 +3557,7 @@ LoadOplog(CidStore& ChunkStore, &AttachmentCount, &LoadAttachmentsTimer, &DownloadStartMS, - &Info, - IgnoreMissingAttachments, - OptionalContext](const IoHash& RawHash) { + &Info](const IoHash& RawHash) { if (!Attachments.insert(RawHash).second) { return; @@ -3116,12 +3567,7 @@ LoadOplog(CidStore& ChunkStore, return; } AttachmentCount.fetch_add(1); - DownloadAndSaveAttachment(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, + DownloadAndSaveAttachment(Context, AttachmentsDownloadLatch, AttachmentsWriteLatch, RemoteResult, @@ -3132,18 +3578,13 @@ LoadOplog(CidStore& ChunkStore, }; std::vector<ChunkedInfo> FilesToDechunk; - auto OnChunkedAttachment = [&Oplog, &ChunkStore, &FilesToDechunk, ForceDownload](const ChunkedInfo& Chunked) { - if (ForceDownload || !ChunkStore.ContainsChunk(Chunked.RawHash)) - { - FilesToDechunk.push_back(Chunked); - } - }; + auto OnChunkedAttachment = [&FilesToDechunk](const ChunkedInfo& Chunked) { FilesToDechunk.push_back(Chunked); }; - auto OnReferencedAttachments = [&Oplog](std::span<IoHash> RawHashes) { Oplog.CaptureAddedAttachments(RawHashes); }; + auto OnReferencedAttachments = [&Context](std::span<IoHash> RawHashes) { Context.Oplog.CaptureAddedAttachments(RawHashes); }; // Make sure we retain any attachments we download before writing the oplog - Oplog.EnableUpdateCapture(); - auto _ = MakeGuard([&Oplog]() { Oplog.DisableUpdateCapture(); }); + Context.Oplog.EnableUpdateCapture(); + auto _ = MakeGuard([&Context]() { Context.Oplog.DisableUpdateCapture(); }); CbObject OplogSection; RemoteProjectStore::Result Result = ParseOplogContainer(LoadContainerResult.ContainerObject, @@ -3153,40 +3594,268 @@ LoadOplog(CidStore& ChunkStore, OnNeedAttachment, OnChunkedAttachment, OplogSection, - OptionalContext); + Context.OptionalJobContext); if (Result.ErrorCode != 0) { RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); } - remotestore_impl::ReportMessage(OptionalContext, + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Parsed oplog in {}, found {} attachments, {} blocks and {} chunked files to download", NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0)), Attachments.size(), BlockCountToDownload, FilesToDechunk.size())); - AttachmentsDownloadLatch.CountDown(); - while (!AttachmentsDownloadLatch.Wait(1000)) + std::vector<IoHash> BlockHashes; + std::vector<IoHash> AllNeededChunkHashes; + BlockHashes.reserve(NeededBlockDownloads.size()); + for (const NeededBlockDownload& BlockDownload : NeededBlockDownloads) { - ptrdiff_t Remaining = AttachmentsDownloadLatch.Remaining(); - if (remotestore_impl::IsCancelled(OptionalContext)) + BlockHashes.push_back(BlockDownload.ThinBlockDescription.BlockHash); + for (uint32_t ChunkIndex : BlockDownload.NeededChunkIndexes) { - if (!RemoteResult.IsError()) + AllNeededChunkHashes.push_back(BlockDownload.ThinBlockDescription.ChunkRawHashes[ChunkIndex]); + } + } + + tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> AllNeededPartialChunkHashesLookup = BuildHashLookup(AllNeededChunkHashes); + std::vector<std::atomic<bool>> ChunkDownloadedFlags(AllNeededChunkHashes.size()); + std::vector<bool> DownloadedViaLegacyChunkFlag(AllNeededChunkHashes.size(), false); + ChunkBlockAnalyser::BlockResult PartialBlocksResult; + + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Fetching descriptions for {} blocks", BlockHashes.size())); + + RemoteProjectStore::GetBlockDescriptionsResult BlockDescriptions = + Context.RemoteStore.GetBlockDescriptions(BlockHashes, Context.OptionalCache, Context.CacheBuildId); + + remotestore_impl::ReportMessage(Context.OptionalJobContext, + fmt::format("GetBlockDescriptions took {}. Found {} blocks", + NiceTimeSpanMs(uint64_t(BlockDescriptions.ElapsedSeconds * 1000)), + BlockDescriptions.Blocks.size())); + + std::vector<IoHash> BlocksWithDescription; + BlocksWithDescription.reserve(BlockDescriptions.Blocks.size()); + for (const ChunkBlockDescription& BlockDescription : BlockDescriptions.Blocks) + { + BlocksWithDescription.push_back(BlockDescription.BlockHash); + } + { + auto WantIt = NeededBlockDownloads.begin(); + auto FindIt = BlockDescriptions.Blocks.begin(); + while (WantIt != NeededBlockDownloads.end()) + { + if (FindIt == BlockDescriptions.Blocks.end()) { - RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", ""); + // Fall back to full download as we can't get enough information about the block + DownloadAndSaveBlock(Context, + AttachmentsDownloadLatch, + AttachmentsWriteLatch, + RemoteResult, + Info, + LoadAttachmentsTimer, + DownloadStartMS, + WantIt->ThinBlockDescription.BlockHash, + AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + 3); + for (uint32_t BlockChunkIndex : WantIt->NeededChunkIndexes) + { + const IoHash& ChunkHash = WantIt->ThinBlockDescription.ChunkRawHashes[BlockChunkIndex]; + auto It = AllNeededPartialChunkHashesLookup.find(ChunkHash); + ZEN_ASSERT(It != AllNeededPartialChunkHashesLookup.end()); + uint32_t ChunkIndex = It->second; + DownloadedViaLegacyChunkFlag[ChunkIndex] = true; + } + WantIt++; + } + else if (WantIt->ThinBlockDescription.BlockHash == FindIt->BlockHash) + { + // Found + FindIt++; + WantIt++; + } + else + { + // Not a requested block? + ZEN_ASSERT(false); } } - uint64_t PartialTransferWallTimeMS = TransferWallTimeMS; - if (DownloadStartMS != (uint64_t)-1) + } + if (!AllNeededChunkHashes.empty()) + { + std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> PartialBlockDownloadModes; + std::vector<bool> BlockExistsInCache(BlocksWithDescription.size(), false); + + if (Context.PartialBlockRequestMode == EPartialBlockRequestMode::Off) { - PartialTransferWallTimeMS += LoadAttachmentsTimer.GetElapsedTimeMs() - DownloadStartMS.load(); + PartialBlockDownloadModes.resize(BlocksWithDescription.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); + } + else + { + if (Context.OptionalCache) + { + std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult = + Context.OptionalCache->BlobsExists(Context.CacheBuildId, BlocksWithDescription); + if (CacheExistsResult.size() == BlocksWithDescription.size()) + { + for (size_t BlobIndex = 0; BlobIndex < CacheExistsResult.size(); BlobIndex++) + { + BlockExistsInCache[BlobIndex] = CacheExistsResult[BlobIndex].HasBody; + } + } + uint64_t FoundBlocks = + std::accumulate(BlockExistsInCache.begin(), + BlockExistsInCache.end(), + uint64_t(0u), + [](uint64_t Current, bool Exists) -> uint64_t { return Current + (Exists ? 1 : 0); }); + if (FoundBlocks > 0) + { + remotestore_impl::ReportMessage( + Context.OptionalJobContext, + fmt::format("Found {} out of {} blocks in cache", FoundBlocks, BlockExistsInCache.size())); + } + } + + ChunkBlockAnalyser::EPartialBlockDownloadMode CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + ChunkBlockAnalyser::EPartialBlockDownloadMode CachePartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + + switch (Context.PartialBlockRequestMode) + { + case EPartialBlockRequestMode::Off: + break; + case EPartialBlockRequestMode::ZenCacheOnly: + CachePartialDownloadMode = Context.CacheMaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + break; + case EPartialBlockRequestMode::Mixed: + CachePartialDownloadMode = Context.CacheMaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; + break; + case EPartialBlockRequestMode::All: + CachePartialDownloadMode = Context.CacheMaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = Context.StoreMaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange + : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; + break; + } + + PartialBlockDownloadModes.reserve(BlocksWithDescription.size()); + for (uint32_t BlockIndex = 0; BlockIndex < BlocksWithDescription.size(); BlockIndex++) + { + const bool BlockExistInCache = BlockExistsInCache[BlockIndex]; + PartialBlockDownloadModes.push_back(BlockExistInCache ? CachePartialDownloadMode : CloudPartialDownloadMode); + } + } + + ZEN_ASSERT(PartialBlockDownloadModes.size() == BlocksWithDescription.size()); + + ChunkBlockAnalyser PartialAnalyser( + *LogOutput, + BlockDescriptions.Blocks, + ChunkBlockAnalyser::Options{.IsQuiet = false, + .IsVerbose = false, + .HostLatencySec = Context.StoreLatencySec, + .HostHighSpeedLatencySec = Context.CacheLatencySec, + .HostMaxRangeCountPerRequest = Context.StoreMaxRangeCountPerRequest, + .HostHighSpeedMaxRangeCountPerRequest = Context.CacheMaxRangeCountPerRequest}); + + std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = + PartialAnalyser.GetNeeded(AllNeededPartialChunkHashesLookup, + [&](uint32_t ChunkIndex) { return !DownloadedViaLegacyChunkFlag[ChunkIndex]; }); + + PartialBlocksResult = PartialAnalyser.CalculatePartialBlockDownloads(NeededBlocks, PartialBlockDownloadModes); + + for (uint32_t FullBlockIndex : PartialBlocksResult.FullBlockIndexes) + { + DownloadAndSaveBlock(Context, + AttachmentsDownloadLatch, + AttachmentsWriteLatch, + RemoteResult, + Info, + LoadAttachmentsTimer, + DownloadStartMS, + BlockDescriptions.Blocks[FullBlockIndex].BlockHash, + AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + 3); + } + + for (size_t BlockRangeIndex = 0; BlockRangeIndex < PartialBlocksResult.BlockRanges.size();) + { + size_t RangeCount = 1; + size_t RangesLeft = PartialBlocksResult.BlockRanges.size() - BlockRangeIndex; + const ChunkBlockAnalyser::BlockRangeDescriptor& CurrentBlockRange = PartialBlocksResult.BlockRanges[BlockRangeIndex]; + while (RangeCount < RangesLeft && + CurrentBlockRange.BlockIndex == PartialBlocksResult.BlockRanges[BlockRangeIndex + RangeCount].BlockIndex) + { + RangeCount++; + } + + DownloadAndSavePartialBlock(Context, + AttachmentsDownloadLatch, + AttachmentsWriteLatch, + RemoteResult, + Info, + LoadAttachmentsTimer, + DownloadStartMS, + BlockDescriptions.Blocks[CurrentBlockRange.BlockIndex], + BlockExistsInCache[CurrentBlockRange.BlockIndex], + PartialBlocksResult.BlockRanges, + BlockRangeIndex, + RangeCount, + AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + /* RetriesLeft*/ 3); + + BlockRangeIndex += RangeCount; + } + } + + AttachmentsDownloadLatch.CountDown(); + { + ptrdiff_t AttachmentCountToUseForProgress = AttachmentsDownloadLatch.Remaining(); + while (!AttachmentsDownloadLatch.Wait(1000)) + { + ptrdiff_t Remaining = AttachmentsDownloadLatch.Remaining(); + if (remotestore_impl::IsCancelled(Context.OptionalJobContext)) + { + if (!RemoteResult.IsError()) + { + RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", ""); + } + } + uint64_t PartialTransferWallTimeMS = TransferWallTimeMS; + if (DownloadStartMS != (uint64_t)-1) + { + PartialTransferWallTimeMS += LoadAttachmentsTimer.GetElapsedTimeMs() - DownloadStartMS.load(); + } + + uint64_t AttachmentsDownloaded = + Info.AttachmentBlocksDownloaded.load() + Info.AttachmentBlocksRangesDownloaded.load() + Info.AttachmentsDownloaded.load(); + uint64_t AttachmentBytesDownloaded = Info.AttachmentBlockBytesDownloaded.load() + + Info.AttachmentBlockRangeBytesDownloaded.load() + Info.AttachmentBytesDownloaded.load(); + + AttachmentCountToUseForProgress = Max(Remaining, AttachmentCountToUseForProgress); + remotestore_impl::ReportProgress( + Context.OptionalJobContext, + "Loading attachments"sv, + fmt::format( + "{} ({}) downloaded, {} ({}) stored, {} remaining. {}", + AttachmentsDownloaded, + NiceBytes(AttachmentBytesDownloaded), + Info.AttachmentsStored.load(), + NiceBytes(Info.AttachmentBytesStored.load()), + Remaining, + remotestore_impl::GetStats(Context.RemoteStore.GetStats(), Context.OptionalCacheStats, PartialTransferWallTimeMS)), + AttachmentCountToUseForProgress, + Remaining); } - remotestore_impl::ReportProgress( - OptionalContext, - "Loading attachments"sv, - fmt::format("{} remaining. {}", Remaining, remotestore_impl::GetStats(RemoteStore.GetStats(), PartialTransferWallTimeMS)), - AttachmentCount.load(), - Remaining); } if (DownloadStartMS != (uint64_t)-1) { @@ -3195,57 +3864,58 @@ LoadOplog(CidStore& ChunkStore, if (AttachmentCount.load() > 0) { - remotestore_impl::ReportProgress(OptionalContext, - "Loading attachments"sv, - fmt::format("{}", remotestore_impl::GetStats(RemoteStore.GetStats(), TransferWallTimeMS)), - AttachmentCount.load(), - 0); + remotestore_impl::ReportProgress( + Context.OptionalJobContext, + "Loading attachments"sv, + fmt::format("{}", remotestore_impl::GetStats(Context.RemoteStore.GetStats(), Context.OptionalCacheStats, TransferWallTimeMS)), + AttachmentCount.load(), + 0); } AttachmentsWriteLatch.CountDown(); - while (!AttachmentsWriteLatch.Wait(1000)) { - ptrdiff_t Remaining = AttachmentsWriteLatch.Remaining(); - if (remotestore_impl::IsCancelled(OptionalContext)) + ptrdiff_t AttachmentCountToUseForProgress = AttachmentsWriteLatch.Remaining(); + while (!AttachmentsWriteLatch.Wait(1000)) { - if (!RemoteResult.IsError()) + ptrdiff_t Remaining = AttachmentsWriteLatch.Remaining(); + if (remotestore_impl::IsCancelled(Context.OptionalJobContext)) { - RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", ""); + if (!RemoteResult.IsError()) + { + RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", ""); + } } + AttachmentCountToUseForProgress = Max(Remaining, AttachmentCountToUseForProgress); + remotestore_impl::ReportProgress(Context.OptionalJobContext, + "Writing attachments"sv, + fmt::format("{} ({}), {} remaining.", + Info.AttachmentsStored.load(), + NiceBytes(Info.AttachmentBytesStored.load()), + Remaining), + AttachmentCountToUseForProgress, + Remaining); } - remotestore_impl::ReportProgress(OptionalContext, - "Writing attachments"sv, - fmt::format("{} remaining.", Remaining), - AttachmentCount.load(), - Remaining); } if (AttachmentCount.load() > 0) { - remotestore_impl::ReportProgress(OptionalContext, "Writing attachments", ""sv, AttachmentCount.load(), 0); + remotestore_impl::ReportProgress(Context.OptionalJobContext, "Writing attachments", ""sv, AttachmentCount.load(), 0); } if (Result.ErrorCode == 0) { if (!FilesToDechunk.empty()) { - remotestore_impl::ReportMessage(OptionalContext, fmt::format("Dechunking {} attachments", FilesToDechunk.size())); + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Dechunking {} attachments", FilesToDechunk.size())); Latch DechunkLatch(1); - std::filesystem::path TempFilePath = Oplog.TempPath(); + std::filesystem::path TempFilePath = Context.Oplog.TempPath(); for (const ChunkedInfo& Chunked : FilesToDechunk) { std::filesystem::path TempFileName = TempFilePath / Chunked.RawHash.ToHexString(); DechunkLatch.AddCount(1); - WorkerPool.ScheduleWork( - [&ChunkStore, - &DechunkLatch, - TempFileName, - &Chunked, - &RemoteResult, - IgnoreMissingAttachments, - &Info, - OptionalContext]() { + Context.WorkerPool.ScheduleWork( + [&Context, &DechunkLatch, TempFileName, &Chunked, &RemoteResult, &Info]() { ZEN_TRACE_CPU("DechunkAttachment"); auto _ = MakeGuard([&DechunkLatch, &TempFileName] { @@ -3279,16 +3949,16 @@ LoadOplog(CidStore& ChunkStore, for (std::uint32_t SequenceIndex : Chunked.ChunkSequence) { const IoHash& ChunkHash = Chunked.ChunkHashes[SequenceIndex]; - IoBuffer Chunk = ChunkStore.FindChunkByCid(ChunkHash); + IoBuffer Chunk = Context.ChunkStore.FindChunkByCid(ChunkHash); if (!Chunk) { remotestore_impl::ReportMessage( - OptionalContext, + Context.OptionalJobContext, fmt::format("Missing chunk {} for chunked attachment {}", ChunkHash, Chunked.RawHash)); // We only add 1 as the resulting missing count will be 1 for the dechunked file Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + if (!Context.IgnoreMissingAttachments) { RemoteResult.SetError( gsl::narrow<int>(HttpResponseCode::NotFound), @@ -3306,7 +3976,7 @@ LoadOplog(CidStore& ChunkStore, if (RawHash != ChunkHash) { remotestore_impl::ReportMessage( - OptionalContext, + Context.OptionalJobContext, fmt::format("Mismatching raw hash {} for chunk {} for chunked attachment {}", RawHash, ChunkHash, @@ -3314,7 +3984,7 @@ LoadOplog(CidStore& ChunkStore, // We only add 1 as the resulting missing count will be 1 for the dechunked file Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + if (!Context.IgnoreMissingAttachments) { RemoteResult.SetError( gsl::narrow<int>(HttpResponseCode::NotFound), @@ -3351,14 +4021,14 @@ LoadOplog(CidStore& ChunkStore, })) { remotestore_impl::ReportMessage( - OptionalContext, + Context.OptionalJobContext, fmt::format("Failed to decompress chunk {} for chunked attachment {}", ChunkHash, Chunked.RawHash)); // We only add 1 as the resulting missing count will be 1 for the dechunked file Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + if (!Context.IgnoreMissingAttachments) { RemoteResult.SetError( gsl::narrow<int>(HttpResponseCode::NotFound), @@ -3380,11 +4050,12 @@ LoadOplog(CidStore& ChunkStore, TmpFile.Close(); TmpBuffer = IoBufferBuilder::MakeFromTemporaryFile(TempFileName); } + uint64_t TmpBufferSize = TmpBuffer.GetSize(); CidStore::InsertResult InsertResult = - ChunkStore.AddChunk(TmpBuffer, Chunked.RawHash, CidStore::InsertMode::kMayBeMovedInPlace); + Context.ChunkStore.AddChunk(TmpBuffer, Chunked.RawHash, CidStore::InsertMode::kMayBeMovedInPlace); if (InsertResult.New) { - Info.AttachmentBytesStored.fetch_add(TmpBuffer.GetSize()); + Info.AttachmentBytesStored.fetch_add(TmpBufferSize); Info.AttachmentsStored.fetch_add(1); } @@ -3407,54 +4078,58 @@ LoadOplog(CidStore& ChunkStore, while (!DechunkLatch.Wait(1000)) { ptrdiff_t Remaining = DechunkLatch.Remaining(); - if (remotestore_impl::IsCancelled(OptionalContext)) + if (remotestore_impl::IsCancelled(Context.OptionalJobContext)) { if (!RemoteResult.IsError()) { RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", ""); remotestore_impl::ReportMessage( - OptionalContext, + Context.OptionalJobContext, fmt::format("Aborting ({}): {}", RemoteResult.GetError(), RemoteResult.GetErrorReason())); } } - remotestore_impl::ReportProgress(OptionalContext, + remotestore_impl::ReportProgress(Context.OptionalJobContext, "Dechunking attachments"sv, fmt::format("{} remaining...", Remaining), FilesToDechunk.size(), Remaining); } - remotestore_impl::ReportProgress(OptionalContext, "Dechunking attachments"sv, ""sv, FilesToDechunk.size(), 0); + remotestore_impl::ReportProgress(Context.OptionalJobContext, "Dechunking attachments"sv, ""sv, FilesToDechunk.size(), 0); } Result = RemoteResult.ConvertResult(); } if (Result.ErrorCode == 0) { - if (CleanOplog) + if (Context.CleanOplog) { - RemoteStore.Flush(); - if (!Oplog.Reset()) + if (Context.OptionalCache) + { + Context.OptionalCache->Flush(100, [](intptr_t) { return /*DontWaitForPendingOperation*/ false; }); + } + if (!Context.Oplog.Reset()) { Result = RemoteProjectStore::Result{.ErrorCode = gsl::narrow<int>(HttpResponseCode::InternalServerError), .ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0, - .Reason = fmt::format("Failed to clean existing oplog '{}'", Oplog.OplogId())}; - remotestore_impl::ReportMessage(OptionalContext, fmt::format("Aborting ({}): {}", Result.ErrorCode, Result.Reason)); + .Reason = fmt::format("Failed to clean existing oplog '{}'", Context.Oplog.OplogId())}; + remotestore_impl::ReportMessage(Context.OptionalJobContext, + fmt::format("Aborting ({}): {}", Result.ErrorCode, Result.Reason)); } } if (Result.ErrorCode == 0) { - remotestore_impl::WriteOplogSection(Oplog, OplogSection, OptionalContext); + remotestore_impl::WriteOplogSection(Context.Oplog, OplogSection, Context.OptionalJobContext); } } Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; - remotestore_impl::LogRemoteStoreStatsDetails(RemoteStore.GetStats()); + remotestore_impl::LogRemoteStoreStatsDetails(Context.RemoteStore.GetStats()); { std::string DownloadDetails; RemoteProjectStore::ExtendedStats ExtendedStats; - if (RemoteStore.GetExtendedStats(ExtendedStats)) + if (Context.RemoteStore.GetExtendedStats(ExtendedStats)) { if (!ExtendedStats.m_ReceivedBytesPerSource.empty()) { @@ -3473,26 +4148,37 @@ LoadOplog(CidStore& ChunkStore, Total += It.second; } - remotestore_impl::ReportMessage(OptionalContext, fmt::format("Downloaded {} ({})", NiceBytes(Total), SB.ToView())); + remotestore_impl::ReportMessage(Context.OptionalJobContext, + fmt::format("Downloaded {} ({})", NiceBytes(Total), SB.ToView())); } } } + uint64_t TotalDownloads = + 1 + Info.AttachmentBlocksDownloaded.load() + Info.AttachmentBlocksRangesDownloaded.load() + Info.AttachmentsDownloaded.load(); + uint64_t TotalBytesDownloaded = Info.OplogSizeBytes + Info.AttachmentBlockBytesDownloaded.load() + + Info.AttachmentBlockRangeBytesDownloaded.load() + Info.AttachmentBytesDownloaded.load(); + remotestore_impl::ReportMessage( - OptionalContext, - fmt::format("Loaded oplog '{}' {} in {} ({}), Blocks: {} ({}), Attachments: {} ({}), Stored: {} ({}), Missing: {} {}", + Context.OptionalJobContext, + fmt::format("Loaded oplog '{}' {} in {} ({}), Blocks: {} ({}), BlockRanges: {} ({}), Attachments: {} " + "({}), Total: {} ({}), Stored: {} ({}), Missing: {} {}", RemoteStoreInfo.ContainerName, Result.ErrorCode == 0 ? "SUCCESS" : "FAILURE", NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0)), NiceBytes(Info.OplogSizeBytes), Info.AttachmentBlocksDownloaded.load(), NiceBytes(Info.AttachmentBlockBytesDownloaded.load()), + Info.AttachmentBlocksRangesDownloaded.load(), + NiceBytes(Info.AttachmentBlockRangeBytesDownloaded.load()), Info.AttachmentsDownloaded.load(), NiceBytes(Info.AttachmentBytesDownloaded.load()), + TotalDownloads, + NiceBytes(TotalBytesDownloaded), Info.AttachmentsStored.load(), NiceBytes(Info.AttachmentBytesStored.load()), Info.MissingAttachmentCount.load(), - remotestore_impl::GetStats(RemoteStore.GetStats(), TransferWallTimeMS))); + remotestore_impl::GetStats(Context.RemoteStore.GetStats(), Context.OptionalCacheStats, TransferWallTimeMS))); return Result; } @@ -3537,7 +4223,7 @@ RemoteProjectStore::~RemoteProjectStore() #if ZEN_WITH_TESTS -namespace testutils { +namespace projectstore_testutils { using namespace std::literals; static std::string OidAsString(const Oid& Id) @@ -3589,7 +4275,29 @@ namespace testutils { return Result; } -} // namespace testutils + class TestJobContext : public JobContext + { + public: + explicit TestJobContext(int& OpIndex) : m_OpIndex(OpIndex) {} + virtual bool IsCancelled() const { return false; } + virtual void ReportMessage(std::string_view Message) { ZEN_INFO("Job {}: {}", m_OpIndex, Message); } + virtual void ReportProgress(std::string_view CurrentOp, std::string_view Details, ptrdiff_t TotalCount, ptrdiff_t RemainingCount) + { + ZEN_INFO("Job {}: Op '{}'{} {}/{}", + m_OpIndex, + CurrentOp, + Details.empty() ? "" : fmt::format(" {}", Details), + TotalCount - RemainingCount, + TotalCount); + } + + private: + int& m_OpIndex; + }; + +} // namespace projectstore_testutils + +TEST_SUITE_BEGIN("remotestore.projectstore"); struct ExportForceDisableBlocksTrue_ForceTempBlocksFalse { @@ -3616,7 +4324,7 @@ TEST_CASE_TEMPLATE("project.store.export", ExportForceDisableBlocksFalse_ForceTempBlocksTrue) { using namespace std::literals; - using namespace testutils; + using namespace projectstore_testutils; ScopedTemporaryDirectory TempDir; ScopedTemporaryDirectory ExportDir; @@ -3684,56 +4392,712 @@ TEST_CASE_TEMPLATE("project.store.export", false, nullptr); - CHECK(ExportResult.ErrorCode == 0); + REQUIRE(ExportResult.ErrorCode == 0); Ref<ProjectStore::Oplog> OplogImport = Project->NewOplog("oplog2", {}); CHECK(OplogImport); - RemoteProjectStore::Result ImportResult = LoadOplog(CidStore, - *RemoteStore, - *OplogImport, - NetworkPool, - WorkerPool, - /*Force*/ false, - /*IgnoreMissingAttachments*/ false, - /*CleanOplog*/ false, - nullptr); + int OpJobIndex = 0; + TestJobContext OpJobContext(OpJobIndex); + + RemoteProjectStore::Result ImportResult = LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + .RemoteStore = *RemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = Oid::Zero, + .Oplog = *OplogImport, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed, + .OptionalJobContext = &OpJobContext}); CHECK(ImportResult.ErrorCode == 0); - - RemoteProjectStore::Result ImportForceResult = LoadOplog(CidStore, - *RemoteStore, - *OplogImport, - NetworkPool, - WorkerPool, - /*Force*/ true, - /*IgnoreMissingAttachments*/ false, - /*CleanOplog*/ false, - nullptr); + OpJobIndex++; + + RemoteProjectStore::Result ImportForceResult = LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + .RemoteStore = *RemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = Oid::Zero, + .Oplog = *OplogImport, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = true, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed, + .OptionalJobContext = &OpJobContext}); CHECK(ImportForceResult.ErrorCode == 0); - - RemoteProjectStore::Result ImportCleanResult = LoadOplog(CidStore, - *RemoteStore, - *OplogImport, - NetworkPool, - WorkerPool, - /*Force*/ false, - /*IgnoreMissingAttachments*/ false, - /*CleanOplog*/ true, - nullptr); + OpJobIndex++; + + RemoteProjectStore::Result ImportCleanResult = LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + .RemoteStore = *RemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = Oid::Zero, + .Oplog = *OplogImport, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = true, + .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed, + .OptionalJobContext = &OpJobContext}); CHECK(ImportCleanResult.ErrorCode == 0); - - RemoteProjectStore::Result ImportForceCleanResult = LoadOplog(CidStore, - *RemoteStore, - *OplogImport, - NetworkPool, - WorkerPool, - /*Force*/ true, - /*IgnoreMissingAttachments*/ false, - /*CleanOplog*/ true, - nullptr); + OpJobIndex++; + + RemoteProjectStore::Result ImportForceCleanResult = + LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + .RemoteStore = *RemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = Oid::Zero, + .Oplog = *OplogImport, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = true, + .IgnoreMissingAttachments = false, + .CleanOplog = true, + .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed, + .OptionalJobContext = &OpJobContext}); CHECK(ImportForceCleanResult.ErrorCode == 0); + OpJobIndex++; } +// Common oplog setup used by the two tests below. +// Returns a FileRemoteStore backed by ExportDir that has been populated with a SaveOplog call. +// Keeps the test data identical to project.store.export so the two test suites exercise the same blocks/attachments. +static RemoteProjectStore::Result +SetupExportStore(CidStore& CidStore, + ProjectStore::Project& Project, + WorkerThreadPool& NetworkPool, + WorkerThreadPool& WorkerPool, + const std::filesystem::path& ExportDir, + std::shared_ptr<RemoteProjectStore>& OutRemoteStore) +{ + using namespace projectstore_testutils; + using namespace std::literals; + + Ref<ProjectStore::Oplog> Oplog = Project.NewOplog("oplog_export", {}); + if (!Oplog) + { + return RemoteProjectStore::Result{.ErrorCode = -1}; + } + + Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(Oid::NewOid(), {})); + Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{77}))); + Oplog->AppendNewOplogEntry( + CreateBulkDataOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{7123, 583, 690, 99}))); + Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{55, 122}))); + Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage( + Oid::NewOid(), + CreateAttachments(std::initializer_list<size_t>{256u * 1024u, 92u * 1024u}, OodleCompressionLevel::None))); + + FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = 64u * 1024, + .MaxChunksPerBlock = 1000, + .MaxChunkEmbedSize = 32 * 1024u, + .ChunkFileSizeLimit = 64u * 1024u}, + /*.FolderPath =*/ExportDir, + /*.Name =*/std::string("oplog_export"), + /*.OptionalBaseName =*/std::string(), + /*.ForceDisableBlocks =*/false, + /*.ForceEnableTempBlocks =*/false}; + + OutRemoteStore = CreateFileRemoteStore(Log(), Options); + return SaveOplog(CidStore, + *OutRemoteStore, + Project, + *Oplog, + NetworkPool, + WorkerPool, + Options.MaxBlockSize, + Options.MaxChunksPerBlock, + Options.MaxChunkEmbedSize, + Options.ChunkFileSizeLimit, + /*EmbedLooseFiles*/ true, + /*ForceUpload*/ false, + /*IgnoreMissingAttachments*/ false, + /*OptionalContext*/ nullptr); +} + +// Creates an export store with a single oplog entry that packs six 512 KB chunks into one +// ~3 MB block (MaxBlockSize = 8 MB). The resulting block slack (~1.5 MB) far exceeds the +// 512 KB threshold that ChunkBlockAnalyser requires before it will consider partial-block +// downloads instead of full-block downloads. +// +// This function is self-contained: it creates its own GcManager, CidStore, ProjectStore and +// Project internally so that each call is independent of any outer test context. After +// SaveOplog returns, all persistent data lives on disk inside ExportDir and the caller can +// freely query OutRemoteStore without holding any references to the internal context. +static RemoteProjectStore::Result +SetupPartialBlockExportStore(WorkerThreadPool& NetworkPool, + WorkerThreadPool& WorkerPool, + const std::filesystem::path& ExportDir, + std::shared_ptr<RemoteProjectStore>& OutRemoteStore) +{ + using namespace projectstore_testutils; + using namespace std::literals; + + // Self-contained CAS and project store. Subdirectories of ExportDir keep everything + // together without relying on the outer TEST_CASE's ExportCidStore / ExportProject. + GcManager LocalGc; + CidStore LocalCidStore(LocalGc); + CidStoreConfiguration LocalCidConfig = {.RootDirectory = ExportDir / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + LocalCidStore.Initialize(LocalCidConfig); + + std::filesystem::path LocalProjectBasePath = ExportDir / "proj"; + ProjectStore LocalProjectStore(LocalCidStore, LocalProjectBasePath, LocalGc, ProjectStore::Configuration{}); + Ref<ProjectStore::Project> LocalProject(LocalProjectStore.NewProject(LocalProjectBasePath / "p"sv, + "p"sv, + (ExportDir / "root").string(), + (ExportDir / "engine").string(), + (ExportDir / "game").string(), + (ExportDir / "game" / "game.uproject").string())); + + Ref<ProjectStore::Oplog> Oplog = LocalProject->NewOplog("oplog_partial_block", {}); + if (!Oplog) + { + return RemoteProjectStore::Result{.ErrorCode = -1}; + } + + // Six 512 KB chunks with OodleCompressionLevel::None so the compressed size stays large + // and the block genuinely exceeds the 512 KB slack threshold. + Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage( + Oid::NewOid(), + CreateAttachments(std::initializer_list<size_t>{512u * 1024u, 512u * 1024u, 512u * 1024u, 512u * 1024u, 512u * 1024u, 512u * 1024u}, + OodleCompressionLevel::None))); + + // MaxChunkEmbedSize must be larger than the compressed size of each 512 KB chunk + // (OodleCompressionLevel::None → compressed ≈ raw ≈ 512 KB). With the legacy + // 32 KB limit all six chunks would become loose large attachments and no block would + // be created, so we use the production default of 1.5 MB instead. + FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = 8u * 1024u * 1024u, + .MaxChunksPerBlock = 1000, + .MaxChunkEmbedSize = RemoteStoreOptions::DefaultMaxChunkEmbedSize, + .ChunkFileSizeLimit = 64u * 1024u * 1024u}, + /*.FolderPath =*/ExportDir, + /*.Name =*/std::string("oplog_partial_block"), + /*.OptionalBaseName =*/std::string(), + /*.ForceDisableBlocks =*/false, + /*.ForceEnableTempBlocks =*/false}; + OutRemoteStore = CreateFileRemoteStore(Log(), Options); + return SaveOplog(LocalCidStore, + *OutRemoteStore, + *LocalProject, + *Oplog, + NetworkPool, + WorkerPool, + Options.MaxBlockSize, + Options.MaxChunksPerBlock, + Options.MaxChunkEmbedSize, + Options.ChunkFileSizeLimit, + /*EmbedLooseFiles*/ true, + /*ForceUpload*/ false, + /*IgnoreMissingAttachments*/ false, + /*OptionalContext*/ nullptr); +} + +// Returns the first block hash that has at least MinChunkCount chunks, or a zero IoHash +// if no qualifying block exists in Store. +static IoHash +FindBlockWithMultipleChunks(RemoteProjectStore& Store, size_t MinChunkCount) +{ + RemoteProjectStore::LoadContainerResult ContainerResult = Store.LoadContainer(); + if (ContainerResult.ErrorCode != 0) + { + return {}; + } + std::vector<IoHash> BlockHashes = GetBlockHashesFromOplog(ContainerResult.ContainerObject); + if (BlockHashes.empty()) + { + return {}; + } + RemoteProjectStore::GetBlockDescriptionsResult Descriptions = Store.GetBlockDescriptions(BlockHashes, nullptr, Oid{}); + if (Descriptions.ErrorCode != 0) + { + return {}; + } + for (const ChunkBlockDescription& Desc : Descriptions.Blocks) + { + if (Desc.ChunkRawHashes.size() >= MinChunkCount) + { + return Desc.BlockHash; + } + } + return {}; +} + +// Loads BlockHash from Source and inserts every even-indexed chunk (0, 2, 4, …) into +// TargetCidStore. Odd-indexed chunks are left absent so that when an import is run +// against the same block, HasAttachment returns false for three non-adjacent positions +// — the minimum needed to exercise the multi-range partial-block download paths. +static void +SeedCidStoreWithAlternateChunks(CidStore& TargetCidStore, RemoteProjectStore& Source, const IoHash& BlockHash) +{ + RemoteProjectStore::LoadAttachmentResult BlockResult = Source.LoadAttachment(BlockHash); + if (BlockResult.ErrorCode != 0 || !BlockResult.Bytes) + { + return; + } + + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(BlockResult.Bytes), RawHash, RawSize); + if (!Compressed) + { + return; + } + CompositeBuffer BlockPayload = Compressed.DecompressToComposite(); + if (!BlockPayload) + { + return; + } + + uint32_t ChunkIndex = 0; + uint64_t HeaderSize = 0; + IterateChunkBlock( + BlockPayload.Flatten(), + [&TargetCidStore, &ChunkIndex](CompressedBuffer&& Chunk, const IoHash& AttachmentHash) { + if (ChunkIndex % 2 == 0) + { + IoBuffer ChunkData = Chunk.GetCompressed().Flatten().AsIoBuffer(); + TargetCidStore.AddChunk(ChunkData, AttachmentHash); + } + ++ChunkIndex; + }, + HeaderSize); +} + +TEST_CASE("project.store.import.context_settings") +{ + using namespace std::literals; + using namespace projectstore_testutils; + + ScopedTemporaryDirectory TempDir; + ScopedTemporaryDirectory ExportDir; + + std::filesystem::path RootDir = TempDir.Path() / "root"; + std::filesystem::path EngineRootDir = TempDir.Path() / "engine"; + std::filesystem::path ProjectRootDir = TempDir.Path() / "game"; + std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject"; + + // Export-side CAS and project store: used only by SetupExportStore to build the remote store + // payload. Kept separate from the import side so the two CAS instances are disjoint. + GcManager ExportGc; + CidStore ExportCidStore(ExportGc); + CidStoreConfiguration ExportCidConfig = {.RootDirectory = TempDir.Path() / "export_cas", + .TinyValueThreshold = 1024, + .HugeValueThreshold = 4096}; + ExportCidStore.Initialize(ExportCidConfig); + + std::filesystem::path ExportBasePath = TempDir.Path() / "export_projectstore"; + ProjectStore ExportProjectStore(ExportCidStore, ExportBasePath, ExportGc, ProjectStore::Configuration{}); + Ref<ProjectStore::Project> ExportProject(ExportProjectStore.NewProject(ExportBasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + ProjectRootDir.string(), + ProjectFilePath.string())); + + uint32_t NetworkWorkerCount = Max(GetHardwareConcurrency() / 4u, 2u); + uint32_t WorkerCount = (NetworkWorkerCount < GetHardwareConcurrency()) ? Max(GetHardwareConcurrency() - NetworkWorkerCount, 4u) : 4u; + WorkerThreadPool WorkerPool(WorkerCount); + WorkerThreadPool NetworkPool(NetworkWorkerCount); + + std::shared_ptr<RemoteProjectStore> RemoteStore; + RemoteProjectStore::Result ExportResult = + SetupExportStore(ExportCidStore, *ExportProject, NetworkPool, WorkerPool, ExportDir.Path(), RemoteStore); + REQUIRE(ExportResult.ErrorCode == 0); + + // Import-side CAS and project store: starts empty, mirroring a fresh machine that has never + // downloaded the data. HasAttachment() therefore returns false for every chunk, so the import + // genuinely contacts the remote store without needing ForceDownload on the populate pass. + GcManager ImportGc; + CidStore ImportCidStore(ImportGc); + CidStoreConfiguration ImportCidConfig = {.RootDirectory = TempDir.Path() / "import_cas", + .TinyValueThreshold = 1024, + .HugeValueThreshold = 4096}; + ImportCidStore.Initialize(ImportCidConfig); + + std::filesystem::path ImportBasePath = TempDir.Path() / "import_projectstore"; + ProjectStore ImportProjectStore(ImportCidStore, ImportBasePath, ImportGc, ProjectStore::Configuration{}); + Ref<ProjectStore::Project> ImportProject(ImportProjectStore.NewProject(ImportBasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + ProjectRootDir.string(), + ProjectFilePath.string())); + + const Oid CacheBuildId = Oid::NewOid(); + BuildStorageCache::Statistics CacheStats; + std::unique_ptr<BuildStorageCache> Cache = CreateInMemoryBuildStorageCache(256u, CacheStats); + auto ResetCacheStats = [&]() { + CacheStats.TotalBytesRead = 0; + CacheStats.TotalBytesWritten = 0; + CacheStats.TotalRequestCount = 0; + CacheStats.TotalRequestTimeUs = 0; + CacheStats.TotalExecutionTimeUs = 0; + CacheStats.PeakSentBytes = 0; + CacheStats.PeakReceivedBytes = 0; + CacheStats.PeakBytesPerSec = 0; + CacheStats.PutBlobCount = 0; + CacheStats.PutBlobByteCount = 0; + }; + + int OpJobIndex = 0; + + TestJobContext OpJobContext(OpJobIndex); + + // Helper: run a LoadOplog against the import-side CAS/project with the given context knobs. + // Each call creates a fresh oplog so repeated calls within one SUBCASE don't short-circuit on + // already-present data. + auto DoImport = [&](BuildStorageCache* OptCache, + EPartialBlockRequestMode Mode, + double StoreLatency, + uint64_t StoreRanges, + double CacheLatency, + uint64_t CacheRanges, + bool PopulateCache, + bool ForceDownload) -> RemoteProjectStore::Result { + Ref<ProjectStore::Oplog> ImportOplog = ImportProject->NewOplog(fmt::format("import_{}", OpJobIndex++), {}); + return LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + .RemoteStore = *RemoteStore, + .OptionalCache = OptCache, + .CacheBuildId = CacheBuildId, + .Oplog = *ImportOplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = ForceDownload, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = Mode, + .PopulateCache = PopulateCache, + .StoreLatencySec = StoreLatency, + .StoreMaxRangeCountPerRequest = StoreRanges, + .CacheLatencySec = CacheLatency, + .CacheMaxRangeCountPerRequest = CacheRanges, + .OptionalJobContext = &OpJobContext}); + }; + + // Shorthand: Mode=All, low latency, 128 ranges for both store and cache. + auto ImportAll = [&](BuildStorageCache* OptCache, bool Populate, bool Force) { + return DoImport(OptCache, EPartialBlockRequestMode::All, 0.001, 128u, 0.001, 128u, Populate, Force); + }; + + SUBCASE("mode_off_no_cache") + { + // Baseline: no partial block requests, no cache. + RemoteProjectStore::Result R = + DoImport(nullptr, EPartialBlockRequestMode::Off, -1.0, (uint64_t)-1, -1.0, (uint64_t)-1, false, false); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("mode_all_multirange_cloud_no_cache") + { + // StoreMaxRangeCountPerRequest > 1 → MultiRange cloud path. + RemoteProjectStore::Result R = DoImport(nullptr, EPartialBlockRequestMode::All, 0.001, 128u, -1.0, 0u, false, false); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("mode_all_singlerange_cloud_no_cache") + { + // StoreMaxRangeCountPerRequest == 1 → SingleRange cloud path. + RemoteProjectStore::Result R = DoImport(nullptr, EPartialBlockRequestMode::All, 0.001, 1u, -1.0, 0u, false, false); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("mode_mixed_high_latency_no_cache") + { + // High store latency encourages range merging; Mixed uses SingleRange for cloud, Off for cache. + RemoteProjectStore::Result R = DoImport(nullptr, EPartialBlockRequestMode::Mixed, 0.1, 128u, -1.0, 0u, false, false); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("cache_populate_and_hit") + { + // First import: ImportCidStore is empty so all blocks are downloaded from the remote store + // and written to the cache. + RemoteProjectStore::Result PopulateResult = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false); + CHECK(PopulateResult.ErrorCode == 0); + CHECK(CacheStats.PutBlobCount > 0); + + // Re-import with ForceDownload=true: all chunks are now in ImportCidStore but Force overrides + // HasAttachment() so the download logic re-runs and serves blocks from the cache instead of + // the remote store. + ResetCacheStats(); + RemoteProjectStore::Result HitResult = ImportAll(Cache.get(), /*PopulateCache=*/false, /*Force=*/true); + CHECK(HitResult.ErrorCode == 0); + CHECK(CacheStats.PutBlobCount == 0); + // TotalRequestCount covers both full-blob cache hits and partial-range cache hits. + CHECK(CacheStats.TotalRequestCount > 0); + } + + SUBCASE("cache_no_populate_flag") + { + // Cache is provided but PopulateCache=false: blocks are downloaded to ImportCidStore but + // nothing should be written to the cache. + RemoteProjectStore::Result R = ImportAll(Cache.get(), /*PopulateCache=*/false, /*Force=*/false); + CHECK(R.ErrorCode == 0); + CHECK(CacheStats.PutBlobCount == 0); + } + + SUBCASE("mode_zencacheonly_cache_multirange") + { + // Pre-populate the cache via a plain import, then re-import with ZenCacheOnly + + // CacheMaxRangeCountPerRequest=128. With 100% of chunks needed, all blocks go to + // FullBlockIndexes and GetBuildBlob (full blob) is called from the cache. + // CacheMaxRangeCountPerRequest > 1 would route partial downloads through GetBuildBlobRanges + // if the analyser ever emits BlockRanges entries. + RemoteProjectStore::Result Populate = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false); + CHECK(Populate.ErrorCode == 0); + ResetCacheStats(); + + RemoteProjectStore::Result R = DoImport(Cache.get(), EPartialBlockRequestMode::ZenCacheOnly, 0.1, 128u, 0.001, 128u, false, true); + CHECK(R.ErrorCode == 0); + CHECK(CacheStats.TotalRequestCount > 0); + } + + SUBCASE("mode_zencacheonly_cache_singlerange") + { + // Pre-populate the cache, then re-import with ZenCacheOnly + CacheMaxRangeCountPerRequest=1. + // With 100% of chunks needed the analyser sends all blocks to FullBlockIndexes (full-block + // download path), which calls GetBuildBlob with no range offset — a full-blob cache hit. + // The single-range vs multi-range distinction only matters for the partial-block (BlockRanges) + // path, which is not reached when all chunks are needed. + RemoteProjectStore::Result Populate = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false); + CHECK(Populate.ErrorCode == 0); + ResetCacheStats(); + + RemoteProjectStore::Result R = DoImport(Cache.get(), EPartialBlockRequestMode::ZenCacheOnly, 0.1, 128u, 0.001, 1u, false, true); + CHECK(R.ErrorCode == 0); + CHECK(CacheStats.TotalRequestCount > 0); + } + + SUBCASE("mode_all_cache_and_cloud_multirange") + { + // Pre-populate cache; All mode uses multi-range for both the cache and cloud paths. + RemoteProjectStore::Result Populate = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false); + CHECK(Populate.ErrorCode == 0); + ResetCacheStats(); + + RemoteProjectStore::Result R = ImportAll(Cache.get(), /*PopulateCache=*/false, /*Force=*/true); + CHECK(R.ErrorCode == 0); + CHECK(CacheStats.TotalRequestCount > 0); + } + + SUBCASE("partial_block_cloud_multirange") + { + // Export store with 6 × 512 KB chunks packed into one ~3 MB block. + ScopedTemporaryDirectory PartialExportDir; + std::shared_ptr<RemoteProjectStore> PartialRemoteStore; + RemoteProjectStore::Result ExportR = + SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore); + REQUIRE(ExportR.ErrorCode == 0); + + // Seeding even-indexed chunks (0, 2, 4) leaves odd ones (1, 3, 5) absent in + // ImportCidStore. Three non-adjacent needed positions → three BlockRangeDescriptors. + IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u); + CHECK(BlockHash != IoHash::Zero); + SeedCidStoreWithAlternateChunks(ImportCidStore, *PartialRemoteStore, BlockHash); + + // StoreMaxRangeCountPerRequest=128 → all three ranges sent in one LoadAttachmentRanges call. + Ref<ProjectStore::Oplog> PartialOplog = ImportProject->NewOplog(fmt::format("partial_cloud_multi_{}", OpJobIndex++), {}); + RemoteProjectStore::Result R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = CacheBuildId, + .Oplog = *PartialOplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::All, + .PopulateCache = false, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 128u, + .CacheLatencySec = -1.0, + .CacheMaxRangeCountPerRequest = 0u, + .OptionalJobContext = &OpJobContext}); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("partial_block_cloud_singlerange") + { + // Same block layout as partial_block_cloud_multirange but StoreMaxRangeCountPerRequest=1. + // DownloadPartialBlock issues one LoadAttachmentRanges call per range. + ScopedTemporaryDirectory PartialExportDir; + std::shared_ptr<RemoteProjectStore> PartialRemoteStore; + RemoteProjectStore::Result ExportR = + SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore); + REQUIRE(ExportR.ErrorCode == 0); + + IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u); + CHECK(BlockHash != IoHash::Zero); + SeedCidStoreWithAlternateChunks(ImportCidStore, *PartialRemoteStore, BlockHash); + + Ref<ProjectStore::Oplog> PartialOplog = ImportProject->NewOplog(fmt::format("partial_cloud_single_{}", OpJobIndex++), {}); + RemoteProjectStore::Result R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = CacheBuildId, + .Oplog = *PartialOplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::All, + .PopulateCache = false, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 1u, + .CacheLatencySec = -1.0, + .CacheMaxRangeCountPerRequest = 0u, + .OptionalJobContext = &OpJobContext}); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("partial_block_cache_multirange") + { + ScopedTemporaryDirectory PartialExportDir; + std::shared_ptr<RemoteProjectStore> PartialRemoteStore; + RemoteProjectStore::Result ExportR = + SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore); + REQUIRE(ExportR.ErrorCode == 0); + + IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u); + CHECK(BlockHash != IoHash::Zero); + + // Phase 1: ImportCidStore starts empty → full block download from remote → PutBuildBlob + // populates the cache. + { + Ref<ProjectStore::Oplog> Phase1Oplog = ImportProject->NewOplog(fmt::format("partial_cache_multi_p1_{}", OpJobIndex++), {}); + RemoteProjectStore::Result Phase1R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = Cache.get(), + .CacheBuildId = CacheBuildId, + .Oplog = *Phase1Oplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::All, + .PopulateCache = true, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 128u, + .CacheLatencySec = 0.001, + .CacheMaxRangeCountPerRequest = 128u, + .OptionalJobContext = &OpJobContext}); + CHECK(Phase1R.ErrorCode == 0); + CHECK(CacheStats.PutBlobCount > 0); + } + ResetCacheStats(); + + // Phase 2: fresh CidStore with only even-indexed chunks seeded. + // HasAttachment returns false for odd chunks (1, 3, 5) → three BlockRangeDescriptors. + // Block is in cache from Phase 1 → cache partial path. + // CacheMaxRangeCountPerRequest=128 → SubRangeCount=3 > 1 → GetBuildBlobRanges. + GcManager Phase2Gc; + CidStore Phase2CidStore(Phase2Gc); + CidStoreConfiguration Phase2CidConfig = {.RootDirectory = TempDir.Path() / "partial_cas", + .TinyValueThreshold = 1024, + .HugeValueThreshold = 4096}; + Phase2CidStore.Initialize(Phase2CidConfig); + SeedCidStoreWithAlternateChunks(Phase2CidStore, *PartialRemoteStore, BlockHash); + + Ref<ProjectStore::Oplog> Phase2Oplog = ImportProject->NewOplog(fmt::format("partial_cache_multi_p2_{}", OpJobIndex++), {}); + RemoteProjectStore::Result Phase2R = LoadOplog(LoadOplogContext{.ChunkStore = Phase2CidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = Cache.get(), + .CacheBuildId = CacheBuildId, + .Oplog = *Phase2Oplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::ZenCacheOnly, + .PopulateCache = false, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 128u, + .CacheLatencySec = 0.001, + .CacheMaxRangeCountPerRequest = 128u, + .OptionalJobContext = &OpJobContext}); + CHECK(Phase2R.ErrorCode == 0); + CHECK(CacheStats.TotalRequestCount > 0); + } + + SUBCASE("partial_block_cache_singlerange") + { + ScopedTemporaryDirectory PartialExportDir; + std::shared_ptr<RemoteProjectStore> PartialRemoteStore; + RemoteProjectStore::Result ExportR = + SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore); + REQUIRE(ExportR.ErrorCode == 0); + + IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u); + CHECK(BlockHash != IoHash::Zero); + + // Phase 1: full block download from remote into cache. + { + Ref<ProjectStore::Oplog> Phase1Oplog = ImportProject->NewOplog(fmt::format("partial_cache_single_p1_{}", OpJobIndex++), {}); + RemoteProjectStore::Result Phase1R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = Cache.get(), + .CacheBuildId = CacheBuildId, + .Oplog = *Phase1Oplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::All, + .PopulateCache = true, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 128u, + .CacheLatencySec = 0.001, + .CacheMaxRangeCountPerRequest = 128u, + .OptionalJobContext = &OpJobContext}); + CHECK(Phase1R.ErrorCode == 0); + CHECK(CacheStats.PutBlobCount > 0); + } + ResetCacheStats(); + + // Phase 2: fresh CidStore with only even-indexed chunks seeded. + // CacheMaxRangeCountPerRequest=1 → SubRangeCount=Min(3,1)=1 → GetBuildBlob with range + // offset (single-range legacy cache path), called once per needed chunk range. + GcManager Phase2Gc; + CidStore Phase2CidStore(Phase2Gc); + CidStoreConfiguration Phase2CidConfig = {.RootDirectory = TempDir.Path() / "partial_cas_single", + .TinyValueThreshold = 1024, + .HugeValueThreshold = 4096}; + Phase2CidStore.Initialize(Phase2CidConfig); + SeedCidStoreWithAlternateChunks(Phase2CidStore, *PartialRemoteStore, BlockHash); + + Ref<ProjectStore::Oplog> Phase2Oplog = ImportProject->NewOplog(fmt::format("partial_cache_single_p2_{}", OpJobIndex++), {}); + RemoteProjectStore::Result Phase2R = LoadOplog(LoadOplogContext{.ChunkStore = Phase2CidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = Cache.get(), + .CacheBuildId = CacheBuildId, + .Oplog = *Phase2Oplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::ZenCacheOnly, + .PopulateCache = false, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 128u, + .CacheLatencySec = 0.001, + .CacheMaxRangeCountPerRequest = 1u, + .OptionalJobContext = &OpJobContext}); + CHECK(Phase2R.ErrorCode == 0); + CHECK(CacheStats.TotalRequestCount > 0); + } +} + +TEST_SUITE_END(); + #endif // ZEN_WITH_TESTS void diff --git a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp index ab82edbef..115d6438d 100644 --- a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp @@ -159,7 +159,8 @@ public: virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override { - std::string LoadRequest = fmt::format("/{}/oplog/{}/rpc"sv, m_Project, m_Oplog); + LoadAttachmentsResult Result; + std::string LoadRequest = fmt::format("/{}/oplog/{}/rpc"sv, m_Project, m_Oplog); CbObject Request; { @@ -187,7 +188,7 @@ public: HttpClient::Response Response = m_Client.Post(LoadRequest, Request, HttpClient::Accept(ZenContentType::kCbPackage)); AddStats(Response); - LoadAttachmentsResult Result = LoadAttachmentsResult{ConvertResult(Response)}; + Result = LoadAttachmentsResult{ConvertResult(Response)}; if (Result.ErrorCode) { Result.Reason = fmt::format("Failed fetching {} oplog attachments from {}/{}/{}. Reason: '{}'", @@ -249,20 +250,49 @@ public: return GetKnownBlocksResult{{.ErrorCode = static_cast<int>(HttpResponseCode::NoContent)}}; } + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span<const IoHash> BlockHashes, + BuildStorageCache* OptionalCache, + const Oid& CacheBuildId) override + { + ZEN_UNUSED(BlockHashes, OptionalCache, CacheBuildId); + return GetBlockDescriptionsResult{Result{.ErrorCode = int(HttpResponseCode::NotFound)}}; + } + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override { + LoadAttachmentResult Result; std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash); HttpClient::Response Response = m_Client.Download(LoadRequest, m_TempFilePath, HttpClient::Accept(ZenContentType::kCompressedBinary)); AddStats(Response); - LoadAttachmentResult Result = LoadAttachmentResult{ConvertResult(Response)}; - if (!Result.ErrorCode) + Result = LoadAttachmentResult{ConvertResult(Response)}; + if (Result.ErrorCode) { - Result.Bytes = Response.ResponsePayload; - Result.Bytes.MakeOwned(); + Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'", + m_ProjectStoreUrl, + m_Project, + m_Oplog, + RawHash, + Result.Reason); } - if (!Result.ErrorCode) + Result.Bytes = Response.ResponsePayload; + Result.Bytes.MakeOwned(); + return Result; + } + + virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, + std::span<const std::pair<uint64_t, uint64_t>> Ranges) override + { + ZEN_ASSERT(!Ranges.empty()); + LoadAttachmentRangesResult Result; + std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash); + HttpClient::Response Response = + m_Client.Download(LoadRequest, m_TempFilePath, HttpClient::Accept(ZenContentType::kCompressedBinary)); + AddStats(Response); + + Result = LoadAttachmentRangesResult{ConvertResult(Response)}; + if (Result.ErrorCode) { Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'", m_ProjectStoreUrl, @@ -271,11 +301,13 @@ public: RawHash, Result.Reason); } + else + { + Result.Ranges = std::vector<std::pair<uint64_t, uint64_t>>(Ranges.begin(), Ranges.end()); + } return Result; } - virtual void Flush() override {} - private: void AddStats(const HttpClient::Response& Result) { diff --git a/src/zenserver-test/buildstore-tests.cpp b/src/zenserver-test/buildstore-tests.cpp index 02b308485..cf9b10896 100644 --- a/src/zenserver-test/buildstore-tests.cpp +++ b/src/zenserver-test/buildstore-tests.cpp @@ -27,6 +27,8 @@ namespace zen::tests { using namespace std::literals; +TEST_SUITE_BEGIN("server.buildstore"); + TEST_CASE("buildstore.blobs") { std::filesystem::path SystemRootPath = TestEnv.CreateNewTestDir(); @@ -36,7 +38,8 @@ TEST_CASE("buildstore.blobs") std::string_view Bucket = "bkt"sv; Oid BuildId = Oid::NewOid(); - std::vector<IoHash> CompressedBlobsHashes; + std::vector<IoHash> CompressedBlobsHashes; + std::vector<uint64_t> CompressedBlobsSizes; { ZenServerInstance Instance(TestEnv); @@ -51,6 +54,7 @@ TEST_CASE("buildstore.blobs") IoBuffer Blob = CreateSemiRandomBlob(4711 + I * 7); CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(std::move(Blob))); CompressedBlobsHashes.push_back(CompressedBlob.DecodeRawHash()); + CompressedBlobsSizes.push_back(CompressedBlob.GetCompressedSize()); IoBuffer Payload = std::move(CompressedBlob).GetCompressed().Flatten().AsIoBuffer(); Payload.SetContentType(ZenContentType::kCompressedBinary); @@ -107,6 +111,7 @@ TEST_CASE("buildstore.blobs") IoBuffer Blob = CreateSemiRandomBlob(5713 + I * 7); CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(std::move(Blob))); CompressedBlobsHashes.push_back(CompressedBlob.DecodeRawHash()); + CompressedBlobsSizes.push_back(CompressedBlob.GetCompressedSize()); IoBuffer Payload = std::move(CompressedBlob).GetCompressed().Flatten().AsIoBuffer(); Payload.SetContentType(ZenContentType::kCompressedBinary); @@ -141,6 +146,201 @@ TEST_CASE("buildstore.blobs") CHECK(IoHash::HashBuffer(Decompressed) == RawHash); } } + + { + // Single-range Get + + ZenServerInstance Instance(TestEnv); + + const uint16_t PortNumber = + Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath)); + CHECK(PortNumber != 0); + + HttpClient Client(Instance.GetBaseUri() + "/builds/"); + + { + const IoHash& RawHash = CompressedBlobsHashes.front(); + uint64_t BlobSize = CompressedBlobsSizes.front(); + + std::vector<std::pair<uint64_t, uint64_t>> Ranges = {{BlobSize / 16 * 1, BlobSize / 2}}; + + uint64_t RangeSizeSum = Ranges.front().second; + + HttpClient::KeyValueMap Headers; + + Headers.Entries.insert( + {"Range", fmt::format("bytes={}-{}", Ranges.front().first, Ranges.front().first + Ranges.front().second - 1)}); + + HttpClient::Response Result = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), Headers); + REQUIRE(Result); + IoBuffer Payload = Result.ResponsePayload; + CHECK_EQ(RangeSizeSum, Payload.GetSize()); + + HttpClient::Response FullBlobResult = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + HttpClient::Accept(ZenContentType::kCompressedBinary)); + REQUIRE(FullBlobResult); + MemoryView ActualRange = FullBlobResult.ResponsePayload.GetView().Mid(Ranges.front().first, Ranges.front().second); + MemoryView RangeView = Payload.GetView(); + CHECK(ActualRange.EqualBytes(RangeView)); + } + } + + { + // Single-range Post + + ZenServerInstance Instance(TestEnv); + + const uint16_t PortNumber = + Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath)); + CHECK(PortNumber != 0); + + HttpClient Client(Instance.GetBaseUri() + "/builds/"); + + { + uint64_t RangeSizeSum = 0; + + const IoHash& RawHash = CompressedBlobsHashes.front(); + uint64_t BlobSize = CompressedBlobsSizes.front(); + + std::vector<std::pair<uint64_t, uint64_t>> Ranges = {{BlobSize / 16 * 1, BlobSize / 2}}; + + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + { + for (const std::pair<uint64_t, uint64_t>& Range : Ranges) + { + Writer.BeginObject(); + { + Writer.AddInteger("offset"sv, Range.first); + Writer.AddInteger("length"sv, Range.second); + RangeSizeSum += Range.second; + } + Writer.EndObject(); + } + } + Writer.EndArray(); // ranges + + HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + Writer.Save(), + HttpClient::Accept(ZenContentType::kCbPackage)); + REQUIRE(Result); + IoBuffer Payload = Result.ResponsePayload; + REQUIRE(Payload.GetContentType() == ZenContentType::kCbPackage); + + CbPackage ResponsePackage = ParsePackageMessage(Payload); + CbObjectView ResponseObject = ResponsePackage.GetObject(); + + CbArrayView RangeArray = ResponseObject["ranges"sv].AsArrayView(); + CHECK_EQ(RangeArray.Num(), Ranges.size()); + size_t RangeOffset = 0; + for (CbFieldView View : RangeArray) + { + CbObjectView Range = View.AsObjectView(); + CHECK_EQ(Range["offset"sv].AsUInt64(), Ranges[RangeOffset].first); + CHECK_EQ(Range["length"sv].AsUInt64(), Ranges[RangeOffset].second); + RangeOffset++; + } + + const CbAttachment* DataAttachment = ResponsePackage.FindAttachment(RawHash); + REQUIRE(DataAttachment); + SharedBuffer PayloadRanges = DataAttachment->AsBinary(); + CHECK_EQ(RangeSizeSum, PayloadRanges.GetSize()); + + HttpClient::Response FullBlobResult = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + HttpClient::Accept(ZenContentType::kCompressedBinary)); + REQUIRE(FullBlobResult); + + uint64_t Offset = 0; + for (const std::pair<uint64_t, uint64_t>& Range : Ranges) + { + MemoryView ActualRange = FullBlobResult.ResponsePayload.GetView().Mid(Range.first, Range.second); + MemoryView RangeView = PayloadRanges.GetView().Mid(Offset, Range.second); + CHECK(ActualRange.EqualBytes(RangeView)); + Offset += Range.second; + } + } + } + + { + // Multi-range + + ZenServerInstance Instance(TestEnv); + + const uint16_t PortNumber = + Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath)); + CHECK(PortNumber != 0); + + HttpClient Client(Instance.GetBaseUri() + "/builds/"); + + { + uint64_t RangeSizeSum = 0; + + const IoHash& RawHash = CompressedBlobsHashes.front(); + uint64_t BlobSize = CompressedBlobsSizes.front(); + + std::vector<std::pair<uint64_t, uint64_t>> Ranges = { + {BlobSize / 16 * 1, BlobSize / 20}, + {BlobSize / 16 * 3, BlobSize / 32}, + {BlobSize / 16 * 5, BlobSize / 16}, + {BlobSize - BlobSize / 16, BlobSize / 16 - 1}, + }; + + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + { + for (const std::pair<uint64_t, uint64_t>& Range : Ranges) + { + Writer.BeginObject(); + { + Writer.AddInteger("offset"sv, Range.first); + Writer.AddInteger("length"sv, Range.second); + RangeSizeSum += Range.second; + } + Writer.EndObject(); + } + } + Writer.EndArray(); // ranges + + HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + Writer.Save(), + HttpClient::Accept(ZenContentType::kCbPackage)); + REQUIRE(Result); + IoBuffer Payload = Result.ResponsePayload; + REQUIRE(Payload.GetContentType() == ZenContentType::kCbPackage); + + CbPackage ResponsePackage = ParsePackageMessage(Payload); + CbObjectView ResponseObject = ResponsePackage.GetObject(); + + CbArrayView RangeArray = ResponseObject["ranges"sv].AsArrayView(); + CHECK_EQ(RangeArray.Num(), Ranges.size()); + size_t RangeOffset = 0; + for (CbFieldView View : RangeArray) + { + CbObjectView Range = View.AsObjectView(); + CHECK_EQ(Range["offset"sv].AsUInt64(), Ranges[RangeOffset].first); + CHECK_EQ(Range["length"sv].AsUInt64(), Ranges[RangeOffset].second); + RangeOffset++; + } + + const CbAttachment* DataAttachment = ResponsePackage.FindAttachment(RawHash); + REQUIRE(DataAttachment); + SharedBuffer PayloadRanges = DataAttachment->AsBinary(); + CHECK_EQ(RangeSizeSum, PayloadRanges.GetSize()); + + HttpClient::Response FullBlobResult = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + HttpClient::Accept(ZenContentType::kCompressedBinary)); + REQUIRE(FullBlobResult); + + uint64_t Offset = 0; + for (const std::pair<uint64_t, uint64_t>& Range : Ranges) + { + MemoryView ActualRange = FullBlobResult.ResponsePayload.GetView().Mid(Range.first, Range.second); + MemoryView RangeView = PayloadRanges.GetView().Mid(Offset, Range.second); + CHECK(ActualRange.EqualBytes(RangeView)); + Offset += Range.second; + } + } + } } namespace { @@ -191,7 +391,7 @@ TEST_CASE("buildstore.metadata") HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/getBlobMetadata", Namespace, Bucket, BuildId), Payload, HttpClient::Accept(ZenContentType::kCbObject)); - CHECK(Result); + REQUIRE(Result); std::vector<CbObject> ResultMetadatas; @@ -372,7 +572,7 @@ TEST_CASE("buildstore.cache") { std::vector<BuildStorageCache::BlobExistsResult> Exists = Cache->BlobsExists(BuildId, BlobHashes); - CHECK(Exists.size() == BlobHashes.size()); + REQUIRE(Exists.size() == BlobHashes.size()); for (size_t I = 0; I < BlobCount; I++) { CHECK(Exists[I].HasBody); @@ -411,7 +611,7 @@ TEST_CASE("buildstore.cache") { std::vector<BuildStorageCache::BlobExistsResult> Exists = Cache->BlobsExists(BuildId, BlobHashes); - CHECK(Exists.size() == BlobHashes.size()); + REQUIRE(Exists.size() == BlobHashes.size()); for (size_t I = 0; I < BlobCount; I++) { CHECK(Exists[I].HasBody); @@ -419,7 +619,7 @@ TEST_CASE("buildstore.cache") } std::vector<CbObject> FetchedMetadatas = Cache->GetBlobMetadatas(BuildId, BlobHashes); - CHECK_EQ(BlobCount, FetchedMetadatas.size()); + REQUIRE_EQ(BlobCount, FetchedMetadatas.size()); for (size_t I = 0; I < BlobCount; I++) { @@ -440,7 +640,7 @@ TEST_CASE("buildstore.cache") { std::vector<BuildStorageCache::BlobExistsResult> Exists = Cache->BlobsExists(BuildId, BlobHashes); - CHECK(Exists.size() == BlobHashes.size()); + REQUIRE(Exists.size() == BlobHashes.size()); for (size_t I = 0; I < BlobCount * 2; I++) { CHECK(Exists[I].HasBody); @@ -451,7 +651,7 @@ TEST_CASE("buildstore.cache") CHECK_EQ(BlobCount, MetaDatas.size()); std::vector<CbObject> FetchedMetadatas = Cache->GetBlobMetadatas(BuildId, BlobHashes); - CHECK_EQ(BlobCount, FetchedMetadatas.size()); + REQUIRE_EQ(BlobCount, FetchedMetadatas.size()); for (size_t I = 0; I < BlobCount; I++) { @@ -474,7 +674,7 @@ TEST_CASE("buildstore.cache") CreateZenBuildStorageCache(Client, Stats, Namespace, Bucket, TempDir, GetTinyWorkerPool(EWorkloadType::Background))); std::vector<BuildStorageCache::BlobExistsResult> Exists = Cache->BlobsExists(BuildId, BlobHashes); - CHECK(Exists.size() == BlobHashes.size()); + REQUIRE(Exists.size() == BlobHashes.size()); for (size_t I = 0; I < BlobCount * 2; I++) { CHECK(Exists[I].HasBody); @@ -493,7 +693,7 @@ TEST_CASE("buildstore.cache") CHECK_EQ(BlobCount, MetaDatas.size()); std::vector<CbObject> FetchedMetadatas = Cache->GetBlobMetadatas(BuildId, BlobHashes); - CHECK_EQ(BlobCount, FetchedMetadatas.size()); + REQUIRE_EQ(BlobCount, FetchedMetadatas.size()); for (size_t I = 0; I < BlobCount; I++) { @@ -502,5 +702,7 @@ TEST_CASE("buildstore.cache") } } +TEST_SUITE_END(); + } // namespace zen::tests #endif diff --git a/src/zenserver-test/cache-tests.cpp b/src/zenserver-test/cache-tests.cpp index 0272d3797..334dd04ab 100644 --- a/src/zenserver-test/cache-tests.cpp +++ b/src/zenserver-test/cache-tests.cpp @@ -23,6 +23,8 @@ namespace zen::tests { +TEST_SUITE_BEGIN("server.cache"); + TEST_CASE("zcache.basic") { using namespace std::literals; @@ -145,7 +147,7 @@ TEST_CASE("zcache.cbpackage") for (const zen::CbAttachment& LhsAttachment : LhsAttachments) { const zen::CbAttachment* RhsAttachment = Rhs.FindAttachment(LhsAttachment.GetHash()); - CHECK(RhsAttachment); + REQUIRE(RhsAttachment); zen::SharedBuffer LhsBuffer = LhsAttachment.AsCompressedBinary().Decompress(); CHECK(!LhsBuffer.IsNull()); @@ -1373,14 +1375,8 @@ TEST_CASE("zcache.rpc") } } -TEST_CASE("zcache.failing.upstream") +TEST_CASE("zcache.failing.upstream" * doctest::skip()) { - // This is an exploratory test that takes a long time to run, so lets skip it by default - if (true) - { - return; - } - using namespace std::literals; using namespace utils; @@ -2669,6 +2665,8 @@ TEST_CASE("zcache.batchoperations") } } +TEST_SUITE_END(); + } // namespace zen::tests #endif diff --git a/src/zenserver-test/cacherequests.cpp b/src/zenserver-test/cacherequests.cpp index 46339aebb..f5302a359 100644 --- a/src/zenserver-test/cacherequests.cpp +++ b/src/zenserver-test/cacherequests.cpp @@ -1037,6 +1037,8 @@ namespace zen { namespace cacherequests { static CompressedBuffer MakeCompressedBuffer(size_t Size) { return CompressedBuffer::Compress(SharedBuffer(IoBuffer(Size))); }; + TEST_SUITE_BEGIN("server.cacherequests"); + TEST_CASE("cacherequests.put.cache.records") { PutCacheRecordsRequest EmptyRequest; @@ -1458,5 +1460,7 @@ namespace zen { namespace cacherequests { "!default!", Invalid)); } + + TEST_SUITE_END(); #endif }} // namespace zen::cacherequests diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp new file mode 100644 index 000000000..c90ac5d8b --- /dev/null +++ b/src/zenserver-test/compute-tests.cpp @@ -0,0 +1,1700 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/zencore.h> + +#if ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES + +# include <zenbase/zenbase.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/filesystem.h> +# include <zencore/guid.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zencore/thread.h> +# include <zencore/timer.h> +# include <zenhttp/httpclient.h> +# include <zenhttp/httpserver.h> +# include <zencompute/computeservice.h> +# include <zenstore/zenstore.h> +# include <zenutil/zenserverprocess.h> + +# include "zenserver-test.h" + +# include <thread> + +namespace zen::tests::compute { + +using namespace std::literals; + +// BuildSystemVersion and function version GUIDs matching zentest-appstub +static constexpr std::string_view kBuildSystemVersion = "17fe280d-ccd8-4be8-a9d1-89c944a70969"; +static constexpr std::string_view kRot13Version = "13131313-1313-1313-1313-131313131313"; +static constexpr std::string_view kSleepVersion = "88888888-8888-8888-8888-888888888888"; + +// In-memory implementation of ChunkResolver for test use. +// Stores compressed data keyed by decompressed content hash. +class InMemoryChunkResolver : public ChunkResolver +{ +public: + IoBuffer FindChunkByCid(const IoHash& DecompressedId) override + { + auto It = m_Chunks.find(DecompressedId); + if (It != m_Chunks.end()) + { + return It->second; + } + return {}; + } + + void AddChunk(const IoHash& DecompressedId, IoBuffer Data) { m_Chunks[DecompressedId] = std::move(Data); } + +private: + std::unordered_map<IoHash, IoBuffer> m_Chunks; +}; + +// Read, compress, and register zentest-appstub as a worker. +// Returns the WorkerId (hash of the worker package object). +static IoHash +RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env) +{ + std::filesystem::path AppStubPath = Env.ProgramBaseDir() / ("zentest-appstub" ZEN_EXE_SUFFIX_LITERAL); + + FileContents AppStubData = zen::ReadFile(AppStubPath); + REQUIRE_MESSAGE(!AppStubData.ErrorCode, fmt::format("Failed to read '{}': {}", AppStubPath.string(), AppStubData.ErrorCode.message())); + + IoBuffer AppStubBuffer = AppStubData.Flatten(); + + CompressedBuffer AppStubCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(AppStubBuffer.GetData(), AppStubBuffer.Size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash AppStubRawHash = AppStubCompressed.DecodeRawHash(); + const uint64_t AppStubRawSize = AppStubBuffer.Size(); + + CbAttachment AppStubAttachment(std::move(AppStubCompressed), AppStubRawHash); + + CbObjectWriter WorkerWriter; + WorkerWriter << "buildsystem_version"sv << Guid::FromString(kBuildSystemVersion); + WorkerWriter << "path"sv + << "zentest-appstub"sv; + + WorkerWriter.BeginArray("executables"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "zentest-appstub"sv; + WorkerWriter.AddAttachment("hash"sv, AppStubAttachment); + WorkerWriter << "size"sv << AppStubRawSize; + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + WorkerWriter.BeginArray("functions"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Rot13"sv; + WorkerWriter << "version"sv << Guid::FromString(kRot13Version); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Sleep"sv; + WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerWriter.Save()); + WorkerPackage.AddAttachment(AppStubAttachment); + + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + + const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); + HttpClient::Response RegisterResp = Client.Post(WorkerUrl, std::move(WorkerPackage)); + REQUIRE_MESSAGE(RegisterResp, + fmt::format("Worker registration failed: status={}, body={}", int(RegisterResp.StatusCode), RegisterResp.ToText())); + + return WorkerId; +} + +// Build a Rot13 action CbPackage for the given input string. +static CbPackage +BuildRot13ActionPackage(std::string_view Input) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Rot13"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kRot13Version); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + +// Build a Sleep action CbPackage. The worker sleeps for SleepTimeMs before returning its input. +static CbPackage +BuildSleepActionPackage(std::string_view Input, uint64_t SleepTimeMs) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Sleep"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kSleepVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "SleepTimeMs"sv << SleepTimeMs; + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + +// Build a Sleep action CbObject and populate the chunk resolver with the input attachment. +static CbObject +BuildSleepActionForSession(std::string_view Input, uint64_t SleepTimeMs, InMemoryChunkResolver& Resolver) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + Resolver.AddChunk(InputRawHash, InputCompressed.GetCompressed().Flatten().AsIoBuffer()); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Sleep"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kSleepVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "SleepTimeMs"sv << SleepTimeMs; + ActionWriter.EndObject(); + + return ActionWriter.Save(); +} + +static HttpClient::Response +PollForResult(HttpClient& Client, const std::string& ResultUrl, uint64_t TimeoutMs = 30'000) +{ + HttpClient::Response Resp; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < TimeoutMs) + { + Resp = Client.Get(ResultUrl); + + if (Resp.StatusCode == HttpResponseCode::OK) + { + break; + } + + Sleep(100); + } + + return Resp; +} + +static bool +PollForLsnInCompleted(HttpClient& Client, const std::string& CompletedUrl, int Lsn, uint64_t TimeoutMs = 30'000) +{ + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < TimeoutMs) + { + HttpClient::Response Resp = Client.Get(CompletedUrl); + + if (Resp) + { + for (auto& Item : Resp.AsObject()["completed"sv]) + { + if (Item.AsInt32() == Lsn) + { + return true; + } + } + } + + Sleep(100); + } + + return false; +} + +static std::string +GetRot13Output(const CbPackage& ResultPackage) +{ + CbObject ResultObj = ResultPackage.GetObject(); + + IoHash OutputHash; + CbFieldView ValuesField = ResultObj["Values"sv]; + + if (CbFieldViewIterator It = begin(ValuesField); It.HasValue()) + { + OutputHash = (*It).AsObjectView()["RawHash"sv].AsHash(); + } + + REQUIRE_MESSAGE(OutputHash != IoHash::Zero, "Expected non-zero output hash in result Values array"); + + const CbAttachment* OutputAttachment = ResultPackage.FindAttachment(OutputHash); + REQUIRE_MESSAGE(OutputAttachment != nullptr, "Output attachment not found in result package"); + + CompressedBuffer OutputCompressed = OutputAttachment->AsCompressedBinary(); + SharedBuffer OutputData = OutputCompressed.Decompress(); + + return std::string(static_cast<const char*>(OutputData.GetData()), OutputData.GetSize()); +} + +// Mock orchestrator HTTP service that serves GET /orch/agents with a controllable response. +class MockOrchestratorService : public HttpService +{ +public: + MockOrchestratorService() + { + // Initialize with empty worker list + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + Cbo.EndArray(); + m_WorkerList = Cbo.Save(); + } + + const char* BaseUri() const override { return "/orch/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + if (Request.RequestVerb() == HttpVerb::kGet && Request.RelativeUri() == "agents"sv) + { + RwLock::SharedLockScope Lock(m_Lock); + Request.WriteResponse(HttpResponseCode::OK, m_WorkerList); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + } + + void SetWorkerList(CbObject WorkerList) + { + RwLock::ExclusiveLockScope Lock(m_Lock); + m_WorkerList = std::move(WorkerList); + } + +private: + RwLock m_Lock; + CbObject m_WorkerList; +}; + +// Manages in-process ASIO HTTP server lifecycle for mock orchestrator. +struct MockOrchestratorFixture +{ + MockOrchestratorService Service; + ScopedTemporaryDirectory TmpDir; + Ref<HttpServer> Server; + std::thread ServerThread; + uint16_t Port = 0; + + MockOrchestratorFixture() + { + HttpServerConfig Config; + Config.ServerClass = "asio"; + Config.ForceLoopback = true; + Server = CreateHttpServer(Config); + Server->RegisterService(Service); + Port = static_cast<uint16_t>(Server->Initialize(TestEnv.GetNewPortNumber(), TmpDir.Path())); + ZEN_ASSERT(Port != 0); + ServerThread = std::thread([this]() { Server->Run(false); }); + } + + ~MockOrchestratorFixture() + { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + } + + std::string GetEndpoint() const { return fmt::format("http://localhost:{}", Port); } +}; + +// Build the CbObject response for /orch/agents matching the format UpdateCoordinatorState expects. +static CbObject +BuildAgentListResponse(std::initializer_list<std::pair<std::string_view, std::string_view>> Workers) +{ + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + for (const auto& [Id, Uri] : Workers) + { + Cbo.BeginObject(); + Cbo << "id"sv << Id; + Cbo << "uri"sv << Uri; + Cbo << "hostname"sv + << "localhost"sv; + Cbo << "reachable"sv << true; + Cbo << "dt"sv << uint64_t(0); + Cbo.EndObject(); + } + Cbo.EndArray(); + return Cbo.Save(); +} + +// Build the worker CbPackage for zentest-appstub AND populate the chunk resolver. +// This is the same logic as RegisterWorker() but returns the package instead of POSTing it. +static CbPackage +BuildWorkerPackage(ZenServerEnvironment& Env, InMemoryChunkResolver& Resolver) +{ + std::filesystem::path AppStubPath = Env.ProgramBaseDir() / ("zentest-appstub" ZEN_EXE_SUFFIX_LITERAL); + + FileContents AppStubData = zen::ReadFile(AppStubPath); + REQUIRE_MESSAGE(!AppStubData.ErrorCode, fmt::format("Failed to read '{}': {}", AppStubPath.string(), AppStubData.ErrorCode.message())); + + IoBuffer AppStubBuffer = AppStubData.Flatten(); + + CompressedBuffer AppStubCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(AppStubBuffer.GetData(), AppStubBuffer.Size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash AppStubRawHash = AppStubCompressed.DecodeRawHash(); + const uint64_t AppStubRawSize = AppStubBuffer.Size(); + + // Store compressed data in chunk resolver for when the remote runner needs it + Resolver.AddChunk(AppStubRawHash, AppStubCompressed.GetCompressed().Flatten().AsIoBuffer()); + + CbAttachment AppStubAttachment(std::move(AppStubCompressed), AppStubRawHash); + + CbObjectWriter WorkerWriter; + WorkerWriter << "buildsystem_version"sv << Guid::FromString(kBuildSystemVersion); + WorkerWriter << "path"sv + << "zentest-appstub"sv; + + WorkerWriter.BeginArray("executables"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "zentest-appstub"sv; + WorkerWriter.AddAttachment("hash"sv, AppStubAttachment); + WorkerWriter << "size"sv << AppStubRawSize; + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + WorkerWriter.BeginArray("functions"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Rot13"sv; + WorkerWriter << "version"sv << Guid::FromString(kRot13Version); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Sleep"sv; + WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerWriter.Save()); + WorkerPackage.AddAttachment(AppStubAttachment); + + return WorkerPackage; +} + +// Build a Rot13 action CbObject (not CbPackage) and populate the chunk resolver with the input attachment. +static CbObject +BuildRot13ActionForSession(std::string_view Input, InMemoryChunkResolver& Resolver) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + // Store compressed data in chunk resolver + Resolver.AddChunk(InputRawHash, InputCompressed.GetCompressed().Flatten().AsIoBuffer()); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Rot13"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kRot13Version); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + + return ActionWriter.Save(); +} + +TEST_SUITE_BEGIN("server.function"); + +TEST_CASE("function.rot13") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Submit action via legacy /jobs/{worker} endpoint + const std::string JobUrl = fmt::format("/jobs/{}", WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission"); + + // Poll for result via legacy /jobs/{lsn} endpoint + const std::string ResultUrl = fmt::format("/jobs/{}", Lsn); + HttpClient::Response ResultResp = PollForResult(Client, ResultUrl); + REQUIRE_MESSAGE( + ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + + // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" + CbPackage ResultPackage = ResultResp.AsPackage(); + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Action failed (empty result package)\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); +} + +TEST_CASE("function.workers") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + // Before registration, GET /workers should return an empty list + HttpClient::Response EmptyListResp = Client.Get("/workers"sv); + REQUIRE_MESSAGE(EmptyListResp, "Failed to list workers before registration"); + CHECK_EQ(EmptyListResp.AsObject()["workers"sv].AsArrayView().Num(), 0); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // GET /workers — the registered worker should appear in the listing + HttpClient::Response ListResp = Client.Get("/workers"sv); + REQUIRE_MESSAGE(ListResp, "Failed to list workers after registration"); + + bool WorkerFound = false; + for (auto& Item : ListResp.AsObject()["workers"sv]) + { + if (Item.AsHash() == WorkerId) + { + WorkerFound = true; + break; + } + } + + REQUIRE_MESSAGE(WorkerFound, fmt::format("Worker {} not found in worker listing", WorkerId.ToHexString())); + + // GET /workers/{worker} — descriptor should match what was registered + const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); + HttpClient::Response DescResp = Client.Get(WorkerUrl); + REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", int(DescResp.StatusCode))); + + CbObject Desc = DescResp.AsObject(); + CHECK_EQ(Desc["buildsystem_version"sv].AsUuid(), Guid::FromString(kBuildSystemVersion)); + CHECK_EQ(Desc["path"sv].AsString(), "zentest-appstub"sv); + + bool Rot13Found = false; + bool SleepFound = false; + for (auto& Item : Desc["functions"sv]) + { + std::string_view Name = Item.AsObjectView()["name"sv].AsString(); + if (Name == "Rot13"sv) + { + CHECK_EQ(Item.AsObjectView()["version"sv].AsUuid(), Guid::FromString(kRot13Version)); + Rot13Found = true; + } + else if (Name == "Sleep"sv) + { + CHECK_EQ(Item.AsObjectView()["version"sv].AsUuid(), Guid::FromString(kSleepVersion)); + SleepFound = true; + } + } + + CHECK_MESSAGE(Rot13Found, "Rot13 function not found in worker descriptor"); + CHECK_MESSAGE(SleepFound, "Sleep function not found in worker descriptor"); + + // GET /workers/{unknown} — should return 404 + const std::string UnknownUrl = fmt::format("/workers/{}", IoHash::Zero.ToHexString()); + HttpClient::Response NotFoundResp = Client.Get(UnknownUrl); + CHECK_EQ(NotFoundResp.StatusCode, HttpResponseCode::NotFound); +} + +TEST_CASE("function.queues.lifecycle") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); + + // Verify the queue appears in the listing + HttpClient::Response ListResp = Client.Get("/queues"sv); + REQUIRE_MESSAGE(ListResp, "Failed to list queues"); + + bool QueueFound = false; + for (auto& Item : ListResp.AsObject()["queues"sv]) + { + if (Item.AsObjectView()["queue_id"sv].AsInt32() == QueueId) + { + QueueFound = true; + break; + } + } + + REQUIRE_MESSAGE(QueueFound, fmt::format("Queue {} not found in queue listing", QueueId)); + + // Submit action via queue-scoped endpoint + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from queue job submission"); + + // Poll for completion via queue-scoped /completed endpoint + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Retrieve result via queue-scoped /jobs/{lsn} endpoint + const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); + HttpClient::Response ResultResp = Client.Get(ResultUrl); + REQUIRE_MESSAGE( + ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + + // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" + CbPackage ResultPackage = ResultResp.AsPackage(); + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + + // Verify queue status reflects completion + const std::string StatusUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response StatusResp = Client.Get(StatusUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 0); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "active"); +} + +TEST_CASE("function.queues.cancel") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); + + // Submit a job + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + // Cancel the queue + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response CancelResp = Client.Delete(QueueUrl); + REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, + fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + + // Verify queue status shows cancelled + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after cancel"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled"); +} + +TEST_CASE("function.queues.remote") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a remote queue — response includes both an integer queue_id and an OID queue_token + HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); + REQUIRE_MESSAGE(CreateResp, + fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + + CbObject CreateObj = CreateResp.AsObject(); + const std::string QueueToken = std::string(CreateObj["queue_token"sv].AsString()); + REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation"); + + // All subsequent requests use the opaque token in place of the integer queue id + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Remote queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from remote queue job submission"); + + // Poll for completion via the token-addressed /completed endpoint + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken); + REQUIRE_MESSAGE( + PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in remote queue completed list within timeout\nServer log:\n{}", Lsn, Instance.GetLogOutput())); + + // Retrieve result via the token-addressed /jobs/{lsn} endpoint + const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, Lsn); + HttpClient::Response ResultResp = Client.Get(ResultUrl); + REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Failed to retrieve result from remote queue: status={}\nServer log:\n{}", + int(ResultResp.StatusCode), + Instance.GetLogOutput())); + + // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" + CbPackage ResultPackage = ResultResp.AsPackage(); + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); +} + +TEST_CASE("function.queues.cancel_running") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); + + // Submit a Sleep job long enough that it will still be running when we cancel + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); + + // Wait for the worker process to start executing before cancelling + Sleep(1'000); + + // Cancel the queue, which should interrupt the running Sleep job + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response CancelResp = Client.Delete(QueueUrl); + REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, + fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + + // The cancelled job should appear in the /completed endpoint once the process exits + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list after cancel\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify the queue reflects one cancelled action + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after cancel"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled"); + CHECK_EQ(QueueStatus["cancelled_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.queues.remote_cancel") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a remote queue to obtain an OID token for token-addressed cancellation + HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); + REQUIRE_MESSAGE(CreateResp, + fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + + const std::string QueueToken = std::string(CreateResp.AsObject()["queue_token"sv].AsString()); + REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation"); + + // Submit a long-running Sleep job via the token-addressed endpoint + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); + + // Wait for the worker process to start executing before cancelling + Sleep(1'000); + + // Cancel the queue via its OID token + const std::string QueueUrl = fmt::format("/queues/{}", QueueToken); + HttpClient::Response CancelResp = Client.Delete(QueueUrl); + REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, + fmt::format("Remote queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + + // The cancelled job should appear in the token-addressed /completed endpoint + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken); + REQUIRE_MESSAGE( + PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in remote queue completed list after cancel\nServer log:\n{}", Lsn, Instance.GetLogOutput())); + + // Verify the queue status reflects the cancellation + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get remote queue status after cancel"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled"); + CHECK_EQ(QueueStatus["cancelled_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.queues.drain") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a long-running job so we can verify it completes even after drain + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response Submit1 = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 2'000)); + REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", int(Submit1.StatusCode))); + const int Lsn1 = Submit1.AsObject()["lsn"sv].AsInt32(); + + // Drain the queue + const std::string DrainUrl = fmt::format("/queues/{}/drain", QueueId); + HttpClient::Response DrainResp = Client.Post(DrainUrl); + REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", int(DrainResp.StatusCode), DrainResp.ToText())); + CHECK_EQ(std::string(DrainResp.AsObject()["state"sv].AsString()), "draining"); + + // Second submission should be rejected with 424 + HttpClient::Response Submit2 = Client.Post(JobUrl, BuildRot13ActionPackage("Hello"sv)); + CHECK_EQ(Submit2.StatusCode, HttpResponseCode::FailedDependency); + CHECK_EQ(std::string(Submit2.AsObject()["error"sv].AsString()), "queue is draining"); + + // First job should still complete + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn1), + fmt::format("LSN {} did not complete after drain\nServer log:\n{}", Lsn1, Instance.GetLogOutput())); + + // Queue status should show draining + complete + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "draining"); + CHECK(QueueStatus["is_complete"sv].AsBool()); +} + +TEST_CASE("function.priority") +{ + // Spawn server with max-actions=1 to guarantee serialized action execution, + // which lets us deterministically verify that higher-priority pending jobs + // are scheduled before lower-priority ones. + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--max-actions=1"); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue for all test jobs + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id"); + + // Submit a blocker Sleep job to occupy the single execution slot. + // Once the blocker is running, the scheduler must choose among the pending + // jobs by priority when the slot becomes free. + const std::string BlockerJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); + HttpClient::Response BlockerResp = Client.Post(BlockerJobUrl, BuildSleepActionPackage("data"sv, 1'000)); + REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", int(BlockerResp.StatusCode))); + + // Submit 3 low-priority Rot13 jobs + const std::string LowJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); + + HttpClient::Response LowResp1 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low1"sv)); + REQUIRE_MESSAGE(LowResp1, "Low-priority job 1 submission failed"); + const int LsnLow1 = LowResp1.AsObject()["lsn"sv].AsInt32(); + + HttpClient::Response LowResp2 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low2"sv)); + REQUIRE_MESSAGE(LowResp2, "Low-priority job 2 submission failed"); + const int LsnLow2 = LowResp2.AsObject()["lsn"sv].AsInt32(); + + HttpClient::Response LowResp3 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low3"sv)); + REQUIRE_MESSAGE(LowResp3, "Low-priority job 3 submission failed"); + const int LsnLow3 = LowResp3.AsObject()["lsn"sv].AsInt32(); + + // Submit 1 high-priority Rot13 job — should execute before the low-priority ones + const std::string HighJobUrl = fmt::format("/queues/{}/jobs/{}?priority=10", QueueId, WorkerId.ToHexString()); + HttpClient::Response HighResp = Client.Post(HighJobUrl, BuildRot13ActionPackage("high"sv)); + REQUIRE_MESSAGE(HighResp, "High-priority job submission failed"); + const int LsnHigh = HighResp.AsObject()["lsn"sv].AsInt32(); + + // Wait for all 4 priority-test jobs to appear in the queue's completed list. + // This avoids any snapshot-timing race: by the time we compare timestamps, all + // jobs have already finished and their history entries are stable. + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + + { + bool AllCompleted = false; + Stopwatch WaitTimer; + + while (!AllCompleted && WaitTimer.GetElapsedTimeMs() < 30'000) + { + HttpClient::Response Resp = Client.Get(CompletedUrl); + + if (Resp) + { + bool FoundHigh = false; + bool FoundLow1 = false; + bool FoundLow2 = false; + bool FoundLow3 = false; + + CbObject RespObj = Resp.AsObject(); + + for (auto& Item : RespObj["completed"sv]) + { + const int Lsn = Item.AsInt32(); + if (Lsn == LsnHigh) + { + FoundHigh = true; + } + else if (Lsn == LsnLow1) + { + FoundLow1 = true; + } + else if (Lsn == LsnLow2) + { + FoundLow2 = true; + } + else if (Lsn == LsnLow3) + { + FoundLow3 = true; + } + } + + AllCompleted = FoundHigh && FoundLow1 && FoundLow2 && FoundLow3; + } + + if (!AllCompleted) + { + Sleep(100); + } + } + + REQUIRE_MESSAGE( + AllCompleted, + fmt::format( + "Not all priority test jobs completed within timeout (lsnHigh={} lsnLow1={} lsnLow2={} lsnLow3={})\nServer log:\n{}", + LsnHigh, + LsnLow1, + LsnLow2, + LsnLow3, + Instance.GetLogOutput())); + } + + // Query the queue-scoped history to obtain the time_Completed timestamp for each + // job. The history endpoint records when each RunnerAction::State transition + // occurred, so time_Completed is the wall-clock tick at which the action finished. + // Using the queue-scoped endpoint avoids exposing history from other queues. + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + CbObject HistoryObj = HistoryResp.AsObject(); + + auto GetCompletedTimestamp = [&](int Lsn) -> uint64_t { + for (auto& Item : HistoryObj["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + return Item.AsObjectView()["time_Completed"sv].AsUInt64(); + } + } + return 0; + }; + + const uint64_t TimeHigh = GetCompletedTimestamp(LsnHigh); + const uint64_t TimeLow1 = GetCompletedTimestamp(LsnLow1); + const uint64_t TimeLow2 = GetCompletedTimestamp(LsnLow2); + const uint64_t TimeLow3 = GetCompletedTimestamp(LsnLow3); + + REQUIRE_MESSAGE(TimeHigh != 0, fmt::format("lsnHigh={} not found in action history", LsnHigh)); + REQUIRE_MESSAGE(TimeLow1 != 0, fmt::format("lsnLow1={} not found in action history", LsnLow1)); + REQUIRE_MESSAGE(TimeLow2 != 0, fmt::format("lsnLow2={} not found in action history", LsnLow2)); + REQUIRE_MESSAGE(TimeLow3 != 0, fmt::format("lsnLow3={} not found in action history", LsnLow3)); + + // The high-priority job must have completed strictly before every low-priority job + CHECK_MESSAGE(TimeHigh < TimeLow1, + fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow1={} completed at t={} (expected later)", + LsnHigh, + TimeHigh, + LsnLow1, + TimeLow1)); + CHECK_MESSAGE(TimeHigh < TimeLow2, + fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow2={} completed at t={} (expected later)", + LsnHigh, + TimeHigh, + LsnLow2, + TimeLow2)); + CHECK_MESSAGE(TimeHigh < TimeLow3, + fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow3={} completed at t={} (expected later)", + LsnHigh, + TimeHigh, + LsnLow3, + TimeLow3)); +} + +////////////////////////////////////////////////////////////////////////// +// Remote worker synchronization tests +// +// These tests exercise the orchestrator discovery path where new compute +// nodes appear over time and must receive previously registered workers +// via SyncWorkersToRunner(). + +TEST_CASE("function.remote.worker_sync_on_discovery") +{ + // Spawn real zenserver in compute mode + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t ServerPort = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(ServerPort != 0, Instance.GetLogOutput()); + + const std::string ServerUri = fmt::format("http://localhost:{}", ServerPort); + + // Start mock orchestrator with empty worker list + MockOrchestratorFixture MockOrch; + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session (stored locally, no runners yet) + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Update mock orchestrator to advertise the real server + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri}})); + + // Wait for scheduler to discover the runner (~5s throttle + margin) + Sleep(7'000); + + // Submit Rot13 action via session + CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Action enqueue failed"); + + // Poll for result + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE( + ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + + Session.Shutdown(); +} + +TEST_CASE("function.remote.late_runner_discovery") +{ + // Spawn first server + ZenServerInstance Instance1(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance1.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port1 = Instance1.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port1 != 0, Instance1.GetLogOutput()); + + const std::string ServerUri1 = fmt::format("http://localhost:{}", Port1); + + // Start mock orchestrator advertising W1 + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri1}})); + + // Create session and register worker + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for W1 discovery + Sleep(7'000); + + // Baseline: submit Rot13 action and verify it completes on W1 + { + CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Baseline action enqueue failed"); + + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Baseline action did not complete in time\nServer log:\n{}", Instance1.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + } + + // Spawn second server + ZenServerInstance Instance2(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance2.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port2 = Instance2.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port2 != 0, Instance2.GetLogOutput()); + + const std::string ServerUri2 = fmt::format("http://localhost:{}", Port2); + + // Update mock orchestrator to include both W1 and W2 + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri1}, {"worker-2", ServerUri2}})); + + // Wait for W2 discovery + Sleep(7'000); + + // Verify W2 received the worker by querying its /compute/workers endpoint directly + { + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port2); + HttpClient Client(ComputeBaseUri); + HttpClient::Response ListResp = Client.Get("/workers"sv); + REQUIRE_MESSAGE(ListResp, "Failed to list workers on W2"); + + bool WorkerFound = false; + for (auto& Item : ListResp.AsObject()["workers"sv]) + { + if (Item.AsHash() == WorkerPackage.GetObjectHash()) + { + WorkerFound = true; + break; + } + } + + REQUIRE_MESSAGE(WorkerFound, + fmt::format("Worker not found on W2 after discovery — SyncWorkersToRunner may have failed\nW2 log:\n{}", + Instance2.GetLogOutput())); + } + + // Submit another action and verify it completes (could run on either W1 or W2) + { + CbObject ActionObj = BuildRot13ActionForSession("Second Test"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Second action enqueue failed"); + + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Second action did not complete in time\nW1 log:\n{}\nW2 log:\n{}", + Instance1.GetLogOutput(), + Instance2.GetLogOutput())); + + // Rot13("Second Test") = "Frpbaq Grfg" + CHECK_EQ(GetRot13Output(ResultPackage), "Frpbaq Grfg"sv); + } + + Session.Shutdown(); +} + +TEST_CASE("function.remote.queue_association") +{ + // Spawn real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a local queue and submit action to it + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Action enqueue to queue failed"); + + // Poll for result + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE( + ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + + // Verify that a non-implicit remote queue was created on the compute node + HttpClient Client(Instance.GetBaseUri() + "/compute"); + + HttpClient::Response QueuesResp = Client.Get("/queues"sv); + REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server"); + + bool RemoteQueueFound = false; + for (auto& Item : QueuesResp.AsObject()["queues"sv]) + { + if (!Item.AsObjectView()["implicit"sv].AsBool()) + { + RemoteQueueFound = true; + break; + } + } + + CHECK_MESSAGE(RemoteQueueFound, "Expected a non-implicit remote queue on the compute node"); + + Session.Shutdown(); +} + +TEST_CASE("function.remote.queue_cancel_propagation") +{ + // Spawn real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a local queue and submit a long-running Sleep action + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); + + // Wait for the action to start running on the remote + Sleep(2'000); + + // Cancel the local queue — this should propagate to the remote + Session.CancelQueue(QueueId); + + // Poll for the action to complete (as cancelled) + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + // Verify the local queue shows cancelled + auto QueueStatus = Session.GetQueueStatus(QueueId); + CHECK(QueueStatus.State == zen::compute::ComputeServiceSession::QueueState::Cancelled); + + // Verify the remote queue on the compute node is also cancelled + HttpClient Client(Instance.GetBaseUri() + "/compute"); + + HttpClient::Response QueuesResp = Client.Get("/queues"sv); + REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server"); + + bool RemoteQueueCancelled = false; + for (auto& Item : QueuesResp.AsObject()["queues"sv]) + { + if (!Item.AsObjectView()["implicit"sv].AsBool()) + { + RemoteQueueCancelled = std::string(Item.AsObjectView()["state"sv].AsString()) == "cancelled"; + break; + } + } + + CHECK_MESSAGE(RemoteQueueCancelled, "Expected the remote queue to be cancelled"); + + Session.Shutdown(); +} + +TEST_CASE("function.abandon_running_http") +{ + // Spawn a real zenserver to execute a long-running action, then abandon via HTTP endpoint + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue and submit a long-running Sleep job + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id"); + + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", int(SubmitResp.StatusCode))); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN"); + + // Wait for the process to start running + Sleep(1'000); + + // Verify the ready endpoint returns OK before abandon + { + HttpClient::Response ReadyResp = Client.Get("/ready"sv); + CHECK(ReadyResp.StatusCode == HttpResponseCode::OK); + } + + // Trigger abandon via the HTTP endpoint + HttpClient::Response AbandonResp = Client.Post("/abandon"sv); + REQUIRE_MESSAGE(AbandonResp.StatusCode == HttpResponseCode::OK, + fmt::format("Abandon request failed: status={}, body={}", int(AbandonResp.StatusCode), AbandonResp.ToText())); + + // Ready endpoint should now return 503 + { + HttpClient::Response ReadyResp = Client.Get("/ready"sv); + CHECK(ReadyResp.StatusCode == HttpResponseCode::ServiceUnavailable); + } + + // The abandoned action should appear in the completed endpoint once the process exits + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list after abandon\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify the queue reflects one abandoned action + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after abandon"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["abandoned_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0); + + // Submitting new work should be rejected + HttpClient::Response RejectedResp = Client.Post(JobUrl, BuildRot13ActionPackage("rejected"sv)); + CHECK_MESSAGE(RejectedResp.StatusCode != HttpResponseCode::OK, "Expected action submission to be rejected in Abandoned state"); +} + +TEST_CASE("function.session.abandon_pending") +{ + // Create a session with no runners so actions stay pending + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Enqueue several actions — they will stay pending because there are no runners + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue"); + + CbObject ActionObj = BuildRot13ActionForSession("abandon-test"sv, Resolver); + + auto Enqueue1 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + auto Enqueue2 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + auto Enqueue3 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + REQUIRE_MESSAGE(Enqueue1, "Failed to enqueue action 1"); + REQUIRE_MESSAGE(Enqueue2, "Failed to enqueue action 2"); + REQUIRE_MESSAGE(Enqueue3, "Failed to enqueue action 3"); + + // Transition to Abandoned — should mark all pending actions as Abandoned + bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); + CHECK(Session.GetSessionState() == zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK(!Session.IsHealthy()); + + // Give the scheduler thread time to process the state changes + Sleep(2'000); + + // All three actions should now be in the results map as abandoned + for (int Lsn : {Enqueue1.Lsn, Enqueue2.Lsn, Enqueue3.Lsn}) + { + CbPackage Result; + HttpResponseCode Code = Session.GetActionResult(Lsn, Result); + CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code))); + } + + // Queue should show 0 active, 3 abandoned + auto Status = Session.GetQueueStatus(QueueResult.QueueId); + CHECK_EQ(Status.ActiveCount, 0); + CHECK_EQ(Status.AbandonedCount, 3); + + // New actions should be rejected + auto Rejected = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + CHECK_MESSAGE(!Rejected, "Expected action submission to be rejected in Abandoned state"); + + // Abandoned → Sunset should be valid + CHECK(Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Sunset)); + + Session.Shutdown(); +} + +TEST_CASE("function.session.abandon_running") +{ + // Spawn a real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a queue and submit a long-running Sleep action + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver); + + auto EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); + + // Wait for the action to start running on the remote + Sleep(2'000); + + // Transition to Abandoned — should abandon the running action + bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); + CHECK(!Session.IsHealthy()); + + // Poll for the action to complete (as abandoned) + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete within timeout\nServer log:\n{}", Instance.GetLogOutput())); + + // Verify the queue shows abandoned, not completed + auto QueueStatus = Session.GetQueueStatus(QueueId); + CHECK_EQ(QueueStatus.ActiveCount, 0); + CHECK_EQ(QueueStatus.AbandonedCount, 1); + CHECK_EQ(QueueStatus.CompletedCount, 0); + + Session.Shutdown(); +} + +TEST_CASE("function.remote.abandon_propagation") +{ + // Spawn real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a local queue and submit a long-running Sleep action + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver); + + auto EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); + + // Wait for the action to start running on the remote + Sleep(2'000); + + // Transition to Abandoned — should abandon the running action and propagate + bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); + + // Poll for the action to complete + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete within timeout\nServer log:\n{}", Instance.GetLogOutput())); + + // Verify the local queue shows abandoned + auto QueueStatus = Session.GetQueueStatus(QueueId); + CHECK_EQ(QueueStatus.ActiveCount, 0); + CHECK_EQ(QueueStatus.AbandonedCount, 1); + + // Session should not be healthy + CHECK(!Session.IsHealthy()); + + // The remote compute node should still be healthy (only the parent abandoned) + HttpClient RemoteClient(Instance.GetBaseUri() + "/compute"); + HttpClient::Response ReadyResp = RemoteClient.Get("/ready"sv); + CHECK_MESSAGE(ReadyResp.StatusCode == HttpResponseCode::OK, "Remote compute node should still be healthy"); + + Session.Shutdown(); +} + +TEST_SUITE_END(); + +} // namespace zen::tests::compute + +#endif diff --git a/src/zenserver-test/hub-tests.cpp b/src/zenserver-test/hub-tests.cpp index 42a5dcae4..11531e30f 100644 --- a/src/zenserver-test/hub-tests.cpp +++ b/src/zenserver-test/hub-tests.cpp @@ -24,7 +24,7 @@ namespace zen::tests::hub { using namespace std::literals; -TEST_SUITE_BEGIN("hub.lifecycle"); +TEST_SUITE_BEGIN("server.hub"); TEST_CASE("hub.lifecycle.basic") { @@ -230,9 +230,7 @@ TEST_CASE("hub.lifecycle.children") } } -TEST_SUITE_END(); - -TEST_CASE("hub.consul.lifecycle") +TEST_CASE("hub.consul.lifecycle" * doctest::skip()) { zen::consul::ConsulProcess ConsulProc; ConsulProc.SpawnConsulAgent(); @@ -248,5 +246,7 @@ TEST_CASE("hub.consul.lifecycle") ConsulProc.StopConsulAgent(); } +TEST_SUITE_END(); + } // namespace zen::tests::hub #endif diff --git a/src/zenserver-test/logging-tests.cpp b/src/zenserver-test/logging-tests.cpp new file mode 100644 index 000000000..2e530ff92 --- /dev/null +++ b/src/zenserver-test/logging-tests.cpp @@ -0,0 +1,261 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/zencore.h> + +#if ZEN_WITH_TESTS + +# include "zenserver-test.h" + +# include <zencore/filesystem.h> +# include <zencore/logging.h> +# include <zencore/testing.h> +# include <zenutil/zenserverprocess.h> + +namespace zen::tests { + +using namespace std::literals; + +TEST_SUITE_BEGIN("server.logging"); + +////////////////////////////////////////////////////////////////////////// + +static bool +LogContains(const std::string& Log, std::string_view Needle) +{ + return Log.find(Needle) != std::string::npos; +} + +static std::string +ReadFileToString(const std::filesystem::path& Path) +{ + FileContents Contents = ReadFile(Path); + if (Contents.ErrorCode) + { + return {}; + } + + IoBuffer Content = Contents.Flatten(); + if (!Content) + { + return {}; + } + + return std::string(static_cast<const char*>(Content.Data()), Content.Size()); +} + +////////////////////////////////////////////////////////////////////////// + +// Verify that a log file is created at the default location (DataDir/logs/zenserver.log) +// even without --abslog. The file must contain "server session id" (logged at INFO +// to all registered loggers during init) and "log starting at" (emitted once a file +// sink is first opened). +TEST_CASE("logging.file.default") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::filesystem::path DefaultLogFile = TestDir / "logs" / "zenserver.log"; + CHECK_MESSAGE(std::filesystem::exists(DefaultLogFile), "Default log file was not created"); + const std::string FileLog = ReadFileToString(DefaultLogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "log starting at"), FileLog); +} + +// --quiet sets the console sink level to WARN. The formatted "[info] ..." +// entry written by the default logger's console sink must therefore not appear +// in captured stdout. (The "console" named logger — used by ZEN_CONSOLE_* +// macros — may still emit plain-text messages without a level marker, so we +// check for the absence of the FullFormatter "[info]" prefix rather than the +// message text itself.) +TEST_CASE("logging.console.quiet") +{ + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--quiet"); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::string Log = Instance.GetLogOutput(); + CHECK_MESSAGE(!LogContains(Log, "[info] server session id"), Log); +} + +// --noconsole removes the stdout sink entirely, so the captured console output +// must not contain any log entries from the logging system. +TEST_CASE("logging.console.disabled") +{ + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--noconsole"); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::string Log = Instance.GetLogOutput(); + CHECK_MESSAGE(!LogContains(Log, "server session id"), Log); +} + +// --abslog <path> creates a rotating log file at the specified path. +// The file must contain "server session id" (logged at INFO to all loggers +// during init) and "log starting at" (emitted once a file sink is active). +TEST_CASE("logging.file.basic") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {}", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "log starting at"), FileLog); +} + +// --abslog with a .json extension selects the JSON formatter. +// Each log entry must be a JSON object containing at least the "message" +// and "source" fields. +TEST_CASE("logging.file.json") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.json"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {}", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "JSON log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "\"message\""), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "\"source\": \"zenserver\""), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); +} + +// --log-id <id> is automatically set to the server instance name in test mode. +// The JSON formatter emits this value as the "id" field, so every entry in a +// .json log file must carry a non-empty "id". +TEST_CASE("logging.log_id") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.json"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {}", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "JSON log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + // The JSON formatter writes the log-id as: "id": "<value>", + CHECK_MESSAGE(LogContains(FileLog, "\"id\": \""), FileLog); +} + +// --log-warn <logger> raises the level threshold above INFO so that INFO messages +// are filtered. "server session id" is broadcast at INFO to all loggers: it must +// appear in the main file sink (default logger unaffected) but must NOT appear in +// http.log where the http_requests logger now has a WARN threshold. +TEST_CASE("logging.level.warn_suppresses_info") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {} --log-warn http_requests", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + + const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log"; + CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created"); + const std::string HttpLog = ReadFileToString(HttpLogFile); + CHECK_MESSAGE(!LogContains(HttpLog, "server session id"), HttpLog); +} + +// --log-info <logger> sets an explicit INFO threshold. The INFO "server session id" +// broadcast must still land in http.log, confirming that INFO messages are not +// filtered when the logger level is exactly INFO. +TEST_CASE("logging.level.info_allows_info") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {} --log-info http_requests", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log"; + CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created"); + const std::string HttpLog = ReadFileToString(HttpLogFile); + CHECK_MESSAGE(LogContains(HttpLog, "server session id"), HttpLog); +} + +// --log-off <logger> silences a named logger entirely. +// "server session id" is broadcast at INFO to all registered loggers via +// spdlog::apply_all during init. When the "http_requests" logger is set to +// OFF its dedicated http.log file must not contain that message. +// The main file sink (via --abslog) must be unaffected. +TEST_CASE("logging.level.off_specific_logger") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {} --log-off http_requests", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + // Main log file must still have the startup message + CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + + // http.log is created by the RotatingFileSink but the logger is OFF, so + // the broadcast "server session id" message must not have been written to it + const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log"; + CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created"); + const std::string HttpLog = ReadFileToString(HttpLogFile); + CHECK_MESSAGE(!LogContains(HttpLog, "server session id"), HttpLog); +} + +TEST_SUITE_END(); + +} // namespace zen::tests + +#endif diff --git a/src/zenserver-test/nomad-tests.cpp b/src/zenserver-test/nomad-tests.cpp new file mode 100644 index 000000000..f8f5a9a30 --- /dev/null +++ b/src/zenserver-test/nomad-tests.cpp @@ -0,0 +1,130 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#if ZEN_WITH_TESTS && ZEN_WITH_NOMAD +# include "zenserver-test.h" +# include <zencore/filesystem.h> +# include <zencore/logging.h> +# include <zencore/testing.h> +# include <zencore/timer.h> +# include <zenhttp/httpclient.h> +# include <zennomad/nomadclient.h> +# include <zennomad/nomadprocess.h> +# include <zenutil/zenserverprocess.h> + +# include <fmt/format.h> + +namespace zen::tests::nomad_tests { + +using namespace std::literals; + +TEST_SUITE_BEGIN("server.nomad"); + +TEST_CASE("nomad.client.lifecycle" * doctest::skip()) +{ + zen::nomad::NomadProcess NomadProc; + NomadProc.SpawnNomadAgent(); + + zen::nomad::NomadTestClient Client("http://localhost:4646/"); + + // Submit a simple batch job that sleeps briefly +# if ZEN_PLATFORM_WINDOWS + auto Job = Client.SubmitJob("zen-test-job", "cmd.exe", {"/C", "timeout /t 10 /nobreak"}); +# else + auto Job = Client.SubmitJob("zen-test-job", "/bin/sleep", {"10"}); +# endif + REQUIRE(!Job.Id.empty()); + CHECK_EQ(Job.Status, "pending"); + + // Poll until the job is running (or dead) + { + Stopwatch Timer; + bool FoundRunning = false; + while (Timer.GetElapsedTimeMs() < 15000) + { + auto Status = Client.GetJobStatus("zen-test-job"); + if (Status.Status == "running") + { + FoundRunning = true; + break; + } + if (Status.Status == "dead") + { + break; + } + Sleep(500); + } + CHECK(FoundRunning); + } + + // Verify allocations exist + auto Allocs = Client.GetAllocations("zen-test-job"); + CHECK(!Allocs.empty()); + + // Stop the job + Client.StopJob("zen-test-job"); + + // Verify it reaches dead state + { + Stopwatch Timer; + bool FoundDead = false; + while (Timer.GetElapsedTimeMs() < 10000) + { + auto Status = Client.GetJobStatus("zen-test-job"); + if (Status.Status == "dead") + { + FoundDead = true; + break; + } + Sleep(500); + } + CHECK(FoundDead); + } + + NomadProc.StopNomadAgent(); +} + +TEST_CASE("nomad.provisioner.integration" * doctest::skip()) +{ + zen::nomad::NomadProcess NomadProc; + NomadProc.SpawnNomadAgent(); + + // Spawn zenserver in compute mode with Nomad provisioning enabled + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + + std::filesystem::path ZenServerPath = TestEnv.ProgramBaseDir() / "zenserver" ZEN_EXE_SUFFIX_LITERAL; + + std::string NomadArgs = fmt::format( + "--nomad-enabled=true" + " --nomad-server=http://localhost:4646" + " --nomad-driver=raw_exec" + " --nomad-binary-path={}" + " --nomad-max-cores=32" + " --nomad-cores-per-job=32", + ZenServerPath.string()); + + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(NomadArgs); + REQUIRE(Port != 0); + + // Give the provisioner time to submit jobs. + // The management thread has a 5s wait between cycles, and the HTTP client has + // a 10s connect timeout, so we need to allow enough time for at least one full cycle. + Sleep(15000); + + // Verify jobs were submitted to Nomad + zen::nomad::NomadTestClient NomadClient("http://localhost:4646/"); + + auto Jobs = NomadClient.ListJobs("zenserver-worker"); + + ZEN_INFO("nomad.provisioner.integration: found {} jobs with prefix 'zenserver-worker'", Jobs.size()); + CHECK_MESSAGE(!Jobs.empty(), Instance.GetLogOutput()); + + Instance.Shutdown(); + NomadProc.StopNomadAgent(); +} + +TEST_SUITE_END(); + +} // namespace zen::tests::nomad_tests +#endif diff --git a/src/zenserver-test/objectstore-tests.cpp b/src/zenserver-test/objectstore-tests.cpp new file mode 100644 index 000000000..f3db5fdf6 --- /dev/null +++ b/src/zenserver-test/objectstore-tests.cpp @@ -0,0 +1,74 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#if ZEN_WITH_TESTS +# include "zenserver-test.h" +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zenutil/zenserverprocess.h> +# include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <tsl/robin_set.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::tests { + +using namespace std::literals; + +TEST_SUITE_BEGIN("server.objectstore"); + +TEST_CASE("objectstore.blobs") +{ + std::string_view Bucket = "bkt"sv; + + std::vector<IoHash> CompressedBlobsHashes; + std::vector<uint64_t> BlobsSizes; + std::vector<uint64_t> CompressedBlobsSizes; + { + ZenServerInstance Instance(TestEnv); + + const uint16_t PortNumber = Instance.SpawnServerAndWaitUntilReady(fmt::format("--objectstore-enabled")); + CHECK(PortNumber != 0); + + HttpClient Client(Instance.GetBaseUri() + "/obj/"); + + for (size_t I = 0; I < 5; I++) + { + IoBuffer Blob = CreateSemiRandomBlob(4711 + I * 7); + BlobsSizes.push_back(Blob.GetSize()); + CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(std::move(Blob))); + CompressedBlobsHashes.push_back(CompressedBlob.DecodeRawHash()); + CompressedBlobsSizes.push_back(CompressedBlob.GetCompressedSize()); + IoBuffer Payload = std::move(CompressedBlob).GetCompressed().Flatten().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCompressedBinary); + + std::string ObjectPath = fmt::format("{}/{}.utoc", + CompressedBlobsHashes.back().ToHexString().substr(0, 2), + CompressedBlobsHashes.back().ToHexString()); + + HttpClient::Response Result = Client.Put(fmt::format("bucket/{}/{}.utoc", Bucket, ObjectPath), Payload); + CHECK(Result); + } + + for (size_t I = 0; I < 5; I++) + { + std::string ObjectPath = + fmt::format("{}/{}.utoc", CompressedBlobsHashes[I].ToHexString().substr(0, 2), CompressedBlobsHashes[I].ToHexString()); + HttpClient::Response Result = Client.Get(fmt::format("bucket/{}/{}.utoc", Bucket, ObjectPath)); + CHECK(Result); + CHECK_EQ(Result.ResponsePayload.GetSize(), CompressedBlobsSizes[I]); + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(std::move(Result.ResponsePayload)), RawHash, RawSize); + CHECK(Compressed); + CHECK_EQ(RawHash, CompressedBlobsHashes[I]); + CHECK_EQ(RawSize, BlobsSizes[I]); + } + } +} + +TEST_SUITE_END(); + +} // namespace zen::tests +#endif diff --git a/src/zenserver-test/projectstore-tests.cpp b/src/zenserver-test/projectstore-tests.cpp index ead062628..eb2e187d7 100644 --- a/src/zenserver-test/projectstore-tests.cpp +++ b/src/zenserver-test/projectstore-tests.cpp @@ -27,6 +27,8 @@ namespace zen::tests { using namespace std::literals; +TEST_SUITE_BEGIN("server.projectstore"); + TEST_CASE("project.basic") { using namespace std::literals; @@ -71,7 +73,7 @@ TEST_CASE("project.basic") { auto Response = Http.Get("/prj/test"sv); - CHECK(Response.StatusCode == HttpResponseCode::OK); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); CbObject ResponseObject = Response.AsObject(); @@ -92,7 +94,7 @@ TEST_CASE("project.basic") { auto Response = Http.Get(""sv); - CHECK(Response.StatusCode == HttpResponseCode::OK); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); CbObject ResponseObject = Response.AsObject(); @@ -213,7 +215,7 @@ TEST_CASE("project.basic") auto Response = Http.Get(ChunkGetUri); REQUIRE(Response); - CHECK(Response.StatusCode == HttpResponseCode::OK); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); IoBuffer Data = Response.ResponsePayload; IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath); @@ -235,13 +237,13 @@ TEST_CASE("project.basic") auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}}); REQUIRE(Response); - CHECK(Response.StatusCode == HttpResponseCode::OK); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); IoBuffer Data = Response.ResponsePayload; IoHash RawHash; uint64_t RawSize; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize); - CHECK(Compressed); + REQUIRE(Compressed); IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer(); IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath); CHECK(RawSize == ReferenceData.GetSize()); @@ -436,14 +438,14 @@ TEST_CASE("project.remote") HttpClient Http{UrlBase}; HttpClient::Response Response = Http.Post(fmt::format("/prj/{}", ProjectName), ProjectPayload); - CHECK(Response); + REQUIRE(Response); }; auto MakeOplog = [](std::string_view UrlBase, std::string_view ProjectName, std::string_view OplogName) { HttpClient Http{UrlBase}; HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName), IoBuffer{}, ZenContentType::kCbObject); - CHECK(Response); + REQUIRE(Response); }; auto MakeOp = [](std::string_view UrlBase, std::string_view ProjectName, std::string_view OplogName, const CbPackage& OpPackage) { @@ -454,7 +456,7 @@ TEST_CASE("project.remote") HttpClient Http{UrlBase}; HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/new", ProjectName, OplogName), Body); - CHECK(Response); + REQUIRE(Response); }; MakeProject(Servers.GetInstance(0).GetBaseUri(), "proj0"); @@ -505,7 +507,7 @@ TEST_CASE("project.remote") HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", Project, Oplog), Payload, {{"Accept", "application/x-ue-cbpkg"}}); - CHECK(Response); + REQUIRE(Response); CbPackage ResponsePackage = ParsePackageMessage(Response.ResponsePayload); CHECK(ResponsePackage.GetAttachments().size() == AttachmentHashes.size()); for (auto A : ResponsePackage.GetAttachments()) @@ -520,7 +522,7 @@ TEST_CASE("project.remote") HttpClient Http{Servers.GetInstance(ServerIndex).GetBaseUri()}; HttpClient::Response Response = Http.Get(fmt::format("/prj/{}/oplog/{}/entries", Project, Oplog)); - CHECK(Response); + REQUIRE(Response); IoBuffer Payload(Response.ResponsePayload); CbObject OplogResonse = LoadCompactBinaryObject(Payload); @@ -542,7 +544,7 @@ TEST_CASE("project.remote") auto HttpWaitForCompletion = [](ZenServerInstance& Server, const HttpClient::Response& Response) { REQUIRE(Response); const uint64_t JobId = ParseInt<uint64_t>(Response.AsText()).value_or(0); - CHECK(JobId != 0); + REQUIRE(JobId != 0); HttpClient Http{Server.GetBaseUri()}; @@ -550,10 +552,10 @@ TEST_CASE("project.remote") { HttpClient::Response StatusResponse = Http.Get(fmt::format("/admin/jobs/{}", JobId), {{"Accept", ToString(ZenContentType::kCbObject)}}); - CHECK(StatusResponse); + REQUIRE(StatusResponse); CbObject ResponseObject = StatusResponse.AsObject(); std::string_view Status = ResponseObject["Status"sv].AsString(); - CHECK(Status != "Aborted"sv); + REQUIRE(Status != "Aborted"sv); if (Status == "Complete"sv) { return; @@ -888,17 +890,17 @@ TEST_CASE("project.rpcappendop") Project.AddString("project"sv, ""sv); Project.AddString("projectfile"sv, ""sv); HttpClient::Response Response = Client.Post(fmt::format("/prj/{}", ProjectName), Project.Save()); - CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); }; auto MakeOplog = [](HttpClient& Client, std::string_view ProjectName, std::string_view OplogName) { HttpClient::Response Response = Client.Post(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName), IoBuffer{}, ZenContentType::kCbObject); - CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); }; auto GetOplog = [](HttpClient& Client, std::string_view ProjectName, std::string_view OplogName) { HttpClient::Response Response = Client.Get(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName)); - CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); return Response.AsObject(); }; @@ -912,7 +914,7 @@ TEST_CASE("project.rpcappendop") } Request.EndArray(); // "ops" HttpClient::Response Response = Client.Post(fmt::format("/prj/{}/oplog/{}/rpc", ProjectName, OplogName), Request.Save()); - CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); CbObjectView ResponsePayload = Response.AsPackage().GetObject(); CbArrayView NeedArray = ResponsePayload["need"sv].AsArrayView(); @@ -1055,6 +1057,8 @@ TEST_CASE("project.rpcappendop") } } +TEST_SUITE_END(); + } // namespace zen::tests #endif diff --git a/src/zenserver-test/workspace-tests.cpp b/src/zenserver-test/workspace-tests.cpp index 7595d790a..655f28872 100644 --- a/src/zenserver-test/workspace-tests.cpp +++ b/src/zenserver-test/workspace-tests.cpp @@ -73,6 +73,8 @@ GenerateFolderContent2(const std::filesystem::path& RootPath) return Result; } +TEST_SUITE_BEGIN("server.workspace"); + TEST_CASE("workspaces.create") { using namespace std::literals; @@ -514,9 +516,9 @@ TEST_CASE("workspaces.share") } IoBuffer BatchResponse = Client.Post(fmt::format("/ws/{}/{}/batch", WorkspaceId, ShareId), BuildChunkBatchRequest(BatchEntries)).ResponsePayload; - CHECK(BatchResponse); + REQUIRE(BatchResponse); std::vector<IoBuffer> BatchResult = ParseChunkBatchResponse(BatchResponse); - CHECK(BatchResult.size() == Files.size()); + REQUIRE(BatchResult.size() == Files.size()); for (const RequestChunkEntry& Request : BatchEntries) { IoBuffer Result = BatchResult[Request.CorrelationId]; @@ -537,5 +539,7 @@ TEST_CASE("workspaces.share") CHECK(Client.Get(fmt::format("/ws/{}", WorkspaceId)).StatusCode == HttpResponseCode::NotFound); } +TEST_SUITE_END(); + } // namespace zen::tests #endif diff --git a/src/zenserver-test/xmake.lua b/src/zenserver-test/xmake.lua index 2a269cea1..7b208bbc7 100644 --- a/src/zenserver-test/xmake.lua +++ b/src/zenserver-test/xmake.lua @@ -6,10 +6,15 @@ target("zenserver-test") add_headerfiles("**.h") add_files("*.cpp") add_files("zenserver-test.cpp", {unity_ignored = true }) - add_deps("zencore", "zenremotestore", "zenhttp") + add_deps("zencore", "zenremotestore", "zenhttp", "zencompute", "zenstore") add_deps("zenserver", {inherit=false}) + add_deps("zentest-appstub", {inherit=false}) add_packages("http_parser") + if has_config("zennomad") then + add_deps("zennomad") + end + if is_plat("macosx") then add_ldflags("-framework CoreFoundation") add_ldflags("-framework Security") diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp index 9a42bb73d..8d5400294 100644 --- a/src/zenserver-test/zenserver-test.cpp +++ b/src/zenserver-test/zenserver-test.cpp @@ -4,12 +4,12 @@ #if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 # include "zenserver-test.h" # include <zencore/except.h> # include <zencore/fmtutils.h> # include <zencore/logging.h> +# include <zencore/logging/registry.h> # include <zencore/stream.h> # include <zencore/string.h> # include <zencore/testutils.h> @@ -17,8 +17,8 @@ # include <zencore/timer.h> # include <zenhttp/httpclient.h> # include <zenhttp/packageformat.h> -# include <zenutil/commandlineoptions.h> -# include <zenutil/logging/testformatter.h> +# include <zenutil/config/commandlineoptions.h> +# include <zenutil/logging/fullformatter.h> # include <zenutil/zenserverprocess.h> # include <atomic> @@ -86,8 +86,9 @@ main(int argc, char** argv) zen::logging::InitializeLogging(); - zen::logging::SetLogLevel(zen::logging::level::Debug); - spdlog::set_formatter(std::make_unique<zen::logging::full_test_formatter>("test", std::chrono::system_clock::now())); + zen::logging::SetLogLevel(zen::logging::Debug); + zen::logging::Registry::Instance().SetFormatter( + std::make_unique<zen::logging::FullFormatter>("test", std::chrono::system_clock::now())); std::filesystem::path ProgramBaseDir = GetRunningExecutablePath().parent_path(); std::filesystem::path TestBaseDir = std::filesystem::current_path() / ".test"; @@ -97,6 +98,7 @@ main(int argc, char** argv) // somehow in the future std::string ServerClass; + bool Verbose = false; for (int i = 1; i < argc; ++i) { @@ -107,13 +109,23 @@ main(int argc, char** argv) ServerClass = argv[++i]; } } + else if (argv[i] == "--verbose"sv) + { + Verbose = true; + } } zen::tests::TestEnv.InitializeForTest(ProgramBaseDir, TestBaseDir, ServerClass); + if (Verbose) + { + zen::tests::TestEnv.SetPassthroughOutput(true); + } + ZEN_INFO("Running tests...(base dir: '{}')", TestBaseDir); zen::testing::TestRunner Runner; + Runner.SetDefaultSuiteFilter("server.*"); Runner.ApplyCommandLine(argc, argv); return Runner.Run(); @@ -121,6 +133,8 @@ main(int argc, char** argv) namespace zen::tests { +TEST_SUITE_BEGIN("server.zenserver"); + TEST_CASE("default.single") { std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); @@ -327,6 +341,8 @@ TEST_CASE("http.package") CHECK_EQ(ResponsePackage, TestPackage); } +TEST_SUITE_END(); + # if 0 TEST_CASE("lifetime.owner") { diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp new file mode 100644 index 000000000..c64f081b3 --- /dev/null +++ b/src/zenserver/compute/computeserver.cpp @@ -0,0 +1,1021 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "computeserver.h" +#include <zencompute/cloudmetadata.h> +#include <zencompute/httpcomputeservice.h> +#include <zencompute/httporchestrator.h> +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/fmtutils.h> +# include <zencore/memory/llm.h> +# include <zencore/memory/memorytrace.h> +# include <zencore/memory/tagtrace.h> +# include <zencore/scopeguard.h> +# include <zencore/sentryintegration.h> +# include <zencore/system.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/windows.h> +# include <zenhttp/httpclient.h> +# include <zenhttp/httpapiservice.h> +# include <zenstore/cidstore.h> +# include <zenutil/service.h> +# if ZEN_WITH_HORDE +# include <zenhorde/hordeconfig.h> +# include <zenhorde/hordeprovisioner.h> +# endif +# if ZEN_WITH_NOMAD +# include <zennomad/nomadconfig.h> +# include <zennomad/nomadprovisioner.h> +# endif + +ZEN_THIRD_PARTY_INCLUDES_START +# include <cxxopts.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +void +ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) +{ + Options.add_option("compute", + "", + "max-actions", + "Maximum number of concurrent local actions (0 = auto)", + cxxopts::value<int32_t>(m_ServerOptions.MaxConcurrentActions)->default_value("0"), + ""); + + Options.add_option("compute", + "", + "upstream-notification-endpoint", + "Endpoint URL for upstream notifications", + cxxopts::value<std::string>(m_ServerOptions.UpstreamNotificationEndpoint)->default_value(""), + ""); + + Options.add_option("compute", + "", + "instance-id", + "Instance ID for use in notifications", + cxxopts::value<std::string>(m_ServerOptions.InstanceId)->default_value(""), + ""); + + Options.add_option("compute", + "", + "coordinator-endpoint", + "Endpoint URL for coordinator service", + cxxopts::value<std::string>(m_ServerOptions.CoordinatorEndpoint)->default_value(""), + ""); + + Options.add_option("compute", + "", + "idms", + "Enable IDMS cloud detection; optionally specify a custom probe endpoint", + cxxopts::value<std::string>(m_ServerOptions.IdmsEndpoint)->default_value("")->implicit_value("auto"), + ""); + + Options.add_option("compute", + "", + "worker-websocket", + "Use WebSocket for worker-orchestrator link (instant reachability detection)", + cxxopts::value<bool>(m_ServerOptions.EnableWorkerWebSocket)->default_value("false"), + ""); + +# if ZEN_WITH_HORDE + // Horde provisioning options + Options.add_option("horde", + "", + "horde-enabled", + "Enable Horde worker provisioning", + cxxopts::value<bool>(m_ServerOptions.HordeConfig.Enabled)->default_value("false"), + ""); + + Options.add_option("horde", + "", + "horde-server", + "Horde server URL", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.ServerUrl)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-token", + "Horde authentication token", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.AuthToken)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-pool", + "Horde pool name", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.Pool)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-cluster", + "Horde cluster ID ('default' or '_auto' for auto-resolve)", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.Cluster)->default_value("default"), + ""); + + Options.add_option("horde", + "", + "horde-mode", + "Horde connection mode (direct, tunnel, relay)", + cxxopts::value<std::string>(m_HordeModeStr)->default_value("direct"), + ""); + + Options.add_option("horde", + "", + "horde-encryption", + "Horde transport encryption (none, aes)", + cxxopts::value<std::string>(m_HordeEncryptionStr)->default_value("none"), + ""); + + Options.add_option("horde", + "", + "horde-max-cores", + "Maximum number of Horde cores to provision", + cxxopts::value<int>(m_ServerOptions.HordeConfig.MaxCores)->default_value("2048"), + ""); + + Options.add_option("horde", + "", + "horde-host", + "Host address for Horde agents to connect back to", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.HostAddress)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-condition", + "Additional Horde agent filter condition", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.Condition)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-binaries", + "Path to directory containing zenserver binary for remote upload", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.BinariesPath)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-zen-service-port", + "Port number for Zen service communication", + cxxopts::value<uint16_t>(m_ServerOptions.HordeConfig.ZenServicePort)->default_value("8558"), + ""); +# endif + +# if ZEN_WITH_NOMAD + // Nomad provisioning options + Options.add_option("nomad", + "", + "nomad-enabled", + "Enable Nomad worker provisioning", + cxxopts::value<bool>(m_ServerOptions.NomadConfig.Enabled)->default_value("false"), + ""); + + Options.add_option("nomad", + "", + "nomad-server", + "Nomad HTTP API URL", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.ServerUrl)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-token", + "Nomad ACL token", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.AclToken)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-datacenter", + "Nomad target datacenter", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.Datacenter)->default_value("dc1"), + ""); + + Options.add_option("nomad", + "", + "nomad-namespace", + "Nomad namespace", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.Namespace)->default_value("default"), + ""); + + Options.add_option("nomad", + "", + "nomad-region", + "Nomad region (empty for server default)", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.Region)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-driver", + "Nomad task driver (raw_exec, docker)", + cxxopts::value<std::string>(m_NomadDriverStr)->default_value("raw_exec"), + ""); + + Options.add_option("nomad", + "", + "nomad-distribution", + "Binary distribution mode (predeployed, artifact)", + cxxopts::value<std::string>(m_NomadDistributionStr)->default_value("predeployed"), + ""); + + Options.add_option("nomad", + "", + "nomad-binary-path", + "Path to zenserver on Nomad clients (predeployed mode)", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.BinaryPath)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-artifact-source", + "URL to download zenserver binary (artifact mode)", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.ArtifactSource)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-docker-image", + "Docker image for zenserver (docker driver)", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.DockerImage)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-max-jobs", + "Maximum concurrent Nomad jobs", + cxxopts::value<int>(m_ServerOptions.NomadConfig.MaxJobs)->default_value("64"), + ""); + + Options.add_option("nomad", + "", + "nomad-cpu-mhz", + "CPU MHz allocated per Nomad task", + cxxopts::value<int>(m_ServerOptions.NomadConfig.CpuMhz)->default_value("1000"), + ""); + + Options.add_option("nomad", + "", + "nomad-memory-mb", + "Memory MB allocated per Nomad task", + cxxopts::value<int>(m_ServerOptions.NomadConfig.MemoryMb)->default_value("2048"), + ""); + + Options.add_option("nomad", + "", + "nomad-cores-per-job", + "Estimated cores per Nomad job (for scaling)", + cxxopts::value<int>(m_ServerOptions.NomadConfig.CoresPerJob)->default_value("32"), + ""); + + Options.add_option("nomad", + "", + "nomad-max-cores", + "Maximum total cores to provision via Nomad", + cxxopts::value<int>(m_ServerOptions.NomadConfig.MaxCores)->default_value("2048"), + ""); + + Options.add_option("nomad", + "", + "nomad-job-prefix", + "Prefix for generated Nomad job IDs", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.JobPrefix)->default_value("zenserver-worker"), + ""); +# endif +} + +void +ZenComputeServerConfigurator::AddConfigOptions(LuaConfig::Options& Options) +{ + ZEN_UNUSED(Options); +} + +void +ZenComputeServerConfigurator::ApplyOptions(cxxopts::Options& Options) +{ + ZEN_UNUSED(Options); +} + +void +ZenComputeServerConfigurator::OnConfigFileParsed(LuaConfig::Options& LuaOptions) +{ + ZEN_UNUSED(LuaOptions); +} + +void +ZenComputeServerConfigurator::ValidateOptions() +{ +# if ZEN_WITH_HORDE + horde::FromString(m_ServerOptions.HordeConfig.Mode, m_HordeModeStr); + horde::FromString(m_ServerOptions.HordeConfig.EncryptionMode, m_HordeEncryptionStr); +# endif + +# if ZEN_WITH_NOMAD + nomad::FromString(m_ServerOptions.NomadConfig.TaskDriver, m_NomadDriverStr); + nomad::FromString(m_ServerOptions.NomadConfig.BinDistribution, m_NomadDistributionStr); +# endif +} + +/////////////////////////////////////////////////////////////////////////// + +ZenComputeServer::ZenComputeServer() +{ +} + +ZenComputeServer::~ZenComputeServer() +{ + Cleanup(); +} + +int +ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServerState::ZenServerEntry* ServerEntry) +{ + ZEN_TRACE_CPU("ZenComputeServer::Initialize"); + ZEN_MEMSCOPE(GetZenserverTag()); + + ZEN_INFO(ZEN_APP_NAME " initializing in COMPUTE server mode"); + + const int EffectiveBasePort = ZenServerBase::Initialize(ServerConfig, ServerEntry); + if (EffectiveBasePort < 0) + { + return EffectiveBasePort; + } + + m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint; + m_InstanceId = ServerConfig.InstanceId; + m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket; + + // This is a workaround to make sure we can have automated tests. Without + // this the ranges for different child zen compute processes could overlap with + // the main test range. + ZenServerEnvironment::SetBaseChildId(2000); + + m_DebugOptionForcedCrash = ServerConfig.ShouldCrash; + + InitializeState(ServerConfig); + InitializeServices(ServerConfig); + RegisterServices(ServerConfig); + + ZenServerBase::Finalize(); + + return EffectiveBasePort; +} + +void +ZenComputeServer::Cleanup() +{ + ZEN_TRACE_CPU("ZenComputeServer::Cleanup"); + ZEN_INFO(ZEN_APP_NAME " cleaning up"); + try + { + // Cancel the maintenance timer so it stops re-enqueuing before we + // tear down the provisioners it references. + m_ProvisionerMaintenanceTimer.cancel(); + m_AnnounceTimer.cancel(); + +# if ZEN_WITH_HORDE + // Shut down Horde provisioner first — this signals all agent threads + // to exit and joins them before we tear down HTTP services. + m_HordeProvisioner.reset(); +# endif + +# if ZEN_WITH_NOMAD + // Shut down Nomad provisioner — stops the management thread and + // sends stop requests for all tracked jobs. + m_NomadProvisioner.reset(); +# endif + + // Close the orchestrator WebSocket client before stopping the io_context + m_WsReconnectTimer.cancel(); + if (m_OrchestratorWsClient) + { + m_OrchestratorWsClient->Close(); + m_OrchestratorWsClient.reset(); + } + m_OrchestratorWsHandler.reset(); + + ResolveCloudMetadata(); + m_CloudMetadata.reset(); + + // Shut down services that own threads or use the io_context before we + // stop the io_context and close the HTTP server. + if (m_OrchestratorService) + { + m_OrchestratorService->Shutdown(); + } + if (m_ComputeService) + { + m_ComputeService->Shutdown(); + } + + m_IoContext.stop(); + if (m_IoRunner.joinable()) + { + m_IoRunner.join(); + } + + ShutdownServices(); + + if (m_Http) + { + m_Http->Close(); + } + } + catch (const std::exception& Ex) + { + ZEN_ERROR("exception thrown during Cleanup() in {}: '{}'", ZEN_APP_NAME, Ex.what()); + } +} + +void +ZenComputeServer::InitializeState(const ZenComputeServerConfig& ServerConfig) +{ + ZEN_UNUSED(ServerConfig); +} + +void +ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) +{ + ZEN_TRACE_CPU("ZenComputeServer::InitializeServices"); + ZEN_INFO("initializing compute services"); + + CidStoreConfiguration Config; + Config.RootDirectory = m_DataRoot / "cas"; + + m_CidStore = std::make_unique<CidStore>(m_GcManager); + m_CidStore->Initialize(Config); + + if (!ServerConfig.IdmsEndpoint.empty()) + { + ZEN_INFO("detecting cloud environment (async)"); + if (ServerConfig.IdmsEndpoint == "auto") + { + m_CloudMetadataFuture = std::async(std::launch::async, [DataDir = ServerConfig.DataDir] { + return std::make_unique<zen::compute::CloudMetadata>(DataDir / "cloud"); + }); + } + else + { + ZEN_INFO("using custom IDMS endpoint: {}", ServerConfig.IdmsEndpoint); + m_CloudMetadataFuture = std::async(std::launch::async, [DataDir = ServerConfig.DataDir, Endpoint = ServerConfig.IdmsEndpoint] { + return std::make_unique<zen::compute::CloudMetadata>(DataDir / "cloud", Endpoint); + }); + } + } + + ZEN_INFO("instantiating API service"); + m_ApiService = std::make_unique<zen::HttpApiService>(*m_Http); + + ZEN_INFO("instantiating orchestrator service"); + m_OrchestratorService = + std::make_unique<zen::compute::HttpOrchestratorService>(ServerConfig.DataDir / "orch", ServerConfig.EnableWorkerWebSocket); + + ZEN_INFO("instantiating function service"); + m_ComputeService = std::make_unique<zen::compute::HttpComputeService>(*m_CidStore, + m_StatsService, + ServerConfig.DataDir / "functions", + ServerConfig.MaxConcurrentActions); + + m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatusService); + +# if ZEN_WITH_NOMAD + // Nomad provisioner + if (ServerConfig.NomadConfig.Enabled && !ServerConfig.NomadConfig.ServerUrl.empty()) + { + ZEN_INFO("instantiating Nomad provisioner (server: {})", ServerConfig.NomadConfig.ServerUrl); + + const auto& NomadCfg = ServerConfig.NomadConfig; + + if (!NomadCfg.Validate()) + { + ZEN_ERROR("invalid Nomad configuration"); + } + else + { + ExtendableStringBuilder<256> OrchestratorEndpoint; + OrchestratorEndpoint << m_Http->GetServiceUri(m_OrchestratorService.get()); + if (auto View = OrchestratorEndpoint.ToView(); !View.empty() && View.back() != '/') + { + OrchestratorEndpoint << '/'; + } + + m_NomadProvisioner = std::make_unique<nomad::NomadProvisioner>(NomadCfg, OrchestratorEndpoint); + } + } +# endif + +# if ZEN_WITH_HORDE + // Horde provisioner + if (ServerConfig.HordeConfig.Enabled && !ServerConfig.HordeConfig.ServerUrl.empty()) + { + ZEN_INFO("instantiating Horde provisioner (server: {})", ServerConfig.HordeConfig.ServerUrl); + + const auto& HordeConfig = ServerConfig.HordeConfig; + + if (!HordeConfig.Validate()) + { + ZEN_ERROR("invalid Horde configuration"); + } + else + { + ExtendableStringBuilder<256> OrchestratorEndpoint; + OrchestratorEndpoint << m_Http->GetServiceUri(m_OrchestratorService.get()); + if (auto View = OrchestratorEndpoint.ToView(); !View.empty() && View.back() != '/') + { + OrchestratorEndpoint << '/'; + } + + // If no binaries path is specified, just use the running executable's directory + std::filesystem::path BinariesPath = HordeConfig.BinariesPath.empty() ? GetRunningExecutablePath().parent_path() + : std::filesystem::path(HordeConfig.BinariesPath); + std::filesystem::path WorkingDir = ServerConfig.DataDir / "horde"; + + m_HordeProvisioner = std::make_unique<horde::HordeProvisioner>(HordeConfig, BinariesPath, WorkingDir, OrchestratorEndpoint); + } + } +# endif +} + +void +ZenComputeServer::ResolveCloudMetadata() +{ + if (m_CloudMetadataFuture.valid()) + { + m_CloudMetadata = m_CloudMetadataFuture.get(); + } +} + +std::string +ZenComputeServer::GetInstanceId() const +{ + if (!m_InstanceId.empty()) + { + return m_InstanceId; + } + return fmt::format("{}-{}", GetMachineName(), GetCurrentProcessId()); +} + +std::string +ZenComputeServer::GetAnnounceUrl() const +{ + return m_Http->GetServiceUri(nullptr); +} + +void +ZenComputeServer::RegisterServices(const ZenComputeServerConfig& ServerConfig) +{ + ZEN_TRACE_CPU("ZenComputeServer::RegisterServices"); + ZEN_UNUSED(ServerConfig); + + if (m_ApiService) + { + m_Http->RegisterService(*m_ApiService); + } + + if (m_OrchestratorService) + { + m_Http->RegisterService(*m_OrchestratorService); + } + + if (m_ComputeService) + { + m_Http->RegisterService(*m_ComputeService); + } + + if (m_FrontendService) + { + m_Http->RegisterService(*m_FrontendService); + } +} + +CbObject +ZenComputeServer::BuildAnnounceBody() +{ + CbObjectWriter AnnounceBody; + AnnounceBody << "id" << GetInstanceId(); + AnnounceBody << "uri" << GetAnnounceUrl(); + AnnounceBody << "hostname" << GetMachineName(); + AnnounceBody << "platform" << GetRuntimePlatformName(); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + + AnnounceBody.BeginObject("metrics"); + Describe(Sm, AnnounceBody); + AnnounceBody.EndObject(); + + AnnounceBody << "cpu_usage" << Sm.CpuUsagePercent; + AnnounceBody << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; + AnnounceBody << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; + + AnnounceBody << "bytes_received" << m_Http->GetTotalBytesReceived(); + AnnounceBody << "bytes_sent" << m_Http->GetTotalBytesSent(); + + auto Actions = m_ComputeService->GetActionCounts(); + AnnounceBody << "actions_pending" << Actions.Pending; + AnnounceBody << "actions_running" << Actions.Running; + AnnounceBody << "actions_completed" << Actions.Completed; + AnnounceBody << "active_queues" << Actions.ActiveQueues; + + // Derive provisioner from instance ID prefix (e.g. "horde-xxx" or "nomad-xxx") + if (m_InstanceId.starts_with("horde-")) + { + AnnounceBody << "provisioner" + << "horde"; + } + else if (m_InstanceId.starts_with("nomad-")) + { + AnnounceBody << "provisioner" + << "nomad"; + } + + ResolveCloudMetadata(); + if (m_CloudMetadata) + { + m_CloudMetadata->Describe(AnnounceBody); + } + + return AnnounceBody.Save(); +} + +void +ZenComputeServer::PostAnnounce() +{ + ZEN_TRACE_CPU("ZenComputeServer::PostAnnounce"); + + if (!m_ComputeService || m_CoordinatorEndpoint.empty()) + { + return; + } + + ZEN_INFO("notifying coordinator at '{}' of our availability at '{}'", m_CoordinatorEndpoint, GetAnnounceUrl()); + + try + { + CbObject Body = BuildAnnounceBody(); + + // If we have an active WebSocket connection, send via that instead of HTTP POST + if (m_OrchestratorWsClient && m_OrchestratorWsClient->IsOpen()) + { + MemoryView View = Body.GetView(); + m_OrchestratorWsClient->SendBinary(std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(View.GetData()), View.GetSize())); + ZEN_INFO("announced to coordinator via WebSocket"); + return; + } + + HttpClient CoordinatorHttp(m_CoordinatorEndpoint); + HttpClient::Response Result = CoordinatorHttp.Post("announce", std::move(Body)); + + if (Result.Error) + { + ZEN_ERROR("failed to notify coordinator at '{}': HTTP error {} - {}", + m_CoordinatorEndpoint, + Result.Error->ErrorCode, + Result.Error->ErrorMessage); + } + else if (!IsHttpOk(Result.StatusCode)) + { + ZEN_ERROR("failed to notify coordinator at '{}': unexpected HTTP status code {}", + m_CoordinatorEndpoint, + static_cast<int>(Result.StatusCode)); + } + else + { + ZEN_INFO("successfully notified coordinator at '{}'", m_CoordinatorEndpoint); + } + } + catch (const std::exception& Ex) + { + ZEN_ERROR("failed to notify coordinator at '{}': {}", m_CoordinatorEndpoint, Ex.what()); + } +} + +void +ZenComputeServer::EnqueueAnnounceTimer() +{ + if (!m_ComputeService || m_CoordinatorEndpoint.empty()) + { + return; + } + + m_AnnounceTimer.expires_after(std::chrono::seconds(15)); + m_AnnounceTimer.async_wait([this](const asio::error_code& Ec) { + if (!Ec) + { + PostAnnounce(); + EnqueueAnnounceTimer(); + } + }); + EnsureIoRunner(); +} + +void +ZenComputeServer::InitializeOrchestratorWebSocket() +{ + if (!m_EnableWorkerWebSocket || m_CoordinatorEndpoint.empty()) + { + return; + } + + // Convert http://host:port → ws://host:port/orch/ws + std::string WsUrl = m_CoordinatorEndpoint; + if (WsUrl.starts_with("http://")) + { + WsUrl = "ws://" + WsUrl.substr(7); + } + else if (WsUrl.starts_with("https://")) + { + WsUrl = "wss://" + WsUrl.substr(8); + } + if (!WsUrl.empty() && WsUrl.back() != '/') + { + WsUrl += '/'; + } + WsUrl += "orch/ws"; + + ZEN_INFO("establishing WebSocket link to orchestrator at {}", WsUrl); + + m_OrchestratorWsHandler = std::make_unique<OrchestratorWsHandler>(*this); + m_OrchestratorWsClient = + std::make_unique<HttpWsClient>(WsUrl, *m_OrchestratorWsHandler, m_IoContext, HttpWsClientSettings{.LogCategory = "orch_ws"}); + + m_OrchestratorWsClient->Connect(); + EnsureIoRunner(); +} + +void +ZenComputeServer::EnqueueWsReconnect() +{ + m_WsReconnectTimer.expires_after(std::chrono::seconds(5)); + m_WsReconnectTimer.async_wait([this](const asio::error_code& Ec) { + if (!Ec && m_OrchestratorWsClient) + { + ZEN_INFO("attempting WebSocket reconnect to orchestrator"); + m_OrchestratorWsClient->Connect(); + } + }); + EnsureIoRunner(); +} + +void +ZenComputeServer::OrchestratorWsHandler::OnWsOpen() +{ + ZEN_INFO("WebSocket link to orchestrator established"); + + // Send initial announce immediately over the WebSocket + Server.PostAnnounce(); +} + +void +ZenComputeServer::OrchestratorWsHandler::OnWsMessage([[maybe_unused]] const WebSocketMessage& Msg) +{ + // Orchestrator does not push messages to workers; ignore +} + +void +ZenComputeServer::OrchestratorWsHandler::OnWsClose([[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + ZEN_WARN("WebSocket link to orchestrator closed (code {}), falling back to HTTP announce", Code); + + // Trigger an immediate HTTP announce so the orchestrator has fresh state, + // then schedule a reconnect attempt. + Server.PostAnnounce(); + Server.EnqueueWsReconnect(); +} + +void +ZenComputeServer::ProvisionerMaintenanceTick() +{ +# if ZEN_WITH_HORDE + if (m_HordeProvisioner) + { + m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_HordeProvisioner->GetStats(); + ZEN_DEBUG("Horde maintenance: target={}, estimated={}, active={}", + Stats.TargetCoreCount, + Stats.EstimatedCoreCount, + Stats.ActiveCoreCount); + } +# endif + +# if ZEN_WITH_NOMAD + if (m_NomadProvisioner) + { + m_NomadProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_NomadProvisioner->GetStats(); + ZEN_DEBUG("Nomad maintenance: target={}, estimated={}, running jobs={}", + Stats.TargetCoreCount, + Stats.EstimatedCoreCount, + Stats.RunningJobCount); + } +# endif +} + +void +ZenComputeServer::EnqueueProvisionerMaintenanceTimer() +{ + bool HasProvisioner = false; +# if ZEN_WITH_HORDE + HasProvisioner = HasProvisioner || (m_HordeProvisioner != nullptr); +# endif +# if ZEN_WITH_NOMAD + HasProvisioner = HasProvisioner || (m_NomadProvisioner != nullptr); +# endif + + if (!HasProvisioner) + { + return; + } + + m_ProvisionerMaintenanceTimer.expires_after(std::chrono::seconds(15)); + m_ProvisionerMaintenanceTimer.async_wait([this](const asio::error_code& Ec) { + if (!Ec) + { + ProvisionerMaintenanceTick(); + EnqueueProvisionerMaintenanceTimer(); + } + }); + EnsureIoRunner(); +} + +void +ZenComputeServer::Run() +{ + ZEN_TRACE_CPU("ZenComputeServer::Run"); + + if (m_ProcessMonitor.IsActive()) + { + CheckOwnerPid(); + } + + if (!m_TestMode) + { + // clang-format off + ZEN_INFO( R"(__________ _________ __ )" "\n" + R"(\____ /____ ____ \_ ___ \ ____ _____ ______ __ ___/ |_ ____ )" "\n" + R"( / // __ \ / \/ \ \/ / _ \ / \\____ \| | \ __\/ __ \ )" "\n" + R"( / /\ ___/| | \ \___( <_> ) Y Y \ |_> > | /| | \ ___/ )" "\n" + R"(/_______ \___ >___| /\______ /\____/|__|_| / __/|____/ |__| \___ >)" "\n" + R"( \/ \/ \/ \/ \/|__| \/ )"); + // clang-format on + + ExtendableStringBuilder<256> BuildOptions; + GetBuildOptions(BuildOptions, '\n'); + ZEN_INFO("Build options ({}/{}):\n{}", GetOperatingSystemName(), GetCpuName(), BuildOptions); + } + + ZEN_INFO(ZEN_APP_NAME " now running as COMPUTE (pid: {})", GetCurrentProcessId()); + +# if ZEN_PLATFORM_WINDOWS + if (zen::windows::IsRunningOnWine()) + { + ZEN_INFO("detected Wine session - " ZEN_APP_NAME " is not formally tested on Wine and may therefore not work or perform well"); + } +# endif + +# if ZEN_USE_SENTRY + ZEN_INFO("sentry crash handler {}", m_UseSentry ? "ENABLED" : "DISABLED"); + if (m_UseSentry) + { + SentryIntegration::ClearCaches(); + } +# endif + + if (m_DebugOptionForcedCrash) + { + ZEN_DEBUG_BREAK(); + } + + const bool IsInteractiveMode = IsInteractiveSession(); // &&!m_TestMode; + + SetNewState(kRunning); + + OnReady(); + + PostAnnounce(); + EnqueueAnnounceTimer(); + InitializeOrchestratorWebSocket(); + +# if ZEN_WITH_HORDE + // Start Horde provisioning if configured — request maximum allowed cores. + // SetTargetCoreCount clamps to HordeConfig::MaxCores internally. + if (m_HordeProvisioner) + { + ZEN_INFO("Horde provisioning starting"); + m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_HordeProvisioner->GetStats(); + ZEN_INFO("Horde provisioning started (target cores: {})", Stats.TargetCoreCount); + } +# endif + +# if ZEN_WITH_NOMAD + // Start Nomad provisioning if configured — request maximum allowed cores. + // SetTargetCoreCount clamps to NomadConfig::MaxCores internally. + if (m_NomadProvisioner) + { + m_NomadProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_NomadProvisioner->GetStats(); + ZEN_INFO("Nomad provisioning started (target cores: {})", Stats.TargetCoreCount); + } +# endif + + EnqueueProvisionerMaintenanceTimer(); + + m_Http->Run(IsInteractiveMode); + + SetNewState(kShuttingDown); + + ZEN_INFO(ZEN_APP_NAME " exiting"); +} + +////////////////////////////////////////////////////////////////////////////////// + +ZenComputeServerMain::ZenComputeServerMain(ZenComputeServerConfig& ServerOptions) +: ZenServerMain(ServerOptions) +, m_ServerOptions(ServerOptions) +{ +} + +void +ZenComputeServerMain::DoRun(ZenServerState::ZenServerEntry* Entry) +{ + ZEN_TRACE_CPU("ZenComputeServerMain::DoRun"); + + ZenComputeServer Server; + Server.SetDataRoot(m_ServerOptions.DataDir); + Server.SetContentRoot(m_ServerOptions.ContentDir); + Server.SetTestMode(m_ServerOptions.IsTest); + Server.SetDedicatedMode(m_ServerOptions.IsDedicated); + + const int EffectiveBasePort = Server.Initialize(m_ServerOptions, Entry); + if (EffectiveBasePort == -1) + { + // Server.Initialize has already logged what the issue is - just exit with failure code here. + std::exit(1); + } + + Entry->EffectiveListenPort = uint16_t(EffectiveBasePort); + if (EffectiveBasePort != m_ServerOptions.BasePort) + { + ZEN_INFO(ZEN_APP_NAME " - relocated to base port {}", EffectiveBasePort); + m_ServerOptions.BasePort = EffectiveBasePort; + } + + std::unique_ptr<std::thread> ShutdownThread; + std::unique_ptr<NamedEvent> ShutdownEvent; + + ExtendableStringBuilder<64> ShutdownEventName; + ShutdownEventName << "Zen_" << m_ServerOptions.BasePort << "_Shutdown"; + ShutdownEvent.reset(new NamedEvent{ShutdownEventName}); + + // Monitor shutdown signals + + ShutdownThread.reset(new std::thread{[&] { + SetCurrentThreadName("shutdown_mon"); + + ZEN_INFO("shutdown monitor thread waiting for shutdown signal '{}' for process {}", ShutdownEventName, zen::GetCurrentProcessId()); + + if (ShutdownEvent->Wait()) + { + ZEN_INFO("shutdown signal for pid {} received", zen::GetCurrentProcessId()); + Server.RequestExit(0); + } + else + { + ZEN_INFO("shutdown signal wait() failed"); + } + }}); + + auto CleanupShutdown = MakeGuard([&ShutdownEvent, &ShutdownThread] { + ReportServiceStatus(ServiceStatus::Stopping); + + if (ShutdownEvent) + { + ShutdownEvent->Set(); + } + if (ShutdownThread && ShutdownThread->joinable()) + { + ShutdownThread->join(); + } + }); + + // If we have a parent process, establish the mechanisms we need + // to be able to communicate readiness with the parent + + Server.SetIsReadyFunc([&] { + std::error_code Ec; + m_LockFile.Update(MakeLockData(true), Ec); + ReportServiceStatus(ServiceStatus::Running); + NotifyReady(); + }); + + Server.Run(); +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h new file mode 100644 index 000000000..8f4edc0f0 --- /dev/null +++ b/src/zenserver/compute/computeserver.h @@ -0,0 +1,188 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenserver.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <future> +# include <zencore/system.h> +# include <zenhttp/httpwsclient.h> +# include <zenstore/gc.h> +# include "frontend/frontend.h" + +namespace cxxopts { +class Options; +} +namespace zen::LuaConfig { +struct Options; +} + +namespace zen::compute { +class CloudMetadata; +class HttpComputeService; +class HttpOrchestratorService; +} // namespace zen::compute + +# if ZEN_WITH_HORDE +# include <zenhorde/hordeconfig.h> +namespace zen::horde { +class HordeProvisioner; +} // namespace zen::horde +# endif + +# if ZEN_WITH_NOMAD +# include <zennomad/nomadconfig.h> +namespace zen::nomad { +class NomadProvisioner; +} // namespace zen::nomad +# endif + +namespace zen { + +class CidStore; +class HttpApiService; + +struct ZenComputeServerConfig : public ZenServerConfig +{ + std::string UpstreamNotificationEndpoint; + std::string InstanceId; // For use in notifications + std::string CoordinatorEndpoint; + std::string IdmsEndpoint; + int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2) + bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link + +# if ZEN_WITH_HORDE + horde::HordeConfig HordeConfig; +# endif + +# if ZEN_WITH_NOMAD + nomad::NomadConfig NomadConfig; +# endif +}; + +struct ZenComputeServerConfigurator : public ZenServerConfiguratorBase +{ + ZenComputeServerConfigurator(ZenComputeServerConfig& ServerOptions) + : ZenServerConfiguratorBase(ServerOptions) + , m_ServerOptions(ServerOptions) + { + } + + ~ZenComputeServerConfigurator() = default; + +private: + virtual void AddCliOptions(cxxopts::Options& Options) override; + virtual void AddConfigOptions(LuaConfig::Options& Options) override; + virtual void ApplyOptions(cxxopts::Options& Options) override; + virtual void OnConfigFileParsed(LuaConfig::Options& LuaOptions) override; + virtual void ValidateOptions() override; + + ZenComputeServerConfig& m_ServerOptions; + +# if ZEN_WITH_HORDE + std::string m_HordeModeStr = "direct"; + std::string m_HordeEncryptionStr = "none"; +# endif + +# if ZEN_WITH_NOMAD + std::string m_NomadDriverStr = "raw_exec"; + std::string m_NomadDistributionStr = "predeployed"; +# endif +}; + +class ZenComputeServerMain : public ZenServerMain +{ +public: + ZenComputeServerMain(ZenComputeServerConfig& ServerOptions); + virtual void DoRun(ZenServerState::ZenServerEntry* Entry) override; + + ZenComputeServerMain(const ZenComputeServerMain&) = delete; + ZenComputeServerMain& operator=(const ZenComputeServerMain&) = delete; + + typedef ZenComputeServerConfig Config; + typedef ZenComputeServerConfigurator Configurator; + +private: + ZenComputeServerConfig& m_ServerOptions; +}; + +/** + * The compute server handles DDC build function execution requests + * only. It's intended to be used on a pure compute resource and does + * not handle any storage tasks. The actual scheduling happens upstream + * in a storage server instance. + */ + +class ZenComputeServer : public ZenServerBase +{ + ZenComputeServer& operator=(ZenComputeServer&&) = delete; + ZenComputeServer(ZenComputeServer&&) = delete; + +public: + ZenComputeServer(); + ~ZenComputeServer(); + + int Initialize(const ZenComputeServerConfig& ServerConfig, ZenServerState::ZenServerEntry* ServerEntry); + void Run(); + void Cleanup(); + +private: + GcManager m_GcManager; + GcScheduler m_GcScheduler{m_GcManager}; + std::unique_ptr<CidStore> m_CidStore; + std::unique_ptr<HttpApiService> m_ApiService; + std::unique_ptr<zen::compute::HttpComputeService> m_ComputeService; + std::unique_ptr<zen::compute::HttpOrchestratorService> m_OrchestratorService; + std::unique_ptr<zen::compute::CloudMetadata> m_CloudMetadata; + std::future<std::unique_ptr<zen::compute::CloudMetadata>> m_CloudMetadataFuture; + std::unique_ptr<HttpFrontendService> m_FrontendService; +# if ZEN_WITH_HORDE + std::unique_ptr<zen::horde::HordeProvisioner> m_HordeProvisioner; +# endif +# if ZEN_WITH_NOMAD + std::unique_ptr<zen::nomad::NomadProvisioner> m_NomadProvisioner; +# endif + SystemMetricsTracker m_MetricsTracker; + std::string m_CoordinatorEndpoint; + std::string m_InstanceId; + + asio::steady_timer m_AnnounceTimer{m_IoContext}; + asio::steady_timer m_ProvisionerMaintenanceTimer{m_IoContext}; + + void InitializeState(const ZenComputeServerConfig& ServerConfig); + void InitializeServices(const ZenComputeServerConfig& ServerConfig); + void RegisterServices(const ZenComputeServerConfig& ServerConfig); + void ResolveCloudMetadata(); + void PostAnnounce(); + void EnqueueAnnounceTimer(); + void EnqueueProvisionerMaintenanceTimer(); + void ProvisionerMaintenanceTick(); + std::string GetAnnounceUrl() const; + std::string GetInstanceId() const; + CbObject BuildAnnounceBody(); + + // Worker→orchestrator WebSocket client + struct OrchestratorWsHandler : public IWsClientHandler + { + ZenComputeServer& Server; + explicit OrchestratorWsHandler(ZenComputeServer& S) : Server(S) {} + + void OnWsOpen() override; + void OnWsMessage(const WebSocketMessage& Msg) override; + void OnWsClose(uint16_t Code, std::string_view Reason) override; + }; + + std::unique_ptr<OrchestratorWsHandler> m_OrchestratorWsHandler; + std::unique_ptr<HttpWsClient> m_OrchestratorWsClient; + asio::steady_timer m_WsReconnectTimer{m_IoContext}; + bool m_EnableWorkerWebSocket = false; + + void InitializeOrchestratorWebSocket(); + void EnqueueWsReconnect(); +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp index 07913e891..e36352dae 100644 --- a/src/zenserver/config/config.cpp +++ b/src/zenserver/config/config.cpp @@ -16,8 +16,8 @@ #include <zencore/iobuffer.h> #include <zencore/logging.h> #include <zencore/string.h> -#include <zenutil/commandlineoptions.h> -#include <zenutil/environmentoptions.h> +#include <zenutil/config/commandlineoptions.h> +#include <zenutil/config/environmentoptions.h> ZEN_THIRD_PARTY_INCLUDES_START #include <fmt/format.h> @@ -119,10 +119,17 @@ ZenServerConfiguratorBase::AddCommonConfigOptions(LuaConfig::Options& LuaOptions ZenServerConfig& ServerOptions = m_ServerOptions; + // logging + + LuaOptions.AddOption("server.logid"sv, ServerOptions.LoggingConfig.LogId, "log-id"sv); + LuaOptions.AddOption("server.abslog"sv, ServerOptions.LoggingConfig.AbsLogFile, "abslog"sv); + LuaOptions.AddOption("server.otlpendpoint"sv, ServerOptions.LoggingConfig.OtelEndpointUri, "otlp-endpoint"sv); + LuaOptions.AddOption("server.quiet"sv, ServerOptions.LoggingConfig.QuietConsole, "quiet"sv); + LuaOptions.AddOption("server.noconsole"sv, ServerOptions.LoggingConfig.NoConsoleOutput, "noconsole"sv); + // server LuaOptions.AddOption("server.dedicated"sv, ServerOptions.IsDedicated, "dedicated"sv); - LuaOptions.AddOption("server.logid"sv, ServerOptions.LogId, "log-id"sv); LuaOptions.AddOption("server.sentry.disable"sv, ServerOptions.SentryConfig.Disable, "no-sentry"sv); LuaOptions.AddOption("server.sentry.allowpersonalinfo"sv, ServerOptions.SentryConfig.AllowPII, "sentry-allow-personal-info"sv); LuaOptions.AddOption("server.sentry.dsn"sv, ServerOptions.SentryConfig.Dsn, "sentry-dsn"sv); @@ -131,12 +138,9 @@ ZenServerConfiguratorBase::AddCommonConfigOptions(LuaConfig::Options& LuaOptions LuaOptions.AddOption("server.systemrootdir"sv, ServerOptions.SystemRootDir, "system-dir"sv); LuaOptions.AddOption("server.datadir"sv, ServerOptions.DataDir, "data-dir"sv); LuaOptions.AddOption("server.contentdir"sv, ServerOptions.ContentDir, "content-dir"sv); - LuaOptions.AddOption("server.abslog"sv, ServerOptions.AbsLogFile, "abslog"sv); - LuaOptions.AddOption("server.otlpendpoint"sv, ServerOptions.OtelEndpointUri, "otlp-endpoint"sv); LuaOptions.AddOption("server.debug"sv, ServerOptions.IsDebug, "debug"sv); LuaOptions.AddOption("server.clean"sv, ServerOptions.IsCleanStart, "clean"sv); - LuaOptions.AddOption("server.quiet"sv, ServerOptions.QuietConsole, "quiet"sv); - LuaOptions.AddOption("server.noconsole"sv, ServerOptions.NoConsoleOutput, "noconsole"sv); + LuaOptions.AddOption("server.security.configpath"sv, ServerOptions.SecurityConfigPath, "security-config-path"sv); ////// network @@ -182,8 +186,10 @@ struct ZenServerCmdLineOptions std::string SystemRootDir; std::string ContentDir; std::string DataDir; - std::string AbsLogFile; std::string BaseSnapshotDir; + std::string SecurityConfigPath; + + ZenLoggingCmdLineOptions LoggingOptions; void AddCliOptions(cxxopts::Options& options, ZenServerConfig& ServerOptions); void ApplyOptions(cxxopts::Options& options, ZenServerConfig& ServerOptions); @@ -249,22 +255,7 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi cxxopts::value<bool>(ServerOptions.ShouldCrash)->default_value("false"), ""); - // clang-format off - options.add_options("logging") - ("abslog", "Path to log file", cxxopts::value<std::string>(AbsLogFile)) - ("log-id", "Specify id for adding context to log output", cxxopts::value<std::string>(ServerOptions.LogId)) - ("quiet", "Configure console logger output to level WARN", cxxopts::value<bool>(ServerOptions.QuietConsole)->default_value("false")) - ("noconsole", "Disable console logging", cxxopts::value<bool>(ServerOptions.NoConsoleOutput)->default_value("false")) - ("log-trace", "Change selected loggers to level TRACE", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Trace])) - ("log-debug", "Change selected loggers to level DEBUG", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Debug])) - ("log-info", "Change selected loggers to level INFO", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Info])) - ("log-warn", "Change selected loggers to level WARN", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Warn])) - ("log-error", "Change selected loggers to level ERROR", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Err])) - ("log-critical", "Change selected loggers to level CRITICAL", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Critical])) - ("log-off", "Change selected loggers to level OFF", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Off])) - ("otlp-endpoint", "OpenTelemetry endpoint URI (e.g http://localhost:4318)", cxxopts::value<std::string>(ServerOptions.OtelEndpointUri)) - ; - // clang-format on + LoggingOptions.AddCliOptions(options, ServerOptions.LoggingConfig); options .add_option("lifetime", "", "owner-pid", "Specify owning process id", cxxopts::value<int>(ServerOptions.OwnerPid), "<identifier>"); @@ -311,6 +302,13 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi cxxopts::value<bool>(ServerOptions.HttpConfig.ForceLoopback)->default_value("false"), "<http forceloopback>"); + options.add_option("network", + "", + "security-config-path", + "Path to http security configuration file", + cxxopts::value<std::string>(SecurityConfigPath), + "<security config path>"); + #if ZEN_WITH_HTTPSYS options.add_option("httpsys", "", @@ -391,12 +389,14 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig throw std::runtime_error(fmt::format("'--snapshot-dir' ('{}') must be a directory", ServerOptions.BaseSnapshotDir)); } - ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir); - ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir); - ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir); - ServerOptions.AbsLogFile = MakeSafeAbsolutePath(AbsLogFile); - ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile); - ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); + ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir); + ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir); + ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir); + ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile); + ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); + ServerOptions.SecurityConfigPath = MakeSafeAbsolutePath(SecurityConfigPath); + + LoggingOptions.ApplyOptions(ServerOptions.LoggingConfig); } ////////////////////////////////////////////////////////////////////////// @@ -466,34 +466,7 @@ ZenServerConfiguratorBase::Configure(int argc, char* argv[]) } #endif - if (m_ServerOptions.QuietConsole) - { - bool HasExplicitConsoleLevel = false; - for (int i = 0; i < logging::level::LogLevelCount; ++i) - { - if (m_ServerOptions.Loggers[i].find("console") != std::string::npos) - { - HasExplicitConsoleLevel = true; - break; - } - } - - if (!HasExplicitConsoleLevel) - { - std::string& WarnLoggers = m_ServerOptions.Loggers[logging::level::Warn]; - if (!WarnLoggers.empty()) - { - WarnLoggers += ","; - } - WarnLoggers += "console"; - } - } - - for (int i = 0; i < logging::level::LogLevelCount; ++i) - { - logging::ConfigureLogLevels(logging::level::LogLevel(i), m_ServerOptions.Loggers[i]); - } - logging::RefreshLogLevels(); + ApplyLoggingOptions(options, m_ServerOptions.LoggingConfig); BaseOptions.ApplyOptions(options, m_ServerOptions); ApplyOptions(options); @@ -532,9 +505,9 @@ ZenServerConfiguratorBase::Configure(int argc, char* argv[]) m_ServerOptions.DataDir = PickDefaultStateDirectory(m_ServerOptions.SystemRootDir); } - if (m_ServerOptions.AbsLogFile.empty()) + if (m_ServerOptions.LoggingConfig.AbsLogFile.empty()) { - m_ServerOptions.AbsLogFile = m_ServerOptions.DataDir / "logs" / "zenserver.log"; + m_ServerOptions.LoggingConfig.AbsLogFile = m_ServerOptions.DataDir / "logs" / "zenserver.log"; } m_ServerOptions.HttpConfig.IsDedicatedServer = m_ServerOptions.IsDedicated; diff --git a/src/zenserver/config/config.h b/src/zenserver/config/config.h index 7c3192a1f..55aee07f9 100644 --- a/src/zenserver/config/config.h +++ b/src/zenserver/config/config.h @@ -6,6 +6,7 @@ #include <zencore/trace.h> #include <zencore/zencore.h> #include <zenhttp/httpserver.h> +#include <zenutil/config/loggingconfig.h> #include <filesystem> #include <string> #include <vector> @@ -42,29 +43,26 @@ struct ZenServerConfig HttpServerConfig HttpConfig; ZenSentryConfig SentryConfig; ZenStatsConfig StatsConfig; - int BasePort = 8558; // Service listen port (used for both UDP and TCP) - int OwnerPid = 0; // Parent process id (zero for standalone) - bool IsDebug = false; - bool IsCleanStart = false; // Indicates whether all state should be wiped on startup or not - bool IsPowerCycle = false; // When true, the process shuts down immediately after initialization - bool IsTest = false; - bool Detach = true; // Whether zenserver should detach from existing process group (Mac/Linux) - bool NoConsoleOutput = false; // Control default use of stdout for diagnostics - bool QuietConsole = false; // Configure console logger output to level WARN - int CoreLimit = 0; // If set, hardware concurrency queries are capped at this number - bool IsDedicated = false; // Indicates a dedicated/shared instance, with larger resource requirements - bool ShouldCrash = false; // Option for testing crash handling - bool IsFirstRun = false; - std::filesystem::path ConfigFile; // Path to Lua config file - std::filesystem::path SystemRootDir; // System root directory (used for machine level config) - std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental) - std::filesystem::path DataDir; // Root directory for state (used for testing) - std::filesystem::path AbsLogFile; // Absolute path to main log file - std::filesystem::path BaseSnapshotDir; // Path to server state snapshot (will be copied into data dir on start) - std::string ChildId; // Id assigned by parent process (used for lifetime management) - std::string LogId; // Id for tagging log output - std::string Loggers[zen::logging::level::LogLevelCount]; - std::string OtelEndpointUri; // OpenTelemetry endpoint URI + ZenLoggingConfig LoggingConfig; + int BasePort = 8558; // Service listen port (used for both UDP and TCP) + int OwnerPid = 0; // Parent process id (zero for standalone) + bool IsDebug = false; + bool IsCleanStart = false; // Indicates whether all state should be wiped on startup or not + bool IsPowerCycle = false; // When true, the process shuts down immediately after initialization + bool IsTest = false; + bool Detach = true; // Whether zenserver should detach from existing process group (Mac/Linux) + int CoreLimit = 0; // If set, hardware concurrency queries are capped at this number + int LieCpu = 0; + bool IsDedicated = false; // Indicates a dedicated/shared instance, with larger resource requirements + bool ShouldCrash = false; // Option for testing crash handling + bool IsFirstRun = false; + std::filesystem::path ConfigFile; // Path to Lua config file + std::filesystem::path SystemRootDir; // System root directory (used for machine level config) + std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental) + std::filesystem::path DataDir; // Root directory for state (used for testing) + std::filesystem::path BaseSnapshotDir; // Path to server state snapshot (will be copied into data dir on start) + std::string ChildId; // Id assigned by parent process (used for lifetime management) + std::filesystem::path SecurityConfigPath; // Path to a Json security configuration file #if ZEN_WITH_TRACE bool HasTraceCommandlineOptions = false; diff --git a/src/zenserver/config/luaconfig.h b/src/zenserver/config/luaconfig.h index ce7013a9a..e3ac3b343 100644 --- a/src/zenserver/config/luaconfig.h +++ b/src/zenserver/config/luaconfig.h @@ -4,7 +4,7 @@ #include <zenbase/concepts.h> #include <zencore/fmtutils.h> -#include <zenutil/commandlineoptions.h> +#include <zenutil/config/commandlineoptions.h> ZEN_THIRD_PARTY_INCLUDES_START #include <fmt/format.h> diff --git a/src/zenserver/diag/diagsvcs.cpp b/src/zenserver/diag/diagsvcs.cpp index d8d53b0e3..dd4b8956c 100644 --- a/src/zenserver/diag/diagsvcs.cpp +++ b/src/zenserver/diag/diagsvcs.cpp @@ -9,12 +9,11 @@ #include <zencore/logging.h> #include <zencore/memory/llm.h> #include <zencore/string.h> +#include <zencore/system.h> #include <fstream> #include <sstream> -ZEN_THIRD_PARTY_INCLUDES_START -#include <spdlog/logger.h> -ZEN_THIRD_PARTY_INCLUDES_END +#include <zencore/logging/logger.h> namespace zen { @@ -53,6 +52,36 @@ HttpHealthService::HttpHealthService() Writer << "AbsLogPath"sv << m_HealthInfo.AbsLogPath.string(); Writer << "BuildVersion"sv << m_HealthInfo.BuildVersion; Writer << "HttpServerClass"sv << m_HealthInfo.HttpServerClass; + Writer << "Port"sv << m_HealthInfo.Port; + Writer << "Pid"sv << m_HealthInfo.Pid; + Writer << "IsDedicated"sv << m_HealthInfo.IsDedicated; + Writer << "StartTimeMs"sv << m_HealthInfo.StartTimeMs; + } + + Writer.BeginObject("RuntimeConfig"sv); + for (const auto& Opt : m_HealthInfo.RuntimeConfig) + { + Writer << Opt.first << Opt.second; + } + Writer.EndObject(); + + Writer.BeginObject("BuildConfig"sv); + for (const auto& Opt : m_HealthInfo.BuildOptions) + { + Writer << Opt.first << Opt.second; + } + Writer.EndObject(); + + Writer << "Hostname"sv << GetMachineName(); + Writer << "Platform"sv << GetRuntimePlatformName(); + Writer << "Arch"sv << GetCpuName(); + Writer << "OS"sv << GetOperatingSystemVersion(); + + { + auto Metrics = GetSystemMetrics(); + Writer.BeginObject("System"sv); + Describe(Metrics, Writer); + Writer.EndObject(); } HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save()); @@ -64,7 +93,7 @@ HttpHealthService::HttpHealthService() [this](HttpRouterRequest& RoutedReq) { HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); - zen::Log().SpdLogger->flush(); + zen::Log().Flush(); std::filesystem::path Path = [&] { RwLock::SharedLockScope _(m_InfoLock); diff --git a/src/zenserver/diag/diagsvcs.h b/src/zenserver/diag/diagsvcs.h index 8cc869c83..87ce80b3c 100644 --- a/src/zenserver/diag/diagsvcs.h +++ b/src/zenserver/diag/diagsvcs.h @@ -6,6 +6,7 @@ #include <zenhttp/httpserver.h> #include <filesystem> +#include <vector> ////////////////////////////////////////////////////////////////////////// @@ -89,10 +90,16 @@ private: struct HealthServiceInfo { - std::filesystem::path DataRoot; - std::filesystem::path AbsLogPath; - std::string HttpServerClass; - std::string BuildVersion; + std::filesystem::path DataRoot; + std::filesystem::path AbsLogPath; + std::string HttpServerClass; + std::string BuildVersion; + int Port = 0; + int Pid = 0; + bool IsDedicated = false; + int64_t StartTimeMs = 0; + std::vector<std::pair<std::string_view, bool>> BuildOptions; + std::vector<std::pair<std::string_view, std::string>> RuntimeConfig; }; /** Health monitoring endpoint diff --git a/src/zenserver/diag/logging.cpp b/src/zenserver/diag/logging.cpp index 4962b9006..178c3d3b5 100644 --- a/src/zenserver/diag/logging.cpp +++ b/src/zenserver/diag/logging.cpp @@ -6,6 +6,8 @@ #include <zencore/filesystem.h> #include <zencore/fmtutils.h> +#include <zencore/logging/logger.h> +#include <zencore/logging/registry.h> #include <zencore/memory/llm.h> #include <zencore/session.h> #include <zencore/string.h> @@ -14,10 +16,6 @@ #include "otlphttp.h" -ZEN_THIRD_PARTY_INCLUDES_START -#include <spdlog/spdlog.h> -ZEN_THIRD_PARTY_INCLUDES_END - namespace zen { void @@ -28,10 +26,10 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService) const LoggingOptions LogOptions = {.IsDebug = InOptions.IsDebug, .IsVerbose = false, .IsTest = InOptions.IsTest, - .NoConsoleOutput = InOptions.NoConsoleOutput, - .QuietConsole = InOptions.QuietConsole, - .AbsLogFile = InOptions.AbsLogFile, - .LogId = InOptions.LogId}; + .NoConsoleOutput = InOptions.LoggingConfig.NoConsoleOutput, + .QuietConsole = InOptions.LoggingConfig.QuietConsole, + .AbsLogFile = InOptions.LoggingConfig.AbsLogFile, + .LogId = InOptions.LoggingConfig.LogId}; BeginInitializeLogging(LogOptions); @@ -43,13 +41,12 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService) std::filesystem::path HttpLogPath = InOptions.DataDir / "logs" / "http.log"; zen::CreateDirectories(HttpLogPath.parent_path()); - auto HttpSink = std::make_shared<zen::logging::RotatingFileSink>(HttpLogPath, - /* max size */ 128 * 1024 * 1024, - /* max files */ 16, - /* rotate on open */ true); - auto HttpLogger = std::make_shared<spdlog::logger>("http_requests", HttpSink); - spdlog::apply_logger_env_levels(HttpLogger); - spdlog::register_logger(HttpLogger); + logging::SinkPtr HttpSink(new zen::logging::RotatingFileSink(HttpLogPath, + /* max size */ 128 * 1024 * 1024, + /* max files */ 16, + /* rotate on open */ true)); + Ref<logging::Logger> HttpLogger(new logging::Logger("http_requests", std::vector<logging::SinkPtr>{HttpSink})); + logging::Registry::Instance().Register(HttpLogger); if (WithCacheService) { @@ -57,33 +54,30 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService) std::filesystem::path CacheLogPath = InOptions.DataDir / "logs" / "z$.log"; zen::CreateDirectories(CacheLogPath.parent_path()); - auto CacheSink = std::make_shared<zen::logging::RotatingFileSink>(CacheLogPath, - /* max size */ 128 * 1024 * 1024, - /* max files */ 16, - /* rotate on open */ false); - auto CacheLogger = std::make_shared<spdlog::logger>("z$", CacheSink); - spdlog::apply_logger_env_levels(CacheLogger); - spdlog::register_logger(CacheLogger); + logging::SinkPtr CacheSink(new zen::logging::RotatingFileSink(CacheLogPath, + /* max size */ 128 * 1024 * 1024, + /* max files */ 16, + /* rotate on open */ false)); + Ref<logging::Logger> CacheLogger(new logging::Logger("z$", std::vector<logging::SinkPtr>{CacheSink})); + logging::Registry::Instance().Register(CacheLogger); // Jupiter - only log upstream HTTP traffic to file - auto JupiterLogger = std::make_shared<spdlog::logger>("jupiter", FileSink); - spdlog::apply_logger_env_levels(JupiterLogger); - spdlog::register_logger(JupiterLogger); + Ref<logging::Logger> JupiterLogger(new logging::Logger("jupiter", std::vector<logging::SinkPtr>{FileSink})); + logging::Registry::Instance().Register(JupiterLogger); // Zen - only log upstream HTTP traffic to file - auto ZenClientLogger = std::make_shared<spdlog::logger>("zenclient", FileSink); - spdlog::apply_logger_env_levels(ZenClientLogger); - spdlog::register_logger(ZenClientLogger); + Ref<logging::Logger> ZenClientLogger(new logging::Logger("zenclient", std::vector<logging::SinkPtr>{FileSink})); + logging::Registry::Instance().Register(ZenClientLogger); } #if ZEN_WITH_OTEL - if (!InOptions.OtelEndpointUri.empty()) + if (!InOptions.LoggingConfig.OtelEndpointUri.empty()) { // TODO: Should sanity check that endpoint is reachable? Also, a valid URI? - auto OtelSink = std::make_shared<zen::logging::OtelHttpProtobufSink>(InOptions.OtelEndpointUri); - zen::logging::Default().SpdLogger->sinks().push_back(std::move(OtelSink)); + logging::SinkPtr OtelSink(new zen::logging::OtelHttpProtobufSink(InOptions.LoggingConfig.OtelEndpointUri)); + zen::logging::Default()->AddSink(std::move(OtelSink)); } #endif @@ -91,9 +85,10 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService) const zen::Oid ServerSessionId = zen::GetSessionId(); - spdlog::apply_all([&](auto Logger) { + static constinit logging::LogPoint SessionIdPoint{{}, logging::Info, "server session id: {}"}; + logging::Registry::Instance().ApplyAll([&](auto Logger) { ZEN_MEMSCOPE(ELLMTag::Logging); - Logger->info("server session id: {}", ServerSessionId); + Logger->Log(SessionIdPoint, fmt::make_format_args(ServerSessionId)); }); } diff --git a/src/zenserver/diag/otlphttp.cpp b/src/zenserver/diag/otlphttp.cpp index d62ccccb6..d6e24cbe3 100644 --- a/src/zenserver/diag/otlphttp.cpp +++ b/src/zenserver/diag/otlphttp.cpp @@ -10,11 +10,18 @@ #include <protozero/buffer_string.hpp> #include <protozero/pbf_builder.hpp> +#include <cstdio> + #if ZEN_WITH_OTEL namespace zen::logging { ////////////////////////////////////////////////////////////////////////// +// +// Important note: in general we cannot use ZEN_WARN/ZEN_ERROR etc in this +// file as it could cause recursive logging calls when we attempt to log +// errors from the OTLP HTTP client itself. +// OtelHttpProtobufSink::OtelHttpProtobufSink(const std::string_view& Uri) : m_OtelHttp(Uri) { @@ -36,14 +43,44 @@ OtelHttpProtobufSink::~OtelHttpProtobufSink() } void +OtelHttpProtobufSink::CheckPostResult(const HttpClient::Response& Result, const char* Endpoint) noexcept +{ + if (!Result.IsSuccess()) + { + uint32_t PrevFailures = m_ConsecutivePostFailures.fetch_add(1); + if (PrevFailures < kMaxReportedFailures) + { + fprintf(stderr, "OtelHttpProtobufSink: %s\n", Result.ErrorMessage(Endpoint).c_str()); + if (PrevFailures + 1 == kMaxReportedFailures) + { + fprintf(stderr, "OtelHttpProtobufSink: suppressing further export errors\n"); + } + } + } + else + { + m_ConsecutivePostFailures.store(0); + } +} + +void OtelHttpProtobufSink::RecordSpans(zen::otel::TraceId Trace, std::span<const zen::otel::Span*> Spans) { - std::string Data = m_Encoder.FormatOtelTrace(Trace, Spans); + try + { + std::string Data = m_Encoder.FormatOtelTrace(Trace, Spans); + + IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()}; + Payload.SetContentType(ZenContentType::kProtobuf); - IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()}; - Payload.SetContentType(ZenContentType::kProtobuf); + HttpClient::Response Result = m_OtelHttp.Post("/v1/traces", Payload); - auto Result = m_OtelHttp.Post("/v1/traces", Payload); + CheckPostResult(Result, "POST /v1/traces"); + } + catch (const std::exception& Ex) + { + fprintf(stderr, "OtelHttpProtobufSink: exception exporting traces: %s\n", Ex.what()); + } } void @@ -53,28 +90,26 @@ OtelHttpProtobufSink::TraceRecorder::RecordSpans(zen::otel::TraceId Trace, std:: } void -OtelHttpProtobufSink::log(const spdlog::details::log_msg& Msg) +OtelHttpProtobufSink::Log(const LogMessage& Msg) { + try { std::string Data = m_Encoder.FormatOtelProtobuf(Msg); IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()}; Payload.SetContentType(ZenContentType::kProtobuf); - auto Result = m_OtelHttp.Post("/v1/logs", Payload); - } + HttpClient::Response Result = m_OtelHttp.Post("/v1/logs", Payload); + CheckPostResult(Result, "POST /v1/logs"); + } + catch (const std::exception& Ex) { - std::string Data = m_Encoder.FormatOtelMetrics(); - - IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()}; - Payload.SetContentType(ZenContentType::kProtobuf); - - auto Result = m_OtelHttp.Post("/v1/metrics", Payload); + fprintf(stderr, "OtelHttpProtobufSink: exception exporting logs: %s\n", Ex.what()); } } void -OtelHttpProtobufSink::flush() +OtelHttpProtobufSink::Flush() { } diff --git a/src/zenserver/diag/otlphttp.h b/src/zenserver/diag/otlphttp.h index 2281bdcc0..64b3dbc87 100644 --- a/src/zenserver/diag/otlphttp.h +++ b/src/zenserver/diag/otlphttp.h @@ -3,23 +3,25 @@ #pragma once -#include <spdlog/sinks/sink.h> +#include <zencore/logging/sink.h> #include <zencore/zencore.h> #include <zenhttp/httpclient.h> #include <zentelemetry/otlpencoder.h> #include <zentelemetry/otlptrace.h> +#include <atomic> + #if ZEN_WITH_OTEL namespace zen::logging { /** - * OTLP/HTTP sink for spdlog + * OTLP/HTTP sink for logging * * Sends log messages and traces to an OpenTelemetry collector via OTLP over HTTP */ -class OtelHttpProtobufSink : public spdlog::sinks::sink +class OtelHttpProtobufSink : public Sink { public: // Note that this URI should be the base URI of the OTLP HTTP endpoint, e.g. @@ -31,12 +33,12 @@ public: OtelHttpProtobufSink& operator=(const OtelHttpProtobufSink&) = delete; private: - virtual void log(const spdlog::details::log_msg& Msg) override; - virtual void flush() override; - virtual void set_pattern(const std::string& pattern) override { ZEN_UNUSED(pattern); } - virtual void set_formatter(std::unique_ptr<spdlog::formatter> sink_formatter) override { ZEN_UNUSED(sink_formatter); } + virtual void Log(const LogMessage& Msg) override; + virtual void Flush() override; + virtual void SetFormatter(std::unique_ptr<Formatter>) override {} void RecordSpans(zen::otel::TraceId Trace, std::span<const zen::otel::Span*> Spans); + void CheckPostResult(const HttpClient::Response& Result, const char* Endpoint) noexcept; // This is just a thin wrapper to call back into the sink while participating in // reference counting from the OTEL trace back-end @@ -54,11 +56,15 @@ private: OtelHttpProtobufSink* m_Sink; }; - HttpClient m_OtelHttp; - OtlpEncoder m_Encoder; - Ref<TraceRecorder> m_TraceRecorder; + static constexpr uint32_t kMaxReportedFailures = 5; + + RwLock m_Lock; + std::atomic<uint32_t> m_ConsecutivePostFailures{0}; + HttpClient m_OtelHttp; + OtlpEncoder m_Encoder; + Ref<TraceRecorder> m_TraceRecorder; }; } // namespace zen::logging -#endif
\ No newline at end of file +#endif diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp index 2b157581f..579a65c5a 100644 --- a/src/zenserver/frontend/frontend.cpp +++ b/src/zenserver/frontend/frontend.cpp @@ -38,7 +38,7 @@ HttpFrontendService::HttpFrontendService(std::filesystem::path Directory, HttpSt #if ZEN_EMBED_HTML_ZIP // Load an embedded Zip archive IoBuffer HtmlZipDataBuffer(IoBuffer::Wrap, gHtmlZipData, sizeof(gHtmlZipData) - 1); - m_ZipFs = ZipFs(std::move(HtmlZipDataBuffer)); + m_ZipFs = std::make_unique<ZipFs>(std::move(HtmlZipDataBuffer)); #endif if (m_Directory.empty() && !m_ZipFs) @@ -114,6 +114,8 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) { using namespace std::literals; + ExtendableStringBuilder<256> UriBuilder; + std::string_view Uri = Request.RelativeUriWithExtension(); for (; Uri.length() > 0 && Uri[0] == '/'; Uri = Uri.substr(1)) ; @@ -121,6 +123,11 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) { Uri = "index.html"sv; } + else if (Uri.back() == '/') + { + UriBuilder << Uri << "index.html"sv; + Uri = UriBuilder; + } // Dismiss if the URI contains .. anywhere to prevent arbitrary file reads if (Uri.find("..") != Uri.npos) @@ -145,24 +152,47 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) return Request.WriteResponse(HttpResponseCode::Forbidden); } - // The given content directory overrides any zip-fs discovered in the binary - if (!m_Directory.empty()) - { - auto FullPath = m_Directory / std::filesystem::path(Uri).make_preferred(); - FileContents File = ReadFile(FullPath); + auto WriteResponseForUri = [this, + &Request](std::string_view InUri, HttpResponseCode ResponseCode, HttpContentType ContentType) -> bool { + // The given content directory overrides any zip-fs discovered in the binary + if (!m_Directory.empty()) + { + auto FullPath = m_Directory / std::filesystem::path(InUri).make_preferred(); + FileContents File = ReadFile(FullPath); - if (!File.ErrorCode) + if (!File.ErrorCode) + { + Request.WriteResponse(ResponseCode, ContentType, File.Data[0]); + + return true; + } + } + + if (m_ZipFs) { - return Request.WriteResponse(HttpResponseCode::OK, ContentType, File.Data[0]); + if (IoBuffer FileBuffer = m_ZipFs->GetFile(InUri)) + { + Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer); + + return true; + } } - } - if (IoBuffer FileBuffer = m_ZipFs.GetFile(Uri)) + return false; + }; + + if (WriteResponseForUri(Uri, HttpResponseCode::OK, ContentType)) { - return Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer); + return; + } + else if (WriteResponseForUri("404.html"sv, HttpResponseCode::NotFound, HttpContentType::kHTML)) + { + return; + } + else + { + Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv); } - - Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv); } } // namespace zen diff --git a/src/zenserver/frontend/frontend.h b/src/zenserver/frontend/frontend.h index 84ffaac42..6d8585b72 100644 --- a/src/zenserver/frontend/frontend.h +++ b/src/zenserver/frontend/frontend.h @@ -7,6 +7,7 @@ #include "zipfs.h" #include <filesystem> +#include <memory> namespace zen { @@ -20,9 +21,9 @@ public: virtual void HandleStatusRequest(HttpServerRequest& Request) override; private: - ZipFs m_ZipFs; - std::filesystem::path m_Directory; - HttpStatusService& m_StatusService; + std::unique_ptr<ZipFs> m_ZipFs; + std::filesystem::path m_Directory; + HttpStatusService& m_StatusService; }; } // namespace zen diff --git a/src/zenserver/frontend/html.zip b/src/zenserver/frontend/html.zip Binary files differindex 5d33302dd..84472ff08 100644 --- a/src/zenserver/frontend/html.zip +++ b/src/zenserver/frontend/html.zip diff --git a/src/zenserver/frontend/html/404.html b/src/zenserver/frontend/html/404.html new file mode 100644 index 000000000..829ef2097 --- /dev/null +++ b/src/zenserver/frontend/html/404.html @@ -0,0 +1,486 @@ +<!DOCTYPE html> +<html lang="en"> +<head> +<meta charset="UTF-8"> +<meta name="viewport" content="width=device-width, initial-scale=1.0"> +<title>Ooops</title> +<style> + * { margin: 0; padding: 0; box-sizing: border-box; } + + :root { + --deep-space: #00000f; + --nebula-blue: #0a0a2e; + --star-white: #ffffff; + --star-blue: #c8d8ff; + --star-yellow: #fff3c0; + --star-red: #ffd0c0; + --nebula-glow: rgba(60, 80, 180, 0.12); + } + + body { + background: var(--deep-space); + min-height: 100vh; + display: flex; + align-items: center; + justify-content: center; + font-family: 'Courier New', monospace; + overflow: hidden; + } + + starfield-bg { + display: block; + position: fixed; + inset: 0; + z-index: 0; + } + + canvas { + display: block; + width: 100%; + height: 100%; + } + + .page-content { + position: relative; + z-index: 1; + text-align: center; + color: rgba(200, 216, 255, 0.85); + letter-spacing: 0.25em; + text-transform: uppercase; + pointer-events: none; + user-select: none; + } + + .page-content h1 { + font-size: clamp(1.2rem, 4vw, 2.4rem); + font-weight: 300; + letter-spacing: 0.6em; + text-shadow: 0 0 40px rgba(120, 160, 255, 0.6), 0 0 80px rgba(80, 120, 255, 0.3); + animation: pulse 6s ease-in-out infinite; + } + + .page-content p { + margin-top: 1.2rem; + font-size: clamp(0.55rem, 1.5vw, 0.75rem); + letter-spacing: 0.4em; + opacity: 0.45; + } + + @keyframes pulse { + 0%, 100% { opacity: 0.7; } + 50% { opacity: 1; } + } + + .globe-link { + display: block; + margin: 0 auto 2rem; + width: 160px; + height: 160px; + pointer-events: auto; + cursor: pointer; + border-radius: 50%; + position: relative; + } + + .globe-link:hover .globe-glow { + opacity: 0.6; + } + + .globe-glow { + position: absolute; + inset: -18px; + border-radius: 50%; + background: radial-gradient(circle, rgba(80, 140, 255, 0.35) 0%, transparent 70%); + opacity: 0.35; + transition: opacity 0.4s; + pointer-events: none; + } + + .globe-link canvas { + display: block; + width: 160px; + height: 160px; + border-radius: 50%; + } +</style> +</head> +<body> + +<starfield-bg + star-count="380" + speed="0.6" + depth="true" + nebula="true" + shooting-stars="true" +></starfield-bg> + +<div class="page-content"> + <a class="globe-link" href="/dashboard/" title="Back to Dashboard"> + <div class="globe-glow"></div> + <canvas id="globe" width="320" height="320"></canvas> + </a> + <h1>404 NOT FOUND</h1> +</div> + +<script> +class StarfieldBg extends HTMLElement { + constructor() { + super(); + this.attachShadow({ mode: 'open' }); + } + + connectedCallback() { + this.shadowRoot.innerHTML = ` + <style> + :host { display: block; position: absolute; inset: 0; overflow: hidden; } + canvas { width: 100%; height: 100%; display: block; } + </style> + <canvas></canvas> + `; + + this.canvas = this.shadowRoot.querySelector('canvas'); + this.ctx = this.canvas.getContext('2d'); + + this.starCount = parseInt(this.getAttribute('star-count') || '350'); + this.speed = parseFloat(this.getAttribute('speed') || '0.6'); + this.useDepth = this.getAttribute('depth') !== 'false'; + this.useNebula = this.getAttribute('nebula') !== 'false'; + this.useShooting = this.getAttribute('shooting-stars') !== 'false'; + + this.stars = []; + this.shooters = []; + this.nebulaTime = 0; + this.frame = 0; + + this.resize(); + this.init(); + + this._ro = new ResizeObserver(() => { this.resize(); this.init(); }); + this._ro.observe(this); + + this.raf = requestAnimationFrame(this.tick.bind(this)); + } + + disconnectedCallback() { + cancelAnimationFrame(this.raf); + this._ro.disconnect(); + } + + resize() { + const dpr = window.devicePixelRatio || 1; + const rect = this.getBoundingClientRect(); + this.W = rect.width || window.innerWidth; + this.H = rect.height || window.innerHeight; + this.canvas.width = this.W * dpr; + this.canvas.height = this.H * dpr; + this.ctx.setTransform(dpr, 0, 0, dpr, 0, 0); + } + + init() { + const COLORS = ['#ffffff', '#c8d8ff', '#d0e8ff', '#fff3c0', '#ffd0c0', '#e0f0ff']; + this.stars = Array.from({ length: this.starCount }, () => ({ + x: Math.random() * this.W, + y: Math.random() * this.H, + z: this.useDepth ? Math.random() : 1, // depth: 0=far, 1=near + r: Math.random() * 1.4 + 0.2, + color: COLORS[Math.floor(Math.random() * COLORS.length)], + twinkleOffset: Math.random() * Math.PI * 2, + twinkleSpeed: 0.008 + Math.random() * 0.012, + })); + } + + spawnShooter() { + const edge = Math.random() < 0.7 ? 'top' : 'left'; + const angle = (Math.random() * 30 + 15) * (Math.PI / 180); + this.shooters.push({ + x: edge === 'top' ? Math.random() * this.W : -10, + y: edge === 'top' ? -10 : Math.random() * this.H * 0.5, + vx: Math.cos(angle) * (6 + Math.random() * 6), + vy: Math.sin(angle) * (6 + Math.random() * 6), + len: 80 + Math.random() * 120, + life: 1, + decay: 0.012 + Math.random() * 0.018, + }); + } + + tick() { + this.raf = requestAnimationFrame(this.tick.bind(this)); + this.frame++; + const ctx = this.ctx; + const W = this.W, H = this.H; + + // Background + ctx.fillStyle = '#00000f'; + ctx.fillRect(0, 0, W, H); + + // Nebula clouds (subtle) + if (this.useNebula) { + this.nebulaTime += 0.003; + this.drawNebula(ctx, W, H); + } + + // Stars + for (const s of this.stars) { + const twinkle = 0.55 + 0.45 * Math.sin(this.frame * s.twinkleSpeed + s.twinkleOffset); + const radius = s.r * (this.useDepth ? (0.3 + s.z * 0.7) : 1); + const alpha = (this.useDepth ? (0.25 + s.z * 0.75) : 1) * twinkle; + + // Tiny drift + s.x += (s.z * this.speed * 0.08) * (this.useDepth ? 1 : 0); + s.y += (s.z * this.speed * 0.04) * (this.useDepth ? 1 : 0); + if (s.x > W + 2) s.x = -2; + if (s.y > H + 2) s.y = -2; + + // Glow for bright stars + if (radius > 1.1 && alpha > 0.6) { + const grd = ctx.createRadialGradient(s.x, s.y, 0, s.x, s.y, radius * 3.5); + grd.addColorStop(0, s.color.replace(')', `, ${alpha * 0.5})`).replace('rgb', 'rgba')); + grd.addColorStop(1, 'transparent'); + ctx.beginPath(); + ctx.arc(s.x, s.y, radius * 3.5, 0, Math.PI * 2); + ctx.fillStyle = grd; + ctx.fill(); + } + + ctx.beginPath(); + ctx.arc(s.x, s.y, radius, 0, Math.PI * 2); + ctx.fillStyle = hexToRgba(s.color, alpha); + ctx.fill(); + } + + // Shooting stars + if (this.useShooting) { + if (this.frame % 140 === 0 && Math.random() < 0.65) this.spawnShooter(); + for (let i = this.shooters.length - 1; i >= 0; i--) { + const s = this.shooters[i]; + const tailX = s.x - s.vx * (s.len / Math.hypot(s.vx, s.vy)); + const tailY = s.y - s.vy * (s.len / Math.hypot(s.vx, s.vy)); + + const grd = ctx.createLinearGradient(tailX, tailY, s.x, s.y); + grd.addColorStop(0, `rgba(255,255,255,0)`); + grd.addColorStop(0.7, `rgba(200,220,255,${s.life * 0.5})`); + grd.addColorStop(1, `rgba(255,255,255,${s.life})`); + + ctx.beginPath(); + ctx.moveTo(tailX, tailY); + ctx.lineTo(s.x, s.y); + ctx.strokeStyle = grd; + ctx.lineWidth = 1.5 * s.life; + ctx.lineCap = 'round'; + ctx.stroke(); + + // Head dot + ctx.beginPath(); + ctx.arc(s.x, s.y, 1.5 * s.life, 0, Math.PI * 2); + ctx.fillStyle = `rgba(255,255,255,${s.life})`; + ctx.fill(); + + s.x += s.vx; + s.y += s.vy; + s.life -= s.decay; + + if (s.life <= 0 || s.x > W + 200 || s.y > H + 200) { + this.shooters.splice(i, 1); + } + } + } + } + + drawNebula(ctx, W, H) { + const t = this.nebulaTime; + const blobs = [ + { x: W * 0.25, y: H * 0.3, rx: W * 0.35, ry: H * 0.25, color: '40,60,180', a: 0.055 }, + { x: W * 0.75, y: H * 0.65, rx: W * 0.30, ry: H * 0.22, color: '100,40,160', a: 0.04 }, + { x: W * 0.5, y: H * 0.5, rx: W * 0.45, ry: H * 0.35, color: '20,50,120', a: 0.035 }, + ]; + ctx.save(); + for (const b of blobs) { + const ox = Math.sin(t * 0.7 + b.x) * 30; + const oy = Math.cos(t * 0.5 + b.y) * 20; + const grd = ctx.createRadialGradient(b.x + ox, b.y + oy, 0, b.x + ox, b.y + oy, Math.max(b.rx, b.ry)); + grd.addColorStop(0, `rgba(${b.color}, ${b.a})`); + grd.addColorStop(0.5, `rgba(${b.color}, ${b.a * 0.4})`); + grd.addColorStop(1, `rgba(${b.color}, 0)`); + ctx.save(); + ctx.scale(b.rx / Math.max(b.rx, b.ry), b.ry / Math.max(b.rx, b.ry)); + ctx.beginPath(); + const scale = Math.max(b.rx, b.ry); + ctx.arc((b.x + ox) / (b.rx / scale), (b.y + oy) / (b.ry / scale), scale, 0, Math.PI * 2); + ctx.fillStyle = grd; + ctx.fill(); + ctx.restore(); + } + ctx.restore(); + } +} + +function hexToRgba(hex, alpha) { + // Handle named-ish values or full hex + const c = hex.startsWith('#') ? hex : '#ffffff'; + const r = parseInt(c.slice(1,3), 16); + const g = parseInt(c.slice(3,5), 16); + const b = parseInt(c.slice(5,7), 16); + return `rgba(${r},${g},${b},${alpha.toFixed(3)})`; +} + +customElements.define('starfield-bg', StarfieldBg); +</script> + +<script> +(function() { + const canvas = document.getElementById('globe'); + const ctx = canvas.getContext('2d'); + const W = canvas.width, H = canvas.height; + const R = W * 0.44; + const cx = W / 2, cy = H / 2; + + // Simplified continent outlines as lon/lat polygon chains (degrees). + // Each continent is an array of [lon, lat] points. + const continents = [ + // North America + [[-130,50],[-125,55],[-120,60],[-115,65],[-100,68],[-85,70],[-75,65],[-60,52],[-65,45],[-70,42],[-75,35],[-80,30],[-85,28],[-90,28],[-95,25],[-100,20],[-105,20],[-110,25],[-115,30],[-120,35],[-125,42],[-130,50]], + // South America + [[-80,10],[-75,5],[-70,5],[-65,0],[-60,-5],[-55,-5],[-50,-10],[-45,-15],[-40,-20],[-40,-25],[-42,-30],[-48,-32],[-52,-34],[-55,-38],[-60,-42],[-65,-50],[-68,-55],[-70,-48],[-72,-40],[-75,-30],[-78,-15],[-80,-5],[-80,5],[-80,10]], + // Europe + [[-10,36],[-5,38],[0,40],[2,43],[5,44],[8,46],[10,48],[15,50],[18,54],[20,56],[25,58],[28,60],[30,62],[35,65],[40,68],[38,60],[35,55],[30,50],[28,48],[25,45],[22,40],[20,38],[15,36],[10,36],[5,36],[0,36],[-5,36],[-10,36]], + // Africa + [[-15,14],[-17,16],[-15,22],[-12,28],[-5,32],[0,35],[5,37],[10,35],[15,32],[20,30],[25,30],[30,28],[35,25],[38,18],[40,12],[42,5],[44,0],[42,-5],[40,-12],[38,-18],[35,-25],[32,-30],[30,-34],[25,-33],[20,-30],[15,-28],[12,-20],[10,-10],[8,-5],[5,0],[2,5],[0,5],[-5,5],[-10,6],[-15,10],[-15,14]], + // Asia (simplified) + [[30,35],[35,38],[40,40],[45,42],[50,45],[55,48],[60,50],[65,55],[70,60],[75,65],[80,68],[90,70],[100,68],[110,65],[120,60],[125,55],[130,50],[135,45],[140,40],[138,35],[130,30],[120,25],[110,20],[105,15],[100,10],[95,12],[90,20],[85,22],[80,25],[75,28],[70,30],[65,35],[55,35],[45,35],[40,35],[35,35],[30,35]], + // Australia + [[115,-12],[120,-14],[125,-15],[130,-14],[135,-13],[138,-16],[140,-18],[145,-20],[148,-22],[150,-25],[152,-28],[150,-33],[148,-35],[145,-37],[140,-38],[135,-36],[130,-33],[125,-30],[120,-25],[118,-22],[116,-20],[114,-18],[115,-15],[115,-12]], + ]; + + function project(lon, lat, rotation) { + // Convert to radians and apply rotation + var lonR = (lon + rotation) * Math.PI / 180; + var latR = lat * Math.PI / 180; + + var x3 = Math.cos(latR) * Math.sin(lonR); + var y3 = -Math.sin(latR); + var z3 = Math.cos(latR) * Math.cos(lonR); + + // Only visible if facing us + if (z3 < 0) return null; + + return { x: cx + x3 * R, y: cy + y3 * R, z: z3 }; + } + + var rotation = 0; + + function draw() { + requestAnimationFrame(draw); + rotation += 0.15; + ctx.clearRect(0, 0, W, H); + + // Atmosphere glow + var atm = ctx.createRadialGradient(cx, cy, R * 0.85, cx, cy, R * 1.15); + atm.addColorStop(0, 'rgba(60,130,255,0.12)'); + atm.addColorStop(0.5, 'rgba(60,130,255,0.06)'); + atm.addColorStop(1, 'rgba(60,130,255,0)'); + ctx.beginPath(); + ctx.arc(cx, cy, R * 1.15, 0, Math.PI * 2); + ctx.fillStyle = atm; + ctx.fill(); + + // Ocean sphere + var oceanGrad = ctx.createRadialGradient(cx - R * 0.3, cy - R * 0.3, R * 0.1, cx, cy, R); + oceanGrad.addColorStop(0, '#1a4a8a'); + oceanGrad.addColorStop(0.5, '#0e2d5e'); + oceanGrad.addColorStop(1, '#071838'); + ctx.beginPath(); + ctx.arc(cx, cy, R, 0, Math.PI * 2); + ctx.fillStyle = oceanGrad; + ctx.fill(); + + // Draw continents + for (var c = 0; c < continents.length; c++) { + var pts = continents[c]; + var projected = []; + var allVisible = true; + + for (var i = 0; i < pts.length; i++) { + var p = project(pts[i][0], pts[i][1], rotation); + if (!p) { allVisible = false; break; } + projected.push(p); + } + + if (!allVisible || projected.length < 3) continue; + + ctx.beginPath(); + ctx.moveTo(projected[0].x, projected[0].y); + for (var i = 1; i < projected.length; i++) { + ctx.lineTo(projected[i].x, projected[i].y); + } + ctx.closePath(); + + // Shade based on average depth + var avgZ = 0; + for (var i = 0; i < projected.length; i++) avgZ += projected[i].z; + avgZ /= projected.length; + var brightness = 0.3 + avgZ * 0.7; + + var r = Math.round(30 * brightness); + var g = Math.round(100 * brightness); + var b = Math.round(50 * brightness); + ctx.fillStyle = 'rgb(' + r + ',' + g + ',' + b + ')'; + ctx.fill(); + } + + // Grid lines (longitude) + ctx.strokeStyle = 'rgba(100,160,255,0.08)'; + ctx.lineWidth = 0.7; + for (var lon = -180; lon < 180; lon += 30) { + ctx.beginPath(); + var started = false; + for (var lat = -90; lat <= 90; lat += 3) { + var p = project(lon, lat, rotation); + if (p) { + if (!started) { ctx.moveTo(p.x, p.y); started = true; } + else ctx.lineTo(p.x, p.y); + } else { + started = false; + } + } + ctx.stroke(); + } + + // Grid lines (latitude) + for (var lat = -60; lat <= 60; lat += 30) { + ctx.beginPath(); + var started = false; + for (var lon = -180; lon <= 180; lon += 3) { + var p = project(lon, lat, rotation); + if (p) { + if (!started) { ctx.moveTo(p.x, p.y); started = true; } + else ctx.lineTo(p.x, p.y); + } else { + started = false; + } + } + ctx.stroke(); + } + + // Specular highlight + var spec = ctx.createRadialGradient(cx - R * 0.35, cy - R * 0.35, 0, cx - R * 0.35, cy - R * 0.35, R * 0.8); + spec.addColorStop(0, 'rgba(180,210,255,0.18)'); + spec.addColorStop(0.4, 'rgba(120,160,255,0.05)'); + spec.addColorStop(1, 'rgba(0,0,0,0)'); + ctx.beginPath(); + ctx.arc(cx, cy, R, 0, Math.PI * 2); + ctx.fillStyle = spec; + ctx.fill(); + + // Rim light + ctx.beginPath(); + ctx.arc(cx, cy, R, 0, Math.PI * 2); + ctx.strokeStyle = 'rgba(80,140,255,0.2)'; + ctx.lineWidth = 1.5; + ctx.stroke(); + } + + draw(); +})(); +</script> +</body> +</html> diff --git a/src/zenserver/frontend/html/banner.js b/src/zenserver/frontend/html/banner.js new file mode 100644 index 000000000..2e878dedf --- /dev/null +++ b/src/zenserver/frontend/html/banner.js @@ -0,0 +1,338 @@ +/** + * zen-banner.js — Zen dashboard banner Web Component + * + * Usage: + * <script src="banner.js" defer></script> + * + * <zen-banner></zen-banner> + * <zen-banner variant="compact"></zen-banner> + * <zen-banner cluster-status="degraded" load="78"></zen-banner> + * + * Attributes: + * variant "full" (default) | "compact" + * cluster-status "nominal" (default) | "degraded" | "offline" + * load 0–100 integer, shown as a percentage (default: hidden) + * tagline custom tagline text (default: "Orchestrator Overview" / "Orchestrator") + * subtitle text after "ZEN" in the wordmark (default: "COMPUTE") + */ + +class ZenBanner extends HTMLElement { + + static get observedAttributes() { + return ['variant', 'cluster-status', 'load', 'tagline', 'subtitle', 'logo-src']; + } + + attributeChangedCallback() { + if (this.shadowRoot) this._render(); + } + + connectedCallback() { + if (!this.shadowRoot) this.attachShadow({ mode: 'open' }); + this._render(); + } + + // ───────────────────────────────────────────── + // Derived values + // ───────────────────────────────────────────── + + get _variant() { return this.getAttribute('variant') || 'full'; } + get _status() { return (this.getAttribute('cluster-status') || 'nominal').toLowerCase(); } + get _load() { return this.getAttribute('load'); } // null → hidden + get _tagline() { return this.getAttribute('tagline'); } // null → default + get _subtitle() { return this.getAttribute('subtitle'); } // null → "COMPUTE" + get _logoSrc() { return this.getAttribute('logo-src'); } // null → inline SVG + + get _statusColor() { + return { nominal: '#7ecfb8', degraded: '#d4a84b', offline: '#c0504d' }[this._status] ?? '#7ecfb8'; + } + + get _statusLabel() { + return { nominal: 'NOMINAL', degraded: 'DEGRADED', offline: 'OFFLINE' }[this._status] ?? 'NOMINAL'; + } + + get _loadColor() { + const v = parseInt(this._load, 10); + if (isNaN(v)) return '#7ecfb8'; + if (v >= 85) return '#c0504d'; + if (v >= 60) return '#d4a84b'; + return '#7ecfb8'; + } + + // ───────────────────────────────────────────── + // Render + // ───────────────────────────────────────────── + + _render() { + const compact = this._variant === 'compact'; + this.shadowRoot.innerHTML = ` + <style>${this._css(compact)}</style> + ${this._html(compact)} + `; + } + + // ───────────────────────────────────────────── + // CSS + // ───────────────────────────────────────────── + + _css(compact) { + const height = compact ? '60px' : '100px'; + const padding = compact ? '0 24px' : '0 32px'; + const gap = compact ? '16px' : '24px'; + const markSize = compact ? '34px' : '52px'; + const divH = compact ? '32px' : '48px'; + const nameSize = compact ? '15px' : '22px'; + const tagSize = compact ? '9px' : '11px'; + const sc = this._statusColor; + const lc = this._loadColor; + + return ` + @import url('https://fonts.googleapis.com/css2?family=Noto+Serif+JP:wght@300;400&family=Space+Mono:wght@400;700&display=swap'); + + *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; } + + :host { + display: block; + font-family: 'Space Mono', monospace; + } + + .banner { + width: 100%; + height: ${height}; + background: var(--theme_g3, #0b0d10); + border: 1px solid var(--theme_g2, #1e2330); + border-radius: 6px; + display: flex; + align-items: center; + padding: ${padding}; + gap: ${gap}; + position: relative; + overflow: hidden; + text-decoration: none; + color: inherit; + cursor: pointer; + } + + /* scan-line texture */ + .banner::before { + content: ''; + position: absolute; + inset: 0; + background: repeating-linear-gradient( + 0deg, + transparent, transparent 3px, + rgba(255,255,255,0.012) 3px, rgba(255,255,255,0.012) 4px + ); + pointer-events: none; + } + + /* ambient glow */ + .banner::after { + content: ''; + position: absolute; + right: -60px; + top: 50%; + transform: translateY(-50%); + width: 280px; + height: 280px; + background: radial-gradient(circle, rgba(130,200,180,0.06) 0%, transparent 70%); + pointer-events: none; + } + + .logo-mark { + flex-shrink: 0; + width: ${markSize}; + height: ${markSize}; + } + + .logo-mark svg, .logo-mark img { width: 100%; height: 100%; object-fit: contain; } + + .divider { + width: 1px; + height: ${divH}; + background: linear-gradient(to bottom, transparent, var(--theme_g2, #2a3040), transparent); + flex-shrink: 0; + } + + .text-block { + display: flex; + flex-direction: column; + gap: 4px; + } + + .wordmark { + font-weight: 700; + font-size: ${nameSize}; + letter-spacing: 0.12em; + color: var(--theme_bright, #e8e4dc); + text-transform: uppercase; + line-height: 1; + } + + .wordmark span { color: #7ecfb8; } + + .tagline { + font-family: 'Noto Serif JP', serif; + font-weight: 300; + font-size: ${tagSize}; + letter-spacing: 0.3em; + color: var(--theme_faint, #4a5a68); + text-transform: uppercase; + } + + .spacer { flex: 1; } + + /* ── right-side decorative circuit ── */ + .circuit { flex-shrink: 0; opacity: 0.22; } + + /* ── status cluster ── */ + .status-cluster { + display: flex; + flex-direction: column; + align-items: flex-end; + gap: 6px; + } + + .status-row { + display: flex; + align-items: center; + gap: 8px; + } + + .status-lbl { + font-size: 9px; + letter-spacing: 0.18em; + color: var(--theme_faint, #3a4555); + text-transform: uppercase; + } + + .pill { + display: flex; + align-items: center; + gap: 5px; + border-radius: 20px; + padding: 2px 10px; + font-size: 10px; + letter-spacing: 0.1em; + } + + .pill.cluster { + color: ${sc}; + background: color-mix(in srgb, ${sc} 8%, transparent); + border: 1px solid color-mix(in srgb, ${sc} 28%, transparent); + } + + .pill.load-pill { + color: ${lc}; + background: color-mix(in srgb, ${lc} 8%, transparent); + border: 1px solid color-mix(in srgb, ${lc} 28%, transparent); + } + + .dot { + width: 5px; + height: 5px; + border-radius: 50%; + animation: pulse 2.4s ease-in-out infinite; + } + + .dot.cluster { background: ${sc}; } + .dot.load-dot { background: ${lc}; animation-delay: 0.5s; } + + @keyframes pulse { + 0%, 100% { opacity: 1; } + 50% { opacity: 0.25; } + } + `; + } + + // ───────────────────────────────────────────── + // HTML template + // ───────────────────────────────────────────── + + _html(compact) { + const loadAttr = this._load; + const hasCluster = !compact && this.hasAttribute('cluster-status'); + const hasLoad = !compact && loadAttr !== null; + const showRight = hasCluster || hasLoad; + + const circuit = showRight ? ` + <svg class="circuit" width="60" height="60" viewBox="0 0 60 60" fill="none"> + <path d="M5 30 H22 L28 18 H60" stroke="#7ecfb8" stroke-width="0.8"/> + <path d="M5 38 H18 L24 46 H60" stroke="#7ecfb8" stroke-width="0.8"/> + <circle cx="22" cy="30" r="2" fill="none" stroke="#7ecfb8" stroke-width="0.8"/> + <circle cx="18" cy="38" r="2" fill="none" stroke="#7ecfb8" stroke-width="0.8"/> + <circle cx="10" cy="30" r="1.2" fill="#7ecfb8"/> + <circle cx="10" cy="38" r="1.2" fill="#7ecfb8"/> + </svg>` : ''; + + const clusterRow = hasCluster ? ` + <div class="status-row"> + <span class="status-lbl">Cluster</span> + <div class="pill cluster"> + <div class="dot cluster"></div> + ${this._statusLabel} + </div> + </div>` : ''; + + const loadRow = hasLoad ? ` + <div class="status-row"> + <span class="status-lbl">Load</span> + <div class="pill load-pill"> + <div class="dot load-dot"></div> + ${parseInt(loadAttr, 10)} % + </div> + </div>` : ''; + + const rightSide = showRight ? ` + ${circuit} + <div class="status-cluster"> + ${clusterRow} + ${loadRow} + </div> + ` : ''; + + return ` + <a class="banner" href="/dashboard/"> + <div class="logo-mark">${this._logoMark()}</div> + <div class="divider"></div> + <div class="text-block"> + <div class="wordmark">ZEN<span> ${this._subtitle ?? 'COMPUTE'}</span></div> + <div class="tagline">${this._tagline ?? (compact ? 'Orchestrator' : 'Orchestrator Overview')}</div> + </div> + <div class="spacer"></div> + ${rightSide} + </a> + `; + } + + // ───────────────────────────────────────────── + // SVG logo mark + // ───────────────────────────────────────────── + + _logoMark() { + const src = this._logoSrc; + if (src) { + return `<img src="${src}" alt="zen">`; + } + return ` + <svg viewBox="0 0 52 52" fill="none" xmlns="http://www.w3.org/2000/svg"> + <circle cx="26" cy="26" r="22" stroke="#2a3a48" stroke-width="1.5"/> + <path d="M26 4 A22 22 0 1 1 12 43.1" stroke="#7ecfb8" stroke-width="2" stroke-linecap="round" fill="none"/> + <circle cx="17" cy="17" r="1.6" fill="#7ecfb8" /> + <circle cx="26" cy="17" r="1.6" fill="#7ecfb8" /> + <circle cx="35" cy="17" r="1.6" fill="#7ecfb8" /> + <circle cx="17" cy="26" r="1.6" fill="#7ecfb8" opacity="0.6"/> + <circle cx="26" cy="26" r="2.2" fill="#7ecfb8"/> + <circle cx="35" cy="26" r="1.6" fill="#7ecfb8" opacity="0.6"/> + <circle cx="17" cy="35" r="1.6" fill="#7ecfb8"/> + <circle cx="26" cy="35" r="1.6" fill="#7ecfb8"/> + <circle cx="35" cy="35" r="1.6" fill="#7ecfb8"/> + <line x1="17" y1="17" x2="35" y2="17" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.25"/> + <line x1="35" y1="17" x2="17" y2="35" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.25"/> + <line x1="17" y1="35" x2="35" y2="35" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.2"/> + <line x1="26" y1="17" x2="26" y2="35" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.2"/> + </svg> + `; + } +} + +customElements.define('zen-banner', ZenBanner); diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html new file mode 100644 index 000000000..66c20175f --- /dev/null +++ b/src/zenserver/frontend/html/compute/compute.html @@ -0,0 +1,929 @@ +<!DOCTYPE html> +<html lang="en"> +<head> + <meta charset="UTF-8"> + <meta name="viewport" content="width=device-width, initial-scale=1.0"> + <title>Zen Compute Dashboard</title> + <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js"></script> + <link rel="stylesheet" type="text/css" href="../zen.css" /> + <script src="../theme.js"></script> + <script src="../banner.js" defer></script> + <script src="../nav.js" defer></script> + <style> + .grid { + grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); + } + + .chart-container { + position: relative; + height: 300px; + margin-top: 20px; + } + + .stats-row { + display: flex; + justify-content: space-between; + margin-bottom: 12px; + padding: 8px 0; + border-bottom: 1px solid var(--theme_border_subtle); + } + + .stats-row:last-child { + border-bottom: none; + margin-bottom: 0; + } + + .stats-label { + color: var(--theme_g1); + font-size: 13px; + } + + .stats-value { + color: var(--theme_bright); + font-weight: 600; + font-size: 13px; + } + + .rate-stats { + display: grid; + grid-template-columns: repeat(3, 1fr); + gap: 16px; + margin-top: 16px; + } + + .rate-item { + text-align: center; + } + + .rate-value { + font-size: 20px; + font-weight: 600; + color: var(--theme_p0); + } + + .rate-label { + font-size: 11px; + color: var(--theme_g1); + margin-top: 4px; + text-transform: uppercase; + } + + .worker-row { + cursor: pointer; + transition: background 0.15s; + } + + .worker-row:hover { + background: var(--theme_p4); + } + + .worker-row.selected { + background: var(--theme_p3); + } + + .worker-detail { + margin-top: 20px; + border-top: 1px solid var(--theme_g2); + padding-top: 16px; + } + + .worker-detail-title { + font-size: 15px; + font-weight: 600; + color: var(--theme_bright); + margin-bottom: 12px; + } + + .detail-section { + margin-bottom: 16px; + } + + .detail-section-label { + font-size: 11px; + font-weight: 600; + color: var(--theme_g1); + text-transform: uppercase; + letter-spacing: 0.5px; + margin-bottom: 6px; + } + + .detail-table { + width: 100%; + border-collapse: collapse; + font-size: 12px; + } + + .detail-table td { + padding: 4px 8px; + color: var(--theme_g0); + border-bottom: 1px solid var(--theme_border_subtle); + vertical-align: top; + } + + .detail-table td:first-child { + color: var(--theme_g1); + width: 40%; + font-family: monospace; + } + + .detail-table tr:last-child td { + border-bottom: none; + } + + .detail-mono { + font-family: monospace; + font-size: 11px; + color: var(--theme_g1); + } + + .detail-tag { + display: inline-block; + padding: 2px 8px; + border-radius: 4px; + background: var(--theme_border_subtle); + color: var(--theme_g0); + font-size: 11px; + margin: 2px 4px 2px 0; + } + </style> +</head> +<body> + <div class="container" style="max-width: 1400px; margin: 0 auto;"> + <zen-banner cluster-status="nominal" load="0" tagline="Node Overview" logo-src="../favicon.ico"></zen-banner> + <zen-nav> + <a href="/dashboard/">Home</a> + <a href="compute.html">Node</a> + <a href="orchestrator.html">Orchestrator</a> + </zen-nav> + <div class="timestamp">Last updated: <span id="last-update">Never</span></div> + + <div id="error-container"></div> + + <!-- Action Queue Stats --> + <div class="section-title">Action Queue</div> + <div class="grid"> + <div class="card"> + <div class="card-title">Pending Actions</div> + <div class="metric-value" id="actions-pending">-</div> + <div class="metric-label">Waiting to be scheduled</div> + </div> + <div class="card"> + <div class="card-title">Running Actions</div> + <div class="metric-value" id="actions-running">-</div> + <div class="metric-label">Currently executing</div> + </div> + <div class="card"> + <div class="card-title">Completed Actions</div> + <div class="metric-value" id="actions-complete">-</div> + <div class="metric-label">Results available</div> + </div> + </div> + + <!-- Action Queue Chart --> + <div class="card" style="margin-bottom: 30px;"> + <div class="card-title">Action Queue History</div> + <div class="chart-container"> + <canvas id="queue-chart"></canvas> + </div> + </div> + + <!-- Performance Metrics --> + <div class="section-title">Performance Metrics</div> + <div class="card" style="margin-bottom: 30px;"> + <div class="card-title">Completion Rate</div> + <div class="rate-stats"> + <div class="rate-item"> + <div class="rate-value" id="rate-1">-</div> + <div class="rate-label">1 min rate</div> + </div> + <div class="rate-item"> + <div class="rate-value" id="rate-5">-</div> + <div class="rate-label">5 min rate</div> + </div> + <div class="rate-item"> + <div class="rate-value" id="rate-15">-</div> + <div class="rate-label">15 min rate</div> + </div> + </div> + <div style="margin-top: 20px;"> + <div class="stats-row"> + <span class="stats-label">Total Retired</span> + <span class="stats-value" id="retired-count">-</span> + </div> + <div class="stats-row"> + <span class="stats-label">Mean Rate</span> + <span class="stats-value" id="rate-mean">-</span> + </div> + </div> + </div> + + <!-- Workers --> + <div class="section-title">Workers</div> + <div class="card" style="margin-bottom: 30px;"> + <div class="card-title">Worker Status</div> + <div class="stats-row"> + <span class="stats-label">Registered Workers</span> + <span class="stats-value" id="worker-count">-</span> + </div> + <div id="worker-table-container" style="margin-top: 16px; display: none;"> + <table id="worker-table"> + <thead> + <tr> + <th>Name</th> + <th>Platform</th> + <th style="text-align: right;">Cores</th> + <th style="text-align: right;">Timeout</th> + <th style="text-align: right;">Functions</th> + <th>Worker ID</th> + </tr> + </thead> + <tbody id="worker-table-body"></tbody> + </table> + <div id="worker-detail" class="worker-detail" style="display: none;"></div> + </div> + </div> + + <!-- Queues --> + <div class="section-title">Queues</div> + <div class="card" style="margin-bottom: 30px;"> + <div class="card-title">Queue Status</div> + <div id="queue-list-empty" class="empty-state" style="text-align: left;">No queues.</div> + <div id="queue-list-container" style="display: none;"> + <table id="queue-list-table"> + <thead> + <tr> + <th style="text-align: right; width: 60px;">ID</th> + <th style="text-align: center; width: 80px;">Status</th> + <th style="text-align: right;">Active</th> + <th style="text-align: right;">Completed</th> + <th style="text-align: right;">Failed</th> + <th style="text-align: right;">Abandoned</th> + <th style="text-align: right;">Cancelled</th> + <th>Token</th> + </tr> + </thead> + <tbody id="queue-list-body"></tbody> + </table> + </div> + </div> + + <!-- Action History --> + <div class="section-title">Recent Actions</div> + <div class="card" style="margin-bottom: 30px;"> + <div class="card-title">Action History</div> + <div id="action-history-empty" class="empty-state" style="text-align: left;">No actions recorded yet.</div> + <div id="action-history-container" style="display: none;"> + <table id="action-history-table"> + <thead> + <tr> + <th style="text-align: right; width: 60px;">LSN</th> + <th style="text-align: right; width: 60px;">Queue</th> + <th style="text-align: center; width: 70px;">Status</th> + <th>Function</th> + <th style="text-align: right; width: 80px;">Started</th> + <th style="text-align: right; width: 80px;">Finished</th> + <th style="text-align: right; width: 80px;">Duration</th> + <th>Worker ID</th> + <th>Action ID</th> + </tr> + </thead> + <tbody id="action-history-body"></tbody> + </table> + </div> + </div> + + <!-- System Resources --> + <div class="section-title">System Resources</div> + <div class="grid"> + <div class="card"> + <div class="card-title">CPU Usage</div> + <div class="metric-value" id="cpu-usage">-</div> + <div class="metric-label">Percent</div> + <div class="progress-bar"> + <div class="progress-fill" id="cpu-progress" style="width: 0%"></div> + </div> + <div style="position: relative; height: 60px; margin-top: 12px;"> + <canvas id="cpu-chart"></canvas> + </div> + <div style="margin-top: 12px;"> + <div class="stats-row"> + <span class="stats-label">Packages</span> + <span class="stats-value" id="cpu-packages">-</span> + </div> + <div class="stats-row"> + <span class="stats-label">Physical Cores</span> + <span class="stats-value" id="cpu-cores">-</span> + </div> + <div class="stats-row"> + <span class="stats-label">Logical Processors</span> + <span class="stats-value" id="cpu-lp">-</span> + </div> + </div> + </div> + <div class="card"> + <div class="card-title">Memory</div> + <div class="stats-row"> + <span class="stats-label">Used</span> + <span class="stats-value" id="memory-used">-</span> + </div> + <div class="stats-row"> + <span class="stats-label">Total</span> + <span class="stats-value" id="memory-total">-</span> + </div> + <div class="progress-bar"> + <div class="progress-fill" id="memory-progress" style="width: 0%"></div> + </div> + </div> + <div class="card"> + <div class="card-title">Disk</div> + <div class="stats-row"> + <span class="stats-label">Used</span> + <span class="stats-value" id="disk-used">-</span> + </div> + <div class="stats-row"> + <span class="stats-label">Total</span> + <span class="stats-value" id="disk-total">-</span> + </div> + <div class="progress-bar"> + <div class="progress-fill" id="disk-progress" style="width: 0%"></div> + </div> + </div> + </div> + </div> + + <script> + // Configuration + const BASE_URL = window.location.origin; + const REFRESH_INTERVAL = 2000; // 2 seconds + const MAX_HISTORY_POINTS = 60; // Show last 2 minutes + + // Data storage + const history = { + timestamps: [], + pending: [], + running: [], + completed: [], + cpu: [] + }; + + // CPU sparkline chart + const cpuCtx = document.getElementById('cpu-chart').getContext('2d'); + const cpuChart = new Chart(cpuCtx, { + type: 'line', + data: { + labels: [], + datasets: [{ + data: [], + borderColor: '#58a6ff', + backgroundColor: 'rgba(88, 166, 255, 0.15)', + borderWidth: 1.5, + tension: 0.4, + fill: true, + pointRadius: 0 + }] + }, + options: { + responsive: true, + maintainAspectRatio: false, + animation: false, + plugins: { legend: { display: false }, tooltip: { enabled: false } }, + scales: { + x: { display: false }, + y: { display: false, min: 0, max: 100 } + } + } + }); + + // Queue chart setup + const ctx = document.getElementById('queue-chart').getContext('2d'); + const chart = new Chart(ctx, { + type: 'line', + data: { + labels: [], + datasets: [ + { + label: 'Pending', + data: [], + borderColor: '#f0883e', + backgroundColor: 'rgba(240, 136, 62, 0.1)', + tension: 0.4, + fill: true + }, + { + label: 'Running', + data: [], + borderColor: '#58a6ff', + backgroundColor: 'rgba(88, 166, 255, 0.1)', + tension: 0.4, + fill: true + }, + { + label: 'Completed', + data: [], + borderColor: '#3fb950', + backgroundColor: 'rgba(63, 185, 80, 0.1)', + tension: 0.4, + fill: true + } + ] + }, + options: { + responsive: true, + maintainAspectRatio: false, + plugins: { + legend: { + display: true, + labels: { + color: '#8b949e' + } + } + }, + scales: { + x: { + display: false + }, + y: { + beginAtZero: true, + ticks: { + color: '#8b949e' + }, + grid: { + color: '#21262d' + } + } + } + } + }); + + // Helper functions + function escapeHtml(text) { + var div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; + } + + function formatBytes(bytes) { + if (bytes === 0) return '0 B'; + const k = 1024; + const sizes = ['B', 'KB', 'MB', 'GB', 'TB']; + const i = Math.floor(Math.log(bytes) / Math.log(k)); + return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; + } + + function formatRate(rate) { + return rate.toFixed(2) + '/s'; + } + + function showError(message) { + const container = document.getElementById('error-container'); + container.innerHTML = `<div class="error">Error: ${escapeHtml(message)}</div>`; + } + + function clearError() { + document.getElementById('error-container').innerHTML = ''; + } + + function updateTimestamp() { + const now = new Date(); + document.getElementById('last-update').textContent = now.toLocaleTimeString(); + } + + // Fetch functions + async function fetchJSON(endpoint) { + const response = await fetch(`${BASE_URL}${endpoint}`, { + headers: { + 'Accept': 'application/json' + } + }); + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + return await response.json(); + } + + async function fetchHealth() { + try { + const response = await fetch(`${BASE_URL}/compute/ready`); + const isHealthy = response.status === 200; + + const banner = document.querySelector('zen-banner'); + + if (isHealthy) { + banner.setAttribute('cluster-status', 'nominal'); + banner.setAttribute('load', '0'); + } else { + banner.setAttribute('cluster-status', 'degraded'); + banner.setAttribute('load', '0'); + } + + return isHealthy; + } catch (error) { + const banner = document.querySelector('zen-banner'); + banner.setAttribute('cluster-status', 'degraded'); + banner.setAttribute('load', '0'); + throw error; + } + } + + async function fetchStats() { + const data = await fetchJSON('/stats/compute'); + + // Update action counts + document.getElementById('actions-pending').textContent = data.actions_pending || 0; + document.getElementById('actions-running').textContent = data.actions_submitted || 0; + document.getElementById('actions-complete').textContent = data.actions_complete || 0; + + // Update completion rates + if (data.actions_retired) { + document.getElementById('rate-1').textContent = formatRate(data.actions_retired.rate_1 || 0); + document.getElementById('rate-5').textContent = formatRate(data.actions_retired.rate_5 || 0); + document.getElementById('rate-15').textContent = formatRate(data.actions_retired.rate_15 || 0); + document.getElementById('retired-count').textContent = data.actions_retired.count || 0; + document.getElementById('rate-mean').textContent = formatRate(data.actions_retired.rate_mean || 0); + } + + // Update chart + const now = new Date().toLocaleTimeString(); + history.timestamps.push(now); + history.pending.push(data.actions_pending || 0); + history.running.push(data.actions_submitted || 0); + history.completed.push(data.actions_complete || 0); + + // Keep only last N points + if (history.timestamps.length > MAX_HISTORY_POINTS) { + history.timestamps.shift(); + history.pending.shift(); + history.running.shift(); + history.completed.shift(); + } + + chart.data.labels = history.timestamps; + chart.data.datasets[0].data = history.pending; + chart.data.datasets[1].data = history.running; + chart.data.datasets[2].data = history.completed; + chart.update('none'); + } + + async function fetchSysInfo() { + const data = await fetchJSON('/compute/sysinfo'); + + // Update CPU + const cpuUsage = data.cpu_usage || 0; + document.getElementById('cpu-usage').textContent = cpuUsage.toFixed(1) + '%'; + document.getElementById('cpu-progress').style.width = cpuUsage + '%'; + + const banner = document.querySelector('zen-banner'); + banner.setAttribute('load', cpuUsage.toFixed(1)); + + history.cpu.push(cpuUsage); + if (history.cpu.length > MAX_HISTORY_POINTS) history.cpu.shift(); + cpuChart.data.labels = history.cpu.map(() => ''); + cpuChart.data.datasets[0].data = history.cpu; + cpuChart.update('none'); + + document.getElementById('cpu-packages').textContent = data.cpu_count ?? '-'; + document.getElementById('cpu-cores').textContent = data.core_count ?? '-'; + document.getElementById('cpu-lp').textContent = data.lp_count ?? '-'; + + // Update Memory + const memUsed = data.memory_used || 0; + const memTotal = data.memory_total || 1; + const memPercent = (memUsed / memTotal) * 100; + document.getElementById('memory-used').textContent = formatBytes(memUsed); + document.getElementById('memory-total').textContent = formatBytes(memTotal); + document.getElementById('memory-progress').style.width = memPercent + '%'; + + // Update Disk + const diskUsed = data.disk_used || 0; + const diskTotal = data.disk_total || 1; + const diskPercent = (diskUsed / diskTotal) * 100; + document.getElementById('disk-used').textContent = formatBytes(diskUsed); + document.getElementById('disk-total').textContent = formatBytes(diskTotal); + document.getElementById('disk-progress').style.width = diskPercent + '%'; + } + + // Persists the selected worker ID across refreshes + let selectedWorkerId = null; + + function renderWorkerDetail(id, desc) { + const panel = document.getElementById('worker-detail'); + + if (!desc) { + panel.style.display = 'none'; + return; + } + + function field(label, value) { + return `<tr><td>${label}</td><td>${value ?? '-'}</td></tr>`; + } + + function monoField(label, value) { + return `<tr><td>${label}</td><td class="detail-mono">${value ?? '-'}</td></tr>`; + } + + // Functions + const functions = desc.functions || []; + const functionsHtml = functions.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : + `<table class="detail-table">${functions.map(f => + `<tr><td>${escapeHtml(f.name || '-')}</td><td class="detail-mono">${escapeHtml(f.version || '-')}</td></tr>` + ).join('')}</table>`; + + // Executables + const executables = desc.executables || []; + const totalExecSize = executables.reduce((sum, e) => sum + (e.size || 0), 0); + const execHtml = executables.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : + `<table class="detail-table"> + <tr style="font-size:11px;"> + <td style="color:var(--theme_faint);padding-bottom:4px;">Path</td> + <td style="color:var(--theme_faint);padding-bottom:4px;">Hash</td> + <td style="color:var(--theme_faint);padding-bottom:4px;text-align:right;">Size</td> + </tr> + ${executables.map(e => + `<tr> + <td>${escapeHtml(e.name || '-')}</td> + <td class="detail-mono">${escapeHtml(e.hash || '-')}</td> + <td style="text-align:right;white-space:nowrap;">${e.size != null ? formatBytes(e.size) : '-'}</td> + </tr>` + ).join('')} + <tr style="border-top:1px solid var(--theme_g2);"> + <td style="color:var(--theme_g1);padding-top:6px;">Total</td> + <td></td> + <td style="text-align:right;white-space:nowrap;padding-top:6px;color:var(--theme_bright);font-weight:600;">${formatBytes(totalExecSize)}</td> + </tr> + </table>`; + + // Files + const files = desc.files || []; + const filesHtml = files.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : + `<table class="detail-table">${files.map(f => + `<tr><td>${escapeHtml(f.name || f)}</td><td class="detail-mono">${escapeHtml(f.hash || '')}</td></tr>` + ).join('')}</table>`; + + // Dirs + const dirs = desc.dirs || []; + const dirsHtml = dirs.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : + dirs.map(d => `<span class="detail-tag">${escapeHtml(d)}</span>`).join(''); + + // Environment + const env = desc.environment || []; + const envHtml = env.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : + env.map(e => `<span class="detail-tag">${escapeHtml(e)}</span>`).join(''); + + panel.innerHTML = ` + <div class="worker-detail-title">${escapeHtml(desc.name || id)}</div> + <div class="detail-section"> + <table class="detail-table"> + ${field('Worker ID', `<span class="detail-mono">${escapeHtml(id)}</span>`)} + ${field('Path', escapeHtml(desc.path || '-'))} + ${field('Platform', escapeHtml(desc.host || '-'))} + ${monoField('Build System', desc.buildsystem_version)} + ${field('Cores', desc.cores)} + ${field('Timeout', desc.timeout != null ? desc.timeout + 's' : null)} + </table> + </div> + <div class="detail-section"> + <div class="detail-section-label">Functions</div> + ${functionsHtml} + </div> + <div class="detail-section"> + <div class="detail-section-label">Executables</div> + ${execHtml} + </div> + <div class="detail-section"> + <div class="detail-section-label">Files</div> + ${filesHtml} + </div> + <div class="detail-section"> + <div class="detail-section-label">Directories</div> + ${dirsHtml} + </div> + <div class="detail-section"> + <div class="detail-section-label">Environment</div> + ${envHtml} + </div> + `; + panel.style.display = 'block'; + } + + async function fetchWorkers() { + const data = await fetchJSON('/compute/workers'); + const workerIds = data.workers || []; + + document.getElementById('worker-count').textContent = workerIds.length; + + const container = document.getElementById('worker-table-container'); + const tbody = document.getElementById('worker-table-body'); + + if (workerIds.length === 0) { + container.style.display = 'none'; + selectedWorkerId = null; + return; + } + + const descriptors = await Promise.all( + workerIds.map(id => fetchJSON(`/compute/workers/${id}`).catch(() => null)) + ); + + // Build a map for quick lookup by ID + const descriptorMap = {}; + workerIds.forEach((id, i) => { descriptorMap[id] = descriptors[i]; }); + + tbody.innerHTML = ''; + descriptors.forEach((desc, i) => { + const id = workerIds[i]; + const name = desc ? (desc.name || '-') : '-'; + const host = desc ? (desc.host || '-') : '-'; + const cores = desc ? (desc.cores != null ? desc.cores : '-') : '-'; + const timeout = desc ? (desc.timeout != null ? desc.timeout + 's' : '-') : '-'; + const functions = desc ? (desc.functions ? desc.functions.length : 0) : '-'; + + const tr = document.createElement('tr'); + tr.className = 'worker-row' + (id === selectedWorkerId ? ' selected' : ''); + tr.dataset.workerId = id; + tr.innerHTML = ` + <td style="color: var(--theme_bright);">${escapeHtml(name)}</td> + <td>${escapeHtml(host)}</td> + <td style="text-align: right;">${escapeHtml(String(cores))}</td> + <td style="text-align: right;">${escapeHtml(String(timeout))}</td> + <td style="text-align: right;">${escapeHtml(String(functions))}</td> + <td style="color: var(--theme_g1); font-family: monospace; font-size: 11px;">${escapeHtml(id)}</td> + `; + tr.addEventListener('click', () => { + document.querySelectorAll('.worker-row').forEach(r => r.classList.remove('selected')); + if (selectedWorkerId === id) { + // Toggle off + selectedWorkerId = null; + document.getElementById('worker-detail').style.display = 'none'; + } else { + selectedWorkerId = id; + tr.classList.add('selected'); + renderWorkerDetail(id, descriptorMap[id]); + } + }); + tbody.appendChild(tr); + }); + + // Re-render detail if selected worker is still present + if (selectedWorkerId && descriptorMap[selectedWorkerId]) { + renderWorkerDetail(selectedWorkerId, descriptorMap[selectedWorkerId]); + } else if (selectedWorkerId && !descriptorMap[selectedWorkerId]) { + selectedWorkerId = null; + document.getElementById('worker-detail').style.display = 'none'; + } + + container.style.display = 'block'; + } + + // Windows FILETIME: 100ns ticks since 1601-01-01. Convert to JS Date. + const FILETIME_EPOCH_OFFSET_MS = 11644473600000n; + function filetimeToDate(ticks) { + if (!ticks) return null; + const ms = BigInt(ticks) / 10000n - FILETIME_EPOCH_OFFSET_MS; + return new Date(Number(ms)); + } + + function formatTime(date) { + if (!date) return '-'; + return date.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' }); + } + + function formatDuration(startDate, endDate) { + if (!startDate || !endDate) return '-'; + const ms = endDate - startDate; + if (ms < 0) return '-'; + if (ms < 1000) return ms + ' ms'; + if (ms < 60000) return (ms / 1000).toFixed(2) + ' s'; + const m = Math.floor(ms / 60000); + const s = ((ms % 60000) / 1000).toFixed(0).padStart(2, '0'); + return `${m}m ${s}s`; + } + + async function fetchQueues() { + const data = await fetchJSON('/compute/queues'); + const queues = data.queues || []; + + const empty = document.getElementById('queue-list-empty'); + const container = document.getElementById('queue-list-container'); + const tbody = document.getElementById('queue-list-body'); + + if (queues.length === 0) { + empty.style.display = ''; + container.style.display = 'none'; + return; + } + + empty.style.display = 'none'; + tbody.innerHTML = ''; + + for (const q of queues) { + const id = q.queue_id ?? '-'; + const badge = q.state === 'cancelled' + ? '<span class="status-badge failure">cancelled</span>' + : q.state === 'draining' + ? '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_warn) 15%, transparent);color:var(--theme_warn);">draining</span>' + : q.is_complete + ? '<span class="status-badge success">complete</span>' + : '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_p0) 15%, transparent);color:var(--theme_p0);">active</span>'; + const token = q.queue_token + ? `<span class="detail-mono">${escapeHtml(q.queue_token)}</span>` + : '<span style="color:var(--theme_faint);">-</span>'; + + const tr = document.createElement('tr'); + tr.innerHTML = ` + <td style="text-align: right; font-family: monospace; color: var(--theme_bright);">${escapeHtml(String(id))}</td> + <td style="text-align: center;">${badge}</td> + <td style="text-align: right;">${q.active_count ?? 0}</td> + <td style="text-align: right; color: var(--theme_ok);">${q.completed_count ?? 0}</td> + <td style="text-align: right; color: var(--theme_fail);">${q.failed_count ?? 0}</td> + <td style="text-align: right; color: var(--theme_warn);">${q.abandoned_count ?? 0}</td> + <td style="text-align: right; color: var(--theme_warn);">${q.cancelled_count ?? 0}</td> + <td>${token}</td> + `; + tbody.appendChild(tr); + } + + container.style.display = 'block'; + } + + async function fetchActionHistory() { + const data = await fetchJSON('/compute/jobs/history?limit=50'); + const entries = data.history || []; + + const empty = document.getElementById('action-history-empty'); + const container = document.getElementById('action-history-container'); + const tbody = document.getElementById('action-history-body'); + + if (entries.length === 0) { + empty.style.display = ''; + container.style.display = 'none'; + return; + } + + empty.style.display = 'none'; + tbody.innerHTML = ''; + + // Entries arrive oldest-first; reverse to show newest at top + for (const entry of [...entries].reverse()) { + const lsn = entry.lsn ?? '-'; + const succeeded = entry.succeeded; + const badge = succeeded == null + ? '<span class="status-badge" style="background:var(--theme_border_subtle);color:var(--theme_g1);">unknown</span>' + : succeeded + ? '<span class="status-badge success">ok</span>' + : '<span class="status-badge failure">failed</span>'; + const desc = entry.actionDescriptor || {}; + const fn = desc.Function || '-'; + const workerId = entry.workerId || '-'; + const actionId = entry.actionId || '-'; + + const startDate = filetimeToDate(entry.time_Running); + const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed); + + const queueId = entry.queueId || 0; + const queueCell = queueId + ? `<a href="/compute/queues/${queueId}" style="color: var(--theme_ln); text-decoration: none; font-family: monospace;">${escapeHtml(String(queueId))}</a>` + : '<span style="color: var(--theme_faint);">-</span>'; + + const tr = document.createElement('tr'); + tr.innerHTML = ` + <td style="text-align: right; font-family: monospace; color: var(--theme_g1);">${escapeHtml(String(lsn))}</td> + <td style="text-align: right;">${queueCell}</td> + <td style="text-align: center;">${badge}</td> + <td style="color: var(--theme_bright);">${escapeHtml(fn)}</td> + <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(startDate)}</td> + <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(endDate)}</td> + <td style="text-align: right; font-size: 12px; white-space: nowrap;">${formatDuration(startDate, endDate)}</td> + <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(workerId)}</td> + <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(actionId)}</td> + `; + tbody.appendChild(tr); + } + + container.style.display = 'block'; + } + + async function updateDashboard() { + try { + await Promise.all([ + fetchHealth(), + fetchStats(), + fetchSysInfo(), + fetchWorkers(), + fetchQueues(), + fetchActionHistory() + ]); + + clearError(); + updateTimestamp(); + } catch (error) { + console.error('Error updating dashboard:', error); + showError(error.message); + } + } + + // Start updating + updateDashboard(); + setInterval(updateDashboard, REFRESH_INTERVAL); + </script> +</body> +</html> diff --git a/src/zenserver/frontend/html/compute/hub.html b/src/zenserver/frontend/html/compute/hub.html new file mode 100644 index 000000000..32e1b05db --- /dev/null +++ b/src/zenserver/frontend/html/compute/hub.html @@ -0,0 +1,170 @@ +<!DOCTYPE html> +<html lang="en"> +<head> + <meta charset="UTF-8"> + <meta name="viewport" content="width=device-width, initial-scale=1.0"> + <link rel="stylesheet" type="text/css" href="../zen.css" /> + <script src="../theme.js"></script> + <script src="../banner.js" defer></script> + <script src="../nav.js" defer></script> + <title>Zen Hub Dashboard</title> +</head> +<body> + <div class="container" style="max-width: 1400px; margin: 0 auto;"> + <zen-banner cluster-status="nominal" subtitle="HUB" tagline="Overview" logo-src="../favicon.ico"></zen-banner> + <zen-nav> + <a href="/dashboard/">Home</a> + <a href="hub.html">Hub</a> + </zen-nav> + <div class="timestamp">Last updated: <span id="last-update">Never</span></div> + + <div id="error-container"></div> + + <div class="section-title">Capacity</div> + <div class="grid"> + <div class="card"> + <div class="card-title">Active Modules</div> + <div class="metric-value" id="instance-count">-</div> + <div class="metric-label">Currently provisioned</div> + </div> + <div class="card"> + <div class="card-title">Peak Modules</div> + <div class="metric-value" id="max-instance-count">-</div> + <div class="metric-label">High watermark</div> + </div> + <div class="card"> + <div class="card-title">Instance Limit</div> + <div class="metric-value" id="instance-limit">-</div> + <div class="metric-label">Maximum allowed</div> + <div class="progress-bar"> + <div class="progress-fill" id="capacity-progress" style="width: 0%"></div> + </div> + </div> + </div> + + <div class="section-title">Modules</div> + <div class="card"> + <div class="card-title">Storage Server Instances</div> + <div id="empty-state" class="empty-state">No modules provisioned.</div> + <table id="module-table" style="display: none;"> + <thead> + <tr> + <th>Module ID</th> + <th style="text-align: center;">Status</th> + </tr> + </thead> + <tbody id="module-table-body"></tbody> + </table> + </div> + </div> + + <script> + const BASE_URL = window.location.origin; + const REFRESH_INTERVAL = 2000; + + function escapeHtml(text) { + var div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; + } + + function showError(message) { + document.getElementById('error-container').innerHTML = + '<div class="error">Error: ' + escapeHtml(message) + '</div>'; + } + + function clearError() { + document.getElementById('error-container').innerHTML = ''; + } + + async function fetchJSON(endpoint) { + var response = await fetch(BASE_URL + endpoint, { + headers: { 'Accept': 'application/json' } + }); + if (!response.ok) { + throw new Error('HTTP ' + response.status + ': ' + response.statusText); + } + return await response.json(); + } + + async function fetchStats() { + var data = await fetchJSON('/hub/stats'); + + var current = data.currentInstanceCount || 0; + var max = data.maxInstanceCount || 0; + var limit = data.instanceLimit || 0; + + document.getElementById('instance-count').textContent = current; + document.getElementById('max-instance-count').textContent = max; + document.getElementById('instance-limit').textContent = limit; + + var pct = limit > 0 ? (current / limit) * 100 : 0; + document.getElementById('capacity-progress').style.width = pct + '%'; + + var banner = document.querySelector('zen-banner'); + if (current === 0) { + banner.setAttribute('cluster-status', 'nominal'); + } else if (limit > 0 && current >= limit * 0.9) { + banner.setAttribute('cluster-status', 'degraded'); + } else { + banner.setAttribute('cluster-status', 'nominal'); + } + } + + async function fetchModules() { + var data = await fetchJSON('/hub/status'); + var modules = data.modules || []; + + var emptyState = document.getElementById('empty-state'); + var table = document.getElementById('module-table'); + var tbody = document.getElementById('module-table-body'); + + if (modules.length === 0) { + emptyState.style.display = ''; + table.style.display = 'none'; + return; + } + + emptyState.style.display = 'none'; + table.style.display = ''; + + tbody.innerHTML = ''; + for (var i = 0; i < modules.length; i++) { + var m = modules[i]; + var moduleId = m.moduleId || ''; + var provisioned = m.provisioned; + + var badge = provisioned + ? '<span class="status-badge active">Provisioned</span>' + : '<span class="status-badge inactive">Inactive</span>'; + + var tr = document.createElement('tr'); + tr.innerHTML = + '<td style="font-family: monospace; font-size: 12px;">' + escapeHtml(moduleId) + '</td>' + + '<td style="text-align: center;">' + badge + '</td>'; + tbody.appendChild(tr); + } + } + + async function updateDashboard() { + var banner = document.querySelector('zen-banner'); + try { + await Promise.all([ + fetchStats(), + fetchModules() + ]); + + clearError(); + document.getElementById('last-update').textContent = new Date().toLocaleTimeString(); + } catch (error) { + console.error('Error updating dashboard:', error); + showError(error.message); + banner.setAttribute('cluster-status', 'offline'); + } + } + + updateDashboard(); + setInterval(updateDashboard, REFRESH_INTERVAL); + </script> +</body> +</html> diff --git a/src/zenserver/frontend/html/compute/index.html b/src/zenserver/frontend/html/compute/index.html new file mode 100644 index 000000000..9597fd7f3 --- /dev/null +++ b/src/zenserver/frontend/html/compute/index.html @@ -0,0 +1 @@ +<meta http-equiv="refresh" content="0; url=compute.html" />
\ No newline at end of file diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html new file mode 100644 index 000000000..a519dee18 --- /dev/null +++ b/src/zenserver/frontend/html/compute/orchestrator.html @@ -0,0 +1,674 @@ +<!DOCTYPE html> +<html lang="en"> +<head> + <meta charset="UTF-8"> + <meta name="viewport" content="width=device-width, initial-scale=1.0"> + <link rel="stylesheet" type="text/css" href="../zen.css" /> + <script src="../theme.js"></script> + <script src="../banner.js" defer></script> + <script src="../nav.js" defer></script> + <title>Zen Orchestrator Dashboard</title> + <style> + .agent-count { + display: flex; + align-items: center; + gap: 8px; + font-size: 14px; + padding: 8px 16px; + border-radius: 6px; + background: var(--theme_g3); + border: 1px solid var(--theme_g2); + } + + .agent-count .count { + font-size: 20px; + font-weight: 600; + color: var(--theme_bright); + } + </style> +</head> +<body> + <div class="container" style="max-width: 1400px; margin: 0 auto;"> + <zen-banner cluster-status="nominal" load="0" logo-src="../favicon.ico"></zen-banner> + <zen-nav> + <a href="/dashboard/">Home</a> + <a href="compute.html">Node</a> + <a href="orchestrator.html">Orchestrator</a> + </zen-nav> + <div class="header"> + <div> + <div class="timestamp">Last updated: <span id="last-update">Never</span></div> + </div> + <div class="agent-count"> + <span>Agents:</span> + <span class="count" id="agent-count">-</span> + </div> + </div> + + <div id="error-container"></div> + + <div class="card"> + <div class="card-title">Compute Agents</div> + <div id="empty-state" class="empty-state">No agents registered.</div> + <table id="agent-table" style="display: none;"> + <thead> + <tr> + <th style="width: 40px; text-align: center;">Health</th> + <th>Hostname</th> + <th style="text-align: right;">CPUs</th> + <th style="text-align: right;">CPU Usage</th> + <th style="text-align: right;">Memory</th> + <th style="text-align: right;">Queues</th> + <th style="text-align: right;">Pending</th> + <th style="text-align: right;">Running</th> + <th style="text-align: right;">Completed</th> + <th style="text-align: right;">Traffic</th> + <th style="text-align: right;">Last Seen</th> + </tr> + </thead> + <tbody id="agent-table-body"></tbody> + </table> + </div> + <div class="card" style="margin-top: 20px;"> + <div class="card-title">Connected Clients</div> + <div id="clients-empty" class="empty-state">No clients connected.</div> + <table id="clients-table" style="display: none;"> + <thead> + <tr> + <th style="width: 40px; text-align: center;">Health</th> + <th>Client ID</th> + <th>Hostname</th> + <th>Address</th> + <th style="text-align: right;">Last Seen</th> + </tr> + </thead> + <tbody id="clients-table-body"></tbody> + </table> + </div> + <div class="card" style="margin-top: 20px;"> + <div style="display: flex; align-items: center; gap: 12px; margin-bottom: 12px;"> + <div class="card-title" style="margin-bottom: 0;">Event History</div> + <div class="history-tabs"> + <button class="history-tab active" data-tab="workers" onclick="switchHistoryTab('workers')">Workers</button> + <button class="history-tab" data-tab="clients" onclick="switchHistoryTab('clients')">Clients</button> + </div> + </div> + <div id="history-panel-workers"> + <div id="history-empty" class="empty-state">No provisioning events recorded.</div> + <table id="history-table" style="display: none;"> + <thead> + <tr> + <th>Time</th> + <th>Event</th> + <th>Worker</th> + <th>Hostname</th> + </tr> + </thead> + <tbody id="history-table-body"></tbody> + </table> + </div> + <div id="history-panel-clients" style="display: none;"> + <div id="client-history-empty" class="empty-state">No client events recorded.</div> + <table id="client-history-table" style="display: none;"> + <thead> + <tr> + <th>Time</th> + <th>Event</th> + <th>Client</th> + <th>Hostname</th> + </tr> + </thead> + <tbody id="client-history-table-body"></tbody> + </table> + </div> + </div> + </div> + + <script> + const BASE_URL = window.location.origin; + const REFRESH_INTERVAL = 2000; + + function escapeHtml(text) { + var div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; + } + + function showError(message) { + document.getElementById('error-container').innerHTML = + '<div class="error">Error: ' + escapeHtml(message) + '</div>'; + } + + function clearError() { + document.getElementById('error-container').innerHTML = ''; + } + + function formatLastSeen(dtMs) { + if (dtMs == null) return '-'; + var seconds = Math.floor(dtMs / 1000); + if (seconds < 60) return seconds + 's ago'; + var minutes = Math.floor(seconds / 60); + if (minutes < 60) return minutes + 'm ' + (seconds % 60) + 's ago'; + var hours = Math.floor(minutes / 60); + return hours + 'h ' + (minutes % 60) + 'm ago'; + } + + function healthClass(dtMs, reachable) { + if (reachable === false) return 'health-red'; + if (dtMs == null) return 'health-red'; + var seconds = dtMs / 1000; + if (seconds < 30 && reachable === true) return 'health-green'; + if (seconds < 120) return 'health-yellow'; + return 'health-red'; + } + + function healthTitle(dtMs, reachable) { + var seenStr = dtMs != null ? 'Last seen ' + formatLastSeen(dtMs) : 'Never seen'; + if (reachable === true) return seenStr + ' · Reachable'; + if (reachable === false) return seenStr + ' · Unreachable'; + return seenStr + ' · Reachability unknown'; + } + + function formatCpuUsage(percent) { + if (percent == null || percent === 0) return '-'; + return percent.toFixed(1) + '%'; + } + + function formatMemory(usedBytes, totalBytes) { + if (!totalBytes) return '-'; + var usedGiB = usedBytes / (1024 * 1024 * 1024); + var totalGiB = totalBytes / (1024 * 1024 * 1024); + return usedGiB.toFixed(1) + ' / ' + totalGiB.toFixed(1) + ' GiB'; + } + + function formatBytes(bytes) { + if (!bytes) return '-'; + if (bytes < 1024) return bytes + ' B'; + if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KiB'; + if (bytes < 1024 * 1024 * 1024) return (bytes / (1024 * 1024)).toFixed(1) + ' MiB'; + if (bytes < 1024 * 1024 * 1024 * 1024) return (bytes / (1024 * 1024 * 1024)).toFixed(1) + ' GiB'; + return (bytes / (1024 * 1024 * 1024 * 1024)).toFixed(1) + ' TiB'; + } + + function formatTraffic(recv, sent) { + if (!recv && !sent) return '-'; + return formatBytes(recv) + ' / ' + formatBytes(sent); + } + + function parseIpFromUri(uri) { + try { + var url = new URL(uri); + var host = url.hostname; + // Strip IPv6 brackets + if (host.startsWith('[') && host.endsWith(']')) host = host.slice(1, -1); + // Only handle IPv4 + var parts = host.split('.'); + if (parts.length !== 4) return null; + var octets = parts.map(Number); + if (octets.some(function(o) { return isNaN(o) || o < 0 || o > 255; })) return null; + return octets; + } catch (e) { + return null; + } + } + + function computeCidr(ips) { + if (ips.length === 0) return null; + if (ips.length === 1) return ips[0].join('.') + '/32'; + + // Convert each IP to a 32-bit integer + var ints = ips.map(function(o) { + return ((o[0] << 24) | (o[1] << 16) | (o[2] << 8) | o[3]) >>> 0; + }); + + // Find common prefix length by ANDing all identical high bits + var common = ~0 >>> 0; + for (var i = 1; i < ints.length; i++) { + // XOR to find differing bits, then mask away everything from the first difference down + var diff = (ints[0] ^ ints[i]) >>> 0; + if (diff !== 0) { + var bit = 31 - Math.floor(Math.log2(diff)); + var mask = bit > 0 ? ((~0 << (32 - bit)) >>> 0) : 0; + common = (common & mask) >>> 0; + } + } + + // Count leading ones in the common mask + var prefix = 0; + for (var b = 31; b >= 0; b--) { + if ((common >>> b) & 1) prefix++; + else break; + } + + // Network address + var net = (ints[0] & common) >>> 0; + var a = (net >>> 24) & 0xff; + var bv = (net >>> 16) & 0xff; + var c = (net >>> 8) & 0xff; + var d = net & 0xff; + return a + '.' + bv + '.' + c + '.' + d + '/' + prefix; + } + + function renderDashboard(data) { + var banner = document.querySelector('zen-banner'); + if (data.hostname) { + banner.setAttribute('tagline', 'Orchestrator \u2014 ' + data.hostname); + } + var workers = data.workers || []; + + document.getElementById('agent-count').textContent = workers.length; + + if (workers.length === 0) { + banner.setAttribute('cluster-status', 'degraded'); + banner.setAttribute('load', '0'); + } else { + banner.setAttribute('cluster-status', 'nominal'); + } + + var emptyState = document.getElementById('empty-state'); + var table = document.getElementById('agent-table'); + var tbody = document.getElementById('agent-table-body'); + + if (workers.length === 0) { + emptyState.style.display = ''; + table.style.display = 'none'; + } else { + emptyState.style.display = 'none'; + table.style.display = ''; + + tbody.innerHTML = ''; + var totalCpus = 0; + var totalWeightedCpuUsage = 0; + var totalMemUsed = 0; + var totalMemTotal = 0; + var totalQueues = 0; + var totalPending = 0; + var totalRunning = 0; + var totalCompleted = 0; + var totalBytesRecv = 0; + var totalBytesSent = 0; + var allIps = []; + for (var i = 0; i < workers.length; i++) { + var w = workers[i]; + var uri = w.uri || ''; + var dt = w.dt; + var dashboardUrl = uri + '/dashboard/compute/'; + + var id = w.id || ''; + + var hostname = w.hostname || ''; + var cpus = w.cpus || 0; + totalCpus += cpus; + if (cpus > 0 && typeof w.cpu_usage === 'number') { + totalWeightedCpuUsage += w.cpu_usage * cpus; + } + + var memTotal = w.memory_total || 0; + var memUsed = w.memory_used || 0; + totalMemTotal += memTotal; + totalMemUsed += memUsed; + + var activeQueues = w.active_queues || 0; + totalQueues += activeQueues; + + var actionsPending = w.actions_pending || 0; + var actionsRunning = w.actions_running || 0; + var actionsCompleted = w.actions_completed || 0; + totalPending += actionsPending; + totalRunning += actionsRunning; + totalCompleted += actionsCompleted; + + var bytesRecv = w.bytes_received || 0; + var bytesSent = w.bytes_sent || 0; + totalBytesRecv += bytesRecv; + totalBytesSent += bytesSent; + + var ip = parseIpFromUri(uri); + if (ip) allIps.push(ip); + + var reachable = w.reachable; + var hClass = healthClass(dt, reachable); + var hTitle = healthTitle(dt, reachable); + + var platform = w.platform || ''; + var badges = ''; + if (platform) { + var platColors = { windows: '#0078d4', wine: '#722f37', linux: '#e95420', macos: '#a2aaad' }; + var platColor = platColors[platform] || '#8b949e'; + badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + platColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(platform) + '</span>'; + } + var provisioner = w.provisioner || ''; + if (provisioner) { + var provColors = { horde: '#8957e5', nomad: '#3fb950' }; + var provColor = provColors[provisioner] || '#8b949e'; + badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + provColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(provisioner) + '</span>'; + } + + var tr = document.createElement('tr'); + tr.title = id; + tr.innerHTML = + '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' + + '<td><a href="' + escapeHtml(dashboardUrl) + '" target="_blank">' + escapeHtml(hostname) + '</a>' + badges + '</td>' + + '<td style="text-align: right;">' + (cpus > 0 ? cpus : '-') + '</td>' + + '<td style="text-align: right;">' + formatCpuUsage(w.cpu_usage) + '</td>' + + '<td style="text-align: right;">' + formatMemory(memUsed, memTotal) + '</td>' + + '<td style="text-align: right;">' + (activeQueues > 0 ? activeQueues : '-') + '</td>' + + '<td style="text-align: right;">' + actionsPending + '</td>' + + '<td style="text-align: right;">' + actionsRunning + '</td>' + + '<td style="text-align: right;">' + actionsCompleted + '</td>' + + '<td style="text-align: right; font-size: 11px; color: var(--theme_g1);">' + formatTraffic(bytesRecv, bytesSent) + '</td>' + + '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>'; + tbody.appendChild(tr); + } + + var clusterLoad = totalCpus > 0 ? (totalWeightedCpuUsage / totalCpus) : 0; + banner.setAttribute('load', clusterLoad.toFixed(1)); + + // Total row + var cidr = computeCidr(allIps); + var totalTr = document.createElement('tr'); + totalTr.className = 'total-row'; + totalTr.innerHTML = + '<td></td>' + + '<td style="text-align: right; color: var(--theme_g1); text-transform: uppercase; font-size: 11px;">Total' + (cidr ? ' <span style="font-family: monospace; font-weight: normal;">' + escapeHtml(cidr) + '</span>' : '') + '</td>' + + '<td style="text-align: right;">' + totalCpus + '</td>' + + '<td></td>' + + '<td style="text-align: right;">' + formatMemory(totalMemUsed, totalMemTotal) + '</td>' + + '<td style="text-align: right;">' + totalQueues + '</td>' + + '<td style="text-align: right;">' + totalPending + '</td>' + + '<td style="text-align: right;">' + totalRunning + '</td>' + + '<td style="text-align: right;">' + totalCompleted + '</td>' + + '<td style="text-align: right; font-size: 11px;">' + formatTraffic(totalBytesRecv, totalBytesSent) + '</td>' + + '<td></td>'; + tbody.appendChild(totalTr); + } + + clearError(); + document.getElementById('last-update').textContent = new Date().toLocaleTimeString(); + + // Render provisioning history if present in WebSocket payload + if (data.events) { + renderProvisioningHistory(data.events); + } + + // Render connected clients if present + if (data.clients) { + renderClients(data.clients); + } + + // Render client history if present + if (data.client_events) { + renderClientHistory(data.client_events); + } + } + + function eventBadge(type) { + var colors = { joined: 'var(--theme_ok)', left: 'var(--theme_fail)', returned: 'var(--theme_warn)' }; + var labels = { joined: 'Joined', left: 'Left', returned: 'Returned' }; + var color = colors[type] || 'var(--theme_g1)'; + var label = labels[type] || type; + return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>'; + } + + function formatTimestamp(ts) { + if (!ts) return '-'; + // CbObject DateTime serialized as ticks (100ns since 0001-01-01) or ISO string + var date; + if (typeof ts === 'number') { + // .NET-style ticks: convert to Unix ms + var unixMs = (ts - 621355968000000000) / 10000; + date = new Date(unixMs); + } else { + date = new Date(ts); + } + if (isNaN(date.getTime())) return '-'; + return date.toLocaleTimeString(); + } + + var activeHistoryTab = 'workers'; + + function switchHistoryTab(tab) { + activeHistoryTab = tab; + var tabs = document.querySelectorAll('.history-tab'); + for (var i = 0; i < tabs.length; i++) { + tabs[i].classList.toggle('active', tabs[i].getAttribute('data-tab') === tab); + } + document.getElementById('history-panel-workers').style.display = tab === 'workers' ? '' : 'none'; + document.getElementById('history-panel-clients').style.display = tab === 'clients' ? '' : 'none'; + } + + function renderProvisioningHistory(events) { + var emptyState = document.getElementById('history-empty'); + var table = document.getElementById('history-table'); + var tbody = document.getElementById('history-table-body'); + + if (!events || events.length === 0) { + emptyState.style.display = ''; + table.style.display = 'none'; + return; + } + + emptyState.style.display = 'none'; + table.style.display = ''; + tbody.innerHTML = ''; + + for (var i = 0; i < events.length; i++) { + var evt = events[i]; + var tr = document.createElement('tr'); + tr.innerHTML = + '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' + + '<td>' + eventBadge(evt.type) + '</td>' + + '<td>' + escapeHtml(evt.worker_id || '') + '</td>' + + '<td>' + escapeHtml(evt.hostname || '') + '</td>'; + tbody.appendChild(tr); + } + } + + function clientHealthClass(dtMs) { + if (dtMs == null) return 'health-red'; + var seconds = dtMs / 1000; + if (seconds < 30) return 'health-green'; + if (seconds < 120) return 'health-yellow'; + return 'health-red'; + } + + function renderClients(clients) { + var emptyState = document.getElementById('clients-empty'); + var table = document.getElementById('clients-table'); + var tbody = document.getElementById('clients-table-body'); + + if (!clients || clients.length === 0) { + emptyState.style.display = ''; + table.style.display = 'none'; + return; + } + + emptyState.style.display = 'none'; + table.style.display = ''; + tbody.innerHTML = ''; + + for (var i = 0; i < clients.length; i++) { + var c = clients[i]; + var dt = c.dt; + var hClass = clientHealthClass(dt); + var hTitle = dt != null ? 'Last seen ' + formatLastSeen(dt) : 'Never seen'; + + var sessionBadge = ''; + if (c.session_id) { + sessionBadge = ' <span style="font-family:monospace;font-size:10px;color:var(--theme_faint);" title="Session ' + escapeHtml(c.session_id) + '">' + escapeHtml(c.session_id.substring(0, 8)) + '</span>'; + } + + var tr = document.createElement('tr'); + tr.innerHTML = + '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' + + '<td>' + escapeHtml(c.id || '') + sessionBadge + '</td>' + + '<td>' + escapeHtml(c.hostname || '') + '</td>' + + '<td style="font-family: monospace; font-size: 12px; color: var(--theme_g1);">' + escapeHtml(c.address || '') + '</td>' + + '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>'; + tbody.appendChild(tr); + } + } + + function clientEventBadge(type) { + var colors = { connected: 'var(--theme_ok)', disconnected: 'var(--theme_fail)', updated: 'var(--theme_warn)' }; + var labels = { connected: 'Connected', disconnected: 'Disconnected', updated: 'Updated' }; + var color = colors[type] || 'var(--theme_g1)'; + var label = labels[type] || type; + return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>'; + } + + function renderClientHistory(events) { + var emptyState = document.getElementById('client-history-empty'); + var table = document.getElementById('client-history-table'); + var tbody = document.getElementById('client-history-table-body'); + + if (!events || events.length === 0) { + emptyState.style.display = ''; + table.style.display = 'none'; + return; + } + + emptyState.style.display = 'none'; + table.style.display = ''; + tbody.innerHTML = ''; + + for (var i = 0; i < events.length; i++) { + var evt = events[i]; + var tr = document.createElement('tr'); + tr.innerHTML = + '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' + + '<td>' + clientEventBadge(evt.type) + '</td>' + + '<td>' + escapeHtml(evt.client_id || '') + '</td>' + + '<td>' + escapeHtml(evt.hostname || '') + '</td>'; + tbody.appendChild(tr); + } + } + + // Fetch-based polling fallback + var pollTimer = null; + + async function fetchProvisioningHistory() { + try { + var response = await fetch(BASE_URL + '/orch/history?limit=50', { + headers: { 'Accept': 'application/json' } + }); + if (response.ok) { + var data = await response.json(); + renderProvisioningHistory(data.events || []); + } + } catch (e) { + console.error('Error fetching provisioning history:', e); + } + } + + async function fetchClients() { + try { + var response = await fetch(BASE_URL + '/orch/clients', { + headers: { 'Accept': 'application/json' } + }); + if (response.ok) { + var data = await response.json(); + renderClients(data.clients || []); + } + } catch (e) { + console.error('Error fetching clients:', e); + } + } + + async function fetchClientHistory() { + try { + var response = await fetch(BASE_URL + '/orch/clients/history?limit=50', { + headers: { 'Accept': 'application/json' } + }); + if (response.ok) { + var data = await response.json(); + renderClientHistory(data.client_events || []); + } + } catch (e) { + console.error('Error fetching client history:', e); + } + } + + async function fetchDashboard() { + var banner = document.querySelector('zen-banner'); + try { + var response = await fetch(BASE_URL + '/orch/agents', { + headers: { 'Accept': 'application/json' } + }); + + if (!response.ok) { + banner.setAttribute('cluster-status', 'degraded'); + throw new Error('HTTP ' + response.status + ': ' + response.statusText); + } + + renderDashboard(await response.json()); + fetchProvisioningHistory(); + fetchClients(); + fetchClientHistory(); + } catch (error) { + console.error('Error updating dashboard:', error); + showError(error.message); + banner.setAttribute('cluster-status', 'offline'); + } + } + + function startPolling() { + if (pollTimer) return; + fetchDashboard(); + pollTimer = setInterval(fetchDashboard, REFRESH_INTERVAL); + } + + function stopPolling() { + if (pollTimer) { + clearInterval(pollTimer); + pollTimer = null; + } + } + + // WebSocket connection with automatic reconnect and polling fallback + var ws = null; + + function connectWebSocket() { + var proto = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + ws = new WebSocket(proto + '//' + window.location.host + '/orch/ws'); + + ws.onopen = function() { + stopPolling(); + clearError(); + }; + + ws.onmessage = function(event) { + try { + renderDashboard(JSON.parse(event.data)); + } catch (e) { + console.error('WebSocket message parse error:', e); + } + }; + + ws.onclose = function() { + ws = null; + startPolling(); + setTimeout(connectWebSocket, 3000); + }; + + ws.onerror = function() { + // onclose will fire after onerror + }; + } + + // Fetch orchestrator hostname for the banner + fetch(BASE_URL + '/orch/status', { headers: { 'Accept': 'application/json' } }) + .then(function(r) { return r.ok ? r.json() : null; }) + .then(function(d) { + if (d && d.hostname) { + document.querySelector('zen-banner').setAttribute('tagline', 'Orchestrator \u2014 ' + d.hostname); + } + }) + .catch(function() {}); + + // Initial load via fetch, then try WebSocket + fetchDashboard(); + connectWebSocket(); + </script> +</body> +</html> diff --git a/src/UnrealEngine.ico b/src/zenserver/frontend/html/epicgames.ico Binary files differindex 1cfa301a2..1cfa301a2 100644 --- a/src/UnrealEngine.ico +++ b/src/zenserver/frontend/html/epicgames.ico diff --git a/src/zenserver/frontend/html/favicon.ico b/src/zenserver/frontend/html/favicon.ico Binary files differindex 1cfa301a2..f7fb251b5 100644 --- a/src/zenserver/frontend/html/favicon.ico +++ b/src/zenserver/frontend/html/favicon.ico diff --git a/src/zenserver/frontend/html/index.html b/src/zenserver/frontend/html/index.html index 6a736e914..24a136a30 100644 --- a/src/zenserver/frontend/html/index.html +++ b/src/zenserver/frontend/html/index.html @@ -10,6 +10,9 @@ </script> <link rel="shortcut icon" href="favicon.ico"> <link rel="stylesheet" type="text/css" href="zen.css" /> + <script src="theme.js"></script> + <script src="banner.js" defer></script> + <script src="nav.js" defer></script> <script type="module" src="zen.js"></script> </head> </html> diff --git a/src/zenserver/frontend/html/nav.js b/src/zenserver/frontend/html/nav.js new file mode 100644 index 000000000..a5de203f2 --- /dev/null +++ b/src/zenserver/frontend/html/nav.js @@ -0,0 +1,79 @@ +/** + * zen-nav.js — Zen dashboard navigation bar Web Component + * + * Usage: + * <script src="nav.js" defer></script> + * + * <zen-nav> + * <a href="compute.html">Node</a> + * <a href="orchestrator.html">Orchestrator</a> + * </zen-nav> + * + * Each child <a> becomes a nav link. The current page is + * highlighted automatically based on the href. + */ + +class ZenNav extends HTMLElement { + + connectedCallback() { + if (!this.shadowRoot) this.attachShadow({ mode: 'open' }); + this._render(); + } + + _render() { + const currentPath = window.location.pathname; + const items = Array.from(this.querySelectorAll(':scope > a')); + + const links = items.map(a => { + const href = a.getAttribute('href') || ''; + const label = a.textContent.trim(); + const active = currentPath.endsWith(href); + return `<a class="nav-link${active ? ' active' : ''}" href="${href}">${label}</a>`; + }).join(''); + + this.shadowRoot.innerHTML = ` + <style> + *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; } + + :host { + display: block; + margin-bottom: 16px; + } + + .nav-bar { + display: flex; + align-items: center; + gap: 4px; + padding: 4px; + background: var(--theme_g3); + border: 1px solid var(--theme_g2); + border-radius: 6px; + } + + .nav-link { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; + font-size: 13px; + font-weight: 500; + color: var(--theme_g1); + text-decoration: none; + padding: 6px 14px; + border-radius: 4px; + transition: color 0.15s, background 0.15s; + } + + .nav-link:hover { + color: var(--theme_g0); + background: var(--theme_p4); + } + + .nav-link.active { + color: var(--theme_bright); + background: var(--theme_g2); + } + </style> + <nav class="nav-bar">${links}</nav> + `; + } +} + +customElements.define('zen-nav', ZenNav); diff --git a/src/zenserver/frontend/html/pages/cache.js b/src/zenserver/frontend/html/pages/cache.js new file mode 100644 index 000000000..3b838958a --- /dev/null +++ b/src/zenserver/frontend/html/pages/cache.js @@ -0,0 +1,690 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +"use strict"; + +import { ZenPage } from "./page.js" +import { Fetcher } from "../util/fetcher.js" +import { Friendly } from "../util/friendly.js" +import { Modal } from "../util/modal.js" +import { Table, Toolbar } from "../util/widgets.js" + +//////////////////////////////////////////////////////////////////////////////// +export class Page extends ZenPage +{ + async main() + { + this.set_title("cache"); + + // Cache Service Stats + const stats_section = this._collapsible_section("Cache Service Stats"); + stats_section.tag().classify("dropall").text("raw yaml \u2192").on_click(() => { + window.open("/stats/z$.yaml?cidstorestats=true&cachestorestats=true", "_blank"); + }); + this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles"); + this._details_host = stats_section; + this._details_container = null; + this._selected_category = null; + + const stats = await new Fetcher().resource("stats", "z$").json(); + if (stats) + { + this._render_stats(stats); + } + + this._connect_stats_ws(); + + // Cache Namespaces + var section = this._collapsible_section("Cache Namespaces"); + + section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all()); + + var columns = [ + "namespace", + "dir", + "buckets", + "entries", + "size disk", + "size mem", + "actions", + ]; + + var zcache_info = await new Fetcher().resource("/z$/").json(); + this._cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_AlignNumeric); + + for (const namespace of zcache_info["Namespaces"] || []) + { + new Fetcher().resource(`/z$/${namespace}/`).json().then((data) => { + const row = this._cache_table.add_row( + "", + data["Configuration"]["RootDir"], + data["Buckets"].length, + data["EntryCount"], + Friendly.bytes(data["StorageSize"].DiskSize), + Friendly.bytes(data["StorageSize"].MemorySize) + ); + var cell = row.get_cell(0); + cell.tag().text(namespace).on_click(() => this.view_namespace(namespace)); + + cell = row.get_cell(-1); + const action_tb = new Toolbar(cell, true); + action_tb.left().add("view").on_click(() => this.view_namespace(namespace)); + action_tb.left().add("drop").on_click(() => this.drop_namespace(namespace)); + + row.attr("zs_name", namespace); + }); + } + + // Namespace detail area (inside namespaces section so it collapses together) + this._namespace_host = section; + this._namespace_container = null; + this._selected_namespace = null; + + // Restore namespace from URL if present + const ns_param = this.get_param("namespace"); + if (ns_param) + { + this.view_namespace(ns_param); + } + } + + _collapsible_section(name) + { + const section = this.add_section(name); + const container = section._parent.inner(); + const heading = container.firstElementChild; + + heading.style.cursor = "pointer"; + heading.style.userSelect = "none"; + + const indicator = document.createElement("span"); + indicator.textContent = " \u25BC"; + indicator.style.fontSize = "0.7em"; + heading.appendChild(indicator); + + let collapsed = false; + heading.addEventListener("click", (e) => { + if (e.target !== heading && e.target !== indicator) + { + return; + } + collapsed = !collapsed; + indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; + let sibling = heading.nextElementSibling; + while (sibling) + { + sibling.style.display = collapsed ? "none" : ""; + sibling = sibling.nextElementSibling; + } + }); + + return section; + } + + _connect_stats_ws() + { + try + { + const proto = location.protocol === "https:" ? "wss:" : "ws:"; + const ws = new WebSocket(`${proto}//${location.host}/stats`); + + try { this._ws_paused = localStorage.getItem("zen-ws-paused") === "true"; } catch (e) { this._ws_paused = false; } + document.addEventListener("zen-ws-toggle", (e) => { + this._ws_paused = e.detail.paused; + }); + + ws.onmessage = (ev) => { + if (this._ws_paused) + { + return; + } + try + { + const all_stats = JSON.parse(ev.data); + const stats = all_stats["z$"]; + if (stats) + { + this._render_stats(stats); + } + } + catch (e) { /* ignore parse errors */ } + }; + + ws.onclose = () => { this._stats_ws = null; }; + ws.onerror = () => { ws.close(); }; + + this._stats_ws = ws; + } + catch (e) { /* WebSocket not available */ } + } + + _render_stats(stats) + { + const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj); + const grid = this._stats_grid; + + this._last_stats = stats; + grid.inner().innerHTML = ""; + + // Store I/O tile + { + const store = safe(stats, "cache.store"); + if (store) + { + const tile = grid.tag().classify("card").classify("stats-tile").classify("stats-tile-detailed"); + if (this._selected_category === "store") tile.classify("stats-tile-selected"); + tile.on_click(() => this._select_category("store")); + tile.tag().classify("card-title").text("Store I/O"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + const storeHits = store.hits || 0; + const storeMisses = store.misses || 0; + const storeTotal = storeHits + storeMisses; + const storeRatio = storeTotal > 0 ? ((storeHits / storeTotal) * 100).toFixed(1) + "%" : "-"; + this._metric(left, storeRatio, "store hit ratio", true); + this._metric(left, Friendly.sep(storeHits), "hits"); + this._metric(left, Friendly.sep(storeMisses), "misses"); + this._metric(left, Friendly.sep(store.writes || 0), "writes"); + this._metric(left, Friendly.sep(store.rejected_reads || 0), "rejected reads"); + this._metric(left, Friendly.sep(store.rejected_writes || 0), "rejected writes"); + + const right = columns.tag().classify("tile-metrics"); + const readRateMean = safe(store, "read.bytes.rate_mean") || 0; + const readRate1 = safe(store, "read.bytes.rate_1") || 0; + const readRate5 = safe(store, "read.bytes.rate_5") || 0; + const writeRateMean = safe(store, "write.bytes.rate_mean") || 0; + const writeRate1 = safe(store, "write.bytes.rate_1") || 0; + const writeRate5 = safe(store, "write.bytes.rate_5") || 0; + this._metric(right, Friendly.bytes(readRateMean) + "/s", "read rate (mean)", true); + this._metric(right, Friendly.bytes(readRate1) + "/s", "read rate (1m)"); + this._metric(right, Friendly.bytes(readRate5) + "/s", "read rate (5m)"); + this._metric(right, Friendly.bytes(writeRateMean) + "/s", "write rate (mean)"); + this._metric(right, Friendly.bytes(writeRate1) + "/s", "write rate (1m)"); + this._metric(right, Friendly.bytes(writeRate5) + "/s", "write rate (5m)"); + } + } + + // Hit/Miss tile + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Hit Ratio"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + const hits = safe(stats, "cache.hits") || 0; + const misses = safe(stats, "cache.misses") || 0; + const writes = safe(stats, "cache.writes") || 0; + const total = hits + misses; + const ratio = total > 0 ? ((hits / total) * 100).toFixed(1) + "%" : "-"; + + this._metric(left, ratio, "hit ratio", true); + this._metric(left, Friendly.sep(hits), "hits"); + this._metric(left, Friendly.sep(misses), "misses"); + this._metric(left, Friendly.sep(writes), "writes"); + + const right = columns.tag().classify("tile-metrics"); + const cidHits = safe(stats, "cache.cidhits") || 0; + const cidMisses = safe(stats, "cache.cidmisses") || 0; + const cidWrites = safe(stats, "cache.cidwrites") || 0; + const cidTotal = cidHits + cidMisses; + const cidRatio = cidTotal > 0 ? ((cidHits / cidTotal) * 100).toFixed(1) + "%" : "-"; + + this._metric(right, cidRatio, "cid hit ratio", true); + this._metric(right, Friendly.sep(cidHits), "cid hits"); + this._metric(right, Friendly.sep(cidMisses), "cid misses"); + this._metric(right, Friendly.sep(cidWrites), "cid writes"); + } + + // HTTP Requests tile + { + const req = safe(stats, "requests"); + if (req) + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("HTTP Requests"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + const reqData = req.requests || req; + this._metric(left, Friendly.sep(reqData.count || 0), "total requests", true); + if (reqData.rate_mean > 0) + { + this._metric(left, Friendly.sep(reqData.rate_mean, 1) + "/s", "req/sec (mean)"); + } + if (reqData.rate_1 > 0) + { + this._metric(left, Friendly.sep(reqData.rate_1, 1) + "/s", "req/sec (1m)"); + } + if (reqData.rate_5 > 0) + { + this._metric(left, Friendly.sep(reqData.rate_5, 1) + "/s", "req/sec (5m)"); + } + if (reqData.rate_15 > 0) + { + this._metric(left, Friendly.sep(reqData.rate_15, 1) + "/s", "req/sec (15m)"); + } + const badRequests = safe(stats, "cache.badrequestcount") || 0; + this._metric(left, Friendly.sep(badRequests), "bad requests"); + + const right = columns.tag().classify("tile-metrics"); + this._metric(right, Friendly.duration(reqData.t_avg || 0), "avg latency", true); + if (reqData.t_p75) + { + this._metric(right, Friendly.duration(reqData.t_p75), "p75"); + } + if (reqData.t_p95) + { + this._metric(right, Friendly.duration(reqData.t_p95), "p95"); + } + if (reqData.t_p99) + { + this._metric(right, Friendly.duration(reqData.t_p99), "p99"); + } + if (reqData.t_p999) + { + this._metric(right, Friendly.duration(reqData.t_p999), "p999"); + } + if (reqData.t_max) + { + this._metric(right, Friendly.duration(reqData.t_max), "max"); + } + } + } + + // RPC tile + { + const rpc = safe(stats, "cache.rpc"); + if (rpc) + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("RPC"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + this._metric(left, Friendly.sep(rpc.count || 0), "rpc calls", true); + this._metric(left, Friendly.sep(rpc.ops || 0), "batch ops"); + + const right = columns.tag().classify("tile-metrics"); + if (rpc.records) + { + this._metric(right, Friendly.sep(rpc.records.count || 0), "record calls"); + this._metric(right, Friendly.sep(rpc.records.ops || 0), "record ops"); + } + if (rpc.values) + { + this._metric(right, Friendly.sep(rpc.values.count || 0), "value calls"); + this._metric(right, Friendly.sep(rpc.values.ops || 0), "value ops"); + } + if (rpc.chunks) + { + this._metric(right, Friendly.sep(rpc.chunks.count || 0), "chunk calls"); + this._metric(right, Friendly.sep(rpc.chunks.ops || 0), "chunk ops"); + } + } + } + + // Storage tile + { + const tile = grid.tag().classify("card").classify("stats-tile").classify("stats-tile-detailed"); + if (this._selected_category === "storage") tile.classify("stats-tile-selected"); + tile.on_click(() => this._select_category("storage")); + tile.tag().classify("card-title").text("Storage"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + this._metric(left, safe(stats, "cache.size.disk") != null ? Friendly.bytes(safe(stats, "cache.size.disk")) : "-", "cache disk", true); + this._metric(left, safe(stats, "cache.size.memory") != null ? Friendly.bytes(safe(stats, "cache.size.memory")) : "-", "cache memory"); + + const right = columns.tag().classify("tile-metrics"); + this._metric(right, safe(stats, "cid.size.total") != null ? Friendly.bytes(safe(stats, "cid.size.total")) : "-", "cid total", true); + this._metric(right, safe(stats, "cid.size.tiny") != null ? Friendly.bytes(safe(stats, "cid.size.tiny")) : "-", "cid tiny"); + this._metric(right, safe(stats, "cid.size.small") != null ? Friendly.bytes(safe(stats, "cid.size.small")) : "-", "cid small"); + this._metric(right, safe(stats, "cid.size.large") != null ? Friendly.bytes(safe(stats, "cid.size.large")) : "-", "cid large"); + } + + // Upstream tile (only if upstream is active) + { + const upstream = safe(stats, "upstream"); + if (upstream) + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Upstream"); + const body = tile.tag().classify("tile-metrics"); + + const upstreamHits = safe(stats, "cache.upstream_hits") || 0; + this._metric(body, Friendly.sep(upstreamHits), "upstream hits", true); + + if (upstream.url) + { + this._metric(body, upstream.url, "endpoint"); + } + } + } + } + + _metric(parent, value, label, hero = false) + { + const m = parent.tag().classify("tile-metric"); + if (hero) + { + m.classify("tile-metric-hero"); + } + m.tag().classify("metric-value").text(value); + m.tag().classify("metric-label").text(label); + } + + async _select_category(category) + { + // Toggle off if already selected + if (this._selected_category === category) + { + this._selected_category = null; + this._clear_details(); + this._render_stats(this._last_stats); + return; + } + + this._selected_category = category; + this._render_stats(this._last_stats); + + // Fetch detailed stats + const detailed = await new Fetcher() + .resource("stats", "z$") + .param("cachestorestats", "true") + .param("cidstorestats", "true") + .json(); + + if (!detailed || this._selected_category !== category) + { + return; + } + + this._clear_details(); + + const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj); + + if (category === "store") + { + this._render_store_details(detailed, safe); + } + else if (category === "storage") + { + this._render_storage_details(detailed, safe); + } + } + + _clear_details() + { + if (this._details_container) + { + this._details_container.inner().remove(); + this._details_container = null; + } + } + + _render_store_details(stats, safe) + { + const namespaces = safe(stats, "cache.store.namespaces") || []; + if (namespaces.length === 0) + { + return; + } + + const container = this._details_host.tag(); + this._details_container = container; + + const columns = [ + "namespace", + "bucket", + "hits", + "misses", + "writes", + "hit ratio", + "read count", + "read bandwidth", + "write count", + "write bandwidth", + ]; + const table = new Table(container, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric); + + for (const ns of namespaces) + { + const nsHits = ns.hits || 0; + const nsMisses = ns.misses || 0; + const nsTotal = nsHits + nsMisses; + const nsRatio = nsTotal > 0 ? ((nsHits / nsTotal) * 100).toFixed(1) + "%" : "-"; + + const readCount = safe(ns, "read.request.count") || 0; + const readBytes = safe(ns, "read.bytes.count") || 0; + const writeCount = safe(ns, "write.request.count") || 0; + const writeBytes = safe(ns, "write.bytes.count") || 0; + + table.add_row( + ns.namespace, + "", + Friendly.sep(nsHits), + Friendly.sep(nsMisses), + Friendly.sep(ns.writes || 0), + nsRatio, + Friendly.sep(readCount), + Friendly.bytes(readBytes), + Friendly.sep(writeCount), + Friendly.bytes(writeBytes), + ); + + if (ns.buckets && ns.buckets.length > 0) + { + for (const bucket of ns.buckets) + { + const bHits = bucket.hits || 0; + const bMisses = bucket.misses || 0; + const bTotal = bHits + bMisses; + const bRatio = bTotal > 0 ? ((bHits / bTotal) * 100).toFixed(1) + "%" : "-"; + + const bReadCount = safe(bucket, "read.request.count") || 0; + const bReadBytes = safe(bucket, "read.bytes.count") || 0; + const bWriteCount = safe(bucket, "write.request.count") || 0; + const bWriteBytes = safe(bucket, "write.bytes.count") || 0; + + table.add_row( + ns.namespace, + bucket.bucket, + Friendly.sep(bHits), + Friendly.sep(bMisses), + Friendly.sep(bucket.writes || 0), + bRatio, + Friendly.sep(bReadCount), + Friendly.bytes(bReadBytes), + Friendly.sep(bWriteCount), + Friendly.bytes(bWriteBytes), + ); + } + } + } + } + + _render_storage_details(stats, safe) + { + const namespaces = safe(stats, "cache.store.namespaces") || []; + if (namespaces.length === 0) + { + return; + } + + const container = this._details_host.tag(); + this._details_container = container; + + const columns = [ + "namespace", + "bucket", + "disk", + "memory", + ]; + const table = new Table(container, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric); + + for (const ns of namespaces) + { + const diskSize = safe(ns, "size.disk") || 0; + const memSize = safe(ns, "size.memory") || 0; + + table.add_row( + ns.namespace, + "", + Friendly.bytes(diskSize), + Friendly.bytes(memSize), + ); + + if (ns.buckets && ns.buckets.length > 0) + { + for (const bucket of ns.buckets) + { + const bDisk = safe(bucket, "size.disk") || 0; + const bMem = safe(bucket, "size.memory") || 0; + + table.add_row( + ns.namespace, + bucket.bucket, + Friendly.bytes(bDisk), + Friendly.bytes(bMem), + ); + } + } + } + } + + async view_namespace(namespace) + { + // Toggle off if already selected + if (this._selected_namespace === namespace) + { + this._selected_namespace = null; + this._clear_namespace(); + this._clear_param("namespace"); + return; + } + + this._selected_namespace = namespace; + this._clear_namespace(); + this.set_param("namespace", namespace); + + const info = await new Fetcher().resource(`/z$/${namespace}/`).json(); + if (this._selected_namespace !== namespace) + { + return; + } + + const section = this._namespace_host.add_section(namespace); + this._namespace_container = section; + + // Buckets table + const bucket_section = section.add_section("Buckets"); + const bucket_columns = ["name", "disk", "memory", "entries", "actions"]; + const bucket_table = bucket_section.add_widget( + Table, + bucket_columns, + Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric + ); + + // Right-align header for numeric columns (skip # and name) + const header = bucket_table._element.firstElementChild; + for (let i = 2; i < header.children.length - 1; i++) + { + header.children[i].style.textAlign = "right"; + } + + let totalDisk = 0, totalMem = 0, totalEntries = 0; + const total_row = bucket_table.add_row("TOTAL"); + total_row.get_cell(0).style("fontWeight", "bold"); + total_row.get_cell(1).style("textAlign", "right").style("fontWeight", "bold"); + total_row.get_cell(2).style("textAlign", "right").style("fontWeight", "bold"); + total_row.get_cell(3).style("textAlign", "right").style("fontWeight", "bold"); + + for (const bucket of info["Buckets"]) + { + const row = bucket_table.add_row(bucket); + new Fetcher().resource(`/z$/${namespace}/${bucket}`).json().then((data) => { + row.get_cell(1).text(Friendly.bytes(data["StorageSize"]["DiskSize"])).style("textAlign", "right"); + row.get_cell(2).text(Friendly.bytes(data["StorageSize"]["MemorySize"])).style("textAlign", "right"); + row.get_cell(3).text(Friendly.sep(data["DiskEntryCount"])).style("textAlign", "right"); + + const cell = row.get_cell(-1); + const action_tb = new Toolbar(cell, true); + action_tb.left().add("drop").on_click(() => this.drop_bucket(namespace, bucket)); + + totalDisk += data["StorageSize"]["DiskSize"]; + totalMem += data["StorageSize"]["MemorySize"]; + totalEntries += data["DiskEntryCount"]; + total_row.get_cell(1).text(Friendly.bytes(totalDisk)).style("textAlign", "right").style("fontWeight", "bold"); + total_row.get_cell(2).text(Friendly.bytes(totalMem)).style("textAlign", "right").style("fontWeight", "bold"); + total_row.get_cell(3).text(Friendly.sep(totalEntries)).style("textAlign", "right").style("fontWeight", "bold"); + }); + } + + } + + _clear_param(name) + { + this._params.delete(name); + const url = new URL(window.location); + url.searchParams.delete(name); + history.replaceState(null, "", url); + } + + _clear_namespace() + { + if (this._namespace_container) + { + this._namespace_container._parent.inner().remove(); + this._namespace_container = null; + } + } + + drop_bucket(namespace, bucket) + { + const drop = async () => { + await new Fetcher().resource("z$", namespace, bucket).delete(); + // Refresh the namespace view + this._selected_namespace = null; + this._clear_namespace(); + this.view_namespace(namespace); + }; + + new Modal() + .title("Confirmation") + .message(`Drop bucket '${bucket}'?`) + .option("Yes", () => drop()) + .option("No"); + } + + drop_namespace(namespace) + { + const drop = async () => { + await new Fetcher().resource("z$", namespace).delete(); + this.reload(); + }; + + new Modal() + .title("Confirmation") + .message(`Drop cache namespace '${namespace}'?`) + .option("Yes", () => drop()) + .option("No"); + } + + async drop_all() + { + const drop = async () => { + for (const row of this._cache_table) + { + const namespace = row.attr("zs_name"); + await new Fetcher().resource("z$", namespace).delete(); + } + this.reload(); + }; + + new Modal() + .title("Confirmation") + .message("Drop every cache namespace?") + .option("Yes", () => drop()) + .option("No"); + } +} diff --git a/src/zenserver/frontend/html/pages/compute.js b/src/zenserver/frontend/html/pages/compute.js new file mode 100644 index 000000000..ab3d49c27 --- /dev/null +++ b/src/zenserver/frontend/html/pages/compute.js @@ -0,0 +1,693 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +"use strict"; + +import { ZenPage } from "./page.js" +import { Fetcher } from "../util/fetcher.js" +import { Friendly } from "../util/friendly.js" +import { Table } from "../util/widgets.js" + +const MAX_HISTORY_POINTS = 60; + +// Windows FILETIME: 100ns ticks since 1601-01-01 +const FILETIME_EPOCH_OFFSET_MS = 11644473600000n; +function filetimeToDate(ticks) +{ + if (!ticks) return null; + const ms = BigInt(ticks) / 10000n - FILETIME_EPOCH_OFFSET_MS; + return new Date(Number(ms)); +} + +function formatTime(date) +{ + if (!date) return "-"; + return date.toLocaleTimeString([], { hour: "2-digit", minute: "2-digit", second: "2-digit" }); +} + +function formatDuration(startDate, endDate) +{ + if (!startDate || !endDate) return "-"; + const ms = endDate - startDate; + if (ms < 0) return "-"; + if (ms < 1000) return ms + " ms"; + if (ms < 60000) return (ms / 1000).toFixed(2) + " s"; + const m = Math.floor(ms / 60000); + const s = ((ms % 60000) / 1000).toFixed(0).padStart(2, "0"); + return `${m}m ${s}s`; +} + +//////////////////////////////////////////////////////////////////////////////// +export class Page extends ZenPage +{ + async main() + { + this.set_title("compute"); + + this._history = { timestamps: [], pending: [], running: [], completed: [], cpu: [] }; + this._selected_worker = null; + this._chart_js = null; + this._queue_chart = null; + this._cpu_chart = null; + + this._ws_paused = false; + try { this._ws_paused = localStorage.getItem("zen-ws-paused") === "true"; } catch (e) {} + document.addEventListener("zen-ws-toggle", (e) => { + this._ws_paused = e.detail.paused; + }); + + // Action Queue section + const queue_section = this._collapsible_section("Action Queue"); + this._queue_grid = queue_section.tag().classify("grid").classify("stats-tiles"); + this._chart_host = queue_section; + + // Performance Metrics section + const perf_section = this._collapsible_section("Performance Metrics"); + this._perf_host = perf_section; + this._perf_grid = null; + + // Workers section + const workers_section = this._collapsible_section("Workers"); + this._workers_host = workers_section; + this._workers_table = null; + this._worker_detail_container = null; + + // Queues section + const queues_section = this._collapsible_section("Queues"); + this._queues_host = queues_section; + this._queues_table = null; + + // Action History section + const history_section = this._collapsible_section("Recent Actions"); + this._history_host = history_section; + this._history_table = null; + + // System Resources section + const sys_section = this._collapsible_section("System Resources"); + this._sys_grid = sys_section.tag().classify("grid").classify("stats-tiles"); + + // Load Chart.js dynamically + this._load_chartjs(); + + // Initial fetch + await this._fetch_all(); + + // Poll + this._poll_timer = setInterval(() => { + if (!this._ws_paused) + { + this._fetch_all(); + } + }, 2000); + } + + _collapsible_section(name) + { + const section = this.add_section(name); + const container = section._parent.inner(); + const heading = container.firstElementChild; + + heading.style.cursor = "pointer"; + heading.style.userSelect = "none"; + + const indicator = document.createElement("span"); + indicator.textContent = " \u25BC"; + indicator.style.fontSize = "0.7em"; + heading.appendChild(indicator); + + let collapsed = false; + heading.addEventListener("click", (e) => { + if (e.target !== heading && e.target !== indicator) + { + return; + } + collapsed = !collapsed; + indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; + let sibling = heading.nextElementSibling; + while (sibling) + { + sibling.style.display = collapsed ? "none" : ""; + sibling = sibling.nextElementSibling; + } + }); + + return section; + } + + async _load_chartjs() + { + if (window.Chart) + { + this._chart_js = window.Chart; + this._init_charts(); + return; + } + + try + { + const script = document.createElement("script"); + script.src = "https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js"; + script.onload = () => { + this._chart_js = window.Chart; + this._init_charts(); + }; + document.head.appendChild(script); + } + catch (e) { /* Chart.js not available */ } + } + + _init_charts() + { + if (!this._chart_js) + { + return; + } + + // Queue history chart + { + const card = this._chart_host.tag().classify("card"); + card.tag().classify("card-title").text("Action Queue History"); + const container = card.tag(); + container.style("position", "relative").style("height", "300px").style("marginTop", "20px"); + const canvas = document.createElement("canvas"); + container.inner().appendChild(canvas); + + this._queue_chart = new this._chart_js(canvas.getContext("2d"), { + type: "line", + data: { + labels: [], + datasets: [ + { label: "Pending", data: [], borderColor: "#f0883e", backgroundColor: "rgba(240, 136, 62, 0.1)", tension: 0.4, fill: true }, + { label: "Running", data: [], borderColor: "#58a6ff", backgroundColor: "rgba(88, 166, 255, 0.1)", tension: 0.4, fill: true }, + { label: "Completed", data: [], borderColor: "#3fb950", backgroundColor: "rgba(63, 185, 80, 0.1)", tension: 0.4, fill: true }, + ] + }, + options: { + responsive: true, + maintainAspectRatio: false, + plugins: { legend: { display: true, labels: { color: "#8b949e" } } }, + scales: { x: { display: false }, y: { beginAtZero: true, ticks: { color: "#8b949e" }, grid: { color: "#21262d" } } } + } + }); + } + + // CPU sparkline (will be appended to CPU card later) + this._cpu_canvas = document.createElement("canvas"); + this._cpu_chart = new this._chart_js(this._cpu_canvas.getContext("2d"), { + type: "line", + data: { + labels: [], + datasets: [{ + data: [], + borderColor: "#58a6ff", + backgroundColor: "rgba(88, 166, 255, 0.15)", + borderWidth: 1.5, + tension: 0.4, + fill: true, + pointRadius: 0 + }] + }, + options: { + responsive: true, + maintainAspectRatio: false, + animation: false, + plugins: { legend: { display: false }, tooltip: { enabled: false } }, + scales: { x: { display: false }, y: { display: false, min: 0, max: 100 } } + } + }); + } + + async _fetch_all() + { + try + { + const [stats, sysinfo, workers_data, queues_data, history_data] = await Promise.all([ + new Fetcher().resource("/stats/compute").json().catch(() => null), + new Fetcher().resource("/compute/sysinfo").json().catch(() => null), + new Fetcher().resource("/compute/workers").json().catch(() => null), + new Fetcher().resource("/compute/queues").json().catch(() => null), + new Fetcher().resource("/compute/jobs/history").param("limit", "50").json().catch(() => null), + ]); + + if (stats) + { + this._render_queue_stats(stats); + this._update_queue_chart(stats); + this._render_perf(stats); + } + if (sysinfo) + { + this._render_sysinfo(sysinfo); + } + if (workers_data) + { + this._render_workers(workers_data); + } + if (queues_data) + { + this._render_queues(queues_data); + } + if (history_data) + { + this._render_action_history(history_data); + } + } + catch (e) { /* service unavailable */ } + } + + _render_queue_stats(data) + { + const grid = this._queue_grid; + grid.inner().innerHTML = ""; + + const tiles = [ + { title: "Pending Actions", value: data.actions_pending || 0, label: "waiting to be scheduled" }, + { title: "Running Actions", value: data.actions_submitted || 0, label: "currently executing" }, + { title: "Completed Actions", value: data.actions_complete || 0, label: "results available" }, + ]; + + for (const t of tiles) + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text(t.title); + const body = tile.tag().classify("tile-metrics"); + this._metric(body, Friendly.sep(t.value), t.label, true); + } + } + + _update_queue_chart(data) + { + const h = this._history; + h.timestamps.push(new Date().toLocaleTimeString()); + h.pending.push(data.actions_pending || 0); + h.running.push(data.actions_submitted || 0); + h.completed.push(data.actions_complete || 0); + + while (h.timestamps.length > MAX_HISTORY_POINTS) + { + h.timestamps.shift(); + h.pending.shift(); + h.running.shift(); + h.completed.shift(); + } + + if (this._queue_chart) + { + this._queue_chart.data.labels = h.timestamps; + this._queue_chart.data.datasets[0].data = h.pending; + this._queue_chart.data.datasets[1].data = h.running; + this._queue_chart.data.datasets[2].data = h.completed; + this._queue_chart.update("none"); + } + } + + _render_perf(data) + { + if (!this._perf_grid) + { + this._perf_grid = this._perf_host.tag().classify("grid").classify("stats-tiles"); + } + const grid = this._perf_grid; + grid.inner().innerHTML = ""; + + const retired = data.actions_retired || {}; + + // Completion rate card + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Completion Rate"); + const body = tile.tag().classify("tile-columns"); + + const left = body.tag().classify("tile-metrics"); + this._metric(left, this._fmt_rate(retired.rate_1), "1 min rate", true); + this._metric(left, this._fmt_rate(retired.rate_5), "5 min rate"); + this._metric(left, this._fmt_rate(retired.rate_15), "15 min rate"); + + const right = body.tag().classify("tile-metrics"); + this._metric(right, Friendly.sep(retired.count || 0), "total retired", true); + this._metric(right, this._fmt_rate(retired.rate_mean), "mean rate"); + } + } + + _fmt_rate(rate) + { + if (rate == null) return "-"; + return rate.toFixed(2) + "/s"; + } + + _render_workers(data) + { + const workerIds = data.workers || []; + + if (this._workers_table) + { + this._workers_table.clear(); + } + else + { + this._workers_table = this._workers_host.add_widget( + Table, + ["name", "platform", "cores", "timeout", "functions", "worker ID"], + Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1 + ); + } + + if (workerIds.length === 0) + { + return; + } + + // Fetch each worker's descriptor + Promise.all( + workerIds.map(id => + new Fetcher().resource("/compute/workers", id).json() + .then(desc => ({ id, desc })) + .catch(() => ({ id, desc: null })) + ) + ).then(results => { + this._workers_table.clear(); + for (const { id, desc } of results) + { + const name = desc ? (desc.name || "-") : "-"; + const host = desc ? (desc.host || "-") : "-"; + const cores = desc ? (desc.cores != null ? desc.cores : "-") : "-"; + const timeout = desc ? (desc.timeout != null ? desc.timeout + "s" : "-") : "-"; + const functions = desc ? (desc.functions ? desc.functions.length : 0) : "-"; + + const row = this._workers_table.add_row( + "", + host, + String(cores), + String(timeout), + String(functions), + id, + ); + + // Make name clickable to expand detail + const cell = row.get_cell(0); + cell.tag().text(name).on_click(() => this._toggle_worker_detail(id, desc)); + + // Highlight selected + if (id === this._selected_worker) + { + row.style("background", "var(--theme_p3)"); + } + } + + this._worker_descriptors = Object.fromEntries(results.map(r => [r.id, r.desc])); + + // Re-render detail if still selected + if (this._selected_worker && this._worker_descriptors[this._selected_worker]) + { + this._show_worker_detail(this._selected_worker, this._worker_descriptors[this._selected_worker]); + } + else if (this._selected_worker) + { + this._selected_worker = null; + this._clear_worker_detail(); + } + }); + } + + _toggle_worker_detail(id, desc) + { + if (this._selected_worker === id) + { + this._selected_worker = null; + this._clear_worker_detail(); + return; + } + this._selected_worker = id; + this._show_worker_detail(id, desc); + } + + _clear_worker_detail() + { + if (this._worker_detail_container) + { + this._worker_detail_container._parent.inner().remove(); + this._worker_detail_container = null; + } + } + + _show_worker_detail(id, desc) + { + this._clear_worker_detail(); + if (!desc) + { + return; + } + + const section = this._workers_host.add_section(desc.name || id); + this._worker_detail_container = section; + + // Basic info table + const info_table = section.add_widget( + Table, ["property", "value"], Table.Flag_FitLeft|Table.Flag_PackRight + ); + const fields = [ + ["Worker ID", id], + ["Path", desc.path || "-"], + ["Platform", desc.host || "-"], + ["Build System", desc.buildsystem_version || "-"], + ["Cores", desc.cores != null ? String(desc.cores) : "-"], + ["Timeout", desc.timeout != null ? desc.timeout + "s" : "-"], + ]; + for (const [label, value] of fields) + { + info_table.add_row(label, value); + } + + // Functions + const functions = desc.functions || []; + if (functions.length > 0) + { + const fn_section = section.add_section("Functions"); + const fn_table = fn_section.add_widget( + Table, ["name", "version"], Table.Flag_FitLeft|Table.Flag_PackRight + ); + for (const f of functions) + { + fn_table.add_row(f.name || "-", f.version || "-"); + } + } + + // Executables + const executables = desc.executables || []; + if (executables.length > 0) + { + const exec_section = section.add_section("Executables"); + const exec_table = exec_section.add_widget( + Table, ["path", "hash", "size"], Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_AlignNumeric + ); + let totalSize = 0; + for (const e of executables) + { + exec_table.add_row(e.name || "-", e.hash || "-", e.size != null ? Friendly.bytes(e.size) : "-"); + totalSize += e.size || 0; + } + const total_row = exec_table.add_row("TOTAL", "", Friendly.bytes(totalSize)); + total_row.get_cell(0).style("fontWeight", "bold"); + total_row.get_cell(2).style("fontWeight", "bold"); + } + + // Files + const files = desc.files || []; + if (files.length > 0) + { + const files_section = section.add_section("Files"); + const files_table = files_section.add_widget( + Table, ["name", "hash"], Table.Flag_FitLeft|Table.Flag_PackRight + ); + for (const f of files) + { + files_table.add_row(typeof f === "string" ? f : (f.name || "-"), typeof f === "string" ? "" : (f.hash || "")); + } + } + + // Directories + const dirs = desc.dirs || []; + if (dirs.length > 0) + { + const dirs_section = section.add_section("Directories"); + for (const d of dirs) + { + dirs_section.tag().classify("detail-tag").text(d); + } + } + + // Environment + const env = desc.environment || []; + if (env.length > 0) + { + const env_section = section.add_section("Environment"); + for (const e of env) + { + env_section.tag().classify("detail-tag").text(e); + } + } + } + + _render_queues(data) + { + const queues = data.queues || []; + + if (this._queues_table) + { + this._queues_table.clear(); + } + else + { + this._queues_table = this._queues_host.add_widget( + Table, + ["ID", "status", "active", "completed", "failed", "abandoned", "cancelled", "token"], + Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1 + ); + } + + for (const q of queues) + { + const id = q.queue_id != null ? String(q.queue_id) : "-"; + const status = q.state === "cancelled" ? "cancelled" + : q.state === "draining" ? "draining" + : q.is_complete ? "complete" : "active"; + + this._queues_table.add_row( + id, + status, + String(q.active_count ?? 0), + String(q.completed_count ?? 0), + String(q.failed_count ?? 0), + String(q.abandoned_count ?? 0), + String(q.cancelled_count ?? 0), + q.queue_token || "-", + ); + } + } + + _render_action_history(data) + { + const entries = data.history || []; + + if (this._history_table) + { + this._history_table.clear(); + } + else + { + this._history_table = this._history_host.add_widget( + Table, + ["LSN", "queue", "status", "function", "started", "finished", "duration", "worker ID", "action ID"], + Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1 + ); + } + + // Entries arrive oldest-first; reverse to show newest at top + for (const entry of [...entries].reverse()) + { + const lsn = entry.lsn != null ? String(entry.lsn) : "-"; + const queueId = entry.queueId ? String(entry.queueId) : "-"; + const status = entry.succeeded == null ? "unknown" + : entry.succeeded ? "ok" : "failed"; + const desc = entry.actionDescriptor || {}; + const fn = desc.Function || "-"; + const startDate = filetimeToDate(entry.time_Running); + const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed); + + this._history_table.add_row( + lsn, + queueId, + status, + fn, + formatTime(startDate), + formatTime(endDate), + formatDuration(startDate, endDate), + entry.workerId || "-", + entry.actionId || "-", + ); + } + } + + _render_sysinfo(data) + { + const grid = this._sys_grid; + grid.inner().innerHTML = ""; + + // CPU card + { + const cpuUsage = data.cpu_usage || 0; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("CPU Usage"); + const body = tile.tag().classify("tile-metrics"); + this._metric(body, cpuUsage.toFixed(1) + "%", "percent", true); + + // Progress bar + const bar = body.tag().classify("progress-bar"); + bar.tag().classify("progress-fill").style("width", cpuUsage + "%"); + + // CPU sparkline + this._history.cpu.push(cpuUsage); + while (this._history.cpu.length > MAX_HISTORY_POINTS) this._history.cpu.shift(); + if (this._cpu_chart) + { + const sparkContainer = body.tag(); + sparkContainer.style("position", "relative").style("height", "60px").style("marginTop", "12px"); + sparkContainer.inner().appendChild(this._cpu_canvas); + + this._cpu_chart.data.labels = this._history.cpu.map(() => ""); + this._cpu_chart.data.datasets[0].data = this._history.cpu; + this._cpu_chart.update("none"); + } + + // CPU details + this._stat_row(body, "Packages", data.cpu_count != null ? String(data.cpu_count) : "-"); + this._stat_row(body, "Physical Cores", data.core_count != null ? String(data.core_count) : "-"); + this._stat_row(body, "Logical Processors", data.lp_count != null ? String(data.lp_count) : "-"); + } + + // Memory card + { + const memUsed = data.memory_used || 0; + const memTotal = data.memory_total || 1; + const memPercent = (memUsed / memTotal) * 100; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Memory"); + const body = tile.tag().classify("tile-metrics"); + this._stat_row(body, "Used", Friendly.bytes(memUsed)); + this._stat_row(body, "Total", Friendly.bytes(memTotal)); + const bar = body.tag().classify("progress-bar"); + bar.tag().classify("progress-fill").style("width", memPercent + "%"); + } + + // Disk card + { + const diskUsed = data.disk_used || 0; + const diskTotal = data.disk_total || 1; + const diskPercent = (diskUsed / diskTotal) * 100; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Disk"); + const body = tile.tag().classify("tile-metrics"); + this._stat_row(body, "Used", Friendly.bytes(diskUsed)); + this._stat_row(body, "Total", Friendly.bytes(diskTotal)); + const bar = body.tag().classify("progress-bar"); + bar.tag().classify("progress-fill").style("width", diskPercent + "%"); + } + } + + _stat_row(parent, label, value) + { + const row = parent.tag().classify("stats-row"); + row.tag().classify("stats-label").text(label); + row.tag().classify("stats-value").text(value); + } + + _metric(parent, value, label, hero = false) + { + const m = parent.tag().classify("tile-metric"); + if (hero) + { + m.classify("tile-metric-hero"); + } + m.tag().classify("metric-value").text(value); + m.tag().classify("metric-label").text(label); + } +} diff --git a/src/zenserver/frontend/html/pages/cookartifacts.js b/src/zenserver/frontend/html/pages/cookartifacts.js new file mode 100644 index 000000000..f2ae094b9 --- /dev/null +++ b/src/zenserver/frontend/html/pages/cookartifacts.js @@ -0,0 +1,397 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +"use strict"; + +import { ZenPage } from "./page.js" +import { Fetcher } from "../util/fetcher.js" +import { Table, Toolbar, PropTable } from "../util/widgets.js" + +//////////////////////////////////////////////////////////////////////////////// +export class Page extends ZenPage +{ + main() + { + this.set_title("cook artifacts"); + + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); + const opkey = this.get_param("opkey"); + const artifact_hash = this.get_param("hash"); + + // Fetch the artifact content as JSON + this._artifact = new Fetcher() + .resource("prj", project, "oplog", oplog, artifact_hash + ".json") + .json(); + + // Optionally fetch entry info for display context + if (opkey) + { + this._entry = new Fetcher() + .resource("prj", project, "oplog", oplog, "entries") + .param("opkey", opkey) + .cbo(); + } + + this._build_page(); + } + + // Map CookDependency enum values to display names + _get_dependency_type_name(type_value) + { + const type_names = { + 0: "None", + 1: "File", + 2: "Function", + 3: "TransitiveBuild", + 4: "Package", + 5: "ConsoleVariable", + 6: "Config", + 7: "SettingsObject", + 8: "NativeClass", + 9: "AssetRegistryQuery", + 10: "RedirectionTarget" + }; + return type_names[type_value] || `Unknown (${type_value})`; + } + + // Check if Data content should be expandable + _should_make_expandable(data_string) + { + if (!data_string || data_string.length < 40) + return false; + + // Check if it's JSON array or object + if (!data_string.startsWith('[') && !data_string.startsWith('{')) + return false; + + // Check if formatting would add newlines + try { + const parsed = JSON.parse(data_string); + const formatted = JSON.stringify(parsed, null, 2); + return formatted.includes('\n'); + } catch (e) { + return false; + } + } + + // Get first line of content for collapsed state + _get_first_line(data_string) + { + if (!data_string) + return ""; + + const newline_index = data_string.indexOf('\n'); + if (newline_index === -1) + { + // No newline, truncate if too long + return data_string.length > 80 ? data_string.substring(0, 77) + "..." : data_string; + } + return data_string.substring(0, newline_index) + "..."; + } + + // Format JSON with indentation + _format_json(data_string) + { + try { + const parsed = JSON.parse(data_string); + return JSON.stringify(parsed, null, 2); + } catch (e) { + return data_string; + } + } + + // Toggle expand/collapse state + _toggle_data_cell(cell) + { + const is_expanded = cell.attr("expanded") !== null; + const full_data = cell.attr("data-full"); + + // Find the text wrapper span + const text_wrapper = cell.first_child().next_sibling(); + + if (is_expanded) + { + // Collapse: show first line only + const first_line = this._get_first_line(full_data); + text_wrapper.text(first_line); + cell.attr("expanded", null); + } + else + { + // Expand: show formatted JSON + const formatted = this._format_json(full_data); + text_wrapper.text(formatted); + cell.attr("expanded", ""); + } + } + + // Format dependency data based on its structure + _format_dependency(dep_array) + { + const type = dep_array[0]; + const formatted = {}; + + // Common patterns based on the example data: + // Type 2 (Function): [type, name, array, hash] + // Type 4 (Package): [type, path, hash] + // Type 5 (ConsoleVariable): [type, bool, array, hash] + // Type 8 (NativeClass): [type, path, hash] + // Type 9 (AssetRegistryQuery): [type, bool, object, hash] + // Type 10 (RedirectionTarget): [type, path, hash] + + if (dep_array.length > 1) + { + // Most types have a name/path as second element + if (typeof dep_array[1] === "string") + { + formatted.Name = dep_array[1]; + } + else if (typeof dep_array[1] === "boolean") + { + formatted.Value = dep_array[1].toString(); + } + } + + if (dep_array.length > 2) + { + // Third element varies + if (Array.isArray(dep_array[2])) + { + formatted.Data = JSON.stringify(dep_array[2]); + } + else if (typeof dep_array[2] === "object") + { + formatted.Data = JSON.stringify(dep_array[2]); + } + else if (typeof dep_array[2] === "string") + { + formatted.Hash = dep_array[2]; + } + } + + if (dep_array.length > 3) + { + // Fourth element is usually the hash + if (typeof dep_array[3] === "string") + { + formatted.Hash = dep_array[3]; + } + } + + return formatted; + } + + async _build_page() + { + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); + const opkey = this.get_param("opkey"); + const artifact_hash = this.get_param("hash"); + + // Build page title + let title = "Cook Artifacts"; + if (this._entry) + { + try + { + const entry = await this._entry; + const entry_obj = entry.as_object().find("entry").as_object(); + const key = entry_obj.find("key").as_value(); + title = `Cook Artifacts`; + } + catch (e) + { + console.error("Failed to fetch entry:", e); + } + } + + const section = this.add_section(title); + + // Fetch and parse artifact + let artifact; + try + { + artifact = await this._artifact; + } + catch (e) + { + section.text(`Failed to load artifact: ${e.message}`); + return; + } + + // Display artifact info + const info_section = section.add_section("Artifact Info"); + const info_table = info_section.add_widget(Table, ["Property", "Value"], Table.Flag_PackRight); + + if (artifact.Version !== undefined) + info_table.add_row("Version", artifact.Version.toString()); + if (artifact.HasSaveResults !== undefined) + info_table.add_row("HasSaveResults", artifact.HasSaveResults.toString()); + if (artifact.PackageSavedHash !== undefined) + info_table.add_row("PackageSavedHash", artifact.PackageSavedHash); + + // Process SaveBuildDependencies + if (artifact.SaveBuildDependencies && artifact.SaveBuildDependencies.Dependencies) + { + this._build_dependency_section( + section, + "Save Build Dependencies", + artifact.SaveBuildDependencies.Dependencies, + artifact.SaveBuildDependencies.StoredKey + ); + } + + // Process LoadBuildDependencies + if (artifact.LoadBuildDependencies && artifact.LoadBuildDependencies.Dependencies) + { + this._build_dependency_section( + section, + "Load Build Dependencies", + artifact.LoadBuildDependencies.Dependencies, + artifact.LoadBuildDependencies.StoredKey + ); + } + + // Process RuntimeDependencies + if (artifact.RuntimeDependencies && artifact.RuntimeDependencies.length > 0) + { + const runtime_section = section.add_section("Runtime Dependencies"); + const runtime_table = runtime_section.add_widget(Table, ["Path"], Table.Flag_PackRight); + for (const dep of artifact.RuntimeDependencies) + { + const row = runtime_table.add_row(dep); + // Make Path clickable to navigate to entry + if (this._should_link_dependency(dep)) + { + row.get_cell(0).text(dep).on_click((opkey) => { + window.location = `?page=entry&project=${project}&oplog=${oplog}&opkey=${opkey.toLowerCase()}`; + }, dep); + } + } + } + } + + _should_link_dependency(name) + { + // Exclude dependencies starting with /Script/ (code-defined entries) - case insensitive + if (name && name.toLowerCase().startsWith("/script/")) + return false; + + return true; + } + + _build_dependency_section(parent_section, title, dependencies, stored_key) + { + const section = parent_section.add_section(title); + + // Add stored key info + if (stored_key) + { + const key_toolbar = section.add_widget(Toolbar); + key_toolbar.left().add(`Key: ${stored_key}`); + } + + // Group dependencies by type + const dependencies_by_type = {}; + + for (const dep_array of dependencies) + { + if (!Array.isArray(dep_array) || dep_array.length === 0) + continue; + + const type = dep_array[0]; + if (!dependencies_by_type[type]) + dependencies_by_type[type] = []; + + dependencies_by_type[type].push(this._format_dependency(dep_array)); + } + + // Sort types numerically + const sorted_types = Object.keys(dependencies_by_type).map(Number).sort((a, b) => a - b); + + for (const type_value of sorted_types) + { + const type_name = this._get_dependency_type_name(type_value); + const deps = dependencies_by_type[type_value]; + + const type_section = section.add_section(type_name); + + // Determine columns based on available fields + const all_fields = new Set(); + for (const dep of deps) + { + for (const field in dep) + all_fields.add(field); + } + let columns = Array.from(all_fields); + + // Remove Hash column for RedirectionTarget as it's not useful + if (type_value === 10) + { + columns = columns.filter(col => col !== "Hash"); + } + + if (columns.length === 0) + { + type_section.text("No data fields"); + continue; + } + + // Create table with dynamic columns + const table = type_section.add_widget(Table, columns, Table.Flag_PackRight); + + // Check if this type should have clickable Name links + const should_link = (type_value === 3 || type_value === 4 || type_value === 10); + const name_col_index = columns.indexOf("Name"); + + for (const dep of deps) + { + const row_values = columns.map(col => dep[col] || ""); + const row = table.add_row(...row_values); + + // Make Name field clickable for Package, TransitiveBuild, and RedirectionTarget + if (should_link && name_col_index >= 0 && dep.Name && this._should_link_dependency(dep.Name)) + { + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); + row.get_cell(name_col_index).text(dep.Name).on_click((opkey) => { + window.location = `?page=entry&project=${project}&oplog=${oplog}&opkey=${opkey.toLowerCase()}`; + }, dep.Name); + } + + // Make Data field expandable/collapsible if needed + const data_col_index = columns.indexOf("Data"); + if (data_col_index >= 0 && dep.Data) + { + const data_cell = row.get_cell(data_col_index); + + if (this._should_make_expandable(dep.Data)) + { + // Store full data in attribute + data_cell.attr("data-full", dep.Data); + + // Clear the cell and rebuild with icon + text + data_cell.inner().innerHTML = ""; + + // Create expand/collapse icon + const icon = data_cell.tag("span").classify("zen_expand_icon").text("+"); + icon.on_click(() => { + this._toggle_data_cell(data_cell); + // Update icon text + const is_expanded = data_cell.attr("expanded") !== null; + icon.text(is_expanded ? "-" : "+"); + }); + + // Add text content wrapper + const text_wrapper = data_cell.tag("span").classify("zen_data_text"); + const first_line = this._get_first_line(dep.Data); + text_wrapper.text(first_line); + + // Store reference to text wrapper for updates + data_cell.attr("data-text-wrapper", "true"); + } + } + } + } + } +} diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js index 08589b090..1e4c82e3f 100644 --- a/src/zenserver/frontend/html/pages/entry.js +++ b/src/zenserver/frontend/html/pages/entry.js @@ -26,6 +26,9 @@ export class Page extends ZenPage this._indexer = this.load_indexer(project, oplog); + this._files_index_start = Number(this.get_param("files_start", 0)) || 0; + this._files_index_count = Number(this.get_param("files_count", 50)) || 0; + this._build_page(); } @@ -40,25 +43,39 @@ export class Page extends ZenPage return indexer; } - async _build_deps(section, tree) + _build_deps(section, tree) { - const indexer = await this._indexer; + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); for (const dep_name in tree) { const dep_section = section.add_section(dep_name); const table = dep_section.add_widget(Table, ["name", "id"], Table.Flag_PackRight); + for (const dep_id of tree[dep_name]) { - const cell_values = ["", dep_id.toString(16).padStart(16, "0")]; + const hex_id = dep_id.toString(16).padStart(16, "0"); + const cell_values = ["loading...", hex_id]; const row = table.add_row(...cell_values); - var opkey = indexer.lookup_id(dep_id); - row.get_cell(0).text(opkey).on_click((k) => this.view_opkey(k), opkey); + // Asynchronously resolve the name + this._resolve_dep_name(row.get_cell(0), dep_id, project, oplog); } } } + async _resolve_dep_name(cell, dep_id, project, oplog) + { + const indexer = await this._indexer; + const opkey = indexer.lookup_id(dep_id); + + if (opkey) + { + cell.text(opkey).on_click((k) => this.view_opkey(k), opkey); + } + } + _find_iohash_field(container, name) { const found_field = container.find(name); @@ -76,6 +93,21 @@ export class Page extends ZenPage return null; } + _is_null_io_hash_string(io_hash) + { + if (!io_hash) + return true; + + for (let char of io_hash) + { + if (char != '0') + { + return false; + } + } + return true; + } + async _build_meta(section, entry) { var tree = {} @@ -123,11 +155,23 @@ export class Page extends ZenPage const project = this.get_param("project"); const oplog = this.get_param("oplog"); + const opkey = this.get_param("opkey"); const link = row.get_cell(0).link( - "/" + ["prj", project, "oplog", oplog, value+".json"].join("/") + (key === "cook.artifacts") ? + `?page=cookartifacts&project=${project}&oplog=${oplog}&opkey=${opkey}&hash=${value}` + : "/" + ["prj", project, "oplog", oplog, value+".json"].join("/") ); const action_tb = new Toolbar(row.get_cell(-1), true); + + // Add "view-raw" button for cook.artifacts + if (key === "cook.artifacts") + { + action_tb.left().add("view-raw").on_click(() => { + window.location = "/" + ["prj", project, "oplog", oplog, value+".json"].join("/"); + }); + } + action_tb.left().add("copy-hash").on_click(async (v) => { await navigator.clipboard.writeText(v); }, value); @@ -137,35 +181,55 @@ export class Page extends ZenPage async _build_page() { var entry = await this._entry; + + // Check if entry exists + if (!entry || entry.as_object().find("entry") == null) + { + const opkey = this.get_param("opkey"); + var section = this.add_section("Entry Not Found"); + section.tag("p").text(`The entry "${opkey}" is not present in this dataset.`); + section.tag("p").text("This could mean:"); + const list = section.tag("ul"); + list.tag("li").text("The entry is for an instance defined in code"); + list.tag("li").text("The entry has not been added to the oplog yet"); + list.tag("li").text("The entry key is misspelled"); + list.tag("li").text("The entry was removed or never existed"); + return; + } + entry = entry.as_object().find("entry").as_object(); const name = entry.find("key").as_value(); var section = this.add_section(name); + var has_package_data = false; // tree { var tree = entry.find("$tree"); if (tree == undefined) tree = this._convert_legacy_to_tree(entry); - if (tree == undefined) - return this._display_unsupported(section, entry); - - delete tree["$id"]; - - if (Object.keys(tree).length != 0) + if (tree != undefined) { - const sub_section = section.add_section("deps"); - this._build_deps(sub_section, tree); + delete tree["$id"]; + + if (Object.keys(tree).length != 0) + { + const sub_section = section.add_section("dependencies"); + this._build_deps(sub_section, tree); + } + has_package_data = true; } } // meta + if (has_package_data) { this._build_meta(section, entry); } // data + if (has_package_data) { const sub_section = section.add_section("data"); const table = sub_section.add_widget( @@ -181,7 +245,7 @@ export class Page extends ZenPage for (const item of pkg_data.as_array()) { - var io_hash, size, raw_size, file_name; + var io_hash = undefined, size = undefined, raw_size = undefined, file_name = undefined; for (const field of item.as_object()) { if (field.is_named("data")) io_hash = field.as_value(); @@ -198,8 +262,8 @@ export class Page extends ZenPage io_hash = ret; } - size = (size !== undefined) ? Friendly.kib(size) : ""; - raw_size = (raw_size !== undefined) ? Friendly.kib(raw_size) : ""; + size = (size !== undefined) ? Friendly.bytes(size) : ""; + raw_size = (raw_size !== undefined) ? Friendly.bytes(raw_size) : ""; const row = table.add_row(file_name, size, raw_size); @@ -219,12 +283,76 @@ export class Page extends ZenPage } } + // files + var has_file_data = false; + { + var file_data = entry.find("files"); + if (file_data != undefined) + { + has_file_data = true; + + // Extract files into array + this._files_data = []; + for (const item of file_data.as_array()) + { + var io_hash = undefined, cid = undefined, server_path = undefined, client_path = undefined; + for (const field of item.as_object()) + { + if (field.is_named("data")) io_hash = field.as_value(); + else if (field.is_named("id")) cid = field.as_value(); + else if (field.is_named("serverpath")) server_path = field.as_value(); + else if (field.is_named("clientpath")) client_path = field.as_value(); + } + + if (io_hash instanceof Uint8Array) + { + var ret = ""; + for (var x of io_hash) + ret += x.toString(16).padStart(2, "0"); + io_hash = ret; + } + + if (cid instanceof Uint8Array) + { + var ret = ""; + for (var x of cid) + ret += x.toString(16).padStart(2, "0"); + cid = ret; + } + + this._files_data.push({ + server_path: server_path, + client_path: client_path, + io_hash: io_hash, + cid: cid + }); + } + + this._files_index_max = this._files_data.length; + + const sub_section = section.add_section("files"); + this._build_files_nav(sub_section); + + this._files_table = sub_section.add_widget( + Table, + ["name", "actions"], Table.Flag_PackRight + ); + this._files_table.id("filetable"); + + this._build_files_table(this._files_index_start); + } + } + // props + if (has_package_data) { const object = entry.to_js_object(); var sub_section = section.add_section("props"); sub_section.add_widget(PropTable).add_object(object); } + + if (!has_package_data && !has_file_data) + return this._display_unsupported(section, entry); } _display_unsupported(section, entry) @@ -271,16 +399,30 @@ export class Page extends ZenPage for (const field of pkgst_entry) { const field_name = field.get_name(); - if (!field_name.endsWith("importedpackageids")) - continue; - - var dep_name = field_name.slice(0, -18); - if (dep_name.length == 0) - dep_name = "imported"; - - var out = tree[dep_name] = []; - for (var item of field.as_array()) - out.push(item.as_value(BigInt)); + if (field_name.endsWith("importedpackageids")) + { + var dep_name = field_name.slice(0, -18); + if (dep_name.length == 0) + dep_name = "hard"; + else + dep_name = "hard." + dep_name; + + var out = tree[dep_name] = []; + for (var item of field.as_array()) + out.push(item.as_value(BigInt)); + } + else if (field_name.endsWith("softpackagereferences")) + { + var dep_name = field_name.slice(0, -21); + if (dep_name.length == 0) + dep_name = "soft"; + else + dep_name = "soft." + dep_name; + + var out = tree[dep_name] = []; + for (var item of field.as_array()) + out.push(item.as_value(BigInt)); + } } return tree; @@ -292,4 +434,149 @@ export class Page extends ZenPage params.set("opkey", opkey); window.location.search = params; } + + _build_files_nav(section) + { + const nav = section.add_widget(Toolbar); + const left = nav.left(); + left.add("|<") .on_click(() => this._on_files_next_prev(-10e10)); + left.add("<<") .on_click(() => this._on_files_next_prev(-10)); + left.add("prev").on_click(() => this._on_files_next_prev( -1)); + left.add("next").on_click(() => this._on_files_next_prev( 1)); + left.add(">>") .on_click(() => this._on_files_next_prev( 10)); + left.add(">|") .on_click(() => this._on_files_next_prev( 10e10)); + + left.sep(); + for (var count of [10, 25, 50, 100]) + { + var handler = (n) => this._on_files_change_count(n); + left.add(count).on_click(handler, count); + } + + const right = nav.right(); + right.add(Friendly.sep(this._files_index_max)); + + right.sep(); + var search_input = right.add("search:", "label").tag("input"); + search_input.on("change", (x) => this._search_files(x.inner().value), search_input); + } + + _build_files_table(index) + { + this._files_index_count = Math.max(this._files_index_count, 1); + index = Math.min(index, this._files_index_max - this._files_index_count); + index = Math.max(index, 0); + + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); + + const end_index = Math.min(index + this._files_index_count, this._files_index_max); + + this._files_table.clear(index); + for (var i = index; i < end_index; i++) + { + const file_item = this._files_data[i]; + const row = this._files_table.add_row(file_item.server_path); + + var base_name = file_item.server_path.split("/").pop().split("\\").pop(); + if (this._is_null_io_hash_string(file_item.io_hash)) + { + const link = row.get_cell(0).link( + "/" + ["prj", project, "oplog", oplog, file_item.cid].join("/") + ); + link.first_child().attr("download", `${file_item.cid}_${base_name}`); + + const action_tb = new Toolbar(row.get_cell(-1), true); + action_tb.left().add("copy-id").on_click(async (v) => { + await navigator.clipboard.writeText(v); + }, file_item.cid); + } + else + { + const link = row.get_cell(0).link( + "/" + ["prj", project, "oplog", oplog, file_item.io_hash].join("/") + ); + link.first_child().attr("download", `${file_item.io_hash}_${base_name}`); + + const action_tb = new Toolbar(row.get_cell(-1), true); + action_tb.left().add("copy-hash").on_click(async (v) => { + await navigator.clipboard.writeText(v); + }, file_item.io_hash); + } + } + + this.set_param("files_start", index); + this.set_param("files_count", this._files_index_count); + this._files_index_start = index; + } + + _on_files_change_count(value) + { + this._files_index_count = parseInt(value); + this._build_files_table(this._files_index_start); + } + + _on_files_next_prev(direction) + { + var index = this._files_index_start + (this._files_index_count * direction); + index = Math.max(0, index); + this._build_files_table(index); + } + + _search_files(needle) + { + if (needle.length == 0) + { + this._build_files_table(this._files_index_start); + return; + } + needle = needle.trim().toLowerCase(); + + this._files_table.clear(this._files_index_start); + + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); + + var added = 0; + const truncate_at = this.get_param("searchmax") || 250; + for (const file_item of this._files_data) + { + if (!file_item.server_path.toLowerCase().includes(needle)) + continue; + + const row = this._files_table.add_row(file_item.server_path); + + var base_name = file_item.server_path.split("/").pop().split("\\").pop(); + if (this._is_null_io_hash_string(file_item.io_hash)) + { + const link = row.get_cell(0).link( + "/" + ["prj", project, "oplog", oplog, file_item.cid].join("/") + ); + link.first_child().attr("download", `${file_item.cid}_${base_name}`); + + const action_tb = new Toolbar(row.get_cell(-1), true); + action_tb.left().add("copy-id").on_click(async (v) => { + await navigator.clipboard.writeText(v); + }, file_item.cid); + } + else + { + const link = row.get_cell(0).link( + "/" + ["prj", project, "oplog", oplog, file_item.io_hash].join("/") + ); + link.first_child().attr("download", `${file_item.io_hash}_${base_name}`); + + const action_tb = new Toolbar(row.get_cell(-1), true); + action_tb.left().add("copy-hash").on_click(async (v) => { + await navigator.clipboard.writeText(v); + }, file_item.io_hash); + } + + if (++added >= truncate_at) + { + this._files_table.add_row("...truncated"); + break; + } + } + } } diff --git a/src/zenserver/frontend/html/pages/hub.js b/src/zenserver/frontend/html/pages/hub.js new file mode 100644 index 000000000..f9e4fff33 --- /dev/null +++ b/src/zenserver/frontend/html/pages/hub.js @@ -0,0 +1,122 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +"use strict"; + +import { ZenPage } from "./page.js" +import { Fetcher } from "../util/fetcher.js" +import { Friendly } from "../util/friendly.js" +import { Table } from "../util/widgets.js" + +//////////////////////////////////////////////////////////////////////////////// +export class Page extends ZenPage +{ + async main() + { + this.set_title("hub"); + + // Capacity + const stats_section = this.add_section("Capacity"); + this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles"); + + // Modules + const mod_section = this.add_section("Modules"); + this._mod_host = mod_section; + this._mod_table = null; + + await this._update(); + this._poll_timer = setInterval(() => this._update(), 2000); + } + + async _update() + { + try + { + const [stats, status] = await Promise.all([ + new Fetcher().resource("/hub/stats").json(), + new Fetcher().resource("/hub/status").json(), + ]); + + this._render_capacity(stats); + this._render_modules(status); + } + catch (e) { /* service unavailable */ } + } + + _render_capacity(data) + { + const grid = this._stats_grid; + grid.inner().innerHTML = ""; + + const current = data.currentInstanceCount || 0; + const max = data.maxInstanceCount || 0; + const limit = data.instanceLimit || 0; + + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Active Modules"); + const body = tile.tag().classify("tile-metrics"); + this._metric(body, Friendly.sep(current), "currently provisioned", true); + } + + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Peak Modules"); + const body = tile.tag().classify("tile-metrics"); + this._metric(body, Friendly.sep(max), "high watermark", true); + } + + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Instance Limit"); + const body = tile.tag().classify("tile-metrics"); + this._metric(body, Friendly.sep(limit), "maximum allowed", true); + if (limit > 0) + { + const pct = ((current / limit) * 100).toFixed(0) + "%"; + this._metric(body, pct, "utilization"); + } + } + } + + _render_modules(data) + { + const modules = data.modules || []; + + if (this._mod_table) + { + this._mod_table.clear(); + } + else + { + this._mod_table = this._mod_host.add_widget( + Table, + ["module ID", "status"], + Table.Flag_FitLeft|Table.Flag_PackRight + ); + } + + if (modules.length === 0) + { + return; + } + + for (const m of modules) + { + this._mod_table.add_row( + m.moduleId || "", + m.provisioned ? "provisioned" : "inactive", + ); + } + } + + _metric(parent, value, label, hero = false) + { + const m = parent.tag().classify("tile-metric"); + if (hero) + { + m.classify("tile-metric-hero"); + } + m.tag().classify("metric-value").text(value); + m.tag().classify("metric-label").text(label); + } +} diff --git a/src/zenserver/frontend/html/pages/info.js b/src/zenserver/frontend/html/pages/info.js new file mode 100644 index 000000000..f92765c78 --- /dev/null +++ b/src/zenserver/frontend/html/pages/info.js @@ -0,0 +1,261 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +"use strict"; + +import { ZenPage } from "./page.js" +import { Fetcher } from "../util/fetcher.js" +import { Friendly } from "../util/friendly.js" + +//////////////////////////////////////////////////////////////////////////////// +export class Page extends ZenPage +{ + async main() + { + this.set_title("info"); + + const [info, gc, services, version] = await Promise.all([ + new Fetcher().resource("/health/info").json(), + new Fetcher().resource("/admin/gc").json().catch(() => null), + new Fetcher().resource("/api/").json().catch(() => ({})), + new Fetcher().resource("/health/version").param("detailed", "true").text(), + ]); + + const section = this.add_section("Server Info"); + const grid = section.tag().classify("grid").classify("info-tiles"); + + // Application + { + const tile = grid.tag().classify("card").classify("info-tile"); + tile.tag().classify("card-title").text("Application"); + const list = tile.tag().classify("info-props"); + + this._prop(list, "version", version || info.BuildVersion || "-"); + this._prop(list, "http server", info.HttpServerClass || "-"); + this._prop(list, "port", info.Port || "-"); + this._prop(list, "pid", info.Pid || "-"); + this._prop(list, "dedicated", info.IsDedicated ? "yes" : "no"); + + if (info.StartTimeMs) + { + const start = new Date(info.StartTimeMs); + const elapsed = Date.now() - info.StartTimeMs; + this._prop(list, "started", start.toLocaleString()); + this._prop(list, "uptime", this._format_duration(elapsed)); + } + + this._prop(list, "data root", info.DataRoot || "-"); + this._prop(list, "log path", info.AbsLogPath || "-"); + } + + // System + { + const tile = grid.tag().classify("card").classify("info-tile"); + tile.tag().classify("card-title").text("System"); + const list = tile.tag().classify("info-props"); + + this._prop(list, "hostname", info.Hostname || "-"); + this._prop(list, "platform", info.Platform || "-"); + this._prop(list, "os", info.OS || "-"); + this._prop(list, "arch", info.Arch || "-"); + + const sys = info.System; + if (sys) + { + this._prop(list, "cpus", sys.cpu_count || "-"); + this._prop(list, "cores", sys.core_count || "-"); + this._prop(list, "logical processors", sys.lp_count || "-"); + this._prop(list, "total memory", sys.total_memory_mb ? Friendly.bytes(sys.total_memory_mb * 1048576) : "-"); + this._prop(list, "available memory", sys.avail_memory_mb ? Friendly.bytes(sys.avail_memory_mb * 1048576) : "-"); + if (sys.uptime_seconds) + { + this._prop(list, "system uptime", this._format_duration(sys.uptime_seconds * 1000)); + } + } + } + + // Runtime Configuration + if (info.RuntimeConfig) + { + const tile = grid.tag().classify("card").classify("info-tile"); + tile.tag().classify("card-title").text("Runtime Configuration"); + const list = tile.tag().classify("info-props"); + + for (const key in info.RuntimeConfig) + { + this._prop(list, key, info.RuntimeConfig[key] || "-"); + } + } + + // Build Configuration + if (info.BuildConfig) + { + const tile = grid.tag().classify("card").classify("info-tile"); + tile.tag().classify("card-title").text("Build Configuration"); + const list = tile.tag().classify("info-props"); + + for (const key in info.BuildConfig) + { + this._prop(list, key, info.BuildConfig[key] ? "yes" : "no"); + } + } + + // Services + { + const tile = grid.tag().classify("card").classify("info-tile"); + tile.tag().classify("card-title").text("Services"); + const list = tile.tag().classify("info-props"); + + const svc_list = (services.services || []).map(s => s.base_uri).sort(); + for (const uri of svc_list) + { + this._prop(list, uri, "registered"); + } + } + + // Garbage Collection + if (gc) + { + const tile = grid.tag().classify("card").classify("info-tile"); + tile.tag().classify("card-title").text("Garbage Collection"); + const list = tile.tag().classify("info-props"); + + this._prop(list, "status", gc.Status || "-"); + + if (gc.AreDiskWritesBlocked !== undefined) + { + this._prop(list, "disk writes blocked", gc.AreDiskWritesBlocked ? "yes" : "no"); + } + + if (gc.DiskSize) + { + this._prop(list, "disk size", gc.DiskSize); + this._prop(list, "disk used", gc.DiskUsed); + this._prop(list, "disk free", gc.DiskFree); + } + + const cfg = gc.Config; + if (cfg) + { + this._prop(list, "gc enabled", cfg.Enabled ? "yes" : "no"); + if (cfg.Interval) + { + this._prop(list, "interval", this._friendly_duration(cfg.Interval)); + } + if (cfg.LightweightInterval) + { + this._prop(list, "lightweight interval", this._friendly_duration(cfg.LightweightInterval)); + } + if (cfg.MaxCacheDuration) + { + this._prop(list, "max cache duration", this._friendly_duration(cfg.MaxCacheDuration)); + } + if (cfg.MaxProjectStoreDuration) + { + this._prop(list, "max project duration", this._friendly_duration(cfg.MaxProjectStoreDuration)); + } + if (cfg.MaxBuildStoreDuration) + { + this._prop(list, "max build duration", this._friendly_duration(cfg.MaxBuildStoreDuration)); + } + } + + if (gc.FullGC) + { + if (gc.FullGC.LastTime) + { + this._prop(list, "last full gc", this._friendly_timestamp(gc.FullGC.LastTime)); + } + if (gc.FullGC.TimeToNext) + { + this._prop(list, "next full gc", this._friendly_duration(gc.FullGC.TimeToNext)); + } + } + + if (gc.LightweightGC) + { + if (gc.LightweightGC.LastTime) + { + this._prop(list, "last lightweight gc", this._friendly_timestamp(gc.LightweightGC.LastTime)); + } + if (gc.LightweightGC.TimeToNext) + { + this._prop(list, "next lightweight gc", this._friendly_duration(gc.LightweightGC.TimeToNext)); + } + } + } + } + + _prop(parent, label, value) + { + const row = parent.tag().classify("info-prop"); + row.tag().classify("info-prop-label").text(label); + const val = row.tag().classify("info-prop-value"); + const str = String(value); + if (str.match(/^[A-Za-z]:[\\/]/) || str.startsWith("/")) + { + val.tag("a").text(str).attr("href", "vscode://" + str.replace(/\\/g, "/")); + } + else + { + val.text(str); + } + } + + _friendly_timestamp(value) + { + const d = new Date(value); + if (isNaN(d.getTime())) + { + return String(value); + } + return d.toLocaleString(undefined, { + year: "numeric", month: "short", day: "numeric", + hour: "2-digit", minute: "2-digit", second: "2-digit", + }); + } + + _friendly_duration(value) + { + if (typeof value === "number") + { + return this._format_duration(value); + } + + const str = String(value); + const match = str.match(/^[+-]?(?:(\d+)\.)?(\d+):(\d+):(\d+)(?:\.(\d+))?$/); + if (!match) + { + return str; + } + + const days = parseInt(match[1] || "0", 10); + const hours = parseInt(match[2], 10); + const minutes = parseInt(match[3], 10); + const seconds = parseInt(match[4], 10); + const total_seconds = days * 86400 + hours * 3600 + minutes * 60 + seconds; + + return this._format_duration(total_seconds * 1000); + } + + _format_duration(ms) + { + const seconds = Math.floor(ms / 1000); + const minutes = Math.floor(seconds / 60); + const hours = Math.floor(minutes / 60); + const days = Math.floor(hours / 24); + + if (days > 0) + { + return `${days}d ${hours % 24}h ${minutes % 60}m`; + } + if (hours > 0) + { + return `${hours}h ${minutes % 60}m`; + } + if (minutes > 0) + { + return `${minutes}m ${seconds % 60}s`; + } + return `${seconds}s`; + } +} diff --git a/src/zenserver/frontend/html/pages/map.js b/src/zenserver/frontend/html/pages/map.js index 58046b255..ac8f298aa 100644 --- a/src/zenserver/frontend/html/pages/map.js +++ b/src/zenserver/frontend/html/pages/map.js @@ -116,9 +116,9 @@ export class Page extends ZenPage for (const name of sorted_keys) nodes.push(new_nodes[name] / branch_size); - var stats = Friendly.kib(branch_size); + var stats = Friendly.bytes(branch_size); stats += " / "; - stats += Friendly.kib(total_size); + stats += Friendly.bytes(total_size); stats += " ("; stats += 0|((branch_size * 100) / total_size); stats += "%)"; diff --git a/src/zenserver/frontend/html/pages/metrics.js b/src/zenserver/frontend/html/pages/metrics.js new file mode 100644 index 000000000..e7a2eca67 --- /dev/null +++ b/src/zenserver/frontend/html/pages/metrics.js @@ -0,0 +1,232 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +"use strict"; + +import { ZenPage } from "./page.js" +import { Fetcher } from "../util/fetcher.js" +import { Friendly } from "../util/friendly.js" +import { PropTable, Toolbar } from "../util/widgets.js" + +//////////////////////////////////////////////////////////////////////////////// +class TemporalStat +{ + constructor(data, as_bytes) + { + this._data = data; + this._as_bytes = as_bytes; + } + + toString() + { + const columns = [ + /* count */ {}, + /* rate */ {}, + /* t */ {}, {}, + ]; + const data = this._data; + for (var key in data) + { + var out = columns[0]; + if (key.startsWith("rate_")) out = columns[1]; + else if (key.startsWith("t_p")) out = columns[3]; + else if (key.startsWith("t_")) out = columns[2]; + out[key] = data[key]; + } + + var friendly = this._as_bytes ? Friendly.bytes : Friendly.sep; + + var content = ""; + for (var i = 0; i < columns.length; ++i) + { + const column = columns[i]; + for (var key in column) + { + var value = column[key]; + if (i) + { + value = Friendly.sep(value, 2); + key = key.padStart(9); + content += key + ": " + value; + } + else + content += friendly(value); + content += "\r\n"; + } + } + + return content; + } + + tag() + { + return "pre"; + } +} + +//////////////////////////////////////////////////////////////////////////////// +export class Page extends ZenPage +{ + async main() + { + this.set_title("metrics"); + + const metrics_section = this.add_section("metrics"); + const top_toolbar = metrics_section.add_widget(Toolbar); + const tb_right = top_toolbar.right(); + this._refresh_label = tb_right.add("", "label"); + this._pause_btn = tb_right.add("pause").on_click(() => this._toggle_pause()); + + this._paused = false; + this._last_refresh = Date.now(); + this._provider_views = []; + + const providers_data = await new Fetcher().resource("stats").json(); + const providers = providers_data["providers"] || []; + + const stats_list = await Promise.all(providers.map((provider) => + new Fetcher() + .resource("stats", provider) + .param("cidstorestats", "true") + .param("cachestorestats", "true") + .json() + .then((stats) => ({ provider, stats })) + )); + + for (const { provider, stats } of stats_list) + { + this._condense(stats); + this._provider_views.push(this._render_provider(provider, stats)); + } + + this._last_refresh = Date.now(); + this._update_refresh_label(); + + this._timer_id = setInterval(() => this._refresh(), 5000); + this._tick_id = setInterval(() => this._update_refresh_label(), 1000); + + document.addEventListener("visibilitychange", () => { + if (document.hidden) + this._pause_timer(false); + else if (!this._paused) + this._resume_timer(); + }); + } + + _render_provider(provider, stats) + { + const section = this.add_section(provider); + const toolbar = section.add_widget(Toolbar); + + toolbar.right().add("detailed →").on_click(() => { + window.location = "?page=stat&provider=" + provider; + }); + + const table = section.add_widget(PropTable); + let current_stats = stats; + let current_category = undefined; + + const show_category = (cat) => { + current_category = cat; + table.clear(); + table.add_object(current_stats[cat], true, 3); + }; + + var first = undefined; + for (var name in stats) + { + first = first || name; + toolbar.left().add(name).on_click(show_category, name); + } + + if (first) + show_category(first); + + return { + provider, + set_stats: (new_stats) => { + current_stats = new_stats; + if (current_category && current_stats[current_category]) + show_category(current_category); + }, + }; + } + + async _refresh() + { + const updates = await Promise.all(this._provider_views.map((view) => + new Fetcher() + .resource("stats", view.provider) + .param("cidstorestats", "true") + .param("cachestorestats", "true") + .json() + .then((stats) => ({ view, stats })) + )); + + for (const { view, stats } of updates) + { + this._condense(stats); + view.set_stats(stats); + } + + this._last_refresh = Date.now(); + this._update_refresh_label(); + } + + _update_refresh_label() + { + const elapsed = Math.floor((Date.now() - this._last_refresh) / 1000); + this._refresh_label.inner().textContent = "refreshed " + elapsed + "s ago"; + } + + _toggle_pause() + { + if (this._paused) + this._resume_timer(); + else + this._pause_timer(true); + } + + _pause_timer(user_paused=true) + { + clearInterval(this._timer_id); + this._timer_id = undefined; + if (user_paused) + { + this._paused = true; + this._pause_btn.inner().textContent = "resume"; + } + } + + _resume_timer() + { + this._paused = false; + this._pause_btn.inner().textContent = "pause"; + this._timer_id = setInterval(() => this._refresh(), 5000); + this._refresh(); + } + + _condense(stats) + { + const impl = function(node) + { + for (var name in node) + { + const candidate = node[name]; + if (!(candidate instanceof Object)) + continue; + + if (candidate["rate_mean"] != undefined) + { + const as_bytes = (name.indexOf("bytes") >= 0); + node[name] = new TemporalStat(candidate, as_bytes); + continue; + } + + impl(candidate); + } + } + + for (var name in stats) + impl(stats[name]); + } +} diff --git a/src/zenserver/frontend/html/pages/oplog.js b/src/zenserver/frontend/html/pages/oplog.js index 879fc4c97..fb857affb 100644 --- a/src/zenserver/frontend/html/pages/oplog.js +++ b/src/zenserver/frontend/html/pages/oplog.js @@ -32,7 +32,7 @@ export class Page extends ZenPage this.set_title("oplog - " + oplog); - var section = this.add_section(project + " - " + oplog); + var section = this.add_section(oplog); oplog_info = await oplog_info; this._index_max = oplog_info["opcount"]; @@ -81,7 +81,7 @@ export class Page extends ZenPage const right = nav.right(); right.add(Friendly.sep(oplog_info["opcount"])); - right.add("(" + Friendly.kib(oplog_info["totalsize"]) + ")"); + right.add("(" + Friendly.bytes(oplog_info["totalsize"]) + ")"); right.sep(); var search_input = right.add("search:", "label").tag("input") diff --git a/src/zenserver/frontend/html/pages/orchestrator.js b/src/zenserver/frontend/html/pages/orchestrator.js new file mode 100644 index 000000000..24805c722 --- /dev/null +++ b/src/zenserver/frontend/html/pages/orchestrator.js @@ -0,0 +1,405 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +"use strict"; + +import { ZenPage } from "./page.js" +import { Fetcher } from "../util/fetcher.js" +import { Friendly } from "../util/friendly.js" +import { Table } from "../util/widgets.js" + +//////////////////////////////////////////////////////////////////////////////// +export class Page extends ZenPage +{ + async main() + { + this.set_title("orchestrator"); + + // Agents section + const agents_section = this._collapsible_section("Compute Agents"); + this._agents_host = agents_section; + this._agents_table = null; + + // Clients section + const clients_section = this._collapsible_section("Connected Clients"); + this._clients_host = clients_section; + this._clients_table = null; + + // Event history + const history_section = this._collapsible_section("Worker Events"); + this._history_host = history_section; + this._history_table = null; + + const client_history_section = this._collapsible_section("Client Events"); + this._client_history_host = client_history_section; + this._client_history_table = null; + + this._ws_paused = false; + try { this._ws_paused = localStorage.getItem("zen-ws-paused") === "true"; } catch (e) {} + document.addEventListener("zen-ws-toggle", (e) => { + this._ws_paused = e.detail.paused; + }); + + // Initial fetch + await this._fetch_all(); + + // Connect WebSocket for live updates, fall back to polling + this._connect_ws(); + } + + _collapsible_section(name) + { + const section = this.add_section(name); + const container = section._parent.inner(); + const heading = container.firstElementChild; + + heading.style.cursor = "pointer"; + heading.style.userSelect = "none"; + + const indicator = document.createElement("span"); + indicator.textContent = " \u25BC"; + indicator.style.fontSize = "0.7em"; + heading.appendChild(indicator); + + let collapsed = false; + heading.addEventListener("click", (e) => { + if (e.target !== heading && e.target !== indicator) + { + return; + } + collapsed = !collapsed; + indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; + let sibling = heading.nextElementSibling; + while (sibling) + { + sibling.style.display = collapsed ? "none" : ""; + sibling = sibling.nextElementSibling; + } + }); + + return section; + } + + async _fetch_all() + { + try + { + const [agents, history, clients, client_history] = await Promise.all([ + new Fetcher().resource("/orch/agents").json(), + new Fetcher().resource("/orch/history").param("limit", "50").json().catch(() => null), + new Fetcher().resource("/orch/clients").json().catch(() => null), + new Fetcher().resource("/orch/clients/history").param("limit", "50").json().catch(() => null), + ]); + + this._render_agents(agents); + if (history) + { + this._render_history(history.events || []); + } + if (clients) + { + this._render_clients(clients.clients || []); + } + if (client_history) + { + this._render_client_history(client_history.client_events || []); + } + } + catch (e) { /* service unavailable */ } + } + + _connect_ws() + { + try + { + const proto = location.protocol === "https:" ? "wss:" : "ws:"; + const ws = new WebSocket(`${proto}//${location.host}/orch/ws`); + + ws.onopen = () => { + if (this._poll_timer) + { + clearInterval(this._poll_timer); + this._poll_timer = null; + } + }; + + ws.onmessage = (ev) => { + if (this._ws_paused) + { + return; + } + try + { + const data = JSON.parse(ev.data); + this._render_agents(data); + if (data.events) + { + this._render_history(data.events); + } + if (data.clients) + { + this._render_clients(data.clients); + } + if (data.client_events) + { + this._render_client_history(data.client_events); + } + } + catch (e) { /* ignore parse errors */ } + }; + + ws.onclose = () => { + this._start_polling(); + setTimeout(() => this._connect_ws(), 3000); + }; + + ws.onerror = () => { /* onclose will fire */ }; + } + catch (e) + { + this._start_polling(); + } + } + + _start_polling() + { + if (!this._poll_timer) + { + this._poll_timer = setInterval(() => this._fetch_all(), 2000); + } + } + + _render_agents(data) + { + const workers = data.workers || []; + + if (this._agents_table) + { + this._agents_table.clear(); + } + else + { + this._agents_table = this._agents_host.add_widget( + Table, + ["hostname", "CPUs", "CPU usage", "memory", "queues", "pending", "running", "completed", "traffic", "last seen"], + Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1 + ); + } + + if (workers.length === 0) + { + return; + } + + let totalCpus = 0, totalWeightedCpu = 0; + let totalMemUsed = 0, totalMemTotal = 0; + let totalQueues = 0, totalPending = 0, totalRunning = 0, totalCompleted = 0; + let totalRecv = 0, totalSent = 0; + + for (const w of workers) + { + const cpus = w.cpus || 0; + const cpuUsage = w.cpu_usage; + const memUsed = w.memory_used || 0; + const memTotal = w.memory_total || 0; + const queues = w.active_queues || 0; + const pending = w.actions_pending || 0; + const running = w.actions_running || 0; + const completed = w.actions_completed || 0; + const recv = w.bytes_received || 0; + const sent = w.bytes_sent || 0; + + totalCpus += cpus; + if (cpus > 0 && typeof cpuUsage === "number") + { + totalWeightedCpu += cpuUsage * cpus; + } + totalMemUsed += memUsed; + totalMemTotal += memTotal; + totalQueues += queues; + totalPending += pending; + totalRunning += running; + totalCompleted += completed; + totalRecv += recv; + totalSent += sent; + + const hostname = w.hostname || ""; + const row = this._agents_table.add_row( + hostname, + cpus > 0 ? Friendly.sep(cpus) : "-", + typeof cpuUsage === "number" ? cpuUsage.toFixed(1) + "%" : "-", + memTotal > 0 ? Friendly.bytes(memUsed) + " / " + Friendly.bytes(memTotal) : "-", + queues > 0 ? Friendly.sep(queues) : "-", + Friendly.sep(pending), + Friendly.sep(running), + Friendly.sep(completed), + this._format_traffic(recv, sent), + this._format_last_seen(w.dt), + ); + + // Link hostname to worker dashboard + if (w.uri) + { + const cell = row.get_cell(0); + cell.inner().textContent = ""; + cell.tag("a").text(hostname).attr("href", w.uri + "/dashboard/compute/").attr("target", "_blank"); + } + } + + // Total row + const total = this._agents_table.add_row( + "TOTAL", + Friendly.sep(totalCpus), + "", + totalMemTotal > 0 ? Friendly.bytes(totalMemUsed) + " / " + Friendly.bytes(totalMemTotal) : "-", + Friendly.sep(totalQueues), + Friendly.sep(totalPending), + Friendly.sep(totalRunning), + Friendly.sep(totalCompleted), + this._format_traffic(totalRecv, totalSent), + "", + ); + total.get_cell(0).style("fontWeight", "bold"); + } + + _render_clients(clients) + { + if (this._clients_table) + { + this._clients_table.clear(); + } + else + { + this._clients_table = this._clients_host.add_widget( + Table, + ["client ID", "hostname", "address", "last seen"], + Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable, -1 + ); + } + + for (const c of clients) + { + this._clients_table.add_row( + c.id || "", + c.hostname || "", + c.address || "", + this._format_last_seen(c.dt), + ); + } + } + + _render_history(events) + { + if (this._history_table) + { + this._history_table.clear(); + } + else + { + this._history_table = this._history_host.add_widget( + Table, + ["time", "event", "worker", "hostname"], + Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable, -1 + ); + } + + for (const evt of events) + { + this._history_table.add_row( + this._format_timestamp(evt.ts), + evt.type || "", + evt.worker_id || "", + evt.hostname || "", + ); + } + } + + _render_client_history(events) + { + if (this._client_history_table) + { + this._client_history_table.clear(); + } + else + { + this._client_history_table = this._client_history_host.add_widget( + Table, + ["time", "event", "client", "hostname"], + Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable, -1 + ); + } + + for (const evt of events) + { + this._client_history_table.add_row( + this._format_timestamp(evt.ts), + evt.type || "", + evt.client_id || "", + evt.hostname || "", + ); + } + } + + _metric(parent, value, label, hero = false) + { + const m = parent.tag().classify("tile-metric"); + if (hero) + { + m.classify("tile-metric-hero"); + } + m.tag().classify("metric-value").text(value); + m.tag().classify("metric-label").text(label); + } + + _format_last_seen(dtMs) + { + if (dtMs == null) + { + return "-"; + } + const seconds = Math.floor(dtMs / 1000); + if (seconds < 60) + { + return seconds + "s ago"; + } + const minutes = Math.floor(seconds / 60); + if (minutes < 60) + { + return minutes + "m " + (seconds % 60) + "s ago"; + } + const hours = Math.floor(minutes / 60); + return hours + "h " + (minutes % 60) + "m ago"; + } + + _format_traffic(recv, sent) + { + if (!recv && !sent) + { + return "-"; + } + return Friendly.bytes(recv) + " / " + Friendly.bytes(sent); + } + + _format_timestamp(ts) + { + if (!ts) + { + return "-"; + } + let date; + if (typeof ts === "number") + { + // .NET-style ticks: convert to Unix ms + const unixMs = (ts - 621355968000000000) / 10000; + date = new Date(unixMs); + } + else + { + date = new Date(ts); + } + if (isNaN(date.getTime())) + { + return "-"; + } + return date.toLocaleTimeString(); + } +} diff --git a/src/zenserver/frontend/html/pages/page.js b/src/zenserver/frontend/html/pages/page.js index 9a9541904..dd8032c28 100644 --- a/src/zenserver/frontend/html/pages/page.js +++ b/src/zenserver/frontend/html/pages/page.js @@ -3,6 +3,7 @@ "use strict"; import { WidgetHost } from "../util/widgets.js" +import { Fetcher } from "../util/fetcher.js" //////////////////////////////////////////////////////////////////////////////// export class PageBase extends WidgetHost @@ -63,31 +64,85 @@ export class ZenPage extends PageBase super(parent, ...args); super.set_title("zen"); this.add_branding(parent); + this.add_service_nav(parent); this.generate_crumbs(); } add_branding(parent) { - var root = parent.tag().id("branding"); - - const zen_store = root.tag("pre").id("logo").text( - "_________ _______ __\n" + - "\\____ /___ ___ / ___// |__ ___ ______ ____\n" + - " / __/ __ \\ / \\ \\___ \\\\_ __// \\\\_ \\/ __ \\\n" + - " / \\ __// | \\/ \\| | ( - )| |\\/\\ __/\n" + - "/______/\\___/\\__|__/\\______/|__| \\___/ |__| \\___|" - ); - zen_store.tag().id("go_home").on_click(() => window.location.search = ""); - - root.tag("img").attr("src", "favicon.ico").id("ue_logo"); - - /* - _________ _______ __ - \____ /___ ___ / ___// |__ ___ ______ ____ - / __/ __ \ / \ \___ \\_ __// \\_ \/ __ \ - / \ __// | \/ \| | ( - )| |\/\ __/ - /______/\___/\__|__/\______/|__| \___/ |__| \___| - */ + var banner = parent.tag("zen-banner"); + banner.attr("subtitle", "SERVER"); + banner.attr("tagline", "Local Storage Service"); + banner.attr("logo-src", "favicon.ico"); + banner.attr("load", "0"); + + this._banner = banner; + this._poll_status(); + } + + async _poll_status() + { + try + { + var cbo = await new Fetcher().resource("/status/status").cbo(); + if (cbo) + { + var obj = cbo.as_object(); + + var hostname = obj.find("hostname"); + if (hostname) + { + this._banner.attr("tagline", "Local Storage Service \u2014 " + hostname.as_value()); + } + + var cpu = obj.find("cpuUsagePercent"); + if (cpu) + { + this._banner.attr("load", cpu.as_value().toFixed(1)); + } + } + } + catch (e) { console.warn("status poll:", e); } + + setTimeout(() => this._poll_status(), 2000); + } + + add_service_nav(parent) + { + const nav = parent.tag().id("service_nav"); + + // Map service base URIs to dashboard links, this table is also used to detemine + // which links to show based on the services that are currently registered. + + const service_dashboards = [ + { base_uri: "/compute/", label: "Compute", href: "/dashboard/?page=compute" }, + { base_uri: "/orch/", label: "Orchestrator", href: "/dashboard/?page=orchestrator" }, + { base_uri: "/hub/", label: "Hub", href: "/dashboard/?page=hub" }, + ]; + + nav.tag("a").text("Home").attr("href", "/dashboard/"); + + nav.tag("a").text("Sessions").attr("href", "/dashboard/?page=sessions"); + nav.tag("a").text("Cache").attr("href", "/dashboard/?page=cache"); + nav.tag("a").text("Projects").attr("href", "/dashboard/?page=projects"); + this._info_link = nav.tag("a").text("Info").attr("href", "/dashboard/?page=info"); + + new Fetcher().resource("/api/").json().then((data) => { + const services = data.services || []; + const uris = new Set(services.map(s => s.base_uri)); + + const links = service_dashboards.filter(d => uris.has(d.base_uri)); + + // Insert service links before the Info link + const info_elem = this._info_link.inner(); + for (const link of links) + { + const a = document.createElement("a"); + a.textContent = link.label; + a.href = link.href; + info_elem.parentNode.insertBefore(a, info_elem); + } + }).catch(() => {}); } set_title(...args) @@ -97,7 +152,7 @@ export class ZenPage extends PageBase generate_crumbs() { - const auto_name = this.get_param("page") || "start"; + var auto_name = this.get_param("page") || "start"; if (auto_name == "start") return; @@ -114,15 +169,30 @@ export class ZenPage extends PageBase var project = this.get_param("project"); if (project != undefined) { + auto_name = project; var oplog = this.get_param("oplog"); if (oplog != undefined) { - new_crumb("project", `?page=project&project=${project}`); - if (this.get_param("opkey")) - new_crumb("oplog", `?page=oplog&project=${project}&oplog=${oplog}`); + new_crumb(auto_name, `?page=project&project=${project}`); + auto_name = oplog; + var opkey = this.get_param("opkey") + if (opkey != undefined) + { + new_crumb(auto_name, `?page=oplog&project=${project}&oplog=${oplog}`); + auto_name = opkey.split("/").pop().split("\\").pop(); + + // Check if we're viewing cook artifacts + var page = this.get_param("page"); + var hash = this.get_param("hash"); + if (hash != undefined && page == "cookartifacts") + { + new_crumb(auto_name, `?page=entry&project=${project}&oplog=${oplog}&opkey=${opkey}`); + auto_name = "cook artifacts"; + } + } } } - new_crumb(auto_name.toLowerCase()); + new_crumb(auto_name); } } diff --git a/src/zenserver/frontend/html/pages/project.js b/src/zenserver/frontend/html/pages/project.js index 42ae30c8c..3a7a45527 100644 --- a/src/zenserver/frontend/html/pages/project.js +++ b/src/zenserver/frontend/html/pages/project.js @@ -59,7 +59,7 @@ export class Page extends ZenPage info = await info; row.get_cell(1).text(info["markerpath"]); - row.get_cell(2).text(Friendly.kib(info["totalsize"])); + row.get_cell(2).text(Friendly.bytes(info["totalsize"])); row.get_cell(3).text(Friendly.sep(info["opcount"])); row.get_cell(4).text(info["expired"]); } diff --git a/src/zenserver/frontend/html/pages/projects.js b/src/zenserver/frontend/html/pages/projects.js new file mode 100644 index 000000000..9c1e519d4 --- /dev/null +++ b/src/zenserver/frontend/html/pages/projects.js @@ -0,0 +1,447 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +"use strict"; + +import { ZenPage } from "./page.js" +import { Fetcher } from "../util/fetcher.js" +import { Friendly } from "../util/friendly.js" +import { Modal } from "../util/modal.js" +import { Table, Toolbar } from "../util/widgets.js" + +//////////////////////////////////////////////////////////////////////////////// +export class Page extends ZenPage +{ + async main() + { + this.set_title("projects"); + + // Project Service Stats + const stats_section = this._collapsible_section("Project Service Stats"); + stats_section.tag().classify("dropall").text("raw yaml \u2192").on_click(() => { + window.open("/stats/prj.yaml", "_blank"); + }); + this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles"); + + const stats = await new Fetcher().resource("stats", "prj").json(); + if (stats) + { + this._render_stats(stats); + } + + this._connect_stats_ws(); + + // Projects list + var section = this._collapsible_section("Projects"); + + section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all()); + + var columns = [ + "name", + "project dir", + "engine dir", + "oplogs", + "actions", + ]; + + this._project_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric); + + var projects = await new Fetcher().resource("/prj/list").json(); + projects.sort((a, b) => (b.LastAccessTime || 0) - (a.LastAccessTime || 0)); + + for (const project of projects) + { + var row = this._project_table.add_row( + "", + "", + "", + "", + ); + + var cell = row.get_cell(0); + cell.tag().text(project.Id).on_click(() => this.view_project(project.Id)); + + if (project.ProjectRootDir) + { + row.get_cell(1).tag("a").text(project.ProjectRootDir) + .attr("href", "vscode://" + project.ProjectRootDir.replace(/\\/g, "/")); + } + if (project.EngineRootDir) + { + row.get_cell(2).tag("a").text(project.EngineRootDir) + .attr("href", "vscode://" + project.EngineRootDir.replace(/\\/g, "/")); + } + + cell = row.get_cell(-1); + const action_tb = new Toolbar(cell, true).left(); + action_tb.add("view").on_click(() => this.view_project(project.Id)); + action_tb.add("drop").on_click(() => this.drop_project(project.Id)); + + row.attr("zs_name", project.Id); + + // Fetch project details to get oplog count + new Fetcher().resource("prj", project.Id).json().then((info) => { + const oplogs = info["oplogs"] || []; + row.get_cell(3).text(Friendly.sep(oplogs.length)).style("textAlign", "right"); + // Right-align the corresponding header cell + const header = this._project_table._element.firstElementChild; + if (header && header.children[4]) + { + header.children[4].style.textAlign = "right"; + } + }).catch(() => {}); + } + + // Project detail area (inside projects section so it collapses together) + this._project_host = section; + this._project_container = null; + this._selected_project = null; + + // Restore project from URL if present + const prj_param = this.get_param("project"); + if (prj_param) + { + this.view_project(prj_param); + } + } + + _collapsible_section(name) + { + const section = this.add_section(name); + const container = section._parent.inner(); + const heading = container.firstElementChild; + + heading.style.cursor = "pointer"; + heading.style.userSelect = "none"; + + const indicator = document.createElement("span"); + indicator.textContent = " \u25BC"; + indicator.style.fontSize = "0.7em"; + heading.appendChild(indicator); + + let collapsed = false; + heading.addEventListener("click", (e) => { + if (e.target !== heading && e.target !== indicator) + { + return; + } + collapsed = !collapsed; + indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; + let sibling = heading.nextElementSibling; + while (sibling) + { + sibling.style.display = collapsed ? "none" : ""; + sibling = sibling.nextElementSibling; + } + }); + + return section; + } + + _clear_param(name) + { + this._params.delete(name); + const url = new URL(window.location); + url.searchParams.delete(name); + history.replaceState(null, "", url); + } + + _connect_stats_ws() + { + try + { + const proto = location.protocol === "https:" ? "wss:" : "ws:"; + const ws = new WebSocket(`${proto}//${location.host}/stats`); + + try { this._ws_paused = localStorage.getItem("zen-ws-paused") === "true"; } catch (e) { this._ws_paused = false; } + document.addEventListener("zen-ws-toggle", (e) => { + this._ws_paused = e.detail.paused; + }); + + ws.onmessage = (ev) => { + if (this._ws_paused) + { + return; + } + try + { + const all_stats = JSON.parse(ev.data); + const stats = all_stats["prj"]; + if (stats) + { + this._render_stats(stats); + } + } + catch (e) { /* ignore parse errors */ } + }; + + ws.onclose = () => { this._stats_ws = null; }; + ws.onerror = () => { ws.close(); }; + + this._stats_ws = ws; + } + catch (e) { /* WebSocket not available */ } + } + + _render_stats(stats) + { + const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj); + const grid = this._stats_grid; + + grid.inner().innerHTML = ""; + + // HTTP Requests tile + { + const req = safe(stats, "requests"); + if (req) + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("HTTP Requests"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + const reqData = req.requests || req; + this._metric(left, Friendly.sep(safe(stats, "store.requestcount") || 0), "total requests", true); + if (reqData.rate_mean > 0) + { + this._metric(left, Friendly.sep(reqData.rate_mean, 1) + "/s", "req/sec (mean)"); + } + if (reqData.rate_1 > 0) + { + this._metric(left, Friendly.sep(reqData.rate_1, 1) + "/s", "req/sec (1m)"); + } + const badRequests = safe(stats, "store.badrequestcount") || 0; + this._metric(left, Friendly.sep(badRequests), "bad requests"); + + const right = columns.tag().classify("tile-metrics"); + this._metric(right, Friendly.duration(reqData.t_avg || 0), "avg latency", true); + if (reqData.t_p75) + { + this._metric(right, Friendly.duration(reqData.t_p75), "p75"); + } + if (reqData.t_p95) + { + this._metric(right, Friendly.duration(reqData.t_p95), "p95"); + } + if (reqData.t_p99) + { + this._metric(right, Friendly.duration(reqData.t_p99), "p99"); + } + } + } + + // Store Operations tile + { + const store = safe(stats, "store"); + if (store) + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Store Operations"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + const proj = store.project || {}; + this._metric(left, Friendly.sep(proj.readcount || 0), "project reads", true); + this._metric(left, Friendly.sep(proj.writecount || 0), "project writes"); + this._metric(left, Friendly.sep(proj.deletecount || 0), "project deletes"); + + const right = columns.tag().classify("tile-metrics"); + const oplog = store.oplog || {}; + this._metric(right, Friendly.sep(oplog.readcount || 0), "oplog reads", true); + this._metric(right, Friendly.sep(oplog.writecount || 0), "oplog writes"); + this._metric(right, Friendly.sep(oplog.deletecount || 0), "oplog deletes"); + } + } + + // Op & Chunk tile + { + const store = safe(stats, "store"); + if (store) + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Ops & Chunks"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + const op = store.op || {}; + const opTotal = (op.hitcount || 0) + (op.misscount || 0); + const opRatio = opTotal > 0 ? (((op.hitcount || 0) / opTotal) * 100).toFixed(1) + "%" : "-"; + this._metric(left, opRatio, "op hit ratio", true); + this._metric(left, Friendly.sep(op.hitcount || 0), "op hits"); + this._metric(left, Friendly.sep(op.misscount || 0), "op misses"); + this._metric(left, Friendly.sep(op.writecount || 0), "op writes"); + + const right = columns.tag().classify("tile-metrics"); + const chunk = store.chunk || {}; + const chunkTotal = (chunk.hitcount || 0) + (chunk.misscount || 0); + const chunkRatio = chunkTotal > 0 ? (((chunk.hitcount || 0) / chunkTotal) * 100).toFixed(1) + "%" : "-"; + this._metric(right, chunkRatio, "chunk hit ratio", true); + this._metric(right, Friendly.sep(chunk.hitcount || 0), "chunk hits"); + this._metric(right, Friendly.sep(chunk.misscount || 0), "chunk misses"); + this._metric(right, Friendly.sep(chunk.writecount || 0), "chunk writes"); + } + } + + // Storage tile + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Storage"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + this._metric(left, safe(stats, "store.size.disk") != null ? Friendly.bytes(safe(stats, "store.size.disk")) : "-", "store disk", true); + this._metric(left, safe(stats, "store.size.memory") != null ? Friendly.bytes(safe(stats, "store.size.memory")) : "-", "store memory"); + + const right = columns.tag().classify("tile-metrics"); + this._metric(right, safe(stats, "cid.size.total") != null ? Friendly.bytes(safe(stats, "cid.size.total")) : "-", "cid total", true); + this._metric(right, safe(stats, "cid.size.tiny") != null ? Friendly.bytes(safe(stats, "cid.size.tiny")) : "-", "cid tiny"); + this._metric(right, safe(stats, "cid.size.small") != null ? Friendly.bytes(safe(stats, "cid.size.small")) : "-", "cid small"); + this._metric(right, safe(stats, "cid.size.large") != null ? Friendly.bytes(safe(stats, "cid.size.large")) : "-", "cid large"); + } + } + + _metric(parent, value, label, hero = false) + { + const m = parent.tag().classify("tile-metric"); + if (hero) + { + m.classify("tile-metric-hero"); + } + m.tag().classify("metric-value").text(value); + m.tag().classify("metric-label").text(label); + } + + async view_project(project_id) + { + // Toggle off if already selected + if (this._selected_project === project_id) + { + this._selected_project = null; + this._clear_project_detail(); + this._clear_param("project"); + return; + } + + this._selected_project = project_id; + this._clear_project_detail(); + this.set_param("project", project_id); + + const info = await new Fetcher().resource("prj", project_id).json(); + if (this._selected_project !== project_id) + { + return; + } + + const section = this._project_host.add_section(project_id); + this._project_container = section; + + // Oplogs table + const oplog_section = section.add_section("Oplogs"); + const oplog_table = oplog_section.add_widget( + Table, + ["name", "marker", "size", "ops", "expired", "actions"], + Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric + ); + + let totalSize = 0, totalOps = 0; + const total_row = oplog_table.add_row("TOTAL"); + total_row.get_cell(0).style("fontWeight", "bold"); + total_row.get_cell(2).style("textAlign", "right").style("fontWeight", "bold"); + total_row.get_cell(3).style("textAlign", "right").style("fontWeight", "bold"); + + // Right-align header for numeric columns (size, ops) + const header = oplog_table._element.firstElementChild; + for (let i = 3; i < header.children.length - 1; i++) + { + header.children[i].style.textAlign = "right"; + } + + for (const oplog of info["oplogs"] || []) + { + const name = oplog["id"]; + const row = oplog_table.add_row(""); + + var cell = row.get_cell(0); + cell.tag().text(name).link("", { + "page": "oplog", + "project": project_id, + "oplog": name, + }); + + cell = row.get_cell(-1); + const action_tb = new Toolbar(cell, true).left(); + action_tb.add("list").link("", { "page": "oplog", "project": project_id, "oplog": name }); + action_tb.add("tree").link("", { "page": "tree", "project": project_id, "oplog": name }); + action_tb.add("drop").on_click(() => this.drop_oplog(project_id, name)); + + new Fetcher().resource("prj", project_id, "oplog", name).json().then((data) => { + row.get_cell(1).text(data["markerpath"]); + row.get_cell(2).text(Friendly.bytes(data["totalsize"])).style("textAlign", "right"); + row.get_cell(3).text(Friendly.sep(data["opcount"])).style("textAlign", "right"); + row.get_cell(4).text(data["expired"]); + + totalSize += data["totalsize"] || 0; + totalOps += data["opcount"] || 0; + total_row.get_cell(2).text(Friendly.bytes(totalSize)).style("textAlign", "right").style("fontWeight", "bold"); + total_row.get_cell(3).text(Friendly.sep(totalOps)).style("textAlign", "right").style("fontWeight", "bold"); + }).catch(() => {}); + } + } + + _clear_project_detail() + { + if (this._project_container) + { + this._project_container._parent.inner().remove(); + this._project_container = null; + } + } + + drop_oplog(project_id, oplog_id) + { + const drop = async () => { + await new Fetcher().resource("prj", project_id, "oplog", oplog_id).delete(); + // Refresh the project view + this._selected_project = null; + this._clear_project_detail(); + this.view_project(project_id); + }; + + new Modal() + .title("Confirmation") + .message(`Drop oplog '${oplog_id}'?`) + .option("Yes", () => drop()) + .option("No"); + } + + drop_project(project_id) + { + const drop = async () => { + await new Fetcher().resource("prj", project_id).delete(); + this.reload(); + }; + + new Modal() + .title("Confirmation") + .message(`Drop project '${project_id}'?`) + .option("Yes", () => drop()) + .option("No"); + } + + async drop_all() + { + const drop = async () => { + for (const row of this._project_table) + { + const project_id = row.attr("zs_name"); + await new Fetcher().resource("prj", project_id).delete(); + } + this.reload(); + }; + + new Modal() + .title("Confirmation") + .message("Drop every project?") + .option("Yes", () => drop()) + .option("No"); + } +} diff --git a/src/zenserver/frontend/html/pages/sessions.js b/src/zenserver/frontend/html/pages/sessions.js new file mode 100644 index 000000000..95533aa96 --- /dev/null +++ b/src/zenserver/frontend/html/pages/sessions.js @@ -0,0 +1,61 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +"use strict"; + +import { ZenPage } from "./page.js" +import { Fetcher } from "../util/fetcher.js" +import { Table } from "../util/widgets.js" + +//////////////////////////////////////////////////////////////////////////////// +export class Page extends ZenPage +{ + async main() + { + this.set_title("sessions"); + + const data = await new Fetcher().resource("/sessions/").json(); + const sessions = data.sessions || []; + + const section = this.add_section("Sessions"); + + if (sessions.length === 0) + { + section.tag().classify("empty-state").text("No active sessions."); + return; + } + + const columns = [ + "id", + "created", + "updated", + "metadata", + ]; + const table = section.add_widget(Table, columns, Table.Flag_FitLeft); + + for (const session of sessions) + { + const created = session.created_at ? new Date(session.created_at).toLocaleString() : "-"; + const updated = session.updated_at ? new Date(session.updated_at).toLocaleString() : "-"; + const meta = this._format_metadata(session.metadata); + + const row = table.add_row( + session.id || "-", + created, + updated, + meta, + ); + } + } + + _format_metadata(metadata) + { + if (!metadata || Object.keys(metadata).length === 0) + { + return "-"; + } + + return Object.entries(metadata) + .map(([k, v]) => `${k}: ${v}`) + .join(", "); + } +} diff --git a/src/zenserver/frontend/html/pages/start.js b/src/zenserver/frontend/html/pages/start.js index 4c8789431..3a68a725d 100644 --- a/src/zenserver/frontend/html/pages/start.js +++ b/src/zenserver/frontend/html/pages/start.js @@ -13,109 +13,117 @@ export class Page extends ZenPage { async main() { + // Discover which services are available + const api_data = await new Fetcher().resource("/api/").json(); + const available = new Set((api_data.services || []).map(s => s.base_uri)); + // project list - var section = this.add_section("projects"); + var project_table = null; + if (available.has("/prj/")) + { + var section = this.add_section("Cooked Projects"); - section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("projects")); + section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("projects")); - var columns = [ - "name", - "project_dir", - "engine_dir", - "actions", - ]; - var project_table = section.add_widget(Table, columns); + var columns = [ + "name", + "project_dir", + "engine_dir", + "actions", + ]; + project_table = section.add_widget(Table, columns); - for (const project of await new Fetcher().resource("/prj/list").json()) - { - var row = project_table.add_row( - "", - project.ProjectRootDir, - project.EngineRootDir, - ); + var projects = await new Fetcher().resource("/prj/list").json(); + projects.sort((a, b) => (b.LastAccessTime || 0) - (a.LastAccessTime || 0)); + projects = projects.slice(0, 25); + projects.sort((a, b) => a.Id.localeCompare(b.Id)); - var cell = row.get_cell(0); - cell.tag().text(project.Id).on_click((x) => this.view_project(x), project.Id); + for (const project of projects) + { + var row = project_table.add_row( + "", + project.ProjectRootDir, + project.EngineRootDir, + ); + + var cell = row.get_cell(0); + cell.tag().text(project.Id).on_click((x) => this.view_project(x), project.Id); - var cell = row.get_cell(-1); - var action_tb = new Toolbar(cell, true); - action_tb.left().add("view").on_click((x) => this.view_project(x), project.Id); - action_tb.left().add("drop").on_click((x) => this.drop_project(x), project.Id); + var cell = row.get_cell(-1); + var action_tb = new Toolbar(cell, true); + action_tb.left().add("view").on_click((x) => this.view_project(x), project.Id); + action_tb.left().add("drop").on_click((x) => this.drop_project(x), project.Id); - row.attr("zs_name", project.Id); + row.attr("zs_name", project.Id); + } } // cache - var section = this.add_section("z$"); - - section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("z$")); - - columns = [ - "namespace", - "dir", - "buckets", - "entries", - "size disk", - "size mem", - "actions", - ] - var zcache_info = new Fetcher().resource("/z$/").json(); - const cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight); - for (const namespace of (await zcache_info)["Namespaces"]) + var cache_table = null; + if (available.has("/z$/")) { - new Fetcher().resource(`/z$/${namespace}/`).json().then((data) => { - const row = cache_table.add_row( - "", - data["Configuration"]["RootDir"], - data["Buckets"].length, - data["EntryCount"], - Friendly.kib(data["StorageSize"].DiskSize), - Friendly.kib(data["StorageSize"].MemorySize) - ); - var cell = row.get_cell(0); - cell.tag().text(namespace).on_click(() => this.view_zcache(namespace)); - row.get_cell(1).tag().text(namespace); + var section = this.add_section("Cache"); - cell = row.get_cell(-1); - const action_tb = new Toolbar(cell, true); - action_tb.left().add("view").on_click(() => this.view_zcache(namespace)); - action_tb.left().add("drop").on_click(() => this.drop_zcache(namespace)); + section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("z$")); - row.attr("zs_name", namespace); - }); + var columns = [ + "namespace", + "dir", + "buckets", + "entries", + "size disk", + "size mem", + "actions", + ]; + var zcache_info = await new Fetcher().resource("/z$/").json(); + cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight); + for (const namespace of zcache_info["Namespaces"] || []) + { + new Fetcher().resource(`/z$/${namespace}/`).json().then((data) => { + const row = cache_table.add_row( + "", + data["Configuration"]["RootDir"], + data["Buckets"].length, + data["EntryCount"], + Friendly.bytes(data["StorageSize"].DiskSize), + Friendly.bytes(data["StorageSize"].MemorySize) + ); + var cell = row.get_cell(0); + cell.tag().text(namespace).on_click(() => this.view_zcache(namespace)); + row.get_cell(1).tag().text(namespace); + + cell = row.get_cell(-1); + const action_tb = new Toolbar(cell, true); + action_tb.left().add("view").on_click(() => this.view_zcache(namespace)); + action_tb.left().add("drop").on_click(() => this.drop_zcache(namespace)); + + row.attr("zs_name", namespace); + }); + } } - // stats + // stats tiles const safe_lookup = (obj, path, pretty=undefined) => { const ret = path.split(".").reduce((a,b) => a && a[b], obj); - if (ret === undefined) return "-"; + if (ret === undefined) return undefined; return pretty ? pretty(ret) : ret; }; - section = this.add_section("stats"); - columns = [ - "name", - "req count", - "size disk", - "size mem", - "cid total", - ]; - const stats_table = section.add_widget(Table, columns, Table.Flag_PackRight); - var providers = new Fetcher().resource("stats").json(); - for (var provider of (await providers)["providers"]) - { - var stats = await new Fetcher().resource("stats", provider).json(); - var size_stat = (stats.store || stats.cache); - var values = [ - "", - safe_lookup(stats, "requests.count"), - safe_lookup(size_stat, "size.disk", Friendly.kib), - safe_lookup(size_stat, "size.memory", Friendly.kib), - safe_lookup(stats, "cid.size.total"), - ]; - row = stats_table.add_row(...values); - row.get_cell(0).tag().text(provider).on_click((x) => this.view_stat(x), provider); - } + var section = this.add_section("Stats"); + section.tag().classify("dropall").text("metrics dashboard →").on_click(() => { + window.location = "?page=metrics"; + }); + + var providers_data = await new Fetcher().resource("stats").json(); + var provider_list = providers_data["providers"] || []; + var all_stats = {}; + await Promise.all(provider_list.map(async (provider) => { + all_stats[provider] = await new Fetcher().resource("stats", provider).json(); + })); + + this._stats_grid = section.tag().classify("grid").classify("stats-tiles"); + this._safe_lookup = safe_lookup; + this._render_stats(all_stats); // version var ver_tag = this.tag().id("version"); @@ -125,6 +133,159 @@ export class Page extends ZenPage this._project_table = project_table; this._cache_table = cache_table; + + // WebSocket for live stats updates + this._connect_stats_ws(); + } + + _connect_stats_ws() + { + try + { + const proto = location.protocol === "https:" ? "wss:" : "ws:"; + const ws = new WebSocket(`${proto}//${location.host}/stats`); + + try { this._ws_paused = localStorage.getItem("zen-ws-paused") === "true"; } catch (e) { this._ws_paused = false; } + document.addEventListener("zen-ws-toggle", (e) => { + this._ws_paused = e.detail.paused; + }); + + ws.onmessage = (ev) => { + if (this._ws_paused) + { + return; + } + try + { + const all_stats = JSON.parse(ev.data); + this._render_stats(all_stats); + } + catch (e) { /* ignore parse errors */ } + }; + + ws.onclose = () => { this._stats_ws = null; }; + ws.onerror = () => { ws.close(); }; + + this._stats_ws = ws; + } + catch (e) { /* WebSocket not available */ } + } + + _render_stats(all_stats) + { + const grid = this._stats_grid; + const safe_lookup = this._safe_lookup; + + // Clear existing tiles + grid.inner().innerHTML = ""; + + // HTTP tile — aggregate request stats across all providers + { + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("HTTP"); + const columns = tile.tag().classify("tile-columns"); + + // Left column: request stats + const left = columns.tag().classify("tile-metrics"); + + let total_requests = 0; + let total_rate = 0; + for (const p in all_stats) + { + total_requests += (safe_lookup(all_stats[p], "requests.count") || 0); + total_rate += (safe_lookup(all_stats[p], "requests.rate_1") || 0); + } + + this._add_tile_metric(left, Friendly.sep(total_requests), "total requests", true); + if (total_rate > 0) + this._add_tile_metric(left, Friendly.sep(total_rate, 1) + "/s", "req/sec (1m)"); + + // Right column: websocket stats + const ws = all_stats["http"] ? (all_stats["http"]["websockets"] || {}) : {}; + const right = columns.tag().classify("tile-metrics"); + + this._add_tile_metric(right, Friendly.sep(ws.active_connections || 0), "ws connections", true); + const ws_frames = (ws.frames_received || 0) + (ws.frames_sent || 0); + if (ws_frames > 0) + this._add_tile_metric(right, Friendly.sep(ws_frames), "ws frames"); + const ws_bytes = (ws.bytes_received || 0) + (ws.bytes_sent || 0); + if (ws_bytes > 0) + this._add_tile_metric(right, Friendly.bytes(ws_bytes), "ws traffic"); + + tile.on_click(() => { window.location = "?page=metrics"; }); + } + + // Cache tile (z$) + if (all_stats["z$"]) + { + const s = all_stats["z$"]; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Cache"); + const body = tile.tag().classify("tile-metrics"); + + const hits = safe_lookup(s, "cache.hits") || 0; + const misses = safe_lookup(s, "cache.misses") || 0; + const ratio = (hits + misses) > 0 ? ((hits / (hits + misses)) * 100).toFixed(1) + "%" : "-"; + + this._add_tile_metric(body, ratio, "hit ratio", true); + this._add_tile_metric(body, safe_lookup(s, "cache.size.disk", Friendly.bytes) || "-", "disk"); + this._add_tile_metric(body, safe_lookup(s, "cache.size.memory", Friendly.bytes) || "-", "memory"); + + tile.on_click(() => { window.location = "?page=stat&provider=z$"; }); + } + + // Project Store tile (prj) + if (all_stats["prj"]) + { + const s = all_stats["prj"]; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Project Store"); + const body = tile.tag().classify("tile-metrics"); + + this._add_tile_metric(body, safe_lookup(s, "requests.count", Friendly.sep) || "-", "requests", true); + this._add_tile_metric(body, safe_lookup(s, "store.size.disk", Friendly.bytes) || "-", "disk"); + + tile.on_click(() => { window.location = "?page=stat&provider=prj"; }); + } + + // Build Store tile (builds) + if (all_stats["builds"]) + { + const s = all_stats["builds"]; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Build Store"); + const body = tile.tag().classify("tile-metrics"); + + this._add_tile_metric(body, safe_lookup(s, "requests.count", Friendly.sep) || "-", "requests", true); + this._add_tile_metric(body, safe_lookup(s, "store.size.disk", Friendly.bytes) || "-", "disk"); + + tile.on_click(() => { window.location = "?page=stat&provider=builds"; }); + } + + // Workspace tile (ws) + if (all_stats["ws"]) + { + const s = all_stats["ws"]; + const tile = grid.tag().classify("card").classify("stats-tile"); + tile.tag().classify("card-title").text("Workspace"); + const body = tile.tag().classify("tile-metrics"); + + this._add_tile_metric(body, safe_lookup(s, "requests.count", Friendly.sep) || "-", "requests", true); + this._add_tile_metric(body, safe_lookup(s, "workspaces.filescount", Friendly.sep) || "-", "files"); + + tile.on_click(() => { window.location = "?page=stat&provider=ws"; }); + } + } + + _add_tile_metric(parent, value, label, hero=false) + { + const m = parent.tag().classify("tile-metric"); + if (hero) + { + m.classify("tile-metric-hero"); + } + m.tag().classify("metric-value").text(value); + m.tag().classify("metric-label").text(label); } view_stat(provider) diff --git a/src/zenserver/frontend/html/pages/stat.js b/src/zenserver/frontend/html/pages/stat.js index d6c7fa8e8..4f020ac5e 100644 --- a/src/zenserver/frontend/html/pages/stat.js +++ b/src/zenserver/frontend/html/pages/stat.js @@ -33,7 +33,7 @@ class TemporalStat out[key] = data[key]; } - var friendly = this._as_bytes ? Friendly.kib : Friendly.sep; + var friendly = this._as_bytes ? Friendly.bytes : Friendly.sep; var content = ""; for (var i = 0; i < columns.length; ++i) diff --git a/src/zenserver/frontend/html/pages/tree.js b/src/zenserver/frontend/html/pages/tree.js index 08a578492..b5fece5a3 100644 --- a/src/zenserver/frontend/html/pages/tree.js +++ b/src/zenserver/frontend/html/pages/tree.js @@ -106,7 +106,7 @@ export class Page extends ZenPage for (var i = 0; i < 2; ++i) { - const size = Friendly.kib(new_nodes[name][i]); + const size = Friendly.bytes(new_nodes[name][i]); info.tag().text(size); } diff --git a/src/zenserver/frontend/html/pages/zcache.js b/src/zenserver/frontend/html/pages/zcache.js index 974893b21..d8bdc892a 100644 --- a/src/zenserver/frontend/html/pages/zcache.js +++ b/src/zenserver/frontend/html/pages/zcache.js @@ -27,8 +27,8 @@ export class Page extends ZenPage cfg_table.add_object(info["Configuration"], true); - storage_table.add_property("disk", Friendly.kib(info["StorageSize"]["DiskSize"])); - storage_table.add_property("mem", Friendly.kib(info["StorageSize"]["MemorySize"])); + storage_table.add_property("disk", Friendly.bytes(info["StorageSize"]["DiskSize"])); + storage_table.add_property("mem", Friendly.bytes(info["StorageSize"]["MemorySize"])); storage_table.add_property("entries", Friendly.sep(info["EntryCount"])); var column_names = ["name", "disk", "mem", "entries", "actions"]; @@ -41,8 +41,8 @@ export class Page extends ZenPage { const row = bucket_table.add_row(bucket); new Fetcher().resource(`/z$/${namespace}/${bucket}`).json().then((data) => { - row.get_cell(1).text(Friendly.kib(data["StorageSize"]["DiskSize"])); - row.get_cell(2).text(Friendly.kib(data["StorageSize"]["MemorySize"])); + row.get_cell(1).text(Friendly.bytes(data["StorageSize"]["DiskSize"])); + row.get_cell(2).text(Friendly.bytes(data["StorageSize"]["MemorySize"])); row.get_cell(3).text(Friendly.sep(data["DiskEntryCount"])); const cell = row.get_cell(-1); diff --git a/src/zenserver/frontend/html/theme.js b/src/zenserver/frontend/html/theme.js new file mode 100644 index 000000000..52ca116ab --- /dev/null +++ b/src/zenserver/frontend/html/theme.js @@ -0,0 +1,116 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +// Theme toggle: cycles system → light → dark → system. +// Persists choice in localStorage. Applies data-theme attribute on <html>. + +(function() { + var KEY = 'zen-theme'; + + function getStored() { + try { return localStorage.getItem(KEY); } catch (e) { return null; } + } + + function setStored(value) { + try { + if (value) localStorage.setItem(KEY, value); + else localStorage.removeItem(KEY); + } catch (e) {} + } + + function apply(theme) { + if (theme) + document.documentElement.setAttribute('data-theme', theme); + else + document.documentElement.removeAttribute('data-theme'); + } + + function getEffective(stored) { + if (stored) return stored; + return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light'; + } + + // Apply stored preference immediately (before paint) + var stored = getStored(); + apply(stored); + + // Create toggle button once DOM is ready + function createToggle() { + var btn = document.createElement('button'); + btn.id = 'zen_theme_toggle'; + btn.title = 'Toggle theme'; + + function updateIcon() { + var effective = getEffective(getStored()); + // Show sun in dark mode (click to go light), moon in light mode (click to go dark) + btn.textContent = effective === 'dark' ? '\u2600' : '\u263E'; + + var isManual = getStored() != null; + btn.title = isManual + ? 'Theme: ' + effective + ' (click to change, double-click for system)' + : 'Theme: system (click to change)'; + } + + btn.addEventListener('click', function() { + var current = getStored(); + var effective = getEffective(current); + // Toggle to the opposite + var next = effective === 'dark' ? 'light' : 'dark'; + setStored(next); + apply(next); + updateIcon(); + }); + + btn.addEventListener('dblclick', function(e) { + e.preventDefault(); + // Reset to system preference + setStored(null); + apply(null); + updateIcon(); + }); + + // Update icon when system preference changes + window.matchMedia('(prefers-color-scheme: dark)').addEventListener('change', function() { + if (!getStored()) updateIcon(); + }); + + updateIcon(); + document.body.appendChild(btn); + + // WebSocket pause/play toggle + var WS_KEY = 'zen-ws-paused'; + var wsBtn = document.createElement('button'); + wsBtn.id = 'zen_ws_toggle'; + + var initialPaused = false; + try { initialPaused = localStorage.getItem(WS_KEY) === 'true'; } catch (e) {} + + function updateWsIcon(paused) { + wsBtn.dataset.paused = paused ? 'true' : 'false'; + wsBtn.textContent = paused ? '\u25B6' : '\u23F8'; + wsBtn.title = paused ? 'Resume live updates' : 'Pause live updates'; + } + + updateWsIcon(initialPaused); + + // Fire initial event so pages pick up persisted state + document.addEventListener('DOMContentLoaded', function() { + if (initialPaused) { + document.dispatchEvent(new CustomEvent('zen-ws-toggle', { detail: { paused: true } })); + } + }); + + wsBtn.addEventListener('click', function() { + var paused = wsBtn.dataset.paused !== 'true'; + try { localStorage.setItem(WS_KEY, paused ? 'true' : 'false'); } catch (e) {} + updateWsIcon(paused); + document.dispatchEvent(new CustomEvent('zen-ws-toggle', { detail: { paused: paused } })); + }); + + document.body.appendChild(wsBtn); + } + + if (document.readyState === 'loading') + document.addEventListener('DOMContentLoaded', createToggle); + else + createToggle(); +})(); diff --git a/src/zenserver/frontend/html/util/compactbinary.js b/src/zenserver/frontend/html/util/compactbinary.js index 90e4249f6..415fa4be8 100644 --- a/src/zenserver/frontend/html/util/compactbinary.js +++ b/src/zenserver/frontend/html/util/compactbinary.js @@ -310,8 +310,8 @@ CbFieldView.prototype.as_value = function(int_type=BigInt) case CbFieldType.IntegerPositive: return VarInt.read_uint(this._data_view, int_type)[0]; case CbFieldType.IntegerNegative: return VarInt.read_int(this._data_view, int_type)[0]; - case CbFieldType.Float32: return new DataView(this._data_view.subarray(0, 4)).getFloat32(0, false); - case CbFieldType.Float64: return new DataView(this._data_view.subarray(0, 8)).getFloat64(0, false); + case CbFieldType.Float32: { const s = this._data_view; return new DataView(s.buffer, s.byteOffset, 4).getFloat32(0, false); } + case CbFieldType.Float64: { const s = this._data_view; return new DataView(s.buffer, s.byteOffset, 8).getFloat64(0, false); } case CbFieldType.BoolFalse: return false; case CbFieldType.BoolTrue: return true; diff --git a/src/zenserver/frontend/html/util/friendly.js b/src/zenserver/frontend/html/util/friendly.js index a15252faf..5d4586165 100644 --- a/src/zenserver/frontend/html/util/friendly.js +++ b/src/zenserver/frontend/html/util/friendly.js @@ -20,4 +20,25 @@ export class Friendly static kib(x, p=0) { return Friendly.sep((BigInt(x) + 1023n) / (1n << 10n)|0n, p) + " KiB"; } static mib(x, p=1) { return Friendly.sep( BigInt(x) / (1n << 20n), p) + " MiB"; } static gib(x, p=2) { return Friendly.sep( BigInt(x) / (1n << 30n), p) + " GiB"; } + + static duration(s) + { + const v = Number(s); + if (v >= 1) return Friendly.sep(v, 2) + " s"; + if (v >= 0.001) return Friendly.sep(v * 1000, 2) + " ms"; + if (v >= 0.000001) return Friendly.sep(v * 1000000, 1) + " µs"; + return Friendly.sep(v * 1000000000, 0) + " ns"; + } + + static bytes(x) + { + const v = BigInt(Math.trunc(Number(x))); + if (v >= (1n << 60n)) return Friendly.sep(Number(v) / Number(1n << 60n), 2) + " EiB"; + if (v >= (1n << 50n)) return Friendly.sep(Number(v) / Number(1n << 50n), 2) + " PiB"; + if (v >= (1n << 40n)) return Friendly.sep(Number(v) / Number(1n << 40n), 2) + " TiB"; + if (v >= (1n << 30n)) return Friendly.sep(Number(v) / Number(1n << 30n), 2) + " GiB"; + if (v >= (1n << 20n)) return Friendly.sep(Number(v) / Number(1n << 20n), 1) + " MiB"; + if (v >= (1n << 10n)) return Friendly.sep(Number(v) / Number(1n << 10n), 0) + " KiB"; + return Friendly.sep(Number(v), 0) + " B"; + } } diff --git a/src/zenserver/frontend/html/util/widgets.js b/src/zenserver/frontend/html/util/widgets.js index 32a3f4d28..2964f92f2 100644 --- a/src/zenserver/frontend/html/util/widgets.js +++ b/src/zenserver/frontend/html/util/widgets.js @@ -54,6 +54,8 @@ export class Table extends Widget static Flag_PackRight = 1 << 1; static Flag_BiasLeft = 1 << 2; static Flag_FitLeft = 1 << 3; + static Flag_Sortable = 1 << 4; + static Flag_AlignNumeric = 1 << 5; constructor(parent, column_names, flags=Table.Flag_EvenSpacing, index_base=0) { @@ -81,11 +83,108 @@ export class Table extends Widget root.style("gridTemplateColumns", column_style); - this._add_row(column_names, false); + this._header_row = this._add_row(column_names, false); this._index = index_base; this._num_columns = column_names.length; this._rows = []; + this._flags = flags; + this._sort_column = -1; + this._sort_ascending = true; + + if (flags & Table.Flag_Sortable) + { + this._init_sortable(); + } + } + + _init_sortable() + { + const header_elem = this._element.firstElementChild; + if (!header_elem) + { + return; + } + + const cells = header_elem.children; + for (let i = 0; i < cells.length; i++) + { + const cell = cells[i]; + cell.style.cursor = "pointer"; + cell.style.userSelect = "none"; + cell.addEventListener("click", () => this._sort_by(i)); + } + } + + _sort_by(column_index) + { + if (this._sort_column === column_index) + { + this._sort_ascending = !this._sort_ascending; + } + else + { + this._sort_column = column_index; + this._sort_ascending = true; + } + + // Update header indicators + const header_elem = this._element.firstElementChild; + for (const cell of header_elem.children) + { + const text = cell.textContent.replace(/ [▲▼]$/, ""); + cell.textContent = text; + } + const active_cell = header_elem.children[column_index]; + active_cell.textContent += this._sort_ascending ? " ▲" : " ▼"; + + // Sort rows by comparing cell text content + const dir = this._sort_ascending ? 1 : -1; + const unit_multipliers = { "B": 1, "KiB": 1024, "MiB": 1048576, "GiB": 1073741824, "TiB": 1099511627776, "PiB": 1125899906842624, "EiB": 1152921504606846976 }; + const parse_sortable = (text) => { + // Try byte units first (e.g. "1,234 KiB", "1.5 GiB") + const byte_match = text.match(/^([\d,.]+)\s*(B|[KMGTPE]iB)/); + if (byte_match) + { + const num = parseFloat(byte_match[1].replace(/,/g, "")); + const mult = unit_multipliers[byte_match[2]] || 1; + return num * mult; + } + // Try percentage (e.g. "95.5%") + const pct_match = text.match(/^([\d,.]+)%/); + if (pct_match) + { + return parseFloat(pct_match[1].replace(/,/g, "")); + } + // Try plain number (possibly with commas/separators) + const num = parseFloat(text.replace(/,/g, "")); + if (!isNaN(num)) + { + return num; + } + return null; + }; + this._rows.sort((a, b) => { + const aElem = a.inner().children[column_index]; + const bElem = b.inner().children[column_index]; + const aText = aElem ? aElem.textContent : ""; + const bText = bElem ? bElem.textContent : ""; + + const aNum = parse_sortable(aText); + const bNum = parse_sortable(bText); + + if (aNum !== null && bNum !== null) + { + return (aNum - bNum) * dir; + } + return aText.localeCompare(bText) * dir; + }); + + // Re-order DOM elements + for (const row of this._rows) + { + this._element.appendChild(row.inner()); + } } *[Symbol.iterator]() @@ -121,6 +220,18 @@ export class Table extends Widget ret.push(new TableCell(leaf, row)); } + if ((this._flags & Table.Flag_AlignNumeric) && indexed) + { + for (const c of ret) + { + const t = c.inner().textContent; + if (t && /^\d/.test(t)) + { + c.style("textAlign", "right"); + } + } + } + if (this._index >= 0) ret.shift(); @@ -131,9 +242,34 @@ export class Table extends Widget { var row = this._add_row(args); this._rows.push(row); + + if ((this._flags & Table.Flag_AlignNumeric) && this._rows.length === 1) + { + this._align_header(); + } + return row; } + _align_header() + { + const first_row = this._rows[0]; + if (!first_row) + { + return; + } + const header_elem = this._element.firstElementChild; + const header_cells = header_elem.children; + const data_cells = first_row.inner().children; + for (let i = 0; i < data_cells.length && i < header_cells.length; i++) + { + if (data_cells[i].style.textAlign === "right") + { + header_cells[i].style.textAlign = "right"; + } + } + } + clear(index=0) { const elem = this._element; diff --git a/src/zenserver/frontend/html/zen.css b/src/zenserver/frontend/html/zen.css index cc53c0519..a968aecab 100644 --- a/src/zenserver/frontend/html/zen.css +++ b/src/zenserver/frontend/html/zen.css @@ -2,66 +2,202 @@ /* theme -------------------------------------------------------------------- */ +/* system preference (default) */ @media (prefers-color-scheme: light) { :root { - --theme_g0: #000; - --theme_g4: #fff; - --theme_g1: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 45%); - --theme_g2: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 80%); - --theme_g3: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 96%); - - --theme_p0: #069; - --theme_p4: hsl(210deg 40% 94%); + --theme_g0: #1f2328; + --theme_g1: #656d76; + --theme_g2: #d0d7de; + --theme_g3: #f6f8fa; + --theme_g4: #ffffff; + + --theme_p0: #0969da; + --theme_p4: #ddf4ff; --theme_p1: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 35%); --theme_p2: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 60%); --theme_p3: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 85%); --theme_ln: var(--theme_p0); - --theme_er: #fcc; + --theme_er: #ffebe9; + + --theme_ok: #1a7f37; + --theme_warn: #9a6700; + --theme_fail: #cf222e; + + --theme_bright: #1f2328; + --theme_faint: #6e7781; + --theme_border_subtle: #d8dee4; } } @media (prefers-color-scheme: dark) { :root { - --theme_g0: #ddd; - --theme_g4: #222; - --theme_g1: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 35%); - --theme_g2: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 65%); - --theme_g3: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 88%); - - --theme_p0: #479; - --theme_p4: #333; + --theme_g0: #c9d1d9; + --theme_g1: #8b949e; + --theme_g2: #30363d; + --theme_g3: #161b22; + --theme_g4: #0d1117; + + --theme_p0: #58a6ff; + --theme_p4: #1c2128; --theme_p1: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 35%); --theme_p2: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 60%); --theme_p3: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 85%); - --theme_ln: #feb; - --theme_er: #622; + --theme_ln: #58a6ff; + --theme_er: #1c1c1c; + + --theme_ok: #3fb950; + --theme_warn: #d29922; + --theme_fail: #f85149; + + --theme_bright: #f0f6fc; + --theme_faint: #6e7681; + --theme_border_subtle: #21262d; } } +/* manual overrides (higher specificity than media queries) */ +:root[data-theme="light"] { + --theme_g0: #1f2328; + --theme_g1: #656d76; + --theme_g2: #d0d7de; + --theme_g3: #f6f8fa; + --theme_g4: #ffffff; + + --theme_p0: #0969da; + --theme_p4: #ddf4ff; + --theme_p1: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 35%); + --theme_p2: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 60%); + --theme_p3: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 85%); + + --theme_ln: var(--theme_p0); + --theme_er: #ffebe9; + + --theme_ok: #1a7f37; + --theme_warn: #9a6700; + --theme_fail: #cf222e; + + --theme_bright: #1f2328; + --theme_faint: #6e7781; + --theme_border_subtle: #d8dee4; +} + +:root[data-theme="dark"] { + --theme_g0: #c9d1d9; + --theme_g1: #8b949e; + --theme_g2: #30363d; + --theme_g3: #161b22; + --theme_g4: #0d1117; + + --theme_p0: #58a6ff; + --theme_p4: #1c2128; + --theme_p1: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 35%); + --theme_p2: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 60%); + --theme_p3: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 85%); + + --theme_ln: #58a6ff; + --theme_er: #1c1c1c; + + --theme_ok: #3fb950; + --theme_warn: #d29922; + --theme_fail: #f85149; + + --theme_bright: #f0f6fc; + --theme_faint: #6e7681; + --theme_border_subtle: #21262d; +} + +/* theme toggle ------------------------------------------------------------- */ + +#zen_ws_toggle { + position: fixed; + top: 16px; + right: 60px; + z-index: 10; + width: 36px; + height: 36px; + border-radius: 6px; + border: 1px solid var(--theme_g2); + background: var(--theme_g3); + color: var(--theme_g1); + cursor: pointer; + display: flex; + align-items: center; + justify-content: center; + font-size: 18px; + line-height: 1; + transition: color 0.15s, background 0.15s, border-color 0.15s; + user-select: none; +} + +#zen_ws_toggle:hover { + color: var(--theme_g0); + background: var(--theme_p4); + border-color: var(--theme_g1); +} + +#zen_theme_toggle { + position: fixed; + top: 16px; + right: 16px; + z-index: 10; + width: 36px; + height: 36px; + border-radius: 6px; + border: 1px solid var(--theme_g2); + background: var(--theme_g3); + color: var(--theme_g1); + cursor: pointer; + display: flex; + align-items: center; + justify-content: center; + font-size: 18px; + line-height: 1; + transition: color 0.15s, background 0.15s, border-color 0.15s; + user-select: none; +} + +#zen_theme_toggle:hover { + color: var(--theme_g0); + background: var(--theme_p4); + border-color: var(--theme_g1); +} + /* page --------------------------------------------------------------------- */ -body, input { - font-family: consolas, monospace; - font-size: 11pt; +body, input, button { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; + font-size: 14px; } body { overflow-y: scroll; margin: 0; + padding: 20px; background-color: var(--theme_g4); color: var(--theme_g0); } -pre { - margin: 0; +pre, code { + font-family: 'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace; + font-size: 13px; + margin: 0; } input { color: var(--theme_g0); background-color: var(--theme_g3); border: 1px solid var(--theme_g2); + border-radius: 4px; + padding: 4px 8px; +} + +button { + color: var(--theme_g0); + background: transparent; + border: none; + cursor: pointer; } * { @@ -69,17 +205,44 @@ input { } #container { - max-width: 130em; - min-width: 80em; + max-width: 1400px; margin: auto; > div { - margin: 0.0em 2.2em 0.0em 2.2em; padding-top: 1.0em; padding-bottom: 1.5em; } } +/* service nav -------------------------------------------------------------- */ + +#service_nav { + display: flex; + align-items: center; + gap: 4px; + margin-bottom: 16px; + padding: 4px; + background-color: var(--theme_g3); + border: 1px solid var(--theme_g2); + border-radius: 6px; + + a { + padding: 6px 14px; + border-radius: 4px; + font-size: 13px; + font-weight: 500; + color: var(--theme_g1); + text-decoration: none; + transition: color 0.15s, background 0.15s; + } + + a:hover { + background-color: var(--theme_p4); + color: var(--theme_g0); + text-decoration: none; + } +} + /* links -------------------------------------------------------------------- */ a { @@ -103,28 +266,37 @@ a { } h1 { - font-size: 1.5em; + font-size: 20px; + font-weight: 600; width: 100%; + color: var(--theme_bright); border-bottom: 1px solid var(--theme_g2); + padding-bottom: 0.4em; + margin-bottom: 16px; } h2 { - font-size: 1.25em; - margin-bottom: 0.5em; + font-size: 16px; + font-weight: 600; + margin-bottom: 12px; } h3 { - font-size: 1.1em; + font-size: 14px; + font-weight: 600; margin: 0em; - padding: 0.4em; - background-color: var(--theme_p4); - border-left: 5px solid var(--theme_p2); - font-weight: normal; + padding: 8px 12px; + background-color: var(--theme_g3); + border: 1px solid var(--theme_g2); + border-radius: 6px 6px 0 0; + color: var(--theme_g1); + text-transform: uppercase; + letter-spacing: 0.5px; } - margin-bottom: 3em; + margin-bottom: 2em; > *:not(h1) { - margin-left: 2em; + margin-left: 0; } } @@ -134,23 +306,36 @@ a { .zen_table { display: grid; border: 1px solid var(--theme_g2); - border-left-style: none; + border-radius: 6px; + overflow: hidden; margin-bottom: 1.2em; + font-size: 13px; > div { display: contents; } - > div:nth-of-type(odd) { + > div:nth-of-type(odd) > div { + background-color: var(--theme_g4); + } + + > div:nth-of-type(even) > div { background-color: var(--theme_g3); } > div:first-of-type { - font-weight: bold; - background-color: var(--theme_p3); + font-weight: 600; + > div { + background-color: var(--theme_g3); + color: var(--theme_g1); + text-transform: uppercase; + letter-spacing: 0.5px; + font-size: 11px; + border-bottom: 1px solid var(--theme_g2); + } } - > div:hover { + > div:not(:first-of-type):hover > div { background-color: var(--theme_p4); } @@ -160,16 +345,37 @@ a { } > div > div { - padding: 0.3em; - padding-left: 0.75em; - padding-right: 0.75em; + padding: 8px 12px; align-content: center; - border-left: 1px solid var(--theme_g2); + border-left: 1px solid var(--theme_border_subtle); overflow: auto; overflow-wrap: break-word; - background-color: inherit; white-space: pre-wrap; } + + > div > div:first-child { + border-left: none; + } +} + +/* expandable cell ---------------------------------------------------------- */ + +.zen_expand_icon { + cursor: pointer; + margin-right: 0.5em; + color: var(--theme_g1); + font-weight: bold; + user-select: none; +} + +.zen_expand_icon:hover { + color: var(--theme_ln); +} + +.zen_data_text { + user-select: text; + font-family: 'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace; + font-size: 13px; } /* toolbar ------------------------------------------------------------------ */ @@ -178,6 +384,7 @@ a { display: flex; margin-top: 0.5em; margin-bottom: 0.6em; + font-size: 13px; > div { display: flex; @@ -225,15 +432,16 @@ a { z-index: -1; top: 0; left: 0; - width: 100%; + width: 100%; height: 100%; background: var(--theme_g0); opacity: 0.4; } > div { - border-radius: 0.5em; - background-color: var(--theme_g4); + border-radius: 6px; + background-color: var(--theme_g3); + border: 1px solid var(--theme_g2); opacity: 1.0; width: 35em; padding: 0em 2em 2em 2em; @@ -244,10 +452,11 @@ a { } .zen_modal_title { - font-size: 1.2em; + font-size: 16px; + font-weight: 600; border-bottom: 1px solid var(--theme_g2); padding: 1.2em 0em 0.5em 0em; - color: var(--theme_g1); + color: var(--theme_bright); } .zen_modal_buttons { @@ -257,16 +466,19 @@ a { > div { margin: 0em 1em 0em 1em; - padding: 1em; + padding: 10px 16px; align-content: center; - border-radius: 0.3em; - background-color: var(--theme_p3); + border-radius: 6px; + background-color: var(--theme_p4); + border: 1px solid var(--theme_g2); width: 6em; cursor: pointer; + font-weight: 500; + transition: background 0.15s; } > div:hover { - background-color: var(--theme_p4); + background-color: var(--theme_p3); } } @@ -284,15 +496,18 @@ a { top: 0; left: 0; width: 100%; - height: 0.5em; + height: 4px; + border-radius: 2px; + overflow: hidden; > div:first-of-type { /* label */ padding: 0.3em; - padding-top: 0.8em; - background-color: var(--theme_p4); + padding-top: 8px; + background-color: var(--theme_g3); width: max-content; - font-size: 0.8em; + font-size: 12px; + color: var(--theme_g1); } > div:last-of-type { @@ -302,7 +517,8 @@ a { left: 0; width: 0%; height: 100%; - background-color: var(--theme_p1); + background-color: var(--theme_p0); + transition: width 0.3s ease; } > div:nth-of-type(2) { @@ -312,7 +528,7 @@ a { left: 0; width: 100%; height: 100%; - background-color: var(--theme_p3); + background-color: var(--theme_g3); } } @@ -321,53 +537,25 @@ a { #crumbs { display: flex; position: relative; - top: -1em; + top: -0.5em; + font-size: 13px; + color: var(--theme_g1); > div { padding-right: 0.5em; } > div:nth-child(odd)::after { - content: ":"; - font-weight: bolder; - color: var(--theme_p2); + content: "/"; + color: var(--theme_g2); + padding-left: 0.5em; } } -/* branding ----------------------------------------------------------------- */ - -#branding { - font-size: 10pt; - font-weight: bolder; - margin-bottom: 2.6em; - position: relative; +/* banner ------------------------------------------------------------------- */ - #logo { - width: min-content; - margin: auto; - user-select: none; - position: relative; - - #go_home { - width: 100%; - height: 100%; - position: absolute; - top: 0; - left: 0; - } - } - - #logo:hover { - filter: drop-shadow(0 0.15em 0.1em var(--theme_p2)); - } - - #ue_logo { - position: absolute; - top: 1em; - right: 0; - width: 5em; - margin: auto; - } +zen-banner { + margin-bottom: 24px; } /* error -------------------------------------------------------------------- */ @@ -378,26 +566,23 @@ a { z-index: 1; color: var(--theme_g0); background-color: var(--theme_er); - padding: 1.0em 2em 2em 2em; + padding: 12px 20px 16px 20px; width: 100%; - border-top: 1px solid var(--theme_g0); + border-top: 1px solid var(--theme_g2); display: flex; + gap: 16px; + align-items: center; + font-size: 13px; > div:nth-child(1) { - font-size: 2.5em; - font-weight: bolder; - font-family: serif; - transform: rotate(-13deg); - color: var(--theme_p0); - } - - > div:nth-child(2) { - margin-left: 2em; + font-size: 24px; + font-weight: bold; + color: var(--theme_fail); } > div:nth-child(2) > pre:nth-child(2) { - margin-top: 0.5em; - font-size: 0.8em; + margin-top: 4px; + font-size: 12px; color: var(--theme_g1); } } @@ -409,18 +594,144 @@ a { min-width: 15%; } +/* sections ----------------------------------------------------------------- */ + +.zen_sector { + position: relative; +} + +.dropall { + position: absolute; + top: 16px; + right: 0; + font-size: 12px; + margin: 0; +} + +/* stats tiles -------------------------------------------------------------- */ + +.stats-tiles { + grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); +} + +.stats-tile { + cursor: pointer; + transition: border-color 0.15s, background 0.15s; +} + +.stats-tile:hover { + border-color: var(--theme_p0); + background: var(--theme_p4); +} + +.stats-tile-detailed { + position: relative; +} + +.stats-tile-detailed::after { + content: "details \203A"; + position: absolute; + bottom: 12px; + right: 20px; + font-size: 11px; + color: var(--theme_g1); + opacity: 0.6; + transition: opacity 0.15s; +} + +.stats-tile-detailed:hover::after { + opacity: 1; + color: var(--theme_p0); +} + +.stats-tile-selected { + border-color: var(--theme_p0); + background: var(--theme_p4); + box-shadow: 0 0 0 1px var(--theme_p0); +} + +.stats-tile-selected::after { + content: "details \2039"; + opacity: 1; + color: var(--theme_p0); +} + +.tile-metrics { + display: flex; + flex-direction: column; + gap: 12px; +} + +.tile-columns { + display: flex; + gap: 24px; +} + +.tile-columns > .tile-metrics { + flex: 1; + min-width: 0; +} + +.tile-metric .metric-value { + font-size: 16px; +} + +.tile-metric-hero .metric-value { + font-size: 28px; +} + /* start -------------------------------------------------------------------- */ #start { - .dropall { - text-align: right; - font-size: 0.85em; - margin: -0.5em 0 0.5em 0; - } #version { - color: var(--theme_g1); + color: var(--theme_faint); text-align: center; - font-size: 0.85em; + font-size: 12px; + margin-top: 24px; + } +} + +/* info --------------------------------------------------------------------- */ + +#info { + .info-tiles { + grid-template-columns: repeat(auto-fit, minmax(320px, 1fr)); + } + + .info-tile { + overflow: hidden; + } + + .info-props { + display: flex; + flex-direction: column; + gap: 1px; + font-size: 13px; + } + + .info-prop { + display: flex; + gap: 12px; + padding: 4px 0; + border-bottom: 1px solid var(--theme_border_subtle); + } + + .info-prop:last-child { + border-bottom: none; + } + + .info-prop-label { + color: var(--theme_g1); + min-width: 140px; + flex-shrink: 0; + text-transform: capitalize; + } + + .info-prop-value { + color: var(--theme_bright); + word-break: break-all; + margin-left: auto; + text-align: right; } } @@ -437,6 +748,8 @@ a { /* tree --------------------------------------------------------------------- */ #tree { + font-size: 13px; + #tree_root > ul { margin-left: 0em; } @@ -448,29 +761,33 @@ a { li > div { display: flex; border-bottom: 1px solid transparent; - padding-left: 0.3em; - padding-right: 0.3em; + padding: 4px 6px; + border-radius: 4px; } li > div > div[active] { text-transform: uppercase; + color: var(--theme_p0); + font-weight: 600; } li > div > div:nth-last-child(3) { margin-left: auto; } li > div > div:nth-last-child(-n + 3) { - font-size: 0.8em; + font-size: 12px; width: 10em; text-align: right; + color: var(--theme_g1); + font-family: 'SF Mono', 'Cascadia Mono', Consolas, monospace; } li > div > div:nth-last-child(1) { width: 6em; } li > div:hover { background-color: var(--theme_p4); - border-bottom: 1px solid var(--theme_g2); + border-bottom: 1px solid var(--theme_border_subtle); } li a { - font-weight: bolder; + font-weight: 600; } li::marker { content: "+"; @@ -503,3 +820,262 @@ html:has(#map) { } } } + +/* ========================================================================== */ +/* Shared classes for compute / dashboard pages */ +/* ========================================================================== */ + +/* cards -------------------------------------------------------------------- */ + +.card { + background: var(--theme_g3); + border: 1px solid var(--theme_g2); + border-radius: 6px; + padding: 20px; +} + +.card-title { + font-size: 14px; + font-weight: 600; + color: var(--theme_g1); + margin-bottom: 12px; + text-transform: uppercase; + letter-spacing: 0.5px; +} + +/* grid --------------------------------------------------------------------- */ + +.grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + gap: 20px; + margin-bottom: 24px; +} + +/* metrics ------------------------------------------------------------------ */ + +.metric-value { + font-size: 36px; + font-weight: 600; + color: var(--theme_bright); + line-height: 1; +} + +.metric-label { + font-size: 12px; + color: var(--theme_g1); + margin-top: 4px; +} + +/* section titles ----------------------------------------------------------- */ + +.section-title { + font-size: 20px; + font-weight: 600; + margin-bottom: 16px; + color: var(--theme_bright); +} + +/* html tables (compute pages) ---------------------------------------------- */ + +table { + width: 100%; + border-collapse: collapse; + font-size: 13px; +} + +th { + text-align: left; + color: var(--theme_g1); + padding: 8px 12px; + border-bottom: 1px solid var(--theme_g2); + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.5px; + font-size: 11px; +} + +td { + padding: 8px 12px; + border-bottom: 1px solid var(--theme_border_subtle); + color: var(--theme_g0); +} + +tr:last-child td { + border-bottom: none; +} + +.total-row td { + border-top: 2px solid var(--theme_g2); + font-weight: 600; + color: var(--theme_bright); +} + +/* status badges ------------------------------------------------------------ */ + +.status-badge { + display: inline-block; + padding: 2px 8px; + border-radius: 4px; + font-size: 11px; + font-weight: 600; +} + +.status-badge.active, +.status-badge.success { + background: color-mix(in srgb, var(--theme_ok) 15%, transparent); + color: var(--theme_ok); +} + +.status-badge.inactive { + background: color-mix(in srgb, var(--theme_g1) 15%, transparent); + color: var(--theme_g1); +} + +.status-badge.failure { + background: color-mix(in srgb, var(--theme_fail) 15%, transparent); + color: var(--theme_fail); +} + +/* health dots -------------------------------------------------------------- */ + +.health-dot { + display: inline-block; + width: 10px; + height: 10px; + border-radius: 50%; + background: var(--theme_g1); +} + +.health-green { + background: var(--theme_ok); +} + +.health-yellow { + background: var(--theme_warn); +} + +.health-red { + background: var(--theme_fail); +} + +/* inline progress bar ------------------------------------------------------ */ + +.progress-bar { + width: 100%; + height: 8px; + background: var(--theme_border_subtle); + border-radius: 4px; + overflow: hidden; + margin-top: 8px; +} + +.progress-fill { + height: 100%; + background: var(--theme_p0); + transition: width 0.3s ease; +} + +/* stats row (label + value pair) ------------------------------------------- */ + +.stats-row { + display: flex; + justify-content: space-between; + margin-bottom: 12px; + padding: 8px 0; + border-bottom: 1px solid var(--theme_border_subtle); +} + +.stats-row:last-child { + border-bottom: none; + margin-bottom: 0; +} + +.stats-label { + color: var(--theme_g1); + font-size: 13px; +} + +.stats-value { + color: var(--theme_bright); + font-weight: 600; + font-size: 13px; +} + +/* detail tag (inline badge) ------------------------------------------------ */ + +.detail-tag { + display: inline-block; + padding: 2px 8px; + border-radius: 4px; + background: var(--theme_border_subtle); + color: var(--theme_g0); + font-size: 11px; + margin: 2px 4px 2px 0; +} + +/* timestamp ---------------------------------------------------------------- */ + +.timestamp { + font-size: 12px; + color: var(--theme_faint); +} + +/* inline error ------------------------------------------------------------- */ + +.error { + color: var(--theme_fail); + padding: 12px; + background: var(--theme_er); + border-radius: 6px; + margin: 20px 0; + font-size: 13px; +} + +/* empty state -------------------------------------------------------------- */ + +.empty-state { + color: var(--theme_faint); + font-size: 13px; + padding: 20px 0; + text-align: center; +} + +/* header layout ------------------------------------------------------------ */ + +.header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 24px; +} + +/* history tabs ------------------------------------------------------------- */ + +.history-tabs { + display: flex; + gap: 4px; + background: var(--theme_g4); + border-radius: 6px; + padding: 2px; +} + +.history-tab { + background: transparent; + color: var(--theme_g1); + font-size: 12px; + font-weight: 600; + padding: 4px 12px; + border-radius: 4px; + text-transform: uppercase; + letter-spacing: 0.5px; +} + +.history-tab:hover { + color: var(--theme_g0); +} + +.history-tab.active { + background: var(--theme_g2); + color: var(--theme_bright); +} diff --git a/src/zenserver/frontend/zipfs.cpp b/src/zenserver/frontend/zipfs.cpp index f9c2bc8ff..42df0520f 100644 --- a/src/zenserver/frontend/zipfs.cpp +++ b/src/zenserver/frontend/zipfs.cpp @@ -149,13 +149,25 @@ ZipFs::ZipFs(IoBuffer&& Buffer) IoBuffer ZipFs::GetFile(const std::string_view& FileName) const { - FileMap::iterator Iter = m_Files.find(FileName); - if (Iter == m_Files.end()) { - return {}; + RwLock::SharedLockScope _(m_FilesLock); + + FileMap::const_iterator Iter = m_Files.find(FileName); + if (Iter == m_Files.end()) + { + return {}; + } + + const FileItem& Item = Iter->second; + if (Item.GetSize() > 0) + { + return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); + } } - FileItem& Item = Iter->second; + RwLock::ExclusiveLockScope _(m_FilesLock); + + FileItem& Item = m_Files.find(FileName)->second; if (Item.GetSize() > 0) { return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); diff --git a/src/zenserver/frontend/zipfs.h b/src/zenserver/frontend/zipfs.h index 1fa7da451..645121693 100644 --- a/src/zenserver/frontend/zipfs.h +++ b/src/zenserver/frontend/zipfs.h @@ -3,23 +3,23 @@ #pragma once #include <zencore/iobuffer.h> +#include <zencore/thread.h> #include <unordered_map> namespace zen { -////////////////////////////////////////////////////////////////////////// class ZipFs { public: - ZipFs() = default; - ZipFs(IoBuffer&& Buffer); + explicit ZipFs(IoBuffer&& Buffer); + IoBuffer GetFile(const std::string_view& FileName) const; - inline operator bool() const { return !m_Files.empty(); } private: using FileItem = MemoryView; using FileMap = std::unordered_map<std::string_view, FileItem>; + mutable RwLock m_FilesLock; FileMap mutable m_Files; IoBuffer m_Buffer; }; diff --git a/src/zenserver/hub/hubservice.cpp b/src/zenserver/hub/hubservice.cpp index 4d9da3a57..7b999ae20 100644 --- a/src/zenserver/hub/hubservice.cpp +++ b/src/zenserver/hub/hubservice.cpp @@ -4,10 +4,12 @@ #include "hydration.h" +#include <zencore/assertfmt.h> #include <zencore/compactbinarybuilder.h> #include <zencore/filesystem.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> +#include <zencore/process.h> #include <zencore/scopeguard.h> #include <zencore/system.h> #include <zenutil/zenserverprocess.h> @@ -150,7 +152,12 @@ struct StorageServerInstance inline uint16_t GetBasePort() const { return m_ServerInstance.GetBasePort(); } +#if ZEN_PLATFORM_WINDOWS + void SetJobObject(JobObject* InJobObject) { m_JobObject = InJobObject; } +#endif + private: + void WakeLocked(); RwLock m_Lock; std::string m_ModuleId; std::atomic<bool> m_IsProvisioned{false}; @@ -160,6 +167,9 @@ private: std::filesystem::path m_TempDir; std::filesystem::path m_HydrationPath; ResourceMetrics m_ResourceMetrics; +#if ZEN_PLATFORM_WINDOWS + JobObject* m_JobObject = nullptr; +#endif void SpawnServerProcess(); @@ -186,10 +196,13 @@ StorageServerInstance::~StorageServerInstance() void StorageServerInstance::SpawnServerProcess() { - ZEN_ASSERT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId); + ZEN_ASSERT_FORMAT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId); m_ServerInstance.SetServerExecutablePath(GetRunningExecutablePath()); m_ServerInstance.SetDataDir(m_BaseDir); +#if ZEN_PLATFORM_WINDOWS + m_ServerInstance.SetJobObject(m_JobObject); +#endif const uint16_t BasePort = m_ServerInstance.SpawnServerAndWaitUntilReady(); ZEN_DEBUG("Storage server instance for module '{}' started, listening on port {}", m_ModuleId, BasePort); @@ -211,7 +224,7 @@ StorageServerInstance::Provision() if (m_IsHibernated) { - Wake(); + WakeLocked(); } else { @@ -294,9 +307,14 @@ StorageServerInstance::Hibernate() void StorageServerInstance::Wake() { - // Start server in-place using existing data - RwLock::ExclusiveLockScope _(m_Lock); + WakeLocked(); +} + +void +StorageServerInstance::WakeLocked() +{ + // Start server in-place using existing data if (!m_IsHibernated) { @@ -305,7 +323,7 @@ StorageServerInstance::Wake() return; } - ZEN_ASSERT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId); + ZEN_ASSERT_FORMAT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId); try { @@ -374,6 +392,21 @@ struct HttpHubService::Impl // flexibility, and to allow running multiple hubs on the same host if // necessary. m_RunEnvironment.SetNextPortNumber(21000); + +#if ZEN_PLATFORM_WINDOWS + if (m_UseJobObject) + { + m_JobObject.Initialize(); + if (m_JobObject.IsValid()) + { + ZEN_INFO("Job object initialized for hub service child process management"); + } + else + { + ZEN_WARN("Failed to initialize job object; child processes will not be auto-terminated on hub crash"); + } + } +#endif } void Cleanup() @@ -416,6 +449,12 @@ struct HttpHubService::Impl IsNewInstance = true; auto NewInstance = std::make_unique<StorageServerInstance>(m_RunEnvironment, ModuleId, m_FileHydrationPath, m_HydrationTempPath); +#if ZEN_PLATFORM_WINDOWS + if (m_JobObject.IsValid()) + { + NewInstance->SetJobObject(&m_JobObject); + } +#endif Instance = NewInstance.get(); m_Instances.emplace(std::string(ModuleId), std::move(NewInstance)); @@ -573,10 +612,15 @@ struct HttpHubService::Impl inline int GetInstanceLimit() { return m_InstanceLimit; } inline int GetMaxInstanceCount() { return m_MaxInstanceCount; } + bool m_UseJobObject = true; + private: - ZenServerEnvironment m_RunEnvironment; - std::filesystem::path m_FileHydrationPath; - std::filesystem::path m_HydrationTempPath; + ZenServerEnvironment m_RunEnvironment; + std::filesystem::path m_FileHydrationPath; + std::filesystem::path m_HydrationTempPath; +#if ZEN_PLATFORM_WINDOWS + JobObject m_JobObject; +#endif RwLock m_Lock; std::unordered_map<std::string, std::unique_ptr<StorageServerInstance>> m_Instances; std::unordered_set<std::string> m_DeprovisioningModules; @@ -802,7 +846,7 @@ HttpHubService::HttpHubService(std::filesystem::path HubBaseDir, std::filesystem Obj << "currentInstanceCount" << m_Impl->GetInstanceCount(); Obj << "maxInstanceCount" << m_Impl->GetMaxInstanceCount(); Obj << "instanceLimit" << m_Impl->GetInstanceLimit(); - Req.ServerRequest().WriteResponse(HttpResponseCode::OK); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); }, HttpVerb::kGet); } @@ -811,6 +855,12 @@ HttpHubService::~HttpHubService() { } +void +HttpHubService::SetUseJobObject(bool Enable) +{ + m_Impl->m_UseJobObject = Enable; +} + const char* HttpHubService::BaseUri() const { diff --git a/src/zenserver/hub/hubservice.h b/src/zenserver/hub/hubservice.h index 1a5a8c57c..ef24bba69 100644 --- a/src/zenserver/hub/hubservice.h +++ b/src/zenserver/hub/hubservice.h @@ -28,6 +28,13 @@ public: void SetNotificationEndpoint(std::string_view UpstreamNotificationEndpoint, std::string_view InstanceId); + /** Enable or disable the use of a Windows Job Object for child process management. + * When enabled, all spawned child processes are assigned to a job object with + * JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, ensuring children are terminated if the hub + * crashes or is force-killed. Must be called before Initialize(). No-op on non-Windows. + */ + void SetUseJobObject(bool Enable); + private: HttpRequestRouter m_Router; diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp index 7a4ba951d..c6d2dc8d4 100644 --- a/src/zenserver/hub/zenhubserver.cpp +++ b/src/zenserver/hub/zenhubserver.cpp @@ -105,7 +105,7 @@ ZenHubServer::Initialize(const ZenHubServerConfig& ServerConfig, ZenServerState: void ZenHubServer::Cleanup() { - ZEN_TRACE_CPU("ZenStorageServer::Cleanup"); + ZEN_TRACE_CPU("ZenHubServer::Cleanup"); ZEN_INFO(ZEN_APP_NAME " cleaning up"); try { @@ -115,6 +115,8 @@ ZenHubServer::Cleanup() m_IoRunner.join(); } + ShutdownServices(); + if (m_Http) { m_Http->Close(); @@ -143,6 +145,8 @@ ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) ZEN_INFO("instantiating hub service"); m_HubService = std::make_unique<HttpHubService>(ServerConfig.DataDir / "hub", ServerConfig.DataDir / "servers"); m_HubService->SetNotificationEndpoint(ServerConfig.UpstreamNotificationEndpoint, ServerConfig.InstanceId); + + m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatusService); } void @@ -159,6 +163,11 @@ ZenHubServer::RegisterServices(const ZenHubServerConfig& ServerConfig) { m_Http->RegisterService(*m_ApiService); } + + if (m_FrontendService) + { + m_Http->RegisterService(*m_FrontendService); + } } void diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h index ac14362f0..4c56fdce5 100644 --- a/src/zenserver/hub/zenhubserver.h +++ b/src/zenserver/hub/zenhubserver.h @@ -2,6 +2,7 @@ #pragma once +#include "frontend/frontend.h" #include "zenserver.h" namespace cxxopts { @@ -81,8 +82,9 @@ private: std::filesystem::path m_ContentRoot; bool m_DebugOptionForcedCrash = false; - std::unique_ptr<HttpHubService> m_HubService; - std::unique_ptr<HttpApiService> m_ApiService; + std::unique_ptr<HttpHubService> m_HubService; + std::unique_ptr<HttpApiService> m_ApiService; + std::unique_ptr<HttpFrontendService> m_FrontendService; void InitializeState(const ZenHubServerConfig& ServerConfig); void InitializeServices(const ZenHubServerConfig& ServerConfig); diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp index 3a58d1f4a..09ecc48e5 100644 --- a/src/zenserver/main.cpp +++ b/src/zenserver/main.cpp @@ -19,10 +19,13 @@ #include <zencore/thread.h> #include <zencore/trace.h> #include <zentelemetry/otlptrace.h> -#include <zenutil/commandlineoptions.h> +#include <zenutil/config/commandlineoptions.h> #include <zenutil/service.h> #include "diag/logging.h" + +#include "compute/computeserver.h" + #include "storage/storageconfig.h" #include "storage/zenstorageserver.h" @@ -38,7 +41,6 @@ // in some shared code into the executable #if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 # include <zencore/testing.h> #endif @@ -61,11 +63,19 @@ namespace zen { #if ZEN_PLATFORM_WINDOWS -template<class T> +/** Windows Service wrapper for Zen servers + * + * This class wraps a Zen server main entry point (the Main template parameter) + * into a Windows Service by implementing the WindowsService interface. + * + * The Main type needs to implement the virtual functions from the ZenServerMain + * base class, which provides the actual server logic. + */ +template<class Main> class ZenWindowsService : public WindowsService { public: - ZenWindowsService(typename T::Config& ServerOptions) : m_EntryPoint(ServerOptions) {} + ZenWindowsService(typename Main::Config& ServerOptions) : m_EntryPoint(ServerOptions) {} ZenWindowsService(const ZenWindowsService&) = delete; ZenWindowsService& operator=(const ZenWindowsService&) = delete; @@ -73,7 +83,7 @@ public: virtual int Run() override { return m_EntryPoint.Run(); } private: - T m_EntryPoint; + Main m_EntryPoint; }; #endif // ZEN_PLATFORM_WINDOWS @@ -84,6 +94,23 @@ private: namespace zen { +/** Application main entry point template + * + * This function handles common application startup tasks while allowing + * different server types to be plugged in via the Main template parameter. + * + * On Windows, this function also handles platform-specific service + * installation and uninstallation. + * + * The Main type needs to implement the virtual functions from the ZenServerMain + * base class, which provides the actual server logic. + * + * The Main type is also expected to provide the following members: + * + * typedef Config -- Server configuration type, derived from ZenServerConfig + * typedef Configurator -- Server configuration handler type, implements ZenServerConfiguratorBase + * + */ template<class Main> int AppMain(int argc, char* argv[]) @@ -219,7 +246,7 @@ test_main(int argc, char** argv) # endif // ZEN_PLATFORM_WINDOWS zen::logging::InitializeLogging(); - zen::logging::SetLogLevel(zen::logging::level::Debug); + zen::logging::SetLogLevel(zen::logging::Debug); zen::MaximizeOpenFileCount(); @@ -239,16 +266,31 @@ main(int argc, char* argv[]) using namespace zen; using namespace std::literals; + // note: doctest has locally (in thirdparty) been fixed to not cause shutdown + // crashes due to TLS destructors + // + // mimalloc on the other hand might still be causing issues, in which case + // we should work out either how to eliminate the mimalloc dependency or how + // to configure it in a way that doesn't cause shutdown issues + +#if 0 auto _ = zen::MakeGuard([] { // Allow some time for worker threads to unravel, in an effort - // to prevent shutdown races in TLS object destruction + // to prevent shutdown races in TLS object destruction, mainly due to + // threads which we don't directly control (Windows thread pool) and + // therefore can't join. + // + // This isn't a great solution, but for now it seems to help reduce + // shutdown crashes observed in some situations. WaitForThreads(1000); }); +#endif enum { kHub, kStore, + kCompute, kTest } ServerMode = kStore; @@ -258,10 +300,14 @@ main(int argc, char* argv[]) { ServerMode = kHub; } - else if (argv[1] == "store"sv) + else if ((argv[1] == "store"sv) || (argv[1] == "storage"sv)) { ServerMode = kStore; } + else if (argv[1] == "compute"sv) + { + ServerMode = kCompute; + } else if (argv[1] == "test"sv) { ServerMode = kTest; @@ -280,6 +326,13 @@ main(int argc, char* argv[]) break; case kHub: return AppMain<ZenHubServerMain>(argc, argv); + case kCompute: +#if ZEN_WITH_COMPUTE_SERVICES + return AppMain<ZenComputeServerMain>(argc, argv); +#else + fprintf(stderr, "compute services are not compiled in!\n"); + exit(5); +#endif default: case kStore: return AppMain<ZenStorageServerMain>(argc, argv); diff --git a/src/zenserver/sessions/httpsessions.cpp b/src/zenserver/sessions/httpsessions.cpp new file mode 100644 index 000000000..05be3c814 --- /dev/null +++ b/src/zenserver/sessions/httpsessions.cpp @@ -0,0 +1,264 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpsessions.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/trace.h> +#include "sessions.h" + +namespace zen { +using namespace std::literals; + +HttpSessionsService::HttpSessionsService(HttpStatusService& StatusService, HttpStatsService& StatsService, SessionsService& Sessions) +: m_Log(logging::Get("sessions")) +, m_StatusService(StatusService) +, m_StatsService(StatsService) +, m_Sessions(Sessions) +{ + Initialize(); +} + +HttpSessionsService::~HttpSessionsService() +{ + m_StatsService.UnregisterHandler("sessions", *this); + m_StatusService.UnregisterHandler("sessions", *this); +} + +const char* +HttpSessionsService::BaseUri() const +{ + return "/sessions/"; +} + +void +HttpSessionsService::HandleRequest(HttpServerRequest& Request) +{ + metrics::OperationTiming::Scope $(m_HttpRequests); + + if (m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + return Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv); + } +} + +CbObject +HttpSessionsService::CollectStats() +{ + ZEN_TRACE_CPU("SessionsService::Stats"); + CbObjectWriter Cbo; + + EmitSnapshot("requests", m_HttpRequests, Cbo); + + Cbo.BeginObject("sessions"); + { + Cbo << "readcount" << m_SessionsStats.SessionReadCount; + Cbo << "writecount" << m_SessionsStats.SessionWriteCount; + Cbo << "deletecount" << m_SessionsStats.SessionDeleteCount; + Cbo << "listcount" << m_SessionsStats.SessionListCount; + Cbo << "requestcount" << m_SessionsStats.RequestCount; + Cbo << "badrequestcount" << m_SessionsStats.BadRequestCount; + Cbo << "count" << m_Sessions.GetSessionCount(); + } + Cbo.EndObject(); + + return Cbo.Save(); +} + +void +HttpSessionsService::HandleStatsRequest(HttpServerRequest& HttpReq) +{ + HttpReq.WriteResponse(HttpResponseCode::OK, CollectStats()); +} + +void +HttpSessionsService::HandleStatusRequest(HttpServerRequest& Request) +{ + ZEN_TRACE_CPU("HttpSessionsService::Status"); + CbObjectWriter Cbo; + Cbo << "ok" << true; + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpSessionsService::Initialize() +{ + using namespace std::literals; + + ZEN_INFO("Initializing Sessions Service"); + + static constexpr AsciiSet ValidHexCharactersSet{"0123456789abcdefABCDEF"}; + + m_Router.AddMatcher("session_id", [](std::string_view Str) -> bool { + return Str.length() == Oid::StringLength && AsciiSet::HasOnly(Str, ValidHexCharactersSet); + }); + + m_Router.RegisterRoute( + "list", + [this](HttpRouterRequest& Req) { ListSessionsRequest(Req); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{session_id}", + [this](HttpRouterRequest& Req) { SessionRequest(Req); }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "", + [this](HttpRouterRequest& Req) { ListSessionsRequest(Req); }, + HttpVerb::kGet); + + m_StatsService.RegisterHandler("sessions", *this); + m_StatusService.RegisterHandler("sessions", *this); +} + +static void +WriteSessionInfo(CbWriter& Writer, const SessionsService::SessionInfo& Info) +{ + Writer << "id" << Info.Id; + if (!Info.AppName.empty()) + { + Writer << "appname" << Info.AppName; + } + if (Info.JobId != Oid::Zero) + { + Writer << "jobid" << Info.JobId; + } + Writer << "created_at" << Info.CreatedAt; + Writer << "updated_at" << Info.UpdatedAt; + + if (Info.Metadata.GetSize() > 0) + { + Writer.BeginObject("metadata"); + for (const CbField& Field : Info.Metadata) + { + Writer.AddField(Field); + } + Writer.EndObject(); + } +} + +void +HttpSessionsService::ListSessionsRequest(HttpRouterRequest& Req) +{ + HttpServerRequest& ServerRequest = Req.ServerRequest(); + + m_SessionsStats.SessionListCount++; + m_SessionsStats.RequestCount++; + + std::vector<Ref<SessionsService::Session>> Sessions = m_Sessions.GetSessions(); + + CbObjectWriter Response; + Response.BeginArray("sessions"); + for (const Ref<SessionsService::Session>& Session : Sessions) + { + Response.BeginObject(); + { + WriteSessionInfo(Response, Session->Info()); + } + Response.EndObject(); + } + Response.EndArray(); + + return ServerRequest.WriteResponse(HttpResponseCode::OK, Response.Save()); +} + +void +HttpSessionsService::SessionRequest(HttpRouterRequest& Req) +{ + HttpServerRequest& ServerRequest = Req.ServerRequest(); + + const Oid SessionId = Oid::TryFromHexString(Req.GetCapture(1)); + if (SessionId == Oid::Zero) + { + m_SessionsStats.BadRequestCount++; + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Invalid session id '{}'", Req.GetCapture(1))); + } + + m_SessionsStats.RequestCount++; + + switch (ServerRequest.RequestVerb()) + { + case HttpVerb::kPost: + case HttpVerb::kPut: + { + IoBuffer Payload = ServerRequest.ReadPayload(); + CbObject RequestObject; + + if (Payload.GetSize() > 0) + { + if (CbValidateError ValidationResult = ValidateCompactBinary(Payload.GetView(), CbValidateMode::All); + ValidationResult != CbValidateError::None) + { + m_SessionsStats.BadRequestCount++; + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Invalid payload: {}", zen::ToString(ValidationResult))); + } + RequestObject = LoadCompactBinaryObject(Payload); + } + + if (ServerRequest.RequestVerb() == HttpVerb::kPost) + { + std::string AppName(RequestObject["appname"sv].AsString()); + Oid JobId = RequestObject["jobid"sv].AsObjectId(); + CbObjectView MetadataView = RequestObject["metadata"sv].AsObjectView(); + + m_SessionsStats.SessionWriteCount++; + if (m_Sessions.RegisterSession(SessionId, std::move(AppName), JobId, MetadataView)) + { + return ServerRequest.WriteResponse(HttpResponseCode::Created, HttpContentType::kText, fmt::format("{}", SessionId)); + } + else + { + // Already exists - try update instead + if (m_Sessions.UpdateSession(SessionId, MetadataView)) + { + return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, fmt::format("{}", SessionId)); + } + return ServerRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + } + else + { + // PUT - update only + m_SessionsStats.SessionWriteCount++; + if (m_Sessions.UpdateSession(SessionId, RequestObject["metadata"sv].AsObjectView())) + { + return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, fmt::format("{}", SessionId)); + } + return ServerRequest.WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("Session '{}' not found", SessionId)); + } + } + case HttpVerb::kGet: + { + m_SessionsStats.SessionReadCount++; + Ref<SessionsService::Session> Session = m_Sessions.GetSession(SessionId); + if (Session) + { + CbObjectWriter Response; + WriteSessionInfo(Response, Session->Info()); + return ServerRequest.WriteResponse(HttpResponseCode::OK, Response.Save()); + } + return ServerRequest.WriteResponse(HttpResponseCode::NotFound); + } + case HttpVerb::kDelete: + { + m_SessionsStats.SessionDeleteCount++; + if (m_Sessions.RemoveSession(SessionId)) + { + return ServerRequest.WriteResponse(HttpResponseCode::OK); + } + return ServerRequest.WriteResponse(HttpResponseCode::NotFound); + } + } +} + +} // namespace zen diff --git a/src/zenserver/sessions/httpsessions.h b/src/zenserver/sessions/httpsessions.h new file mode 100644 index 000000000..e07f3b59b --- /dev/null +++ b/src/zenserver/sessions/httpsessions.h @@ -0,0 +1,55 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> +#include <zenhttp/httpstats.h> +#include <zenhttp/httpstatus.h> +#include <zentelemetry/stats.h> + +namespace zen { + +class SessionsService; + +class HttpSessionsService final : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider +{ +public: + HttpSessionsService(HttpStatusService& StatusService, HttpStatsService& StatsService, SessionsService& Sessions); + virtual ~HttpSessionsService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + + virtual CbObject CollectStats() override; + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + virtual void HandleStatusRequest(HttpServerRequest& Request) override; + +private: + struct SessionsStats + { + std::atomic_uint64_t SessionReadCount{}; + std::atomic_uint64_t SessionWriteCount{}; + std::atomic_uint64_t SessionDeleteCount{}; + std::atomic_uint64_t SessionListCount{}; + std::atomic_uint64_t RequestCount{}; + std::atomic_uint64_t BadRequestCount{}; + }; + + inline LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + + void Initialize(); + + void ListSessionsRequest(HttpRouterRequest& Req); + void SessionRequest(HttpRouterRequest& Req); + + HttpStatusService& m_StatusService; + HttpStatsService& m_StatsService; + HttpRequestRouter m_Router; + SessionsService& m_Sessions; + SessionsStats m_SessionsStats; + metrics::OperationTiming m_HttpRequests; +}; + +} // namespace zen diff --git a/src/zenserver/sessions/sessions.cpp b/src/zenserver/sessions/sessions.cpp new file mode 100644 index 000000000..f73aa40ff --- /dev/null +++ b/src/zenserver/sessions/sessions.cpp @@ -0,0 +1,150 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "sessions.h" + +#include <zencore/basicfile.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> + +namespace zen { +using namespace std::literals; + +class SessionLog : public TRefCounted<SessionLog> +{ +public: + SessionLog(std::filesystem::path LogFilePath) { m_LogFile.Open(LogFilePath, BasicFile::Mode::kWrite); } + +private: + BasicFile m_LogFile; +}; + +class SessionLogStore +{ +public: + SessionLogStore(std::filesystem::path StoragePath) : m_StoragePath(std::move(StoragePath)) {} + + ~SessionLogStore() = default; + + Ref<SessionLog> GetLogForSession(const Oid& SessionId) + { + // For now, just return a new log for each session. We can implement actual log storage and retrieval later. + return Ref(new SessionLog(m_StoragePath / (SessionId.ToString() + ".log"))); + } + + Ref<SessionLog> CreateLogForSession(const Oid& SessionId) + { + // For now, just return a new log for each session. We can implement actual log storage and retrieval later. + return Ref(new SessionLog(m_StoragePath / (SessionId.ToString() + ".log"))); + } + +private: + std::filesystem::path m_StoragePath; +}; + +SessionsService::Session::Session(const SessionInfo& Info) : m_Info(Info) +{ +} +SessionsService::Session::~Session() = default; + +////////////////////////////////////////////////////////////////////////// + +SessionsService::SessionsService() : m_Log(logging::Get("sessions")) +{ +} + +SessionsService::~SessionsService() = default; + +bool +SessionsService::RegisterSession(const Oid& SessionId, std::string AppName, const Oid& JobId, CbObjectView Metadata) +{ + RwLock::ExclusiveLockScope Lock(m_Lock); + + if (m_Sessions.contains(SessionId)) + { + return false; + } + + const DateTime Now = DateTime::Now(); + m_Sessions.emplace(SessionId, + Ref(new Session(SessionInfo{.Id = SessionId, + .AppName = std::move(AppName), + .JobId = JobId, + .Metadata = CbObject::Clone(Metadata), + .CreatedAt = Now, + .UpdatedAt = Now}))); + + ZEN_INFO("Session {} registered (AppName: {}, JobId: {})", SessionId, AppName, JobId); + return true; +} + +bool +SessionsService::UpdateSession(const Oid& SessionId, CbObjectView Metadata) +{ + RwLock::ExclusiveLockScope Lock(m_Lock); + + auto It = m_Sessions.find(SessionId); + if (It == m_Sessions.end()) + { + return false; + } + + It.value()->UpdateMetadata(Metadata); + + const SessionInfo& Info = It.value()->Info(); + ZEN_DEBUG("Session {} updated (AppName: {}, JobId: {})", SessionId, Info.AppName, Info.JobId); + return true; +} + +Ref<SessionsService::Session> +SessionsService::GetSession(const Oid& SessionId) const +{ + RwLock::SharedLockScope Lock(m_Lock); + + auto It = m_Sessions.find(SessionId); + if (It == m_Sessions.end()) + { + return {}; + } + + return It->second; +} + +std::vector<Ref<SessionsService::Session>> +SessionsService::GetSessions() const +{ + RwLock::SharedLockScope Lock(m_Lock); + + std::vector<Ref<Session>> Result; + Result.reserve(m_Sessions.size()); + for (const auto& [Id, SessionRef] : m_Sessions) + { + Result.push_back(SessionRef); + } + return Result; +} + +bool +SessionsService::RemoveSession(const Oid& SessionId) +{ + RwLock::ExclusiveLockScope Lock(m_Lock); + + auto It = m_Sessions.find(SessionId); + if (It == m_Sessions.end()) + { + return false; + } + + ZEN_INFO("Session {} removed (AppName: {}, JobId: {})", SessionId, It.value()->Info().AppName, It.value()->Info().JobId); + + m_Sessions.erase(It); + return true; +} + +uint64_t +SessionsService::GetSessionCount() const +{ + RwLock::SharedLockScope Lock(m_Lock); + return m_Sessions.size(); +} + +} // namespace zen diff --git a/src/zenserver/sessions/sessions.h b/src/zenserver/sessions/sessions.h new file mode 100644 index 000000000..db9704430 --- /dev/null +++ b/src/zenserver/sessions/sessions.h @@ -0,0 +1,83 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/logbase.h> +#include <zencore/thread.h> +#include <zencore/uid.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <optional> +#include <string> +#include <vector> + +namespace zen { + +class SessionLogStore; +class SessionLog; + +/** Session tracker + * + * Acts as a log and session info concentrator when dealing with multiple + * servers and external processes acting as a group. + */ + +class SessionsService +{ +public: + struct SessionInfo + { + Oid Id; + std::string AppName; + Oid JobId; + CbObject Metadata; + DateTime CreatedAt; + DateTime UpdatedAt; + }; + + class Session : public TRefCounted<Session> + { + public: + Session(const SessionInfo& Info); + ~Session(); + + Session(Session&&) = delete; + Session& operator=(Session&&) = delete; + + const SessionInfo& Info() const { return m_Info; } + void UpdateMetadata(CbObjectView Metadata) + { + // Should this be additive rather than replacing the whole thing? We'll see. + m_Info.Metadata = CbObject::Clone(Metadata); + m_Info.UpdatedAt = DateTime::Now(); + } + + private: + SessionInfo m_Info; + Ref<SessionLog> m_Log; + }; + + SessionsService(); + ~SessionsService(); + + bool RegisterSession(const Oid& SessionId, std::string AppName, const Oid& JobId, CbObjectView Metadata); + bool UpdateSession(const Oid& SessionId, CbObjectView Metadata); + Ref<Session> GetSession(const Oid& SessionId) const; + std::vector<Ref<Session>> GetSessions() const; + bool RemoveSession(const Oid& SessionId); + uint64_t GetSessionCount() const; + +private: + LoggerRef& Log() { return m_Log; } + + LoggerRef m_Log; + mutable RwLock m_Lock; + tsl::robin_map<Oid, Ref<Session>, Oid::Hasher> m_Sessions; + std::unique_ptr<SessionLogStore> m_SessionLogs; +}; + +} // namespace zen diff --git a/src/zenserver/storage/admin/admin.cpp b/src/zenserver/storage/admin/admin.cpp index 19155e02b..c9f999c69 100644 --- a/src/zenserver/storage/admin/admin.cpp +++ b/src/zenserver/storage/admin/admin.cpp @@ -716,7 +716,7 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler, "logs", [this](HttpRouterRequest& Req) { CbObjectWriter Obj; - auto LogLevel = logging::level::ToStringView(logging::GetLogLevel()); + auto LogLevel = logging::ToStringView(logging::GetLogLevel()); Obj.AddString("loglevel", std::string_view(LogLevel.data(), LogLevel.size())); Obj.AddString("Logfile", PathToUtf8(m_LogPaths.AbsLogPath)); Obj.BeginObject("cache"); @@ -767,8 +767,8 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler, } if (std::string Param(Params.GetValue("loglevel")); Param.empty() == false) { - logging::level::LogLevel NewLevel = logging::level::ParseLogLevelString(Param); - std::string_view LogLevel = logging::level::ToStringView(NewLevel); + logging::LogLevel NewLevel = logging::ParseLogLevelString(Param); + std::string_view LogLevel = logging::ToStringView(NewLevel); if (LogLevel != Param) { return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest, diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp index f5ba30616..de9589078 100644 --- a/src/zenserver/storage/buildstore/httpbuildstore.cpp +++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp @@ -71,7 +71,7 @@ HttpBuildStoreService::Initialize() m_Router.RegisterRoute( "{namespace}/{bucket}/{buildid}/blobs/{hash}", [this](HttpRouterRequest& Req) { GetBlobRequest(Req); }, - HttpVerb::kGet); + HttpVerb::kGet | HttpVerb::kPost); m_Router.RegisterRoute( "{namespace}/{bucket}/{buildid}/blobs/putBlobMetadata", @@ -161,14 +161,57 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) HttpContentType::kText, fmt::format("Invalid blob hash '{}'", Hash)); } - zen::HttpRanges Ranges; - bool HasRange = ServerRequest.TryGetRanges(Ranges); - if (Ranges.size() > 1) + + std::vector<std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs; + if (ServerRequest.RequestVerb() == HttpVerb::kPost) { - // Only a single range is supported - return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, - HttpContentType::kText, - "Multiple ranges in blob request is not supported"); + CbObject RangePayload = ServerRequest.ReadPayloadObject(); + if (RangePayload) + { + CbArrayView RangesArray = RangePayload["ranges"sv].AsArrayView(); + OffsetAndLengthPairs.reserve(RangesArray.Num()); + for (CbFieldView FieldView : RangesArray) + { + CbObjectView RangeView = FieldView.AsObjectView(); + uint64_t RangeOffset = RangeView["offset"sv].AsUInt64(); + uint64_t RangeLength = RangeView["length"sv].AsUInt64(); + OffsetAndLengthPairs.push_back(std::make_pair(RangeOffset, RangeLength)); + } + if (OffsetAndLengthPairs.size() > MaxRangeCountPerRequestSupported) + { + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Number of ranges ({}) for blob request exceeds maximum range count {}", + OffsetAndLengthPairs.size(), + MaxRangeCountPerRequestSupported)); + } + } + if (OffsetAndLengthPairs.empty()) + { + m_BuildStoreStats.BadRequestCount++; + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Fetching blob without ranges must be done with the GET verb"); + } + } + else + { + HttpRanges Ranges; + bool HasRange = ServerRequest.TryGetRanges(Ranges); + if (HasRange) + { + if (Ranges.size() > 1) + { + // Only a single http range is supported, we have limited support for http multirange responses + m_BuildStoreStats.BadRequestCount++; + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Multiple ranges in blob request is only supported for {} accept type", + ToString(HttpContentType::kCbPackage))); + } + const HttpRange& FirstRange = Ranges.front(); + OffsetAndLengthPairs.push_back(std::make_pair<uint64_t, uint64_t>(FirstRange.Start, FirstRange.End - FirstRange.Start + 1)); + } } m_BuildStoreStats.BlobReadCount++; @@ -179,24 +222,79 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) HttpContentType::kText, fmt::format("Blob with hash '{}' could not be found", Hash)); } - // ZEN_INFO("Fetched blob {}. Size: {}", BlobHash, Blob.GetSize()); m_BuildStoreStats.BlobHitCount++; - if (HasRange) + + if (OffsetAndLengthPairs.empty()) + { + return ServerRequest.WriteResponse(HttpResponseCode::OK, Blob.GetContentType(), Blob); + } + + if (ServerRequest.AcceptContentType() == HttpContentType::kCbPackage) { - const HttpRange& Range = Ranges.front(); - const uint64_t BlobSize = Blob.GetSize(); - const uint64_t MaxBlobSize = Range.Start < BlobSize ? Range.Start - BlobSize : 0; - const uint64_t RangeSize = Min(Range.End - Range.Start + 1, MaxBlobSize); - if (Range.Start + RangeSize > BlobSize) + const uint64_t BlobSize = Blob.GetSize(); + + CbPackage ResponsePackage; + std::vector<IoBuffer> RangeBuffers; + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + for (const std::pair<uint64_t, uint64_t>& Range : OffsetAndLengthPairs) { - return ServerRequest.WriteResponse(HttpResponseCode::NoContent); + const uint64_t MaxBlobSize = Range.first < BlobSize ? BlobSize - Range.first : 0; + const uint64_t RangeSize = Min(Range.second, MaxBlobSize); + Writer.BeginObject(); + { + if (Range.first + RangeSize <= BlobSize) + { + RangeBuffers.push_back(IoBuffer(Blob, Range.first, RangeSize)); + Writer.AddInteger("offset"sv, Range.first); + Writer.AddInteger("length"sv, RangeSize); + } + else + { + Writer.AddInteger("offset"sv, Range.first); + Writer.AddInteger("length"sv, 0); + } + } + Writer.EndObject(); } - Blob = IoBuffer(Blob, Range.Start, RangeSize); - return ServerRequest.WriteResponse(HttpResponseCode::OK, ZenContentType::kBinary, Blob); + Writer.EndArray(); + + CompositeBuffer Ranges(RangeBuffers); + CbAttachment PayloadAttachment(std::move(Ranges), BlobHash); + Writer.AddAttachment("payload", PayloadAttachment); + + CbObject HeaderObject = Writer.Save(); + + ResponsePackage.AddAttachment(PayloadAttachment); + ResponsePackage.SetObject(HeaderObject); + + CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(ResponsePackage); + uint64_t ResponseSize = RpcResponseBuffer.GetSize(); + ZEN_UNUSED(ResponseSize); + return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer); } else { - return ServerRequest.WriteResponse(HttpResponseCode::OK, Blob.GetContentType(), Blob); + if (OffsetAndLengthPairs.size() != 1) + { + // Only a single http range is supported, we have limited support for http multirange responses + m_BuildStoreStats.BadRequestCount++; + return ServerRequest.WriteResponse( + HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Multiple ranges in blob request is only supported for {} accept type", ToString(HttpContentType::kCbPackage))); + } + + const std::pair<uint64_t, uint64_t>& OffsetAndLength = OffsetAndLengthPairs.front(); + const uint64_t BlobSize = Blob.GetSize(); + const uint64_t MaxBlobSize = OffsetAndLength.first < BlobSize ? BlobSize - OffsetAndLength.first : 0; + const uint64_t RangeSize = Min(OffsetAndLength.second, MaxBlobSize); + if (OffsetAndLength.first + RangeSize > BlobSize) + { + return ServerRequest.WriteResponse(HttpResponseCode::NoContent); + } + Blob = IoBuffer(Blob, OffsetAndLength.first, RangeSize); + return ServerRequest.WriteResponse(HttpResponseCode::OK, ZenContentType::kBinary, Blob); } } @@ -507,8 +605,8 @@ HttpBuildStoreService::BlobsExistsRequest(HttpRouterRequest& Req) return ServerRequest.WriteResponse(HttpResponseCode::OK, ResponseObject); } -void -HttpBuildStoreService::HandleStatsRequest(HttpServerRequest& Request) +CbObject +HttpBuildStoreService::CollectStats() { ZEN_TRACE_CPU("HttpBuildStoreService::Stats"); @@ -562,7 +660,13 @@ HttpBuildStoreService::HandleStatsRequest(HttpServerRequest& Request) } Cbo.EndObject(); - return Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + return Cbo.Save(); +} + +void +HttpBuildStoreService::HandleStatsRequest(HttpServerRequest& Request) +{ + Request.WriteResponse(HttpResponseCode::OK, CollectStats()); } void @@ -571,6 +675,11 @@ HttpBuildStoreService::HandleStatusRequest(HttpServerRequest& Request) ZEN_TRACE_CPU("HttpBuildStoreService::Status"); CbObjectWriter Cbo; Cbo << "ok" << true; + Cbo.BeginObject("capabilities"); + { + Cbo << "maxrangecountperrequest" << MaxRangeCountPerRequestSupported; + } + Cbo.EndObject(); // capabilities Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); } diff --git a/src/zenserver/storage/buildstore/httpbuildstore.h b/src/zenserver/storage/buildstore/httpbuildstore.h index e10986411..2a09b71cf 100644 --- a/src/zenserver/storage/buildstore/httpbuildstore.h +++ b/src/zenserver/storage/buildstore/httpbuildstore.h @@ -22,8 +22,9 @@ public: virtual const char* BaseUri() const override; virtual void HandleRequest(zen::HttpServerRequest& Request) override; - virtual void HandleStatsRequest(HttpServerRequest& Request) override; - virtual void HandleStatusRequest(HttpServerRequest& Request) override; + virtual CbObject CollectStats() override; + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + virtual void HandleStatusRequest(HttpServerRequest& Request) override; private: struct BuildStoreStats @@ -45,6 +46,8 @@ private: inline LoggerRef Log() { return m_Log; } + static constexpr uint32_t MaxRangeCountPerRequestSupported = 256u; + LoggerRef m_Log; void PutBlobRequest(HttpRouterRequest& Req); diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp index 72f29d14e..06b8f6c27 100644 --- a/src/zenserver/storage/cache/httpstructuredcache.cpp +++ b/src/zenserver/storage/cache/httpstructuredcache.cpp @@ -654,7 +654,7 @@ HttpStructuredCacheService::HandleCacheNamespaceRequest(HttpServerRequest& Reque auto NewEnd = std::unique(AllAttachments.begin(), AllAttachments.end()); AllAttachments.erase(NewEnd, AllAttachments.end()); - uint64_t AttachmentsSize = 0; + std::atomic<uint64_t> AttachmentsSize = 0; m_CidStore.IterateChunks( AllAttachments, @@ -746,7 +746,7 @@ HttpStructuredCacheService::HandleCacheBucketRequest(HttpServerRequest& Request, ResponseWriter << "Size" << ValuesSize; ResponseWriter << "AttachmentCount" << ContentStats.Attachments.size(); - uint64_t AttachmentsSize = 0; + std::atomic<uint64_t> AttachmentsSize = 0; WorkerThreadPool& WorkerPool = GetMediumWorkerPool(EWorkloadType::Background); @@ -1827,8 +1827,8 @@ HttpStructuredCacheService::HandleRpcRequest(HttpServerRequest& Request, std::st } } -void -HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request) +CbObject +HttpStructuredCacheService::CollectStats() { ZEN_MEMSCOPE(GetCacheHttpTag()); @@ -1858,13 +1858,132 @@ HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request) const CidStoreSize CidSize = m_CidStore.TotalSize(); const CacheStoreSize CacheSize = m_CacheStore.TotalSize(); + Cbo.BeginObject("cache"); + { + Cbo << "badrequestcount" << BadRequestCount; + Cbo.BeginObject("rpc"); + Cbo << "count" << RpcRequests; + Cbo << "ops" << RpcRecordBatchRequests + RpcValueBatchRequests + RpcChunkBatchRequests; + Cbo.BeginObject("records"); + Cbo << "count" << RpcRecordRequests; + Cbo << "ops" << RpcRecordBatchRequests; + Cbo.EndObject(); + Cbo.BeginObject("values"); + Cbo << "count" << RpcValueRequests; + Cbo << "ops" << RpcValueBatchRequests; + Cbo.EndObject(); + Cbo.BeginObject("chunks"); + Cbo << "count" << RpcChunkRequests; + Cbo << "ops" << RpcChunkBatchRequests; + Cbo.EndObject(); + Cbo.EndObject(); + + Cbo.BeginObject("size"); + { + Cbo << "disk" << CacheSize.DiskSize; + Cbo << "memory" << CacheSize.MemorySize; + } + Cbo.EndObject(); + + Cbo << "hits" << HitCount << "misses" << MissCount << "writes" << WriteCount; + Cbo << "hit_ratio" << (TotalCount > 0 ? (double(HitCount) / double(TotalCount)) : 0.0); + + if (m_UpstreamCache.IsActive()) + { + Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); + Cbo << "upstream_hits" << m_CacheStats.UpstreamHitCount; + Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); + Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); + } + + Cbo << "cidhits" << ChunkHitCount << "cidmisses" << ChunkMissCount << "cidwrites" << ChunkWriteCount; + + { + ZenCacheStore::CacheStoreStats StoreStatsData = m_CacheStore.Stats(); + Cbo.BeginObject("store"); + Cbo << "hits" << StoreStatsData.HitCount << "misses" << StoreStatsData.MissCount << "writes" << StoreStatsData.WriteCount + << "rejected_writes" << StoreStatsData.RejectedWriteCount << "rejected_reads" << StoreStatsData.RejectedReadCount; + const uint64_t StoreTotal = StoreStatsData.HitCount + StoreStatsData.MissCount; + Cbo << "hit_ratio" << (StoreTotal > 0 ? (double(StoreStatsData.HitCount) / double(StoreTotal)) : 0.0); + EmitSnapshot("read", StoreStatsData.GetOps, Cbo); + EmitSnapshot("write", StoreStatsData.PutOps, Cbo); + Cbo.EndObject(); + } + } + Cbo.EndObject(); + + if (m_UpstreamCache.IsActive()) + { + EmitSnapshot("upstream_gets", m_UpstreamGetRequestTiming, Cbo); + Cbo.BeginObject("upstream"); + { + m_UpstreamCache.GetStatus(Cbo); + } + Cbo.EndObject(); + } + + Cbo.BeginObject("cid"); + { + Cbo.BeginObject("size"); + { + Cbo << "tiny" << CidSize.TinySize; + Cbo << "small" << CidSize.SmallSize; + Cbo << "large" << CidSize.LargeSize; + Cbo << "total" << CidSize.TotalSize; + } + Cbo.EndObject(); + } + Cbo.EndObject(); + + return Cbo.Save(); +} + +void +HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request) +{ + ZEN_MEMSCOPE(GetCacheHttpTag()); + bool ShowCidStoreStats = Request.GetQueryParams().GetValue("cidstorestats") == "true"; bool ShowCacheStoreStats = Request.GetQueryParams().GetValue("cachestorestats") == "true"; - CidStoreStats CidStoreStats = {}; + if (!ShowCidStoreStats && !ShowCacheStoreStats) + { + Request.WriteResponse(HttpResponseCode::OK, CollectStats()); + return; + } + + // Full stats with optional detailed store/cid breakdowns + + CbObjectWriter Cbo; + + EmitSnapshot("requests", m_HttpRequests, Cbo); + + const uint64_t HitCount = m_CacheStats.HitCount; + const uint64_t UpstreamHitCount = m_CacheStats.UpstreamHitCount; + const uint64_t MissCount = m_CacheStats.MissCount; + const uint64_t WriteCount = m_CacheStats.WriteCount; + const uint64_t BadRequestCount = m_CacheStats.BadRequestCount; + struct CidStoreStats StoreStats = m_CidStore.Stats(); + const uint64_t ChunkHitCount = StoreStats.HitCount; + const uint64_t ChunkMissCount = StoreStats.MissCount; + const uint64_t ChunkWriteCount = StoreStats.WriteCount; + const uint64_t TotalCount = HitCount + MissCount; + + const uint64_t RpcRequests = m_CacheStats.RpcRequests; + const uint64_t RpcRecordRequests = m_CacheStats.RpcRecordRequests; + const uint64_t RpcRecordBatchRequests = m_CacheStats.RpcRecordBatchRequests; + const uint64_t RpcValueRequests = m_CacheStats.RpcValueRequests; + const uint64_t RpcValueBatchRequests = m_CacheStats.RpcValueBatchRequests; + const uint64_t RpcChunkRequests = m_CacheStats.RpcChunkRequests; + const uint64_t RpcChunkBatchRequests = m_CacheStats.RpcChunkBatchRequests; + + const CidStoreSize CidSize = m_CidStore.TotalSize(); + const CacheStoreSize CacheSize = m_CacheStore.TotalSize(); + + CidStoreStats DetailedCidStoreStats = {}; if (ShowCidStoreStats) { - CidStoreStats = m_CidStore.Stats(); + DetailedCidStoreStats = m_CidStore.Stats(); } ZenCacheStore::CacheStoreStats CacheStoreStats = {}; if (ShowCacheStoreStats) @@ -2002,8 +2121,8 @@ HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request) } Cbo.EndObject(); } - Cbo.EndObject(); } + Cbo.EndObject(); if (m_UpstreamCache.IsActive()) { @@ -2029,10 +2148,10 @@ HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request) if (ShowCidStoreStats) { Cbo.BeginObject("store"); - Cbo << "hits" << CidStoreStats.HitCount << "misses" << CidStoreStats.MissCount << "writes" << CidStoreStats.WriteCount; - EmitSnapshot("read", CidStoreStats.FindChunkOps, Cbo); - EmitSnapshot("write", CidStoreStats.AddChunkOps, Cbo); - // EmitSnapshot("exists", CidStoreStats.ContainChunkOps, Cbo); + Cbo << "hits" << DetailedCidStoreStats.HitCount << "misses" << DetailedCidStoreStats.MissCount << "writes" + << DetailedCidStoreStats.WriteCount; + EmitSnapshot("read", DetailedCidStoreStats.FindChunkOps, Cbo); + EmitSnapshot("write", DetailedCidStoreStats.AddChunkOps, Cbo); Cbo.EndObject(); } } diff --git a/src/zenserver/storage/cache/httpstructuredcache.h b/src/zenserver/storage/cache/httpstructuredcache.h index 5a795c215..d462415d4 100644 --- a/src/zenserver/storage/cache/httpstructuredcache.h +++ b/src/zenserver/storage/cache/httpstructuredcache.h @@ -102,11 +102,12 @@ private: void HandleRpcRequest(HttpServerRequest& Request, std::string_view UriNamespace); void HandleDetailsRequest(HttpServerRequest& Request); - void HandleCacheRequest(HttpServerRequest& Request); - void HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view Namespace); - void HandleCacheBucketRequest(HttpServerRequest& Request, std::string_view Namespace, std::string_view Bucket); - virtual void HandleStatsRequest(HttpServerRequest& Request) override; - virtual void HandleStatusRequest(HttpServerRequest& Request) override; + void HandleCacheRequest(HttpServerRequest& Request); + void HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view Namespace); + void HandleCacheBucketRequest(HttpServerRequest& Request, std::string_view Namespace, std::string_view Bucket); + virtual CbObject CollectStats() override; + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + virtual void HandleStatusRequest(HttpServerRequest& Request) override; bool AreDiskWritesAllowed() const; diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp index fe32fa15b..836d84292 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.cpp +++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp @@ -13,7 +13,12 @@ #include <zencore/scopeguard.h> #include <zencore/stream.h> #include <zencore/trace.h> +#include <zenhttp/httpclientauth.h> #include <zenhttp/packageformat.h> +#include <zenremotestore/builds/buildstoragecache.h> +#include <zenremotestore/builds/buildstorageutil.h> +#include <zenremotestore/jupiter/jupiterhost.h> +#include <zenremotestore/operationlogoutput.h> #include <zenremotestore/projectstore/buildsremoteprojectstore.h> #include <zenremotestore/projectstore/fileremoteprojectstore.h> #include <zenremotestore/projectstore/jupiterremoteprojectstore.h> @@ -244,6 +249,22 @@ namespace { { std::shared_ptr<RemoteProjectStore> Store; std::string Description; + double LatencySec = -1.0; + uint64_t MaxRangeCountPerRequest = 1; + + struct Cache + { + std::unique_ptr<HttpClient> Http; + std::unique_ptr<BuildStorageCache> Cache; + Oid BuildsId = Oid::Zero; + std::string Description; + double LatencySec = -1.0; + uint64_t MaxRangeCountPerRequest = 1; + BuildStorageCache::Statistics Stats; + bool Populate = false; + }; + + std::unique_ptr<Cache> OptionalCache; }; CreateRemoteStoreResult CreateRemoteStore(LoggerRef InLog, @@ -260,7 +281,7 @@ namespace { using namespace std::literals; - std::shared_ptr<RemoteProjectStore> RemoteStore; + CreateRemoteStoreResult Result; if (CbObjectView File = Params["file"sv].AsObjectView(); File) { @@ -285,7 +306,9 @@ namespace { std::string(OptionalBaseName), ForceDisableBlocks, ForceEnableTempBlocks}; - RemoteStore = CreateFileRemoteStore(Log(), Options); + Result.Store = CreateFileRemoteStore(Log(), Options); + Result.LatencySec = 0.5 / 1000.0; // 0.5 ms + Result.MaxRangeCountPerRequest = 1024u; } if (CbObjectView Cloud = Params["cloud"sv].AsObjectView(); Cloud) @@ -363,21 +386,32 @@ namespace { bool ForceDisableTempBlocks = Cloud["disabletempblocks"sv].AsBool(false); bool AssumeHttp2 = Cloud["assumehttp2"sv].AsBool(false); - JupiterRemoteStoreOptions Options = { - RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize}, - Url, - std::string(Namespace), - std::string(Bucket), - Key, - BaseKey, - std::string(OpenIdProvider), - AccessToken, - AuthManager, - OidcExePath, - ForceDisableBlocks, - ForceDisableTempBlocks, - AssumeHttp2}; - RemoteStore = CreateJupiterRemoteStore(Log(), Options, TempFilePath, /*Quiet*/ false, /*Unattended*/ false, /*Hidden*/ true); + if (JupiterEndpointTestResult TestResult = TestJupiterEndpoint(Url, AssumeHttp2, /*Verbose*/ false); TestResult.Success) + { + Result.LatencySec = TestResult.LatencySeconds; + Result.MaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest; + + JupiterRemoteStoreOptions Options = { + RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize}, + Url, + std::string(Namespace), + std::string(Bucket), + Key, + BaseKey, + std::string(OpenIdProvider), + AccessToken, + AuthManager, + OidcExePath, + ForceDisableBlocks, + ForceDisableTempBlocks, + AssumeHttp2}; + Result.Store = + CreateJupiterRemoteStore(Log(), Options, TempFilePath, /*Quiet*/ false, /*Unattended*/ false, /*Hidden*/ true); + } + else + { + return {nullptr, fmt::format("Unable to connect to jupiter host '{}'", Url)}; + } } if (CbObjectView Zen = Params["zen"sv].AsObjectView(); Zen) @@ -393,12 +427,13 @@ namespace { { return {nullptr, "Missing oplog"}; } + ZenRemoteStoreOptions Options = { RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize}, std::string(Url), std::string(Project), std::string(Oplog)}; - RemoteStore = CreateZenRemoteStore(Log(), Options, TempFilePath); + Result.Store = CreateZenRemoteStore(Log(), Options, TempFilePath); } if (CbObjectView Builds = Params["builds"sv].AsObjectView(); Builds) @@ -471,11 +506,76 @@ namespace { MemoryView MetaDataSection = Builds["metadata"sv].AsBinaryView(); IoBuffer MetaData(IoBuffer::Wrap, MetaDataSection.GetData(), MetaDataSection.GetSize()); + auto EnsureHttps = [](const std::string& Host, std::string_view PreferredProtocol) { + if (!Host.empty() && Host.find("://"sv) == std::string::npos) + { + // Assume https URL + return fmt::format("{}://{}"sv, PreferredProtocol, Host); + } + return Host; + }; + + Host = EnsureHttps(Host, "https"); + OverrideHost = EnsureHttps(OverrideHost, "https"); + ZenHost = EnsureHttps(ZenHost, "http"); + + std::function<HttpClientAccessToken()> TokenProvider; + if (!OpenIdProvider.empty()) + { + TokenProvider = httpclientauth::CreateFromOpenIdProvider(AuthManager, OpenIdProvider); + } + else if (!AccessToken.empty()) + { + TokenProvider = httpclientauth::CreateFromStaticToken(AccessToken); + } + else if (!OidcExePath.empty()) + { + if (auto TokenProviderMaybe = httpclientauth::CreateFromOidcTokenExecutable(OidcExePath, + Host.empty() ? OverrideHost : Host, + /*Quiet*/ false, + /*Unattended*/ false, + /*Hidden*/ true); + TokenProviderMaybe) + { + TokenProvider = TokenProviderMaybe.value(); + } + } + + if (!TokenProvider) + { + TokenProvider = httpclientauth::CreateFromDefaultOpenIdProvider(AuthManager); + } + + BuildStorageResolveResult ResolveResult; + { + HttpClientSettings ClientSettings{.LogCategory = "httpbuildsclient", + .AccessTokenProvider = TokenProvider, + .AssumeHttp2 = AssumeHttp2, + .AllowResume = true, + .RetryCount = 2}; + + std::unique_ptr<OperationLogOutput> Output(CreateStandardLogOutput(Log())); + + try + { + ResolveResult = ResolveBuildStorage(*Output, + ClientSettings, + Host, + OverrideHost, + ZenHost, + ZenCacheResolveMode::Discovery, + /*Verbose*/ false); + } + catch (const std::exception& Ex) + { + return {nullptr, fmt::format("Failed resolving storage host and cache. Reason: '{}'", Ex.what())}; + } + } + Result.LatencySec = ResolveResult.Cloud.LatencySec; + Result.MaxRangeCountPerRequest = ResolveResult.Cloud.Caps.MaxRangeCountPerRequest; + BuildsRemoteStoreOptions Options = { RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize}, - Host, - OverrideHost, - ZenHost, std::string(Namespace), std::string(Bucket), BuildId, @@ -485,25 +585,43 @@ namespace { OidcExePath, ForceDisableBlocks, ForceDisableTempBlocks, - AssumeHttp2, - PopulateCache, MetaData, MaximumInMemoryDownloadSize}; - RemoteStore = CreateJupiterBuildsRemoteStore(Log(), - Options, - TempFilePath, - /*Quiet*/ false, - /*Unattended*/ false, - /*Hidden*/ true, - GetTinyWorkerPool(EWorkloadType::Background)); + Result.Store = CreateJupiterBuildsRemoteStore(Log(), ResolveResult, std::move(TokenProvider), Options, TempFilePath); + + if (!ResolveResult.Cache.Address.empty()) + { + Result.OptionalCache = std::make_unique<CreateRemoteStoreResult::Cache>(); + + HttpClientSettings CacheClientSettings{.LogCategory = "httpcacheclient", + .ConnectTimeout = std::chrono::milliseconds{3000}, + .Timeout = std::chrono::milliseconds{30000}, + .AssumeHttp2 = ResolveResult.Cache.AssumeHttp2, + .AllowResume = true, + .RetryCount = 0, + .MaximumInMemoryDownloadSize = MaximumInMemoryDownloadSize}; + + Result.OptionalCache->Http = std::make_unique<HttpClient>(ResolveResult.Cache.Address, CacheClientSettings); + Result.OptionalCache->Cache = CreateZenBuildStorageCache(*Result.OptionalCache->Http, + Result.OptionalCache->Stats, + Namespace, + Bucket, + TempFilePath, + GetTinyWorkerPool(EWorkloadType::Background)); + Result.OptionalCache->BuildsId = BuildId; + Result.OptionalCache->LatencySec = ResolveResult.Cache.LatencySec; + Result.OptionalCache->MaxRangeCountPerRequest = ResolveResult.Cache.Caps.MaxRangeCountPerRequest; + Result.OptionalCache->Populate = PopulateCache; + Result.OptionalCache->Description = + fmt::format("[zenserver] {} namespace {} bucket {}", ResolveResult.Cache.Address, Namespace, Bucket); + } } - - if (!RemoteStore) + if (!Result.Store) { return {nullptr, "Unknown remote store type"}; } - return {std::move(RemoteStore), ""}; + return Result; } std::pair<HttpResponseCode, std::string> ConvertResult(const RemoteProjectStore::Result& Result) @@ -714,8 +832,8 @@ HttpProjectService::HandleRequest(HttpServerRequest& Request) } } -void -HttpProjectService::HandleStatsRequest(HttpServerRequest& HttpReq) +CbObject +HttpProjectService::CollectStats() { ZEN_TRACE_CPU("ProjectService::Stats"); @@ -781,7 +899,13 @@ HttpProjectService::HandleStatsRequest(HttpServerRequest& HttpReq) } Cbo.EndObject(); - return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + return Cbo.Save(); +} + +void +HttpProjectService::HandleStatsRequest(HttpServerRequest& HttpReq) +{ + HttpReq.WriteResponse(HttpResponseCode::OK, CollectStats()); } void @@ -2373,15 +2497,19 @@ HttpProjectService::HandleOplogSaveRequest(HttpRouterRequest& Req) tsl::robin_set<IoHash, IoHash::Hasher> Attachments; auto HasAttachment = [this](const IoHash& RawHash) { return m_CidStore.ContainsChunk(RawHash); }; - auto OnNeedBlock = [&AttachmentsLock, &Attachments](const IoHash& BlockHash, const std::vector<IoHash>&& ChunkHashes) { + auto OnNeedBlock = [&AttachmentsLock, &Attachments](ThinChunkBlockDescription&& ThinBlockDescription, + std::vector<uint32_t>&& NeededChunkIndexes) { RwLock::ExclusiveLockScope _(AttachmentsLock); - if (BlockHash != IoHash::Zero) + if (ThinBlockDescription.BlockHash != IoHash::Zero) { - Attachments.insert(BlockHash); + Attachments.insert(ThinBlockDescription.BlockHash); } else { - Attachments.insert(ChunkHashes.begin(), ChunkHashes.end()); + for (uint32_t ChunkIndex : NeededChunkIndexes) + { + Attachments.insert(ThinBlockDescription.ChunkRawHashes[ChunkIndex]); + } } }; auto OnNeedAttachment = [&AttachmentsLock, &Attachments](const IoHash& RawHash) { @@ -2687,36 +2815,39 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) bool CleanOplog = Params["clean"].AsBool(false); bool BoostWorkerCount = Params["boostworkercount"].AsBool(false); bool BoostWorkerMemory = Params["boostworkermemory"sv].AsBool(false); - - CreateRemoteStoreResult RemoteStoreResult = CreateRemoteStore(Log(), - Params, - m_AuthMgr, - MaxBlockSize, - MaxChunkEmbedSize, - GetMaxMemoryBufferSize(MaxBlockSize, BoostWorkerMemory), - Oplog->TempPath()); - - if (RemoteStoreResult.Store == nullptr) + EPartialBlockRequestMode PartialBlockRequestMode = + PartialBlockRequestModeFromString(Params["partialblockrequestmode"sv].AsString("true")); + + std::shared_ptr<CreateRemoteStoreResult> RemoteStoreResult = + std::make_shared<CreateRemoteStoreResult>(CreateRemoteStore(Log(), + Params, + m_AuthMgr, + MaxBlockSize, + MaxChunkEmbedSize, + GetMaxMemoryBufferSize(MaxBlockSize, BoostWorkerMemory), + Oplog->TempPath())); + + if (RemoteStoreResult->Store == nullptr) { - return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, RemoteStoreResult.Description); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, RemoteStoreResult->Description); } - std::shared_ptr<RemoteProjectStore> RemoteStore = std::move(RemoteStoreResult.Store); - RemoteProjectStore::RemoteStoreInfo StoreInfo = RemoteStore->GetInfo(); JobId JobId = m_JobQueue.QueueJob( fmt::format("Import oplog '{}/{}'", Project->Identifier, Oplog->OplogId()), [this, - ChunkStore = &m_CidStore, - ActualRemoteStore = std::move(RemoteStore), + RemoteStoreResult = std::move(RemoteStoreResult), Oplog, Force, IgnoreMissingAttachments, CleanOplog, + PartialBlockRequestMode, BoostWorkerCount](JobContext& Context) { - Context.ReportMessage(fmt::format("Loading oplog '{}/{}' from {}", - Oplog->GetOuterProjectIdentifier(), - Oplog->OplogId(), - ActualRemoteStore->GetInfo().Description)); + Context.ReportMessage( + fmt::format("Loading oplog '{}/{}'\n Host: {}\n Cache: {}", + Oplog->GetOuterProjectIdentifier(), + Oplog->OplogId(), + RemoteStoreResult->Store->GetInfo().Description, + RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->Description : "<none>")); Ref<TransferThreadWorkers> Workers = GetThreadWorkers(BoostWorkerCount, /*SingleThreaded*/ false); @@ -2724,16 +2855,26 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) WorkerThreadPool& NetworkWorkerPool = Workers->GetNetworkPool(); Context.ReportMessage(fmt::format("{}", Workers->GetWorkersInfo())); - - RemoteProjectStore::Result Result = LoadOplog(m_CidStore, - *ActualRemoteStore, - *Oplog, - NetworkWorkerPool, - WorkerPool, - Force, - IgnoreMissingAttachments, - CleanOplog, - &Context); + RemoteProjectStore::Result Result = LoadOplog(LoadOplogContext{ + .ChunkStore = m_CidStore, + .RemoteStore = *RemoteStoreResult->Store, + .OptionalCache = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->Cache.get() : nullptr, + .CacheBuildId = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->BuildsId : Oid::Zero, + .OptionalCacheStats = RemoteStoreResult->OptionalCache ? &RemoteStoreResult->OptionalCache->Stats : nullptr, + .Oplog = *Oplog, + .NetworkWorkerPool = NetworkWorkerPool, + .WorkerPool = WorkerPool, + .ForceDownload = Force, + .IgnoreMissingAttachments = IgnoreMissingAttachments, + .CleanOplog = CleanOplog, + .PartialBlockRequestMode = PartialBlockRequestMode, + .PopulateCache = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->Populate : false, + .StoreLatencySec = RemoteStoreResult->LatencySec, + .StoreMaxRangeCountPerRequest = RemoteStoreResult->MaxRangeCountPerRequest, + .CacheLatencySec = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->LatencySec : -1.0, + .CacheMaxRangeCountPerRequest = + RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->MaxRangeCountPerRequest : 0, + .OptionalJobContext = &Context}); auto Response = ConvertResult(Result); ZEN_INFO("LoadOplog: Status: {} '{}'", ToString(Response.first), Response.second); if (!IsHttpSuccessCode(Response.first)) diff --git a/src/zenserver/storage/projectstore/httpprojectstore.h b/src/zenserver/storage/projectstore/httpprojectstore.h index 1d71329b1..a1f649ed6 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.h +++ b/src/zenserver/storage/projectstore/httpprojectstore.h @@ -51,8 +51,9 @@ public: virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; - virtual void HandleStatsRequest(HttpServerRequest& Request) override; - virtual void HandleStatusRequest(HttpServerRequest& Request) override; + virtual CbObject CollectStats() override; + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + virtual void HandleStatusRequest(HttpServerRequest& Request) override; private: struct ProjectStats diff --git a/src/zenserver/storage/storageconfig.cpp b/src/zenserver/storage/storageconfig.cpp index 99d0f89d7..ad1fb88ea 100644 --- a/src/zenserver/storage/storageconfig.cpp +++ b/src/zenserver/storage/storageconfig.cpp @@ -804,6 +804,7 @@ ZenStorageServerCmdLineOptions::AddCacheOptions(cxxopts::Options& options, ZenSt cxxopts::value<uint64_t>(ServerOptions.StructuredCacheConfig.MemMaxAgeSeconds)->default_value("86400"), ""); + options.add_option("compute", "", "lie-cpus", "Lie to upstream about CPU capabilities", cxxopts::value<int>(ServerOptions.LieCpu), ""); options.add_option("cache", "", "cache-bucket-maxblocksize", diff --git a/src/zenserver/storage/storageconfig.h b/src/zenserver/storage/storageconfig.h index bc2dc78c9..d935ed8b3 100644 --- a/src/zenserver/storage/storageconfig.h +++ b/src/zenserver/storage/storageconfig.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include "config/config.h" @@ -156,6 +157,7 @@ struct ZenStorageServerConfig : public ZenServerConfig ZenWorkspacesConfig WorksSpacesConfig; std::filesystem::path PluginsConfigFile; // Path to plugins config file bool ObjectStoreEnabled = false; + bool ComputeEnabled = true; std::string ScrubOptions; bool RestrictContentTypes = false; }; diff --git a/src/zenserver/storage/workspaces/httpworkspaces.cpp b/src/zenserver/storage/workspaces/httpworkspaces.cpp index dc4cc7e69..785dd62f0 100644 --- a/src/zenserver/storage/workspaces/httpworkspaces.cpp +++ b/src/zenserver/storage/workspaces/httpworkspaces.cpp @@ -110,8 +110,8 @@ HttpWorkspacesService::HandleRequest(HttpServerRequest& Request) } } -void -HttpWorkspacesService::HandleStatsRequest(HttpServerRequest& HttpReq) +CbObject +HttpWorkspacesService::CollectStats() { ZEN_TRACE_CPU("WorkspacesService::Stats"); CbObjectWriter Cbo; @@ -150,7 +150,13 @@ HttpWorkspacesService::HandleStatsRequest(HttpServerRequest& HttpReq) } Cbo.EndObject(); - return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + return Cbo.Save(); +} + +void +HttpWorkspacesService::HandleStatsRequest(HttpServerRequest& HttpReq) +{ + HttpReq.WriteResponse(HttpResponseCode::OK, CollectStats()); } void diff --git a/src/zenserver/storage/workspaces/httpworkspaces.h b/src/zenserver/storage/workspaces/httpworkspaces.h index 888a34b4d..7c5ddeff1 100644 --- a/src/zenserver/storage/workspaces/httpworkspaces.h +++ b/src/zenserver/storage/workspaces/httpworkspaces.h @@ -29,8 +29,9 @@ public: virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; - virtual void HandleStatsRequest(HttpServerRequest& Request) override; - virtual void HandleStatusRequest(HttpServerRequest& Request) override; + virtual CbObject CollectStats() override; + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + virtual void HandleStatusRequest(HttpServerRequest& Request) override; private: struct WorkspacesStats diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp index ea05bd155..f43bb9987 100644 --- a/src/zenserver/storage/zenstorageserver.cpp +++ b/src/zenserver/storage/zenstorageserver.cpp @@ -33,6 +33,7 @@ #include <zenutil/service.h> #include <zenutil/workerpools.h> #include <zenutil/zenserverprocess.h> +#include "../sessions/sessions.h" #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> @@ -133,7 +134,6 @@ void ZenStorageServer::RegisterServices() { m_Http->RegisterService(*m_AuthService); - m_Http->RegisterService(m_StatsService); m_Http->RegisterService(m_TestService); // NOTE: this is intentionally not limited to test mode as it's useful for diagnostics #if ZEN_WITH_TESTS @@ -160,6 +160,11 @@ ZenStorageServer::RegisterServices() m_Http->RegisterService(*m_HttpWorkspacesService); } + if (m_HttpSessionsService) + { + m_Http->RegisterService(*m_HttpSessionsService); + } + m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatusService); if (m_FrontendService) @@ -182,6 +187,18 @@ ZenStorageServer::RegisterServices() #endif // ZEN_WITH_VFS m_Http->RegisterService(*m_AdminService); + + if (m_ApiService) + { + m_Http->RegisterService(*m_ApiService); + } + +#if ZEN_WITH_COMPUTE_SERVICES + if (m_HttpComputeService) + { + m_Http->RegisterService(*m_HttpComputeService); + } +#endif } void @@ -227,6 +244,11 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions *m_Workspaces)); } + { + m_SessionsService = std::make_unique<SessionsService>(); + m_HttpSessionsService = std::make_unique<HttpSessionsService>(m_StatusService, m_StatsService, *m_SessionsService); + } + if (ServerOptions.BuildStoreConfig.Enabled) { CidStoreConfiguration BuildCidConfig; @@ -273,6 +295,16 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions m_BuildStoreService = std::make_unique<HttpBuildStoreService>(m_StatusService, m_StatsService, *m_BuildStore); } +#if ZEN_WITH_COMPUTE_SERVICES + if (ServerOptions.ComputeEnabled) + { + ZEN_OTEL_SPAN("InitializeComputeService"); + + m_HttpComputeService = + std::make_unique<compute::HttpComputeService>(*m_CidStore, m_StatsService, ServerOptions.DataDir / "functions"); + } +#endif + #if ZEN_WITH_VFS m_VfsServiceImpl = std::make_unique<VfsServiceImpl>(); m_VfsServiceImpl->AddService(Ref<ProjectStore>(m_ProjectStore)); @@ -305,13 +337,15 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions .AttachmentPassCount = ServerOptions.GcConfig.AttachmentPassCount}; m_GcScheduler.Initialize(GcConfig); + m_ApiService = std::make_unique<HttpApiService>(*m_Http); + // Create and register admin interface last to make sure all is properly initialized m_AdminService = std::make_unique<HttpAdminService>( m_GcScheduler, *m_JobQueue, m_CacheStore.Get(), [this]() { Flush(); }, - HttpAdminService::LogPaths{.AbsLogPath = ServerOptions.AbsLogFile, + HttpAdminService::LogPaths{.AbsLogPath = ServerOptions.LoggingConfig.AbsLogFile, .HttpLogPath = ServerOptions.DataDir / "logs" / "http.log", .CacheLogPath = ServerOptions.DataDir / "logs" / "z$.log"}, ServerOptions); @@ -689,6 +723,15 @@ ZenStorageServer::Run() ZEN_INFO(ZEN_APP_NAME " now running (pid: {})", GetCurrentProcessId()); + if (m_FrontendService) + { + ZEN_INFO("frontend link: {}", m_Http->GetServiceUri(m_FrontendService.get())); + } + else + { + ZEN_INFO("frontend service disabled"); + } + #if ZEN_PLATFORM_WINDOWS if (zen::windows::IsRunningOnWine()) { @@ -796,6 +839,8 @@ ZenStorageServer::Cleanup() m_IoRunner.join(); } + ShutdownServices(); + if (m_Http) { m_Http->Close(); @@ -811,6 +856,10 @@ ZenStorageServer::Cleanup() Flush(); +#if ZEN_WITH_COMPUTE_SERVICES + m_HttpComputeService.reset(); +#endif + m_AdminService.reset(); m_VfsService.reset(); m_VfsServiceImpl.reset(); @@ -826,6 +875,8 @@ ZenStorageServer::Cleanup() m_UpstreamCache.reset(); m_CacheStore = {}; + m_HttpSessionsService.reset(); + m_SessionsService.reset(); m_HttpWorkspacesService.reset(); m_Workspaces.reset(); m_HttpProjectService.reset(); diff --git a/src/zenserver/storage/zenstorageserver.h b/src/zenserver/storage/zenstorageserver.h index 5ccb587d6..d625f869c 100644 --- a/src/zenserver/storage/zenstorageserver.h +++ b/src/zenserver/storage/zenstorageserver.h @@ -6,11 +6,13 @@ #include <zenhttp/auth/authmgr.h> #include <zenhttp/auth/authservice.h> +#include <zenhttp/httpapiservice.h> #include <zenhttp/httptest.h> #include <zenstore/cache/structuredcachestore.h> #include <zenstore/gc.h> #include <zenstore/projectstore.h> +#include "../sessions/httpsessions.h" #include "admin/admin.h" #include "buildstore/httpbuildstore.h" #include "cache/httpstructuredcache.h" @@ -23,6 +25,10 @@ #include "vfs/vfsservice.h" #include "workspaces/httpworkspaces.h" +#if ZEN_WITH_COMPUTE_SERVICES +# include <zencompute/httpcomputeservice.h> +#endif + namespace zen { class ZenStorageServer : public ZenServerBase @@ -34,11 +40,6 @@ public: ZenStorageServer(); ~ZenStorageServer(); - void SetDedicatedMode(bool State) { m_IsDedicatedMode = State; } - void SetTestMode(bool State) { m_TestMode = State; } - void SetDataRoot(std::filesystem::path Root) { m_DataRoot = Root; } - void SetContentRoot(std::filesystem::path Root) { m_ContentRoot = Root; } - int Initialize(const ZenStorageServerConfig& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry); void Run(); void Cleanup(); @@ -48,14 +49,9 @@ private: void InitializeStructuredCache(const ZenStorageServerConfig& ServerOptions); void Flush(); - bool m_IsDedicatedMode = false; - bool m_TestMode = false; - bool m_DebugOptionForcedCrash = false; - std::string m_StartupScrubOptions; - CbObject m_RootManifest; - std::filesystem::path m_DataRoot; - std::filesystem::path m_ContentRoot; - asio::steady_timer m_StateMarkerTimer{m_IoContext}; + std::string m_StartupScrubOptions; + CbObject m_RootManifest; + asio::steady_timer m_StateMarkerTimer{m_IoContext}; void EnqueueStateMarkerTimer(); void CheckStateMarker(); @@ -67,7 +63,6 @@ private: void InitializeServices(const ZenStorageServerConfig& ServerOptions); void RegisterServices(); - HttpStatsService m_StatsService; std::unique_ptr<JobQueue> m_JobQueue; GcManager m_GcManager; GcScheduler m_GcScheduler{m_GcManager}; @@ -87,6 +82,8 @@ private: std::unique_ptr<HttpProjectService> m_HttpProjectService; std::unique_ptr<Workspaces> m_Workspaces; std::unique_ptr<HttpWorkspacesService> m_HttpWorkspacesService; + std::unique_ptr<SessionsService> m_SessionsService; + std::unique_ptr<HttpSessionsService> m_HttpSessionsService; std::unique_ptr<UpstreamCache> m_UpstreamCache; std::unique_ptr<HttpUpstreamService> m_UpstreamService; std::unique_ptr<HttpStructuredCacheService> m_StructuredCacheService; @@ -95,6 +92,11 @@ private: std::unique_ptr<HttpBuildStoreService> m_BuildStoreService; std::unique_ptr<VfsService> m_VfsService; std::unique_ptr<HttpAdminService> m_AdminService; + std::unique_ptr<HttpApiService> m_ApiService; + +#if ZEN_WITH_COMPUTE_SERVICES + std::unique_ptr<compute::HttpComputeService> m_HttpComputeService; +#endif }; struct ZenStorageServerConfigurator; diff --git a/src/zenserver/trace/tracerecorder.cpp b/src/zenserver/trace/tracerecorder.cpp new file mode 100644 index 000000000..5dec20e18 --- /dev/null +++ b/src/zenserver/trace/tracerecorder.cpp @@ -0,0 +1,565 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "tracerecorder.h" + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/uid.h> + +#include <asio.hpp> + +#include <atomic> +#include <cstring> +#include <memory> +#include <mutex> +#include <thread> + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// + +struct TraceSession : public std::enable_shared_from_this<TraceSession> +{ + TraceSession(asio::ip::tcp::socket&& Socket, const std::filesystem::path& OutputDir) + : m_Socket(std::move(Socket)) + , m_OutputDir(OutputDir) + , m_SessionId(Oid::NewOid()) + { + try + { + m_RemoteAddress = m_Socket.remote_endpoint().address().to_string(); + } + catch (...) + { + m_RemoteAddress = "unknown"; + } + + ZEN_INFO("Trace session {} started from {}", m_SessionId, m_RemoteAddress); + } + + ~TraceSession() + { + if (m_TraceFile.IsOpen()) + { + m_TraceFile.Close(); + } + + ZEN_INFO("Trace session {} ended, {} bytes recorded to '{}'", m_SessionId, m_TotalBytesRecorded, m_TraceFilePath); + } + + void Start() { ReadPreambleHeader(); } + + bool IsActive() const { return m_Socket.is_open(); } + + TraceSessionInfo GetInfo() const + { + TraceSessionInfo Info; + Info.SessionGuid = m_SessionGuid; + Info.TraceGuid = m_TraceGuid; + Info.ControlPort = m_ControlPort; + Info.TransportVersion = m_TransportVersion; + Info.ProtocolVersion = m_ProtocolVersion; + Info.RemoteAddress = m_RemoteAddress; + Info.BytesRecorded = m_TotalBytesRecorded; + Info.TraceFilePath = m_TraceFilePath; + return Info; + } + +private: + // Preamble format: + // [magic: 4 bytes][metadata_size: 2 bytes][metadata fields: variable][version: 2 bytes] + // + // Magic bytes: [0]=version_char ('2'-'9'), [1]='C', [2]='R', [3]='T' + // + // Metadata fields (repeated): + // [size: 1 byte][id: 1 byte][data: <size> bytes] + // Field 0: ControlPort (uint16) + // Field 1: SessionGuid (16 bytes) + // Field 2: TraceGuid (16 bytes) + // + // Version: [transport: 1 byte][protocol: 1 byte] + + static constexpr size_t kMagicSize = 4; + static constexpr size_t kMetadataSizeFieldSize = 2; + static constexpr size_t kPreambleHeaderSize = kMagicSize + kMetadataSizeFieldSize; + static constexpr size_t kVersionSize = 2; + static constexpr size_t kPreambleBufferSize = 256; + static constexpr size_t kReadBufferSize = 64 * 1024; + + void ReadPreambleHeader() + { + auto Self = shared_from_this(); + + // Read the first 6 bytes: 4 magic + 2 metadata size + asio::async_read(m_Socket, + asio::buffer(m_PreambleBuffer, kPreambleHeaderSize), + [this, Self](const asio::error_code& Ec, std::size_t /*BytesRead*/) { + if (Ec) + { + HandleReadError("preamble header", Ec); + return; + } + + if (!ValidateMagic()) + { + ZEN_WARN("Trace session {}: invalid trace magic header", m_SessionId); + CloseSocket(); + return; + } + + ReadPreambleMetadata(); + }); + } + + bool ValidateMagic() + { + const uint8_t* Cursor = m_PreambleBuffer; + + // Validate magic: bytes are version, 'C', 'R', 'T' + if (Cursor[3] != 'T' || Cursor[2] != 'R' || Cursor[1] != 'C') + { + return false; + } + + if (Cursor[0] < '2' || Cursor[0] > '9') + { + return false; + } + + // Extract the metadata fields size (does not include the trailing version bytes) + std::memcpy(&m_MetadataFieldsSize, Cursor + kMagicSize, sizeof(m_MetadataFieldsSize)); + + if (m_MetadataFieldsSize + kVersionSize > kPreambleBufferSize - kPreambleHeaderSize) + { + return false; + } + + return true; + } + + void ReadPreambleMetadata() + { + auto Self = shared_from_this(); + size_t ReadSize = m_MetadataFieldsSize + kVersionSize; + + // Read metadata fields + 2 version bytes + asio::async_read(m_Socket, + asio::buffer(m_PreambleBuffer + kPreambleHeaderSize, ReadSize), + [this, Self](const asio::error_code& Ec, std::size_t /*BytesRead*/) { + if (Ec) + { + HandleReadError("preamble metadata", Ec); + return; + } + + if (!ParseMetadata()) + { + ZEN_WARN("Trace session {}: malformed trace metadata", m_SessionId); + CloseSocket(); + return; + } + + if (!CreateTraceFile()) + { + CloseSocket(); + return; + } + + // Write the full preamble to the trace file so it remains a valid .utrace + size_t PreambleSize = kPreambleHeaderSize + m_MetadataFieldsSize + kVersionSize; + std::error_code WriteEc; + m_TraceFile.Write(m_PreambleBuffer, PreambleSize, 0, WriteEc); + + if (WriteEc) + { + ZEN_ERROR("Trace session {}: failed to write preamble: {}", m_SessionId, WriteEc.message()); + CloseSocket(); + return; + } + + m_TotalBytesRecorded = PreambleSize; + + ZEN_INFO("Trace session {}: metadata - TransportV{} ProtocolV{} ControlPort:{} SessionGuid:{} TraceGuid:{}", + m_SessionId, + m_TransportVersion, + m_ProtocolVersion, + m_ControlPort, + m_SessionGuid, + m_TraceGuid); + + // Begin streaming trace data to disk + ReadMore(); + }); + } + + bool ParseMetadata() + { + const uint8_t* Cursor = m_PreambleBuffer + kPreambleHeaderSize; + int32_t Remaining = static_cast<int32_t>(m_MetadataFieldsSize); + + while (Remaining >= 2) + { + uint8_t FieldSize = Cursor[0]; + uint8_t FieldId = Cursor[1]; + Cursor += 2; + Remaining -= 2; + + if (Remaining < FieldSize) + { + return false; + } + + switch (FieldId) + { + case 0: // ControlPort + if (FieldSize >= sizeof(uint16_t)) + { + std::memcpy(&m_ControlPort, Cursor, sizeof(uint16_t)); + } + break; + case 1: // SessionGuid + if (FieldSize >= sizeof(Guid)) + { + std::memcpy(&m_SessionGuid, Cursor, sizeof(Guid)); + } + break; + case 2: // TraceGuid + if (FieldSize >= sizeof(Guid)) + { + std::memcpy(&m_TraceGuid, Cursor, sizeof(Guid)); + } + break; + } + + Cursor += FieldSize; + Remaining -= FieldSize; + } + + // Metadata should be fully consumed + if (Remaining != 0) + { + return false; + } + + // Version bytes follow immediately after the metadata fields + const uint8_t* VersionPtr = m_PreambleBuffer + kPreambleHeaderSize + m_MetadataFieldsSize; + m_TransportVersion = VersionPtr[0]; + m_ProtocolVersion = VersionPtr[1]; + + return true; + } + + bool CreateTraceFile() + { + m_TraceFilePath = m_OutputDir / fmt::format("{}.utrace", m_SessionId); + + try + { + m_TraceFile.Open(m_TraceFilePath, BasicFile::Mode::kTruncate); + ZEN_INFO("Trace session {} writing to '{}'", m_SessionId, m_TraceFilePath); + return true; + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Trace session {}: failed to create trace file '{}': {}", m_SessionId, m_TraceFilePath, Ex.what()); + return false; + } + } + + void ReadMore() + { + auto Self = shared_from_this(); + + m_Socket.async_read_some(asio::buffer(m_ReadBuffer, kReadBufferSize), + [this, Self](const asio::error_code& Ec, std::size_t BytesRead) { + if (!Ec) + { + if (BytesRead > 0 && m_TraceFile.IsOpen()) + { + std::error_code WriteEc; + const uint64_t FileOffset = m_TotalBytesRecorded; + m_TraceFile.Write(m_ReadBuffer, BytesRead, FileOffset, WriteEc); + + if (WriteEc) + { + ZEN_ERROR("Trace session {}: write error: {}", m_SessionId, WriteEc.message()); + CloseSocket(); + return; + } + + m_TotalBytesRecorded += BytesRead; + } + + ReadMore(); + } + else if (Ec == asio::error::eof) + { + ZEN_DEBUG("Trace session {} connection closed by peer", m_SessionId); + CloseSocket(); + } + else if (Ec == asio::error::operation_aborted) + { + ZEN_DEBUG("Trace session {} operation aborted", m_SessionId); + } + else + { + ZEN_WARN("Trace session {} read error: {}", m_SessionId, Ec.message()); + CloseSocket(); + } + }); + } + + void HandleReadError(const char* Phase, const asio::error_code& Ec) + { + if (Ec == asio::error::eof) + { + ZEN_DEBUG("Trace session {}: connection closed during {}", m_SessionId, Phase); + } + else if (Ec == asio::error::operation_aborted) + { + ZEN_DEBUG("Trace session {}: operation aborted during {}", m_SessionId, Phase); + } + else + { + ZEN_WARN("Trace session {}: error during {}: {}", m_SessionId, Phase, Ec.message()); + } + + CloseSocket(); + } + + void CloseSocket() + { + std::error_code Ec; + m_Socket.close(Ec); + + if (m_TraceFile.IsOpen()) + { + m_TraceFile.Close(); + } + } + + asio::ip::tcp::socket m_Socket; + std::filesystem::path m_OutputDir; + std::filesystem::path m_TraceFilePath; + BasicFile m_TraceFile; + Oid m_SessionId; + std::string m_RemoteAddress; + + // Preamble parsing + uint8_t m_PreambleBuffer[kPreambleBufferSize] = {}; + uint16_t m_MetadataFieldsSize = 0; + + // Extracted metadata + Guid m_SessionGuid{}; + Guid m_TraceGuid{}; + uint16_t m_ControlPort = 0; + uint8_t m_TransportVersion = 0; + uint8_t m_ProtocolVersion = 0; + + // Streaming + uint8_t m_ReadBuffer[kReadBufferSize]; + uint64_t m_TotalBytesRecorded = 0; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TraceRecorder::Impl +{ + Impl() : m_IoContext(), m_Acceptor(m_IoContext) {} + + ~Impl() { Shutdown(); } + + void Initialize(uint16_t InPort, const std::filesystem::path& OutputDir) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + + if (m_IsRunning) + { + ZEN_WARN("TraceRecorder already initialized"); + return; + } + + m_OutputDir = OutputDir; + + try + { + // Create output directory if it doesn't exist + CreateDirectories(m_OutputDir); + + // Configure acceptor + m_Acceptor.open(asio::ip::tcp::v4()); + m_Acceptor.set_option(asio::socket_base::reuse_address(true)); + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::tcp::v4(), InPort)); + m_Acceptor.listen(); + + m_Port = m_Acceptor.local_endpoint().port(); + + ZEN_INFO("TraceRecorder listening on port {}, output directory: '{}'", m_Port, m_OutputDir); + + m_IsRunning = true; + + // Start accepting connections + StartAccept(); + + // Start IO thread + m_IoThread = std::thread([this]() { + try + { + m_IoContext.run(); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("TraceRecorder IO thread exception: {}", Ex.what()); + } + }); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed to initialize TraceRecorder: {}", Ex.what()); + m_IsRunning = false; + throw; + } + } + + void Shutdown() + { + std::lock_guard<std::mutex> Lock(m_Mutex); + + if (!m_IsRunning) + { + return; + } + + ZEN_INFO("TraceRecorder shutting down"); + + m_IsRunning = false; + + std::error_code Ec; + m_Acceptor.close(Ec); + + m_IoContext.stop(); + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + + { + std::lock_guard<std::mutex> SessionLock(m_SessionsMutex); + m_Sessions.clear(); + } + + ZEN_INFO("TraceRecorder shutdown complete"); + } + + bool IsRunning() const { return m_IsRunning; } + + uint16_t GetPort() const { return m_Port; } + + std::vector<TraceSessionInfo> GetActiveSessions() const + { + std::lock_guard<std::mutex> Lock(m_SessionsMutex); + + std::vector<TraceSessionInfo> Result; + for (const auto& WeakSession : m_Sessions) + { + if (auto Session = WeakSession.lock()) + { + if (Session->IsActive()) + { + Result.push_back(Session->GetInfo()); + } + } + } + return Result; + } + +private: + void StartAccept() + { + auto Socket = std::make_shared<asio::ip::tcp::socket>(m_IoContext); + + m_Acceptor.async_accept(*Socket, [this, Socket](const asio::error_code& Ec) { + if (!Ec) + { + auto Session = std::make_shared<TraceSession>(std::move(*Socket), m_OutputDir); + + { + std::lock_guard<std::mutex> Lock(m_SessionsMutex); + + // Prune expired sessions while adding the new one + std::erase_if(m_Sessions, [](const std::weak_ptr<TraceSession>& Wp) { return Wp.expired(); }); + m_Sessions.push_back(Session); + } + + Session->Start(); + } + else if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("Accept error: {}", Ec.message()); + } + + // Continue accepting if still running + if (m_IsRunning) + { + StartAccept(); + } + }); + } + + asio::io_context m_IoContext; + asio::ip::tcp::acceptor m_Acceptor; + std::thread m_IoThread; + std::filesystem::path m_OutputDir; + std::mutex m_Mutex; + std::atomic<bool> m_IsRunning{false}; + uint16_t m_Port = 0; + + mutable std::mutex m_SessionsMutex; + std::vector<std::weak_ptr<TraceSession>> m_Sessions; +}; + +//////////////////////////////////////////////////////////////////////////////// + +TraceRecorder::TraceRecorder() : m_Impl(std::make_unique<Impl>()) +{ +} + +TraceRecorder::~TraceRecorder() +{ + Shutdown(); +} + +void +TraceRecorder::Initialize(uint16_t InPort, const std::filesystem::path& OutputDir) +{ + m_Impl->Initialize(InPort, OutputDir); +} + +void +TraceRecorder::Shutdown() +{ + m_Impl->Shutdown(); +} + +bool +TraceRecorder::IsRunning() const +{ + return m_Impl->IsRunning(); +} + +uint16_t +TraceRecorder::GetPort() const +{ + return m_Impl->GetPort(); +} + +std::vector<TraceSessionInfo> +TraceRecorder::GetActiveSessions() const +{ + return m_Impl->GetActiveSessions(); +} + +} // namespace zen diff --git a/src/zenserver/trace/tracerecorder.h b/src/zenserver/trace/tracerecorder.h new file mode 100644 index 000000000..48857aec8 --- /dev/null +++ b/src/zenserver/trace/tracerecorder.h @@ -0,0 +1,46 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/guid.h> +#include <zencore/zencore.h> + +#include <filesystem> +#include <memory> +#include <string> +#include <vector> + +namespace zen { + +struct TraceSessionInfo +{ + Guid SessionGuid{}; + Guid TraceGuid{}; + uint16_t ControlPort = 0; + uint8_t TransportVersion = 0; + uint8_t ProtocolVersion = 0; + std::string RemoteAddress; + uint64_t BytesRecorded = 0; + std::filesystem::path TraceFilePath; +}; + +class TraceRecorder +{ +public: + TraceRecorder(); + ~TraceRecorder(); + + void Initialize(uint16_t InPort, const std::filesystem::path& OutputDir); + void Shutdown(); + + bool IsRunning() const; + uint16_t GetPort() const; + + std::vector<TraceSessionInfo> GetActiveSessions() const; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen
\ No newline at end of file diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua index 6ee80dc62..f2ed17f05 100644 --- a/src/zenserver/xmake.lua +++ b/src/zenserver/xmake.lua @@ -2,7 +2,11 @@ target("zenserver") set_kind("binary") + if enable_unity then + add_rules("c++.unity_build", {batchsize = 4}) + end add_deps("zencore", + "zencompute", "zenhttp", "zennet", "zenremotestore", @@ -15,6 +19,12 @@ target("zenserver") add_files("**.cpp") add_files("frontend/*.zip") add_files("zenserver.cpp", {unity_ignored = true }) + + if is_plat("linux") and not (get_config("toolchain") or ""):find("clang") then + -- GCC false positives in deeply inlined code (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100137) + add_files("storage/projectstore/httpprojectstore.cpp", {force = {cxxflags = "-Wno-stringop-overflow"} }) + add_files("storage/storageconfig.cpp", {force = {cxxflags = "-Wno-array-bounds"} }) + end add_includedirs(".") set_symbols("debug") @@ -23,6 +33,8 @@ target("zenserver") add_packages("json11") add_packages("lua") add_packages("consul") + add_packages("oidctoken") + add_packages("nomad") if has_config("zenmimalloc") then add_packages("mimalloc") @@ -32,6 +44,14 @@ target("zenserver") add_packages("sentry-native") end + if has_config("zenhorde") then + add_deps("zenhorde") + end + + if has_config("zennomad") then + add_deps("zennomad") + end + if is_mode("release") then set_optimize("fastest") end @@ -141,4 +161,24 @@ target("zenserver") end copy_if_newer(path.join(installdir, "bin", consul_bin), path.join(target:targetdir(), consul_bin), consul_bin) end + + local oidctoken_pkg = target:pkg("oidctoken") + if oidctoken_pkg then + local installdir = oidctoken_pkg:installdir() + local oidctoken_bin = "OidcToken" + if is_plat("windows") then + oidctoken_bin = "OidcToken.exe" + end + copy_if_newer(path.join(installdir, "bin", oidctoken_bin), path.join(target:targetdir(), oidctoken_bin), oidctoken_bin) + end + + local nomad_pkg = target:pkg("nomad") + if nomad_pkg then + local installdir = nomad_pkg:installdir() + local nomad_bin = "nomad" + if is_plat("windows") then + nomad_bin = "nomad.exe" + end + copy_if_newer(path.join(installdir, "bin", nomad_bin), path.join(target:targetdir(), nomad_bin), nomad_bin) + end end) diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index 2bafeeaa1..bb6b02d21 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -18,11 +18,13 @@ #include <zencore/sentryintegration.h> #include <zencore/session.h> #include <zencore/string.h> +#include <zencore/system.h> #include <zencore/thread.h> #include <zencore/timer.h> #include <zencore/trace.h> #include <zencore/workthreadpool.h> #include <zenhttp/httpserver.h> +#include <zenhttp/security/passwordsecurityfilter.h> #include <zentelemetry/otlptrace.h> #include <zenutil/service.h> #include <zenutil/workerpools.h> @@ -44,6 +46,20 @@ ZEN_THIRD_PARTY_INCLUDES_END ////////////////////////////////////////////////////////////////////////// +#ifndef ZEN_WITH_COMPUTE_SERVICES +# define ZEN_WITH_COMPUTE_SERVICES 0 +#endif + +#ifndef ZEN_WITH_HORDE +# define ZEN_WITH_HORDE 0 +#endif + +#ifndef ZEN_WITH_NOMAD +# define ZEN_WITH_NOMAD 0 +#endif + +////////////////////////////////////////////////////////////////////////// + #include "config/config.h" #include "diag/logging.h" @@ -142,8 +158,18 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState:: ZEN_INFO("Effective concurrency: {} (hw: {})", GetHardwareConcurrency(), std::thread::hardware_concurrency()); + InitializeSecuritySettings(ServerOptions); + + if (ServerOptions.LieCpu) + { + SetCpuCountForReporting(ServerOptions.LieCpu); + + ZEN_INFO("Reporting concurrency: {}", ServerOptions.LieCpu); + } + m_StatusService.RegisterHandler("status", *this); m_Http->RegisterService(m_StatusService); + m_Http->RegisterService(m_StatsService); m_StatsReporter.Initialize(ServerOptions.StatsConfig); if (ServerOptions.StatsConfig.Enabled) @@ -151,10 +177,37 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState:: EnqueueStatsReportingTimer(); } - m_HealthService.SetHealthInfo({.DataRoot = ServerOptions.DataDir, - .AbsLogPath = ServerOptions.AbsLogFile, - .HttpServerClass = std::string(ServerOptions.HttpConfig.ServerClass), - .BuildVersion = std::string(ZEN_CFG_VERSION_BUILD_STRING_FULL)}); + // clang-format off + HealthServiceInfo HealthInfo { + .DataRoot = ServerOptions.DataDir, + .AbsLogPath = ServerOptions.LoggingConfig.AbsLogFile, + .HttpServerClass = std::string(ServerOptions.HttpConfig.ServerClass), + .BuildVersion = std::string(ZEN_CFG_VERSION_BUILD_STRING_FULL), + .Port = EffectiveBasePort, + .Pid = GetCurrentProcessId(), + .IsDedicated = ServerOptions.IsDedicated, + .StartTimeMs = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::system_clock::now().time_since_epoch()).count(), + .BuildOptions = { + {"ZEN_ADDRESS_SANITIZER", ZEN_ADDRESS_SANITIZER != 0}, + {"ZEN_USE_SENTRY", ZEN_USE_SENTRY != 0}, + {"ZEN_WITH_TESTS", ZEN_WITH_TESTS != 0}, + {"ZEN_USE_MIMALLOC", ZEN_USE_MIMALLOC != 0}, + {"ZEN_USE_RPMALLOC", ZEN_USE_RPMALLOC != 0}, + {"ZEN_WITH_HTTPSYS", ZEN_WITH_HTTPSYS != 0}, + {"ZEN_WITH_MEMTRACK", ZEN_WITH_MEMTRACK != 0}, + {"ZEN_WITH_TRACE", ZEN_WITH_TRACE != 0}, + {"ZEN_WITH_COMPUTE_SERVICES", ZEN_WITH_COMPUTE_SERVICES != 0}, + {"ZEN_WITH_HORDE", ZEN_WITH_HORDE != 0}, + {"ZEN_WITH_NOMAD", ZEN_WITH_NOMAD != 0}, + }, + .RuntimeConfig = BuildSettingsList(ServerOptions), + }; + // clang-format on + + HealthInfo.RuntimeConfig.emplace(HealthInfo.RuntimeConfig.begin() + 2, "EffectivePort"sv, fmt::to_string(EffectiveBasePort)); + + m_HealthService.SetHealthInfo(std::move(HealthInfo)); LogSettingsSummary(ServerOptions); @@ -164,12 +217,23 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState:: void ZenServerBase::Finalize() { + m_StatsService.RegisterHandler("http", *m_Http); + + m_Http->SetDefaultRedirect("/dashboard/"); + // Register health service last so if we return "OK" for health it means all services have been properly initialized m_Http->RegisterService(m_HealthService); } void +ZenServerBase::ShutdownServices() +{ + m_StatsService.UnregisterHandler("http", *m_Http); + m_StatsService.Shutdown(); +} + +void ZenServerBase::GetBuildOptions(StringBuilderBase& OutOptions, char Separator) const { ZEN_MEMSCOPE(GetZenserverTag()); @@ -375,46 +439,65 @@ ZenServerBase::CheckSigInt() void ZenServerBase::HandleStatusRequest(HttpServerRequest& Request) { + auto Metrics = m_MetricsTracker.Query(); + CbObjectWriter Cbo; Cbo << "ok" << true; Cbo << "state" << ToString(m_CurrentState); + Cbo << "hostname" << GetMachineName(); + Cbo << "cpuUsagePercent" << Metrics.CpuUsagePercent; Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); } -void -ZenServerBase::LogSettingsSummary(const ZenServerConfig& ServerConfig) +std::vector<std::pair<std::string_view, std::string>> +ZenServerBase::BuildSettingsList(const ZenServerConfig& ServerConfig) { // clang-format off - std::list<std::pair<std::string_view, std::string>> Settings = { - {"DataDir"sv, ServerConfig.DataDir.string()}, - {"AbsLogFile"sv, ServerConfig.AbsLogFile.string()}, - {"SystemRootDir"sv, ServerConfig.SystemRootDir.string()}, - {"ContentDir"sv, ServerConfig.ContentDir.string()}, + std::vector<std::pair<std::string_view, std::string>> Settings = { + {"SystemRootDir"sv, fmt::format("{}", ServerConfig.SystemRootDir)}, + {"ContentDir"sv, fmt::format("{}", ServerConfig.ContentDir)}, {"BasePort"sv, fmt::to_string(ServerConfig.BasePort)}, + {"CoreLimit"sv, fmt::to_string(ServerConfig.CoreLimit)}, {"IsDebug"sv, fmt::to_string(ServerConfig.IsDebug)}, {"IsCleanStart"sv, fmt::to_string(ServerConfig.IsCleanStart)}, {"IsPowerCycle"sv, fmt::to_string(ServerConfig.IsPowerCycle)}, {"IsTest"sv, fmt::to_string(ServerConfig.IsTest)}, {"Detach"sv, fmt::to_string(ServerConfig.Detach)}, - {"NoConsoleOutput"sv, fmt::to_string(ServerConfig.NoConsoleOutput)}, - {"QuietConsole"sv, fmt::to_string(ServerConfig.QuietConsole)}, - {"CoreLimit"sv, fmt::to_string(ServerConfig.CoreLimit)}, - {"IsDedicated"sv, fmt::to_string(ServerConfig.IsDedicated)}, - {"ShouldCrash"sv, fmt::to_string(ServerConfig.ShouldCrash)}, + {"NoConsoleOutput"sv, fmt::to_string(ServerConfig.LoggingConfig.NoConsoleOutput)}, + {"QuietConsole"sv, fmt::to_string(ServerConfig.LoggingConfig.QuietConsole)}, {"ChildId"sv, ServerConfig.ChildId}, - {"LogId"sv, ServerConfig.LogId}, + {"LogId"sv, ServerConfig.LoggingConfig.LogId}, {"Sentry DSN"sv, ServerConfig.SentryConfig.Dsn.empty() ? "not set" : ServerConfig.SentryConfig.Dsn}, {"Sentry Environment"sv, ServerConfig.SentryConfig.Environment}, {"Statsd Enabled"sv, fmt::to_string(ServerConfig.StatsConfig.Enabled)}, + {"SecurityConfigPath"sv, fmt::format("{}", ServerConfig.SecurityConfigPath)}, }; // clang-format on if (ServerConfig.StatsConfig.Enabled) { - Settings.emplace_back("Statsd Host", ServerConfig.StatsConfig.StatsdHost); - Settings.emplace_back("Statsd Port", fmt::to_string(ServerConfig.StatsConfig.StatsdPort)); + Settings.emplace_back("Statsd Host"sv, ServerConfig.StatsConfig.StatsdHost); + Settings.emplace_back("Statsd Port"sv, fmt::to_string(ServerConfig.StatsConfig.StatsdPort)); } + return Settings; +} + +void +ZenServerBase::LogSettingsSummary(const ZenServerConfig& ServerConfig) +{ + auto Settings = BuildSettingsList(ServerConfig); + + // Log-only entries not needed in RuntimeConfig + // clang-format off + Settings.insert(Settings.begin(), { + {"DataDir"sv, fmt::format("{}", ServerConfig.DataDir)}, + {"AbsLogFile"sv, fmt::format("{}", ServerConfig.LoggingConfig.AbsLogFile)}, + }); + // clang-format on + Settings.emplace_back("IsDedicated"sv, fmt::to_string(ServerConfig.IsDedicated)); + Settings.emplace_back("ShouldCrash"sv, fmt::to_string(ServerConfig.ShouldCrash)); + size_t MaxWidth = 0; for (const auto& Setting : Settings) { @@ -432,6 +515,44 @@ ZenServerBase::LogSettingsSummary(const ZenServerConfig& ServerConfig) } } +void +ZenServerBase::InitializeSecuritySettings(const ZenServerConfig& ServerOptions) +{ + ZEN_ASSERT(m_Http); + + if (!ServerOptions.SecurityConfigPath.empty()) + { + IoBuffer SecurityJson = ReadFile(ServerOptions.SecurityConfigPath).Flatten(); + std::string_view Json(reinterpret_cast<const char*>(SecurityJson.GetData()), SecurityJson.GetSize()); + std::string JsonError; + CbObject SecurityConfig = LoadCompactBinaryFromJson(Json, JsonError).AsObject(); + if (!JsonError.empty()) + { + throw std::runtime_error( + fmt::format("Invalid security configuration file at {}. '{}'", ServerOptions.SecurityConfigPath, JsonError)); + } + + CbObjectView HttpRootFilterConfig = SecurityConfig["http"sv].AsObjectView()["root"sv].AsObjectView()["filter"sv].AsObjectView(); + if (HttpRootFilterConfig) + { + std::string_view FilterType = HttpRootFilterConfig["type"sv].AsString(); + if (FilterType == PasswordHttpFilter::TypeName) + { + PasswordHttpFilter::Configuration Config = + PasswordHttpFilter::ReadConfiguration(HttpRootFilterConfig["config"].AsObjectView()); + m_HttpRequestFilter = std::make_unique<PasswordHttpFilter>(Config); + m_Http->SetHttpRequestFilter(m_HttpRequestFilter.get()); + } + else + { + throw std::runtime_error(fmt::format("Security configuration file at {} references unknown http root filter type '{}'", + ServerOptions.SecurityConfigPath, + FilterType)); + } + } + } +} + ////////////////////////////////////////////////////////////////////////// ZenServerMain::ZenServerMain(ZenServerConfig& ServerOptions) : m_ServerOptions(ServerOptions) @@ -467,7 +588,7 @@ ZenServerMain::Run() ZEN_OTEL_SPAN("SentryInit"); std::string SentryDatabasePath = (m_ServerOptions.DataDir / ".sentry-native").string(); - std::string SentryAttachmentPath = m_ServerOptions.AbsLogFile.string(); + std::string SentryAttachmentPath = m_ServerOptions.LoggingConfig.AbsLogFile.string(); Sentry.Initialize({.DatabasePath = SentryDatabasePath, .AttachmentsPath = SentryAttachmentPath, @@ -567,6 +688,8 @@ ZenServerMain::Run() { ZEN_INFO(ZEN_APP_NAME " unable to grab lock at '{}' (reason: '{}'), retrying", LockFilePath, Ec.message()); Sleep(500); + + m_LockFile.Create(LockFilePath, MakeLockData(false), Ec); if (Ec) { ZEN_WARN(ZEN_APP_NAME " exiting, unable to grab lock at '{}' (reason: '{}')", LockFilePath, Ec.message()); @@ -622,6 +745,10 @@ ZenServerMain::Run() RequestApplicationExit(1); } +#if ZEN_USE_SENTRY + Sentry.Close(); +#endif + ShutdownServerLogging(); ReportServiceStatus(ServiceStatus::Stopped); diff --git a/src/zenserver/zenserver.h b/src/zenserver/zenserver.h index ab7122fcc..c06093f0d 100644 --- a/src/zenserver/zenserver.h +++ b/src/zenserver/zenserver.h @@ -3,11 +3,13 @@ #pragma once #include <zencore/basicfile.h> +#include <zencore/system.h> #include <zenhttp/httpserver.h> #include <zenhttp/httpstats.h> #include <zenhttp/httpstatus.h> #include <zenutil/zenserverprocess.h> +#include <atomic> #include <memory> #include <string_view> #include "config/config.h" @@ -43,11 +45,18 @@ public: void SetIsReadyFunc(std::function<void()>&& IsReadyFunc) { m_IsReadyFunc = std::move(IsReadyFunc); } + void SetDataRoot(std::filesystem::path Root) { m_DataRoot = Root; } + void SetContentRoot(std::filesystem::path Root) { m_ContentRoot = Root; } + void SetDedicatedMode(bool State) { m_IsDedicatedMode = State; } + void SetTestMode(bool State) { m_TestMode = State; } + protected: int Initialize(const ZenServerConfig& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry); void Finalize(); + void ShutdownServices(); void GetBuildOptions(StringBuilderBase& OutOptions, char Separator = ',') const; - void LogSettingsSummary(const ZenServerConfig& ServerConfig); + static std::vector<std::pair<std::string_view, std::string>> BuildSettingsList(const ZenServerConfig& ServerConfig); + void LogSettingsSummary(const ZenServerConfig& ServerConfig); protected: NamedMutex m_ServerMutex; @@ -55,6 +64,10 @@ protected: bool m_UseSentry = false; bool m_IsPowerCycle = false; + bool m_IsDedicatedMode = false; + bool m_TestMode = false; + bool m_DebugOptionForcedCrash = false; + std::thread m_IoRunner; asio::io_context m_IoContext; void EnsureIoRunner(); @@ -64,17 +77,26 @@ protected: kInitializing, kRunning, kShuttingDown - } m_CurrentState = kInitializing; + }; + std::atomic<ServerState> m_CurrentState = kInitializing; - inline void SetNewState(ServerState NewState) { m_CurrentState = NewState; } + inline void SetNewState(ServerState NewState) { m_CurrentState.store(NewState, std::memory_order_relaxed); } static std::string_view ToString(ServerState Value); std::function<void()> m_IsReadyFunc; void OnReady(); - Ref<HttpServer> m_Http; - HttpHealthService m_HealthService; - HttpStatusService m_StatusService; + std::filesystem::path m_DataRoot; // Root directory for server state + std::filesystem::path m_ContentRoot; // Root directory for frontend content + + Ref<HttpServer> m_Http; + + std::unique_ptr<IHttpRequestFilter> m_HttpRequestFilter; + + HttpHealthService m_HealthService; + HttpStatsService m_StatsService{m_IoContext}; + HttpStatusService m_StatusService; + SystemMetricsTracker m_MetricsTracker; // Stats reporting @@ -107,8 +129,10 @@ protected: // IHttpStatusProvider virtual void HandleStatusRequest(HttpServerRequest& Request) override; -}; +private: + void InitializeSecuritySettings(const ZenServerConfig& ServerOptions); +}; class ZenServerMain { public: diff --git a/src/zenserver/zenserver.rc b/src/zenserver/zenserver.rc index e0003ea8f..f353bd9cc 100644 --- a/src/zenserver/zenserver.rc +++ b/src/zenserver/zenserver.rc @@ -28,7 +28,7 @@ LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US // Icon with lowest ID value placed first to ensure application icon // remains consistent on all systems. -IDI_ICON1 ICON "..\\UnrealEngine.ico" +IDI_ICON1 ICON "..\\zen.ico" #endif // English (United States) resources ///////////////////////////////////////////////////////////////////////////// diff --git a/src/zenstore-test/zenstore-test.cpp b/src/zenstore-test/zenstore-test.cpp index c055dbb64..875373a9d 100644 --- a/src/zenstore-test/zenstore-test.cpp +++ b/src/zenstore-test/zenstore-test.cpp @@ -1,45 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zencore/filesystem.h> -#include <zencore/logging.h> -#include <zencore/trace.h> +#include <zencore/testing.h> #include <zenstore/zenstore.h> #include <zencore/memory/newdelete.h> -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include <zencore/testing.h> -# include <zencore/process.h> -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zenstore_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zenstore-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zenstore-test", zen::zenstore_forcelinktests); #else return 0; #endif diff --git a/src/zenstore/blockstore.cpp b/src/zenstore/blockstore.cpp index 3ea91ead6..6197c7f24 100644 --- a/src/zenstore/blockstore.cpp +++ b/src/zenstore/blockstore.cpp @@ -1556,6 +1556,8 @@ BlockStore::GetMetaData(uint32_t BlockIndex) const #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("store.blockstore"); + TEST_CASE("blockstore.blockstoredisklocation") { BlockStoreLocation Zero = BlockStoreLocation{.BlockIndex = 0, .Offset = 0, .Size = 0}; @@ -2427,6 +2429,8 @@ TEST_CASE("blockstore.BlockStoreFileAppender") } } +TEST_SUITE_END(); + #endif void diff --git a/src/zenstore/buildstore/buildstore.cpp b/src/zenstore/buildstore/buildstore.cpp index 04a0781d3..dff1c3c61 100644 --- a/src/zenstore/buildstore/buildstore.cpp +++ b/src/zenstore/buildstore/buildstore.cpp @@ -266,13 +266,12 @@ BuildStore::PutBlob(const IoHash& BlobHash, const IoBuffer& Payload) m_BlobLookup.insert({BlobHash, NewBlobIndex}); } - m_LastAccessTimeUpdateCount++; if (m_TrackedBlobKeys) { m_TrackedBlobKeys->push_back(BlobHash); if (MetadataHash != IoHash::Zero) { - m_TrackedBlobKeys->push_back(BlobHash); + m_TrackedBlobKeys->push_back(MetadataHash); } } } @@ -374,8 +373,8 @@ BuildStore::PutMetadatas(std::span<const IoHash> BlobHashes, std::span<const IoB CompressedMetadataBuffers.resize(Metadatas.size()); if (OptionalWorkerPool) { - std::atomic<bool> AbortFlag; - std::atomic<bool> PauseFlag; + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); for (size_t Index = 0; Index < Metadatas.size(); Index++) { @@ -506,8 +505,8 @@ BuildStore::GetMetadatas(std::span<const IoHash> BlobHashes, WorkerThreadPool* O else { ZEN_WARN("Metadata {} for blob {} is malformed (not a compressed binary format)", - MetadataHashes[ResultIndex], - BlobHashes[ResultIndex]); + MetadataHashes[Index], + BlobHashes[MetaLocationResultIndexes[Index]]); } } } @@ -562,7 +561,7 @@ BuildStore::GetStorageStats() const RwLock::SharedLockScope _(m_Lock); Result.EntryCount = m_BlobLookup.size(); - for (auto LookupIt : m_BlobLookup) + for (const auto& LookupIt : m_BlobLookup) { const BlobIndex ReadBlobIndex = LookupIt.second; const BlobEntry& ReadBlobEntry = m_BlobEntries[ReadBlobIndex]; @@ -635,7 +634,7 @@ BuildStore::CompactState() const size_t MetadataCount = m_MetadataEntries.size(); MetadataEntries.reserve(MetadataCount); - for (auto LookupIt : m_BlobLookup) + for (const auto& LookupIt : m_BlobLookup) { const IoHash& BlobHash = LookupIt.first; const BlobIndex ReadBlobIndex = LookupIt.second; @@ -956,7 +955,7 @@ BuildStore::WriteAccessTimes(const RwLock::ExclusiveLockScope&, const std::files std::vector<AccessTimeRecord> AccessRecords; AccessRecords.reserve(Header.AccessTimeCount); - for (auto It : m_BlobLookup) + for (const auto& It : m_BlobLookup) { const IoHash& Key = It.first; const BlobIndex Index = It.second; @@ -966,7 +965,7 @@ BuildStore::WriteAccessTimes(const RwLock::ExclusiveLockScope&, const std::files } uint64_t RecordsSize = sizeof(AccessTimeRecord) * Header.AccessTimeCount; TempFile.Write(AccessRecords.data(), RecordsSize, Offset); - Offset += sizeof(AccessTimesHeader) * Header.AccessTimeCount; + Offset += sizeof(AccessTimeRecord) * Header.AccessTimeCount; } if (TempFile.MoveTemporaryIntoPlace(AccessTimesPath, Ec); Ec) { @@ -1373,6 +1372,8 @@ BuildStore::LockState(GcCtx& Ctx) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("store.buildstore"); + TEST_CASE("BuildStore.Blobs") { ScopedTemporaryDirectory _; @@ -1822,6 +1823,8 @@ TEST_CASE("BuildStore.SizeLimit") } } +TEST_SUITE_END(); + void buildstore_forcelink() { diff --git a/src/zenstore/cache/cachedisklayer.cpp b/src/zenstore/cache/cachedisklayer.cpp index ead7e4f3a..4640309d9 100644 --- a/src/zenstore/cache/cachedisklayer.cpp +++ b/src/zenstore/cache/cachedisklayer.cpp @@ -602,7 +602,7 @@ BucketManifestSerializer::ReadSidecarFile(RwLock::ExclusiveLockScope& B if (FileSize < sizeof(BucketMetaHeader)) { - ZEN_WARN("Failed to read sidecar file '{}'. Minimum size {} expected, actual size: ", + ZEN_WARN("Failed to read sidecar file '{}'. Minimum size {} expected, actual size: {}", SidecarPath, sizeof(BucketMetaHeader), FileSize); @@ -626,7 +626,7 @@ BucketManifestSerializer::ReadSidecarFile(RwLock::ExclusiveLockScope& B return false; } - const uint64_t ExpectedEntryCount = (FileSize - sizeof(sizeof(BucketMetaHeader))) / sizeof(ManifestData); + const uint64_t ExpectedEntryCount = (FileSize - sizeof(BucketMetaHeader)) / sizeof(ManifestData); if (Header.EntryCount > ExpectedEntryCount) { ZEN_WARN( @@ -654,6 +654,7 @@ BucketManifestSerializer::ReadSidecarFile(RwLock::ExclusiveLockScope& B SidecarPath, sizeof(ManifestData), CurrentReadOffset); + break; } CurrentReadOffset += sizeof(ManifestData); @@ -1011,7 +1012,7 @@ ZenCacheDiskLayer::CacheBucket::WriteIndexSnapshotLocked(uint64_t LogPosi { // This is non-critical, it only means that we will replay the events of the log over the snapshot - inefficent but in // the end it will be the same result - ZEN_WARN("snapshot failed to clean log file '{}', reason: '{}'", LogPath, IndexPath, Ec.message()); + ZEN_WARN("snapshot failed to clean log file '{}', reason: '{}'", LogPath, Ec.message()); } m_SlogFile.Open(LogPath, CasLogFile::Mode::kWrite); } @@ -1057,7 +1058,7 @@ ZenCacheDiskLayer::CacheBucket::ReadIndexFile(RwLock::ExclusiveLockScope&, const return 0; } - const uint64_t ExpectedEntryCount = (FileSize - sizeof(sizeof(cache::impl::CacheBucketIndexHeader))) / sizeof(DiskIndexEntry); + const uint64_t ExpectedEntryCount = (FileSize - sizeof(cache::impl::CacheBucketIndexHeader)) / sizeof(DiskIndexEntry); if (Header.EntryCount > ExpectedEntryCount) { return 0; @@ -1267,10 +1268,10 @@ ZenCacheDiskLayer::CacheBucket::InitializeIndexFromDisk(RwLock::ExclusiveLockSco { RemoveMemCachedData(IndexLock, Payload); RemoveMetaData(IndexLock, Payload); + Location.Flags |= DiskLocation::kTombStone; + MissingEntries.push_back(DiskIndexEntry{.Key = It.first, .Location = Location}); } } - Location.Flags |= DiskLocation::kTombStone; - MissingEntries.push_back(DiskIndexEntry{.Key = It.first, .Location = Location}); } ZEN_ASSERT(!MissingEntries.empty()); @@ -2812,7 +2813,7 @@ ZenCacheDiskLayer::CacheBucket::PutStandaloneCacheValue(const IoHash& HashKey, c m_BucketDir, Ec.message(), RetriesLeft); - Sleep(100 - (3 - RetriesLeft) * 100); // Total 600 ms + Sleep((3 - RetriesLeft) * 100); // Total 600 ms Ec.clear(); DataFile.MoveTemporaryIntoPlace(FsPath, Ec); RetriesLeft--; @@ -2866,11 +2867,12 @@ ZenCacheDiskLayer::CacheBucket::PutStandaloneCacheValue(const IoHash& HashKey, c { EntryIndex = It.value(); ZEN_ASSERT_SLOW(EntryIndex < PayloadIndex(m_AccessTimes.size())); - BucketPayload& Payload = m_Payloads[EntryIndex]; - uint64_t OldSize = Payload.Location.Size(); + BucketPayload& Payload = m_Payloads[EntryIndex]; + uint64_t OldSize = Payload.Location.Size(); + RemoveMemCachedData(IndexLock, Payload); + RemoveMetaData(IndexLock, Payload); Payload = BucketPayload{.Location = Loc}; m_AccessTimes[EntryIndex] = GcClock::TickCount(); - RemoveMemCachedData(IndexLock, Payload); m_StandaloneSize.fetch_sub(OldSize, std::memory_order::relaxed); } if ((Value.RawSize != 0 || Value.RawHash != IoHash::Zero) && Value.RawSize <= std::numeric_limits<std::uint32_t>::max()) @@ -3521,7 +3523,7 @@ ZenCacheDiskLayer::CacheBucket::GetReferences(const LoggerRef& Logger, } else { - ZEN_WARN("Cache record {} payload is malformed. Reason: ", RawHash, ToString(Error)); + ZEN_WARN("Cache record {} payload is malformed. Reason: {}", RawHash, ToString(Error)); } return false; }; @@ -4282,8 +4284,8 @@ ZenCacheDiskLayer::DiscoverBuckets() RwLock SyncLock; WorkerThreadPool& Pool = GetLargeWorkerPool(EWorkloadType::Burst); - std::atomic<bool> AbortFlag; - std::atomic<bool> PauseFlag; + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -4454,8 +4456,8 @@ ZenCacheDiskLayer::Flush() } { WorkerThreadPool& Pool = GetMediumWorkerPool(EWorkloadType::Burst); - std::atomic<bool> AbortFlag; - std::atomic<bool> PauseFlag; + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -4496,8 +4498,8 @@ ZenCacheDiskLayer::Scrub(ScrubContext& Ctx) RwLock::SharedLockScope _(m_Lock); - std::atomic<bool> Abort; - std::atomic<bool> Pause; + std::atomic<bool> Abort{false}; + std::atomic<bool> Pause{false}; ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog); try @@ -4559,9 +4561,11 @@ ZenCacheDiskLayer::Stats() const ZenCacheDiskLayer::Info ZenCacheDiskLayer::GetInfo() const { - ZenCacheDiskLayer::Info Info = {.RootDir = m_RootDir, .Config = m_Configuration}; + ZenCacheDiskLayer::Info Info; + Info.RootDir = m_RootDir; { RwLock::SharedLockScope _(m_Lock); + Info.Config = m_Configuration; Info.BucketNames.reserve(m_Buckets.size()); for (auto& Kv : m_Buckets) { diff --git a/src/zenstore/cache/cachepolicy.cpp b/src/zenstore/cache/cachepolicy.cpp index ca8a95ca1..c1e7dc5b3 100644 --- a/src/zenstore/cache/cachepolicy.cpp +++ b/src/zenstore/cache/cachepolicy.cpp @@ -284,6 +284,9 @@ CacheRecordPolicyBuilder::Build() } #if ZEN_WITH_TESTS + +TEST_SUITE_BEGIN("store.cachepolicy"); + TEST_CASE("cachepolicy") { SUBCASE("atomics serialization") @@ -400,13 +403,13 @@ TEST_CASE("cacherecordpolicy") RecordPolicy.Save(Writer); CbObject Saved = Writer.Save()->AsObject(); CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get(); - CHECK(!RecordPolicy.IsUniform()); - CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy); - CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy); - CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap); - CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap); - CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy); - CHECK(RecordPolicy.GetValuePolicies().size() == 2); + CHECK(!Loaded.IsUniform()); + CHECK(Loaded.GetRecordPolicy() == UnionPolicy); + CHECK(Loaded.GetBasePolicy() == DefaultPolicy); + CHECK(Loaded.GetValuePolicy(PartialOid) == PartialOverlap); + CHECK(Loaded.GetValuePolicy(NoOverlapOid) == NoOverlap); + CHECK(Loaded.GetValuePolicy(OtherOid) == DefaultValuePolicy); + CHECK(Loaded.GetValuePolicies().size() == 2); } } @@ -416,6 +419,8 @@ TEST_CASE("cacherecordpolicy") CHECK(Loaded.IsNull()); } } + +TEST_SUITE_END(); #endif void diff --git a/src/zenstore/cache/cacherpc.cpp b/src/zenstore/cache/cacherpc.cpp index 94abcf547..90c5a5e60 100644 --- a/src/zenstore/cache/cacherpc.cpp +++ b/src/zenstore/cache/cacherpc.cpp @@ -866,8 +866,8 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb Request.Complete = false; } } - Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); } + Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); }; m_UpstreamCache.GetCacheRecords(*Namespace, UpstreamRequests, std::move(OnCacheRecordGetComplete)); @@ -934,7 +934,7 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb *Namespace, Key.Bucket, Key.Hash, - Request.RecordObject ? ""sv : " (PARTIAL)"sv, + Request.RecordObject ? " (PARTIAL)"sv : ""sv, Request.Source ? Request.Source->Url : "LOCAL"sv, NiceLatencyNs(Request.ElapsedTimeUs * 1000)); m_CacheStats.MissCount++; @@ -966,7 +966,7 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb } else { - ResponseObject.AddBool(true); + ResponseObject.AddBool(false); } } ResponseObject.EndArray(); diff --git a/src/zenstore/cache/structuredcachestore.cpp b/src/zenstore/cache/structuredcachestore.cpp index 52b494e45..cff0e9a35 100644 --- a/src/zenstore/cache/structuredcachestore.cpp +++ b/src/zenstore/cache/structuredcachestore.cpp @@ -608,7 +608,10 @@ ZenCacheStore::GetBatch::Commit() m_CacheStore.m_HitCount++; OpScope.SetBytes(Result.Value.GetSize()); } - m_CacheStore.m_MissCount++; + else + { + m_CacheStore.m_MissCount++; + } } } } @@ -683,8 +686,8 @@ ZenCacheStore::Get(const CacheRequestContext& Context, return false; } ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Get [{}], bucket '{}', key '{}'", - Context, Namespace, + Context, Bucket, HashKey.ToHexString()); @@ -719,8 +722,8 @@ ZenCacheStore::Get(const CacheRequestContext& Context, } ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Get [{}], bucket '{}', key '{}'", - Context, Namespace, + Context, Bucket, HashKey.ToHexString()); @@ -787,8 +790,8 @@ ZenCacheStore::Put(const CacheRequestContext& Context, } ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Put [{}] bucket '{}', key '{}'", - Context, Namespace, + Context, Bucket, HashKey.ToHexString()); @@ -813,7 +816,7 @@ ZenCacheStore::DropNamespace(std::string_view InNamespace) { std::function<void()> PostDropOp; { - RwLock::SharedLockScope _(m_NamespacesLock); + RwLock::ExclusiveLockScope _(m_NamespacesLock); if (auto It = m_Namespaces.find(std::string(InNamespace)); It != m_Namespaces.end()) { ZenCacheNamespace& Namespace = *It->second; @@ -1392,6 +1395,8 @@ namespace testutils { } // namespace testutils +TEST_SUITE_BEGIN("store.structuredcachestore"); + TEST_CASE("cachestore.store") { ScopedTemporaryDirectory TempDir; @@ -1548,7 +1553,7 @@ TEST_CASE("cachestore.size") } } -TEST_CASE("cachestore.threadedinsert") // * doctest::skip(true)) +TEST_CASE("cachestore.threadedinsert" * doctest::skip()) { // for (uint32_t i = 0; i < 100; ++i) { @@ -2741,6 +2746,8 @@ TEST_CASE("cachestore.newgc.basics") } } +TEST_SUITE_END(); + #endif void diff --git a/src/zenstore/cas.cpp b/src/zenstore/cas.cpp index ed017988f..8855c87d8 100644 --- a/src/zenstore/cas.cpp +++ b/src/zenstore/cas.cpp @@ -153,7 +153,10 @@ CasImpl::Initialize(const CidStoreConfiguration& InConfig) } for (std::future<void>& Result : Work) { - Result.get(); + if (Result.valid()) + { + Result.get(); + } } } } @@ -300,12 +303,12 @@ GetCompactCasResults(CasContainerStrategy& Strategy, }; static void -GetFileCasResults(FileCasStrategy& Strategy, - CasStore::InsertMode Mode, - std::span<IoBuffer> Data, - std::span<IoHash> ChunkHashes, - std::span<size_t> Indexes, - std::vector<CasStore::InsertResult> Results) +GetFileCasResults(FileCasStrategy& Strategy, + CasStore::InsertMode Mode, + std::span<IoBuffer> Data, + std::span<IoHash> ChunkHashes, + std::span<size_t> Indexes, + std::vector<CasStore::InsertResult>& Results) { for (size_t Index : Indexes) { @@ -426,7 +429,7 @@ CasImpl::IterateChunks(std::span<IoHash> DecompressedIds, [&](size_t Index, const IoBuffer& Payload) { IoBuffer Chunk(Payload); Chunk.SetContentType(ZenContentType::kCompressedBinary); - return AsyncCallback(Index, Payload); + return AsyncCallback(Index, Chunk); }, OptionalWorkerPool, LargeSizeLimit == 0 ? m_Config.HugeValueThreshold : Min(LargeSizeLimit, m_Config.HugeValueThreshold))) @@ -439,7 +442,7 @@ CasImpl::IterateChunks(std::span<IoHash> DecompressedIds, [&](size_t Index, const IoBuffer& Payload) { IoBuffer Chunk(Payload); Chunk.SetContentType(ZenContentType::kCompressedBinary); - return AsyncCallback(Index, Payload); + return AsyncCallback(Index, Chunk); }, OptionalWorkerPool, LargeSizeLimit == 0 ? m_Config.TinyValueThreshold : Min(LargeSizeLimit, m_Config.TinyValueThreshold))) @@ -452,7 +455,7 @@ CasImpl::IterateChunks(std::span<IoHash> DecompressedIds, [&](size_t Index, const IoBuffer& Payload) { IoBuffer Chunk(Payload); Chunk.SetContentType(ZenContentType::kCompressedBinary); - return AsyncCallback(Index, Payload); + return AsyncCallback(Index, Chunk); }, OptionalWorkerPool)) { @@ -512,6 +515,8 @@ CreateCasStore(GcManager& Gc) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("store.cas"); + TEST_CASE("CasStore") { ScopedTemporaryDirectory TempDir; @@ -553,6 +558,8 @@ TEST_CASE("CasStore") CHECK(Lookup2); } +TEST_SUITE_END(); + void CAS_forcelink() { diff --git a/src/zenstore/caslog.cpp b/src/zenstore/caslog.cpp index 492ce9317..44664dac2 100644 --- a/src/zenstore/caslog.cpp +++ b/src/zenstore/caslog.cpp @@ -35,7 +35,7 @@ CasLogFile::~CasLogFile() } bool -CasLogFile::IsValid(std::filesystem::path FileName, size_t RecordSize) +CasLogFile::IsValid(const std::filesystem::path& FileName, size_t RecordSize) { if (!IsFile(FileName)) { @@ -71,7 +71,7 @@ CasLogFile::IsValid(std::filesystem::path FileName, size_t RecordSize) } void -CasLogFile::Open(std::filesystem::path FileName, size_t RecordSize, Mode Mode) +CasLogFile::Open(const std::filesystem::path& FileName, size_t RecordSize, Mode Mode) { m_RecordSize = RecordSize; @@ -205,7 +205,7 @@ CasLogFile::Replay(std::function<void(const void*)>&& Handler, uint64_t SkipEntr m_File.Read(ReadBuffer.data(), BytesToRead, LogBaseOffset + ReadOffset); - for (int i = 0; i < int(EntriesToRead); ++i) + for (size_t i = 0; i < EntriesToRead; ++i) { Handler(ReadBuffer.data() + (i * m_RecordSize)); } diff --git a/src/zenstore/cidstore.cpp b/src/zenstore/cidstore.cpp index bedf91287..b20d8f565 100644 --- a/src/zenstore/cidstore.cpp +++ b/src/zenstore/cidstore.cpp @@ -48,13 +48,13 @@ struct CidStore::Impl std::vector<CidStore::InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes, CidStore::InsertMode Mode) { + ZEN_ASSERT(ChunkDatas.size() == RawHashes.size()); if (ChunkDatas.size() == 1) { std::vector<CidStore::InsertResult> Result(1); Result[0] = AddChunk(ChunkDatas[0], RawHashes[0], Mode); return Result; } - ZEN_ASSERT(ChunkDatas.size() == RawHashes.size()); std::vector<IoBuffer> Chunks; Chunks.reserve(ChunkDatas.size()); #if ZEN_BUILD_DEBUG @@ -81,6 +81,7 @@ struct CidStore::Impl m_CasStore.InsertChunks(Chunks, RawHashes, static_cast<CasStore::InsertMode>(Mode)); ZEN_ASSERT(CasResults.size() == ChunkDatas.size()); std::vector<CidStore::InsertResult> Result; + Result.reserve(CasResults.size()); for (const CasStore::InsertResult& CasResult : CasResults) { if (CasResult.New) diff --git a/src/zenstore/compactcas.cpp b/src/zenstore/compactcas.cpp index 5d8f95c9e..b09892687 100644 --- a/src/zenstore/compactcas.cpp +++ b/src/zenstore/compactcas.cpp @@ -153,7 +153,7 @@ CasContainerStrategy::~CasContainerStrategy() } catch (const std::exception& Ex) { - ZEN_ERROR("~CasContainerStrategy failed with: ", Ex.what()); + ZEN_ERROR("~CasContainerStrategy failed with: {}", Ex.what()); } m_Gc.RemoveGcReferenceStore(*this); m_Gc.RemoveGcStorage(this); @@ -440,9 +440,9 @@ CasContainerStrategy::IterateChunks(std::span<const IoHash> ChunkHas return true; } - std::atomic<bool> AbortFlag; + std::atomic<bool> AbortFlag{false}; { - std::atomic<bool> PauseFlag; + std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -559,8 +559,8 @@ CasContainerStrategy::ScrubStorage(ScrubContext& Ctx) std::vector<BlockStoreLocation> ChunkLocations; std::vector<IoHash> ChunkIndexToChunkHash; - std::atomic<bool> Abort; - std::atomic<bool> Pause; + std::atomic<bool> Abort{false}; + std::atomic<bool> Pause{false}; ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog); try @@ -1007,7 +1007,7 @@ CasContainerStrategy::CompactIndex(RwLock::ExclusiveLockScope&) std::vector<BlockStoreDiskLocation> Locations; Locations.reserve(EntryCount); LocationMap.reserve(EntryCount); - for (auto It : m_LocationMap) + for (const auto& It : m_LocationMap) { size_t EntryIndex = Locations.size(); Locations.push_back(m_Locations[It.second]); @@ -1106,7 +1106,7 @@ CasContainerStrategy::MakeIndexSnapshot(bool ResetLog) { // This is non-critical, it only means that we will replay the events of the log over the snapshot - inefficent but in // the end it will be the same result - ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, IndexPath, Ec.message()); + ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, Ec.message()); } m_CasLog.Open(LogPath, CasLogFile::Mode::kWrite); } @@ -1136,7 +1136,7 @@ CasContainerStrategy::ReadIndexFile(const std::filesystem::path& IndexPath, uint uint64_t Size = ObjectIndexFile.FileSize(); if (Size >= sizeof(CasDiskIndexHeader)) { - uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(CasDiskIndexHeader))) / sizeof(CasDiskIndexEntry); + uint64_t ExpectedEntryCount = (Size - sizeof(CasDiskIndexHeader)) / sizeof(CasDiskIndexEntry); CasDiskIndexHeader Header; ObjectIndexFile.Read(&Header, sizeof(Header), 0); if ((Header.Magic == CasDiskIndexHeader::ExpectedMagic) && (Header.Version == CasDiskIndexHeader::CurrentVersion) && @@ -1348,6 +1348,8 @@ CasContainerStrategy::OpenContainer(bool IsNewStore) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("store.compactcas"); + TEST_CASE("compactcas.hex") { uint32_t Value; @@ -2159,6 +2161,8 @@ TEST_CASE("compactcas.iteratechunks") } } +TEST_SUITE_END(); + #endif void diff --git a/src/zenstore/filecas.cpp b/src/zenstore/filecas.cpp index 31b3a68c4..0088afe6e 100644 --- a/src/zenstore/filecas.cpp +++ b/src/zenstore/filecas.cpp @@ -383,7 +383,7 @@ FileCasStrategy::InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash, CasStore:: HRESULT WriteRes = PayloadFile.Write(Cursor, Size); if (FAILED(WriteRes)) { - ThrowSystemException(hRes, fmt::format("failed to write {} bytes to shard file '{}'", ChunkSize, ChunkPath)); + ThrowSystemException(WriteRes, fmt::format("failed to write {} bytes to shard file '{}'", ChunkSize, ChunkPath)); } }; #else @@ -669,8 +669,8 @@ FileCasStrategy::IterateChunks(std::span<IoHash> ChunkHashes, return true; }; - std::atomic<bool> AbortFlag; - std::atomic<bool> PauseFlag; + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -823,8 +823,8 @@ FileCasStrategy::ScrubStorage(ScrubContext& Ctx) ZEN_INFO("discovered {} files @ '{}' ({} not in index), scrubbing", m_Index.size(), m_RootDirectory, DiscoveredFilesNotInIndex); - std::atomic<bool> Abort; - std::atomic<bool> Pause; + std::atomic<bool> Abort{false}; + std::atomic<bool> Pause{false}; ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog); try @@ -1016,7 +1016,7 @@ FileCasStrategy::MakeIndexSnapshot(bool ResetLog) { // This is non-critical, it only means that we will replay the events of the log over the snapshot - inefficent but in // the end it will be the same result - ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, IndexPath, Ec.message()); + ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, Ec.message()); } m_CasLog.Open(LogPath, CasLogFile::Mode::kWrite); } @@ -1052,7 +1052,7 @@ FileCasStrategy::ReadIndexFile(const std::filesystem::path& IndexPath, uint32_t& uint64_t Size = ObjectIndexFile.FileSize(); if (Size >= sizeof(FileCasIndexHeader)) { - uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(FileCasIndexHeader))) / sizeof(FileCasIndexEntry); + uint64_t ExpectedEntryCount = (Size - sizeof(FileCasIndexHeader)) / sizeof(FileCasIndexEntry); FileCasIndexHeader Header; ObjectIndexFile.Read(&Header, sizeof(Header), 0); if ((Header.Magic == FileCasIndexHeader::ExpectedMagic) && (Header.Version == FileCasIndexHeader::CurrentVersion) && @@ -1496,6 +1496,8 @@ FileCasStrategy::CreateReferencePruner(GcCtx& Ctx, GcReferenceStoreStats&) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("store.filecas"); + TEST_CASE("cas.chunk.mismatch") { } @@ -1793,6 +1795,8 @@ TEST_CASE("cas.file.move") # endif } +TEST_SUITE_END(); + #endif void diff --git a/src/zenstore/filecas.h b/src/zenstore/filecas.h index e93356927..41756b65f 100644 --- a/src/zenstore/filecas.h +++ b/src/zenstore/filecas.h @@ -74,7 +74,7 @@ private: { static const uint32_t kTombStone = 0x0000'0001; - bool IsFlagSet(const uint32_t Flag) const { return (Flags & kTombStone) == Flag; } + bool IsFlagSet(const uint32_t Flag) const { return (Flags & Flag) == Flag; } IoHash Key; uint32_t Flags = 0; diff --git a/src/zenstore/gc.cpp b/src/zenstore/gc.cpp index 14caa5abf..b3450b805 100644 --- a/src/zenstore/gc.cpp +++ b/src/zenstore/gc.cpp @@ -1494,7 +1494,8 @@ GcManager::CollectGarbage(const GcSettings& Settings) GcReferenceValidatorStats& Stats = Result.ReferenceValidatorStats[It.second].second; try { - // Go through all the ReferenceCheckers to see if the list of Cids the collector selected are referenced or + // Go through all the ReferenceCheckers to see if the list of Cids the collector selected + // are referenced or not SCOPED_TIMER(Stats.ElapsedMS = std::chrono::milliseconds(Timer.GetElapsedTimeMs());); ReferenceValidator->Validate(Ctx, Stats); } @@ -1952,7 +1953,7 @@ GcScheduler::AppendGCLog(std::string_view Id, GcClock::TimePoint StartTime, cons Writer << "SingleThread"sv << Settings.SingleThread; Writer << "CompactBlockUsageThresholdPercent"sv << Settings.CompactBlockUsageThresholdPercent; Writer << "AttachmentRangeMin"sv << Settings.AttachmentRangeMin; - Writer << "AttachmentRangeMax"sv << Settings.AttachmentRangeMin; + Writer << "AttachmentRangeMax"sv << Settings.AttachmentRangeMax; Writer << "ForceStoreCacheAttachmentMetaData"sv << Settings.StoreCacheAttachmentMetaData; Writer << "ForceStoreProjectAttachmentMetaData"sv << Settings.StoreProjectAttachmentMetaData; Writer << "EnableValidation"sv << Settings.EnableValidation; @@ -2893,7 +2894,7 @@ GcScheduler::CollectGarbage(const GcClock::TimePoint& CacheExpireTime, { m_LastFullGCV2Result = Result; m_LastFullAttachmentRangeMin = AttachmentRangeMin; - m_LastFullAttachmentRangeMin = AttachmentRangeMax; + m_LastFullAttachmentRangeMax = AttachmentRangeMax; } Diff.DiskSize = Result.CompactStoresStatSum.RemovedDisk; Diff.MemorySize = Result.ReferencerStatSum.RemoveExpiredDataStats.FreedMemory; @@ -3048,6 +3049,8 @@ GcScheduler::CollectGarbage(const GcClock::TimePoint& CacheExpireTime, #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("store.gc"); + TEST_CASE("gc.diskusagewindow") { DiskUsageWindow Stats; @@ -3379,6 +3382,8 @@ TEST_CASE("gc.attachmentrange") CHECK(AttachmentRangeMax == IoHash::Max); } +TEST_SUITE_END(); + #endif void diff --git a/src/zenstore/include/zenstore/buildstore/buildstore.h b/src/zenstore/include/zenstore/buildstore/buildstore.h index 76cba05b9..ea2ef7f89 100644 --- a/src/zenstore/include/zenstore/buildstore/buildstore.h +++ b/src/zenstore/include/zenstore/buildstore/buildstore.h @@ -1,5 +1,5 @@ - // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include <zenstore/blockstore.h> @@ -223,7 +223,7 @@ private: uint64_t m_MetaLogFlushPosition = 0; std::unique_ptr<std::vector<IoHash>> m_TrackedBlobKeys; - std::atomic<uint64_t> m_LastAccessTimeUpdateCount; + std::atomic<uint64_t> m_LastAccessTimeUpdateCount{0}; friend class BuildStoreGcReferenceChecker; friend class BuildStoreGcReferencePruner; diff --git a/src/zenstore/include/zenstore/cache/cachedisklayer.h b/src/zenstore/include/zenstore/cache/cachedisklayer.h index 3d684587d..393e289ac 100644 --- a/src/zenstore/include/zenstore/cache/cachedisklayer.h +++ b/src/zenstore/include/zenstore/cache/cachedisklayer.h @@ -153,14 +153,14 @@ public: struct BucketStats { - uint64_t DiskSize; - uint64_t MemorySize; - uint64_t DiskHitCount; - uint64_t DiskMissCount; - uint64_t DiskWriteCount; - uint64_t MemoryHitCount; - uint64_t MemoryMissCount; - uint64_t MemoryWriteCount; + uint64_t DiskSize = 0; + uint64_t MemorySize = 0; + uint64_t DiskHitCount = 0; + uint64_t DiskMissCount = 0; + uint64_t DiskWriteCount = 0; + uint64_t MemoryHitCount = 0; + uint64_t MemoryMissCount = 0; + uint64_t MemoryWriteCount = 0; metrics::RequestStatsSnapshot PutOps; metrics::RequestStatsSnapshot GetOps; }; @@ -174,8 +174,8 @@ public: struct DiskStats { std::vector<NamedBucketStats> BucketStats; - uint64_t DiskSize; - uint64_t MemorySize; + uint64_t DiskSize = 0; + uint64_t MemorySize = 0; }; struct PutResult @@ -395,12 +395,12 @@ public: TCasLogFile<DiskIndexEntry> m_SlogFile; uint64_t m_LogFlushPosition = 0; - std::atomic<uint64_t> m_DiskHitCount; - std::atomic<uint64_t> m_DiskMissCount; - std::atomic<uint64_t> m_DiskWriteCount; - std::atomic<uint64_t> m_MemoryHitCount; - std::atomic<uint64_t> m_MemoryMissCount; - std::atomic<uint64_t> m_MemoryWriteCount; + std::atomic<uint64_t> m_DiskHitCount{0}; + std::atomic<uint64_t> m_DiskMissCount{0}; + std::atomic<uint64_t> m_DiskWriteCount{0}; + std::atomic<uint64_t> m_MemoryHitCount{0}; + std::atomic<uint64_t> m_MemoryMissCount{0}; + std::atomic<uint64_t> m_MemoryWriteCount{0}; metrics::RequestStats m_PutOps; metrics::RequestStats m_GetOps; @@ -540,7 +540,7 @@ private: Configuration m_Configuration; std::atomic_uint64_t m_TotalMemCachedSize{}; std::atomic_bool m_IsMemCacheTrimming = false; - std::atomic<GcClock::Tick> m_NextAllowedTrimTick; + std::atomic<GcClock::Tick> m_NextAllowedTrimTick{}; mutable RwLock m_Lock; BucketMap_t m_Buckets; std::vector<std::unique_ptr<CacheBucket>> m_DroppedBuckets; diff --git a/src/zenstore/include/zenstore/cache/cacheshared.h b/src/zenstore/include/zenstore/cache/cacheshared.h index 791720589..8e9cd7fd7 100644 --- a/src/zenstore/include/zenstore/cache/cacheshared.h +++ b/src/zenstore/include/zenstore/cache/cacheshared.h @@ -40,12 +40,12 @@ struct CacheValueDetails { struct ValueDetails { - uint64_t Size; - uint64_t RawSize; + uint64_t Size = 0; + uint64_t RawSize = 0; IoHash RawHash; GcClock::Tick LastAccess{}; std::vector<IoHash> Attachments; - ZenContentType ContentType; + ZenContentType ContentType = ZenContentType::kBinary; }; struct BucketDetails diff --git a/src/zenstore/include/zenstore/cache/structuredcachestore.h b/src/zenstore/include/zenstore/cache/structuredcachestore.h index 5a0a8b069..3722a0d31 100644 --- a/src/zenstore/include/zenstore/cache/structuredcachestore.h +++ b/src/zenstore/include/zenstore/cache/structuredcachestore.h @@ -70,9 +70,9 @@ public: struct NamespaceStats { - uint64_t HitCount; - uint64_t MissCount; - uint64_t WriteCount; + uint64_t HitCount = 0; + uint64_t MissCount = 0; + uint64_t WriteCount = 0; metrics::RequestStatsSnapshot PutOps; metrics::RequestStatsSnapshot GetOps; ZenCacheDiskLayer::DiskStats DiskStats; @@ -342,11 +342,11 @@ private: void LogWorker(); RwLock m_LogQueueLock; std::vector<AccessLogItem> m_LogQueue; - std::atomic_bool m_ExitLogging; + std::atomic_bool m_ExitLogging{false}; Event m_LogEvent; std::thread m_AsyncLoggingThread; - std::atomic_bool m_WriteLogEnabled; - std::atomic_bool m_AccessLogEnabled; + std::atomic_bool m_WriteLogEnabled{false}; + std::atomic_bool m_AccessLogEnabled{false}; friend class CacheStoreReferenceChecker; }; diff --git a/src/zenstore/include/zenstore/caslog.h b/src/zenstore/include/zenstore/caslog.h index f3dd32fb1..7967d9dae 100644 --- a/src/zenstore/include/zenstore/caslog.h +++ b/src/zenstore/include/zenstore/caslog.h @@ -20,8 +20,8 @@ public: kTruncate }; - static bool IsValid(std::filesystem::path FileName, size_t RecordSize); - void Open(std::filesystem::path FileName, size_t RecordSize, Mode Mode); + static bool IsValid(const std::filesystem::path& FileName, size_t RecordSize); + void Open(const std::filesystem::path& FileName, size_t RecordSize, Mode Mode); void Append(const void* DataPointer, uint64_t DataSize); void Replay(std::function<void(const void*)>&& Handler, uint64_t SkipEntryCount); void Flush(); @@ -48,7 +48,7 @@ private: static_assert(sizeof(FileHeader) == 64); private: - void Open(std::filesystem::path FileName, size_t RecordSize, BasicFile::Mode Mode); + void Open(const std::filesystem::path& FileName, size_t RecordSize, BasicFile::Mode Mode); BasicFile m_File; FileHeader m_Header; @@ -60,8 +60,8 @@ template<typename T> class TCasLogFile : public CasLogFile { public: - static bool IsValid(std::filesystem::path FileName) { return CasLogFile::IsValid(FileName, sizeof(T)); } - void Open(std::filesystem::path FileName, Mode Mode) { CasLogFile::Open(FileName, sizeof(T), Mode); } + static bool IsValid(const std::filesystem::path& FileName) { return CasLogFile::IsValid(FileName, sizeof(T)); } + void Open(const std::filesystem::path& FileName, Mode Mode) { CasLogFile::Open(FileName, sizeof(T), Mode); } // This should be called before the Replay() is called to do some basic sanity checking bool Initialize() { return true; } diff --git a/src/zenstore/include/zenstore/gc.h b/src/zenstore/include/zenstore/gc.h index 734d2e5a7..67cf852f9 100644 --- a/src/zenstore/include/zenstore/gc.h +++ b/src/zenstore/include/zenstore/gc.h @@ -238,7 +238,7 @@ bool FilterReferences(GcCtx& Ctx, std::string_view Context, std::vector<IoHa /** * @brief An interface to implement a lock for Stop The World (from writing new data) * - * This interface is registered/unregistered to GcManager vua AddGcReferenceLocker() and RemoveGcReferenceLockerr() + * This interface is registered/unregistered to GcManager via AddGcReferenceLocker() and RemoveGcReferenceLocker() */ class GcReferenceLocker { @@ -443,8 +443,8 @@ struct GcSchedulerState uint64_t DiskFree = 0; GcClock::TimePoint LastFullGcTime{}; GcClock::TimePoint LastLightweightGcTime{}; - std::chrono::seconds RemainingTimeUntilLightweightGc; - std::chrono::seconds RemainingTimeUntilFullGc; + std::chrono::seconds RemainingTimeUntilLightweightGc{}; + std::chrono::seconds RemainingTimeUntilFullGc{}; uint64_t RemainingSpaceUntilFullGC = 0; std::chrono::milliseconds LastFullGcDuration{}; @@ -562,7 +562,7 @@ private: GcClock::TimePoint m_LastGcExpireTime{}; IoHash m_LastFullAttachmentRangeMin = IoHash::Zero; IoHash m_LastFullAttachmentRangeMax = IoHash::Max; - uint8_t m_AttachmentPassIndex; + uint8_t m_AttachmentPassIndex = 0; std::chrono::milliseconds m_LastFullGcDuration{}; GcStorageSize m_LastFullGCDiff; diff --git a/src/zenstore/include/zenstore/projectstore.h b/src/zenstore/include/zenstore/projectstore.h index 33ef996db..6f49cd024 100644 --- a/src/zenstore/include/zenstore/projectstore.h +++ b/src/zenstore/include/zenstore/projectstore.h @@ -67,8 +67,8 @@ public: struct OplogEntryAddress { - uint32_t Offset; // note: Multiple of m_OpsAlign! - uint32_t Size; + uint32_t Offset = 0; // note: Multiple of m_OpsAlign! + uint32_t Size = 0; }; struct OplogEntry @@ -80,11 +80,7 @@ public: uint32_t Reserved; inline bool IsTombstone() const { return OpCoreAddress.Offset == 0 && OpCoreAddress.Size == 0 && OpLsn.Number; } - inline void MakeTombstone() - { - OpLsn = {}; - OpCoreAddress.Offset = OpCoreAddress.Size = OpCoreHash = Reserved = 0; - } + inline void MakeTombstone() { OpCoreAddress.Offset = OpCoreAddress.Size = OpCoreHash = Reserved = 0; } }; static_assert(IsPow2(sizeof(OplogEntry))); diff --git a/src/zenstore/projectstore.cpp b/src/zenstore/projectstore.cpp index 1ab2b317a..03086b473 100644 --- a/src/zenstore/projectstore.cpp +++ b/src/zenstore/projectstore.cpp @@ -1488,7 +1488,7 @@ ProjectStore::Oplog::Read() else { std::vector<OplogEntry> OpLogEntries; - uint64_t InvalidEntries; + uint64_t InvalidEntries = 0; m_Storage->ReadOplogEntriesFromLog(OpLogEntries, InvalidEntries, m_LogFlushPosition); for (const OplogEntry& OpEntry : OpLogEntries) { @@ -1750,8 +1750,8 @@ ProjectStore::Oplog::Validate(const std::filesystem::path& ProjectRootDir, } }; - std::atomic<bool> AbortFlag; - std::atomic<bool> PauseFlag; + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -2373,7 +2373,7 @@ ProjectStore::Oplog::IterateChunks(const std::filesystem::path& P else if (auto MetaIt = m_MetaMap.find(ChunkId); MetaIt != m_MetaMap.end()) { CidChunkIndexes.push_back(ChunkIndex); - CidChunkHashes.push_back(ChunkIt->second); + CidChunkHashes.push_back(MetaIt->second); } else if (auto FileIt = m_FileMap.find(ChunkId); FileIt != m_FileMap.end()) { @@ -2384,8 +2384,8 @@ ProjectStore::Oplog::IterateChunks(const std::filesystem::path& P } if (OptionalWorkerPool) { - std::atomic<bool> AbortFlag; - std::atomic<bool> PauseFlag; + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -3817,7 +3817,7 @@ ProjectStore::Project::OpenOplog(std::string_view OplogId, bool AllowCompact, bo std::filesystem::path DeletePath; if (!RemoveOplog(OplogId, DeletePath)) { - ZEN_WARN("Failed to clean up deleted oplog {}/{}", Identifier, OplogId, OplogBasePath); + ZEN_WARN("Failed to clean up deleted oplog {}/{} at '{}'", Identifier, OplogId, OplogBasePath); } ReOpen = true; @@ -4053,8 +4053,8 @@ ProjectStore::Project::Scrub(ScrubContext& Ctx) RwLock::SharedLockScope _(m_ProjectLock); - std::atomic<bool> Abort; - std::atomic<bool> Pause; + std::atomic<bool> Abort{false}; + std::atomic<bool> Pause{false}; ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog); try @@ -4360,7 +4360,7 @@ ProjectStore::ProjectStore(CidStore& Store, std::filesystem::path BasePath, GcMa , m_DiskWriteBlocker(Gc.GetDiskWriteBlocker()) { ZEN_INFO("initializing project store at '{}'", m_ProjectBasePath); - // m_Log.set_level(spdlog::level::debug); + // m_Log.SetLogLevel(zen::logging::Debug); m_Gc.AddGcStorage(this); m_Gc.AddGcReferencer(*this); m_Gc.AddGcReferenceLocker(*this); @@ -4433,8 +4433,8 @@ ProjectStore::Flush() } WorkerThreadPool& WorkerPool = GetSmallWorkerPool(EWorkloadType::Burst); - std::atomic<bool> AbortFlag; - std::atomic<bool> PauseFlag; + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -4712,6 +4712,13 @@ ProjectStore::GetProjectsList() Response << "ProjectRootDir"sv << PathToUtf8(Prj.ProjectRootDir); Response << "EngineRootDir"sv << PathToUtf8(Prj.EngineRootDir); Response << "ProjectFilePath"sv << PathToUtf8(Prj.ProjectFilePath); + + const auto AccessTime = Prj.LastOplogAccessTime(""sv); + if (AccessTime != GcClock::TimePoint::min()) + { + Response << "LastAccessTime"sv << gsl::narrow<uint64_t>(AccessTime.time_since_epoch().count()); + } + Response.EndObject(); }); Response.EndArray(); @@ -4974,7 +4981,7 @@ ProjectStore::GetProjectChunkInfos(LoggerRef InLog, Project& Project, Oplog& Opl } if (WantsRawSizeField) { - ZEN_ASSERT_SLOW(Sizes[Index] == (uint64_t)-1); + ZEN_ASSERT_SLOW(RawSizes[Index] == (uint64_t)-1); RawSizes[Index] = Payload.GetSize(); } } @@ -5762,7 +5769,7 @@ public: } } - for (auto ProjectIt : m_ProjectStore.m_Projects) + for (const auto& ProjectIt : m_ProjectStore.m_Projects) { Ref<ProjectStore::Project> Project = ProjectIt.second; std::vector<std::string> OplogsToCompact = Project->GetOplogsToCompact(); @@ -6802,6 +6809,8 @@ namespace testutils { } // namespace testutils +TEST_SUITE_BEGIN("store.projectstore"); + TEST_CASE("project.opkeys") { using namespace std::literals; @@ -8473,6 +8482,8 @@ TEST_CASE("project.store.iterateoplog") } } +TEST_SUITE_END(); + #endif void diff --git a/src/zenstore/workspaces.cpp b/src/zenstore/workspaces.cpp index f0f975af4..ad21bbc68 100644 --- a/src/zenstore/workspaces.cpp +++ b/src/zenstore/workspaces.cpp @@ -383,7 +383,7 @@ Workspace::GetShares() const { std::vector<Ref<WorkspaceShare>> Shares; Shares.reserve(m_Shares.size()); - for (auto It : m_Shares) + for (const auto& It : m_Shares) { Shares.push_back(It.second); } @@ -435,7 +435,7 @@ Workspaces::RefreshWorkspaceShares(const Oid& WorkspaceId) Workspace = FindWorkspace(Lock, WorkspaceId); if (Workspace) { - for (auto Share : Workspace->GetShares()) + for (const auto& Share : Workspace->GetShares()) { DeletedShares.insert(Share->GetConfig().Id); } @@ -482,6 +482,12 @@ Workspaces::RefreshWorkspaceShares(const Oid& WorkspaceId) m_ShareAliases.erase(Share->GetConfig().Alias); } Workspace->SetShare(Configuration.Id, std::move(NewShare)); + if (!Configuration.Alias.empty()) + { + m_ShareAliases.insert_or_assign( + Configuration.Alias, + ShareAlias{.WorkspaceId = WorkspaceId, .ShareId = Configuration.Id}); + } } } else @@ -602,7 +608,7 @@ Workspaces::GetWorkspaceShareChunks(const Oid& WorkspaceId, { RequestedOffset = Size; } - if ((RequestedOffset + RequestedSize) > Size) + if (RequestedSize > Size - RequestedOffset) { RequestedSize = Size - RequestedOffset; } @@ -649,7 +655,7 @@ Workspaces::GetWorkspaces() const { std::vector<Oid> Workspaces; RwLock::SharedLockScope Lock(m_Lock); - for (auto It : m_Workspaces) + for (const auto& It : m_Workspaces) { Workspaces.push_back(It.first); } @@ -679,7 +685,7 @@ Workspaces::GetWorkspaceShares(const Oid& WorkspaceId) const if (Workspace) { std::vector<Oid> Shares; - for (auto Share : Workspace->GetShares()) + for (const auto& Share : Workspace->GetShares()) { Shares.push_back(Share->GetConfig().Id); } @@ -1356,6 +1362,8 @@ namespace { } // namespace +TEST_SUITE_BEGIN("store.workspaces"); + TEST_CASE("workspaces.scanfolder") { using namespace std::literals; @@ -1559,6 +1567,8 @@ TEST_CASE("workspace.share.alias") CHECK(!WS.GetShareAlias("my_share").has_value()); } +TEST_SUITE_END(); + #endif void diff --git a/src/zentelemetry-test/zentelemetry-test.cpp b/src/zentelemetry-test/zentelemetry-test.cpp index 83fd549db..5a2ac74de 100644 --- a/src/zentelemetry-test/zentelemetry-test.cpp +++ b/src/zentelemetry-test/zentelemetry-test.cpp @@ -1,45 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zencore/filesystem.h> -#include <zencore/logging.h> -#include <zencore/trace.h> +#include <zencore/testing.h> #include <zentelemetry/zentelemetry.h> #include <zencore/memory/newdelete.h> -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include <zencore/testing.h> -# include <zencore/process.h> -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zentelemetry_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zenstore-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zentelemetry-test", zen::zentelemetry_forcelinktests); #else return 0; #endif diff --git a/src/zentelemetry/include/zentelemetry/otlpencoder.h b/src/zentelemetry/include/zentelemetry/otlpencoder.h index ed6665781..f280aa9ec 100644 --- a/src/zentelemetry/include/zentelemetry/otlpencoder.h +++ b/src/zentelemetry/include/zentelemetry/otlpencoder.h @@ -13,9 +13,9 @@ # include <protozero/pbf_builder.hpp> # include <protozero/types.hpp> -namespace spdlog { namespace details { - struct log_msg; -}} // namespace spdlog::details +namespace zen::logging { +struct LogMessage; +} // namespace zen::logging namespace zen::otel { enum class Resource : protozero::pbf_tag_type; @@ -46,7 +46,7 @@ public: void AddResourceAttribute(const std::string_view& Key, const std::string_view& Value); void AddResourceAttribute(const std::string_view& Key, int64_t Value); - std::string FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const; + std::string FormatOtelProtobuf(const logging::LogMessage& Msg) const; std::string FormatOtelMetrics() const; std::string FormatOtelTrace(zen::otel::TraceId Trace, std::span<const zen::otel::Span*> Spans) const; diff --git a/src/zentelemetry/include/zentelemetry/otlptrace.h b/src/zentelemetry/include/zentelemetry/otlptrace.h index 49dd90358..95718af55 100644 --- a/src/zentelemetry/include/zentelemetry/otlptrace.h +++ b/src/zentelemetry/include/zentelemetry/otlptrace.h @@ -317,6 +317,7 @@ public: ExtendableStringBuilder<128> NameBuilder; NamingFunction(NameBuilder); + Initialize(NameBuilder); } /** Construct a new span with a naming function AND initializer function @@ -350,7 +351,13 @@ public: // Execute a function with the span pointer if valid. This can // be used to add attributes or events to the span after creation - inline void WithSpan(auto Func) const { Func(*m_Span); } + inline void WithSpan(auto Func) const + { + if (m_Span) + { + Func(*m_Span); + } + } private: void Initialize(std::string_view Name); diff --git a/src/zentelemetry/include/zentelemetry/stats.h b/src/zentelemetry/include/zentelemetry/stats.h index 3e67bac1c..260b0fcfb 100644 --- a/src/zentelemetry/include/zentelemetry/stats.h +++ b/src/zentelemetry/include/zentelemetry/stats.h @@ -16,11 +16,17 @@ class CbObjectWriter; namespace zen::metrics { +/** A single atomic value that can be set and read at any time. + * + * Useful for point-in-time readings such as queue depth, active connection count, + * or any value where only the current state matters rather than history. + */ template<typename T> class Gauge { public: Gauge() : m_Value{0} {} + explicit Gauge(T InitialValue) : m_Value{InitialValue} {} T Value() const { return m_Value; } void SetValue(T Value) { m_Value = Value; } @@ -29,12 +35,12 @@ private: std::atomic<T> m_Value; }; -/** Stats counter +/** Monotonically increasing (or decreasing) counter. * - * A counter is modified by adding or subtracting a value from a current value. - * This would typically be used to track number of requests in flight, number - * of active jobs etc + * Suitable for tracking quantities that go up and down over time, such as + * requests in flight or active jobs. All operations are lock-free via atomics. * + * Unlike a Meter, a Counter does not track rates — it only records a running total. */ class Counter { @@ -50,34 +56,56 @@ private: std::atomic<uint64_t> m_count{0}; }; -/** Exponential Weighted Moving Average - - This is very raw, to use as little state as possible. If we - want to use this more broadly in user code we should perhaps - add a more user-friendly wrapper +/** Low-level exponential weighted moving average. + * + * Tracks a smoothed rate using the standard EWMA recurrence: + * + * rate = rate + alpha * (instantRate - rate) + * + * where instantRate = Count / Interval. The alpha value controls how quickly + * the average responds to changes — higher alpha means more weight on recent + * samples. Typical alphas are derived from a decay half-life (e.g. 1, 5, 15 + * minutes) and a fixed tick interval. + * + * This class is intentionally minimal to keep per-instance state to a single + * atomic double. See Meter for a more convenient wrapper. */ - class RawEWMA { public: - /// <summary> - /// Update EWMA with new measure - /// </summary> - /// <param name="Alpha">Smoothing factor (between 0 and 1)</param> - /// <param name="Interval">Elapsed time since last</param> - /// <param name="Count">Value</param> - /// <param name="IsInitialUpdate">Whether this is the first update or not</param> - void Tick(double Alpha, uint64_t Interval, uint64_t Count, bool IsInitialUpdate); + /** Update the EWMA with a new observation. + * + * @param Alpha Smoothing factor in (0, 1). Smaller values give a + * slower-moving average; larger values track recent + * changes more aggressively. + * @param Interval Elapsed hi-freq timer ticks since the last Tick call. + * Used to compute the instantaneous rate as Count/Interval. + * @param Count Number of events observed during this interval. + * @param IsInitialUpdate True on the very first call: seeds the rate directly + * from the instantaneous rate rather than blending it in. + */ + void Tick(double Alpha, uint64_t Interval, uint64_t Count, bool IsInitialUpdate); + + /** Returns the current smoothed rate in events per second. */ double Rate() const; private: std::atomic<double> m_Rate = 0; }; -/// <summary> -/// Tracks rate of events over time (i.e requests/sec), using -/// exponential moving averages -/// </summary> +/** Tracks the rate of events over time using exponential moving averages. + * + * Maintains three EWMA windows (1, 5, 15 minutes) in addition to a simple + * mean rate computed from the total count and elapsed wall time since + * construction. This mirrors the load-average conventions familiar from Unix. + * + * Rate updates are batched: Mark() accumulates a pending count and the EWMA + * is only advanced every ~5 seconds (controlled by kTickIntervalInSeconds), + * keeping contention low even under heavy call rates. Rates are returned in + * events per second. + * + * All operations are thread-safe via lock-free atomics. + */ class Meter { public: @@ -85,18 +113,18 @@ public: ~Meter(); inline uint64_t Count() const { return m_TotalCount; } - double Rate1(); // One-minute rate - double Rate5(); // Five-minute rate - double Rate15(); // Fifteen-minute rate - double MeanRate() const; // Mean rate since instantiation of this meter + double Rate1(); // One-minute EWMA rate (events/sec) + double Rate5(); // Five-minute EWMA rate (events/sec) + double Rate15(); // Fifteen-minute EWMA rate (events/sec) + double MeanRate() const; // Mean rate since instantiation (events/sec) void Mark(uint64_t Count = 1); // Register one or more events private: std::atomic<uint64_t> m_TotalCount{0}; // Accumulator counting number of marks since beginning - std::atomic<uint64_t> m_PendingCount{0}; // Pending EWMA update accumulator - std::atomic<uint64_t> m_StartTick{0}; // Time this was instantiated (for mean) - std::atomic<uint64_t> m_LastTick{0}; // Timestamp of last EWMA tick - std::atomic<int64_t> m_Remainder{0}; // Tracks the "modulo" of tick time + std::atomic<uint64_t> m_PendingCount{0}; // Pending EWMA update accumulator; drained on each tick + std::atomic<uint64_t> m_StartTick{0}; // Hi-freq timer value at construction (for MeanRate) + std::atomic<uint64_t> m_LastTick{0}; // Hi-freq timer value of the last EWMA tick + std::atomic<int64_t> m_Remainder{0}; // Accumulated ticks not yet consumed by EWMA updates bool m_IsFirstTick = true; RawEWMA m_RateM1; RawEWMA m_RateM5; @@ -106,7 +134,14 @@ private: void Tick(); }; -/** Moment-in-time snapshot of a distribution +/** Immutable sorted snapshot of a reservoir sample. + * + * Constructed from a vector of sampled values which are sorted on construction. + * Percentiles are computed on demand via linear interpolation between adjacent + * sorted values, following the standard R-7 quantile method. + * + * Because this is a copy of the reservoir at a point in time, it can be held + * and queried without holding any locks on the source UniformSample. */ class SampleSnapshot { @@ -128,12 +163,19 @@ private: std::vector<double> m_Values; }; -/** Randomly selects samples from a stream. Uses Vitter's - Algorithm R to produce a statistically representative sample. - - http://www.cs.umd.edu/~samir/498/vitter.pdf - Random Sampling with a Reservoir +/** Reservoir sampler for probabilistic distribution tracking. + * + * Maintains a fixed-size reservoir of samples drawn uniformly from the full + * history of values using Vitter's Algorithm R. This gives an unbiased + * statistical representation of the value distribution regardless of how many + * total values have been observed, at the cost of O(ReservoirSize) memory. + * + * A larger reservoir improves accuracy of tail percentiles (P99, P999) but + * increases memory and snapshot cost. The default of 1028 gives good accuracy + * for most telemetry uses. + * + * http://www.cs.umd.edu/~samir/498/vitter.pdf - Random Sampling with a Reservoir */ - class UniformSample { public: @@ -159,7 +201,14 @@ private: std::vector<std::atomic<int64_t>> m_Values; }; -/** Track (probabilistic) sample distribution along with min/max +/** Tracks the statistical distribution of a stream of values. + * + * Records exact min, max, count and mean across all values ever seen, plus a + * reservoir sample (via UniformSample) used to compute percentiles. Percentiles + * are therefore probabilistic — they reflect the distribution of a representative + * sample rather than the full history. + * + * All operations are thread-safe via lock-free atomics. */ class Histogram { @@ -183,11 +232,28 @@ private: std::atomic<int64_t> m_Count{0}; }; -/** Track timing and frequency of some operation - - Example usage would be to track frequency and duration of network - requests, or function calls. - +/** Combines a Histogram and a Meter to track both the distribution and rate + * of a recurring operation. + * + * Duration values are stored in hi-freq timer ticks. Use GetHifreqTimerToSeconds() + * when converting for display. + * + * Typical usage via the RAII Scope helper: + * + * OperationTiming MyTiming; + * + * { + * OperationTiming::Scope Scope(MyTiming); + * DoWork(); + * // Scope destructor calls Stop() automatically + * } + * + * // Or cancel if the operation should not be counted: + * { + * OperationTiming::Scope Scope(MyTiming); + * if (CacheHit) { Scope.Cancel(); return; } + * DoExpensiveWork(); + * } */ class OperationTiming { @@ -207,13 +273,19 @@ public: double Rate15() { return m_Meter.Rate15(); } double MeanRate() const { return m_Meter.MeanRate(); } + /** RAII helper that records duration from construction to Stop() or destruction. + * + * Call Cancel() to discard the measurement (e.g. for cache hits that should + * not skew latency statistics). After Stop() or Cancel() the destructor is a + * no-op. + */ struct Scope { Scope(OperationTiming& Outer); ~Scope(); - void Stop(); - void Cancel(); + void Stop(); // Record elapsed time and mark the meter + void Cancel(); // Discard this measurement; destructor becomes a no-op private: OperationTiming& m_Outer; @@ -225,6 +297,7 @@ private: Histogram m_Histogram; }; +/** Immutable snapshot of a Meter's state at a point in time. */ struct MeterSnapshot { uint64_t Count; @@ -234,6 +307,12 @@ struct MeterSnapshot double Rate15; }; +/** Immutable snapshot of a Histogram's state at a point in time. + * + * Count and all statistical values have been scaled by the ConversionFactor + * supplied when the snapshot was taken (e.g. GetHifreqTimerToSeconds() to + * convert timer ticks to seconds). + */ struct HistogramSnapshot { double Count; @@ -246,24 +325,29 @@ struct HistogramSnapshot double P999; }; +/** Combined snapshot of a Meter and Histogram pair. */ struct StatsSnapshot { MeterSnapshot Meter; HistogramSnapshot Histogram; }; +/** Combined snapshot of request timing and byte transfer statistics. */ struct RequestStatsSnapshot { StatsSnapshot Requests; StatsSnapshot Bytes; }; -/** Metrics for network requests - - Aggregates tracking of duration, payload sizes into a single - class - - */ +/** Tracks both the timing and payload size of network requests. + * + * Maintains two independent histogram+meter pairs: one for request duration + * (in hi-freq timer ticks) and one for transferred bytes. Both dimensions + * share the same request count — a single Update() call advances both. + * + * Duration accessors return values in hi-freq timer ticks. Multiply by + * GetHifreqTimerToSeconds() to convert to seconds. + */ class RequestStats { public: @@ -275,9 +359,9 @@ public: // Timing - int64_t MaxDuration() const { return m_BytesHistogram.Max(); } - int64_t MinDuration() const { return m_BytesHistogram.Min(); } - double MeanDuration() const { return m_BytesHistogram.Mean(); } + int64_t MaxDuration() const { return m_RequestTimeHistogram.Max(); } + int64_t MinDuration() const { return m_RequestTimeHistogram.Min(); } + double MeanDuration() const { return m_RequestTimeHistogram.Mean(); } SampleSnapshot DurationSnapshot() const { return m_RequestTimeHistogram.Snapshot(); } double Rate1() { return m_RequestMeter.Rate1(); } double Rate5() { return m_RequestMeter.Rate5(); } @@ -295,14 +379,23 @@ public: double ByteRate15() { return m_BytesMeter.Rate15(); } double ByteMeanRate() const { return m_BytesMeter.MeanRate(); } + /** RAII helper that records duration and byte count from construction to Stop() + * or destruction. + * + * The byte count can be supplied at construction or updated at any point via + * SetBytes() before the scope ends — useful when the response size is not + * known until the operation completes. + * + * Call Cancel() to discard the measurement entirely. + */ struct Scope { Scope(RequestStats& Outer, int64_t Bytes); ~Scope(); void SetBytes(int64_t Bytes) { m_Bytes = Bytes; } - void Stop(); - void Cancel(); + void Stop(); // Record elapsed time and byte count + void Cancel(); // Discard this measurement; destructor becomes a no-op private: RequestStats& m_Outer; diff --git a/src/zentelemetry/otlpencoder.cpp b/src/zentelemetry/otlpencoder.cpp index 677545066..5477c5381 100644 --- a/src/zentelemetry/otlpencoder.cpp +++ b/src/zentelemetry/otlpencoder.cpp @@ -3,9 +3,9 @@ #include "zentelemetry/otlpencoder.h" #include <zenbase/zenbase.h> +#include <zencore/logging/logmsg.h> #include <zentelemetry/otlptrace.h> -#include <spdlog/sinks/sink.h> #include <zencore/testing.h> #include <protozero/buffer_string.hpp> @@ -29,49 +29,49 @@ OtlpEncoder::~OtlpEncoder() } static int -MapSeverity(const spdlog::level::level_enum Level) +MapSeverity(const logging::LogLevel Level) { switch (Level) { - case spdlog::level::critical: + case logging::Critical: return otel::SEVERITY_NUMBER_FATAL; - case spdlog::level::err: + case logging::Err: return otel::SEVERITY_NUMBER_ERROR; - case spdlog::level::warn: + case logging::Warn: return otel::SEVERITY_NUMBER_WARN; - case spdlog::level::info: + case logging::Info: return otel::SEVERITY_NUMBER_INFO; - case spdlog::level::debug: + case logging::Debug: return otel::SEVERITY_NUMBER_DEBUG; default: - case spdlog::level::trace: + case logging::Trace: return otel::SEVERITY_NUMBER_TRACE; } } static const char* -MapSeverityText(const spdlog::level::level_enum Level) +MapSeverityText(const logging::LogLevel Level) { switch (Level) { - case spdlog::level::critical: + case logging::Critical: return "fatal"; - case spdlog::level::err: + case logging::Err: return "error"; - case spdlog::level::warn: + case logging::Warn: return "warn"; - case spdlog::level::info: + case logging::Info: return "info"; - case spdlog::level::debug: + case logging::Debug: return "debug"; default: - case spdlog::level::trace: + case logging::Trace: return "trace"; } } std::string -OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const +OtlpEncoder::FormatOtelProtobuf(const logging::LogMessage& Msg) const { std::string Data; @@ -98,7 +98,7 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const protozero::pbf_builder<otel::InstrumentationScope> IsBuilder{SlBuilder, otel::ScopeLogs::required_InstrumentationScope_scope}; - IsBuilder.add_string(otel::InstrumentationScope::string_name, Msg.logger_name.data(), Msg.logger_name.size()); + IsBuilder.add_string(otel::InstrumentationScope::string_name, Msg.GetLoggerName().data(), Msg.GetLoggerName().size()); } // LogRecord log_records @@ -106,13 +106,13 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const protozero::pbf_builder<otel::LogRecord> LrBuilder{SlBuilder, otel::ScopeLogs::required_repeated_LogRecord_log_records}; LrBuilder.add_fixed64(otel::LogRecord::required_fixed64_time_unix_nano, - std::chrono::duration_cast<std::chrono::nanoseconds>(Msg.time.time_since_epoch()).count()); + std::chrono::duration_cast<std::chrono::nanoseconds>(Msg.GetTime().time_since_epoch()).count()); - const int Severity = MapSeverity(Msg.level); + const int Severity = MapSeverity(Msg.GetLevel()); LrBuilder.add_enum(otel::LogRecord::optional_SeverityNumber_severity_number, Severity); - LrBuilder.add_string(otel::LogRecord::optional_string_severity_text, MapSeverityText(Msg.level)); + LrBuilder.add_string(otel::LogRecord::optional_string_severity_text, MapSeverityText(Msg.GetLevel())); otel::TraceId TraceId; const otel::SpanId SpanId = otel::Span::GetCurrentSpanId(TraceId); @@ -127,7 +127,7 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const { protozero::pbf_builder<otel::AnyValue> BodyBuilder{LrBuilder, otel::LogRecord::optional_anyvalue_body}; - BodyBuilder.add_string(otel::AnyValue::string_string_value, Msg.payload.data(), Msg.payload.size()); + BodyBuilder.add_string(otel::AnyValue::string_string_value, Msg.GetPayload().data(), Msg.GetPayload().size()); } // attributes @@ -139,7 +139,7 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const { protozero::pbf_builder<otel::AnyValue> AvBuilder{KvBuilder, otel::KeyValue::AnyValue_value}; - AvBuilder.add_int64(otel::AnyValue::int64_int_value, Msg.thread_id); + AvBuilder.add_int64(otel::AnyValue::int64_int_value, Msg.GetThreadId()); } } } diff --git a/src/zentelemetry/otlptrace.cpp b/src/zentelemetry/otlptrace.cpp index 6a095cfeb..3888717d5 100644 --- a/src/zentelemetry/otlptrace.cpp +++ b/src/zentelemetry/otlptrace.cpp @@ -385,6 +385,8 @@ otlptrace_forcelink() # if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("telemetry.otlptrace"); + TEST_CASE("otlp.trace") { // Enable OTLP tracing for the duration of this test @@ -409,6 +411,8 @@ TEST_CASE("otlp.trace") } } +TEST_SUITE_END(); + # endif } // namespace zen::otel diff --git a/src/zentelemetry/stats.cpp b/src/zentelemetry/stats.cpp index c67fa3c66..a417bb52c 100644 --- a/src/zentelemetry/stats.cpp +++ b/src/zentelemetry/stats.cpp @@ -631,7 +631,7 @@ EmitSnapshot(const HistogramSnapshot& Snapshot, CbObjectWriter& Cbo) { Cbo << "t_count" << Snapshot.Count << "t_avg" << Snapshot.Avg; Cbo << "t_min" << Snapshot.Min << "t_max" << Snapshot.Max; - Cbo << "t_p75" << Snapshot.P75 << "t_p95" << Snapshot.P95 << "t_p99" << Snapshot.P999; + Cbo << "t_p75" << Snapshot.P75 << "t_p95" << Snapshot.P95 << "t_p99" << Snapshot.P99 << "t_p999" << Snapshot.P999; } void @@ -660,6 +660,8 @@ EmitSnapshot(std::string_view Tag, const RequestStatsSnapshot& Snapshot, CbObjec #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("telemetry.stats"); + TEST_CASE("Core.Stats.Histogram") { Histogram Histo{258}; @@ -819,6 +821,8 @@ TEST_CASE("Meter") # endif } +TEST_SUITE_END(); + namespace zen { void diff --git a/src/zentelemetry/xmake.lua b/src/zentelemetry/xmake.lua index 7739c0a08..cd9a18ec4 100644 --- a/src/zentelemetry/xmake.lua +++ b/src/zentelemetry/xmake.lua @@ -6,5 +6,5 @@ target('zentelemetry') add_headerfiles("**.h") add_files("**.cpp") add_includedirs("include", {public=true}) - add_deps("zencore", "protozero", "spdlog") + add_deps("zencore", "protozero") add_deps("robin-map") diff --git a/src/zentest-appstub/xmake.lua b/src/zentest-appstub/xmake.lua index 97615e322..844ba82ef 100644 --- a/src/zentest-appstub/xmake.lua +++ b/src/zentest-appstub/xmake.lua @@ -5,6 +5,7 @@ target("zentest-appstub") set_group("tests") add_headerfiles("**.h") add_files("*.cpp") + add_deps("zencore") if is_os("linux") then add_syslinks("pthread") diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp index 24cf21e97..509629739 100644 --- a/src/zentest-appstub/zentest-appstub.cpp +++ b/src/zentest-appstub/zentest-appstub.cpp @@ -1,33 +1,418 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compress.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/stream.h> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +#endif + +#include <fmt/format.h> + #include <stdio.h> +#include <algorithm> #include <chrono> #include <cstdlib> #include <cstring> +#include <filesystem> +#include <string> +#include <system_error> #include <thread> -using namespace std::chrono_literals; +using namespace std::literals; +using namespace zen; + +#if !defined(_MSC_VER) +# define _strnicmp strncasecmp // TEMPORARY WORKAROUND - should not be using this +#endif + +// Some basic functions to implement some test "compute" functions + +std::string +Rot13Function(std::string_view InputString) +{ + std::string OutputString{InputString}; + + std::transform(OutputString.begin(), + OutputString.end(), + OutputString.begin(), + [](std::string::value_type c) -> std::string::value_type { + if (c >= 'a' && c <= 'z') + { + return 'a' + (c - 'a' + 13) % 26; + } + else if (c >= 'A' && c <= 'Z') + { + return 'A' + (c - 'A' + 13) % 26; + } + else + { + return c; + } + }); + + return OutputString; +} + +std::string +ReverseFunction(std::string_view InputString) +{ + std::string OutputString{InputString}; + std::reverse(OutputString.begin(), OutputString.end()); + return OutputString; +} + +std::string +IdentityFunction(std::string_view InputString) +{ + return std::string{InputString}; +} + +std::string +NullFunction(std::string_view) +{ + return {}; +} + +zen::CbObject +DescribeFunctions() +{ + CbObjectWriter Versions; + Versions << "BuildSystemVersion" << Guid::FromString("17fe280d-ccd8-4be8-a9d1-89c944a70969"sv); + + Versions.BeginArray("Functions"sv); + Versions.BeginObject(); + Versions << "Name"sv + << "Null"sv; + Versions << "Version"sv << Guid::FromString("00000000-0000-0000-0000-000000000000"sv); + Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Identity"sv; + Versions << "Version"sv << Guid::FromString("11111111-1111-1111-1111-111111111111"sv); + Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Rot13"sv; + Versions << "Version"sv << Guid::FromString("13131313-1313-1313-1313-131313131313"sv); + Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Reverse"sv; + Versions << "Version"sv << Guid::FromString("31313131-3131-3131-3131-313131313131"sv); + Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Sleep"sv; + Versions << "Version"sv << Guid::FromString("88888888-8888-8888-8888-888888888888"sv); + Versions.EndObject(); + Versions.EndArray(); + + return Versions.Save(); +} + +struct ContentResolver +{ + std::filesystem::path InputsRoot; + + CompressedBuffer ResolveChunk(IoHash Hash, uint64_t ExpectedSize) + { + std::filesystem::path ChunkPath = InputsRoot / Hash.ToHexString(); + IoBuffer ChunkBuffer = IoBufferBuilder::MakeFromFile(ChunkPath); + + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer AsCompressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkBuffer), RawHash, RawSize); + + if (RawSize != ExpectedSize) + { + throw std::runtime_error( + fmt::format("chunk size mismatch - expected {}, got {} for '{}'", ExpectedSize, ChunkBuffer.Size(), ChunkPath)); + } + if (RawHash != Hash) + { + throw std::runtime_error(fmt::format("chunk hash mismatch - expected {}, got {} for '{}'", Hash, RawHash, ChunkPath)); + } + + return AsCompressed; + } +}; + +zen::CbPackage +ExecuteFunction(CbObject Action, ContentResolver ChunkResolver) +{ + auto Apply = [&](auto Func) { + zen::CbPackage Result; + auto Source = Action["Inputs"sv].AsObjectView()["Source"sv].AsObjectView(); + + IoHash InputRawHash = Source["RawHash"sv].AsHash(); + uint64_t InputRawSize = Source["RawSize"sv].AsUInt64(); + + zen::CompressedBuffer InputData = ChunkResolver.ResolveChunk(InputRawHash, InputRawSize); + SharedBuffer Input = InputData.Decompress(); + + std::string Output = Func(std::string_view(static_cast<const char*>(Input.GetData()), Input.GetSize())); + zen::CompressedBuffer OutputData = + zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Output), OodleCompressor::Selkie, OodleCompressionLevel::HyperFast4); + IoHash OutputRawHash = OutputData.DecodeRawHash(); + + CbAttachment OutputAttachment(std::move(OutputData), OutputRawHash); + + CbObjectWriter Cbo; + Cbo.BeginArray("Values"sv); + Cbo.BeginObject(); + Cbo << "Id" << Oid{1, 2, 3}; + Cbo.AddAttachment("RawHash", OutputAttachment); + Cbo << "RawSize" << Output.size(); + Cbo.EndObject(); + Cbo.EndArray(); + + Result.SetObject(Cbo.Save()); + Result.AddAttachment(std::move(OutputAttachment)); + return Result; + }; + + std::string_view Function = Action["Function"sv].AsString(); + + if (Function == "Rot13"sv) + { + return Apply(Rot13Function); + } + else if (Function == "Reverse"sv) + { + return Apply(ReverseFunction); + } + else if (Function == "Identity"sv) + { + return Apply(IdentityFunction); + } + else if (Function == "Null"sv) + { + return Apply(NullFunction); + } + else if (Function == "Sleep"sv) + { + uint64_t SleepTimeMs = Action["Constants"sv].AsObjectView()["SleepTimeMs"sv].AsUInt64(); + zen::Sleep(static_cast<int>(SleepTimeMs)); + return Apply(IdentityFunction); + } + else + { + return {}; + } +} + +/* This implements a minimal application to help testing of process launch-related + functionality + + It also mimics the DDC2 worker command line interface, so it may be used to + exercise compute infrastructure. + */ int main(int argc, char* argv[]) { int ExitCode = 0; - for (int i = 0; i < argc; ++i) + try { - if (std::strncmp(argv[i], "-t=", 3) == 0) + std::filesystem::path BasePath = std::filesystem::current_path(); + std::filesystem::path InputPath = std::filesystem::current_path() / "Inputs"; + std::filesystem::path OutputPath = std::filesystem::current_path() / "Outputs"; + std::filesystem::path VersionPath = std::filesystem::current_path() / "Versions"; + std::vector<std::filesystem::path> ActionPaths; + + /* + GetSwitchValues(TEXT("-B="), ActionPathPatterns); + GetSwitchValues(TEXT("-Build="), ActionPathPatterns); + + GetSwitchValues(TEXT("-I="), InputDirectoryPaths); + GetSwitchValues(TEXT("-Input="), InputDirectoryPaths); + + GetSwitchValues(TEXT("-O="), OutputDirectoryPaths); + GetSwitchValues(TEXT("-Output="), OutputDirectoryPaths); + + GetSwitchValues(TEXT("-V="), VersionPaths); + GetSwitchValues(TEXT("-Version="), VersionPaths); + */ + + auto SplitArg = [](const char* Arg) -> std::string_view { + std::string_view ArgView{Arg}; + if (auto SplitPos = ArgView.find_first_of('='); SplitPos != std::string_view::npos) + { + return ArgView.substr(SplitPos + 1); + } + else + { + return {}; + } + }; + + auto ParseIntArg = [](std::string_view Arg) -> int { + int Rv = 0; + const auto Result = std::from_chars(Arg.data(), Arg.data() + Arg.size(), Rv); + + if (Result.ec != std::errc{}) + { + throw std::invalid_argument(fmt::format("bad argument (not an integer): {}", Arg).c_str()); + } + + return Rv; + }; + + for (int i = 1; i < argc; ++i) { - const int SleepTime = std::atoi(argv[i] + 3); + std::string_view Arg = argv[i]; + + if (Arg.compare(0, 1, "-")) + { + continue; + } + + if (std::strncmp(argv[i], "-t=", 3) == 0) + { + const int SleepTime = std::atoi(argv[i] + 3); + + printf("[zentest] sleeping for %ds...\n", SleepTime); + + std::this_thread::sleep_for(SleepTime * 1s); + } + else if (std::strncmp(argv[i], "-f=", 3) == 0) + { + // Force a "failure" process exit code to return to the invoker + + // This may throw for invalid arguments, which makes this useful for + // testing exception handling + std::string_view ErrorArg = SplitArg(argv[i]); + ExitCode = ParseIntArg(ErrorArg); + } + else if ((_strnicmp(argv[i], "-input=", 7) == 0) || (_strnicmp(argv[i], "-i=", 3) == 0)) + { + /* mimic DDC2 + + GetSwitchValues(TEXT("-I="), InputDirectoryPaths); + GetSwitchValues(TEXT("-Input="), InputDirectoryPaths); + */ + + std::string_view InputArg = SplitArg(argv[i]); + InputPath = InputArg; + } + else if ((_strnicmp(argv[i], "-output=", 8) == 0) || (_strnicmp(argv[i], "-o=", 3) == 0)) + { + /* mimic DDC2 handling of where files storing output chunk files are directed + + GetSwitchValues(TEXT("-O="), OutputDirectoryPaths); + GetSwitchValues(TEXT("-Output="), OutputDirectoryPaths); + */ - printf("[zentest] sleeping for %ds...\n", SleepTime); + std::string_view OutputArg = SplitArg(argv[i]); + OutputPath = OutputArg; + } + else if ((_strnicmp(argv[i], "-version=", 8) == 0) || (_strnicmp(argv[i], "-v=", 3) == 0)) + { + /* mimic DDC2 - std::this_thread::sleep_for(SleepTime * 1s); + GetSwitchValues(TEXT("-V="), VersionPaths); + GetSwitchValues(TEXT("-Version="), VersionPaths); + */ + + std::string_view VersionArg = SplitArg(argv[i]); + VersionPath = VersionArg; + } + else if ((_strnicmp(argv[i], "-build=", 7) == 0) || (_strnicmp(argv[i], "-b=", 3) == 0)) + { + /* mimic DDC2 + + GetSwitchValues(TEXT("-B="), ActionPathPatterns); + GetSwitchValues(TEXT("-Build="), ActionPathPatterns); + */ + + std::string_view BuildActionArg = SplitArg(argv[i]); + std::filesystem::path ActionPath{BuildActionArg}; + ActionPaths.push_back(ActionPath); + + ExitCode = 0; + } } - else if (std::strncmp(argv[i], "-f=", 3) == 0) + + // Emit version information + + if (!VersionPath.empty()) { - ExitCode = std::atoi(argv[i] + 3); + CbObjectWriter Version; + + Version << "BuildSystemVersion" << Guid::FromString("17fe280d-ccd8-4be8-a9d1-89c944a70969"sv); + + Version.BeginArray("Functions"); + + Version.BeginObject(); + Version << "Name" + << "Rot13" + << "Version" << Guid::FromString("13131313-1313-1313-1313-131313131313"sv); + Version.EndObject(); + + Version.BeginObject(); + Version << "Name" + << "Reverse" + << "Version" << Guid::FromString("98765432-1000-0000-0000-000000000000"sv); + Version.EndObject(); + + Version.BeginObject(); + Version << "Name" + << "Identity" + << "Version" << Guid::FromString("11111111-1111-1111-1111-111111111111"sv); + Version.EndObject(); + + Version.BeginObject(); + Version << "Name" + << "Null" + << "Version" << Guid::FromString("00000000-0000-0000-0000-000000000000"sv); + Version.EndObject(); + + Version.EndArray(); + CbObject VersionObject = Version.Save(); + + BinaryWriter Writer; + zen::SaveCompactBinary(Writer, VersionObject); + zen::WriteFile(VersionPath, IoBufferBuilder::MakeFromMemory(Writer.GetView())); + } + + // Evaluate actions + + ContentResolver Resolver; + Resolver.InputsRoot = InputPath; + + for (std::filesystem::path ActionPath : ActionPaths) + { + IoBuffer ActionDescBuffer = ReadFile(ActionPath).Flatten(); + CbObject ActionDesc = LoadCompactBinaryObject(ActionDescBuffer); + CbPackage Result = ExecuteFunction(ActionDesc, Resolver); + CbObject ResultObject = Result.GetObject(); + + BinaryWriter Writer; + zen::SaveCompactBinary(Writer, ResultObject); + zen::WriteFile(ActionPath.replace_extension(".output"), IoBufferBuilder::MakeFromMemory(Writer.GetView())); + + // Also marshal outputs + + for (const auto& Attachment : Result.GetAttachments()) + { + const CompositeBuffer& AttachmentBuffer = Attachment.AsCompressedBinary().GetCompressed(); + zen::WriteFile(OutputPath / Attachment.GetHash().ToHexString(), AttachmentBuffer.Flatten().AsIoBuffer()); + } } } + catch (std::exception& Ex) + { + printf("[zentest] exception caught in main: '%s'\n", Ex.what()); + + ExitCode = 99; + } printf("[zentest] exiting with exit code: %d\n", ExitCode); diff --git a/src/zenutil-test/zenutil-test.cpp b/src/zenutil-test/zenutil-test.cpp index f5cfd5a72..e2b6ac9bd 100644 --- a/src/zenutil-test/zenutil-test.cpp +++ b/src/zenutil-test/zenutil-test.cpp @@ -1,45 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zencore/filesystem.h> -#include <zencore/logging.h> -#include <zencore/trace.h> +#include <zencore/testing.h> #include <zenutil/zenutil.h> #include <zencore/memory/newdelete.h> -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include <zencore/testing.h> -# include <zencore/process.h> -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zenutil_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zenutil-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zenutil-test", zen::zenutil_forcelinktests); #else return 0; #endif diff --git a/src/zenutil/commandlineoptions.cpp b/src/zenutil/config/commandlineoptions.cpp index d94564843..25f5522d8 100644 --- a/src/zenutil/commandlineoptions.cpp +++ b/src/zenutil/config/commandlineoptions.cpp @@ -1,7 +1,8 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zenutil/commandlineoptions.h> +#include <zenutil/config/commandlineoptions.h> +#include <zencore/filesystem.h> #include <zencore/string.h> #include <filesystem> @@ -194,6 +195,8 @@ commandlineoptions_forcelink() { } +TEST_SUITE_BEGIN("util.commandlineoptions"); + TEST_CASE("CommandLine") { std::vector<std::string> v1 = ParseCommandLine("c:\\my\\exe.exe \"quoted arg\" \"one\",two,\"three\\\""); @@ -235,5 +238,7 @@ TEST_CASE("CommandLine") CHECK_EQ(v3Stripped[5], std::string("--build-part-name=win64")); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zenutil/environmentoptions.cpp b/src/zenutil/config/environmentoptions.cpp index ee40086c1..fb7f71706 100644 --- a/src/zenutil/environmentoptions.cpp +++ b/src/zenutil/config/environmentoptions.cpp @@ -1,6 +1,6 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zenutil/environmentoptions.h> +#include <zenutil/config/environmentoptions.h> #include <zencore/filesystem.h> diff --git a/src/zenutil/config/loggingconfig.cpp b/src/zenutil/config/loggingconfig.cpp new file mode 100644 index 000000000..5092c60aa --- /dev/null +++ b/src/zenutil/config/loggingconfig.cpp @@ -0,0 +1,77 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenutil/config/loggingconfig.h" + +#include <zenbase/zenbase.h> +#include <zencore/filesystem.h> +#include <zencore/logging.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cxxopts.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +void +ZenLoggingCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig) +{ + // clang-format off + options.add_options("logging") + ("abslog", "Path to log file", cxxopts::value<std::string>(m_AbsLogFile)) + ("log-id", "Specify id for adding context to log output", cxxopts::value<std::string>(LoggingConfig.LogId)) + ("quiet", "Configure console logger output to level WARN", cxxopts::value<bool>(LoggingConfig.QuietConsole)->default_value("false")) + ("noconsole", "Disable console logging", cxxopts::value<bool>(LoggingConfig.NoConsoleOutput)->default_value("false")) + ("log-trace", "Change selected loggers to level TRACE", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Trace])) + ("log-debug", "Change selected loggers to level DEBUG", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Debug])) + ("log-info", "Change selected loggers to level INFO", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Info])) + ("log-warn", "Change selected loggers to level WARN", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Warn])) + ("log-error", "Change selected loggers to level ERROR", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Err])) + ("log-critical", "Change selected loggers to level CRITICAL", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Critical])) + ("log-off", "Change selected loggers to level OFF", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Off])) + ("otlp-endpoint", "OpenTelemetry endpoint URI (e.g http://localhost:4318)", cxxopts::value<std::string>(LoggingConfig.OtelEndpointUri)) + ; + // clang-format on +} + +void +ZenLoggingCmdLineOptions::ApplyOptions(ZenLoggingConfig& LoggingConfig) +{ + LoggingConfig.AbsLogFile = MakeSafeAbsolutePath(m_AbsLogFile); +} + +void +ApplyLoggingOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig) +{ + ZEN_UNUSED(options); + + if (LoggingConfig.QuietConsole) + { + bool HasExplicitConsoleLevel = false; + for (int i = 0; i < logging::LogLevelCount; ++i) + { + if (LoggingConfig.Loggers[i].find("console") != std::string::npos) + { + HasExplicitConsoleLevel = true; + break; + } + } + + if (!HasExplicitConsoleLevel) + { + std::string& WarnLoggers = LoggingConfig.Loggers[logging::Warn]; + if (!WarnLoggers.empty()) + { + WarnLoggers += ","; + } + WarnLoggers += "console"; + } + } + + for (int i = 0; i < logging::LogLevelCount; ++i) + { + logging::ConfigureLogLevels(logging::LogLevel(i), LoggingConfig.Loggers[i]); + } + logging::RefreshLogLevels(); +} + +} // namespace zen diff --git a/src/zenutil/consoletui.cpp b/src/zenutil/consoletui.cpp new file mode 100644 index 000000000..4410d463d --- /dev/null +++ b/src/zenutil/consoletui.cpp @@ -0,0 +1,483 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/consoletui.h> + +#include <zencore/zencore.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#else +# include <poll.h> +# include <sys/ioctl.h> +# include <termios.h> +# include <unistd.h> +#endif + +#include <cstdio> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// Platform-specific terminal helpers + +#if ZEN_PLATFORM_WINDOWS + +static bool +CheckIsInteractiveTerminal() +{ + DWORD dwMode = 0; + return GetConsoleMode(GetStdHandle(STD_INPUT_HANDLE), &dwMode) && GetConsoleMode(GetStdHandle(STD_OUTPUT_HANDLE), &dwMode); +} + +static void +EnableVirtualTerminal() +{ + HANDLE hStdOut = GetStdHandle(STD_OUTPUT_HANDLE); + DWORD dwMode = 0; + if (GetConsoleMode(hStdOut, &dwMode)) + { + SetConsoleMode(hStdOut, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); + } +} + +// RAII guard: sets the console output code page for the lifetime of the object and +// restores the original on destruction. Required for UTF-8 glyphs to render correctly +// via printf/fflush since the default console code page is not UTF-8. +class ConsoleCodePageGuard +{ +public: + explicit ConsoleCodePageGuard(UINT NewCP) : m_OldCP(GetConsoleOutputCP()) { SetConsoleOutputCP(NewCP); } + ~ConsoleCodePageGuard() { SetConsoleOutputCP(m_OldCP); } + +private: + UINT m_OldCP; +}; + +enum class ConsoleKey +{ + Unknown, + ArrowUp, + ArrowDown, + Enter, + Escape, +}; + +static ConsoleKey +ReadKey() +{ + HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE); + INPUT_RECORD Record{}; + DWORD dwRead = 0; + while (true) + { + if (!ReadConsoleInputA(hStdin, &Record, 1, &dwRead)) + { + return ConsoleKey::Escape; // treat read error as cancel + } + if (Record.EventType == KEY_EVENT && Record.Event.KeyEvent.bKeyDown) + { + switch (Record.Event.KeyEvent.wVirtualKeyCode) + { + case VK_UP: + return ConsoleKey::ArrowUp; + case VK_DOWN: + return ConsoleKey::ArrowDown; + case VK_RETURN: + return ConsoleKey::Enter; + case VK_ESCAPE: + return ConsoleKey::Escape; + default: + break; + } + } + } +} + +#else // POSIX + +static bool +CheckIsInteractiveTerminal() +{ + return isatty(STDIN_FILENO) && isatty(STDOUT_FILENO); +} + +static void +EnableVirtualTerminal() +{ + // ANSI escape codes are native on POSIX terminals; nothing to do +} + +// RAII guard: switches the terminal to raw/unbuffered input mode and restores +// the original attributes on destruction. +class RawModeGuard +{ +public: + RawModeGuard() + { + if (tcgetattr(STDIN_FILENO, &m_OldAttrs) != 0) + { + return; + } + + struct termios Raw = m_OldAttrs; + Raw.c_iflag &= ~static_cast<tcflag_t>(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + Raw.c_cflag |= CS8; + Raw.c_lflag &= ~static_cast<tcflag_t>(ECHO | ICANON | IEXTEN | ISIG); + Raw.c_cc[VMIN] = 1; + Raw.c_cc[VTIME] = 0; + if (tcsetattr(STDIN_FILENO, TCSANOW, &Raw) == 0) + { + m_Valid = true; + } + } + + ~RawModeGuard() + { + if (m_Valid) + { + tcsetattr(STDIN_FILENO, TCSANOW, &m_OldAttrs); + } + } + + bool IsValid() const { return m_Valid; } + +private: + struct termios m_OldAttrs = {}; + bool m_Valid = false; +}; + +static int +ReadByteWithTimeout(int TimeoutMs) +{ + struct pollfd Pfd + { + STDIN_FILENO, POLLIN, 0 + }; + if (poll(&Pfd, 1, TimeoutMs) > 0 && (Pfd.revents & POLLIN)) + { + unsigned char c = 0; + if (read(STDIN_FILENO, &c, 1) == 1) + { + return static_cast<int>(c); + } + } + return -1; +} + +// State for fullscreen live mode (alternate screen + raw input) +static struct termios s_SavedAttrs = {}; +static bool s_InLiveMode = false; + +enum class ConsoleKey +{ + Unknown, + ArrowUp, + ArrowDown, + Enter, + Escape, +}; + +static ConsoleKey +ReadKey() +{ + unsigned char c = 0; + if (read(STDIN_FILENO, &c, 1) != 1) + { + return ConsoleKey::Escape; // treat read error as cancel + } + + if (c == 27) // ESC byte or start of an escape sequence + { + int Next = ReadByteWithTimeout(50); + if (Next == '[') + { + int Final = ReadByteWithTimeout(50); + if (Final == 'A') + { + return ConsoleKey::ArrowUp; + } + if (Final == 'B') + { + return ConsoleKey::ArrowDown; + } + } + return ConsoleKey::Escape; + } + + if (c == '\r' || c == '\n') + { + return ConsoleKey::Enter; + } + + return ConsoleKey::Unknown; +} + +#endif // ZEN_PLATFORM_WINDOWS / POSIX + +////////////////////////////////////////////////////////////////////////// +// Public API + +uint32_t +TuiConsoleColumns(uint32_t Default) +{ +#if ZEN_PLATFORM_WINDOWS + CONSOLE_SCREEN_BUFFER_INFO Csbi = {}; + if (GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &Csbi)) + { + return static_cast<uint32_t>(Csbi.dwSize.X); + } +#else + struct winsize Ws = {}; + if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &Ws) == 0 && Ws.ws_col > 0) + { + return static_cast<uint32_t>(Ws.ws_col); + } +#endif + return Default; +} + +void +TuiEnableOutput() +{ + EnableVirtualTerminal(); +#if ZEN_PLATFORM_WINDOWS + SetConsoleOutputCP(CP_UTF8); +#endif +} + +bool +TuiIsStdoutTty() +{ +#if ZEN_PLATFORM_WINDOWS + static bool Cached = [] { + DWORD dwMode = 0; + return GetConsoleMode(GetStdHandle(STD_OUTPUT_HANDLE), &dwMode) != 0; + }(); + return Cached; +#else + static bool Cached = isatty(STDOUT_FILENO) != 0; + return Cached; +#endif +} + +bool +IsTuiAvailable() +{ + static bool Cached = CheckIsInteractiveTerminal(); + return Cached; +} + +int +TuiPickOne(std::string_view Title, std::span<const std::string> Items) +{ + EnableVirtualTerminal(); + +#if ZEN_PLATFORM_WINDOWS + ConsoleCodePageGuard CodePageGuard(CP_UTF8); +#else + RawModeGuard RawMode; + if (!RawMode.IsValid()) + { + return -1; + } +#endif + + const int Count = static_cast<int>(Items.size()); + int SelectedIndex = 0; + + printf("\n%.*s\n\n", static_cast<int>(Title.size()), Title.data()); + + // Hide cursor during interaction + printf("\033[?25l"); + + // Renders the full entry list and hint footer. + // On subsequent calls, moves the cursor back up first to overwrite the previous output. + bool FirstRender = true; + auto RenderAll = [&] { + if (!FirstRender) + { + printf("\033[%dA", Count + 2); // move up: entries + blank line + hint line + } + FirstRender = false; + + for (int i = 0; i < Count; ++i) + { + bool IsSelected = (i == SelectedIndex); + + printf("\r\033[K"); // erase line + + if (IsSelected) + { + printf("\033[1;7m"); // bold + reverse video + } + + // \xe2\x96\xb6 = U+25B6 BLACK RIGHT-POINTING TRIANGLE (▶) + const char* Indicator = IsSelected ? " \xe2\x96\xb6 " : " "; + + printf("%s%s", Indicator, Items[i].c_str()); + + if (IsSelected) + { + printf("\033[0m"); // reset attributes + } + + printf("\n"); + } + + // Blank separator line + printf("\r\033[K\n"); + + // Hint footer + // \xe2\x86\x91 = U+2191 ↑ \xe2\x86\x93 = U+2193 ↓ + printf( + "\r\033[K \033[2m\xe2\x86\x91/\xe2\x86\x93\033[0m navigate " + "\033[2mEnter\033[0m confirm " + "\033[2mEsc\033[0m cancel\n"); + + fflush(stdout); + }; + + RenderAll(); + + int Result = -1; + bool Done = false; + while (!Done) + { + ConsoleKey Key = ReadKey(); + switch (Key) + { + case ConsoleKey::ArrowUp: + SelectedIndex = (SelectedIndex - 1 + Count) % Count; + RenderAll(); + break; + + case ConsoleKey::ArrowDown: + SelectedIndex = (SelectedIndex + 1) % Count; + RenderAll(); + break; + + case ConsoleKey::Enter: + Result = SelectedIndex; + Done = true; + break; + + case ConsoleKey::Escape: + Done = true; + break; + + default: + break; + } + } + + // Restore cursor and add a blank line for visual separation + printf("\033[?25h\n"); + fflush(stdout); + + return Result; +} + +void +TuiEnterAlternateScreen() +{ + EnableVirtualTerminal(); +#if ZEN_PLATFORM_WINDOWS + SetConsoleOutputCP(CP_UTF8); +#endif + + printf("\033[?1049h"); // Enter alternate screen buffer + printf("\033[?25l"); // Hide cursor + fflush(stdout); + +#if !ZEN_PLATFORM_WINDOWS + if (tcgetattr(STDIN_FILENO, &s_SavedAttrs) == 0) + { + struct termios Raw = s_SavedAttrs; + Raw.c_iflag &= ~static_cast<tcflag_t>(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + Raw.c_cflag |= CS8; + Raw.c_lflag &= ~static_cast<tcflag_t>(ECHO | ICANON | IEXTEN | ISIG); + Raw.c_cc[VMIN] = 1; + Raw.c_cc[VTIME] = 0; + if (tcsetattr(STDIN_FILENO, TCSANOW, &Raw) == 0) + { + s_InLiveMode = true; + } + } +#endif +} + +void +TuiExitAlternateScreen() +{ + printf("\033[?25h"); // Show cursor + printf("\033[?1049l"); // Exit alternate screen buffer + fflush(stdout); + +#if !ZEN_PLATFORM_WINDOWS + if (s_InLiveMode) + { + tcsetattr(STDIN_FILENO, TCSANOW, &s_SavedAttrs); + s_InLiveMode = false; + } +#endif +} + +void +TuiCursorHome() +{ + printf("\033[H"); +} + +uint32_t +TuiConsoleRows(uint32_t Default) +{ +#if ZEN_PLATFORM_WINDOWS + CONSOLE_SCREEN_BUFFER_INFO Csbi = {}; + if (GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &Csbi)) + { + return static_cast<uint32_t>(Csbi.srWindow.Bottom - Csbi.srWindow.Top + 1); + } +#else + struct winsize Ws = {}; + if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &Ws) == 0 && Ws.ws_row > 0) + { + return static_cast<uint32_t>(Ws.ws_row); + } +#endif + return Default; +} + +bool +TuiPollQuit() +{ +#if ZEN_PLATFORM_WINDOWS + HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE); + DWORD dwCount = 0; + if (!GetNumberOfConsoleInputEvents(hStdin, &dwCount) || dwCount == 0) + { + return false; + } + INPUT_RECORD Record{}; + DWORD dwRead = 0; + while (PeekConsoleInputA(hStdin, &Record, 1, &dwRead) && dwRead > 0) + { + ReadConsoleInputA(hStdin, &Record, 1, &dwRead); + if (Record.EventType == KEY_EVENT && Record.Event.KeyEvent.bKeyDown) + { + WORD vk = Record.Event.KeyEvent.wVirtualKeyCode; + char ch = Record.Event.KeyEvent.uChar.AsciiChar; + if (vk == VK_ESCAPE || ch == 'q' || ch == 'Q') + { + return true; + } + } + } + return false; +#else + // Non-blocking read: character 3 = Ctrl+C, 27 = Esc, 'q'/'Q' = quit + int b = ReadByteWithTimeout(0); + return (b == 3 || b == 27 || b == 'q' || b == 'Q'); +#endif +} + +} // namespace zen diff --git a/src/zenutil/include/zenutil/commandlineoptions.h b/src/zenutil/include/zenutil/config/commandlineoptions.h index 01cceedb1..01cceedb1 100644 --- a/src/zenutil/include/zenutil/commandlineoptions.h +++ b/src/zenutil/include/zenutil/config/commandlineoptions.h diff --git a/src/zenutil/include/zenutil/environmentoptions.h b/src/zenutil/include/zenutil/config/environmentoptions.h index 7418608e4..1ecdf591a 100644 --- a/src/zenutil/include/zenutil/environmentoptions.h +++ b/src/zenutil/include/zenutil/config/environmentoptions.h @@ -3,7 +3,7 @@ #pragma once #include <zencore/string.h> -#include <zenutil/commandlineoptions.h> +#include <zenutil/config/commandlineoptions.h> namespace zen { diff --git a/src/zenutil/include/zenutil/config/loggingconfig.h b/src/zenutil/include/zenutil/config/loggingconfig.h new file mode 100644 index 000000000..b55b2d9f7 --- /dev/null +++ b/src/zenutil/include/zenutil/config/loggingconfig.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logbase.h> +#include <filesystem> +#include <string> + +namespace cxxopts { +class Options; +} + +namespace zen { + +struct ZenLoggingConfig +{ + bool NoConsoleOutput = false; // Control default use of stdout for diagnostics + bool QuietConsole = false; // Configure console logger output to level WARN + std::filesystem::path AbsLogFile; // Absolute path to main log file + std::string Loggers[logging::LogLevelCount]; + std::string LogId; // Id for tagging log output + std::string OtelEndpointUri; // OpenTelemetry endpoint URI +}; + +void ApplyLoggingOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig); + +class ZenLoggingCmdLineOptions +{ +public: + void AddCliOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig); + void ApplyOptions(ZenLoggingConfig& LoggingConfig); + +private: + std::string m_AbsLogFile; +}; + +} // namespace zen diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h new file mode 100644 index 000000000..5f74fa82b --- /dev/null +++ b/src/zenutil/include/zenutil/consoletui.h @@ -0,0 +1,60 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <cstdint> +#include <span> +#include <string> +#include <string_view> + +namespace zen { + +// Returns the width of the console in columns, or Default if it cannot be determined. +uint32_t TuiConsoleColumns(uint32_t Default = 120); + +// Enables ANSI/VT escape code processing and UTF-8 console output. +// Call once before printing ANSI escape sequences or multi-byte UTF-8 characters via printf. +// Safe to call multiple times. No-op on POSIX (escape codes are native there). +void TuiEnableOutput(); + +// Returns true if stdout is connected to a real terminal (not piped or redirected). +// Useful for deciding whether to use ANSI escape codes for progress output. +bool TuiIsStdoutTty(); + +// Returns true if both stdin and stdout are connected to an interactive terminal +// (i.e. not piped or redirected). Must be checked before calling TuiPickOne(). +bool IsTuiAvailable(); + +// Displays a cursor-navigable single-select list in the terminal. +// +// - Title: a short description printed once above the list +// - Items: pre-formatted display labels, one per selectable entry +// +// Arrow keys (↑/↓) navigate the selection, Enter confirms, Esc cancels. +// Returns the index of the selected item, or -1 if the user cancelled. +// +// Precondition: IsTuiAvailable() must be true. +int TuiPickOne(std::string_view Title, std::span<const std::string> Items); + +// Enter the alternate screen buffer for fullscreen live-update mode. +// Hides the cursor. On POSIX, switches to raw/unbuffered terminal input. +// Must be balanced by a call to TuiExitAlternateScreen(). +// Precondition: IsTuiAvailable() must be true. +void TuiEnterAlternateScreen(); + +// Exit alternate screen buffer. Restores the cursor and, on POSIX, the original +// terminal mode. Safe to call even if TuiEnterAlternateScreen() was not called. +void TuiExitAlternateScreen(); + +// Move the cursor to the top-left corner of the terminal (row 1, col 1). +void TuiCursorHome(); + +// Returns the height of the console in rows, or Default if it cannot be determined. +uint32_t TuiConsoleRows(uint32_t Default = 40); + +// Non-blocking check: returns true if the user has pressed a key that means quit +// (Esc, 'q', 'Q', or Ctrl+C). Consumes the event if one is pending. +// Should only be called while in alternate screen mode. +bool TuiPollQuit(); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/logging.h b/src/zenutil/include/zenutil/logging.h index 85ddc86cd..95419c274 100644 --- a/src/zenutil/include/zenutil/logging.h +++ b/src/zenutil/include/zenutil/logging.h @@ -3,19 +3,12 @@ #pragma once #include <zencore/logging.h> +#include <zencore/logging/sink.h> #include <filesystem> #include <memory> #include <string> -namespace spdlog::sinks { -class sink; -} - -namespace spdlog { -using sink_ptr = std::shared_ptr<sinks::sink>; -} - ////////////////////////////////////////////////////////////////////////// // // Logging utilities @@ -45,6 +38,6 @@ void FinishInitializeLogging(const LoggingOptions& LoggingOptions); void InitializeLogging(const LoggingOptions& LoggingOptions); void ShutdownLogging(); -spdlog::sink_ptr GetFileSink(); +logging::SinkPtr GetFileSink(); } // namespace zen diff --git a/src/zenutil/include/zenutil/logging/fullformatter.h b/src/zenutil/include/zenutil/logging/fullformatter.h index 9f245becd..33cb94dae 100644 --- a/src/zenutil/include/zenutil/logging/fullformatter.h +++ b/src/zenutil/include/zenutil/logging/fullformatter.h @@ -2,21 +2,19 @@ #pragma once +#include <zencore/logging/formatter.h> +#include <zencore/logging/helpers.h> #include <zencore/memory/llm.h> #include <zencore/zencore.h> #include <string_view> -ZEN_THIRD_PARTY_INCLUDES_START -#include <spdlog/formatter.h> -ZEN_THIRD_PARTY_INCLUDES_END - namespace zen::logging { -class full_formatter final : public spdlog::formatter +class FullFormatter final : public Formatter { public: - full_formatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) + FullFormatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) : m_Epoch(Epoch) , m_LogId(LogId) , m_LinePrefix(128, ' ') @@ -24,16 +22,19 @@ public: { } - full_formatter(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {} + FullFormatter(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {} - virtual std::unique_ptr<formatter> clone() const override + virtual std::unique_ptr<Formatter> Clone() const override { ZEN_MEMSCOPE(ELLMTag::Logging); - // Note: this does not properly clone m_UseFullDate - return std::make_unique<full_formatter>(m_LogId, m_Epoch); + if (m_UseFullDate) + { + return std::make_unique<FullFormatter>(m_LogId); + } + return std::make_unique<FullFormatter>(m_LogId, m_Epoch); } - virtual void format(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& OutBuffer) override + virtual void Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) override { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -44,38 +45,38 @@ public: std::chrono::seconds TimestampSeconds; - std::chrono::milliseconds millis; + std::chrono::milliseconds Millis; if (m_UseFullDate) { - TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(msg.time.time_since_epoch()); + TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch()); if (TimestampSeconds != m_LastLogSecs) { RwLock::ExclusiveLockScope _(m_TimestampLock); m_LastLogSecs = TimestampSeconds; - m_CachedLocalTm = spdlog::details::os::localtime(spdlog::log_clock::to_time_t(msg.time)); + m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); m_CachedDatetime.clear(); m_CachedDatetime.push_back('['); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_year % 100, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_year % 100, m_CachedDatetime); m_CachedDatetime.push_back('-'); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_mon + 1, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_mon + 1, m_CachedDatetime); m_CachedDatetime.push_back('-'); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_mday, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_mday, m_CachedDatetime); m_CachedDatetime.push_back(' '); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_hour, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_hour, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_min, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_min, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_sec, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_sec, m_CachedDatetime); m_CachedDatetime.push_back('.'); } - millis = spdlog::details::fmt_helper::time_fraction<std::chrono::milliseconds>(msg.time); + Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime()); } else { - auto ElapsedTime = msg.time - m_Epoch; + auto ElapsedTime = Msg.GetTime() - m_Epoch; TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(ElapsedTime); if (m_CacheTimestamp.load() != TimestampSeconds) @@ -93,15 +94,15 @@ public: m_CachedDatetime.clear(); m_CachedDatetime.push_back('['); - spdlog::details::fmt_helper::pad2(LogHours, m_CachedDatetime); + helpers::Pad2(LogHours, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(LogMins, m_CachedDatetime); + helpers::Pad2(LogMins, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(LogSecs, m_CachedDatetime); + helpers::Pad2(LogSecs, m_CachedDatetime); m_CachedDatetime.push_back('.'); } - millis = std::chrono::duration_cast<std::chrono::milliseconds>(ElapsedTime - TimestampSeconds); + Millis = std::chrono::duration_cast<std::chrono::milliseconds>(ElapsedTime - TimestampSeconds); } { @@ -109,44 +110,43 @@ public: OutBuffer.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); } - spdlog::details::fmt_helper::pad3(static_cast<uint32_t>(millis.count()), OutBuffer); + helpers::Pad3(static_cast<uint32_t>(Millis.count()), OutBuffer); OutBuffer.push_back(']'); OutBuffer.push_back(' '); if (!m_LogId.empty()) { OutBuffer.push_back('['); - spdlog::details::fmt_helper::append_string_view(m_LogId, OutBuffer); + helpers::AppendStringView(m_LogId, OutBuffer); OutBuffer.push_back(']'); OutBuffer.push_back(' '); } // append logger name if exists - if (msg.logger_name.size() > 0) + if (Msg.GetLoggerName().size() > 0) { OutBuffer.push_back('['); - spdlog::details::fmt_helper::append_string_view(msg.logger_name, OutBuffer); + helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); OutBuffer.push_back(']'); OutBuffer.push_back(' '); } OutBuffer.push_back('['); // wrap the level name with color - msg.color_range_start = OutBuffer.size(); - spdlog::details::fmt_helper::append_string_view(spdlog::level::to_string_view(msg.level), OutBuffer); - msg.color_range_end = OutBuffer.size(); + Msg.ColorRangeStart = OutBuffer.size(); + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), OutBuffer); + Msg.ColorRangeEnd = OutBuffer.size(); OutBuffer.push_back(']'); OutBuffer.push_back(' '); // add source location if present - if (!msg.source.empty()) + if (Msg.GetSource()) { OutBuffer.push_back('['); - const char* filename = - spdlog::details::short_filename_formatter<spdlog::details::null_scoped_padder>::basename(msg.source.filename); - spdlog::details::fmt_helper::append_string_view(filename, OutBuffer); + const char* Filename = helpers::ShortFilename(Msg.GetSource().Filename); + helpers::AppendStringView(Filename, OutBuffer); OutBuffer.push_back(':'); - spdlog::details::fmt_helper::append_int(msg.source.line, OutBuffer); + helpers::AppendInt(Msg.GetSource().Line, OutBuffer); OutBuffer.push_back(']'); OutBuffer.push_back(' '); } @@ -156,8 +156,9 @@ public: const size_t LinePrefixCount = Min<size_t>(OutBuffer.size(), m_LinePrefix.size()); - auto ItLineBegin = msg.payload.begin(); - auto ItMessageEnd = msg.payload.end(); + auto MsgPayload = Msg.GetPayload(); + auto ItLineBegin = MsgPayload.begin(); + auto ItMessageEnd = MsgPayload.end(); bool IsFirstline = true; { @@ -170,9 +171,9 @@ public: } else { - spdlog::details::fmt_helper::append_string_view(std::string_view(m_LinePrefix.data(), LinePrefixCount), OutBuffer); + helpers::AppendStringView(std::string_view(m_LinePrefix.data(), LinePrefixCount), OutBuffer); } - spdlog::details::fmt_helper::append_string_view(spdlog::string_view_t(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer); + helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer); }; while (ItLineEnd != ItMessageEnd) @@ -187,7 +188,7 @@ public: if (ItLineBegin != ItMessageEnd) { EmitLine(); - spdlog::details::fmt_helper::append_string_view("\n"sv, OutBuffer); + helpers::AppendStringView("\n"sv, OutBuffer); } } } @@ -197,7 +198,7 @@ private: std::tm m_CachedLocalTm; std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)}; std::atomic<std::chrono::seconds> m_CacheTimestamp{std::chrono::seconds(87654321)}; - spdlog::memory_buf_t m_CachedDatetime; + MemoryBuffer m_CachedDatetime; std::string m_LogId; std::string m_LinePrefix; bool m_UseFullDate = true; diff --git a/src/zenutil/include/zenutil/logging/jsonformatter.h b/src/zenutil/include/zenutil/logging/jsonformatter.h index 3f660e421..216b1b5e5 100644 --- a/src/zenutil/include/zenutil/logging/jsonformatter.h +++ b/src/zenutil/include/zenutil/logging/jsonformatter.h @@ -2,27 +2,26 @@ #pragma once +#include <zencore/logging/formatter.h> +#include <zencore/logging/helpers.h> #include <zencore/memory/llm.h> #include <zencore/zencore.h> #include <string_view> - -ZEN_THIRD_PARTY_INCLUDES_START -#include <spdlog/formatter.h> -ZEN_THIRD_PARTY_INCLUDES_END +#include <unordered_map> namespace zen::logging { using namespace std::literals; -class json_formatter final : public spdlog::formatter +class JsonFormatter final : public Formatter { public: - json_formatter(std::string_view LogId) : m_LogId(LogId) {} + JsonFormatter(std::string_view LogId) : m_LogId(LogId) {} - virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<json_formatter>(m_LogId); } + virtual std::unique_ptr<Formatter> Clone() const override { return std::make_unique<JsonFormatter>(m_LogId); } - virtual void format(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& dest) override + virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) override { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -30,141 +29,132 @@ public: using std::chrono::milliseconds; using std::chrono::seconds; - auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch()); - if (secs != m_LastLogSecs) + auto Secs = std::chrono::duration_cast<seconds>(Msg.GetTime().time_since_epoch()); + if (Secs != m_LastLogSecs) { - m_CachedTm = spdlog::details::os::localtime(spdlog::log_clock::to_time_t(msg.time)); - m_LastLogSecs = secs; - } - - const auto& tm_time = m_CachedTm; + RwLock::ExclusiveLockScope _(m_TimestampLock); + m_CachedTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); + m_LastLogSecs = Secs; - // cache the date/time part for the next second. - - if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0) - { + // cache the date/time part for the next second. m_CachedDatetime.clear(); - spdlog::details::fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime); + helpers::AppendInt(m_CachedTm.tm_year + 1900, m_CachedDatetime); m_CachedDatetime.push_back('-'); - spdlog::details::fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime); + helpers::Pad2(m_CachedTm.tm_mon + 1, m_CachedDatetime); m_CachedDatetime.push_back('-'); - spdlog::details::fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime); + helpers::Pad2(m_CachedTm.tm_mday, m_CachedDatetime); m_CachedDatetime.push_back(' '); - spdlog::details::fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime); + helpers::Pad2(m_CachedTm.tm_hour, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime); + helpers::Pad2(m_CachedTm.tm_min, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime); + helpers::Pad2(m_CachedTm.tm_sec, m_CachedDatetime); m_CachedDatetime.push_back('.'); - - m_CacheTimestamp = secs; } - dest.append("{"sv); - dest.append("\"time\": \""sv); - dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); - auto millis = spdlog::details::fmt_helper::time_fraction<milliseconds>(msg.time); - spdlog::details::fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest); - dest.append("\", "sv); + helpers::AppendStringView("{"sv, Dest); + helpers::AppendStringView("\"time\": \""sv, Dest); + { + RwLock::SharedLockScope _(m_TimestampLock); + Dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); + } + auto Millis = helpers::TimeFraction<milliseconds>(Msg.GetTime()); + helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest); + helpers::AppendStringView("\", "sv, Dest); - dest.append("\"status\": \""sv); - dest.append(spdlog::level::to_string_view(msg.level)); - dest.append("\", "sv); + helpers::AppendStringView("\"status\": \""sv, Dest); + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); + helpers::AppendStringView("\", "sv, Dest); - dest.append("\"source\": \""sv); - dest.append("zenserver"sv); - dest.append("\", "sv); + helpers::AppendStringView("\"source\": \""sv, Dest); + helpers::AppendStringView("zenserver"sv, Dest); + helpers::AppendStringView("\", "sv, Dest); - dest.append("\"service\": \""sv); - dest.append("zencache"sv); - dest.append("\", "sv); + helpers::AppendStringView("\"service\": \""sv, Dest); + helpers::AppendStringView("zencache"sv, Dest); + helpers::AppendStringView("\", "sv, Dest); if (!m_LogId.empty()) { - dest.append("\"id\": \""sv); - dest.append(m_LogId); - dest.append("\", "sv); + helpers::AppendStringView("\"id\": \""sv, Dest); + helpers::AppendStringView(m_LogId, Dest); + helpers::AppendStringView("\", "sv, Dest); } - if (msg.logger_name.size() > 0) + if (Msg.GetLoggerName().size() > 0) { - dest.append("\"logger.name\": \""sv); - dest.append(msg.logger_name); - dest.append("\", "sv); + helpers::AppendStringView("\"logger.name\": \""sv, Dest); + helpers::AppendStringView(Msg.GetLoggerName(), Dest); + helpers::AppendStringView("\", "sv, Dest); } - if (msg.thread_id != 0) + if (Msg.GetThreadId() != 0) { - dest.append("\"logger.thread_name\": \""sv); - spdlog::details::fmt_helper::pad_uint(msg.thread_id, 0, dest); - dest.append("\", "sv); + helpers::AppendStringView("\"logger.thread_name\": \""sv, Dest); + helpers::PadUint(Msg.GetThreadId(), 0, Dest); + helpers::AppendStringView("\", "sv, Dest); } - if (!msg.source.empty()) + if (Msg.GetSource()) { - dest.append("\"file\": \""sv); - WriteEscapedString( - dest, - spdlog::details::short_filename_formatter<spdlog::details::null_scoped_padder>::basename(msg.source.filename)); - dest.append("\","sv); - - dest.append("\"line\": \""sv); - dest.append(fmt::format("{}", msg.source.line)); - dest.append("\","sv); - - dest.append("\"logger.method_name\": \""sv); - WriteEscapedString(dest, msg.source.funcname); - dest.append("\", "sv); + helpers::AppendStringView("\"file\": \""sv, Dest); + WriteEscapedString(Dest, helpers::ShortFilename(Msg.GetSource().Filename)); + helpers::AppendStringView("\","sv, Dest); + + helpers::AppendStringView("\"line\": \""sv, Dest); + helpers::AppendInt(Msg.GetSource().Line, Dest); + helpers::AppendStringView("\","sv, Dest); } - dest.append("\"message\": \""sv); - WriteEscapedString(dest, msg.payload); - dest.append("\""sv); + helpers::AppendStringView("\"message\": \""sv, Dest); + WriteEscapedString(Dest, Msg.GetPayload()); + helpers::AppendStringView("\""sv, Dest); - dest.append("}\n"sv); + helpers::AppendStringView("}\n"sv, Dest); } private: - static inline const std::unordered_map<char, std::string_view> SpecialCharacterMap{{'\b', "\\b"sv}, - {'\f', "\\f"sv}, - {'\n', "\\n"sv}, - {'\r', "\\r"sv}, - {'\t', "\\t"sv}, - {'"', "\\\""sv}, - {'\\', "\\\\"sv}}; - - static void WriteEscapedString(spdlog::memory_buf_t& dest, const spdlog::string_view_t& payload) + static inline const std::unordered_map<char, std::string_view> s_SpecialCharacterMap{{'\b', "\\b"sv}, + {'\f', "\\f"sv}, + {'\n', "\\n"sv}, + {'\r', "\\r"sv}, + {'\t', "\\t"sv}, + {'"', "\\\""sv}, + {'\\', "\\\\"sv}}; + + static void WriteEscapedString(MemoryBuffer& Dest, const std::string_view& Text) { - const char* RangeStart = payload.begin(); - for (const char* It = RangeStart; It != payload.end(); ++It) + const char* RangeStart = Text.data(); + const char* End = Text.data() + Text.size(); + for (const char* It = RangeStart; It != End; ++It) { - if (auto SpecialIt = SpecialCharacterMap.find(*It); SpecialIt != SpecialCharacterMap.end()) + if (auto SpecialIt = s_SpecialCharacterMap.find(*It); SpecialIt != s_SpecialCharacterMap.end()) { if (RangeStart != It) { - dest.append(RangeStart, It); + Dest.append(RangeStart, It); } - dest.append(SpecialIt->second); + helpers::AppendStringView(SpecialIt->second, Dest); RangeStart = It + 1; } } - if (RangeStart != payload.end()) + if (RangeStart != End) { - dest.append(RangeStart, payload.end()); + Dest.append(RangeStart, End); } }; std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0}; std::chrono::seconds m_LastLogSecs{0}; - std::chrono::seconds m_CacheTimestamp{0}; - spdlog::memory_buf_t m_CachedDatetime; + MemoryBuffer m_CachedDatetime; std::string m_LogId; + RwLock m_TimestampLock; }; } // namespace zen::logging diff --git a/src/zenutil/include/zenutil/logging/rotatingfilesink.h b/src/zenutil/include/zenutil/logging/rotatingfilesink.h index 8901b7779..cebc5b110 100644 --- a/src/zenutil/include/zenutil/logging/rotatingfilesink.h +++ b/src/zenutil/include/zenutil/logging/rotatingfilesink.h @@ -3,14 +3,11 @@ #pragma once #include <zencore/basicfile.h> +#include <zencore/logging/formatter.h> +#include <zencore/logging/messageonlyformatter.h> +#include <zencore/logging/sink.h> #include <zencore/memory/llm.h> -ZEN_THIRD_PARTY_INCLUDES_START -#include <spdlog/details/log_msg.h> -#include <spdlog/pattern_formatter.h> -#include <spdlog/sinks/sink.h> -ZEN_THIRD_PARTY_INCLUDES_END - #include <atomic> #include <filesystem> @@ -19,13 +16,14 @@ namespace zen::logging { // Basically the same functionality as spdlog::sinks::rotating_file_sink with the biggest difference // being that it just ignores any errors when writing/rotating files and keeps chugging on. // It will keep trying to log, and if it starts to work it will continue to log. -class RotatingFileSink : public spdlog::sinks::sink +class RotatingFileSink : public Sink { public: RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen = false) : m_BaseFilename(BaseFilename) , m_MaxSize(MaxSize) , m_MaxFiles(MaxFiles) + , m_Formatter(std::make_unique<MessageOnlyFormatter>()) { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -76,18 +74,21 @@ public: RotatingFileSink& operator=(const RotatingFileSink&) = delete; RotatingFileSink& operator=(RotatingFileSink&&) = delete; - virtual void log(const spdlog::details::log_msg& msg) override + virtual void Log(const LogMessage& Msg) override { ZEN_MEMSCOPE(ELLMTag::Logging); try { - spdlog::memory_buf_t Formatted; - if (TrySinkIt(msg, Formatted)) + MemoryBuffer Formatted; + if (TrySinkIt(Msg, Formatted)) { return; } - while (true) + + // This intentionally has no limit on the number of retries, see + // comment above. + for (;;) { { RwLock::ExclusiveLockScope RotateLock(m_Lock); @@ -113,7 +114,7 @@ public: // Silently eat errors } } - virtual void flush() override + virtual void Flush() override { if (!m_NeedFlush) { @@ -138,28 +139,14 @@ public: m_NeedFlush = false; } - virtual void set_pattern(const std::string& pattern) override + virtual void SetFormatter(std::unique_ptr<Formatter> InFormatter) override { ZEN_MEMSCOPE(ELLMTag::Logging); try { RwLock::ExclusiveLockScope _(m_Lock); - m_Formatter = spdlog::details::make_unique<spdlog::pattern_formatter>(pattern); - } - catch (const std::exception&) - { - // Silently eat errors - } - } - virtual void set_formatter(std::unique_ptr<spdlog::formatter> sink_formatter) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - try - { - RwLock::ExclusiveLockScope _(m_Lock); - m_Formatter = std::move(sink_formatter); + m_Formatter = std::move(InFormatter); } catch (const std::exception&) { @@ -186,11 +173,17 @@ private: return; } - // If we fail to rotate, try extending the current log file m_CurrentSize = m_CurrentFile.FileSize(OutEc); + if (OutEc) + { + // FileSize failed but we have an open file — reset to 0 + // so we can at least attempt writes from the start + m_CurrentSize = 0; + OutEc.clear(); + } } - bool TrySinkIt(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& OutFormatted) + bool TrySinkIt(const LogMessage& Msg, MemoryBuffer& OutFormatted) { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -199,15 +192,15 @@ private: { return false; } - m_Formatter->format(msg, OutFormatted); - size_t add_size = OutFormatted.size(); - size_t write_pos = m_CurrentSize.fetch_add(add_size); - if (write_pos + add_size > m_MaxSize) + m_Formatter->Format(Msg, OutFormatted); + size_t AddSize = OutFormatted.size(); + size_t WritePos = m_CurrentSize.fetch_add(AddSize); + if (WritePos + AddSize > m_MaxSize) { return false; } std::error_code Ec; - m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), write_pos, Ec); + m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), WritePos, Ec); if (Ec) { return false; @@ -216,7 +209,7 @@ private: return true; } - bool TrySinkIt(const spdlog::memory_buf_t& Formatted) + bool TrySinkIt(const MemoryBuffer& Formatted) { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -225,15 +218,15 @@ private: { return false; } - size_t add_size = Formatted.size(); - size_t write_pos = m_CurrentSize.fetch_add(add_size); - if (write_pos + add_size > m_MaxSize) + size_t AddSize = Formatted.size(); + size_t WritePos = m_CurrentSize.fetch_add(AddSize); + if (WritePos + AddSize > m_MaxSize) { return false; } std::error_code Ec; - m_CurrentFile.Write(Formatted.data(), Formatted.size(), write_pos, Ec); + m_CurrentFile.Write(Formatted.data(), Formatted.size(), WritePos, Ec); if (Ec) { return false; @@ -242,14 +235,14 @@ private: return true; } - RwLock m_Lock; - const std::filesystem::path m_BaseFilename; - std::unique_ptr<spdlog::formatter> m_Formatter; - std::atomic_size_t m_CurrentSize; - const std::size_t m_MaxSize; - const std::size_t m_MaxFiles; - BasicFile m_CurrentFile; - std::atomic<bool> m_NeedFlush = false; + RwLock m_Lock; + const std::filesystem::path m_BaseFilename; + const std::size_t m_MaxSize; + const std::size_t m_MaxFiles; + std::unique_ptr<Formatter> m_Formatter; + std::atomic_size_t m_CurrentSize; + BasicFile m_CurrentFile; + std::atomic<bool> m_NeedFlush = false; }; } // namespace zen::logging diff --git a/src/zenutil/include/zenutil/logging/testformatter.h b/src/zenutil/include/zenutil/logging/testformatter.h deleted file mode 100644 index 0b0c191fb..000000000 --- a/src/zenutil/include/zenutil/logging/testformatter.h +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/memory/llm.h> - -#include <spdlog/spdlog.h> - -namespace zen::logging { - -class full_test_formatter final : public spdlog::formatter -{ -public: - full_test_formatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) : m_Epoch(Epoch), m_LogId(LogId) - { - } - - virtual std::unique_ptr<formatter> clone() const override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - return std::make_unique<full_test_formatter>(m_LogId, m_Epoch); - } - - static constexpr bool UseDate = false; - - virtual void format(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& dest) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - using namespace std::literals; - - if constexpr (UseDate) - { - auto secs = std::chrono::duration_cast<std::chrono::seconds>(msg.time.time_since_epoch()); - if (secs != m_LastLogSecs) - { - m_CachedTm = spdlog::details::os::localtime(spdlog::log_clock::to_time_t(msg.time)); - m_LastLogSecs = secs; - } - } - - const auto& tm_time = m_CachedTm; - - // cache the date/time part for the next second. - auto duration = msg.time - m_Epoch; - auto secs = std::chrono::duration_cast<std::chrono::seconds>(duration); - - if (m_CacheTimestamp != secs) - { - RwLock::ExclusiveLockScope _(m_TimestampLock); - - m_CachedDatetime.clear(); - m_CachedDatetime.push_back('['); - - if constexpr (UseDate) - { - spdlog::details::fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - - spdlog::details::fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - - spdlog::details::fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime); - m_CachedDatetime.push_back(' '); - - spdlog::details::fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - - spdlog::details::fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - - spdlog::details::fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime); - } - else - { - int Count = int(secs.count()); - - const int LogSecs = Count % 60; - Count /= 60; - - const int LogMins = Count % 60; - Count /= 60; - - const int LogHours = Count; - - spdlog::details::fmt_helper::pad2(LogHours, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(LogMins, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(LogSecs, m_CachedDatetime); - } - - m_CachedDatetime.push_back('.'); - - m_CacheTimestamp = secs; - } - - { - RwLock::SharedLockScope _(m_TimestampLock); - dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); - } - - auto millis = spdlog::details::fmt_helper::time_fraction<std::chrono::milliseconds>(msg.time); - spdlog::details::fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest); - dest.push_back(']'); - dest.push_back(' '); - - if (!m_LogId.empty()) - { - dest.push_back('['); - spdlog::details::fmt_helper::append_string_view(m_LogId, dest); - dest.push_back(']'); - dest.push_back(' '); - } - - // append logger name if exists - if (msg.logger_name.size() > 0) - { - dest.push_back('['); - spdlog::details::fmt_helper::append_string_view(msg.logger_name, dest); - dest.push_back(']'); - dest.push_back(' '); - } - - dest.push_back('['); - // wrap the level name with color - msg.color_range_start = dest.size(); - spdlog::details::fmt_helper::append_string_view(spdlog::level::to_string_view(msg.level), dest); - msg.color_range_end = dest.size(); - dest.push_back(']'); - dest.push_back(' '); - - // add source location if present - if (!msg.source.empty()) - { - dest.push_back('['); - const char* filename = - spdlog::details::short_filename_formatter<spdlog::details::null_scoped_padder>::basename(msg.source.filename); - spdlog::details::fmt_helper::append_string_view(filename, dest); - dest.push_back(':'); - spdlog::details::fmt_helper::append_int(msg.source.line, dest); - dest.push_back(']'); - dest.push_back(' '); - } - - spdlog::details::fmt_helper::append_string_view(msg.payload, dest); - spdlog::details::fmt_helper::append_string_view("\n"sv, dest); - } - -private: - std::chrono::time_point<std::chrono::system_clock> m_Epoch; - std::tm m_CachedTm; - std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)}; - std::chrono::seconds m_CacheTimestamp{std::chrono::seconds(87654321)}; - spdlog::memory_buf_t m_CachedDatetime; - std::string m_LogId; - RwLock m_TimestampLock; -}; - -} // namespace zen::logging diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h index d0402640b..2a8617162 100644 --- a/src/zenutil/include/zenutil/zenserverprocess.h +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -42,9 +42,13 @@ public: std::filesystem::path GetTestRootDir(std::string_view Path); inline bool IsInitialized() const { return m_IsInitialized; } inline bool IsTestEnvironment() const { return m_IsTestInstance; } + inline bool IsHubEnvironment() const { return m_IsHubInstance; } inline std::string_view GetServerClass() const { return m_ServerClass; } inline uint16_t GetNewPortNumber() { return m_NextPortNumber.fetch_add(1); } + void SetPassthroughOutput(bool Enable) { m_PassthroughOutput = Enable; } + bool IsPassthroughOutput() const { return m_PassthroughOutput; } + // The defaults will work for a single root process only. For hierarchical // setups (e.g., hub managing storage servers), we need to be able to // allocate distinct child IDs and ports to avoid overlap/conflicts. @@ -54,9 +58,10 @@ public: private: std::filesystem::path m_ProgramBaseDir; std::filesystem::path m_ChildProcessBaseDir; - bool m_IsInitialized = false; - bool m_IsTestInstance = false; - bool m_IsHubInstance = false; + bool m_IsInitialized = false; + bool m_IsTestInstance = false; + bool m_IsHubInstance = false; + bool m_PassthroughOutput = false; std::string m_ServerClass; std::atomic_uint16_t m_NextPortNumber{20000}; }; @@ -79,6 +84,7 @@ struct ZenServerInstance { kStorageServer, // default kHubServer, + kComputeServer, }; ZenServerInstance(ZenServerEnvironment& TestEnvironment, ServerMode Mode = ServerMode::kStorageServer); @@ -96,9 +102,12 @@ struct ZenServerInstance inline int GetPid() const { return m_Process.Pid(); } inline void SetOwnerPid(int Pid) { m_OwnerPid = Pid; } void* GetProcessHandle() const { return m_Process.Handle(); } - bool IsRunning(); - bool Terminate(); - std::string GetLogOutput() const; +#if ZEN_PLATFORM_WINDOWS + void SetJobObject(JobObject* Job) { m_JobObject = Job; } +#endif + bool IsRunning(); + bool Terminate(); + std::string GetLogOutput() const; inline ServerMode GetServerMode() const { return m_ServerMode; } @@ -147,6 +156,9 @@ private: std::string m_Name; std::filesystem::path m_OutputCapturePath; std::filesystem::path m_ServerExecutablePath; +#if ZEN_PLATFORM_WINDOWS + JobObject* m_JobObject = nullptr; +#endif void CreateShutdownEvent(int BasePort); void SpawnServer(int BasePort, std::string_view AdditionalServerArgs, int WaitTimeoutMs); diff --git a/src/zenutil/logging.cpp b/src/zenutil/logging.cpp index 806b96d52..1258ca155 100644 --- a/src/zenutil/logging.cpp +++ b/src/zenutil/logging.cpp @@ -2,18 +2,15 @@ #include "zenutil/logging.h" -ZEN_THIRD_PARTY_INCLUDES_START -#include <spdlog/async.h> -#include <spdlog/async_logger.h> -#include <spdlog/sinks/ansicolor_sink.h> -#include <spdlog/sinks/msvc_sink.h> -#include <spdlog/spdlog.h> -ZEN_THIRD_PARTY_INCLUDES_END - #include <zencore/callstack.h> #include <zencore/compactbinary.h> #include <zencore/filesystem.h> #include <zencore/logging.h> +#include <zencore/logging/ansicolorsink.h> +#include <zencore/logging/asyncsink.h> +#include <zencore/logging/logger.h> +#include <zencore/logging/msvcsink.h> +#include <zencore/logging/registry.h> #include <zencore/memory/llm.h> #include <zencore/string.h> #include <zencore/timer.h> @@ -27,9 +24,9 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { static bool g_IsLoggingInitialized; -spdlog::sink_ptr g_FileSink; +logging::SinkPtr g_FileSink; -spdlog::sink_ptr +logging::SinkPtr GetFileSink() { return g_FileSink; @@ -52,33 +49,9 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) zen::logging::InitializeLogging(); zen::logging::EnableVTMode(); - bool IsAsync = LogOptions.AllowAsync; - - if (LogOptions.IsDebug) - { - IsAsync = false; - } - - if (LogOptions.IsTest) - { - IsAsync = false; - } - - if (IsAsync) - { - const int QueueSize = 8192; - const int ThreadCount = 1; - spdlog::init_thread_pool(QueueSize, ThreadCount, [&] { SetCurrentThreadName("spdlog_async"); }); - - auto AsyncSink = spdlog::create_async<spdlog::sinks::ansicolor_stdout_sink_mt>("main"); - zen::logging::SetDefault("main"); - } - // Sinks - spdlog::sink_ptr FileSink; - - // spdlog can't create directories that starts with `\\?\` so we make sure the folder exists before creating the logger instance + logging::SinkPtr FileSink; if (!LogOptions.AbsLogFile.empty()) { @@ -87,17 +60,17 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) zen::CreateDirectories(LogOptions.AbsLogFile.parent_path()); } - FileSink = std::make_shared<zen::logging::RotatingFileSink>(LogOptions.AbsLogFile, - /* max size */ 128 * 1024 * 1024, - /* max files */ 16, - /* rotate on open */ true); + FileSink = logging::SinkPtr(new zen::logging::RotatingFileSink(LogOptions.AbsLogFile, + /* max size */ 128 * 1024 * 1024, + /* max files */ 16, + /* rotate on open */ true)); if (LogOptions.AbsLogFile.extension() == ".json") { - FileSink->set_formatter(std::make_unique<logging::json_formatter>(LogOptions.LogId)); + FileSink->SetFormatter(std::make_unique<logging::JsonFormatter>(LogOptions.LogId)); } else { - FileSink->set_formatter(std::make_unique<logging::full_formatter>(LogOptions.LogId)); // this will have a date prefix + FileSink->SetFormatter(std::make_unique<logging::FullFormatter>(LogOptions.LogId)); // this will have a date prefix } } @@ -127,7 +100,7 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) Message.push_back('\0'); // We use direct ZEN_LOG here instead of ZEN_ERROR as we don't care about *this* code location in the log - ZEN_LOG(Log(), zen::logging::level::Critical, "{}", Message.data()); + ZEN_LOG(Log(), zen::logging::Critical, "{}", Message.data()); zen::logging::FlushLogging(); } catch (const std::exception&) @@ -143,9 +116,9 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) // Default LoggerRef DefaultLogger = zen::logging::Default(); - auto& Sinks = DefaultLogger.SpdLogger->sinks(); - Sinks.clear(); + // Collect sinks into a local vector first so we can optionally wrap them + std::vector<logging::SinkPtr> Sinks; if (LogOptions.NoConsoleOutput) { @@ -153,10 +126,10 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) } else { - auto ConsoleSink = std::make_shared<spdlog::sinks::ansicolor_stdout_sink_mt>(); + logging::SinkPtr ConsoleSink(new logging::AnsiColorStdoutSink()); if (LogOptions.QuietConsole) { - ConsoleSink->set_level(spdlog::level::warn); + ConsoleSink->SetLevel(logging::Warn); } Sinks.push_back(ConsoleSink); } @@ -169,40 +142,54 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) #if ZEN_PLATFORM_WINDOWS if (zen::IsDebuggerPresent() && LogOptions.IsDebug) { - auto DebugSink = std::make_shared<spdlog::sinks::msvc_sink_mt>(); - DebugSink->set_level(spdlog::level::debug); + logging::SinkPtr DebugSink(new logging::MsvcSink()); + DebugSink->SetLevel(logging::Debug); Sinks.push_back(DebugSink); } #endif - spdlog::set_error_handler([](const std::string& msg) { - if (msg == std::bad_alloc().what()) - { - // Don't report out of memory in spdlog as we usually log in response to errors which will cause another OOM crashing the - // program - return; - } - // Bypass zen logging wrapping to reduce potential other error sources - if (auto ErrLogger = zen::logging::ErrorLog()) + bool IsAsync = LogOptions.AllowAsync && !LogOptions.IsDebug && !LogOptions.IsTest; + + if (IsAsync) + { + std::vector<logging::SinkPtr> AsyncSinks; + AsyncSinks.emplace_back(new logging::AsyncSink(std::move(Sinks))); + DefaultLogger->SetSinks(std::move(AsyncSinks)); + } + else + { + DefaultLogger->SetSinks(std::move(Sinks)); + } + + static struct : logging::ErrorHandler + { + void HandleError(const std::string_view& ErrorMsg) override { + if (ErrorMsg == std::bad_alloc().what()) + { + return; + } + static constinit logging::LogPoint ErrorPoint{{}, logging::Err, "{}"}; + if (auto ErrLogger = zen::logging::ErrorLog()) + { + try + { + ErrLogger->Log(ErrorPoint, fmt::make_format_args(ErrorMsg)); + } + catch (const std::exception&) + { + } + } try { - ErrLogger.SpdLogger->log(spdlog::level::err, msg); + Log()->Log(ErrorPoint, fmt::make_format_args(ErrorMsg)); } catch (const std::exception&) { - // Just ignore any errors when in error handler } } - try - { - Log().SpdLogger->error(msg); - } - catch (const std::exception&) - { - // Just ignore any errors when in error handler - } - }); + } s_ErrorHandler; + logging::Registry::Instance().SetErrorHandler(&s_ErrorHandler); g_FileSink = std::move(FileSink); } @@ -212,41 +199,47 @@ FinishInitializeLogging(const LoggingOptions& LogOptions) { ZEN_MEMSCOPE(ELLMTag::Logging); - logging::level::LogLevel LogLevel = logging::level::Info; + logging::LogLevel LogLevel = logging::Info; if (LogOptions.IsDebug) { - LogLevel = logging::level::Debug; + LogLevel = logging::Debug; } if (LogOptions.IsTest || LogOptions.IsVerbose) { - LogLevel = logging::level::Trace; + LogLevel = logging::Trace; } // Configure all registered loggers according to settings logging::RefreshLogLevels(LogLevel); - spdlog::flush_on(spdlog::level::err); - spdlog::flush_every(std::chrono::seconds{2}); - spdlog::set_formatter(std::make_unique<logging::full_formatter>( + logging::Registry::Instance().FlushOn(logging::Err); + logging::Registry::Instance().FlushEvery(std::chrono::seconds{2}); + logging::Registry::Instance().SetFormatter(std::make_unique<logging::FullFormatter>( LogOptions.LogId, std::chrono::system_clock::now() - std::chrono::milliseconds(GetTimeSinceProcessStart()))); // default to duration prefix + // If the console logger was initialized before, the above will change the output format + // so we need to reset it + + logging::ResetConsoleLog(); + if (g_FileSink) { if (LogOptions.AbsLogFile.extension() == ".json") { - g_FileSink->set_formatter(std::make_unique<logging::json_formatter>(LogOptions.LogId)); + g_FileSink->SetFormatter(std::make_unique<logging::JsonFormatter>(LogOptions.LogId)); } else { - g_FileSink->set_formatter(std::make_unique<logging::full_formatter>(LogOptions.LogId)); // this will have a date prefix + g_FileSink->SetFormatter(std::make_unique<logging::FullFormatter>(LogOptions.LogId)); // this will have a date prefix } const std::string StartLogTime = zen::DateTime::Now().ToIso8601(); - spdlog::apply_all([&](auto Logger) { Logger->info("log starting at {}", StartLogTime); }); + static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"}; + logging::Registry::Instance().ApplyAll([&](auto Logger) { Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime)); }); } g_IsLoggingInitialized = true; @@ -263,7 +256,7 @@ ShutdownLogging() zen::logging::ShutdownLogging(); - g_FileSink.reset(); + g_FileSink = nullptr; } } // namespace zen diff --git a/src/zenutil/rpcrecording.cpp b/src/zenutil/rpcrecording.cpp index 54f27dee7..28a0091cb 100644 --- a/src/zenutil/rpcrecording.cpp +++ b/src/zenutil/rpcrecording.cpp @@ -1119,7 +1119,7 @@ rpcrecord_forcelink() { } -TEST_SUITE_BEGIN("rpc.recording"); +TEST_SUITE_BEGIN("util.rpcrecording"); TEST_CASE("rpc.record") { diff --git a/src/zenutil/wildcard.cpp b/src/zenutil/wildcard.cpp index 7a44c0498..7f2f77780 100644 --- a/src/zenutil/wildcard.cpp +++ b/src/zenutil/wildcard.cpp @@ -118,6 +118,8 @@ wildcard_forcelink() { } +TEST_SUITE_BEGIN("util.wildcard"); + TEST_CASE("Wildcard") { CHECK(MatchWildcard("*.*", "normal.txt", true)); @@ -151,5 +153,7 @@ TEST_CASE("Wildcard") CHECK(MatchWildcard("*.d", "dir/path.d", true)); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zenutil/xmake.lua b/src/zenutil/xmake.lua index bc33adf9e..1d5be5977 100644 --- a/src/zenutil/xmake.lua +++ b/src/zenutil/xmake.lua @@ -6,7 +6,7 @@ target('zenutil') add_headerfiles("**.h") add_files("**.cpp") add_includedirs("include", {public=true}) - add_deps("zencore", "zenhttp", "spdlog") + add_deps("zencore", "zenhttp") add_deps("cxxopts") add_deps("robin-map") diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index ef2a4fda5..b09c2d89a 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -787,6 +787,8 @@ ToString(ZenServerInstance::ServerMode Mode) return "storage"sv; case ZenServerInstance::ServerMode::kHubServer: return "hub"sv; + case ZenServerInstance::ServerMode::kComputeServer: + return "compute"sv; default: return "invalid"sv; } @@ -808,6 +810,10 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, { CommandLine << " hub"; } + else if (m_ServerMode == ServerMode::kComputeServer) + { + CommandLine << " compute"; + } CommandLine << " --child-id " << ChildEventName; @@ -829,10 +835,18 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, const std::filesystem::path BaseDir = m_Env.ProgramBaseDir(); const std::filesystem::path Executable = m_ServerExecutablePath.empty() ? (BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL) : m_ServerExecutablePath; - const std::filesystem::path OutputPath = - OpenConsole ? std::filesystem::path{} : std::filesystem::temp_directory_path() / ("zenserver_" + m_Name + ".log"); - CreateProcOptions CreateOptions = {.WorkingDirectory = &CurrentDirectory, .Flags = CreationFlags, .StdoutFile = OutputPath}; - CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); + const std::filesystem::path OutputPath = (OpenConsole || m_Env.IsPassthroughOutput()) + ? std::filesystem::path{} + : std::filesystem::temp_directory_path() / ("zenserver_" + m_Name + ".log"); + CreateProcOptions CreateOptions = { + .WorkingDirectory = &CurrentDirectory, + .Flags = CreationFlags, + .StdoutFile = OutputPath, +#if ZEN_PLATFORM_WINDOWS + .AssignToJob = m_JobObject, +#endif + }; + CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); #if ZEN_PLATFORM_WINDOWS if (!ChildPid) { @@ -841,6 +855,12 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, { ZEN_DEBUG("Regular spawn failed - spawning elevated server"); CreateOptions.Flags |= CreateProcOptions::Flag_Elevated; + // ShellExecuteEx (used by the elevated path) does not support job object assignment + if (CreateOptions.AssignToJob) + { + ZEN_WARN("Elevated process spawn does not support job object assignment; child will not be auto-terminated on parent exit"); + CreateOptions.AssignToJob = nullptr; + } ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); } else @@ -934,7 +954,8 @@ ZenServerInstance::SpawnServer(int BasePort, std::string_view AdditionalServerAr CommandLine << " " << AdditionalServerArgs; } - SpawnServerInternal(ChildId, CommandLine, !IsTest, WaitTimeoutMs); + const bool OpenConsole = !IsTest && !m_Env.IsHubEnvironment(); + SpawnServerInternal(ChildId, CommandLine, OpenConsole, WaitTimeoutMs); } void diff --git a/src/zenutil/zenutil.cpp b/src/zenutil/zenutil.cpp index 51c1ee72e..291dbeadd 100644 --- a/src/zenutil/zenutil.cpp +++ b/src/zenutil/zenutil.cpp @@ -5,7 +5,7 @@ #if ZEN_WITH_TESTS # include <zenutil/rpcrecording.h> -# include <zenutil/commandlineoptions.h> +# include <zenutil/config/commandlineoptions.h> # include <zenutil/wildcard.h> namespace zen { diff --git a/src/zenvfs/xmake.lua b/src/zenvfs/xmake.lua index 7f790c2d4..47665a5d5 100644 --- a/src/zenvfs/xmake.lua +++ b/src/zenvfs/xmake.lua @@ -6,5 +6,5 @@ target('zenvfs') add_headerfiles("**.h") add_files("**.cpp") add_includedirs("include", {public=true}) - add_deps("zencore", "spdlog") + add_deps("zencore") |