From e64d76ae1b6993582bf161a61049f0771414a779 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Wed, 18 Mar 2026 11:27:07 +0100 Subject: Simple S3 client (#836) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This functionality is intended to be used to manage datasets for test cases, but may be useful elsewhere in the future. - **Add S3 client with AWS Signature V4 (SigV4) signing** — new `S3Client` in `zenutil/cloud/` supporting `GetObject`, `PutObject`, `DeleteObject`, `HeadObject`, and `ListObjects` operations - **Add EC2 IMDS credential provider** — automatically fetches and refreshes temporary AWS credentials from the EC2 Instance Metadata Service (IMDSv2) for use by the S3 client - **Add SigV4 signing library** — standalone implementation of AWS Signature Version 4 request signing (headers and query-string presigning) - **Add path-style addressing support** — enables compatibility with S3-compatible stores like MinIO (in addition to virtual-hosted style) - **Add S3 integration tests** — includes a `MinioProcess` test helper that spins up a local MinIO server, plus integration tests exercising the S3 client end-to-end - **Add S3-backed `HttpObjectStoreService` tests** — integration tests verifying the zenserver object store works against an S3 backend - **Refactor mock IMDS into `zenutil/cloud/`** — moved and generalized the mock IMDS server from `zencompute` so it can be reused by both compute and S3 credential tests --- src/zenutil/cloud/s3client.cpp | 986 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 986 insertions(+) create mode 100644 src/zenutil/cloud/s3client.cpp (limited to 'src/zenutil/cloud/s3client.cpp') diff --git a/src/zenutil/cloud/s3client.cpp b/src/zenutil/cloud/s3client.cpp new file mode 100644 index 000000000..88d844b61 --- /dev/null +++ b/src/zenutil/cloud/s3client.cpp @@ -0,0 +1,986 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include +#include + +#include +#include +#include +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#include + +namespace zen { + +namespace { + + /// The SHA-256 hash of an empty payload, precomputed + constexpr std::string_view EmptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; + + /// Simple XML value extractor. Finds the text content between and . + /// This is intentionally minimal - we only need to parse ListBucketResult responses. + /// Returns a string_view into the original XML when no entity decoding is needed. + std::string_view ExtractXmlValue(std::string_view Xml, std::string_view Tag) + { + std::string OpenTag = fmt::format("<{}>", Tag); + std::string CloseTag = fmt::format("", Tag); + + size_t Start = Xml.find(OpenTag); + if (Start == std::string_view::npos) + { + return {}; + } + Start += OpenTag.size(); + + size_t End = Xml.find(CloseTag, Start); + if (End == std::string_view::npos) + { + return {}; + } + + return Xml.substr(Start, End - Start); + } + + /// Decode the five standard XML entities (& < > " ') into a StringBuilderBase. + void DecodeXmlEntities(std::string_view Input, StringBuilderBase& Out) + { + if (Input.find('&') == std::string_view::npos) + { + Out.Append(Input); + return; + } + + for (size_t i = 0; i < Input.size(); ++i) + { + if (Input[i] == '&') + { + std::string_view Remaining = Input.substr(i); + if (Remaining.starts_with("&")) + { + Out.Append('&'); + i += 4; + } + else if (Remaining.starts_with("<")) + { + Out.Append('<'); + i += 3; + } + else if (Remaining.starts_with(">")) + { + Out.Append('>'); + i += 3; + } + else if (Remaining.starts_with(""")) + { + Out.Append('"'); + i += 5; + } + else if (Remaining.starts_with("'")) + { + Out.Append('\''); + i += 5; + } + else + { + Out.Append(Input[i]); + } + } + else + { + Out.Append(Input[i]); + } + } + } + + /// Convenience: decode XML entities and return as std::string. + std::string DecodeXmlEntities(std::string_view Input) + { + if (Input.find('&') == std::string_view::npos) + { + return std::string(Input); + } + + ExtendableStringBuilder<256> Sb; + DecodeXmlEntities(Input, Sb); + return Sb.ToString(); + } + + /// Join a path and canonical query string into a full request path for the HTTP client. + std::string BuildRequestPath(std::string_view Path, std::string_view CanonicalQS) + { + if (CanonicalQS.empty()) + { + return std::string(Path); + } + return fmt::format("{}?{}", Path, CanonicalQS); + } + + /// Case-insensitive header lookup in an HttpClient response header map. + const std::string* FindResponseHeader(const HttpClient::KeyValueMap& Headers, std::string_view Name) + { + for (const auto& [K, V] : *Headers) + { + if (StrCaseCompare(K, Name) == 0) + { + return &V; + } + } + return nullptr; + } + +} // namespace + +S3Client::S3Client(const S3ClientOptions& Options) +: m_Log(logging::Get("s3")) +, m_BucketName(Options.BucketName) +, m_Region(Options.Region) +, m_Endpoint(Options.Endpoint) +, m_PathStyle(Options.PathStyle) +, m_Credentials(Options.Credentials) +, m_CredentialProvider(Options.CredentialProvider) +, m_HttpClient(BuildEndpoint(), + HttpClientSettings{ + .LogCategory = "s3", + .ConnectTimeout = Options.ConnectTimeout, + .Timeout = Options.Timeout, + .RetryCount = Options.RetryCount, + }) +{ + m_Host = BuildHostHeader(); + ZEN_INFO("S3 client configured for bucket '{}' in region '{}' (endpoint: {}, {})", + m_BucketName, + m_Region, + m_HttpClient.GetBaseUri(), + m_PathStyle ? "path-style" : "virtual-hosted"); +} + +S3Client::~S3Client() = default; + +SigV4Credentials +S3Client::GetCurrentCredentials() +{ + if (m_CredentialProvider) + { + SigV4Credentials Creds = m_CredentialProvider->GetCredentials(); + if (!Creds.AccessKeyId.empty()) + { + // Invalidate the signing key cache when the access key changes + if (Creds.AccessKeyId != m_Credentials.AccessKeyId) + { + RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); + m_CachedDateStamp.clear(); + } + m_Credentials = Creds; + } + return m_Credentials; + } + return m_Credentials; +} + +std::string +S3Client::BuildEndpoint() const +{ + if (!m_Endpoint.empty()) + { + return m_Endpoint; + } + + if (m_PathStyle) + { + // Path-style: https://s3.region.amazonaws.com + return fmt::format("https://s3.{}.amazonaws.com", m_Region); + } + + // Virtual-hosted style: https://bucket.s3.region.amazonaws.com + return fmt::format("https://{}.s3.{}.amazonaws.com", m_BucketName, m_Region); +} + +std::string +S3Client::BuildHostHeader() const +{ + if (!m_Endpoint.empty()) + { + // Extract host from custom endpoint URL (strip scheme) + std::string_view Ep = m_Endpoint; + if (size_t Pos = Ep.find("://"); Pos != std::string_view::npos) + { + Ep = Ep.substr(Pos + 3); + } + // Strip trailing slash + if (!Ep.empty() && Ep.back() == '/') + { + Ep = Ep.substr(0, Ep.size() - 1); + } + return std::string(Ep); + } + + if (m_PathStyle) + { + return fmt::format("s3.{}.amazonaws.com", m_Region); + } + + return fmt::format("{}.s3.{}.amazonaws.com", m_BucketName, m_Region); +} + +std::string +S3Client::KeyToPath(std::string_view Key) const +{ + if (m_PathStyle) + { + return fmt::format("/{}/{}", m_BucketName, Key); + } + return fmt::format("/{}", Key); +} + +std::string +S3Client::BucketRootPath() const +{ + if (m_PathStyle) + { + return fmt::format("/{}/", m_BucketName); + } + return "/"; +} + +Sha256Digest +S3Client::GetSigningKey(std::string_view DateStamp) +{ + // Fast path: shared lock for cache hit (common case — key only changes once per day) + { + RwLock::SharedLockScope SharedLock(m_SigningKeyLock); + if (m_CachedDateStamp == DateStamp) + { + return m_CachedSigningKey; + } + } + + // Slow path: exclusive lock to recompute the signing key + RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); + + // Double-check after acquiring exclusive lock (another thread may have updated it) + if (m_CachedDateStamp == DateStamp) + { + return m_CachedSigningKey; + } + + std::string SecretPrefix = fmt::format("AWS4{}", m_Credentials.SecretAccessKey); + + Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), DateStamp.data(), DateStamp.size()); + SecureZeroSecret(SecretPrefix.data(), SecretPrefix.size()); + + Sha256Digest RegionKey = ComputeHmacSha256(DateKey, m_Region); + Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, "s3"); + m_CachedSigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); + m_CachedDateStamp = std::string(DateStamp); + + return m_CachedSigningKey; +} + +HttpClient::KeyValueMap +S3Client::SignRequest(std::string_view Method, std::string_view Path, std::string_view CanonicalQueryString, std::string_view PayloadHash) +{ + SigV4Credentials Credentials = GetCurrentCredentials(); + + std::string AmzDate = GetAmzTimestamp(); + + // Build sorted headers to sign (must be sorted by lowercase name) + std::vector> HeadersToSign; + HeadersToSign.emplace_back("host", m_Host); + HeadersToSign.emplace_back("x-amz-content-sha256", std::string(PayloadHash)); + HeadersToSign.emplace_back("x-amz-date", AmzDate); + if (!Credentials.SessionToken.empty()) + { + HeadersToSign.emplace_back("x-amz-security-token", Credentials.SessionToken); + } + std::sort(HeadersToSign.begin(), HeadersToSign.end()); + + std::string_view DateStamp(AmzDate.data(), 8); + Sha256Digest SigningKey = GetSigningKey(DateStamp); + + SigV4SignedHeaders Signed = + SignRequestV4(Credentials, Method, Path, CanonicalQueryString, m_Region, "s3", AmzDate, HeadersToSign, PayloadHash, &SigningKey); + + HttpClient::KeyValueMap Result; + Result->emplace("Authorization", std::move(Signed.Authorization)); + Result->emplace("x-amz-date", std::move(Signed.AmzDate)); + Result->emplace("x-amz-content-sha256", std::move(Signed.PayloadHash)); + if (!Credentials.SessionToken.empty()) + { + Result->emplace("x-amz-security-token", Credentials.SessionToken); + } + + return Result; +} + +S3Result +S3Client::PutObject(std::string_view Key, IoBuffer Content) +{ + std::string Path = KeyToPath(Key); + + // Hash the payload + std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); + + HttpClient::KeyValueMap Headers = SignRequest("PUT", Path, "", PayloadHash); + + HttpClient::Response Response = m_HttpClient.Put(Path, Content, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 PUT failed"); + ZEN_WARN("S3 PUT '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 PUT '{}' succeeded ({} bytes)", Key, Content.GetSize()); + return {}; +} + +S3GetObjectResult +S3Client::GetObject(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + + HttpClient::KeyValueMap Headers = SignRequest("GET", Path, "", EmptyPayloadHash); + + HttpClient::Response Response = m_HttpClient.Get(Path, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 GET failed"); + ZEN_WARN("S3 GET '{}' failed: {}", Key, Err); + return S3GetObjectResult{S3Result{std::move(Err)}, {}}; + } + + ZEN_DEBUG("S3 GET '{}' succeeded ({} bytes)", Key, Response.ResponsePayload.GetSize()); + return S3GetObjectResult{{}, std::move(Response.ResponsePayload)}; +} + +S3Result +S3Client::DeleteObject(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + + HttpClient::KeyValueMap Headers = SignRequest("DELETE", Path, "", EmptyPayloadHash); + + HttpClient::Response Response = m_HttpClient.Delete(Path, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 DELETE failed"); + ZEN_WARN("S3 DELETE '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 DELETE '{}' succeeded", Key); + return {}; +} + +S3HeadObjectResult +S3Client::HeadObject(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + + HttpClient::KeyValueMap Headers = SignRequest("HEAD", Path, "", EmptyPayloadHash); + + HttpClient::Response Response = m_HttpClient.Head(Path, Headers); + if (!Response.IsSuccess()) + { + if (Response.StatusCode == HttpResponseCode::NotFound) + { + return S3HeadObjectResult{{}, {}, HeadObjectResult::NotFound}; + } + + std::string Err = Response.ErrorMessage("S3 HEAD failed"); + ZEN_WARN("S3 HEAD '{}' failed: {}", Key, Err); + return S3HeadObjectResult{S3Result{std::move(Err)}, {}, HeadObjectResult::Error}; + } + + S3ObjectInfo Info; + Info.Key = std::string(Key); + + if (const std::string* V = FindResponseHeader(Response.Header, "content-length")) + { + Info.Size = ParseInt(*V).value_or(0); + } + + if (const std::string* V = FindResponseHeader(Response.Header, "etag")) + { + Info.ETag = *V; + } + + if (const std::string* V = FindResponseHeader(Response.Header, "last-modified")) + { + Info.LastModified = *V; + } + + ZEN_DEBUG("S3 HEAD '{}' succeeded (size={})", Key, Info.Size); + return S3HeadObjectResult{{}, std::move(Info), HeadObjectResult::Found}; +} + +S3ListObjectsResult +S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) +{ + S3ListObjectsResult Result; + + std::string ContinuationToken; + + for (;;) + { + // Build query parameters for ListObjectsV2 + std::vector> QueryParams; + QueryParams.emplace_back("list-type", "2"); + if (!Prefix.empty()) + { + QueryParams.emplace_back("prefix", std::string(Prefix)); + } + if (MaxKeys > 0) + { + QueryParams.emplace_back("max-keys", fmt::format("{}", MaxKeys)); + } + if (!ContinuationToken.empty()) + { + QueryParams.emplace_back("continuation-token", ContinuationToken); + } + + std::string CanonicalQS = BuildCanonicalQueryString(std::move(QueryParams)); + std::string RootPath = BucketRootPath(); + HttpClient::KeyValueMap Headers = SignRequest("GET", RootPath, CanonicalQS, EmptyPayloadHash); + + std::string FullPath = BuildRequestPath(RootPath, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Get(FullPath, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 ListObjectsV2 failed"); + ZEN_WARN("S3 ListObjectsV2 prefix='{}' failed: {}", Prefix, Err); + Result.Error = std::move(Err); + return Result; + } + + // Parse the XML response to extract object keys + std::string_view ResponseBody = Response.AsText(); + + // Find all elements + std::string_view Remaining = ResponseBody; + while (true) + { + size_t ContentsStart = Remaining.find(""); + if (ContentsStart == std::string_view::npos) + { + break; + } + + size_t ContentsEnd = Remaining.find("", ContentsStart); + if (ContentsEnd == std::string_view::npos) + { + break; + } + + std::string_view ContentsXml = Remaining.substr(ContentsStart, ContentsEnd - ContentsStart + 11); + + S3ObjectInfo Info; + Info.Key = DecodeXmlEntities(ExtractXmlValue(ContentsXml, "Key")); + Info.ETag = DecodeXmlEntities(ExtractXmlValue(ContentsXml, "ETag")); + Info.LastModified = std::string(ExtractXmlValue(ContentsXml, "LastModified")); + + std::string_view SizeStr = ExtractXmlValue(ContentsXml, "Size"); + if (!SizeStr.empty()) + { + Info.Size = ParseInt(SizeStr).value_or(0); + } + + if (!Info.Key.empty()) + { + Result.Objects.push_back(std::move(Info)); + } + + Remaining = Remaining.substr(ContentsEnd + 11); + } + + // Check if there are more pages + std::string_view IsTruncated = ExtractXmlValue(ResponseBody, "IsTruncated"); + if (IsTruncated != "true") + { + break; + } + + std::string_view NextToken = ExtractXmlValue(ResponseBody, "NextContinuationToken"); + if (NextToken.empty()) + { + break; + } + + ContinuationToken = std::string(NextToken); + ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' fetching next page ({} objects so far)", Prefix, Result.Objects.size()); + } + + ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' returned {} objects", Prefix, Result.Objects.size()); + return Result; +} + +////////////////////////////////////////////////////////////////////////// +// Multipart Upload + +S3CreateMultipartUploadResult +S3Client::CreateMultipartUpload(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({{"uploads", ""}}); + + HttpClient::KeyValueMap Headers = SignRequest("POST", Path, CanonicalQS, EmptyPayloadHash); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Post(FullPath, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 CreateMultipartUpload failed"); + ZEN_WARN("S3 CreateMultipartUpload '{}' failed: {}", Key, Err); + return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; + } + + // Parse UploadId from XML response: + // + // ... + // ... + // ... + // + std::string_view ResponseBody = Response.AsText(); + std::string_view UploadId = ExtractXmlValue(ResponseBody, "UploadId"); + if (UploadId.empty()) + { + std::string Err = "failed to parse UploadId from CreateMultipartUpload response"; + ZEN_WARN("S3 CreateMultipartUpload '{}': {}", Key, Err); + return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; + } + + ZEN_DEBUG("S3 CreateMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + return S3CreateMultipartUploadResult{{}, std::string(UploadId)}; +} + +S3UploadPartResult +S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t PartNumber, IoBuffer Content) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({ + {"partNumber", fmt::format("{}", PartNumber)}, + {"uploadId", std::string(UploadId)}, + }); + + std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); + + HttpClient::KeyValueMap Headers = SignRequest("PUT", Path, CanonicalQS, PayloadHash); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Put(FullPath, Content, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage(fmt::format("S3 UploadPart {} failed", PartNumber)); + ZEN_WARN("S3 UploadPart '{}' part {} failed: {}", Key, PartNumber, Err); + return S3UploadPartResult{S3Result{std::move(Err)}, {}}; + } + + // Extract ETag from response headers + const std::string* ETag = FindResponseHeader(Response.Header, "etag"); + if (!ETag) + { + std::string Err = "S3 UploadPart response missing ETag header"; + ZEN_WARN("S3 UploadPart '{}' part {}: {}", Key, PartNumber, Err); + return S3UploadPartResult{S3Result{std::move(Err)}, {}}; + } + + ZEN_DEBUG("S3 UploadPart '{}' part {} succeeded ({} bytes, etag={})", Key, PartNumber, Content.GetSize(), *ETag); + return S3UploadPartResult{{}, *ETag}; +} + +S3Result +S3Client::CompleteMultipartUpload(std::string_view Key, + std::string_view UploadId, + const std::vector>& PartETags) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); + + // Build the CompleteMultipartUpload XML payload + ExtendableStringBuilder<1024> XmlBody; + XmlBody.Append(""); + for (const auto& [PartNumber, ETag] : PartETags) + { + XmlBody.Append(fmt::format("{}{}", PartNumber, ETag)); + } + XmlBody.Append(""); + + std::string_view XmlView = XmlBody.ToView(); + std::string PayloadHash = Sha256ToHex(ComputeSha256(XmlView)); + + HttpClient::KeyValueMap Headers = SignRequest("POST", Path, CanonicalQS, PayloadHash); + Headers->emplace("Content-Type", "application/xml"); + + IoBuffer Payload(IoBuffer::Clone, XmlView.data(), XmlView.size()); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Post(FullPath, Payload, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 CompleteMultipartUpload failed"); + ZEN_WARN("S3 CompleteMultipartUpload '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + // Check for error in response body - S3 can return 200 with an error in the XML body + std::string_view ResponseBody = Response.AsText(); + if (ResponseBody.find("") != std::string_view::npos) + { + std::string_view ErrorCode = ExtractXmlValue(ResponseBody, "Code"); + std::string_view ErrorMessage = ExtractXmlValue(ResponseBody, "Message"); + std::string Err = fmt::format("S3 CompleteMultipartUpload '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); + ZEN_WARN("{}", Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 CompleteMultipartUpload '{}' succeeded ({} parts)", Key, PartETags.size()); + return {}; +} + +S3Result +S3Client::AbortMultipartUpload(std::string_view Key, std::string_view UploadId) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); + + HttpClient::KeyValueMap Headers = SignRequest("DELETE", Path, CanonicalQS, EmptyPayloadHash); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Delete(FullPath, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 AbortMultipartUpload failed"); + ZEN_WARN("S3 AbortMultipartUpload '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 AbortMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + return {}; +} + +std::string +S3Client::GeneratePresignedGetUrl(std::string_view Key, std::chrono::seconds ExpiresIn) +{ + return GeneratePresignedUrlForMethod(Key, "GET", ExpiresIn); +} + +std::string +S3Client::GeneratePresignedPutUrl(std::string_view Key, std::chrono::seconds ExpiresIn) +{ + return GeneratePresignedUrlForMethod(Key, "PUT", ExpiresIn); +} + +std::string +S3Client::GeneratePresignedUrlForMethod(std::string_view Key, std::string_view Method, std::chrono::seconds ExpiresIn) +{ + std::string Path = KeyToPath(Key); + std::string Scheme = "https"; + + if (!m_Endpoint.empty() && m_Endpoint.starts_with("http://")) + { + Scheme = "http"; + } + + SigV4Credentials Credentials = GetCurrentCredentials(); + return GeneratePresignedUrl(Credentials, Method, Scheme, m_Host, Path, m_Region, "s3", ExpiresIn); +} + +S3Result +S3Client::PutObjectMultipart(std::string_view Key, IoBuffer Content, uint64_t PartSize) +{ + const uint64_t ContentSize = Content.GetSize(); + + // If the content fits in a single part, just use PutObject + if (ContentSize <= PartSize) + { + return PutObject(Key, Content); + } + + ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, ContentSize, (ContentSize + PartSize - 1) / PartSize); + + // Initiate multipart upload + + S3CreateMultipartUploadResult InitResult = CreateMultipartUpload(Key); + if (!InitResult) + { + return S3Result{std::move(InitResult.Error)}; + } + + const std::string& UploadId = InitResult.UploadId; + + // Upload parts sequentially + // TODO: upload parts in parallel for improved throughput on large uploads + + std::vector> PartETags; + uint64_t Offset = 0; + uint32_t PartNumber = 1; + + while (Offset < ContentSize) + { + uint64_t ThisPartSize = std::min(PartSize, ContentSize - Offset); + + // Create a sub-buffer referencing the part data within the original content + IoBuffer PartContent(Content, Offset, ThisPartSize); + + S3UploadPartResult PartResult = UploadPart(Key, UploadId, PartNumber, PartContent); + if (!PartResult) + { + // Attempt to abort the multipart upload on failure + AbortMultipartUpload(Key, UploadId); + return S3Result{std::move(PartResult.Error)}; + } + + PartETags.emplace_back(PartNumber, std::move(PartResult.ETag)); + Offset += ThisPartSize; + PartNumber++; + } + + // Complete multipart upload + S3Result CompleteResult = CompleteMultipartUpload(Key, UploadId, PartETags); + if (!CompleteResult) + { + AbortMultipartUpload(Key, UploadId); + return CompleteResult; + } + + ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), ContentSize); + return {}; +} + +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +s3client_forcelink() +{ +} + +TEST_SUITE_BEGIN("util.cloud.s3client"); + +TEST_CASE("s3client.xml_extract") +{ + std::string_view Xml = + "test/file.txt1234" + "\"abc123\"2024-01-01T00:00:00Z"; + + CHECK(ExtractXmlValue(Xml, "Key") == "test/file.txt"); + CHECK(ExtractXmlValue(Xml, "Size") == "1234"); + CHECK(ExtractXmlValue(Xml, "ETag") == "\"abc123\""); + CHECK(ExtractXmlValue(Xml, "LastModified") == "2024-01-01T00:00:00Z"); + CHECK(ExtractXmlValue(Xml, "NonExistent") == ""); +} + +TEST_CASE("s3client.xml_entity_decode") +{ + CHECK(DecodeXmlEntities("no entities") == "no entities"); + CHECK(DecodeXmlEntities("a&b") == "a&b"); + CHECK(DecodeXmlEntities("<tag>") == ""); + CHECK(DecodeXmlEntities(""hello'") == "\"hello'"); + CHECK(DecodeXmlEntities("&&") == "&&"); + CHECK(DecodeXmlEntities("") == ""); + + // Key with entities as S3 would return it + std::string_view Xml = "path/file&name<1>.txt"; + CHECK(DecodeXmlEntities(ExtractXmlValue(Xml, "Key")) == "path/file&name<1>.txt"); +} + +TEST_CASE("s3client.path_style_addressing") +{ + // Verify path-style builds /{bucket}/{key} paths + S3ClientOptions Opts; + Opts.BucketName = "test-bucket"; + Opts.Region = "us-east-1"; + Opts.Endpoint = "http://localhost:9000"; + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = "minioadmin"; + Opts.Credentials.SecretAccessKey = "minioadmin"; + + S3Client Client(Opts); + CHECK(Client.BucketName() == "test-bucket"); + CHECK(Client.Region() == "us-east-1"); +} + +TEST_CASE("s3client.virtual_hosted_addressing") +{ + // Verify virtual-hosted style derives endpoint from region + bucket + S3ClientOptions Opts; + Opts.BucketName = "my-bucket"; + Opts.Region = "eu-west-1"; + Opts.PathStyle = false; + Opts.Credentials.AccessKeyId = "key"; + Opts.Credentials.SecretAccessKey = "secret"; + + S3Client Client(Opts); + CHECK(Client.BucketName() == "my-bucket"); + CHECK(Client.Region() == "eu-west-1"); +} + +TEST_CASE("s3client.minio_integration") +{ + using namespace std::literals; + + // Spawn a local MinIO server + MinioProcessOptions MinioOpts; + MinioOpts.Port = 19000; + MinioOpts.RootUser = "testuser"; + MinioOpts.RootPassword = "testpassword"; + + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + + // Pre-create the test bucket (creates a subdirectory in MinIO's data dir) + Minio.CreateBucket("integration-test"); + + // Configure S3Client for the test bucket + S3ClientOptions Opts; + Opts.BucketName = "integration-test"; + Opts.Region = "us-east-1"; + Opts.Endpoint = Minio.Endpoint(); + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = std::string(Minio.RootUser()); + Opts.Credentials.SecretAccessKey = std::string(Minio.RootPassword()); + + S3Client Client(Opts); + + SUBCASE("put_get_delete") + { + // PUT + std::string_view TestData = "hello, minio integration test!"sv; + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); + S3Result PutRes = Client.PutObject("test/hello.txt", std::move(Content)); + REQUIRE(PutRes.IsSuccess()); + + // GET + S3GetObjectResult GetRes = Client.GetObject("test/hello.txt"); + REQUIRE(GetRes.IsSuccess()); + CHECK(GetRes.AsText() == TestData); + + // HEAD + S3HeadObjectResult HeadRes = Client.HeadObject("test/hello.txt"); + REQUIRE(HeadRes.IsSuccess()); + CHECK(HeadRes.Status == HeadObjectResult::Found); + CHECK(HeadRes.Info.Size == TestData.size()); + + // DELETE + S3Result DelRes = Client.DeleteObject("test/hello.txt"); + REQUIRE(DelRes.IsSuccess()); + + // HEAD after delete + S3HeadObjectResult HeadRes2 = Client.HeadObject("test/hello.txt"); + REQUIRE(HeadRes2.IsSuccess()); + CHECK(HeadRes2.Status == HeadObjectResult::NotFound); + } + + SUBCASE("head_not_found") + { + S3HeadObjectResult Res = Client.HeadObject("nonexistent/key.dat"); + CHECK(Res.IsSuccess()); + CHECK(Res.Status == HeadObjectResult::NotFound); + } + + SUBCASE("list_objects") + { + // Upload several objects with a common prefix + for (int i = 0; i < 3; ++i) + { + std::string Key = fmt::format("list-test/item-{}.txt", i); + std::string Payload = fmt::format("payload-{}", i); + IoBuffer Buf = IoBufferBuilder::MakeFromMemory(MakeMemoryView(Payload)); + S3Result Res = Client.PutObject(Key, std::move(Buf)); + REQUIRE(Res.IsSuccess()); + } + + // List with prefix + S3ListObjectsResult ListRes = Client.ListObjects("list-test/"); + REQUIRE(ListRes.IsSuccess()); + CHECK(ListRes.Objects.size() == 3); + + // Verify keys are present + std::vector Keys; + for (const S3ObjectInfo& Obj : ListRes.Objects) + { + Keys.push_back(Obj.Key); + } + std::sort(Keys.begin(), Keys.end()); + CHECK(Keys[0] == "list-test/item-0.txt"); + CHECK(Keys[1] == "list-test/item-1.txt"); + CHECK(Keys[2] == "list-test/item-2.txt"); + + // Cleanup + for (int i = 0; i < 3; ++i) + { + Client.DeleteObject(fmt::format("list-test/item-{}.txt", i)); + } + } + + SUBCASE("multipart_upload") + { + // Create a payload large enough to exercise multipart (use minimum part size) + constexpr uint64_t PartSize = 5 * 1024 * 1024; // 5 MB minimum + constexpr uint64_t PayloadSize = PartSize + 1024; // slightly over one part + + std::string LargePayload(PayloadSize, 'X'); + // Add some variation + for (uint64_t i = 0; i < PayloadSize; i += 1024) + { + LargePayload[i] = char('A' + (i / 1024) % 26); + } + + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(LargePayload)); + S3Result Res = Client.PutObjectMultipart("multipart/large.bin", std::move(Content), PartSize); + REQUIRE(Res.IsSuccess()); + + // Verify via GET + S3GetObjectResult GetRes = Client.GetObject("multipart/large.bin"); + REQUIRE(GetRes.IsSuccess()); + CHECK(GetRes.Content.GetSize() == PayloadSize); + CHECK(GetRes.AsText() == std::string_view(LargePayload)); + + // Cleanup + Client.DeleteObject("multipart/large.bin"); + } + + SUBCASE("presigned_urls") + { + // Upload an object + std::string_view TestData = "presigned-url-test-data"sv; + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); + S3Result PutRes = Client.PutObject("presigned/test.txt", std::move(Content)); + REQUIRE(PutRes.IsSuccess()); + + // Generate a pre-signed GET URL + std::string Url = Client.GeneratePresignedGetUrl("presigned/test.txt", std::chrono::seconds(60)); + CHECK(!Url.empty()); + CHECK(Url.find("X-Amz-Signature") != std::string::npos); + + // Fetch via the pre-signed URL (no auth headers needed) + HttpClient Hc(Minio.Endpoint()); + // Extract the path+query from the full URL + std::string_view UrlView = Url; + size_t PathStart = UrlView.find('/', UrlView.find("://") + 3); + std::string PathAndQuery(UrlView.substr(PathStart)); + HttpClient::Response Resp = Hc.Get(PathAndQuery); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == TestData); + + // Cleanup + Client.DeleteObject("presigned/test.txt"); + } + + Minio.StopMinioServer(); +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen -- cgit v1.2.3