aboutsummaryrefslogtreecommitdiff
path: root/src/zenutil/cloud/imdscredentials.cpp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-03-18 11:27:07 +0100
committerGitHub Enterprise <[email protected]>2026-03-18 11:27:07 +0100
commite64d76ae1b6993582bf161a61049f0771414a779 (patch)
tree083f3df42cc9e2c7ddbee225708b4848eb217d11 /src/zenutil/cloud/imdscredentials.cpp
parentCompute batching (#849) (diff)
downloadzen-e64d76ae1b6993582bf161a61049f0771414a779.tar.xz
zen-e64d76ae1b6993582bf161a61049f0771414a779.zip
Simple S3 client (#836)
This functionality is intended to be used to manage datasets for test cases, but may be useful elsewhere in the future. - **Add S3 client with AWS Signature V4 (SigV4) signing** — new `S3Client` in `zenutil/cloud/` supporting `GetObject`, `PutObject`, `DeleteObject`, `HeadObject`, and `ListObjects` operations - **Add EC2 IMDS credential provider** — automatically fetches and refreshes temporary AWS credentials from the EC2 Instance Metadata Service (IMDSv2) for use by the S3 client - **Add SigV4 signing library** — standalone implementation of AWS Signature Version 4 request signing (headers and query-string presigning) - **Add path-style addressing support** — enables compatibility with S3-compatible stores like MinIO (in addition to virtual-hosted style) - **Add S3 integration tests** — includes a `MinioProcess` test helper that spins up a local MinIO server, plus integration tests exercising the S3 client end-to-end - **Add S3-backed `HttpObjectStoreService` tests** — integration tests verifying the zenserver object store works against an S3 backend - **Refactor mock IMDS into `zenutil/cloud/`** — moved and generalized the mock IMDS server from `zencompute` so it can be reused by both compute and S3 credential tests
Diffstat (limited to 'src/zenutil/cloud/imdscredentials.cpp')
-rw-r--r--src/zenutil/cloud/imdscredentials.cpp387
1 files changed, 387 insertions, 0 deletions
diff --git a/src/zenutil/cloud/imdscredentials.cpp b/src/zenutil/cloud/imdscredentials.cpp
new file mode 100644
index 000000000..dde1dc019
--- /dev/null
+++ b/src/zenutil/cloud/imdscredentials.cpp
@@ -0,0 +1,387 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/cloud/imdscredentials.h>
+
+#include <zenutil/cloud/mockimds.h>
+
+#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/testutils.h>
+#include <zenhttp/httpserver.h>
+
+#include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+namespace {
+
+ /// Margin before expiration at which we proactively refresh credentials.
+ constexpr auto kRefreshMargin = std::chrono::minutes(5);
+
+ /// Parse an ISO 8601 UTC timestamp (e.g. "2026-03-14T20:00:00Z") into a system_clock time_point.
+ /// Returns epoch on failure.
+ std::chrono::system_clock::time_point ParseIso8601(std::string_view Timestamp)
+ {
+ // Expected format: YYYY-MM-DDTHH:MM:SSZ
+ if (Timestamp.size() < 19)
+ {
+ return {};
+ }
+
+ std::tm Tm = {};
+ // Manual parse since std::get_time is locale-dependent
+ Tm.tm_year = ParseInt<int>(Timestamp.substr(0, 4)).value_or(1970) - 1900;
+ Tm.tm_mon = ParseInt<int>(Timestamp.substr(5, 2)).value_or(1) - 1;
+ Tm.tm_mday = ParseInt<int>(Timestamp.substr(8, 2)).value_or(1);
+ Tm.tm_hour = ParseInt<int>(Timestamp.substr(11, 2)).value_or(0);
+ Tm.tm_min = ParseInt<int>(Timestamp.substr(14, 2)).value_or(0);
+ Tm.tm_sec = ParseInt<int>(Timestamp.substr(17, 2)).value_or(0);
+
+#if ZEN_PLATFORM_WINDOWS
+ time_t EpochSeconds = _mkgmtime(&Tm);
+#else
+ time_t EpochSeconds = timegm(&Tm);
+#endif
+ if (EpochSeconds == -1)
+ {
+ return {};
+ }
+
+ return std::chrono::system_clock::from_time_t(EpochSeconds);
+ }
+
+} // namespace
+
+ImdsCredentialProvider::ImdsCredentialProvider(const ImdsCredentialProviderOptions& Options)
+: m_Log(logging::Get("imds"))
+, m_HttpClient(Options.Endpoint,
+ HttpClientSettings{
+ .LogCategory = "imds",
+ .ConnectTimeout = Options.ConnectTimeout,
+ .Timeout = Options.RequestTimeout,
+ })
+{
+ ZEN_INFO("IMDS credential provider configured (endpoint: {})", m_HttpClient.GetBaseUri());
+}
+
+ImdsCredentialProvider::~ImdsCredentialProvider() = default;
+
+SigV4Credentials
+ImdsCredentialProvider::GetCredentials()
+{
+ // Fast path: shared lock for cache hit
+ {
+ RwLock::SharedLockScope SharedLock(m_Lock);
+ if (!m_CachedCredentials.AccessKeyId.empty() && std::chrono::steady_clock::now() < m_ExpiresAt)
+ {
+ return m_CachedCredentials;
+ }
+ }
+
+ // Slow path: exclusive lock to refresh
+ RwLock::ExclusiveLockScope ExclusiveLock(m_Lock);
+
+ // Double-check after acquiring exclusive lock
+ if (!m_CachedCredentials.AccessKeyId.empty() && std::chrono::steady_clock::now() < m_ExpiresAt)
+ {
+ return m_CachedCredentials;
+ }
+
+ if (!FetchCredentials())
+ {
+ ZEN_WARN("failed to fetch credentials from IMDS");
+ return {};
+ }
+
+ return m_CachedCredentials;
+}
+
+void
+ImdsCredentialProvider::InvalidateCache()
+{
+ RwLock::ExclusiveLockScope ExclusiveLock(m_Lock);
+ m_CachedCredentials = {};
+ m_ExpiresAt = {};
+}
+
+bool
+ImdsCredentialProvider::FetchToken()
+{
+ HttpClient::KeyValueMap Headers;
+ Headers->emplace("X-aws-ec2-metadata-token-ttl-seconds", "21600");
+
+ HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", Headers);
+ if (!Response.IsSuccess())
+ {
+ ZEN_WARN("IMDS token request failed: {}", Response.ErrorMessage("PUT /latest/api/token"));
+ return false;
+ }
+
+ m_ImdsToken = std::string(Response.AsText());
+ if (m_ImdsToken.empty())
+ {
+ ZEN_WARN("IMDS returned empty token");
+ return false;
+ }
+
+ return true;
+}
+
+bool
+ImdsCredentialProvider::FetchCredentials()
+{
+ // Step 1: Get IMDSv2 session token
+ if (!FetchToken())
+ {
+ return false;
+ }
+
+ HttpClient::KeyValueMap TokenHeader;
+ TokenHeader->emplace("X-aws-ec2-metadata-token", m_ImdsToken);
+
+ // Step 2: Discover IAM role name (if not already known)
+ if (m_RoleName.empty())
+ {
+ HttpClient::Response RoleResponse = m_HttpClient.Get("/latest/meta-data/iam/security-credentials/", TokenHeader);
+ if (!RoleResponse.IsSuccess())
+ {
+ ZEN_WARN("IMDS role discovery failed: {}", RoleResponse.ErrorMessage("GET iam/security-credentials/"));
+ return false;
+ }
+
+ m_RoleName = std::string(RoleResponse.AsText());
+ // Trim any trailing whitespace/newlines
+ while (!m_RoleName.empty() && (m_RoleName.back() == '\n' || m_RoleName.back() == '\r' || m_RoleName.back() == ' '))
+ {
+ m_RoleName.pop_back();
+ }
+
+ if (m_RoleName.empty())
+ {
+ ZEN_WARN("IMDS returned empty IAM role name");
+ return false;
+ }
+
+ ZEN_INFO("IMDS discovered IAM role: {}", m_RoleName);
+ }
+
+ // Step 3: Fetch credentials for the role
+ std::string CredentialPath = fmt::format("/latest/meta-data/iam/security-credentials/{}", m_RoleName);
+
+ HttpClient::Response CredResponse = m_HttpClient.Get(CredentialPath, TokenHeader);
+ if (!CredResponse.IsSuccess())
+ {
+ ZEN_WARN("IMDS credential fetch failed: {}", CredResponse.ErrorMessage("GET iam/security-credentials/" + m_RoleName));
+ return false;
+ }
+
+ // Step 4: Parse JSON response
+ std::string JsonError;
+ const json11::Json Json = json11::Json::parse(std::string(CredResponse.AsText()), JsonError);
+
+ if (!JsonError.empty())
+ {
+ ZEN_WARN("IMDS credential response JSON parse error: {}", JsonError);
+ return false;
+ }
+
+ std::string AccessKeyId = Json["AccessKeyId"].string_value();
+ std::string SecretAccessKey = Json["SecretAccessKey"].string_value();
+ std::string SessionToken = Json["Token"].string_value();
+ std::string Expiration = Json["Expiration"].string_value();
+
+ if (AccessKeyId.empty() || SecretAccessKey.empty())
+ {
+ ZEN_WARN("IMDS credential response missing AccessKeyId or SecretAccessKey");
+ return false;
+ }
+
+ // Compute local expiration time based on the Expiration field
+ auto ExpirationTime = ParseIso8601(Expiration);
+ auto Now = std::chrono::system_clock::now();
+
+ std::chrono::steady_clock::time_point NewExpiresAt;
+ if (ExpirationTime > Now)
+ {
+ auto TimeUntilExpiry = ExpirationTime - Now;
+ NewExpiresAt = std::chrono::steady_clock::now() + TimeUntilExpiry - kRefreshMargin;
+ }
+ else
+ {
+ // Expiration is in the past or unparseable — force refresh next time
+ NewExpiresAt = std::chrono::steady_clock::now();
+ }
+
+ bool KeyChanged = (m_CachedCredentials.AccessKeyId != AccessKeyId);
+
+ m_CachedCredentials.AccessKeyId = std::move(AccessKeyId);
+ m_CachedCredentials.SecretAccessKey = std::move(SecretAccessKey);
+ m_CachedCredentials.SessionToken = std::move(SessionToken);
+ m_ExpiresAt = NewExpiresAt;
+
+ if (KeyChanged)
+ {
+ ZEN_INFO("IMDS credentials refreshed (AccessKeyId: {}...)", m_CachedCredentials.AccessKeyId.substr(0, 8));
+ }
+ else
+ {
+ ZEN_DEBUG("IMDS credentials refreshed (unchanged key)");
+ }
+
+ return true;
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Tests
+
+#if ZEN_WITH_TESTS
+
+void
+imdscredentials_forcelink()
+{
+}
+
+TEST_SUITE_BEGIN("util.cloud.imdscredentials");
+
+TEST_CASE("imdscredentials.parse_iso8601")
+{
+ // Verify basic ISO 8601 parsing
+ auto Tp = ParseIso8601("2026-03-14T20:00:00Z");
+ CHECK(Tp != std::chrono::system_clock::time_point{});
+
+ auto Epoch = std::chrono::system_clock::to_time_t(Tp);
+ std::tm Tm;
+# if ZEN_PLATFORM_WINDOWS
+ gmtime_s(&Tm, &Epoch);
+# else
+ gmtime_r(&Epoch, &Tm);
+# endif
+ CHECK(Tm.tm_year + 1900 == 2026);
+ CHECK(Tm.tm_mon + 1 == 3);
+ CHECK(Tm.tm_mday == 14);
+ CHECK(Tm.tm_hour == 20);
+ CHECK(Tm.tm_min == 0);
+ CHECK(Tm.tm_sec == 0);
+
+ // Invalid input
+ auto Bad = ParseIso8601("bad");
+ CHECK(Bad == std::chrono::system_clock::time_point{});
+}
+
+// ---------------------------------------------------------------------------
+// Integration test with mock IMDS server
+// ---------------------------------------------------------------------------
+
+struct TestImdsServer
+{
+ compute::MockImdsService Mock;
+
+ void Start()
+ {
+ m_TmpDir.emplace();
+ m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"});
+ m_Port = m_Server->Initialize(7576, m_TmpDir->Path() / "http");
+ REQUIRE(m_Port != -1);
+ m_Server->RegisterService(Mock);
+ m_ServerThread = std::thread([this]() { m_Server->Run(false); });
+ }
+
+ std::string Endpoint() const { return fmt::format("http://127.0.0.1:{}", m_Port); }
+
+ ~TestImdsServer()
+ {
+ if (m_Server)
+ {
+ m_Server->RequestExit();
+ }
+ if (m_ServerThread.joinable())
+ {
+ m_ServerThread.join();
+ }
+ if (m_Server)
+ {
+ m_Server->Close();
+ }
+ }
+
+private:
+ std::optional<ScopedTemporaryDirectory> m_TmpDir;
+ Ref<HttpServer> m_Server;
+ std::thread m_ServerThread;
+ int m_Port = -1;
+};
+
+TEST_CASE("imdscredentials.fetch_from_mock")
+{
+ TestImdsServer Imds;
+ Imds.Mock.ActiveProvider = compute::CloudProvider::AWS;
+ Imds.Start();
+
+ ImdsCredentialProviderOptions Opts;
+ Opts.Endpoint = Imds.Endpoint();
+
+ Ref<ImdsCredentialProvider> Provider(new ImdsCredentialProvider(Opts));
+
+ SUBCASE("basic_credential_fetch")
+ {
+ SigV4Credentials Creds = Provider->GetCredentials();
+ CHECK(!Creds.AccessKeyId.empty());
+ CHECK(Creds.AccessKeyId == "ASIAIOSFODNN7EXAMPLE");
+ CHECK(Creds.SecretAccessKey == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY");
+ CHECK(Creds.SessionToken == "FwoGZXIvYXdzEBYaDEXAMPLETOKEN");
+ }
+
+ SUBCASE("credentials_are_cached")
+ {
+ SigV4Credentials First = Provider->GetCredentials();
+ SigV4Credentials Second = Provider->GetCredentials();
+ CHECK(First.AccessKeyId == Second.AccessKeyId);
+ CHECK(First.SecretAccessKey == Second.SecretAccessKey);
+ }
+
+ SUBCASE("invalidate_forces_refresh")
+ {
+ SigV4Credentials First = Provider->GetCredentials();
+ CHECK(!First.AccessKeyId.empty());
+
+ // Change the credentials on the mock
+ Imds.Mock.Aws.IamAccessKeyId = "ASIANEWKEYEXAMPLE12";
+
+ Provider->InvalidateCache();
+ SigV4Credentials Second = Provider->GetCredentials();
+ CHECK(Second.AccessKeyId == "ASIANEWKEYEXAMPLE12");
+ }
+
+ SUBCASE("custom_role_name")
+ {
+ Imds.Mock.Aws.IamRoleName = "my-custom-role";
+
+ Ref<ImdsCredentialProvider> Provider2(new ImdsCredentialProvider(Opts));
+ SigV4Credentials Creds = Provider2->GetCredentials();
+ CHECK(!Creds.AccessKeyId.empty());
+ }
+}
+
+TEST_CASE("imdscredentials.unreachable_endpoint")
+{
+ // Point at a non-existent server — should return empty credentials, not crash
+ ImdsCredentialProviderOptions Opts;
+ Opts.Endpoint = "http://127.0.0.1:1"; // unlikely to have anything listening
+ Opts.ConnectTimeout = std::chrono::milliseconds(100);
+ Opts.RequestTimeout = std::chrono::milliseconds(200);
+
+ Ref<ImdsCredentialProvider> Provider(new ImdsCredentialProvider(Opts));
+ SigV4Credentials Creds = Provider->GetCredentials();
+ CHECK(Creds.AccessKeyId.empty());
+}
+
+TEST_SUITE_END();
+
+#endif
+
+} // namespace zen