diff options
40 files changed, 8723 insertions, 1608 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index f30e56754..44b9fb65d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,16 @@ - Improvement: Hub pools HTTP connections to managed instances so provision/deprovision churn no longer exhausts Windows ephemeral ports - Improvement: `zen` consolidates the `project-*` commands into a single `zen project <sub>` command tree (`project create`, `project drop`, `project info`, `project op-details`, `project stats`). The legacy `project-create`/`project-drop`/etc. names remain as hidden deprecated aliases that forward to the new dispatcher, so existing scripts keep working. - Improvement: `zen` consolidates the `oplog-*` commands into a single `zen oplog <sub>` command tree (`oplog create`, `oplog export`, `oplog import`, `oplog snapshot`, `oplog mirror`, `oplog validate`, `oplog download`). The legacy `oplog-create`/`oplog-export`/etc. names remain as hidden deprecated aliases that forward to the new dispatcher, so existing scripts keep working. +- Feature: `AsyncHttpClient` adds cancellable request tokens, streaming GET to a file (`AsyncDownload`), zero-copy chunk-callback GET (`AsyncStream`), pull-mode body source for streaming `AsyncPut`, retry layer mirroring the synchronous client, and a submit-side in-flight cap (`HttpClientSettings::MaxConcurrentRequests`) so hub-scale fanout against a single host cannot stall queued handles into curl's connect-timeout window +- Feature: Hub hydration can route S3 transfers through a non-blocking `AsyncHttpClient` (curl_multi + asio) backed by a single io thread; hydrate and dehydrate now pipeline requests instead of blocking worker threads + - `--hub-hydration-async-enabled` (Lua: `hub.hydration.async.enabled`, default true) + - `--hub-hydration-async-max-concurrent-requests` (Lua: `hub.hydration.async.maxconcurrentrequests`, default `clamp(cpu*4, 128, 512)`) +- Feature: Hub provision/deprovision/obliterate now run as two phases on separate worker pools so per-module hydration cannot starve child-process spawn/despawn (and vice versa) + - New `--hub-instance-spawn-threads` (Lua: `hub.instance.spawnthreads`, default `clamp(cpu/8, 4, 16)`) drives child-process spawn/despawn + - `--hub-instance-provision-threads` (Lua: `hub.instance.provisionthreads`) now drives per-module hydrate/dehydrate scheduling only; default changed from `max(cpu/4, 2)` to `clamp(cpu/8, 4, 12)` + - `--hub-hydration-threads` (Lua: `hub.hydration.threads`) now controls per-file workers inside a single hydrate/dehydrate; default changed from `max(cpu/4, 2)` to `clamp(cpu/8, 4, 12)` +- Feature: `Hub::Configuration` C++ struct fields renamed (`OptionalProvisionWorkerPool`/`OptionalHydrationWorkerPool` -> `OptionalProvisionPool`/`OptionalSpawnPool`/`OptionalHydrationPool`). Embedders constructing `Hub` directly must update field names; provision and spawn pools must both be set or both null (asserted at construction). +- Bugfix: `S3Client` signing-key cache no longer returns stale signatures after IMDS-rotated credentials change `AccessKeyId`; cache is now keyed on `(DateStamp, AccessKeyId)` - Bugfix: `zen builds download` no longer injects the `default` build part when `--download-spec` is given; the spec drives part selection ## 5.8.9 diff --git a/VERSION.txt b/VERSION.txt index 00813e7d5..b9756b309 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -5.8.10-pre1
\ No newline at end of file +5.8.10-pre2
\ No newline at end of file diff --git a/docs/hub.md b/docs/hub.md index 63d90c502..610f5d053 100644 --- a/docs/hub.md +++ b/docs/hub.md @@ -180,7 +180,8 @@ persist across hibernation and are removed only on deprovision. | `--hub-instance-http` | `hub.instance.http` | `httpsys` (Windows), `asio` (Linux/macOS) | HTTP server implementation for child instances. On Windows, use `asio` if the hub runs without elevation and no URL reservation covers the instance port range. | | `--hub-instance-http-threads` | `hub.instance.httpthreads` | `0` | HTTP connection threads per child instance. `0` uses hardware concurrency. | | `--hub-instance-corelimit` | `hub.instance.corelimit` | `0` | Concurrency limit for child instances. `0` is automatic. | -| `--hub-instance-provision-threads` | `hub.instance.provisionthreads` | `max(cpu/4, 2)` | Thread count for the instance provisioning worker pool. Controls parallel I/O during provision and deprovision operations. Set to `0` for synchronous operation. | +| `--hub-instance-provision-threads` | `hub.instance.provisionthreads` | `clamp(cpu/8, 4, 12)` | Per-module hydrate/dehydrate scheduling pool size. One thread per in-flight module hydrate or dehydrate; the per-file work fans out to `--hub-hydration-threads`. | +| `--hub-instance-spawn-threads` | `hub.instance.spawnthreads` | `clamp(cpu/8, 4, 16)` | Per-module child-process spawn/despawn pool size. One thread per `CreateProcess`/health-poll or terminate cycle. | | `--hub-instance-config` | `hub.instance.config` | _(none)_ | Path to a Lua config file passed to every spawned child instance. Use this to configure storage paths, cache sizes, and other storage server settings. See the zenserver configuration documentation. | | `--hub-instance-malloc` | `hub.instance.malloc` | _(none)_ | Memory allocator for child instances (`ansi`, `stomp`, `rpmalloc`, `mimalloc`). When unset, instances use their compiled-in default. | | `--hub-instance-trace` | `hub.instance.trace` | _(none)_ | Trace channel specification for child instances (e.g. `default`, `cpu,log`, `memory`). When set, instances start with tracing enabled on the specified channels. | @@ -247,12 +248,14 @@ suitable for single-host deployments where instances share locally cached data. |---------------------------------------------|---------------------------------------------|---------------------------|-------------| | `--hub-hydration-target-spec` | `hub.hydration.targetspec` | _(local path, see above)_ | Shorthand URI for the hydration source. Must use the `file://` prefix for file targets: `file:///absolute/path`. | | `--hub-hydration-target-config` | `hub.hydration.targetconfig` | _(none)_ | Path to a JSON file specifying the hydration source. Supports `file` and `s3` backends. | -| `--hub-hydration-threads` | `hub.hydration.threads` | `max(cpu/4, 2)` | Thread count for the hydration/dehydration worker pool. Controls parallel file hashing and backend I/O during hydrate/dehydrate. Set to `0` for synchronous operation. | +| `--hub-hydration-threads` | `hub.hydration.threads` | `clamp(cpu/8, 4, 12)` | Per-file worker pool size inside a single hydrate/dehydrate. Drives parallel file hashing and pack assembly; backend I/O on the async S3 path runs on the `AsyncHttpClient` io thread instead of these workers. Set to `0` for synchronous operation. | | `--hub-enable-hydration` | `hub.enablehydration` | `true` | Load instance state from the hydration target on provision. Disable to start every provision from an empty instance directory. | | `--hub-enable-dehydration` | `hub.enabledehydration` | `true` | Save instance state to the hydration target on deprovision. Disable to run the hydrate-only path (useful for perf testing against a fixed backend snapshot). | | `--hub-hydration-enable-pack` | `hub.hydration.enablepack` | `true` | Concatenate small files into CAS pack blobs during dehydrate. See [Pack](#pack). | | `--hub-hydration-pack-threshold-bytes` | `hub.hydration.packthresholdbytes` | `262144` (256 KiB) | Files strictly smaller than this are pack candidates. Larger files are stored as standalone CAS entries. | | `--hub-hydration-max-pack-bytes` | `hub.hydration.maxpackbytes` | `4194304` (4 MiB) | Upper bound on a single pack's concatenation size. Candidates are bin-packed greedily; packs that would exceed this cap are closed and a new pack is started. A unique candidate larger than the cap falls back to standalone upload. | +| `--hub-hydration-async-enabled` | `hub.hydration.async.enabled` | `true` | Route S3 hydration through `AsyncHttpClient` (curl_multi + asio, single io thread). `false` falls back to the blocking `S3Client` path. | +| `--hub-hydration-async-max-concurrent-requests` | `hub.hydration.async.maxconcurrentrequests` | `128` | Cap on in-flight S3 requests submitted to the `AsyncHttpClient`; excess submissions queue inside the client until a slot frees. Only consulted when `--hub-hydration-async-enabled=true`. | Multipart chunk size is S3-specific and set via the target config (see [Multipart chunking](#multipart-chunking)). @@ -497,8 +500,11 @@ hub = { http = "asio", httpthreads = 4, - -- Threads for provision/deprovision I/O (0 = synchronous) - provisionthreads = 4, + -- Per-module hydrate/dehydrate scheduling pool (0 = synchronous) + provisionthreads = 8, + + -- Per-module child-process spawn/despawn pool (0 = synchronous) + spawnthreads = 12, -- Config file applied to every child instance config = "/etc/zen/instance.lua", @@ -507,7 +513,14 @@ hub = { -- Hydrate new instances from S3 hydration = { targetconfig = "/etc/zen/hydration.json", - threads = 4, + threads = 8, -- per-file workers inside a single hydrate/dehydrate + + -- Async S3 path: pipeline requests on a single AsyncHttpClient io thread + -- instead of blocking worker threads. Default true. + async = { + enabled = true, + maxconcurrentrequests = 64, + }, }, watchdog = { diff --git a/scripts/test_scripts/hub/PERF_SEED_README.md b/scripts/test_scripts/hub/PERF_SEED_README.md index fb471d4bb..eacb0da55 100644 --- a/scripts/test_scripts/hub/PERF_SEED_README.md +++ b/scripts/test_scripts/hub/PERF_SEED_README.md @@ -3,35 +3,40 @@ Three-stage pipeline for running repeatable hub-hydration perf tests against a local MinIO backend seeded with real module data pulled from production S3. +The pipeline is **pack-on only** - the seeded baseline always comes from a hub +launched with `--hub-hydration-enable-pack=true`. The pack-off variant is no +longer maintained. + ## Layout -All scripts default to a single perf-seed root - currently `E:/Dev/zen-perf-seed/` -in the script defaults, but every path is overridable via CLI flag (see the -per-stage options below). Pick a root with enough free space (snapshots and -preserved CAS dirs can be large) and either pass the matching `--*-dir` flag on -each invocation or change the script defaults to your chosen root. +All path arguments are required (no hardcoded defaults). Pick a perf-seed root +with enough free space (snapshots and preserved CAS dirs can be large) and pass +the matching `--*-dir` flag on each invocation. Stage A's hub data dir should +live on the same volume as the snapshot dir so snapshotting is an O(1) rename +per module instead of a cross-volume byte copy; Stage C's hub data dir should +live on a different volume from the MinIO data dir so hub I/O does not skew the +measured perf run. -Layout under the chosen root (`<perf-seed>/`): +Example layout (directory names only; pick volumes/roots and pass via `--*-dir` +flags): ``` -<perf-seed>/ - hub-a/ Stage A hub data dir (transient) +<perf-seed-A>/ bulk data + Stage A/B flow (one volume = move-friendly) + hub-a/ Stage A hub data dir (transient; snapshot-step rename source) servers/<moduleid>/ - s3-snapshot/ Preserved production server-state trees (read-only after Stage A) + s3-snapshot/ Preserved production server-state trees (read-only after Stage A) <moduleid>/ - hubs/ Stage B per-bucket hub data dirs (transient) + hubs/ Stage B per-bucket hub data dirs (transient) hub-b-zen-seed-packed/ - hub-b-zen-seed-unpacked/ - minio-data/ Stage B MinIO data dir (transient, carries every seeded bucket) - minio-seeded-baseline/ Preserved baseline MinIO CAS (read-only after Stage B + preserve) - README.txt - minio-seeded-packed/ Preserved packed MinIO CAS (filled by the pack worktree) + minio-data/ Stage B MinIO data dir (transient) + minio-seeded-packed/ Preserved packed MinIO CAS (read-only after Stage B + preserve) README.txt - hub-perf/ Stage C hub data dir (wiped each run) - minio-run/ Stage C MinIO data dir (wiped + re-copied each run) - perf-runs/ Per-run archive: hub.log, logs/, hub.utrace, summary.json + minio-run/ Stage C MinIO data dir (wiped + re-copied each run) + perf-runs/ Per-run archive: hub.log, logs/, hub.utrace, summary.json 20260423-141530_zen-seed-packed/ - 20260423-143112_zen-seed-unpacked/ + +<perf-seed-B>/ separate volume from <perf-seed-A> for measurement isolation + hub-perf/ Stage C hub data dir (wiped each run) ``` ## Prerequisites @@ -46,7 +51,7 @@ Layout under the chosen root (`<perf-seed>/`): ## Stage A - snapshot real S3 data -One-time (or when you want a fresh baseline from production). +One-time (or when you want a fresh snapshot from production). ``` export ZEN_PERF_S3_URI=s3://your-bucket/ @@ -54,8 +59,11 @@ export ZEN_PERF_AWS_PROFILE=your-sso-profile python scripts/test_scripts/hub/seed_s3_snapshot.py ``` -Provisions N modules from `$ZEN_PERF_S3_URI`, hibernates them, then copies -`hub-a/servers/<mid>/` to `s3-snapshot/<mid>/`. Triggers `aws sso login` +Provisions N modules from `$ZEN_PERF_S3_URI`, hibernates them, then **moves** +`hub-a/servers/<mid>/` to `s3-snapshot/<mid>/`. When `--hub-data-dir` and +`--snapshot-dir` share a volume (the default) the move is an O(1) rename per +module; cross-volume falls back to a byte copy with the old cost profile. The +hub data dir is wiped on the next run regardless. Triggers `aws sso login` automatically if the SSO token is missing or expired. Module selection ranks all UUID-shaped folders by their @@ -64,37 +72,27 @@ most-recently-accessed) and takes the top `--module-count`. Options: - `--module-count N` (default 1000) -- `--snapshot-dir PATH` (default `<perf-seed>/s3-snapshot`) -- `--hub-data-dir PATH` (default `<perf-seed>/hub-a`) +- `--snapshot-dir PATH` (required, e.g. `<perf-seed>/s3-snapshot`) +- `--hub-data-dir PATH` (required, e.g. `<perf-seed>/hub-a`) ## Stage B - seed MinIO from the snapshot -One-time per pack-mode (or when `s3-snapshot` changes). +One-time, or when `s3-snapshot/` changes. -`seed_minio.py` seeds a **single** bucket per invocation. The pack flag is -hardcoded inside the script (`--hub-hydration-enable-pack=true` near the -top of `_start_hub`). To produce both packed and unpacked baselines for -comparison, invoke the script twice from two separate worktrees - one with -the flag flipped to `false` - and preserve the resulting MinIO data dir -each time. +`seed_minio.py` seeds the `zen-seed-packed` bucket with pack ON +(`--hub-hydration-enable-pack=true` is hardcoded). The script provisions every +module found under `s3-snapshot/`, hibernates them, overlays the snapshot on +top of the hub's servers dir, then deprovisions all modules - which runs the +dehydrate path and uploads the content into the bucket. ``` -# In the pack worktree (flag = true), seeds zen-seed-packed python scripts/test_scripts/hub/seed_minio.py --wipe --bucket zen-seed-packed python scripts/test_scripts/hub/preserve_minio_state.py --dest <perf-seed>/minio-seeded-packed - -# In the no-pack worktree (flag = false), seeds zen-seed-unpacked -python scripts/test_scripts/hub/seed_minio.py --wipe --bucket zen-seed-unpacked -python scripts/test_scripts/hub/preserve_minio_state.py --dest <perf-seed>/minio-seeded-unpacked ``` -The script provisions every module found under `s3-snapshot/`, hibernates -them, overlays the snapshot on top of the hub's servers dir, then -deprovisions all modules - which runs the dehydrate path and uploads the -content into the bucket. - -`preserve_minio_state.py` copies the resulting `minio-data/` to a -variant-specific preservation dir and writes a README with provenance. +`preserve_minio_state.py` MOVES (default; `--copy` to keep source) the +resulting `minio-data/` to the preservation dir and writes a README with +provenance. Options of interest: - `--bucket NAME` - bucket name (default `zen-seed-packed`). @@ -107,24 +105,18 @@ Options of interest: Repeat as often as you want; each run starts from the preserved baseline. ``` -# Pack-on bucket python scripts/test_scripts/hub/run_minio_perf.py --bucket zen-seed-packed --trace - -# Pack-off bucket (for comparison) -python scripts/test_scripts/hub/run_minio_perf.py --bucket zen-seed-unpacked --trace ``` Steps: -1. Copies `--minio-seeded` (default `minio-seeded-baseline/`) over `minio-run/` so MinIO starts from a known state. -2. Wipes `hub-perf/` (unless `--no-wipe-hub`). +1. Copies `--minio-seeded` over `--minio-run` so MinIO starts from a known state. +2. Wipes `--hub-data-dir` (unless `--no-wipe-hub`). 3. Starts MinIO and hub. 4. Provisions all modules, waits for `provisioned`, deprovisions, waits gone. 5. Stops everything cleanly. Default mode is `--hub-enable-dehydration=false` so MinIO isn't modified; every -iteration exercises the hydrate-only path against the same baseline CAS. The -`--bucket` flag selects which seeded bucket (and therefore which pack mode) -to exercise. +iteration exercises the hydrate-only path against the same baseline CAS. Pass `--enable-dehydration` to run a full provision -> deprovision cycle that includes re-upload (dehydrate) at deprovision time. Use this to measure the @@ -139,10 +131,9 @@ post-hoc. Override the destination with `--archive-dir PATH`. ## Resetting between runs -- **Keep**: `s3-snapshot/`, `minio-seeded-baseline/`, `minio-seeded-packed/`. These are expensive to rebuild. +- **Keep**: `s3-snapshot/`, `minio-seeded-packed/`. These are expensive to rebuild. - **Discard freely**: `hub-a/`, `hubs/`, `hub-perf/`, `minio-data/`, `minio-run/`. -To force a fresh MinIO seed for one variant: delete the matching -`minio-seeded-<variant>/` and re-run Stage B + preserve (with the matching -`--dest`) in that worktree. To force a fresh S3 snapshot: delete -`s3-snapshot/` and re-run Stage A. +To force a fresh MinIO seed: delete `minio-seeded-packed/` and re-run Stage B ++ preserve. To force a fresh S3 snapshot: delete `s3-snapshot/` and re-run +Stage A. diff --git a/scripts/test_scripts/hub/hub_load_test_s3.py b/scripts/test_scripts/hub/hub_load_test_s3.py index 23014409c..e71222e31 100644 --- a/scripts/test_scripts/hub/hub_load_test_s3.py +++ b/scripts/test_scripts/hub/hub_load_test_s3.py @@ -372,8 +372,8 @@ def _wait_for_deprovisioned( def main() -> None: parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument("--data-dir", default="E:/Dev/hub-loadtest-s3", - help="Hub --data-dir (default: E:/Dev/hub-loadtest-s3)") + parser.add_argument("--data-dir", required=True, + help="Hub --data-dir.") parser.add_argument("--port", type=int, default=8558, help="Hub HTTP port (default: 8558)") parser.add_argument("--module-count", type=int, default=200, diff --git a/scripts/test_scripts/hub/parse_perf_log.py b/scripts/test_scripts/hub/parse_perf_log.py new file mode 100644 index 000000000..8833a5211 --- /dev/null +++ b/scripts/test_scripts/hub/parse_perf_log.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +"""Parse hub perf log to extract hydration / spawn / deprovision timings.""" +from __future__ import annotations +import re +import statistics +import sys +from pathlib import Path + +HYDRATE_RE = re.compile(r"Hydration complete module '([^']+)': (\d+) files \(([^)]+)\) in ([\d.]+)(ms|s)") +SPAWN_RE = re.compile(r"module '([^']+)' started, listening on port \d+, spawn took ([\d.]+)(ms|s)") +PROVISION_START_RE = re.compile(r"\[(\d{2}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3})\] \[inf\] Provisioning storage server instance for module '([^']+)'") +SPAWN_DONE_RE = re.compile(r"\[(\d{2}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3})\] \[inf\] Storage server instance for module '([^']+)' started") +HYDRATE_DONE_RE = re.compile(r"\[(\d{2}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3})\] \[inf\] Hydration complete module '([^']+)'") +DEPROV_START_RE = re.compile(r"\[(\d{2}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3})\] \[inf\] Module '([^']+)' changed state from provisioned to deprovisioning") +DEPROV_REMOVE_RE = re.compile(r"\[(\d{2}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3})\] \[inf\] Module '([^']+)' (?:removed|changed state from deprovisioning to unprovisioned)") + + +def to_ms(num: str, unit: str) -> float: + v = float(num) + return v * 1000.0 if unit == "s" else v + + +def parse_ts(ts: str) -> float: + # 26-04-29 22:41:34.009 -> seconds since arbitrary epoch (use day-of-year * 86400 + h*3600 + m*60 + s) + import datetime + dt = datetime.datetime.strptime(ts, "%y-%m-%d %H:%M:%S.%f") + return dt.timestamp() + + +def stats(name: str, vals: list[float]) -> str: + if not vals: + return f"{name:14s} n=0" + vs = sorted(vals) + n = len(vs) + p50 = vs[n // 2] + p95 = vs[min(n - 1, int(n * 0.95))] + p99 = vs[min(n - 1, int(n * 0.99))] + return (f"{name:14s} n={n:5d} mean={statistics.mean(vs):8.0f}ms " + f"p50={p50:7.0f}ms p95={p95:7.0f}ms p99={p99:7.0f}ms max={max(vs):7.0f}ms " + f"sum={sum(vs)/1000.0:7.1f}s") + + +def analyse(log_path: Path) -> None: + print(f"=== {log_path}") + hydrate_durations: list[float] = [] + spawn_durations: list[float] = [] + + prov_start: dict[str, float] = {} + hydrate_done: dict[str, float] = {} + spawn_done: dict[str, float] = {} + deprov_start: dict[str, float] = {} + deprov_done: dict[str, float] = {} + + with log_path.open("r", encoding="utf-8", errors="replace") as f: + for line in f: + m = HYDRATE_RE.search(line) + if m: + hydrate_durations.append(to_ms(m.group(4), m.group(5))) + m = SPAWN_RE.search(line) + if m: + spawn_durations.append(to_ms(m.group(2), m.group(3))) + m = PROVISION_START_RE.search(line) + if m: + prov_start[m.group(2)] = parse_ts(m.group(1)) + continue + m = HYDRATE_DONE_RE.search(line) + if m: + hydrate_done[m.group(2)] = parse_ts(m.group(1)) + continue + m = SPAWN_DONE_RE.search(line) + if m: + spawn_done[m.group(2)] = parse_ts(m.group(1)) + continue + m = DEPROV_START_RE.search(line) + if m: + deprov_start[m.group(2)] = parse_ts(m.group(1)) + continue + m = DEPROV_REMOVE_RE.search(line) + if m: + deprov_done[m.group(2)] = parse_ts(m.group(1)) + continue + + # Reconstruct end-to-end provision wall time per module: + e2e_prov_ms: list[float] = [] + queue_to_hydrate_ms: list[float] = [] # time in DataPool queue before hydrate begun (proxy: prov_start->hydrate_done minus hydrate ms) + spawn_minus_data_ms: list[float] = [] # time between hydrate_done and spawn_done (lifetime + lifetime queue) + for mid, t0 in prov_start.items(): + t_end = spawn_done.get(mid) + if t_end is None: + continue + e2e_prov_ms.append((t_end - t0) * 1000.0) + t_h = hydrate_done.get(mid) + if t_h is not None: + spawn_minus_data_ms.append((t_end - t_h) * 1000.0) + + e2e_deprov_ms: list[float] = [] + for mid, t0 in deprov_start.items(): + t_end = deprov_done.get(mid) + if t_end is None: + continue + e2e_deprov_ms.append((t_end - t0) * 1000.0) + + print(stats("hydrate_dur", hydrate_durations)) + print(stats("spawn_dur", spawn_durations)) + print(stats("e2e_prov", e2e_prov_ms)) + print(stats("after_hydrate", spawn_minus_data_ms)) + print(stats("e2e_deprov", e2e_deprov_ms)) + if hydrate_durations and spawn_durations: + h_sum = sum(hydrate_durations) / 1000.0 + s_sum = sum(spawn_durations) / 1000.0 + print(f" total hydrate work: {h_sum:.1f}s, total spawn work: {s_sum:.1f}s, ratio H/S: {h_sum/s_sum:.2f}") + print() + + +if __name__ == "__main__": + for arg in sys.argv[1:]: + analyse(Path(arg)) diff --git a/scripts/test_scripts/hub/perf_configs/hub.lua b/scripts/test_scripts/hub/perf_configs/hub.lua index f3cf3e697..ff9ab582e 100644 --- a/scripts/test_scripts/hub/perf_configs/hub.lua +++ b/scripts/test_scripts/hub/perf_configs/hub.lua @@ -12,19 +12,14 @@ hub = { disklimitpercent = 90, -- default: 0 (disabled) }, corelimit = 4, -- default: 0 (auto) - provisionthreads = 8, -- default: auto + -- provisionthreads / spawnthreads / hub.hydration.threads left unset: defaults + -- (clamp(cpu/8,4,12) / clamp(cpu/8,4,16) / clamp(cpu/8,4,12)) are tuned from + -- 1000-module sweep at 128 vCPU + 30ms latency. Override here only to A/B test. -- NOTE: hub.instance.config (path to instance lua) is overridden via -- --hub-instance-config on the CLI. If left here, it would be resolved -- relative to the hub's CWD at spawn time (NOT this file's dir). }, - hydration = { - -- Match production's per-module download pool size. Without this, the - -- default auto-picks hardware_concurrency/4 which on --corelimit=128 - -- would be 32. Prod logs consistently show "16 threads" in Download phase. - threads = 16, - }, - watchdog = { cycleintervalms = 5000, -- default: 3000. slower cycle, 1000 instances to scan cycleprocessingbudgetms = 1000, -- default: 500. more budget per cycle for larger instance count diff --git a/scripts/test_scripts/hub/perf_configs/instance.lua b/scripts/test_scripts/hub/perf_configs/instance.lua index 1251997db..cb292a0cd 100644 --- a/scripts/test_scripts/hub/perf_configs/instance.lua +++ b/scripts/test_scripts/hub/perf_configs/instance.lua @@ -13,3 +13,9 @@ cache = { }, }, } + +server = { + sentry = { + disable = true, -- perf runs: keep Sentry on the hub only; skip per-child crash reporter init + }, +} diff --git a/scripts/test_scripts/hub/preserve_minio_state.py b/scripts/test_scripts/hub/preserve_minio_state.py index 365e4a542..cfa63af31 100644 --- a/scripts/test_scripts/hub/preserve_minio_state.py +++ b/scripts/test_scripts/hub/preserve_minio_state.py @@ -9,10 +9,8 @@ wipes --minio-data-dir on its next invocation anyway. Pass --copy to keep --source in place (slower; needs 2x disk). -Typical invocation: - python preserve_minio_state.py - -Defaults map to the paths recommended by PERF_SEED_README.md. +Both --source and --dest are required. See PERF_SEED_README.md for the +expected layout. """ from __future__ import annotations @@ -57,11 +55,10 @@ def _size_of(path: Path) -> tuple[int, int]: def main() -> int: parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument("--source", default="E:/Dev/zen-perf-seed/minio-data", - help="Source MinIO data dir (default: E:/Dev/zen-perf-seed/minio-data)") - parser.add_argument("--dest", default="E:/Dev/zen-perf-seed/minio-seeded-packed", - help="Preservation path (default: E:/Dev/zen-perf-seed/minio-seeded-packed). " - "Sibling to E:/Dev/zen-perf-seed/minio-seeded-baseline.") + parser.add_argument("--source", required=True, + help="Source MinIO data dir produced by Stage B (seed_minio.py).") + parser.add_argument("--dest", required=True, + help="Preservation path; the move/copy target that becomes the baseline read by Stage C.") parser.add_argument("--s3-uri", default=os.environ.get("ZEN_PERF_S3_URI", ""), help="Source S3 URI recorded in the README (defaults to $ZEN_PERF_S3_URI)") parser.add_argument("--bucket", default="zen-seed", diff --git a/scripts/test_scripts/hub/run_minio_perf.py b/scripts/test_scripts/hub/run_minio_perf.py index b59cff41a..575303c61 100644 --- a/scripts/test_scripts/hub/run_minio_perf.py +++ b/scripts/test_scripts/hub/run_minio_perf.py @@ -68,7 +68,7 @@ def _find_zenserver(override: Optional[str]) -> Path: sys.exit(f"zenserver not found at {p}") return p script_dir = Path(__file__).resolve().parent - repo_root = script_dir.parent.parent + repo_root = script_dir.parents[2] for mode in ("release", "debug"): for plat in (("windows", "x64"), ("linux", "x86_64"), ("macosx", "x86_64")): p = repo_root / "build" / plat[0] / plat[1] / mode / f"zenserver{_EXE_SUFFIX}" @@ -91,6 +91,73 @@ def _find_minio(zenserver_path: Path) -> Path: return p +def _resolve_toxiproxy_exe(arg_value: Optional[str], env_var: str, default_name: str) -> Path: + """Resolve a toxiproxy executable: explicit --toxiproxy-* arg wins, then env var, + then PATH lookup of `default_name`. Exits if none of the three resolve to a real + file.""" + candidate = arg_value or os.environ.get(env_var) + if candidate: + p = Path(candidate) + if not p.exists(): + sys.exit(f"[toxiproxy] {default_name}: explicit path '{p}' not found") + return p + found = shutil.which(default_name) + if found: + return Path(found) + sys.exit(f"[toxiproxy] {default_name} not found: pass --toxiproxy-server-exe / --toxiproxy-cli-exe, " + f"set {env_var}, or put {default_name} on PATH") + + +def _start_toxiproxy(server_exe: Path, cli_exe: Path, api_port: int, listen_port: int, + upstream_port: int, latency_ms: int, jitter_ms: int) -> subprocess.Popen: + """Start toxiproxy-server, create a proxy minio_proxy listening on listen_port and + forwarding to localhost:upstream_port, and add a latency toxic. Returns the server + Popen so the caller can terminate it on cleanup.""" + popen_kwargs: dict = {} + if sys.platform == "win32": + popen_kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP + proc = subprocess.Popen( + [str(server_exe), "-port", str(api_port)], + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, **popen_kwargs, + ) + print(f"[toxiproxy] server started (pid {proc.pid}) api-port={api_port}") + # Wait for the API to come up. + api_url = f"http://localhost:{api_port}/version" + deadline = time.monotonic() + 10.0 + while time.monotonic() < deadline: + try: + with urllib.request.urlopen(api_url, timeout=1): + break + except Exception: + time.sleep(0.1) + else: + proc.terminate() + sys.exit(f"[toxiproxy] API at {api_url} never came up") + + env = os.environ.copy() + env["TOXIPROXY_URL"] = f"http://localhost:{api_port}" + # Create the proxy. + subprocess.run( + [str(cli_exe), "create", "-l", f"127.0.0.1:{listen_port}", + "-u", f"127.0.0.1:{upstream_port}", "minio_proxy"], + env=env, check=True, + ) + # Add latency toxic to both directions (default direction is downstream; + # add explicit upstream too for symmetry). + args_common = ["-t", "latency", "-a", f"latency={latency_ms}", "-a", f"jitter={jitter_ms}"] + subprocess.run( + [str(cli_exe), "toxic", "add", "-n", "lat_down", "-d", *args_common, "minio_proxy"], + env=env, check=True, + ) + subprocess.run( + [str(cli_exe), "toxic", "add", "-n", "lat_up", "-u", *args_common, "minio_proxy"], + env=env, check=True, + ) + print(f"[toxiproxy] minio_proxy listening on 127.0.0.1:{listen_port} -> " + f"127.0.0.1:{upstream_port}, latency={latency_ms}ms jitter={jitter_ms}ms (per direction)") + return proc + + def _start_minio(minio_exe: Path, data_dir: Path, port: int, console_port: int) -> subprocess.Popen: data_dir.mkdir(parents=True, exist_ok=True) env = os.environ.copy() @@ -116,10 +183,12 @@ def _start_minio(minio_exe: Path, data_dir: Path, port: int, console_port: int) def _wait_for_minio(port: int, timeout_s: float = 30.0) -> None: deadline = time.monotonic() + timeout_s url = f"http://localhost:{port}/minio/health/live" + t0 = time.monotonic() + print(f"[minio] waiting for ready on port {port} ...", flush=True) while time.monotonic() < deadline: try: with urllib.request.urlopen(url, timeout=1): - print("[minio] ready") + print(f"[minio] ready ({time.monotonic()-t0:.1f}s)", flush=True) return except Exception: time.sleep(0.1) @@ -178,14 +247,21 @@ def _wait_for_hub(proc: subprocess.Popen, port: int, timeout_s: float = 100.0) - deadline = time.monotonic() + timeout_s req = urllib.request.Request(f"http://localhost:{port}/hub/status", headers={"Accept": "application/json"}) + t0 = time.monotonic() + last_tick = t0 + print(f"[hub] waiting for ready on port {port} ...", flush=True) while time.monotonic() < deadline: if proc.poll() is not None: sys.exit(f"[hub] exited unexpectedly (rc={proc.returncode})") try: with urllib.request.urlopen(req, timeout=2): - print("[hub] ready") + print(f"[hub] ready ({time.monotonic()-t0:.1f}s)", flush=True) return except Exception: + now = time.monotonic() + if now - last_tick >= 5.0: + print(f"[hub] still waiting ({now-t0:.1f}s elapsed)", flush=True) + last_tick = now time.sleep(0.2) sys.exit(f"[hub] timed out after {timeout_s}s") @@ -195,14 +271,16 @@ def _zen_down_hub(zen_exe: Path, hub_proc: subprocess.Popen, timeout_s: float = if hub_proc.poll() is not None: return pid = hub_proc.pid - print(f"[hub] zen down --pid {pid}") + print(f"[hub] zen down --pid {pid} ...", flush=True) + t0 = time.time() rc = subprocess.call([str(zen_exe), "down", "--pid", str(pid), "--force"]) if rc != 0: - print(f"[hub] zen down returned rc={rc}; waiting for exit anyway") + print(f"[hub] zen down returned rc={rc}; waiting for exit anyway", flush=True) try: hub_proc.wait(timeout=timeout_s) + print(f"[hub] exited ({time.time()-t0:.1f}s)", flush=True) except subprocess.TimeoutExpired: - print(f"[hub] did not exit after {timeout_s}s, killing") + print(f"[hub] did not exit after {timeout_s}s, killing", flush=True) hub_proc.kill() hub_proc.wait() @@ -210,6 +288,8 @@ def _zen_down_hub(zen_exe: Path, hub_proc: subprocess.Popen, timeout_s: float = def _stop_minio_graceful(proc: subprocess.Popen, timeout_s: float = 30.0) -> None: if proc.poll() is not None: return + print(f"[minio] stopping (pid {proc.pid}) ...", flush=True) + t0 = time.time() try: if sys.platform == "win32": proc.send_signal(signal.CTRL_BREAK_EVENT) @@ -219,8 +299,9 @@ def _stop_minio_graceful(proc: subprocess.Popen, timeout_s: float = 30.0) -> Non proc.terminate() try: proc.wait(timeout=timeout_s) + print(f"[minio] stopped ({time.time()-t0:.1f}s)", flush=True) except subprocess.TimeoutExpired: - print(f"[minio] did not exit after {timeout_s}s, killing") + print(f"[minio] did not exit after {timeout_s}s, killing", flush=True) proc.kill() proc.wait() @@ -341,29 +422,34 @@ def _wait_for_gone(port: int, ids: list[str], timeout_s: float) -> list[str]: def _robust_copytree(src: Path, dst: Path) -> None: - """Windows-friendly directory copy with progress. + """Directory copy of the seeded baseline onto the working MinIO data dir. + + On a ReFS volume we use `refs_clone.clone_tree` to issue + FSCTL_DUPLICATE_EXTENTS_TO_FILE per file: O(1) per file via copy-on-write + metadata, regardless of size. A 300+ GB baseline clones in seconds. + Off ReFS we fall back to `shutil.copytree`. PowerShell `Copy-Item` was + tried previously and observed to byte-copy on this host, so it is no + longer used. - Uses robocopy /MIR to mirror src -> dst and /COPY:DAT /DCOPY:DAT to copy - file data + attributes + timestamps explicitly (not just the defaults). Safe because dst is always the working dir (minio-run) - never the preserved baseline. """ - if sys.platform == "win32": - cmd = [ - "robocopy", str(src), str(dst), - "/MIR", - "/COPY:DAT", "/DCOPY:DAT", - "/NJH", "/NJS", "/NC", "/NDL", "/NFL", "/NP", - "/R:2", "/W:1", - ] - print(f"[reset] robocopy {src} -> {dst}") - rc = subprocess.call(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - if rc >= 8: - sys.exit(f"[reset] robocopy failed rc={rc}") + from refs_clone import clone_tree, is_refs_volume + + if dst.exists(): + _rmtree_robust(dst) + + if sys.platform == "win32" and is_refs_volume(src) and is_refs_volume(dst.parent): + print(f"[reset] ReFS block-clone {src} -> {dst} ...", flush=True) + t0 = time.time() + files, bytes_total = clone_tree(src, dst) + print(f"[reset] cloned {files:,} files, {bytes_total/1024/1024:.1f} MB " + f"in {time.time()-t0:.1f}s", flush=True) else: - if dst.exists(): - _rmtree_robust(dst) + print(f"[reset] copytree (non-ReFS) {src} -> {dst} ...", flush=True) + t0 = time.time() shutil.copytree(src, dst, symlinks=False) + print(f"[reset] copytree done in {time.time()-t0:.1f}s", flush=True) def _archive_run( @@ -415,26 +501,27 @@ def _archive_run( def main() -> int: parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument("--minio-seeded", default="E:/Dev/zen-perf-seed/minio-seeded-packed", - help="Preserved MinIO baseline (default: E:/Dev/zen-perf-seed/minio-seeded-packed). " - "Sibling to E:/Dev/zen-perf-seed/minio-seeded-baseline.") - parser.add_argument("--minio-run", default="E:/Dev/zen-perf-seed/minio-run", - help="Working MinIO data dir, wiped and re-copied each run (default: E:/Dev/zen-perf-seed/minio-run)") - parser.add_argument("--snapshot-dir", default="E:/Dev/zen-perf-seed/s3-snapshot", - help="Source of module IDs to run against (default: E:/Dev/zen-perf-seed/s3-snapshot)") - parser.add_argument("--hub-data-dir", default="E:/Dev/zen-perf-seed/hub-perf", - help="Hub --data-dir for the perf run (wiped each run if --wipe). Default: E:/Dev/zen-perf-seed/hub-perf") - parser.add_argument("--archive-dir", default="E:/Dev/zen-perf-seed/perf-runs", - help="Where to archive hub.log + zenserver.log + hub.utrace after each run " - "(default: E:/Dev/zen-perf-seed/perf-runs)") + parser.add_argument("--minio-seeded", required=True, + help="Preserved MinIO baseline (read-only source the working dir is reset from each run).") + parser.add_argument("--minio-run", required=True, + help="Working MinIO data dir, wiped and re-copied from --minio-seeded each run.") + parser.add_argument("--snapshot-dir", required=True, + help="Source of module IDs to run against (per-module server-state trees from Stage A).") + parser.add_argument("--hub-data-dir", required=True, + help="Hub --data-dir for the perf run (wiped each run unless --no-wipe-hub). " + "Place on a different volume from --minio-run to keep hub I/O from skewing MinIO measurements.") + parser.add_argument("--archive-dir", required=True, + help="Where to archive hub.log + zenserver.log + hub.utrace after each run.") parser.add_argument("--bucket", default="zen-seed-packed", - help="MinIO bucket name to exercise (default: zen-seed-packed). " - "Pack worktree seeds only the packed bucket.") + help="MinIO bucket name to exercise (default: zen-seed-packed).") parser.add_argument("--minio-port", type=int, default=9000) parser.add_argument("--minio-console-port", type=int, default=9001) parser.add_argument("--hub-port", type=int, default=8558) - parser.add_argument("--module-count", type=int, default=0, - help="Cap on modules used (0 = all from snapshot-dir)") + parser.add_argument("--module-count", type=int, default=1000, + help="Cap on modules used (default 1000, the hub instance shared-state " + "limit; raising this above ~1023 will hit 'all slots occupied' errors). " + "Pass 0 to use every module under --snapshot-dir but expect failures " + "if the snapshot has more than ~1023 entries.") parser.add_argument("--workers", type=int, default=50) parser.add_argument("--poll-timeout", type=float, default=1200.0) parser.add_argument("--settle-seconds", type=float, default=5.0) @@ -446,8 +533,29 @@ def main() -> int: "baseline; pass this to measure dehydrate cost as well as hydrate.") parser.add_argument("--no-wipe-hub", action="store_true", help="Don't wipe the hub data dir before starting (useful to inspect leftover state)") + parser.add_argument("--blocking", action="store_true", + help="Force blocking S3Client path (pass --hub-hydration-async-enabled=false). " + "Default is the hub default, which routes S3 hydration through AsyncHttpClient.") + parser.add_argument("--hub-arg", action="append", default=[], + help="Extra raw arg appended verbatim to the hub command line. Repeatable.") parser.add_argument("--zenserver-dir", help="Directory containing zenserver + minio executables (auto-detected)") + # Toxiproxy: simulate network latency between hub and MinIO. + parser.add_argument("--toxiproxy-latency-ms", type=int, default=0, + help="If >0, route hub through toxiproxy with this latency injected on inbound+outbound. " + "Approximates real-network RTT on top of localhost MinIO.") + parser.add_argument("--toxiproxy-jitter-ms", type=int, default=0, + help="Jitter added to toxiproxy latency (passed as -a jitter=N to toxic add).") + parser.add_argument("--toxiproxy-port", type=int, default=9100, + help="Toxiproxy listen port for the MinIO proxy (the hub connects here when --toxiproxy-latency-ms > 0).") + parser.add_argument("--toxiproxy-api-port", type=int, default=8474, + help="Toxiproxy server API port.") + parser.add_argument("--toxiproxy-server-exe", default=None, + help="Path to toxiproxy-server executable. Defaults to ZEN_PERF_TOXIPROXY_SERVER_EXE " + "env var, or 'toxiproxy-server' from PATH.") + parser.add_argument("--toxiproxy-cli-exe", default=None, + help="Path to toxiproxy-cli executable. Defaults to ZEN_PERF_TOXIPROXY_CLI_EXE " + "env var, or 'toxiproxy-cli' from PATH.") args = parser.parse_args() minio_seeded = Path(args.minio_seeded).resolve() @@ -491,27 +599,39 @@ def main() -> int: # Wipe the hub data dir so every run starts from scratch unless the user opts out. if not args.no_wipe_hub and hub_data_dir.exists(): - print(f"[reset] wiping {hub_data_dir}") + print(f"[reset] wiping {hub_data_dir} ...", flush=True) + t0 = time.time() _rmtree_robust(hub_data_dir) + print(f"[reset] wipe done in {time.time()-t0:.1f}s", flush=True) hub_data_dir.mkdir(parents=True, exist_ok=True) + # If toxiproxy is enabled, the hub points at the proxy port, which forwards + # to the real MinIO with the configured latency injected. + hub_endpoint_port = args.toxiproxy_port if args.toxiproxy_latency_ms > 0 else args.minio_port + + s3_settings = { + "uri": f"s3://{args.bucket}", + "endpoint": f"http://localhost:{hub_endpoint_port}", + "path-style": True, + "region": "us-east-1", + } config_path = hub_data_dir / "hydration_config.json" config_path.write_text( - json.dumps({ - "type": "s3", - "settings": { - "uri": f"s3://{args.bucket}", - "endpoint": f"http://localhost:{args.minio_port}", - "path-style": True, - "region": "us-east-1", - }, - }), + json.dumps({"type": "s3", "settings": s3_settings}), encoding="ascii", ) hub_extra_args = [ f"--hub-hydration-target-config={config_path}", f"--hub-enable-dehydration={'true' if args.enable_dehydration else 'false'}", ] + if args.blocking: + hub_extra_args.append("--hub-hydration-async-enabled=false") + print("[mode] --blocking=true: hub will use blocking S3Client path") + else: + print("[mode] async path (default): hub routes S3 hydration through AsyncHttpClient") + if args.hub_arg: + hub_extra_args.extend(args.hub_arg) + print(f"[mode] extra hub args: {args.hub_arg}") if args.enable_dehydration: print("[mode] --enable-dehydration=true: deprovision will re-upload to MinIO; baseline will diverge") if args.trace: @@ -524,6 +644,7 @@ def main() -> int: } minio_proc: Optional[subprocess.Popen] = None + toxiproxy_proc: Optional[subprocess.Popen] = None hub_proc: Optional[subprocess.Popen] = None hub_log_handle = None exit_code = 0 @@ -542,6 +663,21 @@ def main() -> int: minio_proc = _start_minio(minio_exe, minio_run, args.minio_port, args.minio_console_port) _wait_for_minio(args.minio_port) + if args.toxiproxy_latency_ms > 0: + toxiproxy_server = _resolve_toxiproxy_exe( + args.toxiproxy_server_exe, "ZEN_PERF_TOXIPROXY_SERVER_EXE", "toxiproxy-server") + toxiproxy_cli = _resolve_toxiproxy_exe( + args.toxiproxy_cli_exe, "ZEN_PERF_TOXIPROXY_CLI_EXE", "toxiproxy-cli") + toxiproxy_proc = _start_toxiproxy( + toxiproxy_server, + toxiproxy_cli, + args.toxiproxy_api_port, + args.toxiproxy_port, + args.minio_port, + args.toxiproxy_latency_ms, + args.toxiproxy_jitter_ms, + ) + hub_log = hub_data_dir / "hub.log" hub_proc, hub_log_handle = _start_hub( zenserver_exe, hub_data_dir, args.hub_port, hub_log, @@ -601,6 +737,13 @@ def main() -> int: _zen_down_hub(zen_exe, hub_proc) if hub_log_handle is not None: hub_log_handle.close() + if toxiproxy_proc is not None and toxiproxy_proc.poll() is None: + try: + toxiproxy_proc.terminate() + toxiproxy_proc.wait(timeout=5) + except Exception: + toxiproxy_proc.kill() + print("[toxiproxy] stopped") if minio_proc is not None and minio_proc.poll() is None: _stop_minio_graceful(minio_proc) # Archive AFTER the hub has exited so in-flight log writes are flushed. diff --git a/scripts/test_scripts/hub/seed_minio.py b/scripts/test_scripts/hub/seed_minio.py index e0e45c4cb..f5928b995 100644 --- a/scripts/test_scripts/hub/seed_minio.py +++ b/scripts/test_scripts/hub/seed_minio.py @@ -1,13 +1,10 @@ #!/usr/bin/env python3 """Stage B of the perf-seed workflow. -Replays the snapshot produced by seed_s3_snapshot.py into a single MinIO -bucket so a later perf run can exercise that bucket. Pack mode is hardcoded -in this script (see --hub-hydration-enable-pack flag in _start_hub) and -governs how the hub uploads into MinIO. To compare packed vs unpacked, -invoke this script twice from two separate worktrees - one with the pack -flag flipped to false - producing two preserved MinIO data dirs, then run -run_minio_perf.py against each. +Replays the snapshot produced by seed_s3_snapshot.py into a MinIO bucket so +a later perf run can exercise it. Pack mode is fixed ON (the only mode the +perf-seed pipeline caters to) - the hub is launched with +--hub-hydration-enable-pack=true so dehydrate emits packed CAS. Flow: 1. Start a local MinIO server against --minio-data-dir, create the bucket. @@ -78,7 +75,7 @@ def _find_zenserver(override: Optional[str]) -> Path: sys.exit(f"zenserver not found at {p}") return p script_dir = Path(__file__).resolve().parent - repo_root = script_dir.parent.parent + repo_root = script_dir.parents[2] for mode in ("release", "debug"): for plat in (("windows", "x64"), ("linux", "x86_64"), ("macosx", "x86_64")): p = repo_root / "build" / plat[0] / plat[1] / mode / f"zenserver{_EXE_SUFFIX}" @@ -189,11 +186,10 @@ def _start_hub( "--hub-provision-disk-limit-percent=99", "--hub-provision-memory-limit-percent=80", f"--hub-instance-limit={instance_limit}", - # Seeding is not a perf-measurement path - we want it as fast as the - # host can manage. Let the hub go wide on both provisioning and - # hydration thread pools rather than matching prod limits. - "--hub-instance-provision-threads=64", - "--hub-hydration-threads=64", + # Provision / hydration / async-cap use server defaults; on a 128-core + # host these resolve to 16 / 16 / 512 which are sized for both the + # async hydrate (Stage A) and the sync dehydrate-PUT path that drives + # the upload here. # Prevent the watchdog from auto-deprovisioning modules while we're # still hydrating the tail / in the overlay phase. BOTH timers have to # be extended - the provisioned one (default 600s) is what bites on @@ -203,7 +199,7 @@ def _start_hub( # Explicit - default is true, but make it obvious that Stage B needs # it since the final deprovision drives the MinIO upload. "--hub-enable-dehydration=true", - # Pack worktree: turn pack on so dehydrate emits packed CAS. + # Pack ON (the only seeding mode) so dehydrate emits packed CAS. "--hub-hydration-enable-pack=true", ] + extra_args @@ -411,7 +407,16 @@ def _overlay_snapshot(snapshot_root: Path, hub_servers_root: Path, module_ids: l """Replace hub_servers_root/<mid>/* with snapshot_root/<mid>/*. snapshot_root is treated as read-only; only hub_servers_root is written to. + Uses ReFS block-clone (`refs_clone.clone_tree`) when source and dest live on + a ReFS volume so the overlay is O(1) per file via copy-on-write metadata + rather than a full byte copy. Falls back to byte copy per-file when the + volume is not ReFS or a file is too small for FSCTL_DUPLICATE_EXTENTS_TO_FILE. """ + from refs_clone import clone_tree, is_refs_volume + + use_clone = is_refs_volume(snapshot_root) and is_refs_volume(hub_servers_root) + print(f"[overlay] {'ReFS block-clone path' if use_clone else 'byte-copy path (non-ReFS)'}") + files_copied = 0 bytes_copied = 0 modules_overlaid = 0 @@ -424,16 +429,21 @@ def _overlay_snapshot(snapshot_root: Path, hub_servers_root: Path, module_ids: l continue if dst.exists(): _rmtree_robust(dst) - shutil.copytree(src, dst, symlinks=False, dirs_exist_ok=False) + if use_clone: + f, b = clone_tree(src, dst) + files_copied += f + bytes_copied += b + else: + shutil.copytree(src, dst, symlinks=False, dirs_exist_ok=False) + for root, _dirs, files in os.walk(dst): + for f in files: + p = Path(root) / f + try: + bytes_copied += p.stat().st_size + except OSError: + pass + files_copied += 1 modules_overlaid += 1 - for root, _dirs, files in os.walk(dst): - for f in files: - p = Path(root) / f - try: - bytes_copied += p.stat().st_size - except OSError: - pass - files_copied += 1 if i % 25 == 0 or i == len(module_ids): print(f"[overlay] {i}/{len(module_ids)} modules overlaid " f"({files_copied:,} files, {bytes_copied/1024/1024:.1f} MB)") @@ -601,21 +611,23 @@ def _seed_one_bucket( def main() -> int: parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument("--snapshot-dir", default="E:/Dev/zen-perf-seed/s3-snapshot", - help="Source of per-module server-state trees (READ-ONLY) (default: E:/Dev/zen-perf-seed/s3-snapshot)") - parser.add_argument("--hub-data-root", default="E:/Dev/zen-perf-seed/hubs", - help="Each bucket gets its own hub data dir under this root: <root>/hub-b-<bucket>/ " - "(default: E:/Dev/zen-perf-seed/hubs)") - parser.add_argument("--minio-data-dir", default="E:/Dev/zen-perf-seed/minio-data", - help="MinIO data dir shared by every bucket (default: E:/Dev/zen-perf-seed/minio-data)") + parser.add_argument("--snapshot-dir", required=True, + help="Source of per-module server-state trees (READ-ONLY).") + parser.add_argument("--hub-data-root", required=True, + help="Each bucket gets its own hub data dir under this root: <root>/hub-b-<bucket>/") + parser.add_argument("--minio-data-dir", required=True, + help="MinIO data dir shared by every bucket.") parser.add_argument("--minio-port", type=int, default=9000) parser.add_argument("--minio-console-port", type=int, default=9001) parser.add_argument("--hub-port", type=int, default=8558) parser.add_argument("--bucket", default="zen-seed-packed", - help="Bucket to seed (default: zen-seed-packed). Pack worktree - " - "hub is launched with --hub-hydration-enable-pack=true.") - parser.add_argument("--module-count", type=int, default=0, - help="Cap on modules processed (0 = all modules in snapshot-dir)") + help="Bucket to seed (default: zen-seed-packed). Hub is launched with " + "--hub-hydration-enable-pack=true (the only seeding mode).") + parser.add_argument("--module-count", type=int, default=1000, + help="Cap on modules processed (default 1000, the hub instance shared-state " + "limit; raising this above ~1023 will hit 'all slots occupied' errors). " + "Pass 0 to process every module under --snapshot-dir but expect failures " + "if the snapshot has more than ~1023 entries.") parser.add_argument("--workers", type=int, default=50) parser.add_argument("--poll-timeout", type=float, default=1800.0, help="Max seconds to wait for each state transition (default: 1800)") @@ -706,7 +718,7 @@ def main() -> int: for key, lbl in zip(phases, labels): print(f" {lbl:<22s} {timings.get(key, 0.0):>8.1f}") - print(f"[summary] next: preserve {minio_data_dir} to E:/Dev/zen-perf-seed/minio-seeded/") + print(f"[summary] next: preserve {minio_data_dir} via preserve_minio_state.py --source <this> --dest <baseline>") finally: if minio_proc is not None and minio_proc.poll() is None: diff --git a/scripts/test_scripts/hub/seed_s3_snapshot.py b/scripts/test_scripts/hub/seed_s3_snapshot.py index f0bc7b607..a99c9e130 100644 --- a/scripts/test_scripts/hub/seed_s3_snapshot.py +++ b/scripts/test_scripts/hub/seed_s3_snapshot.py @@ -87,7 +87,7 @@ def _find_zenserver(override: Optional[str]) -> Path: return p script_dir = Path(__file__).resolve().parent - repo_root = script_dir.parent.parent + repo_root = script_dir.parents[2] for mode in ("release", "debug"): for plat in (("windows", "x64"), ("linux", "x86_64"), ("macosx", "x86_64")): p = repo_root / "build" / plat[0] / plat[1] / mode / f"zenserver{_EXE_SUFFIX}" @@ -161,14 +161,24 @@ def _parse_s3_uri(uri: str) -> tuple[str, str]: return bucket, prefix +_EMPTY_MIN_OBJECTS = 3 +_EMPTY_MIN_BYTES = 16 * 1024 + +# Size buckets for stratified selection. Pyramid distribution mirrors typical +# workloads: many small modules, fewer medium, fewest large. +_SMALL_MAX_BYTES = 1 * 1024 * 1024 # <1 MiB +_LARGE_MIN_BYTES = 500 * 1024 * 1024 # >500 MiB +_BUCKET_RATIO = (("small", 500), ("medium", 350), ("large", 150)) + + def _list_module_ids(session, bucket: str, prefix: str, region: str, limit: int) -> list[str]: """List UUID-shaped module folders under the bucket and return the `limit` - most recently active ones, ranked by the LastModified of each module's - `incremental-state.cbo` (newest first - these were last dehydrated most - recently, which is the closest proxy for "most recently accessed"). + most recently active non-empty ones, ranked by the LastModified of each + module's `incremental-state.cbo` (newest first - these were last dehydrated + most recently, which is the closest proxy for "most recently accessed"). - Falls back to listing order for any folder whose state file can't be - HEADed (missing / 403 / transient error). + Modules with < _EMPTY_MIN_OBJECTS objects or < _EMPTY_MIN_BYTES total bytes + are treated as empty and dropped. """ s3 = session.client("s3", region_name=region) prefix_norm = prefix if (not prefix or prefix.endswith("/")) else prefix + "/" @@ -181,30 +191,147 @@ def _list_module_ids(session, bucket: str, prefix: str, region: str, limit: int) folder = cp.get("Prefix", "")[len(prefix_norm):].rstrip("/") if folder and _MODULEID_RE.match(folder): candidates.append(folder) - print(f"[s3] {len(candidates)} module folders match UUID shape; ranking by state.cbo LastModified...") + print(f"[s3] {len(candidates)} module folders match UUID shape; " + f"sizing + ranking by state.cbo LastModified (skip empty <{_EMPTY_MIN_OBJECTS} obj or <{_EMPTY_MIN_BYTES} B)...") - # 2. HEAD each module's state file in parallel. Missing/failed HEADs land - # at the tail via a sentinel epoch 0 timestamp. + # 2. Per-module ListObjectsV2: total object count + total bytes + state.cbo + # LastModified, all in one pass. Missing state file => epoch sentinel. from datetime import datetime, timezone epoch = datetime(1970, 1, 1, tzinfo=timezone.utc) - def _state_mtime(mid: str) -> datetime: - key = f"{prefix_norm}{mid}/incremental-state.cbo" + def _probe(mid: str) -> tuple[int, int, datetime]: + mid_prefix = f"{prefix_norm}{mid}/" + state_key = f"{mid_prefix}incremental-state.cbo" + count = 0 + total_bytes = 0 + state_mtime = epoch try: - resp = s3.head_object(Bucket=bucket, Key=key) - return resp.get("LastModified", epoch) + for page in paginator.paginate(Bucket=bucket, Prefix=mid_prefix): + for obj in page.get("Contents", []) or []: + count += 1 + total_bytes += int(obj.get("Size", 0)) + if obj.get("Key") == state_key: + state_mtime = obj.get("LastModified", epoch) or epoch except Exception: - return epoch + return 0, 0, epoch + return count, total_bytes, state_mtime with ThreadPoolExecutor(max_workers=50) as pool: - times = list(pool.map(_state_mtime, candidates)) - - # 3. Sort descending (newest first). Folders without a state file sink. - ranked = sorted(zip(candidates, times), key=lambda x: x[1], reverse=True) - missing = sum(1 for _, t in ranked if t == epoch) - if missing: - print(f"[s3] {missing}/{len(ranked)} modules have no incremental-state.cbo (sorted to tail)") - return [mid for mid, _ in ranked[:limit]] + probes = list(pool.map(_probe, candidates)) + + # 3. Stratified selection. Bucket non-empty modules by total_bytes into + # small/medium/large, sort each bucket by state.cbo mtime desc, and pull + # according to _BUCKET_RATIO scaled to `limit`. If a bucket is short, + # spill the deficit onto the relaxed pool first, then onto neighbouring + # buckets so the caller still gets `limit` modules. + small: list[tuple[str, datetime, int]] = [] + medium: list[tuple[str, datetime, int]] = [] + large: list[tuple[str, datetime, int]] = [] + relaxed: list[tuple[str, datetime, int, int]] = [] + empties = 0 + no_state = 0 + for mid, (count, total_bytes, mtime) in zip(candidates, probes): + is_empty = count < _EMPTY_MIN_OBJECTS or total_bytes < _EMPTY_MIN_BYTES + if is_empty: + empties += 1 + if mtime == epoch: + no_state += 1 + if not is_empty and mtime != epoch: + if total_bytes < _SMALL_MAX_BYTES: + small.append((mid, mtime, total_bytes)) + elif total_bytes >= _LARGE_MIN_BYTES: + large.append((mid, mtime, total_bytes)) + else: + medium.append((mid, mtime, total_bytes)) + else: + relaxed.append((mid, mtime, count, total_bytes)) + + for bucket in (small, medium, large): + bucket.sort(key=lambda x: x[1], reverse=True) + print(f"[s3] candidates={len(candidates)} empty={empties} no-state={no_state} " + f"small={len(small)} medium={len(medium)} large={len(large)}") + + total_ratio = sum(w for _, w in _BUCKET_RATIO) + targets = {name: max(0, (limit * w) // total_ratio) for name, w in _BUCKET_RATIO} + # Distribute rounding leftovers to first bucket(s) deterministically. + leftover = limit - sum(targets.values()) + for name, _ in _BUCKET_RATIO: + if leftover <= 0: + break + targets[name] += 1 + leftover -= 1 + + pools = {"small": small, "medium": medium, "large": large} + # If a bucket is short of its target, the deficit redirects to the nearest + # neighbour bucket(s) before falling back to anything further away. Small + # is rare in real workloads, so its deficit pads from medium - not large. + fallback_chain = { + "small": ["medium", "large"], + "medium": ["small", "large"], + "large": ["medium", "small"], + } + cursors = {"small": 0, "medium": 0, "large": 0} + selected: list[str] = [] + seen: set[str] = set() + fills = {"small": 0, "medium": 0, "large": 0} + redirected = {"small": 0, "medium": 0, "large": 0} + + def _take_from(name: str, n: int) -> int: + """Take up to n entries from pools[name] starting at cursor. Returns + the number actually taken.""" + if n <= 0: + return 0 + pool = pools[name] + start = cursors[name] + end = start + n + taken = 0 + for mid, _, _ in pool[start:end]: + if mid in seen: + continue + seen.add(mid) + selected.append(mid) + taken += 1 + cursors[name] = end + return taken + + for name, _ in _BUCKET_RATIO: + want = targets[name] + got = _take_from(name, want) + fills[name] = got + deficit = want - got + if deficit <= 0: + continue + for fb in fallback_chain[name]: + if deficit <= 0: + break + avail = max(0, len(pools[fb]) - cursors[fb]) + grab = min(deficit, avail) + if grab <= 0: + continue + got_fb = _take_from(fb, grab) + redirected[name] += got_fb + deficit -= got_fb + print(f"[s3] bucket fills: small={fills['small']}+{redirected['small']} " + f"medium={fills['medium']}+{redirected['medium']} " + f"large={fills['large']}+{redirected['large']} " + f"(numbers after '+' = redirected from neighbour buckets)") + + deficit = limit - len(selected) + if deficit > 0: + # Last-resort spill: relaxed pool (empties + no-state). + relaxed.sort(key=lambda x: (x[1], x[3], x[2]), reverse=True) + added = 0 + for mid, _, _, _ in relaxed: + if added >= deficit: + break + if mid in seen: + continue + seen.add(mid) + selected.append(mid) + added += 1 + print(f"[s3] extended with {added} relaxed entries to fill deficit") + + return selected[:limit] # --------------------------------------------------------------------------- @@ -231,11 +358,9 @@ def _start_hub( "--hub-provision-disk-limit-percent=99", "--hub-provision-memory-limit-percent=80", f"--hub-instance-limit={instance_limit}", - # Seeding is not a perf-measurement path - we want it as fast as the - # host can manage. Let the hub go wide on both provisioning and - # hydration thread pools rather than matching prod limits. - "--hub-instance-provision-threads=64", - "--hub-hydration-threads=64", + # Provision pool + async S3 in-flight cap use server defaults; on a + # 128-core host these resolve to 16 / 512 which is what the seed run + # needs (clamp(cpu/8, 4, 16) and clamp(cpu*4, 128, 512)). # With 1000 modules the seeding flow runs for 20+ minutes. Extend BOTH # watchdog inactivity timers so early-provisioned modules do not get # auto-deprovisioned while we're still hydrating the tail (hibernated @@ -326,6 +451,37 @@ def _hub_post(port: int, path: str, timeout_s: float = 60.0) -> tuple[int, dict] return 0, {"error": str(e)} +_HYDRATION_FAIL_RE = re.compile(r"Hydration of module '([0-9a-f-]+)' failed: (.+?)(?:\.\.|$)") + + +def _scan_hydration_failures(log_path: Path, offset: int, already_warned: set[str], label: str) -> int: + """Read new hub.log content from `offset` to EOF; print a [warn] line for + each newly observed `Hydration of module 'X' failed` warning. Returns the + new offset (resume point for next call).""" + try: + size = log_path.stat().st_size + except OSError: + return offset + if size <= offset: + return offset + try: + with open(log_path, "rb") as f: + f.seek(offset) + chunk = f.read(size - offset).decode("utf-8", errors="replace") + except OSError: + return offset + for m in _HYDRATION_FAIL_RE.finditer(chunk): + mid, reason = m.group(1), m.group(2).strip() + if mid in already_warned: + continue + already_warned.add(mid) + # Trim long S3 error reasons so the line stays scannable. + if len(reason) > 200: + reason = reason[:200] + "..." + print(f"\n[{label}] HYDRATION FAILED {mid}: {reason}", flush=True) + return size + + def _hub_module_states(port: int, timeout_s: float = 10.0) -> Optional[dict[str, str]]: url = f"http://localhost:{port}/hub/status" req = urllib.request.Request(url, headers={"Accept": "application/json"}) @@ -370,16 +526,25 @@ def _wait_for_state( target_state: str, timeout_s: float, label: str, + hub_log: Optional[Path] = None, + hydration_warned: Optional[set[str]] = None, ) -> tuple[list[str], list[str], dict[str, str]]: """Poll hub status until every module hits target_state, fails, or times out. Returns (stuck, failed, last_states). 'stuck' = still mid-transition when we timed out. 'failed' = hit an _FAILED_STATES value such as 'crashed'. + + When hub_log + hydration_warned are supplied, the wait loop also tails the + hub log and surfaces `Hydration of module ... failed` warnings as they + appear. Hydration failures do NOT push modules into _FAILED_STATES (the hub + cleans the state and still marks the module 'provisioned'), so this is the + only way to see them in script output. """ deadline = time.monotonic() + timeout_s remaining = set(module_ids) failed: list[str] = [] last_states: dict[str, str] = {mid: "" for mid in module_ids} + log_offset = hub_log.stat().st_size if (hub_log and hub_log.exists()) else 0 while remaining and time.monotonic() < deadline: states = _hub_module_states(port) @@ -393,10 +558,16 @@ def _wait_for_state( remaining.discard(mid) failed.append(mid) done = len(module_ids) - len(remaining) - print(f"[{label}] {done}/{len(module_ids)} '{target_state}' ({len(failed)} failed)...", end="\r") + warned_count = len(hydration_warned) if hydration_warned is not None else 0 + suffix = f", hydration-failed {warned_count}" if hydration_warned is not None else "" + print(f"[{label}] {done}/{len(module_ids)} '{target_state}' ({len(failed)} failed{suffix})...", end="\r") + if hub_log is not None and hydration_warned is not None: + log_offset = _scan_hydration_failures(hub_log, log_offset, hydration_warned, label) time.sleep(2.0) print() + if hub_log is not None and hydration_warned is not None: + _scan_hydration_failures(hub_log, log_offset, hydration_warned, label) return list(remaining), failed, last_states @@ -404,16 +575,37 @@ def _wait_for_state( # Snapshot copy (run AFTER hub shutdown so there's no concurrent writer) # --------------------------------------------------------------------------- +def _same_volume(a: Path, b: Path) -> bool: + """True when a and b live on the same filesystem volume (so an os.replace + rename is O(1) instead of an EXDEV-driven copy+delete fallback).""" + try: + return os.stat(a).st_dev == os.stat(b).st_dev + except OSError: + return False + + def _copy_snapshot(src_root: Path, dst_root: Path, module_ids: list[str]) -> tuple[int, int, int]: - """Copy src_root/<mid>/* to dst_root/<mid>/* for each module. + """Move (preferred) or copy src_root/<mid>/ to dst_root/<mid>/ for each module. - Returns (modules_copied, files_copied, bytes_copied). + Hub data dir is throwaway after Stage A (the hub is shut down right after + this step), so per-module trees can be moved out of it. When src_root and + dst_root share a volume the move is an O(1) directory rename; when they + don't, shutil.move falls back to a byte copy plus rmtree, which matches + the old shutil.copytree cost. + + Returns (modules_moved, files_moved, bytes_moved). Replaces only the specific per-module subdirs; never touches siblings. """ dst_root.mkdir(parents=True, exist_ok=True) - modules_copied = 0 - files_copied = 0 - bytes_copied = 0 + same_vol = _same_volume(src_root, dst_root) + if same_vol: + print(f"[snapshot] src and dst share a volume; moving per-module trees (O(1) rename per module)") + else: + print(f"[snapshot] src and dst are on different volumes; falling back to byte copy") + + modules_moved = 0 + files_moved = 0 + bytes_moved = 0 for i, mid in enumerate(module_ids, 1): src = src_root / mid @@ -423,21 +615,24 @@ def _copy_snapshot(src_root: Path, dst_root: Path, module_ids: list[str]) -> tup continue if dst.exists(): _rmtree_robust(dst) - shutil.copytree(src, dst, symlinks=False, dirs_exist_ok=False) - modules_copied += 1 + if same_vol: + os.replace(src, dst) + else: + shutil.move(str(src), str(dst)) + modules_moved += 1 for root, _dirs, files in os.walk(dst): for f in files: p = Path(root) / f try: - bytes_copied += p.stat().st_size + bytes_moved += p.stat().st_size except OSError: pass - files_copied += 1 + files_moved += 1 if i % 25 == 0 or i == len(module_ids): - print(f"[snapshot] {i}/{len(module_ids)} modules copied " - f"({files_copied:,} files, {bytes_copied/1024/1024:.1f} MB)") + print(f"[snapshot] {i}/{len(module_ids)} modules moved " + f"({files_moved:,} files, {bytes_moved/1024/1024:.1f} MB)") - return modules_copied, files_copied, bytes_copied + return modules_moved, files_moved, bytes_moved # --------------------------------------------------------------------------- @@ -447,10 +642,10 @@ def _copy_snapshot(src_root: Path, dst_root: Path, module_ids: list[str]) -> tup def main() -> int: parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument("--hub-data-dir", default="E:/Dev/zen-perf-seed/hub-a", - help="Hub --data-dir (default: E:/Dev/zen-perf-seed/hub-a)") - parser.add_argument("--snapshot-dir", default="E:/Dev/zen-perf-seed/s3-snapshot", - help="Destination for per-module server-state trees (default: E:/Dev/zen-perf-seed/s3-snapshot)") + parser.add_argument("--hub-data-dir", required=True, + help="Hub --data-dir. Place on the same volume as --snapshot-dir so the snapshot step is a rename per module instead of a cross-volume byte copy.") + parser.add_argument("--snapshot-dir", required=True, + help="Destination for per-module server-state trees.") parser.add_argument("--port", type=int, default=8558, help="Hub HTTP port (default: 8558)") parser.add_argument("--module-count", type=int, default=1000, @@ -505,6 +700,15 @@ def main() -> int: if len(module_ids) < args.module_count: print(f"[s3] WARNING: asked for {args.module_count} modules, only {len(module_ids)} matched the UUID filter") + # _list_module_ids returns modules grouped by size bucket (small first, then + # medium, then large). Provisioning in that order would create a size-sorted + # fan-out and skew load: all small first (cheap, fast hydrations) then all + # large (heavy, slow). Shuffle so provision/hibernate/copy hit a mixed-size + # stream that better resembles real workload distribution. Fixed seed keeps + # runs reproducible across reseeds with the same bucket contents. + import random + random.Random(0xC0FFEE).shuffle(module_ids) + if not module_ids: sys.exit("[s3] no module folders found, aborting") @@ -554,10 +758,14 @@ def main() -> int: if not accepted: sys.exit("[provision] nothing accepted, aborting") - stuck, failed, last_states = _wait_for_state(args.port, accepted, "provisioned", args.poll_timeout, "provision") + hydration_warned: set[str] = set() + stuck, failed, last_states = _wait_for_state( + args.port, accepted, "provisioned", args.poll_timeout, "provision", + hub_log=hub_log, hydration_warned=hydration_warned, + ) prov_done = len(accepted) - len(stuck) - len(failed) - print(f"[provision] complete: {prov_done}/{len(accepted)} provisioned, {len(failed)} failed, {len(stuck)} stuck " - f"({time.monotonic()-t0:.1f}s)") + print(f"[provision] complete: {prov_done}/{len(accepted)} provisioned, {len(failed)} failed, {len(stuck)} stuck, " + f"{len(hydration_warned)} hydration-failed ({time.monotonic()-t0:.1f}s)") if failed: for mid in failed[:10]: print(f"[provision] FAILED {mid}: last state='{last_states.get(mid, '')}'") @@ -589,20 +797,23 @@ def main() -> int: for mid in (failed_hib + stuck_hib)[:10]: print(f"[hibernate] not-hibernated {mid}: last state='{last_states_hib.get(mid, '')}'") - # --- Copy snapshots while hub is still running. All instances are - # hibernated (no writers), watchdog-hibernated-timeout is 86400s - # (no auto-deprovision), hub is only touching its own metadata - # outside servers/<mid>/. Safe. --- - copy_src = hub_data_dir / "servers" - to_copy = [m for m in hib_accepted if m not in set(stuck_hib) and m not in set(failed_hib)] - print(f"[snapshot] copying {len(to_copy)} module trees from {copy_src} -> {snapshot_dir}") + # --- Snapshot per-module state out of the hub data dir while the hub + # is still running. All instances are hibernated (no writers), + # watchdog-hibernated-timeout is 86400s (no auto-deprovision), + # hub is only touching its own metadata outside servers/<mid>/. + # When --hub-data-dir and --snapshot-dir share a volume the per- + # module trees are renamed (O(1)); cross-volume falls back to a + # byte copy. The hub data dir is wiped on the next run regardless. + snapshot_src = hub_data_dir / "servers" + to_snapshot = [m for m in hib_accepted if m not in set(stuck_hib) and m not in set(failed_hib)] + print(f"[snapshot] moving {len(to_snapshot)} module trees from {snapshot_src} -> {snapshot_dir}") t0 = time.monotonic() - modules_copied, files_copied, bytes_copied = _copy_snapshot(copy_src, snapshot_dir, to_copy) - print(f"[snapshot] copied {modules_copied} modules, {files_copied:,} files, {bytes_copied/1024/1024:.1f} MB " + modules_moved, files_moved, bytes_moved = _copy_snapshot(snapshot_src, snapshot_dir, to_snapshot) + print(f"[snapshot] moved {modules_moved} modules, {files_moved:,} files, {bytes_moved/1024/1024:.1f} MB " f"({time.monotonic()-t0:.1f}s)") - if modules_copied < len(to_copy): - print(f"[snapshot] WARNING: only {modules_copied}/{len(to_copy)} trees copied") + if modules_moved < len(to_snapshot): + print(f"[snapshot] WARNING: only {modules_moved}/{len(to_snapshot)} trees moved") exit_code = 1 # --- Graceful hub shutdown via 'zen down' --- diff --git a/scripts/test_scripts/hub/sweep_threads.py b/scripts/test_scripts/hub/sweep_threads.py new file mode 100644 index 000000000..f92a5d734 --- /dev/null +++ b/scripts/test_scripts/hub/sweep_threads.py @@ -0,0 +1,264 @@ +"""Greedy hill-climb search for hub thread balance. + +Per phase: from current best, try +STEP on each of (data, lifetime, hydration). +Pick best. Repeat. Stop when no candidate improves or all blocked by caps. + +Constraints: +- async_threads fixed at 1 +- step size = 4 per phase +- per-dim cap: 32 +- total cap: data + lifetime + hydration <= 96 +- start (8, 8, 8) + +Usage: + python sweep_threads.py +""" +from __future__ import annotations +import argparse +import atexit +import json +import signal +import subprocess +import sys +from pathlib import Path +import csv +import time + +HERE = Path(__file__).resolve().parent +RUNNER = HERE / "run_minio_perf.py" +LOG = HERE / "threads_sweep.csv" +ZEN_DIR = "build/windows/x64/release" +STEP = 4 +PER_DIM_CAP = 32 +TOTAL_CAP = 96 +ASYNC = 1 +DIMS = ("data", "lifetime", "hydration") +CHILD_PROCS: list[subprocess.Popen] = [] + + +def kill_strays(): + """Best-effort terminate hub, minio, toxiproxy. Idempotent.""" + for proc in CHILD_PROCS: + if proc.poll() is None: + try: + proc.terminate() + except Exception: + pass + targets = ["zenserver.exe", "minio.exe", "toxiproxy-server.exe", "toxiproxy.exe"] + for name in targets: + try: + subprocess.run(["taskkill", "/F", "/IM", name], capture_output=True, timeout=10) + except Exception: + pass + + +def install_signal_handlers(): + def handler(signum, frame): + print(f"\n[sweep] signal {signum} received - cleaning up", flush=True) + kill_strays() + sys.exit(130) + for sig in (signal.SIGINT, signal.SIGTERM): + try: + signal.signal(sig, handler) + except Exception: + pass + atexit.register(kill_strays) + + +def run_once(data: int, lifetime: int, hydration: int, latency_ms: int, label: str, async_threads: int = ASYNC) -> dict: + cmd = [ + sys.executable, "-u", str(RUNNER), + "--async-http", + f"--async-threads={async_threads}", + f"--toxiproxy-latency-ms={latency_ms}", + f"--zenserver-dir={ZEN_DIR}", + f"--hub-arg=--hub-instance-provision-threads={data}", + f"--hub-arg=--hub-instance-spawn-threads={lifetime}", + f"--hub-arg=--hub-hydration-threads={hydration}", + ] + print(f"\n=== [{label}] data={data} lifetime={lifetime} hydration={hydration} async={async_threads} latency={latency_ms}ms ===", flush=True) + t0 = time.time() + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) + CHILD_PROCS.append(proc) + archive = None + rc = 0 + assert proc.stdout is not None + for line in proc.stdout: + line = line.rstrip() + if not line: + continue + if line.startswith("[archive]"): + archive = line.split(" ", 1)[1].split(" ")[0] + # Compress spinner-tick pile-up to last segment so we don't print + # "[provision] N/1000 provisioned..." 5x per line. Only triggers when + # the line is actually a spinner concat (contains the milestone marker). + if line.count("...") >= 2 and ("/1000 provisioned" in line or "/1000 gone" in line): + line = line.split("...")[-1].lstrip() + if not line: + continue + print(f" {line}", flush=True) + rc = proc.wait() + if proc in CHILD_PROCS: + CHILD_PROCS.remove(proc) + wall = time.time() - t0 + summary = {} + if archive: + sj = Path(archive) / "summary.json" + if sj.exists(): + summary = json.loads(sj.read_text()) + row = { + "label": label, + "data": data, + "lifetime": lifetime, + "hydration": hydration, + "async_threads": async_threads, + "latency_ms": latency_ms, + "exit_code": rc, + "wall_s": round(wall, 1), + "provision_s": summary.get("provision_total_s"), + "deprovision_s": summary.get("deprovision_total_s"), + "total_s": summary.get("total_s"), + "archive": archive, + } + print(f" -> total={row['total_s']}s provision={row['provision_s']}s deprovision={row['deprovision_s']}s exit={row['exit_code']}", flush=True) + append_row(row) + return row + + +def append_row(row: dict): + new = not LOG.exists() + with LOG.open("a", newline="") as f: + w = csv.DictWriter(f, fieldnames=list(row.keys())) + if new: + w.writeheader() + w.writerow(row) + + +def feasible(p: dict) -> bool: + if any(p[d] > PER_DIM_CAP for d in DIMS): + return False + if sum(p[d] for d in DIMS) > TOTAL_CAP: + return False + return True + + +def candidates(center: dict) -> list[tuple[str, dict]]: + out = [] + for d in DIMS: + c = dict(center) + c[d] += STEP + if feasible(c): + out.append((d, c)) + return out + + +def parse_points(s: str, default_async: int) -> list[tuple[int, int, int, int]]: + out = [] + for chunk in s.split(";"): + chunk = chunk.strip() + if not chunk: + continue + parts = [int(x.strip()) for x in chunk.split(",")] + if len(parts) == 3: + d, l, h = parts + a = default_async + elif len(parts) == 4: + d, l, h, a = parts + else: + raise SystemExit(f"--points entry needs 3 ints (d,l,h) or 4 (d,l,h,async): {chunk!r}") + out.append((d, l, h, a)) + return out + + +def run_points(points: list[tuple[int, int, int, int]], latency_ms: int): + rows = [] + for i, (d, l, h, a) in enumerate(points, 1): + if d > PER_DIM_CAP or l > PER_DIM_CAP or h > PER_DIM_CAP: + print(f" -- skip {d},{l},{h} (per-dim cap {PER_DIM_CAP})", flush=True) + continue + if d + l + h > TOTAL_CAP: + print(f" -- skip {d},{l},{h} (total {d+l+h} > {TOTAL_CAP})", flush=True) + continue + row = run_once(data=d, lifetime=l, hydration=h, latency_ms=latency_ms, + label=f"pt{i}-{d},{l},{h},a{a}", async_threads=a) + rows.append(row) + return rows + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--latency-ms", type=int, default=30) + p.add_argument("--max-phases", type=int, default=20) + p.add_argument("--async-threads", type=int, default=ASYNC, + help=f"AsyncHttpClient io_context threads/shards (default {ASYNC}). " + "Per-point override available via 4th value in --points entry.") + p.add_argument("--points", type=str, default="", + help="Run an explicit list of (d,l,h) or (d,l,h,async) points instead of hill-climb. " + "Format: 'd,l,h;d,l,h,a;...'. CSV appended.") + args = p.parse_args() + + install_signal_handlers() + + if args.points: + pts = parse_points(args.points, args.async_threads) + print(f"[sweep] explicit list: {len(pts)} points", flush=True) + rows = run_points(pts, args.latency_ms) + valid = [r for r in rows if r.get("total_s") is not None and r["exit_code"] == 0] + if valid: + valid.sort(key=lambda r: r["total_s"]) + print("\n=== POINT RESULTS (sorted by total_s) ===", flush=True) + for r in valid: + tot = r["data"] + r["lifetime"] + r["hydration"] + print(f" d={r['data']:>3} l={r['lifetime']:>3} h={r['hydration']:>3} a={r['async_threads']:>2} sum={tot:>3} -> total={r['total_s']:.1f}s prov={r['provision_s']:.1f}s deprov={r['deprovision_s']:.1f}s", flush=True) + print(f" log={LOG}", flush=True) + return + + cur = dict(data=8, lifetime=8, hydration=8) + tried = set() + history = [] + + base = run_once(**cur, latency_ms=args.latency_ms, label="phase0") + history.append(base) + tried.add((cur["data"], cur["lifetime"], cur["hydration"])) + best = base + + for phase in range(1, args.max_phases + 1): + cands = candidates(cur) + if not cands: + print(f" === phase {phase}: all candidates blocked (per-dim {PER_DIM_CAP} or total {TOTAL_CAP}) ===", flush=True) + break + phase_best = None + for dim, c in cands: + key = (c["data"], c["lifetime"], c["hydration"]) + if key in tried: + continue + tried.add(key) + row = run_once(**c, latency_ms=args.latency_ms, label=f"p{phase}-{dim}+{STEP}") + history.append(row) + if row.get("total_s") is not None and row["exit_code"] == 0: + if phase_best is None or row["total_s"] < phase_best["total_s"]: + phase_best = row + if phase_best is None or phase_best["total_s"] >= best["total_s"] - 1.0: + print(f" === phase {phase}: no improvement (best {best['total_s']}s vs phase {phase_best and phase_best['total_s']}s) ===", flush=True) + break + cur = dict(data=phase_best["data"], lifetime=phase_best["lifetime"], hydration=phase_best["hydration"]) + best = phase_best + print(f" +++ phase {phase}: -> data={cur['data']} lifetime={cur['lifetime']} hydration={cur['hydration']} total={best['total_s']}s", flush=True) + + print("\n=== OPTIMUM ===", flush=True) + print(f" data={cur['data']} lifetime={cur['lifetime']} hydration={cur['hydration']} async={ASYNC}", flush=True) + print(f" total={best['total_s']}s provision={best['provision_s']}s deprovision={best['deprovision_s']}s", flush=True) + print(f" log={LOG}", flush=True) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\n[sweep] interrupted - cleaning up", flush=True) + kill_strays() + sys.exit(130) + except Exception as e: + print(f"\n[sweep] error: {e!r} - cleaning up", flush=True) + kill_strays() + raise diff --git a/src/zencore/include/zencore/parallelwork.h b/src/zencore/include/zencore/parallelwork.h index 536b0a056..d9b20b9d7 100644 --- a/src/zencore/include/zencore/parallelwork.h +++ b/src/zencore/include/zencore/parallelwork.h @@ -13,6 +13,8 @@ namespace zen { class ParallelWork { public: + class ExternalWorkToken; + ParallelWork(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, WorkerThreadPool::EMode Mode); ~ParallelWork(); @@ -74,9 +76,53 @@ public: Latch& PendingWork() { return m_PendingWork; } + // Register a unit of work whose completion is signalled out-of-band (typically + // from an async callback firing on a different thread). Counter increments now; + // the returned token's Complete()/Fail() decrements. Used by S3AsyncStorage so + // in-flight S3 requests count against ParallelWork without occupying a worker. + [[nodiscard]] ExternalWorkToken RegisterExternal(); + + class ExternalWorkToken + { + public: + ExternalWorkToken() = default; + + ExternalWorkToken(const ExternalWorkToken&) = delete; + ExternalWorkToken& operator=(const ExternalWorkToken&) = delete; + + ExternalWorkToken(ExternalWorkToken&& Other) noexcept : m_Owner(Other.m_Owner) { Other.m_Owner = nullptr; } + ExternalWorkToken& operator=(ExternalWorkToken&& Other) noexcept + { + if (this != &Other) + { + Release(); + m_Owner = Other.m_Owner; + Other.m_Owner = nullptr; + } + return *this; + } + + ~ExternalWorkToken() { Release(); } + + void Complete(); + void Fail(std::exception_ptr Ex); + + bool IsActive() const { return m_Owner != nullptr; } + + private: + friend class ParallelWork; + + explicit ExternalWorkToken(ParallelWork* Owner) : m_Owner(Owner) {} + + void Release(); + + ParallelWork* m_Owner = nullptr; + }; + private: ExceptionCallback DefaultErrorFunction(); void RethrowErrors(); + void RecordExternalError(std::exception_ptr Ex); std::atomic<bool>& m_AbortFlag; std::atomic<bool>& m_PauseFlag; diff --git a/src/zencore/parallelwork.cpp b/src/zencore/parallelwork.cpp index 94696f479..ec00fe0bc 100644 --- a/src/zencore/parallelwork.cpp +++ b/src/zencore/parallelwork.cpp @@ -2,6 +2,7 @@ #include <zencore/parallelwork.h> +#include <zencore/assertfmt.h> #include <zencore/callstack.h> #include <zencore/except.h> #include <zencore/fmtutils.h> @@ -11,6 +12,8 @@ #if ZEN_WITH_TESTS # include <zencore/testing.h> + +# include <thread> #endif // ZEN_WITH_TESTS namespace zen { @@ -90,6 +93,65 @@ ParallelWork::DefaultErrorFunction() } void +ParallelWork::RecordExternalError(std::exception_ptr Ex) +{ + m_ErrorLock.WithExclusiveLock([&]() { m_Errors.push_back(Ex); }); + m_AbortFlag = true; +} + +ParallelWork::ExternalWorkToken +ParallelWork::RegisterExternal() +{ + ZEN_ASSERT(!m_DispatchComplete); + m_PendingWork.AddCount(1); + return ExternalWorkToken(this); +} + +void +ParallelWork::ExternalWorkToken::Complete() +{ + ZEN_ASSERT(m_Owner != nullptr); + m_Owner->m_PendingWork.CountDown(); + m_Owner = nullptr; +} + +void +ParallelWork::ExternalWorkToken::Fail(std::exception_ptr Ex) +{ + ZEN_ASSERT(m_Owner != nullptr); + // Null exception_ptr would propagate as std::bad_exception via + // rethrow_exception(nullptr) and mask the real failure mode. Catches + // patterns like MakeGuard([Token]{ Token->Fail(std::current_exception()); }) + // firing on a normal-return path where no exception is in flight. + ZEN_ASSERT(Ex != nullptr); + m_Owner->RecordExternalError(Ex); + m_Owner->m_PendingWork.CountDown(); + m_Owner = nullptr; +} + +void +ParallelWork::ExternalWorkToken::Release() +{ + if (m_Owner != nullptr) + { + // Tests should fail loudly so that any leaked path surfaces immediately; + // in production we keep the safety-net countdown so a leak does not deadlock + // the latch but log it as an error rather than a warning - this is always + // a programming bug. +#if ZEN_WITH_TESTS + ZEN_ASSERT_FORMAT(false, "ParallelWork::ExternalWorkToken destroyed without Complete()/Fail()"); +#else + ZEN_ERROR("ParallelWork::ExternalWorkToken destroyed without Complete()/Fail(); decrementing latch as safety net"); + // Surface as an error from Wait()/RethrowErrors() so the caller does not see a phantom success. + m_Owner->RecordExternalError( + std::make_exception_ptr(std::runtime_error("ParallelWork::ExternalWorkToken destroyed without Complete()/Fail()"))); +#endif + m_Owner->m_PendingWork.CountDown(); + m_Owner = nullptr; + } +} + +void ParallelWork::Wait(int32_t UpdateIntervalMS, UpdateCallback&& UpdateCallback) { ZEN_ASSERT(!m_DispatchComplete); @@ -257,6 +319,100 @@ TEST_CASE("parallellwork.limitqueue") Work.Wait(); } +TEST_CASE("parallellwork.external_basic") +{ + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + std::vector<ParallelWork::ExternalWorkToken> Tokens; + for (uint32_t I = 0; I < 5; I++) + { + Tokens.push_back(Work.RegisterExternal()); + } + for (auto& Token : Tokens) + { + Token.Complete(); + } + + Work.Wait(); + CHECK_FALSE(AbortFlag.load()); +} + +TEST_CASE("parallellwork.external_completes_from_other_thread") +{ + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + auto Token = Work.RegisterExternal(); + std::thread Worker([Token = std::move(Token)]() mutable { + Sleep(20); + Token.Complete(); + }); + + Work.Wait(); + Worker.join(); + CHECK_FALSE(AbortFlag.load()); +} + +TEST_CASE("parallellwork.external_fail_propagates_exception") +{ + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + auto Token = Work.RegisterExternal(); + try + { + throw std::runtime_error("external work failed"); + } + catch (...) + { + Token.Fail(std::current_exception()); + } + + CHECK_THROWS_WITH(Work.Wait(), "external work failed"); + CHECK(AbortFlag.load()); +} + +TEST_CASE("parallellwork.external_mixed_with_scheduled") +{ + WorkerThreadPool WorkerPool(2); + + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + std::atomic<uint32_t> ScheduledCount = 0; + for (uint32_t I = 0; I < 3; I++) + { + Work.ScheduleWork(WorkerPool, [&ScheduledCount](std::atomic<bool>& AbortFlag) { + ZEN_UNUSED(AbortFlag); + ScheduledCount++; + }); + } + + std::vector<ParallelWork::ExternalWorkToken> Tokens; + for (uint32_t I = 0; I < 3; I++) + { + Tokens.push_back(Work.RegisterExternal()); + } + + std::thread Completer([&]() { + for (auto& Token : Tokens) + { + Sleep(5); + Token.Complete(); + } + }); + + Work.Wait(); + Completer.join(); + + CHECK_EQ(ScheduledCount.load(), 3u); +} + TEST_SUITE_END(); void diff --git a/src/zenhttp/asynchttpclient_test.cpp b/src/zenhttp/asynchttpclient_test.cpp index 151863370..0b6877c7b 100644 --- a/src/zenhttp/asynchttpclient_test.cpp +++ b/src/zenhttp/asynchttpclient_test.cpp @@ -5,21 +5,24 @@ #if ZEN_WITH_TESTS +# include <zencore/basicfile.h> # include <zencore/iobuffer.h> # include <zencore/logging.h> # include <zencore/scopeguard.h> # include <zencore/testing.h> # include <zencore/testutils.h> +# include <zencore/thread.h> # include "servers/httpasio.h" -# include <atomic> -# include <thread> - ZEN_THIRD_PARTY_INCLUDES_START # include <asio.hpp> ZEN_THIRD_PARTY_INCLUDES_END +# include <atomic> +# include <cstring> +# include <thread> + namespace zen { using namespace std::literals; @@ -67,13 +70,65 @@ public: Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"ok\":true}"); }, HttpVerb::kGet); + + m_Router.RegisterRoute( + "large", + [](HttpRouterRequest& Req) { + // 4 MB body so the response exercises chunked write callbacks. + IoBuffer Body(4u * 1024u * 1024u); + uint8_t* Data = static_cast<uint8_t*>(Body.MutableData()); + for (size_t I = 0; I < Body.GetSize(); ++I) + { + Data[I] = static_cast<uint8_t>(I & 0xFF); + } + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Body); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "slow", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + std::string_view MsStr = Params.GetValue("ms"); + int Ms = MsStr.empty() ? 100 : std::atoi(std::string(MsStr).c_str()); + m_SlowHits.fetch_add(1, std::memory_order_relaxed); + zen::Sleep(Ms); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow ok"); + }, + HttpVerb::kGet); + + // Returns 503 for the first ?fail=N requests, then 200 for the rest. + m_Router.RegisterRoute( + "flaky", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + std::string_view FailStr = Params.GetValue("fail"); + const int FailN = FailStr.empty() ? 0 : std::atoi(std::string(FailStr).c_str()); + const int Hit = m_FlakyHits.fetch_add(1, std::memory_order_relaxed) + 1; + if (Hit <= FailN) + { + HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable, HttpContentType::kText, "fail"); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + }, + HttpVerb::kGet); } + std::atomic<int>& SlowHits() { return m_SlowHits; } + std::atomic<int>& FlakyHits() { return m_FlakyHits; } + virtual const char* BaseUri() const override { return "/api/async-test/"; } virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); } private: HttpRequestRouter m_Router; + std::atomic<int> m_SlowHits{0}; + std::atomic<int> m_FlakyHits{0}; }; ////////////////////////////////////////////////////////////////////////// @@ -118,189 +173,593 @@ struct AsyncTestServerFixture TEST_SUITE_BEGIN("http.asynchttpclient"); -TEST_CASE("asynchttpclient.future.verbs") +// Future API + callback API + verb dispatch + payload echo + lifecycle. All +// scopes share one fixture and one default-settings client. Per-scope sets +// up its own promises/futures. +TEST_CASE("asynchttpclient.basic") { AsyncTestServerFixture Fixture; AsyncHttpClient Client = Fixture.MakeClient(); - SUBCASE("GET returns 200 with expected body") + // future.verbs - GET / POST / PUT / DELETE / HEAD echo the verb. { - auto Future = Client.Get("/api/async-test/echo/method"); - auto Resp = Future.get(); + auto Resp = Client.Get("/api/async-test/echo/method").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), "GET"); } - - SUBCASE("POST dispatches correctly") { - auto Future = Client.Post("/api/async-test/echo/method"); - auto Resp = Future.get(); + auto Resp = Client.Post("/api/async-test/echo/method").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), "POST"); } - - SUBCASE("PUT dispatches correctly") { - auto Future = Client.Put("/api/async-test/echo/method"); - auto Resp = Future.get(); + auto Resp = Client.Put("/api/async-test/echo/method").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), "PUT"); } - - SUBCASE("DELETE dispatches correctly") { - auto Future = Client.Delete("/api/async-test/echo/method"); - auto Resp = Future.get(); + auto Resp = Client.Delete("/api/async-test/echo/method").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), "DELETE"); } - - SUBCASE("HEAD returns 200 with empty body") { - auto Future = Client.Head("/api/async-test/echo/method"); - auto Resp = Future.get(); + auto Resp = Client.Head("/api/async-test/echo/method").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), ""sv); } -} -TEST_CASE("asynchttpclient.future.get") -{ - AsyncTestServerFixture Fixture; - AsyncHttpClient Client = Fixture.MakeClient(); - - SUBCASE("simple GET with text response") + // future.get - text body, JSON body, 204 NoContent. { - auto Future = Client.Get("/api/async-test/hello"); - auto Resp = Future.get(); + auto Resp = Client.Get("/api/async-test/hello").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK); CHECK_EQ(Resp.AsText(), "hello world"); } - - SUBCASE("GET returning JSON") { - auto Future = Client.Get("/api/async-test/json"); - auto Resp = Future.get(); + auto Resp = Client.Get("/api/async-test/json").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.AsText(), "{\"ok\":true}"); } - - SUBCASE("GET 204 NoContent") { - auto Future = Client.Get("/api/async-test/nocontent"); - auto Resp = Future.get(); + auto Resp = Client.Get("/api/async-test/nocontent").get(); CHECK(Resp.IsSuccess()); CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); } -} - -TEST_CASE("asynchttpclient.future.post.with.payload") -{ - AsyncTestServerFixture Fixture; - AsyncHttpClient Client = Fixture.MakeClient(); - std::string_view PayloadStr = "async payload data"; - IoBuffer Payload(IoBuffer::Clone, PayloadStr.data(), PayloadStr.size()); - Payload.SetContentType(ZenContentType::kText); + // future.post.with.payload + future.put.with.payload - echo round-trips. + { + std::string_view Str = "async payload data"; + IoBuffer Payload(IoBuffer::Clone, Str.data(), Str.size()); + Payload.SetContentType(ZenContentType::kText); + auto Resp = Client.Post("/api/async-test/echo", Payload).get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "async payload data"); + } + { + std::string_view Str = "put payload"; + IoBuffer Payload(IoBuffer::Clone, Str.data(), Str.size()); + Payload.SetContentType(ZenContentType::kText); + auto Resp = Client.Put("/api/async-test/echo", Payload).get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "put payload"); + } - auto Future = Client.Post("/api/async-test/echo", Payload); - auto Resp = Future.get(); - CHECK(Resp.IsSuccess()); - CHECK_EQ(Resp.AsText(), "async payload data"); -} + // callback - AsyncGet completion fires the callback. + { + std::promise<HttpClient::Response> Promise; + auto Future = Promise.get_future(); + Client.AsyncGet("/api/async-test/hello", [&Promise](HttpClient::Response Resp) { Promise.set_value(std::move(Resp)); }); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } -TEST_CASE("asynchttpclient.future.put.with.payload") -{ - AsyncTestServerFixture Fixture; - AsyncHttpClient Client = Fixture.MakeClient(); + // concurrent.requests - multiple verbs in flight at once. + { + auto F1 = Client.Get("/api/async-test/hello"); + auto F2 = Client.Get("/api/async-test/json"); + auto F3 = Client.Post("/api/async-test/echo/method"); + auto F4 = Client.Delete("/api/async-test/echo/method"); + auto R1 = F1.get(); + auto R2 = F2.get(); + auto R3 = F3.get(); + auto R4 = F4.get(); + CHECK(R1.IsSuccess()); + CHECK_EQ(R1.AsText(), "hello world"); + CHECK(R2.IsSuccess()); + CHECK_EQ(R2.AsText(), "{\"ok\":true}"); + CHECK(R3.IsSuccess()); + CHECK_EQ(R3.AsText(), "POST"); + CHECK(R4.IsSuccess()); + CHECK_EQ(R4.AsText(), "DELETE"); + } - std::string_view PutStr = "put payload"; - IoBuffer Payload(IoBuffer::Clone, PutStr.data(), PutStr.size()); - Payload.SetContentType(ZenContentType::kText); + // cancel.after.completion.is.noop - late Cancel must be quiet. + { + auto Resp = Client.Get("/api/async-test/hello").get(); + REQUIRE(Resp.IsSuccess()); + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + AsyncRequestToken Token = Client.AsyncGet("/api/async-test/hello", [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); + auto Resp2 = F.get(); + REQUIRE(Resp2.IsSuccess()); + Token.Cancel(); // no-op; must not crash + } - auto Future = Client.Put("/api/async-test/echo", Payload); - auto Resp = Future.get(); - CHECK(Resp.IsSuccess()); - CHECK_EQ(Resp.AsText(), "put payload"); + // lifecycle.repeated.construct.destroy - 8 fresh clients against the same + // server. Catches io thread / curl_multi leaks across construct/destroy. + for (int I = 0; I < 8; ++I) + { + AsyncHttpClient Local = Fixture.MakeClient(); + auto Resp = Local.Get("/api/async-test/hello").get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } } -TEST_CASE("asynchttpclient.callback") +// Submit-side behavior: external io_context, shutdown cancel, no-queue +// contract, cross-thread cancel-before-submit race, unlimited fan-out. +// +// MaxConcurrentRequests is applied as curl connection caps only; cap-level +// fan-out throttling lives in the storage layer (see +// server.s3asyncstorage.admission.fanout). +TEST_CASE("asynchttpclient.submit_and_shutdown") { AsyncTestServerFixture Fixture; - AsyncHttpClient Client = Fixture.MakeClient(); - std::promise<HttpClient::Response> Promise; - auto Future = Promise.get_future(); + // external.io_context - caller drives the run loop. Verifies the + // Cleanup-via-promise path in the dtor. + { + asio::io_context IoContext; + auto WorkGuard = asio::make_work_guard(IoContext); + std::thread IoThread([&IoContext]() { IoContext.run(); }); + { + AsyncHttpClient Client = Fixture.MakeClient(IoContext); + auto Resp = Client.Get("/api/async-test/echo/method").get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "GET"); + } + WorkGuard.reset(); + IoThread.join(); + } + + // shutdown.cancels.in.flight - dtor synthesizes cancel for all in-flight. + { + const int N = 6; + std::vector<std::promise<HttpClient::Response>> Promises(N); + std::vector<std::future<HttpClient::Response>> Futures; + for (auto& P : Promises) + { + Futures.push_back(P.get_future()); + } + { + AsyncHttpClient Client = Fixture.MakeClient(); + std::vector<AsyncRequestToken> Tokens; + for (int I = 0; I < N; ++I) + { + Tokens.push_back(Client.AsyncGet("/api/async-test/slow?ms=2000", + [&Promises, I](HttpClient::Response R) { Promises[I].set_value(std::move(R)); })); + } + Sleep(50); // let requests actually start before client teardown + } + int CancelCount = 0; + for (auto& F : Futures) + { + REQUIRE(F.wait_for(std::chrono::seconds(5)) == std::future_status::ready); + HttpClient::Response R = F.get(); + if (R.Error.has_value() && R.Error->ErrorCode == HttpClientErrorCode::kRequestCancelled) + { + ++CancelCount; + } + } + CHECK(CancelCount == N); + } + + // contract.no.queue - all submissions reach the network despite cap=4. + // Fan-out gating is the storage layer's responsibility. + { + HttpClientSettings Settings; + Settings.MaxConcurrentRequests = 4; + AsyncHttpClient Client = Fixture.MakeClient(Settings); + + const int N = 100; + std::vector<std::promise<HttpClient::Response>> Promises(N); + std::vector<std::future<HttpClient::Response>> Futures; + std::vector<AsyncRequestToken> Tokens; + for (auto& P : Promises) + { + Futures.push_back(P.get_future()); + } + for (int I = 0; I < N; ++I) + { + Tokens.push_back( + Client.AsyncGet("/api/async-test/hello", [&Promises, I](HttpClient::Response R) { Promises[I].set_value(std::move(R)); })); + } + for (auto& F : Futures) + { + REQUIRE(F.wait_for(std::chrono::seconds(15)) == std::future_status::ready); + CHECK(F.get().IsSuccess()); + } + } - Client.AsyncGet("/api/async-test/hello", [&Promise](HttpClient::Response Resp) { Promise.set_value(std::move(Resp)); }); + // cancel.before.submit - cross-thread race: SubmitFromSpec posted, Cancel + // from another thread fires before/after submit handler runs. All callbacks + // must surface kRequestCancelled exactly once. + { + AsyncHttpClient Client = Fixture.MakeClient(); + const int N = 16; + std::vector<std::promise<HttpClient::Response>> Promises(N); + std::vector<std::future<HttpClient::Response>> Futures; + std::vector<AsyncRequestToken> Tokens; + for (auto& P : Promises) + { + Futures.push_back(P.get_future()); + } + for (int I = 0; I < N; ++I) + { + Tokens.push_back(Client.AsyncGet("/api/async-test/slow?ms=2000", + [&Promises, I](HttpClient::Response R) { Promises[I].set_value(std::move(R)); })); + } + std::thread CancelThread([&]() { + for (auto& T : Tokens) + { + T.Cancel(); + } + }); + for (auto& F : Futures) + { + REQUIRE(F.wait_for(std::chrono::seconds(5)) == std::future_status::ready); + HttpClient::Response R = F.get(); + REQUIRE(R.Error.has_value()); + CHECK(R.Error->ErrorCode == HttpClientErrorCode::kRequestCancelled); + } + CancelThread.join(); + } - auto Resp = Future.get(); - CHECK(Resp.IsSuccess()); - CHECK_EQ(Resp.AsText(), "hello world"); + // unlimited.parallel.fanout - 8 parallel 100ms requests with default + // settings (no cap) finish well under the 800ms serial floor. Sized to one + // batch on the asio server's 8-thread pool so per-request setup overhead on + // slow CI agents does not dominate; threshold leaves >=2x margin over the + // ~100ms parallel ideal. + { + AsyncHttpClient Client = Fixture.MakeClient(); + const int N = 8; + std::vector<std::promise<HttpClient::Response>> Promises(N); + std::vector<std::future<HttpClient::Response>> Futures; + std::vector<AsyncRequestToken> Tokens; + for (auto& P : Promises) + { + Futures.push_back(P.get_future()); + } + const auto Start = std::chrono::steady_clock::now(); + for (int I = 0; I < N; ++I) + { + Tokens.push_back(Client.AsyncGet("/api/async-test/slow?ms=100", + [&Promises, I](HttpClient::Response R) { Promises[I].set_value(std::move(R)); })); + } + for (auto& F : Futures) + { + REQUIRE(F.wait_for(std::chrono::seconds(10)) == std::future_status::ready); + CHECK(F.get().IsSuccess()); + } + const auto ElapsedMs = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - Start).count(); + CHECK(ElapsedMs < 600); + } } -TEST_CASE("asynchttpclient.concurrent.requests") +// AsyncStream coverage: chunk delivery, OnData abort, mid-flight cancel. +TEST_CASE("asynchttpclient.stream") { AsyncTestServerFixture Fixture; - AsyncHttpClient Client = Fixture.MakeClient(); - - // Fire multiple requests concurrently - auto Future1 = Client.Get("/api/async-test/hello"); - auto Future2 = Client.Get("/api/async-test/json"); - auto Future3 = Client.Post("/api/async-test/echo/method"); - auto Future4 = Client.Delete("/api/async-test/echo/method"); - auto Resp1 = Future1.get(); - auto Resp2 = Future2.get(); - auto Resp3 = Future3.get(); - auto Resp4 = Future4.get(); + // stream.basic - 4 MiB stream completes; bytes accounted; no body buffering. + { + AsyncHttpClient Client = Fixture.MakeClient(); + std::atomic<uint64_t> TotalReceived{0}; + std::atomic<uint32_t> ChunkCount{0}; + std::atomic<uint64_t> TotalSizeSeen{0}; + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + + auto Token = Client.AsyncStream( + "/api/async-test/large", + [&](const uint8_t* /*Data*/, size_t Size, uint64_t TotalSize) -> bool { + TotalReceived.fetch_add(Size, std::memory_order_relaxed); + ChunkCount.fetch_add(1, std::memory_order_relaxed); + if (TotalSize != 0) + { + TotalSizeSeen.store(TotalSize, std::memory_order_relaxed); + } + return true; + }, + [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); - CHECK(Resp1.IsSuccess()); - CHECK_EQ(Resp1.AsText(), "hello world"); + REQUIRE(F.wait_for(std::chrono::seconds(10)) == std::future_status::ready); + auto Resp = F.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(TotalReceived.load(), 4u * 1024u * 1024u); + CHECK_EQ(TotalSizeSeen.load(), 4u * 1024u * 1024u); + CHECK(ChunkCount.load() >= 1); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 0u); + } - CHECK(Resp2.IsSuccess()); - CHECK_EQ(Resp2.AsText(), "{\"ok\":true}"); + // stream.ondata.abort - returning false from OnData stops the transfer. + { + AsyncHttpClient Client = Fixture.MakeClient(); + std::atomic<uint32_t> ChunkCount{0}; + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + + Client.AsyncStream( + "/api/async-test/large", + [&](const uint8_t*, size_t, uint64_t) -> bool { + ChunkCount.fetch_add(1, std::memory_order_relaxed); + return false; + }, + [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); - CHECK(Resp3.IsSuccess()); - CHECK_EQ(Resp3.AsText(), "POST"); + REQUIRE(F.wait_for(std::chrono::seconds(10)) == std::future_status::ready); + auto Resp = F.get(); + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + CHECK(ChunkCount.load() <= 1); + } - CHECK(Resp4.IsSuccess()); - CHECK_EQ(Resp4.AsText(), "DELETE"); + // stream.cancel.mid.flight - Cancel during long stream surfaces + // kRequestCancelled. RetryCount=0 so no retry layer masks the error code. + { + HttpClientSettings Settings; + Settings.RetryCount = 0; + AsyncHttpClient Client = Fixture.MakeClient(Settings); + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + auto Token = Client.AsyncStream( + "/api/async-test/slow?ms=2000", + [](const uint8_t*, size_t, uint64_t) -> bool { return true; }, + [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); + Sleep(50); + Token.Cancel(); + REQUIRE(F.wait_for(std::chrono::seconds(5)) == std::future_status::ready); + auto Resp = F.get(); + CHECK_FALSE(Resp.IsSuccess()); + REQUIRE(Resp.Error.has_value()); + CHECK(Resp.Error->ErrorCode == HttpClientErrorCode::kRequestCancelled); + } } -TEST_CASE("asynchttpclient.external.io_context") +// High-fanout, mixed verbs, large payload, streaming-source PUT. +TEST_CASE("asynchttpclient.stress") { AsyncTestServerFixture Fixture; - asio::io_context IoContext; - auto WorkGuard = asio::make_work_guard(IoContext); - std::thread IoThread([&IoContext]() { IoContext.run(); }); + // high.fanout - 32 unlimited parallel GETs all succeed. + { + AsyncHttpClient Client = Fixture.MakeClient(); + const int N = 32; + std::vector<std::promise<HttpClient::Response>> Promises(N); + std::vector<std::future<HttpClient::Response>> Futures; + std::vector<AsyncRequestToken> Tokens; + for (auto& Pr : Promises) + { + Futures.push_back(Pr.get_future()); + } + for (int I = 0; I < N; ++I) + { + Tokens.push_back( + Client.AsyncGet("/api/async-test/hello", [&Promises, I](HttpClient::Response R) { Promises[I].set_value(std::move(R)); })); + } + for (int I = 0; I < N; ++I) + { + REQUIRE(Futures[I].wait_for(std::chrono::seconds(15)) == std::future_status::ready); + auto Resp = Futures[I].get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + } + // mixed.verbs.concurrent - GET/POST/PUT/DELETE/Stream concurrently under cap=4. { - AsyncHttpClient Client = Fixture.MakeClient(IoContext); + HttpClientSettings Settings; + Settings.MaxConcurrentRequests = 4; + AsyncHttpClient Client = Fixture.MakeClient(Settings); + + IoBuffer Payload(64); + std::memset(Payload.MutableData(), 0xAB, Payload.GetSize()); + + auto FGet = Client.Get("/api/async-test/hello"); + auto FPost = Client.Post("/api/async-test/echo", Payload); + auto FPut = Client.Put("/api/async-test/echo", Payload); + auto FDelete = Client.Delete("/api/async-test/echo/method"); + auto FJson = Client.Get("/api/async-test/json"); + + std::atomic<uint64_t> StreamBytes{0}; + std::promise<HttpClient::Response> StreamP; + auto StreamF = StreamP.get_future(); + Client.AsyncStream( + "/api/async-test/large", + [&](const uint8_t*, size_t Size, uint64_t) -> bool { + StreamBytes.fetch_add(Size, std::memory_order_relaxed); + return true; + }, + [&StreamP](HttpClient::Response R) { StreamP.set_value(std::move(R)); }); + + auto Get = FGet.get(); + auto Post = FPost.get(); + auto Put = FPut.get(); + auto Delete = FDelete.get(); + auto Json = FJson.get(); + REQUIRE(StreamF.wait_for(std::chrono::seconds(10)) == std::future_status::ready); + auto Stream = StreamF.get(); + + CHECK(Get.IsSuccess()); + CHECK_EQ(Get.AsText(), "hello world"); + CHECK(Post.IsSuccess()); + CHECK_EQ(Post.ResponsePayload.GetSize(), Payload.GetSize()); + CHECK(Put.IsSuccess()); + CHECK_EQ(Put.ResponsePayload.GetSize(), Payload.GetSize()); + CHECK(Delete.IsSuccess()); + CHECK_EQ(Delete.AsText(), "DELETE"); + CHECK(Json.IsSuccess()); + CHECK_EQ(Json.AsText(), "{\"ok\":true}"); + CHECK(Stream.IsSuccess()); + CHECK_EQ(StreamBytes.load(), 4u * 1024u * 1024u); + } - auto Future = Client.Get("/api/async-test/hello"); - auto Resp = Future.get(); - CHECK(Resp.IsSuccess()); - CHECK_EQ(Resp.AsText(), "hello world"); + // large.put.roundtrip - 4 MiB PUT echoed back; spot-check positions. + { + AsyncHttpClient Client = Fixture.MakeClient(); + const size_t Size = 4u * 1024u * 1024u; + IoBuffer Payload(Size); + uint8_t* Data = static_cast<uint8_t*>(Payload.MutableData()); + for (size_t I = 0; I < Size; ++I) + { + Data[I] = static_cast<uint8_t>((I * 31u) & 0xFF); + } + auto Resp = Client.Put("/api/async-test/echo", Payload).get(); + REQUIRE(Resp.IsSuccess()); + REQUIRE_EQ(Resp.ResponsePayload.GetSize(), Size); + const uint8_t* RecvData = static_cast<const uint8_t*>(Resp.ResponsePayload.GetData()); + CHECK(RecvData[0] == Data[0]); + CHECK(RecvData[1u << 10] == Data[1u << 10]); + CHECK(RecvData[Size / 2] == Data[Size / 2]); + CHECK(RecvData[Size - 1] == Data[Size - 1]); } - WorkGuard.reset(); - IoThread.join(); + // streaming.put.source - AsyncPut(url, size, source, callback). + // Part 1: 2 MiB echo round-trip. + // Part 2: source returning 0 with offset < TotalSize aborts via CURL_READFUNC_ABORT. + { + AsyncHttpClient Client = Fixture.MakeClient(); + + { + const size_t Size = 2u * 1024u * 1024u; + std::vector<uint8_t> Source(Size); + for (size_t I = 0; I < Size; ++I) + { + Source[I] = static_cast<uint8_t>((I * 17u + 5u) & 0xFF); + } + std::atomic<uint64_t> SourceCalls{0}; + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + Client.AsyncPut( + "/api/async-test/echo", + Size, + [&](uint8_t* Dst, size_t MaxBytes, uint64_t Offset) -> size_t { + SourceCalls.fetch_add(1, std::memory_order_relaxed); + const size_t Remaining = Size > Offset ? Size - static_cast<size_t>(Offset) : 0; + const size_t Take = std::min(MaxBytes, Remaining); + if (Take == 0) + { + return 0; + } + std::memcpy(Dst, Source.data() + Offset, Take); + return Take; + }, + [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); + REQUIRE(F.wait_for(std::chrono::seconds(15)) == std::future_status::ready); + HttpClient::Response Resp = F.get(); + REQUIRE(Resp.IsSuccess()); + REQUIRE_EQ(Resp.ResponsePayload.GetSize(), Size); + CHECK(SourceCalls.load() >= 1); + const uint8_t* RecvData = static_cast<const uint8_t*>(Resp.ResponsePayload.GetData()); + CHECK(RecvData[0] == Source[0]); + CHECK(RecvData[Size / 2] == Source[Size / 2]); + CHECK(RecvData[Size - 1] == Source[Size - 1]); + CHECK(RecvData[1234567] == Source[1234567]); + } + + { + const size_t DeclaredSize = 1024u * 1024u; + std::atomic<uint64_t> SourceCalls{0}; + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + Client.AsyncPut( + "/api/async-test/echo", + DeclaredSize, + [&](uint8_t* Dst, size_t MaxBytes, uint64_t /*Offset*/) -> size_t { + const uint64_t Hits = SourceCalls.fetch_add(1, std::memory_order_relaxed); + if (Hits >= 1) + { + return 0; // abort: 0 returned with Offset < TotalSize + } + const size_t Take = std::min<size_t>(MaxBytes, 64u); + std::memset(Dst, 0xAB, Take); + return Take; + }, + [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); + REQUIRE(F.wait_for(std::chrono::seconds(15)) == std::future_status::ready); + HttpClient::Response Resp = F.get(); + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + } } -TEST_CASE("asynchttpclient.connection.error") +// Connection-error / retry semantics targeting a dead port. No server fixture +// needed; each scope constructs its own client with bespoke timeout settings. +TEST_CASE("asynchttpclient.connection_errors") { - // Connect to a port where nothing is listening - AsyncHttpClient Client("127.0.0.1:1", HttpClientSettings{.ConnectTimeout = std::chrono::milliseconds(500)}); + // connection.error - Get against dead port surfaces connection error. + { + AsyncHttpClient Client("127.0.0.1:1", HttpClientSettings{.ConnectTimeout = std::chrono::milliseconds(500)}); + auto Resp = Client.Get("/should-fail").get(); + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + CHECK(Resp.Error->IsConnectionError()); + } + + // retry.respected.on.connection.error - 2 retries adds >=300ms backoff + // (100ms + 200ms accumulated past the initial attempt). + { + HttpClientSettings Settings{ + .ConnectTimeout = std::chrono::milliseconds(50), + .RetryCount = 2, + }; + AsyncHttpClient Client("127.0.0.1:1", Settings); + const auto Start = std::chrono::steady_clock::now(); + auto Resp = Client.Get("/should-fail").get(); + const auto Elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - Start); + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Elapsed.count() >= 300); + } - auto Future = Client.Get("/should-fail"); - auto Resp = Future.get(); + // cancel.in.flight - Cancel mid-connect doesn't hang (regardless of which + // side of the strand wins the cancel-vs-ECONNREFUSED race). + { + HttpClientSettings Settings{ + .ConnectTimeout = std::chrono::milliseconds(60000), + .RetryCount = 0, + }; + AsyncHttpClient Client("127.0.0.1:1", Settings); + std::promise<HttpClient::Response> P; + auto F = P.get_future(); + const auto Start = std::chrono::steady_clock::now(); + AsyncRequestToken Token = Client.AsyncGet("/should-cancel", [&P](HttpClient::Response R) { P.set_value(std::move(R)); }); + Token.Cancel(); + REQUIRE(F.wait_for(std::chrono::seconds(5)) == std::future_status::ready); + const auto Elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - Start); + auto Resp = F.get(); + CHECK_FALSE(Resp.IsSuccess()); + REQUIRE(Resp.Error.has_value()); + CHECK((Resp.Error->ErrorCode == HttpClientErrorCode::kRequestCancelled || + Resp.Error->ErrorCode == HttpClientErrorCode::kConnectionFailure)); + CHECK(Elapsed.count() < 5000); + } - CHECK_FALSE(Resp.IsSuccess()); - CHECK(Resp.Error.has_value()); - CHECK(Resp.Error->IsConnectionError()); + // retry.zero.no.retry - RetryCount=0 returns promptly (< 1s) with no + // extra backoff past ConnectTimeout. + { + HttpClientSettings Settings{ + .ConnectTimeout = std::chrono::milliseconds(50), + .RetryCount = 0, + }; + AsyncHttpClient Client("127.0.0.1:1", Settings); + const auto Start = std::chrono::steady_clock::now(); + auto Resp = Client.Get("/should-fail").get(); + const auto Elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - Start); + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Elapsed.count() < 1000); + } } TEST_SUITE_END(); diff --git a/src/zenhttp/clients/asynchttpclient.cpp b/src/zenhttp/clients/asynchttpclient.cpp index ea88fc783..e7e904f89 100644 --- a/src/zenhttp/clients/asynchttpclient.cpp +++ b/src/zenhttp/clients/asynchttpclient.cpp @@ -4,8 +4,11 @@ #include "httpclientcurlhelpers.h" +#include <zencore/basicfile.h> #include <zencore/filesystem.h> +#include <zencore/fmtutils.h> #include <zencore/logging.h> +#include <zencore/scopeguard.h> #include <zencore/session.h> #include <zencore/thread.h> #include <zencore/trace.h> @@ -15,6 +18,9 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio/steady_timer.hpp> ZEN_THIRD_PARTY_INCLUDES_END +#include <algorithm> +#include <charconv> +#include <deque> #include <thread> #include <unordered_map> @@ -22,45 +28,204 @@ namespace zen { ////////////////////////////////////////////////////////////////////////// // +// AsyncRequestToken state (forward-declared in the public header so tokens +// can be returned by value without leaking impl details). + +struct AsyncRequestToken::State +{ + std::function<void()> CancelFn; + std::atomic<bool> Cancelled = false; +}; + +////////////////////////////////////////////////////////////////////////// +// // TransferContext: per-transfer state associated with each CURL easy handle +// Request blueprint kept alongside each transfer so retries can re-issue with +// the original verb/url/headers/payload after the previous attempt's transient +// failure. +enum class AsyncRequestMethod +{ + Get, + Head, + Delete, + Post, + PostWithPayload, + Put, + PutWithPayload, + PutWithSource, // PUT, body pulled via OnReadSource (no materialized payload) + Stream, // GET, response body delivered via OnData callback (no copy) +}; + +inline std::string_view +AsyncRequestMethodName(AsyncRequestMethod M) +{ + switch (M) + { + case AsyncRequestMethod::Get: + return "GET"; + case AsyncRequestMethod::Head: + return "HEAD"; + case AsyncRequestMethod::Delete: + return "DELETE"; + case AsyncRequestMethod::Post: + return "POST"; + case AsyncRequestMethod::PostWithPayload: + return "POST(payload)"; + case AsyncRequestMethod::Put: + return "PUT"; + case AsyncRequestMethod::PutWithPayload: + return "PUT(payload)"; + case AsyncRequestMethod::PutWithSource: + return "PUT(stream)"; + case AsyncRequestMethod::Stream: + return "GET(stream)"; + } + return "?"; +} + +struct AsyncRequestSpec +{ + AsyncRequestMethod Method = AsyncRequestMethod::Get; + std::string Url; + HttpClient::KeyValueMap AdditionalHeader; + HttpClient::KeyValueMap Parameters; + IoBuffer Payload; // POST/PUT with payload + ZenContentType ContentType = ZenContentType::kUnknownContentType; + bool HasContentType = false; + AsyncHttpDataCallback OnData; // Stream method + AsyncHttpReadSource OnReadSource; // PutWithSource method + uint64_t StreamingPutSize = 0; // Content-Length for PutWithSource + + // Opt-in header capture. By default Response::Header is left empty; inline + // extracts always run because they steer the body path. WantHeaderMap pays + // O(headers) string allocs on the io thread (rare - most callers use + // Response::FindHeader). WantEtag triggers an inline parse only. + bool WantEtag = false; + bool WantHeaderMap = false; +}; + struct TransferContext { - AsyncHttpCallback Callback; - std::string Body; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - CurlWriteCallbackData WriteData; - CurlHeaderCallbackData HeaderData; - curl_slist* HeaderList = nullptr; - - // For PUT/POST with payload: keep the data alive until transfer completes + AsyncHttpCallback Callback; + AsyncRequestSpec Spec; + uint8_t AttemptCount = 0; + uint64_t TokenId = 0; + bool Cancelled = false; + // Curl handle owning this transfer once submitted to curl_multi. Null between + // Submit() and SubmitFromSpec(), and again after CompleteTransfer releases the + // handle back to the pool. + CURL* CurlHandle = nullptr; + // Strong reference to the user-facing AsyncRequestToken::State. Kept alive + // so AsyncRequestToken::Cancel() retains a valid CancelFn target until the + // transfer completes. The typed pointer also lets SubmitFromSpec read + // State.Cancelled to honour cancels that arrived between Submit() returning + // and the io thread running SubmitFromSpec. + std::shared_ptr<AsyncRequestToken::State> TokenStateRef; + // Two paths gated by BodyPreallocated: when Content-Length is known we + // fill Body in place (zero-copy move into Response); otherwise BodyChunks + // accumulates per-WRITE IoBuffers, flattened at completion. + IoBuffer Body; + uint64_t BodyWriteOffset = 0; + bool BodyPreallocated = false; + std::vector<IoBuffer> BodyChunks; + + // Raw response headers, "Key: Value\r\n" lines as delivered by curl. + // One growing buffer; reserve covers common case (~1 KiB) with no realloc. + std::string HeaderArena; + + // Captured at the curl callback boundary so an exception out of the user + // OnData / OnReadSource never propagates through curl's C frames (UB). + // Surfaced as a kInternalError response by CompleteTransfer. + bool CallbackFailed = false; + std::string CallbackErrorMessage; + + // Inline-parsed in CurlHeaderCallback. Etag is populated only when + // Spec.WantEtag is set; the others are always parsed since they steer + // the body path / content-type tagging. + uint64_t ContentLength = 0; + bool ContentLengthSet = false; + ZenContentType BodyContentType = ZenContentType::kUnknownContentType; + std::string Etag; + + curl_slist* HeaderList = nullptr; + IoBuffer PayloadBuffer; CurlReadCallbackData ReadData; + uint64_t SourceOffset = 0; // PutWithSource: bytes pulled from Spec.OnReadSource so far. + + // Last attempt's failure, kept across the backoff so a retry-abandoned + // path can surface the underlying cause instead of a generic + // "Request canceled (retry abandoned)". Non-empty LastErrorMessage = stash valid. + CURLcode LastCurlResult = CURLE_OK; + long LastStatusCode = 0; + std::string LastErrorMessage; + + TransferContext(AsyncHttpCallback&& InCallback) : Callback(std::move(InCallback)) { HeaderArena.reserve(1024); } + + ~TransferContext() { FreeHeaderList(); } + + TransferContext(const TransferContext&) = delete; + TransferContext& operator=(const TransferContext&) = delete; - TransferContext(AsyncHttpCallback&& InCallback) : Callback(std::move(InCallback)) + // Reset accumulated response state so the same context can be re-submitted + // for a retry attempt. + void ResetForRetry() { - WriteData.Body = &Body; - HeaderData.Headers = &ResponseHeaders; + Body = IoBuffer{}; + BodyWriteOffset = 0; + BodyPreallocated = false; + BodyChunks.clear(); + HeaderArena.clear(); // keep capacity + ContentLength = 0; + ContentLengthSet = false; + BodyContentType = ZenContentType::kUnknownContentType; + Etag.clear(); + FreeHeaderList(); + ReadData = {}; + SourceOffset = 0; + CallbackFailed = false; + CallbackErrorMessage.clear(); + // LastCurlResult / LastStatusCode / LastErrorMessage are intentionally NOT + // cleared - they describe the just-finished attempt that triggered this + // retry, and surface in the abandon path if the next attempt is cancelled + // or shutdown-aborted. } - ~TransferContext() + void FreeHeaderList() { if (HeaderList) { curl_slist_free_all(HeaderList); + HeaderList = nullptr; } } - - TransferContext(const TransferContext&) = delete; - TransferContext& operator=(const TransferContext&) = delete; }; ////////////////////////////////////////////////////////////////////////// // -// AsyncHttpClient::Impl +// SocketInfo: per-socket state. -struct AsyncHttpClient::Impl +struct AsyncSocketInfo { + asio::ip::tcp::socket Socket; + int WatchFlags = 0; // CURL_POLL_IN, CURL_POLL_OUT, CURL_POLL_INOUT + int PendingFlags = 0; // directions with outstanding async_wait + + // Bound to the strand executor so async_wait completions are serialized on + // the same strand that drives curl_multi - safe even when the underlying + // io_context is multithreaded (external-context mode). + explicit AsyncSocketInfo(const asio::strand<asio::io_context::executor_type>& Strand) : Socket(Strand) {} +}; + +// Holds the curl_multi instance and a strand that serializes every curl_multi op. Owned io_context +// (default ctor) spins a private thread driving run(); external io_context mode lets the caller +// drive the loop. The strand is the single serialization point for both modes. +struct AsyncHttpClient::Impl : std::enable_shared_from_this<AsyncHttpClient::Impl> +{ + // Owned-io_context ctor: allocate a private io_context and run it on a + // dedicated thread. Cleanest path for callers that just want an + // AsyncHttpClient and don't care about the loop. Impl(std::string_view BaseUri, const HttpClientSettings& Settings) : m_BaseUri(BaseUri) , m_Settings(Settings) @@ -72,19 +237,33 @@ struct AsyncHttpClient::Impl { Init(); m_WorkGuard.emplace(m_IoContext.get_executor()); - m_IoThread = std::thread([this]() { - SetCurrentThreadName("async_http"); - try - { - m_IoContext.run(); - } - catch (const std::exception& Ex) + + auto ThreadGuard = MakeGuard([this]() { + m_WorkGuard.reset(); + if (m_IoThread.joinable()) { - ZEN_ERROR("AsyncHttpClient: unhandled exception in io thread: {}", Ex.what()); + m_IoThread.join(); } }); + m_IoThread = std::thread([this]() { + SetCurrentThreadName("async_http"); + try + { + m_IoContext.run(); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("AsyncHttpClient: io thread unhandled exception: {}", Ex.what()); + } + }); + ThreadGuard.Dismiss(); } + // External-io_context ctor: caller drives the run loop. We do NOT spawn a + // thread, do NOT hold a work guard, and do NOT call stop()/restart() on + // teardown - the caller's lifecycle owns those. Shutdown blocks on a + // promise until our cleanup handler runs through the strand, so the + // caller MUST keep the loop running until the AsyncHttpClient destructs. Impl(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings) : m_BaseUri(BaseUri) , m_Settings(Settings) @@ -96,70 +275,167 @@ struct AsyncHttpClient::Impl Init(); } - ~Impl() + ~Impl() { Shutdown(); } + + // Synchronous teardown from ~AsyncHttpClient. Idempotent. Branches by ownership: + // - Owned io_context: post Cleanup, drop the work guard, force run() to return, join the io + // thread, then poll() to drain lambdas that captured shared_ptr<Impl> (the "lambda in + // io_context owned by Impl" cycle would otherwise pin Impl alive forever). + // - External io_context: caller drives the loop, post Cleanup to the strand and block on a + // promise. Destroying from the io thread itself would deadlock. + void Shutdown() { - // Clean up curl state on the strand where all curl_multi operations - // are serialized. Use a promise to block until the cleanup handler - // has actually executed - essential for the external io_context case - // where we don't own the run loop. - std::promise<void> Done; - std::future<void> DoneFuture = Done.get_future(); - - asio::post(m_Strand, [this, &Done]() { - m_ShuttingDown = true; - m_Timer.cancel(); - - // Release all tracked sockets (don't close - curl owns the fds). - for (auto& [Fd, Info] : m_Sockets) - { - if (Info->Socket.is_open()) - { - Info->Socket.cancel(); - Info->Socket.release(); - } - } - m_Sockets.clear(); + if (m_ShutdownDone.exchange(true, std::memory_order_acq_rel)) + { + return; + } - for (auto& [Handle, Ctx] : m_Transfers) + if (m_OwnedIoContext) + { + // Post Cleanup before releasing the work guard so the io thread + // runs Cleanup before run() returns. Run() normally exits when + // the queue is empty AND no work guard is held. + asio::post(m_Strand, [this]() { Cleanup(); }); + m_WorkGuard.reset(); + // Belt-and-suspenders: force run() to return even if asio's + // outstanding_work_ counter is left non-zero by a close-during- + // cancel race in win_iocp_socket_service. The race is observable + // as a hung join() at shutdown after a burst of socket teardowns + // inside curl_multi_cleanup; stop() here is purely a safety net + // for the teardown path. The trailing restart() + poll() drains + // any handlers stop() leaves undispatched. + m_IoContext.stop(); + if (m_IoThread.joinable()) { - curl_multi_remove_handle(m_Multi, Handle); - curl_easy_cleanup(Handle); + m_IoThread.join(); } - m_Transfers.clear(); - for (CURL* Handle : m_HandlePool) + // Drain any leftover work posted to the io_context but not run + // before the thread exited (e.g. a Cancel-after-completion lambda + // that captured shared_ptr<Impl> by value). Loop until the queue + // is empty: a single poll() is one quanta and may not drain + // handlers that themselves post follow-ups. + m_IoContext.restart(); + while (m_IoContext.poll() != 0) { - curl_easy_cleanup(Handle); } - m_HandlePool.clear(); - - Done.set_value(); - }); + } + else + { + // External: block on the cleanup handler so we can guarantee + // curl_multi state is gone before m_Impl drops. + std::promise<void> Done; + std::future<void> DoneFuture = Done.get_future(); + asio::post(m_Strand, [this, &Done]() { + Cleanup(); + Done.set_value(); + }); + DoneFuture.wait(); + } + } - // For owned io_context: release work guard so run() can return after - // processing the cleanup handler above. - m_WorkGuard.reset(); + // Cleanup body, run on the io thread. + void Cleanup() + { + m_ShuttingDown = true; + m_Timer.cancel(); - if (m_IoThread.joinable()) + // Tear down curl handles first; curl drives CURL_POLL_REMOVE + + // CLOSESOCKETFUNCTION for each owned socket, which our handlers + // use to retire SocketInfo entries from m_Sockets. + for (auto& [TokenId, Ctx] : m_Transfers) { - m_IoThread.join(); + if (Ctx->CurlHandle) + { + curl_multi_remove_handle(m_Multi, Ctx->CurlHandle); + curl_easy_cleanup(Ctx->CurlHandle); + Ctx->CurlHandle = nullptr; + } + + HttpClient::Response Resp; + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "AsyncHttpClient shutting down", + }; + try + { + Ctx->Callback(std::move(Resp)); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("AsyncHttpClient: unhandled exception in shutdown callback (token={}, {} {}): {}", + TokenId, + AsyncRequestMethodName(Ctx->Spec.Method), + Ctx->Spec.Url, + Ex.what()); + } } - else + m_Transfers.clear(); + m_InFlight = 0; + + // Drain transfers parked in retry backoff. The timer's pending + // async_wait fires after Cleanup with the io_context already stopped; + // it will find no entry and be a no-op. Fire cancel callbacks here + // while we still hold the storage so callers' futures resolve. + for (auto& [Id, Entry] : m_RetryingTransfers) { - // External io_context: wait for the cleanup handler to complete. - DoneFuture.wait(); + if (Entry.Timer) + { + Entry.Timer->cancel(); + } + HttpClient::Response Resp; + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "AsyncHttpClient shutting down", + }; + try + { + Entry.Ctx->Callback(std::move(Resp)); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("AsyncHttpClient: unhandled exception in shutdown retry callback (token={}, {} {}): {}", + Id, + AsyncRequestMethodName(Entry.Ctx->Spec.Method), + Entry.Ctx->Spec.Url, + Ex.what()); + } } + m_RetryingTransfers.clear(); + // curl_multi_cleanup walks the connection cache and fires + // CLOSESOCKETFUNCTION for each cached fd. Run it on the io thread + // while m_Sockets is still populated so our callback routes each + // close through the map (Socket dtor closes fd exactly once). if (m_Multi) { curl_multi_cleanup(m_Multi); + m_Multi = nullptr; + } + + for (CURL* Handle : m_HandlePool) + { + curl_easy_cleanup(Handle); + } + m_HandlePool.clear(); + + asio::error_code CancelEc; + for (auto& [Fd, Info] : m_Sockets) + { + Info->Socket.cancel(CancelEc); } + m_Sockets.clear(); } LoggerRef Log() { return m_Log; } void Init() { + if (!m_Settings.UnixSocketPath.empty()) + { + m_UnixSocketPathUtf8 = PathToUtf8(m_Settings.UnixSocketPath); + } + m_Multi = curl_multi_init(); if (!m_Multi) { @@ -168,6 +444,24 @@ struct AsyncHttpClient::Impl SetupMultiCallbacks(); + if (m_Settings.MaxConcurrentConnectionsPerHost != 0) + { + curl_multi_setopt(m_Multi, CURLMOPT_MAX_HOST_CONNECTIONS, static_cast<long>(m_Settings.MaxConcurrentConnectionsPerHost)); + } + if (m_Settings.MaxConcurrentConnectionsTotal != 0) + { + curl_multi_setopt(m_Multi, CURLMOPT_MAX_TOTAL_CONNECTIONS, static_cast<long>(m_Settings.MaxConcurrentConnectionsTotal)); + } + + // Size the idle-conn cache to the in-flight cap so reused conns are + // never evicted while requests are queued. Each eviction costs a + // fresh TCP+TLS handshake (~280ms WAN to S3) on the next reuse. + const long MaxConnectsHint = static_cast<long>(std::max({m_Settings.MaxConcurrentRequests, + m_Settings.MaxConcurrentConnectionsTotal, + m_Settings.MaxConcurrentConnectionsPerHost, + 128u})); + curl_multi_setopt(m_Multi, CURLMOPT_MAXCONNECTS, MaxConnectsHint); + if (m_Settings.SessionId == Oid::Zero) { m_SessionId = std::string(GetSessionIdString()); @@ -178,6 +472,30 @@ struct AsyncHttpClient::Impl } } + // Run a completion callback inline on the io thread. By the time this is + // called, curl_multi_remove_handle + curl_easy_cleanup have already finalized + // the easy handle, so deferring to next io tick buys nothing. Direct call + // saves one alloc + queue insert per request. + // + // CONTRACT: user callbacks run on the AsyncHttpClient io thread. Heavy work + // (disk syscalls, lock contention, large allocations) must be hopped to a + // worker pool; otherwise it stalls curl_multi for ALL in-flight transfers. + void DispatchCallback(AsyncHttpCallback Cb, HttpClient::Response Resp, const TransferContext& Ctx) + { + try + { + Cb(std::move(Resp)); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("AsyncHttpClient: unhandled exception in completion callback (token={}, {} {}): {}", + Ctx.TokenId, + AsyncRequestMethodName(Ctx.Spec.Method), + Ctx.Spec.Url, + Ex.what()); + } + } + // -- Handle pool ----------------------------------------------------- CURL* AllocHandle() @@ -199,20 +517,15 @@ struct AsyncHttpClient::Impl void ReleaseHandle(CURL* Handle) { m_HandlePool.push_back(Handle); } - // -- Configure a handle with common settings ------------------------- - // Called only from DoAsync* lambdas running on the strand. - + // Called only from DoAsync* lambdas running on the io thread. void ConfigureHandle(CURL* Handle, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters) { - // Build URL ExtendableStringBuilder<256> Url; BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters); curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); - // Unix domain socket if (!m_Settings.UnixSocketPath.empty()) { - m_UnixSocketPathUtf8 = PathToUtf8(m_Settings.UnixSocketPath); curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, m_UnixSocketPathUtf8.c_str()); } @@ -226,12 +539,6 @@ struct AsyncHttpClient::Impl curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_Settings.Timeout.count())); } - // HTTP/2 - if (m_Settings.AssumeHttp2) - { - curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); - } - // SSL if (m_Settings.InsecureSsl) { @@ -243,7 +550,6 @@ struct AsyncHttpClient::Impl curl_easy_setopt(Handle, CURLOPT_CAINFO, m_Settings.CaBundlePath.c_str()); } - // Verbose/debug if (m_Settings.Verbose) { curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); @@ -252,6 +558,22 @@ struct AsyncHttpClient::Impl // Thread safety curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); + // 256 KiB recv buffer aligns with optimal read syscall size and matches + // upload buffer; pairs with downstream 512 KiB write slots (2 recv calls + // per write slot under bulk transfer). + curl_easy_setopt(Handle, CURLOPT_BUFFERSIZE, 262144L); + curl_easy_setopt(Handle, CURLOPT_UPLOAD_BUFFERSIZE, 262144L); + + // Skip per-transfer progress bookkeeping; we don't consume it. + curl_easy_setopt(Handle, CURLOPT_NOPROGRESS, 1L); + + // Disable Nagle (default since curl 7.50; explicit for safety). + curl_easy_setopt(Handle, CURLOPT_TCP_NODELAY, 1L); + + // Take ownership of socket close (see CurlCloseSocketCallback). + curl_easy_setopt(Handle, CURLOPT_CLOSESOCKETFUNCTION, &CurlCloseSocketCallback); + curl_easy_setopt(Handle, CURLOPT_CLOSESOCKETDATA, this); + if (m_Settings.ForbidReuseConnection) { curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); @@ -260,23 +582,24 @@ struct AsyncHttpClient::Impl // -- Access token ---------------------------------------------------- - std::optional<std::string> GetAccessToken() + struct AccessTokenResult { + std::optional<std::string> Token; + bool ProviderFailed = false; // provider configured but returned invalid twice + }; + + // Called only on the io thread. + AccessTokenResult GetAccessToken() + { + AccessTokenResult Result; if (!m_Settings.AccessTokenProvider.has_value()) { - return {}; + return Result; // No provider: anonymous is the intended mode. } - { - RwLock::SharedLockScope _(m_AccessTokenLock); - if (!m_CachedAccessToken.NeedsRefresh()) - { - return m_CachedAccessToken.GetValue(); - } - } - RwLock::ExclusiveLockScope _(m_AccessTokenLock); if (!m_CachedAccessToken.NeedsRefresh()) { - return m_CachedAccessToken.GetValue(); + Result.Token = m_CachedAccessToken.GetValue(); + return Result; } HttpClientAccessToken NewToken = m_Settings.AccessTokenProvider.value()(); if (!NewToken.IsValid()) @@ -287,10 +610,176 @@ struct AsyncHttpClient::Impl if (NewToken.IsValid()) { m_CachedAccessToken = NewToken; - return m_CachedAccessToken.GetValue(); + Result.Token = m_CachedAccessToken.GetValue(); + return Result; } ZEN_WARN("AsyncHttpClient: access token provider returned invalid token"); - return {}; + Result.ProviderFailed = true; + return Result; + } + + // -- Submit / resubmit ----------------------------------------------- + // + // SubmitFromSpec runs on the io thread. Used for both the initial submission + // and for retries: the AsyncRequestSpec inside Ctx encodes everything + // needed to (re)build the curl handle from scratch. + + void SubmitFromSpec(std::unique_ptr<TransferContext> Ctx) + { + if (m_ShuttingDown) + { + // Synthesize a cancel response so the user callback fires exactly once. + // Without this any Ctx that lands here post-shutdown would be dropped + // silently, leaving waiting futures unresolved. + HttpClient::Response CancelResp; + CancelResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled (client shutting down)", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(CancelResp), *Ctx); + return; + } + + // Cancel-before-submit race: the user can call AsyncRequestToken::Cancel() + // between Submit() returning and the io thread running SubmitFromSpec. + // State.Cancelled is set under acq_rel by Cancel(); read here under + // acquire ensures visibility. If set, fire the cancel callback exactly + // once and bail before the transfer enters m_Transfers. + if (Ctx->TokenStateRef && Ctx->TokenStateRef->Cancelled.load(std::memory_order_acquire)) + { + HttpClient::Response CancelResp; + CancelResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled before submit", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(CancelResp), *Ctx); + return; + } + + // Allocate the curl handle BEFORE bumping m_InFlight: AllocHandle can throw + // (curl_easy_init returning null). The InFlightGuard covers the rest of + // the body (BuildHeaderList, ExtraHeaders push_back, GetAccessToken can + // all throw bad_alloc) so a throw post-increment doesn't leak the slot. + // HandleGuard returns the handle to the pool on throw; CallbackGuard + // synthesizes a kInternalError response so the user future resolves. + CURL* Handle = AllocHandle(); + ++m_InFlight; + auto InFlightGuard = MakeGuard([this] { --m_InFlight; }); + auto HandleGuard = MakeGuard([this, Handle] { ReleaseHandle(Handle); }); + auto CallbackGuard = MakeGuard([this, &Ctx] { + HttpClient::Response ErrResp; + ErrResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kInternalError, + .ErrorMessage = "AsyncHttpClient::SubmitFromSpec: setup threw before dispatch", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(ErrResp), *Ctx); + }); + ConfigureHandle(Handle, Ctx->Spec.Url, Ctx->Spec.Parameters); + + switch (Ctx->Spec.Method) + { + case AsyncRequestMethod::Get: + curl_easy_setopt(Handle, CURLOPT_HTTPGET, 1L); + break; + + case AsyncRequestMethod::Stream: + curl_easy_setopt(Handle, CURLOPT_HTTPGET, 1L); + break; + + case AsyncRequestMethod::Head: + curl_easy_setopt(Handle, CURLOPT_NOBODY, 1L); + break; + + case AsyncRequestMethod::Delete: + curl_easy_setopt(Handle, CURLOPT_CUSTOMREQUEST, "DELETE"); + break; + + case AsyncRequestMethod::Post: + curl_easy_setopt(Handle, CURLOPT_POST, 1L); + curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE, 0L); + break; + + case AsyncRequestMethod::PostWithPayload: + { + curl_easy_setopt(Handle, CURLOPT_POST, 1L); + Ctx->PayloadBuffer = Ctx->Spec.Payload; + Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); + Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); + Ctx->ReadData.Offset = 0; + curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); + break; + } + + case AsyncRequestMethod::Put: + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, 0LL); + break; + + case AsyncRequestMethod::PutWithPayload: + { + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + Ctx->PayloadBuffer = Ctx->Spec.Payload; + Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); + Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); + Ctx->ReadData.Offset = 0; + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); + break; + } + + case AsyncRequestMethod::PutWithSource: + { + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + Ctx->SourceOffset = 0; + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Ctx->Spec.StreamingPutSize)); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, AsyncCurlSourceReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, Ctx.get()); + break; + } + } + + // Headers - include Content-Type for payload-bearing methods, Content-Length: 0 for empty PUT. + std::vector<std::pair<std::string, std::string>> ExtraHeaders; + if (Ctx->Spec.Method == AsyncRequestMethod::PostWithPayload) + { + const ZenContentType Effective = Ctx->Spec.HasContentType ? Ctx->Spec.ContentType : Ctx->Spec.Payload.GetContentType(); + ExtraHeaders.emplace_back("Content-Type", std::string(MapContentTypeToString(Effective))); + } + else if (Ctx->Spec.Method == AsyncRequestMethod::PutWithPayload) + { + ExtraHeaders.emplace_back("Content-Type", std::string(MapContentTypeToString(Ctx->Spec.Payload.GetContentType()))); + } + else if (Ctx->Spec.Method == AsyncRequestMethod::Put) + { + ExtraHeaders.emplace_back("Content-Length", "0"); + } + + AccessTokenResult Token = GetAccessToken(); + if (Token.ProviderFailed) + { + // Provider configured but failed twice: do NOT silently downgrade + // to an anonymous request - the server will respond 403 and the + // caller has no way to tell auth failed. + HttpClient::Response ErrResp; + ErrResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kInternalError, + .ErrorMessage = "AsyncHttpClient: access token provider failed; refusing to issue anonymous request", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(ErrResp), *Ctx); + CallbackGuard.Dismiss(); + return; + } + + Ctx->HeaderList = BuildHeaderList(Ctx->Spec.AdditionalHeader, m_SessionId, std::move(Token.Token), ExtraHeaders); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + InFlightGuard.Dismiss(); + HandleGuard.Dismiss(); + CallbackGuard.Dismiss(); + SubmitTransfer(Handle, std::move(Ctx)); } // -- Submit a transfer ----------------------------------------------- @@ -298,56 +787,66 @@ struct AsyncHttpClient::Impl void SubmitTransfer(CURL* Handle, std::unique_ptr<TransferContext> Ctx) { ZEN_TRACE_CPU("AsyncHttpClient::SubmitTransfer"); - // Setup write/header callbacks - curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &Ctx->WriteData); - curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &Ctx->HeaderData); - - m_Transfers[Handle] = std::move(Ctx); - + // Pick the WRITE callback by method: + // - Stream: forwards each chunk to the caller's OnData (no copy) + // - other: buffers bytes in TransferContext::Body + if (Ctx->Spec.Method == AsyncRequestMethod::Stream) + { + curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, AsyncCurlStreamWriteCallback); + curl_easy_setopt(Handle, CURLOPT_WRITEDATA, Ctx.get()); + } + else + { + curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, AsyncCurlWriteCallback); + curl_easy_setopt(Handle, CURLOPT_WRITEDATA, Ctx.get()); + } + curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, AsyncCurlHeaderCallback); + curl_easy_setopt(Handle, CURLOPT_HEADERDATA, Ctx.get()); + // Stash TokenId on the curl handle so CheckCompleted can look up the + // TransferContext directly from curl_multi_info_read's CURLMsg.easy_handle. + curl_easy_setopt(Handle, CURLOPT_PRIVATE, reinterpret_cast<void*>(static_cast<uintptr_t>(Ctx->TokenId))); + + Ctx->CurlHandle = Handle; + const uint64_t TokenIdLocal = Ctx->TokenId; + + // Try the curl_multi add first. On failure, Ctx is still owned locally so the + // rollback is a single Release path with no map churn. On success, ownership + // moves into the single TokenId-keyed lookup table. CURLMcode Mc = curl_multi_add_handle(m_Multi, Handle); if (Mc != CURLM_OK) { - auto Stolen = std::move(m_Transfers[Handle]); - m_Transfers.erase(Handle); + Ctx->CurlHandle = nullptr; ReleaseHandle(Handle); HttpClient::Response ErrorResponse; ErrorResponse.Error = HttpClient::ErrorContext{.ErrorCode = HttpClientErrorCode::kInternalError, .ErrorMessage = fmt::format("curl_multi_add_handle failed: {}", curl_multi_strerror(Mc))}; - asio::post(m_IoContext, - [Cb = std::move(Stolen->Callback), Response = std::move(ErrorResponse)]() mutable { Cb(std::move(Response)); }); + DispatchCallback(std::move(Ctx->Callback), std::move(ErrorResponse), *Ctx); + OnSlotFreed(); return; } - } - // -- Socket-action integration --------------------------------------- - // - // curl_multi drives I/O via two callbacks: - // - SocketCallback: curl tells us which sockets to watch for read/write - // - TimerCallback: curl tells us when to fire a timeout - // - // On each socket event or timeout we call curl_multi_socket_action(), - // then drain completed transfers via curl_multi_info_read(). + m_Transfers.emplace(TokenIdLocal, std::move(Ctx)); + } - // Per-socket state: wraps the native fd in an ASIO socket for async_wait. - struct SocketInfo + // Telemetry only; this client does not gate fan-out. Callers (e.g. + // S3AsyncStorage) layer their own admission semaphore on top. + // Assert catches SubmitFromSpec/OnSlotFreed imbalance. + void OnSlotFreed() { - asio::ip::tcp::socket Socket; - int WatchFlags = 0; // CURL_POLL_IN, CURL_POLL_OUT, CURL_POLL_INOUT - - explicit SocketInfo(asio::io_context& IoContext) : Socket(IoContext) {} - }; + ZEN_ASSERT(m_InFlight > 0); + --m_InFlight; + } - // Static thunks registered with curl_multi ---------------------------- + // curl_multi drives I/O via SocketCallback (which fds to watch) and TimerCallback (when to fire). + // On each event we call curl_multi_socket_action() and drain via curl_multi_info_read(). + // Static thunks: UserData = Impl* (set via CURLMOPT_SOCKETDATA / CURLMOPT_TIMERDATA). Bodies run on io thread. static int CurlSocketCallback(CURL* Easy, curl_socket_t Fd, int Action, void* UserPtr, void* SocketPtr) { - ZEN_UNUSED(Easy); auto* Self = static_cast<Impl*>(UserPtr); - Self->OnCurlSocket(Fd, Action, static_cast<SocketInfo*>(SocketPtr)); + Self->OnCurlSocket(Easy, Fd, Action, static_cast<AsyncSocketInfo*>(SocketPtr)); return 0; } @@ -359,6 +858,219 @@ struct AsyncHttpClient::Impl return 0; } + // Async-specific HEADER callback. Appends raw "Key: Value\r\n" lines to + // Ctx->HeaderArena (one growing buffer; ~zero allocs in the common case + // where the initial reserve covers the response). Inline-parses + // Content-Length and Content-Type unconditionally; parses ETag only when + // Spec.WantEtag is set. No std::string/pair allocations per line. + static size_t AsyncCurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) + { + auto* Ctx = static_cast<TransferContext*>(UserData); + const size_t TotalBytes = Size * Nmemb; + std::string_view Line(Buffer, TotalBytes); + + Ctx->HeaderArena.append(Buffer, TotalBytes); + + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + if (Line.empty()) + { + return TotalBytes; + } + const size_t Colon = Line.find(':'); + if (Colon == std::string_view::npos) + { + return TotalBytes; // HTTP status line or malformed + } + std::string_view Key = Line.substr(0, Colon); + std::string_view Value = Line.substr(Colon + 1); + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + if (StrCaseEquals(Key, "Content-Length")) + { + uint64_t Length = 0; + std::from_chars_result Res = std::from_chars(Value.data(), Value.data() + Value.size(), Length); + if (Res.ec == std::errc{}) + { + Ctx->ContentLength = Length; + Ctx->ContentLengthSet = true; + } + } + else if (StrCaseEquals(Key, "Content-Type")) + { + Ctx->BodyContentType = ParseContentType(Value); + } + else if (Ctx->Spec.WantEtag && StrCaseEquals(Key, "ETag")) + { + Ctx->Etag.assign(Value); + } + + return TotalBytes; + } + + // Async-specific write callback. Targets a TransferContext directly. + // Preallocates Body from Ctx->ContentLength (parsed in HEADER cb). If + // Content-Length is absent (e.g. chunked encoding), falls back to + // BodyChunks accumulation; CompleteTransfer flattens at the end. + static size_t AsyncCurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) + { + auto* Ctx = static_cast<TransferContext*>(UserData); + const size_t TotalBytes = Size * Nmemb; + if (TotalBytes == 0) + { + return 0; + } + + if (!Ctx->BodyPreallocated && Ctx->BodyWriteOffset == 0 && Ctx->BodyChunks.empty() && Ctx->ContentLengthSet && + Ctx->ContentLength > 0) + { + Ctx->Body = IoBuffer(static_cast<size_t>(Ctx->ContentLength)); + Ctx->BodyPreallocated = true; + } + + if (Ctx->BodyPreallocated) + { + if (Ctx->BodyWriteOffset + TotalBytes > Ctx->Body.GetSize()) + { + // Server sent more than Content-Length advertised; abort. + return 0; + } + memcpy(static_cast<uint8_t*>(Ctx->Body.MutableData()) + Ctx->BodyWriteOffset, Ptr, TotalBytes); + Ctx->BodyWriteOffset += TotalBytes; + } + else + { + IoBuffer Chunk(TotalBytes); + memcpy(Chunk.MutableData(), Ptr, TotalBytes); + Ctx->BodyChunks.push_back(std::move(Chunk)); + } + + return TotalBytes; + } + + // PutWithSource read callback. Pulls up to MaxBytes from Spec.OnReadSource + // into curl's send buffer. Source closure runs on the io thread - same + // strand discipline as Stream's OnData. Returning 0 with SourceOffset < + // StreamingPutSize signals an upload abort to curl. + static size_t AsyncCurlSourceReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) + { + auto* Ctx = static_cast<TransferContext*>(UserData); + const size_t MaxBytes = Size * Nmemb; + if (MaxBytes == 0 || !Ctx->Spec.OnReadSource) + { + return 0; + } + // Catch at the curl boundary: user OnReadSource may call into IoBuffer + // allocation, file reads, or async dispatch which can throw. Letting an + // exception unwind through curl's C frames is UB. + size_t Pulled = 0; + try + { + Pulled = Ctx->Spec.OnReadSource(reinterpret_cast<uint8_t*>(Buffer), MaxBytes, Ctx->SourceOffset); + } + catch (const std::exception& Ex) + { + Ctx->CallbackFailed = true; + Ctx->CallbackErrorMessage = fmt::format("upload source callback threw: {}", Ex.what()); + return CURL_READFUNC_ABORT; + } + catch (...) + { + Ctx->CallbackFailed = true; + Ctx->CallbackErrorMessage = "upload source callback threw unknown exception"; + return CURL_READFUNC_ABORT; + } + if (Pulled == 0 && Ctx->SourceOffset < Ctx->Spec.StreamingPutSize) + { + return CURL_READFUNC_ABORT; + } + Ctx->SourceOffset += Pulled; + return Pulled; + } + + // Stream-method write callback. Hands each chunk to the caller's OnData + // without allocating or copying. The pointer is curl's internal receive + // buffer; valid only for the duration of this call. Caller's OnData runs + // on the io thread, so blocking work (disk write etc) blocks the poll + // loop. TotalSize comes from inline-parsed Content-Length (0 if absent / + // chunked). + static size_t AsyncCurlStreamWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) + { + auto* Ctx = static_cast<TransferContext*>(UserData); + const size_t TotalBytes = Size * Nmemb; + if (TotalBytes == 0) + { + return 0; + } + + if (!Ctx->Spec.OnData) + { + // Stream method requires OnData by contract; reaching this branch + // is API misuse. Returning TotalBytes (success) would silently + // drop the body and report "ok"; fail loudly instead. + Ctx->CallbackFailed = true; + Ctx->CallbackErrorMessage = "stream request submitted without OnData callback"; + return 0; + } + + // Catch at the curl boundary so an exception inside the user OnData + // (e.g. IoBuffer alloc failure, ScheduleWork rejection, ZEN_ASSERT in a + // downstream pool) cannot propagate through curl's C frames. Stash on + // Ctx so CompleteTransfer surfaces it as kInternalError. + try + { + const bool ContinueTransfer = Ctx->Spec.OnData(reinterpret_cast<const uint8_t*>(Ptr), TotalBytes, Ctx->ContentLength); + return ContinueTransfer ? TotalBytes : 0; // returning 0 aborts + } + catch (const std::exception& Ex) + { + Ctx->CallbackFailed = true; + Ctx->CallbackErrorMessage = fmt::format("stream data callback threw: {}", Ex.what()); + return 0; + } + catch (...) + { + Ctx->CallbackFailed = true; + Ctx->CallbackErrorMessage = "stream data callback threw unknown exception"; + return 0; + } + } + + // Take ownership of socket close. CLOSESOCKETFUNCTION is invoked from + // within curl_multi operations which run on the io thread, so direct map + // access is safe. The asio tcp::socket destructor closes the fd; we + // return 0 to tell curl the close succeeded. Letting curl close as well + // would race (double-close, fd-reuse hazard) and on Windows IOCP + // `release()` throws `operation_not_supported`, killing the io thread. + static int CurlCloseSocketCallback(void* ClientPtr, curl_socket_t Fd) + { + auto* Self = static_cast<Impl*>(ClientPtr); + auto It = Self->m_Sockets.find(Fd); + if (It != Self->m_Sockets.end()) + { + asio::error_code Ec; + It->second->Socket.cancel(Ec); + Self->m_Sockets.erase(It); + return 0; + } + // Fd not tracked (e.g. pre-poll-add or post-shutdown); close directly. +#if ZEN_PLATFORM_WINDOWS + ::closesocket(Fd); +#else + ::close(Fd); +#endif + return 0; + } + void SetupMultiCallbacks() { curl_multi_setopt(m_Multi, CURLMOPT_SOCKETFUNCTION, CurlSocketCallback); @@ -369,45 +1081,97 @@ struct AsyncHttpClient::Impl // Called by curl when socket watch state changes --------------------- - void OnCurlSocket(curl_socket_t Fd, int Action, SocketInfo* Info) + // Synthesize a transport-level failure for the easy handle currently bound + // to the curl_multi entry that owns Fd. Used when the asio side cannot bind + // the fd; without this the affected transfer would hang on curl's + // connect/transfer timeout instead of failing fast with the real error. + void FailEasyHandleForFd(CURL* Easy, std::string_view Reason) + { + if (!Easy) + { + return; + } + curl_multi_remove_handle(m_Multi, Easy); + + char* Private = nullptr; + curl_easy_getinfo(Easy, CURLINFO_PRIVATE, &Private); + const uint64_t TokenId = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(Private)); + + auto It = m_Transfers.find(TokenId); + if (It == m_Transfers.end()) + { + ReleaseHandle(Easy); + return; + } + + std::unique_ptr<TransferContext> Ctx = std::move(It->second); + m_Transfers.erase(It); + Ctx->CurlHandle = nullptr; + ReleaseHandle(Easy); + + HttpClient::Response Resp; + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kInternalError, + .ErrorMessage = std::string(Reason), + }; + DispatchCallback(std::move(Ctx->Callback), std::move(Resp), *Ctx); + OnSlotFreed(); + } + + void OnCurlSocket(CURL* Easy, curl_socket_t Fd, int Action, AsyncSocketInfo* Info) { if (Action == CURL_POLL_REMOVE) { if (Info) { - // Cancel pending async_wait ops before releasing the fd. - // curl owns the fd, so we must release() rather than close(). - Info->Socket.cancel(); - if (Info->Socket.is_open()) - { - Info->Socket.release(); - } - m_Sockets.erase(Fd); + // Cancel any pending async_wait but KEEP the AsyncSocketInfo + // alive in m_Sockets. CURL_POLL_REMOVE only means "stop + // watching this socket"; curl may still use the fd + // (keep-alive reuse) and rewatch it later. The asio Socket + // stays bound to the same IOCP it was first assigned to; + // we never call release()+assign() on the same fd, which + // avoided a race where in-flight async_wait callbacks + // raced with curl_multi reading from the socket and + // corrupted HTTP framing. The fd's actual close happens + // in CurlCloseSocketCallback, which erases the entry and + // lets the asio Socket destructor close. + asio::error_code Ec; + Info->Socket.cancel(Ec); + Info->WatchFlags = 0; + Info->PendingFlags = 0; } return; } if (!Info) { - // New socket - wrap the native fd in an ASIO socket. - auto [It, Inserted] = m_Sockets.emplace(Fd, std::make_unique<SocketInfo>(m_IoContext)); - Info = It->second.get(); - - asio::error_code Ec; - // Determine protocol from the fd (v4 vs v6). Default to v4. - Info->Socket.assign(asio::ip::tcp::v4(), Fd, Ec); - if (Ec) + // CURL_POLL_IN/OUT with no Info attached. Two cases: + // 1) brand new fd - emplace, assign, IOCP-bind. + // 2) curl re-watching a kept-alive fd it earlier removed - + // reuse the existing AsyncSocketInfo, no re-assign. + auto It = m_Sockets.find(Fd); + if (It == m_Sockets.end()) { - // Try v6 as fallback - Info->Socket.assign(asio::ip::tcp::v6(), Fd, Ec); - } - if (Ec) - { - ZEN_WARN("AsyncHttpClient: failed to assign socket fd {}: {}", static_cast<int>(Fd), Ec.message()); - m_Sockets.erase(Fd); - return; - } + auto [NewIt, _] = m_Sockets.emplace(Fd, std::make_unique<AsyncSocketInfo>(m_Strand)); + It = NewIt; + asio::error_code Ec; + It->second->Socket.assign(asio::ip::tcp::v4(), Fd, Ec); + if (Ec) + { + It->second->Socket.assign(asio::ip::tcp::v6(), Fd, Ec); + } + if (Ec) + { + std::string Reason = + fmt::format("AsyncHttpClient: failed to assign socket fd {}: {}", static_cast<int>(Fd), Ec.message()); + ZEN_WARN("{}", Reason); + m_Sockets.erase(It); + FailEasyHandleForFd(Easy, Reason); + return; + } + } + Info = It->second.get(); curl_multi_assign(m_Multi, Fd, Info); } @@ -415,37 +1179,84 @@ struct AsyncHttpClient::Impl SetSocketWatch(Fd, Info); } - void SetSocketWatch(curl_socket_t Fd, SocketInfo* Info) + void SetSocketWatch(curl_socket_t Fd, AsyncSocketInfo* Info) { - // Cancel any pending wait before issuing a new one. - Info->Socket.cancel(); + // Cancel only when a previously-watched direction is no longer wanted. + // In the common path (one-shot async_wait completes, curl re-watches + // the same flags) PendingFlags is a subset of WatchFlags and we just + // re-arm the missing direction without touching CancelIoEx. + const int Desired = Info->WatchFlags & (CURL_POLL_IN | CURL_POLL_OUT); + + if (Info->PendingFlags & ~Desired) + { + asio::error_code Ec; + Info->Socket.cancel(Ec); + Info->PendingFlags = 0; + } + + const int ToAdd = Desired & ~Info->PendingFlags; - if (Info->WatchFlags & CURL_POLL_IN) + if (ToAdd & CURL_POLL_IN) { - Info->Socket.async_wait(asio::socket_base::wait_read, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { - if (Ec || m_ShuttingDown) - { - return; - } - OnSocketReady(Fd, CURL_CSELECT_IN); - })); + Info->PendingFlags |= CURL_POLL_IN; + Info->Socket.async_wait(asio::socket_base::wait_read, [this, Fd](const asio::error_code& Ec) { + if (m_ShuttingDown) + { + return; + } + auto It = m_Sockets.find(Fd); + if (It == m_Sockets.end()) + { + return; + } + It->second->PendingFlags &= ~CURL_POLL_IN; + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_DEBUG("AsyncHttpClient: read async_wait fd {} error: {}; signalling curl with CURL_CSELECT_ERR", + static_cast<int>(Fd), + Ec.message()); + OnSocketReady(Fd, CURL_CSELECT_ERR); + } + return; + } + OnSocketReady(Fd, CURL_CSELECT_IN); + }); } - if (Info->WatchFlags & CURL_POLL_OUT) + if (ToAdd & CURL_POLL_OUT) { - Info->Socket.async_wait(asio::socket_base::wait_write, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { - if (Ec || m_ShuttingDown) - { - return; - } - OnSocketReady(Fd, CURL_CSELECT_OUT); - })); + Info->PendingFlags |= CURL_POLL_OUT; + Info->Socket.async_wait(asio::socket_base::wait_write, [this, Fd](const asio::error_code& Ec) { + if (m_ShuttingDown) + { + return; + } + auto It = m_Sockets.find(Fd); + if (It == m_Sockets.end()) + { + return; + } + It->second->PendingFlags &= ~CURL_POLL_OUT; + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_DEBUG("AsyncHttpClient: write async_wait fd {} error: {}; signalling curl with CURL_CSELECT_ERR", + static_cast<int>(Fd), + Ec.message()); + OnSocketReady(Fd, CURL_CSELECT_ERR); + } + return; + } + OnSocketReady(Fd, CURL_CSELECT_OUT); + }); } } void OnSocketReady(curl_socket_t Fd, int CurlAction) { - ZEN_TRACE_CPU("AsyncHttpClient::OnSocketReady"); int StillRunning = 0; curl_multi_socket_action(m_Multi, Fd, CurlAction, &StillRunning); CheckCompleted(); @@ -472,7 +1283,7 @@ struct AsyncHttpClient::Impl if (TimeoutMs == 0) { - // curl wants immediate action - run it directly on the strand. + // curl wants immediate action - run it on the next strand tick. asio::post(m_Strand, [this]() { if (m_ShuttingDown) { @@ -486,16 +1297,24 @@ struct AsyncHttpClient::Impl } m_Timer.expires_after(std::chrono::milliseconds(TimeoutMs)); - m_Timer.async_wait(asio::bind_executor(m_Strand, [this](const asio::error_code& Ec) { - if (Ec || m_ShuttingDown) + m_Timer.async_wait([this](const asio::error_code& Ec) { + if (m_ShuttingDown) + { + return; + } + if (Ec) { + if (Ec != asio::error::operation_aborted) + { + ZEN_DEBUG("AsyncHttpClient: curl multi timer error: {}", Ec.message()); + } return; } ZEN_TRACE_CPU("AsyncHttpClient::OnTimeout"); int StillRunning = 0; curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning); CheckCompleted(); - })); + }); } // Drain completed transfers from curl_multi -------------------------- @@ -516,7 +1335,13 @@ struct AsyncHttpClient::Impl curl_multi_remove_handle(m_Multi, Handle); - auto It = m_Transfers.find(Handle); + // Recover TokenId from CURLOPT_PRIVATE; cheaper than a per-handle + // reverse map. Returns nullptr if option was never set. + char* Private = nullptr; + curl_easy_getinfo(Handle, CURLINFO_PRIVATE, &Private); + const uint64_t TokenId = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(Private)); + + auto It = m_Transfers.find(TokenId); if (It == m_Transfers.end()) { ReleaseHandle(Handle); @@ -530,9 +1355,45 @@ struct AsyncHttpClient::Impl } } + // Mirrors CurlHttpClient::ShouldRetry semantics; keep the two in sync. + static bool ShouldRetryAsync(CURLcode CurlResult, long StatusCode) + { + switch (CurlResult) + { + case CURLE_OK: + break; + case CURLE_COULDNT_CONNECT: + case CURLE_RECV_ERROR: + case CURLE_SEND_ERROR: + case CURLE_OPERATION_TIMEDOUT: + case CURLE_PARTIAL_FILE: + return true; + default: + return false; + } + switch (static_cast<HttpResponseCode>(StatusCode)) + { + case HttpResponseCode::RequestTimeout: + case HttpResponseCode::TooManyRequests: + case HttpResponseCode::InternalServerError: + case HttpResponseCode::BadGateway: + case HttpResponseCode::ServiceUnavailable: + case HttpResponseCode::GatewayTimeout: + return true; + default: + return false; + } + } + void CompleteTransfer(CURL* Handle, CURLcode CurlResult, std::unique_ptr<TransferContext> Ctx) { ZEN_TRACE_CPU("AsyncHttpClient::CompleteTransfer"); + + // Free the in-flight counter before any retry / cancel branch. Retry + // re-submits via SubmitFromSpec, which re-increments the counter for + // the next attempt - keeping the assert balanced across retries. + OnSlotFreed(); + // Extract result info long StatusCode = 0; curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &StatusCode); @@ -548,13 +1409,205 @@ struct AsyncHttpClient::Impl ReleaseHandle(Handle); + // Cancellation came in after curl ran but before we processed completion. + // Synthesize a cancel response and skip retry. + if (Ctx->Cancelled) + { + HttpClient::Response CancelResp; + CancelResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(CancelResp), *Ctx); + return; + } + + // User OnData / OnReadSource threw inside curl. The transfer is already + // aborted; surface the stashed exception text and skip retry (a callback + // exception is not a transient transport failure). + if (Ctx->CallbackFailed) + { + HttpClient::Response Resp; + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kInternalError, + .ErrorMessage = std::move(Ctx->CallbackErrorMessage), + }; + DispatchCallback(std::move(Ctx->Callback), std::move(Resp), *Ctx); + return; + } + + // Retry path: re-issue from spec after backoff. Keeps the user callback + // untouched so the eventual final result fires only once. + if (!m_ShuttingDown && Ctx->AttemptCount < m_Settings.RetryCount && ShouldRetryAsync(CurlResult, StatusCode)) + { + ++Ctx->AttemptCount; + const long BackoffMs = 100 * Ctx->AttemptCount; + + if (CurlResult != CURLE_OK) + { + ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' (Curl error: {}) Attempt {}/{}", + m_SessionId, + static_cast<int>(MapCurlError(CurlResult)), + curl_easy_strerror(CurlResult), + static_cast<int>(CurlResult), + Ctx->AttemptCount, + m_Settings.RetryCount + 1); + } + else + { + ZEN_INFO("Retry (session: {}): HTTP status ({}) '{}' Attempt {}/{}", + m_SessionId, + StatusCode, + zen::ToString(HttpResponseCode(StatusCode)), + Ctx->AttemptCount, + m_Settings.RetryCount + 1); + } + + // Stash the just-finished attempt's failure so the abandon path can + // surface it instead of a generic "Request canceled (retry abandoned)". + // Non-empty LastErrorMessage marks the stash valid. + Ctx->LastCurlResult = CurlResult; + Ctx->LastStatusCode = StatusCode; + Ctx->LastErrorMessage = (CurlResult != CURLE_OK) ? std::string(curl_easy_strerror(CurlResult)) + : std::string(zen::ToString(HttpResponseCode(StatusCode))); + Ctx->ResetForRetry(); + + const uint64_t RetryTokenId = Ctx->TokenId; + auto RetryTimer = std::make_shared<asio::steady_timer>(m_Strand); + RetryTimer->expires_after(std::chrono::milliseconds(BackoffMs)); + + // Park Ctx + Timer in m_RetryingTransfers so HandleCancel can find + // it and cancel the timer (early cancel without paying the full + // backoff). The timer lambda re-claims Ctx through the map; if + // HandleCancel got there first the entry is gone and the lambda + // returns silently. + auto [It, Inserted] = m_RetryingTransfers.emplace(RetryTokenId, RetryEntry{std::move(Ctx), RetryTimer}); + ZEN_ASSERT(Inserted); + + // Capture weak_from_this() rather than raw `this`. With an external + // io_context, Cleanup cancels the timer but the cancellation handler + // is queued on the caller's loop and may fire AFTER ~Impl runs. The + // owned-context path drains via restart()+poll() before destruction + // so this is moot there, but the weak ref is the cheapest way to + // keep both paths safe. + RetryTimer->async_wait([Self = weak_from_this(), RetryTokenId](const asio::error_code& Ec) { + auto Locked = Self.lock(); + if (!Locked) + { + return; + } + Impl& Me = *Locked; + auto It = Me.m_RetryingTransfers.find(RetryTokenId); + if (It == Me.m_RetryingTransfers.end()) + { + // HandleCancel already removed the entry and dispatched the + // cancel callback. + return; + } + std::unique_ptr<TransferContext> Ctx = std::move(It->second.Ctx); + Me.m_RetryingTransfers.erase(It); + + if (Ec || Me.m_ShuttingDown) + { + // Retry abandoned by timer cancellation or client shutdown. + // Surface the underlying failure (stashed pre-backoff) so + // the caller can distinguish a real timeout/throttle from a + // shutdown / cancel race. + HttpClient::Response Resp; + if (!Ctx->LastErrorMessage.empty()) + { + Resp.StatusCode = HttpResponseCode(Ctx->LastStatusCode); + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = + (Ctx->LastCurlResult != CURLE_OK) ? MapCurlError(Ctx->LastCurlResult) : HttpClientErrorCode::kOtherError, + .ErrorMessage = fmt::format("Request canceled (retry abandoned after: {})", Ctx->LastErrorMessage), + }; + } + else + { + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled (retry abandoned)", + }; + } + Me.DispatchCallback(std::move(Ctx->Callback), std::move(Resp), *Ctx); + return; + } + Me.SubmitFromSpec(std::move(Ctx)); + }); + return; + } + // Build response HttpClient::Response Response; Response.StatusCode = HttpResponseCode(StatusCode); Response.UploadedBytes = static_cast<int64_t>(UpBytes); Response.DownloadedBytes = static_cast<int64_t>(DownBytes); Response.ElapsedSeconds = Elapsed; - Response.Header = BuildHeaderMap(Ctx->ResponseHeaders); + // Hand the raw arena over - FindHeader scans this lazily. Build the + // parsed KeyValueMap only when caller explicitly asks (rare). + Response.HeaderArena = std::move(Ctx->HeaderArena); + if (Ctx->Spec.WantHeaderMap) + { + std::string_view View(Response.HeaderArena); + while (!View.empty()) + { + const size_t LineEnd = View.find('\n'); + std::string_view Line = LineEnd == std::string_view::npos ? View : View.substr(0, LineEnd); + View = LineEnd == std::string_view::npos ? std::string_view{} : View.substr(LineEnd + 1); + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + if (auto Header = ParseHeaderLine(Line)) + { + Response.Header->insert_or_assign(std::string(Header->first), std::string(Header->second)); + } + } + } + + // Helper: produce the response IoBuffer from whichever Body path was + // taken. Preallocated path moves with zero copy; chunked fallback + // moves the single chunk or flattens N chunks into one allocation. + auto BuildResponsePayload = [&]() -> IoBuffer { + if (Ctx->BodyPreallocated) + { + IoBuffer Out = std::move(Ctx->Body); + if (Ctx->BodyWriteOffset != Out.GetSize()) + { + // Server closed early - return a non-owning sub-buffer over + // the actually-received prefix. Sub-buffer holds a ref to + // Out's core so the underlying allocation stays alive; no + // memcpy. + return IoBuffer(Out, 0, Ctx->BodyWriteOffset); + } + return Out; + } + if (Ctx->BodyChunks.size() == 1) + { + return std::move(Ctx->BodyChunks[0]); + } + if (!Ctx->BodyChunks.empty()) + { + // Flatten N chunks into one IoBuffer; single alloc avoids a copy chain. + size_t Total = 0; + for (const IoBuffer& C : Ctx->BodyChunks) + { + Total += C.GetSize(); + } + IoBuffer Out(Total); + uint8_t* Dst = static_cast<uint8_t*>(Out.MutableData()); + for (const IoBuffer& C : Ctx->BodyChunks) + { + memcpy(Dst, C.GetData(), C.GetSize()); + Dst += C.GetSize(); + } + return Out; + } + return IoBuffer{}; + }; + + const bool HasBody = Ctx->BodyPreallocated ? Ctx->BodyWriteOffset > 0 : !Ctx->BodyChunks.empty(); if (CurlResult != CURLE_OK) { @@ -562,258 +1615,261 @@ struct AsyncHttpClient::Impl if (CurlResult != CURLE_OPERATION_TIMEDOUT && CurlResult != CURLE_COULDNT_CONNECT && CurlResult != CURLE_ABORTED_BY_CALLBACK) { - ZEN_WARN("AsyncHttpClient failure: ({}) '{}'", static_cast<int>(CurlResult), ErrorMsg); + ZEN_WARN("AsyncHttpClient failure: token={} {} '{}': ({}) '{}'", + Ctx->TokenId, + AsyncRequestMethodName(Ctx->Spec.Method), + Ctx->Spec.Url, + static_cast<int>(CurlResult), + ErrorMsg); } - if (!Ctx->Body.empty()) + if (HasBody) { - Response.ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size()); + Response.ResponsePayload = BuildResponsePayload(); } Response.Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(CurlResult), .ErrorMessage = std::string(ErrorMsg)}; } - else if (StatusCode == static_cast<long>(HttpResponseCode::NoContent) || Ctx->Body.empty()) + else if (StatusCode == static_cast<long>(HttpResponseCode::NoContent) || !HasBody) { // No payload } else { - IoBuffer PayloadBuffer = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size()); - ApplyContentTypeFromHeaders(PayloadBuffer, Ctx->ResponseHeaders); + IoBuffer PayloadBuffer = BuildResponsePayload(); + if (Ctx->BodyContentType != ZenContentType::kUnknownContentType) + { + PayloadBuffer.SetContentType(Ctx->BodyContentType); + } const HttpResponseCode Code = HttpResponseCode(StatusCode); if (!IsHttpSuccessCode(Code) && Code != HttpResponseCode::NotFound) { - ZEN_WARN("AsyncHttpClient request failed: status={}, base={}", static_cast<int>(Code), m_BaseUri); + ZEN_WARN("AsyncHttpClient request failed: token={} {} '{}': status={}", + Ctx->TokenId, + AsyncRequestMethodName(Ctx->Spec.Method), + Ctx->Spec.Url, + static_cast<int>(Code)); } Response.ResponsePayload = std::move(PayloadBuffer); } - // Dispatch the user callback off the strand so a slow callback - // cannot starve the curl_multi poll loop. - asio::post(m_IoContext, [LogRef = m_Log, Cb = std::move(Ctx->Callback), Response = std::move(Response)]() mutable { - try + // Token reaches terminal state. Ctx (and its embedded TokenState) is + // destroyed when this scope ends; late Cancel() calls find no entry in + // m_Transfers and become no-ops. + DispatchCallback(std::move(Ctx->Callback), std::move(Response), *Ctx); + } + + // -- Async verb implementations -------------------------------------- + + AsyncRequestToken Submit(std::unique_ptr<TransferContext> Ctx) + { + // Allocate token ID + state up front so callers can cancel before the + // posted submit even runs. Token::State is shared between the user-held + // AsyncRequestToken and the TransferContext (no separate strand-side map). + const uint64_t Id = m_NextTokenId.fetch_add(1, std::memory_order_relaxed); + Ctx->TokenId = Id; + + auto State = std::make_shared<AsyncRequestToken::State>(); + State->CancelFn = [WeakSelf = weak_from_this(), Id]() { + auto Self = WeakSelf.lock(); + if (!Self) { - Cb(std::move(Response)); + return; } - catch (const std::exception& Ex) + asio::post(Self->m_Strand, [Self, Id]() { Self->HandleCancel(Id); }); + }; + Ctx->TokenStateRef = State; + + asio::post(m_Strand, [this, Ctx = std::move(Ctx)]() mutable { + if (m_ShuttingDown) { - ZEN_SCOPED_LOG(LogRef); - ZEN_ERROR("AsyncHttpClient: unhandled exception in completion callback: {}", Ex.what()); + HttpClient::Response Resp; + Resp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled (client shutting down)", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(Resp), *Ctx); + return; } + SubmitFromSpec(std::move(Ctx)); }); - } - // -- Async verb implementations -------------------------------------- + return AsyncRequestToken(std::move(State)); + } - void DoAsyncGet(std::string Url, - AsyncHttpCallback Callback, - HttpClient::KeyValueMap AdditionalHeader, - HttpClient::KeyValueMap Parameters) + void HandleCancel(uint64_t Id) { - asio::post(m_Strand, - [this, - Url = std::move(Url), - Callback = std::move(Callback), - AdditionalHeader = std::move(AdditionalHeader), - Parameters = std::move(Parameters)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Get"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, Parameters); - curl_easy_setopt(Handle, CURLOPT_HTTPGET, 1L); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto It = m_Transfers.find(Id); + if (It != m_Transfers.end()) + { + std::unique_ptr<TransferContext> Ctx = std::move(It->second); + m_Transfers.erase(It); + + Ctx->Cancelled = true; + if (Ctx->CurlHandle) + { + curl_multi_remove_handle(m_Multi, Ctx->CurlHandle); + ReleaseHandle(Ctx->CurlHandle); + Ctx->CurlHandle = nullptr; + } + + HttpClient::Response CancelResp; + CancelResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(CancelResp), *Ctx); + OnSlotFreed(); + return; + } + + // Cancel landed during the retry backoff: take the parked Ctx + Timer, + // cancel the timer (the in-flight async_wait fires later with + // operation_aborted but finds no entry, so it's a no-op), and dispatch + // the cancel callback now so the user observes immediate cancellation + // rather than waiting out the backoff. + auto RetIt = m_RetryingTransfers.find(Id); + if (RetIt != m_RetryingTransfers.end()) + { + std::unique_ptr<TransferContext> Ctx = std::move(RetIt->second.Ctx); + std::shared_ptr<asio::steady_timer> Timer = std::move(RetIt->second.Timer); + m_RetryingTransfers.erase(RetIt); + Timer->cancel(); + + HttpClient::Response CancelResp; + CancelResp.Error = HttpClient::ErrorContext{ + .ErrorCode = HttpClientErrorCode::kRequestCancelled, + .ErrorMessage = "Request canceled", + }; + DispatchCallback(std::move(Ctx->Callback), std::move(CancelResp), *Ctx); + return; + } + + // Cancel landed before SubmitFromSpec ran (Cancel posted between Submit + // returning and the io thread executing the posted SubmitFromSpec). + // State.Cancelled has already been set by Cancel(); SubmitFromSpec + // checks it and synthesizes the cancel callback when the transfer + // eventually arrives. Nothing to do here. } - void DoAsyncHead(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) + AsyncRequestToken DoAsyncGet(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) { - asio::post(m_Strand, - [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Head"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, {}); - curl_easy_setopt(Handle, CURLOPT_NOBODY, 1L); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::Get; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Parameters = std::move(Parameters); + return Submit(std::move(Ctx)); } - void DoAsyncDelete(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) + AsyncRequestToken DoAsyncHead(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) { - asio::post(m_Strand, - [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Delete"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, {}); - curl_easy_setopt(Handle, CURLOPT_CUSTOMREQUEST, "DELETE"); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::Head; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + return Submit(std::move(Ctx)); } - void DoAsyncPost(std::string Url, - AsyncHttpCallback Callback, - HttpClient::KeyValueMap AdditionalHeader, - HttpClient::KeyValueMap Parameters) + AsyncRequestToken DoAsyncDelete(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) { - asio::post(m_Strand, - [this, - Url = std::move(Url), - Callback = std::move(Callback), - AdditionalHeader = std::move(AdditionalHeader), - Parameters = std::move(Parameters)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Post"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, Parameters); - curl_easy_setopt(Handle, CURLOPT_POST, 1L); - curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE, 0L); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::Delete; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + return Submit(std::move(Ctx)); } - void DoAsyncPostWithPayload(std::string Url, - IoBuffer Payload, - ZenContentType ContentType, - AsyncHttpCallback Callback, - HttpClient::KeyValueMap AdditionalHeader) + AsyncRequestToken DoAsyncPost(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) { - asio::post(m_Strand, - [this, - Url = std::move(Url), - Payload = std::move(Payload), - ContentType, - Callback = std::move(Callback), - AdditionalHeader = std::move(AdditionalHeader)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::PostWithPayload"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, {}); - curl_easy_setopt(Handle, CURLOPT_POST, 1L); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->PayloadBuffer = std::move(Payload); - Ctx->HeaderList = - BuildHeaderList(AdditionalHeader, - m_SessionId, - GetAccessToken(), - {std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)))}); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - // Set up read callback for payload data - Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); - Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); - Ctx->ReadData.Offset = 0; - - curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); - curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); - curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::Post; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Parameters = std::move(Parameters); + return Submit(std::move(Ctx)); } - void DoAsyncPutWithPayload(std::string Url, - IoBuffer Payload, - AsyncHttpCallback Callback, - HttpClient::KeyValueMap AdditionalHeader, - HttpClient::KeyValueMap Parameters) + AsyncRequestToken DoAsyncPostWithPayload(std::string Url, + IoBuffer Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader) { - asio::post(m_Strand, - [this, - Url = std::move(Url), - Payload = std::move(Payload), - Callback = std::move(Callback), - AdditionalHeader = std::move(AdditionalHeader), - Parameters = std::move(Parameters)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Put"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, Parameters); - curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); - - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); - Ctx->PayloadBuffer = std::move(Payload); - Ctx->HeaderList = BuildHeaderList( - AdditionalHeader, - m_SessionId, - GetAccessToken(), - {std::make_pair("Content-Type", std::string(MapContentTypeToString(Ctx->PayloadBuffer.GetContentType())))}); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); - - Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); - Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); - Ctx->ReadData.Offset = 0; - - curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); - curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); - curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); - - SubmitTransfer(Handle, std::move(Ctx)); - }); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::PostWithPayload; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Payload = std::move(Payload); + Ctx->Spec.ContentType = ContentType; + Ctx->Spec.HasContentType = true; + return Submit(std::move(Ctx)); } - void DoAsyncPutNoPayload(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap Parameters) + AsyncRequestToken DoAsyncPutWithPayload(std::string Url, + IoBuffer Payload, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) { - asio::post(m_Strand, [this, Url = std::move(Url), Callback = std::move(Callback), Parameters = std::move(Parameters)]() mutable { - ZEN_TRACE_CPU("AsyncHttpClient::Put"); - if (m_ShuttingDown) - { - return; - } - CURL* Handle = AllocHandle(); - ConfigureHandle(Handle, Url, Parameters); - curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); - curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, 0LL); + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::PutWithPayload; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Parameters = std::move(Parameters); + Ctx->Spec.Payload = std::move(Payload); + return Submit(std::move(Ctx)); + } - auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + AsyncRequestToken DoAsyncPutNoPayload(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::Put; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Parameters = std::move(Parameters); + return Submit(std::move(Ctx)); + } - HttpClient::KeyValueMap ContentLengthHeader{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; - Ctx->HeaderList = BuildHeaderList(ContentLengthHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + AsyncRequestToken DoAsyncPutWithSource(std::string Url, + uint64_t TotalSize, + AsyncHttpReadSource Source, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader) + { + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->Spec.Method = AsyncRequestMethod::PutWithSource; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.OnReadSource = std::move(Source); + Ctx->Spec.StreamingPutSize = TotalSize; + return Submit(std::move(Ctx)); + } - SubmitTransfer(Handle, std::move(Ctx)); - }); + AsyncRequestToken DoAsyncStream(std::string Url, + AsyncHttpDataCallback OnData, + AsyncHttpCallback OnComplete, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + auto Ctx = std::make_unique<TransferContext>(std::move(OnComplete)); + Ctx->Spec.Method = AsyncRequestMethod::Stream; + Ctx->Spec.Url = std::move(Url); + Ctx->Spec.AdditionalHeader = std::move(AdditionalHeader); + Ctx->Spec.Parameters = std::move(Parameters); + Ctx->Spec.OnData = std::move(OnData); + return Submit(std::move(Ctx)); } // -- Members --------------------------------------------------------- @@ -824,25 +1880,44 @@ struct AsyncHttpClient::Impl std::string m_SessionId; std::string m_UnixSocketPathUtf8; - // io_context and strand - all curl_multi operations are serialized on the - // strand, making this safe even when the io_context has multiple threads. - std::unique_ptr<asio::io_context> m_OwnedIoContext; - asio::io_context& m_IoContext; + // io_context: either privately owned (m_OwnedIoContext non-null + we spin + // m_IoThread + hold m_WorkGuard) or supplied by the caller (m_OwnedIoContext + // null; caller drives the loop). Declared before m_Strand so the strand + // can bind its executor from m_IoContext during member init. + std::unique_ptr<asio::io_context> m_OwnedIoContext; + asio::io_context& m_IoContext; + // Single serialization point for every curl_multi operation, every async + // completion handler, and every Cleanup call. Declared after m_IoContext + // (so make_strand can read its executor) and before m_Timer (which binds + // to the strand). asio::strand<asio::io_context::executor_type> m_Strand; std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard; std::thread m_IoThread; - // curl_multi and socket-action state - CURLM* m_Multi = nullptr; - std::unordered_map<CURL*, std::unique_ptr<TransferContext>> m_Transfers; - std::vector<CURL*> m_HandlePool; - std::unordered_map<curl_socket_t, std::unique_ptr<SocketInfo>> m_Sockets; - asio::steady_timer m_Timer; - bool m_ShuttingDown = false; + // Strand-bound; async_wait completions land back on m_Strand. + asio::steady_timer m_Timer; + CURLM* m_Multi = nullptr; + // Single TokenId-keyed map. CurlHandle lives in TransferContext; reverse + // lookup from CURL* uses CURLOPT_PRIVATE (set in SubmitTransfer). + std::unordered_map<uint64_t, std::unique_ptr<TransferContext>> m_Transfers; + // Transfers parked between curl-side completion and the next attempt's + // SubmitFromSpec. Lookups by TokenId let HandleCancel cancel the backoff + // timer without paying the full delay. + struct RetryEntry + { + std::unique_ptr<TransferContext> Ctx; + std::shared_ptr<asio::steady_timer> Timer; + }; + std::unordered_map<uint64_t, RetryEntry> m_RetryingTransfers; + std::vector<CURL*> m_HandlePool; + std::unordered_map<curl_socket_t, std::unique_ptr<AsyncSocketInfo>> m_Sockets; + uint32_t m_InFlight = 0; // telemetry only; storage layer caps fan-out - // Access token cache - RwLock m_AccessTokenLock; + std::atomic<bool> m_ShuttingDown{false}; HttpClientAccessToken m_CachedAccessToken; + + std::atomic<uint64_t> m_NextTokenId{1}; + std::atomic<bool> m_ShutdownDone{false}; }; ////////////////////////////////////////////////////////////////////////// @@ -850,79 +1925,134 @@ struct AsyncHttpClient::Impl // AsyncHttpClient public API AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings) -: m_Impl(std::make_unique<Impl>(BaseUri, Settings)) +: m_Impl(std::make_shared<Impl>(BaseUri, Settings)) { } AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings) -: m_Impl(std::make_unique<Impl>(BaseUri, IoContext, Settings)) +: m_Impl(std::make_shared<Impl>(BaseUri, IoContext, Settings)) { } -AsyncHttpClient::~AsyncHttpClient() = default; +AsyncHttpClient::~AsyncHttpClient() +{ + // Drive teardown synchronously while we're guaranteed to be on a user + // thread (not the io thread). Joining and draining here ensures any + // posted lambdas that captured a shared_ptr<Impl> by value (e.g. + // Cancel after the transfer completed) are destroyed and release + // their refs before m_Impl drops; otherwise the cycle "lambda holds + // Impl ref / lambda lives in io_context owned by Impl" would pin Impl + // alive forever and leak curl_multi + socket handles. + if (m_Impl) + { + m_Impl->Shutdown(); + } +} // -- Callback-based API -------------------------------------------------- -void +AsyncRequestToken AsyncHttpClient::AsyncGet(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) { - m_Impl->DoAsyncGet(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); + return m_Impl->DoAsyncGet(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); } -void +AsyncRequestToken AsyncHttpClient::AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) { - m_Impl->DoAsyncHead(std::string(Url), std::move(Callback), AdditionalHeader); + return m_Impl->DoAsyncHead(std::string(Url), std::move(Callback), AdditionalHeader); } -void +AsyncRequestToken AsyncHttpClient::AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) { - m_Impl->DoAsyncDelete(std::string(Url), std::move(Callback), AdditionalHeader); + return m_Impl->DoAsyncDelete(std::string(Url), std::move(Callback), AdditionalHeader); } -void +AsyncRequestToken AsyncHttpClient::AsyncPost(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) { - m_Impl->DoAsyncPost(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); + return m_Impl->DoAsyncPost(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); } -void +AsyncRequestToken AsyncHttpClient::AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) { - m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, Payload.GetContentType(), std::move(Callback), AdditionalHeader); + return m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, Payload.GetContentType(), std::move(Callback), AdditionalHeader); } -void +AsyncRequestToken AsyncHttpClient::AsyncPost(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) { - m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, ContentType, std::move(Callback), AdditionalHeader); + return m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, ContentType, std::move(Callback), AdditionalHeader); } -void +AsyncRequestToken AsyncHttpClient::AsyncPut(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) { - m_Impl->DoAsyncPutWithPayload(std::string(Url), Payload, std::move(Callback), AdditionalHeader, Parameters); + return m_Impl->DoAsyncPutWithPayload(std::string(Url), Payload, std::move(Callback), AdditionalHeader, Parameters); +} + +AsyncRequestToken +AsyncHttpClient::AsyncPut(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + return m_Impl->DoAsyncPutNoPayload(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); +} + +AsyncRequestToken +AsyncHttpClient::AsyncPut(std::string_view Url, + uint64_t TotalSize, + AsyncHttpReadSource Source, + AsyncHttpCallback OnComplete, + const KeyValueMap& AdditionalHeader) +{ + return m_Impl->DoAsyncPutWithSource(std::string(Url), TotalSize, std::move(Source), std::move(OnComplete), AdditionalHeader); +} + +AsyncRequestToken +AsyncHttpClient::AsyncStream(std::string_view Url, + AsyncHttpDataCallback OnData, + AsyncHttpCallback OnComplete, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + return m_Impl->DoAsyncStream(std::string(Url), std::move(OnData), std::move(OnComplete), AdditionalHeader, Parameters); } +// -- Token cancellation -------------------------------------------------- + void -AsyncHttpClient::AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters) +AsyncRequestToken::Cancel() { - m_Impl->DoAsyncPutNoPayload(std::string(Url), std::move(Callback), Parameters); + if (!m_State) + { + return; + } + if (m_State->Cancelled.exchange(true, std::memory_order_acq_rel)) + { + return; // already cancelled + } + if (m_State->CancelFn) + { + m_State->CancelFn(); + } } // -- Future-based API ---------------------------------------------------- @@ -1026,6 +2156,7 @@ AsyncHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) AsyncPut( Url, [Promise](Response R) { Promise->set_value(std::move(R)); }, + KeyValueMap{}, Parameters); return Future; } diff --git a/src/zenhttp/clients/httpclientcurlhelpers.h b/src/zenhttp/clients/httpclientcurlhelpers.h index cb5f5d9a9..410399a11 100644 --- a/src/zenhttp/clients/httpclientcurlhelpers.h +++ b/src/zenhttp/clients/httpclientcurlhelpers.h @@ -232,27 +232,48 @@ BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, const std::optional<std::string>& AccessToken, const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) { - curl_slist* Headers = nullptr; + // Manual tail tracking. curl_slist_append walks to the tail on every call, + // so chained appends are O(n^2) over the header count. We append in O(1) + // by allocating each node directly and stitching it onto a tracked tail. + curl_slist* Head = nullptr; + curl_slist* Tail = nullptr; + auto Append = [&](const char* Line) { + curl_slist* New = curl_slist_append(nullptr, Line); + if (!New) + { + return; + } + if (!Head) + { + Head = New; + Tail = New; + } + else + { + Tail->next = New; + Tail = New; + } + }; for (const auto& [Key, Value] : *AdditionalHeader) { ExtendableStringBuilder<64> HeaderLine; HeaderLine << Key << ": " << Value; - Headers = curl_slist_append(Headers, HeaderLine.c_str()); + Append(HeaderLine.c_str()); } if (!SessionId.empty()) { ExtendableStringBuilder<64> SessionHeader; SessionHeader << "UE-Session: " << SessionId; - Headers = curl_slist_append(Headers, SessionHeader.c_str()); + Append(SessionHeader.c_str()); } if (AccessToken.has_value()) { ExtendableStringBuilder<128> AuthHeader; AuthHeader << "Authorization: " << AccessToken.value(); - Headers = curl_slist_append(Headers, AuthHeader.c_str()); + Append(AuthHeader.c_str()); } bool HasContentTypeOverride = AdditionalHeader->contains("Content-Type"); @@ -264,10 +285,10 @@ BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, } ExtendableStringBuilder<128> HeaderLine; HeaderLine << Key << ": " << Value; - Headers = curl_slist_append(Headers, HeaderLine.c_str()); + Append(HeaderLine.c_str()); } - return Headers; + return Head; } inline HttpClient::KeyValueMap diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 9d5846f71..b6a07250e 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -320,6 +320,62 @@ HttpClient::Response::IsSuccess() const noexcept return !Error && IsHttpSuccessCode(StatusCode); } +std::string_view +HttpClient::Response::FindHeader(std::string_view Name) const +{ + // Scan the raw arena first - the async client populates this and leaves + // Header empty by default. Lines are "Key: Value\r\n" (the trailing \r\n + // is what curl hands the HEADER callback; we keep it verbatim). + if (!HeaderArena.empty()) + { + std::string_view View(HeaderArena); + while (!View.empty()) + { + const size_t LineEnd = View.find('\n'); + std::string_view Line = LineEnd == std::string_view::npos ? View : View.substr(0, LineEnd); + View = LineEnd == std::string_view::npos ? std::string_view{} : View.substr(LineEnd + 1); + + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + if (Line.empty()) + { + continue; + } + const size_t Colon = Line.find(':'); + if (Colon == std::string_view::npos) + { + continue; // HTTP status line or malformed + } + std::string_view K = Line.substr(0, Colon); + std::string_view V = Line.substr(Colon + 1); + while (!K.empty() && K.back() == ' ') + { + K.remove_suffix(1); + } + while (!V.empty() && V.front() == ' ') + { + V.remove_prefix(1); + } + if (StrCaseEquals(K, Name)) + { + return V; + } + } + } + + // Fall back to the parsed map (sync client populates this). + for (const auto& [K, V] : *Header) + { + if (StrCaseEquals(K, Name)) + { + return V; + } + } + return {}; +} + std::string HttpClient::Response::ErrorMessage(std::string_view Prefix) const { diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index b0e097a54..ea73ff7a3 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -300,6 +300,49 @@ struct TestServerFixture TEST_SUITE_BEGIN("http.httpclient"); +TEST_CASE("httpclient.response.findheader.arena") +{ + // Async client populates HeaderArena (raw "Key: Value\r\n" bytes); FindHeader + // scans it lazily without building the full KeyValueMap. Exercise arena scan + // path independent of any live HTTP transfer. + HttpClient::Response Resp; + Resp.HeaderArena = + "Content-Type: application/json\r\n" + "ETag: \"abc-123\"\r\n" + "Content-Length: 42\r\n" + "X-Amz-Request-Id: deadbeef\r\n" + "\r\n"; + + CHECK_EQ(Resp.FindHeader("Content-Type"), "application/json"); + CHECK_EQ(Resp.FindHeader("etag"), "\"abc-123\""); // case-insensitive + CHECK_EQ(Resp.FindHeader("CONTENT-LENGTH"), "42"); // case-insensitive + CHECK_EQ(Resp.FindHeader("x-amz-request-id"), "deadbeef"); + CHECK_EQ(Resp.FindHeader("missing"), ""); // absent + CHECK_EQ(Resp.FindHeader(""), ""); // empty name +} + +TEST_CASE("httpclient.response.findheader.map_fallback") +{ + // Sync client populates Header KeyValueMap; FindHeader falls back to it + // when HeaderArena is empty. + HttpClient::Response Resp; + Resp.Header->insert_or_assign("Content-Type", "text/plain"); + Resp.Header->insert_or_assign("ETag", "\"xyz\""); + + CHECK_EQ(Resp.FindHeader("content-type"), "text/plain"); + CHECK_EQ(Resp.FindHeader("ETag"), "\"xyz\""); + CHECK_EQ(Resp.FindHeader("missing"), ""); +} + +TEST_CASE("httpclient.response.findheader.arena_takes_priority") +{ + // If both populated (unusual), arena is scanned first. + HttpClient::Response Resp; + Resp.HeaderArena = "ETag: from-arena\r\n"; + Resp.Header->insert_or_assign("ETag", "from-map"); + CHECK_EQ(Resp.FindHeader("ETag"), "from-arena"); +} + TEST_CASE("httpclient.verbs") { TestServerFixture Fixture; diff --git a/src/zenhttp/include/zenhttp/asynchttpclient.h b/src/zenhttp/include/zenhttp/asynchttpclient.h index cb41626b9..d4a33e3ac 100644 --- a/src/zenhttp/include/zenhttp/asynchttpclient.h +++ b/src/zenhttp/include/zenhttp/asynchttpclient.h @@ -19,21 +19,96 @@ namespace zen { /// Completion callback for async HTTP operations. using AsyncHttpCallback = std::function<void(HttpClient::Response)>; +/// Pull-mode body source for the streaming `AsyncPut` overload. Fired from +/// the AsyncHttpClient io thread when curl needs more upload bytes; runs on +/// the same strand as every other transfer. +/// +/// IMPORTANT: like AsyncHttpDataCallback, this runs on the io thread and +/// stalls curl_multi for ALL in-flight transfers if it blocks. Local-disk +/// pread is acceptable on fast storage; for anything slower, pre-fetch into +/// a caller-owned ring on a worker pool and have this callback pop from it. +/// +/// Dst - destination buffer to fill (caller writes here). +/// MaxBytes - maximum bytes Dst can hold this call. +/// AbsOffset - cumulative bytes returned so far (i.e. start offset of +/// this chunk into the request body). Strictly monotonic +/// across calls; equals TotalSize when EOF reached. +/// +/// Returns the number of bytes written into Dst (<= MaxBytes). Returning 0 +/// while AbsOffset < TotalSize is treated as an upload error (CURL_READFUNC_ABORT). +using AsyncHttpReadSource = std::function<size_t(uint8_t* Dst, size_t MaxBytes, uint64_t AbsOffset)>; + +/// Per-chunk data callback for `AsyncStream`. Fired from the AsyncHttpClient +/// io thread (the same thread driving curl_multi for every other transfer on +/// this client) once per response-body chunk that curl delivers. +/// +/// IMPORTANT: OnData runs on the io thread. Any blocking work inside it +/// (synchronous disk I/O, `std::mutex` waits, network calls, lock contention) +/// stalls curl_multi for ALL in-flight transfers on this client, not just +/// this one. Treat OnData as a strand: copy/refcount the bytes into a buffer +/// you own and hop the heavy work to a worker pool. See +/// `S3AsyncStorage::Get` (medium tier) for the canonical pattern - fill a +/// pre-sized IoBuffer from a bounded pool here, dispatch the positional disk +/// write to a worker, release the buffer back to the pool from the worker. +/// +/// Data - pointer to received bytes; valid only for the duration of +/// this call. Caller must consume synchronously (memcpy into +/// a pre-arranged buffer slot, etc). No allocation or copy +/// is performed by AsyncHttpClient. +/// Size - bytes available at Data for this chunk. +/// TotalSize - declared total payload size (Content-Length, in bytes); 0 +/// if the server did not declare a length (e.g. +/// Transfer-Encoding: chunked). Constant across all calls +/// for a given request. +/// +/// Returns true to continue the transfer; returning false aborts the +/// transfer and the completion callback fires with an error response. +using AsyncHttpDataCallback = std::function<bool(const uint8_t* Data, size_t Size, uint64_t TotalSize)>; + +/// Handle to an in-flight async HTTP request. Returned by every AsyncXxx call. +/// Default-constructed token is empty (Cancel is a no-op). Calling Cancel on a +/// non-empty token requests cancellation of the underlying transfer; the +/// completion callback still fires once with an error response so callers do +/// not need a second notification path. +/// +/// The token does not keep the AsyncHttpClient alive. Callers must ensure +/// Cancel is invoked before the client is destroyed; tokens left dangling are +/// safe to destroy but Cancel calls after client destruction are no-ops. +class AsyncRequestToken +{ +public: + // Forward-declared so AsyncHttpClient internals can hold a typed + // shared_ptr<State> across worker / strand boundaries (used for the + // cancel-before-submit race check). State definition is private to the + // implementation TU. + struct State; + + AsyncRequestToken() = default; + + void Cancel(); + bool IsValid() const { return m_State != nullptr; } + +private: + friend class AsyncHttpClient; + explicit AsyncRequestToken(std::shared_ptr<State> S) : m_State(std::move(S)) {} + + std::shared_ptr<State> m_State; +}; + /** Asynchronous HTTP client backed by curl_multi and ASIO. * * Uses curl_multi_socket_action() driven by ASIO socket async_wait to process - * transfers without blocking the caller. All curl_multi operations are - * serialized on an internal strand; callers may issue requests from any - * thread, and the io_context may have multiple threads. + * transfers without blocking the caller. By default the client owns an + * io_context and a single io thread driving it; the second constructor reuses + * an external io_context. All curl_multi operations are serialized on an + * internal strand. Callers may issue requests from any thread. * - * Two construction modes: - * - Owned io_context: creates an internal thread (self-contained). - * - External io_context: caller runs the event loop. - * - * Completion callbacks are dispatched on the io_context (not the internal - * strand), so a slow callback will not block the curl poll loop. Future- - * based wrappers (Get, Post, ...) return a std::future<Response> for - * callers that prefer blocking on a result. + * Completion callbacks run inline on the AsyncHttpClient io thread (same + * strand as the curl poll loop): heavy work (disk syscalls, lock contention, + * large allocations) must be hopped to a worker pool. See + * `AsyncHttpDataCallback` and `AsyncHttpReadSource` for the same contract on + * streaming. Future-based wrappers (Get, Post, ...) return a + * std::future<Response> for callers that prefer blocking on a result. */ class AsyncHttpClient { @@ -41,11 +116,15 @@ public: using Response = HttpClient::Response; using KeyValueMap = HttpClient::KeyValueMap; - /// Construct with an internally-owned io_context and thread. explicit AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings = {}); /// Construct with an externally-managed io_context. The io_context must - /// outlive this client and must be running (via run()) on at least one thread. + /// outlive the AsyncHttpClient and must be running (e.g. via run() on + /// a dedicated thread, or driven by the caller). The destructor posts + /// the cleanup handler to that loop and blocks until it completes - + /// destroying the client from the same thread that drives the io_context + /// would deadlock. Multiple threads on the same io_context are safe; + /// all curl_multi operations are serialized through an internal strand. AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings = {}); ~AsyncHttpClient(); @@ -54,38 +133,93 @@ public: AsyncHttpClient& operator=(const AsyncHttpClient&) = delete; // -- Callback-based API ---------------------------------------------- - - void AsyncGet(std::string_view Url, - AsyncHttpCallback Callback, - const KeyValueMap& AdditionalHeader = {}, - const KeyValueMap& Parameters = {}); - - void AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); - - void AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); - - void AsyncPost(std::string_view Url, - AsyncHttpCallback Callback, - const KeyValueMap& AdditionalHeader = {}, - const KeyValueMap& Parameters = {}); - - void AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); - - void AsyncPost(std::string_view Url, - const IoBuffer& Payload, - ZenContentType ContentType, - AsyncHttpCallback Callback, - const KeyValueMap& AdditionalHeader = {}); - - void AsyncPut(std::string_view Url, - const IoBuffer& Payload, - AsyncHttpCallback Callback, - const KeyValueMap& AdditionalHeader = {}, - const KeyValueMap& Parameters = {}); - - void AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters = {}); + // + // On every callback overload below the response's parsed Header map is left + // EMPTY by default - the raw header bytes live in Response::HeaderArena and + // callers should use Response::FindHeader(name) to look up individual values. + // This skips the per-line std::string allocations on the io thread that the + // sync client incurs. Callers that need to iterate all headers must opt in + // via AsyncRequestSpec::WantHeaderMap on the impl side; today no public + // overload exposes that flag because no in-tree caller needs it. + + AsyncRequestToken AsyncGet(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + AsyncRequestToken AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + AsyncRequestToken AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + AsyncRequestToken AsyncPost(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + AsyncRequestToken AsyncPost(std::string_view Url, + const IoBuffer& Payload, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}); + + AsyncRequestToken AsyncPost(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}); + + AsyncRequestToken AsyncPut(std::string_view Url, + const IoBuffer& Payload, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + AsyncRequestToken AsyncPut(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + /// Streaming PUT: body bytes are pulled from `Source` on demand, no + /// pre-materialized payload buffer. TotalSize is set as Content-Length. + /// `Source` runs on the AsyncHttpClient io thread - same strand + /// discipline as AsyncStream's OnData. Useful for multipart parts or + /// medium-tier uploads where materializing the body would waste RAM. + AsyncRequestToken AsyncPut(std::string_view Url, + uint64_t TotalSize, + AsyncHttpReadSource Source, + AsyncHttpCallback OnComplete, + const KeyValueMap& AdditionalHeader = {}); + + /// Streaming GET with a caller-supplied per-chunk data callback. Bytes are + /// delivered to OnData as they arrive, with no internal allocation or + /// copy of the payload. Caller manages its own destination (positional + /// disk write, pre-arranged buffer slot, etc.) and must consume each + /// chunk synchronously before returning - the data pointer is only valid + /// for the duration of the call. + /// + /// WARNING: OnData runs on the AsyncHttpClient io thread. Blocking I/O, + /// lock contention, or any wait inside OnData stalls curl_multi for ALL + /// in-flight transfers on this client. See `AsyncHttpDataCallback` doc + /// for the recommended buffer-then-hop pattern. + /// + /// The completion Callback fires once with status / headers / error + /// (Response.ResponsePayload is always empty on this path). + /// + /// Useful when the caller already owns the destination file (e.g. ranged + /// multipart download writing per-range slices to a shared file) and + /// wants to avoid the curl-buffer -> in-memory IoBuffer -> file double + /// memcpy / lazy-commit page-fault cost of AsyncGet+worker-write. + AsyncRequestToken AsyncStream(std::string_view Url, + AsyncHttpDataCallback OnData, + AsyncHttpCallback OnComplete, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); // -- Future-based API ------------------------------------------------ + // + // These wrappers discard the underlying AsyncRequestToken; callers who need + // to cancel an in-flight request must use the callback-based API. The + // returned future resolves once with the final Response (which may carry an + // error / cancel response if the client is shutting down). [[nodiscard]] std::future<Response> Get(std::string_view Url, const KeyValueMap& AdditionalHeader = {}, @@ -115,7 +249,7 @@ public: private: struct Impl; - std::unique_ptr<Impl> m_Impl; + std::shared_ptr<Impl> m_Impl; }; void asynchttpclient_test_forcelink(); // internal diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 4cf3a86a8..3162823f3 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -136,6 +136,24 @@ struct HttpClientSettings /// nullptr disables sharing. Must outlive every HttpClient that /// references it. Non-curl backends ignore this field. HttpClientShare* OptionalShare = nullptr; + + /// Max concurrent connections per host. Used by AsyncHttpClient to set + /// CURLMOPT_MAX_HOST_CONNECTIONS. libcurl default is small; hub-scale fanout + /// against a single S3 endpoint needs much higher. 0 = leave libcurl default. + uint32_t MaxConcurrentConnectionsPerHost = 0; + + /// Max concurrent connections total. Used by AsyncHttpClient to set + /// CURLMOPT_MAX_TOTAL_CONNECTIONS. 0 = leave libcurl default. + uint32_t MaxConcurrentConnectionsTotal = 0; + + /// Hint for the maximum number of in-flight requests the caller intends to + /// keep submitted to AsyncHttpClient. AsyncHttpClient itself does not gate + /// on this value - fan-out is bounded externally by the caller (e.g. the + /// hub's S3AsyncStorage admission semaphore). This setting is reused by + /// hub configuration to derive matching curl connection caps so libcurl's + /// CONNECTTIMEOUT does not tick on connections waiting behind the cap. + /// 0 = no hint. + uint32_t MaxConcurrentRequests = 0; }; class HttpClientError : public std::runtime_error @@ -272,9 +290,17 @@ public: HttpResponseCode StatusCode = HttpResponseCode::ImATeapot; IoBuffer ResponsePayload; // Note: this also includes the content type - // Contains the response headers + // Contains the response headers. By default the async path leaves this + // empty (raw bytes are kept in HeaderArena instead) - use FindHeader + // for lookups, or set AsyncRequestSpec::WantHeaderMap if you need the + // full parsed map. The synchronous client populates this as before. KeyValueMap Header; + // Raw response header bytes, "Key: Value\r\n" lines concatenated. + // Populated by the async client; empty for the sync client. Owned + // and freed with the Response. FindHeader() scans this lazily. + std::string HeaderArena; + // The number of bytes sent as part of the request int64_t UploadedBytes = 0; @@ -321,6 +347,12 @@ public: // objects, returns text as-is for text types like Text, JSON, HTML etc std::string ToText() const; + // Lookup a header by name (case-insensitive). Checks HeaderArena first, + // then Header map. Returns the first matching value or empty string_view + // if absent. View is valid until this Response is destroyed (or Header + // map is mutated). + std::string_view FindHeader(std::string_view Name) const; + // Returns whether the HTTP status code is considered successful (i.e in the // 2xx range) bool IsSuccess() const noexcept; diff --git a/src/zenserver/hub/hub.cpp b/src/zenserver/hub/hub.cpp index b4e9de2f0..2afee8729 100644 --- a/src/zenserver/hub/hub.cpp +++ b/src/zenserver/hub/hub.cpp @@ -21,6 +21,10 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <gsl/gsl-lite.hpp> ZEN_THIRD_PARTY_INCLUDES_END +#include <zencore/thread.h> + +#include <thread> + #if ZEN_WITH_TESTS # include <zencore/testing.h> # include <zencore/testutils.h> @@ -30,8 +34,6 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -/////////////////////////////////////////////////////////////////////////// - /** * A timeline of events with sequence IDs and timestamps. Used to * track significant events for broadcasting to listeners. @@ -81,7 +83,6 @@ public: */ void IterateEventsSince(auto&& Callback, uint64_t SinceEventId) { - // Hold the lock for as short a time as possible eastl::fixed_vector<EventRecord, 128> EventsToProcess; m_Lock.WithSharedLock([&] { for (auto& Event : m_Events) @@ -93,7 +94,6 @@ public: } }); - // Now invoke the callback outside the lock for (auto& Event : EventsToProcess) { Callback(Event); @@ -176,16 +176,22 @@ Hub::GetMachineMetrics(SystemMetrics& OutSystemMetrict, DiskSpace& OutDiskSpace) Hub::Hub(const Configuration& Config, ZenServerEnvironment&& RunEnvironment, AsyncModuleStateChangeCallbackFunc&& ModuleStateChangeCallback) : m_Config(Config) , m_RunEnvironment(std::move(RunEnvironment)) -, m_WorkerPool(Config.OptionalProvisionWorkerPool) +, m_ProvisionPool(Config.OptionalProvisionPool) +, m_SpawnPool(Config.OptionalSpawnPool) , m_InstanceClientShare(std::make_unique<HttpClientShare>()) , m_BackgroundWorkLatch(1) , m_ModuleStateChangeCallback(std::move(ModuleStateChangeCallback)) , m_ActiveInstances(Config.InstanceLimit) , m_FreeActiveInstanceIndexes(Config.InstanceLimit) { - ZEN_ASSERT_FORMAT( - Config.OptionalProvisionWorkerPool != Config.OptionalHydrationWorkerPool || Config.OptionalProvisionWorkerPool == nullptr, - "Provision and hydration worker pools must be distinct to avoid deadlocks"); + ZEN_ASSERT_FORMAT((Config.OptionalProvisionPool == nullptr) == (Config.OptionalSpawnPool == nullptr), + "Provision and spawn worker pools must both be set or both be null"); + ZEN_ASSERT_FORMAT(Config.OptionalProvisionPool != Config.OptionalSpawnPool || Config.OptionalProvisionPool == nullptr, + "Provision and spawn worker pools must be distinct to avoid deadlocks"); + ZEN_ASSERT_FORMAT(Config.OptionalProvisionPool != Config.OptionalHydrationPool || Config.OptionalProvisionPool == nullptr, + "Provision and hydration worker pools must be distinct to avoid deadlocks"); + ZEN_ASSERT_FORMAT(Config.OptionalSpawnPool != Config.OptionalHydrationPool || Config.OptionalSpawnPool == nullptr, + "Spawn and hydration worker pools must be distinct to avoid deadlocks"); HydrationBase::Configuration HydrationConfig; if (!m_Config.HydrationTargetSpecification.empty()) @@ -202,6 +208,11 @@ Hub::Hub(const Configuration& Config, ZenServerEnvironment&& RunEnvironment, Asy { HydrationConfig.Options = m_Config.HydrationOptions; } + if (Config.HydrationAsyncEnabled) + { + HydrationConfig.AsyncEnabled = true; + HydrationConfig.AsyncMaxConcurrentRequests = Config.HydrationAsyncMaxConcurrentRequests; + } m_Hydration = InitHydration(HydrationConfig); m_HydrationTempPath = m_RunEnvironment.CreateChildDir("hydration_temp"); @@ -272,7 +283,7 @@ Hub::Shutdown() m_WatchDog = {}; - if (WaitForBackgroundWork && m_WorkerPool) + if (WaitForBackgroundWork && (m_ProvisionPool != nullptr || m_SpawnPool != nullptr)) { m_BackgroundWorkLatch.CountDown(); m_BackgroundWorkLatch.Wait(); @@ -300,7 +311,7 @@ Hub::Shutdown() } }); - if (WaitForBackgroundWork && m_WorkerPool) + if (WaitForBackgroundWork && (m_ProvisionPool != nullptr || m_SpawnPool != nullptr)) { m_BackgroundWorkLatch.CountDown(); m_BackgroundWorkLatch.Wait(); @@ -317,7 +328,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) size_t ActiveInstanceIndex = (size_t)-1; HubInstanceState OldState = HubInstanceState::Unprovisioned; { - RwLock::ExclusiveLockScope _(m_Lock); + RwLock::ExclusiveLockScope HubLock(m_Lock); if (auto It = m_InstanceLookup.find(std::string(ModuleId)); It == m_InstanceLookup.end()) { @@ -339,23 +350,15 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) { auto NewInstance = std::make_unique<StorageServerInstance>( m_RunEnvironment, - *m_Hydration, - StorageServerInstance::Configuration{.BasePort = GetInstanceIndexAssignedPort(ActiveInstanceIndex), - .StateDir = m_RunEnvironment.CreateChildDir(ModuleId), - .TempDir = m_HydrationTempPath / ModuleId, - .HttpThreadCount = m_Config.InstanceHttpThreadCount, - .CoreLimit = m_Config.InstanceCoreLimit, - .ConfigPath = m_Config.InstanceConfigPath, - .Malloc = m_Config.InstanceMalloc, - .Trace = m_Config.InstanceTrace, - .TraceHost = m_Config.InstanceTraceHost, - .TraceFile = m_Config.InstanceTraceFile, - .EnableHydration = m_Config.EnableHydration, - .EnableDehydration = m_Config.EnableDehydration, - .HydrationPackEnabled = m_Config.HydrationPackEnabled, - .HydrationPackThresholdBytes = m_Config.HydrationPackThresholdBytes, - .HydrationMaxPackBytes = m_Config.HydrationMaxPackBytes, - .OptionalWorkerPool = m_Config.OptionalHydrationWorkerPool}, + StorageServerInstance::Configuration{.BasePort = GetInstanceIndexAssignedPort(ActiveInstanceIndex), + .StateDir = m_RunEnvironment.CreateChildDir(ModuleId), + .HttpThreadCount = m_Config.InstanceHttpThreadCount, + .CoreLimit = m_Config.InstanceCoreLimit, + .ConfigPath = m_Config.InstanceConfigPath, + .Malloc = m_Config.InstanceMalloc, + .Trace = m_Config.InstanceTrace, + .TraceHost = m_Config.InstanceTraceHost, + .TraceFile = m_Config.InstanceTraceFile}, ModuleId); #if ZEN_PLATFORM_WINDOWS @@ -367,7 +370,8 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) Instance = NewInstance->LockExclusive(/*Wait*/ true); - m_ActiveInstances[ActiveInstanceIndex].Instance = std::move(NewInstance); + m_ActiveInstances[ActiveInstanceIndex].Instance = std::move(NewInstance); + m_ActiveInstances[ActiveInstanceIndex].HydrationState = {}; m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Reset(); m_InstanceLookup.insert_or_assign(std::string(ModuleId), ActiveInstanceIndex); // Set Provisioning while both hub lock and instance lock are held so that any @@ -378,7 +382,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) { Instance = {}; m_ActiveInstances[ActiveInstanceIndex].Instance.reset(); - m_ActiveInstances[ActiveInstanceIndex].State.store(HubInstanceState::Unprovisioned); + UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Unprovisioned); m_InstanceLookup.erase(std::string(ModuleId)); m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex); throw; @@ -419,7 +423,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(std::chrono::system_clock::now()); return Response{EResponseCode::Completed}; case HubInstanceState::Hibernated: - _.ReleaseNow(); + HubLock.ReleaseNow(); return Wake(std::string(ModuleId)); default: return Response{EResponseCode::Rejected, @@ -463,54 +467,54 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Provisioning, OutInfo.Port, {}); - if (m_WorkerPool) + const bool Async = m_ProvisionPool != nullptr && m_SpawnPool != nullptr; + if (Async) { + auto SharedInstance = std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance)); m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( + m_ProvisionPool->ScheduleWork( [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, OldState, IsNewInstance, - Instance = std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance))]() { + Port = OutInfo.Port, + Instance = SharedInstance]() { auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); + if (!RunProvisionPhase1(*Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port)) + { + return; + } + + m_BackgroundWorkLatch.AddCount(1); try { - CompleteProvision(*Instance, ActiveInstanceIndex, OldState, IsNewInstance); + m_SpawnPool->ScheduleWork( + [this, ModuleId, ActiveInstanceIndex, OldState, IsNewInstance, Port, Instance]() { + auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); + RunProvisionPhase2(*Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); + }, + WorkerThreadPool::EMode::EnableBacklog); } - catch (const std::exception& Ex) + catch (const std::exception& DispatchEx) { - ZEN_ERROR("Failed async provision of module '{}': {}", ModuleId, Ex.what()); + // Fallback: run Phase2 inline on the ProvisionPool worker. Couples pool + // lifetimes (this ProvisionPool slot now executes what should run on + // SpawnPool) but better than dropping the request. + ZEN_ERROR("Failed async dispatch of provision phase 2 for module '{}': {}", ModuleId, DispatchEx.what()); + m_BackgroundWorkLatch.CountDown(); + RunProvisionPhase2(*Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); } }, WorkerThreadPool::EMode::EnableBacklog); } catch (const std::exception& DispatchEx) { - // Dispatch failed: undo latch increment and roll back state. - ZEN_ERROR("Failed async dispatch provision of module '{}': {}", ModuleId, DispatchEx.what()); + ZEN_ERROR("Failed async dispatch of provision phase 1 for module '{}': {}", ModuleId, DispatchEx.what()); m_BackgroundWorkLatch.CountDown(); - - // dispatch failed before the lambda ran, so ActiveInstance::State is still Provisioning - NotifyStateUpdate(ModuleId, HubInstanceState::Provisioning, OldState, OutInfo.Port, {}); - - std::unique_ptr<StorageServerInstance> DestroyInstance; - { - RwLock::ExclusiveLockScope HubLock(m_Lock); - ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId)) != m_InstanceLookup.end()); - ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId))->second == ActiveInstanceIndex); - if (IsNewInstance) - { - DestroyInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance); - m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex); - m_InstanceLookup.erase(std::string(ModuleId)); - } - UpdateInstanceState(HubLock, ActiveInstanceIndex, OldState); - } - DestroyInstance.reset(); - + RollbackFailedProvision(*SharedInstance, ActiveInstanceIndex, OldState, IsNewInstance, OutInfo.Port); throw; } } @@ -519,7 +523,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) CompleteProvision(Instance, ActiveInstanceIndex, OldState, IsNewInstance); } - return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; + return Response{Async ? EResponseCode::Accepted : EResponseCode::Completed}; } void @@ -529,38 +533,91 @@ Hub::CompleteProvision(StorageServerInstance::ExclusiveLockedPtr& Instance, bool IsNewInstance) { ZEN_TRACE_CPU("Hub::CompleteProvision"); + const uint16_t Port = Instance.GetBasePort(); + if (RunProvisionPhase1(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port)) + { + RunProvisionPhase2(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); + } +} + +bool +Hub::RunProvisionPhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunProvisionPhase1"); const std::string ModuleId(Instance.GetModuleId()); - const uint16_t Port = Instance.GetBasePort(); - std::string BaseUri; // TODO? - if (m_ShutdownFlag.load() == false) + if (m_ShutdownFlag.load()) { - try - { - switch (OldState) - { - case HubInstanceState::Crashed: - case HubInstanceState::Unprovisioned: - Instance.Provision(); - break; - case HubInstanceState::Hibernated: - ZEN_ASSERT(false); // unreachable: Provision redirects Hibernated->Wake before setting Provisioning - break; - default: - ZEN_ASSERT(false); - } - UpdateInstanceState(Instance, ActiveInstanceIndex, HubInstanceState::Provisioned); - NotifyStateUpdate(ModuleId, HubInstanceState::Provisioning, HubInstanceState::Provisioned, Port, BaseUri); - Instance = {}; - return; - } - catch (const std::exception& Ex) + RollbackFailedProvision(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); + return false; + } + + try + { + switch (OldState) { - ZEN_ERROR("Failed to provision storage server instance for module '{}': {}", ModuleId, Ex.what()); - // Instance will be notified and removed below. + case HubInstanceState::Crashed: + case HubInstanceState::Unprovisioned: + HydrateInstance(ActiveInstanceIndex, ModuleId); + break; + case HubInstanceState::Hibernated: + ZEN_ASSERT(false); // unreachable: Provision redirects Hibernated->Wake before setting Provisioning + break; + default: + ZEN_ASSERT(false); } + return true; + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed to hydrate storage server instance for module '{}': {}", ModuleId, Ex.what()); + RollbackFailedProvision(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); + return false; + } +} + +void +Hub::RunProvisionPhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunProvisionPhase2"); + const std::string ModuleId(Instance.GetModuleId()); + + if (m_ShutdownFlag.load()) + { + RollbackFailedProvision(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); + return; + } + + try + { + Instance.Provision(); + UpdateInstanceState(Instance, ActiveInstanceIndex, HubInstanceState::Provisioned); + NotifyStateUpdate(ModuleId, HubInstanceState::Provisioning, HubInstanceState::Provisioned, Port, {}); + Instance = {}; + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed to provision storage server instance for module '{}': {}", ModuleId, Ex.what()); + RollbackFailedProvision(Instance, ActiveInstanceIndex, OldState, IsNewInstance, Port); } +} +void +Hub::RollbackFailedProvision(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port) +{ + const std::string ModuleId(Instance.GetModuleId()); if (IsNewInstance) { NotifyStateUpdate(ModuleId, HubInstanceState::Provisioning, HubInstanceState::Unprovisioned, Port, {}); @@ -568,11 +625,11 @@ Hub::CompleteProvision(StorageServerInstance::ExclusiveLockedPtr& Instance, std::unique_ptr<StorageServerInstance> DestroyInstance; { RwLock::ExclusiveLockScope HubLock(m_Lock); - ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId)) != m_InstanceLookup.end()); - ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId))->second == ActiveInstanceIndex); + ZEN_ASSERT_SLOW(m_InstanceLookup.find(ModuleId) != m_InstanceLookup.end()); + ZEN_ASSERT_SLOW(m_InstanceLookup.find(ModuleId)->second == ActiveInstanceIndex); DestroyInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance); m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex); - m_InstanceLookup.erase(std::string(ModuleId)); + m_InstanceLookup.erase(ModuleId); UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Unprovisioned); } DestroyInstance.reset(); @@ -652,11 +709,7 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI } } - // NOTE: done while not holding the hub lock, to avoid blocking other operations. - // The exclusive instance lock acquired above prevents concurrent LockExclusive callers - // from modifying instance state. The state transition to Deprovisioning happens below, - // after the hub lock is released. - + // Outside hub lock: exclusive instance lock above blocks concurrent state mutation. See Provision for the locking argument. ZEN_ASSERT(Instance); ZEN_ASSERT(ActiveInstanceIndex != (size_t)-1); @@ -664,32 +717,55 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI const uint16_t Port = Instance.GetBasePort(); NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Deprovisioning, Port, {}); - if (m_WorkerPool) + const bool Async = m_ProvisionPool != nullptr && m_SpawnPool != nullptr; + if (Async) { - std::shared_ptr<StorageServerInstance::ExclusiveLockedPtr> SharedInstancePtr = - std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance)); - + auto SharedInstance = std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance)); m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( - [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Instance = std::move(SharedInstancePtr), OldState]() mutable { + m_SpawnPool->ScheduleWork( + [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, OldState, Port, Instance = SharedInstance]() { auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); try { - CompleteDeprovision(*Instance, ActiveInstanceIndex, OldState); + RunDeprovisionPhase1(*Instance, ActiveInstanceIndex, OldState, Port); } catch (const std::exception& Ex) { - ZEN_ERROR("Failed async deprovision of module '{}': {}", ModuleId, Ex.what()); + // Phase1 transitions the instance to Crashed before rethrowing, + // so the watchdog will pick this up via AttemptRecoverInstance. + // The deprovision request silently morphs into a recovery cycle; + // caller already saw EResponseCode::Accepted. + ZEN_ERROR("Failed async deprovision phase 1 of module '{}': {}", ModuleId, Ex.what()); + return; + } + + m_BackgroundWorkLatch.AddCount(1); + try + { + m_ProvisionPool->ScheduleWork( + [this, ModuleId, ActiveInstanceIndex, Port, Instance]() { + auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); + RunDeprovisionPhase2(*Instance, ActiveInstanceIndex, Port); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& DispatchEx) + { + // Fallback: run Phase2 inline on the SpawnPool worker. Couples pool + // lifetimes (this SpawnPool slot now executes what should run on + // ProvisionPool) but better than dropping the request. + ZEN_ERROR("Failed async dispatch of deprovision phase 2 for module '{}': {}", ModuleId, DispatchEx.what()); + m_BackgroundWorkLatch.CountDown(); + RunDeprovisionPhase2(*Instance, ActiveInstanceIndex, Port); } }, WorkerThreadPool::EMode::EnableBacklog); } catch (const std::exception& DispatchEx) { - // Dispatch failed: undo latch increment and roll back state. - ZEN_ERROR("Failed async dispatch deprovision of module '{}': {}", ModuleId, DispatchEx.what()); + ZEN_ERROR("Failed async dispatch of deprovision phase 1 for module '{}': {}", ModuleId, DispatchEx.what()); m_BackgroundWorkLatch.CountDown(); NotifyStateUpdate(ModuleId, HubInstanceState::Deprovisioning, OldState, Port, {}); @@ -708,7 +784,7 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI CompleteDeprovision(Instance, ActiveInstanceIndex, OldState); } - return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; + return Response{Async ? EResponseCode::Accepted : EResponseCode::Completed}; } Hub::Response @@ -766,12 +842,12 @@ Hub::Obliterate(const std::string& ModuleId) m_ObliteratingInstances.insert(ModuleId); Lock.ReleaseNow(); - if (m_WorkerPool) + if (m_ProvisionPool != nullptr) { m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( + m_ProvisionPool->ScheduleWork( [this, ModuleId = std::string(ModuleId)]() { auto Guard = MakeGuard([this, ModuleId]() { m_Lock.WithExclusiveLock([this, ModuleId]() { m_ObliteratingInstances.erase(ModuleId); }); @@ -817,31 +893,51 @@ Hub::Obliterate(const std::string& ModuleId) const uint16_t Port = Instance.GetBasePort(); NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Obliterating, Port, {}); - if (m_WorkerPool) + const bool Async = m_ProvisionPool != nullptr && m_SpawnPool != nullptr; + if (Async) { - std::shared_ptr<StorageServerInstance::ExclusiveLockedPtr> SharedInstancePtr = - std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance)); - + auto SharedInstance = std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance)); m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( - [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Instance = std::move(SharedInstancePtr)]() mutable { + m_SpawnPool->ScheduleWork( + [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Port, Instance = SharedInstance]() { auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); try { - CompleteObliterate(*Instance, ActiveInstanceIndex); + RunObliteratePhase1(*Instance, ActiveInstanceIndex, Port); } catch (const std::exception& Ex) { - ZEN_ERROR("Failed async obliterate of module '{}': {}", ModuleId, Ex.what()); + ZEN_ERROR("Failed async obliterate phase 1 of module '{}': {}", ModuleId, Ex.what()); + return; + } + + m_BackgroundWorkLatch.AddCount(1); + try + { + m_ProvisionPool->ScheduleWork( + [this, ModuleId, ActiveInstanceIndex, Port, Instance]() { + auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); }); + RunObliteratePhase2(*Instance, ActiveInstanceIndex, Port); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& DispatchEx) + { + // Fallback: run Phase2 inline on the SpawnPool worker. Couples pool + // lifetimes (this SpawnPool slot now executes what should run on + // ProvisionPool) but better than dropping the request. + ZEN_ERROR("Failed async dispatch of obliterate phase 2 for module '{}': {}", ModuleId, DispatchEx.what()); + m_BackgroundWorkLatch.CountDown(); + RunObliteratePhase2(*Instance, ActiveInstanceIndex, Port); } }, WorkerThreadPool::EMode::EnableBacklog); } catch (const std::exception& DispatchEx) { - ZEN_ERROR("Failed async dispatch obliterate of module '{}': {}", ModuleId, DispatchEx.what()); + ZEN_ERROR("Failed async dispatch obliterate phase 1 of module '{}': {}", ModuleId, DispatchEx.what()); m_BackgroundWorkLatch.CountDown(); NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, OldState, Port, {}); @@ -860,15 +956,23 @@ Hub::Obliterate(const std::string& ModuleId) CompleteObliterate(Instance, ActiveInstanceIndex); } - return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; + return Response{Async ? EResponseCode::Accepted : EResponseCode::Completed}; } void Hub::CompleteObliterate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex) { ZEN_TRACE_CPU("Hub::CompleteObliterate"); + const uint16_t Port = Instance.GetBasePort(); + RunObliteratePhase1(Instance, ActiveInstanceIndex, Port); + RunObliteratePhase2(Instance, ActiveInstanceIndex, Port); +} + +void +Hub::RunObliteratePhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunObliteratePhase1"); const std::string ModuleId(Instance.GetModuleId()); - const uint16_t Port = Instance.GetBasePort(); try { @@ -876,15 +980,35 @@ Hub::CompleteObliterate(StorageServerInstance::ExclusiveLockedPtr& Instance, siz } catch (const std::exception& Ex) { + // Best-effort cleanup: drop tracking and mark Unprovisioned. Transitioning to + // Crashed would let the watchdog re-provision a module the operator wanted gone. ZEN_ERROR("Failed to obliterate storage server instance for module '{}': {}", ModuleId, Ex.what()); - Instance = {}; - { - RwLock::ExclusiveLockScope HubLock(m_Lock); - UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Crashed); - } - NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Crashed, Port, {}); + NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Unprovisioned, Port, {}); + RemoveInstance(Instance, ActiveInstanceIndex, ModuleId); throw; } +} + +void +Hub::RunObliteratePhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunObliteratePhase2"); + const std::string ModuleId(Instance.GetModuleId()); + + try + { + ObliterateBackendData(ModuleId); + } + catch (const std::exception& Ex) + { + // Backend delete failed - documented leak path (see hydration.cpp Obliterate + // retry-then-fail). Drop local tracking and mark Unprovisioned; the watchdog + // must not re-provision a module the operator wanted gone. + ZEN_ERROR("Failed to obliterate backend data for module '{}': {}", ModuleId, Ex.what()); + NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Unprovisioned, Port, {}); + RemoveInstance(Instance, ActiveInstanceIndex, ModuleId); + return; + } NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Unprovisioned, Port, {}); RemoveInstance(Instance, ActiveInstanceIndex, ModuleId); @@ -894,8 +1018,19 @@ void Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState) { ZEN_TRACE_CPU("Hub::CompleteDeprovision"); + const uint16_t Port = Instance.GetBasePort(); + RunDeprovisionPhase1(Instance, ActiveInstanceIndex, OldState, Port); + RunDeprovisionPhase2(Instance, ActiveInstanceIndex, Port); +} + +void +Hub::RunDeprovisionPhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunDeprovisionPhase1"); const std::string ModuleId(Instance.GetModuleId()); - const uint16_t Port = Instance.GetBasePort(); try { @@ -942,9 +1077,8 @@ Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, si catch (const std::exception& Ex) { ZEN_ERROR("Failed to deprovision storage server instance for module '{}': {}", ModuleId, Ex.what()); - // Effectively unreachable: Shutdown() never throws and Dehydrate() failures are swallowed - // by DeprovisionLocked. Kept as a safety net; if somehow reached, transition to Crashed - // so the watchdog can attempt recovery. + // GcClient HTTP calls and Instance.Deprovision can throw on transport failure + // or watchdog races; transition to Crashed so the watchdog can attempt recovery. Instance = {}; { RwLock::ExclusiveLockScope HubLock(m_Lock); @@ -953,6 +1087,22 @@ Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, si NotifyStateUpdate(ModuleId, HubInstanceState::Deprovisioning, HubInstanceState::Crashed, Port, {}); throw; } +} + +void +Hub::RunDeprovisionPhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port) +{ + ZEN_TRACE_CPU("Hub::RunDeprovisionPhase2"); + const std::string ModuleId(Instance.GetModuleId()); + + try + { + DehydrateInstance(ActiveInstanceIndex, ModuleId); + } + catch (const std::exception& Ex) + { + ZEN_WARN("Dehydration of module {} failed during deprovisioning, current state not saved. Reason: {}", ModuleId, Ex.what()); + } NotifyStateUpdate(ModuleId, HubInstanceState::Deprovisioning, HubInstanceState::Unprovisioned, Port, {}); RemoveInstance(Instance, ActiveInstanceIndex, ModuleId); @@ -1012,11 +1162,7 @@ Hub::Hibernate(const std::string& ModuleId) } } - // NOTE: done while not holding the hub lock, to avoid blocking other operations. - // Any concurrent caller that acquired the hub lock and saw Provisioned will now block on - // LockExclusive(Wait=true); by the time it acquires the lock, UpdateInstanceState below - // will have already changed the state and the re-validate above will reject it. - + // Outside hub lock: re-validate after re-locking in worker rejects races. See Provision for the locking argument. ZEN_ASSERT(Instance); ZEN_ASSERT(ActiveInstanceIndex != (size_t)-1); @@ -1024,12 +1170,12 @@ Hub::Hibernate(const std::string& ModuleId) const uint16_t Port = Instance.GetBasePort(); NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Hibernating, Port, {}); - if (m_WorkerPool) + if (m_SpawnPool != nullptr) { m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( + m_SpawnPool->ScheduleWork( [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, @@ -1069,7 +1215,7 @@ Hub::Hibernate(const std::string& ModuleId) CompleteHibernate(Instance, ActiveInstanceIndex, OldState); } - return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; + return Response{m_SpawnPool != nullptr ? EResponseCode::Accepted : EResponseCode::Completed}; } void @@ -1148,11 +1294,7 @@ Hub::Wake(const std::string& ModuleId) } } - // NOTE: done while not holding the hub lock, to avoid blocking other operations. - // Any concurrent caller that acquired the hub lock and saw Hibernated will now block on - // LockExclusive(Wait=true); by the time it acquires the lock, UpdateInstanceState below - // will have already changed the state and the re-validate above will reject it. - + // Outside hub lock: re-validate after re-locking in worker rejects races. See Provision for the locking argument. ZEN_ASSERT(Instance); ZEN_ASSERT(ActiveInstanceIndex != (size_t)-1); @@ -1160,12 +1302,12 @@ Hub::Wake(const std::string& ModuleId) const uint16_t Port = Instance.GetBasePort(); NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Waking, Port, {}); - if (m_WorkerPool) + if (m_SpawnPool != nullptr) { m_BackgroundWorkLatch.AddCount(1); try { - m_WorkerPool->ScheduleWork( + m_SpawnPool->ScheduleWork( [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, @@ -1205,7 +1347,7 @@ Hub::Wake(const std::string& ModuleId) CompleteWake(Instance, ActiveInstanceIndex, OldState); } - return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed}; + return Response{m_SpawnPool != nullptr ? EResponseCode::Accepted : EResponseCode::Completed}; } void @@ -1243,7 +1385,8 @@ Hub::RemoveInstance(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t auto It = m_InstanceLookup.find(std::string(ModuleId)); ZEN_ASSERT_SLOW(It != m_InstanceLookup.end()); ZEN_ASSERT_SLOW(It->second == ActiveInstanceIndex); - DeleteInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance); + DeleteInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance); + m_ActiveInstances[ActiveInstanceIndex].HydrationState = {}; m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex); m_InstanceLookup.erase(It); UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Unprovisioned); @@ -1251,23 +1394,64 @@ Hub::RemoveInstance(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t DeleteInstance.reset(); } +HydrationConfig +Hub::MakeHydrationConfigForModule(std::string_view ModuleId, std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag) const +{ + HydrationConfig Config{.ServerStateDir = m_RunEnvironment.GetChildBaseDir() / ModuleId, + .TempDir = m_HydrationTempPath / ModuleId, + .ModuleId = std::string(ModuleId)}; + if (m_Config.OptionalHydrationPool) + { + Config.Threading.emplace(HydrationConfig::ThreadingOptions{.WorkerPool = m_Config.OptionalHydrationPool, + .AbortFlag = &AbortFlag, + .PauseFlag = &PauseFlag}); + } + Config.PackEnabled = m_Config.HydrationPackEnabled; + Config.PackThresholdBytes = m_Config.HydrationPackThresholdBytes; + Config.MaxPackBytes = m_Config.HydrationMaxPackBytes; + return Config; +} + void -Hub::ObliterateBackendData(std::string_view ModuleId) +Hub::HydrateInstance(size_t ActiveInstanceIndex, std::string_view ModuleId) { - std::filesystem::path ServerStateDir = m_RunEnvironment.GetChildBaseDir() / ModuleId; - std::filesystem::path TempDir = m_HydrationTempPath / ModuleId; + if (!m_Config.EnableHydration) + { + ZEN_INFO("Hydration disabled; skipping hydrate for module '{}'", ModuleId); + return; + } + ZEN_TRACE_CPU("Hub::HydrateInstance"); - std::atomic<bool> AbortFlag{false}; - std::atomic<bool> PauseFlag{false}; + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + HydrationConfig Config = MakeHydrationConfigForModule(ModuleId, AbortFlag, PauseFlag); + std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration->CreateHydrator(Config); + m_ActiveInstances[ActiveInstanceIndex].HydrationState = Hydrator->Hydrate(); +} - HydrationConfig Config{.ServerStateDir = ServerStateDir, .TempDir = TempDir, .ModuleId = std::string(ModuleId)}; - if (m_Config.OptionalHydrationWorkerPool) +void +Hub::DehydrateInstance(size_t ActiveInstanceIndex, std::string_view ModuleId) +{ + if (!m_Config.EnableDehydration) { - Config.Threading.emplace(HydrationConfig::ThreadingOptions{.WorkerPool = m_Config.OptionalHydrationWorkerPool, - .AbortFlag = &AbortFlag, - .PauseFlag = &PauseFlag}); + ZEN_INFO("Dehydration disabled; skipping dehydrate for module '{}'", ModuleId); + return; } + ZEN_TRACE_CPU("Hub::DehydrateInstance"); + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + HydrationConfig Config = MakeHydrationConfigForModule(ModuleId, AbortFlag, PauseFlag); + std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration->CreateHydrator(Config); + Hydrator->Dehydrate(m_ActiveInstances[ActiveInstanceIndex].HydrationState); +} + +void +Hub::ObliterateBackendData(std::string_view ModuleId) +{ + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + HydrationConfig Config = MakeHydrationConfigForModule(ModuleId, AbortFlag, PauseFlag); std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration->CreateHydrator(Config); Hydrator->Obliterate(); } @@ -1478,6 +1662,16 @@ Hub::AttemptRecoverInstance(std::string_view ModuleId) try { Instance.Deprovision(); + try + { + DehydrateInstance(ActiveInstanceIndex, ModuleId); + } + catch (const std::exception& Ex) + { + ZEN_WARN("Dehydration of module {} failed during crash recovery cleanup, current state not saved. Reason: {}", + ModuleId, + Ex.what()); + } } catch (const std::exception& Ex) { @@ -1502,6 +1696,7 @@ Hub::AttemptRecoverInstance(std::string_view ModuleId) try { + HydrateInstance(ActiveInstanceIndex, ModuleId); Instance.Provision(); UpdateInstanceState(Instance, ActiveInstanceIndex, HubInstanceState::Provisioned); NotifyStateUpdate(ModuleId, HubInstanceState::Recovering, HubInstanceState::Provisioned, Instance.GetBasePort(), /*BaseUri*/ {}); @@ -1795,7 +1990,6 @@ Hub::WatchDog() } catch (const std::exception& Ex) { - // TODO: Catch specific errors such as asserts, OOM, OOD, system_error etc ZEN_ERROR("Hub watchdog threw exception: {}", Ex.what()); } } @@ -1835,9 +2029,15 @@ namespace hub_testutils { struct TestHubPools { WorkerThreadPool ProvisionPool; + WorkerThreadPool SpawnPool; WorkerThreadPool HydrationPool; - explicit TestHubPools(int ThreadCount) : ProvisionPool(ThreadCount, "hub_test_prov"), HydrationPool(ThreadCount, "hub_test_hydr") {} + explicit TestHubPools(int ThreadCount) + : ProvisionPool(ThreadCount, "hub_test_provision") + , SpawnPool(ThreadCount, "hub_test_spawn") + , HydrationPool(ThreadCount, "hub_test_hydr") + { + } }; ZenServerEnvironment MakeHubEnvironment(const std::filesystem::path& BaseDir) @@ -1852,8 +2052,9 @@ namespace hub_testutils { { if (Pools) { - Config.OptionalProvisionWorkerPool = &Pools->ProvisionPool; - Config.OptionalHydrationWorkerPool = &Pools->HydrationPool; + Config.OptionalProvisionPool = &Pools->ProvisionPool; + Config.OptionalSpawnPool = &Pools->SpawnPool; + Config.OptionalHydrationPool = &Pools->HydrationPool; } return std::make_unique<Hub>(Config, MakeHubEnvironment(BaseDir), std::move(StateChangeCallback)); } diff --git a/src/zenserver/hub/hub.h b/src/zenserver/hub/hub.h index fa504de33..c9b673ad6 100644 --- a/src/zenserver/hub/hub.h +++ b/src/zenserver/hub/hub.h @@ -3,6 +3,7 @@ #pragma once #include "hubinstancestate.h" +#include "hydration.h" #include "resourcemetrics.h" #include "storageserverinstance.h" @@ -81,13 +82,22 @@ public: bool HydrationPackEnabled = true; uint64_t HydrationPackThresholdBytes = DefaultPackThresholdBytes; uint64_t HydrationMaxPackBytes = DefaultMaxPackBytes; + // Route S3 hydration through AsyncHttpClient. false falls back to the + // blocking S3Client path. + bool HydrationAsyncEnabled = true; + + // Hub-wide cap on concurrent S3 hydration requests (sizes the shared + // AsyncHttpClient connection pool and the admission semaphore). Only + // consulted when HydrationAsyncEnabled. + uint32_t HydrationAsyncMaxConcurrentRequests = 128; WatchDogConfiguration WatchDog; ResourceMetrics ResourceLimits; - WorkerThreadPool* OptionalProvisionWorkerPool = nullptr; - WorkerThreadPool* OptionalHydrationWorkerPool = nullptr; + WorkerThreadPool* OptionalProvisionPool = nullptr; + WorkerThreadPool* OptionalSpawnPool = nullptr; + WorkerThreadPool* OptionalHydrationPool = nullptr; }; typedef std::function< @@ -209,7 +219,8 @@ public: private: const Configuration m_Config; ZenServerEnvironment m_RunEnvironment; - WorkerThreadPool* m_WorkerPool = nullptr; + WorkerThreadPool* m_ProvisionPool = nullptr; + WorkerThreadPool* m_SpawnPool = nullptr; // Declared early so it destructs late: every HttpClient referencing the // share (watchdog ActivityCheckClient, GC client, in-flight worker // tasks) is required to be gone before this member runs its dtor. @@ -264,6 +275,12 @@ private: // Set in UpdateInstanceStateLocked on every state transition; read lock-free by Find/EnumerateModules. std::atomic<std::chrono::system_clock::time_point> StateChangeTime = std::chrono::system_clock::time_point::min(); + + // Cached hydration state returned by Hydrate, consumed by Dehydrate. Synchronized + // by the caller's StorageServerInstance ExclusiveLockedPtr (HydrateInstance and + // DehydrateInstance run under it); RemoveInstance also writes under the Hub + // exclusive lock when finalizing slot teardown. + CbObject HydrationState; }; // UpdateInstanceState is overloaded to accept a locked instance pointer (exclusive or shared) or the hub exclusive @@ -319,8 +336,42 @@ private: void CompleteObliterate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex); void CompleteHibernate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState); void CompleteWake(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState); + + // Provision/Deprovision/Obliterate are split into two phases scheduled on different worker + // pools. The Phase1/Phase2 helpers are shared between sync and async code paths so behavior + // cannot diverge between them. + bool RunProvisionPhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port); + void RunProvisionPhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port); + void RollbackFailedProvision(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + bool IsNewInstance, + uint16_t Port); + + void RunDeprovisionPhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, + size_t ActiveInstanceIndex, + HubInstanceState OldState, + uint16_t Port); + void RunDeprovisionPhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port); + + void RunObliteratePhase1(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port); + void RunObliteratePhase2(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, uint16_t Port); + void RemoveInstance(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, std::string_view ModuleId); - void ObliterateBackendData(std::string_view ModuleId); + HydrationConfig MakeHydrationConfigForModule(std::string_view ModuleId, + std::atomic<bool>& AbortFlag, + std::atomic<bool>& PauseFlag) const; + void HydrateInstance(size_t ActiveInstanceIndex, std::string_view ModuleId); + void DehydrateInstance(size_t ActiveInstanceIndex, std::string_view ModuleId); + void ObliterateBackendData(std::string_view ModuleId); // Notifications may fire slightly out of sync with the Hub's internal State flag. // The guarantee is that notifications are sent in the correct order, but the State diff --git a/src/zenserver/hub/hydration.cpp b/src/zenserver/hub/hydration.cpp index 621af8a46..2df326fab 100644 --- a/src/zenserver/hub/hydration.cpp +++ b/src/zenserver/hub/hydration.cpp @@ -2,6 +2,12 @@ #include "hydration.h" +#include "s3asyncstorage.h" + +#include <zenhttp/asynchttpclient.h> +#include <zenutil/cloud/s3requestbuilder.h> +#include <zenutil/cloud/s3response.h> + #include <zencore/basicfile.h> #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> @@ -30,6 +36,7 @@ #include <unordered_set> #if ZEN_WITH_TESTS +# include <zencore/process.h> # include <zencore/testing.h> # include <zencore/testutils.h> # include <zencore/workthreadpool.h> @@ -159,6 +166,12 @@ namespace hydration_impl { std::atomic<uint64_t> FirstScheduleUs{UINT64_MAX}; std::atomic<uint64_t> FirstStartUs{UINT64_MAX}; + // Admission-gate wait. AdmissionWaitTotalUs is summed across all requests that blocked on + // the storage-layer concurrency semaphore; AdmissionWaitMaxUs is the worst single wait + // observed. Total exposes back-pressure cost; Max surfaces head-of-line blocking. + std::atomic<uint64_t> AdmissionWaitTotalUs{0}; + std::atomic<uint64_t> AdmissionWaitMaxUs{0}; + void RecordScheduled() { uint64_t Now = PhaseClock.GetElapsedTimeUs(); @@ -195,6 +208,15 @@ namespace hydration_impl { { } } + + void RecordAdmissionWait(uint64_t Us) + { + AdmissionWaitTotalUs.fetch_add(Us, std::memory_order_relaxed); + uint64_t Prev = AdmissionWaitMaxUs.load(std::memory_order_relaxed); + while (Us > Prev && !AdmissionWaitMaxUs.compare_exchange_weak(Prev, Us, std::memory_order_relaxed)) + { + } + } }; struct DehydrateStatistics @@ -362,6 +384,116 @@ namespace hydration_impl { uint64_t m_MultipartChunkSize; }; + S3AsyncStorageStats AsyncStatsFrom(PhaseStats& Stats) + { + return S3AsyncStorageStats{Stats.RequestCount, + Stats.RequestTotalUs, + Stats.RequestMaxUs, + Stats.Bytes, + Stats.InFlight, + Stats.InFlightPeak, + Stats.FirstScheduleUs, + Stats.FirstStartUs, + Stats.AdmissionWaitTotalUs, + Stats.AdmissionWaitMaxUs, + Stats.PhaseClock}; + } + + class S3AsyncStorageAdapter : public StorageBase + { + public: + static constexpr std::string_view Type = "s3-async"; + + S3AsyncStorageAdapter(AsyncHttpClient& Client, + S3RequestBuilder& Builder, + S3AsyncStorage::CredentialsCallback GetCreds, + std::string KeyPrefix, + uint64_t MultipartChunkSize, + std::shared_ptr<AdmissionSemaphore> Admission, + uint32_t AdmissionCap) + : m_Client(Client) + , m_Builder(Builder) + , m_GetCreds(std::move(GetCreds)) + , m_KeyPrefix(std::move(KeyPrefix)) + , m_MultipartChunkSize(MultipartChunkSize) + , m_Admission(std::move(Admission)) + , m_AdmissionCap(AdmissionCap) + , m_Storage( + std::make_unique<S3AsyncStorage>(Client, Builder, m_GetCreds, m_KeyPrefix, m_MultipartChunkSize, m_Admission, m_AdmissionCap)) + { + } + + virtual std::string Describe() const override { return fmt::format("s3-async://{}/{}"sv, m_Builder.BucketName(), m_KeyPrefix); } + + virtual void SaveMetadata(const CbObject& Data) override; + virtual CbObject LoadMetadata() override; + virtual CbObject GetSettings() override + { + CbObjectWriter Writer; + Writer << "MultipartChunkSize"sv << m_MultipartChunkSize; + return Writer.Save(); + } + virtual void ParseSettings(const CbObjectView& Settings) override + { + m_MultipartChunkSize = Settings["MultipartChunkSize"sv].AsUInt64(DefaultMultipartChunkSize); + m_Storage = std::make_unique<S3AsyncStorage>(m_Client, + m_Builder, + m_GetCreds, + m_KeyPrefix, + m_MultipartChunkSize, + m_Admission, + m_AdmissionCap); + } + virtual std::vector<IoHash> List() override { return m_Storage->List(); } + + virtual void Put(ParallelWork& Work, + WorkerThreadPool& WorkerPool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + PhaseStats& Stats) override + { + Stats.Files.fetch_add(1, std::memory_order_relaxed); + S3AsyncStorageStats AsyncStats = AsyncStatsFrom(Stats); + m_Storage->Put(Work, WorkerPool, Hash, Size, SourcePath, AsyncStats); + } + + virtual void Get(ParallelWork& Work, + WorkerThreadPool& WorkerPool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath, + PhaseStats& Stats) override + { + Stats.Files.fetch_add(1, std::memory_order_relaxed); + S3AsyncStorageStats AsyncStats = AsyncStatsFrom(Stats); + m_Storage->Get(Work, WorkerPool, Hash, Size, DestinationPath, AsyncStats); + } + + virtual void Touch(ParallelWork& Work, WorkerThreadPool& WorkerPool, const IoHash& Hash, PhaseStats& Stats) override + { + Stats.Files.fetch_add(1, std::memory_order_relaxed); + S3AsyncStorageStats AsyncStats = AsyncStatsFrom(Stats); + m_Storage->Touch(Work, WorkerPool, Hash, AsyncStats); + } + + virtual void Delete(ParallelWork& Work, WorkerThreadPool& WorkerPool) override + { + ZEN_UNUSED(WorkerPool); + m_Storage->DeleteAll(Work); + } + + private: + AsyncHttpClient& m_Client; + S3RequestBuilder& m_Builder; + S3AsyncStorage::CredentialsCallback m_GetCreds; + std::string m_KeyPrefix; + uint64_t m_MultipartChunkSize; + std::shared_ptr<AdmissionSemaphore> m_Admission; + uint32_t m_AdmissionCap = 0; + std::unique_ptr<S3AsyncStorage> m_Storage; + }; + /////////////////////////////////////////////////////////////////////// // FileStorage implementations @@ -701,9 +833,14 @@ namespace hydration_impl { RenameFile(ChunkPath, DestinationPath, Ec); if (Ec) { - Chunk.Content = IoBufferBuilder::MakeFromFile(ChunkPath); - Chunk.Content.SetDeleteOnClose(true); - WriteFile(DestinationPath, Chunk.Content); + // Cross-volume rename failed; copy the temp file to the destination + // using explicit positional reads (no mmap). Caller is responsible + // for the source file's eventual cleanup via the temp dir sweep. + BasicFile Src(ChunkPath, BasicFile::Mode::kRead); + IoBuffer Body = Src.ReadAll(); + WriteFile(DestinationPath, Body); + std::error_code RemoveEc; + std::filesystem::remove(ChunkPath, RemoveEc); } } } @@ -759,6 +896,63 @@ namespace hydration_impl { } } + void S3AsyncStorageAdapter::SaveMetadata(const CbObject& Data) + { + ZEN_TRACE_CPU("S3AsyncStorageAdapter::SaveMetadata"); + BinaryWriter Output; + SaveCompactBinary(Output, Data); + + SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + throw zen::runtime_error("S3AsyncStorageAdapter::SaveMetadata: no credentials available"sv); + } + + std::string Key = fmt::format("{}/incremental-state.cbo", m_KeyPrefix); + std::string Path = m_Builder.KeyToPath(Key); + std::string Hash = Sha256ToHex(ComputeSha256(Output.GetData(), Output.GetSize())); + HttpClient::KeyValueMap Signed = m_Builder.SignRequest(Creds, "PUT", Path, "", Hash); + IoBuffer Payload(IoBuffer::Clone, Output.GetData(), Output.GetSize()); + + HttpClient::Response Resp = m_Client.Put(Path, Payload, Signed).get(); + if (!Resp.IsSuccess()) + { + throw zen::runtime_error("Failed to save incremental metadata to '{}': {}"sv, Key, S3ErrorMessage("S3 PUT failed", Resp)); + } + } + + CbObject S3AsyncStorageAdapter::LoadMetadata() + { + ZEN_TRACE_CPU("S3AsyncStorageAdapter::LoadMetadata"); + SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + throw zen::runtime_error("S3AsyncStorageAdapter::LoadMetadata: no credentials available"sv); + } + + std::string Key = fmt::format("{}/incremental-state.cbo", m_KeyPrefix); + std::string Path = m_Builder.KeyToPath(Key); + HttpClient::KeyValueMap Signed = m_Builder.SignRequest(Creds, "GET", Path, "", S3EmptyPayloadHash); + + HttpClient::Response Resp = m_Client.Get(Path, Signed).get(); + if (!Resp.IsSuccess()) + { + if (Resp.StatusCode == HttpResponseCode::NotFound) + { + return {}; + } + throw zen::runtime_error("Failed to load incremental metadata from '{}': {}"sv, Key, S3ErrorMessage("S3 GET failed", Resp)); + } + + CbValidateError Error; + CbObject Meta = ValidateAndReadCompactBinaryObject(std::move(Resp.ResponsePayload), Error); + if (Error != CbValidateError::None) + { + throw zen::runtime_error("Failed to parse incremental metadata from '{}': {}"sv, Key, ToString(Error)); + } + return Meta; + } + /////////////////////////////////////////////////////////////////////// // IncrementalHydrator: the only HydrationStrategyBase implementation. // Summary emission for hydrate/dehydrate operations. @@ -802,6 +996,7 @@ namespace hydration_impl { // (Stats.PackUpload), loose Touch (Stats.Touch), and pack-blob Touch (Stats.PackTouch). // Per-request data is collected per PhaseStats by Storage::Put / Storage::Touch and // reported as a single combined "Requests" line. + // const uint64_t UpReqCount = Stats.Upload.RequestCount.load() + Stats.PackUpload.RequestCount.load() + Stats.Touch.RequestCount.load() + Stats.PackTouch.RequestCount.load(); const uint64_t UpReqTotalUs = Stats.Upload.RequestTotalUs.load() + Stats.PackUpload.RequestTotalUs.load() + @@ -827,6 +1022,16 @@ namespace hydration_impl { Stats.PackTouch.FirstStartUs.load()}); const uint64_t UpQueueUs = QueueWaitUs(UpFirstSchedUs, UpFirstStartUs); + // Storage-layer admission wait. Sum across all four upload sub-phases gives total + // time blocked acquiring slots on the dispatcher; max is the worst single wait + // observed. Both zero when admission is disabled (file-backend / blocking S3). + const uint64_t UpAdmTotalUs = Stats.Upload.AdmissionWaitTotalUs.load() + Stats.PackUpload.AdmissionWaitTotalUs.load() + + Stats.Touch.AdmissionWaitTotalUs.load() + Stats.PackTouch.AdmissionWaitTotalUs.load(); + const uint64_t UpAdmMaxUs = std::max({Stats.Upload.AdmissionWaitMaxUs.load(), + Stats.PackUpload.AdmissionWaitMaxUs.load(), + Stats.Touch.AdmissionWaitMaxUs.load(), + Stats.PackTouch.AdmissionWaitMaxUs.load()}); + const uint64_t LooseFiles = Stats.Upload.Files.load(); const uint64_t LooseBytes = Stats.Upload.Bytes.load(); const uint64_t TouchFiles = Stats.Touch.Files.load(); @@ -852,7 +1057,7 @@ namespace hydration_impl { " List existing: {}\n" " Pack: {} {} packs, {} files, {}, {}bits/s\n" " Upload: {} loose {} files ({}), packed {} blobs ({}), touched {} loose ({}) + {} packs ({}), {}bits/s\n" - " Requests: {} reqs, avg {}/req, max {}/req, peak in-flight {}, queue wait {}\n" + " Requests: {} reqs, avg {}/req, max {}/req, peak in-flight {}, queue wait {}, admission wait avg {}/req max {}\n" " Save metadata: {}\n" " Clean: {}", Prefix, @@ -896,6 +1101,8 @@ namespace hydration_impl { NiceTimeSpanUs(UpReqMaxUs), UpPeak, NiceTimeSpanUs(UpQueueUs), + NiceTimeSpanUs(SafeAvg(UpAdmTotalUs, UpReqCount)), + NiceTimeSpanUs(UpAdmMaxUs), NiceTimeSpanUs(Stats.SaveMetadataUs.load()), NiceTimeSpanUs(Stats.CleanUs.load())); } @@ -924,6 +1131,12 @@ namespace hydration_impl { const uint64_t DlFirstStartUs = std::min(Stats.Download.FirstStartUs.load(), Stats.PackDownload.FirstStartUs.load()); const uint64_t QueueUs = QueueWaitUs(DlFirstSchedUs, DlFirstStartUs); + // Storage-layer admission wait. Sum across loose + pack downloads gives total + // dispatcher block time; max is the worst single wait. Both zero when admission + // is disabled (file-backend / blocking S3). + const uint64_t DlAdmTotalUs = Stats.Download.AdmissionWaitTotalUs.load() + Stats.PackDownload.AdmissionWaitTotalUs.load(); + const uint64_t DlAdmMaxUs = std::max(Stats.Download.AdmissionWaitMaxUs.load(), Stats.PackDownload.AdmissionWaitMaxUs.load()); + const uint64_t PackCount = Stats.PackCount.load(); const uint64_t PackedFiles = Stats.PackedFiles.load(); const uint64_t PackUnpackUs = Stats.PackUnpackUs.load(); @@ -942,7 +1155,7 @@ namespace hydration_impl { " Load metadata: {}\n" " Create dirs: {} {} dirs, {} dirs/s\n" " Download: {} loose {} files ({}), packed {} blobs ({}), {}bits/s\n" - " Requests: {} reqs, avg {}/req, max {}/req, peak in-flight {}, queue wait {}\n" + " Requests: {} reqs, avg {}/req, max {}/req, peak in-flight {}, queue wait {}, admission wait avg {}/req max {}\n" " Unpack: {} {} packs, {} files ({}), {}bits/s\n" " Clean: {}\n" " Finalize: {}\n" @@ -969,6 +1182,8 @@ namespace hydration_impl { NiceTimeSpanUs(DlReqMaxUs), DlPeak, NiceTimeSpanUs(QueueUs), + NiceTimeSpanUs(SafeAvg(DlAdmTotalUs, DlReqCount)), + NiceTimeSpanUs(DlAdmMaxUs), NiceTimeSpanUs(PackUnpackUs), ThousandsNum(PackCount), ThousandsNum(PackedFiles), @@ -1157,20 +1372,26 @@ namespace hydration_impl { // the hash is a meta-hash combining the embedded RawHash with the file size, which // avoids a collision between an uncompressed file and a same-content compressed file. // All other files use a streaming raw hash via BasicFile + IoHashStream (sequential - // read, friendlier to the Windows cache manager than mmap). + // reads). All reads are explicit positional reads via BasicFile - no mmap, no + // IoBufferBuilder::MakeFromFile materialization. void HashFileContent(const std::filesystem::path& AbsPath, Entry& Out) { + BasicFile File(AbsPath, BasicFile::Mode::kRead); + if (AbsPath.extension().empty()) { std::string_view Rel = Out.RelativePath; std::string_view First = Rel.substr(0, Rel.find('/')); if (First.ends_with("cas")) { + // Read compressed bytes into a heap IoBuffer (single positional read, no mmap) + // and probe with FromCompressed. On success, derive a meta-hash from the + // embedded RawHash + size and return without hashing the bytes themselves. + IoBuffer Compressed = File.ReadAll(); IoHash RawHash; uint64_t RawSize; - CompressedBuffer Compressed = - CompressedBuffer::FromCompressed(SharedBuffer(IoBufferBuilder::MakeFromFile(AbsPath)), RawHash, RawSize); - if (Compressed) + CompressedBuffer Probe = CompressedBuffer::FromCompressed(SharedBuffer(std::move(Compressed)), RawHash, RawSize); + if (Probe) { IoHashStream Hasher; Hasher.Append(RawHash.Hash, sizeof(RawHash.Hash)); @@ -1178,10 +1399,10 @@ namespace hydration_impl { Out.Hash = Hasher.GetHash(); return; } + // Not a compressed file - fall through to streaming raw hash. } } - BasicFile File(AbsPath, BasicFile::Mode::kRead); IoHashStream Hasher; File.StreamFile([&Hasher](const void* Data, uint64_t Size) { Hasher.Append(Data, Size); }); Out.Hash = Hasher.GetHash(); @@ -1264,9 +1485,7 @@ namespace hydration_impl { // Returns one PackPlan per pack to build (empty if no packs are produced). std::vector<PackPlan> PlanPacks(std::vector<Entry>& Entries, uint64_t Threshold, uint64_t MaxPackBytes) { - // 1. Group small-file Entries[] indices by content hash. Every index in a group - // shares the same bytes, so any one of them sources the pack content; all of - // them get tagged IsPacked once the pack hash is known. + // Group small-file Entries[] indices by content hash. std::unordered_map<IoHash, EntryGroup, IoHash::Hasher> UniqueMap; for (size_t Index = 0; Index < Entries.size(); ++Index) { @@ -1277,7 +1496,6 @@ namespace hydration_impl { UniqueMap[Entries[Index].Hash].push_back(Index); } - // Need at least 2 unique groups for any pack to survive the "discard 1-entry packs" rule. if (UniqueMap.size() < 2) { return {}; @@ -1286,7 +1504,7 @@ namespace hydration_impl { auto GroupHash = [&](const EntryGroup& G) -> const IoHash& { return Entries[G.front()].Hash; }; auto GroupSize = [&](const EntryGroup& G) -> uint64_t { return Entries[G.front()].Size; }; - // 2. Deterministic order: ascending IoHash. Drain the map so the index vectors move. + // Sort groups by ascending IoHash for deterministic pack composition. std::vector<EntryGroup> Ordered; Ordered.reserve(UniqueMap.size()); for (auto& [h, g] : UniqueMap) @@ -1295,7 +1513,7 @@ namespace hydration_impl { } std::sort(Ordered.begin(), Ordered.end(), [&](const EntryGroup& A, const EntryGroup& B) { return GroupHash(A) < GroupHash(B); }); - // 3. Bin-pack greedily under MaxPackBytes. + // Bin-pack greedily under MaxPackBytes. std::vector<PackPlan> Plans; PackPlan Current; uint64_t CurrentSize = 0; @@ -1815,6 +2033,17 @@ namespace hydration_impl { const std::vector<PackPlan> Pending = m_Config.PackEnabled ? PlanPacks(Entries, m_Config.PackThresholdBytes, m_Config.MaxPackBytes) : std::vector<PackPlan>{}; + // Pre-build absolute paths once. MakeSafeAbsolutePath does path normalization + + // Windows long-path prefix application; ~microseconds per entry. With 100k+ + // entries the per-iter cost in the dispatch loop adds up. Mirrors the Hydrate + // side which already pre-builds EntryPaths. + std::vector<std::filesystem::path> EntryPaths; + EntryPaths.reserve(Entries.size()); + for (const Entry& CurrentEntry : Entries) + { + EntryPaths.push_back(MakeSafeAbsolutePath(ServerStateDir / CurrentEntry.RelativePath)); + } + uint64_t DehydrateDurationMs = 0; { // Upload, PackUpload, Touch, and PackTouch share one ParallelWork; reset all @@ -1829,9 +2058,9 @@ namespace hydration_impl { // Schedule loose-CAS uploads first so they begin running while the pack-build // loop below executes serially on this thread. - for (const Entry& CurrentEntry : Entries) + for (size_t I = 0; I < Entries.size(); ++I) { - if (CurrentEntry.IsPacked) + if (Entries[I].IsPacked) { continue; // pack phase covers it } @@ -1839,9 +2068,9 @@ namespace hydration_impl { Work, *m_Threading.WorkerPool, ExistsLookup, - CurrentEntry.Hash, - CurrentEntry.Size, - MakeSafeAbsolutePath(ServerStateDir / CurrentEntry.RelativePath), + Entries[I].Hash, + Entries[I].Size, + EntryPaths[I], Stats.Upload, Stats.Touch); } @@ -1930,6 +2159,8 @@ namespace hydration_impl { } catch (const std::exception& Ex) { + // Failure is OK to swallow: next dehydrate or fresh hydrate falls back to + // the older state still on the backend. ZEN_WARN("Dehydration of module '{}' failed: {}. Leaving server state '{}'", m_Config.ModuleId, Ex.what(), @@ -2143,6 +2374,8 @@ namespace hydration_impl { } catch (const std::exception& Ex) { + // Failure is OK to swallow: starts the instance with empty state, next + // dehydrate re-publishes from whatever the running instance materializes. ZEN_WARN("Hydration of module '{}' failed: {}. Cleaning server state '{}'", m_Config.ModuleId, Ex.what(), @@ -2227,6 +2460,20 @@ private: Ref<ImdsCredentialProvider> m_CredentialProvider; std::unique_ptr<S3Client> m_Client; uint64_t m_DefaultMultipartChunkSize; + + // Async path: when Config.AsyncEnabled, build AsyncHttpClient + S3RequestBuilder + // shared across all per-module storage instances. Null otherwise. + std::unique_ptr<AsyncHttpClient> m_AsyncClient; + std::unique_ptr<S3RequestBuilder> m_AsyncBuilder; + + // Storage-layer admission gate, shared across all per-module S3AsyncStorage + // instances. Initial slot count = AsyncMaxConcurrentRequests. + std::shared_ptr<AdmissionSemaphore> m_AsyncAdmission; + uint32_t m_AsyncAdmissionCap = 0; + + // Captures m_Credentials / m_CredentialProvider state via callable so per-module + // S3AsyncStorage instances stay decoupled from the credential rotation logic. + S3AsyncStorage::CredentialsCallback BuildCredentialsCallback(); }; HydrationBase::HydrationBase(const Configuration& Config) @@ -2357,6 +2604,49 @@ S3Hydration::S3Hydration(const Configuration& Config) : HydrationBase(Config) ClientOptions.HttpSettings.RetryCount = 3; m_Client = std::make_unique<S3Client>(ClientOptions); + + if (Config.AsyncEnabled) + { + // Curl conn caps pinned to the request cap so handles never sit on + // libcurl's internal queue waiting for a connection slot (CONNECTTIMEOUT + // would tick there). With one S3 endpoint all connections go to the same + // host: PerHost is the binding cap, Total mirrors. MaxConcurrentRequests + // is a hint shared with the storage admission semaphore below. + HttpClientSettings AsyncSettings = ClientOptions.HttpSettings; + AsyncSettings.MaxConcurrentRequests = Config.AsyncMaxConcurrentRequests; + AsyncSettings.MaxConcurrentConnectionsPerHost = Config.AsyncMaxConcurrentRequests; + AsyncSettings.MaxConcurrentConnectionsTotal = Config.AsyncMaxConcurrentRequests; + + m_AsyncBuilder = std::make_unique<S3RequestBuilder>(m_Region, m_Bucket, m_Endpoint, m_PathStyle); + m_AsyncClient = std::make_unique<AsyncHttpClient>(m_AsyncBuilder->Endpoint(), AsyncSettings); + + // Storage-layer admission: paces S3 fan-out at the same in-flight cap + // curl uses for connection limits. Acquire happens on the dispatcher + // thread that drives Hydrate/Dehydrate, so back-pressure flows back to + // the caller without blocking io strand or hydration-pool workers. + m_AsyncAdmissionCap = Config.AsyncMaxConcurrentRequests; + m_AsyncAdmission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(m_AsyncAdmissionCap)); + ZEN_INFO("S3 hydration: async path enabled (max-concurrent-requests={})", Config.AsyncMaxConcurrentRequests); + } + else + { + ZEN_INFO("S3 hydration: blocking S3Client path"); + } +} + +S3AsyncStorage::CredentialsCallback +S3Hydration::BuildCredentialsCallback() +{ + if (m_CredentialProvider) + { + Ref<ImdsCredentialProvider> Provider = m_CredentialProvider; + return [Provider]() { + SigV4Credentials Creds = Provider->GetCredentials(); + return Creds; + }; + } + SigV4Credentials Creds = m_Credentials; + return [Creds]() { return Creds; }; } std::unique_ptr<HydrationStrategyBase> @@ -2365,6 +2655,19 @@ S3Hydration::CreateHydrator(const HydrationConfig& Config) using namespace hydration_impl; std::string KeyPrefix = m_KeyPrefixRoot.empty() ? std::string(Config.ModuleId) : fmt::format("{}/{}"sv, m_KeyPrefixRoot, Config.ModuleId); + + if (m_AsyncClient) + { + return std::make_unique<IncrementalHydrator>(Config, + std::make_unique<S3AsyncStorageAdapter>(*m_AsyncClient, + *m_AsyncBuilder, + BuildCredentialsCallback(), + std::move(KeyPrefix), + m_DefaultMultipartChunkSize, + m_AsyncAdmission, + m_AsyncAdmissionCap), + m_Excludes); + } return std::make_unique<IncrementalHydrator>( Config, std::make_unique<S3Storage>(*m_Client, std::move(KeyPrefix), Config.TempDir, m_DefaultMultipartChunkSize), @@ -2993,10 +3296,20 @@ TEST_CASE("hydration.file.concurrent") // The MinIO binary must be present in the same directory as the test executable (copied by xmake). // --------------------------------------------------------------------------- +namespace { + // Per-binary unique MinIO port. + uint16_t AllocateHydrationMinioTestPort() + { + static const uint16_t Base = static_cast<uint16_t>(20000u + (static_cast<uint32_t>(GetCurrentProcessId()) % 30000u)); + static std::atomic<uint16_t> Slot{0}; + return Base + Slot.fetch_add(1, std::memory_order_relaxed); + } +} // namespace + TEST_CASE("hydration.s3.dehydrate_hydrate") { MinioProcessOptions MinioOpts; - MinioOpts.Port = 19011; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3055,7 +3368,7 @@ TEST_CASE("hydration.s3.concurrent") // N modules dehydrate and hydrate concurrently against MinIO. // Each module has a distinct ModuleId, so S3 key prefixes don't overlap. MinioProcessOptions MinioOpts; - MinioOpts.Port = 19013; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3149,7 +3462,7 @@ TEST_CASE("hydration.s3.concurrent") TEST_CASE("hydration.s3.obliterate") { MinioProcessOptions MinioOpts; - MinioOpts.Port = 19019; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3215,7 +3528,7 @@ TEST_CASE("hydration.s3.obliterate") TEST_CASE("hydration.s3.config_overrides") { MinioProcessOptions MinioOpts; - MinioOpts.Port = 19015; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3293,7 +3606,7 @@ TEST_CASE("hydration.s3.config_overrides") TEST_CASE("hydration.s3.dehydrate_hydrate.performance" * doctest::skip()) { MinioProcessOptions MinioOpts; - MinioOpts.Port = 19010; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3425,7 +3738,7 @@ TEST_CASE("hydration.file.incremental") TEST_CASE("hydration.s3.incremental") { MinioProcessOptions MinioOpts; - MinioOpts.Port = 19017; + MinioOpts.Port = AllocateHydrationMinioTestPort(); MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); Minio.CreateBucket("zen-hydration-test"); @@ -3524,6 +3837,328 @@ TEST_CASE("hydration.create_hydrator_rejects_invalid_config") CHECK_THROWS(InitHydration({})); } +TEST_CASE("hydration.s3async.dehydrate_hydrate") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = AllocateHydrationMinioTestPort(); + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + ScopedTemporaryDirectory TempDir; + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationTemp); + + HydrationBase::Configuration BaseConfig; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + BaseConfig.Options = std::move(Root).AsObject(); + } + BaseConfig.AsyncEnabled = true; + auto Hydration = InitHydration(BaseConfig); + + HydrationConfig Config{.ServerStateDir = ServerStateDir, .TempDir = HydrationTemp, .ModuleId = "s3async_roundtrip"}; + + WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256)); + Hydration->CreateHydrator(Config)->Hydrate(); + CHECK(std::filesystem::is_empty(ServerStateDir)); + + CreateSmallTestTree(ServerStateDir); + Hydration->CreateHydrator(Config)->Dehydrate(CbObject()); + + CreateSmallTestTree(ServerStateDir); + WriteFile(ServerStateDir / "v2marker.bin", CreateSemiRandomBlob(64)); + Hydration->CreateHydrator(Config)->Dehydrate(CbObject()); + + CleanDirectory(ServerStateDir, true); + Hydration->CreateHydrator(Config)->Hydrate(); + + CHECK(std::filesystem::exists(ServerStateDir / "v2marker.bin")); + CHECK(std::filesystem::exists(ServerStateDir / "subdir" / "file_b.bin")); + CHECK(std::filesystem::exists(ServerStateDir / "subdir" / "nested" / "file_c.bin")); +} + +// Exercises all three Put tiers (Small/Medium/Multipart) plus pack uploads in +// one round-trip. CreateTestTree adds 256K/512K/9M/63M blobs on top of the +// small-file set; the small files get packed, the 9M lands in PutMedium, and +// the 63M lands in PutMultipart. +TEST_CASE("hydration.s3async.dehydrate_hydrate.all_tiers") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = AllocateHydrationMinioTestPort(); + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + ScopedTemporaryDirectory TempDir; + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationTemp); + + HydrationBase::Configuration BaseConfig; + { + std::string ConfigJson = fmt::format( + R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true,"chunksize":{}}}}})", + Minio.Endpoint(), + 5u * 1024u * 1024u); // 5 MiB chunks -> multipart threshold ~6.25 MiB + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + BaseConfig.Options = std::move(Root).AsObject(); + } + BaseConfig.AsyncEnabled = true; + auto Hydration = InitHydration(BaseConfig); + + TestThreading Threading(8); + HydrationConfig Config{.ServerStateDir = ServerStateDir, + .TempDir = HydrationTemp, + .ModuleId = "s3async_all_tiers", + .Threading = Threading.Options}; + + // CreateTestTree: small files (pack candidates) + 256K (Small), 512K (Medium), + // 9M (Medium), 63M (Multipart). + auto Files = CreateTestTree(ServerStateDir); + Hydration->CreateHydrator(Config)->Dehydrate(CbObject()); + CHECK(std::filesystem::is_empty(ServerStateDir)); + + Hydration->CreateHydrator(Config)->Hydrate(); + VerifyTree(ServerStateDir, Files); +} + +TEST_CASE("hydration.s3async.concurrent") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = AllocateHydrationMinioTestPort(); + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + constexpr int kModuleCount = 6; + constexpr int kThreadCount = 4; + + TestThreading Threading(kThreadCount); + + ScopedTemporaryDirectory TempDir; + + struct ModuleData + { + HydrationConfig Config; + std::vector<std::pair<std::filesystem::path, IoBuffer>> Files; + }; + std::vector<ModuleData> Modules(kModuleCount); + + HydrationBase::Configuration BaseConfig; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + BaseConfig.Options = std::move(Root).AsObject(); + } + BaseConfig.AsyncEnabled = true; + auto Hydration = InitHydration(BaseConfig); + + for (int I = 0; I < kModuleCount; ++I) + { + std::string ModuleId = fmt::format("s3async_concurrent_{}"sv, I); + std::filesystem::path StateDir = TempDir.Path() / ModuleId / "state"; + std::filesystem::path TempPath = TempDir.Path() / ModuleId / "temp"; + CreateDirectories(StateDir); + CreateDirectories(TempPath); + + Modules[I].Config.ServerStateDir = StateDir; + Modules[I].Config.TempDir = TempPath; + Modules[I].Config.ModuleId = ModuleId; + Modules[I].Config.Threading = Threading.Options; + Modules[I].Files = CreateTestTree(StateDir); + } + + { + WorkerThreadPool Pool(kThreadCount, "hydration_s3async_dehy"); + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + for (int I = 0; I < kModuleCount; ++I) + { + Work.ScheduleWork(Pool, [&Hydration, &Config = Modules[I].Config](std::atomic<bool>&) { + Hydration->CreateHydrator(Config)->Dehydrate(CbObject()); + }); + } + Work.Wait(); + CHECK_FALSE(Work.IsAborted()); + } + + { + WorkerThreadPool Pool(kThreadCount, "hydration_s3async_hy"); + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + for (int I = 0; I < kModuleCount; ++I) + { + Work.ScheduleWork(Pool, [&Hydration, &Config = Modules[I].Config](std::atomic<bool>&) { + CleanDirectory(Config.ServerStateDir, true); + Hydration->CreateHydrator(Config)->Hydrate(); + }); + } + Work.Wait(); + CHECK_FALSE(Work.IsAborted()); + } + + for (int I = 0; I < kModuleCount; ++I) + { + VerifyTree(Modules[I].Config.ServerStateDir, Modules[I].Files); + } +} + +TEST_CASE("hydration.s3async.obliterate") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = AllocateHydrationMinioTestPort(); + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + ScopedTemporaryDirectory TempDir; + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationTemp); + + constexpr std::string_view ModuleId = "s3async_obliterate"sv; + + HydrationBase::Configuration BaseConfig; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + BaseConfig.Options = std::move(Root).AsObject(); + } + BaseConfig.AsyncEnabled = true; + auto Hydration = InitHydration(BaseConfig); + + HydrationConfig Config{.ServerStateDir = ServerStateDir, .TempDir = HydrationTemp, .ModuleId = std::string(ModuleId)}; + + CreateSmallTestTree(ServerStateDir); + Hydration->CreateHydrator(Config)->Dehydrate(CbObject()); + + auto ListModuleObjects = [&]() { + S3ClientOptions Opts; + Opts.BucketName = "zen-hydration-test"; + Opts.Endpoint = Minio.Endpoint(); + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = Minio.RootUser(); + Opts.Credentials.SecretAccessKey = Minio.RootPassword(); + S3Client Client(Opts); + return Client.ListObjects(fmt::format("{}/"sv, ModuleId)); + }; + + CHECK(!ListModuleObjects().Objects.empty()); + + CreateSmallTestTree(ServerStateDir); + WriteFile(HydrationTemp / "leftover.tmp", CreateSemiRandomBlob(64)); + + Hydration->CreateHydrator(Config)->Obliterate(); + + CHECK(ListModuleObjects().Objects.empty()); + CHECK(std::filesystem::is_empty(ServerStateDir)); + CHECK(std::filesystem::is_empty(HydrationTemp)); +} + +TEST_CASE("hydration.s3async.incremental") +{ + MinioProcessOptions MinioOpts; + MinioOpts.Port = AllocateHydrationMinioTestPort(); + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + + ScopedTemporaryDirectory TempDir; + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationTemp); + + constexpr std::string_view ModuleId = "s3async_incremental"sv; + + TestThreading Threading(8); + HydrationBase::Configuration BaseConfig; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + BaseConfig.Options = std::move(Root).AsObject(); + } + BaseConfig.AsyncEnabled = true; + auto Hydration = InitHydration(BaseConfig); + + HydrationConfig Config{.ServerStateDir = ServerStateDir, + .TempDir = HydrationTemp, + .ModuleId = std::string(ModuleId), + .Threading = Threading.Options}; + + // Mirrors hydration.s3.incremental but with Config.IoContext set so the + // async S3AsyncStorageAdapter handles I/O. Each Dehydrate empties + // ServerStateDir as a side effect; subsequent Hydrate/Dehydrate calls + // thread the prior HydrationState so incremental dehydrate can hit its + // cache instead of re-uploading. + + CbObject HydrationState = Hydration->CreateHydrator(Config)->Hydrate(); + CHECK_FALSE(HydrationState); + + auto TestFiles = CreateTestTree(ServerStateDir); + Hydration->CreateHydrator(Config)->Dehydrate(HydrationState); + CHECK(std::filesystem::is_empty(ServerStateDir)); + + HydrationState = Hydration->CreateHydrator(Config)->Hydrate(); + VerifyTree(ServerStateDir, TestFiles); + + Hydration->CreateHydrator(Config)->Dehydrate(HydrationState); + CHECK(std::filesystem::is_empty(ServerStateDir)); + + HydrationState = Hydration->CreateHydrator(Config)->Hydrate(); + + TestFiles = CreateTestTree(ServerStateDir); + Hydration->CreateHydrator(Config)->Dehydrate(HydrationState); + + HydrationState = Hydration->CreateHydrator(Config)->Hydrate(); + VerifyTree(ServerStateDir, TestFiles); + + Hydration->CreateHydrator(Config)->Dehydrate(HydrationState); +} + TEST_SUITE_END(); void diff --git a/src/zenserver/hub/hydration.h b/src/zenserver/hub/hydration.h index d9a3dda5b..55db41738 100644 --- a/src/zenserver/hub/hydration.h +++ b/src/zenserver/hub/hydration.h @@ -103,6 +103,13 @@ public: // settings (e.g. S3 "chunksize") live inside Options["settings"]. The common // `excludes` entry is parsed once by HydrationBase and shared across modules. CbObject Options; + + // Routes S3 hydration through AsyncHttpClient + S3AsyncStorage when true; false + // falls back to the blocking S3Client + S3Storage path. Ignored by FileHydration. + bool AsyncEnabled = false; + + // Per-AsyncHttpClient max in-flight requests. Only consulted when AsyncEnabled. + uint32_t AsyncMaxConcurrentRequests = 128; }; // Parses common Options entries (`excludes`) into m_Excludes, applying built-in diff --git a/src/zenserver/hub/s3asyncstorage.cpp b/src/zenserver/hub/s3asyncstorage.cpp new file mode 100644 index 000000000..b8bbc55b2 --- /dev/null +++ b/src/zenserver/hub/s3asyncstorage.cpp @@ -0,0 +1,2808 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "s3asyncstorage.h" + +#include <zencore/basicfile.h> +#include <zencore/except_fmt.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/timer.h> +#include <zenutil/cloud/s3response.h> + +#include <cstring> + +#if ZEN_WITH_TESTS +# include <zencore/iohash.h> +# include <zencore/process.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zenutil/cloud/minioprocess.h> +#endif + +namespace zen { + +using namespace std::literals; + +namespace { + // Bounded write-buffer pool. OnData callbacks on the io strand call + // Acquire to obtain a fixed-size IoBuffer; workers call Release after + // the dispatched write completes. The pool retains at most as many + // buffers as future Acquire calls remain (PendingAcquires), so it + // stabilises at peak write concurrency without dragging buffers past + // the point they can ever be re-used. + class WriteBufferPool + { + public: + WriteBufferPool(size_t InBufferSize, uint32_t TotalAcquires) : m_BufferSize(InBufferSize), m_PendingAcquires(TotalAcquires) {} + + size_t BufferSize() const { return m_BufferSize; } + + IoBuffer Acquire() + { + IoBuffer Buf; + bool Exhausted = false; + m_Lock.WithExclusiveLock([&] { + if (m_PendingAcquires == 0) + { + Exhausted = true; + return; + } + --m_PendingAcquires; + if (!m_Pool.empty()) + { + Buf = std::move(m_Pool.back()); + m_Pool.pop_back(); + } + }); + if (Exhausted) + { + // Caller (OnData / OnComplete tail flush on the curl io strand) + // over-acquired vs the precomputed TotalBlocks budget. Throw a + // real exception so the boundary catch in + // AsyncCurlStreamWriteCallback can surface it; never let + // ZEN_ASSERT propagate through curl's C frames. + throw zen::runtime_error("WriteBufferPool exhausted: more buffers requested than expected"); + } + if (!Buf) + { + Buf = IoBuffer(m_BufferSize); + } + return Buf; + } + + void Release(IoBuffer Buf) + { + m_Lock.WithExclusiveLock([&] { + if (m_Pool.size() < m_PendingAcquires) + { + m_Pool.push_back(std::move(Buf)); + } + }); + } + + // Returns the slot consumed by Acquire() back to the budget without + // supplying a buffer; for paths where the buffer was already moved into a + // lambda (e.g. ScheduleWork threw after the move). + void RestoreAcquireSlot() + { + m_Lock.WithExclusiveLock([&] { ++m_PendingAcquires; }); + } + + private: + const size_t m_BufferSize; + RwLock m_Lock; + std::vector<IoBuffer> m_Pool; + uint32_t m_PendingAcquires = 0; + }; + + // Acquire one admission slot on the dispatcher thread. Returns a refcounted + // handle whose deleter releases the slot on last drop, so callers thread it + // through worker lambdas and AsyncXxx callbacks; whichever runs last fires + // the release. nullptr when admission is disabled (Sem == nullptr). + // + // Stats may be null when the caller has no per-request stat block (DeleteAll); + // the slot is still held but the wait time is not recorded. + // + // Exception safety: counting_semaphore::acquire is noexcept; only the + // shared_ptr control-block allocation can throw. The standard's deleter- + // invocation guarantee for shared_ptr(p, d) only kicks in when p is non- + // null (p == nullptr here), so libstdc++/MSVC happen to invoke the deleter + // on alloc failure but it is not portably required - hence the explicit + // guard. + std::shared_ptr<void> AcquireAdmissionSlot(const std::shared_ptr<AdmissionSemaphore>& Sem, S3AsyncStorageStats* Stats = nullptr) + { + if (!Sem) + { + return nullptr; + } + Stopwatch AdmWait; + Sem->acquire(); + auto ReleaseGuard = MakeGuard([&Sem] { Sem->release(); }); + std::shared_ptr<void> SlotRef(nullptr, [Sem](void*) { Sem->release(); }); + ReleaseGuard.Dismiss(); + if (Stats) + { + Stats->RecordAdmissionWait(AdmWait.GetElapsedTimeUs()); + } + return SlotRef; + } +} // namespace + +S3AsyncStorage::S3AsyncStorage(AsyncHttpClient& Client, + S3RequestBuilder& Builder, + CredentialsCallback GetCreds, + std::string KeyPrefix, + uint64_t MultipartChunkSize, + std::shared_ptr<AdmissionSemaphore> Admission, + uint32_t AdmissionCap) +: m_Client(Client) +, m_Builder(Builder) +, m_GetCreds(std::move(GetCreds)) +, m_KeyPrefix(std::move(KeyPrefix)) +, m_MultipartChunkSize(MultipartChunkSize) +, m_Admission(std::move(Admission)) +, m_AdmissionCap(m_Admission ? AdmissionCap : 0u) +{ +} + +std::string +S3AsyncStorage::CasKey(const IoHash& Hash) const +{ + return fmt::format("{}/cas/{}", m_KeyPrefix, Hash); +} + +std::string +S3AsyncStorage::CasPath(const IoHash& Hash) const +{ + return m_Builder.KeyToPath(CasKey(Hash)); +} + +// Per-range writes target the prealloc'd file at-offset; BasicFile::Write is positional and concurrency-safe. +struct S3AsyncStorage::GetMultipartState +{ + GetMultipartState(ParallelWork& InWork, WorkerThreadPool& InPool, S3AsyncStorageStats InStats, uint32_t TotalBlocks) + : Work(InWork) + , Pool(InPool) + , Stats(InStats) + , Buffers(WriteBufferSize, TotalBlocks) + { + } + + static constexpr size_t WriteBufferSize = 512u * 1024u; + + ParallelWork& Work; + WorkerThreadPool& Pool; + std::shared_ptr<BasicFile> DestFile; + std::filesystem::path DestPath; + std::shared_ptr<ParallelWork::ExternalWorkToken> Token; + IoHash ContentHash; + S3AsyncStorageStats Stats; + uint32_t TotalRanges = 0; + std::atomic<uint32_t> PendingRanges{0}; + std::atomic<bool> AnyFailed{false}; + RwLock ErrorLock; + std::string FirstError; + + WriteBufferPool Buffers; +}; + +// Shared state for a single-stream GET (medium tier - one ranged request, +// streamed body). OnData fills 512 KiB IoBuffers from a shared pool on the +// io strand and dispatches each filled buffer to a worker for one positional +// Write. PendingWork counts in-flight writes plus a +1 slot for the stream +// itself; the side that drops it to zero (last writer or OnComplete after +// the final flush) finalises the request. +struct S3AsyncStorage::GetStreamState +{ + GetStreamState(ParallelWork& InWork, WorkerThreadPool& InPool, S3AsyncStorageStats InStats, uint32_t TotalBlocks) + : Work(InWork) + , Pool(InPool) + , Stats(InStats) + , Buffers(WriteBufferSize, TotalBlocks) + { + } + + static constexpr size_t WriteBufferSize = 512u * 1024u; + + ParallelWork& Work; + WorkerThreadPool& Pool; + std::shared_ptr<BasicFile> DestFile; + std::filesystem::path DestPath; + std::shared_ptr<ParallelWork::ExternalWorkToken> Token; + IoHash ContentHash; + S3AsyncStorageStats Stats; + uint64_t ExpectedSize = 0; + + IoBuffer ActiveBuf; + size_t BufFill = 0; + uint64_t NextAbsOffset = 0; + uint64_t TotalReceived = 0; + + std::atomic<uint32_t> PendingWork{1}; // +1 for stream completion + std::atomic<bool> Failed{false}; + RwLock ErrorLock; + std::string FirstError; + + WriteBufferPool Buffers; +}; + +// Shared state for an in-flight multipart upload. PendingParts.fetch_sub(acq_rel) +// in FinalizePutPart is the publication barrier for ETag writes; CompleteMultipart +// (gated on Remaining == 1) sees every part's slot. ETagLock guards FirstError; +// AnyFailed short-circuits the pipeline on first error. +struct S3AsyncStorage::PutMultipartState +{ + PutMultipartState(ParallelWork& InWork, WorkerThreadPool& InPool, S3AsyncStorageStats InStats) + : Work(InWork) + , Pool(InPool) + , Stats(InStats) + { + } + + ParallelWork& Work; + WorkerThreadPool& Pool; + std::shared_ptr<BasicFile> File; + std::shared_ptr<ParallelWork::ExternalWorkToken> Token; + IoHash ContentHash; + S3AsyncStorageStats Stats; + // Snapshot of credentials taken once at PutMultipart entry. Reused across + // all per-part signing + Complete/Abort. Avoids per-part m_GetCreds() + // shared-lock contention on the credential provider. + SigV4Credentials Creds; + std::string Key; + std::string Path; + std::string UploadId; + uint64_t TotalSize = 0; + uint64_t PartSize = 0; + uint32_t TotalParts = 0; + std::atomic<uint32_t> PendingParts{0}; + // Monotonic dispatch cursor. Initialized to PreAcquired (min(TotalParts, + // AdmissionCap)) in PutMultipart. Each AsyncPut completion racing the + // handoff path issues fetch_add(1); whichever completion observes a value + // < TotalParts owns the next dispatch. + std::atomic<uint32_t> NextPartToDispatch{0}; + uint32_t PreAcquired = 0; + // Counts pending ReadRange calls into File. Used to release File from the + // worker thread that does the last read so the close runs off the curl + // io strand. + std::atomic<uint32_t> ReadsRemaining{0}; + std::atomic<bool> AnyFailed{false}; + + RwLock ETagLock; + std::vector<std::string> ETags; // indexed by PartNumber-1 + std::string FirstError; +}; + +namespace { + // Sets AnyFailed (CAS) and captures FirstError on first failure only. + // Does NOT touch the dispatch cursor; callers that need to drain the + // undispatched tail invoke ClaimUndispatchedParts and fan out skips. + void RecordPutPartFailure(S3AsyncStorage::PutMultipartState& State, const std::string& Err) + { + // Log every failure so correlated S3 errors (e.g. AccessDenied on one + // part, ServiceUnavailable on another) all reach the operator. CAS still + // keeps the first message for the eventual Token->Fail. + ZEN_WARN("S3AsyncStorage::PutMultipart '{}': {}", State.ContentHash, Err); + bool ExpectedFalse = false; + if (State.AnyFailed.compare_exchange_strong(ExpectedFalse, true)) + { + State.ETagLock.WithExclusiveLock([&] { State.FirstError = Err; }); + } + } + + // Claims every part index from NextPartToDispatch up to TotalParts. + // Returns the count claimed; caller is responsible for firing one + // FinalizePutPart per claimed index so PendingParts can reach 1 and + // trip CompleteMultipart/AbortMultipart. Idempotent on repeated calls + // (subsequent invocations return 0). + uint32_t ClaimUndispatchedParts(S3AsyncStorage::PutMultipartState& State) + { + const uint32_t Claimed = State.NextPartToDispatch.exchange(State.TotalParts, std::memory_order_acq_rel); + return Claimed < State.TotalParts ? State.TotalParts - Claimed : 0u; + } + + void RecordGetPartFailure(S3AsyncStorage::GetMultipartState& State, const std::string& Err) + { + ZEN_WARN("S3AsyncStorage::GetMultipart '{}': {}", State.ContentHash, Err); + bool ExpectedFalse = false; + if (State.AnyFailed.compare_exchange_strong(ExpectedFalse, true)) + { + State.ErrorLock.WithExclusiveLock([&] { State.FirstError = Err; }); + } + } + + void RecordGetStreamFailure(S3AsyncStorage::GetStreamState& State, const std::string& Err) + { + ZEN_WARN("S3AsyncStorage::Get '{}': {}", State.ContentHash, Err); + bool ExpectedFalse = false; + if (State.Failed.compare_exchange_strong(ExpectedFalse, true)) + { + State.ErrorLock.WithExclusiveLock([&] { State.FirstError = Err; }); + } + } +} // namespace + +// Three tiers selected by Size: +// - Small (< kPutSmallThreshold): single PUT, body materialized into one IoBuffer; SHA in one pass. +// - Medium (< MultipartThreshold): single PUT, streaming source; two-pass (hash, then upload) with bounded RAM. +// - Large (>= MultipartThreshold): S3 multipart; per-part body materialized in DispatchPartUpload. +namespace { + constexpr uint64_t kPutSmallThreshold = 512u * 1024u; + constexpr size_t kPutReadChunk = 256u * 1024u; +} // namespace + +void +S3AsyncStorage::Put(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats) +{ + const uint64_t MultipartThreshold = m_MultipartChunkSize + (m_MultipartChunkSize / 4); + if (Size >= MultipartThreshold) + { + PutMultipart(Work, Pool, Hash, Size, SourcePath, Stats); + return; + } + if (Size < kPutSmallThreshold) + { + PutSmall(Work, Pool, Hash, Size, SourcePath, Stats); + return; + } + PutMedium(Work, Pool, Hash, Size, SourcePath, Stats); +} + +void +S3AsyncStorage::PutSmall(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats) +{ + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + // Acquire one admission slot on the dispatcher; SlotRef threads through + // worker lambda + AsyncPut callback so the slot is held until the network + // transfer completes (or the chain is destroyed on shutdown/cancel). + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission, &Stats); + + // Snapshot creds once on the dispatcher (single shared-lock acquire on the + // credential provider) and capture by value into the worker. Mirrors + // PutMultipart and avoids per-fanout contention on m_GetCreds. + SigV4Credentials Creds = m_GetCreds(); + + // Capture Stats BY VALUE: S3AsyncStorageStats is a struct of references to + // the long-lived PhaseStats atomics (PhaseStats outlives all in-flight + // transfers via the ParallelWork latch). The Stats wrapper struct itself + // is a local in the caller (S3AsyncStorageAdapter::Put) and goes out of + // scope before deferred work runs, so capturing &Stats would dangle. + Work.ScheduleWork( + Pool, + [this, + Hash = IoHash(Hash), + Size, + SourcePath = std::filesystem::path(SourcePath), + Token, + Stats, + Creds = std::move(Creds), + SlotRef = std::move(SlotRef)](std::atomic<bool>& Abort) mutable { + if (Abort.load()) + { + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutSmall '{}': aborted"sv, Hash))); + return; + } + try + { + if (Creds.AccessKeyId.empty()) + { + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutSmall '{}': no credentials available"sv, Hash))); + return; + } + + BasicFile File(SourcePath, BasicFile::Mode::kRead); + if (File.FileSize() != Size) + { + Token->Fail(std::make_exception_ptr( + zen::runtime_error("S3AsyncStorage::PutSmall '{}': source size {} differs from declared {}"sv, + Hash, + File.FileSize(), + Size))); + return; + } + + // Heap IoBuffer of exact size; explicit chunked pread fills it. No + // IoBuffer auto-materialize, no mmap. SHA computed in one pass over + // the buffer once it's filled. + IoBuffer Body(static_cast<size_t>(Size)); + uint8_t* Dst = static_cast<uint8_t*>(Body.MutableData()); + uint64_t Off = 0; + while (Off < Size) + { + const size_t Take = static_cast<size_t>(std::min<uint64_t>(kPutReadChunk, Size - Off)); + File.Read(Dst + Off, Take, Off); + Off += Take; + } + + std::string PayloadHash = Sha256ToHex(ComputeSha256(Body.GetData(), Body.GetSize())); + std::string Path = CasPath(Hash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "PUT", Path, "", PayloadHash); + + Stopwatch Timer = Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncPut (lambda capture + // before submit) does not strand InFlight elevated. + auto InFlightGuard = MakeGuard([&Stats] { Stats.EndRequest(0, 0); }); + m_Client.AsyncPut( + Path, + std::move(Body), + [Hash, Token, Timer, Stats, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? static_cast<uint64_t>(Resp.UploadedBytes) : 0); + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 PUT failed", Resp); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutSmall '{}': {}"sv, Hash, Err))); + return; + } + Token->Complete(); + }, + Headers); + InFlightGuard.Dismiss(); + } + catch (...) + { + Token->Fail(std::current_exception()); + } + }); +} + +void +S3AsyncStorage::PutMedium(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats) +{ + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission, &Stats); + + // Snapshot creds on the dispatcher; see PutSmall. + SigV4Credentials Creds = m_GetCreds(); + + Work.ScheduleWork( + Pool, + [this, + Hash = IoHash(Hash), + Size, + SourcePath = std::filesystem::path(SourcePath), + Token, + Stats, + Creds = std::move(Creds), + SlotRef = std::move(SlotRef)](std::atomic<bool>& Abort) mutable { + if (Abort.load()) + { + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMedium '{}': aborted"sv, Hash))); + return; + } + try + { + if (Creds.AccessKeyId.empty()) + { + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMedium '{}': no credentials available"sv, Hash))); + return; + } + + // Hash pass: stream-read 256 KiB chunks into a heap scratch buffer, + // feed Sha256Stream incrementally. No body materialization. The file + // handle is shared with the upload pass via shared_ptr so libcurl's + // READ callback can pread from the same handle on the io thread. + auto File = std::make_shared<BasicFile>(SourcePath, BasicFile::Mode::kRead); + if (File->FileSize() != Size) + { + Token->Fail(std::make_exception_ptr( + zen::runtime_error("S3AsyncStorage::PutMedium '{}': source size {} differs from declared {}"sv, + Hash, + File->FileSize(), + Size))); + return; + } + + Sha256Stream Hasher; + { + std::unique_ptr<uint8_t[]> Scratch(new uint8_t[kPutReadChunk]); + uint64_t Off = 0; + while (Off < Size) + { + const size_t Take = static_cast<size_t>(std::min<uint64_t>(kPutReadChunk, Size - Off)); + File->Read(Scratch.get(), Take, Off); + Hasher.Update(Scratch.get(), Take); + Off += Take; + } + } + std::string PayloadHash = Sha256ToHex(Hasher.Finalize()); + std::string Path = CasPath(Hash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "PUT", Path, "", PayloadHash); + + Stopwatch Timer = Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncPut does not + // strand InFlight elevated. + auto InFlightGuard = MakeGuard([&Stats] { Stats.EndRequest(0, 0); }); + + // Streaming source: libcurl pulls 256 KiB chunks from the file on + // the io thread. Local-disk pread is fast enough for medium tier; + // page cache makes the second pass over the same bytes near-free. + m_Client.AsyncPut( + Path, + Size, + [File](uint8_t* DstBuf, size_t MaxBytes, uint64_t AbsOffset) -> size_t { + const uint64_t Remaining = File->FileSize() > AbsOffset ? File->FileSize() - AbsOffset : 0; + const size_t Take = static_cast<size_t>(std::min<uint64_t>(MaxBytes, Remaining)); + if (Take == 0) + { + return 0; + } + File->Read(DstBuf, Take, AbsOffset); + return Take; + }, + [Hash, Token, Timer, Size, Stats, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? Size : 0); + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 PUT failed", Resp); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMedium '{}': {}"sv, Hash, Err))); + return; + } + Token->Complete(); + }, + Headers); + InFlightGuard.Dismiss(); + } + catch (...) + { + Token->Fail(std::current_exception()); + } + }); +} + +void +S3AsyncStorage::PutMultipart(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats) +{ + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + // Fires only when an exception is in flight; a missed Dismiss() on a + // non-throwing early return falls through to Token's destructor safety + // net rather than asserting on Token->Fail(nullptr). + auto FailGuard = MakeGuard([Token] { + if (auto Ex = std::current_exception()) + { + Token->Fail(Ex); + } + }); + + SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + FailGuard.Dismiss(); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': no credentials available"sv, Hash))); + return; + } + + auto File = std::make_shared<BasicFile>(SourcePath, BasicFile::Mode::kRead); + if (File->FileSize() != Size) + { + FailGuard.Dismiss(); + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': source size {} differs from declared {}"sv, + Hash, + File->FileSize(), + Size))); + return; + } + + auto State = std::make_shared<PutMultipartState>(Work, Pool, Stats); + State->File = std::move(File); + State->Token = Token; + State->ContentHash = Hash; + State->Creds = Creds; // snapshot reused across all parts + complete/abort + State->Key = CasKey(Hash); + State->Path = CasPath(Hash); + State->TotalSize = Size; + State->PartSize = m_MultipartChunkSize; + State->TotalParts = static_cast<uint32_t>((Size + State->PartSize - 1) / State->PartSize); + State->PendingParts = State->TotalParts; + State->ReadsRemaining = State->TotalParts; + State->ETags.resize(State->TotalParts); + + // Admission gating: the wave acquires min(TotalParts, AdmissionCap) slots + // inside DispatchInitialPartWave on a worker thread (off the dispatcher + // and off the io strand) so a single multipart cannot starve other + // dispatcher work while it drains the semaphore. Tail parts (TotalParts > + // cap) are dispatched lazily by HandoffSlotToNextPart from each AsyncPut + // completion, so the in-flight count per upload stays <= cap. + State->PreAcquired = m_Admission ? std::min<uint32_t>(State->TotalParts, m_AdmissionCap) : State->TotalParts; + State->NextPartToDispatch.store(State->PreAcquired, std::memory_order_relaxed); + + std::string CanonicalQs = BuildCanonicalQueryString({{"uploads", ""}}); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "POST", State->Path, CanonicalQs, S3EmptyPayloadHash); + std::string FullPath = S3BuildRequestPath(State->Path, CanonicalQs); + + // CreateMultipartUpload / CompleteMultipartUpload / AbortMultipartUpload are + // metadata operations - they do not move payload bytes and would distort the + // per-request stats (avg/max latency) if folded into RequestCount. Counted + // separately here means PhaseStats.RequestCount tracks data-bearing PUTs only, + // matching the sync S3Storage path's accounting more closely. + + m_Client.AsyncPost( + FullPath, + [this, State](HttpClient::Response Resp) { + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 CreateMultipartUpload failed", Resp); + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': {}"sv, State->ContentHash, Err))); + return; + } + // See PutMultipart entry: tolerate null exception_ptr on a missed Dismiss(). + auto CallbackGuard = MakeGuard([State] { + if (auto Ex = std::current_exception()) + { + State->Token->Fail(Ex); + } + }); + std::string_view Body = Resp.AsText(); + // S3 can answer 200 with an embedded <Error> body even on Create. + // Mirror CompleteMultipart's parse so the original error code/message + // surfaces instead of a generic "missing UploadId". + std::string_view ErrorCode; + std::string_view ErrorMessage; + if (S3ExtractError(Body, ErrorCode, ErrorMessage)) + { + CallbackGuard.Dismiss(); + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': create returned error: {} - {}"sv, + State->ContentHash, + ErrorCode, + ErrorMessage))); + return; + } + std::string_view UploadId = S3ExtractXmlValue(Body, "UploadId"); + if (UploadId.empty()) + { + CallbackGuard.Dismiss(); + State->Token->Fail(std::make_exception_ptr( + zen::runtime_error("S3AsyncStorage::PutMultipart '{}': missing UploadId in CreateMultipartUpload response"sv, + State->ContentHash))); + return; + } + State->UploadId = std::string(UploadId); + // Hop the wave dispatch onto a worker so the per-part admission + // acquire never blocks the io strand or the caller dispatcher. + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State](std::atomic<bool>&) { DispatchInitialPartWave(State); }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (...) + { + CallbackGuard.Dismiss(); + RecordPutPartFailure(*State, "S3 UploadPart wave schedule failed"); + for (uint32_t J = 0; J < State->TotalParts; ++J) + { + FinalizePutPart(State); + } + return; + } + CallbackGuard.Dismiss(); + }, + Headers); + FailGuard.Dismiss(); +} + +void +S3AsyncStorage::DispatchInitialPartWave(std::shared_ptr<PutMultipartState> State) +{ + const bool AdmissionEnabled = (m_Admission != nullptr); + const uint32_t InitialDispatch = AdmissionEnabled ? State->PreAcquired : State->TotalParts; + + for (uint32_t I = 0; I < InitialDispatch; ++I) + { + const uint32_t PartNum = I + 1; + std::shared_ptr<void> SlotRef; + try + { + if (AdmissionEnabled) + { + SlotRef = AcquireAdmissionSlot(m_Admission, &State->Stats); + } + State->Work.ScheduleWork( + State->Pool, + [this, State, PartNum, SlotRef = std::move(SlotRef)](std::atomic<bool>&) mutable { + DispatchPartUpload(State, PartNum, std::move(SlotRef)); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + // Acquire/schedule failed mid-wave. Account for this part plus + // the un-iterated wave tail, then drain the lazy tail past the + // wave. Slot (if acquired) releases when SlotRef destructs. + RecordPutPartFailure(*State, fmt::format("S3 UploadPart {} schedule failed: {}", PartNum, Ex.what())); + const uint32_t WaveTail = InitialDispatch - I; + for (uint32_t J = 0; J < WaveTail; ++J) + { + FinalizePutPart(State); + } + DrainUndispatchedParts(State); + return; + } + } +} + +void +S3AsyncStorage::HandoffSlotToNextPart(std::shared_ptr<PutMultipartState> State, uint32_t PartIdx, std::shared_ptr<void> SlotRef) +{ + const uint32_t PartNum = PartIdx + 1; + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State, PartNum, SlotRef = std::move(SlotRef)](std::atomic<bool>&) mutable { + DispatchPartUpload(State, PartNum, std::move(SlotRef)); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + // Handoff dispatch failed. The slot we were about to hand off + // releases when SlotRef destructs at scope exit. Account for this + // never-dispatched part and drain the rest of the tail. + RecordPutPartFailure(*State, fmt::format("S3 UploadPart {} handoff failed: {}", PartNum, Ex.what())); + FinalizePutPart(State); + DrainUndispatchedParts(State); + } +} + +void +S3AsyncStorage::DrainUndispatchedParts(std::shared_ptr<PutMultipartState> State) +{ + const uint32_t Skipped = ClaimUndispatchedParts(*State); + for (uint32_t J = 0; J < Skipped; ++J) + { + FinalizePutPart(State); + } +} + +#if ZEN_WITH_TESTS +namespace s3asyncstorage_test_hooks { + std::atomic<uint32_t> g_ForceNextPartFailures{0}; + void ForceNextPartFailures(uint32_t Count) { g_ForceNextPartFailures.store(Count, std::memory_order_relaxed); } +} // namespace s3asyncstorage_test_hooks +#endif + +void +S3AsyncStorage::DispatchPartUpload(std::shared_ptr<PutMultipartState> State, uint32_t PartNum, std::shared_ptr<void> SlotRef) +{ + const uint64_t Offset = static_cast<uint64_t>(PartNum - 1) * State->PartSize; + const uint64_t ChunkSize = std::min<uint64_t>(State->PartSize, State->TotalSize - Offset); + + // Release File on this worker thread once all reads have been served, so + // the close never runs on the curl io strand from a captured State ref in + // an AsyncPut callback. + auto ReleaseFileIfLast = [&State] { + if (State->ReadsRemaining.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + State->File.reset(); + } + }; + +#if ZEN_WITH_TESTS + // Test hook: synthesize a part-level failure to drive the AbortMultipart + // path. Decrement the counter while > 0; each consumed slot fails one part. + { + uint32_t Cur = s3asyncstorage_test_hooks::g_ForceNextPartFailures.load(std::memory_order_relaxed); + while (Cur > 0 && + !s3asyncstorage_test_hooks::g_ForceNextPartFailures.compare_exchange_weak(Cur, Cur - 1, std::memory_order_acq_rel)) + { + } + if (Cur > 0) + { + ReleaseFileIfLast(); + RecordPutPartFailure(*State, fmt::format("test-injected part failure (part {})", PartNum)); + FinalizePutPart(State); + DrainUndispatchedParts(State); + return; + } + } +#endif + + // Short-circuit if a foreign failure has flipped AnyFailed since this part was + // scheduled (handoff race: success callback bumps the cursor to dispatch the + // next part, drain runs concurrently). Skip the file read + sign + AsyncPut for + // a doomed transfer; FinalizePutPart accounting is symmetric to the success path. + if (State->AnyFailed.load(std::memory_order_acquire)) + { + ReleaseFileIfLast(); + FinalizePutPart(State); + return; + } + + // Explicit chunked positional pread + Sha256Stream in one pass. One heap + // IoBuffer of exact part size; no mmap, no IoBuffer auto-materialize. + IoBuffer Part(static_cast<size_t>(ChunkSize)); + uint8_t* Dst = static_cast<uint8_t*>(Part.MutableData()); + Sha256Stream Hasher; + try + { + uint64_t Read = 0; + while (Read < ChunkSize) + { + const size_t Take = static_cast<size_t>(std::min<uint64_t>(kPutReadChunk, ChunkSize - Read)); + State->File->Read(Dst + Read, Take, Offset + Read); + Hasher.Update(Dst + Read, Take); + Read += Take; + } + ReleaseFileIfLast(); + } + catch (const std::exception& Ex) + { + ReleaseFileIfLast(); + RecordPutPartFailure(*State, fmt::format("S3 UploadPart {} read failed: {}", PartNum, Ex.what())); + FinalizePutPart(State); + DrainUndispatchedParts(State); + return; + } + + try + { + // Use the snapshot taken at PutMultipart entry; saves ~TotalParts + // shared-lock acquisitions on the credential provider. + const SigV4Credentials& Creds = State->Creds; + + std::string CanonicalQs = BuildCanonicalQueryString({ + {"partNumber", fmt::format("{}", PartNum)}, + {"uploadId", State->UploadId}, + }); + std::string PayloadHash = Sha256ToHex(Hasher.Finalize()); + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "PUT", State->Path, CanonicalQs, PayloadHash); + std::string FullPath = S3BuildRequestPath(State->Path, CanonicalQs); + + Stopwatch Timer = State->Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncPut does not strand + // InFlight elevated. + auto InFlightGuard = MakeGuard([&State] { State->Stats.EndRequest(0, 0); }); + + m_Client.AsyncPut( + FullPath, + std::move(Part), + [this, State, PartNum, Timer, ChunkSize, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + State->Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? ChunkSize : 0); + if (!Resp.IsSuccess()) + { + RecordPutPartFailure(*State, S3ErrorMessage(fmt::format("S3 UploadPart {} failed", PartNum), Resp)); + DrainUndispatchedParts(State); + FinalizePutPart(State); + return; + } + std::string_view ETag = Resp.FindHeader("etag"); + if (ETag.empty()) + { + RecordPutPartFailure(*State, fmt::format("S3 UploadPart {} response missing ETag header", PartNum)); + DrainUndispatchedParts(State); + FinalizePutPart(State); + return; + } + // Per-slot writes don't race - each part owns its slot. + // PendingParts.fetch_sub(acq_rel) in FinalizePutPart is the + // publication barrier: CompleteMultipart (gated on Remaining + // == 1) sees every ETag write. New readers must go through + // the same barrier or take ETagLock. + State->ETags[PartNum - 1].assign(ETag); + + // Try to hand the slot off to the next undispatched part. Skip + // the handoff if a prior failure has flipped AnyFailed (cursor + // was already exchanged to TotalParts by the drain there). + if (!State->AnyFailed.load(std::memory_order_acquire)) + { + const uint32_t NextIdx = State->NextPartToDispatch.fetch_add(1, std::memory_order_acq_rel); + if (NextIdx < State->TotalParts) + { + HandoffSlotToNextPart(State, NextIdx, std::move(SlotRef)); + } + } + // Account for THIS part. SlotRef may already be moved-from + // (if handed off); if not, releases when this lambda destructs. + FinalizePutPart(State); + }, + Headers); + InFlightGuard.Dismiss(); + } + catch (const std::exception& Ex) + { + RecordPutPartFailure(*State, fmt::format("S3 UploadPart {} dispatch failed: {}", PartNum, Ex.what())); + FinalizePutPart(State); + DrainUndispatchedParts(State); + } +} + +void +S3AsyncStorage::FinalizePutPart(std::shared_ptr<PutMultipartState> State) +{ + const uint32_t Remaining = State->PendingParts.fetch_sub(1, std::memory_order_acq_rel); + if (Remaining != 1) + { + return; + } + + // Hop the Complete/Abort dispatch off the curl io strand. SHA over the + // parts-XML and SignRequest are CPU work; AsyncPost itself queues a strand + // wakeup. Running both inline here would pin the strand thread. + const bool Failed = State->AnyFailed.load(std::memory_order_acquire); + try + { + State->Work.ScheduleWork(State->Pool, [this, State, Failed](std::atomic<bool>&) { + if (Failed) + { + AbortMultipart(State); + } + else + { + CompleteMultipart(State); + } + }); + } + catch (...) + { + // ScheduleWork failed before Complete/Abort could run; surface the leak via Token->Fail. + State->Token->Fail(std::current_exception()); + } +} + +void +S3AsyncStorage::CompleteMultipart(std::shared_ptr<PutMultipartState> State) +{ + // Reuse the snapshot taken at PutMultipart entry. The upload-id was issued + // against these creds; refreshing here would only matter if creds expired + // mid-upload, in which case all the part uploads would already have failed. + const SigV4Credentials& Creds = State->Creds; + + ExtendableStringBuilder<1024> Xml; + Xml.Append("<CompleteMultipartUpload>"); + for (uint32_t I = 0; I < State->TotalParts; ++I) + { + Xml.Append(fmt::format("<Part><PartNumber>{}</PartNumber><ETag>{}</ETag></Part>", I + 1, State->ETags[I])); + } + Xml.Append("</CompleteMultipartUpload>"); + std::string_view XmlView = Xml.ToView(); + + std::string CanonicalQs = BuildCanonicalQueryString({{"uploadId", State->UploadId}}); + std::string PayloadHash = Sha256ToHex(ComputeSha256(XmlView)); + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "POST", State->Path, CanonicalQs, PayloadHash); + std::string FullPath = S3BuildRequestPath(State->Path, CanonicalQs); + + IoBuffer Payload(IoBuffer::Clone, XmlView.data(), XmlView.size()); + + // Metadata op; not counted in PhaseStats (see CreateMultipartUpload comment). + m_Client.AsyncPost( + FullPath, + Payload, + ZenContentType::kXML, + [State](HttpClient::Response Resp) { + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 CompleteMultipartUpload failed", Resp); + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': {}"sv, State->ContentHash, Err))); + return; + } + std::string_view Body = Resp.AsText(); + std::string_view ErrorCode; + std::string_view ErrorMessage; + if (S3ExtractError(Body, ErrorCode, ErrorMessage)) + { + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': complete returned error: {} - {}"sv, + State->ContentHash, + ErrorCode, + ErrorMessage))); + return; + } + State->Token->Complete(); + }, + Headers); +} + +void +S3AsyncStorage::AbortMultipart(std::shared_ptr<PutMultipartState> State) +{ + // Reuse the snapshot taken at PutMultipart entry. + const SigV4Credentials& Creds = State->Creds; + + std::string CanonicalQs = BuildCanonicalQueryString({{"uploadId", State->UploadId}}); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "DELETE", State->Path, CanonicalQs, S3EmptyPayloadHash); + std::string FullPath = S3BuildRequestPath(State->Path, CanonicalQs); + + // Metadata op; not counted in PhaseStats (see CreateMultipartUpload comment). + m_Client.AsyncDelete( + FullPath, + [State](HttpClient::Response Resp) { + std::string FirstError; + State->ETagLock.WithExclusiveLock([&] { FirstError = std::move(State->FirstError); }); + if (!Resp.IsSuccess()) + { + State->Token->Fail(std::make_exception_ptr( + zen::runtime_error("S3AsyncStorage::PutMultipart '{}': part failed ({}); abort also failed: {}"sv, + State->ContentHash, + FirstError, + S3ErrorMessage("S3 AbortMultipartUpload failed", Resp)))); + return; + } + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::PutMultipart '{}': {}"sv, State->ContentHash, FirstError))); + }, + Headers); +} + +void +S3AsyncStorage::Get(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath, + S3AsyncStorageStats& Stats) +{ + // Three tiers: in-memory AsyncGet (< 512 KiB), AsyncStream into pooled + // 512 KiB write buffers (medium), GetMultipart (>= MultiRangeThreshold). + constexpr uint64_t StreamingThreshold = 512u * 1024u; + + const uint64_t MultiRangeThreshold = m_MultipartChunkSize + (m_MultipartChunkSize / 4); + if (Size >= MultiRangeThreshold) + { + GetMultipart(Work, Pool, Hash, Size, DestinationPath, Stats); + return; + } + + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + // See PutMultipart: tolerate null exception_ptr from a refactor that + // adds a non-throwing early return without Dismiss(). + auto FailGuard = MakeGuard([Token] { + if (auto Ex = std::current_exception()) + { + Token->Fail(Ex); + } + }); + + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission, &Stats); + + SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + FailGuard.Dismiss(); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': no credentials available"sv, Hash))); + return; + } + + std::string Path = CasPath(Hash); + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "GET", Path, "", S3EmptyPayloadHash); + + Stopwatch Timer = Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncGet/AsyncStream does not + // strand InFlight elevated. + auto InFlightGuard = MakeGuard([&Stats] { Stats.EndRequest(0, 0); }); + + if (Size < StreamingThreshold) + { + // Small: in-memory AsyncGet, single worker-hop write at completion. + m_Client.AsyncGet( + Path, + [Hash = IoHash(Hash), Token, Timer, Stats, DestPath = DestinationPath, Size, &Work, &Pool, SlotRef = std::move(SlotRef)]( + HttpClient::Response Resp) mutable { + Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? Resp.ResponsePayload.GetSize() : 0); + // No remove() on the failure / size-mismatch paths: the destination + // file is only created by the worker hop below on success, so + // nothing exists to delete here. + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 GET failed", Resp); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': {}"sv, Hash, Err))); + return; + } + if (Resp.ResponsePayload.GetSize() != Size) + { + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': received {} bytes, expected {}"sv, + Hash, + Resp.ResponsePayload.GetSize(), + Size))); + return; + } + try + { + Work.ScheduleWork( + Pool, + [Token, DestPath, Hash, Payload = std::move(Resp.ResponsePayload)](std::atomic<bool>&) mutable { + try + { + BasicFile Out(DestPath, BasicFile::Mode::kTruncate); + Out.Write(Payload.GetData(), Payload.GetSize(), 0); + Token->Complete(); + } + catch (const std::exception& Ex) + { + std::error_code Ec; + std::filesystem::remove(DestPath, Ec); + Token->Fail(std::make_exception_ptr( + zen::runtime_error("S3AsyncStorage::Get '{}': write failed: {}"sv, Hash, Ex.what()))); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': schedule failed: {}"sv, Hash, Ex.what()))); + } + }, + Headers); + InFlightGuard.Dismiss(); + FailGuard.Dismiss(); + return; + } + + // Medium tier: AsyncStream + GetStreamState (see struct doc). + const uint32_t StreamTotalBlocks = + static_cast<uint32_t>((Size + GetStreamState::WriteBufferSize - 1) / GetStreamState::WriteBufferSize); + auto State = std::make_shared<GetStreamState>(Work, Pool, Stats, StreamTotalBlocks); + State->Token = Token; + State->ContentHash = Hash; + State->DestPath = DestinationPath; + State->ExpectedSize = Size; + + try + { + State->DestFile = std::make_shared<BasicFile>(DestinationPath, BasicFile::Mode::kTruncate); + std::error_code PrepareEc = PrepareFileForScatteredWrite(State->DestFile->Handle(), Size); + if (PrepareEc) + { + throw zen::runtime_error("PrepareFileForScatteredWrite failed: {}"sv, PrepareEc.message()); + } + } + catch (const std::exception& Ex) + { + FailGuard.Dismiss(); + // Drop DestFile (if opened) before remove() so Windows lets the unlink + // proceed; otherwise we leave a 0-byte file behind on the prealloc path. + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(DestinationPath, Ec); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': preallocate failed: {}"sv, Hash, Ex.what()))); + return; + } + + m_Client.AsyncStream( + Path, + [this, State](const uint8_t* Data, size_t SizeBytes, uint64_t /*TotalSize*/) -> bool { + constexpr size_t WriteBufferSize = GetStreamState::WriteBufferSize; + if (State->TotalReceived + SizeBytes > State->ExpectedSize) + { + RecordGetStreamFailure( + *State, + fmt::format("S3 GET overflow: got {} bytes past expected {}", State->TotalReceived + SizeBytes, State->ExpectedSize)); + return false; + } + const uint8_t* Cur = Data; + size_t Remaining = SizeBytes; + while (Remaining > 0) + { + if (!State->ActiveBuf) + { + State->ActiveBuf = State->Buffers.Acquire(); + } + const size_t Avail = WriteBufferSize - State->BufFill; + const size_t Take = std::min(Remaining, Avail); + std::memcpy(State->ActiveBuf.MutableData<uint8_t>() + State->BufFill, Cur, Take); + State->BufFill += Take; + State->TotalReceived += Take; + Cur += Take; + Remaining -= Take; + if (State->BufFill == WriteBufferSize) + { + IoBuffer Buf = std::move(State->ActiveBuf); + const size_t Fill = State->BufFill; + const uint64_t AbsOffset = State->NextAbsOffset; + State->NextAbsOffset += Fill; + State->BufFill = 0; + State->PendingWork.fetch_add(1, std::memory_order_acq_rel); + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State, Buf = std::move(Buf), Fill, AbsOffset](std::atomic<bool>&) mutable { + try + { + State->DestFile->Write(Buf.GetData(), Fill, AbsOffset); + } + catch (const std::exception& Ex) + { + RecordGetStreamFailure(*State, fmt::format("S3 GET write failed: {}", Ex.what())); + } + State->Buffers.Release(std::move(Buf)); + if (State->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetStreamFinalised(State); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + // ScheduleWork failed before the worker took ownership of the + // PendingWork increment we did above; balance it here and + // abort the transfer so OnComplete tears down cleanly. Buf is + // gone (moved into the lambda then destroyed when ScheduleWork + // threw); restore the pool's acquire slot so OnComplete's + // tail-flush can still Acquire if needed. + RecordGetStreamFailure(*State, fmt::format("S3 GET schedule failed: {}", Ex.what())); + State->Buffers.RestoreAcquireSlot(); + State->PendingWork.fetch_sub(1, std::memory_order_acq_rel); + return false; + } + } + } + return true; + }, + [this, State, Timer, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + State->Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? State->TotalReceived : 0); + bool Failed = !Resp.IsSuccess(); + if (Failed) + { + RecordGetStreamFailure(*State, S3ErrorMessage("S3 GET failed", Resp)); + } + else if (State->TotalReceived != State->ExpectedSize) + { + RecordGetStreamFailure(*State, + fmt::format("S3 GET wrote {} bytes, expected {}", State->TotalReceived, State->ExpectedSize)); + Failed = true; + } + if (!Failed && State->BufFill > 0) + { + IoBuffer Buf = std::move(State->ActiveBuf); + const size_t Fill = State->BufFill; + const uint64_t AbsOffset = State->NextAbsOffset; + State->BufFill = 0; + State->PendingWork.fetch_add(1, std::memory_order_acq_rel); + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State, Buf = std::move(Buf), Fill, AbsOffset](std::atomic<bool>&) mutable { + try + { + State->DestFile->Write(Buf.GetData(), Fill, AbsOffset); + } + catch (const std::exception& Ex) + { + RecordGetStreamFailure(*State, fmt::format("S3 GET write failed: {}", Ex.what())); + } + State->Buffers.Release(std::move(Buf)); + if (State->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetStreamFinalised(State); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + RecordGetStreamFailure(*State, fmt::format("S3 GET tail-flush schedule failed: {}", Ex.what())); + State->Buffers.RestoreAcquireSlot(); + State->PendingWork.fetch_sub(1, std::memory_order_acq_rel); + } + } + // PendingWork = 1 (stream slot, set in ctor) + N (per-buffer worker + // dispatched in OnData) + (optional 1 for the tail-flush dispatched + // just above). The fetch_sub returning 1 means we observed the value + // that was 1 right before our decrement - i.e. we are the last + // participant. Either this branch (no tail flush, no in-flight + // workers) or the last per-buffer worker performs the finalise. Both + // paths use acq_rel so the work-item writes are visible to whichever + // thread runs OnGetStreamFinalised. + if (State->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetStreamFinalised(State); + } + }, + Headers); + InFlightGuard.Dismiss(); + FailGuard.Dismiss(); +} + +void +S3AsyncStorage::OnGetStreamFinalised(std::shared_ptr<GetStreamState> State) +{ + // All disk I/O (close + unlink on failure) is hopped into the worker pool. + // The last PendingWork dec can land on the curl io strand, and inline + // disk syscalls there would block the curl_multi poll loop. + const bool Failed = State->Failed.load(std::memory_order_acquire); + try + { + State->Work.ScheduleWork( + State->Pool, + [State, Failed](std::atomic<bool>&) { + if (Failed) + { + std::string Err; + State->ErrorLock.WithExclusiveLock([&] { Err = std::move(State->FirstError); }); + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(State->DestPath, Ec); + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': {}"sv, State->ContentHash, Err))); + return; + } + + // No Flush: consumer reads via page cache; crash recovery re-hydrates. + State->DestFile.reset(); + State->Token->Complete(); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (...) + { + // Schedule failed; release file handle inline and surface via Token. + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(State->DestPath, Ec); + State->Token->Fail(std::current_exception()); + } +} + +void +S3AsyncStorage::GetMultipart(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath, + S3AsyncStorageStats& Stats) +{ + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + const uint64_t Chunk = m_MultipartChunkSize; + const uint32_t RangeCount = static_cast<uint32_t>((Size + Chunk - 1) / Chunk); + uint32_t TotalBlocks = 0; + for (uint32_t I = 0; I < RangeCount; ++I) + { + const uint64_t RS = std::min<uint64_t>(Chunk, Size - static_cast<uint64_t>(I) * Chunk); + TotalBlocks += static_cast<uint32_t>((RS + GetMultipartState::WriteBufferSize - 1) / GetMultipartState::WriteBufferSize); + } + + auto State = std::make_shared<GetMultipartState>(Work, Pool, Stats, TotalBlocks); + State->Token = Token; + State->ContentHash = Hash; + State->DestPath = DestinationPath; + State->TotalRanges = RangeCount; + State->PendingRanges = RangeCount; + + // RangeCount > AdmissionCap is fine: the per-range acquire below blocks the + // dispatcher (caller thread, NOT the io strand) until in-flight ranges fire + // their AsyncStream completions and release slots. In-flight ranges per + // upload stay bounded by AdmissionCap. + + try + { + // Sparse + preallocated. Sparse mode lets per-range writes land at + // arbitrary offsets without the OS zero-filling intervening pages + // first; preallocating the full size up front avoids fragmentation + // and lazy-commit page faults during the writes themselves. + State->DestFile = std::make_shared<BasicFile>(DestinationPath, BasicFile::Mode::kTruncate); + std::error_code PrepareEc = PrepareFileForScatteredWrite(State->DestFile->Handle(), Size); + if (PrepareEc) + { + throw zen::runtime_error("PrepareFileForScatteredWrite failed: {}"sv, PrepareEc.message()); + } + } + catch (const std::exception& Ex) + { + // Drop DestFile (if opened) before remove() so Windows lets the unlink + // proceed; otherwise we leave a 0-byte file behind on the prealloc path. + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(DestinationPath, Ec); + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::GetMultipart '{}': preallocate failed: {}"sv, Hash, Ex.what()))); + return; + } + + std::string Path = CasPath(Hash); + + // Snapshot creds once for the whole multipart Get. Mirrors PutMultipart; + // avoids N shared-lock acquisitions on the credential provider when the + // range count is large. + const SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + // Empty creds are loop-invariant; drive PendingRanges to zero in one shot + // rather than acquiring TotalRanges admission slots only to fire identical + // failures. RecordGetPartFailure CAS keeps the first message; subsequent + // completions just decrement the counter. + RecordGetPartFailure(*State, "no credentials available"); + for (uint32_t I = 0; I < State->TotalRanges; ++I) + { + OnGetPartCompleted(State); + } + return; + } + + for (uint32_t I = 0; I < State->TotalRanges; ++I) + { + const uint64_t Offset = static_cast<uint64_t>(I) * Chunk; + const uint64_t RangeSize = std::min<uint64_t>(Chunk, Size - Offset); + const uint32_t RangeIdx = I; + + try + { + // Per-range admission slot. Acquire blocks the caller thread when the + // cap is reached; in-flight ranges release slots via SlotRef destruct + // in their AsyncStream completion callback. Acquire on the caller + // (vs the io strand) is safe to block: GetMultipart is called from + // the provision-pool worker driving Storage::Get, never from the + // curl io thread that has to drain completions. + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission, &Stats); + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "GET", Path, "", S3EmptyPayloadHash); + Headers->emplace("Range", fmt::format("bytes={}-{}", Offset, Offset + RangeSize - 1)); + + Stopwatch Timer = State->Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncStream does not + // strand InFlight elevated. + auto InFlightGuard = MakeGuard([&State] { State->Stats.EndRequest(0, 0); }); + + // Per-range mirror of GetStreamState's buffered-write machinery (see + // the medium-tier branch in Get()). + constexpr size_t WriteBufferSize = GetMultipartState::WriteBufferSize; + + struct RangeWriteState + { + IoBuffer ActiveBuf; + size_t BufFill = 0; + uint64_t NextAbsOffset; + uint64_t TotalReceived = 0; + std::atomic<uint32_t> PendingWork{1}; // +1 for stream completion + }; + + auto WState = std::make_shared<RangeWriteState>(); + WState->NextAbsOffset = Offset; + + m_Client.AsyncStream( + Path, + [this, State, RangeSize, RangeIdx, WState](const uint8_t* Data, size_t SizeBytes, uint64_t /*TotalSize*/) -> bool { + if (WState->TotalReceived + SizeBytes > RangeSize) + { + RecordGetPartFailure(*State, + fmt::format("S3 GET range {} overflow: got {} bytes past expected {}", + RangeIdx, + WState->TotalReceived + SizeBytes, + RangeSize)); + return false; + } + const uint8_t* Cur = Data; + size_t Remaining = SizeBytes; + while (Remaining > 0) + { + if (!WState->ActiveBuf) + { + WState->ActiveBuf = State->Buffers.Acquire(); + } + const size_t Avail = WriteBufferSize - WState->BufFill; + const size_t Take = std::min(Remaining, Avail); + std::memcpy(WState->ActiveBuf.MutableData<uint8_t>() + WState->BufFill, Cur, Take); + WState->BufFill += Take; + WState->TotalReceived += Take; + Cur += Take; + Remaining -= Take; + if (WState->BufFill == WriteBufferSize) + { + IoBuffer Buf = std::move(WState->ActiveBuf); + const size_t Fill = WState->BufFill; + const uint64_t AbsOffset = WState->NextAbsOffset; + WState->NextAbsOffset += Fill; + WState->BufFill = 0; + WState->PendingWork.fetch_add(1, std::memory_order_acq_rel); + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State, RangeIdx, WState, Buf = std::move(Buf), Fill, AbsOffset](std::atomic<bool>&) mutable { + try + { + State->DestFile->Write(Buf.GetData(), Fill, AbsOffset); + } + catch (const std::exception& Ex) + { + RecordGetPartFailure(*State, + fmt::format("S3 GET range {} write failed: {}", RangeIdx, Ex.what())); + } + State->Buffers.Release(std::move(Buf)); + if (WState->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetPartCompleted(State); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + RecordGetPartFailure(*State, fmt::format("S3 GET range {} schedule failed: {}", RangeIdx, Ex.what())); + State->Buffers.RestoreAcquireSlot(); + WState->PendingWork.fetch_sub(1, std::memory_order_acq_rel); + return false; + } + } + } + return true; + }, + [this, State, RangeSize, RangeIdx, WState, Timer, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + State->Stats.EndRequest(Timer.GetElapsedTimeUs(), Resp.IsSuccess() ? WState->TotalReceived : 0); + bool Failed = !Resp.IsSuccess(); + if (Failed) + { + RecordGetPartFailure(*State, S3ErrorMessage(fmt::format("S3 GET range {} failed", RangeIdx), Resp)); + } + else if (WState->TotalReceived != RangeSize) + { + RecordGetPartFailure( + *State, + fmt::format("S3 GET range {} wrote {} bytes, expected {}", RangeIdx, WState->TotalReceived, RangeSize)); + Failed = true; + } + if (!Failed && WState->BufFill > 0) + { + IoBuffer Buf = std::move(WState->ActiveBuf); + const size_t Fill = WState->BufFill; + const uint64_t AbsOffset = WState->NextAbsOffset; + WState->BufFill = 0; + WState->PendingWork.fetch_add(1, std::memory_order_acq_rel); + try + { + State->Work.ScheduleWork( + State->Pool, + [this, State, RangeIdx, WState, Buf = std::move(Buf), Fill, AbsOffset](std::atomic<bool>&) mutable { + try + { + State->DestFile->Write(Buf.GetData(), Fill, AbsOffset); + } + catch (const std::exception& Ex) + { + RecordGetPartFailure(*State, fmt::format("S3 GET range {} write failed: {}", RangeIdx, Ex.what())); + } + State->Buffers.Release(std::move(Buf)); + if (WState->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetPartCompleted(State); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (const std::exception& Ex) + { + RecordGetPartFailure(*State, + fmt::format("S3 GET range {} tail-flush schedule failed: {}", RangeIdx, Ex.what())); + State->Buffers.RestoreAcquireSlot(); + WState->PendingWork.fetch_sub(1, std::memory_order_acq_rel); + } + } + if (WState->PendingWork.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + OnGetPartCompleted(State); + } + }, + Headers); + InFlightGuard.Dismiss(); + } + catch (const std::exception& Ex) + { + RecordGetPartFailure(*State, fmt::format("S3 GET range {} dispatch failed: {}", RangeIdx, Ex.what())); + OnGetPartCompleted(State); + } + } +} + +void +S3AsyncStorage::OnGetPartCompleted(std::shared_ptr<GetMultipartState> State) +{ + const uint32_t Remaining = State->PendingRanges.fetch_sub(1, std::memory_order_acq_rel); + if (Remaining != 1) + { + return; + } + + // All disk I/O hopped to the worker pool: see note in + // OnGetStreamFinalised. + const bool Failed = State->AnyFailed.load(std::memory_order_acquire); + try + { + State->Work.ScheduleWork( + State->Pool, + [State, Failed](std::atomic<bool>&) { + if (Failed) + { + std::string FirstError; + State->ErrorLock.WithExclusiveLock([&] { FirstError = std::move(State->FirstError); }); + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(State->DestPath, Ec); + State->Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Get '{}': {}"sv, State->ContentHash, FirstError))); + return; + } + + State->DestFile.reset(); + State->Token->Complete(); + }, + WorkerThreadPool::EMode::EnableBacklog); + } + catch (...) + { + State->DestFile.reset(); + std::error_code Ec; + std::filesystem::remove(State->DestPath, Ec); + State->Token->Fail(std::current_exception()); + } +} + +void +S3AsyncStorage::Touch(ParallelWork& Work, WorkerThreadPool& Pool, const IoHash& Hash, S3AsyncStorageStats& Stats) +{ + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Stats.RecordScheduled(); + + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission, &Stats); + + // Snapshot creds on the dispatcher; see PutSmall. + SigV4Credentials Creds = m_GetCreds(); + + Work.ScheduleWork( + Pool, + [this, Hash = IoHash(Hash), Token, Stats, Creds = std::move(Creds), SlotRef = std::move(SlotRef)]( + std::atomic<bool>& Abort) mutable { + if (Abort.load()) + { + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Touch '{}': aborted"sv, Hash))); + return; + } + try + { + if (Creds.AccessKeyId.empty()) + { + Token->Fail( + std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Touch '{}': no credentials available"sv, Hash))); + return; + } + + std::string Key = CasKey(Hash); + std::string Path = CasPath(Hash); + + std::vector<std::pair<std::string, std::string>> ExtraSigned{ + {"x-amz-copy-source", fmt::format("/{}/{}", m_Builder.BucketName(), AwsUriEncode(Key, false))}, + {"x-amz-metadata-directive", "REPLACE"}, + }; + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "PUT", Path, "", S3EmptyPayloadHash, ExtraSigned); + + Stopwatch Timer = Stats.BeginRequest(); + // Pair Begin with End so an alloc throw inside AsyncPut does not strand + // InFlight elevated. + auto InFlightGuard = MakeGuard([&Stats] { Stats.EndRequest(0, 0); }); + m_Client.AsyncPut( + Path, + [Hash, Token, Timer, Stats, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + Stats.EndRequest(Timer.GetElapsedTimeUs(), 0); + if (!Resp.IsSuccess()) + { + std::string Err = S3ErrorMessage("S3 Touch failed", Resp); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Touch '{}': {}"sv, Hash, Err))); + return; + } + // PUT-COPY (REPLACE) can return HTTP 200 with an <Error> body. Mirror + // the sync S3Client::Touch check; without this the dehydrate-touch + // fails silently. + std::string_view Body = Resp.AsText(); + std::string_view ErrorCode; + std::string_view ErrorMessage; + if (S3ExtractError(Body, ErrorCode, ErrorMessage)) + { + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::Touch '{}': returned error: {} - {}"sv, + Hash, + ErrorCode, + ErrorMessage))); + return; + } + Token->Complete(); + }, + Headers); + InFlightGuard.Dismiss(); + } + catch (...) + { + Token->Fail(std::current_exception()); + } + }); +} + +std::vector<std::string> +S3AsyncStorage::ListAllObjects(std::string_view Prefix) +{ + constexpr std::string_view ContentsCloseTag = "</Contents>"; + + // List requests are not counted in PhaseStats; List/DeleteAll is admin + // scaffolding (test fixture teardown, manual obliterate). Mirrors the + // CreateMultipart/Complete/Abort exclusion in PutMultipart. + + // Snapshot creds once for the listing run; ListObjectsV2 pagination is + // fast enough that mid-list refresh isn't needed. + const SigV4Credentials Creds = m_GetCreds(); + if (Creds.AccessKeyId.empty()) + { + throw zen::runtime_error("S3AsyncStorage::ListAllObjects: no credentials available"sv); + } + + std::vector<std::string> Keys; + std::string ContinuationToken; + std::string RootPath = m_Builder.BucketRootPath(); + while (true) + { + std::string CanonicalQs = fmt::format("list-type=2&prefix={}", AwsUriEncode(Prefix)); + if (!ContinuationToken.empty()) + { + CanonicalQs += fmt::format("&continuation-token={}", AwsUriEncode(ContinuationToken)); + } + + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Creds, "GET", RootPath, CanonicalQs, S3EmptyPayloadHash); + + std::string FullPath = S3BuildRequestPath(RootPath, CanonicalQs); + HttpClient::Response Resp = m_Client.Get(FullPath, Headers).get(); + if (!Resp.IsSuccess()) + { + throw zen::runtime_error("S3AsyncStorage::ListAllObjects '{}': {}"sv, Prefix, S3ErrorMessage("S3 LIST failed", Resp)); + } + + if (!Resp.ResponsePayload) + { + break; + } + std::string_view Body(reinterpret_cast<const char*>(Resp.ResponsePayload.GetData()), Resp.ResponsePayload.GetSize()); + + std::string_view Cursor = Body; + while (true) + { + size_t ContentsOpen = Cursor.find("<Contents>"); + if (ContentsOpen == std::string_view::npos) + { + break; + } + size_t ContentsClose = Cursor.find(ContentsCloseTag, ContentsOpen); + if (ContentsClose == std::string_view::npos) + { + break; + } + std::string_view ContentsBlock = Cursor.substr(ContentsOpen, ContentsClose - ContentsOpen); + std::string_view Key = S3ExtractXmlValue(ContentsBlock, "Key"); + Keys.emplace_back(Key); + Cursor = Cursor.substr(ContentsClose + ContentsCloseTag.size()); + } + + std::string_view IsTruncated = S3ExtractXmlValue(Body, "IsTruncated"); + if (IsTruncated != "true") + { + break; + } + std::string_view NextToken = S3ExtractXmlValue(Body, "NextContinuationToken"); + if (NextToken.empty()) + { + break; + } + ContinuationToken.assign(NextToken); + } + return Keys; +} + +std::vector<IoHash> +S3AsyncStorage::List() +{ + std::string CasPrefix = fmt::format("{}/cas/", m_KeyPrefix); + std::vector<std::string> Keys = ListAllObjects(CasPrefix); + + std::vector<IoHash> Hashes; + Hashes.reserve(Keys.size()); + for (const std::string& Key : Keys) + { + size_t LastSlash = Key.rfind('/'); + if (LastSlash == std::string::npos) + { + continue; + } + IoHash Hash; + if (IoHash::TryParse(std::string_view(Key).substr(LastSlash + 1), Hash)) + { + Hashes.push_back(Hash); + } + } + return Hashes; +} + +void +S3AsyncStorage::DeleteAll(ParallelWork& Work) +{ + std::string Prefix = fmt::format("{}/", m_KeyPrefix); + std::vector<std::string> Keys = ListAllObjects(Prefix); + + // Snapshot creds once for the whole obliterate; saves N shared-lock + // acquisitions on the credential provider. + const SigV4Credentials DelCreds = m_GetCreds(); + if (DelCreds.AccessKeyId.empty()) + { + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::DeleteAll: no credentials available"sv))); + return; + } + + for (const std::string& Key : Keys) + { + auto Token = std::make_shared<ParallelWork::ExternalWorkToken>(Work.RegisterExternal()); + + try + { + std::string Path = m_Builder.KeyToPath(Key); + + HttpClient::KeyValueMap DelHeaders = m_Builder.SignRequest(DelCreds, "DELETE", Path, "", S3EmptyPayloadHash); + + // DeleteAll is untracked by stats - admission slot is acquired but the + // wait is not recorded; passes Stats == nullptr to the helper. + std::shared_ptr<void> SlotRef = AcquireAdmissionSlot(m_Admission); + + m_Client.AsyncDelete( + Path, + [KeyCopy = Key, Token, SlotRef = std::move(SlotRef)](HttpClient::Response Resp) mutable { + if (!Resp.IsSuccess() && Resp.StatusCode != HttpResponseCode::NotFound) + { + std::string Err = S3ErrorMessage("S3 DELETE failed", Resp); + Token->Fail(std::make_exception_ptr(zen::runtime_error("S3AsyncStorage::DeleteAll '{}': {}"sv, KeyCopy, Err))); + return; + } + Token->Complete(); + }, + DelHeaders); + } + catch (...) + { + // Sign / acquire / dispatch threw before AsyncDelete posted. Surface via Token so + // Wait() reports the partial failure instead of silently dropping the key. + Token->Fail(std::current_exception()); + } + } +} + +void +s3asyncstorage_forcelink() +{ +} + +#if ZEN_WITH_TESTS + +namespace { + // Per-binary unique MinIO port. + uint16_t AllocateMinioTestPort() + { + static const uint16_t Base = static_cast<uint16_t>(20000u + (static_cast<uint32_t>(GetCurrentProcessId()) % 30000u)); + static std::atomic<uint16_t> Slot{0}; + return Base + Slot.fetch_add(1, std::memory_order_relaxed); + } + + MinioProcessOptions MakeMinioOpts() + { + MinioProcessOptions Opts; + Opts.Port = AllocateMinioTestPort(); + return Opts; + } + + struct AsyncS3Fixture + { + MinioProcess Minio{MakeMinioOpts()}; + ScopedTemporaryDirectory TmpDir; + std::unique_ptr<AsyncHttpClient> ClientStorage; + std::unique_ptr<S3RequestBuilder> Builder; + AsyncHttpClient* Client = nullptr; + SigV4Credentials Creds; + WorkerThreadPool Pool{4}; + + AsyncS3Fixture() + { + Minio.SpawnMinioServer(); + Minio.CreateBucket("async-test"); + + Creds.AccessKeyId = "minioadmin"; + Creds.SecretAccessKey = "minioadmin"; + + Builder = std::make_unique<S3RequestBuilder>("us-east-1", "async-test", Minio.Endpoint(), /*PathStyle*/ true); + + HttpClientSettings Settings; + Settings.LogCategory = "async-s3-test"; + Settings.MaxConcurrentConnectionsPerHost = 16; + // Match production S3Hydration: cover MinIO post-spawn 503 + // (XMinioServerNotInitialized) transients before storage backend + // finishes warming up after the spawn ready-check passes. + Settings.RetryCount = 3; + ClientStorage = std::make_unique<AsyncHttpClient>(Minio.Endpoint(), Settings); + Client = ClientStorage.get(); + } + }; + + std::filesystem::path WriteBlob(const std::filesystem::path& Path, const std::vector<uint8_t>& Bytes) + { + WriteFile(Path, IoBuffer(IoBuffer::Wrap, Bytes.data(), Bytes.size())); + return Path; + } + + struct StatsBlock + { + std::atomic<uint64_t> RequestCount{0}; + std::atomic<uint64_t> RequestTotalUs{0}; + std::atomic<uint64_t> RequestMaxUs{0}; + std::atomic<uint64_t> Bytes{0}; + std::atomic<uint32_t> InFlight{0}; + std::atomic<uint32_t> InFlightPeak{0}; + std::atomic<uint64_t> FirstScheduleUs{UINT64_MAX}; + std::atomic<uint64_t> FirstStartUs{UINT64_MAX}; + std::atomic<uint64_t> AdmissionWaitTotalUs{0}; + std::atomic<uint64_t> AdmissionWaitMaxUs{0}; + Stopwatch PhaseClock; + + S3AsyncStorageStats Ref() + { + return S3AsyncStorageStats{RequestCount, + RequestTotalUs, + RequestMaxUs, + Bytes, + InFlight, + InFlightPeak, + FirstScheduleUs, + FirstStartUs, + AdmissionWaitTotalUs, + AdmissionWaitMaxUs, + PhaseClock}; + } + }; +} // namespace + +TEST_SUITE_BEGIN("server.s3asyncstorage"); + +TEST_CASE("s3asyncstorage.put_get_round_trip") +{ + StatsBlock Sb; + AsyncS3Fixture F; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-A", + 8u * 1024u * 1024u); + + const uint64_t Size = 64u * 1024u; + std::vector<uint8_t> Bytes(Size); + for (size_t I = 0; I < Size; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 31u) & 0xFF); + } + + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "src.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, Size, SrcPath, Stats); + Work.Wait(); + } + + std::filesystem::path DstPath = F.TmpDir.Path() / "dst.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Get(Work, F.Pool, ContentHash, Size, DstPath, Stats); + Work.Wait(); + } + + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == Size); + + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == ContentHash); +} + +TEST_CASE("s3asyncstorage.touch_existing_object") +{ + StatsBlock Sb; + AsyncS3Fixture F; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-touch", + 8u * 1024u * 1024u); + + std::vector<uint8_t> Bytes{1, 2, 3, 4, 5}; + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "touch.bin", Bytes); + IoHash Hash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, Hash, Bytes.size(), SrcPath, Stats); + Work.Wait(); + } + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Touch(Work, F.Pool, Hash, Stats); + Work.Wait(); + } +} + +TEST_CASE("s3asyncstorage.list_returns_uploaded_hashes") +{ + StatsBlock Sb; + AsyncS3Fixture F; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-list", + 8u * 1024u * 1024u); + + const size_t N = 5; + std::vector<IoHash> Hashes; + std::vector<std::filesystem::path> SrcPaths; + for (size_t I = 0; I < N; ++I) + { + std::vector<uint8_t> Bytes(64, static_cast<uint8_t>(I + 1)); + std::filesystem::path P = WriteBlob(F.TmpDir.Path() / fmt::format("blob_{}.bin", I), Bytes); + Hashes.push_back(IoHash::HashBuffer(Bytes.data(), Bytes.size())); + SrcPaths.push_back(P); + } + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + for (size_t I = 0; I < N; ++I) + { + Storage.Put(Work, F.Pool, Hashes[I], 64, SrcPaths[I], Stats); + } + Work.Wait(); + } + + std::vector<IoHash> Listed = Storage.List(); + std::sort(Listed.begin(), Listed.end()); + std::vector<IoHash> Expected = Hashes; + std::sort(Expected.begin(), Expected.end()); + CHECK(Listed == Expected); +} + +TEST_CASE("s3asyncstorage.streaming_download_round_trip") +{ + // Object size sits between the streaming threshold (512 KiB) and the + // multipart threshold (PartSize + PartSize/4). Exercises the medium-tier + // AsyncStream branch in S3AsyncStorage::Get - body streams via a 512 KiB + // IoBuffer pool with per-buffer worker write hops. + StatsBlock Sb; + AsyncS3Fixture F; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-stream", + 8u * 1024u * 1024u); + + const uint64_t Size = 2u * 1024u * 1024u; // 2 MiB - between 512 KiB and 10 MiB + std::vector<uint8_t> Bytes(Size); + for (size_t I = 0; I < Size; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 53u + 11u) & 0xFF); + } + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "stream.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, Size, SrcPath, Stats); + Work.Wait(); + } + + std::filesystem::path DstPath = F.TmpDir.Path() / "stream_out.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Get(Work, F.Pool, ContentHash, Size, DstPath, Stats); + Work.Wait(); + } + + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == Size); + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == ContentHash); +} + +TEST_CASE("s3asyncstorage.multipart_round_trip") +{ + StatsBlock Sb; + AsyncS3Fixture F; + const uint64_t PartSize = 5u * 1024u * 1024u; // MinIO accepts >=5MiB parts. + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-multipart", + PartSize); + + // Three parts: 5 MiB + 5 MiB + 1 MiB. Threshold is PartSize + PartSize/4 (= + // 6.25 MiB). Total 11 MiB triggers multipart path. + const uint64_t TotalSize = 11u * 1024u * 1024u; + std::vector<uint8_t> Bytes(TotalSize); + for (size_t I = 0; I < TotalSize; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 131u + 7u) & 0xFF); + } + + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "big.bin", Bytes); + IoHash Hash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, Hash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + + std::filesystem::path DstPath = F.TmpDir.Path() / "big_out.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Get(Work, F.Pool, Hash, TotalSize, DstPath, Stats); + Work.Wait(); + } + + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == TotalSize); + + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == Hash); +} + +TEST_CASE("s3asyncstorage.parallel_puts") +{ + StatsBlock Sb; + AsyncS3Fixture F; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-parallel", + 8u * 1024u * 1024u); + + const size_t N = 32; + std::vector<IoHash> Hashes; + std::vector<std::filesystem::path> SrcPaths; + for (size_t I = 0; I < N; ++I) + { + std::vector<uint8_t> Bytes(1024, static_cast<uint8_t>(I)); + std::filesystem::path P = WriteBlob(F.TmpDir.Path() / fmt::format("p_{}.bin", I), Bytes); + Hashes.push_back(IoHash::HashBuffer(Bytes.data(), Bytes.size())); + SrcPaths.push_back(P); + } + + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + for (size_t I = 0; I < N; ++I) + { + Storage.Put(Work, F.Pool, Hashes[I], 1024, SrcPaths[I], Stats); + } + Work.Wait(); + + CHECK(Storage.List().size() == N); +} + +namespace { + // One MinIO + builder + creds + pool, reused across multiple admission / + // throttle test scopes. Each scope constructs its own AsyncHttpClient via + // MakeClient(Cap) so different MaxConcurrentRequests values are exercised + // without re-spawning MinIO. + struct ThrottledAsyncS3Fixture + { + MinioProcess Minio{MakeMinioOpts()}; + ScopedTemporaryDirectory TmpDir; + std::unique_ptr<S3RequestBuilder> Builder; + SigV4Credentials Creds; + WorkerThreadPool Pool{4}; + + ThrottledAsyncS3Fixture() + { + Minio.SpawnMinioServer(); + Minio.CreateBucket("async-test-throttle"); + + Creds.AccessKeyId = "minioadmin"; + Creds.SecretAccessKey = "minioadmin"; + + Builder = std::make_unique<S3RequestBuilder>("us-east-1", "async-test-throttle", Minio.Endpoint(), /*PathStyle*/ true); + } + + std::unique_ptr<AsyncHttpClient> MakeClient(uint32_t MaxConcurrentRequests) + { + HttpClientSettings Settings; + Settings.LogCategory = "async-s3-throttle-test"; + Settings.MaxConcurrentConnectionsPerHost = 16; + Settings.MaxConcurrentRequests = MaxConcurrentRequests; + // Match production S3Hydration: cover MinIO post-spawn 503 + // (XMinioServerNotInitialized) transients before storage backend + // finishes warming up after the spawn ready-check passes. + Settings.RetryCount = 3; + return std::make_unique<AsyncHttpClient>(Minio.Endpoint(), Settings); + } + }; +} // namespace + +// Non-multipart admission scenarios: small-file fanout under a cap, fanout +// with admission disabled, and admission-wait stat recording. All share one +// MinIO instance via ThrottledAsyncS3Fixture; each scope picks its own Cap +// and KeyPrefix so the bucket entries don't collide across scopes. +TEST_CASE("s3asyncstorage.admission.fanout") +{ + ThrottledAsyncS3Fixture F; + + // respects.async.client.cap - in-flight peak <= storage admission cap. + { + StatsBlock Sb; + const uint32_t Cap = 2; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-cap", + 8u * 1024u * 1024u, + Admission, + Cap); + + const size_t N = 8; + std::vector<IoHash> Hashes; + std::vector<std::filesystem::path> SrcPaths; + for (size_t I = 0; I < N; ++I) + { + std::vector<uint8_t> Bytes(256, static_cast<uint8_t>(I)); + std::filesystem::path P = WriteBlob(F.TmpDir.Path() / fmt::format("c_{}.bin", I), Bytes); + Hashes.push_back(IoHash::HashBuffer(Bytes.data(), Bytes.size())); + SrcPaths.push_back(P); + } + + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + for (size_t I = 0; I < N; ++I) + { + Storage.Put(Work, F.Pool, Hashes[I], 256, SrcPaths[I], Stats); + } + Work.Wait(); + + std::vector<IoHash> Listed = Storage.List(); + std::sort(Listed.begin(), Listed.end()); + std::vector<IoHash> Expected = Hashes; + std::sort(Expected.begin(), Expected.end()); + CHECK(Listed == Expected); + CHECK(Sb.InFlightPeak.load() <= 2u); + } + + // unlimited_concurrent_requests - admission disabled (nullptr semaphore, + // Cap=0). Storage drains the fanout cleanly without gating. + { + StatsBlock Sb; + auto Client = F.MakeClient(0); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-uncapped", + 8u * 1024u * 1024u); + + const size_t N = 8; + std::vector<IoHash> Hashes; + std::vector<std::filesystem::path> SrcPaths; + for (size_t I = 0; I < N; ++I) + { + std::vector<uint8_t> Bytes(256, static_cast<uint8_t>(I)); + std::filesystem::path P = WriteBlob(F.TmpDir.Path() / fmt::format("u_{}.bin", I), Bytes); + Hashes.push_back(IoHash::HashBuffer(Bytes.data(), Bytes.size())); + SrcPaths.push_back(P); + } + + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + for (size_t I = 0; I < N; ++I) + { + Storage.Put(Work, F.Pool, Hashes[I], 256, SrcPaths[I], Stats); + } + Work.Wait(); + + std::vector<IoHash> Listed = Storage.List(); + std::sort(Listed.begin(), Listed.end()); + std::vector<IoHash> Expected = Hashes; + std::sort(Expected.begin(), Expected.end()); + CHECK(Listed == Expected); + } + + // admission.wait.us.recorded - cap=2 with 8 fanout submits forces at least + // some submissions to block on the admission semaphore. AdmissionWait + // totals must be non-zero. + { + StatsBlock Sb; + const uint32_t Cap = 2; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-adm-wait", + 8u * 1024u * 1024u, + Admission, + Cap); + + const size_t N = 8; + std::vector<IoHash> Hashes; + std::vector<std::filesystem::path> SrcPaths; + for (size_t I = 0; I < N; ++I) + { + std::vector<uint8_t> Bytes(256, static_cast<uint8_t>(I)); + std::filesystem::path P = WriteBlob(F.TmpDir.Path() / fmt::format("aw_{}.bin", I), Bytes); + Hashes.push_back(IoHash::HashBuffer(Bytes.data(), Bytes.size())); + SrcPaths.push_back(P); + } + + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + for (size_t I = 0; I < N; ++I) + { + Storage.Put(Work, F.Pool, Hashes[I], 256, SrcPaths[I], Stats); + } + Work.Wait(); + + CHECK(Sb.AdmissionWaitTotalUs.load() > 0u); + CHECK(Sb.AdmissionWaitMaxUs.load() > 0u); + CHECK(Sb.InFlightPeak.load() <= Cap); + } +} + +// Drives the AbortMultipart path: a 3-part multipart upload where one part is +// forced to fail via the test hook. AnyFailed flips, FinalizePutPart hops +// off-strand once PendingParts reaches 1, AbortMultipart sends DELETE to S3 +// to discard the upload-id. +// Token must surface the original part failure; bucket must not contain the +// uploaded CAS object after Wait returns. +TEST_CASE("s3asyncstorage.multipart.abort_on_part_failure") +{ + StatsBlock Sb; + AsyncS3Fixture F; + const uint64_t PartSize = 5u * 1024u * 1024u; + S3AsyncStorage Storage( + *F.Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-multipart-abort", + PartSize); + + const uint64_t TotalSize = 11u * 1024u * 1024u; // 3 parts: 5+5+1 MiB + std::vector<uint8_t> Bytes(TotalSize); + for (size_t I = 0; I < TotalSize; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 23u + 9u) & 0xFF); + } + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "abort.bin", Bytes); + IoHash Hash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + // Force one part to fail; AnyFailed -> AbortMultipart. + s3asyncstorage_test_hooks::ForceNextPartFailures(1); + + bool ThrewFromWait = false; + std::string ErrMessage; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, Hash, TotalSize, SrcPath, Stats); + try + { + Work.Wait(); + } + catch (const std::exception& Ex) + { + ThrewFromWait = true; + ErrMessage = Ex.what(); + } + } + CHECK(ThrewFromWait); + CHECK(ErrMessage.find("test-injected part failure") != std::string::npos); + + // Reset hook so subsequent multipart tests aren't affected. + s3asyncstorage_test_hooks::ForceNextPartFailures(0); + + // Bucket must not contain the CAS object - AbortMultipart discards the + // upload before S3 publishes the assembled object. + std::vector<IoHash> Listed = Storage.List(); + CHECK(std::find(Listed.begin(), Listed.end(), Hash) == Listed.end()); +} + +// Multipart admission scenarios sharing one MinIO instance. Each scope picks +// its own Cap, KeyPrefix, and payload pattern so bucket entries don't collide. +// Covers: +// - under.cap: TotalParts < cap (no slot handoff needed); round-trip Put+Get. +// - no.deadlock.streaming.get: Put + ranged Get on same hydration pool. +// - parts.exceed.cap.paces: TotalParts > cap; wave + slot-handoff completes +// successfully with InFlightPeak <= cap. +// - aborts.release.tokens: forced part failure releases held slots so a +// follow-up Put can run. +// - handoff.in.flight: cap=1 forces strictly sequential dispatch. +TEST_CASE("s3asyncstorage.admission.multipart") +{ + ThrottledAsyncS3Fixture F; + const uint64_t PartSize = 5u * 1024u * 1024u; // MinIO accepts >= 5 MiB parts. + + // under.cap: 3-part Put + Get sits within the initial wave (cap=4). + { + StatsBlock Sb; + const uint32_t Cap = 4; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-multipart-cap", + PartSize, + Admission, + Cap); + + const uint64_t TotalSize = 11u * 1024u * 1024u; + std::vector<uint8_t> Bytes(TotalSize); + for (size_t I = 0; I < TotalSize; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 17u + 5u) & 0xFF); + } + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "mp.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + + std::filesystem::path DstPath = F.TmpDir.Path() / "mp_out.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Get(Work, F.Pool, ContentHash, TotalSize, DstPath, Stats); + Work.Wait(); + } + + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == TotalSize); + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == ContentHash); + } + + // no.deadlock.streaming.get: PutMultipart fans onto hydration pool while + // io strand fires AsyncPut completions; followed by Get exercising stream + // write-back on the same pool. + { + StatsBlock Sb; + const uint32_t Cap = 4; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-stream-no-deadlock", + PartSize, + Admission, + Cap); + + const uint64_t TotalSize = 11u * 1024u * 1024u; + std::vector<uint8_t> Bytes(TotalSize); + for (size_t I = 0; I < TotalSize; ++I) + { + Bytes[I] = static_cast<uint8_t>((I * 41u + 7u) & 0xFF); + } + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "sd.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + + std::filesystem::path DstPath = F.TmpDir.Path() / "sd_out.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Get(Work, F.Pool, ContentHash, TotalSize, DstPath, Stats); + Work.Wait(); + } + + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == TotalSize); + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == ContentHash); + } + + // parts.exceed.cap.paces: TotalParts > cap completes by pacing via slot + // handoff. Initial wave dispatches Cap parts; each completion hands its + // slot to the next undispatched part. InFlightPeak stays bounded by Cap. + { + StatsBlock Sb; + const uint32_t Cap = 2; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-mp-paces", + PartSize, + Admission, + Cap); + + const uint64_t TotalSize = 16u * 1024u * 1024u; // 4 parts > cap=2 + std::vector<uint8_t> Bytes(TotalSize, 0xA5); + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "mp_paces.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + CHECK(Sb.InFlightPeak.load() <= Cap); + std::vector<IoHash> Listed = Storage.List(); + CHECK(std::find(Listed.begin(), Listed.end(), ContentHash) != Listed.end()); + } + + // aborts.release.tokens: forced part failure mid-flight; drain releases + // admission slots so a follow-up small Put can run without blocking. + { + StatsBlock Sb; + const uint32_t Cap = 4; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-mp-abort-release", + PartSize, + Admission, + Cap); + + const uint64_t TotalSize = 11u * 1024u * 1024u; // 3 parts under cap=4 + std::vector<uint8_t> Bytes(TotalSize, 0x5A); + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "mp_abrel.bin", Bytes); + IoHash AbortHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + s3asyncstorage_test_hooks::ForceNextPartFailures(1); + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, AbortHash, TotalSize, SrcPath, Stats); + try + { + Work.Wait(); + } + catch (...) + { + } + } + s3asyncstorage_test_hooks::ForceNextPartFailures(0); + + std::vector<uint8_t> Small(1024, 0x11); + std::filesystem::path SmallPath = WriteBlob(F.TmpDir.Path() / "mp_abrel_small.bin", Small); + IoHash SmallHash = IoHash::HashBuffer(Small.data(), Small.size()); + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, SmallHash, Small.size(), SmallPath, Stats); + Work.Wait(); + } + std::vector<IoHash> Listed = Storage.List(); + CHECK(std::find(Listed.begin(), Listed.end(), SmallHash) != Listed.end()); + } + + // handoff.in.flight: cap=1 forces strictly sequential dispatch (initial + // wave size 1, every subsequent part dispatched via handoff from the prior + // AsyncPut completion). + { + StatsBlock Sb; + const uint32_t Cap = 1; + auto Admission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(Cap)); + auto Client = F.MakeClient(Cap); + S3AsyncStorage Storage( + *Client, + *F.Builder, + [&F]() { return F.Creds; }, + "module-mp-handoff", + PartSize, + Admission, + Cap); + + const uint64_t TotalSize = 16u * 1024u * 1024u; // 4 parts; handoff fires 3 times + std::vector<uint8_t> Bytes(TotalSize, 0x33); + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "mp_handoff.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + Storage.Put(Work, F.Pool, ContentHash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + CHECK(Sb.InFlightPeak.load() == 1u); + std::vector<IoHash> Listed = Storage.List(); + CHECK(std::find(Listed.begin(), Listed.end(), ContentHash) != Listed.end()); + } + + // ranges.exceed.cap.paces: GetMultipart with RangeCount > AdmissionCap. + // The dispatcher loop's per-range AcquireAdmissionSlot blocks the caller + // thread until in-flight ranges fire AsyncStream completions and release + // slots. InFlightPeak per Get stays bounded by Cap. Upload first with a + // larger client cap so the Put itself doesn't pace, then re-open Storage + // with the small cap purely to exercise the Get pacing path. + { + StatsBlock Sb; + const uint32_t UploadCap = 8; + auto UploadAdmission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(UploadCap)); + auto UploadClient = F.MakeClient(UploadCap); + S3AsyncStorage UploadStorage( + *UploadClient, + *F.Builder, + [&F]() { return F.Creds; }, + "module-mp-get-paces", + PartSize, + UploadAdmission, + UploadCap); + + const uint64_t TotalSize = 16u * 1024u * 1024u; // 4 ranges with PartSize=5 MiB + std::vector<uint8_t> Bytes(TotalSize, 0x77); + std::filesystem::path SrcPath = WriteBlob(F.TmpDir.Path() / "mp_get_paces.bin", Bytes); + IoHash ContentHash = IoHash::HashBuffer(Bytes.data(), Bytes.size()); + + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = Sb.Ref(); + UploadStorage.Put(Work, F.Pool, ContentHash, TotalSize, SrcPath, Stats); + Work.Wait(); + } + + StatsBlock GetSb; + const uint32_t GetCap = 2; + auto GetAdmission = std::make_shared<AdmissionSemaphore>(static_cast<std::ptrdiff_t>(GetCap)); + auto GetClient = F.MakeClient(GetCap); + S3AsyncStorage GetStorage( + *GetClient, + *F.Builder, + [&F]() { return F.Creds; }, + "module-mp-get-paces", + PartSize, + GetAdmission, + GetCap); + + std::filesystem::path DstPath = F.TmpDir.Path() / "mp_get_paces_out.bin"; + { + std::atomic<bool> AbortFlag{false}; + std::atomic<bool> PauseFlag{false}; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + S3AsyncStorageStats Stats = GetSb.Ref(); + GetStorage.Get(Work, F.Pool, ContentHash, TotalSize, DstPath, Stats); + Work.Wait(); + } + CHECK(GetSb.InFlightPeak.load() <= GetCap); + REQUIRE(std::filesystem::exists(DstPath)); + REQUIRE(std::filesystem::file_size(DstPath) == TotalSize); + BasicFile Out(DstPath, BasicFile::Mode::kRead); + IoBuffer Read = Out.ReadAll(); + IoHash ReadHash = IoHash::HashBuffer(Read); + CHECK(ReadHash == ContentHash); + } +} + +TEST_SUITE_END(); + +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zenserver/hub/s3asyncstorage.h b/src/zenserver/hub/s3asyncstorage.h new file mode 100644 index 000000000..1bb8c14ca --- /dev/null +++ b/src/zenserver/hub/s3asyncstorage.h @@ -0,0 +1,224 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> +#include <zencore/parallelwork.h> +#include <zencore/timer.h> +#include <zenhttp/asynchttpclient.h> +#include <zenutil/cloud/s3client.h> +#include <zenutil/cloud/s3requestbuilder.h> +#include <zenutil/cloud/sigv4.h> + +#include <atomic> +#include <filesystem> +#include <functional> +#include <limits> +#include <memory> +#include <semaphore> +#include <string> + +namespace zen { + +// Storage-layer admission semaphore. Sized at runtime to MaxConcurrentRequests; +// LeastMaxValue is the compile-time upper bound. Caller may pass nullptr to +// disable gating entirely. +using AdmissionSemaphore = std::counting_semaphore<std::numeric_limits<std::ptrdiff_t>::max()>; + +struct S3AsyncStorageStats +{ + std::atomic<uint64_t>& RequestCount; + std::atomic<uint64_t>& RequestTotalUs; + std::atomic<uint64_t>& RequestMaxUs; + std::atomic<uint64_t>& Bytes; + std::atomic<uint32_t>& InFlight; + std::atomic<uint32_t>& InFlightPeak; + std::atomic<uint64_t>& FirstScheduleUs; + std::atomic<uint64_t>& FirstStartUs; + std::atomic<uint64_t>& AdmissionWaitTotalUs; + std::atomic<uint64_t>& AdmissionWaitMaxUs; + Stopwatch& PhaseClock; + + void RecordScheduled() + { + const uint64_t Now = PhaseClock.GetElapsedTimeUs(); + uint64_t Existing = FirstScheduleUs.load(std::memory_order_relaxed); + while (Now < Existing && !FirstScheduleUs.compare_exchange_weak(Existing, Now, std::memory_order_relaxed)) + { + } + } + + Stopwatch BeginRequest() + { + const uint64_t Now = PhaseClock.GetElapsedTimeUs(); + uint64_t Existing = FirstStartUs.load(std::memory_order_relaxed); + while (Now < Existing && !FirstStartUs.compare_exchange_weak(Existing, Now, std::memory_order_relaxed)) + { + } + const uint32_t Current = InFlight.fetch_add(1, std::memory_order_relaxed) + 1; + uint32_t Peak = InFlightPeak.load(std::memory_order_relaxed); + while (Current > Peak && !InFlightPeak.compare_exchange_weak(Peak, Current, std::memory_order_relaxed)) + { + } + return Stopwatch{}; + } + + void EndRequest(uint64_t ElapsedUs, uint64_t BytesValue) + { + InFlight.fetch_sub(1, std::memory_order_relaxed); + RequestCount.fetch_add(1, std::memory_order_relaxed); + RequestTotalUs.fetch_add(ElapsedUs, std::memory_order_relaxed); + Bytes.fetch_add(BytesValue, std::memory_order_relaxed); + uint64_t Existing = RequestMaxUs.load(std::memory_order_relaxed); + while (ElapsedUs > Existing && !RequestMaxUs.compare_exchange_weak(Existing, ElapsedUs, std::memory_order_relaxed)) + { + } + } + + void RecordAdmissionWait(uint64_t Us) + { + AdmissionWaitTotalUs.fetch_add(Us, std::memory_order_relaxed); + uint64_t Prev = AdmissionWaitMaxUs.load(std::memory_order_relaxed); + while (Us > Prev && !AdmissionWaitMaxUs.compare_exchange_weak(Prev, Us, std::memory_order_relaxed)) + { + } + } +}; + +// Async S3 storage adapter for hub hydration. Mirrors the behavior of the +// blocking `S3Storage` (defined inside hydration.cpp) but submits requests +// via `AsyncHttpClient` and counts in-flight work against `ParallelWork` +// using `ExternalWorkToken` instead of occupying worker-pool threads. +// +// Construction: +// S3AsyncStorage Storage(AsyncClient, RequestBuilder, GetCreds, KeyPrefix, +// MultipartChunkSize); +// +// AsyncClient and RequestBuilder are owned by the caller (typically the hub +// or a test fixture). GetCreds is a callable returning the latest SigV4 +// credentials - keeps the storage decoupled from any specific provider. +// +// Per-call SourcePath/DestinationPath args carry temp-dir state from the +// caller (IncrementalHydrator picks paths under HydrationConfig::TempDir); +// S3AsyncStorage itself never owns a temp directory. +// +// DeleteAll blocks the caller while the prefix listing runs, then issues +// async deletes via Work. +class S3AsyncStorage +{ +public: + using CredentialsCallback = std::function<SigV4Credentials()>; + + S3AsyncStorage(AsyncHttpClient& Client, + S3RequestBuilder& Builder, + CredentialsCallback GetCreds, + std::string KeyPrefix, + uint64_t MultipartChunkSize, + std::shared_ptr<AdmissionSemaphore> Admission = nullptr, + uint32_t AdmissionCap = 0); + + std::string_view KeyPrefix() const { return m_KeyPrefix; } + + // Async data operations. Each registers a ParallelWork::ExternalWorkToken; + // the async callback completes/fails the token. Caller drives Work.Wait() + // to block until all submissions finish. + // + // Stats (must outlive in-flight callbacks): every S3 request the storage + // issues calls RecordScheduled at submit time, BeginRequest just before + // the network handoff, and EndRequest from the completion callback. Bytes + // is the payload size on success, 0 on failure or zero-payload requests. + // Multipart and ranged GET fire one Begin/EndRequest per part/range. + void Put(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats); + void Get(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath, + S3AsyncStorageStats& Stats); + void Touch(ParallelWork& Work, WorkerThreadPool& Pool, const IoHash& Hash, S3AsyncStorageStats& Stats); + + // Synchronous list of all CAS hashes under this module's prefix. + std::vector<IoHash> List(); + + // Synchronous delete-all under the module prefix. Lists then issues async + // deletes via Work; caller still drives Work.Wait(). + void DeleteAll(ParallelWork& Work); + + // Forward declarations are public (not the structs themselves) so file-scope + // free helpers in s3asyncstorage.cpp can name the types in their signatures + // without being declared friends. The structs are defined privately in the + // .cpp; no out-of-module caller can construct or inspect them. +public: + struct PutMultipartState; + struct GetMultipartState; + struct GetStreamState; + +private: + std::string CasKey(const IoHash& Hash) const; + std::string CasPath(const IoHash& Hash) const; + + void PutSmall(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats); + void PutMedium(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats); + void PutMultipart(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& SourcePath, + S3AsyncStorageStats& Stats); + void DispatchInitialPartWave(std::shared_ptr<PutMultipartState> State); + void DispatchPartUpload(std::shared_ptr<PutMultipartState> State, uint32_t PartNum, std::shared_ptr<void> SlotRef); + void HandoffSlotToNextPart(std::shared_ptr<PutMultipartState> State, uint32_t PartIdx, std::shared_ptr<void> SlotRef); + void DrainUndispatchedParts(std::shared_ptr<PutMultipartState> State); + void FinalizePutPart(std::shared_ptr<PutMultipartState> State); + void CompleteMultipart(std::shared_ptr<PutMultipartState> State); + void AbortMultipart(std::shared_ptr<PutMultipartState> State); + + void GetMultipart(ParallelWork& Work, + WorkerThreadPool& Pool, + const IoHash& Hash, + uint64_t Size, + const std::filesystem::path& DestinationPath, + S3AsyncStorageStats& Stats); + void OnGetPartCompleted(std::shared_ptr<GetMultipartState> State); + void OnGetStreamFinalised(std::shared_ptr<GetStreamState> State); + + std::vector<std::string> ListAllObjects(std::string_view Prefix); + + AsyncHttpClient& m_Client; + S3RequestBuilder& m_Builder; + CredentialsCallback m_GetCreds; + std::string m_KeyPrefix; + uint64_t m_MultipartChunkSize; + std::shared_ptr<AdmissionSemaphore> m_Admission; // nullable; when null, no admission gating + uint32_t m_AdmissionCap; // initial slot count; 0 when admission disabled +}; + +void s3asyncstorage_forcelink(); + +#if ZEN_WITH_TESTS +namespace s3asyncstorage_test_hooks { + // Per-process counter consumed by DispatchPartUpload. While > 0, each + // invocation decrements it and synthesizes a part-level failure (via + // RecordPutPartFailure + drain + FinalizePutPart fan-out) instead of + // issuing the UploadPart request. Used to drive the AbortMultipart path + // from tests without fault-injecting MinIO. + void ForceNextPartFailures(uint32_t Count); +} // namespace s3asyncstorage_test_hooks +#endif + +} // namespace zen diff --git a/src/zenserver/hub/storageserverinstance.cpp b/src/zenserver/hub/storageserverinstance.cpp index 8d36e6a46..34860eb9b 100644 --- a/src/zenserver/hub/storageserverinstance.cpp +++ b/src/zenserver/hub/storageserverinstance.cpp @@ -2,8 +2,6 @@ #include "storageserverinstance.h" -#include "hydration.h" - #include <zencore/assertfmt.h> #include <zencore/filesystem.h> #include <zencore/fmtutils.h> @@ -14,12 +12,8 @@ namespace zen { -StorageServerInstance::StorageServerInstance(ZenServerEnvironment& RunEnvironment, - HydrationBase& Hydration, - const Configuration& Config, - std::string_view ModuleId) -: m_Hydration(Hydration) -, m_Config(Config) +StorageServerInstance::StorageServerInstance(ZenServerEnvironment& RunEnvironment, const Configuration& Config, std::string_view ModuleId) +: m_Config(Config) , m_ModuleId(ModuleId) , m_ServerInstance(RunEnvironment, ZenServerInstance::ServerMode::kStorageServer) { @@ -138,7 +132,6 @@ StorageServerInstance::ProvisionLocked() ZEN_INFO("Provisioning storage server instance for module '{}', at '{}'", m_ModuleId, m_Config.StateDir); try { - Hydrate(); SpawnServerProcess(); } catch (const std::exception& Ex) @@ -156,19 +149,6 @@ StorageServerInstance::DeprovisionLocked() { ZEN_TRACE_CPU("StorageServerInstance::DeprovisionLocked"); ShutdownServerProcess(); - - // Crashed or Hibernated: process already dead; skip Shutdown. - // Dehydrate preserves instance state for future re-provisioning. Failure means saved state - // may be stale or absent, but the process is already dead so the slot can still be released. - // Swallow the exception and proceed with cleanup rather than leaving the module stuck. - try - { - Dehydrate(); - } - catch (const std::exception& Ex) - { - ZEN_WARN("Dehydration of module {} failed during deprovisioning, current state not saved. Reason: {}", m_ModuleId, Ex.what()); - } } void @@ -176,12 +156,6 @@ StorageServerInstance::ObliterateLocked() { ZEN_TRACE_CPU("StorageServerInstance::ObliterateLocked"); ShutdownServerProcess(); - - std::atomic<bool> AbortFlag{false}; - std::atomic<bool> PauseFlag{false}; - HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag); - std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration.CreateHydrator(Config); - Hydrator->Obliterate(); } void @@ -218,54 +192,6 @@ StorageServerInstance::WakeLocked() } } -void -StorageServerInstance::Hydrate() -{ - if (!m_Config.EnableHydration) - { - ZEN_INFO("Hydration disabled; skipping hydrate for module '{}'", m_ModuleId); - return; - } - ZEN_TRACE_CPU("StorageServerInstance::Hydrate"); - std::atomic<bool> AbortFlag{false}; - std::atomic<bool> PauseFlag{false}; - HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag); - std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration.CreateHydrator(Config); - m_HydrationState = Hydrator->Hydrate(); -} - -void -StorageServerInstance::Dehydrate() -{ - if (!m_Config.EnableDehydration) - { - ZEN_INFO("Dehydration disabled; skipping dehydrate for module '{}'", m_ModuleId); - return; - } - ZEN_TRACE_CPU("StorageServerInstance::Dehydrate"); - std::atomic<bool> AbortFlag{false}; - std::atomic<bool> PauseFlag{false}; - HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag); - std::unique_ptr<HydrationStrategyBase> Hydrator = m_Hydration.CreateHydrator(Config); - Hydrator->Dehydrate(m_HydrationState); -} - -HydrationConfig -StorageServerInstance::MakeHydrationConfig(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag) -{ - HydrationConfig Config{.ServerStateDir = m_Config.StateDir, .TempDir = m_Config.TempDir, .ModuleId = m_ModuleId}; - if (m_Config.OptionalWorkerPool) - { - Config.Threading.emplace( - HydrationConfig::ThreadingOptions{.WorkerPool = m_Config.OptionalWorkerPool, .AbortFlag = &AbortFlag, .PauseFlag = &PauseFlag}); - } - Config.PackEnabled = m_Config.HydrationPackEnabled; - Config.PackThresholdBytes = m_Config.HydrationPackThresholdBytes; - Config.MaxPackBytes = m_Config.HydrationMaxPackBytes; - - return Config; -} - StorageServerInstance::SharedLockedPtr::SharedLockedPtr() : m_Lock(nullptr), m_Instance(nullptr) { } diff --git a/src/zenserver/hub/storageserverinstance.h b/src/zenserver/hub/storageserverinstance.h index a2f376a23..599431300 100644 --- a/src/zenserver/hub/storageserverinstance.h +++ b/src/zenserver/hub/storageserverinstance.h @@ -2,18 +2,12 @@ #pragma once -#include "hydration.h" - -#include <zencore/compactbinary.h> #include <zenutil/zenserverprocess.h> -#include <atomic> #include <filesystem> namespace zen { -class WorkerThreadPool; - /** * Storage Server Instance * @@ -28,7 +22,6 @@ public: { uint16_t BasePort; std::filesystem::path StateDir; - std::filesystem::path TempDir; uint32_t HttpThreadCount = 0; // Automatic int CoreLimit = 0; // Automatic std::filesystem::path ConfigPath; @@ -36,19 +29,9 @@ public: std::string Trace; std::string TraceHost; std::string TraceFile; - bool EnableHydration = true; - bool EnableDehydration = true; - bool HydrationPackEnabled = true; - uint64_t HydrationPackThresholdBytes = DefaultPackThresholdBytes; - uint64_t HydrationMaxPackBytes = DefaultMaxPackBytes; - - WorkerThreadPool* OptionalWorkerPool = nullptr; }; - StorageServerInstance(ZenServerEnvironment& RunEnvironment, - HydrationBase& Hydration, - const Configuration& Config, - std::string_view ModuleId); + StorageServerInstance(ZenServerEnvironment& RunEnvironment, const Configuration& Config, std::string_view ModuleId); ~StorageServerInstance(); inline std::string_view GetModuleId() const { return m_ModuleId; } @@ -146,13 +129,10 @@ private: void WakeLocked(); mutable RwLock m_Lock; - HydrationBase& m_Hydration; const Configuration m_Config; std::string m_ModuleId; ZenServerInstance m_ServerInstance; - CbObject m_HydrationState; - #if ZEN_PLATFORM_WINDOWS JobObject* m_JobObject = nullptr; #endif @@ -160,10 +140,6 @@ private: void SpawnServerProcess(); void ShutdownServerProcess(); - void Hydrate(); - void Dehydrate(); - HydrationConfig MakeHydrationConfig(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag); - friend class SharedLockedPtr; friend class ExclusiveLockedPtr; }; diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp index 303b1f1b2..a2a366a80 100644 --- a/src/zenserver/hub/zenhubserver.cpp +++ b/src/zenserver/hub/zenhubserver.cpp @@ -188,18 +188,29 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.HubInstanceConfigPath), "<instance config>"); - const uint32_t DefaultHubInstanceProvisionThreadCount = Max(GetHardwareConcurrency() / 4u, 2u); + // Provision pool: scheduling-bound, blocked on hydrate. clamp(cpu/8, 4, 16). + const uint32_t DefaultHubInstanceProvisionThreadCount = Min(Max(GetHardwareConcurrency() / 8u, 4u), 16u); + // Spawn pool: blocking on CreateProcess + /health, RSS-bounded. clamp(cpu/8, 4, 16). + const uint32_t DefaultHubInstanceSpawnThreadCount = Min(Max(GetHardwareConcurrency() / 8u, 4u), 16u); Options.add_option("hub", "", "hub-instance-provision-threads", - fmt::format("Number of threads for instance provisioning (default {})", DefaultHubInstanceProvisionThreadCount), + fmt::format("Hub-wide hydrate/dehydrate scheduling pool size (default {})", DefaultHubInstanceProvisionThreadCount), cxxopts::value<uint32_t>(m_ServerOptions.HubInstanceProvisionThreadCount) ->default_value(fmt::format("{}", DefaultHubInstanceProvisionThreadCount)), "<threads>"); Options.add_option("hub", "", + "hub-instance-spawn-threads", + fmt::format("Hub-wide child-process spawn/despawn pool size (default {})", DefaultHubInstanceSpawnThreadCount), + cxxopts::value<uint32_t>(m_ServerOptions.HubInstanceSpawnThreadCount) + ->default_value(fmt::format("{}", DefaultHubInstanceSpawnThreadCount)), + "<threads>"); + + Options.add_option("hub", + "", "hub-hydration-target-spec", "Specification for hydration target. 'file://<path>' prefix indicates file storage at <path>. Defaults to " "<data-dir>/servers/hydration_storage", @@ -214,13 +225,18 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.HydrationTargetConfigPath), "<path>"); - const uint32_t DefaultHubHydrationThreadCount = Max(GetHardwareConcurrency() / 4u, 2u); + // Hydration pool: per-file workers inside one hydrate/dehydrate call (xxhash, pack + // build, file I/O; S3 I/O only on the blocking path when HubHydrationAsyncEnabled is + // false). At 30ms latency on loopback the AsyncHttpClient strand is the bottleneck; + // hydration=8/12/16/32 all landed within run-to-run noise (~10s) at 1000 modules. + // clamp(cpu/8, 4, 12). + const uint32_t DefaultHubHydrationThreadCount = Min(Max(GetHardwareConcurrency() / 8u, 4u), 12u); Options.add_option( "hub", "", "hub-hydration-threads", - fmt::format("Number of threads for hydration/dehydration (default {})", DefaultHubHydrationThreadCount), + fmt::format("Per-file worker pool size inside hydrate/dehydrate (default {})", DefaultHubHydrationThreadCount), cxxopts::value<uint32_t>(m_ServerOptions.HubHydrationThreadCount)->default_value(fmt::format("{}", DefaultHubHydrationThreadCount)), "<threads>"); @@ -273,6 +289,30 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("hub", "", + "hub-hydration-async-enabled", + "Route S3 hydration through AsyncHttpClient. false falls back to the blocking S3Client path.", + cxxopts::value<bool>(m_ServerOptions.HubHydrationAsyncEnabled)->default_value("true"), + "<bool>"); + + // Async S3 in-flight cap: also drives CURLMOPT_MAXCONNECTS and curl per-host conn limits in + // the AsyncHttpClient. Sized so big modules (~hundreds of parallel ranges) saturate the pipe + // without admission-wait, and the conn cache stays large enough to avoid eviction churn on + // reused conns. clamp(cpu*4, 128, 512). + const uint32_t DefaultHubHydrationAsyncMaxConcurrentRequests = Min(Max(GetHardwareConcurrency() * 4u, 128u), 512u); + + Options.add_option( + "hub", + "", + "hub-hydration-async-max-concurrent-requests", + fmt::format("Max in-flight S3 requests submitted to the AsyncHttpClient. Only used when --hub-hydration-async-enabled " + "(default {}).", + DefaultHubHydrationAsyncMaxConcurrentRequests), + cxxopts::value<uint32_t>(m_ServerOptions.HubHydrationAsyncMaxConcurrentRequests) + ->default_value(fmt::format("{}", DefaultHubHydrationAsyncMaxConcurrentRequests)), + "<n>"); + + Options.add_option("hub", + "", "hub-watchdog-cycle-interval-ms", "Interval between watchdog cycles in milliseconds", cxxopts::value<uint32_t>(m_ServerOptions.WatchdogConfig.CycleIntervalMs)->default_value("3000"), @@ -401,6 +441,7 @@ ZenHubServerConfigurator::AddConfigOptions(LuaConfig::Options& Options) Options.AddOption("hub.instance.provisionthreads"sv, m_ServerOptions.HubInstanceProvisionThreadCount, "hub-instance-provision-threads"sv); + Options.AddOption("hub.instance.spawnthreads"sv, m_ServerOptions.HubInstanceSpawnThreadCount, "hub-instance-spawn-threads"sv); Options.AddOption("hub.hydration.targetspec"sv, m_ServerOptions.HydrationTargetSpecification, "hub-hydration-target-spec"sv); Options.AddOption("hub.hydration.targetconfig"sv, m_ServerOptions.HydrationTargetConfigPath, "hub-hydration-target-config"sv); @@ -412,6 +453,10 @@ ZenHubServerConfigurator::AddConfigOptions(LuaConfig::Options& Options) m_ServerOptions.HubHydrationPackThresholdBytes, "hub-hydration-pack-threshold-bytes"sv); Options.AddOption("hub.hydration.maxpackbytes"sv, m_ServerOptions.HubHydrationMaxPackBytes, "hub-hydration-max-pack-bytes"sv); + Options.AddOption("hub.hydration.async.enabled"sv, m_ServerOptions.HubHydrationAsyncEnabled, "hub-hydration-async-enabled"sv); + Options.AddOption("hub.hydration.async.maxconcurrentrequests"sv, + m_ServerOptions.HubHydrationAsyncMaxConcurrentRequests, + "hub-hydration-async-max-concurrent-requests"sv); Options.AddOption("hub.watchdog.cycleintervalms"sv, m_ServerOptions.WatchdogConfig.CycleIntervalMs, "hub-watchdog-cycle-interval-ms"sv); Options.AddOption("hub.watchdog.cycleprocessingbudgetms"sv, @@ -585,9 +630,9 @@ ZenHubServer::Initialize(const ZenHubServerConfig& ServerConfig, ZenServerState: // the main test range. ZenServerEnvironment::SetBaseChildId(1000); - m_ProvisionWorkerPool = - std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubInstanceProvisionThreadCount), "hub_provision"); - m_HydrationWorkerPool = std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubHydrationThreadCount), "hub_hydration"); + m_ProvisionPool = std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubInstanceProvisionThreadCount), "hub_provision"); + m_SpawnPool = std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubInstanceSpawnThreadCount), "hub_spawn"); + m_HydrationPool = std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubHydrationThreadCount), "hub_hydration"); m_DebugOptionForcedCrash = ServerConfig.ShouldCrash; @@ -695,22 +740,24 @@ ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) { ZEN_INFO("instantiating Hub"); Hub::Configuration HubConfig{ - .UseJobObject = ServerConfig.HubUseJobObject, - .BasePortNumber = ServerConfig.HubBasePortNumber, - .InstanceLimit = ServerConfig.HubInstanceLimit, - .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount, - .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit, - .InstanceMalloc = ServerConfig.HubInstanceMalloc, - .InstanceTrace = ServerConfig.HubInstanceTrace, - .InstanceTraceHost = ServerConfig.HubInstanceTraceHost, - .InstanceTraceFile = ServerConfig.HubInstanceTraceFile, - .InstanceConfigPath = ServerConfig.HubInstanceConfigPath, - .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification, - .EnableHydration = ServerConfig.HubEnableHydration, - .EnableDehydration = ServerConfig.HubEnableDehydration, - .HydrationPackEnabled = ServerConfig.HubHydrationPackEnabled, - .HydrationPackThresholdBytes = ServerConfig.HubHydrationPackThresholdBytes, - .HydrationMaxPackBytes = ServerConfig.HubHydrationMaxPackBytes, + .UseJobObject = ServerConfig.HubUseJobObject, + .BasePortNumber = ServerConfig.HubBasePortNumber, + .InstanceLimit = ServerConfig.HubInstanceLimit, + .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount, + .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit, + .InstanceMalloc = ServerConfig.HubInstanceMalloc, + .InstanceTrace = ServerConfig.HubInstanceTrace, + .InstanceTraceHost = ServerConfig.HubInstanceTraceHost, + .InstanceTraceFile = ServerConfig.HubInstanceTraceFile, + .InstanceConfigPath = ServerConfig.HubInstanceConfigPath, + .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification, + .EnableHydration = ServerConfig.HubEnableHydration, + .EnableDehydration = ServerConfig.HubEnableDehydration, + .HydrationPackEnabled = ServerConfig.HubHydrationPackEnabled, + .HydrationPackThresholdBytes = ServerConfig.HubHydrationPackThresholdBytes, + .HydrationMaxPackBytes = ServerConfig.HubHydrationMaxPackBytes, + .HydrationAsyncEnabled = ServerConfig.HubHydrationAsyncEnabled, + .HydrationAsyncMaxConcurrentRequests = ServerConfig.HubHydrationAsyncMaxConcurrentRequests, .WatchDog = { .CycleInterval = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleIntervalMs), @@ -722,9 +769,10 @@ ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) .ActivityCheckConnectTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckConnectTimeoutMs), .ActivityCheckRequestTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckRequestTimeoutMs), }, - .ResourceLimits = ResolveLimits(ServerConfig), - .OptionalProvisionWorkerPool = m_ProvisionWorkerPool.get(), - .OptionalHydrationWorkerPool = m_HydrationWorkerPool.get()}; + .ResourceLimits = ResolveLimits(ServerConfig), + .OptionalProvisionPool = m_ProvisionPool.get(), + .OptionalSpawnPool = m_SpawnPool.get(), + .OptionalHydrationPool = m_HydrationPool.get()}; if (!ServerConfig.HydrationTargetConfigPath.empty()) { @@ -842,9 +890,8 @@ ZenHubServer::InitializeConsulRegistration(const ZenHubServerConfig& ServerConfi } consul::ServiceRegistrationInfo Info; - Info.ServiceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId); - Info.ServiceName = "zen-hub"; - // Info.Address = "localhost"; // Let the consul agent figure out out external address // TODO: Info.BaseUri? + Info.ServiceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId); + Info.ServiceName = "zen-hub"; Info.Port = static_cast<uint16_t>(EffectivePort); Info.HealthEndpoint = "health"; Info.Tags = std::vector<std::pair<std::string, std::string>>{ diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h index 6416792a6..0d5ac600f 100644 --- a/src/zenserver/hub/zenhubserver.h +++ b/src/zenserver/hub/zenhubserver.h @@ -53,18 +53,22 @@ struct ZenHubServerConfig : public ZenServerConfig bool HubHydrationPackEnabled = true; // Concatenate small files into raw CAS pack blobs during dehydrate uint64_t HubHydrationPackThresholdBytes = DefaultPackThresholdBytes; // Files strictly smaller than this are pack candidates uint64_t HubHydrationMaxPackBytes = DefaultMaxPackBytes; // Upper bound on a pack's concatenation size - std::string HubInstanceHttpClass = "asio"; - std::string HubInstanceMalloc; - std::string HubInstanceTrace; - std::string HubInstanceTraceHost; - std::string HubInstanceTraceFile; - uint32_t HubInstanceHttpThreadCount = 0; // Automatic - uint32_t HubInstanceProvisionThreadCount = 0; // Synchronous provisioning - uint32_t HubHydrationThreadCount = 0; // Synchronous hydration/dehydration - int HubInstanceCoreLimit = 0; // Automatic - std::filesystem::path HubInstanceConfigPath; // Path to Lua config file - std::string HydrationTargetSpecification; // hydration/dehydration target specification - std::filesystem::path HydrationTargetConfigPath; // path to JSON config file (mutually exclusive with HydrationTargetSpecification) + bool HubHydrationAsyncEnabled = true; // Route S3 hydration through AsyncHttpClient + uint32_t HubHydrationAsyncMaxConcurrentRequests = + 128; // Hub-wide cap on concurrent S3 hydration requests (only when HubHydrationAsyncEnabled) + std::string HubInstanceHttpClass = "asio"; + std::string HubInstanceMalloc; + std::string HubInstanceTrace; + std::string HubInstanceTraceHost; + std::string HubInstanceTraceFile; + uint32_t HubInstanceHttpThreadCount = 0; // Automatic + uint32_t HubInstanceProvisionThreadCount = 0; // Hub-wide hydrate/dehydrate scheduling pool size + uint32_t HubInstanceSpawnThreadCount = 0; // Hub-wide child process spawn/despawn pool size + uint32_t HubHydrationThreadCount = 0; // Internal hydration parallelism (per-file) + int HubInstanceCoreLimit = 0; // Automatic + std::filesystem::path HubInstanceConfigPath; // Path to Lua config file + std::string HydrationTargetSpecification; // hydration/dehydration target specification + std::filesystem::path HydrationTargetConfigPath; // path to JSON config file (mutually exclusive with HydrationTargetSpecification) ZenHubWatchdogConfig WatchdogConfig; uint64_t HubProvisionDiskLimitBytes = 0; uint32_t HubProvisionDiskLimitPercent = 0; @@ -137,8 +141,9 @@ private: bool m_DebugOptionForcedCrash = false; std::unique_ptr<HttpProxyHandler> m_Proxy; - std::unique_ptr<WorkerThreadPool> m_ProvisionWorkerPool; - std::unique_ptr<WorkerThreadPool> m_HydrationWorkerPool; + std::unique_ptr<WorkerThreadPool> m_ProvisionPool; + std::unique_ptr<WorkerThreadPool> m_SpawnPool; + std::unique_ptr<WorkerThreadPool> m_HydrationPool; std::unique_ptr<Hub> m_Hub; std::unique_ptr<HttpHubService> m_HubService; diff --git a/src/zenutil/cloud/s3client.cpp b/src/zenutil/cloud/s3client.cpp index ab80cfcc7..83443b98b 100644 --- a/src/zenutil/cloud/s3client.cpp +++ b/src/zenutil/cloud/s3client.cpp @@ -4,6 +4,7 @@ #include <zenutil/cloud/imdscredentials.h> #include <zenutil/cloud/minioprocess.h> +#include <zenutil/cloud/s3response.h> #include <zencore/except_fmt.h> #include <zencore/iobuffer.h> @@ -19,210 +20,21 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -namespace { - - /// The SHA-256 hash of an empty payload, precomputed - constexpr std::string_view EmptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; - - /// Simple XML value extractor. Finds the text content between <Tag> and </Tag>. - /// This is intentionally minimal - we only need to parse ListBucketResult responses. - /// Returns a string_view into the original XML when no entity decoding is needed. - std::string_view ExtractXmlValue(std::string_view Xml, std::string_view Tag) - { - std::string OpenTag = fmt::format("<{}>", Tag); - std::string CloseTag = fmt::format("</{}>", Tag); - - size_t Start = Xml.find(OpenTag); - if (Start == std::string_view::npos) - { - return {}; - } - Start += OpenTag.size(); - - size_t End = Xml.find(CloseTag, Start); - if (End == std::string_view::npos) - { - return {}; - } - - return Xml.substr(Start, End - Start); - } - - /// Decode the five standard XML entities (& < > " ') into a StringBuilderBase. - void DecodeXmlEntities(std::string_view Input, StringBuilderBase& Out) - { - if (Input.find('&') == std::string_view::npos) - { - Out.Append(Input); - return; - } - - for (size_t i = 0; i < Input.size(); ++i) - { - if (Input[i] == '&') - { - std::string_view Remaining = Input.substr(i); - if (Remaining.starts_with("&")) - { - Out.Append('&'); - i += 4; - } - else if (Remaining.starts_with("<")) - { - Out.Append('<'); - i += 3; - } - else if (Remaining.starts_with(">")) - { - Out.Append('>'); - i += 3; - } - else if (Remaining.starts_with(""")) - { - Out.Append('"'); - i += 5; - } - else if (Remaining.starts_with("'")) - { - Out.Append('\''); - i += 5; - } - else - { - Out.Append(Input[i]); - } - } - else - { - Out.Append(Input[i]); - } - } - } - - /// Convenience: decode XML entities and return as std::string. - std::string DecodeXmlEntities(std::string_view Input) - { - if (Input.find('&') == std::string_view::npos) - { - return std::string(Input); - } - - ExtendableStringBuilder<256> Sb; - DecodeXmlEntities(Input, Sb); - return Sb.ToString(); - } - - /// Join a path and canonical query string into a full request path for the HTTP client. - std::string BuildRequestPath(std::string_view Path, std::string_view CanonicalQS) - { - if (CanonicalQS.empty()) - { - return std::string(Path); - } - return fmt::format("{}?{}", Path, CanonicalQS); - } - - /// Case-insensitive header lookup in an HttpClient response header map. - const std::string* FindResponseHeader(const HttpClient::KeyValueMap& Headers, std::string_view Name) - { - for (const auto& [K, V] : *Headers) - { - if (StrCaseCompare(K, Name) == 0) - { - return &V; - } - } - return nullptr; - } - - /// Extract Code/Message from an S3 XML error body. Returns true if an <Error> element was - /// found, even if Code/Message are empty. - bool ExtractS3Error(std::string_view Body, std::string_view& OutCode, std::string_view& OutMessage) - { - if (Body.find("<Error>") == std::string_view::npos) - { - return false; - } - OutCode = ExtractXmlValue(Body, "Code"); - OutMessage = ExtractXmlValue(Body, "Message"); - return true; - } - - /// True if the response indicates S3 throttling (503 SlowDown / ServiceUnavailable / 429). - /// Code is checked on both the HTTP status and the XML error code so we catch proxies that - /// return 200 with a SlowDown body. - bool IsS3Throttled(const HttpClient::Response& Response, std::string_view ErrorCode) - { - const int Status = static_cast<int>(Response.StatusCode); - if (Status == 503 || Status == 429) - { - return true; - } - if (ErrorCode == "SlowDown" || ErrorCode == "ServiceUnavailable" || ErrorCode == "ThrottlingException" || - ErrorCode == "RequestLimitExceeded" || ErrorCode == "TooManyRequests") - { - return true; - } - return false; - } - - /// Build a human-readable error message for a failed S3 response. When the response body - /// contains an S3 `<Error>` element, the Code and Message fields are included in the string - /// so transient 4xx/5xx failures (SignatureDoesNotMatch, AuthorizationHeaderMalformed, etc.) - /// show up in logs instead of being swallowed. Falls back to the generic HTTP/transport - /// message when no XML body is available (HEAD responses, transport errors). - /// Also emits a distinct `S3 THROTTLED` warning when the response indicates throttling so - /// callers can grep for it without parsing combined error text. - std::string S3ErrorMessage(std::string_view Prefix, const HttpClient::Response& Response) - { - if (!Response.Error.has_value() && Response.ResponsePayload) - { - std::string_view Body(reinterpret_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize()); - std::string_view Code; - std::string_view Message; - if (ExtractS3Error(Body, Code, Message) && (!Code.empty() || !Message.empty())) - { - ExtendableStringBuilder<256> Decoded; - DecodeXmlEntities(Message, Decoded); - if (IsS3Throttled(Response, Code)) - { - ZEN_WARN("S3 THROTTLED [{}] status={} code='{}' message='{}'", - Prefix, - static_cast<int>(Response.StatusCode), - Code, - Decoded.ToView()); - } - return fmt::format("{}: HTTP status ({}) {} - {}", Prefix, static_cast<int>(Response.StatusCode), Code, Decoded.ToView()); - } - } - if (IsS3Throttled(Response, {})) - { - ZEN_WARN("S3 THROTTLED [{}] status={} (no XML body)", Prefix, static_cast<int>(Response.StatusCode)); - } - return Response.ErrorMessage(Prefix); - } - -} // namespace - std::string_view S3GetObjectResult::NotFoundErrorText = "Not found"; S3Client::S3Client(const S3ClientOptions& Options) : m_Log(logging::Get("s3")) -, m_BucketName(Options.BucketName) -, m_Region(Options.Region) -, m_Endpoint(Options.Endpoint) -, m_PathStyle(Options.PathStyle) +, m_Builder(Options.Region, Options.BucketName, Options.Endpoint, Options.PathStyle) , m_Credentials(Options.Credentials) , m_CredentialProvider(Options.CredentialProvider) -, m_HttpClient(BuildEndpoint(), Options.HttpSettings) +, m_HttpClient(std::string(m_Builder.Endpoint()), Options.HttpSettings) , m_Verbose(Options.HttpSettings.Verbose) { - m_Host = BuildHostHeader(); ZEN_INFO("S3 client configured for bucket '{}' in region '{}' (endpoint: {}, {})", - m_BucketName, - m_Region, + m_Builder.BucketName(), + m_Builder.Region(), m_HttpClient.GetBaseUri(), - m_PathStyle ? "path-style" : "virtual-hosted"); + m_Builder.PathStyle() ? "path-style" : "virtual-hosted"); } S3Client::~S3Client() = default; @@ -235,19 +47,15 @@ S3Client::GetCurrentCredentials() SigV4Credentials Creds = m_CredentialProvider->GetCredentials(); if (!Creds.AccessKeyId.empty()) { - // Invalidate the signing key cache when the access key changes, and update stored - // credentials atomically under the same lock so callers see a consistent snapshot. - RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); - if (Creds.AccessKeyId != m_Credentials.AccessKeyId) - { - m_CachedDateStamp.clear(); - } + // Update stored credentials atomically so callers see a consistent snapshot. + // The builder's signing-key cache is keyed on (DateStamp, AccessKeyId), so it + // self-invalidates on the next Sign() call when either rotates. + RwLock::ExclusiveLockScope _(m_CredentialsLock); m_Credentials = Creds; - // Return Creds directly - avoids reading m_Credentials after releasing the lock, - // which would race with another concurrent write. return Creds; } // IMDS returned empty credentials; fall back to the last known-good credentials. + RwLock::SharedLockScope _(m_CredentialsLock); return m_Credentials; } return m_Credentials; @@ -261,153 +69,6 @@ S3Client::BuildNoCredentialsError(std::string Context) return Err; } -std::string -S3Client::BuildEndpoint() const -{ - if (!m_Endpoint.empty()) - { - return m_Endpoint; - } - - if (m_PathStyle) - { - // Path-style: https://s3.region.amazonaws.com - return fmt::format("https://s3.{}.amazonaws.com", m_Region); - } - - // Virtual-hosted style: https://bucket.s3.region.amazonaws.com - return fmt::format("https://{}.s3.{}.amazonaws.com", m_BucketName, m_Region); -} - -std::string -S3Client::BuildHostHeader() const -{ - if (!m_Endpoint.empty()) - { - // Extract host from custom endpoint URL (strip scheme) - std::string_view Ep = m_Endpoint; - if (size_t Pos = Ep.find("://"); Pos != std::string_view::npos) - { - Ep = Ep.substr(Pos + 3); - } - // Strip trailing slash - if (!Ep.empty() && Ep.back() == '/') - { - Ep = Ep.substr(0, Ep.size() - 1); - } - return std::string(Ep); - } - - if (m_PathStyle) - { - return fmt::format("s3.{}.amazonaws.com", m_Region); - } - - return fmt::format("{}.s3.{}.amazonaws.com", m_BucketName, m_Region); -} - -std::string -S3Client::KeyToPath(std::string_view Key) const -{ - if (m_PathStyle) - { - return fmt::format("/{}/{}", m_BucketName, Key); - } - return fmt::format("/{}", Key); -} - -std::string -S3Client::BucketRootPath() const -{ - if (m_PathStyle) - { - return fmt::format("/{}/", m_BucketName); - } - return "/"; -} - -Sha256Digest -S3Client::GetSigningKey(std::string_view DateStamp) -{ - // Fast path: shared lock for cache hit (common case - key only changes once per day) - { - RwLock::SharedLockScope SharedLock(m_SigningKeyLock); - if (m_CachedDateStamp == DateStamp) - { - return m_CachedSigningKey; - } - } - - // Slow path: exclusive lock to recompute the signing key - RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); - - // Double-check after acquiring exclusive lock (another thread may have updated it) - if (m_CachedDateStamp == DateStamp) - { - return m_CachedSigningKey; - } - - std::string SecretPrefix = fmt::format("AWS4{}", m_Credentials.SecretAccessKey); - - Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), DateStamp.data(), DateStamp.size()); - SecureZeroSecret(SecretPrefix.data(), SecretPrefix.size()); - - Sha256Digest RegionKey = ComputeHmacSha256(DateKey, m_Region); - Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, "s3"); - m_CachedSigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); - m_CachedDateStamp = std::string(DateStamp); - - return m_CachedSigningKey; -} - -HttpClient::KeyValueMap -S3Client::SignRequest(const SigV4Credentials& Credentials, - std::string_view Method, - std::string_view Path, - std::string_view CanonicalQueryString, - std::string_view PayloadHash, - std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders) -{ - std::string AmzDate = GetAmzTimestamp(); - - // Build sorted headers to sign (must be sorted by lowercase name) - std::vector<std::pair<std::string, std::string>> HeadersToSign; - HeadersToSign.reserve(4 + ExtraSignedHeaders.size()); - HeadersToSign.emplace_back("host", m_Host); - HeadersToSign.emplace_back("x-amz-content-sha256", std::string(PayloadHash)); - HeadersToSign.emplace_back("x-amz-date", AmzDate); - if (!Credentials.SessionToken.empty()) - { - HeadersToSign.emplace_back("x-amz-security-token", Credentials.SessionToken); - } - for (const auto& [K, V] : ExtraSignedHeaders) - { - HeadersToSign.emplace_back(K, V); - } - std::sort(HeadersToSign.begin(), HeadersToSign.end()); - - std::string_view DateStamp(AmzDate.data(), 8); - Sha256Digest SigningKey = GetSigningKey(DateStamp); - - SigV4SignedHeaders Signed = - SignRequestV4(Credentials, Method, Path, CanonicalQueryString, m_Region, "s3", AmzDate, HeadersToSign, PayloadHash, &SigningKey); - - HttpClient::KeyValueMap Result; - Result->emplace("Authorization", std::move(Signed.Authorization)); - Result->emplace("x-amz-date", std::move(Signed.AmzDate)); - Result->emplace("x-amz-content-sha256", std::move(Signed.PayloadHash)); - if (!Credentials.SessionToken.empty()) - { - Result->emplace("x-amz-security-token", Credentials.SessionToken); - } - for (const auto& [K, V] : ExtraSignedHeaders) - { - Result->emplace(K, V); - } - - return Result; -} - S3Result S3Client::PutObject(std::string_view Key, IoBuffer Content) { @@ -417,12 +78,12 @@ S3Client::PutObject(std::string_view Key, IoBuffer Content) return S3Result{std::move(Err)}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); // Hash the payload std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, "", PayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "PUT", Path, "", PayloadHash); HttpClient::Response Response = m_HttpClient.Put(Path, Content, Headers); if (!Response.IsSuccess()) @@ -448,9 +109,9 @@ S3Client::GetObject(std::string_view Key, const std::filesystem::path& TempFileP return S3GetObjectResult{S3Result{std::move(Err)}, {}}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "GET", Path, "", S3EmptyPayloadHash); HttpClient::Response Response = m_HttpClient.Download(Path, TempFilePath, Headers); if (!Response.IsSuccess()) @@ -484,9 +145,9 @@ S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t Ran return S3GetObjectResult{S3Result{std::move(Err)}, {}}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "GET", Path, "", S3EmptyPayloadHash); Headers->emplace("Range", fmt::format("bytes={}-{}", RangeStart, RangeStart + RangeSize - 1)); HttpClient::Response Response = m_HttpClient.Get(Path, Headers); @@ -537,9 +198,9 @@ S3Client::DeleteObject(std::string_view Key) return S3Result{std::move(Err)}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "DELETE", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "DELETE", Path, "", S3EmptyPayloadHash); HttpClient::Response Response = m_HttpClient.Delete(Path, Headers); if (!Response.IsSuccess()) @@ -565,17 +226,17 @@ S3Client::Touch(std::string_view Key) return S3Result{std::move(Err)}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); // x-amz-copy-source is always "/bucket/key" regardless of addressing style. // Key must be URI-encoded except for '/' separators. When source and destination // are identical, REPLACE is required; COPY is rejected with InvalidRequest. const std::array<std::pair<std::string, std::string>, 2> ExtraSigned{{ - {"x-amz-copy-source", fmt::format("/{}/{}", m_BucketName, AwsUriEncode(Key, false))}, + {"x-amz-copy-source", fmt::format("/{}/{}", m_Builder.BucketName(), AwsUriEncode(Key, false))}, {"x-amz-metadata-directive", "REPLACE"}, }}; - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, "", EmptyPayloadHash, ExtraSigned); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "PUT", Path, "", S3EmptyPayloadHash, ExtraSigned); HttpClient::Response Response = m_HttpClient.Put(Path, IoBuffer{}, Headers); if (!Response.IsSuccess()) @@ -589,7 +250,7 @@ S3Client::Touch(std::string_view Key) std::string_view ResponseBody = Response.AsText(); std::string_view ErrorCode; std::string_view ErrorMessage; - if (ExtractS3Error(ResponseBody, ErrorCode, ErrorMessage)) + if (S3ExtractError(ResponseBody, ErrorCode, ErrorMessage)) { std::string Err = fmt::format("S3 Touch '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); ZEN_WARN("{}", Err); @@ -612,9 +273,9 @@ S3Client::HeadObject(std::string_view Key) return S3HeadObjectResult{S3Result{std::move(Err)}, {}, HeadObjectResult::Error}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "HEAD", Path, "", EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "HEAD", Path, "", S3EmptyPayloadHash); HttpClient::Response Response = m_HttpClient.Head(Path, Headers); if (!Response.IsSuccess()) @@ -632,17 +293,17 @@ S3Client::HeadObject(std::string_view Key) S3ObjectInfo Info; Info.Key = std::string(Key); - if (const std::string* V = FindResponseHeader(Response.Header, "content-length")) + if (const std::string* V = S3FindResponseHeader(Response.Header, "content-length")) { Info.Size = ParseInt<uint64_t>(*V).value_or(0); } - if (const std::string* V = FindResponseHeader(Response.Header, "etag")) + if (const std::string* V = S3FindResponseHeader(Response.Header, "etag")) { Info.ETag = *V; } - if (const std::string* V = FindResponseHeader(Response.Header, "last-modified")) + if (const std::string* V = S3FindResponseHeader(Response.Header, "last-modified")) { Info.LastModified = *V; } @@ -687,10 +348,10 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) } std::string CanonicalQS = BuildCanonicalQueryString(std::move(QueryParams)); - std::string RootPath = BucketRootPath(); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "GET", RootPath, CanonicalQS, EmptyPayloadHash); + std::string RootPath = m_Builder.BucketRootPath(); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "GET", RootPath, CanonicalQS, S3EmptyPayloadHash); - std::string FullPath = BuildRequestPath(RootPath, CanonicalQS); + std::string FullPath = S3BuildRequestPath(RootPath, CanonicalQS); HttpClient::Response Response = m_HttpClient.Get(FullPath, Headers); if (!Response.IsSuccess()) { @@ -722,11 +383,11 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) std::string_view ContentsXml = Remaining.substr(ContentsStart, ContentsEnd - ContentsStart + 11); S3ObjectInfo Info; - Info.Key = DecodeXmlEntities(ExtractXmlValue(ContentsXml, "Key")); - Info.ETag = DecodeXmlEntities(ExtractXmlValue(ContentsXml, "ETag")); - Info.LastModified = std::string(ExtractXmlValue(ContentsXml, "LastModified")); + Info.Key = S3DecodeXmlEntities(S3ExtractXmlValue(ContentsXml, "Key")); + Info.ETag = S3DecodeXmlEntities(S3ExtractXmlValue(ContentsXml, "ETag")); + Info.LastModified = std::string(S3ExtractXmlValue(ContentsXml, "LastModified")); - std::string_view SizeStr = ExtractXmlValue(ContentsXml, "Size"); + std::string_view SizeStr = S3ExtractXmlValue(ContentsXml, "Size"); if (!SizeStr.empty()) { Info.Size = ParseInt<uint64_t>(SizeStr).value_or(0); @@ -741,13 +402,13 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) } // Check if there are more pages - std::string_view IsTruncated = ExtractXmlValue(ResponseBody, "IsTruncated"); + std::string_view IsTruncated = S3ExtractXmlValue(ResponseBody, "IsTruncated"); if (IsTruncated != "true") { break; } - std::string_view NextToken = ExtractXmlValue(ResponseBody, "NextContinuationToken"); + std::string_view NextToken = S3ExtractXmlValue(ResponseBody, "NextContinuationToken"); if (NextToken.empty()) { break; @@ -779,12 +440,12 @@ S3Client::CreateMultipartUpload(std::string_view Key) return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploads", ""}}); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "POST", Path, CanonicalQS, EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "POST", Path, CanonicalQS, S3EmptyPayloadHash); - std::string FullPath = BuildRequestPath(Path, CanonicalQS); + std::string FullPath = S3BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Post(FullPath, Headers); if (!Response.IsSuccess()) { @@ -800,7 +461,7 @@ S3Client::CreateMultipartUpload(std::string_view Key) // <UploadId>...</UploadId> // </InitiateMultipartUploadResult> std::string_view ResponseBody = Response.AsText(); - std::string_view UploadId = ExtractXmlValue(ResponseBody, "UploadId"); + std::string_view UploadId = S3ExtractXmlValue(ResponseBody, "UploadId"); if (UploadId.empty()) { std::string Err = "failed to parse UploadId from CreateMultipartUpload response"; @@ -824,7 +485,7 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P return S3UploadPartResult{S3Result{std::move(Err)}, {}}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({ {"partNumber", fmt::format("{}", PartNumber)}, {"uploadId", std::string(UploadId)}, @@ -832,9 +493,9 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "PUT", Path, CanonicalQS, PayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "PUT", Path, CanonicalQS, PayloadHash); - std::string FullPath = BuildRequestPath(Path, CanonicalQS); + std::string FullPath = S3BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Put(FullPath, Content, Headers); if (!Response.IsSuccess()) { @@ -844,7 +505,7 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P } // Extract ETag from response headers - const std::string* ETag = FindResponseHeader(Response.Header, "etag"); + const std::string* ETag = S3FindResponseHeader(Response.Header, "etag"); if (!ETag) { std::string Err = "S3 UploadPart response missing ETag header"; @@ -870,7 +531,7 @@ S3Client::CompleteMultipartUpload(std::string_view Key, return S3Result{std::move(Err)}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); // Build the CompleteMultipartUpload XML payload @@ -885,12 +546,12 @@ S3Client::CompleteMultipartUpload(std::string_view Key, std::string_view XmlView = XmlBody.ToView(); std::string PayloadHash = Sha256ToHex(ComputeSha256(XmlView)); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "POST", Path, CanonicalQS, PayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "POST", Path, CanonicalQS, PayloadHash); Headers->emplace("Content-Type", "application/xml"); IoBuffer Payload(IoBuffer::Clone, XmlView.data(), XmlView.size()); - std::string FullPath = BuildRequestPath(Path, CanonicalQS); + std::string FullPath = S3BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Post(FullPath, Payload, Headers); if (!Response.IsSuccess()) { @@ -903,7 +564,7 @@ S3Client::CompleteMultipartUpload(std::string_view Key, std::string_view ResponseBody = Response.AsText(); std::string_view ErrorCode; std::string_view ErrorMessage; - if (ExtractS3Error(ResponseBody, ErrorCode, ErrorMessage)) + if (S3ExtractError(ResponseBody, ErrorCode, ErrorMessage)) { std::string Err = fmt::format("S3 CompleteMultipartUpload '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); ZEN_WARN("{}", Err); @@ -926,12 +587,12 @@ S3Client::AbortMultipartUpload(std::string_view Key, std::string_view UploadId) return S3Result{std::move(Err)}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); - HttpClient::KeyValueMap Headers = SignRequest(Credentials, "DELETE", Path, CanonicalQS, EmptyPayloadHash); + HttpClient::KeyValueMap Headers = m_Builder.SignRequest(Credentials, "DELETE", Path, CanonicalQS, S3EmptyPayloadHash); - std::string FullPath = BuildRequestPath(Path, CanonicalQS); + std::string FullPath = S3BuildRequestPath(Path, CanonicalQS); HttpClient::Response Response = m_HttpClient.Delete(FullPath, Headers); if (!Response.IsSuccess()) { @@ -968,15 +629,15 @@ S3Client::GeneratePresignedUrlForMethod(std::string_view Key, std::string_view M return {}; } - std::string Path = KeyToPath(Key); + std::string Path = m_Builder.KeyToPath(Key); std::string Scheme = "https"; - if (!m_Endpoint.empty() && m_Endpoint.starts_with("http://")) + if (!m_Builder.Endpoint().empty() && m_Builder.Endpoint().starts_with("http://")) { Scheme = "http"; } - return GeneratePresignedUrl(Credentials, Method, Scheme, m_Host, Path, m_Region, "s3", ExpiresIn); + return GeneratePresignedUrl(Credentials, Method, Scheme, m_Builder.Host(), Path, m_Builder.Region(), "s3", ExpiresIn); } S3Result @@ -1004,8 +665,26 @@ S3Client::PutObjectMultipart(std::string_view Key, const std::string& UploadId = InitResult.UploadId; - // Upload parts sequentially - // TODO: upload parts in parallel for improved throughput on large uploads + // Cleanup helper: AbortMultipartUpload itself can throw on transport failure; + // inside the catch (...) below that would replace the original exception with + // a less actionable transport one. Swallow + log. + auto SafeAbort = [this, &Key, &UploadId]() noexcept { + try + { + AbortMultipartUpload(Key, UploadId); + } + catch (const std::exception& Ex) + { + ZEN_WARN("S3 AbortMultipartUpload '{}' threw during cleanup: {}", Key, Ex.what()); + } + catch (...) + { + ZEN_WARN("S3 AbortMultipartUpload '{}' threw during cleanup", Key); + } + }; + + // Sequential upload by design; for parallel multipart use S3AsyncStorage::PutMultipart + // in the hub hydration path. std::vector<std::pair<uint32_t, std::string>> PartETags; uint64_t Offset = 0; @@ -1020,7 +699,7 @@ S3Client::PutObjectMultipart(std::string_view Key, S3UploadPartResult PartResult = UploadPart(Key, UploadId, PartNumber, std::move(PartContent)); if (!PartResult) { - AbortMultipartUpload(Key, UploadId); + SafeAbort(); return S3Result{std::move(PartResult.Error)}; } @@ -1032,13 +711,13 @@ S3Client::PutObjectMultipart(std::string_view Key, S3Result CompleteResult = CompleteMultipartUpload(Key, UploadId, PartETags); if (!CompleteResult) { - AbortMultipartUpload(Key, UploadId); + SafeAbort(); return CompleteResult; } } catch (...) { - AbortMultipartUpload(Key, UploadId); + SafeAbort(); throw; } @@ -1077,25 +756,25 @@ TEST_CASE("s3client.xml_extract") "<Contents><Key>test/file.txt</Key><Size>1234</Size>" "<ETag>\"abc123\"</ETag><LastModified>2024-01-01T00:00:00Z</LastModified></Contents>"; - CHECK(ExtractXmlValue(Xml, "Key") == "test/file.txt"); - CHECK(ExtractXmlValue(Xml, "Size") == "1234"); - CHECK(ExtractXmlValue(Xml, "ETag") == "\"abc123\""); - CHECK(ExtractXmlValue(Xml, "LastModified") == "2024-01-01T00:00:00Z"); - CHECK(ExtractXmlValue(Xml, "NonExistent") == ""); + CHECK(S3ExtractXmlValue(Xml, "Key") == "test/file.txt"); + CHECK(S3ExtractXmlValue(Xml, "Size") == "1234"); + CHECK(S3ExtractXmlValue(Xml, "ETag") == "\"abc123\""); + CHECK(S3ExtractXmlValue(Xml, "LastModified") == "2024-01-01T00:00:00Z"); + CHECK(S3ExtractXmlValue(Xml, "NonExistent") == ""); } TEST_CASE("s3client.xml_entity_decode") { - CHECK(DecodeXmlEntities("no entities") == "no entities"); - CHECK(DecodeXmlEntities("a&b") == "a&b"); - CHECK(DecodeXmlEntities("<tag>") == "<tag>"); - CHECK(DecodeXmlEntities(""hello'") == "\"hello'"); - CHECK(DecodeXmlEntities("&&") == "&&"); - CHECK(DecodeXmlEntities("") == ""); + CHECK(S3DecodeXmlEntities("no entities") == "no entities"); + CHECK(S3DecodeXmlEntities("a&b") == "a&b"); + CHECK(S3DecodeXmlEntities("<tag>") == "<tag>"); + CHECK(S3DecodeXmlEntities(""hello'") == "\"hello'"); + CHECK(S3DecodeXmlEntities("&&") == "&&"); + CHECK(S3DecodeXmlEntities("") == ""); // Key with entities as S3 would return it std::string_view Xml = "<Key>path/file&name<1>.txt</Key>"; - CHECK(DecodeXmlEntities(ExtractXmlValue(Xml, "Key")) == "path/file&name<1>.txt"); + CHECK(S3DecodeXmlEntities(S3ExtractXmlValue(Xml, "Key")) == "path/file&name<1>.txt"); } TEST_CASE("s3client.path_style_addressing") @@ -1129,6 +808,97 @@ TEST_CASE("s3client.virtual_hosted_addressing") CHECK(Client.Region() == "eu-west-1"); } +TEST_CASE("s3requestbuilder.path_style_paths") +{ + S3RequestBuilder Builder("us-east-1", "test-bucket", "http://localhost:9000", /*PathStyle*/ true); + + CHECK(Builder.Region() == "us-east-1"); + CHECK(Builder.BucketName() == "test-bucket"); + CHECK(Builder.PathStyle()); + CHECK(Builder.Endpoint() == "http://localhost:9000"); + CHECK(Builder.Host() == "localhost:9000"); + CHECK(Builder.KeyToPath("foo/bar") == "/test-bucket/foo/bar"); + CHECK(Builder.BucketRootPath() == "/test-bucket/"); +} + +TEST_CASE("s3requestbuilder.virtual_hosted_paths") +{ + S3RequestBuilder Builder("eu-west-1", "my-bucket", /*Endpoint*/ "", /*PathStyle*/ false); + + CHECK(Builder.Endpoint() == "https://my-bucket.s3.eu-west-1.amazonaws.com"); + CHECK(Builder.Host() == "my-bucket.s3.eu-west-1.amazonaws.com"); + CHECK(Builder.KeyToPath("foo/bar") == "/foo/bar"); + CHECK(Builder.BucketRootPath() == "/"); +} + +TEST_CASE("s3requestbuilder.derived_endpoint_path_style") +{ + S3RequestBuilder Builder("us-west-2", "another-bucket", /*Endpoint*/ "", /*PathStyle*/ true); + CHECK(Builder.Endpoint() == "https://s3.us-west-2.amazonaws.com"); + CHECK(Builder.Host() == "s3.us-west-2.amazonaws.com"); +} + +TEST_CASE("s3requestbuilder.sign_request_headers") +{ + S3RequestBuilder Builder("us-east-1", "bucket", "http://localhost:9000", /*PathStyle*/ true); + + SigV4Credentials Creds; + Creds.AccessKeyId = "AKIDEXAMPLE"; + Creds.SecretAccessKey = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"; + + HttpClient::KeyValueMap Headers = Builder.SignRequest(Creds, "GET", "/bucket/foo", "", S3EmptyPayloadHash); + + // All four required SigV4 headers present. + CHECK(Headers->find("Authorization") != Headers->end()); + CHECK(Headers->find("x-amz-date") != Headers->end()); + CHECK(Headers->find("x-amz-content-sha256") != Headers->end()); + + // SessionToken absent -> no x-amz-security-token header. + CHECK(Headers->find("x-amz-security-token") == Headers->end()); + + const std::string& Auth = Headers->find("Authorization")->second; + CHECK(Auth.starts_with("AWS4-HMAC-SHA256 ")); + CHECK(Auth.find("Credential=AKIDEXAMPLE/") != std::string::npos); + CHECK(Auth.find("/us-east-1/s3/aws4_request") != std::string::npos); + CHECK(Auth.find("Signature=") != std::string::npos); +} + +TEST_CASE("s3requestbuilder.session_token_emits_security_header") +{ + S3RequestBuilder Builder("us-east-1", "bucket", "http://localhost:9000", true); + + SigV4Credentials Creds; + Creds.AccessKeyId = "ASIA-tmp"; + Creds.SecretAccessKey = "secret"; + Creds.SessionToken = "sts-session-token-value"; + + HttpClient::KeyValueMap Headers = Builder.SignRequest(Creds, "PUT", "/bucket/key", "", S3EmptyPayloadHash); + + auto It = Headers->find("x-amz-security-token"); + REQUIRE(It != Headers->end()); + CHECK(It->second == "sts-session-token-value"); +} + +TEST_CASE("s3requestbuilder.signing_key_cache_invalidates_on_key_rotate") +{ + // Two consecutive Sign() calls with the same date but different AccessKeyId + // must produce different Authorization signatures, proving the cache is keyed + // on (DateStamp, AccessKeyId) and not date alone. + S3RequestBuilder Builder("us-east-1", "bucket", "http://localhost:9000", true); + + SigV4Credentials A; + A.AccessKeyId = "AKIDEXAMPLE"; + A.SecretAccessKey = "secretA"; + HttpClient::KeyValueMap HA = Builder.SignRequest(A, "GET", "/bucket/foo", "", S3EmptyPayloadHash); + + SigV4Credentials B; + B.AccessKeyId = "AKIDEXAMPLE2"; + B.SecretAccessKey = "secretB"; + HttpClient::KeyValueMap HB = Builder.SignRequest(B, "GET", "/bucket/foo", "", S3EmptyPayloadHash); + + CHECK(HA->find("Authorization")->second != HB->find("Authorization")->second); +} + TEST_CASE("s3client.minio_integration") { using namespace std::literals; diff --git a/src/zenutil/cloud/s3requestbuilder.cpp b/src/zenutil/cloud/s3requestbuilder.cpp new file mode 100644 index 000000000..7bf200308 --- /dev/null +++ b/src/zenutil/cloud/s3requestbuilder.cpp @@ -0,0 +1,149 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/s3requestbuilder.h> + +#include <zencore/fmtutils.h> + +#include <algorithm> + +namespace zen { + +S3RequestBuilder::S3RequestBuilder(std::string Region, std::string BucketName, std::string Endpoint, bool PathStyle) +: m_Region(std::move(Region)) +, m_BucketName(std::move(BucketName)) +, m_Endpoint(DeriveEndpoint(Endpoint, m_Region, m_BucketName, PathStyle)) +, m_Host(HostFromEndpoint(m_Endpoint)) +, m_PathStyle(PathStyle) +{ +} + +std::string +S3RequestBuilder::DeriveEndpoint(std::string_view Endpoint, std::string_view Region, std::string_view BucketName, bool PathStyle) +{ + if (!Endpoint.empty()) + { + return std::string(Endpoint); + } + if (PathStyle) + { + return fmt::format("https://s3.{}.amazonaws.com", Region); + } + return fmt::format("https://{}.s3.{}.amazonaws.com", BucketName, Region); +} + +std::string +S3RequestBuilder::HostFromEndpoint(std::string_view Endpoint) +{ + std::string_view Ep = Endpoint; + if (size_t Pos = Ep.find("://"); Pos != std::string_view::npos) + { + Ep = Ep.substr(Pos + 3); + } + if (!Ep.empty() && Ep.back() == '/') + { + Ep = Ep.substr(0, Ep.size() - 1); + } + return std::string(Ep); +} + +std::string +S3RequestBuilder::KeyToPath(std::string_view Key) const +{ + if (m_PathStyle) + { + return fmt::format("/{}/{}", m_BucketName, Key); + } + return fmt::format("/{}", Key); +} + +std::string +S3RequestBuilder::BucketRootPath() const +{ + if (m_PathStyle) + { + return fmt::format("/{}/", m_BucketName); + } + return "/"; +} + +// Cached for the (DateStamp, AccessKeyId) tuple. DateStamp rolls over at UTC +// midnight; concurrent callers around the rollover may each compute a fresh +// key once before the cache settles on the new day's value. Harmless extra +// HMACs, no signature corruption. +Sha256Digest +S3RequestBuilder::GetSigningKey(std::string_view DateStamp, const SigV4Credentials& Credentials) +{ + { + RwLock::SharedLockScope _(m_SigningKeyLock); + if (m_CachedDateStamp == DateStamp && m_CachedAccessKeyId == Credentials.AccessKeyId) + { + return m_CachedSigningKey; + } + } + + RwLock::ExclusiveLockScope _(m_SigningKeyLock); + if (m_CachedDateStamp == DateStamp && m_CachedAccessKeyId == Credentials.AccessKeyId) + { + return m_CachedSigningKey; + } + + std::string SecretPrefix = fmt::format("AWS4{}", Credentials.SecretAccessKey); + Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), DateStamp.data(), DateStamp.size()); + SecureZeroSecret(SecretPrefix.data(), SecretPrefix.size()); + + Sha256Digest RegionKey = ComputeHmacSha256(DateKey, m_Region); + Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, "s3"); + m_CachedSigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); + m_CachedDateStamp = std::string(DateStamp); + m_CachedAccessKeyId = Credentials.AccessKeyId; + + return m_CachedSigningKey; +} + +HttpClient::KeyValueMap +S3RequestBuilder::SignRequest(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Path, + std::string_view CanonicalQueryString, + std::string_view PayloadHash, + std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders) +{ + std::string AmzDate = GetAmzTimestamp(); + + std::vector<std::pair<std::string, std::string>> HeadersToSign; + HeadersToSign.reserve(4 + ExtraSignedHeaders.size()); + HeadersToSign.emplace_back("host", m_Host); + HeadersToSign.emplace_back("x-amz-content-sha256", std::string(PayloadHash)); + HeadersToSign.emplace_back("x-amz-date", AmzDate); + if (!Credentials.SessionToken.empty()) + { + HeadersToSign.emplace_back("x-amz-security-token", Credentials.SessionToken); + } + for (const auto& [K, V] : ExtraSignedHeaders) + { + HeadersToSign.emplace_back(K, V); + } + std::sort(HeadersToSign.begin(), HeadersToSign.end()); + + std::string_view DateStamp(AmzDate.data(), 8); + Sha256Digest SigningKey = GetSigningKey(DateStamp, Credentials); + + SigV4SignedHeaders Signed = + SignRequestV4(Credentials, Method, Path, CanonicalQueryString, m_Region, "s3", AmzDate, HeadersToSign, PayloadHash, &SigningKey); + + HttpClient::KeyValueMap Result; + Result->emplace("Authorization", std::move(Signed.Authorization)); + Result->emplace("x-amz-date", std::move(Signed.AmzDate)); + Result->emplace("x-amz-content-sha256", std::move(Signed.PayloadHash)); + if (!Credentials.SessionToken.empty()) + { + Result->emplace("x-amz-security-token", Credentials.SessionToken); + } + for (const auto& [K, V] : ExtraSignedHeaders) + { + Result->emplace(K, V); + } + return Result; +} + +} // namespace zen diff --git a/src/zenutil/cloud/s3response.cpp b/src/zenutil/cloud/s3response.cpp new file mode 100644 index 000000000..a9e7f0208 --- /dev/null +++ b/src/zenutil/cloud/s3response.cpp @@ -0,0 +1,181 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/s3response.h> + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> + +namespace zen { + +std::string_view +S3ExtractXmlValue(std::string_view Xml, std::string_view Tag) +{ + std::string OpenTag = fmt::format("<{}>", Tag); + std::string CloseTag = fmt::format("</{}>", Tag); + + size_t Start = Xml.find(OpenTag); + if (Start == std::string_view::npos) + { + return {}; + } + Start += OpenTag.size(); + + size_t End = Xml.find(CloseTag, Start); + if (End == std::string_view::npos) + { + return {}; + } + + return Xml.substr(Start, End - Start); +} + +void +S3DecodeXmlEntities(std::string_view Input, StringBuilderBase& Out) +{ + if (Input.find('&') == std::string_view::npos) + { + Out.Append(Input); + return; + } + + for (size_t i = 0; i < Input.size(); ++i) + { + if (Input[i] == '&') + { + std::string_view Remaining = Input.substr(i); + if (Remaining.starts_with("&")) + { + Out.Append('&'); + i += 4; + } + else if (Remaining.starts_with("<")) + { + Out.Append('<'); + i += 3; + } + else if (Remaining.starts_with(">")) + { + Out.Append('>'); + i += 3; + } + else if (Remaining.starts_with(""")) + { + Out.Append('"'); + i += 5; + } + else if (Remaining.starts_with("'")) + { + Out.Append('\''); + i += 5; + } + else + { + Out.Append(Input[i]); + } + } + else + { + Out.Append(Input[i]); + } + } +} + +std::string +S3DecodeXmlEntities(std::string_view Input) +{ + if (Input.find('&') == std::string_view::npos) + { + return std::string(Input); + } + ExtendableStringBuilder<256> Sb; + S3DecodeXmlEntities(Input, Sb); + return Sb.ToString(); +} + +std::string +S3BuildRequestPath(std::string_view Path, std::string_view CanonicalQS) +{ + if (CanonicalQS.empty()) + { + return std::string(Path); + } + return fmt::format("{}?{}", Path, CanonicalQS); +} + +const std::string* +S3FindResponseHeader(const HttpClient::KeyValueMap& Headers, std::string_view Name) +{ + for (const auto& [K, V] : *Headers) + { + if (StrCaseCompare(K, Name) == 0) + { + return &V; + } + } + return nullptr; +} + +bool +S3ExtractError(std::string_view Body, std::string_view& OutCode, std::string_view& OutMessage) +{ + if (Body.find("<Error>") == std::string_view::npos) + { + return false; + } + OutCode = S3ExtractXmlValue(Body, "Code"); + OutMessage = S3ExtractXmlValue(Body, "Message"); + // Treat malformed bodies (Error tag present but no parseable Code/Message) + // as a parse miss; callers format "<prefix>: <Code> - <Message>" and an + // empty render is indistinguishable from "no error". S3IsThrottled with + // empty ErrorCode + S3ErrorMessage's Response.ErrorMessage fallback path + // covers status-only triage. + return !OutCode.empty() || !OutMessage.empty(); +} + +bool +S3IsThrottled(const HttpClient::Response& Response, std::string_view ErrorCode) +{ + const int Status = static_cast<int>(Response.StatusCode); + if (Status == 503 || Status == 429) + { + return true; + } + if (ErrorCode == "SlowDown" || ErrorCode == "ServiceUnavailable" || ErrorCode == "ThrottlingException" || + ErrorCode == "RequestLimitExceeded" || ErrorCode == "TooManyRequests") + { + return true; + } + return false; +} + +std::string +S3ErrorMessage(std::string_view Prefix, const HttpClient::Response& Response) +{ + if (!Response.Error.has_value() && Response.ResponsePayload) + { + std::string_view Body(reinterpret_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize()); + std::string_view Code; + std::string_view Message; + if (S3ExtractError(Body, Code, Message)) + { + ExtendableStringBuilder<256> Decoded; + S3DecodeXmlEntities(Message, Decoded); + if (S3IsThrottled(Response, Code)) + { + ZEN_WARN("S3 THROTTLED [{}] status={} code='{}' message='{}'", + Prefix, + static_cast<int>(Response.StatusCode), + Code, + Decoded.ToView()); + } + return fmt::format("{}: HTTP status ({}) {} - {}", Prefix, static_cast<int>(Response.StatusCode), Code, Decoded.ToView()); + } + } + if (S3IsThrottled(Response, {})) + { + ZEN_WARN("S3 THROTTLED [{}] status={} (no XML body)", Prefix, static_cast<int>(Response.StatusCode)); + } + return Response.ErrorMessage(Prefix); +} + +} // namespace zen diff --git a/src/zenutil/cloud/sigv4.cpp b/src/zenutil/cloud/sigv4.cpp index 055ccb2ad..34bd7f5f3 100644 --- a/src/zenutil/cloud/sigv4.cpp +++ b/src/zenutil/cloud/sigv4.cpp @@ -53,6 +53,62 @@ ComputeSha256(const void* Data, size_t Size) return Result; } +Sha256Stream::Sha256Stream() +{ + EVP_MD_CTX* Ctx = EVP_MD_CTX_new(); + ZEN_ASSERT(Ctx != nullptr); + int Rc = EVP_DigestInit_ex(Ctx, EVP_sha256(), nullptr); + ZEN_ASSERT(Rc == 1); + m_Ctx = Ctx; +} + +Sha256Stream::~Sha256Stream() +{ + if (m_Ctx) + { + EVP_MD_CTX_free(static_cast<EVP_MD_CTX*>(m_Ctx)); + m_Ctx = nullptr; + } +} + +Sha256Stream::Sha256Stream(Sha256Stream&& Other) noexcept : m_Ctx(Other.m_Ctx) +{ + Other.m_Ctx = nullptr; +} + +Sha256Stream& +Sha256Stream::operator=(Sha256Stream&& Other) noexcept +{ + if (this != &Other) + { + if (m_Ctx) + { + EVP_MD_CTX_free(static_cast<EVP_MD_CTX*>(m_Ctx)); + } + m_Ctx = Other.m_Ctx; + Other.m_Ctx = nullptr; + } + return *this; +} + +void +Sha256Stream::Update(const void* Data, size_t Size) +{ + int Rc = EVP_DigestUpdate(static_cast<EVP_MD_CTX*>(m_Ctx), Data, Size); + ZEN_ASSERT(Rc == 1); +} + +Sha256Digest +Sha256Stream::Finalize() +{ + Sha256Digest Result; + unsigned int Len = 0; + int Rc = EVP_DigestFinal_ex(static_cast<EVP_MD_CTX*>(m_Ctx), Result.data(), &Len); + ZEN_ASSERT(Rc == 1); + ZEN_ASSERT(Len == 32); + return Result; +} + Sha256Digest ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize) { @@ -171,6 +227,62 @@ ComputeSha256(const void* Data, size_t Size) return BcryptHash(GetBcryptHandles().Sha256, Data, Size); } +Sha256Stream::Sha256Stream() +{ + BCRYPT_HASH_HANDLE Handle = nullptr; + NTSTATUS Status = BCryptCreateHash(GetBcryptHandles().Sha256, &Handle, nullptr, 0, nullptr, 0, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + m_Ctx = Handle; +} + +Sha256Stream::~Sha256Stream() +{ + if (m_Ctx) + { + BCryptDestroyHash(static_cast<BCRYPT_HASH_HANDLE>(m_Ctx)); + m_Ctx = nullptr; + } +} + +Sha256Stream::Sha256Stream(Sha256Stream&& Other) noexcept : m_Ctx(Other.m_Ctx) +{ + Other.m_Ctx = nullptr; +} + +Sha256Stream& +Sha256Stream::operator=(Sha256Stream&& Other) noexcept +{ + if (this != &Other) + { + if (m_Ctx) + { + BCryptDestroyHash(static_cast<BCRYPT_HASH_HANDLE>(m_Ctx)); + } + m_Ctx = Other.m_Ctx; + Other.m_Ctx = nullptr; + } + return *this; +} + +void +Sha256Stream::Update(const void* Data, size_t Size) +{ + NTSTATUS Status = BCryptHashData(static_cast<BCRYPT_HASH_HANDLE>(m_Ctx), + reinterpret_cast<PUCHAR>(const_cast<void*>(Data)), + static_cast<ULONG>(Size), + 0); + ZEN_ASSERT(NT_SUCCESS(Status)); +} + +Sha256Digest +Sha256Stream::Finalize() +{ + Sha256Digest Result; + NTSTATUS Status = BCryptFinishHash(static_cast<BCRYPT_HASH_HANDLE>(m_Ctx), Result.data(), static_cast<ULONG>(Result.size()), 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + return Result; +} + Sha256Digest ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize) { @@ -251,6 +363,8 @@ GetAmzTimestamp() std::string AwsUriEncode(std::string_view Input, bool EncodeSlash) { + static constexpr char kHex[] = "0123456789ABCDEF"; + ExtendableStringBuilder<256> Result; for (char C : Input) { @@ -264,7 +378,11 @@ AwsUriEncode(std::string_view Input, bool EncodeSlash) } else { - Result.Append(fmt::format("%{:02X}", static_cast<unsigned char>(C))); + // Hand-rolled hex encode; avoids per-char fmt::format std::string alloc + // inside a loop that runs once per non-unreserved input character. + const uint8_t Byte = static_cast<uint8_t>(C); + const char Encoded[3] = {'%', kHex[Byte >> 4], kHex[Byte & 0x0F]}; + Result.Append(std::string_view(Encoded, 3)); } } return std::string(Result.ToView()); @@ -313,18 +431,10 @@ SignRequestV4(const SigV4Credentials& Credentials, std::string DateStamp = GetDateStamp(Result.AmzDate); - // Step 1: Create canonical request - // CanonicalRequest = - // HTTPRequestMethod + '\n' + - // CanonicalURI + '\n' + - // CanonicalQueryString + '\n' + - // CanonicalHeaders + '\n' + - // SignedHeaders + '\n' + - // HexEncode(Hash(RequestPayload)) - + // Step 1: canonical request. std::string CanonicalUri = AwsUriEncode(Url, false); - // Build canonical headers and signed headers (headers must be sorted by lowercase name) + // Headers assumed pre-sorted by lowercase name. ExtendableStringBuilder<512> CanonicalHeadersSb; ExtendableStringBuilder<256> SignedHeadersSb; @@ -342,30 +452,45 @@ SignRequestV4(const SigV4Credentials& Credentials, SignedHeadersSb.Append(Headers[i].first); } - std::string SignedHeaders = std::string(SignedHeadersSb.ToView()); - - std::string CanonicalRequest = fmt::format("{}\n{}\n{}\n{}\n{}\n{}", - Method, - CanonicalUri, - CanonicalQueryString, - CanonicalHeadersSb.ToView(), - SignedHeaders, - PayloadHash); - - // Step 2: Create the string to sign - std::string CredentialScope = fmt::format("{}/{}/{}/aws4_request", DateStamp, Region, Service); + std::string_view SignedHeaders = SignedHeadersSb.ToView(); + + ExtendableStringBuilder<2048> CanonicalRequestSb; + CanonicalRequestSb.Append(Method); + CanonicalRequestSb.Append('\n'); + CanonicalRequestSb.Append(CanonicalUri); + CanonicalRequestSb.Append('\n'); + CanonicalRequestSb.Append(CanonicalQueryString); + CanonicalRequestSb.Append('\n'); + CanonicalRequestSb.Append(CanonicalHeadersSb.ToView()); + CanonicalRequestSb.Append('\n'); + CanonicalRequestSb.Append(SignedHeaders); + CanonicalRequestSb.Append('\n'); + CanonicalRequestSb.Append(PayloadHash); + std::string_view CanonicalRequest = CanonicalRequestSb.ToView(); + + // Step 2: string-to-sign. + ExtendableStringBuilder<128> CredentialScopeSb; + CredentialScopeSb.Append(DateStamp); + CredentialScopeSb.Append('/'); + CredentialScopeSb.Append(Region); + CredentialScopeSb.Append('/'); + CredentialScopeSb.Append(Service); + CredentialScopeSb.Append("/aws4_request"); + std::string_view CredentialScope = CredentialScopeSb.ToView(); Sha256Digest CanonicalRequestHash = ComputeSha256(CanonicalRequest); std::string CanonicalRequestHex = Sha256ToHex(CanonicalRequestHash); - std::string StringToSign = fmt::format("AWS4-HMAC-SHA256\n{}\n{}\n{}", Result.AmzDate, CredentialScope, CanonicalRequestHex); - - // Step 3: Calculate the signing key - // kDate = HMAC("AWS4" + SecretKey, DateStamp) - // kRegion = HMAC(kDate, Region) - // kService = HMAC(kRegion, Service) - // kSigning = HMAC(kService, "aws4_request") + ExtendableStringBuilder<256> StringToSignSb; + StringToSignSb.Append("AWS4-HMAC-SHA256\n"); + StringToSignSb.Append(Result.AmzDate); + StringToSignSb.Append('\n'); + StringToSignSb.Append(CredentialScope); + StringToSignSb.Append('\n'); + StringToSignSb.Append(CanonicalRequestHex); + std::string_view StringToSign = StringToSignSb.ToView(); + // Step 3: derive signing key (kDate -> kRegion -> kService -> kSigning HMAC chain). Sha256Digest DerivedSigningKey; if (!SigningKeyPtr) { @@ -380,11 +505,11 @@ SignRequestV4(const SigV4Credentials& Credentials, SigningKeyPtr = &DerivedSigningKey; } - // Step 4: Calculate the signature + // Step 4: signature. Sha256Digest Signature = ComputeHmacSha256(*SigningKeyPtr, StringToSign); std::string SignatureHex = Sha256ToHex(Signature); - // Step 5: Build the Authorization header + // Step 5: Authorization header. Result.Authorization = fmt::format("AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", Credentials.AccessKeyId, CredentialScope, @@ -489,6 +614,69 @@ TEST_CASE("sigv4.sha256") CHECK(HelloHex == "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"); } +TEST_CASE("sigv4.sha256stream.matches_oneshot") +{ + // Empty input. + { + Sha256Stream S; + Sha256Digest Streamed = S.Finalize(); + Sha256Digest OneShot = ComputeSha256("", 0); + CHECK(Streamed == OneShot); + } + + // Single update. + { + Sha256Stream S; + S.Update("hello", 5); + Sha256Digest Streamed = S.Finalize(); + Sha256Digest OneShot = ComputeSha256("hello"); + CHECK(Streamed == OneShot); + } + + // Multi-chunk update; result must equal one-shot over concatenated bytes. + { + const std::string Whole = "the quick brown fox jumps over the lazy dog"; + Sha256Stream S; + S.Update(Whole.data(), 4); + S.Update(Whole.data() + 4, 16); + S.Update(Whole.data() + 20, Whole.size() - 20); + Sha256Digest Streamed = S.Finalize(); + Sha256Digest OneShot = ComputeSha256(Whole.data(), Whole.size()); + CHECK(Streamed == OneShot); + } + + // Large input fed in 256 KiB chunks (matches PutMedium hash pass shape). + { + std::vector<uint8_t> Big(1u * 1024u * 1024u + 7u); + for (size_t I = 0; I < Big.size(); ++I) + { + Big[I] = static_cast<uint8_t>((I * 31u + 11u) & 0xFF); + } + Sha256Stream S; + constexpr size_t kChunk = 256u * 1024u; + size_t Off = 0; + while (Off < Big.size()) + { + const size_t Take = std::min(kChunk, Big.size() - Off); + S.Update(Big.data() + Off, Take); + Off += Take; + } + Sha256Digest Streamed = S.Finalize(); + Sha256Digest OneShot = ComputeSha256(Big.data(), Big.size()); + CHECK(Streamed == OneShot); + } + + // Move construct + finalize on the moved-to instance. + { + Sha256Stream A; + A.Update("abc", 3); + Sha256Stream B(std::move(A)); + B.Update("def", 3); + Sha256Digest Result = B.Finalize(); + CHECK(Result == ComputeSha256("abcdef")); + } +} + TEST_CASE("sigv4.hmac_sha256") { // RFC 4231 Test Case 2 diff --git a/src/zenutil/include/zenutil/cloud/s3client.h b/src/zenutil/include/zenutil/cloud/s3client.h index 1ce2a768e..f09788b82 100644 --- a/src/zenutil/include/zenutil/cloud/s3client.h +++ b/src/zenutil/include/zenutil/cloud/s3client.h @@ -3,6 +3,7 @@ #pragma once #include <zenutil/cloud/imdscredentials.h> +#include <zenutil/cloud/s3requestbuilder.h> #include <zenutil/cloud/sigv4.h> #include <zencore/iobuffer.h> @@ -187,8 +188,8 @@ public: /// @param ExpiresIn URL validity duration (default 1 hour, max 7 days) std::string GeneratePresignedPutUrl(std::string_view Key, std::chrono::seconds ExpiresIn = std::chrono::hours(1)); - std::string_view BucketName() const { return m_BucketName; } - std::string_view Region() const { return m_Region; } + std::string_view BucketName() const { return m_Builder.BucketName(); } + std::string_view Region() const { return m_Builder.Region(); } private: /// Shared implementation for pre-signed URL generation @@ -196,31 +197,6 @@ private: LoggerRef Log() { return m_Log; } - /// Build the endpoint URL for the bucket - std::string BuildEndpoint() const; - - /// Build the host header value - std::string BuildHostHeader() const; - - /// Build the S3 object path from a key, accounting for path-style addressing - std::string KeyToPath(std::string_view Key) const; - - /// Build the bucket root path ("/" for virtual-hosted, "/bucket/" for path-style) - std::string BucketRootPath() const; - - /// Sign a request and return headers with Authorization, x-amz-date, x-amz-content-sha256. - /// Additional x-amz-* headers that must participate in the signature are passed via - /// ExtraSignedHeaders (lowercase name, value); they are also copied into the returned map. - HttpClient::KeyValueMap SignRequest(const SigV4Credentials& Credentials, - std::string_view Method, - std::string_view Path, - std::string_view QueryString, - std::string_view PayloadHash, - std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders = {}); - - /// Get or compute the signing key for the given date stamp, caching across requests on the same day - Sha256Digest GetSigningKey(std::string_view DateStamp); - /// Get the current credentials, either from the provider or from static config SigV4Credentials GetCurrentCredentials(); @@ -242,20 +218,12 @@ private: std::string BuildNoCredentialsError(std::string Context); LoggerRef m_Log; - std::string m_BucketName; - std::string m_Region; - std::string m_Endpoint; - std::string m_Host; - bool m_PathStyle; + S3RequestBuilder m_Builder; + mutable RwLock m_CredentialsLock; SigV4Credentials m_Credentials; Ref<ImdsCredentialProvider> m_CredentialProvider; HttpClient m_HttpClient; bool m_Verbose = false; - - // Cached signing key (only changes once per day, protected by RwLock for thread safety) - mutable RwLock m_SigningKeyLock; - std::string m_CachedDateStamp; - Sha256Digest m_CachedSigningKey{}; }; void s3client_forcelink(); diff --git a/src/zenutil/include/zenutil/cloud/s3requestbuilder.h b/src/zenutil/include/zenutil/cloud/s3requestbuilder.h new file mode 100644 index 000000000..c46167fba --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/s3requestbuilder.h @@ -0,0 +1,71 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenutil/cloud/sigv4.h> + +#include <zencore/thread.h> +#include <zenhttp/httpclient.h> + +#include <span> +#include <string> +#include <string_view> + +namespace zen { + +// Stateless builder of signed S3 requests, shared between blocking S3Client +// (HttpClient-backed) and S3AsyncStorage (AsyncHttpClient-backed). Owns the +// per-day signing-key cache so identical signatures across consecutive +// requests do not redo HMAC derivation. +// +// Configuration (region, bucket, endpoint, addressing style) is fixed at +// construction. Credentials are passed per Sign() call. The cache is keyed on +// (DateStamp, AccessKeyId), so any AccessKeyId rotation invalidates the +// derived signing key. STS / IMDS rotate AccessKeyId and SecretAccessKey +// together, so this is sufficient there. Callers hand-rotating SecretAccessKey +// while reusing AccessKeyId would sign with a stale derived key - not a usage +// pattern this builder supports. +class S3RequestBuilder +{ +public: + S3RequestBuilder(std::string Region, std::string BucketName, std::string Endpoint, bool PathStyle); + + std::string_view Region() const { return m_Region; } + std::string_view BucketName() const { return m_BucketName; } + std::string_view Endpoint() const { return m_Endpoint; } + std::string_view Host() const { return m_Host; } + bool PathStyle() const { return m_PathStyle; } + + std::string KeyToPath(std::string_view Key) const; + std::string BucketRootPath() const; + + // Sign a request. Returns headers including Authorization, x-amz-date, + // x-amz-content-sha256, and x-amz-security-token (when SessionToken set). + // ExtraSignedHeaders are lowercase-name pairs that must participate in the + // signature; they are also copied into the returned map. + HttpClient::KeyValueMap SignRequest(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Path, + std::string_view CanonicalQueryString, + std::string_view PayloadHash, + std::span<const std::pair<std::string, std::string>> ExtraSignedHeaders = {}); + +private: + Sha256Digest GetSigningKey(std::string_view DateStamp, const SigV4Credentials& Credentials); + + static std::string DeriveEndpoint(std::string_view Endpoint, std::string_view Region, std::string_view BucketName, bool PathStyle); + static std::string HostFromEndpoint(std::string_view Endpoint); + + std::string m_Region; + std::string m_BucketName; + std::string m_Endpoint; + std::string m_Host; + bool m_PathStyle; + + mutable RwLock m_SigningKeyLock; + std::string m_CachedDateStamp; + std::string m_CachedAccessKeyId; + Sha256Digest m_CachedSigningKey{}; +}; + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/s3response.h b/src/zenutil/include/zenutil/cloud/s3response.h new file mode 100644 index 000000000..9dec3215b --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/s3response.h @@ -0,0 +1,51 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/string.h> +#include <zenhttp/httpclient.h> + +#include <string> +#include <string_view> + +namespace zen { + +// Helpers for parsing S3 (or S3-compatible) HTTP responses. Shared between +// the blocking S3Client and the async S3AsyncStorage path so XML/error +// handling stays consistent across implementations. +// +// The XML parser is intentionally minimal: only ListBucketResult / Error +// shapes are needed. CDATA, namespaces, and attributes are not handled. + +constexpr std::string_view S3EmptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; + +// Find text between <Tag> and </Tag>. Returns a view into Xml; empty if not found. +std::string_view S3ExtractXmlValue(std::string_view Xml, std::string_view Tag); + +// Decode the five standard XML entities into Out. +void S3DecodeXmlEntities(std::string_view Input, StringBuilderBase& Out); + +// Convenience overload returning std::string. +std::string S3DecodeXmlEntities(std::string_view Input); + +// Append a canonical query string to a path: "Path?CanonicalQS" or "Path" when QS empty. +std::string S3BuildRequestPath(std::string_view Path, std::string_view CanonicalQS); + +// Case-insensitive header lookup in a response header map. Returns nullptr if not found. +const std::string* S3FindResponseHeader(const HttpClient::KeyValueMap& Headers, std::string_view Name); + +// Extract Code/Message from an S3 <Error> body. Returns true only when at least one of +// Code/Message parsed non-empty - malformed bodies (Error tag present but no parseable +// children) are reported as parse miss to keep formatted error strings non-empty. +bool S3ExtractError(std::string_view Body, std::string_view& OutCode, std::string_view& OutMessage); + +// Detect throttling: HTTP 503/429 or known S3 throttle codes (SlowDown, ServiceUnavailable, etc). +// Code is checked on both the HTTP status and the XML error code. +bool S3IsThrottled(const HttpClient::Response& Response, std::string_view ErrorCode); + +// Build a human-readable error message for a failed S3 response. Falls back to the +// generic HTTP/transport message when no <Error> body is available. Logs a +// distinct "S3 THROTTLED" warning when the response indicates throttling. +std::string S3ErrorMessage(std::string_view Prefix, const HttpClient::Response& Response); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/sigv4.h b/src/zenutil/include/zenutil/cloud/sigv4.h index 9ac08df76..a16daf2cf 100644 --- a/src/zenutil/include/zenutil/cloud/sigv4.h +++ b/src/zenutil/include/zenutil/cloud/sigv4.h @@ -19,6 +19,26 @@ using Sha256Digest = std::array<uint8_t, 32>; Sha256Digest ComputeSha256(const void* Data, size_t Size); Sha256Digest ComputeSha256(std::string_view Data); +/// Streaming SHA-256: feed data in chunks, finalize once. +/// Move-only; copying digest state is not supported by the underlying APIs. +class Sha256Stream +{ +public: + Sha256Stream(); + ~Sha256Stream(); + + Sha256Stream(const Sha256Stream&) = delete; + Sha256Stream& operator=(const Sha256Stream&) = delete; + Sha256Stream(Sha256Stream&& Other) noexcept; + Sha256Stream& operator=(Sha256Stream&& Other) noexcept; + + void Update(const void* Data, size_t Size); + Sha256Digest Finalize(); // single-use; further Update/Finalize is undefined + +private: + void* m_Ctx = nullptr; +}; + /// Compute HMAC-SHA256 with the given key and data Sha256Digest ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize); Sha256Digest ComputeHmacSha256(const Sha256Digest& Key, std::string_view Data); |