diff options
Diffstat (limited to 'src/zencompute')
44 files changed, 11606 insertions, 2505 deletions
diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md new file mode 100644 index 000000000..f5188123f --- /dev/null +++ b/src/zencompute/CLAUDE.md @@ -0,0 +1,232 @@ +# zencompute Module + +Lambda-style compute function service. Accepts execution requests from HTTP clients, schedules them across local and remote runners, and tracks results. + +## Directory Structure + +``` +src/zencompute/ +├── include/zencompute/ # Public headers +│ ├── computeservice.h # ComputeServiceSession public API +│ ├── httpcomputeservice.h # HTTP service wrapper +│ ├── orchestratorservice.h # Worker registry and orchestration +│ ├── httporchestrator.h # HTTP orchestrator with WebSocket push +│ ├── recordingreader.h # Recording/replay reader API +│ ├── cloudmetadata.h # Cloud provider detection (AWS/Azure/GCP) +│ └── mockimds.h # Test helper for cloud metadata +├── runners/ # Execution backends +│ ├── functionrunner.h/.cpp # Abstract base + BaseRunnerGroup/RunnerGroup +│ ├── localrunner.h/.cpp # LocalProcessRunner (sandbox, monitoring, CPU sampling) +│ ├── windowsrunner.h/.cpp # Windows AppContainer sandboxing + CreateProcessW +│ ├── linuxrunner.h/.cpp # Linux user/mount/network namespace isolation +│ ├── macrunner.h/.cpp # macOS Seatbelt sandboxing +│ ├── winerunner.h/.cpp # Wine runner for Windows executables on Linux +│ ├── remotehttprunner.h/.cpp # Remote HTTP submission to other zenserver instances +│ └── deferreddeleter.h/.cpp # Background deletion of sandbox directories +├── recording/ # Recording/replay subsystem +│ ├── actionrecorder.h/.cpp # Write actions+attachments to disk +│ └── recordingreader.cpp # Read recordings back +├── timeline/ +│ └── workertimeline.h/.cpp # Per-worker action lifecycle event tracking +├── testing/ +│ └── mockimds.cpp # Mock IMDS for cloud metadata tests +├── computeservice.cpp # ComputeServiceSession::Impl (~1700 lines) +├── httpcomputeservice.cpp # HTTP route registration and handlers (~900 lines) +├── httporchestrator.cpp # Orchestrator HTTP API + WebSocket push +├── orchestratorservice.cpp # Worker registry, health probing +└── cloudmetadata.cpp # IMDS probing, termination monitoring +``` + +## Key Classes + +### `ComputeServiceSession` (computeservice.h) +Public API entry point. Uses PIMPL (`struct Impl` in computeservice.cpp). Owns: +- Two `RunnerGroup`s: `m_LocalRunnerGroup`, `m_RemoteRunnerGroup` +- Scheduler thread that drains `m_UpdatedActions` and drives state transitions +- Action maps: `m_PendingActions`, `m_RunningMap`, `m_ResultsMap` +- Queue map: `m_Queues` (QueueEntry objects) +- Action history ring: `m_ActionHistory` (bounded deque, default 1000) + +**Session states:** Created → Ready → Draining → Paused → Abandoned → Sunset. Both Abandoned and Sunset can be jumped to from any earlier state. Abandoned is used for spot instance termination grace periods — on entry, all pending and running actions are immediately marked as `RunnerAction::State::Abandoned` and running processes are best-effort cancelled. Auto-retry is suppressed while the session is Abandoned. `IsHealthy()` returns false for Abandoned and Sunset. + +### `RunnerAction` (runners/functionrunner.h) +Shared ref-counted struct representing one action through its lifecycle. + +**Key fields:** +- `ActionLsn` — global unique sequence number +- `QueueId` — 0 for implicit/unqueued actions +- `Worker` — descriptor + content hash +- `ActionObj` — CbObject with the action spec +- `CpuUsagePercent` / `CpuSeconds` — atomics updated by monitor thread +- `RetryCount` — atomic int tracking how many times the action has been rescheduled +- `Timestamps[State::_Count]` — timestamp of each state transition + +**State machine (forward-only under normal flow, atomic):** +``` +New → Pending → Submitting → Running → Completed + → Failed + → Abandoned + → Cancelled +``` +`SetActionState()` rejects non-forward transitions. The one exception is `ResetActionStateToPending()`, which uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling. It clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline. + +### `LocalProcessRunner` (runners/localrunner.h) +Base for all local execution. Platform runners subclass this and override: +- `SubmitAction()` — fork/exec the worker process +- `SweepRunningActions()` — poll for process exit (waitpid / WaitForSingleObject) +- `CancelRunningActions()` — signal all processes during shutdown +- `SampleProcessCpu(RunningAction&)` — read platform CPU usage (no-op default) + +**Infrastructure owned by LocalProcessRunner:** +- Monitor thread — calls `SweepRunningActions()` then `SampleRunningProcessCpu()` in a loop +- `m_RunningMap` — `RwLock`-guarded map of `Lsn → RunningAction` +- `DeferredDirectoryDeleter` — sandbox directories are queued for async deletion +- `PrepareActionSubmission()` — shared preamble (capacity check, sandbox creation, worker manifesting, input decompression) +- `ProcessCompletedActions()` — shared post-processing (gather outputs, set state, enqueue deletion) + +**CPU sampling:** `SampleRunningProcessCpu()` iterates `m_RunningMap` under shared lock, calls `SampleProcessCpu()` per entry, throttled to every 5 seconds per action. Platform implementations: +- Linux: `/proc/{pid}/stat` utime+stime jiffies ÷ `_SC_CLK_TCK` +- Windows: `GetProcessTimes()` in 100ns intervals ÷ 10,000,000 +- macOS: `proc_pidinfo(PROC_PIDTASKINFO)` pti_total_user+system nanoseconds ÷ 1,000,000,000 + +### `FunctionRunner` / `RunnerGroup` (runners/functionrunner.h) +Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. + +### `HttpComputeService` (include/zencompute/httpcomputeservice.h) +Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. + +## Action Lifecycle (End to End) + +1. **HTTP POST** → `HttpComputeService` ingests attachments, calls `EnqueueAction()` +2. **Enqueue** → creates `RunnerAction` (New → Pending), calls `PostUpdate()` +3. **PostUpdate** → appends to `m_UpdatedActions`, signals scheduler thread event +4. **Scheduler thread** → drains `m_UpdatedActions`, drives pending actions to runners +5. **Runner `SubmitAction()`** → Pending → Submitting (on runner's worker pool thread) +6. **Process launch** → Submitting → Running, added to `m_RunningMap` +7. **Monitor thread `SweepRunningActions()`** → detects exit, gathers outputs +8. **`ProcessCompletedActions()`** → Running → Completed/Failed/Abandoned, `PostUpdate()` +9. **Scheduler thread `HandleActionUpdates()`** — for Failed/Abandoned actions, checks retry limit; if retries remain, calls `ResetActionStateToPending()` which loops back to step 3. Otherwise moves to `m_ResultsMap`, records history, notifies queue. +10. **Client `GET /jobs/{lsn}`** → returns result from `m_ResultsMap`, schedules retirement + +### Action Rescheduling + +Actions that fail or are abandoned can be automatically retried or manually rescheduled via the API. + +**Automatic retry (scheduler path):** In `HandleActionUpdates()`, when a Failed or Abandoned state is detected, the scheduler checks `RetryCount < GetMaxRetriesForQueue(QueueId)`. If retries remain, the action is removed from active maps and `ResetActionStateToPending()` is called, which re-enters it into the scheduler pipeline. The action keeps its original LSN so clients can continue polling with the same identifier. + +**Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure. + +**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. + +**Cancelled actions are never retried** — cancellation is an intentional user action, not a transient failure. + +## Queue System + +Queues group actions from a single client session. A `QueueEntry` (internal) tracks: +- `State` — `std::atomic<QueueState>` lifecycle state (Active → Draining → Cancelled) +- `ActiveCount` — pending + running actions (atomic) +- `CompletedCount / FailedCount / AbandonedCount / CancelledCount` (atomics) +- `ActiveLsns` — for cancellation lookup (under `m_Lock`) +- `FinishedLsns` — moved here when actions complete +- `IdleSince` — used for 15-minute automatic expiry +- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit + +**Queue state machine (`QueueState` enum):** +``` +Active → Draining → Cancelled + \ ↑ + ─────────────────────/ +``` +- **Active** — accepts new work, schedules pending work, finishes running work (initial state) +- **Draining** — rejects new work, finishes existing work (one-way via CAS from Active; cannot override Cancelled) +- **Cancelled** — rejects new work, actively cancels in-flight work (reachable from Active or Draining) + +Key operations: +- `CreateQueue(Tag)` → returns `QueueId` +- `EnqueueActionToQueue(QueueId, ...)` → action's `QueueId` field is set at creation +- `CancelQueue(QueueId)` → marks all active LSNs for cancellation +- `DrainQueue(QueueId)` → stops accepting new submissions; existing work finishes naturally (irreversible) +- `GetQueueCompleted(QueueId)` → CbWriter output of finished results +- Queue references in HTTP routes accept either a decimal ID or an Oid token (24-hex), resolved by `ResolveQueueRef()` + +## HTTP API + +All routes registered in `HttpComputeService` constructor. Prefix is configured externally (typically `/compute`). + +### Global endpoints +| Method | Path | Description | +|--------|------|-------------| +| POST | `abandon` | Transition session to Abandoned state (409 if invalid) | +| GET | `jobs/history` | Action history (last N, with timestamps per state) | +| GET | `jobs/running` | In-flight actions with CPU metrics | +| GET | `jobs/completed` | Actions with results available | +| GET/POST/DELETE | `jobs/{lsn}` | GET: result; POST: reschedule failed action; DELETE: retire | +| POST | `jobs/{worker}` | Submit action for specific worker | +| POST | `jobs` | Submit action (worker resolved from descriptor) | +| GET | `workers` | List worker IDs | +| GET | `workers/all` | All workers with full descriptors | +| GET/POST | `workers/{worker}` | Get/register worker | + +### Queue-scoped endpoints +Queue ref is capture(1) in all `queues/{queueref}/...` routes. + +| Method | Path | Description | +|--------|------|-------------| +| GET | `queues` | List queue IDs | +| POST | `queues` | Create queue | +| GET/DELETE | `queues/{queueref}` | Status / delete | +| POST | `queues/{queueref}/drain` | Drain queue (irreversible; rejects new submissions) | +| GET | `queues/{queueref}/completed` | Queue's completed results | +| GET | `queues/{queueref}/history` | Queue's action history | +| GET | `queues/{queueref}/running` | Queue's running actions | +| POST | `queues/{queueref}/jobs` | Submit to queue | +| GET/POST | `queues/{queueref}/jobs/{lsn}` | GET: result; POST: reschedule | +| GET/POST | `queues/{queueref}/workers/...` | Worker endpoints (same as global) | + +Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `HandleWorkersAllGet`, `HandleWorkerRequest`) shared by top-level and queue-scoped routes. + +## Concurrency Model + +**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks: +1. `m_ResultsLock` +2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock") +3. `m_PendingLock` +4. `m_QueueLock` + +**Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`. + +**Update decoupling:** Runners call `PostUpdate(RunnerAction*)` rather than directly mutating service state. The scheduler thread batches and deduplicates updates. + +**Thread ownership:** +- Scheduler thread — drives state transitions, owns `m_PendingActions` +- Monitor thread (per runner) — polls process completion, owns `m_RunningMap` via shared lock +- Worker pool threads — async submission, brief `SubmitAction()` calls +- HTTP threads — read-only access to results, queue status + +## Sandbox Layout + +Each action gets a unique numbered directory under `m_SandboxPath`: +``` +scratch/{counter}/ + worker/ ← worker binaries (or bind-mounted on Linux) + inputs/ ← decompressed action inputs + outputs/ ← written by worker process +``` + +On Linux with sandboxing enabled, the process runs in a pivot-rooted namespace with `/usr`, `/lib`, `/etc`, `/worker` bind-mounted read-only and a tmpfs `/dev`. + +## Adding a New HTTP Endpoint + +1. Register the route in the `HttpComputeService` constructor in `httpcomputeservice.cpp` +2. If the handler is shared between top-level and a `queues/{queueref}/...` variant, extract it as a private helper method declared in `httpcomputeservice.h` +3. Queue-scoped routes validate the queue ref with `ResolveQueueRef(HttpReq, Req.GetCapture(1))` which writes an error response and returns 0 on failure +4. Use `CbObjectWriter` for response bodies; emit via `HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save())` +5. Conditional fields (e.g., optional CPU metrics): emit inside `if (value > 0.0f)` / `if (value >= 0.0f)` guards to omit absent values rather than emitting sentinel values + +## Adding a New Runner Platform + +1. Subclass `LocalProcessRunner`, add `h`/`cpp` files in `runners/` +2. Override `SubmitAction()`, `SweepRunningActions()`, `CancelRunningActions()`, and optionally `CancelAction(int)` and `SampleProcessCpu(RunningAction&)` +3. `SampleProcessCpu()` must update both `Running.Action->CpuSeconds` (unconditionally from the absolute OS value) and `Running.Action->CpuUsagePercent` (delta-based, only after second sample) +4. `ProcessHandle` convention: store pid as `reinterpret_cast<void*>(static_cast<intptr_t>(pid))` for consistency with the base class +5. Register in `ComputeServiceSession::AddLocalRunner()` in `computeservice.cpp` diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp new file mode 100644 index 000000000..65bac895f --- /dev/null +++ b/src/zencompute/cloudmetadata.cpp @@ -0,0 +1,1014 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/cloudmetadata.h> + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> +#include <zencore/string.h> +#include <zencore/trace.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +// All major cloud providers expose instance metadata at this link-local address. +// It is only routable from within a cloud VM; on bare-metal the TCP connect will +// fail, which is how we distinguish cloud from non-cloud environments. +static constexpr std::string_view kImdsEndpoint = "http://169.254.169.254"; + +// Short connect timeout so that detection on non-cloud machines is fast. The IMDS +// is a local service on the hypervisor so 200ms is generous for actual cloud VMs. +static constexpr auto kImdsTimeout = std::chrono::milliseconds{200}; + +std::string_view +ToString(CloudProvider Provider) +{ + switch (Provider) + { + case CloudProvider::AWS: + return "AWS"; + case CloudProvider::Azure: + return "Azure"; + case CloudProvider::GCP: + return "GCP"; + default: + return "None"; + } +} + +CloudMetadata::CloudMetadata(std::filesystem::path DataDir) : CloudMetadata(std::move(DataDir), std::string(kImdsEndpoint)) +{ +} + +CloudMetadata::CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint) +: m_Log(logging::Get("cloud")) +, m_DataDir(std::move(DataDir)) +, m_ImdsEndpoint(std::move(ImdsEndpoint)) +{ + ZEN_TRACE_CPU("CloudMetadata::CloudMetadata"); + + std::error_code Ec; + std::filesystem::create_directories(m_DataDir, Ec); + + DetectProvider(); + + if (m_Info.Provider != CloudProvider::None) + { + StartTerminationMonitor(); + } +} + +CloudMetadata::~CloudMetadata() +{ + ZEN_TRACE_CPU("CloudMetadata::~CloudMetadata"); + m_MonitorEnabled = false; + m_MonitorEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } +} + +CloudProvider +CloudMetadata::GetProvider() const +{ + return m_InfoLock.WithSharedLock([&] { return m_Info.Provider; }); +} + +CloudInstanceInfo +CloudMetadata::GetInstanceInfo() const +{ + return m_InfoLock.WithSharedLock([&] { return m_Info; }); +} + +bool +CloudMetadata::IsTerminationPending() const +{ + return m_TerminationPending.load(std::memory_order_relaxed); +} + +std::string +CloudMetadata::GetTerminationReason() const +{ + return m_ReasonLock.WithSharedLock([&] { return m_TerminationReason; }); +} + +void +CloudMetadata::Describe(CbWriter& Writer) const +{ + ZEN_TRACE_CPU("CloudMetadata::Describe"); + CloudInstanceInfo Info = GetInstanceInfo(); + + if (Info.Provider == CloudProvider::None) + { + return; + } + + Writer.BeginObject("cloud"); + Writer << "provider" << ToString(Info.Provider); + Writer << "instance_id" << Info.InstanceId; + Writer << "availability_zone" << Info.AvailabilityZone; + Writer << "is_spot" << Info.IsSpot; + Writer << "is_autoscaling" << Info.IsAutoscaling; + Writer << "termination_pending" << IsTerminationPending(); + + if (IsTerminationPending()) + { + Writer << "termination_reason" << GetTerminationReason(); + } + + Writer.EndObject(); +} + +void +CloudMetadata::DetectProvider() +{ + ZEN_TRACE_CPU("CloudMetadata::DetectProvider"); + + if (TryDetectAWS()) + { + return; + } + + if (TryDetectAzure()) + { + return; + } + + if (TryDetectGCP()) + { + return; + } + + ZEN_DEBUG("no cloud provider detected"); +} + +// AWS detection uses IMDSv2 which requires a session token obtained via PUT before +// any GET requests are allowed. This is more secure than IMDSv1 (which allowed +// unauthenticated GETs) and is the default on modern EC2 instances. The token has +// a 300-second TTL and is reused for termination polling. +bool +CloudMetadata::TryDetectAWS() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectAWS"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotAWS"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping AWS detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing AWS IMDS"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + // IMDSv2: acquire session token. The TTL header is mandatory; we request + // 300s which is sufficient for the detection phase. The token is also + // stored in m_AwsToken for reuse by the termination polling thread. + HttpClient::KeyValueMap TokenHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token-ttl-seconds", "300"}); + HttpClient::Response TokenResponse = ImdsClient.Put("/latest/api/token", IoBuffer{}, TokenHeaders); + + if (!TokenResponse.IsSuccess()) + { + ZEN_DEBUG("AWS IMDS token request failed ({}), not on AWS", static_cast<int>(TokenResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + m_AwsToken = std::string(TokenResponse.AsText()); + + HttpClient::KeyValueMap AuthHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token", m_AwsToken}); + + HttpClient::Response IdResponse = ImdsClient.Get("/latest/meta-data/instance-id", AuthHeaders); + if (IdResponse.IsSuccess()) + { + m_Info.InstanceId = std::string(IdResponse.AsText()); + } + + HttpClient::Response AzResponse = ImdsClient.Get("/latest/meta-data/placement/availability-zone", AuthHeaders); + if (AzResponse.IsSuccess()) + { + m_Info.AvailabilityZone = std::string(AzResponse.AsText()); + } + + // "spot" vs "on-demand" — determines whether the instance can be + // reclaimed by AWS with a 2-minute warning + HttpClient::Response LifecycleResponse = ImdsClient.Get("/latest/meta-data/instance-life-cycle", AuthHeaders); + if (LifecycleResponse.IsSuccess()) + { + m_Info.IsSpot = (LifecycleResponse.AsText() == "spot"); + } + + // This endpoint only exists on instances managed by an Auto Scaling + // Group. A successful response (regardless of value) means autoscaling. + HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders); + if (AutoscaleResponse.IsSuccess()) + { + m_Info.IsAutoscaling = true; + } + + m_Info.Provider = CloudProvider::AWS; + + ZEN_INFO("detected AWS instance: id={}, az={}, spot={}, autoscaling={}", + m_Info.InstanceId, + m_Info.AvailabilityZone, + m_Info.IsSpot, + m_Info.IsAutoscaling); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("AWS IMDS probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// Azure IMDS returns a single JSON document for the entire instance metadata, +// unlike AWS and GCP which use separate plain-text endpoints per field. The +// "Metadata: true" header is required; requests without it are rejected. +// The api-version parameter is mandatory and pins the response schema. +bool +CloudMetadata::TryDetectAzure() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectAzure"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotAzure"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping Azure detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing Azure IMDS"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + HttpClient::KeyValueMap MetadataHeaders({ + std::pair<std::string_view, std::string_view>{"Metadata", "true"}, + }); + + HttpClient::Response InstanceResponse = ImdsClient.Get("/metadata/instance?api-version=2021-02-01", MetadataHeaders); + + if (!InstanceResponse.IsSuccess()) + { + ZEN_DEBUG("Azure IMDS request failed ({}), not on Azure", static_cast<int>(InstanceResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(InstanceResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + ZEN_DEBUG("Azure IMDS returned invalid JSON: {}", JsonError); + WriteSentinelFile(SentinelPath); + return false; + } + + const json11::Json& Compute = Json["compute"]; + + m_Info.InstanceId = Compute["vmId"].string_value(); + m_Info.AvailabilityZone = Compute["location"].string_value(); + + // Azure spot VMs have priority "Spot"; regular VMs have "Regular" + std::string Priority = Compute["priority"].string_value(); + m_Info.IsSpot = (Priority == "Spot"); + + // Check if part of a VMSS (Virtual Machine Scale Set) — indicates autoscaling + std::string VmssName = Compute["vmScaleSetName"].string_value(); + m_Info.IsAutoscaling = !VmssName.empty(); + + m_Info.Provider = CloudProvider::Azure; + + ZEN_INFO("detected Azure instance: id={}, location={}, spot={}, vmss={}", + m_Info.InstanceId, + m_Info.AvailabilityZone, + m_Info.IsSpot, + m_Info.IsAutoscaling); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("Azure IMDS probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// GCP requires the "Metadata-Flavor: Google" header on all IMDS requests. +// Unlike AWS, there is no session token; the header itself is the auth mechanism +// (it prevents SSRF attacks since browsers won't send custom headers to the +// metadata endpoint). Each metadata field is fetched from a separate URL. +bool +CloudMetadata::TryDetectGCP() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectGCP"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotGCP"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping GCP detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing GCP metadata service"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + HttpClient::KeyValueMap MetadataHeaders(std::pair<std::string_view, std::string_view>{"Metadata-Flavor", "Google"}); + + // Fetch instance ID + HttpClient::Response IdResponse = ImdsClient.Get("/computeMetadata/v1/instance/id", MetadataHeaders); + + if (!IdResponse.IsSuccess()) + { + ZEN_DEBUG("GCP metadata request failed ({}), not on GCP", static_cast<int>(IdResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + m_Info.InstanceId = std::string(IdResponse.AsText()); + + // GCP returns the fully-qualified zone path "projects/<num>/zones/<zone>". + // Strip the prefix to get just the zone name (e.g. "us-central1-a"). + HttpClient::Response ZoneResponse = ImdsClient.Get("/computeMetadata/v1/instance/zone", MetadataHeaders); + if (ZoneResponse.IsSuccess()) + { + std::string_view Zone = ZoneResponse.AsText(); + if (auto Pos = Zone.rfind('/'); Pos != std::string_view::npos) + { + Zone = Zone.substr(Pos + 1); + } + m_Info.AvailabilityZone = std::string(Zone); + } + + // Check for preemptible/spot (scheduling/preemptible returns "TRUE" or "FALSE") + HttpClient::Response PreemptibleResponse = ImdsClient.Get("/computeMetadata/v1/instance/scheduling/preemptible", MetadataHeaders); + if (PreemptibleResponse.IsSuccess()) + { + m_Info.IsSpot = (PreemptibleResponse.AsText() == "TRUE"); + } + + // Check for maintenance event + HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders); + if (MaintenanceResponse.IsSuccess()) + { + std::string_view Event = MaintenanceResponse.AsText(); + if (!Event.empty() && Event != "NONE") + { + m_TerminationPending = true; + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); }); + } + } + + m_Info.Provider = CloudProvider::GCP; + + ZEN_INFO("detected GCP instance: id={}, az={}, spot={}", m_Info.InstanceId, m_Info.AvailabilityZone, m_Info.IsSpot); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("GCP metadata probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// Sentinel files are empty marker files whose mere existence signals that a +// previous detection attempt for a given provider failed. This avoids paying +// the connect-timeout cost on every startup for providers that are known to +// be absent. The files persist across process restarts; delete them manually +// (or remove the DataDir) to force re-detection. +void +CloudMetadata::WriteSentinelFile(const std::filesystem::path& Path) +{ + try + { + BasicFile File; + File.Open(Path, BasicFile::Mode::kTruncate); + } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to write sentinel file '{}': {}", Path.string(), Ex.what()); + } +} + +bool +CloudMetadata::HasSentinelFile(const std::filesystem::path& Path) const +{ + return zen::IsFile(Path); +} + +void +CloudMetadata::ClearSentinelFiles() +{ + std::error_code Ec; + std::filesystem::remove(m_DataDir / ".isNotAWS", Ec); + std::filesystem::remove(m_DataDir / ".isNotAzure", Ec); + std::filesystem::remove(m_DataDir / ".isNotGCP", Ec); +} + +void +CloudMetadata::StartTerminationMonitor() +{ + ZEN_INFO("starting cloud termination monitor for {} instance {}", ToString(m_Info.Provider), m_Info.InstanceId); + + m_MonitorThread = std::thread{&CloudMetadata::TerminationMonitorThread, this}; +} + +void +CloudMetadata::TerminationMonitorThread() +{ + SetCurrentThreadName("cloud_term_mon"); + + // Poll every 5 seconds. The Event is used as an interruptible sleep so + // that the destructor can wake us up immediately for a clean shutdown. + while (m_MonitorEnabled) + { + m_MonitorEvent.Wait(5000); + m_MonitorEvent.Reset(); + + if (!m_MonitorEnabled) + { + return; + } + + PollTermination(); + } +} + +void +CloudMetadata::PollTermination() +{ + try + { + CloudProvider Provider = m_InfoLock.WithSharedLock([&] { return m_Info.Provider; }); + + if (Provider == CloudProvider::AWS) + { + PollAWSTermination(); + } + else if (Provider == CloudProvider::Azure) + { + PollAzureTermination(); + } + else if (Provider == CloudProvider::GCP) + { + PollGCPTermination(); + } + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("termination poll error: {}", Ex.what()); + } +} + +// AWS termination signals: +// - /spot/instance-action: returns 200 with a JSON body ~2 minutes before +// a spot instance is reclaimed. Returns 404 when no action is pending. +// - /autoscaling/target-lifecycle-state: returns the ASG lifecycle state. +// "InService" is normal; anything else (e.g. "Terminated:Wait") means +// the instance is being cycled out. +void +CloudMetadata::PollAWSTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollAWSTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap AuthHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token", m_AwsToken}); + + HttpClient::Response SpotResponse = ImdsClient.Get("/latest/meta-data/spot/instance-action", AuthHeaders); + if (SpotResponse.IsSuccess()) + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS spot interruption: {}", SpotResponse.AsText()); }); + ZEN_WARN("AWS spot interruption detected: {}", SpotResponse.AsText()); + } + return; + } + + HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders); + if (AutoscaleResponse.IsSuccess()) + { + std::string_view State = AutoscaleResponse.AsText(); + if (State.find("InService") == std::string_view::npos) + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS autoscaling lifecycle: {}", State); }); + ZEN_WARN("AWS autoscaling termination detected: {}", State); + } + } + } +} + +// Azure Scheduled Events API returns a JSON array of upcoming platform events. +// We care about "Preempt" (spot eviction), "Terminate", and "Reboot" events. +// Other event types like "Freeze" (live migration) are non-destructive and +// ignored. The Events array is empty when nothing is pending. +void +CloudMetadata::PollAzureTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollAzureTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap MetadataHeaders({ + std::pair<std::string_view, std::string_view>{"Metadata", "true"}, + }); + + HttpClient::Response EventsResponse = ImdsClient.Get("/metadata/scheduledevents?api-version=2020-07-01", MetadataHeaders); + + if (!EventsResponse.IsSuccess()) + { + return; + } + + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(EventsResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + return; + } + + const json11::Json::array& Events = Json["Events"].array_items(); + for (const auto& Evt : Events) + { + std::string EventType = Evt["EventType"].string_value(); + if (EventType == "Preempt" || EventType == "Terminate" || EventType == "Reboot") + { + if (!m_TerminationPending.exchange(true)) + { + std::string EventStatus = Evt["EventStatus"].string_value(); + m_ReasonLock.WithExclusiveLock( + [&] { m_TerminationReason = fmt::format("Azure scheduled event: {} ({})", EventType, EventStatus); }); + ZEN_WARN("Azure termination event detected: {} ({})", EventType, EventStatus); + } + return; + } + } +} + +// GCP maintenance-event returns "NONE" when nothing is pending, and a +// descriptive string like "TERMINATE_ON_HOST_MAINTENANCE" when the VM is +// about to be live-migrated or terminated. Preemptible/spot VMs get a +// 30-second warning before termination. +void +CloudMetadata::PollGCPTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollGCPTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap MetadataHeaders(std::pair<std::string_view, std::string_view>{"Metadata-Flavor", "Google"}); + + HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders); + if (MaintenanceResponse.IsSuccess()) + { + std::string_view Event = MaintenanceResponse.AsText(); + if (!Event.empty() && Event != "NONE") + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); }); + ZEN_WARN("GCP maintenance event detected: {}", Event); + } + } + } +} + +} // namespace zen::compute + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +# include <zencompute/mockimds.h> + +# include <zencore/filesystem.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zenhttp/httpserver.h> + +# include <memory> +# include <thread> + +namespace zen::compute { + +TEST_SUITE_BEGIN("compute.cloudmetadata"); + +// --------------------------------------------------------------------------- +// Test helper — spins up a local ASIO HTTP server hosting a MockImdsService +// --------------------------------------------------------------------------- + +struct TestImdsServer +{ + MockImdsService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(7575, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + std::string Endpoint() const { return fmt::format("http://127.0.0.1:{}", m_Port); } + + std::filesystem::path DataDir() const { return m_TmpDir->Path() / "cloud"; } + + std::unique_ptr<CloudMetadata> CreateCloud() { return std::make_unique<CloudMetadata>(DataDir(), Endpoint()); } + + ~TestImdsServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional<ScopedTemporaryDirectory> m_TmpDir; + Ref<HttpServer> m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.aws") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::AWS; + + SUBCASE("detection basics") + { + Imds.Mock.Aws.InstanceId = "i-abc123"; + Imds.Mock.Aws.AvailabilityZone = "us-west-2b"; + Imds.Mock.Aws.LifeCycle = "on-demand"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::AWS); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "i-abc123"); + CHECK(Info.AvailabilityZone == "us-west-2b"); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("spot instance") + { + Imds.Mock.Aws.LifeCycle = "spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("autoscaling instance") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsAutoscaling == true); + } + + SUBCASE("spot termination") + { + Imds.Mock.Aws.LifeCycle = "spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + // Simulate a spot interruption notice appearing + Imds.Mock.Aws.SpotAction = R"({"action":"terminate","time":"2025-01-01T00:00:00Z"})"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("spot interruption") != std::string::npos); + } + + SUBCASE("autoscaling termination") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + // Simulate ASG cycling the instance out + Imds.Mock.Aws.AutoscalingState = "Terminated:Wait"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("autoscaling") != std::string::npos); + } + + SUBCASE("no termination when InService") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.azure") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::Azure; + + SUBCASE("detection basics") + { + Imds.Mock.Azure.VmId = "vm-test-1234"; + Imds.Mock.Azure.Location = "westeurope"; + Imds.Mock.Azure.Priority = "Regular"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::Azure); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "vm-test-1234"); + CHECK(Info.AvailabilityZone == "westeurope"); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("spot instance") + { + Imds.Mock.Azure.Priority = "Spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("vmss instance") + { + Imds.Mock.Azure.VmScaleSetName = "my-vmss"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsAutoscaling == true); + } + + SUBCASE("preempt termination") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Azure.ScheduledEventType = "Preempt"; + Imds.Mock.Azure.ScheduledEventStatus = "Scheduled"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("Preempt") != std::string::npos); + } + + SUBCASE("terminate event") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Azure.ScheduledEventType = "Terminate"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("Terminate") != std::string::npos); + } + + SUBCASE("no termination when events empty") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.gcp") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::GCP; + + SUBCASE("detection basics") + { + Imds.Mock.Gcp.InstanceId = "9876543210"; + Imds.Mock.Gcp.Zone = "projects/123/zones/europe-west1-b"; + Imds.Mock.Gcp.Preemptible = "FALSE"; + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::GCP); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "9876543210"); + CHECK(Info.AvailabilityZone == "europe-west1-b"); // zone prefix stripped + CHECK(Info.IsSpot == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("preemptible instance") + { + Imds.Mock.Gcp.Preemptible = "TRUE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("maintenance event during detection") + { + Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + // GCP sets termination pending immediately during detection if a + // maintenance event is active + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos); + } + + SUBCASE("maintenance event during polling") + { + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos); + } + + SUBCASE("no termination when NONE") + { + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// No provider +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.no_provider") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::None; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::None); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId.empty()); + CHECK(Info.AvailabilityZone.empty()); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); +} + +// --------------------------------------------------------------------------- +// Sentinel file management +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.sentinel_files") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::None; + Imds.Start(); + + auto DataDir = Imds.DataDir(); + + SUBCASE("sentinels are written on failed detection") + { + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::None); + CHECK(zen::IsFile(DataDir / ".isNotAWS")); + CHECK(zen::IsFile(DataDir / ".isNotAzure")); + CHECK(zen::IsFile(DataDir / ".isNotGCP")); + } + + SUBCASE("ClearSentinelFiles removes sentinels") + { + auto Cloud = Imds.CreateCloud(); + + CHECK(zen::IsFile(DataDir / ".isNotAWS")); + CHECK(zen::IsFile(DataDir / ".isNotAzure")); + CHECK(zen::IsFile(DataDir / ".isNotGCP")); + + Cloud->ClearSentinelFiles(); + + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP")); + } + + SUBCASE("only failed providers get sentinels") + { + // Switch to AWS — Azure and GCP never probed, so no sentinels for them + Imds.Mock.ActiveProvider = CloudProvider::AWS; + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::AWS); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP")); + } +} + +TEST_SUITE_END(); + +void +cloudmetadata_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp new file mode 100644 index 000000000..838d741b6 --- /dev/null +++ b/src/zencompute/computeservice.cpp @@ -0,0 +1,2236 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "runners/functionrunner.h" +# include "recording/actionrecorder.h" +# include "runners/localrunner.h" +# include "runners/remotehttprunner.h" +# if ZEN_PLATFORM_LINUX +# include "runners/linuxrunner.h" +# elif ZEN_PLATFORM_WINDOWS +# include "runners/windowsrunner.h" +# elif ZEN_PLATFORM_MAC +# include "runners/macrunner.h" +# endif + +# include <zencompute/recordingreader.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/trace.h> +# include <zencore/workthreadpool.h> +# include <zenutil/workerpools.h> +# include <zentelemetry/stats.h> +# include <zenhttp/httpclient.h> + +# include <set> +# include <deque> +# include <map> +# include <thread> +# include <unordered_map> +# include <unordered_set> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <EASTL/hash_set.h> +ZEN_THIRD_PARTY_INCLUDES_END + +using namespace std::literals; + +namespace zen { + +const char* +ToString(compute::ComputeServiceSession::SessionState State) +{ + using enum compute::ComputeServiceSession::SessionState; + switch (State) + { + case Created: + return "Created"; + case Ready: + return "Ready"; + case Draining: + return "Draining"; + case Paused: + return "Paused"; + case Abandoned: + return "Abandoned"; + case Sunset: + return "Sunset"; + } + return "Unknown"; +} + +const char* +ToString(compute::ComputeServiceSession::QueueState State) +{ + using enum compute::ComputeServiceSession::QueueState; + switch (State) + { + case Active: + return "active"; + case Draining: + return "draining"; + case Cancelled: + return "cancelled"; + } + return "unknown"; +} + +} // namespace zen + +namespace zen::compute { + +using SessionState = ComputeServiceSession::SessionState; + +static_assert(ZEN_ARRAY_COUNT(ComputeServiceSession::ActionHistoryEntry::Timestamps) == static_cast<size_t>(RunnerAction::State::_Count)); + +////////////////////////////////////////////////////////////////////////// + +struct ComputeServiceSession::Impl +{ + ComputeServiceSession* m_ComputeServiceSession; + ChunkResolver& m_ChunkResolver; + LoggerRef m_Log{logging::Get("compute")}; + + Impl(ComputeServiceSession* InComputeServiceSession, ChunkResolver& InChunkResolver) + : m_ComputeServiceSession(InComputeServiceSession) + , m_ChunkResolver(InChunkResolver) + , m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) + , m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) + { + // Create a non-expiring, non-deletable implicit queue for legacy endpoints + auto Result = CreateQueue("implicit"sv, {}, {}); + m_ImplicitQueueId = Result.QueueId; + m_QueueLock.WithSharedLock([&] { m_Queues[m_ImplicitQueueId]->Implicit = true; }); + + m_SchedulingThread = std::thread{&Impl::SchedulerThreadFunction, this}; + } + + void WaitUntilReady(); + void Shutdown(); + bool IsHealthy(); + + bool RequestStateTransition(SessionState NewState); + void AbandonAllActions(); + + LoggerRef Log() { return m_Log; } + + // Orchestration + + void SetOrchestratorEndpoint(std::string_view Endpoint); + void SetOrchestratorBasePath(std::filesystem::path BasePath); + + std::string m_OrchestratorEndpoint; + std::filesystem::path m_OrchestratorBasePath; + Stopwatch m_OrchestratorQueryTimer; + std::unordered_set<std::string> m_KnownWorkerUris; + + void UpdateCoordinatorState(); + + // Worker registration and discovery + + struct FunctionDefinition + { + std::string FunctionName; + Guid FunctionVersion; + Guid BuildSystemVersion; + IoHash WorkerId; + }; + + void RegisterWorker(CbPackage Worker); + WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + + // Action scheduling and tracking + + std::atomic<SessionState> m_SessionState{SessionState::Created}; + std::atomic<int32_t> m_ActionsCounter = 0; // sequence number + metrics::Meter m_ArrivalRate; + + RwLock m_PendingLock; + std::map<int, Ref<RunnerAction>> m_PendingActions; + + RwLock m_RunningLock; + std::unordered_map<int, Ref<RunnerAction>> m_RunningMap; + + RwLock m_ResultsLock; + std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap; + metrics::Meter m_ResultRate; + std::atomic<uint64_t> m_RetiredCount{0}; + + EnqueueResult EnqueueAction(int QueueId, CbObject ActionObject, int Priority); + EnqueueResult EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority); + + void GetCompleted(CbWriter& Cbo); + + HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + void RetireActionResult(int ActionLsn); + + std::thread m_SchedulingThread; + std::atomic<bool> m_SchedulingThreadEnabled{true}; + Event m_SchedulingThreadEvent; + + void SchedulerThreadFunction(); + void SchedulePendingActions(); + + // Workers + + RwLock m_WorkerLock; + std::unordered_map<IoHash, CbPackage> m_WorkerMap; + std::vector<FunctionDefinition> m_FunctionList; + std::vector<IoHash> GetKnownWorkerIds(); + void SyncWorkersToRunner(FunctionRunner& Runner); + + // Runners + + DeferredDirectoryDeleter m_DeferredDeleter; + WorkerThreadPool& m_LocalSubmitPool; + WorkerThreadPool& m_RemoteSubmitPool; + RunnerGroup<LocalProcessRunner> m_LocalRunnerGroup; + RunnerGroup<RemoteHttpRunner> m_RemoteRunnerGroup; + + void ShutdownRunners(); + + // Recording + + void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); + void StopRecording(); + + std::unique_ptr<ActionRecorder> m_Recorder; + + // History tracking + + RwLock m_ActionHistoryLock; + std::deque<ComputeServiceSession::ActionHistoryEntry> m_ActionHistory; + size_t m_HistoryLimit = 1000; + + // Queue tracking + + using QueueState = ComputeServiceSession::QueueState; + + struct QueueEntry : RefCounted + { + int QueueId; + bool Implicit{false}; + std::atomic<QueueState> State{QueueState::Active}; + std::atomic<int> ActiveCount{0}; // pending + running + std::atomic<int> CompletedCount{0}; // successfully completed + std::atomic<int> FailedCount{0}; // failed + std::atomic<int> AbandonedCount{0}; // abandoned + std::atomic<int> CancelledCount{0}; // cancelled + std::atomic<uint64_t> IdleSince{0}; // hifreq tick when queue became idle; 0 = has active work + + RwLock m_Lock; + std::unordered_set<int> ActiveLsns; // for cancellation lookup + std::unordered_set<int> FinishedLsns; // completed/failed/cancelled LSNs + + std::string Tag; + CbObject Metadata; + CbObject Config; + }; + + int m_ImplicitQueueId{0}; + std::atomic<int> m_QueueCounter{0}; + RwLock m_QueueLock; + std::unordered_map<int, Ref<QueueEntry>> m_Queues; + + Ref<QueueEntry> FindQueue(int QueueId) + { + Ref<QueueEntry> Queue; + m_QueueLock.WithSharedLock([&] { + if (auto It = m_Queues.find(QueueId); It != m_Queues.end()) + { + Queue = It->second; + } + }); + return Queue; + } + + ComputeServiceSession::CreateQueueResult CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config); + std::vector<int> GetQueueIds(); + ComputeServiceSession::QueueStatus GetQueueStatus(int QueueId); + CbObject GetQueueMetadata(int QueueId); + CbObject GetQueueConfig(int QueueId); + void CancelQueue(int QueueId); + void DeleteQueue(int QueueId); + void DrainQueue(int QueueId); + ComputeServiceSession::EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); + ComputeServiceSession::EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + void GetQueueCompleted(int QueueId, CbWriter& Cbo); + void NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState); + void ExpireCompletedQueues(); + + Stopwatch m_QueueExpiryTimer; + + std::vector<ComputeServiceSession::RunningActionInfo> GetRunningActions(); + std::vector<ComputeServiceSession::ActionHistoryEntry> GetActionHistory(int Limit); + std::vector<ComputeServiceSession::ActionHistoryEntry> GetQueueHistory(int QueueId, int Limit); + + // Action submission + + [[nodiscard]] size_t QueryCapacity(); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action); + [[nodiscard]] std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + [[nodiscard]] size_t GetSubmittedActionCount(); + + // Updates + + RwLock m_UpdatedActionsLock; + std::vector<Ref<RunnerAction>> m_UpdatedActions; + + void HandleActionUpdates(); + void PostUpdate(RunnerAction* Action); + + static constexpr int kDefaultMaxRetries = 3; + int GetMaxRetriesForQueue(int QueueId); + + ComputeServiceSession::RescheduleResult RescheduleAction(int ActionLsn); + + ActionCounts GetActionCounts() + { + ActionCounts Counts; + Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load(); + Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] { + size_t Count = 0; + for (const auto& [Id, Queue] : m_Queues) + { + if (!Queue->Implicit) + { + ++Count; + } + } + return Count; + }); + return Counts; + } + + void EmitStats(CbObjectWriter& Cbo) + { + Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed)); + m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); + m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); + m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); + Cbo << "actions_submitted"sv << GetSubmittedActionCount(); + EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo); + EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); + } +}; + +bool +ComputeServiceSession::Impl::IsHealthy() +{ + return m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned; +} + +bool +ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState) +{ + SessionState Current = m_SessionState.load(std::memory_order_relaxed); + + for (;;) + { + if (Current == NewState) + { + return true; + } + + // Validate the transition + bool Valid = false; + + switch (Current) + { + case SessionState::Created: + Valid = (NewState == SessionState::Ready); + break; + case SessionState::Ready: + Valid = (NewState == SessionState::Draining); + break; + case SessionState::Draining: + Valid = (NewState == SessionState::Ready || NewState == SessionState::Paused); + break; + case SessionState::Paused: + Valid = (NewState == SessionState::Ready || NewState == SessionState::Sunset); + break; + case SessionState::Abandoned: + Valid = (NewState == SessionState::Sunset); + break; + case SessionState::Sunset: + Valid = false; + break; + } + + // Allow jumping directly to Abandoned or Sunset from any non-terminal state + if (NewState == SessionState::Abandoned && Current < SessionState::Abandoned) + { + Valid = true; + } + if (NewState == SessionState::Sunset && Current != SessionState::Sunset) + { + Valid = true; + } + + if (!Valid) + { + ZEN_WARN("invalid session state transition {} -> {}", ToString(Current), ToString(NewState)); + return false; + } + + if (m_SessionState.compare_exchange_strong(Current, NewState, std::memory_order_acq_rel)) + { + ZEN_INFO("session state: {} -> {}", ToString(Current), ToString(NewState)); + + if (NewState == SessionState::Abandoned) + { + AbandonAllActions(); + } + + return true; + } + + // CAS failed, Current was updated — retry with the new value + } +} + +void +ComputeServiceSession::Impl::AbandonAllActions() +{ + // Collect all pending actions and mark them as Abandoned + std::vector<Ref<RunnerAction>> PendingToAbandon; + + m_PendingLock.WithSharedLock([&] { + PendingToAbandon.reserve(m_PendingActions.size()); + for (auto& [Lsn, Action] : m_PendingActions) + { + PendingToAbandon.push_back(Action); + } + }); + + for (auto& Action : PendingToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + } + + // Collect all running actions and mark them as Abandoned, then + // best-effort cancel via the local runner group + std::vector<Ref<RunnerAction>> RunningToAbandon; + + m_RunningLock.WithSharedLock([&] { + RunningToAbandon.reserve(m_RunningMap.size()); + for (auto& [Lsn, Action] : m_RunningMap) + { + RunningToAbandon.push_back(Action); + } + }); + + for (auto& Action : RunningToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + m_LocalRunnerGroup.CancelAction(Action->ActionLsn); + } + + ZEN_INFO("abandoned all actions: {} pending, {} running", PendingToAbandon.size(), RunningToAbandon.size()); +} + +void +ComputeServiceSession::Impl::SetOrchestratorEndpoint(std::string_view Endpoint) +{ + m_OrchestratorEndpoint = Endpoint; +} + +void +ComputeServiceSession::Impl::SetOrchestratorBasePath(std::filesystem::path BasePath) +{ + m_OrchestratorBasePath = std::move(BasePath); +} + +void +ComputeServiceSession::Impl::UpdateCoordinatorState() +{ + ZEN_TRACE_CPU("ComputeServiceSession::UpdateCoordinatorState"); + if (m_OrchestratorEndpoint.empty()) + { + return; + } + + // Poll faster when we have no discovered workers yet so remote runners come online quickly + const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000; + if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs) + { + return; + } + + m_OrchestratorQueryTimer.Reset(); + + try + { + HttpClient Client(m_OrchestratorEndpoint); + + HttpClient::Response Response = Client.Get("/orch/agents"); + + if (!Response.IsSuccess()) + { + ZEN_WARN("orchestrator query failed with status {}", static_cast<int>(Response.StatusCode)); + return; + } + + CbObject WorkerList = Response.AsObject(); + + std::unordered_set<std::string> ValidWorkerUris; + + for (auto& Item : WorkerList["workers"sv]) + { + CbObjectView Worker = Item.AsObjectView(); + + uint64_t Dt = Worker["dt"sv].AsUInt64(); + bool Reachable = Worker["reachable"sv].AsBool(); + std::string_view Uri = Worker["uri"sv].AsString(); + + // Skip stale workers (not seen in over 30 seconds) + if (Dt > 30000) + { + continue; + } + + // Skip workers that are not confirmed reachable + if (!Reachable) + { + continue; + } + + std::string UriStr{Uri}; + ValidWorkerUris.insert(UriStr); + + // Skip workers we already know about + if (m_KnownWorkerUris.contains(UriStr)) + { + continue; + } + + ZEN_INFO("discovered new worker at {}", UriStr); + + m_KnownWorkerUris.insert(UriStr); + + auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool); + SyncWorkersToRunner(*NewRunner); + m_RemoteRunnerGroup.AddRunner(NewRunner); + } + + // Remove workers that are no longer valid (stale or unreachable) + for (auto It = m_KnownWorkerUris.begin(); It != m_KnownWorkerUris.end();) + { + if (!ValidWorkerUris.contains(*It)) + { + const std::string& ExpiredUri = *It; + ZEN_INFO("removing expired worker at {}", ExpiredUri); + + m_RemoteRunnerGroup.RemoveRunnerIf([&](const RemoteHttpRunner& Runner) { return Runner.GetHostName() == ExpiredUri; }); + + It = m_KnownWorkerUris.erase(It); + } + else + { + ++It; + } + } + } + catch (const HttpClientError& Ex) + { + ZEN_WARN("orchestrator query error: {}", Ex.what()); + } + catch (const std::exception& Ex) + { + ZEN_WARN("orchestrator query unexpected error: {}", Ex.what()); + } +} + +void +ComputeServiceSession::Impl::WaitUntilReady() +{ + if (m_RemoteRunnerGroup.GetRunnerCount() || !m_OrchestratorEndpoint.empty()) + { + ZEN_INFO("waiting for remote runners..."); + + constexpr int MaxWaitSeconds = 120; + + for (int Elapsed = 0; Elapsed < MaxWaitSeconds; Elapsed++) + { + if (!m_SchedulingThreadEnabled.load(std::memory_order_relaxed)) + { + ZEN_WARN("shutdown requested while waiting for remote runners"); + return; + } + + const size_t Capacity = m_RemoteRunnerGroup.QueryCapacity(); + + if (Capacity > 0) + { + ZEN_INFO("found {} remote runners (capacity: {})", m_RemoteRunnerGroup.GetRunnerCount(), Capacity); + break; + } + + zen::Sleep(1000); + } + } + else + { + ZEN_ASSERT(m_LocalRunnerGroup.GetRunnerCount(), "no runners available"); + } + + RequestStateTransition(SessionState::Ready); +} + +void +ComputeServiceSession::Impl::Shutdown() +{ + RequestStateTransition(SessionState::Sunset); + + m_SchedulingThreadEnabled = false; + m_SchedulingThreadEvent.Set(); + if (m_SchedulingThread.joinable()) + { + m_SchedulingThread.join(); + } + + ShutdownRunners(); + + m_DeferredDeleter.Shutdown(); +} + +void +ComputeServiceSession::Impl::ShutdownRunners() +{ + m_LocalRunnerGroup.Shutdown(); + m_RemoteRunnerGroup.Shutdown(); +} + +void +ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) +{ + ZEN_INFO("starting recording to '{}'", RecordingPath); + + m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath); + + ZEN_INFO("started recording to '{}'", RecordingPath); +} + +void +ComputeServiceSession::Impl::StopRecording() +{ + ZEN_INFO("stopping recording"); + + m_Recorder = nullptr; + + ZEN_INFO("stopped recording"); +} + +std::vector<ComputeServiceSession::RunningActionInfo> +ComputeServiceSession::Impl::GetRunningActions() +{ + std::vector<ComputeServiceSession::RunningActionInfo> Result; + m_RunningLock.WithSharedLock([&] { + Result.reserve(m_RunningMap.size()); + for (const auto& [Lsn, Action] : m_RunningMap) + { + Result.push_back({.Lsn = Lsn, + .QueueId = Action->QueueId, + .ActionId = Action->ActionId, + .CpuUsagePercent = Action->CpuUsagePercent.load(std::memory_order_relaxed), + .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed)}); + } + }); + return Result; +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::Impl::GetActionHistory(int Limit) +{ + RwLock::SharedLockScope _(m_ActionHistoryLock); + + if (Limit > 0 && static_cast<size_t>(Limit) < m_ActionHistory.size()) + { + return std::vector<ActionHistoryEntry>(m_ActionHistory.end() - Limit, m_ActionHistory.end()); + } + + return std::vector<ActionHistoryEntry>(m_ActionHistory.begin(), m_ActionHistory.end()); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::Impl::GetQueueHistory(int QueueId, int Limit) +{ + // Resolve the queue and snapshot its finished LSN set + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + std::unordered_set<int> FinishedLsns; + + Queue->m_Lock.WithSharedLock([&] { FinishedLsns = Queue->FinishedLsns; }); + + // Filter the global history to entries belonging to this queue. + // m_ActionHistory is ordered oldest-first, so the filtered result keeps the same ordering. + std::vector<ActionHistoryEntry> Result; + + m_ActionHistoryLock.WithSharedLock([&] { + for (const auto& Entry : m_ActionHistory) + { + if (FinishedLsns.contains(Entry.Lsn)) + { + Result.push_back(Entry); + } + } + }); + + if (Limit > 0 && static_cast<size_t>(Limit) < Result.size()) + { + Result.erase(Result.begin(), Result.end() - Limit); + } + + return Result; +} + +void +ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) +{ + ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker"); + RwLock::ExclusiveLockScope _(m_WorkerLock); + + const IoHash& WorkerId = Worker.GetObject().GetHash(); + + if (m_WorkerMap.insert_or_assign(WorkerId, Worker).second) + { + // Note that since the convention currently is that WorkerId is equal to the hash + // of the worker descriptor there is no chance that we get a second write with a + // different descriptor. Thus we only need to call this the first time, when the + // worker is added + + m_LocalRunnerGroup.RegisterWorker(Worker); + m_RemoteRunnerGroup.RegisterWorker(Worker); + + if (m_Recorder) + { + m_Recorder->RegisterWorker(Worker); + } + + CbObject WorkerObj = Worker.GetObject(); + + // Populate worker database + + const Guid WorkerBuildSystemVersion = WorkerObj["buildsystem_version"sv].AsUuid(); + + for (auto& Item : WorkerObj["functions"sv]) + { + CbObjectView Function = Item.AsObjectView(); + + std::string_view FunctionName = Function["name"sv].AsString(); + const Guid FunctionVersion = Function["version"sv].AsUuid(); + + m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, + .FunctionVersion = FunctionVersion, + .BuildSystemVersion = WorkerBuildSystemVersion, + .WorkerId = WorkerId}); + } + } +} + +void +ComputeServiceSession::Impl::SyncWorkersToRunner(FunctionRunner& Runner) +{ + ZEN_TRACE_CPU("SyncWorkersToRunner"); + + std::vector<CbPackage> Workers; + + { + RwLock::SharedLockScope _(m_WorkerLock); + Workers.reserve(m_WorkerMap.size()); + for (const auto& [Id, Pkg] : m_WorkerMap) + { + Workers.push_back(Pkg); + } + } + + for (const CbPackage& Worker : Workers) + { + Runner.RegisterWorker(Worker); + } +} + +WorkerDesc +ComputeServiceSession::Impl::GetWorkerDescriptor(const IoHash& WorkerId) +{ + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + const CbPackage& Desc = It->second; + return {Desc, WorkerId}; + } + + return {}; +} + +std::vector<IoHash> +ComputeServiceSession::Impl::GetKnownWorkerIds() +{ + std::vector<IoHash> WorkerIds; + + m_WorkerLock.WithSharedLock([&] { + WorkerIds.reserve(m_WorkerMap.size()); + for (const auto& [WorkerId, _] : m_WorkerMap) + { + WorkerIds.push_back(WorkerId); + } + }); + + return WorkerIds; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueAction(int QueueId, CbObject ActionObject, int Priority) +{ + ZEN_TRACE_CPU("ComputeServiceSession::EnqueueAction"); + + // Resolve function to worker + + IoHash WorkerId{IoHash::Zero}; + CbPackage WorkerPackage; + + std::string_view FunctionName = ActionObject["Function"sv].AsString(); + const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); + const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + + m_WorkerLock.WithSharedLock([&] { + for (const FunctionDefinition& FuncDef : m_FunctionList) + { + if (FuncDef.FunctionName == FunctionName && FuncDef.FunctionVersion == FunctionVersion && + FuncDef.BuildSystemVersion == BuildSystemVersion) + { + WorkerId = FuncDef.WorkerId; + + break; + } + } + + if (WorkerId != IoHash::Zero) + { + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + WorkerPackage = It->second; + } + } + }); + + if (WorkerId == IoHash::Zero) + { + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker matches the action specification"; + + return {0, Writer.Save()}; + } + + if (WorkerPackage) + { + return EnqueueResolvedAction(QueueId, WorkerDesc{WorkerPackage, WorkerId}, ActionObject, Priority); + } + + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker found despite match"; + + return {0, Writer.Save()}; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + ZEN_TRACE_CPU("ComputeServiceSession::EnqueueResolvedAction"); + + if (m_SessionState.load(std::memory_order_relaxed) != SessionState::Ready) + { + CbObjectWriter Writer; + Writer << "error"sv << fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load())); + return {0, Writer.Save()}; + } + + const int ActionLsn = ++m_ActionsCounter; + + m_ArrivalRate.Mark(); + + Ref<RunnerAction> Pending{new RunnerAction(m_ComputeServiceSession)}; + + Pending->ActionLsn = ActionLsn; + Pending->QueueId = QueueId; + Pending->Worker = Worker; + Pending->ActionId = ActionObj.GetHash(); + Pending->ActionObj = ActionObj; + Pending->Priority = RequestPriority; + + // For now simply put action into pending state, so we can do batch scheduling + + ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); + + Pending->SetActionState(RunnerAction::State::Pending); + + if (m_Recorder) + { + m_Recorder->RecordAction(Pending); + } + + CbObjectWriter Writer; + Writer << "lsn" << Pending->ActionLsn; + Writer << "worker" << Pending->Worker.WorkerId; + Writer << "action" << Pending->ActionId; + + return {Pending->ActionLsn, Writer.Save()}; +} + +SubmitResult +ComputeServiceSession::Impl::SubmitAction(Ref<RunnerAction> Action) +{ + // Loosely round-robin scheduling of actions across runners. + // + // It's not entirely clear what this means given that submits + // can come in across multiple threads, but it's probably better + // than always starting with the first runner. + // + // Longer term we should track the state of the individual + // runners and make decisions accordingly. + + SubmitResult Result = m_LocalRunnerGroup.SubmitAction(Action); + if (Result.IsAccepted) + { + return Result; + } + + return m_RemoteRunnerGroup.SubmitAction(Action); +} + +size_t +ComputeServiceSession::Impl::GetSubmittedActionCount() +{ + return m_LocalRunnerGroup.GetSubmittedActionCount() + m_RemoteRunnerGroup.GetSubmittedActionCount(); +} + +HttpResponseCode +ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) + { + return HttpResponseCode::Accepted; + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) + { + return HttpResponseCode::Accepted; + } + } + + return HttpResponseCode::NotFound; +} + +HttpResponseCode +ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) + { + if (It->second->ActionId == ActionId) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + for (const auto& [K, Pending] : m_PendingActions) + { + if (Pending->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + for (const auto& [K, v] : m_RunningMap) + { + if (v->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + return HttpResponseCode::NotFound; +} + +void +ComputeServiceSession::Impl::RetireActionResult(int ActionLsn) +{ + m_DeferredDeleter.MarkReady(ActionLsn); +} + +void +ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo) +{ + Cbo.BeginArray("completed"); + + m_ResultsLock.WithSharedLock([&] { + for (auto& [Lsn, Action] : m_ResultsMap) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Lsn; + Cbo << "state"sv << RunnerAction::ToString(Action->ActionState()); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); +} + +////////////////////////////////////////////////////////////////////////// +// Queue management + +ComputeServiceSession::CreateQueueResult +ComputeServiceSession::Impl::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config) +{ + const int QueueId = ++m_QueueCounter; + + Ref<QueueEntry> Queue{new QueueEntry()}; + Queue->QueueId = QueueId; + Queue->Tag = Tag; + Queue->Metadata = std::move(Metadata); + Queue->Config = std::move(Config); + Queue->IdleSince = GetHifreqTimerValue(); + + m_QueueLock.WithExclusiveLock([&] { m_Queues[QueueId] = Queue; }); + + ZEN_DEBUG("created queue {}", QueueId); + + return {.QueueId = QueueId}; +} + +std::vector<int> +ComputeServiceSession::Impl::GetQueueIds() +{ + std::vector<int> Ids; + + m_QueueLock.WithSharedLock([&] { + Ids.reserve(m_Queues.size()); + for (const auto& [Id, Queue] : m_Queues) + { + if (!Queue->Implicit) + { + Ids.push_back(Id); + } + } + }); + + return Ids; +} + +ComputeServiceSession::QueueStatus +ComputeServiceSession::Impl::GetQueueStatus(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + const int Active = Queue->ActiveCount.load(std::memory_order_relaxed); + const int Completed = Queue->CompletedCount.load(std::memory_order_relaxed); + const int Failed = Queue->FailedCount.load(std::memory_order_relaxed); + const int AbandonedN = Queue->AbandonedCount.load(std::memory_order_relaxed); + const int CancelledN = Queue->CancelledCount.load(std::memory_order_relaxed); + const QueueState QState = Queue->State.load(); + + return { + .IsValid = true, + .QueueId = QueueId, + .ActiveCount = Active, + .CompletedCount = Completed, + .FailedCount = Failed, + .AbandonedCount = AbandonedN, + .CancelledCount = CancelledN, + .State = QState, + .IsComplete = (Active == 0), + }; +} + +CbObject +ComputeServiceSession::Impl::GetQueueMetadata(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + return Queue->Metadata; +} + +CbObject +ComputeServiceSession::Impl::GetQueueConfig(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + return Queue->Config; +} + +void +ComputeServiceSession::Impl::CancelQueue(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue || Queue->Implicit) + { + return; + } + + Queue->State.store(QueueState::Cancelled); + + // Collect active LSNs snapshot for cancellation + std::vector<int> LsnsToCancel; + + Queue->m_Lock.WithSharedLock([&] { LsnsToCancel.assign(Queue->ActiveLsns.begin(), Queue->ActiveLsns.end()); }); + + // Identify which LSNs are still pending (not yet dispatched to a runner) + std::vector<Ref<RunnerAction>> PendingActionsToCancel; + std::vector<int> RunningLsnsToCancel; + + m_PendingLock.WithSharedLock([&] { + for (int Lsn : LsnsToCancel) + { + if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end()) + { + PendingActionsToCancel.push_back(It->second); + } + } + }); + + m_RunningLock.WithSharedLock([&] { + for (int Lsn : LsnsToCancel) + { + if (m_RunningMap.find(Lsn) != m_RunningMap.end()) + { + RunningLsnsToCancel.push_back(Lsn); + } + } + }); + + // Cancel pending actions by marking them as Cancelled; they will flow through + // HandleActionUpdates and eventually be removed from the pending map. + for (auto& Action : PendingActionsToCancel) + { + Action->SetActionState(RunnerAction::State::Cancelled); + } + + // Best-effort cancellation of running actions via the local runner group. + // Also set the action state to Cancelled directly so a subsequent Failed + // transition from the runner is blocked (Cancelled > Failed in the enum). + for (int Lsn : RunningLsnsToCancel) + { + m_RunningLock.WithSharedLock([&] { + if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end()) + { + It->second->SetActionState(RunnerAction::State::Cancelled); + } + }); + m_LocalRunnerGroup.CancelAction(Lsn); + } + + m_RemoteRunnerGroup.CancelRemoteQueue(QueueId); + + ZEN_INFO("cancelled queue {}: {} pending, {} running actions cancelled", + QueueId, + PendingActionsToCancel.size(), + RunningLsnsToCancel.size()); + + // Wake up the scheduler to process the cancelled actions + m_SchedulingThreadEvent.Set(); +} + +void +ComputeServiceSession::Impl::DeleteQueue(int QueueId) +{ + // Never delete the implicit queue + { + Ref<QueueEntry> Queue = FindQueue(QueueId); + if (Queue && Queue->Implicit) + { + return; + } + } + + // Cancel any active work first + CancelQueue(QueueId); + + m_QueueLock.WithExclusiveLock([&] { + if (auto It = m_Queues.find(QueueId); It != m_Queues.end()) + { + m_Queues.erase(It); + } + }); +} + +void +ComputeServiceSession::Impl::DrainQueue(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue || Queue->Implicit) + { + return; + } + + QueueState Expected = QueueState::Active; + Queue->State.compare_exchange_strong(Expected, QueueState::Draining); + ZEN_INFO("draining queue {}", QueueId); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue not found"sv; + return {0, Writer.Save()}; + } + + QueueState QState = Queue->State.load(); + if (QState == QueueState::Cancelled) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is cancelled"sv; + return {0, Writer.Save()}; + } + + if (QState == QueueState::Draining) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is draining"sv; + return {0, Writer.Save()}; + } + + EnqueueResult Result = EnqueueAction(QueueId, ActionObject, Priority); + + if (Result.Lsn != 0) + { + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + } + + return Result; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue not found"sv; + return {0, Writer.Save()}; + } + + QueueState QState = Queue->State.load(); + if (QState == QueueState::Cancelled) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is cancelled"sv; + return {0, Writer.Save()}; + } + + if (QState == QueueState::Draining) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is draining"sv; + return {0, Writer.Save()}; + } + + EnqueueResult Result = EnqueueResolvedAction(QueueId, Worker, ActionObj, Priority); + + if (Result.Lsn != 0) + { + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + } + + return Result; +} + +void +ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + Cbo.BeginArray("completed"); + + if (Queue) + { + Queue->m_Lock.WithSharedLock([&] { + m_ResultsLock.WithSharedLock([&] { + for (int Lsn : Queue->FinishedLsns) + { + if (m_ResultsMap.contains(Lsn)) + { + Cbo << Lsn; + } + } + }); + }); + } + + Cbo.EndArray(); +} + +void +ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState) +{ + if (QueueId == 0) + { + return; + } + + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return; + } + + Queue->m_Lock.WithExclusiveLock([&] { + Queue->ActiveLsns.erase(Lsn); + Queue->FinishedLsns.insert(Lsn); + }); + + const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed); + if (PreviousActive == 1) + { + Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed); + } + + switch (ActionState) + { + case RunnerAction::State::Completed: + Queue->CompletedCount.fetch_add(1, std::memory_order_relaxed); + break; + case RunnerAction::State::Abandoned: + Queue->AbandonedCount.fetch_add(1, std::memory_order_relaxed); + break; + case RunnerAction::State::Cancelled: + Queue->CancelledCount.fetch_add(1, std::memory_order_relaxed); + break; + default: + Queue->FailedCount.fetch_add(1, std::memory_order_relaxed); + break; + } +} + +void +ComputeServiceSession::Impl::ExpireCompletedQueues() +{ + static constexpr uint64_t ExpiryTimeMs = 15 * 60 * 1000; + + std::vector<int> ExpiredQueueIds; + + m_QueueLock.WithSharedLock([&] { + for (const auto& [Id, Queue] : m_Queues) + { + if (Queue->Implicit) + { + continue; + } + const uint64_t Idle = Queue->IdleSince.load(std::memory_order_relaxed); + if (Idle != 0 && Queue->ActiveCount.load(std::memory_order_relaxed) == 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(GetHifreqTimerValue() - Idle); + if (ElapsedMs >= ExpiryTimeMs) + { + ExpiredQueueIds.push_back(Id); + } + } + } + }); + + for (int QueueId : ExpiredQueueIds) + { + ZEN_INFO("expiring idle queue {}", QueueId); + DeleteQueue(QueueId); + } +} + +void +ComputeServiceSession::Impl::SchedulePendingActions() +{ + ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions"); + int ScheduledCount = 0; + size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); + + static Stopwatch DumpRunningTimer; + + auto _ = MakeGuard([&] { + ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", + ScheduledCount, + RunningCount, + m_RetiredCount.load(), + PendingCount, + ResultCount); + + if (DumpRunningTimer.GetElapsedTimeMs() > 30000) + { + DumpRunningTimer.Reset(); + + std::set<int> RunningList; + m_RunningLock.WithSharedLock([&] { + for (auto& [K, V] : m_RunningMap) + { + RunningList.insert(K); + } + }); + + ExtendableStringBuilder<1024> RunningString; + for (int i : RunningList) + { + if (RunningString.Size()) + { + RunningString << ", "; + } + + RunningString.Append(IntNum(i)); + } + + ZEN_INFO("running: {}", RunningString); + } + }); + + size_t Capacity = QueryCapacity(); + + if (!Capacity) + { + _.Dismiss(); + + return; + } + + std::vector<Ref<RunnerAction>> ActionsToSchedule; + + // Pull actions to schedule from the pending queue, we will + // try to submit these to the runner outside of the lock. Note + // that because of how the state transitions work it's not + // actually the case that all of these actions will still be + // pending by the time we try to submit them, but that's fine. + // + // Also note that the m_PendingActions list is not maintained + // here, that's done periodically in SchedulePendingActions() + + m_PendingLock.WithExclusiveLock([&] { + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) + { + return; + } + + if (m_PendingActions.empty()) + { + return; + } + + for (auto& [Lsn, Pending] : m_PendingActions) + { + switch (Pending->ActionState()) + { + case RunnerAction::State::Pending: + ActionsToSchedule.push_back(Pending); + break; + + case RunnerAction::State::Submitting: + break; // already claimed by async submission + + case RunnerAction::State::Running: + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + case RunnerAction::State::Abandoned: + case RunnerAction::State::Cancelled: + break; + + default: + case RunnerAction::State::New: + ZEN_WARN("unexpected state {} for pending action {}", static_cast<int>(Pending->ActionState()), Pending->ActionLsn); + break; + } + } + + // Sort by priority descending, then by LSN ascending (FIFO within same priority) + std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) { + if (A->Priority != B->Priority) + { + return A->Priority > B->Priority; + } + return A->ActionLsn < B->ActionLsn; + }); + + if (ActionsToSchedule.size() > Capacity) + { + ActionsToSchedule.resize(Capacity); + } + + PendingCount = m_PendingActions.size(); + }); + + if (ActionsToSchedule.empty()) + { + _.Dismiss(); + return; + } + + ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); + + Stopwatch SubmitTimer; + std::vector<SubmitResult> SubmitResults = SubmitActions(ActionsToSchedule); + + int NotAcceptedCount = 0; + int ScheduledActionCount = 0; + + for (const SubmitResult& SubResult : SubmitResults) + { + if (SubResult.IsAccepted) + { + ++ScheduledActionCount; + } + else + { + ++NotAcceptedCount; + } + } + + ZEN_INFO("scheduled {} pending actions in {} ({} rejected)", + ScheduledActionCount, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + NotAcceptedCount); + + ScheduledCount += ScheduledActionCount; + PendingCount -= ScheduledActionCount; +} + +void +ComputeServiceSession::Impl::SchedulerThreadFunction() +{ + SetCurrentThreadName("Function_Scheduler"); + + auto _ = MakeGuard([&] { ZEN_INFO("scheduler thread exiting"); }); + + do + { + int TimeoutMs = 500; + + auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + + if (PendingCount) + { + TimeoutMs = 100; + } + + const bool WasSignaled = m_SchedulingThreadEvent.Wait(TimeoutMs); + + if (m_SchedulingThreadEnabled == false) + { + return; + } + + if (WasSignaled) + { + m_SchedulingThreadEvent.Reset(); + } + + ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}", + m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }), + PendingCount, + m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }), + m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }), + TimeoutMs); + + HandleActionUpdates(); + + // Auto-transition Draining → Paused when all work is done + if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining) + { + size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + + if (Pending == 0 && Running == 0) + { + SessionState Expected = SessionState::Draining; + if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel)) + { + ZEN_INFO("session state: Draining -> Paused (all work completed)"); + } + } + } + + UpdateCoordinatorState(); + SchedulePendingActions(); + + static constexpr uint64_t QueueExpirySweepIntervalMs = 30000; + if (m_QueueExpiryTimer.GetElapsedTimeMs() >= QueueExpirySweepIntervalMs) + { + m_QueueExpiryTimer.Reset(); + ExpireCompletedQueues(); + } + } while (m_SchedulingThreadEnabled); +} + +void +ComputeServiceSession::Impl::PostUpdate(RunnerAction* Action) +{ + m_UpdatedActionsLock.WithExclusiveLock([&] { m_UpdatedActions.emplace_back(Action); }); + m_SchedulingThreadEvent.Set(); +} + +int +ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId) +{ + if (QueueId == 0) + { + return kDefaultMaxRetries; + } + + CbObject Config = GetQueueConfig(QueueId); + + if (Config) + { + int Value = Config["max_retries"].AsInt32(0); + + if (Value > 0) + { + return Value; + } + } + + return kDefaultMaxRetries; +} + +ComputeServiceSession::RescheduleResult +ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) +{ + Ref<RunnerAction> Action; + RunnerAction::State State; + RescheduleResult ValidationError; + bool Removed = false; + + // Find, validate, and remove atomically under a single lock scope to prevent + // concurrent RescheduleAction calls from double-removing the same action. + m_ResultsLock.WithExclusiveLock([&] { + auto It = m_ResultsMap.find(ActionLsn); + if (It == m_ResultsMap.end()) + { + ValidationError = {.Success = false, .Error = "Action not found in results"}; + return; + } + + Action = It->second; + State = Action->ActionState(); + + if (State != RunnerAction::State::Failed && State != RunnerAction::State::Abandoned) + { + ValidationError = {.Success = false, .Error = "Action is not in a failed or abandoned state"}; + return; + } + + int MaxRetries = GetMaxRetriesForQueue(Action->QueueId); + if (Action->RetryCount.load(std::memory_order_relaxed) >= MaxRetries) + { + ValidationError = {.Success = false, .Error = "Retry limit reached"}; + return; + } + + m_ResultsMap.erase(It); + Removed = true; + }); + + if (!Removed) + { + return ValidationError; + } + + if (Action->QueueId != 0) + { + Ref<QueueEntry> Queue = FindQueue(Action->QueueId); + + if (Queue) + { + Queue->m_Lock.WithExclusiveLock([&] { + Queue->FinishedLsns.erase(ActionLsn); + Queue->ActiveLsns.insert(ActionLsn); + }); + + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + + if (State == RunnerAction::State::Failed) + { + Queue->FailedCount.fetch_sub(1, std::memory_order_relaxed); + } + else + { + Queue->AbandonedCount.fetch_sub(1, std::memory_order_relaxed); + } + } + } + + // Reset action state — this calls PostUpdate() internally + Action->ResetActionStateToPending(); + + int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); + ZEN_INFO("action {} ({}) manually rescheduled (retry {})", Action->ActionId, ActionLsn, NewRetryCount); + + return {.Success = true, .RetryCount = NewRetryCount}; +} + +void +ComputeServiceSession::Impl::HandleActionUpdates() +{ + ZEN_TRACE_CPU("ComputeServiceSession::HandleActionUpdates"); + + // Drain the update queue atomically + std::vector<Ref<RunnerAction>> UpdatedActions; + m_UpdatedActionsLock.WithExclusiveLock([&] { std::swap(UpdatedActions, m_UpdatedActions); }); + + std::unordered_set<int> SeenLsn; + + // Process each action's latest state, deduplicating by LSN. + // + // This is safe because state transitions are monotonically increasing by enum + // rank (Pending < Submitting < Running < Completed/Failed/Cancelled), so + // SetActionState rejects any transition to a lower-ranked state. By the time + // we read ActionState() here, it reflects the highest state reached — making + // the first occurrence per LSN authoritative and duplicates redundant. + for (Ref<RunnerAction>& Action : UpdatedActions) + { + const int ActionLsn = Action->ActionLsn; + + if (auto [It, Inserted] = SeenLsn.insert(ActionLsn); Inserted) + { + switch (Action->ActionState()) + { + // Newly enqueued — add to pending map for scheduling + case RunnerAction::State::Pending: + m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + break; + + // Async submission in progress — remains in pending map + case RunnerAction::State::Submitting: + break; + + // Dispatched to a runner — move from pending to running + case RunnerAction::State::Running: + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + m_RunningMap.insert({ActionLsn, Action}); + m_PendingActions.erase(ActionLsn); + }); + }); + ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); + break; + + // Terminal states — move to results, record history, notify queue + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + case RunnerAction::State::Abandoned: + case RunnerAction::State::Cancelled: + { + auto TerminalState = Action->ActionState(); + + // Automatic retry for Failed/Abandoned actions with retries remaining. + // Skip retries when the session itself is abandoned — those actions + // were intentionally abandoned and should not be rescheduled. + if ((TerminalState == RunnerAction::State::Failed || TerminalState == RunnerAction::State::Abandoned) && + m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned) + { + int MaxRetries = GetMaxRetriesForQueue(Action->QueueId); + + if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) + { + // Remove from whichever active map the action is in before resetting + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); + + // Reset triggers PostUpdate() which re-enters the action as Pending + Action->ResetActionStateToPending(); + int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); + + ZEN_INFO("action {} ({}) auto-rescheduled (retry {}/{})", + Action->ActionId, + ActionLsn, + NewRetryCount, + MaxRetries); + break; + } + } + + // Remove from whichever active map the action is in + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); + + m_ResultsLock.WithExclusiveLock([&] { + m_ResultsMap[ActionLsn] = Action; + + // Append to bounded action history ring + m_ActionHistoryLock.WithExclusiveLock([&] { + ActionHistoryEntry Entry{.Lsn = ActionLsn, + .QueueId = Action->QueueId, + .ActionId = Action->ActionId, + .WorkerId = Action->Worker.WorkerId, + .ActionDescriptor = Action->ActionObj, + .ExecutionLocation = std::move(Action->ExecutionLocation), + .Succeeded = TerminalState == RunnerAction::State::Completed, + .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed), + .RetryCount = Action->RetryCount.load(std::memory_order_relaxed)}; + + std::copy(std::begin(Action->Timestamps), std::end(Action->Timestamps), std::begin(Entry.Timestamps)); + + m_ActionHistory.push_back(std::move(Entry)); + + if (m_ActionHistory.size() > m_HistoryLimit) + { + m_ActionHistory.pop_front(); + } + }); + }); + m_RetiredCount.fetch_add(1); + m_ResultRate.Mark(1); + ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", + Action->ActionId, + ActionLsn, + TerminalState == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); + break; + } + } + } + } +} + +size_t +ComputeServiceSession::Impl::QueryCapacity() +{ + return m_LocalRunnerGroup.QueryCapacity() + m_RemoteRunnerGroup.QueryCapacity(); +} + +std::vector<SubmitResult> +ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions"); + std::vector<SubmitResult> Results(Actions.size()); + + // First try submitting the batch to local runners in parallel + + std::vector<SubmitResult> LocalResults = m_LocalRunnerGroup.SubmitActions(Actions); + std::vector<size_t> RemoteIndices; + std::vector<Ref<RunnerAction>> RemoteActions; + + for (size_t i = 0; i < Actions.size(); ++i) + { + if (LocalResults[i].IsAccepted) + { + Results[i] = std::move(LocalResults[i]); + } + else + { + RemoteIndices.push_back(i); + RemoteActions.push_back(Actions[i]); + } + } + + // Submit remaining actions to remote runners in parallel + if (!RemoteActions.empty()) + { + std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); + + for (size_t j = 0; j < RemoteIndices.size(); ++j) + { + Results[RemoteIndices[j]] = std::move(RemoteResults[j]); + } + } + + return Results; +} + +////////////////////////////////////////////////////////////////////////// + +ComputeServiceSession::ComputeServiceSession(ChunkResolver& InChunkResolver) +{ + m_Impl = std::make_unique<Impl>(this, InChunkResolver); +} + +ComputeServiceSession::~ComputeServiceSession() +{ + Shutdown(); +} + +bool +ComputeServiceSession::IsHealthy() +{ + return m_Impl->IsHealthy(); +} + +void +ComputeServiceSession::WaitUntilReady() +{ + m_Impl->WaitUntilReady(); +} + +void +ComputeServiceSession::Shutdown() +{ + m_Impl->Shutdown(); +} + +ComputeServiceSession::SessionState +ComputeServiceSession::GetSessionState() const +{ + return m_Impl->m_SessionState.load(std::memory_order_relaxed); +} + +bool +ComputeServiceSession::RequestStateTransition(SessionState NewState) +{ + return m_Impl->RequestStateTransition(NewState); +} + +void +ComputeServiceSession::SetOrchestratorEndpoint(std::string_view Endpoint) +{ + m_Impl->SetOrchestratorEndpoint(Endpoint); +} + +void +ComputeServiceSession::SetOrchestratorBasePath(std::filesystem::path BasePath) +{ + m_Impl->SetOrchestratorBasePath(std::move(BasePath)); +} + +void +ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) +{ + m_Impl->StartRecording(InResolver, RecordingPath); +} + +void +ComputeServiceSession::StopRecording() +{ + m_Impl->StopRecording(); +} + +ComputeServiceSession::ActionCounts +ComputeServiceSession::GetActionCounts() +{ + return m_Impl->GetActionCounts(); +} + +void +ComputeServiceSession::EmitStats(CbObjectWriter& Cbo) +{ + m_Impl->EmitStats(Cbo); +} + +std::vector<IoHash> +ComputeServiceSession::GetKnownWorkerIds() +{ + return m_Impl->GetKnownWorkerIds(); +} + +WorkerDesc +ComputeServiceSession::GetWorkerDescriptor(const IoHash& WorkerId) +{ + return m_Impl->GetWorkerDescriptor(WorkerId); +} + +void +ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddLocalRunner"); + +# if ZEN_PLATFORM_LINUX + auto* NewRunner = new LinuxProcessRunner(InChunkResolver, + BasePath, + m_Impl->m_DeferredDeleter, + m_Impl->m_LocalSubmitPool, + false, + MaxConcurrentActions); +# elif ZEN_PLATFORM_WINDOWS + auto* NewRunner = new WindowsProcessRunner(InChunkResolver, + BasePath, + m_Impl->m_DeferredDeleter, + m_Impl->m_LocalSubmitPool, + false, + MaxConcurrentActions); +# elif ZEN_PLATFORM_MAC + auto* NewRunner = + new MacProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, false, MaxConcurrentActions); +# endif + + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner); +} + +void +ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner"); + + auto* NewRunner = new RemoteHttpRunner(InChunkResolver, BasePath, HostName, m_Impl->m_RemoteSubmitPool); + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_RemoteRunnerGroup.AddRunner(NewRunner); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueAction(CbObject ActionObject, int Priority) +{ + return m_Impl->EnqueueActionToQueue(m_Impl->m_ImplicitQueueId, ActionObject, Priority); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + return m_Impl->EnqueueResolvedActionToQueue(m_Impl->m_ImplicitQueueId, Worker, ActionObj, RequestPriority); +} +ComputeServiceSession::CreateQueueResult +ComputeServiceSession::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config) +{ + return m_Impl->CreateQueue(Tag, std::move(Metadata), std::move(Config)); +} + +CbObject +ComputeServiceSession::GetQueueMetadata(int QueueId) +{ + return m_Impl->GetQueueMetadata(QueueId); +} + +CbObject +ComputeServiceSession::GetQueueConfig(int QueueId) +{ + return m_Impl->GetQueueConfig(QueueId); +} + +std::vector<int> +ComputeServiceSession::GetQueueIds() +{ + return m_Impl->GetQueueIds(); +} + +ComputeServiceSession::QueueStatus +ComputeServiceSession::GetQueueStatus(int QueueId) +{ + return m_Impl->GetQueueStatus(QueueId); +} + +void +ComputeServiceSession::CancelQueue(int QueueId) +{ + m_Impl->CancelQueue(QueueId); +} + +void +ComputeServiceSession::DrainQueue(int QueueId) +{ + m_Impl->DrainQueue(QueueId); +} + +void +ComputeServiceSession::DeleteQueue(int QueueId) +{ + m_Impl->DeleteQueue(QueueId); +} + +void +ComputeServiceSession::GetQueueCompleted(int QueueId, CbWriter& Cbo) +{ + m_Impl->GetQueueCompleted(QueueId, Cbo); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + return m_Impl->EnqueueActionToQueue(QueueId, ActionObject, Priority); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + return m_Impl->EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority); +} + +void +ComputeServiceSession::RegisterWorker(CbPackage Worker) +{ + m_Impl->RegisterWorker(Worker); +} + +HttpResponseCode +ComputeServiceSession::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + return m_Impl->GetActionResult(ActionLsn, OutResultPackage); +} + +HttpResponseCode +ComputeServiceSession::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + return m_Impl->FindActionResult(ActionId, OutResultPackage); +} + +void +ComputeServiceSession::RetireActionResult(int ActionLsn) +{ + m_Impl->RetireActionResult(ActionLsn); +} + +ComputeServiceSession::RescheduleResult +ComputeServiceSession::RescheduleAction(int ActionLsn) +{ + return m_Impl->RescheduleAction(ActionLsn); +} + +std::vector<ComputeServiceSession::RunningActionInfo> +ComputeServiceSession::GetRunningActions() +{ + return m_Impl->GetRunningActions(); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::GetActionHistory(int Limit) +{ + return m_Impl->GetActionHistory(Limit); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::GetQueueHistory(int QueueId, int Limit) +{ + return m_Impl->GetQueueHistory(QueueId, Limit); +} + +void +ComputeServiceSession::GetCompleted(CbWriter& Cbo) +{ + m_Impl->GetCompleted(Cbo); +} + +void +ComputeServiceSession::PostUpdate(RunnerAction* Action) +{ + m_Impl->PostUpdate(Action); +} + +////////////////////////////////////////////////////////////////////////// + +void +computeservice_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/functionrunner.cpp b/src/zencompute/functionrunner.cpp deleted file mode 100644 index 8e7c12b2b..000000000 --- a/src/zencompute/functionrunner.cpp +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "functionrunner.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include <zencore/compactbinary.h> -# include <zencore/filesystem.h> - -# include <fmt/format.h> -# include <vector> - -namespace zen::compute { - -FunctionRunner::FunctionRunner(std::filesystem::path BasePath) : m_ActionsPath(BasePath / "actions") -{ -} - -FunctionRunner::~FunctionRunner() = default; - -size_t -FunctionRunner::QueryCapacity() -{ - return 1; -} - -std::vector<SubmitResult> -FunctionRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) -{ - std::vector<SubmitResult> Results; - Results.reserve(Actions.size()); - - for (const Ref<RunnerAction>& Action : Actions) - { - Results.push_back(SubmitAction(Action)); - } - - return Results; -} - -void -FunctionRunner::MaybeDumpAction(int ActionLsn, const CbObject& ActionObject) -{ - if (m_DumpActions) - { - std::string UniqueId = fmt::format("{}.ddb", ActionLsn); - std::filesystem::path Path = m_ActionsPath / UniqueId; - - zen::WriteFile(Path, IoBuffer(ActionObject.GetBuffer().AsIoBuffer())); - } -} - -////////////////////////////////////////////////////////////////////////// - -RunnerAction::RunnerAction(FunctionServiceSession* OwnerSession) : m_OwnerSession(OwnerSession) -{ - this->Timestamps[static_cast<int>(State::New)] = DateTime::Now().GetTicks(); -} - -RunnerAction::~RunnerAction() -{ -} - -void -RunnerAction::SetActionState(State NewState) -{ - ZEN_ASSERT(NewState < State::_Count); - this->Timestamps[static_cast<int>(NewState)] = DateTime::Now().GetTicks(); - - do - { - if (State CurrentState = m_ActionState.load(); CurrentState == NewState) - { - // No state change - return; - } - else - { - if (NewState <= CurrentState) - { - // Cannot transition to an earlier or same state - return; - } - - if (m_ActionState.compare_exchange_strong(CurrentState, NewState)) - { - // Successful state change - - m_OwnerSession->PostUpdate(this); - - return; - } - } - } while (true); -} - -void -RunnerAction::SetResult(CbPackage&& Result) -{ - m_Result = std::move(Result); -} - -CbPackage& -RunnerAction::GetResult() -{ - ZEN_ASSERT(IsCompleted()); - return m_Result; -} - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/functionrunner.h b/src/zencompute/functionrunner.h deleted file mode 100644 index 6fd0d84cc..000000000 --- a/src/zencompute/functionrunner.h +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencompute/functionservice.h> - -#if ZEN_WITH_COMPUTE_SERVICES - -# include <filesystem> -# include <vector> - -namespace zen::compute { - -struct SubmitResult -{ - bool IsAccepted = false; - std::string Reason; -}; - -/** Base interface for classes implementing a remote execution "runner" - */ -class FunctionRunner : public RefCounted -{ - FunctionRunner(FunctionRunner&&) = delete; - FunctionRunner& operator=(FunctionRunner&&) = delete; - -public: - FunctionRunner(std::filesystem::path BasePath); - virtual ~FunctionRunner() = 0; - - virtual void Shutdown() = 0; - virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; - - [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0; - [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; - [[nodiscard]] virtual bool IsHealthy() = 0; - [[nodiscard]] virtual size_t QueryCapacity(); - [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); - -protected: - std::filesystem::path m_ActionsPath; - bool m_DumpActions = false; - void MaybeDumpAction(int ActionLsn, const CbObject& ActionObject); -}; - -template<typename RunnerType> -struct RunnerGroup -{ - void AddRunner(RunnerType* Runner) - { - m_RunnersLock.WithExclusiveLock([&] { m_Runners.emplace_back(Runner); }); - } - size_t QueryCapacity() - { - size_t TotalCapacity = 0; - m_RunnersLock.WithSharedLock([&] { - for (const auto& Runner : m_Runners) - { - TotalCapacity += Runner->QueryCapacity(); - } - }); - return TotalCapacity; - } - - SubmitResult SubmitAction(Ref<RunnerAction> Action) - { - RwLock::SharedLockScope _(m_RunnersLock); - - const int InitialIndex = m_NextSubmitIndex.load(std::memory_order_acquire); - int Index = InitialIndex; - const int RunnerCount = gsl::narrow<int>(m_Runners.size()); - - if (RunnerCount == 0) - { - return {.IsAccepted = false, .Reason = "No runners available"}; - } - - do - { - while (Index >= RunnerCount) - { - Index -= RunnerCount; - } - - auto& Runner = m_Runners[Index++]; - - SubmitResult Result = Runner->SubmitAction(Action); - - if (Result.IsAccepted == true) - { - m_NextSubmitIndex = Index % RunnerCount; - - return Result; - } - - while (Index >= RunnerCount) - { - Index -= RunnerCount; - } - } while (Index != InitialIndex); - - return {.IsAccepted = false}; - } - - size_t GetSubmittedActionCount() - { - RwLock::SharedLockScope _(m_RunnersLock); - - size_t TotalCount = 0; - - for (const auto& Runner : m_Runners) - { - TotalCount += Runner->GetSubmittedActionCount(); - } - - return TotalCount; - } - - void RegisterWorker(CbPackage Worker) - { - RwLock::SharedLockScope _(m_RunnersLock); - - for (auto& Runner : m_Runners) - { - Runner->RegisterWorker(Worker); - } - } - - void Shutdown() - { - RwLock::SharedLockScope _(m_RunnersLock); - - for (auto& Runner : m_Runners) - { - Runner->Shutdown(); - } - } - -private: - RwLock m_RunnersLock; - std::vector<Ref<RunnerType>> m_Runners; - std::atomic<int> m_NextSubmitIndex{0}; -}; - -/** - * This represents an action going through different stages of scheduling and execution. - */ -struct RunnerAction : public RefCounted -{ - explicit RunnerAction(FunctionServiceSession* OwnerSession); - ~RunnerAction(); - - int ActionLsn = 0; - WorkerDesc Worker; - IoHash ActionId; - CbObject ActionObj; - int Priority = 0; - - enum class State - { - New, - Pending, - Running, - Completed, - Failed, - _Count - }; - - static const char* ToString(State _) - { - switch (_) - { - case State::New: - return "New"; - case State::Pending: - return "Pending"; - case State::Running: - return "Running"; - case State::Completed: - return "Completed"; - case State::Failed: - return "Failed"; - default: - return "Unknown"; - } - } - - uint64_t Timestamps[static_cast<int>(State::_Count)] = {}; - - State ActionState() const { return m_ActionState; } - void SetActionState(State NewState); - - bool IsSuccess() const { return ActionState() == State::Completed; } - bool IsCompleted() const { return ActionState() == State::Completed || ActionState() == State::Failed; } - - void SetResult(CbPackage&& Result); - CbPackage& GetResult(); - -private: - std::atomic<State> m_ActionState = State::New; - FunctionServiceSession* m_OwnerSession = nullptr; - CbPackage m_Result; -}; - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/functionservice.cpp b/src/zencompute/functionservice.cpp deleted file mode 100644 index 0698449e9..000000000 --- a/src/zencompute/functionservice.cpp +++ /dev/null @@ -1,957 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "zencompute/functionservice.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "functionrunner.h" -# include "actionrecorder.h" -# include "localrunner.h" -# include "remotehttprunner.h" - -# include <zencompute/recordingreader.h> -# include <zencore/compactbinary.h> -# include <zencore/compactbinarybuilder.h> -# include <zencore/compactbinarypackage.h> -# include <zencore/compress.h> -# include <zencore/except.h> -# include <zencore/filesystem.h> -# include <zencore/fmtutils.h> -# include <zencore/iobuffer.h> -# include <zencore/iohash.h> -# include <zencore/logging.h> -# include <zencore/scopeguard.h> -# include <zentelemetry/stats.h> - -# include <set> -# include <deque> -# include <map> -# include <thread> -# include <unordered_map> - -ZEN_THIRD_PARTY_INCLUDES_START -# include <EASTL/hash_set.h> -ZEN_THIRD_PARTY_INCLUDES_END - -using namespace std::literals; - -namespace zen::compute { - -////////////////////////////////////////////////////////////////////////// - -struct FunctionServiceSession::Impl -{ - FunctionServiceSession* m_FunctionServiceSession; - ChunkResolver& m_ChunkResolver; - LoggerRef m_Log{logging::Get("apply")}; - - Impl(FunctionServiceSession* InFunctionServiceSession, ChunkResolver& InChunkResolver) - : m_FunctionServiceSession(InFunctionServiceSession) - , m_ChunkResolver(InChunkResolver) - { - m_SchedulingThread = std::thread{&Impl::MonitorThreadFunction, this}; - } - - void Shutdown(); - bool IsHealthy(); - - LoggerRef Log() { return m_Log; } - - std::atomic_bool m_AcceptActions = true; - - struct FunctionDefinition - { - std::string FunctionName; - Guid FunctionVersion; - Guid BuildSystemVersion; - IoHash WorkerId; - }; - - void EmitStats(CbObjectWriter& Cbo) - { - m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); - m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); - m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); - Cbo << "actions_submitted"sv << GetSubmittedActionCount(); - EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); - } - - void RegisterWorker(CbPackage Worker); - WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); - - std::atomic<int32_t> m_ActionsCounter = 0; // sequence number - - RwLock m_PendingLock; - std::map<int, Ref<RunnerAction>> m_PendingActions; - - RwLock m_RunningLock; - std::unordered_map<int, Ref<RunnerAction>> m_RunningMap; - - RwLock m_ResultsLock; - std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap; - metrics::Meter m_ResultRate; - std::atomic<uint64_t> m_RetiredCount{0}; - - HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); - HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); - - std::atomic<bool> m_ShutdownRequested{false}; - - std::thread m_SchedulingThread; - std::atomic<bool> m_SchedulingThreadEnabled{true}; - Event m_SchedulingThreadEvent; - - void MonitorThreadFunction(); - void SchedulePendingActions(); - - // Workers - - RwLock m_WorkerLock; - std::unordered_map<IoHash, CbPackage> m_WorkerMap; - std::vector<FunctionDefinition> m_FunctionList; - std::vector<IoHash> GetKnownWorkerIds(); - - // Runners - - RunnerGroup<LocalProcessRunner> m_LocalRunnerGroup; - RunnerGroup<RemoteHttpRunner> m_RemoteRunnerGroup; - - EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); - EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority); - - void GetCompleted(CbWriter& Cbo); - - // Recording - - void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); - void StopRecording(); - - std::unique_ptr<ActionRecorder> m_Recorder; - - // History tracking - - RwLock m_ActionHistoryLock; - std::deque<FunctionServiceSession::ActionHistoryEntry> m_ActionHistory; - size_t m_HistoryLimit = 1000; - - std::vector<FunctionServiceSession::ActionHistoryEntry> GetActionHistory(int Limit); - - // - - [[nodiscard]] size_t QueryCapacity(); - - [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action); - [[nodiscard]] std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); - [[nodiscard]] size_t GetSubmittedActionCount(); - - // Updates - - RwLock m_UpdatedActionsLock; - std::vector<Ref<RunnerAction>> m_UpdatedActions; - - void HandleActionUpdates(); - void PostUpdate(RunnerAction* Action); - - void ShutdownRunners(); -}; - -bool -FunctionServiceSession::Impl::IsHealthy() -{ - return true; -} - -void -FunctionServiceSession::Impl::Shutdown() -{ - m_AcceptActions = false; - m_ShutdownRequested = true; - - m_SchedulingThreadEnabled = false; - m_SchedulingThreadEvent.Set(); - if (m_SchedulingThread.joinable()) - { - m_SchedulingThread.join(); - } - - ShutdownRunners(); -} - -void -FunctionServiceSession::Impl::ShutdownRunners() -{ - m_LocalRunnerGroup.Shutdown(); - m_RemoteRunnerGroup.Shutdown(); -} - -void -FunctionServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) -{ - ZEN_INFO("starting recording to '{}'", RecordingPath); - - m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath); - - ZEN_INFO("started recording to '{}'", RecordingPath); -} - -void -FunctionServiceSession::Impl::StopRecording() -{ - ZEN_INFO("stopping recording"); - - m_Recorder = nullptr; - - ZEN_INFO("stopped recording"); -} - -std::vector<FunctionServiceSession::ActionHistoryEntry> -FunctionServiceSession::Impl::GetActionHistory(int Limit) -{ - RwLock::SharedLockScope _(m_ActionHistoryLock); - - if (Limit > 0 && static_cast<size_t>(Limit) < m_ActionHistory.size()) - { - return std::vector<ActionHistoryEntry>(m_ActionHistory.end() - Limit, m_ActionHistory.end()); - } - - return std::vector<ActionHistoryEntry>(m_ActionHistory.begin(), m_ActionHistory.end()); -} - -void -FunctionServiceSession::Impl::RegisterWorker(CbPackage Worker) -{ - RwLock::ExclusiveLockScope _(m_WorkerLock); - - const IoHash& WorkerId = Worker.GetObject().GetHash(); - - if (m_WorkerMap.insert_or_assign(WorkerId, Worker).second) - { - // Note that since the convention currently is that WorkerId is equal to the hash - // of the worker descriptor there is no chance that we get a second write with a - // different descriptor. Thus we only need to call this the first time, when the - // worker is added - - m_LocalRunnerGroup.RegisterWorker(Worker); - m_RemoteRunnerGroup.RegisterWorker(Worker); - - if (m_Recorder) - { - m_Recorder->RegisterWorker(Worker); - } - - CbObject WorkerObj = Worker.GetObject(); - - // Populate worker database - - const Guid WorkerBuildSystemVersion = WorkerObj["buildsystem_version"sv].AsUuid(); - - for (auto& Item : WorkerObj["functions"sv]) - { - CbObjectView Function = Item.AsObjectView(); - - std::string_view FunctionName = Function["name"sv].AsString(); - const Guid FunctionVersion = Function["version"sv].AsUuid(); - - m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, - .FunctionVersion = FunctionVersion, - .BuildSystemVersion = WorkerBuildSystemVersion, - .WorkerId = WorkerId}); - } - } -} - -WorkerDesc -FunctionServiceSession::Impl::GetWorkerDescriptor(const IoHash& WorkerId) -{ - RwLock::SharedLockScope _(m_WorkerLock); - - if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) - { - const CbPackage& Desc = It->second; - return {Desc, WorkerId}; - } - - return {}; -} - -std::vector<IoHash> -FunctionServiceSession::Impl::GetKnownWorkerIds() -{ - std::vector<IoHash> WorkerIds; - WorkerIds.reserve(m_WorkerMap.size()); - - m_WorkerLock.WithSharedLock([&] { - for (const auto& [WorkerId, _] : m_WorkerMap) - { - WorkerIds.push_back(WorkerId); - } - }); - - return WorkerIds; -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::Impl::EnqueueAction(CbObject ActionObject, int Priority) -{ - // Resolve function to worker - - IoHash WorkerId{IoHash::Zero}; - - std::string_view FunctionName = ActionObject["Function"sv].AsString(); - const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); - const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); - - for (const FunctionDefinition& FuncDef : m_FunctionList) - { - if (FuncDef.FunctionName == FunctionName && FuncDef.FunctionVersion == FunctionVersion && - FuncDef.BuildSystemVersion == BuildSystemVersion) - { - WorkerId = FuncDef.WorkerId; - - break; - } - } - - if (WorkerId == IoHash::Zero) - { - CbObjectWriter Writer; - - Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; - Writer << "error" - << "no worker matches the action specification"; - - return {0, Writer.Save()}; - } - - if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) - { - CbPackage WorkerPackage = It->second; - - return EnqueueResolvedAction(WorkerDesc{WorkerPackage, WorkerId}, ActionObject, Priority); - } - - CbObjectWriter Writer; - - Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; - Writer << "error" - << "no worker found despite match"; - - return {0, Writer.Save()}; -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::Impl::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) -{ - const int ActionLsn = ++m_ActionsCounter; - - Ref<RunnerAction> Pending{new RunnerAction(m_FunctionServiceSession)}; - - Pending->ActionLsn = ActionLsn; - Pending->Worker = Worker; - Pending->ActionId = ActionObj.GetHash(); - Pending->ActionObj = ActionObj; - Pending->Priority = RequestPriority; - - SubmitResult SubResult = SubmitAction(Pending); - - if (SubResult.IsAccepted) - { - // Great, the job is being taken care of by the runner - ZEN_DEBUG("direct schedule LSN {}", Pending->ActionLsn); - } - else - { - ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); - - Pending->SetActionState(RunnerAction::State::Pending); - } - - if (m_Recorder) - { - m_Recorder->RecordAction(Pending); - } - - CbObjectWriter Writer; - Writer << "lsn" << Pending->ActionLsn; - Writer << "worker" << Pending->Worker.WorkerId; - Writer << "action" << Pending->ActionId; - - return {Pending->ActionLsn, Writer.Save()}; -} - -SubmitResult -FunctionServiceSession::Impl::SubmitAction(Ref<RunnerAction> Action) -{ - // Loosely round-robin scheduling of actions across runners. - // - // It's not entirely clear what this means given that submits - // can come in across multiple threads, but it's probably better - // than always starting with the first runner. - // - // Longer term we should track the state of the individual - // runners and make decisions accordingly. - - SubmitResult Result = m_LocalRunnerGroup.SubmitAction(Action); - if (Result.IsAccepted) - { - return Result; - } - - return m_RemoteRunnerGroup.SubmitAction(Action); -} - -size_t -FunctionServiceSession::Impl::GetSubmittedActionCount() -{ - return m_LocalRunnerGroup.GetSubmittedActionCount() + m_RemoteRunnerGroup.GetSubmittedActionCount(); -} - -HttpResponseCode -FunctionServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) -{ - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); - - if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) - { - OutResultPackage = std::move(It->second->GetResult()); - - m_ResultsMap.erase(It); - - return HttpResponseCode::OK; - } - - { - RwLock::SharedLockScope __(m_PendingLock); - - if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) - { - return HttpResponseCode::Accepted; - } - } - - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - - { - RwLock::SharedLockScope __(m_RunningLock); - - if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) - { - return HttpResponseCode::Accepted; - } - } - - return HttpResponseCode::NotFound; -} - -HttpResponseCode -FunctionServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) -{ - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); - - for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) - { - if (It->second->ActionId == ActionId) - { - OutResultPackage = std::move(It->second->GetResult()); - - m_ResultsMap.erase(It); - - return HttpResponseCode::OK; - } - } - - { - RwLock::SharedLockScope __(m_PendingLock); - - for (const auto& [K, Pending] : m_PendingActions) - { - if (Pending->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } - } - } - - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - - { - RwLock::SharedLockScope __(m_RunningLock); - - for (const auto& [K, v] : m_RunningMap) - { - if (v->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } - } - } - - return HttpResponseCode::NotFound; -} - -void -FunctionServiceSession::Impl::GetCompleted(CbWriter& Cbo) -{ - Cbo.BeginArray("completed"); - - m_ResultsLock.WithSharedLock([&] { - for (auto& Kv : m_ResultsMap) - { - Cbo << Kv.first; - } - }); - - Cbo.EndArray(); -} - -# define ZEN_BATCH_SCHEDULER 1 - -void -FunctionServiceSession::Impl::SchedulePendingActions() -{ - int ScheduledCount = 0; - size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); - size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); - - static Stopwatch DumpRunningTimer; - - auto _ = MakeGuard([&] { - ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", - ScheduledCount, - RunningCount, - m_RetiredCount.load(), - PendingCount, - ResultCount); - - if (DumpRunningTimer.GetElapsedTimeMs() > 30000) - { - DumpRunningTimer.Reset(); - - std::set<int> RunningList; - m_RunningLock.WithSharedLock([&] { - for (auto& [K, V] : m_RunningMap) - { - RunningList.insert(K); - } - }); - - ExtendableStringBuilder<1024> RunningString; - for (int i : RunningList) - { - if (RunningString.Size()) - { - RunningString << ", "; - } - - RunningString.Append(IntNum(i)); - } - - ZEN_INFO("running: {}", RunningString); - } - }); - -# if ZEN_BATCH_SCHEDULER - size_t Capacity = QueryCapacity(); - - if (!Capacity) - { - _.Dismiss(); - - return; - } - - std::vector<Ref<RunnerAction>> ActionsToSchedule; - - // Pull actions to schedule from the pending queue, we will try to submit these to the runner outside of the lock - - m_PendingLock.WithExclusiveLock([&] { - if (m_ShutdownRequested) - { - return; - } - - if (m_PendingActions.empty()) - { - return; - } - - size_t NumActionsToSchedule = std::min(Capacity, m_PendingActions.size()); - - auto PendingIt = m_PendingActions.begin(); - const auto PendingEnd = m_PendingActions.end(); - - while (NumActionsToSchedule && PendingIt != PendingEnd) - { - const Ref<RunnerAction>& Pending = PendingIt->second; - - switch (Pending->ActionState()) - { - case RunnerAction::State::Pending: - ActionsToSchedule.push_back(Pending); - break; - - case RunnerAction::State::Running: - case RunnerAction::State::Completed: - case RunnerAction::State::Failed: - break; - - default: - case RunnerAction::State::New: - ZEN_WARN("unexpected state {} for pending action {}", static_cast<int>(Pending->ActionState()), Pending->ActionLsn); - break; - } - - ++PendingIt; - --NumActionsToSchedule; - } - - PendingCount = m_PendingActions.size(); - }); - - if (ActionsToSchedule.empty()) - { - _.Dismiss(); - return; - } - - ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); - - auto SubmitResults = SubmitActions(ActionsToSchedule); - - // Move successfully scheduled actions to the running map and remove - // from pending queue. It's actually possible that by the time we get - // to this stage some of the actions may have already completed, so - // they should not always be added to the running map - - eastl::hash_set<int> ScheduledActions; - - for (size_t i = 0; i < ActionsToSchedule.size(); ++i) - { - const Ref<RunnerAction>& Pending = ActionsToSchedule[i]; - const SubmitResult& SubResult = SubmitResults[i]; - - if (SubResult.IsAccepted) - { - ScheduledActions.insert(Pending->ActionLsn); - } - } - - ScheduledCount += (int)ActionsToSchedule.size(); - -# else - m_PendingLock.WithExclusiveLock([&] { - while (!m_PendingActions.empty()) - { - if (m_ShutdownRequested) - { - return; - } - - // Here it would be good if we could decide to pop immediately to avoid - // holding the lock while creating processes etc - const Ref<RunnerAction>& Pending = m_PendingActions.begin()->second; - FunctionRunner::SubmitResult SubResult = SubmitAction(Pending); - - if (SubResult.IsAccepted) - { - // Great, the job is being taken care of by the runner - - ZEN_DEBUG("action {} ({}) PENDING -> RUNNING", Pending->ActionId, Pending->ActionLsn); - - m_RunningLock.WithExclusiveLock([&] { - m_RunningMap.insert({Pending->ActionLsn, Pending}); - - RunningCount = m_RunningMap.size(); - }); - - m_PendingActions.pop_front(); - - PendingCount = m_PendingActions.size(); - ++ScheduledCount; - } - else - { - // Runner could not accept the job, leave it on the pending queue - - return; - } - } - }); -# endif -} - -void -FunctionServiceSession::Impl::MonitorThreadFunction() -{ - SetCurrentThreadName("FunctionServiceSession_Monitor"); - - auto _ = MakeGuard([&] { ZEN_INFO("monitor thread exiting"); }); - - do - { - int TimeoutMs = 1000; - - if (m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); })) - { - TimeoutMs = 100; - } - - const bool Timedout = m_SchedulingThreadEvent.Wait(TimeoutMs); - - if (m_SchedulingThreadEnabled == false) - { - return; - } - - HandleActionUpdates(); - - // Schedule pending actions - - SchedulePendingActions(); - - if (!Timedout) - { - m_SchedulingThreadEvent.Reset(); - } - } while (m_SchedulingThreadEnabled); -} - -void -FunctionServiceSession::Impl::PostUpdate(RunnerAction* Action) -{ - m_UpdatedActionsLock.WithExclusiveLock([&] { m_UpdatedActions.emplace_back(Action); }); -} - -void -FunctionServiceSession::Impl::HandleActionUpdates() -{ - std::vector<Ref<RunnerAction>> UpdatedActions; - - m_UpdatedActionsLock.WithExclusiveLock([&] { std::swap(UpdatedActions, m_UpdatedActions); }); - - std::unordered_set<int> SeenLsn; - std::unordered_set<int> RunningLsn; - - for (Ref<RunnerAction>& Action : UpdatedActions) - { - const int ActionLsn = Action->ActionLsn; - - if (auto [It, Inserted] = SeenLsn.insert(ActionLsn); Inserted) - { - switch (Action->ActionState()) - { - case RunnerAction::State::Pending: - m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); - break; - - case RunnerAction::State::Running: - m_PendingLock.WithExclusiveLock([&] { - m_RunningLock.WithExclusiveLock([&] { - m_RunningMap.insert({ActionLsn, Action}); - m_PendingActions.erase(ActionLsn); - }); - }); - ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); - break; - - case RunnerAction::State::Completed: - case RunnerAction::State::Failed: - m_ResultsLock.WithExclusiveLock([&] { - m_ResultsMap[ActionLsn] = Action; - - m_PendingLock.WithExclusiveLock([&] { - m_RunningLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); - - m_ActionHistoryLock.WithExclusiveLock([&] { - ActionHistoryEntry Entry{.Lsn = ActionLsn, - .ActionId = Action->ActionId, - .WorkerId = Action->Worker.WorkerId, - .ActionDescriptor = Action->ActionObj, - .Succeeded = Action->ActionState() == RunnerAction::State::Completed}; - - std::copy(std::begin(Action->Timestamps), std::end(Action->Timestamps), std::begin(Entry.Timestamps)); - - m_ActionHistory.push_back(std::move(Entry)); - - if (m_ActionHistory.size() > m_HistoryLimit) - { - m_ActionHistory.pop_front(); - } - }); - }); - m_RetiredCount.fetch_add(1); - m_ResultRate.Mark(1); - ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", - Action->ActionId, - ActionLsn, - Action->ActionState() == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); - break; - } - } - } -} - -size_t -FunctionServiceSession::Impl::QueryCapacity() -{ - return m_LocalRunnerGroup.QueryCapacity() + m_RemoteRunnerGroup.QueryCapacity(); -} - -std::vector<SubmitResult> -FunctionServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) -{ - std::vector<SubmitResult> Results; - - for (const Ref<RunnerAction>& Action : Actions) - { - Results.push_back(SubmitAction(Action)); - } - - return Results; -} - -////////////////////////////////////////////////////////////////////////// - -FunctionServiceSession::FunctionServiceSession(ChunkResolver& InChunkResolver) -{ - m_Impl = std::make_unique<Impl>(this, InChunkResolver); -} - -FunctionServiceSession::~FunctionServiceSession() -{ - Shutdown(); -} - -bool -FunctionServiceSession::IsHealthy() -{ - return m_Impl->IsHealthy(); -} - -void -FunctionServiceSession::Shutdown() -{ - m_Impl->Shutdown(); -} - -void -FunctionServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) -{ - m_Impl->StartRecording(InResolver, RecordingPath); -} - -void -FunctionServiceSession::StopRecording() -{ - m_Impl->StopRecording(); -} - -void -FunctionServiceSession::EmitStats(CbObjectWriter& Cbo) -{ - m_Impl->EmitStats(Cbo); -} - -std::vector<IoHash> -FunctionServiceSession::GetKnownWorkerIds() -{ - return m_Impl->GetKnownWorkerIds(); -} - -WorkerDesc -FunctionServiceSession::GetWorkerDescriptor(const IoHash& WorkerId) -{ - return m_Impl->GetWorkerDescriptor(WorkerId); -} - -void -FunctionServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath) -{ - m_Impl->m_LocalRunnerGroup.AddRunner(new LocalProcessRunner(InChunkResolver, BasePath)); -} - -void -FunctionServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) -{ - m_Impl->m_RemoteRunnerGroup.AddRunner(new RemoteHttpRunner(InChunkResolver, BasePath, HostName)); -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::EnqueueAction(CbObject ActionObject, int Priority) -{ - return m_Impl->EnqueueAction(ActionObject, Priority); -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) -{ - return m_Impl->EnqueueResolvedAction(Worker, ActionObj, RequestPriority); -} - -void -FunctionServiceSession::RegisterWorker(CbPackage Worker) -{ - m_Impl->RegisterWorker(Worker); -} - -HttpResponseCode -FunctionServiceSession::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) -{ - return m_Impl->GetActionResult(ActionLsn, OutResultPackage); -} - -HttpResponseCode -FunctionServiceSession::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) -{ - return m_Impl->FindActionResult(ActionId, OutResultPackage); -} - -std::vector<FunctionServiceSession::ActionHistoryEntry> -FunctionServiceSession::GetActionHistory(int Limit) -{ - return m_Impl->GetActionHistory(Limit); -} - -void -FunctionServiceSession::GetCompleted(CbWriter& Cbo) -{ - m_Impl->GetCompleted(Cbo); -} - -void -FunctionServiceSession::PostUpdate(RunnerAction* Action) -{ - m_Impl->PostUpdate(Action); -} - -////////////////////////////////////////////////////////////////////////// - -void -function_forcelink() -{ -} - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp new file mode 100644 index 000000000..e82a40781 --- /dev/null +++ b/src/zencompute/httpcomputeservice.cpp @@ -0,0 +1,1643 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/httpcomputeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "runners/functionrunner.h" + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/logging.h> +# include <zencore/system.h> +# include <zencore/thread.h> +# include <zencore/trace.h> +# include <zencore/uid.h> +# include <zenstore/cidstore.h> +# include <zentelemetry/stats.h> + +# include <span> +# include <unordered_map> + +using namespace std::literals; + +namespace zen::compute { + +constinit AsciiSet g_DecimalSet("0123456789"); +constinit AsciiSet g_HexSet("0123456789abcdefABCDEF"); + +auto DecimalMatcher = [](std::string_view Str) { return AsciiSet::HasOnly(Str, g_DecimalSet); }; +auto IoHashMatcher = [](std::string_view Str) { return Str.size() == 40 && AsciiSet::HasOnly(Str, g_HexSet); }; +auto OidMatcher = [](std::string_view Str) { return Str.size() == 24 && AsciiSet::HasOnly(Str, g_HexSet); }; + +////////////////////////////////////////////////////////////////////////// + +struct HttpComputeService::Impl +{ + HttpComputeService* m_Self; + CidStore& m_CidStore; + IHttpStatsService& m_StatsService; + LoggerRef m_Log; + std::filesystem::path m_BaseDir; + HttpRequestRouter m_Router; + ComputeServiceSession m_ComputeService; + SystemMetricsTracker m_MetricsTracker; + + // Metrics + + metrics::OperationTiming m_HttpRequests; + + // Per-remote-queue metadata, shared across all lookup maps below. + + struct RemoteQueueInfo : RefCounted + { + int QueueId = 0; + Oid Token; + std::string IdempotencyKey; // empty if no idempotency key was provided + std::string ClientHostname; // empty if no hostname was provided + }; + + // Remote queue registry — all three maps share the same RemoteQueueInfo objects. + // All maps are guarded by m_RemoteQueueLock. + + RwLock m_RemoteQueueLock; + std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token → info + std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId → info + std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key → info + + LoggerRef Log() { return m_Log; } + + int ResolveQueueToken(const Oid& Token); + int ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture); + + struct IngestStats + { + int Count = 0; + int NewCount = 0; + uint64_t Bytes = 0; + uint64_t NewBytes = 0; + }; + + IngestStats IngestPackageAttachments(const CbPackage& Package); + bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); + void HandleWorkersGet(HttpServerRequest& HttpReq); + void HandleWorkersAllGet(HttpServerRequest& HttpReq); + void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); + void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + + void RegisterRoutes(); + + Impl(HttpComputeService* Self, + CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions) + : m_Self(Self) + , m_CidStore(InCidStore) + , m_StatsService(StatsService) + , m_Log(logging::Get("compute")) + , m_BaseDir(BaseDir) + , m_ComputeService(InCidStore) + { + m_ComputeService.AddLocalRunner(InCidStore, m_BaseDir / "local", MaxConcurrentActions); + m_ComputeService.WaitUntilReady(); + m_StatsService.RegisterHandler("compute", *m_Self); + RegisterRoutes(); + } +}; + +////////////////////////////////////////////////////////////////////////// + +void +HttpComputeService::Impl::RegisterRoutes() +{ + m_Router.AddMatcher("lsn", DecimalMatcher); + m_Router.AddMatcher("worker", IoHashMatcher); + m_Router.AddMatcher("action", IoHashMatcher); + m_Router.AddMatcher("queue", DecimalMatcher); + m_Router.AddMatcher("oidtoken", OidMatcher); + m_Router.AddMatcher("queueref", [](std::string_view Str) { return DecimalMatcher(Str) || OidMatcher(Str); }); + + m_Router.RegisterRoute( + "ready", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.IsHealthy()) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + + return HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "abandon", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + bool Success = m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Abandoned); + + if (Success) + { + CbObjectWriter Cbo; + Cbo << "state"sv + << "Abandoned"sv; + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Abandoned from current state"sv; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "workers", + [this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "workers/{worker}", + [this](HttpRouterRequest& Req) { HandleWorkerRequest(Req.ServerRequest(), IoHash::FromHexString(Req.GetCapture(1))); }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/completed", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + m_ComputeService.GetCompleted(Cbo); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + Cbo.BeginObject("metrics"); + Describe(Sm, Cbo); + Cbo.EndObject(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto QueryParams = HttpReq.GetQueryParams(); + + int QueryLimit = 50; + + if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) + { + QueryLimit = ParseInt<int>(LimitParam).value_or(50); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("history"); + for (const auto& Entry : m_ComputeService.GetActionHistory(QueryLimit)) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Entry.Lsn; + Cbo << "queueId"sv << Entry.QueueId; + Cbo << "actionId"sv << Entry.ActionId; + Cbo << "workerId"sv << Entry.WorkerId; + Cbo << "succeeded"sv << Entry.Succeeded; + Cbo << "actionDescriptor"sv << Entry.ActionDescriptor; + if (Entry.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds); + } + if (Entry.RetryCount > 0) + { + Cbo << "retry_count"sv << Entry.RetryCount; + } + + for (const auto& Timestamp : Entry.Timestamps) + { + Cbo.AddInteger( + fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))), + Timestamp); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/running", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Running = m_ComputeService.GetRunningActions(); + CbObjectWriter Cbo; + Cbo.BeginArray("running"); + for (const auto& Info : Running) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Info.Lsn; + Cbo << "queueId"sv << Info.QueueId; + Cbo << "actionId"sv << Info.ActionId; + if (Info.CpuUsagePercent >= 0.0f) + { + Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent); + } + if (Info.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{lsn}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int ActionLsn = ParseInt<int>(Req.GetCapture(1)).value_or(0); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbPackage Output; + HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output); + + if (ResponseCode == HttpResponseCode::OK) + { + HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + else + { + HttpReq.WriteResponse(ResponseCode); + } + + // Once we've initiated the response we can mark the result + // as retired, allowing the service to free any associated + // resources. Note that there still needs to be a delay + // to allow the transmission to complete, it would be better + // if we could issue this once the response is fully sent... + m_ComputeService.RetireActionResult(ActionLsn); + } + break; + + case HttpVerb::kPost: + { + auto Result = m_ComputeService.RescheduleAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "lsn"sv << ActionLsn; + Cbo << "retry_count"sv << Result.RetryCount; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the + // one which uses the scheduled action lsn for lookups + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2)); + + CbPackage Output; + if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output); + ResponseCode != HttpResponseCode::OK) + { + ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) + + if (ResponseCode == HttpResponseCode::NotFound) + { + return HttpReq.WriteResponse(ResponseCode); + } + + return HttpReq.WriteResponse(ResponseCode); + } + + ZEN_DEBUG("jobs/{}/{}: OK", Req.GetCapture(1), Req.GetCapture(2)) + + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId); + + if (!Worker) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + // TODO: return status of all pending or executing jobs + break; + + case HttpVerb::kPost: + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span<const CbAttachment> Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + ActionObj.GetHash(), + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + break; + + default: + break; + } + break; + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto QueryParams = HttpReq.GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + // Resolve worker + + // + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span<const CbAttachment> Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + return; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "workers/all", + [this](HttpRouterRequest& Req) { HandleWorkersAllGet(Req.ServerRequest()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkersGet(HttpReq); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers/all", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkersAllGet(HttpReq); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkerRequest(HttpReq, IoHash::FromHexString(Req.GetCapture(2))); + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "sysinfo", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + + CbObjectWriter Cbo; + Describe(Sm, Cbo); + + Cbo << "cpu_usage" << Sm.CpuUsagePercent; + Cbo << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; + Cbo << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; + Cbo << "disk_used" << 100 * 1024; + Cbo << "disk_total" << 100 * 1024 * 1024; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "record/start", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + m_ComputeService.StartRecording(m_CidStore, m_BaseDir / "recording"); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "record/stop", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + m_ComputeService.StopRecording(); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); + + // Local-only queue listing and creation + + m_Router.RegisterRoute( + "queues", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbObjectWriter Cbo; + Cbo.BeginArray("queues"sv); + + for (const int QueueId : m_ComputeService.GetQueueIds()) + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + continue; + } + + Cbo.BeginObject(); + WriteQueueDescription(Cbo, QueueId, Status); + Cbo.EndObject(); + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + case HttpVerb::kPost: + { + CbObject Metadata; + CbObject Config; + if (const CbObject Body = HttpReq.ReadPayloadObject()) + { + Metadata = Body.Find("metadata"sv).AsObject(); + Config = Body.Find("config"sv).AsObject(); + } + + ComputeServiceSession::CreateQueueResult Result = + m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config)); + + CbObjectWriter Cbo; + Cbo << "queue_id"sv << Result.QueueId; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + // Queue creation routes — these remain separate since local creates a plain queue + // while remote additionally generates an OID token for external access. + + m_Router.RegisterRoute( + "queues/remote", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + // Extract optional fields from the request body. + // idempotency_key: when present, we return the existing remote queue token for this + // key rather than creating a new queue, making the endpoint safe to call concurrently. + // hostname: human-readable origin context stored alongside the queue for diagnostics. + // metadata: arbitrary CbObject metadata propagated from the originating queue. + // config: arbitrary CbObject config propagated from the originating queue. + std::string IdempotencyKey; + std::string ClientHostname; + CbObject Metadata; + CbObject Config; + if (const CbObject Body = HttpReq.ReadPayloadObject()) + { + IdempotencyKey = std::string(Body["idempotency_key"sv].AsString()); + ClientHostname = std::string(Body["hostname"sv].AsString()); + Metadata = Body.Find("metadata"sv).AsObject(); + Config = Body.Find("config"sv).AsObject(); + } + + // Stamp the forwarding node's hostname into the metadata so that the + // remote side knows which node originated the queue. + if (!ClientHostname.empty()) + { + CbObjectWriter MetaWriter; + for (auto Field : Metadata) + { + MetaWriter.AddField(Field.GetName(), Field); + } + MetaWriter << "via"sv << ClientHostname; + Metadata = MetaWriter.Save(); + } + + RwLock::ExclusiveLockScope _(m_RemoteQueueLock); + + if (!IdempotencyKey.empty()) + { + if (auto It = m_RemoteQueuesByTag.find(IdempotencyKey); It != m_RemoteQueuesByTag.end()) + { + Ref<RemoteQueueInfo> Existing = It->second; + if (m_ComputeService.GetQueueStatus(Existing->QueueId).IsValid) + { + CbObjectWriter Cbo; + Cbo << "queue_token"sv << Existing->Token.ToString(); + Cbo << "queue_id"sv << Existing->QueueId; + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + // Queue has since expired — clean up stale entries and fall through to create a new one + m_RemoteQueuesByToken.erase(Existing->Token); + m_RemoteQueuesByQueueId.erase(Existing->QueueId); + m_RemoteQueuesByTag.erase(It); + } + } + + ComputeServiceSession::CreateQueueResult Result = m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config)); + Ref<RemoteQueueInfo> InfoRef(new RemoteQueueInfo()); + InfoRef->QueueId = Result.QueueId; + InfoRef->Token = Oid::NewOid(); + InfoRef->IdempotencyKey = std::move(IdempotencyKey); + InfoRef->ClientHostname = std::move(ClientHostname); + + m_RemoteQueuesByToken[InfoRef->Token] = InfoRef; + m_RemoteQueuesByQueueId[InfoRef->QueueId] = InfoRef; + if (!InfoRef->IdempotencyKey.empty()) + { + m_RemoteQueuesByTag[InfoRef->IdempotencyKey] = InfoRef; + } + + CbObjectWriter Cbo; + Cbo << "queue_token"sv << InfoRef->Token.ToString(); + Cbo << "queue_id"sv << InfoRef->QueueId; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + + // Unified queue routes — {queueref} accepts both local integer IDs and remote OID tokens. + // ResolveQueueRef() handles access control (local-only for integer IDs) and token resolution. + + m_Router.RegisterRoute( + "queues/{queueref}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Cbo; + WriteQueueDescription(Cbo, QueueId, Status); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + case HttpVerb::kDelete: + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + m_ComputeService.CancelQueue(QueueId); + + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "queues/{queueref}/drain", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + m_ComputeService.DrainQueue(QueueId); + + // Return updated queue status + Status = m_ComputeService.GetQueueStatus(QueueId); + + CbObjectWriter Cbo; + WriteQueueDescription(Cbo, QueueId, Status); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/completed", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Cbo; + m_ComputeService.GetQueueCompleted(QueueId, Cbo); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = HttpReq.GetQueryParams(); + + int QueryLimit = 50; + + if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) + { + QueryLimit = ParseInt<int>(LimitParam).value_or(50); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("history"); + for (const auto& Entry : m_ComputeService.GetQueueHistory(QueueId, QueryLimit)) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Entry.Lsn; + Cbo << "queueId"sv << Entry.QueueId; + Cbo << "actionId"sv << Entry.ActionId; + Cbo << "workerId"sv << Entry.WorkerId; + Cbo << "succeeded"sv << Entry.Succeeded; + if (Entry.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds); + } + if (Entry.RetryCount > 0) + { + Cbo << "retry_count"sv << Entry.RetryCount; + } + + for (const auto& Timestamp : Entry.Timestamps) + { + Cbo.AddInteger( + fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))), + Timestamp); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/running", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + if (QueueId == 0) + { + return; + } + // Filter global running list to this queue + auto AllRunning = m_ComputeService.GetRunningActions(); + std::vector<ComputeServiceSession::RunningActionInfo> Running; + for (auto& Info : AllRunning) + if (Info.QueueId == QueueId) + Running.push_back(Info); + CbObjectWriter Cbo; + Cbo.BeginArray("running"); + for (const auto& Info : Running) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Info.Lsn; + Cbo << "queueId"sv << Info.QueueId; + Cbo << "actionId"sv << Info.ActionId; + if (Info.CpuUsagePercent >= 0.0f) + { + Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent); + } + if (Info.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(2)); + WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId); + + if (!Worker) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + if (!CheckAttachments(ActionObj, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, ActionObj.GetHash(), Result.Lsn); + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + IngestStats Stats = IngestPackageAttachments(Action); + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + ActionObj.GetHash(), + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + if (!CheckAttachments(ActionObj, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: action accepted (lsn {})", QueueId, Result.Lsn); + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + IngestStats Stats = IngestPackageAttachments(Action); + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{lsn}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + const int ActionLsn = ParseInt<int>(Req.GetCapture(2)).value_or(0); + + if (QueueId == 0) + { + return; + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + ZEN_UNUSED(QueueId); + + CbPackage Output; + HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output); + + if (ResponseCode == HttpResponseCode::OK) + { + HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + else + { + HttpReq.WriteResponse(ResponseCode); + } + + m_ComputeService.RetireActionResult(ActionLsn); + } + break; + + case HttpVerb::kPost: + { + ZEN_UNUSED(QueueId); + + auto Result = m_ComputeService.RescheduleAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "lsn"sv << ActionLsn; + Cbo << "retry_count"sv << Result.RetryCount; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); +} + +////////////////////////////////////////////////////////////////////////// + +HttpComputeService::HttpComputeService(CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions) +: m_Impl(std::make_unique<Impl>(this, InCidStore, StatsService, BaseDir, MaxConcurrentActions)) +{ +} + +HttpComputeService::~HttpComputeService() +{ + m_Impl->m_StatsService.UnregisterHandler("compute", *this); +} + +void +HttpComputeService::Shutdown() +{ + m_Impl->m_ComputeService.Shutdown(); +} + +ComputeServiceSession::ActionCounts +HttpComputeService::GetActionCounts() +{ + return m_Impl->m_ComputeService.GetActionCounts(); +} + +const char* +HttpComputeService::BaseUri() const +{ + return "/compute/"; +} + +void +HttpComputeService::HandleRequest(HttpServerRequest& Request) +{ + ZEN_TRACE_CPU("HttpComputeService::HandleRequest"); + metrics::OperationTiming::Scope $(m_Impl->m_HttpRequests); + + if (m_Impl->m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +void +HttpComputeService::HandleStatsRequest(HttpServerRequest& Request) +{ + CbObjectWriter Cbo; + m_Impl->m_ComputeService.EmitStats(Cbo); + + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpComputeService::Impl::WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status) +{ + Cbo << "queue_id"sv << Status.QueueId; + Cbo << "active_count"sv << Status.ActiveCount; + Cbo << "completed_count"sv << Status.CompletedCount; + Cbo << "failed_count"sv << Status.FailedCount; + Cbo << "abandoned_count"sv << Status.AbandonedCount; + Cbo << "cancelled_count"sv << Status.CancelledCount; + Cbo << "state"sv << ToString(Status.State); + Cbo << "cancelled"sv << (Status.State == ComputeServiceSession::QueueState::Cancelled); + Cbo << "draining"sv << (Status.State == ComputeServiceSession::QueueState::Draining); + Cbo << "is_complete"sv << Status.IsComplete; + + if (CbObject Meta = m_ComputeService.GetQueueMetadata(QueueId)) + { + Cbo << "metadata"sv << Meta; + } + + if (CbObject Cfg = m_ComputeService.GetQueueConfig(QueueId)) + { + Cbo << "config"sv << Cfg; + } + + { + RwLock::SharedLockScope $(m_RemoteQueueLock); + if (auto It = m_RemoteQueuesByQueueId.find(QueueId); It != m_RemoteQueuesByQueueId.end()) + { + Cbo << "queue_token"sv << It->second->Token.ToString(); + if (!It->second->ClientHostname.empty()) + { + Cbo << "hostname"sv << It->second->ClientHostname; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// + +int +HttpComputeService::Impl::ResolveQueueToken(const Oid& Token) +{ + RwLock::SharedLockScope $(m_RemoteQueueLock); + + auto It = m_RemoteQueuesByToken.find(Token); + + if (It != m_RemoteQueuesByToken.end()) + { + return It->second->QueueId; + } + + return 0; +} + +int +HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture) +{ + if (OidMatcher(Capture)) + { + // Remote OID token — accessible from any client + const Oid Token = Oid::FromHexString(Capture); + const int QueueId = ResolveQueueToken(Token); + + if (QueueId == 0) + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + return QueueId; + } + + // Local integer queue ID — restricted to local machine requests + if (!HttpReq.IsLocalMachineRequest()) + { + HttpReq.WriteResponse(HttpResponseCode::Forbidden); + return 0; + } + + return ParseInt<int>(Capture).value_or(0); +} + +HttpComputeService::Impl::IngestStats +HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package) +{ + IngestStats Stats; + + for (const CbAttachment& Attachment : Package.GetAttachments()) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + Stats.Bytes += CompressedSize; + ++Stats.Count; + + const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + Stats.NewBytes += CompressedSize; + ++Stats.NewCount; + } + } + + return Stats; +} + +bool +HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList) +{ + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + return NeedList.empty(); +} + +void +HttpComputeService::Impl::HandleWorkersGet(HttpServerRequest& HttpReq) +{ + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + for (const IoHash& WorkerId : m_ComputeService.GetKnownWorkerIds()) + { + Cbo << WorkerId; + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpComputeService::Impl::HandleWorkersAllGet(HttpServerRequest& HttpReq) +{ + std::vector<IoHash> WorkerIds = m_ComputeService.GetKnownWorkerIds(); + + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + for (const IoHash& WorkerId : WorkerIds) + { + Cbo.BeginObject(); + Cbo << "id" << WorkerId; + Cbo << "descriptor" << m_ComputeService.GetWorkerDescriptor(WorkerId).Descriptor.GetObject(); + Cbo.EndObject(); + } + + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId) +{ + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + if (WorkerDesc Desc = m_ComputeService.GetWorkerDescriptor(WorkerId)) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor.GetObject()); + } + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + + case HttpVerb::kPost: + { + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + CbObject WorkerSpec = HttpReq.ReadPayloadObject(); + + HashKeySet ChunkSet; + WorkerSpec.IterateAttachments([&](CbFieldView Field) { + const IoHash Hash = Field.AsHash(); + ChunkSet.AddHashToSet(Hash); + }); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerSpec); + + m_CidStore.FilterChunks(ChunkSet); + + if (ChunkSet.IsEmpty()) + { + ZEN_DEBUG("worker {}: all attachments already available", WorkerId); + m_ComputeService.RegisterWorker(WorkerPackage); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); + ChunkSet.IterateHashes([&](const IoHash& Hash) { + ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash); + ResponseWriter.AddHash(Hash); + }); + ResponseWriter.EndArray(); + + ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize()); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage WorkerSpecPackage = HttpReq.ReadPayloadPackage(); + CbObject WorkerSpec = WorkerSpecPackage.GetObject(); + + std::span<const CbAttachment> Attachments = WorkerSpecPackage.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + TotalAttachmentBytes += Buffer.GetCompressedSize(); + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += Buffer.GetCompressedSize(); + ++NewAttachmentCount; + } + } + + ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments", + WorkerId, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + m_ComputeService.RegisterWorker(WorkerSpecPackage); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + break; + + default: + break; + } + } + break; + + default: + break; + } +} + +////////////////////////////////////////////////////////////////////////// + +void +httpcomputeservice_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httpfunctionservice.cpp b/src/zencompute/httpfunctionservice.cpp deleted file mode 100644 index 09a9684a7..000000000 --- a/src/zencompute/httpfunctionservice.cpp +++ /dev/null @@ -1,709 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "zencompute/httpfunctionservice.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "functionrunner.h" - -# include <zencore/compactbinary.h> -# include <zencore/compactbinarybuilder.h> -# include <zencore/compactbinarypackage.h> -# include <zencore/compress.h> -# include <zencore/except.h> -# include <zencore/filesystem.h> -# include <zencore/fmtutils.h> -# include <zencore/iobuffer.h> -# include <zencore/iohash.h> -# include <zencore/system.h> -# include <zenstore/cidstore.h> - -# include <span> - -using namespace std::literals; - -namespace zen::compute { - -constinit AsciiSet g_DecimalSet("0123456789"); -auto DecimalMatcher = [](std::string_view Str) { return AsciiSet::HasOnly(Str, g_DecimalSet); }; - -constinit AsciiSet g_HexSet("0123456789abcdefABCDEF"); -auto IoHashMatcher = [](std::string_view Str) { return Str.size() == 40 && AsciiSet::HasOnly(Str, g_HexSet); }; - -HttpFunctionService::HttpFunctionService(CidStore& InCidStore, - IHttpStatsService& StatsService, - [[maybe_unused]] const std::filesystem::path& BaseDir) -: m_CidStore(InCidStore) -, m_StatsService(StatsService) -, m_Log(logging::Get("apply")) -, m_BaseDir(BaseDir) -, m_FunctionService(InCidStore) -{ - m_FunctionService.AddLocalRunner(InCidStore, m_BaseDir / "local"); - - m_StatsService.RegisterHandler("apply", *this); - - m_Router.AddMatcher("lsn", DecimalMatcher); - m_Router.AddMatcher("worker", IoHashMatcher); - m_Router.AddMatcher("action", IoHashMatcher); - - m_Router.RegisterRoute( - "ready", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - if (m_FunctionService.IsHealthy()) - { - return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); - } - - return HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "workers", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - CbObjectWriter Cbo; - Cbo.BeginArray("workers"sv); - for (const IoHash& WorkerId : m_FunctionService.GetKnownWorkerIds()) - { - Cbo << WorkerId; - } - Cbo.EndArray(); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "workers/{worker}", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); - - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - if (WorkerDesc Desc = m_FunctionService.GetWorkerDescriptor(WorkerId)) - { - return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor.GetObject()); - } - return HttpReq.WriteResponse(HttpResponseCode::NotFound); - - case HttpVerb::kPost: - { - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - CbObject WorkerSpec = HttpReq.ReadPayloadObject(); - - // Determine which pieces are missing and need to be transmitted - - HashKeySet ChunkSet; - - WorkerSpec.IterateAttachments([&](CbFieldView Field) { - const IoHash Hash = Field.AsHash(); - ChunkSet.AddHashToSet(Hash); - }); - - CbPackage WorkerPackage; - WorkerPackage.SetObject(WorkerSpec); - - m_CidStore.FilterChunks(ChunkSet); - - if (ChunkSet.IsEmpty()) - { - ZEN_DEBUG("worker {}: all attachments already available", WorkerId); - m_FunctionService.RegisterWorker(WorkerPackage); - return HttpReq.WriteResponse(HttpResponseCode::NoContent); - } - - CbObjectWriter ResponseWriter; - ResponseWriter.BeginArray("need"); - - ChunkSet.IterateHashes([&](const IoHash& Hash) { - ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash); - ResponseWriter.AddHash(Hash); - }); - - ResponseWriter.EndArray(); - - ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize()); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); - } - break; - - case HttpContentType::kCbPackage: - { - CbPackage WorkerSpecPackage = HttpReq.ReadPayloadPackage(); - CbObject WorkerSpec = WorkerSpecPackage.GetObject(); - - std::span<const CbAttachment> Attachments = WorkerSpecPackage.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer Buffer = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - TotalAttachmentBytes += Buffer.GetCompressedSize(); - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += Buffer.GetCompressedSize(); - ++NewAttachmentCount; - } - } - - ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments", - WorkerId, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - m_FunctionService.RegisterWorker(WorkerSpecPackage); - - return HttpReq.WriteResponse(HttpResponseCode::NoContent); - } - break; - - default: - break; - } - } - break; - - default: - break; - } - }, - HttpVerb::kGet | HttpVerb::kPost); - - m_Router.RegisterRoute( - "jobs/completed", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - CbObjectWriter Cbo; - m_FunctionService.GetCompleted(Cbo); - - SystemMetrics Sm = GetSystemMetricsForReporting(); - Cbo.BeginObject("metrics"); - Describe(Sm, Cbo); - Cbo.EndObject(); - - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "jobs/history", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const auto QueryParams = HttpReq.GetQueryParams(); - - int QueryLimit = 50; - - if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) - { - QueryLimit = ParseInt<int>(LimitParam).value_or(50); - } - - CbObjectWriter Cbo; - Cbo.BeginArray("history"); - for (const auto& Entry : m_FunctionService.GetActionHistory(QueryLimit)) - { - Cbo.BeginObject(); - Cbo << "lsn"sv << Entry.Lsn; - Cbo << "actionId"sv << Entry.ActionId; - Cbo << "workerId"sv << Entry.WorkerId; - Cbo << "succeeded"sv << Entry.Succeeded; - Cbo << "actionDescriptor"sv << Entry.ActionDescriptor; - - for (const auto& Timestamp : Entry.Timestamps) - { - Cbo.AddInteger( - fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))), - Timestamp); - } - Cbo.EndObject(); - } - Cbo.EndArray(); - - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "jobs/{lsn}", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const int ActionLsn = std::stoi(std::string{Req.GetCapture(1)}); - - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - { - CbPackage Output; - HttpResponseCode ResponseCode = m_FunctionService.GetActionResult(ActionLsn, Output); - - if (ResponseCode == HttpResponseCode::OK) - { - return HttpReq.WriteResponse(HttpResponseCode::OK, Output); - } - - return HttpReq.WriteResponse(ResponseCode); - } - break; - - case HttpVerb::kPost: - { - // Add support for cancellation, priority changes - } - break; - - default: - break; - } - }, - HttpVerb::kGet | HttpVerb::kPost); - - m_Router.RegisterRoute( - "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the - // one which uses the scheduled action lsn for lookups - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2)); - - CbPackage Output; - if (HttpResponseCode ResponseCode = m_FunctionService.FindActionResult(ActionId, /* out */ Output); - ResponseCode != HttpResponseCode::OK) - { - ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) - - if (ResponseCode == HttpResponseCode::NotFound) - { - return HttpReq.WriteResponse(ResponseCode); - } - - return HttpReq.WriteResponse(ResponseCode); - } - - ZEN_DEBUG("jobs/{}/{}: OK", Req.GetCapture(1), Req.GetCapture(2)) - - return HttpReq.WriteResponse(HttpResponseCode::OK, Output); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "jobs/{worker}", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); - - WorkerDesc Worker = m_FunctionService.GetWorkerDescriptor(WorkerId); - - if (!Worker) - { - return HttpReq.WriteResponse(HttpResponseCode::NotFound); - } - - const auto QueryParams = Req.ServerRequest().GetQueryParams(); - - int RequestPriority = -1; - - if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) - { - RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); - } - - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - // TODO: return status of all pending or executing jobs - break; - - case HttpVerb::kPost: - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (FunctionServiceSession::EnqueueResult Result = - m_FunctionService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - break; - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span<const CbAttachment> Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (FunctionServiceSession::EnqueueResult Result = - m_FunctionService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", - ActionObj.GetHash(), - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - break; - - default: - break; - } - break; - - default: - break; - } - }, - HttpVerb::kPost); - - m_Router.RegisterRoute( - "jobs", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - const auto QueryParams = HttpReq.GetQueryParams(); - - int RequestPriority = -1; - - if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) - { - RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); - } - - // Resolve worker - - // - - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (FunctionServiceSession::EnqueueResult Result = m_FunctionService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span<const CbAttachment> Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (FunctionServiceSession::EnqueueResult Result = m_FunctionService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - return; - } - }, - HttpVerb::kPost); - - m_Router.RegisterRoute( - "workers/all", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - std::vector<IoHash> WorkerIds = m_FunctionService.GetKnownWorkerIds(); - - CbObjectWriter Cbo; - Cbo.BeginArray("workers"); - - for (const IoHash& WorkerId : WorkerIds) - { - Cbo.BeginObject(); - - Cbo << "id" << WorkerId; - - const auto& Descriptor = m_FunctionService.GetWorkerDescriptor(WorkerId); - - Cbo << "descriptor" << Descriptor.Descriptor.GetObject(); - - Cbo.EndObject(); - } - - Cbo.EndArray(); - - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "sysinfo", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - SystemMetrics Sm = GetSystemMetricsForReporting(); - - CbObjectWriter Cbo; - Describe(Sm, Cbo); - - Cbo << "cpu_usage" << Sm.CpuUsagePercent; - Cbo << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; - Cbo << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; - Cbo << "disk_used" << 100 * 1024; - Cbo << "disk_total" << 100 * 1024 * 1024; - - return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "record/start", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - m_FunctionService.StartRecording(m_CidStore, m_BaseDir / "recording"); - - return HttpReq.WriteResponse(HttpResponseCode::OK); - }, - HttpVerb::kPost); - - m_Router.RegisterRoute( - "record/stop", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - m_FunctionService.StopRecording(); - - return HttpReq.WriteResponse(HttpResponseCode::OK); - }, - HttpVerb::kPost); -} - -HttpFunctionService::~HttpFunctionService() -{ - m_StatsService.UnregisterHandler("apply", *this); -} - -void -HttpFunctionService::Shutdown() -{ - m_FunctionService.Shutdown(); -} - -const char* -HttpFunctionService::BaseUri() const -{ - return "/apply/"; -} - -void -HttpFunctionService::HandleRequest(HttpServerRequest& Request) -{ - metrics::OperationTiming::Scope $(m_HttpRequests); - - if (m_Router.HandleRequest(Request) == false) - { - ZEN_WARN("No route found for {0}", Request.RelativeUri()); - } -} - -void -HttpFunctionService::HandleStatsRequest(HttpServerRequest& Request) -{ - CbObjectWriter Cbo; - m_FunctionService.EmitStats(Cbo); - - Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); -} - -////////////////////////////////////////////////////////////////////////// - -void -httpfunction_forcelink() -{ -} - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp index 39e7e60d7..6cbe01e04 100644 --- a/src/zencompute/httporchestrator.cpp +++ b/src/zencompute/httporchestrator.cpp @@ -2,65 +2,398 @@ #include "zencompute/httporchestrator.h" -#include <zencore/compactbinarybuilder.h> -#include <zencore/logging.h> +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencompute/orchestratorservice.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/logging.h> +# include <zencore/string.h> +# include <zencore/system.h> namespace zen::compute { -HttpOrchestratorService::HttpOrchestratorService() : m_Log(logging::Get("orch")) +// Worker IDs must be 3-64 characters and can only contain letters, numbers, underscores, and dashes +static bool +IsValidWorkerId(std::string_view Id) +{ + if (Id.size() < 3 || Id.size() > 64) + { + return false; + } + for (char c : Id) + { + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-') + { + continue; + } + return false; + } + return true; +} + +// Shared announce payload parser used by both the HTTP POST route and the +// WebSocket message handler. Returns the worker ID on success (empty on +// validation failure). The returned WorkerAnnouncement has string_view +// fields that reference the supplied CbObjectView, so the CbObject must +// outlive the returned announcement. +static std::string_view +ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnnouncement& Ann) { + Ann.Id = Data["id"].AsString(""); + Ann.Uri = Data["uri"].AsString(""); + + if (!IsValidWorkerId(Ann.Id)) + { + return {}; + } + + if (!Ann.Uri.starts_with("http://") && !Ann.Uri.starts_with("https://")) + { + return {}; + } + + Ann.Hostname = Data["hostname"].AsString(""); + Ann.Platform = Data["platform"].AsString(""); + Ann.CpuUsagePercent = Data["cpu_usage"].AsFloat(0.0f); + Ann.MemoryTotalBytes = Data["memory_total"].AsUInt64(0); + Ann.MemoryUsedBytes = Data["memory_used"].AsUInt64(0); + Ann.BytesReceived = Data["bytes_received"].AsUInt64(0); + Ann.BytesSent = Data["bytes_sent"].AsUInt64(0); + Ann.ActionsPending = Data["actions_pending"].AsInt32(0); + Ann.ActionsRunning = Data["actions_running"].AsInt32(0); + Ann.ActionsCompleted = Data["actions_completed"].AsInt32(0); + Ann.ActiveQueues = Data["active_queues"].AsInt32(0); + Ann.Provisioner = Data["provisioner"].AsString(""); + + if (auto Metrics = Data["metrics"].AsObjectView()) + { + Ann.Cpus = Metrics["lp_count"].AsInt32(0); + if (Ann.Cpus <= 0) + { + Ann.Cpus = 1; + } + } + + return Ann.Id; +} + +HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) +: m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket)) +, m_Hostname(GetMachineName()) +{ + m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); + m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); + + // dummy endpoint for websocket clients + m_Router.RegisterRoute( + "ws", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "status", + [this](HttpRouterRequest& Req) { + CbObjectWriter Cbo; + Cbo << "hostname" << std::string_view(m_Hostname); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + m_Router.RegisterRoute( "provision", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "announce", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); - CbObjectWriter Cbo; - Cbo.BeginArray("workers"); + CbObject Data = HttpReq.ReadPayloadObject(); + + OrchestratorService::WorkerAnnouncement Ann; + std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); - m_KnownWorkersLock.WithSharedLock([&] { - for (const auto& [WorkerId, Worker] : m_KnownWorkers) + if (WorkerId.empty()) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Invalid worker announcement: id must be 3-64 alphanumeric/underscore/dash " + "characters and uri must start with http:// or https://"); + } + + m_Service->AnnounceWorker(Ann); + + HttpReq.WriteResponse(HttpResponseCode::OK); + +# if ZEN_WITH_WEBSOCKETS + // Notify push thread that state may have changed + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "agents", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + int Limit = 100; + auto LimitStr = Params.GetValue("limit"); + if (!LimitStr.empty()) + { + Limit = std::atoi(std::string(LimitStr).c_str()); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetProvisioningHistory(Limit)); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "timeline/{workerid}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + std::string_view WorkerId = Req.GetCapture(1); + auto Params = HttpReq.GetQueryParams(); + + auto FromStr = Params.GetValue("from"); + auto ToStr = Params.GetValue("to"); + auto LimitStr = Params.GetValue("limit"); + + std::optional<DateTime> From; + std::optional<DateTime> To; + + if (!FromStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(FromStr); + if (!Val) { - Cbo.BeginObject(); - Cbo << "uri" << Worker.BaseUri; - Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); - Cbo.EndObject(); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } - }); + From = DateTime(*Val); + } + + if (!ToStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(ToStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + To = DateTime(*Val); + } + + int Limit = !LimitStr.empty() ? zen::ParseInt<int>(LimitStr).value_or(0) : 0; - Cbo.EndArray(); + CbObject Result = m_Service->GetWorkerTimeline(WorkerId, From, To, Limit); - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + if (!Result) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); }, - HttpVerb::kPost); + HttpVerb::kGet); m_Router.RegisterRoute( - "announce", + "timeline", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + auto FromStr = Params.GetValue("from"); + auto ToStr = Params.GetValue("to"); + + DateTime From = DateTime(0); + DateTime To = DateTime::Now(); + + if (!FromStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(FromStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + From = DateTime(*Val); + } + + if (!ToStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(ToStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + To = DateTime(*Val); + } + + CbObject Result = m_Service->GetAllTimelines(From, To); + + HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); + }, + HttpVerb::kGet); + + // Client tracking endpoints + + m_Router.RegisterRoute( + "clients", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); CbObject Data = HttpReq.ReadPayloadObject(); - std::string_view WorkerId = Data["id"].AsString(""); - std::string_view WorkerUri = Data["uri"].AsString(""); + OrchestratorService::ClientAnnouncement Ann; + Ann.SessionId = Data["session_id"].AsObjectId(Oid::Zero); + Ann.Hostname = Data["hostname"].AsString(""); + Ann.Address = HttpReq.GetRemoteAddress(); - if (WorkerId.empty() || WorkerUri.empty()) + auto MetadataView = Data["metadata"].AsObjectView(); + if (MetadataView) { - return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + Ann.Metadata = CbObject::Clone(MetadataView); } - m_KnownWorkersLock.WithExclusiveLock([&] { - auto& Worker = m_KnownWorkers[std::string(WorkerId)]; - Worker.BaseUri = WorkerUri; - Worker.LastSeen.Reset(); - }); + std::string ClientId = m_Service->AnnounceClient(Ann); - HttpReq.WriteResponse(HttpResponseCode::OK); + CbObjectWriter ResponseObj; + ResponseObj << "id" << std::string_view(ClientId); + HttpReq.WriteResponse(HttpResponseCode::OK, ResponseObj.Save()); + +# if ZEN_WITH_WEBSOCKETS + m_PushEvent.Set(); +# endif }, HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients/{clientid}/update", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view ClientId = Req.GetCapture(1); + + CbObject MetadataObj; + CbObject Data = HttpReq.ReadPayloadObject(); + if (Data) + { + auto MetadataView = Data["metadata"].AsObjectView(); + if (MetadataView) + { + MetadataObj = CbObject::Clone(MetadataView); + } + } + + if (m_Service->UpdateClient(ClientId, std::move(MetadataObj))) + { + HttpReq.WriteResponse(HttpResponseCode::OK); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients/{clientid}/complete", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view ClientId = Req.GetCapture(1); + + if (m_Service->CompleteClient(ClientId)) + { + HttpReq.WriteResponse(HttpResponseCode::OK); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + +# if ZEN_WITH_WEBSOCKETS + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetClientList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "clients/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + int Limit = 100; + auto LimitStr = Params.GetValue("limit"); + if (!LimitStr.empty()) + { + Limit = std::atoi(std::string(LimitStr).c_str()); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetClientHistory(Limit)); + }, + HttpVerb::kGet); + +# if ZEN_WITH_WEBSOCKETS + + // Start the WebSocket push thread + m_PushEnabled.store(true); + m_PushThread = std::thread([this] { PushThreadFunction(); }); +# endif } HttpOrchestratorService::~HttpOrchestratorService() { + Shutdown(); +} + +void +HttpOrchestratorService::Shutdown() +{ +# if ZEN_WITH_WEBSOCKETS + if (!m_PushEnabled.exchange(false)) + { + return; + } + + // Stop the push thread first, before touching connections. This ensures + // the push thread is no longer reading m_WsConnections or calling into + // m_Service when we start tearing things down. + m_PushEvent.Set(); + if (m_PushThread.joinable()) + { + m_PushThread.join(); + } + + // Clean up worker WebSocket connections — collect IDs under lock, then + // notify the service outside the lock to avoid lock-order inversions. + std::vector<std::string> WorkerIds; + m_WorkerWsLock.WithExclusiveLock([&] { + WorkerIds.reserve(m_WorkerWsMap.size()); + for (const auto& [Conn, Id] : m_WorkerWsMap) + { + WorkerIds.push_back(Id); + } + m_WorkerWsMap.clear(); + }); + for (const auto& Id : WorkerIds) + { + m_Service->SetWorkerWebSocketConnected(Id, false); + } + + // Now that the push thread is gone, release all dashboard connections. + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); }); +# endif } const char* @@ -78,4 +411,240 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) } } +////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +# if ZEN_WITH_WEBSOCKETS +void +HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + if (!m_PushEnabled.load()) + { + return; + } + + ZEN_INFO("WebSocket client connected"); + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); + + // Wake push thread to send initial state immediately + m_PushEvent.Set(); +} + +void +HttpOrchestratorService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + // Only handle binary messages from workers when the feature is enabled. + if (!m_Service->IsWorkerWebSocketEnabled() || Msg.Opcode != WebSocketOpcode::kBinary) + { + return; + } + + std::string WorkerId = HandleWorkerWebSocketMessage(Msg); + if (WorkerId.empty()) + { + return; + } + + // Check if this is a new worker WebSocket connection + bool IsNewWorkerWs = false; + m_WorkerWsLock.WithExclusiveLock([&] { + auto It = m_WorkerWsMap.find(&Conn); + if (It == m_WorkerWsMap.end()) + { + m_WorkerWsMap[&Conn] = WorkerId; + IsNewWorkerWs = true; + } + }); + + if (IsNewWorkerWs) + { + m_Service->SetWorkerWebSocketConnected(WorkerId, true); + } + + m_PushEvent.Set(); +} + +std::string +HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg) +{ + // Workers send CbObject in native binary format over the WebSocket to + // avoid the lossy CbObject↔JSON round-trip. + CbObject Data = CbObject::MakeView(Msg.Payload.GetData()); + if (!Data) + { + ZEN_WARN("worker WebSocket message is not a valid CbObject"); + return {}; + } + + OrchestratorService::WorkerAnnouncement Ann; + std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); + if (WorkerId.empty()) + { + ZEN_WARN("invalid worker announcement via WebSocket"); + return {}; + } + + m_Service->AnnounceWorker(Ann); + return std::string(WorkerId); +} + +void +HttpOrchestratorService::OnWebSocketClose(WebSocketConnection& Conn, + [[maybe_unused]] uint16_t Code, + [[maybe_unused]] std::string_view Reason) +{ + ZEN_INFO("WebSocket client disconnected (code {})", Code); + + // Check if this was a worker WebSocket connection; collect the ID under + // the worker lock, then notify the service outside the lock. + std::string DisconnectedWorkerId; + m_WorkerWsLock.WithExclusiveLock([&] { + auto It = m_WorkerWsMap.find(&Conn); + if (It != m_WorkerWsMap.end()) + { + DisconnectedWorkerId = std::move(It->second); + m_WorkerWsMap.erase(It); + } + }); + + if (!DisconnectedWorkerId.empty()) + { + m_Service->SetWorkerWebSocketConnected(DisconnectedWorkerId, false); + m_PushEvent.Set(); + } + + if (!m_PushEnabled.load()) + { + return; + } + + // Remove from dashboard connections + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} +# endif + +////////////////////////////////////////////////////////////////////////// +// +// Push thread +// + +# if ZEN_WITH_WEBSOCKETS +void +HttpOrchestratorService::PushThreadFunction() +{ + SetCurrentThreadName("orch_ws_push"); + + while (m_PushEnabled.load()) + { + m_PushEvent.Wait(2000); + m_PushEvent.Reset(); + + if (!m_PushEnabled.load()) + { + break; + } + + // Snapshot current connections + std::vector<Ref<WebSocketConnection>> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; }); + + if (Connections.empty()) + { + continue; + } + + // Build combined JSON with worker list, provisioning history, clients, and client history + CbObject WorkerList = m_Service->GetWorkerList(); + CbObject History = m_Service->GetProvisioningHistory(50); + CbObject ClientList = m_Service->GetClientList(); + CbObject ClientHistory = m_Service->GetClientHistory(50); + + ExtendableStringBuilder<4096> JsonBuilder; + JsonBuilder.Append("{"); + JsonBuilder.Append(fmt::format("\"hostname\":\"{}\",", m_Hostname)); + + // Emit workers array from worker list + ExtendableStringBuilder<2048> WorkerJson; + WorkerList.ToJson(WorkerJson); + std::string_view WorkerJsonView = WorkerJson.ToView(); + // Strip outer braces: {"workers":[...]} -> "workers":[...] + if (WorkerJsonView.size() >= 2) + { + JsonBuilder.Append(WorkerJsonView.substr(1, WorkerJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit events array from history + ExtendableStringBuilder<2048> HistoryJson; + History.ToJson(HistoryJson); + std::string_view HistoryJsonView = HistoryJson.ToView(); + if (HistoryJsonView.size() >= 2) + { + JsonBuilder.Append(HistoryJsonView.substr(1, HistoryJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit clients array from client list + ExtendableStringBuilder<2048> ClientJson; + ClientList.ToJson(ClientJson); + std::string_view ClientJsonView = ClientJson.ToView(); + if (ClientJsonView.size() >= 2) + { + JsonBuilder.Append(ClientJsonView.substr(1, ClientJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit client_events array from client history + ExtendableStringBuilder<2048> ClientHistoryJson; + ClientHistory.ToJson(ClientHistoryJson); + std::string_view ClientHistoryJsonView = ClientHistoryJson.ToView(); + if (ClientHistoryJsonView.size() >= 2) + { + JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2)); + } + + JsonBuilder.Append("}"); + std::string_view Json = JsonBuilder.ToView(); + + // Broadcast to all connected clients, prune closed ones + bool HadClosedConnections = false; + + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Json); + } + else + { + HadClosedConnections = true; + } + } + + if (HadClosedConnections) + { + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref<WebSocketConnection>& C) { + return !C->IsOpen(); + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); + } + } +} +# endif + } // namespace zen::compute + +#endif diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h new file mode 100644 index 000000000..a5bc5a34d --- /dev/null +++ b/src/zencompute/include/zencompute/cloudmetadata.h @@ -0,0 +1,151 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinarybuilder.h> +#include <zencore/logging.h> +#include <zencore/thread.h> + +#include <atomic> +#include <filesystem> +#include <string> +#include <thread> + +namespace zen::compute { + +enum class CloudProvider +{ + None, + AWS, + Azure, + GCP +}; + +std::string_view ToString(CloudProvider Provider); + +/** Snapshot of detected cloud instance properties. */ +struct CloudInstanceInfo +{ + CloudProvider Provider = CloudProvider::None; + std::string InstanceId; + std::string AvailabilityZone; + bool IsSpot = false; + bool IsAutoscaling = false; +}; + +/** + * Detects whether the process is running on a cloud VM (AWS, Azure, or GCP) + * and monitors for impending termination signals. + * + * Detection works by querying the Instance Metadata Service (IMDS) at the + * well-known link-local address 169.254.169.254, which is only routable from + * within a cloud VM. Each provider is probed in sequence (AWS -> Azure -> GCP); + * the first successful response wins. + * + * To avoid a ~200ms connect timeout penalty on every startup when running on + * bare-metal or non-cloud machines, failed probes write sentinel files + * (e.g. ".isNotAWS") to DataDir. Subsequent startups skip providers that have + * a sentinel present. Delete the sentinel files to force re-detection. + * + * When a provider is detected, a background thread polls for termination + * signals every 5 seconds (spot interruption, autoscaling lifecycle changes, + * scheduled maintenance). The termination state is exposed as an atomic bool + * so the compute server can include it in coordinator announcements and react + * to imminent shutdown. + * + * Thread safety: GetInstanceInfo() and GetTerminationReason() acquire a + * shared RwLock; the background monitor thread acquires the exclusive lock + * only when writing the termination reason (a one-time transition). The + * termination-pending flag itself is a lock-free atomic. + * + * Usage: + * auto Cloud = std::make_unique<CloudMetadata>(DataDir / "cloud"); + * if (Cloud->IsTerminationPending()) { ... } + * Cloud->Describe(AnnounceBody); // writes "cloud" sub-object into CB + */ +class CloudMetadata +{ +public: + /** Synchronously probes cloud providers and starts the termination monitor + * if a provider is detected. Creates DataDir if it does not exist. + */ + explicit CloudMetadata(std::filesystem::path DataDir); + + /** Synchronously probes cloud providers at the given IMDS endpoint. + * Intended for testing — allows redirecting all IMDS queries to a local + * mock HTTP server instead of the real 169.254.169.254 endpoint. + */ + CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint); + + /** Stops the termination monitor thread and joins it. */ + ~CloudMetadata(); + + CloudMetadata(const CloudMetadata&) = delete; + CloudMetadata& operator=(const CloudMetadata&) = delete; + + CloudProvider GetProvider() const; + CloudInstanceInfo GetInstanceInfo() const; + bool IsTerminationPending() const; + std::string GetTerminationReason() const; + + /** Writes a "cloud" sub-object into the compact binary writer if a provider + * was detected. No-op when running on bare metal. + */ + void Describe(CbWriter& Writer) const; + + /** Executes a single termination-poll cycle for the detected provider. + * Public so tests can drive poll cycles synchronously without relying on + * the background thread's 5-second timer. + */ + void PollTermination(); + + /** Removes the negative-cache sentinel files (.isNotAWS, .isNotAzure, + * .isNotGCP) from DataDir so subsequent detection probes are not skipped. + * Primarily intended for tests that need to reset state between sub-cases. + */ + void ClearSentinelFiles(); + +private: + /** Tries each provider in order, stops on first successful detection. */ + void DetectProvider(); + bool TryDetectAWS(); + bool TryDetectAzure(); + bool TryDetectGCP(); + + void WriteSentinelFile(const std::filesystem::path& Path); + bool HasSentinelFile(const std::filesystem::path& Path) const; + + void StartTerminationMonitor(); + void TerminationMonitorThread(); + void PollAWSTermination(); + void PollAzureTermination(); + void PollGCPTermination(); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + std::filesystem::path m_DataDir; + std::string m_ImdsEndpoint; + + mutable RwLock m_InfoLock; + CloudInstanceInfo m_Info; + + std::atomic<bool> m_TerminationPending{false}; + + mutable RwLock m_ReasonLock; + std::string m_TerminationReason; + + // IMDSv2 session token, acquired during AWS detection and reused for + // subsequent termination polling. Has a 300s TTL on the AWS side; if it + // expires mid-run the poll requests will get 401s which we treat as + // non-terminal (the monitor simply retries next cycle). + std::string m_AwsToken; + + std::thread m_MonitorThread; + std::atomic<bool> m_MonitorEnabled{true}; + Event m_MonitorEvent; +}; + +void cloudmetadata_forcelink(); // internal + +} // namespace zen::compute diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h new file mode 100644 index 000000000..65ec5f9ee --- /dev/null +++ b/src/zencompute/include/zencompute/computeservice.h @@ -0,0 +1,262 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/iohash.h> +# include <zenstore/zenstore.h> +# include <zenhttp/httpcommon.h> + +# include <filesystem> + +namespace zen { +class ChunkResolver; +class CbObjectWriter; +} // namespace zen + +namespace zen::compute { + +class ActionRecorder; +class ComputeServiceSession; +class IActionResultHandler; +class LocalProcessRunner; +class RemoteHttpRunner; +struct RunnerAction; +struct SubmitResult; + +struct WorkerDesc +{ + CbPackage Descriptor; + IoHash WorkerId{IoHash::Zero}; + + inline operator bool() const { return WorkerId != IoHash::Zero; } +}; + +/** + * Lambda style compute function service + * + * The responsibility of this class is to accept function execution requests, and + * schedule them using one or more FunctionRunner instances. It will basically always + * accept requests, queueing them if necessary, and then hand them off to runners + * as they become available. + * + * This is typically fronted by an API service that handles communication with clients. + */ +class ComputeServiceSession final +{ +public: + /** + * Session lifecycle state machine. + * + * Forward transitions: Created -> Ready -> Draining -> Paused -> Sunset + * Backward transitions: Draining -> Ready, Paused -> Ready + * Automatic transition: Draining -> Paused (when pending + running reaches 0) + * Jump transitions: any non-terminal -> Abandoned, any non-terminal -> Sunset + * Terminal states: Abandoned (only Sunset out), Sunset (no transitions out) + * + * | State | Accept new actions | Schedule pending | Finish running | + * |-----------|-------------------|-----------------|----------------| + * | Created | No | No | N/A | + * | Ready | Yes | Yes | Yes | + * | Draining | No | Yes | Yes | + * | Paused | No | No | No | + * | Abandoned | No | No | No (all abandoned) | + * | Sunset | No | No | No | + */ + enum class SessionState + { + Created, // Initial state before WaitUntilReady completes + Ready, // Normal operating state; accepts and schedules work + Draining, // Stops accepting new work; finishes existing; auto-transitions to Paused when empty + Paused, // Idle; no work accepted or scheduled; can resume to Ready + Abandoned, // Spot termination grace period; all actions abandoned; only Sunset out + Sunset // Terminal; triggers full shutdown + }; + + ComputeServiceSession(ChunkResolver& InChunkResolver); + ~ComputeServiceSession(); + + void WaitUntilReady(); + void Shutdown(); + bool IsHealthy(); + + SessionState GetSessionState() const; + + // Request a state transition. Returns false if the transition is invalid. + // Sunset can be reached from any non-Sunset state. + bool RequestStateTransition(SessionState NewState); + + // Orchestration + + void SetOrchestratorEndpoint(std::string_view Endpoint); + void SetOrchestratorBasePath(std::filesystem::path BasePath); + + // Worker registration and discovery + + void RegisterWorker(CbPackage Worker); + [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + [[nodiscard]] std::vector<IoHash> GetKnownWorkerIds(); + + // Action runners + + void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); + void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); + + // Action submission + + struct EnqueueResult + { + int Lsn; + CbObject ResponseMessage; + + inline operator bool() const { return Lsn != 0; } + }; + + [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority); + [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); + + // Queue management + // + // Queues group actions submitted by a single client session. They allow + // cancelling or polling completion of all actions in the group. + + struct CreateQueueResult + { + int QueueId = 0; // 0 if creation failed + }; + + enum class QueueState + { + Active, + Draining, + Cancelled, + }; + + struct QueueStatus + { + bool IsValid = false; + int QueueId = 0; + int ActiveCount = 0; // pending + running (not yet completed) + int CompletedCount = 0; // successfully completed + int FailedCount = 0; // failed + int AbandonedCount = 0; // abandoned + int CancelledCount = 0; // cancelled + QueueState State = QueueState::Active; + bool IsComplete = false; // ActiveCount == 0 + }; + + [[nodiscard]] CreateQueueResult CreateQueue(std::string_view Tag = {}, CbObject Metadata = {}, CbObject Config = {}); + [[nodiscard]] std::vector<int> GetQueueIds(); + [[nodiscard]] QueueStatus GetQueueStatus(int QueueId); + [[nodiscard]] CbObject GetQueueMetadata(int QueueId); + [[nodiscard]] CbObject GetQueueConfig(int QueueId); + void CancelQueue(int QueueId); + void DrainQueue(int QueueId); + void DeleteQueue(int QueueId); + void GetQueueCompleted(int QueueId, CbWriter& Cbo); + + // Queue-scoped action submission. Actions submitted via these methods are + // tracked under the given queue in addition to the global LSN-based tracking. + + [[nodiscard]] EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); + [[nodiscard]] EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + + // Completed action tracking + + [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + void RetireActionResult(int ActionLsn); + + // Action rescheduling + + struct RescheduleResult + { + bool Success = false; + std::string Error; + int RetryCount = 0; + }; + + [[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn); + + void GetCompleted(CbWriter&); + + // Running action tracking + + struct RunningActionInfo + { + int Lsn; + int QueueId; + IoHash ActionId; + float CpuUsagePercent; // -1.0 if not yet sampled + float CpuSeconds; // 0.0 if not yet sampled + }; + + [[nodiscard]] std::vector<RunningActionInfo> GetRunningActions(); + + // Action history tracking (note that this is separate from completed action tracking, and + // will include actions which have been retired and no longer have their results available) + + struct ActionHistoryEntry + { + int Lsn; + int QueueId = 0; + IoHash ActionId; + IoHash WorkerId; + CbObject ActionDescriptor; + std::string ExecutionLocation; + bool Succeeded; + float CpuSeconds = 0.0f; // total CPU time at completion; 0.0 if not sampled + int RetryCount = 0; // number of times this action was rescheduled + // sized to match RunnerAction::State::_Count but we can't use the enum here + // for dependency reasons, so just use a fixed size array and static assert in + // the implementation file + uint64_t Timestamps[8] = {}; + }; + + [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); + [[nodiscard]] std::vector<ActionHistoryEntry> GetQueueHistory(int QueueId, int Limit = 100); + + // Stats reporting + + struct ActionCounts + { + int Pending = 0; + int Running = 0; + int Completed = 0; + int ActiveQueues = 0; + }; + + [[nodiscard]] ActionCounts GetActionCounts(); + + void EmitStats(CbObjectWriter& Cbo); + + // Recording + + void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); + void StopRecording(); + +private: + void PostUpdate(RunnerAction* Action); + + friend class FunctionRunner; + friend struct RunnerAction; + + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void computeservice_forcelink(); + +} // namespace zen::compute + +namespace zen { +const char* ToString(compute::ComputeServiceSession::SessionState State); +const char* ToString(compute::ComputeServiceSession::QueueState State); +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/functionservice.h b/src/zencompute/include/zencompute/functionservice.h deleted file mode 100644 index 1deb99fd5..000000000 --- a/src/zencompute/include/zencompute/functionservice.h +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/zencore.h> - -#if !defined(ZEN_WITH_COMPUTE_SERVICES) -# define ZEN_WITH_COMPUTE_SERVICES 1 -#endif - -#if ZEN_WITH_COMPUTE_SERVICES - -# include <zencore/compactbinary.h> -# include <zencore/compactbinarypackage.h> -# include <zencore/iohash.h> -# include <zenstore/zenstore.h> -# include <zenhttp/httpcommon.h> - -# include <filesystem> - -namespace zen { -class ChunkResolver; -class CbObjectWriter; -} // namespace zen - -namespace zen::compute { - -class ActionRecorder; -class FunctionServiceSession; -class IActionResultHandler; -class LocalProcessRunner; -class RemoteHttpRunner; -struct RunnerAction; -struct SubmitResult; - -struct WorkerDesc -{ - CbPackage Descriptor; - IoHash WorkerId{IoHash::Zero}; - - inline operator bool() const { return WorkerId != IoHash::Zero; } -}; - -/** - * Lambda style compute function service - * - * The responsibility of this class is to accept function execution requests, and - * schedule them using one or more FunctionRunner instances. It will basically always - * accept requests, queueing them if necessary, and then hand them off to runners - * as they become available. - * - * This is typically fronted by an API service that handles communication with clients. - */ -class FunctionServiceSession final -{ -public: - FunctionServiceSession(ChunkResolver& InChunkResolver); - ~FunctionServiceSession(); - - void Shutdown(); - bool IsHealthy(); - - // Worker registration and discovery - - void RegisterWorker(CbPackage Worker); - [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); - [[nodiscard]] std::vector<IoHash> GetKnownWorkerIds(); - - // Action runners - - void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath); - void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); - - // Action submission - - struct EnqueueResult - { - int Lsn; - CbObject ResponseMessage; - - inline operator bool() const { return Lsn != 0; } - }; - - [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority); - [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); - - // Completed action tracking - - [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); - [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); - - void GetCompleted(CbWriter&); - - // Action history tracking (note that this is separate from completed action tracking, and - // will include actions which have been retired and no longer have their results available) - - struct ActionHistoryEntry - { - int Lsn; - IoHash ActionId; - IoHash WorkerId; - CbObject ActionDescriptor; - bool Succeeded; - uint64_t Timestamps[5] = {}; - }; - - [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); - - // Stats reporting - - void EmitStats(CbObjectWriter& Cbo); - - // Recording - - void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); - void StopRecording(); - -private: - void PostUpdate(RunnerAction* Action); - - friend class FunctionRunner; - friend struct RunnerAction; - - struct Impl; - std::unique_ptr<Impl> m_Impl; -}; - -void function_forcelink(); - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h new file mode 100644 index 000000000..ee1cd2614 --- /dev/null +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "zencompute/computeservice.h" + +# include <zenhttp/httpserver.h> + +# include <filesystem> +# include <memory> + +namespace zen { +class CidStore; +} + +namespace zen::compute { + +/** + * HTTP interface for compute service + */ +class HttpComputeService : public HttpService, public IHttpStatsProvider +{ +public: + HttpComputeService(CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions = 0); + ~HttpComputeService(); + + void Shutdown(); + + [[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts(); + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + + // IHttpStatsProvider + + void HandleStatsRequest(HttpServerRequest& Request) override; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void httpcomputeservice_forcelink(); + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httpfunctionservice.h b/src/zencompute/include/zencompute/httpfunctionservice.h deleted file mode 100644 index 6e2344ae6..000000000 --- a/src/zencompute/include/zencompute/httpfunctionservice.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/zencore.h> - -#if !defined(ZEN_WITH_COMPUTE_SERVICES) -# define ZEN_WITH_COMPUTE_SERVICES 1 -#endif - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "zencompute/functionservice.h" - -# include <zencore/compactbinary.h> -# include <zencore/compactbinarypackage.h> -# include <zencore/iohash.h> -# include <zencore/logging.h> -# include <zentelemetry/stats.h> -# include <zenhttp/httpserver.h> - -# include <deque> -# include <filesystem> -# include <unordered_map> - -namespace zen { -class CidStore; -} - -namespace zen::compute { - -class HttpFunctionService; -class FunctionService; - -/** - * HTTP interface for compute function service - */ -class HttpFunctionService : public HttpService, public IHttpStatsProvider -{ -public: - HttpFunctionService(CidStore& InCidStore, IHttpStatsService& StatsService, const std::filesystem::path& BaseDir); - ~HttpFunctionService(); - - void Shutdown(); - - virtual const char* BaseUri() const override; - virtual void HandleRequest(HttpServerRequest& Request) override; - - // IHttpStatsProvider - - virtual void HandleStatsRequest(HttpServerRequest& Request) override; - -protected: - CidStore& m_CidStore; - IHttpStatsService& m_StatsService; - LoggerRef Log() { return m_Log; } - -private: - LoggerRef m_Log; - std::filesystem ::path m_BaseDir; - HttpRequestRouter m_Router; - FunctionServiceSession m_FunctionService; - - // Metrics - - metrics::OperationTiming m_HttpRequests; -}; - -void httpfunction_forcelink(); - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h index 168c6d7fe..da5c5dfc3 100644 --- a/src/zencompute/include/zencompute/httporchestrator.h +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -2,43 +2,100 @@ #pragma once +#include <zencompute/zencompute.h> + #include <zencore/logging.h> #include <zencore/thread.h> -#include <zencore/timer.h> #include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> +#include <atomic> +#include <filesystem> +#include <memory> +#include <string> +#include <thread> #include <unordered_map> +#include <vector> + +#define ZEN_WITH_WEBSOCKETS 1 namespace zen::compute { +class OrchestratorService; + +// Experimental helper, to see if we can get rid of some error-prone +// boilerplate when declaring loggers as class members. + +class LoggerHelper +{ +public: + LoggerHelper(std::string_view Logger) : m_Log(logging::Get(Logger)) {} + + LoggerRef operator()() { return m_Log; } + +private: + LoggerRef m_Log; +}; + /** - * Mock orchestrator service, for testing dynamic provisioning + * Orchestrator HTTP service with WebSocket push support + * + * Normal HTTP requests are routed through the HttpRequestRouter as before. + * WebSocket clients connecting to /orch/ws receive periodic state broadcasts + * from a dedicated push thread, eliminating the need for polling. */ class HttpOrchestratorService : public HttpService +#if ZEN_WITH_WEBSOCKETS +, + public IWebSocketHandler +#endif { public: - HttpOrchestratorService(); + explicit HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); ~HttpOrchestratorService(); HttpOrchestratorService(const HttpOrchestratorService&) = delete; HttpOrchestratorService& operator=(const HttpOrchestratorService&) = delete; + /** + * Gracefully shut down the WebSocket push thread and release connections. + * Must be called while the ASIO io_context is still alive. The destructor + * also calls this, so it is safe (but not ideal) to omit the explicit call. + */ + void Shutdown(); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; + // IWebSocketHandler +#if ZEN_WITH_WEBSOCKETS + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; +#endif + private: - HttpRequestRouter m_Router; - LoggerRef m_Log; + HttpRequestRouter m_Router; + LoggerHelper Log{"orch"}; + std::unique_ptr<OrchestratorService> m_Service; + std::string m_Hostname; + + // WebSocket push - struct KnownWorker - { - std::string_view BaseUri; - Stopwatch LastSeen; - }; +#if ZEN_WITH_WEBSOCKETS + RwLock m_WsConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; + std::thread m_PushThread; + std::atomic<bool> m_PushEnabled{false}; + Event m_PushEvent; + void PushThreadFunction(); - RwLock m_KnownWorkersLock; - std::unordered_map<std::string, KnownWorker> m_KnownWorkers; + // Worker WebSocket connections (worker→orchestrator persistent links) + RwLock m_WorkerWsLock; + std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr → worker ID + std::string HandleWorkerWebSocketMessage(const WebSocketMessage& Msg); +#endif }; } // namespace zen::compute diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h new file mode 100644 index 000000000..521722e63 --- /dev/null +++ b/src/zencompute/include/zencompute/mockimds.h @@ -0,0 +1,102 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/cloudmetadata.h> +#include <zenhttp/httpserver.h> + +#include <string> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +/** + * Mock IMDS (Instance Metadata Service) for testing CloudMetadata. + * + * Implements an HttpService that responds to the same URL paths as the real + * cloud provider metadata endpoints (AWS IMDSv2, Azure IMDS, GCP metadata). + * Tests configure which provider is "active" and set the desired response + * values, then pass the mock server's address as the ImdsEndpoint to the + * CloudMetadata constructor. + * + * When a request arrives for a provider that is not the ActiveProvider, the + * mock returns 404, causing CloudMetadata to write a sentinel file and move + * on to the next provider — exactly like a failed probe on bare metal. + * + * All config fields are public and can be mutated between poll cycles to + * simulate state changes (e.g. a spot interruption appearing mid-run). + * + * Usage: + * MockImdsService Mock; + * Mock.ActiveProvider = CloudProvider::AWS; + * Mock.Aws.InstanceId = "i-test"; + * // ... stand up ASIO server, register Mock, create CloudMetadata with endpoint + */ +class MockImdsService : public HttpService +{ +public: + /** AWS IMDSv2 response configuration. */ + struct AwsConfig + { + std::string Token = "mock-aws-token-v2"; + std::string InstanceId = "i-0123456789abcdef0"; + std::string AvailabilityZone = "us-east-1a"; + std::string LifeCycle = "on-demand"; // "spot" or "on-demand" + + // Empty string → endpoint returns 404 (instance not in an ASG). + // Non-empty → returned as the response body. "InService" means healthy; + // anything else (e.g. "Terminated:Wait") triggers termination detection. + std::string AutoscalingState; + + // Empty string → endpoint returns 404 (no spot interruption). + // Non-empty → returned as the response body, signalling a spot reclaim. + std::string SpotAction; + }; + + /** Azure IMDS response configuration. */ + struct AzureConfig + { + std::string VmId = "vm-12345678-1234-1234-1234-123456789abc"; + std::string Location = "eastus"; + std::string Priority = "Regular"; // "Spot" or "Regular" + + // Empty → instance is not in a VM Scale Set (no autoscaling). + std::string VmScaleSetName; + + // Empty → no scheduled events. Set to "Preempt", "Terminate", or + // "Reboot" to simulate a termination-class event. + std::string ScheduledEventType; + std::string ScheduledEventStatus = "Scheduled"; + }; + + /** GCP metadata response configuration. */ + struct GcpConfig + { + std::string InstanceId = "1234567890123456789"; + std::string Zone = "projects/123456/zones/us-central1-a"; + std::string Preemptible = "FALSE"; // "TRUE" or "FALSE" + std::string MaintenanceEvent = "NONE"; // "NONE" or event description + }; + + /** Which provider's endpoints respond successfully. + * Requests targeting other providers receive 404. + */ + CloudProvider ActiveProvider = CloudProvider::None; + + AwsConfig Aws; + AzureConfig Azure; + GcpConfig Gcp; + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + +private: + void HandleAwsRequest(HttpServerRequest& Request); + void HandleAzureRequest(HttpServerRequest& Request); + void HandleGcpRequest(HttpServerRequest& Request); +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h new file mode 100644 index 000000000..071e902b3 --- /dev/null +++ b/src/zencompute/include/zencompute/orchestratorservice.h @@ -0,0 +1,177 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/thread.h> +# include <zencore/timer.h> +# include <zencore/uid.h> + +# include <deque> +# include <optional> +# include <filesystem> +# include <memory> +# include <string> +# include <string_view> +# include <thread> +# include <unordered_map> + +namespace zen::compute { + +class WorkerTimelineStore; + +class OrchestratorService +{ +public: + explicit OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); + ~OrchestratorService(); + + OrchestratorService(const OrchestratorService&) = delete; + OrchestratorService& operator=(const OrchestratorService&) = delete; + + struct WorkerAnnouncement + { + std::string_view Id; + std::string_view Uri; + std::string_view Hostname; + std::string_view Platform; // e.g. "windows", "wine", "linux", "macos" + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string_view Provisioner; // e.g. "horde", "nomad", or empty + }; + + struct ProvisioningEvent + { + enum class Type + { + Joined, + Left, + Returned + }; + Type EventType; + DateTime Timestamp; + std::string WorkerId; + std::string Hostname; + }; + + struct ClientAnnouncement + { + Oid SessionId; + std::string_view Hostname; + std::string_view Address; + CbObject Metadata; + }; + + struct ClientEvent + { + enum class Type + { + Connected, + Disconnected, + Updated + }; + Type EventType; + DateTime Timestamp; + std::string ClientId; + std::string Hostname; + }; + + CbObject GetWorkerList(); + void AnnounceWorker(const WorkerAnnouncement& Announcement); + + bool IsWorkerWebSocketEnabled() const; + void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected); + + CbObject GetProvisioningHistory(int Limit = 100); + + CbObject GetWorkerTimeline(std::string_view WorkerId, std::optional<DateTime> From, std::optional<DateTime> To, int Limit); + + CbObject GetAllTimelines(DateTime From, DateTime To); + + std::string AnnounceClient(const ClientAnnouncement& Announcement); + bool UpdateClient(std::string_view ClientId, CbObject Metadata = {}); + bool CompleteClient(std::string_view ClientId); + CbObject GetClientList(); + CbObject GetClientHistory(int Limit = 100); + +private: + enum class ReachableState + { + Unknown, + Reachable, + Unreachable, + }; + + struct KnownWorker + { + std::string BaseUri; + Stopwatch LastSeen; + std::string Hostname; + std::string Platform; + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string Provisioner; + ReachableState Reachable = ReachableState::Unknown; + bool WsConnected = false; + Stopwatch LastProbed; + }; + + RwLock m_KnownWorkersLock; + std::unordered_map<std::string, KnownWorker> m_KnownWorkers; + std::unique_ptr<WorkerTimelineStore> m_TimelineStore; + + RwLock m_ProvisioningLogLock; + std::deque<ProvisioningEvent> m_ProvisioningLog; + static constexpr size_t kMaxProvisioningEvents = 1000; + + void RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname); + + struct KnownClient + { + Oid SessionId; + std::string Hostname; + std::string Address; + Stopwatch LastSeen; + CbObject Metadata; + }; + + RwLock m_KnownClientsLock; + std::unordered_map<std::string, KnownClient> m_KnownClients; + + RwLock m_ClientLogLock; + std::deque<ClientEvent> m_ClientLog; + static constexpr size_t kMaxClientEvents = 1000; + + void RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname); + + bool m_EnableWorkerWebSocket = false; + + std::thread m_ProbeThread; + std::atomic<bool> m_ProbeThreadEnabled{true}; + Event m_ProbeThreadEvent; + void ProbeThreadFunction(); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/include/zencompute/recordingreader.h b/src/zencompute/include/zencompute/recordingreader.h index bf1aff125..3f233fae0 100644 --- a/src/zencompute/include/zencompute/recordingreader.h +++ b/src/zencompute/include/zencompute/recordingreader.h @@ -2,7 +2,9 @@ #pragma once -#include <zencompute/functionservice.h> +#include <zencompute/zencompute.h> + +#include <zencompute/computeservice.h> #include <zencompute/zencompute.h> #include <zencore/basicfile.h> #include <zencore/compactbinarybuilder.h> diff --git a/src/zencompute/include/zencompute/zencompute.h b/src/zencompute/include/zencompute/zencompute.h index 6dc32eeea..00be4d4a0 100644 --- a/src/zencompute/include/zencompute/zencompute.h +++ b/src/zencompute/include/zencompute/zencompute.h @@ -4,6 +4,10 @@ #include <zencore/zencore.h> +#if !defined(ZEN_WITH_COMPUTE_SERVICES) +# define ZEN_WITH_COMPUTE_SERVICES 1 +#endif + namespace zen { void zencompute_forcelinktests(); diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp new file mode 100644 index 000000000..9ea695305 --- /dev/null +++ b/src/zencompute/orchestratorservice.cpp @@ -0,0 +1,710 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/orchestratorservice.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinarybuilder.h> +# include <zencore/logging.h> +# include <zencore/trace.h> +# include <zenhttp/httpclient.h> + +# include "timeline/workertimeline.h" + +namespace zen::compute { + +OrchestratorService::OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) +: m_TimelineStore(std::make_unique<WorkerTimelineStore>(DataDir / "timelines")) +, m_EnableWorkerWebSocket(EnableWorkerWebSocket) +{ + m_ProbeThread = std::thread{&OrchestratorService::ProbeThreadFunction, this}; +} + +OrchestratorService::~OrchestratorService() +{ + m_ProbeThreadEnabled = false; + m_ProbeThreadEvent.Set(); + if (m_ProbeThread.joinable()) + { + m_ProbeThread.join(); + } +} + +CbObject +OrchestratorService::GetWorkerList() +{ + ZEN_TRACE_CPU("OrchestratorService::GetWorkerList"); + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + m_KnownWorkersLock.WithSharedLock([&] { + for (const auto& [WorkerId, Worker] : m_KnownWorkers) + { + Cbo.BeginObject(); + Cbo << "id" << WorkerId; + Cbo << "uri" << Worker.BaseUri; + Cbo << "hostname" << Worker.Hostname; + if (!Worker.Platform.empty()) + { + Cbo << "platform" << std::string_view(Worker.Platform); + } + Cbo << "cpus" << Worker.Cpus; + Cbo << "cpu_usage" << Worker.CpuUsagePercent; + Cbo << "memory_total" << Worker.MemoryTotalBytes; + Cbo << "memory_used" << Worker.MemoryUsedBytes; + Cbo << "bytes_received" << Worker.BytesReceived; + Cbo << "bytes_sent" << Worker.BytesSent; + Cbo << "actions_pending" << Worker.ActionsPending; + Cbo << "actions_running" << Worker.ActionsRunning; + Cbo << "actions_completed" << Worker.ActionsCompleted; + Cbo << "active_queues" << Worker.ActiveQueues; + if (!Worker.Provisioner.empty()) + { + Cbo << "provisioner" << std::string_view(Worker.Provisioner); + } + if (Worker.Reachable != ReachableState::Unknown) + { + Cbo << "reachable" << (Worker.Reachable == ReachableState::Reachable); + } + if (Worker.WsConnected) + { + Cbo << "ws_connected" << true; + } + Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +void +OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann) +{ + ZEN_TRACE_CPU("OrchestratorService::AnnounceWorker"); + + bool IsNew = false; + std::string EvictedId; + std::string EvictedHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + IsNew = (m_KnownWorkers.find(std::string(Ann.Id)) == m_KnownWorkers.end()); + + // If a different worker ID already maps to the same URI, the old entry + // is stale (e.g. a previous Horde lease on the same machine). Remove it + // so the dashboard doesn't show duplicates. + if (IsNew) + { + for (auto It = m_KnownWorkers.begin(); It != m_KnownWorkers.end(); ++It) + { + if (It->second.BaseUri == Ann.Uri && It->first != Ann.Id) + { + EvictedId = It->first; + EvictedHostname = It->second.Hostname; + m_KnownWorkers.erase(It); + break; + } + } + } + + auto& Worker = m_KnownWorkers[std::string(Ann.Id)]; + Worker.BaseUri = Ann.Uri; + Worker.Hostname = Ann.Hostname; + if (!Ann.Platform.empty()) + { + Worker.Platform = Ann.Platform; + } + Worker.Cpus = Ann.Cpus; + Worker.CpuUsagePercent = Ann.CpuUsagePercent; + Worker.MemoryTotalBytes = Ann.MemoryTotalBytes; + Worker.MemoryUsedBytes = Ann.MemoryUsedBytes; + Worker.BytesReceived = Ann.BytesReceived; + Worker.BytesSent = Ann.BytesSent; + Worker.ActionsPending = Ann.ActionsPending; + Worker.ActionsRunning = Ann.ActionsRunning; + Worker.ActionsCompleted = Ann.ActionsCompleted; + Worker.ActiveQueues = Ann.ActiveQueues; + if (!Ann.Provisioner.empty()) + { + Worker.Provisioner = Ann.Provisioner; + } + Worker.LastSeen.Reset(); + }); + + if (!EvictedId.empty()) + { + ZEN_INFO("worker {} superseded by {} (same endpoint)", EvictedId, Ann.Id); + RecordProvisioningEvent(ProvisioningEvent::Type::Left, EvictedId, EvictedHostname); + } + + if (IsNew) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Joined, Ann.Id, Ann.Hostname); + } +} + +bool +OrchestratorService::IsWorkerWebSocketEnabled() const +{ + return m_EnableWorkerWebSocket; +} + +void +OrchestratorService::SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected) +{ + ReachableState PrevState = ReachableState::Unknown; + std::string WorkerHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + auto It = m_KnownWorkers.find(std::string(WorkerId)); + if (It == m_KnownWorkers.end()) + { + return; + } + + PrevState = It->second.Reachable; + WorkerHostname = It->second.Hostname; + It->second.WsConnected = Connected; + It->second.Reachable = Connected ? ReachableState::Reachable : ReachableState::Unreachable; + + if (Connected) + { + ZEN_INFO("worker {} WebSocket connected — marking reachable", WorkerId); + } + else + { + ZEN_WARN("worker {} WebSocket disconnected — marking unreachable", WorkerId); + } + }); + + // Record provisioning events for state transitions outside the lock + if (Connected && PrevState == ReachableState::Unreachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Returned, WorkerId, WorkerHostname); + } + else if (!Connected && PrevState == ReachableState::Reachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Left, WorkerId, WorkerHostname); + } +} + +CbObject +OrchestratorService::GetWorkerTimeline(std::string_view WorkerId, std::optional<DateTime> From, std::optional<DateTime> To, int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetWorkerTimeline"); + + Ref<WorkerTimeline> Timeline = m_TimelineStore->Find(WorkerId); + if (!Timeline) + { + return {}; + } + + std::vector<WorkerTimeline::Event> Events; + + if (From || To) + { + DateTime StartTime = From.value_or(DateTime(0)); + DateTime EndTime = To.value_or(DateTime::Now()); + Events = Timeline->QueryTimeline(StartTime, EndTime); + } + else if (Limit > 0) + { + Events = Timeline->QueryRecent(Limit); + } + else + { + Events = Timeline->QueryRecent(); + } + + WorkerTimeline::TimeRange Range = Timeline->GetTimeRange(); + + CbObjectWriter Cbo; + Cbo << "worker_id" << WorkerId; + Cbo << "event_count" << static_cast<int32_t>(Timeline->GetEventCount()); + + if (Range) + { + Cbo.AddDateTime("time_first", Range.First); + Cbo.AddDateTime("time_last", Range.Last); + } + + Cbo.BeginArray("events"); + for (const auto& Evt : Events) + { + Cbo.BeginObject(); + Cbo << "type" << WorkerTimeline::ToString(Evt.Type); + Cbo.AddDateTime("ts", Evt.Timestamp); + + if (Evt.ActionLsn != 0) + { + Cbo << "lsn" << Evt.ActionLsn; + Cbo << "action_id" << Evt.ActionId; + } + + if (Evt.Type == WorkerTimeline::EventType::ActionStateChanged) + { + Cbo << "prev_state" << RunnerAction::ToString(Evt.PreviousState); + Cbo << "state" << RunnerAction::ToString(Evt.ActionState); + } + + if (!Evt.Reason.empty()) + { + Cbo << "reason" << std::string_view(Evt.Reason); + } + + Cbo.EndObject(); + } + Cbo.EndArray(); + + return Cbo.Save(); +} + +CbObject +OrchestratorService::GetAllTimelines(DateTime From, DateTime To) +{ + ZEN_TRACE_CPU("OrchestratorService::GetAllTimelines"); + + DateTime StartTime = From; + DateTime EndTime = To; + + auto AllInfo = m_TimelineStore->GetAllWorkerInfo(); + + CbObjectWriter Cbo; + Cbo.AddDateTime("from", StartTime); + Cbo.AddDateTime("to", EndTime); + + Cbo.BeginArray("workers"); + for (const auto& Info : AllInfo) + { + if (!Info.Range || Info.Range.Last < StartTime || Info.Range.First > EndTime) + { + continue; + } + + Cbo.BeginObject(); + Cbo << "worker_id" << Info.WorkerId; + Cbo.AddDateTime("time_first", Info.Range.First); + Cbo.AddDateTime("time_last", Info.Range.Last); + Cbo.EndObject(); + } + Cbo.EndArray(); + + return Cbo.Save(); +} + +void +OrchestratorService::RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname) +{ + ProvisioningEvent Evt{ + .EventType = Type, + .Timestamp = DateTime::Now(), + .WorkerId = std::string(WorkerId), + .Hostname = std::string(Hostname), + }; + + m_ProvisioningLogLock.WithExclusiveLock([&] { + m_ProvisioningLog.push_back(std::move(Evt)); + while (m_ProvisioningLog.size() > kMaxProvisioningEvents) + { + m_ProvisioningLog.pop_front(); + } + }); +} + +CbObject +OrchestratorService::GetProvisioningHistory(int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetProvisioningHistory"); + + if (Limit <= 0) + { + Limit = 100; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("events"); + + m_ProvisioningLogLock.WithSharedLock([&] { + // Return last N events, newest first + int Count = 0; + for (auto It = m_ProvisioningLog.rbegin(); It != m_ProvisioningLog.rend() && Count < Limit; ++It, ++Count) + { + const auto& Evt = *It; + Cbo.BeginObject(); + + switch (Evt.EventType) + { + case ProvisioningEvent::Type::Joined: + Cbo << "type" + << "joined"; + break; + case ProvisioningEvent::Type::Left: + Cbo << "type" + << "left"; + break; + case ProvisioningEvent::Type::Returned: + Cbo << "type" + << "returned"; + break; + } + + Cbo.AddDateTime("ts", Evt.Timestamp); + Cbo << "worker_id" << std::string_view(Evt.WorkerId); + Cbo << "hostname" << std::string_view(Evt.Hostname); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +std::string +OrchestratorService::AnnounceClient(const ClientAnnouncement& Ann) +{ + ZEN_TRACE_CPU("OrchestratorService::AnnounceClient"); + + std::string ClientId = fmt::format("client-{}", Oid::NewOid().ToString()); + + bool IsNew = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(ClientId); + IsNew = (It == m_KnownClients.end()); + + auto& Client = m_KnownClients[ClientId]; + Client.SessionId = Ann.SessionId; + Client.Hostname = Ann.Hostname; + if (!Ann.Address.empty()) + { + Client.Address = Ann.Address; + } + if (Ann.Metadata) + { + Client.Metadata = Ann.Metadata; + } + Client.LastSeen.Reset(); + }); + + if (IsNew) + { + RecordClientEvent(ClientEvent::Type::Connected, ClientId, Ann.Hostname); + } + else + { + RecordClientEvent(ClientEvent::Type::Updated, ClientId, Ann.Hostname); + } + + return ClientId; +} + +bool +OrchestratorService::UpdateClient(std::string_view ClientId, CbObject Metadata) +{ + ZEN_TRACE_CPU("OrchestratorService::UpdateClient"); + + bool Found = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(std::string(ClientId)); + if (It != m_KnownClients.end()) + { + Found = true; + if (Metadata) + { + It->second.Metadata = std::move(Metadata); + } + It->second.LastSeen.Reset(); + } + }); + + return Found; +} + +bool +OrchestratorService::CompleteClient(std::string_view ClientId) +{ + ZEN_TRACE_CPU("OrchestratorService::CompleteClient"); + + std::string Hostname; + bool Found = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(std::string(ClientId)); + if (It != m_KnownClients.end()) + { + Found = true; + Hostname = It->second.Hostname; + m_KnownClients.erase(It); + } + }); + + if (Found) + { + RecordClientEvent(ClientEvent::Type::Disconnected, ClientId, Hostname); + } + + return Found; +} + +CbObject +OrchestratorService::GetClientList() +{ + ZEN_TRACE_CPU("OrchestratorService::GetClientList"); + CbObjectWriter Cbo; + Cbo.BeginArray("clients"); + + m_KnownClientsLock.WithSharedLock([&] { + for (const auto& [ClientId, Client] : m_KnownClients) + { + Cbo.BeginObject(); + Cbo << "id" << ClientId; + if (Client.SessionId) + { + Cbo << "session_id" << Client.SessionId; + } + Cbo << "hostname" << std::string_view(Client.Hostname); + if (!Client.Address.empty()) + { + Cbo << "address" << std::string_view(Client.Address); + } + Cbo << "dt" << Client.LastSeen.GetElapsedTimeMs(); + if (Client.Metadata) + { + Cbo << "metadata" << Client.Metadata; + } + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +CbObject +OrchestratorService::GetClientHistory(int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetClientHistory"); + + if (Limit <= 0) + { + Limit = 100; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("client_events"); + + m_ClientLogLock.WithSharedLock([&] { + int Count = 0; + for (auto It = m_ClientLog.rbegin(); It != m_ClientLog.rend() && Count < Limit; ++It, ++Count) + { + const auto& Evt = *It; + Cbo.BeginObject(); + + switch (Evt.EventType) + { + case ClientEvent::Type::Connected: + Cbo << "type" + << "connected"; + break; + case ClientEvent::Type::Disconnected: + Cbo << "type" + << "disconnected"; + break; + case ClientEvent::Type::Updated: + Cbo << "type" + << "updated"; + break; + } + + Cbo.AddDateTime("ts", Evt.Timestamp); + Cbo << "client_id" << std::string_view(Evt.ClientId); + Cbo << "hostname" << std::string_view(Evt.Hostname); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +void +OrchestratorService::RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname) +{ + ClientEvent Evt{ + .EventType = Type, + .Timestamp = DateTime::Now(), + .ClientId = std::string(ClientId), + .Hostname = std::string(Hostname), + }; + + m_ClientLogLock.WithExclusiveLock([&] { + m_ClientLog.push_back(std::move(Evt)); + while (m_ClientLog.size() > kMaxClientEvents) + { + m_ClientLog.pop_front(); + } + }); +} + +void +OrchestratorService::ProbeThreadFunction() +{ + ZEN_TRACE_CPU("OrchestratorService::ProbeThreadFunction"); + SetCurrentThreadName("orch_probe"); + + bool IsFirstProbe = true; + + do + { + if (!IsFirstProbe) + { + m_ProbeThreadEvent.Wait(5'000); + m_ProbeThreadEvent.Reset(); + } + else + { + IsFirstProbe = false; + } + + if (m_ProbeThreadEnabled == false) + { + return; + } + + m_ProbeThreadEvent.Reset(); + + // Snapshot worker IDs and URIs under shared lock + struct WorkerSnapshot + { + std::string Id; + std::string Uri; + bool WsConnected = false; + }; + std::vector<WorkerSnapshot> Snapshots; + + m_KnownWorkersLock.WithSharedLock([&] { + Snapshots.reserve(m_KnownWorkers.size()); + for (const auto& [WorkerId, Worker] : m_KnownWorkers) + { + Snapshots.push_back({WorkerId, Worker.BaseUri, Worker.WsConnected}); + } + }); + + // Probe each worker outside the lock + for (const auto& Snap : Snapshots) + { + if (m_ProbeThreadEnabled == false) + { + return; + } + + // Workers with an active WebSocket connection are known-reachable; + // skip the HTTP health probe for them. + if (Snap.WsConnected) + { + continue; + } + + ReachableState NewState = ReachableState::Unreachable; + + try + { + HttpClient Client(Snap.Uri, + {.ConnectTimeout = std::chrono::milliseconds{3000}, .Timeout = std::chrono::milliseconds{5000}}); + HttpClient::Response Response = Client.Get("/health/"); + if (Response.IsSuccess()) + { + NewState = ReachableState::Reachable; + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + } + + ReachableState PrevState = ReachableState::Unknown; + std::string WorkerHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + auto It = m_KnownWorkers.find(Snap.Id); + if (It != m_KnownWorkers.end()) + { + PrevState = It->second.Reachable; + WorkerHostname = It->second.Hostname; + It->second.Reachable = NewState; + It->second.LastProbed.Reset(); + + if (PrevState != NewState) + { + if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable) + { + ZEN_INFO("worker {} ({}) is reachable again", Snap.Id, Snap.Uri); + } + else if (NewState == ReachableState::Reachable) + { + ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri); + } + else if (PrevState == ReachableState::Reachable) + { + ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri); + } + else + { + ZEN_WARN("worker {} ({}) is not reachable", Snap.Id, Snap.Uri); + } + } + } + }); + + // Record provisioning events for state transitions outside the lock + if (PrevState != NewState) + { + if (NewState == ReachableState::Unreachable && PrevState == ReachableState::Reachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Left, Snap.Id, WorkerHostname); + } + else if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Returned, Snap.Id, WorkerHostname); + } + } + } + + // Sweep expired clients (5-minute timeout) + static constexpr int64_t kClientTimeoutMs = 5 * 60 * 1000; + + struct ExpiredClient + { + std::string Id; + std::string Hostname; + }; + std::vector<ExpiredClient> ExpiredClients; + + m_KnownClientsLock.WithExclusiveLock([&] { + for (auto It = m_KnownClients.begin(); It != m_KnownClients.end();) + { + if (It->second.LastSeen.GetElapsedTimeMs() > kClientTimeoutMs) + { + ExpiredClients.push_back({It->first, It->second.Hostname}); + It = m_KnownClients.erase(It); + } + else + { + ++It; + } + } + }); + + for (const auto& Expired : ExpiredClients) + { + ZEN_INFO("client {} timed out (no announcement for >5 minutes)", Expired.Id); + RecordClientEvent(ClientEvent::Type::Disconnected, Expired.Id, Expired.Hostname); + } + } while (m_ProbeThreadEnabled); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/actionrecorder.cpp b/src/zencompute/recording/actionrecorder.cpp index 04c4b5141..90141ca55 100644 --- a/src/zencompute/actionrecorder.cpp +++ b/src/zencompute/recording/actionrecorder.cpp @@ -2,7 +2,7 @@ #include "actionrecorder.h" -#include "functionrunner.h" +#include "../runners/functionrunner.h" #include <zencore/compactbinary.h> #include <zencore/compactbinaryfile.h> diff --git a/src/zencompute/actionrecorder.h b/src/zencompute/recording/actionrecorder.h index 9cc2b44a2..2827b6ac7 100644 --- a/src/zencompute/actionrecorder.h +++ b/src/zencompute/recording/actionrecorder.h @@ -2,7 +2,7 @@ #pragma once -#include <zencompute/functionservice.h> +#include <zencompute/computeservice.h> #include <zencompute/zencompute.h> #include <zencore/basicfile.h> #include <zencore/compactbinarybuilder.h> diff --git a/src/zencompute/recordingreader.cpp b/src/zencompute/recording/recordingreader.cpp index 1c1a119cf..1c1a119cf 100644 --- a/src/zencompute/recordingreader.cpp +++ b/src/zencompute/recording/recordingreader.cpp diff --git a/src/zencompute/runners/deferreddeleter.cpp b/src/zencompute/runners/deferreddeleter.cpp new file mode 100644 index 000000000..4fad2cf70 --- /dev/null +++ b/src/zencompute/runners/deferreddeleter.cpp @@ -0,0 +1,340 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "deferreddeleter.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/logging.h> +# include <zencore/thread.h> + +# include <algorithm> +# include <chrono> + +namespace zen::compute { + +using namespace std::chrono_literals; + +using Clock = std::chrono::steady_clock; + +// Default deferral: how long to wait before attempting deletion. +// This gives memory-mapped file handles time to close naturally. +static constexpr auto DeferralPeriod = 60s; + +// Shortened deferral after MarkReady(): the client has collected results +// so handles should be released soon, but we still wait briefly. +static constexpr auto ReadyGracePeriod = 5s; + +// Interval between retry attempts for directories that failed deletion. +static constexpr auto RetryInterval = 5s; + +static constexpr int MaxRetries = 10; + +DeferredDirectoryDeleter::DeferredDirectoryDeleter() : m_Thread(&DeferredDirectoryDeleter::ThreadFunction, this) +{ +} + +DeferredDirectoryDeleter::~DeferredDirectoryDeleter() +{ + Shutdown(); +} + +void +DeferredDirectoryDeleter::Enqueue(int ActionLsn, std::filesystem::path Path) +{ + { + std::lock_guard Lock(m_Mutex); + m_Queue.push_back({ActionLsn, std::move(Path)}); + } + m_Cv.notify_one(); +} + +void +DeferredDirectoryDeleter::MarkReady(int ActionLsn) +{ + { + std::lock_guard Lock(m_Mutex); + m_ReadyLsns.push_back(ActionLsn); + } + m_Cv.notify_one(); +} + +void +DeferredDirectoryDeleter::Shutdown() +{ + { + std::lock_guard Lock(m_Mutex); + m_Done = true; + } + m_Cv.notify_one(); + + if (m_Thread.joinable()) + { + m_Thread.join(); + } +} + +void +DeferredDirectoryDeleter::ThreadFunction() +{ + SetCurrentThreadName("ZenDirCleanup"); + + struct PendingEntry + { + int ActionLsn; + std::filesystem::path Path; + Clock::time_point ReadyTime; + int Attempts = 0; + }; + + std::vector<PendingEntry> PendingList; + + auto TryDelete = [](PendingEntry& Entry) -> bool { + std::error_code Ec; + std::filesystem::remove_all(Entry.Path, Ec); + return !Ec; + }; + + for (;;) + { + bool Shutting = false; + + // Drain the incoming queue and process MarkReady signals + + { + std::unique_lock Lock(m_Mutex); + + if (m_Queue.empty() && m_ReadyLsns.empty() && !m_Done) + { + if (PendingList.empty()) + { + m_Cv.wait(Lock, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; }); + } + else + { + auto NextReady = PendingList.front().ReadyTime; + for (const auto& Entry : PendingList) + { + if (Entry.ReadyTime < NextReady) + { + NextReady = Entry.ReadyTime; + } + } + + m_Cv.wait_until(Lock, NextReady, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; }); + } + } + + // Move new items into PendingList with the full deferral deadline + auto Now = Clock::now(); + for (auto& Entry : m_Queue) + { + PendingList.push_back({Entry.ActionLsn, std::move(Entry.Path), Now + DeferralPeriod, 0}); + } + m_Queue.clear(); + + // Apply MarkReady: shorten ReadyTime for matching entries + for (int Lsn : m_ReadyLsns) + { + for (auto& Entry : PendingList) + { + if (Entry.ActionLsn == Lsn) + { + auto NewReady = Now + ReadyGracePeriod; + if (NewReady < Entry.ReadyTime) + { + Entry.ReadyTime = NewReady; + } + } + } + } + m_ReadyLsns.clear(); + + Shutting = m_Done; + } + + // Process items whose deferral period has elapsed (or all items on shutdown) + + auto Now = Clock::now(); + + for (size_t i = 0; i < PendingList.size();) + { + auto& Entry = PendingList[i]; + + if (!Shutting && Now < Entry.ReadyTime) + { + ++i; + continue; + } + + if (TryDelete(Entry)) + { + if (Entry.Attempts > 0) + { + ZEN_INFO("Retry succeeded for directory '{}'", Entry.Path); + } + + PendingList[i] = std::move(PendingList.back()); + PendingList.pop_back(); + } + else + { + ++Entry.Attempts; + + if (Entry.Attempts >= MaxRetries) + { + ZEN_WARN("Giving up on deleting '{}' after {} attempts", Entry.Path, Entry.Attempts); + PendingList[i] = std::move(PendingList.back()); + PendingList.pop_back(); + } + else + { + ZEN_WARN("Unable to delete directory '{}' (attempt {}), will retry", Entry.Path, Entry.Attempts); + Entry.ReadyTime = Now + RetryInterval; + ++i; + } + } + } + + // Exit once shutdown is requested and nothing remains + + if (Shutting && PendingList.empty()) + { + return; + } + } +} + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS + +# include <zencore/testing.h> + +namespace zen::compute { + +void +deferreddeleter_forcelink() +{ +} + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/testutils.h> + +namespace zen::compute { + +TEST_SUITE_BEGIN("compute.deferreddeleter"); + +TEST_CASE("DeferredDirectoryDeleter.DeletesSingleDirectory") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path DirToDelete = TempDir.Path() / "subdir"; + CreateDirectories(DirToDelete / "nested"); + + CHECK(std::filesystem::exists(DirToDelete)); + + { + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(1, DirToDelete); + } + + CHECK(!std::filesystem::exists(DirToDelete)); +} + +TEST_CASE("DeferredDirectoryDeleter.DeletesMultipleDirectories") +{ + ScopedTemporaryDirectory TempDir; + + constexpr int NumDirs = 10; + std::vector<std::filesystem::path> Dirs; + + for (int i = 0; i < NumDirs; ++i) + { + auto Dir = TempDir.Path() / std::to_string(i); + CreateDirectories(Dir / "child"); + Dirs.push_back(std::move(Dir)); + } + + { + DeferredDirectoryDeleter Deleter; + for (int i = 0; i < NumDirs; ++i) + { + CHECK(std::filesystem::exists(Dirs[i])); + Deleter.Enqueue(100 + i, Dirs[i]); + } + } + + for (const auto& Dir : Dirs) + { + CHECK(!std::filesystem::exists(Dir)); + } +} + +TEST_CASE("DeferredDirectoryDeleter.ShutdownIsIdempotent") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "idempotent"; + CreateDirectories(Dir); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(42, Dir); + Deleter.Shutdown(); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_CASE("DeferredDirectoryDeleter.HandlesNonExistentPath") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path NoSuchDir = TempDir.Path() / "does_not_exist"; + + { + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(99, NoSuchDir); + } +} + +TEST_CASE("DeferredDirectoryDeleter.ExplicitShutdownBeforeDestruction") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "explicit"; + CreateDirectories(Dir / "inner"); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(7, Dir); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_CASE("DeferredDirectoryDeleter.MarkReadyShortensDeferral") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "markready"; + CreateDirectories(Dir / "child"); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(50, Dir); + + // Without MarkReady the full deferral (60s) would apply. + // MarkReady shortens it to 5s, and shutdown bypasses even that. + Deleter.MarkReady(50); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_SUITE_END(); + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/runners/deferreddeleter.h b/src/zencompute/runners/deferreddeleter.h new file mode 100644 index 000000000..9b010aa0f --- /dev/null +++ b/src/zencompute/runners/deferreddeleter.h @@ -0,0 +1,68 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <condition_variable> +# include <deque> +# include <filesystem> +# include <mutex> +# include <thread> +# include <vector> + +namespace zen::compute { + +/// Deletes directories on a background thread to avoid blocking callers. +/// Useful when DeleteDirectories may stall (e.g. Wine's deferred-unlink semantics). +/// +/// Enqueued directories wait for a deferral period before deletion, giving +/// file handles time to close. Call MarkReady() with the ActionLsn to shorten +/// the wait to a brief grace period (e.g. once a client has collected results). +/// On shutdown, all pending directories are deleted immediately. +class DeferredDirectoryDeleter +{ + DeferredDirectoryDeleter(const DeferredDirectoryDeleter&) = delete; + DeferredDirectoryDeleter& operator=(const DeferredDirectoryDeleter&) = delete; + +public: + DeferredDirectoryDeleter(); + ~DeferredDirectoryDeleter(); + + /// Enqueue a directory for deferred deletion, associated with an action LSN. + void Enqueue(int ActionLsn, std::filesystem::path Path); + + /// Signal that the action result has been consumed and the directory + /// can be deleted after a short grace period instead of the full deferral. + void MarkReady(int ActionLsn); + + /// Drain the queue and join the background thread. Idempotent. + void Shutdown(); + +private: + struct QueueEntry + { + int ActionLsn; + std::filesystem::path Path; + }; + + std::mutex m_Mutex; + std::condition_variable m_Cv; + std::deque<QueueEntry> m_Queue; + std::vector<int> m_ReadyLsns; + bool m_Done = false; + std::thread m_Thread; + void ThreadFunction(); +}; + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS +namespace zen::compute { +void deferreddeleter_forcelink(); // internal +} // namespace zen::compute +#endif diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp new file mode 100644 index 000000000..768cdf1e1 --- /dev/null +++ b/src/zencompute/runners/functionrunner.cpp @@ -0,0 +1,365 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "functionrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/filesystem.h> +# include <zencore/trace.h> + +# include <fmt/format.h> +# include <vector> + +namespace zen::compute { + +FunctionRunner::FunctionRunner(std::filesystem::path BasePath) : m_ActionsPath(BasePath / "actions") +{ +} + +FunctionRunner::~FunctionRunner() = default; + +size_t +FunctionRunner::QueryCapacity() +{ + return 1; +} + +std::vector<SubmitResult> +FunctionRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + std::vector<SubmitResult> Results; + Results.reserve(Actions.size()); + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; +} + +void +FunctionRunner::MaybeDumpAction(int ActionLsn, const CbObject& ActionObject) +{ + if (m_DumpActions) + { + std::string UniqueId = fmt::format("{}.ddb", ActionLsn); + std::filesystem::path Path = m_ActionsPath / UniqueId; + + zen::WriteFile(Path, IoBuffer(ActionObject.GetBuffer().AsIoBuffer())); + } +} + +////////////////////////////////////////////////////////////////////////// + +void +BaseRunnerGroup::AddRunnerInternal(FunctionRunner* Runner) +{ + m_RunnersLock.WithExclusiveLock([&] { m_Runners.emplace_back(Runner); }); +} + +size_t +BaseRunnerGroup::QueryCapacity() +{ + size_t TotalCapacity = 0; + m_RunnersLock.WithSharedLock([&] { + for (const auto& Runner : m_Runners) + { + TotalCapacity += Runner->QueryCapacity(); + } + }); + return TotalCapacity; +} + +SubmitResult +BaseRunnerGroup::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("BaseRunnerGroup::SubmitAction"); + RwLock::SharedLockScope _(m_RunnersLock); + + const int InitialIndex = m_NextSubmitIndex.load(std::memory_order_acquire); + int Index = InitialIndex; + const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + + if (RunnerCount == 0) + { + return {.IsAccepted = false, .Reason = "No runners available"}; + } + + do + { + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + + auto& Runner = m_Runners[Index++]; + + SubmitResult Result = Runner->SubmitAction(Action); + + if (Result.IsAccepted == true) + { + m_NextSubmitIndex = Index % RunnerCount; + + return Result; + } + + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + } while (Index != InitialIndex); + + return {.IsAccepted = false}; +} + +std::vector<SubmitResult> +BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions"); + RwLock::SharedLockScope _(m_RunnersLock); + + const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + + if (RunnerCount == 0) + { + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); + } + + // Query capacity per runner and compute total + std::vector<size_t> Capacities(RunnerCount); + size_t TotalCapacity = 0; + + for (int i = 0; i < RunnerCount; ++i) + { + Capacities[i] = m_Runners[i]->QueryCapacity(); + TotalCapacity += Capacities[i]; + } + + if (TotalCapacity == 0) + { + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"}); + } + + // Distribute actions across runners proportionally to their available capacity + std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions(RunnerCount); + std::vector<size_t> ActionRunnerIndex(Actions.size()); + size_t ActionIdx = 0; + + for (int i = 0; i < RunnerCount; ++i) + { + if (Capacities[i] == 0) + { + continue; + } + + size_t Share = (Actions.size() * Capacities[i] + TotalCapacity - 1) / TotalCapacity; + Share = std::min(Share, Capacities[i]); + + for (size_t j = 0; j < Share && ActionIdx < Actions.size(); ++j, ++ActionIdx) + { + PerRunnerActions[i].push_back(Actions[ActionIdx]); + ActionRunnerIndex[ActionIdx] = i; + } + } + + // Assign any remaining actions to runners with capacity (round-robin) + for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount) + { + if (Capacities[i] > PerRunnerActions[i].size()) + { + PerRunnerActions[i].push_back(Actions[ActionIdx]); + ActionRunnerIndex[ActionIdx] = i; + ++ActionIdx; + } + } + + // Submit batches per runner + std::vector<std::vector<SubmitResult>> PerRunnerResults(RunnerCount); + + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]); + } + } + + // Reassemble results in original action order + std::vector<SubmitResult> Results(Actions.size()); + std::vector<size_t> PerRunnerIdx(RunnerCount, 0); + + for (size_t i = 0; i < Actions.size(); ++i) + { + size_t RunnerIdx = ActionRunnerIndex[i]; + size_t Idx = PerRunnerIdx[RunnerIdx]++; + Results[i] = std::move(PerRunnerResults[RunnerIdx][Idx]); + } + + return Results; +} + +size_t +BaseRunnerGroup::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunnersLock); + + size_t TotalCount = 0; + + for (const auto& Runner : m_Runners) + { + TotalCount += Runner->GetSubmittedActionCount(); + } + + return TotalCount; +} + +void +BaseRunnerGroup::RegisterWorker(CbPackage Worker) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->RegisterWorker(Worker); + } +} + +void +BaseRunnerGroup::Shutdown() +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->Shutdown(); + } +} + +bool +BaseRunnerGroup::CancelAction(int ActionLsn) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + if (Runner->CancelAction(ActionLsn)) + { + return true; + } + } + + return false; +} + +void +BaseRunnerGroup::CancelRemoteQueue(int QueueId) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->CancelRemoteQueue(QueueId); + } +} + +////////////////////////////////////////////////////////////////////////// + +RunnerAction::RunnerAction(ComputeServiceSession* OwnerSession) : m_OwnerSession(OwnerSession) +{ + this->Timestamps[static_cast<int>(State::New)] = DateTime::Now().GetTicks(); +} + +RunnerAction::~RunnerAction() +{ +} + +bool +RunnerAction::ResetActionStateToPending() +{ + // Only allow reset from Failed or Abandoned states + State CurrentState = m_ActionState.load(); + + if (CurrentState != State::Failed && CurrentState != State::Abandoned) + { + return false; + } + + if (!m_ActionState.compare_exchange_strong(CurrentState, State::Pending)) + { + return false; + } + + // Clear timestamps from Submitting through _Count + for (int i = static_cast<int>(State::Submitting); i < static_cast<int>(State::_Count); ++i) + { + this->Timestamps[i] = 0; + } + + // Record new Pending timestamp + this->Timestamps[static_cast<int>(State::Pending)] = DateTime::Now().GetTicks(); + + // Clear execution fields + ExecutionLocation.clear(); + CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); + CpuSeconds.store(0.0f, std::memory_order_relaxed); + + // Increment retry count + RetryCount.fetch_add(1, std::memory_order_relaxed); + + // Re-enter the scheduler pipeline + m_OwnerSession->PostUpdate(this); + + return true; +} + +void +RunnerAction::SetActionState(State NewState) +{ + ZEN_ASSERT(NewState < State::_Count); + this->Timestamps[static_cast<int>(NewState)] = DateTime::Now().GetTicks(); + + do + { + if (State CurrentState = m_ActionState.load(); CurrentState == NewState) + { + // No state change + return; + } + else + { + if (NewState <= CurrentState) + { + // Cannot transition to an earlier or same state + return; + } + + if (m_ActionState.compare_exchange_strong(CurrentState, NewState)) + { + // Successful state change + + m_OwnerSession->PostUpdate(this); + + return; + } + } + } while (true); +} + +void +RunnerAction::SetResult(CbPackage&& Result) +{ + m_Result = std::move(Result); +} + +CbPackage& +RunnerAction::GetResult() +{ + ZEN_ASSERT(IsCompleted()); + return m_Result; +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h new file mode 100644 index 000000000..f67414dbb --- /dev/null +++ b/src/zencompute/runners/functionrunner.h @@ -0,0 +1,214 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/computeservice.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <atomic> +# include <filesystem> +# include <vector> + +namespace zen::compute { + +struct SubmitResult +{ + bool IsAccepted = false; + std::string Reason; +}; + +/** Base interface for classes implementing a remote execution "runner" + */ +class FunctionRunner : public RefCounted +{ + FunctionRunner(FunctionRunner&&) = delete; + FunctionRunner& operator=(FunctionRunner&&) = delete; + +public: + FunctionRunner(std::filesystem::path BasePath); + virtual ~FunctionRunner() = 0; + + virtual void Shutdown() = 0; + virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; + + [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0; + [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; + [[nodiscard]] virtual bool IsHealthy() = 0; + [[nodiscard]] virtual size_t QueryCapacity(); + [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + + // Best-effort cancellation of a specific in-flight action. Returns true if the + // cancellation signal was successfully sent. The action will transition to Cancelled + // asynchronously once the platform-level termination completes. + virtual bool CancelAction(int /*ActionLsn*/) { return false; } + + // Cancel the remote queue corresponding to the given local QueueId. + // Only meaningful for remote runners; local runners ignore this. + virtual void CancelRemoteQueue(int /*QueueId*/) {} + +protected: + std::filesystem::path m_ActionsPath; + bool m_DumpActions = false; + void MaybeDumpAction(int ActionLsn, const CbObject& ActionObject); +}; + +/** Base class for RunnerGroup that operates on generic FunctionRunner references. + * All scheduling, capacity, and lifecycle logic lives here. + */ +class BaseRunnerGroup +{ +public: + size_t QueryCapacity(); + SubmitResult SubmitAction(Ref<RunnerAction> Action); + std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + size_t GetSubmittedActionCount(); + void RegisterWorker(CbPackage Worker); + void Shutdown(); + bool CancelAction(int ActionLsn); + void CancelRemoteQueue(int QueueId); + + size_t GetRunnerCount() + { + return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); }); + } + +protected: + void AddRunnerInternal(FunctionRunner* Runner); + + RwLock m_RunnersLock; + std::vector<Ref<FunctionRunner>> m_Runners; + std::atomic<int> m_NextSubmitIndex{0}; +}; + +/** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal. + */ +template<typename RunnerType> +struct RunnerGroup : public BaseRunnerGroup +{ + void AddRunner(RunnerType* Runner) { AddRunnerInternal(Runner); } + + template<typename Predicate> + size_t RemoveRunnerIf(Predicate&& Pred) + { + size_t RemovedCount = 0; + m_RunnersLock.WithExclusiveLock([&] { + auto It = m_Runners.begin(); + while (It != m_Runners.end()) + { + if (Pred(static_cast<RunnerType&>(**It))) + { + (*It)->Shutdown(); + It = m_Runners.erase(It); + ++RemovedCount; + } + else + { + ++It; + } + } + }); + return RemovedCount; + } +}; + +/** + * This represents an action going through different stages of scheduling and execution. + */ +struct RunnerAction : public RefCounted +{ + explicit RunnerAction(ComputeServiceSession* OwnerSession); + ~RunnerAction(); + + int ActionLsn = 0; + int QueueId = 0; + WorkerDesc Worker; + IoHash ActionId; + CbObject ActionObj; + int Priority = 0; + std::string ExecutionLocation; // "local" or remote hostname + + // CPU usage and total CPU time of the running process, sampled periodically by the local runner. + // CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage. + // CpuSeconds: total CPU time (user+system) consumed since process start, in seconds. 0.0 if not yet sampled. + std::atomic<float> CpuUsagePercent{-1.0f}; + std::atomic<float> CpuSeconds{0.0f}; + std::atomic<int> RetryCount{0}; + + enum class State + { + New, + Pending, + Submitting, + Running, + Completed, + Failed, + Abandoned, + Cancelled, + _Count + }; + + static const char* ToString(State _) + { + switch (_) + { + case State::New: + return "New"; + case State::Pending: + return "Pending"; + case State::Submitting: + return "Submitting"; + case State::Running: + return "Running"; + case State::Completed: + return "Completed"; + case State::Failed: + return "Failed"; + case State::Abandoned: + return "Abandoned"; + case State::Cancelled: + return "Cancelled"; + default: + return "Unknown"; + } + } + + static State FromString(std::string_view Name, State Default = State::Failed) + { + for (int i = 0; i < static_cast<int>(State::_Count); ++i) + { + if (Name == ToString(static_cast<State>(i))) + { + return static_cast<State>(i); + } + } + return Default; + } + + uint64_t Timestamps[static_cast<int>(State::_Count)] = {}; + + State ActionState() const { return m_ActionState; } + void SetActionState(State NewState); + + bool IsSuccess() const { return ActionState() == State::Completed; } + bool ResetActionStateToPending(); + bool IsCompleted() const + { + return ActionState() == State::Completed || ActionState() == State::Failed || ActionState() == State::Abandoned || + ActionState() == State::Cancelled; + } + + void SetResult(CbPackage&& Result); + CbPackage& GetResult(); + + ComputeServiceSession* GetOwnerSession() const { return m_OwnerSession; } + +private: + std::atomic<State> m_ActionState = State::New; + ComputeServiceSession* m_OwnerSession = nullptr; + CbPackage m_Result; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp new file mode 100644 index 000000000..e79a6c90f --- /dev/null +++ b/src/zencompute/runners/linuxrunner.cpp @@ -0,0 +1,734 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "linuxrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <fcntl.h> +# include <sched.h> +# include <signal.h> +# include <sys/mount.h> +# include <sys/stat.h> +# include <sys/syscall.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +namespace { + + // All helper functions in this namespace are async-signal-safe (safe to call + // between fork() and execve()). They use only raw syscalls and avoid any + // heap allocation, stdio, or other non-AS-safe operations. + + void WriteToFd(int Fd, const char* Buf, size_t Len) + { + while (Len > 0) + { + ssize_t Written = write(Fd, Buf, Len); + if (Written <= 0) + { + break; + } + Buf += Written; + Len -= static_cast<size_t>(Written); + } + } + + [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno) + { + // Write the message prefix + size_t MsgLen = 0; + for (const char* P = Msg; *P; ++P) + { + ++MsgLen; + } + WriteToFd(ErrorPipeFd, Msg, MsgLen); + + // Append ": " and the errno string if non-zero + if (Errno != 0) + { + WriteToFd(ErrorPipeFd, ": ", 2); + const char* ErrStr = strerror(Errno); + size_t ErrLen = 0; + for (const char* P = ErrStr; *P; ++P) + { + ++ErrLen; + } + WriteToFd(ErrorPipeFd, ErrStr, ErrLen); + } + + _exit(127); + } + + int MkdirIfNeeded(const char* Path, mode_t Mode) + { + if (mkdir(Path, Mode) != 0 && errno != EEXIST) + { + return -1; + } + return 0; + } + + int BindMountReadOnly(const char* Src, const char* Dst) + { + if (mount(Src, Dst, nullptr, MS_BIND | MS_REC, nullptr) != 0) + { + return -1; + } + + // Remount read-only + if (mount(nullptr, Dst, nullptr, MS_REMOUNT | MS_BIND | MS_RDONLY | MS_REC, nullptr) != 0) + { + return -1; + } + + return 0; + } + + // Set up namespace-based sandbox isolation in the child process. + // This is called after fork(), before execve(). All operations must be + // async-signal-safe. + // + // The sandbox layout after pivot_root: + // / -> the sandbox directory (tmpfs-like, was SandboxPath) + // /usr -> bind-mount of host /usr (read-only) + // /lib -> bind-mount of host /lib (read-only) + // /lib64 -> bind-mount of host /lib64 (read-only, optional) + // /etc -> bind-mount of host /etc (read-only) + // /worker -> bind-mount of worker directory (read-only) + // /proc -> proc filesystem + // /dev -> tmpfs with null, zero, urandom + void SetupNamespaceSandbox(const char* SandboxPath, uid_t Uid, gid_t Gid, const char* WorkerPath, int ErrorPipeFd) + { + // 1. Unshare user, mount, and network namespaces + if (unshare(CLONE_NEWUSER | CLONE_NEWNS | CLONE_NEWNET) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "unshare() failed", errno); + } + + // 2. Write UID/GID mappings + // Must deny setgroups first (required by kernel for unprivileged user namespaces) + { + int Fd = open("/proc/self/setgroups", O_WRONLY); + if (Fd >= 0) + { + WriteToFd(Fd, "deny", 4); + close(Fd); + } + // setgroups file may not exist on older kernels; not fatal + } + + { + // uid_map: map our UID to 0 inside the namespace + char Buf[64]; + int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast<unsigned>(Uid)); + + int Fd = open("/proc/self/uid_map", O_WRONLY); + if (Fd < 0) + { + WriteErrorAndExit(ErrorPipeFd, "open uid_map failed", errno); + } + WriteToFd(Fd, Buf, static_cast<size_t>(Len)); + close(Fd); + } + + { + // gid_map: map our GID to 0 inside the namespace + char Buf[64]; + int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast<unsigned>(Gid)); + + int Fd = open("/proc/self/gid_map", O_WRONLY); + if (Fd < 0) + { + WriteErrorAndExit(ErrorPipeFd, "open gid_map failed", errno); + } + WriteToFd(Fd, Buf, static_cast<size_t>(Len)); + close(Fd); + } + + // 3. Privatize the entire mount tree so our mounts don't propagate + if (mount(nullptr, "/", nullptr, MS_REC | MS_PRIVATE, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount MS_PRIVATE failed", errno); + } + + // 4. Create mount points inside the sandbox and bind-mount system directories + + // Helper macro-like pattern for building paths inside sandbox + // We use stack buffers since we can't allocate heap memory safely + char MountPoint[4096]; + + auto BuildPath = [&](const char* Suffix) -> const char* { + snprintf(MountPoint, sizeof(MountPoint), "%s/%s", SandboxPath, Suffix); + return MountPoint; + }; + + // /usr (required) + if (MkdirIfNeeded(BuildPath("usr"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/usr failed", errno); + } + if (BindMountReadOnly("/usr", BuildPath("usr")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /usr failed", errno); + } + + // /lib (required) + if (MkdirIfNeeded(BuildPath("lib"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/lib failed", errno); + } + if (BindMountReadOnly("/lib", BuildPath("lib")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /lib failed", errno); + } + + // /lib64 (optional — not all distros have it) + { + struct stat St; + if (stat("/lib64", &St) == 0 && S_ISDIR(St.st_mode)) + { + if (MkdirIfNeeded(BuildPath("lib64"), 0755) == 0) + { + BindMountReadOnly("/lib64", BuildPath("lib64")); + // Failure is non-fatal for lib64 + } + } + } + + // /etc (required — for resolv.conf, ld.so.cache, etc.) + if (MkdirIfNeeded(BuildPath("etc"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/etc failed", errno); + } + if (BindMountReadOnly("/etc", BuildPath("etc")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /etc failed", errno); + } + + // /worker — bind-mount worker directory (contains the executable) + if (MkdirIfNeeded(BuildPath("worker"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/worker failed", errno); + } + if (BindMountReadOnly(WorkerPath, BuildPath("worker")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount worker dir failed", errno); + } + + // 5. Mount /proc inside sandbox + if (MkdirIfNeeded(BuildPath("proc"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/proc failed", errno); + } + if (mount("proc", BuildPath("proc"), "proc", MS_NOSUID | MS_NOEXEC | MS_NODEV, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount /proc failed", errno); + } + + // 6. Mount tmpfs /dev and bind-mount essential device nodes + if (MkdirIfNeeded(BuildPath("dev"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/dev failed", errno); + } + if (mount("tmpfs", BuildPath("dev"), "tmpfs", MS_NOSUID | MS_NOEXEC, "size=64k,mode=0755") != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount tmpfs /dev failed", errno); + } + + // Bind-mount /dev/null, /dev/zero, /dev/urandom + { + char DevSrc[64]; + char DevDst[4096]; + + auto BindDev = [&](const char* Name) { + snprintf(DevSrc, sizeof(DevSrc), "/dev/%s", Name); + snprintf(DevDst, sizeof(DevDst), "%s/dev/%s", SandboxPath, Name); + + // Create the file to mount over + int Fd = open(DevDst, O_WRONLY | O_CREAT, 0666); + if (Fd >= 0) + { + close(Fd); + } + mount(DevSrc, DevDst, nullptr, MS_BIND, nullptr); + // Non-fatal if individual devices fail + }; + + BindDev("null"); + BindDev("zero"); + BindDev("urandom"); + } + + // 7. pivot_root to sandbox + // pivot_root requires the new root and put_old to be mount points. + // Bind-mount sandbox onto itself to make it a mount point. + if (mount(SandboxPath, SandboxPath, nullptr, MS_BIND | MS_REC, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount sandbox onto itself failed", errno); + } + + // Create .pivot_old inside sandbox + char PivotOld[4096]; + snprintf(PivotOld, sizeof(PivotOld), "%s/.pivot_old", SandboxPath); + if (MkdirIfNeeded(PivotOld, 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir .pivot_old failed", errno); + } + + if (syscall(SYS_pivot_root, SandboxPath, PivotOld) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "pivot_root failed", errno); + } + + // 8. Now inside new root. Clean up old root. + if (chdir("/") != 0) + { + WriteErrorAndExit(ErrorPipeFd, "chdir / failed", errno); + } + + if (umount2("/.pivot_old", MNT_DETACH) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "umount2 .pivot_old failed", errno); + } + + rmdir("/.pivot_old"); + } + +} // anonymous namespace + +LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); + + if (m_Sandboxed) + { + ZEN_INFO("namespace sandboxing enabled for child processes"); + } +} + +SubmitResult +LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("LinuxProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: <worker_exe_path> -Build=build.action + // Pre-compute all path strings before fork() for async-signal-safety. + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::string ExePathStr; + std::string SandboxedExePathStr; + + if (m_Sandboxed) + { + // After pivot_root, the worker dir is at /worker inside the new root + std::filesystem::path SandboxedExePath = std::filesystem::path("/worker") / std::filesystem::path(ExecPath); + SandboxedExePathStr = SandboxedExePath.string(); + // We still need the real path for logging + ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string(); + } + else + { + ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string(); + } + + std::string BuildArg = "-Build=build.action"; + + // argv[0] should be the path the child will see + const std::string& ChildExePath = m_Sandboxed ? SandboxedExePathStr : ExePathStr; + + std::vector<char*> ArgV; + ArgV.push_back(const_cast<char*>(ChildExePath.data())); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + std::string WorkerPathStr = Prepared->WorkerPath.string(); + + // Pre-fork: get uid/gid for namespace mapping, create error pipe + uid_t CurrentUid = 0; + gid_t CurrentGid = 0; + int ErrorPipe[2] = {-1, -1}; + + if (m_Sandboxed) + { + CurrentUid = getuid(); + CurrentGid = getgid(); + + if (pipe2(ErrorPipe, O_CLOEXEC) != 0) + { + throw zen::runtime_error("pipe2() for sandbox error pipe failed: {}", strerror(errno)); + } + } + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + int SavedErrno = errno; + if (m_Sandboxed) + { + close(ErrorPipe[0]); + close(ErrorPipe[1]); + } + throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno)); + } + + if (ChildPid == 0) + { + // Child process + + if (m_Sandboxed) + { + // Close read end of error pipe — child only writes + close(ErrorPipe[0]); + + SetupNamespaceSandbox(SandboxPathStr.c_str(), CurrentUid, CurrentGid, WorkerPathStr.c_str(), ErrorPipe[1]); + + // After pivot_root, CWD is "/" which is the sandbox root. + // execve with the sandboxed path. + execve(SandboxedExePathStr.c_str(), ArgV.data(), Envp.data()); + + WriteErrorAndExit(ErrorPipe[1], "execve failed", errno); + } + else + { + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + _exit(127); + } + } + + // Parent process + + if (m_Sandboxed) + { + // Close write end of error pipe — parent only reads + close(ErrorPipe[1]); + + // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC + // and read returns 0. If setup failed, child wrote an error message. + char ErrBuf[512]; + ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1); + close(ErrorPipe[0]); + + if (BytesRead > 0) + { + // Sandbox setup or execve failed + ErrBuf[BytesRead] = '\0'; + + // Reap the child (it called _exit(127)) + waitpid(ChildPid, nullptr, 0); + + // Clean up the sandbox in the background + m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); + + ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + + Action->SetActionState(RunnerAction::State::Failed); + return SubmitResult{.IsAccepted = false}; + } + } + + // Store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +LinuxProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("LinuxProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +LinuxProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("LinuxProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +LinuxProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("LinuxProcessRunner::CancelAction"); + + // Hold the shared lock while sending the signal to prevent the sweep thread + // from reaping the PID (via waitpid) between our lookup and kill(). Without + // the lock held, the PID could be recycled by the kernel and we'd signal an + // unrelated process. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (!Target->ProcessHandle) + { + return; + } + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Target->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno)); + return; + } + + ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +static uint64_t +ReadProcStatCpuTicks(pid_t Pid) +{ + char Path[64]; + snprintf(Path, sizeof(Path), "/proc/%d/stat", static_cast<int>(Pid)); + + char Buf[256]; + int Fd = open(Path, O_RDONLY); + if (Fd < 0) + { + return 0; + } + + ssize_t Len = read(Fd, Buf, sizeof(Buf) - 1); + close(Fd); + + if (Len <= 0) + { + return 0; + } + + Buf[Len] = '\0'; + + // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens + const char* P = strrchr(Buf, ')'); + if (!P) + { + return 0; + } + + P += 2; // skip ') ' + + // Remaining fields (space-separated, 0-indexed from here): + // 0:state 1:ppid 2:pgrp 3:session 4:tty_nr 5:tty_pgrp 6:flags + // 7:minflt 8:cminflt 9:majflt 10:cmajflt 11:utime 12:stime + unsigned long UTime = 0; + unsigned long STime = 0; + sscanf(P, "%*c %*d %*d %*d %*d %*d %*u %*u %*u %*u %*u %lu %lu", &UTime, &STime); + return UTime + STime; +} + +void +LinuxProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + static const long ClkTck = sysconf(_SC_CLK_TCK); + + const pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running.ProcessHandle)); + + const uint64_t NowTicks = GetHifreqTimerValue(); + const uint64_t CurrentOsTicks = ReadProcStatCpuTicks(Pid); + + if (CurrentOsTicks == 0) + { + // Process gone or /proc entry unreadable — record timestamp without updating usage + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = 0; + return; + } + + // Cumulative CPU seconds (absolute, available from first sample) + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / ClkTck), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) * 1000.0 / ClkTck / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/linuxrunner.h b/src/zencompute/runners/linuxrunner.h new file mode 100644 index 000000000..266de366b --- /dev/null +++ b/src/zencompute/runners/linuxrunner.h @@ -0,0 +1,44 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +namespace zen::compute { + +/** Native Linux process runner for executing Linux worker executables directly. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using Linux namespaces: + user, mount, and network namespaces are unshared so the child has no network + access and can only see the sandbox directory (with system libraries bind-mounted + read-only). This requires no special privileges thanks to user namespaces. + */ +class LinuxProcessRunner : public LocalProcessRunner +{ +public: + LinuxProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + bool m_Sandboxed = false; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/localrunner.cpp b/src/zencompute/runners/localrunner.cpp index 9a27f3f3d..7aaefb06e 100644 --- a/src/zencompute/localrunner.cpp +++ b/src/zencompute/runners/localrunner.cpp @@ -8,7 +8,7 @@ # include <zencore/compactbinarybuilder.h> # include <zencore/compactbinarypackage.h> # include <zencore/compress.h> -# include <zencore/except.h> +# include <zencore/except_fmt.h> # include <zencore/filesystem.h> # include <zencore/fmtutils.h> # include <zencore/iobuffer.h> @@ -16,6 +16,7 @@ # include <zencore/system.h> # include <zencore/scopeguard.h> # include <zencore/timer.h> +# include <zencore/trace.h> # include <zenstore/cidstore.h> # include <span> @@ -24,17 +25,28 @@ namespace zen::compute { using namespace std::literals; -LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, const std::filesystem::path& BaseDir) +LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions) : FunctionRunner(BaseDir) , m_Log(logging::Get("local_exec")) , m_ChunkResolver(Resolver) , m_WorkerPath(std::filesystem::weakly_canonical(BaseDir / "workers")) , m_SandboxPath(std::filesystem::weakly_canonical(BaseDir / "scratch")) +, m_DeferredDeleter(Deleter) +, m_WorkerPool(WorkerPool) { SystemMetrics Sm = GetSystemMetricsForReporting(); m_MaxRunningActions = Sm.LogicalProcessorCount * 2; + if (MaxConcurrentActions > 0) + { + m_MaxRunningActions = MaxConcurrentActions; + } + ZEN_INFO("Max concurrent action count: {}", m_MaxRunningActions); bool DidCleanup = false; @@ -116,6 +128,7 @@ LocalProcessRunner::~LocalProcessRunner() void LocalProcessRunner::Shutdown() { + ZEN_TRACE_CPU("LocalProcessRunner::Shutdown"); m_AcceptNewActions = false; m_MonitorThreadEnabled = false; @@ -131,6 +144,7 @@ LocalProcessRunner::Shutdown() std::filesystem::path LocalProcessRunner::CreateNewSandbox() { + ZEN_TRACE_CPU("LocalProcessRunner::CreateNewSandbox"); std::string UniqueId = std::to_string(++m_SandboxCounter); std::filesystem::path Path = m_SandboxPath / UniqueId; zen::CreateDirectories(Path); @@ -141,6 +155,7 @@ LocalProcessRunner::CreateNewSandbox() void LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) { + ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker"); if (m_DumpActions) { CbObject WorkerDescriptor = WorkerPackage.GetObject(); @@ -172,32 +187,84 @@ LocalProcessRunner::QueryCapacity() return 0; } - size_t RunningCount = m_RunningMap.size(); + const size_t InFlightCount = m_RunningMap.size() + m_SubmittingCount.load(std::memory_order_relaxed); - if (RunningCount >= size_t(m_MaxRunningActions)) + if (const size_t MaxRunningActions = m_MaxRunningActions; InFlightCount >= MaxRunningActions) { return 0; } - - return m_MaxRunningActions - RunningCount; + else + { + return MaxRunningActions - InFlightCount; + } } std::vector<SubmitResult> LocalProcessRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) { - std::vector<SubmitResult> Results; + if (Actions.size() <= 1) + { + std::vector<SubmitResult> Results; - for (const Ref<RunnerAction>& Action : Actions) + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; + } + + // For nontrivial batches, check capacity upfront and accept what fits. + // Accepted actions are transitioned to Submitting and dispatched to the + // worker pool as fire-and-forget, so SubmitActions returns immediately + // and the scheduler thread is free to handle completions and updates. + + size_t Available = QueryCapacity(); + + std::vector<SubmitResult> Results(Actions.size()); + + size_t AcceptCount = std::min(Available, Actions.size()); + + for (size_t i = 0; i < AcceptCount; ++i) { - Results.push_back(SubmitAction(Action)); + const Ref<RunnerAction>& Action = Actions[i]; + + Action->SetActionState(RunnerAction::State::Submitting); + m_SubmittingCount.fetch_add(1, std::memory_order_relaxed); + + Results[i] = SubmitResult{.IsAccepted = true}; + + m_WorkerPool.ScheduleWork( + [this, Action]() { + auto CountGuard = MakeGuard([this] { m_SubmittingCount.fetch_sub(1, std::memory_order_relaxed); }); + + SubmitResult Result = SubmitAction(Action); + + if (!Result.IsAccepted) + { + // This might require another state? We should + // distinguish between outright rejections (e.g. invalid action) + // and transient failures (e.g. failed to launch process) which might + // be retried by the scheduler, but for now just fail the action + Action->SetActionState(RunnerAction::State::Failed); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + + for (size_t i = AcceptCount; i < Actions.size(); ++i) + { + Results[i] = SubmitResult{.IsAccepted = false}; } return Results; } -SubmitResult -LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) +std::optional<LocalProcessRunner::PreparedAction> +LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action) { + ZEN_TRACE_CPU("LocalProcessRunner::PrepareActionSubmission"); + // Verify whether we can accept more work { @@ -205,29 +272,29 @@ LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (!m_AcceptNewActions) { - return SubmitResult{.IsAccepted = false}; + return std::nullopt; } if (m_RunningMap.size() >= size_t(m_MaxRunningActions)) { - return SubmitResult{.IsAccepted = false}; + return std::nullopt; } } - using namespace std::literals; - // Each enqueued action is assigned an integer index (logical sequence number), // which we use as a key for tracking data structures and as an opaque id which // may be used by clients to reference the scheduled action const int32_t ActionLsn = Action->ActionLsn; const CbObject& ActionObj = Action->ActionObj; - const IoHash ActionId = ActionObj.GetHash(); MaybeDumpAction(ActionLsn, ActionObj); std::filesystem::path SandboxPath = CreateNewSandbox(); + // Ensure the sandbox directory is cleaned up if any subsequent step throws + auto SandboxGuard = MakeGuard([&] { m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(SandboxPath)); }); + CbPackage WorkerPackage = Action->Worker.Descriptor; std::filesystem::path WorkerPath = ManifestWorker(Action->Worker); @@ -251,89 +318,24 @@ LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) zen::WriteFile(FilePath, DataBuffer); }); -# if ZEN_PLATFORM_WINDOWS - // Set up environment variables + Action->ExecutionLocation = "local"; - StringBuilder<1024> EnvironmentBlock; + SandboxGuard.Dismiss(); - CbObject WorkerDescription = WorkerPackage.GetObject(); - - for (auto& It : WorkerDescription["environment"sv]) - { - EnvironmentBlock.Append(It.AsString()); - EnvironmentBlock.Append('\0'); - } - EnvironmentBlock.Append('\0'); - EnvironmentBlock.Append('\0'); - - // Execute process - this spawns the child process immediately without waiting - // for completion - - std::string_view ExecPath = WorkerDescription["path"sv].AsString(); - std::filesystem::path ExePath = WorkerPath / std::filesystem::path(ExecPath).make_preferred(); - - ExtendableWideStringBuilder<512> CommandLine; - CommandLine.Append(L'"'); - CommandLine.Append(ExePath.c_str()); - CommandLine.Append(L'"'); - CommandLine.Append(L" -Build=build.action"); - - LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; - LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; - BOOL bInheritHandles = FALSE; - DWORD dwCreationFlags = 0; - - STARTUPINFO StartupInfo{}; - StartupInfo.cb = sizeof StartupInfo; - - PROCESS_INFORMATION ProcessInformation{}; - - ZEN_DEBUG("Executing: {}", WideToUtf8(CommandLine.c_str())); - - CommandLine.EnsureNulTerminated(); - - BOOL Success = CreateProcessW(nullptr, - CommandLine.Data(), - lpProcessAttributes, - lpThreadAttributes, - bInheritHandles, - dwCreationFlags, - (LPVOID)EnvironmentBlock.Data(), // Environment block - SandboxPath.c_str(), // Current directory - &StartupInfo, - /* out */ &ProcessInformation); - - if (!Success) - { - // TODO: this is probably not the best way to report failure. The return - // object should include a failure state and context - - zen::ThrowLastError("Unable to launch process" /* TODO: Add context */); - } - - CloseHandle(ProcessInformation.hThread); - - Ref<RunningAction> NewAction{new RunningAction()}; - NewAction->Action = Action; - NewAction->ProcessHandle = ProcessInformation.hProcess; - NewAction->SandboxPath = std::move(SandboxPath); - - { - RwLock::ExclusiveLockScope _(m_RunningLock); - - m_RunningMap[ActionLsn] = std::move(NewAction); - } - - Action->SetActionState(RunnerAction::State::Running); -# else - ZEN_UNUSED(ActionId); - - ZEN_NOT_IMPLEMENTED(); - - int ExitCode = 0; -# endif + return PreparedAction{ + .ActionLsn = ActionLsn, + .SandboxPath = std::move(SandboxPath), + .WorkerPath = std::move(WorkerPath), + .WorkerPackage = std::move(WorkerPackage), + }; +} - return SubmitResult{.IsAccepted = true}; +SubmitResult +LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + // Base class is not directly usable — platform subclasses override this + ZEN_UNUSED(Action); + return SubmitResult{.IsAccepted = false}; } size_t @@ -346,6 +348,7 @@ LocalProcessRunner::GetSubmittedActionCount() std::filesystem::path LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) { + ZEN_TRACE_CPU("LocalProcessRunner::ManifestWorker"); RwLock::SharedLockScope _(m_WorkerLock); std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); @@ -405,6 +408,23 @@ LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromP std::filesystem::path FilePath{SandboxRootPath / std::filesystem::path(Name).make_preferred()}; + // Validate the resolved path stays within the sandbox to prevent directory traversal + // via malicious names like "../../etc/evil" + // + // This might be worth revisiting to frontload the validation and eliminate some memory + // allocations in the future. + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxRootPath); + std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(FilePath); + std::string RootStr = CanonicalRoot.string(); + std::string FileStr = CanonicalFile.string(); + + if (FileStr.size() < RootStr.size() || FileStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: '{}' escapes sandbox root '{}'", Name, SandboxRootPath); + } + } + SharedBuffer Decompressed = Compressed.Decompress(); zen::WriteFile(FilePath, Decompressed.AsIoBuffer()); } @@ -421,12 +441,34 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, for (auto& It : WorkerDescription["executables"sv]) { DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); +# if !ZEN_PLATFORM_WINDOWS + std::string_view ExeName = It.AsObjectView()["name"sv].AsString(); + std::filesystem::path ExePath{SandboxPath / std::filesystem::path(ExeName).make_preferred()}; + std::filesystem::permissions( + ExePath, + std::filesystem::perms::owner_exec | std::filesystem::perms::group_exec | std::filesystem::perms::others_exec, + std::filesystem::perm_options::add); +# endif } for (auto& It : WorkerDescription["dirs"sv]) { std::string_view Name = It.AsString(); std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; + + // Validate dir path stays within sandbox + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxPath); + std::filesystem::path CanonicalDir = std::filesystem::weakly_canonical(DirPath); + std::string RootStr = CanonicalRoot.string(); + std::string DirStr = CanonicalDir.string(); + + if (DirStr.size() < RootStr.size() || DirStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: dir '{}' escapes sandbox root '{}'", Name, SandboxPath); + } + } + zen::CreateDirectories(DirPath); } @@ -441,6 +483,7 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, CbPackage LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) { + ZEN_TRACE_CPU("LocalProcessRunner::GatherActionOutputs"); std::filesystem::path OutputFile = SandboxPath / "build.output"; FileContents OutputData = zen::ReadFile(OutputFile); @@ -542,134 +585,53 @@ LocalProcessRunner::MonitorThreadFunction() } SweepRunningActions(); + SampleRunningProcessCpu(); } // Signal received SweepRunningActions(); + SampleRunningProcessCpu(); } while (m_MonitorThreadEnabled); } void LocalProcessRunner::CancelRunningActions() { - Stopwatch Timer; - std::unordered_map<int, Ref<RunningAction>> RunningMap; - - m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); - - if (RunningMap.empty()) - { - return; - } - - ZEN_INFO("cancelling all running actions"); - - // For expedience we initiate the process termination for all known - // processes before attempting to wait for them to exit. - - std::vector<int> TerminatedLsnList; - - for (const auto& Kv : RunningMap) - { - Ref<RunningAction> Action = Kv.second; - - // Terminate running process + // Base class is not directly usable — platform subclasses override this +} -# if ZEN_PLATFORM_WINDOWS - BOOL Success = TerminateProcess(Action->ProcessHandle, 222); +void +LocalProcessRunner::SampleRunningProcessCpu() +{ + static constexpr uint64_t kSampleIntervalMs = 5'000; - if (Success) - { - TerminatedLsnList.push_back(Kv.first); - } - else + m_RunningLock.WithSharedLock([&] { + const uint64_t Now = GetHifreqTimerValue(); + for (auto& [Lsn, Running] : m_RunningMap) { - DWORD LastError = GetLastError(); - - if (LastError != ERROR_ACCESS_DENIED) + const bool NeverSampled = Running->LastCpuSampleTicks == 0; + const bool IntervalElapsed = Stopwatch::GetElapsedTimeMs(Now - Running->LastCpuSampleTicks) >= kSampleIntervalMs; + if (NeverSampled || IntervalElapsed) { - ZEN_WARN("TerminateProcess for LSN {} not successful: {}", Action->Action->ActionLsn, GetSystemErrorAsString(LastError)); + SampleProcessCpu(*Running); } } -# else - ZEN_NOT_IMPLEMENTED("need to implement process termination"); -# endif - } - - // We only post results for processes we have terminated, in order - // to avoid multiple results getting posted for the same action - - for (int Lsn : TerminatedLsnList) - { - if (auto It = RunningMap.find(Lsn); It != RunningMap.end()) - { - Ref<RunningAction> Running = It->second; - -# if ZEN_PLATFORM_WINDOWS - if (Running->ProcessHandle != INVALID_HANDLE_VALUE) - { - DWORD WaitResult = WaitForSingleObject(Running->ProcessHandle, 2000); - - if (WaitResult != WAIT_OBJECT_0) - { - ZEN_WARN("wait for LSN {}: process exit did not succeed, result = {}", Running->Action->ActionLsn, WaitResult); - } - else - { - ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); - } - } -# endif - - // Clean up and post error result - - DeleteDirectories(Running->SandboxPath); - Running->Action->SetActionState(RunnerAction::State::Failed); - } - } - - ZEN_INFO("DONE - cancelled {} running processes (took {})", TerminatedLsnList.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); } void LocalProcessRunner::SweepRunningActions() { - std::vector<Ref<RunningAction>> CompletedActions; - - m_RunningLock.WithExclusiveLock([&] { - // TODO: It would be good to not hold the exclusive lock while making - // system calls and other expensive operations. - - for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) - { - Ref<RunningAction> Action = It->second; - -# if ZEN_PLATFORM_WINDOWS - DWORD ExitCode = 0; - BOOL IsSuccess = GetExitCodeProcess(Action->ProcessHandle, &ExitCode); - - if (IsSuccess && ExitCode != STILL_ACTIVE) - { - CloseHandle(Action->ProcessHandle); - Action->ProcessHandle = INVALID_HANDLE_VALUE; - - CompletedActions.push_back(std::move(Action)); - It = m_RunningMap.erase(It); - } - else - { - ++It; - } -# else - // TODO: implement properly for Mac/Linux - - ZEN_UNUSED(Action); -# endif - } - }); + ZEN_TRACE_CPU("LocalProcessRunner::SweepRunningActions"); +} - // Notify outer. Note that this has to be done without holding any local locks +void +LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& CompletedActions) +{ + ZEN_TRACE_CPU("LocalProcessRunner::ProcessCompletedActions"); + // Shared post-processing: gather outputs, set state, clean sandbox. + // Note that this must be called without holding any local locks // otherwise we may end up with deadlocks. for (Ref<RunningAction> Running : CompletedActions) @@ -687,11 +649,9 @@ LocalProcessRunner::SweepRunningActions() Running->Action->SetResult(std::move(OutputPackage)); Running->Action->SetActionState(RunnerAction::State::Completed); - // We can delete the files at this point - if (!DeleteDirectories(Running->SandboxPath)) - { - ZEN_WARN("Unable to delete directory '{}', this will continue to exist until service restart", Running->SandboxPath); - } + // Enqueue sandbox for deferred background deletion, giving + // file handles time to close before we attempt removal. + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); // Success -- continue with next iteration of the loop continue; @@ -702,17 +662,9 @@ LocalProcessRunner::SweepRunningActions() } } - // Failed - for now this is indicated with an empty package in - // the results map. We can clean out the sandbox directory immediately. - - std::error_code Ec; - DeleteDirectories(Running->SandboxPath, Ec); - - if (Ec) - { - ZEN_WARN("Unable to delete sandbox directory '{}': {}", Running->SandboxPath, Ec.message()); - } + // Failed - clean up the sandbox in the background. + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); Running->Action->SetActionState(RunnerAction::State::Failed); } } diff --git a/src/zencompute/localrunner.h b/src/zencompute/runners/localrunner.h index 35f464805..7493e980b 100644 --- a/src/zencompute/localrunner.h +++ b/src/zencompute/runners/localrunner.h @@ -2,7 +2,7 @@ #pragma once -#include "zencompute/functionservice.h" +#include "zencompute/computeservice.h" #if ZEN_WITH_COMPUTE_SERVICES @@ -14,8 +14,13 @@ # include <zencore/compactbinarypackage.h> # include <zencore/logging.h> +# include "deferreddeleter.h" + +# include <zencore/workthreadpool.h> + # include <atomic> # include <filesystem> +# include <optional> # include <thread> namespace zen { @@ -38,7 +43,11 @@ class LocalProcessRunner : public FunctionRunner LocalProcessRunner& operator=(LocalProcessRunner&&) = delete; public: - LocalProcessRunner(ChunkResolver& Resolver, const std::filesystem::path& BaseDir); + LocalProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions = 0); ~LocalProcessRunner(); virtual void Shutdown() override; @@ -60,6 +69,10 @@ protected: void* ProcessHandle = nullptr; int ExitCode = 0; std::filesystem::path SandboxPath; + + // State for periodic CPU usage sampling + uint64_t LastCpuSampleTicks = 0; // hifreq timer value at last sample + uint64_t LastCpuOsTicks = 0; // OS CPU ticks (platform-specific units) at last sample }; std::atomic_bool m_AcceptNewActions; @@ -75,12 +88,37 @@ protected: RwLock m_RunningLock; std::unordered_map<int, Ref<RunningAction>> m_RunningMap; + std::atomic<int32_t> m_SubmittingCount = 0; + DeferredDirectoryDeleter& m_DeferredDeleter; + WorkerThreadPool& m_WorkerPool; + std::thread m_MonitorThread; std::atomic<bool> m_MonitorThreadEnabled{true}; Event m_MonitorThreadEvent; void MonitorThreadFunction(); - void SweepRunningActions(); - void CancelRunningActions(); + virtual void SweepRunningActions(); + virtual void CancelRunningActions(); + + // Sample CPU usage for all currently running processes (throttled per-action). + void SampleRunningProcessCpu(); + + // Override in platform runners to sample one process. Called under a shared RunningLock. + virtual void SampleProcessCpu(RunningAction& /*Running*/) {} + + // Shared preamble for SubmitAction: capacity check, sandbox creation, + // worker manifesting, action writing, input manifesting. + struct PreparedAction + { + int32_t ActionLsn; + std::filesystem::path SandboxPath; + std::filesystem::path WorkerPath; + CbPackage WorkerPackage; + }; + std::optional<PreparedAction> PrepareActionSubmission(Ref<RunnerAction> Action); + + // Shared post-processing for SweepRunningActions: gather outputs, + // set state, clean sandbox. + void ProcessCompletedActions(std::vector<Ref<RunningAction>>& CompletedActions); std::filesystem::path CreateNewSandbox(); void ManifestWorker(const CbPackage& WorkerPackage, diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp new file mode 100644 index 000000000..5cec90699 --- /dev/null +++ b/src/zencompute/runners/macrunner.cpp @@ -0,0 +1,491 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "macrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <fcntl.h> +# include <libproc.h> +# include <sandbox.h> +# include <signal.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +namespace { + + // All helper functions in this namespace are async-signal-safe (safe to call + // between fork() and execve()). They use only raw syscalls and avoid any + // heap allocation, stdio, or other non-AS-safe operations. + + void WriteToFd(int Fd, const char* Buf, size_t Len) + { + while (Len > 0) + { + ssize_t Written = write(Fd, Buf, Len); + if (Written <= 0) + { + break; + } + Buf += Written; + Len -= static_cast<size_t>(Written); + } + } + + [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno) + { + // Write the message prefix + size_t MsgLen = 0; + for (const char* P = Msg; *P; ++P) + { + ++MsgLen; + } + WriteToFd(ErrorPipeFd, Msg, MsgLen); + + // Append ": " and the errno string if non-zero + if (Errno != 0) + { + WriteToFd(ErrorPipeFd, ": ", 2); + const char* ErrStr = strerror(Errno); + size_t ErrLen = 0; + for (const char* P = ErrStr; *P; ++P) + { + ++ErrLen; + } + WriteToFd(ErrorPipeFd, ErrStr, ErrLen); + } + + _exit(127); + } + + // Build a Seatbelt profile string that denies everything by default and + // allows only the minimum needed for the worker to execute: process ops, + // system library reads, worker directory (read-only), and sandbox directory + // (read-write). Network access is denied implicitly by the deny-default policy. + std::string BuildSandboxProfile(const std::string& SandboxPath, const std::string& WorkerPath) + { + std::string Profile; + Profile.reserve(1024); + + Profile += "(version 1)\n"; + Profile += "(deny default)\n"; + Profile += "(allow process*)\n"; + Profile += "(allow sysctl-read)\n"; + Profile += "(allow file-read-metadata)\n"; + + // System library paths needed for dynamic linker and runtime + Profile += "(allow file-read* (subpath \"/usr\"))\n"; + Profile += "(allow file-read* (subpath \"/System\"))\n"; + Profile += "(allow file-read* (subpath \"/Library\"))\n"; + Profile += "(allow file-read* (subpath \"/dev\"))\n"; + Profile += "(allow file-read* (subpath \"/private/var/db/dyld\"))\n"; + Profile += "(allow file-read* (subpath \"/etc\"))\n"; + + // Worker directory: read-only + Profile += "(allow file-read* (subpath \""; + Profile += WorkerPath; + Profile += "\"))\n"; + + // Sandbox directory: read+write + Profile += "(allow file-read* file-write* (subpath \""; + Profile += SandboxPath; + Profile += "\"))\n"; + + return Profile; + } + +} // anonymous namespace + +MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); + + if (m_Sandboxed) + { + ZEN_INFO("Seatbelt sandboxing enabled for child processes"); + } +} + +SubmitResult +MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("MacProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: <worker_exe_path> -Build=build.action + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath); + std::string ExePathStr = ExePath.string(); + std::string BuildArg = "-Build=build.action"; + + std::vector<char*> ArgV; + ArgV.push_back(ExePathStr.data()); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + std::string WorkerPathStr = Prepared->WorkerPath.string(); + + // Pre-fork: build sandbox profile and create error pipe + std::string SandboxProfile; + int ErrorPipe[2] = {-1, -1}; + + if (m_Sandboxed) + { + SandboxProfile = BuildSandboxProfile(SandboxPathStr, WorkerPathStr); + + if (pipe(ErrorPipe) != 0) + { + throw zen::runtime_error("pipe() for sandbox error pipe failed: {}", strerror(errno)); + } + fcntl(ErrorPipe[0], F_SETFD, FD_CLOEXEC); + fcntl(ErrorPipe[1], F_SETFD, FD_CLOEXEC); + } + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + int SavedErrno = errno; + if (m_Sandboxed) + { + close(ErrorPipe[0]); + close(ErrorPipe[1]); + } + throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno)); + } + + if (ChildPid == 0) + { + // Child process + + if (m_Sandboxed) + { + // Close read end of error pipe — child only writes + close(ErrorPipe[0]); + + // Apply Seatbelt sandbox profile + char* ErrorBuf = nullptr; + if (sandbox_init(SandboxProfile.c_str(), 0, &ErrorBuf) != 0) + { + // sandbox_init failed — write error to pipe and exit + if (ErrorBuf) + { + WriteErrorAndExit(ErrorPipe[1], ErrorBuf, 0); + // WriteErrorAndExit does not return, but sandbox_free_error + // is not needed since we _exit + } + WriteErrorAndExit(ErrorPipe[1], "sandbox_init failed", errno); + } + if (ErrorBuf) + { + sandbox_free_error(ErrorBuf); + } + + if (chdir(SandboxPathStr.c_str()) != 0) + { + WriteErrorAndExit(ErrorPipe[1], "chdir to sandbox failed", errno); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + + WriteErrorAndExit(ErrorPipe[1], "execve failed", errno); + } + else + { + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + _exit(127); + } + } + + // Parent process + + if (m_Sandboxed) + { + // Close write end of error pipe — parent only reads + close(ErrorPipe[1]); + + // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC + // and read returns 0. If setup failed, child wrote an error message. + char ErrBuf[512]; + ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1); + close(ErrorPipe[0]); + + if (BytesRead > 0) + { + // Sandbox setup or execve failed + ErrBuf[BytesRead] = '\0'; + + // Reap the child (it called _exit(127)) + waitpid(ChildPid, nullptr, 0); + + // Clean up the sandbox in the background + m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); + + ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + + Action->SetActionState(RunnerAction::State::Failed); + return SubmitResult{.IsAccepted = false}; + } + } + + // Store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +MacProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("MacProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +MacProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("MacProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +MacProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("MacProcessRunner::CancelAction"); + + // Hold the shared lock while sending the signal to prevent the sweep thread + // from reaping the PID (via waitpid) between our lookup and kill(). Without + // the lock held, the PID could be recycled by the kernel and we'd signal an + // unrelated process. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (!Target->ProcessHandle) + { + return; + } + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Target->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno)); + return; + } + + ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +void +MacProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + const pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running.ProcessHandle)); + + struct proc_taskinfo Info; + if (proc_pidinfo(Pid, PROC_PIDTASKINFO, 0, &Info, sizeof(Info)) <= 0) + { + return; + } + + // pti_total_user and pti_total_system are in nanoseconds + const uint64_t CurrentOsTicks = Info.pti_total_user + Info.pti_total_system; + const uint64_t NowTicks = GetHifreqTimerValue(); + + // Cumulative CPU seconds (absolute, available from first sample): ns → seconds + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 1'000'000'000.0), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + // ns → ms: divide by 1,000,000; then as percent of elapsed ms + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 1'000'000.0 / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/macrunner.h b/src/zencompute/runners/macrunner.h new file mode 100644 index 000000000..d653b923a --- /dev/null +++ b/src/zencompute/runners/macrunner.h @@ -0,0 +1,43 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC + +namespace zen::compute { + +/** Native macOS process runner for executing Mac worker executables directly. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using macOS Seatbelt + (sandbox_init): no network access and no filesystem access outside the + explicitly allowed sandbox and worker directories. This requires no elevation. + */ +class MacProcessRunner : public LocalProcessRunner +{ +public: + MacProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + bool m_Sandboxed = false; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp index 98ced5fe8..672636d06 100644 --- a/src/zencompute/remotehttprunner.cpp +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -14,6 +14,8 @@ # include <zencore/iobuffer.h> # include <zencore/iohash.h> # include <zencore/scopeguard.h> +# include <zencore/system.h> +# include <zencore/trace.h> # include <zenhttp/httpcommon.h> # include <zenstore/cidstore.h> @@ -27,12 +29,18 @@ using namespace std::literals; ////////////////////////////////////////////////////////////////////////// -RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, const std::filesystem::path& BaseDir, std::string_view HostName) +RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, + const std::filesystem::path& BaseDir, + std::string_view HostName, + WorkerThreadPool& InWorkerPool) : FunctionRunner(BaseDir) , m_Log(logging::Get("http_exec")) , m_ChunkResolver{InChunkResolver} -, m_BaseUrl{fmt::format("{}/apply", HostName)} +, m_WorkerPool{InWorkerPool} +, m_HostName{HostName} +, m_BaseUrl{fmt::format("{}/compute", HostName)} , m_Http(m_BaseUrl) +, m_InstanceId(Oid::NewOid()) { m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; } @@ -58,6 +66,7 @@ RemoteHttpRunner::Shutdown() void RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) { + ZEN_TRACE_CPU("RemoteHttpRunner::RegisterWorker"); const IoHash WorkerId = WorkerPackage.GetObjectHash(); CbPackage WorkerDesc = WorkerPackage; @@ -168,11 +177,38 @@ RemoteHttpRunner::QueryCapacity() std::vector<SubmitResult> RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) { - std::vector<SubmitResult> Results; + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions"); + + if (Actions.size() <= 1) + { + std::vector<SubmitResult> Results; + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; + } + + // For larger batches, submit HTTP requests in parallel via the shared worker pool + + std::vector<std::future<SubmitResult>> Futures; + Futures.reserve(Actions.size()); for (const Ref<RunnerAction>& Action : Actions) { - Results.push_back(SubmitAction(Action)); + std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); }); + + Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog)); + } + + std::vector<SubmitResult> Results; + Results.reserve(Futures.size()); + + for (auto& Future : Futures) + { + Results.push_back(Future.get()); } return Results; @@ -181,6 +217,8 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) SubmitResult RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) { + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitAction"); + // Verify whether we can accept more work { @@ -197,18 +235,53 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) // which we use as a key for tracking data structures and as an opaque id which // may be used by clients to reference the scheduled action + Action->ExecutionLocation = m_HostName; + const int32_t ActionLsn = Action->ActionLsn; const CbObject& ActionObj = Action->ActionObj; const IoHash ActionId = ActionObj.GetHash(); MaybeDumpAction(ActionLsn, ActionObj); - // Enqueue job + // Determine the submission URL. If the action belongs to a queue, ensure a + // corresponding remote queue exists on the target node and submit via it. + + std::string SubmitUrl = "/jobs"; + if (const int QueueId = Action->QueueId; QueueId != 0) + { + CbObject QueueMeta = Action->GetOwnerSession()->GetQueueMetadata(QueueId); + CbObject QueueConfig = Action->GetOwnerSession()->GetQueueConfig(QueueId); + if (Oid Token = EnsureRemoteQueue(QueueId, QueueMeta, QueueConfig); Token != Oid::Zero) + { + SubmitUrl = fmt::format("/queues/{}/jobs", Token); + } + } - CbObject Result; + // Enqueue job. If the remote returns FailedDependency (424), it means it + // cannot resolve the worker/function — re-register the worker and retry once. - HttpClient::Response WorkResponse = m_Http.Post("/jobs", ActionObj); - HttpResponseCode WorkResponseCode = WorkResponse.StatusCode; + CbObject Result; + HttpClient::Response WorkResponse; + HttpResponseCode WorkResponseCode{}; + + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, ActionObj); + WorkResponseCode = WorkResponse.StatusCode; + + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); + + RegisterWorker(Action->Worker.Descriptor); + } + else + { + break; + } + } if (WorkResponseCode == HttpResponseCode::OK) { @@ -250,11 +323,11 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) // Post resulting package - HttpClient::Response PayloadResponse = m_Http.Post("/jobs", Pkg); + HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); if (!PayloadResponse) { - ZEN_WARN("unable to register payloads for action {} at {}/jobs", ActionId, m_Http.GetBaseUri()); + ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); // TODO: include more information about the failure in the response @@ -270,17 +343,19 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) const int ResponseStatusCode = (int)PayloadResponse.StatusCode; - ZEN_WARN("unable to register payloads for action {} at {}/jobs (error: {} {})", + ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})", ActionId, m_Http.GetBaseUri(), + SubmitUrl, ResponseStatusCode, ToString(ResponseStatusCode)); return {.IsAccepted = false, - .Reason = fmt::format("unexpected response code {} {} from {}/jobs", + .Reason = fmt::format("unexpected response code {} {} from {}{}", ResponseStatusCode, ToString(ResponseStatusCode), - m_Http.GetBaseUri())}; + m_Http.GetBaseUri(), + SubmitUrl)}; } } @@ -309,6 +384,82 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) return {}; } +Oid +RemoteHttpRunner::EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config) +{ + { + RwLock::SharedLockScope _(m_QueueTokenLock); + if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end()) + { + return It->second; + } + } + + // Build a stable idempotency key that uniquely identifies this (runner instance, local queue) + // pair. The server uses this to return the same remote queue token for concurrent or redundant + // requests, preventing orphaned remote queues when multiple threads race through here. + // Also send hostname so the server can associate the queue with its origin for diagnostics. + CbObjectWriter Body; + Body << "idempotency_key"sv << fmt::format("{}/{}", m_InstanceId, QueueId); + Body << "hostname"sv << GetMachineName(); + if (Metadata) + { + Body << "metadata"sv << Metadata; + } + if (Config) + { + Body << "config"sv << Config; + } + + HttpClient::Response Resp = m_Http.Post("/queues/remote", Body.Save()); + if (!Resp) + { + ZEN_WARN("failed to create remote queue for local queue {} on {}", QueueId, m_HostName); + return Oid::Zero; + } + + Oid Token = Oid::TryFromHexString(Resp.AsObject()["queue_token"sv].AsString()); + if (Token == Oid::Zero) + { + return Oid::Zero; + } + + ZEN_DEBUG("created remote queue '{}' for local queue {} on {}", Token, QueueId, m_HostName); + + RwLock::ExclusiveLockScope _(m_QueueTokenLock); + auto [It, Inserted] = m_RemoteQueueTokens.try_emplace(QueueId, Token); + return It->second; +} + +void +RemoteHttpRunner::CancelRemoteQueue(int QueueId) +{ + Oid Token; + { + RwLock::SharedLockScope _(m_QueueTokenLock); + if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end()) + { + Token = It->second; + } + } + + if (Token == Oid::Zero) + { + return; + } + + HttpClient::Response Resp = m_Http.Delete(fmt::format("/queues/{}", Token)); + + if (Resp.StatusCode == HttpResponseCode::NoContent) + { + ZEN_DEBUG("cancelled remote queue '{}' (local queue {}) on {}", Token, QueueId, m_HostName); + } + else + { + ZEN_WARN("failed to cancel remote queue '{}' on {}: {}", Token, m_HostName, int(Resp.StatusCode)); + } +} + bool RemoteHttpRunner::IsHealthy() { @@ -337,7 +488,7 @@ RemoteHttpRunner::MonitorThreadFunction() do { - const int NormalWaitingTime = 1000; + const int NormalWaitingTime = 200; int WaitTimeMs = NormalWaitingTime; auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; auto SweepOnce = [&] { @@ -376,6 +527,7 @@ RemoteHttpRunner::MonitorThreadFunction() size_t RemoteHttpRunner::SweepRunningActions() { + ZEN_TRACE_CPU("RemoteHttpRunner::SweepRunningActions"); std::vector<HttpRunningAction> CompletedActions; // Poll remote for list of completed actions @@ -386,29 +538,38 @@ RemoteHttpRunner::SweepRunningActions() { for (auto& FieldIt : Completed["completed"sv]) { - const int32_t CompleteLsn = FieldIt.AsInt32(); + CbObjectView EntryObj = FieldIt.AsObjectView(); + const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); + std::string_view StateName = EntryObj["state"sv].AsString(); - if (HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn))) - { - m_RunningLock.WithExclusiveLock([&] { - if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end()) - { - HttpRunningAction CompletedAction = std::move(CompleteIt->second); - CompletedAction.ActionResults = ResponseJob.AsPackage(); - CompletedAction.Success = true; + RunnerAction::State RemoteState = RunnerAction::FromString(StateName); - CompletedActions.push_back(std::move(CompletedAction)); - m_RemoteRunningMap.erase(CompleteIt); - } - else + // Always fetch to drain the result from the remote's results map, + // but only keep the result package for successfully completed actions. + HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn)); + + m_RunningLock.WithExclusiveLock([&] { + if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end()) + { + HttpRunningAction CompletedAction = std::move(CompleteIt->second); + CompletedAction.RemoteState = RemoteState; + + if (RemoteState == RunnerAction::State::Completed && ResponseJob) { - // we received a completion notice for an action we don't know about, - // this can happen if the runner is used by multiple upstream schedulers, - // or if this compute node was recently restarted and lost track of - // previously scheduled actions + CompletedAction.ActionResults = ResponseJob.AsPackage(); } - }); - } + + CompletedActions.push_back(std::move(CompletedAction)); + m_RemoteRunningMap.erase(CompleteIt); + } + else + { + // we received a completion notice for an action we don't know about, + // this can happen if the runner is used by multiple upstream schedulers, + // or if this compute node was recently restarted and lost track of + // previously scheduled actions + } + }); } if (CbObjectView Metrics = Completed["metrics"sv].AsObjectView()) @@ -435,18 +596,18 @@ RemoteHttpRunner::SweepRunningActions() { const int ActionLsn = HttpAction.Action->ActionLsn; - if (HttpAction.Success) - { - ZEN_DEBUG("completed: {} LSN {} (remote LSN {})", HttpAction.Action->ActionId, ActionLsn, HttpAction.RemoteActionLsn); - - HttpAction.Action->SetActionState(RunnerAction::State::Completed); + ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + RunnerAction::ToString(HttpAction.RemoteState)); - HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); - } - else + if (HttpAction.RemoteState == RunnerAction::State::Completed) { - HttpAction.Action->SetActionState(RunnerAction::State::Failed); + HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); } + + HttpAction.Action->SetActionState(HttpAction.RemoteState); } return CompletedActions.size(); diff --git a/src/zencompute/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h index 1e885da3d..9119992a9 100644 --- a/src/zencompute/remotehttprunner.h +++ b/src/zencompute/runners/remotehttprunner.h @@ -2,7 +2,7 @@ #pragma once -#include "zencompute/functionservice.h" +#include "zencompute/computeservice.h" #if ZEN_WITH_COMPUTE_SERVICES @@ -10,12 +10,15 @@ # include <zencore/compactbinarypackage.h> # include <zencore/logging.h> +# include <zencore/uid.h> +# include <zencore/workthreadpool.h> # include <zencore/zencore.h> # include <zenhttp/httpclient.h> # include <atomic> # include <filesystem> # include <thread> +# include <unordered_map> namespace zen { class CidStore; @@ -35,7 +38,10 @@ class RemoteHttpRunner : public FunctionRunner RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete; public: - RemoteHttpRunner(ChunkResolver& InChunkResolver, const std::filesystem::path& BaseDir, std::string_view HostName); + RemoteHttpRunner(ChunkResolver& InChunkResolver, + const std::filesystem::path& BaseDir, + std::string_view HostName, + WorkerThreadPool& InWorkerPool); ~RemoteHttpRunner(); virtual void Shutdown() override; @@ -45,24 +51,29 @@ public: [[nodiscard]] virtual size_t GetSubmittedActionCount() override; [[nodiscard]] virtual size_t QueryCapacity() override; [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override; + virtual void CancelRemoteQueue(int QueueId) override; + + std::string_view GetHostName() const { return m_HostName; } protected: LoggerRef Log() { return m_Log; } private: - LoggerRef m_Log; - ChunkResolver& m_ChunkResolver; - std::string m_BaseUrl; - HttpClient m_Http; + LoggerRef m_Log; + ChunkResolver& m_ChunkResolver; + WorkerThreadPool& m_WorkerPool; + std::string m_HostName; + std::string m_BaseUrl; + HttpClient m_Http; int32_t m_MaxRunningActions = 256; // arbitrary limit for testing struct HttpRunningAction { - Ref<RunnerAction> Action; - int RemoteActionLsn = 0; // Remote LSN - bool Success = false; - CbPackage ActionResults; + Ref<RunnerAction> Action; + int RemoteActionLsn = 0; // Remote LSN + RunnerAction::State RemoteState = RunnerAction::State::Failed; + CbPackage ActionResults; }; RwLock m_RunningLock; @@ -73,6 +84,15 @@ private: Event m_MonitorThreadEvent; void MonitorThreadFunction(); size_t SweepRunningActions(); + + RwLock m_QueueTokenLock; + std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId → remote queue token + + // Stable identity for this runner instance, used as part of the idempotency key when + // creating remote queues. Generated once at construction and never changes. + Oid m_InstanceId; + + Oid EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config); }; } // namespace zen::compute diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp new file mode 100644 index 000000000..e9a1ae8b6 --- /dev/null +++ b/src/zencompute/runners/windowsrunner.cpp @@ -0,0 +1,460 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "windowsrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/scopeguard.h> +# include <zencore/trace.h> +# include <zencore/system.h> +# include <zencore/timer.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <userenv.h> +# include <aclapi.h> +# include <sddl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +using namespace std::literals; + +WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + if (!m_Sandboxed) + { + return; + } + + // Build a unique profile name per process to avoid collisions + m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); + + // Clean up any stale profile from a previous crash + DeleteAppContainerProfile(m_AppContainerName.c_str()); + + PSID Sid = nullptr; + + HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), + m_AppContainerName.c_str(), // display name + m_AppContainerName.c_str(), // description + nullptr, // no capabilities + 0, // capability count + &Sid); + + if (FAILED(Hr)) + { + throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); + } + + m_AppContainerSid = Sid; + + ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); +} + +WindowsProcessRunner::~WindowsProcessRunner() +{ + if (m_AppContainerSid) + { + FreeSid(m_AppContainerSid); + m_AppContainerSid = nullptr; + } + + if (!m_AppContainerName.empty()) + { + DeleteAppContainerProfile(m_AppContainerName.c_str()); + } +} + +void +WindowsProcessRunner::GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask) +{ + PACL ExistingDacl = nullptr; + PSECURITY_DESCRIPTOR SecurityDescriptor = nullptr; + + DWORD Result = GetNamedSecurityInfoW(Path.c_str(), + SE_FILE_OBJECT, + DACL_SECURITY_INFORMATION, + nullptr, + nullptr, + &ExistingDacl, + nullptr, + &SecurityDescriptor); + + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("GetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } + + auto $0 = MakeGuard([&] { LocalFree(SecurityDescriptor); }); + + EXPLICIT_ACCESSW Access{}; + Access.grfAccessPermissions = AccessMask; + Access.grfAccessMode = SET_ACCESS; + Access.grfInheritance = OBJECT_INHERIT_ACE | CONTAINER_INHERIT_ACE; + Access.Trustee.TrusteeForm = TRUSTEE_IS_SID; + Access.Trustee.TrusteeType = TRUSTEE_IS_WELL_KNOWN_GROUP; + Access.Trustee.ptstrName = static_cast<LPWSTR>(m_AppContainerSid); + + PACL NewDacl = nullptr; + + Result = SetEntriesInAclW(1, &Access, ExistingDacl, &NewDacl); + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("SetEntriesInAclW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } + + auto $1 = MakeGuard([&] { LocalFree(NewDacl); }); + + Result = SetNamedSecurityInfoW(const_cast<LPWSTR>(Path.c_str()), + SE_FILE_OBJECT, + DACL_SECURITY_INFORMATION, + nullptr, + nullptr, + NewDacl, + nullptr); + + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("SetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } +} + +SubmitResult +WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("WindowsProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Set up environment variables + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + StringBuilder<1024> EnvironmentBlock; + + for (auto& It : WorkerDescription["environment"sv]) + { + EnvironmentBlock.Append(It.AsString()); + EnvironmentBlock.Append('\0'); + } + EnvironmentBlock.Append('\0'); + EnvironmentBlock.Append('\0'); + + // Execute process - this spawns the child process immediately without waiting + // for completion + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred(); + + ExtendableWideStringBuilder<512> CommandLine; + CommandLine.Append(L'"'); + CommandLine.Append(ExePath.c_str()); + CommandLine.Append(L'"'); + CommandLine.Append(L" -Build=build.action"); + + LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; + LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; + BOOL bInheritHandles = FALSE; + DWORD dwCreationFlags = 0; + + ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed); + + CommandLine.EnsureNulTerminated(); + + PROCESS_INFORMATION ProcessInformation{}; + + if (m_Sandboxed) + { + // Grant AppContainer access to sandbox and worker directories + GrantAppContainerAccess(Prepared->SandboxPath, FILE_ALL_ACCESS); + GrantAppContainerAccess(Prepared->WorkerPath, FILE_GENERIC_READ | FILE_GENERIC_EXECUTE); + + // Set up extended startup info with AppContainer security capabilities + SECURITY_CAPABILITIES SecurityCapabilities{}; + SecurityCapabilities.AppContainerSid = m_AppContainerSid; + SecurityCapabilities.Capabilities = nullptr; + SecurityCapabilities.CapabilityCount = 0; + + SIZE_T AttrListSize = 0; + InitializeProcThreadAttributeList(nullptr, 1, 0, &AttrListSize); + + auto AttrList = static_cast<PPROC_THREAD_ATTRIBUTE_LIST>(malloc(AttrListSize)); + auto $0 = MakeGuard([&] { free(AttrList); }); + + if (!InitializeProcThreadAttributeList(AttrList, 1, 0, &AttrListSize)) + { + zen::ThrowLastError("InitializeProcThreadAttributeList failed"); + } + + auto $1 = MakeGuard([&] { DeleteProcThreadAttributeList(AttrList); }); + + if (!UpdateProcThreadAttribute(AttrList, + 0, + PROC_THREAD_ATTRIBUTE_SECURITY_CAPABILITIES, + &SecurityCapabilities, + sizeof(SecurityCapabilities), + nullptr, + nullptr)) + { + zen::ThrowLastError("UpdateProcThreadAttribute (SECURITY_CAPABILITIES) failed"); + } + + STARTUPINFOEXW StartupInfoEx{}; + StartupInfoEx.StartupInfo.cb = sizeof(STARTUPINFOEXW); + StartupInfoEx.lpAttributeList = AttrList; + + dwCreationFlags |= EXTENDED_STARTUPINFO_PRESENT; + + BOOL Success = CreateProcessW(nullptr, + CommandLine.Data(), + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags, + (LPVOID)EnvironmentBlock.Data(), + Prepared->SandboxPath.c_str(), + &StartupInfoEx.StartupInfo, + /* out */ &ProcessInformation); + + if (!Success) + { + zen::ThrowLastError("Unable to launch sandboxed process"); + } + } + else + { + STARTUPINFO StartupInfo{}; + StartupInfo.cb = sizeof StartupInfo; + + BOOL Success = CreateProcessW(nullptr, + CommandLine.Data(), + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags, + (LPVOID)EnvironmentBlock.Data(), + Prepared->SandboxPath.c_str(), + &StartupInfo, + /* out */ &ProcessInformation); + + if (!Success) + { + zen::ThrowLastError("Unable to launch process"); + } + } + + CloseHandle(ProcessInformation.hThread); + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = ProcessInformation.hProcess; + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +WindowsProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("WindowsProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + DWORD ExitCode = 0; + BOOL IsSuccess = GetExitCodeProcess(Running->ProcessHandle, &ExitCode); + + if (IsSuccess && ExitCode != STILL_ACTIVE) + { + CloseHandle(Running->ProcessHandle); + Running->ProcessHandle = INVALID_HANDLE_VALUE; + Running->ExitCode = ExitCode; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +WindowsProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("WindowsProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // For expedience we initiate the process termination for all known + // processes before attempting to wait for them to exit. + + // Initiate termination for all known processes before waiting for them to exit. + + for (const auto& Kv : RunningMap) + { + Ref<RunningAction> Running = Kv.second; + + BOOL TermSuccess = TerminateProcess(Running->ProcessHandle, 222); + + if (!TermSuccess) + { + DWORD LastError = GetLastError(); + + if (LastError != ERROR_ACCESS_DENIED) + { + ZEN_WARN("TerminateProcess for LSN {} not successful: {}", Running->Action->ActionLsn, GetSystemErrorAsString(LastError)); + } + } + } + + // Wait for all processes and clean up, regardless of whether TerminateProcess succeeded. + + for (auto& [Lsn, Running] : RunningMap) + { + if (Running->ProcessHandle != INVALID_HANDLE_VALUE) + { + DWORD WaitResult = WaitForSingleObject(Running->ProcessHandle, 2000); + + if (WaitResult != WAIT_OBJECT_0) + { + ZEN_WARN("wait for LSN {}: process exit did not succeed, result = {}", Running->Action->ActionLsn, WaitResult); + } + else + { + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + } + + CloseHandle(Running->ProcessHandle); + Running->ProcessHandle = INVALID_HANDLE_VALUE; + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +WindowsProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("WindowsProcessRunner::CancelAction"); + + // Hold the shared lock while terminating to prevent the sweep thread from + // closing the handle between our lookup and TerminateProcess call. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (Target->ProcessHandle == INVALID_HANDLE_VALUE) + { + return; + } + + BOOL TermSuccess = TerminateProcess(Target->ProcessHandle, 222); + + if (!TermSuccess) + { + DWORD LastError = GetLastError(); + + if (LastError != ERROR_ACCESS_DENIED) + { + ZEN_WARN("CancelAction: TerminateProcess for LSN {} not successful: {}", ActionLsn, GetSystemErrorAsString(LastError)); + } + + return; + } + + ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +void +WindowsProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + FILETIME CreationTime, ExitTime, KernelTime, UserTime; + if (!GetProcessTimes(Running.ProcessHandle, &CreationTime, &ExitTime, &KernelTime, &UserTime)) + { + return; + } + + auto FtToU64 = [](FILETIME Ft) -> uint64_t { return (static_cast<uint64_t>(Ft.dwHighDateTime) << 32) | Ft.dwLowDateTime; }; + + // FILETIME values are in 100-nanosecond intervals + const uint64_t CurrentOsTicks = FtToU64(KernelTime) + FtToU64(UserTime); + const uint64_t NowTicks = GetHifreqTimerValue(); + + // Cumulative CPU seconds (absolute, available from first sample): 100ns → seconds + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 10'000'000.0), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + // 100ns → ms: divide by 10000; then as percent of elapsed ms + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 10000.0 / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h new file mode 100644 index 000000000..9f2385cc4 --- /dev/null +++ b/src/zencompute/runners/windowsrunner.h @@ -0,0 +1,53 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS + +# include <zencore/windows.h> + +# include <string> + +namespace zen::compute { + +/** Windows process runner using CreateProcessW for executing worker executables. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using a Windows AppContainer: + no network access (AppContainer blocks network by default when no capabilities are + granted) and no filesystem access outside explicitly granted sandbox and worker + directories. This requires no elevation. + */ +class WindowsProcessRunner : public LocalProcessRunner +{ +public: + WindowsProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + ~WindowsProcessRunner(); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + void GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask); + + bool m_Sandboxed = false; + PSID m_AppContainerSid = nullptr; + std::wstring m_AppContainerName; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp new file mode 100644 index 000000000..506bec73b --- /dev/null +++ b/src/zencompute/runners/winerunner.cpp @@ -0,0 +1,237 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "winerunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <signal.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); +} + +SubmitResult +WineProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("WineProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: wine <worker_exe_path> -Build=build.action + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath); + std::string ExePathStr = ExePath.string(); + std::string WinePathStr = m_WinePath; + std::string BuildArg = "-Build=build.action"; + + std::vector<char*> ArgV; + ArgV.push_back(WinePathStr.data()); + ArgV.push_back(ExePathStr.data()); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing via Wine: {} {} {}", WinePathStr, ExePathStr, BuildArg); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + throw std::runtime_error(fmt::format("fork() failed: {}", strerror(errno))); + } + + if (ChildPid == 0) + { + // Child process + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(WinePathStr.c_str(), ArgV.data(), Envp.data()); + + // execve only returns on failure + _exit(127); + } + + // Parent: store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +WineProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("WineProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +WineProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("WineProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/winerunner.h b/src/zencompute/runners/winerunner.h new file mode 100644 index 000000000..7df62e7c0 --- /dev/null +++ b/src/zencompute/runners/winerunner.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <string> + +namespace zen::compute { + +/** Wine-based process runner for executing Windows worker executables on Linux. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + */ +class WineProcessRunner : public LocalProcessRunner +{ +public: + WineProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + +private: + std::string m_WinePath = "wine"; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/testing/mockimds.cpp b/src/zencompute/testing/mockimds.cpp new file mode 100644 index 000000000..dd09312df --- /dev/null +++ b/src/zencompute/testing/mockimds.cpp @@ -0,0 +1,205 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/mockimds.h> + +#include <zencore/fmtutils.h> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +const char* +MockImdsService::BaseUri() const +{ + return "/"; +} + +void +MockImdsService::HandleRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // AWS endpoints live under /latest/ + if (Uri.starts_with("latest/")) + { + if (ActiveProvider == CloudProvider::AWS) + { + HandleAwsRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // Azure endpoints live under /metadata/ + if (Uri.starts_with("metadata/")) + { + if (ActiveProvider == CloudProvider::Azure) + { + HandleAzureRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // GCP endpoints live under /computeMetadata/ + if (Uri.starts_with("computeMetadata/")) + { + if (ActiveProvider == CloudProvider::GCP) + { + HandleGcpRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAwsRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // IMDSv2 token acquisition (PUT only) + if (Uri == "latest/api/token" && Request.RequestVerb() == HttpVerb::kPut) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.Token); + return; + } + + // Instance identity + if (Uri == "latest/meta-data/instance-id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.InstanceId); + return; + } + + if (Uri == "latest/meta-data/placement/availability-zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AvailabilityZone); + return; + } + + if (Uri == "latest/meta-data/instance-life-cycle") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.LifeCycle); + return; + } + + // Autoscaling lifecycle state — 404 when not in an ASG + if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") + { + if (Aws.AutoscalingState.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AutoscalingState); + return; + } + + // Spot interruption notice — 404 when no interruption pending + if (Uri == "latest/meta-data/spot/instance-action") + { + if (Aws.SpotAction.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.SpotAction); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAzureRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // Instance metadata (single JSON document) + if (Uri == "metadata/instance") + { + std::string Json = fmt::format(R"({{"compute":{{"vmId":"{}","location":"{}","priority":"{}","vmScaleSetName":"{}"}}}})", + Azure.VmId, + Azure.Location, + Azure.Priority, + Azure.VmScaleSetName); + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + // Scheduled events for termination monitoring + if (Uri == "metadata/scheduledevents") + { + std::string Json; + if (Azure.ScheduledEventType.empty()) + { + Json = R"({"Events":[]})"; + } + else + { + Json = fmt::format(R"({{"Events":[{{"EventType":"{}","EventStatus":"{}"}}]}})", + Azure.ScheduledEventType, + Azure.ScheduledEventStatus); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleGcpRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + if (Uri == "computeMetadata/v1/instance/id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.InstanceId); + return; + } + + if (Uri == "computeMetadata/v1/instance/zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Zone); + return; + } + + if (Uri == "computeMetadata/v1/instance/scheduling/preemptible") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Preemptible); + return; + } + + if (Uri == "computeMetadata/v1/instance/maintenance-event") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.MaintenanceEvent); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/timeline/workertimeline.cpp b/src/zencompute/timeline/workertimeline.cpp new file mode 100644 index 000000000..88ef5b62d --- /dev/null +++ b/src/zencompute/timeline/workertimeline.cpp @@ -0,0 +1,430 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "workertimeline.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/basicfile.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinaryfile.h> + +# include <algorithm> + +namespace zen::compute { + +WorkerTimeline::WorkerTimeline(std::string_view WorkerId) : m_WorkerId(WorkerId) +{ +} + +WorkerTimeline::~WorkerTimeline() +{ +} + +void +WorkerTimeline::RecordProvisioned() +{ + AppendEvent({ + .Type = EventType::WorkerProvisioned, + .Timestamp = DateTime::Now(), + }); +} + +void +WorkerTimeline::RecordDeprovisioned() +{ + AppendEvent({ + .Type = EventType::WorkerDeprovisioned, + .Timestamp = DateTime::Now(), + }); +} + +void +WorkerTimeline::RecordActionAccepted(int ActionLsn, const IoHash& ActionId) +{ + AppendEvent({ + .Type = EventType::ActionAccepted, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + }); +} + +void +WorkerTimeline::RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason) +{ + AppendEvent({ + .Type = EventType::ActionRejected, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + .Reason = std::string(Reason), + }); +} + +void +WorkerTimeline::RecordActionStateChanged(int ActionLsn, + const IoHash& ActionId, + RunnerAction::State PreviousState, + RunnerAction::State NewState) +{ + AppendEvent({ + .Type = EventType::ActionStateChanged, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + .ActionState = NewState, + .PreviousState = PreviousState, + }); +} + +std::vector<WorkerTimeline::Event> +WorkerTimeline::QueryTimeline(DateTime StartTime, DateTime EndTime) const +{ + std::vector<Event> Result; + + m_EventsLock.WithSharedLock([&] { + for (const auto& Evt : m_Events) + { + if (Evt.Timestamp >= StartTime && Evt.Timestamp <= EndTime) + { + Result.push_back(Evt); + } + } + }); + + return Result; +} + +std::vector<WorkerTimeline::Event> +WorkerTimeline::QueryRecent(int Limit) const +{ + std::vector<Event> Result; + + m_EventsLock.WithSharedLock([&] { + const int Count = std::min(Limit, gsl::narrow<int>(m_Events.size())); + auto It = m_Events.end() - Count; + Result.assign(It, m_Events.end()); + }); + + return Result; +} + +size_t +WorkerTimeline::GetEventCount() const +{ + size_t Count = 0; + m_EventsLock.WithSharedLock([&] { Count = m_Events.size(); }); + return Count; +} + +WorkerTimeline::TimeRange +WorkerTimeline::GetTimeRange() const +{ + TimeRange Range; + m_EventsLock.WithSharedLock([&] { + if (!m_Events.empty()) + { + Range.First = m_Events.front().Timestamp; + Range.Last = m_Events.back().Timestamp; + } + }); + return Range; +} + +void +WorkerTimeline::AppendEvent(Event&& Evt) +{ + m_EventsLock.WithExclusiveLock([&] { + while (m_Events.size() >= m_MaxEvents) + { + m_Events.pop_front(); + } + + m_Events.push_back(std::move(Evt)); + }); +} + +const char* +WorkerTimeline::ToString(EventType Type) +{ + switch (Type) + { + case EventType::WorkerProvisioned: + return "provisioned"; + case EventType::WorkerDeprovisioned: + return "deprovisioned"; + case EventType::ActionAccepted: + return "accepted"; + case EventType::ActionRejected: + return "rejected"; + case EventType::ActionStateChanged: + return "state_changed"; + default: + return "unknown"; + } +} + +static WorkerTimeline::EventType +EventTypeFromString(std::string_view Str) +{ + if (Str == "provisioned") + return WorkerTimeline::EventType::WorkerProvisioned; + if (Str == "deprovisioned") + return WorkerTimeline::EventType::WorkerDeprovisioned; + if (Str == "accepted") + return WorkerTimeline::EventType::ActionAccepted; + if (Str == "rejected") + return WorkerTimeline::EventType::ActionRejected; + if (Str == "state_changed") + return WorkerTimeline::EventType::ActionStateChanged; + return WorkerTimeline::EventType::WorkerProvisioned; +} + +void +WorkerTimeline::WriteTo(const std::filesystem::path& Path) const +{ + CbObjectWriter Cbo; + Cbo << "worker_id" << m_WorkerId; + + m_EventsLock.WithSharedLock([&] { + if (!m_Events.empty()) + { + Cbo.AddDateTime("time_first", m_Events.front().Timestamp); + Cbo.AddDateTime("time_last", m_Events.back().Timestamp); + } + + Cbo.BeginArray("events"); + for (const auto& Evt : m_Events) + { + Cbo.BeginObject(); + Cbo << "type" << ToString(Evt.Type); + Cbo.AddDateTime("ts", Evt.Timestamp); + + if (Evt.ActionLsn != 0) + { + Cbo << "lsn" << Evt.ActionLsn; + Cbo << "action_id" << Evt.ActionId; + } + + if (Evt.Type == EventType::ActionStateChanged) + { + Cbo << "prev_state" << static_cast<int32_t>(Evt.PreviousState); + Cbo << "state" << static_cast<int32_t>(Evt.ActionState); + } + + if (!Evt.Reason.empty()) + { + Cbo << "reason" << std::string_view(Evt.Reason); + } + + Cbo.EndObject(); + } + Cbo.EndArray(); + }); + + CbObject Obj = Cbo.Save(); + + BasicFile File(Path, BasicFile::Mode::kTruncate); + File.Write(Obj.GetBuffer().GetView(), 0); +} + +void +WorkerTimeline::ReadFrom(const std::filesystem::path& Path) +{ + CbObjectFromFile Loaded = LoadCompactBinaryObject(Path); + CbObject Root = std::move(Loaded.Object); + + if (!Root) + { + return; + } + + std::deque<Event> LoadedEvents; + + for (CbFieldView Field : Root["events"].AsArrayView()) + { + CbObjectView EventObj = Field.AsObjectView(); + + Event Evt; + Evt.Type = EventTypeFromString(EventObj["type"].AsString()); + Evt.Timestamp = EventObj["ts"].AsDateTime(); + + Evt.ActionLsn = EventObj["lsn"].AsInt32(); + Evt.ActionId = EventObj["action_id"].AsHash(); + + if (Evt.Type == EventType::ActionStateChanged) + { + Evt.PreviousState = static_cast<RunnerAction::State>(EventObj["prev_state"].AsInt32()); + Evt.ActionState = static_cast<RunnerAction::State>(EventObj["state"].AsInt32()); + } + + std::string_view Reason = EventObj["reason"].AsString(); + if (!Reason.empty()) + { + Evt.Reason = std::string(Reason); + } + + LoadedEvents.push_back(std::move(Evt)); + } + + m_EventsLock.WithExclusiveLock([&] { m_Events = std::move(LoadedEvents); }); +} + +WorkerTimeline::TimeRange +WorkerTimeline::ReadTimeRange(const std::filesystem::path& Path) +{ + CbObjectFromFile Loaded = LoadCompactBinaryObject(Path); + + if (!Loaded.Object) + { + return {}; + } + + return { + .First = Loaded.Object["time_first"].AsDateTime(), + .Last = Loaded.Object["time_last"].AsDateTime(), + }; +} + +// WorkerTimelineStore + +static constexpr std::string_view kTimelineExtension = ".ztimeline"; + +WorkerTimelineStore::WorkerTimelineStore(std::filesystem::path PersistenceDir) : m_PersistenceDir(std::move(PersistenceDir)) +{ + std::error_code Ec; + std::filesystem::create_directories(m_PersistenceDir, Ec); +} + +Ref<WorkerTimeline> +WorkerTimelineStore::GetOrCreate(std::string_view WorkerId) +{ + // Fast path: check if it already exists in memory + { + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + return It->second; + } + } + + // Slow path: create under exclusive lock, loading from disk if available + RwLock::ExclusiveLockScope _(m_Lock); + + auto& Entry = m_Timelines[std::string(WorkerId)]; + if (!Entry) + { + Entry = Ref<WorkerTimeline>(new WorkerTimeline(WorkerId)); + + std::filesystem::path Path = TimelinePath(WorkerId); + std::error_code Ec; + if (std::filesystem::is_regular_file(Path, Ec)) + { + Entry->ReadFrom(Path); + } + } + return Entry; +} + +Ref<WorkerTimeline> +WorkerTimelineStore::Find(std::string_view WorkerId) +{ + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + return It->second; + } + return {}; +} + +std::vector<std::string> +WorkerTimelineStore::GetActiveWorkerIds() const +{ + std::vector<std::string> Result; + + RwLock::SharedLockScope $(m_Lock); + Result.reserve(m_Timelines.size()); + for (const auto& [Id, _] : m_Timelines) + { + Result.push_back(Id); + } + + return Result; +} + +std::vector<WorkerTimelineStore::WorkerTimelineInfo> +WorkerTimelineStore::GetAllWorkerInfo() const +{ + std::unordered_map<std::string, WorkerTimeline::TimeRange> InfoMap; + + { + RwLock::SharedLockScope _(m_Lock); + for (const auto& [Id, Timeline] : m_Timelines) + { + InfoMap[Id] = Timeline->GetTimeRange(); + } + } + + std::error_code Ec; + for (const auto& Entry : std::filesystem::directory_iterator(m_PersistenceDir, Ec)) + { + if (!Entry.is_regular_file()) + { + continue; + } + + const auto& Path = Entry.path(); + if (Path.extension().string() != kTimelineExtension) + { + continue; + } + + std::string Id = Path.stem().string(); + if (InfoMap.find(Id) == InfoMap.end()) + { + InfoMap[Id] = WorkerTimeline::ReadTimeRange(Path); + } + } + + std::vector<WorkerTimelineInfo> Result; + Result.reserve(InfoMap.size()); + for (auto& [Id, Range] : InfoMap) + { + Result.push_back({.WorkerId = std::move(Id), .Range = Range}); + } + return Result; +} + +void +WorkerTimelineStore::Save(std::string_view WorkerId) +{ + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + It->second->WriteTo(TimelinePath(WorkerId)); + } +} + +void +WorkerTimelineStore::SaveAll() +{ + RwLock::SharedLockScope _(m_Lock); + for (const auto& [Id, Timeline] : m_Timelines) + { + Timeline->WriteTo(TimelinePath(Id)); + } +} + +std::filesystem::path +WorkerTimelineStore::TimelinePath(std::string_view WorkerId) const +{ + return m_PersistenceDir / (std::string(WorkerId) + std::string(kTimelineExtension)); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/timeline/workertimeline.h b/src/zencompute/timeline/workertimeline.h new file mode 100644 index 000000000..87e19bc28 --- /dev/null +++ b/src/zencompute/timeline/workertimeline.h @@ -0,0 +1,169 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../runners/functionrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zenbase/refcount.h> +# include <zencore/compactbinary.h> +# include <zencore/iohash.h> +# include <zencore/thread.h> +# include <zencore/timer.h> + +# include <deque> +# include <filesystem> +# include <string> +# include <string_view> +# include <unordered_map> +# include <vector> + +namespace zen::compute { + +struct RunnerAction; + +/** Worker activity timeline for tracking and visualizing worker activity over time. + * + * Records worker lifecycle events (provisioning/deprovisioning) and action lifecycle + * events (accept, reject, state changes) with timestamps, enabling time-range queries + * for dashboard visualization. + */ +class WorkerTimeline : public RefCounted +{ +public: + explicit WorkerTimeline(std::string_view WorkerId); + ~WorkerTimeline() override; + + struct TimeRange + { + DateTime First = DateTime(0); + DateTime Last = DateTime(0); + + explicit operator bool() const { return First.GetTicks() != 0; } + }; + + enum class EventType + { + WorkerProvisioned, + WorkerDeprovisioned, + ActionAccepted, + ActionRejected, + ActionStateChanged + }; + + static const char* ToString(EventType Type); + + struct Event + { + EventType Type; + DateTime Timestamp = DateTime(0); + + // Action context (only set for action events) + int ActionLsn = 0; + IoHash ActionId; + RunnerAction::State ActionState = RunnerAction::State::New; + RunnerAction::State PreviousState = RunnerAction::State::New; + + // Optional reason (e.g. rejection reason) + std::string Reason; + }; + + /** Record that this worker has been provisioned and is available for work. */ + void RecordProvisioned(); + + /** Record that this worker has been deprovisioned and is no longer available. */ + void RecordDeprovisioned(); + + /** Record that an action was accepted by this worker. */ + void RecordActionAccepted(int ActionLsn, const IoHash& ActionId); + + /** Record that an action was rejected by this worker. */ + void RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason); + + /** Record an action state transition on this worker. */ + void RecordActionStateChanged(int ActionLsn, const IoHash& ActionId, RunnerAction::State PreviousState, RunnerAction::State NewState); + + /** Query events within a time range (inclusive). Returns events ordered by timestamp. */ + [[nodiscard]] std::vector<Event> QueryTimeline(DateTime StartTime, DateTime EndTime) const; + + /** Query the most recent N events. */ + [[nodiscard]] std::vector<Event> QueryRecent(int Limit = 100) const; + + /** Return the total number of recorded events. */ + [[nodiscard]] size_t GetEventCount() const; + + /** Return the time range covered by the events in this timeline. */ + [[nodiscard]] TimeRange GetTimeRange() const; + + [[nodiscard]] const std::string& GetWorkerId() const { return m_WorkerId; } + + /** Write the timeline to a file at the given path. */ + void WriteTo(const std::filesystem::path& Path) const; + + /** Read the timeline from a file at the given path. Replaces current in-memory events. */ + void ReadFrom(const std::filesystem::path& Path); + + /** Read only the time range from a persisted timeline file, without loading events. */ + [[nodiscard]] static TimeRange ReadTimeRange(const std::filesystem::path& Path); + +private: + void AppendEvent(Event&& Evt); + + std::string m_WorkerId; + mutable RwLock m_EventsLock; + std::deque<Event> m_Events; + size_t m_MaxEvents = 10'000; +}; + +/** Manages a set of WorkerTimeline instances, keyed by worker ID. + * + * Provides thread-safe lookup and on-demand creation of timelines, backed by + * a persistence directory. Each timeline is stored as a separate file named + * {WorkerId}.ztimeline within the directory. + */ +class WorkerTimelineStore +{ +public: + explicit WorkerTimelineStore(std::filesystem::path PersistenceDir); + ~WorkerTimelineStore() = default; + + WorkerTimelineStore(const WorkerTimelineStore&) = delete; + WorkerTimelineStore& operator=(const WorkerTimelineStore&) = delete; + + /** Get the timeline for a worker, creating one if it does not exist. + * If a persisted file exists on disk it will be loaded on first access. */ + Ref<WorkerTimeline> GetOrCreate(std::string_view WorkerId); + + /** Get the timeline for a worker, or null ref if it does not exist in memory. */ + [[nodiscard]] Ref<WorkerTimeline> Find(std::string_view WorkerId); + + /** Return the worker IDs of currently loaded (in-memory) timelines. */ + [[nodiscard]] std::vector<std::string> GetActiveWorkerIds() const; + + struct WorkerTimelineInfo + { + std::string WorkerId; + WorkerTimeline::TimeRange Range; + }; + + /** Return info for all known timelines (in-memory and on-disk), including time range. */ + [[nodiscard]] std::vector<WorkerTimelineInfo> GetAllWorkerInfo() const; + + /** Persist a single worker's timeline to disk. */ + void Save(std::string_view WorkerId); + + /** Persist all in-memory timelines to disk. */ + void SaveAll(); + +private: + [[nodiscard]] std::filesystem::path TimelinePath(std::string_view WorkerId) const; + + std::filesystem::path m_PersistenceDir; + mutable RwLock m_Lock; + std::unordered_map<std::string, Ref<WorkerTimeline>> m_Timelines; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/xmake.lua b/src/zencompute/xmake.lua index c710b662d..ed0af66a5 100644 --- a/src/zencompute/xmake.lua +++ b/src/zencompute/xmake.lua @@ -6,6 +6,14 @@ target('zencompute') add_headerfiles("**.h") add_files("**.cpp") add_includedirs("include", {public=true}) + add_includedirs(".", {private=true}) add_deps("zencore", "zenstore", "zenutil", "zennet", "zenhttp") - add_packages("vcpkg::gsl-lite") - add_packages("vcpkg::spdlog", "vcpkg::cxxopts") + add_packages("json11") + + if is_os("macosx") then + add_cxxflags("-Wno-deprecated-declarations") + end + + if is_plat("windows") then + add_syslinks("Userenv") + end diff --git a/src/zencompute/zencompute.cpp b/src/zencompute/zencompute.cpp index 633250f4e..1f3f6d3f9 100644 --- a/src/zencompute/zencompute.cpp +++ b/src/zencompute/zencompute.cpp @@ -2,11 +2,20 @@ #include "zencompute/zencompute.h" +#if ZEN_WITH_TESTS +# include "runners/deferreddeleter.h" +# include <zencompute/cloudmetadata.h> +#endif + namespace zen { void zencompute_forcelinktests() { +#if ZEN_WITH_TESTS + compute::cloudmetadata_forcelink(); + compute::deferreddeleter_forcelink(); +#endif } } // namespace zen |