diff options
| author | Stefan Boberg <[email protected]> | 2026-03-18 11:27:07 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-18 11:27:07 +0100 |
| commit | e64d76ae1b6993582bf161a61049f0771414a779 (patch) | |
| tree | 083f3df42cc9e2c7ddbee225708b4848eb217d11 /src/zenutil/cloud/imdscredentials.cpp | |
| parent | Compute batching (#849) (diff) | |
| download | zen-e64d76ae1b6993582bf161a61049f0771414a779.tar.xz zen-e64d76ae1b6993582bf161a61049f0771414a779.zip | |
Simple S3 client (#836)
This functionality is intended to be used to manage datasets for test cases, but may be useful elsewhere in the future.
- **Add S3 client with AWS Signature V4 (SigV4) signing** — new `S3Client` in `zenutil/cloud/` supporting `GetObject`, `PutObject`, `DeleteObject`, `HeadObject`, and `ListObjects` operations
- **Add EC2 IMDS credential provider** — automatically fetches and refreshes temporary AWS credentials from the EC2 Instance Metadata Service (IMDSv2) for use by the S3 client
- **Add SigV4 signing library** — standalone implementation of AWS Signature Version 4 request signing (headers and query-string presigning)
- **Add path-style addressing support** — enables compatibility with S3-compatible stores like MinIO (in addition to virtual-hosted style)
- **Add S3 integration tests** — includes a `MinioProcess` test helper that spins up a local MinIO server, plus integration tests exercising the S3 client end-to-end
- **Add S3-backed `HttpObjectStoreService` tests** — integration tests verifying the zenserver object store works against an S3 backend
- **Refactor mock IMDS into `zenutil/cloud/`** — moved and generalized the mock IMDS server from `zencompute` so it can be reused by both compute and S3 credential tests
Diffstat (limited to 'src/zenutil/cloud/imdscredentials.cpp')
| -rw-r--r-- | src/zenutil/cloud/imdscredentials.cpp | 387 |
1 files changed, 387 insertions, 0 deletions
diff --git a/src/zenutil/cloud/imdscredentials.cpp b/src/zenutil/cloud/imdscredentials.cpp new file mode 100644 index 000000000..dde1dc019 --- /dev/null +++ b/src/zenutil/cloud/imdscredentials.cpp @@ -0,0 +1,387 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/imdscredentials.h> + +#include <zenutil/cloud/mockimds.h> + +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> +#include <zenhttp/httpserver.h> + +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +namespace { + + /// Margin before expiration at which we proactively refresh credentials. + constexpr auto kRefreshMargin = std::chrono::minutes(5); + + /// Parse an ISO 8601 UTC timestamp (e.g. "2026-03-14T20:00:00Z") into a system_clock time_point. + /// Returns epoch on failure. + std::chrono::system_clock::time_point ParseIso8601(std::string_view Timestamp) + { + // Expected format: YYYY-MM-DDTHH:MM:SSZ + if (Timestamp.size() < 19) + { + return {}; + } + + std::tm Tm = {}; + // Manual parse since std::get_time is locale-dependent + Tm.tm_year = ParseInt<int>(Timestamp.substr(0, 4)).value_or(1970) - 1900; + Tm.tm_mon = ParseInt<int>(Timestamp.substr(5, 2)).value_or(1) - 1; + Tm.tm_mday = ParseInt<int>(Timestamp.substr(8, 2)).value_or(1); + Tm.tm_hour = ParseInt<int>(Timestamp.substr(11, 2)).value_or(0); + Tm.tm_min = ParseInt<int>(Timestamp.substr(14, 2)).value_or(0); + Tm.tm_sec = ParseInt<int>(Timestamp.substr(17, 2)).value_or(0); + +#if ZEN_PLATFORM_WINDOWS + time_t EpochSeconds = _mkgmtime(&Tm); +#else + time_t EpochSeconds = timegm(&Tm); +#endif + if (EpochSeconds == -1) + { + return {}; + } + + return std::chrono::system_clock::from_time_t(EpochSeconds); + } + +} // namespace + +ImdsCredentialProvider::ImdsCredentialProvider(const ImdsCredentialProviderOptions& Options) +: m_Log(logging::Get("imds")) +, m_HttpClient(Options.Endpoint, + HttpClientSettings{ + .LogCategory = "imds", + .ConnectTimeout = Options.ConnectTimeout, + .Timeout = Options.RequestTimeout, + }) +{ + ZEN_INFO("IMDS credential provider configured (endpoint: {})", m_HttpClient.GetBaseUri()); +} + +ImdsCredentialProvider::~ImdsCredentialProvider() = default; + +SigV4Credentials +ImdsCredentialProvider::GetCredentials() +{ + // Fast path: shared lock for cache hit + { + RwLock::SharedLockScope SharedLock(m_Lock); + if (!m_CachedCredentials.AccessKeyId.empty() && std::chrono::steady_clock::now() < m_ExpiresAt) + { + return m_CachedCredentials; + } + } + + // Slow path: exclusive lock to refresh + RwLock::ExclusiveLockScope ExclusiveLock(m_Lock); + + // Double-check after acquiring exclusive lock + if (!m_CachedCredentials.AccessKeyId.empty() && std::chrono::steady_clock::now() < m_ExpiresAt) + { + return m_CachedCredentials; + } + + if (!FetchCredentials()) + { + ZEN_WARN("failed to fetch credentials from IMDS"); + return {}; + } + + return m_CachedCredentials; +} + +void +ImdsCredentialProvider::InvalidateCache() +{ + RwLock::ExclusiveLockScope ExclusiveLock(m_Lock); + m_CachedCredentials = {}; + m_ExpiresAt = {}; +} + +bool +ImdsCredentialProvider::FetchToken() +{ + HttpClient::KeyValueMap Headers; + Headers->emplace("X-aws-ec2-metadata-token-ttl-seconds", "21600"); + + HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", Headers); + if (!Response.IsSuccess()) + { + ZEN_WARN("IMDS token request failed: {}", Response.ErrorMessage("PUT /latest/api/token")); + return false; + } + + m_ImdsToken = std::string(Response.AsText()); + if (m_ImdsToken.empty()) + { + ZEN_WARN("IMDS returned empty token"); + return false; + } + + return true; +} + +bool +ImdsCredentialProvider::FetchCredentials() +{ + // Step 1: Get IMDSv2 session token + if (!FetchToken()) + { + return false; + } + + HttpClient::KeyValueMap TokenHeader; + TokenHeader->emplace("X-aws-ec2-metadata-token", m_ImdsToken); + + // Step 2: Discover IAM role name (if not already known) + if (m_RoleName.empty()) + { + HttpClient::Response RoleResponse = m_HttpClient.Get("/latest/meta-data/iam/security-credentials/", TokenHeader); + if (!RoleResponse.IsSuccess()) + { + ZEN_WARN("IMDS role discovery failed: {}", RoleResponse.ErrorMessage("GET iam/security-credentials/")); + return false; + } + + m_RoleName = std::string(RoleResponse.AsText()); + // Trim any trailing whitespace/newlines + while (!m_RoleName.empty() && (m_RoleName.back() == '\n' || m_RoleName.back() == '\r' || m_RoleName.back() == ' ')) + { + m_RoleName.pop_back(); + } + + if (m_RoleName.empty()) + { + ZEN_WARN("IMDS returned empty IAM role name"); + return false; + } + + ZEN_INFO("IMDS discovered IAM role: {}", m_RoleName); + } + + // Step 3: Fetch credentials for the role + std::string CredentialPath = fmt::format("/latest/meta-data/iam/security-credentials/{}", m_RoleName); + + HttpClient::Response CredResponse = m_HttpClient.Get(CredentialPath, TokenHeader); + if (!CredResponse.IsSuccess()) + { + ZEN_WARN("IMDS credential fetch failed: {}", CredResponse.ErrorMessage("GET iam/security-credentials/" + m_RoleName)); + return false; + } + + // Step 4: Parse JSON response + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(CredResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + ZEN_WARN("IMDS credential response JSON parse error: {}", JsonError); + return false; + } + + std::string AccessKeyId = Json["AccessKeyId"].string_value(); + std::string SecretAccessKey = Json["SecretAccessKey"].string_value(); + std::string SessionToken = Json["Token"].string_value(); + std::string Expiration = Json["Expiration"].string_value(); + + if (AccessKeyId.empty() || SecretAccessKey.empty()) + { + ZEN_WARN("IMDS credential response missing AccessKeyId or SecretAccessKey"); + return false; + } + + // Compute local expiration time based on the Expiration field + auto ExpirationTime = ParseIso8601(Expiration); + auto Now = std::chrono::system_clock::now(); + + std::chrono::steady_clock::time_point NewExpiresAt; + if (ExpirationTime > Now) + { + auto TimeUntilExpiry = ExpirationTime - Now; + NewExpiresAt = std::chrono::steady_clock::now() + TimeUntilExpiry - kRefreshMargin; + } + else + { + // Expiration is in the past or unparseable — force refresh next time + NewExpiresAt = std::chrono::steady_clock::now(); + } + + bool KeyChanged = (m_CachedCredentials.AccessKeyId != AccessKeyId); + + m_CachedCredentials.AccessKeyId = std::move(AccessKeyId); + m_CachedCredentials.SecretAccessKey = std::move(SecretAccessKey); + m_CachedCredentials.SessionToken = std::move(SessionToken); + m_ExpiresAt = NewExpiresAt; + + if (KeyChanged) + { + ZEN_INFO("IMDS credentials refreshed (AccessKeyId: {}...)", m_CachedCredentials.AccessKeyId.substr(0, 8)); + } + else + { + ZEN_DEBUG("IMDS credentials refreshed (unchanged key)"); + } + + return true; +} + +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +imdscredentials_forcelink() +{ +} + +TEST_SUITE_BEGIN("util.cloud.imdscredentials"); + +TEST_CASE("imdscredentials.parse_iso8601") +{ + // Verify basic ISO 8601 parsing + auto Tp = ParseIso8601("2026-03-14T20:00:00Z"); + CHECK(Tp != std::chrono::system_clock::time_point{}); + + auto Epoch = std::chrono::system_clock::to_time_t(Tp); + std::tm Tm; +# if ZEN_PLATFORM_WINDOWS + gmtime_s(&Tm, &Epoch); +# else + gmtime_r(&Epoch, &Tm); +# endif + CHECK(Tm.tm_year + 1900 == 2026); + CHECK(Tm.tm_mon + 1 == 3); + CHECK(Tm.tm_mday == 14); + CHECK(Tm.tm_hour == 20); + CHECK(Tm.tm_min == 0); + CHECK(Tm.tm_sec == 0); + + // Invalid input + auto Bad = ParseIso8601("bad"); + CHECK(Bad == std::chrono::system_clock::time_point{}); +} + +// --------------------------------------------------------------------------- +// Integration test with mock IMDS server +// --------------------------------------------------------------------------- + +struct TestImdsServer +{ + compute::MockImdsService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(7576, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + std::string Endpoint() const { return fmt::format("http://127.0.0.1:{}", m_Port); } + + ~TestImdsServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional<ScopedTemporaryDirectory> m_TmpDir; + Ref<HttpServer> m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +TEST_CASE("imdscredentials.fetch_from_mock") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = compute::CloudProvider::AWS; + Imds.Start(); + + ImdsCredentialProviderOptions Opts; + Opts.Endpoint = Imds.Endpoint(); + + Ref<ImdsCredentialProvider> Provider(new ImdsCredentialProvider(Opts)); + + SUBCASE("basic_credential_fetch") + { + SigV4Credentials Creds = Provider->GetCredentials(); + CHECK(!Creds.AccessKeyId.empty()); + CHECK(Creds.AccessKeyId == "ASIAIOSFODNN7EXAMPLE"); + CHECK(Creds.SecretAccessKey == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"); + CHECK(Creds.SessionToken == "FwoGZXIvYXdzEBYaDEXAMPLETOKEN"); + } + + SUBCASE("credentials_are_cached") + { + SigV4Credentials First = Provider->GetCredentials(); + SigV4Credentials Second = Provider->GetCredentials(); + CHECK(First.AccessKeyId == Second.AccessKeyId); + CHECK(First.SecretAccessKey == Second.SecretAccessKey); + } + + SUBCASE("invalidate_forces_refresh") + { + SigV4Credentials First = Provider->GetCredentials(); + CHECK(!First.AccessKeyId.empty()); + + // Change the credentials on the mock + Imds.Mock.Aws.IamAccessKeyId = "ASIANEWKEYEXAMPLE12"; + + Provider->InvalidateCache(); + SigV4Credentials Second = Provider->GetCredentials(); + CHECK(Second.AccessKeyId == "ASIANEWKEYEXAMPLE12"); + } + + SUBCASE("custom_role_name") + { + Imds.Mock.Aws.IamRoleName = "my-custom-role"; + + Ref<ImdsCredentialProvider> Provider2(new ImdsCredentialProvider(Opts)); + SigV4Credentials Creds = Provider2->GetCredentials(); + CHECK(!Creds.AccessKeyId.empty()); + } +} + +TEST_CASE("imdscredentials.unreachable_endpoint") +{ + // Point at a non-existent server — should return empty credentials, not crash + ImdsCredentialProviderOptions Opts; + Opts.Endpoint = "http://127.0.0.1:1"; // unlikely to have anything listening + Opts.ConnectTimeout = std::chrono::milliseconds(100); + Opts.RequestTimeout = std::chrono::milliseconds(200); + + Ref<ImdsCredentialProvider> Provider(new ImdsCredentialProvider(Opts)); + SigV4Credentials Creds = Provider->GetCredentials(); + CHECK(Creds.AccessKeyId.empty()); +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen |