diff options
| author | Stefan Boberg <[email protected]> | 2026-03-30 15:07:08 +0200 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-30 15:07:08 +0200 |
| commit | 3540d676733efaddecf504b30e9a596465bd43f8 (patch) | |
| tree | 7a8d8b3d2da993e30c34e3ff36f659b90a2b228e /src | |
| parent | include rawHash in structure output for builds ls command (#903) (diff) | |
| download | zen-3540d676733efaddecf504b30e9a596465bd43f8.tar.xz zen-3540d676733efaddecf504b30e9a596465bd43f8.zip | |
Request validation and resilience improvements (#864)
### Security: Input validation & path safety
- **Reject local file references by default** in package parsing — only allow when explicitly opted in by the service (`ParseFlags::kAllowLocalReferences`) and validated by an `ILocalRefPolicy` (fail-closed: no policy = rejected)
- **`DataRootLocalRefPolicy`** restricts local ref paths to the server's data root via canonical path prefix matching
- **Validate attachment hashes** in compute HTTP handlers — decompresses and re-hashes each attachment at ingestion time to reject tampered payloads
- **Path traversal validation** for worker descriptions (`pathvalidation.h`) — rejects absolute paths, `..` components, Windows reserved device names, and invalid filename characters
- **Harden CbPackage parsing** against corrupt inputs — overflow-safe attachment count, bounds checks on local ref offset/size, graceful failure instead of `ZEN_ASSERT` for untrusted data
- **Harden legacy package parser** — reject zero-size binary fields, missing mappers, and optionally validate resolved attachment hashes
- **Bounds check in `CbPackageReader::MarshalLocalChunkReference`** — detect when `MakeFromFile` silently clamps offset+size to file size
### Reliability: Lock consolidation & bug fixes
- **Consolidate three action map locks into one** (`m_ActionMapLock`) — eliminates deadlock risk from multi-lock ordering, simplifies state transitions, and fixes a race where newly enqueued actions were briefly invisible to `GetActionResult`/`FindActionResult`
- **Fix infinite loop in `BaseRunnerGroup::SubmitActions`** when actions exceed total runner capacity — cap round-robin at `TotalCapacity` and default unassigned results to "No capacity"
- **Fix `MakeSafeAbsolutePathInPlace` for UNC paths** — `\server\share` now correctly becomes `\?\UNC\server\share` instead of `\?\server\share`
- **Fix `max_retries=0`** — previously fell through to the default of 3; now correctly means "no retries"
### New: ManagedProcessRunner
- Cross-platform process runner backed by `SubprocessManager` — uses async exit callbacks instead of polling, delegates CPU/memory metrics to the manager's built-in sampler
- `ProcessGroup` (JobObject on Windows, process group on POSIX) for bulk cancellation on shutdown
- `--managed` flag on `zen exec inproc` to select this runner
- Refactored monitor thread lifecycle — `StartMonitorThread()` now called from derived constructors to avoid calling virtual functions from base constructor
### Process management
- **Suppress crash dialogs** via `JOB_OBJECT_UILIMIT_ERRORMODE` + `SEM_NOGPFAULTERRORBOX` in both `WindowsProcessRunner` and `JobObject::Initialize` — prevents WER/Dr. Watson modal dialogs from blocking the monitor thread
- **CREATE_SUSPENDED → AssignProcessToJobObject → ResumeThread** pattern in `WindowsProcessRunner` — ensures job object assignment before process execution
- **Move stdout/stderr callbacks to `Spawn()` parameters** in `SubprocessManager` — prevents race where early output could be missed before callback installation
- Consistent PID logging across all runner types
### Test infrastructure
- **`zentest-appstub`**: Added `Fail` (configurable exit code) and `Crash` (abort / nullptr deref) test functions
- **Compute integration tests**: exit code handling, auto-retry exhaustion, manual reschedule after failure, mixed success/failure queues, crash handling (abort + nullptr), crash auto-retry, immediate query visibility after enqueue
- **Package format tests**: truncated header, bad magic, attachment count overflow, truncated data, local ref rejection/acceptance, policy enforcement (inside/outside root, traversal, no-policy fail-closed)
- **Legacy package parser tests**: empty input, zero-size binary, hash resolution with/without mapper, hash mismatch detection
- **UNC path tests** for `MakeSafeAbsolutePath`
### Misc
- ANSI color helper macros (`ZEN_RED`, `ZEN_BRIGHT_WHITE`, etc.) and `ZEN_BOLD`/`ZEN_DIM`/etc.
- Generic `fmt::formatter` for types with free `ToString` functions
- Compute dashboard: truncated hash display with monospace font and hover for full value
- Renamed `usonpackage_forcelink` → `cbpackage_forcelink`
- Compute enabled by default in xmake config (releases still explicitly disable)
Diffstat (limited to 'src')
44 files changed, 2479 insertions, 376 deletions
diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp index 30e860a3f..9719fce77 100644 --- a/src/zen/cmds/exec_cmd.cpp +++ b/src/zen/cmds/exec_cmd.cpp @@ -1119,6 +1119,7 @@ ExecHttpSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) ExecInprocSubCmd::ExecInprocSubCmd(ExecCommand& Parent) : ZenSubCmdBase("inproc", "Handle execution in-process"), m_Parent(Parent) { + m_SubOptions.add_option("managed", "", "managed", "Use managed local runner (if supported)", cxxopts::value(m_Managed), "<bool>"); } void @@ -1130,7 +1131,16 @@ ExecInprocSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) zen::compute::ComputeServiceSession ComputeSession(Resolver); std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - ComputeSession.AddLocalRunner(Resolver, TempPath); + if (m_Managed) + { + ZEN_CONSOLE_INFO("using managed local runner"); + ComputeSession.AddManagedLocalRunner(Resolver, TempPath); + } + else + { + ZEN_CONSOLE_INFO("using local runner"); + ComputeSession.AddLocalRunner(Resolver, TempPath); + } Stopwatch ExecTimer; int ReturnValue = m_Parent.RunSession(ComputeSession); diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h index c55412780..a0bf201a1 100644 --- a/src/zen/cmds/exec_cmd.h +++ b/src/zen/cmds/exec_cmd.h @@ -61,6 +61,7 @@ public: private: ExecCommand& m_Parent; + bool m_Managed = false; }; class ExecBeaconSubCmd : public ZenSubCmdBase diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md index a1a39fc3c..750879d5a 100644 --- a/src/zencompute/CLAUDE.md +++ b/src/zencompute/CLAUDE.md @@ -141,7 +141,7 @@ Actions that fail or are abandoned can be automatically retried or manually resc **Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure. -**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. +**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Setting `max_retries=0` disables automatic retry entirely; omitting the field (or setting it to a negative value) uses the default of 3. Both automatic and manual paths respect this limit. **Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent. @@ -156,7 +156,7 @@ Queues group actions from a single client session. A `QueueEntry` (internal) tra - `ActiveLsns` — for cancellation lookup (under `m_Lock`) - `FinishedLsns` — moved here when actions complete - `IdleSince` — used for 15-minute automatic expiry -- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit +- `Config` — CbObject set at creation; supports `max_retries` (int, default 3) to override the default retry limit. `0` = no retries, negative or absent = use default **Queue state machine (`QueueState` enum):** ``` @@ -216,11 +216,7 @@ Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `Han ## Concurrency Model -**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks: -1. `m_ResultsLock` -2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock") -3. `m_PendingLock` -4. `m_QueueLock` +**Locking discipline:** The three action maps (`m_PendingActions`, `m_RunningMap`, `m_ResultsMap`) are guarded by a single `m_ActionMapLock`. This eliminates lock-ordering concerns between maps and prevents actions from being temporarily absent from all maps during state transitions. Runner-level `m_RunningLock` in `LocalProcessRunner` / `RemoteHttpRunner` is a separate lock on a different class — unrelated to the session-level action map lock. **Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`. diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp index 92901de64..58761556a 100644 --- a/src/zencompute/computeservice.cpp +++ b/src/zencompute/computeservice.cpp @@ -8,6 +8,8 @@ # include "recording/actionrecorder.h" # include "runners/localrunner.h" # include "runners/remotehttprunner.h" +# include "runners/managedrunner.h" +# include "pathvalidation.h" # if ZEN_PLATFORM_LINUX # include "runners/linuxrunner.h" # elif ZEN_PLATFORM_WINDOWS @@ -195,13 +197,9 @@ struct ComputeServiceSession::Impl std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr}; - RwLock m_PendingLock; - std::map<int, Ref<RunnerAction>> m_PendingActions; - - RwLock m_RunningLock; + RwLock m_ActionMapLock; // Guards m_PendingActions, m_RunningMap, m_ResultsMap + std::map<int, Ref<RunnerAction>> m_PendingActions; std::unordered_map<int, Ref<RunnerAction>> m_RunningMap; - - RwLock m_ResultsLock; std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap; metrics::Meter m_ResultRate; std::atomic<uint64_t> m_RetiredCount{0}; @@ -343,9 +341,12 @@ struct ComputeServiceSession::Impl ActionCounts GetActionCounts() { ActionCounts Counts; - Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); - Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load(); + m_ActionMapLock.WithSharedLock([&] { + Counts.Pending = (int)m_PendingActions.size(); + Counts.Running = (int)m_RunningMap.size(); + Counts.Completed = (int)m_ResultsMap.size(); + }); + Counts.Completed += (int)m_RetiredCount.load(); Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] { size_t Count = 0; for (const auto& [Id, Queue] : m_Queues) @@ -364,8 +365,10 @@ struct ComputeServiceSession::Impl { Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed)); m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); - m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); - m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); + m_ActionMapLock.WithSharedLock([&] { + Cbo << "actions_complete"sv << m_ResultsMap.size(); + Cbo << "actions_pending"sv << m_PendingActions.size(); + }); Cbo << "actions_submitted"sv << GetSubmittedActionCount(); EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo); EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); @@ -450,27 +453,17 @@ ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState) void ComputeServiceSession::Impl::AbandonAllActions() { - // Collect all pending actions and mark them as Abandoned + // Collect all pending and running actions under a single lock scope std::vector<Ref<RunnerAction>> PendingToAbandon; + std::vector<Ref<RunnerAction>> RunningToAbandon; - m_PendingLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { PendingToAbandon.reserve(m_PendingActions.size()); for (auto& [Lsn, Action] : m_PendingActions) { PendingToAbandon.push_back(Action); } - }); - - for (auto& Action : PendingToAbandon) - { - Action->SetActionState(RunnerAction::State::Abandoned); - } - // Collect all running actions and mark them as Abandoned, then - // best-effort cancel via the local runner group - std::vector<Ref<RunnerAction>> RunningToAbandon; - - m_RunningLock.WithSharedLock([&] { RunningToAbandon.reserve(m_RunningMap.size()); for (auto& [Lsn, Action] : m_RunningMap) { @@ -478,6 +471,11 @@ ComputeServiceSession::Impl::AbandonAllActions() } }); + for (auto& Action : PendingToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + } + for (auto& Action : RunningToAbandon) { Action->SetActionState(RunnerAction::State::Abandoned); @@ -742,7 +740,7 @@ std::vector<ComputeServiceSession::RunningActionInfo> ComputeServiceSession::Impl::GetRunningActions() { std::vector<ComputeServiceSession::RunningActionInfo> Result; - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { Result.reserve(m_RunningMap.size()); for (const auto& [Lsn, Action] : m_RunningMap) { @@ -810,6 +808,11 @@ void ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) { ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker"); + + // Validate all paths in the worker description upfront, before the worker is + // distributed to runners. This rejects malicious packages early at ingestion time. + ValidateWorkerDescriptionPaths(Worker.GetObject()); + RwLock::ExclusiveLockScope _(m_WorkerLock); const IoHash& WorkerId = Worker.GetObject().GetHash(); @@ -994,10 +997,15 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke Pending->ActionObj = ActionObj; Pending->Priority = RequestPriority; - // For now simply put action into pending state, so we can do batch scheduling + // Insert into the pending map immediately so the action is visible to + // FindActionResult/GetActionResult right away. SetActionState will call + // PostUpdate which adds the action to m_UpdatedActions and signals the + // scheduler, but the scheduler's HandleActionUpdates inserts with + // std::map::insert which is a no-op for existing keys. ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Pending}); }); Pending->SetActionState(RunnerAction::State::Pending); if (m_Recorder) @@ -1043,11 +1051,7 @@ ComputeServiceSession::Impl::GetSubmittedActionCount() HttpResponseCode ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) { - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); + RwLock::ExclusiveLockScope _(m_ActionMapLock); if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) { @@ -1058,25 +1062,14 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult return HttpResponseCode::OK; } + if (m_PendingActions.find(ActionLsn) != m_PendingActions.end()) { - RwLock::SharedLockScope __(m_PendingLock); - - if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - + if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) { - RwLock::SharedLockScope __(m_RunningLock); - - if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } return HttpResponseCode::NotFound; @@ -1085,11 +1078,7 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult HttpResponseCode ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) { - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); + RwLock::ExclusiveLockScope _(m_ActionMapLock); for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) { @@ -1103,30 +1092,19 @@ ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& } } + for (const auto& [K, Pending] : m_PendingActions) { - RwLock::SharedLockScope __(m_PendingLock); - - for (const auto& [K, Pending] : m_PendingActions) + if (Pending->ActionId == ActionId) { - if (Pending->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } } - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - + for (const auto& [K, v] : m_RunningMap) { - RwLock::SharedLockScope __(m_RunningLock); - - for (const auto& [K, v] : m_RunningMap) + if (v->ActionId == ActionId) { - if (v->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } } @@ -1144,7 +1122,7 @@ ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo) { Cbo.BeginArray("completed"); - m_ResultsLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (auto& [Lsn, Action] : m_ResultsMap) { Cbo.BeginObject(); @@ -1275,20 +1253,14 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId) std::vector<Ref<RunnerAction>> PendingActionsToCancel; std::vector<int> RunningLsnsToCancel; - m_PendingLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (int Lsn : LsnsToCancel) { if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end()) { PendingActionsToCancel.push_back(It->second); } - } - }); - - m_RunningLock.WithSharedLock([&] { - for (int Lsn : LsnsToCancel) - { - if (m_RunningMap.find(Lsn) != m_RunningMap.end()) + else if (m_RunningMap.find(Lsn) != m_RunningMap.end()) { RunningLsnsToCancel.push_back(Lsn); } @@ -1307,7 +1279,7 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId) // transition from the runner is blocked (Cancelled > Failed in the enum). for (int Lsn : RunningLsnsToCancel) { - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end()) { It->second->SetActionState(RunnerAction::State::Cancelled); @@ -1445,7 +1417,7 @@ ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo) if (Queue) { Queue->m_Lock.WithSharedLock([&] { - m_ResultsLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (int Lsn : Queue->FinishedLsns) { if (m_ResultsMap.contains(Lsn)) @@ -1541,9 +1513,15 @@ ComputeServiceSession::Impl::SchedulePendingActions() { ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions"); int ScheduledCount = 0; - size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); - size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); + size_t RunningCount = 0; + size_t PendingCount = 0; + size_t ResultCount = 0; + + m_ActionMapLock.WithSharedLock([&] { + RunningCount = m_RunningMap.size(); + PendingCount = m_PendingActions.size(); + ResultCount = m_ResultsMap.size(); + }); static Stopwatch DumpRunningTimer; @@ -1560,7 +1538,7 @@ ComputeServiceSession::Impl::SchedulePendingActions() DumpRunningTimer.Reset(); std::set<int> RunningList; - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (auto& [K, V] : m_RunningMap) { RunningList.insert(K); @@ -1602,7 +1580,7 @@ ComputeServiceSession::Impl::SchedulePendingActions() // Also note that the m_PendingActions list is not maintained // here, that's done periodically in SchedulePendingActions() - m_PendingLock.WithExclusiveLock([&] { + m_ActionMapLock.WithExclusiveLock([&] { if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) { return; @@ -1701,7 +1679,7 @@ ComputeServiceSession::Impl::SchedulerThreadFunction() { int TimeoutMs = 500; - auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + auto PendingCount = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.size(); }); if (PendingCount) { @@ -1720,22 +1698,22 @@ ComputeServiceSession::Impl::SchedulerThreadFunction() m_SchedulingThreadEvent.Reset(); } - ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}", - m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }), - PendingCount, - m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }), - m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }), - TimeoutMs); + m_ActionMapLock.WithSharedLock([&] { + ZEN_DEBUG("compute scheduler TICK (Pending: {}, Running: {}, Results: {}) timeout: {}", + m_PendingActions.size(), + m_RunningMap.size(), + m_ResultsMap.size(), + TimeoutMs); + }); HandleActionUpdates(); // Auto-transition Draining → Paused when all work is done if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining) { - size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + bool AllDrained = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.empty() && m_RunningMap.empty(); }); - if (Pending == 0 && Running == 0) + if (AllDrained) { SessionState Expected = SessionState::Draining; if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel)) @@ -1776,9 +1754,9 @@ ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId) if (Config) { - int Value = Config["max_retries"].AsInt32(0); + int Value = Config["max_retries"].AsInt32(-1); - if (Value > 0) + if (Value >= 0) { return Value; } @@ -1797,7 +1775,7 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) // Find, validate, and remove atomically under a single lock scope to prevent // concurrent RescheduleAction calls from double-removing the same action. - m_ResultsLock.WithExclusiveLock([&] { + m_ActionMapLock.WithExclusiveLock([&] { auto It = m_ResultsMap.find(ActionLsn); if (It == m_ResultsMap.end()) { @@ -1871,26 +1849,20 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn) bool WasRunning = false; // Look for the action in pending or running maps - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end()) { Action = It->second; WasRunning = true; } + else if (auto PIt = m_PendingActions.find(ActionLsn); PIt != m_PendingActions.end()) + { + Action = PIt->second; + } }); if (!Action) { - m_PendingLock.WithSharedLock([&] { - if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end()) - { - Action = It->second; - } - }); - } - - if (!Action) - { return {.Success = false, .Error = "Action not found in pending or running maps"}; } @@ -1912,18 +1884,15 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn) void ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn) { - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); + // Caller must hold m_ActionMapLock exclusively. + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } } void @@ -1973,7 +1942,7 @@ ComputeServiceSession::Impl::HandleActionUpdates() } else { - m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); } break; @@ -1983,11 +1952,9 @@ ComputeServiceSession::Impl::HandleActionUpdates() // Dispatched to a runner — move from pending to running case RunnerAction::State::Running: - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - m_RunningMap.insert({ActionLsn, Action}); - m_PendingActions.erase(ActionLsn); - }); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.insert({ActionLsn, Action}); + m_PendingActions.erase(ActionLsn); }); ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); break; @@ -1995,7 +1962,10 @@ ComputeServiceSession::Impl::HandleActionUpdates() // Retracted — pull back for rescheduling without counting against retry limit case RunnerAction::State::Retracted: { - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.erase(ActionLsn); + m_PendingActions[ActionLsn] = Action; + }); Action->ResetActionStateToPending(); ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn); break; @@ -2019,7 +1989,10 @@ ComputeServiceSession::Impl::HandleActionUpdates() if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) { - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.erase(ActionLsn); + m_PendingActions[ActionLsn] = Action; + }); // Reset triggers PostUpdate() which re-enters the action as Pending Action->ResetActionStateToPending(); @@ -2034,16 +2007,16 @@ ComputeServiceSession::Impl::HandleActionUpdates() } } - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + RemoveActionFromActiveMaps(ActionLsn); - // Update queue counters BEFORE publishing the result into - // m_ResultsMap. GetActionResult erases from m_ResultsMap - // under m_ResultsLock, so if we updated counters after - // releasing that lock, a caller could observe ActiveCount - // still at 1 immediately after GetActionResult returned OK. - NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); + // Update queue counters BEFORE publishing the result into + // m_ResultsMap. GetActionResult erases from m_ResultsMap + // under m_ActionMapLock, so if we updated counters after + // releasing that lock, a caller could observe ActiveCount + // still at 1 immediately after GetActionResult returned OK. + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); - m_ResultsLock.WithExclusiveLock([&] { m_ResultsMap[ActionLsn] = Action; // Append to bounded action history ring @@ -2282,6 +2255,18 @@ ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::files } void +ComputeServiceSession::AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddManagedLocalRunner"); + + auto* NewRunner = + new ManagedProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, MaxConcurrentActions); + + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner); +} + +void ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) { ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner"); diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp index bdfd9d197..bd3f4e70e 100644 --- a/src/zencompute/httpcomputeservice.cpp +++ b/src/zencompute/httpcomputeservice.cpp @@ -93,13 +93,14 @@ struct HttpComputeService::Impl uint64_t NewBytes = 0; }; - IngestStats IngestPackageAttachments(const CbPackage& Package); - bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); - void HandleWorkersGet(HttpServerRequest& HttpReq); - void HandleWorkersAllGet(HttpServerRequest& HttpReq); - void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); - void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); - void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker); + bool IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats); + bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); + bool ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment); + void HandleWorkersGet(HttpServerRequest& HttpReq); + void HandleWorkersAllGet(HttpServerRequest& HttpReq); + void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); + void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker); // WebSocket / observer void OnWebSocketOpen(Ref<WebSocketConnection> Connection); @@ -373,7 +374,7 @@ HttpComputeService::Impl::RegisterRoutes() if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output); ResponseCode != HttpResponseCode::OK) { - ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) + ZEN_DEBUG("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) if (ResponseCode == HttpResponseCode::NotFound) { @@ -1167,35 +1168,81 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin return ParseInt<int>(Capture).value_or(0); } -HttpComputeService::Impl::IngestStats -HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package) +bool +HttpComputeService::Impl::ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment) { - IngestStats Stats; + const IoHash ClaimedHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + const IoHash HeaderHash = Buffer.DecodeRawHash(); + + if (HeaderHash != ClaimedHash) + { + ZEN_WARN("attachment header hash mismatch: claimed {} but header contains {}", ClaimedHash, HeaderHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + IoHashStream Hasher; + + bool DecompressOk = Buffer.DecompressToStream( + 0, + Buffer.DecodeRawSize(), + [&](uint64_t /*SourceOffset*/, uint64_t /*SourceSize*/, uint64_t /*Offset*/, const CompositeBuffer& Range) -> bool { + for (const SharedBuffer& Segment : Range.GetSegments()) + { + Hasher.Append(Segment.GetView()); + } + return true; + }); + + if (!DecompressOk) + { + ZEN_WARN("attachment {}: failed to decompress", ClaimedHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + const IoHash ActualHash = Hasher.GetHash(); + + if (ActualHash != ClaimedHash) + { + ZEN_WARN("attachment hash mismatch: claimed {} but decompressed data hashes to {}", ClaimedHash, ActualHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + return true; +} +bool +HttpComputeService::Impl::IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats) +{ for (const CbAttachment& Attachment : Package.GetAttachments()) { ZEN_ASSERT(Attachment.IsCompressedBinary()); - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); + if (!ValidateAttachmentHash(HttpReq, Attachment)) + { + return false; + } - const uint64_t CompressedSize = DataView.GetCompressedSize(); + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + const uint64_t CompressedSize = DataView.GetCompressedSize(); - Stats.Bytes += CompressedSize; - ++Stats.Count; + OutStats.Bytes += CompressedSize; + ++OutStats.Count; const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); if (InsertResult.New) { - Stats.NewBytes += CompressedSize; - ++Stats.NewCount; + OutStats.NewBytes += CompressedSize; + ++OutStats.NewCount; } } - return Stats; + return true; } bool @@ -1253,7 +1300,10 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que { CbPackage Package = HttpReq.ReadPayloadPackage(); Body = Package.GetObject(); - Stats = IngestPackageAttachments(Package); + if (!IngestPackageAttachments(HttpReq, Package, Stats)) + { + return; // validation failed, response already written + } break; } @@ -1268,8 +1318,7 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que { // --- Batch path --- - // For CbObject payloads, check all attachments upfront before enqueuing anything - if (HttpReq.RequestContentType() == HttpContentType::kCbObject) + // Verify all action attachment references exist in the store { std::vector<IoHash> NeedList; @@ -1345,7 +1394,6 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que // --- Single-action path: Body is the action itself --- - if (HttpReq.RequestContentType() == HttpContentType::kCbObject) { std::vector<IoHash> NeedList; @@ -1491,10 +1539,14 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const { ZEN_ASSERT(Attachment.IsCompressedBinary()); + if (!ValidateAttachmentHash(HttpReq, Attachment)) + { + return; + } + const IoHash DataHash = Attachment.GetHash(); CompressedBuffer Buffer = Attachment.AsCompressedBinary(); - ZEN_UNUSED(DataHash); TotalAttachmentBytes += Buffer.GetCompressedSize(); ++AttachmentCount; diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h index 1ca78738a..ad556f546 100644 --- a/src/zencompute/include/zencompute/computeservice.h +++ b/src/zencompute/include/zencompute/computeservice.h @@ -167,6 +167,7 @@ public: // Action runners void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); + void AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); // Action submission diff --git a/src/zencompute/pathvalidation.h b/src/zencompute/pathvalidation.h new file mode 100644 index 000000000..c2e30183a --- /dev/null +++ b/src/zencompute/pathvalidation.h @@ -0,0 +1,118 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/except_fmt.h> +#include <zencore/string.h> + +#include <filesystem> +#include <string_view> + +namespace zen::compute { + +// Validate that a single path component contains only characters that are valid +// file/directory names on all supported platforms. Uses Windows rules as the most +// restrictive superset, since packages may be built on one platform and consumed +// on another. +inline void +ValidatePathComponent(std::string_view Component, std::string_view FullPath) +{ + // Reject control characters (0x00-0x1F) and characters forbidden on Windows + for (char Ch : Component) + { + if (static_cast<unsigned char>(Ch) < 0x20 || Ch == '<' || Ch == '>' || Ch == ':' || Ch == '"' || Ch == '|' || Ch == '?' || + Ch == '*') + { + throw zen::invalid_argument("invalid character in path component '{}' of '{}'", Component, FullPath); + } + } + + // Reject empty components and trailing dots or spaces (silently stripped on Windows, leading to confusion) + if (Component.empty() || Component.back() == '.' || Component.back() == ' ') + { + throw zen::invalid_argument("path component '{}' of '{}' has trailing dot or space", Component, FullPath); + } + + // Reject Windows reserved device names (CON, PRN, AUX, NUL, COM1-9, LPT1-9) + // These are reserved with or without an extension (e.g. "CON.txt" is still reserved). + std::string_view Stem = Component.substr(0, Component.find('.')); + + static constexpr std::string_view ReservedNames[] = { + "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", + "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", + }; + + for (std::string_view Reserved : ReservedNames) + { + if (zen::StrCaseCompare(Stem, Reserved) == 0) + { + throw zen::invalid_argument("path component '{}' of '{}' uses reserved device name '{}'", Component, FullPath, Reserved); + } + } +} + +// Validate that a path extracted from a package is a safe relative path. +// Rejects absolute paths, ".." components, and invalid platform filenames. +inline void +ValidateSandboxRelativePath(std::string_view Name) +{ + if (Name.empty()) + { + throw zen::invalid_argument("path traversal detected: empty path name"); + } + + std::filesystem::path Parsed(Name); + + if (Parsed.is_absolute()) + { + throw zen::invalid_argument("path traversal detected: '{}' is an absolute path", Name); + } + + for (const auto& Component : Parsed) + { + std::string ComponentStr = Component.string(); + + if (ComponentStr == "..") + { + throw zen::invalid_argument("path traversal detected: '{}' contains '..' component", Name); + } + + // Skip "." (current directory) — harmless in relative paths + if (ComponentStr != ".") + { + ValidatePathComponent(ComponentStr, Name); + } + } +} + +// Validate all path entries in a worker description CbObject. +// Checks path, executables[].name, dirs[], and files[].name fields. +// Throws an exception if any invalid paths are found. +inline void +ValidateWorkerDescriptionPaths(const CbObject& WorkerDescription) +{ + using namespace std::literals; + + if (auto PathField = WorkerDescription["path"sv]; PathField.HasValue()) + { + ValidateSandboxRelativePath(PathField.AsString()); + } + + for (auto& It : WorkerDescription["executables"sv]) + { + ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString()); + } + + for (auto& It : WorkerDescription["dirs"sv]) + { + ValidateSandboxRelativePath(It.AsString()); + } + + for (auto& It : WorkerDescription["files"sv]) + { + ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString()); + } +} + +} // namespace zen::compute diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp index 4f116e7d8..67e12b84e 100644 --- a/src/zencompute/runners/functionrunner.cpp +++ b/src/zencompute/runners/functionrunner.cpp @@ -164,8 +164,9 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } } - // Assign any remaining actions to runners with capacity (round-robin) - for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount) + // Assign any remaining actions to runners with capacity (round-robin). + // Cap at TotalCapacity to avoid spinning when there are more actions than runners can accept. + for (int i = 0; ActionIdx < Actions.size() && ActionIdx < TotalCapacity; i = (i + 1) % RunnerCount) { if (Capacities[i] > PerRunnerActions[i].size()) { @@ -186,11 +187,12 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } } - // Reassemble results in original action order - std::vector<SubmitResult> Results(Actions.size()); + // Reassemble results in original action order. + // Actions beyond ActionIdx were not assigned to any runner (insufficient capacity). + std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"}); std::vector<size_t> PerRunnerIdx(RunnerCount, 0); - for (size_t i = 0; i < Actions.size(); ++i) + for (size_t i = 0; i < ActionIdx; ++i) { size_t RunnerIdx = ActionRunnerIndex[i]; size_t Idx = PerRunnerIdx[RunnerIdx]++; diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp index e79a6c90f..9055005d9 100644 --- a/src/zencompute/runners/linuxrunner.cpp +++ b/src/zencompute/runners/linuxrunner.cpp @@ -331,6 +331,8 @@ LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver, { ZEN_INFO("namespace sandboxing enabled for child processes"); } + + StartMonitorThread(); } SubmitResult diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp index b61e0a46f..1b748c0e5 100644 --- a/src/zencompute/runners/localrunner.cpp +++ b/src/zencompute/runners/localrunner.cpp @@ -4,6 +4,8 @@ #if ZEN_WITH_COMPUTE_SERVICES +# include "pathvalidation.h" + # include <zencore/compactbinary.h> # include <zencore/compactbinarybuilder.h> # include <zencore/compactbinarypackage.h> @@ -104,8 +106,6 @@ LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, ZEN_INFO("Cleanup complete"); } - m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; - # if ZEN_PLATFORM_WINDOWS // Suppress any error dialogs caused by missing dependencies UINT OldMode = ::SetErrorMode(0); @@ -382,6 +382,8 @@ LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromP const IoHash ChunkHash = FileEntry["hash"sv].AsHash(); const uint64_t Size = FileEntry["size"sv].AsUInt64(); + ValidateSandboxRelativePath(Name); + CompressedBuffer Compressed; if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash)) @@ -457,7 +459,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, for (auto& It : WorkerDescription["dirs"sv]) { - std::string_view Name = It.AsString(); + std::string_view Name = It.AsString(); + ValidateSandboxRelativePath(Name); std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; // Validate dir path stays within sandbox @@ -482,6 +485,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, } WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer()); + + ZEN_INFO("manifested worker '{}' in '{}'", WorkerPackage.GetObjectHash(), SandboxPath); } CbPackage @@ -540,6 +545,12 @@ LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) } void +LocalProcessRunner::StartMonitorThread() +{ + m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; +} + +void LocalProcessRunner::MonitorThreadFunction() { SetCurrentThreadName("LocalProcessRunner_Monitor"); diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h index b8cff6826..d6589db43 100644 --- a/src/zencompute/runners/localrunner.h +++ b/src/zencompute/runners/localrunner.h @@ -67,6 +67,7 @@ protected: { Ref<RunnerAction> Action; void* ProcessHandle = nullptr; + int Pid = 0; int ExitCode = 0; std::filesystem::path SandboxPath; @@ -83,8 +84,6 @@ protected: std::filesystem::path m_SandboxPath; int32_t m_MaxRunningActions = 64; // arbitrary limit for testing - // if used in conjuction with m_ResultsLock, this lock must be taken *after* - // m_ResultsLock to avoid deadlocks RwLock m_RunningLock; std::unordered_map<int, Ref<RunningAction>> m_RunningMap; @@ -95,6 +94,7 @@ protected: std::thread m_MonitorThread; std::atomic<bool> m_MonitorThreadEnabled{true}; Event m_MonitorThreadEvent; + void StartMonitorThread(); void MonitorThreadFunction(); virtual void SweepRunningActions(); virtual void CancelRunningActions(); diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp index 5cec90699..c2ccca9a6 100644 --- a/src/zencompute/runners/macrunner.cpp +++ b/src/zencompute/runners/macrunner.cpp @@ -130,6 +130,8 @@ MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver, { ZEN_INFO("Seatbelt sandboxing enabled for child processes"); } + + StartMonitorThread(); } SubmitResult diff --git a/src/zencompute/runners/managedrunner.cpp b/src/zencompute/runners/managedrunner.cpp new file mode 100644 index 000000000..e4a7ba388 --- /dev/null +++ b/src/zencompute/runners/managedrunner.cpp @@ -0,0 +1,279 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "managedrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/scopeguard.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio/io_context.hpp> +# include <asio/executor_work_guard.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +using namespace std::literals; + +ManagedProcessRunner::ManagedProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_IoContext(std::make_unique<asio::io_context>()) +, m_SubprocessManager(std::make_unique<SubprocessManager>(*m_IoContext)) +{ + m_ProcessGroup = m_SubprocessManager->CreateGroup("compute-workers"); + + // Run the io_context on a small thread pool so that exit callbacks and + // metrics sampling are dispatched without blocking each other. + for (int i = 0; i < kIoThreadCount; ++i) + { + m_IoThreads.emplace_back([this, i] { + SetCurrentThreadName(fmt::format("mrunner_{}", i)); + + // work_guard keeps run() alive even when there is no pending work yet + auto WorkGuard = asio::make_work_guard(*m_IoContext); + + m_IoContext->run(); + }); + } +} + +ManagedProcessRunner::~ManagedProcessRunner() +{ + try + { + Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception during managed process runner shutdown: {}", Ex.what()); + } +} + +void +ManagedProcessRunner::Shutdown() +{ + ZEN_TRACE_CPU("ManagedProcessRunner::Shutdown"); + m_AcceptNewActions = false; + + CancelRunningActions(); + + // Tear down the SubprocessManager before stopping the io_context so that + // any in-flight callbacks are drained cleanly. + if (m_SubprocessManager) + { + m_SubprocessManager->DestroyGroup("compute-workers"); + m_ProcessGroup = nullptr; + m_SubprocessManager.reset(); + } + + if (m_IoContext) + { + m_IoContext->stop(); + } + + for (std::thread& Thread : m_IoThreads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + m_IoThreads.clear(); +} + +SubmitResult +ManagedProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + // Parse environment variables from worker descriptor ("KEY=VALUE" strings) + // into the key-value pairs expected by CreateProcOptions. + std::vector<std::pair<std::string, std::string>> EnvPairs; + for (auto& It : WorkerDescription["environment"sv]) + { + std::string_view Str = It.AsString(); + size_t Eq = Str.find('='); + if (Eq != std::string_view::npos) + { + EnvPairs.emplace_back(std::string(Str.substr(0, Eq)), std::string(Str.substr(Eq + 1))); + } + } + + // Build command line + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred(); + + std::string CommandLine = fmt::format("\"{}\" -Build=build.action"sv, ExePath.string()); + + ZEN_DEBUG("Executing (managed): '{}' (sandbox='{}')", CommandLine, Prepared->SandboxPath); + + CreateProcOptions Options; + Options.WorkingDirectory = &Prepared->SandboxPath; + Options.Flags = CreateProcOptions::Flag_NoConsole; + Options.Environment = std::move(EnvPairs); + + const int32_t ActionLsn = Prepared->ActionLsn; + + ManagedProcess* Proc = nullptr; + + try + { + Proc = m_ProcessGroup->Spawn(ExePath, CommandLine, Options, [this, ActionLsn](ManagedProcess& /*Process*/, int ExitCode) { + OnProcessExit(ActionLsn, ExitCode); + }); + } + catch (std::exception& Ex) + { + ZEN_ERROR("Failed to spawn process for action LSN {}: {}", ActionLsn, Ex.what()); + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Prepared->SandboxPath)); + return SubmitResult{.IsAccepted = false}; + } + + { + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = static_cast<void*>(Proc); + NewAction->Pid = Proc->Pid(); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + ZEN_DEBUG("Managed runner: action LSN {} -> PID {}", ActionLsn, Proc->Pid()); + + return SubmitResult{.IsAccepted = true}; +} + +void +ManagedProcessRunner::OnProcessExit(int ActionLsn, int ExitCode) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::OnProcessExit"); + + Ref<RunningAction> Running; + + m_RunningLock.WithExclusiveLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It != m_RunningMap.end()) + { + Running = std::move(It->second); + m_RunningMap.erase(It); + } + }); + + if (!Running) + { + return; + } + + ZEN_DEBUG("Managed runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"), ActionLsn, Running->Pid, ExitCode); + + Running->ExitCode = ExitCode; + + // Capture final CPU metrics from the managed process before it is removed. + auto* Proc = static_cast<ManagedProcess*>(Running->ProcessHandle); + if (Proc) + { + ProcessMetrics Metrics = Proc->GetLatestMetrics(); + float CpuMs = static_cast<float>(Metrics.UserTimeMs + Metrics.KernelTimeMs); + Running->Action->CpuSeconds.store(CpuMs / 1000.0f, std::memory_order_relaxed); + + float CpuPct = Proc->GetCpuUsagePercent(); + if (CpuPct >= 0.0f) + { + Running->Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running->ProcessHandle = nullptr; + + std::vector<Ref<RunningAction>> CompletedActions; + CompletedActions.push_back(std::move(Running)); + ProcessCompletedActions(CompletedActions); +} + +void +ManagedProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("ManagedProcessRunner::CancelRunningActions"); + + std::unordered_map<int, Ref<RunningAction>> RunningMap; + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling {} running actions via process group", RunningMap.size()); + + Stopwatch Timer; + + // Kill all processes in the group atomically (TerminateJobObject on Windows, + // SIGTERM+SIGKILL on POSIX). + if (m_ProcessGroup) + { + m_ProcessGroup->KillAll(); + } + + for (auto& [Lsn, Running] : RunningMap) + { + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +ManagedProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::CancelAction"); + + ManagedProcess* Proc = nullptr; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It != m_RunningMap.end() && It->second->ProcessHandle != nullptr) + { + Proc = static_cast<ManagedProcess*>(It->second->ProcessHandle); + } + }); + + if (!Proc) + { + return false; + } + + // Terminate the process. The exit callback will handle the rest + // (remove from running map, gather outputs or mark failed). + Proc->Terminate(222); + + ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn); + return true; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/managedrunner.h b/src/zencompute/runners/managedrunner.h new file mode 100644 index 000000000..21a44d43c --- /dev/null +++ b/src/zencompute/runners/managedrunner.h @@ -0,0 +1,64 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zenutil/process/subprocessmanager.h> + +# include <memory> + +namespace asio { +class io_context; +} + +namespace zen::compute { + +/** Cross-platform process runner backed by SubprocessManager. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and shared action preparation. Replaces the polling-based + monitor thread with async exit callbacks driven by SubprocessManager, and + delegates CPU/memory metrics sampling to the manager's built-in round-robin + sampler. + + A ProcessGroup (backed by a JobObject on Windows, process group on POSIX) is + used for bulk cancellation on shutdown. + + This runner does not perform any platform-specific sandboxing (AppContainer, + namespaces, Seatbelt). It is intended as a simpler, cross-platform alternative + to the platform-specific runners for non-sandboxed workloads. + */ +class ManagedProcessRunner : public LocalProcessRunner +{ +public: + ManagedProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions = 0); + ~ManagedProcessRunner(); + + void Shutdown() override; + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + [[nodiscard]] bool IsHealthy() override { return true; } + +private: + static constexpr int kIoThreadCount = 4; + + // Exit callback posted on an io_context thread. + void OnProcessExit(int ActionLsn, int ExitCode); + + std::unique_ptr<asio::io_context> m_IoContext; + std::unique_ptr<SubprocessManager> m_SubprocessManager; + ProcessGroup* m_ProcessGroup = nullptr; + std::vector<std::thread> m_IoThreads; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp index cd4b646e9..92ee65c2d 100644 --- a/src/zencompute/runners/windowsrunner.cpp +++ b/src/zencompute/runners/windowsrunner.cpp @@ -21,6 +21,12 @@ ZEN_THIRD_PARTY_INCLUDES_START # include <sddl.h> ZEN_THIRD_PARTY_INCLUDES_END +// JOB_OBJECT_UILIMIT_ERRORMODE is defined in winuser.h which may be +// excluded by WIN32_LEAN_AND_MEAN. +# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE) +# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400 +# endif + namespace zen::compute { using namespace std::literals; @@ -34,38 +40,65 @@ WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver, : LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) , m_Sandboxed(Sandboxed) { - if (!m_Sandboxed) + // Create a job object shared by all child processes. Restricting the + // error-mode UI prevents crash dialogs (WER / Dr. Watson) from + // blocking the monitor thread when a worker process terminates + // abnormally. + m_JobObject = CreateJobObjectW(nullptr, nullptr); + if (m_JobObject) { - return; + JOBOBJECT_EXTENDED_LIMIT_INFORMATION ExtLimits{}; + ExtLimits.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION; + SetInformationJobObject(m_JobObject, JobObjectExtendedLimitInformation, &ExtLimits, sizeof(ExtLimits)); + + JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{}; + UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE; + SetInformationJobObject(m_JobObject, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions)); + + // Set error mode on this process so children inherit it. The + // UILIMIT_ERRORMODE restriction above prevents them from clearing + // SEM_NOGPFAULTERRORBOX. + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX); } - // Build a unique profile name per process to avoid collisions - m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); + if (m_Sandboxed) + { + // Build a unique profile name per process to avoid collisions + m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); - // Clean up any stale profile from a previous crash - DeleteAppContainerProfile(m_AppContainerName.c_str()); + // Clean up any stale profile from a previous crash + DeleteAppContainerProfile(m_AppContainerName.c_str()); - PSID Sid = nullptr; + PSID Sid = nullptr; - HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), - m_AppContainerName.c_str(), // display name - m_AppContainerName.c_str(), // description - nullptr, // no capabilities - 0, // capability count - &Sid); + HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), + m_AppContainerName.c_str(), // display name + m_AppContainerName.c_str(), // description + nullptr, // no capabilities + 0, // capability count + &Sid); - if (FAILED(Hr)) - { - throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); - } + if (FAILED(Hr)) + { + throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); + } - m_AppContainerSid = Sid; + m_AppContainerSid = Sid; + + ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); + } - ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); + StartMonitorThread(); } WindowsProcessRunner::~WindowsProcessRunner() { + if (m_JobObject) + { + CloseHandle(m_JobObject); + m_JobObject = nullptr; + } + if (m_AppContainerSid) { FreeSid(m_AppContainerSid); @@ -172,9 +205,9 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; BOOL bInheritHandles = FALSE; - DWORD dwCreationFlags = DETACHED_PROCESS; + DWORD dwCreationFlags = CREATE_SUSPENDED | DETACHED_PROCESS; - ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed); + ZEN_DEBUG("{}: '{}' (sandbox='{}')", m_Sandboxed ? "Sandboxing" : "Executing", WideToUtf8(CommandLine.c_str()), Prepared->SandboxPath); CommandLine.EnsureNulTerminated(); @@ -260,14 +293,21 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) } } - CloseHandle(ProcessInformation.hThread); + if (m_JobObject) + { + AssignProcessToJobObject(m_JobObject, ProcessInformation.hProcess); + } - Ref<RunningAction> NewAction{new RunningAction()}; - NewAction->Action = Action; - NewAction->ProcessHandle = ProcessInformation.hProcess; - NewAction->SandboxPath = std::move(Prepared->SandboxPath); + ResumeThread(ProcessInformation.hThread); + CloseHandle(ProcessInformation.hThread); { + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = ProcessInformation.hProcess; + NewAction->Pid = ProcessInformation.dwProcessId; + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + RwLock::ExclusiveLockScope _(m_RunningLock); m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); @@ -275,6 +315,8 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) Action->SetActionState(RunnerAction::State::Running); + ZEN_DEBUG("Local runner: action LSN {} -> PID {}", Action->ActionLsn, ProcessInformation.dwProcessId); + return SubmitResult{.IsAccepted = true}; } @@ -294,6 +336,11 @@ WindowsProcessRunner::SweepRunningActions() if (IsSuccess && ExitCode != STILL_ACTIVE) { + ZEN_DEBUG("Local runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"), + Running->Action->ActionLsn, + Running->Pid, + ExitCode); + CloseHandle(Running->ProcessHandle); Running->ProcessHandle = INVALID_HANDLE_VALUE; Running->ExitCode = ExitCode; diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h index 9f2385cc4..adeaf02fc 100644 --- a/src/zencompute/runners/windowsrunner.h +++ b/src/zencompute/runners/windowsrunner.h @@ -46,6 +46,7 @@ private: bool m_Sandboxed = false; PSID m_AppContainerSid = nullptr; std::wstring m_AppContainerName; + HANDLE m_JobObject = nullptr; }; } // namespace zen::compute diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp index 506bec73b..b4fafb467 100644 --- a/src/zencompute/runners/winerunner.cpp +++ b/src/zencompute/runners/winerunner.cpp @@ -36,6 +36,8 @@ WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver, sigemptyset(&Action.sa_mask); Action.sa_handler = SIG_DFL; sigaction(SIGCHLD, &Action, nullptr); + + StartMonitorThread(); } SubmitResult diff --git a/src/zencore/compactbinarypackage.cpp b/src/zencore/compactbinarypackage.cpp index 56a292ca6..cd268745c 100644 --- a/src/zencore/compactbinarypackage.cpp +++ b/src/zencore/compactbinarypackage.cpp @@ -684,14 +684,22 @@ namespace legacy { Writer.Save(Ar); } - bool TryLoadCbPackage(CbPackage& Package, IoBuffer InBuffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper) + bool TryLoadCbPackage(CbPackage& Package, + IoBuffer InBuffer, + BufferAllocator Allocator, + CbPackage::AttachmentResolver* Mapper, + bool ValidateHashes) { BinaryReader Reader(InBuffer.Data(), InBuffer.Size()); - return TryLoadCbPackage(Package, Reader, Allocator, Mapper); + return TryLoadCbPackage(Package, Reader, Allocator, Mapper, ValidateHashes); } - bool TryLoadCbPackage(CbPackage& Package, BinaryReader& Reader, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper) + bool TryLoadCbPackage(CbPackage& Package, + BinaryReader& Reader, + BufferAllocator Allocator, + CbPackage::AttachmentResolver* Mapper, + bool ValidateHashes) { Package = CbPackage(); for (;;) @@ -708,7 +716,11 @@ namespace legacy { if (ValueField.IsBinary()) { const MemoryView View = ValueField.AsBinaryView(); - if (View.GetSize() > 0) + if (View.GetSize() == 0) + { + return false; + } + else { SharedBuffer Buffer = SharedBuffer::MakeView(View, ValueField.GetOuterBuffer()).MakeOwned(); CbField HashField = LoadCompactBinary(Reader, Allocator); @@ -748,7 +760,11 @@ namespace legacy { { const IoHash Hash = ValueField.AsHash(); - ZEN_ASSERT(Mapper); + if (!Mapper) + { + return false; + } + if (SharedBuffer AttachmentData = (*Mapper)(Hash)) { IoHash RawHash; @@ -763,6 +779,10 @@ namespace legacy { } else { + if (ValidateHashes && IoHash::HashBuffer(AttachmentData) != Hash) + { + return false; + } const CbValidateError ValidationResult = ValidateCompactBinary(AttachmentData.GetView(), CbValidateMode::All); if (ValidationResult != CbValidateError::None) { @@ -801,13 +821,13 @@ namespace legacy { #if ZEN_WITH_TESTS void -usonpackage_forcelink() +cbpackage_forcelink() { } TEST_SUITE_BEGIN("core.compactbinarypackage"); -TEST_CASE("usonpackage") +TEST_CASE("cbpackage") { using namespace std::literals; @@ -997,7 +1017,7 @@ TEST_CASE("usonpackage") } } -TEST_CASE("usonpackage.serialization") +TEST_CASE("cbpackage.serialization") { using namespace std::literals; @@ -1303,7 +1323,7 @@ TEST_CASE("usonpackage.serialization") } } -TEST_CASE("usonpackage.invalidpackage") +TEST_CASE("cbpackage.invalidpackage") { const auto TestLoad = [](std::initializer_list<uint8_t> RawData, BufferAllocator Allocator = UniqueBuffer::Alloc) { const MemoryView RawView = MakeMemoryView(RawData); @@ -1345,6 +1365,90 @@ TEST_CASE("usonpackage.invalidpackage") } } +TEST_CASE("cbpackage.legacy.invalidpackage") +{ + const auto TestLegacyLoad = [](std::initializer_list<uint8_t> RawData) { + const MemoryView RawView = MakeMemoryView(RawData); + IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(RawView.GetData()), RawView.GetSize()); + CbPackage Package; + CHECK_FALSE(legacy::TryLoadCbPackage(Package, Buffer, &UniqueBuffer::Alloc)); + }; + + SUBCASE("Empty") { TestLegacyLoad({}); } + + SUBCASE("Zero size binary rejects") + { + // A binary field with zero payload size should be rejected (would desync the reader) + BinaryWriter Writer; + CbWriter Cb; + Cb.AddBinary(MemoryView()); // zero-size binary + Cb.Save(Writer); + + IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(MakeMemoryView(Writer).GetData()), MakeMemoryView(Writer).GetSize()); + CbPackage Package; + CHECK_FALSE(legacy::TryLoadCbPackage(Package, Buffer, &UniqueBuffer::Alloc)); + } +} + +TEST_CASE("cbpackage.legacy.hashresolution") +{ + // Build a valid legacy package with an object, then round-trip it + CbObjectWriter RootWriter; + RootWriter.AddString("name", "test"); + CbObject RootObject = RootWriter.Save(); + + CbAttachment ObjectAttach(RootObject); + + CbPackage OriginalPkg; + OriginalPkg.SetObject(RootObject); + OriginalPkg.AddAttachment(ObjectAttach); + + BinaryWriter Writer; + legacy::SaveCbPackage(OriginalPkg, Writer); + + IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(MakeMemoryView(Writer).GetData()), MakeMemoryView(Writer).GetSize()); + CbPackage LoadedPkg; + CHECK(legacy::TryLoadCbPackage(LoadedPkg, Buffer, &UniqueBuffer::Alloc)); + + // The hash-only path requires a mapper — without one it should fail + CbWriter HashOnlyCb; + HashOnlyCb.AddHash(ObjectAttach.GetHash()); + HashOnlyCb.AddNull(); + BinaryWriter HashOnlyWriter; + HashOnlyCb.Save(HashOnlyWriter); + + IoBuffer HashOnlyBuffer(IoBuffer::Wrap, + const_cast<void*>(MakeMemoryView(HashOnlyWriter).GetData()), + MakeMemoryView(HashOnlyWriter).GetSize()); + CbPackage HashOnlyPkg; + CHECK_FALSE(legacy::TryLoadCbPackage(HashOnlyPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, nullptr)); + + // With a mapper that returns valid data, it should succeed + CbPackage::AttachmentResolver Resolver = [&](const IoHash& Hash) -> SharedBuffer { + if (Hash == ObjectAttach.GetHash()) + { + return RootObject.GetBuffer(); + } + return {}; + }; + CHECK(legacy::TryLoadCbPackage(HashOnlyPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &Resolver)); + + // Build a different but structurally valid CbObject to use as mismatched data + CbObjectWriter DifferentWriter; + DifferentWriter.AddString("name", "different"); + CbObject DifferentObject = DifferentWriter.Save(); + + CbPackage::AttachmentResolver BadResolver = [&](const IoHash&) -> SharedBuffer { return DifferentObject.GetBuffer(); }; + CbPackage BadPkg; + + // With ValidateHashes enabled and a mapper that returns mismatched data, it should fail + CHECK_FALSE(legacy::TryLoadCbPackage(BadPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &BadResolver, /*ValidateHashes*/ true)); + + // Without ValidateHashes, the mismatched data is accepted (structure is still valid CB) + CbPackage UncheckedPkg; + CHECK(legacy::TryLoadCbPackage(UncheckedPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &BadResolver, /*ValidateHashes*/ false)); +} + TEST_SUITE_END(); #endif diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 416312cae..a63594be9 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -3277,12 +3277,23 @@ MakeSafeAbsolutePathInPlace(std::filesystem::path& Path) { Path = std::filesystem::absolute(Path).make_preferred(); #if ZEN_PLATFORM_WINDOWS - const std::string_view Prefix = "\\\\?\\"; - const std::u8string PrefixU8(Prefix.begin(), Prefix.end()); - std::u8string PathString = Path.u8string(); - if (!PathString.empty() && !PathString.starts_with(PrefixU8)) + const std::u8string_view LongPathPrefix = u8"\\\\?\\"; + const std::u8string_view UncPrefix = u8"\\\\"; + const std::u8string_view LongPathUncPrefix = u8"\\\\?\\UNC\\"; + + std::u8string PathString = Path.u8string(); + if (!PathString.empty() && !PathString.starts_with(LongPathPrefix)) { - PathString.insert(0, PrefixU8); + if (PathString.starts_with(UncPrefix)) + { + // UNC path: \\server\share → \\?\UNC\server\share + PathString.replace(0, UncPrefix.size(), LongPathUncPrefix); + } + else + { + // Local path: C:\foo → \\?\C:\foo + PathString.insert(0, LongPathPrefix); + } Path = PathString; } #endif // ZEN_PLATFORM_WINDOWS @@ -4049,6 +4060,54 @@ TEST_CASE("SharedMemory") CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, false)); } +TEST_CASE("filesystem.MakeSafeAbsolutePath") +{ +# if ZEN_PLATFORM_WINDOWS + // Local path gets \\?\ prefix + { + std::filesystem::path Local = MakeSafeAbsolutePath("C:\\Users\\test"); + CHECK(Local.u8string().starts_with(u8"\\\\?\\")); + CHECK(Local.u8string().find(u8"C:\\Users\\test") != std::u8string::npos); + } + + // UNC path gets \\?\UNC\ prefix + { + std::filesystem::path Unc = MakeSafeAbsolutePath("\\\\server\\share\\path"); + std::u8string UncStr = Unc.u8string(); + CHECK_MESSAGE(UncStr.starts_with(u8"\\\\?\\UNC\\"), fmt::format("Expected \\\\?\\UNC\\ prefix, got '{}'", Unc)); + CHECK_MESSAGE(UncStr.find(u8"server\\share\\path") != std::u8string::npos, + fmt::format("Expected server\\share\\path in '{}'", Unc)); + // Must NOT produce \\?\\\server (double backslash after \\?\) + CHECK_MESSAGE(UncStr.find(u8"\\\\?\\\\\\") == std::u8string::npos, + fmt::format("Path contains invalid double-backslash after prefix: '{}'", Unc)); + } + + // Already-prefixed path is not double-prefixed + { + std::filesystem::path Already = MakeSafeAbsolutePath("\\\\?\\C:\\already\\prefixed"); + size_t Count = 0; + std::u8string Str = Already.u8string(); + for (size_t Pos = Str.find(u8"\\\\?\\"); Pos != std::u8string::npos; Pos = Str.find(u8"\\\\?\\", Pos + 1)) + { + ++Count; + } + CHECK_EQ(Count, 1); + } + + // Already-prefixed UNC path is not double-prefixed + { + std::filesystem::path AlreadyUnc = MakeSafeAbsolutePath("\\\\?\\UNC\\server\\share"); + size_t Count = 0; + std::u8string Str = AlreadyUnc.u8string(); + for (size_t Pos = Str.find(u8"\\\\?\\"); Pos != std::u8string::npos; Pos = Str.find(u8"\\\\?\\", Pos + 1)) + { + ++Count; + } + CHECK_EQ(Count, 1); + } +# endif // ZEN_PLATFORM_WINDOWS +} + TEST_SUITE_END(); #endif diff --git a/src/zencore/include/zencore/compactbinarypackage.h b/src/zencore/include/zencore/compactbinarypackage.h index 64b62e2c0..148c0d3fd 100644 --- a/src/zencore/include/zencore/compactbinarypackage.h +++ b/src/zencore/include/zencore/compactbinarypackage.h @@ -278,10 +278,10 @@ public: * @return The attachment, or null if the attachment is not found. * @note The returned pointer is only valid until the attachments on this package are modified. */ - const CbAttachment* FindAttachment(const IoHash& Hash) const; + [[nodiscard]] const CbAttachment* FindAttachment(const IoHash& Hash) const; /** Find an attachment if it exists in the package. */ - inline const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); } + [[nodiscard]] const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); } /** Add the attachment to this package. */ inline void AddAttachment(const CbAttachment& Attachment) { AddAttachment(Attachment, nullptr); } @@ -336,17 +336,26 @@ private: IoHash ObjectHash; }; +/** In addition to the above, we also support a legacy format which is used by + * the HTTP project store for historical reasons. Don't use the below functions + * for new code. + */ namespace legacy { void SaveCbAttachment(const CbAttachment& Attachment, CbWriter& Writer); void SaveCbPackage(const CbPackage& Package, CbWriter& Writer); void SaveCbPackage(const CbPackage& Package, BinaryWriter& Ar); - bool TryLoadCbPackage(CbPackage& Package, IoBuffer Buffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper = nullptr); + bool TryLoadCbPackage(CbPackage& Package, + IoBuffer Buffer, + BufferAllocator Allocator, + CbPackage::AttachmentResolver* Mapper = nullptr, + bool ValidateHashes = false); bool TryLoadCbPackage(CbPackage& Package, BinaryReader& Reader, BufferAllocator Allocator, - CbPackage::AttachmentResolver* Mapper = nullptr); + CbPackage::AttachmentResolver* Mapper = nullptr, + bool ValidateHashes = false); } // namespace legacy -void usonpackage_forcelink(); // internal +void cbpackage_forcelink(); // internal } // namespace zen diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h index 3427991d2..1608ad523 100644 --- a/src/zencore/include/zencore/logging.h +++ b/src/zencore/include/zencore/logging.h @@ -90,6 +90,34 @@ using zen::ConsoleLog; using zen::ErrorLog; using zen::Log; +//////////////////////////////////////////////////////////////////////// +// Color helpers + +#define ZEN_RED(str) "\033[31m" str "\033[0m" +#define ZEN_GREEN(str) "\033[32m" str "\033[0m" +#define ZEN_YELLOW(str) "\033[33m" str "\033[0m" +#define ZEN_BLUE(str) "\033[34m" str "\033[0m" +#define ZEN_MAGENTA(str) "\033[35m" str "\033[0m" +#define ZEN_CYAN(str) "\033[36m" str "\033[0m" +#define ZEN_WHITE(str) "\033[37m" str "\033[0m" + +#define ZEN_BRIGHT_RED(str) "\033[91m" str "\033[0m" +#define ZEN_BRIGHT_GREEN(str) "\033[92m" str "\033[0m" +#define ZEN_BRIGHT_YELLOW(str) "\033[93m" str "\033[0m" +#define ZEN_BRIGHT_BLUE(str) "\033[94m" str "\033[0m" +#define ZEN_BRIGHT_MAGENTA(str) "\033[95m" str "\033[0m" +#define ZEN_BRIGHT_CYAN(str) "\033[96m" str "\033[0m" +#define ZEN_BRIGHT_WHITE(str) "\033[97m" str "\033[0m" + +#define ZEN_BOLD(str) "\033[1m" str "\033[0m" +#define ZEN_UNDERLINE(str) "\033[4m" str "\033[0m" +#define ZEN_DIM(str) "\033[2m" str "\033[0m" +#define ZEN_ITALIC(str) "\033[3m" str "\033[0m" +#define ZEN_STRIKETHROUGH(str) "\033[9m" str "\033[0m" +#define ZEN_INVERSE(str) "\033[7m" str "\033[0m" + +//////////////////////////////////////////////////////////////////////// + #if ZEN_BUILD_DEBUG # define ZEN_CHECK_FORMAT_STRING(fmtstr, ...) \ while (false) \ diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp index e7baa3f8e..9cbbfa56a 100644 --- a/src/zencore/process.cpp +++ b/src/zencore/process.cpp @@ -1252,14 +1252,28 @@ JobObject::Initialize() } JOBOBJECT_EXTENDED_LIMIT_INFORMATION LimitInfo = {}; - LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION; if (!SetInformationJobObject(m_JobHandle, JobObjectExtendedLimitInformation, &LimitInfo, sizeof(LimitInfo))) { ZEN_WARN("Failed to set job object limits: {}", zen::GetLastError()); CloseHandle(m_JobHandle); m_JobHandle = nullptr; + return; } + + // Prevent child processes from clearing SEM_NOGPFAULTERRORBOX, which + // suppresses WER/Dr. Watson crash dialogs. Without this, a crashing + // child can pop a modal dialog and block the monitor thread. +# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE) +# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400 +# endif + JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{}; + UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE; + SetInformationJobObject(m_JobHandle, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions)); + + // Set error mode on the current process so children inherit it. + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX); } bool diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp index 8c29a8962..c1ac63621 100644 --- a/src/zencore/zencore.cpp +++ b/src/zencore/zencore.cpp @@ -273,7 +273,7 @@ zencore_forcelinktests() zen::uid_forcelink(); zen::uson_forcelink(); zen::usonbuilder_forcelink(); - zen::usonpackage_forcelink(); + zen::cbpackage_forcelink(); zen::cbjson_forcelink(); zen::cbyaml_forcelink(); zen::workthreadpool_forcelink(); diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index e05c9815f..38021be16 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -479,6 +479,18 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) return Ref<IHttpPackageHandler>(); } +bool +HttpService::AcceptsLocalFileReferences() const +{ + return false; +} + +const ILocalRefPolicy* +HttpService::GetLocalRefPolicy() const +{ + return nullptr; +} + ////////////////////////////////////////////////////////////////////////// HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service) @@ -705,7 +717,10 @@ HttpServerRequest::ReadPayloadPackage() { if (IoBuffer Payload = ReadPayload()) { - return ParsePackageMessage(std::move(Payload)); + ParseFlags Flags = + (IsLocalMachineRequest() && m_Service.AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences : ParseFlags::kDefault; + const ILocalRefPolicy* Policy = EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences) ? m_Service.GetLocalRefPolicy() : nullptr; + return ParsePackageMessage(std::move(Payload), {}, Flags, Policy); } return {}; @@ -1273,7 +1288,12 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP return PackageHandlerRef->CreateTarget(Cid, Size); }; - CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer); + ParseFlags PkgFlags = (Request.IsLocalMachineRequest() && Service.AcceptsLocalFileReferences()) + ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + const ILocalRefPolicy* PkgPolicy = + EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? Service.GetLocalRefPolicy() : nullptr; + CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer, PkgFlags, PkgPolicy); PackageHandlerRef->OnRequestComplete(); } diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 5eaed6004..76f219f04 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -12,6 +12,7 @@ #include <zencore/string.h> #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <zenhttp/localrefpolicy.h> #include <zentelemetry/hyperloglog.h> #include <zentelemetry/stats.h> @@ -193,9 +194,16 @@ public: HttpService() = default; virtual ~HttpService() = default; - virtual const char* BaseUri() const = 0; - virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; - virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); + [[nodiscard]] virtual const char* BaseUri() const = 0; + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; + virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); + + /// Whether this service accepts local file references in inbound packages from local clients. + [[nodiscard]] virtual bool AcceptsLocalFileReferences() const; + + /// Returns the local ref policy for validating file paths in inbound local references. + /// Returns nullptr by default, which causes file-path local refs to be rejected (fail-closed). + [[nodiscard]] virtual const ILocalRefPolicy* GetLocalRefPolicy() const; // Internals diff --git a/src/zenhttp/include/zenhttp/localrefpolicy.h b/src/zenhttp/include/zenhttp/localrefpolicy.h new file mode 100644 index 000000000..0b37f9dc7 --- /dev/null +++ b/src/zenhttp/include/zenhttp/localrefpolicy.h @@ -0,0 +1,21 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <filesystem> + +namespace zen { + +/// Policy interface for validating local file reference paths in inbound CbPackage messages. +/// Implementations should throw std::invalid_argument if the path is not allowed. +class ILocalRefPolicy +{ +public: + virtual ~ILocalRefPolicy() = default; + + /// Validate that a local file reference path is allowed. + /// Throws std::invalid_argument if the path escapes the allowed root. + virtual void ValidatePath(const std::filesystem::path& Path) const = 0; +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h index 1a5068580..66e3f6e55 100644 --- a/src/zenhttp/include/zenhttp/packageformat.h +++ b/src/zenhttp/include/zenhttp/packageformat.h @@ -5,6 +5,7 @@ #include <zencore/compactbinarypackage.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> +#include <zenhttp/localrefpolicy.h> #include <functional> #include <gsl/gsl-lite.hpp> @@ -97,11 +98,22 @@ gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions); std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr); CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr); -CbPackage ParsePackageMessage( - IoBuffer Payload, - std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { + +enum class ParseFlags +{ + kDefault = 0, + kAllowLocalReferences = (1u << 0), // Allow packages containing local file references (local clients only) +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(ParseFlags); + +CbPackage ParsePackageMessage( + IoBuffer Payload, + std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; - }); + }, + ParseFlags Flags = ParseFlags::kDefault, + const ILocalRefPolicy* Policy = nullptr); bool IsPackageMessage(IoBuffer Payload); bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage); @@ -122,10 +134,11 @@ CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, void* Targe class CbPackageReader { public: - CbPackageReader(); + CbPackageReader(ParseFlags Flags = ParseFlags::kDefault); ~CbPackageReader(); void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer); + void SetLocalRefPolicy(const ILocalRefPolicy* Policy); /** Process compact binary package data stream @@ -149,6 +162,8 @@ private: kReadingBuffers } m_CurrentState = State::kInitialState; + ParseFlags m_Flags; + const ILocalRefPolicy* m_LocalRefPolicy = nullptr; std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer; std::vector<IoBuffer> m_PayloadBuffers; std::vector<CbAttachmentEntry> m_AttachmentEntries; diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp index 9c62c1f2d..267ce386c 100644 --- a/src/zenhttp/packageformat.cpp +++ b/src/zenhttp/packageformat.cpp @@ -36,6 +36,71 @@ const std::string_view HandlePrefix(":?#:"); typedef eastl::fixed_vector<IoBuffer, 16> IoBufferVec_t; +/// Enforce local-ref path policy. Handle-based refs bypass the policy since they use OS handle security. +/// If no policy is set, file-path local refs are rejected (fail-closed). +static void +ApplyLocalRefPolicy(const ILocalRefPolicy* Policy, const std::filesystem::path& Path) +{ + if (Policy) + { + Policy->ValidatePath(Path); + } + else + { + throw std::invalid_argument("local file reference rejected: no validation policy"); + } +} + +// Validates the CbPackageHeader magic and attachment count. Returns the total +// chunk count (AttachmentCount + 1, including the implicit root object). +static uint32_t +ValidatePackageHeader(const CbPackageHeader& Hdr) +{ + if (Hdr.HeaderMagic != kCbPkgMagic) + { + throw std::invalid_argument( + fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr.HeaderMagic)); + } + // ChunkCount is AttachmentCount + 1 (the root object is implicit). Guard against + // UINT32_MAX wrapping to 0, which would bypass subsequent size checks. + if (Hdr.AttachmentCount == UINT32_MAX) + { + throw std::invalid_argument("invalid CbPackage, attachment count overflow"); + } + return Hdr.AttachmentCount + 1; +} + +struct ValidatedLocalRef +{ + bool Valid = false; + const CbAttachmentReferenceHeader* Header = nullptr; + std::string_view Path; + std::string Error; +}; + +// Validates that the attachment buffer contains a well-formed local reference +// header and path. On failure, Valid is false and Error contains the reason. +static ValidatedLocalRef +ValidateLocalRef(const IoBuffer& AttachmentBuffer) +{ + if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader)) + { + return {.Error = fmt::format("local ref attachment too small for header (size {})", AttachmentBuffer.Size())}; + } + + const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); + + if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength) + { + return {.Error = fmt::format("local ref attachment too small for path (need {}, have {})", + sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength, + AttachmentBuffer.Size())}; + } + + const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); + return {.Valid = true, .Header = AttachRefHdr, .Path = std::string_view(PathPointer, AttachRefHdr->AbsolutePathLength)}; +} + IoBufferVec_t FormatPackageMessageInternal(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle); std::vector<IoBuffer> @@ -361,7 +426,10 @@ IsPackageMessage(IoBuffer Payload) } CbPackage -ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer) +ParsePackageMessage(IoBuffer Payload, + std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer, + ParseFlags Flags, + const ILocalRefPolicy* Policy) { ZEN_TRACE_CPU("ParsePackageMessage"); @@ -372,17 +440,13 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint BinaryReader Reader(Payload); - const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData()); - if (Hdr->HeaderMagic != kCbPkgMagic) - { - throw std::invalid_argument( - fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr->HeaderMagic)); - } + const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData()); + const uint32_t ChunkCount = ValidatePackageHeader(*Hdr); Reader.Skip(sizeof(CbPackageHeader)); - const uint32_t ChunkCount = Hdr->AttachmentCount + 1; - - if (Reader.Remaining() < sizeof(CbAttachmentEntry) * ChunkCount) + // Widen to uint64_t so the multiplication cannot wrap on 32-bit. + const uint64_t AttachmentTableSize = uint64_t(sizeof(CbAttachmentEntry)) * ChunkCount; + if (Reader.Remaining() < AttachmentTableSize) { throw std::invalid_argument(fmt::format("invalid CbPackage, missing attachment entry data (need {} bytes, have {} bytes)", sizeof(CbAttachmentEntry) * ChunkCount, @@ -417,15 +481,22 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) { - // Marshal local reference - a "pointer" to the chunk backing file - - ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); + if (!EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences)) + { + throw std::invalid_argument( + fmt::format("package contains local reference (attachment #{}) but local references are not allowed", i)); + } - const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); - const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); + // Marshal local reference - a "pointer" to the chunk backing file - ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); - std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength); + ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer); + if (!LocalRef.Valid) + { + MalformedAttachments.push_back(std::make_pair(i, fmt::format("{} for {}", LocalRef.Error, Entry.AttachmentHash))); + continue; + } + const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header; + std::string_view PathView = LocalRef.Path; IoBuffer FullFileBuffer; @@ -461,13 +532,29 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint } else { + ApplyLocalRefPolicy(Policy, Path); FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second; } } if (FullFileBuffer) { - IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize() + // Guard against offset+size overflow or exceeding the file bounds. + const uint64_t FileSize = FullFileBuffer.GetSize(); + if (AttachRefHdr->PayloadByteOffset > FileSize || + AttachRefHdr->PayloadByteSize > FileSize - AttachRefHdr->PayloadByteOffset) + { + MalformedAttachments.push_back( + std::make_pair(i, + fmt::format("Local ref offset/size out of bounds (offset {}, size {}, file size {}) for {}", + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize, + FileSize, + Entry.AttachmentHash))); + continue; + } + + IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FileSize ? FullFileBuffer : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); @@ -630,7 +717,9 @@ ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPa return OutPackage.TryLoad(Response); } -CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) +CbPackageReader::CbPackageReader(ParseFlags Flags) +: m_Flags(Flags) +, m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) { } @@ -644,6 +733,12 @@ CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Ci m_CreateBuffer = CreateBuffer; } +void +CbPackageReader::SetLocalRefPolicy(const ILocalRefPolicy* Policy) +{ + m_LocalRefPolicy = Policy; +} + uint64_t CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) { @@ -657,12 +752,14 @@ CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) return sizeof m_PackageHeader; case State::kReadingHeader: - ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); - memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); - ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic); - m_CurrentState = State::kReadingAttachmentEntries; - m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1); - return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry); + { + ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); + memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); + const uint32_t ChunkCount = ValidatePackageHeader(m_PackageHeader); + m_CurrentState = State::kReadingAttachmentEntries; + m_AttachmentEntries.resize(ChunkCount); + return uint64_t(ChunkCount) * sizeof(CbAttachmentEntry); + } case State::kReadingAttachmentEntries: ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry))); @@ -691,16 +788,19 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) { // Marshal local reference - a "pointer" to the chunk backing file - ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); - - const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); - const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1); - - ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); + ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer); + if (!LocalRef.Valid) + { + throw std::invalid_argument(LocalRef.Error); + } - std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength}; + const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header; + std::filesystem::path Path(Utf8ToWide(LocalRef.Path)); - std::filesystem::path Path{PathView}; + if (!LocalRef.Path.starts_with(HandlePrefix)) + { + ApplyLocalRefPolicy(m_LocalRefPolicy, Path); + } IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); @@ -714,6 +814,17 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) AttachRefHdr->PayloadByteSize)); } + // MakeFromFile silently clamps offset+size to the file size. Detect this + // to avoid returning a short buffer that could cause subtle downstream issues. + if (ChunkReference.GetSize() != AttachRefHdr->PayloadByteSize) + { + throw std::invalid_argument(fmt::format("local ref offset/size out of bounds for '{}' (requested offset {}, size {}, got size {})", + PathToUtf8(Path), + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize, + ChunkReference.GetSize())); + } + return ChunkReference; }; @@ -732,6 +843,13 @@ CbPackageReader::Finalize() { IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex]; + if ((Entry.Flags & CbAttachmentEntry::kIsLocalRef) && !EnumHasAllFlags(m_Flags, ParseFlags::kAllowLocalReferences)) + { + throw std::invalid_argument( + fmt::format("package contains local reference (attachment #{}) but local references are not allowed", + CurrentAttachmentIndex)); + } + if (CurrentAttachmentIndex == 0) { // Root object @@ -815,6 +933,13 @@ CbPackageReader::Finalize() TEST_SUITE_BEGIN("http.packageformat"); +/// Permissive policy that allows any path, for use in tests that exercise local ref +/// functionality but are not testing path validation. +struct PermissiveLocalRefPolicy : public ILocalRefPolicy +{ + void ValidatePath(const std::filesystem::path&) const override {} +}; + TEST_CASE("CbPackage.Serialization") { // Make a test package @@ -922,6 +1047,169 @@ TEST_CASE("CbPackage.LocalRef") RemainingBytes -= ByteCount; }; + PermissiveLocalRefPolicy AllowAllPolicy; + CbPackageReader Reader(ParseFlags::kAllowLocalReferences); + Reader.SetLocalRefPolicy(&AllowAllPolicy); + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); + NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); + auto Buffers = Reader.GetPayloadBuffers(); + + for (auto& PayloadBuffer : Buffers) + { + CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); + } + + Reader.Finalize(); +} + +TEST_CASE("CbPackage.Validation.TruncatedHeader") +{ + // Payload too small for a CbPackageHeader + uint8_t Bytes[] = {0xcc, 0xaa, 0x77, 0xaa}; + IoBuffer Payload(IoBuffer::Wrap, Bytes, sizeof(Bytes)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.BadMagic") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = 0xDEADBEEF; + Hdr.AttachmentCount = 0; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.AttachmentCountOverflow") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = UINT32_MAX; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.TruncatedAttachmentTable") +{ + // Valid header but not enough data for the attachment entries + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = 10; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.TruncatedAttachmentData") +{ + // Valid header + one attachment entry claiming more data than available + std::vector<uint8_t> Data(sizeof(CbPackageHeader) + sizeof(CbAttachmentEntry)); + + CbPackageHeader* Hdr = reinterpret_cast<CbPackageHeader*>(Data.data()); + Hdr->HeaderMagic = kCbPkgMagic; + Hdr->AttachmentCount = 0; // ChunkCount = 1 (root object) + + CbAttachmentEntry* Entry = reinterpret_cast<CbAttachmentEntry*>(Data.data() + sizeof(CbPackageHeader)); + Entry->PayloadSize = 9999; // way more than available + Entry->Flags = CbAttachmentEntry::kIsObject; + Entry->AttachmentHash = IoHash(); + + IoBuffer Payload(IoBuffer::Wrap, Data.data(), Data.size()); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.LocalRefRejectedByDefault") +{ + // Build a valid package with local refs backed by compressed-format files, + // then verify it's rejected with default ParseFlags and accepted when allowed. + ScopedTemporaryDirectory TempDir; + auto Path1 = TempDir.Path() / "abcd"; + auto Path2 = TempDir.Path() / "efgh"; + + // Compress data and write to disk, then create file-backed compressed attachments. + // The files must contain compressed-format data because ParsePackageMessage expects it + // when resolving local refs. + CompressedBuffer Comp1 = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("abcd")), OodleCompressor::NotSet, OodleCompressionLevel::None); + CompressedBuffer Comp2 = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("efgh")), OodleCompressor::NotSet, OodleCompressionLevel::None); + + IoHash Hash1 = Comp1.DecodeRawHash(); + IoHash Hash2 = Comp2.DecodeRawHash(); + + { + IoBuffer Buf1 = Comp1.GetCompressed().Flatten().AsIoBuffer(); + IoBuffer Buf2 = Comp2.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(Path1, Buf1); + WriteFile(Path2, Buf2); + } + + // Create attachments from file-backed buffers so FormatPackageMessage uses local refs + CbAttachment Attach1{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path1)), Hash1}; + CbAttachment Attach2{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path2)), Hash2}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("abcd", Attach1); + Cbo.AddAttachment("efgh", Attach2); + + CbPackage Pkg; + Pkg.AddAttachment(Attach1); + Pkg.AddAttachment(Attach2); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // Default flags should reject local refs + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); + + // With kAllowLocalReferences + a permissive policy, the local-ref gate is passed (the full round-trip + // for local refs through ParsePackageMessage is covered by CbPackage.LocalRef via CbPackageReader) + PermissiveLocalRefPolicy AllowAllPolicy; + CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &AllowAllPolicy); + CHECK(Result.GetObject()); + CHECK(Result.GetAttachments().size() == 2); +} + +TEST_CASE("CbPackage.Validation.LocalRefRejectedByReader") +{ + // Same test but via CbPackageReader + ScopedTemporaryDirectory TempDir; + auto FilePath = TempDir.Path() / "testdata"; + + { + IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("testdata")); + WriteFile(FilePath, Buf); + } + + IoBuffer FileBuffer = IoBufferBuilder::MakeFromFile(FilePath); + CbAttachment Attach{SharedBuffer(FileBuffer)}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten(); + const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); + uint64_t RemainingBytes = Buffer.GetSize(); + + auto ConsumeBytes = [&](uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + void* ReturnPtr = (void*)CursorPtr; + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + return ReturnPtr; + }; + + auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + memcpy(TargetBuffer, CursorPtr, ByteCount); + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + }; + + // Default flags should reject CbPackageReader Reader; uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); @@ -933,7 +1221,199 @@ TEST_CASE("CbPackage.LocalRef") CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); } - Reader.Finalize(); + CHECK_THROWS_AS(Reader.Finalize(), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.BadMagicViaReader") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = 0xBADCAFE; + Hdr.AttachmentCount = 0; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.AttachmentCountOverflowViaReader") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = UINT32_MAX; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathOutsideRoot") +{ + // A file outside the allowed root should be rejected by the policy + ScopedTemporaryDirectory AllowedRoot; + ScopedTemporaryDirectory OutsideDir; + + auto OutsidePath = OutsideDir.Path() / "secret.dat"; + { + IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("secret")); + WriteFile(OutsidePath, Buf); + } + + // Create file-backed compressed attachment from outside root + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("secret")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(OutsidePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // Policy rooted at AllowedRoot should reject the file in OutsideDir + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + Policy.Root = std::filesystem::weakly_canonical(AllowedRoot.Path()).string(); + + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathInsideRoot") +{ + // A file inside the allowed root should be accepted by the policy + ScopedTemporaryDirectory TempRoot; + + auto FilePath = TempRoot.Path() / "data.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("hello")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(FilePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string(); + + CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy); + CHECK(Result.GetObject()); + CHECK(Result.GetAttachments().size() == 1); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathTraversal") +{ + // A file path containing ".." that resolves outside root should be rejected + ScopedTemporaryDirectory TempRoot; + ScopedTemporaryDirectory OutsideDir; + + auto OutsidePath = OutsideDir.Path() / "evil.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("evil")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(OutsidePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + // Root is TempRoot, but the file lives in OutsideDir + Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string(); + + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.NoPolicyFailClosed") +{ + // When local refs are allowed but no policy is provided, file-path refs should be rejected + ScopedTemporaryDirectory TempDir; + + auto FilePath = TempDir.Path() / "data.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("data")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(FilePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // kAllowLocalReferences but nullptr policy => fail-closed + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, nullptr), std::invalid_argument); } TEST_SUITE_END(); diff --git a/src/zenserver-test/cache-tests.cpp b/src/zenserver-test/cache-tests.cpp index 14748e214..986dc67e0 100644 --- a/src/zenserver-test/cache-tests.cpp +++ b/src/zenserver-test/cache-tests.cpp @@ -9,6 +9,7 @@ # include <zencore/compactbinarypackage.h> # include <zencore/compress.h> # include <zencore/fmtutils.h> +# include <zenhttp/localrefpolicy.h> # include <zenhttp/packageformat.h> # include <zenstore/cache/cachepolicy.h> # include <zencore/filesystem.h> @@ -25,6 +26,13 @@ namespace zen::tests { TEST_SUITE_BEGIN("server.cache"); +/// Permissive policy that allows any path, for use in tests that exercise local ref +/// functionality but are not testing path validation. +struct PermissiveLocalRefPolicy : public ILocalRefPolicy +{ + void ValidatePath(const std::filesystem::path&) const override {} +}; + TEST_CASE("zcache.basic") { using namespace std::literals; @@ -743,7 +751,11 @@ TEST_CASE("zcache.rpc") if (Result.StatusCode == HttpResponseCode::OK) { - CbPackage Response = ParsePackageMessage(Result.ResponsePayload); + ParseFlags PFlags = EnumHasAllFlags(AcceptOptions, RpcAcceptOptions::kAllowLocalReferences) ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + PermissiveLocalRefPolicy AllowAllPolicy; + const ILocalRefPolicy* PPolicy = EnumHasAllFlags(PFlags, ParseFlags::kAllowLocalReferences) ? &AllowAllPolicy : nullptr; + CbPackage Response = ParsePackageMessage(Result.ResponsePayload, {}, PFlags, PPolicy); CHECK(!Response.IsNull()); OutResult.Response = std::move(Response); CHECK(OutResult.Result.Parse(OutResult.Response)); @@ -1745,8 +1757,13 @@ TEST_CASE("zcache.rpc.partialchunks") CHECK(Result.StatusCode == HttpResponseCode::OK); - CbPackage Response = ParsePackageMessage(Result.ResponsePayload); - bool Loaded = !Response.IsNull(); + ParseFlags PFlags = EnumHasAllFlags(Options.AcceptOptions, RpcAcceptOptions::kAllowLocalReferences) + ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + PermissiveLocalRefPolicy AllowAllPolicy; + const ILocalRefPolicy* PPolicy = EnumHasAllFlags(PFlags, ParseFlags::kAllowLocalReferences) ? &AllowAllPolicy : nullptr; + CbPackage Response = ParsePackageMessage(Result.ResponsePayload, {}, PFlags, PPolicy); + bool Loaded = !Response.IsNull(); CHECK_MESSAGE(Loaded, "GetCacheChunks response failed to load."); cacherequests::GetCacheChunksResult GetCacheChunksResult; CHECK(GetCacheChunksResult.Parse(Response)); diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp index 021052a3b..95541c3ce 100644 --- a/src/zenserver-test/compute-tests.cpp +++ b/src/zenserver-test/compute-tests.cpp @@ -21,6 +21,7 @@ # include <zenhttp/httpserver.h> # include <zenhttp/websocket.h> # include <zencompute/computeservice.h> +# include <zencore/fmtutils.h> # include <zenstore/zenstore.h> # include <zenutil/zenserverprocess.h> @@ -36,6 +37,8 @@ using namespace std::literals; static constexpr std::string_view kBuildSystemVersion = "17fe280d-ccd8-4be8-a9d1-89c944a70969"; static constexpr std::string_view kRot13Version = "13131313-1313-1313-1313-131313131313"; static constexpr std::string_view kSleepVersion = "88888888-8888-8888-8888-888888888888"; +static constexpr std::string_view kFailVersion = "fa11fa11-fa11-fa11-fa11-fa11fa11fa11"; +static constexpr std::string_view kCrashVersion = "c4a50000-c4a5-c4a5-c4a5-c4a5c4a5c4a5"; // In-memory implementation of ChunkResolver for test use. // Stores compressed data keyed by decompressed content hash. @@ -104,6 +107,16 @@ RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env) << "Sleep"sv; WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Fail"sv; + WorkerWriter << "version"sv << Guid::FromString(kFailVersion); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Crash"sv; + WorkerWriter << "version"sv << Guid::FromString(kCrashVersion); + WorkerWriter.EndObject(); WorkerWriter.EndArray(); CbPackage WorkerPackage; @@ -115,7 +128,7 @@ RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env) const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); HttpClient::Response RegisterResp = Client.Post(WorkerUrl, std::move(WorkerPackage)); REQUIRE_MESSAGE(RegisterResp, - fmt::format("Worker registration failed: status={}, body={}", int(RegisterResp.StatusCode), RegisterResp.ToText())); + fmt::format("Worker registration failed: status={}, body={}", RegisterResp.StatusCode, RegisterResp.ToText())); return WorkerId; } @@ -220,6 +233,83 @@ BuildSleepActionForSession(std::string_view Input, uint64_t SleepTimeMs, InMemor return ActionWriter.Save(); } +// Build a Fail action CbPackage. The worker exits with the given exit code. +static CbPackage +BuildFailActionPackage(int ExitCode) +{ + // The Fail function throws before reading inputs, but the action structure + // still requires a valid input attachment for the runner to manifest. + std::string_view Dummy = "x"sv; + + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Dummy.data(), Dummy.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Dummy.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Fail"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kFailVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "ExitCode"sv << static_cast<uint64_t>(ExitCode); + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + +// Build a Crash action CbPackage. The worker process crashes hard. +// Mode: "abort" (default) or "nullptr" (null pointer dereference). +static CbPackage +BuildCrashActionPackage(std::string_view Mode = "abort"sv) +{ + std::string_view Dummy = "x"sv; + + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Dummy.data(), Dummy.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Dummy.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Crash"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kCrashVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "Mode"sv << Mode; + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + static HttpClient::Response PollForResult(HttpClient& Client, const std::string& ResultUrl, uint64_t TimeoutMs = 30'000) { @@ -469,6 +559,16 @@ BuildWorkerPackage(ZenServerEnvironment& Env, InMemoryChunkResolver& Resolver) << "Sleep"sv; WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Fail"sv; + WorkerWriter << "version"sv << Guid::FromString(kFailVersion); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Crash"sv; + WorkerWriter << "version"sv << Guid::FromString(kCrashVersion); + WorkerWriter.EndObject(); WorkerWriter.EndArray(); CbPackage WorkerPackage; @@ -526,7 +626,7 @@ TEST_CASE("function.rot13") // Submit action via legacy /jobs/{worker} endpoint const std::string JobUrl = fmt::format("/jobs/{}", WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission"); @@ -536,7 +636,7 @@ TEST_CASE("function.rot13") HttpClient::Response ResultResp = PollForResult(Client, ResultUrl); REQUIRE_MESSAGE( ResultResp.StatusCode == HttpResponseCode::OK, - fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", ResultResp.StatusCode, Instance.GetLogOutput())); // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" CbPackage ResultPackage = ResultResp.AsPackage(); @@ -581,7 +681,7 @@ TEST_CASE("function.workers") // GET /workers/{worker} — descriptor should match what was registered const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); HttpClient::Response DescResp = Client.Get(WorkerUrl); - REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", int(DescResp.StatusCode))); + REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", DescResp.StatusCode)); CbObject Desc = DescResp.AsObject(); CHECK_EQ(Desc["buildsystem_version"sv].AsUuid(), Guid::FromString(kBuildSystemVersion)); @@ -627,7 +727,7 @@ TEST_CASE("function.queues.lifecycle") // Create a queue HttpClient::Response CreateResp = Client.Post("/queues"sv); - REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText())); const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); @@ -651,8 +751,7 @@ TEST_CASE("function.queues.lifecycle") // Submit action via queue-scoped endpoint const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); - REQUIRE_MESSAGE(SubmitResp, - fmt::format("Queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Queue job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from queue job submission"); @@ -668,9 +767,8 @@ TEST_CASE("function.queues.lifecycle") // Retrieve result via queue-scoped /jobs/{lsn} endpoint const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); HttpClient::Response ResultResp = Client.Get(ResultUrl); - REQUIRE_MESSAGE( - ResultResp.StatusCode == HttpResponseCode::OK, - fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", ResultResp.StatusCode, Instance.GetLogOutput())); // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" CbPackage ResultPackage = ResultResp.AsPackage(); @@ -712,13 +810,13 @@ TEST_CASE("function.queues.cancel") // Submit a job const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); // Cancel the queue const std::string QueueUrl = fmt::format("/queues/{}", QueueId); HttpClient::Response CancelResp = Client.Delete(QueueUrl); REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, - fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + fmt::format("Queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText())); // Verify queue status shows cancelled HttpClient::Response StatusResp = Client.Get(QueueUrl); @@ -743,7 +841,7 @@ TEST_CASE("function.queues.remote") // Create a remote queue — response includes both an integer queue_id and an OID queue_token HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); REQUIRE_MESSAGE(CreateResp, - fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + fmt::format("Remote queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText())); CbObject CreateObj = CreateResp.AsObject(); const std::string QueueToken = std::string(CreateObj["queue_token"sv].AsString()); @@ -753,7 +851,7 @@ TEST_CASE("function.queues.remote") const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); REQUIRE_MESSAGE(SubmitResp, - fmt::format("Remote queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + fmt::format("Remote queue job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from remote queue job submission"); @@ -769,7 +867,7 @@ TEST_CASE("function.queues.remote") HttpClient::Response ResultResp = Client.Get(ResultUrl); REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK, fmt::format("Failed to retrieve result from remote queue: status={}\nServer log:\n{}", - int(ResultResp.StatusCode), + ResultResp.StatusCode, Instance.GetLogOutput())); // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" @@ -801,8 +899,7 @@ TEST_CASE("function.queues.cancel_running") // Submit a Sleep job long enough that it will still be running when we cancel const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(SubmitResp, - fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); @@ -814,7 +911,7 @@ TEST_CASE("function.queues.cancel_running") const std::string QueueUrl = fmt::format("/queues/{}", QueueId); HttpClient::Response CancelResp = Client.Delete(QueueUrl); REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, - fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + fmt::format("Queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText())); // The cancelled job should appear in the /completed endpoint once the process exits const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); @@ -849,7 +946,7 @@ TEST_CASE("function.queues.remote_cancel") // Create a remote queue to obtain an OID token for token-addressed cancellation HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); REQUIRE_MESSAGE(CreateResp, - fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + fmt::format("Remote queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText())); const std::string QueueToken = std::string(CreateResp.AsObject()["queue_token"sv].AsString()); REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation"); @@ -857,8 +954,7 @@ TEST_CASE("function.queues.remote_cancel") // Submit a long-running Sleep job via the token-addressed endpoint const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(SubmitResp, - fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); @@ -870,7 +966,7 @@ TEST_CASE("function.queues.remote_cancel") const std::string QueueUrl = fmt::format("/queues/{}", QueueToken); HttpClient::Response CancelResp = Client.Delete(QueueUrl); REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, - fmt::format("Remote queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + fmt::format("Remote queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText())); // The cancelled job should appear in the token-addressed /completed endpoint const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken); @@ -910,13 +1006,13 @@ TEST_CASE("function.queues.drain") // Submit a long-running job so we can verify it completes even after drain const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response Submit1 = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 2'000)); - REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", int(Submit1.StatusCode))); + REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", Submit1.StatusCode)); const int Lsn1 = Submit1.AsObject()["lsn"sv].AsInt32(); // Drain the queue const std::string DrainUrl = fmt::format("/queues/{}/drain", QueueId); HttpClient::Response DrainResp = Client.Post(DrainUrl); - REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", int(DrainResp.StatusCode), DrainResp.ToText())); + REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", DrainResp.StatusCode, DrainResp.ToText())); CHECK_EQ(std::string(DrainResp.AsObject()["state"sv].AsString()), "draining"); // Second submission should be rejected with 424 @@ -965,7 +1061,7 @@ TEST_CASE("function.priority") // jobs by priority when the slot becomes free. const std::string BlockerJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); HttpClient::Response BlockerResp = Client.Post(BlockerJobUrl, BuildSleepActionPackage("data"sv, 1'000)); - REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", int(BlockerResp.StatusCode))); + REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", BlockerResp.StatusCode)); // Submit 3 low-priority Rot13 jobs const std::string LowJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); @@ -1104,6 +1200,432 @@ TEST_CASE("function.priority") } ////////////////////////////////////////////////////////////////////////// +// Process exit code tests +// +// These tests exercise how the compute service handles worker processes +// that exit with non-zero exit codes, including retry behaviour and +// final failure reporting. + +TEST_CASE("function.exit_code.failed_action") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=0 so the action fails immediately + // without being rescheduled. + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 0; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Fail action with exit code 42 + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(42)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Fail job submission"); + + // Poll for the LSN to appear in the completed list + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify queue status reflects the failure + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + + // Verify action history records the failure + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + bool FoundInHistory = false; + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + FoundInHistory = true; + break; + } + } + CHECK_MESSAGE(FoundInHistory, fmt::format("LSN {} not found in action history", Lsn)); + + // GET /jobs/{lsn} for a failed action should return OK but with an empty result package + const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); + HttpClient::Response ResultResp = Client.Get(ResultUrl); + CHECK_EQ(ResultResp.StatusCode, HttpResponseCode::OK); +} + +TEST_CASE("function.exit_code.auto_retry") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=2 so the action is retried twice before + // being reported as failed (3 total attempts: initial + 2 retries). + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 2; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Fail action — the worker process will exit with code 1 on + // every attempt, eventually exhausting retries. + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(1)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}", SubmitResp.StatusCode)); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + + // Poll for the LSN to appear in the completed list — this only + // happens after all retries are exhausted. + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000), + fmt::format("LSN {} did not appear in queue {} completed list after retries\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify the action history records the retry count + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + CHECK_EQ(Item.AsObjectView()["retry_count"sv].AsInt32(), 2); + break; + } + } + + // Queue should show 1 failed, 0 completed + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.exit_code.reschedule_failed") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=1 so we have room for one manual reschedule + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 1; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Fail action — auto-retry will fire once, then it lands in results as Failed + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(7)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}", SubmitResp.StatusCode)); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + + // Wait for the action to exhaust its auto-retry and land in completed + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000), + fmt::format("LSN {} did not appear in queue completed list\nServer log:\n{}", Lsn, Instance.GetLogOutput())); + + // Try to manually reschedule — should fail because retry limit is reached + const std::string RescheduleUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); + HttpClient::Response RescheduleResp = Client.Post(RescheduleUrl); + CHECK_EQ(RescheduleResp.StatusCode, HttpResponseCode::Conflict); +} + +TEST_CASE("function.exit_code.mixed_success_and_failure") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=0 for fast failure + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 0; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit one Rot13 (success) and one Fail (failure) + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + + HttpClient::Response SuccessResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello"sv)); + REQUIRE_MESSAGE(SuccessResp, "Rot13 job submission failed"); + const int LsnSuccess = SuccessResp.AsObject()["lsn"sv].AsInt32(); + + HttpClient::Response FailResp = Client.Post(JobUrl, BuildFailActionPackage(1)); + REQUIRE_MESSAGE(FailResp, "Fail job submission failed"); + const int LsnFail = FailResp.AsObject()["lsn"sv].AsInt32(); + + // Wait for both to appear in the completed list + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, LsnSuccess), + fmt::format("Success LSN {} did not complete\nServer log:\n{}", LsnSuccess, Instance.GetLogOutput())); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, LsnFail), + fmt::format("Fail LSN {} did not complete\nServer log:\n{}", LsnFail, Instance.GetLogOutput())); + + // Verify queue counters + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.crash.abort") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=0 so we don't wait through retries + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 0; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Crash action that calls std::abort() + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("abort"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Crash job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Crash job submission"); + + // Poll for the LSN to appear in the completed list + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify queue status reflects the failure + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + + // Verify action history records the failure + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + bool FoundInHistory = false; + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + FoundInHistory = true; + break; + } + } + CHECK_MESSAGE(FoundInHistory, fmt::format("LSN {} not found in action history", Lsn)); +} + +TEST_CASE("function.crash.nullptr") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=0 + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 0; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Crash action that dereferences a null pointer + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("nullptr"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Crash job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Crash job submission"); + + // Poll for the LSN to appear in the completed list + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify queue status reflects the failure + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.crash.auto_retry") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=1 — the crash should be retried once + // before being reported as permanently failed. + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 1; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Crash action — will crash on every attempt + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("abort"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Crash job submission failed: status={}", SubmitResp.StatusCode)); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + + // Poll for the LSN to appear in the completed list after retries exhaust + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000), + fmt::format("LSN {} did not appear in queue {} completed list after retries\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify the action history records the retry count + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + CHECK_EQ(Item.AsObjectView()["retry_count"sv].AsInt32(), 1); + break; + } + } + + // Queue should show 1 failed, 0 completed + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +////////////////////////////////////////////////////////////////////////// // Remote worker synchronization tests // // These tests exercise the orchestrator discovery path where new compute @@ -1162,9 +1684,8 @@ TEST_CASE("function.remote.worker_sync_on_discovery") Sleep(200); } - REQUIRE_MESSAGE( - ResultCode == HttpResponseCode::OK, - fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", ResultCode, Instance.GetLogOutput())); REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); @@ -1349,9 +1870,8 @@ TEST_CASE("function.remote.queue_association") Sleep(200); } - REQUIRE_MESSAGE( - ResultCode == HttpResponseCode::OK, - fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", ResultCode, Instance.GetLogOutput())); REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); @@ -1481,7 +2001,7 @@ TEST_CASE("function.abandon_running_http") const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", int(SubmitResp.StatusCode))); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", SubmitResp.StatusCode)); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN"); @@ -1498,7 +2018,7 @@ TEST_CASE("function.abandon_running_http") // Trigger abandon via the HTTP endpoint HttpClient::Response AbandonResp = Client.Post("/abandon"sv); REQUIRE_MESSAGE(AbandonResp.StatusCode == HttpResponseCode::OK, - fmt::format("Abandon request failed: status={}, body={}", int(AbandonResp.StatusCode), AbandonResp.ToText())); + fmt::format("Abandon request failed: status={}, body={}", AbandonResp.StatusCode, AbandonResp.ToText())); // Ready endpoint should now return 503 { @@ -1577,7 +2097,7 @@ TEST_CASE("function.session.abandon_pending") } Sleep(100); } - CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code))); + CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, Code)); } // Queue should show 0 active, 3 abandoned @@ -1979,11 +2499,11 @@ TEST_CASE("function.retract_http") // Submit a long-running Sleep action to occupy the single execution slot const std::string BlockerUrl = fmt::format("/jobs/{}", WorkerId.ToHexString()); HttpClient::Response BlockerResp = Client.Post(BlockerUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", int(BlockerResp.StatusCode))); + REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", BlockerResp.StatusCode)); // Submit a second action — it will stay pending because the slot is occupied HttpClient::Response SubmitResp = Client.Post(BlockerUrl, BuildRot13ActionPackage("Retract HTTP Test"sv)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", int(SubmitResp.StatusCode))); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", SubmitResp.StatusCode)); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission"); @@ -1996,7 +2516,7 @@ TEST_CASE("function.retract_http") HttpClient::Response RetractResp = Client.Post(RetractUrl); CHECK_MESSAGE(RetractResp.StatusCode == HttpResponseCode::OK, fmt::format("Retract failed: status={}, body={}\nServer log:\n{}", - int(RetractResp.StatusCode), + RetractResp.StatusCode, RetractResp.ToText(), Instance.GetLogOutput())); @@ -2011,7 +2531,42 @@ TEST_CASE("function.retract_http") Sleep(500); HttpClient::Response RetractResp2 = Client.Post(RetractUrl); CHECK_MESSAGE(RetractResp2.StatusCode == HttpResponseCode::OK, - fmt::format("Second retract failed: status={}, body={}", int(RetractResp2.StatusCode), RetractResp2.ToText())); + fmt::format("Second retract failed: status={}, body={}", RetractResp2.StatusCode, RetractResp2.ToText())); +} + +TEST_CASE("function.session.immediate_query_after_enqueue") +{ + // Verify that actions are immediately visible to GetActionResult and + // FindActionResult right after enqueue, without waiting for the + // scheduler thread to process the update. + + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.Ready(); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + CbObject ActionObj = BuildRot13ActionForSession("immediate-query"sv, Resolver); + + auto EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Failed to enqueue action"); + + // Query by LSN immediately — must not return NotFound + CbPackage Result; + HttpResponseCode Code = Session.GetActionResult(EnqueueRes.Lsn, Result); + CHECK_MESSAGE(Code == HttpResponseCode::Accepted, + fmt::format("GetActionResult returned {} immediately after enqueue, expected Accepted", Code)); + + // Query by ActionId immediately — must not return NotFound + const IoHash ActionId = ActionObj.GetHash(); + CbPackage FindResult; + HttpResponseCode FindCode = Session.FindActionResult(ActionId, FindResult); + CHECK_MESSAGE(FindCode == HttpResponseCode::Accepted, + fmt::format("FindActionResult returned {} immediately after enqueue, expected Accepted", FindCode)); + + Session.Shutdown(); } TEST_SUITE_END(); diff --git a/src/zenserver/frontend/html/pages/compute.js b/src/zenserver/frontend/html/pages/compute.js index d1a880954..2eb4d4e9b 100644 --- a/src/zenserver/frontend/html/pages/compute.js +++ b/src/zenserver/frontend/html/pages/compute.js @@ -24,6 +24,12 @@ function formatTime(date) return date.toLocaleTimeString([], { hour: "2-digit", minute: "2-digit", second: "2-digit" }); } +function truncateHash(hash) +{ + if (!hash || hash.length <= 15) return hash; + return hash.slice(0, 6) + "\u2026" + hash.slice(-6); +} + function formatDuration(startDate, endDate) { if (!startDate || !endDate) return "-"; @@ -305,11 +311,7 @@ export class Page extends ZenPage { const workerIds = data.workers || []; - if (this._workers_table) - { - this._workers_table.clear(); - } - else + if (!this._workers_table) { this._workers_table = this._workers_host.add_widget( Table, @@ -320,6 +322,7 @@ export class Page extends ZenPage if (workerIds.length === 0) { + this._workers_table.clear(); return; } @@ -349,6 +352,9 @@ export class Page extends ZenPage id, ); + // Worker ID column: monospace for hex readability + row.get_cell(5).style("fontFamily", "'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace"); + // Make name clickable to expand detail const cell = row.get_cell(0); cell.tag().text(name).on_click(() => this._toggle_worker_detail(id, desc)); @@ -546,6 +552,11 @@ export class Page extends ZenPage ["LSN", "queue", "status", "function", "started", "finished", "duration", "worker ID", "action ID"], Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1 ); + + // Right-align hash column headers to match data cells + const hdr = this._history_table.inner().firstElementChild; + hdr.children[7].style.textAlign = "right"; + hdr.children[8].style.textAlign = "right"; } // Entries arrive oldest-first; reverse to show newest at top @@ -560,7 +571,10 @@ export class Page extends ZenPage const startDate = filetimeToDate(entry.time_Running); const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed); - this._history_table.add_row( + const workerId = entry.workerId || "-"; + const actionId = entry.actionId || "-"; + + const row = this._history_table.add_row( lsn, queueId, status, @@ -568,9 +582,15 @@ export class Page extends ZenPage formatTime(startDate), formatTime(endDate), formatDuration(startDate, endDate), - entry.workerId || "-", - entry.actionId || "-", + truncateHash(workerId), + truncateHash(actionId), ); + + // Hash columns: force right-align (AlignNumeric misses hex strings starting with a-f), + // use monospace for readability, and show full value on hover + const mono = "'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace"; + row.get_cell(7).style("textAlign", "right").style("fontFamily", mono).attr("title", workerId); + row.get_cell(8).style("textAlign", "right").style("fontFamily", mono).attr("title", actionId); } } diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp index c1727270c..8ad48225b 100644 --- a/src/zenserver/storage/cache/httpstructuredcache.cpp +++ b/src/zenserver/storage/cache/httpstructuredcache.cpp @@ -80,7 +80,8 @@ HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCach HttpStatusService& StatusService, UpstreamCache& UpstreamCache, const DiskWriteBlocker* InDiskWriteBlocker, - OpenProcessCache& InOpenProcessCache) + OpenProcessCache& InOpenProcessCache, + const ILocalRefPolicy* InLocalRefPolicy) : m_Log(logging::Get("cache")) , m_CacheStore(InCacheStore) , m_StatsService(StatsService) @@ -90,6 +91,7 @@ HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCach , m_DiskWriteBlocker(InDiskWriteBlocker) , m_OpenProcessCache(InOpenProcessCache) , m_RpcHandler(m_Log, m_CacheStats, UpstreamCache, InCacheStore, InCidStore, InDiskWriteBlocker) +, m_LocalRefPolicy(InLocalRefPolicy) { m_StatsService.RegisterHandler("z$", *this); m_StatusService.RegisterHandler("z$", *this); @@ -114,6 +116,18 @@ HttpStructuredCacheService::BaseUri() const return "/z$/"; } +bool +HttpStructuredCacheService::AcceptsLocalFileReferences() const +{ + return true; +} + +const ILocalRefPolicy* +HttpStructuredCacheService::GetLocalRefPolicy() const +{ + return m_LocalRefPolicy; +} + void HttpStructuredCacheService::Flush() { diff --git a/src/zenserver/storage/cache/httpstructuredcache.h b/src/zenserver/storage/cache/httpstructuredcache.h index fc80b449e..f606126d6 100644 --- a/src/zenserver/storage/cache/httpstructuredcache.h +++ b/src/zenserver/storage/cache/httpstructuredcache.h @@ -76,11 +76,14 @@ public: HttpStatusService& StatusService, UpstreamCache& UpstreamCache, const DiskWriteBlocker* InDiskWriteBlocker, - OpenProcessCache& InOpenProcessCache); + OpenProcessCache& InOpenProcessCache, + const ILocalRefPolicy* InLocalRefPolicy = nullptr); ~HttpStructuredCacheService(); - virtual const char* BaseUri() const override; - virtual void HandleRequest(HttpServerRequest& Request) override; + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + virtual bool AcceptsLocalFileReferences() const override; + virtual const ILocalRefPolicy* GetLocalRefPolicy() const override; void Flush(); @@ -125,6 +128,7 @@ private: const DiskWriteBlocker* m_DiskWriteBlocker = nullptr; OpenProcessCache& m_OpenProcessCache; CacheRpcHandler m_RpcHandler; + const ILocalRefPolicy* m_LocalRefPolicy = nullptr; void ReplayRequestRecorder(const CacheRequestContext& Context, cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount); diff --git a/src/zenserver/storage/localrefpolicy.cpp b/src/zenserver/storage/localrefpolicy.cpp new file mode 100644 index 000000000..47ef13b28 --- /dev/null +++ b/src/zenserver/storage/localrefpolicy.cpp @@ -0,0 +1,29 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "localrefpolicy.h" + +#include <zencore/except_fmt.h> +#include <zencore/fmtutils.h> + +#include <filesystem> + +namespace zen { + +DataRootLocalRefPolicy::DataRootLocalRefPolicy(const std::filesystem::path& DataRoot) +: m_CanonicalRoot(std::filesystem::weakly_canonical(DataRoot).string()) +{ +} + +void +DataRootLocalRefPolicy::ValidatePath(const std::filesystem::path& Path) const +{ + std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(Path); + std::string FileStr = CanonicalFile.string(); + + if (FileStr.size() < m_CanonicalRoot.size() || FileStr.compare(0, m_CanonicalRoot.size(), m_CanonicalRoot) != 0) + { + throw zen::invalid_argument("local file reference '{}' is outside allowed data root", CanonicalFile); + } +} + +} // namespace zen diff --git a/src/zenserver/storage/localrefpolicy.h b/src/zenserver/storage/localrefpolicy.h new file mode 100644 index 000000000..3686d1880 --- /dev/null +++ b/src/zenserver/storage/localrefpolicy.h @@ -0,0 +1,25 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/localrefpolicy.h> + +#include <filesystem> +#include <string> + +namespace zen { + +/// Local ref policy that restricts file paths to a canonical data root directory. +/// Uses weakly_canonical + string prefix comparison to detect path traversal. +class DataRootLocalRefPolicy : public ILocalRefPolicy +{ +public: + explicit DataRootLocalRefPolicy(const std::filesystem::path& DataRoot); + + void ValidatePath(const std::filesystem::path& Path) const override; + +private: + std::string m_CanonicalRoot; +}; + +} // namespace zen diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp index a7c8c66b6..afd0d8f82 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.cpp +++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp @@ -656,7 +656,8 @@ HttpProjectService::HttpProjectService(CidStore& Store, JobQueue& InJobQueue, bool InRestrictContentTypes, const std::filesystem::path& InOidcTokenExePath, - bool InAllowExternalOidcTokenExe) + bool InAllowExternalOidcTokenExe, + const ILocalRefPolicy* InLocalRefPolicy) : m_Log(logging::Get("project")) , m_CidStore(Store) , m_ProjectStore(Projects) @@ -668,6 +669,7 @@ HttpProjectService::HttpProjectService(CidStore& Store, , m_RestrictContentTypes(InRestrictContentTypes) , m_OidcTokenExePath(InOidcTokenExePath) , m_AllowExternalOidcTokenExe(InAllowExternalOidcTokenExe) +, m_LocalRefPolicy(InLocalRefPolicy) { ZEN_MEMSCOPE(GetProjectHttpTag()); @@ -820,6 +822,18 @@ HttpProjectService::BaseUri() const return "/prj/"; } +bool +HttpProjectService::AcceptsLocalFileReferences() const +{ + return true; +} + +const ILocalRefPolicy* +HttpProjectService::GetLocalRefPolicy() const +{ + return m_LocalRefPolicy; +} + void HttpProjectService::HandleRequest(HttpServerRequest& Request) { @@ -1668,7 +1682,8 @@ HttpProjectService::HandleOplogOpNewRequest(HttpRouterRequest& Req) CbPackage Package; - if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver)) + const bool ValidateHashes = false; + if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver, ValidateHashes)) { CbValidateError ValidateResult; if (CbObject Core = ValidateAndReadCompactBinaryObject(IoBuffer(Payload), ValidateResult); @@ -2763,7 +2778,11 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) case HttpContentType::kCbPackage: try { - Package = ParsePackageMessage(Payload); + ParseFlags PkgFlags = (HttpReq.IsLocalMachineRequest() && AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + const ILocalRefPolicy* PkgPolicy = + EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? GetLocalRefPolicy() : nullptr; + Package = ParsePackageMessage(Payload, {}, PkgFlags, PkgPolicy); Cb = Package.GetObject(); } catch (const std::invalid_argument& ex) diff --git a/src/zenserver/storage/projectstore/httpprojectstore.h b/src/zenserver/storage/projectstore/httpprojectstore.h index e3ed02f26..8aa345fa7 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.h +++ b/src/zenserver/storage/projectstore/httpprojectstore.h @@ -47,11 +47,14 @@ public: JobQueue& InJobQueue, bool InRestrictContentTypes, const std::filesystem::path& InOidcTokenExePath, - bool AllowExternalOidcTokenExe); + bool AllowExternalOidcTokenExe, + const ILocalRefPolicy* InLocalRefPolicy = nullptr); ~HttpProjectService(); - virtual const char* BaseUri() const override; - virtual void HandleRequest(HttpServerRequest& Request) override; + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + virtual bool AcceptsLocalFileReferences() const override; + virtual const ILocalRefPolicy* GetLocalRefPolicy() const override; virtual void HandleStatusRequest(HttpServerRequest& Request) override; virtual void HandleStatsRequest(HttpServerRequest& Request) override; @@ -117,6 +120,7 @@ private: bool m_RestrictContentTypes; std::filesystem::path m_OidcTokenExePath; bool m_AllowExternalOidcTokenExe; + const ILocalRefPolicy* m_LocalRefPolicy; Ref<TransferThreadWorkers> GetThreadWorkers(bool BoostWorkers, bool SingleThreaded); }; diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp index bc0a8f4ac..6b1da5f12 100644 --- a/src/zenserver/storage/zenstorageserver.cpp +++ b/src/zenserver/storage/zenstorageserver.cpp @@ -223,6 +223,7 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions ZEN_INFO("instantiating project service"); + m_LocalRefPolicy = std::make_unique<DataRootLocalRefPolicy>(m_DataRoot); m_JobQueue = MakeJobQueue(8, "bgjobs"); m_OpenProcessCache = std::make_unique<OpenProcessCache>(); @@ -236,7 +237,8 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions *m_JobQueue, ServerOptions.RestrictContentTypes, ServerOptions.OidcTokenExecutable, - ServerOptions.AllowExternalOidcTokenExe}); + ServerOptions.AllowExternalOidcTokenExe, + m_LocalRefPolicy.get()}); if (ServerOptions.WorksSpacesConfig.Enabled) { @@ -713,7 +715,8 @@ ZenStorageServer::InitializeStructuredCache(const ZenStorageServerConfig& Server m_StatusService, *m_UpstreamCache, m_GcManager.GetDiskWriteBlocker(), - *m_OpenProcessCache); + *m_OpenProcessCache, + m_LocalRefPolicy.get()); m_StatsReporter.AddProvider(m_CacheStore.Get()); m_StatsReporter.AddProvider(m_CidStore.get()); diff --git a/src/zenserver/storage/zenstorageserver.h b/src/zenserver/storage/zenstorageserver.h index fad22ad54..e3c6248e6 100644 --- a/src/zenserver/storage/zenstorageserver.h +++ b/src/zenserver/storage/zenstorageserver.h @@ -11,6 +11,7 @@ #include <zenstore/cache/structuredcachestore.h> #include <zenstore/gc.h> #include <zenstore/projectstore.h> +#include "localrefpolicy.h" #include "admin/admin.h" #include "buildstore/httpbuildstore.h" @@ -65,15 +66,16 @@ private: void InitializeServices(const ZenStorageServerConfig& ServerOptions); void RegisterServices(); - std::unique_ptr<JobQueue> m_JobQueue; - GcManager m_GcManager; - GcScheduler m_GcScheduler{m_GcManager}; - std::unique_ptr<CidStore> m_CidStore; - Ref<ZenCacheStore> m_CacheStore; - std::unique_ptr<OpenProcessCache> m_OpenProcessCache; - HttpTestService m_TestService; - std::unique_ptr<CidStore> m_BuildCidStore; - std::unique_ptr<BuildStore> m_BuildStore; + std::unique_ptr<DataRootLocalRefPolicy> m_LocalRefPolicy; + std::unique_ptr<JobQueue> m_JobQueue; + GcManager m_GcManager; + GcScheduler m_GcScheduler{m_GcManager}; + std::unique_ptr<CidStore> m_CidStore; + Ref<ZenCacheStore> m_CacheStore; + std::unique_ptr<OpenProcessCache> m_OpenProcessCache; + HttpTestService m_TestService; + std::unique_ptr<CidStore> m_BuildCidStore; + std::unique_ptr<BuildStore> m_BuildStore; #if ZEN_WITH_TESTS HttpTestingService m_TestingService; diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index 6aa02eb87..087b40d6a 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -272,6 +272,8 @@ ZenServerBase::GetBuildOptions(StringBuilderBase& OutOptions, char Separator) co OutOptions << Separator; OutOptions << "ZEN_WITH_MEMTRACK=" << (ZEN_WITH_MEMTRACK ? "1" : "0"); OutOptions << Separator; + OutOptions << "ZEN_WITH_COMPUTE_SERVICES=" << (ZEN_WITH_COMPUTE_SERVICES ? "1" : "0"); + OutOptions << Separator; OutOptions << "ZEN_WITH_TRACE=" << (ZEN_WITH_TRACE ? "1" : "0"); } diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp index 54e54edde..73cb7ff2d 100644 --- a/src/zentest-appstub/zentest-appstub.cpp +++ b/src/zentest-appstub/zentest-appstub.cpp @@ -33,6 +33,13 @@ using namespace zen; // Some basic functions to implement some test "compute" functions +struct ForcedExitException : std::exception +{ + int Code; + explicit ForcedExitException(int InCode) : Code(InCode) {} + const char* what() const noexcept override { return "forced exit"; } +}; + std::string Rot13Function(std::string_view InputString) { @@ -111,6 +118,16 @@ DescribeFunctions() << "Sleep"sv; Versions << "Version"sv << Guid::FromString("88888888-8888-8888-8888-888888888888"sv); Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Fail"sv; + Versions << "Version"sv << Guid::FromString("fa11fa11-fa11-fa11-fa11-fa11fa11fa11"sv); + Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Crash"sv; + Versions << "Version"sv << Guid::FromString("c4a50000-c4a5-c4a5-c4a5-c4a5c4a5c4a5"sv); + Versions.EndObject(); Versions.EndArray(); return Versions.Save(); @@ -201,6 +218,38 @@ ExecuteFunction(CbObject Action, ContentResolver ChunkResolver) zen::Sleep(static_cast<int>(SleepTimeMs)); return Apply(IdentityFunction); } + else if (Function == "Fail"sv) + { + int FailExitCode = static_cast<int>(Action["Constants"sv].AsObjectView()["ExitCode"sv].AsUInt64()); + if (FailExitCode == 0) + { + FailExitCode = 1; + } + throw ForcedExitException(FailExitCode); + } + else if (Function == "Crash"sv) + { + // Crash modes: + // "abort" - calls std::abort() (SIGABRT / process termination) + // "nullptr" - dereferences a null pointer (SIGSEGV / access violation) + std::string_view Mode = Action["Constants"sv].AsObjectView()["Mode"sv].AsString(); + + printf("[zentest] crashing with mode: %.*s\n", int(Mode.size()), Mode.data()); + fflush(stdout); + + if (Mode == "nullptr"sv) + { + volatile int* Ptr = nullptr; + *Ptr = 42; + } + + // Default crash mode (also reached after nullptr write on platforms + // that don't immediately fault on null dereference) +#if defined(_MSC_VER) + _set_abort_behavior(0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); +#endif + std::abort(); + } else { return {}; @@ -421,6 +470,12 @@ main(int argc, char* argv[]) } } } + catch (ForcedExitException& Ex) + { + printf("[zentest] forced exit with code: %d\n", Ex.Code); + + ExitCode = Ex.Code; + } catch (std::exception& Ex) { printf("[zentest] exception caught in main: '%s'\n", Ex.what()); diff --git a/src/zenutil/include/zenutil/process/subprocessmanager.h b/src/zenutil/include/zenutil/process/subprocessmanager.h index 4a25170df..e16c0c446 100644 --- a/src/zenutil/include/zenutil/process/subprocessmanager.h +++ b/src/zenutil/include/zenutil/process/subprocessmanager.h @@ -95,14 +95,19 @@ public: /// Spawn a new child process and begin monitoring it. /// /// If Options.StdoutPipe is set, the pipe is consumed and async reading - /// begins automatically. Similarly for Options.StderrPipe. + /// begins automatically. Similarly for Options.StderrPipe. When providing + /// pipes, pass the corresponding data callback here so it is installed + /// before the first async read completes — setting it later via + /// SetStdoutCallback risks losing early output. /// /// Returns a non-owning pointer valid until Remove() or manager destruction. /// The exit callback fires on an io_context thread when the process terminates. ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout = {}, + ProcessDataCallback OnStderr = {}); /// Adopt an already-running process by handle. Takes ownership of handle internals. ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); @@ -182,12 +187,6 @@ public: /// yet computed. [[nodiscard]] float GetCpuUsagePercent() const; - /// Set per-process stdout callback (overrides manager default). - void SetStdoutCallback(ProcessDataCallback Callback); - - /// Set per-process stderr callback (overrides manager default). - void SetStderrCallback(ProcessDataCallback Callback); - /// Return all stdout captured so far. When a callback is set, output is /// delivered there instead of being accumulated. [[nodiscard]] std::string GetCapturedStdout() const; @@ -237,11 +236,14 @@ public: /// Group name (as passed to CreateGroup). [[nodiscard]] std::string_view GetName() const; - /// Spawn a process into this group. + /// Spawn a process into this group. See SubprocessManager::Spawn for + /// details on the stdout/stderr callback parameters. ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout = {}, + ProcessDataCallback OnStderr = {}); /// Adopt an already-running process into this group. /// On Windows the process is assigned to the group's JobObject. diff --git a/src/zenutil/logging/jsonformatter.cpp b/src/zenutil/logging/jsonformatter.cpp index 673a03c94..c63ad891e 100644 --- a/src/zenutil/logging/jsonformatter.cpp +++ b/src/zenutil/logging/jsonformatter.cpp @@ -19,8 +19,6 @@ static void WriteEscapedString(MemoryBuffer& Dest, std::string_view Text) { // Strip ANSI SGR sequences before escaping so they don't appear in JSON output - static const auto IsEscapeStart = [](char C) { return C == '\033'; }; - const char* RangeStart = Text.data(); const char* End = Text.data() + Text.size(); diff --git a/src/zenutil/process/subprocessmanager.cpp b/src/zenutil/process/subprocessmanager.cpp index 3a91b0a61..b053ac6bd 100644 --- a/src/zenutil/process/subprocessmanager.cpp +++ b/src/zenutil/process/subprocessmanager.cpp @@ -196,18 +196,6 @@ ManagedProcess::GetCpuUsagePercent() const return m_Impl->m_CpuUsagePercent.load(); } -void -ManagedProcess::SetStdoutCallback(ProcessDataCallback Callback) -{ - m_Impl->m_StdoutCallback = std::move(Callback); -} - -void -ManagedProcess::SetStderrCallback(ProcessDataCallback Callback) -{ - m_Impl->m_StderrCallback = std::move(Callback); -} - std::string ManagedProcess::GetCapturedStdout() const { @@ -288,7 +276,9 @@ struct SubprocessManager::Impl ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr); ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); void Remove(int Pid); void RemoveAll(); @@ -462,7 +452,9 @@ ManagedProcess* SubprocessManager::Impl::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { bool HasStdout = Options.StdoutPipe != nullptr; bool HasStderr = Options.StderrPipe != nullptr; @@ -476,6 +468,16 @@ SubprocessManager::Impl::Spawn(const std::filesystem::path& Executable, ImplPtr->m_Handle.Initialize(static_cast<int>(Result)); #endif + // Install callbacks before starting async readers so no data is missed. + if (OnStdout) + { + ImplPtr->m_StdoutCallback = std::move(OnStdout); + } + if (OnStderr) + { + ImplPtr->m_StderrCallback = std::move(OnStderr); + } + auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr))); ManagedProcess* Ptr = AddProcess(std::move(Proc)); @@ -719,10 +721,12 @@ ManagedProcess* SubprocessManager::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { ZEN_TRACE_CPU("SubprocessManager::Spawn"); - return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit)); + return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr)); } ManagedProcess* @@ -835,7 +839,9 @@ struct ProcessGroup::Impl ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr); ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); void Remove(int Pid); void KillAll(); @@ -884,7 +890,9 @@ ManagedProcess* ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { bool HasStdout = Options.StdoutPipe != nullptr; bool HasStderr = Options.StderrPipe != nullptr; @@ -917,6 +925,16 @@ ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable, } #endif + // Install callbacks before starting async readers so no data is missed. + if (OnStdout) + { + ImplPtr->m_StdoutCallback = std::move(OnStdout); + } + if (OnStderr) + { + ImplPtr->m_StderrCallback = std::move(OnStderr); + } + auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr))); ManagedProcess* Ptr = AddProcess(std::move(Proc)); @@ -1077,10 +1095,12 @@ ManagedProcess* ProcessGroup::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { ZEN_TRACE_CPU("ProcessGroup::Spawn"); - return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit)); + return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr)); } ManagedProcess* @@ -1289,9 +1309,12 @@ TEST_CASE("SubprocessManager.StdoutCallback") std::string ReceivedData; bool Exited = false; - ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); - - Proc->SetStdoutCallback([&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); }); + ManagedProcess* Proc = Manager.Spawn( + AppStub, + CmdLine, + Options, + [&](ManagedProcess&, int) { Exited = true; }, + [&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); }); IoContext.run_for(5s); |