diff options
Diffstat (limited to 'src/zenutil')
49 files changed, 8660 insertions, 232 deletions
diff --git a/src/zenutil/cloud/imdscredentials.cpp b/src/zenutil/cloud/imdscredentials.cpp index dde1dc019..a23cb9c28 100644 --- a/src/zenutil/cloud/imdscredentials.cpp +++ b/src/zenutil/cloud/imdscredentials.cpp @@ -64,6 +64,7 @@ ImdsCredentialProvider::ImdsCredentialProvider(const ImdsCredentialProviderOptio .LogCategory = "imds", .ConnectTimeout = Options.ConnectTimeout, .Timeout = Options.RequestTimeout, + .RetryCount = 3, }) { ZEN_INFO("IMDS credential provider configured (endpoint: {})", m_HttpClient.GetBaseUri()); @@ -115,7 +116,7 @@ ImdsCredentialProvider::FetchToken() HttpClient::KeyValueMap Headers; Headers->emplace("X-aws-ec2-metadata-token-ttl-seconds", "21600"); - HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", Headers); + HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", IoBuffer{}, Headers); if (!Response.IsSuccess()) { ZEN_WARN("IMDS token request failed: {}", Response.ErrorMessage("PUT /latest/api/token")); @@ -213,7 +214,7 @@ ImdsCredentialProvider::FetchCredentials() } else { - // Expiration is in the past or unparseable — force refresh next time + // Expiration is in the past or unparseable - force refresh next time NewExpiresAt = std::chrono::steady_clock::now(); } @@ -226,7 +227,7 @@ ImdsCredentialProvider::FetchCredentials() if (KeyChanged) { - ZEN_INFO("IMDS credentials refreshed (AccessKeyId: {}...)", m_CachedCredentials.AccessKeyId.substr(0, 8)); + ZEN_INFO("IMDS credentials refreshed (AccessKeyId: {})", HideSensitiveString(m_CachedCredentials.AccessKeyId)); } else { @@ -369,7 +370,7 @@ TEST_CASE("imdscredentials.fetch_from_mock") TEST_CASE("imdscredentials.unreachable_endpoint") { - // Point at a non-existent server — should return empty credentials, not crash + // Point at a non-existent server - should return empty credentials, not crash ImdsCredentialProviderOptions Opts; Opts.Endpoint = "http://127.0.0.1:1"; // unlikely to have anything listening Opts.ConnectTimeout = std::chrono::milliseconds(100); diff --git a/src/zenutil/cloud/minioprocess.cpp b/src/zenutil/cloud/minioprocess.cpp index 565705731..2db0010dc 100644 --- a/src/zenutil/cloud/minioprocess.cpp +++ b/src/zenutil/cloud/minioprocess.cpp @@ -45,7 +45,7 @@ struct MinioProcess::Impl } CreateProcOptions Options; - Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; Options.Environment.emplace_back("MINIO_ROOT_USER", m_Options.RootUser); Options.Environment.emplace_back("MINIO_ROOT_PASSWORD", m_Options.RootPassword); @@ -72,11 +72,12 @@ struct MinioProcess::Impl ZEN_INFO("MinIO server started successfully (waited {})", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); return; } - } while (Timer.GetElapsedTimeMs() < 10000); + } while (Timer.GetElapsedTimeMs() < 30000); } - // Report failure - ZEN_WARN("MinIO server failed to start within timeout period"); + // Report failure - throw so test failures show the real cause instead of a confusing + // assertion failure later when S3 operations fail silently. + throw std::runtime_error(fmt::format("MinIO server on port {} failed to start within timeout", m_Options.Port)); } void StopMinioServer() @@ -101,7 +102,7 @@ struct MinioProcess::Impl { if (m_DataDir.empty()) { - ZEN_WARN("MinIO: Cannot create bucket before data directory is initialized — call SpawnMinioServer() first"); + ZEN_WARN("MinIO: Cannot create bucket before data directory is initialized - call SpawnMinioServer() first"); return; } diff --git a/src/zenutil/cloud/mockimds.cpp b/src/zenutil/cloud/mockimds.cpp index 6919fab4d..88b348ed6 100644 --- a/src/zenutil/cloud/mockimds.cpp +++ b/src/zenutil/cloud/mockimds.cpp @@ -93,7 +93,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request) return; } - // Autoscaling lifecycle state — 404 when not in an ASG + // Autoscaling lifecycle state - 404 when not in an ASG if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") { if (Aws.AutoscalingState.empty()) @@ -105,7 +105,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request) return; } - // Spot interruption notice — 404 when no interruption pending + // Spot interruption notice - 404 when no interruption pending if (Uri == "latest/meta-data/spot/instance-action") { if (Aws.SpotAction.empty()) @@ -117,7 +117,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request) return; } - // IAM role discovery — returns the role name + // IAM role discovery - returns the role name if (Uri == "latest/meta-data/iam/security-credentials/") { if (Aws.IamRoleName.empty()) diff --git a/src/zenutil/cloud/s3client.cpp b/src/zenutil/cloud/s3client.cpp index 88d844b61..f8bed92da 100644 --- a/src/zenutil/cloud/s3client.cpp +++ b/src/zenutil/cloud/s3client.cpp @@ -135,8 +135,45 @@ namespace { return nullptr; } + /// Extract Code/Message from an S3 XML error body. Returns true if an <Error> element was + /// found, even if Code/Message are empty. + bool ExtractS3Error(std::string_view Body, std::string_view& OutCode, std::string_view& OutMessage) + { + if (Body.find("<Error>") == std::string_view::npos) + { + return false; + } + OutCode = ExtractXmlValue(Body, "Code"); + OutMessage = ExtractXmlValue(Body, "Message"); + return true; + } + + /// Build a human-readable error message for a failed S3 response. When the response body + /// contains an S3 `<Error>` element, the Code and Message fields are included in the string + /// so transient 4xx/5xx failures (SignatureDoesNotMatch, AuthorizationHeaderMalformed, etc.) + /// show up in logs instead of being swallowed. Falls back to the generic HTTP/transport + /// message when no XML body is available (HEAD responses, transport errors). + std::string S3ErrorMessage(std::string_view Prefix, const HttpClient::Response& Response) + { + if (!Response.Error.has_value() && Response.ResponsePayload) + { + std::string_view Body(reinterpret_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize()); + std::string_view Code; + std::string_view Message; + if (ExtractS3Error(Body, Code, Message) && (!Code.empty() || !Message.empty())) + { + ExtendableStringBuilder<256> Decoded; + DecodeXmlEntities(Message, Decoded); + return fmt::format("{}: HTTP status ({}) {} - {}", Prefix, static_cast<int>(Response.StatusCode), Code, Decoded.ToView()); + } + } + return Response.ErrorMessage(Prefix); + } + } // namespace +std::string_view S3GetObjectResult::NotFoundErrorText = "Not found"; + S3Client::S3Client(const S3ClientOptions& Options) : m_Log(logging::Get("s3")) , m_BucketName(Options.BucketName) @@ -145,13 +182,8 @@ S3Client::S3Client(const S3ClientOptions& Options) , m_PathStyle(Options.PathStyle) , m_Credentials(Options.Credentials) , m_CredentialProvider(Options.CredentialProvider) -, m_HttpClient(BuildEndpoint(), - HttpClientSettings{ - .LogCategory = "s3", - .ConnectTimeout = Options.ConnectTimeout, - .Timeout = Options.Timeout, - .RetryCount = Options.RetryCount, - }) +, m_HttpClient(BuildEndpoint(), Options.HttpSettings) +, m_Verbose(Options.HttpSettings.Verbose) { m_Host = BuildHostHeader(); ZEN_INFO("S3 client configured for bucket '{}' in region '{}' (endpoint: {}, {})", @@ -171,20 +203,33 @@ S3Client::GetCurrentCredentials() SigV4Credentials Creds = m_CredentialProvider->GetCredentials(); if (!Creds.AccessKeyId.empty()) { - // Invalidate the signing key cache when the access key changes + // Invalidate the signing key cache when the access key changes, and update stored + // credentials atomically under the same lock so callers see a consistent snapshot. + RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); if (Creds.AccessKeyId != m_Credentials.AccessKeyId) { - RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); m_CachedDateStamp.clear(); } m_Credentials = Creds; + // Return Creds directly - avoids reading m_Credentials after releasing the lock, + // which would race with another concurrent write. + return Creds; } + // IMDS returned empty credentials; fall back to the last known-good credentials. return m_Credentials; } return m_Credentials; } std::string +S3Client::BuildNoCredentialsError(std::string Context) +{ + std::string Err = fmt::format("{}: no credentials available", Context); + ZEN_WARN("{}", Err); + return Err; +} + +std::string S3Client::BuildEndpoint() const { if (!m_Endpoint.empty()) @@ -252,7 +297,7 @@ S3Client::BucketRootPath() const Sha256Digest S3Client::GetSigningKey(std::string_view DateStamp) { - // Fast path: shared lock for cache hit (common case — key only changes once per day) + // Fast path: shared lock for cache hit (common case - key only changes once per day) { RwLock::SharedLockScope SharedLock(m_SigningKeyLock); if (m_CachedDateStamp == DateStamp) @@ -284,14 +329,18 @@ S3Client::GetSigningKey(std::string_view DateStamp) } HttpClient::KeyValueMap -S3Client::SignRequest(std::string_view Method, std::string_view Path, std::string_view CanonicalQueryString, std::string_view PayloadHash) +S3Client::SignRequest(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Path, + std::string_view CanonicalQueryString, + std::string_view PayloadHash, + std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders) { - SigV4Credentials Credentials = GetCurrentCredentials(); - std::string AmzDate = GetAmzTimestamp(); // Build sorted headers to sign (must be sorted by lowercase name) std::vector<std::pair<std::string, std::string>> HeadersToSign; + HeadersToSign.reserve(4 + ExtraSignedHeaders.size()); HeadersToSign.emplace_back("host", m_Host); HeadersToSign.emplace_back("x-amz-content-sha256", std::string(PayloadHash)); HeadersToSign.emplace_back("x-amz-date", AmzDate); @@ -299,6 +348,10 @@ S3Client::SignRequest(std::string_view Method, std::string_view Path, std::strin { HeadersToSign.emplace_back("x-amz-security-token", Credentials.SessionToken); } + for (const auto& [K, V] : ExtraSignedHeaders) + { + HeadersToSign.emplace_back(K, V); + } std::sort(HeadersToSign.begin(), HeadersToSign.end()); std::string_view DateStamp(AmzDate.data(), 8); @@ -315,6 +368,10 @@ S3Client::SignRequest(std::string_view Method, std::string_view Path, std::strin { Result->emplace("x-amz-security-token", Credentials.SessionToken); } + for (const auto& [K, V] : ExtraSignedHeaders) + { + Result->emplace(K, V); + } return Result; } @@ -322,69 +379,210 @@ S3Client::SignRequest(std::string_view Method, std::string_view Path, std::strin S3Result S3Client::PutObject(std::string_view Key, IoBuffer Content) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 PUT '{}' failed", Key); !Err.empty()) + { + return S3Result{std::move(Err)}; + } + std::string Path = KeyToPath(Key); // Hash the payload std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); - HttpClient::KeyValueMap Headers = SignRequest("PUT", Path, "", PayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, "", PayloadHash); HttpClient::Response Response = m_HttpClient.Put(Path, Content, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 PUT failed"); + std::string Err = S3ErrorMessage("S3 PUT failed", Response); ZEN_WARN("S3 PUT '{}' failed: {}", Key, Err); return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 PUT '{}' succeeded ({} bytes)", Key, Content.GetSize()); + if (m_Verbose) + { + ZEN_INFO("S3 PUT '{}' succeeded ({} bytes)", Key, Content.GetSize()); + } return {}; } S3GetObjectResult -S3Client::GetObject(std::string_view Key) +S3Client::GetObject(std::string_view Key, const std::filesystem::path& TempFilePath) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 GET '{}' failed", Key); !Err.empty()) + { + return S3GetObjectResult{S3Result{std::move(Err)}, {}}; + } + std::string Path = KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest("GET", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", Path, "", EmptyPayloadHash); - HttpClient::Response Response = m_HttpClient.Get(Path, Headers); + HttpClient::Response Response = m_HttpClient.Download(Path, TempFilePath, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 GET failed"); + if (Response.StatusCode == HttpResponseCode::NotFound) + { + return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}}; + } + + std::string Err = S3ErrorMessage("S3 GET failed", Response); ZEN_WARN("S3 GET '{}' failed: {}", Key, Err); return S3GetObjectResult{S3Result{std::move(Err)}, {}}; } - ZEN_DEBUG("S3 GET '{}' succeeded ({} bytes)", Key, Response.ResponsePayload.GetSize()); + if (m_Verbose) + { + ZEN_INFO("S3 GET '{}' succeeded ({} bytes)", Key, Response.ResponsePayload.GetSize()); + } + return S3GetObjectResult{{}, std::move(Response.ResponsePayload)}; +} + +S3GetObjectResult +S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t RangeSize) +{ + ZEN_ASSERT(RangeSize > 0); + + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 GET range '{}' [{}-{}] failed", Key, RangeStart, RangeStart + RangeSize - 1); + !Err.empty()) + { + return S3GetObjectResult{S3Result{std::move(Err)}, {}}; + } + + std::string Path = KeyToPath(Key); + + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", Path, "", EmptyPayloadHash); + Headers->emplace("Range", fmt::format("bytes={}-{}", RangeStart, RangeStart + RangeSize - 1)); + + HttpClient::Response Response = m_HttpClient.Get(Path, Headers); + if (!Response.IsSuccess()) + { + if (Response.StatusCode == HttpResponseCode::NotFound) + { + return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}}; + } + + std::string Err = S3ErrorMessage("S3 GET range failed", Response); + ZEN_WARN("S3 GET range '{}' [{}-{}] failed: {}", Key, RangeStart, RangeStart + RangeSize - 1, Err); + return S3GetObjectResult{S3Result{std::move(Err)}, {}}; + } + + // Callers are expected to request only ranges that lie within the known object size (e.g. + // by calling HeadObject first). Treat a short read as an error rather than silently + // returning a truncated buffer - a partial write is more dangerous than a hard failure. + if (Response.ResponsePayload.GetSize() != RangeSize) + { + std::string Err = fmt::format("S3 GET range '{}' [{}-{}] returned {} bytes, expected {}", + Key, + RangeStart, + RangeStart + RangeSize - 1, + Response.ResponsePayload.GetSize(), + RangeSize); + ZEN_WARN("{}", Err); + return S3GetObjectResult{S3Result{std::move(Err)}, {}}; + } + + if (m_Verbose) + { + ZEN_INFO("S3 GET range '{}' [{}-{}] succeeded ({} bytes)", + Key, + RangeStart, + RangeStart + RangeSize - 1, + Response.ResponsePayload.GetSize()); + } return S3GetObjectResult{{}, std::move(Response.ResponsePayload)}; } S3Result S3Client::DeleteObject(std::string_view Key) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 DELETE '{}' failed", Key); !Err.empty()) + { + return S3Result{std::move(Err)}; + } + std::string Path = KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest("DELETE", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "DELETE", Path, "", EmptyPayloadHash); HttpClient::Response Response = m_HttpClient.Delete(Path, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 DELETE failed"); + std::string Err = S3ErrorMessage("S3 DELETE failed", Response); ZEN_WARN("S3 DELETE '{}' failed: {}", Key, Err); return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 DELETE '{}' succeeded", Key); + if (m_Verbose) + { + ZEN_INFO("S3 DELETE '{}' succeeded", Key); + } + return {}; +} + +S3Result +S3Client::Touch(std::string_view Key) +{ + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 Touch '{}' failed", Key); !Err.empty()) + { + return S3Result{std::move(Err)}; + } + + std::string Path = KeyToPath(Key); + + // x-amz-copy-source is always "/bucket/key" regardless of addressing style. + // Key must be URI-encoded except for '/' separators. When source and destination + // are identical, REPLACE is required; COPY is rejected with InvalidRequest. + const std::array<std::pair<std::string, std::string>, 2> ExtraSigned{{ + {"x-amz-copy-source", fmt::format("/{}/{}", m_BucketName, AwsUriEncode(Key, false))}, + {"x-amz-metadata-directive", "REPLACE"}, + }}; + + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, "", EmptyPayloadHash, ExtraSigned); + + HttpClient::Response Response = m_HttpClient.Put(Path, IoBuffer{}, Headers); + if (!Response.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 Touch failed", Response); + ZEN_WARN("S3 Touch '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + // Copy operations can return HTTP 200 with an error in the XML body. + std::string_view ResponseBody = Response.AsText(); + std::string_view ErrorCode; + std::string_view ErrorMessage; + if (ExtractS3Error(ResponseBody, ErrorCode, ErrorMessage)) + { + std::string Err = fmt::format("S3 Touch '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); + ZEN_WARN("{}", Err); + return S3Result{std::move(Err)}; + } + + if (m_Verbose) + { + ZEN_INFO("S3 Touch '{}' succeeded", Key); + } return {}; } S3HeadObjectResult S3Client::HeadObject(std::string_view Key) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 HEAD '{}' failed", Key); !Err.empty()) + { + return S3HeadObjectResult{S3Result{std::move(Err)}, {}, HeadObjectResult::Error}; + } + std::string Path = KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest("HEAD", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "HEAD", Path, "", EmptyPayloadHash); HttpClient::Response Response = m_HttpClient.Head(Path, Headers); if (!Response.IsSuccess()) @@ -394,7 +592,7 @@ S3Client::HeadObject(std::string_view Key) return S3HeadObjectResult{{}, {}, HeadObjectResult::NotFound}; } - std::string Err = Response.ErrorMessage("S3 HEAD failed"); + std::string Err = S3ErrorMessage("S3 HEAD failed", Response); ZEN_WARN("S3 HEAD '{}' failed: {}", Key, Err); return S3HeadObjectResult{S3Result{std::move(Err)}, {}, HeadObjectResult::Error}; } @@ -417,7 +615,10 @@ S3Client::HeadObject(std::string_view Key) Info.LastModified = *V; } - ZEN_DEBUG("S3 HEAD '{}' succeeded (size={})", Key, Info.Size); + if (m_Verbose) + { + ZEN_INFO("S3 HEAD '{}' succeeded (size={})", Key, Info.Size); + } return S3HeadObjectResult{{}, std::move(Info), HeadObjectResult::Found}; } @@ -430,6 +631,13 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) for (;;) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 ListObjectsV2 prefix='{}' failed", Prefix); !Err.empty()) + { + Result.Error = std::move(Err); + return Result; + } + // Build query parameters for ListObjectsV2 std::vector<std::pair<std::string, std::string>> QueryParams; QueryParams.emplace_back("list-type", "2"); @@ -448,13 +656,13 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) std::string CanonicalQS = BuildCanonicalQueryString(std::move(QueryParams)); std::string RootPath = BucketRootPath(); - HttpClient::KeyValueMap Headers = SignRequest("GET", RootPath, CanonicalQS, EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", RootPath, CanonicalQS, EmptyPayloadHash); std::string FullPath = BuildRequestPath(RootPath, CanonicalQS); HttpClient::Response Response = m_HttpClient.Get(FullPath, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 ListObjectsV2 failed"); + std::string Err = S3ErrorMessage("S3 ListObjectsV2 failed", Response); ZEN_WARN("S3 ListObjectsV2 prefix='{}' failed: {}", Prefix, Err); Result.Error = std::move(Err); return Result; @@ -514,10 +722,16 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) } ContinuationToken = std::string(NextToken); - ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' fetching next page ({} objects so far)", Prefix, Result.Objects.size()); + if (m_Verbose) + { + ZEN_INFO("S3 ListObjectsV2 prefix='{}' fetching next page ({} objects so far)", Prefix, Result.Objects.size()); + } } - ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' returned {} objects", Prefix, Result.Objects.size()); + if (m_Verbose) + { + ZEN_INFO("S3 ListObjectsV2 prefix='{}' returned {} objects", Prefix, Result.Objects.size()); + } return Result; } @@ -527,16 +741,22 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) S3CreateMultipartUploadResult S3Client::CreateMultipartUpload(std::string_view Key) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 CreateMultipartUpload '{}' failed", Key); !Err.empty()) + { + return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; + } + std::string Path = KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploads", ""}}); - HttpClient::KeyValueMap Headers = SignRequest("POST", Path, CanonicalQS, EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "POST", Path, CanonicalQS, EmptyPayloadHash); std::string FullPath = BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Post(FullPath, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 CreateMultipartUpload failed"); + std::string Err = S3ErrorMessage("S3 CreateMultipartUpload failed", Response); ZEN_WARN("S3 CreateMultipartUpload '{}' failed: {}", Key, Err); return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; } @@ -556,13 +776,22 @@ S3Client::CreateMultipartUpload(std::string_view Key) return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; } - ZEN_DEBUG("S3 CreateMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + if (m_Verbose) + { + ZEN_INFO("S3 CreateMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + } return S3CreateMultipartUploadResult{{}, std::string(UploadId)}; } S3UploadPartResult S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t PartNumber, IoBuffer Content) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 UploadPart '{}' part {} failed", Key, PartNumber); !Err.empty()) + { + return S3UploadPartResult{S3Result{std::move(Err)}, {}}; + } + std::string Path = KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({ {"partNumber", fmt::format("{}", PartNumber)}, @@ -571,13 +800,13 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); - HttpClient::KeyValueMap Headers = SignRequest("PUT", Path, CanonicalQS, PayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, CanonicalQS, PayloadHash); std::string FullPath = BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Put(FullPath, Content, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage(fmt::format("S3 UploadPart {} failed", PartNumber)); + std::string Err = S3ErrorMessage(fmt::format("S3 UploadPart {} failed", PartNumber), Response); ZEN_WARN("S3 UploadPart '{}' part {} failed: {}", Key, PartNumber, Err); return S3UploadPartResult{S3Result{std::move(Err)}, {}}; } @@ -591,7 +820,10 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P return S3UploadPartResult{S3Result{std::move(Err)}, {}}; } - ZEN_DEBUG("S3 UploadPart '{}' part {} succeeded ({} bytes, etag={})", Key, PartNumber, Content.GetSize(), *ETag); + if (m_Verbose) + { + ZEN_INFO("S3 UploadPart '{}' part {} succeeded ({} bytes, etag={})", Key, PartNumber, Content.GetSize(), *ETag); + } return S3UploadPartResult{{}, *ETag}; } @@ -600,6 +832,12 @@ S3Client::CompleteMultipartUpload(std::string_view Key, std::string_view UploadId, const std::vector<std::pair<uint32_t, std::string>>& PartETags) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 CompleteMultipartUpload '{}' failed", Key); !Err.empty()) + { + return S3Result{std::move(Err)}; + } + std::string Path = KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); @@ -615,7 +853,7 @@ S3Client::CompleteMultipartUpload(std::string_view Key, std::string_view XmlView = XmlBody.ToView(); std::string PayloadHash = Sha256ToHex(ComputeSha256(XmlView)); - HttpClient::KeyValueMap Headers = SignRequest("POST", Path, CanonicalQS, PayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "POST", Path, CanonicalQS, PayloadHash); Headers->emplace("Content-Type", "application/xml"); IoBuffer Payload(IoBuffer::Clone, XmlView.data(), XmlView.size()); @@ -624,44 +862,56 @@ S3Client::CompleteMultipartUpload(std::string_view Key, HttpClient::Response Response = m_HttpClient.Post(FullPath, Payload, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 CompleteMultipartUpload failed"); + std::string Err = S3ErrorMessage("S3 CompleteMultipartUpload failed", Response); ZEN_WARN("S3 CompleteMultipartUpload '{}' failed: {}", Key, Err); return S3Result{std::move(Err)}; } // Check for error in response body - S3 can return 200 with an error in the XML body std::string_view ResponseBody = Response.AsText(); - if (ResponseBody.find("<Error>") != std::string_view::npos) + std::string_view ErrorCode; + std::string_view ErrorMessage; + if (ExtractS3Error(ResponseBody, ErrorCode, ErrorMessage)) { - std::string_view ErrorCode = ExtractXmlValue(ResponseBody, "Code"); - std::string_view ErrorMessage = ExtractXmlValue(ResponseBody, "Message"); - std::string Err = fmt::format("S3 CompleteMultipartUpload '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); + std::string Err = fmt::format("S3 CompleteMultipartUpload '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); ZEN_WARN("{}", Err); return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 CompleteMultipartUpload '{}' succeeded ({} parts)", Key, PartETags.size()); + if (m_Verbose) + { + ZEN_INFO("S3 CompleteMultipartUpload '{}' succeeded ({} parts)", Key, PartETags.size()); + } return {}; } S3Result S3Client::AbortMultipartUpload(std::string_view Key, std::string_view UploadId) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 AbortMultipartUpload '{}' failed", Key); !Err.empty()) + { + return S3Result{std::move(Err)}; + } + std::string Path = KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); - HttpClient::KeyValueMap Headers = SignRequest("DELETE", Path, CanonicalQS, EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "DELETE", Path, CanonicalQS, EmptyPayloadHash); std::string FullPath = BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Delete(FullPath, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 AbortMultipartUpload failed"); + std::string Err = S3ErrorMessage("S3 AbortMultipartUpload failed", Response); ZEN_WARN("S3 AbortMultipartUpload '{}' failed: {}", Key, Err); return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 AbortMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + if (m_Verbose) + { + ZEN_INFO("S3 AbortMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + } return {}; } @@ -680,6 +930,12 @@ S3Client::GeneratePresignedPutUrl(std::string_view Key, std::chrono::seconds Exp std::string S3Client::GeneratePresignedUrlForMethod(std::string_view Key, std::string_view Method, std::chrono::seconds ExpiresIn) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 GeneratePresignedUrl '{}' {} failed", Key, Method); !Err.empty()) + { + return {}; + } + std::string Path = KeyToPath(Key); std::string Scheme = "https"; @@ -688,24 +944,25 @@ S3Client::GeneratePresignedUrlForMethod(std::string_view Key, std::string_view M Scheme = "http"; } - SigV4Credentials Credentials = GetCurrentCredentials(); return GeneratePresignedUrl(Credentials, Method, Scheme, m_Host, Path, m_Region, "s3", ExpiresIn); } S3Result -S3Client::PutObjectMultipart(std::string_view Key, IoBuffer Content, uint64_t PartSize) +S3Client::PutObjectMultipart(std::string_view Key, + uint64_t TotalSize, + std::function<IoBuffer(uint64_t Offset, uint64_t Size)> FetchRange, + uint64_t PartSize) { - const uint64_t ContentSize = Content.GetSize(); - // If the content fits in a single part, just use PutObject - if (ContentSize <= PartSize) + if (TotalSize <= PartSize) { - return PutObject(Key, Content); + return PutObject(Key, TotalSize > 0 ? FetchRange(0, TotalSize) : IoBuffer{}); } - ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, ContentSize, (ContentSize + PartSize - 1) / PartSize); - - // Initiate multipart upload + if (m_Verbose) + { + ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, TotalSize, (TotalSize + PartSize - 1) / PartSize); + } S3CreateMultipartUploadResult InitResult = CreateMultipartUpload(Key); if (!InitResult) @@ -722,38 +979,54 @@ S3Client::PutObjectMultipart(std::string_view Key, IoBuffer Content, uint64_t Pa uint64_t Offset = 0; uint32_t PartNumber = 1; - while (Offset < ContentSize) + try { - uint64_t ThisPartSize = std::min(PartSize, ContentSize - Offset); + while (Offset < TotalSize) + { + uint64_t ThisPartSize = std::min(PartSize, TotalSize - Offset); + IoBuffer PartContent = FetchRange(Offset, ThisPartSize); + S3UploadPartResult PartResult = UploadPart(Key, UploadId, PartNumber, std::move(PartContent)); + if (!PartResult) + { + AbortMultipartUpload(Key, UploadId); + return S3Result{std::move(PartResult.Error)}; + } - // Create a sub-buffer referencing the part data within the original content - IoBuffer PartContent(Content, Offset, ThisPartSize); + PartETags.emplace_back(PartNumber, std::move(PartResult.ETag)); + Offset += ThisPartSize; + PartNumber++; + } - S3UploadPartResult PartResult = UploadPart(Key, UploadId, PartNumber, PartContent); - if (!PartResult) + S3Result CompleteResult = CompleteMultipartUpload(Key, UploadId, PartETags); + if (!CompleteResult) { - // Attempt to abort the multipart upload on failure AbortMultipartUpload(Key, UploadId); - return S3Result{std::move(PartResult.Error)}; + return CompleteResult; } - - PartETags.emplace_back(PartNumber, std::move(PartResult.ETag)); - Offset += ThisPartSize; - PartNumber++; } - - // Complete multipart upload - S3Result CompleteResult = CompleteMultipartUpload(Key, UploadId, PartETags); - if (!CompleteResult) + catch (...) { AbortMultipartUpload(Key, UploadId); - return CompleteResult; + throw; } - ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), ContentSize); + if (m_Verbose) + { + ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), TotalSize); + } return {}; } +S3Result +S3Client::PutObjectMultipart(std::string_view Key, IoBuffer Content, uint64_t PartSize) +{ + return PutObjectMultipart( + Key, + Content.GetSize(), + [&Content](uint64_t Offset, uint64_t Size) { return IoBuffer(Content, Offset, Size); }, + PartSize); +} + ////////////////////////////////////////////////////////////////////////// // Tests @@ -828,7 +1101,10 @@ TEST_CASE("s3client.minio_integration") { using namespace std::literals; - // Spawn a local MinIO server + // Spawn a single MinIO server for the entire test case. Previously each SUBCASE re-entered + // the TEST_CASE from the top, spawning and killing MinIO per subcase - slow and flaky on + // macOS CI. Sequential sections avoid the re-entry while still sharing one MinIO instance + // that is torn down via RAII at scope exit. MinioProcessOptions MinioOpts; MinioOpts.Port = 19000; MinioOpts.RootUser = "testuser"; @@ -836,11 +1112,8 @@ TEST_CASE("s3client.minio_integration") MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); - - // Pre-create the test bucket (creates a subdirectory in MinIO's data dir) Minio.CreateBucket("integration-test"); - // Configure S3Client for the test bucket S3ClientOptions Opts; Opts.BucketName = "integration-test"; Opts.Region = "us-east-1"; @@ -851,7 +1124,7 @@ TEST_CASE("s3client.minio_integration") S3Client Client(Opts); - SUBCASE("put_get_delete") + // -- put_get_delete ------------------------------------------------------- { // PUT std::string_view TestData = "hello, minio integration test!"sv; @@ -880,14 +1153,50 @@ TEST_CASE("s3client.minio_integration") CHECK(HeadRes2.Status == HeadObjectResult::NotFound); } - SUBCASE("head_not_found") + // -- touch ---------------------------------------------------------------- + { + std::string_view TestData = "touch-me"sv; + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); + S3Result PutRes = Client.PutObject("touch/obj.txt", std::move(Content)); + REQUIRE(PutRes.IsSuccess()); + + S3HeadObjectResult Before = Client.HeadObject("touch/obj.txt"); + REQUIRE(Before.IsSuccess()); + REQUIRE(Before.Status == HeadObjectResult::Found); + + // S3 LastModified has second precision; sleep past the second boundary so + // the touched timestamp is strictly greater. + Sleep(1100); + + S3Result TouchRes = Client.Touch("touch/obj.txt"); + REQUIRE(TouchRes.IsSuccess()); + + S3HeadObjectResult After = Client.HeadObject("touch/obj.txt"); + REQUIRE(After.IsSuccess()); + REQUIRE(After.Status == HeadObjectResult::Found); + CHECK(After.Info.Size == Before.Info.Size); + CHECK(After.Info.LastModified != Before.Info.LastModified); + + // Content must be unchanged by a self-copy. + S3GetObjectResult GetRes = Client.GetObject("touch/obj.txt"); + REQUIRE(GetRes.IsSuccess()); + CHECK(GetRes.AsText() == TestData); + + // Touching a missing key must fail. + S3Result MissRes = Client.Touch("touch/does-not-exist.txt"); + CHECK_FALSE(MissRes.IsSuccess()); + + Client.DeleteObject("touch/obj.txt"); + } + + // -- head_not_found ------------------------------------------------------- { S3HeadObjectResult Res = Client.HeadObject("nonexistent/key.dat"); CHECK(Res.IsSuccess()); CHECK(Res.Status == HeadObjectResult::NotFound); } - SUBCASE("list_objects") + // -- list_objects --------------------------------------------------------- { // Upload several objects with a common prefix for (int i = 0; i < 3; ++i) @@ -922,7 +1231,7 @@ TEST_CASE("s3client.minio_integration") } } - SUBCASE("multipart_upload") + // -- multipart_upload ----------------------------------------------------- { // Create a payload large enough to exercise multipart (use minimum part size) constexpr uint64_t PartSize = 5 * 1024 * 1024; // 5 MB minimum @@ -949,7 +1258,7 @@ TEST_CASE("s3client.minio_integration") Client.DeleteObject("multipart/large.bin"); } - SUBCASE("presigned_urls") + // -- presigned_urls ------------------------------------------------------- { // Upload an object std::string_view TestData = "presigned-url-test-data"sv; @@ -975,8 +1284,6 @@ TEST_CASE("s3client.minio_integration") // Cleanup Client.DeleteObject("presigned/test.txt"); } - - Minio.StopMinioServer(); } TEST_SUITE_END(); diff --git a/src/zenutil/config/commandlineoptions.cpp b/src/zenutil/config/commandlineoptions.cpp index 25f5522d8..42ce0d06a 100644 --- a/src/zenutil/config/commandlineoptions.cpp +++ b/src/zenutil/config/commandlineoptions.cpp @@ -8,6 +8,12 @@ #include <zencore/windows.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <EASTL/fixed_vector.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <algorithm> + #if ZEN_WITH_TESTS # include <zencore/testing.h> #endif // ZEN_WITH_TESTS @@ -188,6 +194,422 @@ CommandLineConverter::CommandLineConverter(int& argc, char**& argv) argv = RawArgs.data(); } +////////////////////////////////////////////////////////////////////////// +// Command-line scrubber (used by invocation history and Sentry integration). + +namespace { + + // Suffixes (matched against the normalized option name) that mark an option + // as carrying a sensitive value. Names are normalized before comparison: + // leading dashes dropped, remaining '-' / '_' stripped, lowercased. So + // `--access-token` / `--Access-Token` / `--access_token` all normalize to + // "accesstoken" and match the "token" suffix. + // + // Picked deliberately to be sharp: catches every credential-bearing option + // in zen and zenserver (access tokens, OAuth client secrets, AES keys, + // Sentry DSNs, upstream Jupiter tokens) without false-positive masks on + // look-alikes (--access-token-env, --access-token-path, --oidctoken-exe-path, + // --valuekey, --opkey, the bare --key cloud lookup hash, --encryption-aes-iv). + // + // For freeform / unknown sensitive values that follow a known format (AWS + // access keys, Google API keys, JWT bearer tokens) the value-pattern scanner + // (kSecretPatterns below) provides an orthogonal safety net. + constexpr std::string_view kSensitiveNameSuffixes[] = { + "token", + "aeskey", + "secret", + "dsn", + }; + + constexpr char ToLowerAscii(char C) { return (C >= 'A' && C <= 'Z') ? char(C + ('a' - 'A')) : C; } + + bool IsSensitiveOptionName(std::string_view Name) + { + // Normalize: skip syntactic quotes, drop leading dashes, strip remaining + // '-' / '_', lowercase ASCII. "--access-token" / "--Access-Token" / + // "--access_token" / `"--Access_Token` all collapse to "accesstoken". + // The 64-byte inline buffer covers every realistic option name; the + // builder spills to the heap on a pathological name. + ExtendableStringBuilder<64> Norm; + bool LeadingDashes = true; + for (char C : Name) + { + if (C == '"') + { + continue; + } + if (LeadingDashes) + { + if (C == '-') + { + continue; + } + LeadingDashes = false; + } + if (C == '-' || C == '_') + { + continue; + } + Norm.Append(ToLowerAscii(C)); + } + const std::string_view View = Norm.ToView(); + for (std::string_view Suffix : kSensitiveNameSuffixes) + { + if (View.ends_with(Suffix)) + { + return true; + } + } + return false; + } + + constexpr std::string_view kUserAndPass = "***:***"; + constexpr std::string_view kUserOnly = "***"; + + // Locate scheme://[user[:pass]]@ credentials in Text. Sets HostStart to the + // offset right after "://", UserInfoLen to the length of the userinfo span + // to redact (i.e. AtPos - HostStart), and HasPassword to true when a ':' + // separates user and password within the userinfo. Returns false if Text + // has no credentialed authority component. + struct UrlCredentials + { + size_t HostStart; + size_t UserInfoLen; + bool HasPassword; + }; + bool FindUrlCredentials(std::string_view Text, UrlCredentials& Out) + { + const size_t SchemePos = Text.find("://"); + if (SchemePos == std::string_view::npos) + { + return false; + } + const size_t HostStart = SchemePos + 3; + const size_t AtPos = Text.find('@', HostStart); + if (AtPos == std::string_view::npos) + { + return false; + } + // '@' must be in the authority, not the path/query/fragment. + const size_t TermPos = Text.find_first_of("/?#", HostStart); + if (TermPos != std::string_view::npos && TermPos < AtPos) + { + return false; + } + const size_t ColonPos = Text.find(':', HostStart); + Out.HostStart = HostStart; + Out.UserInfoLen = AtPos - HostStart; + Out.HasPassword = (ColonPos != std::string_view::npos && ColonPos < AtPos); + return true; + } + + // Inline capacity for the tokenized cmdline. 32 covers any realistic + // invocation; if exceeded the eastl::fixed_vector spills to the heap. + constexpr size_t kTokenInlineCapacity = 32; + using TokenVector = eastl::fixed_vector<std::string_view, kTokenInlineCapacity>; + + constexpr AsciiSet kUpperAlnum = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + constexpr AsciiSet kBase64Url = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; + constexpr AsciiSet kJwtCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_."; + + // Known credential value formats. The cmdline scanner walks each non-argv[0] + // token and at every position checks whether one of these prefixes matches. + // On match it greedily consumes [MinBodyLen, MaxBodyLen] characters from + // BodyChars; if the consumed body is at least MinBodyLen the prefix + body + // range is masked. Patterns are kept strict (distinctive prefix + length + // + charset) to keep false positives near zero. + // + // Only formats that zen/zenserver interacts with (or may plausibly accept + // via a cloud config) are listed here: + // - AWS access keys: used by the S3 client (SigV4) and EC2 IMDS provider. + // - Google API keys: covers keys passed via GCP-adjacent configuration. + // - JWTs: OAuth/OIDC bearer tokens accepted via --access-token and the + // upstream-jupiter-oauth-* options are typically JWTs. + struct SecretPattern + { + std::string_view Prefix; + size_t MinBodyLen; + size_t MaxBodyLen; + AsciiSet BodyChars; + }; + + constexpr SecretPattern kSecretPatterns[] = { + // AWS access key (long-term and temporary / STS). + {"AKIA", 16, 16, kUpperAlnum}, + {"ASIA", 16, 16, kUpperAlnum}, + // Google API key: exactly "AIza" + 35 base64url body chars. + {"AIza", 35, 35, kBase64Url}, + // JWT (header.payload.signature). Loose - prefix "eyJ" plus body chars. + // Verified to contain at least two '.' before accepting. + {"eyJ", 20, 8192, kJwtCharset}, + }; + + // Scan a token for embedded secret patterns. For each match append an Edit + // that replaces the matched range with HideSensitiveString(matched value). + // The first 4 characters of any 17+ char match leak through - safe here + // since they are part of the public format prefix that triggered the match + // (e.g. "AKIA...", "AIza...", "eyJh..."). Edit offsets are absolute + // Cmdline offsets. + template<typename EditVecT> + void FindSecretPatternEdits(std::string_view Token, size_t TokStart, EditVecT& Edits) + { + size_t I = 0; + while (I < Token.size()) + { + bool Matched = false; + for (const SecretPattern& Pat : kSecretPatterns) + { + if (I + Pat.Prefix.size() > Token.size()) + { + continue; + } + if (Token.compare(I, Pat.Prefix.size(), Pat.Prefix) != 0) + { + continue; + } + const size_t BodyStart = I + Pat.Prefix.size(); + size_t J = BodyStart; + while (J < Token.size() && (J - BodyStart) < Pat.MaxBodyLen && Pat.BodyChars.Contains(Token[J])) + { + ++J; + } + const size_t BodyLen = J - BodyStart; + if (BodyLen < Pat.MinBodyLen) + { + continue; + } + if (Pat.Prefix == "eyJ") + { + int DotCount = 0; + for (size_t K = BodyStart; K < J; ++K) + { + if (Token[K] == '.' && ++DotCount == 2) + { + break; + } + } + if (DotCount < 2) + { + continue; + } + } + const size_t MatchLen = Pat.Prefix.size() + BodyLen; + Edits.push_back({TokStart + I, MatchLen, HideSensitiveString(Token.substr(I, MatchLen))}); + I = J; + Matched = true; + break; + } + if (!Matched) + { + ++I; + } + } + } + + // Tokenize a Windows-style command line into views into the original + // string. Each token preserves any surrounding/embedded quote characters. + // Splits on unquoted ASCII whitespace. No heap allocation up to + // kTokenInlineCapacity tokens. + TokenVector SplitCommandLineTokens(std::string_view Input) + { + TokenVector Tokens; + size_t I = 0; + while (I < Input.size()) + { + while (I < Input.size() && (Input[I] == ' ' || Input[I] == '\t')) + { + ++I; + } + if (I >= Input.size()) + { + break; + } + const size_t Start = I; + bool InQuote = false; + while (I < Input.size()) + { + const char C = Input[I]; + if (C == '"') + { + InQuote = !InQuote; + ++I; + continue; + } + if (!InQuote && (C == ' ' || C == '\t')) + { + break; + } + ++I; + } + Tokens.push_back(Input.substr(Start, I - Start)); + } + return Tokens; + } + + // First non-quote character in a token (or '\0' if there isn't one). + // Used to test whether a token starts with '-' through any leading quotes. + char FirstNonQuote(std::string_view Token) + { + for (char C : Token) + { + if (C != '"') + { + return C; + } + } + return '\0'; + } + + void ScrubSensitiveValuesImpl(std::string& Cmdline) + { + const TokenVector Tokens = SplitCommandLineTokens(Cmdline); + if (Tokens.size() <= 1) + { + return; + } + + // Edits are accumulated and applied right-to-left at the end, so + // untouched offsets remain valid. fixed_vector keeps the storage + // inline; spills to heap only on a pathological number of edits. The + // std::string Replacement uses small-string optimization for the + // short masks we produce (max ~12 chars) so no heap traffic per edit + // in the common case. + struct Edit + { + size_t Start; + size_t Len; + std::string Replacement; + }; + eastl::fixed_vector<Edit, kTokenInlineCapacity> Edits; + + auto PushMask = [&](size_t Start, size_t Len, std::string_view Replacement) { + Edits.push_back({Start, Len, std::string(Replacement)}); + }; + + bool MaskNext = false; + // Skip Tokens[0] (executable path) - never scrub. + for (size_t Idx = 1; Idx < Tokens.size(); ++Idx) + { + const std::string_view Tok = Tokens[Idx]; + const size_t TokStart = static_cast<size_t>(Tok.data() - Cmdline.data()); + const char Lead = FirstNonQuote(Tok); + const bool IsFlag = (Lead == '-'); + + if (MaskNext) + { + MaskNext = false; + if (!IsFlag && Lead != '\0') + { + PushMask(TokStart, Tok.size(), kUserOnly); + continue; + } + // Otherwise fall through to re-evaluate this token as a flag. + } + + if (!IsFlag) + { + // Positional. URL credentials and secret patterns target + // different ranges so they can coexist; overlap is filtered + // out below. + if (UrlCredentials U; FindUrlCredentials(Tok, U)) + { + PushMask(TokStart + U.HostStart, U.UserInfoLen, U.HasPassword ? kUserAndPass : kUserOnly); + } + FindSecretPatternEdits(Tok, TokStart, Edits); + continue; + } + + // It's a flag. Look for inline =value. + const size_t EqPos = Tok.find('='); + if (EqPos == std::string_view::npos) + { + if (IsSensitiveOptionName(Tok)) + { + MaskNext = true; + } + else + { + // Bare flag - still scan for embedded secret prefixes + // (e.g. "--AKIAEXAMPLE..." would be unusual but harmless). + FindSecretPatternEdits(Tok, TokStart, Edits); + } + continue; + } + + const std::string_view Name = Tok.substr(0, EqPos); + if (IsSensitiveOptionName(Name)) + { + const size_t ValueStart = TokStart + EqPos + 1; + size_t ValueLen = Tok.size() - (EqPos + 1); + // Preserve the closing quote when the whole token is outer-quoted + // (e.g. `"--name=value"`); otherwise the replacement would eat it + // and leave an unbalanced quote in the cmdline string. + if (ValueLen > 0 && Tok.front() == '"' && Tok.back() == '"') + { + --ValueLen; + } + PushMask(ValueStart, ValueLen, kUserOnly); + } + else + { + // Non-sensitive flag with a value: URL scrub + pattern scan + // on the value (same coexistence as the positional branch). + const std::string_view Value = Tok.substr(EqPos + 1); + const size_t ValueAbsStart = TokStart + EqPos + 1; + if (UrlCredentials U; FindUrlCredentials(Value, U)) + { + PushMask(ValueAbsStart + U.HostStart, U.UserInfoLen, U.HasPassword ? kUserAndPass : kUserOnly); + } + FindSecretPatternEdits(Value, ValueAbsStart, Edits); + } + } + + if (Edits.empty()) + { + return; + } + + // Sort by start ascending and drop overlapping/duplicate edits so the + // right-to-left replacement pass does not corrupt the string. Earlier + // (lower-Start) edits win over later ones that fall inside their range. + std::sort(Edits.begin(), Edits.end(), [](const Edit& A, const Edit& B) { return A.Start < B.Start; }); + size_t Write = 0; + for (size_t Read = 0; Read < Edits.size(); ++Read) + { + if (Write > 0 && Edits[Read].Start < Edits[Write - 1].Start + Edits[Write - 1].Len) + { + continue; // overlaps the prior kept edit + } + if (Write != Read) + { + Edits[Write] = Edits[Read]; + } + ++Write; + } + Edits.resize(Write); + + // Apply right-to-left so earlier offsets stay valid as later edits + // resize the string. + for (size_t E = Edits.size(); E-- > 0;) + { + Cmdline.replace(Edits[E].Start, Edits[E].Len, Edits[E].Replacement); + } + } + +} // namespace + +void +ScrubSensitiveValues(std::string& Cmdline) noexcept +{ + try + { + ScrubSensitiveValuesImpl(Cmdline); + } + catch (...) + { + } +} + #if ZEN_WITH_TESTS void @@ -195,6 +617,22 @@ commandlineoptions_forcelink() { } +namespace { + // Test helper: in-place redaction of user:password@ in a URL. Production + // code calls FindUrlCredentials + PushMask directly inside the cmdline + // walker; this wrapper exists only for direct unit-test coverage of the + // URL redaction rule. + void ScrubUrlCredentials(std::string& Token) + { + UrlCredentials U; + if (!FindUrlCredentials(Token, U)) + { + return; + } + Token.replace(U.HostStart, U.UserInfoLen, U.HasPassword ? kUserAndPass : kUserOnly); + } +} // namespace + TEST_SUITE_BEGIN("util.commandlineoptions"); TEST_CASE("CommandLine") @@ -238,6 +676,321 @@ TEST_CASE("CommandLine") CHECK_EQ(v3Stripped[5], std::string("--build-part-name=win64")); } +TEST_CASE("IsSensitiveOptionName.matches") +{ + // Real zen / zenserver options ending in one of the sensitive suffixes. + CHECK(IsSensitiveOptionName("--access-token")); // token + CHECK(IsSensitiveOptionName("--openid-refresh-token")); // token + CHECK(IsSensitiveOptionName("--upstream-jupiter-token")); // token + CHECK(IsSensitiveOptionName("--encryption-aes-key")); // aeskey + CHECK(IsSensitiveOptionName("--oauth-clientsecret")); // secret + CHECK(IsSensitiveOptionName("--upstream-jupiter-oauth-clientsecret")); // secret + CHECK(IsSensitiveOptionName("--sentry-dsn")); // dsn + + // Generic forms ending in one of the suffixes. + CHECK(IsSensitiveOptionName("--token")); + CHECK(IsSensitiveOptionName("--my-aeskey")); + + // Normalization equivalents - dashes / underscores stripped, case folded. + CHECK(IsSensitiveOptionName("--ACCESS-TOKEN")); + CHECK(IsSensitiveOptionName("--access_token")); + CHECK(IsSensitiveOptionName("--AccessToken")); + CHECK(IsSensitiveOptionName("--Encryption_AES_Key")); + CHECK(IsSensitiveOptionName("--Sentry-DSN")); + + // Quotes around the option name are stripped before matching. + CHECK(IsSensitiveOptionName("\"--access-token\"")); + CHECK(IsSensitiveOptionName("\"--sentry-dsn")); +} + +TEST_CASE("IsSensitiveOptionName.no-match") +{ + // Real zen options whose name contains "token" or "key" mid-name but + // whose value is NOT a secret (suffix doesn't match). + CHECK_FALSE(IsSensitiveOptionName("--access-token-env")); // env name + CHECK_FALSE(IsSensitiveOptionName("--access-token-path")); // file path + CHECK_FALSE(IsSensitiveOptionName("--oidctoken-exe-path")); // exe path + CHECK_FALSE(IsSensitiveOptionName("--allow-external-oidctoken-exe")); // boolean + CHECK_FALSE(IsSensitiveOptionName("--encryption-aes-iv")); // IV is public + CHECK_FALSE(IsSensitiveOptionName("--key")); // cloud lookup hash + CHECK_FALSE(IsSensitiveOptionName("--valuekey")); // IoHash filter + CHECK_FALSE(IsSensitiveOptionName("--opkey")); // chunk OID filter + + // "key" alone is not a sensitive suffix in this scheme - keeps lookup + // hashes / cache filter / api-key / ssh-key visible. + CHECK_FALSE(IsSensitiveOptionName("--api-key")); + CHECK_FALSE(IsSensitiveOptionName("--ssh-key")); + + // Non-sensitive option names. + CHECK_FALSE(IsSensitiveOptionName("--port")); + CHECK_FALSE(IsSensitiveOptionName("--data-dir")); + CHECK_FALSE(IsSensitiveOptionName("--debug")); + CHECK_FALSE(IsSensitiveOptionName("--no-sentry")); + CHECK_FALSE(IsSensitiveOptionName("--filter")); + CHECK_FALSE(IsSensitiveOptionName("--monkey")); + CHECK_FALSE(IsSensitiveOptionName("--keychain")); + + CHECK_FALSE(IsSensitiveOptionName("")); + CHECK_FALSE(IsSensitiveOptionName("--")); + CHECK_FALSE(IsSensitiveOptionName("---")); + + // Pathologically long non-sensitive name spills the inline buffer but + // still resolves cleanly. + CHECK_FALSE(IsSensitiveOptionName(std::string(200, 'x'))); +} + +TEST_CASE("ScrubUrlCredentials.basic") +{ + std::string T = "https://user:[email protected]/path"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "https://***:***@host.example.com/path"); + + T = "https://[email protected]"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "https://***@host.example.com"); + + T = "ftp://u:[email protected]:21/file"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "ftp://***:***@example.com:21/file"); +} + +TEST_CASE("ScrubUrlCredentials.passthrough") +{ + std::string T = "just a string"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "just a string"); + + T = "https://example.com/path"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "https://example.com/path"); + + T = "https://example.com/users/foo@bar"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "https://example.com/users/foo@bar"); + + T = "user@host"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "user@host"); +} + +TEST_CASE("ScrubSensitiveValues.inline-equals") +{ + std::string C = "zen.exe --access-token=secret --port=8558 version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token=*** --port=8558 version"); +} + +TEST_CASE("ScrubSensitiveValues.next-token") +{ + std::string C = "zen.exe --access-token mysecret version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token *** version"); + + // Multiple next-token sensitive options in a row. + C = "zen.exe --access-token tok1 --oauth-clientsecret sec1 --port 8558 version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token *** --oauth-clientsecret *** --port 8558 version"); +} + +TEST_CASE("ScrubSensitiveValues.sensitive-flag-followed-by-flag") +{ + // `--access-token` here is acting like a switch (no value); the next + // flag must NOT be masked. + std::string C = "zen.exe --access-token --debug version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token --debug version"); +} + +TEST_CASE("ScrubSensitiveValues.normalized-name-forms") +{ + // Same option masked under different casing/separator styles. + std::string C = "zen.exe --Access-Token bar --ACCESS_TOKEN=foo --AccessToken=baz version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --Access-Token *** --ACCESS_TOKEN=*** --AccessToken=*** version"); +} + +TEST_CASE("ScrubSensitiveValues.aes-key-and-iv-and-bare-key") +{ + // --encryption-aes-key matches the "aeskey" suffix and is masked. + // --encryption-aes-iv ends in "iv", not in the suffix set, and stays. + // --key (cloud lookup hash) ends in just "key", not in the suffix set, + // and stays. + std::string C = "zen.exe --encryption-aes-key=BASE64STUFF --encryption-aes-iv ABCDEF --key=cloudval --key bareval version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --encryption-aes-key=*** --encryption-aes-iv ABCDEF --key=cloudval --key bareval version"); +} + +TEST_CASE("ScrubSensitiveValues.no-false-positives") +{ + // Names that don't end in any sensitive suffix stay untouched. + std::string C = "zen.exe --password mypass --api-key bar --monkey=banana --filter zen --no-sentry --port 8558 version"; + const std::string Original = C; + ScrubSensitiveValues(C); + CHECK_EQ(C, Original); +} + +TEST_CASE("ScrubSensitiveValues.url-credentials-positional") +{ + std::string C = "zen.exe serve https://user:[email protected]/path"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe serve https://***:***@host.example.com/path"); +} + +TEST_CASE("ScrubSensitiveValues.url-credentials-in-option-value") +{ + std::string C = "zen.exe --url=https://user:[email protected] version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --url=https://***:***@host.example.com version"); +} + +TEST_CASE("ScrubSensitiveValues.outer-quoted-token-preserves-closing-quote") +{ + std::string C = "zen.exe status \"--access-token=Bearer xyz\""; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe status \"--access-token=***\""); +} + +TEST_CASE("ScrubSensitiveValues.executable-path-never-touched") +{ + // argv[0] stays untouched even if it embeds a token-shaped substring. + std::string C = "/path/with/AKIAIOSFODNN7EXAMPLE/zen.exe version"; + const std::string Original = C; + ScrubSensitiveValues(C); + CHECK_EQ(C, Original); +} + +TEST_CASE("ScrubSensitiveValues.empty-and-trivial") +{ + std::string C; + ScrubSensitiveValues(C); + CHECK_EQ(C, ""); + + C = "zen.exe"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe"); +} + +TEST_CASE("ScrubSensitiveValues.long-value") +{ + const std::string LongSecret(4096, 'X'); + std::string C = "zen.exe --access-token=" + LongSecret + " version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token=*** version"); + + C = "zen.exe --access-token " + LongSecret + " version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token *** version"); + + C = "zen.exe \"--access-token=" + LongSecret + "\""; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe \"--access-token=***\""); +} + +TEST_CASE("ScrubSensitiveValues.long-option-name-spills-buffer") +{ + // Names longer than the 64-byte inline buffer normalize correctly via + // ExtendableStringBuilder's heap-spill path, and the suffix check still + // identifies them as sensitive. + const std::string LongName = "--" + std::string(80, 'a') + "-token"; + std::string C = "zen.exe " + LongName + "=secretvalue version"; + const std::string Expected = "zen.exe " + LongName + "=*** version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, Expected); +} + +TEST_CASE("ScrubSensitiveValues.no-allocation-on-noop") +{ + std::string C = "zen.exe version --port=8558 --debug"; + C.reserve(256); + const auto* DataBefore = C.data(); + const auto CapacityBefore = C.capacity(); + ScrubSensitiveValues(C); + CHECK_EQ(C.data(), DataBefore); + CHECK_EQ(C.capacity(), CapacityBefore); +} + +// Value-based pattern matching ------------------------------------------ + +// Pattern-matched values are masked via HideSensitiveString, which leaks +// the first 4 characters for any value over 16 chars. Safe here because +// those 4 characters are part of the public format prefix that triggered +// the match (AKIA, AIza, eyJh) and serve as a useful debugging hint when +// reading a crash report. + +TEST_CASE("ScrubSensitiveValues.aws-access-key") +{ + std::string C = "zen.exe upload --bucket foo AKIAIOSFODNN7EXAMPLE done"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe upload --bucket foo AKIAXXXX... done"); + + // Inside an option value with non-sensitive name. + C = "zen.exe --note=AKIAIOSFODNN7EXAMPLE"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --note=AKIAXXXX..."); + + // AKIA followed by too few chars: not a match. + C = "zen.exe AKIA12345"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe AKIA12345"); +} + +TEST_CASE("ScrubSensitiveValues.google-api-key") +{ + // Google API keys are exactly AIza + 35 body chars. + std::string C = "zen.exe lookup AIzaSyA_abcdefghijklmnopqrstuvwx-123456 done"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe lookup AIzaXXXX... done"); + + // AIza body too short: not a match. + C = "zen.exe AIzaShort"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe AIzaShort"); +} + +TEST_CASE("ScrubSensitiveValues.jwt") +{ + std::string C = + "zen.exe call eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c " + "done"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe call eyJhXXXX... done"); + + // "eyJ" but no two dots: not a JWT. + C = "zen.exe eyJonly"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe eyJonly"); +} + +TEST_CASE("ScrubSensitiveValues.no-pattern-false-positives") +{ + // Common identifier formats used by zen that must NOT be flagged: + // - 24-hex Oid + // - SHA-like hex hashes + // - file paths + // - port numbers + // - lowercase words / common shell commands + std::string C = "zen.exe cache get 09f7831b0139270d22cf2fe2 --port=8558 run /var/tmp/zen.log"; + const std::string Original = C; + ScrubSensitiveValues(C); + CHECK_EQ(C, Original); +} + +TEST_CASE("ScrubSensitiveValues.pattern-and-name-on-same-token") +{ + // Sensitive-name mask wins; pattern scan is suppressed for the value. + std::string C = "zen.exe --access-token=AKIAIOSFODNN7EXAMPLE done"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token=*** done"); +} + +TEST_CASE("ScrubSensitiveValues.multiple-patterns-in-one-token") +{ + // Two AWS keys glued together via separator - both should mask. + std::string C = "zen.exe AKIAIOSFODNN7EXAMPLE,AKIAIOSFODNN7OTHER12 list"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe AKIAXXXX...,AKIAXXXX... list"); +} + TEST_SUITE_END(); #endif diff --git a/src/zenutil/config/loggingconfig.cpp b/src/zenutil/config/loggingconfig.cpp index 5092c60aa..e2db31160 100644 --- a/src/zenutil/config/loggingconfig.cpp +++ b/src/zenutil/config/loggingconfig.cpp @@ -29,6 +29,8 @@ ZenLoggingCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenLoggingCon ("log-critical", "Change selected loggers to level CRITICAL", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Critical])) ("log-off", "Change selected loggers to level OFF", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Off])) ("otlp-endpoint", "OpenTelemetry endpoint URI (e.g http://localhost:4318)", cxxopts::value<std::string>(LoggingConfig.OtelEndpointUri)) + ("force-color", "Force colored log output even when stdout is not a terminal", cxxopts::value<bool>(LoggingConfig.ForceColor)->default_value("false")) + ("log-stream", "TCP log stream endpoint (host:port)", cxxopts::value<std::string>(LoggingConfig.LogStreamEndpoint)) ; // clang-format on } diff --git a/src/zenutil/consoletui.cpp b/src/zenutil/consoletui.cpp index 124132aed..10e8abb31 100644 --- a/src/zenutil/consoletui.cpp +++ b/src/zenutil/consoletui.cpp @@ -311,7 +311,7 @@ TuiPickOne(std::string_view Title, std::span<const std::string> Items) printf("\033[1;7m"); // bold + reverse video } - // \xe2\x96\xb6 = U+25B6 BLACK RIGHT-POINTING TRIANGLE (▶) + // \xe2\x96\xb6 = U+25B6 BLACK RIGHT-POINTING TRIANGLE (>) const char* Indicator = IsSelected ? " \xe2\x96\xb6 " : " "; printf("%s%s", Indicator, Items[i].c_str()); @@ -328,7 +328,7 @@ TuiPickOne(std::string_view Title, std::span<const std::string> Items) printf("\r\033[K\n"); // Hint footer - // \xe2\x86\x91 = U+2191 ↑ \xe2\x86\x93 = U+2193 ↓ + // \xe2\x86\x91 = U+2191 ^ \xe2\x86\x93 = U+2193 v printf( "\r\033[K \033[2m\xe2\x86\x91/\xe2\x86\x93\033[0m navigate " "\033[2mEnter\033[0m confirm " diff --git a/src/zenutil/consul/consul.cpp b/src/zenutil/consul/consul.cpp index d463c0938..3d16a9188 100644 --- a/src/zenutil/consul/consul.cpp +++ b/src/zenutil/consul/consul.cpp @@ -9,10 +9,18 @@ #include <zencore/logging.h> #include <zencore/process.h> #include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> #include <zencore/thread.h> #include <zencore/timer.h> +#include <zenhttp/httpserver.h> + +#include <unordered_set> + +ZEN_THIRD_PARTY_INCLUDES_START #include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END namespace zen::consul { @@ -31,7 +39,7 @@ struct ConsulProcess::Impl } CreateProcOptions Options; - Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; const std::filesystem::path ConsulExe = GetRunningExecutablePath().parent_path() / ("consul" ZEN_EXE_SUFFIX_LITERAL); CreateProcResult Result = CreateProc(ConsulExe, "consul" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options); @@ -107,12 +115,30 @@ ConsulProcess::StopConsulAgent() ////////////////////////////////////////////////////////////////////////// -ConsulClient::ConsulClient(std::string_view BaseUri, std::string_view Token) : m_Token(Token), m_HttpClient(BaseUri) +ConsulClient::ConsulClient(const Configuration& Config) +: m_Config(Config) +, m_HttpClient(m_Config.BaseUri, + HttpClientSettings{.ConnectTimeout = std::chrono::milliseconds{500}, .Timeout = std::chrono::milliseconds{500}}, + [this] { return m_Stop.load(); }) { + m_Worker = std::thread(&ConsulClient::WorkerLoop, this); } ConsulClient::~ConsulClient() { + try + { + m_Stop.store(true); + m_Wakeup.Set(); + if (m_Worker.joinable()) + { + m_Worker.join(); + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("ConsulClient::~ConsulClient threw exception: {}", Ex.what()); + } } void @@ -158,15 +184,35 @@ ConsulClient::DeleteKey(std::string_view Key) } } -bool +void ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) { + PendingOp Op{PendingOp::Kind::Register, Info}; + m_QueueLock.WithExclusiveLock([&] { m_Queue.push_back(std::move(Op)); }); + m_Wakeup.Set(); +} + +void +ConsulClient::DeregisterService(std::string_view ServiceId) +{ + PendingOp Op; + Op.Type = PendingOp::Kind::Deregister; + Op.Info.ServiceId = std::string(ServiceId); + m_QueueLock.WithExclusiveLock([&] { m_Queue.push_back(std::move(Op)); }); + m_Wakeup.Set(); +} + +bool +ConsulClient::DoRegister(const ServiceRegistrationInfo& Info) +{ using namespace std::literals; HttpClient::KeyValueMap AdditionalHeaders; ApplyCommonHeaders(AdditionalHeaders); AdditionalHeaders.Entries.emplace(HttpClient::Accept(HttpContentType::kJSON)); + HttpClient::KeyValueMap AdditionalParameters(std::make_pair<std::string, std::string>("replace-existing-checks", "true")); + CbObjectWriter Writer; { Writer.AddString("ID"sv, Info.ServiceId); @@ -185,13 +231,27 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) } Writer.EndArray(); // Tags } - Writer.BeginObject("Check"sv); + if (Info.HealthIntervalSeconds != 0) { - Writer.AddString("HTTP"sv, fmt::format("http://{}:{}/{}", Info.Address, Info.Port, Info.HealthEndpoint)); - Writer.AddString("Interval"sv, fmt::format("{}s", Info.HealthIntervalSeconds)); - Writer.AddString("DeregisterCriticalServiceAfter"sv, fmt::format("{}s", Info.DeregisterAfterSeconds)); + // Consul requires Interval whenever HTTP is specified; omit the Check block entirely + // when no interval is configured (e.g. during Provisioning). + Writer.BeginObject("Check"sv); + { + Writer.AddString( + "HTTP"sv, + fmt::format("http://{}:{}/{}", Info.Address.empty() ? "localhost" : Info.Address, Info.Port, Info.HealthEndpoint)); + Writer.AddString("Interval"sv, fmt::format("{}s", Info.HealthIntervalSeconds)); + if (Info.DeregisterAfterSeconds != 0) + { + Writer.AddString("DeregisterCriticalServiceAfter"sv, fmt::format("{}s", Info.DeregisterAfterSeconds)); + } + if (!Info.InitialStatus.empty()) + { + Writer.AddString("Status"sv, Info.InitialStatus); + } + } + Writer.EndObject(); // Check } - Writer.EndObject(); // Check } ExtendableStringBuilder<512> SB; @@ -199,11 +259,11 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) IoBuffer PayloadBuffer(IoBuffer::Wrap, SB.Data(), SB.Size()); PayloadBuffer.SetContentType(HttpContentType::kJSON); - HttpClient::Response Result = m_HttpClient.Put("v1/agent/service/register", PayloadBuffer, AdditionalHeaders); + HttpClient::Response Result = m_HttpClient.Put("v1/agent/service/register", PayloadBuffer, AdditionalHeaders, AdditionalParameters); if (!Result) { - ZEN_WARN("ConsulClient::RegisterService() failed to register service '{}' ({})", Info.ServiceId, Result.ErrorMessage("")); + ZEN_WARN("ConsulClient::DoRegister() failed to register service '{}' ({})", Info.ServiceId, Result.ErrorMessage("")); return false; } @@ -211,29 +271,114 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) } bool -ConsulClient::DeregisterService(std::string_view ServiceId) +ConsulClient::DoDeregister(std::string_view ServiceId) { + using namespace std::literals; + HttpClient::KeyValueMap AdditionalHeaders; ApplyCommonHeaders(AdditionalHeaders); AdditionalHeaders.Entries.emplace(HttpClient::Accept(HttpContentType::kJSON)); - HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), AdditionalHeaders); + HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), IoBuffer{}, AdditionalHeaders); + if (Result) + { + return true; + } + + // Agent deregister failed - fall back to catalog deregister. + // This handles cases where the service was registered via a different Consul agent + // (e.g. load-balanced endpoint routing to different agents). + std::string NodeName = GetNodeName(); + if (!NodeName.empty()) + { + CbObjectWriter Writer; + Writer.AddString("Node"sv, NodeName); + Writer.AddString("ServiceID"sv, ServiceId); + + ExtendableStringBuilder<256> SB; + CompactBinaryToJson(Writer.Save(), SB); + + IoBuffer PayloadBuffer(IoBuffer::Wrap, SB.Data(), SB.Size()); + PayloadBuffer.SetContentType(HttpContentType::kJSON); + + HttpClient::Response CatalogResult = m_HttpClient.Put("v1/catalog/deregister", PayloadBuffer, AdditionalHeaders); + if (CatalogResult) + { + ZEN_INFO("ConsulClient::DoDeregister() deregistered service '{}' via catalog fallback (agent error: {})", + ServiceId, + Result.ErrorMessage("")); + return true; + } + + ZEN_WARN("ConsulClient::DoDeregister() failed to deregister service '{}' (agent: {}, catalog: {})", + ServiceId, + Result.ErrorMessage(""), + CatalogResult.ErrorMessage("")); + } + else + { + ZEN_WARN( + "ConsulClient::DoDeregister() failed to deregister service '{}' (agent: {}, could not determine node name for catalog " + "fallback)", + ServiceId, + Result.ErrorMessage("")); + } + + return false; +} + +std::string +ConsulClient::GetNodeName() +{ + using namespace std::literals; + + HttpClient::KeyValueMap AdditionalHeaders; + ApplyCommonHeaders(AdditionalHeaders); + HttpClient::Response Result = m_HttpClient.Get("v1/agent/self", AdditionalHeaders); if (!Result) { - ZEN_WARN("ConsulClient::DeregisterService() failed to deregister service '{}' ({})", ServiceId, Result.ErrorMessage("")); - return false; + return {}; } - return true; + std::string JsonError; + CbFieldIterator Root = LoadCompactBinaryFromJson(Result.AsText(), JsonError); + if (!Root || !JsonError.empty()) + { + return {}; + } + + for (CbFieldView Field : Root) + { + if (Field.GetName() == "Config"sv) + { + CbObjectView Config = Field.AsObjectView(); + if (Config) + { + return std::string(Config["NodeName"sv].AsString()); + } + } + } + + return {}; } void ConsulClient::ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap) { - if (!m_Token.empty()) + std::string Token; + if (!m_Config.StaticToken.empty()) + { + Token = m_Config.StaticToken; + } + else if (!m_Config.TokenEnvName.empty()) + { + Token = GetEnvVariable(m_Config.TokenEnvName); + } + + if (!Token.empty()) { - InOutHeaderMap.Entries.emplace("X-Consul-Token", m_Token); + InOutHeaderMap.Entries.emplace("X-Consul-Token", Token); } } @@ -285,8 +430,10 @@ ConsulClient::WatchService(std::string_view ServiceId, uint64_t& InOutIndex, int HttpClient::KeyValueMap AdditionalHeaders; ApplyCommonHeaders(AdditionalHeaders); - // Note: m_HttpClient uses unlimited HTTP timeout (Timeout{0}); the WaitSeconds parameter - // governs the server-side bound on the blocking query. Do not add a separate client timeout. + // Note: m_HttpClient runs with a 500ms client-side timeout to keep Register/Deregister from + // stalling the hub state machine when the agent is unreachable. That bound applies here too: + // WaitSeconds is effectively capped at ~500ms regardless of the argument, so callers must + // treat this as a short-poll and loop rather than rely on a true blocking query. HttpClient::KeyValueMap Parameters({{"index", std::to_string(InOutIndex)}, {"wait", fmt::format("{}s", WaitSeconds)}}); HttpClient::Response Result = m_HttpClient.Get("v1/agent/services", AdditionalHeaders, Parameters); if (!Result) @@ -321,6 +468,96 @@ ConsulClient::GetAgentServicesJson() return Result.ToText(); } +std::string +ConsulClient::GetAgentChecksJson() +{ + HttpClient::KeyValueMap AdditionalHeaders; + ApplyCommonHeaders(AdditionalHeaders); + + HttpClient::Response Result = m_HttpClient.Get("v1/agent/checks", AdditionalHeaders); + if (!Result) + { + return "{}"; + } + return Result.ToText(); +} + +void +ConsulClient::WorkerLoop() +{ + SetCurrentThreadName("ConsulClient"); + + std::unordered_set<std::string> RegisteredServices; + + while (true) + { + m_Wakeup.Wait(-1); + m_Wakeup.Reset(); + + const bool Stopping = m_Stop.load(); + + std::vector<PendingOp> Batch; + m_QueueLock.WithExclusiveLock([&] { Batch.swap(m_Queue); }); + + for (size_t Index = 0; Index < Batch.size(); ++Index) + { + PendingOp& Op = Batch[Index]; + + if (Stopping && Op.Type == PendingOp::Kind::Register) + { + continue; + } + + const std::string_view OpName = (Op.Type == PendingOp::Kind::Register) ? "register" : "deregister"; + + try + { + if (Op.Type == PendingOp::Kind::Register) + { + bool Ok = DoRegister(Op.Info); + if (Ok) + { + RegisteredServices.insert(Op.Info.ServiceId); + } + else + { + const size_t Remaining = Batch.size() - Index - 1; + ZEN_WARN("ConsulClient worker: {} for '{}' failed; dropping {} remaining queued op(s)", + OpName, + Op.Info.ServiceId, + Remaining); + break; + } + } + else + { + ZEN_ASSERT(Op.Type == PendingOp::Kind::Deregister); + if (RegisteredServices.erase(Op.Info.ServiceId) == 1u) + { + if (!DoDeregister(Op.Info.ServiceId)) + { + ZEN_WARN("ConsulClient worker: {} for '{}' failed", OpName, Op.Info.ServiceId); + } + } + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("ConsulClient worker: {} for '{}' threw: {}", OpName, Op.Info.ServiceId, Ex.what()); + } + catch (...) + { + ZEN_WARN("ConsulClient worker: {} for '{}' threw unknown exception", OpName, Op.Info.ServiceId); + } + } + + if (Stopping) + { + break; + } + } +} + ////////////////////////////////////////////////////////////////////////// ServiceRegistration::ServiceRegistration(ConsulClient* Client, const ServiceRegistrationInfo& Info) : m_Client(Client), m_Info(Info) @@ -341,7 +578,7 @@ ServiceRegistration::~ServiceRegistration() if (m_IsRegistered.load()) { - if (!m_Client->DeregisterService(m_Info.ServiceId)) + if (!m_Client->DoDeregister(m_Info.ServiceId)) { ZEN_WARN("ServiceRegistration: Failed to deregister service '{}' during cleanup", m_Info.ServiceId); } @@ -381,7 +618,7 @@ ServiceRegistration::RegistrationLoop() // Try to register with exponential backoff for (int Attempt = 0; Attempt < MaxAttempts; ++Attempt) { - if (m_Client->RegisterService(m_Info)) + if (m_Client->DoRegister(m_Info)) { Succeeded = true; break; @@ -393,7 +630,7 @@ ServiceRegistration::RegistrationLoop() } } - if (Succeeded || m_Client->RegisterService(m_Info)) + if (Succeeded || m_Client->DoRegister(m_Info)) { break; } @@ -422,4 +659,201 @@ ServiceRegistration::RegistrationLoop() } } +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +consul_forcelink() +{ +} + +struct MockHealthService : public HttpService +{ + std::atomic<bool> FailHealth{false}; + std::atomic<int> HealthCheckCount{0}; + + const char* BaseUri() const override { return "/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + std::string_view Uri = Request.RelativeUri(); + if (Uri == "health/" || Uri == "health") + { + HealthCheckCount.fetch_add(1); + if (FailHealth.load()) + { + Request.WriteResponse(HttpResponseCode::ServiceUnavailable); + } + else + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + } +}; + +struct TestHealthServer +{ + MockHealthService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(0, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + int Port() const { return m_Port; } + + ~TestHealthServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional<ScopedTemporaryDirectory> m_TmpDir; + Ref<HttpServer> m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +static bool +WaitForCondition(std::function<bool()> Predicate, int TimeoutMs, int PollIntervalMs = 200) +{ + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < static_cast<uint64_t>(TimeoutMs)) + { + if (Predicate()) + { + return true; + } + Sleep(PollIntervalMs); + } + return Predicate(); +} + +static std::string +GetCheckStatus(ConsulClient& Client, std::string_view ServiceId) +{ + using namespace std::literals; + + std::string JsonError; + CbFieldIterator ChecksRoot = LoadCompactBinaryFromJson(Client.GetAgentChecksJson(), JsonError); + if (!ChecksRoot || !JsonError.empty()) + { + return {}; + } + + for (CbFieldView F : ChecksRoot) + { + if (!F.IsObject()) + { + continue; + } + for (CbFieldView C : F.AsObjectView()) + { + CbObjectView Check = C.AsObjectView(); + if (Check["ServiceID"sv].AsString() == ServiceId) + { + return std::string(Check["Status"sv].AsString()); + } + } + } + return {}; +} + +TEST_SUITE_BEGIN("util.consul"); + +TEST_CASE("util.consul.service_lifecycle") +{ + ConsulProcess ConsulProc; + ConsulProc.SpawnConsulAgent(); + + TestHealthServer HealthServer; + HealthServer.Start(); + + ConsulClient Client({.BaseUri = "http://localhost:8500/"}); + + const std::string ServiceId = "test-health-svc"; + + ServiceRegistrationInfo Info; + Info.ServiceId = ServiceId; + Info.ServiceName = "zen-test-health"; + Info.Address = "127.0.0.1"; + Info.Port = static_cast<uint16_t>(HealthServer.Port()); + Info.HealthEndpoint = "health/"; + Info.HealthIntervalSeconds = 1; + Info.DeregisterAfterSeconds = 60; + + // Register/Deregister are async; wait for the worker to propagate to Consul. + + // Phase 1: Register and verify Consul sends health checks to our service + Client.RegisterService(Info); + REQUIRE(WaitForCondition([&]() { return Client.HasService(ServiceId); }, 10000, 50)); + + REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000, 50)); + CHECK(HealthServer.Mock.HealthCheckCount.load() >= 1); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing"); + + // Phase 2: Explicit deregister + Client.DeregisterService(ServiceId); + REQUIRE(WaitForCondition([&]() { return !Client.HasService(ServiceId); }, 10000, 50)); + + // Phase 3: Register with InitialStatus, verify immediately passing before any health check fires, + // then fail health and verify check goes critical + HealthServer.Mock.HealthCheckCount.store(0); + HealthServer.Mock.FailHealth.store(false); + + Info.InitialStatus = "passing"; + Client.RegisterService(Info); + REQUIRE(WaitForCondition([&]() { return Client.HasService(ServiceId); }, 10000, 50)); + + // Registration is async; by the time HasService observes the service the 1s health interval + // may already have fired, so we can't robustly assert HealthCheckCount==0. The "passing" status + // below still proves InitialStatus applied (it can only be "passing" via InitialStatus or a + // successful health check - both are acceptable demonstrations). + CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing"); + + REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000, 50)); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing"); + + HealthServer.Mock.FailHealth.store(true); + + // Wait for Consul to observe the failing check + REQUIRE(WaitForCondition([&]() { return GetCheckStatus(Client, ServiceId) == "critical"; }, 10000, 50)); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "critical"); + + // Phase 4: Explicit deregister while critical + Client.DeregisterService(ServiceId); + REQUIRE(WaitForCondition([&]() { return !Client.HasService(ServiceId); }, 10000, 50)); + + // Phase 5: Deregister an already-deregistered service - should not crash + Client.DeregisterService(ServiceId); + REQUIRE(WaitForCondition([&]() { return !Client.HasService(ServiceId); }, 10000, 50)); + + ConsulProc.StopConsulAgent(); +} + +TEST_SUITE_END(); + +#endif + } // namespace zen::consul diff --git a/src/zenutil/filesystemutils.cpp b/src/zenutil/filesystemutils.cpp new file mode 100644 index 000000000..f8f7bfb18 --- /dev/null +++ b/src/zenutil/filesystemutils.cpp @@ -0,0 +1,724 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/filesystemutils.h> + +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/parallelwork.h> +#include <zencore/scopeguard.h> +#include <zencore/timer.h> +#include <zencore/trace.h> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <zencore/testutils.h> +#endif // ZEN_WITH_TESTS + +namespace zen { + +BufferedOpenFile::BufferedOpenFile(const std::filesystem::path Path, + std::atomic<uint64_t>& OpenReadCount, + std::atomic<uint64_t>& CurrentOpenFileCount, + std::atomic<uint64_t>& ReadCount, + std::atomic<uint64_t>& ReadByteCount) +: m_Source(Path, BasicFile::Mode::kRead) +, m_SourceSize(m_Source.FileSize()) +, m_OpenReadCount(OpenReadCount) +, m_CurrentOpenFileCount(CurrentOpenFileCount) +, m_ReadCount(ReadCount) +, m_ReadByteCount(ReadByteCount) + +{ + m_OpenReadCount++; + m_CurrentOpenFileCount++; +} + +BufferedOpenFile::~BufferedOpenFile() +{ + m_CurrentOpenFileCount--; +} + +CompositeBuffer +BufferedOpenFile::GetRange(uint64_t Offset, uint64_t Size) +{ + ZEN_TRACE_CPU("BufferedOpenFile::GetRange"); + + ZEN_ASSERT((m_CacheBlockIndex == (uint64_t)-1) || m_Cache); + auto _ = MakeGuard([&]() { ZEN_ASSERT((m_CacheBlockIndex == (uint64_t)-1) || m_Cache); }); + + ZEN_ASSERT((Offset + Size) <= m_SourceSize); + const uint64_t BlockIndexStart = Offset / BlockSize; + const uint64_t BlockIndexEnd = (Offset + Size - 1) / BlockSize; + + std::vector<SharedBuffer> BufferRanges; + BufferRanges.reserve(BlockIndexEnd - BlockIndexStart + 1); + + uint64_t ReadOffset = Offset; + for (uint64_t BlockIndex = BlockIndexStart; BlockIndex <= BlockIndexEnd; BlockIndex++) + { + const uint64_t BlockStartOffset = BlockIndex * BlockSize; + if (m_CacheBlockIndex != BlockIndex) + { + uint64_t CacheSize = Min(BlockSize, m_SourceSize - BlockStartOffset); + ZEN_ASSERT(CacheSize > 0); + m_Cache = IoBuffer(CacheSize); + m_Source.Read(m_Cache.GetMutableView().GetData(), CacheSize, BlockStartOffset); + m_ReadCount++; + m_ReadByteCount += CacheSize; + m_CacheBlockIndex = BlockIndex; + } + + const uint64_t BytesRead = ReadOffset - Offset; + ZEN_ASSERT(BlockStartOffset <= ReadOffset); + const uint64_t OffsetIntoBlock = ReadOffset - BlockStartOffset; + ZEN_ASSERT(OffsetIntoBlock < m_Cache.GetSize()); + const uint64_t BlockBytes = Min(m_Cache.GetSize() - OffsetIntoBlock, Size - BytesRead); + BufferRanges.emplace_back(SharedBuffer(IoBuffer(m_Cache, OffsetIntoBlock, BlockBytes))); + ReadOffset += BlockBytes; + } + CompositeBuffer Result(std::move(BufferRanges)); + ZEN_ASSERT(Result.GetSize() == Size); + return Result; +} + +bool +IsFileWithRetry(const std::filesystem::path& Path) +{ + std::error_code Ec; + bool Result = IsFile(Path, Ec); + for (size_t Retries = 0; Ec && Retries < 3; Retries++) + { + Sleep(100 + int(Retries * 50)); + Ec.clear(); + Result = IsFile(Path, Ec); + } + if (Ec) + { + throw std::system_error(std::error_code(Ec.value(), std::system_category()), + fmt::format("Failed to check path '{}' is file, reason: ({}) {}", Path, Ec.value(), Ec.message())); + } + return Result; +} + +bool +SetFileReadOnlyWithRetry(const std::filesystem::path& Path, bool ReadOnly) +{ + std::error_code Ec; + bool Result = SetFileReadOnly(Path, ReadOnly, Ec); + for (size_t Retries = 0; Ec && Retries < 3; Retries++) + { + if (!IsFileWithRetry(Path)) + { + return false; + } + Sleep(100 + int(Retries * 50)); + Ec.clear(); + Result = SetFileReadOnly(Path, ReadOnly, Ec); + } + if (Ec) + { + throw std::system_error(std::error_code(Ec.value(), std::system_category()), + fmt::format("Failed {} read only flag for file '{}', reason: ({}) {}", + ReadOnly ? "setting" : "clearing", + Path, + Ec.value(), + Ec.message())); + } + return Result; +} + +std::error_code +RenameFileWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath) +{ + std::error_code Ec; + RenameFile(SourcePath, TargetPath, Ec); + for (size_t Retries = 0; Ec && Retries < 5; Retries++) + { + ZEN_ASSERT_SLOW(IsFile(SourcePath)); + Sleep(50 + int(Retries * 150)); + Ec.clear(); + RenameFile(SourcePath, TargetPath, Ec); + } + return Ec; +} + +std::error_code +RenameDirectoryWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath) +{ + std::error_code Ec; + RenameDirectory(SourcePath, TargetPath, Ec); + for (size_t Retries = 0; Ec && Retries < 5; Retries++) + { + ZEN_ASSERT_SLOW(IsDir(SourcePath)); + Sleep(50 + int(Retries * 150)); + Ec.clear(); + RenameDirectory(SourcePath, TargetPath, Ec); + } + return Ec; +} + +std::error_code +TryRemoveFile(const std::filesystem::path& Path) +{ + std::error_code Ec; + RemoveFile(Path, Ec); + if (Ec) + { + if (IsFile(Path, Ec)) + { + Ec.clear(); + RemoveFile(Path, Ec); + if (Ec) + { + return Ec; + } + } + } + return {}; +} + +void +RemoveFileWithRetry(const std::filesystem::path& Path) +{ + std::error_code Ec; + RemoveFile(Path, Ec); + for (size_t Retries = 0; Ec && Retries < 6; Retries++) + { + if (!IsFileWithRetry(Path)) + { + return; + } + Sleep(100 + int(Retries * 50)); + Ec.clear(); + RemoveFile(Path, Ec); + } + if (Ec) + { + throw std::system_error(std::error_code(Ec.value(), std::system_category()), + fmt::format("Failed removing file '{}', reason: ({}) {}", Path, Ec.value(), Ec.message())); + } +} + +void +FastCopyFile(bool AllowFileClone, + bool UseSparseFiles, + const std::filesystem::path& SourceFilePath, + const std::filesystem::path& TargetFilePath, + uint64_t RawSize, + std::atomic<uint64_t>& WriteCount, + std::atomic<uint64_t>& WriteByteCount, + std::atomic<uint64_t>& CloneCount, + std::atomic<uint64_t>& CloneByteCount) +{ + ZEN_TRACE_CPU("CopyFile"); + std::error_code CloneEc = + AllowFileClone ? TryCloneFile(SourceFilePath, TargetFilePath) : std::make_error_code(std::errc::operation_not_supported); + if (!CloneEc) + { + WriteCount += 1; + WriteByteCount += RawSize; + CloneCount += 1; + CloneByteCount += RawSize; + } + else + { + BasicFile TargetFile(TargetFilePath, BasicFile::Mode::kTruncate); + if (UseSparseFiles) + { + PrepareFileForScatteredWrite(TargetFile.Handle(), RawSize); + } + uint64_t Offset = 0; + std::error_code ScanEc = ScanFile(SourceFilePath, 512u * 1024u, [&](const void* Data, size_t Size) { + TargetFile.Write(Data, Size, Offset); + Offset += Size; + WriteCount++; + WriteByteCount += Size; + }); + if (ScanEc) + { + throw std::system_error(ScanEc, fmt::format("Failed to copy file '{}' to '{}'", SourceFilePath, TargetFilePath)); + } + } +} + +void +GetDirectoryContent(WorkerThreadPool& WorkerPool, + const std::filesystem::path& Path, + DirectoryContentFlags Flags, + DirectoryContent& OutContent) +{ + struct Visitor : public GetDirectoryContentVisitor + { + Visitor(zen::DirectoryContent& OutContent, const std::filesystem::path& InRootPath) : Content(OutContent), RootPath(InRootPath) {} + virtual bool AsyncAllowDirectory(const std::filesystem::path& Parent, const std::filesystem::path& DirectoryName) const + { + ZEN_UNUSED(Parent, DirectoryName); + return true; + } + virtual void AsyncVisitDirectory(const std::filesystem::path& RelativeRoot, DirectoryContent&& InContent) + { + std::vector<std::filesystem::path> Files; + std::vector<std::filesystem::path> Directories; + + if (!InContent.FileNames.empty()) + { + Files.reserve(InContent.FileNames.size()); + for (const std::filesystem::path& FileName : InContent.FileNames) + { + if (RelativeRoot.empty()) + { + Files.push_back(RootPath / FileName); + } + else + { + Files.push_back(RootPath / RelativeRoot / FileName); + } + } + } + + if (!InContent.DirectoryNames.empty()) + { + Directories.reserve(InContent.DirectoryNames.size()); + for (const std::filesystem::path& DirName : InContent.DirectoryNames) + { + if (RelativeRoot.empty()) + { + Directories.push_back(RootPath / DirName); + } + else + { + Directories.push_back(RootPath / RelativeRoot / DirName); + } + } + } + + Lock.WithExclusiveLock([&]() { + if (!InContent.FileNames.empty()) + { + for (const std::filesystem::path& FileName : InContent.FileNames) + { + if (RelativeRoot.empty()) + { + Content.Files.push_back(RootPath / FileName); + } + else + { + Content.Files.push_back(RootPath / RelativeRoot / FileName); + } + } + } + if (!InContent.FileSizes.empty()) + { + Content.FileSizes.insert(Content.FileSizes.end(), InContent.FileSizes.begin(), InContent.FileSizes.end()); + } + if (!InContent.FileAttributes.empty()) + { + Content.FileAttributes.insert(Content.FileAttributes.end(), + InContent.FileAttributes.begin(), + InContent.FileAttributes.end()); + } + if (!InContent.FileModificationTicks.empty()) + { + Content.FileModificationTicks.insert(Content.FileModificationTicks.end(), + InContent.FileModificationTicks.begin(), + InContent.FileModificationTicks.end()); + } + + if (!InContent.DirectoryNames.empty()) + { + for (const std::filesystem::path& DirName : InContent.DirectoryNames) + { + if (RelativeRoot.empty()) + { + Content.Directories.push_back(RootPath / DirName); + } + else + { + Content.Directories.push_back(RootPath / RelativeRoot / DirName); + } + } + } + if (!InContent.DirectoryAttributes.empty()) + { + Content.DirectoryAttributes.insert(Content.DirectoryAttributes.end(), + InContent.DirectoryAttributes.begin(), + InContent.DirectoryAttributes.end()); + } + }); + } + RwLock Lock; + zen::DirectoryContent& Content; + const std::filesystem::path& RootPath; + }; + + Visitor RootVisitor(OutContent, Path); + + Latch PendingWork(1); + GetDirectoryContent(Path, Flags, RootVisitor, WorkerPool, PendingWork); + PendingWork.CountDown(); + PendingWork.Wait(); +} + +CleanDirectoryResult +CleanDirectory( + WorkerThreadPool& IOWorkerPool, + std::atomic<bool>& AbortFlag, + std::atomic<bool>& PauseFlag, + const std::filesystem::path& Path, + std::span<const std::string> ExcludeDirectories, + std::function<void(const std::string_view Details, uint64_t TotalCount, uint64_t RemainingCount, bool IsPaused, bool IsAborted)>&& + ProgressFunc, + uint32_t ProgressUpdateDelayMS) +{ + ZEN_TRACE_CPU("CleanDirectory"); + Stopwatch Timer; + + std::atomic<uint64_t> DiscoveredItemCount = 0; + std::atomic<uint64_t> DeletedItemCount = 0; + std::atomic<uint64_t> DeletedByteCount = 0; + + std::vector<std::filesystem::path> DirectoriesToDelete; + CleanDirectoryResult Result; + RwLock ResultLock; + auto _ = MakeGuard([&]() { + Result.DeletedCount = DeletedItemCount.load(); + Result.DeletedByteCount = DeletedByteCount.load(); + Result.FoundCount = DiscoveredItemCount.load(); + }); + + ParallelWork Work(AbortFlag, + PauseFlag, + ProgressFunc ? WorkerThreadPool::EMode::DisableBacklog : WorkerThreadPool::EMode::EnableBacklog); + + struct AsyncVisitor : public GetDirectoryContentVisitor + { + AsyncVisitor(const std::filesystem::path& InPath, + std::atomic<bool>& InAbortFlag, + std::atomic<uint64_t>& InDiscoveredItemCount, + std::atomic<uint64_t>& InDeletedItemCount, + std::atomic<uint64_t>& InDeletedByteCount, + std::span<const std::string> InExcludeDirectories, + std::vector<std::filesystem::path>& OutDirectoriesToDelete, + CleanDirectoryResult& InResult, + RwLock& InResultLock) + : Path(InPath) + , AbortFlag(InAbortFlag) + , DiscoveredItemCount(InDiscoveredItemCount) + , DeletedItemCount(InDeletedItemCount) + , DeletedByteCount(InDeletedByteCount) + , ExcludeDirectories(InExcludeDirectories) + , DirectoriesToDelete(OutDirectoriesToDelete) + , Result(InResult) + , ResultLock(InResultLock) + { + } + + virtual bool AsyncAllowDirectory(const std::filesystem::path& Parent, const std::filesystem::path& DirectoryName) const override + { + ZEN_UNUSED(Parent); + + if (AbortFlag) + { + return false; + } + const std::string DirectoryString = DirectoryName.string(); + for (const std::string_view ExcludeDirectory : ExcludeDirectories) + { + if (DirectoryString == ExcludeDirectory) + { + return false; + } + } + return true; + } + + virtual void AsyncVisitDirectory(const std::filesystem::path& RelativeRoot, DirectoryContent&& Content) override + { + ZEN_TRACE_CPU("CleanDirectory_AsyncVisitDirectory"); + if (!AbortFlag) + { + DiscoveredItemCount += Content.FileNames.size(); + + ZEN_TRACE_CPU("DeleteFiles"); + std::vector<std::pair<std::filesystem::path, std::error_code>> FailedRemovePaths; + for (size_t FileIndex = 0; FileIndex < Content.FileNames.size(); FileIndex++) + { + const std::filesystem::path& FileName = Content.FileNames[FileIndex]; + const std::filesystem::path FilePath = (Path / RelativeRoot / FileName).make_preferred(); + + bool IsRemoved = false; + std::error_code Ec; + (void)SetFileReadOnly(FilePath, false, Ec); + for (size_t Retries = 0; Ec && Retries < 3; Retries++) + { + if (!IsFileWithRetry(FilePath)) + { + IsRemoved = true; + Ec.clear(); + break; + } + Sleep(100 + int(Retries * 50)); + Ec.clear(); + (void)SetFileReadOnly(FilePath, false, Ec); + } + if (!IsRemoved && !Ec) + { + (void)RemoveFile(FilePath, Ec); + for (size_t Retries = 0; Ec && Retries < 6; Retries++) + { + if (!IsFileWithRetry(FilePath)) + { + IsRemoved = true; + Ec.clear(); + break; + } + Sleep(100 + int(Retries * 50)); + Ec.clear(); + (void)RemoveFile(FilePath, Ec); + } + } + if (!IsRemoved && Ec) + { + FailedRemovePaths.push_back(std::make_pair(FilePath, Ec)); + } + else + { + DeletedItemCount++; + DeletedByteCount += Content.FileSizes[FileIndex]; + } + } + + if (!FailedRemovePaths.empty()) + { + RwLock::ExclusiveLockScope _(ResultLock); + Result.FailedRemovePaths.insert(Result.FailedRemovePaths.end(), FailedRemovePaths.begin(), FailedRemovePaths.end()); + } + else if (!RelativeRoot.empty()) + { + DiscoveredItemCount++; + RwLock::ExclusiveLockScope _(ResultLock); + DirectoriesToDelete.push_back(RelativeRoot); + } + } + } + const std::filesystem::path& Path; + std::atomic<bool>& AbortFlag; + std::atomic<uint64_t>& DiscoveredItemCount; + std::atomic<uint64_t>& DeletedItemCount; + std::atomic<uint64_t>& DeletedByteCount; + std::span<const std::string> ExcludeDirectories; + std::vector<std::filesystem::path>& DirectoriesToDelete; + CleanDirectoryResult& Result; + RwLock& ResultLock; + } Visitor(Path, + AbortFlag, + DiscoveredItemCount, + DeletedItemCount, + DeletedByteCount, + ExcludeDirectories, + DirectoriesToDelete, + Result, + ResultLock); + + GetDirectoryContent(Path, + DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive | DirectoryContentFlags::IncludeFileSizes, + Visitor, + IOWorkerPool, + Work.PendingWork()); + + uint64_t LastUpdateTimeMs = Timer.GetElapsedTimeMs(); + + if (ProgressFunc && ProgressUpdateDelayMS != 0) + { + Work.Wait(ProgressUpdateDelayMS, [&](bool IsAborted, bool IsPaused, ptrdiff_t PendingWork) { + ZEN_UNUSED(PendingWork); + LastUpdateTimeMs = Timer.GetElapsedTimeMs(); + + uint64_t Deleted = DeletedItemCount.load(); + uint64_t DeletedBytes = DeletedByteCount.load(); + uint64_t Discovered = DiscoveredItemCount.load(); + std::string Details = fmt::format("Found {}, Deleted {} ({})", Discovered, Deleted, NiceBytes(DeletedBytes)); + ProgressFunc(Details, Discovered, Discovered - Deleted, IsPaused, IsAborted); + }); + } + else + { + Work.Wait(); + } + + { + ZEN_TRACE_CPU("DeleteDirs"); + + std::sort(DirectoriesToDelete.begin(), + DirectoriesToDelete.end(), + [](const std::filesystem::path& Lhs, const std::filesystem::path& Rhs) { + auto DistanceLhs = std::distance(Lhs.begin(), Lhs.end()); + auto DistanceRhs = std::distance(Rhs.begin(), Rhs.end()); + return DistanceLhs > DistanceRhs; + }); + + for (const std::filesystem::path& DirectoryToDelete : DirectoriesToDelete) + { + if (AbortFlag) + { + break; + } + else + { + while (PauseFlag && !AbortFlag) + { + Sleep(2000); + } + } + + const std::filesystem::path FullPath = Path / DirectoryToDelete; + + std::error_code Ec; + RemoveDir(FullPath, Ec); + if (Ec) + { + for (size_t Retries = 0; Ec && Retries < 3; Retries++) + { + if (!IsDir(FullPath)) + { + Ec.clear(); + break; + } + Sleep(100 + int(Retries * 50)); + Ec.clear(); + RemoveDir(FullPath, Ec); + } + } + if (Ec) + { + RwLock::ExclusiveLockScope __(ResultLock); + Result.FailedRemovePaths.push_back(std::make_pair(DirectoryToDelete, Ec)); + } + else + { + DeletedItemCount++; + } + + if (ProgressFunc) + { + uint64_t NowMs = Timer.GetElapsedTimeMs(); + + if ((NowMs - LastUpdateTimeMs) > 0) + { + LastUpdateTimeMs = NowMs; + + uint64_t Deleted = DeletedItemCount.load(); + uint64_t DeletedBytes = DeletedByteCount.load(); + uint64_t Discovered = DiscoveredItemCount.load(); + std::string Details = fmt::format("Found {}, Deleted {} ({})", Discovered, Deleted, NiceBytes(DeletedBytes)); + ProgressFunc(Details, Discovered, Discovered - Deleted, PauseFlag, AbortFlag); + } + } + } + } + + return Result; +} + +bool +CleanAndRemoveDirectory(WorkerThreadPool& WorkerPool, + std::atomic<bool>& AbortFlag, + std::atomic<bool>& PauseFlag, + const std::filesystem::path& Directory) +{ + if (!IsDir(Directory)) + { + return true; + } + if (CleanDirectoryResult Res = CleanDirectory( + WorkerPool, + AbortFlag, + PauseFlag, + Directory, + {}, + [](const std::string_view Details, uint64_t TotalCount, uint64_t RemainingCount, bool IsPaused, bool IsAborted) { + ZEN_UNUSED(Details, TotalCount, RemainingCount, IsPaused, IsAborted); + }, + 1000); + Res.FailedRemovePaths.empty()) + { + std::error_code Ec; + RemoveDir(Directory, Ec); + return !Ec; + } + return false; +} + +#if ZEN_WITH_TESTS + +void +filesystemutils_forcelink() +{ +} + +namespace { + void GenerateFile(const std::filesystem::path& Path) { BasicFile _(Path, BasicFile::Mode::kTruncate); } +} // namespace + +TEST_SUITE_BEGIN("util.filesystemutils"); + +TEST_CASE("filesystemutils.CleanDirectory") +{ + ScopedTemporaryDirectory TmpDir; + + CreateDirectories(TmpDir.Path() / ".keepme"); + GenerateFile(TmpDir.Path() / ".keepme" / "keep"); + GenerateFile(TmpDir.Path() / "deleteme1"); + GenerateFile(TmpDir.Path() / "deleteme2"); + GenerateFile(TmpDir.Path() / "deleteme3"); + CreateDirectories(TmpDir.Path() / ".keepmenot"); + CreateDirectories(TmpDir.Path() / "no.keepme"); + + CreateDirectories(TmpDir.Path() / "DeleteMe"); + GenerateFile(TmpDir.Path() / "DeleteMe" / "delete1"); + CreateDirectories(TmpDir.Path() / "CantDeleteMe"); + GenerateFile(TmpDir.Path() / "CantDeleteMe" / "delete1"); + GenerateFile(TmpDir.Path() / "CantDeleteMe" / "delete2"); + GenerateFile(TmpDir.Path() / "CantDeleteMe" / "delete3"); + CreateDirectories(TmpDir.Path() / "CantDeleteMe" / ".keepme"); + CreateDirectories(TmpDir.Path() / "CantDeleteMe" / "DeleteMe2"); + GenerateFile(TmpDir.Path() / "CantDeleteMe" / "DeleteMe2" / "delete2"); + GenerateFile(TmpDir.Path() / "CantDeleteMe" / "DeleteMe2" / "delete3"); + CreateDirectories(TmpDir.Path() / "CantDeleteMe2" / ".keepme"); + CreateDirectories(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept"); + GenerateFile(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept" / "kept1"); + GenerateFile(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept" / "kept2"); + GenerateFile(TmpDir.Path() / "CantDeleteMe2" / "deleteme"); + + WorkerThreadPool Pool(4); + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + + CleanDirectory(Pool, AbortFlag, PauseFlag, TmpDir.Path(), std::vector<std::string>{".keepme"}, {}, 0); + + CHECK(IsDir(TmpDir.Path() / ".keepme")); + CHECK(IsFile(TmpDir.Path() / ".keepme" / "keep")); + CHECK(!IsFile(TmpDir.Path() / "deleteme1")); + CHECK(!IsFile(TmpDir.Path() / "deleteme2")); + CHECK(!IsFile(TmpDir.Path() / "deleteme3")); + CHECK(!IsFile(TmpDir.Path() / ".keepmenot")); + CHECK(!IsFile(TmpDir.Path() / "no.keepme")); + + CHECK(!IsDir(TmpDir.Path() / "DeleteMe")); + CHECK(!IsDir(TmpDir.Path() / "DeleteMe2")); + + CHECK(IsDir(TmpDir.Path() / "CantDeleteMe")); + CHECK(IsDir(TmpDir.Path() / "CantDeleteMe" / ".keepme")); + CHECK(IsDir(TmpDir.Path() / "CantDeleteMe2")); + CHECK(IsDir(TmpDir.Path() / "CantDeleteMe2" / ".keepme")); + CHECK(IsDir(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept")); + CHECK(IsFile(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept" / "kept1")); + CHECK(IsFile(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept" / "kept2")); + CHECK(!IsFile(TmpDir.Path() / "CantDeleteMe2" / "deleteme")); +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen diff --git a/src/zenutil/filteredrate.cpp b/src/zenutil/filteredrate.cpp new file mode 100644 index 000000000..de01af57b --- /dev/null +++ b/src/zenutil/filteredrate.cpp @@ -0,0 +1,92 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/filteredrate.h> + +namespace zen { + +void +FilteredRate::Start() +{ + if (StartTimeUS == (uint64_t)-1) + { + uint64_t Expected = (uint64_t)-1; + if (StartTimeUS.compare_exchange_strong(Expected, Timer.GetElapsedTimeUs())) + { + LastTimeUS = StartTimeUS.load(); + } + } +} + +void +FilteredRate::Stop() +{ + if (EndTimeUS == (uint64_t)-1) + { + uint64_t Expected = (uint64_t)-1; + EndTimeUS.compare_exchange_strong(Expected, Timer.GetElapsedTimeUs()); + } +} + +void +FilteredRate::Update(uint64_t Count) +{ + if (LastTimeUS == (uint64_t)-1) + { + return; + } + uint64_t TimeUS = Timer.GetElapsedTimeUs(); + uint64_t TimeDeltaUS = TimeUS - LastTimeUS; + if (TimeDeltaUS >= 2000000) + { + uint64_t Delta = Count - LastCount; + uint64_t PerSecond = (Delta * 1000000) / TimeDeltaUS; + + FilteredPerSecond = (PerSecond + (LastPerSecond * 7)) / 8; + + LastPerSecond = PerSecond; + LastCount = Count; + LastTimeUS = TimeUS; + } +} + +uint64_t +FilteredRate::GetCurrent() const +{ + if (LastTimeUS == (uint64_t)-1) + { + return 0; + } + return FilteredPerSecond; +} + +uint64_t +FilteredRate::GetElapsedTimeUS() const +{ + if (StartTimeUS == (uint64_t)-1) + { + return 0; + } + if (EndTimeUS == (uint64_t)-1) + { + return 0; + } + return EndTimeUS - StartTimeUS; +} + +bool +FilteredRate::IsActive() const +{ + return (StartTimeUS != (uint64_t)-1) && (EndTimeUS == (uint64_t)-1); +} + +uint64_t +GetBytesPerSecond(uint64_t ElapsedWallTimeUS, uint64_t Count) +{ + if (ElapsedWallTimeUS == 0) + { + return 0; + } + return Count * 1000000 / ElapsedWallTimeUS; +} + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/mockimds.h b/src/zenutil/include/zenutil/cloud/mockimds.h index d0c0155b0..28e1e8ba6 100644 --- a/src/zenutil/include/zenutil/cloud/mockimds.h +++ b/src/zenutil/include/zenutil/cloud/mockimds.h @@ -23,7 +23,7 @@ namespace zen::compute { * * When a request arrives for a provider that is not the ActiveProvider, the * mock returns 404, causing CloudMetadata to write a sentinel file and move - * on to the next provider — exactly like a failed probe on bare metal. + * on to the next provider - exactly like a failed probe on bare metal. * * All config fields are public and can be mutated between poll cycles to * simulate state changes (e.g. a spot interruption appearing mid-run). @@ -45,13 +45,13 @@ public: std::string AvailabilityZone = "us-east-1a"; std::string LifeCycle = "on-demand"; // "spot" or "on-demand" - // Empty string → endpoint returns 404 (instance not in an ASG). - // Non-empty → returned as the response body. "InService" means healthy; + // Empty string -> endpoint returns 404 (instance not in an ASG). + // Non-empty -> returned as the response body. "InService" means healthy; // anything else (e.g. "Terminated:Wait") triggers termination detection. std::string AutoscalingState; - // Empty string → endpoint returns 404 (no spot interruption). - // Non-empty → returned as the response body, signalling a spot reclaim. + // Empty string -> endpoint returns 404 (no spot interruption). + // Non-empty -> returned as the response body, signalling a spot reclaim. std::string SpotAction; // IAM credential fields for ImdsCredentialProvider testing @@ -69,10 +69,10 @@ public: std::string Location = "eastus"; std::string Priority = "Regular"; // "Spot" or "Regular" - // Empty → instance is not in a VM Scale Set (no autoscaling). + // Empty -> instance is not in a VM Scale Set (no autoscaling). std::string VmScaleSetName; - // Empty → no scheduled events. Set to "Preempt", "Terminate", or + // Empty -> no scheduled events. Set to "Preempt", "Terminate", or // "Reboot" to simulate a termination-class event. std::string ScheduledEventType; std::string ScheduledEventStatus = "Scheduled"; diff --git a/src/zenutil/include/zenutil/cloud/s3client.h b/src/zenutil/include/zenutil/cloud/s3client.h index 47501c5b5..1ce2a768e 100644 --- a/src/zenutil/include/zenutil/cloud/s3client.h +++ b/src/zenutil/include/zenutil/cloud/s3client.h @@ -11,6 +11,12 @@ #include <zencore/thread.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <functional> +#include <span> #include <string> #include <string_view> #include <vector> @@ -34,9 +40,7 @@ struct S3ClientOptions /// Overrides the static Credentials field. Ref<ImdsCredentialProvider> CredentialProvider; - std::chrono::milliseconds ConnectTimeout{5000}; - std::chrono::milliseconds Timeout{}; - uint8_t RetryCount = 3; + HttpClientSettings HttpSettings = {.LogCategory = "s3", .ConnectTimeout = std::chrono::milliseconds(5000), .RetryCount = 3}; }; struct S3ObjectInfo @@ -63,34 +67,36 @@ enum class HeadObjectResult Error, }; -/// Result of GetObject — carries the downloaded content. +/// Result of GetObject - carries the downloaded content. struct S3GetObjectResult : S3Result { IoBuffer Content; std::string_view AsText() const { return std::string_view(reinterpret_cast<const char*>(Content.GetData()), Content.GetSize()); } + + static std::string_view NotFoundErrorText; }; -/// Result of HeadObject — carries object metadata and existence status. +/// Result of HeadObject - carries object metadata and existence status. struct S3HeadObjectResult : S3Result { S3ObjectInfo Info; HeadObjectResult Status = HeadObjectResult::NotFound; }; -/// Result of ListObjects — carries the list of matching objects. +/// Result of ListObjects - carries the list of matching objects. struct S3ListObjectsResult : S3Result { std::vector<S3ObjectInfo> Objects; }; -/// Result of CreateMultipartUpload — carries the upload ID. +/// Result of CreateMultipartUpload - carries the upload ID. struct S3CreateMultipartUploadResult : S3Result { std::string UploadId; }; -/// Result of UploadPart — carries the part ETag. +/// Result of UploadPart - carries the part ETag. struct S3UploadPartResult : S3Result { std::string ETag; @@ -118,11 +124,21 @@ public: S3Result PutObject(std::string_view Key, IoBuffer Content); /// Download an object from S3 - S3GetObjectResult GetObject(std::string_view Key); + S3GetObjectResult GetObject(std::string_view Key, const std::filesystem::path& TempFilePath = {}); + + /// Download a byte range of an object from S3 + /// @param RangeStart First byte offset (inclusive) + /// @param RangeSize Number of bytes to download + S3GetObjectResult GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t RangeSize); /// Delete an object from S3 S3Result DeleteObject(std::string_view Key); + /// Refresh an object's LastModified timestamp via a PUT Object - Copy onto itself + /// (x-amz-metadata-directive: REPLACE). Useful to reset lifecycle-expiration timers + /// without re-uploading the content. + S3Result Touch(std::string_view Key); + /// Check if an object exists and get its metadata S3HeadObjectResult HeadObject(std::string_view Key); @@ -151,6 +167,16 @@ public: /// @param PartSize Size of each part in bytes (minimum 5 MB, default 8 MB) S3Result PutObjectMultipart(std::string_view Key, IoBuffer Content, uint64_t PartSize = 8 * 1024 * 1024); + /// High-level multipart upload: calls FetchRange(Offset, Size) to read each part on demand, + /// avoiding loading the full content into memory. + /// @param TotalSize Total object size in bytes + /// @param FetchRange Callback invoked once per part; must return exactly Size bytes + /// @param PartSize Size of each part in bytes (minimum 5 MB, default 8 MB) + S3Result PutObjectMultipart(std::string_view Key, + uint64_t TotalSize, + std::function<IoBuffer(uint64_t Offset, uint64_t Size)> FetchRange, + uint64_t PartSize = 8 * 1024 * 1024); + /// Generate a pre-signed URL for downloading an object (GET) /// @param Key The object key /// @param ExpiresIn URL validity duration (default 1 hour, max 7 days) @@ -182,11 +208,15 @@ private: /// Build the bucket root path ("/" for virtual-hosted, "/bucket/" for path-style) std::string BucketRootPath() const; - /// Sign a request and return headers with Authorization, x-amz-date, x-amz-content-sha256 - HttpClient::KeyValueMap SignRequest(std::string_view Method, - std::string_view Path, - std::string_view QueryString, - std::string_view PayloadHash); + /// Sign a request and return headers with Authorization, x-amz-date, x-amz-content-sha256. + /// Additional x-amz-* headers that must participate in the signature are passed via + /// ExtraSignedHeaders (lowercase name, value); they are also copied into the returned map. + HttpClient::KeyValueMap SignRequest(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Path, + std::string_view QueryString, + std::string_view PayloadHash, + std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders = {}); /// Get or compute the signing key for the given date stamp, caching across requests on the same day Sha256Digest GetSigningKey(std::string_view DateStamp); @@ -194,6 +224,23 @@ private: /// Get the current credentials, either from the provider or from static config SigV4Credentials GetCurrentCredentials(); + /// Populate OutCredentials and return empty string on success; on failure return a + /// "<context>: no credentials available" error (also logged). Context args are only + /// formatted on the failure path. + template<typename... ContextArgs> + std::string RequireCredentials(SigV4Credentials& OutCredentials, fmt::format_string<ContextArgs...> ContextFmt, ContextArgs&&... Args) + { + OutCredentials = GetCurrentCredentials(); + if (!OutCredentials.AccessKeyId.empty()) + { + return {}; + } + OutCredentials = {}; + return BuildNoCredentialsError(fmt::format(ContextFmt, std::forward<ContextArgs>(Args)...)); + } + + std::string BuildNoCredentialsError(std::string Context); + LoggerRef m_Log; std::string m_BucketName; std::string m_Region; @@ -203,6 +250,7 @@ private: SigV4Credentials m_Credentials; Ref<ImdsCredentialProvider> m_CredentialProvider; HttpClient m_HttpClient; + bool m_Verbose = false; // Cached signing key (only changes once per day, protected by RwLock for thread safety) mutable RwLock m_SigningKeyLock; diff --git a/src/zenutil/include/zenutil/config/commandlineoptions.h b/src/zenutil/include/zenutil/config/commandlineoptions.h index 01cceedb1..ed7f46a08 100644 --- a/src/zenutil/include/zenutil/config/commandlineoptions.h +++ b/src/zenutil/include/zenutil/config/commandlineoptions.h @@ -22,6 +22,17 @@ std::vector<char*> StripCommandlineQuotes(std::vector<std::string>& InOutArgs) std::filesystem::path StringToPath(const std::string_view& Path); std::string_view RemoveQuotes(const std::string_view& Arg); +/// Scrub sensitive values from a command-line string in place. Masks the value +/// of options whose normalized name ends in "token", "aeskey", "secret" or +/// "dsn"; redacts user:password@ in URL authorities; partially masks values +/// matching credential formats that zen/zenserver handles (AWS access keys, +/// Google API keys, JWT bearer tokens). On any exception the string is left +/// unchanged. +/// +/// Used by the invocation-history writer and by the Sentry integration before +/// the command line is attached to a crash report. +void ScrubSensitiveValues(std::string& Cmdline) noexcept; + class CommandLineConverter { public: diff --git a/src/zenutil/include/zenutil/config/loggingconfig.h b/src/zenutil/include/zenutil/config/loggingconfig.h index b55b2d9f7..33a5eb172 100644 --- a/src/zenutil/include/zenutil/config/loggingconfig.h +++ b/src/zenutil/include/zenutil/config/loggingconfig.h @@ -16,10 +16,12 @@ struct ZenLoggingConfig { bool NoConsoleOutput = false; // Control default use of stdout for diagnostics bool QuietConsole = false; // Configure console logger output to level WARN + bool ForceColor = false; // Force colored output even when stdout is not a terminal std::filesystem::path AbsLogFile; // Absolute path to main log file std::string Loggers[logging::LogLevelCount]; - std::string LogId; // Id for tagging log output - std::string OtelEndpointUri; // OpenTelemetry endpoint URI + std::string LogId; // Id for tagging log output + std::string OtelEndpointUri; // OpenTelemetry endpoint URI + std::string LogStreamEndpoint; // TCP log stream endpoint (host:port) }; void ApplyLoggingOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig); diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h index 22737589b..49bb0cc92 100644 --- a/src/zenutil/include/zenutil/consoletui.h +++ b/src/zenutil/include/zenutil/consoletui.h @@ -30,7 +30,7 @@ bool IsTuiAvailable(); // - Title: a short description printed once above the list // - Items: pre-formatted display labels, one per selectable entry // -// Arrow keys (↑/↓) navigate the selection, Enter confirms, Esc cancels. +// Arrow keys (^/v) navigate the selection, Enter confirms, Esc cancels. // Returns the index of the selected item, or -1 if the user cancelled. // // Precondition: IsTuiAvailable() must be true. diff --git a/src/zenutil/include/zenutil/consul.h b/src/zenutil/include/zenutil/consul.h index 7bf2ce437..4efb10263 100644 --- a/src/zenutil/include/zenutil/consul.h +++ b/src/zenutil/include/zenutil/consul.h @@ -3,6 +3,7 @@ #pragma once #include <zenbase/zenbase.h> +#include <zencore/thread.h> #include <zenhttp/httpclient.h> #include <atomic> @@ -10,6 +11,7 @@ #include <string> #include <string_view> #include <thread> +#include <vector> namespace zen::consul { @@ -21,14 +23,22 @@ struct ServiceRegistrationInfo uint16_t Port = 0; std::string HealthEndpoint; std::vector<std::pair<std::string, std::string>> Tags; - int HealthIntervalSeconds = 10; - int DeregisterAfterSeconds = 30; + uint32_t HealthIntervalSeconds = 10; + uint32_t DeregisterAfterSeconds = 30; + std::string InitialStatus; }; class ConsulClient { public: - ConsulClient(std::string_view BaseUri, std::string_view Token = ""); + struct Configuration + { + std::string BaseUri; + std::string StaticToken; + std::string TokenEnvName; + }; + + ConsulClient(const Configuration& Config); ~ConsulClient(); ConsulClient(const ConsulClient&) = delete; @@ -38,12 +48,20 @@ public: std::string GetKeyValue(std::string_view Key); void DeleteKey(std::string_view Key); - bool RegisterService(const ServiceRegistrationInfo& Info); - bool DeregisterService(std::string_view ServiceId); + // Async. Enqueue onto the worker thread and return immediately. + // Transport outcome is not reported to the caller. + void RegisterService(const ServiceRegistrationInfo& Info); + void DeregisterService(std::string_view ServiceId); + + // Synchronous counterparts. Block on the HTTP call and return true on + // success. Use when the caller needs a result (e.g. a retry loop). + bool DoRegister(const ServiceRegistrationInfo& Info); + bool DoDeregister(std::string_view ServiceId); // Query methods for testing bool HasService(std::string_view ServiceId); std::string GetAgentServicesJson(); + std::string GetAgentChecksJson(); // Blocking query on v1/agent/services. Blocks until the service list changes or // the wait period expires. InOutIndex must be 0 for the first call; it is updated @@ -52,11 +70,29 @@ public: bool WatchService(std::string_view ServiceId, uint64_t& InOutIndex, int WaitSeconds); private: + struct PendingOp + { + enum class Kind + { + Register, + Deregister + }; + Kind Type; + ServiceRegistrationInfo Info; + }; + static bool FindServiceInJson(std::string_view Json, std::string_view ServiceId); void ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap); - - std::string m_Token; - HttpClient m_HttpClient; + std::string GetNodeName(); + void WorkerLoop(); + + Configuration m_Config; + std::atomic<bool> m_Stop{false}; + HttpClient m_HttpClient; + RwLock m_QueueLock; + std::vector<PendingOp> m_Queue; + Event m_Wakeup; + std::thread m_Worker; }; class ConsulProcess @@ -108,4 +144,6 @@ private: void RegistrationLoop(); }; +void consul_forcelink(); + } // namespace zen::consul diff --git a/src/zenutil/include/zenutil/filesystemutils.h b/src/zenutil/include/zenutil/filesystemutils.h new file mode 100644 index 000000000..05defd1a8 --- /dev/null +++ b/src/zenutil/include/zenutil/filesystemutils.h @@ -0,0 +1,98 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> + +namespace zen { + +class CompositeBuffer; + +class BufferedOpenFile +{ +public: + static constexpr uint64_t BlockSize = 256u * 1024u; + + BufferedOpenFile(const std::filesystem::path Path, + std::atomic<uint64_t>& OpenReadCount, + std::atomic<uint64_t>& CurrentOpenFileCount, + std::atomic<uint64_t>& ReadCount, + std::atomic<uint64_t>& ReadByteCount); + ~BufferedOpenFile(); + BufferedOpenFile() = delete; + BufferedOpenFile(const BufferedOpenFile&) = delete; + BufferedOpenFile(BufferedOpenFile&&) = delete; + BufferedOpenFile& operator=(BufferedOpenFile&&) = delete; + BufferedOpenFile& operator=(const BufferedOpenFile&) = delete; + + CompositeBuffer GetRange(uint64_t Offset, uint64_t Size); + +public: + void* Handle() { return m_Source.Handle(); } + +private: + BasicFile m_Source; + const uint64_t m_SourceSize; + std::atomic<uint64_t>& m_OpenReadCount; + std::atomic<uint64_t>& m_CurrentOpenFileCount; + std::atomic<uint64_t>& m_ReadCount; + std::atomic<uint64_t>& m_ReadByteCount; + uint64_t m_CacheBlockIndex = (uint64_t)-1; + IoBuffer m_Cache; +}; + +bool IsFileWithRetry(const std::filesystem::path& Path); + +bool SetFileReadOnlyWithRetry(const std::filesystem::path& Path, bool ReadOnly); + +std::error_code RenameFileWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); +std::error_code RenameDirectoryWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); + +std::error_code TryRemoveFile(const std::filesystem::path& Path); + +void RemoveFileWithRetry(const std::filesystem::path& Path); + +void FastCopyFile(bool AllowFileClone, + bool UseSparseFiles, + const std::filesystem::path& SourceFilePath, + const std::filesystem::path& TargetFilePath, + uint64_t RawSize, + std::atomic<uint64_t>& WriteCount, + std::atomic<uint64_t>& WriteByteCount, + std::atomic<uint64_t>& CloneCount, + std::atomic<uint64_t>& CloneByteCount); + +struct CleanDirectoryResult +{ + uint64_t FoundCount = 0; + uint64_t DeletedCount = 0; + uint64_t DeletedByteCount = 0; + std::vector<std::pair<std::filesystem::path, std::error_code>> FailedRemovePaths; +}; + +class WorkerThreadPool; + +void GetDirectoryContent(WorkerThreadPool& WorkerPool, + const std::filesystem::path& Path, + DirectoryContentFlags Flags, + DirectoryContent& OutContent); + +CleanDirectoryResult CleanDirectory( + WorkerThreadPool& IOWorkerPool, + std::atomic<bool>& AbortFlag, + std::atomic<bool>& PauseFlag, + const std::filesystem::path& Path, + std::span<const std::string> ExcludeDirectories, + std::function<void(const std::string_view Details, uint64_t TotalCount, uint64_t RemainingCount, bool IsPaused, bool IsAborted)>&& + ProgressFunc, + uint32_t ProgressUpdateDelayMS); + +bool CleanAndRemoveDirectory(WorkerThreadPool& WorkerPool, + std::atomic<bool>& AbortFlag, + std::atomic<bool>& PauseFlag, + const std::filesystem::path& Directory); + +void filesystemutils_forcelink(); // internal + +} // namespace zen diff --git a/src/zenutil/include/zenutil/filteredrate.h b/src/zenutil/include/zenutil/filteredrate.h new file mode 100644 index 000000000..3349823d0 --- /dev/null +++ b/src/zenutil/include/zenutil/filteredrate.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/timer.h> + +#include <atomic> +#include <cstdint> + +namespace zen { + +class FilteredRate +{ +public: + FilteredRate() {} + + void Start(); + void Stop(); + void Update(uint64_t Count); + + uint64_t GetCurrent() const; + uint64_t GetElapsedTimeUS() const; + bool IsActive() const; + +private: + Stopwatch Timer; + std::atomic<uint64_t> StartTimeUS = (uint64_t)-1; + std::atomic<uint64_t> EndTimeUS = (uint64_t)-1; + std::atomic<uint64_t> LastTimeUS = (uint64_t)-1; + uint64_t LastCount = 0; + uint64_t LastPerSecond = 0; + uint64_t FilteredPerSecond = 0; +}; + +uint64_t GetBytesPerSecond(uint64_t ElapsedWallTimeUS, uint64_t Count); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/invocationhistory.h b/src/zenutil/include/zenutil/invocationhistory.h new file mode 100644 index 000000000..9843d14bc --- /dev/null +++ b/src/zenutil/include/zenutil/invocationhistory.h @@ -0,0 +1,59 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <cstdint> +#include <filesystem> +#include <initializer_list> +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +struct HistoryRecord +{ + std::string Id; + std::string Ts; + std::string Exe; + std::string Mode; + std::string Cwd; + std::string Path; + std::string CmdLine; + uint32_t Pid = 0; +}; + +/// Append an invocation record to the per-user history file. +/// +/// Must be called BEFORE CommandLineConverter has had a chance to mutate argv +/// (which strips quoting and re-parses the Windows command line). Typical call +/// site is right after InstallCrashHandler() in main(). +/// +/// Never throws. Any I/O or parsing failure is swallowed so that logging cannot +/// break the actual command being run. If argv contains the single-token +/// `--enable-execution-history=false` (or `=0` / `=no`), returns without doing +/// anything. +/// +/// Exe should be "zen" or "zenserver". Mode is zenserver-only (hub/store/ +/// compute/proxy/test); pass an empty string_view for zen. +/// +/// ExcludeSubcommands is an optional list of subcommand names. If argv[1] +/// matches any entry, logging is skipped. Used e.g. so `zen history` does not +/// pollute the history file it inspects. +void LogInvocation(std::string_view Exe, + std::string_view Mode, + int argc, + char** argv, + std::initializer_list<std::string_view> ExcludeSubcommands = {}) noexcept; + +/// Read the most recent MaxRecords entries from the history file (newest last). +/// Returns an empty vector if the file does not exist or is unreadable. +std::vector<HistoryRecord> ReadInvocationHistory(size_t MaxRecords = 100); + +/// Resolve the per-user history file path. Does not create the file or the +/// parent directory. Never throws; returns an empty path on any error. +std::filesystem::path GetInvocationHistoryPath() noexcept; + +void invocationhistory_forcelink(); // internal + +} // namespace zen diff --git a/src/zenutil/include/zenutil/logging.h b/src/zenutil/include/zenutil/logging.h index 95419c274..6abf6a96f 100644 --- a/src/zenutil/include/zenutil/logging.h +++ b/src/zenutil/include/zenutil/logging.h @@ -18,6 +18,10 @@ // for sharing across different executables // +namespace zen::logging { +class BroadcastSink; +} + namespace zen { struct LoggingOptions @@ -28,7 +32,8 @@ struct LoggingOptions bool AllowAsync = true; bool NoConsoleOutput = false; bool QuietConsole = false; - std::filesystem::path AbsLogFile; // Absolute path to main log file + bool ForceColor = false; // Force colored output even when stdout is not a terminal + std::filesystem::path AbsLogFile; // Absolute path to main log file std::string LogId; }; @@ -38,6 +43,7 @@ void FinishInitializeLogging(const LoggingOptions& LoggingOptions); void InitializeLogging(const LoggingOptions& LoggingOptions); void ShutdownLogging(); -logging::SinkPtr GetFileSink(); +logging::SinkPtr GetFileSink(); +Ref<logging::BroadcastSink> GetDefaultBroadcastSink(); } // namespace zen diff --git a/src/zenutil/include/zenutil/parallelsort.h b/src/zenutil/include/zenutil/parallelsort.h new file mode 100644 index 000000000..ed455ce9d --- /dev/null +++ b/src/zenutil/include/zenutil/parallelsort.h @@ -0,0 +1,119 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/scopeguard.h> +#include <zencore/thread.h> +#include <zencore/workthreadpool.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <EASTL/sort.h> +#include <EASTL/vector.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <algorithm> + +namespace zen { + +// Bottom-up parallel merge sort using WorkerThreadPool + Latch. +// +// Splits the range into chunks, sorts each chunk in parallel via +// eastl::sort, then iteratively merges adjacent pairs in parallel +// using std::inplace_merge until a single sorted range remains. +// +// Falls back to eastl::sort for ranges below kMinParallelSortSize. +template<typename RandomIt, typename Compare> +void +ParallelSort(WorkerThreadPool& Pool, RandomIt First, RandomIt Last, Compare Comp) +{ + constexpr size_t kMinParallelSortSize = 65536; + constexpr size_t kMinChunkSize = 65536; + constexpr size_t kMaxChunks = 64; + + size_t Count = size_t(Last - First); + if (Count <= kMinParallelSortSize) + { + eastl::sort(First, Last, Comp); + return; + } + + // Determine chunk count: enough to saturate workers, but not so many + // that scheduling overhead dominates. + size_t ChunkCount = (Count + kMinChunkSize - 1) / kMinChunkSize; + if (ChunkCount > kMaxChunks) + { + ChunkCount = kMaxChunks; + } + if (ChunkCount < 2) + { + ChunkCount = 2; + } + + // Compute chunk boundaries. + eastl::vector<RandomIt> Boundaries; + Boundaries.reserve(ChunkCount + 1); + size_t ChunkSize = Count / ChunkCount; + for (size_t I = 0; I < ChunkCount; ++I) + { + Boundaries.push_back(First + ptrdiff_t(I * ChunkSize)); + } + Boundaries.push_back(Last); + + // Phase 1: Sort each chunk in parallel. + { + Latch Done(1); + for (size_t I = 0; I < ChunkCount; ++I) + { + Done.AddCount(1); + Pool.ScheduleWork( + [&Done, Begin = Boundaries[I], End = Boundaries[I + 1], &Comp]() { + auto Guard = MakeGuard([&Done]() { Done.CountDown(); }); + eastl::sort(Begin, End, Comp); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + Done.CountDown(); + Done.Wait(); + } + + // Phase 2: Pairwise merge rounds until a single sorted range remains. + // Each round merges non-overlapping adjacent pairs in parallel, then + // compacts the boundary list. An odd trailing chunk is carried forward. + while (ChunkCount > 1) + { + size_t Pairs = ChunkCount / 2; + + { + Latch Done(1); + for (size_t I = 0; I < Pairs; ++I) + { + size_t Left = 2 * I; + Done.AddCount(1); + Pool.ScheduleWork( + [&Done, F = Boundaries[Left], M = Boundaries[Left + 1], L = Boundaries[Left + 2], &Comp]() { + auto Guard = MakeGuard([&Done]() { Done.CountDown(); }); + std::inplace_merge(F, M, L, Comp); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + Done.CountDown(); + Done.Wait(); + } + + // Compact boundaries: merged pairs collapse, odd chunk carried forward. + eastl::vector<RandomIt> NewBounds; + NewBounds.reserve((ChunkCount + 1) / 2 + 1); + for (size_t I = 0; I < ChunkCount; I += 2) + { + NewBounds.push_back(Boundaries[I]); + } + NewBounds.push_back(Last); + + ChunkCount = NewBounds.size() - 1; + Boundaries = eastl::move(NewBounds); + } +} + +void parallelsort_forcelink(); // internal + +} // namespace zen diff --git a/src/zenutil/include/zenutil/process/subprocessmanager.h b/src/zenutil/include/zenutil/process/subprocessmanager.h new file mode 100644 index 000000000..95d7fa43d --- /dev/null +++ b/src/zenutil/include/zenutil/process/subprocessmanager.h @@ -0,0 +1,285 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/process.h> +#include <zencore/zencore.h> + +#include <filesystem> +#include <functional> +#include <memory> +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +/** Tracked process entry with latest metrics snapshot. + */ +struct TrackedProcessEntry +{ + int Pid = 0; + ProcessMetrics Metrics; + + // Derived CPU usage percentage (delta-based, requires two samples). + // -1.0 means not yet sampled. + float CpuUsagePercent = -1.0f; +}; + +/** Aggregate metrics across all tracked processes. + */ +struct AggregateProcessMetrics +{ + uint64_t TotalWorkingSetSize = 0; + uint64_t TotalPeakWorkingSetSize = 0; + uint64_t TotalUserTimeMs = 0; + uint64_t TotalKernelTimeMs = 0; + uint32_t ProcessCount = 0; +}; + +} // namespace zen + +namespace asio { +class io_context; +} + +namespace zen { + +class ManagedProcess; +class ProcessGroup; + +/// Callback invoked when a managed process exits. +using ProcessExitCallback = std::function<void(ManagedProcess& Process, int ExitCode)>; + +/// Callback invoked when data is read from a managed process's stdout or stderr. +using ProcessDataCallback = std::function<void(ManagedProcess& Process, std::string_view Data)>; + +/// Configuration for SubprocessManager. +struct SubprocessManagerConfig +{ + /// Interval for periodic metrics sampling. Set to 0 to disable. + uint64_t MetricsSampleIntervalMs = 5000; + + /// Number of processes sampled per metrics tick (round-robin). + uint32_t MetricsBatchSize = 16; +}; + +/// Manages a set of child processes with async exit detection, stdout/stderr +/// capture, and periodic metrics sampling. +/// +/// All callbacks are posted to the io_context and never invoked under internal +/// locks. The caller must ensure the io_context outlives this manager and that +/// its run loop is active. +/// +/// Usage: +/// asio::io_context IoContext; +/// SubprocessManager Manager(IoContext); +/// +/// StdoutPipeHandles StdoutPipe; +/// CreateStdoutPipe(StdoutPipe); +/// +/// CreateProcOptions Options; +/// Options.StdoutPipe = &StdoutPipe; +/// +/// auto* Proc = Manager.Spawn(Executable, CommandLine, Options, +/// [](ManagedProcess& P, int Code) { ... }); +class SubprocessManager +{ +public: + explicit SubprocessManager(asio::io_context& IoContext, SubprocessManagerConfig Config = {}); + ~SubprocessManager(); + + SubprocessManager(const SubprocessManager&) = delete; + SubprocessManager& operator=(const SubprocessManager&) = delete; + + /// Spawn a new child process and begin monitoring it. + /// + /// If Options.StdoutPipe is set, the pipe is consumed and async reading + /// begins automatically. Similarly for Options.StderrPipe. When providing + /// pipes, pass the corresponding data callback here so it is installed + /// before the first async read completes - setting it later via + /// SetStdoutCallback risks losing early output. + /// + /// Returns a non-owning pointer valid until Remove() or manager destruction. + /// The exit callback fires on an io_context thread when the process terminates. + ManagedProcess* Spawn(const std::filesystem::path& Executable, + std::string_view CommandLine, + CreateProcOptions& Options, + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout = {}, + ProcessDataCallback OnStderr = {}); + + /// Adopt an already-running process by handle. Takes ownership of handle internals. + ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); + + /// Stop monitoring a process by pid. Does NOT kill the process - call + /// process->Kill() first if needed. The exit callback will not fire after + /// this returns. + void Remove(int Pid); + + /// Remove all managed processes. + void RemoveAll(); + + /// Set default stdout callback. Per-process callbacks override this. + void SetDefaultStdoutCallback(ProcessDataCallback Callback); + + /// Set default stderr callback. Per-process callbacks override this. + void SetDefaultStderrCallback(ProcessDataCallback Callback); + + /// Snapshot of per-process metrics for all managed processes. + [[nodiscard]] std::vector<TrackedProcessEntry> GetMetricsSnapshot() const; + + /// Aggregate metrics across all managed processes. + [[nodiscard]] AggregateProcessMetrics GetAggregateMetrics() const; + + /// Number of currently managed processes. + [[nodiscard]] size_t GetProcessCount() const; + + /// Enumerate all managed processes under a shared lock. + void Enumerate(std::function<void(const ManagedProcess&)> Callback) const; + + /// Create a new process group. The group is owned by this manager. + /// On Windows the group is backed by a JobObject (kill-on-close guarantee). + /// On POSIX the group uses setpgid for bulk signal delivery. + ProcessGroup* CreateGroup(std::string Name); + + /// Destroy a group by name. Kills all processes in the group first. + void DestroyGroup(std::string_view Name); + + /// Find a group by name. Returns nullptr if not found. + [[nodiscard]] ProcessGroup* FindGroup(std::string_view Name) const; + + /// Enumerate all groups. + void EnumerateGroups(std::function<void(const ProcessGroup&)> Callback) const; + +private: + friend class ProcessGroup; + + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +/// A process managed by SubprocessManager. +/// +/// Not user-constructible. Pointers obtained from Spawn()/Adopt() remain valid +/// until Remove() or manager destruction. +class ManagedProcess +{ +public: + ~ManagedProcess(); + + ManagedProcess(const ManagedProcess&) = delete; + ManagedProcess& operator=(const ManagedProcess&) = delete; + + /// Process id. + [[nodiscard]] int Pid() const; + + /// Whether the process is still running. + [[nodiscard]] bool IsRunning() const; + + /// Underlying process handle. + [[nodiscard]] const ProcessHandle& GetHandle() const; + + /// Most recently sampled metrics (best-effort snapshot). + [[nodiscard]] ProcessMetrics GetLatestMetrics() const; + + /// CPU usage percentage from the last two samples. Returns -1.0 if not + /// yet computed. + [[nodiscard]] float GetCpuUsagePercent() const; + + /// Return all stdout captured so far. When a callback is set, output is + /// delivered there instead of being accumulated. + [[nodiscard]] std::string GetCapturedStdout() const; + + /// Return all stderr captured so far. + [[nodiscard]] std::string GetCapturedStderr() const; + + /// Graceful shutdown with fallback to forced kill. + bool Kill(); + + /// Immediate forced termination. + bool Terminate(int ExitCode); + + /// User-defined tag for identifying this process in callbacks. + void SetTag(std::string Tag); + + /// Get the user-defined tag. + [[nodiscard]] std::string_view GetTag() const; + +private: + friend class SubprocessManager; + friend class ProcessGroup; + + struct Impl; + std::unique_ptr<Impl> m_Impl; + + explicit ManagedProcess(std::unique_ptr<Impl> InImpl); +}; + +/// A group of managed processes with OS-level backing. +/// +/// On Windows: backed by a JobObject. All processes assigned on spawn. +/// Kill-on-close guarantee - if the group is destroyed, the OS terminates +/// all member processes. +/// On Linux/macOS: uses setpgid() so children share a process group. +/// Enables bulk signal delivery via kill(-pgid, sig). +/// +/// Created via SubprocessManager::CreateGroup(). Not user-constructible. +class ProcessGroup +{ +public: + ~ProcessGroup(); + + ProcessGroup(const ProcessGroup&) = delete; + ProcessGroup& operator=(const ProcessGroup&) = delete; + + /// Group name (as passed to CreateGroup). + [[nodiscard]] std::string_view GetName() const; + + /// Spawn a process into this group. See SubprocessManager::Spawn for + /// details on the stdout/stderr callback parameters. + ManagedProcess* Spawn(const std::filesystem::path& Executable, + std::string_view CommandLine, + CreateProcOptions& Options, + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout = {}, + ProcessDataCallback OnStderr = {}); + + /// Adopt an already-running process into this group. + /// On Windows the process is assigned to the group's JobObject. + /// On POSIX the process cannot be moved into a different process group + /// after creation, so OS-level grouping is best-effort for adopted processes. + ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); + + /// Remove a process from this group. Does NOT kill it. + void Remove(int Pid); + + /// Kill all processes in the group. + /// On Windows: uses TerminateJobObject for atomic group kill. + /// On POSIX: sends SIGTERM then SIGKILL to the process group. + void KillAll(); + + /// Aggregate metrics for this group's processes. + [[nodiscard]] AggregateProcessMetrics GetAggregateMetrics() const; + + /// Per-process metrics snapshot for this group. + [[nodiscard]] std::vector<TrackedProcessEntry> GetMetricsSnapshot() const; + + /// Number of processes in this group. + [[nodiscard]] size_t GetProcessCount() const; + + /// Enumerate processes in this group. + void Enumerate(std::function<void(const ManagedProcess&)> Callback) const; + +private: + friend class SubprocessManager; + + struct Impl; + std::unique_ptr<Impl> m_Impl; + + explicit ProcessGroup(std::unique_ptr<Impl> InImpl); +}; + +void subprocessmanager_forcelink(); // internal + +} // namespace zen diff --git a/src/zenutil/include/zenutil/progress.h b/src/zenutil/include/zenutil/progress.h new file mode 100644 index 000000000..4103723b3 --- /dev/null +++ b/src/zenutil/include/zenutil/progress.h @@ -0,0 +1,68 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logbase.h> + +#include <memory> +#include <string> + +namespace zen { + +class ProgressBase +{ +public: + virtual ~ProgressBase() = default; + + virtual void SetLogOperationName(std::string_view Name) = 0; + virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) = 0; + virtual void PushLogOperation(std::string_view Name) = 0; + virtual void PopLogOperation() = 0; + virtual uint32_t GetProgressUpdateDelayMS() const = 0; + + class ProgressBar + { + public: + struct State + { + bool operator==(const State&) const = default; + std::string Task; + std::string Details; + uint64_t TotalCount = 0; + uint64_t RemainingCount = 0; + uint64_t OptionalElapsedTime = (uint64_t)-1; + enum class EStatus + { + Running, + Aborted, + Paused + }; + EStatus Status = EStatus::Running; + + static constexpr EStatus CalculateStatus(bool IsAborted, bool IsPaused) + { + if (IsAborted) + { + return EStatus::Aborted; + } + if (IsPaused) + { + return EStatus::Paused; + } + return EStatus::Running; + } + }; + + virtual ~ProgressBar() = default; + + virtual void UpdateState(const State& NewState, bool DoLinebreak) = 0; + virtual void ForceLinebreak() = 0; + virtual void Finish() = 0; + }; + + virtual std::unique_ptr<ProgressBar> CreateProgressBar(std::string_view InSubTask) = 0; +}; + +ProgressBase* CreateStandardProgress(LoggerRef Log); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/sessionsclient.h b/src/zenutil/include/zenutil/sessionsclient.h new file mode 100644 index 000000000..c144a9baa --- /dev/null +++ b/src/zenutil/include/zenutil/sessionsclient.h @@ -0,0 +1,95 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/blockingqueue.h> +#include <zencore/compactbinary.h> +#include <zencore/logging.h> +#include <zencore/logging/sink.h> +#include <zencore/string.h> +#include <zencore/uid.h> +#include <zenhttp/httpclient.h> + +#include <memory> +#include <string> +#include <thread> + +namespace zen { + +/// Client for announcing and maintaining a session on a remote zenserver's /sessions/ endpoint. +/// All HTTP I/O runs on a single background worker thread. Public methods enqueue commands +/// and return immediately. +class SessionsServiceClient +{ +public: + struct Options + { + std::string TargetUrl; // Base URL of the target zenserver (e.g. "http://localhost:8558") + std::string AppName; // Application name to register + std::string Mode; // Server mode (e.g. "Server", "Compute", "Proxy") + Oid SessionId = Oid::Zero; // Session ID to register under + Oid JobId = Oid::Zero; // Optional job ID + HttpClientSettings ClientSettings; // Optional; timeouts are overridden internally (e.g. for unix sockets) + }; + + /// Command sent to the background worker thread. + struct SessionCommand + { + enum class Type : uint8_t + { + Announce, + UpdateMetadata, + Remove, + Log, + FlushLogs, + Shutdown + }; + + Type CommandType = Type::Log; + CbObject Metadata; // Announce, UpdateMetadata + logging::LogLevel LogLevel{}; // Log + CompactString LogMessage; // Log + }; + + explicit SessionsServiceClient(Options Opts); + ~SessionsServiceClient(); + + SessionsServiceClient(const SessionsServiceClient&) = delete; + SessionsServiceClient& operator=(const SessionsServiceClient&) = delete; + + /// POST /sessions/{id} — enqueues an announce command (fire-and-forget). + void Announce(CbObjectView Metadata = {}); + + /// PUT /sessions/{id} — enqueues a metadata update command (fire-and-forget). + void UpdateMetadata(CbObjectView Metadata = {}); + + /// DELETE /sessions/{id} — enqueues a remove command (fire-and-forget). + void Remove(); + + /// Create a logging sink that forwards log messages to the session's log endpoint. + /// The sink enqueues messages onto this client's worker thread. + /// The returned sink can be added to any logger via Logger::AddSink(). + logging::SinkPtr CreateLogSink(); + + const Options& GetOptions() const { return m_Options; } + const std::string& GetSessionPath() const { return m_SessionPath; } + +private: + CbObject BuildRequestBody(CbObjectView Metadata) const; + + void WorkerLoop(); + void DoAnnounce(HttpClient& Http, CbObjectView Metadata); + void DoUpdateMetadata(HttpClient& Http, CbObjectView Metadata); + void DoRemove(HttpClient& Http); + void SendLogBatch(HttpClient& Http, const std::string& LogPath, const std::vector<SessionCommand>& Batch); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + Options m_Options; + std::string m_SessionPath; // "/sessions/<hex>" + BlockingQueue<SessionCommand> m_Queue; + std::thread m_WorkerThread; +}; + +} // namespace zen diff --git a/src/zenutil/include/zenutil/splitconsole/logstreamlistener.h b/src/zenutil/include/zenutil/splitconsole/logstreamlistener.h new file mode 100644 index 000000000..f3b960f51 --- /dev/null +++ b/src/zenutil/include/zenutil/splitconsole/logstreamlistener.h @@ -0,0 +1,61 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <cstdint> +#include <memory> +#include <string_view> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio/io_context.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +/// Abstract target for log lines received over a TCP log stream. +class LogStreamTarget +{ +public: + virtual ~LogStreamTarget() = default; + + /// Called (potentially from any thread) when a log line is received. + virtual void AppendLogLine(std::string_view Text) = 0; +}; + +/// TCP listener that accepts connections from remote processes streaming log messages. +/// Each message is a CbObject with fields: "text" (string), "source" (string), "level" (string, optional). +/// +/// Two modes of operation: +/// - Owned thread: pass only Target and Port; an internal IO thread is created. +/// - External io_context: pass an existing asio::io_context; no thread is created, +/// the caller is responsible for running the io_context. +class LogStreamListener +{ +public: + /// Start listening with an internal IO thread. + LogStreamListener(LogStreamTarget& Target, uint16_t Port = 0); + + /// Start listening on an externally-driven io_context (no thread created). + LogStreamListener(LogStreamTarget& Target, asio::io_context& IoContext, uint16_t Port = 0); + + ~LogStreamListener(); + + LogStreamListener(const LogStreamListener&) = delete; + LogStreamListener& operator=(const LogStreamListener&) = delete; + + /// Returns the actual port the listener is bound to. + uint16_t GetPort() const; + + /// Gracefully stop accepting new connections and shut down existing sessions. + void Shutdown(); + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void logstreamlistener_forcelink(); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h b/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h new file mode 100644 index 000000000..e59ebc7f4 --- /dev/null +++ b/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h @@ -0,0 +1,193 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinarybuilder.h> +#include <zencore/logging/sink.h> +#include <zencore/thread.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <EASTL/fixed_vector.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <atomic> +#include <condition_variable> +#include <deque> +#include <mutex> +#include <string> +#include <thread> +#include <vector> + +namespace zen { + +/// Logging sink that connects to a LogStreamListener via TCP and sends CbObject-framed log messages. +/// Connection is lazy (on first enqueued message) and silently reconnects on failure. +/// Messages are serialized on the caller thread and written asynchronously from a dedicated IO thread. +/// A bounded queue drops the oldest messages on overflow to prevent unbounded memory growth. +class TcpLogStreamSink : public logging::Sink +{ +public: + TcpLogStreamSink(const std::string& Host, uint16_t Port, std::string Source, uint32_t MaxQueueSize = 4096) + : m_Host(Host) + , m_Port(Port) + , m_Source(std::move(Source)) + , m_MaxQueueSize(MaxQueueSize) + { + m_IoThread = std::thread([this]() { IoThreadMain(); }); + } + + ~TcpLogStreamSink() override + { + { + std::lock_guard<std::mutex> Lock(m_QueueMutex); + m_Stopping = true; + m_DrainDeadline = std::chrono::steady_clock::now() + std::chrono::seconds(2); + } + m_QueueCv.notify_one(); + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + } + + void Log(const logging::LogMessage& Msg) override + { + std::string_view Text = Msg.GetPayload(); + + // Strip trailing newlines + while (!Text.empty() && (Text.back() == '\n' || Text.back() == '\r')) + { + Text.remove_suffix(1); + } + + uint64_t Seq = m_NextSequence.fetch_add(1, std::memory_order_relaxed); + + // Build CbObject with text, source, level, and sequence number fields + CbObjectWriter Writer; + Writer.AddString("text", Text); + Writer.AddString("source", m_Source); + Writer.AddString("level", logging::ToString(Msg.GetLevel())); + Writer.AddInteger("seq", Seq); + CbObject Obj = Writer.Save(); + + // Enqueue for async write + { + std::lock_guard<std::mutex> Lock(m_QueueMutex); + while (m_Queue.size() >= m_MaxQueueSize) + { + m_Queue.pop_front(); + } + m_Queue.push_back(std::move(Obj)); + } + m_QueueCv.notify_one(); + } + + void Flush() override + { + // Nothing to flush - writes happen asynchronously + } + + void SetFormatter(std::unique_ptr<logging::Formatter> /*InFormatter*/) override + { + // Not used - we output the raw payload directly + } + +private: + void IoThreadMain() + { + zen::SetCurrentThreadName("TcpLogSink"); + + for (;;) + { + std::deque<CbObject> Batch; + { + std::unique_lock<std::mutex> Lock(m_QueueMutex); + m_QueueCv.wait(Lock, [this]() { return m_Stopping || !m_Queue.empty(); }); + + if (m_Stopping && m_Queue.empty()) + { + break; + } + + if (m_Stopping && std::chrono::steady_clock::now() >= m_DrainDeadline) + { + break; + } + + Batch.swap(m_Queue); + } + + if (!m_Connected && !Connect()) + { + if (m_Stopping) + { + break; // don't retry during shutdown + } + continue; // drop batch - will retry on next batch + } + + // Build a gathered buffer sequence so the entire batch is written + // in a single socket operation (or as few as the OS needs). + eastl::fixed_vector<asio::const_buffer, 64> Buffers; + Buffers.reserve(Batch.size()); + for (auto& Obj : Batch) + { + MemoryView View = Obj.GetView(); + Buffers.emplace_back(View.GetData(), View.GetSize()); + } + + asio::error_code Ec; + asio::write(m_Socket, Buffers, Ec); + if (Ec) + { + m_Connected = false; + } + } + } + + bool Connect() + { + try + { + asio::ip::tcp::resolver Resolver(m_IoContext); + auto Endpoints = Resolver.resolve(m_Host, std::to_string(m_Port)); + asio::connect(m_Socket, Endpoints); + m_Connected = true; + return true; + } + catch (const std::exception&) + { + // Reset the socket for next attempt + m_Socket = asio::ip::tcp::socket(m_IoContext); + m_Connected = false; + return false; + } + } + + // IO thread state (only accessed from m_IoThread) + asio::io_context m_IoContext; + asio::ip::tcp::socket m_Socket{m_IoContext}; + bool m_Connected = false; + + // Configuration (immutable after construction) + std::string m_Host; + uint16_t m_Port; + std::string m_Source; + uint32_t m_MaxQueueSize; + + // Sequence counter - incremented atomically by Log() callers. + // Gaps in the sequence seen by the receiver indicate dropped messages. + std::atomic<uint64_t> m_NextSequence{0}; + + // Queue shared between Log() callers and IO thread + std::mutex m_QueueMutex; + std::condition_variable m_QueueCv; + std::deque<CbObject> m_Queue; + bool m_Stopping = false; + std::chrono::steady_clock::time_point m_DrainDeadline; + + std::thread m_IoThread; +}; + +} // namespace zen diff --git a/src/zenutil/include/zenutil/testartifactprovider.h b/src/zenutil/include/zenutil/testartifactprovider.h new file mode 100644 index 000000000..77af7b850 --- /dev/null +++ b/src/zenutil/include/zenutil/testartifactprovider.h @@ -0,0 +1,109 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenutil/cloud/s3client.h> + +#include <zenbase/refcount.h> +#include <zencore/iobuffer.h> + +#include <filesystem> +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +struct TestArtifactInfo +{ + std::string RelativePath; + uint64_t Size = 0; +}; + +struct TestArtifactResult +{ + std::string Error; + + bool IsSuccess() const { return Error.empty(); } + explicit operator bool() const { return IsSuccess(); } +}; + +struct TestArtifactFetchResult : TestArtifactResult +{ + IoBuffer Content; +}; + +struct TestArtifactListResult : TestArtifactResult +{ + std::vector<TestArtifactInfo> Artifacts; +}; + +/// Provider for read/write access to test artifacts. +/// +/// Artifacts are identified by a forward-slash relative path (e.g. "traces/powercycle.utrace"). +/// A provider is backed by a writable local cache directory and, optionally, an S3 source +/// used as the read-through primary. Fetch consults the cache first; on miss it pulls from the +/// primary (if configured) and writes the result back into the cache. +class TestArtifactProvider : public RefCounted +{ +public: + virtual std::string Describe() const = 0; + virtual bool Exists(std::string_view RelativePath) = 0; + virtual TestArtifactFetchResult Fetch(std::string_view RelativePath) = 0; + virtual TestArtifactListResult List(std::string_view Prefix) = 0; + virtual TestArtifactResult Store(std::string_view RelativePath, IoBuffer Content) = 0; +}; + +/// Environment variable specifying the local cache directory (overrides GetDefaultLocalTestArtifactPath()). +inline constexpr std::string_view kTestArtifactsPathEnvVar = "ZEN_TEST_ARTIFACTS_PATH"; + +/// Environment variable specifying the primary S3 source. Expected format: +/// "s3://<bucket>[/<prefix>]" (the "s3://" scheme is optional). +inline constexpr std::string_view kTestArtifactsS3EnvVar = "ZEN_TEST_ARTIFACTS_S3"; + +/// Directory name placed next to the default zenserver state directory for the local cache. +inline constexpr std::string_view kDefaultLocalTestArtifactDirName = "zen-artifact-cache"; + +/// Returns the default local cache root, derived from PickDefaultSystemRootDirectory() by +/// appending kDefaultLocalTestArtifactDirName. Returns an empty path if no system root can be +/// determined on the current platform. +std::filesystem::path GetDefaultLocalTestArtifactPath(); + +struct TestArtifactProviderOptions +{ + /// Local cache directory. If empty, ZEN_TEST_ARTIFACTS_PATH is consulted, then + /// GetDefaultLocalTestArtifactPath() is used as a final fallback. + std::filesystem::path CacheDir; + + /// Optional remote S3 primary source. Empty fields are filled from environment variables: + /// ZEN_TEST_ARTIFACTS_S3 for bucket + key prefix, AWS_DEFAULT_REGION / AWS_REGION for region, + /// AWS_ENDPOINT_URL for endpoint, AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY / AWS_SESSION_TOKEN + /// for credentials. + S3ClientOptions S3Client; + std::string S3KeyPrefix; +}; + +/// Create a test artifact provider. Combines a writable local cache with an optional S3 primary +/// source; Fetch consults the cache first and populates it on remote hits. +/// +/// With default-constructed options the provider is configured entirely from environment variables +/// (see kTestArtifactsPathEnvVar / kTestArtifactsS3EnvVar / AWS_*). Returns a null reference only +/// on platforms where no cache directory can be determined. +Ref<TestArtifactProvider> CreateTestArtifactProvider(TestArtifactProviderOptions Options = {}); + +/// Returns true when at least one test artifact source is configured and reachable: either a +/// local artifact directory (ZEN_TEST_ARTIFACTS_PATH) or an S3 source with signing credentials +/// available (see S3TestArtifactsAvailable). Intended for doctest::skip decorators: +/// TEST_CASE("needs_artifacts" * doctest::skip(!TestArtifactsAvailable())) +bool TestArtifactsAvailable(); + +/// Returns true when the S3 test artifact source is configured (ZEN_TEST_ARTIFACTS_S3) AND the +/// current environment can sign requests: either static AWS credentials are present +/// (AWS_ACCESS_KEY_ID / AWS_SESSION_TOKEN) or the EC2 Instance Metadata Service is usable +/// (not macOS, not explicitly disabled via AWS_EC2_METADATA_DISABLED=true). +/// TEST_CASE("needs_s3" * doctest::skip(!S3TestArtifactsAvailable())) +bool S3TestArtifactsAvailable(); + +void testartifactprovider_forcelink(); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/windows/windowsservice.h b/src/zenutil/include/zenutil/windows/windowsservice.h index ca0270a36..d7b3347a9 100644 --- a/src/zenutil/include/zenutil/windows/windowsservice.h +++ b/src/zenutil/include/zenutil/windows/windowsservice.h @@ -8,7 +8,7 @@ class WindowsService { public: WindowsService(); - ~WindowsService(); + virtual ~WindowsService(); virtual int Run() = 0; diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h index 308ae0ef2..2fa212f92 100644 --- a/src/zenutil/include/zenutil/zenserverprocess.h +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -12,6 +12,8 @@ #include <filesystem> #include <functional> #include <optional> +#include <stdexcept> +#include <string> namespace zen { @@ -64,6 +66,7 @@ public: std::filesystem::path CreateNewTestDir(); std::filesystem::path CreateChildDir(std::string_view ChildName); std::filesystem::path ProgramBaseDir() const { return m_ProgramBaseDir; } + std::filesystem::path GetChildBaseDir() const { return m_ChildProcessBaseDir; } std::filesystem::path GetTestRootDir(std::string_view Path); inline bool IsInitialized() const { return m_IsInitialized; } inline bool IsTestEnvironment() const { return m_IsTestInstance; } @@ -133,12 +136,19 @@ struct ZenServerInstance #endif bool IsRunning() const; bool Terminate(); + void ResetDeadProcess(); std::string GetLogOutput() const; inline ServerMode GetServerMode() const { return m_ServerMode; } inline void SetServerExecutablePath(std::filesystem::path ExecutablePath) { m_ServerExecutablePath = ExecutablePath; } + // Controls whether the spawned server records its invocation in the per-user + // execution history file. Default is false so tests, hubs, and auto-spawned + // instances do not pollute the user's history. User-invoked launches (zen up) + // should enable this so the spawn shows up in `zen history`. + inline void SetEnableExecutionHistory(bool Enable) { m_EnableExecutionHistory = Enable; } + void SetDataDir(std::filesystem::path TestDir); inline void SpawnServer(std::string_view AdditionalServerArgs = std::string_view()) @@ -187,6 +197,7 @@ private: std::string m_Name; std::filesystem::path m_OutputCapturePath; std::filesystem::path m_ServerExecutablePath; + bool m_EnableExecutionHistory = false; #if ZEN_PLATFORM_WINDOWS JobObject* m_JobObject = nullptr; #endif @@ -330,4 +341,29 @@ CbObject MakeLockFilePayload(const LockFileInfo& Info); LockFileInfo ReadLockFilePayload(const CbObject& Payload); bool ValidateLockFileInfo(const LockFileInfo& Info, std::string& OutReason); +struct StartupZenServerOptions +{ + std::filesystem::path ProgramBaseDir; // empty = auto-resolve from running executable + uint16_t Port = 0; + bool OpenConsole = false; // open a console window for the server process + bool ShowLog = false; // emit captured server log to LogRef on successful start + std::string ExtraArgs; // e.g. GlobalOptions.PassthroughCommandLine + ZenServerInstance::ServerMode Mode = ZenServerInstance::ServerMode::kStorageServer; + bool EnableExecutionHistory = false; // record the spawned zenserver in `zen history` +}; + +// Returns std::nullopt if a matching server is already running (no action taken); logs instance info via LogRef. +// Returns 0 if the server was successfully started and is ready. +// Returns a non-zero exit code if startup failed; the captured server log is emitted via LogRef before returning. +std::optional<int> StartupZenServer(LoggerRef LogRef, const StartupZenServerOptions& Options); + +// Attempts graceful shutdown of a running server entry. +// First tries ZenServerInstance::SignalShutdown; falls back to +// ZenServerEntry::SignalShutdownRequest + polling. +// Returns true on successful shutdown, false if it timed out. +bool ShutdownZenServer(LoggerRef LogRef, + ZenServerState& State, + ZenServerState::ZenServerEntry* Entry, + const std::filesystem::path& ProgramBaseDir); + } // namespace zen diff --git a/src/zenutil/invocationhistory.cpp b/src/zenutil/invocationhistory.cpp new file mode 100644 index 000000000..077061752 --- /dev/null +++ b/src/zenutil/invocationhistory.cpp @@ -0,0 +1,308 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/invocationhistory.h> + +#include <zencore/basicfile.h> +#include <zencore/compactbinary.h> +#include <zencore/filesystem.h> +#include <zencore/iobuffer.h> +#include <zencore/memoryview.h> +#include <zencore/process.h> +#include <zencore/uid.h> +#include <zenutil/config/commandlineoptions.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <system_error> + +namespace zen { + +namespace { + + constexpr size_t kMaxRecords = 100; + constexpr std::string_view kHistoryFileName = "invocations.jsonl"; + + // Safety cap. With 100 records at typical ~500-1000 bytes each the file + // normally sits around 50-100 KB. If it has grown past this threshold + // (external corruption, runaway producer, another tool writing garbage) + // we refuse to read it and start fresh with just the new record. Keeps + // LogInvocation from slowing startup on a pathological file. + constexpr uintmax_t kMaxReadSize = 1 * 1024 * 1024; // 1 MB + + bool ExecutionHistoryDisabled(int argc, char** argv) + { + for (int I = 1; I < argc; ++I) + { + if (argv[I] == nullptr) + { + continue; + } + std::string_view A = argv[I]; + if (A == "--enable-execution-history=false" || A == "--enable-execution-history=0" || A == "--enable-execution-history=no") + { + return true; + } + } + return false; + } + + std::filesystem::path ResolveHistoryDir() + { +#if ZEN_PLATFORM_WINDOWS + std::string LocalAppData = GetEnvVariable("LOCALAPPDATA"); + if (!LocalAppData.empty()) + { + return std::filesystem::path(LocalAppData) / "Epic" / "Zen" / "History"; + } +#endif + std::filesystem::path SystemRoot = PickDefaultSystemRootDirectory(); + if (SystemRoot.empty()) + { + return {}; + } + return SystemRoot / "History"; + } + + std::string BuildJsonRecord(const HistoryRecord& Rec) + { + json11::Json::object Obj{ + {"id", Rec.Id}, + {"ts", Rec.Ts}, + {"exe", Rec.Exe}, + {"pid", static_cast<int>(Rec.Pid)}, + {"cwd", Rec.Cwd}, + {"path", Rec.Path}, + {"cmdline", Rec.CmdLine}, + }; + if (!Rec.Mode.empty()) + { + Obj.emplace("mode", Rec.Mode); + } + return json11::Json(Obj).dump(); + } + + bool ParseJsonRecord(std::string_view Line, HistoryRecord& OutRec) + { + std::string Err; + json11::Json J = json11::Json::parse(std::string(Line), Err); + if (!Err.empty() || !J.is_object()) + { + return false; + } + OutRec.Id = J["id"].string_value(); + OutRec.Ts = J["ts"].string_value(); + OutRec.Exe = J["exe"].string_value(); + OutRec.Mode = J["mode"].string_value(); + OutRec.Cwd = J["cwd"].string_value(); + OutRec.Path = J["path"].string_value(); + OutRec.CmdLine = J["cmdline"].string_value(); + OutRec.Pid = static_cast<uint32_t>(J["pid"].int_value()); + return true; + } + + std::vector<std::string> ReadHistoryLines(const std::filesystem::path& Path) + { + std::vector<std::string> Lines; + + std::error_code SizeEc; + const std::uintmax_t FileSize = std::filesystem::file_size(Path, SizeEc); + if (SizeEc || FileSize > kMaxReadSize) + { + return Lines; + } + + FileContents Contents = ReadFile(Path); + if (!Contents) + { + return Lines; + } + IoBuffer Flat = Contents.Flatten(); + const char* Data = static_cast<const char*>(Flat.GetData()); + const size_t Size = Flat.GetSize(); + size_t Start = 0; + for (size_t I = 0; I < Size; ++I) + { + if (Data[I] == '\n') + { + if (I > Start) + { + size_t LineEnd = I; + if (LineEnd > Start && Data[LineEnd - 1] == '\r') + { + --LineEnd; + } + if (LineEnd > Start) + { + Lines.emplace_back(Data + Start, LineEnd - Start); + } + } + Start = I + 1; + } + } + if (Start < Size) + { + Lines.emplace_back(Data + Start, Size - Start); + } + return Lines; + } + +} // namespace + +std::filesystem::path +GetInvocationHistoryPath() noexcept +{ + try + { + std::filesystem::path Dir = ResolveHistoryDir(); + if (Dir.empty()) + { + return {}; + } + return Dir / kHistoryFileName; + } + catch (...) + { + return {}; + } +} + +void +LogInvocation(std::string_view Exe, + std::string_view Mode, + int argc, + char** argv, + std::initializer_list<std::string_view> ExcludeSubcommands) noexcept +{ + try + { + if (ExecutionHistoryDisabled(argc, argv)) + { + return; + } + + if (argc >= 2 && argv[1] != nullptr) + { + std::string_view A1 = argv[1]; + for (std::string_view Excluded : ExcludeSubcommands) + { + if (A1 == Excluded) + { + return; + } + } + } + + std::filesystem::path Dir = ResolveHistoryDir(); + if (Dir.empty()) + { + return; + } + + std::error_code Ec; + CreateDirectories(Dir, Ec); + if (Ec) + { + return; + } + + std::filesystem::path Path = Dir / kHistoryFileName; + + HistoryRecord Rec; + Rec.Id = Oid::NewOid().ToString(); + Rec.Ts = DateTime::Now().ToIso8601(); + Rec.Exe = std::string(Exe); + Rec.Mode = std::string(Mode); + + std::error_code CwdEc; + Rec.Cwd = std::filesystem::current_path(CwdEc).string(); + + Rec.Path = GetRunningExecutablePath().string(); + Rec.Pid = static_cast<uint32_t>(GetCurrentProcessId()); + + std::string Raw = GetRawCommandLine(); + if (Raw.empty()) + { + std::vector<std::string> Args; + Args.reserve(argc); + for (int I = 0; I < argc; ++I) + { + if (argv[I] != nullptr) + { + Args.emplace_back(argv[I]); + } + } + Raw = BuildCommandLine(Args); + } + ScrubSensitiveValues(Raw); + Rec.CmdLine = std::move(Raw); + + std::vector<std::string> Lines = ReadHistoryLines(Path); + if (Lines.size() >= kMaxRecords) + { + Lines.erase(Lines.begin(), Lines.begin() + (Lines.size() - (kMaxRecords - 1))); + } + Lines.push_back(BuildJsonRecord(Rec)); + + std::string NewContents; + size_t TotalSize = 0; + for (const std::string& L : Lines) + { + TotalSize += L.size() + 1; + } + NewContents.reserve(TotalSize); + for (const std::string& L : Lines) + { + NewContents.append(L); + NewContents.push_back('\n'); + } + + std::error_code WriteEc; + TemporaryFile::SafeWriteFile(Path, MemoryView(NewContents.data(), NewContents.size()), WriteEc); + } + catch (...) + { + } +} + +std::vector<HistoryRecord> +ReadInvocationHistory(size_t MaxRecords) +{ + std::vector<HistoryRecord> Records; + try + { + std::filesystem::path Path = GetInvocationHistoryPath(); + if (Path.empty()) + { + return Records; + } + + std::vector<std::string> Lines = ReadHistoryLines(Path); + if (Lines.size() > MaxRecords) + { + Lines.erase(Lines.begin(), Lines.begin() + (Lines.size() - MaxRecords)); + } + + Records.reserve(Lines.size()); + for (const std::string& L : Lines) + { + HistoryRecord Rec; + if (ParseJsonRecord(L, Rec)) + { + Records.push_back(std::move(Rec)); + } + } + } + catch (...) + { + } + return Records; +} + +void +invocationhistory_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenutil/logging/fullformatter.cpp b/src/zenutil/logging/fullformatter.cpp index 2a4840241..283a8bc37 100644 --- a/src/zenutil/logging/fullformatter.cpp +++ b/src/zenutil/logging/fullformatter.cpp @@ -12,6 +12,7 @@ #include <atomic> #include <chrono> #include <string> +#include "zencore/logging.h" namespace zen::logging { @@ -25,7 +26,7 @@ struct FullFormatter::Impl { } - explicit Impl(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {} + explicit Impl(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' ') {} std::chrono::time_point<std::chrono::system_clock> m_Epoch; std::tm m_CachedLocalTm{}; @@ -155,15 +156,7 @@ FullFormatter::Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) OutBuffer.push_back(' '); } - // append logger name if exists - if (Msg.GetLoggerName().size() > 0) - { - OutBuffer.push_back('['); - helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - } - + // level OutBuffer.push_back('['); if (IsColorEnabled()) { @@ -177,6 +170,23 @@ FullFormatter::Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) OutBuffer.push_back(']'); OutBuffer.push_back(' '); + // logger name + if (Msg.GetLoggerName().size() > 0) + { + if (IsColorEnabled()) + { + OutBuffer.append("\033[1m"sv); + } + OutBuffer.push_back('['); + helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); + OutBuffer.push_back(']'); + if (IsColorEnabled()) + { + OutBuffer.append("\033[0m"sv); + } + OutBuffer.push_back(' '); + } + // add source location if present if (Msg.GetSource()) { diff --git a/src/zenutil/logging/jsonformatter.cpp b/src/zenutil/logging/jsonformatter.cpp index 673a03c94..c63ad891e 100644 --- a/src/zenutil/logging/jsonformatter.cpp +++ b/src/zenutil/logging/jsonformatter.cpp @@ -19,8 +19,6 @@ static void WriteEscapedString(MemoryBuffer& Dest, std::string_view Text) { // Strip ANSI SGR sequences before escaping so they don't appear in JSON output - static const auto IsEscapeStart = [](char C) { return C == '\033'; }; - const char* RangeStart = Text.data(); const char* End = Text.data() + Text.size(); diff --git a/src/zenutil/logging/logging.cpp b/src/zenutil/logging/logging.cpp index ea2448a42..c1636da61 100644 --- a/src/zenutil/logging/logging.cpp +++ b/src/zenutil/logging/logging.cpp @@ -8,6 +8,7 @@ #include <zencore/logging.h> #include <zencore/logging/ansicolorsink.h> #include <zencore/logging/asyncsink.h> +#include <zencore/logging/broadcastsink.h> #include <zencore/logging/logger.h> #include <zencore/logging/msvcsink.h> #include <zencore/logging/registry.h> @@ -23,8 +24,9 @@ namespace zen { -static bool g_IsLoggingInitialized; -logging::SinkPtr g_FileSink; +static bool g_IsLoggingInitialized; +logging::SinkPtr g_FileSink; +Ref<logging::BroadcastSink> g_BroadcastSink; logging::SinkPtr GetFileSink() @@ -32,6 +34,12 @@ GetFileSink() return g_FileSink; } +Ref<logging::BroadcastSink> +GetDefaultBroadcastSink() +{ + return g_BroadcastSink; +} + void InitializeLogging(const LoggingOptions& LogOptions) { @@ -47,7 +55,6 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) ZEN_MEMSCOPE(ELLMTag::Logging); zen::logging::InitializeLogging(); - zen::logging::EnableVTMode(); // Sinks @@ -117,8 +124,10 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) LoggerRef DefaultLogger = zen::logging::Default(); - // Collect sinks into a local vector first so we can optionally wrap them - std::vector<logging::SinkPtr> Sinks; + // Build the broadcast sink - a shared indirection point that all + // loggers cloned from the default will share. Adding or removing + // a child sink later is immediately visible to every logger. + std::vector<logging::SinkPtr> BroadcastChildren; if (LogOptions.NoConsoleOutput) { @@ -126,17 +135,18 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) } else { - logging::SinkPtr ConsoleSink(new logging::AnsiColorStdoutSink()); + logging::SinkPtr ConsoleSink( + new logging::AnsiColorStdoutSink(LogOptions.ForceColor ? logging::ColorMode::On : logging::ColorMode::Auto)); if (LogOptions.QuietConsole) { ConsoleSink->SetLevel(logging::Warn); } - Sinks.push_back(ConsoleSink); + BroadcastChildren.push_back(ConsoleSink); } if (FileSink) { - Sinks.push_back(FileSink); + BroadcastChildren.push_back(FileSink); } #if ZEN_PLATFORM_WINDOWS @@ -144,21 +154,21 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) { logging::SinkPtr DebugSink(new logging::MsvcSink()); DebugSink->SetLevel(logging::Debug); - Sinks.push_back(DebugSink); + BroadcastChildren.push_back(DebugSink); } #endif + g_BroadcastSink = Ref<logging::BroadcastSink>(new logging::BroadcastSink(std::move(BroadcastChildren))); + bool IsAsync = LogOptions.AllowAsync && !LogOptions.IsDebug && !LogOptions.IsTest; if (IsAsync) { - std::vector<logging::SinkPtr> AsyncSinks; - AsyncSinks.emplace_back(new logging::AsyncSink(std::move(Sinks))); - DefaultLogger->SetSinks(std::move(AsyncSinks)); + DefaultLogger->SetSink(logging::SinkPtr(new logging::AsyncSink({logging::SinkPtr(g_BroadcastSink.Get())}))); } else { - DefaultLogger->SetSinks(std::move(Sinks)); + DefaultLogger->SetSink(logging::SinkPtr(g_BroadcastSink.Get())); } static struct : logging::ErrorHandler @@ -169,7 +179,7 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) { return; } - static constinit logging::LogPoint ErrorPoint{{}, logging::Err, "{}"}; + static constinit logging::LogPoint ErrorPoint{0, 0, logging::Err, "{}"}; if (auto ErrLogger = zen::logging::ErrorLog()) { try @@ -239,7 +249,7 @@ FinishInitializeLogging(const LoggingOptions& LogOptions) const std::string StartLogTime = zen::DateTime::Now().ToIso8601(); logging::Registry::Instance().ApplyAll([&](auto Logger) { - static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"}; + static constinit logging::LogPoint LogStartPoint{0, 0, logging::Info, "log starting at {}"}; Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime)); }); } @@ -258,7 +268,8 @@ ShutdownLogging() zen::logging::ShutdownLogging(); - g_FileSink = nullptr; + g_FileSink = nullptr; + g_BroadcastSink = nullptr; } } // namespace zen diff --git a/src/zenutil/logging/rotatingfilesink.cpp b/src/zenutil/logging/rotatingfilesink.cpp index 23cf60d16..df59af5fe 100644 --- a/src/zenutil/logging/rotatingfilesink.cpp +++ b/src/zenutil/logging/rotatingfilesink.cpp @@ -85,7 +85,7 @@ struct RotatingFileSink::Impl m_CurrentSize = m_CurrentFile.FileSize(OutEc); if (OutEc) { - // FileSize failed but we have an open file — reset to 0 + // FileSize failed but we have an open file - reset to 0 // so we can at least attempt writes from the start m_CurrentSize = 0; OutEc.clear(); diff --git a/src/zenutil/parallelsort.cpp b/src/zenutil/parallelsort.cpp new file mode 100644 index 000000000..8a9f547bc --- /dev/null +++ b/src/zenutil/parallelsort.cpp @@ -0,0 +1,148 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/parallelsort.h> + +#include <zencore/testing.h> + +namespace zen { + +#if ZEN_WITH_TESTS + +void +parallelsort_forcelink() +{ +} + +TEST_SUITE_BEGIN("util.parallelsort"); + +TEST_CASE("empty range") +{ + WorkerThreadPool Pool(2); + eastl::vector<int> Vec; + ParallelSort(Pool, Vec.begin(), Vec.end(), [](int A, int B) { return A < B; }); + CHECK(Vec.empty()); +} + +TEST_CASE("single element") +{ + WorkerThreadPool Pool(2); + eastl::vector<int> Vec = {42}; + ParallelSort(Pool, Vec.begin(), Vec.end(), [](int A, int B) { return A < B; }); + CHECK(Vec.size() == 1); + CHECK(Vec[0] == 42); +} + +TEST_CASE("small array below threshold") +{ + WorkerThreadPool Pool(2); + eastl::vector<int> Vec = {5, 3, 8, 1, 9, 2, 7, 4, 6, 0}; + ParallelSort(Pool, Vec.begin(), Vec.end(), [](int A, int B) { return A < B; }); + for (size_t I = 0; I < Vec.size(); ++I) + { + CHECK(Vec[I] == int(I)); + } +} + +TEST_CASE("large array triggers parallel path") +{ + WorkerThreadPool Pool(4); + + // 200K elements — well above the 64K threshold. + constexpr size_t N = 200'000; + eastl::vector<uint32_t> Vec(N); + + // Fill with descending values. + for (size_t I = 0; I < N; ++I) + { + Vec[I] = uint32_t(N - 1 - I); + } + + ParallelSort(Pool, Vec.begin(), Vec.end(), [](uint32_t A, uint32_t B) { return A < B; }); + + for (size_t I = 0; I < N; ++I) + { + CHECK_MESSAGE(Vec[I] == uint32_t(I), "index=", I, " got=", Vec[I]); + } +} + +TEST_CASE("already sorted") +{ + WorkerThreadPool Pool(4); + + constexpr size_t N = 200'000; + eastl::vector<uint32_t> Vec(N); + for (size_t I = 0; I < N; ++I) + { + Vec[I] = uint32_t(I); + } + + ParallelSort(Pool, Vec.begin(), Vec.end(), [](uint32_t A, uint32_t B) { return A < B; }); + + for (size_t I = 0; I < N; ++I) + { + CHECK_MESSAGE(Vec[I] == uint32_t(I), "index=", I, " got=", Vec[I]); + } +} + +TEST_CASE("reverse sorted") +{ + WorkerThreadPool Pool(4); + + constexpr size_t N = 200'000; + eastl::vector<uint32_t> Vec(N); + for (size_t I = 0; I < N; ++I) + { + Vec[I] = uint32_t(N - 1 - I); + } + + ParallelSort(Pool, Vec.begin(), Vec.end(), [](uint32_t A, uint32_t B) { return A < B; }); + + for (size_t I = 0; I < N; ++I) + { + CHECK_MESSAGE(Vec[I] == uint32_t(I), "index=", I, " got=", Vec[I]); + } +} + +TEST_CASE("duplicate keys") +{ + WorkerThreadPool Pool(4); + + constexpr size_t N = 200'000; + eastl::vector<uint32_t> Vec(N); + for (size_t I = 0; I < N; ++I) + { + Vec[I] = uint32_t(I % 100); // only 100 distinct values + } + + ParallelSort(Pool, Vec.begin(), Vec.end(), [](uint32_t A, uint32_t B) { return A < B; }); + + for (size_t I = 1; I < N; ++I) + { + CHECK_MESSAGE(Vec[I - 1] <= Vec[I], "not sorted at index=", I); + } +} + +TEST_CASE("custom comparator descending") +{ + WorkerThreadPool Pool(4); + + constexpr size_t N = 200'000; + eastl::vector<uint32_t> Vec(N); + for (size_t I = 0; I < N; ++I) + { + Vec[I] = uint32_t(I); + } + + ParallelSort(Pool, Vec.begin(), Vec.end(), [](uint32_t A, uint32_t B) { return A > B; }); + + for (size_t I = 0; I < N; ++I) + { + CHECK_MESSAGE(Vec[I] == uint32_t(N - 1 - I), "index=", I, " got=", Vec[I]); + } +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen diff --git a/src/zenutil/process/asyncpipereader.cpp b/src/zenutil/process/asyncpipereader.cpp new file mode 100644 index 000000000..8eac350c6 --- /dev/null +++ b/src/zenutil/process/asyncpipereader.cpp @@ -0,0 +1,276 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "asyncpipereader.h" + +#include <zencore/logging.h> + +#include <array> + +ZEN_THIRD_PARTY_INCLUDES_START + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +# include <asio/io_context.hpp> +# include <asio/windows/stream_handle.hpp> +#else +# include <fcntl.h> +# include <unistd.h> +# include <asio/io_context.hpp> +# include <asio/posix/stream_descriptor.hpp> +#endif + +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +static constexpr size_t kReadBufferSize = 4096; + +// ============================================================================ +// POSIX: non-blocking pipe + stream_descriptor +// ============================================================================ + +#if !ZEN_PLATFORM_WINDOWS + +struct AsyncPipeReader::Impl +{ + asio::io_context& m_IoContext; + std::unique_ptr<asio::posix::stream_descriptor> m_Descriptor; + std::function<void(std::string_view)> m_DataCallback; + std::function<void()> m_EofCallback; + std::array<char, kReadBufferSize> m_Buffer{}; + + explicit Impl(asio::io_context& IoContext) : m_IoContext(IoContext) {} + + ~Impl() { Stop(); } + + void Start(StdoutPipeHandles&& Pipe, std::function<void(std::string_view)> DataCallback, std::function<void()> EofCallback) + { + m_DataCallback = std::move(DataCallback); + m_EofCallback = std::move(EofCallback); + + int Fd = Pipe.ReadFd; + + // Close the write end - child already has it + Pipe.CloseWriteEnd(); + + // Set non-blocking + int Flags = fcntl(Fd, F_GETFL, 0); + fcntl(Fd, F_SETFL, Flags | O_NONBLOCK); + + // Take ownership of the fd. Detach it from StdoutPipeHandles so it + // doesn't get double-closed. + Pipe.ReadFd = -1; + + m_Descriptor = std::make_unique<asio::posix::stream_descriptor>(m_IoContext, Fd); + EnqueueRead(); + } + + void Stop() + { + if (m_Descriptor) + { + asio::error_code Ec; + m_Descriptor->cancel(Ec); + m_Descriptor.reset(); + } + } + + void EnqueueRead() + { + if (!m_Descriptor) + { + return; + } + + m_Descriptor->async_read_some(asio::buffer(m_Buffer), [this](const asio::error_code& Ec, size_t BytesRead) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && m_EofCallback) + { + m_EofCallback(); + } + return; + } + + if (BytesRead > 0 && m_DataCallback) + { + m_DataCallback(std::string_view(m_Buffer.data(), BytesRead)); + } + + EnqueueRead(); + }); + } +}; + +bool +CreateOverlappedStdoutPipe(StdoutPipeHandles& OutPipe) +{ + // On POSIX, regular pipes work fine with non-blocking I/O + return CreateStdoutPipe(OutPipe); +} + +// ============================================================================ +// Windows: overlapped named pipe + asio::windows::stream_handle +// +// Anonymous pipes (CreatePipe) do not support overlapped I/O. Instead, we +// create a named pipe pair with FILE_FLAG_OVERLAPPED on the read (server) end. +// The write (client) end is inheritable and used as the child's stdout/stderr. +// +// Callers must use CreateOverlappedStdoutPipe() instead of CreateStdoutPipe() +// so the pipe is overlapped from the start. Passing a non-overlapped anonymous +// pipe to Start() will fail. +// ============================================================================ + +#else // ZEN_PLATFORM_WINDOWS + +static std::atomic<uint64_t> s_PipeSerial{0}; + +bool +CreateOverlappedStdoutPipe(StdoutPipeHandles& OutPipe) +{ + // Generate a unique pipe name + uint64_t Serial = s_PipeSerial.fetch_add(1); + wchar_t PipeName[128]; + swprintf_s(PipeName, + _countof(PipeName), + L"\\\\.\\pipe\\zen_async_%u_%llu", + GetCurrentProcessId(), + static_cast<unsigned long long>(Serial)); + + // Create the server (read) end with FILE_FLAG_OVERLAPPED + HANDLE ReadHandle = CreateNamedPipeW(PipeName, + PIPE_ACCESS_INBOUND | FILE_FLAG_OVERLAPPED, + PIPE_TYPE_BYTE | PIPE_WAIT, + 1, // max instances + 0, // out buffer size + kReadBufferSize, + 0, // default timeout + nullptr); + + if (ReadHandle == INVALID_HANDLE_VALUE) + { + ZEN_WARN("CreateNamedPipeW failed: {}", GetLastError()); + return false; + } + + // The read end should not be inherited by the child + SetHandleInformation(ReadHandle, HANDLE_FLAG_INHERIT, 0); + + // Open the client (write) end - inheritable, for the child process + SECURITY_ATTRIBUTES Sa; + Sa.nLength = sizeof(Sa); + Sa.lpSecurityDescriptor = nullptr; + Sa.bInheritHandle = TRUE; + + HANDLE WriteHandle = CreateFileW(PipeName, + GENERIC_WRITE, + 0, // no sharing + &Sa, // inheritable + OPEN_EXISTING, + 0, // no special flags on write end + nullptr); + + if (WriteHandle == INVALID_HANDLE_VALUE) + { + DWORD Err = GetLastError(); + CloseHandle(ReadHandle); + ZEN_WARN("CreateFileW for pipe client end failed: {}", Err); + return false; + } + + OutPipe.ReadHandle = ReadHandle; + OutPipe.WriteHandle = WriteHandle; + return true; +} + +struct AsyncPipeReader::Impl +{ + asio::io_context& m_IoContext; + std::unique_ptr<asio::windows::stream_handle> m_StreamHandle; + std::function<void(std::string_view)> m_DataCallback; + std::function<void()> m_EofCallback; + std::array<char, kReadBufferSize> m_Buffer{}; + + explicit Impl(asio::io_context& IoContext) : m_IoContext(IoContext) {} + + ~Impl() { Stop(); } + + void Start(StdoutPipeHandles&& Pipe, std::function<void(std::string_view)> DataCallback, std::function<void()> EofCallback) + { + m_DataCallback = std::move(DataCallback); + m_EofCallback = std::move(EofCallback); + + HANDLE ReadHandle = static_cast<HANDLE>(Pipe.ReadHandle); + + // Close the write end - child already has it + Pipe.CloseWriteEnd(); + + // Take ownership of the read handle + Pipe.ReadHandle = nullptr; + + m_StreamHandle = std::make_unique<asio::windows::stream_handle>(m_IoContext, ReadHandle); + EnqueueRead(); + } + + void Stop() + { + if (m_StreamHandle) + { + asio::error_code Ec; + m_StreamHandle->cancel(Ec); + m_StreamHandle.reset(); + } + } + + void EnqueueRead() + { + if (!m_StreamHandle) + { + return; + } + + m_StreamHandle->async_read_some(asio::buffer(m_Buffer), [this](const asio::error_code& Ec, size_t BytesRead) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && m_EofCallback) + { + m_EofCallback(); + } + return; + } + + if (BytesRead > 0 && m_DataCallback) + { + m_DataCallback(std::string_view(m_Buffer.data(), BytesRead)); + } + + EnqueueRead(); + }); + } +}; + +#endif + +// ============================================================================ +// Common wrapper +// ============================================================================ + +AsyncPipeReader::AsyncPipeReader(asio::io_context& IoContext) : m_Impl(std::make_unique<Impl>(IoContext)) +{ +} + +AsyncPipeReader::~AsyncPipeReader() = default; + +void +AsyncPipeReader::Start(StdoutPipeHandles&& Pipe, std::function<void(std::string_view)> DataCallback, std::function<void()> EofCallback) +{ + m_Impl->Start(std::move(Pipe), std::move(DataCallback), std::move(EofCallback)); +} + +void +AsyncPipeReader::Stop() +{ + m_Impl->Stop(); +} + +} // namespace zen diff --git a/src/zenutil/process/asyncpipereader.h b/src/zenutil/process/asyncpipereader.h new file mode 100644 index 000000000..ad2ff8455 --- /dev/null +++ b/src/zenutil/process/asyncpipereader.h @@ -0,0 +1,62 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/process.h> +#include <zencore/zencore.h> + +#include <functional> +#include <memory> +#include <string_view> + +namespace asio { +class io_context; +} + +namespace zen { + +/// Create an overlapped pipe pair suitable for async I/O on Windows. +/// +/// Unlike CreateStdoutPipe() (which creates anonymous non-overlapped pipes), +/// this creates a named pipe with FILE_FLAG_OVERLAPPED on the read end, so it +/// can be used with asio::windows::stream_handle for fully async reads. +/// The write end is inheritable and suitable for child process redirection. +/// +/// On non-Windows platforms this simply delegates to CreateStdoutPipe(). +bool CreateOverlappedStdoutPipe(StdoutPipeHandles& OutPipe); + +/// Async pipe reader for capturing child process stdout/stderr. +/// +/// Takes ownership of a pipe's read end and reads asynchronously: +/// Linux/macOS: non-blocking fd + asio::posix::stream_descriptor +/// Windows: overlapped named pipe + asio::windows::stream_handle +/// +/// On Windows the pipe must have been created with CreateOverlappedStdoutPipe() +/// for async I/O to work. Pipes from CreateStdoutPipe() will fail. +/// +/// DataCallback is invoked for each chunk read (on the io_context). +/// EofCallback is invoked when the pipe closes (child exited or pipe broken). +class AsyncPipeReader +{ +public: + explicit AsyncPipeReader(asio::io_context& IoContext); + ~AsyncPipeReader(); + + AsyncPipeReader(const AsyncPipeReader&) = delete; + AsyncPipeReader& operator=(const AsyncPipeReader&) = delete; + + /// Take ownership of the pipe read-end and start async reading. + /// The write end is closed immediately (caller should have already launched + /// the child process). DataCallback receives raw chunks. EofCallback fires + /// once when the pipe reaches EOF. + void Start(StdoutPipeHandles&& Pipe, std::function<void(std::string_view Data)> DataCallback, std::function<void()> EofCallback); + + /// Stop reading and close the pipe. Callbacks will not fire after this returns. + void Stop(); + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen diff --git a/src/zenutil/process/exitwatcher.cpp b/src/zenutil/process/exitwatcher.cpp new file mode 100644 index 000000000..cef31ebca --- /dev/null +++ b/src/zenutil/process/exitwatcher.cpp @@ -0,0 +1,294 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "exitwatcher.h" + +#include <zencore/logging.h> + +ZEN_THIRD_PARTY_INCLUDES_START + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +# include <asio/io_context.hpp> +# include <asio/windows/object_handle.hpp> +#elif ZEN_PLATFORM_LINUX +# include <sys/syscall.h> +# include <sys/wait.h> +# include <unistd.h> +# include <asio/io_context.hpp> +# include <asio/posix/stream_descriptor.hpp> + +# ifndef SYS_pidfd_open +# define SYS_pidfd_open 434 // x86_64 +# endif +#elif ZEN_PLATFORM_MAC +# include <sys/event.h> +# include <sys/wait.h> +# include <unistd.h> +# include <asio/io_context.hpp> +# include <asio/posix/stream_descriptor.hpp> +#endif + +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +// ============================================================================ +// Linux: pidfd_open + stream_descriptor +// ============================================================================ + +#if ZEN_PLATFORM_LINUX + +struct ProcessExitWatcher::Impl +{ + asio::io_context& m_IoContext; + std::unique_ptr<asio::posix::stream_descriptor> m_Descriptor; + int m_PidFd = -1; + int m_Pid = 0; + + explicit Impl(asio::io_context& IoContext) : m_IoContext(IoContext) {} + + ~Impl() { Cancel(); } + + void Watch(const ProcessHandle& Handle, std::function<void(int ExitCode)> OnExit) + { + m_Pid = Handle.Pid(); + + // pidfd_open returns an fd that becomes readable when the process exits. + // Available since Linux 5.3. + m_PidFd = static_cast<int>(syscall(SYS_pidfd_open, m_Pid, 0)); + if (m_PidFd < 0) + { + ZEN_WARN("pidfd_open failed for pid {}: {}", m_Pid, strerror(errno)); + return; + } + + m_Descriptor = std::make_unique<asio::posix::stream_descriptor>(m_IoContext, m_PidFd); + + m_Descriptor->async_wait(asio::posix::stream_descriptor::wait_read, + [this, Callback = std::move(OnExit)](const asio::error_code& Ec) { + if (Ec) + { + return; // Cancelled or error + } + + int ExitCode = -1; + int Status = 0; + // The pidfd told us the process exited. Reap it with waitpid. + if (waitpid(m_Pid, &Status, WNOHANG) > 0) + { + if (WIFEXITED(Status)) + { + ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + constexpr int kSignalExitBase = 128; + ExitCode = kSignalExitBase + WTERMSIG(Status); + } + } + + Callback(ExitCode); + }); + } + + void Cancel() + { + if (m_Descriptor) + { + asio::error_code Ec; + m_Descriptor->cancel(Ec); + m_Descriptor.reset(); + // stream_descriptor closes the fd on destruction, so don't close m_PidFd separately + m_PidFd = -1; + } + else if (m_PidFd >= 0) + { + close(m_PidFd); + m_PidFd = -1; + } + } +}; + +// ============================================================================ +// Windows: object_handle::async_wait +// ============================================================================ + +#elif ZEN_PLATFORM_WINDOWS + +struct ProcessExitWatcher::Impl +{ + asio::io_context& m_IoContext; + std::unique_ptr<asio::windows::object_handle> m_ObjectHandle; + void* m_DuplicatedHandle = nullptr; + + explicit Impl(asio::io_context& IoContext) : m_IoContext(IoContext) {} + + ~Impl() { Cancel(); } + + void Watch(const ProcessHandle& Handle, std::function<void(int ExitCode)> OnExit) + { + // Duplicate the process handle so ASIO can take ownership independently + HANDLE SourceHandle = static_cast<HANDLE>(Handle.Handle()); + HANDLE CurrentProcess = GetCurrentProcess(); + BOOL Success = DuplicateHandle(CurrentProcess, + SourceHandle, + CurrentProcess, + reinterpret_cast<LPHANDLE>(&m_DuplicatedHandle), + SYNCHRONIZE | PROCESS_QUERY_INFORMATION, + FALSE, + 0); + + if (!Success) + { + ZEN_WARN("DuplicateHandle failed for pid {}: {}", Handle.Pid(), GetLastError()); + return; + } + + // object_handle takes ownership of the handle + m_ObjectHandle = std::make_unique<asio::windows::object_handle>(m_IoContext, m_DuplicatedHandle); + + m_ObjectHandle->async_wait([this, DupHandle = m_DuplicatedHandle, Callback = std::move(OnExit)](const asio::error_code& Ec) { + if (Ec) + { + return; + } + + DWORD ExitCode = 0; + GetExitCodeProcess(static_cast<HANDLE>(DupHandle), &ExitCode); + Callback(static_cast<int>(ExitCode)); + }); + } + + void Cancel() + { + if (m_ObjectHandle) + { + asio::error_code Ec; + m_ObjectHandle->cancel(Ec); + m_ObjectHandle.reset(); // Closes the duplicated handle + m_DuplicatedHandle = nullptr; + } + else if (m_DuplicatedHandle) + { + CloseHandle(static_cast<HANDLE>(m_DuplicatedHandle)); + m_DuplicatedHandle = nullptr; + } + } +}; + +// ============================================================================ +// macOS: kqueue EVFILT_PROC + stream_descriptor +// ============================================================================ + +#elif ZEN_PLATFORM_MAC + +struct ProcessExitWatcher::Impl +{ + asio::io_context& m_IoContext; + std::unique_ptr<asio::posix::stream_descriptor> m_Descriptor; + int m_KqueueFd = -1; + int m_Pid = 0; + + explicit Impl(asio::io_context& IoContext) : m_IoContext(IoContext) {} + + ~Impl() { Cancel(); } + + void Watch(const ProcessHandle& Handle, std::function<void(int ExitCode)> OnExit) + { + m_Pid = Handle.Pid(); + + m_KqueueFd = kqueue(); + if (m_KqueueFd < 0) + { + ZEN_WARN("kqueue() failed for pid {}: {}", m_Pid, strerror(errno)); + return; + } + + // Register interest in the process exit event + struct kevent Change; + EV_SET(&Change, static_cast<uintptr_t>(m_Pid), EVFILT_PROC, EV_ADD | EV_ONESHOT, NOTE_EXIT, 0, nullptr); + + if (kevent(m_KqueueFd, &Change, 1, nullptr, 0, nullptr) < 0) + { + ZEN_WARN("kevent register failed for pid {}: {}", m_Pid, strerror(errno)); + close(m_KqueueFd); + m_KqueueFd = -1; + return; + } + + m_Descriptor = std::make_unique<asio::posix::stream_descriptor>(m_IoContext, m_KqueueFd); + + m_Descriptor->async_wait(asio::posix::stream_descriptor::wait_read, + [this, Callback = std::move(OnExit)](const asio::error_code& Ec) { + if (Ec) + { + return; + } + + // Drain the kqueue event + struct kevent Event; + struct timespec Timeout = {0, 0}; + kevent(m_KqueueFd, nullptr, 0, &Event, 1, &Timeout); + + int ExitCode = -1; + int Status = 0; + if (waitpid(m_Pid, &Status, WNOHANG) > 0) + { + if (WIFEXITED(Status)) + { + ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + constexpr int kSignalExitBase = 128; + ExitCode = kSignalExitBase + WTERMSIG(Status); + } + } + + Callback(ExitCode); + }); + } + + void Cancel() + { + if (m_Descriptor) + { + asio::error_code Ec; + m_Descriptor->cancel(Ec); + m_Descriptor.reset(); + // stream_descriptor closes the kqueue fd on destruction + m_KqueueFd = -1; + } + else if (m_KqueueFd >= 0) + { + close(m_KqueueFd); + m_KqueueFd = -1; + } + } +}; + +#endif + +// ============================================================================ +// Common wrapper (delegates to Impl) +// ============================================================================ + +ProcessExitWatcher::ProcessExitWatcher(asio::io_context& IoContext) : m_Impl(std::make_unique<Impl>(IoContext)) +{ +} + +ProcessExitWatcher::~ProcessExitWatcher() = default; + +void +ProcessExitWatcher::Watch(const ProcessHandle& Handle, std::function<void(int ExitCode)> OnExit) +{ + m_Impl->Watch(Handle, std::move(OnExit)); +} + +void +ProcessExitWatcher::Cancel() +{ + m_Impl->Cancel(); +} + +} // namespace zen diff --git a/src/zenutil/process/exitwatcher.h b/src/zenutil/process/exitwatcher.h new file mode 100644 index 000000000..24906d7d0 --- /dev/null +++ b/src/zenutil/process/exitwatcher.h @@ -0,0 +1,48 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/process.h> +#include <zencore/zencore.h> + +#include <functional> +#include <memory> + +namespace asio { +class io_context; +} + +namespace zen { + +/// Async process exit watcher. +/// +/// Uses platform-specific mechanisms for scalable, non-polling exit detection: +/// Linux: pidfd_open() + asio::posix::stream_descriptor +/// Windows: asio::windows::object_handle +/// macOS: kqueue EVFILT_PROC/NOTE_EXIT + asio::posix::stream_descriptor +/// +/// The callback is invoked exactly once when the process exits, posted to the +/// io_context. Call Cancel() to suppress the callback. +class ProcessExitWatcher +{ +public: + explicit ProcessExitWatcher(asio::io_context& IoContext); + ~ProcessExitWatcher(); + + ProcessExitWatcher(const ProcessExitWatcher&) = delete; + ProcessExitWatcher& operator=(const ProcessExitWatcher&) = delete; + + /// Begin watching the given process. The callback is posted to the io_context + /// when the process exits. Only one Watch() may be active at a time. + void Watch(const ProcessHandle& Handle, std::function<void(int ExitCode)> OnExit); + + /// Cancel any outstanding watch. The callback will not be invoked after this + /// returns. Safe to call if no watch is active. + void Cancel(); + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen diff --git a/src/zenutil/process/subprocessmanager.cpp b/src/zenutil/process/subprocessmanager.cpp new file mode 100644 index 000000000..d0b912a0d --- /dev/null +++ b/src/zenutil/process/subprocessmanager.cpp @@ -0,0 +1,1977 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/process/subprocessmanager.h> + +#include "asyncpipereader.h" +#include "exitwatcher.h" + +#include <zencore/logging.h> +#include <zencore/thread.h> +#include <zencore/timer.h> +#include <zencore/trace.h> + +#include <algorithm> +#include <atomic> +#include <numeric> +#include <random> +#include <string> +#include <unordered_map> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#else +# include <csignal> +#endif +#include <asio/io_context.hpp> +#include <asio/steady_timer.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +// ============================================================================ +// ManagedProcess::Impl +// ============================================================================ + +struct ManagedProcess::Impl +{ + asio::io_context& m_IoContext; + ProcessHandle m_Handle; + ProcessExitWatcher m_ExitWatcher; + ProcessExitCallback m_ExitCallback; + std::atomic<bool> m_Exited{false}; + + // Stdout capture + std::unique_ptr<AsyncPipeReader> m_StdoutReader; + ProcessDataCallback m_StdoutCallback; + mutable RwLock m_StdoutLock; + std::string m_CapturedStdout; + + // Stderr capture + std::unique_ptr<AsyncPipeReader> m_StderrReader; + ProcessDataCallback m_StderrCallback; + mutable RwLock m_StderrLock; + std::string m_CapturedStderr; + + // Metrics + ProcessMetrics m_LastMetrics; + std::atomic<float> m_CpuUsagePercent{-1.0f}; + uint64_t m_PrevUserTimeMs = 0; + uint64_t m_PrevKernelTimeMs = 0; + uint64_t m_PrevSampleTicks = 0; + + // User tag + std::string m_Tag; + + explicit Impl(asio::io_context& IoContext) : m_IoContext(IoContext), m_ExitWatcher(IoContext) {} + + void OnStdoutData(ManagedProcess& Self, std::string_view Data) + { + if (m_StdoutCallback) + { + m_StdoutCallback(Self, Data); + } + else + { + RwLock::ExclusiveLockScope $(m_StdoutLock); + m_CapturedStdout.append(Data); + } + } + + void OnStderrData(ManagedProcess& Self, std::string_view Data) + { + if (m_StderrCallback) + { + m_StderrCallback(Self, Data); + } + else + { + RwLock::ExclusiveLockScope $(m_StderrLock); + m_CapturedStderr.append(Data); + } + } + + void SampleMetrics() + { + if (m_Exited.load()) + { + return; + } + + ProcessMetrics Metrics; + GetProcessMetrics(m_Handle, Metrics); + + uint64_t NowTicks = GetHifreqTimerValue(); + + if (m_PrevSampleTicks > 0) + { + uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - m_PrevSampleTicks); + uint64_t DeltaCpuTimeMs = (Metrics.UserTimeMs + Metrics.KernelTimeMs) - (m_PrevUserTimeMs + m_PrevKernelTimeMs); + if (ElapsedMs > 0) + { + m_CpuUsagePercent.store(static_cast<float>(static_cast<double>(DeltaCpuTimeMs) / ElapsedMs * 100.0)); + } + } + + m_PrevUserTimeMs = Metrics.UserTimeMs; + m_PrevKernelTimeMs = Metrics.KernelTimeMs; + m_PrevSampleTicks = NowTicks; + m_LastMetrics = Metrics; + } + + [[nodiscard]] int Pid() const { return m_Handle.Pid(); } + + [[nodiscard]] bool IsRunning() const { return !m_Exited.load() && m_Handle.IsValid() && m_Handle.IsRunning(); } + + [[nodiscard]] std::string GetCapturedStdout() const + { + RwLock::SharedLockScope $(m_StdoutLock); + return m_CapturedStdout; + } + + [[nodiscard]] std::string GetCapturedStderr() const + { + RwLock::SharedLockScope $(m_StderrLock); + return m_CapturedStderr; + } + + void CancelAll() + { + m_ExitWatcher.Cancel(); + if (m_StdoutReader) + { + m_StdoutReader->Stop(); + } + if (m_StderrReader) + { + m_StderrReader->Stop(); + } + } +}; + +// ============================================================================ +// ManagedProcess +// ============================================================================ + +ManagedProcess::ManagedProcess(std::unique_ptr<Impl> InImpl) : m_Impl(std::move(InImpl)) +{ +} + +ManagedProcess::~ManagedProcess() +{ + if (m_Impl) + { + m_Impl->CancelAll(); + } +} + +int +ManagedProcess::Pid() const +{ + return m_Impl->Pid(); +} + +bool +ManagedProcess::IsRunning() const +{ + return m_Impl->IsRunning(); +} + +const ProcessHandle& +ManagedProcess::GetHandle() const +{ + return m_Impl->m_Handle; +} + +ProcessMetrics +ManagedProcess::GetLatestMetrics() const +{ + return m_Impl->m_LastMetrics; +} + +float +ManagedProcess::GetCpuUsagePercent() const +{ + return m_Impl->m_CpuUsagePercent.load(); +} + +std::string +ManagedProcess::GetCapturedStdout() const +{ + return m_Impl->GetCapturedStdout(); +} + +std::string +ManagedProcess::GetCapturedStderr() const +{ + return m_Impl->GetCapturedStderr(); +} + +bool +ManagedProcess::Kill() +{ + return m_Impl->m_Handle.Kill(); +} + +bool +ManagedProcess::Terminate(int ExitCode) +{ + return m_Impl->m_Handle.Terminate(ExitCode); +} + +void +ManagedProcess::SetTag(std::string Tag) +{ + m_Impl->m_Tag = std::move(Tag); +} + +std::string_view +ManagedProcess::GetTag() const +{ + return m_Impl->m_Tag; +} + +// ============================================================================ +// SubprocessManager::Impl +// ============================================================================ + +struct SubprocessManager::Impl +{ + asio::io_context& m_IoContext; + SubprocessManagerConfig m_Config; + + // Ungrouped processes + mutable RwLock m_Lock; + std::unordered_map<int, std::unique_ptr<ManagedProcess>> m_Processes; + + // Groups + mutable RwLock m_GroupsLock; + std::unordered_map<std::string, std::unique_ptr<ProcessGroup>> m_Groups; + + // Cross-group metrics index: all pids (grouped + ungrouped) for round-robin sampling + mutable RwLock m_MetricsLock; + std::unordered_map<int, ManagedProcess*> m_AllProcesses; // non-owning + std::vector<int> m_KeyOrder; + size_t m_NextSampleIndex = 0; + + ProcessDataCallback m_DefaultStdoutCallback; + ProcessDataCallback m_DefaultStderrCallback; + + std::unique_ptr<asio::steady_timer> m_MetricsTimer; + std::atomic<bool> m_Running{true}; + + explicit Impl(asio::io_context& IoContext, SubprocessManagerConfig Config); + ~Impl(); + + ManagedProcess* AddProcess(std::unique_ptr<ManagedProcess> Process); + void RegisterForMetrics(int Pid, ManagedProcess* Ptr); + void UnregisterFromMetrics(int Pid); + ManagedProcess* FindProcess(int Pid) const; + + void SetupExitWatcher(ManagedProcess* Proc, ProcessExitCallback OnExit); + void SetupStdoutReader(ManagedProcess* Proc, StdoutPipeHandles&& Pipe); + void SetupStderrReader(ManagedProcess* Proc, StdoutPipeHandles&& Pipe); + + ManagedProcess* Spawn(const std::filesystem::path& Executable, + std::string_view CommandLine, + CreateProcOptions& Options, + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr); + ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); + void Remove(int Pid); + void RemoveAll(); + + void SetDefaultStdoutCallback(ProcessDataCallback Callback) { m_DefaultStdoutCallback = std::move(Callback); } + void SetDefaultStderrCallback(ProcessDataCallback Callback) { m_DefaultStderrCallback = std::move(Callback); } + + void EnqueueMetricsTimer(); + void SampleBatch(); + std::vector<TrackedProcessEntry> GetMetricsSnapshot() const; + AggregateProcessMetrics GetAggregateMetrics() const; + [[nodiscard]] size_t GetProcessCount() const; + void Enumerate(std::function<void(const ManagedProcess&)> Callback) const; + + ProcessGroup* CreateGroup(std::string Name); + void DestroyGroup(std::string_view Name); + ProcessGroup* FindGroup(std::string_view Name) const; + void EnumerateGroups(std::function<void(const ProcessGroup&)> Callback) const; +}; + +// ============================================================================ +// SubprocessManager::Impl method definitions +// ============================================================================ + +SubprocessManager::Impl::Impl(asio::io_context& IoContext, SubprocessManagerConfig Config) : m_IoContext(IoContext), m_Config(Config) +{ + if (m_Config.MetricsSampleIntervalMs > 0) + { + m_MetricsTimer = std::make_unique<asio::steady_timer>(IoContext); + EnqueueMetricsTimer(); + } +} + +SubprocessManager::Impl::~Impl() +{ + m_Running = false; + if (m_MetricsTimer) + { + m_MetricsTimer->cancel(); + } + + // Destroy groups first (they reference m_Manager back to us) + { + RwLock::ExclusiveLockScope $(m_GroupsLock); + m_Groups.clear(); + } + + RemoveAll(); +} + +ManagedProcess* +SubprocessManager::Impl::AddProcess(std::unique_ptr<ManagedProcess> Process) +{ + int Pid = Process->Pid(); + ManagedProcess* Ptr = Process.get(); + + { + RwLock::ExclusiveLockScope $(m_Lock); + m_Processes[Pid] = std::move(Process); + } + + RegisterForMetrics(Pid, Ptr); + return Ptr; +} + +void +SubprocessManager::Impl::RegisterForMetrics(int Pid, ManagedProcess* Ptr) +{ + RwLock::ExclusiveLockScope $(m_MetricsLock); + m_AllProcesses[Pid] = Ptr; + m_KeyOrder.push_back(Pid); +} + +void +SubprocessManager::Impl::UnregisterFromMetrics(int Pid) +{ + RwLock::ExclusiveLockScope $(m_MetricsLock); + m_AllProcesses.erase(Pid); + m_KeyOrder.erase(std::remove(m_KeyOrder.begin(), m_KeyOrder.end(), Pid), m_KeyOrder.end()); + if (m_NextSampleIndex >= m_KeyOrder.size()) + { + m_NextSampleIndex = 0; + } +} + +ManagedProcess* +SubprocessManager::Impl::FindProcess(int Pid) const +{ + RwLock::SharedLockScope $(m_MetricsLock); + auto It = m_AllProcesses.find(Pid); + if (It != m_AllProcesses.end()) + { + return It->second; + } + return nullptr; +} + +void +SubprocessManager::Impl::SetupExitWatcher(ManagedProcess* Proc, ProcessExitCallback OnExit) +{ + int Pid = Proc->Pid(); + + Proc->m_Impl->m_ExitWatcher.Watch(Proc->m_Impl->m_Handle, [this, Pid, Callback = std::move(OnExit)](int ExitCode) { + ManagedProcess* Found = FindProcess(Pid); + + if (Found) + { + Found->m_Impl->m_Exited.store(true); + Callback(*Found, ExitCode); + } + }); +} + +void +SubprocessManager::Impl::SetupStdoutReader(ManagedProcess* Proc, StdoutPipeHandles&& Pipe) +{ + int Pid = Proc->Pid(); + Proc->m_Impl->m_StdoutReader = std::make_unique<AsyncPipeReader>(m_IoContext); + Proc->m_Impl->m_StdoutReader->Start( + std::move(Pipe), + [this, Pid](std::string_view Data) { + ManagedProcess* Found = FindProcess(Pid); + if (Found) + { + if (Found->m_Impl->m_StdoutCallback) + { + Found->m_Impl->m_StdoutCallback(*Found, Data); + } + else if (m_DefaultStdoutCallback) + { + m_DefaultStdoutCallback(*Found, Data); + } + else + { + Found->m_Impl->OnStdoutData(*Found, Data); + } + } + }, + [] {}); +} + +void +SubprocessManager::Impl::SetupStderrReader(ManagedProcess* Proc, StdoutPipeHandles&& Pipe) +{ + int Pid = Proc->Pid(); + Proc->m_Impl->m_StderrReader = std::make_unique<AsyncPipeReader>(m_IoContext); + Proc->m_Impl->m_StderrReader->Start( + std::move(Pipe), + [this, Pid](std::string_view Data) { + ManagedProcess* Found = FindProcess(Pid); + if (Found) + { + if (Found->m_Impl->m_StderrCallback) + { + Found->m_Impl->m_StderrCallback(*Found, Data); + } + else if (m_DefaultStderrCallback) + { + m_DefaultStderrCallback(*Found, Data); + } + else + { + Found->m_Impl->OnStderrData(*Found, Data); + } + } + }, + [] {}); +} + +ManagedProcess* +SubprocessManager::Impl::Spawn(const std::filesystem::path& Executable, + std::string_view CommandLine, + CreateProcOptions& Options, + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) +{ + bool HasStdout = Options.StdoutPipe != nullptr; + bool HasStderr = Options.StderrPipe != nullptr; + + CreateProcResult Result = CreateProc(Executable, CommandLine, Options); + + auto ImplPtr = std::make_unique<ManagedProcess::Impl>(m_IoContext); +#if ZEN_PLATFORM_WINDOWS + ImplPtr->m_Handle.Initialize(Result); +#else + ImplPtr->m_Handle.Initialize(static_cast<int>(Result)); +#endif + + // Install callbacks before starting async readers so no data is missed. + if (OnStdout) + { + ImplPtr->m_StdoutCallback = std::move(OnStdout); + } + if (OnStderr) + { + ImplPtr->m_StderrCallback = std::move(OnStderr); + } + + auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr))); + + ManagedProcess* Ptr = AddProcess(std::move(Proc)); + SetupExitWatcher(Ptr, std::move(OnExit)); + + if (HasStdout) + { + SetupStdoutReader(Ptr, std::move(*Options.StdoutPipe)); + } + if (HasStderr) + { + SetupStderrReader(Ptr, std::move(*Options.StderrPipe)); + } + + return Ptr; +} + +ManagedProcess* +SubprocessManager::Impl::Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit) +{ + int Pid = Handle.Pid(); + + auto ImplPtr = std::make_unique<ManagedProcess::Impl>(m_IoContext); + ImplPtr->m_Handle.Initialize(Pid); + + // Reset the original handle so caller doesn't double-close + Handle.Reset(); + + auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr))); + + ManagedProcess* Ptr = AddProcess(std::move(Proc)); + SetupExitWatcher(Ptr, std::move(OnExit)); + + return Ptr; +} + +void +SubprocessManager::Impl::Remove(int Pid) +{ + UnregisterFromMetrics(Pid); + + RwLock::ExclusiveLockScope $(m_Lock); + auto It = m_Processes.find(Pid); + if (It != m_Processes.end()) + { + It->second->m_Impl->CancelAll(); + m_Processes.erase(It); + } +} + +void +SubprocessManager::Impl::RemoveAll() +{ + { + RwLock::ExclusiveLockScope $(m_Lock); + for (auto& [Pid, Proc] : m_Processes) + { + Proc->m_Impl->CancelAll(); + } + m_Processes.clear(); + } + + { + RwLock::ExclusiveLockScope $(m_MetricsLock); + m_AllProcesses.clear(); + m_KeyOrder.clear(); + m_NextSampleIndex = 0; + } +} + +void +SubprocessManager::Impl::EnqueueMetricsTimer() +{ + if (!m_MetricsTimer || !m_Running.load()) + { + return; + } + + m_MetricsTimer->expires_after(std::chrono::milliseconds(m_Config.MetricsSampleIntervalMs)); + m_MetricsTimer->async_wait([this](const asio::error_code& Ec) { + if (Ec || !m_Running.load()) + { + return; + } + + SampleBatch(); + EnqueueMetricsTimer(); + }); +} + +void +SubprocessManager::Impl::SampleBatch() +{ + RwLock::SharedLockScope $(m_MetricsLock); + + if (m_KeyOrder.empty()) + { + return; + } + + size_t Remaining = std::min(static_cast<size_t>(m_Config.MetricsBatchSize), m_KeyOrder.size()); + + while (Remaining > 0) + { + if (m_NextSampleIndex >= m_KeyOrder.size()) + { + m_NextSampleIndex = 0; + } + + int Pid = m_KeyOrder[m_NextSampleIndex]; + auto It = m_AllProcesses.find(Pid); + + if (It != m_AllProcesses.end()) + { + It->second->m_Impl->SampleMetrics(); + } + + m_NextSampleIndex++; + Remaining--; + } +} + +std::vector<TrackedProcessEntry> +SubprocessManager::Impl::GetMetricsSnapshot() const +{ + std::vector<TrackedProcessEntry> Result; + + RwLock::SharedLockScope $(m_MetricsLock); + Result.reserve(m_AllProcesses.size()); + + for (const auto& [Pid, Proc] : m_AllProcesses) + { + TrackedProcessEntry Entry; + Entry.Pid = Pid; + Entry.Metrics = Proc->m_Impl->m_LastMetrics; + Entry.CpuUsagePercent = Proc->m_Impl->m_CpuUsagePercent.load(); + Result.push_back(std::move(Entry)); + } + + return Result; +} + +AggregateProcessMetrics +SubprocessManager::Impl::GetAggregateMetrics() const +{ + AggregateProcessMetrics Agg; + + RwLock::SharedLockScope $(m_MetricsLock); + + for (const auto& [Pid, Proc] : m_AllProcesses) + { + const ProcessMetrics& M = Proc->m_Impl->m_LastMetrics; + Agg.TotalWorkingSetSize += M.WorkingSetSize; + Agg.TotalPeakWorkingSetSize += M.PeakWorkingSetSize; + Agg.TotalUserTimeMs += M.UserTimeMs; + Agg.TotalKernelTimeMs += M.KernelTimeMs; + Agg.ProcessCount++; + } + + return Agg; +} + +size_t +SubprocessManager::Impl::GetProcessCount() const +{ + RwLock::SharedLockScope $(m_MetricsLock); + return m_AllProcesses.size(); +} + +void +SubprocessManager::Impl::Enumerate(std::function<void(const ManagedProcess&)> Callback) const +{ + RwLock::SharedLockScope $(m_MetricsLock); + for (const auto& [Pid, Proc] : m_AllProcesses) + { + Callback(*Proc); + } +} + +ProcessGroup* +SubprocessManager::Impl::CreateGroup(std::string Name) +{ + auto GroupImpl = std::make_unique<ProcessGroup::Impl>(std::move(Name), *this, m_IoContext); + ProcessGroup* Ptr = nullptr; + + auto Group = std::unique_ptr<ProcessGroup>(new ProcessGroup(std::move(GroupImpl))); + Ptr = Group.get(); + + RwLock::ExclusiveLockScope $(m_GroupsLock); + m_Groups[std::string(Ptr->GetName())] = std::move(Group); + + return Ptr; +} + +void +SubprocessManager::Impl::DestroyGroup(std::string_view Name) +{ + RwLock::ExclusiveLockScope $(m_GroupsLock); + auto It = m_Groups.find(std::string(Name)); + if (It != m_Groups.end()) + { + It->second->KillAll(); + m_Groups.erase(It); + } +} + +ProcessGroup* +SubprocessManager::Impl::FindGroup(std::string_view Name) const +{ + RwLock::SharedLockScope $(m_GroupsLock); + auto It = m_Groups.find(std::string(Name)); + if (It != m_Groups.end()) + { + return It->second.get(); + } + return nullptr; +} + +void +SubprocessManager::Impl::EnumerateGroups(std::function<void(const ProcessGroup&)> Callback) const +{ + RwLock::SharedLockScope $(m_GroupsLock); + for (const auto& [Name, Group] : m_Groups) + { + Callback(*Group); + } +} + +// ============================================================================ +// SubprocessManager +// ============================================================================ + +SubprocessManager::SubprocessManager(asio::io_context& IoContext, SubprocessManagerConfig Config) +: m_Impl(std::make_unique<Impl>(IoContext, Config)) +{ +} + +SubprocessManager::~SubprocessManager() = default; + +ManagedProcess* +SubprocessManager::Spawn(const std::filesystem::path& Executable, + std::string_view CommandLine, + CreateProcOptions& Options, + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) +{ + ZEN_TRACE_CPU("SubprocessManager::Spawn"); + return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr)); +} + +ManagedProcess* +SubprocessManager::Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit) +{ + ZEN_TRACE_CPU("SubprocessManager::Adopt"); + return m_Impl->Adopt(std::move(Handle), std::move(OnExit)); +} + +void +SubprocessManager::Remove(int Pid) +{ + ZEN_TRACE_CPU("SubprocessManager::Remove"); + m_Impl->Remove(Pid); +} + +void +SubprocessManager::RemoveAll() +{ + ZEN_TRACE_CPU("SubprocessManager::RemoveAll"); + m_Impl->RemoveAll(); +} + +void +SubprocessManager::SetDefaultStdoutCallback(ProcessDataCallback Callback) +{ + m_Impl->SetDefaultStdoutCallback(std::move(Callback)); +} + +void +SubprocessManager::SetDefaultStderrCallback(ProcessDataCallback Callback) +{ + m_Impl->SetDefaultStderrCallback(std::move(Callback)); +} + +std::vector<TrackedProcessEntry> +SubprocessManager::GetMetricsSnapshot() const +{ + return m_Impl->GetMetricsSnapshot(); +} + +AggregateProcessMetrics +SubprocessManager::GetAggregateMetrics() const +{ + return m_Impl->GetAggregateMetrics(); +} + +size_t +SubprocessManager::GetProcessCount() const +{ + return m_Impl->GetProcessCount(); +} + +void +SubprocessManager::Enumerate(std::function<void(const ManagedProcess&)> Callback) const +{ + m_Impl->Enumerate(std::move(Callback)); +} + +ProcessGroup* +SubprocessManager::CreateGroup(std::string Name) +{ + ZEN_TRACE_CPU("SubprocessManager::CreateGroup"); + return m_Impl->CreateGroup(std::move(Name)); +} + +void +SubprocessManager::DestroyGroup(std::string_view Name) +{ + ZEN_TRACE_CPU("SubprocessManager::DestroyGroup"); + m_Impl->DestroyGroup(Name); +} + +ProcessGroup* +SubprocessManager::FindGroup(std::string_view Name) const +{ + return m_Impl->FindGroup(Name); +} + +void +SubprocessManager::EnumerateGroups(std::function<void(const ProcessGroup&)> Callback) const +{ + m_Impl->EnumerateGroups(std::move(Callback)); +} + +// ============================================================================ +// ProcessGroup::Impl +// ============================================================================ + +struct ProcessGroup::Impl +{ + std::string m_Name; + SubprocessManager::Impl& m_Manager; + asio::io_context& m_IoContext; + + mutable RwLock m_Lock; + std::unordered_map<int, std::unique_ptr<ManagedProcess>> m_Processes; + +#if ZEN_PLATFORM_WINDOWS + JobObject m_JobObject; +#else + int m_Pgid = 0; +#endif + + Impl(std::string Name, SubprocessManager::Impl& Manager, asio::io_context& IoContext); + ~Impl(); + + ManagedProcess* AddProcess(std::unique_ptr<ManagedProcess> Process); + + ManagedProcess* Spawn(const std::filesystem::path& Executable, + std::string_view CommandLine, + CreateProcOptions& Options, + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr); + ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); + void Remove(int Pid); + void KillAll(); + + AggregateProcessMetrics GetAggregateMetrics() const; + std::vector<TrackedProcessEntry> GetMetricsSnapshot() const; + [[nodiscard]] size_t GetProcessCount() const; + void Enumerate(std::function<void(const ManagedProcess&)> Callback) const; +}; + +// ============================================================================ +// ProcessGroup::Impl method definitions +// ============================================================================ + +ProcessGroup::Impl::Impl(std::string Name, SubprocessManager::Impl& Manager, asio::io_context& IoContext) +: m_Name(std::move(Name)) +, m_Manager(Manager) +, m_IoContext(IoContext) +{ +#if ZEN_PLATFORM_WINDOWS + m_JobObject.Initialize(); +#endif +} + +ProcessGroup::Impl::~Impl() +{ + KillAll(); +} + +ManagedProcess* +ProcessGroup::Impl::AddProcess(std::unique_ptr<ManagedProcess> Process) +{ + int Pid = Process->Pid(); + ManagedProcess* Ptr = Process.get(); + + { + RwLock::ExclusiveLockScope $(m_Lock); + m_Processes[Pid] = std::move(Process); + } + + m_Manager.RegisterForMetrics(Pid, Ptr); + return Ptr; +} + +ManagedProcess* +ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable, + std::string_view CommandLine, + CreateProcOptions& Options, + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) +{ + bool HasStdout = Options.StdoutPipe != nullptr; + bool HasStderr = Options.StderrPipe != nullptr; + +#if ZEN_PLATFORM_WINDOWS + if (m_JobObject.IsValid()) + { + Options.AssignToJob = &m_JobObject; + } +#else + if (m_Pgid == 0) + { + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; + } + else + { + Options.ProcessGroupId = m_Pgid; + } +#endif + + CreateProcResult Result = CreateProc(Executable, CommandLine, Options); + + auto ImplPtr = std::make_unique<ManagedProcess::Impl>(m_IoContext); +#if ZEN_PLATFORM_WINDOWS + ImplPtr->m_Handle.Initialize(Result); +#else + int Pid = static_cast<int>(Result); + ImplPtr->m_Handle.Initialize(Pid); + + // First process becomes the group leader + if (m_Pgid == 0) + { + m_Pgid = Pid; + } +#endif + + // Install callbacks before starting async readers so no data is missed. + if (OnStdout) + { + ImplPtr->m_StdoutCallback = std::move(OnStdout); + } + if (OnStderr) + { + ImplPtr->m_StderrCallback = std::move(OnStderr); + } + + auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr))); + + ManagedProcess* Ptr = AddProcess(std::move(Proc)); + m_Manager.SetupExitWatcher(Ptr, std::move(OnExit)); + + if (HasStdout) + { + m_Manager.SetupStdoutReader(Ptr, std::move(*Options.StdoutPipe)); + } + if (HasStderr) + { + m_Manager.SetupStderrReader(Ptr, std::move(*Options.StderrPipe)); + } + + return Ptr; +} + +ManagedProcess* +ProcessGroup::Impl::Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit) +{ + int Pid = Handle.Pid(); + + auto ImplPtr = std::make_unique<ManagedProcess::Impl>(m_IoContext); + ImplPtr->m_Handle.Initialize(Pid); + Handle.Reset(); + +#if ZEN_PLATFORM_WINDOWS + if (m_JobObject.IsValid()) + { + m_JobObject.AssignProcess(ImplPtr->m_Handle.Handle()); + } +#endif + + auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr))); + + ManagedProcess* Ptr = AddProcess(std::move(Proc)); + m_Manager.SetupExitWatcher(Ptr, std::move(OnExit)); + + return Ptr; +} + +void +ProcessGroup::Impl::Remove(int Pid) +{ + m_Manager.UnregisterFromMetrics(Pid); + + RwLock::ExclusiveLockScope $(m_Lock); + auto It = m_Processes.find(Pid); + if (It != m_Processes.end()) + { + It->second->m_Impl->CancelAll(); + m_Processes.erase(It); + } +} + +void +ProcessGroup::Impl::KillAll() +{ +#if ZEN_PLATFORM_WINDOWS + if (m_JobObject.IsValid()) + { + TerminateJobObject(static_cast<HANDLE>(m_JobObject.Handle()), 1); + } +#else + if (m_Pgid > 0) + { + kill(-m_Pgid, SIGTERM); + } +#endif + // Also kill individually as fallback and clean up + RwLock::ExclusiveLockScope $(m_Lock); + for (auto& [Pid, Proc] : m_Processes) + { + if (Proc->IsRunning()) + { + Proc->Kill(); + } + m_Manager.UnregisterFromMetrics(Pid); + Proc->m_Impl->CancelAll(); + } + m_Processes.clear(); +} + +AggregateProcessMetrics +ProcessGroup::Impl::GetAggregateMetrics() const +{ + AggregateProcessMetrics Agg; + + RwLock::SharedLockScope $(m_Lock); + + for (const auto& [Pid, Proc] : m_Processes) + { + const ProcessMetrics& M = Proc->m_Impl->m_LastMetrics; + Agg.TotalWorkingSetSize += M.WorkingSetSize; + Agg.TotalPeakWorkingSetSize += M.PeakWorkingSetSize; + Agg.TotalUserTimeMs += M.UserTimeMs; + Agg.TotalKernelTimeMs += M.KernelTimeMs; + Agg.ProcessCount++; + } + + return Agg; +} + +std::vector<TrackedProcessEntry> +ProcessGroup::Impl::GetMetricsSnapshot() const +{ + std::vector<TrackedProcessEntry> Result; + + RwLock::SharedLockScope $(m_Lock); + Result.reserve(m_Processes.size()); + + for (const auto& [Pid, Proc] : m_Processes) + { + TrackedProcessEntry Entry; + Entry.Pid = Pid; + Entry.Metrics = Proc->m_Impl->m_LastMetrics; + Entry.CpuUsagePercent = Proc->m_Impl->m_CpuUsagePercent.load(); + Result.push_back(std::move(Entry)); + } + + return Result; +} + +size_t +ProcessGroup::Impl::GetProcessCount() const +{ + RwLock::SharedLockScope $(m_Lock); + return m_Processes.size(); +} + +void +ProcessGroup::Impl::Enumerate(std::function<void(const ManagedProcess&)> Callback) const +{ + RwLock::SharedLockScope $(m_Lock); + for (const auto& [Pid, Proc] : m_Processes) + { + Callback(*Proc); + } +} + +// ============================================================================ +// ProcessGroup +// ============================================================================ + +ProcessGroup::ProcessGroup(std::unique_ptr<Impl> InImpl) : m_Impl(std::move(InImpl)) +{ +} + +ProcessGroup::~ProcessGroup() = default; + +std::string_view +ProcessGroup::GetName() const +{ + return m_Impl->m_Name; +} + +ManagedProcess* +ProcessGroup::Spawn(const std::filesystem::path& Executable, + std::string_view CommandLine, + CreateProcOptions& Options, + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) +{ + ZEN_TRACE_CPU("ProcessGroup::Spawn"); + return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr)); +} + +ManagedProcess* +ProcessGroup::Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit) +{ + ZEN_TRACE_CPU("ProcessGroup::Adopt"); + return m_Impl->Adopt(std::move(Handle), std::move(OnExit)); +} + +void +ProcessGroup::Remove(int Pid) +{ + ZEN_TRACE_CPU("ProcessGroup::Remove"); + m_Impl->Remove(Pid); +} + +void +ProcessGroup::KillAll() +{ + ZEN_TRACE_CPU("ProcessGroup::KillAll"); + m_Impl->KillAll(); +} + +AggregateProcessMetrics +ProcessGroup::GetAggregateMetrics() const +{ + return m_Impl->GetAggregateMetrics(); +} + +std::vector<TrackedProcessEntry> +ProcessGroup::GetMetricsSnapshot() const +{ + return m_Impl->GetMetricsSnapshot(); +} + +size_t +ProcessGroup::GetProcessCount() const +{ + return m_Impl->GetProcessCount(); +} + +void +ProcessGroup::Enumerate(std::function<void(const ManagedProcess&)> Callback) const +{ + m_Impl->Enumerate(std::move(Callback)); +} + +} // namespace zen + +// ============================================================================ +// Tests +// ============================================================================ + +#if ZEN_WITH_TESTS + +# include <zencore/testing.h> + +# include <chrono> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio/io_context.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +using namespace zen; +using namespace std::literals; + +void +zen::subprocessmanager_forcelink() +{ +} + +namespace { + +std::filesystem::path +GetAppStubPath() +{ + std::error_code Ec; + std::filesystem::path SelfPath = GetProcessExecutablePath(zen::GetCurrentProcessId(), Ec); + return SelfPath.parent_path() / "zentest-appstub" ZEN_EXE_SUFFIX_LITERAL; +} + +} // namespace + +TEST_SUITE_BEGIN("util.subprocessmanager"); + +TEST_CASE("SubprocessManager.SpawnAndDetectExit") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + std::filesystem::path AppStub = GetAppStubPath(); + std::string CmdLine = AppStub.string() + " -f=42"; + + int ReceivedExitCode = -1; + bool CallbackFired = false; + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int ExitCode) { + ReceivedExitCode = ExitCode; + CallbackFired = true; + }); + + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (CallbackFired) + { + break; + } + } + } + + CHECK(CallbackFired); + CHECK(ReceivedExitCode == 42); +} + +TEST_CASE("SubprocessManager.SpawnAndDetectCleanExit") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + std::filesystem::path AppStub = GetAppStubPath(); + std::string CmdLine = AppStub.string(); + + int ReceivedExitCode = -1; + bool CallbackFired = false; + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int ExitCode) { + ReceivedExitCode = ExitCode; + CallbackFired = true; + }); + + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (CallbackFired) + { + break; + } + } + } + + CHECK(CallbackFired); + CHECK(ReceivedExitCode == 0); +} + +TEST_CASE("SubprocessManager.StdoutCapture") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + std::filesystem::path AppStub = GetAppStubPath(); + std::string CmdLine = AppStub.string() + " -echo=hello_world"; + + StdoutPipeHandles StdoutPipe; + REQUIRE(CreateOverlappedStdoutPipe(StdoutPipe)); + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + Options.StdoutPipe = &StdoutPipe; + + bool Exited = false; + + ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); + + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } + + CHECK(Exited); + std::string Captured = Proc->GetCapturedStdout(); + CHECK(Captured.find("hello_world") != std::string::npos); +} + +TEST_CASE("SubprocessManager.StderrCapture") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + std::filesystem::path AppStub = GetAppStubPath(); + std::string CmdLine = AppStub.string() + " -echoerr=error_msg"; + + StdoutPipeHandles StdoutPipe; + StdoutPipeHandles StderrPipe; + REQUIRE(CreateOverlappedStdoutPipe(StdoutPipe)); + REQUIRE(CreateOverlappedStdoutPipe(StderrPipe)); + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + Options.StdoutPipe = &StdoutPipe; + Options.StderrPipe = &StderrPipe; + + bool Exited = false; + + ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); + + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } + + CHECK(Exited); + std::string CapturedErr = Proc->GetCapturedStderr(); + CHECK(CapturedErr.find("error_msg") != std::string::npos); +} + +TEST_CASE("SubprocessManager.StdoutCallback") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + std::filesystem::path AppStub = GetAppStubPath(); + std::string CmdLine = AppStub.string() + " -echo=callback_test"; + + StdoutPipeHandles StdoutPipe; + REQUIRE(CreateOverlappedStdoutPipe(StdoutPipe)); + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + Options.StdoutPipe = &StdoutPipe; + + std::string ReceivedData; + bool Exited = false; + + ManagedProcess* Proc = Manager.Spawn( + AppStub, + CmdLine, + Options, + [&](ManagedProcess&, int) { Exited = true; }, + [&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); }); + + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } + + CHECK(Exited); + CHECK(ReceivedData.find("callback_test") != std::string::npos); + // When a callback is set, accumulated buffer should be empty + CHECK(Proc->GetCapturedStdout().empty()); +} + +TEST_CASE("SubprocessManager.MetricsSampling") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 100, .MetricsBatchSize = 16}); + + std::filesystem::path AppStub = GetAppStubPath(); + std::string CmdLine = AppStub.string() + " -t=2"; + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + bool Exited = false; + + ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); + + // Poll until metrics are available + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Proc->GetLatestMetrics().WorkingSetSize > 0) + { + break; + } + } + } + + ProcessMetrics Metrics = Proc->GetLatestMetrics(); + CHECK(Metrics.WorkingSetSize > 0); + + auto Snapshot = Manager.GetMetricsSnapshot(); + CHECK(Snapshot.size() == 1); + + // Let it finish + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 10'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } + CHECK(Exited); +} + +TEST_CASE("SubprocessManager.RemoveWhileRunning") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + std::filesystem::path AppStub = GetAppStubPath(); + std::string CmdLine = AppStub.string() + " -t=10"; + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + bool CallbackFired = false; + + ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { CallbackFired = true; }); + + int Pid = Proc->Pid(); + + // Let it start + IoContext.run_for(100ms); + + // Remove without killing - callback should NOT fire after this + Manager.Remove(Pid); + + IoContext.run_for(500ms); + + CHECK_FALSE(CallbackFired); + CHECK(Manager.GetProcessCount() == 0); +} + +TEST_CASE("SubprocessManager.KillAndWaitForExit") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + std::filesystem::path AppStub = GetAppStubPath(); + std::string CmdLine = AppStub.string() + " -t=60"; + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + bool CallbackFired = false; + + ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { CallbackFired = true; }); + + // Let it start + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Proc->IsRunning()) + { + break; + } + } + } + + Proc->Kill(); + + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (CallbackFired) + { + break; + } + } + } + CHECK(CallbackFired); +} + +TEST_CASE("SubprocessManager.AdoptProcess") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + std::filesystem::path AppStub = GetAppStubPath(); + std::string CmdLine = AppStub.string() + " -f=7"; + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + CreateProcResult Result = CreateProc(AppStub, CmdLine, Options); + + int ReceivedExitCode = -1; + + Manager.Adopt(ProcessHandle(Result), [&](ManagedProcess&, int ExitCode) { ReceivedExitCode = ExitCode; }); + + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (ReceivedExitCode != -1) + { + break; + } + } + } + + CHECK(ReceivedExitCode == 7); +} + +TEST_CASE("SubprocessManager.UserTag") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + std::filesystem::path AppStub = GetAppStubPath(); + std::string CmdLine = AppStub.string() + " -f=0"; + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + std::string ReceivedTag; + + ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess& P, int) { ReceivedTag = std::string(P.GetTag()); }); + + Proc->SetTag("my-worker-1"); + CHECK(Proc->GetTag() == "my-worker-1"); + + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (!ReceivedTag.empty()) + { + break; + } + } + } + + CHECK(ReceivedTag == "my-worker-1"); +} + +TEST_CASE("ProcessGroup.SpawnAndMembership") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + ProcessGroup* Group = Manager.CreateGroup("test-group"); + REQUIRE(Group != nullptr); + CHECK(Group->GetName() == "test-group"); + + std::filesystem::path AppStub = GetAppStubPath(); + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + int ExitCount = 0; + + std::string CmdLine1 = AppStub.string() + " -f=0"; + std::string CmdLine2 = AppStub.string() + " -f=1"; + + Group->Spawn(AppStub, CmdLine1, Options, [&](ManagedProcess&, int) { ExitCount++; }); + Group->Spawn(AppStub, CmdLine2, Options, [&](ManagedProcess&, int) { ExitCount++; }); + + CHECK(Group->GetProcessCount() == 2); + CHECK(Manager.GetProcessCount() == 2); + + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (ExitCount == 2) + { + break; + } + } + } + + CHECK(ExitCount == 2); +} + +TEST_CASE("ProcessGroup.KillAll") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + ProcessGroup* Group = Manager.CreateGroup("kill-group"); + + std::filesystem::path AppStub = GetAppStubPath(); + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + int ExitCount = 0; + + std::string CmdLine = AppStub.string() + " -t=60"; + + Group->Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { ExitCount++; }); + Group->Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { ExitCount++; }); + + // Let them start + IoContext.run_for(200ms); + CHECK(Group->GetProcessCount() == 2); + + Group->KillAll(); + CHECK(Group->GetProcessCount() == 0); +} + +TEST_CASE("ProcessGroup.AggregateMetrics") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 100, .MetricsBatchSize = 16}); + + ProcessGroup* Group = Manager.CreateGroup("metrics-group"); + + std::filesystem::path AppStub = GetAppStubPath(); + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + std::string CmdLine = AppStub.string() + " -t=3"; + + Group->Spawn(AppStub, CmdLine, Options, [](ManagedProcess&, int) {}); + Group->Spawn(AppStub, CmdLine, Options, [](ManagedProcess&, int) {}); + + // Wait for metrics sampling + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Group->GetAggregateMetrics().TotalWorkingSetSize > 0) + { + break; + } + } + } + + AggregateProcessMetrics GroupAgg = Group->GetAggregateMetrics(); + CHECK(GroupAgg.ProcessCount == 2); + CHECK(GroupAgg.TotalWorkingSetSize > 0); + + // Manager-level metrics should include group processes + AggregateProcessMetrics ManagerAgg = Manager.GetAggregateMetrics(); + CHECK(ManagerAgg.ProcessCount == 2); + + Group->KillAll(); +} + +TEST_CASE("ProcessGroup.DestroyGroup") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + ProcessGroup* Group = Manager.CreateGroup("destroy-group"); + + std::filesystem::path AppStub = GetAppStubPath(); + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + std::string CmdLine = AppStub.string() + " -t=60"; + + Group->Spawn(AppStub, CmdLine, Options, [](ManagedProcess&, int) {}); + Group->Spawn(AppStub, CmdLine, Options, [](ManagedProcess&, int) {}); + + IoContext.run_for(200ms); + CHECK(Manager.GetProcessCount() == 2); + + Manager.DestroyGroup("destroy-group"); + + CHECK(Manager.FindGroup("destroy-group") == nullptr); + CHECK(Manager.GetProcessCount() == 0); +} + +TEST_CASE("ProcessGroup.MixedGroupedAndUngrouped") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + ProcessGroup* Group = Manager.CreateGroup("mixed-group"); + + std::filesystem::path AppStub = GetAppStubPath(); + + CreateProcOptions Options; + Options.Flags = CreateProcOptions::Flag_NoConsole; + + int GroupExitCount = 0; + int UngroupedExitCode = -1; + + std::string CmdLine = AppStub.string() + " -f=0"; + + // Grouped processes + Group->Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { GroupExitCount++; }); + Group->Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { GroupExitCount++; }); + + // Ungrouped process + Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int ExitCode) { UngroupedExitCode = ExitCode; }); + + CHECK(Group->GetProcessCount() == 2); + CHECK(Manager.GetProcessCount() == 3); + + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (GroupExitCount == 2 && UngroupedExitCode != -1) + { + break; + } + } + } + + CHECK(GroupExitCount == 2); + CHECK(UngroupedExitCode == 0); +} + +TEST_CASE("ProcessGroup.FindGroup") +{ + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 0}); + + CHECK(Manager.FindGroup("nonexistent") == nullptr); + + ProcessGroup* Group = Manager.CreateGroup("findable"); + CHECK(Manager.FindGroup("findable") == Group); + CHECK(Manager.FindGroup("findable")->GetName() == "findable"); +} + +TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) +{ + // Seed for reproducibility - change to explore different orderings + // + // Note that while this is a stress test, it is still single-threaded + + constexpr uint32_t Seed = 42; + std::mt19937 Rng(Seed); + ZEN_INFO("StressTest: seed={}", Seed); + + asio::io_context IoContext; + SubprocessManager Manager(IoContext, {.MetricsSampleIntervalMs = 200, .MetricsBatchSize = 32}); + + std::filesystem::path AppStub = GetAppStubPath(); + CreateProcOptions BaseOptions; + BaseOptions.Flags = CreateProcOptions::Flag_NoConsole; + + std::atomic<int> TotalExitCallbacks{0}; + std::atomic<int> KilledGroupProcessCount{0}; + + auto MakeExitCallback = [&](std::atomic<int>& Counter) { + return [&Counter, &TotalExitCallbacks](ManagedProcess&, int) { + Counter++; + TotalExitCallbacks++; + }; + }; + + // ======================================================================== + // Phase 1: Spawn multiple groups with varied workloads + // ======================================================================== + + ZEN_INFO("StressTest: Phase 1 - spawning initial groups"); + + constexpr int NumInitialGroups = 8; + std::vector<std::string> GroupNames; + std::vector<std::atomic<int>> GroupExitCounts(NumInitialGroups); + std::uniform_int_distribution<int> ProcCountDist(5, 100); + std::uniform_int_distribution<int> SleepDist(1, 5); + std::uniform_int_distribution<int> ExitCodeDist(0, 10); + std::uniform_int_distribution<int> WorkloadDist(0, 2); // 0=sleep, 1=exit-code, 2=echo+exit + int TotalPhase1Spawned = 0; + + for (int G = 0; G < NumInitialGroups; G++) + { + std::string GroupName = fmt::format("stress-group-{}", G); + ProcessGroup* Group = Manager.CreateGroup(GroupName); + GroupNames.push_back(GroupName); + + int ProcCount = ProcCountDist(Rng); + for (int P = 0; P < ProcCount; P++) + { + std::string CmdLine; + int Workload = WorkloadDist(Rng); + if (Workload == 0) + { + int Sleep = SleepDist(Rng); + CmdLine = fmt::format("{} -t={}", AppStub.string(), Sleep); + } + else if (Workload == 1) + { + int Code = ExitCodeDist(Rng); + CmdLine = fmt::format("{} -f={}", AppStub.string(), Code); + } + else + { + int Code = ExitCodeDist(Rng); + CmdLine = fmt::format("{} -echo=stress_g{}_p{} -f={}", AppStub.string(), G, P, Code); + } + + Group->Spawn(AppStub, CmdLine, BaseOptions, MakeExitCallback(GroupExitCounts[G])); + TotalPhase1Spawned++; + } + + ZEN_INFO("StressTest: group '{}' spawned {} processes", GroupName, ProcCount); + } + + ZEN_INFO("StressTest: Phase 1 total spawned: {}", TotalPhase1Spawned); + + // Let processes start running and some short-lived ones exit + IoContext.run_for(1s); + + // ======================================================================== + // Phase 2: Randomly kill some groups, create replacements, add ungrouped + // ======================================================================== + + ZEN_INFO("StressTest: Phase 2 - random group kills and replacements"); + + constexpr int NumGroupsToKill = 3; + + // Pick random groups to kill + std::vector<int> GroupIndices(NumInitialGroups); + std::iota(GroupIndices.begin(), GroupIndices.end(), 0); + std::shuffle(GroupIndices.begin(), GroupIndices.end(), Rng); + + std::vector<int> KilledIndices(GroupIndices.begin(), GroupIndices.begin() + NumGroupsToKill); + + for (int Idx : KilledIndices) + { + ProcessGroup* Group = Manager.FindGroup(GroupNames[Idx]); + if (Group) + { + size_t Count = Group->GetProcessCount(); + ZEN_INFO("StressTest: killing group '{}' ({} processes)", GroupNames[Idx], Count); + Manager.DestroyGroup(GroupNames[Idx]); + } + } + + // Let kills propagate + IoContext.run_for(500ms); + + // Create replacement groups + std::atomic<int> ReplacementExitCount{0}; + std::uniform_int_distribution<int> ReplacementCountDist(3, 10); + + for (int R = 0; R < NumGroupsToKill; R++) + { + std::string Name = fmt::format("replacement-group-{}", R); + ProcessGroup* Group = Manager.CreateGroup(Name); + int Count = ReplacementCountDist(Rng); + + for (int P = 0; P < Count; P++) + { + int Sleep = SleepDist(Rng); + std::string CmdLine = fmt::format("{} -t={}", AppStub.string(), Sleep); + Group->Spawn(AppStub, CmdLine, BaseOptions, MakeExitCallback(ReplacementExitCount)); + } + + ZEN_INFO("StressTest: replacement group '{}' spawned {} processes", Name, Count); + } + + // Also spawn some ungrouped processes + std::atomic<int> UngroupedExitCount{0}; + constexpr int NumUngrouped = 10; + + for (int U = 0; U < NumUngrouped; U++) + { + int ExitCode = ExitCodeDist(Rng); + std::string CmdLine = fmt::format("{} -f={}", AppStub.string(), ExitCode); + Manager.Spawn(AppStub, CmdLine, BaseOptions, MakeExitCallback(UngroupedExitCount)); + } + + ZEN_INFO("StressTest: spawned {} ungrouped processes", NumUngrouped); + + // Let things run + IoContext.run_for(2s); + + // ======================================================================== + // Phase 3: Rapid spawn/exit churn + // ======================================================================== + + ZEN_INFO("StressTest: Phase 3 - rapid spawn/exit churn"); + + std::atomic<int> ChurnExitCount{0}; + int TotalChurnSpawned = 0; + constexpr int NumChurnBatches = 10; + std::uniform_int_distribution<int> ChurnBatchSizeDist(10, 20); + + for (int Batch = 0; Batch < NumChurnBatches; Batch++) + { + std::string Name = fmt::format("churn-batch-{}", Batch); + ProcessGroup* Group = Manager.CreateGroup(Name); + int Count = ChurnBatchSizeDist(Rng); + + for (int P = 0; P < Count; P++) + { + // Immediate exit processes to stress spawn/exit path + std::string CmdLine = fmt::format("{} -f=0", AppStub.string()); + Group->Spawn(AppStub, CmdLine, BaseOptions, MakeExitCallback(ChurnExitCount)); + TotalChurnSpawned++; + } + + // Brief pump to allow some exits to be processed + IoContext.run_for(200ms); + + // Destroy the group - any still-running processes get killed + Manager.DestroyGroup(Name); + } + + ZEN_INFO("StressTest: Phase 3 spawned {} churn processes across {} batches", TotalChurnSpawned, NumChurnBatches); + + // ======================================================================== + // Phase 4: Drain and verify + // ======================================================================== + + ZEN_INFO("StressTest: Phase 4 - draining remaining processes"); + + // Check metrics were collected before we wind down + AggregateProcessMetrics Agg = Manager.GetAggregateMetrics(); + ZEN_INFO("StressTest: aggregate metrics: {} processes, {} bytes working set", Agg.ProcessCount, Agg.TotalWorkingSetSize); + + // Let remaining processes finish (replacement groups have up to 5s sleep) + IoContext.run_for(8s); + + // Kill anything still running + Manager.RemoveAll(); + + // Final pump to process any remaining callbacks + IoContext.run_for(1s); + + ZEN_INFO("StressTest: Results:"); + ZEN_INFO("StressTest: total exit callbacks fired: {}", TotalExitCallbacks.load()); + ZEN_INFO("StressTest: ungrouped exits: {}", UngroupedExitCount.load()); + ZEN_INFO("StressTest: replacement exits: {}", ReplacementExitCount.load()); + ZEN_INFO("StressTest: churn exits: {}", ChurnExitCount.load()); + + // Verify the manager is clean + CHECK(Manager.GetProcessCount() == 0); + + // Ungrouped processes should all have exited (they were immediate-exit) + CHECK(UngroupedExitCount.load() == NumUngrouped); + + // Verify we got a reasonable number of total callbacks + // (exact count is hard to predict due to killed groups, but should be > 0) + CHECK(TotalExitCallbacks.load() > 0); + + ZEN_INFO("StressTest: PASSED - seed={}", Seed); +} + +TEST_SUITE_END(); + +#endif diff --git a/src/zenutil/progress.cpp b/src/zenutil/progress.cpp new file mode 100644 index 000000000..01f6529f2 --- /dev/null +++ b/src/zenutil/progress.cpp @@ -0,0 +1,102 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/progress.h> + +#include <zencore/logging.h> + +namespace zen { + +class StandardProgressBase; + +class StandardProgressBar : public ProgressBase::ProgressBar +{ +public: + StandardProgressBar(StandardProgressBase& Owner, std::string_view InSubTask) : m_Owner(Owner), m_SubTask(InSubTask) {} + + virtual void UpdateState(const State& NewState, bool DoLinebreak) override; + virtual void ForceLinebreak() override {} + virtual void Finish() override; + +private: + LoggerRef Log(); + StandardProgressBase& m_Owner; + std::string m_SubTask; + State m_State; +}; + +class StandardProgressBase : public ProgressBase +{ +public: + StandardProgressBase(LoggerRef Log) : m_Log(Log) {} + + virtual void SetLogOperationName(std::string_view Name) override + { + m_LogOperationName = Name; + ZEN_INFO("{}", m_LogOperationName); + } + virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override + { + const size_t PercentDone = StepCount > 0u ? (100u * StepIndex) / StepCount : 0u; + ZEN_INFO("{}: {}%", m_LogOperationName, PercentDone); + } + virtual void PushLogOperation(std::string_view Name) override { ZEN_UNUSED(Name); } + virtual void PopLogOperation() override {} + virtual uint32_t GetProgressUpdateDelayMS() const override { return 2000; } + virtual std::unique_ptr<ProgressBar> CreateProgressBar(std::string_view InSubTask) override + { + return std::make_unique<StandardProgressBar>(*this, InSubTask); + } + +private: + friend class StandardProgressBar; + LoggerRef m_Log; + std::string m_LogOperationName; + LoggerRef Log() { return m_Log; } +}; + +LoggerRef +StandardProgressBar::Log() +{ + return m_Owner.Log(); +} + +void +StandardProgressBar::UpdateState(const State& NewState, bool DoLinebreak) +{ + ZEN_UNUSED(DoLinebreak); + const size_t PercentDone = + NewState.TotalCount > 0u ? (100u * (NewState.TotalCount - NewState.RemainingCount)) / NewState.TotalCount : 0u; + std::string Task = NewState.Task; + switch (NewState.Status) + { + case State::EStatus::Aborted: + Task = "Aborting"; + break; + case State::EStatus::Paused: + Task = "Paused"; + break; + default: + break; + } + ZEN_INFO("{}: {}%{}", Task, PercentDone, NewState.Details.empty() ? "" : fmt::format(" {}", NewState.Details)); + m_State = NewState; +} +void +StandardProgressBar::Finish() +{ + if (m_State.RemainingCount > 0) + { + State NewState = m_State; + NewState.RemainingCount = 0; + NewState.Details = ""; + UpdateState(NewState, /*DoLinebreak*/ true); + } +} + +ProgressBase* +CreateStandardProgress(LoggerRef Log) +{ + return new StandardProgressBase(Log); +} + +} // namespace zen diff --git a/src/zenutil/rpcrecording.cpp b/src/zenutil/rpcrecording.cpp index a9e95b9ce..3c273abb6 100644 --- a/src/zenutil/rpcrecording.cpp +++ b/src/zenutil/rpcrecording.cpp @@ -13,12 +13,12 @@ #include <zencore/testutils.h> ZEN_THIRD_PARTY_INCLUDES_START +#include <EASTL/deque.h> #include <fmt/format.h> #include <gsl/gsl-lite.hpp> ZEN_THIRD_PARTY_INCLUDES_END #include <condition_variable> -#include <deque> #include <mutex> #include <thread> @@ -371,7 +371,7 @@ private: std::mutex m_QueueMutex; std::condition_variable m_QueueCondition; bool m_IsWriterReady = false; - std::deque<QueuedRequest> m_RequestQueue; + eastl::deque<QueuedRequest> m_RequestQueue; void WriterThreadMain(); }; diff --git a/src/zenutil/sessionsclient.cpp b/src/zenutil/sessionsclient.cpp new file mode 100644 index 000000000..6ba997a62 --- /dev/null +++ b/src/zenutil/sessionsclient.cpp @@ -0,0 +1,381 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/sessionsclient.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/logging/logmsg.h> +#include <zencore/thread.h> + +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// SessionLogSink — thin enqueuer that posts log messages to the +// SessionsServiceClient worker thread via its BlockingQueue. +// + +class SessionLogSink final : public logging::Sink +{ +public: + explicit SessionLogSink(BlockingQueue<SessionsServiceClient::SessionCommand>* Queue) : m_Queue(Queue) { SetLevel(logging::Info); } + + ~SessionLogSink() override = default; + + void Log(const logging::LogMessage& Msg) override + { + SessionsServiceClient::SessionCommand Cmd; + Cmd.CommandType = SessionsServiceClient::SessionCommand::Type::Log; + Cmd.LogLevel = Msg.GetLevel(); + Cmd.LogMessage = CompactString(Msg.GetPayload()); + m_Queue->Enqueue(std::move(Cmd)); + } + + void Flush() override + { + SessionsServiceClient::SessionCommand Cmd; + Cmd.CommandType = SessionsServiceClient::SessionCommand::Type::FlushLogs; + m_Queue->Enqueue(std::move(Cmd)); + } + + void SetFormatter(std::unique_ptr<logging::Formatter> /*InFormatter*/) override + { + // No formatting needed - we send raw message text + } + +private: + BlockingQueue<SessionsServiceClient::SessionCommand>* m_Queue; +}; + +////////////////////////////////////////////////////////////////////////// +// +// SessionsServiceClient +// + +SessionsServiceClient::SessionsServiceClient(Options Opts) +: m_Log(logging::Get("sessionsclient")) +, m_Options(std::move(Opts)) +, m_SessionPath(fmt::format("/sessions/{}", m_Options.SessionId)) +{ + // Strip trailing slash to avoid double-slash when appending paths like /sessions/{id} + while (m_Options.TargetUrl.ends_with('/')) + { + m_Options.TargetUrl.pop_back(); + } + + m_WorkerThread = std::thread([this]() { + zen::SetCurrentThreadName("SessionIO"); + WorkerLoop(); + }); +} + +SessionsServiceClient::~SessionsServiceClient() +{ + SessionCommand ShutdownCmd; + ShutdownCmd.CommandType = SessionCommand::Type::Shutdown; + m_Queue.Enqueue(std::move(ShutdownCmd)); + m_Queue.CompleteAdding(); + + if (m_WorkerThread.joinable()) + { + m_WorkerThread.join(); + } +} + +CbObject +SessionsServiceClient::BuildRequestBody(CbObjectView Metadata) const +{ + CbObjectWriter Writer; + Writer << "appname" << m_Options.AppName; + if (!m_Options.Mode.empty()) + { + Writer << "mode" << m_Options.Mode; + } + if (m_Options.JobId != Oid::Zero) + { + Writer << "jobid" << m_Options.JobId; + } + if (Metadata) + { + Writer.AddObject("metadata", Metadata); + } + return Writer.Save(); +} + +////////////////////////////////////////////////////////////////////////// +// Public API — non-blocking enqueuers + +void +SessionsServiceClient::Announce(CbObjectView Metadata) +{ + SessionCommand Cmd; + Cmd.CommandType = SessionCommand::Type::Announce; + if (Metadata) + { + Cmd.Metadata = CbObject::Clone(Metadata); + } + m_Queue.Enqueue(std::move(Cmd)); +} + +void +SessionsServiceClient::UpdateMetadata(CbObjectView Metadata) +{ + SessionCommand Cmd; + Cmd.CommandType = SessionCommand::Type::UpdateMetadata; + if (Metadata) + { + Cmd.Metadata = CbObject::Clone(Metadata); + } + m_Queue.Enqueue(std::move(Cmd)); +} + +void +SessionsServiceClient::Remove() +{ + SessionCommand Cmd; + Cmd.CommandType = SessionCommand::Type::Remove; + m_Queue.Enqueue(std::move(Cmd)); +} + +logging::SinkPtr +SessionsServiceClient::CreateLogSink() +{ + return Ref(new SessionLogSink(&m_Queue)); +} + +////////////////////////////////////////////////////////////////////////// +// Worker thread — processes all session HTTP I/O + +void +SessionsServiceClient::DoAnnounce(HttpClient& Http, CbObjectView Metadata) +{ + try + { + CbObject Body = BuildRequestBody(Metadata); + + HttpClient::Response Result = Http.Post(m_SessionPath, std::move(Body)); + + if (Result.Error) + { + ZEN_WARN("sessions announce failed for '{}': HTTP error {} - {}", + m_Options.TargetUrl, + static_cast<int>(Result.Error->ErrorCode), + Result.Error->ErrorMessage); + return; + } + if (!IsHttpOk(Result.StatusCode)) + { + ZEN_WARN("sessions announce failed for '{}': HTTP status {}", m_Options.TargetUrl, static_cast<int>(Result.StatusCode)); + return; + } + + ZEN_DEBUG("session announced to '{}'", m_Options.TargetUrl); + } + catch (const std::exception& Ex) + { + ZEN_WARN("sessions announce failed for '{}': {}", m_Options.TargetUrl, Ex.what()); + } +} + +void +SessionsServiceClient::DoUpdateMetadata(HttpClient& Http, CbObjectView Metadata) +{ + try + { + CbObject Body = BuildRequestBody(Metadata); + + MemoryView View = Body.GetView(); + IoBuffer Payload = IoBufferBuilder::MakeCloneFromMemory(View, ZenContentType::kCbObject); + + HttpClient::Response Result = Http.Put(m_SessionPath, Payload); + + if (Result.Error) + { + ZEN_WARN("sessions update failed for '{}': HTTP error {} - {}", + m_Options.TargetUrl, + static_cast<int>(Result.Error->ErrorCode), + Result.Error->ErrorMessage); + return; + } + if (!IsHttpOk(Result.StatusCode)) + { + ZEN_WARN("sessions update failed for '{}': HTTP status {}", m_Options.TargetUrl, static_cast<int>(Result.StatusCode)); + return; + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("sessions update failed for '{}': {}", m_Options.TargetUrl, Ex.what()); + } +} + +void +SessionsServiceClient::DoRemove(HttpClient& Http) +{ + try + { + HttpClient::Response Result = Http.Delete(m_SessionPath); + + if (Result.Error) + { + ZEN_WARN("sessions remove failed for '{}': HTTP error {} - {}", + m_Options.TargetUrl, + static_cast<int>(Result.Error->ErrorCode), + Result.Error->ErrorMessage); + return; + } + if (!IsHttpOk(Result.StatusCode)) + { + ZEN_WARN("sessions remove failed for '{}': HTTP status {}", m_Options.TargetUrl, static_cast<int>(Result.StatusCode)); + return; + } + + ZEN_DEBUG("session removed from '{}'", m_Options.TargetUrl); + } + catch (const std::exception& Ex) + { + ZEN_WARN("sessions remove failed for '{}': {}", m_Options.TargetUrl, Ex.what()); + } +} + +void +SessionsServiceClient::SendLogBatch(HttpClient& Http, const std::string& LogPath, const std::vector<SessionCommand>& Batch) +{ + try + { + CbObjectWriter Writer; + Writer.BeginArray("entries"); + for (const SessionCommand& Entry : Batch) + { + Writer.BeginObject(); + Writer << "level" << static_cast<int32_t>(Entry.LogLevel); + Writer << "message" << Entry.LogMessage.c_str(); + Writer.EndObject(); + } + Writer.EndArray(); + + HttpClient::Response Result = Http.Post(LogPath, Writer.Save()); + (void)Result; // Best-effort + } + catch (const std::exception&) + { + // Best-effort — silently discard on failure + } +} + +void +SessionsServiceClient::WorkerLoop() +{ + HttpClientSettings Settings = m_Options.ClientSettings; + Settings.ConnectTimeout = std::chrono::milliseconds(3000); + Settings.Timeout = std::chrono::milliseconds(5000); + HttpClient Http(m_Options.TargetUrl, Settings); + + std::string LogPath = m_SessionPath + "/log"; + bool Removed = false; + + static constexpr size_t BatchSize = 50; + + std::vector<SessionCommand> LogBatch; + LogBatch.reserve(BatchSize); + + auto FlushLogBatch = [&]() { + if (!LogBatch.empty()) + { + SendLogBatch(Http, LogPath, LogBatch); + LogBatch.clear(); + } + }; + + // Returns false to signal loop exit (Shutdown received) + auto ProcessCommand = [&](SessionCommand& Cmd) -> bool { + switch (Cmd.CommandType) + { + case SessionCommand::Type::Log: + LogBatch.push_back(std::move(Cmd)); + if (LogBatch.size() >= BatchSize) + { + FlushLogBatch(); + } + return true; + + case SessionCommand::Type::FlushLogs: + FlushLogBatch(); + return true; + + case SessionCommand::Type::Announce: + FlushLogBatch(); + DoAnnounce(Http, Cmd.Metadata); + return true; + + case SessionCommand::Type::UpdateMetadata: + FlushLogBatch(); + DoUpdateMetadata(Http, Cmd.Metadata); + return true; + + case SessionCommand::Type::Remove: + FlushLogBatch(); + if (!Removed) + { + Removed = true; + DoRemove(Http); + } + return true; + + case SessionCommand::Type::Shutdown: + { + // Drain remaining log entries from the queue + SessionCommand Remaining; + while (m_Queue.WaitAndDequeue(Remaining)) + { + if (Remaining.CommandType == SessionCommand::Type::Log) + { + LogBatch.push_back(std::move(Remaining)); + } + } + FlushLogBatch(); + + if (!Removed) + { + Removed = true; + DoRemove(Http); + } + return false; + } + } + return true; + }; + + SessionCommand Cmd; + while (m_Queue.WaitAndDequeue(Cmd)) + { + if (!ProcessCommand(Cmd)) + { + return; + } + + // Drain additional queued entries without blocking (batching optimization) + while (LogBatch.size() < BatchSize && m_Queue.Size() > 0) + { + SessionCommand Extra; + if (m_Queue.WaitAndDequeue(Extra)) + { + if (!ProcessCommand(Extra)) + { + return; + } + } + } + + FlushLogBatch(); + } +} + +} // namespace zen diff --git a/src/zenutil/splitconsole/logstreamlistener.cpp b/src/zenutil/splitconsole/logstreamlistener.cpp new file mode 100644 index 000000000..df985a196 --- /dev/null +++ b/src/zenutil/splitconsole/logstreamlistener.cpp @@ -0,0 +1,426 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/splitconsole/logstreamlistener.h> + +#include <zenbase/refcount.h> +#include <zencore/compactbinary.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/thread.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <vector> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// LogStreamSession - reads CbObject-framed messages from a single TCP connection + +class LogStreamSession : public RefCounted +{ +public: + LogStreamSession(asio::ip::tcp::socket Socket, LogStreamTarget& Target) : m_Socket(std::move(Socket)), m_Target(Target) {} + + void Start() { DoRead(); } + +private: + void DoRead() + { + Ref<LogStreamSession> Self(this); + m_Socket.async_read_some(asio::buffer(m_ReadBuf.data() + m_BufferUsed, m_ReadBuf.size() - m_BufferUsed), + [Self](const asio::error_code& Ec, std::size_t BytesRead) { + if (Ec) + { + return; // connection closed or error - session ends + } + Self->m_BufferUsed += BytesRead; + Self->ProcessBuffer(); + Self->DoRead(); + }); + } + + void ProcessBuffer() + { + // Try to consume as many complete CbObject messages as possible + while (m_BufferUsed > 0) + { + MemoryView View = MakeMemoryView(m_ReadBuf.data(), m_BufferUsed); + CbFieldType Type; + uint64_t Size = 0; + + if (!TryMeasureCompactBinary(View, Type, Size)) + { + break; // need more data + } + + if (Size > m_BufferUsed) + { + break; // need more data + } + + // Parse the CbObject + CbObject Obj = CbObject(SharedBuffer::Clone(MakeMemoryView(m_ReadBuf.data(), Size))); + + std::string_view Text = Obj["text"].AsString(); + std::string_view Source = Obj["source"].AsString(); + + // Check sequence number for gaps (dropped messages) + uint64_t Seq = Obj["seq"].AsUInt64(); + if (Seq > m_NextExpectedSeq) + { + uint64_t Dropped = Seq - m_NextExpectedSeq; + m_Target.AppendLogLine(fmt::format("[{}] *** {} log message(s) dropped ***", Source.empty() ? "log" : Source, Dropped)); + } + m_NextExpectedSeq = Seq + 1; + + // Split multi-line messages into individual AppendLogLine calls so that + // each line gets its own row in the target's log output. + while (!Text.empty()) + { + std::string_view Line = Text; + auto Pos = Text.find('\n'); + if (Pos != std::string_view::npos) + { + Line = Text.substr(0, Pos); + Text.remove_prefix(Pos + 1); + } + else + { + Text = {}; + } + + // Strip trailing CR from CRLF + if (!Line.empty() && Line.back() == '\r') + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + continue; + } + + if (!Source.empty()) + { + m_Target.AppendLogLine(fmt::format("[{}] {}", Source, Line)); + } + else + { + m_Target.AppendLogLine(Line); + } + } + + // Remove consumed bytes from buffer + std::size_t Consumed = static_cast<std::size_t>(Size); + std::memmove(m_ReadBuf.data(), m_ReadBuf.data() + Consumed, m_BufferUsed - Consumed); + m_BufferUsed -= Consumed; + } + + // If buffer is full and we can't parse a message, the message is too large - drop connection + if (m_BufferUsed == m_ReadBuf.size()) + { + ZEN_WARN("LogStreamSession: buffer full with no complete message, dropping connection"); + asio::error_code Ec; + m_Socket.close(Ec); + m_BufferUsed = 0; + } + } + + asio::ip::tcp::socket m_Socket; + LogStreamTarget& m_Target; + std::array<uint8_t, 65536> m_ReadBuf{}; + std::size_t m_BufferUsed = 0; + uint64_t m_NextExpectedSeq = 0; +}; + +////////////////////////////////////////////////////////////////////////// +// LogStreamListener::Impl + +struct LogStreamListener::Impl +{ + // Owned io_context mode - creates and runs its own thread + Impl(LogStreamTarget& Target, uint16_t Port) + : m_Target(Target) + , m_OwnedIoContext(std::make_unique<asio::io_context>()) + , m_Acceptor(*m_OwnedIoContext) + { + SetupAcceptor(Port); + m_IoThread = std::thread([this]() { + zen::SetCurrentThreadName("LogStreamIO"); + m_OwnedIoContext->run(); + }); + } + + // External io_context mode - caller drives the io_context + Impl(LogStreamTarget& Target, asio::io_context& IoContext, uint16_t Port) : m_Target(Target), m_Acceptor(IoContext) + { + SetupAcceptor(Port); + } + + ~Impl() { Shutdown(); } + + void Shutdown() + { + if (m_Stopped.exchange(true)) + { + return; + } + + asio::error_code Ec; + m_Acceptor.close(Ec); + + if (m_OwnedIoContext) + { + m_OwnedIoContext->stop(); + } + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + } + + uint16_t GetPort() const { return m_Port; } + +private: + void SetupAcceptor(uint16_t Port) + { + auto& IoCtx = m_OwnedIoContext ? *m_OwnedIoContext : m_Acceptor.get_executor().context(); + ZEN_UNUSED(IoCtx); + + // Try dual-stack IPv6 first (accepts both IPv4 and IPv6), fall back to IPv4-only + asio::error_code Ec; + m_Acceptor.open(asio::ip::tcp::v6(), Ec); + if (!Ec) + { + m_Acceptor.set_option(asio::ip::v6_only(false), Ec); + if (!Ec) + { + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::tcp::v6(), Port), Ec); + } + } + + if (Ec) + { + // Fall back to IPv4-only + if (m_Acceptor.is_open()) + { + m_Acceptor.close(); + } + m_Acceptor.open(asio::ip::tcp::v4()); + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::tcp::v4(), Port)); + } + + m_Acceptor.listen(); + m_Port = m_Acceptor.local_endpoint().port(); + StartAccept(); + } + + void StartAccept() + { + m_Acceptor.async_accept([this](const asio::error_code& Ec, asio::ip::tcp::socket Socket) { + if (Ec) + { + return; // acceptor closed + } + + Ref<LogStreamSession> Session(new LogStreamSession(std::move(Socket), m_Target)); + Session->Start(); + + if (!m_Stopped.load()) + { + StartAccept(); + } + }); + } + + LogStreamTarget& m_Target; + std::unique_ptr<asio::io_context> m_OwnedIoContext; // null when using external io_context + asio::ip::tcp::acceptor m_Acceptor; + std::thread m_IoThread; + uint16_t m_Port = 0; + std::atomic<bool> m_Stopped{false}; +}; + +////////////////////////////////////////////////////////////////////////// +// LogStreamListener + +LogStreamListener::LogStreamListener(LogStreamTarget& Target, uint16_t Port) : m_Impl(std::make_unique<Impl>(Target, Port)) +{ +} + +LogStreamListener::LogStreamListener(LogStreamTarget& Target, asio::io_context& IoContext, uint16_t Port) +: m_Impl(std::make_unique<Impl>(Target, IoContext, Port)) +{ +} + +LogStreamListener::~LogStreamListener() = default; + +uint16_t +LogStreamListener::GetPort() const +{ + return m_Impl->GetPort(); +} + +void +LogStreamListener::Shutdown() +{ + m_Impl->Shutdown(); +} + +} // namespace zen + +#if ZEN_WITH_TESTS + +# include <zencore/testing.h> +# include <zenutil/splitconsole/tcplogstreamsink.h> + +namespace zen { + +void +logstreamlistener_forcelink() +{ +} + +namespace { + + class CollectingTarget : public LogStreamTarget + { + public: + void AppendLogLine(std::string_view Text) override + { + std::lock_guard<std::mutex> Lock(m_Mutex); + m_Lines.emplace_back(Text); + m_Cv.notify_all(); + } + + std::vector<std::string> WaitForLines(size_t Count, std::chrono::milliseconds Timeout = std::chrono::milliseconds(5000)) + { + std::unique_lock<std::mutex> Lock(m_Mutex); + m_Cv.wait_for(Lock, Timeout, [&]() { return m_Lines.size() >= Count; }); + return m_Lines; + } + + private: + std::mutex m_Mutex; + std::condition_variable m_Cv; + std::vector<std::string> m_Lines; + }; + + logging::LogMessage MakeLogMessage(std::string_view Text, logging::LogLevel Level = logging::Info) + { + static logging::LogPoint Point{0, 0, Level, {}}; + Point.Level = Level; + return logging::LogMessage(Point, "test", Text); + } + +} // namespace + +TEST_SUITE_BEGIN("util.logstreamlistener"); + +TEST_CASE("BasicMessageDelivery") +{ + CollectingTarget Target; + LogStreamListener Listener(Target); + + { + TcpLogStreamSink Sink("127.0.0.1", Listener.GetPort(), "TestSource", 64); + Sink.Log(MakeLogMessage("hello world")); + Sink.Log(MakeLogMessage("second line")); + } + + auto Lines = Target.WaitForLines(2); + REQUIRE(Lines.size() == 2); + CHECK(Lines[0] == "[TestSource] hello world"); + CHECK(Lines[1] == "[TestSource] second line"); +} + +TEST_CASE("MultiLineMessageSplit") +{ + CollectingTarget Target; + LogStreamListener Listener(Target); + + { + TcpLogStreamSink Sink("127.0.0.1", Listener.GetPort(), "src", 64); + Sink.Log(MakeLogMessage("line1\nline2\nline3")); + } + + auto Lines = Target.WaitForLines(3); + REQUIRE(Lines.size() == 3); + CHECK(Lines[0] == "[src] line1"); + CHECK(Lines[1] == "[src] line2"); + CHECK(Lines[2] == "[src] line3"); +} + +TEST_CASE("DroppedMessageDetection") +{ + // Test sequence-gap detection deterministically by sending raw CbObjects + // with an explicit gap in sequence numbers, bypassing TcpLogStreamSink. + CollectingTarget Target; + LogStreamListener Listener(Target); + + { + asio::io_context IoContext; + asio::ip::tcp::socket Socket(IoContext); + Socket.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), Listener.GetPort())); + + // Send seq=0, then seq=5 - the listener should detect a gap of 4 + for (uint64_t Seq : {uint64_t(0), uint64_t(5)}) + { + CbObjectWriter Writer; + Writer.AddString("text", fmt::format("msg{}", Seq)); + Writer.AddString("source", "src"); + Writer.AddInteger("seq", Seq); + CbObject Obj = Writer.Save(); + MemoryView View = Obj.GetView(); + + asio::write(Socket, asio::buffer(View.GetData(), View.GetSize())); + } + } + + // Expect: msg0, drop notice, msg5 + auto Lines = Target.WaitForLines(3); + REQUIRE(Lines.size() >= 3); + CHECK(Lines[0] == "[src] msg0"); + CHECK(Lines[1].find("4 log message(s) dropped") != std::string::npos); + CHECK(Lines[2] == "[src] msg5"); +} + +TEST_CASE("SequenceNumbersAreContiguous") +{ + CollectingTarget Target; + LogStreamListener Listener(Target); + + constexpr int NumMessages = 5; + { + TcpLogStreamSink Sink("127.0.0.1", Listener.GetPort(), "seq", 64); + for (int i = 0; i < NumMessages; i++) + { + Sink.Log(MakeLogMessage(fmt::format("msg{}", i))); + } + } + + auto Lines = Target.WaitForLines(NumMessages); + REQUIRE(Lines.size() == NumMessages); + + // No "dropped" notices should appear when nothing is dropped + for (auto& Line : Lines) + { + CHECK(Line.find("dropped") == std::string::npos); + } + + // Verify ordering + for (int i = 0; i < NumMessages; i++) + { + CHECK(Lines[i] == fmt::format("[seq] msg{}", i)); + } +} + +TEST_SUITE_END(); + +} // namespace zen + +#endif diff --git a/src/zenutil/testartifactprovider.cpp b/src/zenutil/testartifactprovider.cpp new file mode 100644 index 000000000..666a1758d --- /dev/null +++ b/src/zenutil/testartifactprovider.cpp @@ -0,0 +1,588 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/testartifactprovider.h> + +#include <zencore/except_fmt.h> +#include <zencore/filesystem.h> +#include <zencore/logging.h> + +#include <memory> +#include <system_error> + +namespace zen { + +namespace { + + std::string JoinKey(std::string_view Prefix, std::string_view RelativePath) + { + if (Prefix.empty()) + { + return std::string(RelativePath); + } + std::string Result; + Result.reserve(Prefix.size() + 1 + RelativePath.size()); + Result.append(Prefix); + if (Result.back() != '/') + { + Result.push_back('/'); + } + Result.append(RelativePath); + return Result; + } + + class LocalBackend + { + public: + explicit LocalBackend(std::filesystem::path RootDir) : m_RootDir(std::move(RootDir)) {} + + const std::filesystem::path& RootDir() const { return m_RootDir; } + + bool Exists(std::string_view RelativePath) const + { + std::error_code Ec; + return std::filesystem::exists(Resolve(RelativePath), Ec) && !Ec; + } + + [[nodiscard]] TestArtifactFetchResult Fetch(std::string_view RelativePath) const + { + TestArtifactFetchResult Result; + + std::filesystem::path FullPath = Resolve(RelativePath); + FileContents Contents = ReadFile(FullPath); + if (!Contents) + { + Result.Error = fmt::format("failed to read '{}': {}", FullPath.string(), Contents.ErrorCode.message()); + return Result; + } + Result.Content = Contents.Flatten(); + return Result; + } + + [[nodiscard]] TestArtifactResult Store(std::string_view RelativePath, IoBuffer Content) const + { + TestArtifactResult Result; + + std::filesystem::path FullPath = Resolve(RelativePath); + std::filesystem::path ParentPath = FullPath.parent_path(); + if (!ParentPath.empty()) + { + std::error_code Ec; + std::filesystem::create_directories(ParentPath, Ec); + if (Ec) + { + Result.Error = fmt::format("failed to create '{}': {}", ParentPath.string(), Ec.message()); + return Result; + } + } + + try + { + WriteFile(FullPath, std::move(Content)); + } + catch (const std::exception& Ex) + { + Result.Error = fmt::format("failed to write '{}': {}", FullPath.string(), Ex.what()); + } + return Result; + } + + [[nodiscard]] TestArtifactListResult List(std::string_view Prefix) const + { + TestArtifactListResult Result; + + std::error_code Ec; + if (!std::filesystem::exists(m_RootDir, Ec)) + { + Result.Error = fmt::format("root directory does not exist: {}", m_RootDir.string()); + return Result; + } + + std::filesystem::recursive_directory_iterator It(m_RootDir, Ec); + if (Ec) + { + Result.Error = fmt::format("failed to iterate '{}': {}", m_RootDir.string(), Ec.message()); + return Result; + } + + for (const std::filesystem::directory_entry& Entry : It) + { + if (!Entry.is_regular_file()) + { + continue; + } + + std::filesystem::path Relative = std::filesystem::relative(Entry.path(), m_RootDir, Ec); + if (Ec) + { + continue; + } + + std::string RelString = Relative.generic_string(); + if (!Prefix.empty() && RelString.compare(0, Prefix.size(), Prefix) != 0) + { + continue; + } + + TestArtifactInfo Info; + Info.RelativePath = std::move(RelString); + Info.Size = Entry.file_size(Ec); + Result.Artifacts.push_back(std::move(Info)); + } + return Result; + } + + private: + std::filesystem::path Resolve(std::string_view RelativePath) const { return m_RootDir / std::filesystem::path(RelativePath); } + + std::filesystem::path m_RootDir; + }; + + class S3Backend + { + public: + S3Backend(S3ClientOptions ClientOptions, std::string KeyPrefix) + : m_ClientOptions(std::move(ClientOptions)) + , m_KeyPrefix(std::move(KeyPrefix)) + , m_Client(m_ClientOptions) + { + } + + std::string Describe() const { return fmt::format("s3:{}/{}", m_Client.BucketName(), m_KeyPrefix); } + + bool Exists(std::string_view RelativePath) + { + S3HeadObjectResult Head = m_Client.HeadObject(JoinKey(m_KeyPrefix, RelativePath)); + return Head.Status == HeadObjectResult::Found; + } + + TestArtifactFetchResult Fetch(std::string_view RelativePath) + { + TestArtifactFetchResult Result; + + S3GetObjectResult Get = m_Client.GetObject(JoinKey(m_KeyPrefix, RelativePath)); + if (!Get) + { + Result.Error = std::move(Get.Error); + return Result; + } + Result.Content = std::move(Get.Content); + return Result; + } + + TestArtifactListResult List(std::string_view Prefix) + { + TestArtifactListResult Result; + + std::string FullPrefix = JoinKey(m_KeyPrefix, Prefix); + S3ListObjectsResult Listing = m_Client.ListObjects(FullPrefix); + if (!Listing) + { + Result.Error = std::move(Listing.Error); + return Result; + } + + size_t StripLen = m_KeyPrefix.size(); + if (StripLen > 0 && m_KeyPrefix.back() != '/') + { + StripLen += 1; // account for the separator JoinKey inserts + } + + Result.Artifacts.reserve(Listing.Objects.size()); + for (S3ObjectInfo& Obj : Listing.Objects) + { + TestArtifactInfo Info; + Info.RelativePath = (StripLen <= Obj.Key.size()) ? Obj.Key.substr(StripLen) : std::move(Obj.Key); + Info.Size = Obj.Size; + Result.Artifacts.push_back(std::move(Info)); + } + return Result; + } + + private: + S3ClientOptions m_ClientOptions; + std::string m_KeyPrefix; + S3Client m_Client; + }; + + class TestArtifactProviderImpl final : public TestArtifactProvider + { + public: + TestArtifactProviderImpl(LocalBackend Cache, std::unique_ptr<S3Backend> Primary) + : m_Cache(std::move(Cache)) + , m_Primary(std::move(Primary)) + { + } + + std::string Describe() const override + { + if (m_Primary) + { + return fmt::format("cache:{} <- {}", m_Cache.RootDir().string(), m_Primary->Describe()); + } + return fmt::format("local:{}", m_Cache.RootDir().string()); + } + + bool Exists(std::string_view RelativePath) override + { + if (m_Cache.Exists(RelativePath)) + { + return true; + } + return m_Primary && m_Primary->Exists(RelativePath); + } + + TestArtifactFetchResult Fetch(std::string_view RelativePath) override + { + if (m_Cache.Exists(RelativePath)) + { + TestArtifactFetchResult Result = m_Cache.Fetch(RelativePath); + if (Result) + { + ZEN_INFO("test artifact '{}' served from cache {} ({} bytes)", + RelativePath, + m_Cache.RootDir().string(), + Result.Content.GetSize()); + } + return Result; + } + + if (!m_Primary) + { + TestArtifactFetchResult Result; + Result.Error = fmt::format("artifact '{}' not found in cache and no remote source configured", RelativePath); + return Result; + } + + ZEN_INFO("downloading test artifact '{}' from {}", RelativePath, m_Primary->Describe()); + + TestArtifactFetchResult Fetched = m_Primary->Fetch(RelativePath); + if (!Fetched) + { + return Fetched; + } + + ZEN_INFO("test artifact '{}' fetched from {} ({} bytes)", RelativePath, m_Primary->Describe(), Fetched.Content.GetSize()); + + TestArtifactResult StoreRes = m_Cache.Store(RelativePath, Fetched.Content); + if (!StoreRes) + { + ZEN_WARN("failed to cache artifact '{}' into {}: {}", RelativePath, m_Cache.RootDir().string(), StoreRes.Error); + } + return Fetched; + } + + TestArtifactListResult List(std::string_view Prefix) override + { + if (m_Primary) + { + return m_Primary->List(Prefix); + } + return m_Cache.List(Prefix); + } + + TestArtifactResult Store(std::string_view RelativePath, IoBuffer Content) override + { + return m_Cache.Store(RelativePath, std::move(Content)); + } + + private: + LocalBackend m_Cache; + std::unique_ptr<S3Backend> m_Primary; + }; + + void ApplyS3UrlToOptions(std::string_view Url, S3ClientOptions& Client, std::string& KeyPrefix) + { + constexpr std::string_view kScheme = "s3://"; + if (Url.substr(0, kScheme.size()) == kScheme) + { + Url.remove_prefix(kScheme.size()); + } + + size_t Slash = Url.find('/'); + std::string_view Bucket = (Slash == std::string_view::npos) ? Url : Url.substr(0, Slash); + std::string_view Prefix = (Slash == std::string_view::npos) ? std::string_view{} : Url.substr(Slash + 1); + + if (Client.BucketName.empty() && !Bucket.empty()) + { + Client.BucketName.assign(Bucket); + } + if (KeyPrefix.empty() && !Prefix.empty()) + { + KeyPrefix.assign(Prefix); + } + } + + // Returns true when the caller has explicitly (or implicitly by platform) disabled + // the EC2 Instance Metadata Service fallback. Honors the standard AWS env var + // AWS_EC2_METADATA_DISABLED=true and skips by default on macOS, where Mac EC2 + // instances are rare and the link-local probe would just emit noise on failure. + bool IsImdsDisabled() + { +#if ZEN_PLATFORM_MAC + return true; +#else + std::string Disabled = GetEnvVariable("AWS_EC2_METADATA_DISABLED"); + return Disabled == "true" || Disabled == "TRUE" || Disabled == "1"; +#endif + } + + void ApplyAwsEnvDefaults(S3ClientOptions& Client) + { + if (std::string Region = GetEnvVariable("AWS_DEFAULT_REGION"); !Region.empty()) + { + Client.Region = std::move(Region); + } + else if (std::string FallbackRegion = GetEnvVariable("AWS_REGION"); !FallbackRegion.empty()) + { + Client.Region = std::move(FallbackRegion); + } + + if (Client.Endpoint.empty()) + { + Client.Endpoint = GetEnvVariable("AWS_ENDPOINT_URL"); + } + + if (Client.Credentials.AccessKeyId.empty() && !Client.CredentialProvider) + { + std::string AccessKeyId = GetEnvVariable("AWS_ACCESS_KEY_ID"); + if (!AccessKeyId.empty()) + { + Client.Credentials.AccessKeyId = std::move(AccessKeyId); + Client.Credentials.SecretAccessKey = GetEnvVariable("AWS_SECRET_ACCESS_KEY"); + Client.Credentials.SessionToken = GetEnvVariable("AWS_SESSION_TOKEN"); + } + else if (!IsImdsDisabled()) + { + // Fall back to the EC2 Instance Metadata Service so self-hosted runners + // with an attached IAM role can sign S3 requests without static creds. + Client.CredentialProvider = Ref<ImdsCredentialProvider>(new ImdsCredentialProvider({})); + } + } + } + +} // namespace + +std::filesystem::path +GetDefaultLocalTestArtifactPath() +{ + std::filesystem::path SystemRoot = PickDefaultSystemRootDirectory(); + if (SystemRoot.empty()) + { + return {}; + } + return SystemRoot / kDefaultLocalTestArtifactDirName; +} + +bool +S3TestArtifactsAvailable() +{ + if (GetEnvVariable(kTestArtifactsS3EnvVar).empty()) + { + return false; + } + if (!GetEnvVariable("AWS_ACCESS_KEY_ID").empty() || !GetEnvVariable("AWS_SESSION_TOKEN").empty()) + { + return true; + } + return !IsImdsDisabled(); +} + +bool +TestArtifactsAvailable() +{ + if (!GetEnvVariable(kTestArtifactsPathEnvVar).empty()) + { + return true; + } + return S3TestArtifactsAvailable(); +} + +Ref<TestArtifactProvider> +CreateTestArtifactProvider(TestArtifactProviderOptions Options) +{ + if (Options.CacheDir.empty()) + { + std::string EnvValue = GetEnvVariable(kTestArtifactsPathEnvVar); + if (!EnvValue.empty()) + { + Options.CacheDir = std::filesystem::path(EnvValue); + } + } + if (Options.CacheDir.empty()) + { + Options.CacheDir = GetDefaultLocalTestArtifactPath(); + } + if (Options.CacheDir.empty()) + { + return {}; + } + + if (Options.S3Client.BucketName.empty() || Options.S3KeyPrefix.empty()) + { + std::string EnvValue = GetEnvVariable(kTestArtifactsS3EnvVar); + if (!EnvValue.empty()) + { + ApplyS3UrlToOptions(EnvValue, Options.S3Client, Options.S3KeyPrefix); + } + } + + std::unique_ptr<S3Backend> Primary; + if (!Options.S3Client.BucketName.empty()) + { + ApplyAwsEnvDefaults(Options.S3Client); + Primary = std::make_unique<S3Backend>(std::move(Options.S3Client), std::move(Options.S3KeyPrefix)); + } + + return Ref<TestArtifactProvider>(new TestArtifactProviderImpl(LocalBackend(std::move(Options.CacheDir)), std::move(Primary))); +} + +void +testartifactprovider_forcelink() +{ +} + +} // namespace zen + +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +# include <zenutil/cloud/minioprocess.h> + +# include <zencore/memoryview.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> + +# include <cstring> + +namespace zen::tests { + +using namespace std::literals; + +namespace { + + bool ContentMatches(const IoBuffer& Buf, std::string_view Expected) + { + return Buf.GetSize() == Expected.size() && std::memcmp(Buf.GetData(), Expected.data(), Expected.size()) == 0; + } + +} // namespace + +TEST_SUITE_BEGIN("util.testartifactprovider"); + +TEST_CASE("local_only.store_fetch_list") +{ + ScopedTemporaryDirectory CacheDir; + + // Keep the test hermetic: ignore any S3 configuration the developer may have in the environment. + ScopedEnvVar EnvS3(kTestArtifactsS3EnvVar, ""); + + TestArtifactProviderOptions Opts; + Opts.CacheDir = CacheDir.Path(); + Ref<TestArtifactProvider> Provider = CreateTestArtifactProvider(std::move(Opts)); + REQUIRE(Provider); + + CHECK_FALSE(Provider->Exists("missing.txt")); + + constexpr std::string_view kContent = "local payload"sv; + TestArtifactResult StoreRes = Provider->Store("greet/hello.txt", IoBufferBuilder::MakeFromMemory(MakeMemoryView(kContent))); + REQUIRE_MESSAGE(StoreRes.IsSuccess(), StoreRes.Error); + + CHECK(Provider->Exists("greet/hello.txt")); + + TestArtifactFetchResult Fetch = Provider->Fetch("greet/hello.txt"); + REQUIRE_MESSAGE(Fetch.IsSuccess(), Fetch.Error); + CHECK(ContentMatches(Fetch.Content, kContent)); + + TestArtifactFetchResult Missing = Provider->Fetch("nope.txt"); + CHECK_FALSE(Missing.IsSuccess()); + + TestArtifactListResult ListRes = Provider->List(""); + REQUIRE_MESSAGE(ListRes.IsSuccess(), ListRes.Error); + REQUIRE_EQ(ListRes.Artifacts.size(), 1u); + CHECK_EQ(ListRes.Artifacts.front().RelativePath, "greet/hello.txt"); +} + +TEST_CASE("minio.s3_primary") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = 19020; + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("artifacts-test"); + + // Seed S3 directly so we can verify the provider's read path. + S3ClientOptions SeedOpts; + SeedOpts.BucketName = "artifacts-test"; + SeedOpts.Region = "us-east-1"; + SeedOpts.Endpoint = Minio.Endpoint(); + SeedOpts.PathStyle = true; + SeedOpts.Credentials.AccessKeyId = std::string(Minio.RootUser()); + SeedOpts.Credentials.SecretAccessKey = std::string(Minio.RootPassword()); + S3Client SeedClient(SeedOpts); + + constexpr std::string_view kHello = "hello from minio"sv; + REQUIRE(SeedClient.PutObject("artifacts/hello.txt", IoBufferBuilder::MakeFromMemory(MakeMemoryView(kHello))).IsSuccess()); + + ScopedTemporaryDirectory CacheDir; + + // Configure everything via environment variables to exercise the env-based defaults. + ScopedEnvVar EnvPath(kTestArtifactsPathEnvVar, CacheDir.Path().string()); + ScopedEnvVar EnvS3(kTestArtifactsS3EnvVar, "s3://artifacts-test/artifacts"); + ScopedEnvVar EnvEndpoint("AWS_ENDPOINT_URL", Minio.Endpoint()); + ScopedEnvVar EnvRegion("AWS_DEFAULT_REGION", "us-east-1"); + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + TestArtifactProviderOptions Opts; + Opts.S3Client.PathStyle = true; // MinIO requires path-style addressing + Ref<TestArtifactProvider> Provider = CreateTestArtifactProvider(std::move(Opts)); + REQUIRE(Provider); + + // -- exists checks consult primary when cache is cold -------------------- + CHECK(Provider->Exists("hello.txt")); + CHECK_FALSE(Provider->Exists("nope.txt")); + + // -- first fetch pulls from primary and populates the cache -------------- + TestArtifactFetchResult First = Provider->Fetch("hello.txt"); + REQUIRE_MESSAGE(First.IsSuccess(), First.Error); + CHECK(ContentMatches(First.Content, kHello)); + CHECK(std::filesystem::exists(CacheDir.Path() / "hello.txt")); + + // -- delete from S3, then fetch again: must come from cache -------------- + REQUIRE(SeedClient.DeleteObject("artifacts/hello.txt").IsSuccess()); + TestArtifactFetchResult Cached = Provider->Fetch("hello.txt"); + REQUIRE_MESSAGE(Cached.IsSuccess(), Cached.Error); + CHECK(ContentMatches(Cached.Content, kHello)); + + // -- list reflects the primary (S3) source ------------------------------ + REQUIRE(SeedClient.PutObject("artifacts/sub/file-a.bin", IoBufferBuilder::MakeFromMemory(MakeMemoryView("A"sv))).IsSuccess()); + REQUIRE(SeedClient.PutObject("artifacts/sub/file-b.bin", IoBufferBuilder::MakeFromMemory(MakeMemoryView("B"sv))).IsSuccess()); + + TestArtifactListResult ListRes = Provider->List(""); + REQUIRE_MESSAGE(ListRes.IsSuccess(), ListRes.Error); + + auto HasPath = [&](std::string_view Rel) { + for (const TestArtifactInfo& Info : ListRes.Artifacts) + { + if (Info.RelativePath == Rel) + { + return true; + } + } + return false; + }; + CHECK(HasPath("sub/file-a.bin")); + CHECK(HasPath("sub/file-b.bin")); + + // -- fetching a missing artifact surfaces a remote error ----------------- + TestArtifactFetchResult MissingFetch = Provider->Fetch("does-not-exist.bin"); + CHECK_FALSE(MissingFetch.IsSuccess()); +} + +TEST_SUITE_END(); + +} // namespace zen::tests + +#endif diff --git a/src/zenutil/windows/windowsservice.cpp b/src/zenutil/windows/windowsservice.cpp index ebb88b018..383568650 100644 --- a/src/zenutil/windows/windowsservice.cpp +++ b/src/zenutil/windows/windowsservice.cpp @@ -16,7 +16,6 @@ SERVICE_STATUS gSvcStatus; SERVICE_STATUS_HANDLE gSvcStatusHandle; -HANDLE ghSvcStopEvent = NULL; void SvcInstall(void); @@ -205,21 +204,6 @@ WindowsService::SvcMain() ReportSvcStatus(SERVICE_START_PENDING, NO_ERROR, 3000); - // Create an event. The control handler function, SvcCtrlHandler, - // signals this event when it receives the stop control code. - - ghSvcStopEvent = CreateEvent(NULL, // default security attributes - TRUE, // manual reset event - FALSE, // not signaled - NULL); // no name - - if (ghSvcStopEvent == NULL) - { - ReportSvcStatus(SERVICE_STOPPED, GetLastError(), 0); - - return 1; - } - int ReturnCode = Run(); return ReturnCode; } @@ -549,9 +533,18 @@ ReportSvcStatus(DWORD dwCurrentState, DWORD dwWin32ExitCode, DWORD dwWaitHint) gSvcStatus.dwWaitHint = dwWaitHint; if (dwCurrentState == SERVICE_START_PENDING) + { gSvcStatus.dwControlsAccepted = 0; + } + else if (dwCurrentState == SERVICE_STOP_PENDING || dwCurrentState == SERVICE_STOPPED) + { + // We are already stopping/stopped - don't accept further control codes. + gSvcStatus.dwControlsAccepted = 0; + } else - gSvcStatus.dwControlsAccepted = SERVICE_ACCEPT_STOP; + { + gSvcStatus.dwControlsAccepted = SERVICE_ACCEPT_STOP | SERVICE_ACCEPT_SHUTDOWN; + } if ((dwCurrentState == SERVICE_RUNNING) || (dwCurrentState == SERVICE_STOPPED)) gSvcStatus.dwCheckPoint = 0; @@ -573,14 +566,14 @@ WindowsService::SvcCtrlHandler(DWORD dwCtrl) switch (dwCtrl) { case SERVICE_CONTROL_STOP: - ReportSvcStatus(SERVICE_STOP_PENDING, NO_ERROR, 0); + case SERVICE_CONTROL_SHUTDOWN: + // SCM gives us dwWaitHint milliseconds to complete the stop; 30s leaves + // plenty of headroom for the HTTP server and service state to tear down. + ReportSvcStatus(SERVICE_STOP_PENDING, NO_ERROR, 30000); - // Signal the service to stop. - - SetEvent(ghSvcStopEvent); + // The HTTP server run loops poll IsApplicationExitRequested(), so this + // flag is how we ask them to wind down. zen::RequestApplicationExit(0); - - ReportSvcStatus(gSvcStatus.dwCurrentState, NO_ERROR, 0); return; case SERVICE_CONTROL_INTERROGATE: diff --git a/src/zenutil/xmake.lua b/src/zenutil/xmake.lua index 1e19f7b2f..e28f6e345 100644 --- a/src/zenutil/xmake.lua +++ b/src/zenutil/xmake.lua @@ -9,8 +9,15 @@ target('zenutil') add_deps("zencore", "zenhttp") add_deps("cxxopts") add_deps("robin-map") + if is_plat("linux", "macosx") then + add_syslinks("dl") + end add_packages("json11") + if is_plat("linux", "macosx") then + add_packages("openssl3") + end + if is_plat("linux") then add_includedirs("$(projectdir)/thirdparty/systemd/include") add_linkdirs("$(projectdir)/thirdparty/systemd/lib") diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index 8eaf2cf5b..2d4334ffa 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -15,6 +15,7 @@ #include <zencore/timer.h> #include <atomic> +#include <string> #include <gsl/gsl-lite.hpp> @@ -180,7 +181,7 @@ ZenServerState::Initialize() ThrowLastError("Could not map view of Zen server state"); } #else - int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, geteuid() == 0 ? 0766 : 0666); + int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT, geteuid() == 0 ? 0766 : 0666); if (Fd < 0) { // Work around a potential issue if the service user is changed in certain configurations. @@ -190,7 +191,7 @@ ZenServerState::Initialize() // shared memory object and retry, we'll be able to get past shm_open() so long as we have // the appropriate permissions to create the shared memory object. shm_unlink("/UnrealEngineZen"); - Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, geteuid() == 0 ? 0766 : 0666); + Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT, geteuid() == 0 ? 0766 : 0666); if (Fd < 0) { ThrowLastError("Could not open a shared memory object"); @@ -243,7 +244,7 @@ ZenServerState::InitializeReadOnly() ThrowLastError("Could not map view of Zen server state"); } #else - int Fd = shm_open("/UnrealEngineZen", O_RDONLY | O_CLOEXEC, 0666); + int Fd = shm_open("/UnrealEngineZen", O_RDONLY, 0666); if (Fd < 0) { return false; @@ -266,6 +267,8 @@ ZenServerState::InitializeReadOnly() ZenServerState::ZenServerEntry* ZenServerState::Lookup(int DesiredListenPort) const { + const uint32_t OurPid = GetCurrentProcessId(); + for (int i = 0; i < m_MaxEntryCount; ++i) { uint16_t EntryPort = m_Data[i].DesiredListenPort; @@ -273,6 +276,14 @@ ZenServerState::Lookup(int DesiredListenPort) const { if (DesiredListenPort == 0 || (EntryPort == DesiredListenPort)) { + // If the entry's PID matches our own but we haven't registered yet, + // this is a stale entry from a previous process incarnation (e.g. PID 1 + // reuse after unclean shutdown in k8s). Skip it. + if (m_Data[i].Pid == OurPid && m_OurEntry == nullptr) + { + continue; + } + std::error_code _; if (IsProcessRunning(m_Data[i].Pid, _)) { @@ -288,6 +299,8 @@ ZenServerState::Lookup(int DesiredListenPort) const ZenServerState::ZenServerEntry* ZenServerState::LookupByEffectivePort(int Port) const { + const uint32_t OurPid = GetCurrentProcessId(); + for (int i = 0; i < m_MaxEntryCount; ++i) { uint16_t EntryPort = m_Data[i].EffectiveListenPort; @@ -295,6 +308,11 @@ ZenServerState::LookupByEffectivePort(int Port) const { if (EntryPort == Port) { + if (m_Data[i].Pid == OurPid && m_OurEntry == nullptr) + { + continue; + } + std::error_code _; if (IsProcessRunning(m_Data[i].Pid, _)) { @@ -357,12 +375,26 @@ ZenServerState::Sweep() ZEN_ASSERT(m_IsReadOnly == false); + const uint32_t OurPid = GetCurrentProcessId(); + for (int i = 0; i < m_MaxEntryCount; ++i) { ZenServerEntry& Entry = m_Data[i]; if (Entry.DesiredListenPort) { + // If the entry's PID matches our own but we haven't registered yet, + // this is a stale entry from a previous process incarnation (e.g. PID 1 + // reuse after unclean shutdown in k8s). Reclaim it. + if (Entry.Pid == OurPid && m_OurEntry == nullptr) + { + ZEN_CONSOLE_DEBUG("Sweep - pid {} matches current process but no registration yet, reclaiming stale entry (port {})", + Entry.Pid.load(), + Entry.DesiredListenPort.load()); + Entry.Reset(); + continue; + } + std::error_code ErrorCode; if (Entry.Pid != 0 && IsProcessRunning(Entry.Pid, ErrorCode) == false) { @@ -619,7 +651,7 @@ ZenServerInstanceInfo::Create(const Oid& SessionId, const InstanceInfoData& Data ThrowLastError("Could not map instance info shared memory"); } #else - int Fd = shm_open(Name.c_str(), O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC, 0666); + int Fd = shm_open(Name.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0666); if (Fd < 0) { ThrowLastError("Could not create instance info shared memory"); @@ -686,7 +718,7 @@ ZenServerInstanceInfo::OpenReadOnly(const Oid& SessionId) return false; } #else - int Fd = shm_open(Name.c_str(), O_RDONLY | O_CLOEXEC, 0666); + int Fd = shm_open(Name.c_str(), O_RDONLY, 0666); if (Fd < 0) { return false; @@ -964,6 +996,7 @@ ZenServerInstance::Shutdown() ZEN_DEBUG("zenserver process {} ({}) exited", m_Name, m_Process.Pid()); int ExitCode = m_Process.GetExitCode(); m_Process.Reset(); + m_ShutdownEvent.reset(); return ExitCode; } @@ -993,6 +1026,7 @@ ZenServerInstance::Shutdown() ZEN_DEBUG("zenserver process {} ({}) exited", m_Name, m_Process.Pid()); int ExitCode = m_Process.GetExitCode(); m_Process.Reset(); + m_ShutdownEvent.reset(); return ExitCode; } else if (Ec) @@ -1020,6 +1054,7 @@ ZenServerInstance::Shutdown() int ExitCode = m_Process.GetExitCode(); ZEN_DEBUG("zenserver process {} ({}) exited", m_Name, m_Process.Pid()); m_Process.Reset(); + m_ShutdownEvent.reset(); return ExitCode; } ZEN_DEBUG("Detached from zenserver process {} ({})", m_Name, m_Process.Pid()); @@ -1078,8 +1113,23 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, ChildEventName << "Zen_Child_" << ChildId; NamedEvent ChildEvent{ChildEventName}; + const std::filesystem::path BaseDir = m_Env.ProgramBaseDir(); + const std::filesystem::path Executable = + m_ServerExecutablePath.empty() ? (BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL) : m_ServerExecutablePath; + ExtendableStringBuilder<512> CommandLine; - CommandLine << "zenserver" ZEN_EXE_SUFFIX_LITERAL; // see CreateProc() call for actual binary path + { + const std::string ExeUtf8 = PathToUtf8(Executable); + constexpr AsciiSet QuoteChars = " \t\""; + if (AsciiSet::HasAny(ExeUtf8.c_str(), QuoteChars)) + { + CommandLine << '"' << ExeUtf8 << '"'; + } + else + { + CommandLine << ExeUtf8; + } + } if (m_ServerMode == ServerMode::kHubServer) { @@ -1092,6 +1142,11 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, CommandLine << " --child-id " << ChildEventName; + if (!m_EnableExecutionHistory) + { + CommandLine << " --enable-execution-history=false"; + } + if (!ServerArgs.empty()) { CommandLine << " " << ServerArgs; @@ -1106,10 +1161,6 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, { CreationFlags |= CreateProcOptions::Flag_NewConsole; } - - const std::filesystem::path BaseDir = m_Env.ProgramBaseDir(); - const std::filesystem::path Executable = - m_ServerExecutablePath.empty() ? (BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL) : m_ServerExecutablePath; const std::filesystem::path OutputPath = (OpenConsole || m_Env.IsPassthroughOutput()) ? std::filesystem::path{} : std::filesystem::temp_directory_path() / ("zenserver_" + m_Name + ".log"); @@ -1371,18 +1422,31 @@ ZenServerInstance::OnServerReady() const ZenServerState::ZenServerEntry* Entry = nullptr; - if (m_BasePort) - { - Entry = State.Lookup(m_BasePort); - } - else + // The child process signals its ready event after writing its state entry, but under + // heavy instrumentation (e.g. sanitizers) the shared memory writes may not be immediately + // visible to this process. Retry briefly before giving up. + for (int Attempt = 0; Attempt < 10; ++Attempt) { - State.Snapshot([&](const ZenServerState::ZenServerEntry& InEntry) { - if (InEntry.Pid == (uint32_t)m_Process.Pid()) - { - Entry = &InEntry; - } - }); + if (m_BasePort) + { + Entry = State.Lookup(m_BasePort); + } + else + { + State.Snapshot([&](const ZenServerState::ZenServerEntry& InEntry) { + if (InEntry.Pid == (uint32_t)m_Process.Pid()) + { + Entry = &InEntry; + } + }); + } + + if (Entry) + { + break; + } + + Sleep(100); } if (!Entry) @@ -1428,6 +1492,16 @@ ZenServerInstance::IsRunning() const return m_Process.IsRunning(); } +void +ZenServerInstance::ResetDeadProcess() +{ + if (m_Process.IsValid() && !m_Process.IsRunning()) + { + m_Process.Reset(); + m_ShutdownEvent.reset(); + } +} + std::string ZenServerInstance::GetLogOutput() const { @@ -1553,4 +1627,136 @@ ValidateLockFileInfo(const LockFileInfo& Info, std::string& OutReason) return true; } +std::optional<int> +StartupZenServer(LoggerRef LogRef, const StartupZenServerOptions& Options) +{ + ZEN_SCOPED_LOG(LogRef); + + // Check if a matching server is already running + { + ZenServerState State; + if (State.InitializeReadOnly()) + { + uint32_t RunningPid = 0; + uint16_t RunningEffectivePort = 0; + State.Snapshot([&, DesiredPort = Options.Port](const ZenServerState::ZenServerEntry& Entry) { + if (RunningPid == 0 && (DesiredPort == 0 || Entry.DesiredListenPort.load() == DesiredPort)) + { + RunningPid = Entry.Pid.load(); + RunningEffectivePort = Entry.EffectiveListenPort.load(); + } + }); + if (RunningPid != 0) + { + ZEN_INFO("Zen server already running at port {}, pid {}", RunningEffectivePort, RunningPid); + return std::nullopt; + } + } + } + + std::filesystem::path ProgramBaseDir = Options.ProgramBaseDir; + if (ProgramBaseDir.empty()) + { + ProgramBaseDir = GetRunningExecutablePath().parent_path(); + } + + ZenServerEnvironment ServerEnvironment; + ServerEnvironment.Initialize(ProgramBaseDir); + ZenServerInstance Server(ServerEnvironment, Options.Mode); + Server.SetEnableExecutionHistory(Options.EnableExecutionHistory); + + std::string ServerArguments(Options.ExtraArgs); + if ((Options.Port != 0) && (ServerArguments.find("--port") == std::string::npos)) + { + ServerArguments.append(fmt::format(" --port {}", Options.Port)); + } + Server.SpawnServer(ServerArguments, Options.OpenConsole, /*WaitTimeoutMs*/ 0); + + constexpr int Timeout = 10000; + + if (!Server.WaitUntilReady(Timeout)) + { + ZEN_WARN("{}", Server.GetLogOutput()); + if (Server.IsRunning()) + { + ZEN_WARN("Zen server launch failed (timed out), terminating"); + Server.Terminate(); + return 1; + } + int ExitCode = Server.Shutdown(); + ZEN_WARN("Zen server failed to get to a ready state and exited with return code {}", ExitCode); + return ExitCode != 0 ? ExitCode : 1; + } + + if (Options.ShowLog) + { + ZEN_INFO("{}", Server.GetLogOutput()); + } + return 0; +} + +bool +ShutdownZenServer(LoggerRef LogRef, + ZenServerState& State, + ZenServerState::ZenServerEntry* Entry, + const std::filesystem::path& ProgramBaseDir) +{ + ZEN_SCOPED_LOG(LogRef); + int EntryPort = (int)Entry->DesiredListenPort.load(); + const uint32_t ServerProcessPid = Entry->Pid.load(); + try + { + ZenServerEnvironment ServerEnvironment; + ServerEnvironment.Initialize(ProgramBaseDir); + ZenServerInstance Server(ServerEnvironment); + Server.AttachToRunningServer(EntryPort); + + ZEN_INFO("attached to server on port {} (pid {}), requesting shutdown", EntryPort, ServerProcessPid); + + std::error_code Ec; + if (Server.SignalShutdown(Ec) && !Ec) + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 10000) + { + if (Server.WaitUntilExited(100, Ec) && !Ec) + { + ZEN_INFO("shutdown complete"); + return true; + } + else if (Ec) + { + ZEN_WARN("Waiting for server on port {} (pid {}) failed. Reason: '{}'", EntryPort, ServerProcessPid, Ec.message()); + } + } + } + else if (Ec) + { + ZEN_WARN("Requesting shutdown of server on port {} failed. Reason: '{}'", EntryPort, Ec.message()); + } + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("Exception caught when requesting shutdown: {}", Ex.what()); + } + + ZEN_INFO("Requesting detached shutdown of server on port {}", EntryPort); + Entry->SignalShutdownRequest(); + + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 10000) + { + State.Sweep(); + Entry = State.Lookup(EntryPort); + if (Entry == nullptr || Entry->Pid.load() != ServerProcessPid) + { + ZEN_INFO("Shutdown complete"); + return true; + } + Sleep(100); + } + + return false; +} + } // namespace zen diff --git a/src/zenutil/zenutil.cpp b/src/zenutil/zenutil.cpp index 734813b69..b9617b1ed 100644 --- a/src/zenutil/zenutil.cpp +++ b/src/zenutil/zenutil.cpp @@ -5,10 +5,17 @@ #if ZEN_WITH_TESTS # include <zenutil/cloud/imdscredentials.h> +# include <zenutil/consul.h> # include <zenutil/cloud/s3client.h> # include <zenutil/cloud/sigv4.h> -# include <zenutil/rpcrecording.h> # include <zenutil/config/commandlineoptions.h> +# include <zenutil/filesystemutils.h> +# include <zenutil/invocationhistory.h> +# include <zenutil/parallelsort.h> +# include <zenutil/rpcrecording.h> +# include <zenutil/splitconsole/logstreamlistener.h> +# include <zenutil/process/subprocessmanager.h> +# include <zenutil/testartifactprovider.h> # include <zenutil/wildcard.h> namespace zen { @@ -18,9 +25,16 @@ zenutil_forcelinktests() { cache::rpcrecord_forcelink(); commandlineoptions_forcelink(); + consul::consul_forcelink(); + filesystemutils_forcelink(); imdscredentials_forcelink(); + invocationhistory_forcelink(); + parallelsort_forcelink(); + logstreamlistener_forcelink(); + subprocessmanager_forcelink(); s3client_forcelink(); sigv4_forcelink(); + testartifactprovider_forcelink(); wildcard_forcelink(); } |