diff options
Diffstat (limited to 'src/zenutil')
43 files changed, 4895 insertions, 445 deletions
diff --git a/src/zenutil/cloud/imdscredentials.cpp b/src/zenutil/cloud/imdscredentials.cpp index dde1dc019..a23cb9c28 100644 --- a/src/zenutil/cloud/imdscredentials.cpp +++ b/src/zenutil/cloud/imdscredentials.cpp @@ -64,6 +64,7 @@ ImdsCredentialProvider::ImdsCredentialProvider(const ImdsCredentialProviderOptio .LogCategory = "imds", .ConnectTimeout = Options.ConnectTimeout, .Timeout = Options.RequestTimeout, + .RetryCount = 3, }) { ZEN_INFO("IMDS credential provider configured (endpoint: {})", m_HttpClient.GetBaseUri()); @@ -115,7 +116,7 @@ ImdsCredentialProvider::FetchToken() HttpClient::KeyValueMap Headers; Headers->emplace("X-aws-ec2-metadata-token-ttl-seconds", "21600"); - HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", Headers); + HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", IoBuffer{}, Headers); if (!Response.IsSuccess()) { ZEN_WARN("IMDS token request failed: {}", Response.ErrorMessage("PUT /latest/api/token")); @@ -213,7 +214,7 @@ ImdsCredentialProvider::FetchCredentials() } else { - // Expiration is in the past or unparseable — force refresh next time + // Expiration is in the past or unparseable - force refresh next time NewExpiresAt = std::chrono::steady_clock::now(); } @@ -226,7 +227,7 @@ ImdsCredentialProvider::FetchCredentials() if (KeyChanged) { - ZEN_INFO("IMDS credentials refreshed (AccessKeyId: {}...)", m_CachedCredentials.AccessKeyId.substr(0, 8)); + ZEN_INFO("IMDS credentials refreshed (AccessKeyId: {})", HideSensitiveString(m_CachedCredentials.AccessKeyId)); } else { @@ -369,7 +370,7 @@ TEST_CASE("imdscredentials.fetch_from_mock") TEST_CASE("imdscredentials.unreachable_endpoint") { - // Point at a non-existent server — should return empty credentials, not crash + // Point at a non-existent server - should return empty credentials, not crash ImdsCredentialProviderOptions Opts; Opts.Endpoint = "http://127.0.0.1:1"; // unlikely to have anything listening Opts.ConnectTimeout = std::chrono::milliseconds(100); diff --git a/src/zenutil/cloud/minioprocess.cpp b/src/zenutil/cloud/minioprocess.cpp index 457453bd8..2db0010dc 100644 --- a/src/zenutil/cloud/minioprocess.cpp +++ b/src/zenutil/cloud/minioprocess.cpp @@ -45,7 +45,7 @@ struct MinioProcess::Impl } CreateProcOptions Options; - Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; Options.Environment.emplace_back("MINIO_ROOT_USER", m_Options.RootUser); Options.Environment.emplace_back("MINIO_ROOT_PASSWORD", m_Options.RootPassword); @@ -102,7 +102,7 @@ struct MinioProcess::Impl { if (m_DataDir.empty()) { - ZEN_WARN("MinIO: Cannot create bucket before data directory is initialized — call SpawnMinioServer() first"); + ZEN_WARN("MinIO: Cannot create bucket before data directory is initialized - call SpawnMinioServer() first"); return; } diff --git a/src/zenutil/cloud/mockimds.cpp b/src/zenutil/cloud/mockimds.cpp index 6919fab4d..88b348ed6 100644 --- a/src/zenutil/cloud/mockimds.cpp +++ b/src/zenutil/cloud/mockimds.cpp @@ -93,7 +93,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request) return; } - // Autoscaling lifecycle state — 404 when not in an ASG + // Autoscaling lifecycle state - 404 when not in an ASG if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") { if (Aws.AutoscalingState.empty()) @@ -105,7 +105,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request) return; } - // Spot interruption notice — 404 when no interruption pending + // Spot interruption notice - 404 when no interruption pending if (Uri == "latest/meta-data/spot/instance-action") { if (Aws.SpotAction.empty()) @@ -117,7 +117,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request) return; } - // IAM role discovery — returns the role name + // IAM role discovery - returns the role name if (Uri == "latest/meta-data/iam/security-credentials/") { if (Aws.IamRoleName.empty()) diff --git a/src/zenutil/cloud/s3client.cpp b/src/zenutil/cloud/s3client.cpp index d9fde05d9..f8bed92da 100644 --- a/src/zenutil/cloud/s3client.cpp +++ b/src/zenutil/cloud/s3client.cpp @@ -135,6 +135,41 @@ namespace { return nullptr; } + /// Extract Code/Message from an S3 XML error body. Returns true if an <Error> element was + /// found, even if Code/Message are empty. + bool ExtractS3Error(std::string_view Body, std::string_view& OutCode, std::string_view& OutMessage) + { + if (Body.find("<Error>") == std::string_view::npos) + { + return false; + } + OutCode = ExtractXmlValue(Body, "Code"); + OutMessage = ExtractXmlValue(Body, "Message"); + return true; + } + + /// Build a human-readable error message for a failed S3 response. When the response body + /// contains an S3 `<Error>` element, the Code and Message fields are included in the string + /// so transient 4xx/5xx failures (SignatureDoesNotMatch, AuthorizationHeaderMalformed, etc.) + /// show up in logs instead of being swallowed. Falls back to the generic HTTP/transport + /// message when no XML body is available (HEAD responses, transport errors). + std::string S3ErrorMessage(std::string_view Prefix, const HttpClient::Response& Response) + { + if (!Response.Error.has_value() && Response.ResponsePayload) + { + std::string_view Body(reinterpret_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize()); + std::string_view Code; + std::string_view Message; + if (ExtractS3Error(Body, Code, Message) && (!Code.empty() || !Message.empty())) + { + ExtendableStringBuilder<256> Decoded; + DecodeXmlEntities(Message, Decoded); + return fmt::format("{}: HTTP status ({}) {} - {}", Prefix, static_cast<int>(Response.StatusCode), Code, Decoded.ToView()); + } + } + return Response.ErrorMessage(Prefix); + } + } // namespace std::string_view S3GetObjectResult::NotFoundErrorText = "Not found"; @@ -148,6 +183,7 @@ S3Client::S3Client(const S3ClientOptions& Options) , m_Credentials(Options.Credentials) , m_CredentialProvider(Options.CredentialProvider) , m_HttpClient(BuildEndpoint(), Options.HttpSettings) +, m_Verbose(Options.HttpSettings.Verbose) { m_Host = BuildHostHeader(); ZEN_INFO("S3 client configured for bucket '{}' in region '{}' (endpoint: {}, {})", @@ -186,6 +222,14 @@ S3Client::GetCurrentCredentials() } std::string +S3Client::BuildNoCredentialsError(std::string Context) +{ + std::string Err = fmt::format("{}: no credentials available", Context); + ZEN_WARN("{}", Err); + return Err; +} + +std::string S3Client::BuildEndpoint() const { if (!m_Endpoint.empty()) @@ -285,14 +329,18 @@ S3Client::GetSigningKey(std::string_view DateStamp) } HttpClient::KeyValueMap -S3Client::SignRequest(std::string_view Method, std::string_view Path, std::string_view CanonicalQueryString, std::string_view PayloadHash) +S3Client::SignRequest(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Path, + std::string_view CanonicalQueryString, + std::string_view PayloadHash, + std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders) { - SigV4Credentials Credentials = GetCurrentCredentials(); - std::string AmzDate = GetAmzTimestamp(); // Build sorted headers to sign (must be sorted by lowercase name) std::vector<std::pair<std::string, std::string>> HeadersToSign; + HeadersToSign.reserve(4 + ExtraSignedHeaders.size()); HeadersToSign.emplace_back("host", m_Host); HeadersToSign.emplace_back("x-amz-content-sha256", std::string(PayloadHash)); HeadersToSign.emplace_back("x-amz-date", AmzDate); @@ -300,6 +348,10 @@ S3Client::SignRequest(std::string_view Method, std::string_view Path, std::strin { HeadersToSign.emplace_back("x-amz-security-token", Credentials.SessionToken); } + for (const auto& [K, V] : ExtraSignedHeaders) + { + HeadersToSign.emplace_back(K, V); + } std::sort(HeadersToSign.begin(), HeadersToSign.end()); std::string_view DateStamp(AmzDate.data(), 8); @@ -316,6 +368,10 @@ S3Client::SignRequest(std::string_view Method, std::string_view Path, std::strin { Result->emplace("x-amz-security-token", Credentials.SessionToken); } + for (const auto& [K, V] : ExtraSignedHeaders) + { + Result->emplace(K, V); + } return Result; } @@ -323,31 +379,46 @@ S3Client::SignRequest(std::string_view Method, std::string_view Path, std::strin S3Result S3Client::PutObject(std::string_view Key, IoBuffer Content) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 PUT '{}' failed", Key); !Err.empty()) + { + return S3Result{std::move(Err)}; + } + std::string Path = KeyToPath(Key); // Hash the payload std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); - HttpClient::KeyValueMap Headers = SignRequest("PUT", Path, "", PayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, "", PayloadHash); HttpClient::Response Response = m_HttpClient.Put(Path, Content, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 PUT failed"); + std::string Err = S3ErrorMessage("S3 PUT failed", Response); ZEN_WARN("S3 PUT '{}' failed: {}", Key, Err); return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 PUT '{}' succeeded ({} bytes)", Key, Content.GetSize()); + if (m_Verbose) + { + ZEN_INFO("S3 PUT '{}' succeeded ({} bytes)", Key, Content.GetSize()); + } return {}; } S3GetObjectResult S3Client::GetObject(std::string_view Key, const std::filesystem::path& TempFilePath) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 GET '{}' failed", Key); !Err.empty()) + { + return S3GetObjectResult{S3Result{std::move(Err)}, {}}; + } + std::string Path = KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest("GET", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", Path, "", EmptyPayloadHash); HttpClient::Response Response = m_HttpClient.Download(Path, TempFilePath, Headers); if (!Response.IsSuccess()) @@ -357,12 +428,15 @@ S3Client::GetObject(std::string_view Key, const std::filesystem::path& TempFileP return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}}; } - std::string Err = Response.ErrorMessage("S3 GET failed"); + std::string Err = S3ErrorMessage("S3 GET failed", Response); ZEN_WARN("S3 GET '{}' failed: {}", Key, Err); return S3GetObjectResult{S3Result{std::move(Err)}, {}}; } - ZEN_DEBUG("S3 GET '{}' succeeded ({} bytes)", Key, Response.ResponsePayload.GetSize()); + if (m_Verbose) + { + ZEN_INFO("S3 GET '{}' succeeded ({} bytes)", Key, Response.ResponsePayload.GetSize()); + } return S3GetObjectResult{{}, std::move(Response.ResponsePayload)}; } @@ -370,9 +444,17 @@ S3GetObjectResult S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t RangeSize) { ZEN_ASSERT(RangeSize > 0); + + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 GET range '{}' [{}-{}] failed", Key, RangeStart, RangeStart + RangeSize - 1); + !Err.empty()) + { + return S3GetObjectResult{S3Result{std::move(Err)}, {}}; + } + std::string Path = KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest("GET", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", Path, "", EmptyPayloadHash); Headers->emplace("Range", fmt::format("bytes={}-{}", RangeStart, RangeStart + RangeSize - 1)); HttpClient::Response Response = m_HttpClient.Get(Path, Headers); @@ -383,7 +465,7 @@ S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t Ran return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}}; } - std::string Err = Response.ErrorMessage("S3 GET range failed"); + std::string Err = S3ErrorMessage("S3 GET range failed", Response); ZEN_WARN("S3 GET range '{}' [{}-{}] failed: {}", Key, RangeStart, RangeStart + RangeSize - 1, Err); return S3GetObjectResult{S3Result{std::move(Err)}, {}}; } @@ -403,39 +485,104 @@ S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t Ran return S3GetObjectResult{S3Result{std::move(Err)}, {}}; } - ZEN_DEBUG("S3 GET range '{}' [{}-{}] succeeded ({} bytes)", - Key, - RangeStart, - RangeStart + RangeSize - 1, - Response.ResponsePayload.GetSize()); + if (m_Verbose) + { + ZEN_INFO("S3 GET range '{}' [{}-{}] succeeded ({} bytes)", + Key, + RangeStart, + RangeStart + RangeSize - 1, + Response.ResponsePayload.GetSize()); + } return S3GetObjectResult{{}, std::move(Response.ResponsePayload)}; } S3Result S3Client::DeleteObject(std::string_view Key) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 DELETE '{}' failed", Key); !Err.empty()) + { + return S3Result{std::move(Err)}; + } + std::string Path = KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest("DELETE", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "DELETE", Path, "", EmptyPayloadHash); HttpClient::Response Response = m_HttpClient.Delete(Path, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 DELETE failed"); + std::string Err = S3ErrorMessage("S3 DELETE failed", Response); ZEN_WARN("S3 DELETE '{}' failed: {}", Key, Err); return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 DELETE '{}' succeeded", Key); + if (m_Verbose) + { + ZEN_INFO("S3 DELETE '{}' succeeded", Key); + } + return {}; +} + +S3Result +S3Client::Touch(std::string_view Key) +{ + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 Touch '{}' failed", Key); !Err.empty()) + { + return S3Result{std::move(Err)}; + } + + std::string Path = KeyToPath(Key); + + // x-amz-copy-source is always "/bucket/key" regardless of addressing style. + // Key must be URI-encoded except for '/' separators. When source and destination + // are identical, REPLACE is required; COPY is rejected with InvalidRequest. + const std::array<std::pair<std::string, std::string>, 2> ExtraSigned{{ + {"x-amz-copy-source", fmt::format("/{}/{}", m_BucketName, AwsUriEncode(Key, false))}, + {"x-amz-metadata-directive", "REPLACE"}, + }}; + + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, "", EmptyPayloadHash, ExtraSigned); + + HttpClient::Response Response = m_HttpClient.Put(Path, IoBuffer{}, Headers); + if (!Response.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 Touch failed", Response); + ZEN_WARN("S3 Touch '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + // Copy operations can return HTTP 200 with an error in the XML body. + std::string_view ResponseBody = Response.AsText(); + std::string_view ErrorCode; + std::string_view ErrorMessage; + if (ExtractS3Error(ResponseBody, ErrorCode, ErrorMessage)) + { + std::string Err = fmt::format("S3 Touch '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); + ZEN_WARN("{}", Err); + return S3Result{std::move(Err)}; + } + + if (m_Verbose) + { + ZEN_INFO("S3 Touch '{}' succeeded", Key); + } return {}; } S3HeadObjectResult S3Client::HeadObject(std::string_view Key) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 HEAD '{}' failed", Key); !Err.empty()) + { + return S3HeadObjectResult{S3Result{std::move(Err)}, {}, HeadObjectResult::Error}; + } + std::string Path = KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest("HEAD", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "HEAD", Path, "", EmptyPayloadHash); HttpClient::Response Response = m_HttpClient.Head(Path, Headers); if (!Response.IsSuccess()) @@ -445,7 +592,7 @@ S3Client::HeadObject(std::string_view Key) return S3HeadObjectResult{{}, {}, HeadObjectResult::NotFound}; } - std::string Err = Response.ErrorMessage("S3 HEAD failed"); + std::string Err = S3ErrorMessage("S3 HEAD failed", Response); ZEN_WARN("S3 HEAD '{}' failed: {}", Key, Err); return S3HeadObjectResult{S3Result{std::move(Err)}, {}, HeadObjectResult::Error}; } @@ -468,7 +615,10 @@ S3Client::HeadObject(std::string_view Key) Info.LastModified = *V; } - ZEN_DEBUG("S3 HEAD '{}' succeeded (size={})", Key, Info.Size); + if (m_Verbose) + { + ZEN_INFO("S3 HEAD '{}' succeeded (size={})", Key, Info.Size); + } return S3HeadObjectResult{{}, std::move(Info), HeadObjectResult::Found}; } @@ -481,6 +631,13 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) for (;;) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 ListObjectsV2 prefix='{}' failed", Prefix); !Err.empty()) + { + Result.Error = std::move(Err); + return Result; + } + // Build query parameters for ListObjectsV2 std::vector<std::pair<std::string, std::string>> QueryParams; QueryParams.emplace_back("list-type", "2"); @@ -499,13 +656,13 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) std::string CanonicalQS = BuildCanonicalQueryString(std::move(QueryParams)); std::string RootPath = BucketRootPath(); - HttpClient::KeyValueMap Headers = SignRequest("GET", RootPath, CanonicalQS, EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", RootPath, CanonicalQS, EmptyPayloadHash); std::string FullPath = BuildRequestPath(RootPath, CanonicalQS); HttpClient::Response Response = m_HttpClient.Get(FullPath, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 ListObjectsV2 failed"); + std::string Err = S3ErrorMessage("S3 ListObjectsV2 failed", Response); ZEN_WARN("S3 ListObjectsV2 prefix='{}' failed: {}", Prefix, Err); Result.Error = std::move(Err); return Result; @@ -565,10 +722,16 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) } ContinuationToken = std::string(NextToken); - ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' fetching next page ({} objects so far)", Prefix, Result.Objects.size()); + if (m_Verbose) + { + ZEN_INFO("S3 ListObjectsV2 prefix='{}' fetching next page ({} objects so far)", Prefix, Result.Objects.size()); + } } - ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' returned {} objects", Prefix, Result.Objects.size()); + if (m_Verbose) + { + ZEN_INFO("S3 ListObjectsV2 prefix='{}' returned {} objects", Prefix, Result.Objects.size()); + } return Result; } @@ -578,16 +741,22 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) S3CreateMultipartUploadResult S3Client::CreateMultipartUpload(std::string_view Key) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 CreateMultipartUpload '{}' failed", Key); !Err.empty()) + { + return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; + } + std::string Path = KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploads", ""}}); - HttpClient::KeyValueMap Headers = SignRequest("POST", Path, CanonicalQS, EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "POST", Path, CanonicalQS, EmptyPayloadHash); std::string FullPath = BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Post(FullPath, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 CreateMultipartUpload failed"); + std::string Err = S3ErrorMessage("S3 CreateMultipartUpload failed", Response); ZEN_WARN("S3 CreateMultipartUpload '{}' failed: {}", Key, Err); return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; } @@ -607,13 +776,22 @@ S3Client::CreateMultipartUpload(std::string_view Key) return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; } - ZEN_DEBUG("S3 CreateMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + if (m_Verbose) + { + ZEN_INFO("S3 CreateMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + } return S3CreateMultipartUploadResult{{}, std::string(UploadId)}; } S3UploadPartResult S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t PartNumber, IoBuffer Content) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 UploadPart '{}' part {} failed", Key, PartNumber); !Err.empty()) + { + return S3UploadPartResult{S3Result{std::move(Err)}, {}}; + } + std::string Path = KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({ {"partNumber", fmt::format("{}", PartNumber)}, @@ -622,13 +800,13 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); - HttpClient::KeyValueMap Headers = SignRequest("PUT", Path, CanonicalQS, PayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, CanonicalQS, PayloadHash); std::string FullPath = BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Put(FullPath, Content, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage(fmt::format("S3 UploadPart {} failed", PartNumber)); + std::string Err = S3ErrorMessage(fmt::format("S3 UploadPart {} failed", PartNumber), Response); ZEN_WARN("S3 UploadPart '{}' part {} failed: {}", Key, PartNumber, Err); return S3UploadPartResult{S3Result{std::move(Err)}, {}}; } @@ -642,7 +820,10 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P return S3UploadPartResult{S3Result{std::move(Err)}, {}}; } - ZEN_DEBUG("S3 UploadPart '{}' part {} succeeded ({} bytes, etag={})", Key, PartNumber, Content.GetSize(), *ETag); + if (m_Verbose) + { + ZEN_INFO("S3 UploadPart '{}' part {} succeeded ({} bytes, etag={})", Key, PartNumber, Content.GetSize(), *ETag); + } return S3UploadPartResult{{}, *ETag}; } @@ -651,6 +832,12 @@ S3Client::CompleteMultipartUpload(std::string_view Key, std::string_view UploadId, const std::vector<std::pair<uint32_t, std::string>>& PartETags) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 CompleteMultipartUpload '{}' failed", Key); !Err.empty()) + { + return S3Result{std::move(Err)}; + } + std::string Path = KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); @@ -666,7 +853,7 @@ S3Client::CompleteMultipartUpload(std::string_view Key, std::string_view XmlView = XmlBody.ToView(); std::string PayloadHash = Sha256ToHex(ComputeSha256(XmlView)); - HttpClient::KeyValueMap Headers = SignRequest("POST", Path, CanonicalQS, PayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "POST", Path, CanonicalQS, PayloadHash); Headers->emplace("Content-Type", "application/xml"); IoBuffer Payload(IoBuffer::Clone, XmlView.data(), XmlView.size()); @@ -675,44 +862,56 @@ S3Client::CompleteMultipartUpload(std::string_view Key, HttpClient::Response Response = m_HttpClient.Post(FullPath, Payload, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 CompleteMultipartUpload failed"); + std::string Err = S3ErrorMessage("S3 CompleteMultipartUpload failed", Response); ZEN_WARN("S3 CompleteMultipartUpload '{}' failed: {}", Key, Err); return S3Result{std::move(Err)}; } // Check for error in response body - S3 can return 200 with an error in the XML body std::string_view ResponseBody = Response.AsText(); - if (ResponseBody.find("<Error>") != std::string_view::npos) + std::string_view ErrorCode; + std::string_view ErrorMessage; + if (ExtractS3Error(ResponseBody, ErrorCode, ErrorMessage)) { - std::string_view ErrorCode = ExtractXmlValue(ResponseBody, "Code"); - std::string_view ErrorMessage = ExtractXmlValue(ResponseBody, "Message"); - std::string Err = fmt::format("S3 CompleteMultipartUpload '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); + std::string Err = fmt::format("S3 CompleteMultipartUpload '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); ZEN_WARN("{}", Err); return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 CompleteMultipartUpload '{}' succeeded ({} parts)", Key, PartETags.size()); + if (m_Verbose) + { + ZEN_INFO("S3 CompleteMultipartUpload '{}' succeeded ({} parts)", Key, PartETags.size()); + } return {}; } S3Result S3Client::AbortMultipartUpload(std::string_view Key, std::string_view UploadId) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 AbortMultipartUpload '{}' failed", Key); !Err.empty()) + { + return S3Result{std::move(Err)}; + } + std::string Path = KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); - HttpClient::KeyValueMap Headers = SignRequest("DELETE", Path, CanonicalQS, EmptyPayloadHash); + HttpClient::KeyValueMap Headers = SignRequest(Credentials, "DELETE", Path, CanonicalQS, EmptyPayloadHash); std::string FullPath = BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Delete(FullPath, Headers); if (!Response.IsSuccess()) { - std::string Err = Response.ErrorMessage("S3 AbortMultipartUpload failed"); + std::string Err = S3ErrorMessage("S3 AbortMultipartUpload failed", Response); ZEN_WARN("S3 AbortMultipartUpload '{}' failed: {}", Key, Err); return S3Result{std::move(Err)}; } - ZEN_DEBUG("S3 AbortMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + if (m_Verbose) + { + ZEN_INFO("S3 AbortMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + } return {}; } @@ -731,6 +930,12 @@ S3Client::GeneratePresignedPutUrl(std::string_view Key, std::chrono::seconds Exp std::string S3Client::GeneratePresignedUrlForMethod(std::string_view Key, std::string_view Method, std::chrono::seconds ExpiresIn) { + SigV4Credentials Credentials; + if (std::string Err = RequireCredentials(Credentials, "S3 GeneratePresignedUrl '{}' {} failed", Key, Method); !Err.empty()) + { + return {}; + } + std::string Path = KeyToPath(Key); std::string Scheme = "https"; @@ -739,7 +944,6 @@ S3Client::GeneratePresignedUrlForMethod(std::string_view Key, std::string_view M Scheme = "http"; } - SigV4Credentials Credentials = GetCurrentCredentials(); return GeneratePresignedUrl(Credentials, Method, Scheme, m_Host, Path, m_Region, "s3", ExpiresIn); } @@ -755,7 +959,10 @@ S3Client::PutObjectMultipart(std::string_view Key, return PutObject(Key, TotalSize > 0 ? FetchRange(0, TotalSize) : IoBuffer{}); } - ZEN_DEBUG("S3 multipart upload '{}': {} bytes in ~{} parts", Key, TotalSize, (TotalSize + PartSize - 1) / PartSize); + if (m_Verbose) + { + ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, TotalSize, (TotalSize + PartSize - 1) / PartSize); + } S3CreateMultipartUploadResult InitResult = CreateMultipartUpload(Key); if (!InitResult) @@ -803,7 +1010,10 @@ S3Client::PutObjectMultipart(std::string_view Key, throw; } - ZEN_DEBUG("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), TotalSize); + if (m_Verbose) + { + ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), TotalSize); + } return {}; } @@ -892,7 +1102,7 @@ TEST_CASE("s3client.minio_integration") using namespace std::literals; // Spawn a single MinIO server for the entire test case. Previously each SUBCASE re-entered - // the TEST_CASE from the top, spawning and killing MinIO per subcase — slow and flaky on + // the TEST_CASE from the top, spawning and killing MinIO per subcase - slow and flaky on // macOS CI. Sequential sections avoid the re-entry while still sharing one MinIO instance // that is torn down via RAII at scope exit. MinioProcessOptions MinioOpts; @@ -943,6 +1153,42 @@ TEST_CASE("s3client.minio_integration") CHECK(HeadRes2.Status == HeadObjectResult::NotFound); } + // -- touch ---------------------------------------------------------------- + { + std::string_view TestData = "touch-me"sv; + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); + S3Result PutRes = Client.PutObject("touch/obj.txt", std::move(Content)); + REQUIRE(PutRes.IsSuccess()); + + S3HeadObjectResult Before = Client.HeadObject("touch/obj.txt"); + REQUIRE(Before.IsSuccess()); + REQUIRE(Before.Status == HeadObjectResult::Found); + + // S3 LastModified has second precision; sleep past the second boundary so + // the touched timestamp is strictly greater. + Sleep(1100); + + S3Result TouchRes = Client.Touch("touch/obj.txt"); + REQUIRE(TouchRes.IsSuccess()); + + S3HeadObjectResult After = Client.HeadObject("touch/obj.txt"); + REQUIRE(After.IsSuccess()); + REQUIRE(After.Status == HeadObjectResult::Found); + CHECK(After.Info.Size == Before.Info.Size); + CHECK(After.Info.LastModified != Before.Info.LastModified); + + // Content must be unchanged by a self-copy. + S3GetObjectResult GetRes = Client.GetObject("touch/obj.txt"); + REQUIRE(GetRes.IsSuccess()); + CHECK(GetRes.AsText() == TestData); + + // Touching a missing key must fail. + S3Result MissRes = Client.Touch("touch/does-not-exist.txt"); + CHECK_FALSE(MissRes.IsSuccess()); + + Client.DeleteObject("touch/obj.txt"); + } + // -- head_not_found ------------------------------------------------------- { S3HeadObjectResult Res = Client.HeadObject("nonexistent/key.dat"); diff --git a/src/zenutil/config/commandlineoptions.cpp b/src/zenutil/config/commandlineoptions.cpp index 25f5522d8..42ce0d06a 100644 --- a/src/zenutil/config/commandlineoptions.cpp +++ b/src/zenutil/config/commandlineoptions.cpp @@ -8,6 +8,12 @@ #include <zencore/windows.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <EASTL/fixed_vector.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <algorithm> + #if ZEN_WITH_TESTS # include <zencore/testing.h> #endif // ZEN_WITH_TESTS @@ -188,6 +194,422 @@ CommandLineConverter::CommandLineConverter(int& argc, char**& argv) argv = RawArgs.data(); } +////////////////////////////////////////////////////////////////////////// +// Command-line scrubber (used by invocation history and Sentry integration). + +namespace { + + // Suffixes (matched against the normalized option name) that mark an option + // as carrying a sensitive value. Names are normalized before comparison: + // leading dashes dropped, remaining '-' / '_' stripped, lowercased. So + // `--access-token` / `--Access-Token` / `--access_token` all normalize to + // "accesstoken" and match the "token" suffix. + // + // Picked deliberately to be sharp: catches every credential-bearing option + // in zen and zenserver (access tokens, OAuth client secrets, AES keys, + // Sentry DSNs, upstream Jupiter tokens) without false-positive masks on + // look-alikes (--access-token-env, --access-token-path, --oidctoken-exe-path, + // --valuekey, --opkey, the bare --key cloud lookup hash, --encryption-aes-iv). + // + // For freeform / unknown sensitive values that follow a known format (AWS + // access keys, Google API keys, JWT bearer tokens) the value-pattern scanner + // (kSecretPatterns below) provides an orthogonal safety net. + constexpr std::string_view kSensitiveNameSuffixes[] = { + "token", + "aeskey", + "secret", + "dsn", + }; + + constexpr char ToLowerAscii(char C) { return (C >= 'A' && C <= 'Z') ? char(C + ('a' - 'A')) : C; } + + bool IsSensitiveOptionName(std::string_view Name) + { + // Normalize: skip syntactic quotes, drop leading dashes, strip remaining + // '-' / '_', lowercase ASCII. "--access-token" / "--Access-Token" / + // "--access_token" / `"--Access_Token` all collapse to "accesstoken". + // The 64-byte inline buffer covers every realistic option name; the + // builder spills to the heap on a pathological name. + ExtendableStringBuilder<64> Norm; + bool LeadingDashes = true; + for (char C : Name) + { + if (C == '"') + { + continue; + } + if (LeadingDashes) + { + if (C == '-') + { + continue; + } + LeadingDashes = false; + } + if (C == '-' || C == '_') + { + continue; + } + Norm.Append(ToLowerAscii(C)); + } + const std::string_view View = Norm.ToView(); + for (std::string_view Suffix : kSensitiveNameSuffixes) + { + if (View.ends_with(Suffix)) + { + return true; + } + } + return false; + } + + constexpr std::string_view kUserAndPass = "***:***"; + constexpr std::string_view kUserOnly = "***"; + + // Locate scheme://[user[:pass]]@ credentials in Text. Sets HostStart to the + // offset right after "://", UserInfoLen to the length of the userinfo span + // to redact (i.e. AtPos - HostStart), and HasPassword to true when a ':' + // separates user and password within the userinfo. Returns false if Text + // has no credentialed authority component. + struct UrlCredentials + { + size_t HostStart; + size_t UserInfoLen; + bool HasPassword; + }; + bool FindUrlCredentials(std::string_view Text, UrlCredentials& Out) + { + const size_t SchemePos = Text.find("://"); + if (SchemePos == std::string_view::npos) + { + return false; + } + const size_t HostStart = SchemePos + 3; + const size_t AtPos = Text.find('@', HostStart); + if (AtPos == std::string_view::npos) + { + return false; + } + // '@' must be in the authority, not the path/query/fragment. + const size_t TermPos = Text.find_first_of("/?#", HostStart); + if (TermPos != std::string_view::npos && TermPos < AtPos) + { + return false; + } + const size_t ColonPos = Text.find(':', HostStart); + Out.HostStart = HostStart; + Out.UserInfoLen = AtPos - HostStart; + Out.HasPassword = (ColonPos != std::string_view::npos && ColonPos < AtPos); + return true; + } + + // Inline capacity for the tokenized cmdline. 32 covers any realistic + // invocation; if exceeded the eastl::fixed_vector spills to the heap. + constexpr size_t kTokenInlineCapacity = 32; + using TokenVector = eastl::fixed_vector<std::string_view, kTokenInlineCapacity>; + + constexpr AsciiSet kUpperAlnum = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + constexpr AsciiSet kBase64Url = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; + constexpr AsciiSet kJwtCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_."; + + // Known credential value formats. The cmdline scanner walks each non-argv[0] + // token and at every position checks whether one of these prefixes matches. + // On match it greedily consumes [MinBodyLen, MaxBodyLen] characters from + // BodyChars; if the consumed body is at least MinBodyLen the prefix + body + // range is masked. Patterns are kept strict (distinctive prefix + length + // + charset) to keep false positives near zero. + // + // Only formats that zen/zenserver interacts with (or may plausibly accept + // via a cloud config) are listed here: + // - AWS access keys: used by the S3 client (SigV4) and EC2 IMDS provider. + // - Google API keys: covers keys passed via GCP-adjacent configuration. + // - JWTs: OAuth/OIDC bearer tokens accepted via --access-token and the + // upstream-jupiter-oauth-* options are typically JWTs. + struct SecretPattern + { + std::string_view Prefix; + size_t MinBodyLen; + size_t MaxBodyLen; + AsciiSet BodyChars; + }; + + constexpr SecretPattern kSecretPatterns[] = { + // AWS access key (long-term and temporary / STS). + {"AKIA", 16, 16, kUpperAlnum}, + {"ASIA", 16, 16, kUpperAlnum}, + // Google API key: exactly "AIza" + 35 base64url body chars. + {"AIza", 35, 35, kBase64Url}, + // JWT (header.payload.signature). Loose - prefix "eyJ" plus body chars. + // Verified to contain at least two '.' before accepting. + {"eyJ", 20, 8192, kJwtCharset}, + }; + + // Scan a token for embedded secret patterns. For each match append an Edit + // that replaces the matched range with HideSensitiveString(matched value). + // The first 4 characters of any 17+ char match leak through - safe here + // since they are part of the public format prefix that triggered the match + // (e.g. "AKIA...", "AIza...", "eyJh..."). Edit offsets are absolute + // Cmdline offsets. + template<typename EditVecT> + void FindSecretPatternEdits(std::string_view Token, size_t TokStart, EditVecT& Edits) + { + size_t I = 0; + while (I < Token.size()) + { + bool Matched = false; + for (const SecretPattern& Pat : kSecretPatterns) + { + if (I + Pat.Prefix.size() > Token.size()) + { + continue; + } + if (Token.compare(I, Pat.Prefix.size(), Pat.Prefix) != 0) + { + continue; + } + const size_t BodyStart = I + Pat.Prefix.size(); + size_t J = BodyStart; + while (J < Token.size() && (J - BodyStart) < Pat.MaxBodyLen && Pat.BodyChars.Contains(Token[J])) + { + ++J; + } + const size_t BodyLen = J - BodyStart; + if (BodyLen < Pat.MinBodyLen) + { + continue; + } + if (Pat.Prefix == "eyJ") + { + int DotCount = 0; + for (size_t K = BodyStart; K < J; ++K) + { + if (Token[K] == '.' && ++DotCount == 2) + { + break; + } + } + if (DotCount < 2) + { + continue; + } + } + const size_t MatchLen = Pat.Prefix.size() + BodyLen; + Edits.push_back({TokStart + I, MatchLen, HideSensitiveString(Token.substr(I, MatchLen))}); + I = J; + Matched = true; + break; + } + if (!Matched) + { + ++I; + } + } + } + + // Tokenize a Windows-style command line into views into the original + // string. Each token preserves any surrounding/embedded quote characters. + // Splits on unquoted ASCII whitespace. No heap allocation up to + // kTokenInlineCapacity tokens. + TokenVector SplitCommandLineTokens(std::string_view Input) + { + TokenVector Tokens; + size_t I = 0; + while (I < Input.size()) + { + while (I < Input.size() && (Input[I] == ' ' || Input[I] == '\t')) + { + ++I; + } + if (I >= Input.size()) + { + break; + } + const size_t Start = I; + bool InQuote = false; + while (I < Input.size()) + { + const char C = Input[I]; + if (C == '"') + { + InQuote = !InQuote; + ++I; + continue; + } + if (!InQuote && (C == ' ' || C == '\t')) + { + break; + } + ++I; + } + Tokens.push_back(Input.substr(Start, I - Start)); + } + return Tokens; + } + + // First non-quote character in a token (or '\0' if there isn't one). + // Used to test whether a token starts with '-' through any leading quotes. + char FirstNonQuote(std::string_view Token) + { + for (char C : Token) + { + if (C != '"') + { + return C; + } + } + return '\0'; + } + + void ScrubSensitiveValuesImpl(std::string& Cmdline) + { + const TokenVector Tokens = SplitCommandLineTokens(Cmdline); + if (Tokens.size() <= 1) + { + return; + } + + // Edits are accumulated and applied right-to-left at the end, so + // untouched offsets remain valid. fixed_vector keeps the storage + // inline; spills to heap only on a pathological number of edits. The + // std::string Replacement uses small-string optimization for the + // short masks we produce (max ~12 chars) so no heap traffic per edit + // in the common case. + struct Edit + { + size_t Start; + size_t Len; + std::string Replacement; + }; + eastl::fixed_vector<Edit, kTokenInlineCapacity> Edits; + + auto PushMask = [&](size_t Start, size_t Len, std::string_view Replacement) { + Edits.push_back({Start, Len, std::string(Replacement)}); + }; + + bool MaskNext = false; + // Skip Tokens[0] (executable path) - never scrub. + for (size_t Idx = 1; Idx < Tokens.size(); ++Idx) + { + const std::string_view Tok = Tokens[Idx]; + const size_t TokStart = static_cast<size_t>(Tok.data() - Cmdline.data()); + const char Lead = FirstNonQuote(Tok); + const bool IsFlag = (Lead == '-'); + + if (MaskNext) + { + MaskNext = false; + if (!IsFlag && Lead != '\0') + { + PushMask(TokStart, Tok.size(), kUserOnly); + continue; + } + // Otherwise fall through to re-evaluate this token as a flag. + } + + if (!IsFlag) + { + // Positional. URL credentials and secret patterns target + // different ranges so they can coexist; overlap is filtered + // out below. + if (UrlCredentials U; FindUrlCredentials(Tok, U)) + { + PushMask(TokStart + U.HostStart, U.UserInfoLen, U.HasPassword ? kUserAndPass : kUserOnly); + } + FindSecretPatternEdits(Tok, TokStart, Edits); + continue; + } + + // It's a flag. Look for inline =value. + const size_t EqPos = Tok.find('='); + if (EqPos == std::string_view::npos) + { + if (IsSensitiveOptionName(Tok)) + { + MaskNext = true; + } + else + { + // Bare flag - still scan for embedded secret prefixes + // (e.g. "--AKIAEXAMPLE..." would be unusual but harmless). + FindSecretPatternEdits(Tok, TokStart, Edits); + } + continue; + } + + const std::string_view Name = Tok.substr(0, EqPos); + if (IsSensitiveOptionName(Name)) + { + const size_t ValueStart = TokStart + EqPos + 1; + size_t ValueLen = Tok.size() - (EqPos + 1); + // Preserve the closing quote when the whole token is outer-quoted + // (e.g. `"--name=value"`); otherwise the replacement would eat it + // and leave an unbalanced quote in the cmdline string. + if (ValueLen > 0 && Tok.front() == '"' && Tok.back() == '"') + { + --ValueLen; + } + PushMask(ValueStart, ValueLen, kUserOnly); + } + else + { + // Non-sensitive flag with a value: URL scrub + pattern scan + // on the value (same coexistence as the positional branch). + const std::string_view Value = Tok.substr(EqPos + 1); + const size_t ValueAbsStart = TokStart + EqPos + 1; + if (UrlCredentials U; FindUrlCredentials(Value, U)) + { + PushMask(ValueAbsStart + U.HostStart, U.UserInfoLen, U.HasPassword ? kUserAndPass : kUserOnly); + } + FindSecretPatternEdits(Value, ValueAbsStart, Edits); + } + } + + if (Edits.empty()) + { + return; + } + + // Sort by start ascending and drop overlapping/duplicate edits so the + // right-to-left replacement pass does not corrupt the string. Earlier + // (lower-Start) edits win over later ones that fall inside their range. + std::sort(Edits.begin(), Edits.end(), [](const Edit& A, const Edit& B) { return A.Start < B.Start; }); + size_t Write = 0; + for (size_t Read = 0; Read < Edits.size(); ++Read) + { + if (Write > 0 && Edits[Read].Start < Edits[Write - 1].Start + Edits[Write - 1].Len) + { + continue; // overlaps the prior kept edit + } + if (Write != Read) + { + Edits[Write] = Edits[Read]; + } + ++Write; + } + Edits.resize(Write); + + // Apply right-to-left so earlier offsets stay valid as later edits + // resize the string. + for (size_t E = Edits.size(); E-- > 0;) + { + Cmdline.replace(Edits[E].Start, Edits[E].Len, Edits[E].Replacement); + } + } + +} // namespace + +void +ScrubSensitiveValues(std::string& Cmdline) noexcept +{ + try + { + ScrubSensitiveValuesImpl(Cmdline); + } + catch (...) + { + } +} + #if ZEN_WITH_TESTS void @@ -195,6 +617,22 @@ commandlineoptions_forcelink() { } +namespace { + // Test helper: in-place redaction of user:password@ in a URL. Production + // code calls FindUrlCredentials + PushMask directly inside the cmdline + // walker; this wrapper exists only for direct unit-test coverage of the + // URL redaction rule. + void ScrubUrlCredentials(std::string& Token) + { + UrlCredentials U; + if (!FindUrlCredentials(Token, U)) + { + return; + } + Token.replace(U.HostStart, U.UserInfoLen, U.HasPassword ? kUserAndPass : kUserOnly); + } +} // namespace + TEST_SUITE_BEGIN("util.commandlineoptions"); TEST_CASE("CommandLine") @@ -238,6 +676,321 @@ TEST_CASE("CommandLine") CHECK_EQ(v3Stripped[5], std::string("--build-part-name=win64")); } +TEST_CASE("IsSensitiveOptionName.matches") +{ + // Real zen / zenserver options ending in one of the sensitive suffixes. + CHECK(IsSensitiveOptionName("--access-token")); // token + CHECK(IsSensitiveOptionName("--openid-refresh-token")); // token + CHECK(IsSensitiveOptionName("--upstream-jupiter-token")); // token + CHECK(IsSensitiveOptionName("--encryption-aes-key")); // aeskey + CHECK(IsSensitiveOptionName("--oauth-clientsecret")); // secret + CHECK(IsSensitiveOptionName("--upstream-jupiter-oauth-clientsecret")); // secret + CHECK(IsSensitiveOptionName("--sentry-dsn")); // dsn + + // Generic forms ending in one of the suffixes. + CHECK(IsSensitiveOptionName("--token")); + CHECK(IsSensitiveOptionName("--my-aeskey")); + + // Normalization equivalents - dashes / underscores stripped, case folded. + CHECK(IsSensitiveOptionName("--ACCESS-TOKEN")); + CHECK(IsSensitiveOptionName("--access_token")); + CHECK(IsSensitiveOptionName("--AccessToken")); + CHECK(IsSensitiveOptionName("--Encryption_AES_Key")); + CHECK(IsSensitiveOptionName("--Sentry-DSN")); + + // Quotes around the option name are stripped before matching. + CHECK(IsSensitiveOptionName("\"--access-token\"")); + CHECK(IsSensitiveOptionName("\"--sentry-dsn")); +} + +TEST_CASE("IsSensitiveOptionName.no-match") +{ + // Real zen options whose name contains "token" or "key" mid-name but + // whose value is NOT a secret (suffix doesn't match). + CHECK_FALSE(IsSensitiveOptionName("--access-token-env")); // env name + CHECK_FALSE(IsSensitiveOptionName("--access-token-path")); // file path + CHECK_FALSE(IsSensitiveOptionName("--oidctoken-exe-path")); // exe path + CHECK_FALSE(IsSensitiveOptionName("--allow-external-oidctoken-exe")); // boolean + CHECK_FALSE(IsSensitiveOptionName("--encryption-aes-iv")); // IV is public + CHECK_FALSE(IsSensitiveOptionName("--key")); // cloud lookup hash + CHECK_FALSE(IsSensitiveOptionName("--valuekey")); // IoHash filter + CHECK_FALSE(IsSensitiveOptionName("--opkey")); // chunk OID filter + + // "key" alone is not a sensitive suffix in this scheme - keeps lookup + // hashes / cache filter / api-key / ssh-key visible. + CHECK_FALSE(IsSensitiveOptionName("--api-key")); + CHECK_FALSE(IsSensitiveOptionName("--ssh-key")); + + // Non-sensitive option names. + CHECK_FALSE(IsSensitiveOptionName("--port")); + CHECK_FALSE(IsSensitiveOptionName("--data-dir")); + CHECK_FALSE(IsSensitiveOptionName("--debug")); + CHECK_FALSE(IsSensitiveOptionName("--no-sentry")); + CHECK_FALSE(IsSensitiveOptionName("--filter")); + CHECK_FALSE(IsSensitiveOptionName("--monkey")); + CHECK_FALSE(IsSensitiveOptionName("--keychain")); + + CHECK_FALSE(IsSensitiveOptionName("")); + CHECK_FALSE(IsSensitiveOptionName("--")); + CHECK_FALSE(IsSensitiveOptionName("---")); + + // Pathologically long non-sensitive name spills the inline buffer but + // still resolves cleanly. + CHECK_FALSE(IsSensitiveOptionName(std::string(200, 'x'))); +} + +TEST_CASE("ScrubUrlCredentials.basic") +{ + std::string T = "https://user:[email protected]/path"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "https://***:***@host.example.com/path"); + + T = "https://[email protected]"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "https://***@host.example.com"); + + T = "ftp://u:[email protected]:21/file"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "ftp://***:***@example.com:21/file"); +} + +TEST_CASE("ScrubUrlCredentials.passthrough") +{ + std::string T = "just a string"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "just a string"); + + T = "https://example.com/path"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "https://example.com/path"); + + T = "https://example.com/users/foo@bar"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "https://example.com/users/foo@bar"); + + T = "user@host"; + ScrubUrlCredentials(T); + CHECK_EQ(T, "user@host"); +} + +TEST_CASE("ScrubSensitiveValues.inline-equals") +{ + std::string C = "zen.exe --access-token=secret --port=8558 version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token=*** --port=8558 version"); +} + +TEST_CASE("ScrubSensitiveValues.next-token") +{ + std::string C = "zen.exe --access-token mysecret version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token *** version"); + + // Multiple next-token sensitive options in a row. + C = "zen.exe --access-token tok1 --oauth-clientsecret sec1 --port 8558 version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token *** --oauth-clientsecret *** --port 8558 version"); +} + +TEST_CASE("ScrubSensitiveValues.sensitive-flag-followed-by-flag") +{ + // `--access-token` here is acting like a switch (no value); the next + // flag must NOT be masked. + std::string C = "zen.exe --access-token --debug version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token --debug version"); +} + +TEST_CASE("ScrubSensitiveValues.normalized-name-forms") +{ + // Same option masked under different casing/separator styles. + std::string C = "zen.exe --Access-Token bar --ACCESS_TOKEN=foo --AccessToken=baz version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --Access-Token *** --ACCESS_TOKEN=*** --AccessToken=*** version"); +} + +TEST_CASE("ScrubSensitiveValues.aes-key-and-iv-and-bare-key") +{ + // --encryption-aes-key matches the "aeskey" suffix and is masked. + // --encryption-aes-iv ends in "iv", not in the suffix set, and stays. + // --key (cloud lookup hash) ends in just "key", not in the suffix set, + // and stays. + std::string C = "zen.exe --encryption-aes-key=BASE64STUFF --encryption-aes-iv ABCDEF --key=cloudval --key bareval version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --encryption-aes-key=*** --encryption-aes-iv ABCDEF --key=cloudval --key bareval version"); +} + +TEST_CASE("ScrubSensitiveValues.no-false-positives") +{ + // Names that don't end in any sensitive suffix stay untouched. + std::string C = "zen.exe --password mypass --api-key bar --monkey=banana --filter zen --no-sentry --port 8558 version"; + const std::string Original = C; + ScrubSensitiveValues(C); + CHECK_EQ(C, Original); +} + +TEST_CASE("ScrubSensitiveValues.url-credentials-positional") +{ + std::string C = "zen.exe serve https://user:[email protected]/path"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe serve https://***:***@host.example.com/path"); +} + +TEST_CASE("ScrubSensitiveValues.url-credentials-in-option-value") +{ + std::string C = "zen.exe --url=https://user:[email protected] version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --url=https://***:***@host.example.com version"); +} + +TEST_CASE("ScrubSensitiveValues.outer-quoted-token-preserves-closing-quote") +{ + std::string C = "zen.exe status \"--access-token=Bearer xyz\""; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe status \"--access-token=***\""); +} + +TEST_CASE("ScrubSensitiveValues.executable-path-never-touched") +{ + // argv[0] stays untouched even if it embeds a token-shaped substring. + std::string C = "/path/with/AKIAIOSFODNN7EXAMPLE/zen.exe version"; + const std::string Original = C; + ScrubSensitiveValues(C); + CHECK_EQ(C, Original); +} + +TEST_CASE("ScrubSensitiveValues.empty-and-trivial") +{ + std::string C; + ScrubSensitiveValues(C); + CHECK_EQ(C, ""); + + C = "zen.exe"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe"); +} + +TEST_CASE("ScrubSensitiveValues.long-value") +{ + const std::string LongSecret(4096, 'X'); + std::string C = "zen.exe --access-token=" + LongSecret + " version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token=*** version"); + + C = "zen.exe --access-token " + LongSecret + " version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token *** version"); + + C = "zen.exe \"--access-token=" + LongSecret + "\""; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe \"--access-token=***\""); +} + +TEST_CASE("ScrubSensitiveValues.long-option-name-spills-buffer") +{ + // Names longer than the 64-byte inline buffer normalize correctly via + // ExtendableStringBuilder's heap-spill path, and the suffix check still + // identifies them as sensitive. + const std::string LongName = "--" + std::string(80, 'a') + "-token"; + std::string C = "zen.exe " + LongName + "=secretvalue version"; + const std::string Expected = "zen.exe " + LongName + "=*** version"; + ScrubSensitiveValues(C); + CHECK_EQ(C, Expected); +} + +TEST_CASE("ScrubSensitiveValues.no-allocation-on-noop") +{ + std::string C = "zen.exe version --port=8558 --debug"; + C.reserve(256); + const auto* DataBefore = C.data(); + const auto CapacityBefore = C.capacity(); + ScrubSensitiveValues(C); + CHECK_EQ(C.data(), DataBefore); + CHECK_EQ(C.capacity(), CapacityBefore); +} + +// Value-based pattern matching ------------------------------------------ + +// Pattern-matched values are masked via HideSensitiveString, which leaks +// the first 4 characters for any value over 16 chars. Safe here because +// those 4 characters are part of the public format prefix that triggered +// the match (AKIA, AIza, eyJh) and serve as a useful debugging hint when +// reading a crash report. + +TEST_CASE("ScrubSensitiveValues.aws-access-key") +{ + std::string C = "zen.exe upload --bucket foo AKIAIOSFODNN7EXAMPLE done"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe upload --bucket foo AKIAXXXX... done"); + + // Inside an option value with non-sensitive name. + C = "zen.exe --note=AKIAIOSFODNN7EXAMPLE"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --note=AKIAXXXX..."); + + // AKIA followed by too few chars: not a match. + C = "zen.exe AKIA12345"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe AKIA12345"); +} + +TEST_CASE("ScrubSensitiveValues.google-api-key") +{ + // Google API keys are exactly AIza + 35 body chars. + std::string C = "zen.exe lookup AIzaSyA_abcdefghijklmnopqrstuvwx-123456 done"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe lookup AIzaXXXX... done"); + + // AIza body too short: not a match. + C = "zen.exe AIzaShort"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe AIzaShort"); +} + +TEST_CASE("ScrubSensitiveValues.jwt") +{ + std::string C = + "zen.exe call eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c " + "done"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe call eyJhXXXX... done"); + + // "eyJ" but no two dots: not a JWT. + C = "zen.exe eyJonly"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe eyJonly"); +} + +TEST_CASE("ScrubSensitiveValues.no-pattern-false-positives") +{ + // Common identifier formats used by zen that must NOT be flagged: + // - 24-hex Oid + // - SHA-like hex hashes + // - file paths + // - port numbers + // - lowercase words / common shell commands + std::string C = "zen.exe cache get 09f7831b0139270d22cf2fe2 --port=8558 run /var/tmp/zen.log"; + const std::string Original = C; + ScrubSensitiveValues(C); + CHECK_EQ(C, Original); +} + +TEST_CASE("ScrubSensitiveValues.pattern-and-name-on-same-token") +{ + // Sensitive-name mask wins; pattern scan is suppressed for the value. + std::string C = "zen.exe --access-token=AKIAIOSFODNN7EXAMPLE done"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe --access-token=*** done"); +} + +TEST_CASE("ScrubSensitiveValues.multiple-patterns-in-one-token") +{ + // Two AWS keys glued together via separator - both should mask. + std::string C = "zen.exe AKIAIOSFODNN7EXAMPLE,AKIAIOSFODNN7OTHER12 list"; + ScrubSensitiveValues(C); + CHECK_EQ(C, "zen.exe AKIAXXXX...,AKIAXXXX... list"); +} + TEST_SUITE_END(); #endif diff --git a/src/zenutil/consoletui.cpp b/src/zenutil/consoletui.cpp index 84af1d372..354d24d8d 100644 --- a/src/zenutil/consoletui.cpp +++ b/src/zenutil/consoletui.cpp @@ -74,6 +74,15 @@ EnableVirtualTerminal() // ANSI escape codes are native on POSIX terminals; nothing to do } +// SIGWINCH (terminal resize) flag — set by signal handler, consumed by TuiReadKey() +static volatile sig_atomic_t s_GotSigWinch = 0; + +static void +SigWinchHandler(int /*Sig*/) +{ + s_GotSigWinch = 1; +} + // RAII guard: switches the terminal to raw/unbuffered input mode and restores // the original attributes on destruction. class RawModeGuard @@ -146,15 +155,6 @@ static bool s_InLiveMode = false; static char s_LastChar = 0; -// SIGWINCH (terminal resize) flag — set by signal handler, consumed by TuiReadKey() -static volatile sig_atomic_t s_GotSigWinch = 0; - -static void -SigWinchHandler(int /*Sig*/) -{ - s_GotSigWinch = 1; -} - #endif // ZEN_PLATFORM_WINDOWS / POSIX ////////////////////////////////////////////////////////////////////////// diff --git a/src/zenutil/consul/consul.cpp b/src/zenutil/consul/consul.cpp index c9144e589..762f06817 100644 --- a/src/zenutil/consul/consul.cpp +++ b/src/zenutil/consul/consul.cpp @@ -9,10 +9,18 @@ #include <zencore/logging.h> #include <zencore/process.h> #include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> #include <zencore/thread.h> #include <zencore/timer.h> +#include <zenhttp/httpserver.h> + +#include <unordered_set> + +ZEN_THIRD_PARTY_INCLUDES_START #include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END namespace zen::consul { @@ -31,7 +39,7 @@ struct ConsulProcess::Impl } CreateProcOptions Options; - Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; const std::filesystem::path ConsulExe = GetRunningExecutablePath().parent_path() / ("consul" ZEN_EXE_SUFFIX_LITERAL); CreateProcResult Result = CreateProc(ConsulExe, "consul" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options); @@ -107,12 +115,30 @@ ConsulProcess::StopConsulAgent() ////////////////////////////////////////////////////////////////////////// -ConsulClient::ConsulClient(std::string_view BaseUri, std::string_view Token) : m_Token(Token), m_HttpClient(BaseUri) +ConsulClient::ConsulClient(const Configuration& Config) +: m_Config(Config) +, m_HttpClient(m_Config.BaseUri, HttpClientSettings{.ConnectTimeout = m_Config.ConnectTimeout, .Timeout = m_Config.Timeout}, [this] { + return m_Stop.load(); +}) { + m_Worker = std::thread(&ConsulClient::WorkerLoop, this); } ConsulClient::~ConsulClient() { + try + { + m_Stop.store(true); + m_Wakeup.Set(); + if (m_Worker.joinable()) + { + m_Worker.join(); + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("ConsulClient::~ConsulClient threw exception: {}", Ex.what()); + } } void @@ -158,9 +184,27 @@ ConsulClient::DeleteKey(std::string_view Key) } } -bool +void ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) { + PendingOp Op{PendingOp::Kind::Register, Info}; + m_QueueLock.WithExclusiveLock([&] { m_Queue.push_back(std::move(Op)); }); + m_Wakeup.Set(); +} + +void +ConsulClient::DeregisterService(std::string_view ServiceId) +{ + PendingOp Op; + Op.Type = PendingOp::Kind::Deregister; + Op.Info.ServiceId = std::string(ServiceId); + m_QueueLock.WithExclusiveLock([&] { m_Queue.push_back(std::move(Op)); }); + m_Wakeup.Set(); +} + +bool +ConsulClient::DoRegister(const ServiceRegistrationInfo& Info) +{ using namespace std::literals; HttpClient::KeyValueMap AdditionalHeaders; @@ -193,12 +237,18 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) // when no interval is configured (e.g. during Provisioning). Writer.BeginObject("Check"sv); { - Writer.AddString("HTTP"sv, fmt::format("http://{}:{}/{}", Info.Address, Info.Port, Info.HealthEndpoint)); + Writer.AddString( + "HTTP"sv, + fmt::format("http://{}:{}/{}", Info.Address.empty() ? "localhost" : Info.Address, Info.Port, Info.HealthEndpoint)); Writer.AddString("Interval"sv, fmt::format("{}s", Info.HealthIntervalSeconds)); if (Info.DeregisterAfterSeconds != 0) { Writer.AddString("DeregisterCriticalServiceAfter"sv, fmt::format("{}s", Info.DeregisterAfterSeconds)); } + if (!Info.InitialStatus.empty()) + { + Writer.AddString("Status"sv, Info.InitialStatus); + } } Writer.EndObject(); // Check } @@ -213,7 +263,7 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) if (!Result) { - ZEN_WARN("ConsulClient::RegisterService() failed to register service '{}' ({})", Info.ServiceId, Result.ErrorMessage("")); + ZEN_WARN("ConsulClient::DoRegister() failed to register service '{}' ({})", Info.ServiceId, Result.ErrorMessage("")); return false; } @@ -221,29 +271,114 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) } bool -ConsulClient::DeregisterService(std::string_view ServiceId) +ConsulClient::DoDeregister(std::string_view ServiceId) { + using namespace std::literals; + HttpClient::KeyValueMap AdditionalHeaders; ApplyCommonHeaders(AdditionalHeaders); AdditionalHeaders.Entries.emplace(HttpClient::Accept(HttpContentType::kJSON)); - HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), AdditionalHeaders); + HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), IoBuffer{}, AdditionalHeaders); + if (Result) + { + return true; + } + + // Agent deregister failed - fall back to catalog deregister. + // This handles cases where the service was registered via a different Consul agent + // (e.g. load-balanced endpoint routing to different agents). + std::string NodeName = GetNodeName(); + if (!NodeName.empty()) + { + CbObjectWriter Writer; + Writer.AddString("Node"sv, NodeName); + Writer.AddString("ServiceID"sv, ServiceId); + + ExtendableStringBuilder<256> SB; + CompactBinaryToJson(Writer.Save(), SB); + + IoBuffer PayloadBuffer(IoBuffer::Wrap, SB.Data(), SB.Size()); + PayloadBuffer.SetContentType(HttpContentType::kJSON); + + HttpClient::Response CatalogResult = m_HttpClient.Put("v1/catalog/deregister", PayloadBuffer, AdditionalHeaders); + if (CatalogResult) + { + ZEN_INFO("ConsulClient::DoDeregister() deregistered service '{}' via catalog fallback (agent error: {})", + ServiceId, + Result.ErrorMessage("")); + return true; + } + + ZEN_WARN("ConsulClient::DoDeregister() failed to deregister service '{}' (agent: {}, catalog: {})", + ServiceId, + Result.ErrorMessage(""), + CatalogResult.ErrorMessage("")); + } + else + { + ZEN_WARN( + "ConsulClient::DoDeregister() failed to deregister service '{}' (agent: {}, could not determine node name for catalog " + "fallback)", + ServiceId, + Result.ErrorMessage("")); + } + + return false; +} +std::string +ConsulClient::GetNodeName() +{ + using namespace std::literals; + + HttpClient::KeyValueMap AdditionalHeaders; + ApplyCommonHeaders(AdditionalHeaders); + + HttpClient::Response Result = m_HttpClient.Get("v1/agent/self", AdditionalHeaders); if (!Result) { - ZEN_WARN("ConsulClient::DeregisterService() failed to deregister service '{}' ({})", ServiceId, Result.ErrorMessage("")); - return false; + return {}; } - return true; + std::string JsonError; + CbFieldIterator Root = LoadCompactBinaryFromJson(Result.AsText(), JsonError); + if (!Root || !JsonError.empty()) + { + return {}; + } + + for (CbFieldView Field : Root) + { + if (Field.GetName() == "Config"sv) + { + CbObjectView Config = Field.AsObjectView(); + if (Config) + { + return std::string(Config["NodeName"sv].AsString()); + } + } + } + + return {}; } void ConsulClient::ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap) { - if (!m_Token.empty()) + std::string Token; + if (!m_Config.StaticToken.empty()) + { + Token = m_Config.StaticToken; + } + else if (!m_Config.TokenEnvName.empty()) + { + Token = GetEnvVariable(m_Config.TokenEnvName); + } + + if (!Token.empty()) { - InOutHeaderMap.Entries.emplace("X-Consul-Token", m_Token); + InOutHeaderMap.Entries.emplace("X-Consul-Token", Token); } } @@ -295,8 +430,10 @@ ConsulClient::WatchService(std::string_view ServiceId, uint64_t& InOutIndex, int HttpClient::KeyValueMap AdditionalHeaders; ApplyCommonHeaders(AdditionalHeaders); - // Note: m_HttpClient uses unlimited HTTP timeout (Timeout{0}); the WaitSeconds parameter - // governs the server-side bound on the blocking query. Do not add a separate client timeout. + // Note: m_HttpClient runs with a 500ms client-side timeout to keep Register/Deregister from + // stalling the hub state machine when the agent is unreachable. That bound applies here too: + // WaitSeconds is effectively capped at ~500ms regardless of the argument, so callers must + // treat this as a short-poll and loop rather than rely on a true blocking query. HttpClient::KeyValueMap Parameters({{"index", std::to_string(InOutIndex)}, {"wait", fmt::format("{}s", WaitSeconds)}}); HttpClient::Response Result = m_HttpClient.Get("v1/agent/services", AdditionalHeaders, Parameters); if (!Result) @@ -345,6 +482,82 @@ ConsulClient::GetAgentChecksJson() return Result.ToText(); } +void +ConsulClient::WorkerLoop() +{ + SetCurrentThreadName("ConsulClient"); + + std::unordered_set<std::string> RegisteredServices; + + while (true) + { + m_Wakeup.Wait(-1); + m_Wakeup.Reset(); + + const bool Stopping = m_Stop.load(); + + std::vector<PendingOp> Batch; + m_QueueLock.WithExclusiveLock([&] { Batch.swap(m_Queue); }); + + for (size_t Index = 0; Index < Batch.size(); ++Index) + { + PendingOp& Op = Batch[Index]; + + if (Stopping && Op.Type == PendingOp::Kind::Register) + { + continue; + } + + const std::string_view OpName = (Op.Type == PendingOp::Kind::Register) ? "register" : "deregister"; + + try + { + if (Op.Type == PendingOp::Kind::Register) + { + bool Ok = DoRegister(Op.Info); + if (Ok) + { + RegisteredServices.insert(Op.Info.ServiceId); + } + else + { + const size_t Remaining = Batch.size() - Index - 1; + ZEN_WARN("ConsulClient worker: {} for '{}' failed; dropping {} remaining queued op(s)", + OpName, + Op.Info.ServiceId, + Remaining); + break; + } + } + else + { + ZEN_ASSERT(Op.Type == PendingOp::Kind::Deregister); + if (RegisteredServices.erase(Op.Info.ServiceId) == 1u) + { + if (!DoDeregister(Op.Info.ServiceId)) + { + ZEN_WARN("ConsulClient worker: {} for '{}' failed", OpName, Op.Info.ServiceId); + } + } + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("ConsulClient worker: {} for '{}' threw: {}", OpName, Op.Info.ServiceId, Ex.what()); + } + catch (...) + { + ZEN_WARN("ConsulClient worker: {} for '{}' threw unknown exception", OpName, Op.Info.ServiceId); + } + } + + if (Stopping) + { + break; + } + } +} + ////////////////////////////////////////////////////////////////////////// ServiceRegistration::ServiceRegistration(ConsulClient* Client, const ServiceRegistrationInfo& Info) : m_Client(Client), m_Info(Info) @@ -365,7 +578,7 @@ ServiceRegistration::~ServiceRegistration() if (m_IsRegistered.load()) { - if (!m_Client->DeregisterService(m_Info.ServiceId)) + if (!m_Client->DoDeregister(m_Info.ServiceId)) { ZEN_WARN("ServiceRegistration: Failed to deregister service '{}' during cleanup", m_Info.ServiceId); } @@ -405,7 +618,7 @@ ServiceRegistration::RegistrationLoop() // Try to register with exponential backoff for (int Attempt = 0; Attempt < MaxAttempts; ++Attempt) { - if (m_Client->RegisterService(m_Info)) + if (m_Client->DoRegister(m_Info)) { Succeeded = true; break; @@ -417,7 +630,7 @@ ServiceRegistration::RegistrationLoop() } } - if (Succeeded || m_Client->RegisterService(m_Info)) + if (Succeeded || m_Client->DoRegister(m_Info)) { break; } @@ -446,4 +659,202 @@ ServiceRegistration::RegistrationLoop() } } +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +consul_forcelink() +{ +} + +struct MockHealthService : public HttpService +{ + std::atomic<bool> FailHealth{false}; + std::atomic<int> HealthCheckCount{0}; + + const char* BaseUri() const override { return "/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + std::string_view Uri = Request.RelativeUri(); + if (Uri == "health/" || Uri == "health") + { + HealthCheckCount.fetch_add(1); + if (FailHealth.load()) + { + Request.WriteResponse(HttpResponseCode::ServiceUnavailable); + } + else + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + } +}; + +struct TestHealthServer +{ + MockHealthService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(0, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + int Port() const { return m_Port; } + + ~TestHealthServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional<ScopedTemporaryDirectory> m_TmpDir; + Ref<HttpServer> m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +static bool +WaitForCondition(std::function<bool()> Predicate, int TimeoutMs, int PollIntervalMs = 200) +{ + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < static_cast<uint64_t>(TimeoutMs)) + { + if (Predicate()) + { + return true; + } + Sleep(PollIntervalMs); + } + return Predicate(); +} + +static std::string +GetCheckStatus(ConsulClient& Client, std::string_view ServiceId) +{ + using namespace std::literals; + + std::string JsonError; + CbFieldIterator ChecksRoot = LoadCompactBinaryFromJson(Client.GetAgentChecksJson(), JsonError); + if (!ChecksRoot || !JsonError.empty()) + { + return {}; + } + + for (CbFieldView F : ChecksRoot) + { + if (!F.IsObject()) + { + continue; + } + for (CbFieldView C : F.AsObjectView()) + { + CbObjectView Check = C.AsObjectView(); + if (Check["ServiceID"sv].AsString() == ServiceId) + { + return std::string(Check["Status"sv].AsString()); + } + } + } + return {}; +} + +TEST_SUITE_BEGIN("util.consul"); + +TEST_CASE("util.consul.service_lifecycle") +{ + ConsulProcess ConsulProc; + ConsulProc.SpawnConsulAgent(); + + TestHealthServer HealthServer; + HealthServer.Start(); + + ConsulClient Client( + {.BaseUri = "http://localhost:8500/", .ConnectTimeout = std::chrono::seconds{5}, .Timeout = std::chrono::seconds{5}}); + + const std::string ServiceId = "test-health-svc"; + + ServiceRegistrationInfo Info; + Info.ServiceId = ServiceId; + Info.ServiceName = "zen-test-health"; + Info.Address = "127.0.0.1"; + Info.Port = static_cast<uint16_t>(HealthServer.Port()); + Info.HealthEndpoint = "health/"; + Info.HealthIntervalSeconds = 1; + Info.DeregisterAfterSeconds = 60; + + // Register/Deregister are async; wait for the worker to propagate to Consul. + + // Phase 1: Register and verify Consul sends health checks to our service + Client.RegisterService(Info); + REQUIRE(WaitForCondition([&]() { return Client.HasService(ServiceId); }, 10000, 50)); + + REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000, 50)); + CHECK(HealthServer.Mock.HealthCheckCount.load() >= 1); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing"); + + // Phase 2: Explicit deregister + Client.DeregisterService(ServiceId); + REQUIRE(WaitForCondition([&]() { return !Client.HasService(ServiceId); }, 10000, 50)); + + // Phase 3: Register with InitialStatus, verify immediately passing before any health check fires, + // then fail health and verify check goes critical + HealthServer.Mock.HealthCheckCount.store(0); + HealthServer.Mock.FailHealth.store(false); + + Info.InitialStatus = "passing"; + Client.RegisterService(Info); + REQUIRE(WaitForCondition([&]() { return Client.HasService(ServiceId); }, 10000, 50)); + + // Registration is async; by the time HasService observes the service the 1s health interval + // may already have fired, so we can't robustly assert HealthCheckCount==0. The "passing" status + // below still proves InitialStatus applied (it can only be "passing" via InitialStatus or a + // successful health check - both are acceptable demonstrations). + CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing"); + + REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000, 50)); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing"); + + HealthServer.Mock.FailHealth.store(true); + + // Wait for Consul to observe the failing check + REQUIRE(WaitForCondition([&]() { return GetCheckStatus(Client, ServiceId) == "critical"; }, 10000, 50)); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "critical"); + + // Phase 4: Explicit deregister while critical + Client.DeregisterService(ServiceId); + REQUIRE(WaitForCondition([&]() { return !Client.HasService(ServiceId); }, 10000, 50)); + + // Phase 5: Deregister an already-deregistered service - should not crash + Client.DeregisterService(ServiceId); + REQUIRE(WaitForCondition([&]() { return !Client.HasService(ServiceId); }, 10000, 50)); + + ConsulProc.StopConsulAgent(); +} + +TEST_SUITE_END(); + +#endif + } // namespace zen::consul diff --git a/src/zenutil/filesystemutils.cpp b/src/zenutil/filesystemutils.cpp new file mode 100644 index 000000000..f8f7bfb18 --- /dev/null +++ b/src/zenutil/filesystemutils.cpp @@ -0,0 +1,724 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/filesystemutils.h> + +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/parallelwork.h> +#include <zencore/scopeguard.h> +#include <zencore/timer.h> +#include <zencore/trace.h> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <zencore/testutils.h> +#endif // ZEN_WITH_TESTS + +namespace zen { + +BufferedOpenFile::BufferedOpenFile(const std::filesystem::path Path, + std::atomic<uint64_t>& OpenReadCount, + std::atomic<uint64_t>& CurrentOpenFileCount, + std::atomic<uint64_t>& ReadCount, + std::atomic<uint64_t>& ReadByteCount) +: m_Source(Path, BasicFile::Mode::kRead) +, m_SourceSize(m_Source.FileSize()) +, m_OpenReadCount(OpenReadCount) +, m_CurrentOpenFileCount(CurrentOpenFileCount) +, m_ReadCount(ReadCount) +, m_ReadByteCount(ReadByteCount) + +{ + m_OpenReadCount++; + m_CurrentOpenFileCount++; +} + +BufferedOpenFile::~BufferedOpenFile() +{ + m_CurrentOpenFileCount--; +} + +CompositeBuffer +BufferedOpenFile::GetRange(uint64_t Offset, uint64_t Size) +{ + ZEN_TRACE_CPU("BufferedOpenFile::GetRange"); + + ZEN_ASSERT((m_CacheBlockIndex == (uint64_t)-1) || m_Cache); + auto _ = MakeGuard([&]() { ZEN_ASSERT((m_CacheBlockIndex == (uint64_t)-1) || m_Cache); }); + + ZEN_ASSERT((Offset + Size) <= m_SourceSize); + const uint64_t BlockIndexStart = Offset / BlockSize; + const uint64_t BlockIndexEnd = (Offset + Size - 1) / BlockSize; + + std::vector<SharedBuffer> BufferRanges; + BufferRanges.reserve(BlockIndexEnd - BlockIndexStart + 1); + + uint64_t ReadOffset = Offset; + for (uint64_t BlockIndex = BlockIndexStart; BlockIndex <= BlockIndexEnd; BlockIndex++) + { + const uint64_t BlockStartOffset = BlockIndex * BlockSize; + if (m_CacheBlockIndex != BlockIndex) + { + uint64_t CacheSize = Min(BlockSize, m_SourceSize - BlockStartOffset); + ZEN_ASSERT(CacheSize > 0); + m_Cache = IoBuffer(CacheSize); + m_Source.Read(m_Cache.GetMutableView().GetData(), CacheSize, BlockStartOffset); + m_ReadCount++; + m_ReadByteCount += CacheSize; + m_CacheBlockIndex = BlockIndex; + } + + const uint64_t BytesRead = ReadOffset - Offset; + ZEN_ASSERT(BlockStartOffset <= ReadOffset); + const uint64_t OffsetIntoBlock = ReadOffset - BlockStartOffset; + ZEN_ASSERT(OffsetIntoBlock < m_Cache.GetSize()); + const uint64_t BlockBytes = Min(m_Cache.GetSize() - OffsetIntoBlock, Size - BytesRead); + BufferRanges.emplace_back(SharedBuffer(IoBuffer(m_Cache, OffsetIntoBlock, BlockBytes))); + ReadOffset += BlockBytes; + } + CompositeBuffer Result(std::move(BufferRanges)); + ZEN_ASSERT(Result.GetSize() == Size); + return Result; +} + +bool +IsFileWithRetry(const std::filesystem::path& Path) +{ + std::error_code Ec; + bool Result = IsFile(Path, Ec); + for (size_t Retries = 0; Ec && Retries < 3; Retries++) + { + Sleep(100 + int(Retries * 50)); + Ec.clear(); + Result = IsFile(Path, Ec); + } + if (Ec) + { + throw std::system_error(std::error_code(Ec.value(), std::system_category()), + fmt::format("Failed to check path '{}' is file, reason: ({}) {}", Path, Ec.value(), Ec.message())); + } + return Result; +} + +bool +SetFileReadOnlyWithRetry(const std::filesystem::path& Path, bool ReadOnly) +{ + std::error_code Ec; + bool Result = SetFileReadOnly(Path, ReadOnly, Ec); + for (size_t Retries = 0; Ec && Retries < 3; Retries++) + { + if (!IsFileWithRetry(Path)) + { + return false; + } + Sleep(100 + int(Retries * 50)); + Ec.clear(); + Result = SetFileReadOnly(Path, ReadOnly, Ec); + } + if (Ec) + { + throw std::system_error(std::error_code(Ec.value(), std::system_category()), + fmt::format("Failed {} read only flag for file '{}', reason: ({}) {}", + ReadOnly ? "setting" : "clearing", + Path, + Ec.value(), + Ec.message())); + } + return Result; +} + +std::error_code +RenameFileWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath) +{ + std::error_code Ec; + RenameFile(SourcePath, TargetPath, Ec); + for (size_t Retries = 0; Ec && Retries < 5; Retries++) + { + ZEN_ASSERT_SLOW(IsFile(SourcePath)); + Sleep(50 + int(Retries * 150)); + Ec.clear(); + RenameFile(SourcePath, TargetPath, Ec); + } + return Ec; +} + +std::error_code +RenameDirectoryWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath) +{ + std::error_code Ec; + RenameDirectory(SourcePath, TargetPath, Ec); + for (size_t Retries = 0; Ec && Retries < 5; Retries++) + { + ZEN_ASSERT_SLOW(IsDir(SourcePath)); + Sleep(50 + int(Retries * 150)); + Ec.clear(); + RenameDirectory(SourcePath, TargetPath, Ec); + } + return Ec; +} + +std::error_code +TryRemoveFile(const std::filesystem::path& Path) +{ + std::error_code Ec; + RemoveFile(Path, Ec); + if (Ec) + { + if (IsFile(Path, Ec)) + { + Ec.clear(); + RemoveFile(Path, Ec); + if (Ec) + { + return Ec; + } + } + } + return {}; +} + +void +RemoveFileWithRetry(const std::filesystem::path& Path) +{ + std::error_code Ec; + RemoveFile(Path, Ec); + for (size_t Retries = 0; Ec && Retries < 6; Retries++) + { + if (!IsFileWithRetry(Path)) + { + return; + } + Sleep(100 + int(Retries * 50)); + Ec.clear(); + RemoveFile(Path, Ec); + } + if (Ec) + { + throw std::system_error(std::error_code(Ec.value(), std::system_category()), + fmt::format("Failed removing file '{}', reason: ({}) {}", Path, Ec.value(), Ec.message())); + } +} + +void +FastCopyFile(bool AllowFileClone, + bool UseSparseFiles, + const std::filesystem::path& SourceFilePath, + const std::filesystem::path& TargetFilePath, + uint64_t RawSize, + std::atomic<uint64_t>& WriteCount, + std::atomic<uint64_t>& WriteByteCount, + std::atomic<uint64_t>& CloneCount, + std::atomic<uint64_t>& CloneByteCount) +{ + ZEN_TRACE_CPU("CopyFile"); + std::error_code CloneEc = + AllowFileClone ? TryCloneFile(SourceFilePath, TargetFilePath) : std::make_error_code(std::errc::operation_not_supported); + if (!CloneEc) + { + WriteCount += 1; + WriteByteCount += RawSize; + CloneCount += 1; + CloneByteCount += RawSize; + } + else + { + BasicFile TargetFile(TargetFilePath, BasicFile::Mode::kTruncate); + if (UseSparseFiles) + { + PrepareFileForScatteredWrite(TargetFile.Handle(), RawSize); + } + uint64_t Offset = 0; + std::error_code ScanEc = ScanFile(SourceFilePath, 512u * 1024u, [&](const void* Data, size_t Size) { + TargetFile.Write(Data, Size, Offset); + Offset += Size; + WriteCount++; + WriteByteCount += Size; + }); + if (ScanEc) + { + throw std::system_error(ScanEc, fmt::format("Failed to copy file '{}' to '{}'", SourceFilePath, TargetFilePath)); + } + } +} + +void +GetDirectoryContent(WorkerThreadPool& WorkerPool, + const std::filesystem::path& Path, + DirectoryContentFlags Flags, + DirectoryContent& OutContent) +{ + struct Visitor : public GetDirectoryContentVisitor + { + Visitor(zen::DirectoryContent& OutContent, const std::filesystem::path& InRootPath) : Content(OutContent), RootPath(InRootPath) {} + virtual bool AsyncAllowDirectory(const std::filesystem::path& Parent, const std::filesystem::path& DirectoryName) const + { + ZEN_UNUSED(Parent, DirectoryName); + return true; + } + virtual void AsyncVisitDirectory(const std::filesystem::path& RelativeRoot, DirectoryContent&& InContent) + { + std::vector<std::filesystem::path> Files; + std::vector<std::filesystem::path> Directories; + + if (!InContent.FileNames.empty()) + { + Files.reserve(InContent.FileNames.size()); + for (const std::filesystem::path& FileName : InContent.FileNames) + { + if (RelativeRoot.empty()) + { + Files.push_back(RootPath / FileName); + } + else + { + Files.push_back(RootPath / RelativeRoot / FileName); + } + } + } + + if (!InContent.DirectoryNames.empty()) + { + Directories.reserve(InContent.DirectoryNames.size()); + for (const std::filesystem::path& DirName : InContent.DirectoryNames) + { + if (RelativeRoot.empty()) + { + Directories.push_back(RootPath / DirName); + } + else + { + Directories.push_back(RootPath / RelativeRoot / DirName); + } + } + } + + Lock.WithExclusiveLock([&]() { + if (!InContent.FileNames.empty()) + { + for (const std::filesystem::path& FileName : InContent.FileNames) + { + if (RelativeRoot.empty()) + { + Content.Files.push_back(RootPath / FileName); + } + else + { + Content.Files.push_back(RootPath / RelativeRoot / FileName); + } + } + } + if (!InContent.FileSizes.empty()) + { + Content.FileSizes.insert(Content.FileSizes.end(), InContent.FileSizes.begin(), InContent.FileSizes.end()); + } + if (!InContent.FileAttributes.empty()) + { + Content.FileAttributes.insert(Content.FileAttributes.end(), + InContent.FileAttributes.begin(), + InContent.FileAttributes.end()); + } + if (!InContent.FileModificationTicks.empty()) + { + Content.FileModificationTicks.insert(Content.FileModificationTicks.end(), + InContent.FileModificationTicks.begin(), + InContent.FileModificationTicks.end()); + } + + if (!InContent.DirectoryNames.empty()) + { + for (const std::filesystem::path& DirName : InContent.DirectoryNames) + { + if (RelativeRoot.empty()) + { + Content.Directories.push_back(RootPath / DirName); + } + else + { + Content.Directories.push_back(RootPath / RelativeRoot / DirName); + } + } + } + if (!InContent.DirectoryAttributes.empty()) + { + Content.DirectoryAttributes.insert(Content.DirectoryAttributes.end(), + InContent.DirectoryAttributes.begin(), + InContent.DirectoryAttributes.end()); + } + }); + } + RwLock Lock; + zen::DirectoryContent& Content; + const std::filesystem::path& RootPath; + }; + + Visitor RootVisitor(OutContent, Path); + + Latch PendingWork(1); + GetDirectoryContent(Path, Flags, RootVisitor, WorkerPool, PendingWork); + PendingWork.CountDown(); + PendingWork.Wait(); +} + +CleanDirectoryResult +CleanDirectory( + WorkerThreadPool& IOWorkerPool, + std::atomic<bool>& AbortFlag, + std::atomic<bool>& PauseFlag, + const std::filesystem::path& Path, + std::span<const std::string> ExcludeDirectories, + std::function<void(const std::string_view Details, uint64_t TotalCount, uint64_t RemainingCount, bool IsPaused, bool IsAborted)>&& + ProgressFunc, + uint32_t ProgressUpdateDelayMS) +{ + ZEN_TRACE_CPU("CleanDirectory"); + Stopwatch Timer; + + std::atomic<uint64_t> DiscoveredItemCount = 0; + std::atomic<uint64_t> DeletedItemCount = 0; + std::atomic<uint64_t> DeletedByteCount = 0; + + std::vector<std::filesystem::path> DirectoriesToDelete; + CleanDirectoryResult Result; + RwLock ResultLock; + auto _ = MakeGuard([&]() { + Result.DeletedCount = DeletedItemCount.load(); + Result.DeletedByteCount = DeletedByteCount.load(); + Result.FoundCount = DiscoveredItemCount.load(); + }); + + ParallelWork Work(AbortFlag, + PauseFlag, + ProgressFunc ? WorkerThreadPool::EMode::DisableBacklog : WorkerThreadPool::EMode::EnableBacklog); + + struct AsyncVisitor : public GetDirectoryContentVisitor + { + AsyncVisitor(const std::filesystem::path& InPath, + std::atomic<bool>& InAbortFlag, + std::atomic<uint64_t>& InDiscoveredItemCount, + std::atomic<uint64_t>& InDeletedItemCount, + std::atomic<uint64_t>& InDeletedByteCount, + std::span<const std::string> InExcludeDirectories, + std::vector<std::filesystem::path>& OutDirectoriesToDelete, + CleanDirectoryResult& InResult, + RwLock& InResultLock) + : Path(InPath) + , AbortFlag(InAbortFlag) + , DiscoveredItemCount(InDiscoveredItemCount) + , DeletedItemCount(InDeletedItemCount) + , DeletedByteCount(InDeletedByteCount) + , ExcludeDirectories(InExcludeDirectories) + , DirectoriesToDelete(OutDirectoriesToDelete) + , Result(InResult) + , ResultLock(InResultLock) + { + } + + virtual bool AsyncAllowDirectory(const std::filesystem::path& Parent, const std::filesystem::path& DirectoryName) const override + { + ZEN_UNUSED(Parent); + + if (AbortFlag) + { + return false; + } + const std::string DirectoryString = DirectoryName.string(); + for (const std::string_view ExcludeDirectory : ExcludeDirectories) + { + if (DirectoryString == ExcludeDirectory) + { + return false; + } + } + return true; + } + + virtual void AsyncVisitDirectory(const std::filesystem::path& RelativeRoot, DirectoryContent&& Content) override + { + ZEN_TRACE_CPU("CleanDirectory_AsyncVisitDirectory"); + if (!AbortFlag) + { + DiscoveredItemCount += Content.FileNames.size(); + + ZEN_TRACE_CPU("DeleteFiles"); + std::vector<std::pair<std::filesystem::path, std::error_code>> FailedRemovePaths; + for (size_t FileIndex = 0; FileIndex < Content.FileNames.size(); FileIndex++) + { + const std::filesystem::path& FileName = Content.FileNames[FileIndex]; + const std::filesystem::path FilePath = (Path / RelativeRoot / FileName).make_preferred(); + + bool IsRemoved = false; + std::error_code Ec; + (void)SetFileReadOnly(FilePath, false, Ec); + for (size_t Retries = 0; Ec && Retries < 3; Retries++) + { + if (!IsFileWithRetry(FilePath)) + { + IsRemoved = true; + Ec.clear(); + break; + } + Sleep(100 + int(Retries * 50)); + Ec.clear(); + (void)SetFileReadOnly(FilePath, false, Ec); + } + if (!IsRemoved && !Ec) + { + (void)RemoveFile(FilePath, Ec); + for (size_t Retries = 0; Ec && Retries < 6; Retries++) + { + if (!IsFileWithRetry(FilePath)) + { + IsRemoved = true; + Ec.clear(); + break; + } + Sleep(100 + int(Retries * 50)); + Ec.clear(); + (void)RemoveFile(FilePath, Ec); + } + } + if (!IsRemoved && Ec) + { + FailedRemovePaths.push_back(std::make_pair(FilePath, Ec)); + } + else + { + DeletedItemCount++; + DeletedByteCount += Content.FileSizes[FileIndex]; + } + } + + if (!FailedRemovePaths.empty()) + { + RwLock::ExclusiveLockScope _(ResultLock); + Result.FailedRemovePaths.insert(Result.FailedRemovePaths.end(), FailedRemovePaths.begin(), FailedRemovePaths.end()); + } + else if (!RelativeRoot.empty()) + { + DiscoveredItemCount++; + RwLock::ExclusiveLockScope _(ResultLock); + DirectoriesToDelete.push_back(RelativeRoot); + } + } + } + const std::filesystem::path& Path; + std::atomic<bool>& AbortFlag; + std::atomic<uint64_t>& DiscoveredItemCount; + std::atomic<uint64_t>& DeletedItemCount; + std::atomic<uint64_t>& DeletedByteCount; + std::span<const std::string> ExcludeDirectories; + std::vector<std::filesystem::path>& DirectoriesToDelete; + CleanDirectoryResult& Result; + RwLock& ResultLock; + } Visitor(Path, + AbortFlag, + DiscoveredItemCount, + DeletedItemCount, + DeletedByteCount, + ExcludeDirectories, + DirectoriesToDelete, + Result, + ResultLock); + + GetDirectoryContent(Path, + DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive | DirectoryContentFlags::IncludeFileSizes, + Visitor, + IOWorkerPool, + Work.PendingWork()); + + uint64_t LastUpdateTimeMs = Timer.GetElapsedTimeMs(); + + if (ProgressFunc && ProgressUpdateDelayMS != 0) + { + Work.Wait(ProgressUpdateDelayMS, [&](bool IsAborted, bool IsPaused, ptrdiff_t PendingWork) { + ZEN_UNUSED(PendingWork); + LastUpdateTimeMs = Timer.GetElapsedTimeMs(); + + uint64_t Deleted = DeletedItemCount.load(); + uint64_t DeletedBytes = DeletedByteCount.load(); + uint64_t Discovered = DiscoveredItemCount.load(); + std::string Details = fmt::format("Found {}, Deleted {} ({})", Discovered, Deleted, NiceBytes(DeletedBytes)); + ProgressFunc(Details, Discovered, Discovered - Deleted, IsPaused, IsAborted); + }); + } + else + { + Work.Wait(); + } + + { + ZEN_TRACE_CPU("DeleteDirs"); + + std::sort(DirectoriesToDelete.begin(), + DirectoriesToDelete.end(), + [](const std::filesystem::path& Lhs, const std::filesystem::path& Rhs) { + auto DistanceLhs = std::distance(Lhs.begin(), Lhs.end()); + auto DistanceRhs = std::distance(Rhs.begin(), Rhs.end()); + return DistanceLhs > DistanceRhs; + }); + + for (const std::filesystem::path& DirectoryToDelete : DirectoriesToDelete) + { + if (AbortFlag) + { + break; + } + else + { + while (PauseFlag && !AbortFlag) + { + Sleep(2000); + } + } + + const std::filesystem::path FullPath = Path / DirectoryToDelete; + + std::error_code Ec; + RemoveDir(FullPath, Ec); + if (Ec) + { + for (size_t Retries = 0; Ec && Retries < 3; Retries++) + { + if (!IsDir(FullPath)) + { + Ec.clear(); + break; + } + Sleep(100 + int(Retries * 50)); + Ec.clear(); + RemoveDir(FullPath, Ec); + } + } + if (Ec) + { + RwLock::ExclusiveLockScope __(ResultLock); + Result.FailedRemovePaths.push_back(std::make_pair(DirectoryToDelete, Ec)); + } + else + { + DeletedItemCount++; + } + + if (ProgressFunc) + { + uint64_t NowMs = Timer.GetElapsedTimeMs(); + + if ((NowMs - LastUpdateTimeMs) > 0) + { + LastUpdateTimeMs = NowMs; + + uint64_t Deleted = DeletedItemCount.load(); + uint64_t DeletedBytes = DeletedByteCount.load(); + uint64_t Discovered = DiscoveredItemCount.load(); + std::string Details = fmt::format("Found {}, Deleted {} ({})", Discovered, Deleted, NiceBytes(DeletedBytes)); + ProgressFunc(Details, Discovered, Discovered - Deleted, PauseFlag, AbortFlag); + } + } + } + } + + return Result; +} + +bool +CleanAndRemoveDirectory(WorkerThreadPool& WorkerPool, + std::atomic<bool>& AbortFlag, + std::atomic<bool>& PauseFlag, + const std::filesystem::path& Directory) +{ + if (!IsDir(Directory)) + { + return true; + } + if (CleanDirectoryResult Res = CleanDirectory( + WorkerPool, + AbortFlag, + PauseFlag, + Directory, + {}, + [](const std::string_view Details, uint64_t TotalCount, uint64_t RemainingCount, bool IsPaused, bool IsAborted) { + ZEN_UNUSED(Details, TotalCount, RemainingCount, IsPaused, IsAborted); + }, + 1000); + Res.FailedRemovePaths.empty()) + { + std::error_code Ec; + RemoveDir(Directory, Ec); + return !Ec; + } + return false; +} + +#if ZEN_WITH_TESTS + +void +filesystemutils_forcelink() +{ +} + +namespace { + void GenerateFile(const std::filesystem::path& Path) { BasicFile _(Path, BasicFile::Mode::kTruncate); } +} // namespace + +TEST_SUITE_BEGIN("util.filesystemutils"); + +TEST_CASE("filesystemutils.CleanDirectory") +{ + ScopedTemporaryDirectory TmpDir; + + CreateDirectories(TmpDir.Path() / ".keepme"); + GenerateFile(TmpDir.Path() / ".keepme" / "keep"); + GenerateFile(TmpDir.Path() / "deleteme1"); + GenerateFile(TmpDir.Path() / "deleteme2"); + GenerateFile(TmpDir.Path() / "deleteme3"); + CreateDirectories(TmpDir.Path() / ".keepmenot"); + CreateDirectories(TmpDir.Path() / "no.keepme"); + + CreateDirectories(TmpDir.Path() / "DeleteMe"); + GenerateFile(TmpDir.Path() / "DeleteMe" / "delete1"); + CreateDirectories(TmpDir.Path() / "CantDeleteMe"); + GenerateFile(TmpDir.Path() / "CantDeleteMe" / "delete1"); + GenerateFile(TmpDir.Path() / "CantDeleteMe" / "delete2"); + GenerateFile(TmpDir.Path() / "CantDeleteMe" / "delete3"); + CreateDirectories(TmpDir.Path() / "CantDeleteMe" / ".keepme"); + CreateDirectories(TmpDir.Path() / "CantDeleteMe" / "DeleteMe2"); + GenerateFile(TmpDir.Path() / "CantDeleteMe" / "DeleteMe2" / "delete2"); + GenerateFile(TmpDir.Path() / "CantDeleteMe" / "DeleteMe2" / "delete3"); + CreateDirectories(TmpDir.Path() / "CantDeleteMe2" / ".keepme"); + CreateDirectories(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept"); + GenerateFile(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept" / "kept1"); + GenerateFile(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept" / "kept2"); + GenerateFile(TmpDir.Path() / "CantDeleteMe2" / "deleteme"); + + WorkerThreadPool Pool(4); + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + + CleanDirectory(Pool, AbortFlag, PauseFlag, TmpDir.Path(), std::vector<std::string>{".keepme"}, {}, 0); + + CHECK(IsDir(TmpDir.Path() / ".keepme")); + CHECK(IsFile(TmpDir.Path() / ".keepme" / "keep")); + CHECK(!IsFile(TmpDir.Path() / "deleteme1")); + CHECK(!IsFile(TmpDir.Path() / "deleteme2")); + CHECK(!IsFile(TmpDir.Path() / "deleteme3")); + CHECK(!IsFile(TmpDir.Path() / ".keepmenot")); + CHECK(!IsFile(TmpDir.Path() / "no.keepme")); + + CHECK(!IsDir(TmpDir.Path() / "DeleteMe")); + CHECK(!IsDir(TmpDir.Path() / "DeleteMe2")); + + CHECK(IsDir(TmpDir.Path() / "CantDeleteMe")); + CHECK(IsDir(TmpDir.Path() / "CantDeleteMe" / ".keepme")); + CHECK(IsDir(TmpDir.Path() / "CantDeleteMe2")); + CHECK(IsDir(TmpDir.Path() / "CantDeleteMe2" / ".keepme")); + CHECK(IsDir(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept")); + CHECK(IsFile(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept" / "kept1")); + CHECK(IsFile(TmpDir.Path() / "CantDeleteMe2" / ".keepme" / "Kept" / "kept2")); + CHECK(!IsFile(TmpDir.Path() / "CantDeleteMe2" / "deleteme")); +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen diff --git a/src/zenutil/filteredrate.cpp b/src/zenutil/filteredrate.cpp new file mode 100644 index 000000000..de01af57b --- /dev/null +++ b/src/zenutil/filteredrate.cpp @@ -0,0 +1,92 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/filteredrate.h> + +namespace zen { + +void +FilteredRate::Start() +{ + if (StartTimeUS == (uint64_t)-1) + { + uint64_t Expected = (uint64_t)-1; + if (StartTimeUS.compare_exchange_strong(Expected, Timer.GetElapsedTimeUs())) + { + LastTimeUS = StartTimeUS.load(); + } + } +} + +void +FilteredRate::Stop() +{ + if (EndTimeUS == (uint64_t)-1) + { + uint64_t Expected = (uint64_t)-1; + EndTimeUS.compare_exchange_strong(Expected, Timer.GetElapsedTimeUs()); + } +} + +void +FilteredRate::Update(uint64_t Count) +{ + if (LastTimeUS == (uint64_t)-1) + { + return; + } + uint64_t TimeUS = Timer.GetElapsedTimeUs(); + uint64_t TimeDeltaUS = TimeUS - LastTimeUS; + if (TimeDeltaUS >= 2000000) + { + uint64_t Delta = Count - LastCount; + uint64_t PerSecond = (Delta * 1000000) / TimeDeltaUS; + + FilteredPerSecond = (PerSecond + (LastPerSecond * 7)) / 8; + + LastPerSecond = PerSecond; + LastCount = Count; + LastTimeUS = TimeUS; + } +} + +uint64_t +FilteredRate::GetCurrent() const +{ + if (LastTimeUS == (uint64_t)-1) + { + return 0; + } + return FilteredPerSecond; +} + +uint64_t +FilteredRate::GetElapsedTimeUS() const +{ + if (StartTimeUS == (uint64_t)-1) + { + return 0; + } + if (EndTimeUS == (uint64_t)-1) + { + return 0; + } + return EndTimeUS - StartTimeUS; +} + +bool +FilteredRate::IsActive() const +{ + return (StartTimeUS != (uint64_t)-1) && (EndTimeUS == (uint64_t)-1); +} + +uint64_t +GetBytesPerSecond(uint64_t ElapsedWallTimeUS, uint64_t Count) +{ + if (ElapsedWallTimeUS == 0) + { + return 0; + } + return Count * 1000000 / ElapsedWallTimeUS; +} + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/mockimds.h b/src/zenutil/include/zenutil/cloud/mockimds.h index d0c0155b0..28e1e8ba6 100644 --- a/src/zenutil/include/zenutil/cloud/mockimds.h +++ b/src/zenutil/include/zenutil/cloud/mockimds.h @@ -23,7 +23,7 @@ namespace zen::compute { * * When a request arrives for a provider that is not the ActiveProvider, the * mock returns 404, causing CloudMetadata to write a sentinel file and move - * on to the next provider — exactly like a failed probe on bare metal. + * on to the next provider - exactly like a failed probe on bare metal. * * All config fields are public and can be mutated between poll cycles to * simulate state changes (e.g. a spot interruption appearing mid-run). @@ -45,13 +45,13 @@ public: std::string AvailabilityZone = "us-east-1a"; std::string LifeCycle = "on-demand"; // "spot" or "on-demand" - // Empty string → endpoint returns 404 (instance not in an ASG). - // Non-empty → returned as the response body. "InService" means healthy; + // Empty string -> endpoint returns 404 (instance not in an ASG). + // Non-empty -> returned as the response body. "InService" means healthy; // anything else (e.g. "Terminated:Wait") triggers termination detection. std::string AutoscalingState; - // Empty string → endpoint returns 404 (no spot interruption). - // Non-empty → returned as the response body, signalling a spot reclaim. + // Empty string -> endpoint returns 404 (no spot interruption). + // Non-empty -> returned as the response body, signalling a spot reclaim. std::string SpotAction; // IAM credential fields for ImdsCredentialProvider testing @@ -69,10 +69,10 @@ public: std::string Location = "eastus"; std::string Priority = "Regular"; // "Spot" or "Regular" - // Empty → instance is not in a VM Scale Set (no autoscaling). + // Empty -> instance is not in a VM Scale Set (no autoscaling). std::string VmScaleSetName; - // Empty → no scheduled events. Set to "Preempt", "Terminate", or + // Empty -> no scheduled events. Set to "Preempt", "Terminate", or // "Reboot" to simulate a termination-class event. std::string ScheduledEventType; std::string ScheduledEventStatus = "Scheduled"; diff --git a/src/zenutil/include/zenutil/cloud/s3client.h b/src/zenutil/include/zenutil/cloud/s3client.h index f1f0df0e4..1ce2a768e 100644 --- a/src/zenutil/include/zenutil/cloud/s3client.h +++ b/src/zenutil/include/zenutil/cloud/s3client.h @@ -11,7 +11,12 @@ #include <zencore/thread.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + #include <functional> +#include <span> #include <string> #include <string_view> #include <vector> @@ -129,6 +134,11 @@ public: /// Delete an object from S3 S3Result DeleteObject(std::string_view Key); + /// Refresh an object's LastModified timestamp via a PUT Object - Copy onto itself + /// (x-amz-metadata-directive: REPLACE). Useful to reset lifecycle-expiration timers + /// without re-uploading the content. + S3Result Touch(std::string_view Key); + /// Check if an object exists and get its metadata S3HeadObjectResult HeadObject(std::string_view Key); @@ -198,11 +208,15 @@ private: /// Build the bucket root path ("/" for virtual-hosted, "/bucket/" for path-style) std::string BucketRootPath() const; - /// Sign a request and return headers with Authorization, x-amz-date, x-amz-content-sha256 - HttpClient::KeyValueMap SignRequest(std::string_view Method, - std::string_view Path, - std::string_view QueryString, - std::string_view PayloadHash); + /// Sign a request and return headers with Authorization, x-amz-date, x-amz-content-sha256. + /// Additional x-amz-* headers that must participate in the signature are passed via + /// ExtraSignedHeaders (lowercase name, value); they are also copied into the returned map. + HttpClient::KeyValueMap SignRequest(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Path, + std::string_view QueryString, + std::string_view PayloadHash, + std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders = {}); /// Get or compute the signing key for the given date stamp, caching across requests on the same day Sha256Digest GetSigningKey(std::string_view DateStamp); @@ -210,6 +224,23 @@ private: /// Get the current credentials, either from the provider or from static config SigV4Credentials GetCurrentCredentials(); + /// Populate OutCredentials and return empty string on success; on failure return a + /// "<context>: no credentials available" error (also logged). Context args are only + /// formatted on the failure path. + template<typename... ContextArgs> + std::string RequireCredentials(SigV4Credentials& OutCredentials, fmt::format_string<ContextArgs...> ContextFmt, ContextArgs&&... Args) + { + OutCredentials = GetCurrentCredentials(); + if (!OutCredentials.AccessKeyId.empty()) + { + return {}; + } + OutCredentials = {}; + return BuildNoCredentialsError(fmt::format(ContextFmt, std::forward<ContextArgs>(Args)...)); + } + + std::string BuildNoCredentialsError(std::string Context); + LoggerRef m_Log; std::string m_BucketName; std::string m_Region; @@ -219,6 +250,7 @@ private: SigV4Credentials m_Credentials; Ref<ImdsCredentialProvider> m_CredentialProvider; HttpClient m_HttpClient; + bool m_Verbose = false; // Cached signing key (only changes once per day, protected by RwLock for thread safety) mutable RwLock m_SigningKeyLock; diff --git a/src/zenutil/include/zenutil/config/commandlineoptions.h b/src/zenutil/include/zenutil/config/commandlineoptions.h index 01cceedb1..ed7f46a08 100644 --- a/src/zenutil/include/zenutil/config/commandlineoptions.h +++ b/src/zenutil/include/zenutil/config/commandlineoptions.h @@ -22,6 +22,17 @@ std::vector<char*> StripCommandlineQuotes(std::vector<std::string>& InOutArgs) std::filesystem::path StringToPath(const std::string_view& Path); std::string_view RemoveQuotes(const std::string_view& Arg); +/// Scrub sensitive values from a command-line string in place. Masks the value +/// of options whose normalized name ends in "token", "aeskey", "secret" or +/// "dsn"; redacts user:password@ in URL authorities; partially masks values +/// matching credential formats that zen/zenserver handles (AWS access keys, +/// Google API keys, JWT bearer tokens). On any exception the string is left +/// unchanged. +/// +/// Used by the invocation-history writer and by the Sentry integration before +/// the command line is attached to a crash report. +void ScrubSensitiveValues(std::string& Cmdline) noexcept; + class CommandLineConverter { public: diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h index 921d33ce2..c5c9b7d64 100644 --- a/src/zenutil/include/zenutil/consoletui.h +++ b/src/zenutil/include/zenutil/consoletui.h @@ -76,7 +76,7 @@ bool IsTuiAvailable(); // Items (must be same length as Items). Useful when the display label differs // from the searchable content. // -// Arrow keys (↑/↓) navigate the selection, Enter confirms, Esc cancels. +// Arrow keys (^/v) navigate the selection, Enter confirms, Esc cancels. // Type to incrementally filter the list. Returns the index of the selected // item in the original Items array, or -1 if the user cancelled. // diff --git a/src/zenutil/include/zenutil/consul.h b/src/zenutil/include/zenutil/consul.h index 4002d5d23..38d450583 100644 --- a/src/zenutil/include/zenutil/consul.h +++ b/src/zenutil/include/zenutil/consul.h @@ -3,13 +3,16 @@ #pragma once #include <zenbase/zenbase.h> +#include <zencore/thread.h> #include <zenhttp/httpclient.h> #include <atomic> +#include <chrono> #include <cstdint> #include <string> #include <string_view> #include <thread> +#include <vector> namespace zen::consul { @@ -23,12 +26,22 @@ struct ServiceRegistrationInfo std::vector<std::pair<std::string, std::string>> Tags; uint32_t HealthIntervalSeconds = 10; uint32_t DeregisterAfterSeconds = 30; + std::string InitialStatus; }; class ConsulClient { public: - ConsulClient(std::string_view BaseUri, std::string_view Token = ""); + struct Configuration + { + std::string BaseUri; + std::string StaticToken; + std::string TokenEnvName; + std::chrono::milliseconds ConnectTimeout{1000}; + std::chrono::milliseconds Timeout{2000}; + }; + + ConsulClient(const Configuration& Config); ~ConsulClient(); ConsulClient(const ConsulClient&) = delete; @@ -38,8 +51,15 @@ public: std::string GetKeyValue(std::string_view Key); void DeleteKey(std::string_view Key); - bool RegisterService(const ServiceRegistrationInfo& Info); - bool DeregisterService(std::string_view ServiceId); + // Async. Enqueue onto the worker thread and return immediately. + // Transport outcome is not reported to the caller. + void RegisterService(const ServiceRegistrationInfo& Info); + void DeregisterService(std::string_view ServiceId); + + // Synchronous counterparts. Block on the HTTP call and return true on + // success. Use when the caller needs a result (e.g. a retry loop). + bool DoRegister(const ServiceRegistrationInfo& Info); + bool DoDeregister(std::string_view ServiceId); // Query methods for testing bool HasService(std::string_view ServiceId); @@ -53,11 +73,29 @@ public: bool WatchService(std::string_view ServiceId, uint64_t& InOutIndex, int WaitSeconds); private: + struct PendingOp + { + enum class Kind + { + Register, + Deregister + }; + Kind Type; + ServiceRegistrationInfo Info; + }; + static bool FindServiceInJson(std::string_view Json, std::string_view ServiceId); void ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap); - - std::string m_Token; - HttpClient m_HttpClient; + std::string GetNodeName(); + void WorkerLoop(); + + Configuration m_Config; + std::atomic<bool> m_Stop{false}; + HttpClient m_HttpClient; + RwLock m_QueueLock; + std::vector<PendingOp> m_Queue; + Event m_Wakeup; + std::thread m_Worker; }; class ConsulProcess @@ -109,4 +147,6 @@ private: void RegistrationLoop(); }; +void consul_forcelink(); + } // namespace zen::consul diff --git a/src/zenutil/include/zenutil/filesystemutils.h b/src/zenutil/include/zenutil/filesystemutils.h new file mode 100644 index 000000000..05defd1a8 --- /dev/null +++ b/src/zenutil/include/zenutil/filesystemutils.h @@ -0,0 +1,98 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> + +namespace zen { + +class CompositeBuffer; + +class BufferedOpenFile +{ +public: + static constexpr uint64_t BlockSize = 256u * 1024u; + + BufferedOpenFile(const std::filesystem::path Path, + std::atomic<uint64_t>& OpenReadCount, + std::atomic<uint64_t>& CurrentOpenFileCount, + std::atomic<uint64_t>& ReadCount, + std::atomic<uint64_t>& ReadByteCount); + ~BufferedOpenFile(); + BufferedOpenFile() = delete; + BufferedOpenFile(const BufferedOpenFile&) = delete; + BufferedOpenFile(BufferedOpenFile&&) = delete; + BufferedOpenFile& operator=(BufferedOpenFile&&) = delete; + BufferedOpenFile& operator=(const BufferedOpenFile&) = delete; + + CompositeBuffer GetRange(uint64_t Offset, uint64_t Size); + +public: + void* Handle() { return m_Source.Handle(); } + +private: + BasicFile m_Source; + const uint64_t m_SourceSize; + std::atomic<uint64_t>& m_OpenReadCount; + std::atomic<uint64_t>& m_CurrentOpenFileCount; + std::atomic<uint64_t>& m_ReadCount; + std::atomic<uint64_t>& m_ReadByteCount; + uint64_t m_CacheBlockIndex = (uint64_t)-1; + IoBuffer m_Cache; +}; + +bool IsFileWithRetry(const std::filesystem::path& Path); + +bool SetFileReadOnlyWithRetry(const std::filesystem::path& Path, bool ReadOnly); + +std::error_code RenameFileWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); +std::error_code RenameDirectoryWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); + +std::error_code TryRemoveFile(const std::filesystem::path& Path); + +void RemoveFileWithRetry(const std::filesystem::path& Path); + +void FastCopyFile(bool AllowFileClone, + bool UseSparseFiles, + const std::filesystem::path& SourceFilePath, + const std::filesystem::path& TargetFilePath, + uint64_t RawSize, + std::atomic<uint64_t>& WriteCount, + std::atomic<uint64_t>& WriteByteCount, + std::atomic<uint64_t>& CloneCount, + std::atomic<uint64_t>& CloneByteCount); + +struct CleanDirectoryResult +{ + uint64_t FoundCount = 0; + uint64_t DeletedCount = 0; + uint64_t DeletedByteCount = 0; + std::vector<std::pair<std::filesystem::path, std::error_code>> FailedRemovePaths; +}; + +class WorkerThreadPool; + +void GetDirectoryContent(WorkerThreadPool& WorkerPool, + const std::filesystem::path& Path, + DirectoryContentFlags Flags, + DirectoryContent& OutContent); + +CleanDirectoryResult CleanDirectory( + WorkerThreadPool& IOWorkerPool, + std::atomic<bool>& AbortFlag, + std::atomic<bool>& PauseFlag, + const std::filesystem::path& Path, + std::span<const std::string> ExcludeDirectories, + std::function<void(const std::string_view Details, uint64_t TotalCount, uint64_t RemainingCount, bool IsPaused, bool IsAborted)>&& + ProgressFunc, + uint32_t ProgressUpdateDelayMS); + +bool CleanAndRemoveDirectory(WorkerThreadPool& WorkerPool, + std::atomic<bool>& AbortFlag, + std::atomic<bool>& PauseFlag, + const std::filesystem::path& Directory); + +void filesystemutils_forcelink(); // internal + +} // namespace zen diff --git a/src/zenutil/include/zenutil/filteredrate.h b/src/zenutil/include/zenutil/filteredrate.h new file mode 100644 index 000000000..3349823d0 --- /dev/null +++ b/src/zenutil/include/zenutil/filteredrate.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/timer.h> + +#include <atomic> +#include <cstdint> + +namespace zen { + +class FilteredRate +{ +public: + FilteredRate() {} + + void Start(); + void Stop(); + void Update(uint64_t Count); + + uint64_t GetCurrent() const; + uint64_t GetElapsedTimeUS() const; + bool IsActive() const; + +private: + Stopwatch Timer; + std::atomic<uint64_t> StartTimeUS = (uint64_t)-1; + std::atomic<uint64_t> EndTimeUS = (uint64_t)-1; + std::atomic<uint64_t> LastTimeUS = (uint64_t)-1; + uint64_t LastCount = 0; + uint64_t LastPerSecond = 0; + uint64_t FilteredPerSecond = 0; +}; + +uint64_t GetBytesPerSecond(uint64_t ElapsedWallTimeUS, uint64_t Count); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/invocationhistory.h b/src/zenutil/include/zenutil/invocationhistory.h new file mode 100644 index 000000000..9843d14bc --- /dev/null +++ b/src/zenutil/include/zenutil/invocationhistory.h @@ -0,0 +1,59 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <cstdint> +#include <filesystem> +#include <initializer_list> +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +struct HistoryRecord +{ + std::string Id; + std::string Ts; + std::string Exe; + std::string Mode; + std::string Cwd; + std::string Path; + std::string CmdLine; + uint32_t Pid = 0; +}; + +/// Append an invocation record to the per-user history file. +/// +/// Must be called BEFORE CommandLineConverter has had a chance to mutate argv +/// (which strips quoting and re-parses the Windows command line). Typical call +/// site is right after InstallCrashHandler() in main(). +/// +/// Never throws. Any I/O or parsing failure is swallowed so that logging cannot +/// break the actual command being run. If argv contains the single-token +/// `--enable-execution-history=false` (or `=0` / `=no`), returns without doing +/// anything. +/// +/// Exe should be "zen" or "zenserver". Mode is zenserver-only (hub/store/ +/// compute/proxy/test); pass an empty string_view for zen. +/// +/// ExcludeSubcommands is an optional list of subcommand names. If argv[1] +/// matches any entry, logging is skipped. Used e.g. so `zen history` does not +/// pollute the history file it inspects. +void LogInvocation(std::string_view Exe, + std::string_view Mode, + int argc, + char** argv, + std::initializer_list<std::string_view> ExcludeSubcommands = {}) noexcept; + +/// Read the most recent MaxRecords entries from the history file (newest last). +/// Returns an empty vector if the file does not exist or is unreadable. +std::vector<HistoryRecord> ReadInvocationHistory(size_t MaxRecords = 100); + +/// Resolve the per-user history file path. Does not create the file or the +/// parent directory. Never throws; returns an empty path on any error. +std::filesystem::path GetInvocationHistoryPath() noexcept; + +void invocationhistory_forcelink(); // internal + +} // namespace zen diff --git a/src/zenutil/include/zenutil/parallelsort.h b/src/zenutil/include/zenutil/parallelsort.h new file mode 100644 index 000000000..ed455ce9d --- /dev/null +++ b/src/zenutil/include/zenutil/parallelsort.h @@ -0,0 +1,119 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/scopeguard.h> +#include <zencore/thread.h> +#include <zencore/workthreadpool.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <EASTL/sort.h> +#include <EASTL/vector.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <algorithm> + +namespace zen { + +// Bottom-up parallel merge sort using WorkerThreadPool + Latch. +// +// Splits the range into chunks, sorts each chunk in parallel via +// eastl::sort, then iteratively merges adjacent pairs in parallel +// using std::inplace_merge until a single sorted range remains. +// +// Falls back to eastl::sort for ranges below kMinParallelSortSize. +template<typename RandomIt, typename Compare> +void +ParallelSort(WorkerThreadPool& Pool, RandomIt First, RandomIt Last, Compare Comp) +{ + constexpr size_t kMinParallelSortSize = 65536; + constexpr size_t kMinChunkSize = 65536; + constexpr size_t kMaxChunks = 64; + + size_t Count = size_t(Last - First); + if (Count <= kMinParallelSortSize) + { + eastl::sort(First, Last, Comp); + return; + } + + // Determine chunk count: enough to saturate workers, but not so many + // that scheduling overhead dominates. + size_t ChunkCount = (Count + kMinChunkSize - 1) / kMinChunkSize; + if (ChunkCount > kMaxChunks) + { + ChunkCount = kMaxChunks; + } + if (ChunkCount < 2) + { + ChunkCount = 2; + } + + // Compute chunk boundaries. + eastl::vector<RandomIt> Boundaries; + Boundaries.reserve(ChunkCount + 1); + size_t ChunkSize = Count / ChunkCount; + for (size_t I = 0; I < ChunkCount; ++I) + { + Boundaries.push_back(First + ptrdiff_t(I * ChunkSize)); + } + Boundaries.push_back(Last); + + // Phase 1: Sort each chunk in parallel. + { + Latch Done(1); + for (size_t I = 0; I < ChunkCount; ++I) + { + Done.AddCount(1); + Pool.ScheduleWork( + [&Done, Begin = Boundaries[I], End = Boundaries[I + 1], &Comp]() { + auto Guard = MakeGuard([&Done]() { Done.CountDown(); }); + eastl::sort(Begin, End, Comp); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + Done.CountDown(); + Done.Wait(); + } + + // Phase 2: Pairwise merge rounds until a single sorted range remains. + // Each round merges non-overlapping adjacent pairs in parallel, then + // compacts the boundary list. An odd trailing chunk is carried forward. + while (ChunkCount > 1) + { + size_t Pairs = ChunkCount / 2; + + { + Latch Done(1); + for (size_t I = 0; I < Pairs; ++I) + { + size_t Left = 2 * I; + Done.AddCount(1); + Pool.ScheduleWork( + [&Done, F = Boundaries[Left], M = Boundaries[Left + 1], L = Boundaries[Left + 2], &Comp]() { + auto Guard = MakeGuard([&Done]() { Done.CountDown(); }); + std::inplace_merge(F, M, L, Comp); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + Done.CountDown(); + Done.Wait(); + } + + // Compact boundaries: merged pairs collapse, odd chunk carried forward. + eastl::vector<RandomIt> NewBounds; + NewBounds.reserve((ChunkCount + 1) / 2 + 1); + for (size_t I = 0; I < ChunkCount; I += 2) + { + NewBounds.push_back(Boundaries[I]); + } + NewBounds.push_back(Last); + + ChunkCount = NewBounds.size() - 1; + Boundaries = eastl::move(NewBounds); + } +} + +void parallelsort_forcelink(); // internal + +} // namespace zen diff --git a/src/zenutil/include/zenutil/process/subprocessmanager.h b/src/zenutil/include/zenutil/process/subprocessmanager.h index e16c0c446..95d7fa43d 100644 --- a/src/zenutil/include/zenutil/process/subprocessmanager.h +++ b/src/zenutil/include/zenutil/process/subprocessmanager.h @@ -97,7 +97,7 @@ public: /// If Options.StdoutPipe is set, the pipe is consumed and async reading /// begins automatically. Similarly for Options.StderrPipe. When providing /// pipes, pass the corresponding data callback here so it is installed - /// before the first async read completes — setting it later via + /// before the first async read completes - setting it later via /// SetStdoutCallback risks losing early output. /// /// Returns a non-owning pointer valid until Remove() or manager destruction. @@ -112,7 +112,7 @@ public: /// Adopt an already-running process by handle. Takes ownership of handle internals. ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); - /// Stop monitoring a process by pid. Does NOT kill the process — call + /// Stop monitoring a process by pid. Does NOT kill the process - call /// process->Kill() first if needed. The exit callback will not fire after /// this returns. void Remove(int Pid); @@ -219,7 +219,7 @@ private: /// A group of managed processes with OS-level backing. /// /// On Windows: backed by a JobObject. All processes assigned on spawn. -/// Kill-on-close guarantee — if the group is destroyed, the OS terminates +/// Kill-on-close guarantee - if the group is destroyed, the OS terminates /// all member processes. /// On Linux/macOS: uses setpgid() so children share a process group. /// Enables bulk signal delivery via kill(-pgid, sig). diff --git a/src/zenutil/include/zenutil/progress.h b/src/zenutil/include/zenutil/progress.h new file mode 100644 index 000000000..4103723b3 --- /dev/null +++ b/src/zenutil/include/zenutil/progress.h @@ -0,0 +1,68 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logbase.h> + +#include <memory> +#include <string> + +namespace zen { + +class ProgressBase +{ +public: + virtual ~ProgressBase() = default; + + virtual void SetLogOperationName(std::string_view Name) = 0; + virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) = 0; + virtual void PushLogOperation(std::string_view Name) = 0; + virtual void PopLogOperation() = 0; + virtual uint32_t GetProgressUpdateDelayMS() const = 0; + + class ProgressBar + { + public: + struct State + { + bool operator==(const State&) const = default; + std::string Task; + std::string Details; + uint64_t TotalCount = 0; + uint64_t RemainingCount = 0; + uint64_t OptionalElapsedTime = (uint64_t)-1; + enum class EStatus + { + Running, + Aborted, + Paused + }; + EStatus Status = EStatus::Running; + + static constexpr EStatus CalculateStatus(bool IsAborted, bool IsPaused) + { + if (IsAborted) + { + return EStatus::Aborted; + } + if (IsPaused) + { + return EStatus::Paused; + } + return EStatus::Running; + } + }; + + virtual ~ProgressBar() = default; + + virtual void UpdateState(const State& NewState, bool DoLinebreak) = 0; + virtual void ForceLinebreak() = 0; + virtual void Finish() = 0; + }; + + virtual std::unique_ptr<ProgressBar> CreateProgressBar(std::string_view InSubTask) = 0; +}; + +ProgressBase* CreateStandardProgress(LoggerRef Log); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/sessionsclient.h b/src/zenutil/include/zenutil/sessionsclient.h index aca45e61d..c144a9baa 100644 --- a/src/zenutil/include/zenutil/sessionsclient.h +++ b/src/zenutil/include/zenutil/sessionsclient.h @@ -2,31 +2,53 @@ #pragma once +#include <zencore/blockingqueue.h> #include <zencore/compactbinary.h> #include <zencore/logging.h> #include <zencore/logging/sink.h> - +#include <zencore/string.h> #include <zencore/uid.h> +#include <zenhttp/httpclient.h> #include <memory> #include <string> +#include <thread> namespace zen { -class HttpClient; - /// Client for announcing and maintaining a session on a remote zenserver's /sessions/ endpoint. -/// Follows the same best-effort pattern as ZenComputeServer's coordinator announce. +/// All HTTP I/O runs on a single background worker thread. Public methods enqueue commands +/// and return immediately. class SessionsServiceClient { public: struct Options { - std::string TargetUrl; // Base URL of the target zenserver (e.g. "http://localhost:8558") - std::string AppName; // Application name to register - std::string Mode; // Server mode (e.g. "Server", "Compute", "Proxy") - Oid SessionId = Oid::Zero; // Session ID to register under - Oid JobId = Oid::Zero; // Optional job ID + std::string TargetUrl; // Base URL of the target zenserver (e.g. "http://localhost:8558") + std::string AppName; // Application name to register + std::string Mode; // Server mode (e.g. "Server", "Compute", "Proxy") + Oid SessionId = Oid::Zero; // Session ID to register under + Oid JobId = Oid::Zero; // Optional job ID + HttpClientSettings ClientSettings; // Optional; timeouts are overridden internally (e.g. for unix sockets) + }; + + /// Command sent to the background worker thread. + struct SessionCommand + { + enum class Type : uint8_t + { + Announce, + UpdateMetadata, + Remove, + Log, + FlushLogs, + Shutdown + }; + + Type CommandType = Type::Log; + CbObject Metadata; // Announce, UpdateMetadata + logging::LogLevel LogLevel{}; // Log + CompactString LogMessage; // Log }; explicit SessionsServiceClient(Options Opts); @@ -35,17 +57,17 @@ public: SessionsServiceClient(const SessionsServiceClient&) = delete; SessionsServiceClient& operator=(const SessionsServiceClient&) = delete; - /// POST /sessions/{id} — register or re-announce the session with optional metadata. - [[nodiscard]] bool Announce(CbObjectView Metadata = {}); + /// POST /sessions/{id} — enqueues an announce command (fire-and-forget). + void Announce(CbObjectView Metadata = {}); - /// PUT /sessions/{id} — update metadata on an existing session. - [[nodiscard]] bool UpdateMetadata(CbObjectView Metadata = {}); + /// PUT /sessions/{id} — enqueues a metadata update command (fire-and-forget). + void UpdateMetadata(CbObjectView Metadata = {}); - /// DELETE /sessions/{id} — remove the session. - [[nodiscard]] bool Remove(); + /// DELETE /sessions/{id} — enqueues a remove command (fire-and-forget). + void Remove(); /// Create a logging sink that forwards log messages to the session's log endpoint. - /// The sink batches messages on a background thread and POSTs them periodically. + /// The sink enqueues messages onto this client's worker thread. /// The returned sink can be added to any logger via Logger::AddSink(). logging::SinkPtr CreateLogSink(); @@ -55,11 +77,19 @@ public: private: CbObject BuildRequestBody(CbObjectView Metadata) const; - LoggerRef Log() { return m_Log; } - LoggerRef m_Log; - Options m_Options; - std::string m_SessionPath; // "sessions/<hex>" - std::unique_ptr<HttpClient> m_Http; + void WorkerLoop(); + void DoAnnounce(HttpClient& Http, CbObjectView Metadata); + void DoUpdateMetadata(HttpClient& Http, CbObjectView Metadata); + void DoRemove(HttpClient& Http); + void SendLogBatch(HttpClient& Http, const std::string& LogPath, const std::vector<SessionCommand>& Batch); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + Options m_Options; + std::string m_SessionPath; // "/sessions/<hex>" + BlockingQueue<SessionCommand> m_Queue; + std::thread m_WorkerThread; }; } // namespace zen diff --git a/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h b/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h index f4ac5ff22..e59ebc7f4 100644 --- a/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h +++ b/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h @@ -67,7 +67,7 @@ public: CbObjectWriter Writer; Writer.AddString("text", Text); Writer.AddString("source", m_Source); - Writer.AddString("level", logging::ToStringView(Msg.GetLevel())); + Writer.AddString("level", logging::ToString(Msg.GetLevel())); Writer.AddInteger("seq", Seq); CbObject Obj = Writer.Save(); @@ -85,12 +85,12 @@ public: void Flush() override { - // Nothing to flush — writes happen asynchronously + // Nothing to flush - writes happen asynchronously } void SetFormatter(std::unique_ptr<logging::Formatter> /*InFormatter*/) override { - // Not used — we output the raw payload directly + // Not used - we output the raw payload directly } private: @@ -124,7 +124,7 @@ private: { break; // don't retry during shutdown } - continue; // drop batch — will retry on next batch + continue; // drop batch - will retry on next batch } // Build a gathered buffer sequence so the entire batch is written @@ -176,7 +176,7 @@ private: std::string m_Source; uint32_t m_MaxQueueSize; - // Sequence counter — incremented atomically by Log() callers. + // Sequence counter - incremented atomically by Log() callers. // Gaps in the sequence seen by the receiver indicate dropped messages. std::atomic<uint64_t> m_NextSequence{0}; diff --git a/src/zenutil/include/zenutil/suggest.h b/src/zenutil/include/zenutil/suggest.h new file mode 100644 index 000000000..c24bbcc33 --- /dev/null +++ b/src/zenutil/include/zenutil/suggest.h @@ -0,0 +1,19 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <span> +#include <string_view> +#include <vector> + +namespace zen { + +// Returns up to 5 names from `Candidates` that look like likely typos of `Typed`, +// sorted by best match first. Case-insensitive. Uses Damerau-Levenshtein distance +// with a small prefix bonus so that short inputs (e.g. "ca") still surface longer +// commands ("cache"). Returns an empty vector if nothing is a plausible match. +std::vector<std::string_view> SuggestSimilarCommands(std::string_view Typed, std::span<const std::string_view> Candidates); + +void suggest_forcelink(); // internal + +} // namespace zen diff --git a/src/zenutil/include/zenutil/testartifactprovider.h b/src/zenutil/include/zenutil/testartifactprovider.h new file mode 100644 index 000000000..77af7b850 --- /dev/null +++ b/src/zenutil/include/zenutil/testartifactprovider.h @@ -0,0 +1,109 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenutil/cloud/s3client.h> + +#include <zenbase/refcount.h> +#include <zencore/iobuffer.h> + +#include <filesystem> +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +struct TestArtifactInfo +{ + std::string RelativePath; + uint64_t Size = 0; +}; + +struct TestArtifactResult +{ + std::string Error; + + bool IsSuccess() const { return Error.empty(); } + explicit operator bool() const { return IsSuccess(); } +}; + +struct TestArtifactFetchResult : TestArtifactResult +{ + IoBuffer Content; +}; + +struct TestArtifactListResult : TestArtifactResult +{ + std::vector<TestArtifactInfo> Artifacts; +}; + +/// Provider for read/write access to test artifacts. +/// +/// Artifacts are identified by a forward-slash relative path (e.g. "traces/powercycle.utrace"). +/// A provider is backed by a writable local cache directory and, optionally, an S3 source +/// used as the read-through primary. Fetch consults the cache first; on miss it pulls from the +/// primary (if configured) and writes the result back into the cache. +class TestArtifactProvider : public RefCounted +{ +public: + virtual std::string Describe() const = 0; + virtual bool Exists(std::string_view RelativePath) = 0; + virtual TestArtifactFetchResult Fetch(std::string_view RelativePath) = 0; + virtual TestArtifactListResult List(std::string_view Prefix) = 0; + virtual TestArtifactResult Store(std::string_view RelativePath, IoBuffer Content) = 0; +}; + +/// Environment variable specifying the local cache directory (overrides GetDefaultLocalTestArtifactPath()). +inline constexpr std::string_view kTestArtifactsPathEnvVar = "ZEN_TEST_ARTIFACTS_PATH"; + +/// Environment variable specifying the primary S3 source. Expected format: +/// "s3://<bucket>[/<prefix>]" (the "s3://" scheme is optional). +inline constexpr std::string_view kTestArtifactsS3EnvVar = "ZEN_TEST_ARTIFACTS_S3"; + +/// Directory name placed next to the default zenserver state directory for the local cache. +inline constexpr std::string_view kDefaultLocalTestArtifactDirName = "zen-artifact-cache"; + +/// Returns the default local cache root, derived from PickDefaultSystemRootDirectory() by +/// appending kDefaultLocalTestArtifactDirName. Returns an empty path if no system root can be +/// determined on the current platform. +std::filesystem::path GetDefaultLocalTestArtifactPath(); + +struct TestArtifactProviderOptions +{ + /// Local cache directory. If empty, ZEN_TEST_ARTIFACTS_PATH is consulted, then + /// GetDefaultLocalTestArtifactPath() is used as a final fallback. + std::filesystem::path CacheDir; + + /// Optional remote S3 primary source. Empty fields are filled from environment variables: + /// ZEN_TEST_ARTIFACTS_S3 for bucket + key prefix, AWS_DEFAULT_REGION / AWS_REGION for region, + /// AWS_ENDPOINT_URL for endpoint, AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY / AWS_SESSION_TOKEN + /// for credentials. + S3ClientOptions S3Client; + std::string S3KeyPrefix; +}; + +/// Create a test artifact provider. Combines a writable local cache with an optional S3 primary +/// source; Fetch consults the cache first and populates it on remote hits. +/// +/// With default-constructed options the provider is configured entirely from environment variables +/// (see kTestArtifactsPathEnvVar / kTestArtifactsS3EnvVar / AWS_*). Returns a null reference only +/// on platforms where no cache directory can be determined. +Ref<TestArtifactProvider> CreateTestArtifactProvider(TestArtifactProviderOptions Options = {}); + +/// Returns true when at least one test artifact source is configured and reachable: either a +/// local artifact directory (ZEN_TEST_ARTIFACTS_PATH) or an S3 source with signing credentials +/// available (see S3TestArtifactsAvailable). Intended for doctest::skip decorators: +/// TEST_CASE("needs_artifacts" * doctest::skip(!TestArtifactsAvailable())) +bool TestArtifactsAvailable(); + +/// Returns true when the S3 test artifact source is configured (ZEN_TEST_ARTIFACTS_S3) AND the +/// current environment can sign requests: either static AWS credentials are present +/// (AWS_ACCESS_KEY_ID / AWS_SESSION_TOKEN) or the EC2 Instance Metadata Service is usable +/// (not macOS, not explicitly disabled via AWS_EC2_METADATA_DISABLED=true). +/// TEST_CASE("needs_s3" * doctest::skip(!S3TestArtifactsAvailable())) +bool S3TestArtifactsAvailable(); + +void testartifactprovider_forcelink(); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/windows/windowsservice.h b/src/zenutil/include/zenutil/windows/windowsservice.h index ca0270a36..d7b3347a9 100644 --- a/src/zenutil/include/zenutil/windows/windowsservice.h +++ b/src/zenutil/include/zenutil/windows/windowsservice.h @@ -8,7 +8,7 @@ class WindowsService { public: WindowsService(); - ~WindowsService(); + virtual ~WindowsService(); virtual int Run() = 0; diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h index d6f66fbea..2fa212f92 100644 --- a/src/zenutil/include/zenutil/zenserverprocess.h +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -143,6 +143,12 @@ struct ZenServerInstance inline void SetServerExecutablePath(std::filesystem::path ExecutablePath) { m_ServerExecutablePath = ExecutablePath; } + // Controls whether the spawned server records its invocation in the per-user + // execution history file. Default is false so tests, hubs, and auto-spawned + // instances do not pollute the user's history. User-invoked launches (zen up) + // should enable this so the spawn shows up in `zen history`. + inline void SetEnableExecutionHistory(bool Enable) { m_EnableExecutionHistory = Enable; } + void SetDataDir(std::filesystem::path TestDir); inline void SpawnServer(std::string_view AdditionalServerArgs = std::string_view()) @@ -191,6 +197,7 @@ private: std::string m_Name; std::filesystem::path m_OutputCapturePath; std::filesystem::path m_ServerExecutablePath; + bool m_EnableExecutionHistory = false; #if ZEN_PLATFORM_WINDOWS JobObject* m_JobObject = nullptr; #endif @@ -341,7 +348,8 @@ struct StartupZenServerOptions bool OpenConsole = false; // open a console window for the server process bool ShowLog = false; // emit captured server log to LogRef on successful start std::string ExtraArgs; // e.g. GlobalOptions.PassthroughCommandLine - ZenServerInstance::ServerMode Mode = ZenServerInstance::ServerMode::kStorageServer; + ZenServerInstance::ServerMode Mode = ZenServerInstance::ServerMode::kStorageServer; + bool EnableExecutionHistory = false; // record the spawned zenserver in `zen history` }; // Returns std::nullopt if a matching server is already running (no action taken); logs instance info via LogRef. diff --git a/src/zenutil/invocationhistory.cpp b/src/zenutil/invocationhistory.cpp new file mode 100644 index 000000000..077061752 --- /dev/null +++ b/src/zenutil/invocationhistory.cpp @@ -0,0 +1,308 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/invocationhistory.h> + +#include <zencore/basicfile.h> +#include <zencore/compactbinary.h> +#include <zencore/filesystem.h> +#include <zencore/iobuffer.h> +#include <zencore/memoryview.h> +#include <zencore/process.h> +#include <zencore/uid.h> +#include <zenutil/config/commandlineoptions.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <system_error> + +namespace zen { + +namespace { + + constexpr size_t kMaxRecords = 100; + constexpr std::string_view kHistoryFileName = "invocations.jsonl"; + + // Safety cap. With 100 records at typical ~500-1000 bytes each the file + // normally sits around 50-100 KB. If it has grown past this threshold + // (external corruption, runaway producer, another tool writing garbage) + // we refuse to read it and start fresh with just the new record. Keeps + // LogInvocation from slowing startup on a pathological file. + constexpr uintmax_t kMaxReadSize = 1 * 1024 * 1024; // 1 MB + + bool ExecutionHistoryDisabled(int argc, char** argv) + { + for (int I = 1; I < argc; ++I) + { + if (argv[I] == nullptr) + { + continue; + } + std::string_view A = argv[I]; + if (A == "--enable-execution-history=false" || A == "--enable-execution-history=0" || A == "--enable-execution-history=no") + { + return true; + } + } + return false; + } + + std::filesystem::path ResolveHistoryDir() + { +#if ZEN_PLATFORM_WINDOWS + std::string LocalAppData = GetEnvVariable("LOCALAPPDATA"); + if (!LocalAppData.empty()) + { + return std::filesystem::path(LocalAppData) / "Epic" / "Zen" / "History"; + } +#endif + std::filesystem::path SystemRoot = PickDefaultSystemRootDirectory(); + if (SystemRoot.empty()) + { + return {}; + } + return SystemRoot / "History"; + } + + std::string BuildJsonRecord(const HistoryRecord& Rec) + { + json11::Json::object Obj{ + {"id", Rec.Id}, + {"ts", Rec.Ts}, + {"exe", Rec.Exe}, + {"pid", static_cast<int>(Rec.Pid)}, + {"cwd", Rec.Cwd}, + {"path", Rec.Path}, + {"cmdline", Rec.CmdLine}, + }; + if (!Rec.Mode.empty()) + { + Obj.emplace("mode", Rec.Mode); + } + return json11::Json(Obj).dump(); + } + + bool ParseJsonRecord(std::string_view Line, HistoryRecord& OutRec) + { + std::string Err; + json11::Json J = json11::Json::parse(std::string(Line), Err); + if (!Err.empty() || !J.is_object()) + { + return false; + } + OutRec.Id = J["id"].string_value(); + OutRec.Ts = J["ts"].string_value(); + OutRec.Exe = J["exe"].string_value(); + OutRec.Mode = J["mode"].string_value(); + OutRec.Cwd = J["cwd"].string_value(); + OutRec.Path = J["path"].string_value(); + OutRec.CmdLine = J["cmdline"].string_value(); + OutRec.Pid = static_cast<uint32_t>(J["pid"].int_value()); + return true; + } + + std::vector<std::string> ReadHistoryLines(const std::filesystem::path& Path) + { + std::vector<std::string> Lines; + + std::error_code SizeEc; + const std::uintmax_t FileSize = std::filesystem::file_size(Path, SizeEc); + if (SizeEc || FileSize > kMaxReadSize) + { + return Lines; + } + + FileContents Contents = ReadFile(Path); + if (!Contents) + { + return Lines; + } + IoBuffer Flat = Contents.Flatten(); + const char* Data = static_cast<const char*>(Flat.GetData()); + const size_t Size = Flat.GetSize(); + size_t Start = 0; + for (size_t I = 0; I < Size; ++I) + { + if (Data[I] == '\n') + { + if (I > Start) + { + size_t LineEnd = I; + if (LineEnd > Start && Data[LineEnd - 1] == '\r') + { + --LineEnd; + } + if (LineEnd > Start) + { + Lines.emplace_back(Data + Start, LineEnd - Start); + } + } + Start = I + 1; + } + } + if (Start < Size) + { + Lines.emplace_back(Data + Start, Size - Start); + } + return Lines; + } + +} // namespace + +std::filesystem::path +GetInvocationHistoryPath() noexcept +{ + try + { + std::filesystem::path Dir = ResolveHistoryDir(); + if (Dir.empty()) + { + return {}; + } + return Dir / kHistoryFileName; + } + catch (...) + { + return {}; + } +} + +void +LogInvocation(std::string_view Exe, + std::string_view Mode, + int argc, + char** argv, + std::initializer_list<std::string_view> ExcludeSubcommands) noexcept +{ + try + { + if (ExecutionHistoryDisabled(argc, argv)) + { + return; + } + + if (argc >= 2 && argv[1] != nullptr) + { + std::string_view A1 = argv[1]; + for (std::string_view Excluded : ExcludeSubcommands) + { + if (A1 == Excluded) + { + return; + } + } + } + + std::filesystem::path Dir = ResolveHistoryDir(); + if (Dir.empty()) + { + return; + } + + std::error_code Ec; + CreateDirectories(Dir, Ec); + if (Ec) + { + return; + } + + std::filesystem::path Path = Dir / kHistoryFileName; + + HistoryRecord Rec; + Rec.Id = Oid::NewOid().ToString(); + Rec.Ts = DateTime::Now().ToIso8601(); + Rec.Exe = std::string(Exe); + Rec.Mode = std::string(Mode); + + std::error_code CwdEc; + Rec.Cwd = std::filesystem::current_path(CwdEc).string(); + + Rec.Path = GetRunningExecutablePath().string(); + Rec.Pid = static_cast<uint32_t>(GetCurrentProcessId()); + + std::string Raw = GetRawCommandLine(); + if (Raw.empty()) + { + std::vector<std::string> Args; + Args.reserve(argc); + for (int I = 0; I < argc; ++I) + { + if (argv[I] != nullptr) + { + Args.emplace_back(argv[I]); + } + } + Raw = BuildCommandLine(Args); + } + ScrubSensitiveValues(Raw); + Rec.CmdLine = std::move(Raw); + + std::vector<std::string> Lines = ReadHistoryLines(Path); + if (Lines.size() >= kMaxRecords) + { + Lines.erase(Lines.begin(), Lines.begin() + (Lines.size() - (kMaxRecords - 1))); + } + Lines.push_back(BuildJsonRecord(Rec)); + + std::string NewContents; + size_t TotalSize = 0; + for (const std::string& L : Lines) + { + TotalSize += L.size() + 1; + } + NewContents.reserve(TotalSize); + for (const std::string& L : Lines) + { + NewContents.append(L); + NewContents.push_back('\n'); + } + + std::error_code WriteEc; + TemporaryFile::SafeWriteFile(Path, MemoryView(NewContents.data(), NewContents.size()), WriteEc); + } + catch (...) + { + } +} + +std::vector<HistoryRecord> +ReadInvocationHistory(size_t MaxRecords) +{ + std::vector<HistoryRecord> Records; + try + { + std::filesystem::path Path = GetInvocationHistoryPath(); + if (Path.empty()) + { + return Records; + } + + std::vector<std::string> Lines = ReadHistoryLines(Path); + if (Lines.size() > MaxRecords) + { + Lines.erase(Lines.begin(), Lines.begin() + (Lines.size() - MaxRecords)); + } + + Records.reserve(Lines.size()); + for (const std::string& L : Lines) + { + HistoryRecord Rec; + if (ParseJsonRecord(L, Rec)) + { + Records.push_back(std::move(Rec)); + } + } + } + catch (...) + { + } + return Records; +} + +void +invocationhistory_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenutil/logging/fullformatter.cpp b/src/zenutil/logging/fullformatter.cpp index 2a4840241..283a8bc37 100644 --- a/src/zenutil/logging/fullformatter.cpp +++ b/src/zenutil/logging/fullformatter.cpp @@ -12,6 +12,7 @@ #include <atomic> #include <chrono> #include <string> +#include "zencore/logging.h" namespace zen::logging { @@ -25,7 +26,7 @@ struct FullFormatter::Impl { } - explicit Impl(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {} + explicit Impl(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' ') {} std::chrono::time_point<std::chrono::system_clock> m_Epoch; std::tm m_CachedLocalTm{}; @@ -155,15 +156,7 @@ FullFormatter::Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) OutBuffer.push_back(' '); } - // append logger name if exists - if (Msg.GetLoggerName().size() > 0) - { - OutBuffer.push_back('['); - helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - } - + // level OutBuffer.push_back('['); if (IsColorEnabled()) { @@ -177,6 +170,23 @@ FullFormatter::Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) OutBuffer.push_back(']'); OutBuffer.push_back(' '); + // logger name + if (Msg.GetLoggerName().size() > 0) + { + if (IsColorEnabled()) + { + OutBuffer.append("\033[1m"sv); + } + OutBuffer.push_back('['); + helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); + OutBuffer.push_back(']'); + if (IsColorEnabled()) + { + OutBuffer.append("\033[0m"sv); + } + OutBuffer.push_back(' '); + } + // add source location if present if (Msg.GetSource()) { diff --git a/src/zenutil/logging/logging.cpp b/src/zenutil/logging/logging.cpp index aa34fc50c..936e3c4fd 100644 --- a/src/zenutil/logging/logging.cpp +++ b/src/zenutil/logging/logging.cpp @@ -124,7 +124,7 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) LoggerRef DefaultLogger = zen::logging::Default(); - // Build the broadcast sink — a shared indirection point that all + // Build the broadcast sink - a shared indirection point that all // loggers cloned from the default will share. Adding or removing // a child sink later is immediately visible to every logger. std::vector<logging::SinkPtr> BroadcastChildren; @@ -158,6 +158,10 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) } #endif + // Trace forwarding is installed at a lower level (in Logger::Log itself) + // so it can see the typed fmt argument pack before vformat collapses it + // to a string - see src/zencore/logging/tracelog.cpp. + g_BroadcastSink = Ref<logging::BroadcastSink>(new logging::BroadcastSink(std::move(BroadcastChildren))); bool IsAsync = LogOptions.AllowAsync && !LogOptions.IsDebug && !LogOptions.IsTest; @@ -179,7 +183,7 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) { return; } - static constinit logging::LogPoint ErrorPoint{{}, logging::Err, "{}"}; + static constinit logging::LogPoint ErrorPoint{0, 0, logging::Err, "{}"}; if (auto ErrLogger = zen::logging::ErrorLog()) { try @@ -249,7 +253,7 @@ FinishInitializeLogging(const LoggingOptions& LogOptions) const std::string StartLogTime = zen::DateTime::Now().ToIso8601(); logging::Registry::Instance().ApplyAll([&](auto Logger) { - static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"}; + static constinit logging::LogPoint LogStartPoint{0, 0, logging::Info, "log starting at {}"}; Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime)); }); } diff --git a/src/zenutil/logging/rotatingfilesink.cpp b/src/zenutil/logging/rotatingfilesink.cpp index 23cf60d16..df59af5fe 100644 --- a/src/zenutil/logging/rotatingfilesink.cpp +++ b/src/zenutil/logging/rotatingfilesink.cpp @@ -85,7 +85,7 @@ struct RotatingFileSink::Impl m_CurrentSize = m_CurrentFile.FileSize(OutEc); if (OutEc) { - // FileSize failed but we have an open file — reset to 0 + // FileSize failed but we have an open file - reset to 0 // so we can at least attempt writes from the start m_CurrentSize = 0; OutEc.clear(); diff --git a/src/zenutil/parallelsort.cpp b/src/zenutil/parallelsort.cpp new file mode 100644 index 000000000..8a9f547bc --- /dev/null +++ b/src/zenutil/parallelsort.cpp @@ -0,0 +1,148 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/parallelsort.h> + +#include <zencore/testing.h> + +namespace zen { + +#if ZEN_WITH_TESTS + +void +parallelsort_forcelink() +{ +} + +TEST_SUITE_BEGIN("util.parallelsort"); + +TEST_CASE("empty range") +{ + WorkerThreadPool Pool(2); + eastl::vector<int> Vec; + ParallelSort(Pool, Vec.begin(), Vec.end(), [](int A, int B) { return A < B; }); + CHECK(Vec.empty()); +} + +TEST_CASE("single element") +{ + WorkerThreadPool Pool(2); + eastl::vector<int> Vec = {42}; + ParallelSort(Pool, Vec.begin(), Vec.end(), [](int A, int B) { return A < B; }); + CHECK(Vec.size() == 1); + CHECK(Vec[0] == 42); +} + +TEST_CASE("small array below threshold") +{ + WorkerThreadPool Pool(2); + eastl::vector<int> Vec = {5, 3, 8, 1, 9, 2, 7, 4, 6, 0}; + ParallelSort(Pool, Vec.begin(), Vec.end(), [](int A, int B) { return A < B; }); + for (size_t I = 0; I < Vec.size(); ++I) + { + CHECK(Vec[I] == int(I)); + } +} + +TEST_CASE("large array triggers parallel path") +{ + WorkerThreadPool Pool(4); + + // 200K elements — well above the 64K threshold. + constexpr size_t N = 200'000; + eastl::vector<uint32_t> Vec(N); + + // Fill with descending values. + for (size_t I = 0; I < N; ++I) + { + Vec[I] = uint32_t(N - 1 - I); + } + + ParallelSort(Pool, Vec.begin(), Vec.end(), [](uint32_t A, uint32_t B) { return A < B; }); + + for (size_t I = 0; I < N; ++I) + { + CHECK_MESSAGE(Vec[I] == uint32_t(I), "index=", I, " got=", Vec[I]); + } +} + +TEST_CASE("already sorted") +{ + WorkerThreadPool Pool(4); + + constexpr size_t N = 200'000; + eastl::vector<uint32_t> Vec(N); + for (size_t I = 0; I < N; ++I) + { + Vec[I] = uint32_t(I); + } + + ParallelSort(Pool, Vec.begin(), Vec.end(), [](uint32_t A, uint32_t B) { return A < B; }); + + for (size_t I = 0; I < N; ++I) + { + CHECK_MESSAGE(Vec[I] == uint32_t(I), "index=", I, " got=", Vec[I]); + } +} + +TEST_CASE("reverse sorted") +{ + WorkerThreadPool Pool(4); + + constexpr size_t N = 200'000; + eastl::vector<uint32_t> Vec(N); + for (size_t I = 0; I < N; ++I) + { + Vec[I] = uint32_t(N - 1 - I); + } + + ParallelSort(Pool, Vec.begin(), Vec.end(), [](uint32_t A, uint32_t B) { return A < B; }); + + for (size_t I = 0; I < N; ++I) + { + CHECK_MESSAGE(Vec[I] == uint32_t(I), "index=", I, " got=", Vec[I]); + } +} + +TEST_CASE("duplicate keys") +{ + WorkerThreadPool Pool(4); + + constexpr size_t N = 200'000; + eastl::vector<uint32_t> Vec(N); + for (size_t I = 0; I < N; ++I) + { + Vec[I] = uint32_t(I % 100); // only 100 distinct values + } + + ParallelSort(Pool, Vec.begin(), Vec.end(), [](uint32_t A, uint32_t B) { return A < B; }); + + for (size_t I = 1; I < N; ++I) + { + CHECK_MESSAGE(Vec[I - 1] <= Vec[I], "not sorted at index=", I); + } +} + +TEST_CASE("custom comparator descending") +{ + WorkerThreadPool Pool(4); + + constexpr size_t N = 200'000; + eastl::vector<uint32_t> Vec(N); + for (size_t I = 0; I < N; ++I) + { + Vec[I] = uint32_t(I); + } + + ParallelSort(Pool, Vec.begin(), Vec.end(), [](uint32_t A, uint32_t B) { return A > B; }); + + for (size_t I = 0; I < N; ++I) + { + CHECK_MESSAGE(Vec[I] == uint32_t(N - 1 - I), "index=", I, " got=", Vec[I]); + } +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen diff --git a/src/zenutil/process/asyncpipereader.cpp b/src/zenutil/process/asyncpipereader.cpp index 2fdcda30d..8eac350c6 100644 --- a/src/zenutil/process/asyncpipereader.cpp +++ b/src/zenutil/process/asyncpipereader.cpp @@ -50,7 +50,7 @@ struct AsyncPipeReader::Impl int Fd = Pipe.ReadFd; - // Close the write end — child already has it + // Close the write end - child already has it Pipe.CloseWriteEnd(); // Set non-blocking @@ -156,7 +156,7 @@ CreateOverlappedStdoutPipe(StdoutPipeHandles& OutPipe) // The read end should not be inherited by the child SetHandleInformation(ReadHandle, HANDLE_FLAG_INHERIT, 0); - // Open the client (write) end — inheritable, for the child process + // Open the client (write) end - inheritable, for the child process SECURITY_ATTRIBUTES Sa; Sa.nLength = sizeof(Sa); Sa.lpSecurityDescriptor = nullptr; @@ -202,7 +202,7 @@ struct AsyncPipeReader::Impl HANDLE ReadHandle = static_cast<HANDLE>(Pipe.ReadHandle); - // Close the write end — child already has it + // Close the write end - child already has it Pipe.CloseWriteEnd(); // Take ownership of the read handle diff --git a/src/zenutil/process/subprocessmanager.cpp b/src/zenutil/process/subprocessmanager.cpp index b053ac6bd..d0b912a0d 100644 --- a/src/zenutil/process/subprocessmanager.cpp +++ b/src/zenutil/process/subprocessmanager.cpp @@ -903,7 +903,11 @@ ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable, Options.AssignToJob = &m_JobObject; } #else - if (m_Pgid > 0) + if (m_Pgid == 0) + { + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; + } + else { Options.ProcessGroupId = m_Pgid; } @@ -1205,7 +1209,17 @@ TEST_CASE("SubprocessManager.SpawnAndDetectExit") CallbackFired = true; }); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (CallbackFired) + { + break; + } + } + } CHECK(CallbackFired); CHECK(ReceivedExitCode == 42); @@ -1230,7 +1244,17 @@ TEST_CASE("SubprocessManager.SpawnAndDetectCleanExit") CallbackFired = true; }); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (CallbackFired) + { + break; + } + } + } CHECK(CallbackFired); CHECK(ReceivedExitCode == 0); @@ -1255,7 +1279,17 @@ TEST_CASE("SubprocessManager.StdoutCapture") ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } CHECK(Exited); std::string Captured = Proc->GetCapturedStdout(); @@ -1284,7 +1318,17 @@ TEST_CASE("SubprocessManager.StderrCapture") ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } CHECK(Exited); std::string CapturedErr = Proc->GetCapturedStderr(); @@ -1316,7 +1360,17 @@ TEST_CASE("SubprocessManager.StdoutCallback") [&](ManagedProcess&, int) { Exited = true; }, [&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); }); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } CHECK(Exited); CHECK(ReceivedData.find("callback_test") != std::string::npos); @@ -1339,8 +1393,18 @@ TEST_CASE("SubprocessManager.MetricsSampling") ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); - // Run for enough time to get metrics samples - IoContext.run_for(1s); + // Poll until metrics are available + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Proc->GetLatestMetrics().WorkingSetSize > 0) + { + break; + } + } + } ProcessMetrics Metrics = Proc->GetLatestMetrics(); CHECK(Metrics.WorkingSetSize > 0); @@ -1349,7 +1413,17 @@ TEST_CASE("SubprocessManager.MetricsSampling") CHECK(Snapshot.size() == 1); // Let it finish - IoContext.run_for(3s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 10'000) + { + IoContext.run_for(10ms); + if (Exited) + { + break; + } + } + } CHECK(Exited); } @@ -1373,7 +1447,7 @@ TEST_CASE("SubprocessManager.RemoveWhileRunning") // Let it start IoContext.run_for(100ms); - // Remove without killing — callback should NOT fire after this + // Remove without killing - callback should NOT fire after this Manager.Remove(Pid); IoContext.run_for(500ms); @@ -1398,12 +1472,31 @@ TEST_CASE("SubprocessManager.KillAndWaitForExit") ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { CallbackFired = true; }); // Let it start - IoContext.run_for(200ms); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Proc->IsRunning()) + { + break; + } + } + } Proc->Kill(); - IoContext.run_for(2s); - + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (CallbackFired) + { + break; + } + } + } CHECK(CallbackFired); } @@ -1424,7 +1517,17 @@ TEST_CASE("SubprocessManager.AdoptProcess") Manager.Adopt(ProcessHandle(Result), [&](ManagedProcess&, int ExitCode) { ReceivedExitCode = ExitCode; }); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (ReceivedExitCode != -1) + { + break; + } + } + } CHECK(ReceivedExitCode == 7); } @@ -1447,7 +1550,17 @@ TEST_CASE("SubprocessManager.UserTag") Proc->SetTag("my-worker-1"); CHECK(Proc->GetTag() == "my-worker-1"); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (!ReceivedTag.empty()) + { + break; + } + } + } CHECK(ReceivedTag == "my-worker-1"); } @@ -1477,7 +1590,17 @@ TEST_CASE("ProcessGroup.SpawnAndMembership") CHECK(Group->GetProcessCount() == 2); CHECK(Manager.GetProcessCount() == 2); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (ExitCount == 2) + { + break; + } + } + } CHECK(ExitCount == 2); } @@ -1527,7 +1650,17 @@ TEST_CASE("ProcessGroup.AggregateMetrics") Group->Spawn(AppStub, CmdLine, Options, [](ManagedProcess&, int) {}); // Wait for metrics sampling - IoContext.run_for(1s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (Group->GetAggregateMetrics().TotalWorkingSetSize > 0) + { + break; + } + } + } AggregateProcessMetrics GroupAgg = Group->GetAggregateMetrics(); CHECK(GroupAgg.ProcessCount == 2); @@ -1593,7 +1726,17 @@ TEST_CASE("ProcessGroup.MixedGroupedAndUngrouped") CHECK(Group->GetProcessCount() == 2); CHECK(Manager.GetProcessCount() == 3); - IoContext.run_for(5s); + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + IoContext.run_for(10ms); + if (GroupExitCount == 2 && UngroupedExitCode != -1) + { + break; + } + } + } CHECK(GroupExitCount == 2); CHECK(UngroupedExitCode == 0); @@ -1613,7 +1756,7 @@ TEST_CASE("ProcessGroup.FindGroup") TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) { - // Seed for reproducibility — change to explore different orderings + // Seed for reproducibility - change to explore different orderings // // Note that while this is a stress test, it is still single-threaded @@ -1642,7 +1785,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // Phase 1: Spawn multiple groups with varied workloads // ======================================================================== - ZEN_INFO("StressTest: Phase 1 — spawning initial groups"); + ZEN_INFO("StressTest: Phase 1 - spawning initial groups"); constexpr int NumInitialGroups = 8; std::vector<std::string> GroupNames; @@ -1696,7 +1839,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // Phase 2: Randomly kill some groups, create replacements, add ungrouped // ======================================================================== - ZEN_INFO("StressTest: Phase 2 — random group kills and replacements"); + ZEN_INFO("StressTest: Phase 2 - random group kills and replacements"); constexpr int NumGroupsToKill = 3; @@ -1761,7 +1904,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // Phase 3: Rapid spawn/exit churn // ======================================================================== - ZEN_INFO("StressTest: Phase 3 — rapid spawn/exit churn"); + ZEN_INFO("StressTest: Phase 3 - rapid spawn/exit churn"); std::atomic<int> ChurnExitCount{0}; int TotalChurnSpawned = 0; @@ -1785,7 +1928,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // Brief pump to allow some exits to be processed IoContext.run_for(200ms); - // Destroy the group — any still-running processes get killed + // Destroy the group - any still-running processes get killed Manager.DestroyGroup(Name); } @@ -1795,7 +1938,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // Phase 4: Drain and verify // ======================================================================== - ZEN_INFO("StressTest: Phase 4 — draining remaining processes"); + ZEN_INFO("StressTest: Phase 4 - draining remaining processes"); // Check metrics were collected before we wind down AggregateProcessMetrics Agg = Manager.GetAggregateMetrics(); @@ -1826,7 +1969,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip()) // (exact count is hard to predict due to killed groups, but should be > 0) CHECK(TotalExitCallbacks.load() > 0); - ZEN_INFO("StressTest: PASSED — seed={}", Seed); + ZEN_INFO("StressTest: PASSED - seed={}", Seed); } TEST_SUITE_END(); diff --git a/src/zenutil/progress.cpp b/src/zenutil/progress.cpp new file mode 100644 index 000000000..01f6529f2 --- /dev/null +++ b/src/zenutil/progress.cpp @@ -0,0 +1,102 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/progress.h> + +#include <zencore/logging.h> + +namespace zen { + +class StandardProgressBase; + +class StandardProgressBar : public ProgressBase::ProgressBar +{ +public: + StandardProgressBar(StandardProgressBase& Owner, std::string_view InSubTask) : m_Owner(Owner), m_SubTask(InSubTask) {} + + virtual void UpdateState(const State& NewState, bool DoLinebreak) override; + virtual void ForceLinebreak() override {} + virtual void Finish() override; + +private: + LoggerRef Log(); + StandardProgressBase& m_Owner; + std::string m_SubTask; + State m_State; +}; + +class StandardProgressBase : public ProgressBase +{ +public: + StandardProgressBase(LoggerRef Log) : m_Log(Log) {} + + virtual void SetLogOperationName(std::string_view Name) override + { + m_LogOperationName = Name; + ZEN_INFO("{}", m_LogOperationName); + } + virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override + { + const size_t PercentDone = StepCount > 0u ? (100u * StepIndex) / StepCount : 0u; + ZEN_INFO("{}: {}%", m_LogOperationName, PercentDone); + } + virtual void PushLogOperation(std::string_view Name) override { ZEN_UNUSED(Name); } + virtual void PopLogOperation() override {} + virtual uint32_t GetProgressUpdateDelayMS() const override { return 2000; } + virtual std::unique_ptr<ProgressBar> CreateProgressBar(std::string_view InSubTask) override + { + return std::make_unique<StandardProgressBar>(*this, InSubTask); + } + +private: + friend class StandardProgressBar; + LoggerRef m_Log; + std::string m_LogOperationName; + LoggerRef Log() { return m_Log; } +}; + +LoggerRef +StandardProgressBar::Log() +{ + return m_Owner.Log(); +} + +void +StandardProgressBar::UpdateState(const State& NewState, bool DoLinebreak) +{ + ZEN_UNUSED(DoLinebreak); + const size_t PercentDone = + NewState.TotalCount > 0u ? (100u * (NewState.TotalCount - NewState.RemainingCount)) / NewState.TotalCount : 0u; + std::string Task = NewState.Task; + switch (NewState.Status) + { + case State::EStatus::Aborted: + Task = "Aborting"; + break; + case State::EStatus::Paused: + Task = "Paused"; + break; + default: + break; + } + ZEN_INFO("{}: {}%{}", Task, PercentDone, NewState.Details.empty() ? "" : fmt::format(" {}", NewState.Details)); + m_State = NewState; +} +void +StandardProgressBar::Finish() +{ + if (m_State.RemainingCount > 0) + { + State NewState = m_State; + NewState.RemainingCount = 0; + NewState.Details = ""; + UpdateState(NewState, /*DoLinebreak*/ true); + } +} + +ProgressBase* +CreateStandardProgress(LoggerRef Log) +{ + return new StandardProgressBase(Log); +} + +} // namespace zen diff --git a/src/zenutil/rpcrecording.cpp b/src/zenutil/rpcrecording.cpp index a9e95b9ce..3c273abb6 100644 --- a/src/zenutil/rpcrecording.cpp +++ b/src/zenutil/rpcrecording.cpp @@ -13,12 +13,12 @@ #include <zencore/testutils.h> ZEN_THIRD_PARTY_INCLUDES_START +#include <EASTL/deque.h> #include <fmt/format.h> #include <gsl/gsl-lite.hpp> ZEN_THIRD_PARTY_INCLUDES_END #include <condition_variable> -#include <deque> #include <mutex> #include <thread> @@ -371,7 +371,7 @@ private: std::mutex m_QueueMutex; std::condition_variable m_QueueCondition; bool m_IsWriterReady = false; - std::deque<QueuedRequest> m_RequestQueue; + eastl::deque<QueuedRequest> m_RequestQueue; void WriterThreadMain(); }; diff --git a/src/zenutil/sessionsclient.cpp b/src/zenutil/sessionsclient.cpp index c62cc4099..6ba997a62 100644 --- a/src/zenutil/sessionsclient.cpp +++ b/src/zenutil/sessionsclient.cpp @@ -2,15 +2,12 @@ #include <zenutil/sessionsclient.h> -#include <zencore/blockingqueue.h> #include <zencore/compactbinarybuilder.h> #include <zencore/fmtutils.h> #include <zencore/iobuffer.h> #include <zencore/logging/logmsg.h> #include <zencore/thread.h> -#include <zenhttp/httpclient.h> -#include <thread> #include <vector> ZEN_THIRD_PARTY_INCLUDES_START @@ -21,232 +18,76 @@ namespace zen { ////////////////////////////////////////////////////////////////////////// // -// SessionLogSink — batching log sink that forwards to /sessions/{id}/log +// SessionLogSink — thin enqueuer that posts log messages to the +// SessionsServiceClient worker thread via its BlockingQueue. // -static const char* -LogLevelToString(logging::LogLevel Level) -{ - switch (Level) - { - case logging::Trace: - return "trace"; - case logging::Debug: - return "debug"; - case logging::Info: - return "info"; - case logging::Warn: - return "warn"; - case logging::Err: - return "error"; - case logging::Critical: - return "critical"; - default: - return "info"; - } -} - -struct BufferedLogEntry -{ - enum class Type : uint8_t - { - Log, - Flush, - Shutdown - }; - - Type Type = Type::Log; - std::string Level; - std::string Message; -}; - class SessionLogSink final : public logging::Sink { public: - SessionLogSink(std::string TargetUrl, std::string LogPath) : m_LogPath(std::move(LogPath)) - { - HttpClientSettings Settings; - Settings.ConnectTimeout = std::chrono::milliseconds(3000); - m_Http = std::make_unique<HttpClient>(std::move(TargetUrl), Settings); - - SetLevel(logging::Info); - - m_WorkerThread = std::thread([this]() { - zen::SetCurrentThreadName("SessionLog"); - WorkerLoop(); - }); - } + explicit SessionLogSink(BlockingQueue<SessionsServiceClient::SessionCommand>* Queue) : m_Queue(Queue) { SetLevel(logging::Info); } - ~SessionLogSink() override - { - BufferedLogEntry ShutdownMsg; - ShutdownMsg.Type = BufferedLogEntry::Type::Shutdown; - m_Queue.Enqueue(std::move(ShutdownMsg)); - - if (m_WorkerThread.joinable()) - { - m_WorkerThread.join(); - } - } + ~SessionLogSink() override = default; void Log(const logging::LogMessage& Msg) override { - BufferedLogEntry Entry; - Entry.Type = BufferedLogEntry::Type::Log; - Entry.Level = LogLevelToString(Msg.GetLevel()); - Entry.Message = std::string(Msg.GetPayload()); - m_Queue.Enqueue(std::move(Entry)); + SessionsServiceClient::SessionCommand Cmd; + Cmd.CommandType = SessionsServiceClient::SessionCommand::Type::Log; + Cmd.LogLevel = Msg.GetLevel(); + Cmd.LogMessage = CompactString(Msg.GetPayload()); + m_Queue->Enqueue(std::move(Cmd)); } void Flush() override { - // Best-effort: enqueue a flush marker so the worker sends any pending entries - BufferedLogEntry FlushMsg; - FlushMsg.Type = BufferedLogEntry::Type::Flush; - m_Queue.Enqueue(std::move(FlushMsg)); + SessionsServiceClient::SessionCommand Cmd; + Cmd.CommandType = SessionsServiceClient::SessionCommand::Type::FlushLogs; + m_Queue->Enqueue(std::move(Cmd)); } void SetFormatter(std::unique_ptr<logging::Formatter> /*InFormatter*/) override { - // No formatting needed — we send raw message text + // No formatting needed - we send raw message text } private: - static constexpr size_t BatchSize = 50; - - void WorkerLoop() - { - std::vector<BufferedLogEntry> Batch; - Batch.reserve(BatchSize); - - BufferedLogEntry Msg; - while (m_Queue.WaitAndDequeue(Msg)) - { - if (Msg.Type == BufferedLogEntry::Type::Shutdown) - { - // Drain remaining log entries - BufferedLogEntry Remaining; - while (m_Queue.WaitAndDequeue(Remaining)) - { - if (Remaining.Type == BufferedLogEntry::Type::Log) - { - Batch.push_back(std::move(Remaining)); - } - } - if (!Batch.empty()) - { - SendBatch(Batch); - } - return; - } - - if (Msg.Type == BufferedLogEntry::Type::Flush) - { - if (!Batch.empty()) - { - SendBatch(Batch); - Batch.clear(); - } - continue; - } - - // Log entry - Batch.push_back(std::move(Msg)); - - if (Batch.size() >= BatchSize) - { - SendBatch(Batch); - Batch.clear(); - } - else - { - // Drain any additional queued entries without blocking - while (Batch.size() < BatchSize && m_Queue.Size() > 0) - { - BufferedLogEntry Extra; - if (m_Queue.WaitAndDequeue(Extra)) - { - if (Extra.Type == BufferedLogEntry::Type::Shutdown) - { - if (!Batch.empty()) - { - SendBatch(Batch); - } - // Drain remaining - while (m_Queue.WaitAndDequeue(Extra)) - { - if (Extra.Type == BufferedLogEntry::Type::Log) - { - Batch.push_back(std::move(Extra)); - } - } - if (!Batch.empty()) - { - SendBatch(Batch); - } - return; - } - if (Extra.Type == BufferedLogEntry::Type::Log) - { - Batch.push_back(std::move(Extra)); - } - else if (Extra.Type == BufferedLogEntry::Type::Flush) - { - break; - } - } - } - - if (!Batch.empty()) - { - SendBatch(Batch); - Batch.clear(); - } - } - } - } - - void SendBatch(const std::vector<BufferedLogEntry>& Batch) - { - try - { - CbObjectWriter Writer; - Writer.BeginArray("entries"); - for (const BufferedLogEntry& Entry : Batch) - { - Writer.BeginObject(); - Writer << "level" << Entry.Level; - Writer << "message" << Entry.Message; - Writer.EndObject(); - } - Writer.EndArray(); - - HttpClient::Response Result = m_Http->Post(m_LogPath, Writer.Save()); - (void)Result; // Best-effort - } - catch (const std::exception&) - { - // Best-effort — silently discard on failure - } - } - - std::string m_LogPath; - std::unique_ptr<HttpClient> m_Http; - BlockingQueue<BufferedLogEntry> m_Queue; - std::thread m_WorkerThread; + BlockingQueue<SessionsServiceClient::SessionCommand>* m_Queue; }; +////////////////////////////////////////////////////////////////////////// +// +// SessionsServiceClient +// + SessionsServiceClient::SessionsServiceClient(Options Opts) : m_Log(logging::Get("sessionsclient")) , m_Options(std::move(Opts)) -, m_SessionPath(fmt::format("sessions/{}", m_Options.SessionId)) +, m_SessionPath(fmt::format("/sessions/{}", m_Options.SessionId)) { - HttpClientSettings Settings; - Settings.ConnectTimeout = std::chrono::milliseconds(3000); - m_Http = std::make_unique<HttpClient>(m_Options.TargetUrl, Settings); + // Strip trailing slash to avoid double-slash when appending paths like /sessions/{id} + while (m_Options.TargetUrl.ends_with('/')) + { + m_Options.TargetUrl.pop_back(); + } + + m_WorkerThread = std::thread([this]() { + zen::SetCurrentThreadName("SessionIO"); + WorkerLoop(); + }); } -SessionsServiceClient::~SessionsServiceClient() = default; +SessionsServiceClient::~SessionsServiceClient() +{ + SessionCommand ShutdownCmd; + ShutdownCmd.CommandType = SessionCommand::Type::Shutdown; + m_Queue.Enqueue(std::move(ShutdownCmd)); + m_Queue.CompleteAdding(); + + if (m_WorkerThread.joinable()) + { + m_WorkerThread.join(); + } +} CbObject SessionsServiceClient::BuildRequestBody(CbObjectView Metadata) const @@ -261,21 +102,65 @@ SessionsServiceClient::BuildRequestBody(CbObjectView Metadata) const { Writer << "jobid" << m_Options.JobId; } - if (Metadata.GetSize() > 0) + if (Metadata) { Writer.AddObject("metadata", Metadata); } return Writer.Save(); } -bool +////////////////////////////////////////////////////////////////////////// +// Public API — non-blocking enqueuers + +void SessionsServiceClient::Announce(CbObjectView Metadata) { + SessionCommand Cmd; + Cmd.CommandType = SessionCommand::Type::Announce; + if (Metadata) + { + Cmd.Metadata = CbObject::Clone(Metadata); + } + m_Queue.Enqueue(std::move(Cmd)); +} + +void +SessionsServiceClient::UpdateMetadata(CbObjectView Metadata) +{ + SessionCommand Cmd; + Cmd.CommandType = SessionCommand::Type::UpdateMetadata; + if (Metadata) + { + Cmd.Metadata = CbObject::Clone(Metadata); + } + m_Queue.Enqueue(std::move(Cmd)); +} + +void +SessionsServiceClient::Remove() +{ + SessionCommand Cmd; + Cmd.CommandType = SessionCommand::Type::Remove; + m_Queue.Enqueue(std::move(Cmd)); +} + +logging::SinkPtr +SessionsServiceClient::CreateLogSink() +{ + return Ref(new SessionLogSink(&m_Queue)); +} + +////////////////////////////////////////////////////////////////////////// +// Worker thread — processes all session HTTP I/O + +void +SessionsServiceClient::DoAnnounce(HttpClient& Http, CbObjectView Metadata) +{ try { CbObject Body = BuildRequestBody(Metadata); - HttpClient::Response Result = m_Http->Post(m_SessionPath, std::move(Body)); + HttpClient::Response Result = Http.Post(m_SessionPath, std::move(Body)); if (Result.Error) { @@ -283,26 +168,24 @@ SessionsServiceClient::Announce(CbObjectView Metadata) m_Options.TargetUrl, static_cast<int>(Result.Error->ErrorCode), Result.Error->ErrorMessage); - return false; + return; } if (!IsHttpOk(Result.StatusCode)) { ZEN_WARN("sessions announce failed for '{}': HTTP status {}", m_Options.TargetUrl, static_cast<int>(Result.StatusCode)); - return false; + return; } - ZEN_INFO("session announced to '{}'", m_Options.TargetUrl); - return true; + ZEN_DEBUG("session announced to '{}'", m_Options.TargetUrl); } catch (const std::exception& Ex) { ZEN_WARN("sessions announce failed for '{}': {}", m_Options.TargetUrl, Ex.what()); - return false; } } -bool -SessionsServiceClient::UpdateMetadata(CbObjectView Metadata) +void +SessionsServiceClient::DoUpdateMetadata(HttpClient& Http, CbObjectView Metadata) { try { @@ -311,7 +194,7 @@ SessionsServiceClient::UpdateMetadata(CbObjectView Metadata) MemoryView View = Body.GetView(); IoBuffer Payload = IoBufferBuilder::MakeCloneFromMemory(View, ZenContentType::kCbObject); - HttpClient::Response Result = m_Http->Put(m_SessionPath, Payload); + HttpClient::Response Result = Http.Put(m_SessionPath, Payload); if (Result.Error) { @@ -319,29 +202,26 @@ SessionsServiceClient::UpdateMetadata(CbObjectView Metadata) m_Options.TargetUrl, static_cast<int>(Result.Error->ErrorCode), Result.Error->ErrorMessage); - return false; + return; } if (!IsHttpOk(Result.StatusCode)) { ZEN_WARN("sessions update failed for '{}': HTTP status {}", m_Options.TargetUrl, static_cast<int>(Result.StatusCode)); - return false; + return; } - - return true; } catch (const std::exception& Ex) { ZEN_WARN("sessions update failed for '{}': {}", m_Options.TargetUrl, Ex.what()); - return false; } } -bool -SessionsServiceClient::Remove() +void +SessionsServiceClient::DoRemove(HttpClient& Http) { try { - HttpClient::Response Result = m_Http->Delete(m_SessionPath); + HttpClient::Response Result = Http.Delete(m_SessionPath); if (Result.Error) { @@ -349,29 +229,153 @@ SessionsServiceClient::Remove() m_Options.TargetUrl, static_cast<int>(Result.Error->ErrorCode), Result.Error->ErrorMessage); - return false; + return; } if (!IsHttpOk(Result.StatusCode)) { ZEN_WARN("sessions remove failed for '{}': HTTP status {}", m_Options.TargetUrl, static_cast<int>(Result.StatusCode)); - return false; + return; } - ZEN_INFO("session removed from '{}'", m_Options.TargetUrl); - return true; + ZEN_DEBUG("session removed from '{}'", m_Options.TargetUrl); } catch (const std::exception& Ex) { ZEN_WARN("sessions remove failed for '{}': {}", m_Options.TargetUrl, Ex.what()); - return false; } } -logging::SinkPtr -SessionsServiceClient::CreateLogSink() +void +SessionsServiceClient::SendLogBatch(HttpClient& Http, const std::string& LogPath, const std::vector<SessionCommand>& Batch) { + try + { + CbObjectWriter Writer; + Writer.BeginArray("entries"); + for (const SessionCommand& Entry : Batch) + { + Writer.BeginObject(); + Writer << "level" << static_cast<int32_t>(Entry.LogLevel); + Writer << "message" << Entry.LogMessage.c_str(); + Writer.EndObject(); + } + Writer.EndArray(); + + HttpClient::Response Result = Http.Post(LogPath, Writer.Save()); + (void)Result; // Best-effort + } + catch (const std::exception&) + { + // Best-effort — silently discard on failure + } +} + +void +SessionsServiceClient::WorkerLoop() +{ + HttpClientSettings Settings = m_Options.ClientSettings; + Settings.ConnectTimeout = std::chrono::milliseconds(3000); + Settings.Timeout = std::chrono::milliseconds(5000); + HttpClient Http(m_Options.TargetUrl, Settings); + std::string LogPath = m_SessionPath + "/log"; - return Ref(new SessionLogSink(m_Options.TargetUrl, std::move(LogPath))); + bool Removed = false; + + static constexpr size_t BatchSize = 50; + + std::vector<SessionCommand> LogBatch; + LogBatch.reserve(BatchSize); + + auto FlushLogBatch = [&]() { + if (!LogBatch.empty()) + { + SendLogBatch(Http, LogPath, LogBatch); + LogBatch.clear(); + } + }; + + // Returns false to signal loop exit (Shutdown received) + auto ProcessCommand = [&](SessionCommand& Cmd) -> bool { + switch (Cmd.CommandType) + { + case SessionCommand::Type::Log: + LogBatch.push_back(std::move(Cmd)); + if (LogBatch.size() >= BatchSize) + { + FlushLogBatch(); + } + return true; + + case SessionCommand::Type::FlushLogs: + FlushLogBatch(); + return true; + + case SessionCommand::Type::Announce: + FlushLogBatch(); + DoAnnounce(Http, Cmd.Metadata); + return true; + + case SessionCommand::Type::UpdateMetadata: + FlushLogBatch(); + DoUpdateMetadata(Http, Cmd.Metadata); + return true; + + case SessionCommand::Type::Remove: + FlushLogBatch(); + if (!Removed) + { + Removed = true; + DoRemove(Http); + } + return true; + + case SessionCommand::Type::Shutdown: + { + // Drain remaining log entries from the queue + SessionCommand Remaining; + while (m_Queue.WaitAndDequeue(Remaining)) + { + if (Remaining.CommandType == SessionCommand::Type::Log) + { + LogBatch.push_back(std::move(Remaining)); + } + } + FlushLogBatch(); + + if (!Removed) + { + Removed = true; + DoRemove(Http); + } + return false; + } + } + return true; + }; + + SessionCommand Cmd; + while (m_Queue.WaitAndDequeue(Cmd)) + { + if (!ProcessCommand(Cmd)) + { + return; + } + + // Drain additional queued entries without blocking (batching optimization) + while (LogBatch.size() < BatchSize && m_Queue.Size() > 0) + { + SessionCommand Extra; + if (m_Queue.WaitAndDequeue(Extra)) + { + if (!ProcessCommand(Extra)) + { + return; + } + } + } + + FlushLogBatch(); + } } } // namespace zen diff --git a/src/zenutil/splitconsole/logstreamlistener.cpp b/src/zenutil/splitconsole/logstreamlistener.cpp index 04718b543..df985a196 100644 --- a/src/zenutil/splitconsole/logstreamlistener.cpp +++ b/src/zenutil/splitconsole/logstreamlistener.cpp @@ -17,7 +17,7 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { ////////////////////////////////////////////////////////////////////////// -// LogStreamSession — reads CbObject-framed messages from a single TCP connection +// LogStreamSession - reads CbObject-framed messages from a single TCP connection class LogStreamSession : public RefCounted { @@ -34,7 +34,7 @@ private: [Self](const asio::error_code& Ec, std::size_t BytesRead) { if (Ec) { - return; // connection closed or error — session ends + return; // connection closed or error - session ends } Self->m_BufferUsed += BytesRead; Self->ProcessBuffer(); @@ -119,7 +119,7 @@ private: m_BufferUsed -= Consumed; } - // If buffer is full and we can't parse a message, the message is too large — drop connection + // If buffer is full and we can't parse a message, the message is too large - drop connection if (m_BufferUsed == m_ReadBuf.size()) { ZEN_WARN("LogStreamSession: buffer full with no complete message, dropping connection"); @@ -141,7 +141,7 @@ private: struct LogStreamListener::Impl { - // Owned io_context mode — creates and runs its own thread + // Owned io_context mode - creates and runs its own thread Impl(LogStreamTarget& Target, uint16_t Port) : m_Target(Target) , m_OwnedIoContext(std::make_unique<asio::io_context>()) @@ -154,7 +154,7 @@ struct LogStreamListener::Impl }); } - // External io_context mode — caller drives the io_context + // External io_context mode - caller drives the io_context Impl(LogStreamTarget& Target, asio::io_context& IoContext, uint16_t Port) : m_Target(Target), m_Acceptor(IoContext) { SetupAcceptor(Port); @@ -312,7 +312,7 @@ namespace { logging::LogMessage MakeLogMessage(std::string_view Text, logging::LogLevel Level = logging::Info) { - static logging::LogPoint Point{{}, Level, {}}; + static logging::LogPoint Point{0, 0, Level, {}}; Point.Level = Level; return logging::LogMessage(Point, "test", Text); } @@ -367,7 +367,7 @@ TEST_CASE("DroppedMessageDetection") asio::ip::tcp::socket Socket(IoContext); Socket.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), Listener.GetPort())); - // Send seq=0, then seq=5 — the listener should detect a gap of 4 + // Send seq=0, then seq=5 - the listener should detect a gap of 4 for (uint64_t Seq : {uint64_t(0), uint64_t(5)}) { CbObjectWriter Writer; diff --git a/src/zenutil/suggest.cpp b/src/zenutil/suggest.cpp new file mode 100644 index 000000000..15582e4f8 --- /dev/null +++ b/src/zenutil/suggest.cpp @@ -0,0 +1,230 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/suggest.h> + +#include <zencore/string.h> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <ostream> // needed for doctest to stringify std::string_view on CHECK failure +#endif + +#include <algorithm> + +namespace zen { + +namespace { + + constexpr size_t kMaxSuggestionInputLen = 64; + + // Damerau-Levenshtein distance with case-insensitive ASCII comparison. Returns the + // raw edit count (insertions, deletions, substitutions, adjacent transpositions). + // Falls back to max(len) for pathologically long inputs so the scratch buffers can + // stay stack-allocated. + size_t ComputeEditDistance(std::string_view Source, std::string_view Target) + { + if (Source.size() > kMaxSuggestionInputLen || Target.size() > kMaxSuggestionInputLen) + { + return std::max(Source.size(), Target.size()); + } + + auto AsciiToLower = [](unsigned char Ch) -> unsigned char { + return (Ch >= 'A' && Ch <= 'Z') ? static_cast<unsigned char>(Ch + ('a' - 'A')) : Ch; + }; + + const size_t SourceLen = Source.size(); + const size_t TargetLen = Target.size(); + if (SourceLen == 0) + { + return TargetLen; + } + if (TargetLen == 0) + { + return SourceLen; + } + + size_t Buf0[kMaxSuggestionInputLen + 1]; + size_t Buf1[kMaxSuggestionInputLen + 1]; + size_t Buf2[kMaxSuggestionInputLen + 1]; + size_t* PrevPrev = Buf0; + size_t* Prev = Buf1; + size_t* Curr = Buf2; + + for (size_t j = 0; j <= TargetLen; ++j) + { + Prev[j] = j; + } + + for (size_t i = 1; i <= SourceLen; ++i) + { + Curr[0] = i; + for (size_t j = 1; j <= TargetLen; ++j) + { + const size_t Cost = (AsciiToLower(Source[i - 1]) == AsciiToLower(Target[j - 1])) ? 0 : 1; + Curr[j] = std::min({Prev[j] + 1, // deletion from Source + Curr[j - 1] + 1, // insertion into Source + Prev[j - 1] + Cost}); // substitution + if (i > 1 && j > 1 && AsciiToLower(Source[i - 1]) == AsciiToLower(Target[j - 2]) && + AsciiToLower(Source[i - 2]) == AsciiToLower(Target[j - 1])) + { + Curr[j] = std::min(Curr[j], PrevPrev[j - 2] + 1); // transposition + } + } + // Rotate buffers: (PrevPrev, Prev, Curr) -> (Prev, Curr, PrevPrev) + size_t* Tmp = PrevPrev; + PrevPrev = Prev; + Prev = Curr; + Curr = Tmp; + } + return Prev[TargetLen]; + } + +} // namespace + +std::vector<std::string_view> +SuggestSimilarCommands(std::string_view Typed, std::span<const std::string_view> Candidates) +{ + constexpr size_t kMaxSuggestions = 5; + + if (Typed.empty()) + { + return {}; + } + + struct Scored + { + std::string_view Name; + size_t Distance; + }; + + std::vector<Scored> Ranked; + Ranked.reserve(Candidates.size()); + + for (std::string_view Name : Candidates) + { + size_t Distance = ComputeEditDistance(Typed, Name); + + // Prefix bonus so short inputs ("ca") still surface longer commands ("cache") + // without us having to use a distance threshold that would admit unrelated matches. + if (Name.size() >= Typed.size() && StrCaseCompare(Name.substr(0, Typed.size()), Typed) == 0) + { + Distance = std::min<size_t>(Distance, 1); + } + + Ranked.push_back({Name, Distance}); + } + + // Scale with typed length: very short inputs get a tighter bound so we don't + // flood the output with every short command that happens to be 2 edits away. + // Short inputs rely on the prefix bonus instead. Cap at 3 for longer strings + // since beyond that the suggestions stop being plausible. + const size_t MaxDistance = std::clamp<size_t>(Typed.size() / 2, 1, 3); + + std::stable_sort(Ranked.begin(), Ranked.end(), [](const Scored& Lhs, const Scored& Rhs) { return Lhs.Distance < Rhs.Distance; }); + + std::vector<std::string_view> Result; + for (const Scored& Entry : Ranked) + { + if (Entry.Distance > MaxDistance) + { + break; + } + Result.push_back(Entry.Name); + if (Result.size() >= kMaxSuggestions) + { + break; + } + } + return Result; +} + +#if ZEN_WITH_TESTS + +void +suggest_forcelink() +{ +} + +TEST_SUITE_BEGIN("util.suggest"); + +TEST_CASE("transposition_is_distance_one") +{ + // "stauts" -> "status" is a single adjacent transposition under Damerau-Levenshtein. + const std::string_view Candidates[] = {"status", "start", "stop"}; + std::vector<std::string_view> Result = SuggestSimilarCommands("stauts", Candidates); + + REQUIRE(!Result.empty()); + CHECK(Result.front() == "status"); +} + +TEST_CASE("short_prefix_surfaces_longer_candidate") +{ + // "ca" is raw distance 3 from "cache", outside the default threshold for a 2-char + // input. The prefix bonus is what brings it back into the result set. + const std::string_view Candidates[] = {"cache", "status", "version"}; + std::vector<std::string_view> Result = SuggestSimilarCommands("ca", Candidates); + + REQUIRE(!Result.empty()); + CHECK(Result.front() == "cache"); +} + +TEST_CASE("gibberish_returns_no_suggestions") +{ + // Nothing in the candidate set is plausibly close. + const std::string_view Candidates[] = {"cache", "status", "version"}; + std::vector<std::string_view> Result = SuggestSimilarCommands("xyzzyqqqq", Candidates); + + CHECK(Result.empty()); +} + +TEST_CASE("empty_input_returns_no_suggestions") +{ + const std::string_view Candidates[] = {"cache", "status"}; + std::vector<std::string_view> Result = SuggestSimilarCommands("", Candidates); + + CHECK(Result.empty()); +} + +TEST_CASE("case_insensitive_match") +{ + const std::string_view Candidates[] = {"cache", "status"}; + std::vector<std::string_view> Result = SuggestSimilarCommands("STAUTS", Candidates); + + REQUIRE(!Result.empty()); + CHECK(Result.front() == "status"); +} + +TEST_CASE("substitution_within_threshold") +{ + // "versoin" -> "version": one transposition. Within threshold for a 7-char input. + const std::string_view Candidates[] = {"version", "verify", "value"}; + std::vector<std::string_view> Result = SuggestSimilarCommands("versoin", Candidates); + + REQUIRE(!Result.empty()); + CHECK(Result.front() == "version"); +} + +TEST_CASE("results_sorted_by_distance") +{ + // "stats" is exact match; "stat" and "start" are both distance 1. Best-first ordering. + const std::string_view Candidates[] = {"start", "stat", "stats"}; + std::vector<std::string_view> Result = SuggestSimilarCommands("stats", Candidates); + + REQUIRE(Result.size() >= 1); + CHECK(Result.front() == "stats"); +} + +TEST_CASE("caps_at_five_results") +{ + // Eight prefix-matching candidates; only the first five should survive the cap. + const std::string_view Candidates[] = {"ca1", "ca2", "ca3", "ca4", "ca5", "ca6", "ca7", "ca8"}; + std::vector<std::string_view> Result = SuggestSimilarCommands("ca", Candidates); + + CHECK(Result.size() == 5); +} + +TEST_SUITE_END(); + +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zenutil/testartifactprovider.cpp b/src/zenutil/testartifactprovider.cpp new file mode 100644 index 000000000..666a1758d --- /dev/null +++ b/src/zenutil/testartifactprovider.cpp @@ -0,0 +1,588 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/testartifactprovider.h> + +#include <zencore/except_fmt.h> +#include <zencore/filesystem.h> +#include <zencore/logging.h> + +#include <memory> +#include <system_error> + +namespace zen { + +namespace { + + std::string JoinKey(std::string_view Prefix, std::string_view RelativePath) + { + if (Prefix.empty()) + { + return std::string(RelativePath); + } + std::string Result; + Result.reserve(Prefix.size() + 1 + RelativePath.size()); + Result.append(Prefix); + if (Result.back() != '/') + { + Result.push_back('/'); + } + Result.append(RelativePath); + return Result; + } + + class LocalBackend + { + public: + explicit LocalBackend(std::filesystem::path RootDir) : m_RootDir(std::move(RootDir)) {} + + const std::filesystem::path& RootDir() const { return m_RootDir; } + + bool Exists(std::string_view RelativePath) const + { + std::error_code Ec; + return std::filesystem::exists(Resolve(RelativePath), Ec) && !Ec; + } + + [[nodiscard]] TestArtifactFetchResult Fetch(std::string_view RelativePath) const + { + TestArtifactFetchResult Result; + + std::filesystem::path FullPath = Resolve(RelativePath); + FileContents Contents = ReadFile(FullPath); + if (!Contents) + { + Result.Error = fmt::format("failed to read '{}': {}", FullPath.string(), Contents.ErrorCode.message()); + return Result; + } + Result.Content = Contents.Flatten(); + return Result; + } + + [[nodiscard]] TestArtifactResult Store(std::string_view RelativePath, IoBuffer Content) const + { + TestArtifactResult Result; + + std::filesystem::path FullPath = Resolve(RelativePath); + std::filesystem::path ParentPath = FullPath.parent_path(); + if (!ParentPath.empty()) + { + std::error_code Ec; + std::filesystem::create_directories(ParentPath, Ec); + if (Ec) + { + Result.Error = fmt::format("failed to create '{}': {}", ParentPath.string(), Ec.message()); + return Result; + } + } + + try + { + WriteFile(FullPath, std::move(Content)); + } + catch (const std::exception& Ex) + { + Result.Error = fmt::format("failed to write '{}': {}", FullPath.string(), Ex.what()); + } + return Result; + } + + [[nodiscard]] TestArtifactListResult List(std::string_view Prefix) const + { + TestArtifactListResult Result; + + std::error_code Ec; + if (!std::filesystem::exists(m_RootDir, Ec)) + { + Result.Error = fmt::format("root directory does not exist: {}", m_RootDir.string()); + return Result; + } + + std::filesystem::recursive_directory_iterator It(m_RootDir, Ec); + if (Ec) + { + Result.Error = fmt::format("failed to iterate '{}': {}", m_RootDir.string(), Ec.message()); + return Result; + } + + for (const std::filesystem::directory_entry& Entry : It) + { + if (!Entry.is_regular_file()) + { + continue; + } + + std::filesystem::path Relative = std::filesystem::relative(Entry.path(), m_RootDir, Ec); + if (Ec) + { + continue; + } + + std::string RelString = Relative.generic_string(); + if (!Prefix.empty() && RelString.compare(0, Prefix.size(), Prefix) != 0) + { + continue; + } + + TestArtifactInfo Info; + Info.RelativePath = std::move(RelString); + Info.Size = Entry.file_size(Ec); + Result.Artifacts.push_back(std::move(Info)); + } + return Result; + } + + private: + std::filesystem::path Resolve(std::string_view RelativePath) const { return m_RootDir / std::filesystem::path(RelativePath); } + + std::filesystem::path m_RootDir; + }; + + class S3Backend + { + public: + S3Backend(S3ClientOptions ClientOptions, std::string KeyPrefix) + : m_ClientOptions(std::move(ClientOptions)) + , m_KeyPrefix(std::move(KeyPrefix)) + , m_Client(m_ClientOptions) + { + } + + std::string Describe() const { return fmt::format("s3:{}/{}", m_Client.BucketName(), m_KeyPrefix); } + + bool Exists(std::string_view RelativePath) + { + S3HeadObjectResult Head = m_Client.HeadObject(JoinKey(m_KeyPrefix, RelativePath)); + return Head.Status == HeadObjectResult::Found; + } + + TestArtifactFetchResult Fetch(std::string_view RelativePath) + { + TestArtifactFetchResult Result; + + S3GetObjectResult Get = m_Client.GetObject(JoinKey(m_KeyPrefix, RelativePath)); + if (!Get) + { + Result.Error = std::move(Get.Error); + return Result; + } + Result.Content = std::move(Get.Content); + return Result; + } + + TestArtifactListResult List(std::string_view Prefix) + { + TestArtifactListResult Result; + + std::string FullPrefix = JoinKey(m_KeyPrefix, Prefix); + S3ListObjectsResult Listing = m_Client.ListObjects(FullPrefix); + if (!Listing) + { + Result.Error = std::move(Listing.Error); + return Result; + } + + size_t StripLen = m_KeyPrefix.size(); + if (StripLen > 0 && m_KeyPrefix.back() != '/') + { + StripLen += 1; // account for the separator JoinKey inserts + } + + Result.Artifacts.reserve(Listing.Objects.size()); + for (S3ObjectInfo& Obj : Listing.Objects) + { + TestArtifactInfo Info; + Info.RelativePath = (StripLen <= Obj.Key.size()) ? Obj.Key.substr(StripLen) : std::move(Obj.Key); + Info.Size = Obj.Size; + Result.Artifacts.push_back(std::move(Info)); + } + return Result; + } + + private: + S3ClientOptions m_ClientOptions; + std::string m_KeyPrefix; + S3Client m_Client; + }; + + class TestArtifactProviderImpl final : public TestArtifactProvider + { + public: + TestArtifactProviderImpl(LocalBackend Cache, std::unique_ptr<S3Backend> Primary) + : m_Cache(std::move(Cache)) + , m_Primary(std::move(Primary)) + { + } + + std::string Describe() const override + { + if (m_Primary) + { + return fmt::format("cache:{} <- {}", m_Cache.RootDir().string(), m_Primary->Describe()); + } + return fmt::format("local:{}", m_Cache.RootDir().string()); + } + + bool Exists(std::string_view RelativePath) override + { + if (m_Cache.Exists(RelativePath)) + { + return true; + } + return m_Primary && m_Primary->Exists(RelativePath); + } + + TestArtifactFetchResult Fetch(std::string_view RelativePath) override + { + if (m_Cache.Exists(RelativePath)) + { + TestArtifactFetchResult Result = m_Cache.Fetch(RelativePath); + if (Result) + { + ZEN_INFO("test artifact '{}' served from cache {} ({} bytes)", + RelativePath, + m_Cache.RootDir().string(), + Result.Content.GetSize()); + } + return Result; + } + + if (!m_Primary) + { + TestArtifactFetchResult Result; + Result.Error = fmt::format("artifact '{}' not found in cache and no remote source configured", RelativePath); + return Result; + } + + ZEN_INFO("downloading test artifact '{}' from {}", RelativePath, m_Primary->Describe()); + + TestArtifactFetchResult Fetched = m_Primary->Fetch(RelativePath); + if (!Fetched) + { + return Fetched; + } + + ZEN_INFO("test artifact '{}' fetched from {} ({} bytes)", RelativePath, m_Primary->Describe(), Fetched.Content.GetSize()); + + TestArtifactResult StoreRes = m_Cache.Store(RelativePath, Fetched.Content); + if (!StoreRes) + { + ZEN_WARN("failed to cache artifact '{}' into {}: {}", RelativePath, m_Cache.RootDir().string(), StoreRes.Error); + } + return Fetched; + } + + TestArtifactListResult List(std::string_view Prefix) override + { + if (m_Primary) + { + return m_Primary->List(Prefix); + } + return m_Cache.List(Prefix); + } + + TestArtifactResult Store(std::string_view RelativePath, IoBuffer Content) override + { + return m_Cache.Store(RelativePath, std::move(Content)); + } + + private: + LocalBackend m_Cache; + std::unique_ptr<S3Backend> m_Primary; + }; + + void ApplyS3UrlToOptions(std::string_view Url, S3ClientOptions& Client, std::string& KeyPrefix) + { + constexpr std::string_view kScheme = "s3://"; + if (Url.substr(0, kScheme.size()) == kScheme) + { + Url.remove_prefix(kScheme.size()); + } + + size_t Slash = Url.find('/'); + std::string_view Bucket = (Slash == std::string_view::npos) ? Url : Url.substr(0, Slash); + std::string_view Prefix = (Slash == std::string_view::npos) ? std::string_view{} : Url.substr(Slash + 1); + + if (Client.BucketName.empty() && !Bucket.empty()) + { + Client.BucketName.assign(Bucket); + } + if (KeyPrefix.empty() && !Prefix.empty()) + { + KeyPrefix.assign(Prefix); + } + } + + // Returns true when the caller has explicitly (or implicitly by platform) disabled + // the EC2 Instance Metadata Service fallback. Honors the standard AWS env var + // AWS_EC2_METADATA_DISABLED=true and skips by default on macOS, where Mac EC2 + // instances are rare and the link-local probe would just emit noise on failure. + bool IsImdsDisabled() + { +#if ZEN_PLATFORM_MAC + return true; +#else + std::string Disabled = GetEnvVariable("AWS_EC2_METADATA_DISABLED"); + return Disabled == "true" || Disabled == "TRUE" || Disabled == "1"; +#endif + } + + void ApplyAwsEnvDefaults(S3ClientOptions& Client) + { + if (std::string Region = GetEnvVariable("AWS_DEFAULT_REGION"); !Region.empty()) + { + Client.Region = std::move(Region); + } + else if (std::string FallbackRegion = GetEnvVariable("AWS_REGION"); !FallbackRegion.empty()) + { + Client.Region = std::move(FallbackRegion); + } + + if (Client.Endpoint.empty()) + { + Client.Endpoint = GetEnvVariable("AWS_ENDPOINT_URL"); + } + + if (Client.Credentials.AccessKeyId.empty() && !Client.CredentialProvider) + { + std::string AccessKeyId = GetEnvVariable("AWS_ACCESS_KEY_ID"); + if (!AccessKeyId.empty()) + { + Client.Credentials.AccessKeyId = std::move(AccessKeyId); + Client.Credentials.SecretAccessKey = GetEnvVariable("AWS_SECRET_ACCESS_KEY"); + Client.Credentials.SessionToken = GetEnvVariable("AWS_SESSION_TOKEN"); + } + else if (!IsImdsDisabled()) + { + // Fall back to the EC2 Instance Metadata Service so self-hosted runners + // with an attached IAM role can sign S3 requests without static creds. + Client.CredentialProvider = Ref<ImdsCredentialProvider>(new ImdsCredentialProvider({})); + } + } + } + +} // namespace + +std::filesystem::path +GetDefaultLocalTestArtifactPath() +{ + std::filesystem::path SystemRoot = PickDefaultSystemRootDirectory(); + if (SystemRoot.empty()) + { + return {}; + } + return SystemRoot / kDefaultLocalTestArtifactDirName; +} + +bool +S3TestArtifactsAvailable() +{ + if (GetEnvVariable(kTestArtifactsS3EnvVar).empty()) + { + return false; + } + if (!GetEnvVariable("AWS_ACCESS_KEY_ID").empty() || !GetEnvVariable("AWS_SESSION_TOKEN").empty()) + { + return true; + } + return !IsImdsDisabled(); +} + +bool +TestArtifactsAvailable() +{ + if (!GetEnvVariable(kTestArtifactsPathEnvVar).empty()) + { + return true; + } + return S3TestArtifactsAvailable(); +} + +Ref<TestArtifactProvider> +CreateTestArtifactProvider(TestArtifactProviderOptions Options) +{ + if (Options.CacheDir.empty()) + { + std::string EnvValue = GetEnvVariable(kTestArtifactsPathEnvVar); + if (!EnvValue.empty()) + { + Options.CacheDir = std::filesystem::path(EnvValue); + } + } + if (Options.CacheDir.empty()) + { + Options.CacheDir = GetDefaultLocalTestArtifactPath(); + } + if (Options.CacheDir.empty()) + { + return {}; + } + + if (Options.S3Client.BucketName.empty() || Options.S3KeyPrefix.empty()) + { + std::string EnvValue = GetEnvVariable(kTestArtifactsS3EnvVar); + if (!EnvValue.empty()) + { + ApplyS3UrlToOptions(EnvValue, Options.S3Client, Options.S3KeyPrefix); + } + } + + std::unique_ptr<S3Backend> Primary; + if (!Options.S3Client.BucketName.empty()) + { + ApplyAwsEnvDefaults(Options.S3Client); + Primary = std::make_unique<S3Backend>(std::move(Options.S3Client), std::move(Options.S3KeyPrefix)); + } + + return Ref<TestArtifactProvider>(new TestArtifactProviderImpl(LocalBackend(std::move(Options.CacheDir)), std::move(Primary))); +} + +void +testartifactprovider_forcelink() +{ +} + +} // namespace zen + +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +# include <zenutil/cloud/minioprocess.h> + +# include <zencore/memoryview.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> + +# include <cstring> + +namespace zen::tests { + +using namespace std::literals; + +namespace { + + bool ContentMatches(const IoBuffer& Buf, std::string_view Expected) + { + return Buf.GetSize() == Expected.size() && std::memcmp(Buf.GetData(), Expected.data(), Expected.size()) == 0; + } + +} // namespace + +TEST_SUITE_BEGIN("util.testartifactprovider"); + +TEST_CASE("local_only.store_fetch_list") +{ + ScopedTemporaryDirectory CacheDir; + + // Keep the test hermetic: ignore any S3 configuration the developer may have in the environment. + ScopedEnvVar EnvS3(kTestArtifactsS3EnvVar, ""); + + TestArtifactProviderOptions Opts; + Opts.CacheDir = CacheDir.Path(); + Ref<TestArtifactProvider> Provider = CreateTestArtifactProvider(std::move(Opts)); + REQUIRE(Provider); + + CHECK_FALSE(Provider->Exists("missing.txt")); + + constexpr std::string_view kContent = "local payload"sv; + TestArtifactResult StoreRes = Provider->Store("greet/hello.txt", IoBufferBuilder::MakeFromMemory(MakeMemoryView(kContent))); + REQUIRE_MESSAGE(StoreRes.IsSuccess(), StoreRes.Error); + + CHECK(Provider->Exists("greet/hello.txt")); + + TestArtifactFetchResult Fetch = Provider->Fetch("greet/hello.txt"); + REQUIRE_MESSAGE(Fetch.IsSuccess(), Fetch.Error); + CHECK(ContentMatches(Fetch.Content, kContent)); + + TestArtifactFetchResult Missing = Provider->Fetch("nope.txt"); + CHECK_FALSE(Missing.IsSuccess()); + + TestArtifactListResult ListRes = Provider->List(""); + REQUIRE_MESSAGE(ListRes.IsSuccess(), ListRes.Error); + REQUIRE_EQ(ListRes.Artifacts.size(), 1u); + CHECK_EQ(ListRes.Artifacts.front().RelativePath, "greet/hello.txt"); +} + +TEST_CASE("minio.s3_primary") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = 19020; + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("artifacts-test"); + + // Seed S3 directly so we can verify the provider's read path. + S3ClientOptions SeedOpts; + SeedOpts.BucketName = "artifacts-test"; + SeedOpts.Region = "us-east-1"; + SeedOpts.Endpoint = Minio.Endpoint(); + SeedOpts.PathStyle = true; + SeedOpts.Credentials.AccessKeyId = std::string(Minio.RootUser()); + SeedOpts.Credentials.SecretAccessKey = std::string(Minio.RootPassword()); + S3Client SeedClient(SeedOpts); + + constexpr std::string_view kHello = "hello from minio"sv; + REQUIRE(SeedClient.PutObject("artifacts/hello.txt", IoBufferBuilder::MakeFromMemory(MakeMemoryView(kHello))).IsSuccess()); + + ScopedTemporaryDirectory CacheDir; + + // Configure everything via environment variables to exercise the env-based defaults. + ScopedEnvVar EnvPath(kTestArtifactsPathEnvVar, CacheDir.Path().string()); + ScopedEnvVar EnvS3(kTestArtifactsS3EnvVar, "s3://artifacts-test/artifacts"); + ScopedEnvVar EnvEndpoint("AWS_ENDPOINT_URL", Minio.Endpoint()); + ScopedEnvVar EnvRegion("AWS_DEFAULT_REGION", "us-east-1"); + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + TestArtifactProviderOptions Opts; + Opts.S3Client.PathStyle = true; // MinIO requires path-style addressing + Ref<TestArtifactProvider> Provider = CreateTestArtifactProvider(std::move(Opts)); + REQUIRE(Provider); + + // -- exists checks consult primary when cache is cold -------------------- + CHECK(Provider->Exists("hello.txt")); + CHECK_FALSE(Provider->Exists("nope.txt")); + + // -- first fetch pulls from primary and populates the cache -------------- + TestArtifactFetchResult First = Provider->Fetch("hello.txt"); + REQUIRE_MESSAGE(First.IsSuccess(), First.Error); + CHECK(ContentMatches(First.Content, kHello)); + CHECK(std::filesystem::exists(CacheDir.Path() / "hello.txt")); + + // -- delete from S3, then fetch again: must come from cache -------------- + REQUIRE(SeedClient.DeleteObject("artifacts/hello.txt").IsSuccess()); + TestArtifactFetchResult Cached = Provider->Fetch("hello.txt"); + REQUIRE_MESSAGE(Cached.IsSuccess(), Cached.Error); + CHECK(ContentMatches(Cached.Content, kHello)); + + // -- list reflects the primary (S3) source ------------------------------ + REQUIRE(SeedClient.PutObject("artifacts/sub/file-a.bin", IoBufferBuilder::MakeFromMemory(MakeMemoryView("A"sv))).IsSuccess()); + REQUIRE(SeedClient.PutObject("artifacts/sub/file-b.bin", IoBufferBuilder::MakeFromMemory(MakeMemoryView("B"sv))).IsSuccess()); + + TestArtifactListResult ListRes = Provider->List(""); + REQUIRE_MESSAGE(ListRes.IsSuccess(), ListRes.Error); + + auto HasPath = [&](std::string_view Rel) { + for (const TestArtifactInfo& Info : ListRes.Artifacts) + { + if (Info.RelativePath == Rel) + { + return true; + } + } + return false; + }; + CHECK(HasPath("sub/file-a.bin")); + CHECK(HasPath("sub/file-b.bin")); + + // -- fetching a missing artifact surfaces a remote error ----------------- + TestArtifactFetchResult MissingFetch = Provider->Fetch("does-not-exist.bin"); + CHECK_FALSE(MissingFetch.IsSuccess()); +} + +TEST_SUITE_END(); + +} // namespace zen::tests + +#endif diff --git a/src/zenutil/windows/windowsservice.cpp b/src/zenutil/windows/windowsservice.cpp index ebb88b018..383568650 100644 --- a/src/zenutil/windows/windowsservice.cpp +++ b/src/zenutil/windows/windowsservice.cpp @@ -16,7 +16,6 @@ SERVICE_STATUS gSvcStatus; SERVICE_STATUS_HANDLE gSvcStatusHandle; -HANDLE ghSvcStopEvent = NULL; void SvcInstall(void); @@ -205,21 +204,6 @@ WindowsService::SvcMain() ReportSvcStatus(SERVICE_START_PENDING, NO_ERROR, 3000); - // Create an event. The control handler function, SvcCtrlHandler, - // signals this event when it receives the stop control code. - - ghSvcStopEvent = CreateEvent(NULL, // default security attributes - TRUE, // manual reset event - FALSE, // not signaled - NULL); // no name - - if (ghSvcStopEvent == NULL) - { - ReportSvcStatus(SERVICE_STOPPED, GetLastError(), 0); - - return 1; - } - int ReturnCode = Run(); return ReturnCode; } @@ -549,9 +533,18 @@ ReportSvcStatus(DWORD dwCurrentState, DWORD dwWin32ExitCode, DWORD dwWaitHint) gSvcStatus.dwWaitHint = dwWaitHint; if (dwCurrentState == SERVICE_START_PENDING) + { gSvcStatus.dwControlsAccepted = 0; + } + else if (dwCurrentState == SERVICE_STOP_PENDING || dwCurrentState == SERVICE_STOPPED) + { + // We are already stopping/stopped - don't accept further control codes. + gSvcStatus.dwControlsAccepted = 0; + } else - gSvcStatus.dwControlsAccepted = SERVICE_ACCEPT_STOP; + { + gSvcStatus.dwControlsAccepted = SERVICE_ACCEPT_STOP | SERVICE_ACCEPT_SHUTDOWN; + } if ((dwCurrentState == SERVICE_RUNNING) || (dwCurrentState == SERVICE_STOPPED)) gSvcStatus.dwCheckPoint = 0; @@ -573,14 +566,14 @@ WindowsService::SvcCtrlHandler(DWORD dwCtrl) switch (dwCtrl) { case SERVICE_CONTROL_STOP: - ReportSvcStatus(SERVICE_STOP_PENDING, NO_ERROR, 0); + case SERVICE_CONTROL_SHUTDOWN: + // SCM gives us dwWaitHint milliseconds to complete the stop; 30s leaves + // plenty of headroom for the HTTP server and service state to tear down. + ReportSvcStatus(SERVICE_STOP_PENDING, NO_ERROR, 30000); - // Signal the service to stop. - - SetEvent(ghSvcStopEvent); + // The HTTP server run loops poll IsApplicationExitRequested(), so this + // flag is how we ask them to wind down. zen::RequestApplicationExit(0); - - ReportSvcStatus(gSvcStatus.dwCurrentState, NO_ERROR, 0); return; case SERVICE_CONTROL_INTERROGATE: diff --git a/src/zenutil/xmake.lua b/src/zenutil/xmake.lua index 83a6b7f93..e28f6e345 100644 --- a/src/zenutil/xmake.lua +++ b/src/zenutil/xmake.lua @@ -9,6 +9,9 @@ target('zenutil') add_deps("zencore", "zenhttp") add_deps("cxxopts") add_deps("robin-map") + if is_plat("linux", "macosx") then + add_syslinks("dl") + end add_packages("json11") if is_plat("linux", "macosx") then diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index 2b27b2d8b..2d4334ffa 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -181,7 +181,7 @@ ZenServerState::Initialize() ThrowLastError("Could not map view of Zen server state"); } #else - int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, geteuid() == 0 ? 0766 : 0666); + int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT, geteuid() == 0 ? 0766 : 0666); if (Fd < 0) { // Work around a potential issue if the service user is changed in certain configurations. @@ -191,7 +191,7 @@ ZenServerState::Initialize() // shared memory object and retry, we'll be able to get past shm_open() so long as we have // the appropriate permissions to create the shared memory object. shm_unlink("/UnrealEngineZen"); - Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, geteuid() == 0 ? 0766 : 0666); + Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT, geteuid() == 0 ? 0766 : 0666); if (Fd < 0) { ThrowLastError("Could not open a shared memory object"); @@ -244,7 +244,7 @@ ZenServerState::InitializeReadOnly() ThrowLastError("Could not map view of Zen server state"); } #else - int Fd = shm_open("/UnrealEngineZen", O_RDONLY | O_CLOEXEC, 0666); + int Fd = shm_open("/UnrealEngineZen", O_RDONLY, 0666); if (Fd < 0) { return false; @@ -267,6 +267,8 @@ ZenServerState::InitializeReadOnly() ZenServerState::ZenServerEntry* ZenServerState::Lookup(int DesiredListenPort) const { + const uint32_t OurPid = GetCurrentProcessId(); + for (int i = 0; i < m_MaxEntryCount; ++i) { uint16_t EntryPort = m_Data[i].DesiredListenPort; @@ -274,6 +276,14 @@ ZenServerState::Lookup(int DesiredListenPort) const { if (DesiredListenPort == 0 || (EntryPort == DesiredListenPort)) { + // If the entry's PID matches our own but we haven't registered yet, + // this is a stale entry from a previous process incarnation (e.g. PID 1 + // reuse after unclean shutdown in k8s). Skip it. + if (m_Data[i].Pid == OurPid && m_OurEntry == nullptr) + { + continue; + } + std::error_code _; if (IsProcessRunning(m_Data[i].Pid, _)) { @@ -289,6 +299,8 @@ ZenServerState::Lookup(int DesiredListenPort) const ZenServerState::ZenServerEntry* ZenServerState::LookupByEffectivePort(int Port) const { + const uint32_t OurPid = GetCurrentProcessId(); + for (int i = 0; i < m_MaxEntryCount; ++i) { uint16_t EntryPort = m_Data[i].EffectiveListenPort; @@ -296,6 +308,11 @@ ZenServerState::LookupByEffectivePort(int Port) const { if (EntryPort == Port) { + if (m_Data[i].Pid == OurPid && m_OurEntry == nullptr) + { + continue; + } + std::error_code _; if (IsProcessRunning(m_Data[i].Pid, _)) { @@ -358,12 +375,26 @@ ZenServerState::Sweep() ZEN_ASSERT(m_IsReadOnly == false); + const uint32_t OurPid = GetCurrentProcessId(); + for (int i = 0; i < m_MaxEntryCount; ++i) { ZenServerEntry& Entry = m_Data[i]; if (Entry.DesiredListenPort) { + // If the entry's PID matches our own but we haven't registered yet, + // this is a stale entry from a previous process incarnation (e.g. PID 1 + // reuse after unclean shutdown in k8s). Reclaim it. + if (Entry.Pid == OurPid && m_OurEntry == nullptr) + { + ZEN_CONSOLE_DEBUG("Sweep - pid {} matches current process but no registration yet, reclaiming stale entry (port {})", + Entry.Pid.load(), + Entry.DesiredListenPort.load()); + Entry.Reset(); + continue; + } + std::error_code ErrorCode; if (Entry.Pid != 0 && IsProcessRunning(Entry.Pid, ErrorCode) == false) { @@ -620,7 +651,7 @@ ZenServerInstanceInfo::Create(const Oid& SessionId, const InstanceInfoData& Data ThrowLastError("Could not map instance info shared memory"); } #else - int Fd = shm_open(Name.c_str(), O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC, 0666); + int Fd = shm_open(Name.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0666); if (Fd < 0) { ThrowLastError("Could not create instance info shared memory"); @@ -687,7 +718,7 @@ ZenServerInstanceInfo::OpenReadOnly(const Oid& SessionId) return false; } #else - int Fd = shm_open(Name.c_str(), O_RDONLY | O_CLOEXEC, 0666); + int Fd = shm_open(Name.c_str(), O_RDONLY, 0666); if (Fd < 0) { return false; @@ -1082,8 +1113,23 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, ChildEventName << "Zen_Child_" << ChildId; NamedEvent ChildEvent{ChildEventName}; + const std::filesystem::path BaseDir = m_Env.ProgramBaseDir(); + const std::filesystem::path Executable = + m_ServerExecutablePath.empty() ? (BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL) : m_ServerExecutablePath; + ExtendableStringBuilder<512> CommandLine; - CommandLine << "zenserver" ZEN_EXE_SUFFIX_LITERAL; // see CreateProc() call for actual binary path + { + const std::string ExeUtf8 = PathToUtf8(Executable); + constexpr AsciiSet QuoteChars = " \t\""; + if (AsciiSet::HasAny(ExeUtf8.c_str(), QuoteChars)) + { + CommandLine << '"' << ExeUtf8 << '"'; + } + else + { + CommandLine << ExeUtf8; + } + } if (m_ServerMode == ServerMode::kHubServer) { @@ -1096,6 +1142,11 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, CommandLine << " --child-id " << ChildEventName; + if (!m_EnableExecutionHistory) + { + CommandLine << " --enable-execution-history=false"; + } + if (!ServerArgs.empty()) { CommandLine << " " << ServerArgs; @@ -1110,10 +1161,6 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, { CreationFlags |= CreateProcOptions::Flag_NewConsole; } - - const std::filesystem::path BaseDir = m_Env.ProgramBaseDir(); - const std::filesystem::path Executable = - m_ServerExecutablePath.empty() ? (BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL) : m_ServerExecutablePath; const std::filesystem::path OutputPath = (OpenConsole || m_Env.IsPassthroughOutput()) ? std::filesystem::path{} : std::filesystem::temp_directory_path() / ("zenserver_" + m_Name + ".log"); @@ -1583,7 +1630,7 @@ ValidateLockFileInfo(const LockFileInfo& Info, std::string& OutReason) std::optional<int> StartupZenServer(LoggerRef LogRef, const StartupZenServerOptions& Options) { - auto Log = [&LogRef]() { return LogRef; }; + ZEN_SCOPED_LOG(LogRef); // Check if a matching server is already running { @@ -1616,6 +1663,7 @@ StartupZenServer(LoggerRef LogRef, const StartupZenServerOptions& Options) ZenServerEnvironment ServerEnvironment; ServerEnvironment.Initialize(ProgramBaseDir); ZenServerInstance Server(ServerEnvironment, Options.Mode); + Server.SetEnableExecutionHistory(Options.EnableExecutionHistory); std::string ServerArguments(Options.ExtraArgs); if ((Options.Port != 0) && (ServerArguments.find("--port") == std::string::npos)) @@ -1653,7 +1701,7 @@ ShutdownZenServer(LoggerRef LogRef, ZenServerState::ZenServerEntry* Entry, const std::filesystem::path& ProgramBaseDir) { - auto Log = [&LogRef]() { return LogRef; }; + ZEN_SCOPED_LOG(LogRef); int EntryPort = (int)Entry->DesiredListenPort.load(); const uint32_t ServerProcessPid = Entry->Pid.load(); try diff --git a/src/zenutil/zenutil.cpp b/src/zenutil/zenutil.cpp index 2ca380c75..3c0fd9ab6 100644 --- a/src/zenutil/zenutil.cpp +++ b/src/zenutil/zenutil.cpp @@ -5,12 +5,18 @@ #if ZEN_WITH_TESTS # include <zenutil/cloud/imdscredentials.h> +# include <zenutil/consul.h> # include <zenutil/cloud/s3client.h> # include <zenutil/cloud/sigv4.h> # include <zenutil/config/commandlineoptions.h> +# include <zenutil/filesystemutils.h> +# include <zenutil/invocationhistory.h> +# include <zenutil/parallelsort.h> # include <zenutil/rpcrecording.h> # include <zenutil/splitconsole/logstreamlistener.h> # include <zenutil/process/subprocessmanager.h> +# include <zenutil/suggest.h> +# include <zenutil/testartifactprovider.h> # include <zenutil/wildcard.h> namespace zen { @@ -20,11 +26,17 @@ zenutil_forcelinktests() { cache::rpcrecord_forcelink(); commandlineoptions_forcelink(); + consul::consul_forcelink(); + filesystemutils_forcelink(); imdscredentials_forcelink(); + invocationhistory_forcelink(); + parallelsort_forcelink(); logstreamlistener_forcelink(); subprocessmanager_forcelink(); s3client_forcelink(); sigv4_forcelink(); + suggest_forcelink(); + testartifactprovider_forcelink(); wildcard_forcelink(); } |