diff options
| author | Dan Engelbrecht <[email protected]> | 2026-05-05 14:59:21 +0200 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-05-05 14:59:21 +0200 |
| commit | 46f456ffd4d0717a035253ff9076ca6ee664e536 (patch) | |
| tree | 69d7a9a43b9874fd3990c43aa5ff4135c35d53d9 /src | |
| parent | watchdog ephemeral port exhaust (#1022) (diff) | |
| download | archived-zen-46f456ffd4d0717a035253ff9076ca6ee664e536.tar.xz archived-zen-46f456ffd4d0717a035253ff9076ca6ee664e536.zip | |
hub async s3 client (#1024)
- Feature: `AsyncHttpClient` adds cancellable request tokens, streaming GET to a file (`AsyncDownload`), zero-copy chunk-callback GET (`AsyncStream`), pull-mode body source for streaming `AsyncPut`, retry layer mirroring the synchronous client, and a submit-side in-flight cap (`HttpClientSettings::MaxConcurrentRequests`) so hub-scale fanout against a single host cannot stall queued handles into curl's connect-timeout window
- Feature: Hub hydration can route S3 transfers through a non-blocking `AsyncHttpClient` (curl_multi + asio) backed by a single io thread; hydrate and dehydrate now pipeline requests instead of blocking worker threads
- `--hub-hydration-async-enabled` (Lua: `hub.hydration.async.enabled`, default true)
- `--hub-hydration-async-max-concurrent-requests` (Lua: `hub.hydration.async.maxconcurrentrequests`, default `clamp(cpu*4, 128, 512)`)
- Feature: Hub provision/deprovision/obliterate now run as two phases on separate worker pools so per-module hydration cannot starve child-process spawn/despawn (and vice versa)
- New `--hub-instance-spawn-threads` (Lua: `hub.instance.spawnthreads`, default `clamp(cpu/8, 4, 16)`) drives child-process spawn/despawn
- `--hub-instance-provision-threads` (Lua: `hub.instance.provisionthreads`) now drives per-module hydrate/dehydrate scheduling only; default changed from `max(cpu/4, 2)` to `clamp(cpu/8, 4, 12)`
- `--hub-hydration-threads` (Lua: `hub.hydration.threads`) now controls per-file workers inside a single hydrate/dehydrate; default changed from `max(cpu/4, 2)` to `clamp(cpu/8, 4, 12)`
- Feature: `AsyncHttpClient` owns its `asio::io_context` and one io thread by default; the `(BaseUri, io_context&)` constructor is preserved for callers that want to share an externally-driven `io_context` across clients (caller MUST keep the loop running until the client destructs)
- Feature: `Hub::Configuration` C++ struct fields renamed (`OptionalProvisionWorkerPool`/`OptionalHydrationWorkerPool` -> `OptionalProvisionPool`/`OptionalSpawnPool`/`OptionalHydrationPool`). Embedders constructing `Hub` directly must update field names; provision and spawn pools must both be set or both null (asserted at construction).
- Bugfix: `S3Client` signing-key cache no longer returns stale signatures after IMDS-rotated credentials change `AccessKeyId`; cache is now keyed on `(DateStamp, AccessKeyId)`
Diffstat (limited to 'src')
27 files changed, 7739 insertions, 1383 deletions
diff --git a/src/zencore/include/zencore/parallelwork.h b/src/zencore/include/zencore/parallelwork.h index 536b0a056..d9b20b9d7 100644 --- a/src/zencore/include/zencore/parallelwork.h +++ b/src/zencore/include/zencore/parallelwork.h @@ -13,6 +13,8 @@ namespace zen { class ParallelWork { public: + class ExternalWorkToken; + ParallelWork(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, WorkerThreadPool::EMode Mode); ~ParallelWork(); @@ -74,9 +76,53 @@ public: Latch& PendingWork() { return m_PendingWork; } + // Register a unit of work whose completion is signalled out-of-band (typically + // from an async callback firing on a different thread). Counter increments now; + // the returned token's Complete()/Fail() decrements. Used by S3AsyncStorage so + // in-flight S3 requests count against ParallelWork without occupying a worker. + [[nodiscard]] ExternalWorkToken RegisterExternal(); + + class ExternalWorkToken + { + public: + ExternalWorkToken() = default; + + ExternalWorkToken(const ExternalWorkToken&) = delete; + ExternalWorkToken& operator=(const ExternalWorkToken&) = delete; + + ExternalWorkToken(ExternalWorkToken&& Other) noexcept : m_Owner(Other.m_Owner) { Other.m_Owner = nullptr; } + ExternalWorkToken& operator=(ExternalWorkToken&& Other) noexcept + { + if (this != &Other) + { + Release(); + m_Owner = Other.m_Owner; + Other.m_Owner = nullptr; + } + return *this; + } + + ~ExternalWorkToken() { Release(); } + + void Complete(); + void Fail(std::exception_ptr Ex); + + bool IsActive() const { return m_Owner != nullptr; } + + private: + friend class ParallelWork; + + explicit ExternalWorkToken(ParallelWork* Owner) : m_Owner(Owner) {} + + void Release(); + + ParallelWork* m_Owner = nullptr; + }; + private: ExceptionCallback DefaultErrorFunction(); void RethrowErrors(); + void RecordExternalError(std::exception_ptr Ex); std::atomic<bool>& m_AbortFlag; std::atomic<bool>& m_PauseFlag; diff --git a/src/zencore/parallelwork.cpp b/src/zencore/parallelwork.cpp index 94696f479..ec00fe0bc 100644 --- a/src/zencore/parallelwork.cpp +++ b/src/zencore/parallelwork.cpp @@ -2,6 +2,7 @@ #include <zencore/parallelwork.h> +#include <zencore/assertfmt.h> #include <zencore/callstack.h> #include <zencore/except.h> #include <zencore/fmtutils.h> @@ -11,6 +12,8 @@ #if ZEN_WITH_TESTS # include <zencore/testing.h> + +# include <thread> #endif // ZEN_WITH_TESTS namespace zen { @@ -90,6 +93,65 @@ ParallelWork::DefaultErrorFunction() } void +ParallelWork::RecordExternalError(std::exception_ptr Ex) +{ + m_ErrorLock.WithExclusiveLock([&]() { m_Errors.push_back(Ex); }); + m_AbortFlag = true; +} + +ParallelWork::ExternalWorkToken +ParallelWork::RegisterExternal() +{ + ZEN_ASSERT(!m_DispatchComplete); + m_PendingWork.AddCount(1); + return ExternalWorkToken(this); +} + +void +ParallelWork::ExternalWorkToken::Complete() +{ + ZEN_ASSERT(m_Owner != nullptr); + m_Owner->m_PendingWork.CountDown(); + m_Owner = nullptr; +} + +void +ParallelWork::ExternalWorkToken::Fail(std::exception_ptr Ex) +{ + ZEN_ASSERT(m_Owner != nullptr); + // Null exception_ptr would propagate as std::bad_exception via + // rethrow_exception(nullptr) and mask the real failure mode. Catches + // patterns like MakeGuard([Token]{ Token->Fail(std::current_exception()); }) + // firing on a normal-return path where no exception is in flight. + ZEN_ASSERT(Ex != nullptr); + m_Owner->RecordExternalError(Ex); + m_Owner->m_PendingWork.CountDown(); + m_Owner = nullptr; +} + +void +ParallelWork::ExternalWorkToken::Release() +{ + if (m_Owner != nullptr) + { + // Tests should fail loudly so that any leaked path surfaces immediately; + // in production we keep the safety-net countdown so a leak does not deadlock + // the latch but log it as an error rather than a warning - this is always + // a programming bug. +#if ZEN_WITH_TESTS + ZEN_ASSERT_FORMAT(false, "ParallelWork::ExternalWorkToken destroyed without Complete()/Fail()"); +#else + ZEN_ERROR("ParallelWork::ExternalWorkToken destroyed without Complete()/Fail(); decrementing latch as safety net"); + // Surface as an error from Wait()/RethrowErrors() so the caller does not see a phantom success. + m_Owner->RecordExternalError( + std::make_exception_ptr(std::runtime_error("ParallelWork::ExternalWorkToken destroyed without Complete()/Fail()"))); +#endif + m_Owner->m_PendingWork.CountDown(); + m_Owner = nullptr; + } +} + +void ParallelWork::Wait(int32_t UpdateIntervalMS, UpdateCallback&& UpdateCallback) { ZEN_ASSERT(!m_DispatchComplete); @@ -257,6 +319,100 @@ TEST_CASE("parallellwork.limitqueue") Work.Wait(); } +TEST_CASE("parallellwork.external_basic") +{ + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + std::vector<ParallelWork::ExternalWorkToken> Tokens; + for (uint32_t I = 0; I < 5; I++) + { + Tokens.push_back(Work.RegisterExternal()); + } + for (auto& Token : Tokens) + { + Token.Complete(); + } + + Work.Wait(); + CHECK_FALSE(AbortFlag.load()); +} + +TEST_CASE("parallellwork.external_completes_from_other_thread") +{ + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + auto Token = Work.RegisterExternal(); + std::thread Worker([Token = std::move(Token)]() mutable { + Sleep(20); + Token.Complete(); + }); + + Work.Wait(); + Worker.join(); + CHECK_FALSE(AbortFlag.load()); +} + +TEST_CASE("parallellwork.external_fail_propagates_exception") +{ + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + auto Token = Work.RegisterExternal(); + try + { + throw std::runtime_error("external work failed"); + } + catch (...) + { + Token.Fail(std::current_exception()); + } + + CHECK_THROWS_WITH(Work.Wait(), "external work failed"); + CHECK(AbortFlag.load()); +} + +TEST_CASE("parallellwork.external_mixed_with_scheduled") +{ + WorkerThreadPool WorkerPool(2); + + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + std::atomic<uint32_t> ScheduledCount = 0; + for (uint32_t I = 0; I < 3; I++) + { + Work.ScheduleWork(WorkerPool, [&ScheduledCount](std::atomic<bool>& AbortFlag) { + ZEN_UNUSED(AbortFlag); + ScheduledCount++; + }); + } + + std::vector<ParallelWork::ExternalWorkToken> Tokens; + for (uint32_t I = 0; I < 3; I++) + { + Tokens.push_back(Work.RegisterExternal()); + } + + std::thread Completer([&]() { + for (auto& Token : Tokens) + { + Sleep(5); + Token.Complete(); + } + }); + + Work.Wait(); + Completer.join(); + + CHECK_EQ(ScheduledCount.load(), 3u); +} + TEST_SUITE_END(); void diff --git a/src/zenhttp/asynchttpclient_test.cpp b/src/zenhttp/asynchttpclient_test.cpp index 151863370..0b6877c7b 100644 --- a/src/zenhttp/asynchttpclient_test.cpp +++ b/src/zenhttp/asynchttpclient_test.cpp @@ -5,21 +5,24 @@ #if ZEN_WITH_TESTS +# include <zencore/basicfile.h> # include <zencore/iobuffer.h> # include <zencore/logging.h> # include <zencore/scopeguard.h> # include <zencore/testing.h> # include <zencore/testutils.h> +# include <zencore/thread.h> # include "servers/httpasio.h" -# include <atomic> -# include <thread> - ZEN_THIRD_PARTY_INCLUDES_START # include <asio.hpp> ZEN_THIRD_PARTY_INCLUDES_END +# include <atomic> +# include <cstring> +# include <thread> + namespace zen { using namespace std::literals; @@ -67,13 +70,65 @@ public: Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"ok\":true}"); }, HttpVerb::kGet); + + m_Router.RegisterRoute( + "large", + [](HttpRouterRequest& Req) { + // 4 MB body so the response exercises chunked write callbacks. + IoBuffer Body(4u * 1024u * 1024u); + uint8_t* Data = static_cast<uint8_t*>(Body.MutableData()); + for (size_t I = 0; I < Body.GetSize(); ++I) + { + Data[I] = static_cast<uint8_t>(I & 0xFF); + } + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Body); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "slow", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + std::string_view MsStr = Params.GetValue("ms"); + int Ms = MsStr.empty() ? 100 : std::atoi(std::string(MsStr).c_str()); + m_SlowHits.fetch_add(1, std::memory_order_relaxed); + zen::Sleep(Ms); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow ok"); + }, + HttpVerb::kGet); + + // Returns 503 for the first ?fail=N requests, then 200 for the rest. + m_Router.RegisterRoute( + "flaky", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + std::string_view FailStr = Params.GetValue("fail"); + const int FailN = FailStr.empty() ? 0 : std::atoi(std::string(FailStr).c_str()); + const int Hit = m_FlakyHits.fetch_add(1, std::memory_order_relaxed) + 1; + if (Hit <= FailN) + { + HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable, HttpContentType::kText, "fail"); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + }, + HttpVerb::kGet); } + std::atomic<int>& SlowHits() { return m_SlowHits; } + std::atomic<int>& FlakyHits() { return m_FlakyHits; } + virtual const char* BaseUri() const override { return "/api/async-test/"; } virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); } private: HttpRequestRouter m_Router; + std::atomic<int> m_SlowHits{0}; + std::atomic<int> m_FlakyHits{0}; }; ////////////////////////////////////////////////////////////////////////// @@ -118,189 +173,593 @@ struct AsyncTestServerFixture TEST_SUITE_BEGIN("http.asynchttpclient"); -TEST_CASE("asynchttpclient.future.verbs") +// Future API + callback API + verb dispatch + payload echo + lifecycle. All +// scopes share one fixture and one default-settings client. Per-scope sets +// up its own promises/futures. +TEST_CASE("asynchttpclient.basic") { AsyncTestServerFixture Fixture; AsyncHttpClient Client = Fixture.MakeClient(); - SUBCASE("GET returns 200 with expected body") + // future.verbs - GET / POST / PUT / DELETE / HEAD echo the verb. { - auto Future = Client.Get("/api/async-test/echo/method"); - auto Resp = Future.get(); + auto Resp = Client.Get("/api/async-test/echo/method").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), "GET"); } - - SUBCASE("POST dispatches correctly") { - auto Future = Client.Post("/api/async-test/echo/method"); - auto Resp = Future.get(); + auto Resp = Client.Post("/api/async-test/echo/method").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), "POST"); } - - SUBCASE("PUT dispatches correctly") { - auto Future = Client.Put("/api/async-test/echo/method"); - auto Resp = Future.get(); + auto Resp = Client.Put("/api/async-test/echo/method").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), "PUT"); } - - SUBCASE("DELETE dispatches correctly") { - auto Future = Client.Delete("/api/async-test/echo/method"); - auto Resp = Future.get(); + auto Resp = Client.Delete("/api/async-test/echo/method").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), "DELETE"); } - - SUBCASE("HEAD returns 200 with empty body") { - auto Future = Client.Head("/api/async-test/echo/method"); - auto Resp = Future.get(); + auto Resp = Client.Head("/api/async-test/echo/method").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), ""sv); } -} -TEST_CASE("asynchttpclient.future.get") -{ - AsyncTestServerFixture Fixture; - AsyncHttpClient Client = Fixture.MakeClient(); - - SUBCASE("simple GET with text response") + // future.get - text body, JSON body, 204 NoContent. { - auto Future = Client.Get("/api/async-test/hello"); - auto Resp = Future.get(); + auto Resp = Client.Get("/api/async-test/hello").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK); CHECK_EQ(Resp.AsText(), "hello world"); } - - SUBCASE("GET returning JSON") { - auto Future = Client.Get("/api/async-test/json"); - auto Resp = Future.get(); + auto Resp = Client.Get("/api/async-test/json").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), "{\"ok\":true}"); } - - SUBCASE("GET 204 NoContent") { - auto Future = Client.Get("/api/async-test/nocontent"); - auto Resp = Future.get(); + auto Resp = Client.Get("/api/async-test/nocontent").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); } -} - -TEST_CASE("asynchttpclient.future.post.with.payload") -{ - AsyncTestServerFixture Fixture; - AsyncHttpClient Client = Fixture.MakeClient(); - std::string_view PayloadStr = "async payload data"; - IoBuffer Payload(IoBuffer::Clone, PayloadStr.data(), PayloadStr.size()); - Payload.SetContentType(ZenContentType::kText); + // future.post.with.payload + future.put.with.payload - echo round-trips. + { + std::string_view Str = "async payload data"; + IoBuffer Payload(IoBuffer::Clone, Str.data(), Str.size()); + Payload.SetContentType(ZenContentType::kText); + auto Resp = Client.Post("/api/async-test/echo", Payload).get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "async payload data"); + } + { + std::string_view Str = "put payload"; + IoBuffer Payload(IoBuffer::Clone, Str.data(), Str.size()); + Payload.SetContentType(ZenContentType::kText); + auto Resp = Client.Put("/api/async-test/echo", Payload).get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "put payload"); + } - auto Future = Client.Post("/api/async-test/echo", Payload); - auto Resp = Future.get(); - CHECK(Resp.IsSuccess()); - CHECK_EQ(Resp.AsText(), "async payload data"); -} + // callback - AsyncGet completion fires the callback. + { + std::promise<HttpClient::Response> Promise; + auto Future = Promise.get_future(); + Client.AsyncGet("/api/async-test/hello", [&Promise](HttpClient::Response Resp) { Promise.set_value(std::move(Resp)); }); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } -TEST_CASE("asynchttpclient.future.put.with.payload") -{ - AsyncTestServerFixture Fixture; - AsyncHttpClient Client = Fixture.MakeClient(); + // concurrent.requests - multiple verbs in flight at once. + { + auto F1 = Client.Get("/api/async-test/hello"); + auto F2 = Client.Get("/api/async-test/json"); + auto F3 = Client.Post("/api/async-test/echo/method"); + auto F4 = Client.Delete("/api/async-test/echo/method"); + auto R1 = F1.get(); + auto R2 = F2.get(); + auto R3 = F3.get(); + auto R4 = F4.get(); + CHECK(R1.IsSuccess()); + CHECK_EQ(R1.AsText(), "hello world"); + CHECK(R2.IsSuccess()); + CHECK_EQ(R2.AsText(), "{\"ok\":true}"); + CHECK(R3.IsSuccess()); + CHECK_EQ(R3.AsText(), "POST"); + CHECK(R4.IsSuccess()); + CHECK_EQ(R4.AsText(), "DELETE"); + } - std::string_view PutStr = "put payload"; - IoBuffer Payload(IoBuffer::Clone, PutStr.data(), PutStr.size()); - Payload.SetContentType(ZenContentType::kText); + // cancel.after.completion.is.noop - late Cancel must be quiet. + { + auto Resp = Client.Get("/api/async-test/hello").get(); + REQUIRE(Resp.IsSuccess()); + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + AsyncRequestToken Token = Client.AsyncGet("/api/async-test/hello", [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); + auto Resp2 = F.get(); + REQUIRE(Resp2.IsSuccess()); + Token.Cancel(); // no-op; must not crash + } - auto Future = Client.Put("/api/async-test/echo", Payload); - auto Resp = Future.get(); - CHECK(Resp.IsSuccess()); - CHECK_EQ(Resp.AsText(), "put payload"); + // lifecycle.repeated.construct.destroy - 8 fresh clients against the same + // server. Catches io thread / curl_multi leaks across construct/destroy. + for (int I = 0; I < 8; ++I) + { + AsyncHttpClient Local = Fixture.MakeClient(); + auto Resp = Local.Get("/api/async-test/hello").get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } } -TEST_CASE("asynchttpclient.callback") +// Submit-side behavior: external io_context, shutdown cancel, no-queue +// contract, cross-thread cancel-before-submit race, unlimited fan-out. +// +// MaxConcurrentRequests is applied as curl connection caps only; cap-level +// fan-out throttling lives in the storage layer (see +// server.s3asyncstorage.admission.fanout). +TEST_CASE("asynchttpclient.submit_and_shutdown") { AsyncTestServerFixture Fixture; - AsyncHttpClient Client = Fixture.MakeClient(); - std::promise<HttpClient::Response> Promise; - auto Future = Promise.get_future(); + // external.io_context - caller drives the run loop. Verifies the + // Cleanup-via-promise path in the dtor. + { + asio::io_context IoContext; + auto WorkGuard = asio::make_work_guard(IoContext); + std::thread IoThread([&IoContext]() { IoContext.run(); }); + { + AsyncHttpClient Client = Fixture.MakeClient(IoContext); + auto Resp = Client.Get("/api/async-test/echo/method").get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "GET"); + } + WorkGuard.reset(); + IoThread.join(); + } + + // shutdown.cancels.in.flight - dtor synthesizes cancel for all in-flight. + { + const int N = 6; + std::vector<std::promise<HttpClient::Response>> Promises(N); + std::vector<std::future<HttpClient::Response>> Futures; + for (auto& P : Promises) + { + Futures.push_back(P.get_future()); + } + { + AsyncHttpClient Client = Fixture.MakeClient(); + std::vector<AsyncRequestToken> Tokens; + for (int I = 0; I < N; ++I) + { + Tokens.push_back(Client.AsyncGet("/api/async-test/slow?ms=2000", + [&Promises, I](HttpClient::Response R) { Promises[I].set_value(std::move(R)); })); + } + Sleep(50); // let requests actually start before client teardown + } + int CancelCount = 0; + for (auto& F : Futures) + { + REQUIRE(F.wait_for(std::chrono::seconds(5)) == std::future_status::ready); + HttpClient::Response R = F.get(); + if (R.Error.has_value() && R.Error->ErrorCode == HttpClientErrorCode::kRequestCancelled) + { + ++CancelCount; + } + } + CHECK(CancelCount == N); + } + + // contract.no.queue - all submissions reach the network despite cap=4. + // Fan-out gating is the storage layer's responsibility. + { + HttpClientSettings Settings; + Settings.MaxConcurrentRequests = 4; + AsyncHttpClient Client = Fixture.MakeClient(Settings); + + const int N = 100; + std::vector<std::promise<HttpClient::Response>> Promises(N); + std::vector<std::future<HttpClient::Response>> Futures; + std::vector<AsyncRequestToken> Tokens; + for (auto& P : Promises) + { + Futures.push_back(P.get_future()); + } + for (int I = 0; I < N; ++I) + { + Tokens.push_back( + Client.AsyncGet("/api/async-test/hello", [&Promises, I](HttpClient::Response R) { Promises[I].set_value(std::move(R)); })); + } + for (auto& F : Futures) + { + REQUIRE(F.wait_for(std::chrono::seconds(15)) == std::future_status::ready); + CHECK(F.get().IsSuccess()); + } + } - Client.AsyncGet("/api/async-test/hello", [&Promise](HttpClient::Response Resp) { Promise.set_value(std::move(Resp)); }); + // cancel.before.submit - cross-thread race: SubmitFromSpec posted, Cancel + // from another thread fires before/after submit handler runs. All callbacks + // must surface kRequestCancelled exactly once. + { + AsyncHttpClient Client = Fixture.MakeClient(); + const int N = 16; + std::vector<std::promise<HttpClient::Response>> Promises(N); + std::vector<std::future<HttpClient::Response>> Futures; + std::vector<AsyncRequestToken> Tokens; + for (auto& P : Promises) + { + Futures.push_back(P.get_future()); + } + for (int I = 0; I < N; ++I) + { + Tokens.push_back(Client.AsyncGet("/api/async-test/slow?ms=2000", + [&Promises, I](HttpClient::Response R) { Promises[I].set_value(std::move(R)); })); + } + std::thread CancelThread([&]() { + for (auto& T : Tokens) + { + T.Cancel(); + } + }); + for (auto& F : Futures) + { + REQUIRE(F.wait_for(std::chrono::seconds(5)) == std::future_status::ready); + HttpClient::Response R = F.get(); + REQUIRE(R.Error.has_value()); + CHECK(R.Error->ErrorCode == HttpClientErrorCode::kRequestCancelled); + } + CancelThread.join(); + } - auto Resp = Future.get(); - CHECK(Resp.IsSuccess()); - CHECK_EQ(Resp.AsText(), "hello world"); + // unlimited.parallel.fanout - 8 parallel 100ms requests with default + // settings (no cap) finish well under the 800ms serial floor. Sized to one + // batch on the asio server's 8-thread pool so per-request setup overhead on + // slow CI agents does not dominate; threshold leaves >=2x margin over the + // ~100ms parallel ideal. + { + AsyncHttpClient Client = Fixture.MakeClient(); + const int N = 8; + std::vector<std::promise<HttpClient::Response>> Promises(N); + std::vector<std::future<HttpClient::Response>> Futures; + std::vector<AsyncRequestToken> Tokens; + for (auto& P : Promises) + { + Futures.push_back(P.get_future()); + } + const auto Start = std::chrono::steady_clock::now(); + for (int I = 0; I < N; ++I) + { + Tokens.push_back(Client.AsyncGet("/api/async-test/slow?ms=100", + [&Promises, I](HttpClient::Response R) { Promises[I].set_value(std::move(R)); })); + } + for (auto& F : Futures) + { + REQUIRE(F.wait_for(std::chrono::seconds(10)) == std::future_status::ready); + CHECK(F.get().IsSuccess()); + } + const auto ElapsedMs = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - Start).count(); + CHECK(ElapsedMs < 600); + } } -TEST_CASE("asynchttpclient.concurrent.requests") +// AsyncStream coverage: chunk delivery, OnData abort, mid-flight cancel. +TEST_CASE("asynchttpclient.stream") { AsyncTestServerFixture Fixture; - AsyncHttpClient Client = Fixture.MakeClient(); - - // Fire multiple requests concurrently - auto Future1 = Client.Get("/api/async-test/hello"); - auto Future2 = Client.Get("/api/async-test/json"); - auto Future3 = Client.Post("/api/async-test/echo/method"); - auto Future4 = Client.Delete("/api/async-test/echo/method"); - auto Resp1 = Future1.get(); - auto Resp2 = Future2.get(); - auto Resp3 = Future3.get(); - auto Resp4 = Future4.get(); + // stream.basic - 4 MiB stream completes; bytes accounted; no body buffering. + { + AsyncHttpClient Client = Fixture.MakeClient(); + std::atomic<uint64_t> TotalReceived{0}; + std::atomic<uint32_t> ChunkCount{0}; + std::atomic<uint64_t> TotalSizeSeen{0}; + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + + auto Token = Client.AsyncStream( + "/api/async-test/large", + [&](const uint8_t* /*Data*/, size_t Size, uint64_t TotalSize) -> bool { + TotalReceived.fetch_add(Size, std::memory_order_relaxed); + ChunkCount.fetch_add(1, std::memory_order_relaxed); + if (TotalSize != 0) + { + TotalSizeSeen.store(TotalSize, std::memory_order_relaxed); + } + return true; + }, + [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); - CHECK(Resp1.IsSuccess()); - CHECK_EQ(Resp1.AsText(), "hello world"); + REQUIRE(F.wait_for(std::chrono::seconds(10)) == std::future_status::ready); + auto Resp = F.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(TotalReceived.load(), 4u * 1024u * 1024u); + CHECK_EQ(TotalSizeSeen.load(), 4u * 1024u * 1024u); + CHECK(ChunkCount.load() >= 1); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 0u); + } - CHECK(Resp2.IsSuccess()); - CHECK_EQ(Resp2.AsText(), "{\"ok\":true}"); + // stream.ondata.abort - returning false from OnData stops the transfer. + { + AsyncHttpClient Client = Fixture.MakeClient(); + std::atomic<uint32_t> ChunkCount{0}; + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + + Client.AsyncStream( + "/api/async-test/large", + [&](const uint8_t*, size_t, uint64_t) -> bool { + ChunkCount.fetch_add(1, std::memory_order_relaxed); + return false; + }, + [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); - CHECK(Resp3.IsSuccess()); - CHECK_EQ(Resp3.AsText(), "POST"); + REQUIRE(F.wait_for(std::chrono::seconds(10)) == std::future_status::ready); + auto Resp = F.get(); + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + CHECK(ChunkCount.load() <= 1); + } - CHECK(Resp4.IsSuccess()); - CHECK_EQ(Resp4.AsText(), "DELETE"); + // stream.cancel.mid.flight - Cancel during long stream surfaces + // kRequestCancelled. RetryCount=0 so no retry layer masks the error code. + { + HttpClientSettings Settings; + Settings.RetryCount = 0; + AsyncHttpClient Client = Fixture.MakeClient(Settings); + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + auto Token = Client.AsyncStream( + "/api/async-test/slow?ms=2000", + [](const uint8_t*, size_t, uint64_t) -> bool { return true; }, + [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); + Sleep(50); + Token.Cancel(); + REQUIRE(F.wait_for(std::chrono::seconds(5)) == std::future_status::ready); + auto Resp = F.get(); + CHECK_FALSE(Resp.IsSuccess()); + REQUIRE(Resp.Error.has_value()); + CHECK(Resp.Error->ErrorCode == HttpClientErrorCode::kRequestCancelled); + } } -TEST_CASE("asynchttpclient.external.io_context") +// High-fanout, mixed verbs, large payload, streaming-source PUT. +TEST_CASE("asynchttpclient.stress") { AsyncTestServerFixture Fixture; - asio::io_context IoContext; - auto WorkGuard = asio::make_work_guard(IoContext); - std::thread IoThread([&IoContext]() { IoContext.run(); }); + // high.fanout - 32 unlimited parallel GETs all succeed. + { + AsyncHttpClient Client = Fixture.MakeClient(); + const int N = 32; + std::vector<std::promise<HttpClient::Response>> Promises(N); + std::vector<std::future<HttpClient::Response>> Futures; + std::vector<AsyncRequestToken> Tokens; + for (auto& Pr : Promises) + { + Futures.push_back(Pr.get_future()); + } + for (int I = 0; I < N; ++I) + { + Tokens.push_back( + Client.AsyncGet("/api/async-test/hello", [&Promises, I](HttpClient::Response R) { Promises[I].set_value(std::move(R)); })); + } + for (int I = 0; I < N; ++I) + { + REQUIRE(Futures[I].wait_for(std::chrono::seconds(15)) == std::future_status::ready); + auto Resp = Futures[I].get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + } + // mixed.verbs.concurrent - GET/POST/PUT/DELETE/Stream concurrently under cap=4. { - AsyncHttpClient Client = Fixture.MakeClient(IoContext); + HttpClientSettings Settings; + Settings.MaxConcurrentRequests = 4; + AsyncHttpClient Client = Fixture.MakeClient(Settings); + + IoBuffer Payload(64); + std::memset(Payload.MutableData(), 0xAB, Payload.GetSize()); + + auto FGet = Client.Get("/api/async-test/hello"); + auto FPost = Client.Post("/api/async-test/echo", Payload); + auto FPut = Client.Put("/api/async-test/echo", Payload); + auto FDelete = Client.Delete("/api/async-test/echo/method"); + auto FJson = Client.Get("/api/async-test/json"); + + std::atomic<uint64_t> StreamBytes{0}; + std::promise<HttpClient::Response> StreamP; + auto StreamF = StreamP.get_future(); + Client.AsyncStream( + "/api/async-test/large", + [&](const uint8_t*, size_t Size, uint64_t) -> bool { + StreamBytes.fetch_add(Size, std::memory_order_relaxed); + return true; + }, + [&StreamP](HttpClient::Response R) { StreamP.set_value(std::move(R)); }); + + auto Get = FGet.get(); + auto Post = FPost.get(); + auto Put = FPut.get(); + auto Delete = FDelete.get(); + auto Json = FJson.get(); + REQUIRE(StreamF.wait_for(std::chrono::seconds(10)) == std::future_status::ready); + auto Stream = StreamF.get(); + + CHECK(Get.IsSuccess()); + CHECK_EQ(Get.AsText(), "hello world"); + CHECK(Post.IsSuccess()); + CHECK_EQ(Post.ResponsePayload.GetSize(), Payload.GetSize()); + CHECK(Put.IsSuccess()); + CHECK_EQ(Put.ResponsePayload.GetSize(), Payload.GetSize()); + CHECK(Delete.IsSuccess()); + CHECK_EQ(Delete.AsText(), "DELETE"); + CHECK(Json.IsSuccess()); + CHECK_EQ(Json.AsText(), "{\"ok\":true}"); + CHECK(Stream.IsSuccess()); + CHECK_EQ(StreamBytes.load(), 4u * 1024u * 1024u); + } - auto Future = Client.Get("/api/async-test/hello"); - auto Resp = Future.get(); - CHECK(Resp.IsSuccess()); - CHECK_EQ(Resp.AsText(), "hello world"); + // large.put.roundtrip - 4 MiB PUT echoed back; spot-check positions. + { + AsyncHttpClient Client = Fixture.MakeClient(); + const size_t Size = 4u * 1024u * 1024u; + IoBuffer Payload(Size); + uint8_t* Data = static_cast<uint8_t*>(Payload.MutableData()); + for (size_t I = 0; I < Size; ++I) + { + Data[I] = static_cast<uint8_t>((I * 31u) & 0xFF); + } + auto Resp = Client.Put("/api/async-test/echo", Payload).get(); + REQUIRE(Resp.IsSuccess()); + REQUIRE_EQ(Resp.ResponsePayload.GetSize(), Size); + const uint8_t* RecvData = static_cast<const uint8_t*>(Resp.ResponsePayload.GetData()); + CHECK(RecvData[0] == Data[0]); + CHECK(RecvData[1u << 10] == Data[1u << 10]); + CHECK(RecvData[Size / 2] == Data[Size / 2]); + CHECK(RecvData[Size - 1] == Data[Size - 1]); } - WorkGuard.reset(); - IoThread.join(); + // streaming.put.source - AsyncPut(url, size, source, callback). + // Part 1: 2 MiB echo round-trip. + // Part 2: source returning 0 with offset < TotalSize aborts via CURL_READFUNC_ABORT. + { + AsyncHttpClient Client = Fixture.MakeClient(); + + { + const size_t Size = 2u * 1024u * 1024u; + std::vector<uint8_t> Source(Size); + for (size_t I = 0; I < Size; ++I) + { + Source[I] = static_cast<uint8_t>((I * 17u + 5u) & 0xFF); + } + std::atomic<uint64_t> SourceCalls{0}; + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + Client.AsyncPut( + "/api/async-test/echo", + Size, + [&](uint8_t* Dst, size_t MaxBytes, uint64_t Offset) -> size_t { + SourceCalls.fetch_add(1, std::memory_order_relaxed); + const size_t Remaining = Size > Offset ? Size - static_cast<size_t>(Offset) : 0; + const size_t Take = std::min(MaxBytes, Remaining); + if (Take == 0) + { + return 0; + } + std::memcpy(Dst, Source.data() + Offset, Take); + return Take; + }, + [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); + REQUIRE(F.wait_for(std::chrono::seconds(15)) == std::future_status::ready); + HttpClient::Response Resp = F.get(); + REQUIRE(Resp.IsSuccess()); + REQUIRE_EQ(Resp.ResponsePayload.GetSize(), Size); + CHECK(SourceCalls.load() >= 1); + const uint8_t* RecvData = static_cast<const uint8_t*>(Resp.ResponsePayload.GetData()); + CHECK(RecvData[0] == Source[0]); + CHECK(RecvData[Size / 2] == Source[Size / 2]); + CHECK(RecvData[Size - 1] == Source[Size - 1]); + CHECK(RecvData[1234567] == Source[1234567]); + } + + { + const size_t DeclaredSize = 1024u * 1024u; + std::atomic<uint64_t> SourceCalls{0}; + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + Client.AsyncPut( + "/api/async-test/echo", + DeclaredSize, + [&](uint8_t* Dst, size_t MaxBytes, uint64_t /*Offset*/) -> size_t { + const uint64_t Hits = SourceCalls.fetch_add(1, std::memory_order_relaxed); + if (Hits >= 1) + { + return 0; // abort: 0 returned with Offset < TotalSize + } + const size_t Take = std::min<size_t>(MaxBytes, 64u); + std::memset(Dst, 0xAB, Take); + return Take; + }, + [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); + REQUIRE(F.wait_for(std::chrono::seconds(15)) == std::future_status::ready); + HttpClient::Response Resp = F.get(); + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + } } -TEST_CASE("asynchttpclient.connection.error") +// Connection-error / retry semantics targeting a dead port. No server fixture +// needed; each scope constructs its own client with bespoke timeout settings. +TEST_CASE("asynchttpclient.connection_errors") { - // Connect to a port where nothing is listening - AsyncHttpClient Client("127.0.0.1:1", HttpClientSettings{.ConnectTimeout = std::chrono::milliseconds(500)}); + // connection.error - Get against dead port surfaces connection error. + { + AsyncHttpClient Client("127.0.0.1:1", HttpClientSettings{.ConnectTimeout = std::chrono::milliseconds(500)}); + auto Resp = Client.Get("/should-fail").get(); + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + CHECK(Resp.Error->IsConnectionError()); + } + + // retry.respected.on.connection.error - 2 retries adds >=300ms backoff + // (100ms + 200ms accumulated past the initial attempt). + { + HttpClientSettings Settings{ + .ConnectTimeout = std::chrono::milliseconds(50), + .RetryCount = 2, + }; + AsyncHttpClient Client("127.0.0.1:1", Settings); + const auto Start = std::chrono::steady_clock::now(); + auto Resp = Client.Get("/should-fail").get(); + const auto Elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - Start); + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Elapsed.count() >= 300); + } - auto Future = Client.Get("/should-fail"); - auto Resp = Future.get(); + // cancel.in.flight - Cancel mid-connect doesn't hang (regardless of which + // side of the strand wins the cancel-vs-ECONNREFUSED race). + { + HttpClientSettings Settings{ + .ConnectTimeout = std::chrono::milliseconds(60000), + .RetryCount = 0, + }; + AsyncHttpClient Client("127.0.0.1:1", Settings); + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + const auto Start = std::chrono::steady_clock::now(); + AsyncRequestToken Token = Client.AsyncGet("/should-cancel", [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); + Token.Cancel(); + REQUIRE(F.wait_for(std::chrono::seconds(5)) == std::future_status::ready); + const auto Elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - Start); + auto Resp = F.get(); + CHECK_FALSE(Resp.IsSuccess()); + REQUIRE(Resp.Error.has_value()); + CHECK((Resp.Error->ErrorCode == HttpClientErrorCode::kRequestCancelled || + Resp.Error->ErrorCode == HttpClientErrorCode::kConnectionFailure)); + CHECK(Elapsed.count() < 5000); + } - CHECK_FALSE(Resp.IsSuccess()); - CHECK(Resp.Error.has_value()); - CHECK(Resp.Error->IsConnectionError()); + // retry.zero.no.retry - RetryCount=0 returns promptly (< 1s) with no + // extra backoff past ConnectTimeout. + { + HttpClientSettings Settings{ + .ConnectTimeout = std::chrono::milliseconds(50), + .RetryCount = 0, + }; + AsyncHttpClient Client("127.0.0.1:1", Settings); + const auto Start = std::chrono::steady_clock::now(); + auto Resp = Client.Get("/should-fail").get(); + const auto Elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - Start); + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Elapsed.count() < 1000); + } } TEST_SUITE_END(); diff --git a/src/zenhttp/clients/asynchttpclient.cpp b/src/zenhttp/clients/asynchttpclient.cpp index ea88fc783..e7e904f89 100644 --- a/src/zenhttp/clients/asynchttpclient.cpp +++ b/src/zenhttp/clients/asynchttpclient.cpp @@ -4,8 +4,11 @@ #include "httpclientcurlhelpers.h" +#include <zencore/basicfile.h> #include <zencore/filesystem.h> +#include <zencore/fmtutils.h> #include <zencore/logging.h> +#include <zencore/scopeguard.h> #include <zencore/session.h> #include <zencore/thread.h> #include <zencore/trace.h> @@ -15,6 +18,9 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio/steady_timer.hpp> ZEN_THIRD_PARTY_INCLUDES_END +#include <algorithm> +#include <charconv> +#include <deque> #include <thread> #include <unordered_map> @@ -22,45 +28,204 @@ namespace zen { ////////////////////////////////////////////////////////////////////////// // +// AsyncRequestToken state (forward-declared in the public header so tokens +// can be returned by value without leaking impl details). + +struct AsyncRequestToken::State +{ + std::function<void()> CancelFn; + std::atomic<bool> Cancelled = false; +}; + +////////////////////////////////////////////////////////////////////////// +// // TransferContext: per-transfer state associated with each CURL easy handle +// Request blueprint kept alongside each transfer so retries can re-issue with +// the original verb/url/headers/payload after the previous attempt's transient +// failure. +enum class AsyncRequestMethod +{ + Get, + Head, + Delete, + Post, + PostWithPayload, + Put, + PutWithPayload, + PutWithSource, // PUT, body pulled via OnReadSource (no materialized payload) + Stream, // GET, response body delivered via OnData callback (no copy) +}; + +inline std::string_view +AsyncRequestMethodName(AsyncRequestMethod M) +{ + switch (M) + { + case AsyncRequestMethod::Get: + return "GET"; + case AsyncRequestMethod::Head: + return "HEAD"; + case AsyncRequestMethod::Delete: + return "DELETE"; + case AsyncRequestMethod::Post: + return "POST"; + case AsyncRequestMethod::PostWithPayload: + return "POST(payload)"; + case AsyncRequestMethod::Put: + return "PUT"; + case AsyncRequestMethod::PutWithPayload: + return "PUT(payload)"; + case AsyncRequestMethod::PutWithSource: + return "PUT(stream)"; + case AsyncRequestMethod::Stream: + return "GET(stream)"; + } + return "?"; +} + +struct AsyncRequestSpec +{ + AsyncRequestMethod Method = AsyncRequestMethod::Get; + std::string Url; + HttpClient::KeyValueMap AdditionalHeader; + HttpClient::KeyValueMap Parameters; + IoBuffer Payload; // POST/PUT with payload + ZenContentType ContentType = ZenContentType::kUnknownContentType; + bool HasContentType = false; + AsyncHttpDataCallback OnData; // Stream method + AsyncHttpReadSource OnReadSource; // PutWithSource method + uint64_t StreamingPutSize = 0; // Content-Length for PutWithSource + + // Opt-in header capture. By default Response::Header is left empty; inline + // extracts always run because they steer the body path. WantHeaderMap pays + // O(headers) string allocs on the io thread (rare - most callers use + // Response::FindHeader). WantEtag triggers an inline parse only. + bool WantEtag = false; + bool WantHeaderMap = false; +}; + struct TransferContext { - AsyncHttpCallback Callback; - std::string Body; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - CurlWriteCallbackData WriteData; - CurlHeaderCallbackData HeaderData; - curl_slist* HeaderList = nullptr; - - // For PUT/POST with payload: keep the data alive until transfer completes + AsyncHttpCallback Callback; + AsyncRequestSpec Spec; + uint8_t AttemptCount = 0; + uint64_t TokenId = 0; + bool Cancelled = false; + // Curl handle owning this transfer once submitted to curl_multi. Null between + // Submit() and SubmitFromSpec(), and again after CompleteTransfer releases the + // handle back to the pool. + CURL* CurlHandle = nullptr; + // Strong reference to the user-facing AsyncRequestToken::State. Kept alive + // so AsyncRequestToken::Cancel() retains a valid CancelFn target until the + // transfer completes. The typed pointer also lets SubmitFromSpec read + // State.Cancelled to honour cancels that arrived between Submit() returning + // and the io thread running SubmitFromSpec. + std::shared_ptr<AsyncRequestToken::State> TokenStateRef; + // Two paths gated by BodyPreallocated: when Content-Length is known we + // fill Body in place (zero-copy move into Response); otherwise BodyChunks + // accumulates per-WRITE IoBuffers, flattened at completion. + IoBuffer Body; + uint64_t BodyWriteOffset = 0; + bool BodyPreallocated = false; + std::vector<IoBuffer> BodyChunks; + + // Raw response headers, "Key: Value\r\n" lines as delivered by curl. + // One growing buffer; reserve covers common case (~1 KiB) with no realloc. + std::string HeaderArena; + + // Captured at the curl callback boundary so an exception out of the user + // OnData / OnReadSource never propagates through curl's C frames (UB). + // Surfaced as a kInternalError response by CompleteTransfer. + bool CallbackFailed = false; + std::string CallbackErrorMessage; + + // Inline-parsed in CurlHeaderCallback. Etag is populated only when + // Spec.WantEtag is set; the others are always parsed since they steer + // the body path / content-type tagging. + uint64_t ContentLength = 0; + bool ContentLengthSet = false; + ZenContentType BodyContentType = ZenContentType::kUnknownContentType; + std::string Etag; + + curl_slist* HeaderList = nullptr; + IoBuffer PayloadBuffer; CurlReadCallbackData ReadData; + uint64_t SourceOffset = 0; // PutWithSource: bytes pulled from Spec.OnReadSource so far. + + // Last attempt's failure, kept across the backoff so a retry-abandoned + // path can surface the underlying cause instead of a generic + // "Request canceled (retry abandoned)". Non-empty LastErrorMessage = stash valid. + CURLcode LastCurlResult = CURLE_OK; + long LastStatusCode = 0; + std::string LastErrorMessage; + + TransferContext(AsyncHttpCallback&& InCallback) : Callback(std::move(InCallback)) { HeaderArena.reserve(1024); } + + ~TransferContext() { FreeHeaderList(); } + + TransferContext(const TransferContext&) = delete; + TransferContext& operator=(const TransferContext&) = delete; - TransferContext(AsyncHttpCallback&& InCallback) : Callback(std::move(InCallback)) + // Reset accumulated response state so the same context can be re-submitted + // for a retry attempt. + void ResetForRetry() { - WriteData.Body = &Body; - HeaderData.Headers = &ResponseHeaders; + Body = IoBuffer{}; + BodyWriteOffset = 0; + BodyPreallocated = false; + BodyChunks.clear(); + HeaderArena.clear(); // keep capacity + ContentLength = 0; + ContentLengthSet = false; + BodyContentType = ZenContentType::kUnknownContentType; + Etag.clear(); + FreeHeaderList(); + ReadData = {}; + SourceOffset = 0; + CallbackFailed = false; + CallbackErrorMessage.clear(); + // LastCurlResult / LastStatusCode / LastErrorMessage are intentionally NOT + // cleared - they describe the just-finished attempt that triggered this + // retry, and surface in the abandon path if the next attempt is cancelled + // or shutdown-aborted. } - ~TransferContext() + void FreeHeaderList() { if (HeaderList) { curl_slist_free_all(HeaderList); + HeaderList = nullptr; } } - - TransferContext(const TransferContext&) = delete; - TransferContext& operator=(const TransferContext&) = delete; }; ////////////////////////////////////////////////////////////////////////// // -// AsyncHttpClient::Impl +// SocketInfo: per-socket state. -struct AsyncHttpClient::Impl +struct AsyncSocketInfo { + asio::ip::tcp::socket Socket; + int WatchFlags = 0; // CURL_POLL_IN, CURL_POLL_OUT, CURL_POLL_INOUT + int PendingFlags = 0; // directions with outstanding async_wait + + // Bound to the strand executor so async_wait completions are serialized on + // the same strand that drives curl_multi - safe even when the underlying + // io_context is multithreaded (external-context mode). + explicit AsyncSocketInfo(const asio::strand<asio::io_context::executor_type>& Strand) : Socket(Strand) {} +}; + +// Holds the curl_multi instance and a strand that serializes every curl_multi op. Owned io_context +// (default ctor) spins a private thread driving run(); external io_context mode lets the caller +// drive the loop. The strand is the single serialization point for both modes. +struct AsyncHttpClient::Impl : std::enable_shared_from_this<AsyncHttpClient::Impl> +{ + // Owned-io_context ctor: allocate a private io_context and run it on a + // dedicated thread. Cleanest path for callers that just want an + // AsyncHttpClient and don't care about the loop. Impl(std::string_view BaseUri, const HttpClientSettings& Settings) : m_BaseUri(BaseUri) , m_Settings(Settings) @@ -72,19 +237,33 @@ struct AsyncHttpClient::Impl { Init(); m_WorkGuard.emplace(m_IoContext.get_executor()); - m_IoThread = std::thread([this]() { - SetCurrentThreadName("async_http"); - try - { - m_IoContext.run(); - } - catch (const std::exception& Ex) + + auto ThreadGuard = MakeGuard([this]() { + m_WorkGuard.reset(); + if (m_IoThread.joinable()) { - ZEN_ERROR("AsyncHttpClient: unhandled exception in io thread: {}", Ex.what()); + m_IoThread.join(); } }); + m_IoThread = std::thread([this]() { + SetCurrentThreadName("async_http"); + try + { + m_IoContext.run(); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("AsyncHttpClient: io thread unhandled exception: {}", Ex.what()); + } + }); + ThreadGuard.Dismiss(); } + // External-io_context ctor: caller drives the run loop. We do NOT spawn a + // thread, do NOT hold a work guard, and do NOT call stop()/restart() on + // teardown - the caller's lifecycle owns those. Shutdown blocks on a + // promise until our cleanup handler runs through the strand, so the + // caller MUST keep the loop running until the AsyncHttpClient destructs. Impl(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings) : m_BaseUri(BaseUri) , m_Settings(Settings) @@ -96,70 +275,167 @@ struct AsyncHttpClient::Impl Init(); } - ~Impl() + ~Impl() { Shutdown(); } + + // Synchronous teardown from ~AsyncHttpClient. Idempotent. Branches by ownership: + // - Owned io_context: post Cleanup, drop the work guard, force run() to return, join the io + // thread, then poll() to drain lambdas that captured shared_ptr<Impl> (the "lambda in + // io_context owned by Impl" cycle would otherwise pin Impl alive forever). + // - External io_context: caller drives the loop, post Cleanup to the strand and block on a + // promise. Destroying from the io thread itself would deadlock. + void Shutdown() { - // Clean up curl state on the strand where all curl_multi operations - // are serialized. Use a promise to block until the cleanup handler - // has actually executed - essential for the external io_context case - // where we don't own the run loop. - std::promise<void> Done; - std::future<void> DoneFuture = Done.get_future(); - - asio::post(m_Strand, [this, &Done]() { - m_ShuttingDown = true; - m_Timer.cancel(); - - // Release all tracked sockets (don't close - curl owns the fds). - for (auto& [Fd, Info] : m_Sockets) - { - if (Info->Socket.is_open()) - { - Info->Socket.cancel(); - Info->Socket.release(); - } - } - m_Sockets.clear(); + if (m_ShutdownDone.exchange(true, std::memory_order_acq_rel)) + { + return; + } - for (auto& [Handle, Ctx] : m_Transfers) + if (m_OwnedIoContext) + { + // Post Cleanup before releasing the work guard so the io thread + // runs Cleanup before run() returns. Run() normally exits when + // the queue is empty AND no work guard is held. + asio::post(m_Strand, [this]() { Cleanup(); }); + m_WorkGuard.reset(); + // Belt-and-suspenders: force run() to return even if asio's + // outstanding_work_ counter is left non-zero by a close-during- + // cancel race in win_iocp_socket_service. The race is observable + // as a hung join() at shutdown after a burst of socket teardowns + // inside curl_multi_cleanup; stop() here is purely a safety net + // for the teardown path. The trailing restart() + poll() drains + // any handlers stop() leaves undispatched. + m_IoContext.stop(); + if (m_IoThread.joinable()) { - curl_multi_remove_handle(m_Multi, Handle); - curl_easy_cleanup(Handle); + m_IoThread.join(); } - m_Transfers.clear(); - for (CURL* Handle : m_HandlePool) + // Drain any leftover work posted to the io_context but not run + // before the thread exited (e.g. a Cancel-after-completion lambda + // that captured shared_ptr<Impl> by value). Loop until the queue + // is empty: a single poll() is one quanta and may not drain + // handlers that themselves post follow-ups. + m_IoContext.restart(); + while (m_IoContext.poll() != 0) { - curl_easy_cleanup(Handle); } - m_HandlePool.clear(); - - Done.set_value(); - }); + } + else + { + // External: block on the cleanup handler so we can guarantee + // curl_multi state is gone before m_Impl drops. + std::promise<void> Done; + std::future<void> DoneFuture = Done.get_future(); + asio::post(m_Strand, [this, &Done]() { + Cleanup(); + Done.set_value(); + }); + DoneFuture.wait(); + } + } - // For owned io_context: release work guard so run() can return after - // processing the cleanup handler above. - m_WorkGuard.reset(); + // Cleanup body, run on the io thread. + void Cleanup() + { + m_ShuttingDown = true; + m_Timer.cancel(); - if (m_IoThread.joinable()) + // Tear down curl handles first; curl drives CURL_POLL_REMOVE + + // CLOSESOCKETFUNCTION for each owned socket, which our handlers + // use to retire SocketInfo entries from m_Sockets. + for (auto& [TokenId, Ctx] : m_Transfers) { - m_IoThread.join(); + if (Ctx->CurlHandle) + { + curl_multi_remove_handle(m_Multi, Ctx->CurlHandle); + curl_easy_cleanup(Ctx->CurlHandle); + Ctx->CurlHandle = nullptr; + } + + HttpClient::Response Resp; + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "AsyncHttpClient shutting down", + }; + try + { + Ctx->Callback(std::move(Resp)); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("AsyncHttpClient: unhandled exception in shutdown callback (token={}, {} {}): {}", + TokenId, + AsyncRequestMethodName(Ctx->Spec.Method), + Ctx->Spec.Url, + Ex.what()); + } } - else + m_Transfers.clear(); + m_InFlight = 0; + + // Drain transfers parked in retry backoff. The timer's pending + // async_wait fires after Cleanup with the io_context already stopped; + // it will find no entry and be a no-op. Fire cancel callbacks here + // while we still hold the storage so callers' futures resolve. + for (auto& [Id, Entry] : m_RetryingTransfers) { - // External io_context: wait for the cleanup handler to complete. - DoneFuture.wait(); + if (Entry.Timer) + { + Entry.Timer->cancel(); + } + HttpClient::Response Resp; + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "AsyncHttpClient shutting down", + }; + try + { + Entry.Ctx->Callback(std::move(Resp)); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("AsyncHttpClient: unhandled exception in shutdown retry callback (token={}, {} {}): {}", + Id, + AsyncRequestMethodName(Entry.Ctx->Spec.Method), + Entry.Ctx->Spec.Url, + Ex.what()); + } } + m_RetryingTransfers.clear(); + // curl_multi_cleanup walks the connection cache and fires + // CLOSESOCKETFUNCTION for each cached fd. Run it on the io thread + // while m_Sockets is still populated so our callback routes each + // close through the map (Socket dtor closes fd exactly once). if (m_Multi) { curl_multi_cleanup(m_Multi); + m_Multi = nullptr; + } + + for (CURL* Handle : m_HandlePool) + { + curl_easy_cleanup(Handle); + } + m_HandlePool.clear(); + + asio::error_code CancelEc; + for (auto& [Fd, Info] : m_Sockets) + { + Info->Socket.cancel(CancelEc); } + m_Sockets.clear(); } LoggerRef Log() { return m_Log; } void Init() { + if (!m_Settings.UnixSocketPath.empty()) + { + m_UnixSocketPathUtf8 = PathToUtf8(m_Settings.UnixSocketPath); + } + m_Multi = curl_multi_init(); if (!m_Multi) { @@ -168,6 +444,24 @@ struct AsyncHttpClient::Impl SetupMultiCallbacks(); + if (m_Settings.MaxConcurrentConnectionsPerHost != 0) + { + curl_multi_setopt(m_Multi, CURLMOPT_MAX_HOST_CONNECTIONS, static_cast<long>(m_Settings.MaxConcurrentConnectionsPerHost)); + } + if (m_Settings.MaxConcurrentConnectionsTotal != 0) + { + curl_multi_setopt(m_Multi, CURLMOPT_MAX_TOTAL_CONNECTIONS, static_cast<long>(m_Settings.MaxConcurrentConnectionsTotal)); + } + + // Size the idle-conn cache to the in-flight cap so reused conns are + // never evicted while requests are queued. Each eviction costs a + // fresh TCP+TLS handshake (~280ms WAN to S3) on the next reuse. + const long MaxConnectsHint = static_cast<long>(std::max({m_Settings.MaxConcurrentRequests, + m_Settings.MaxConcurrentConnectionsTotal, + m_Settings.MaxConcurrentConnectionsPerHost, + 128u})); + curl_multi_setopt(m_Multi, CURLMOPT_MAXCONNECTS, MaxConnectsHint); + if (m_Settings.SessionId == Oid::Zero) { m_SessionId = std::string(GetSessionIdString()); @@ -178,6 +472,30 @@ struct AsyncHttpClient::Impl } } + // Run a completion callback inline on the io thread. By the time this is + // called, curl_multi_remove_handle + curl_easy_cleanup have already finalized + // the easy handle, so deferring to next io tick buys nothing. Direct call + // saves one alloc + queue insert per request. + // + // CONTRACT: user callbacks run on the AsyncHttpClient io thread. Heavy work + // (disk syscalls, lock contention, large allocations) must be hopped to a + // worker pool; otherwise it stalls curl_multi for ALL in-flight transfers. + void DispatchCallback(AsyncHttpCallback Cb, HttpClient::Response Resp, const TransferContext& Ctx) + { + try + { + Cb(std::move(Resp)); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("AsyncHttpClient: unhandled exception in completion callback (token={}, {} {}): {}", + Ctx.TokenId, + AsyncRequestMethodName(Ctx.Spec.Method), + Ctx.Spec.Url, + Ex.what()); + } + } + // -- Handle pool ----------------------------------------------------- CURL* AllocHandle() @@ -199,20 +517,15 @@ struct AsyncHttpClient::Impl void ReleaseHandle(CURL* Handle) { m_HandlePool.push_back(Handle); } - // -- Configure a handle with common settings ------------------------- - // Called only from DoAsync* lambdas running on the strand. - + // Called only from DoAsync* lambdas running on the io thread. void ConfigureHandle(CURL* Handle, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters) { - // Build URL ExtendableStringBuilder<256> Url; BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters); curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); - // Unix domain socket if (!m_Settings.UnixSocketPath.empty()) { - m_UnixSocketPathUtf8 = PathToUtf8(m_Settings.UnixSocketPath); curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, m_UnixSocketPathUtf8.c_str()); } @@ -226,12 +539,6 @@ struct AsyncHttpClient::Impl curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_Settings.Timeout.count())); } - // HTTP/2 - if (m_Settings.AssumeHttp2) - { - curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); - } - // SSL if (m_Settings.InsecureSsl) { @@ -243,7 +550,6 @@ struct AsyncHttpClient::Impl curl_easy_setopt(Handle, CURLOPT_CAINFO, m_Settings.CaBundlePath.c_str()); } - // Verbose/debug if (m_Settings.Verbose) { curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); @@ -252,6 +558,22 @@ struct AsyncHttpClient::Impl // Thread safety curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); + // 256 KiB recv buffer aligns with optimal read syscall size and matches + // upload buffer; pairs with downstream 512 KiB write slots (2 recv calls + // per write slot under bulk transfer). + curl_easy_setopt(Handle, CURLOPT_BUFFERSIZE, 262144L); + curl_easy_setopt(Handle, CURLOPT_UPLOAD_BUFFERSIZE, 262144L); + + // Skip per-transfer progress bookkeeping; we don't consume it. + curl_easy_setopt(Handle, CURLOPT_NOPROGRESS, 1L); + + // Disable Nagle (default since curl 7.50; explicit for safety). + curl_easy_setopt(Handle, CURLOPT_TCP_NODELAY, 1L); + + // Take ownership of socket close (see CurlCloseSocketCallback). + curl_easy_setopt(Handle, CURLOPT_CLOSESOCKETFUNCTION, &CurlCloseSocketCallback); + curl_easy_setopt(Handle, CURLOPT_CLOSESOCKETDATA, this); + if (m_Settings.ForbidReuseConnection) { curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); @@ -260,23 +582,24 @@ struct AsyncHttpClient::Impl // -- Access token ---------------------------------------------------- - std::optional<std::string> GetAccessToken() + struct AccessTokenResult { + std::optional<std::string> Token; + bool ProviderFailed = false; // provider configured but returned invalid twice + }; + + // Called only on the io thread. + AccessTokenResult GetAccessToken() + { + AccessTokenResult Result; if (!m_Settings.AccessTokenProvider.has_value()) { - return {}; + return Result; // No provider: anonymous is the intended mode. } - { - RwLock::SharedLockScope _(m_AccessTokenLock); - if (!m_CachedAccessToken.NeedsRefresh()) - { - return m_CachedAccessToken.GetValue(); - } - } - RwLock::ExclusiveLockScope _(m_AccessTokenLock); if (!m_CachedAccessToken.NeedsRefresh()) { - return m_CachedAccessToken.GetValue(); + Result.Token = m_CachedAccessToken.GetValue(); + return Result; } HttpClientAccessToken NewToken = m_Settings.AccessTokenProvider.value()(); if (!NewToken.IsValid()) @@ -287,10 +610,176 @@ struct AsyncHttpClient::Impl if (NewToken.IsValid()) { m_CachedAccessToken = NewToken; - return m_CachedAccessToken.GetValue(); + Result.Token = m_CachedAccessToken.GetValue(); + return Result; } ZEN_WARN("AsyncHttpClient: access token provider returned invalid token"); - return {}; + Result.ProviderFailed = true; + return Result; + } + + // -- Submit / resubmit ----------------------------------------------- + // + // SubmitFromSpec runs on the io thread. Used for both the initial submission + // and for retries: the AsyncRequestSpec inside Ctx encodes everything + // needed to (re)build the curl handle from scratch. + + void SubmitFromSpec(std::unique_ptr<TransferContext> Ctx) + { + if (m_ShuttingDown) + { + // Synthesize a cancel response so the user callback fires exactly once. + // Without this any Ctx that lands here post-shutdown would be dropped + // silently, leaving waiting futures unresolved. + HttpClient::Response CancelResp; + CancelResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled (client shutting down)", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(CancelResp), *Ctx); + return; + } + + // Cancel-before-submit race: the user can call AsyncRequestToken::Cancel() + // between Submit() returning and the io thread running SubmitFromSpec. + // State.Cancelled is set under acq_rel by Cancel(); read here under + // acquire ensures visibility. If set, fire the cancel callback exactly + // once and bail before the transfer enters m_Transfers. + if (Ctx->TokenStateRef && Ctx->TokenStateRef->Cancelled.load(std::memory_order_acquire)) + { + HttpClient::Response CancelResp; + CancelResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled before submit", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(CancelResp), *Ctx); + return; + } + + // Allocate the curl handle BEFORE bumping m_InFlight: AllocHandle can throw + // (curl_easy_init returning null). The InFlightGuard covers the rest of + // the body (BuildHeaderList, ExtraHeaders push_back, GetAccessToken can + // all throw bad_alloc) so a throw post-increment doesn't leak the slot. + // HandleGuard returns the handle to the pool on throw; CallbackGuard + // synthesizes a kInternalError response so the user future resolves. + CURL* Handle = AllocHandle(); + ++m_InFlight; + auto InFlightGuard = MakeGuard([this] { --m_InFlight; }); + auto HandleGuard = MakeGuard([this, Handle] { ReleaseHandle(Handle); }); + auto CallbackGuard = MakeGuard([this, &Ctx] { + HttpClient::Response ErrResp; + ErrResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kInternalError, + .ErrorMessage = "AsyncHttpClient::SubmitFromSpec: setup threw before dispatch", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(ErrResp), *Ctx); + }); + ConfigureHandle(Handle, Ctx->Spec.Url, Ctx->Spec.Parameters); + + switch (Ctx->Spec.Method) + { + case AsyncRequestMethod::Get: + curl_easy_setopt(Handle, CURLOPT_HTTPGET, 1L); + break; + + case AsyncRequestMethod::Stream: + curl_easy_setopt(Handle, CURLOPT_HTTPGET, 1L); + break; + + case AsyncRequestMethod::Head: + curl_easy_setopt(Handle, CURLOPT_NOBODY, 1L); + break; + + case AsyncRequestMethod::Delete: + curl_easy_setopt(Handle, CURLOPT_CUSTOMREQUEST, "DELETE"); + break; + + case AsyncRequestMethod::Post: + curl_easy_setopt(Handle, CURLOPT_POST, 1L); + curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE, 0L); + break; + + case AsyncRequestMethod::PostWithPayload: + { + curl_easy_setopt(Handle, CURLOPT_POST, 1L); + Ctx->PayloadBuffer = Ctx->Spec.Payload; + Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); + Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); + Ctx->ReadData.Offset = 0; + curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); + break; + } + + case AsyncRequestMethod::Put: + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, 0LL); + break; + + case AsyncRequestMethod::PutWithPayload: + { + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + Ctx->PayloadBuffer = Ctx->Spec.Payload; + Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); + Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); + Ctx->ReadData.Offset = 0; + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); + break; + } + + case AsyncRequestMethod::PutWithSource: + { + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + Ctx->SourceOffset = 0; + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Ctx->Spec.StreamingPutSize)); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, AsyncCurlSourceReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, Ctx.get()); + break; + } + } + + // Headers - include Content-Type for payload-bearing methods, Content-Length: 0 for empty PUT. + std::vector<std::pair<std::string, std::string>> ExtraHeaders; + if (Ctx->Spec.Method == AsyncRequestMethod::PostWithPayload) + { + const ZenContentType Effective = Ctx->Spec.HasContentType ? Ctx->Spec.ContentType : Ctx->Spec.Payload.GetContentType(); + ExtraHeaders.emplace_back("Content-Type", std::string(MapContentTypeToString(Effective))); + } + else if (Ctx->Spec.Method == AsyncRequestMethod::PutWithPayload) + { + ExtraHeaders.emplace_back("Content-Type", std::string(MapContentTypeToString(Ctx->Spec.Payload.GetContentType()))); + } + else if (Ctx->Spec.Method == AsyncRequestMethod::Put) + { + ExtraHeaders.emplace_back("Content-Length", "0"); + } + + AccessTokenResult Token = GetAccessToken(); + if (Token.ProviderFailed) + { + // Provider configured but failed twice: do NOT silently downgrade + // to an anonymous request - the server will respond 403 and the + // caller has no way to tell auth failed. + HttpClient::Response ErrResp; + ErrResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kInternalError, + .ErrorMessage = "AsyncHttpClient: access token provider failed; refusing to issue anonymous request", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(ErrResp), *Ctx); + CallbackGuard.Dismiss(); + return; + } + + Ctx->HeaderList = BuildHeaderList(Ctx->Spec.AdditionalHeader, m_SessionId, std::move(Token.Token), ExtraHeaders); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + InFlightGuard.Dismiss(); + HandleGuard.Dismiss(); + CallbackGuard.Dismiss(); + SubmitTransfer(Handle, std::move(Ctx)); } // -- Submit a transfer ----------------------------------------------- @@ -298,56 +787,66 @@ struct AsyncHttpClient::Impl void SubmitTransfer(CURL* Handle, std::unique_ptr<TransferContext> Ctx) { ZEN_TRACE_CPU("AsyncHttpClient::SubmitTransfer"); - // Setup write/header callbacks - curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &Ctx->WriteData); - curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &Ctx->HeaderData); - - m_Transfers[Handle] = std::move(Ctx); - + // Pick the WRITE callback by method: + // - Stream: forwards each chunk to the caller's OnData (no copy) + // - other: buffers bytes in TransferContext::Body + if (Ctx->Spec.Method == AsyncRequestMethod::Stream) + { + curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, AsyncCurlStreamWriteCallback); + curl_easy_setopt(Handle, CURLOPT_WRITEDATA, Ctx.get()); + } + else + { + curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, AsyncCurlWriteCallback); + curl_easy_setopt(Handle, CURLOPT_WRITEDATA, Ctx.get()); + } + curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, AsyncCurlHeaderCallback); + curl_easy_setopt(Handle, CURLOPT_HEADERDATA, Ctx.get()); + // Stash TokenId on the curl handle so CheckCompleted can look up the + // TransferContext directly from curl_multi_info_read's CURLMsg.easy_handle. + curl_easy_setopt(Handle, CURLOPT_PRIVATE, reinterpret_cast<void*>(static_cast<uintptr_t>(Ctx->TokenId))); + + Ctx->CurlHandle = Handle; + const uint64_t TokenIdLocal = Ctx->TokenId; + + // Try the curl_multi add first. On failure, Ctx is still owned locally so the + // rollback is a single Release path with no map churn. On success, ownership + // moves into the single TokenId-keyed lookup table. CURLMcode Mc = curl_multi_add_handle(m_Multi, Handle); if (Mc != CURLM_OK) { - auto Stolen = std::move(m_Transfers[Handle]); - m_Transfers.erase(Handle); + Ctx->CurlHandle = nullptr; ReleaseHandle(Handle); HttpClient::Response ErrorResponse; ErrorResponse.Error = HttpClient::ErrorContext{.ErrorCode = HttpClientErrorCode::kInternalError, .ErrorMessage = fmt::format("curl_multi_add_handle failed: {}", curl_multi_strerror(Mc))}; - asio::post(m_IoContext, - [Cb = std::move(Stolen->Callback), Response = std::move(ErrorResponse)]() mutable { Cb(std::move(Response)); }); + DispatchCallback(std::move(Ctx->Callback), std::move(ErrorResponse), *Ctx); + OnSlotFreed(); return; } - } - // -- Socket-action integration --------------------------------------- - // - // curl_multi drives I/O via two callbacks: - // - SocketCallback: curl tells us which sockets to watch for read/write - // - TimerCallback: curl tells us when to fire a timeout - // - // On each socket event or timeout we call curl_multi_socket_action(), - // then drain completed transfers via curl_multi_info_read(). + m_Transfers.emplace(TokenIdLocal, std::move(Ctx)); + } - // Per-socket state: wraps the native fd in an ASIO socket for async_wait. - struct SocketInfo + // Telemetry only; this client does not gate fan-out. Callers (e.g. + // S3AsyncStorage) layer their own admission semaphore on top. + // Assert catches SubmitFromSpec/OnSlotFreed imbalance. + void OnSlotFreed() { - asio::ip::tcp::socket Socket; - int WatchFlags = 0; // CURL_POLL_IN, CURL_POLL_OUT, CURL_POLL_INOUT - - explicit SocketInfo(asio::io_context& IoContext) : Socket(IoContext) {} - }; + ZEN_ASSERT(m_InFlight > 0); + --m_InFlight; + } - // Static thunks registered with curl_multi ---------------------------- + // curl_multi drives I/O via SocketCallback (which fds to watch) and TimerCallback (when to fire). + // On each event we call curl_multi_socket_action() and drain via curl_multi_info_read(). + // Static thunks: UserData = Impl* (set via CURLMOPT_SOCKETDATA / CURLMOPT_TIMERDATA). Bodies run on io thread. static int CurlSocketCallback(CURL* Easy, curl_socket_t Fd, int Action, void* UserPtr, void* SocketPtr) { - ZEN_UNUSED(Easy); auto* Self = static_cast<Impl*>(UserPtr); - Self->OnCurlSocket(Fd, Action, static_cast<SocketInfo*>(SocketPtr)); + Self->OnCurlSocket(Easy, Fd, Action, static_cast<AsyncSocketInfo*>(SocketPtr)); return 0; } @@ -359,6 +858,219 @@ struct AsyncHttpClient::Impl return 0; } + // Async-specific HEADER callback. Appends raw "Key: Value\r\n" lines to + // Ctx->HeaderArena (one growing buffer; ~zero allocs in the common case + // where the initial reserve covers the response). Inline-parses + // Content-Length and Content-Type unconditionally; parses ETag only when + // Spec.WantEtag is set. No std::string/pair allocations per line. + static size_t AsyncCurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) + { + auto* Ctx = static_cast<TransferContext*>(UserData); + const size_t TotalBytes = Size * Nmemb; + std::string_view Line(Buffer, TotalBytes); + + Ctx->HeaderArena.append(Buffer, TotalBytes); + + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + if (Line.empty()) + { + return TotalBytes; + } + const size_t Colon = Line.find(':'); + if (Colon == std::string_view::npos) + { + return TotalBytes; // HTTP status line or malformed + } + std::string_view Key = Line.substr(0, Colon); + std::string_view Value = Line.substr(Colon + 1); + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + if (StrCaseEquals(Key, "Content-Length")) + { + uint64_t Length = 0; + std::from_chars_result Res = std::from_chars(Value.data(), Value.data() + Value.size(), Length); + if (Res.ec == std::errc{}) + { + Ctx->ContentLength = Length; + Ctx->ContentLengthSet = true; + } + } + else if (StrCaseEquals(Key, "Content-Type")) + { + Ctx->BodyContentType = ParseContentType(Value); + } + else if (Ctx->Spec.WantEtag && StrCaseEquals(Key, "ETag")) + { + Ctx->Etag.assign(Value); + } + + return TotalBytes; + } + + // Async-specific write callback. Targets a TransferContext directly. + // Preallocates Body from Ctx->ContentLength (parsed in HEADER cb). If + // Content-Length is absent (e.g. chunked encoding), falls back to + // BodyChunks accumulation; CompleteTransfer flattens at the end. + static size_t AsyncCurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) + { + auto* Ctx = static_cast<TransferContext*>(UserData); + const size_t TotalBytes = Size * Nmemb; + if (TotalBytes == 0) + { + return 0; + } + + if (!Ctx->BodyPreallocated && Ctx->BodyWriteOffset == 0 && Ctx->BodyChunks.empty() && Ctx->ContentLengthSet && + Ctx->ContentLength > 0) + { + Ctx->Body = IoBuffer(static_cast<size_t>(Ctx->ContentLength)); + Ctx->BodyPreallocated = true; + } + + if (Ctx->BodyPreallocated) + { + if (Ctx->BodyWriteOffset + TotalBytes > Ctx->Body.GetSize()) + { + // Server sent more than Content-Length advertised; abort. + return 0; + } + memcpy(static_cast<uint8_t*>(Ctx->Body.MutableData()) + Ctx->BodyWriteOffset, Ptr, TotalBytes); + Ctx->BodyWriteOffset += TotalBytes; + } + else + { + IoBuffer Chunk(TotalBytes); + memcpy(Chunk.MutableData(), Ptr, TotalBytes); + Ctx->BodyChunks.push_back(std::move(Chunk)); + } + + return TotalBytes; + } + + // PutWithSource read callback. Pulls up to MaxBytes from Spec.OnReadSource + // into curl's send buffer. Source closure runs on the io thread - same + // strand discipline as Stream's OnData. Returning 0 with SourceOffset < + // StreamingPutSize signals an upload abort to curl. + static size_t AsyncCurlSourceReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) + { + auto* Ctx = static_cast<TransferContext*>(UserData); + const size_t MaxBytes = Size * Nmemb; + if (MaxBytes == 0 || !Ctx->Spec.OnReadSource) + { + return 0; + } + // Catch at the curl boundary: user OnReadSource may call into IoBuffer + // allocation, file reads, or async dispatch which can throw. Letting an + // exception unwind through curl's C frames is UB. + size_t Pulled = 0; + try + { + Pulled = Ctx->Spec.OnReadSource(reinterpret_cast<uint8_t*>(Buffer), MaxBytes, Ctx->SourceOffset); + } + catch (const std::exception& Ex) + { + Ctx->CallbackFailed = true; + Ctx->CallbackErrorMessage = fmt::format("upload source callback threw: {}", Ex.what()); + return CURL_READFUNC_ABORT; + } + catch (...) + { + Ctx->CallbackFailed = true; + Ctx->CallbackErrorMessage = "upload source callback threw unknown exception"; + return CURL_READFUNC_ABORT; + } + if (Pulled == 0 && Ctx->SourceOffset < Ctx->Spec.StreamingPutSize) + { + return CURL_READFUNC_ABORT; + } + Ctx->SourceOffset += Pulled; + return Pulled; + } + + // Stream-method write callback. Hands each chunk to the caller's OnData + // without allocating or copying. The pointer is curl's internal receive + // buffer; valid only for the duration of this call. Caller's OnData runs + // on the io thread, so blocking work (disk write etc) blocks the poll + // loop. TotalSize comes from inline-parsed Content-Length (0 if absent / + // chunked). + static size_t AsyncCurlStreamWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) + { + auto* Ctx = static_cast<TransferContext*>(UserData); + const size_t TotalBytes = Size * Nmemb; + if (TotalBytes == 0) + { + return 0; + } + + if (!Ctx->Spec.OnData) + { + // Stream method requires OnData by contract; reaching this branch + // is API misuse. Returning TotalBytes (success) would silently + // drop the body and report "ok"; fail loudly instead. + Ctx->CallbackFailed = true; + Ctx->CallbackErrorMessage = "stream request submitted without OnData callback"; + return 0; + } + + // Catch at the curl boundary so an exception inside the user OnData + // (e.g. IoBuffer alloc failure, ScheduleWork rejection, ZEN_ASSERT in a + // downstream pool) cannot propagate through curl's C frames. Stash on + // Ctx so CompleteTransfer surfaces it as kInternalError. + try + { + const bool ContinueTransfer = Ctx->Spec.OnData(reinterpret_cast<const uint8_t*>(Ptr), TotalBytes, Ctx->ContentLength); + return ContinueTransfer ? TotalBytes : 0; // returning 0 aborts + } + catch (const std::exception& Ex) + { + Ctx->CallbackFailed = true; + Ctx->CallbackErrorMessage = fmt::format("stream data callback threw: {}", Ex.what()); + return 0; + } + catch (...) + { + Ctx->CallbackFailed = true; + Ctx->CallbackErrorMessage = "stream data callback threw unknown exception"; + return 0; + } + } + + // Take ownership of socket close. CLOSESOCKETFUNCTION is invoked from + // within curl_multi operations which run on the io thread, so direct map + // access is safe. The asio tcp::socket destructor closes the fd; we + // return 0 to tell curl the close succeeded. Letting curl close as well + // would race (double-close, fd-reuse hazard) and on Windows IOCP + // `release()` throws `operation_not_supported`, killing the io thread. + static int CurlCloseSocketCallback(void* ClientPtr, curl_socket_t Fd) + { + auto* Self = static_cast<Impl*>(ClientPtr); + auto It = Self->m_Sockets.find(Fd); + if (It != Self->m_Sockets.end()) + { + asio::error_code Ec; + It->second->Socket.cancel(Ec); + Self->m_Sockets.erase(It); + return 0; + } + // Fd not tracked (e.g. pre-poll-add or post-shutdown); close directly. +#if ZEN_PLATFORM_WINDOWS + ::closesocket(Fd); +#else + ::close(Fd); +#endif + return 0; + } + void SetupMultiCallbacks() { curl_multi_setopt(m_Multi, CURLMOPT_SOCKETFUNCTION, CurlSocketCallback); @@ -369,45 +1081,97 @@ struct AsyncHttpClient::Impl // Called by curl when socket watch state changes --------------------- - void OnCurlSocket(curl_socket_t Fd, int Action, SocketInfo* Info) + // Synthesize a transport-level failure for the easy handle currently bound + // to the curl_multi entry that owns Fd. Used when the asio side cannot bind + // the fd; without this the affected transfer would hang on curl's + // connect/transfer timeout instead of failing fast with the real error. + void FailEasyHandleForFd(CURL* Easy, std::string_view Reason) + { + if (!Easy) + { + return; + } + curl_multi_remove_handle(m_Multi, Easy); + + char* Private = nullptr; + curl_easy_getinfo(Easy, CURLINFO_PRIVATE, &Private); + const uint64_t TokenId = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(Private)); + + auto It = m_Transfers.find(TokenId); + if (It == m_Transfers.end()) + { + ReleaseHandle(Easy); + return; + } + + std::unique_ptr<TransferContext> Ctx = std::move(It->second); + m_Transfers.erase(It); + Ctx->CurlHandle = nullptr; + ReleaseHandle(Easy); + + HttpClient::Response Resp; + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kInternalError, + .ErrorMessage = std::string(Reason), + }; + DispatchCallback(std::move(Ctx->Callback), std::move(Resp), *Ctx); + OnSlotFreed(); + } + + void OnCurlSocket(CURL* Easy, curl_socket_t Fd, int Action, AsyncSocketInfo* Info) { if (Action == CURL_POLL_REMOVE) { if (Info) { - // Cancel pending async_wait ops before releasing the fd. - // curl owns the fd, so we must release() rather than close(). - Info->Socket.cancel(); - if (Info->Socket.is_open()) - { - Info->Socket.release(); - } - m_Sockets.erase(Fd); + // Cancel any pending async_wait but KEEP the AsyncSocketInfo + // alive in m_Sockets. CURL_POLL_REMOVE only means "stop + // watching this socket"; curl may still use the fd + // (keep-alive reuse) and rewatch it later. The asio Socket + // stays bound to the same IOCP it was first assigned to; + // we never call release()+assign() on the same fd, which + // avoided a race where in-flight async_wait callbacks + // raced with curl_multi reading from the socket and + // corrupted HTTP framing. The fd's actual close happens + // in CurlCloseSocketCallback, which erases the entry and + // lets the asio Socket destructor close. + asio::error_code Ec; + Info->Socket.cancel(Ec); + Info->WatchFlags = 0; + Info->PendingFlags = 0; } return; } if (!Info) { - // New socket - wrap the native fd in an ASIO socket. - auto [It, Inserted] = m_Sockets.emplace(Fd, std::make_unique<SocketInfo>(m_IoContext)); - Info = It->second.get(); - - asio::error_code Ec; - // Determine protocol from the fd (v4 vs v6). Default to v4. - Info->Socket.assign(asio::ip::tcp::v4(), Fd, Ec); - if (Ec) + // CURL_POLL_IN/OUT with no Info attached. Two cases: + // 1) brand new fd - emplace, assign, IOCP-bind. + // 2) curl re-watching a kept-alive fd it earlier removed - + // reuse the existing AsyncSocketInfo, no re-assign. + auto It = m_Sockets.find(Fd); + if (It == m_Sockets.end()) { - // Try v6 as fallback - Info->Socket.assign(asio::ip::tcp::v6(), Fd, Ec); - } - if (Ec) - { - ZEN_WARN("AsyncHttpClient: failed to assign socket fd {}: {}", static_cast<int>(Fd), Ec.message()); - m_Sockets.erase(Fd); - return; - } + auto [NewIt, _] = m_Sockets.emplace(Fd, std::make_unique<AsyncSocketInfo>(m_Strand)); + It = NewIt; + asio::error_code Ec; + It->second->Socket.assign(asio::ip::tcp::v4(), Fd, Ec); + if (Ec) + { + It->second->Socket.assign(asio::ip::tcp::v6(), Fd, Ec); + } + if (Ec) + { + std::string Reason = + fmt::format("AsyncHttpClient: failed to assign socket fd {}: {}", static_cast<int>(Fd), Ec.message()); + ZEN_WARN("{}", Reason); + m_Sockets.erase(It); + FailEasyHandleForFd(Easy, Reason); + return; + } + } + Info = It->second.get(); curl_multi_assign(m_Multi, Fd, Info); } @@ -415,37 +1179,84 @@ struct AsyncHttpClient::Impl SetSocketWatch(Fd, Info); } - void SetSocketWatch(curl_socket_t Fd, SocketInfo* Info) + void SetSocketWatch(curl_socket_t Fd, AsyncSocketInfo* Info) { - // Cancel any pending wait before issuing a new one. - Info->Socket.cancel(); + // Cancel only when a previously-watched direction is no longer wanted. + // In the common path (one-shot async_wait completes, curl re-watches + // the same flags) PendingFlags is a subset of WatchFlags and we just + // re-arm the missing direction without touching CancelIoEx. + const int Desired = Info->WatchFlags & (CURL_POLL_IN | CURL_POLL_OUT); + + if (Info->PendingFlags & ~Desired) + { + asio::error_code Ec; + Info->Socket.cancel(Ec); + Info->PendingFlags = 0; + } + + const int ToAdd = Desired & ~Info->PendingFlags; - if (Info->WatchFlags & CURL_POLL_IN) + if (ToAdd & CURL_POLL_IN) { - Info->Socket.async_wait(asio::socket_base::wait_read, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { - if (Ec || m_ShuttingDown) - { - return; - } - OnSocketReady(Fd, CURL_CSELECT_IN); - })); + Info->PendingFlags |= CURL_POLL_IN; + Info->Socket.async_wait(asio::socket_base::wait_read, [this, Fd](const asio::error_code& Ec) { + if (m_ShuttingDown) + { + return; + } + auto It = m_Sockets.find(Fd); + if (It == m_Sockets.end()) + { + return; + } + It->second->PendingFlags &= ~CURL_POLL_IN; + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_DEBUG("AsyncHttpClient: read async_wait fd {} error: {}; signalling curl with CURL_CSELECT_ERR", + static_cast<int>(Fd), + Ec.message()); + OnSocketReady(Fd, CURL_CSELECT_ERR); + } + return; + } + OnSocketReady(Fd, CURL_CSELECT_IN); + }); } - if (Info->WatchFlags & CURL_POLL_OUT) + if (ToAdd & CURL_POLL_OUT) { - Info->Socket.async_wait(asio::socket_base::wait_write, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { - if (Ec || m_ShuttingDown) - { - return; - } - OnSocketReady(Fd, CURL_CSELECT_OUT); - })); + Info->PendingFlags |= CURL_POLL_OUT; + Info->Socket.async_wait(asio::socket_base::wait_write, [this, Fd](const asio::error_code& Ec) { + if (m_ShuttingDown) + { + return; + } + auto It = m_Sockets.find(Fd); + if (It == m_Sockets.end()) + { + return; + } + It->second->PendingFlags &= ~CURL_POLL_OUT; + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_DEBUG("AsyncHttpClient: write async_wait fd {} error: {}; signalling curl with CURL_CSELECT_ERR", + static_cast<int>(Fd), + Ec.message()); + OnSocketReady(Fd, CURL_CSELECT_ERR); + } + return; + } + OnSocketReady(Fd, CURL_CSELECT_OUT); + }); } } void OnSocketReady(curl_socket_t Fd, int CurlAction) { - ZEN_TRACE_CPU("AsyncHttpClient::OnSocketReady"); int StillRunning = 0; curl_multi_socket_action(m_Multi, Fd, CurlAction, &StillRunning); CheckCompleted(); @@ -472,7 +1283,7 @@ struct AsyncHttpClient::Impl if (TimeoutMs == 0) { - // curl wants immediate action - run it directly on the strand. + // curl wants immediate action - run it on the next strand tick. asio::post(m_Strand, [this]() { if (m_ShuttingDown) { @@ -486,16 +1297,24 @@ struct AsyncHttpClient::Impl } m_Timer.expires_after(std::chrono::milliseconds(TimeoutMs)); - m_Timer.async_wait(asio::bind_executor(m_Strand, [this](const asio::error_code& Ec) { - if (Ec || m_ShuttingDown) + m_Timer.async_wait([this](const asio::error_code& Ec) { + if (m_ShuttingDown) + { + return; + } + if (Ec) { + if (Ec != asio::error::operation_aborted) + { + ZEN_DEBUG("AsyncHttpClient: curl multi timer error: {}", Ec.message()); + } return; } ZEN_TRACE_CPU("AsyncHttpClient::OnTimeout"); int StillRunning = 0; curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning); CheckCompleted(); - })); + }); } // Drain completed transfers from curl_multi -------------------------- @@ -516,7 +1335,13 @@ struct AsyncHttpClient::Impl curl_multi_remove_handle(m_Multi, Handle); - auto It = m_Transfers.find(Handle); + // Recover TokenId from CURLOPT_PRIVATE; cheaper than a per-handle + // reverse map. Returns nullptr if option was never set. + char* Private = nullptr; + curl_easy_getinfo(Handle, CURLINFO_PRIVATE, &Private); + const uint64_t TokenId = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(Private)); + + auto It = m_Transfers.find(TokenId); if (It == m_Transfers.end()) { ReleaseHandle(Handle); @@ -530,9 +1355,45 @@ struct AsyncHttpClient::Impl } } + // Mirrors CurlHttpClient::ShouldRetry semantics; keep the two in sync. + static bool ShouldRetryAsync(CURLcode CurlResult, long StatusCode) + { + switch (CurlResult) + { + case CURLE_OK: + break; + case CURLE_COULDNT_CONNECT: + case CURLE_RECV_ERROR: + case CURLE_SEND_ERROR: + case CURLE_OPERATION_TIMEDOUT: + case CURLE_PARTIAL_FILE: + return true; + default: + return false; + } + switch (static_cast<HttpResponseCode>(StatusCode)) + { + case HttpResponseCode::RequestTimeout: + case HttpResponseCode::TooManyRequests: + case HttpResponseCode::InternalServerError: + case HttpResponseCode::BadGateway: + case HttpResponseCode::ServiceUnavailable: + case HttpResponseCode::GatewayTimeout: + return true; + default: + return false; + } + } + void CompleteTransfer(CURL* Handle, CURLcode CurlResult, std::unique_ptr<TransferContext> Ctx) { ZEN_TRACE_CPU("AsyncHttpClient::CompleteTransfer"); + + // Free the in-flight counter before any retry / cancel branch. Retry + // re-submits via SubmitFromSpec, which re-increments the counter for + // the next attempt - keeping the assert balanced across retries. + OnSlotFreed(); + // Extract result info long StatusCode = 0; curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &StatusCode); @@ -548,13 +1409,205 @@ struct AsyncHttpClient::Impl ReleaseHandle(Handle); + // Cancellation came in after curl ran but before we processed completion. + // Synthesize a cancel response and skip retry. + if (Ctx->Cancelled) + { + HttpClient::Response CancelResp; + CancelResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(CancelResp), *Ctx); + return; + } + + // User OnData / OnReadSource threw inside curl. The transfer is already + // aborted; surface the stashed exception text and skip retry (a callback + // exception is not a transient transport failure). + if (Ctx->CallbackFailed) + { + HttpClient::Response Resp; + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kInternalError, + .ErrorMessage = std::move(Ctx->CallbackErrorMessage), + }; + DispatchCallback(std::move(Ctx->Callback), std::move(Resp), *Ctx); + return; + } + + // Retry path: re-issue from spec after backoff. Keeps the user callback + // untouched so the eventual final result fires only once. + if (!m_ShuttingDown && Ctx->AttemptCount < m_Settings.RetryCount && ShouldRetryAsync(CurlResult, StatusCode)) + { + ++Ctx->AttemptCount; + const long BackoffMs = 100 * Ctx->AttemptCount; + + if (CurlResult != CURLE_OK) + { + ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' (Curl error: {}) Attempt {}/{}", + m_SessionId, + static_cast<int>(MapCurlError(CurlResult)), + curl_easy_strerror(CurlResult), + static_cast<int>(CurlResult), + Ctx->AttemptCount, + m_Settings.RetryCount + 1); + } + else + { + ZEN_INFO("Retry (session: {}): HTTP status ({}) '{}' Attempt {}/{}", + m_SessionId, + StatusCode, + zen::ToString(HttpResponseCode(StatusCode)), + Ctx->AttemptCount, + m_Settings.RetryCount + 1); + } + + // Stash the just-finished attempt's failure so the abandon path can + // surface it instead of a generic "Request canceled (retry abandoned)". + // Non-empty LastErrorMessage marks the stash valid. + Ctx->LastCurlResult = CurlResult; + Ctx->LastStatusCode = StatusCode; + Ctx->LastErrorMessage = (CurlResult != CURLE_OK) ? std::string(curl_easy_strerror(CurlResult)) + : std::string(zen::ToString(HttpResponseCode(StatusCode))); + Ctx->ResetForRetry(); + + const uint64_t RetryTokenId = Ctx->TokenId; + auto RetryTimer = std::make_shared<asio::steady_timer>(m_Strand); + RetryTimer->expires_after(std::chrono::milliseconds(BackoffMs)); + + // Park Ctx + Timer in m_RetryingTransfers so HandleCancel can find + // it and cancel the timer (early cancel without paying the full + // backoff). The timer lambda re-claims Ctx through the map; if + // HandleCancel got there first the entry is gone and the lambda + // returns silently. + auto [It, Inserted] = m_RetryingTransfers.emplace(RetryTokenId, RetryEntry{std::move(Ctx), RetryTimer}); + ZEN_ASSERT(Inserted); + + // Capture weak_from_this() rather than raw `this`. With an external + // io_context, Cleanup cancels the timer but the cancellation handler + // is queued on the caller's loop and may fire AFTER ~Impl runs. The + // owned-context path drains via restart()+poll() before destruction + // so this is moot there, but the weak ref is the cheapest way to + // keep both paths safe. + RetryTimer->async_wait([Self = weak_from_this(), RetryTokenId](const asio::error_code& Ec) { + auto Locked = Self.lock(); + if (!Locked) + { + return; + } + Impl& Me = *Locked; + auto It = Me.m_RetryingTransfers.find(RetryTokenId); + if (It == Me.m_RetryingTransfers.end()) + { + // HandleCancel already removed the entry and dispatched the + // cancel callback. + return; + } + std::unique_ptr<TransferContext> Ctx = std::move(It->second.Ctx); + Me.m_RetryingTransfers.erase(It); + + if (Ec || Me.m_ShuttingDown) + { + // Retry abandoned by timer cancellation or client shutdown. + // Surface the underlying failure (stashed pre-backoff) so + // the caller can distinguish a real timeout/throttle from a + // shutdown / cancel race. + HttpClient::Response Resp; + if (!Ctx->LastErrorMessage.empty()) + { + Resp.StatusCode = HttpResponseCode(Ctx->LastStatusCode); + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = + (Ctx->LastCurlResult != CURLE_OK) ? MapCurlError(Ctx->LastCurlResult) : HttpClientErrorCode::kOtherError, + .ErrorMessage = fmt::format("Request canceled (retry abandoned after: {})", Ctx->LastErrorMessage), + }; + } + else + { + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled (retry abandoned)", + }; + } + Me.DispatchCallback(std::move(Ctx->Callback), std::move(Resp), *Ctx); + return; + } + Me.SubmitFromSpec(std::move(Ctx)); + }); + return; + } + // Build response HttpClient::Response Response; Response.StatusCode = HttpResponseCode(StatusCode); Response.UploadedBytes = static_cast<int64_t>(UpBytes); Response.DownloadedBytes = static_cast<int64_t>(DownBytes); Response.ElapsedSeconds = Elapsed; - Response.Header = BuildHeaderMap(Ctx->ResponseHeaders); + // Hand the raw arena over - FindHeader scans this lazily. Build the + // parsed KeyValueMap only when caller explicitly asks (rare). + Response.HeaderArena = std::move(Ctx->HeaderArena); + if (Ctx->Spec.WantHeaderMap) + { + std::string_view View(Response.HeaderArena); + while (!View.empty()) + { + const size_t LineEnd = View.find('\n'); + std::string_view Line = LineEnd == std::string_view::npos ? View : View.substr(0, LineEnd); + View = LineEnd == std::string_view::npos ? std::string_view{} : View.substr(LineEnd + 1); + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + if (auto Header = ParseHeaderLine(Line)) + { + Response.Header->insert_or_assign(std::string(Header->first), std::string(Header->second)); + } + } + } + + // Helper: produce the response IoBuffer from whichever Body path was + // taken. Preallocated path moves with zero copy; chunked fallback + // moves the single chunk or flattens N chunks into one allocation. + auto BuildResponsePayload = [&]() -> IoBuffer { + if (Ctx->BodyPreallocated) + { + IoBuffer Out = std::move(Ctx->Body); + if (Ctx->BodyWriteOffset != Out.GetSize()) + { + // Server closed early - return a non-owning sub-buffer over + // the actually-received prefix. Sub-buffer holds a ref to + // Out's core so the underlying allocation stays alive; no + // memcpy. + return IoBuffer(Out, 0, Ctx->BodyWriteOffset); + } + return Out; + } + if (Ctx->BodyChunks.size() == 1) + { + return std::move(Ctx->BodyChunks[0]); + } + if (!Ctx->BodyChunks.empty()) + { + // Flatten N chunks into one IoBuffer; single alloc avoids a copy chain. + size_t Total = 0; + for (const IoBuffer& C : Ctx->BodyChunks) + { + Total += C.GetSize(); + } + IoBuffer Out(Total); + uint8_t* Dst = static_cast<uint8_t*>(Out.MutableData()); + for (const IoBuffer& C : Ctx->BodyChunks) + { + memcpy(Dst, C.GetData(), C.GetSize()); + Dst += C.GetSize(); + } + return Out; + } + return IoBuffer{}; + }; + + const bool HasBody = Ctx->BodyPreallocated ? Ctx->BodyWriteOffset > 0 : !Ctx->BodyChunks.empty(); if (CurlResult != CURLE_OK) { @@ -562,258 +1615,261 @@ struct AsyncHttpClient::Impl if (CurlResult != CURLE_OPERATION_TIMEDOUT && CurlResult != CURLE_COULDNT_CONNECT && CurlResult != CURLE_ABORTED_BY_CALLBACK) { - ZEN_WARN("AsyncHttpClient failure: ({}) '{}'", static_cast<int>(CurlResult), ErrorMsg); + ZEN_WARN("AsyncHttpClient failure: token={} {} '{}': ({}) '{}'", + Ctx->TokenId, + AsyncRequestMethodName(Ctx->Spec.Method), + Ctx->Spec.Url, + static_cast<int>(CurlResult), + ErrorMsg); } - if (!Ctx->Body.empty()) + if (HasBody) { - Response.ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size()); + Response.ResponsePayload = BuildResponsePayload(); } Response.Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(CurlResult), .ErrorMessage = std::string(ErrorMsg)}; } - else if (StatusCode == static_cast<long>(HttpResponseCode::NoContent) || Ctx->Body.empty()) + else if (StatusCode == static_cast<long>(HttpResponseCode::NoContent) || !HasBody) { // No payload } else { - IoBuffer PayloadBuffer = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size()); - ApplyContentTypeFromHeaders(PayloadBuffer, Ctx->ResponseHeaders); + IoBuffer PayloadBuffer = BuildResponsePayload(); + if (Ctx->BodyContentType != ZenContentType::kUnknownContentType) + { + PayloadBuffer.SetContentType(Ctx->BodyContentType); + } const HttpResponseCode Code = HttpResponseCode(StatusCode); if (!IsHttpSuccessCode(Code) && Code != HttpResponseCode::NotFound) { - ZEN_WARN("AsyncHttpClient request failed: status={}, base={}", static_cast<int>(Code), m_BaseUri); + ZEN_WARN("AsyncHttpClient request failed: token={} {} '{}': status={}", + Ctx->TokenId, + AsyncRequestMethodName(Ctx->Spec.Method), + Ctx->Spec.Url, + static_cast<int>(Code)); } Response.ResponsePayload = std::move(PayloadBuffer); } - // Dispatch the user callback off the strand so a slow callback - // cannot starve the curl_multi poll loop. - asio::post(m_IoContext, [LogRef = m_Log, Cb = std::move(Ctx->Callback), Response = std::move(Response)]() mutable { - try + // Token reaches terminal state. Ctx (and its embedded TokenState) is + // destroyed when this scope ends; late Cancel() calls find no entry in + // m_Transfers and become no-ops. + DispatchCallback(std::move(Ctx->Callback), std::move(Response), *Ctx); + } + + // -- Async verb implementations -------------------------------------- + + AsyncRequestToken Submit(std::unique_ptr<TransferContext> Ctx) + { + // Allocate token ID + state up front so callers can cancel before the + // posted submit even runs. Token::State is shared between the user-held + // AsyncRequestToken and the TransferContext (no separate strand-side map). + const uint64_t Id = m_NextTokenId.fetch_add(1, std::memory_order_relaxed); + Ctx->TokenId = Id; + + auto State = std::make_shared<AsyncRequestToken::State>(); + State->CancelFn = [WeakSelf = weak_from_this(), Id]() { + auto Self = WeakSelf.lock(); + if (!Self) { - Cb(std::move(Response)); + return; } - catch (const std::exception& Ex) + asio::post(Self->m_Strand, [Self, Id]() { Self->HandleCancel(Id); }); + }; + Ctx->TokenStateRef = State; + + asio::post(m_Strand, [this, Ctx = std::move(Ctx)]() mutable { + if (m_ShuttingDown) { - ZEN_SCOPED_LOG(LogRef); - ZEN_ERROR("AsyncHttpClient: unhandled exception in completion callback: {}", Ex.what()); + HttpClient::Response Resp; + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled (client shutting down)", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(Resp), *Ctx); + return; } + SubmitFromSpec(std::move(Ctx)); }); - } - // -- Async verb implementations -------------------------------------- + return AsyncRequestToken(std::move(State)); + } - void DoAsyncGet(std::string Url, - AsyncHttpCallback Callback, - HttpClient::KeyValueMap AdditionalHeader, - HttpClient::KeyValueMap Parameters) + void HandleCancel(uint64_t Id) { - asio::post(m_Strand, - [this, - Url = std::move(Url), - Callback = std::move(Callback), - AdditionalHeader = std::move(AdditionalHeader), - Parameters = std::move(Parameters)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Get"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, Parameters); - curl_easy_setopt(Handle, CURLOPT_HTTPGET, 1L); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto It = m_Transfers.find(Id); + if (It != m_Transfers.end()) + { + std::unique_ptr<TransferContext> Ctx = std::move(It->second); + m_Transfers.erase(It); + + Ctx->Cancelled = true; + if (Ctx->CurlHandle) + { + curl_multi_remove_handle(m_Multi, Ctx->CurlHandle); + ReleaseHandle(Ctx->CurlHandle); + Ctx->CurlHandle = nullptr; + } + + HttpClient::Response CancelResp; + CancelResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(CancelResp), *Ctx); + OnSlotFreed(); + return; + } + + // Cancel landed during the retry backoff: take the parked Ctx + Timer, + // cancel the timer (the in-flight async_wait fires later with + // operation_aborted but finds no entry, so it's a no-op), and dispatch + // the cancel callback now so the user observes immediate cancellation + // rather than waiting out the backoff. + auto RetIt = m_RetryingTransfers.find(Id); + if (RetIt != m_RetryingTransfers.end()) + { + std::unique_ptr<TransferContext> Ctx = std::move(RetIt->second.Ctx); + std::shared_ptr<asio::steady_timer> Timer = std::move(RetIt->second.Timer); + m_RetryingTransfers.erase(RetIt); + Timer->cancel(); + + HttpClient::Response CancelResp; + CancelResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(CancelResp), *Ctx); + return; + } + + // Cancel landed before SubmitFromSpec ran (Cancel posted between Submit + // returning and the io thread executing the posted SubmitFromSpec). + // State.Cancelled has already been set by Cancel(); SubmitFromSpec + // checks it and synthesizes the cancel callback when the transfer + // eventually arrives. Nothing to do here. } - void DoAsyncHead(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) + AsyncRequestToken DoAsyncGet(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) { - asio::post(m_Strand, - [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Head"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, {}); - curl_easy_setopt(Handle, CURLOPT_NOBODY, 1L); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::Get; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Parameters = std::move(Parameters); + return Submit(std::move(Ctx)); } - void DoAsyncDelete(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) + AsyncRequestToken DoAsyncHead(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) { - asio::post(m_Strand, - [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Delete"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, {}); - curl_easy_setopt(Handle, CURLOPT_CUSTOMREQUEST, "DELETE"); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::Head; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + return Submit(std::move(Ctx)); } - void DoAsyncPost(std::string Url, - AsyncHttpCallback Callback, - HttpClient::KeyValueMap AdditionalHeader, - HttpClient::KeyValueMap Parameters) + AsyncRequestToken DoAsyncDelete(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) { - asio::post(m_Strand, - [this, - Url = std::move(Url), - Callback = std::move(Callback), - AdditionalHeader = std::move(AdditionalHeader), - Parameters = std::move(Parameters)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Post"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, Parameters); - curl_easy_setopt(Handle, CURLOPT_POST, 1L); - curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE, 0L); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::Delete; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + return Submit(std::move(Ctx)); } - void DoAsyncPostWithPayload(std::string Url, - IoBuffer Payload, - ZenContentType ContentType, - AsyncHttpCallback Callback, - HttpClient::KeyValueMap AdditionalHeader) + AsyncRequestToken DoAsyncPost(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) { - asio::post(m_Strand, - [this, - Url = std::move(Url), - Payload = std::move(Payload), - ContentType, - Callback = std::move(Callback), - AdditionalHeader = std::move(AdditionalHeader)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::PostWithPayload"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, {}); - curl_easy_setopt(Handle, CURLOPT_POST, 1L); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->PayloadBuffer = std::move(Payload); - Ctx->HeaderList = - BuildHeaderList(AdditionalHeader, - m_SessionId, - GetAccessToken(), - {std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)))}); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - // Set up read callback for payload data - Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); - Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); - Ctx->ReadData.Offset = 0; - - curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); - curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); - curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::Post; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Parameters = std::move(Parameters); + return Submit(std::move(Ctx)); } - void DoAsyncPutWithPayload(std::string Url, - IoBuffer Payload, - AsyncHttpCallback Callback, - HttpClient::KeyValueMap AdditionalHeader, - HttpClient::KeyValueMap Parameters) + AsyncRequestToken DoAsyncPostWithPayload(std::string Url, + IoBuffer Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader) { - asio::post(m_Strand, - [this, - Url = std::move(Url), - Payload = std::move(Payload), - Callback = std::move(Callback), - AdditionalHeader = std::move(AdditionalHeader), - Parameters = std::move(Parameters)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Put"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, Parameters); - curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->PayloadBuffer = std::move(Payload); - Ctx->HeaderList = BuildHeaderList( - AdditionalHeader, - m_SessionId, - GetAccessToken(), - {std::make_pair("Content-Type", std::string(MapContentTypeToString(Ctx->PayloadBuffer.GetContentType())))}); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); - Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); - Ctx->ReadData.Offset = 0; - - curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); - curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); - curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::PostWithPayload; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Payload = std::move(Payload); + Ctx->Spec.ContentType = ContentType; + Ctx->Spec.HasContentType = true; + return Submit(std::move(Ctx)); } - void DoAsyncPutNoPayload(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap Parameters) + AsyncRequestToken DoAsyncPutWithPayload(std::string Url, + IoBuffer Payload, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) { - asio::post(m_Strand, [this, Url = std::move(Url), Callback = std::move(Callback), Parameters = std::move(Parameters)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Put"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, Parameters); - curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); - curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, 0LL); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::PutWithPayload; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Parameters = std::move(Parameters); + Ctx->Spec.Payload = std::move(Payload); + return Submit(std::move(Ctx)); + } - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + AsyncRequestToken DoAsyncPutNoPayload(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::Put; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Parameters = std::move(Parameters); + return Submit(std::move(Ctx)); + } - HttpClient::KeyValueMap ContentLengthHeader{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; - Ctx->HeaderList = BuildHeaderList(ContentLengthHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + AsyncRequestToken DoAsyncPutWithSource(std::string Url, + uint64_t TotalSize, + AsyncHttpReadSource Source, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader) + { + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::PutWithSource; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.OnReadSource = std::move(Source); + Ctx->Spec.StreamingPutSize = TotalSize; + return Submit(std::move(Ctx)); + } - SubmitTransfer(Handle, std::move(Ctx)); - }); + AsyncRequestToken DoAsyncStream(std::string Url, + AsyncHttpDataCallback OnData, + AsyncHttpCallback OnComplete, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + auto Ctx = std::make_unique<TransferContext>(std::move(OnComplete)); + Ctx->Spec.Method = AsyncRequestMethod::Stream; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Parameters = std::move(Parameters); + Ctx->Spec.OnData = std::move(OnData); + return Submit(std::move(Ctx)); } // -- Members --------------------------------------------------------- @@ -824,25 +1880,44 @@ struct AsyncHttpClient::Impl std::string m_SessionId; std::string m_UnixSocketPathUtf8; - // io_context and strand - all curl_multi operations are serialized on the - // strand, making this safe even when the io_context has multiple threads. - std::unique_ptr<asio::io_context> m_OwnedIoContext; - asio::io_context& m_IoContext; + // io_context: either privately owned (m_OwnedIoContext non-null + we spin + // m_IoThread + hold m_WorkGuard) or supplied by the caller (m_OwnedIoContext + // null; caller drives the loop). Declared before m_Strand so the strand + // can bind its executor from m_IoContext during member init. + std::unique_ptr<asio::io_context> m_OwnedIoContext; + asio::io_context& m_IoContext; + // Single serialization point for every curl_multi operation, every async + // completion handler, and every Cleanup call. Declared after m_IoContext + // (so make_strand can read its executor) and before m_Timer (which binds + // to the strand). asio::strand<asio::io_context::executor_type> m_Strand; std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard; std::thread m_IoThread; - // curl_multi and socket-action state - CURLM* m_Multi = nullptr; - std::unordered_map<CURL*, std::unique_ptr<TransferContext>> m_Transfers; - std::vector<CURL*> m_HandlePool; - std::unordered_map<curl_socket_t, std::unique_ptr<SocketInfo>> m_Sockets; - asio::steady_timer m_Timer; - bool m_ShuttingDown = false; + // Strand-bound; async_wait completions land back on m_Strand. + asio::steady_timer m_Timer; + CURLM* m_Multi = nullptr; + // Single TokenId-keyed map. CurlHandle lives in TransferContext; reverse + // lookup from CURL* uses CURLOPT_PRIVATE (set in SubmitTransfer). + std::unordered_map<uint64_t, std::unique_ptr<TransferContext>> m_Transfers; + // Transfers parked between curl-side completion and the next attempt's + // SubmitFromSpec. Lookups by TokenId let HandleCancel cancel the backoff + // timer without paying the full delay. + struct RetryEntry + { + std::unique_ptr<TransferContext> Ctx; + std::shared_ptr<asio::steady_timer> Timer; + }; + std::unordered_map<uint64_t, RetryEntry> m_RetryingTransfers; + std::vector<CURL*> m_HandlePool; + std::unordered_map<curl_socket_t, std::unique_ptr<AsyncSocketInfo>> m_Sockets; + uint32_t m_InFlight = 0; // telemetry only; storage layer caps fan-out - // Access token cache - RwLock m_AccessTokenLock; + std::atomic<bool> m_ShuttingDown{false}; HttpClientAccessToken m_CachedAccessToken; + + std::atomic<uint64_t> m_NextTokenId{1}; + std::atomic<bool> m_ShutdownDone{false}; }; ////////////////////////////////////////////////////////////////////////// @@ -850,79 +1925,134 @@ struct AsyncHttpClient::Impl // AsyncHttpClient public API AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings) -: m_Impl(std::make_unique<Impl>(BaseUri, Settings)) +: m_Impl(std::make_shared<Impl>(BaseUri, Settings)) { } AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings) -: m_Impl(std::make_unique<Impl>(BaseUri, IoContext, Settings)) +: m_Impl(std::make_shared<Impl>(BaseUri, IoContext, Settings)) { } -AsyncHttpClient::~AsyncHttpClient() = default; +AsyncHttpClient::~AsyncHttpClient() +{ + // Drive teardown synchronously while we're guaranteed to be on a user + // thread (not the io thread). Joining and draining here ensures any + // posted lambdas that captured a shared_ptr<Impl> by value (e.g. + // Cancel after the transfer completed) are destroyed and release + // their refs before m_Impl drops; otherwise the cycle "lambda holds + // Impl ref / lambda lives in io_context owned by Impl" would pin Impl + // alive forever and leak curl_multi + socket handles. + if (m_Impl) + { + m_Impl->Shutdown(); + } +} // -- Callback-based API -------------------------------------------------- -void +AsyncRequestToken AsyncHttpClient::AsyncGet(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) { - m_Impl->DoAsyncGet(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); + return m_Impl->DoAsyncGet(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); } -void +AsyncRequestToken AsyncHttpClient::AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) { - m_Impl->DoAsyncHead(std::string(Url), std::move(Callback), AdditionalHeader); + return m_Impl->DoAsyncHead(std::string(Url), std::move(Callback), AdditionalHeader); } -void +AsyncRequestToken AsyncHttpClient::AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) { - m_Impl->DoAsyncDelete(std::string(Url), std::move(Callback), AdditionalHeader); + return m_Impl->DoAsyncDelete(std::string(Url), std::move(Callback), AdditionalHeader); } -void +AsyncRequestToken AsyncHttpClient::AsyncPost(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) { - m_Impl->DoAsyncPost(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); + return m_Impl->DoAsyncPost(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); } -void +AsyncRequestToken AsyncHttpClient::AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) { - m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, Payload.GetContentType(), std::move(Callback), AdditionalHeader); + return m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, Payload.GetContentType(), std::move(Callback), AdditionalHeader); } -void +AsyncRequestToken AsyncHttpClient::AsyncPost(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) { - m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, ContentType, std::move(Callback), AdditionalHeader); + return m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, ContentType, std::move(Callback), AdditionalHeader); } -void +AsyncRequestToken AsyncHttpClient::AsyncPut(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) { - m_Impl->DoAsyncPutWithPayload(std::string(Url), Payload, std::move(Callback), AdditionalHeader, Parameters); + return m_Impl->DoAsyncPutWithPayload(std::string(Url), Payload, std::move(Callback), AdditionalHeader, Parameters); +} + +AsyncRequestToken +AsyncHttpClient::AsyncPut(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + return m_Impl->DoAsyncPutNoPayload(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); +} + +AsyncRequestToken +AsyncHttpClient::AsyncPut(std::string_view Url, + uint64_t TotalSize, + AsyncHttpReadSource Source, + AsyncHttpCallback OnComplete, + const KeyValueMap& AdditionalHeader) +{ + return m_Impl->DoAsyncPutWithSource(std::string(Url), TotalSize, std::move(Source), std::move(OnComplete), AdditionalHeader); +} + +AsyncRequestToken +AsyncHttpClient::AsyncStream(std::string_view Url, + AsyncHttpDataCallback OnData, + AsyncHttpCallback OnComplete, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + return m_Impl->DoAsyncStream(std::string(Url), std::move(OnData), std::move(OnComplete), AdditionalHeader, Parameters); } +// -- Token cancellation -------------------------------------------------- + void -AsyncHttpClient::AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters) +AsyncRequestToken::Cancel() { - m_Impl->DoAsyncPutNoPayload(std::string(Url), std::move(Callback), Parameters); + if (!m_State) + { + return; + } + if (m_State->Cancelled.exchange(true, std::memory_order_acq_rel)) + { + return; // already cancelled + } + if (m_State->CancelFn) + { + m_State->CancelFn(); + } } // -- Future-based API ---------------------------------------------------- @@ -1026,6 +2156,7 @@ AsyncHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) AsyncPut( Url, [Promise](Response R) { Promise->set_value(std::move(R)); }, + KeyValueMap{}, Parameters); return Future; } diff --git a/src/zenhttp/clients/httpclientcurlhelpers.h b/src/zenhttp/clients/httpclientcurlhelpers.h index cb5f5d9a9..410399a11 100644 --- a/src/zenhttp/clients/httpclientcurlhelpers.h +++ b/src/zenhttp/clients/httpclientcurlhelpers.h @@ -232,27 +232,48 @@ BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, const std::optional<std::string>& AccessToken, const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) { - curl_slist* Headers = nullptr; + // Manual tail tracking. curl_slist_append walks to the tail on every call, + // so chained appends are O(n^2) over the header count. We append in O(1) + // by allocating each node directly and stitching it onto a tracked tail. + curl_slist* Head = nullptr; + curl_slist* Tail = nullptr; + auto Append = [&](const char* Line) { + curl_slist* New = curl_slist_append(nullptr, Line); + if (!New) + { + return; + } + if (!Head) + { + Head = New; + Tail = New; + } + else + { + Tail->next = New; + Tail = New; + } + }; for (const auto& [Key, Value] : *AdditionalHeader) { ExtendableStringBuilder<64> HeaderLine; HeaderLine << Key << ": " << Value; - Headers = curl_slist_append(Headers, HeaderLine.c_str()); + Append(HeaderLine.c_str()); } if (!SessionId.empty()) { ExtendableStringBuilder<64> SessionHeader; SessionHeader << "UE-Session: " << SessionId; - Headers = curl_slist_append(Headers, SessionHeader.c_str()); + Append(SessionHeader.c_str()); } if (AccessToken.has_value()) { ExtendableStringBuilder<128> AuthHeader; AuthHeader << "Authorization: " << AccessToken.value(); - Headers = curl_slist_append(Headers, AuthHeader.c_str()); + Append(AuthHeader.c_str()); } bool HasContentTypeOverride = AdditionalHeader->contains("Content-Type"); @@ -264,10 +285,10 @@ BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, } ExtendableStringBuilder<128> HeaderLine; HeaderLine << Key << ": " << Value; - Headers = curl_slist_append(Headers, HeaderLine.c_str()); + Append(HeaderLine.c_str()); } - return Headers; + return Head; } inline HttpClient::KeyValueMap diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 9d5846f71..b6a07250e 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -320,6 +320,62 @@ HttpClient::Response::IsSuccess() const noexcept return !Error && IsHttpSuccessCode(StatusCode); } +std::string_view +HttpClient::Response::FindHeader(std::string_view Name) const +{ + // Scan the raw arena first - the async client populates this and leaves + // Header empty by default. Lines are "Key: Value\r\n" (the trailing \r\n + // is what curl hands the HEADER callback; we keep it verbatim). + if (!HeaderArena.empty()) + { + std::string_view View(HeaderArena); + while (!View.empty()) + { + const size_t LineEnd = View.find('\n'); + std::string_view Line = LineEnd == std::string_view::npos ? View : View.substr(0, LineEnd); + View = LineEnd == std::string_view::npos ? std::string_view{} : View.substr(LineEnd + 1); + + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + if (Line.empty()) + { + continue; + } + const size_t Colon = Line.find(':'); + if (Colon == std::string_view::npos) + { + continue; // HTTP status line or malformed + } + std::string_view K = Line.substr(0, Colon); + std::string_view V = Line.substr(Colon + 1); + while (!K.empty() && K.back() == ' ') + { + K.remove_suffix(1); + } + while (!V.empty() && V.front() == ' ') + { + V.remove_prefix(1); + } + if (StrCaseEquals(K, Name)) + { + return V; + } + } + } + + // Fall back to the parsed map (sync client populates this). + for (const auto& [K, V] : *Header) + { + if (StrCaseEquals(K, Name)) + { + return V; + } + } + return {}; +} + std::string HttpClient::Response::ErrorMessage(std::string_view Prefix) const { diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index b0e097a54..ea73ff7a3 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -300,6 +300,49 @@ struct TestServerFixture TEST_SUITE_BEGIN("http.httpclient"); +TEST_CASE("httpclient.response.findheader.arena") +{ + // Async client populates HeaderArena (raw "Key: Value\r\n" bytes); FindHeader + // scans it lazily without building the full KeyValueMap. Exercise arena scan + // path independent of any live HTTP transfer. + HttpClient::Response Resp; + Resp.HeaderArena = + "Content-Type: application/json\r\n" + "ETag: \"abc-123\"\r\n" + "Content-Length: 42\r\n" + "X-Amz-Request-Id: deadbeef\r\n" + "\r\n"; + + CHECK_EQ(Resp.FindHeader("Content-Type"), "application/json"); + CHECK_EQ(Resp.FindHeader("etag"), "\"abc-123\""); // case-insensitive + CHECK_EQ(Resp.FindHeader("CONTENT-LENGTH"), "42"); // case-insensitive + CHECK_EQ(Resp.FindHeader("x-amz-request-id"), "deadbeef"); + CHECK_EQ(Resp.FindHeader("missing"), ""); // absent + CHECK_EQ(Resp.FindHeader(""), ""); // empty name +} + +TEST_CASE("httpclient.response.findheader.map_fallback") +{ + // Sync client populates Header KeyValueMap; FindHeader falls back to it + // when HeaderArena is empty. + HttpClient::Response Resp; + Resp.Header->insert_or_assign("Content-Type", "text/plain"); + Resp.Header->insert_or_assign("ETag", "\"xyz\""); + + CHECK_EQ(Resp.FindHeader("content-type"), "text/plain"); + CHECK_EQ(Resp.FindHeader("ETag"), "\"xyz\""); + CHECK_EQ(Resp.FindHeader("missing"), ""); +} + +TEST_CASE("httpclient.response.findheader.arena_takes_priority") +{ + // If both populated (unusual), arena is scanned first. + HttpClient::Response Resp; + Resp.HeaderArena = "ETag: from-arena\r\n"; + Resp.Header->insert_or_assign("ETag", "from-map"); + CHECK_EQ(Resp.FindHeader("ETag"), "from-arena"); +} + TEST_CASE("httpclient.verbs") { TestServerFixture Fixture; diff --git a/src/zenhttp/include/zenhttp/asynchttpclient.h b/src/zenhttp/include/zenhttp/asynchttpclient.h index cb41626b9..d4a33e3ac 100644 --- a/src/zenhttp/include/zenhttp/asynchttpclient.h +++ b/src/zenhttp/include/zenhttp/asynchttpclient.h @@ -19,21 +19,96 @@ namespace zen { /// Completion callback for async HTTP operations. using AsyncHttpCallback = std::function<void(HttpClient::Response)>; +/// Pull-mode body source for the streaming `AsyncPut` overload. Fired from +/// the AsyncHttpClient io thread when curl needs more upload bytes; runs on +/// the same strand as every other transfer. +/// +/// IMPORTANT: like AsyncHttpDataCallback, this runs on the io thread and +/// stalls curl_multi for ALL in-flight transfers if it blocks. Local-disk +/// pread is acceptable on fast storage; for anything slower, pre-fetch into +/// a caller-owned ring on a worker pool and have this callback pop from it. +/// +/// Dst - destination buffer to fill (caller writes here). +/// MaxBytes - maximum bytes Dst can hold this call. +/// AbsOffset - cumulative bytes returned so far (i.e. start offset of +/// this chunk into the request body). Strictly monotonic +/// across calls; equals TotalSize when EOF reached. +/// +/// Returns the number of bytes written into Dst (<= MaxBytes). Returning 0 +/// while AbsOffset < TotalSize is treated as an upload error (CURL_READFUNC_ABORT). +using AsyncHttpReadSource = std::function<size_t(uint8_t* Dst, size_t MaxBytes, uint64_t AbsOffset)>; + +/// Per-chunk data callback for `AsyncStream`. Fired from the AsyncHttpClient +/// io thread (the same thread driving curl_multi for every other transfer on +/// this client) once per response-body chunk that curl delivers. +/// +/// IMPORTANT: OnData runs on the io thread. Any blocking work inside it +/// (synchronous disk I/O, `std::mutex` waits, network calls, lock contention) +/// stalls curl_multi for ALL in-flight transfers on this client, not just +/// this one. Treat OnData as a strand: copy/refcount the bytes into a buffer +/// you own and hop the heavy work to a worker pool. See +/// `S3AsyncStorage::Get` (medium tier) for the canonical pattern - fill a +/// pre-sized IoBuffer from a bounded pool here, dispatch the positional disk +/// write to a worker, release the buffer back to the pool from the worker. +/// +/// Data - pointer to received bytes; valid only for the duration of +/// this call. Caller must consume synchronously (memcpy into +/// a pre-arranged buffer slot, etc). No allocation or copy +/// is performed by AsyncHttpClient. +/// Size - bytes available at Data for this chunk. +/// TotalSize - declared total payload size (Content-Length, in bytes); 0 +/// if the server did not declare a length (e.g. +/// Transfer-Encoding: chunked). Constant across all calls +/// for a given request. +/// +/// Returns true to continue the transfer; returning false aborts the +/// transfer and the completion callback fires with an error response. +using AsyncHttpDataCallback = std::function<bool(const uint8_t* Data, size_t Size, uint64_t TotalSize)>; + +/// Handle to an in-flight async HTTP request. Returned by every AsyncXxx call. +/// Default-constructed token is empty (Cancel is a no-op). Calling Cancel on a +/// non-empty token requests cancellation of the underlying transfer; the +/// completion callback still fires once with an error response so callers do +/// not need a second notification path. +/// +/// The token does not keep the AsyncHttpClient alive. Callers must ensure +/// Cancel is invoked before the client is destroyed; tokens left dangling are +/// safe to destroy but Cancel calls after client destruction are no-ops. +class AsyncRequestToken +{ +public: + // Forward-declared so AsyncHttpClient internals can hold a typed + // shared_ptr<State> across worker / strand boundaries (used for the + // cancel-before-submit race check). State definition is private to the + // implementation TU. + struct State; + + AsyncRequestToken() = default; + + void Cancel(); + bool IsValid() const { return m_State != nullptr; } + +private: + friend class AsyncHttpClient; + explicit AsyncRequestToken(std::shared_ptr<State> S) : m_State(std::move(S)) {} + + std::shared_ptr<State> m_State; +}; + /** Asynchronous HTTP client backed by curl_multi and ASIO. * * Uses curl_multi_socket_action() driven by ASIO socket async_wait to process - * transfers without blocking the caller. All curl_multi operations are - * serialized on an internal strand; callers may issue requests from any - * thread, and the io_context may have multiple threads. + * transfers without blocking the caller. By default the client owns an + * io_context and a single io thread driving it; the second constructor reuses + * an external io_context. All curl_multi operations are serialized on an + * internal strand. Callers may issue requests from any thread. * - * Two construction modes: - * - Owned io_context: creates an internal thread (self-contained). - * - External io_context: caller runs the event loop. - * - * Completion callbacks are dispatched on the io_context (not the internal - * strand), so a slow callback will not block the curl poll loop. Future- - * based wrappers (Get, Post, ...) return a std::future<Response> for - * callers that prefer blocking on a result. + * Completion callbacks run inline on the AsyncHttpClient io thread (same + * strand as the curl poll loop): heavy work (disk syscalls, lock contention, + * large allocations) must be hopped to a worker pool. See + * `AsyncHttpDataCallback` and `AsyncHttpReadSource` for the same contract on + * streaming. Future-based wrappers (Get, Post, ...) return a + * std::future<Response> for callers that prefer blocking on a result. */ class AsyncHttpClient { @@ -41,11 +116,15 @@ public: using Response = HttpClient::Response; using KeyValueMap = HttpClient::KeyValueMap; - /// Construct with an internally-owned io_context and thread. explicit AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings = {}); /// Construct with an externally-managed io_context. The io_context must - /// outlive this client and must be running (via run()) on at least one thread. + /// outlive the AsyncHttpClient and must be running (e.g. via run() on + /// a dedicated thread, or driven by the caller). The destructor posts + /// the cleanup handler to that loop and blocks until it completes - + /// destroying the client from the same thread that drives the io_context + /// would deadlock. Multiple threads on the same io_context are safe; + /// all curl_multi operations are serialized through an internal strand. AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings = {}); ~AsyncHttpClient(); @@ -54,38 +133,93 @@ public: AsyncHttpClient& operator=(const AsyncHttpClient&) = delete; // -- Callback-based API ---------------------------------------------- - - void AsyncGet(std::string_view Url, - AsyncHttpCallback Callback, - const KeyValueMap& AdditionalHeader = {}, - const KeyValueMap& Parameters = {}); - - void AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); - - void AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); - - void AsyncPost(std::string_view Url, - AsyncHttpCallback Callback, - const KeyValueMap& AdditionalHeader = {}, - const KeyValueMap& Parameters = {}); - - void AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); - - void AsyncPost(std::string_view Url, - const IoBuffer& Payload, - ZenContentType ContentType, - AsyncHttpCallback Callback, - const KeyValueMap& AdditionalHeader = {}); - - void AsyncPut(std::string_view Url, - const IoBuffer& Payload, - AsyncHttpCallback Callback, - const KeyValueMap& AdditionalHeader = {}, - const KeyValueMap& Parameters = {}); - - void AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters = {}); + // + // On every callback overload below the response's parsed Header map is left + // EMPTY by default - the raw header bytes live in Response::HeaderArena and + // callers should use Response::FindHeader(name) to look up individual values. + // This skips the per-line std::string allocations on the io thread that the + // sync client incurs. Callers that need to iterate all headers must opt in + // via AsyncRequestSpec::WantHeaderMap on the impl side; today no public + // overload exposes that flag because no in-tree caller needs it. + + AsyncRequestToken AsyncGet(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + AsyncRequestToken AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + AsyncRequestToken AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + AsyncRequestToken AsyncPost(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + AsyncRequestToken AsyncPost(std::string_view Url, + const IoBuffer& Payload, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}); + + AsyncRequestToken AsyncPost(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}); + + AsyncRequestToken AsyncPut(std::string_view Url, + const IoBuffer& Payload, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + AsyncRequestToken AsyncPut(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + /// Streaming PUT: body bytes are pulled from `Source` on demand, no + /// pre-materialized payload buffer. TotalSize is set as Content-Length. + /// `Source` runs on the AsyncHttpClient io thread - same strand + /// discipline as AsyncStream's OnData. Useful for multipart parts or + /// medium-tier uploads where materializing the body would waste RAM. + AsyncRequestToken AsyncPut(std::string_view Url, + uint64_t TotalSize, + AsyncHttpReadSource Source, + AsyncHttpCallback OnComplete, + const KeyValueMap& AdditionalHeader = {}); + + /// Streaming GET with a caller-supplied per-chunk data callback. Bytes are + /// delivered to OnData as they arrive, with no internal allocation or + /// copy of the payload. Caller manages its own destination (positional + /// disk write, pre-arranged buffer slot, etc.) and must consume each + /// chunk synchronously before returning - the data pointer is only valid + /// for the duration of the call. + /// + /// WARNING: OnData runs on the AsyncHttpClient io thread. Blocking I/O, + /// lock contention, or any wait inside OnData stalls curl_multi for ALL + /// in-flight transfers on this client. See `AsyncHttpDataCallback` doc + /// for the recommended buffer-then-hop pattern. + /// + /// The completion Callback fires once with status / headers / error + /// (Response.ResponsePayload is always empty on this path). + /// + /// Useful when the caller already owns the destination file (e.g. ranged + /// multipart download writing per-range slices to a shared file) and + /// wants to avoid the curl-buffer -> in-memory IoBuffer -> file double + /// memcpy / lazy-commit page-fault cost of AsyncGet+worker-write. + AsyncRequestToken AsyncStream(std::string_view Url, + AsyncHttpDataCallback OnData, + AsyncHttpCallback OnComplete, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); // -- Future-based API ------------------------------------------------ + // + // These wrappers discard the underlying AsyncRequestToken; callers who need + // to cancel an in-flight request must use the callback-based API. The + // returned future resolves once with the final Response (which may carry an + // error / cancel response if the client is shutting down). [[nodiscard]] std::future<Response> Get(std::string_view Url, const KeyValueMap& AdditionalHeader = {}, @@ -115,7 +249,7 @@ public: private: struct Impl; - std::unique_ptr<Impl> m_Impl; + std::shared_ptr<Impl> m_Impl; }; void asynchttpclient_test_forcelink(); // internal diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 4cf3a86a8..3162823f3 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -136,6 +136,24 @@ struct HttpClientSettings /// nullptr disables sharing. Must outlive every HttpClient that /// references it. Non-curl backends ignore this field. HttpClientShare* OptionalShare = nullptr; + + /// Max concurrent connections per host. Used by AsyncHttpClient to set + /// CURLMOPT_MAX_HOST_CONNECTIONS. libcurl default is small; hub-scale fanout + /// against a single S3 endpoint needs much higher. 0 = leave libcurl default. + uint32_t MaxConcurrentConnectionsPerHost = 0; + + /// Max concurrent connections total. Used by AsyncHttpClient to set + /// CURLMOPT_MAX_TOTAL_CONNECTIONS. 0 = leave libcurl default. + uint32_t MaxConcurrentConnectionsTotal = 0; + + /// Hint for the maximum number of in-flight requests the caller intends to + /// keep submitted to AsyncHttpClient. AsyncHttpClient itself does not gate + /// on this value - fan-out is bounded externally by the caller (e.g. the + /// hub's S3AsyncStorage admission semaphore). This setting is reused by + /// hub configuration to derive matching curl connection caps so libcurl's + /// CONNECTTIMEOUT does not tick on connections waiting behind the cap. + /// 0 = no hint. + uint32_t MaxConcurrentRequests = 0; }; class HttpClientError : public std::runtime_error @@ -272,9 +290,17 @@ public: HttpResponseCode StatusCode = HttpResponseCode::ImATeapot; IoBuffer ResponsePayload; // Note: this also includes the content type - // Contains the response headers + // Contains the response headers. By default the async path leaves this + // empty (raw bytes are kept in HeaderArena instead) - use FindHeader + // for lookups, or set AsyncRequestSpec::WantHeaderMap if you need the + // full parsed map. The synchronous client populates this as before. KeyValueMap Header; + // Raw response header bytes, "Key: Value\r\n" lines concatenated. + // Populated by the async client; empty for the sync client. Owned + // and freed with the Response. FindHeader() scans this lazily. + std::string HeaderArena; + // The number of bytes sent as part of the request int64_t UploadedBytes = 0; @@ -321,6 +347,12 @@ public: // objects, returns text as-is for text types like Text, JSON, HTML etc std::string ToText() const; + // Lookup a header by name (case-insensitive). Checks HeaderArena first, + // then Header map. Returns the first matching value or empty string_view + // if absent. View is valid until this Response is destroyed (or Header + // map is mutated). + std::string_view FindHeader(std::string_view Name) const; + // Returns whether the HTTP status code is considered successful (i.e in the // 2xx range) bool IsSuccess() const noexcept; diff --git a/src/zenserver/hub/hub.cpp b/src/zenserver/hub/hub.cpp index b4e9de2f0..2afee8729 100644 --- a/src/zenserver/hub/hub.cpp +++ b/src/zenserver/hub/hub.cpp @@ -21,6 +21,10 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <gsl/gsl-lite.hpp> ZEN_THIRD_PARTY_INCLUDES_END +#include <zencore/thread.h> + +#include <thread> + #if ZEN_WITH_TESTS # include <zencore/testing.h> # include <zencore/testutils.h> @@ -30,8 +34,6 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -/////////////////////////////////////////////////////////////////////////// - /** * A timeline of events with sequence IDs and timestamps. Used to * track significant events for broadcasting to listeners. @@ -81,7 +83,6 @@ public: */ void IterateEventsSince(auto&& Callback, uint64_t SinceEventId) { - // Hold the lock for as short a time as possible eastl::fixed_vector<EventRecord, 128> EventsToProcess; m_Lock.WithSharedLock([&] { for (auto& Event : m_Events) @@ -93,7 +94,6 @@ public: } }); - // Now invoke the callback outside the lock for (auto& Event : EventsToProcess) { Callback(Event); @@ -176,16 +176,22 @@ Hub::GetMachineMetrics(SystemMetrics& OutSystemMetrict, DiskSpace& OutDiskSpace) Hub::Hub(const Configuration& Config, ZenServerEnvironment&& RunEnvironment, AsyncModuleStateChangeCallbackFunc&& ModuleStateChangeCallback) : m_Config(Config) , m_RunEnvironment(std::move(RunEnvironment)) -, m_WorkerPool(Config.OptionalProvisionWorkerPool) +, m_ProvisionPool(Config.OptionalProvisionPool) +, m_SpawnPool(Config.OptionalSpawnPool) , m_InstanceClientShare(std::make_unique<HttpClientShare>()) , m_BackgroundWorkLatch(1) , m_ModuleStateChangeCallback(std::move(ModuleStateChangeCallback)) , m_ActiveInstances(Config.InstanceLimit) , m_FreeActiveInstanceIndexes(Config.InstanceLimit) { - ZEN_ASSERT_FORMAT( - Config.OptionalProvisionWorkerPool != Config.OptionalHydrationWorkerPool || Config.OptionalProvisionWorkerPool == nullptr, - "Provision and hydration worker pools must be distinct to avoid deadlocks"); + ZEN_ASSERT_FORMAT((Config.OptionalProvisionPool == nullptr) == (Config.OptionalSpawnPool == nullptr), + "Provision and spawn worker pools must both be set or both be null"); + ZEN_ASSERT_FORMAT(Config.OptionalProvisionPool != Config.OptionalSpawnPool || Config.OptionalProvisionPool == nullptr, + "Provision and spawn worker pools must be distinct to avoid deadlocks"); + ZEN_ASSERT_FORMAT(Config.OptionalProvisionPool != Config.OptionalHydrationPool || Config.OptionalProvisionPool == nullptr, + "Provision and hydration worker pools must be distinct to avoid deadlocks"); + ZEN_ASSERT_FORMAT(Config.OptionalSpawnPool != Config.OptionalHydrationPool || Config.OptionalSpawnPool == nullptr, + "Spawn and hydration worker pools must be distinct to avoid deadlocks"); HydrationBase::Configuration HydrationConfig; if (!m_Config.HydrationTargetSpecification.empty()) @@ -202,6 +208,11 @@ Hub::Hub(const Configuration& Config, ZenServerEnvironment&& RunEnvironment, Asy { HydrationConfig.Options = m_Config.HydrationOptions; } + if (Config.HydrationAsyncEnabled) + { + HydrationConfig.AsyncEnabled = true; + HydrationConfig.AsyncMaxConcurrentRequests = Config.HydrationAsyncMaxConcurrentRequests; + } m_Hydration = InitHydration(HydrationConfig); m_HydrationTempPath = m_RunEnvironment.CreateChildDir("hydration_temp"); @@ -272,7 +283,7 @@ Hub::Shutdown() m_WatchDog = {}; - if (WaitForBackgroundWork && m_WorkerPool) + if (WaitForBackgroundWork && (m_ProvisionPool != nullptr || m_SpawnPool != nullptr)) { m_BackgroundWorkLatch.CountDown(); m_BackgroundWorkLatch.Wait(); @@ -300,7 +311,7 @@ Hub::Shutdown() } }); - if (WaitForBackgroundWork && m_WorkerPool) + if (WaitForBackgroundWork && (m_ProvisionPool != nullptr || m_SpawnPool != nullptr)) { m_BackgroundWorkLatch.CountDown(); m_BackgroundWorkLatch.Wait(); @@ -317,7 +328,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) size_t ActiveInstanceIndex = (size_t)-1; HubInstanceState OldState = HubInstanceState::Unprovisioned; { - RwLock::ExclusiveLockScope _(m_Lock); + RwLock::ExclusiveLockScope HubLock(m_Lock); if (auto It = m_InstanceLookup.find(std::string(ModuleId)); It == m_InstanceLookup.end()) { @@ -339,23 +350,15 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) { auto NewInstance = std::make_unique<StorageServerInstance>( m_RunEnvironment, - *m_Hydration, - StorageServerInstance::Configuration{.BasePort = GetInstanceIndexAssignedPort(ActiveInstanceIndex), - .StateDir = m_RunEnvironment.CreateChildDir(ModuleId), - .TempDir = m_HydrationTempPath / ModuleId, - .HttpThreadCount = m_Config.InstanceHttpThreadCount, - .CoreLimit = m_Config.InstanceCoreLimit, - .ConfigPath = m_Config.InstanceConfigPath, - .Malloc = m_Config.InstanceMalloc, - .Trace = m_Config.InstanceTrace, - .TraceHost = m_Config.InstanceTraceHost, - .TraceFile = m_Config.InstanceTraceFile, - .EnableHydration = m_Config.EnableHydration, - .EnableDehydration = m_Config.EnableDehydration, - .HydrationPackEnabled = m_Config.HydrationPackEnabled, - .HydrationPackThresholdBytes = m_Config.HydrationPackThresholdBytes, - .HydrationMaxPackBytes = m_Config.HydrationMaxPackBytes, - .OptionalWorkerPool = m_Config.OptionalHydrationWorkerPool}, + StorageServerInstance::Configuration{.BasePort = GetInstanceIndexAssignedPort(ActiveInstanceIndex), + .StateDir = m_RunEnvironment.CreateChildDir(ModuleId), + .HttpThreadCount = m_Config.InstanceHttpThreadCount, + .CoreLimit = m_Config.InstanceCoreLimit, + .ConfigPath = m_Config.InstanceConfigPath, + .Malloc = m_Config.InstanceMalloc, + .Trace = m_Config.InstanceTrace, + .TraceHost = m_Config.InstanceTraceHost, + .TraceFile = m_Config.InstanceTraceFile}, ModuleId); #if ZEN_PLATFORM_WINDOWS @@ -367,7 +370,8 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) Instance = NewInstance->LockExclusive(/*Wait*/ true); - m_ActiveInstances[ActiveInstanceIndex].Instance = std::move(NewInstance); + m_ActiveInstances[ActiveInstanceIndex].Instance = std::move(NewInstance); + m_ActiveInstances[ActiveInstanceIndex].HydrationState = {}; m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Reset(); m_InstanceLookup.insert_or_assign(std::string(ModuleId), ActiveInstanceIndex); // Set Provisioning while both hub lock and instance lock are held so that any @@ -378,7 +382,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) { Instance = {}; m_ActiveInstances[ActiveInstanceIndex].Instance.reset(); - m_ActiveInstances[ActiveInstanceIndex].State.store(HubInstanceState::Unprovisioned); + UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Unprovisioned); m_InstanceLookup.erase(std::string(ModuleId)); m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex); throw; @@ -419,7 +423,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(std::chrono::system_clock::now()); return Response{EResponseCode::Completed}; case HubInstanceState::Hibernated: - _.ReleaseNow(); + HubLock.ReleaseNow(); return Wake(std::string(ModuleId)); default: return Response{EResponseCode::Rejected, @@ -463,54 +467,54 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Provisioning, OutInfo.Port, {}); - if (m_WorkerPool) + const bool Async = m_ProvisionPool != nullptr && m_SpawnPool != nullptr; + if (Async) { + auto SharedInstance = std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance)); m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( + m_ProvisionPool->ScheduleWork( [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, OldState, IsNewInstance, - Instance = std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance))]() { + Port = OutInfo.Port, + Instance = SharedInstance]() { auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); + if (!RunProvisionPhase1(*Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port)) + { + return; + } + + m_BackgroundWorkLatch.AddCount(1); try { - CompleteProvision(*Instance, ActiveInstanceIndex, OldState, IsNewInstance); + m_SpawnPool->ScheduleWork( + [this, ModuleId, ActiveInstanceIndex, OldState, IsNewInstance, Port, Instance]() { + auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); + RunProvisionPhase2(*Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); + }, + WorkerThreadPool::EMode::EnableBacklog); } - catch (const std::exception& Ex) + catch (const std::exception& DispatchEx) { - ZEN_ERROR("Failed async provision of module '{}': {}", ModuleId, Ex.what()); + // Fallback: run Phase2 inline on the ProvisionPool worker. Couples pool + // lifetimes (this ProvisionPool slot now executes what should run on + // SpawnPool) but better than dropping the request. + ZEN_ERROR("Failed async dispatch of provision phase 2 for module '{}': {}", ModuleId, DispatchEx.what()); + m_BackgroundWorkLatch.CountDown(); + RunProvisionPhase2(*Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); } }, WorkerThreadPool::EMode::EnableBacklog); } catch (const std::exception& DispatchEx) { - // Dispatch failed: undo latch increment and roll back state. - ZEN_ERROR("Failed async dispatch provision of module '{}': {}", ModuleId, DispatchEx.what()); + ZEN_ERROR("Failed async dispatch of provision phase 1 for module '{}': {}", ModuleId, DispatchEx.what()); m_BackgroundWorkLatch.CountDown(); - - // dispatch failed before the lambda ran, so ActiveInstance::State is still Provisioning - NotifyStateUpdate(ModuleId, HubInstanceState::Provisioning, OldState, OutInfo.Port, {}); - - std::unique_ptr<StorageServerInstance> DestroyInstance; - { - RwLock::ExclusiveLockScope HubLock(m_Lock); - ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId)) != m_InstanceLookup.end()); - ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId))->second == ActiveInstanceIndex); - if (IsNewInstance) - { - DestroyInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance); - m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex); - m_InstanceLookup.erase(std::string(ModuleId)); - } - UpdateInstanceState(HubLock, ActiveInstanceIndex, OldState); - } - DestroyInstance.reset(); - + RollbackFailedProvision(*SharedInstance, ActiveInstanceIndex, OldState, IsNewInstance, OutInfo.Port); throw; } } @@ -519,7 +523,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) CompleteProvision(Instance, ActiveInstanceIndex, OldState, IsNewInstance); } - return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; + return Response{Async ? EResponseCode::Accepted : EResponseCode::Completed}; } void @@ -529,38 +533,91 @@ Hub::CompleteProvision(StorageServerInstance::ExclusiveLockedPtr& Instance, bool IsNewInstance) { ZEN_TRACE_CPU("Hub::CompleteProvision"); + const uint16_t Port = Instance.GetBasePort(); + if (RunProvisionPhase1(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port)) + { + RunProvisionPhase2(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); + } +} + +bool +Hub::RunProvisionPhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunProvisionPhase1"); const std::string ModuleId(Instance.GetModuleId()); - const uint16_t Port = Instance.GetBasePort(); - std::string BaseUri; // TODO? - if (m_ShutdownFlag.load() == false) + if (m_ShutdownFlag.load()) { - try - { - switch (OldState) - { - case HubInstanceState::Crashed: - case HubInstanceState::Unprovisioned: - Instance.Provision(); - break; - case HubInstanceState::Hibernated: - ZEN_ASSERT(false); // unreachable: Provision redirects Hibernated->Wake before setting Provisioning - break; - default: - ZEN_ASSERT(false); - } - UpdateInstanceState(Instance, ActiveInstanceIndex, HubInstanceState::Provisioned); - NotifyStateUpdate(ModuleId, HubInstanceState::Provisioning, HubInstanceState::Provisioned, Port, BaseUri); - Instance = {}; - return; - } - catch (const std::exception& Ex) + RollbackFailedProvision(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); + return false; + } + + try + { + switch (OldState) { - ZEN_ERROR("Failed to provision storage server instance for module '{}': {}", ModuleId, Ex.what()); - // Instance will be notified and removed below. + case HubInstanceState::Crashed: + case HubInstanceState::Unprovisioned: + HydrateInstance(ActiveInstanceIndex, ModuleId); + break; + case HubInstanceState::Hibernated: + ZEN_ASSERT(false); // unreachable: Provision redirects Hibernated->Wake before setting Provisioning + break; + default: + ZEN_ASSERT(false); } + return true; + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed to hydrate storage server instance for module '{}': {}", ModuleId, Ex.what()); + RollbackFailedProvision(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); + return false; + } +} + +void +Hub::RunProvisionPhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunProvisionPhase2"); + const std::string ModuleId(Instance.GetModuleId()); + + if (m_ShutdownFlag.load()) + { + RollbackFailedProvision(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); + return; + } + + try + { + Instance.Provision(); + UpdateInstanceState(Instance, ActiveInstanceIndex, HubInstanceState::Provisioned); + NotifyStateUpdate(ModuleId, HubInstanceState::Provisioning, HubInstanceState::Provisioned, Port, {}); + Instance = {}; + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed to provision storage server instance for module '{}': {}", ModuleId, Ex.what()); + RollbackFailedProvision(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); } +} +void +Hub::RollbackFailedProvision(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port) +{ + const std::string ModuleId(Instance.GetModuleId()); if (IsNewInstance) { NotifyStateUpdate(ModuleId, HubInstanceState::Provisioning, HubInstanceState::Unprovisioned, Port, {}); @@ -568,11 +625,11 @@ Hub::CompleteProvision(StorageServerInstance::ExclusiveLockedPtr& Instance, std::unique_ptr<StorageServerInstance> DestroyInstance; { RwLock::ExclusiveLockScope HubLock(m_Lock); - ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId)) != m_InstanceLookup.end()); - ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId))->second == ActiveInstanceIndex); + ZEN_ASSERT_SLOW(m_InstanceLookup.find(ModuleId) != m_InstanceLookup.end()); + ZEN_ASSERT_SLOW(m_InstanceLookup.find(ModuleId)->second == ActiveInstanceIndex); DestroyInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance); m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex); - m_InstanceLookup.erase(std::string(ModuleId)); + m_InstanceLookup.erase(ModuleId); UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Unprovisioned); } DestroyInstance.reset(); @@ -652,11 +709,7 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI } } - // NOTE: done while not holding the hub lock, to avoid blocking other operations. - // The exclusive instance lock acquired above prevents concurrent LockExclusive callers - // from modifying instance state. The state transition to Deprovisioning happens below, - // after the hub lock is released. - + // Outside hub lock: exclusive instance lock above blocks concurrent state mutation. See Provision for the locking argument. ZEN_ASSERT(Instance); ZEN_ASSERT(ActiveInstanceIndex != (size_t)-1); @@ -664,32 +717,55 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI const uint16_t Port = Instance.GetBasePort(); NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Deprovisioning, Port, {}); - if (m_WorkerPool) + const bool Async = m_ProvisionPool != nullptr && m_SpawnPool != nullptr; + if (Async) { - std::shared_ptr<StorageServerInstance::ExclusiveLockedPtr> SharedInstancePtr = - std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance)); - + auto SharedInstance = std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance)); m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( - [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Instance = std::move(SharedInstancePtr), OldState]() mutable { + m_SpawnPool->ScheduleWork( + [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, OldState, Port, Instance = SharedInstance]() { auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); try { - CompleteDeprovision(*Instance, ActiveInstanceIndex, OldState); + RunDeprovisionPhase1(*Instance, ActiveInstanceIndex, OldState, Port); } catch (const std::exception& Ex) { - ZEN_ERROR("Failed async deprovision of module '{}': {}", ModuleId, Ex.what()); + // Phase1 transitions the instance to Crashed before rethrowing, + // so the watchdog will pick this up via AttemptRecoverInstance. + // The deprovision request silently morphs into a recovery cycle; + // caller already saw EResponseCode::Accepted. + ZEN_ERROR("Failed async deprovision phase 1 of module '{}': {}", ModuleId, Ex.what()); + return; + } + + m_BackgroundWorkLatch.AddCount(1); + try + { + m_ProvisionPool->ScheduleWork( + [this, ModuleId, ActiveInstanceIndex, Port, Instance]() { + auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); + RunDeprovisionPhase2(*Instance, ActiveInstanceIndex, Port); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& DispatchEx) + { + // Fallback: run Phase2 inline on the SpawnPool worker. Couples pool + // lifetimes (this SpawnPool slot now executes what should run on + // ProvisionPool) but better than dropping the request. + ZEN_ERROR("Failed async dispatch of deprovision phase 2 for module '{}': {}", ModuleId, DispatchEx.what()); + m_BackgroundWorkLatch.CountDown(); + RunDeprovisionPhase2(*Instance, ActiveInstanceIndex, Port); } }, WorkerThreadPool::EMode::EnableBacklog); } catch (const std::exception& DispatchEx) { - // Dispatch failed: undo latch increment and roll back state. - ZEN_ERROR("Failed async dispatch deprovision of module '{}': {}", ModuleId, DispatchEx.what()); + ZEN_ERROR("Failed async dispatch of deprovision phase 1 for module '{}': {}", ModuleId, DispatchEx.what()); m_BackgroundWorkLatch.CountDown(); NotifyStateUpdate(ModuleId, HubInstanceState::Deprovisioning, OldState, Port, {}); @@ -708,7 +784,7 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI CompleteDeprovision(Instance, ActiveInstanceIndex, OldState); } - return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; + return Response{Async ? EResponseCode::Accepted : EResponseCode::Completed}; } Hub::Response @@ -766,12 +842,12 @@ Hub::Obliterate(const std::string& ModuleId) m_ObliteratingInstances.insert(ModuleId); Lock.ReleaseNow(); - if (m_WorkerPool) + if (m_ProvisionPool != nullptr) { m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( + m_ProvisionPool->ScheduleWork( [this, ModuleId = std::string(ModuleId)]() { auto Guard = MakeGuard([this, ModuleId]() { m_Lock.WithExclusiveLock([this, ModuleId]() { m_ObliteratingInstances.erase(ModuleId); }); @@ -817,31 +893,51 @@ Hub::Obliterate(const std::string& ModuleId) const uint16_t Port = Instance.GetBasePort(); NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Obliterating, Port, {}); - if (m_WorkerPool) + const bool Async = m_ProvisionPool != nullptr && m_SpawnPool != nullptr; + if (Async) { - std::shared_ptr<StorageServerInstance::ExclusiveLockedPtr> SharedInstancePtr = - std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance)); - + auto SharedInstance = std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance)); m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( - [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Instance = std::move(SharedInstancePtr)]() mutable { + m_SpawnPool->ScheduleWork( + [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Port, Instance = SharedInstance]() { auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); try { - CompleteObliterate(*Instance, ActiveInstanceIndex); + RunObliteratePhase1(*Instance, ActiveInstanceIndex, Port); } catch (const std::exception& Ex) { - ZEN_ERROR("Failed async obliterate of module '{}': {}", ModuleId, Ex.what()); + ZEN_ERROR("Failed async obliterate phase 1 of module '{}': {}", ModuleId, Ex.what()); + return; + } + + m_BackgroundWorkLatch.AddCount(1); + try + { + m_ProvisionPool->ScheduleWork( + [this, ModuleId, ActiveInstanceIndex, Port, Instance]() { + auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); + RunObliteratePhase2(*Instance, ActiveInstanceIndex, Port); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& DispatchEx) + { + // Fallback: run Phase2 inline on the SpawnPool worker. Couples pool + // lifetimes (this SpawnPool slot now executes what should run on + // ProvisionPool) but better than dropping the request. + ZEN_ERROR("Failed async dispatch of obliterate phase 2 for module '{}': {}", ModuleId, DispatchEx.what()); + m_BackgroundWorkLatch.CountDown(); + RunObliteratePhase2(*Instance, ActiveInstanceIndex, Port); } }, WorkerThreadPool::EMode::EnableBacklog); } catch (const std::exception& DispatchEx) { - ZEN_ERROR("Failed async dispatch obliterate of module '{}': {}", ModuleId, DispatchEx.what()); + ZEN_ERROR("Failed async dispatch obliterate phase 1 of module '{}': {}", ModuleId, DispatchEx.what()); m_BackgroundWorkLatch.CountDown(); NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, OldState, Port, {}); @@ -860,15 +956,23 @@ Hub::Obliterate(const std::string& ModuleId) CompleteObliterate(Instance, ActiveInstanceIndex); } - return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; + return Response{Async ? EResponseCode::Accepted : EResponseCode::Completed}; } void Hub::CompleteObliterate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex) { ZEN_TRACE_CPU("Hub::CompleteObliterate"); + const uint16_t Port = Instance.GetBasePort(); + RunObliteratePhase1(Instance, ActiveInstanceIndex, Port); + RunObliteratePhase2(Instance, ActiveInstanceIndex, Port); +} + +void +Hub::RunObliteratePhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunObliteratePhase1"); const std::string ModuleId(Instance.GetModuleId()); - const uint16_t Port = Instance.GetBasePort(); try { @@ -876,15 +980,35 @@ Hub::CompleteObliterate(StorageServerInstance::ExclusiveLockedPtr& Instance, siz } catch (const std::exception& Ex) { + // Best-effort cleanup: drop tracking and mark Unprovisioned. Transitioning to + // Crashed would let the watchdog re-provision a module the operator wanted gone. ZEN_ERROR("Failed to obliterate storage server instance for module '{}': {}", ModuleId, Ex.what()); - Instance = {}; - { - RwLock::ExclusiveLockScope HubLock(m_Lock); - UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Crashed); - } - NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Crashed, Port, {}); + NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Unprovisioned, Port, {}); + RemoveInstance(Instance, ActiveInstanceIndex, ModuleId); throw; } +} + +void +Hub::RunObliteratePhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunObliteratePhase2"); + const std::string ModuleId(Instance.GetModuleId()); + + try + { + ObliterateBackendData(ModuleId); + } + catch (const std::exception& Ex) + { + // Backend delete failed - documented leak path (see hydration.cpp Obliterate + // retry-then-fail). Drop local tracking and mark Unprovisioned; the watchdog + // must not re-provision a module the operator wanted gone. + ZEN_ERROR("Failed to obliterate backend data for module '{}': {}", ModuleId, Ex.what()); + NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Unprovisioned, Port, {}); + RemoveInstance(Instance, ActiveInstanceIndex, ModuleId); + return; + } NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Unprovisioned, Port, {}); RemoveInstance(Instance, ActiveInstanceIndex, ModuleId); @@ -894,8 +1018,19 @@ void Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState) { ZEN_TRACE_CPU("Hub::CompleteDeprovision"); + const uint16_t Port = Instance.GetBasePort(); + RunDeprovisionPhase1(Instance, ActiveInstanceIndex, OldState, Port); + RunDeprovisionPhase2(Instance, ActiveInstanceIndex, Port); +} + +void +Hub::RunDeprovisionPhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunDeprovisionPhase1"); const std::string ModuleId(Instance.GetModuleId()); - const uint16_t Port = Instance.GetBasePort(); try { @@ -942,9 +1077,8 @@ Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, si catch (const std::exception& Ex) { ZEN_ERROR("Failed to deprovision storage server instance for module '{}': {}", ModuleId, Ex.what()); - // Effectively unreachable: Shutdown() never throws and Dehydrate() failures are swallowed - // by DeprovisionLocked. Kept as a safety net; if somehow reached, transition to Crashed - // so the watchdog can attempt recovery. + // GcClient HTTP calls and Instance.Deprovision can throw on transport failure + // or watchdog races; transition to Crashed so the watchdog can attempt recovery. Instance = {}; { RwLock::ExclusiveLockScope HubLock(m_Lock); @@ -953,6 +1087,22 @@ Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, si NotifyStateUpdate(ModuleId, HubInstanceState::Deprovisioning, HubInstanceState::Crashed, Port, {}); throw; } +} + +void +Hub::RunDeprovisionPhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunDeprovisionPhase2"); + const std::string ModuleId(Instance.GetModuleId()); + + try + { + DehydrateInstance(ActiveInstanceIndex, ModuleId); + } + catch (const std::exception& Ex) + { + ZEN_WARN("Dehydration of module {} failed during deprovisioning, current state not saved. Reason: {}", ModuleId, Ex.what()); + } NotifyStateUpdate(ModuleId, HubInstanceState::Deprovisioning, HubInstanceState::Unprovisioned, Port, {}); RemoveInstance(Instance, ActiveInstanceIndex, ModuleId); @@ -1012,11 +1162,7 @@ Hub::Hibernate(const std::string& ModuleId) } } - // NOTE: done while not holding the hub lock, to avoid blocking other operations. - // Any concurrent caller that acquired the hub lock and saw Provisioned will now block on - // LockExclusive(Wait=true); by the time it acquires the lock, UpdateInstanceState below - // will have already changed the state and the re-validate above will reject it. - + // Outside hub lock: re-validate after re-locking in worker rejects races. See Provision for the locking argument. ZEN_ASSERT(Instance); ZEN_ASSERT(ActiveInstanceIndex != (size_t)-1); @@ -1024,12 +1170,12 @@ Hub::Hibernate(const std::string& ModuleId) const uint16_t Port = Instance.GetBasePort(); NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Hibernating, Port, {}); - if (m_WorkerPool) + if (m_SpawnPool != nullptr) { m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( + m_SpawnPool->ScheduleWork( [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, @@ -1069,7 +1215,7 @@ Hub::Hibernate(const std::string& ModuleId) CompleteHibernate(Instance, ActiveInstanceIndex, OldState); } - return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; + return Response{m_SpawnPool != nullptr ? EResponseCode::Accepted : EResponseCode::Completed}; } void @@ -1148,11 +1294,7 @@ Hub::Wake(const std::string& ModuleId) } } - // NOTE: done while not holding the hub lock, to avoid blocking other operations. - // Any concurrent caller that acquired the hub lock and saw Hibernated will now block on - // LockExclusive(Wait=true); by the time it acquires the lock, UpdateInstanceState below - // will have already changed the state and the re-validate above will reject it. - + // Outside hub lock: re-validate after re-locking in worker rejects races. See Provision for the locking argument. ZEN_ASSERT(Instance); ZEN_ASSERT(ActiveInstanceIndex != (size_t)-1); @@ -1160,12 +1302,12 @@ Hub::Wake(const std::string& ModuleId) const uint16_t Port = Instance.GetBasePort(); NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Waking, Port, {}); - if (m_WorkerPool) + if (m_SpawnPool != nullptr) { m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( + m_SpawnPool->ScheduleWork( [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, @@ -1205,7 +1347,7 @@ Hub::Wake(const std::string& ModuleId) CompleteWake(Instance, ActiveInstanceIndex, OldState); } - return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; + return Response{m_SpawnPool != nullptr ? EResponseCode::Accepted : EResponseCode::Completed}; } void @@ -1243,7 +1385,8 @@ Hub::RemoveInstance(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t auto It = m_InstanceLookup.find(std::string(ModuleId)); ZEN_ASSERT_SLOW(It != m_InstanceLookup.end()); ZEN_ASSERT_SLOW(It->second == ActiveInstanceIndex); - DeleteInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance); + DeleteInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance); + m_ActiveInstances[ActiveInstanceIndex].HydrationState = {}; m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex); m_InstanceLookup.erase(It); UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Unprovisioned); @@ -1251,23 +1394,64 @@ Hub::RemoveInstance(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t DeleteInstance.reset(); } +HydrationConfig +Hub::MakeHydrationConfigForModule(std::string_view ModuleId, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag) const +{ + HydrationConfig Config{.ServerStateDir = m_RunEnvironment.GetChildBaseDir() / ModuleId, + .TempDir = m_HydrationTempPath / ModuleId, + .ModuleId = std::string(ModuleId)}; + if (m_Config.OptionalHydrationPool) + { + Config.Threading.emplace(HydrationConfig::ThreadingOptions{.WorkerPool = m_Config.OptionalHydrationPool, + .AbortFlag = &AbortFlag, + .PauseFlag = &PauseFlag}); + } + Config.PackEnabled = m_Config.HydrationPackEnabled; + Config.PackThresholdBytes = m_Config.HydrationPackThresholdBytes; + Config.MaxPackBytes = m_Config.HydrationMaxPackBytes; + return Config; +} + void -Hub::ObliterateBackendData(std::string_view ModuleId) +Hub::HydrateInstance(size_t ActiveInstanceIndex, std::string_view ModuleId) { - std::filesystem::path ServerStateDir = m_RunEnvironment.GetChildBaseDir() / ModuleId; - std::filesystem::path TempDir = m_HydrationTempPath / ModuleId; + if (!m_Config.EnableHydration) + { + ZEN_INFO("Hydration disabled; skipping hydrate for module '{}'", ModuleId); + return; + } + ZEN_TRACE_CPU("Hub::HydrateInstance"); - std::atomic<bool> AbortFlag{false}; - std::atomic<bool> PauseFlag{false}; + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + HydrationConfig Config = MakeHydrationConfigForModule(ModuleId, AbortFlag, PauseFlag); + std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration->CreateHydrator(Config); + m_ActiveInstances[ActiveInstanceIndex].HydrationState = Hydrator->Hydrate(); +} - HydrationConfig Config{.ServerStateDir = ServerStateDir, .TempDir = TempDir, .ModuleId = std::string(ModuleId)}; - if (m_Config.OptionalHydrationWorkerPool) +void +Hub::DehydrateInstance(size_t ActiveInstanceIndex, std::string_view ModuleId) +{ + if (!m_Config.EnableDehydration) { - Config.Threading.emplace(HydrationConfig::ThreadingOptions{.WorkerPool = m_Config.OptionalHydrationWorkerPool, - .AbortFlag = &AbortFlag, - .PauseFlag = &PauseFlag}); + ZEN_INFO("Dehydration disabled; skipping dehydrate for module '{}'", ModuleId); + return; } + ZEN_TRACE_CPU("Hub::DehydrateInstance"); + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + HydrationConfig Config = MakeHydrationConfigForModule(ModuleId, AbortFlag, PauseFlag); + std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration->CreateHydrator(Config); + Hydrator->Dehydrate(m_ActiveInstances[ActiveInstanceIndex].HydrationState); +} + +void +Hub::ObliterateBackendData(std::string_view ModuleId) +{ + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + HydrationConfig Config = MakeHydrationConfigForModule(ModuleId, AbortFlag, PauseFlag); std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration->CreateHydrator(Config); Hydrator->Obliterate(); } @@ -1478,6 +1662,16 @@ Hub::AttemptRecoverInstance(std::string_view ModuleId) try { Instance.Deprovision(); + try + { + DehydrateInstance(ActiveInstanceIndex, ModuleId); + } + catch (const std::exception& Ex) + { + ZEN_WARN("Dehydration of module {} failed during crash recovery cleanup, current state not saved. Reason: {}", + ModuleId, + Ex.what()); + } } catch (const std::exception& Ex) { @@ -1502,6 +1696,7 @@ Hub::AttemptRecoverInstance(std::string_view ModuleId) try { + HydrateInstance(ActiveInstanceIndex, ModuleId); Instance.Provision(); UpdateInstanceState(Instance, ActiveInstanceIndex, HubInstanceState::Provisioned); NotifyStateUpdate(ModuleId, HubInstanceState::Recovering, HubInstanceState::Provisioned, Instance.GetBasePort(), /*BaseUri*/ {}); @@ -1795,7 +1990,6 @@ Hub::WatchDog() } catch (const std::exception& Ex) { - // TODO: Catch specific errors such as asserts, OOM, OOD, system_error etc ZEN_ERROR("Hub watchdog threw exception: {}", Ex.what()); } } @@ -1835,9 +2029,15 @@ namespace hub_testutils { struct TestHubPools { WorkerThreadPool ProvisionPool; + WorkerThreadPool SpawnPool; WorkerThreadPool HydrationPool; - explicit TestHubPools(int ThreadCount) : ProvisionPool(ThreadCount, "hub_test_prov"), HydrationPool(ThreadCount, "hub_test_hydr") {} + explicit TestHubPools(int ThreadCount) + : ProvisionPool(ThreadCount, "hub_test_provision") + , SpawnPool(ThreadCount, "hub_test_spawn") + , HydrationPool(ThreadCount, "hub_test_hydr") + { + } }; ZenServerEnvironment MakeHubEnvironment(const std::filesystem::path& BaseDir) @@ -1852,8 +2052,9 @@ namespace hub_testutils { { if (Pools) { - Config.OptionalProvisionWorkerPool = &Pools->ProvisionPool; - Config.OptionalHydrationWorkerPool = &Pools->HydrationPool; + Config.OptionalProvisionPool = &Pools->ProvisionPool; + Config.OptionalSpawnPool = &Pools->SpawnPool; + Config.OptionalHydrationPool = &Pools->HydrationPool; } return std::make_unique<Hub>(Config, MakeHubEnvironment(BaseDir), std::move(StateChangeCallback)); } diff --git a/src/zenserver/hub/hub.h b/src/zenserver/hub/hub.h index fa504de33..c9b673ad6 100644 --- a/src/zenserver/hub/hub.h +++ b/src/zenserver/hub/hub.h @@ -3,6 +3,7 @@ #pragma once #include "hubinstancestate.h" +#include "hydration.h" #include "resourcemetrics.h" #include "storageserverinstance.h" @@ -81,13 +82,22 @@ public: bool HydrationPackEnabled = true; uint64_t HydrationPackThresholdBytes = DefaultPackThresholdBytes; uint64_t HydrationMaxPackBytes = DefaultMaxPackBytes; + // Route S3 hydration through AsyncHttpClient. false falls back to the + // blocking S3Client path. + bool HydrationAsyncEnabled = true; + + // Hub-wide cap on concurrent S3 hydration requests (sizes the shared + // AsyncHttpClient connection pool and the admission semaphore). Only + // consulted when HydrationAsyncEnabled. + uint32_t HydrationAsyncMaxConcurrentRequests = 128; WatchDogConfiguration WatchDog; ResourceMetrics ResourceLimits; - WorkerThreadPool* OptionalProvisionWorkerPool = nullptr; - WorkerThreadPool* OptionalHydrationWorkerPool = nullptr; + WorkerThreadPool* OptionalProvisionPool = nullptr; + WorkerThreadPool* OptionalSpawnPool = nullptr; + WorkerThreadPool* OptionalHydrationPool = nullptr; }; typedef std::function< @@ -209,7 +219,8 @@ public: private: const Configuration m_Config; ZenServerEnvironment m_RunEnvironment; - WorkerThreadPool* m_WorkerPool = nullptr; + WorkerThreadPool* m_ProvisionPool = nullptr; + WorkerThreadPool* m_SpawnPool = nullptr; // Declared early so it destructs late: every HttpClient referencing the // share (watchdog ActivityCheckClient, GC client, in-flight worker // tasks) is required to be gone before this member runs its dtor. @@ -264,6 +275,12 @@ private: // Set in UpdateInstanceStateLocked on every state transition; read lock-free by Find/EnumerateModules. std::atomic<std::chrono::system_clock::time_point> StateChangeTime = std::chrono::system_clock::time_point::min(); + + // Cached hydration state returned by Hydrate, consumed by Dehydrate. Synchronized + // by the caller's StorageServerInstance ExclusiveLockedPtr (HydrateInstance and + // DehydrateInstance run under it); RemoveInstance also writes under the Hub + // exclusive lock when finalizing slot teardown. + CbObject HydrationState; }; // UpdateInstanceState is overloaded to accept a locked instance pointer (exclusive or shared) or the hub exclusive @@ -319,8 +336,42 @@ private: void CompleteObliterate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex); void CompleteHibernate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState); void CompleteWake(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState); + + // Provision/Deprovision/Obliterate are split into two phases scheduled on different worker + // pools. The Phase1/Phase2 helpers are shared between sync and async code paths so behavior + // cannot diverge between them. + bool RunProvisionPhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port); + void RunProvisionPhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port); + void RollbackFailedProvision(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port); + + void RunDeprovisionPhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + uint16_t Port); + void RunDeprovisionPhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port); + + void RunObliteratePhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port); + void RunObliteratePhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port); + void RemoveInstance(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, std::string_view ModuleId); - void ObliterateBackendData(std::string_view ModuleId); + HydrationConfig MakeHydrationConfigForModule(std::string_view ModuleId, + std::atomic<bool>& AbortFlag, + std::atomic<bool>& PauseFlag) const; + void HydrateInstance(size_t ActiveInstanceIndex, std::string_view ModuleId); + void DehydrateInstance(size_t ActiveInstanceIndex, std::string_view ModuleId); + void ObliterateBackendData(std::string_view ModuleId); // Notifications may fire slightly out of sync with the Hub's internal State flag. // The guarantee is that notifications are sent in the correct order, but the State diff --git a/src/zenserver/hub/hydration.cpp b/src/zenserver/hub/hydration.cpp index 621af8a46..2df326fab 100644 --- a/src/zenserver/hub/hydration.cpp +++ b/src/zenserver/hub/hydration.cpp @@ -2,6 +2,12 @@ #include "hydration.h" +#include "s3asyncstorage.h" + +#include <zenhttp/asynchttpclient.h> +#include <zenutil/cloud/s3requestbuilder.h> +#include <zenutil/cloud/s3response.h> + #include <zencore/basicfile.h> #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> @@ -30,6 +36,7 @@ #include <unordered_set> #if ZEN_WITH_TESTS +# include <zencore/process.h> # include <zencore/testing.h> # include <zencore/testutils.h> # include <zencore/workthreadpool.h> @@ -159,6 +166,12 @@ namespace hydration_impl { std::atomic<uint64_t> FirstScheduleUs{UINT64_MAX}; std::atomic<uint64_t> FirstStartUs{UINT64_MAX}; + // Admission-gate wait. AdmissionWaitTotalUs is summed across all requests that blocked on + // the storage-layer concurrency semaphore; AdmissionWaitMaxUs is the worst single wait + // observed. Total exposes back-pressure cost; Max surfaces head-of-line blocking. + std::atomic<uint64_t> AdmissionWaitTotalUs{0}; + std::atomic<uint64_t> AdmissionWaitMaxUs{0}; + void RecordScheduled() { uint64_t Now = PhaseClock.GetElapsedTimeUs(); @@ -195,6 +208,15 @@ namespace hydration_impl { { } } + + void RecordAdmissionWait(uint64_t Us) + { + AdmissionWaitTotalUs.fetch_add(Us, std::memory_order_relaxed); + uint64_t Prev = AdmissionWaitMaxUs.load(std::memory_order_relaxed); + while (Us > Prev && !AdmissionWaitMaxUs.compare_exchange_weak(Prev, Us, std::memory_order_relaxed)) + { + } + } }; struct DehydrateStatistics @@ -362,6 +384,116 @@ namespace hydration_impl { uint64_t m_MultipartChunkSize; }; + S3AsyncStorageStats AsyncStatsFrom(PhaseStats& Stats) + { + return S3AsyncStorageStats{Stats.RequestCount, + Stats.RequestTotalUs, + Stats.RequestMaxUs, + Stats.Bytes, + Stats.InFlight, + Stats.InFlightPeak, + Stats.FirstScheduleUs, + Stats.FirstStartUs, + Stats.AdmissionWaitTotalUs, + Stats.AdmissionWaitMaxUs, + Stats.PhaseClock}; + } + + class S3AsyncStorageAdapter : public StorageBase + { + public: + static constexpr std::string_view Type = "s3-async"; + + S3AsyncStorageAdapter(AsyncHttpClient& Client, + S3RequestBuilder& Builder, + S3AsyncStorage::CredentialsCallback GetCreds, + std::string KeyPrefix, + uint64_t MultipartChunkSize, + std::shared_ptr<AdmissionSemaphore> Admission, + uint32_t AdmissionCap) + : m_Client(Client) + , m_Builder(Builder) + , m_GetCreds(std::move(GetCreds)) + , m_KeyPrefix(std::move(KeyPrefix)) + , m_MultipartChunkSize(MultipartChunkSize) + , m_Admission(std::move(Admission)) + , m_AdmissionCap(AdmissionCap) + , m_Storage( + std::make_unique<S3AsyncStorage>(Client, Builder, m_GetCreds, m_KeyPrefix, m_MultipartChunkSize, m_Admission, m_AdmissionCap)) + { + } + + virtual std::string Describe() const override { return fmt::format("s3-async://{}/{}"sv, m_Builder.BucketName(), m_KeyPrefix); } + + virtual void SaveMetadata(const CbObject& Data) override; + virtual CbObject LoadMetadata() override; + virtual CbObject GetSettings() override + { + CbObjectWriter Writer; + Writer << "MultipartChunkSize"sv << m_MultipartChunkSize; + return Writer.Save(); + } + virtual void ParseSettings(const CbObjectView& Settings) override + { + m_MultipartChunkSize = Settings["MultipartChunkSize"sv].AsUInt64(DefaultMultipartChunkSize); + m_Storage = std::make_unique<S3AsyncStorage>(m_Client, + m_Builder, + m_GetCreds, + m_KeyPrefix, + m_MultipartChunkSize, + m_Admission, + m_AdmissionCap); + } + virtual std::vector<IoHash> List() override { return m_Storage->List(); } + + virtual void Put(ParallelWork& Work, + WorkerThreadPool& WorkerPool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + PhaseStats& Stats) override + { + Stats.Files.fetch_add(1, std::memory_order_relaxed); + S3AsyncStorageStats AsyncStats = AsyncStatsFrom(Stats); + m_Storage->Put(Work, WorkerPool, Hash, Size, SourcePath, AsyncStats); + } + + virtual void Get(ParallelWork& Work, + WorkerThreadPool& WorkerPool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath, + PhaseStats& Stats) override + { + Stats.Files.fetch_add(1, std::memory_order_relaxed); + S3AsyncStorageStats AsyncStats = AsyncStatsFrom(Stats); + m_Storage->Get(Work, WorkerPool, Hash, Size, DestinationPath, AsyncStats); + } + + virtual void Touch(ParallelWork& Work, WorkerThreadPool& WorkerPool, const IoHash& Hash, PhaseStats& Stats) override + { + Stats.Files.fetch_add(1, std::memory_order_relaxed); + S3AsyncStorageStats AsyncStats = AsyncStatsFrom(Stats); + m_Storage->Touch(Work, WorkerPool, Hash, AsyncStats); + } + + virtual void Delete(ParallelWork& Work, WorkerThreadPool& WorkerPool) override + { + ZEN_UNUSED(WorkerPool); + m_Storage->DeleteAll(Work); + } + + private: + AsyncHttpClient& m_Client; + S3RequestBuilder& m_Builder; + S3AsyncStorage::CredentialsCallback m_GetCreds; + std::string m_KeyPrefix; + uint64_t m_MultipartChunkSize; + std::shared_ptr<AdmissionSemaphore> m_Admission; + uint32_t m_AdmissionCap = 0; + std::unique_ptr<S3AsyncStorage> m_Storage; + }; + /////////////////////////////////////////////////////////////////////// // FileStorage implementations @@ -701,9 +833,14 @@ namespace hydration_impl { RenameFile(ChunkPath, DestinationPath, Ec); if (Ec) { - Chunk.Content = IoBufferBuilder::MakeFromFile(ChunkPath); - Chunk.Content.SetDeleteOnClose(true); - WriteFile(DestinationPath, Chunk.Content); + // Cross-volume rename failed; copy the temp file to the destination + // using explicit positional reads (no mmap). Caller is responsible + // for the source file's eventual cleanup via the temp dir sweep. + BasicFile Src(ChunkPath, BasicFile::Mode::kRead); + IoBuffer Body = Src.ReadAll(); + WriteFile(DestinationPath, Body); + std::error_code RemoveEc; + std::filesystem::remove(ChunkPath, RemoveEc); } } } @@ -759,6 +896,63 @@ namespace hydration_impl { } } + void S3AsyncStorageAdapter::SaveMetadata(const CbObject& Data) + { + ZEN_TRACE_CPU("S3AsyncStorageAdapter::SaveMetadata"); + BinaryWriter Output; + SaveCompactBinary(Output, Data); + + SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + throw zen::runtime_error("S3AsyncStorageAdapter::SaveMetadata: no credentials available"sv); + } + + std::string Key = fmt::format("{}/incremental-state.cbo", m_KeyPrefix); + std::string Path = m_Builder.KeyToPath(Key); + std::string Hash = Sha256ToHex(ComputeSha256(Output.GetData(), Output.GetSize())); + HttpClient::KeyValueMap Signed = m_Builder.SignRequest(Creds, "PUT", Path, "", Hash); + IoBuffer Payload(IoBuffer::Clone, Output.GetData(), Output.GetSize()); + + HttpClient::Response Resp = m_Client.Put(Path, Payload, Signed).get(); + if (!Resp.IsSuccess()) + { + throw zen::runtime_error("Failed to save incremental metadata to '{}': {}"sv, Key, S3ErrorMessage("S3 PUT failed", Resp)); + } + } + + CbObject S3AsyncStorageAdapter::LoadMetadata() + { + ZEN_TRACE_CPU("S3AsyncStorageAdapter::LoadMetadata"); + SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + throw zen::runtime_error("S3AsyncStorageAdapter::LoadMetadata: no credentials available"sv); + } + + std::string Key = fmt::format("{}/incremental-state.cbo", m_KeyPrefix); + std::string Path = m_Builder.KeyToPath(Key); + HttpClient::KeyValueMap Signed = m_Builder.SignRequest(Creds, "GET", Path, "", S3EmptyPayloadHash); + + HttpClient::Response Resp = m_Client.Get(Path, Signed).get(); + if (!Resp.IsSuccess()) + { + if (Resp.StatusCode == HttpResponseCode::NotFound) + { + return {}; + } + throw zen::runtime_error("Failed to load incremental metadata from '{}': {}"sv, Key, S3ErrorMessage("S3 GET failed", Resp)); + } + + CbValidateError Error; + CbObject Meta = ValidateAndReadCompactBinaryObject(std::move(Resp.ResponsePayload), Error); + if (Error != CbValidateError::None) + { + throw zen::runtime_error("Failed to parse incremental metadata from '{}': {}"sv, Key, ToString(Error)); + } + return Meta; + } + /////////////////////////////////////////////////////////////////////// // IncrementalHydrator: the only HydrationStrategyBase implementation. // Summary emission for hydrate/dehydrate operations. @@ -802,6 +996,7 @@ namespace hydration_impl { // (Stats.PackUpload), loose Touch (Stats.Touch), and pack-blob Touch (Stats.PackTouch). // Per-request data is collected per PhaseStats by Storage::Put / Storage::Touch and // reported as a single combined "Requests" line. + // const uint64_t UpReqCount = Stats.Upload.RequestCount.load() + Stats.PackUpload.RequestCount.load() + Stats.Touch.RequestCount.load() + Stats.PackTouch.RequestCount.load(); const uint64_t UpReqTotalUs = Stats.Upload.RequestTotalUs.load() + Stats.PackUpload.RequestTotalUs.load() + @@ -827,6 +1022,16 @@ namespace hydration_impl { Stats.PackTouch.FirstStartUs.load()}); const uint64_t UpQueueUs = QueueWaitUs(UpFirstSchedUs, UpFirstStartUs); + // Storage-layer admission wait. Sum across all four upload sub-phases gives total + // time blocked acquiring slots on the dispatcher; max is the worst single wait + // observed. Both zero when admission is disabled (file-backend / blocking S3). + const uint64_t UpAdmTotalUs = Stats.Upload.AdmissionWaitTotalUs.load() + Stats.PackUpload.AdmissionWaitTotalUs.load() + + Stats.Touch.AdmissionWaitTotalUs.load() + Stats.PackTouch.AdmissionWaitTotalUs.load(); + const uint64_t UpAdmMaxUs = std::max({Stats.Upload.AdmissionWaitMaxUs.load(), + Stats.PackUpload.AdmissionWaitMaxUs.load(), + Stats.Touch.AdmissionWaitMaxUs.load(), + Stats.PackTouch.AdmissionWaitMaxUs.load()}); + const uint64_t LooseFiles = Stats.Upload.Files.load(); const uint64_t LooseBytes = Stats.Upload.Bytes.load(); const uint64_t TouchFiles = Stats.Touch.Files.load(); @@ -852,7 +1057,7 @@ namespace hydration_impl { " List existing: {}\n" " Pack: {} {} packs, {} files, {}, {}bits/s\n" " Upload: {} loose {} files ({}), packed {} blobs ({}), touched {} loose ({}) + {} packs ({}), {}bits/s\n" - " Requests: {} reqs, avg {}/req, max {}/req, peak in-flight {}, queue wait {}\n" + " Requests: {} reqs, avg {}/req, max {}/req, peak in-flight {}, queue wait {}, admission wait avg {}/req max {}\n" " Save metadata: {}\n" " Clean: {}", Prefix, @@ -896,6 +1101,8 @@ namespace hydration_impl { NiceTimeSpanUs(UpReqMaxUs), UpPeak, NiceTimeSpanUs(UpQueueUs), + NiceTimeSpanUs(SafeAvg(UpAdmTotalUs, UpReqCount)), + NiceTimeSpanUs(UpAdmMaxUs), NiceTimeSpanUs(Stats.SaveMetadataUs.load()), NiceTimeSpanUs(Stats.CleanUs.load())); } @@ -924,6 +1131,12 @@ namespace hydration_impl { const uint64_t DlFirstStartUs = std::min(Stats.Download.FirstStartUs.load(), Stats.PackDownload.FirstStartUs.load()); const uint64_t QueueUs = QueueWaitUs(DlFirstSchedUs, DlFirstStartUs); + // Storage-layer admission wait. Sum across loose + pack downloads gives total + // dispatcher block time; max is the worst single wait. Both zero when admission + // is disabled (file-backend / blocking S3). + const uint64_t DlAdmTotalUs = Stats.Download.AdmissionWaitTotalUs.load() + Stats.PackDownload.AdmissionWaitTotalUs.load(); + const uint64_t DlAdmMaxUs = std::max(Stats.Download.AdmissionWaitMaxUs.load(), Stats.PackDownload.AdmissionWaitMaxUs.load()); + const uint64_t PackCount = Stats.PackCount.load(); const uint64_t PackedFiles = Stats.PackedFiles.load(); const uint64_t PackUnpackUs = Stats.PackUnpackUs.load(); @@ -942,7 +1155,7 @@ namespace hydration_impl { " Load metadata: {}\n" " Create dirs: {} {} dirs, {} dirs/s\n" " Download: {} loose {} files ({}), packed {} blobs ({}), {}bits/s\n" - " Requests: {} reqs, avg {}/req, max {}/req, peak in-flight {}, queue wait {}\n" + " Requests: {} reqs, avg {}/req, max {}/req, peak in-flight {}, queue wait {}, admission wait avg {}/req max {}\n" " Unpack: {} {} packs, {} files ({}), {}bits/s\n" " Clean: {}\n" " Finalize: {}\n" @@ -969,6 +1182,8 @@ namespace hydration_impl { NiceTimeSpanUs(DlReqMaxUs), DlPeak, NiceTimeSpanUs(QueueUs), + NiceTimeSpanUs(SafeAvg(DlAdmTotalUs, DlReqCount)), + NiceTimeSpanUs(DlAdmMaxUs), NiceTimeSpanUs(PackUnpackUs), ThousandsNum(PackCount), ThousandsNum(PackedFiles), @@ -1157,20 +1372,26 @@ namespace hydration_impl { // the hash is a meta-hash combining the embedded RawHash with the file size, which // avoids a collision between an uncompressed file and a same-content compressed file. // All other files use a streaming raw hash via BasicFile + IoHashStream (sequential - // read, friendlier to the Windows cache manager than mmap). + // reads). All reads are explicit positional reads via BasicFile - no mmap, no + // IoBufferBuilder::MakeFromFile materialization. void HashFileContent(const std::filesystem::path& AbsPath, Entry& Out) { + BasicFile File(AbsPath, BasicFile::Mode::kRead); + if (AbsPath.extension().empty()) { std::string_view Rel = Out.RelativePath; std::string_view First = Rel.substr(0, Rel.find('/')); if (First.ends_with("cas")) { + // Read compressed bytes into a heap IoBuffer (single positional read, no mmap) + // and probe with FromCompressed. On success, derive a meta-hash from the + // embedded RawHash + size and return without hashing the bytes themselves. + IoBuffer Compressed = File.ReadAll(); IoHash RawHash; uint64_t RawSize; - CompressedBuffer Compressed = - CompressedBuffer::FromCompressed(SharedBuffer(IoBufferBuilder::MakeFromFile(AbsPath)), RawHash, RawSize); - if (Compressed) + CompressedBuffer Probe = CompressedBuffer::FromCompressed(SharedBuffer(std::move(Compressed)), RawHash, RawSize); + if (Probe) { IoHashStream Hasher; Hasher.Append(RawHash.Hash, sizeof(RawHash.Hash)); @@ -1178,10 +1399,10 @@ namespace hydration_impl { Out.Hash = Hasher.GetHash(); return; } + // Not a compressed file - fall through to streaming raw hash. } } - BasicFile File(AbsPath, BasicFile::Mode::kRead); IoHashStream Hasher; File.StreamFile([&Hasher](const void* Data, uint64_t Size) { Hasher.Append(Data, Size); }); Out.Hash = Hasher.GetHash(); @@ -1264,9 +1485,7 @@ namespace hydration_impl { // Returns one PackPlan per pack to build (empty if no packs are produced). std::vector<PackPlan> PlanPacks(std::vector<Entry>& Entries, uint64_t Threshold, uint64_t MaxPackBytes) { - // 1. Group small-file Entries[] indices by content hash. Every index in a group - // shares the same bytes, so any one of them sources the pack content; all of - // them get tagged IsPacked once the pack hash is known. + // Group small-file Entries[] indices by content hash. std::unordered_map<IoHash, EntryGroup, IoHash::Hasher> UniqueMap; for (size_t Index = 0; Index < Entries.size(); ++Index) { @@ -1277,7 +1496,6 @@ namespace hydration_impl { UniqueMap[Entries[Index].Hash].push_back(Index); } - // Need at least 2 unique groups for any pack to survive the "discard 1-entry packs" rule. if (UniqueMap.size() < 2) { return {}; @@ -1286,7 +1504,7 @@ namespace hydration_impl { auto GroupHash = [&](const EntryGroup& G) -> const IoHash& { return Entries[G.front()].Hash; }; auto GroupSize = [&](const EntryGroup& G) -> uint64_t { return Entries[G.front()].Size; }; - // 2. Deterministic order: ascending IoHash. Drain the map so the index vectors move. + // Sort groups by ascending IoHash for deterministic pack composition. std::vector<EntryGroup> Ordered; Ordered.reserve(UniqueMap.size()); for (auto& [h, g] : UniqueMap) @@ -1295,7 +1513,7 @@ namespace hydration_impl { } std::sort(Ordered.begin(), Ordered.end(), [&](const EntryGroup& A, const EntryGroup& B) { return GroupHash(A) < GroupHash(B); }); - // 3. Bin-pack greedily under MaxPackBytes. + // Bin-pack greedily under MaxPackBytes. std::vector<PackPlan> Plans; PackPlan Current; uint64_t CurrentSize = 0; @@ -1815,6 +2033,17 @@ namespace hydration_impl { const std::vector<PackPlan> Pending = m_Config.PackEnabled ? PlanPacks(Entries, m_Config.PackThresholdBytes, m_Config.MaxPackBytes) : std::vector<PackPlan>{}; + // Pre-build absolute paths once. MakeSafeAbsolutePath does path normalization + + // Windows long-path prefix application; ~microseconds per entry. With 100k+ + // entries the per-iter cost in the dispatch loop adds up. Mirrors the Hydrate + // side which already pre-builds EntryPaths. + std::vector<std::filesystem::path> EntryPaths; + EntryPaths.reserve(Entries.size()); + for (const Entry& CurrentEntry : Entries) + { + EntryPaths.push_back(MakeSafeAbsolutePath(ServerStateDir / CurrentEntry.RelativePath)); + } + uint64_t DehydrateDurationMs = 0; { // Upload, PackUpload, Touch, and PackTouch share one ParallelWork; reset all @@ -1829,9 +2058,9 @@ namespace hydration_impl { // Schedule loose-CAS uploads first so they begin running while the pack-build // loop below executes serially on this thread. - for (const Entry& CurrentEntry : Entries) + for (size_t I = 0; I < Entries.size(); ++I) { - if (CurrentEntry.IsPacked) + if (Entries[I].IsPacked) { continue; // pack phase covers it } @@ -1839,9 +2068,9 @@ namespace hydration_impl { Work, *m_Threading.WorkerPool, ExistsLookup, - CurrentEntry.Hash, - CurrentEntry.Size, - MakeSafeAbsolutePath(ServerStateDir / CurrentEntry.RelativePath), + Entries[I].Hash, + Entries[I].Size, + EntryPaths[I], Stats.Upload, Stats.Touch); } @@ -1930,6 +2159,8 @@ namespace hydration_impl { } catch (const std::exception& Ex) { + // Failure is OK to swallow: next dehydrate or fresh hydrate falls back to + // the older state still on the backend. ZEN_WARN("Dehydration of module '{}' failed: {}. Leaving server state '{}'", m_Config.ModuleId, Ex.what(), @@ -2143,6 +2374,8 @@ namespace hydration_impl { } catch (const std::exception& Ex) { + // Failure is OK to swallow: starts the instance with empty state, next + // dehydrate re-publishes from whatever the running instance materializes. ZEN_WARN("Hydration of module '{}' failed: {}. Cleaning server state '{}'", m_Config.ModuleId, Ex.what(), @@ -2227,6 +2460,20 @@ private: Ref<ImdsCredentialProvider> m_CredentialProvider; std::unique_ptr<S3Client> m_Client; uint64_t m_DefaultMultipartChunkSize; + + // Async path: when Config.AsyncEnabled, build AsyncHttpClient + S3RequestBuilder + // shared across all per-module storage instances. Null otherwise. + std::unique_ptr<AsyncHttpClient> m_AsyncClient; + std::unique_ptr<S3RequestBuilder> m_AsyncBuilder; + + // Storage-layer admission gate, shared across all per-module S3AsyncStorage + // instances. Initial slot count = AsyncMaxConcurrentRequests. + std::shared_ptr<AdmissionSemaphore> m_AsyncAdmission; + uint32_t m_AsyncAdmissionCap = 0; + + // Captures m_Credentials / m_CredentialProvider state via callable so per-module + // S3AsyncStorage instances stay decoupled from the credential rotation logic. + S3AsyncStorage::CredentialsCallback BuildCredentialsCallback(); }; HydrationBase::HydrationBase(const Configuration& Config) @@ -2357,6 +2604,49 @@ S3Hydration::S3Hydration(const Configuration& Config) : HydrationBase(Config) ClientOptions.HttpSettings.RetryCount = 3; m_Client = std::make_unique<S3Client>(ClientOptions); + + if (Config.AsyncEnabled) + { + // Curl conn caps pinned to the request cap so handles never sit on + // libcurl's internal queue waiting for a connection slot (CONNECTTIMEOUT + // would tick there). With one S3 endpoint all connections go to the same + // host: PerHost is the binding cap, Total mirrors. MaxConcurrentRequests + // is a hint shared with the storage admission semaphore below. + HttpClientSettings AsyncSettings = ClientOptions.HttpSettings; + AsyncSettings.MaxConcurrentRequests = Config.AsyncMaxConcurrentRequests; + AsyncSettings.MaxConcurrentConnectionsPerHost = Config.AsyncMaxConcurrentRequests; + AsyncSettings.MaxConcurrentConnectionsTotal = Config.AsyncMaxConcurrentRequests; + + m_AsyncBuilder = std::make_unique<S3RequestBuilder>(m_Region, m_Bucket, m_Endpoint, m_PathStyle); + m_AsyncClient = std::make_unique<AsyncHttpClient>(m_AsyncBuilder->Endpoint(), AsyncSettings); + + // Storage-layer admission: paces S3 fan-out at the same in-flight cap + // curl uses for connection limits. Acquire happens on the dispatcher + // thread that drives Hydrate/Dehydrate, so back-pressure flows back to + // the caller without blocking io strand or hydration-pool workers. + m_AsyncAdmissionCap = Config.AsyncMaxConcurrentRequests; + m_AsyncAdmission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(m_AsyncAdmissionCap)); + ZEN_INFO("S3 hydration: async path enabled (max-concurrent-requests={})", Config.AsyncMaxConcurrentRequests); + } + else + { + ZEN_INFO("S3 hydration: blocking S3Client path"); + } +} + +S3AsyncStorage::CredentialsCallback +S3Hydration::BuildCredentialsCallback() +{ + if (m_CredentialProvider) + { + Ref<ImdsCredentialProvider> Provider = m_CredentialProvider; + return [Provider]() { + SigV4Credentials Creds = Provider->GetCredentials(); + return Creds; + }; + } + SigV4Credentials Creds = m_Credentials; + return [Creds]() { return Creds; }; } std::unique_ptr<HydrationStrategyBase> @@ -2365,6 +2655,19 @@ S3Hydration::CreateHydrator(const HydrationConfig& Config) using namespace hydration_impl; std::string KeyPrefix = m_KeyPrefixRoot.empty() ? std::string(Config.ModuleId) : fmt::format("{}/{}"sv, m_KeyPrefixRoot, Config.ModuleId); + + if (m_AsyncClient) + { + return std::make_unique<IncrementalHydrator>(Config, + std::make_unique<S3AsyncStorageAdapter>(*m_AsyncClient, + *m_AsyncBuilder, + BuildCredentialsCallback(), + std::move(KeyPrefix), + m_DefaultMultipartChunkSize, + m_AsyncAdmission, + m_AsyncAdmissionCap), + m_Excludes); + } return std::make_unique<IncrementalHydrator>( Config, std::make_unique<S3Storage>(*m_Client, std::move(KeyPrefix), Config.TempDir, m_DefaultMultipartChunkSize), @@ -2993,10 +3296,20 @@ TEST_CASE("hydration.file.concurrent") // The MinIO binary must be present in the same directory as the test executable (copied by xmake). // --------------------------------------------------------------------------- +namespace { + // Per-binary unique MinIO port. + uint16_t AllocateHydrationMinioTestPort() + { + static const uint16_t Base = static_cast<uint16_t>(20000u + (static_cast<uint32_t>(GetCurrentProcessId()) % 30000u)); + static std::atomic<uint16_t> Slot{0}; + return Base + Slot.fetch_add(1, std::memory_order_relaxed); + } +} // namespace + TEST_CASE("hydration.s3.dehydrate_hydrate") { MinioProcessOptions MinioOpts; - MinioOpts.Port = 19011; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3055,7 +3368,7 @@ TEST_CASE("hydration.s3.concurrent") // N modules dehydrate and hydrate concurrently against MinIO. // Each module has a distinct ModuleId, so S3 key prefixes don't overlap. MinioProcessOptions MinioOpts; - MinioOpts.Port = 19013; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3149,7 +3462,7 @@ TEST_CASE("hydration.s3.concurrent") TEST_CASE("hydration.s3.obliterate") { MinioProcessOptions MinioOpts; - MinioOpts.Port = 19019; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3215,7 +3528,7 @@ TEST_CASE("hydration.s3.obliterate") TEST_CASE("hydration.s3.config_overrides") { MinioProcessOptions MinioOpts; - MinioOpts.Port = 19015; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3293,7 +3606,7 @@ TEST_CASE("hydration.s3.config_overrides") TEST_CASE("hydration.s3.dehydrate_hydrate.performance" * doctest::skip()) { MinioProcessOptions MinioOpts; - MinioOpts.Port = 19010; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3425,7 +3738,7 @@ TEST_CASE("hydration.file.incremental") TEST_CASE("hydration.s3.incremental") { MinioProcessOptions MinioOpts; - MinioOpts.Port = 19017; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3524,6 +3837,328 @@ TEST_CASE("hydration.create_hydrator_rejects_invalid_config") CHECK_THROWS(InitHydration({})); } +TEST_CASE("hydration.s3async.dehydrate_hydrate") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = AllocateHydrationMinioTestPort(); + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + ScopedTemporaryDirectory TempDir; + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationTemp); + + HydrationBase::Configuration BaseConfig; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + BaseConfig.Options = std::move(Root).AsObject(); + } + BaseConfig.AsyncEnabled = true; + auto Hydration = InitHydration(BaseConfig); + + HydrationConfig Config{.ServerStateDir = ServerStateDir, .TempDir = HydrationTemp, .ModuleId = "s3async_roundtrip"}; + + WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256)); + Hydration->CreateHydrator(Config)->Hydrate(); + CHECK(std::filesystem::is_empty(ServerStateDir)); + + CreateSmallTestTree(ServerStateDir); + Hydration->CreateHydrator(Config)->Dehydrate(CbObject()); + + CreateSmallTestTree(ServerStateDir); + WriteFile(ServerStateDir / "v2marker.bin", CreateSemiRandomBlob(64)); + Hydration->CreateHydrator(Config)->Dehydrate(CbObject()); + + CleanDirectory(ServerStateDir, true); + Hydration->CreateHydrator(Config)->Hydrate(); + + CHECK(std::filesystem::exists(ServerStateDir / "v2marker.bin")); + CHECK(std::filesystem::exists(ServerStateDir / "subdir" / "file_b.bin")); + CHECK(std::filesystem::exists(ServerStateDir / "subdir" / "nested" / "file_c.bin")); +} + +// Exercises all three Put tiers (Small/Medium/Multipart) plus pack uploads in +// one round-trip. CreateTestTree adds 256K/512K/9M/63M blobs on top of the +// small-file set; the small files get packed, the 9M lands in PutMedium, and +// the 63M lands in PutMultipart. +TEST_CASE("hydration.s3async.dehydrate_hydrate.all_tiers") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = AllocateHydrationMinioTestPort(); + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + ScopedTemporaryDirectory TempDir; + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationTemp); + + HydrationBase::Configuration BaseConfig; + { + std::string ConfigJson = fmt::format( + R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true,"chunksize":{}}}}})", + Minio.Endpoint(), + 5u * 1024u * 1024u); // 5 MiB chunks -> multipart threshold ~6.25 MiB + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + BaseConfig.Options = std::move(Root).AsObject(); + } + BaseConfig.AsyncEnabled = true; + auto Hydration = InitHydration(BaseConfig); + + TestThreading Threading(8); + HydrationConfig Config{.ServerStateDir = ServerStateDir, + .TempDir = HydrationTemp, + .ModuleId = "s3async_all_tiers", + .Threading = Threading.Options}; + + // CreateTestTree: small files (pack candidates) + 256K (Small), 512K (Medium), + // 9M (Medium), 63M (Multipart). + auto Files = CreateTestTree(ServerStateDir); + Hydration->CreateHydrator(Config)->Dehydrate(CbObject()); + CHECK(std::filesystem::is_empty(ServerStateDir)); + + Hydration->CreateHydrator(Config)->Hydrate(); + VerifyTree(ServerStateDir, Files); +} + +TEST_CASE("hydration.s3async.concurrent") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = AllocateHydrationMinioTestPort(); + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + constexpr int kModuleCount = 6; + constexpr int kThreadCount = 4; + + TestThreading Threading(kThreadCount); + + ScopedTemporaryDirectory TempDir; + + struct ModuleData + { + HydrationConfig Config; + std::vector<std::pair<std::filesystem::path, IoBuffer>> Files; + }; + std::vector<ModuleData> Modules(kModuleCount); + + HydrationBase::Configuration BaseConfig; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + BaseConfig.Options = std::move(Root).AsObject(); + } + BaseConfig.AsyncEnabled = true; + auto Hydration = InitHydration(BaseConfig); + + for (int I = 0; I < kModuleCount; ++I) + { + std::string ModuleId = fmt::format("s3async_concurrent_{}"sv, I); + std::filesystem::path StateDir = TempDir.Path() / ModuleId / "state"; + std::filesystem::path TempPath = TempDir.Path() / ModuleId / "temp"; + CreateDirectories(StateDir); + CreateDirectories(TempPath); + + Modules[I].Config.ServerStateDir = StateDir; + Modules[I].Config.TempDir = TempPath; + Modules[I].Config.ModuleId = ModuleId; + Modules[I].Config.Threading = Threading.Options; + Modules[I].Files = CreateTestTree(StateDir); + } + + { + WorkerThreadPool Pool(kThreadCount, "hydration_s3async_dehy"); + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + for (int I = 0; I < kModuleCount; ++I) + { + Work.ScheduleWork(Pool, [&Hydration, &Config = Modules[I].Config](std::atomic<bool>&) { + Hydration->CreateHydrator(Config)->Dehydrate(CbObject()); + }); + } + Work.Wait(); + CHECK_FALSE(Work.IsAborted()); + } + + { + WorkerThreadPool Pool(kThreadCount, "hydration_s3async_hy"); + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + for (int I = 0; I < kModuleCount; ++I) + { + Work.ScheduleWork(Pool, [&Hydration, &Config = Modules[I].Config](std::atomic<bool>&) { + CleanDirectory(Config.ServerStateDir, true); + Hydration->CreateHydrator(Config)->Hydrate(); + }); + } + Work.Wait(); + CHECK_FALSE(Work.IsAborted()); + } + + for (int I = 0; I < kModuleCount; ++I) + { + VerifyTree(Modules[I].Config.ServerStateDir, Modules[I].Files); + } +} + +TEST_CASE("hydration.s3async.obliterate") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = AllocateHydrationMinioTestPort(); + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + ScopedTemporaryDirectory TempDir; + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationTemp); + + constexpr std::string_view ModuleId = "s3async_obliterate"sv; + + HydrationBase::Configuration BaseConfig; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + BaseConfig.Options = std::move(Root).AsObject(); + } + BaseConfig.AsyncEnabled = true; + auto Hydration = InitHydration(BaseConfig); + + HydrationConfig Config{.ServerStateDir = ServerStateDir, .TempDir = HydrationTemp, .ModuleId = std::string(ModuleId)}; + + CreateSmallTestTree(ServerStateDir); + Hydration->CreateHydrator(Config)->Dehydrate(CbObject()); + + auto ListModuleObjects = [&]() { + S3ClientOptions Opts; + Opts.BucketName = "zen-hydration-test"; + Opts.Endpoint = Minio.Endpoint(); + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = Minio.RootUser(); + Opts.Credentials.SecretAccessKey = Minio.RootPassword(); + S3Client Client(Opts); + return Client.ListObjects(fmt::format("{}/"sv, ModuleId)); + }; + + CHECK(!ListModuleObjects().Objects.empty()); + + CreateSmallTestTree(ServerStateDir); + WriteFile(HydrationTemp / "leftover.tmp", CreateSemiRandomBlob(64)); + + Hydration->CreateHydrator(Config)->Obliterate(); + + CHECK(ListModuleObjects().Objects.empty()); + CHECK(std::filesystem::is_empty(ServerStateDir)); + CHECK(std::filesystem::is_empty(HydrationTemp)); +} + +TEST_CASE("hydration.s3async.incremental") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = AllocateHydrationMinioTestPort(); + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + ScopedTemporaryDirectory TempDir; + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationTemp); + + constexpr std::string_view ModuleId = "s3async_incremental"sv; + + TestThreading Threading(8); + HydrationBase::Configuration BaseConfig; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + BaseConfig.Options = std::move(Root).AsObject(); + } + BaseConfig.AsyncEnabled = true; + auto Hydration = InitHydration(BaseConfig); + + HydrationConfig Config{.ServerStateDir = ServerStateDir, + .TempDir = HydrationTemp, + .ModuleId = std::string(ModuleId), + .Threading = Threading.Options}; + + // Mirrors hydration.s3.incremental but with Config.IoContext set so the + // async S3AsyncStorageAdapter handles I/O. Each Dehydrate empties + // ServerStateDir as a side effect; subsequent Hydrate/Dehydrate calls + // thread the prior HydrationState so incremental dehydrate can hit its + // cache instead of re-uploading. + + CbObject HydrationState = Hydration->CreateHydrator(Config)->Hydrate(); + CHECK_FALSE(HydrationState); + + auto TestFiles = CreateTestTree(ServerStateDir); + Hydration->CreateHydrator(Config)->Dehydrate(HydrationState); + CHECK(std::filesystem::is_empty(ServerStateDir)); + + HydrationState = Hydration->CreateHydrator(Config)->Hydrate(); + VerifyTree(ServerStateDir, TestFiles); + + Hydration->CreateHydrator(Config)->Dehydrate(HydrationState); + CHECK(std::filesystem::is_empty(ServerStateDir)); + + HydrationState = Hydration->CreateHydrator(Config)->Hydrate(); + + TestFiles = CreateTestTree(ServerStateDir); + Hydration->CreateHydrator(Config)->Dehydrate(HydrationState); + + HydrationState = Hydration->CreateHydrator(Config)->Hydrate(); + VerifyTree(ServerStateDir, TestFiles); + + Hydration->CreateHydrator(Config)->Dehydrate(HydrationState); +} + TEST_SUITE_END(); void diff --git a/src/zenserver/hub/hydration.h b/src/zenserver/hub/hydration.h index d9a3dda5b..55db41738 100644 --- a/src/zenserver/hub/hydration.h +++ b/src/zenserver/hub/hydration.h @@ -103,6 +103,13 @@ public: // settings (e.g. S3 "chunksize") live inside Options["settings"]. The common // `excludes` entry is parsed once by HydrationBase and shared across modules. CbObject Options; + + // Routes S3 hydration through AsyncHttpClient + S3AsyncStorage when true; false + // falls back to the blocking S3Client + S3Storage path. Ignored by FileHydration. + bool AsyncEnabled = false; + + // Per-AsyncHttpClient max in-flight requests. Only consulted when AsyncEnabled. + uint32_t AsyncMaxConcurrentRequests = 128; }; // Parses common Options entries (`excludes`) into m_Excludes, applying built-in diff --git a/src/zenserver/hub/s3asyncstorage.cpp b/src/zenserver/hub/s3asyncstorage.cpp new file mode 100644 index 000000000..b8bbc55b2 --- /dev/null +++ b/src/zenserver/hub/s3asyncstorage.cpp @@ -0,0 +1,2808 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "s3asyncstorage.h" + +#include <zencore/basicfile.h> +#include <zencore/except_fmt.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/timer.h> +#include <zenutil/cloud/s3response.h> + +#include <cstring> + +#if ZEN_WITH_TESTS +# include <zencore/iohash.h> +# include <zencore/process.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zenutil/cloud/minioprocess.h> +#endif + +namespace zen { + +using namespace std::literals; + +namespace { + // Bounded write-buffer pool. OnData callbacks on the io strand call + // Acquire to obtain a fixed-size IoBuffer; workers call Release after + // the dispatched write completes. The pool retains at most as many + // buffers as future Acquire calls remain (PendingAcquires), so it + // stabilises at peak write concurrency without dragging buffers past + // the point they can ever be re-used. + class WriteBufferPool + { + public: + WriteBufferPool(size_t InBufferSize, uint32_t TotalAcquires) : m_BufferSize(InBufferSize), m_PendingAcquires(TotalAcquires) {} + + size_t BufferSize() const { return m_BufferSize; } + + IoBuffer Acquire() + { + IoBuffer Buf; + bool Exhausted = false; + m_Lock.WithExclusiveLock([&] { + if (m_PendingAcquires == 0) + { + Exhausted = true; + return; + } + --m_PendingAcquires; + if (!m_Pool.empty()) + { + Buf = std::move(m_Pool.back()); + m_Pool.pop_back(); + } + }); + if (Exhausted) + { + // Caller (OnData / OnComplete tail flush on the curl io strand) + // over-acquired vs the precomputed TotalBlocks budget. Throw a + // real exception so the boundary catch in + // AsyncCurlStreamWriteCallback can surface it; never let + // ZEN_ASSERT propagate through curl's C frames. + throw zen::runtime_error("WriteBufferPool exhausted: more buffers requested than expected"); + } + if (!Buf) + { + Buf = IoBuffer(m_BufferSize); + } + return Buf; + } + + void Release(IoBuffer Buf) + { + m_Lock.WithExclusiveLock([&] { + if (m_Pool.size() < m_PendingAcquires) + { + m_Pool.push_back(std::move(Buf)); + } + }); + } + + // Returns the slot consumed by Acquire() back to the budget without + // supplying a buffer; for paths where the buffer was already moved into a + // lambda (e.g. ScheduleWork threw after the move). + void RestoreAcquireSlot() + { + m_Lock.WithExclusiveLock([&] { ++m_PendingAcquires; }); + } + + private: + const size_t m_BufferSize; + RwLock m_Lock; + std::vector<IoBuffer> m_Pool; + uint32_t m_PendingAcquires = 0; + }; + + // Acquire one admission slot on the dispatcher thread. Returns a refcounted + // handle whose deleter releases the slot on last drop, so callers thread it + // through worker lambdas and AsyncXxx callbacks; whichever runs last fires + // the release. nullptr when admission is disabled (Sem == nullptr). + // + // Stats may be null when the caller has no per-request stat block (DeleteAll); + // the slot is still held but the wait time is not recorded. + // + // Exception safety: counting_semaphore::acquire is noexcept; only the + // shared_ptr control-block allocation can throw. The standard's deleter- + // invocation guarantee for shared_ptr(p, d) only kicks in when p is non- + // null (p == nullptr here), so libstdc++/MSVC happen to invoke the deleter + // on alloc failure but it is not portably required - hence the explicit + // guard. + std::shared_ptr<void> AcquireAdmissionSlot(const std::shared_ptr<AdmissionSemaphore>& Sem, S3AsyncStorageStats* Stats = nullptr) + { + if (!Sem) + { + return nullptr; + } + Stopwatch AdmWait; + Sem->acquire(); + auto ReleaseGuard = MakeGuard([&Sem] { Sem->release(); }); + std::shared_ptr<void> SlotRef(nullptr, [Sem](void*) { Sem->release(); }); + ReleaseGuard.Dismiss(); + if (Stats) + { + Stats->RecordAdmissionWait(AdmWait.GetElapsedTimeUs()); + } + return SlotRef; + } +} // namespace + +S3AsyncStorage::S3AsyncStorage(AsyncHttpClient& Client, + S3RequestBuilder& Builder, + CredentialsCallback GetCreds, + std::string KeyPrefix, + uint64_t MultipartChunkSize, + std::shared_ptr<AdmissionSemaphore> Admission, + uint32_t AdmissionCap) +: m_Client(Client) +, m_Builder(Builder) +, m_GetCreds(std::move(GetCreds)) +, m_KeyPrefix(std::move(KeyPrefix)) +, m_MultipartChunkSize(MultipartChunkSize) +, m_Admission(std::move(Admission)) +, m_AdmissionCap(m_Admission ? AdmissionCap : 0u) +{ +} + +std::string +S3AsyncStorage::CasKey(const IoHash& Hash) const +{ + return fmt::format("{}/cas/{}", m_KeyPrefix, Hash); +} + +std::string +S3AsyncStorage::CasPath(const IoHash& Hash) const +{ + return m_Builder.KeyToPath(CasKey(Hash)); +} + +// Per-range writes target the prealloc'd file at-offset; BasicFile::Write is positional and concurrency-safe. +struct S3AsyncStorage::GetMultipartState +{ + GetMultipartState(ParallelWork& InWork, WorkerThreadPool& InPool, S3AsyncStorageStats InStats, uint32_t TotalBlocks) + : Work(InWork) + , Pool(InPool) + , Stats(InStats) + , Buffers(WriteBufferSize, TotalBlocks) + { + } + + static constexpr size_t WriteBufferSize = 512u * 1024u; + + ParallelWork& Work; + WorkerThreadPool& Pool; + std::shared_ptr<BasicFile> DestFile; + std::filesystem::path DestPath; + std::shared_ptr<ParallelWork::ExternalWorkToken> Token; + IoHash ContentHash; + S3AsyncStorageStats Stats; + uint32_t TotalRanges = 0; + std::atomic<uint32_t> PendingRanges{0}; + std::atomic<bool> AnyFailed{false}; + RwLock ErrorLock; + std::string FirstError; + + WriteBufferPool Buffers; +}; + +// Shared state for a single-stream GET (medium tier - one ranged request, +// streamed body). OnData fills 512 KiB IoBuffers from a shared pool on the +// io strand and dispatches each filled buffer to a worker for one positional +// Write. PendingWork counts in-flight writes plus a +1 slot for the stream +// itself; the side that drops it to zero (last writer or OnComplete after +// the final flush) finalises the request. +struct S3AsyncStorage::GetStreamState +{ + GetStreamState(ParallelWork& InWork, WorkerThreadPool& InPool, S3AsyncStorageStats InStats, uint32_t TotalBlocks) + : Work(InWork) + , Pool(InPool) + , Stats(InStats) + , Buffers(WriteBufferSize, TotalBlocks) + { + } + + static constexpr size_t WriteBufferSize = 512u * 1024u; + + ParallelWork& Work; + WorkerThreadPool& Pool; + std::shared_ptr<BasicFile> DestFile; + std::filesystem::path DestPath; + std::shared_ptr<ParallelWork::ExternalWorkToken> Token; + IoHash ContentHash; + S3AsyncStorageStats Stats; + uint64_t ExpectedSize = 0; + + IoBuffer ActiveBuf; + size_t BufFill = 0; + uint64_t NextAbsOffset = 0; + uint64_t TotalReceived = 0; + + std::atomic<uint32_t> PendingWork{1}; // +1 for stream completion + std::atomic<bool> Failed{false}; + RwLock ErrorLock; + std::string FirstError; + + WriteBufferPool Buffers; +}; + +// Shared state for an in-flight multipart upload. PendingParts.fetch_sub(acq_rel) +// in FinalizePutPart is the publication barrier for ETag writes; CompleteMultipart +// (gated on Remaining == 1) sees every part's slot. ETagLock guards FirstError; +// AnyFailed short-circuits the pipeline on first error. +struct S3AsyncStorage::PutMultipartState +{ + PutMultipartState(ParallelWork& InWork, WorkerThreadPool& InPool, S3AsyncStorageStats InStats) + : Work(InWork) + , Pool(InPool) + , Stats(InStats) + { + } + + ParallelWork& Work; + WorkerThreadPool& Pool; + std::shared_ptr<BasicFile> File; + std::shared_ptr<ParallelWork::ExternalWorkToken> Token; + IoHash ContentHash; + S3AsyncStorageStats Stats; + // Snapshot of credentials taken once at PutMultipart entry. Reused across + // all per-part signing + Complete/Abort. Avoids per-part m_GetCreds() + // shared-lock contention on the credential provider. + SigV4Credentials Creds; + std::string Key; + std::string Path; + std::string UploadId; + uint64_t TotalSize = 0; + uint64_t PartSize = 0; + uint32_t TotalParts = 0; + std::atomic<uint32_t> PendingParts{0}; + // Monotonic dispatch cursor. Initialized to PreAcquired (min(TotalParts, + // AdmissionCap)) in PutMultipart. Each AsyncPut completion racing the + // handoff path issues fetch_add(1); whichever completion observes a value + // < TotalParts owns the next dispatch. + std::atomic<uint32_t> NextPartToDispatch{0}; + uint32_t PreAcquired = 0; + // Counts pending ReadRange calls into File. Used to release File from the + // worker thread that does the last read so the close runs off the curl + // io strand. + std::atomic<uint32_t> ReadsRemaining{0}; + std::atomic<bool> AnyFailed{false}; + + RwLock ETagLock; + std::vector<std::string> ETags; // indexed by PartNumber-1 + std::string FirstError; +}; + +namespace { + // Sets AnyFailed (CAS) and captures FirstError on first failure only. + // Does NOT touch the dispatch cursor; callers that need to drain the + // undispatched tail invoke ClaimUndispatchedParts and fan out skips. + void RecordPutPartFailure(S3AsyncStorage::PutMultipartState& State, const std::string& Err) + { + // Log every failure so correlated S3 errors (e.g. AccessDenied on one + // part, ServiceUnavailable on another) all reach the operator. CAS still + // keeps the first message for the eventual Token->Fail. + ZEN_WARN("S3AsyncStorage::PutMultipart '{}': {}", State.ContentHash, Err); + bool ExpectedFalse = false; + if (State.AnyFailed.compare_exchange_strong(ExpectedFalse, true)) + { + State.ETagLock.WithExclusiveLock([&] { State.FirstError = Err; }); + } + } + + // Claims every part index from NextPartToDispatch up to TotalParts. + // Returns the count claimed; caller is responsible for firing one + // FinalizePutPart per claimed index so PendingParts can reach 1 and + // trip CompleteMultipart/AbortMultipart. Idempotent on repeated calls + // (subsequent invocations return 0). + uint32_t ClaimUndispatchedParts(S3AsyncStorage::PutMultipartState& State) + { + const uint32_t Claimed = State.NextPartToDispatch.exchange(State.TotalParts, std::memory_order_acq_rel); + return Claimed < State.TotalParts ? State.TotalParts - Claimed : 0u; + } + + void RecordGetPartFailure(S3AsyncStorage::GetMultipartState& State, const std::string& Err) + { + ZEN_WARN("S3AsyncStorage::GetMultipart '{}': {}", State.ContentHash, Err); + bool ExpectedFalse = false; + if (State.AnyFailed.compare_exchange_strong(ExpectedFalse, true)) + { + State.ErrorLock.WithExclusiveLock([&] { State.FirstError = Err; }); + } + } + + void RecordGetStreamFailure(S3AsyncStorage::GetStreamState& State, const std::string& Err) + { + ZEN_WARN("S3AsyncStorage::Get '{}': {}", State.ContentHash, Err); + bool ExpectedFalse = false; + if (State.Failed.compare_exchange_strong(ExpectedFalse, true)) + { + State.ErrorLock.WithExclusiveLock([&] { State.FirstError = Err; }); + } + } +} // namespace + +// Three tiers selected by Size: +// - Small (< kPutSmallThreshold): single PUT, body materialized into one IoBuffer; SHA in one pass. +// - Medium (< MultipartThreshold): single PUT, streaming source; two-pass (hash, then upload) with bounded RAM. +// - Large (>= MultipartThreshold): S3 multipart; per-part body materialized in DispatchPartUpload. +namespace { + constexpr uint64_t kPutSmallThreshold = 512u * 1024u; + constexpr size_t kPutReadChunk = 256u * 1024u; +} // namespace + +void +S3AsyncStorage::Put(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats) +{ + const uint64_t MultipartThreshold = m_MultipartChunkSize + (m_MultipartChunkSize / 4); + if (Size >= MultipartThreshold) + { + PutMultipart(Work, Pool, Hash, Size, SourcePath, Stats); + return; + } + if (Size < kPutSmallThreshold) + { + PutSmall(Work, Pool, Hash, Size, SourcePath, Stats); + return; + } + PutMedium(Work, Pool, Hash, Size, SourcePath, Stats); +} + +void +S3AsyncStorage::PutSmall(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats) +{ + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + // Acquire one admission slot on the dispatcher; SlotRef threads through + // worker lambda + AsyncPut callback so the slot is held until the network + // transfer completes (or the chain is destroyed on shutdown/cancel). + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission, &Stats); + + // Snapshot creds once on the dispatcher (single shared-lock acquire on the + // credential provider) and capture by value into the worker. Mirrors + // PutMultipart and avoids per-fanout contention on m_GetCreds. + SigV4Credentials Creds = m_GetCreds(); + + // Capture Stats BY VALUE: S3AsyncStorageStats is a struct of references to + // the long-lived PhaseStats atomics (PhaseStats outlives all in-flight + // transfers via the ParallelWork latch). The Stats wrapper struct itself + // is a local in the caller (S3AsyncStorageAdapter::Put) and goes out of + // scope before deferred work runs, so capturing &Stats would dangle. + Work.ScheduleWork( + Pool, + [this, + Hash = IoHash(Hash), + Size, + SourcePath = std::filesystem::path(SourcePath), + Token, + Stats, + Creds = std::move(Creds), + SlotRef = std::move(SlotRef)](std::atomic<bool>& Abort) mutable { + if (Abort.load()) + { + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutSmall '{}': aborted"sv, Hash))); + return; + } + try + { + if (Creds.AccessKeyId.empty()) + { + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutSmall '{}': no credentials available"sv, Hash))); + return; + } + + BasicFile File(SourcePath, BasicFile::Mode::kRead); + if (File.FileSize() != Size) + { + Token->Fail(std::make_exception_ptr( + zen::runtime_error("S3AsyncStorage::PutSmall '{}': source size {} differs from declared {}"sv, + Hash, + File.FileSize(), + Size))); + return; + } + + // Heap IoBuffer of exact size; explicit chunked pread fills it. No + // IoBuffer auto-materialize, no mmap. SHA computed in one pass over + // the buffer once it's filled. + IoBuffer Body(static_cast<size_t>(Size)); + uint8_t* Dst = static_cast<uint8_t*>(Body.MutableData()); + uint64_t Off = 0; + while (Off < Size) + { + const size_t Take = static_cast<size_t>(std::min<uint64_t>(kPutReadChunk, Size - Off)); + File.Read(Dst + Off, Take, Off); + Off += Take; + } + + std::string PayloadHash = Sha256ToHex(ComputeSha256(Body.GetData(), Body.GetSize())); + std::string Path = CasPath(Hash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "PUT", Path, "", PayloadHash); + + Stopwatch Timer = Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncPut (lambda capture + // before submit) does not strand InFlight elevated. + auto InFlightGuard = MakeGuard([&Stats] { Stats.EndRequest(0, 0); }); + m_Client.AsyncPut( + Path, + std::move(Body), + [Hash, Token, Timer, Stats, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? static_cast<uint64_t>(Resp.UploadedBytes) : 0); + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 PUT failed", Resp); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutSmall '{}': {}"sv, Hash, Err))); + return; + } + Token->Complete(); + }, + Headers); + InFlightGuard.Dismiss(); + } + catch (...) + { + Token->Fail(std::current_exception()); + } + }); +} + +void +S3AsyncStorage::PutMedium(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats) +{ + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission, &Stats); + + // Snapshot creds on the dispatcher; see PutSmall. + SigV4Credentials Creds = m_GetCreds(); + + Work.ScheduleWork( + Pool, + [this, + Hash = IoHash(Hash), + Size, + SourcePath = std::filesystem::path(SourcePath), + Token, + Stats, + Creds = std::move(Creds), + SlotRef = std::move(SlotRef)](std::atomic<bool>& Abort) mutable { + if (Abort.load()) + { + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMedium '{}': aborted"sv, Hash))); + return; + } + try + { + if (Creds.AccessKeyId.empty()) + { + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMedium '{}': no credentials available"sv, Hash))); + return; + } + + // Hash pass: stream-read 256 KiB chunks into a heap scratch buffer, + // feed Sha256Stream incrementally. No body materialization. The file + // handle is shared with the upload pass via shared_ptr so libcurl's + // READ callback can pread from the same handle on the io thread. + auto File = std::make_shared<BasicFile>(SourcePath, BasicFile::Mode::kRead); + if (File->FileSize() != Size) + { + Token->Fail(std::make_exception_ptr( + zen::runtime_error("S3AsyncStorage::PutMedium '{}': source size {} differs from declared {}"sv, + Hash, + File->FileSize(), + Size))); + return; + } + + Sha256Stream Hasher; + { + std::unique_ptr<uint8_t[]> Scratch(new uint8_t[kPutReadChunk]); + uint64_t Off = 0; + while (Off < Size) + { + const size_t Take = static_cast<size_t>(std::min<uint64_t>(kPutReadChunk, Size - Off)); + File->Read(Scratch.get(), Take, Off); + Hasher.Update(Scratch.get(), Take); + Off += Take; + } + } + std::string PayloadHash = Sha256ToHex(Hasher.Finalize()); + std::string Path = CasPath(Hash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "PUT", Path, "", PayloadHash); + + Stopwatch Timer = Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncPut does not + // strand InFlight elevated. + auto InFlightGuard = MakeGuard([&Stats] { Stats.EndRequest(0, 0); }); + + // Streaming source: libcurl pulls 256 KiB chunks from the file on + // the io thread. Local-disk pread is fast enough for medium tier; + // page cache makes the second pass over the same bytes near-free. + m_Client.AsyncPut( + Path, + Size, + [File](uint8_t* DstBuf, size_t MaxBytes, uint64_t AbsOffset) -> size_t { + const uint64_t Remaining = File->FileSize() > AbsOffset ? File->FileSize() - AbsOffset : 0; + const size_t Take = static_cast<size_t>(std::min<uint64_t>(MaxBytes, Remaining)); + if (Take == 0) + { + return 0; + } + File->Read(DstBuf, Take, AbsOffset); + return Take; + }, + [Hash, Token, Timer, Size, Stats, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? Size : 0); + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 PUT failed", Resp); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMedium '{}': {}"sv, Hash, Err))); + return; + } + Token->Complete(); + }, + Headers); + InFlightGuard.Dismiss(); + } + catch (...) + { + Token->Fail(std::current_exception()); + } + }); +} + +void +S3AsyncStorage::PutMultipart(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats) +{ + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + // Fires only when an exception is in flight; a missed Dismiss() on a + // non-throwing early return falls through to Token's destructor safety + // net rather than asserting on Token->Fail(nullptr). + auto FailGuard = MakeGuard([Token] { + if (auto Ex = std::current_exception()) + { + Token->Fail(Ex); + } + }); + + SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + FailGuard.Dismiss(); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': no credentials available"sv, Hash))); + return; + } + + auto File = std::make_shared<BasicFile>(SourcePath, BasicFile::Mode::kRead); + if (File->FileSize() != Size) + { + FailGuard.Dismiss(); + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': source size {} differs from declared {}"sv, + Hash, + File->FileSize(), + Size))); + return; + } + + auto State = std::make_shared<PutMultipartState>(Work, Pool, Stats); + State->File = std::move(File); + State->Token = Token; + State->ContentHash = Hash; + State->Creds = Creds; // snapshot reused across all parts + complete/abort + State->Key = CasKey(Hash); + State->Path = CasPath(Hash); + State->TotalSize = Size; + State->PartSize = m_MultipartChunkSize; + State->TotalParts = static_cast<uint32_t>((Size + State->PartSize - 1) / State->PartSize); + State->PendingParts = State->TotalParts; + State->ReadsRemaining = State->TotalParts; + State->ETags.resize(State->TotalParts); + + // Admission gating: the wave acquires min(TotalParts, AdmissionCap) slots + // inside DispatchInitialPartWave on a worker thread (off the dispatcher + // and off the io strand) so a single multipart cannot starve other + // dispatcher work while it drains the semaphore. Tail parts (TotalParts > + // cap) are dispatched lazily by HandoffSlotToNextPart from each AsyncPut + // completion, so the in-flight count per upload stays <= cap. + State->PreAcquired = m_Admission ? std::min<uint32_t>(State->TotalParts, m_AdmissionCap) : State->TotalParts; + State->NextPartToDispatch.store(State->PreAcquired, std::memory_order_relaxed); + + std::string CanonicalQs = BuildCanonicalQueryString({{"uploads", ""}}); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "POST", State->Path, CanonicalQs, S3EmptyPayloadHash); + std::string FullPath = S3BuildRequestPath(State->Path, CanonicalQs); + + // CreateMultipartUpload / CompleteMultipartUpload / AbortMultipartUpload are + // metadata operations - they do not move payload bytes and would distort the + // per-request stats (avg/max latency) if folded into RequestCount. Counted + // separately here means PhaseStats.RequestCount tracks data-bearing PUTs only, + // matching the sync S3Storage path's accounting more closely. + + m_Client.AsyncPost( + FullPath, + [this, State](HttpClient::Response Resp) { + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 CreateMultipartUpload failed", Resp); + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': {}"sv, State->ContentHash, Err))); + return; + } + // See PutMultipart entry: tolerate null exception_ptr on a missed Dismiss(). + auto CallbackGuard = MakeGuard([State] { + if (auto Ex = std::current_exception()) + { + State->Token->Fail(Ex); + } + }); + std::string_view Body = Resp.AsText(); + // S3 can answer 200 with an embedded <Error> body even on Create. + // Mirror CompleteMultipart's parse so the original error code/message + // surfaces instead of a generic "missing UploadId". + std::string_view ErrorCode; + std::string_view ErrorMessage; + if (S3ExtractError(Body, ErrorCode, ErrorMessage)) + { + CallbackGuard.Dismiss(); + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': create returned error: {} - {}"sv, + State->ContentHash, + ErrorCode, + ErrorMessage))); + return; + } + std::string_view UploadId = S3ExtractXmlValue(Body, "UploadId"); + if (UploadId.empty()) + { + CallbackGuard.Dismiss(); + State->Token->Fail(std::make_exception_ptr( + zen::runtime_error("S3AsyncStorage::PutMultipart '{}': missing UploadId in CreateMultipartUpload response"sv, + State->ContentHash))); + return; + } + State->UploadId = std::string(UploadId); + // Hop the wave dispatch onto a worker so the per-part admission + // acquire never blocks the io strand or the caller dispatcher. + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State](std::atomic<bool>&) { DispatchInitialPartWave(State); }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (...) + { + CallbackGuard.Dismiss(); + RecordPutPartFailure(*State, "S3 UploadPart wave schedule failed"); + for (uint32_t J = 0; J < State->TotalParts; ++J) + { + FinalizePutPart(State); + } + return; + } + CallbackGuard.Dismiss(); + }, + Headers); + FailGuard.Dismiss(); +} + +void +S3AsyncStorage::DispatchInitialPartWave(std::shared_ptr<PutMultipartState> State) +{ + const bool AdmissionEnabled = (m_Admission != nullptr); + const uint32_t InitialDispatch = AdmissionEnabled ? State->PreAcquired : State->TotalParts; + + for (uint32_t I = 0; I < InitialDispatch; ++I) + { + const uint32_t PartNum = I + 1; + std::shared_ptr<void> SlotRef; + try + { + if (AdmissionEnabled) + { + SlotRef = AcquireAdmissionSlot(m_Admission, &State->Stats); + } + State->Work.ScheduleWork( + State->Pool, + [this, State, PartNum, SlotRef = std::move(SlotRef)](std::atomic<bool>&) mutable { + DispatchPartUpload(State, PartNum, std::move(SlotRef)); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + // Acquire/schedule failed mid-wave. Account for this part plus + // the un-iterated wave tail, then drain the lazy tail past the + // wave. Slot (if acquired) releases when SlotRef destructs. + RecordPutPartFailure(*State, fmt::format("S3 UploadPart {} schedule failed: {}", PartNum, Ex.what())); + const uint32_t WaveTail = InitialDispatch - I; + for (uint32_t J = 0; J < WaveTail; ++J) + { + FinalizePutPart(State); + } + DrainUndispatchedParts(State); + return; + } + } +} + +void +S3AsyncStorage::HandoffSlotToNextPart(std::shared_ptr<PutMultipartState> State, uint32_t PartIdx, std::shared_ptr<void> SlotRef) +{ + const uint32_t PartNum = PartIdx + 1; + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State, PartNum, SlotRef = std::move(SlotRef)](std::atomic<bool>&) mutable { + DispatchPartUpload(State, PartNum, std::move(SlotRef)); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + // Handoff dispatch failed. The slot we were about to hand off + // releases when SlotRef destructs at scope exit. Account for this + // never-dispatched part and drain the rest of the tail. + RecordPutPartFailure(*State, fmt::format("S3 UploadPart {} handoff failed: {}", PartNum, Ex.what())); + FinalizePutPart(State); + DrainUndispatchedParts(State); + } +} + +void +S3AsyncStorage::DrainUndispatchedParts(std::shared_ptr<PutMultipartState> State) +{ + const uint32_t Skipped = ClaimUndispatchedParts(*State); + for (uint32_t J = 0; J < Skipped; ++J) + { + FinalizePutPart(State); + } +} + +#if ZEN_WITH_TESTS +namespace s3asyncstorage_test_hooks { + std::atomic<uint32_t> g_ForceNextPartFailures{0}; + void ForceNextPartFailures(uint32_t Count) { g_ForceNextPartFailures.store(Count, std::memory_order_relaxed); } +} // namespace s3asyncstorage_test_hooks +#endif + +void +S3AsyncStorage::DispatchPartUpload(std::shared_ptr<PutMultipartState> State, uint32_t PartNum, std::shared_ptr<void> SlotRef) +{ + const uint64_t Offset = static_cast<uint64_t>(PartNum - 1) * State->PartSize; + const uint64_t ChunkSize = std::min<uint64_t>(State->PartSize, State->TotalSize - Offset); + + // Release File on this worker thread once all reads have been served, so + // the close never runs on the curl io strand from a captured State ref in + // an AsyncPut callback. + auto ReleaseFileIfLast = [&State] { + if (State->ReadsRemaining.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + State->File.reset(); + } + }; + +#if ZEN_WITH_TESTS + // Test hook: synthesize a part-level failure to drive the AbortMultipart + // path. Decrement the counter while > 0; each consumed slot fails one part. + { + uint32_t Cur = s3asyncstorage_test_hooks::g_ForceNextPartFailures.load(std::memory_order_relaxed); + while (Cur > 0 && + !s3asyncstorage_test_hooks::g_ForceNextPartFailures.compare_exchange_weak(Cur, Cur - 1, std::memory_order_acq_rel)) + { + } + if (Cur > 0) + { + ReleaseFileIfLast(); + RecordPutPartFailure(*State, fmt::format("test-injected part failure (part {})", PartNum)); + FinalizePutPart(State); + DrainUndispatchedParts(State); + return; + } + } +#endif + + // Short-circuit if a foreign failure has flipped AnyFailed since this part was + // scheduled (handoff race: success callback bumps the cursor to dispatch the + // next part, drain runs concurrently). Skip the file read + sign + AsyncPut for + // a doomed transfer; FinalizePutPart accounting is symmetric to the success path. + if (State->AnyFailed.load(std::memory_order_acquire)) + { + ReleaseFileIfLast(); + FinalizePutPart(State); + return; + } + + // Explicit chunked positional pread + Sha256Stream in one pass. One heap + // IoBuffer of exact part size; no mmap, no IoBuffer auto-materialize. + IoBuffer Part(static_cast<size_t>(ChunkSize)); + uint8_t* Dst = static_cast<uint8_t*>(Part.MutableData()); + Sha256Stream Hasher; + try + { + uint64_t Read = 0; + while (Read < ChunkSize) + { + const size_t Take = static_cast<size_t>(std::min<uint64_t>(kPutReadChunk, ChunkSize - Read)); + State->File->Read(Dst + Read, Take, Offset + Read); + Hasher.Update(Dst + Read, Take); + Read += Take; + } + ReleaseFileIfLast(); + } + catch (const std::exception& Ex) + { + ReleaseFileIfLast(); + RecordPutPartFailure(*State, fmt::format("S3 UploadPart {} read failed: {}", PartNum, Ex.what())); + FinalizePutPart(State); + DrainUndispatchedParts(State); + return; + } + + try + { + // Use the snapshot taken at PutMultipart entry; saves ~TotalParts + // shared-lock acquisitions on the credential provider. + const SigV4Credentials& Creds = State->Creds; + + std::string CanonicalQs = BuildCanonicalQueryString({ + {"partNumber", fmt::format("{}", PartNum)}, + {"uploadId", State->UploadId}, + }); + std::string PayloadHash = Sha256ToHex(Hasher.Finalize()); + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "PUT", State->Path, CanonicalQs, PayloadHash); + std::string FullPath = S3BuildRequestPath(State->Path, CanonicalQs); + + Stopwatch Timer = State->Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncPut does not strand + // InFlight elevated. + auto InFlightGuard = MakeGuard([&State] { State->Stats.EndRequest(0, 0); }); + + m_Client.AsyncPut( + FullPath, + std::move(Part), + [this, State, PartNum, Timer, ChunkSize, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + State->Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? ChunkSize : 0); + if (!Resp.IsSuccess()) + { + RecordPutPartFailure(*State, S3ErrorMessage(fmt::format("S3 UploadPart {} failed", PartNum), Resp)); + DrainUndispatchedParts(State); + FinalizePutPart(State); + return; + } + std::string_view ETag = Resp.FindHeader("etag"); + if (ETag.empty()) + { + RecordPutPartFailure(*State, fmt::format("S3 UploadPart {} response missing ETag header", PartNum)); + DrainUndispatchedParts(State); + FinalizePutPart(State); + return; + } + // Per-slot writes don't race - each part owns its slot. + // PendingParts.fetch_sub(acq_rel) in FinalizePutPart is the + // publication barrier: CompleteMultipart (gated on Remaining + // == 1) sees every ETag write. New readers must go through + // the same barrier or take ETagLock. + State->ETags[PartNum - 1].assign(ETag); + + // Try to hand the slot off to the next undispatched part. Skip + // the handoff if a prior failure has flipped AnyFailed (cursor + // was already exchanged to TotalParts by the drain there). + if (!State->AnyFailed.load(std::memory_order_acquire)) + { + const uint32_t NextIdx = State->NextPartToDispatch.fetch_add(1, std::memory_order_acq_rel); + if (NextIdx < State->TotalParts) + { + HandoffSlotToNextPart(State, NextIdx, std::move(SlotRef)); + } + } + // Account for THIS part. SlotRef may already be moved-from + // (if handed off); if not, releases when this lambda destructs. + FinalizePutPart(State); + }, + Headers); + InFlightGuard.Dismiss(); + } + catch (const std::exception& Ex) + { + RecordPutPartFailure(*State, fmt::format("S3 UploadPart {} dispatch failed: {}", PartNum, Ex.what())); + FinalizePutPart(State); + DrainUndispatchedParts(State); + } +} + +void +S3AsyncStorage::FinalizePutPart(std::shared_ptr<PutMultipartState> State) +{ + const uint32_t Remaining = State->PendingParts.fetch_sub(1, std::memory_order_acq_rel); + if (Remaining != 1) + { + return; + } + + // Hop the Complete/Abort dispatch off the curl io strand. SHA over the + // parts-XML and SignRequest are CPU work; AsyncPost itself queues a strand + // wakeup. Running both inline here would pin the strand thread. + const bool Failed = State->AnyFailed.load(std::memory_order_acquire); + try + { + State->Work.ScheduleWork(State->Pool, [this, State, Failed](std::atomic<bool>&) { + if (Failed) + { + AbortMultipart(State); + } + else + { + CompleteMultipart(State); + } + }); + } + catch (...) + { + // ScheduleWork failed before Complete/Abort could run; surface the leak via Token->Fail. + State->Token->Fail(std::current_exception()); + } +} + +void +S3AsyncStorage::CompleteMultipart(std::shared_ptr<PutMultipartState> State) +{ + // Reuse the snapshot taken at PutMultipart entry. The upload-id was issued + // against these creds; refreshing here would only matter if creds expired + // mid-upload, in which case all the part uploads would already have failed. + const SigV4Credentials& Creds = State->Creds; + + ExtendableStringBuilder<1024> Xml; + Xml.Append("<CompleteMultipartUpload>"); + for (uint32_t I = 0; I < State->TotalParts; ++I) + { + Xml.Append(fmt::format("<Part><PartNumber>{}</PartNumber><ETag>{}</ETag></Part>", I + 1, State->ETags[I])); + } + Xml.Append("</CompleteMultipartUpload>"); + std::string_view XmlView = Xml.ToView(); + + std::string CanonicalQs = BuildCanonicalQueryString({{"uploadId", State->UploadId}}); + std::string PayloadHash = Sha256ToHex(ComputeSha256(XmlView)); + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "POST", State->Path, CanonicalQs, PayloadHash); + std::string FullPath = S3BuildRequestPath(State->Path, CanonicalQs); + + IoBuffer Payload(IoBuffer::Clone, XmlView.data(), XmlView.size()); + + // Metadata op; not counted in PhaseStats (see CreateMultipartUpload comment). + m_Client.AsyncPost( + FullPath, + Payload, + ZenContentType::kXML, + [State](HttpClient::Response Resp) { + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 CompleteMultipartUpload failed", Resp); + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': {}"sv, State->ContentHash, Err))); + return; + } + std::string_view Body = Resp.AsText(); + std::string_view ErrorCode; + std::string_view ErrorMessage; + if (S3ExtractError(Body, ErrorCode, ErrorMessage)) + { + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': complete returned error: {} - {}"sv, + State->ContentHash, + ErrorCode, + ErrorMessage))); + return; + } + State->Token->Complete(); + }, + Headers); +} + +void +S3AsyncStorage::AbortMultipart(std::shared_ptr<PutMultipartState> State) +{ + // Reuse the snapshot taken at PutMultipart entry. + const SigV4Credentials& Creds = State->Creds; + + std::string CanonicalQs = BuildCanonicalQueryString({{"uploadId", State->UploadId}}); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "DELETE", State->Path, CanonicalQs, S3EmptyPayloadHash); + std::string FullPath = S3BuildRequestPath(State->Path, CanonicalQs); + + // Metadata op; not counted in PhaseStats (see CreateMultipartUpload comment). + m_Client.AsyncDelete( + FullPath, + [State](HttpClient::Response Resp) { + std::string FirstError; + State->ETagLock.WithExclusiveLock([&] { FirstError = std::move(State->FirstError); }); + if (!Resp.IsSuccess()) + { + State->Token->Fail(std::make_exception_ptr( + zen::runtime_error("S3AsyncStorage::PutMultipart '{}': part failed ({}); abort also failed: {}"sv, + State->ContentHash, + FirstError, + S3ErrorMessage("S3 AbortMultipartUpload failed", Resp)))); + return; + } + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': {}"sv, State->ContentHash, FirstError))); + }, + Headers); +} + +void +S3AsyncStorage::Get(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath, + S3AsyncStorageStats& Stats) +{ + // Three tiers: in-memory AsyncGet (< 512 KiB), AsyncStream into pooled + // 512 KiB write buffers (medium), GetMultipart (>= MultiRangeThreshold). + constexpr uint64_t StreamingThreshold = 512u * 1024u; + + const uint64_t MultiRangeThreshold = m_MultipartChunkSize + (m_MultipartChunkSize / 4); + if (Size >= MultiRangeThreshold) + { + GetMultipart(Work, Pool, Hash, Size, DestinationPath, Stats); + return; + } + + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + // See PutMultipart: tolerate null exception_ptr from a refactor that + // adds a non-throwing early return without Dismiss(). + auto FailGuard = MakeGuard([Token] { + if (auto Ex = std::current_exception()) + { + Token->Fail(Ex); + } + }); + + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission, &Stats); + + SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + FailGuard.Dismiss(); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': no credentials available"sv, Hash))); + return; + } + + std::string Path = CasPath(Hash); + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "GET", Path, "", S3EmptyPayloadHash); + + Stopwatch Timer = Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncGet/AsyncStream does not + // strand InFlight elevated. + auto InFlightGuard = MakeGuard([&Stats] { Stats.EndRequest(0, 0); }); + + if (Size < StreamingThreshold) + { + // Small: in-memory AsyncGet, single worker-hop write at completion. + m_Client.AsyncGet( + Path, + [Hash = IoHash(Hash), Token, Timer, Stats, DestPath = DestinationPath, Size, &Work, &Pool, SlotRef = std::move(SlotRef)]( + HttpClient::Response Resp) mutable { + Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? Resp.ResponsePayload.GetSize() : 0); + // No remove() on the failure / size-mismatch paths: the destination + // file is only created by the worker hop below on success, so + // nothing exists to delete here. + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 GET failed", Resp); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': {}"sv, Hash, Err))); + return; + } + if (Resp.ResponsePayload.GetSize() != Size) + { + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': received {} bytes, expected {}"sv, + Hash, + Resp.ResponsePayload.GetSize(), + Size))); + return; + } + try + { + Work.ScheduleWork( + Pool, + [Token, DestPath, Hash, Payload = std::move(Resp.ResponsePayload)](std::atomic<bool>&) mutable { + try + { + BasicFile Out(DestPath, BasicFile::Mode::kTruncate); + Out.Write(Payload.GetData(), Payload.GetSize(), 0); + Token->Complete(); + } + catch (const std::exception& Ex) + { + std::error_code Ec; + std::filesystem::remove(DestPath, Ec); + Token->Fail(std::make_exception_ptr( + zen::runtime_error("S3AsyncStorage::Get '{}': write failed: {}"sv, Hash, Ex.what()))); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': schedule failed: {}"sv, Hash, Ex.what()))); + } + }, + Headers); + InFlightGuard.Dismiss(); + FailGuard.Dismiss(); + return; + } + + // Medium tier: AsyncStream + GetStreamState (see struct doc). + const uint32_t StreamTotalBlocks = + static_cast<uint32_t>((Size + GetStreamState::WriteBufferSize - 1) / GetStreamState::WriteBufferSize); + auto State = std::make_shared<GetStreamState>(Work, Pool, Stats, StreamTotalBlocks); + State->Token = Token; + State->ContentHash = Hash; + State->DestPath = DestinationPath; + State->ExpectedSize = Size; + + try + { + State->DestFile = std::make_shared<BasicFile>(DestinationPath, BasicFile::Mode::kTruncate); + std::error_code PrepareEc = PrepareFileForScatteredWrite(State->DestFile->Handle(), Size); + if (PrepareEc) + { + throw zen::runtime_error("PrepareFileForScatteredWrite failed: {}"sv, PrepareEc.message()); + } + } + catch (const std::exception& Ex) + { + FailGuard.Dismiss(); + // Drop DestFile (if opened) before remove() so Windows lets the unlink + // proceed; otherwise we leave a 0-byte file behind on the prealloc path. + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(DestinationPath, Ec); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': preallocate failed: {}"sv, Hash, Ex.what()))); + return; + } + + m_Client.AsyncStream( + Path, + [this, State](const uint8_t* Data, size_t SizeBytes, uint64_t /*TotalSize*/) -> bool { + constexpr size_t WriteBufferSize = GetStreamState::WriteBufferSize; + if (State->TotalReceived + SizeBytes > State->ExpectedSize) + { + RecordGetStreamFailure( + *State, + fmt::format("S3 GET overflow: got {} bytes past expected {}", State->TotalReceived + SizeBytes, State->ExpectedSize)); + return false; + } + const uint8_t* Cur = Data; + size_t Remaining = SizeBytes; + while (Remaining > 0) + { + if (!State->ActiveBuf) + { + State->ActiveBuf = State->Buffers.Acquire(); + } + const size_t Avail = WriteBufferSize - State->BufFill; + const size_t Take = std::min(Remaining, Avail); + std::memcpy(State->ActiveBuf.MutableData<uint8_t>() + State->BufFill, Cur, Take); + State->BufFill += Take; + State->TotalReceived += Take; + Cur += Take; + Remaining -= Take; + if (State->BufFill == WriteBufferSize) + { + IoBuffer Buf = std::move(State->ActiveBuf); + const size_t Fill = State->BufFill; + const uint64_t AbsOffset = State->NextAbsOffset; + State->NextAbsOffset += Fill; + State->BufFill = 0; + State->PendingWork.fetch_add(1, std::memory_order_acq_rel); + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State, Buf = std::move(Buf), Fill, AbsOffset](std::atomic<bool>&) mutable { + try + { + State->DestFile->Write(Buf.GetData(), Fill, AbsOffset); + } + catch (const std::exception& Ex) + { + RecordGetStreamFailure(*State, fmt::format("S3 GET write failed: {}", Ex.what())); + } + State->Buffers.Release(std::move(Buf)); + if (State->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetStreamFinalised(State); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + // ScheduleWork failed before the worker took ownership of the + // PendingWork increment we did above; balance it here and + // abort the transfer so OnComplete tears down cleanly. Buf is + // gone (moved into the lambda then destroyed when ScheduleWork + // threw); restore the pool's acquire slot so OnComplete's + // tail-flush can still Acquire if needed. + RecordGetStreamFailure(*State, fmt::format("S3 GET schedule failed: {}", Ex.what())); + State->Buffers.RestoreAcquireSlot(); + State->PendingWork.fetch_sub(1, std::memory_order_acq_rel); + return false; + } + } + } + return true; + }, + [this, State, Timer, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + State->Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? State->TotalReceived : 0); + bool Failed = !Resp.IsSuccess(); + if (Failed) + { + RecordGetStreamFailure(*State, S3ErrorMessage("S3 GET failed", Resp)); + } + else if (State->TotalReceived != State->ExpectedSize) + { + RecordGetStreamFailure(*State, + fmt::format("S3 GET wrote {} bytes, expected {}", State->TotalReceived, State->ExpectedSize)); + Failed = true; + } + if (!Failed && State->BufFill > 0) + { + IoBuffer Buf = std::move(State->ActiveBuf); + const size_t Fill = State->BufFill; + const uint64_t AbsOffset = State->NextAbsOffset; + State->BufFill = 0; + State->PendingWork.fetch_add(1, std::memory_order_acq_rel); + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State, Buf = std::move(Buf), Fill, AbsOffset](std::atomic<bool>&) mutable { + try + { + State->DestFile->Write(Buf.GetData(), Fill, AbsOffset); + } + catch (const std::exception& Ex) + { + RecordGetStreamFailure(*State, fmt::format("S3 GET write failed: {}", Ex.what())); + } + State->Buffers.Release(std::move(Buf)); + if (State->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetStreamFinalised(State); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + RecordGetStreamFailure(*State, fmt::format("S3 GET tail-flush schedule failed: {}", Ex.what())); + State->Buffers.RestoreAcquireSlot(); + State->PendingWork.fetch_sub(1, std::memory_order_acq_rel); + } + } + // PendingWork = 1 (stream slot, set in ctor) + N (per-buffer worker + // dispatched in OnData) + (optional 1 for the tail-flush dispatched + // just above). The fetch_sub returning 1 means we observed the value + // that was 1 right before our decrement - i.e. we are the last + // participant. Either this branch (no tail flush, no in-flight + // workers) or the last per-buffer worker performs the finalise. Both + // paths use acq_rel so the work-item writes are visible to whichever + // thread runs OnGetStreamFinalised. + if (State->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetStreamFinalised(State); + } + }, + Headers); + InFlightGuard.Dismiss(); + FailGuard.Dismiss(); +} + +void +S3AsyncStorage::OnGetStreamFinalised(std::shared_ptr<GetStreamState> State) +{ + // All disk I/O (close + unlink on failure) is hopped into the worker pool. + // The last PendingWork dec can land on the curl io strand, and inline + // disk syscalls there would block the curl_multi poll loop. + const bool Failed = State->Failed.load(std::memory_order_acquire); + try + { + State->Work.ScheduleWork( + State->Pool, + [State, Failed](std::atomic<bool>&) { + if (Failed) + { + std::string Err; + State->ErrorLock.WithExclusiveLock([&] { Err = std::move(State->FirstError); }); + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(State->DestPath, Ec); + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': {}"sv, State->ContentHash, Err))); + return; + } + + // No Flush: consumer reads via page cache; crash recovery re-hydrates. + State->DestFile.reset(); + State->Token->Complete(); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (...) + { + // Schedule failed; release file handle inline and surface via Token. + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(State->DestPath, Ec); + State->Token->Fail(std::current_exception()); + } +} + +void +S3AsyncStorage::GetMultipart(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath, + S3AsyncStorageStats& Stats) +{ + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + const uint64_t Chunk = m_MultipartChunkSize; + const uint32_t RangeCount = static_cast<uint32_t>((Size + Chunk - 1) / Chunk); + uint32_t TotalBlocks = 0; + for (uint32_t I = 0; I < RangeCount; ++I) + { + const uint64_t RS = std::min<uint64_t>(Chunk, Size - static_cast<uint64_t>(I) * Chunk); + TotalBlocks += static_cast<uint32_t>((RS + GetMultipartState::WriteBufferSize - 1) / GetMultipartState::WriteBufferSize); + } + + auto State = std::make_shared<GetMultipartState>(Work, Pool, Stats, TotalBlocks); + State->Token = Token; + State->ContentHash = Hash; + State->DestPath = DestinationPath; + State->TotalRanges = RangeCount; + State->PendingRanges = RangeCount; + + // RangeCount > AdmissionCap is fine: the per-range acquire below blocks the + // dispatcher (caller thread, NOT the io strand) until in-flight ranges fire + // their AsyncStream completions and release slots. In-flight ranges per + // upload stay bounded by AdmissionCap. + + try + { + // Sparse + preallocated. Sparse mode lets per-range writes land at + // arbitrary offsets without the OS zero-filling intervening pages + // first; preallocating the full size up front avoids fragmentation + // and lazy-commit page faults during the writes themselves. + State->DestFile = std::make_shared<BasicFile>(DestinationPath, BasicFile::Mode::kTruncate); + std::error_code PrepareEc = PrepareFileForScatteredWrite(State->DestFile->Handle(), Size); + if (PrepareEc) + { + throw zen::runtime_error("PrepareFileForScatteredWrite failed: {}"sv, PrepareEc.message()); + } + } + catch (const std::exception& Ex) + { + // Drop DestFile (if opened) before remove() so Windows lets the unlink + // proceed; otherwise we leave a 0-byte file behind on the prealloc path. + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(DestinationPath, Ec); + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::GetMultipart '{}': preallocate failed: {}"sv, Hash, Ex.what()))); + return; + } + + std::string Path = CasPath(Hash); + + // Snapshot creds once for the whole multipart Get. Mirrors PutMultipart; + // avoids N shared-lock acquisitions on the credential provider when the + // range count is large. + const SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + // Empty creds are loop-invariant; drive PendingRanges to zero in one shot + // rather than acquiring TotalRanges admission slots only to fire identical + // failures. RecordGetPartFailure CAS keeps the first message; subsequent + // completions just decrement the counter. + RecordGetPartFailure(*State, "no credentials available"); + for (uint32_t I = 0; I < State->TotalRanges; ++I) + { + OnGetPartCompleted(State); + } + return; + } + + for (uint32_t I = 0; I < State->TotalRanges; ++I) + { + const uint64_t Offset = static_cast<uint64_t>(I) * Chunk; + const uint64_t RangeSize = std::min<uint64_t>(Chunk, Size - Offset); + const uint32_t RangeIdx = I; + + try + { + // Per-range admission slot. Acquire blocks the caller thread when the + // cap is reached; in-flight ranges release slots via SlotRef destruct + // in their AsyncStream completion callback. Acquire on the caller + // (vs the io strand) is safe to block: GetMultipart is called from + // the provision-pool worker driving Storage::Get, never from the + // curl io thread that has to drain completions. + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission, &Stats); + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "GET", Path, "", S3EmptyPayloadHash); + Headers->emplace("Range", fmt::format("bytes={}-{}", Offset, Offset + RangeSize - 1)); + + Stopwatch Timer = State->Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncStream does not + // strand InFlight elevated. + auto InFlightGuard = MakeGuard([&State] { State->Stats.EndRequest(0, 0); }); + + // Per-range mirror of GetStreamState's buffered-write machinery (see + // the medium-tier branch in Get()). + constexpr size_t WriteBufferSize = GetMultipartState::WriteBufferSize; + + struct RangeWriteState + { + IoBuffer ActiveBuf; + size_t BufFill = 0; + uint64_t NextAbsOffset; + uint64_t TotalReceived = 0; + std::atomic<uint32_t> PendingWork{1}; // +1 for stream completion + }; + + auto WState = std::make_shared<RangeWriteState>(); + WState->NextAbsOffset = Offset; + + m_Client.AsyncStream( + Path, + [this, State, RangeSize, RangeIdx, WState](const uint8_t* Data, size_t SizeBytes, uint64_t /*TotalSize*/) -> bool { + if (WState->TotalReceived + SizeBytes > RangeSize) + { + RecordGetPartFailure(*State, + fmt::format("S3 GET range {} overflow: got {} bytes past expected {}", + RangeIdx, + WState->TotalReceived + SizeBytes, + RangeSize)); + return false; + } + const uint8_t* Cur = Data; + size_t Remaining = SizeBytes; + while (Remaining > 0) + { + if (!WState->ActiveBuf) + { + WState->ActiveBuf = State->Buffers.Acquire(); + } + const size_t Avail = WriteBufferSize - WState->BufFill; + const size_t Take = std::min(Remaining, Avail); + std::memcpy(WState->ActiveBuf.MutableData<uint8_t>() + WState->BufFill, Cur, Take); + WState->BufFill += Take; + WState->TotalReceived += Take; + Cur += Take; + Remaining -= Take; + if (WState->BufFill == WriteBufferSize) + { + IoBuffer Buf = std::move(WState->ActiveBuf); + const size_t Fill = WState->BufFill; + const uint64_t AbsOffset = WState->NextAbsOffset; + WState->NextAbsOffset += Fill; + WState->BufFill = 0; + WState->PendingWork.fetch_add(1, std::memory_order_acq_rel); + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State, RangeIdx, WState, Buf = std::move(Buf), Fill, AbsOffset](std::atomic<bool>&) mutable { + try + { + State->DestFile->Write(Buf.GetData(), Fill, AbsOffset); + } + catch (const std::exception& Ex) + { + RecordGetPartFailure(*State, + fmt::format("S3 GET range {} write failed: {}", RangeIdx, Ex.what())); + } + State->Buffers.Release(std::move(Buf)); + if (WState->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetPartCompleted(State); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + RecordGetPartFailure(*State, fmt::format("S3 GET range {} schedule failed: {}", RangeIdx, Ex.what())); + State->Buffers.RestoreAcquireSlot(); + WState->PendingWork.fetch_sub(1, std::memory_order_acq_rel); + return false; + } + } + } + return true; + }, + [this, State, RangeSize, RangeIdx, WState, Timer, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + State->Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? WState->TotalReceived : 0); + bool Failed = !Resp.IsSuccess(); + if (Failed) + { + RecordGetPartFailure(*State, S3ErrorMessage(fmt::format("S3 GET range {} failed", RangeIdx), Resp)); + } + else if (WState->TotalReceived != RangeSize) + { + RecordGetPartFailure( + *State, + fmt::format("S3 GET range {} wrote {} bytes, expected {}", RangeIdx, WState->TotalReceived, RangeSize)); + Failed = true; + } + if (!Failed && WState->BufFill > 0) + { + IoBuffer Buf = std::move(WState->ActiveBuf); + const size_t Fill = WState->BufFill; + const uint64_t AbsOffset = WState->NextAbsOffset; + WState->BufFill = 0; + WState->PendingWork.fetch_add(1, std::memory_order_acq_rel); + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State, RangeIdx, WState, Buf = std::move(Buf), Fill, AbsOffset](std::atomic<bool>&) mutable { + try + { + State->DestFile->Write(Buf.GetData(), Fill, AbsOffset); + } + catch (const std::exception& Ex) + { + RecordGetPartFailure(*State, fmt::format("S3 GET range {} write failed: {}", RangeIdx, Ex.what())); + } + State->Buffers.Release(std::move(Buf)); + if (WState->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetPartCompleted(State); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + RecordGetPartFailure(*State, + fmt::format("S3 GET range {} tail-flush schedule failed: {}", RangeIdx, Ex.what())); + State->Buffers.RestoreAcquireSlot(); + WState->PendingWork.fetch_sub(1, std::memory_order_acq_rel); + } + } + if (WState->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetPartCompleted(State); + } + }, + Headers); + InFlightGuard.Dismiss(); + } + catch (const std::exception& Ex) + { + RecordGetPartFailure(*State, fmt::format("S3 GET range {} dispatch failed: {}", RangeIdx, Ex.what())); + OnGetPartCompleted(State); + } + } +} + +void +S3AsyncStorage::OnGetPartCompleted(std::shared_ptr<GetMultipartState> State) +{ + const uint32_t Remaining = State->PendingRanges.fetch_sub(1, std::memory_order_acq_rel); + if (Remaining != 1) + { + return; + } + + // All disk I/O hopped to the worker pool: see note in + // OnGetStreamFinalised. + const bool Failed = State->AnyFailed.load(std::memory_order_acquire); + try + { + State->Work.ScheduleWork( + State->Pool, + [State, Failed](std::atomic<bool>&) { + if (Failed) + { + std::string FirstError; + State->ErrorLock.WithExclusiveLock([&] { FirstError = std::move(State->FirstError); }); + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(State->DestPath, Ec); + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': {}"sv, State->ContentHash, FirstError))); + return; + } + + State->DestFile.reset(); + State->Token->Complete(); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (...) + { + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(State->DestPath, Ec); + State->Token->Fail(std::current_exception()); + } +} + +void +S3AsyncStorage::Touch(ParallelWork& Work, WorkerThreadPool& Pool, const IoHash& Hash, S3AsyncStorageStats& Stats) +{ + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission, &Stats); + + // Snapshot creds on the dispatcher; see PutSmall. + SigV4Credentials Creds = m_GetCreds(); + + Work.ScheduleWork( + Pool, + [this, Hash = IoHash(Hash), Token, Stats, Creds = std::move(Creds), SlotRef = std::move(SlotRef)]( + std::atomic<bool>& Abort) mutable { + if (Abort.load()) + { + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Touch '{}': aborted"sv, Hash))); + return; + } + try + { + if (Creds.AccessKeyId.empty()) + { + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Touch '{}': no credentials available"sv, Hash))); + return; + } + + std::string Key = CasKey(Hash); + std::string Path = CasPath(Hash); + + std::vector<std::pair<std::string, std::string>> ExtraSigned{ + {"x-amz-copy-source", fmt::format("/{}/{}", m_Builder.BucketName(), AwsUriEncode(Key, false))}, + {"x-amz-metadata-directive", "REPLACE"}, + }; + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "PUT", Path, "", S3EmptyPayloadHash, ExtraSigned); + + Stopwatch Timer = Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncPut does not strand + // InFlight elevated. + auto InFlightGuard = MakeGuard([&Stats] { Stats.EndRequest(0, 0); }); + m_Client.AsyncPut( + Path, + [Hash, Token, Timer, Stats, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + Stats.EndRequest(Timer.GetElapsedTimeUs(), 0); + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 Touch failed", Resp); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Touch '{}': {}"sv, Hash, Err))); + return; + } + // PUT-COPY (REPLACE) can return HTTP 200 with an <Error> body. Mirror + // the sync S3Client::Touch check; without this the dehydrate-touch + // fails silently. + std::string_view Body = Resp.AsText(); + std::string_view ErrorCode; + std::string_view ErrorMessage; + if (S3ExtractError(Body, ErrorCode, ErrorMessage)) + { + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Touch '{}': returned error: {} - {}"sv, + Hash, + ErrorCode, + ErrorMessage))); + return; + } + Token->Complete(); + }, + Headers); + InFlightGuard.Dismiss(); + } + catch (...) + { + Token->Fail(std::current_exception()); + } + }); +} + +std::vector<std::string> +S3AsyncStorage::ListAllObjects(std::string_view Prefix) +{ + constexpr std::string_view ContentsCloseTag = "</Contents>"; + + // List requests are not counted in PhaseStats; List/DeleteAll is admin + // scaffolding (test fixture teardown, manual obliterate). Mirrors the + // CreateMultipart/Complete/Abort exclusion in PutMultipart. + + // Snapshot creds once for the listing run; ListObjectsV2 pagination is + // fast enough that mid-list refresh isn't needed. + const SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + throw zen::runtime_error("S3AsyncStorage::ListAllObjects: no credentials available"sv); + } + + std::vector<std::string> Keys; + std::string ContinuationToken; + std::string RootPath = m_Builder.BucketRootPath(); + while (true) + { + std::string CanonicalQs = fmt::format("list-type=2&prefix={}", AwsUriEncode(Prefix)); + if (!ContinuationToken.empty()) + { + CanonicalQs += fmt::format("&continuation-token={}", AwsUriEncode(ContinuationToken)); + } + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "GET", RootPath, CanonicalQs, S3EmptyPayloadHash); + + std::string FullPath = S3BuildRequestPath(RootPath, CanonicalQs); + HttpClient::Response Resp = m_Client.Get(FullPath, Headers).get(); + if (!Resp.IsSuccess()) + { + throw zen::runtime_error("S3AsyncStorage::ListAllObjects '{}': {}"sv, Prefix, S3ErrorMessage("S3 LIST failed", Resp)); + } + + if (!Resp.ResponsePayload) + { + break; + } + std::string_view Body(reinterpret_cast<const char*>(Resp.ResponsePayload.GetData()), Resp.ResponsePayload.GetSize()); + + std::string_view Cursor = Body; + while (true) + { + size_t ContentsOpen = Cursor.find("<Contents>"); + if (ContentsOpen == std::string_view::npos) + { + break; + } + size_t ContentsClose = Cursor.find(ContentsCloseTag, ContentsOpen); + if (ContentsClose == std::string_view::npos) + { + break; + } + std::string_view ContentsBlock = Cursor.substr(ContentsOpen, ContentsClose - ContentsOpen); + std::string_view Key = S3ExtractXmlValue(ContentsBlock, "Key"); + Keys.emplace_back(Key); + Cursor = Cursor.substr(ContentsClose + ContentsCloseTag.size()); + } + + std::string_view IsTruncated = S3ExtractXmlValue(Body, "IsTruncated"); + if (IsTruncated != "true") + { + break; + } + std::string_view NextToken = S3ExtractXmlValue(Body, "NextContinuationToken"); + if (NextToken.empty()) + { + break; + } + ContinuationToken.assign(NextToken); + } + return Keys; +} + +std::vector<IoHash> +S3AsyncStorage::List() +{ + std::string CasPrefix = fmt::format("{}/cas/", m_KeyPrefix); + std::vector<std::string> Keys = ListAllObjects(CasPrefix); + + std::vector<IoHash> Hashes; + Hashes.reserve(Keys.size()); + for (const std::string& Key : Keys) + { + size_t LastSlash = Key.rfind('/'); + if (LastSlash == std::string::npos) + { + continue; + } + IoHash Hash; + if (IoHash::TryParse(std::string_view(Key).substr(LastSlash + 1), Hash)) + { + Hashes.push_back(Hash); + } + } + return Hashes; +} + +void +S3AsyncStorage::DeleteAll(ParallelWork& Work) +{ + std::string Prefix = fmt::format("{}/", m_KeyPrefix); + std::vector<std::string> Keys = ListAllObjects(Prefix); + + // Snapshot creds once for the whole obliterate; saves N shared-lock + // acquisitions on the credential provider. + const SigV4Credentials DelCreds = m_GetCreds(); + if (DelCreds.AccessKeyId.empty()) + { + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::DeleteAll: no credentials available"sv))); + return; + } + + for (const std::string& Key : Keys) + { + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + + try + { + std::string Path = m_Builder.KeyToPath(Key); + + HttpClient::KeyValueMap DelHeaders = m_Builder.SignRequest(DelCreds, "DELETE", Path, "", S3EmptyPayloadHash); + + // DeleteAll is untracked by stats - admission slot is acquired but the + // wait is not recorded; passes Stats == nullptr to the helper. + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission); + + m_Client.AsyncDelete( + Path, + [KeyCopy = Key, Token, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + if (!Resp.IsSuccess() && Resp.StatusCode != HttpResponseCode::NotFound) + { + std::string Err = S3ErrorMessage("S3 DELETE failed", Resp); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::DeleteAll '{}': {}"sv, KeyCopy, Err))); + return; + } + Token->Complete(); + }, + DelHeaders); + } + catch (...) + { + // Sign / acquire / dispatch threw before AsyncDelete posted. Surface via Token so + // Wait() reports the partial failure instead of silently dropping the key. + Token->Fail(std::current_exception()); + } + } +} + +void +s3asyncstorage_forcelink() +{ +} + +#if ZEN_WITH_TESTS + +namespace { + // Per-binary unique MinIO port. + uint16_t AllocateMinioTestPort() + { + static const uint16_t Base = static_cast<uint16_t>(20000u + (static_cast<uint32_t>(GetCurrentProcessId()) % 30000u)); + static std::atomic<uint16_t> Slot{0}; + return Base + Slot.fetch_add(1, std::memory_order_relaxed); + } + + MinioProcessOptions MakeMinioOpts() + { + MinioProcessOptions Opts; + Opts.Port = AllocateMinioTestPort(); + return Opts; + } + + struct AsyncS3Fixture + { + MinioProcess Minio{MakeMinioOpts()}; + ScopedTemporaryDirectory TmpDir; + std::unique_ptr<AsyncHttpClient> ClientStorage; + std::unique_ptr<S3RequestBuilder> Builder; + AsyncHttpClient* Client = nullptr; + SigV4Credentials Creds; + WorkerThreadPool Pool{4}; + + AsyncS3Fixture() + { + Minio.SpawnMinioServer(); + Minio.CreateBucket("async-test"); + + Creds.AccessKeyId = "minioadmin"; + Creds.SecretAccessKey = "minioadmin"; + + Builder = std::make_unique<S3RequestBuilder>("us-east-1", "async-test", Minio.Endpoint(), /*PathStyle*/ true); + + HttpClientSettings Settings; + Settings.LogCategory = "async-s3-test"; + Settings.MaxConcurrentConnectionsPerHost = 16; + // Match production S3Hydration: cover MinIO post-spawn 503 + // (XMinioServerNotInitialized) transients before storage backend + // finishes warming up after the spawn ready-check passes. + Settings.RetryCount = 3; + ClientStorage = std::make_unique<AsyncHttpClient>(Minio.Endpoint(), Settings); + Client = ClientStorage.get(); + } + }; + + std::filesystem::path WriteBlob(const std::filesystem::path& Path, const std::vector<uint8_t>& Bytes) + { + WriteFile(Path, IoBuffer(IoBuffer::Wrap, Bytes.data(), Bytes.size())); + return Path; + } + + struct StatsBlock + { + std::atomic<uint64_t> RequestCount{0}; + std::atomic<uint64_t> RequestTotalUs{0}; + std::atomic<uint64_t> RequestMaxUs{0}; + std::atomic<uint64_t> Bytes{0}; + std::atomic<uint32_t> InFlight{0}; + std::atomic<uint32_t> InFlightPeak{0}; + std::atomic<uint64_t> FirstScheduleUs{UINT64_MAX}; + std::atomic<uint64_t> FirstStartUs{UINT64_MAX}; + std::atomic<uint64_t> AdmissionWaitTotalUs{0}; + std::atomic<uint64_t> AdmissionWaitMaxUs{0}; + Stopwatch PhaseClock; + + S3AsyncStorageStats Ref() + { + return S3AsyncStorageStats{RequestCount, + RequestTotalUs, + RequestMaxUs, + Bytes, + InFlight, + InFlightPeak, + FirstScheduleUs, + FirstStartUs, + AdmissionWaitTotalUs, + AdmissionWaitMaxUs, + PhaseClock}; + } + }; +} // namespace + +TEST_SUITE_BEGIN("server.s3asyncstorage"); + +TEST_CASE("s3asyncstorage.put_get_round_trip") +{ + StatsBlock Sb; + AsyncS3Fixture F; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-A", + 8u * 1024u * 1024u); + + const uint64_t Size = 64u * 1024u; + std::vector<uint8_t> Bytes(Size); + for (size_t I = 0; I < Size; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 31u) & 0xFF); + } + + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "src.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, Size, SrcPath, Stats); + Work.Wait(); + } + + std::filesystem::path DstPath = F.TmpDir.Path() / "dst.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Get(Work, F.Pool, ContentHash, Size, DstPath, Stats); + Work.Wait(); + } + + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == Size); + + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == ContentHash); +} + +TEST_CASE("s3asyncstorage.touch_existing_object") +{ + StatsBlock Sb; + AsyncS3Fixture F; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-touch", + 8u * 1024u * 1024u); + + std::vector<uint8_t> Bytes{1, 2, 3, 4, 5}; + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "touch.bin", Bytes); + IoHash Hash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, Hash, Bytes.size(), SrcPath, Stats); + Work.Wait(); + } + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Touch(Work, F.Pool, Hash, Stats); + Work.Wait(); + } +} + +TEST_CASE("s3asyncstorage.list_returns_uploaded_hashes") +{ + StatsBlock Sb; + AsyncS3Fixture F; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-list", + 8u * 1024u * 1024u); + + const size_t N = 5; + std::vector<IoHash> Hashes; + std::vector<std::filesystem::path> SrcPaths; + for (size_t I = 0; I < N; ++I) + { + std::vector<uint8_t> Bytes(64, static_cast<uint8_t>(I + 1)); + std::filesystem::path P = WriteBlob(F.TmpDir.Path() / fmt::format("blob_{}.bin", I), Bytes); + Hashes.push_back(IoHash::HashBuffer(Bytes.data(), Bytes.size())); + SrcPaths.push_back(P); + } + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + for (size_t I = 0; I < N; ++I) + { + Storage.Put(Work, F.Pool, Hashes[I], 64, SrcPaths[I], Stats); + } + Work.Wait(); + } + + std::vector<IoHash> Listed = Storage.List(); + std::sort(Listed.begin(), Listed.end()); + std::vector<IoHash> Expected = Hashes; + std::sort(Expected.begin(), Expected.end()); + CHECK(Listed == Expected); +} + +TEST_CASE("s3asyncstorage.streaming_download_round_trip") +{ + // Object size sits between the streaming threshold (512 KiB) and the + // multipart threshold (PartSize + PartSize/4). Exercises the medium-tier + // AsyncStream branch in S3AsyncStorage::Get - body streams via a 512 KiB + // IoBuffer pool with per-buffer worker write hops. + StatsBlock Sb; + AsyncS3Fixture F; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-stream", + 8u * 1024u * 1024u); + + const uint64_t Size = 2u * 1024u * 1024u; // 2 MiB - between 512 KiB and 10 MiB + std::vector<uint8_t> Bytes(Size); + for (size_t I = 0; I < Size; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 53u + 11u) & 0xFF); + } + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "stream.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, Size, SrcPath, Stats); + Work.Wait(); + } + + std::filesystem::path DstPath = F.TmpDir.Path() / "stream_out.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Get(Work, F.Pool, ContentHash, Size, DstPath, Stats); + Work.Wait(); + } + + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == Size); + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == ContentHash); +} + +TEST_CASE("s3asyncstorage.multipart_round_trip") +{ + StatsBlock Sb; + AsyncS3Fixture F; + const uint64_t PartSize = 5u * 1024u * 1024u; // MinIO accepts >=5MiB parts. + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-multipart", + PartSize); + + // Three parts: 5 MiB + 5 MiB + 1 MiB. Threshold is PartSize + PartSize/4 (= + // 6.25 MiB). Total 11 MiB triggers multipart path. + const uint64_t TotalSize = 11u * 1024u * 1024u; + std::vector<uint8_t> Bytes(TotalSize); + for (size_t I = 0; I < TotalSize; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 131u + 7u) & 0xFF); + } + + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "big.bin", Bytes); + IoHash Hash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, Hash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + + std::filesystem::path DstPath = F.TmpDir.Path() / "big_out.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Get(Work, F.Pool, Hash, TotalSize, DstPath, Stats); + Work.Wait(); + } + + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == TotalSize); + + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == Hash); +} + +TEST_CASE("s3asyncstorage.parallel_puts") +{ + StatsBlock Sb; + AsyncS3Fixture F; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-parallel", + 8u * 1024u * 1024u); + + const size_t N = 32; + std::vector<IoHash> Hashes; + std::vector<std::filesystem::path> SrcPaths; + for (size_t I = 0; I < N; ++I) + { + std::vector<uint8_t> Bytes(1024, static_cast<uint8_t>(I)); + std::filesystem::path P = WriteBlob(F.TmpDir.Path() / fmt::format("p_{}.bin", I), Bytes); + Hashes.push_back(IoHash::HashBuffer(Bytes.data(), Bytes.size())); + SrcPaths.push_back(P); + } + + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + for (size_t I = 0; I < N; ++I) + { + Storage.Put(Work, F.Pool, Hashes[I], 1024, SrcPaths[I], Stats); + } + Work.Wait(); + + CHECK(Storage.List().size() == N); +} + +namespace { + // One MinIO + builder + creds + pool, reused across multiple admission / + // throttle test scopes. Each scope constructs its own AsyncHttpClient via + // MakeClient(Cap) so different MaxConcurrentRequests values are exercised + // without re-spawning MinIO. + struct ThrottledAsyncS3Fixture + { + MinioProcess Minio{MakeMinioOpts()}; + ScopedTemporaryDirectory TmpDir; + std::unique_ptr<S3RequestBuilder> Builder; + SigV4Credentials Creds; + WorkerThreadPool Pool{4}; + + ThrottledAsyncS3Fixture() + { + Minio.SpawnMinioServer(); + Minio.CreateBucket("async-test-throttle"); + + Creds.AccessKeyId = "minioadmin"; + Creds.SecretAccessKey = "minioadmin"; + + Builder = std::make_unique<S3RequestBuilder>("us-east-1", "async-test-throttle", Minio.Endpoint(), /*PathStyle*/ true); + } + + std::unique_ptr<AsyncHttpClient> MakeClient(uint32_t MaxConcurrentRequests) + { + HttpClientSettings Settings; + Settings.LogCategory = "async-s3-throttle-test"; + Settings.MaxConcurrentConnectionsPerHost = 16; + Settings.MaxConcurrentRequests = MaxConcurrentRequests; + // Match production S3Hydration: cover MinIO post-spawn 503 + // (XMinioServerNotInitialized) transients before storage backend + // finishes warming up after the spawn ready-check passes. + Settings.RetryCount = 3; + return std::make_unique<AsyncHttpClient>(Minio.Endpoint(), Settings); + } + }; +} // namespace + +// Non-multipart admission scenarios: small-file fanout under a cap, fanout +// with admission disabled, and admission-wait stat recording. All share one +// MinIO instance via ThrottledAsyncS3Fixture; each scope picks its own Cap +// and KeyPrefix so the bucket entries don't collide across scopes. +TEST_CASE("s3asyncstorage.admission.fanout") +{ + ThrottledAsyncS3Fixture F; + + // respects.async.client.cap - in-flight peak <= storage admission cap. + { + StatsBlock Sb; + const uint32_t Cap = 2; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-cap", + 8u * 1024u * 1024u, + Admission, + Cap); + + const size_t N = 8; + std::vector<IoHash> Hashes; + std::vector<std::filesystem::path> SrcPaths; + for (size_t I = 0; I < N; ++I) + { + std::vector<uint8_t> Bytes(256, static_cast<uint8_t>(I)); + std::filesystem::path P = WriteBlob(F.TmpDir.Path() / fmt::format("c_{}.bin", I), Bytes); + Hashes.push_back(IoHash::HashBuffer(Bytes.data(), Bytes.size())); + SrcPaths.push_back(P); + } + + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + for (size_t I = 0; I < N; ++I) + { + Storage.Put(Work, F.Pool, Hashes[I], 256, SrcPaths[I], Stats); + } + Work.Wait(); + + std::vector<IoHash> Listed = Storage.List(); + std::sort(Listed.begin(), Listed.end()); + std::vector<IoHash> Expected = Hashes; + std::sort(Expected.begin(), Expected.end()); + CHECK(Listed == Expected); + CHECK(Sb.InFlightPeak.load() <= 2u); + } + + // unlimited_concurrent_requests - admission disabled (nullptr semaphore, + // Cap=0). Storage drains the fanout cleanly without gating. + { + StatsBlock Sb; + auto Client = F.MakeClient(0); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-uncapped", + 8u * 1024u * 1024u); + + const size_t N = 8; + std::vector<IoHash> Hashes; + std::vector<std::filesystem::path> SrcPaths; + for (size_t I = 0; I < N; ++I) + { + std::vector<uint8_t> Bytes(256, static_cast<uint8_t>(I)); + std::filesystem::path P = WriteBlob(F.TmpDir.Path() / fmt::format("u_{}.bin", I), Bytes); + Hashes.push_back(IoHash::HashBuffer(Bytes.data(), Bytes.size())); + SrcPaths.push_back(P); + } + + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + for (size_t I = 0; I < N; ++I) + { + Storage.Put(Work, F.Pool, Hashes[I], 256, SrcPaths[I], Stats); + } + Work.Wait(); + + std::vector<IoHash> Listed = Storage.List(); + std::sort(Listed.begin(), Listed.end()); + std::vector<IoHash> Expected = Hashes; + std::sort(Expected.begin(), Expected.end()); + CHECK(Listed == Expected); + } + + // admission.wait.us.recorded - cap=2 with 8 fanout submits forces at least + // some submissions to block on the admission semaphore. AdmissionWait + // totals must be non-zero. + { + StatsBlock Sb; + const uint32_t Cap = 2; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-adm-wait", + 8u * 1024u * 1024u, + Admission, + Cap); + + const size_t N = 8; + std::vector<IoHash> Hashes; + std::vector<std::filesystem::path> SrcPaths; + for (size_t I = 0; I < N; ++I) + { + std::vector<uint8_t> Bytes(256, static_cast<uint8_t>(I)); + std::filesystem::path P = WriteBlob(F.TmpDir.Path() / fmt::format("aw_{}.bin", I), Bytes); + Hashes.push_back(IoHash::HashBuffer(Bytes.data(), Bytes.size())); + SrcPaths.push_back(P); + } + + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + for (size_t I = 0; I < N; ++I) + { + Storage.Put(Work, F.Pool, Hashes[I], 256, SrcPaths[I], Stats); + } + Work.Wait(); + + CHECK(Sb.AdmissionWaitTotalUs.load() > 0u); + CHECK(Sb.AdmissionWaitMaxUs.load() > 0u); + CHECK(Sb.InFlightPeak.load() <= Cap); + } +} + +// Drives the AbortMultipart path: a 3-part multipart upload where one part is +// forced to fail via the test hook. AnyFailed flips, FinalizePutPart hops +// off-strand once PendingParts reaches 1, AbortMultipart sends DELETE to S3 +// to discard the upload-id. +// Token must surface the original part failure; bucket must not contain the +// uploaded CAS object after Wait returns. +TEST_CASE("s3asyncstorage.multipart.abort_on_part_failure") +{ + StatsBlock Sb; + AsyncS3Fixture F; + const uint64_t PartSize = 5u * 1024u * 1024u; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-multipart-abort", + PartSize); + + const uint64_t TotalSize = 11u * 1024u * 1024u; // 3 parts: 5+5+1 MiB + std::vector<uint8_t> Bytes(TotalSize); + for (size_t I = 0; I < TotalSize; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 23u + 9u) & 0xFF); + } + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "abort.bin", Bytes); + IoHash Hash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + // Force one part to fail; AnyFailed -> AbortMultipart. + s3asyncstorage_test_hooks::ForceNextPartFailures(1); + + bool ThrewFromWait = false; + std::string ErrMessage; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, Hash, TotalSize, SrcPath, Stats); + try + { + Work.Wait(); + } + catch (const std::exception& Ex) + { + ThrewFromWait = true; + ErrMessage = Ex.what(); + } + } + CHECK(ThrewFromWait); + CHECK(ErrMessage.find("test-injected part failure") != std::string::npos); + + // Reset hook so subsequent multipart tests aren't affected. + s3asyncstorage_test_hooks::ForceNextPartFailures(0); + + // Bucket must not contain the CAS object - AbortMultipart discards the + // upload before S3 publishes the assembled object. + std::vector<IoHash> Listed = Storage.List(); + CHECK(std::find(Listed.begin(), Listed.end(), Hash) == Listed.end()); +} + +// Multipart admission scenarios sharing one MinIO instance. Each scope picks +// its own Cap, KeyPrefix, and payload pattern so bucket entries don't collide. +// Covers: +// - under.cap: TotalParts < cap (no slot handoff needed); round-trip Put+Get. +// - no.deadlock.streaming.get: Put + ranged Get on same hydration pool. +// - parts.exceed.cap.paces: TotalParts > cap; wave + slot-handoff completes +// successfully with InFlightPeak <= cap. +// - aborts.release.tokens: forced part failure releases held slots so a +// follow-up Put can run. +// - handoff.in.flight: cap=1 forces strictly sequential dispatch. +TEST_CASE("s3asyncstorage.admission.multipart") +{ + ThrottledAsyncS3Fixture F; + const uint64_t PartSize = 5u * 1024u * 1024u; // MinIO accepts >= 5 MiB parts. + + // under.cap: 3-part Put + Get sits within the initial wave (cap=4). + { + StatsBlock Sb; + const uint32_t Cap = 4; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-multipart-cap", + PartSize, + Admission, + Cap); + + const uint64_t TotalSize = 11u * 1024u * 1024u; + std::vector<uint8_t> Bytes(TotalSize); + for (size_t I = 0; I < TotalSize; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 17u + 5u) & 0xFF); + } + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "mp.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + + std::filesystem::path DstPath = F.TmpDir.Path() / "mp_out.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Get(Work, F.Pool, ContentHash, TotalSize, DstPath, Stats); + Work.Wait(); + } + + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == TotalSize); + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == ContentHash); + } + + // no.deadlock.streaming.get: PutMultipart fans onto hydration pool while + // io strand fires AsyncPut completions; followed by Get exercising stream + // write-back on the same pool. + { + StatsBlock Sb; + const uint32_t Cap = 4; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-stream-no-deadlock", + PartSize, + Admission, + Cap); + + const uint64_t TotalSize = 11u * 1024u * 1024u; + std::vector<uint8_t> Bytes(TotalSize); + for (size_t I = 0; I < TotalSize; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 41u + 7u) & 0xFF); + } + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "sd.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + + std::filesystem::path DstPath = F.TmpDir.Path() / "sd_out.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Get(Work, F.Pool, ContentHash, TotalSize, DstPath, Stats); + Work.Wait(); + } + + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == TotalSize); + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == ContentHash); + } + + // parts.exceed.cap.paces: TotalParts > cap completes by pacing via slot + // handoff. Initial wave dispatches Cap parts; each completion hands its + // slot to the next undispatched part. InFlightPeak stays bounded by Cap. + { + StatsBlock Sb; + const uint32_t Cap = 2; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-mp-paces", + PartSize, + Admission, + Cap); + + const uint64_t TotalSize = 16u * 1024u * 1024u; // 4 parts > cap=2 + std::vector<uint8_t> Bytes(TotalSize, 0xA5); + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "mp_paces.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + CHECK(Sb.InFlightPeak.load() <= Cap); + std::vector<IoHash> Listed = Storage.List(); + CHECK(std::find(Listed.begin(), Listed.end(), ContentHash) != Listed.end()); + } + + // aborts.release.tokens: forced part failure mid-flight; drain releases + // admission slots so a follow-up small Put can run without blocking. + { + StatsBlock Sb; + const uint32_t Cap = 4; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-mp-abort-release", + PartSize, + Admission, + Cap); + + const uint64_t TotalSize = 11u * 1024u * 1024u; // 3 parts under cap=4 + std::vector<uint8_t> Bytes(TotalSize, 0x5A); + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "mp_abrel.bin", Bytes); + IoHash AbortHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + s3asyncstorage_test_hooks::ForceNextPartFailures(1); + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, AbortHash, TotalSize, SrcPath, Stats); + try + { + Work.Wait(); + } + catch (...) + { + } + } + s3asyncstorage_test_hooks::ForceNextPartFailures(0); + + std::vector<uint8_t> Small(1024, 0x11); + std::filesystem::path SmallPath = WriteBlob(F.TmpDir.Path() / "mp_abrel_small.bin", Small); + IoHash SmallHash = IoHash::HashBuffer(Small.data(), Small.size()); + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, SmallHash, Small.size(), SmallPath, Stats); + Work.Wait(); + } + std::vector<IoHash> Listed = Storage.List(); + CHECK(std::find(Listed.begin(), Listed.end(), SmallHash) != Listed.end()); + } + + // handoff.in.flight: cap=1 forces strictly sequential dispatch (initial + // wave size 1, every subsequent part dispatched via handoff from the prior + // AsyncPut completion). + { + StatsBlock Sb; + const uint32_t Cap = 1; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-mp-handoff", + PartSize, + Admission, + Cap); + + const uint64_t TotalSize = 16u * 1024u * 1024u; // 4 parts; handoff fires 3 times + std::vector<uint8_t> Bytes(TotalSize, 0x33); + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "mp_handoff.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + CHECK(Sb.InFlightPeak.load() == 1u); + std::vector<IoHash> Listed = Storage.List(); + CHECK(std::find(Listed.begin(), Listed.end(), ContentHash) != Listed.end()); + } + + // ranges.exceed.cap.paces: GetMultipart with RangeCount > AdmissionCap. + // The dispatcher loop's per-range AcquireAdmissionSlot blocks the caller + // thread until in-flight ranges fire AsyncStream completions and release + // slots. InFlightPeak per Get stays bounded by Cap. Upload first with a + // larger client cap so the Put itself doesn't pace, then re-open Storage + // with the small cap purely to exercise the Get pacing path. + { + StatsBlock Sb; + const uint32_t UploadCap = 8; + auto UploadAdmission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(UploadCap)); + auto UploadClient = F.MakeClient(UploadCap); + S3AsyncStorage UploadStorage( + *UploadClient, + *F.Builder, + [&F]() { return F.Creds; }, + "module-mp-get-paces", + PartSize, + UploadAdmission, + UploadCap); + + const uint64_t TotalSize = 16u * 1024u * 1024u; // 4 ranges with PartSize=5 MiB + std::vector<uint8_t> Bytes(TotalSize, 0x77); + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "mp_get_paces.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + UploadStorage.Put(Work, F.Pool, ContentHash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + + StatsBlock GetSb; + const uint32_t GetCap = 2; + auto GetAdmission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(GetCap)); + auto GetClient = F.MakeClient(GetCap); + S3AsyncStorage GetStorage( + *GetClient, + *F.Builder, + [&F]() { return F.Creds; }, + "module-mp-get-paces", + PartSize, + GetAdmission, + GetCap); + + std::filesystem::path DstPath = F.TmpDir.Path() / "mp_get_paces_out.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = GetSb.Ref(); + GetStorage.Get(Work, F.Pool, ContentHash, TotalSize, DstPath, Stats); + Work.Wait(); + } + CHECK(GetSb.InFlightPeak.load() <= GetCap); + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == TotalSize); + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == ContentHash); + } +} + +TEST_SUITE_END(); + +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zenserver/hub/s3asyncstorage.h b/src/zenserver/hub/s3asyncstorage.h new file mode 100644 index 000000000..1bb8c14ca --- /dev/null +++ b/src/zenserver/hub/s3asyncstorage.h @@ -0,0 +1,224 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> +#include <zencore/parallelwork.h> +#include <zencore/timer.h> +#include <zenhttp/asynchttpclient.h> +#include <zenutil/cloud/s3client.h> +#include <zenutil/cloud/s3requestbuilder.h> +#include <zenutil/cloud/sigv4.h> + +#include <atomic> +#include <filesystem> +#include <functional> +#include <limits> +#include <memory> +#include <semaphore> +#include <string> + +namespace zen { + +// Storage-layer admission semaphore. Sized at runtime to MaxConcurrentRequests; +// LeastMaxValue is the compile-time upper bound. Caller may pass nullptr to +// disable gating entirely. +using AdmissionSemaphore = std::counting_semaphore<std::numeric_limits<std::ptrdiff_t>::max()>; + +struct S3AsyncStorageStats +{ + std::atomic<uint64_t>& RequestCount; + std::atomic<uint64_t>& RequestTotalUs; + std::atomic<uint64_t>& RequestMaxUs; + std::atomic<uint64_t>& Bytes; + std::atomic<uint32_t>& InFlight; + std::atomic<uint32_t>& InFlightPeak; + std::atomic<uint64_t>& FirstScheduleUs; + std::atomic<uint64_t>& FirstStartUs; + std::atomic<uint64_t>& AdmissionWaitTotalUs; + std::atomic<uint64_t>& AdmissionWaitMaxUs; + Stopwatch& PhaseClock; + + void RecordScheduled() + { + const uint64_t Now = PhaseClock.GetElapsedTimeUs(); + uint64_t Existing = FirstScheduleUs.load(std::memory_order_relaxed); + while (Now < Existing && !FirstScheduleUs.compare_exchange_weak(Existing, Now, std::memory_order_relaxed)) + { + } + } + + Stopwatch BeginRequest() + { + const uint64_t Now = PhaseClock.GetElapsedTimeUs(); + uint64_t Existing = FirstStartUs.load(std::memory_order_relaxed); + while (Now < Existing && !FirstStartUs.compare_exchange_weak(Existing, Now, std::memory_order_relaxed)) + { + } + const uint32_t Current = InFlight.fetch_add(1, std::memory_order_relaxed) + 1; + uint32_t Peak = InFlightPeak.load(std::memory_order_relaxed); + while (Current > Peak && !InFlightPeak.compare_exchange_weak(Peak, Current, std::memory_order_relaxed)) + { + } + return Stopwatch{}; + } + + void EndRequest(uint64_t ElapsedUs, uint64_t BytesValue) + { + InFlight.fetch_sub(1, std::memory_order_relaxed); + RequestCount.fetch_add(1, std::memory_order_relaxed); + RequestTotalUs.fetch_add(ElapsedUs, std::memory_order_relaxed); + Bytes.fetch_add(BytesValue, std::memory_order_relaxed); + uint64_t Existing = RequestMaxUs.load(std::memory_order_relaxed); + while (ElapsedUs > Existing && !RequestMaxUs.compare_exchange_weak(Existing, ElapsedUs, std::memory_order_relaxed)) + { + } + } + + void RecordAdmissionWait(uint64_t Us) + { + AdmissionWaitTotalUs.fetch_add(Us, std::memory_order_relaxed); + uint64_t Prev = AdmissionWaitMaxUs.load(std::memory_order_relaxed); + while (Us > Prev && !AdmissionWaitMaxUs.compare_exchange_weak(Prev, Us, std::memory_order_relaxed)) + { + } + } +}; + +// Async S3 storage adapter for hub hydration. Mirrors the behavior of the +// blocking `S3Storage` (defined inside hydration.cpp) but submits requests +// via `AsyncHttpClient` and counts in-flight work against `ParallelWork` +// using `ExternalWorkToken` instead of occupying worker-pool threads. +// +// Construction: +// S3AsyncStorage Storage(AsyncClient, RequestBuilder, GetCreds, KeyPrefix, +// MultipartChunkSize); +// +// AsyncClient and RequestBuilder are owned by the caller (typically the hub +// or a test fixture). GetCreds is a callable returning the latest SigV4 +// credentials - keeps the storage decoupled from any specific provider. +// +// Per-call SourcePath/DestinationPath args carry temp-dir state from the +// caller (IncrementalHydrator picks paths under HydrationConfig::TempDir); +// S3AsyncStorage itself never owns a temp directory. +// +// DeleteAll blocks the caller while the prefix listing runs, then issues +// async deletes via Work. +class S3AsyncStorage +{ +public: + using CredentialsCallback = std::function<SigV4Credentials()>; + + S3AsyncStorage(AsyncHttpClient& Client, + S3RequestBuilder& Builder, + CredentialsCallback GetCreds, + std::string KeyPrefix, + uint64_t MultipartChunkSize, + std::shared_ptr<AdmissionSemaphore> Admission = nullptr, + uint32_t AdmissionCap = 0); + + std::string_view KeyPrefix() const { return m_KeyPrefix; } + + // Async data operations. Each registers a ParallelWork::ExternalWorkToken; + // the async callback completes/fails the token. Caller drives Work.Wait() + // to block until all submissions finish. + // + // Stats (must outlive in-flight callbacks): every S3 request the storage + // issues calls RecordScheduled at submit time, BeginRequest just before + // the network handoff, and EndRequest from the completion callback. Bytes + // is the payload size on success, 0 on failure or zero-payload requests. + // Multipart and ranged GET fire one Begin/EndRequest per part/range. + void Put(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats); + void Get(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath, + S3AsyncStorageStats& Stats); + void Touch(ParallelWork& Work, WorkerThreadPool& Pool, const IoHash& Hash, S3AsyncStorageStats& Stats); + + // Synchronous list of all CAS hashes under this module's prefix. + std::vector<IoHash> List(); + + // Synchronous delete-all under the module prefix. Lists then issues async + // deletes via Work; caller still drives Work.Wait(). + void DeleteAll(ParallelWork& Work); + + // Forward declarations are public (not the structs themselves) so file-scope + // free helpers in s3asyncstorage.cpp can name the types in their signatures + // without being declared friends. The structs are defined privately in the + // .cpp; no out-of-module caller can construct or inspect them. +public: + struct PutMultipartState; + struct GetMultipartState; + struct GetStreamState; + +private: + std::string CasKey(const IoHash& Hash) const; + std::string CasPath(const IoHash& Hash) const; + + void PutSmall(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats); + void PutMedium(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats); + void PutMultipart(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats); + void DispatchInitialPartWave(std::shared_ptr<PutMultipartState> State); + void DispatchPartUpload(std::shared_ptr<PutMultipartState> State, uint32_t PartNum, std::shared_ptr<void> SlotRef); + void HandoffSlotToNextPart(std::shared_ptr<PutMultipartState> State, uint32_t PartIdx, std::shared_ptr<void> SlotRef); + void DrainUndispatchedParts(std::shared_ptr<PutMultipartState> State); + void FinalizePutPart(std::shared_ptr<PutMultipartState> State); + void CompleteMultipart(std::shared_ptr<PutMultipartState> State); + void AbortMultipart(std::shared_ptr<PutMultipartState> State); + + void GetMultipart(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath, + S3AsyncStorageStats& Stats); + void OnGetPartCompleted(std::shared_ptr<GetMultipartState> State); + void OnGetStreamFinalised(std::shared_ptr<GetStreamState> State); + + std::vector<std::string> ListAllObjects(std::string_view Prefix); + + AsyncHttpClient& m_Client; + S3RequestBuilder& m_Builder; + CredentialsCallback m_GetCreds; + std::string m_KeyPrefix; + uint64_t m_MultipartChunkSize; + std::shared_ptr<AdmissionSemaphore> m_Admission; // nullable; when null, no admission gating + uint32_t m_AdmissionCap; // initial slot count; 0 when admission disabled +}; + +void s3asyncstorage_forcelink(); + +#if ZEN_WITH_TESTS +namespace s3asyncstorage_test_hooks { + // Per-process counter consumed by DispatchPartUpload. While > 0, each + // invocation decrements it and synthesizes a part-level failure (via + // RecordPutPartFailure + drain + FinalizePutPart fan-out) instead of + // issuing the UploadPart request. Used to drive the AbortMultipart path + // from tests without fault-injecting MinIO. + void ForceNextPartFailures(uint32_t Count); +} // namespace s3asyncstorage_test_hooks +#endif + +} // namespace zen diff --git a/src/zenserver/hub/storageserverinstance.cpp b/src/zenserver/hub/storageserverinstance.cpp index 8d36e6a46..34860eb9b 100644 --- a/src/zenserver/hub/storageserverinstance.cpp +++ b/src/zenserver/hub/storageserverinstance.cpp @@ -2,8 +2,6 @@ #include "storageserverinstance.h" -#include "hydration.h" - #include <zencore/assertfmt.h> #include <zencore/filesystem.h> #include <zencore/fmtutils.h> @@ -14,12 +12,8 @@ namespace zen { -StorageServerInstance::StorageServerInstance(ZenServerEnvironment& RunEnvironment, - HydrationBase& Hydration, - const Configuration& Config, - std::string_view ModuleId) -: m_Hydration(Hydration) -, m_Config(Config) +StorageServerInstance::StorageServerInstance(ZenServerEnvironment& RunEnvironment, const Configuration& Config, std::string_view ModuleId) +: m_Config(Config) , m_ModuleId(ModuleId) , m_ServerInstance(RunEnvironment, ZenServerInstance::ServerMode::kStorageServer) { @@ -138,7 +132,6 @@ StorageServerInstance::ProvisionLocked() ZEN_INFO("Provisioning storage server instance for module '{}', at '{}'", m_ModuleId, m_Config.StateDir); try { - Hydrate(); SpawnServerProcess(); } catch (const std::exception& Ex) @@ -156,19 +149,6 @@ StorageServerInstance::DeprovisionLocked() { ZEN_TRACE_CPU("StorageServerInstance::DeprovisionLocked"); ShutdownServerProcess(); - - // Crashed or Hibernated: process already dead; skip Shutdown. - // Dehydrate preserves instance state for future re-provisioning. Failure means saved state - // may be stale or absent, but the process is already dead so the slot can still be released. - // Swallow the exception and proceed with cleanup rather than leaving the module stuck. - try - { - Dehydrate(); - } - catch (const std::exception& Ex) - { - ZEN_WARN("Dehydration of module {} failed during deprovisioning, current state not saved. Reason: {}", m_ModuleId, Ex.what()); - } } void @@ -176,12 +156,6 @@ StorageServerInstance::ObliterateLocked() { ZEN_TRACE_CPU("StorageServerInstance::ObliterateLocked"); ShutdownServerProcess(); - - std::atomic<bool> AbortFlag{false}; - std::atomic<bool> PauseFlag{false}; - HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag); - std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration.CreateHydrator(Config); - Hydrator->Obliterate(); } void @@ -218,54 +192,6 @@ StorageServerInstance::WakeLocked() } } -void -StorageServerInstance::Hydrate() -{ - if (!m_Config.EnableHydration) - { - ZEN_INFO("Hydration disabled; skipping hydrate for module '{}'", m_ModuleId); - return; - } - ZEN_TRACE_CPU("StorageServerInstance::Hydrate"); - std::atomic<bool> AbortFlag{false}; - std::atomic<bool> PauseFlag{false}; - HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag); - std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration.CreateHydrator(Config); - m_HydrationState = Hydrator->Hydrate(); -} - -void -StorageServerInstance::Dehydrate() -{ - if (!m_Config.EnableDehydration) - { - ZEN_INFO("Dehydration disabled; skipping dehydrate for module '{}'", m_ModuleId); - return; - } - ZEN_TRACE_CPU("StorageServerInstance::Dehydrate"); - std::atomic<bool> AbortFlag{false}; - std::atomic<bool> PauseFlag{false}; - HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag); - std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration.CreateHydrator(Config); - Hydrator->Dehydrate(m_HydrationState); -} - -HydrationConfig -StorageServerInstance::MakeHydrationConfig(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag) -{ - HydrationConfig Config{.ServerStateDir = m_Config.StateDir, .TempDir = m_Config.TempDir, .ModuleId = m_ModuleId}; - if (m_Config.OptionalWorkerPool) - { - Config.Threading.emplace( - HydrationConfig::ThreadingOptions{.WorkerPool = m_Config.OptionalWorkerPool, .AbortFlag = &AbortFlag, .PauseFlag = &PauseFlag}); - } - Config.PackEnabled = m_Config.HydrationPackEnabled; - Config.PackThresholdBytes = m_Config.HydrationPackThresholdBytes; - Config.MaxPackBytes = m_Config.HydrationMaxPackBytes; - - return Config; -} - StorageServerInstance::SharedLockedPtr::SharedLockedPtr() : m_Lock(nullptr), m_Instance(nullptr) { } diff --git a/src/zenserver/hub/storageserverinstance.h b/src/zenserver/hub/storageserverinstance.h index a2f376a23..599431300 100644 --- a/src/zenserver/hub/storageserverinstance.h +++ b/src/zenserver/hub/storageserverinstance.h @@ -2,18 +2,12 @@ #pragma once -#include "hydration.h" - -#include <zencore/compactbinary.h> #include <zenutil/zenserverprocess.h> -#include <atomic> #include <filesystem> namespace zen { -class WorkerThreadPool; - /** * Storage Server Instance * @@ -28,7 +22,6 @@ public: { uint16_t BasePort; std::filesystem::path StateDir; - std::filesystem::path TempDir; uint32_t HttpThreadCount = 0; // Automatic int CoreLimit = 0; // Automatic std::filesystem::path ConfigPath; @@ -36,19 +29,9 @@ public: std::string Trace; std::string TraceHost; std::string TraceFile; - bool EnableHydration = true; - bool EnableDehydration = true; - bool HydrationPackEnabled = true; - uint64_t HydrationPackThresholdBytes = DefaultPackThresholdBytes; - uint64_t HydrationMaxPackBytes = DefaultMaxPackBytes; - - WorkerThreadPool* OptionalWorkerPool = nullptr; }; - StorageServerInstance(ZenServerEnvironment& RunEnvironment, - HydrationBase& Hydration, - const Configuration& Config, - std::string_view ModuleId); + StorageServerInstance(ZenServerEnvironment& RunEnvironment, const Configuration& Config, std::string_view ModuleId); ~StorageServerInstance(); inline std::string_view GetModuleId() const { return m_ModuleId; } @@ -146,13 +129,10 @@ private: void WakeLocked(); mutable RwLock m_Lock; - HydrationBase& m_Hydration; const Configuration m_Config; std::string m_ModuleId; ZenServerInstance m_ServerInstance; - CbObject m_HydrationState; - #if ZEN_PLATFORM_WINDOWS JobObject* m_JobObject = nullptr; #endif @@ -160,10 +140,6 @@ private: void SpawnServerProcess(); void ShutdownServerProcess(); - void Hydrate(); - void Dehydrate(); - HydrationConfig MakeHydrationConfig(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag); - friend class SharedLockedPtr; friend class ExclusiveLockedPtr; }; diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp index 303b1f1b2..a2a366a80 100644 --- a/src/zenserver/hub/zenhubserver.cpp +++ b/src/zenserver/hub/zenhubserver.cpp @@ -188,18 +188,29 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.HubInstanceConfigPath), "<instance config>"); - const uint32_t DefaultHubInstanceProvisionThreadCount = Max(GetHardwareConcurrency() / 4u, 2u); + // Provision pool: scheduling-bound, blocked on hydrate. clamp(cpu/8, 4, 16). + const uint32_t DefaultHubInstanceProvisionThreadCount = Min(Max(GetHardwareConcurrency() / 8u, 4u), 16u); + // Spawn pool: blocking on CreateProcess + /health, RSS-bounded. clamp(cpu/8, 4, 16). + const uint32_t DefaultHubInstanceSpawnThreadCount = Min(Max(GetHardwareConcurrency() / 8u, 4u), 16u); Options.add_option("hub", "", "hub-instance-provision-threads", - fmt::format("Number of threads for instance provisioning (default {})", DefaultHubInstanceProvisionThreadCount), + fmt::format("Hub-wide hydrate/dehydrate scheduling pool size (default {})", DefaultHubInstanceProvisionThreadCount), cxxopts::value<uint32_t>(m_ServerOptions.HubInstanceProvisionThreadCount) ->default_value(fmt::format("{}", DefaultHubInstanceProvisionThreadCount)), "<threads>"); Options.add_option("hub", "", + "hub-instance-spawn-threads", + fmt::format("Hub-wide child-process spawn/despawn pool size (default {})", DefaultHubInstanceSpawnThreadCount), + cxxopts::value<uint32_t>(m_ServerOptions.HubInstanceSpawnThreadCount) + ->default_value(fmt::format("{}", DefaultHubInstanceSpawnThreadCount)), + "<threads>"); + + Options.add_option("hub", + "", "hub-hydration-target-spec", "Specification for hydration target. 'file://<path>' prefix indicates file storage at <path>. Defaults to " "<data-dir>/servers/hydration_storage", @@ -214,13 +225,18 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.HydrationTargetConfigPath), "<path>"); - const uint32_t DefaultHubHydrationThreadCount = Max(GetHardwareConcurrency() / 4u, 2u); + // Hydration pool: per-file workers inside one hydrate/dehydrate call (xxhash, pack + // build, file I/O; S3 I/O only on the blocking path when HubHydrationAsyncEnabled is + // false). At 30ms latency on loopback the AsyncHttpClient strand is the bottleneck; + // hydration=8/12/16/32 all landed within run-to-run noise (~10s) at 1000 modules. + // clamp(cpu/8, 4, 12). + const uint32_t DefaultHubHydrationThreadCount = Min(Max(GetHardwareConcurrency() / 8u, 4u), 12u); Options.add_option( "hub", "", "hub-hydration-threads", - fmt::format("Number of threads for hydration/dehydration (default {})", DefaultHubHydrationThreadCount), + fmt::format("Per-file worker pool size inside hydrate/dehydrate (default {})", DefaultHubHydrationThreadCount), cxxopts::value<uint32_t>(m_ServerOptions.HubHydrationThreadCount)->default_value(fmt::format("{}", DefaultHubHydrationThreadCount)), "<threads>"); @@ -273,6 +289,30 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("hub", "", + "hub-hydration-async-enabled", + "Route S3 hydration through AsyncHttpClient. false falls back to the blocking S3Client path.", + cxxopts::value<bool>(m_ServerOptions.HubHydrationAsyncEnabled)->default_value("true"), + "<bool>"); + + // Async S3 in-flight cap: also drives CURLMOPT_MAXCONNECTS and curl per-host conn limits in + // the AsyncHttpClient. Sized so big modules (~hundreds of parallel ranges) saturate the pipe + // without admission-wait, and the conn cache stays large enough to avoid eviction churn on + // reused conns. clamp(cpu*4, 128, 512). + const uint32_t DefaultHubHydrationAsyncMaxConcurrentRequests = Min(Max(GetHardwareConcurrency() * 4u, 128u), 512u); + + Options.add_option( + "hub", + "", + "hub-hydration-async-max-concurrent-requests", + fmt::format("Max in-flight S3 requests submitted to the AsyncHttpClient. Only used when --hub-hydration-async-enabled " + "(default {}).", + DefaultHubHydrationAsyncMaxConcurrentRequests), + cxxopts::value<uint32_t>(m_ServerOptions.HubHydrationAsyncMaxConcurrentRequests) + ->default_value(fmt::format("{}", DefaultHubHydrationAsyncMaxConcurrentRequests)), + "<n>"); + + Options.add_option("hub", + "", "hub-watchdog-cycle-interval-ms", "Interval between watchdog cycles in milliseconds", cxxopts::value<uint32_t>(m_ServerOptions.WatchdogConfig.CycleIntervalMs)->default_value("3000"), @@ -401,6 +441,7 @@ ZenHubServerConfigurator::AddConfigOptions(LuaConfig::Options& Options) Options.AddOption("hub.instance.provisionthreads"sv, m_ServerOptions.HubInstanceProvisionThreadCount, "hub-instance-provision-threads"sv); + Options.AddOption("hub.instance.spawnthreads"sv, m_ServerOptions.HubInstanceSpawnThreadCount, "hub-instance-spawn-threads"sv); Options.AddOption("hub.hydration.targetspec"sv, m_ServerOptions.HydrationTargetSpecification, "hub-hydration-target-spec"sv); Options.AddOption("hub.hydration.targetconfig"sv, m_ServerOptions.HydrationTargetConfigPath, "hub-hydration-target-config"sv); @@ -412,6 +453,10 @@ ZenHubServerConfigurator::AddConfigOptions(LuaConfig::Options& Options) m_ServerOptions.HubHydrationPackThresholdBytes, "hub-hydration-pack-threshold-bytes"sv); Options.AddOption("hub.hydration.maxpackbytes"sv, m_ServerOptions.HubHydrationMaxPackBytes, "hub-hydration-max-pack-bytes"sv); + Options.AddOption("hub.hydration.async.enabled"sv, m_ServerOptions.HubHydrationAsyncEnabled, "hub-hydration-async-enabled"sv); + Options.AddOption("hub.hydration.async.maxconcurrentrequests"sv, + m_ServerOptions.HubHydrationAsyncMaxConcurrentRequests, + "hub-hydration-async-max-concurrent-requests"sv); Options.AddOption("hub.watchdog.cycleintervalms"sv, m_ServerOptions.WatchdogConfig.CycleIntervalMs, "hub-watchdog-cycle-interval-ms"sv); Options.AddOption("hub.watchdog.cycleprocessingbudgetms"sv, @@ -585,9 +630,9 @@ ZenHubServer::Initialize(const ZenHubServerConfig& ServerConfig, ZenServerState: // the main test range. ZenServerEnvironment::SetBaseChildId(1000); - m_ProvisionWorkerPool = - std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubInstanceProvisionThreadCount), "hub_provision"); - m_HydrationWorkerPool = std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubHydrationThreadCount), "hub_hydration"); + m_ProvisionPool = std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubInstanceProvisionThreadCount), "hub_provision"); + m_SpawnPool = std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubInstanceSpawnThreadCount), "hub_spawn"); + m_HydrationPool = std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubHydrationThreadCount), "hub_hydration"); m_DebugOptionForcedCrash = ServerConfig.ShouldCrash; @@ -695,22 +740,24 @@ ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) { ZEN_INFO("instantiating Hub"); Hub::Configuration HubConfig{ - .UseJobObject = ServerConfig.HubUseJobObject, - .BasePortNumber = ServerConfig.HubBasePortNumber, - .InstanceLimit = ServerConfig.HubInstanceLimit, - .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount, - .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit, - .InstanceMalloc = ServerConfig.HubInstanceMalloc, - .InstanceTrace = ServerConfig.HubInstanceTrace, - .InstanceTraceHost = ServerConfig.HubInstanceTraceHost, - .InstanceTraceFile = ServerConfig.HubInstanceTraceFile, - .InstanceConfigPath = ServerConfig.HubInstanceConfigPath, - .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification, - .EnableHydration = ServerConfig.HubEnableHydration, - .EnableDehydration = ServerConfig.HubEnableDehydration, - .HydrationPackEnabled = ServerConfig.HubHydrationPackEnabled, - .HydrationPackThresholdBytes = ServerConfig.HubHydrationPackThresholdBytes, - .HydrationMaxPackBytes = ServerConfig.HubHydrationMaxPackBytes, + .UseJobObject = ServerConfig.HubUseJobObject, + .BasePortNumber = ServerConfig.HubBasePortNumber, + .InstanceLimit = ServerConfig.HubInstanceLimit, + .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount, + .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit, + .InstanceMalloc = ServerConfig.HubInstanceMalloc, + .InstanceTrace = ServerConfig.HubInstanceTrace, + .InstanceTraceHost = ServerConfig.HubInstanceTraceHost, + .InstanceTraceFile = ServerConfig.HubInstanceTraceFile, + .InstanceConfigPath = ServerConfig.HubInstanceConfigPath, + .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification, + .EnableHydration = ServerConfig.HubEnableHydration, + .EnableDehydration = ServerConfig.HubEnableDehydration, + .HydrationPackEnabled = ServerConfig.HubHydrationPackEnabled, + .HydrationPackThresholdBytes = ServerConfig.HubHydrationPackThresholdBytes, + .HydrationMaxPackBytes = ServerConfig.HubHydrationMaxPackBytes, + .HydrationAsyncEnabled = ServerConfig.HubHydrationAsyncEnabled, + .HydrationAsyncMaxConcurrentRequests = ServerConfig.HubHydrationAsyncMaxConcurrentRequests, .WatchDog = { .CycleInterval = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleIntervalMs), @@ -722,9 +769,10 @@ ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) .ActivityCheckConnectTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckConnectTimeoutMs), .ActivityCheckRequestTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckRequestTimeoutMs), }, - .ResourceLimits = ResolveLimits(ServerConfig), - .OptionalProvisionWorkerPool = m_ProvisionWorkerPool.get(), - .OptionalHydrationWorkerPool = m_HydrationWorkerPool.get()}; + .ResourceLimits = ResolveLimits(ServerConfig), + .OptionalProvisionPool = m_ProvisionPool.get(), + .OptionalSpawnPool = m_SpawnPool.get(), + .OptionalHydrationPool = m_HydrationPool.get()}; if (!ServerConfig.HydrationTargetConfigPath.empty()) { @@ -842,9 +890,8 @@ ZenHubServer::InitializeConsulRegistration(const ZenHubServerConfig& ServerConfi } consul::ServiceRegistrationInfo Info; - Info.ServiceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId); - Info.ServiceName = "zen-hub"; - // Info.Address = "localhost"; // Let the consul agent figure out out external address // TODO: Info.BaseUri? + Info.ServiceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId); + Info.ServiceName = "zen-hub"; Info.Port = static_cast<uint16_t>(EffectivePort); Info.HealthEndpoint = "health"; Info.Tags = std::vector<std::pair<std::string, std::string>>{ diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h index 6416792a6..0d5ac600f 100644 --- a/src/zenserver/hub/zenhubserver.h +++ b/src/zenserver/hub/zenhubserver.h @@ -53,18 +53,22 @@ struct ZenHubServerConfig : public ZenServerConfig bool HubHydrationPackEnabled = true; // Concatenate small files into raw CAS pack blobs during dehydrate uint64_t HubHydrationPackThresholdBytes = DefaultPackThresholdBytes; // Files strictly smaller than this are pack candidates uint64_t HubHydrationMaxPackBytes = DefaultMaxPackBytes; // Upper bound on a pack's concatenation size - std::string HubInstanceHttpClass = "asio"; - std::string HubInstanceMalloc; - std::string HubInstanceTrace; - std::string HubInstanceTraceHost; - std::string HubInstanceTraceFile; - uint32_t HubInstanceHttpThreadCount = 0; // Automatic - uint32_t HubInstanceProvisionThreadCount = 0; // Synchronous provisioning - uint32_t HubHydrationThreadCount = 0; // Synchronous hydration/dehydration - int HubInstanceCoreLimit = 0; // Automatic - std::filesystem::path HubInstanceConfigPath; // Path to Lua config file - std::string HydrationTargetSpecification; // hydration/dehydration target specification - std::filesystem::path HydrationTargetConfigPath; // path to JSON config file (mutually exclusive with HydrationTargetSpecification) + bool HubHydrationAsyncEnabled = true; // Route S3 hydration through AsyncHttpClient + uint32_t HubHydrationAsyncMaxConcurrentRequests = + 128; // Hub-wide cap on concurrent S3 hydration requests (only when HubHydrationAsyncEnabled) + std::string HubInstanceHttpClass = "asio"; + std::string HubInstanceMalloc; + std::string HubInstanceTrace; + std::string HubInstanceTraceHost; + std::string HubInstanceTraceFile; + uint32_t HubInstanceHttpThreadCount = 0; // Automatic + uint32_t HubInstanceProvisionThreadCount = 0; // Hub-wide hydrate/dehydrate scheduling pool size + uint32_t HubInstanceSpawnThreadCount = 0; // Hub-wide child process spawn/despawn pool size + uint32_t HubHydrationThreadCount = 0; // Internal hydration parallelism (per-file) + int HubInstanceCoreLimit = 0; // Automatic + std::filesystem::path HubInstanceConfigPath; // Path to Lua config file + std::string HydrationTargetSpecification; // hydration/dehydration target specification + std::filesystem::path HydrationTargetConfigPath; // path to JSON config file (mutually exclusive with HydrationTargetSpecification) ZenHubWatchdogConfig WatchdogConfig; uint64_t HubProvisionDiskLimitBytes = 0; uint32_t HubProvisionDiskLimitPercent = 0; @@ -137,8 +141,9 @@ private: bool m_DebugOptionForcedCrash = false; std::unique_ptr<HttpProxyHandler> m_Proxy; - std::unique_ptr<WorkerThreadPool> m_ProvisionWorkerPool; - std::unique_ptr<WorkerThreadPool> m_HydrationWorkerPool; + std::unique_ptr<WorkerThreadPool> m_ProvisionPool; + std::unique_ptr<WorkerThreadPool> m_SpawnPool; + std::unique_ptr<WorkerThreadPool> m_HydrationPool; std::unique_ptr<Hub> m_Hub; std::unique_ptr<HttpHubService> m_HubService; diff --git a/src/zenutil/cloud/s3client.cpp b/src/zenutil/cloud/s3client.cpp index ab80cfcc7..83443b98b 100644 --- a/src/zenutil/cloud/s3client.cpp +++ b/src/zenutil/cloud/s3client.cpp @@ -4,6 +4,7 @@ #include <zenutil/cloud/imdscredentials.h> #include <zenutil/cloud/minioprocess.h> +#include <zenutil/cloud/s3response.h> #include <zencore/except_fmt.h> #include <zencore/iobuffer.h> @@ -19,210 +20,21 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -namespace { - - /// The SHA-256 hash of an empty payload, precomputed - constexpr std::string_view EmptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; - - /// Simple XML value extractor. Finds the text content between <Tag> and </Tag>. - /// This is intentionally minimal - we only need to parse ListBucketResult responses. - /// Returns a string_view into the original XML when no entity decoding is needed. - std::string_view ExtractXmlValue(std::string_view Xml, std::string_view Tag) - { - std::string OpenTag = fmt::format("<{}>", Tag); - std::string CloseTag = fmt::format("</{}>", Tag); - - size_t Start = Xml.find(OpenTag); - if (Start == std::string_view::npos) - { - return {}; - } - Start += OpenTag.size(); - - size_t End = Xml.find(CloseTag, Start); - if (End == std::string_view::npos) - { - return {}; - } - - return Xml.substr(Start, End - Start); - } - - /// Decode the five standard XML entities (& < > " ') into a StringBuilderBase. - void DecodeXmlEntities(std::string_view Input, StringBuilderBase& Out) - { - if (Input.find('&') == std::string_view::npos) - { - Out.Append(Input); - return; - } - - for (size_t i = 0; i < Input.size(); ++i) - { - if (Input[i] == '&') - { - std::string_view Remaining = Input.substr(i); - if (Remaining.starts_with("&")) - { - Out.Append('&'); - i += 4; - } - else if (Remaining.starts_with("<")) - { - Out.Append('<'); - i += 3; - } - else if (Remaining.starts_with(">")) - { - Out.Append('>'); - i += 3; - } - else if (Remaining.starts_with(""")) - { - Out.Append('"'); - i += 5; - } - else if (Remaining.starts_with("'")) - { - Out.Append('\''); - i += 5; - } - else - { - Out.Append(Input[i]); - } - } - else - { - Out.Append(Input[i]); - } - } - } - - /// Convenience: decode XML entities and return as std::string. - std::string DecodeXmlEntities(std::string_view Input) - { - if (Input.find('&') == std::string_view::npos) - { - return std::string(Input); - } - - ExtendableStringBuilder<256> Sb; - DecodeXmlEntities(Input, Sb); - return Sb.ToString(); - } - - /// Join a path and canonical query string into a full request path for the HTTP client. - std::string BuildRequestPath(std::string_view Path, std::string_view CanonicalQS) - { - if (CanonicalQS.empty()) - { - return std::string(Path); - } - return fmt::format("{}?{}", Path, CanonicalQS); - } - - /// Case-insensitive header lookup in an HttpClient response header map. - const std::string* FindResponseHeader(const HttpClient::KeyValueMap& Headers, std::string_view Name) - { - for (const auto& [K, V] : *Headers) - { - if (StrCaseCompare(K, Name) == 0) - { - return &V; - } - } - return nullptr; - } - - /// Extract Code/Message from an S3 XML error body. Returns true if an <Error> element was - /// found, even if Code/Message are empty. - bool ExtractS3Error(std::string_view Body, std::string_view& OutCode, std::string_view& OutMessage) - { - if (Body.find("<Error>") == std::string_view::npos) - { - return false; - } - OutCode = ExtractXmlValue(Body, "Code"); - OutMessage = ExtractXmlValue(Body, "Message"); - return true; - } - - /// True if the response indicates S3 throttling (503 SlowDown / ServiceUnavailable / 429). - /// Code is checked on both the HTTP status and the XML error code so we catch proxies that - /// return 200 with a SlowDown body. - bool IsS3Throttled(const HttpClient::Response& Response, std::string_view ErrorCode) - { - const int Status = static_cast<int>(Response.StatusCode); - if (Status == 503 || Status == 429) - { - return true; - } - if (ErrorCode == "SlowDown" || ErrorCode == "ServiceUnavailable" || ErrorCode == "ThrottlingException" || - ErrorCode == "RequestLimitExceeded" || ErrorCode == "TooManyRequests") - { - return true; - } - return false; - } - - /// Build a human-readable error message for a failed S3 response. When the response body - /// contains an S3 `<Error>` element, the Code and Message fields are included in the string - /// so transient 4xx/5xx failures (SignatureDoesNotMatch, AuthorizationHeaderMalformed, etc.) - /// show up in logs instead of being swallowed. Falls back to the generic HTTP/transport - /// message when no XML body is available (HEAD responses, transport errors). - /// Also emits a distinct `S3 THROTTLED` warning when the response indicates throttling so - /// callers can grep for it without parsing combined error text. - std::string S3ErrorMessage(std::string_view Prefix, const HttpClient::Response& Response) - { - if (!Response.Error.has_value() && Response.ResponsePayload) - { - std::string_view Body(reinterpret_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize()); - std::string_view Code; - std::string_view Message; - if (ExtractS3Error(Body, Code, Message) && (!Code.empty() || !Message.empty())) - { - ExtendableStringBuilder<256> Decoded; - DecodeXmlEntities(Message, Decoded); - if (IsS3Throttled(Response, Code)) - { - ZEN_WARN("S3 THROTTLED [{}] status={} code='{}' message='{}'", - Prefix, - static_cast<int>(Response.StatusCode), - Code, - Decoded.ToView()); - } - return fmt::format("{}: HTTP status ({}) {} - {}", Prefix, static_cast<int>(Response.StatusCode), Code, Decoded.ToView()); - } - } - if (IsS3Throttled(Response, {})) - { - ZEN_WARN("S3 THROTTLED [{}] status={} (no XML body)", Prefix, static_cast<int>(Response.StatusCode)); - } - return Response.ErrorMessage(Prefix); - } - -} // namespace - std::string_view S3GetObjectResult::NotFoundErrorText = "Not found"; S3Client::S3Client(const S3ClientOptions& Options) : m_Log(logging::Get("s3")) -, m_BucketName(Options.BucketName) -, m_Region(Options.Region) -, m_Endpoint(Options.Endpoint) -, m_PathStyle(Options.PathStyle) +, m_Builder(Options.Region, Options.BucketName, Options.Endpoint, Options.PathStyle) , m_Credentials(Options.Credentials) , m_CredentialProvider(Options.CredentialProvider) -, m_HttpClient(BuildEndpoint(), Options.HttpSettings) +, m_HttpClient(std::string(m_Builder.Endpoint()), Options.HttpSettings) , m_Verbose(Options.HttpSettings.Verbose) { - m_Host = BuildHostHeader(); ZEN_INFO("S3 client configured for bucket '{}' in region '{}' (endpoint: {}, {})", - m_BucketName, - m_Region, + m_Builder.BucketName(), + m_Builder.Region(), m_HttpClient.GetBaseUri(), - m_PathStyle ? "path-style" : "virtual-hosted"); + m_Builder.PathStyle() ? "path-style" : "virtual-hosted"); } S3Client::~S3Client() = default; @@ -235,19 +47,15 @@ S3Client::GetCurrentCredentials() SigV4Credentials Creds = m_CredentialProvider->GetCredentials(); if (!Creds.AccessKeyId.empty()) { - // Invalidate the signing key cache when the access key changes, and update stored - // credentials atomically under the same lock so callers see a consistent snapshot. - RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); - if (Creds.AccessKeyId != m_Credentials.AccessKeyId) - { - m_CachedDateStamp.clear(); - } + // Update stored credentials atomically so callers see a consistent snapshot. + // The builder's signing-key cache is keyed on (DateStamp, AccessKeyId), so it + // self-invalidates on the next Sign() call when either rotates. + RwLock::ExclusiveLockScope _(m_CredentialsLock); m_Credentials = Creds; - // Return Creds directly - avoids reading m_Credentials after releasing the lock, - // which would race with another concurrent write. return Creds; } // IMDS returned empty credentials; fall back to the last known-good credentials. + RwLock::SharedLockScope _(m_CredentialsLock); return m_Credentials; } return m_Credentials; @@ -261,153 +69,6 @@ S3Client::BuildNoCredentialsError(std::string Context) return Err; } -std::string -S3Client::BuildEndpoint() const -{ - if (!m_Endpoint.empty()) - { - return m_Endpoint; - } - - if (m_PathStyle) - { - // Path-style: https://s3.region.amazonaws.com - return fmt::format("https://s3.{}.amazonaws.com", m_Region); - } - - // Virtual-hosted style: https://bucket.s3.region.amazonaws.com - return fmt::format("https://{}.s3.{}.amazonaws.com", m_BucketName, m_Region); -} - -std::string -S3Client::BuildHostHeader() const -{ - if (!m_Endpoint.empty()) - { - // Extract host from custom endpoint URL (strip scheme) - std::string_view Ep = m_Endpoint; - if (size_t Pos = Ep.find("://"); Pos != std::string_view::npos) - { - Ep = Ep.substr(Pos + 3); - } - // Strip trailing slash - if (!Ep.empty() && Ep.back() == '/') - { - Ep = Ep.substr(0, Ep.size() - 1); - } - return std::string(Ep); - } - - if (m_PathStyle) - { - return fmt::format("s3.{}.amazonaws.com", m_Region); - } - - return fmt::format("{}.s3.{}.amazonaws.com", m_BucketName, m_Region); -} - -std::string -S3Client::KeyToPath(std::string_view Key) const -{ - if (m_PathStyle) - { - return fmt::format("/{}/{}", m_BucketName, Key); - } - return fmt::format("/{}", Key); -} - -std::string -S3Client::BucketRootPath() const -{ - if (m_PathStyle) - { - return fmt::format("/{}/", m_BucketName); - } - return "/"; -} - -Sha256Digest -S3Client::GetSigningKey(std::string_view DateStamp) -{ - // Fast path: shared lock for cache hit (common case - key only changes once per day) - { - RwLock::SharedLockScope SharedLock(m_SigningKeyLock); - if (m_CachedDateStamp == DateStamp) - { - return m_CachedSigningKey; - } - } - - // Slow path: exclusive lock to recompute the signing key - RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); - - // Double-check after acquiring exclusive lock (another thread may have updated it) - if (m_CachedDateStamp == DateStamp) - { - return m_CachedSigningKey; - } - - std::string SecretPrefix = fmt::format("AWS4{}", m_Credentials.SecretAccessKey); - - Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), DateStamp.data(), DateStamp.size()); - SecureZeroSecret(SecretPrefix.data(), SecretPrefix.size()); - - Sha256Digest RegionKey = ComputeHmacSha256(DateKey, m_Region); - Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, "s3"); - m_CachedSigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); - m_CachedDateStamp = std::string(DateStamp); - - return m_CachedSigningKey; -} - -HttpClient::KeyValueMap -S3Client::SignRequest(const SigV4Credentials& Credentials, - std::string_view Method, - std::string_view Path, - std::string_view CanonicalQueryString, - std::string_view PayloadHash, - std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders) -{ - std::string AmzDate = GetAmzTimestamp(); - - // Build sorted headers to sign (must be sorted by lowercase name) - std::vector<std::pair<std::string, std::string>> HeadersToSign; - HeadersToSign.reserve(4 + ExtraSignedHeaders.size()); - HeadersToSign.emplace_back("host", m_Host); - HeadersToSign.emplace_back("x-amz-content-sha256", std::string(PayloadHash)); - HeadersToSign.emplace_back("x-amz-date", AmzDate); - if (!Credentials.SessionToken.empty()) - { - HeadersToSign.emplace_back("x-amz-security-token", Credentials.SessionToken); - } - for (const auto& [K, V] : ExtraSignedHeaders) - { - HeadersToSign.emplace_back(K, V); - } - std::sort(HeadersToSign.begin(), HeadersToSign.end()); - - std::string_view DateStamp(AmzDate.data(), 8); - Sha256Digest SigningKey = GetSigningKey(DateStamp); - - SigV4SignedHeaders Signed = - SignRequestV4(Credentials, Method, Path, CanonicalQueryString, m_Region, "s3", AmzDate, HeadersToSign, PayloadHash, &SigningKey); - - HttpClient::KeyValueMap Result; - Result->emplace("Authorization", std::move(Signed.Authorization)); - Result->emplace("x-amz-date", std::move(Signed.AmzDate)); - Result->emplace("x-amz-content-sha256", std::move(Signed.PayloadHash)); - if (!Credentials.SessionToken.empty()) - { - Result->emplace("x-amz-security-token", Credentials.SessionToken); - } - for (const auto& [K, V] : ExtraSignedHeaders) - { - Result->emplace(K, V); - } - - return Result; -} - S3Result S3Client::PutObject(std::string_view Key, IoBuffer Content) { @@ -417,12 +78,12 @@ S3Client::PutObject(std::string_view Key, IoBuffer Content) return S3Result{std::move(Err)}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); // Hash the payload std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, "", PayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "PUT", Path, "", PayloadHash); HttpClient::Response Response = m_HttpClient.Put(Path, Content, Headers); if (!Response.IsSuccess()) @@ -448,9 +109,9 @@ S3Client::GetObject(std::string_view Key, const std::filesystem::path& TempFileP return S3GetObjectResult{S3Result{std::move(Err)}, {}}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "GET", Path, "", S3EmptyPayloadHash); HttpClient::Response Response = m_HttpClient.Download(Path, TempFilePath, Headers); if (!Response.IsSuccess()) @@ -484,9 +145,9 @@ S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t Ran return S3GetObjectResult{S3Result{std::move(Err)}, {}}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "GET", Path, "", S3EmptyPayloadHash); Headers->emplace("Range", fmt::format("bytes={}-{}", RangeStart, RangeStart + RangeSize - 1)); HttpClient::Response Response = m_HttpClient.Get(Path, Headers); @@ -537,9 +198,9 @@ S3Client::DeleteObject(std::string_view Key) return S3Result{std::move(Err)}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "DELETE", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "DELETE", Path, "", S3EmptyPayloadHash); HttpClient::Response Response = m_HttpClient.Delete(Path, Headers); if (!Response.IsSuccess()) @@ -565,17 +226,17 @@ S3Client::Touch(std::string_view Key) return S3Result{std::move(Err)}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); // x-amz-copy-source is always "/bucket/key" regardless of addressing style. // Key must be URI-encoded except for '/' separators. When source and destination // are identical, REPLACE is required; COPY is rejected with InvalidRequest. const std::array<std::pair<std::string, std::string>, 2> ExtraSigned{{ - {"x-amz-copy-source", fmt::format("/{}/{}", m_BucketName, AwsUriEncode(Key, false))}, + {"x-amz-copy-source", fmt::format("/{}/{}", m_Builder.BucketName(), AwsUriEncode(Key, false))}, {"x-amz-metadata-directive", "REPLACE"}, }}; - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, "", EmptyPayloadHash, ExtraSigned); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "PUT", Path, "", S3EmptyPayloadHash, ExtraSigned); HttpClient::Response Response = m_HttpClient.Put(Path, IoBuffer{}, Headers); if (!Response.IsSuccess()) @@ -589,7 +250,7 @@ S3Client::Touch(std::string_view Key) std::string_view ResponseBody = Response.AsText(); std::string_view ErrorCode; std::string_view ErrorMessage; - if (ExtractS3Error(ResponseBody, ErrorCode, ErrorMessage)) + if (S3ExtractError(ResponseBody, ErrorCode, ErrorMessage)) { std::string Err = fmt::format("S3 Touch '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); ZEN_WARN("{}", Err); @@ -612,9 +273,9 @@ S3Client::HeadObject(std::string_view Key) return S3HeadObjectResult{S3Result{std::move(Err)}, {}, HeadObjectResult::Error}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "HEAD", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "HEAD", Path, "", S3EmptyPayloadHash); HttpClient::Response Response = m_HttpClient.Head(Path, Headers); if (!Response.IsSuccess()) @@ -632,17 +293,17 @@ S3Client::HeadObject(std::string_view Key) S3ObjectInfo Info; Info.Key = std::string(Key); - if (const std::string* V = FindResponseHeader(Response.Header, "content-length")) + if (const std::string* V = S3FindResponseHeader(Response.Header, "content-length")) { Info.Size = ParseInt<uint64_t>(*V).value_or(0); } - if (const std::string* V = FindResponseHeader(Response.Header, "etag")) + if (const std::string* V = S3FindResponseHeader(Response.Header, "etag")) { Info.ETag = *V; } - if (const std::string* V = FindResponseHeader(Response.Header, "last-modified")) + if (const std::string* V = S3FindResponseHeader(Response.Header, "last-modified")) { Info.LastModified = *V; } @@ -687,10 +348,10 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) } std::string CanonicalQS = BuildCanonicalQueryString(std::move(QueryParams)); - std::string RootPath = BucketRootPath(); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", RootPath, CanonicalQS, EmptyPayloadHash); + std::string RootPath = m_Builder.BucketRootPath(); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "GET", RootPath, CanonicalQS, S3EmptyPayloadHash); - std::string FullPath = BuildRequestPath(RootPath, CanonicalQS); + std::string FullPath = S3BuildRequestPath(RootPath, CanonicalQS); HttpClient::Response Response = m_HttpClient.Get(FullPath, Headers); if (!Response.IsSuccess()) { @@ -722,11 +383,11 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) std::string_view ContentsXml = Remaining.substr(ContentsStart, ContentsEnd - ContentsStart + 11); S3ObjectInfo Info; - Info.Key = DecodeXmlEntities(ExtractXmlValue(ContentsXml, "Key")); - Info.ETag = DecodeXmlEntities(ExtractXmlValue(ContentsXml, "ETag")); - Info.LastModified = std::string(ExtractXmlValue(ContentsXml, "LastModified")); + Info.Key = S3DecodeXmlEntities(S3ExtractXmlValue(ContentsXml, "Key")); + Info.ETag = S3DecodeXmlEntities(S3ExtractXmlValue(ContentsXml, "ETag")); + Info.LastModified = std::string(S3ExtractXmlValue(ContentsXml, "LastModified")); - std::string_view SizeStr = ExtractXmlValue(ContentsXml, "Size"); + std::string_view SizeStr = S3ExtractXmlValue(ContentsXml, "Size"); if (!SizeStr.empty()) { Info.Size = ParseInt<uint64_t>(SizeStr).value_or(0); @@ -741,13 +402,13 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) } // Check if there are more pages - std::string_view IsTruncated = ExtractXmlValue(ResponseBody, "IsTruncated"); + std::string_view IsTruncated = S3ExtractXmlValue(ResponseBody, "IsTruncated"); if (IsTruncated != "true") { break; } - std::string_view NextToken = ExtractXmlValue(ResponseBody, "NextContinuationToken"); + std::string_view NextToken = S3ExtractXmlValue(ResponseBody, "NextContinuationToken"); if (NextToken.empty()) { break; @@ -779,12 +440,12 @@ S3Client::CreateMultipartUpload(std::string_view Key) return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploads", ""}}); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "POST", Path, CanonicalQS, EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "POST", Path, CanonicalQS, S3EmptyPayloadHash); - std::string FullPath = BuildRequestPath(Path, CanonicalQS); + std::string FullPath = S3BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Post(FullPath, Headers); if (!Response.IsSuccess()) { @@ -800,7 +461,7 @@ S3Client::CreateMultipartUpload(std::string_view Key) // <UploadId>...</UploadId> // </InitiateMultipartUploadResult> std::string_view ResponseBody = Response.AsText(); - std::string_view UploadId = ExtractXmlValue(ResponseBody, "UploadId"); + std::string_view UploadId = S3ExtractXmlValue(ResponseBody, "UploadId"); if (UploadId.empty()) { std::string Err = "failed to parse UploadId from CreateMultipartUpload response"; @@ -824,7 +485,7 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P return S3UploadPartResult{S3Result{std::move(Err)}, {}}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({ {"partNumber", fmt::format("{}", PartNumber)}, {"uploadId", std::string(UploadId)}, @@ -832,9 +493,9 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, CanonicalQS, PayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "PUT", Path, CanonicalQS, PayloadHash); - std::string FullPath = BuildRequestPath(Path, CanonicalQS); + std::string FullPath = S3BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Put(FullPath, Content, Headers); if (!Response.IsSuccess()) { @@ -844,7 +505,7 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P } // Extract ETag from response headers - const std::string* ETag = FindResponseHeader(Response.Header, "etag"); + const std::string* ETag = S3FindResponseHeader(Response.Header, "etag"); if (!ETag) { std::string Err = "S3 UploadPart response missing ETag header"; @@ -870,7 +531,7 @@ S3Client::CompleteMultipartUpload(std::string_view Key, return S3Result{std::move(Err)}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); // Build the CompleteMultipartUpload XML payload @@ -885,12 +546,12 @@ S3Client::CompleteMultipartUpload(std::string_view Key, std::string_view XmlView = XmlBody.ToView(); std::string PayloadHash = Sha256ToHex(ComputeSha256(XmlView)); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "POST", Path, CanonicalQS, PayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "POST", Path, CanonicalQS, PayloadHash); Headers->emplace("Content-Type", "application/xml"); IoBuffer Payload(IoBuffer::Clone, XmlView.data(), XmlView.size()); - std::string FullPath = BuildRequestPath(Path, CanonicalQS); + std::string FullPath = S3BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Post(FullPath, Payload, Headers); if (!Response.IsSuccess()) { @@ -903,7 +564,7 @@ S3Client::CompleteMultipartUpload(std::string_view Key, std::string_view ResponseBody = Response.AsText(); std::string_view ErrorCode; std::string_view ErrorMessage; - if (ExtractS3Error(ResponseBody, ErrorCode, ErrorMessage)) + if (S3ExtractError(ResponseBody, ErrorCode, ErrorMessage)) { std::string Err = fmt::format("S3 CompleteMultipartUpload '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); ZEN_WARN("{}", Err); @@ -926,12 +587,12 @@ S3Client::AbortMultipartUpload(std::string_view Key, std::string_view UploadId) return S3Result{std::move(Err)}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "DELETE", Path, CanonicalQS, EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "DELETE", Path, CanonicalQS, S3EmptyPayloadHash); - std::string FullPath = BuildRequestPath(Path, CanonicalQS); + std::string FullPath = S3BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Delete(FullPath, Headers); if (!Response.IsSuccess()) { @@ -968,15 +629,15 @@ S3Client::GeneratePresignedUrlForMethod(std::string_view Key, std::string_view M return {}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); std::string Scheme = "https"; - if (!m_Endpoint.empty() && m_Endpoint.starts_with("http://")) + if (!m_Builder.Endpoint().empty() && m_Builder.Endpoint().starts_with("http://")) { Scheme = "http"; } - return GeneratePresignedUrl(Credentials, Method, Scheme, m_Host, Path, m_Region, "s3", ExpiresIn); + return GeneratePresignedUrl(Credentials, Method, Scheme, m_Builder.Host(), Path, m_Builder.Region(), "s3", ExpiresIn); } S3Result @@ -1004,8 +665,26 @@ S3Client::PutObjectMultipart(std::string_view Key, const std::string& UploadId = InitResult.UploadId; - // Upload parts sequentially - // TODO: upload parts in parallel for improved throughput on large uploads + // Cleanup helper: AbortMultipartUpload itself can throw on transport failure; + // inside the catch (...) below that would replace the original exception with + // a less actionable transport one. Swallow + log. + auto SafeAbort = [this, &Key, &UploadId]() noexcept { + try + { + AbortMultipartUpload(Key, UploadId); + } + catch (const std::exception& Ex) + { + ZEN_WARN("S3 AbortMultipartUpload '{}' threw during cleanup: {}", Key, Ex.what()); + } + catch (...) + { + ZEN_WARN("S3 AbortMultipartUpload '{}' threw during cleanup", Key); + } + }; + + // Sequential upload by design; for parallel multipart use S3AsyncStorage::PutMultipart + // in the hub hydration path. std::vector<std::pair<uint32_t, std::string>> PartETags; uint64_t Offset = 0; @@ -1020,7 +699,7 @@ S3Client::PutObjectMultipart(std::string_view Key, S3UploadPartResult PartResult = UploadPart(Key, UploadId, PartNumber, std::move(PartContent)); if (!PartResult) { - AbortMultipartUpload(Key, UploadId); + SafeAbort(); return S3Result{std::move(PartResult.Error)}; } @@ -1032,13 +711,13 @@ S3Client::PutObjectMultipart(std::string_view Key, S3Result CompleteResult = CompleteMultipartUpload(Key, UploadId, PartETags); if (!CompleteResult) { - AbortMultipartUpload(Key, UploadId); + SafeAbort(); return CompleteResult; } } catch (...) { - AbortMultipartUpload(Key, UploadId); + SafeAbort(); throw; } @@ -1077,25 +756,25 @@ TEST_CASE("s3client.xml_extract") "<Contents><Key>test/file.txt</Key><Size>1234</Size>" "<ETag>\"abc123\"</ETag><LastModified>2024-01-01T00:00:00Z</LastModified></Contents>"; - CHECK(ExtractXmlValue(Xml, "Key") == "test/file.txt"); - CHECK(ExtractXmlValue(Xml, "Size") == "1234"); - CHECK(ExtractXmlValue(Xml, "ETag") == "\"abc123\""); - CHECK(ExtractXmlValue(Xml, "LastModified") == "2024-01-01T00:00:00Z"); - CHECK(ExtractXmlValue(Xml, "NonExistent") == ""); + CHECK(S3ExtractXmlValue(Xml, "Key") == "test/file.txt"); + CHECK(S3ExtractXmlValue(Xml, "Size") == "1234"); + CHECK(S3ExtractXmlValue(Xml, "ETag") == "\"abc123\""); + CHECK(S3ExtractXmlValue(Xml, "LastModified") == "2024-01-01T00:00:00Z"); + CHECK(S3ExtractXmlValue(Xml, "NonExistent") == ""); } TEST_CASE("s3client.xml_entity_decode") { - CHECK(DecodeXmlEntities("no entities") == "no entities"); - CHECK(DecodeXmlEntities("a&b") == "a&b"); - CHECK(DecodeXmlEntities("<tag>") == "<tag>"); - CHECK(DecodeXmlEntities(""hello'") == "\"hello'"); - CHECK(DecodeXmlEntities("&&") == "&&"); - CHECK(DecodeXmlEntities("") == ""); + CHECK(S3DecodeXmlEntities("no entities") == "no entities"); + CHECK(S3DecodeXmlEntities("a&b") == "a&b"); + CHECK(S3DecodeXmlEntities("<tag>") == "<tag>"); + CHECK(S3DecodeXmlEntities(""hello'") == "\"hello'"); + CHECK(S3DecodeXmlEntities("&&") == "&&"); + CHECK(S3DecodeXmlEntities("") == ""); // Key with entities as S3 would return it std::string_view Xml = "<Key>path/file&name<1>.txt</Key>"; - CHECK(DecodeXmlEntities(ExtractXmlValue(Xml, "Key")) == "path/file&name<1>.txt"); + CHECK(S3DecodeXmlEntities(S3ExtractXmlValue(Xml, "Key")) == "path/file&name<1>.txt"); } TEST_CASE("s3client.path_style_addressing") @@ -1129,6 +808,97 @@ TEST_CASE("s3client.virtual_hosted_addressing") CHECK(Client.Region() == "eu-west-1"); } +TEST_CASE("s3requestbuilder.path_style_paths") +{ + S3RequestBuilder Builder("us-east-1", "test-bucket", "http://localhost:9000", /*PathStyle*/ true); + + CHECK(Builder.Region() == "us-east-1"); + CHECK(Builder.BucketName() == "test-bucket"); + CHECK(Builder.PathStyle()); + CHECK(Builder.Endpoint() == "http://localhost:9000"); + CHECK(Builder.Host() == "localhost:9000"); + CHECK(Builder.KeyToPath("foo/bar") == "/test-bucket/foo/bar"); + CHECK(Builder.BucketRootPath() == "/test-bucket/"); +} + +TEST_CASE("s3requestbuilder.virtual_hosted_paths") +{ + S3RequestBuilder Builder("eu-west-1", "my-bucket", /*Endpoint*/ "", /*PathStyle*/ false); + + CHECK(Builder.Endpoint() == "https://my-bucket.s3.eu-west-1.amazonaws.com"); + CHECK(Builder.Host() == "my-bucket.s3.eu-west-1.amazonaws.com"); + CHECK(Builder.KeyToPath("foo/bar") == "/foo/bar"); + CHECK(Builder.BucketRootPath() == "/"); +} + +TEST_CASE("s3requestbuilder.derived_endpoint_path_style") +{ + S3RequestBuilder Builder("us-west-2", "another-bucket", /*Endpoint*/ "", /*PathStyle*/ true); + CHECK(Builder.Endpoint() == "https://s3.us-west-2.amazonaws.com"); + CHECK(Builder.Host() == "s3.us-west-2.amazonaws.com"); +} + +TEST_CASE("s3requestbuilder.sign_request_headers") +{ + S3RequestBuilder Builder("us-east-1", "bucket", "http://localhost:9000", /*PathStyle*/ true); + + SigV4Credentials Creds; + Creds.AccessKeyId = "AKIDEXAMPLE"; + Creds.SecretAccessKey = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"; + + HttpClient::KeyValueMap Headers = Builder.SignRequest(Creds, "GET", "/bucket/foo", "", S3EmptyPayloadHash); + + // All four required SigV4 headers present. + CHECK(Headers->find("Authorization") != Headers->end()); + CHECK(Headers->find("x-amz-date") != Headers->end()); + CHECK(Headers->find("x-amz-content-sha256") != Headers->end()); + + // SessionToken absent -> no x-amz-security-token header. + CHECK(Headers->find("x-amz-security-token") == Headers->end()); + + const std::string& Auth = Headers->find("Authorization")->second; + CHECK(Auth.starts_with("AWS4-HMAC-SHA256 ")); + CHECK(Auth.find("Credential=AKIDEXAMPLE/") != std::string::npos); + CHECK(Auth.find("/us-east-1/s3/aws4_request") != std::string::npos); + CHECK(Auth.find("Signature=") != std::string::npos); +} + +TEST_CASE("s3requestbuilder.session_token_emits_security_header") +{ + S3RequestBuilder Builder("us-east-1", "bucket", "http://localhost:9000", true); + + SigV4Credentials Creds; + Creds.AccessKeyId = "ASIA-tmp"; + Creds.SecretAccessKey = "secret"; + Creds.SessionToken = "sts-session-token-value"; + + HttpClient::KeyValueMap Headers = Builder.SignRequest(Creds, "PUT", "/bucket/key", "", S3EmptyPayloadHash); + + auto It = Headers->find("x-amz-security-token"); + REQUIRE(It != Headers->end()); + CHECK(It->second == "sts-session-token-value"); +} + +TEST_CASE("s3requestbuilder.signing_key_cache_invalidates_on_key_rotate") +{ + // Two consecutive Sign() calls with the same date but different AccessKeyId + // must produce different Authorization signatures, proving the cache is keyed + // on (DateStamp, AccessKeyId) and not date alone. + S3RequestBuilder Builder("us-east-1", "bucket", "http://localhost:9000", true); + + SigV4Credentials A; + A.AccessKeyId = "AKIDEXAMPLE"; + A.SecretAccessKey = "secretA"; + HttpClient::KeyValueMap HA = Builder.SignRequest(A, "GET", "/bucket/foo", "", S3EmptyPayloadHash); + + SigV4Credentials B; + B.AccessKeyId = "AKIDEXAMPLE2"; + B.SecretAccessKey = "secretB"; + HttpClient::KeyValueMap HB = Builder.SignRequest(B, "GET", "/bucket/foo", "", S3EmptyPayloadHash); + + CHECK(HA->find("Authorization")->second != HB->find("Authorization")->second); +} + TEST_CASE("s3client.minio_integration") { using namespace std::literals; diff --git a/src/zenutil/cloud/s3requestbuilder.cpp b/src/zenutil/cloud/s3requestbuilder.cpp new file mode 100644 index 000000000..7bf200308 --- /dev/null +++ b/src/zenutil/cloud/s3requestbuilder.cpp @@ -0,0 +1,149 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/s3requestbuilder.h> + +#include <zencore/fmtutils.h> + +#include <algorithm> + +namespace zen { + +S3RequestBuilder::S3RequestBuilder(std::string Region, std::string BucketName, std::string Endpoint, bool PathStyle) +: m_Region(std::move(Region)) +, m_BucketName(std::move(BucketName)) +, m_Endpoint(DeriveEndpoint(Endpoint, m_Region, m_BucketName, PathStyle)) +, m_Host(HostFromEndpoint(m_Endpoint)) +, m_PathStyle(PathStyle) +{ +} + +std::string +S3RequestBuilder::DeriveEndpoint(std::string_view Endpoint, std::string_view Region, std::string_view BucketName, bool PathStyle) +{ + if (!Endpoint.empty()) + { + return std::string(Endpoint); + } + if (PathStyle) + { + return fmt::format("https://s3.{}.amazonaws.com", Region); + } + return fmt::format("https://{}.s3.{}.amazonaws.com", BucketName, Region); +} + +std::string +S3RequestBuilder::HostFromEndpoint(std::string_view Endpoint) +{ + std::string_view Ep = Endpoint; + if (size_t Pos = Ep.find("://"); Pos != std::string_view::npos) + { + Ep = Ep.substr(Pos + 3); + } + if (!Ep.empty() && Ep.back() == '/') + { + Ep = Ep.substr(0, Ep.size() - 1); + } + return std::string(Ep); +} + +std::string +S3RequestBuilder::KeyToPath(std::string_view Key) const +{ + if (m_PathStyle) + { + return fmt::format("/{}/{}", m_BucketName, Key); + } + return fmt::format("/{}", Key); +} + +std::string +S3RequestBuilder::BucketRootPath() const +{ + if (m_PathStyle) + { + return fmt::format("/{}/", m_BucketName); + } + return "/"; +} + +// Cached for the (DateStamp, AccessKeyId) tuple. DateStamp rolls over at UTC +// midnight; concurrent callers around the rollover may each compute a fresh +// key once before the cache settles on the new day's value. Harmless extra +// HMACs, no signature corruption. +Sha256Digest +S3RequestBuilder::GetSigningKey(std::string_view DateStamp, const SigV4Credentials& Credentials) +{ + { + RwLock::SharedLockScope _(m_SigningKeyLock); + if (m_CachedDateStamp == DateStamp && m_CachedAccessKeyId == Credentials.AccessKeyId) + { + return m_CachedSigningKey; + } + } + + RwLock::ExclusiveLockScope _(m_SigningKeyLock); + if (m_CachedDateStamp == DateStamp && m_CachedAccessKeyId == Credentials.AccessKeyId) + { + return m_CachedSigningKey; + } + + std::string SecretPrefix = fmt::format("AWS4{}", Credentials.SecretAccessKey); + Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), DateStamp.data(), DateStamp.size()); + SecureZeroSecret(SecretPrefix.data(), SecretPrefix.size()); + + Sha256Digest RegionKey = ComputeHmacSha256(DateKey, m_Region); + Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, "s3"); + m_CachedSigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); + m_CachedDateStamp = std::string(DateStamp); + m_CachedAccessKeyId = Credentials.AccessKeyId; + + return m_CachedSigningKey; +} + +HttpClient::KeyValueMap +S3RequestBuilder::SignRequest(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Path, + std::string_view CanonicalQueryString, + std::string_view PayloadHash, + std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders) +{ + std::string AmzDate = GetAmzTimestamp(); + + std::vector<std::pair<std::string, std::string>> HeadersToSign; + HeadersToSign.reserve(4 + ExtraSignedHeaders.size()); + HeadersToSign.emplace_back("host", m_Host); + HeadersToSign.emplace_back("x-amz-content-sha256", std::string(PayloadHash)); + HeadersToSign.emplace_back("x-amz-date", AmzDate); + if (!Credentials.SessionToken.empty()) + { + HeadersToSign.emplace_back("x-amz-security-token", Credentials.SessionToken); + } + for (const auto& [K, V] : ExtraSignedHeaders) + { + HeadersToSign.emplace_back(K, V); + } + std::sort(HeadersToSign.begin(), HeadersToSign.end()); + + std::string_view DateStamp(AmzDate.data(), 8); + Sha256Digest SigningKey = GetSigningKey(DateStamp, Credentials); + + SigV4SignedHeaders Signed = + SignRequestV4(Credentials, Method, Path, CanonicalQueryString, m_Region, "s3", AmzDate, HeadersToSign, PayloadHash, &SigningKey); + + HttpClient::KeyValueMap Result; + Result->emplace("Authorization", std::move(Signed.Authorization)); + Result->emplace("x-amz-date", std::move(Signed.AmzDate)); + Result->emplace("x-amz-content-sha256", std::move(Signed.PayloadHash)); + if (!Credentials.SessionToken.empty()) + { + Result->emplace("x-amz-security-token", Credentials.SessionToken); + } + for (const auto& [K, V] : ExtraSignedHeaders) + { + Result->emplace(K, V); + } + return Result; +} + +} // namespace zen diff --git a/src/zenutil/cloud/s3response.cpp b/src/zenutil/cloud/s3response.cpp new file mode 100644 index 000000000..a9e7f0208 --- /dev/null +++ b/src/zenutil/cloud/s3response.cpp @@ -0,0 +1,181 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/s3response.h> + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> + +namespace zen { + +std::string_view +S3ExtractXmlValue(std::string_view Xml, std::string_view Tag) +{ + std::string OpenTag = fmt::format("<{}>", Tag); + std::string CloseTag = fmt::format("</{}>", Tag); + + size_t Start = Xml.find(OpenTag); + if (Start == std::string_view::npos) + { + return {}; + } + Start += OpenTag.size(); + + size_t End = Xml.find(CloseTag, Start); + if (End == std::string_view::npos) + { + return {}; + } + + return Xml.substr(Start, End - Start); +} + +void +S3DecodeXmlEntities(std::string_view Input, StringBuilderBase& Out) +{ + if (Input.find('&') == std::string_view::npos) + { + Out.Append(Input); + return; + } + + for (size_t i = 0; i < Input.size(); ++i) + { + if (Input[i] == '&') + { + std::string_view Remaining = Input.substr(i); + if (Remaining.starts_with("&")) + { + Out.Append('&'); + i += 4; + } + else if (Remaining.starts_with("<")) + { + Out.Append('<'); + i += 3; + } + else if (Remaining.starts_with(">")) + { + Out.Append('>'); + i += 3; + } + else if (Remaining.starts_with(""")) + { + Out.Append('"'); + i += 5; + } + else if (Remaining.starts_with("'")) + { + Out.Append('\''); + i += 5; + } + else + { + Out.Append(Input[i]); + } + } + else + { + Out.Append(Input[i]); + } + } +} + +std::string +S3DecodeXmlEntities(std::string_view Input) +{ + if (Input.find('&') == std::string_view::npos) + { + return std::string(Input); + } + ExtendableStringBuilder<256> Sb; + S3DecodeXmlEntities(Input, Sb); + return Sb.ToString(); +} + +std::string +S3BuildRequestPath(std::string_view Path, std::string_view CanonicalQS) +{ + if (CanonicalQS.empty()) + { + return std::string(Path); + } + return fmt::format("{}?{}", Path, CanonicalQS); +} + +const std::string* +S3FindResponseHeader(const HttpClient::KeyValueMap& Headers, std::string_view Name) +{ + for (const auto& [K, V] : *Headers) + { + if (StrCaseCompare(K, Name) == 0) + { + return &V; + } + } + return nullptr; +} + +bool +S3ExtractError(std::string_view Body, std::string_view& OutCode, std::string_view& OutMessage) +{ + if (Body.find("<Error>") == std::string_view::npos) + { + return false; + } + OutCode = S3ExtractXmlValue(Body, "Code"); + OutMessage = S3ExtractXmlValue(Body, "Message"); + // Treat malformed bodies (Error tag present but no parseable Code/Message) + // as a parse miss; callers format "<prefix>: <Code> - <Message>" and an + // empty render is indistinguishable from "no error". S3IsThrottled with + // empty ErrorCode + S3ErrorMessage's Response.ErrorMessage fallback path + // covers status-only triage. + return !OutCode.empty() || !OutMessage.empty(); +} + +bool +S3IsThrottled(const HttpClient::Response& Response, std::string_view ErrorCode) +{ + const int Status = static_cast<int>(Response.StatusCode); + if (Status == 503 || Status == 429) + { + return true; + } + if (ErrorCode == "SlowDown" || ErrorCode == "ServiceUnavailable" || ErrorCode == "ThrottlingException" || + ErrorCode == "RequestLimitExceeded" || ErrorCode == "TooManyRequests") + { + return true; + } + return false; +} + +std::string +S3ErrorMessage(std::string_view Prefix, const HttpClient::Response& Response) +{ + if (!Response.Error.has_value() && Response.ResponsePayload) + { + std::string_view Body(reinterpret_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize()); + std::string_view Code; + std::string_view Message; + if (S3ExtractError(Body, Code, Message)) + { + ExtendableStringBuilder<256> Decoded; + S3DecodeXmlEntities(Message, Decoded); + if (S3IsThrottled(Response, Code)) + { + ZEN_WARN("S3 THROTTLED [{}] status={} code='{}' message='{}'", + Prefix, + static_cast<int>(Response.StatusCode), + Code, + Decoded.ToView()); + } + return fmt::format("{}: HTTP status ({}) {} - {}", Prefix, static_cast<int>(Response.StatusCode), Code, Decoded.ToView()); + } + } + if (S3IsThrottled(Response, {})) + { + ZEN_WARN("S3 THROTTLED [{}] status={} (no XML body)", Prefix, static_cast<int>(Response.StatusCode)); + } + return Response.ErrorMessage(Prefix); +} + +} // namespace zen diff --git a/src/zenutil/cloud/sigv4.cpp b/src/zenutil/cloud/sigv4.cpp index 055ccb2ad..34bd7f5f3 100644 --- a/src/zenutil/cloud/sigv4.cpp +++ b/src/zenutil/cloud/sigv4.cpp @@ -53,6 +53,62 @@ ComputeSha256(const void* Data, size_t Size) return Result; } +Sha256Stream::Sha256Stream() +{ + EVP_MD_CTX* Ctx = EVP_MD_CTX_new(); + ZEN_ASSERT(Ctx != nullptr); + int Rc = EVP_DigestInit_ex(Ctx, EVP_sha256(), nullptr); + ZEN_ASSERT(Rc == 1); + m_Ctx = Ctx; +} + +Sha256Stream::~Sha256Stream() +{ + if (m_Ctx) + { + EVP_MD_CTX_free(static_cast<EVP_MD_CTX*>(m_Ctx)); + m_Ctx = nullptr; + } +} + +Sha256Stream::Sha256Stream(Sha256Stream&& Other) noexcept : m_Ctx(Other.m_Ctx) +{ + Other.m_Ctx = nullptr; +} + +Sha256Stream& +Sha256Stream::operator=(Sha256Stream&& Other) noexcept +{ + if (this != &Other) + { + if (m_Ctx) + { + EVP_MD_CTX_free(static_cast<EVP_MD_CTX*>(m_Ctx)); + } + m_Ctx = Other.m_Ctx; + Other.m_Ctx = nullptr; + } + return *this; +} + +void +Sha256Stream::Update(const void* Data, size_t Size) +{ + int Rc = EVP_DigestUpdate(static_cast<EVP_MD_CTX*>(m_Ctx), Data, Size); + ZEN_ASSERT(Rc == 1); +} + +Sha256Digest +Sha256Stream::Finalize() +{ + Sha256Digest Result; + unsigned int Len = 0; + int Rc = EVP_DigestFinal_ex(static_cast<EVP_MD_CTX*>(m_Ctx), Result.data(), &Len); + ZEN_ASSERT(Rc == 1); + ZEN_ASSERT(Len == 32); + return Result; +} + Sha256Digest ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize) { @@ -171,6 +227,62 @@ ComputeSha256(const void* Data, size_t Size) return BcryptHash(GetBcryptHandles().Sha256, Data, Size); } +Sha256Stream::Sha256Stream() +{ + BCRYPT_HASH_HANDLE Handle = nullptr; + NTSTATUS Status = BCryptCreateHash(GetBcryptHandles().Sha256, &Handle, nullptr, 0, nullptr, 0, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + m_Ctx = Handle; +} + +Sha256Stream::~Sha256Stream() +{ + if (m_Ctx) + { + BCryptDestroyHash(static_cast<BCRYPT_HASH_HANDLE>(m_Ctx)); + m_Ctx = nullptr; + } +} + +Sha256Stream::Sha256Stream(Sha256Stream&& Other) noexcept : m_Ctx(Other.m_Ctx) +{ + Other.m_Ctx = nullptr; +} + +Sha256Stream& +Sha256Stream::operator=(Sha256Stream&& Other) noexcept +{ + if (this != &Other) + { + if (m_Ctx) + { + BCryptDestroyHash(static_cast<BCRYPT_HASH_HANDLE>(m_Ctx)); + } + m_Ctx = Other.m_Ctx; + Other.m_Ctx = nullptr; + } + return *this; +} + +void +Sha256Stream::Update(const void* Data, size_t Size) +{ + NTSTATUS Status = BCryptHashData(static_cast<BCRYPT_HASH_HANDLE>(m_Ctx), + reinterpret_cast<PUCHAR>(const_cast<void*>(Data)), + static_cast<ULONG>(Size), + 0); + ZEN_ASSERT(NT_SUCCESS(Status)); +} + +Sha256Digest +Sha256Stream::Finalize() +{ + Sha256Digest Result; + NTSTATUS Status = BCryptFinishHash(static_cast<BCRYPT_HASH_HANDLE>(m_Ctx), Result.data(), static_cast<ULONG>(Result.size()), 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + return Result; +} + Sha256Digest ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize) { @@ -251,6 +363,8 @@ GetAmzTimestamp() std::string AwsUriEncode(std::string_view Input, bool EncodeSlash) { + static constexpr char kHex[] = "0123456789ABCDEF"; + ExtendableStringBuilder<256> Result; for (char C : Input) { @@ -264,7 +378,11 @@ AwsUriEncode(std::string_view Input, bool EncodeSlash) } else { - Result.Append(fmt::format("%{:02X}", static_cast<unsigned char>(C))); + // Hand-rolled hex encode; avoids per-char fmt::format std::string alloc + // inside a loop that runs once per non-unreserved input character. + const uint8_t Byte = static_cast<uint8_t>(C); + const char Encoded[3] = {'%', kHex[Byte >> 4], kHex[Byte & 0x0F]}; + Result.Append(std::string_view(Encoded, 3)); } } return std::string(Result.ToView()); @@ -313,18 +431,10 @@ SignRequestV4(const SigV4Credentials& Credentials, std::string DateStamp = GetDateStamp(Result.AmzDate); - // Step 1: Create canonical request - // CanonicalRequest = - // HTTPRequestMethod + '\n' + - // CanonicalURI + '\n' + - // CanonicalQueryString + '\n' + - // CanonicalHeaders + '\n' + - // SignedHeaders + '\n' + - // HexEncode(Hash(RequestPayload)) - + // Step 1: canonical request. std::string CanonicalUri = AwsUriEncode(Url, false); - // Build canonical headers and signed headers (headers must be sorted by lowercase name) + // Headers assumed pre-sorted by lowercase name. ExtendableStringBuilder<512> CanonicalHeadersSb; ExtendableStringBuilder<256> SignedHeadersSb; @@ -342,30 +452,45 @@ SignRequestV4(const SigV4Credentials& Credentials, SignedHeadersSb.Append(Headers[i].first); } - std::string SignedHeaders = std::string(SignedHeadersSb.ToView()); - - std::string CanonicalRequest = fmt::format("{}\n{}\n{}\n{}\n{}\n{}", - Method, - CanonicalUri, - CanonicalQueryString, - CanonicalHeadersSb.ToView(), - SignedHeaders, - PayloadHash); - - // Step 2: Create the string to sign - std::string CredentialScope = fmt::format("{}/{}/{}/aws4_request", DateStamp, Region, Service); + std::string_view SignedHeaders = SignedHeadersSb.ToView(); + + ExtendableStringBuilder<2048> CanonicalRequestSb; + CanonicalRequestSb.Append(Method); + CanonicalRequestSb.Append('\n'); + CanonicalRequestSb.Append(CanonicalUri); + CanonicalRequestSb.Append('\n'); + CanonicalRequestSb.Append(CanonicalQueryString); + CanonicalRequestSb.Append('\n'); + CanonicalRequestSb.Append(CanonicalHeadersSb.ToView()); + CanonicalRequestSb.Append('\n'); + CanonicalRequestSb.Append(SignedHeaders); + CanonicalRequestSb.Append('\n'); + CanonicalRequestSb.Append(PayloadHash); + std::string_view CanonicalRequest = CanonicalRequestSb.ToView(); + + // Step 2: string-to-sign. + ExtendableStringBuilder<128> CredentialScopeSb; + CredentialScopeSb.Append(DateStamp); + CredentialScopeSb.Append('/'); + CredentialScopeSb.Append(Region); + CredentialScopeSb.Append('/'); + CredentialScopeSb.Append(Service); + CredentialScopeSb.Append("/aws4_request"); + std::string_view CredentialScope = CredentialScopeSb.ToView(); Sha256Digest CanonicalRequestHash = ComputeSha256(CanonicalRequest); std::string CanonicalRequestHex = Sha256ToHex(CanonicalRequestHash); - std::string StringToSign = fmt::format("AWS4-HMAC-SHA256\n{}\n{}\n{}", Result.AmzDate, CredentialScope, CanonicalRequestHex); - - // Step 3: Calculate the signing key - // kDate = HMAC("AWS4" + SecretKey, DateStamp) - // kRegion = HMAC(kDate, Region) - // kService = HMAC(kRegion, Service) - // kSigning = HMAC(kService, "aws4_request") + ExtendableStringBuilder<256> StringToSignSb; + StringToSignSb.Append("AWS4-HMAC-SHA256\n"); + StringToSignSb.Append(Result.AmzDate); + StringToSignSb.Append('\n'); + StringToSignSb.Append(CredentialScope); + StringToSignSb.Append('\n'); + StringToSignSb.Append(CanonicalRequestHex); + std::string_view StringToSign = StringToSignSb.ToView(); + // Step 3: derive signing key (kDate -> kRegion -> kService -> kSigning HMAC chain). Sha256Digest DerivedSigningKey; if (!SigningKeyPtr) { @@ -380,11 +505,11 @@ SignRequestV4(const SigV4Credentials& Credentials, SigningKeyPtr = &DerivedSigningKey; } - // Step 4: Calculate the signature + // Step 4: signature. Sha256Digest Signature = ComputeHmacSha256(*SigningKeyPtr, StringToSign); std::string SignatureHex = Sha256ToHex(Signature); - // Step 5: Build the Authorization header + // Step 5: Authorization header. Result.Authorization = fmt::format("AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", Credentials.AccessKeyId, CredentialScope, @@ -489,6 +614,69 @@ TEST_CASE("sigv4.sha256") CHECK(HelloHex == "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"); } +TEST_CASE("sigv4.sha256stream.matches_oneshot") +{ + // Empty input. + { + Sha256Stream S; + Sha256Digest Streamed = S.Finalize(); + Sha256Digest OneShot = ComputeSha256("", 0); + CHECK(Streamed == OneShot); + } + + // Single update. + { + Sha256Stream S; + S.Update("hello", 5); + Sha256Digest Streamed = S.Finalize(); + Sha256Digest OneShot = ComputeSha256("hello"); + CHECK(Streamed == OneShot); + } + + // Multi-chunk update; result must equal one-shot over concatenated bytes. + { + const std::string Whole = "the quick brown fox jumps over the lazy dog"; + Sha256Stream S; + S.Update(Whole.data(), 4); + S.Update(Whole.data() + 4, 16); + S.Update(Whole.data() + 20, Whole.size() - 20); + Sha256Digest Streamed = S.Finalize(); + Sha256Digest OneShot = ComputeSha256(Whole.data(), Whole.size()); + CHECK(Streamed == OneShot); + } + + // Large input fed in 256 KiB chunks (matches PutMedium hash pass shape). + { + std::vector<uint8_t> Big(1u * 1024u * 1024u + 7u); + for (size_t I = 0; I < Big.size(); ++I) + { + Big[I] = static_cast<uint8_t>((I * 31u + 11u) & 0xFF); + } + Sha256Stream S; + constexpr size_t kChunk = 256u * 1024u; + size_t Off = 0; + while (Off < Big.size()) + { + const size_t Take = std::min(kChunk, Big.size() - Off); + S.Update(Big.data() + Off, Take); + Off += Take; + } + Sha256Digest Streamed = S.Finalize(); + Sha256Digest OneShot = ComputeSha256(Big.data(), Big.size()); + CHECK(Streamed == OneShot); + } + + // Move construct + finalize on the moved-to instance. + { + Sha256Stream A; + A.Update("abc", 3); + Sha256Stream B(std::move(A)); + B.Update("def", 3); + Sha256Digest Result = B.Finalize(); + CHECK(Result == ComputeSha256("abcdef")); + } +} + TEST_CASE("sigv4.hmac_sha256") { // RFC 4231 Test Case 2 diff --git a/src/zenutil/include/zenutil/cloud/s3client.h b/src/zenutil/include/zenutil/cloud/s3client.h index 1ce2a768e..f09788b82 100644 --- a/src/zenutil/include/zenutil/cloud/s3client.h +++ b/src/zenutil/include/zenutil/cloud/s3client.h @@ -3,6 +3,7 @@ #pragma once #include <zenutil/cloud/imdscredentials.h> +#include <zenutil/cloud/s3requestbuilder.h> #include <zenutil/cloud/sigv4.h> #include <zencore/iobuffer.h> @@ -187,8 +188,8 @@ public: /// @param ExpiresIn URL validity duration (default 1 hour, max 7 days) std::string GeneratePresignedPutUrl(std::string_view Key, std::chrono::seconds ExpiresIn = std::chrono::hours(1)); - std::string_view BucketName() const { return m_BucketName; } - std::string_view Region() const { return m_Region; } + std::string_view BucketName() const { return m_Builder.BucketName(); } + std::string_view Region() const { return m_Builder.Region(); } private: /// Shared implementation for pre-signed URL generation @@ -196,31 +197,6 @@ private: LoggerRef Log() { return m_Log; } - /// Build the endpoint URL for the bucket - std::string BuildEndpoint() const; - - /// Build the host header value - std::string BuildHostHeader() const; - - /// Build the S3 object path from a key, accounting for path-style addressing - std::string KeyToPath(std::string_view Key) const; - - /// Build the bucket root path ("/" for virtual-hosted, "/bucket/" for path-style) - std::string BucketRootPath() const; - - /// Sign a request and return headers with Authorization, x-amz-date, x-amz-content-sha256. - /// Additional x-amz-* headers that must participate in the signature are passed via - /// ExtraSignedHeaders (lowercase name, value); they are also copied into the returned map. - HttpClient::KeyValueMap SignRequest(const SigV4Credentials& Credentials, - std::string_view Method, - std::string_view Path, - std::string_view QueryString, - std::string_view PayloadHash, - std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders = {}); - - /// Get or compute the signing key for the given date stamp, caching across requests on the same day - Sha256Digest GetSigningKey(std::string_view DateStamp); - /// Get the current credentials, either from the provider or from static config SigV4Credentials GetCurrentCredentials(); @@ -242,20 +218,12 @@ private: std::string BuildNoCredentialsError(std::string Context); LoggerRef m_Log; - std::string m_BucketName; - std::string m_Region; - std::string m_Endpoint; - std::string m_Host; - bool m_PathStyle; + S3RequestBuilder m_Builder; + mutable RwLock m_CredentialsLock; SigV4Credentials m_Credentials; Ref<ImdsCredentialProvider> m_CredentialProvider; HttpClient m_HttpClient; bool m_Verbose = false; - - // Cached signing key (only changes once per day, protected by RwLock for thread safety) - mutable RwLock m_SigningKeyLock; - std::string m_CachedDateStamp; - Sha256Digest m_CachedSigningKey{}; }; void s3client_forcelink(); diff --git a/src/zenutil/include/zenutil/cloud/s3requestbuilder.h b/src/zenutil/include/zenutil/cloud/s3requestbuilder.h new file mode 100644 index 000000000..c46167fba --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/s3requestbuilder.h @@ -0,0 +1,71 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenutil/cloud/sigv4.h> + +#include <zencore/thread.h> +#include <zenhttp/httpclient.h> + +#include <span> +#include <string> +#include <string_view> + +namespace zen { + +// Stateless builder of signed S3 requests, shared between blocking S3Client +// (HttpClient-backed) and S3AsyncStorage (AsyncHttpClient-backed). Owns the +// per-day signing-key cache so identical signatures across consecutive +// requests do not redo HMAC derivation. +// +// Configuration (region, bucket, endpoint, addressing style) is fixed at +// construction. Credentials are passed per Sign() call. The cache is keyed on +// (DateStamp, AccessKeyId), so any AccessKeyId rotation invalidates the +// derived signing key. STS / IMDS rotate AccessKeyId and SecretAccessKey +// together, so this is sufficient there. Callers hand-rotating SecretAccessKey +// while reusing AccessKeyId would sign with a stale derived key - not a usage +// pattern this builder supports. +class S3RequestBuilder +{ +public: + S3RequestBuilder(std::string Region, std::string BucketName, std::string Endpoint, bool PathStyle); + + std::string_view Region() const { return m_Region; } + std::string_view BucketName() const { return m_BucketName; } + std::string_view Endpoint() const { return m_Endpoint; } + std::string_view Host() const { return m_Host; } + bool PathStyle() const { return m_PathStyle; } + + std::string KeyToPath(std::string_view Key) const; + std::string BucketRootPath() const; + + // Sign a request. Returns headers including Authorization, x-amz-date, + // x-amz-content-sha256, and x-amz-security-token (when SessionToken set). + // ExtraSignedHeaders are lowercase-name pairs that must participate in the + // signature; they are also copied into the returned map. + HttpClient::KeyValueMap SignRequest(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Path, + std::string_view CanonicalQueryString, + std::string_view PayloadHash, + std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders = {}); + +private: + Sha256Digest GetSigningKey(std::string_view DateStamp, const SigV4Credentials& Credentials); + + static std::string DeriveEndpoint(std::string_view Endpoint, std::string_view Region, std::string_view BucketName, bool PathStyle); + static std::string HostFromEndpoint(std::string_view Endpoint); + + std::string m_Region; + std::string m_BucketName; + std::string m_Endpoint; + std::string m_Host; + bool m_PathStyle; + + mutable RwLock m_SigningKeyLock; + std::string m_CachedDateStamp; + std::string m_CachedAccessKeyId; + Sha256Digest m_CachedSigningKey{}; +}; + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/s3response.h b/src/zenutil/include/zenutil/cloud/s3response.h new file mode 100644 index 000000000..9dec3215b --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/s3response.h @@ -0,0 +1,51 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/string.h> +#include <zenhttp/httpclient.h> + +#include <string> +#include <string_view> + +namespace zen { + +// Helpers for parsing S3 (or S3-compatible) HTTP responses. Shared between +// the blocking S3Client and the async S3AsyncStorage path so XML/error +// handling stays consistent across implementations. +// +// The XML parser is intentionally minimal: only ListBucketResult / Error +// shapes are needed. CDATA, namespaces, and attributes are not handled. + +constexpr std::string_view S3EmptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; + +// Find text between <Tag> and </Tag>. Returns a view into Xml; empty if not found. +std::string_view S3ExtractXmlValue(std::string_view Xml, std::string_view Tag); + +// Decode the five standard XML entities into Out. +void S3DecodeXmlEntities(std::string_view Input, StringBuilderBase& Out); + +// Convenience overload returning std::string. +std::string S3DecodeXmlEntities(std::string_view Input); + +// Append a canonical query string to a path: "Path?CanonicalQS" or "Path" when QS empty. +std::string S3BuildRequestPath(std::string_view Path, std::string_view CanonicalQS); + +// Case-insensitive header lookup in a response header map. Returns nullptr if not found. +const std::string* S3FindResponseHeader(const HttpClient::KeyValueMap& Headers, std::string_view Name); + +// Extract Code/Message from an S3 <Error> body. Returns true only when at least one of +// Code/Message parsed non-empty - malformed bodies (Error tag present but no parseable +// children) are reported as parse miss to keep formatted error strings non-empty. +bool S3ExtractError(std::string_view Body, std::string_view& OutCode, std::string_view& OutMessage); + +// Detect throttling: HTTP 503/429 or known S3 throttle codes (SlowDown, ServiceUnavailable, etc). +// Code is checked on both the HTTP status and the XML error code. +bool S3IsThrottled(const HttpClient::Response& Response, std::string_view ErrorCode); + +// Build a human-readable error message for a failed S3 response. Falls back to the +// generic HTTP/transport message when no <Error> body is available. Logs a +// distinct "S3 THROTTLED" warning when the response indicates throttling. +std::string S3ErrorMessage(std::string_view Prefix, const HttpClient::Response& Response); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/sigv4.h b/src/zenutil/include/zenutil/cloud/sigv4.h index 9ac08df76..a16daf2cf 100644 --- a/src/zenutil/include/zenutil/cloud/sigv4.h +++ b/src/zenutil/include/zenutil/cloud/sigv4.h @@ -19,6 +19,26 @@ using Sha256Digest = std::array<uint8_t, 32>; Sha256Digest ComputeSha256(const void* Data, size_t Size); Sha256Digest ComputeSha256(std::string_view Data); +/// Streaming SHA-256: feed data in chunks, finalize once. +/// Move-only; copying digest state is not supported by the underlying APIs. +class Sha256Stream +{ +public: + Sha256Stream(); + ~Sha256Stream(); + + Sha256Stream(const Sha256Stream&) = delete; + Sha256Stream& operator=(const Sha256Stream&) = delete; + Sha256Stream(Sha256Stream&& Other) noexcept; + Sha256Stream& operator=(Sha256Stream&& Other) noexcept; + + void Update(const void* Data, size_t Size); + Sha256Digest Finalize(); // single-use; further Update/Finalize is undefined + +private: + void* m_Ctx = nullptr; +}; + /// Compute HMAC-SHA256 with the given key and data Sha256Digest ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize); Sha256Digest ComputeHmacSha256(const Sha256Digest& Key, std::string_view Data); |