diff options
| author | Stefan Boberg <[email protected]> | 2026-03-18 11:27:07 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-18 11:27:07 +0100 |
| commit | e64d76ae1b6993582bf161a61049f0771414a779 (patch) | |
| tree | 083f3df42cc9e2c7ddbee225708b4848eb217d11 /src | |
| parent | Compute batching (#849) (diff) | |
| download | zen-e64d76ae1b6993582bf161a61049f0771414a779.tar.xz zen-e64d76ae1b6993582bf161a61049f0771414a779.zip | |
Simple S3 client (#836)
This functionality is intended to be used to manage datasets for test cases, but may be useful elsewhere in the future.
- **Add S3 client with AWS Signature V4 (SigV4) signing** — new `S3Client` in `zenutil/cloud/` supporting `GetObject`, `PutObject`, `DeleteObject`, `HeadObject`, and `ListObjects` operations
- **Add EC2 IMDS credential provider** — automatically fetches and refreshes temporary AWS credentials from the EC2 Instance Metadata Service (IMDSv2) for use by the S3 client
- **Add SigV4 signing library** — standalone implementation of AWS Signature Version 4 request signing (headers and query-string presigning)
- **Add path-style addressing support** — enables compatibility with S3-compatible stores like MinIO (in addition to virtual-hosted style)
- **Add S3 integration tests** — includes a `MinioProcess` test helper that spins up a local MinIO server, plus integration tests exercising the S3 client end-to-end
- **Add S3-backed `HttpObjectStoreService` tests** — integration tests verifying the zenserver object store works against an S3 backend
- **Refactor mock IMDS into `zenutil/cloud/`** — moved and generalized the mock IMDS server from `zencompute` so it can be reused by both compute and S3 credential tests
Diffstat (limited to 'src')
23 files changed, 3555 insertions, 330 deletions
diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp index 65bac895f..eb4c05f9f 100644 --- a/src/zencompute/cloudmetadata.cpp +++ b/src/zencompute/cloudmetadata.cpp @@ -23,22 +23,6 @@ static constexpr std::string_view kImdsEndpoint = "http://169.254.169.254"; // is a local service on the hypervisor so 200ms is generous for actual cloud VMs. static constexpr auto kImdsTimeout = std::chrono::milliseconds{200}; -std::string_view -ToString(CloudProvider Provider) -{ - switch (Provider) - { - case CloudProvider::AWS: - return "AWS"; - case CloudProvider::Azure: - return "Azure"; - case CloudProvider::GCP: - return "GCP"; - default: - return "None"; - } -} - CloudMetadata::CloudMetadata(std::filesystem::path DataDir) : CloudMetadata(std::move(DataDir), std::string(kImdsEndpoint)) { } @@ -610,7 +594,7 @@ CloudMetadata::PollGCPTermination() #if ZEN_WITH_TESTS -# include <zencompute/mockimds.h> +# include <zenutil/cloud/mockimds.h> # include <zencore/filesystem.h> # include <zencore/testing.h> diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h index a5bc5a34d..3b9642ac3 100644 --- a/src/zencompute/include/zencompute/cloudmetadata.h +++ b/src/zencompute/include/zencompute/cloudmetadata.h @@ -2,6 +2,8 @@ #pragma once +#include <zenutil/cloud/cloudprovider.h> + #include <zencore/compactbinarybuilder.h> #include <zencore/logging.h> #include <zencore/thread.h> @@ -13,16 +15,6 @@ namespace zen::compute { -enum class CloudProvider -{ - None, - AWS, - Azure, - GCP -}; - -std::string_view ToString(CloudProvider Provider); - /** Snapshot of detected cloud instance properties. */ struct CloudInstanceInfo { diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h index 521722e63..704306913 100644 --- a/src/zencompute/include/zencompute/mockimds.h +++ b/src/zencompute/include/zencompute/mockimds.h @@ -1,102 +1,6 @@ // Copyright Epic Games, Inc. All Rights Reserved. +// Moved to zenutil — this header is kept for backward compatibility. #pragma once -#include <zencompute/cloudmetadata.h> -#include <zenhttp/httpserver.h> - -#include <string> - -#if ZEN_WITH_TESTS - -namespace zen::compute { - -/** - * Mock IMDS (Instance Metadata Service) for testing CloudMetadata. - * - * Implements an HttpService that responds to the same URL paths as the real - * cloud provider metadata endpoints (AWS IMDSv2, Azure IMDS, GCP metadata). - * Tests configure which provider is "active" and set the desired response - * values, then pass the mock server's address as the ImdsEndpoint to the - * CloudMetadata constructor. - * - * When a request arrives for a provider that is not the ActiveProvider, the - * mock returns 404, causing CloudMetadata to write a sentinel file and move - * on to the next provider — exactly like a failed probe on bare metal. - * - * All config fields are public and can be mutated between poll cycles to - * simulate state changes (e.g. a spot interruption appearing mid-run). - * - * Usage: - * MockImdsService Mock; - * Mock.ActiveProvider = CloudProvider::AWS; - * Mock.Aws.InstanceId = "i-test"; - * // ... stand up ASIO server, register Mock, create CloudMetadata with endpoint - */ -class MockImdsService : public HttpService -{ -public: - /** AWS IMDSv2 response configuration. */ - struct AwsConfig - { - std::string Token = "mock-aws-token-v2"; - std::string InstanceId = "i-0123456789abcdef0"; - std::string AvailabilityZone = "us-east-1a"; - std::string LifeCycle = "on-demand"; // "spot" or "on-demand" - - // Empty string → endpoint returns 404 (instance not in an ASG). - // Non-empty → returned as the response body. "InService" means healthy; - // anything else (e.g. "Terminated:Wait") triggers termination detection. - std::string AutoscalingState; - - // Empty string → endpoint returns 404 (no spot interruption). - // Non-empty → returned as the response body, signalling a spot reclaim. - std::string SpotAction; - }; - - /** Azure IMDS response configuration. */ - struct AzureConfig - { - std::string VmId = "vm-12345678-1234-1234-1234-123456789abc"; - std::string Location = "eastus"; - std::string Priority = "Regular"; // "Spot" or "Regular" - - // Empty → instance is not in a VM Scale Set (no autoscaling). - std::string VmScaleSetName; - - // Empty → no scheduled events. Set to "Preempt", "Terminate", or - // "Reboot" to simulate a termination-class event. - std::string ScheduledEventType; - std::string ScheduledEventStatus = "Scheduled"; - }; - - /** GCP metadata response configuration. */ - struct GcpConfig - { - std::string InstanceId = "1234567890123456789"; - std::string Zone = "projects/123456/zones/us-central1-a"; - std::string Preemptible = "FALSE"; // "TRUE" or "FALSE" - std::string MaintenanceEvent = "NONE"; // "NONE" or event description - }; - - /** Which provider's endpoints respond successfully. - * Requests targeting other providers receive 404. - */ - CloudProvider ActiveProvider = CloudProvider::None; - - AwsConfig Aws; - AzureConfig Azure; - GcpConfig Gcp; - - const char* BaseUri() const override; - void HandleRequest(HttpServerRequest& Request) override; - -private: - void HandleAwsRequest(HttpServerRequest& Request); - void HandleAzureRequest(HttpServerRequest& Request); - void HandleGcpRequest(HttpServerRequest& Request); -}; - -} // namespace zen::compute - -#endif // ZEN_WITH_TESTS +#include <zenutil/cloud/mockimds.h> diff --git a/src/zencompute/testing/mockimds.cpp b/src/zencompute/testing/mockimds.cpp index dd09312df..5415f48f3 100644 --- a/src/zencompute/testing/mockimds.cpp +++ b/src/zencompute/testing/mockimds.cpp @@ -1,205 +1,2 @@ // Copyright Epic Games, Inc. All Rights Reserved. - -#include <zencompute/mockimds.h> - -#include <zencore/fmtutils.h> - -#if ZEN_WITH_TESTS - -namespace zen::compute { - -const char* -MockImdsService::BaseUri() const -{ - return "/"; -} - -void -MockImdsService::HandleRequest(HttpServerRequest& Request) -{ - std::string_view Uri = Request.RelativeUri(); - - // AWS endpoints live under /latest/ - if (Uri.starts_with("latest/")) - { - if (ActiveProvider == CloudProvider::AWS) - { - HandleAwsRequest(Request); - return; - } - Request.WriteResponse(HttpResponseCode::NotFound); - return; - } - - // Azure endpoints live under /metadata/ - if (Uri.starts_with("metadata/")) - { - if (ActiveProvider == CloudProvider::Azure) - { - HandleAzureRequest(Request); - return; - } - Request.WriteResponse(HttpResponseCode::NotFound); - return; - } - - // GCP endpoints live under /computeMetadata/ - if (Uri.starts_with("computeMetadata/")) - { - if (ActiveProvider == CloudProvider::GCP) - { - HandleGcpRequest(Request); - return; - } - Request.WriteResponse(HttpResponseCode::NotFound); - return; - } - - Request.WriteResponse(HttpResponseCode::NotFound); -} - -// --------------------------------------------------------------------------- -// AWS -// --------------------------------------------------------------------------- - -void -MockImdsService::HandleAwsRequest(HttpServerRequest& Request) -{ - std::string_view Uri = Request.RelativeUri(); - - // IMDSv2 token acquisition (PUT only) - if (Uri == "latest/api/token" && Request.RequestVerb() == HttpVerb::kPut) - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.Token); - return; - } - - // Instance identity - if (Uri == "latest/meta-data/instance-id") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.InstanceId); - return; - } - - if (Uri == "latest/meta-data/placement/availability-zone") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AvailabilityZone); - return; - } - - if (Uri == "latest/meta-data/instance-life-cycle") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.LifeCycle); - return; - } - - // Autoscaling lifecycle state — 404 when not in an ASG - if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") - { - if (Aws.AutoscalingState.empty()) - { - Request.WriteResponse(HttpResponseCode::NotFound); - return; - } - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AutoscalingState); - return; - } - - // Spot interruption notice — 404 when no interruption pending - if (Uri == "latest/meta-data/spot/instance-action") - { - if (Aws.SpotAction.empty()) - { - Request.WriteResponse(HttpResponseCode::NotFound); - return; - } - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.SpotAction); - return; - } - - Request.WriteResponse(HttpResponseCode::NotFound); -} - -// --------------------------------------------------------------------------- -// Azure -// --------------------------------------------------------------------------- - -void -MockImdsService::HandleAzureRequest(HttpServerRequest& Request) -{ - std::string_view Uri = Request.RelativeUri(); - - // Instance metadata (single JSON document) - if (Uri == "metadata/instance") - { - std::string Json = fmt::format(R"({{"compute":{{"vmId":"{}","location":"{}","priority":"{}","vmScaleSetName":"{}"}}}})", - Azure.VmId, - Azure.Location, - Azure.Priority, - Azure.VmScaleSetName); - - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); - return; - } - - // Scheduled events for termination monitoring - if (Uri == "metadata/scheduledevents") - { - std::string Json; - if (Azure.ScheduledEventType.empty()) - { - Json = R"({"Events":[]})"; - } - else - { - Json = fmt::format(R"({{"Events":[{{"EventType":"{}","EventStatus":"{}"}}]}})", - Azure.ScheduledEventType, - Azure.ScheduledEventStatus); - } - - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); - return; - } - - Request.WriteResponse(HttpResponseCode::NotFound); -} - -// --------------------------------------------------------------------------- -// GCP -// --------------------------------------------------------------------------- - -void -MockImdsService::HandleGcpRequest(HttpServerRequest& Request) -{ - std::string_view Uri = Request.RelativeUri(); - - if (Uri == "computeMetadata/v1/instance/id") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.InstanceId); - return; - } - - if (Uri == "computeMetadata/v1/instance/zone") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Zone); - return; - } - - if (Uri == "computeMetadata/v1/instance/scheduling/preemptible") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Preemptible); - return; - } - - if (Uri == "computeMetadata/v1/instance/maintenance-event") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.MaintenanceEvent); - return; - } - - Request.WriteResponse(HttpResponseCode::NotFound); -} - -} // namespace zen::compute - -#endif // ZEN_WITH_TESTS +// Moved to zenutil/cloud/mockimds.cpp diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp index 080607f13..0c55e6c7e 100644 --- a/src/zencore/process.cpp +++ b/src/zencore/process.cpp @@ -492,8 +492,10 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma LPSECURITY_ATTRIBUTES ThreadAttributes = nullptr; // Build environment block when custom environment variables are specified + ExtendableWideStringBuilder<512> EnvironmentBlock; void* Environment = nullptr; + if (!Options.Environment.empty()) { // Capture current environment into a map diff --git a/src/zens3-testbed/main.cpp b/src/zens3-testbed/main.cpp new file mode 100644 index 000000000..4cd6b411f --- /dev/null +++ b/src/zens3-testbed/main.cpp @@ -0,0 +1,526 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +// Simple test bed for exercising the zens3 module against a real S3 bucket. +// +// Usage: +// zens3-testbed --bucket <name> --region <region> [command] [args...] +// +// Credentials are read from environment variables: +// AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY +// +// Commands: +// put <key> <file> Upload a local file +// get <key> [file] Download an object (prints to stdout if no file given) +// head <key> Check if object exists, show metadata +// delete <key> Delete an object +// list [prefix] List objects with optional prefix +// multipart-put <key> <file> [part-size-mb] Upload via multipart +// roundtrip <key> Upload test data, download, verify, delete + +#include <zenutil/cloud/imdscredentials.h> +#include <zenutil/cloud/s3client.h> + +#include <zencore/except_fmt.h> +#include <zencore/filesystem.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/string.h> + +#include <zencore/memory/newdelete.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <cxxopts.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <cstdlib> +#include <fstream> +#include <iostream> + +namespace { + +using namespace zen; + +std::string +GetEnvVar(const char* Name) +{ + const char* Value = std::getenv(Name); + return Value ? std::string(Value) : std::string(); +} + +IoBuffer +ReadFileToBuffer(const std::filesystem::path& Path) +{ + return zen::ReadFile(Path).Flatten(); +} + +void +WriteBufferToFile(const IoBuffer& Buffer, const std::filesystem::path& Path) +{ + std::ofstream File(Path, std::ios::binary); + if (!File) + { + throw zen::runtime_error("failed to open '{}' for writing", Path.string()); + } + File.write(reinterpret_cast<const char*>(Buffer.GetData()), static_cast<std::streamsize>(Buffer.GetSize())); +} + +S3Client +CreateClient(const cxxopts::ParseResult& Args) +{ + S3ClientOptions Options; + Options.BucketName = Args["bucket"].as<std::string>(); + Options.Region = Args["region"].as<std::string>(); + + if (Args.count("imds")) + { + // Use IMDS credential provider for EC2 instances + ImdsCredentialProviderOptions ImdsOpts; + if (Args.count("imds-endpoint")) + { + ImdsOpts.Endpoint = Args["imds-endpoint"].as<std::string>(); + } + Options.CredentialProvider = Ref<ImdsCredentialProvider>(new ImdsCredentialProvider(ImdsOpts)); + } + else + { + std::string AccessKey = GetEnvVar("AWS_ACCESS_KEY_ID"); + std::string SecretKey = GetEnvVar("AWS_SECRET_ACCESS_KEY"); + std::string SessionToken = GetEnvVar("AWS_SESSION_TOKEN"); + + if (AccessKey.empty() || SecretKey.empty()) + { + throw zen::runtime_error("AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables must be set"); + } + + Options.Credentials.AccessKeyId = std::move(AccessKey); + Options.Credentials.SecretAccessKey = std::move(SecretKey); + Options.Credentials.SessionToken = std::move(SessionToken); + } + + if (Args.count("endpoint")) + { + Options.Endpoint = Args["endpoint"].as<std::string>(); + } + + if (Args.count("path-style")) + { + Options.PathStyle = true; + } + + if (Args.count("timeout")) + { + Options.Timeout = std::chrono::milliseconds(Args["timeout"].as<int>() * 1000); + } + + return S3Client(Options); +} + +int +CmdPut(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 3) + { + fmt::print(stderr, "Usage: zens3-testbed ... put <key> <file>\n"); + return 1; + } + + const auto& Key = Positional[1]; + const auto& FilePath = Positional[2]; + + IoBuffer Content = ReadFileToBuffer(FilePath); + fmt::print("Uploading '{}' ({} bytes) to s3://{}/{}\n", FilePath, Content.GetSize(), Client.BucketName(), Key); + + S3Result Result = Client.PutObject(Key, Content); + if (!Result) + { + fmt::print(stderr, "PUT failed: {}\n", Result.Error); + return 1; + } + + fmt::print("OK\n"); + return 0; +} + +int +CmdGet(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 2) + { + fmt::print(stderr, "Usage: zens3-testbed ... get <key> [file]\n"); + return 1; + } + + const auto& Key = Positional[1]; + + S3GetObjectResult Result = Client.GetObject(Key); + if (!Result) + { + fmt::print(stderr, "GET failed: {}\n", Result.Error); + return 1; + } + + if (Positional.size() >= 3) + { + const auto& FilePath = Positional[2]; + WriteBufferToFile(Result.Content, FilePath); + fmt::print("Downloaded {} bytes to '{}'\n", Result.Content.GetSize(), FilePath); + } + else + { + // Print to stdout + std::string_view Text = Result.AsText(); + std::cout.write(Text.data(), static_cast<std::streamsize>(Text.size())); + std::cout << std::endl; + } + + return 0; +} + +int +CmdHead(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 2) + { + fmt::print(stderr, "Usage: zens3-testbed ... head <key>\n"); + return 1; + } + + const auto& Key = Positional[1]; + + S3HeadObjectResult Result = Client.HeadObject(Key); + + if (!Result) + { + fmt::print(stderr, "HEAD failed: {}\n", Result.Error); + return 1; + } + + if (Result.Status == HeadObjectResult::NotFound) + { + fmt::print("Object '{}' does not exist\n", Key); + return 1; + } + + fmt::print("Key: {}\n", Result.Info.Key); + fmt::print("Size: {} bytes\n", Result.Info.Size); + fmt::print("ETag: {}\n", Result.Info.ETag); + fmt::print("Last-Modified: {}\n", Result.Info.LastModified); + return 0; +} + +int +CmdDelete(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 2) + { + fmt::print(stderr, "Usage: zens3-testbed ... delete <key>\n"); + return 1; + } + + const auto& Key = Positional[1]; + + S3Result Result = Client.DeleteObject(Key); + if (!Result) + { + fmt::print(stderr, "DELETE failed: {}\n", Result.Error); + return 1; + } + + fmt::print("Deleted '{}'\n", Key); + return 0; +} + +int +CmdList(S3Client& Client, const std::vector<std::string>& Positional) +{ + std::string Prefix; + if (Positional.size() >= 2) + { + Prefix = Positional[1]; + } + + S3ListObjectsResult Result = Client.ListObjects(Prefix); + if (!Result) + { + fmt::print(stderr, "LIST failed: {}\n", Result.Error); + return 1; + } + + fmt::print("{} objects found:\n", Result.Objects.size()); + for (const auto& Obj : Result.Objects) + { + fmt::print(" {:>12} {} {}\n", Obj.Size, Obj.LastModified, Obj.Key); + } + + return 0; +} + +int +CmdMultipartPut(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 3) + { + fmt::print(stderr, "Usage: zens3-testbed ... multipart-put <key> <file> [part-size-mb]\n"); + return 1; + } + + const auto& Key = Positional[1]; + const auto& FilePath = Positional[2]; + + uint64_t PartSize = 8 * 1024 * 1024; // 8 MB default + if (Positional.size() >= 4) + { + PartSize = std::stoull(Positional[3]) * 1024 * 1024; + } + + IoBuffer Content = ReadFileToBuffer(FilePath); + fmt::print("Multipart uploading '{}' ({} bytes, part size {} MB) to s3://{}/{}\n", + FilePath, + Content.GetSize(), + PartSize / (1024 * 1024), + Client.BucketName(), + Key); + + S3Result Result = Client.PutObjectMultipart(Key, Content, PartSize); + if (!Result) + { + fmt::print(stderr, "Multipart PUT failed: {}\n", Result.Error); + return 1; + } + + fmt::print("OK\n"); + return 0; +} + +int +CmdRoundtrip(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 2) + { + fmt::print(stderr, "Usage: zens3-testbed ... roundtrip <key>\n"); + return 1; + } + + const auto& Key = Positional[1]; + + // Generate test data + const size_t TestSize = 1024 * 64; // 64 KB + std::vector<uint8_t> TestData(TestSize); + for (size_t i = 0; i < TestSize; ++i) + { + TestData[i] = static_cast<uint8_t>(i & 0xFF); + } + + IoBuffer UploadContent(IoBuffer::Clone, TestData.data(), TestData.size()); + + fmt::print("=== Roundtrip test for key '{}' ===\n\n", Key); + + // PUT + fmt::print("[1/4] PUT {} bytes...\n", TestSize); + S3Result Result = Client.PutObject(Key, UploadContent); + if (!Result) + { + fmt::print(stderr, " FAILED: {}\n", Result.Error); + return 1; + } + fmt::print(" OK\n"); + + // HEAD + fmt::print("[2/4] HEAD...\n"); + S3HeadObjectResult HeadResult = Client.HeadObject(Key); + if (HeadResult.Status != HeadObjectResult::Found) + { + fmt::print(stderr, " FAILED: {}\n", !HeadResult ? HeadResult.Error : "not found"); + return 1; + } + fmt::print(" OK (size={}, etag={})\n", HeadResult.Info.Size, HeadResult.Info.ETag); + + if (HeadResult.Info.Size != TestSize) + { + fmt::print(stderr, " SIZE MISMATCH: expected {}, got {}\n", TestSize, HeadResult.Info.Size); + return 1; + } + + // GET + fmt::print("[3/4] GET and verify...\n"); + S3GetObjectResult GetResult = Client.GetObject(Key); + if (!GetResult) + { + fmt::print(stderr, " FAILED: {}\n", GetResult.Error); + return 1; + } + + if (GetResult.Content.GetSize() != TestSize) + { + fmt::print(stderr, " SIZE MISMATCH: expected {}, got {}\n", TestSize, GetResult.Content.GetSize()); + return 1; + } + + if (memcmp(GetResult.Content.GetData(), TestData.data(), TestSize) != 0) + { + fmt::print(stderr, " DATA MISMATCH\n"); + return 1; + } + fmt::print(" OK (verified {} bytes)\n", TestSize); + + // DELETE + fmt::print("[4/4] DELETE...\n"); + Result = Client.DeleteObject(Key); + if (!Result) + { + fmt::print(stderr, " FAILED: {}\n", Result.Error); + return 1; + } + fmt::print(" OK\n"); + + fmt::print("\n=== Roundtrip test PASSED ===\n"); + return 0; +} + +int +CmdPresign(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 2) + { + fmt::print(stderr, "Usage: zens3-testbed ... presign <key> [method] [expires-seconds]\n"); + return 1; + } + + const auto& Key = Positional[1]; + + std::string Method = "GET"; + if (Positional.size() >= 3) + { + Method = Positional[2]; + } + + std::chrono::seconds ExpiresIn(3600); + if (Positional.size() >= 4) + { + ExpiresIn = std::chrono::seconds(std::stoul(Positional[3])); + } + + std::string Url; + if (Method == "PUT") + { + Url = Client.GeneratePresignedPutUrl(Key, ExpiresIn); + } + else + { + Url = Client.GeneratePresignedGetUrl(Key, ExpiresIn); + } + + fmt::print("{}\n", Url); + return 0; +} + +} // namespace + +int +main(int argc, char* argv[]) +{ + using namespace zen; + + logging::InitializeLogging(); + + cxxopts::Options Options("zens3-testbed", "Test bed for exercising S3 operations via the zens3 module"); + + // clang-format off + Options.add_options() + ("b,bucket", "S3 bucket name", cxxopts::value<std::string>()) + ("r,region", "AWS region", cxxopts::value<std::string>()->default_value("us-east-1")) + ("e,endpoint", "Custom S3 endpoint URL", cxxopts::value<std::string>()) + ("path-style", "Use path-style addressing (for MinIO, etc.)") + ("imds", "Use EC2 IMDS for credentials instead of env vars") + ("imds-endpoint", "Custom IMDS endpoint URL (for testing)", cxxopts::value<std::string>()) + ("timeout", "Request timeout in seconds", cxxopts::value<int>()->default_value("30")) + ("v,verbose", "Enable verbose logging") + ("h,help", "Show help") + ("positional", "Command and arguments", cxxopts::value<std::vector<std::string>>()); + // clang-format on + + Options.parse_positional({"positional"}); + Options.positional_help("<command> [args...]"); + + try + { + auto Result = Options.parse(argc, argv); + + if (Result.count("help") || !Result.count("positional")) + { + fmt::print("{}\n", Options.help()); + fmt::print("Commands:\n"); + fmt::print(" put <key> <file> Upload a local file\n"); + fmt::print(" get <key> [file] Download (to file or stdout)\n"); + fmt::print(" head <key> Show object metadata\n"); + fmt::print(" delete <key> Delete an object\n"); + fmt::print(" list [prefix] List objects\n"); + fmt::print(" multipart-put <key> <file> [part-mb] Multipart upload\n"); + fmt::print(" roundtrip <key> Upload/download/verify/delete\n"); + fmt::print(" presign <key> [method] [expires-sec] Generate pre-signed URL\n"); + fmt::print("\nCredentials via AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY env vars,\n"); + fmt::print("or use --imds to fetch from EC2 Instance Metadata Service.\n"); + return 0; + } + + if (!Result.count("bucket")) + { + fmt::print(stderr, "Error: --bucket is required\n"); + return 1; + } + + if (Result.count("verbose")) + { + logging::SetLogLevel(logging::Debug); + } + + auto Client = CreateClient(Result); + + const auto& Positional = Result["positional"].as<std::vector<std::string>>(); + const auto& Command = Positional[0]; + + if (Command == "put") + { + return CmdPut(Client, Positional); + } + else if (Command == "get") + { + return CmdGet(Client, Positional); + } + else if (Command == "head") + { + return CmdHead(Client, Positional); + } + else if (Command == "delete") + { + return CmdDelete(Client, Positional); + } + else if (Command == "list") + { + return CmdList(Client, Positional); + } + else if (Command == "multipart-put") + { + return CmdMultipartPut(Client, Positional); + } + else if (Command == "roundtrip") + { + return CmdRoundtrip(Client, Positional); + } + else if (Command == "presign") + { + return CmdPresign(Client, Positional); + } + else + { + fmt::print(stderr, "Unknown command: '{}'\n", Command); + return 1; + } + } + catch (const std::exception& Ex) + { + fmt::print(stderr, "Error: {}\n", Ex.what()); + return 1; + } +} diff --git a/src/zens3-testbed/xmake.lua b/src/zens3-testbed/xmake.lua new file mode 100644 index 000000000..168ab9de9 --- /dev/null +++ b/src/zens3-testbed/xmake.lua @@ -0,0 +1,8 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("zens3-testbed") + set_kind("binary") + set_group("tools") + add_files("*.cpp") + add_deps("zenutil", "zencore") + add_deps("cxxopts", "fmt") diff --git a/src/zenserver-test/objectstore-tests.cpp b/src/zenserver-test/objectstore-tests.cpp index f3db5fdf6..1f6a7675c 100644 --- a/src/zenserver-test/objectstore-tests.cpp +++ b/src/zenserver-test/objectstore-tests.cpp @@ -2,10 +2,12 @@ #if ZEN_WITH_TESTS # include "zenserver-test.h" +# include <zencore/memoryview.h> # include <zencore/testing.h> # include <zencore/testutils.h> -# include <zenutil/zenserverprocess.h> # include <zenhttp/httpclient.h> +# include <zenutil/cloud/s3client.h> +# include <zenutil/zenserverprocess.h> ZEN_THIRD_PARTY_INCLUDES_START # include <tsl/robin_set.h> @@ -68,6 +70,94 @@ TEST_CASE("objectstore.blobs") } } +TEST_CASE("objectstore.s3client") +{ + ZenServerInstance Instance(TestEnv); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--objectstore-enabled"); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + // S3Client in path-style builds paths as /{bucket}/{key}. + // The objectstore routes objects at bucket/{bucket}/{key} relative to its base. + // Point the S3Client endpoint at {server}/obj/bucket so the paths line up. + S3ClientOptions Opts; + Opts.BucketName = "s3test"; + Opts.Region = "us-east-1"; + Opts.Endpoint = fmt::format("http://localhost:{}/obj/bucket", Port); + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = "testkey"; + Opts.Credentials.SecretAccessKey = "testsecret"; + + S3Client Client(Opts); + + // -- PUT + GET roundtrip -- + std::string_view TestData = "hello from s3client via objectstore"sv; + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); + S3Result PutRes = Client.PutObject("test/hello.txt", std::move(Content)); + REQUIRE_MESSAGE(PutRes.IsSuccess(), PutRes.Error); + + S3GetObjectResult GetRes = Client.GetObject("test/hello.txt"); + REQUIRE_MESSAGE(GetRes.IsSuccess(), GetRes.Error); + CHECK(GetRes.AsText() == TestData); + + // -- PUT overwrites -- + IoBuffer Original = IoBufferBuilder::MakeFromMemory(MakeMemoryView("original"sv)); + IoBuffer Overwrite = IoBufferBuilder::MakeFromMemory(MakeMemoryView("overwritten"sv)); + REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Original)).IsSuccess()); + REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Overwrite)).IsSuccess()); + + S3GetObjectResult OverwriteGet = Client.GetObject("overwrite/file.txt"); + REQUIRE(OverwriteGet.IsSuccess()); + CHECK(OverwriteGet.AsText() == "overwritten"sv); + + // -- GET not found -- + S3GetObjectResult NotFoundGet = Client.GetObject("nonexistent/file.dat"); + CHECK_FALSE(NotFoundGet.IsSuccess()); + + // -- HEAD found -- + std::string_view HeadData = "head test data"sv; + IoBuffer HeadContent = IoBufferBuilder::MakeFromMemory(MakeMemoryView(HeadData)); + REQUIRE(Client.PutObject("head/meta.txt", std::move(HeadContent)).IsSuccess()); + + S3HeadObjectResult HeadRes = Client.HeadObject("head/meta.txt"); + REQUIRE_MESSAGE(HeadRes.IsSuccess(), HeadRes.Error); + CHECK(HeadRes.Status == HeadObjectResult::Found); + CHECK(HeadRes.Info.Size == HeadData.size()); + + // -- HEAD not found -- + S3HeadObjectResult HeadNotFound = Client.HeadObject("nonexistent/file.dat"); + CHECK(HeadNotFound.IsSuccess()); + CHECK(HeadNotFound.Status == HeadObjectResult::NotFound); + + // -- LIST objects -- + for (int i = 0; i < 3; ++i) + { + std::string Key = fmt::format("listing/item-{}.txt", i); + std::string Payload = fmt::format("content-{}", i); + IoBuffer Buf = IoBufferBuilder::MakeFromMemory(MakeMemoryView(Payload)); + REQUIRE(Client.PutObject(Key, std::move(Buf)).IsSuccess()); + } + + S3ListObjectsResult ListRes = Client.ListObjects("listing/"); + REQUIRE_MESSAGE(ListRes.IsSuccess(), ListRes.Error); + REQUIRE(ListRes.Objects.size() == 3); + + std::vector<std::string> Keys; + for (const S3ObjectInfo& Obj : ListRes.Objects) + { + Keys.push_back(Obj.Key); + CHECK(Obj.Size > 0); + } + std::sort(Keys.begin(), Keys.end()); + CHECK(Keys[0] == "listing/item-0.txt"); + CHECK(Keys[1] == "listing/item-1.txt"); + CHECK(Keys[2] == "listing/item-2.txt"); + + // -- LIST empty prefix -- + S3ListObjectsResult EmptyList = Client.ListObjects("no-such-prefix/"); + REQUIRE(EmptyList.IsSuccess()); + CHECK(EmptyList.Objects.empty()); +} + TEST_SUITE_END(); } // namespace zen::tests diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua index b619c5548..7bfc3575e 100644 --- a/src/zenserver/xmake.lua +++ b/src/zenserver/xmake.lua @@ -36,6 +36,7 @@ target("zenserver") add_packages("json11") add_packages("lua") add_packages("consul") + add_packages("minio") add_packages("oidctoken") add_packages("nomad") @@ -215,6 +216,16 @@ target("zenserver") copy_if_newer(path.join(installdir, "bin", consul_bin), path.join(target:targetdir(), consul_bin), consul_bin) end + local minio_pkg = target:pkg("minio") + if minio_pkg then + local installdir = minio_pkg:installdir() + local minio_bin = "minio" + if is_plat("windows") then + minio_bin = "minio.exe" + end + copy_if_newer(path.join(installdir, "bin", minio_bin), path.join(target:targetdir(), minio_bin), minio_bin) + end + local oidctoken_pkg = target:pkg("oidctoken") if oidctoken_pkg then local installdir = oidctoken_pkg:installdir() diff --git a/src/zenutil/cloud/cloudprovider.cpp b/src/zenutil/cloud/cloudprovider.cpp new file mode 100644 index 000000000..e32a50c64 --- /dev/null +++ b/src/zenutil/cloud/cloudprovider.cpp @@ -0,0 +1,23 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/cloudprovider.h> + +namespace zen::compute { + +std::string_view +ToString(CloudProvider Provider) +{ + switch (Provider) + { + case CloudProvider::AWS: + return "AWS"; + case CloudProvider::Azure: + return "Azure"; + case CloudProvider::GCP: + return "GCP"; + default: + return "None"; + } +} + +} // namespace zen::compute diff --git a/src/zenutil/cloud/imdscredentials.cpp b/src/zenutil/cloud/imdscredentials.cpp new file mode 100644 index 000000000..dde1dc019 --- /dev/null +++ b/src/zenutil/cloud/imdscredentials.cpp @@ -0,0 +1,387 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/imdscredentials.h> + +#include <zenutil/cloud/mockimds.h> + +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> +#include <zenhttp/httpserver.h> + +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +namespace { + + /// Margin before expiration at which we proactively refresh credentials. + constexpr auto kRefreshMargin = std::chrono::minutes(5); + + /// Parse an ISO 8601 UTC timestamp (e.g. "2026-03-14T20:00:00Z") into a system_clock time_point. + /// Returns epoch on failure. + std::chrono::system_clock::time_point ParseIso8601(std::string_view Timestamp) + { + // Expected format: YYYY-MM-DDTHH:MM:SSZ + if (Timestamp.size() < 19) + { + return {}; + } + + std::tm Tm = {}; + // Manual parse since std::get_time is locale-dependent + Tm.tm_year = ParseInt<int>(Timestamp.substr(0, 4)).value_or(1970) - 1900; + Tm.tm_mon = ParseInt<int>(Timestamp.substr(5, 2)).value_or(1) - 1; + Tm.tm_mday = ParseInt<int>(Timestamp.substr(8, 2)).value_or(1); + Tm.tm_hour = ParseInt<int>(Timestamp.substr(11, 2)).value_or(0); + Tm.tm_min = ParseInt<int>(Timestamp.substr(14, 2)).value_or(0); + Tm.tm_sec = ParseInt<int>(Timestamp.substr(17, 2)).value_or(0); + +#if ZEN_PLATFORM_WINDOWS + time_t EpochSeconds = _mkgmtime(&Tm); +#else + time_t EpochSeconds = timegm(&Tm); +#endif + if (EpochSeconds == -1) + { + return {}; + } + + return std::chrono::system_clock::from_time_t(EpochSeconds); + } + +} // namespace + +ImdsCredentialProvider::ImdsCredentialProvider(const ImdsCredentialProviderOptions& Options) +: m_Log(logging::Get("imds")) +, m_HttpClient(Options.Endpoint, + HttpClientSettings{ + .LogCategory = "imds", + .ConnectTimeout = Options.ConnectTimeout, + .Timeout = Options.RequestTimeout, + }) +{ + ZEN_INFO("IMDS credential provider configured (endpoint: {})", m_HttpClient.GetBaseUri()); +} + +ImdsCredentialProvider::~ImdsCredentialProvider() = default; + +SigV4Credentials +ImdsCredentialProvider::GetCredentials() +{ + // Fast path: shared lock for cache hit + { + RwLock::SharedLockScope SharedLock(m_Lock); + if (!m_CachedCredentials.AccessKeyId.empty() && std::chrono::steady_clock::now() < m_ExpiresAt) + { + return m_CachedCredentials; + } + } + + // Slow path: exclusive lock to refresh + RwLock::ExclusiveLockScope ExclusiveLock(m_Lock); + + // Double-check after acquiring exclusive lock + if (!m_CachedCredentials.AccessKeyId.empty() && std::chrono::steady_clock::now() < m_ExpiresAt) + { + return m_CachedCredentials; + } + + if (!FetchCredentials()) + { + ZEN_WARN("failed to fetch credentials from IMDS"); + return {}; + } + + return m_CachedCredentials; +} + +void +ImdsCredentialProvider::InvalidateCache() +{ + RwLock::ExclusiveLockScope ExclusiveLock(m_Lock); + m_CachedCredentials = {}; + m_ExpiresAt = {}; +} + +bool +ImdsCredentialProvider::FetchToken() +{ + HttpClient::KeyValueMap Headers; + Headers->emplace("X-aws-ec2-metadata-token-ttl-seconds", "21600"); + + HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", Headers); + if (!Response.IsSuccess()) + { + ZEN_WARN("IMDS token request failed: {}", Response.ErrorMessage("PUT /latest/api/token")); + return false; + } + + m_ImdsToken = std::string(Response.AsText()); + if (m_ImdsToken.empty()) + { + ZEN_WARN("IMDS returned empty token"); + return false; + } + + return true; +} + +bool +ImdsCredentialProvider::FetchCredentials() +{ + // Step 1: Get IMDSv2 session token + if (!FetchToken()) + { + return false; + } + + HttpClient::KeyValueMap TokenHeader; + TokenHeader->emplace("X-aws-ec2-metadata-token", m_ImdsToken); + + // Step 2: Discover IAM role name (if not already known) + if (m_RoleName.empty()) + { + HttpClient::Response RoleResponse = m_HttpClient.Get("/latest/meta-data/iam/security-credentials/", TokenHeader); + if (!RoleResponse.IsSuccess()) + { + ZEN_WARN("IMDS role discovery failed: {}", RoleResponse.ErrorMessage("GET iam/security-credentials/")); + return false; + } + + m_RoleName = std::string(RoleResponse.AsText()); + // Trim any trailing whitespace/newlines + while (!m_RoleName.empty() && (m_RoleName.back() == '\n' || m_RoleName.back() == '\r' || m_RoleName.back() == ' ')) + { + m_RoleName.pop_back(); + } + + if (m_RoleName.empty()) + { + ZEN_WARN("IMDS returned empty IAM role name"); + return false; + } + + ZEN_INFO("IMDS discovered IAM role: {}", m_RoleName); + } + + // Step 3: Fetch credentials for the role + std::string CredentialPath = fmt::format("/latest/meta-data/iam/security-credentials/{}", m_RoleName); + + HttpClient::Response CredResponse = m_HttpClient.Get(CredentialPath, TokenHeader); + if (!CredResponse.IsSuccess()) + { + ZEN_WARN("IMDS credential fetch failed: {}", CredResponse.ErrorMessage("GET iam/security-credentials/" + m_RoleName)); + return false; + } + + // Step 4: Parse JSON response + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(CredResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + ZEN_WARN("IMDS credential response JSON parse error: {}", JsonError); + return false; + } + + std::string AccessKeyId = Json["AccessKeyId"].string_value(); + std::string SecretAccessKey = Json["SecretAccessKey"].string_value(); + std::string SessionToken = Json["Token"].string_value(); + std::string Expiration = Json["Expiration"].string_value(); + + if (AccessKeyId.empty() || SecretAccessKey.empty()) + { + ZEN_WARN("IMDS credential response missing AccessKeyId or SecretAccessKey"); + return false; + } + + // Compute local expiration time based on the Expiration field + auto ExpirationTime = ParseIso8601(Expiration); + auto Now = std::chrono::system_clock::now(); + + std::chrono::steady_clock::time_point NewExpiresAt; + if (ExpirationTime > Now) + { + auto TimeUntilExpiry = ExpirationTime - Now; + NewExpiresAt = std::chrono::steady_clock::now() + TimeUntilExpiry - kRefreshMargin; + } + else + { + // Expiration is in the past or unparseable — force refresh next time + NewExpiresAt = std::chrono::steady_clock::now(); + } + + bool KeyChanged = (m_CachedCredentials.AccessKeyId != AccessKeyId); + + m_CachedCredentials.AccessKeyId = std::move(AccessKeyId); + m_CachedCredentials.SecretAccessKey = std::move(SecretAccessKey); + m_CachedCredentials.SessionToken = std::move(SessionToken); + m_ExpiresAt = NewExpiresAt; + + if (KeyChanged) + { + ZEN_INFO("IMDS credentials refreshed (AccessKeyId: {}...)", m_CachedCredentials.AccessKeyId.substr(0, 8)); + } + else + { + ZEN_DEBUG("IMDS credentials refreshed (unchanged key)"); + } + + return true; +} + +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +imdscredentials_forcelink() +{ +} + +TEST_SUITE_BEGIN("util.cloud.imdscredentials"); + +TEST_CASE("imdscredentials.parse_iso8601") +{ + // Verify basic ISO 8601 parsing + auto Tp = ParseIso8601("2026-03-14T20:00:00Z"); + CHECK(Tp != std::chrono::system_clock::time_point{}); + + auto Epoch = std::chrono::system_clock::to_time_t(Tp); + std::tm Tm; +# if ZEN_PLATFORM_WINDOWS + gmtime_s(&Tm, &Epoch); +# else + gmtime_r(&Epoch, &Tm); +# endif + CHECK(Tm.tm_year + 1900 == 2026); + CHECK(Tm.tm_mon + 1 == 3); + CHECK(Tm.tm_mday == 14); + CHECK(Tm.tm_hour == 20); + CHECK(Tm.tm_min == 0); + CHECK(Tm.tm_sec == 0); + + // Invalid input + auto Bad = ParseIso8601("bad"); + CHECK(Bad == std::chrono::system_clock::time_point{}); +} + +// --------------------------------------------------------------------------- +// Integration test with mock IMDS server +// --------------------------------------------------------------------------- + +struct TestImdsServer +{ + compute::MockImdsService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(7576, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + std::string Endpoint() const { return fmt::format("http://127.0.0.1:{}", m_Port); } + + ~TestImdsServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional<ScopedTemporaryDirectory> m_TmpDir; + Ref<HttpServer> m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +TEST_CASE("imdscredentials.fetch_from_mock") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = compute::CloudProvider::AWS; + Imds.Start(); + + ImdsCredentialProviderOptions Opts; + Opts.Endpoint = Imds.Endpoint(); + + Ref<ImdsCredentialProvider> Provider(new ImdsCredentialProvider(Opts)); + + SUBCASE("basic_credential_fetch") + { + SigV4Credentials Creds = Provider->GetCredentials(); + CHECK(!Creds.AccessKeyId.empty()); + CHECK(Creds.AccessKeyId == "ASIAIOSFODNN7EXAMPLE"); + CHECK(Creds.SecretAccessKey == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"); + CHECK(Creds.SessionToken == "FwoGZXIvYXdzEBYaDEXAMPLETOKEN"); + } + + SUBCASE("credentials_are_cached") + { + SigV4Credentials First = Provider->GetCredentials(); + SigV4Credentials Second = Provider->GetCredentials(); + CHECK(First.AccessKeyId == Second.AccessKeyId); + CHECK(First.SecretAccessKey == Second.SecretAccessKey); + } + + SUBCASE("invalidate_forces_refresh") + { + SigV4Credentials First = Provider->GetCredentials(); + CHECK(!First.AccessKeyId.empty()); + + // Change the credentials on the mock + Imds.Mock.Aws.IamAccessKeyId = "ASIANEWKEYEXAMPLE12"; + + Provider->InvalidateCache(); + SigV4Credentials Second = Provider->GetCredentials(); + CHECK(Second.AccessKeyId == "ASIANEWKEYEXAMPLE12"); + } + + SUBCASE("custom_role_name") + { + Imds.Mock.Aws.IamRoleName = "my-custom-role"; + + Ref<ImdsCredentialProvider> Provider2(new ImdsCredentialProvider(Opts)); + SigV4Credentials Creds = Provider2->GetCredentials(); + CHECK(!Creds.AccessKeyId.empty()); + } +} + +TEST_CASE("imdscredentials.unreachable_endpoint") +{ + // Point at a non-existent server — should return empty credentials, not crash + ImdsCredentialProviderOptions Opts; + Opts.Endpoint = "http://127.0.0.1:1"; // unlikely to have anything listening + Opts.ConnectTimeout = std::chrono::milliseconds(100); + Opts.RequestTimeout = std::chrono::milliseconds(200); + + Ref<ImdsCredentialProvider> Provider(new ImdsCredentialProvider(Opts)); + SigV4Credentials Creds = Provider->GetCredentials(); + CHECK(Creds.AccessKeyId.empty()); +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen diff --git a/src/zenutil/cloud/minioprocess.cpp b/src/zenutil/cloud/minioprocess.cpp new file mode 100644 index 000000000..565705731 --- /dev/null +++ b/src/zenutil/cloud/minioprocess.cpp @@ -0,0 +1,174 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/minioprocess.h> + +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/process.h> +#include <zencore/timer.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +struct MinioProcess::Impl +{ + Impl(const MinioProcessOptions& Options) : m_Options(Options), m_HttpClient(fmt::format("http://localhost:{}/", Options.Port)) {} + ~Impl() = default; + + void SpawnMinioServer() + { + if (m_ProcessHandle.IsValid()) + { + return; + } + + // Create a clean temp data directory, removing any stale data from a previous run + std::error_code Ec; + m_DataDir = std::filesystem::temp_directory_path(Ec) / fmt::format("zen-minio-{}", GetCurrentProcessId()); + if (Ec) + { + ZEN_WARN("MinIO: Failed to get temp directory: {}", Ec.message()); + return; + } + std::filesystem::remove_all(m_DataDir, Ec); + Ec.clear(); + std::filesystem::create_directories(m_DataDir, Ec); + if (Ec) + { + ZEN_WARN("MinIO: Failed to create data directory '{}': {}", m_DataDir.string(), Ec.message()); + return; + } + + CreateProcOptions Options; + Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Environment.emplace_back("MINIO_ROOT_USER", m_Options.RootUser); + Options.Environment.emplace_back("MINIO_ROOT_PASSWORD", m_Options.RootPassword); + + const std::filesystem::path MinioExe = GetRunningExecutablePath().parent_path() / ("minio" ZEN_EXE_SUFFIX_LITERAL); + + std::string CommandLine = + fmt::format("minio" ZEN_EXE_SUFFIX_LITERAL " server {} --address :{} --quiet", m_DataDir.string(), m_Options.Port); + + CreateProcResult Result = CreateProc(MinioExe, CommandLine, Options); + + if (Result) + { + m_ProcessHandle.Initialize(Result); + + Stopwatch Timer; + + // Poll to check when the server is ready + do + { + Sleep(100); + HttpClient::Response Resp = m_HttpClient.Get("minio/health/live"); + if (Resp) + { + ZEN_INFO("MinIO server started successfully (waited {})", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + return; + } + } while (Timer.GetElapsedTimeMs() < 10000); + } + + // Report failure + ZEN_WARN("MinIO server failed to start within timeout period"); + } + + void StopMinioServer() + { + if (!m_ProcessHandle.IsValid()) + { + return; + } + + m_ProcessHandle.Kill(); + + // Clean up temp data directory + std::error_code Ec; + std::filesystem::remove_all(m_DataDir, Ec); + if (Ec) + { + ZEN_WARN("MinIO: Failed to clean up data directory '{}': {}", m_DataDir.string(), Ec.message()); + } + } + + void CreateBucket(std::string_view BucketName) + { + if (m_DataDir.empty()) + { + ZEN_WARN("MinIO: Cannot create bucket before data directory is initialized — call SpawnMinioServer() first"); + return; + } + + std::filesystem::path BucketDir = m_DataDir / std::string(BucketName); + std::error_code Ec; + std::filesystem::create_directories(BucketDir, Ec); + if (Ec) + { + ZEN_WARN("MinIO: Failed to create bucket directory '{}': {}", BucketDir.string(), Ec.message()); + } + } + + MinioProcessOptions m_Options; + ProcessHandle m_ProcessHandle; + HttpClient m_HttpClient; + std::filesystem::path m_DataDir; +}; + +MinioProcess::MinioProcess(const MinioProcessOptions& Options) : m_Impl(std::make_unique<Impl>(Options)) +{ +} + +MinioProcess::~MinioProcess() +{ + m_Impl->StopMinioServer(); +} + +void +MinioProcess::SpawnMinioServer() +{ + m_Impl->SpawnMinioServer(); +} + +void +MinioProcess::StopMinioServer() +{ + m_Impl->StopMinioServer(); +} + +void +MinioProcess::CreateBucket(std::string_view BucketName) +{ + m_Impl->CreateBucket(BucketName); +} + +uint16_t +MinioProcess::Port() const +{ + return m_Impl->m_Options.Port; +} + +std::string_view +MinioProcess::RootUser() const +{ + return m_Impl->m_Options.RootUser; +} + +std::string_view +MinioProcess::RootPassword() const +{ + return m_Impl->m_Options.RootPassword; +} + +std::string +MinioProcess::Endpoint() const +{ + return fmt::format("http://localhost:{}", m_Impl->m_Options.Port); +} + +} // namespace zen diff --git a/src/zenutil/cloud/mockimds.cpp b/src/zenutil/cloud/mockimds.cpp new file mode 100644 index 000000000..6919fab4d --- /dev/null +++ b/src/zenutil/cloud/mockimds.cpp @@ -0,0 +1,237 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/mockimds.h> + +#include <zencore/fmtutils.h> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +const char* +MockImdsService::BaseUri() const +{ + return "/"; +} + +void +MockImdsService::HandleRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // AWS endpoints live under /latest/ + if (Uri.starts_with("latest/")) + { + if (ActiveProvider == CloudProvider::AWS) + { + HandleAwsRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // Azure endpoints live under /metadata/ + if (Uri.starts_with("metadata/")) + { + if (ActiveProvider == CloudProvider::Azure) + { + HandleAzureRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // GCP endpoints live under /computeMetadata/ + if (Uri.starts_with("computeMetadata/")) + { + if (ActiveProvider == CloudProvider::GCP) + { + HandleGcpRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAwsRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // IMDSv2 token acquisition (PUT only) + if (Uri == "latest/api/token" && Request.RequestVerb() == HttpVerb::kPut) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.Token); + return; + } + + // Instance identity + if (Uri == "latest/meta-data/instance-id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.InstanceId); + return; + } + + if (Uri == "latest/meta-data/placement/availability-zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AvailabilityZone); + return; + } + + if (Uri == "latest/meta-data/instance-life-cycle") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.LifeCycle); + return; + } + + // Autoscaling lifecycle state — 404 when not in an ASG + if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") + { + if (Aws.AutoscalingState.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AutoscalingState); + return; + } + + // Spot interruption notice — 404 when no interruption pending + if (Uri == "latest/meta-data/spot/instance-action") + { + if (Aws.SpotAction.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.SpotAction); + return; + } + + // IAM role discovery — returns the role name + if (Uri == "latest/meta-data/iam/security-credentials/") + { + if (Aws.IamRoleName.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.IamRoleName); + return; + } + + // IAM credentials for a specific role + constexpr std::string_view kIamCredPrefix = "latest/meta-data/iam/security-credentials/"; + if (Uri.starts_with(kIamCredPrefix) && Uri.size() > kIamCredPrefix.size()) + { + std::string_view RequestedRole = Uri.substr(kIamCredPrefix.size()); + if (RequestedRole == Aws.IamRoleName) + { + std::string Json = + fmt::format(R"({{"Code":"Success","AccessKeyId":"{}","SecretAccessKey":"{}","Token":"{}","Expiration":"{}"}})", + Aws.IamAccessKeyId, + Aws.IamSecretAccessKey, + Aws.IamSessionToken, + Aws.IamExpiration); + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAzureRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // Instance metadata (single JSON document) + if (Uri == "metadata/instance") + { + std::string Json = fmt::format(R"({{"compute":{{"vmId":"{}","location":"{}","priority":"{}","vmScaleSetName":"{}"}}}})", + Azure.VmId, + Azure.Location, + Azure.Priority, + Azure.VmScaleSetName); + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + // Scheduled events for termination monitoring + if (Uri == "metadata/scheduledevents") + { + std::string Json; + if (Azure.ScheduledEventType.empty()) + { + Json = R"({"Events":[]})"; + } + else + { + Json = fmt::format(R"({{"Events":[{{"EventType":"{}","EventStatus":"{}"}}]}})", + Azure.ScheduledEventType, + Azure.ScheduledEventStatus); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleGcpRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + if (Uri == "computeMetadata/v1/instance/id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.InstanceId); + return; + } + + if (Uri == "computeMetadata/v1/instance/zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Zone); + return; + } + + if (Uri == "computeMetadata/v1/instance/scheduling/preemptible") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Preemptible); + return; + } + + if (Uri == "computeMetadata/v1/instance/maintenance-event") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.MaintenanceEvent); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zenutil/cloud/s3client.cpp b/src/zenutil/cloud/s3client.cpp new file mode 100644 index 000000000..88d844b61 --- /dev/null +++ b/src/zenutil/cloud/s3client.cpp @@ -0,0 +1,986 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/s3client.h> + +#include <zenutil/cloud/imdscredentials.h> +#include <zenutil/cloud/minioprocess.h> + +#include <zencore/except_fmt.h> +#include <zencore/iobuffer.h> +#include <zencore/memoryview.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <algorithm> + +namespace zen { + +namespace { + + /// The SHA-256 hash of an empty payload, precomputed + constexpr std::string_view EmptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; + + /// Simple XML value extractor. Finds the text content between <Tag> and </Tag>. + /// This is intentionally minimal - we only need to parse ListBucketResult responses. + /// Returns a string_view into the original XML when no entity decoding is needed. + std::string_view ExtractXmlValue(std::string_view Xml, std::string_view Tag) + { + std::string OpenTag = fmt::format("<{}>", Tag); + std::string CloseTag = fmt::format("</{}>", Tag); + + size_t Start = Xml.find(OpenTag); + if (Start == std::string_view::npos) + { + return {}; + } + Start += OpenTag.size(); + + size_t End = Xml.find(CloseTag, Start); + if (End == std::string_view::npos) + { + return {}; + } + + return Xml.substr(Start, End - Start); + } + + /// Decode the five standard XML entities (& < > " ') into a StringBuilderBase. + void DecodeXmlEntities(std::string_view Input, StringBuilderBase& Out) + { + if (Input.find('&') == std::string_view::npos) + { + Out.Append(Input); + return; + } + + for (size_t i = 0; i < Input.size(); ++i) + { + if (Input[i] == '&') + { + std::string_view Remaining = Input.substr(i); + if (Remaining.starts_with("&")) + { + Out.Append('&'); + i += 4; + } + else if (Remaining.starts_with("<")) + { + Out.Append('<'); + i += 3; + } + else if (Remaining.starts_with(">")) + { + Out.Append('>'); + i += 3; + } + else if (Remaining.starts_with(""")) + { + Out.Append('"'); + i += 5; + } + else if (Remaining.starts_with("'")) + { + Out.Append('\''); + i += 5; + } + else + { + Out.Append(Input[i]); + } + } + else + { + Out.Append(Input[i]); + } + } + } + + /// Convenience: decode XML entities and return as std::string. + std::string DecodeXmlEntities(std::string_view Input) + { + if (Input.find('&') == std::string_view::npos) + { + return std::string(Input); + } + + ExtendableStringBuilder<256> Sb; + DecodeXmlEntities(Input, Sb); + return Sb.ToString(); + } + + /// Join a path and canonical query string into a full request path for the HTTP client. + std::string BuildRequestPath(std::string_view Path, std::string_view CanonicalQS) + { + if (CanonicalQS.empty()) + { + return std::string(Path); + } + return fmt::format("{}?{}", Path, CanonicalQS); + } + + /// Case-insensitive header lookup in an HttpClient response header map. + const std::string* FindResponseHeader(const HttpClient::KeyValueMap& Headers, std::string_view Name) + { + for (const auto& [K, V] : *Headers) + { + if (StrCaseCompare(K, Name) == 0) + { + return &V; + } + } + return nullptr; + } + +} // namespace + +S3Client::S3Client(const S3ClientOptions& Options) +: m_Log(logging::Get("s3")) +, m_BucketName(Options.BucketName) +, m_Region(Options.Region) +, m_Endpoint(Options.Endpoint) +, m_PathStyle(Options.PathStyle) +, m_Credentials(Options.Credentials) +, m_CredentialProvider(Options.CredentialProvider) +, m_HttpClient(BuildEndpoint(), + HttpClientSettings{ + .LogCategory = "s3", + .ConnectTimeout = Options.ConnectTimeout, + .Timeout = Options.Timeout, + .RetryCount = Options.RetryCount, + }) +{ + m_Host = BuildHostHeader(); + ZEN_INFO("S3 client configured for bucket '{}' in region '{}' (endpoint: {}, {})", + m_BucketName, + m_Region, + m_HttpClient.GetBaseUri(), + m_PathStyle ? "path-style" : "virtual-hosted"); +} + +S3Client::~S3Client() = default; + +SigV4Credentials +S3Client::GetCurrentCredentials() +{ + if (m_CredentialProvider) + { + SigV4Credentials Creds = m_CredentialProvider->GetCredentials(); + if (!Creds.AccessKeyId.empty()) + { + // Invalidate the signing key cache when the access key changes + if (Creds.AccessKeyId != m_Credentials.AccessKeyId) + { + RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); + m_CachedDateStamp.clear(); + } + m_Credentials = Creds; + } + return m_Credentials; + } + return m_Credentials; +} + +std::string +S3Client::BuildEndpoint() const +{ + if (!m_Endpoint.empty()) + { + return m_Endpoint; + } + + if (m_PathStyle) + { + // Path-style: https://s3.region.amazonaws.com + return fmt::format("https://s3.{}.amazonaws.com", m_Region); + } + + // Virtual-hosted style: https://bucket.s3.region.amazonaws.com + return fmt::format("https://{}.s3.{}.amazonaws.com", m_BucketName, m_Region); +} + +std::string +S3Client::BuildHostHeader() const +{ + if (!m_Endpoint.empty()) + { + // Extract host from custom endpoint URL (strip scheme) + std::string_view Ep = m_Endpoint; + if (size_t Pos = Ep.find("://"); Pos != std::string_view::npos) + { + Ep = Ep.substr(Pos + 3); + } + // Strip trailing slash + if (!Ep.empty() && Ep.back() == '/') + { + Ep = Ep.substr(0, Ep.size() - 1); + } + return std::string(Ep); + } + + if (m_PathStyle) + { + return fmt::format("s3.{}.amazonaws.com", m_Region); + } + + return fmt::format("{}.s3.{}.amazonaws.com", m_BucketName, m_Region); +} + +std::string +S3Client::KeyToPath(std::string_view Key) const +{ + if (m_PathStyle) + { + return fmt::format("/{}/{}", m_BucketName, Key); + } + return fmt::format("/{}", Key); +} + +std::string +S3Client::BucketRootPath() const +{ + if (m_PathStyle) + { + return fmt::format("/{}/", m_BucketName); + } + return "/"; +} + +Sha256Digest +S3Client::GetSigningKey(std::string_view DateStamp) +{ + // Fast path: shared lock for cache hit (common case — key only changes once per day) + { + RwLock::SharedLockScope SharedLock(m_SigningKeyLock); + if (m_CachedDateStamp == DateStamp) + { + return m_CachedSigningKey; + } + } + + // Slow path: exclusive lock to recompute the signing key + RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); + + // Double-check after acquiring exclusive lock (another thread may have updated it) + if (m_CachedDateStamp == DateStamp) + { + return m_CachedSigningKey; + } + + std::string SecretPrefix = fmt::format("AWS4{}", m_Credentials.SecretAccessKey); + + Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), DateStamp.data(), DateStamp.size()); + SecureZeroSecret(SecretPrefix.data(), SecretPrefix.size()); + + Sha256Digest RegionKey = ComputeHmacSha256(DateKey, m_Region); + Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, "s3"); + m_CachedSigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); + m_CachedDateStamp = std::string(DateStamp); + + return m_CachedSigningKey; +} + +HttpClient::KeyValueMap +S3Client::SignRequest(std::string_view Method, std::string_view Path, std::string_view CanonicalQueryString, std::string_view PayloadHash) +{ + SigV4Credentials Credentials = GetCurrentCredentials(); + + std::string AmzDate = GetAmzTimestamp(); + + // Build sorted headers to sign (must be sorted by lowercase name) + std::vector<std::pair<std::string, std::string>> HeadersToSign; + HeadersToSign.emplace_back("host", m_Host); + HeadersToSign.emplace_back("x-amz-content-sha256", std::string(PayloadHash)); + HeadersToSign.emplace_back("x-amz-date", AmzDate); + if (!Credentials.SessionToken.empty()) + { + HeadersToSign.emplace_back("x-amz-security-token", Credentials.SessionToken); + } + std::sort(HeadersToSign.begin(), HeadersToSign.end()); + + std::string_view DateStamp(AmzDate.data(), 8); + Sha256Digest SigningKey = GetSigningKey(DateStamp); + + SigV4SignedHeaders Signed = + SignRequestV4(Credentials, Method, Path, CanonicalQueryString, m_Region, "s3", AmzDate, HeadersToSign, PayloadHash, &SigningKey); + + HttpClient::KeyValueMap Result; + Result->emplace("Authorization", std::move(Signed.Authorization)); + Result->emplace("x-amz-date", std::move(Signed.AmzDate)); + Result->emplace("x-amz-content-sha256", std::move(Signed.PayloadHash)); + if (!Credentials.SessionToken.empty()) + { + Result->emplace("x-amz-security-token", Credentials.SessionToken); + } + + return Result; +} + +S3Result +S3Client::PutObject(std::string_view Key, IoBuffer Content) +{ + std::string Path = KeyToPath(Key); + + // Hash the payload + std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); + + HttpClient::KeyValueMap Headers = SignRequest("PUT", Path, "", PayloadHash); + + HttpClient::Response Response = m_HttpClient.Put(Path, Content, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 PUT failed"); + ZEN_WARN("S3 PUT '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 PUT '{}' succeeded ({} bytes)", Key, Content.GetSize()); + return {}; +} + +S3GetObjectResult +S3Client::GetObject(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + + HttpClient::KeyValueMap Headers = SignRequest("GET", Path, "", EmptyPayloadHash); + + HttpClient::Response Response = m_HttpClient.Get(Path, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 GET failed"); + ZEN_WARN("S3 GET '{}' failed: {}", Key, Err); + return S3GetObjectResult{S3Result{std::move(Err)}, {}}; + } + + ZEN_DEBUG("S3 GET '{}' succeeded ({} bytes)", Key, Response.ResponsePayload.GetSize()); + return S3GetObjectResult{{}, std::move(Response.ResponsePayload)}; +} + +S3Result +S3Client::DeleteObject(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + + HttpClient::KeyValueMap Headers = SignRequest("DELETE", Path, "", EmptyPayloadHash); + + HttpClient::Response Response = m_HttpClient.Delete(Path, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 DELETE failed"); + ZEN_WARN("S3 DELETE '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 DELETE '{}' succeeded", Key); + return {}; +} + +S3HeadObjectResult +S3Client::HeadObject(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + + HttpClient::KeyValueMap Headers = SignRequest("HEAD", Path, "", EmptyPayloadHash); + + HttpClient::Response Response = m_HttpClient.Head(Path, Headers); + if (!Response.IsSuccess()) + { + if (Response.StatusCode == HttpResponseCode::NotFound) + { + return S3HeadObjectResult{{}, {}, HeadObjectResult::NotFound}; + } + + std::string Err = Response.ErrorMessage("S3 HEAD failed"); + ZEN_WARN("S3 HEAD '{}' failed: {}", Key, Err); + return S3HeadObjectResult{S3Result{std::move(Err)}, {}, HeadObjectResult::Error}; + } + + S3ObjectInfo Info; + Info.Key = std::string(Key); + + if (const std::string* V = FindResponseHeader(Response.Header, "content-length")) + { + Info.Size = ParseInt<uint64_t>(*V).value_or(0); + } + + if (const std::string* V = FindResponseHeader(Response.Header, "etag")) + { + Info.ETag = *V; + } + + if (const std::string* V = FindResponseHeader(Response.Header, "last-modified")) + { + Info.LastModified = *V; + } + + ZEN_DEBUG("S3 HEAD '{}' succeeded (size={})", Key, Info.Size); + return S3HeadObjectResult{{}, std::move(Info), HeadObjectResult::Found}; +} + +S3ListObjectsResult +S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) +{ + S3ListObjectsResult Result; + + std::string ContinuationToken; + + for (;;) + { + // Build query parameters for ListObjectsV2 + std::vector<std::pair<std::string, std::string>> QueryParams; + QueryParams.emplace_back("list-type", "2"); + if (!Prefix.empty()) + { + QueryParams.emplace_back("prefix", std::string(Prefix)); + } + if (MaxKeys > 0) + { + QueryParams.emplace_back("max-keys", fmt::format("{}", MaxKeys)); + } + if (!ContinuationToken.empty()) + { + QueryParams.emplace_back("continuation-token", ContinuationToken); + } + + std::string CanonicalQS = BuildCanonicalQueryString(std::move(QueryParams)); + std::string RootPath = BucketRootPath(); + HttpClient::KeyValueMap Headers = SignRequest("GET", RootPath, CanonicalQS, EmptyPayloadHash); + + std::string FullPath = BuildRequestPath(RootPath, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Get(FullPath, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 ListObjectsV2 failed"); + ZEN_WARN("S3 ListObjectsV2 prefix='{}' failed: {}", Prefix, Err); + Result.Error = std::move(Err); + return Result; + } + + // Parse the XML response to extract object keys + std::string_view ResponseBody = Response.AsText(); + + // Find all <Contents> elements + std::string_view Remaining = ResponseBody; + while (true) + { + size_t ContentsStart = Remaining.find("<Contents>"); + if (ContentsStart == std::string_view::npos) + { + break; + } + + size_t ContentsEnd = Remaining.find("</Contents>", ContentsStart); + if (ContentsEnd == std::string_view::npos) + { + break; + } + + std::string_view ContentsXml = Remaining.substr(ContentsStart, ContentsEnd - ContentsStart + 11); + + S3ObjectInfo Info; + Info.Key = DecodeXmlEntities(ExtractXmlValue(ContentsXml, "Key")); + Info.ETag = DecodeXmlEntities(ExtractXmlValue(ContentsXml, "ETag")); + Info.LastModified = std::string(ExtractXmlValue(ContentsXml, "LastModified")); + + std::string_view SizeStr = ExtractXmlValue(ContentsXml, "Size"); + if (!SizeStr.empty()) + { + Info.Size = ParseInt<uint64_t>(SizeStr).value_or(0); + } + + if (!Info.Key.empty()) + { + Result.Objects.push_back(std::move(Info)); + } + + Remaining = Remaining.substr(ContentsEnd + 11); + } + + // Check if there are more pages + std::string_view IsTruncated = ExtractXmlValue(ResponseBody, "IsTruncated"); + if (IsTruncated != "true") + { + break; + } + + std::string_view NextToken = ExtractXmlValue(ResponseBody, "NextContinuationToken"); + if (NextToken.empty()) + { + break; + } + + ContinuationToken = std::string(NextToken); + ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' fetching next page ({} objects so far)", Prefix, Result.Objects.size()); + } + + ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' returned {} objects", Prefix, Result.Objects.size()); + return Result; +} + +////////////////////////////////////////////////////////////////////////// +// Multipart Upload + +S3CreateMultipartUploadResult +S3Client::CreateMultipartUpload(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({{"uploads", ""}}); + + HttpClient::KeyValueMap Headers = SignRequest("POST", Path, CanonicalQS, EmptyPayloadHash); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Post(FullPath, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 CreateMultipartUpload failed"); + ZEN_WARN("S3 CreateMultipartUpload '{}' failed: {}", Key, Err); + return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; + } + + // Parse UploadId from XML response: + // <InitiateMultipartUploadResult> + // <Bucket>...</Bucket> + // <Key>...</Key> + // <UploadId>...</UploadId> + // </InitiateMultipartUploadResult> + std::string_view ResponseBody = Response.AsText(); + std::string_view UploadId = ExtractXmlValue(ResponseBody, "UploadId"); + if (UploadId.empty()) + { + std::string Err = "failed to parse UploadId from CreateMultipartUpload response"; + ZEN_WARN("S3 CreateMultipartUpload '{}': {}", Key, Err); + return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; + } + + ZEN_DEBUG("S3 CreateMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + return S3CreateMultipartUploadResult{{}, std::string(UploadId)}; +} + +S3UploadPartResult +S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t PartNumber, IoBuffer Content) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({ + {"partNumber", fmt::format("{}", PartNumber)}, + {"uploadId", std::string(UploadId)}, + }); + + std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); + + HttpClient::KeyValueMap Headers = SignRequest("PUT", Path, CanonicalQS, PayloadHash); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Put(FullPath, Content, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage(fmt::format("S3 UploadPart {} failed", PartNumber)); + ZEN_WARN("S3 UploadPart '{}' part {} failed: {}", Key, PartNumber, Err); + return S3UploadPartResult{S3Result{std::move(Err)}, {}}; + } + + // Extract ETag from response headers + const std::string* ETag = FindResponseHeader(Response.Header, "etag"); + if (!ETag) + { + std::string Err = "S3 UploadPart response missing ETag header"; + ZEN_WARN("S3 UploadPart '{}' part {}: {}", Key, PartNumber, Err); + return S3UploadPartResult{S3Result{std::move(Err)}, {}}; + } + + ZEN_DEBUG("S3 UploadPart '{}' part {} succeeded ({} bytes, etag={})", Key, PartNumber, Content.GetSize(), *ETag); + return S3UploadPartResult{{}, *ETag}; +} + +S3Result +S3Client::CompleteMultipartUpload(std::string_view Key, + std::string_view UploadId, + const std::vector<std::pair<uint32_t, std::string>>& PartETags) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); + + // Build the CompleteMultipartUpload XML payload + ExtendableStringBuilder<1024> XmlBody; + XmlBody.Append("<CompleteMultipartUpload>"); + for (const auto& [PartNumber, ETag] : PartETags) + { + XmlBody.Append(fmt::format("<Part><PartNumber>{}</PartNumber><ETag>{}</ETag></Part>", PartNumber, ETag)); + } + XmlBody.Append("</CompleteMultipartUpload>"); + + std::string_view XmlView = XmlBody.ToView(); + std::string PayloadHash = Sha256ToHex(ComputeSha256(XmlView)); + + HttpClient::KeyValueMap Headers = SignRequest("POST", Path, CanonicalQS, PayloadHash); + Headers->emplace("Content-Type", "application/xml"); + + IoBuffer Payload(IoBuffer::Clone, XmlView.data(), XmlView.size()); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Post(FullPath, Payload, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 CompleteMultipartUpload failed"); + ZEN_WARN("S3 CompleteMultipartUpload '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + // Check for error in response body - S3 can return 200 with an error in the XML body + std::string_view ResponseBody = Response.AsText(); + if (ResponseBody.find("<Error>") != std::string_view::npos) + { + std::string_view ErrorCode = ExtractXmlValue(ResponseBody, "Code"); + std::string_view ErrorMessage = ExtractXmlValue(ResponseBody, "Message"); + std::string Err = fmt::format("S3 CompleteMultipartUpload '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); + ZEN_WARN("{}", Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 CompleteMultipartUpload '{}' succeeded ({} parts)", Key, PartETags.size()); + return {}; +} + +S3Result +S3Client::AbortMultipartUpload(std::string_view Key, std::string_view UploadId) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); + + HttpClient::KeyValueMap Headers = SignRequest("DELETE", Path, CanonicalQS, EmptyPayloadHash); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Delete(FullPath, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 AbortMultipartUpload failed"); + ZEN_WARN("S3 AbortMultipartUpload '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 AbortMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + return {}; +} + +std::string +S3Client::GeneratePresignedGetUrl(std::string_view Key, std::chrono::seconds ExpiresIn) +{ + return GeneratePresignedUrlForMethod(Key, "GET", ExpiresIn); +} + +std::string +S3Client::GeneratePresignedPutUrl(std::string_view Key, std::chrono::seconds ExpiresIn) +{ + return GeneratePresignedUrlForMethod(Key, "PUT", ExpiresIn); +} + +std::string +S3Client::GeneratePresignedUrlForMethod(std::string_view Key, std::string_view Method, std::chrono::seconds ExpiresIn) +{ + std::string Path = KeyToPath(Key); + std::string Scheme = "https"; + + if (!m_Endpoint.empty() && m_Endpoint.starts_with("http://")) + { + Scheme = "http"; + } + + SigV4Credentials Credentials = GetCurrentCredentials(); + return GeneratePresignedUrl(Credentials, Method, Scheme, m_Host, Path, m_Region, "s3", ExpiresIn); +} + +S3Result +S3Client::PutObjectMultipart(std::string_view Key, IoBuffer Content, uint64_t PartSize) +{ + const uint64_t ContentSize = Content.GetSize(); + + // If the content fits in a single part, just use PutObject + if (ContentSize <= PartSize) + { + return PutObject(Key, Content); + } + + ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, ContentSize, (ContentSize + PartSize - 1) / PartSize); + + // Initiate multipart upload + + S3CreateMultipartUploadResult InitResult = CreateMultipartUpload(Key); + if (!InitResult) + { + return S3Result{std::move(InitResult.Error)}; + } + + const std::string& UploadId = InitResult.UploadId; + + // Upload parts sequentially + // TODO: upload parts in parallel for improved throughput on large uploads + + std::vector<std::pair<uint32_t, std::string>> PartETags; + uint64_t Offset = 0; + uint32_t PartNumber = 1; + + while (Offset < ContentSize) + { + uint64_t ThisPartSize = std::min(PartSize, ContentSize - Offset); + + // Create a sub-buffer referencing the part data within the original content + IoBuffer PartContent(Content, Offset, ThisPartSize); + + S3UploadPartResult PartResult = UploadPart(Key, UploadId, PartNumber, PartContent); + if (!PartResult) + { + // Attempt to abort the multipart upload on failure + AbortMultipartUpload(Key, UploadId); + return S3Result{std::move(PartResult.Error)}; + } + + PartETags.emplace_back(PartNumber, std::move(PartResult.ETag)); + Offset += ThisPartSize; + PartNumber++; + } + + // Complete multipart upload + S3Result CompleteResult = CompleteMultipartUpload(Key, UploadId, PartETags); + if (!CompleteResult) + { + AbortMultipartUpload(Key, UploadId); + return CompleteResult; + } + + ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), ContentSize); + return {}; +} + +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +s3client_forcelink() +{ +} + +TEST_SUITE_BEGIN("util.cloud.s3client"); + +TEST_CASE("s3client.xml_extract") +{ + std::string_view Xml = + "<Contents><Key>test/file.txt</Key><Size>1234</Size>" + "<ETag>\"abc123\"</ETag><LastModified>2024-01-01T00:00:00Z</LastModified></Contents>"; + + CHECK(ExtractXmlValue(Xml, "Key") == "test/file.txt"); + CHECK(ExtractXmlValue(Xml, "Size") == "1234"); + CHECK(ExtractXmlValue(Xml, "ETag") == "\"abc123\""); + CHECK(ExtractXmlValue(Xml, "LastModified") == "2024-01-01T00:00:00Z"); + CHECK(ExtractXmlValue(Xml, "NonExistent") == ""); +} + +TEST_CASE("s3client.xml_entity_decode") +{ + CHECK(DecodeXmlEntities("no entities") == "no entities"); + CHECK(DecodeXmlEntities("a&b") == "a&b"); + CHECK(DecodeXmlEntities("<tag>") == "<tag>"); + CHECK(DecodeXmlEntities(""hello'") == "\"hello'"); + CHECK(DecodeXmlEntities("&&") == "&&"); + CHECK(DecodeXmlEntities("") == ""); + + // Key with entities as S3 would return it + std::string_view Xml = "<Key>path/file&name<1>.txt</Key>"; + CHECK(DecodeXmlEntities(ExtractXmlValue(Xml, "Key")) == "path/file&name<1>.txt"); +} + +TEST_CASE("s3client.path_style_addressing") +{ + // Verify path-style builds /{bucket}/{key} paths + S3ClientOptions Opts; + Opts.BucketName = "test-bucket"; + Opts.Region = "us-east-1"; + Opts.Endpoint = "http://localhost:9000"; + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = "minioadmin"; + Opts.Credentials.SecretAccessKey = "minioadmin"; + + S3Client Client(Opts); + CHECK(Client.BucketName() == "test-bucket"); + CHECK(Client.Region() == "us-east-1"); +} + +TEST_CASE("s3client.virtual_hosted_addressing") +{ + // Verify virtual-hosted style derives endpoint from region + bucket + S3ClientOptions Opts; + Opts.BucketName = "my-bucket"; + Opts.Region = "eu-west-1"; + Opts.PathStyle = false; + Opts.Credentials.AccessKeyId = "key"; + Opts.Credentials.SecretAccessKey = "secret"; + + S3Client Client(Opts); + CHECK(Client.BucketName() == "my-bucket"); + CHECK(Client.Region() == "eu-west-1"); +} + +TEST_CASE("s3client.minio_integration") +{ + using namespace std::literals; + + // Spawn a local MinIO server + MinioProcessOptions MinioOpts; + MinioOpts.Port = 19000; + MinioOpts.RootUser = "testuser"; + MinioOpts.RootPassword = "testpassword"; + + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + + // Pre-create the test bucket (creates a subdirectory in MinIO's data dir) + Minio.CreateBucket("integration-test"); + + // Configure S3Client for the test bucket + S3ClientOptions Opts; + Opts.BucketName = "integration-test"; + Opts.Region = "us-east-1"; + Opts.Endpoint = Minio.Endpoint(); + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = std::string(Minio.RootUser()); + Opts.Credentials.SecretAccessKey = std::string(Minio.RootPassword()); + + S3Client Client(Opts); + + SUBCASE("put_get_delete") + { + // PUT + std::string_view TestData = "hello, minio integration test!"sv; + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); + S3Result PutRes = Client.PutObject("test/hello.txt", std::move(Content)); + REQUIRE(PutRes.IsSuccess()); + + // GET + S3GetObjectResult GetRes = Client.GetObject("test/hello.txt"); + REQUIRE(GetRes.IsSuccess()); + CHECK(GetRes.AsText() == TestData); + + // HEAD + S3HeadObjectResult HeadRes = Client.HeadObject("test/hello.txt"); + REQUIRE(HeadRes.IsSuccess()); + CHECK(HeadRes.Status == HeadObjectResult::Found); + CHECK(HeadRes.Info.Size == TestData.size()); + + // DELETE + S3Result DelRes = Client.DeleteObject("test/hello.txt"); + REQUIRE(DelRes.IsSuccess()); + + // HEAD after delete + S3HeadObjectResult HeadRes2 = Client.HeadObject("test/hello.txt"); + REQUIRE(HeadRes2.IsSuccess()); + CHECK(HeadRes2.Status == HeadObjectResult::NotFound); + } + + SUBCASE("head_not_found") + { + S3HeadObjectResult Res = Client.HeadObject("nonexistent/key.dat"); + CHECK(Res.IsSuccess()); + CHECK(Res.Status == HeadObjectResult::NotFound); + } + + SUBCASE("list_objects") + { + // Upload several objects with a common prefix + for (int i = 0; i < 3; ++i) + { + std::string Key = fmt::format("list-test/item-{}.txt", i); + std::string Payload = fmt::format("payload-{}", i); + IoBuffer Buf = IoBufferBuilder::MakeFromMemory(MakeMemoryView(Payload)); + S3Result Res = Client.PutObject(Key, std::move(Buf)); + REQUIRE(Res.IsSuccess()); + } + + // List with prefix + S3ListObjectsResult ListRes = Client.ListObjects("list-test/"); + REQUIRE(ListRes.IsSuccess()); + CHECK(ListRes.Objects.size() == 3); + + // Verify keys are present + std::vector<std::string> Keys; + for (const S3ObjectInfo& Obj : ListRes.Objects) + { + Keys.push_back(Obj.Key); + } + std::sort(Keys.begin(), Keys.end()); + CHECK(Keys[0] == "list-test/item-0.txt"); + CHECK(Keys[1] == "list-test/item-1.txt"); + CHECK(Keys[2] == "list-test/item-2.txt"); + + // Cleanup + for (int i = 0; i < 3; ++i) + { + Client.DeleteObject(fmt::format("list-test/item-{}.txt", i)); + } + } + + SUBCASE("multipart_upload") + { + // Create a payload large enough to exercise multipart (use minimum part size) + constexpr uint64_t PartSize = 5 * 1024 * 1024; // 5 MB minimum + constexpr uint64_t PayloadSize = PartSize + 1024; // slightly over one part + + std::string LargePayload(PayloadSize, 'X'); + // Add some variation + for (uint64_t i = 0; i < PayloadSize; i += 1024) + { + LargePayload[i] = char('A' + (i / 1024) % 26); + } + + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(LargePayload)); + S3Result Res = Client.PutObjectMultipart("multipart/large.bin", std::move(Content), PartSize); + REQUIRE(Res.IsSuccess()); + + // Verify via GET + S3GetObjectResult GetRes = Client.GetObject("multipart/large.bin"); + REQUIRE(GetRes.IsSuccess()); + CHECK(GetRes.Content.GetSize() == PayloadSize); + CHECK(GetRes.AsText() == std::string_view(LargePayload)); + + // Cleanup + Client.DeleteObject("multipart/large.bin"); + } + + SUBCASE("presigned_urls") + { + // Upload an object + std::string_view TestData = "presigned-url-test-data"sv; + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); + S3Result PutRes = Client.PutObject("presigned/test.txt", std::move(Content)); + REQUIRE(PutRes.IsSuccess()); + + // Generate a pre-signed GET URL + std::string Url = Client.GeneratePresignedGetUrl("presigned/test.txt", std::chrono::seconds(60)); + CHECK(!Url.empty()); + CHECK(Url.find("X-Amz-Signature") != std::string::npos); + + // Fetch via the pre-signed URL (no auth headers needed) + HttpClient Hc(Minio.Endpoint()); + // Extract the path+query from the full URL + std::string_view UrlView = Url; + size_t PathStart = UrlView.find('/', UrlView.find("://") + 3); + std::string PathAndQuery(UrlView.substr(PathStart)); + HttpClient::Response Resp = Hc.Get(PathAndQuery); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == TestData); + + // Cleanup + Client.DeleteObject("presigned/test.txt"); + } + + Minio.StopMinioServer(); +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen diff --git a/src/zenutil/cloud/sigv4.cpp b/src/zenutil/cloud/sigv4.cpp new file mode 100644 index 000000000..055ccb2ad --- /dev/null +++ b/src/zenutil/cloud/sigv4.cpp @@ -0,0 +1,531 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/sigv4.h> + +#include <zencore/string.h> +#include <zencore/testing.h> + +#include <algorithm> +#include <chrono> +#include <cstring> +#include <ctime> + +// Platform-specific crypto backends +#if ZEN_PLATFORM_WINDOWS +# define ZEN_S3_USE_BCRYPT 1 +#else +# define ZEN_S3_USE_BCRYPT 0 +#endif + +#ifndef ZEN_S3_USE_OPENSSL +# if ZEN_S3_USE_BCRYPT +# define ZEN_S3_USE_OPENSSL 0 +# else +# define ZEN_S3_USE_OPENSSL 1 +# endif +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> + +#if ZEN_S3_USE_OPENSSL +# include <openssl/evp.h> +#elif ZEN_S3_USE_BCRYPT +# include <zencore/windows.h> +# include <bcrypt.h> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// SHA-256 + +#if ZEN_S3_USE_OPENSSL + +Sha256Digest +ComputeSha256(const void* Data, size_t Size) +{ + Sha256Digest Result; + unsigned int Len = 0; + EVP_Digest(Data, Size, Result.data(), &Len, EVP_sha256(), nullptr); + ZEN_ASSERT(Len == 32); + return Result; +} + +Sha256Digest +ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize) +{ + Sha256Digest Result; + + EVP_MAC* Mac = EVP_MAC_fetch(nullptr, "HMAC", nullptr); + ZEN_ASSERT(Mac != nullptr); + + EVP_MAC_CTX* Ctx = EVP_MAC_CTX_new(Mac); + ZEN_ASSERT(Ctx != nullptr); + + OSSL_PARAM Params[] = { + OSSL_PARAM_construct_utf8_string("digest", const_cast<char*>("SHA256"), 0), + OSSL_PARAM_construct_end(), + }; + + int Rc = EVP_MAC_init(Ctx, reinterpret_cast<const unsigned char*>(Key), KeySize, Params); + ZEN_ASSERT(Rc == 1); + + Rc = EVP_MAC_update(Ctx, reinterpret_cast<const unsigned char*>(Data), DataSize); + ZEN_ASSERT(Rc == 1); + + size_t OutLen = 0; + Rc = EVP_MAC_final(Ctx, Result.data(), &OutLen, Result.size()); + ZEN_ASSERT(Rc == 1); + ZEN_ASSERT(OutLen == 32); + + EVP_MAC_CTX_free(Ctx); + EVP_MAC_free(Mac); + + return Result; +} + +#elif ZEN_S3_USE_BCRYPT + +namespace { + +# define NT_SUCCESS(Status) (((NTSTATUS)(Status)) >= 0) + + Sha256Digest BcryptHash(BCRYPT_ALG_HANDLE Algorithm, const void* Data, size_t DataSize) + { + Sha256Digest Result; + BCRYPT_HASH_HANDLE HashHandle = nullptr; + NTSTATUS Status; + + Status = BCryptCreateHash(Algorithm, &HashHandle, nullptr, 0, nullptr, 0, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + Status = BCryptHashData(HashHandle, (PUCHAR)Data, (ULONG)DataSize, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + Status = BCryptFinishHash(HashHandle, Result.data(), (ULONG)Result.size(), 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + BCryptDestroyHash(HashHandle); + return Result; + } + + Sha256Digest BcryptHmac(BCRYPT_ALG_HANDLE Algorithm, const void* Key, size_t KeySize, const void* Data, size_t DataSize) + { + Sha256Digest Result; + BCRYPT_HASH_HANDLE HashHandle = nullptr; + NTSTATUS Status; + + Status = BCryptCreateHash(Algorithm, &HashHandle, nullptr, 0, (PUCHAR)Key, (ULONG)KeySize, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + Status = BCryptHashData(HashHandle, (PUCHAR)Data, (ULONG)DataSize, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + Status = BCryptFinishHash(HashHandle, Result.data(), (ULONG)Result.size(), 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + BCryptDestroyHash(HashHandle); + return Result; + } + + struct BcryptAlgorithmHandles + { + BCRYPT_ALG_HANDLE Sha256 = nullptr; + BCRYPT_ALG_HANDLE HmacSha256 = nullptr; + + BcryptAlgorithmHandles() + { + NTSTATUS Status; + Status = BCryptOpenAlgorithmProvider(&Sha256, BCRYPT_SHA256_ALGORITHM, nullptr, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + Status = BCryptOpenAlgorithmProvider(&HmacSha256, BCRYPT_SHA256_ALGORITHM, nullptr, BCRYPT_ALG_HANDLE_HMAC_FLAG); + ZEN_ASSERT(NT_SUCCESS(Status)); + } + + ~BcryptAlgorithmHandles() + { + if (Sha256) + { + BCryptCloseAlgorithmProvider(Sha256, 0); + } + if (HmacSha256) + { + BCryptCloseAlgorithmProvider(HmacSha256, 0); + } + } + }; + + BcryptAlgorithmHandles& GetBcryptHandles() + { + static BcryptAlgorithmHandles s_Handles; + return s_Handles; + } + +} // namespace + +Sha256Digest +ComputeSha256(const void* Data, size_t Size) +{ + return BcryptHash(GetBcryptHandles().Sha256, Data, Size); +} + +Sha256Digest +ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize) +{ + return BcryptHmac(GetBcryptHandles().HmacSha256, Key, KeySize, Data, DataSize); +} + +#endif + +Sha256Digest +ComputeSha256(std::string_view Data) +{ + return ComputeSha256(Data.data(), Data.size()); +} + +Sha256Digest +ComputeHmacSha256(const Sha256Digest& Key, std::string_view Data) +{ + return ComputeHmacSha256(Key.data(), Key.size(), Data.data(), Data.size()); +} + +std::string +Sha256ToHex(const Sha256Digest& Digest) +{ + std::string Result; + Result.reserve(64); + for (uint8_t Byte : Digest) + { + fmt::format_to(std::back_inserter(Result), "{:02x}", Byte); + } + return Result; +} + +void +SecureZeroSecret(void* Data, size_t Size) +{ +#if ZEN_PLATFORM_WINDOWS + SecureZeroMemory(Data, Size); +#elif ZEN_PLATFORM_LINUX + explicit_bzero(Data, Size); +#else + // Portable fallback: volatile pointer prevents the compiler from optimizing away the memset + static void* (*const volatile VolatileMemset)(void*, int, size_t) = memset; + VolatileMemset(Data, 0, Size); +#endif +} + +////////////////////////////////////////////////////////////////////////// +// SigV4 signing + +namespace { + + std::string GetDateStamp(std::string_view AmzDate) + { + // AmzDate is "YYYYMMDDTHHMMSSZ", date stamp is first 8 chars + return std::string(AmzDate.substr(0, 8)); + } + +} // namespace + +std::string +GetAmzTimestamp() +{ + auto Now = std::chrono::system_clock::now(); + std::time_t NowTime = std::chrono::system_clock::to_time_t(Now); + + struct tm Tm; +#if ZEN_PLATFORM_WINDOWS + gmtime_s(&Tm, &NowTime); +#else + gmtime_r(&NowTime, &Tm); +#endif + + char Buf[32]; + std::strftime(Buf, sizeof(Buf), "%Y%m%dT%H%M%SZ", &Tm); + return std::string(Buf); +} + +std::string +AwsUriEncode(std::string_view Input, bool EncodeSlash) +{ + ExtendableStringBuilder<256> Result; + for (char C : Input) + { + if ((C >= 'A' && C <= 'Z') || (C >= 'a' && C <= 'z') || (C >= '0' && C <= '9') || C == '_' || C == '-' || C == '~' || C == '.') + { + Result.Append(C); + } + else if (C == '/' && !EncodeSlash) + { + Result.Append(C); + } + else + { + Result.Append(fmt::format("%{:02X}", static_cast<unsigned char>(C))); + } + } + return std::string(Result.ToView()); +} + +std::string +BuildCanonicalQueryString(std::vector<std::pair<std::string, std::string>> Parameters) +{ + if (Parameters.empty()) + { + return {}; + } + + // Sort by key name, then by value (as required by SigV4) + std::sort(Parameters.begin(), Parameters.end()); + + ExtendableStringBuilder<512> Result; + for (size_t i = 0; i < Parameters.size(); ++i) + { + if (i > 0) + { + Result.Append('&'); + } + Result.Append(AwsUriEncode(Parameters[i].first)); + Result.Append('='); + Result.Append(AwsUriEncode(Parameters[i].second)); + } + return std::string(Result.ToView()); +} + +SigV4SignedHeaders +SignRequestV4(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Url, + std::string_view CanonicalQueryString, + std::string_view Region, + std::string_view Service, + std::string_view AmzDate, + const std::vector<std::pair<std::string, std::string>>& Headers, + std::string_view PayloadHash, + const Sha256Digest* SigningKeyPtr) +{ + SigV4SignedHeaders Result; + Result.AmzDate = std::string(AmzDate); + Result.PayloadHash = std::string(PayloadHash); + + std::string DateStamp = GetDateStamp(Result.AmzDate); + + // Step 1: Create canonical request + // CanonicalRequest = + // HTTPRequestMethod + '\n' + + // CanonicalURI + '\n' + + // CanonicalQueryString + '\n' + + // CanonicalHeaders + '\n' + + // SignedHeaders + '\n' + + // HexEncode(Hash(RequestPayload)) + + std::string CanonicalUri = AwsUriEncode(Url, false); + + // Build canonical headers and signed headers (headers must be sorted by lowercase name) + ExtendableStringBuilder<512> CanonicalHeadersSb; + ExtendableStringBuilder<256> SignedHeadersSb; + + for (size_t i = 0; i < Headers.size(); ++i) + { + CanonicalHeadersSb.Append(Headers[i].first); + CanonicalHeadersSb.Append(':'); + CanonicalHeadersSb.Append(Headers[i].second); + CanonicalHeadersSb.Append('\n'); + + if (i > 0) + { + SignedHeadersSb.Append(';'); + } + SignedHeadersSb.Append(Headers[i].first); + } + + std::string SignedHeaders = std::string(SignedHeadersSb.ToView()); + + std::string CanonicalRequest = fmt::format("{}\n{}\n{}\n{}\n{}\n{}", + Method, + CanonicalUri, + CanonicalQueryString, + CanonicalHeadersSb.ToView(), + SignedHeaders, + PayloadHash); + + // Step 2: Create the string to sign + std::string CredentialScope = fmt::format("{}/{}/{}/aws4_request", DateStamp, Region, Service); + + Sha256Digest CanonicalRequestHash = ComputeSha256(CanonicalRequest); + std::string CanonicalRequestHex = Sha256ToHex(CanonicalRequestHash); + + std::string StringToSign = fmt::format("AWS4-HMAC-SHA256\n{}\n{}\n{}", Result.AmzDate, CredentialScope, CanonicalRequestHex); + + // Step 3: Calculate the signing key + // kDate = HMAC("AWS4" + SecretKey, DateStamp) + // kRegion = HMAC(kDate, Region) + // kService = HMAC(kRegion, Service) + // kSigning = HMAC(kService, "aws4_request") + + Sha256Digest DerivedSigningKey; + if (!SigningKeyPtr) + { + std::string SecretPrefix = fmt::format("AWS4{}", Credentials.SecretAccessKey); + + Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), DateStamp.data(), DateStamp.size()); + SecureZeroSecret(SecretPrefix.data(), SecretPrefix.size()); + + Sha256Digest RegionKey = ComputeHmacSha256(DateKey, Region); + Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, Service); + DerivedSigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); + SigningKeyPtr = &DerivedSigningKey; + } + + // Step 4: Calculate the signature + Sha256Digest Signature = ComputeHmacSha256(*SigningKeyPtr, StringToSign); + std::string SignatureHex = Sha256ToHex(Signature); + + // Step 5: Build the Authorization header + Result.Authorization = fmt::format("AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", + Credentials.AccessKeyId, + CredentialScope, + SignedHeaders, + SignatureHex); + + return Result; +} + +std::string +GeneratePresignedUrl(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Scheme, + std::string_view Host, + std::string_view Path, + std::string_view Region, + std::string_view Service, + std::chrono::seconds ExpiresIn, + const std::vector<std::pair<std::string, std::string>>& ExtraQueryParams) +{ + // Pre-signed URLs use query string authentication: + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + + std::string AmzDate = GetAmzTimestamp(); + std::string DateStamp = GetDateStamp(AmzDate); + + std::string CredentialScope = fmt::format("{}/{}/{}/aws4_request", DateStamp, Region, Service); + std::string Credential = fmt::format("{}/{}", Credentials.AccessKeyId, CredentialScope); + + // The only signed header for pre-signed URLs is "host" + constexpr std::string_view SignedHeaders = "host"; + + // Build query parameters that will be part of the canonical request. + // These are the auth params (minus X-Amz-Signature which is added after signing). + std::vector<std::pair<std::string, std::string>> QueryParams = ExtraQueryParams; + QueryParams.emplace_back("X-Amz-Algorithm", "AWS4-HMAC-SHA256"); + QueryParams.emplace_back("X-Amz-Credential", Credential); + QueryParams.emplace_back("X-Amz-Date", AmzDate); + QueryParams.emplace_back("X-Amz-Expires", fmt::format("{}", ExpiresIn.count())); + if (!Credentials.SessionToken.empty()) + { + QueryParams.emplace_back("X-Amz-Security-Token", Credentials.SessionToken); + } + QueryParams.emplace_back("X-Amz-SignedHeaders", std::string(SignedHeaders)); + + std::string CanonicalQueryString = BuildCanonicalQueryString(QueryParams); + std::string CanonicalUri = AwsUriEncode(Path, false); + + // For pre-signed URLs, the payload is always UNSIGNED-PAYLOAD + constexpr std::string_view PayloadHash = "UNSIGNED-PAYLOAD"; + + // Build the canonical request + // Only "host" is in the canonical headers for pre-signed URLs + std::string CanonicalHeaders = fmt::format("host:{}\n", Host); + + std::string CanonicalRequest = + fmt::format("{}\n{}\n{}\n{}\n{}\n{}", Method, CanonicalUri, CanonicalQueryString, CanonicalHeaders, SignedHeaders, PayloadHash); + + // Create the string to sign + Sha256Digest CanonicalRequestHash = ComputeSha256(CanonicalRequest); + std::string CanonicalRequestHex = Sha256ToHex(CanonicalRequestHash); + + std::string StringToSign = fmt::format("AWS4-HMAC-SHA256\n{}\n{}\n{}", AmzDate, CredentialScope, CanonicalRequestHex); + + // Calculate the signing key + std::string SecretPrefix = fmt::format("AWS4{}", Credentials.SecretAccessKey); + Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), DateStamp.data(), DateStamp.size()); + SecureZeroSecret(SecretPrefix.data(), SecretPrefix.size()); + Sha256Digest RegionKey = ComputeHmacSha256(DateKey, Region); + Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, Service); + Sha256Digest SigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); + + // Calculate the signature + std::string SignatureHex = Sha256ToHex(ComputeHmacSha256(SigningKey, StringToSign)); + + // Build the final URL (use the URI-encoded path so special characters are properly escaped) + return fmt::format("{}://{}{}?{}&X-Amz-Signature={}", Scheme, Host, CanonicalUri, CanonicalQueryString, SignatureHex); +} + +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +sigv4_forcelink() +{ +} + +TEST_SUITE_BEGIN("util.cloud.sigv4"); + +TEST_CASE("sigv4.sha256") +{ + // Test with known test vector (empty string) + Sha256Digest Empty = ComputeSha256("", 0); + std::string Hex = Sha256ToHex(Empty); + CHECK(Hex == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"); + + // Test with "hello" + Sha256Digest Hello = ComputeSha256("hello"); + std::string HelloHex = Sha256ToHex(Hello); + CHECK(HelloHex == "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"); +} + +TEST_CASE("sigv4.hmac_sha256") +{ + // RFC 4231 Test Case 2 + std::string_view Key = "Jefe"; + std::string_view Data = "what do ya want for nothing?"; + + Sha256Digest Result = ComputeHmacSha256(Key.data(), Key.size(), Data.data(), Data.size()); + std::string Hex = Sha256ToHex(Result); + CHECK(Hex == "5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843"); +} + +TEST_CASE("sigv4.signing") +{ + // Based on the AWS SigV4 test suite example + // https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html + + SigV4Credentials Creds; + Creds.AccessKeyId = "AKIDEXAMPLE"; + Creds.SecretAccessKey = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"; + + // We can't test with a fixed timestamp since SignRequestV4 uses current time, + // but we can verify the crypto primitives produce correct results by testing + // the signing key derivation manually. + + // Test signing key derivation: HMAC chain for "20150830" / "us-east-1" / "iam" + std::string SecretPrefix = "AWS4wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"; + Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), "20150830", 8); + Sha256Digest RegionKey = ComputeHmacSha256(DateKey, "us-east-1"); + Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, "iam"); + Sha256Digest SigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); + + std::string SigningKeyHex = Sha256ToHex(SigningKey); + CHECK(SigningKeyHex == "c4afb1cc5771d871763a393e44b703571b55cc28424d1a5e86da6ed3c154a4b9"); +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/cloudprovider.h b/src/zenutil/include/zenutil/cloud/cloudprovider.h new file mode 100644 index 000000000..5825eb308 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/cloudprovider.h @@ -0,0 +1,19 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <string_view> + +namespace zen::compute { + +enum class CloudProvider +{ + None, + AWS, + Azure, + GCP +}; + +std::string_view ToString(CloudProvider Provider); + +} // namespace zen::compute diff --git a/src/zenutil/include/zenutil/cloud/imdscredentials.h b/src/zenutil/include/zenutil/cloud/imdscredentials.h new file mode 100644 index 000000000..33df5a1e2 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/imdscredentials.h @@ -0,0 +1,58 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenutil/cloud/sigv4.h> + +#include <zenbase/refcount.h> +#include <zencore/logging.h> +#include <zencore/thread.h> +#include <zenhttp/httpclient.h> + +#include <chrono> +#include <string> + +namespace zen { + +struct ImdsCredentialProviderOptions +{ + std::string Endpoint = "http://169.254.169.254"; // Override for testing + std::chrono::milliseconds ConnectTimeout{1000}; + std::chrono::milliseconds RequestTimeout{5000}; +}; + +/// Fetches and caches temporary AWS credentials from the EC2 Instance Metadata +/// Service (IMDSv2). Thread-safe; credentials are refreshed automatically before +/// they expire. +class ImdsCredentialProvider : public RefCounted +{ +public: + explicit ImdsCredentialProvider(const ImdsCredentialProviderOptions& Options = {}); + ~ImdsCredentialProvider(); + + /// Fetch or return cached credentials. Thread-safe. + /// Returns empty credentials (empty AccessKeyId) on failure. + SigV4Credentials GetCredentials(); + + /// Force a refresh on next GetCredentials() call. + void InvalidateCache(); + +private: + bool FetchToken(); + bool FetchCredentials(); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + HttpClient m_HttpClient; + + mutable RwLock m_Lock; + std::string m_ImdsToken; + SigV4Credentials m_CachedCredentials; + std::string m_RoleName; + std::chrono::steady_clock::time_point m_ExpiresAt; +}; + +void imdscredentials_forcelink(); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/minioprocess.h b/src/zenutil/include/zenutil/cloud/minioprocess.h new file mode 100644 index 000000000..7af350e60 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/minioprocess.h @@ -0,0 +1,48 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> + +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> + +namespace zen { + +struct MinioProcessOptions +{ + uint16_t Port = 9000; + std::string RootUser = "minioadmin"; + std::string RootPassword = "minioadmin"; +}; + +class MinioProcess +{ +public: + explicit MinioProcess(const MinioProcessOptions& Options = {}); + ~MinioProcess(); + + MinioProcess(const MinioProcess&) = delete; + MinioProcess& operator=(const MinioProcess&) = delete; + + void SpawnMinioServer(); + void StopMinioServer(); + + /// Pre-create a bucket by creating a subdirectory in the MinIO data directory. + /// Can be called before or after SpawnMinioServer(). MinIO discovers these at startup + /// and also picks up new directories at runtime. + void CreateBucket(std::string_view BucketName); + + uint16_t Port() const; + std::string_view RootUser() const; + std::string_view RootPassword() const; + std::string Endpoint() const; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/mockimds.h b/src/zenutil/include/zenutil/cloud/mockimds.h new file mode 100644 index 000000000..d0c0155b0 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/mockimds.h @@ -0,0 +1,110 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> +#include <zenutil/cloud/cloudprovider.h> + +#include <string> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +/** + * Mock IMDS (Instance Metadata Service) for testing cloud metadata and + * credential providers. + * + * Implements an HttpService that responds to the same URL paths as the real + * cloud provider metadata endpoints (AWS IMDSv2, Azure IMDS, GCP metadata). + * Tests configure which provider is "active" and set the desired response + * values, then pass the mock server's address as the ImdsEndpoint to the + * CloudMetadata constructor. + * + * When a request arrives for a provider that is not the ActiveProvider, the + * mock returns 404, causing CloudMetadata to write a sentinel file and move + * on to the next provider — exactly like a failed probe on bare metal. + * + * All config fields are public and can be mutated between poll cycles to + * simulate state changes (e.g. a spot interruption appearing mid-run). + * + * Usage: + * MockImdsService Mock; + * Mock.ActiveProvider = CloudProvider::AWS; + * Mock.Aws.InstanceId = "i-test"; + * // ... stand up ASIO server, register Mock, create CloudMetadata with endpoint + */ +class MockImdsService : public HttpService +{ +public: + /** AWS IMDSv2 response configuration. */ + struct AwsConfig + { + std::string Token = "mock-aws-token-v2"; + std::string InstanceId = "i-0123456789abcdef0"; + std::string AvailabilityZone = "us-east-1a"; + std::string LifeCycle = "on-demand"; // "spot" or "on-demand" + + // Empty string → endpoint returns 404 (instance not in an ASG). + // Non-empty → returned as the response body. "InService" means healthy; + // anything else (e.g. "Terminated:Wait") triggers termination detection. + std::string AutoscalingState; + + // Empty string → endpoint returns 404 (no spot interruption). + // Non-empty → returned as the response body, signalling a spot reclaim. + std::string SpotAction; + + // IAM credential fields for ImdsCredentialProvider testing + std::string IamRoleName = "test-role"; + std::string IamAccessKeyId = "ASIAIOSFODNN7EXAMPLE"; + std::string IamSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; + std::string IamSessionToken = "FwoGZXIvYXdzEBYaDEXAMPLETOKEN"; + std::string IamExpiration = "2099-01-01T00:00:00Z"; + }; + + /** Azure IMDS response configuration. */ + struct AzureConfig + { + std::string VmId = "vm-12345678-1234-1234-1234-123456789abc"; + std::string Location = "eastus"; + std::string Priority = "Regular"; // "Spot" or "Regular" + + // Empty → instance is not in a VM Scale Set (no autoscaling). + std::string VmScaleSetName; + + // Empty → no scheduled events. Set to "Preempt", "Terminate", or + // "Reboot" to simulate a termination-class event. + std::string ScheduledEventType; + std::string ScheduledEventStatus = "Scheduled"; + }; + + /** GCP metadata response configuration. */ + struct GcpConfig + { + std::string InstanceId = "1234567890123456789"; + std::string Zone = "projects/123456/zones/us-central1-a"; + std::string Preemptible = "FALSE"; // "TRUE" or "FALSE" + std::string MaintenanceEvent = "NONE"; // "NONE" or event description + }; + + /** Which provider's endpoints respond successfully. + * Requests targeting other providers receive 404. + */ + CloudProvider ActiveProvider = CloudProvider::None; + + AwsConfig Aws; + AzureConfig Azure; + GcpConfig Gcp; + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + +private: + void HandleAwsRequest(HttpServerRequest& Request); + void HandleAzureRequest(HttpServerRequest& Request); + void HandleGcpRequest(HttpServerRequest& Request); +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zenutil/include/zenutil/cloud/s3client.h b/src/zenutil/include/zenutil/cloud/s3client.h new file mode 100644 index 000000000..47501c5b5 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/s3client.h @@ -0,0 +1,215 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenutil/cloud/imdscredentials.h> +#include <zenutil/cloud/sigv4.h> + +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zenhttp/httpclient.h> + +#include <zencore/thread.h> + +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +struct S3ClientOptions +{ + std::string Region = "us-east-1"; + std::string BucketName; + std::string Endpoint; // e.g., "https://s3.us-east-1.amazonaws.com". If empty, derived from Region. + + /// Use path-style addressing (endpoint/bucket/key) instead of virtual-hosted style + /// (bucket.endpoint/key). Required for S3-compatible services like MinIO that don't + /// support virtual-hosted style. + bool PathStyle = false; + + SigV4Credentials Credentials; + + /// When set, credentials are fetched from EC2 IMDS on demand. + /// Overrides the static Credentials field. + Ref<ImdsCredentialProvider> CredentialProvider; + + std::chrono::milliseconds ConnectTimeout{5000}; + std::chrono::milliseconds Timeout{}; + uint8_t RetryCount = 3; +}; + +struct S3ObjectInfo +{ + std::string Key; + uint64_t Size = 0; + std::string ETag; + std::string LastModified; +}; + +/// Result type for S3 operations. Empty Error string indicates success. +struct S3Result +{ + std::string Error; + + bool IsSuccess() const { return Error.empty(); } + explicit operator bool() const { return IsSuccess(); } +}; + +enum class HeadObjectResult +{ + Found, + NotFound, + Error, +}; + +/// Result of GetObject — carries the downloaded content. +struct S3GetObjectResult : S3Result +{ + IoBuffer Content; + + std::string_view AsText() const { return std::string_view(reinterpret_cast<const char*>(Content.GetData()), Content.GetSize()); } +}; + +/// Result of HeadObject — carries object metadata and existence status. +struct S3HeadObjectResult : S3Result +{ + S3ObjectInfo Info; + HeadObjectResult Status = HeadObjectResult::NotFound; +}; + +/// Result of ListObjects — carries the list of matching objects. +struct S3ListObjectsResult : S3Result +{ + std::vector<S3ObjectInfo> Objects; +}; + +/// Result of CreateMultipartUpload — carries the upload ID. +struct S3CreateMultipartUploadResult : S3Result +{ + std::string UploadId; +}; + +/// Result of UploadPart — carries the part ETag. +struct S3UploadPartResult : S3Result +{ + std::string ETag; +}; + +/// Client for S3-compatible object storage. +/// +/// Supports basic object operations (GET, PUT, DELETE, HEAD), listing, multipart +/// uploads, and pre-signed URL generation. Requests are authenticated with AWS +/// Signature Version 4; the signing key is cached per day to avoid redundant HMAC +/// derivation. +/// +/// Limitations: +/// - Multipart uploads are sequential (no parallel part upload). +/// - XML responses are parsed with a minimal tag extractor that only decodes the five +/// standard XML entities; CDATA sections and nested/namespaced tags are not handled. +/// - Automatic credential refresh is supported via ImdsCredentialProvider. +class S3Client +{ +public: + explicit S3Client(const S3ClientOptions& Options); + ~S3Client(); + + /// Upload an object to S3 + S3Result PutObject(std::string_view Key, IoBuffer Content); + + /// Download an object from S3 + S3GetObjectResult GetObject(std::string_view Key); + + /// Delete an object from S3 + S3Result DeleteObject(std::string_view Key); + + /// Check if an object exists and get its metadata + S3HeadObjectResult HeadObject(std::string_view Key); + + /// List objects with the given prefix + /// @param MaxKeys Maximum number of keys to return (0 = default/1000) + S3ListObjectsResult ListObjects(std::string_view Prefix, uint32_t MaxKeys = 0); + + /// Multipart upload: initiate a multipart upload and return the upload ID + S3CreateMultipartUploadResult CreateMultipartUpload(std::string_view Key); + + /// Multipart upload: upload a single part + /// @param PartNumber Part number (1-based, 1 to 10000) + /// @param Content The part data (minimum 5 MB except for the last part) + S3UploadPartResult UploadPart(std::string_view Key, std::string_view UploadId, uint32_t PartNumber, IoBuffer Content); + + /// Multipart upload: complete a multipart upload by assembling previously uploaded parts + /// @param PartETags List of {part_number, etag} pairs from UploadPart calls + S3Result CompleteMultipartUpload(std::string_view Key, + std::string_view UploadId, + const std::vector<std::pair<uint32_t, std::string>>& PartETags); + + /// Multipart upload: abort an in-progress multipart upload, discarding all uploaded parts + S3Result AbortMultipartUpload(std::string_view Key, std::string_view UploadId); + + /// High-level multipart upload: automatically splits content into parts and uploads + /// @param PartSize Size of each part in bytes (minimum 5 MB, default 8 MB) + S3Result PutObjectMultipart(std::string_view Key, IoBuffer Content, uint64_t PartSize = 8 * 1024 * 1024); + + /// Generate a pre-signed URL for downloading an object (GET) + /// @param Key The object key + /// @param ExpiresIn URL validity duration (default 1 hour, max 7 days) + std::string GeneratePresignedGetUrl(std::string_view Key, std::chrono::seconds ExpiresIn = std::chrono::hours(1)); + + /// Generate a pre-signed URL for uploading an object (PUT) + /// @param Key The object key + /// @param ExpiresIn URL validity duration (default 1 hour, max 7 days) + std::string GeneratePresignedPutUrl(std::string_view Key, std::chrono::seconds ExpiresIn = std::chrono::hours(1)); + + std::string_view BucketName() const { return m_BucketName; } + std::string_view Region() const { return m_Region; } + +private: + /// Shared implementation for pre-signed URL generation + std::string GeneratePresignedUrlForMethod(std::string_view Key, std::string_view Method, std::chrono::seconds ExpiresIn); + + LoggerRef Log() { return m_Log; } + + /// Build the endpoint URL for the bucket + std::string BuildEndpoint() const; + + /// Build the host header value + std::string BuildHostHeader() const; + + /// Build the S3 object path from a key, accounting for path-style addressing + std::string KeyToPath(std::string_view Key) const; + + /// Build the bucket root path ("/" for virtual-hosted, "/bucket/" for path-style) + std::string BucketRootPath() const; + + /// Sign a request and return headers with Authorization, x-amz-date, x-amz-content-sha256 + HttpClient::KeyValueMap SignRequest(std::string_view Method, + std::string_view Path, + std::string_view QueryString, + std::string_view PayloadHash); + + /// Get or compute the signing key for the given date stamp, caching across requests on the same day + Sha256Digest GetSigningKey(std::string_view DateStamp); + + /// Get the current credentials, either from the provider or from static config + SigV4Credentials GetCurrentCredentials(); + + LoggerRef m_Log; + std::string m_BucketName; + std::string m_Region; + std::string m_Endpoint; + std::string m_Host; + bool m_PathStyle; + SigV4Credentials m_Credentials; + Ref<ImdsCredentialProvider> m_CredentialProvider; + HttpClient m_HttpClient; + + // Cached signing key (only changes once per day, protected by RwLock for thread safety) + mutable RwLock m_SigningKeyLock; + std::string m_CachedDateStamp; + Sha256Digest m_CachedSigningKey{}; +}; + +void s3client_forcelink(); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/sigv4.h b/src/zenutil/include/zenutil/cloud/sigv4.h new file mode 100644 index 000000000..9ac08df76 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/sigv4.h @@ -0,0 +1,116 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <array> +#include <chrono> +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +/// SHA-256 digest (32 bytes) +using Sha256Digest = std::array<uint8_t, 32>; + +/// Compute SHA-256 hash of the given data +Sha256Digest ComputeSha256(const void* Data, size_t Size); +Sha256Digest ComputeSha256(std::string_view Data); + +/// Compute HMAC-SHA256 with the given key and data +Sha256Digest ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize); +Sha256Digest ComputeHmacSha256(const Sha256Digest& Key, std::string_view Data); + +/// Convert a SHA-256 digest to lowercase hex string +std::string Sha256ToHex(const Sha256Digest& Digest); + +/// Securely zero memory containing secret key material (prevents compiler from optimizing away) +void SecureZeroSecret(void* Data, size_t Size); + +/// AWS Signature Version 4 signing + +struct SigV4Credentials +{ + std::string AccessKeyId; + std::string SecretAccessKey; + std::string SessionToken; // Optional; required for temporary credentials (STS/SSO) +}; + +struct SigV4SignedHeaders +{ + /// The value for the "Authorization" header + std::string Authorization; + + /// The ISO 8601 date-time string used in signing (for x-amz-date header) + std::string AmzDate; + + /// The SHA-256 hex digest of the payload (for x-amz-content-sha256 header) + std::string PayloadHash; +}; + +/// Get the current UTC timestamp in ISO 8601 format (YYYYMMDDTHHMMSSZ) +std::string GetAmzTimestamp(); + +/// URI-encode a string per AWS requirements (RFC 3986 unreserved chars are not encoded) +/// @param EncodeSlash If false, '/' is left unencoded (use for URI paths) +std::string AwsUriEncode(std::string_view Input, bool EncodeSlash = true); + +/// Build a canonical query string from key=value pairs. +/// Parameters are URI-encoded and sorted by key name as required by SigV4. +/// Takes parameters by value to sort in-place without copying. +std::string BuildCanonicalQueryString(std::vector<std::pair<std::string, std::string>> Parameters); + +/// Sign an HTTP request using AWS Signature Version 4 +/// +/// @param Credentials AWS access key and secret key +/// @param Method HTTP method (GET, PUT, DELETE, HEAD, etc.) +/// @param Url The path portion of the URL (e.g., "/bucket/key") +/// @param CanonicalQueryString Pre-built canonical query string (use BuildCanonicalQueryString) +/// @param Region The AWS region (e.g., "us-east-1") +/// @param Service The AWS service (e.g., "s3") +/// @param AmzDate The ISO 8601 date-time string (from GetAmzTimestamp()) +/// @param Headers Sorted list of {lowercase-header-name, value} pairs to sign. +/// Must include "host" and "x-amz-content-sha256". +/// Should NOT include "authorization". +/// @param PayloadHash Hex SHA-256 hash of the request payload. Use +/// "UNSIGNED-PAYLOAD" for unsigned payloads. +/// @param SigningKey Optional pre-computed signing key. If null, derived from Credentials + date + Region + Service. +SigV4SignedHeaders SignRequestV4(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Url, + std::string_view CanonicalQueryString, + std::string_view Region, + std::string_view Service, + std::string_view AmzDate, + const std::vector<std::pair<std::string, std::string>>& Headers, + std::string_view PayloadHash, + const Sha256Digest* SigningKey = nullptr); + +/// Generate a pre-signed URL using AWS Signature Version 4 query string authentication. +/// +/// The returned URL can be used by anyone (no credentials needed) until it expires. +/// +/// @param Credentials AWS access key and secret key +/// @param Method HTTP method the URL will be used with (typically "GET" or "PUT") +/// @param Scheme URL scheme ("https" or "http") +/// @param Host The host (e.g., "bucket.s3.us-east-1.amazonaws.com") +/// @param Path The path portion (e.g., "/key") +/// @param Region The AWS region (e.g., "us-east-1") +/// @param Service The AWS service (e.g., "s3") +/// @param ExpiresIn URL validity duration +/// @param ExtraQueryParams Additional query parameters to include (e.g., response-content-type) +std::string GeneratePresignedUrl(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Scheme, + std::string_view Host, + std::string_view Path, + std::string_view Region, + std::string_view Service, + std::chrono::seconds ExpiresIn, + const std::vector<std::pair<std::string, std::string>>& ExtraQueryParams = {}); + +void sigv4_forcelink(); + +} // namespace zen diff --git a/src/zenutil/xmake.lua b/src/zenutil/xmake.lua index 1d5be5977..1e19f7b2f 100644 --- a/src/zenutil/xmake.lua +++ b/src/zenutil/xmake.lua @@ -9,6 +9,7 @@ target('zenutil') add_deps("zencore", "zenhttp") add_deps("cxxopts") add_deps("robin-map") + add_packages("json11") if is_plat("linux") then add_includedirs("$(projectdir)/thirdparty/systemd/include") diff --git a/src/zenutil/zenutil.cpp b/src/zenutil/zenutil.cpp index 291dbeadd..734813b69 100644 --- a/src/zenutil/zenutil.cpp +++ b/src/zenutil/zenutil.cpp @@ -4,6 +4,9 @@ #if ZEN_WITH_TESTS +# include <zenutil/cloud/imdscredentials.h> +# include <zenutil/cloud/s3client.h> +# include <zenutil/cloud/sigv4.h> # include <zenutil/rpcrecording.h> # include <zenutil/config/commandlineoptions.h> # include <zenutil/wildcard.h> @@ -15,6 +18,9 @@ zenutil_forcelinktests() { cache::rpcrecord_forcelink(); commandlineoptions_forcelink(); + imdscredentials_forcelink(); + s3client_forcelink(); + sigv4_forcelink(); wildcard_forcelink(); } |