diff options
| author | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
|---|---|---|
| committer | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
| commit | d1abc50ee9d4fb72efc646e17decafea741caa34 (patch) | |
| tree | e4288e00f2f7ca0391b83d986efcb69d3ba66a83 /src/zencompute | |
| parent | Allow requests with invalid content-types unless specified in command line or... (diff) | |
| parent | updated chunk–block analyser (#818) (diff) | |
| download | zen-d1abc50ee9d4fb72efc646e17decafea741caa34.tar.xz zen-d1abc50ee9d4fb72efc646e17decafea741caa34.zip | |
Merge branch 'main' into lm/restrict-content-type
Diffstat (limited to 'src/zencompute')
38 files changed, 13620 insertions, 0 deletions
diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md new file mode 100644 index 000000000..f5188123f --- /dev/null +++ b/src/zencompute/CLAUDE.md @@ -0,0 +1,232 @@ +# zencompute Module + +Lambda-style compute function service. Accepts execution requests from HTTP clients, schedules them across local and remote runners, and tracks results. + +## Directory Structure + +``` +src/zencompute/ +├── include/zencompute/ # Public headers +│ ├── computeservice.h # ComputeServiceSession public API +│ ├── httpcomputeservice.h # HTTP service wrapper +│ ├── orchestratorservice.h # Worker registry and orchestration +│ ├── httporchestrator.h # HTTP orchestrator with WebSocket push +│ ├── recordingreader.h # Recording/replay reader API +│ ├── cloudmetadata.h # Cloud provider detection (AWS/Azure/GCP) +│ └── mockimds.h # Test helper for cloud metadata +├── runners/ # Execution backends +│ ├── functionrunner.h/.cpp # Abstract base + BaseRunnerGroup/RunnerGroup +│ ├── localrunner.h/.cpp # LocalProcessRunner (sandbox, monitoring, CPU sampling) +│ ├── windowsrunner.h/.cpp # Windows AppContainer sandboxing + CreateProcessW +│ ├── linuxrunner.h/.cpp # Linux user/mount/network namespace isolation +│ ├── macrunner.h/.cpp # macOS Seatbelt sandboxing +│ ├── winerunner.h/.cpp # Wine runner for Windows executables on Linux +│ ├── remotehttprunner.h/.cpp # Remote HTTP submission to other zenserver instances +│ └── deferreddeleter.h/.cpp # Background deletion of sandbox directories +├── recording/ # Recording/replay subsystem +│ ├── actionrecorder.h/.cpp # Write actions+attachments to disk +│ └── recordingreader.cpp # Read recordings back +├── timeline/ +│ └── workertimeline.h/.cpp # Per-worker action lifecycle event tracking +├── testing/ +│ └── mockimds.cpp # Mock IMDS for cloud metadata tests +├── computeservice.cpp # ComputeServiceSession::Impl (~1700 lines) +├── httpcomputeservice.cpp # HTTP route registration and handlers (~900 lines) +├── httporchestrator.cpp # Orchestrator HTTP API + WebSocket push +├── orchestratorservice.cpp # Worker registry, health probing +└── cloudmetadata.cpp # IMDS probing, termination monitoring +``` + +## Key Classes + +### `ComputeServiceSession` (computeservice.h) +Public API entry point. Uses PIMPL (`struct Impl` in computeservice.cpp). Owns: +- Two `RunnerGroup`s: `m_LocalRunnerGroup`, `m_RemoteRunnerGroup` +- Scheduler thread that drains `m_UpdatedActions` and drives state transitions +- Action maps: `m_PendingActions`, `m_RunningMap`, `m_ResultsMap` +- Queue map: `m_Queues` (QueueEntry objects) +- Action history ring: `m_ActionHistory` (bounded deque, default 1000) + +**Session states:** Created → Ready → Draining → Paused → Abandoned → Sunset. Both Abandoned and Sunset can be jumped to from any earlier state. Abandoned is used for spot instance termination grace periods — on entry, all pending and running actions are immediately marked as `RunnerAction::State::Abandoned` and running processes are best-effort cancelled. Auto-retry is suppressed while the session is Abandoned. `IsHealthy()` returns false for Abandoned and Sunset. + +### `RunnerAction` (runners/functionrunner.h) +Shared ref-counted struct representing one action through its lifecycle. + +**Key fields:** +- `ActionLsn` — global unique sequence number +- `QueueId` — 0 for implicit/unqueued actions +- `Worker` — descriptor + content hash +- `ActionObj` — CbObject with the action spec +- `CpuUsagePercent` / `CpuSeconds` — atomics updated by monitor thread +- `RetryCount` — atomic int tracking how many times the action has been rescheduled +- `Timestamps[State::_Count]` — timestamp of each state transition + +**State machine (forward-only under normal flow, atomic):** +``` +New → Pending → Submitting → Running → Completed + → Failed + → Abandoned + → Cancelled +``` +`SetActionState()` rejects non-forward transitions. The one exception is `ResetActionStateToPending()`, which uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling. It clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline. + +### `LocalProcessRunner` (runners/localrunner.h) +Base for all local execution. Platform runners subclass this and override: +- `SubmitAction()` — fork/exec the worker process +- `SweepRunningActions()` — poll for process exit (waitpid / WaitForSingleObject) +- `CancelRunningActions()` — signal all processes during shutdown +- `SampleProcessCpu(RunningAction&)` — read platform CPU usage (no-op default) + +**Infrastructure owned by LocalProcessRunner:** +- Monitor thread — calls `SweepRunningActions()` then `SampleRunningProcessCpu()` in a loop +- `m_RunningMap` — `RwLock`-guarded map of `Lsn → RunningAction` +- `DeferredDirectoryDeleter` — sandbox directories are queued for async deletion +- `PrepareActionSubmission()` — shared preamble (capacity check, sandbox creation, worker manifesting, input decompression) +- `ProcessCompletedActions()` — shared post-processing (gather outputs, set state, enqueue deletion) + +**CPU sampling:** `SampleRunningProcessCpu()` iterates `m_RunningMap` under shared lock, calls `SampleProcessCpu()` per entry, throttled to every 5 seconds per action. Platform implementations: +- Linux: `/proc/{pid}/stat` utime+stime jiffies ÷ `_SC_CLK_TCK` +- Windows: `GetProcessTimes()` in 100ns intervals ÷ 10,000,000 +- macOS: `proc_pidinfo(PROC_PIDTASKINFO)` pti_total_user+system nanoseconds ÷ 1,000,000,000 + +### `FunctionRunner` / `RunnerGroup` (runners/functionrunner.h) +Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. + +### `HttpComputeService` (include/zencompute/httpcomputeservice.h) +Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. + +## Action Lifecycle (End to End) + +1. **HTTP POST** → `HttpComputeService` ingests attachments, calls `EnqueueAction()` +2. **Enqueue** → creates `RunnerAction` (New → Pending), calls `PostUpdate()` +3. **PostUpdate** → appends to `m_UpdatedActions`, signals scheduler thread event +4. **Scheduler thread** → drains `m_UpdatedActions`, drives pending actions to runners +5. **Runner `SubmitAction()`** → Pending → Submitting (on runner's worker pool thread) +6. **Process launch** → Submitting → Running, added to `m_RunningMap` +7. **Monitor thread `SweepRunningActions()`** → detects exit, gathers outputs +8. **`ProcessCompletedActions()`** → Running → Completed/Failed/Abandoned, `PostUpdate()` +9. **Scheduler thread `HandleActionUpdates()`** — for Failed/Abandoned actions, checks retry limit; if retries remain, calls `ResetActionStateToPending()` which loops back to step 3. Otherwise moves to `m_ResultsMap`, records history, notifies queue. +10. **Client `GET /jobs/{lsn}`** → returns result from `m_ResultsMap`, schedules retirement + +### Action Rescheduling + +Actions that fail or are abandoned can be automatically retried or manually rescheduled via the API. + +**Automatic retry (scheduler path):** In `HandleActionUpdates()`, when a Failed or Abandoned state is detected, the scheduler checks `RetryCount < GetMaxRetriesForQueue(QueueId)`. If retries remain, the action is removed from active maps and `ResetActionStateToPending()` is called, which re-enters it into the scheduler pipeline. The action keeps its original LSN so clients can continue polling with the same identifier. + +**Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure. + +**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. + +**Cancelled actions are never retried** — cancellation is an intentional user action, not a transient failure. + +## Queue System + +Queues group actions from a single client session. A `QueueEntry` (internal) tracks: +- `State` — `std::atomic<QueueState>` lifecycle state (Active → Draining → Cancelled) +- `ActiveCount` — pending + running actions (atomic) +- `CompletedCount / FailedCount / AbandonedCount / CancelledCount` (atomics) +- `ActiveLsns` — for cancellation lookup (under `m_Lock`) +- `FinishedLsns` — moved here when actions complete +- `IdleSince` — used for 15-minute automatic expiry +- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit + +**Queue state machine (`QueueState` enum):** +``` +Active → Draining → Cancelled + \ ↑ + ─────────────────────/ +``` +- **Active** — accepts new work, schedules pending work, finishes running work (initial state) +- **Draining** — rejects new work, finishes existing work (one-way via CAS from Active; cannot override Cancelled) +- **Cancelled** — rejects new work, actively cancels in-flight work (reachable from Active or Draining) + +Key operations: +- `CreateQueue(Tag)` → returns `QueueId` +- `EnqueueActionToQueue(QueueId, ...)` → action's `QueueId` field is set at creation +- `CancelQueue(QueueId)` → marks all active LSNs for cancellation +- `DrainQueue(QueueId)` → stops accepting new submissions; existing work finishes naturally (irreversible) +- `GetQueueCompleted(QueueId)` → CbWriter output of finished results +- Queue references in HTTP routes accept either a decimal ID or an Oid token (24-hex), resolved by `ResolveQueueRef()` + +## HTTP API + +All routes registered in `HttpComputeService` constructor. Prefix is configured externally (typically `/compute`). + +### Global endpoints +| Method | Path | Description | +|--------|------|-------------| +| POST | `abandon` | Transition session to Abandoned state (409 if invalid) | +| GET | `jobs/history` | Action history (last N, with timestamps per state) | +| GET | `jobs/running` | In-flight actions with CPU metrics | +| GET | `jobs/completed` | Actions with results available | +| GET/POST/DELETE | `jobs/{lsn}` | GET: result; POST: reschedule failed action; DELETE: retire | +| POST | `jobs/{worker}` | Submit action for specific worker | +| POST | `jobs` | Submit action (worker resolved from descriptor) | +| GET | `workers` | List worker IDs | +| GET | `workers/all` | All workers with full descriptors | +| GET/POST | `workers/{worker}` | Get/register worker | + +### Queue-scoped endpoints +Queue ref is capture(1) in all `queues/{queueref}/...` routes. + +| Method | Path | Description | +|--------|------|-------------| +| GET | `queues` | List queue IDs | +| POST | `queues` | Create queue | +| GET/DELETE | `queues/{queueref}` | Status / delete | +| POST | `queues/{queueref}/drain` | Drain queue (irreversible; rejects new submissions) | +| GET | `queues/{queueref}/completed` | Queue's completed results | +| GET | `queues/{queueref}/history` | Queue's action history | +| GET | `queues/{queueref}/running` | Queue's running actions | +| POST | `queues/{queueref}/jobs` | Submit to queue | +| GET/POST | `queues/{queueref}/jobs/{lsn}` | GET: result; POST: reschedule | +| GET/POST | `queues/{queueref}/workers/...` | Worker endpoints (same as global) | + +Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `HandleWorkersAllGet`, `HandleWorkerRequest`) shared by top-level and queue-scoped routes. + +## Concurrency Model + +**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks: +1. `m_ResultsLock` +2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock") +3. `m_PendingLock` +4. `m_QueueLock` + +**Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`. + +**Update decoupling:** Runners call `PostUpdate(RunnerAction*)` rather than directly mutating service state. The scheduler thread batches and deduplicates updates. + +**Thread ownership:** +- Scheduler thread — drives state transitions, owns `m_PendingActions` +- Monitor thread (per runner) — polls process completion, owns `m_RunningMap` via shared lock +- Worker pool threads — async submission, brief `SubmitAction()` calls +- HTTP threads — read-only access to results, queue status + +## Sandbox Layout + +Each action gets a unique numbered directory under `m_SandboxPath`: +``` +scratch/{counter}/ + worker/ ← worker binaries (or bind-mounted on Linux) + inputs/ ← decompressed action inputs + outputs/ ← written by worker process +``` + +On Linux with sandboxing enabled, the process runs in a pivot-rooted namespace with `/usr`, `/lib`, `/etc`, `/worker` bind-mounted read-only and a tmpfs `/dev`. + +## Adding a New HTTP Endpoint + +1. Register the route in the `HttpComputeService` constructor in `httpcomputeservice.cpp` +2. If the handler is shared between top-level and a `queues/{queueref}/...` variant, extract it as a private helper method declared in `httpcomputeservice.h` +3. Queue-scoped routes validate the queue ref with `ResolveQueueRef(HttpReq, Req.GetCapture(1))` which writes an error response and returns 0 on failure +4. Use `CbObjectWriter` for response bodies; emit via `HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save())` +5. Conditional fields (e.g., optional CPU metrics): emit inside `if (value > 0.0f)` / `if (value >= 0.0f)` guards to omit absent values rather than emitting sentinel values + +## Adding a New Runner Platform + +1. Subclass `LocalProcessRunner`, add `h`/`cpp` files in `runners/` +2. Override `SubmitAction()`, `SweepRunningActions()`, `CancelRunningActions()`, and optionally `CancelAction(int)` and `SampleProcessCpu(RunningAction&)` +3. `SampleProcessCpu()` must update both `Running.Action->CpuSeconds` (unconditionally from the absolute OS value) and `Running.Action->CpuUsagePercent` (delta-based, only after second sample) +4. `ProcessHandle` convention: store pid as `reinterpret_cast<void*>(static_cast<intptr_t>(pid))` for consistency with the base class +5. Register in `ComputeServiceSession::AddLocalRunner()` in `computeservice.cpp` diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp new file mode 100644 index 000000000..65bac895f --- /dev/null +++ b/src/zencompute/cloudmetadata.cpp @@ -0,0 +1,1014 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/cloudmetadata.h> + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> +#include <zencore/string.h> +#include <zencore/trace.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +// All major cloud providers expose instance metadata at this link-local address. +// It is only routable from within a cloud VM; on bare-metal the TCP connect will +// fail, which is how we distinguish cloud from non-cloud environments. +static constexpr std::string_view kImdsEndpoint = "http://169.254.169.254"; + +// Short connect timeout so that detection on non-cloud machines is fast. The IMDS +// is a local service on the hypervisor so 200ms is generous for actual cloud VMs. +static constexpr auto kImdsTimeout = std::chrono::milliseconds{200}; + +std::string_view +ToString(CloudProvider Provider) +{ + switch (Provider) + { + case CloudProvider::AWS: + return "AWS"; + case CloudProvider::Azure: + return "Azure"; + case CloudProvider::GCP: + return "GCP"; + default: + return "None"; + } +} + +CloudMetadata::CloudMetadata(std::filesystem::path DataDir) : CloudMetadata(std::move(DataDir), std::string(kImdsEndpoint)) +{ +} + +CloudMetadata::CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint) +: m_Log(logging::Get("cloud")) +, m_DataDir(std::move(DataDir)) +, m_ImdsEndpoint(std::move(ImdsEndpoint)) +{ + ZEN_TRACE_CPU("CloudMetadata::CloudMetadata"); + + std::error_code Ec; + std::filesystem::create_directories(m_DataDir, Ec); + + DetectProvider(); + + if (m_Info.Provider != CloudProvider::None) + { + StartTerminationMonitor(); + } +} + +CloudMetadata::~CloudMetadata() +{ + ZEN_TRACE_CPU("CloudMetadata::~CloudMetadata"); + m_MonitorEnabled = false; + m_MonitorEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } +} + +CloudProvider +CloudMetadata::GetProvider() const +{ + return m_InfoLock.WithSharedLock([&] { return m_Info.Provider; }); +} + +CloudInstanceInfo +CloudMetadata::GetInstanceInfo() const +{ + return m_InfoLock.WithSharedLock([&] { return m_Info; }); +} + +bool +CloudMetadata::IsTerminationPending() const +{ + return m_TerminationPending.load(std::memory_order_relaxed); +} + +std::string +CloudMetadata::GetTerminationReason() const +{ + return m_ReasonLock.WithSharedLock([&] { return m_TerminationReason; }); +} + +void +CloudMetadata::Describe(CbWriter& Writer) const +{ + ZEN_TRACE_CPU("CloudMetadata::Describe"); + CloudInstanceInfo Info = GetInstanceInfo(); + + if (Info.Provider == CloudProvider::None) + { + return; + } + + Writer.BeginObject("cloud"); + Writer << "provider" << ToString(Info.Provider); + Writer << "instance_id" << Info.InstanceId; + Writer << "availability_zone" << Info.AvailabilityZone; + Writer << "is_spot" << Info.IsSpot; + Writer << "is_autoscaling" << Info.IsAutoscaling; + Writer << "termination_pending" << IsTerminationPending(); + + if (IsTerminationPending()) + { + Writer << "termination_reason" << GetTerminationReason(); + } + + Writer.EndObject(); +} + +void +CloudMetadata::DetectProvider() +{ + ZEN_TRACE_CPU("CloudMetadata::DetectProvider"); + + if (TryDetectAWS()) + { + return; + } + + if (TryDetectAzure()) + { + return; + } + + if (TryDetectGCP()) + { + return; + } + + ZEN_DEBUG("no cloud provider detected"); +} + +// AWS detection uses IMDSv2 which requires a session token obtained via PUT before +// any GET requests are allowed. This is more secure than IMDSv1 (which allowed +// unauthenticated GETs) and is the default on modern EC2 instances. The token has +// a 300-second TTL and is reused for termination polling. +bool +CloudMetadata::TryDetectAWS() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectAWS"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotAWS"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping AWS detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing AWS IMDS"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + // IMDSv2: acquire session token. The TTL header is mandatory; we request + // 300s which is sufficient for the detection phase. The token is also + // stored in m_AwsToken for reuse by the termination polling thread. + HttpClient::KeyValueMap TokenHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token-ttl-seconds", "300"}); + HttpClient::Response TokenResponse = ImdsClient.Put("/latest/api/token", IoBuffer{}, TokenHeaders); + + if (!TokenResponse.IsSuccess()) + { + ZEN_DEBUG("AWS IMDS token request failed ({}), not on AWS", static_cast<int>(TokenResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + m_AwsToken = std::string(TokenResponse.AsText()); + + HttpClient::KeyValueMap AuthHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token", m_AwsToken}); + + HttpClient::Response IdResponse = ImdsClient.Get("/latest/meta-data/instance-id", AuthHeaders); + if (IdResponse.IsSuccess()) + { + m_Info.InstanceId = std::string(IdResponse.AsText()); + } + + HttpClient::Response AzResponse = ImdsClient.Get("/latest/meta-data/placement/availability-zone", AuthHeaders); + if (AzResponse.IsSuccess()) + { + m_Info.AvailabilityZone = std::string(AzResponse.AsText()); + } + + // "spot" vs "on-demand" — determines whether the instance can be + // reclaimed by AWS with a 2-minute warning + HttpClient::Response LifecycleResponse = ImdsClient.Get("/latest/meta-data/instance-life-cycle", AuthHeaders); + if (LifecycleResponse.IsSuccess()) + { + m_Info.IsSpot = (LifecycleResponse.AsText() == "spot"); + } + + // This endpoint only exists on instances managed by an Auto Scaling + // Group. A successful response (regardless of value) means autoscaling. + HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders); + if (AutoscaleResponse.IsSuccess()) + { + m_Info.IsAutoscaling = true; + } + + m_Info.Provider = CloudProvider::AWS; + + ZEN_INFO("detected AWS instance: id={}, az={}, spot={}, autoscaling={}", + m_Info.InstanceId, + m_Info.AvailabilityZone, + m_Info.IsSpot, + m_Info.IsAutoscaling); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("AWS IMDS probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// Azure IMDS returns a single JSON document for the entire instance metadata, +// unlike AWS and GCP which use separate plain-text endpoints per field. The +// "Metadata: true" header is required; requests without it are rejected. +// The api-version parameter is mandatory and pins the response schema. +bool +CloudMetadata::TryDetectAzure() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectAzure"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotAzure"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping Azure detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing Azure IMDS"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + HttpClient::KeyValueMap MetadataHeaders({ + std::pair<std::string_view, std::string_view>{"Metadata", "true"}, + }); + + HttpClient::Response InstanceResponse = ImdsClient.Get("/metadata/instance?api-version=2021-02-01", MetadataHeaders); + + if (!InstanceResponse.IsSuccess()) + { + ZEN_DEBUG("Azure IMDS request failed ({}), not on Azure", static_cast<int>(InstanceResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(InstanceResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + ZEN_DEBUG("Azure IMDS returned invalid JSON: {}", JsonError); + WriteSentinelFile(SentinelPath); + return false; + } + + const json11::Json& Compute = Json["compute"]; + + m_Info.InstanceId = Compute["vmId"].string_value(); + m_Info.AvailabilityZone = Compute["location"].string_value(); + + // Azure spot VMs have priority "Spot"; regular VMs have "Regular" + std::string Priority = Compute["priority"].string_value(); + m_Info.IsSpot = (Priority == "Spot"); + + // Check if part of a VMSS (Virtual Machine Scale Set) — indicates autoscaling + std::string VmssName = Compute["vmScaleSetName"].string_value(); + m_Info.IsAutoscaling = !VmssName.empty(); + + m_Info.Provider = CloudProvider::Azure; + + ZEN_INFO("detected Azure instance: id={}, location={}, spot={}, vmss={}", + m_Info.InstanceId, + m_Info.AvailabilityZone, + m_Info.IsSpot, + m_Info.IsAutoscaling); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("Azure IMDS probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// GCP requires the "Metadata-Flavor: Google" header on all IMDS requests. +// Unlike AWS, there is no session token; the header itself is the auth mechanism +// (it prevents SSRF attacks since browsers won't send custom headers to the +// metadata endpoint). Each metadata field is fetched from a separate URL. +bool +CloudMetadata::TryDetectGCP() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectGCP"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotGCP"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping GCP detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing GCP metadata service"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + HttpClient::KeyValueMap MetadataHeaders(std::pair<std::string_view, std::string_view>{"Metadata-Flavor", "Google"}); + + // Fetch instance ID + HttpClient::Response IdResponse = ImdsClient.Get("/computeMetadata/v1/instance/id", MetadataHeaders); + + if (!IdResponse.IsSuccess()) + { + ZEN_DEBUG("GCP metadata request failed ({}), not on GCP", static_cast<int>(IdResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + m_Info.InstanceId = std::string(IdResponse.AsText()); + + // GCP returns the fully-qualified zone path "projects/<num>/zones/<zone>". + // Strip the prefix to get just the zone name (e.g. "us-central1-a"). + HttpClient::Response ZoneResponse = ImdsClient.Get("/computeMetadata/v1/instance/zone", MetadataHeaders); + if (ZoneResponse.IsSuccess()) + { + std::string_view Zone = ZoneResponse.AsText(); + if (auto Pos = Zone.rfind('/'); Pos != std::string_view::npos) + { + Zone = Zone.substr(Pos + 1); + } + m_Info.AvailabilityZone = std::string(Zone); + } + + // Check for preemptible/spot (scheduling/preemptible returns "TRUE" or "FALSE") + HttpClient::Response PreemptibleResponse = ImdsClient.Get("/computeMetadata/v1/instance/scheduling/preemptible", MetadataHeaders); + if (PreemptibleResponse.IsSuccess()) + { + m_Info.IsSpot = (PreemptibleResponse.AsText() == "TRUE"); + } + + // Check for maintenance event + HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders); + if (MaintenanceResponse.IsSuccess()) + { + std::string_view Event = MaintenanceResponse.AsText(); + if (!Event.empty() && Event != "NONE") + { + m_TerminationPending = true; + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); }); + } + } + + m_Info.Provider = CloudProvider::GCP; + + ZEN_INFO("detected GCP instance: id={}, az={}, spot={}", m_Info.InstanceId, m_Info.AvailabilityZone, m_Info.IsSpot); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("GCP metadata probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// Sentinel files are empty marker files whose mere existence signals that a +// previous detection attempt for a given provider failed. This avoids paying +// the connect-timeout cost on every startup for providers that are known to +// be absent. The files persist across process restarts; delete them manually +// (or remove the DataDir) to force re-detection. +void +CloudMetadata::WriteSentinelFile(const std::filesystem::path& Path) +{ + try + { + BasicFile File; + File.Open(Path, BasicFile::Mode::kTruncate); + } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to write sentinel file '{}': {}", Path.string(), Ex.what()); + } +} + +bool +CloudMetadata::HasSentinelFile(const std::filesystem::path& Path) const +{ + return zen::IsFile(Path); +} + +void +CloudMetadata::ClearSentinelFiles() +{ + std::error_code Ec; + std::filesystem::remove(m_DataDir / ".isNotAWS", Ec); + std::filesystem::remove(m_DataDir / ".isNotAzure", Ec); + std::filesystem::remove(m_DataDir / ".isNotGCP", Ec); +} + +void +CloudMetadata::StartTerminationMonitor() +{ + ZEN_INFO("starting cloud termination monitor for {} instance {}", ToString(m_Info.Provider), m_Info.InstanceId); + + m_MonitorThread = std::thread{&CloudMetadata::TerminationMonitorThread, this}; +} + +void +CloudMetadata::TerminationMonitorThread() +{ + SetCurrentThreadName("cloud_term_mon"); + + // Poll every 5 seconds. The Event is used as an interruptible sleep so + // that the destructor can wake us up immediately for a clean shutdown. + while (m_MonitorEnabled) + { + m_MonitorEvent.Wait(5000); + m_MonitorEvent.Reset(); + + if (!m_MonitorEnabled) + { + return; + } + + PollTermination(); + } +} + +void +CloudMetadata::PollTermination() +{ + try + { + CloudProvider Provider = m_InfoLock.WithSharedLock([&] { return m_Info.Provider; }); + + if (Provider == CloudProvider::AWS) + { + PollAWSTermination(); + } + else if (Provider == CloudProvider::Azure) + { + PollAzureTermination(); + } + else if (Provider == CloudProvider::GCP) + { + PollGCPTermination(); + } + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("termination poll error: {}", Ex.what()); + } +} + +// AWS termination signals: +// - /spot/instance-action: returns 200 with a JSON body ~2 minutes before +// a spot instance is reclaimed. Returns 404 when no action is pending. +// - /autoscaling/target-lifecycle-state: returns the ASG lifecycle state. +// "InService" is normal; anything else (e.g. "Terminated:Wait") means +// the instance is being cycled out. +void +CloudMetadata::PollAWSTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollAWSTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap AuthHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token", m_AwsToken}); + + HttpClient::Response SpotResponse = ImdsClient.Get("/latest/meta-data/spot/instance-action", AuthHeaders); + if (SpotResponse.IsSuccess()) + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS spot interruption: {}", SpotResponse.AsText()); }); + ZEN_WARN("AWS spot interruption detected: {}", SpotResponse.AsText()); + } + return; + } + + HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders); + if (AutoscaleResponse.IsSuccess()) + { + std::string_view State = AutoscaleResponse.AsText(); + if (State.find("InService") == std::string_view::npos) + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS autoscaling lifecycle: {}", State); }); + ZEN_WARN("AWS autoscaling termination detected: {}", State); + } + } + } +} + +// Azure Scheduled Events API returns a JSON array of upcoming platform events. +// We care about "Preempt" (spot eviction), "Terminate", and "Reboot" events. +// Other event types like "Freeze" (live migration) are non-destructive and +// ignored. The Events array is empty when nothing is pending. +void +CloudMetadata::PollAzureTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollAzureTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap MetadataHeaders({ + std::pair<std::string_view, std::string_view>{"Metadata", "true"}, + }); + + HttpClient::Response EventsResponse = ImdsClient.Get("/metadata/scheduledevents?api-version=2020-07-01", MetadataHeaders); + + if (!EventsResponse.IsSuccess()) + { + return; + } + + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(EventsResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + return; + } + + const json11::Json::array& Events = Json["Events"].array_items(); + for (const auto& Evt : Events) + { + std::string EventType = Evt["EventType"].string_value(); + if (EventType == "Preempt" || EventType == "Terminate" || EventType == "Reboot") + { + if (!m_TerminationPending.exchange(true)) + { + std::string EventStatus = Evt["EventStatus"].string_value(); + m_ReasonLock.WithExclusiveLock( + [&] { m_TerminationReason = fmt::format("Azure scheduled event: {} ({})", EventType, EventStatus); }); + ZEN_WARN("Azure termination event detected: {} ({})", EventType, EventStatus); + } + return; + } + } +} + +// GCP maintenance-event returns "NONE" when nothing is pending, and a +// descriptive string like "TERMINATE_ON_HOST_MAINTENANCE" when the VM is +// about to be live-migrated or terminated. Preemptible/spot VMs get a +// 30-second warning before termination. +void +CloudMetadata::PollGCPTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollGCPTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap MetadataHeaders(std::pair<std::string_view, std::string_view>{"Metadata-Flavor", "Google"}); + + HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders); + if (MaintenanceResponse.IsSuccess()) + { + std::string_view Event = MaintenanceResponse.AsText(); + if (!Event.empty() && Event != "NONE") + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); }); + ZEN_WARN("GCP maintenance event detected: {}", Event); + } + } + } +} + +} // namespace zen::compute + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +# include <zencompute/mockimds.h> + +# include <zencore/filesystem.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zenhttp/httpserver.h> + +# include <memory> +# include <thread> + +namespace zen::compute { + +TEST_SUITE_BEGIN("compute.cloudmetadata"); + +// --------------------------------------------------------------------------- +// Test helper — spins up a local ASIO HTTP server hosting a MockImdsService +// --------------------------------------------------------------------------- + +struct TestImdsServer +{ + MockImdsService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(7575, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + std::string Endpoint() const { return fmt::format("http://127.0.0.1:{}", m_Port); } + + std::filesystem::path DataDir() const { return m_TmpDir->Path() / "cloud"; } + + std::unique_ptr<CloudMetadata> CreateCloud() { return std::make_unique<CloudMetadata>(DataDir(), Endpoint()); } + + ~TestImdsServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional<ScopedTemporaryDirectory> m_TmpDir; + Ref<HttpServer> m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.aws") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::AWS; + + SUBCASE("detection basics") + { + Imds.Mock.Aws.InstanceId = "i-abc123"; + Imds.Mock.Aws.AvailabilityZone = "us-west-2b"; + Imds.Mock.Aws.LifeCycle = "on-demand"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::AWS); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "i-abc123"); + CHECK(Info.AvailabilityZone == "us-west-2b"); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("spot instance") + { + Imds.Mock.Aws.LifeCycle = "spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("autoscaling instance") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsAutoscaling == true); + } + + SUBCASE("spot termination") + { + Imds.Mock.Aws.LifeCycle = "spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + // Simulate a spot interruption notice appearing + Imds.Mock.Aws.SpotAction = R"({"action":"terminate","time":"2025-01-01T00:00:00Z"})"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("spot interruption") != std::string::npos); + } + + SUBCASE("autoscaling termination") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + // Simulate ASG cycling the instance out + Imds.Mock.Aws.AutoscalingState = "Terminated:Wait"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("autoscaling") != std::string::npos); + } + + SUBCASE("no termination when InService") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.azure") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::Azure; + + SUBCASE("detection basics") + { + Imds.Mock.Azure.VmId = "vm-test-1234"; + Imds.Mock.Azure.Location = "westeurope"; + Imds.Mock.Azure.Priority = "Regular"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::Azure); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "vm-test-1234"); + CHECK(Info.AvailabilityZone == "westeurope"); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("spot instance") + { + Imds.Mock.Azure.Priority = "Spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("vmss instance") + { + Imds.Mock.Azure.VmScaleSetName = "my-vmss"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsAutoscaling == true); + } + + SUBCASE("preempt termination") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Azure.ScheduledEventType = "Preempt"; + Imds.Mock.Azure.ScheduledEventStatus = "Scheduled"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("Preempt") != std::string::npos); + } + + SUBCASE("terminate event") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Azure.ScheduledEventType = "Terminate"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("Terminate") != std::string::npos); + } + + SUBCASE("no termination when events empty") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.gcp") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::GCP; + + SUBCASE("detection basics") + { + Imds.Mock.Gcp.InstanceId = "9876543210"; + Imds.Mock.Gcp.Zone = "projects/123/zones/europe-west1-b"; + Imds.Mock.Gcp.Preemptible = "FALSE"; + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::GCP); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "9876543210"); + CHECK(Info.AvailabilityZone == "europe-west1-b"); // zone prefix stripped + CHECK(Info.IsSpot == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("preemptible instance") + { + Imds.Mock.Gcp.Preemptible = "TRUE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("maintenance event during detection") + { + Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + // GCP sets termination pending immediately during detection if a + // maintenance event is active + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos); + } + + SUBCASE("maintenance event during polling") + { + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos); + } + + SUBCASE("no termination when NONE") + { + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// No provider +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.no_provider") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::None; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::None); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId.empty()); + CHECK(Info.AvailabilityZone.empty()); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); +} + +// --------------------------------------------------------------------------- +// Sentinel file management +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.sentinel_files") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::None; + Imds.Start(); + + auto DataDir = Imds.DataDir(); + + SUBCASE("sentinels are written on failed detection") + { + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::None); + CHECK(zen::IsFile(DataDir / ".isNotAWS")); + CHECK(zen::IsFile(DataDir / ".isNotAzure")); + CHECK(zen::IsFile(DataDir / ".isNotGCP")); + } + + SUBCASE("ClearSentinelFiles removes sentinels") + { + auto Cloud = Imds.CreateCloud(); + + CHECK(zen::IsFile(DataDir / ".isNotAWS")); + CHECK(zen::IsFile(DataDir / ".isNotAzure")); + CHECK(zen::IsFile(DataDir / ".isNotGCP")); + + Cloud->ClearSentinelFiles(); + + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP")); + } + + SUBCASE("only failed providers get sentinels") + { + // Switch to AWS — Azure and GCP never probed, so no sentinels for them + Imds.Mock.ActiveProvider = CloudProvider::AWS; + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::AWS); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP")); + } +} + +TEST_SUITE_END(); + +void +cloudmetadata_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp new file mode 100644 index 000000000..838d741b6 --- /dev/null +++ b/src/zencompute/computeservice.cpp @@ -0,0 +1,2236 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "runners/functionrunner.h" +# include "recording/actionrecorder.h" +# include "runners/localrunner.h" +# include "runners/remotehttprunner.h" +# if ZEN_PLATFORM_LINUX +# include "runners/linuxrunner.h" +# elif ZEN_PLATFORM_WINDOWS +# include "runners/windowsrunner.h" +# elif ZEN_PLATFORM_MAC +# include "runners/macrunner.h" +# endif + +# include <zencompute/recordingreader.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/trace.h> +# include <zencore/workthreadpool.h> +# include <zenutil/workerpools.h> +# include <zentelemetry/stats.h> +# include <zenhttp/httpclient.h> + +# include <set> +# include <deque> +# include <map> +# include <thread> +# include <unordered_map> +# include <unordered_set> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <EASTL/hash_set.h> +ZEN_THIRD_PARTY_INCLUDES_END + +using namespace std::literals; + +namespace zen { + +const char* +ToString(compute::ComputeServiceSession::SessionState State) +{ + using enum compute::ComputeServiceSession::SessionState; + switch (State) + { + case Created: + return "Created"; + case Ready: + return "Ready"; + case Draining: + return "Draining"; + case Paused: + return "Paused"; + case Abandoned: + return "Abandoned"; + case Sunset: + return "Sunset"; + } + return "Unknown"; +} + +const char* +ToString(compute::ComputeServiceSession::QueueState State) +{ + using enum compute::ComputeServiceSession::QueueState; + switch (State) + { + case Active: + return "active"; + case Draining: + return "draining"; + case Cancelled: + return "cancelled"; + } + return "unknown"; +} + +} // namespace zen + +namespace zen::compute { + +using SessionState = ComputeServiceSession::SessionState; + +static_assert(ZEN_ARRAY_COUNT(ComputeServiceSession::ActionHistoryEntry::Timestamps) == static_cast<size_t>(RunnerAction::State::_Count)); + +////////////////////////////////////////////////////////////////////////// + +struct ComputeServiceSession::Impl +{ + ComputeServiceSession* m_ComputeServiceSession; + ChunkResolver& m_ChunkResolver; + LoggerRef m_Log{logging::Get("compute")}; + + Impl(ComputeServiceSession* InComputeServiceSession, ChunkResolver& InChunkResolver) + : m_ComputeServiceSession(InComputeServiceSession) + , m_ChunkResolver(InChunkResolver) + , m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) + , m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) + { + // Create a non-expiring, non-deletable implicit queue for legacy endpoints + auto Result = CreateQueue("implicit"sv, {}, {}); + m_ImplicitQueueId = Result.QueueId; + m_QueueLock.WithSharedLock([&] { m_Queues[m_ImplicitQueueId]->Implicit = true; }); + + m_SchedulingThread = std::thread{&Impl::SchedulerThreadFunction, this}; + } + + void WaitUntilReady(); + void Shutdown(); + bool IsHealthy(); + + bool RequestStateTransition(SessionState NewState); + void AbandonAllActions(); + + LoggerRef Log() { return m_Log; } + + // Orchestration + + void SetOrchestratorEndpoint(std::string_view Endpoint); + void SetOrchestratorBasePath(std::filesystem::path BasePath); + + std::string m_OrchestratorEndpoint; + std::filesystem::path m_OrchestratorBasePath; + Stopwatch m_OrchestratorQueryTimer; + std::unordered_set<std::string> m_KnownWorkerUris; + + void UpdateCoordinatorState(); + + // Worker registration and discovery + + struct FunctionDefinition + { + std::string FunctionName; + Guid FunctionVersion; + Guid BuildSystemVersion; + IoHash WorkerId; + }; + + void RegisterWorker(CbPackage Worker); + WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + + // Action scheduling and tracking + + std::atomic<SessionState> m_SessionState{SessionState::Created}; + std::atomic<int32_t> m_ActionsCounter = 0; // sequence number + metrics::Meter m_ArrivalRate; + + RwLock m_PendingLock; + std::map<int, Ref<RunnerAction>> m_PendingActions; + + RwLock m_RunningLock; + std::unordered_map<int, Ref<RunnerAction>> m_RunningMap; + + RwLock m_ResultsLock; + std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap; + metrics::Meter m_ResultRate; + std::atomic<uint64_t> m_RetiredCount{0}; + + EnqueueResult EnqueueAction(int QueueId, CbObject ActionObject, int Priority); + EnqueueResult EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority); + + void GetCompleted(CbWriter& Cbo); + + HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + void RetireActionResult(int ActionLsn); + + std::thread m_SchedulingThread; + std::atomic<bool> m_SchedulingThreadEnabled{true}; + Event m_SchedulingThreadEvent; + + void SchedulerThreadFunction(); + void SchedulePendingActions(); + + // Workers + + RwLock m_WorkerLock; + std::unordered_map<IoHash, CbPackage> m_WorkerMap; + std::vector<FunctionDefinition> m_FunctionList; + std::vector<IoHash> GetKnownWorkerIds(); + void SyncWorkersToRunner(FunctionRunner& Runner); + + // Runners + + DeferredDirectoryDeleter m_DeferredDeleter; + WorkerThreadPool& m_LocalSubmitPool; + WorkerThreadPool& m_RemoteSubmitPool; + RunnerGroup<LocalProcessRunner> m_LocalRunnerGroup; + RunnerGroup<RemoteHttpRunner> m_RemoteRunnerGroup; + + void ShutdownRunners(); + + // Recording + + void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); + void StopRecording(); + + std::unique_ptr<ActionRecorder> m_Recorder; + + // History tracking + + RwLock m_ActionHistoryLock; + std::deque<ComputeServiceSession::ActionHistoryEntry> m_ActionHistory; + size_t m_HistoryLimit = 1000; + + // Queue tracking + + using QueueState = ComputeServiceSession::QueueState; + + struct QueueEntry : RefCounted + { + int QueueId; + bool Implicit{false}; + std::atomic<QueueState> State{QueueState::Active}; + std::atomic<int> ActiveCount{0}; // pending + running + std::atomic<int> CompletedCount{0}; // successfully completed + std::atomic<int> FailedCount{0}; // failed + std::atomic<int> AbandonedCount{0}; // abandoned + std::atomic<int> CancelledCount{0}; // cancelled + std::atomic<uint64_t> IdleSince{0}; // hifreq tick when queue became idle; 0 = has active work + + RwLock m_Lock; + std::unordered_set<int> ActiveLsns; // for cancellation lookup + std::unordered_set<int> FinishedLsns; // completed/failed/cancelled LSNs + + std::string Tag; + CbObject Metadata; + CbObject Config; + }; + + int m_ImplicitQueueId{0}; + std::atomic<int> m_QueueCounter{0}; + RwLock m_QueueLock; + std::unordered_map<int, Ref<QueueEntry>> m_Queues; + + Ref<QueueEntry> FindQueue(int QueueId) + { + Ref<QueueEntry> Queue; + m_QueueLock.WithSharedLock([&] { + if (auto It = m_Queues.find(QueueId); It != m_Queues.end()) + { + Queue = It->second; + } + }); + return Queue; + } + + ComputeServiceSession::CreateQueueResult CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config); + std::vector<int> GetQueueIds(); + ComputeServiceSession::QueueStatus GetQueueStatus(int QueueId); + CbObject GetQueueMetadata(int QueueId); + CbObject GetQueueConfig(int QueueId); + void CancelQueue(int QueueId); + void DeleteQueue(int QueueId); + void DrainQueue(int QueueId); + ComputeServiceSession::EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); + ComputeServiceSession::EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + void GetQueueCompleted(int QueueId, CbWriter& Cbo); + void NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState); + void ExpireCompletedQueues(); + + Stopwatch m_QueueExpiryTimer; + + std::vector<ComputeServiceSession::RunningActionInfo> GetRunningActions(); + std::vector<ComputeServiceSession::ActionHistoryEntry> GetActionHistory(int Limit); + std::vector<ComputeServiceSession::ActionHistoryEntry> GetQueueHistory(int QueueId, int Limit); + + // Action submission + + [[nodiscard]] size_t QueryCapacity(); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action); + [[nodiscard]] std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + [[nodiscard]] size_t GetSubmittedActionCount(); + + // Updates + + RwLock m_UpdatedActionsLock; + std::vector<Ref<RunnerAction>> m_UpdatedActions; + + void HandleActionUpdates(); + void PostUpdate(RunnerAction* Action); + + static constexpr int kDefaultMaxRetries = 3; + int GetMaxRetriesForQueue(int QueueId); + + ComputeServiceSession::RescheduleResult RescheduleAction(int ActionLsn); + + ActionCounts GetActionCounts() + { + ActionCounts Counts; + Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load(); + Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] { + size_t Count = 0; + for (const auto& [Id, Queue] : m_Queues) + { + if (!Queue->Implicit) + { + ++Count; + } + } + return Count; + }); + return Counts; + } + + void EmitStats(CbObjectWriter& Cbo) + { + Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed)); + m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); + m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); + m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); + Cbo << "actions_submitted"sv << GetSubmittedActionCount(); + EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo); + EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); + } +}; + +bool +ComputeServiceSession::Impl::IsHealthy() +{ + return m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned; +} + +bool +ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState) +{ + SessionState Current = m_SessionState.load(std::memory_order_relaxed); + + for (;;) + { + if (Current == NewState) + { + return true; + } + + // Validate the transition + bool Valid = false; + + switch (Current) + { + case SessionState::Created: + Valid = (NewState == SessionState::Ready); + break; + case SessionState::Ready: + Valid = (NewState == SessionState::Draining); + break; + case SessionState::Draining: + Valid = (NewState == SessionState::Ready || NewState == SessionState::Paused); + break; + case SessionState::Paused: + Valid = (NewState == SessionState::Ready || NewState == SessionState::Sunset); + break; + case SessionState::Abandoned: + Valid = (NewState == SessionState::Sunset); + break; + case SessionState::Sunset: + Valid = false; + break; + } + + // Allow jumping directly to Abandoned or Sunset from any non-terminal state + if (NewState == SessionState::Abandoned && Current < SessionState::Abandoned) + { + Valid = true; + } + if (NewState == SessionState::Sunset && Current != SessionState::Sunset) + { + Valid = true; + } + + if (!Valid) + { + ZEN_WARN("invalid session state transition {} -> {}", ToString(Current), ToString(NewState)); + return false; + } + + if (m_SessionState.compare_exchange_strong(Current, NewState, std::memory_order_acq_rel)) + { + ZEN_INFO("session state: {} -> {}", ToString(Current), ToString(NewState)); + + if (NewState == SessionState::Abandoned) + { + AbandonAllActions(); + } + + return true; + } + + // CAS failed, Current was updated — retry with the new value + } +} + +void +ComputeServiceSession::Impl::AbandonAllActions() +{ + // Collect all pending actions and mark them as Abandoned + std::vector<Ref<RunnerAction>> PendingToAbandon; + + m_PendingLock.WithSharedLock([&] { + PendingToAbandon.reserve(m_PendingActions.size()); + for (auto& [Lsn, Action] : m_PendingActions) + { + PendingToAbandon.push_back(Action); + } + }); + + for (auto& Action : PendingToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + } + + // Collect all running actions and mark them as Abandoned, then + // best-effort cancel via the local runner group + std::vector<Ref<RunnerAction>> RunningToAbandon; + + m_RunningLock.WithSharedLock([&] { + RunningToAbandon.reserve(m_RunningMap.size()); + for (auto& [Lsn, Action] : m_RunningMap) + { + RunningToAbandon.push_back(Action); + } + }); + + for (auto& Action : RunningToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + m_LocalRunnerGroup.CancelAction(Action->ActionLsn); + } + + ZEN_INFO("abandoned all actions: {} pending, {} running", PendingToAbandon.size(), RunningToAbandon.size()); +} + +void +ComputeServiceSession::Impl::SetOrchestratorEndpoint(std::string_view Endpoint) +{ + m_OrchestratorEndpoint = Endpoint; +} + +void +ComputeServiceSession::Impl::SetOrchestratorBasePath(std::filesystem::path BasePath) +{ + m_OrchestratorBasePath = std::move(BasePath); +} + +void +ComputeServiceSession::Impl::UpdateCoordinatorState() +{ + ZEN_TRACE_CPU("ComputeServiceSession::UpdateCoordinatorState"); + if (m_OrchestratorEndpoint.empty()) + { + return; + } + + // Poll faster when we have no discovered workers yet so remote runners come online quickly + const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000; + if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs) + { + return; + } + + m_OrchestratorQueryTimer.Reset(); + + try + { + HttpClient Client(m_OrchestratorEndpoint); + + HttpClient::Response Response = Client.Get("/orch/agents"); + + if (!Response.IsSuccess()) + { + ZEN_WARN("orchestrator query failed with status {}", static_cast<int>(Response.StatusCode)); + return; + } + + CbObject WorkerList = Response.AsObject(); + + std::unordered_set<std::string> ValidWorkerUris; + + for (auto& Item : WorkerList["workers"sv]) + { + CbObjectView Worker = Item.AsObjectView(); + + uint64_t Dt = Worker["dt"sv].AsUInt64(); + bool Reachable = Worker["reachable"sv].AsBool(); + std::string_view Uri = Worker["uri"sv].AsString(); + + // Skip stale workers (not seen in over 30 seconds) + if (Dt > 30000) + { + continue; + } + + // Skip workers that are not confirmed reachable + if (!Reachable) + { + continue; + } + + std::string UriStr{Uri}; + ValidWorkerUris.insert(UriStr); + + // Skip workers we already know about + if (m_KnownWorkerUris.contains(UriStr)) + { + continue; + } + + ZEN_INFO("discovered new worker at {}", UriStr); + + m_KnownWorkerUris.insert(UriStr); + + auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool); + SyncWorkersToRunner(*NewRunner); + m_RemoteRunnerGroup.AddRunner(NewRunner); + } + + // Remove workers that are no longer valid (stale or unreachable) + for (auto It = m_KnownWorkerUris.begin(); It != m_KnownWorkerUris.end();) + { + if (!ValidWorkerUris.contains(*It)) + { + const std::string& ExpiredUri = *It; + ZEN_INFO("removing expired worker at {}", ExpiredUri); + + m_RemoteRunnerGroup.RemoveRunnerIf([&](const RemoteHttpRunner& Runner) { return Runner.GetHostName() == ExpiredUri; }); + + It = m_KnownWorkerUris.erase(It); + } + else + { + ++It; + } + } + } + catch (const HttpClientError& Ex) + { + ZEN_WARN("orchestrator query error: {}", Ex.what()); + } + catch (const std::exception& Ex) + { + ZEN_WARN("orchestrator query unexpected error: {}", Ex.what()); + } +} + +void +ComputeServiceSession::Impl::WaitUntilReady() +{ + if (m_RemoteRunnerGroup.GetRunnerCount() || !m_OrchestratorEndpoint.empty()) + { + ZEN_INFO("waiting for remote runners..."); + + constexpr int MaxWaitSeconds = 120; + + for (int Elapsed = 0; Elapsed < MaxWaitSeconds; Elapsed++) + { + if (!m_SchedulingThreadEnabled.load(std::memory_order_relaxed)) + { + ZEN_WARN("shutdown requested while waiting for remote runners"); + return; + } + + const size_t Capacity = m_RemoteRunnerGroup.QueryCapacity(); + + if (Capacity > 0) + { + ZEN_INFO("found {} remote runners (capacity: {})", m_RemoteRunnerGroup.GetRunnerCount(), Capacity); + break; + } + + zen::Sleep(1000); + } + } + else + { + ZEN_ASSERT(m_LocalRunnerGroup.GetRunnerCount(), "no runners available"); + } + + RequestStateTransition(SessionState::Ready); +} + +void +ComputeServiceSession::Impl::Shutdown() +{ + RequestStateTransition(SessionState::Sunset); + + m_SchedulingThreadEnabled = false; + m_SchedulingThreadEvent.Set(); + if (m_SchedulingThread.joinable()) + { + m_SchedulingThread.join(); + } + + ShutdownRunners(); + + m_DeferredDeleter.Shutdown(); +} + +void +ComputeServiceSession::Impl::ShutdownRunners() +{ + m_LocalRunnerGroup.Shutdown(); + m_RemoteRunnerGroup.Shutdown(); +} + +void +ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) +{ + ZEN_INFO("starting recording to '{}'", RecordingPath); + + m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath); + + ZEN_INFO("started recording to '{}'", RecordingPath); +} + +void +ComputeServiceSession::Impl::StopRecording() +{ + ZEN_INFO("stopping recording"); + + m_Recorder = nullptr; + + ZEN_INFO("stopped recording"); +} + +std::vector<ComputeServiceSession::RunningActionInfo> +ComputeServiceSession::Impl::GetRunningActions() +{ + std::vector<ComputeServiceSession::RunningActionInfo> Result; + m_RunningLock.WithSharedLock([&] { + Result.reserve(m_RunningMap.size()); + for (const auto& [Lsn, Action] : m_RunningMap) + { + Result.push_back({.Lsn = Lsn, + .QueueId = Action->QueueId, + .ActionId = Action->ActionId, + .CpuUsagePercent = Action->CpuUsagePercent.load(std::memory_order_relaxed), + .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed)}); + } + }); + return Result; +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::Impl::GetActionHistory(int Limit) +{ + RwLock::SharedLockScope _(m_ActionHistoryLock); + + if (Limit > 0 && static_cast<size_t>(Limit) < m_ActionHistory.size()) + { + return std::vector<ActionHistoryEntry>(m_ActionHistory.end() - Limit, m_ActionHistory.end()); + } + + return std::vector<ActionHistoryEntry>(m_ActionHistory.begin(), m_ActionHistory.end()); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::Impl::GetQueueHistory(int QueueId, int Limit) +{ + // Resolve the queue and snapshot its finished LSN set + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + std::unordered_set<int> FinishedLsns; + + Queue->m_Lock.WithSharedLock([&] { FinishedLsns = Queue->FinishedLsns; }); + + // Filter the global history to entries belonging to this queue. + // m_ActionHistory is ordered oldest-first, so the filtered result keeps the same ordering. + std::vector<ActionHistoryEntry> Result; + + m_ActionHistoryLock.WithSharedLock([&] { + for (const auto& Entry : m_ActionHistory) + { + if (FinishedLsns.contains(Entry.Lsn)) + { + Result.push_back(Entry); + } + } + }); + + if (Limit > 0 && static_cast<size_t>(Limit) < Result.size()) + { + Result.erase(Result.begin(), Result.end() - Limit); + } + + return Result; +} + +void +ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) +{ + ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker"); + RwLock::ExclusiveLockScope _(m_WorkerLock); + + const IoHash& WorkerId = Worker.GetObject().GetHash(); + + if (m_WorkerMap.insert_or_assign(WorkerId, Worker).second) + { + // Note that since the convention currently is that WorkerId is equal to the hash + // of the worker descriptor there is no chance that we get a second write with a + // different descriptor. Thus we only need to call this the first time, when the + // worker is added + + m_LocalRunnerGroup.RegisterWorker(Worker); + m_RemoteRunnerGroup.RegisterWorker(Worker); + + if (m_Recorder) + { + m_Recorder->RegisterWorker(Worker); + } + + CbObject WorkerObj = Worker.GetObject(); + + // Populate worker database + + const Guid WorkerBuildSystemVersion = WorkerObj["buildsystem_version"sv].AsUuid(); + + for (auto& Item : WorkerObj["functions"sv]) + { + CbObjectView Function = Item.AsObjectView(); + + std::string_view FunctionName = Function["name"sv].AsString(); + const Guid FunctionVersion = Function["version"sv].AsUuid(); + + m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, + .FunctionVersion = FunctionVersion, + .BuildSystemVersion = WorkerBuildSystemVersion, + .WorkerId = WorkerId}); + } + } +} + +void +ComputeServiceSession::Impl::SyncWorkersToRunner(FunctionRunner& Runner) +{ + ZEN_TRACE_CPU("SyncWorkersToRunner"); + + std::vector<CbPackage> Workers; + + { + RwLock::SharedLockScope _(m_WorkerLock); + Workers.reserve(m_WorkerMap.size()); + for (const auto& [Id, Pkg] : m_WorkerMap) + { + Workers.push_back(Pkg); + } + } + + for (const CbPackage& Worker : Workers) + { + Runner.RegisterWorker(Worker); + } +} + +WorkerDesc +ComputeServiceSession::Impl::GetWorkerDescriptor(const IoHash& WorkerId) +{ + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + const CbPackage& Desc = It->second; + return {Desc, WorkerId}; + } + + return {}; +} + +std::vector<IoHash> +ComputeServiceSession::Impl::GetKnownWorkerIds() +{ + std::vector<IoHash> WorkerIds; + + m_WorkerLock.WithSharedLock([&] { + WorkerIds.reserve(m_WorkerMap.size()); + for (const auto& [WorkerId, _] : m_WorkerMap) + { + WorkerIds.push_back(WorkerId); + } + }); + + return WorkerIds; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueAction(int QueueId, CbObject ActionObject, int Priority) +{ + ZEN_TRACE_CPU("ComputeServiceSession::EnqueueAction"); + + // Resolve function to worker + + IoHash WorkerId{IoHash::Zero}; + CbPackage WorkerPackage; + + std::string_view FunctionName = ActionObject["Function"sv].AsString(); + const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); + const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + + m_WorkerLock.WithSharedLock([&] { + for (const FunctionDefinition& FuncDef : m_FunctionList) + { + if (FuncDef.FunctionName == FunctionName && FuncDef.FunctionVersion == FunctionVersion && + FuncDef.BuildSystemVersion == BuildSystemVersion) + { + WorkerId = FuncDef.WorkerId; + + break; + } + } + + if (WorkerId != IoHash::Zero) + { + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + WorkerPackage = It->second; + } + } + }); + + if (WorkerId == IoHash::Zero) + { + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker matches the action specification"; + + return {0, Writer.Save()}; + } + + if (WorkerPackage) + { + return EnqueueResolvedAction(QueueId, WorkerDesc{WorkerPackage, WorkerId}, ActionObject, Priority); + } + + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker found despite match"; + + return {0, Writer.Save()}; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + ZEN_TRACE_CPU("ComputeServiceSession::EnqueueResolvedAction"); + + if (m_SessionState.load(std::memory_order_relaxed) != SessionState::Ready) + { + CbObjectWriter Writer; + Writer << "error"sv << fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load())); + return {0, Writer.Save()}; + } + + const int ActionLsn = ++m_ActionsCounter; + + m_ArrivalRate.Mark(); + + Ref<RunnerAction> Pending{new RunnerAction(m_ComputeServiceSession)}; + + Pending->ActionLsn = ActionLsn; + Pending->QueueId = QueueId; + Pending->Worker = Worker; + Pending->ActionId = ActionObj.GetHash(); + Pending->ActionObj = ActionObj; + Pending->Priority = RequestPriority; + + // For now simply put action into pending state, so we can do batch scheduling + + ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); + + Pending->SetActionState(RunnerAction::State::Pending); + + if (m_Recorder) + { + m_Recorder->RecordAction(Pending); + } + + CbObjectWriter Writer; + Writer << "lsn" << Pending->ActionLsn; + Writer << "worker" << Pending->Worker.WorkerId; + Writer << "action" << Pending->ActionId; + + return {Pending->ActionLsn, Writer.Save()}; +} + +SubmitResult +ComputeServiceSession::Impl::SubmitAction(Ref<RunnerAction> Action) +{ + // Loosely round-robin scheduling of actions across runners. + // + // It's not entirely clear what this means given that submits + // can come in across multiple threads, but it's probably better + // than always starting with the first runner. + // + // Longer term we should track the state of the individual + // runners and make decisions accordingly. + + SubmitResult Result = m_LocalRunnerGroup.SubmitAction(Action); + if (Result.IsAccepted) + { + return Result; + } + + return m_RemoteRunnerGroup.SubmitAction(Action); +} + +size_t +ComputeServiceSession::Impl::GetSubmittedActionCount() +{ + return m_LocalRunnerGroup.GetSubmittedActionCount() + m_RemoteRunnerGroup.GetSubmittedActionCount(); +} + +HttpResponseCode +ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) + { + return HttpResponseCode::Accepted; + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) + { + return HttpResponseCode::Accepted; + } + } + + return HttpResponseCode::NotFound; +} + +HttpResponseCode +ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) + { + if (It->second->ActionId == ActionId) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + for (const auto& [K, Pending] : m_PendingActions) + { + if (Pending->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + for (const auto& [K, v] : m_RunningMap) + { + if (v->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + return HttpResponseCode::NotFound; +} + +void +ComputeServiceSession::Impl::RetireActionResult(int ActionLsn) +{ + m_DeferredDeleter.MarkReady(ActionLsn); +} + +void +ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo) +{ + Cbo.BeginArray("completed"); + + m_ResultsLock.WithSharedLock([&] { + for (auto& [Lsn, Action] : m_ResultsMap) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Lsn; + Cbo << "state"sv << RunnerAction::ToString(Action->ActionState()); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); +} + +////////////////////////////////////////////////////////////////////////// +// Queue management + +ComputeServiceSession::CreateQueueResult +ComputeServiceSession::Impl::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config) +{ + const int QueueId = ++m_QueueCounter; + + Ref<QueueEntry> Queue{new QueueEntry()}; + Queue->QueueId = QueueId; + Queue->Tag = Tag; + Queue->Metadata = std::move(Metadata); + Queue->Config = std::move(Config); + Queue->IdleSince = GetHifreqTimerValue(); + + m_QueueLock.WithExclusiveLock([&] { m_Queues[QueueId] = Queue; }); + + ZEN_DEBUG("created queue {}", QueueId); + + return {.QueueId = QueueId}; +} + +std::vector<int> +ComputeServiceSession::Impl::GetQueueIds() +{ + std::vector<int> Ids; + + m_QueueLock.WithSharedLock([&] { + Ids.reserve(m_Queues.size()); + for (const auto& [Id, Queue] : m_Queues) + { + if (!Queue->Implicit) + { + Ids.push_back(Id); + } + } + }); + + return Ids; +} + +ComputeServiceSession::QueueStatus +ComputeServiceSession::Impl::GetQueueStatus(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + const int Active = Queue->ActiveCount.load(std::memory_order_relaxed); + const int Completed = Queue->CompletedCount.load(std::memory_order_relaxed); + const int Failed = Queue->FailedCount.load(std::memory_order_relaxed); + const int AbandonedN = Queue->AbandonedCount.load(std::memory_order_relaxed); + const int CancelledN = Queue->CancelledCount.load(std::memory_order_relaxed); + const QueueState QState = Queue->State.load(); + + return { + .IsValid = true, + .QueueId = QueueId, + .ActiveCount = Active, + .CompletedCount = Completed, + .FailedCount = Failed, + .AbandonedCount = AbandonedN, + .CancelledCount = CancelledN, + .State = QState, + .IsComplete = (Active == 0), + }; +} + +CbObject +ComputeServiceSession::Impl::GetQueueMetadata(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + return Queue->Metadata; +} + +CbObject +ComputeServiceSession::Impl::GetQueueConfig(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + return Queue->Config; +} + +void +ComputeServiceSession::Impl::CancelQueue(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue || Queue->Implicit) + { + return; + } + + Queue->State.store(QueueState::Cancelled); + + // Collect active LSNs snapshot for cancellation + std::vector<int> LsnsToCancel; + + Queue->m_Lock.WithSharedLock([&] { LsnsToCancel.assign(Queue->ActiveLsns.begin(), Queue->ActiveLsns.end()); }); + + // Identify which LSNs are still pending (not yet dispatched to a runner) + std::vector<Ref<RunnerAction>> PendingActionsToCancel; + std::vector<int> RunningLsnsToCancel; + + m_PendingLock.WithSharedLock([&] { + for (int Lsn : LsnsToCancel) + { + if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end()) + { + PendingActionsToCancel.push_back(It->second); + } + } + }); + + m_RunningLock.WithSharedLock([&] { + for (int Lsn : LsnsToCancel) + { + if (m_RunningMap.find(Lsn) != m_RunningMap.end()) + { + RunningLsnsToCancel.push_back(Lsn); + } + } + }); + + // Cancel pending actions by marking them as Cancelled; they will flow through + // HandleActionUpdates and eventually be removed from the pending map. + for (auto& Action : PendingActionsToCancel) + { + Action->SetActionState(RunnerAction::State::Cancelled); + } + + // Best-effort cancellation of running actions via the local runner group. + // Also set the action state to Cancelled directly so a subsequent Failed + // transition from the runner is blocked (Cancelled > Failed in the enum). + for (int Lsn : RunningLsnsToCancel) + { + m_RunningLock.WithSharedLock([&] { + if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end()) + { + It->second->SetActionState(RunnerAction::State::Cancelled); + } + }); + m_LocalRunnerGroup.CancelAction(Lsn); + } + + m_RemoteRunnerGroup.CancelRemoteQueue(QueueId); + + ZEN_INFO("cancelled queue {}: {} pending, {} running actions cancelled", + QueueId, + PendingActionsToCancel.size(), + RunningLsnsToCancel.size()); + + // Wake up the scheduler to process the cancelled actions + m_SchedulingThreadEvent.Set(); +} + +void +ComputeServiceSession::Impl::DeleteQueue(int QueueId) +{ + // Never delete the implicit queue + { + Ref<QueueEntry> Queue = FindQueue(QueueId); + if (Queue && Queue->Implicit) + { + return; + } + } + + // Cancel any active work first + CancelQueue(QueueId); + + m_QueueLock.WithExclusiveLock([&] { + if (auto It = m_Queues.find(QueueId); It != m_Queues.end()) + { + m_Queues.erase(It); + } + }); +} + +void +ComputeServiceSession::Impl::DrainQueue(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue || Queue->Implicit) + { + return; + } + + QueueState Expected = QueueState::Active; + Queue->State.compare_exchange_strong(Expected, QueueState::Draining); + ZEN_INFO("draining queue {}", QueueId); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue not found"sv; + return {0, Writer.Save()}; + } + + QueueState QState = Queue->State.load(); + if (QState == QueueState::Cancelled) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is cancelled"sv; + return {0, Writer.Save()}; + } + + if (QState == QueueState::Draining) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is draining"sv; + return {0, Writer.Save()}; + } + + EnqueueResult Result = EnqueueAction(QueueId, ActionObject, Priority); + + if (Result.Lsn != 0) + { + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + } + + return Result; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue not found"sv; + return {0, Writer.Save()}; + } + + QueueState QState = Queue->State.load(); + if (QState == QueueState::Cancelled) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is cancelled"sv; + return {0, Writer.Save()}; + } + + if (QState == QueueState::Draining) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is draining"sv; + return {0, Writer.Save()}; + } + + EnqueueResult Result = EnqueueResolvedAction(QueueId, Worker, ActionObj, Priority); + + if (Result.Lsn != 0) + { + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + } + + return Result; +} + +void +ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + Cbo.BeginArray("completed"); + + if (Queue) + { + Queue->m_Lock.WithSharedLock([&] { + m_ResultsLock.WithSharedLock([&] { + for (int Lsn : Queue->FinishedLsns) + { + if (m_ResultsMap.contains(Lsn)) + { + Cbo << Lsn; + } + } + }); + }); + } + + Cbo.EndArray(); +} + +void +ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState) +{ + if (QueueId == 0) + { + return; + } + + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return; + } + + Queue->m_Lock.WithExclusiveLock([&] { + Queue->ActiveLsns.erase(Lsn); + Queue->FinishedLsns.insert(Lsn); + }); + + const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed); + if (PreviousActive == 1) + { + Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed); + } + + switch (ActionState) + { + case RunnerAction::State::Completed: + Queue->CompletedCount.fetch_add(1, std::memory_order_relaxed); + break; + case RunnerAction::State::Abandoned: + Queue->AbandonedCount.fetch_add(1, std::memory_order_relaxed); + break; + case RunnerAction::State::Cancelled: + Queue->CancelledCount.fetch_add(1, std::memory_order_relaxed); + break; + default: + Queue->FailedCount.fetch_add(1, std::memory_order_relaxed); + break; + } +} + +void +ComputeServiceSession::Impl::ExpireCompletedQueues() +{ + static constexpr uint64_t ExpiryTimeMs = 15 * 60 * 1000; + + std::vector<int> ExpiredQueueIds; + + m_QueueLock.WithSharedLock([&] { + for (const auto& [Id, Queue] : m_Queues) + { + if (Queue->Implicit) + { + continue; + } + const uint64_t Idle = Queue->IdleSince.load(std::memory_order_relaxed); + if (Idle != 0 && Queue->ActiveCount.load(std::memory_order_relaxed) == 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(GetHifreqTimerValue() - Idle); + if (ElapsedMs >= ExpiryTimeMs) + { + ExpiredQueueIds.push_back(Id); + } + } + } + }); + + for (int QueueId : ExpiredQueueIds) + { + ZEN_INFO("expiring idle queue {}", QueueId); + DeleteQueue(QueueId); + } +} + +void +ComputeServiceSession::Impl::SchedulePendingActions() +{ + ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions"); + int ScheduledCount = 0; + size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); + + static Stopwatch DumpRunningTimer; + + auto _ = MakeGuard([&] { + ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", + ScheduledCount, + RunningCount, + m_RetiredCount.load(), + PendingCount, + ResultCount); + + if (DumpRunningTimer.GetElapsedTimeMs() > 30000) + { + DumpRunningTimer.Reset(); + + std::set<int> RunningList; + m_RunningLock.WithSharedLock([&] { + for (auto& [K, V] : m_RunningMap) + { + RunningList.insert(K); + } + }); + + ExtendableStringBuilder<1024> RunningString; + for (int i : RunningList) + { + if (RunningString.Size()) + { + RunningString << ", "; + } + + RunningString.Append(IntNum(i)); + } + + ZEN_INFO("running: {}", RunningString); + } + }); + + size_t Capacity = QueryCapacity(); + + if (!Capacity) + { + _.Dismiss(); + + return; + } + + std::vector<Ref<RunnerAction>> ActionsToSchedule; + + // Pull actions to schedule from the pending queue, we will + // try to submit these to the runner outside of the lock. Note + // that because of how the state transitions work it's not + // actually the case that all of these actions will still be + // pending by the time we try to submit them, but that's fine. + // + // Also note that the m_PendingActions list is not maintained + // here, that's done periodically in SchedulePendingActions() + + m_PendingLock.WithExclusiveLock([&] { + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) + { + return; + } + + if (m_PendingActions.empty()) + { + return; + } + + for (auto& [Lsn, Pending] : m_PendingActions) + { + switch (Pending->ActionState()) + { + case RunnerAction::State::Pending: + ActionsToSchedule.push_back(Pending); + break; + + case RunnerAction::State::Submitting: + break; // already claimed by async submission + + case RunnerAction::State::Running: + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + case RunnerAction::State::Abandoned: + case RunnerAction::State::Cancelled: + break; + + default: + case RunnerAction::State::New: + ZEN_WARN("unexpected state {} for pending action {}", static_cast<int>(Pending->ActionState()), Pending->ActionLsn); + break; + } + } + + // Sort by priority descending, then by LSN ascending (FIFO within same priority) + std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) { + if (A->Priority != B->Priority) + { + return A->Priority > B->Priority; + } + return A->ActionLsn < B->ActionLsn; + }); + + if (ActionsToSchedule.size() > Capacity) + { + ActionsToSchedule.resize(Capacity); + } + + PendingCount = m_PendingActions.size(); + }); + + if (ActionsToSchedule.empty()) + { + _.Dismiss(); + return; + } + + ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); + + Stopwatch SubmitTimer; + std::vector<SubmitResult> SubmitResults = SubmitActions(ActionsToSchedule); + + int NotAcceptedCount = 0; + int ScheduledActionCount = 0; + + for (const SubmitResult& SubResult : SubmitResults) + { + if (SubResult.IsAccepted) + { + ++ScheduledActionCount; + } + else + { + ++NotAcceptedCount; + } + } + + ZEN_INFO("scheduled {} pending actions in {} ({} rejected)", + ScheduledActionCount, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + NotAcceptedCount); + + ScheduledCount += ScheduledActionCount; + PendingCount -= ScheduledActionCount; +} + +void +ComputeServiceSession::Impl::SchedulerThreadFunction() +{ + SetCurrentThreadName("Function_Scheduler"); + + auto _ = MakeGuard([&] { ZEN_INFO("scheduler thread exiting"); }); + + do + { + int TimeoutMs = 500; + + auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + + if (PendingCount) + { + TimeoutMs = 100; + } + + const bool WasSignaled = m_SchedulingThreadEvent.Wait(TimeoutMs); + + if (m_SchedulingThreadEnabled == false) + { + return; + } + + if (WasSignaled) + { + m_SchedulingThreadEvent.Reset(); + } + + ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}", + m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }), + PendingCount, + m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }), + m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }), + TimeoutMs); + + HandleActionUpdates(); + + // Auto-transition Draining → Paused when all work is done + if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining) + { + size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + + if (Pending == 0 && Running == 0) + { + SessionState Expected = SessionState::Draining; + if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel)) + { + ZEN_INFO("session state: Draining -> Paused (all work completed)"); + } + } + } + + UpdateCoordinatorState(); + SchedulePendingActions(); + + static constexpr uint64_t QueueExpirySweepIntervalMs = 30000; + if (m_QueueExpiryTimer.GetElapsedTimeMs() >= QueueExpirySweepIntervalMs) + { + m_QueueExpiryTimer.Reset(); + ExpireCompletedQueues(); + } + } while (m_SchedulingThreadEnabled); +} + +void +ComputeServiceSession::Impl::PostUpdate(RunnerAction* Action) +{ + m_UpdatedActionsLock.WithExclusiveLock([&] { m_UpdatedActions.emplace_back(Action); }); + m_SchedulingThreadEvent.Set(); +} + +int +ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId) +{ + if (QueueId == 0) + { + return kDefaultMaxRetries; + } + + CbObject Config = GetQueueConfig(QueueId); + + if (Config) + { + int Value = Config["max_retries"].AsInt32(0); + + if (Value > 0) + { + return Value; + } + } + + return kDefaultMaxRetries; +} + +ComputeServiceSession::RescheduleResult +ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) +{ + Ref<RunnerAction> Action; + RunnerAction::State State; + RescheduleResult ValidationError; + bool Removed = false; + + // Find, validate, and remove atomically under a single lock scope to prevent + // concurrent RescheduleAction calls from double-removing the same action. + m_ResultsLock.WithExclusiveLock([&] { + auto It = m_ResultsMap.find(ActionLsn); + if (It == m_ResultsMap.end()) + { + ValidationError = {.Success = false, .Error = "Action not found in results"}; + return; + } + + Action = It->second; + State = Action->ActionState(); + + if (State != RunnerAction::State::Failed && State != RunnerAction::State::Abandoned) + { + ValidationError = {.Success = false, .Error = "Action is not in a failed or abandoned state"}; + return; + } + + int MaxRetries = GetMaxRetriesForQueue(Action->QueueId); + if (Action->RetryCount.load(std::memory_order_relaxed) >= MaxRetries) + { + ValidationError = {.Success = false, .Error = "Retry limit reached"}; + return; + } + + m_ResultsMap.erase(It); + Removed = true; + }); + + if (!Removed) + { + return ValidationError; + } + + if (Action->QueueId != 0) + { + Ref<QueueEntry> Queue = FindQueue(Action->QueueId); + + if (Queue) + { + Queue->m_Lock.WithExclusiveLock([&] { + Queue->FinishedLsns.erase(ActionLsn); + Queue->ActiveLsns.insert(ActionLsn); + }); + + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + + if (State == RunnerAction::State::Failed) + { + Queue->FailedCount.fetch_sub(1, std::memory_order_relaxed); + } + else + { + Queue->AbandonedCount.fetch_sub(1, std::memory_order_relaxed); + } + } + } + + // Reset action state — this calls PostUpdate() internally + Action->ResetActionStateToPending(); + + int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); + ZEN_INFO("action {} ({}) manually rescheduled (retry {})", Action->ActionId, ActionLsn, NewRetryCount); + + return {.Success = true, .RetryCount = NewRetryCount}; +} + +void +ComputeServiceSession::Impl::HandleActionUpdates() +{ + ZEN_TRACE_CPU("ComputeServiceSession::HandleActionUpdates"); + + // Drain the update queue atomically + std::vector<Ref<RunnerAction>> UpdatedActions; + m_UpdatedActionsLock.WithExclusiveLock([&] { std::swap(UpdatedActions, m_UpdatedActions); }); + + std::unordered_set<int> SeenLsn; + + // Process each action's latest state, deduplicating by LSN. + // + // This is safe because state transitions are monotonically increasing by enum + // rank (Pending < Submitting < Running < Completed/Failed/Cancelled), so + // SetActionState rejects any transition to a lower-ranked state. By the time + // we read ActionState() here, it reflects the highest state reached — making + // the first occurrence per LSN authoritative and duplicates redundant. + for (Ref<RunnerAction>& Action : UpdatedActions) + { + const int ActionLsn = Action->ActionLsn; + + if (auto [It, Inserted] = SeenLsn.insert(ActionLsn); Inserted) + { + switch (Action->ActionState()) + { + // Newly enqueued — add to pending map for scheduling + case RunnerAction::State::Pending: + m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + break; + + // Async submission in progress — remains in pending map + case RunnerAction::State::Submitting: + break; + + // Dispatched to a runner — move from pending to running + case RunnerAction::State::Running: + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + m_RunningMap.insert({ActionLsn, Action}); + m_PendingActions.erase(ActionLsn); + }); + }); + ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); + break; + + // Terminal states — move to results, record history, notify queue + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + case RunnerAction::State::Abandoned: + case RunnerAction::State::Cancelled: + { + auto TerminalState = Action->ActionState(); + + // Automatic retry for Failed/Abandoned actions with retries remaining. + // Skip retries when the session itself is abandoned — those actions + // were intentionally abandoned and should not be rescheduled. + if ((TerminalState == RunnerAction::State::Failed || TerminalState == RunnerAction::State::Abandoned) && + m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned) + { + int MaxRetries = GetMaxRetriesForQueue(Action->QueueId); + + if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) + { + // Remove from whichever active map the action is in before resetting + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); + + // Reset triggers PostUpdate() which re-enters the action as Pending + Action->ResetActionStateToPending(); + int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); + + ZEN_INFO("action {} ({}) auto-rescheduled (retry {}/{})", + Action->ActionId, + ActionLsn, + NewRetryCount, + MaxRetries); + break; + } + } + + // Remove from whichever active map the action is in + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); + + m_ResultsLock.WithExclusiveLock([&] { + m_ResultsMap[ActionLsn] = Action; + + // Append to bounded action history ring + m_ActionHistoryLock.WithExclusiveLock([&] { + ActionHistoryEntry Entry{.Lsn = ActionLsn, + .QueueId = Action->QueueId, + .ActionId = Action->ActionId, + .WorkerId = Action->Worker.WorkerId, + .ActionDescriptor = Action->ActionObj, + .ExecutionLocation = std::move(Action->ExecutionLocation), + .Succeeded = TerminalState == RunnerAction::State::Completed, + .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed), + .RetryCount = Action->RetryCount.load(std::memory_order_relaxed)}; + + std::copy(std::begin(Action->Timestamps), std::end(Action->Timestamps), std::begin(Entry.Timestamps)); + + m_ActionHistory.push_back(std::move(Entry)); + + if (m_ActionHistory.size() > m_HistoryLimit) + { + m_ActionHistory.pop_front(); + } + }); + }); + m_RetiredCount.fetch_add(1); + m_ResultRate.Mark(1); + ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", + Action->ActionId, + ActionLsn, + TerminalState == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); + break; + } + } + } + } +} + +size_t +ComputeServiceSession::Impl::QueryCapacity() +{ + return m_LocalRunnerGroup.QueryCapacity() + m_RemoteRunnerGroup.QueryCapacity(); +} + +std::vector<SubmitResult> +ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions"); + std::vector<SubmitResult> Results(Actions.size()); + + // First try submitting the batch to local runners in parallel + + std::vector<SubmitResult> LocalResults = m_LocalRunnerGroup.SubmitActions(Actions); + std::vector<size_t> RemoteIndices; + std::vector<Ref<RunnerAction>> RemoteActions; + + for (size_t i = 0; i < Actions.size(); ++i) + { + if (LocalResults[i].IsAccepted) + { + Results[i] = std::move(LocalResults[i]); + } + else + { + RemoteIndices.push_back(i); + RemoteActions.push_back(Actions[i]); + } + } + + // Submit remaining actions to remote runners in parallel + if (!RemoteActions.empty()) + { + std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); + + for (size_t j = 0; j < RemoteIndices.size(); ++j) + { + Results[RemoteIndices[j]] = std::move(RemoteResults[j]); + } + } + + return Results; +} + +////////////////////////////////////////////////////////////////////////// + +ComputeServiceSession::ComputeServiceSession(ChunkResolver& InChunkResolver) +{ + m_Impl = std::make_unique<Impl>(this, InChunkResolver); +} + +ComputeServiceSession::~ComputeServiceSession() +{ + Shutdown(); +} + +bool +ComputeServiceSession::IsHealthy() +{ + return m_Impl->IsHealthy(); +} + +void +ComputeServiceSession::WaitUntilReady() +{ + m_Impl->WaitUntilReady(); +} + +void +ComputeServiceSession::Shutdown() +{ + m_Impl->Shutdown(); +} + +ComputeServiceSession::SessionState +ComputeServiceSession::GetSessionState() const +{ + return m_Impl->m_SessionState.load(std::memory_order_relaxed); +} + +bool +ComputeServiceSession::RequestStateTransition(SessionState NewState) +{ + return m_Impl->RequestStateTransition(NewState); +} + +void +ComputeServiceSession::SetOrchestratorEndpoint(std::string_view Endpoint) +{ + m_Impl->SetOrchestratorEndpoint(Endpoint); +} + +void +ComputeServiceSession::SetOrchestratorBasePath(std::filesystem::path BasePath) +{ + m_Impl->SetOrchestratorBasePath(std::move(BasePath)); +} + +void +ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) +{ + m_Impl->StartRecording(InResolver, RecordingPath); +} + +void +ComputeServiceSession::StopRecording() +{ + m_Impl->StopRecording(); +} + +ComputeServiceSession::ActionCounts +ComputeServiceSession::GetActionCounts() +{ + return m_Impl->GetActionCounts(); +} + +void +ComputeServiceSession::EmitStats(CbObjectWriter& Cbo) +{ + m_Impl->EmitStats(Cbo); +} + +std::vector<IoHash> +ComputeServiceSession::GetKnownWorkerIds() +{ + return m_Impl->GetKnownWorkerIds(); +} + +WorkerDesc +ComputeServiceSession::GetWorkerDescriptor(const IoHash& WorkerId) +{ + return m_Impl->GetWorkerDescriptor(WorkerId); +} + +void +ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddLocalRunner"); + +# if ZEN_PLATFORM_LINUX + auto* NewRunner = new LinuxProcessRunner(InChunkResolver, + BasePath, + m_Impl->m_DeferredDeleter, + m_Impl->m_LocalSubmitPool, + false, + MaxConcurrentActions); +# elif ZEN_PLATFORM_WINDOWS + auto* NewRunner = new WindowsProcessRunner(InChunkResolver, + BasePath, + m_Impl->m_DeferredDeleter, + m_Impl->m_LocalSubmitPool, + false, + MaxConcurrentActions); +# elif ZEN_PLATFORM_MAC + auto* NewRunner = + new MacProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, false, MaxConcurrentActions); +# endif + + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner); +} + +void +ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner"); + + auto* NewRunner = new RemoteHttpRunner(InChunkResolver, BasePath, HostName, m_Impl->m_RemoteSubmitPool); + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_RemoteRunnerGroup.AddRunner(NewRunner); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueAction(CbObject ActionObject, int Priority) +{ + return m_Impl->EnqueueActionToQueue(m_Impl->m_ImplicitQueueId, ActionObject, Priority); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + return m_Impl->EnqueueResolvedActionToQueue(m_Impl->m_ImplicitQueueId, Worker, ActionObj, RequestPriority); +} +ComputeServiceSession::CreateQueueResult +ComputeServiceSession::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config) +{ + return m_Impl->CreateQueue(Tag, std::move(Metadata), std::move(Config)); +} + +CbObject +ComputeServiceSession::GetQueueMetadata(int QueueId) +{ + return m_Impl->GetQueueMetadata(QueueId); +} + +CbObject +ComputeServiceSession::GetQueueConfig(int QueueId) +{ + return m_Impl->GetQueueConfig(QueueId); +} + +std::vector<int> +ComputeServiceSession::GetQueueIds() +{ + return m_Impl->GetQueueIds(); +} + +ComputeServiceSession::QueueStatus +ComputeServiceSession::GetQueueStatus(int QueueId) +{ + return m_Impl->GetQueueStatus(QueueId); +} + +void +ComputeServiceSession::CancelQueue(int QueueId) +{ + m_Impl->CancelQueue(QueueId); +} + +void +ComputeServiceSession::DrainQueue(int QueueId) +{ + m_Impl->DrainQueue(QueueId); +} + +void +ComputeServiceSession::DeleteQueue(int QueueId) +{ + m_Impl->DeleteQueue(QueueId); +} + +void +ComputeServiceSession::GetQueueCompleted(int QueueId, CbWriter& Cbo) +{ + m_Impl->GetQueueCompleted(QueueId, Cbo); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + return m_Impl->EnqueueActionToQueue(QueueId, ActionObject, Priority); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + return m_Impl->EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority); +} + +void +ComputeServiceSession::RegisterWorker(CbPackage Worker) +{ + m_Impl->RegisterWorker(Worker); +} + +HttpResponseCode +ComputeServiceSession::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + return m_Impl->GetActionResult(ActionLsn, OutResultPackage); +} + +HttpResponseCode +ComputeServiceSession::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + return m_Impl->FindActionResult(ActionId, OutResultPackage); +} + +void +ComputeServiceSession::RetireActionResult(int ActionLsn) +{ + m_Impl->RetireActionResult(ActionLsn); +} + +ComputeServiceSession::RescheduleResult +ComputeServiceSession::RescheduleAction(int ActionLsn) +{ + return m_Impl->RescheduleAction(ActionLsn); +} + +std::vector<ComputeServiceSession::RunningActionInfo> +ComputeServiceSession::GetRunningActions() +{ + return m_Impl->GetRunningActions(); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::GetActionHistory(int Limit) +{ + return m_Impl->GetActionHistory(Limit); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::GetQueueHistory(int QueueId, int Limit) +{ + return m_Impl->GetQueueHistory(QueueId, Limit); +} + +void +ComputeServiceSession::GetCompleted(CbWriter& Cbo) +{ + m_Impl->GetCompleted(Cbo); +} + +void +ComputeServiceSession::PostUpdate(RunnerAction* Action) +{ + m_Impl->PostUpdate(Action); +} + +////////////////////////////////////////////////////////////////////////// + +void +computeservice_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp new file mode 100644 index 000000000..e82a40781 --- /dev/null +++ b/src/zencompute/httpcomputeservice.cpp @@ -0,0 +1,1643 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/httpcomputeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "runners/functionrunner.h" + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/logging.h> +# include <zencore/system.h> +# include <zencore/thread.h> +# include <zencore/trace.h> +# include <zencore/uid.h> +# include <zenstore/cidstore.h> +# include <zentelemetry/stats.h> + +# include <span> +# include <unordered_map> + +using namespace std::literals; + +namespace zen::compute { + +constinit AsciiSet g_DecimalSet("0123456789"); +constinit AsciiSet g_HexSet("0123456789abcdefABCDEF"); + +auto DecimalMatcher = [](std::string_view Str) { return AsciiSet::HasOnly(Str, g_DecimalSet); }; +auto IoHashMatcher = [](std::string_view Str) { return Str.size() == 40 && AsciiSet::HasOnly(Str, g_HexSet); }; +auto OidMatcher = [](std::string_view Str) { return Str.size() == 24 && AsciiSet::HasOnly(Str, g_HexSet); }; + +////////////////////////////////////////////////////////////////////////// + +struct HttpComputeService::Impl +{ + HttpComputeService* m_Self; + CidStore& m_CidStore; + IHttpStatsService& m_StatsService; + LoggerRef m_Log; + std::filesystem::path m_BaseDir; + HttpRequestRouter m_Router; + ComputeServiceSession m_ComputeService; + SystemMetricsTracker m_MetricsTracker; + + // Metrics + + metrics::OperationTiming m_HttpRequests; + + // Per-remote-queue metadata, shared across all lookup maps below. + + struct RemoteQueueInfo : RefCounted + { + int QueueId = 0; + Oid Token; + std::string IdempotencyKey; // empty if no idempotency key was provided + std::string ClientHostname; // empty if no hostname was provided + }; + + // Remote queue registry — all three maps share the same RemoteQueueInfo objects. + // All maps are guarded by m_RemoteQueueLock. + + RwLock m_RemoteQueueLock; + std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token → info + std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId → info + std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key → info + + LoggerRef Log() { return m_Log; } + + int ResolveQueueToken(const Oid& Token); + int ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture); + + struct IngestStats + { + int Count = 0; + int NewCount = 0; + uint64_t Bytes = 0; + uint64_t NewBytes = 0; + }; + + IngestStats IngestPackageAttachments(const CbPackage& Package); + bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); + void HandleWorkersGet(HttpServerRequest& HttpReq); + void HandleWorkersAllGet(HttpServerRequest& HttpReq); + void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); + void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + + void RegisterRoutes(); + + Impl(HttpComputeService* Self, + CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions) + : m_Self(Self) + , m_CidStore(InCidStore) + , m_StatsService(StatsService) + , m_Log(logging::Get("compute")) + , m_BaseDir(BaseDir) + , m_ComputeService(InCidStore) + { + m_ComputeService.AddLocalRunner(InCidStore, m_BaseDir / "local", MaxConcurrentActions); + m_ComputeService.WaitUntilReady(); + m_StatsService.RegisterHandler("compute", *m_Self); + RegisterRoutes(); + } +}; + +////////////////////////////////////////////////////////////////////////// + +void +HttpComputeService::Impl::RegisterRoutes() +{ + m_Router.AddMatcher("lsn", DecimalMatcher); + m_Router.AddMatcher("worker", IoHashMatcher); + m_Router.AddMatcher("action", IoHashMatcher); + m_Router.AddMatcher("queue", DecimalMatcher); + m_Router.AddMatcher("oidtoken", OidMatcher); + m_Router.AddMatcher("queueref", [](std::string_view Str) { return DecimalMatcher(Str) || OidMatcher(Str); }); + + m_Router.RegisterRoute( + "ready", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.IsHealthy()) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + + return HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "abandon", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + bool Success = m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Abandoned); + + if (Success) + { + CbObjectWriter Cbo; + Cbo << "state"sv + << "Abandoned"sv; + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Abandoned from current state"sv; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "workers", + [this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "workers/{worker}", + [this](HttpRouterRequest& Req) { HandleWorkerRequest(Req.ServerRequest(), IoHash::FromHexString(Req.GetCapture(1))); }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/completed", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + m_ComputeService.GetCompleted(Cbo); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + Cbo.BeginObject("metrics"); + Describe(Sm, Cbo); + Cbo.EndObject(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto QueryParams = HttpReq.GetQueryParams(); + + int QueryLimit = 50; + + if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) + { + QueryLimit = ParseInt<int>(LimitParam).value_or(50); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("history"); + for (const auto& Entry : m_ComputeService.GetActionHistory(QueryLimit)) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Entry.Lsn; + Cbo << "queueId"sv << Entry.QueueId; + Cbo << "actionId"sv << Entry.ActionId; + Cbo << "workerId"sv << Entry.WorkerId; + Cbo << "succeeded"sv << Entry.Succeeded; + Cbo << "actionDescriptor"sv << Entry.ActionDescriptor; + if (Entry.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds); + } + if (Entry.RetryCount > 0) + { + Cbo << "retry_count"sv << Entry.RetryCount; + } + + for (const auto& Timestamp : Entry.Timestamps) + { + Cbo.AddInteger( + fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))), + Timestamp); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/running", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Running = m_ComputeService.GetRunningActions(); + CbObjectWriter Cbo; + Cbo.BeginArray("running"); + for (const auto& Info : Running) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Info.Lsn; + Cbo << "queueId"sv << Info.QueueId; + Cbo << "actionId"sv << Info.ActionId; + if (Info.CpuUsagePercent >= 0.0f) + { + Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent); + } + if (Info.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{lsn}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int ActionLsn = ParseInt<int>(Req.GetCapture(1)).value_or(0); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbPackage Output; + HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output); + + if (ResponseCode == HttpResponseCode::OK) + { + HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + else + { + HttpReq.WriteResponse(ResponseCode); + } + + // Once we've initiated the response we can mark the result + // as retired, allowing the service to free any associated + // resources. Note that there still needs to be a delay + // to allow the transmission to complete, it would be better + // if we could issue this once the response is fully sent... + m_ComputeService.RetireActionResult(ActionLsn); + } + break; + + case HttpVerb::kPost: + { + auto Result = m_ComputeService.RescheduleAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "lsn"sv << ActionLsn; + Cbo << "retry_count"sv << Result.RetryCount; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the + // one which uses the scheduled action lsn for lookups + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2)); + + CbPackage Output; + if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output); + ResponseCode != HttpResponseCode::OK) + { + ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) + + if (ResponseCode == HttpResponseCode::NotFound) + { + return HttpReq.WriteResponse(ResponseCode); + } + + return HttpReq.WriteResponse(ResponseCode); + } + + ZEN_DEBUG("jobs/{}/{}: OK", Req.GetCapture(1), Req.GetCapture(2)) + + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId); + + if (!Worker) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + // TODO: return status of all pending or executing jobs + break; + + case HttpVerb::kPost: + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span<const CbAttachment> Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + ActionObj.GetHash(), + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + break; + + default: + break; + } + break; + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto QueryParams = HttpReq.GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + // Resolve worker + + // + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span<const CbAttachment> Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + return; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "workers/all", + [this](HttpRouterRequest& Req) { HandleWorkersAllGet(Req.ServerRequest()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkersGet(HttpReq); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers/all", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkersAllGet(HttpReq); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkerRequest(HttpReq, IoHash::FromHexString(Req.GetCapture(2))); + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "sysinfo", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + + CbObjectWriter Cbo; + Describe(Sm, Cbo); + + Cbo << "cpu_usage" << Sm.CpuUsagePercent; + Cbo << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; + Cbo << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; + Cbo << "disk_used" << 100 * 1024; + Cbo << "disk_total" << 100 * 1024 * 1024; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "record/start", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + m_ComputeService.StartRecording(m_CidStore, m_BaseDir / "recording"); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "record/stop", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + m_ComputeService.StopRecording(); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); + + // Local-only queue listing and creation + + m_Router.RegisterRoute( + "queues", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbObjectWriter Cbo; + Cbo.BeginArray("queues"sv); + + for (const int QueueId : m_ComputeService.GetQueueIds()) + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + continue; + } + + Cbo.BeginObject(); + WriteQueueDescription(Cbo, QueueId, Status); + Cbo.EndObject(); + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + case HttpVerb::kPost: + { + CbObject Metadata; + CbObject Config; + if (const CbObject Body = HttpReq.ReadPayloadObject()) + { + Metadata = Body.Find("metadata"sv).AsObject(); + Config = Body.Find("config"sv).AsObject(); + } + + ComputeServiceSession::CreateQueueResult Result = + m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config)); + + CbObjectWriter Cbo; + Cbo << "queue_id"sv << Result.QueueId; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + // Queue creation routes — these remain separate since local creates a plain queue + // while remote additionally generates an OID token for external access. + + m_Router.RegisterRoute( + "queues/remote", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + // Extract optional fields from the request body. + // idempotency_key: when present, we return the existing remote queue token for this + // key rather than creating a new queue, making the endpoint safe to call concurrently. + // hostname: human-readable origin context stored alongside the queue for diagnostics. + // metadata: arbitrary CbObject metadata propagated from the originating queue. + // config: arbitrary CbObject config propagated from the originating queue. + std::string IdempotencyKey; + std::string ClientHostname; + CbObject Metadata; + CbObject Config; + if (const CbObject Body = HttpReq.ReadPayloadObject()) + { + IdempotencyKey = std::string(Body["idempotency_key"sv].AsString()); + ClientHostname = std::string(Body["hostname"sv].AsString()); + Metadata = Body.Find("metadata"sv).AsObject(); + Config = Body.Find("config"sv).AsObject(); + } + + // Stamp the forwarding node's hostname into the metadata so that the + // remote side knows which node originated the queue. + if (!ClientHostname.empty()) + { + CbObjectWriter MetaWriter; + for (auto Field : Metadata) + { + MetaWriter.AddField(Field.GetName(), Field); + } + MetaWriter << "via"sv << ClientHostname; + Metadata = MetaWriter.Save(); + } + + RwLock::ExclusiveLockScope _(m_RemoteQueueLock); + + if (!IdempotencyKey.empty()) + { + if (auto It = m_RemoteQueuesByTag.find(IdempotencyKey); It != m_RemoteQueuesByTag.end()) + { + Ref<RemoteQueueInfo> Existing = It->second; + if (m_ComputeService.GetQueueStatus(Existing->QueueId).IsValid) + { + CbObjectWriter Cbo; + Cbo << "queue_token"sv << Existing->Token.ToString(); + Cbo << "queue_id"sv << Existing->QueueId; + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + // Queue has since expired — clean up stale entries and fall through to create a new one + m_RemoteQueuesByToken.erase(Existing->Token); + m_RemoteQueuesByQueueId.erase(Existing->QueueId); + m_RemoteQueuesByTag.erase(It); + } + } + + ComputeServiceSession::CreateQueueResult Result = m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config)); + Ref<RemoteQueueInfo> InfoRef(new RemoteQueueInfo()); + InfoRef->QueueId = Result.QueueId; + InfoRef->Token = Oid::NewOid(); + InfoRef->IdempotencyKey = std::move(IdempotencyKey); + InfoRef->ClientHostname = std::move(ClientHostname); + + m_RemoteQueuesByToken[InfoRef->Token] = InfoRef; + m_RemoteQueuesByQueueId[InfoRef->QueueId] = InfoRef; + if (!InfoRef->IdempotencyKey.empty()) + { + m_RemoteQueuesByTag[InfoRef->IdempotencyKey] = InfoRef; + } + + CbObjectWriter Cbo; + Cbo << "queue_token"sv << InfoRef->Token.ToString(); + Cbo << "queue_id"sv << InfoRef->QueueId; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + + // Unified queue routes — {queueref} accepts both local integer IDs and remote OID tokens. + // ResolveQueueRef() handles access control (local-only for integer IDs) and token resolution. + + m_Router.RegisterRoute( + "queues/{queueref}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Cbo; + WriteQueueDescription(Cbo, QueueId, Status); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + case HttpVerb::kDelete: + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + m_ComputeService.CancelQueue(QueueId); + + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "queues/{queueref}/drain", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + m_ComputeService.DrainQueue(QueueId); + + // Return updated queue status + Status = m_ComputeService.GetQueueStatus(QueueId); + + CbObjectWriter Cbo; + WriteQueueDescription(Cbo, QueueId, Status); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/completed", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Cbo; + m_ComputeService.GetQueueCompleted(QueueId, Cbo); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = HttpReq.GetQueryParams(); + + int QueryLimit = 50; + + if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) + { + QueryLimit = ParseInt<int>(LimitParam).value_or(50); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("history"); + for (const auto& Entry : m_ComputeService.GetQueueHistory(QueueId, QueryLimit)) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Entry.Lsn; + Cbo << "queueId"sv << Entry.QueueId; + Cbo << "actionId"sv << Entry.ActionId; + Cbo << "workerId"sv << Entry.WorkerId; + Cbo << "succeeded"sv << Entry.Succeeded; + if (Entry.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds); + } + if (Entry.RetryCount > 0) + { + Cbo << "retry_count"sv << Entry.RetryCount; + } + + for (const auto& Timestamp : Entry.Timestamps) + { + Cbo.AddInteger( + fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))), + Timestamp); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/running", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + if (QueueId == 0) + { + return; + } + // Filter global running list to this queue + auto AllRunning = m_ComputeService.GetRunningActions(); + std::vector<ComputeServiceSession::RunningActionInfo> Running; + for (auto& Info : AllRunning) + if (Info.QueueId == QueueId) + Running.push_back(Info); + CbObjectWriter Cbo; + Cbo.BeginArray("running"); + for (const auto& Info : Running) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Info.Lsn; + Cbo << "queueId"sv << Info.QueueId; + Cbo << "actionId"sv << Info.ActionId; + if (Info.CpuUsagePercent >= 0.0f) + { + Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent); + } + if (Info.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(2)); + WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId); + + if (!Worker) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + if (!CheckAttachments(ActionObj, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, ActionObj.GetHash(), Result.Lsn); + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + IngestStats Stats = IngestPackageAttachments(Action); + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + ActionObj.GetHash(), + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + if (!CheckAttachments(ActionObj, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: action accepted (lsn {})", QueueId, Result.Lsn); + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + IngestStats Stats = IngestPackageAttachments(Action); + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{lsn}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + const int ActionLsn = ParseInt<int>(Req.GetCapture(2)).value_or(0); + + if (QueueId == 0) + { + return; + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + ZEN_UNUSED(QueueId); + + CbPackage Output; + HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output); + + if (ResponseCode == HttpResponseCode::OK) + { + HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + else + { + HttpReq.WriteResponse(ResponseCode); + } + + m_ComputeService.RetireActionResult(ActionLsn); + } + break; + + case HttpVerb::kPost: + { + ZEN_UNUSED(QueueId); + + auto Result = m_ComputeService.RescheduleAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "lsn"sv << ActionLsn; + Cbo << "retry_count"sv << Result.RetryCount; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); +} + +////////////////////////////////////////////////////////////////////////// + +HttpComputeService::HttpComputeService(CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions) +: m_Impl(std::make_unique<Impl>(this, InCidStore, StatsService, BaseDir, MaxConcurrentActions)) +{ +} + +HttpComputeService::~HttpComputeService() +{ + m_Impl->m_StatsService.UnregisterHandler("compute", *this); +} + +void +HttpComputeService::Shutdown() +{ + m_Impl->m_ComputeService.Shutdown(); +} + +ComputeServiceSession::ActionCounts +HttpComputeService::GetActionCounts() +{ + return m_Impl->m_ComputeService.GetActionCounts(); +} + +const char* +HttpComputeService::BaseUri() const +{ + return "/compute/"; +} + +void +HttpComputeService::HandleRequest(HttpServerRequest& Request) +{ + ZEN_TRACE_CPU("HttpComputeService::HandleRequest"); + metrics::OperationTiming::Scope $(m_Impl->m_HttpRequests); + + if (m_Impl->m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +void +HttpComputeService::HandleStatsRequest(HttpServerRequest& Request) +{ + CbObjectWriter Cbo; + m_Impl->m_ComputeService.EmitStats(Cbo); + + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpComputeService::Impl::WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status) +{ + Cbo << "queue_id"sv << Status.QueueId; + Cbo << "active_count"sv << Status.ActiveCount; + Cbo << "completed_count"sv << Status.CompletedCount; + Cbo << "failed_count"sv << Status.FailedCount; + Cbo << "abandoned_count"sv << Status.AbandonedCount; + Cbo << "cancelled_count"sv << Status.CancelledCount; + Cbo << "state"sv << ToString(Status.State); + Cbo << "cancelled"sv << (Status.State == ComputeServiceSession::QueueState::Cancelled); + Cbo << "draining"sv << (Status.State == ComputeServiceSession::QueueState::Draining); + Cbo << "is_complete"sv << Status.IsComplete; + + if (CbObject Meta = m_ComputeService.GetQueueMetadata(QueueId)) + { + Cbo << "metadata"sv << Meta; + } + + if (CbObject Cfg = m_ComputeService.GetQueueConfig(QueueId)) + { + Cbo << "config"sv << Cfg; + } + + { + RwLock::SharedLockScope $(m_RemoteQueueLock); + if (auto It = m_RemoteQueuesByQueueId.find(QueueId); It != m_RemoteQueuesByQueueId.end()) + { + Cbo << "queue_token"sv << It->second->Token.ToString(); + if (!It->second->ClientHostname.empty()) + { + Cbo << "hostname"sv << It->second->ClientHostname; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// + +int +HttpComputeService::Impl::ResolveQueueToken(const Oid& Token) +{ + RwLock::SharedLockScope $(m_RemoteQueueLock); + + auto It = m_RemoteQueuesByToken.find(Token); + + if (It != m_RemoteQueuesByToken.end()) + { + return It->second->QueueId; + } + + return 0; +} + +int +HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture) +{ + if (OidMatcher(Capture)) + { + // Remote OID token — accessible from any client + const Oid Token = Oid::FromHexString(Capture); + const int QueueId = ResolveQueueToken(Token); + + if (QueueId == 0) + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + return QueueId; + } + + // Local integer queue ID — restricted to local machine requests + if (!HttpReq.IsLocalMachineRequest()) + { + HttpReq.WriteResponse(HttpResponseCode::Forbidden); + return 0; + } + + return ParseInt<int>(Capture).value_or(0); +} + +HttpComputeService::Impl::IngestStats +HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package) +{ + IngestStats Stats; + + for (const CbAttachment& Attachment : Package.GetAttachments()) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + Stats.Bytes += CompressedSize; + ++Stats.Count; + + const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + Stats.NewBytes += CompressedSize; + ++Stats.NewCount; + } + } + + return Stats; +} + +bool +HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList) +{ + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + return NeedList.empty(); +} + +void +HttpComputeService::Impl::HandleWorkersGet(HttpServerRequest& HttpReq) +{ + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + for (const IoHash& WorkerId : m_ComputeService.GetKnownWorkerIds()) + { + Cbo << WorkerId; + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpComputeService::Impl::HandleWorkersAllGet(HttpServerRequest& HttpReq) +{ + std::vector<IoHash> WorkerIds = m_ComputeService.GetKnownWorkerIds(); + + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + for (const IoHash& WorkerId : WorkerIds) + { + Cbo.BeginObject(); + Cbo << "id" << WorkerId; + Cbo << "descriptor" << m_ComputeService.GetWorkerDescriptor(WorkerId).Descriptor.GetObject(); + Cbo.EndObject(); + } + + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId) +{ + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + if (WorkerDesc Desc = m_ComputeService.GetWorkerDescriptor(WorkerId)) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor.GetObject()); + } + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + + case HttpVerb::kPost: + { + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + CbObject WorkerSpec = HttpReq.ReadPayloadObject(); + + HashKeySet ChunkSet; + WorkerSpec.IterateAttachments([&](CbFieldView Field) { + const IoHash Hash = Field.AsHash(); + ChunkSet.AddHashToSet(Hash); + }); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerSpec); + + m_CidStore.FilterChunks(ChunkSet); + + if (ChunkSet.IsEmpty()) + { + ZEN_DEBUG("worker {}: all attachments already available", WorkerId); + m_ComputeService.RegisterWorker(WorkerPackage); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); + ChunkSet.IterateHashes([&](const IoHash& Hash) { + ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash); + ResponseWriter.AddHash(Hash); + }); + ResponseWriter.EndArray(); + + ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize()); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage WorkerSpecPackage = HttpReq.ReadPayloadPackage(); + CbObject WorkerSpec = WorkerSpecPackage.GetObject(); + + std::span<const CbAttachment> Attachments = WorkerSpecPackage.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + TotalAttachmentBytes += Buffer.GetCompressedSize(); + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += Buffer.GetCompressedSize(); + ++NewAttachmentCount; + } + } + + ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments", + WorkerId, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + m_ComputeService.RegisterWorker(WorkerSpecPackage); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + break; + + default: + break; + } + } + break; + + default: + break; + } +} + +////////////////////////////////////////////////////////////////////////// + +void +httpcomputeservice_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp new file mode 100644 index 000000000..6cbe01e04 --- /dev/null +++ b/src/zencompute/httporchestrator.cpp @@ -0,0 +1,650 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/httporchestrator.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencompute/orchestratorservice.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/logging.h> +# include <zencore/string.h> +# include <zencore/system.h> + +namespace zen::compute { + +// Worker IDs must be 3-64 characters and can only contain letters, numbers, underscores, and dashes +static bool +IsValidWorkerId(std::string_view Id) +{ + if (Id.size() < 3 || Id.size() > 64) + { + return false; + } + for (char c : Id) + { + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-') + { + continue; + } + return false; + } + return true; +} + +// Shared announce payload parser used by both the HTTP POST route and the +// WebSocket message handler. Returns the worker ID on success (empty on +// validation failure). The returned WorkerAnnouncement has string_view +// fields that reference the supplied CbObjectView, so the CbObject must +// outlive the returned announcement. +static std::string_view +ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnnouncement& Ann) +{ + Ann.Id = Data["id"].AsString(""); + Ann.Uri = Data["uri"].AsString(""); + + if (!IsValidWorkerId(Ann.Id)) + { + return {}; + } + + if (!Ann.Uri.starts_with("http://") && !Ann.Uri.starts_with("https://")) + { + return {}; + } + + Ann.Hostname = Data["hostname"].AsString(""); + Ann.Platform = Data["platform"].AsString(""); + Ann.CpuUsagePercent = Data["cpu_usage"].AsFloat(0.0f); + Ann.MemoryTotalBytes = Data["memory_total"].AsUInt64(0); + Ann.MemoryUsedBytes = Data["memory_used"].AsUInt64(0); + Ann.BytesReceived = Data["bytes_received"].AsUInt64(0); + Ann.BytesSent = Data["bytes_sent"].AsUInt64(0); + Ann.ActionsPending = Data["actions_pending"].AsInt32(0); + Ann.ActionsRunning = Data["actions_running"].AsInt32(0); + Ann.ActionsCompleted = Data["actions_completed"].AsInt32(0); + Ann.ActiveQueues = Data["active_queues"].AsInt32(0); + Ann.Provisioner = Data["provisioner"].AsString(""); + + if (auto Metrics = Data["metrics"].AsObjectView()) + { + Ann.Cpus = Metrics["lp_count"].AsInt32(0); + if (Ann.Cpus <= 0) + { + Ann.Cpus = 1; + } + } + + return Ann.Id; +} + +HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) +: m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket)) +, m_Hostname(GetMachineName()) +{ + m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); + m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); + + // dummy endpoint for websocket clients + m_Router.RegisterRoute( + "ws", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "status", + [this](HttpRouterRequest& Req) { + CbObjectWriter Cbo; + Cbo << "hostname" << std::string_view(m_Hostname); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "provision", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "announce", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObject Data = HttpReq.ReadPayloadObject(); + + OrchestratorService::WorkerAnnouncement Ann; + std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); + + if (WorkerId.empty()) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Invalid worker announcement: id must be 3-64 alphanumeric/underscore/dash " + "characters and uri must start with http:// or https://"); + } + + m_Service->AnnounceWorker(Ann); + + HttpReq.WriteResponse(HttpResponseCode::OK); + +# if ZEN_WITH_WEBSOCKETS + // Notify push thread that state may have changed + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "agents", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + int Limit = 100; + auto LimitStr = Params.GetValue("limit"); + if (!LimitStr.empty()) + { + Limit = std::atoi(std::string(LimitStr).c_str()); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetProvisioningHistory(Limit)); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "timeline/{workerid}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + std::string_view WorkerId = Req.GetCapture(1); + auto Params = HttpReq.GetQueryParams(); + + auto FromStr = Params.GetValue("from"); + auto ToStr = Params.GetValue("to"); + auto LimitStr = Params.GetValue("limit"); + + std::optional<DateTime> From; + std::optional<DateTime> To; + + if (!FromStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(FromStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + From = DateTime(*Val); + } + + if (!ToStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(ToStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + To = DateTime(*Val); + } + + int Limit = !LimitStr.empty() ? zen::ParseInt<int>(LimitStr).value_or(0) : 0; + + CbObject Result = m_Service->GetWorkerTimeline(WorkerId, From, To, Limit); + + if (!Result) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "timeline", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + auto FromStr = Params.GetValue("from"); + auto ToStr = Params.GetValue("to"); + + DateTime From = DateTime(0); + DateTime To = DateTime::Now(); + + if (!FromStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(FromStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + From = DateTime(*Val); + } + + if (!ToStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(ToStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + To = DateTime(*Val); + } + + CbObject Result = m_Service->GetAllTimelines(From, To); + + HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); + }, + HttpVerb::kGet); + + // Client tracking endpoints + + m_Router.RegisterRoute( + "clients", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObject Data = HttpReq.ReadPayloadObject(); + + OrchestratorService::ClientAnnouncement Ann; + Ann.SessionId = Data["session_id"].AsObjectId(Oid::Zero); + Ann.Hostname = Data["hostname"].AsString(""); + Ann.Address = HttpReq.GetRemoteAddress(); + + auto MetadataView = Data["metadata"].AsObjectView(); + if (MetadataView) + { + Ann.Metadata = CbObject::Clone(MetadataView); + } + + std::string ClientId = m_Service->AnnounceClient(Ann); + + CbObjectWriter ResponseObj; + ResponseObj << "id" << std::string_view(ClientId); + HttpReq.WriteResponse(HttpResponseCode::OK, ResponseObj.Save()); + +# if ZEN_WITH_WEBSOCKETS + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients/{clientid}/update", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view ClientId = Req.GetCapture(1); + + CbObject MetadataObj; + CbObject Data = HttpReq.ReadPayloadObject(); + if (Data) + { + auto MetadataView = Data["metadata"].AsObjectView(); + if (MetadataView) + { + MetadataObj = CbObject::Clone(MetadataView); + } + } + + if (m_Service->UpdateClient(ClientId, std::move(MetadataObj))) + { + HttpReq.WriteResponse(HttpResponseCode::OK); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients/{clientid}/complete", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view ClientId = Req.GetCapture(1); + + if (m_Service->CompleteClient(ClientId)) + { + HttpReq.WriteResponse(HttpResponseCode::OK); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + +# if ZEN_WITH_WEBSOCKETS + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetClientList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "clients/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + int Limit = 100; + auto LimitStr = Params.GetValue("limit"); + if (!LimitStr.empty()) + { + Limit = std::atoi(std::string(LimitStr).c_str()); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetClientHistory(Limit)); + }, + HttpVerb::kGet); + +# if ZEN_WITH_WEBSOCKETS + + // Start the WebSocket push thread + m_PushEnabled.store(true); + m_PushThread = std::thread([this] { PushThreadFunction(); }); +# endif +} + +HttpOrchestratorService::~HttpOrchestratorService() +{ + Shutdown(); +} + +void +HttpOrchestratorService::Shutdown() +{ +# if ZEN_WITH_WEBSOCKETS + if (!m_PushEnabled.exchange(false)) + { + return; + } + + // Stop the push thread first, before touching connections. This ensures + // the push thread is no longer reading m_WsConnections or calling into + // m_Service when we start tearing things down. + m_PushEvent.Set(); + if (m_PushThread.joinable()) + { + m_PushThread.join(); + } + + // Clean up worker WebSocket connections — collect IDs under lock, then + // notify the service outside the lock to avoid lock-order inversions. + std::vector<std::string> WorkerIds; + m_WorkerWsLock.WithExclusiveLock([&] { + WorkerIds.reserve(m_WorkerWsMap.size()); + for (const auto& [Conn, Id] : m_WorkerWsMap) + { + WorkerIds.push_back(Id); + } + m_WorkerWsMap.clear(); + }); + for (const auto& Id : WorkerIds) + { + m_Service->SetWorkerWebSocketConnected(Id, false); + } + + // Now that the push thread is gone, release all dashboard connections. + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); }); +# endif +} + +const char* +HttpOrchestratorService::BaseUri() const +{ + return "/orch/"; +} + +void +HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) +{ + if (m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +# if ZEN_WITH_WEBSOCKETS +void +HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + if (!m_PushEnabled.load()) + { + return; + } + + ZEN_INFO("WebSocket client connected"); + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); + + // Wake push thread to send initial state immediately + m_PushEvent.Set(); +} + +void +HttpOrchestratorService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + // Only handle binary messages from workers when the feature is enabled. + if (!m_Service->IsWorkerWebSocketEnabled() || Msg.Opcode != WebSocketOpcode::kBinary) + { + return; + } + + std::string WorkerId = HandleWorkerWebSocketMessage(Msg); + if (WorkerId.empty()) + { + return; + } + + // Check if this is a new worker WebSocket connection + bool IsNewWorkerWs = false; + m_WorkerWsLock.WithExclusiveLock([&] { + auto It = m_WorkerWsMap.find(&Conn); + if (It == m_WorkerWsMap.end()) + { + m_WorkerWsMap[&Conn] = WorkerId; + IsNewWorkerWs = true; + } + }); + + if (IsNewWorkerWs) + { + m_Service->SetWorkerWebSocketConnected(WorkerId, true); + } + + m_PushEvent.Set(); +} + +std::string +HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg) +{ + // Workers send CbObject in native binary format over the WebSocket to + // avoid the lossy CbObject↔JSON round-trip. + CbObject Data = CbObject::MakeView(Msg.Payload.GetData()); + if (!Data) + { + ZEN_WARN("worker WebSocket message is not a valid CbObject"); + return {}; + } + + OrchestratorService::WorkerAnnouncement Ann; + std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); + if (WorkerId.empty()) + { + ZEN_WARN("invalid worker announcement via WebSocket"); + return {}; + } + + m_Service->AnnounceWorker(Ann); + return std::string(WorkerId); +} + +void +HttpOrchestratorService::OnWebSocketClose(WebSocketConnection& Conn, + [[maybe_unused]] uint16_t Code, + [[maybe_unused]] std::string_view Reason) +{ + ZEN_INFO("WebSocket client disconnected (code {})", Code); + + // Check if this was a worker WebSocket connection; collect the ID under + // the worker lock, then notify the service outside the lock. + std::string DisconnectedWorkerId; + m_WorkerWsLock.WithExclusiveLock([&] { + auto It = m_WorkerWsMap.find(&Conn); + if (It != m_WorkerWsMap.end()) + { + DisconnectedWorkerId = std::move(It->second); + m_WorkerWsMap.erase(It); + } + }); + + if (!DisconnectedWorkerId.empty()) + { + m_Service->SetWorkerWebSocketConnected(DisconnectedWorkerId, false); + m_PushEvent.Set(); + } + + if (!m_PushEnabled.load()) + { + return; + } + + // Remove from dashboard connections + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} +# endif + +////////////////////////////////////////////////////////////////////////// +// +// Push thread +// + +# if ZEN_WITH_WEBSOCKETS +void +HttpOrchestratorService::PushThreadFunction() +{ + SetCurrentThreadName("orch_ws_push"); + + while (m_PushEnabled.load()) + { + m_PushEvent.Wait(2000); + m_PushEvent.Reset(); + + if (!m_PushEnabled.load()) + { + break; + } + + // Snapshot current connections + std::vector<Ref<WebSocketConnection>> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; }); + + if (Connections.empty()) + { + continue; + } + + // Build combined JSON with worker list, provisioning history, clients, and client history + CbObject WorkerList = m_Service->GetWorkerList(); + CbObject History = m_Service->GetProvisioningHistory(50); + CbObject ClientList = m_Service->GetClientList(); + CbObject ClientHistory = m_Service->GetClientHistory(50); + + ExtendableStringBuilder<4096> JsonBuilder; + JsonBuilder.Append("{"); + JsonBuilder.Append(fmt::format("\"hostname\":\"{}\",", m_Hostname)); + + // Emit workers array from worker list + ExtendableStringBuilder<2048> WorkerJson; + WorkerList.ToJson(WorkerJson); + std::string_view WorkerJsonView = WorkerJson.ToView(); + // Strip outer braces: {"workers":[...]} -> "workers":[...] + if (WorkerJsonView.size() >= 2) + { + JsonBuilder.Append(WorkerJsonView.substr(1, WorkerJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit events array from history + ExtendableStringBuilder<2048> HistoryJson; + History.ToJson(HistoryJson); + std::string_view HistoryJsonView = HistoryJson.ToView(); + if (HistoryJsonView.size() >= 2) + { + JsonBuilder.Append(HistoryJsonView.substr(1, HistoryJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit clients array from client list + ExtendableStringBuilder<2048> ClientJson; + ClientList.ToJson(ClientJson); + std::string_view ClientJsonView = ClientJson.ToView(); + if (ClientJsonView.size() >= 2) + { + JsonBuilder.Append(ClientJsonView.substr(1, ClientJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit client_events array from client history + ExtendableStringBuilder<2048> ClientHistoryJson; + ClientHistory.ToJson(ClientHistoryJson); + std::string_view ClientHistoryJsonView = ClientHistoryJson.ToView(); + if (ClientHistoryJsonView.size() >= 2) + { + JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2)); + } + + JsonBuilder.Append("}"); + std::string_view Json = JsonBuilder.ToView(); + + // Broadcast to all connected clients, prune closed ones + bool HadClosedConnections = false; + + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Json); + } + else + { + HadClosedConnections = true; + } + } + + if (HadClosedConnections) + { + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref<WebSocketConnection>& C) { + return !C->IsOpen(); + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); + } + } +} +# endif + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h new file mode 100644 index 000000000..a5bc5a34d --- /dev/null +++ b/src/zencompute/include/zencompute/cloudmetadata.h @@ -0,0 +1,151 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinarybuilder.h> +#include <zencore/logging.h> +#include <zencore/thread.h> + +#include <atomic> +#include <filesystem> +#include <string> +#include <thread> + +namespace zen::compute { + +enum class CloudProvider +{ + None, + AWS, + Azure, + GCP +}; + +std::string_view ToString(CloudProvider Provider); + +/** Snapshot of detected cloud instance properties. */ +struct CloudInstanceInfo +{ + CloudProvider Provider = CloudProvider::None; + std::string InstanceId; + std::string AvailabilityZone; + bool IsSpot = false; + bool IsAutoscaling = false; +}; + +/** + * Detects whether the process is running on a cloud VM (AWS, Azure, or GCP) + * and monitors for impending termination signals. + * + * Detection works by querying the Instance Metadata Service (IMDS) at the + * well-known link-local address 169.254.169.254, which is only routable from + * within a cloud VM. Each provider is probed in sequence (AWS -> Azure -> GCP); + * the first successful response wins. + * + * To avoid a ~200ms connect timeout penalty on every startup when running on + * bare-metal or non-cloud machines, failed probes write sentinel files + * (e.g. ".isNotAWS") to DataDir. Subsequent startups skip providers that have + * a sentinel present. Delete the sentinel files to force re-detection. + * + * When a provider is detected, a background thread polls for termination + * signals every 5 seconds (spot interruption, autoscaling lifecycle changes, + * scheduled maintenance). The termination state is exposed as an atomic bool + * so the compute server can include it in coordinator announcements and react + * to imminent shutdown. + * + * Thread safety: GetInstanceInfo() and GetTerminationReason() acquire a + * shared RwLock; the background monitor thread acquires the exclusive lock + * only when writing the termination reason (a one-time transition). The + * termination-pending flag itself is a lock-free atomic. + * + * Usage: + * auto Cloud = std::make_unique<CloudMetadata>(DataDir / "cloud"); + * if (Cloud->IsTerminationPending()) { ... } + * Cloud->Describe(AnnounceBody); // writes "cloud" sub-object into CB + */ +class CloudMetadata +{ +public: + /** Synchronously probes cloud providers and starts the termination monitor + * if a provider is detected. Creates DataDir if it does not exist. + */ + explicit CloudMetadata(std::filesystem::path DataDir); + + /** Synchronously probes cloud providers at the given IMDS endpoint. + * Intended for testing — allows redirecting all IMDS queries to a local + * mock HTTP server instead of the real 169.254.169.254 endpoint. + */ + CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint); + + /** Stops the termination monitor thread and joins it. */ + ~CloudMetadata(); + + CloudMetadata(const CloudMetadata&) = delete; + CloudMetadata& operator=(const CloudMetadata&) = delete; + + CloudProvider GetProvider() const; + CloudInstanceInfo GetInstanceInfo() const; + bool IsTerminationPending() const; + std::string GetTerminationReason() const; + + /** Writes a "cloud" sub-object into the compact binary writer if a provider + * was detected. No-op when running on bare metal. + */ + void Describe(CbWriter& Writer) const; + + /** Executes a single termination-poll cycle for the detected provider. + * Public so tests can drive poll cycles synchronously without relying on + * the background thread's 5-second timer. + */ + void PollTermination(); + + /** Removes the negative-cache sentinel files (.isNotAWS, .isNotAzure, + * .isNotGCP) from DataDir so subsequent detection probes are not skipped. + * Primarily intended for tests that need to reset state between sub-cases. + */ + void ClearSentinelFiles(); + +private: + /** Tries each provider in order, stops on first successful detection. */ + void DetectProvider(); + bool TryDetectAWS(); + bool TryDetectAzure(); + bool TryDetectGCP(); + + void WriteSentinelFile(const std::filesystem::path& Path); + bool HasSentinelFile(const std::filesystem::path& Path) const; + + void StartTerminationMonitor(); + void TerminationMonitorThread(); + void PollAWSTermination(); + void PollAzureTermination(); + void PollGCPTermination(); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + std::filesystem::path m_DataDir; + std::string m_ImdsEndpoint; + + mutable RwLock m_InfoLock; + CloudInstanceInfo m_Info; + + std::atomic<bool> m_TerminationPending{false}; + + mutable RwLock m_ReasonLock; + std::string m_TerminationReason; + + // IMDSv2 session token, acquired during AWS detection and reused for + // subsequent termination polling. Has a 300s TTL on the AWS side; if it + // expires mid-run the poll requests will get 401s which we treat as + // non-terminal (the monitor simply retries next cycle). + std::string m_AwsToken; + + std::thread m_MonitorThread; + std::atomic<bool> m_MonitorEnabled{true}; + Event m_MonitorEvent; +}; + +void cloudmetadata_forcelink(); // internal + +} // namespace zen::compute diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h new file mode 100644 index 000000000..65ec5f9ee --- /dev/null +++ b/src/zencompute/include/zencompute/computeservice.h @@ -0,0 +1,262 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/iohash.h> +# include <zenstore/zenstore.h> +# include <zenhttp/httpcommon.h> + +# include <filesystem> + +namespace zen { +class ChunkResolver; +class CbObjectWriter; +} // namespace zen + +namespace zen::compute { + +class ActionRecorder; +class ComputeServiceSession; +class IActionResultHandler; +class LocalProcessRunner; +class RemoteHttpRunner; +struct RunnerAction; +struct SubmitResult; + +struct WorkerDesc +{ + CbPackage Descriptor; + IoHash WorkerId{IoHash::Zero}; + + inline operator bool() const { return WorkerId != IoHash::Zero; } +}; + +/** + * Lambda style compute function service + * + * The responsibility of this class is to accept function execution requests, and + * schedule them using one or more FunctionRunner instances. It will basically always + * accept requests, queueing them if necessary, and then hand them off to runners + * as they become available. + * + * This is typically fronted by an API service that handles communication with clients. + */ +class ComputeServiceSession final +{ +public: + /** + * Session lifecycle state machine. + * + * Forward transitions: Created -> Ready -> Draining -> Paused -> Sunset + * Backward transitions: Draining -> Ready, Paused -> Ready + * Automatic transition: Draining -> Paused (when pending + running reaches 0) + * Jump transitions: any non-terminal -> Abandoned, any non-terminal -> Sunset + * Terminal states: Abandoned (only Sunset out), Sunset (no transitions out) + * + * | State | Accept new actions | Schedule pending | Finish running | + * |-----------|-------------------|-----------------|----------------| + * | Created | No | No | N/A | + * | Ready | Yes | Yes | Yes | + * | Draining | No | Yes | Yes | + * | Paused | No | No | No | + * | Abandoned | No | No | No (all abandoned) | + * | Sunset | No | No | No | + */ + enum class SessionState + { + Created, // Initial state before WaitUntilReady completes + Ready, // Normal operating state; accepts and schedules work + Draining, // Stops accepting new work; finishes existing; auto-transitions to Paused when empty + Paused, // Idle; no work accepted or scheduled; can resume to Ready + Abandoned, // Spot termination grace period; all actions abandoned; only Sunset out + Sunset // Terminal; triggers full shutdown + }; + + ComputeServiceSession(ChunkResolver& InChunkResolver); + ~ComputeServiceSession(); + + void WaitUntilReady(); + void Shutdown(); + bool IsHealthy(); + + SessionState GetSessionState() const; + + // Request a state transition. Returns false if the transition is invalid. + // Sunset can be reached from any non-Sunset state. + bool RequestStateTransition(SessionState NewState); + + // Orchestration + + void SetOrchestratorEndpoint(std::string_view Endpoint); + void SetOrchestratorBasePath(std::filesystem::path BasePath); + + // Worker registration and discovery + + void RegisterWorker(CbPackage Worker); + [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + [[nodiscard]] std::vector<IoHash> GetKnownWorkerIds(); + + // Action runners + + void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); + void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); + + // Action submission + + struct EnqueueResult + { + int Lsn; + CbObject ResponseMessage; + + inline operator bool() const { return Lsn != 0; } + }; + + [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority); + [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); + + // Queue management + // + // Queues group actions submitted by a single client session. They allow + // cancelling or polling completion of all actions in the group. + + struct CreateQueueResult + { + int QueueId = 0; // 0 if creation failed + }; + + enum class QueueState + { + Active, + Draining, + Cancelled, + }; + + struct QueueStatus + { + bool IsValid = false; + int QueueId = 0; + int ActiveCount = 0; // pending + running (not yet completed) + int CompletedCount = 0; // successfully completed + int FailedCount = 0; // failed + int AbandonedCount = 0; // abandoned + int CancelledCount = 0; // cancelled + QueueState State = QueueState::Active; + bool IsComplete = false; // ActiveCount == 0 + }; + + [[nodiscard]] CreateQueueResult CreateQueue(std::string_view Tag = {}, CbObject Metadata = {}, CbObject Config = {}); + [[nodiscard]] std::vector<int> GetQueueIds(); + [[nodiscard]] QueueStatus GetQueueStatus(int QueueId); + [[nodiscard]] CbObject GetQueueMetadata(int QueueId); + [[nodiscard]] CbObject GetQueueConfig(int QueueId); + void CancelQueue(int QueueId); + void DrainQueue(int QueueId); + void DeleteQueue(int QueueId); + void GetQueueCompleted(int QueueId, CbWriter& Cbo); + + // Queue-scoped action submission. Actions submitted via these methods are + // tracked under the given queue in addition to the global LSN-based tracking. + + [[nodiscard]] EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); + [[nodiscard]] EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + + // Completed action tracking + + [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + void RetireActionResult(int ActionLsn); + + // Action rescheduling + + struct RescheduleResult + { + bool Success = false; + std::string Error; + int RetryCount = 0; + }; + + [[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn); + + void GetCompleted(CbWriter&); + + // Running action tracking + + struct RunningActionInfo + { + int Lsn; + int QueueId; + IoHash ActionId; + float CpuUsagePercent; // -1.0 if not yet sampled + float CpuSeconds; // 0.0 if not yet sampled + }; + + [[nodiscard]] std::vector<RunningActionInfo> GetRunningActions(); + + // Action history tracking (note that this is separate from completed action tracking, and + // will include actions which have been retired and no longer have their results available) + + struct ActionHistoryEntry + { + int Lsn; + int QueueId = 0; + IoHash ActionId; + IoHash WorkerId; + CbObject ActionDescriptor; + std::string ExecutionLocation; + bool Succeeded; + float CpuSeconds = 0.0f; // total CPU time at completion; 0.0 if not sampled + int RetryCount = 0; // number of times this action was rescheduled + // sized to match RunnerAction::State::_Count but we can't use the enum here + // for dependency reasons, so just use a fixed size array and static assert in + // the implementation file + uint64_t Timestamps[8] = {}; + }; + + [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); + [[nodiscard]] std::vector<ActionHistoryEntry> GetQueueHistory(int QueueId, int Limit = 100); + + // Stats reporting + + struct ActionCounts + { + int Pending = 0; + int Running = 0; + int Completed = 0; + int ActiveQueues = 0; + }; + + [[nodiscard]] ActionCounts GetActionCounts(); + + void EmitStats(CbObjectWriter& Cbo); + + // Recording + + void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); + void StopRecording(); + +private: + void PostUpdate(RunnerAction* Action); + + friend class FunctionRunner; + friend struct RunnerAction; + + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void computeservice_forcelink(); + +} // namespace zen::compute + +namespace zen { +const char* ToString(compute::ComputeServiceSession::SessionState State); +const char* ToString(compute::ComputeServiceSession::QueueState State); +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h new file mode 100644 index 000000000..ee1cd2614 --- /dev/null +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "zencompute/computeservice.h" + +# include <zenhttp/httpserver.h> + +# include <filesystem> +# include <memory> + +namespace zen { +class CidStore; +} + +namespace zen::compute { + +/** + * HTTP interface for compute service + */ +class HttpComputeService : public HttpService, public IHttpStatsProvider +{ +public: + HttpComputeService(CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions = 0); + ~HttpComputeService(); + + void Shutdown(); + + [[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts(); + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + + // IHttpStatsProvider + + void HandleStatsRequest(HttpServerRequest& Request) override; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void httpcomputeservice_forcelink(); + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h new file mode 100644 index 000000000..da5c5dfc3 --- /dev/null +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -0,0 +1,101 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#include <zencore/logging.h> +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> + +#include <atomic> +#include <filesystem> +#include <memory> +#include <string> +#include <thread> +#include <unordered_map> +#include <vector> + +#define ZEN_WITH_WEBSOCKETS 1 + +namespace zen::compute { + +class OrchestratorService; + +// Experimental helper, to see if we can get rid of some error-prone +// boilerplate when declaring loggers as class members. + +class LoggerHelper +{ +public: + LoggerHelper(std::string_view Logger) : m_Log(logging::Get(Logger)) {} + + LoggerRef operator()() { return m_Log; } + +private: + LoggerRef m_Log; +}; + +/** + * Orchestrator HTTP service with WebSocket push support + * + * Normal HTTP requests are routed through the HttpRequestRouter as before. + * WebSocket clients connecting to /orch/ws receive periodic state broadcasts + * from a dedicated push thread, eliminating the need for polling. + */ + +class HttpOrchestratorService : public HttpService +#if ZEN_WITH_WEBSOCKETS +, + public IWebSocketHandler +#endif +{ +public: + explicit HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); + ~HttpOrchestratorService(); + + HttpOrchestratorService(const HttpOrchestratorService&) = delete; + HttpOrchestratorService& operator=(const HttpOrchestratorService&) = delete; + + /** + * Gracefully shut down the WebSocket push thread and release connections. + * Must be called while the ASIO io_context is still alive. The destructor + * also calls this, so it is safe (but not ideal) to omit the explicit call. + */ + void Shutdown(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + + // IWebSocketHandler +#if ZEN_WITH_WEBSOCKETS + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; +#endif + +private: + HttpRequestRouter m_Router; + LoggerHelper Log{"orch"}; + std::unique_ptr<OrchestratorService> m_Service; + std::string m_Hostname; + + // WebSocket push + +#if ZEN_WITH_WEBSOCKETS + RwLock m_WsConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; + std::thread m_PushThread; + std::atomic<bool> m_PushEnabled{false}; + Event m_PushEvent; + void PushThreadFunction(); + + // Worker WebSocket connections (worker→orchestrator persistent links) + RwLock m_WorkerWsLock; + std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr → worker ID + std::string HandleWorkerWebSocketMessage(const WebSocketMessage& Msg); +#endif +}; + +} // namespace zen::compute diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h new file mode 100644 index 000000000..521722e63 --- /dev/null +++ b/src/zencompute/include/zencompute/mockimds.h @@ -0,0 +1,102 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/cloudmetadata.h> +#include <zenhttp/httpserver.h> + +#include <string> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +/** + * Mock IMDS (Instance Metadata Service) for testing CloudMetadata. + * + * Implements an HttpService that responds to the same URL paths as the real + * cloud provider metadata endpoints (AWS IMDSv2, Azure IMDS, GCP metadata). + * Tests configure which provider is "active" and set the desired response + * values, then pass the mock server's address as the ImdsEndpoint to the + * CloudMetadata constructor. + * + * When a request arrives for a provider that is not the ActiveProvider, the + * mock returns 404, causing CloudMetadata to write a sentinel file and move + * on to the next provider — exactly like a failed probe on bare metal. + * + * All config fields are public and can be mutated between poll cycles to + * simulate state changes (e.g. a spot interruption appearing mid-run). + * + * Usage: + * MockImdsService Mock; + * Mock.ActiveProvider = CloudProvider::AWS; + * Mock.Aws.InstanceId = "i-test"; + * // ... stand up ASIO server, register Mock, create CloudMetadata with endpoint + */ +class MockImdsService : public HttpService +{ +public: + /** AWS IMDSv2 response configuration. */ + struct AwsConfig + { + std::string Token = "mock-aws-token-v2"; + std::string InstanceId = "i-0123456789abcdef0"; + std::string AvailabilityZone = "us-east-1a"; + std::string LifeCycle = "on-demand"; // "spot" or "on-demand" + + // Empty string → endpoint returns 404 (instance not in an ASG). + // Non-empty → returned as the response body. "InService" means healthy; + // anything else (e.g. "Terminated:Wait") triggers termination detection. + std::string AutoscalingState; + + // Empty string → endpoint returns 404 (no spot interruption). + // Non-empty → returned as the response body, signalling a spot reclaim. + std::string SpotAction; + }; + + /** Azure IMDS response configuration. */ + struct AzureConfig + { + std::string VmId = "vm-12345678-1234-1234-1234-123456789abc"; + std::string Location = "eastus"; + std::string Priority = "Regular"; // "Spot" or "Regular" + + // Empty → instance is not in a VM Scale Set (no autoscaling). + std::string VmScaleSetName; + + // Empty → no scheduled events. Set to "Preempt", "Terminate", or + // "Reboot" to simulate a termination-class event. + std::string ScheduledEventType; + std::string ScheduledEventStatus = "Scheduled"; + }; + + /** GCP metadata response configuration. */ + struct GcpConfig + { + std::string InstanceId = "1234567890123456789"; + std::string Zone = "projects/123456/zones/us-central1-a"; + std::string Preemptible = "FALSE"; // "TRUE" or "FALSE" + std::string MaintenanceEvent = "NONE"; // "NONE" or event description + }; + + /** Which provider's endpoints respond successfully. + * Requests targeting other providers receive 404. + */ + CloudProvider ActiveProvider = CloudProvider::None; + + AwsConfig Aws; + AzureConfig Azure; + GcpConfig Gcp; + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + +private: + void HandleAwsRequest(HttpServerRequest& Request); + void HandleAzureRequest(HttpServerRequest& Request); + void HandleGcpRequest(HttpServerRequest& Request); +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h new file mode 100644 index 000000000..071e902b3 --- /dev/null +++ b/src/zencompute/include/zencompute/orchestratorservice.h @@ -0,0 +1,177 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/thread.h> +# include <zencore/timer.h> +# include <zencore/uid.h> + +# include <deque> +# include <optional> +# include <filesystem> +# include <memory> +# include <string> +# include <string_view> +# include <thread> +# include <unordered_map> + +namespace zen::compute { + +class WorkerTimelineStore; + +class OrchestratorService +{ +public: + explicit OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); + ~OrchestratorService(); + + OrchestratorService(const OrchestratorService&) = delete; + OrchestratorService& operator=(const OrchestratorService&) = delete; + + struct WorkerAnnouncement + { + std::string_view Id; + std::string_view Uri; + std::string_view Hostname; + std::string_view Platform; // e.g. "windows", "wine", "linux", "macos" + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string_view Provisioner; // e.g. "horde", "nomad", or empty + }; + + struct ProvisioningEvent + { + enum class Type + { + Joined, + Left, + Returned + }; + Type EventType; + DateTime Timestamp; + std::string WorkerId; + std::string Hostname; + }; + + struct ClientAnnouncement + { + Oid SessionId; + std::string_view Hostname; + std::string_view Address; + CbObject Metadata; + }; + + struct ClientEvent + { + enum class Type + { + Connected, + Disconnected, + Updated + }; + Type EventType; + DateTime Timestamp; + std::string ClientId; + std::string Hostname; + }; + + CbObject GetWorkerList(); + void AnnounceWorker(const WorkerAnnouncement& Announcement); + + bool IsWorkerWebSocketEnabled() const; + void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected); + + CbObject GetProvisioningHistory(int Limit = 100); + + CbObject GetWorkerTimeline(std::string_view WorkerId, std::optional<DateTime> From, std::optional<DateTime> To, int Limit); + + CbObject GetAllTimelines(DateTime From, DateTime To); + + std::string AnnounceClient(const ClientAnnouncement& Announcement); + bool UpdateClient(std::string_view ClientId, CbObject Metadata = {}); + bool CompleteClient(std::string_view ClientId); + CbObject GetClientList(); + CbObject GetClientHistory(int Limit = 100); + +private: + enum class ReachableState + { + Unknown, + Reachable, + Unreachable, + }; + + struct KnownWorker + { + std::string BaseUri; + Stopwatch LastSeen; + std::string Hostname; + std::string Platform; + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string Provisioner; + ReachableState Reachable = ReachableState::Unknown; + bool WsConnected = false; + Stopwatch LastProbed; + }; + + RwLock m_KnownWorkersLock; + std::unordered_map<std::string, KnownWorker> m_KnownWorkers; + std::unique_ptr<WorkerTimelineStore> m_TimelineStore; + + RwLock m_ProvisioningLogLock; + std::deque<ProvisioningEvent> m_ProvisioningLog; + static constexpr size_t kMaxProvisioningEvents = 1000; + + void RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname); + + struct KnownClient + { + Oid SessionId; + std::string Hostname; + std::string Address; + Stopwatch LastSeen; + CbObject Metadata; + }; + + RwLock m_KnownClientsLock; + std::unordered_map<std::string, KnownClient> m_KnownClients; + + RwLock m_ClientLogLock; + std::deque<ClientEvent> m_ClientLog; + static constexpr size_t kMaxClientEvents = 1000; + + void RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname); + + bool m_EnableWorkerWebSocket = false; + + std::thread m_ProbeThread; + std::atomic<bool> m_ProbeThreadEnabled{true}; + Event m_ProbeThreadEvent; + void ProbeThreadFunction(); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/include/zencompute/recordingreader.h b/src/zencompute/include/zencompute/recordingreader.h new file mode 100644 index 000000000..3f233fae0 --- /dev/null +++ b/src/zencompute/include/zencompute/recordingreader.h @@ -0,0 +1,129 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#include <zencompute/computeservice.h> +#include <zencompute/zencompute.h> +#include <zencore/basicfile.h> +#include <zencore/compactbinarybuilder.h> +#include <zenstore/cidstore.h> +#include <zenstore/gc.h> +#include <zenstore/zenstore.h> + +#include <filesystem> +#include <functional> +#include <unordered_map> + +namespace zen { +class CbObject; +class CbPackage; +struct IoHash; +} // namespace zen + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +////////////////////////////////////////////////////////////////////////// + +class RecordingReaderBase +{ + RecordingReaderBase(const RecordingReaderBase&) = delete; + RecordingReaderBase& operator=(const RecordingReaderBase&) = delete; + +public: + RecordingReaderBase() = default; + virtual ~RecordingReaderBase() = 0; + virtual std::unordered_map<IoHash, CbPackage> ReadWorkers() = 0; + virtual void IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, int TargetParallelism) = 0; + virtual size_t GetActionCount() const = 0; +}; + +////////////////////////////////////////////////////////////////////////// + +/** + * Reader for recordings done via the zencompute recording system, which + * have a shared chunk store and a log of actions with pointers into the + * chunk store for their data. + */ +class RecordingReader : public RecordingReaderBase, public ChunkResolver +{ +public: + explicit RecordingReader(const std::filesystem::path& RecordingPath); + ~RecordingReader(); + + virtual std::unordered_map<zen::IoHash, zen::CbPackage> ReadWorkers() override; + + virtual void IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, + int TargetParallelism) override; + virtual size_t GetActionCount() const override; + +private: + std::filesystem::path m_RecordingLogDir; + BasicFile m_WorkerDataFile; + BasicFile m_ActionDataFile; + GcManager m_Gc; + CidStore m_CidStore{m_Gc}; + + // ChunkResolver interface + virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + + struct ActionEntry + { + IoHash ActionId; + uint64_t Offset; + uint64_t Size; + }; + + std::vector<ActionEntry> m_Actions; + + void ScanActions(); +}; + +////////////////////////////////////////////////////////////////////////// + +struct LocalResolver : public ChunkResolver +{ + LocalResolver(const LocalResolver&) = delete; + LocalResolver& operator=(const LocalResolver&) = delete; + + LocalResolver() = default; + ~LocalResolver() = default; + + virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + void Add(const IoHash& Cid, IoBuffer Data); + +private: + RwLock MapLock; + std::unordered_map<IoHash, IoBuffer> Attachments; +}; + +/** + * This is a reader for UE/DDB recordings, which have a different layout on + * disk (no shared chunk store) + */ +class UeRecordingReader : public RecordingReaderBase, public ChunkResolver +{ +public: + explicit UeRecordingReader(const std::filesystem::path& RecordingPath); + ~UeRecordingReader(); + + virtual std::unordered_map<zen::IoHash, zen::CbPackage> ReadWorkers() override; + virtual void IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, + int TargetParallelism) override; + virtual size_t GetActionCount() const override; + virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + +private: + std::filesystem::path m_RecordingDir; + LocalResolver m_LocalResolver; + std::vector<std::filesystem::path> m_WorkDirs; + + CbPackage ReadAction(std::filesystem::path WorkDir); +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/zencompute.h b/src/zencompute/include/zencompute/zencompute.h new file mode 100644 index 000000000..00be4d4a0 --- /dev/null +++ b/src/zencompute/include/zencompute/zencompute.h @@ -0,0 +1,15 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if !defined(ZEN_WITH_COMPUTE_SERVICES) +# define ZEN_WITH_COMPUTE_SERVICES 1 +#endif + +namespace zen { + +void zencompute_forcelinktests(); + +} diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp new file mode 100644 index 000000000..9ea695305 --- /dev/null +++ b/src/zencompute/orchestratorservice.cpp @@ -0,0 +1,710 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/orchestratorservice.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinarybuilder.h> +# include <zencore/logging.h> +# include <zencore/trace.h> +# include <zenhttp/httpclient.h> + +# include "timeline/workertimeline.h" + +namespace zen::compute { + +OrchestratorService::OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) +: m_TimelineStore(std::make_unique<WorkerTimelineStore>(DataDir / "timelines")) +, m_EnableWorkerWebSocket(EnableWorkerWebSocket) +{ + m_ProbeThread = std::thread{&OrchestratorService::ProbeThreadFunction, this}; +} + +OrchestratorService::~OrchestratorService() +{ + m_ProbeThreadEnabled = false; + m_ProbeThreadEvent.Set(); + if (m_ProbeThread.joinable()) + { + m_ProbeThread.join(); + } +} + +CbObject +OrchestratorService::GetWorkerList() +{ + ZEN_TRACE_CPU("OrchestratorService::GetWorkerList"); + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + m_KnownWorkersLock.WithSharedLock([&] { + for (const auto& [WorkerId, Worker] : m_KnownWorkers) + { + Cbo.BeginObject(); + Cbo << "id" << WorkerId; + Cbo << "uri" << Worker.BaseUri; + Cbo << "hostname" << Worker.Hostname; + if (!Worker.Platform.empty()) + { + Cbo << "platform" << std::string_view(Worker.Platform); + } + Cbo << "cpus" << Worker.Cpus; + Cbo << "cpu_usage" << Worker.CpuUsagePercent; + Cbo << "memory_total" << Worker.MemoryTotalBytes; + Cbo << "memory_used" << Worker.MemoryUsedBytes; + Cbo << "bytes_received" << Worker.BytesReceived; + Cbo << "bytes_sent" << Worker.BytesSent; + Cbo << "actions_pending" << Worker.ActionsPending; + Cbo << "actions_running" << Worker.ActionsRunning; + Cbo << "actions_completed" << Worker.ActionsCompleted; + Cbo << "active_queues" << Worker.ActiveQueues; + if (!Worker.Provisioner.empty()) + { + Cbo << "provisioner" << std::string_view(Worker.Provisioner); + } + if (Worker.Reachable != ReachableState::Unknown) + { + Cbo << "reachable" << (Worker.Reachable == ReachableState::Reachable); + } + if (Worker.WsConnected) + { + Cbo << "ws_connected" << true; + } + Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +void +OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann) +{ + ZEN_TRACE_CPU("OrchestratorService::AnnounceWorker"); + + bool IsNew = false; + std::string EvictedId; + std::string EvictedHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + IsNew = (m_KnownWorkers.find(std::string(Ann.Id)) == m_KnownWorkers.end()); + + // If a different worker ID already maps to the same URI, the old entry + // is stale (e.g. a previous Horde lease on the same machine). Remove it + // so the dashboard doesn't show duplicates. + if (IsNew) + { + for (auto It = m_KnownWorkers.begin(); It != m_KnownWorkers.end(); ++It) + { + if (It->second.BaseUri == Ann.Uri && It->first != Ann.Id) + { + EvictedId = It->first; + EvictedHostname = It->second.Hostname; + m_KnownWorkers.erase(It); + break; + } + } + } + + auto& Worker = m_KnownWorkers[std::string(Ann.Id)]; + Worker.BaseUri = Ann.Uri; + Worker.Hostname = Ann.Hostname; + if (!Ann.Platform.empty()) + { + Worker.Platform = Ann.Platform; + } + Worker.Cpus = Ann.Cpus; + Worker.CpuUsagePercent = Ann.CpuUsagePercent; + Worker.MemoryTotalBytes = Ann.MemoryTotalBytes; + Worker.MemoryUsedBytes = Ann.MemoryUsedBytes; + Worker.BytesReceived = Ann.BytesReceived; + Worker.BytesSent = Ann.BytesSent; + Worker.ActionsPending = Ann.ActionsPending; + Worker.ActionsRunning = Ann.ActionsRunning; + Worker.ActionsCompleted = Ann.ActionsCompleted; + Worker.ActiveQueues = Ann.ActiveQueues; + if (!Ann.Provisioner.empty()) + { + Worker.Provisioner = Ann.Provisioner; + } + Worker.LastSeen.Reset(); + }); + + if (!EvictedId.empty()) + { + ZEN_INFO("worker {} superseded by {} (same endpoint)", EvictedId, Ann.Id); + RecordProvisioningEvent(ProvisioningEvent::Type::Left, EvictedId, EvictedHostname); + } + + if (IsNew) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Joined, Ann.Id, Ann.Hostname); + } +} + +bool +OrchestratorService::IsWorkerWebSocketEnabled() const +{ + return m_EnableWorkerWebSocket; +} + +void +OrchestratorService::SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected) +{ + ReachableState PrevState = ReachableState::Unknown; + std::string WorkerHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + auto It = m_KnownWorkers.find(std::string(WorkerId)); + if (It == m_KnownWorkers.end()) + { + return; + } + + PrevState = It->second.Reachable; + WorkerHostname = It->second.Hostname; + It->second.WsConnected = Connected; + It->second.Reachable = Connected ? ReachableState::Reachable : ReachableState::Unreachable; + + if (Connected) + { + ZEN_INFO("worker {} WebSocket connected — marking reachable", WorkerId); + } + else + { + ZEN_WARN("worker {} WebSocket disconnected — marking unreachable", WorkerId); + } + }); + + // Record provisioning events for state transitions outside the lock + if (Connected && PrevState == ReachableState::Unreachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Returned, WorkerId, WorkerHostname); + } + else if (!Connected && PrevState == ReachableState::Reachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Left, WorkerId, WorkerHostname); + } +} + +CbObject +OrchestratorService::GetWorkerTimeline(std::string_view WorkerId, std::optional<DateTime> From, std::optional<DateTime> To, int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetWorkerTimeline"); + + Ref<WorkerTimeline> Timeline = m_TimelineStore->Find(WorkerId); + if (!Timeline) + { + return {}; + } + + std::vector<WorkerTimeline::Event> Events; + + if (From || To) + { + DateTime StartTime = From.value_or(DateTime(0)); + DateTime EndTime = To.value_or(DateTime::Now()); + Events = Timeline->QueryTimeline(StartTime, EndTime); + } + else if (Limit > 0) + { + Events = Timeline->QueryRecent(Limit); + } + else + { + Events = Timeline->QueryRecent(); + } + + WorkerTimeline::TimeRange Range = Timeline->GetTimeRange(); + + CbObjectWriter Cbo; + Cbo << "worker_id" << WorkerId; + Cbo << "event_count" << static_cast<int32_t>(Timeline->GetEventCount()); + + if (Range) + { + Cbo.AddDateTime("time_first", Range.First); + Cbo.AddDateTime("time_last", Range.Last); + } + + Cbo.BeginArray("events"); + for (const auto& Evt : Events) + { + Cbo.BeginObject(); + Cbo << "type" << WorkerTimeline::ToString(Evt.Type); + Cbo.AddDateTime("ts", Evt.Timestamp); + + if (Evt.ActionLsn != 0) + { + Cbo << "lsn" << Evt.ActionLsn; + Cbo << "action_id" << Evt.ActionId; + } + + if (Evt.Type == WorkerTimeline::EventType::ActionStateChanged) + { + Cbo << "prev_state" << RunnerAction::ToString(Evt.PreviousState); + Cbo << "state" << RunnerAction::ToString(Evt.ActionState); + } + + if (!Evt.Reason.empty()) + { + Cbo << "reason" << std::string_view(Evt.Reason); + } + + Cbo.EndObject(); + } + Cbo.EndArray(); + + return Cbo.Save(); +} + +CbObject +OrchestratorService::GetAllTimelines(DateTime From, DateTime To) +{ + ZEN_TRACE_CPU("OrchestratorService::GetAllTimelines"); + + DateTime StartTime = From; + DateTime EndTime = To; + + auto AllInfo = m_TimelineStore->GetAllWorkerInfo(); + + CbObjectWriter Cbo; + Cbo.AddDateTime("from", StartTime); + Cbo.AddDateTime("to", EndTime); + + Cbo.BeginArray("workers"); + for (const auto& Info : AllInfo) + { + if (!Info.Range || Info.Range.Last < StartTime || Info.Range.First > EndTime) + { + continue; + } + + Cbo.BeginObject(); + Cbo << "worker_id" << Info.WorkerId; + Cbo.AddDateTime("time_first", Info.Range.First); + Cbo.AddDateTime("time_last", Info.Range.Last); + Cbo.EndObject(); + } + Cbo.EndArray(); + + return Cbo.Save(); +} + +void +OrchestratorService::RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname) +{ + ProvisioningEvent Evt{ + .EventType = Type, + .Timestamp = DateTime::Now(), + .WorkerId = std::string(WorkerId), + .Hostname = std::string(Hostname), + }; + + m_ProvisioningLogLock.WithExclusiveLock([&] { + m_ProvisioningLog.push_back(std::move(Evt)); + while (m_ProvisioningLog.size() > kMaxProvisioningEvents) + { + m_ProvisioningLog.pop_front(); + } + }); +} + +CbObject +OrchestratorService::GetProvisioningHistory(int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetProvisioningHistory"); + + if (Limit <= 0) + { + Limit = 100; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("events"); + + m_ProvisioningLogLock.WithSharedLock([&] { + // Return last N events, newest first + int Count = 0; + for (auto It = m_ProvisioningLog.rbegin(); It != m_ProvisioningLog.rend() && Count < Limit; ++It, ++Count) + { + const auto& Evt = *It; + Cbo.BeginObject(); + + switch (Evt.EventType) + { + case ProvisioningEvent::Type::Joined: + Cbo << "type" + << "joined"; + break; + case ProvisioningEvent::Type::Left: + Cbo << "type" + << "left"; + break; + case ProvisioningEvent::Type::Returned: + Cbo << "type" + << "returned"; + break; + } + + Cbo.AddDateTime("ts", Evt.Timestamp); + Cbo << "worker_id" << std::string_view(Evt.WorkerId); + Cbo << "hostname" << std::string_view(Evt.Hostname); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +std::string +OrchestratorService::AnnounceClient(const ClientAnnouncement& Ann) +{ + ZEN_TRACE_CPU("OrchestratorService::AnnounceClient"); + + std::string ClientId = fmt::format("client-{}", Oid::NewOid().ToString()); + + bool IsNew = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(ClientId); + IsNew = (It == m_KnownClients.end()); + + auto& Client = m_KnownClients[ClientId]; + Client.SessionId = Ann.SessionId; + Client.Hostname = Ann.Hostname; + if (!Ann.Address.empty()) + { + Client.Address = Ann.Address; + } + if (Ann.Metadata) + { + Client.Metadata = Ann.Metadata; + } + Client.LastSeen.Reset(); + }); + + if (IsNew) + { + RecordClientEvent(ClientEvent::Type::Connected, ClientId, Ann.Hostname); + } + else + { + RecordClientEvent(ClientEvent::Type::Updated, ClientId, Ann.Hostname); + } + + return ClientId; +} + +bool +OrchestratorService::UpdateClient(std::string_view ClientId, CbObject Metadata) +{ + ZEN_TRACE_CPU("OrchestratorService::UpdateClient"); + + bool Found = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(std::string(ClientId)); + if (It != m_KnownClients.end()) + { + Found = true; + if (Metadata) + { + It->second.Metadata = std::move(Metadata); + } + It->second.LastSeen.Reset(); + } + }); + + return Found; +} + +bool +OrchestratorService::CompleteClient(std::string_view ClientId) +{ + ZEN_TRACE_CPU("OrchestratorService::CompleteClient"); + + std::string Hostname; + bool Found = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(std::string(ClientId)); + if (It != m_KnownClients.end()) + { + Found = true; + Hostname = It->second.Hostname; + m_KnownClients.erase(It); + } + }); + + if (Found) + { + RecordClientEvent(ClientEvent::Type::Disconnected, ClientId, Hostname); + } + + return Found; +} + +CbObject +OrchestratorService::GetClientList() +{ + ZEN_TRACE_CPU("OrchestratorService::GetClientList"); + CbObjectWriter Cbo; + Cbo.BeginArray("clients"); + + m_KnownClientsLock.WithSharedLock([&] { + for (const auto& [ClientId, Client] : m_KnownClients) + { + Cbo.BeginObject(); + Cbo << "id" << ClientId; + if (Client.SessionId) + { + Cbo << "session_id" << Client.SessionId; + } + Cbo << "hostname" << std::string_view(Client.Hostname); + if (!Client.Address.empty()) + { + Cbo << "address" << std::string_view(Client.Address); + } + Cbo << "dt" << Client.LastSeen.GetElapsedTimeMs(); + if (Client.Metadata) + { + Cbo << "metadata" << Client.Metadata; + } + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +CbObject +OrchestratorService::GetClientHistory(int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetClientHistory"); + + if (Limit <= 0) + { + Limit = 100; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("client_events"); + + m_ClientLogLock.WithSharedLock([&] { + int Count = 0; + for (auto It = m_ClientLog.rbegin(); It != m_ClientLog.rend() && Count < Limit; ++It, ++Count) + { + const auto& Evt = *It; + Cbo.BeginObject(); + + switch (Evt.EventType) + { + case ClientEvent::Type::Connected: + Cbo << "type" + << "connected"; + break; + case ClientEvent::Type::Disconnected: + Cbo << "type" + << "disconnected"; + break; + case ClientEvent::Type::Updated: + Cbo << "type" + << "updated"; + break; + } + + Cbo.AddDateTime("ts", Evt.Timestamp); + Cbo << "client_id" << std::string_view(Evt.ClientId); + Cbo << "hostname" << std::string_view(Evt.Hostname); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +void +OrchestratorService::RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname) +{ + ClientEvent Evt{ + .EventType = Type, + .Timestamp = DateTime::Now(), + .ClientId = std::string(ClientId), + .Hostname = std::string(Hostname), + }; + + m_ClientLogLock.WithExclusiveLock([&] { + m_ClientLog.push_back(std::move(Evt)); + while (m_ClientLog.size() > kMaxClientEvents) + { + m_ClientLog.pop_front(); + } + }); +} + +void +OrchestratorService::ProbeThreadFunction() +{ + ZEN_TRACE_CPU("OrchestratorService::ProbeThreadFunction"); + SetCurrentThreadName("orch_probe"); + + bool IsFirstProbe = true; + + do + { + if (!IsFirstProbe) + { + m_ProbeThreadEvent.Wait(5'000); + m_ProbeThreadEvent.Reset(); + } + else + { + IsFirstProbe = false; + } + + if (m_ProbeThreadEnabled == false) + { + return; + } + + m_ProbeThreadEvent.Reset(); + + // Snapshot worker IDs and URIs under shared lock + struct WorkerSnapshot + { + std::string Id; + std::string Uri; + bool WsConnected = false; + }; + std::vector<WorkerSnapshot> Snapshots; + + m_KnownWorkersLock.WithSharedLock([&] { + Snapshots.reserve(m_KnownWorkers.size()); + for (const auto& [WorkerId, Worker] : m_KnownWorkers) + { + Snapshots.push_back({WorkerId, Worker.BaseUri, Worker.WsConnected}); + } + }); + + // Probe each worker outside the lock + for (const auto& Snap : Snapshots) + { + if (m_ProbeThreadEnabled == false) + { + return; + } + + // Workers with an active WebSocket connection are known-reachable; + // skip the HTTP health probe for them. + if (Snap.WsConnected) + { + continue; + } + + ReachableState NewState = ReachableState::Unreachable; + + try + { + HttpClient Client(Snap.Uri, + {.ConnectTimeout = std::chrono::milliseconds{3000}, .Timeout = std::chrono::milliseconds{5000}}); + HttpClient::Response Response = Client.Get("/health/"); + if (Response.IsSuccess()) + { + NewState = ReachableState::Reachable; + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + } + + ReachableState PrevState = ReachableState::Unknown; + std::string WorkerHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + auto It = m_KnownWorkers.find(Snap.Id); + if (It != m_KnownWorkers.end()) + { + PrevState = It->second.Reachable; + WorkerHostname = It->second.Hostname; + It->second.Reachable = NewState; + It->second.LastProbed.Reset(); + + if (PrevState != NewState) + { + if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable) + { + ZEN_INFO("worker {} ({}) is reachable again", Snap.Id, Snap.Uri); + } + else if (NewState == ReachableState::Reachable) + { + ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri); + } + else if (PrevState == ReachableState::Reachable) + { + ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri); + } + else + { + ZEN_WARN("worker {} ({}) is not reachable", Snap.Id, Snap.Uri); + } + } + } + }); + + // Record provisioning events for state transitions outside the lock + if (PrevState != NewState) + { + if (NewState == ReachableState::Unreachable && PrevState == ReachableState::Reachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Left, Snap.Id, WorkerHostname); + } + else if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Returned, Snap.Id, WorkerHostname); + } + } + } + + // Sweep expired clients (5-minute timeout) + static constexpr int64_t kClientTimeoutMs = 5 * 60 * 1000; + + struct ExpiredClient + { + std::string Id; + std::string Hostname; + }; + std::vector<ExpiredClient> ExpiredClients; + + m_KnownClientsLock.WithExclusiveLock([&] { + for (auto It = m_KnownClients.begin(); It != m_KnownClients.end();) + { + if (It->second.LastSeen.GetElapsedTimeMs() > kClientTimeoutMs) + { + ExpiredClients.push_back({It->first, It->second.Hostname}); + It = m_KnownClients.erase(It); + } + else + { + ++It; + } + } + }); + + for (const auto& Expired : ExpiredClients) + { + ZEN_INFO("client {} timed out (no announcement for >5 minutes)", Expired.Id); + RecordClientEvent(ClientEvent::Type::Disconnected, Expired.Id, Expired.Hostname); + } + } while (m_ProbeThreadEnabled); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/recording/actionrecorder.cpp b/src/zencompute/recording/actionrecorder.cpp new file mode 100644 index 000000000..90141ca55 --- /dev/null +++ b/src/zencompute/recording/actionrecorder.cpp @@ -0,0 +1,258 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "actionrecorder.h" + +#include "../runners/functionrunner.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinaryfile.h> +#include <zencore/compactbinaryvalue.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> + +#if ZEN_PLATFORM_WINDOWS +# include <ppl.h> +# define ZEN_CONCRT_AVAILABLE 1 +#else +# define ZEN_CONCRT_AVAILABLE 0 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +RecordingFileWriter::RecordingFileWriter() +{ +} + +RecordingFileWriter::~RecordingFileWriter() +{ + Close(); +} + +void +RecordingFileWriter::Open(std::filesystem::path FilePath) +{ + using namespace std::literals; + + m_File.Open(FilePath, BasicFile::Mode::kTruncate); + m_File.Write("----DDC2----DATA", 16, 0); + m_FileOffset = 16; + + std::filesystem::path TocPath = FilePath.replace_extension(".ztoc"); + m_TocFile.Open(TocPath, BasicFile::Mode::kTruncate); + + m_TocWriter << "version"sv << 1; + m_TocWriter.BeginArray("toc"sv); +} + +void +RecordingFileWriter::Close() +{ + m_TocWriter.EndArray(); + CbObject Toc = m_TocWriter.Save(); + + std::error_code Ec; + m_TocFile.WriteAll(Toc.GetBuffer().AsIoBuffer(), Ec); +} + +void +RecordingFileWriter::AppendObject(const CbObject& Object, const IoHash& ObjectHash) +{ + RwLock::ExclusiveLockScope _(m_FileLock); + + MemoryView ObjectView = Object.GetBuffer().GetView(); + + std::error_code Ec; + m_File.Write(ObjectView, m_FileOffset, Ec); + + if (Ec) + { + throw std::system_error(Ec, "failed writing to archive"); + } + + m_TocWriter.BeginArray(); + m_TocWriter.AddHash(ObjectHash); + m_TocWriter.AddInteger(m_FileOffset); + m_TocWriter.AddInteger(gsl::narrow<int>(ObjectView.GetSize())); + m_TocWriter.EndArray(); + + m_FileOffset += ObjectView.GetSize(); +} + +////////////////////////////////////////////////////////////////////////// + +ActionRecorder::ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath) +: m_ChunkResolver(InChunkResolver) +, m_RecordingLogDir(RecordingLogPath) +{ + std::error_code Ec; + CreateDirectories(m_RecordingLogDir, Ec); + + if (Ec) + { + ZEN_WARN("Could not create directory '{}': {}", m_RecordingLogDir, Ec.message()); + } + + CleanDirectory(m_RecordingLogDir, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Could not clean directory '{}': {}", m_RecordingLogDir, Ec.message()); + } + + m_WorkersFile.Open(m_RecordingLogDir / "workers.zdat"); + m_ActionsFile.Open(m_RecordingLogDir / "actions.zdat"); + + CidStoreConfiguration CidConfig; + CidConfig.RootDirectory = m_RecordingLogDir / "cid"; + CidConfig.HugeValueThreshold = 128 * 1024 * 1024; + + m_CidStore.Initialize(CidConfig); +} + +ActionRecorder::~ActionRecorder() +{ + Shutdown(); +} + +void +ActionRecorder::Shutdown() +{ + m_CidStore.Flush(); +} + +void +ActionRecorder::RegisterWorker(const CbPackage& WorkerPackage) +{ + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + + m_WorkersFile.AppendObject(WorkerPackage.GetObject(), WorkerId); + + std::unordered_set<IoHash> AddedChunks; + uint64_t AddedBytes = 0; + + // First add all attachments from the worker package itself + + for (const CbAttachment& Attachment : WorkerPackage.GetAttachments()) + { + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + IoBuffer Data = Buffer.GetCompressed().Flatten().AsIoBuffer(); + + const IoHash ChunkHash = Buffer.DecodeRawHash(); + + CidStore::InsertResult Result = m_CidStore.AddChunk(Data, ChunkHash, CidStore::InsertMode::kCopyOnly); + + AddedChunks.insert(ChunkHash); + + if (Result.New) + { + AddedBytes += Data.GetSize(); + } + } + + // Not all attachments will be present in the worker package, so we need to add + // all referenced chunks to ensure that the recording is self-contained and not + // referencing data in the main CID store + + CbObject WorkerDescriptor = WorkerPackage.GetObject(); + + WorkerDescriptor.IterateAttachments([&](const CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + + if (!AddedChunks.contains(AttachmentCid)) + { + IoBuffer AttachmentData = m_ChunkResolver.FindChunkByCid(AttachmentCid); + + if (AttachmentData) + { + CidStore::InsertResult Result = m_CidStore.AddChunk(AttachmentData, AttachmentCid, CidStore::InsertMode::kCopyOnly); + + if (Result.New) + { + AddedBytes += AttachmentData.GetSize(); + } + } + else + { + ZEN_WARN("RegisterWorker: could not resolve attachment chunk {} for worker {}", AttachmentCid, WorkerId); + } + + AddedChunks.insert(AttachmentCid); + } + }); + + ZEN_INFO("recorded worker {} with {} attachments ({} bytes)", WorkerId, AddedChunks.size(), AddedBytes); +} + +bool +ActionRecorder::RecordAction(Ref<RunnerAction> Action) +{ + bool AllGood = true; + + Action->ActionObj.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsHash(); + IoBuffer ChunkData = m_ChunkResolver.FindChunkByCid(AttachData); + + if (ChunkData) + { + if (ChunkData.GetContentType() == ZenContentType::kCompressedBinary) + { + IoHash DecompressedHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), /* out */ DecompressedHash, /* out*/ RawSize); + + OodleCompressor Compressor; + OodleCompressionLevel CompressionLevel; + uint64_t BlockSize = 0; + if (Compressed.TryGetCompressParameters(/* out */ Compressor, /* out */ CompressionLevel, /* out */ BlockSize)) + { + if (Compressor == OodleCompressor::NotSet) + { + CompositeBuffer Decompressed = Compressed.DecompressToComposite(); + CompressedBuffer NewCompressed = CompressedBuffer::Compress(std::move(Decompressed), + OodleCompressor::Mermaid, + OodleCompressionLevel::Fast, + BlockSize); + + ChunkData = NewCompressed.GetCompressed().Flatten().AsIoBuffer(); + } + } + } + + const uint64_t ChunkSize = ChunkData.GetSize(); + + m_CidStore.AddChunk(ChunkData, AttachData, CidStore::InsertMode::kCopyOnly); + ++m_ChunkCounter; + m_ChunkBytesCounter.fetch_add(ChunkSize); + } + else + { + AllGood = false; + + ZEN_WARN("could not resolve chunk {}", AttachData); + } + }); + + if (AllGood) + { + m_ActionsFile.AppendObject(Action->ActionObj, Action->ActionId); + ++m_ActionsCounter; + + return true; + } + else + { + return false; + } +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/recording/actionrecorder.h b/src/zencompute/recording/actionrecorder.h new file mode 100644 index 000000000..2827b6ac7 --- /dev/null +++ b/src/zencompute/recording/actionrecorder.h @@ -0,0 +1,91 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/computeservice.h> +#include <zencompute/zencompute.h> +#include <zencore/basicfile.h> +#include <zencore/compactbinarybuilder.h> +#include <zenstore/cidstore.h> +#include <zenstore/gc.h> +#include <zenstore/zenstore.h> + +#include <filesystem> +#include <functional> +#include <map> +#include <unordered_map> + +namespace zen { +class CbObject; +class CbPackage; +struct IoHash; +} // namespace zen + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +////////////////////////////////////////////////////////////////////////// + +struct RecordingFileWriter +{ + RecordingFileWriter(RecordingFileWriter&&) = delete; + RecordingFileWriter& operator=(RecordingFileWriter&&) = delete; + + RwLock m_FileLock; + BasicFile m_File; + uint64_t m_FileOffset = 0; + CbObjectWriter m_TocWriter; + BasicFile m_TocFile; + + RecordingFileWriter(); + ~RecordingFileWriter(); + + void Open(std::filesystem::path FilePath); + void Close(); + void AppendObject(const CbObject& Object, const IoHash& ObjectHash); +}; + +////////////////////////////////////////////////////////////////////////// + +/** + * Recording "runner" implementation + * + * This class writes out all actions and their attachments to a recording directory + * in a format that can be read back by the RecordingReader. + * + * The contents of the recording directory will be self-contained, with all referenced + * attachments stored in the recording directory itself, so that the recording can be + * moved or shared without needing to maintain references to the main CID store. + * + */ + +class ActionRecorder +{ +public: + ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath); + ~ActionRecorder(); + + ActionRecorder(const ActionRecorder&) = delete; + ActionRecorder& operator=(const ActionRecorder&) = delete; + + void Shutdown(); + void RegisterWorker(const CbPackage& WorkerPackage); + bool RecordAction(Ref<RunnerAction> Action); + +private: + ChunkResolver& m_ChunkResolver; + std::filesystem::path m_RecordingLogDir; + + RecordingFileWriter m_WorkersFile; + RecordingFileWriter m_ActionsFile; + GcManager m_Gc; + CidStore m_CidStore{m_Gc}; + std::atomic<int> m_ChunkCounter{0}; + std::atomic<uint64_t> m_ChunkBytesCounter{0}; + std::atomic<int> m_ActionsCounter{0}; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/recording/recordingreader.cpp b/src/zencompute/recording/recordingreader.cpp new file mode 100644 index 000000000..1c1a119cf --- /dev/null +++ b/src/zencompute/recording/recordingreader.cpp @@ -0,0 +1,335 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/recordingreader.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinaryfile.h> +#include <zencore/compactbinaryvalue.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> + +#if ZEN_PLATFORM_WINDOWS +# include <ppl.h> +# define ZEN_CONCRT_AVAILABLE 1 +#else +# define ZEN_CONCRT_AVAILABLE 0 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +# if ZEN_PLATFORM_WINDOWS +# define ZEN_BUILD_ACTION L"Build.action" +# define ZEN_WORKER_UCB L"worker.ucb" +# else +# define ZEN_BUILD_ACTION "Build.action" +# define ZEN_WORKER_UCB "worker.ucb" +# endif + +////////////////////////////////////////////////////////////////////////// + +struct RecordingTreeVisitor : public FileSystemTraversal::TreeVisitor +{ + virtual void VisitFile(const std::filesystem::path& Parent, + const path_view& File, + uint64_t FileSize, + uint32_t NativeModeOrAttributes, + uint64_t NativeModificationTick) + { + ZEN_UNUSED(Parent, File, FileSize, NativeModeOrAttributes, NativeModificationTick); + + if (File.compare(path_view(ZEN_BUILD_ACTION)) == 0) + { + WorkDirs.push_back(Parent); + } + else if (File.compare(path_view(ZEN_WORKER_UCB)) == 0) + { + WorkerDirs.push_back(Parent); + } + } + + virtual bool VisitDirectory(const std::filesystem::path& Parent, const path_view& DirectoryName, uint32_t NativeModeOrAttributes) + { + ZEN_UNUSED(Parent, DirectoryName, NativeModeOrAttributes); + + return true; + } + + std::vector<std::filesystem::path> WorkerDirs; + std::vector<std::filesystem::path> WorkDirs; +}; + +////////////////////////////////////////////////////////////////////////// + +void +IterateOverArray(auto Array, auto Func, int TargetParallelism) +{ +# if ZEN_CONCRT_AVAILABLE + if (TargetParallelism > 1) + { + concurrency::simple_partitioner Chunker(Array.size() / TargetParallelism); + concurrency::parallel_for_each(begin(Array), end(Array), [&](const auto& Item) { Func(Item); }); + + return; + } +# else + ZEN_UNUSED(TargetParallelism); +# endif + + for (const auto& Item : Array) + { + Func(Item); + } +} + +////////////////////////////////////////////////////////////////////////// + +RecordingReaderBase::~RecordingReaderBase() = default; + +////////////////////////////////////////////////////////////////////////// + +RecordingReader::RecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingLogDir(RecordingPath) +{ + CidStoreConfiguration CidConfig; + CidConfig.RootDirectory = m_RecordingLogDir / "cid"; + CidConfig.HugeValueThreshold = 128 * 1024 * 1024; + + m_CidStore.Initialize(CidConfig); +} + +RecordingReader::~RecordingReader() +{ + m_CidStore.Flush(); +} + +size_t +RecordingReader::GetActionCount() const +{ + return m_Actions.size(); +} + +IoBuffer +RecordingReader::FindChunkByCid(const IoHash& DecompressedId) +{ + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(DecompressedId)) + { + return Chunk; + } + + ZEN_ERROR("failed lookup of chunk with CID '{}'", DecompressedId); + + return {}; +} + +std::unordered_map<zen::IoHash, zen::CbPackage> +RecordingReader::ReadWorkers() +{ + std::unordered_map<zen::IoHash, zen::CbPackage> WorkerMap; + + { + CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "workers.ztoc"); + CbObject Toc = TocFile.Object; + + m_WorkerDataFile.Open(m_RecordingLogDir / "workers.zdat", BasicFile::Mode::kRead); + + ZEN_ASSERT(Toc["version"sv].AsInt32() == 1); + + for (auto& It : Toc["toc"]) + { + CbArrayView Entry = It.AsArrayView(); + CbFieldViewIterator Vit = Entry.CreateViewIterator(); + + const IoHash WorkerId = Vit++->AsHash(); + const uint64_t Offset = Vit++->AsInt64(0); + const uint64_t Size = Vit++->AsInt64(0); + + IoBuffer WorkerRange = m_WorkerDataFile.ReadRange(Offset, Size); + CbObject WorkerDesc = LoadCompactBinaryObject(WorkerRange); + CbPackage& WorkerPkg = WorkerMap[WorkerId]; + WorkerPkg.SetObject(WorkerDesc); + + WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = m_CidStore.FindChunkByCid(AttachmentCid); + + if (AttachmentData) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash)); + } + }); + } + } + + // Scan actions as well (this should be called separately, ideally) + + ScanActions(); + + return WorkerMap; +} + +void +RecordingReader::ScanActions() +{ + CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "actions.ztoc"); + CbObject Toc = TocFile.Object; + + m_ActionDataFile.Open(m_RecordingLogDir / "actions.zdat", BasicFile::Mode::kRead); + + ZEN_ASSERT(Toc["version"sv].AsInt32() == 1); + + for (auto& It : Toc["toc"]) + { + CbArrayView ArrayEntry = It.AsArrayView(); + CbFieldViewIterator Vit = ArrayEntry.CreateViewIterator(); + + ActionEntry Entry; + Entry.ActionId = Vit++->AsHash(); + Entry.Offset = Vit++->AsInt64(0); + Entry.Size = Vit++->AsInt64(0); + + m_Actions.push_back(Entry); + } +} + +void +RecordingReader::IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, int TargetParallelism) +{ + IterateOverArray( + m_Actions, + [&](const ActionEntry& Entry) { + CbObject ActionDesc = LoadCompactBinaryObject(m_ActionDataFile.ReadRange(Entry.Offset, Entry.Size)); + + Callback(ActionDesc, Entry.ActionId); + }, + TargetParallelism); +} + +////////////////////////////////////////////////////////////////////////// + +IoBuffer +LocalResolver::FindChunkByCid(const IoHash& DecompressedId) +{ + RwLock::SharedLockScope _(MapLock); + if (auto It = Attachments.find(DecompressedId); It != Attachments.end()) + { + return It->second; + } + + return {}; +} + +void +LocalResolver::Add(const IoHash& Cid, IoBuffer Data) +{ + RwLock::ExclusiveLockScope _(MapLock); + Data.SetContentType(ZenContentType::kCompressedBinary); + Attachments[Cid] = Data; +} + +/// + +UeRecordingReader::UeRecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingDir(RecordingPath) +{ +} + +UeRecordingReader::~UeRecordingReader() +{ +} + +size_t +UeRecordingReader::GetActionCount() const +{ + return m_WorkDirs.size(); +} + +IoBuffer +UeRecordingReader::FindChunkByCid(const IoHash& DecompressedId) +{ + return m_LocalResolver.FindChunkByCid(DecompressedId); +} + +std::unordered_map<zen::IoHash, zen::CbPackage> +UeRecordingReader::ReadWorkers() +{ + std::unordered_map<zen::IoHash, zen::CbPackage> WorkerMap; + + FileSystemTraversal Traversal; + RecordingTreeVisitor Visitor; + Traversal.TraverseFileSystem(m_RecordingDir, Visitor); + + m_WorkDirs = std::move(Visitor.WorkDirs); + + for (const std::filesystem::path& WorkerDir : Visitor.WorkerDirs) + { + CbObjectFromFile WorkerFile = LoadCompactBinaryObject(WorkerDir / "worker.ucb"); + CbObject WorkerDesc = WorkerFile.Object; + const IoHash& WorkerId = WorkerFile.Hash; + CbPackage& WorkerPkg = WorkerMap[WorkerId]; + WorkerPkg.SetObject(WorkerDesc); + + WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = ReadFile(WorkerDir / "chunks" / AttachmentCid.ToHexString()).Flatten(); + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash)); + }); + } + + return WorkerMap; +} + +void +UeRecordingReader::IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, int ParallelismTarget) +{ + IterateOverArray( + m_WorkDirs, + [&](const std::filesystem::path& WorkDir) { + CbPackage WorkPackage = ReadAction(WorkDir); + CbObject ActionObject = WorkPackage.GetObject(); + const IoHash& ActionId = WorkPackage.GetObjectHash(); + + Callback(ActionObject, ActionId); + }, + ParallelismTarget); +} + +CbPackage +UeRecordingReader::ReadAction(std::filesystem::path WorkDir) +{ + CbPackage WorkPackage; + std::filesystem::path WorkDescPath = WorkDir / "Build.action"; + CbObjectFromFile ActionFile = LoadCompactBinaryObject(WorkDescPath); + CbObject& ActionObject = ActionFile.Object; + + WorkPackage.SetObject(ActionObject); + + ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = ReadFile(WorkDir / "inputs" / AttachmentCid.ToHexString()).Flatten(); + + m_LocalResolver.Add(AttachmentCid, AttachmentData); + + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + ZEN_ASSERT(AttachmentCid == RawHash); + WorkPackage.AddAttachment(CbAttachment(CompressedData, RawHash)); + }); + + return WorkPackage; +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/runners/deferreddeleter.cpp b/src/zencompute/runners/deferreddeleter.cpp new file mode 100644 index 000000000..4fad2cf70 --- /dev/null +++ b/src/zencompute/runners/deferreddeleter.cpp @@ -0,0 +1,340 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "deferreddeleter.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/logging.h> +# include <zencore/thread.h> + +# include <algorithm> +# include <chrono> + +namespace zen::compute { + +using namespace std::chrono_literals; + +using Clock = std::chrono::steady_clock; + +// Default deferral: how long to wait before attempting deletion. +// This gives memory-mapped file handles time to close naturally. +static constexpr auto DeferralPeriod = 60s; + +// Shortened deferral after MarkReady(): the client has collected results +// so handles should be released soon, but we still wait briefly. +static constexpr auto ReadyGracePeriod = 5s; + +// Interval between retry attempts for directories that failed deletion. +static constexpr auto RetryInterval = 5s; + +static constexpr int MaxRetries = 10; + +DeferredDirectoryDeleter::DeferredDirectoryDeleter() : m_Thread(&DeferredDirectoryDeleter::ThreadFunction, this) +{ +} + +DeferredDirectoryDeleter::~DeferredDirectoryDeleter() +{ + Shutdown(); +} + +void +DeferredDirectoryDeleter::Enqueue(int ActionLsn, std::filesystem::path Path) +{ + { + std::lock_guard Lock(m_Mutex); + m_Queue.push_back({ActionLsn, std::move(Path)}); + } + m_Cv.notify_one(); +} + +void +DeferredDirectoryDeleter::MarkReady(int ActionLsn) +{ + { + std::lock_guard Lock(m_Mutex); + m_ReadyLsns.push_back(ActionLsn); + } + m_Cv.notify_one(); +} + +void +DeferredDirectoryDeleter::Shutdown() +{ + { + std::lock_guard Lock(m_Mutex); + m_Done = true; + } + m_Cv.notify_one(); + + if (m_Thread.joinable()) + { + m_Thread.join(); + } +} + +void +DeferredDirectoryDeleter::ThreadFunction() +{ + SetCurrentThreadName("ZenDirCleanup"); + + struct PendingEntry + { + int ActionLsn; + std::filesystem::path Path; + Clock::time_point ReadyTime; + int Attempts = 0; + }; + + std::vector<PendingEntry> PendingList; + + auto TryDelete = [](PendingEntry& Entry) -> bool { + std::error_code Ec; + std::filesystem::remove_all(Entry.Path, Ec); + return !Ec; + }; + + for (;;) + { + bool Shutting = false; + + // Drain the incoming queue and process MarkReady signals + + { + std::unique_lock Lock(m_Mutex); + + if (m_Queue.empty() && m_ReadyLsns.empty() && !m_Done) + { + if (PendingList.empty()) + { + m_Cv.wait(Lock, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; }); + } + else + { + auto NextReady = PendingList.front().ReadyTime; + for (const auto& Entry : PendingList) + { + if (Entry.ReadyTime < NextReady) + { + NextReady = Entry.ReadyTime; + } + } + + m_Cv.wait_until(Lock, NextReady, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; }); + } + } + + // Move new items into PendingList with the full deferral deadline + auto Now = Clock::now(); + for (auto& Entry : m_Queue) + { + PendingList.push_back({Entry.ActionLsn, std::move(Entry.Path), Now + DeferralPeriod, 0}); + } + m_Queue.clear(); + + // Apply MarkReady: shorten ReadyTime for matching entries + for (int Lsn : m_ReadyLsns) + { + for (auto& Entry : PendingList) + { + if (Entry.ActionLsn == Lsn) + { + auto NewReady = Now + ReadyGracePeriod; + if (NewReady < Entry.ReadyTime) + { + Entry.ReadyTime = NewReady; + } + } + } + } + m_ReadyLsns.clear(); + + Shutting = m_Done; + } + + // Process items whose deferral period has elapsed (or all items on shutdown) + + auto Now = Clock::now(); + + for (size_t i = 0; i < PendingList.size();) + { + auto& Entry = PendingList[i]; + + if (!Shutting && Now < Entry.ReadyTime) + { + ++i; + continue; + } + + if (TryDelete(Entry)) + { + if (Entry.Attempts > 0) + { + ZEN_INFO("Retry succeeded for directory '{}'", Entry.Path); + } + + PendingList[i] = std::move(PendingList.back()); + PendingList.pop_back(); + } + else + { + ++Entry.Attempts; + + if (Entry.Attempts >= MaxRetries) + { + ZEN_WARN("Giving up on deleting '{}' after {} attempts", Entry.Path, Entry.Attempts); + PendingList[i] = std::move(PendingList.back()); + PendingList.pop_back(); + } + else + { + ZEN_WARN("Unable to delete directory '{}' (attempt {}), will retry", Entry.Path, Entry.Attempts); + Entry.ReadyTime = Now + RetryInterval; + ++i; + } + } + } + + // Exit once shutdown is requested and nothing remains + + if (Shutting && PendingList.empty()) + { + return; + } + } +} + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS + +# include <zencore/testing.h> + +namespace zen::compute { + +void +deferreddeleter_forcelink() +{ +} + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/testutils.h> + +namespace zen::compute { + +TEST_SUITE_BEGIN("compute.deferreddeleter"); + +TEST_CASE("DeferredDirectoryDeleter.DeletesSingleDirectory") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path DirToDelete = TempDir.Path() / "subdir"; + CreateDirectories(DirToDelete / "nested"); + + CHECK(std::filesystem::exists(DirToDelete)); + + { + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(1, DirToDelete); + } + + CHECK(!std::filesystem::exists(DirToDelete)); +} + +TEST_CASE("DeferredDirectoryDeleter.DeletesMultipleDirectories") +{ + ScopedTemporaryDirectory TempDir; + + constexpr int NumDirs = 10; + std::vector<std::filesystem::path> Dirs; + + for (int i = 0; i < NumDirs; ++i) + { + auto Dir = TempDir.Path() / std::to_string(i); + CreateDirectories(Dir / "child"); + Dirs.push_back(std::move(Dir)); + } + + { + DeferredDirectoryDeleter Deleter; + for (int i = 0; i < NumDirs; ++i) + { + CHECK(std::filesystem::exists(Dirs[i])); + Deleter.Enqueue(100 + i, Dirs[i]); + } + } + + for (const auto& Dir : Dirs) + { + CHECK(!std::filesystem::exists(Dir)); + } +} + +TEST_CASE("DeferredDirectoryDeleter.ShutdownIsIdempotent") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "idempotent"; + CreateDirectories(Dir); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(42, Dir); + Deleter.Shutdown(); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_CASE("DeferredDirectoryDeleter.HandlesNonExistentPath") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path NoSuchDir = TempDir.Path() / "does_not_exist"; + + { + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(99, NoSuchDir); + } +} + +TEST_CASE("DeferredDirectoryDeleter.ExplicitShutdownBeforeDestruction") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "explicit"; + CreateDirectories(Dir / "inner"); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(7, Dir); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_CASE("DeferredDirectoryDeleter.MarkReadyShortensDeferral") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "markready"; + CreateDirectories(Dir / "child"); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(50, Dir); + + // Without MarkReady the full deferral (60s) would apply. + // MarkReady shortens it to 5s, and shutdown bypasses even that. + Deleter.MarkReady(50); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_SUITE_END(); + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/runners/deferreddeleter.h b/src/zencompute/runners/deferreddeleter.h new file mode 100644 index 000000000..9b010aa0f --- /dev/null +++ b/src/zencompute/runners/deferreddeleter.h @@ -0,0 +1,68 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <condition_variable> +# include <deque> +# include <filesystem> +# include <mutex> +# include <thread> +# include <vector> + +namespace zen::compute { + +/// Deletes directories on a background thread to avoid blocking callers. +/// Useful when DeleteDirectories may stall (e.g. Wine's deferred-unlink semantics). +/// +/// Enqueued directories wait for a deferral period before deletion, giving +/// file handles time to close. Call MarkReady() with the ActionLsn to shorten +/// the wait to a brief grace period (e.g. once a client has collected results). +/// On shutdown, all pending directories are deleted immediately. +class DeferredDirectoryDeleter +{ + DeferredDirectoryDeleter(const DeferredDirectoryDeleter&) = delete; + DeferredDirectoryDeleter& operator=(const DeferredDirectoryDeleter&) = delete; + +public: + DeferredDirectoryDeleter(); + ~DeferredDirectoryDeleter(); + + /// Enqueue a directory for deferred deletion, associated with an action LSN. + void Enqueue(int ActionLsn, std::filesystem::path Path); + + /// Signal that the action result has been consumed and the directory + /// can be deleted after a short grace period instead of the full deferral. + void MarkReady(int ActionLsn); + + /// Drain the queue and join the background thread. Idempotent. + void Shutdown(); + +private: + struct QueueEntry + { + int ActionLsn; + std::filesystem::path Path; + }; + + std::mutex m_Mutex; + std::condition_variable m_Cv; + std::deque<QueueEntry> m_Queue; + std::vector<int> m_ReadyLsns; + bool m_Done = false; + std::thread m_Thread; + void ThreadFunction(); +}; + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS +namespace zen::compute { +void deferreddeleter_forcelink(); // internal +} // namespace zen::compute +#endif diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp new file mode 100644 index 000000000..768cdf1e1 --- /dev/null +++ b/src/zencompute/runners/functionrunner.cpp @@ -0,0 +1,365 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "functionrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/filesystem.h> +# include <zencore/trace.h> + +# include <fmt/format.h> +# include <vector> + +namespace zen::compute { + +FunctionRunner::FunctionRunner(std::filesystem::path BasePath) : m_ActionsPath(BasePath / "actions") +{ +} + +FunctionRunner::~FunctionRunner() = default; + +size_t +FunctionRunner::QueryCapacity() +{ + return 1; +} + +std::vector<SubmitResult> +FunctionRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + std::vector<SubmitResult> Results; + Results.reserve(Actions.size()); + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; +} + +void +FunctionRunner::MaybeDumpAction(int ActionLsn, const CbObject& ActionObject) +{ + if (m_DumpActions) + { + std::string UniqueId = fmt::format("{}.ddb", ActionLsn); + std::filesystem::path Path = m_ActionsPath / UniqueId; + + zen::WriteFile(Path, IoBuffer(ActionObject.GetBuffer().AsIoBuffer())); + } +} + +////////////////////////////////////////////////////////////////////////// + +void +BaseRunnerGroup::AddRunnerInternal(FunctionRunner* Runner) +{ + m_RunnersLock.WithExclusiveLock([&] { m_Runners.emplace_back(Runner); }); +} + +size_t +BaseRunnerGroup::QueryCapacity() +{ + size_t TotalCapacity = 0; + m_RunnersLock.WithSharedLock([&] { + for (const auto& Runner : m_Runners) + { + TotalCapacity += Runner->QueryCapacity(); + } + }); + return TotalCapacity; +} + +SubmitResult +BaseRunnerGroup::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("BaseRunnerGroup::SubmitAction"); + RwLock::SharedLockScope _(m_RunnersLock); + + const int InitialIndex = m_NextSubmitIndex.load(std::memory_order_acquire); + int Index = InitialIndex; + const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + + if (RunnerCount == 0) + { + return {.IsAccepted = false, .Reason = "No runners available"}; + } + + do + { + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + + auto& Runner = m_Runners[Index++]; + + SubmitResult Result = Runner->SubmitAction(Action); + + if (Result.IsAccepted == true) + { + m_NextSubmitIndex = Index % RunnerCount; + + return Result; + } + + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + } while (Index != InitialIndex); + + return {.IsAccepted = false}; +} + +std::vector<SubmitResult> +BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions"); + RwLock::SharedLockScope _(m_RunnersLock); + + const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + + if (RunnerCount == 0) + { + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); + } + + // Query capacity per runner and compute total + std::vector<size_t> Capacities(RunnerCount); + size_t TotalCapacity = 0; + + for (int i = 0; i < RunnerCount; ++i) + { + Capacities[i] = m_Runners[i]->QueryCapacity(); + TotalCapacity += Capacities[i]; + } + + if (TotalCapacity == 0) + { + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"}); + } + + // Distribute actions across runners proportionally to their available capacity + std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions(RunnerCount); + std::vector<size_t> ActionRunnerIndex(Actions.size()); + size_t ActionIdx = 0; + + for (int i = 0; i < RunnerCount; ++i) + { + if (Capacities[i] == 0) + { + continue; + } + + size_t Share = (Actions.size() * Capacities[i] + TotalCapacity - 1) / TotalCapacity; + Share = std::min(Share, Capacities[i]); + + for (size_t j = 0; j < Share && ActionIdx < Actions.size(); ++j, ++ActionIdx) + { + PerRunnerActions[i].push_back(Actions[ActionIdx]); + ActionRunnerIndex[ActionIdx] = i; + } + } + + // Assign any remaining actions to runners with capacity (round-robin) + for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount) + { + if (Capacities[i] > PerRunnerActions[i].size()) + { + PerRunnerActions[i].push_back(Actions[ActionIdx]); + ActionRunnerIndex[ActionIdx] = i; + ++ActionIdx; + } + } + + // Submit batches per runner + std::vector<std::vector<SubmitResult>> PerRunnerResults(RunnerCount); + + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]); + } + } + + // Reassemble results in original action order + std::vector<SubmitResult> Results(Actions.size()); + std::vector<size_t> PerRunnerIdx(RunnerCount, 0); + + for (size_t i = 0; i < Actions.size(); ++i) + { + size_t RunnerIdx = ActionRunnerIndex[i]; + size_t Idx = PerRunnerIdx[RunnerIdx]++; + Results[i] = std::move(PerRunnerResults[RunnerIdx][Idx]); + } + + return Results; +} + +size_t +BaseRunnerGroup::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunnersLock); + + size_t TotalCount = 0; + + for (const auto& Runner : m_Runners) + { + TotalCount += Runner->GetSubmittedActionCount(); + } + + return TotalCount; +} + +void +BaseRunnerGroup::RegisterWorker(CbPackage Worker) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->RegisterWorker(Worker); + } +} + +void +BaseRunnerGroup::Shutdown() +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->Shutdown(); + } +} + +bool +BaseRunnerGroup::CancelAction(int ActionLsn) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + if (Runner->CancelAction(ActionLsn)) + { + return true; + } + } + + return false; +} + +void +BaseRunnerGroup::CancelRemoteQueue(int QueueId) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->CancelRemoteQueue(QueueId); + } +} + +////////////////////////////////////////////////////////////////////////// + +RunnerAction::RunnerAction(ComputeServiceSession* OwnerSession) : m_OwnerSession(OwnerSession) +{ + this->Timestamps[static_cast<int>(State::New)] = DateTime::Now().GetTicks(); +} + +RunnerAction::~RunnerAction() +{ +} + +bool +RunnerAction::ResetActionStateToPending() +{ + // Only allow reset from Failed or Abandoned states + State CurrentState = m_ActionState.load(); + + if (CurrentState != State::Failed && CurrentState != State::Abandoned) + { + return false; + } + + if (!m_ActionState.compare_exchange_strong(CurrentState, State::Pending)) + { + return false; + } + + // Clear timestamps from Submitting through _Count + for (int i = static_cast<int>(State::Submitting); i < static_cast<int>(State::_Count); ++i) + { + this->Timestamps[i] = 0; + } + + // Record new Pending timestamp + this->Timestamps[static_cast<int>(State::Pending)] = DateTime::Now().GetTicks(); + + // Clear execution fields + ExecutionLocation.clear(); + CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); + CpuSeconds.store(0.0f, std::memory_order_relaxed); + + // Increment retry count + RetryCount.fetch_add(1, std::memory_order_relaxed); + + // Re-enter the scheduler pipeline + m_OwnerSession->PostUpdate(this); + + return true; +} + +void +RunnerAction::SetActionState(State NewState) +{ + ZEN_ASSERT(NewState < State::_Count); + this->Timestamps[static_cast<int>(NewState)] = DateTime::Now().GetTicks(); + + do + { + if (State CurrentState = m_ActionState.load(); CurrentState == NewState) + { + // No state change + return; + } + else + { + if (NewState <= CurrentState) + { + // Cannot transition to an earlier or same state + return; + } + + if (m_ActionState.compare_exchange_strong(CurrentState, NewState)) + { + // Successful state change + + m_OwnerSession->PostUpdate(this); + + return; + } + } + } while (true); +} + +void +RunnerAction::SetResult(CbPackage&& Result) +{ + m_Result = std::move(Result); +} + +CbPackage& +RunnerAction::GetResult() +{ + ZEN_ASSERT(IsCompleted()); + return m_Result; +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h new file mode 100644 index 000000000..f67414dbb --- /dev/null +++ b/src/zencompute/runners/functionrunner.h @@ -0,0 +1,214 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/computeservice.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <atomic> +# include <filesystem> +# include <vector> + +namespace zen::compute { + +struct SubmitResult +{ + bool IsAccepted = false; + std::string Reason; +}; + +/** Base interface for classes implementing a remote execution "runner" + */ +class FunctionRunner : public RefCounted +{ + FunctionRunner(FunctionRunner&&) = delete; + FunctionRunner& operator=(FunctionRunner&&) = delete; + +public: + FunctionRunner(std::filesystem::path BasePath); + virtual ~FunctionRunner() = 0; + + virtual void Shutdown() = 0; + virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; + + [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0; + [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; + [[nodiscard]] virtual bool IsHealthy() = 0; + [[nodiscard]] virtual size_t QueryCapacity(); + [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + + // Best-effort cancellation of a specific in-flight action. Returns true if the + // cancellation signal was successfully sent. The action will transition to Cancelled + // asynchronously once the platform-level termination completes. + virtual bool CancelAction(int /*ActionLsn*/) { return false; } + + // Cancel the remote queue corresponding to the given local QueueId. + // Only meaningful for remote runners; local runners ignore this. + virtual void CancelRemoteQueue(int /*QueueId*/) {} + +protected: + std::filesystem::path m_ActionsPath; + bool m_DumpActions = false; + void MaybeDumpAction(int ActionLsn, const CbObject& ActionObject); +}; + +/** Base class for RunnerGroup that operates on generic FunctionRunner references. + * All scheduling, capacity, and lifecycle logic lives here. + */ +class BaseRunnerGroup +{ +public: + size_t QueryCapacity(); + SubmitResult SubmitAction(Ref<RunnerAction> Action); + std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + size_t GetSubmittedActionCount(); + void RegisterWorker(CbPackage Worker); + void Shutdown(); + bool CancelAction(int ActionLsn); + void CancelRemoteQueue(int QueueId); + + size_t GetRunnerCount() + { + return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); }); + } + +protected: + void AddRunnerInternal(FunctionRunner* Runner); + + RwLock m_RunnersLock; + std::vector<Ref<FunctionRunner>> m_Runners; + std::atomic<int> m_NextSubmitIndex{0}; +}; + +/** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal. + */ +template<typename RunnerType> +struct RunnerGroup : public BaseRunnerGroup +{ + void AddRunner(RunnerType* Runner) { AddRunnerInternal(Runner); } + + template<typename Predicate> + size_t RemoveRunnerIf(Predicate&& Pred) + { + size_t RemovedCount = 0; + m_RunnersLock.WithExclusiveLock([&] { + auto It = m_Runners.begin(); + while (It != m_Runners.end()) + { + if (Pred(static_cast<RunnerType&>(**It))) + { + (*It)->Shutdown(); + It = m_Runners.erase(It); + ++RemovedCount; + } + else + { + ++It; + } + } + }); + return RemovedCount; + } +}; + +/** + * This represents an action going through different stages of scheduling and execution. + */ +struct RunnerAction : public RefCounted +{ + explicit RunnerAction(ComputeServiceSession* OwnerSession); + ~RunnerAction(); + + int ActionLsn = 0; + int QueueId = 0; + WorkerDesc Worker; + IoHash ActionId; + CbObject ActionObj; + int Priority = 0; + std::string ExecutionLocation; // "local" or remote hostname + + // CPU usage and total CPU time of the running process, sampled periodically by the local runner. + // CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage. + // CpuSeconds: total CPU time (user+system) consumed since process start, in seconds. 0.0 if not yet sampled. + std::atomic<float> CpuUsagePercent{-1.0f}; + std::atomic<float> CpuSeconds{0.0f}; + std::atomic<int> RetryCount{0}; + + enum class State + { + New, + Pending, + Submitting, + Running, + Completed, + Failed, + Abandoned, + Cancelled, + _Count + }; + + static const char* ToString(State _) + { + switch (_) + { + case State::New: + return "New"; + case State::Pending: + return "Pending"; + case State::Submitting: + return "Submitting"; + case State::Running: + return "Running"; + case State::Completed: + return "Completed"; + case State::Failed: + return "Failed"; + case State::Abandoned: + return "Abandoned"; + case State::Cancelled: + return "Cancelled"; + default: + return "Unknown"; + } + } + + static State FromString(std::string_view Name, State Default = State::Failed) + { + for (int i = 0; i < static_cast<int>(State::_Count); ++i) + { + if (Name == ToString(static_cast<State>(i))) + { + return static_cast<State>(i); + } + } + return Default; + } + + uint64_t Timestamps[static_cast<int>(State::_Count)] = {}; + + State ActionState() const { return m_ActionState; } + void SetActionState(State NewState); + + bool IsSuccess() const { return ActionState() == State::Completed; } + bool ResetActionStateToPending(); + bool IsCompleted() const + { + return ActionState() == State::Completed || ActionState() == State::Failed || ActionState() == State::Abandoned || + ActionState() == State::Cancelled; + } + + void SetResult(CbPackage&& Result); + CbPackage& GetResult(); + + ComputeServiceSession* GetOwnerSession() const { return m_OwnerSession; } + +private: + std::atomic<State> m_ActionState = State::New; + ComputeServiceSession* m_OwnerSession = nullptr; + CbPackage m_Result; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp new file mode 100644 index 000000000..e79a6c90f --- /dev/null +++ b/src/zencompute/runners/linuxrunner.cpp @@ -0,0 +1,734 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "linuxrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <fcntl.h> +# include <sched.h> +# include <signal.h> +# include <sys/mount.h> +# include <sys/stat.h> +# include <sys/syscall.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +namespace { + + // All helper functions in this namespace are async-signal-safe (safe to call + // between fork() and execve()). They use only raw syscalls and avoid any + // heap allocation, stdio, or other non-AS-safe operations. + + void WriteToFd(int Fd, const char* Buf, size_t Len) + { + while (Len > 0) + { + ssize_t Written = write(Fd, Buf, Len); + if (Written <= 0) + { + break; + } + Buf += Written; + Len -= static_cast<size_t>(Written); + } + } + + [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno) + { + // Write the message prefix + size_t MsgLen = 0; + for (const char* P = Msg; *P; ++P) + { + ++MsgLen; + } + WriteToFd(ErrorPipeFd, Msg, MsgLen); + + // Append ": " and the errno string if non-zero + if (Errno != 0) + { + WriteToFd(ErrorPipeFd, ": ", 2); + const char* ErrStr = strerror(Errno); + size_t ErrLen = 0; + for (const char* P = ErrStr; *P; ++P) + { + ++ErrLen; + } + WriteToFd(ErrorPipeFd, ErrStr, ErrLen); + } + + _exit(127); + } + + int MkdirIfNeeded(const char* Path, mode_t Mode) + { + if (mkdir(Path, Mode) != 0 && errno != EEXIST) + { + return -1; + } + return 0; + } + + int BindMountReadOnly(const char* Src, const char* Dst) + { + if (mount(Src, Dst, nullptr, MS_BIND | MS_REC, nullptr) != 0) + { + return -1; + } + + // Remount read-only + if (mount(nullptr, Dst, nullptr, MS_REMOUNT | MS_BIND | MS_RDONLY | MS_REC, nullptr) != 0) + { + return -1; + } + + return 0; + } + + // Set up namespace-based sandbox isolation in the child process. + // This is called after fork(), before execve(). All operations must be + // async-signal-safe. + // + // The sandbox layout after pivot_root: + // / -> the sandbox directory (tmpfs-like, was SandboxPath) + // /usr -> bind-mount of host /usr (read-only) + // /lib -> bind-mount of host /lib (read-only) + // /lib64 -> bind-mount of host /lib64 (read-only, optional) + // /etc -> bind-mount of host /etc (read-only) + // /worker -> bind-mount of worker directory (read-only) + // /proc -> proc filesystem + // /dev -> tmpfs with null, zero, urandom + void SetupNamespaceSandbox(const char* SandboxPath, uid_t Uid, gid_t Gid, const char* WorkerPath, int ErrorPipeFd) + { + // 1. Unshare user, mount, and network namespaces + if (unshare(CLONE_NEWUSER | CLONE_NEWNS | CLONE_NEWNET) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "unshare() failed", errno); + } + + // 2. Write UID/GID mappings + // Must deny setgroups first (required by kernel for unprivileged user namespaces) + { + int Fd = open("/proc/self/setgroups", O_WRONLY); + if (Fd >= 0) + { + WriteToFd(Fd, "deny", 4); + close(Fd); + } + // setgroups file may not exist on older kernels; not fatal + } + + { + // uid_map: map our UID to 0 inside the namespace + char Buf[64]; + int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast<unsigned>(Uid)); + + int Fd = open("/proc/self/uid_map", O_WRONLY); + if (Fd < 0) + { + WriteErrorAndExit(ErrorPipeFd, "open uid_map failed", errno); + } + WriteToFd(Fd, Buf, static_cast<size_t>(Len)); + close(Fd); + } + + { + // gid_map: map our GID to 0 inside the namespace + char Buf[64]; + int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast<unsigned>(Gid)); + + int Fd = open("/proc/self/gid_map", O_WRONLY); + if (Fd < 0) + { + WriteErrorAndExit(ErrorPipeFd, "open gid_map failed", errno); + } + WriteToFd(Fd, Buf, static_cast<size_t>(Len)); + close(Fd); + } + + // 3. Privatize the entire mount tree so our mounts don't propagate + if (mount(nullptr, "/", nullptr, MS_REC | MS_PRIVATE, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount MS_PRIVATE failed", errno); + } + + // 4. Create mount points inside the sandbox and bind-mount system directories + + // Helper macro-like pattern for building paths inside sandbox + // We use stack buffers since we can't allocate heap memory safely + char MountPoint[4096]; + + auto BuildPath = [&](const char* Suffix) -> const char* { + snprintf(MountPoint, sizeof(MountPoint), "%s/%s", SandboxPath, Suffix); + return MountPoint; + }; + + // /usr (required) + if (MkdirIfNeeded(BuildPath("usr"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/usr failed", errno); + } + if (BindMountReadOnly("/usr", BuildPath("usr")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /usr failed", errno); + } + + // /lib (required) + if (MkdirIfNeeded(BuildPath("lib"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/lib failed", errno); + } + if (BindMountReadOnly("/lib", BuildPath("lib")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /lib failed", errno); + } + + // /lib64 (optional — not all distros have it) + { + struct stat St; + if (stat("/lib64", &St) == 0 && S_ISDIR(St.st_mode)) + { + if (MkdirIfNeeded(BuildPath("lib64"), 0755) == 0) + { + BindMountReadOnly("/lib64", BuildPath("lib64")); + // Failure is non-fatal for lib64 + } + } + } + + // /etc (required — for resolv.conf, ld.so.cache, etc.) + if (MkdirIfNeeded(BuildPath("etc"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/etc failed", errno); + } + if (BindMountReadOnly("/etc", BuildPath("etc")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /etc failed", errno); + } + + // /worker — bind-mount worker directory (contains the executable) + if (MkdirIfNeeded(BuildPath("worker"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/worker failed", errno); + } + if (BindMountReadOnly(WorkerPath, BuildPath("worker")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount worker dir failed", errno); + } + + // 5. Mount /proc inside sandbox + if (MkdirIfNeeded(BuildPath("proc"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/proc failed", errno); + } + if (mount("proc", BuildPath("proc"), "proc", MS_NOSUID | MS_NOEXEC | MS_NODEV, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount /proc failed", errno); + } + + // 6. Mount tmpfs /dev and bind-mount essential device nodes + if (MkdirIfNeeded(BuildPath("dev"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/dev failed", errno); + } + if (mount("tmpfs", BuildPath("dev"), "tmpfs", MS_NOSUID | MS_NOEXEC, "size=64k,mode=0755") != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount tmpfs /dev failed", errno); + } + + // Bind-mount /dev/null, /dev/zero, /dev/urandom + { + char DevSrc[64]; + char DevDst[4096]; + + auto BindDev = [&](const char* Name) { + snprintf(DevSrc, sizeof(DevSrc), "/dev/%s", Name); + snprintf(DevDst, sizeof(DevDst), "%s/dev/%s", SandboxPath, Name); + + // Create the file to mount over + int Fd = open(DevDst, O_WRONLY | O_CREAT, 0666); + if (Fd >= 0) + { + close(Fd); + } + mount(DevSrc, DevDst, nullptr, MS_BIND, nullptr); + // Non-fatal if individual devices fail + }; + + BindDev("null"); + BindDev("zero"); + BindDev("urandom"); + } + + // 7. pivot_root to sandbox + // pivot_root requires the new root and put_old to be mount points. + // Bind-mount sandbox onto itself to make it a mount point. + if (mount(SandboxPath, SandboxPath, nullptr, MS_BIND | MS_REC, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount sandbox onto itself failed", errno); + } + + // Create .pivot_old inside sandbox + char PivotOld[4096]; + snprintf(PivotOld, sizeof(PivotOld), "%s/.pivot_old", SandboxPath); + if (MkdirIfNeeded(PivotOld, 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir .pivot_old failed", errno); + } + + if (syscall(SYS_pivot_root, SandboxPath, PivotOld) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "pivot_root failed", errno); + } + + // 8. Now inside new root. Clean up old root. + if (chdir("/") != 0) + { + WriteErrorAndExit(ErrorPipeFd, "chdir / failed", errno); + } + + if (umount2("/.pivot_old", MNT_DETACH) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "umount2 .pivot_old failed", errno); + } + + rmdir("/.pivot_old"); + } + +} // anonymous namespace + +LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); + + if (m_Sandboxed) + { + ZEN_INFO("namespace sandboxing enabled for child processes"); + } +} + +SubmitResult +LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("LinuxProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: <worker_exe_path> -Build=build.action + // Pre-compute all path strings before fork() for async-signal-safety. + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::string ExePathStr; + std::string SandboxedExePathStr; + + if (m_Sandboxed) + { + // After pivot_root, the worker dir is at /worker inside the new root + std::filesystem::path SandboxedExePath = std::filesystem::path("/worker") / std::filesystem::path(ExecPath); + SandboxedExePathStr = SandboxedExePath.string(); + // We still need the real path for logging + ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string(); + } + else + { + ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string(); + } + + std::string BuildArg = "-Build=build.action"; + + // argv[0] should be the path the child will see + const std::string& ChildExePath = m_Sandboxed ? SandboxedExePathStr : ExePathStr; + + std::vector<char*> ArgV; + ArgV.push_back(const_cast<char*>(ChildExePath.data())); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + std::string WorkerPathStr = Prepared->WorkerPath.string(); + + // Pre-fork: get uid/gid for namespace mapping, create error pipe + uid_t CurrentUid = 0; + gid_t CurrentGid = 0; + int ErrorPipe[2] = {-1, -1}; + + if (m_Sandboxed) + { + CurrentUid = getuid(); + CurrentGid = getgid(); + + if (pipe2(ErrorPipe, O_CLOEXEC) != 0) + { + throw zen::runtime_error("pipe2() for sandbox error pipe failed: {}", strerror(errno)); + } + } + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + int SavedErrno = errno; + if (m_Sandboxed) + { + close(ErrorPipe[0]); + close(ErrorPipe[1]); + } + throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno)); + } + + if (ChildPid == 0) + { + // Child process + + if (m_Sandboxed) + { + // Close read end of error pipe — child only writes + close(ErrorPipe[0]); + + SetupNamespaceSandbox(SandboxPathStr.c_str(), CurrentUid, CurrentGid, WorkerPathStr.c_str(), ErrorPipe[1]); + + // After pivot_root, CWD is "/" which is the sandbox root. + // execve with the sandboxed path. + execve(SandboxedExePathStr.c_str(), ArgV.data(), Envp.data()); + + WriteErrorAndExit(ErrorPipe[1], "execve failed", errno); + } + else + { + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + _exit(127); + } + } + + // Parent process + + if (m_Sandboxed) + { + // Close write end of error pipe — parent only reads + close(ErrorPipe[1]); + + // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC + // and read returns 0. If setup failed, child wrote an error message. + char ErrBuf[512]; + ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1); + close(ErrorPipe[0]); + + if (BytesRead > 0) + { + // Sandbox setup or execve failed + ErrBuf[BytesRead] = '\0'; + + // Reap the child (it called _exit(127)) + waitpid(ChildPid, nullptr, 0); + + // Clean up the sandbox in the background + m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); + + ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + + Action->SetActionState(RunnerAction::State::Failed); + return SubmitResult{.IsAccepted = false}; + } + } + + // Store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +LinuxProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("LinuxProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +LinuxProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("LinuxProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +LinuxProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("LinuxProcessRunner::CancelAction"); + + // Hold the shared lock while sending the signal to prevent the sweep thread + // from reaping the PID (via waitpid) between our lookup and kill(). Without + // the lock held, the PID could be recycled by the kernel and we'd signal an + // unrelated process. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (!Target->ProcessHandle) + { + return; + } + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Target->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno)); + return; + } + + ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +static uint64_t +ReadProcStatCpuTicks(pid_t Pid) +{ + char Path[64]; + snprintf(Path, sizeof(Path), "/proc/%d/stat", static_cast<int>(Pid)); + + char Buf[256]; + int Fd = open(Path, O_RDONLY); + if (Fd < 0) + { + return 0; + } + + ssize_t Len = read(Fd, Buf, sizeof(Buf) - 1); + close(Fd); + + if (Len <= 0) + { + return 0; + } + + Buf[Len] = '\0'; + + // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens + const char* P = strrchr(Buf, ')'); + if (!P) + { + return 0; + } + + P += 2; // skip ') ' + + // Remaining fields (space-separated, 0-indexed from here): + // 0:state 1:ppid 2:pgrp 3:session 4:tty_nr 5:tty_pgrp 6:flags + // 7:minflt 8:cminflt 9:majflt 10:cmajflt 11:utime 12:stime + unsigned long UTime = 0; + unsigned long STime = 0; + sscanf(P, "%*c %*d %*d %*d %*d %*d %*u %*u %*u %*u %*u %lu %lu", &UTime, &STime); + return UTime + STime; +} + +void +LinuxProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + static const long ClkTck = sysconf(_SC_CLK_TCK); + + const pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running.ProcessHandle)); + + const uint64_t NowTicks = GetHifreqTimerValue(); + const uint64_t CurrentOsTicks = ReadProcStatCpuTicks(Pid); + + if (CurrentOsTicks == 0) + { + // Process gone or /proc entry unreadable — record timestamp without updating usage + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = 0; + return; + } + + // Cumulative CPU seconds (absolute, available from first sample) + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / ClkTck), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) * 1000.0 / ClkTck / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/linuxrunner.h b/src/zencompute/runners/linuxrunner.h new file mode 100644 index 000000000..266de366b --- /dev/null +++ b/src/zencompute/runners/linuxrunner.h @@ -0,0 +1,44 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +namespace zen::compute { + +/** Native Linux process runner for executing Linux worker executables directly. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using Linux namespaces: + user, mount, and network namespaces are unshared so the child has no network + access and can only see the sandbox directory (with system libraries bind-mounted + read-only). This requires no special privileges thanks to user namespaces. + */ +class LinuxProcessRunner : public LocalProcessRunner +{ +public: + LinuxProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + bool m_Sandboxed = false; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp new file mode 100644 index 000000000..7aaefb06e --- /dev/null +++ b/src/zencompute/runners/localrunner.cpp @@ -0,0 +1,674 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/system.h> +# include <zencore/scopeguard.h> +# include <zencore/timer.h> +# include <zencore/trace.h> +# include <zenstore/cidstore.h> + +# include <span> + +namespace zen::compute { + +using namespace std::literals; + +LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions) +: FunctionRunner(BaseDir) +, m_Log(logging::Get("local_exec")) +, m_ChunkResolver(Resolver) +, m_WorkerPath(std::filesystem::weakly_canonical(BaseDir / "workers")) +, m_SandboxPath(std::filesystem::weakly_canonical(BaseDir / "scratch")) +, m_DeferredDeleter(Deleter) +, m_WorkerPool(WorkerPool) +{ + SystemMetrics Sm = GetSystemMetricsForReporting(); + + m_MaxRunningActions = Sm.LogicalProcessorCount * 2; + + if (MaxConcurrentActions > 0) + { + m_MaxRunningActions = MaxConcurrentActions; + } + + ZEN_INFO("Max concurrent action count: {}", m_MaxRunningActions); + + bool DidCleanup = false; + + if (std::filesystem::is_directory(m_ActionsPath)) + { + ZEN_INFO("Cleaning '{}'", m_ActionsPath); + + std::error_code Ec; + CleanDirectory(m_ActionsPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_ActionsPath, Ec.message()); + } + + DidCleanup = true; + } + + if (std::filesystem::is_directory(m_SandboxPath)) + { + ZEN_INFO("Cleaning '{}'", m_SandboxPath); + std::error_code Ec; + CleanDirectory(m_SandboxPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_SandboxPath, Ec.message()); + } + + DidCleanup = true; + } + + // We clean out all workers on startup since we can't know they are good. They could be bad + // due to tampering, malware (which I also mean to include AV and antimalware software) or + // other processes we have no control over + if (std::filesystem::is_directory(m_WorkerPath)) + { + ZEN_INFO("Cleaning '{}'", m_WorkerPath); + std::error_code Ec; + CleanDirectory(m_WorkerPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_WorkerPath, Ec.message()); + } + + DidCleanup = true; + } + + if (DidCleanup) + { + ZEN_INFO("Cleanup complete"); + } + + m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; + +# if ZEN_PLATFORM_WINDOWS + // Suppress any error dialogs caused by missing dependencies + UINT OldMode = ::SetErrorMode(0); + ::SetErrorMode(OldMode | SEM_FAILCRITICALERRORS); +# endif + + m_AcceptNewActions = true; +} + +LocalProcessRunner::~LocalProcessRunner() +{ + try + { + Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception during local process runner shutdown: {}", Ex.what()); + } +} + +void +LocalProcessRunner::Shutdown() +{ + ZEN_TRACE_CPU("LocalProcessRunner::Shutdown"); + m_AcceptNewActions = false; + + m_MonitorThreadEnabled = false; + m_MonitorThreadEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } + + CancelRunningActions(); +} + +std::filesystem::path +LocalProcessRunner::CreateNewSandbox() +{ + ZEN_TRACE_CPU("LocalProcessRunner::CreateNewSandbox"); + std::string UniqueId = std::to_string(++m_SandboxCounter); + std::filesystem::path Path = m_SandboxPath / UniqueId; + zen::CreateDirectories(Path); + + return Path; +} + +void +LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) +{ + ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker"); + if (m_DumpActions) + { + CbObject WorkerDescriptor = WorkerPackage.GetObject(); + const IoHash& WorkerId = WorkerPackage.GetObjectHash(); + + std::string UniqueId = fmt::format("worker_{}"sv, WorkerId); + std::filesystem::path Path = m_ActionsPath / UniqueId; + + zen::WriteFile(Path / "worker.ucb", WorkerDescriptor.GetBuffer().AsIoBuffer()); + + ManifestWorker(WorkerPackage, Path / "tree", [&](const IoHash& Cid, CompressedBuffer& ChunkBuffer) { + std::filesystem::path ChunkPath = Path / "chunks" / Cid.ToHexString(); + zen::WriteFile(ChunkPath, ChunkBuffer.GetCompressed()); + }); + + ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path); + } +} + +size_t +LocalProcessRunner::QueryCapacity() +{ + // Estimate how much more work we're ready to accept + + RwLock::SharedLockScope _{m_RunningLock}; + + if (!m_AcceptNewActions) + { + return 0; + } + + const size_t InFlightCount = m_RunningMap.size() + m_SubmittingCount.load(std::memory_order_relaxed); + + if (const size_t MaxRunningActions = m_MaxRunningActions; InFlightCount >= MaxRunningActions) + { + return 0; + } + else + { + return MaxRunningActions - InFlightCount; + } +} + +std::vector<SubmitResult> +LocalProcessRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + if (Actions.size() <= 1) + { + std::vector<SubmitResult> Results; + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; + } + + // For nontrivial batches, check capacity upfront and accept what fits. + // Accepted actions are transitioned to Submitting and dispatched to the + // worker pool as fire-and-forget, so SubmitActions returns immediately + // and the scheduler thread is free to handle completions and updates. + + size_t Available = QueryCapacity(); + + std::vector<SubmitResult> Results(Actions.size()); + + size_t AcceptCount = std::min(Available, Actions.size()); + + for (size_t i = 0; i < AcceptCount; ++i) + { + const Ref<RunnerAction>& Action = Actions[i]; + + Action->SetActionState(RunnerAction::State::Submitting); + m_SubmittingCount.fetch_add(1, std::memory_order_relaxed); + + Results[i] = SubmitResult{.IsAccepted = true}; + + m_WorkerPool.ScheduleWork( + [this, Action]() { + auto CountGuard = MakeGuard([this] { m_SubmittingCount.fetch_sub(1, std::memory_order_relaxed); }); + + SubmitResult Result = SubmitAction(Action); + + if (!Result.IsAccepted) + { + // This might require another state? We should + // distinguish between outright rejections (e.g. invalid action) + // and transient failures (e.g. failed to launch process) which might + // be retried by the scheduler, but for now just fail the action + Action->SetActionState(RunnerAction::State::Failed); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + + for (size_t i = AcceptCount; i < Actions.size(); ++i) + { + Results[i] = SubmitResult{.IsAccepted = false}; + } + + return Results; +} + +std::optional<LocalProcessRunner::PreparedAction> +LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("LocalProcessRunner::PrepareActionSubmission"); + + // Verify whether we can accept more work + + { + RwLock::SharedLockScope _{m_RunningLock}; + + if (!m_AcceptNewActions) + { + return std::nullopt; + } + + if (m_RunningMap.size() >= size_t(m_MaxRunningActions)) + { + return std::nullopt; + } + } + + // Each enqueued action is assigned an integer index (logical sequence number), + // which we use as a key for tracking data structures and as an opaque id which + // may be used by clients to reference the scheduled action + + const int32_t ActionLsn = Action->ActionLsn; + const CbObject& ActionObj = Action->ActionObj; + + MaybeDumpAction(ActionLsn, ActionObj); + + std::filesystem::path SandboxPath = CreateNewSandbox(); + + // Ensure the sandbox directory is cleaned up if any subsequent step throws + auto SandboxGuard = MakeGuard([&] { m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(SandboxPath)); }); + + CbPackage WorkerPackage = Action->Worker.Descriptor; + + std::filesystem::path WorkerPath = ManifestWorker(Action->Worker); + + // Write out action + + zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer()); + + // Manifest inputs in sandbox + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash Cid = Field.AsHash(); + std::filesystem::path FilePath{SandboxPath / "Inputs"sv / Cid.ToHexString()}; + IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(Cid); + + if (!DataBuffer) + { + throw std::runtime_error(fmt::format("input CID chunk '{}' missing", Cid)); + } + + zen::WriteFile(FilePath, DataBuffer); + }); + + Action->ExecutionLocation = "local"; + + SandboxGuard.Dismiss(); + + return PreparedAction{ + .ActionLsn = ActionLsn, + .SandboxPath = std::move(SandboxPath), + .WorkerPath = std::move(WorkerPath), + .WorkerPackage = std::move(WorkerPackage), + }; +} + +SubmitResult +LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + // Base class is not directly usable — platform subclasses override this + ZEN_UNUSED(Action); + return SubmitResult{.IsAccepted = false}; +} + +size_t +LocalProcessRunner::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunningLock); + return m_RunningMap.size(); +} + +std::filesystem::path +LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) +{ + ZEN_TRACE_CPU("LocalProcessRunner::ManifestWorker"); + RwLock::SharedLockScope _(m_WorkerLock); + + std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); + + if (!std::filesystem::exists(WorkerDir)) + { + _.ReleaseNow(); + + RwLock::ExclusiveLockScope $(m_WorkerLock); + + if (!std::filesystem::exists(WorkerDir)) + { + ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {}); + } + } + + return WorkerDir; +} + +void +LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromPackage, + CbObjectView FileEntry, + const std::filesystem::path& SandboxRootPath, + std::function<void(const IoHash&, CompressedBuffer&)>& ChunkReferenceCallback) +{ + std::string_view Name = FileEntry["name"sv].AsString(); + const IoHash ChunkHash = FileEntry["hash"sv].AsHash(); + const uint64_t Size = FileEntry["size"sv].AsUInt64(); + + CompressedBuffer Compressed; + + if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash)) + { + Compressed = Attachment->AsCompressedBinary(); + } + else + { + IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(ChunkHash); + + if (!DataBuffer) + { + throw std::runtime_error(fmt::format("worker chunk '{}' missing", ChunkHash)); + } + + uint64_t DataRawSize = 0; + IoHash DataRawHash; + Compressed = CompressedBuffer::FromCompressed(SharedBuffer{DataBuffer}, DataRawHash, DataRawSize); + + if (DataRawSize != Size) + { + throw std::runtime_error( + fmt::format("worker chunk '{}' size: {}, action spec expected {}", ChunkHash, DataBuffer.Size(), Size)); + } + } + + ChunkReferenceCallback(ChunkHash, Compressed); + + std::filesystem::path FilePath{SandboxRootPath / std::filesystem::path(Name).make_preferred()}; + + // Validate the resolved path stays within the sandbox to prevent directory traversal + // via malicious names like "../../etc/evil" + // + // This might be worth revisiting to frontload the validation and eliminate some memory + // allocations in the future. + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxRootPath); + std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(FilePath); + std::string RootStr = CanonicalRoot.string(); + std::string FileStr = CanonicalFile.string(); + + if (FileStr.size() < RootStr.size() || FileStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: '{}' escapes sandbox root '{}'", Name, SandboxRootPath); + } + } + + SharedBuffer Decompressed = Compressed.Decompress(); + zen::WriteFile(FilePath, Decompressed.AsIoBuffer()); +} + +void +LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, + const std::filesystem::path& SandboxPath, + std::function<void(const IoHash&, CompressedBuffer&)>&& ChunkReferenceCallback) +{ + CbObject WorkerDescription = WorkerPackage.GetObject(); + + // Manifest worker in Sandbox + + for (auto& It : WorkerDescription["executables"sv]) + { + DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); +# if !ZEN_PLATFORM_WINDOWS + std::string_view ExeName = It.AsObjectView()["name"sv].AsString(); + std::filesystem::path ExePath{SandboxPath / std::filesystem::path(ExeName).make_preferred()}; + std::filesystem::permissions( + ExePath, + std::filesystem::perms::owner_exec | std::filesystem::perms::group_exec | std::filesystem::perms::others_exec, + std::filesystem::perm_options::add); +# endif + } + + for (auto& It : WorkerDescription["dirs"sv]) + { + std::string_view Name = It.AsString(); + std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; + + // Validate dir path stays within sandbox + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxPath); + std::filesystem::path CanonicalDir = std::filesystem::weakly_canonical(DirPath); + std::string RootStr = CanonicalRoot.string(); + std::string DirStr = CanonicalDir.string(); + + if (DirStr.size() < RootStr.size() || DirStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: dir '{}' escapes sandbox root '{}'", Name, SandboxPath); + } + } + + zen::CreateDirectories(DirPath); + } + + for (auto& It : WorkerDescription["files"sv]) + { + DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); + } + + WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer()); +} + +CbPackage +LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) +{ + ZEN_TRACE_CPU("LocalProcessRunner::GatherActionOutputs"); + std::filesystem::path OutputFile = SandboxPath / "build.output"; + FileContents OutputData = zen::ReadFile(OutputFile); + + if (OutputData.ErrorCode) + { + throw std::system_error(OutputData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputFile)); + } + + CbPackage OutputPackage; + CbObject Output = zen::LoadCompactBinaryObject(OutputData.Flatten()); + + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalRawAttachmentBytes = 0; + + Output.IterateAttachments([&](CbFieldView Field) { + IoHash Hash = Field.AsHash(); + std::filesystem::path OutputPath{SandboxPath / "Outputs" / Hash.ToHexString()}; + FileContents ChunkData = zen::ReadFile(OutputPath); + + if (ChunkData.ErrorCode) + { + throw std::system_error(ChunkData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputPath)); + } + + uint64_t ChunkDataRawSize = 0; + IoHash ChunkDataHash; + CompressedBuffer AttachmentBuffer = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData.Flatten()), ChunkDataHash, ChunkDataRawSize); + + if (!AttachmentBuffer) + { + throw std::runtime_error("Invalid output encountered (not valid CompressedBuffer format)"); + } + + TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize(); + TotalRawAttachmentBytes += ChunkDataRawSize; + + CbAttachment Attachment(std::move(AttachmentBuffer), ChunkDataHash); + OutputPackage.AddAttachment(Attachment); + }); + + OutputPackage.SetObject(Output); + + ZEN_DEBUG("Action completed with {} attachments ({} compressed, {} uncompressed)", + OutputPackage.GetAttachments().size(), + NiceBytes(TotalAttachmentBytes), + NiceBytes(TotalRawAttachmentBytes)); + + return OutputPackage; +} + +void +LocalProcessRunner::MonitorThreadFunction() +{ + SetCurrentThreadName("LocalProcessRunner_Monitor"); + + auto _ = MakeGuard([&] { ZEN_INFO("monitor thread exiting"); }); + + do + { + // On Windows it's possible to wait on process handles, so we wait for either a process to exit + // or for the monitor event to be signaled (which indicates we should check for cancellation + // or shutdown). This could be further improved by using a completion port and registering process + // handles with it, but this is a reasonable first implementation given that we shouldn't be dealing + // with an enormous number of concurrent processes. + // + // On other platforms we just wait on the monitor event and poll for process exits at intervals. +# if ZEN_PLATFORM_WINDOWS + auto WaitOnce = [&] { + HANDLE WaitHandles[MAXIMUM_WAIT_OBJECTS]; + + uint32_t NumHandles = 0; + + WaitHandles[NumHandles++] = m_MonitorThreadEvent.GetWindowsHandle(); + + m_RunningLock.WithSharedLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd && NumHandles < MAXIMUM_WAIT_OBJECTS; ++It) + { + Ref<RunningAction> Action = It->second; + + WaitHandles[NumHandles++] = Action->ProcessHandle; + } + }); + + DWORD WaitResult = WaitForMultipleObjects(NumHandles, WaitHandles, FALSE, 1000); + + // return true if a handle was signaled + return (WaitResult <= NumHandles); + }; +# else + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(1000); }; +# endif + + while (!WaitOnce()) + { + if (m_MonitorThreadEnabled == false) + { + return; + } + + SweepRunningActions(); + SampleRunningProcessCpu(); + } + + // Signal received + + SweepRunningActions(); + SampleRunningProcessCpu(); + } while (m_MonitorThreadEnabled); +} + +void +LocalProcessRunner::CancelRunningActions() +{ + // Base class is not directly usable — platform subclasses override this +} + +void +LocalProcessRunner::SampleRunningProcessCpu() +{ + static constexpr uint64_t kSampleIntervalMs = 5'000; + + m_RunningLock.WithSharedLock([&] { + const uint64_t Now = GetHifreqTimerValue(); + for (auto& [Lsn, Running] : m_RunningMap) + { + const bool NeverSampled = Running->LastCpuSampleTicks == 0; + const bool IntervalElapsed = Stopwatch::GetElapsedTimeMs(Now - Running->LastCpuSampleTicks) >= kSampleIntervalMs; + if (NeverSampled || IntervalElapsed) + { + SampleProcessCpu(*Running); + } + } + }); +} + +void +LocalProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("LocalProcessRunner::SweepRunningActions"); +} + +void +LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& CompletedActions) +{ + ZEN_TRACE_CPU("LocalProcessRunner::ProcessCompletedActions"); + // Shared post-processing: gather outputs, set state, clean sandbox. + // Note that this must be called without holding any local locks + // otherwise we may end up with deadlocks. + + for (Ref<RunningAction> Running : CompletedActions) + { + const int ActionLsn = Running->Action->ActionLsn; + + if (Running->ExitCode == 0) + { + try + { + // Gather outputs + + CbPackage OutputPackage = GatherActionOutputs(Running->SandboxPath); + + Running->Action->SetResult(std::move(OutputPackage)); + Running->Action->SetActionState(RunnerAction::State::Completed); + + // Enqueue sandbox for deferred background deletion, giving + // file handles time to close before we attempt removal. + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); + + // Success -- continue with next iteration of the loop + continue; + } + catch (std::exception& Ex) + { + ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what()); + } + } + + // Failed - clean up the sandbox in the background. + + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h new file mode 100644 index 000000000..7493e980b --- /dev/null +++ b/src/zencompute/runners/localrunner.h @@ -0,0 +1,138 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "functionrunner.h" + +# include <zencore/thread.h> +# include <zencore/zencore.h> +# include <zenstore/cidstore.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/logging.h> + +# include "deferreddeleter.h" + +# include <zencore/workthreadpool.h> + +# include <atomic> +# include <filesystem> +# include <optional> +# include <thread> + +namespace zen { +class CbPackage; +} + +namespace zen::compute { + +/** Direct process spawner + + This runner simply sets up a directory structure for each job and + creates a process to perform the computation in it. It is not very + efficient and is intended mostly for testing. + + */ + +class LocalProcessRunner : public FunctionRunner +{ + LocalProcessRunner(LocalProcessRunner&&) = delete; + LocalProcessRunner& operator=(LocalProcessRunner&&) = delete; + +public: + LocalProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions = 0); + ~LocalProcessRunner(); + + virtual void Shutdown() override; + virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + [[nodiscard]] virtual bool IsHealthy() override { return true; } + [[nodiscard]] virtual size_t GetSubmittedActionCount() override; + [[nodiscard]] virtual size_t QueryCapacity() override; + [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override; + +protected: + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + + struct RunningAction : public RefCounted + { + Ref<RunnerAction> Action; + void* ProcessHandle = nullptr; + int ExitCode = 0; + std::filesystem::path SandboxPath; + + // State for periodic CPU usage sampling + uint64_t LastCpuSampleTicks = 0; // hifreq timer value at last sample + uint64_t LastCpuOsTicks = 0; // OS CPU ticks (platform-specific units) at last sample + }; + + std::atomic_bool m_AcceptNewActions; + ChunkResolver& m_ChunkResolver; + RwLock m_WorkerLock; + std::filesystem::path m_WorkerPath; + std::atomic<int32_t> m_SandboxCounter = 0; + std::filesystem::path m_SandboxPath; + int32_t m_MaxRunningActions = 64; // arbitrary limit for testing + + // if used in conjuction with m_ResultsLock, this lock must be taken *after* + // m_ResultsLock to avoid deadlocks + RwLock m_RunningLock; + std::unordered_map<int, Ref<RunningAction>> m_RunningMap; + + std::atomic<int32_t> m_SubmittingCount = 0; + DeferredDirectoryDeleter& m_DeferredDeleter; + WorkerThreadPool& m_WorkerPool; + + std::thread m_MonitorThread; + std::atomic<bool> m_MonitorThreadEnabled{true}; + Event m_MonitorThreadEvent; + void MonitorThreadFunction(); + virtual void SweepRunningActions(); + virtual void CancelRunningActions(); + + // Sample CPU usage for all currently running processes (throttled per-action). + void SampleRunningProcessCpu(); + + // Override in platform runners to sample one process. Called under a shared RunningLock. + virtual void SampleProcessCpu(RunningAction& /*Running*/) {} + + // Shared preamble for SubmitAction: capacity check, sandbox creation, + // worker manifesting, action writing, input manifesting. + struct PreparedAction + { + int32_t ActionLsn; + std::filesystem::path SandboxPath; + std::filesystem::path WorkerPath; + CbPackage WorkerPackage; + }; + std::optional<PreparedAction> PrepareActionSubmission(Ref<RunnerAction> Action); + + // Shared post-processing for SweepRunningActions: gather outputs, + // set state, clean sandbox. + void ProcessCompletedActions(std::vector<Ref<RunningAction>>& CompletedActions); + + std::filesystem::path CreateNewSandbox(); + void ManifestWorker(const CbPackage& WorkerPackage, + const std::filesystem::path& SandboxPath, + std::function<void(const IoHash&, CompressedBuffer&)>&& ChunkReferenceCallback); + std::filesystem::path ManifestWorker(const WorkerDesc& Worker); + CbPackage GatherActionOutputs(std::filesystem::path SandboxPath); + + void DecompressAttachmentToFile(const CbPackage& FromPackage, + CbObjectView FileEntry, + const std::filesystem::path& SandboxRootPath, + std::function<void(const IoHash&, CompressedBuffer&)>& ChunkReferenceCallback); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp new file mode 100644 index 000000000..5cec90699 --- /dev/null +++ b/src/zencompute/runners/macrunner.cpp @@ -0,0 +1,491 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "macrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <fcntl.h> +# include <libproc.h> +# include <sandbox.h> +# include <signal.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +namespace { + + // All helper functions in this namespace are async-signal-safe (safe to call + // between fork() and execve()). They use only raw syscalls and avoid any + // heap allocation, stdio, or other non-AS-safe operations. + + void WriteToFd(int Fd, const char* Buf, size_t Len) + { + while (Len > 0) + { + ssize_t Written = write(Fd, Buf, Len); + if (Written <= 0) + { + break; + } + Buf += Written; + Len -= static_cast<size_t>(Written); + } + } + + [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno) + { + // Write the message prefix + size_t MsgLen = 0; + for (const char* P = Msg; *P; ++P) + { + ++MsgLen; + } + WriteToFd(ErrorPipeFd, Msg, MsgLen); + + // Append ": " and the errno string if non-zero + if (Errno != 0) + { + WriteToFd(ErrorPipeFd, ": ", 2); + const char* ErrStr = strerror(Errno); + size_t ErrLen = 0; + for (const char* P = ErrStr; *P; ++P) + { + ++ErrLen; + } + WriteToFd(ErrorPipeFd, ErrStr, ErrLen); + } + + _exit(127); + } + + // Build a Seatbelt profile string that denies everything by default and + // allows only the minimum needed for the worker to execute: process ops, + // system library reads, worker directory (read-only), and sandbox directory + // (read-write). Network access is denied implicitly by the deny-default policy. + std::string BuildSandboxProfile(const std::string& SandboxPath, const std::string& WorkerPath) + { + std::string Profile; + Profile.reserve(1024); + + Profile += "(version 1)\n"; + Profile += "(deny default)\n"; + Profile += "(allow process*)\n"; + Profile += "(allow sysctl-read)\n"; + Profile += "(allow file-read-metadata)\n"; + + // System library paths needed for dynamic linker and runtime + Profile += "(allow file-read* (subpath \"/usr\"))\n"; + Profile += "(allow file-read* (subpath \"/System\"))\n"; + Profile += "(allow file-read* (subpath \"/Library\"))\n"; + Profile += "(allow file-read* (subpath \"/dev\"))\n"; + Profile += "(allow file-read* (subpath \"/private/var/db/dyld\"))\n"; + Profile += "(allow file-read* (subpath \"/etc\"))\n"; + + // Worker directory: read-only + Profile += "(allow file-read* (subpath \""; + Profile += WorkerPath; + Profile += "\"))\n"; + + // Sandbox directory: read+write + Profile += "(allow file-read* file-write* (subpath \""; + Profile += SandboxPath; + Profile += "\"))\n"; + + return Profile; + } + +} // anonymous namespace + +MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); + + if (m_Sandboxed) + { + ZEN_INFO("Seatbelt sandboxing enabled for child processes"); + } +} + +SubmitResult +MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("MacProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: <worker_exe_path> -Build=build.action + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath); + std::string ExePathStr = ExePath.string(); + std::string BuildArg = "-Build=build.action"; + + std::vector<char*> ArgV; + ArgV.push_back(ExePathStr.data()); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + std::string WorkerPathStr = Prepared->WorkerPath.string(); + + // Pre-fork: build sandbox profile and create error pipe + std::string SandboxProfile; + int ErrorPipe[2] = {-1, -1}; + + if (m_Sandboxed) + { + SandboxProfile = BuildSandboxProfile(SandboxPathStr, WorkerPathStr); + + if (pipe(ErrorPipe) != 0) + { + throw zen::runtime_error("pipe() for sandbox error pipe failed: {}", strerror(errno)); + } + fcntl(ErrorPipe[0], F_SETFD, FD_CLOEXEC); + fcntl(ErrorPipe[1], F_SETFD, FD_CLOEXEC); + } + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + int SavedErrno = errno; + if (m_Sandboxed) + { + close(ErrorPipe[0]); + close(ErrorPipe[1]); + } + throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno)); + } + + if (ChildPid == 0) + { + // Child process + + if (m_Sandboxed) + { + // Close read end of error pipe — child only writes + close(ErrorPipe[0]); + + // Apply Seatbelt sandbox profile + char* ErrorBuf = nullptr; + if (sandbox_init(SandboxProfile.c_str(), 0, &ErrorBuf) != 0) + { + // sandbox_init failed — write error to pipe and exit + if (ErrorBuf) + { + WriteErrorAndExit(ErrorPipe[1], ErrorBuf, 0); + // WriteErrorAndExit does not return, but sandbox_free_error + // is not needed since we _exit + } + WriteErrorAndExit(ErrorPipe[1], "sandbox_init failed", errno); + } + if (ErrorBuf) + { + sandbox_free_error(ErrorBuf); + } + + if (chdir(SandboxPathStr.c_str()) != 0) + { + WriteErrorAndExit(ErrorPipe[1], "chdir to sandbox failed", errno); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + + WriteErrorAndExit(ErrorPipe[1], "execve failed", errno); + } + else + { + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + _exit(127); + } + } + + // Parent process + + if (m_Sandboxed) + { + // Close write end of error pipe — parent only reads + close(ErrorPipe[1]); + + // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC + // and read returns 0. If setup failed, child wrote an error message. + char ErrBuf[512]; + ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1); + close(ErrorPipe[0]); + + if (BytesRead > 0) + { + // Sandbox setup or execve failed + ErrBuf[BytesRead] = '\0'; + + // Reap the child (it called _exit(127)) + waitpid(ChildPid, nullptr, 0); + + // Clean up the sandbox in the background + m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); + + ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + + Action->SetActionState(RunnerAction::State::Failed); + return SubmitResult{.IsAccepted = false}; + } + } + + // Store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +MacProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("MacProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +MacProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("MacProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +MacProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("MacProcessRunner::CancelAction"); + + // Hold the shared lock while sending the signal to prevent the sweep thread + // from reaping the PID (via waitpid) between our lookup and kill(). Without + // the lock held, the PID could be recycled by the kernel and we'd signal an + // unrelated process. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (!Target->ProcessHandle) + { + return; + } + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Target->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno)); + return; + } + + ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +void +MacProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + const pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running.ProcessHandle)); + + struct proc_taskinfo Info; + if (proc_pidinfo(Pid, PROC_PIDTASKINFO, 0, &Info, sizeof(Info)) <= 0) + { + return; + } + + // pti_total_user and pti_total_system are in nanoseconds + const uint64_t CurrentOsTicks = Info.pti_total_user + Info.pti_total_system; + const uint64_t NowTicks = GetHifreqTimerValue(); + + // Cumulative CPU seconds (absolute, available from first sample): ns → seconds + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 1'000'000'000.0), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + // ns → ms: divide by 1,000,000; then as percent of elapsed ms + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 1'000'000.0 / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/macrunner.h b/src/zencompute/runners/macrunner.h new file mode 100644 index 000000000..d653b923a --- /dev/null +++ b/src/zencompute/runners/macrunner.h @@ -0,0 +1,43 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC + +namespace zen::compute { + +/** Native macOS process runner for executing Mac worker executables directly. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using macOS Seatbelt + (sandbox_init): no network access and no filesystem access outside the + explicitly allowed sandbox and worker directories. This requires no elevation. + */ +class MacProcessRunner : public LocalProcessRunner +{ +public: + MacProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + bool m_Sandboxed = false; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp new file mode 100644 index 000000000..672636d06 --- /dev/null +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -0,0 +1,618 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "remotehttprunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/scopeguard.h> +# include <zencore/system.h> +# include <zencore/trace.h> +# include <zenhttp/httpcommon.h> +# include <zenstore/cidstore.h> + +# include <span> + +////////////////////////////////////////////////////////////////////////// + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, + const std::filesystem::path& BaseDir, + std::string_view HostName, + WorkerThreadPool& InWorkerPool) +: FunctionRunner(BaseDir) +, m_Log(logging::Get("http_exec")) +, m_ChunkResolver{InChunkResolver} +, m_WorkerPool{InWorkerPool} +, m_HostName{HostName} +, m_BaseUrl{fmt::format("{}/compute", HostName)} +, m_Http(m_BaseUrl) +, m_InstanceId(Oid::NewOid()) +{ + m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; +} + +RemoteHttpRunner::~RemoteHttpRunner() +{ + Shutdown(); +} + +void +RemoteHttpRunner::Shutdown() +{ + // TODO: should cleanly drain/cancel pending work + + m_MonitorThreadEnabled = false; + m_MonitorThreadEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } +} + +void +RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::RegisterWorker"); + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + CbPackage WorkerDesc = WorkerPackage; + + std::string WorkerUrl = fmt::format("/workers/{}", WorkerId); + + HttpClient::Response WorkerResponse = m_Http.Get(WorkerUrl); + + if (WorkerResponse.StatusCode == HttpResponseCode::NotFound) + { + HttpClient::Response DescResponse = m_Http.Post(WorkerUrl, WorkerDesc.GetObject()); + + if (DescResponse.StatusCode == HttpResponseCode::NotFound) + { + CbPackage Pkg = WorkerDesc; + + // Build response package by sending only the attachments + // the other end needs. We start with the full package and + // remove the attachments which are not needed. + + { + std::unordered_set<IoHash> Needed; + + CbObject Response = DescResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + Needed.insert(NeedHash); + } + + std::unordered_set<IoHash> ToRemove; + + for (const CbAttachment& Attachment : Pkg.GetAttachments()) + { + const IoHash& Hash = Attachment.GetHash(); + + if (Needed.find(Hash) == Needed.end()) + { + ToRemove.insert(Hash); + } + } + + for (const IoHash& Hash : ToRemove) + { + int RemovedCount = Pkg.RemoveAttachment(Hash); + + ZEN_ASSERT(RemovedCount == 1); + } + } + + // Post resulting package + + HttpClient::Response PayloadResponse = m_Http.Post(WorkerUrl, Pkg); + + if (!IsHttpSuccessCode(PayloadResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); + + // TODO: propagate error + } + } + else if (!IsHttpSuccessCode(DescResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); + + // TODO: propagate error + } + else + { + ZEN_ASSERT(DescResponse.StatusCode == HttpResponseCode::NoContent); + } + } + else if (WorkerResponse.StatusCode == HttpResponseCode::OK) + { + // Already known from a previous run + } + else if (!IsHttpSuccessCode(WorkerResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to look up worker {} at {}{} (error: {} {})", + WorkerId, + m_Http.GetBaseUri(), + WorkerUrl, + (int)WorkerResponse.StatusCode, + ToString(WorkerResponse.StatusCode)); + + // TODO: propagate error + } +} + +size_t +RemoteHttpRunner::QueryCapacity() +{ + // Estimate how much more work we're ready to accept + + RwLock::SharedLockScope _{m_RunningLock}; + + size_t RunningCount = m_RemoteRunningMap.size(); + + if (RunningCount >= size_t(m_MaxRunningActions)) + { + return 0; + } + + return m_MaxRunningActions - RunningCount; +} + +std::vector<SubmitResult> +RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions"); + + if (Actions.size() <= 1) + { + std::vector<SubmitResult> Results; + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; + } + + // For larger batches, submit HTTP requests in parallel via the shared worker pool + + std::vector<std::future<SubmitResult>> Futures; + Futures.reserve(Actions.size()); + + for (const Ref<RunnerAction>& Action : Actions) + { + std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); }); + + Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog)); + } + + std::vector<SubmitResult> Results; + Results.reserve(Futures.size()); + + for (auto& Future : Futures) + { + Results.push_back(Future.get()); + } + + return Results; +} + +SubmitResult +RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitAction"); + + // Verify whether we can accept more work + + { + RwLock::SharedLockScope _{m_RunningLock}; + if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) + { + return SubmitResult{.IsAccepted = false}; + } + } + + using namespace std::literals; + + // Each enqueued action is assigned an integer index (logical sequence number), + // which we use as a key for tracking data structures and as an opaque id which + // may be used by clients to reference the scheduled action + + Action->ExecutionLocation = m_HostName; + + const int32_t ActionLsn = Action->ActionLsn; + const CbObject& ActionObj = Action->ActionObj; + const IoHash ActionId = ActionObj.GetHash(); + + MaybeDumpAction(ActionLsn, ActionObj); + + // Determine the submission URL. If the action belongs to a queue, ensure a + // corresponding remote queue exists on the target node and submit via it. + + std::string SubmitUrl = "/jobs"; + if (const int QueueId = Action->QueueId; QueueId != 0) + { + CbObject QueueMeta = Action->GetOwnerSession()->GetQueueMetadata(QueueId); + CbObject QueueConfig = Action->GetOwnerSession()->GetQueueConfig(QueueId); + if (Oid Token = EnsureRemoteQueue(QueueId, QueueMeta, QueueConfig); Token != Oid::Zero) + { + SubmitUrl = fmt::format("/queues/{}/jobs", Token); + } + } + + // Enqueue job. If the remote returns FailedDependency (424), it means it + // cannot resolve the worker/function — re-register the worker and retry once. + + CbObject Result; + HttpClient::Response WorkResponse; + HttpResponseCode WorkResponseCode{}; + + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, ActionObj); + WorkResponseCode = WorkResponse.StatusCode; + + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); + + RegisterWorker(Action->Worker.Descriptor); + } + else + { + break; + } + } + + if (WorkResponseCode == HttpResponseCode::OK) + { + Result = WorkResponse.AsObject(); + } + else if (WorkResponseCode == HttpResponseCode::NotFound) + { + // Not all attachments are present + + // Build response package including all required attachments + + CbPackage Pkg; + Pkg.SetObject(ActionObj); + + CbObject Response = WorkResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + } + else + { + // No such attachment + + return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + } + } + + // Post resulting package + + HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); + + if (!PayloadResponse) + { + ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); + + // TODO: include more information about the failure in the response + + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + else if (PayloadResponse.StatusCode == HttpResponseCode::OK) + { + Result = PayloadResponse.AsObject(); + } + else + { + // Unexpected response + + const int ResponseStatusCode = (int)PayloadResponse.StatusCode; + + ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})", + ActionId, + m_Http.GetBaseUri(), + SubmitUrl, + ResponseStatusCode, + ToString(ResponseStatusCode)); + + return {.IsAccepted = false, + .Reason = fmt::format("unexpected response code {} {} from {}{}", + ResponseStatusCode, + ToString(ResponseStatusCode), + m_Http.GetBaseUri(), + SubmitUrl)}; + } + } + + if (Result) + { + if (const int32_t LsnField = Result["lsn"].AsInt32(0)) + { + HttpRunningAction NewAction; + NewAction.Action = Action; + NewAction.RemoteActionLsn = LsnField; + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + + m_RemoteRunningMap[LsnField] = std::move(NewAction); + } + + ZEN_DEBUG("scheduled action {} with remote LSN {} (local LSN {})", ActionId, LsnField, ActionLsn); + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; + } + } + + return {}; +} + +Oid +RemoteHttpRunner::EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config) +{ + { + RwLock::SharedLockScope _(m_QueueTokenLock); + if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end()) + { + return It->second; + } + } + + // Build a stable idempotency key that uniquely identifies this (runner instance, local queue) + // pair. The server uses this to return the same remote queue token for concurrent or redundant + // requests, preventing orphaned remote queues when multiple threads race through here. + // Also send hostname so the server can associate the queue with its origin for diagnostics. + CbObjectWriter Body; + Body << "idempotency_key"sv << fmt::format("{}/{}", m_InstanceId, QueueId); + Body << "hostname"sv << GetMachineName(); + if (Metadata) + { + Body << "metadata"sv << Metadata; + } + if (Config) + { + Body << "config"sv << Config; + } + + HttpClient::Response Resp = m_Http.Post("/queues/remote", Body.Save()); + if (!Resp) + { + ZEN_WARN("failed to create remote queue for local queue {} on {}", QueueId, m_HostName); + return Oid::Zero; + } + + Oid Token = Oid::TryFromHexString(Resp.AsObject()["queue_token"sv].AsString()); + if (Token == Oid::Zero) + { + return Oid::Zero; + } + + ZEN_DEBUG("created remote queue '{}' for local queue {} on {}", Token, QueueId, m_HostName); + + RwLock::ExclusiveLockScope _(m_QueueTokenLock); + auto [It, Inserted] = m_RemoteQueueTokens.try_emplace(QueueId, Token); + return It->second; +} + +void +RemoteHttpRunner::CancelRemoteQueue(int QueueId) +{ + Oid Token; + { + RwLock::SharedLockScope _(m_QueueTokenLock); + if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end()) + { + Token = It->second; + } + } + + if (Token == Oid::Zero) + { + return; + } + + HttpClient::Response Resp = m_Http.Delete(fmt::format("/queues/{}", Token)); + + if (Resp.StatusCode == HttpResponseCode::NoContent) + { + ZEN_DEBUG("cancelled remote queue '{}' (local queue {}) on {}", Token, QueueId, m_HostName); + } + else + { + ZEN_WARN("failed to cancel remote queue '{}' on {}: {}", Token, m_HostName, int(Resp.StatusCode)); + } +} + +bool +RemoteHttpRunner::IsHealthy() +{ + if (HttpClient::Response Ready = m_Http.Get("/ready")) + { + return true; + } + else + { + // TODO: use response to propagate context + return false; + } +} + +size_t +RemoteHttpRunner::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunningLock); + return m_RemoteRunningMap.size(); +} + +void +RemoteHttpRunner::MonitorThreadFunction() +{ + SetCurrentThreadName("RemoteHttpRunner_Monitor"); + + do + { + const int NormalWaitingTime = 200; + int WaitTimeMs = NormalWaitingTime; + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; + auto SweepOnce = [&] { + const size_t RetiredCount = SweepRunningActions(); + + m_RunningLock.WithSharedLock([&] { + if (m_RemoteRunningMap.size() > 16) + { + WaitTimeMs = NormalWaitingTime / 4; + } + else + { + if (RetiredCount) + { + WaitTimeMs = NormalWaitingTime / 2; + } + else + { + WaitTimeMs = NormalWaitingTime; + } + } + }); + }; + + while (!WaitOnce()) + { + SweepOnce(); + } + + // Signal received - this may mean we should quit + + SweepOnce(); + } while (m_MonitorThreadEnabled); +} + +size_t +RemoteHttpRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SweepRunningActions"); + std::vector<HttpRunningAction> CompletedActions; + + // Poll remote for list of completed actions + + HttpClient::Response ResponseCompleted = m_Http.Get("/jobs/completed"sv); + + if (CbObject Completed = ResponseCompleted.AsObject()) + { + for (auto& FieldIt : Completed["completed"sv]) + { + CbObjectView EntryObj = FieldIt.AsObjectView(); + const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); + std::string_view StateName = EntryObj["state"sv].AsString(); + + RunnerAction::State RemoteState = RunnerAction::FromString(StateName); + + // Always fetch to drain the result from the remote's results map, + // but only keep the result package for successfully completed actions. + HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn)); + + m_RunningLock.WithExclusiveLock([&] { + if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end()) + { + HttpRunningAction CompletedAction = std::move(CompleteIt->second); + CompletedAction.RemoteState = RemoteState; + + if (RemoteState == RunnerAction::State::Completed && ResponseJob) + { + CompletedAction.ActionResults = ResponseJob.AsPackage(); + } + + CompletedActions.push_back(std::move(CompletedAction)); + m_RemoteRunningMap.erase(CompleteIt); + } + else + { + // we received a completion notice for an action we don't know about, + // this can happen if the runner is used by multiple upstream schedulers, + // or if this compute node was recently restarted and lost track of + // previously scheduled actions + } + }); + } + + if (CbObjectView Metrics = Completed["metrics"sv].AsObjectView()) + { + // if (const size_t CpuCount = Metrics["core_count"].AsInt32(0)) + if (const int32_t CpuCount = Metrics["lp_count"].AsInt32(0)) + { + const int32_t NewCap = zen::Max(4, CpuCount); + + if (m_MaxRunningActions > NewCap) + { + ZEN_DEBUG("capping {} to {} actions (was {})", m_BaseUrl, NewCap, m_MaxRunningActions); + + m_MaxRunningActions = NewCap; + } + } + } + } + + // Notify outer. Note that this has to be done without holding any local locks + // otherwise we may end up with deadlocks. + + for (HttpRunningAction& HttpAction : CompletedActions) + { + const int ActionLsn = HttpAction.Action->ActionLsn; + + ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + RunnerAction::ToString(HttpAction.RemoteState)); + + if (HttpAction.RemoteState == RunnerAction::State::Completed) + { + HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); + } + + HttpAction.Action->SetActionState(HttpAction.RemoteState); + } + + return CompletedActions.size(); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h new file mode 100644 index 000000000..9119992a9 --- /dev/null +++ b/src/zencompute/runners/remotehttprunner.h @@ -0,0 +1,100 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "functionrunner.h" + +# include <zencore/compactbinarypackage.h> +# include <zencore/logging.h> +# include <zencore/uid.h> +# include <zencore/workthreadpool.h> +# include <zencore/zencore.h> +# include <zenhttp/httpclient.h> + +# include <atomic> +# include <filesystem> +# include <thread> +# include <unordered_map> + +namespace zen { +class CidStore; +} + +namespace zen::compute { + +/** HTTP-based runner + + This implements a DDC remote compute execution strategy via REST API + + */ + +class RemoteHttpRunner : public FunctionRunner +{ + RemoteHttpRunner(RemoteHttpRunner&&) = delete; + RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete; + +public: + RemoteHttpRunner(ChunkResolver& InChunkResolver, + const std::filesystem::path& BaseDir, + std::string_view HostName, + WorkerThreadPool& InWorkerPool); + ~RemoteHttpRunner(); + + virtual void Shutdown() override; + virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + [[nodiscard]] virtual bool IsHealthy() override; + [[nodiscard]] virtual size_t GetSubmittedActionCount() override; + [[nodiscard]] virtual size_t QueryCapacity() override; + [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override; + virtual void CancelRemoteQueue(int QueueId) override; + + std::string_view GetHostName() const { return m_HostName; } + +protected: + LoggerRef Log() { return m_Log; } + +private: + LoggerRef m_Log; + ChunkResolver& m_ChunkResolver; + WorkerThreadPool& m_WorkerPool; + std::string m_HostName; + std::string m_BaseUrl; + HttpClient m_Http; + + int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + + struct HttpRunningAction + { + Ref<RunnerAction> Action; + int RemoteActionLsn = 0; // Remote LSN + RunnerAction::State RemoteState = RunnerAction::State::Failed; + CbPackage ActionResults; + }; + + RwLock m_RunningLock; + std::unordered_map<int, HttpRunningAction> m_RemoteRunningMap; // Note that this is keyed on the *REMOTE* lsn + + std::thread m_MonitorThread; + std::atomic<bool> m_MonitorThreadEnabled{true}; + Event m_MonitorThreadEvent; + void MonitorThreadFunction(); + size_t SweepRunningActions(); + + RwLock m_QueueTokenLock; + std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId → remote queue token + + // Stable identity for this runner instance, used as part of the idempotency key when + // creating remote queues. Generated once at construction and never changes. + Oid m_InstanceId; + + Oid EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp new file mode 100644 index 000000000..e9a1ae8b6 --- /dev/null +++ b/src/zencompute/runners/windowsrunner.cpp @@ -0,0 +1,460 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "windowsrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/scopeguard.h> +# include <zencore/trace.h> +# include <zencore/system.h> +# include <zencore/timer.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <userenv.h> +# include <aclapi.h> +# include <sddl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +using namespace std::literals; + +WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + if (!m_Sandboxed) + { + return; + } + + // Build a unique profile name per process to avoid collisions + m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); + + // Clean up any stale profile from a previous crash + DeleteAppContainerProfile(m_AppContainerName.c_str()); + + PSID Sid = nullptr; + + HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), + m_AppContainerName.c_str(), // display name + m_AppContainerName.c_str(), // description + nullptr, // no capabilities + 0, // capability count + &Sid); + + if (FAILED(Hr)) + { + throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); + } + + m_AppContainerSid = Sid; + + ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); +} + +WindowsProcessRunner::~WindowsProcessRunner() +{ + if (m_AppContainerSid) + { + FreeSid(m_AppContainerSid); + m_AppContainerSid = nullptr; + } + + if (!m_AppContainerName.empty()) + { + DeleteAppContainerProfile(m_AppContainerName.c_str()); + } +} + +void +WindowsProcessRunner::GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask) +{ + PACL ExistingDacl = nullptr; + PSECURITY_DESCRIPTOR SecurityDescriptor = nullptr; + + DWORD Result = GetNamedSecurityInfoW(Path.c_str(), + SE_FILE_OBJECT, + DACL_SECURITY_INFORMATION, + nullptr, + nullptr, + &ExistingDacl, + nullptr, + &SecurityDescriptor); + + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("GetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } + + auto $0 = MakeGuard([&] { LocalFree(SecurityDescriptor); }); + + EXPLICIT_ACCESSW Access{}; + Access.grfAccessPermissions = AccessMask; + Access.grfAccessMode = SET_ACCESS; + Access.grfInheritance = OBJECT_INHERIT_ACE | CONTAINER_INHERIT_ACE; + Access.Trustee.TrusteeForm = TRUSTEE_IS_SID; + Access.Trustee.TrusteeType = TRUSTEE_IS_WELL_KNOWN_GROUP; + Access.Trustee.ptstrName = static_cast<LPWSTR>(m_AppContainerSid); + + PACL NewDacl = nullptr; + + Result = SetEntriesInAclW(1, &Access, ExistingDacl, &NewDacl); + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("SetEntriesInAclW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } + + auto $1 = MakeGuard([&] { LocalFree(NewDacl); }); + + Result = SetNamedSecurityInfoW(const_cast<LPWSTR>(Path.c_str()), + SE_FILE_OBJECT, + DACL_SECURITY_INFORMATION, + nullptr, + nullptr, + NewDacl, + nullptr); + + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("SetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } +} + +SubmitResult +WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("WindowsProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Set up environment variables + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + StringBuilder<1024> EnvironmentBlock; + + for (auto& It : WorkerDescription["environment"sv]) + { + EnvironmentBlock.Append(It.AsString()); + EnvironmentBlock.Append('\0'); + } + EnvironmentBlock.Append('\0'); + EnvironmentBlock.Append('\0'); + + // Execute process - this spawns the child process immediately without waiting + // for completion + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred(); + + ExtendableWideStringBuilder<512> CommandLine; + CommandLine.Append(L'"'); + CommandLine.Append(ExePath.c_str()); + CommandLine.Append(L'"'); + CommandLine.Append(L" -Build=build.action"); + + LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; + LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; + BOOL bInheritHandles = FALSE; + DWORD dwCreationFlags = 0; + + ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed); + + CommandLine.EnsureNulTerminated(); + + PROCESS_INFORMATION ProcessInformation{}; + + if (m_Sandboxed) + { + // Grant AppContainer access to sandbox and worker directories + GrantAppContainerAccess(Prepared->SandboxPath, FILE_ALL_ACCESS); + GrantAppContainerAccess(Prepared->WorkerPath, FILE_GENERIC_READ | FILE_GENERIC_EXECUTE); + + // Set up extended startup info with AppContainer security capabilities + SECURITY_CAPABILITIES SecurityCapabilities{}; + SecurityCapabilities.AppContainerSid = m_AppContainerSid; + SecurityCapabilities.Capabilities = nullptr; + SecurityCapabilities.CapabilityCount = 0; + + SIZE_T AttrListSize = 0; + InitializeProcThreadAttributeList(nullptr, 1, 0, &AttrListSize); + + auto AttrList = static_cast<PPROC_THREAD_ATTRIBUTE_LIST>(malloc(AttrListSize)); + auto $0 = MakeGuard([&] { free(AttrList); }); + + if (!InitializeProcThreadAttributeList(AttrList, 1, 0, &AttrListSize)) + { + zen::ThrowLastError("InitializeProcThreadAttributeList failed"); + } + + auto $1 = MakeGuard([&] { DeleteProcThreadAttributeList(AttrList); }); + + if (!UpdateProcThreadAttribute(AttrList, + 0, + PROC_THREAD_ATTRIBUTE_SECURITY_CAPABILITIES, + &SecurityCapabilities, + sizeof(SecurityCapabilities), + nullptr, + nullptr)) + { + zen::ThrowLastError("UpdateProcThreadAttribute (SECURITY_CAPABILITIES) failed"); + } + + STARTUPINFOEXW StartupInfoEx{}; + StartupInfoEx.StartupInfo.cb = sizeof(STARTUPINFOEXW); + StartupInfoEx.lpAttributeList = AttrList; + + dwCreationFlags |= EXTENDED_STARTUPINFO_PRESENT; + + BOOL Success = CreateProcessW(nullptr, + CommandLine.Data(), + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags, + (LPVOID)EnvironmentBlock.Data(), + Prepared->SandboxPath.c_str(), + &StartupInfoEx.StartupInfo, + /* out */ &ProcessInformation); + + if (!Success) + { + zen::ThrowLastError("Unable to launch sandboxed process"); + } + } + else + { + STARTUPINFO StartupInfo{}; + StartupInfo.cb = sizeof StartupInfo; + + BOOL Success = CreateProcessW(nullptr, + CommandLine.Data(), + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags, + (LPVOID)EnvironmentBlock.Data(), + Prepared->SandboxPath.c_str(), + &StartupInfo, + /* out */ &ProcessInformation); + + if (!Success) + { + zen::ThrowLastError("Unable to launch process"); + } + } + + CloseHandle(ProcessInformation.hThread); + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = ProcessInformation.hProcess; + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +WindowsProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("WindowsProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + DWORD ExitCode = 0; + BOOL IsSuccess = GetExitCodeProcess(Running->ProcessHandle, &ExitCode); + + if (IsSuccess && ExitCode != STILL_ACTIVE) + { + CloseHandle(Running->ProcessHandle); + Running->ProcessHandle = INVALID_HANDLE_VALUE; + Running->ExitCode = ExitCode; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +WindowsProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("WindowsProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // For expedience we initiate the process termination for all known + // processes before attempting to wait for them to exit. + + // Initiate termination for all known processes before waiting for them to exit. + + for (const auto& Kv : RunningMap) + { + Ref<RunningAction> Running = Kv.second; + + BOOL TermSuccess = TerminateProcess(Running->ProcessHandle, 222); + + if (!TermSuccess) + { + DWORD LastError = GetLastError(); + + if (LastError != ERROR_ACCESS_DENIED) + { + ZEN_WARN("TerminateProcess for LSN {} not successful: {}", Running->Action->ActionLsn, GetSystemErrorAsString(LastError)); + } + } + } + + // Wait for all processes and clean up, regardless of whether TerminateProcess succeeded. + + for (auto& [Lsn, Running] : RunningMap) + { + if (Running->ProcessHandle != INVALID_HANDLE_VALUE) + { + DWORD WaitResult = WaitForSingleObject(Running->ProcessHandle, 2000); + + if (WaitResult != WAIT_OBJECT_0) + { + ZEN_WARN("wait for LSN {}: process exit did not succeed, result = {}", Running->Action->ActionLsn, WaitResult); + } + else + { + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + } + + CloseHandle(Running->ProcessHandle); + Running->ProcessHandle = INVALID_HANDLE_VALUE; + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +WindowsProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("WindowsProcessRunner::CancelAction"); + + // Hold the shared lock while terminating to prevent the sweep thread from + // closing the handle between our lookup and TerminateProcess call. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (Target->ProcessHandle == INVALID_HANDLE_VALUE) + { + return; + } + + BOOL TermSuccess = TerminateProcess(Target->ProcessHandle, 222); + + if (!TermSuccess) + { + DWORD LastError = GetLastError(); + + if (LastError != ERROR_ACCESS_DENIED) + { + ZEN_WARN("CancelAction: TerminateProcess for LSN {} not successful: {}", ActionLsn, GetSystemErrorAsString(LastError)); + } + + return; + } + + ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +void +WindowsProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + FILETIME CreationTime, ExitTime, KernelTime, UserTime; + if (!GetProcessTimes(Running.ProcessHandle, &CreationTime, &ExitTime, &KernelTime, &UserTime)) + { + return; + } + + auto FtToU64 = [](FILETIME Ft) -> uint64_t { return (static_cast<uint64_t>(Ft.dwHighDateTime) << 32) | Ft.dwLowDateTime; }; + + // FILETIME values are in 100-nanosecond intervals + const uint64_t CurrentOsTicks = FtToU64(KernelTime) + FtToU64(UserTime); + const uint64_t NowTicks = GetHifreqTimerValue(); + + // Cumulative CPU seconds (absolute, available from first sample): 100ns → seconds + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 10'000'000.0), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + // 100ns → ms: divide by 10000; then as percent of elapsed ms + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 10000.0 / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h new file mode 100644 index 000000000..9f2385cc4 --- /dev/null +++ b/src/zencompute/runners/windowsrunner.h @@ -0,0 +1,53 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS + +# include <zencore/windows.h> + +# include <string> + +namespace zen::compute { + +/** Windows process runner using CreateProcessW for executing worker executables. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using a Windows AppContainer: + no network access (AppContainer blocks network by default when no capabilities are + granted) and no filesystem access outside explicitly granted sandbox and worker + directories. This requires no elevation. + */ +class WindowsProcessRunner : public LocalProcessRunner +{ +public: + WindowsProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + ~WindowsProcessRunner(); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + void GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask); + + bool m_Sandboxed = false; + PSID m_AppContainerSid = nullptr; + std::wstring m_AppContainerName; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp new file mode 100644 index 000000000..506bec73b --- /dev/null +++ b/src/zencompute/runners/winerunner.cpp @@ -0,0 +1,237 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "winerunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <signal.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); +} + +SubmitResult +WineProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("WineProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: wine <worker_exe_path> -Build=build.action + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath); + std::string ExePathStr = ExePath.string(); + std::string WinePathStr = m_WinePath; + std::string BuildArg = "-Build=build.action"; + + std::vector<char*> ArgV; + ArgV.push_back(WinePathStr.data()); + ArgV.push_back(ExePathStr.data()); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing via Wine: {} {} {}", WinePathStr, ExePathStr, BuildArg); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + throw std::runtime_error(fmt::format("fork() failed: {}", strerror(errno))); + } + + if (ChildPid == 0) + { + // Child process + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(WinePathStr.c_str(), ArgV.data(), Envp.data()); + + // execve only returns on failure + _exit(127); + } + + // Parent: store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +WineProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("WineProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +WineProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("WineProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/winerunner.h b/src/zencompute/runners/winerunner.h new file mode 100644 index 000000000..7df62e7c0 --- /dev/null +++ b/src/zencompute/runners/winerunner.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <string> + +namespace zen::compute { + +/** Wine-based process runner for executing Windows worker executables on Linux. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + */ +class WineProcessRunner : public LocalProcessRunner +{ +public: + WineProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + +private: + std::string m_WinePath = "wine"; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/testing/mockimds.cpp b/src/zencompute/testing/mockimds.cpp new file mode 100644 index 000000000..dd09312df --- /dev/null +++ b/src/zencompute/testing/mockimds.cpp @@ -0,0 +1,205 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/mockimds.h> + +#include <zencore/fmtutils.h> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +const char* +MockImdsService::BaseUri() const +{ + return "/"; +} + +void +MockImdsService::HandleRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // AWS endpoints live under /latest/ + if (Uri.starts_with("latest/")) + { + if (ActiveProvider == CloudProvider::AWS) + { + HandleAwsRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // Azure endpoints live under /metadata/ + if (Uri.starts_with("metadata/")) + { + if (ActiveProvider == CloudProvider::Azure) + { + HandleAzureRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // GCP endpoints live under /computeMetadata/ + if (Uri.starts_with("computeMetadata/")) + { + if (ActiveProvider == CloudProvider::GCP) + { + HandleGcpRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAwsRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // IMDSv2 token acquisition (PUT only) + if (Uri == "latest/api/token" && Request.RequestVerb() == HttpVerb::kPut) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.Token); + return; + } + + // Instance identity + if (Uri == "latest/meta-data/instance-id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.InstanceId); + return; + } + + if (Uri == "latest/meta-data/placement/availability-zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AvailabilityZone); + return; + } + + if (Uri == "latest/meta-data/instance-life-cycle") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.LifeCycle); + return; + } + + // Autoscaling lifecycle state — 404 when not in an ASG + if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") + { + if (Aws.AutoscalingState.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AutoscalingState); + return; + } + + // Spot interruption notice — 404 when no interruption pending + if (Uri == "latest/meta-data/spot/instance-action") + { + if (Aws.SpotAction.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.SpotAction); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAzureRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // Instance metadata (single JSON document) + if (Uri == "metadata/instance") + { + std::string Json = fmt::format(R"({{"compute":{{"vmId":"{}","location":"{}","priority":"{}","vmScaleSetName":"{}"}}}})", + Azure.VmId, + Azure.Location, + Azure.Priority, + Azure.VmScaleSetName); + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + // Scheduled events for termination monitoring + if (Uri == "metadata/scheduledevents") + { + std::string Json; + if (Azure.ScheduledEventType.empty()) + { + Json = R"({"Events":[]})"; + } + else + { + Json = fmt::format(R"({{"Events":[{{"EventType":"{}","EventStatus":"{}"}}]}})", + Azure.ScheduledEventType, + Azure.ScheduledEventStatus); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleGcpRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + if (Uri == "computeMetadata/v1/instance/id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.InstanceId); + return; + } + + if (Uri == "computeMetadata/v1/instance/zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Zone); + return; + } + + if (Uri == "computeMetadata/v1/instance/scheduling/preemptible") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Preemptible); + return; + } + + if (Uri == "computeMetadata/v1/instance/maintenance-event") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.MaintenanceEvent); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/timeline/workertimeline.cpp b/src/zencompute/timeline/workertimeline.cpp new file mode 100644 index 000000000..88ef5b62d --- /dev/null +++ b/src/zencompute/timeline/workertimeline.cpp @@ -0,0 +1,430 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "workertimeline.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/basicfile.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinaryfile.h> + +# include <algorithm> + +namespace zen::compute { + +WorkerTimeline::WorkerTimeline(std::string_view WorkerId) : m_WorkerId(WorkerId) +{ +} + +WorkerTimeline::~WorkerTimeline() +{ +} + +void +WorkerTimeline::RecordProvisioned() +{ + AppendEvent({ + .Type = EventType::WorkerProvisioned, + .Timestamp = DateTime::Now(), + }); +} + +void +WorkerTimeline::RecordDeprovisioned() +{ + AppendEvent({ + .Type = EventType::WorkerDeprovisioned, + .Timestamp = DateTime::Now(), + }); +} + +void +WorkerTimeline::RecordActionAccepted(int ActionLsn, const IoHash& ActionId) +{ + AppendEvent({ + .Type = EventType::ActionAccepted, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + }); +} + +void +WorkerTimeline::RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason) +{ + AppendEvent({ + .Type = EventType::ActionRejected, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + .Reason = std::string(Reason), + }); +} + +void +WorkerTimeline::RecordActionStateChanged(int ActionLsn, + const IoHash& ActionId, + RunnerAction::State PreviousState, + RunnerAction::State NewState) +{ + AppendEvent({ + .Type = EventType::ActionStateChanged, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + .ActionState = NewState, + .PreviousState = PreviousState, + }); +} + +std::vector<WorkerTimeline::Event> +WorkerTimeline::QueryTimeline(DateTime StartTime, DateTime EndTime) const +{ + std::vector<Event> Result; + + m_EventsLock.WithSharedLock([&] { + for (const auto& Evt : m_Events) + { + if (Evt.Timestamp >= StartTime && Evt.Timestamp <= EndTime) + { + Result.push_back(Evt); + } + } + }); + + return Result; +} + +std::vector<WorkerTimeline::Event> +WorkerTimeline::QueryRecent(int Limit) const +{ + std::vector<Event> Result; + + m_EventsLock.WithSharedLock([&] { + const int Count = std::min(Limit, gsl::narrow<int>(m_Events.size())); + auto It = m_Events.end() - Count; + Result.assign(It, m_Events.end()); + }); + + return Result; +} + +size_t +WorkerTimeline::GetEventCount() const +{ + size_t Count = 0; + m_EventsLock.WithSharedLock([&] { Count = m_Events.size(); }); + return Count; +} + +WorkerTimeline::TimeRange +WorkerTimeline::GetTimeRange() const +{ + TimeRange Range; + m_EventsLock.WithSharedLock([&] { + if (!m_Events.empty()) + { + Range.First = m_Events.front().Timestamp; + Range.Last = m_Events.back().Timestamp; + } + }); + return Range; +} + +void +WorkerTimeline::AppendEvent(Event&& Evt) +{ + m_EventsLock.WithExclusiveLock([&] { + while (m_Events.size() >= m_MaxEvents) + { + m_Events.pop_front(); + } + + m_Events.push_back(std::move(Evt)); + }); +} + +const char* +WorkerTimeline::ToString(EventType Type) +{ + switch (Type) + { + case EventType::WorkerProvisioned: + return "provisioned"; + case EventType::WorkerDeprovisioned: + return "deprovisioned"; + case EventType::ActionAccepted: + return "accepted"; + case EventType::ActionRejected: + return "rejected"; + case EventType::ActionStateChanged: + return "state_changed"; + default: + return "unknown"; + } +} + +static WorkerTimeline::EventType +EventTypeFromString(std::string_view Str) +{ + if (Str == "provisioned") + return WorkerTimeline::EventType::WorkerProvisioned; + if (Str == "deprovisioned") + return WorkerTimeline::EventType::WorkerDeprovisioned; + if (Str == "accepted") + return WorkerTimeline::EventType::ActionAccepted; + if (Str == "rejected") + return WorkerTimeline::EventType::ActionRejected; + if (Str == "state_changed") + return WorkerTimeline::EventType::ActionStateChanged; + return WorkerTimeline::EventType::WorkerProvisioned; +} + +void +WorkerTimeline::WriteTo(const std::filesystem::path& Path) const +{ + CbObjectWriter Cbo; + Cbo << "worker_id" << m_WorkerId; + + m_EventsLock.WithSharedLock([&] { + if (!m_Events.empty()) + { + Cbo.AddDateTime("time_first", m_Events.front().Timestamp); + Cbo.AddDateTime("time_last", m_Events.back().Timestamp); + } + + Cbo.BeginArray("events"); + for (const auto& Evt : m_Events) + { + Cbo.BeginObject(); + Cbo << "type" << ToString(Evt.Type); + Cbo.AddDateTime("ts", Evt.Timestamp); + + if (Evt.ActionLsn != 0) + { + Cbo << "lsn" << Evt.ActionLsn; + Cbo << "action_id" << Evt.ActionId; + } + + if (Evt.Type == EventType::ActionStateChanged) + { + Cbo << "prev_state" << static_cast<int32_t>(Evt.PreviousState); + Cbo << "state" << static_cast<int32_t>(Evt.ActionState); + } + + if (!Evt.Reason.empty()) + { + Cbo << "reason" << std::string_view(Evt.Reason); + } + + Cbo.EndObject(); + } + Cbo.EndArray(); + }); + + CbObject Obj = Cbo.Save(); + + BasicFile File(Path, BasicFile::Mode::kTruncate); + File.Write(Obj.GetBuffer().GetView(), 0); +} + +void +WorkerTimeline::ReadFrom(const std::filesystem::path& Path) +{ + CbObjectFromFile Loaded = LoadCompactBinaryObject(Path); + CbObject Root = std::move(Loaded.Object); + + if (!Root) + { + return; + } + + std::deque<Event> LoadedEvents; + + for (CbFieldView Field : Root["events"].AsArrayView()) + { + CbObjectView EventObj = Field.AsObjectView(); + + Event Evt; + Evt.Type = EventTypeFromString(EventObj["type"].AsString()); + Evt.Timestamp = EventObj["ts"].AsDateTime(); + + Evt.ActionLsn = EventObj["lsn"].AsInt32(); + Evt.ActionId = EventObj["action_id"].AsHash(); + + if (Evt.Type == EventType::ActionStateChanged) + { + Evt.PreviousState = static_cast<RunnerAction::State>(EventObj["prev_state"].AsInt32()); + Evt.ActionState = static_cast<RunnerAction::State>(EventObj["state"].AsInt32()); + } + + std::string_view Reason = EventObj["reason"].AsString(); + if (!Reason.empty()) + { + Evt.Reason = std::string(Reason); + } + + LoadedEvents.push_back(std::move(Evt)); + } + + m_EventsLock.WithExclusiveLock([&] { m_Events = std::move(LoadedEvents); }); +} + +WorkerTimeline::TimeRange +WorkerTimeline::ReadTimeRange(const std::filesystem::path& Path) +{ + CbObjectFromFile Loaded = LoadCompactBinaryObject(Path); + + if (!Loaded.Object) + { + return {}; + } + + return { + .First = Loaded.Object["time_first"].AsDateTime(), + .Last = Loaded.Object["time_last"].AsDateTime(), + }; +} + +// WorkerTimelineStore + +static constexpr std::string_view kTimelineExtension = ".ztimeline"; + +WorkerTimelineStore::WorkerTimelineStore(std::filesystem::path PersistenceDir) : m_PersistenceDir(std::move(PersistenceDir)) +{ + std::error_code Ec; + std::filesystem::create_directories(m_PersistenceDir, Ec); +} + +Ref<WorkerTimeline> +WorkerTimelineStore::GetOrCreate(std::string_view WorkerId) +{ + // Fast path: check if it already exists in memory + { + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + return It->second; + } + } + + // Slow path: create under exclusive lock, loading from disk if available + RwLock::ExclusiveLockScope _(m_Lock); + + auto& Entry = m_Timelines[std::string(WorkerId)]; + if (!Entry) + { + Entry = Ref<WorkerTimeline>(new WorkerTimeline(WorkerId)); + + std::filesystem::path Path = TimelinePath(WorkerId); + std::error_code Ec; + if (std::filesystem::is_regular_file(Path, Ec)) + { + Entry->ReadFrom(Path); + } + } + return Entry; +} + +Ref<WorkerTimeline> +WorkerTimelineStore::Find(std::string_view WorkerId) +{ + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + return It->second; + } + return {}; +} + +std::vector<std::string> +WorkerTimelineStore::GetActiveWorkerIds() const +{ + std::vector<std::string> Result; + + RwLock::SharedLockScope $(m_Lock); + Result.reserve(m_Timelines.size()); + for (const auto& [Id, _] : m_Timelines) + { + Result.push_back(Id); + } + + return Result; +} + +std::vector<WorkerTimelineStore::WorkerTimelineInfo> +WorkerTimelineStore::GetAllWorkerInfo() const +{ + std::unordered_map<std::string, WorkerTimeline::TimeRange> InfoMap; + + { + RwLock::SharedLockScope _(m_Lock); + for (const auto& [Id, Timeline] : m_Timelines) + { + InfoMap[Id] = Timeline->GetTimeRange(); + } + } + + std::error_code Ec; + for (const auto& Entry : std::filesystem::directory_iterator(m_PersistenceDir, Ec)) + { + if (!Entry.is_regular_file()) + { + continue; + } + + const auto& Path = Entry.path(); + if (Path.extension().string() != kTimelineExtension) + { + continue; + } + + std::string Id = Path.stem().string(); + if (InfoMap.find(Id) == InfoMap.end()) + { + InfoMap[Id] = WorkerTimeline::ReadTimeRange(Path); + } + } + + std::vector<WorkerTimelineInfo> Result; + Result.reserve(InfoMap.size()); + for (auto& [Id, Range] : InfoMap) + { + Result.push_back({.WorkerId = std::move(Id), .Range = Range}); + } + return Result; +} + +void +WorkerTimelineStore::Save(std::string_view WorkerId) +{ + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + It->second->WriteTo(TimelinePath(WorkerId)); + } +} + +void +WorkerTimelineStore::SaveAll() +{ + RwLock::SharedLockScope _(m_Lock); + for (const auto& [Id, Timeline] : m_Timelines) + { + Timeline->WriteTo(TimelinePath(Id)); + } +} + +std::filesystem::path +WorkerTimelineStore::TimelinePath(std::string_view WorkerId) const +{ + return m_PersistenceDir / (std::string(WorkerId) + std::string(kTimelineExtension)); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/timeline/workertimeline.h b/src/zencompute/timeline/workertimeline.h new file mode 100644 index 000000000..87e19bc28 --- /dev/null +++ b/src/zencompute/timeline/workertimeline.h @@ -0,0 +1,169 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../runners/functionrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zenbase/refcount.h> +# include <zencore/compactbinary.h> +# include <zencore/iohash.h> +# include <zencore/thread.h> +# include <zencore/timer.h> + +# include <deque> +# include <filesystem> +# include <string> +# include <string_view> +# include <unordered_map> +# include <vector> + +namespace zen::compute { + +struct RunnerAction; + +/** Worker activity timeline for tracking and visualizing worker activity over time. + * + * Records worker lifecycle events (provisioning/deprovisioning) and action lifecycle + * events (accept, reject, state changes) with timestamps, enabling time-range queries + * for dashboard visualization. + */ +class WorkerTimeline : public RefCounted +{ +public: + explicit WorkerTimeline(std::string_view WorkerId); + ~WorkerTimeline() override; + + struct TimeRange + { + DateTime First = DateTime(0); + DateTime Last = DateTime(0); + + explicit operator bool() const { return First.GetTicks() != 0; } + }; + + enum class EventType + { + WorkerProvisioned, + WorkerDeprovisioned, + ActionAccepted, + ActionRejected, + ActionStateChanged + }; + + static const char* ToString(EventType Type); + + struct Event + { + EventType Type; + DateTime Timestamp = DateTime(0); + + // Action context (only set for action events) + int ActionLsn = 0; + IoHash ActionId; + RunnerAction::State ActionState = RunnerAction::State::New; + RunnerAction::State PreviousState = RunnerAction::State::New; + + // Optional reason (e.g. rejection reason) + std::string Reason; + }; + + /** Record that this worker has been provisioned and is available for work. */ + void RecordProvisioned(); + + /** Record that this worker has been deprovisioned and is no longer available. */ + void RecordDeprovisioned(); + + /** Record that an action was accepted by this worker. */ + void RecordActionAccepted(int ActionLsn, const IoHash& ActionId); + + /** Record that an action was rejected by this worker. */ + void RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason); + + /** Record an action state transition on this worker. */ + void RecordActionStateChanged(int ActionLsn, const IoHash& ActionId, RunnerAction::State PreviousState, RunnerAction::State NewState); + + /** Query events within a time range (inclusive). Returns events ordered by timestamp. */ + [[nodiscard]] std::vector<Event> QueryTimeline(DateTime StartTime, DateTime EndTime) const; + + /** Query the most recent N events. */ + [[nodiscard]] std::vector<Event> QueryRecent(int Limit = 100) const; + + /** Return the total number of recorded events. */ + [[nodiscard]] size_t GetEventCount() const; + + /** Return the time range covered by the events in this timeline. */ + [[nodiscard]] TimeRange GetTimeRange() const; + + [[nodiscard]] const std::string& GetWorkerId() const { return m_WorkerId; } + + /** Write the timeline to a file at the given path. */ + void WriteTo(const std::filesystem::path& Path) const; + + /** Read the timeline from a file at the given path. Replaces current in-memory events. */ + void ReadFrom(const std::filesystem::path& Path); + + /** Read only the time range from a persisted timeline file, without loading events. */ + [[nodiscard]] static TimeRange ReadTimeRange(const std::filesystem::path& Path); + +private: + void AppendEvent(Event&& Evt); + + std::string m_WorkerId; + mutable RwLock m_EventsLock; + std::deque<Event> m_Events; + size_t m_MaxEvents = 10'000; +}; + +/** Manages a set of WorkerTimeline instances, keyed by worker ID. + * + * Provides thread-safe lookup and on-demand creation of timelines, backed by + * a persistence directory. Each timeline is stored as a separate file named + * {WorkerId}.ztimeline within the directory. + */ +class WorkerTimelineStore +{ +public: + explicit WorkerTimelineStore(std::filesystem::path PersistenceDir); + ~WorkerTimelineStore() = default; + + WorkerTimelineStore(const WorkerTimelineStore&) = delete; + WorkerTimelineStore& operator=(const WorkerTimelineStore&) = delete; + + /** Get the timeline for a worker, creating one if it does not exist. + * If a persisted file exists on disk it will be loaded on first access. */ + Ref<WorkerTimeline> GetOrCreate(std::string_view WorkerId); + + /** Get the timeline for a worker, or null ref if it does not exist in memory. */ + [[nodiscard]] Ref<WorkerTimeline> Find(std::string_view WorkerId); + + /** Return the worker IDs of currently loaded (in-memory) timelines. */ + [[nodiscard]] std::vector<std::string> GetActiveWorkerIds() const; + + struct WorkerTimelineInfo + { + std::string WorkerId; + WorkerTimeline::TimeRange Range; + }; + + /** Return info for all known timelines (in-memory and on-disk), including time range. */ + [[nodiscard]] std::vector<WorkerTimelineInfo> GetAllWorkerInfo() const; + + /** Persist a single worker's timeline to disk. */ + void Save(std::string_view WorkerId); + + /** Persist all in-memory timelines to disk. */ + void SaveAll(); + +private: + [[nodiscard]] std::filesystem::path TimelinePath(std::string_view WorkerId) const; + + std::filesystem::path m_PersistenceDir; + mutable RwLock m_Lock; + std::unordered_map<std::string, Ref<WorkerTimeline>> m_Timelines; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/xmake.lua b/src/zencompute/xmake.lua new file mode 100644 index 000000000..ed0af66a5 --- /dev/null +++ b/src/zencompute/xmake.lua @@ -0,0 +1,19 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zencompute') + set_kind("static") + set_group("libs") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_includedirs(".", {private=true}) + add_deps("zencore", "zenstore", "zenutil", "zennet", "zenhttp") + add_packages("json11") + + if is_os("macosx") then + add_cxxflags("-Wno-deprecated-declarations") + end + + if is_plat("windows") then + add_syslinks("Userenv") + end diff --git a/src/zencompute/zencompute.cpp b/src/zencompute/zencompute.cpp new file mode 100644 index 000000000..1f3f6d3f9 --- /dev/null +++ b/src/zencompute/zencompute.cpp @@ -0,0 +1,21 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/zencompute.h" + +#if ZEN_WITH_TESTS +# include "runners/deferreddeleter.h" +# include <zencompute/cloudmetadata.h> +#endif + +namespace zen { + +void +zencompute_forcelinktests() +{ +#if ZEN_WITH_TESTS + compute::cloudmetadata_forcelink(); + compute::deferreddeleter_forcelink(); +#endif +} + +} // namespace zen |