diff options
Diffstat (limited to 'src')
97 files changed, 5494 insertions, 985 deletions
diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index 93108dd47..cc8315e0b 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -1784,6 +1784,7 @@ namespace builds_impl { { OptionalStructuredOutput->AddString("path"sv, fmt::format("{}", Path)); OptionalStructuredOutput->AddInteger("rawSize"sv, RawSize); + OptionalStructuredOutput->AddHash("rawHash"sv, RawHash); switch (Platform) { case SourcePlatform::Windows: diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp index 30e860a3f..9719fce77 100644 --- a/src/zen/cmds/exec_cmd.cpp +++ b/src/zen/cmds/exec_cmd.cpp @@ -1119,6 +1119,7 @@ ExecHttpSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) ExecInprocSubCmd::ExecInprocSubCmd(ExecCommand& Parent) : ZenSubCmdBase("inproc", "Handle execution in-process"), m_Parent(Parent) { + m_SubOptions.add_option("managed", "", "managed", "Use managed local runner (if supported)", cxxopts::value(m_Managed), "<bool>"); } void @@ -1130,7 +1131,16 @@ ExecInprocSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) zen::compute::ComputeServiceSession ComputeSession(Resolver); std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - ComputeSession.AddLocalRunner(Resolver, TempPath); + if (m_Managed) + { + ZEN_CONSOLE_INFO("using managed local runner"); + ComputeSession.AddManagedLocalRunner(Resolver, TempPath); + } + else + { + ZEN_CONSOLE_INFO("using local runner"); + ComputeSession.AddLocalRunner(Resolver, TempPath); + } Stopwatch ExecTimer; int ReturnValue = m_Parent.RunSession(ComputeSession); diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h index c55412780..a0bf201a1 100644 --- a/src/zen/cmds/exec_cmd.h +++ b/src/zen/cmds/exec_cmd.h @@ -61,6 +61,7 @@ public: private: ExecCommand& m_Parent; + bool m_Managed = false; }; class ExecBeaconSubCmd : public ZenSubCmdBase diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md index a1a39fc3c..750879d5a 100644 --- a/src/zencompute/CLAUDE.md +++ b/src/zencompute/CLAUDE.md @@ -141,7 +141,7 @@ Actions that fail or are abandoned can be automatically retried or manually resc **Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure. -**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. +**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Setting `max_retries=0` disables automatic retry entirely; omitting the field (or setting it to a negative value) uses the default of 3. Both automatic and manual paths respect this limit. **Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent. @@ -156,7 +156,7 @@ Queues group actions from a single client session. A `QueueEntry` (internal) tra - `ActiveLsns` — for cancellation lookup (under `m_Lock`) - `FinishedLsns` — moved here when actions complete - `IdleSince` — used for 15-minute automatic expiry -- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit +- `Config` — CbObject set at creation; supports `max_retries` (int, default 3) to override the default retry limit. `0` = no retries, negative or absent = use default **Queue state machine (`QueueState` enum):** ``` @@ -216,11 +216,7 @@ Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `Han ## Concurrency Model -**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks: -1. `m_ResultsLock` -2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock") -3. `m_PendingLock` -4. `m_QueueLock` +**Locking discipline:** The three action maps (`m_PendingActions`, `m_RunningMap`, `m_ResultsMap`) are guarded by a single `m_ActionMapLock`. This eliminates lock-ordering concerns between maps and prevents actions from being temporarily absent from all maps during state transitions. Runner-level `m_RunningLock` in `LocalProcessRunner` / `RemoteHttpRunner` is a separate lock on a different class — unrelated to the session-level action map lock. **Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`. diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp index 92901de64..aaf34cbe2 100644 --- a/src/zencompute/computeservice.cpp +++ b/src/zencompute/computeservice.cpp @@ -8,6 +8,8 @@ # include "recording/actionrecorder.h" # include "runners/localrunner.h" # include "runners/remotehttprunner.h" +# include "runners/managedrunner.h" +# include "pathvalidation.h" # if ZEN_PLATFORM_LINUX # include "runners/linuxrunner.h" # elif ZEN_PLATFORM_WINDOWS @@ -195,13 +197,9 @@ struct ComputeServiceSession::Impl std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr}; - RwLock m_PendingLock; - std::map<int, Ref<RunnerAction>> m_PendingActions; - - RwLock m_RunningLock; + RwLock m_ActionMapLock; // Guards m_PendingActions, m_RunningMap, m_ResultsMap + std::map<int, Ref<RunnerAction>> m_PendingActions; std::unordered_map<int, Ref<RunnerAction>> m_RunningMap; - - RwLock m_ResultsLock; std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap; metrics::Meter m_ResultRate; std::atomic<uint64_t> m_RetiredCount{0}; @@ -343,9 +341,12 @@ struct ComputeServiceSession::Impl ActionCounts GetActionCounts() { ActionCounts Counts; - Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); - Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load(); + m_ActionMapLock.WithSharedLock([&] { + Counts.Pending = (int)m_PendingActions.size(); + Counts.Running = (int)m_RunningMap.size(); + Counts.Completed = (int)m_ResultsMap.size(); + }); + Counts.Completed += (int)m_RetiredCount.load(); Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] { size_t Count = 0; for (const auto& [Id, Queue] : m_Queues) @@ -364,8 +365,10 @@ struct ComputeServiceSession::Impl { Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed)); m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); - m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); - m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); + m_ActionMapLock.WithSharedLock([&] { + Cbo << "actions_complete"sv << m_ResultsMap.size(); + Cbo << "actions_pending"sv << m_PendingActions.size(); + }); Cbo << "actions_submitted"sv << GetSubmittedActionCount(); EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo); EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); @@ -450,27 +453,17 @@ ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState) void ComputeServiceSession::Impl::AbandonAllActions() { - // Collect all pending actions and mark them as Abandoned + // Collect all pending and running actions under a single lock scope std::vector<Ref<RunnerAction>> PendingToAbandon; + std::vector<Ref<RunnerAction>> RunningToAbandon; - m_PendingLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { PendingToAbandon.reserve(m_PendingActions.size()); for (auto& [Lsn, Action] : m_PendingActions) { PendingToAbandon.push_back(Action); } - }); - - for (auto& Action : PendingToAbandon) - { - Action->SetActionState(RunnerAction::State::Abandoned); - } - // Collect all running actions and mark them as Abandoned, then - // best-effort cancel via the local runner group - std::vector<Ref<RunnerAction>> RunningToAbandon; - - m_RunningLock.WithSharedLock([&] { RunningToAbandon.reserve(m_RunningMap.size()); for (auto& [Lsn, Action] : m_RunningMap) { @@ -478,6 +471,11 @@ ComputeServiceSession::Impl::AbandonAllActions() } }); + for (auto& Action : PendingToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + } + for (auto& Action : RunningToAbandon) { Action->SetActionState(RunnerAction::State::Abandoned); @@ -742,7 +740,7 @@ std::vector<ComputeServiceSession::RunningActionInfo> ComputeServiceSession::Impl::GetRunningActions() { std::vector<ComputeServiceSession::RunningActionInfo> Result; - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { Result.reserve(m_RunningMap.size()); for (const auto& [Lsn, Action] : m_RunningMap) { @@ -810,6 +808,11 @@ void ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) { ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker"); + + // Validate all paths in the worker description upfront, before the worker is + // distributed to runners. This rejects malicious packages early at ingestion time. + ValidateWorkerDescriptionPaths(Worker.GetObject()); + RwLock::ExclusiveLockScope _(m_WorkerLock); const IoHash& WorkerId = Worker.GetObject().GetHash(); @@ -994,10 +997,15 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke Pending->ActionObj = ActionObj; Pending->Priority = RequestPriority; - // For now simply put action into pending state, so we can do batch scheduling + // Insert into the pending map immediately so the action is visible to + // FindActionResult/GetActionResult right away. SetActionState will call + // PostUpdate which adds the action to m_UpdatedActions and signals the + // scheduler, but the scheduler's HandleActionUpdates inserts with + // std::map::insert which is a no-op for existing keys. ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Pending}); }); Pending->SetActionState(RunnerAction::State::Pending); if (m_Recorder) @@ -1043,11 +1051,7 @@ ComputeServiceSession::Impl::GetSubmittedActionCount() HttpResponseCode ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) { - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); + RwLock::ExclusiveLockScope _(m_ActionMapLock); if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) { @@ -1058,25 +1062,14 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult return HttpResponseCode::OK; } + if (m_PendingActions.find(ActionLsn) != m_PendingActions.end()) { - RwLock::SharedLockScope __(m_PendingLock); - - if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - + if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) { - RwLock::SharedLockScope __(m_RunningLock); - - if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } return HttpResponseCode::NotFound; @@ -1085,11 +1078,7 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult HttpResponseCode ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) { - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); + RwLock::ExclusiveLockScope _(m_ActionMapLock); for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) { @@ -1103,30 +1092,19 @@ ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& } } + for (const auto& [K, Pending] : m_PendingActions) { - RwLock::SharedLockScope __(m_PendingLock); - - for (const auto& [K, Pending] : m_PendingActions) + if (Pending->ActionId == ActionId) { - if (Pending->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } } - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - + for (const auto& [K, v] : m_RunningMap) { - RwLock::SharedLockScope __(m_RunningLock); - - for (const auto& [K, v] : m_RunningMap) + if (v->ActionId == ActionId) { - if (v->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } } @@ -1144,7 +1122,7 @@ ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo) { Cbo.BeginArray("completed"); - m_ResultsLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (auto& [Lsn, Action] : m_ResultsMap) { Cbo.BeginObject(); @@ -1275,20 +1253,14 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId) std::vector<Ref<RunnerAction>> PendingActionsToCancel; std::vector<int> RunningLsnsToCancel; - m_PendingLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (int Lsn : LsnsToCancel) { if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end()) { PendingActionsToCancel.push_back(It->second); } - } - }); - - m_RunningLock.WithSharedLock([&] { - for (int Lsn : LsnsToCancel) - { - if (m_RunningMap.find(Lsn) != m_RunningMap.end()) + else if (m_RunningMap.find(Lsn) != m_RunningMap.end()) { RunningLsnsToCancel.push_back(Lsn); } @@ -1307,7 +1279,7 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId) // transition from the runner is blocked (Cancelled > Failed in the enum). for (int Lsn : RunningLsnsToCancel) { - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end()) { It->second->SetActionState(RunnerAction::State::Cancelled); @@ -1445,7 +1417,7 @@ ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo) if (Queue) { Queue->m_Lock.WithSharedLock([&] { - m_ResultsLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (int Lsn : Queue->FinishedLsns) { if (m_ResultsMap.contains(Lsn)) @@ -1475,15 +1447,19 @@ ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, Run return; } + bool WasActive = false; Queue->m_Lock.WithExclusiveLock([&] { - Queue->ActiveLsns.erase(Lsn); + WasActive = Queue->ActiveLsns.erase(Lsn) > 0; Queue->FinishedLsns.insert(Lsn); }); - const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed); - if (PreviousActive == 1) + if (WasActive) { - Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed); + const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed); + if (PreviousActive == 1) + { + Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed); + } } switch (ActionState) @@ -1541,9 +1517,15 @@ ComputeServiceSession::Impl::SchedulePendingActions() { ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions"); int ScheduledCount = 0; - size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); - size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); + size_t RunningCount = 0; + size_t PendingCount = 0; + size_t ResultCount = 0; + + m_ActionMapLock.WithSharedLock([&] { + RunningCount = m_RunningMap.size(); + PendingCount = m_PendingActions.size(); + ResultCount = m_ResultsMap.size(); + }); static Stopwatch DumpRunningTimer; @@ -1560,7 +1542,7 @@ ComputeServiceSession::Impl::SchedulePendingActions() DumpRunningTimer.Reset(); std::set<int> RunningList; - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (auto& [K, V] : m_RunningMap) { RunningList.insert(K); @@ -1602,7 +1584,7 @@ ComputeServiceSession::Impl::SchedulePendingActions() // Also note that the m_PendingActions list is not maintained // here, that's done periodically in SchedulePendingActions() - m_PendingLock.WithExclusiveLock([&] { + m_ActionMapLock.WithExclusiveLock([&] { if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) { return; @@ -1701,7 +1683,7 @@ ComputeServiceSession::Impl::SchedulerThreadFunction() { int TimeoutMs = 500; - auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + auto PendingCount = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.size(); }); if (PendingCount) { @@ -1720,22 +1702,22 @@ ComputeServiceSession::Impl::SchedulerThreadFunction() m_SchedulingThreadEvent.Reset(); } - ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}", - m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }), - PendingCount, - m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }), - m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }), - TimeoutMs); + m_ActionMapLock.WithSharedLock([&] { + ZEN_DEBUG("compute scheduler TICK (Pending: {}, Running: {}, Results: {}) timeout: {}", + m_PendingActions.size(), + m_RunningMap.size(), + m_ResultsMap.size(), + TimeoutMs); + }); HandleActionUpdates(); // Auto-transition Draining → Paused when all work is done if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining) { - size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + bool AllDrained = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.empty() && m_RunningMap.empty(); }); - if (Pending == 0 && Running == 0) + if (AllDrained) { SessionState Expected = SessionState::Draining; if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel)) @@ -1776,9 +1758,9 @@ ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId) if (Config) { - int Value = Config["max_retries"].AsInt32(0); + int Value = Config["max_retries"].AsInt32(-1); - if (Value > 0) + if (Value >= 0) { return Value; } @@ -1797,7 +1779,7 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) // Find, validate, and remove atomically under a single lock scope to prevent // concurrent RescheduleAction calls from double-removing the same action. - m_ResultsLock.WithExclusiveLock([&] { + m_ActionMapLock.WithExclusiveLock([&] { auto It = m_ResultsMap.find(ActionLsn); if (It == m_ResultsMap.end()) { @@ -1871,26 +1853,20 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn) bool WasRunning = false; // Look for the action in pending or running maps - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end()) { Action = It->second; WasRunning = true; } + else if (auto PIt = m_PendingActions.find(ActionLsn); PIt != m_PendingActions.end()) + { + Action = PIt->second; + } }); if (!Action) { - m_PendingLock.WithSharedLock([&] { - if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end()) - { - Action = It->second; - } - }); - } - - if (!Action) - { return {.Success = false, .Error = "Action not found in pending or running maps"}; } @@ -1912,18 +1888,15 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn) void ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn) { - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); + // Caller must hold m_ActionMapLock exclusively. + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } } void @@ -1973,7 +1946,7 @@ ComputeServiceSession::Impl::HandleActionUpdates() } else { - m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); } break; @@ -1983,11 +1956,9 @@ ComputeServiceSession::Impl::HandleActionUpdates() // Dispatched to a runner — move from pending to running case RunnerAction::State::Running: - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - m_RunningMap.insert({ActionLsn, Action}); - m_PendingActions.erase(ActionLsn); - }); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.insert({ActionLsn, Action}); + m_PendingActions.erase(ActionLsn); }); ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); break; @@ -1995,7 +1966,10 @@ ComputeServiceSession::Impl::HandleActionUpdates() // Retracted — pull back for rescheduling without counting against retry limit case RunnerAction::State::Retracted: { - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.erase(ActionLsn); + m_PendingActions[ActionLsn] = Action; + }); Action->ResetActionStateToPending(); ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn); break; @@ -2019,7 +1993,10 @@ ComputeServiceSession::Impl::HandleActionUpdates() if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) { - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.erase(ActionLsn); + m_PendingActions[ActionLsn] = Action; + }); // Reset triggers PostUpdate() which re-enters the action as Pending Action->ResetActionStateToPending(); @@ -2034,16 +2011,16 @@ ComputeServiceSession::Impl::HandleActionUpdates() } } - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + RemoveActionFromActiveMaps(ActionLsn); - // Update queue counters BEFORE publishing the result into - // m_ResultsMap. GetActionResult erases from m_ResultsMap - // under m_ResultsLock, so if we updated counters after - // releasing that lock, a caller could observe ActiveCount - // still at 1 immediately after GetActionResult returned OK. - NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); + // Update queue counters BEFORE publishing the result into + // m_ResultsMap. GetActionResult erases from m_ResultsMap + // under m_ActionMapLock, so if we updated counters after + // releasing that lock, a caller could observe ActiveCount + // still at 1 immediately after GetActionResult returned OK. + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); - m_ResultsLock.WithExclusiveLock([&] { m_ResultsMap[ActionLsn] = Action; // Append to bounded action history ring @@ -2282,6 +2259,18 @@ ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::files } void +ComputeServiceSession::AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddManagedLocalRunner"); + + auto* NewRunner = + new ManagedProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, MaxConcurrentActions); + + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner); +} + +void ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) { ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner"); diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp index bdfd9d197..1d28e7137 100644 --- a/src/zencompute/httpcomputeservice.cpp +++ b/src/zencompute/httpcomputeservice.cpp @@ -93,16 +93,17 @@ struct HttpComputeService::Impl uint64_t NewBytes = 0; }; - IngestStats IngestPackageAttachments(const CbPackage& Package); - bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); - void HandleWorkersGet(HttpServerRequest& HttpReq); - void HandleWorkersAllGet(HttpServerRequest& HttpReq); - void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); - void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); - void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker); + bool IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats); + bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); + bool ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment); + void HandleWorkersGet(HttpServerRequest& HttpReq); + void HandleWorkersAllGet(HttpServerRequest& HttpReq); + void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); + void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker); // WebSocket / observer - void OnWebSocketOpen(Ref<WebSocketConnection> Connection); + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri); void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code); void OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions); @@ -373,7 +374,7 @@ HttpComputeService::Impl::RegisterRoutes() if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output); ResponseCode != HttpResponseCode::OK) { - ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) + ZEN_DEBUG("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) if (ResponseCode == HttpResponseCode::NotFound) { @@ -1167,35 +1168,81 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin return ParseInt<int>(Capture).value_or(0); } -HttpComputeService::Impl::IngestStats -HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package) +bool +HttpComputeService::Impl::ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment) { - IngestStats Stats; + const IoHash ClaimedHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + const IoHash HeaderHash = Buffer.DecodeRawHash(); + + if (HeaderHash != ClaimedHash) + { + ZEN_WARN("attachment header hash mismatch: claimed {} but header contains {}", ClaimedHash, HeaderHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + IoHashStream Hasher; + + bool DecompressOk = Buffer.DecompressToStream( + 0, + Buffer.DecodeRawSize(), + [&](uint64_t /*SourceOffset*/, uint64_t /*SourceSize*/, uint64_t /*Offset*/, const CompositeBuffer& Range) -> bool { + for (const SharedBuffer& Segment : Range.GetSegments()) + { + Hasher.Append(Segment.GetView()); + } + return true; + }); + + if (!DecompressOk) + { + ZEN_WARN("attachment {}: failed to decompress", ClaimedHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + const IoHash ActualHash = Hasher.GetHash(); + + if (ActualHash != ClaimedHash) + { + ZEN_WARN("attachment hash mismatch: claimed {} but decompressed data hashes to {}", ClaimedHash, ActualHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + return true; +} +bool +HttpComputeService::Impl::IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats) +{ for (const CbAttachment& Attachment : Package.GetAttachments()) { ZEN_ASSERT(Attachment.IsCompressedBinary()); - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); + if (!ValidateAttachmentHash(HttpReq, Attachment)) + { + return false; + } - const uint64_t CompressedSize = DataView.GetCompressedSize(); + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + const uint64_t CompressedSize = DataView.GetCompressedSize(); - Stats.Bytes += CompressedSize; - ++Stats.Count; + OutStats.Bytes += CompressedSize; + ++OutStats.Count; const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); if (InsertResult.New) { - Stats.NewBytes += CompressedSize; - ++Stats.NewCount; + OutStats.NewBytes += CompressedSize; + ++OutStats.NewCount; } } - return Stats; + return true; } bool @@ -1253,7 +1300,10 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que { CbPackage Package = HttpReq.ReadPayloadPackage(); Body = Package.GetObject(); - Stats = IngestPackageAttachments(Package); + if (!IngestPackageAttachments(HttpReq, Package, Stats)) + { + return; // validation failed, response already written + } break; } @@ -1268,8 +1318,7 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que { // --- Batch path --- - // For CbObject payloads, check all attachments upfront before enqueuing anything - if (HttpReq.RequestContentType() == HttpContentType::kCbObject) + // Verify all action attachment references exist in the store { std::vector<IoHash> NeedList; @@ -1345,7 +1394,6 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que // --- Single-action path: Body is the action itself --- - if (HttpReq.RequestContentType() == HttpContentType::kCbObject) { std::vector<IoHash> NeedList; @@ -1491,10 +1539,14 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const { ZEN_ASSERT(Attachment.IsCompressedBinary()); + if (!ValidateAttachmentHash(HttpReq, Attachment)) + { + return; + } + const IoHash DataHash = Attachment.GetHash(); CompressedBuffer Buffer = Attachment.AsCompressedBinary(); - ZEN_UNUSED(DataHash); TotalAttachmentBytes += Buffer.GetCompressedSize(); ++AttachmentCount; @@ -1537,9 +1589,9 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const // void -HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { - m_Impl->OnWebSocketOpen(std::move(Connection)); + m_Impl->OnWebSocketOpen(std::move(Connection), RelativeUri); } void @@ -1566,8 +1618,9 @@ HttpComputeService::OnActionsCompleted(std::span<const CompletedActionNotificati // void -HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); ZEN_INFO("compute WebSocket client connected"); m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); } diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp index 6cbe01e04..d92af8716 100644 --- a/src/zencompute/httporchestrator.cpp +++ b/src/zencompute/httporchestrator.cpp @@ -418,8 +418,9 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) # if ZEN_WITH_WEBSOCKETS void -HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); if (!m_PushEnabled.load()) { return; diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h index 1ca78738a..ad556f546 100644 --- a/src/zencompute/include/zencompute/computeservice.h +++ b/src/zencompute/include/zencompute/computeservice.h @@ -167,6 +167,7 @@ public: // Action runners void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); + void AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); // Action submission diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h index b58e73a0d..de85a295f 100644 --- a/src/zencompute/include/zencompute/httpcomputeservice.h +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -45,7 +45,7 @@ public: // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h index da5c5dfc3..58b2c9152 100644 --- a/src/zencompute/include/zencompute/httporchestrator.h +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -70,7 +70,7 @@ public: // IWebSocketHandler #if ZEN_WITH_WEBSOCKETS - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; #endif diff --git a/src/zencompute/pathvalidation.h b/src/zencompute/pathvalidation.h new file mode 100644 index 000000000..c2e30183a --- /dev/null +++ b/src/zencompute/pathvalidation.h @@ -0,0 +1,118 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/except_fmt.h> +#include <zencore/string.h> + +#include <filesystem> +#include <string_view> + +namespace zen::compute { + +// Validate that a single path component contains only characters that are valid +// file/directory names on all supported platforms. Uses Windows rules as the most +// restrictive superset, since packages may be built on one platform and consumed +// on another. +inline void +ValidatePathComponent(std::string_view Component, std::string_view FullPath) +{ + // Reject control characters (0x00-0x1F) and characters forbidden on Windows + for (char Ch : Component) + { + if (static_cast<unsigned char>(Ch) < 0x20 || Ch == '<' || Ch == '>' || Ch == ':' || Ch == '"' || Ch == '|' || Ch == '?' || + Ch == '*') + { + throw zen::invalid_argument("invalid character in path component '{}' of '{}'", Component, FullPath); + } + } + + // Reject empty components and trailing dots or spaces (silently stripped on Windows, leading to confusion) + if (Component.empty() || Component.back() == '.' || Component.back() == ' ') + { + throw zen::invalid_argument("path component '{}' of '{}' has trailing dot or space", Component, FullPath); + } + + // Reject Windows reserved device names (CON, PRN, AUX, NUL, COM1-9, LPT1-9) + // These are reserved with or without an extension (e.g. "CON.txt" is still reserved). + std::string_view Stem = Component.substr(0, Component.find('.')); + + static constexpr std::string_view ReservedNames[] = { + "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", + "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", + }; + + for (std::string_view Reserved : ReservedNames) + { + if (zen::StrCaseCompare(Stem, Reserved) == 0) + { + throw zen::invalid_argument("path component '{}' of '{}' uses reserved device name '{}'", Component, FullPath, Reserved); + } + } +} + +// Validate that a path extracted from a package is a safe relative path. +// Rejects absolute paths, ".." components, and invalid platform filenames. +inline void +ValidateSandboxRelativePath(std::string_view Name) +{ + if (Name.empty()) + { + throw zen::invalid_argument("path traversal detected: empty path name"); + } + + std::filesystem::path Parsed(Name); + + if (Parsed.is_absolute()) + { + throw zen::invalid_argument("path traversal detected: '{}' is an absolute path", Name); + } + + for (const auto& Component : Parsed) + { + std::string ComponentStr = Component.string(); + + if (ComponentStr == "..") + { + throw zen::invalid_argument("path traversal detected: '{}' contains '..' component", Name); + } + + // Skip "." (current directory) — harmless in relative paths + if (ComponentStr != ".") + { + ValidatePathComponent(ComponentStr, Name); + } + } +} + +// Validate all path entries in a worker description CbObject. +// Checks path, executables[].name, dirs[], and files[].name fields. +// Throws an exception if any invalid paths are found. +inline void +ValidateWorkerDescriptionPaths(const CbObject& WorkerDescription) +{ + using namespace std::literals; + + if (auto PathField = WorkerDescription["path"sv]; PathField.HasValue()) + { + ValidateSandboxRelativePath(PathField.AsString()); + } + + for (auto& It : WorkerDescription["executables"sv]) + { + ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString()); + } + + for (auto& It : WorkerDescription["dirs"sv]) + { + ValidateSandboxRelativePath(It.AsString()); + } + + for (auto& It : WorkerDescription["files"sv]) + { + ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString()); + } +} + +} // namespace zen::compute diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp index 4f116e7d8..67e12b84e 100644 --- a/src/zencompute/runners/functionrunner.cpp +++ b/src/zencompute/runners/functionrunner.cpp @@ -164,8 +164,9 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } } - // Assign any remaining actions to runners with capacity (round-robin) - for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount) + // Assign any remaining actions to runners with capacity (round-robin). + // Cap at TotalCapacity to avoid spinning when there are more actions than runners can accept. + for (int i = 0; ActionIdx < Actions.size() && ActionIdx < TotalCapacity; i = (i + 1) % RunnerCount) { if (Capacities[i] > PerRunnerActions[i].size()) { @@ -186,11 +187,12 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } } - // Reassemble results in original action order - std::vector<SubmitResult> Results(Actions.size()); + // Reassemble results in original action order. + // Actions beyond ActionIdx were not assigned to any runner (insufficient capacity). + std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"}); std::vector<size_t> PerRunnerIdx(RunnerCount, 0); - for (size_t i = 0; i < Actions.size(); ++i) + for (size_t i = 0; i < ActionIdx; ++i) { size_t RunnerIdx = ActionRunnerIndex[i]; size_t Idx = PerRunnerIdx[RunnerIdx]++; diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp index e79a6c90f..9055005d9 100644 --- a/src/zencompute/runners/linuxrunner.cpp +++ b/src/zencompute/runners/linuxrunner.cpp @@ -331,6 +331,8 @@ LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver, { ZEN_INFO("namespace sandboxing enabled for child processes"); } + + StartMonitorThread(); } SubmitResult diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp index b61e0a46f..1b748c0e5 100644 --- a/src/zencompute/runners/localrunner.cpp +++ b/src/zencompute/runners/localrunner.cpp @@ -4,6 +4,8 @@ #if ZEN_WITH_COMPUTE_SERVICES +# include "pathvalidation.h" + # include <zencore/compactbinary.h> # include <zencore/compactbinarybuilder.h> # include <zencore/compactbinarypackage.h> @@ -104,8 +106,6 @@ LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, ZEN_INFO("Cleanup complete"); } - m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; - # if ZEN_PLATFORM_WINDOWS // Suppress any error dialogs caused by missing dependencies UINT OldMode = ::SetErrorMode(0); @@ -382,6 +382,8 @@ LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromP const IoHash ChunkHash = FileEntry["hash"sv].AsHash(); const uint64_t Size = FileEntry["size"sv].AsUInt64(); + ValidateSandboxRelativePath(Name); + CompressedBuffer Compressed; if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash)) @@ -457,7 +459,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, for (auto& It : WorkerDescription["dirs"sv]) { - std::string_view Name = It.AsString(); + std::string_view Name = It.AsString(); + ValidateSandboxRelativePath(Name); std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; // Validate dir path stays within sandbox @@ -482,6 +485,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, } WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer()); + + ZEN_INFO("manifested worker '{}' in '{}'", WorkerPackage.GetObjectHash(), SandboxPath); } CbPackage @@ -540,6 +545,12 @@ LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) } void +LocalProcessRunner::StartMonitorThread() +{ + m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; +} + +void LocalProcessRunner::MonitorThreadFunction() { SetCurrentThreadName("LocalProcessRunner_Monitor"); diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h index b8cff6826..d6589db43 100644 --- a/src/zencompute/runners/localrunner.h +++ b/src/zencompute/runners/localrunner.h @@ -67,6 +67,7 @@ protected: { Ref<RunnerAction> Action; void* ProcessHandle = nullptr; + int Pid = 0; int ExitCode = 0; std::filesystem::path SandboxPath; @@ -83,8 +84,6 @@ protected: std::filesystem::path m_SandboxPath; int32_t m_MaxRunningActions = 64; // arbitrary limit for testing - // if used in conjuction with m_ResultsLock, this lock must be taken *after* - // m_ResultsLock to avoid deadlocks RwLock m_RunningLock; std::unordered_map<int, Ref<RunningAction>> m_RunningMap; @@ -95,6 +94,7 @@ protected: std::thread m_MonitorThread; std::atomic<bool> m_MonitorThreadEnabled{true}; Event m_MonitorThreadEvent; + void StartMonitorThread(); void MonitorThreadFunction(); virtual void SweepRunningActions(); virtual void CancelRunningActions(); diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp index 5cec90699..c2ccca9a6 100644 --- a/src/zencompute/runners/macrunner.cpp +++ b/src/zencompute/runners/macrunner.cpp @@ -130,6 +130,8 @@ MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver, { ZEN_INFO("Seatbelt sandboxing enabled for child processes"); } + + StartMonitorThread(); } SubmitResult diff --git a/src/zencompute/runners/managedrunner.cpp b/src/zencompute/runners/managedrunner.cpp new file mode 100644 index 000000000..e4a7ba388 --- /dev/null +++ b/src/zencompute/runners/managedrunner.cpp @@ -0,0 +1,279 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "managedrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/scopeguard.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio/io_context.hpp> +# include <asio/executor_work_guard.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +using namespace std::literals; + +ManagedProcessRunner::ManagedProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_IoContext(std::make_unique<asio::io_context>()) +, m_SubprocessManager(std::make_unique<SubprocessManager>(*m_IoContext)) +{ + m_ProcessGroup = m_SubprocessManager->CreateGroup("compute-workers"); + + // Run the io_context on a small thread pool so that exit callbacks and + // metrics sampling are dispatched without blocking each other. + for (int i = 0; i < kIoThreadCount; ++i) + { + m_IoThreads.emplace_back([this, i] { + SetCurrentThreadName(fmt::format("mrunner_{}", i)); + + // work_guard keeps run() alive even when there is no pending work yet + auto WorkGuard = asio::make_work_guard(*m_IoContext); + + m_IoContext->run(); + }); + } +} + +ManagedProcessRunner::~ManagedProcessRunner() +{ + try + { + Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception during managed process runner shutdown: {}", Ex.what()); + } +} + +void +ManagedProcessRunner::Shutdown() +{ + ZEN_TRACE_CPU("ManagedProcessRunner::Shutdown"); + m_AcceptNewActions = false; + + CancelRunningActions(); + + // Tear down the SubprocessManager before stopping the io_context so that + // any in-flight callbacks are drained cleanly. + if (m_SubprocessManager) + { + m_SubprocessManager->DestroyGroup("compute-workers"); + m_ProcessGroup = nullptr; + m_SubprocessManager.reset(); + } + + if (m_IoContext) + { + m_IoContext->stop(); + } + + for (std::thread& Thread : m_IoThreads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + m_IoThreads.clear(); +} + +SubmitResult +ManagedProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + // Parse environment variables from worker descriptor ("KEY=VALUE" strings) + // into the key-value pairs expected by CreateProcOptions. + std::vector<std::pair<std::string, std::string>> EnvPairs; + for (auto& It : WorkerDescription["environment"sv]) + { + std::string_view Str = It.AsString(); + size_t Eq = Str.find('='); + if (Eq != std::string_view::npos) + { + EnvPairs.emplace_back(std::string(Str.substr(0, Eq)), std::string(Str.substr(Eq + 1))); + } + } + + // Build command line + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred(); + + std::string CommandLine = fmt::format("\"{}\" -Build=build.action"sv, ExePath.string()); + + ZEN_DEBUG("Executing (managed): '{}' (sandbox='{}')", CommandLine, Prepared->SandboxPath); + + CreateProcOptions Options; + Options.WorkingDirectory = &Prepared->SandboxPath; + Options.Flags = CreateProcOptions::Flag_NoConsole; + Options.Environment = std::move(EnvPairs); + + const int32_t ActionLsn = Prepared->ActionLsn; + + ManagedProcess* Proc = nullptr; + + try + { + Proc = m_ProcessGroup->Spawn(ExePath, CommandLine, Options, [this, ActionLsn](ManagedProcess& /*Process*/, int ExitCode) { + OnProcessExit(ActionLsn, ExitCode); + }); + } + catch (std::exception& Ex) + { + ZEN_ERROR("Failed to spawn process for action LSN {}: {}", ActionLsn, Ex.what()); + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Prepared->SandboxPath)); + return SubmitResult{.IsAccepted = false}; + } + + { + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = static_cast<void*>(Proc); + NewAction->Pid = Proc->Pid(); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + ZEN_DEBUG("Managed runner: action LSN {} -> PID {}", ActionLsn, Proc->Pid()); + + return SubmitResult{.IsAccepted = true}; +} + +void +ManagedProcessRunner::OnProcessExit(int ActionLsn, int ExitCode) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::OnProcessExit"); + + Ref<RunningAction> Running; + + m_RunningLock.WithExclusiveLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It != m_RunningMap.end()) + { + Running = std::move(It->second); + m_RunningMap.erase(It); + } + }); + + if (!Running) + { + return; + } + + ZEN_DEBUG("Managed runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"), ActionLsn, Running->Pid, ExitCode); + + Running->ExitCode = ExitCode; + + // Capture final CPU metrics from the managed process before it is removed. + auto* Proc = static_cast<ManagedProcess*>(Running->ProcessHandle); + if (Proc) + { + ProcessMetrics Metrics = Proc->GetLatestMetrics(); + float CpuMs = static_cast<float>(Metrics.UserTimeMs + Metrics.KernelTimeMs); + Running->Action->CpuSeconds.store(CpuMs / 1000.0f, std::memory_order_relaxed); + + float CpuPct = Proc->GetCpuUsagePercent(); + if (CpuPct >= 0.0f) + { + Running->Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running->ProcessHandle = nullptr; + + std::vector<Ref<RunningAction>> CompletedActions; + CompletedActions.push_back(std::move(Running)); + ProcessCompletedActions(CompletedActions); +} + +void +ManagedProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("ManagedProcessRunner::CancelRunningActions"); + + std::unordered_map<int, Ref<RunningAction>> RunningMap; + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling {} running actions via process group", RunningMap.size()); + + Stopwatch Timer; + + // Kill all processes in the group atomically (TerminateJobObject on Windows, + // SIGTERM+SIGKILL on POSIX). + if (m_ProcessGroup) + { + m_ProcessGroup->KillAll(); + } + + for (auto& [Lsn, Running] : RunningMap) + { + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +ManagedProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::CancelAction"); + + ManagedProcess* Proc = nullptr; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It != m_RunningMap.end() && It->second->ProcessHandle != nullptr) + { + Proc = static_cast<ManagedProcess*>(It->second->ProcessHandle); + } + }); + + if (!Proc) + { + return false; + } + + // Terminate the process. The exit callback will handle the rest + // (remove from running map, gather outputs or mark failed). + Proc->Terminate(222); + + ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn); + return true; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/managedrunner.h b/src/zencompute/runners/managedrunner.h new file mode 100644 index 000000000..21a44d43c --- /dev/null +++ b/src/zencompute/runners/managedrunner.h @@ -0,0 +1,64 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zenutil/process/subprocessmanager.h> + +# include <memory> + +namespace asio { +class io_context; +} + +namespace zen::compute { + +/** Cross-platform process runner backed by SubprocessManager. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and shared action preparation. Replaces the polling-based + monitor thread with async exit callbacks driven by SubprocessManager, and + delegates CPU/memory metrics sampling to the manager's built-in round-robin + sampler. + + A ProcessGroup (backed by a JobObject on Windows, process group on POSIX) is + used for bulk cancellation on shutdown. + + This runner does not perform any platform-specific sandboxing (AppContainer, + namespaces, Seatbelt). It is intended as a simpler, cross-platform alternative + to the platform-specific runners for non-sandboxed workloads. + */ +class ManagedProcessRunner : public LocalProcessRunner +{ +public: + ManagedProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions = 0); + ~ManagedProcessRunner(); + + void Shutdown() override; + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + [[nodiscard]] bool IsHealthy() override { return true; } + +private: + static constexpr int kIoThreadCount = 4; + + // Exit callback posted on an io_context thread. + void OnProcessExit(int ActionLsn, int ExitCode); + + std::unique_ptr<asio::io_context> m_IoContext; + std::unique_ptr<SubprocessManager> m_SubprocessManager; + ProcessGroup* m_ProcessGroup = nullptr; + std::vector<std::thread> m_IoThreads; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp index cd4b646e9..92ee65c2d 100644 --- a/src/zencompute/runners/windowsrunner.cpp +++ b/src/zencompute/runners/windowsrunner.cpp @@ -21,6 +21,12 @@ ZEN_THIRD_PARTY_INCLUDES_START # include <sddl.h> ZEN_THIRD_PARTY_INCLUDES_END +// JOB_OBJECT_UILIMIT_ERRORMODE is defined in winuser.h which may be +// excluded by WIN32_LEAN_AND_MEAN. +# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE) +# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400 +# endif + namespace zen::compute { using namespace std::literals; @@ -34,38 +40,65 @@ WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver, : LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) , m_Sandboxed(Sandboxed) { - if (!m_Sandboxed) + // Create a job object shared by all child processes. Restricting the + // error-mode UI prevents crash dialogs (WER / Dr. Watson) from + // blocking the monitor thread when a worker process terminates + // abnormally. + m_JobObject = CreateJobObjectW(nullptr, nullptr); + if (m_JobObject) { - return; + JOBOBJECT_EXTENDED_LIMIT_INFORMATION ExtLimits{}; + ExtLimits.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION; + SetInformationJobObject(m_JobObject, JobObjectExtendedLimitInformation, &ExtLimits, sizeof(ExtLimits)); + + JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{}; + UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE; + SetInformationJobObject(m_JobObject, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions)); + + // Set error mode on this process so children inherit it. The + // UILIMIT_ERRORMODE restriction above prevents them from clearing + // SEM_NOGPFAULTERRORBOX. + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX); } - // Build a unique profile name per process to avoid collisions - m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); + if (m_Sandboxed) + { + // Build a unique profile name per process to avoid collisions + m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); - // Clean up any stale profile from a previous crash - DeleteAppContainerProfile(m_AppContainerName.c_str()); + // Clean up any stale profile from a previous crash + DeleteAppContainerProfile(m_AppContainerName.c_str()); - PSID Sid = nullptr; + PSID Sid = nullptr; - HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), - m_AppContainerName.c_str(), // display name - m_AppContainerName.c_str(), // description - nullptr, // no capabilities - 0, // capability count - &Sid); + HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), + m_AppContainerName.c_str(), // display name + m_AppContainerName.c_str(), // description + nullptr, // no capabilities + 0, // capability count + &Sid); - if (FAILED(Hr)) - { - throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); - } + if (FAILED(Hr)) + { + throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); + } - m_AppContainerSid = Sid; + m_AppContainerSid = Sid; + + ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); + } - ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); + StartMonitorThread(); } WindowsProcessRunner::~WindowsProcessRunner() { + if (m_JobObject) + { + CloseHandle(m_JobObject); + m_JobObject = nullptr; + } + if (m_AppContainerSid) { FreeSid(m_AppContainerSid); @@ -172,9 +205,9 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; BOOL bInheritHandles = FALSE; - DWORD dwCreationFlags = DETACHED_PROCESS; + DWORD dwCreationFlags = CREATE_SUSPENDED | DETACHED_PROCESS; - ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed); + ZEN_DEBUG("{}: '{}' (sandbox='{}')", m_Sandboxed ? "Sandboxing" : "Executing", WideToUtf8(CommandLine.c_str()), Prepared->SandboxPath); CommandLine.EnsureNulTerminated(); @@ -260,14 +293,21 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) } } - CloseHandle(ProcessInformation.hThread); + if (m_JobObject) + { + AssignProcessToJobObject(m_JobObject, ProcessInformation.hProcess); + } - Ref<RunningAction> NewAction{new RunningAction()}; - NewAction->Action = Action; - NewAction->ProcessHandle = ProcessInformation.hProcess; - NewAction->SandboxPath = std::move(Prepared->SandboxPath); + ResumeThread(ProcessInformation.hThread); + CloseHandle(ProcessInformation.hThread); { + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = ProcessInformation.hProcess; + NewAction->Pid = ProcessInformation.dwProcessId; + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + RwLock::ExclusiveLockScope _(m_RunningLock); m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); @@ -275,6 +315,8 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) Action->SetActionState(RunnerAction::State::Running); + ZEN_DEBUG("Local runner: action LSN {} -> PID {}", Action->ActionLsn, ProcessInformation.dwProcessId); + return SubmitResult{.IsAccepted = true}; } @@ -294,6 +336,11 @@ WindowsProcessRunner::SweepRunningActions() if (IsSuccess && ExitCode != STILL_ACTIVE) { + ZEN_DEBUG("Local runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"), + Running->Action->ActionLsn, + Running->Pid, + ExitCode); + CloseHandle(Running->ProcessHandle); Running->ProcessHandle = INVALID_HANDLE_VALUE; Running->ExitCode = ExitCode; diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h index 9f2385cc4..adeaf02fc 100644 --- a/src/zencompute/runners/windowsrunner.h +++ b/src/zencompute/runners/windowsrunner.h @@ -46,6 +46,7 @@ private: bool m_Sandboxed = false; PSID m_AppContainerSid = nullptr; std::wstring m_AppContainerName; + HANDLE m_JobObject = nullptr; }; } // namespace zen::compute diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp index 506bec73b..b4fafb467 100644 --- a/src/zencompute/runners/winerunner.cpp +++ b/src/zencompute/runners/winerunner.cpp @@ -36,6 +36,8 @@ WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver, sigemptyset(&Action.sa_mask); Action.sa_handler = SIG_DFL; sigaction(SIGCHLD, &Action, nullptr); + + StartMonitorThread(); } SubmitResult diff --git a/src/zencore/compactbinarypackage.cpp b/src/zencore/compactbinarypackage.cpp index 56a292ca6..cd268745c 100644 --- a/src/zencore/compactbinarypackage.cpp +++ b/src/zencore/compactbinarypackage.cpp @@ -684,14 +684,22 @@ namespace legacy { Writer.Save(Ar); } - bool TryLoadCbPackage(CbPackage& Package, IoBuffer InBuffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper) + bool TryLoadCbPackage(CbPackage& Package, + IoBuffer InBuffer, + BufferAllocator Allocator, + CbPackage::AttachmentResolver* Mapper, + bool ValidateHashes) { BinaryReader Reader(InBuffer.Data(), InBuffer.Size()); - return TryLoadCbPackage(Package, Reader, Allocator, Mapper); + return TryLoadCbPackage(Package, Reader, Allocator, Mapper, ValidateHashes); } - bool TryLoadCbPackage(CbPackage& Package, BinaryReader& Reader, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper) + bool TryLoadCbPackage(CbPackage& Package, + BinaryReader& Reader, + BufferAllocator Allocator, + CbPackage::AttachmentResolver* Mapper, + bool ValidateHashes) { Package = CbPackage(); for (;;) @@ -708,7 +716,11 @@ namespace legacy { if (ValueField.IsBinary()) { const MemoryView View = ValueField.AsBinaryView(); - if (View.GetSize() > 0) + if (View.GetSize() == 0) + { + return false; + } + else { SharedBuffer Buffer = SharedBuffer::MakeView(View, ValueField.GetOuterBuffer()).MakeOwned(); CbField HashField = LoadCompactBinary(Reader, Allocator); @@ -748,7 +760,11 @@ namespace legacy { { const IoHash Hash = ValueField.AsHash(); - ZEN_ASSERT(Mapper); + if (!Mapper) + { + return false; + } + if (SharedBuffer AttachmentData = (*Mapper)(Hash)) { IoHash RawHash; @@ -763,6 +779,10 @@ namespace legacy { } else { + if (ValidateHashes && IoHash::HashBuffer(AttachmentData) != Hash) + { + return false; + } const CbValidateError ValidationResult = ValidateCompactBinary(AttachmentData.GetView(), CbValidateMode::All); if (ValidationResult != CbValidateError::None) { @@ -801,13 +821,13 @@ namespace legacy { #if ZEN_WITH_TESTS void -usonpackage_forcelink() +cbpackage_forcelink() { } TEST_SUITE_BEGIN("core.compactbinarypackage"); -TEST_CASE("usonpackage") +TEST_CASE("cbpackage") { using namespace std::literals; @@ -997,7 +1017,7 @@ TEST_CASE("usonpackage") } } -TEST_CASE("usonpackage.serialization") +TEST_CASE("cbpackage.serialization") { using namespace std::literals; @@ -1303,7 +1323,7 @@ TEST_CASE("usonpackage.serialization") } } -TEST_CASE("usonpackage.invalidpackage") +TEST_CASE("cbpackage.invalidpackage") { const auto TestLoad = [](std::initializer_list<uint8_t> RawData, BufferAllocator Allocator = UniqueBuffer::Alloc) { const MemoryView RawView = MakeMemoryView(RawData); @@ -1345,6 +1365,90 @@ TEST_CASE("usonpackage.invalidpackage") } } +TEST_CASE("cbpackage.legacy.invalidpackage") +{ + const auto TestLegacyLoad = [](std::initializer_list<uint8_t> RawData) { + const MemoryView RawView = MakeMemoryView(RawData); + IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(RawView.GetData()), RawView.GetSize()); + CbPackage Package; + CHECK_FALSE(legacy::TryLoadCbPackage(Package, Buffer, &UniqueBuffer::Alloc)); + }; + + SUBCASE("Empty") { TestLegacyLoad({}); } + + SUBCASE("Zero size binary rejects") + { + // A binary field with zero payload size should be rejected (would desync the reader) + BinaryWriter Writer; + CbWriter Cb; + Cb.AddBinary(MemoryView()); // zero-size binary + Cb.Save(Writer); + + IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(MakeMemoryView(Writer).GetData()), MakeMemoryView(Writer).GetSize()); + CbPackage Package; + CHECK_FALSE(legacy::TryLoadCbPackage(Package, Buffer, &UniqueBuffer::Alloc)); + } +} + +TEST_CASE("cbpackage.legacy.hashresolution") +{ + // Build a valid legacy package with an object, then round-trip it + CbObjectWriter RootWriter; + RootWriter.AddString("name", "test"); + CbObject RootObject = RootWriter.Save(); + + CbAttachment ObjectAttach(RootObject); + + CbPackage OriginalPkg; + OriginalPkg.SetObject(RootObject); + OriginalPkg.AddAttachment(ObjectAttach); + + BinaryWriter Writer; + legacy::SaveCbPackage(OriginalPkg, Writer); + + IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(MakeMemoryView(Writer).GetData()), MakeMemoryView(Writer).GetSize()); + CbPackage LoadedPkg; + CHECK(legacy::TryLoadCbPackage(LoadedPkg, Buffer, &UniqueBuffer::Alloc)); + + // The hash-only path requires a mapper — without one it should fail + CbWriter HashOnlyCb; + HashOnlyCb.AddHash(ObjectAttach.GetHash()); + HashOnlyCb.AddNull(); + BinaryWriter HashOnlyWriter; + HashOnlyCb.Save(HashOnlyWriter); + + IoBuffer HashOnlyBuffer(IoBuffer::Wrap, + const_cast<void*>(MakeMemoryView(HashOnlyWriter).GetData()), + MakeMemoryView(HashOnlyWriter).GetSize()); + CbPackage HashOnlyPkg; + CHECK_FALSE(legacy::TryLoadCbPackage(HashOnlyPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, nullptr)); + + // With a mapper that returns valid data, it should succeed + CbPackage::AttachmentResolver Resolver = [&](const IoHash& Hash) -> SharedBuffer { + if (Hash == ObjectAttach.GetHash()) + { + return RootObject.GetBuffer(); + } + return {}; + }; + CHECK(legacy::TryLoadCbPackage(HashOnlyPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &Resolver)); + + // Build a different but structurally valid CbObject to use as mismatched data + CbObjectWriter DifferentWriter; + DifferentWriter.AddString("name", "different"); + CbObject DifferentObject = DifferentWriter.Save(); + + CbPackage::AttachmentResolver BadResolver = [&](const IoHash&) -> SharedBuffer { return DifferentObject.GetBuffer(); }; + CbPackage BadPkg; + + // With ValidateHashes enabled and a mapper that returns mismatched data, it should fail + CHECK_FALSE(legacy::TryLoadCbPackage(BadPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &BadResolver, /*ValidateHashes*/ true)); + + // Without ValidateHashes, the mismatched data is accepted (structure is still valid CB) + CbPackage UncheckedPkg; + CHECK(legacy::TryLoadCbPackage(UncheckedPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &BadResolver, /*ValidateHashes*/ false)); +} + TEST_SUITE_END(); #endif diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 0d361801f..a63594be9 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -3275,14 +3275,25 @@ MakeSafeAbsolutePathInPlace(std::filesystem::path& Path) { if (!Path.empty()) { - std::filesystem::path AbsolutePath = std::filesystem::absolute(Path).make_preferred(); + Path = std::filesystem::absolute(Path).make_preferred(); #if ZEN_PLATFORM_WINDOWS - const std::string_view Prefix = "\\\\?\\"; - const std::u8string PrefixU8(Prefix.begin(), Prefix.end()); - std::u8string PathString = AbsolutePath.u8string(); - if (!PathString.empty() && !PathString.starts_with(PrefixU8)) + const std::u8string_view LongPathPrefix = u8"\\\\?\\"; + const std::u8string_view UncPrefix = u8"\\\\"; + const std::u8string_view LongPathUncPrefix = u8"\\\\?\\UNC\\"; + + std::u8string PathString = Path.u8string(); + if (!PathString.empty() && !PathString.starts_with(LongPathPrefix)) { - PathString.insert(0, PrefixU8); + if (PathString.starts_with(UncPrefix)) + { + // UNC path: \\server\share → \\?\UNC\server\share + PathString.replace(0, UncPrefix.size(), LongPathUncPrefix); + } + else + { + // Local path: C:\foo → \\?\C:\foo + PathString.insert(0, LongPathPrefix); + } Path = PathString; } #endif // ZEN_PLATFORM_WINDOWS @@ -4049,6 +4060,54 @@ TEST_CASE("SharedMemory") CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, false)); } +TEST_CASE("filesystem.MakeSafeAbsolutePath") +{ +# if ZEN_PLATFORM_WINDOWS + // Local path gets \\?\ prefix + { + std::filesystem::path Local = MakeSafeAbsolutePath("C:\\Users\\test"); + CHECK(Local.u8string().starts_with(u8"\\\\?\\")); + CHECK(Local.u8string().find(u8"C:\\Users\\test") != std::u8string::npos); + } + + // UNC path gets \\?\UNC\ prefix + { + std::filesystem::path Unc = MakeSafeAbsolutePath("\\\\server\\share\\path"); + std::u8string UncStr = Unc.u8string(); + CHECK_MESSAGE(UncStr.starts_with(u8"\\\\?\\UNC\\"), fmt::format("Expected \\\\?\\UNC\\ prefix, got '{}'", Unc)); + CHECK_MESSAGE(UncStr.find(u8"server\\share\\path") != std::u8string::npos, + fmt::format("Expected server\\share\\path in '{}'", Unc)); + // Must NOT produce \\?\\\server (double backslash after \\?\) + CHECK_MESSAGE(UncStr.find(u8"\\\\?\\\\\\") == std::u8string::npos, + fmt::format("Path contains invalid double-backslash after prefix: '{}'", Unc)); + } + + // Already-prefixed path is not double-prefixed + { + std::filesystem::path Already = MakeSafeAbsolutePath("\\\\?\\C:\\already\\prefixed"); + size_t Count = 0; + std::u8string Str = Already.u8string(); + for (size_t Pos = Str.find(u8"\\\\?\\"); Pos != std::u8string::npos; Pos = Str.find(u8"\\\\?\\", Pos + 1)) + { + ++Count; + } + CHECK_EQ(Count, 1); + } + + // Already-prefixed UNC path is not double-prefixed + { + std::filesystem::path AlreadyUnc = MakeSafeAbsolutePath("\\\\?\\UNC\\server\\share"); + size_t Count = 0; + std::u8string Str = AlreadyUnc.u8string(); + for (size_t Pos = Str.find(u8"\\\\?\\"); Pos != std::u8string::npos; Pos = Str.find(u8"\\\\?\\", Pos + 1)) + { + ++Count; + } + CHECK_EQ(Count, 1); + } +# endif // ZEN_PLATFORM_WINDOWS +} + TEST_SUITE_END(); #endif diff --git a/src/zencore/include/zencore/compactbinarypackage.h b/src/zencore/include/zencore/compactbinarypackage.h index 64b62e2c0..148c0d3fd 100644 --- a/src/zencore/include/zencore/compactbinarypackage.h +++ b/src/zencore/include/zencore/compactbinarypackage.h @@ -278,10 +278,10 @@ public: * @return The attachment, or null if the attachment is not found. * @note The returned pointer is only valid until the attachments on this package are modified. */ - const CbAttachment* FindAttachment(const IoHash& Hash) const; + [[nodiscard]] const CbAttachment* FindAttachment(const IoHash& Hash) const; /** Find an attachment if it exists in the package. */ - inline const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); } + [[nodiscard]] const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); } /** Add the attachment to this package. */ inline void AddAttachment(const CbAttachment& Attachment) { AddAttachment(Attachment, nullptr); } @@ -336,17 +336,26 @@ private: IoHash ObjectHash; }; +/** In addition to the above, we also support a legacy format which is used by + * the HTTP project store for historical reasons. Don't use the below functions + * for new code. + */ namespace legacy { void SaveCbAttachment(const CbAttachment& Attachment, CbWriter& Writer); void SaveCbPackage(const CbPackage& Package, CbWriter& Writer); void SaveCbPackage(const CbPackage& Package, BinaryWriter& Ar); - bool TryLoadCbPackage(CbPackage& Package, IoBuffer Buffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper = nullptr); + bool TryLoadCbPackage(CbPackage& Package, + IoBuffer Buffer, + BufferAllocator Allocator, + CbPackage::AttachmentResolver* Mapper = nullptr, + bool ValidateHashes = false); bool TryLoadCbPackage(CbPackage& Package, BinaryReader& Reader, BufferAllocator Allocator, - CbPackage::AttachmentResolver* Mapper = nullptr); + CbPackage::AttachmentResolver* Mapper = nullptr, + bool ValidateHashes = false); } // namespace legacy -void usonpackage_forcelink(); // internal +void cbpackage_forcelink(); // internal } // namespace zen diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h index 3427991d2..1608ad523 100644 --- a/src/zencore/include/zencore/logging.h +++ b/src/zencore/include/zencore/logging.h @@ -90,6 +90,34 @@ using zen::ConsoleLog; using zen::ErrorLog; using zen::Log; +//////////////////////////////////////////////////////////////////////// +// Color helpers + +#define ZEN_RED(str) "\033[31m" str "\033[0m" +#define ZEN_GREEN(str) "\033[32m" str "\033[0m" +#define ZEN_YELLOW(str) "\033[33m" str "\033[0m" +#define ZEN_BLUE(str) "\033[34m" str "\033[0m" +#define ZEN_MAGENTA(str) "\033[35m" str "\033[0m" +#define ZEN_CYAN(str) "\033[36m" str "\033[0m" +#define ZEN_WHITE(str) "\033[37m" str "\033[0m" + +#define ZEN_BRIGHT_RED(str) "\033[91m" str "\033[0m" +#define ZEN_BRIGHT_GREEN(str) "\033[92m" str "\033[0m" +#define ZEN_BRIGHT_YELLOW(str) "\033[93m" str "\033[0m" +#define ZEN_BRIGHT_BLUE(str) "\033[94m" str "\033[0m" +#define ZEN_BRIGHT_MAGENTA(str) "\033[95m" str "\033[0m" +#define ZEN_BRIGHT_CYAN(str) "\033[96m" str "\033[0m" +#define ZEN_BRIGHT_WHITE(str) "\033[97m" str "\033[0m" + +#define ZEN_BOLD(str) "\033[1m" str "\033[0m" +#define ZEN_UNDERLINE(str) "\033[4m" str "\033[0m" +#define ZEN_DIM(str) "\033[2m" str "\033[0m" +#define ZEN_ITALIC(str) "\033[3m" str "\033[0m" +#define ZEN_STRIKETHROUGH(str) "\033[9m" str "\033[0m" +#define ZEN_INVERSE(str) "\033[7m" str "\033[0m" + +//////////////////////////////////////////////////////////////////////// + #if ZEN_BUILD_DEBUG # define ZEN_CHECK_FORMAT_STRING(fmtstr, ...) \ while (false) \ diff --git a/src/zencore/include/zencore/process.h b/src/zencore/include/zencore/process.h index 5ae7fad68..8cbed781d 100644 --- a/src/zencore/include/zencore/process.h +++ b/src/zencore/include/zencore/process.h @@ -174,9 +174,11 @@ struct CreateProcOptions // allocated and no conhost.exe is spawned. Stdout/stderr still work when redirected // via pipes. Prefer this for headless worker processes. Flag_NoConsole = 1 << 3, - // Create the child in a new process group (CREATE_NEW_PROCESS_GROUP on Windows). - // Allows sending CTRL_BREAK_EVENT to the child group without affecting the parent. - Flag_Windows_NewProcessGroup = 1 << 4, + // Spawn the child as a new process group leader (its pgid = its own pid). + // On Windows: CREATE_NEW_PROCESS_GROUP, enables CTRL_BREAK_EVENT targeting. + // On POSIX: child calls setpgid(0,0) / posix_spawn with POSIX_SPAWN_SETPGROUP+pgid=0. + // Mutually exclusive with ProcessGroupId > 0. + Flag_NewProcessGroup = 1 << 4, // Allocate a hidden console for the child (CREATE_NO_WINDOW on Windows). Unlike // Flag_NoConsole the child still gets a console (and a conhost.exe) but no visible // window. Use this when the child needs a console for stdio but should not show a window. @@ -197,9 +199,9 @@ struct CreateProcOptions #if ZEN_PLATFORM_WINDOWS JobObject* AssignToJob = nullptr; // When set, the process is created suspended, assigned to the job, then resumed #else - /// POSIX process group id. When > 0, the child is placed into this process - /// group via setpgid() before exec. Use the pid of the first child as the - /// pgid to create a group, then pass the same pgid for subsequent children. + /// When > 0, child joins this existing process group. Mutually exclusive with + /// Flag_NewProcessGroup; use that flag on the first spawn to create the group, + /// then pass the resulting pid here for subsequent spawns to join it. int ProcessGroupId = 0; #endif }; diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp index e7baa3f8e..ee821944a 100644 --- a/src/zencore/process.cpp +++ b/src/zencore/process.cpp @@ -37,7 +37,9 @@ ZEN_THIRD_PARTY_INCLUDES_START #endif #if ZEN_PLATFORM_MAC +# include <crt_externs.h> # include <libproc.h> +# include <spawn.h> # include <sys/types.h> # include <sys/sysctl.h> #endif @@ -135,8 +137,68 @@ IsZombieProcess(int pid, std::error_code& OutEc) } return false; } + +static char** +GetEnviron() +{ + return *_NSGetEnviron(); +} #endif // ZEN_PLATFORM_MAC +#if ZEN_PLATFORM_LINUX +static char** +GetEnviron() +{ + return environ; +} +#endif // ZEN_PLATFORM_LINUX + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +// Holds a null-terminated envp array built by merging the current process environment with +// a set of overrides. When Overrides is empty, Data points directly to environ (no allocation). +// Must outlive any posix_spawn / execve call that receives Data. +struct EnvpHolder +{ + char** Data = GetEnviron(); + + explicit EnvpHolder(const std::vector<std::pair<std::string, std::string>>& Overrides) + { + if (Overrides.empty()) + { + return; + } + std::map<std::string, std::string> EnvMap; + for (char** E = GetEnviron(); *E; ++E) + { + std::string_view Entry(*E); + const size_t EqPos = Entry.find('='); + if (EqPos != std::string_view::npos) + { + EnvMap[std::string(Entry.substr(0, EqPos))] = std::string(Entry.substr(EqPos + 1)); + } + } + for (const auto& [Key, Value] : Overrides) + { + EnvMap[Key] = Value; + } + for (const auto& [Key, Value] : EnvMap) + { + m_Strings.push_back(Key + "=" + Value); + } + for (std::string& S : m_Strings) + { + m_Ptrs.push_back(S.data()); + } + m_Ptrs.push_back(nullptr); + Data = m_Ptrs.data(); + } + +private: + std::vector<std::string> m_Strings; + std::vector<char*> m_Ptrs; +}; +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + ////////////////////////////////////////////////////////////////////////// // Pipe creation for child process stdout capture @@ -691,6 +753,7 @@ BuildArgV(std::vector<char*>& Out, char* CommandLine) ++Cursor; } } + #endif // !WINDOWS || TESTS #if ZEN_PLATFORM_WINDOWS @@ -766,7 +829,7 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma { CreationFlags |= CREATE_NO_WINDOW; } - if (Options.Flags & CreateProcOptions::Flag_Windows_NewProcessGroup) + if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup) { CreationFlags |= CREATE_NEW_PROCESS_GROUP; } @@ -1070,23 +1133,30 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine } return CreateProcNormal(Executable, CommandLine, Options); -#else +#elif ZEN_PLATFORM_LINUX + // vfork uses CLONE_VM|CLONE_VFORK: the child shares the parent's address space and the + // parent is suspended until the child calls exec or _exit. This avoids page-table duplication + // and the ENOMEM that fork() produces on systems with strict overcommit (vm.overcommit_memory=2). + // All child-side setup uses only syscalls that do not modify user-space memory. + // Environment overrides are merged into envp before vfork so that setenv() is never called + // from the child (which would corrupt the shared address space). std::vector<char*> ArgV; std::string CommandLineZ(CommandLine); BuildArgV(ArgV, CommandLineZ.data()); ArgV.push_back(nullptr); - int ChildPid = fork(); + EnvpHolder Envp(Options.Environment); + + int ChildPid = vfork(); if (ChildPid < 0) { - ThrowLastError("Failed to fork a new child process"); + ThrowLastError("Failed to vfork a new child process"); } else if (ChildPid == 0) { if (Options.WorkingDirectory != nullptr) { - int Result = chdir(Options.WorkingDirectory->c_str()); - ZEN_UNUSED(Result); + chdir(Options.WorkingDirectory->c_str()); } if (Options.StdoutPipe != nullptr && Options.StdoutPipe->WriteFd >= 0) @@ -1118,23 +1188,99 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine } } - if (Options.ProcessGroupId > 0) + if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup) + { + setpgid(0, 0); + } + else if (Options.ProcessGroupId > 0) { setpgid(0, Options.ProcessGroupId); } - for (const auto& [Key, Value] : Options.Environment) + execve(Executable.c_str(), ArgV.data(), Envp.Data); + _exit(127); + } + + return ChildPid; +#else // macOS + std::vector<char*> ArgV; + std::string CommandLineZ(CommandLine); + BuildArgV(ArgV, CommandLineZ.data()); + ArgV.push_back(nullptr); + + posix_spawn_file_actions_t FileActions; + posix_spawnattr_t Attr; + + int Err = posix_spawn_file_actions_init(&FileActions); + if (Err != 0) + { + ThrowSystemError(Err, "posix_spawn_file_actions_init failed"); + } + auto FileActionsGuard = MakeGuard([&] { posix_spawn_file_actions_destroy(&FileActions); }); + + Err = posix_spawnattr_init(&Attr); + if (Err != 0) + { + ThrowSystemError(Err, "posix_spawnattr_init failed"); + } + auto AttrGuard = MakeGuard([&] { posix_spawnattr_destroy(&Attr); }); + + if (Options.WorkingDirectory != nullptr) + { + Err = posix_spawn_file_actions_addchdir_np(&FileActions, Options.WorkingDirectory->c_str()); + if (Err != 0) { - setenv(Key.c_str(), Value.c_str(), 1); + ThrowSystemError(Err, "posix_spawn_file_actions_addchdir_np failed"); } + } + + if (Options.StdoutPipe != nullptr && Options.StdoutPipe->WriteFd >= 0) + { + const int StdoutWriteFd = Options.StdoutPipe->WriteFd; + ZEN_ASSERT(StdoutWriteFd > STDERR_FILENO); + posix_spawn_file_actions_adddup2(&FileActions, StdoutWriteFd, STDOUT_FILENO); - if (execv(Executable.c_str(), ArgV.data()) < 0) + if (Options.StderrPipe != nullptr && Options.StderrPipe->WriteFd >= 0) + { + const int StderrWriteFd = Options.StderrPipe->WriteFd; + ZEN_ASSERT(StderrWriteFd > STDERR_FILENO && StderrWriteFd != StdoutWriteFd); + posix_spawn_file_actions_adddup2(&FileActions, StderrWriteFd, STDERR_FILENO); + posix_spawn_file_actions_addclose(&FileActions, StderrWriteFd); + } + else { - ThrowLastError("Failed to exec() a new process image"); + posix_spawn_file_actions_adddup2(&FileActions, StdoutWriteFd, STDERR_FILENO); } + + posix_spawn_file_actions_addclose(&FileActions, StdoutWriteFd); + } + else if (!Options.StdoutFile.empty()) + { + posix_spawn_file_actions_addopen(&FileActions, STDOUT_FILENO, Options.StdoutFile.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); + posix_spawn_file_actions_adddup2(&FileActions, STDOUT_FILENO, STDERR_FILENO); } - return ChildPid; + if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup) + { + posix_spawnattr_setflags(&Attr, POSIX_SPAWN_SETPGROUP); + posix_spawnattr_setpgroup(&Attr, 0); + } + else if (Options.ProcessGroupId > 0) + { + posix_spawnattr_setflags(&Attr, POSIX_SPAWN_SETPGROUP); + posix_spawnattr_setpgroup(&Attr, Options.ProcessGroupId); + } + + EnvpHolder Envp(Options.Environment); + + pid_t ChildPid = 0; + Err = posix_spawn(&ChildPid, Executable.c_str(), &FileActions, &Attr, ArgV.data(), Envp.Data); + if (Err != 0) + { + ThrowSystemError(Err, "Failed to posix_spawn a new child process"); + } + + return int(ChildPid); #endif } @@ -1252,14 +1398,28 @@ JobObject::Initialize() } JOBOBJECT_EXTENDED_LIMIT_INFORMATION LimitInfo = {}; - LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION; if (!SetInformationJobObject(m_JobHandle, JobObjectExtendedLimitInformation, &LimitInfo, sizeof(LimitInfo))) { ZEN_WARN("Failed to set job object limits: {}", zen::GetLastError()); CloseHandle(m_JobHandle); m_JobHandle = nullptr; + return; } + + // Prevent child processes from clearing SEM_NOGPFAULTERRORBOX, which + // suppresses WER/Dr. Watson crash dialogs. Without this, a crashing + // child can pop a modal dialog and block the monitor thread. +# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE) +# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400 +# endif + JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{}; + UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE; + SetInformationJobObject(m_JobHandle, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions)); + + // Set error mode on the current process so children inherit it. + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX); } bool diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp index 8c29a8962..c1ac63621 100644 --- a/src/zencore/zencore.cpp +++ b/src/zencore/zencore.cpp @@ -273,7 +273,7 @@ zencore_forcelinktests() zen::uid_forcelink(); zen::uson_forcelink(); zen::usonbuilder_forcelink(); - zen::usonpackage_forcelink(); + zen::cbpackage_forcelink(); zen::cbjson_forcelink(); zen::cbyaml_forcelink(); zen::workthreadpool_forcelink(); diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index e05c9815f..38021be16 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -479,6 +479,18 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) return Ref<IHttpPackageHandler>(); } +bool +HttpService::AcceptsLocalFileReferences() const +{ + return false; +} + +const ILocalRefPolicy* +HttpService::GetLocalRefPolicy() const +{ + return nullptr; +} + ////////////////////////////////////////////////////////////////////////// HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service) @@ -705,7 +717,10 @@ HttpServerRequest::ReadPayloadPackage() { if (IoBuffer Payload = ReadPayload()) { - return ParsePackageMessage(std::move(Payload)); + ParseFlags Flags = + (IsLocalMachineRequest() && m_Service.AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences : ParseFlags::kDefault; + const ILocalRefPolicy* Policy = EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences) ? m_Service.GetLocalRefPolicy() : nullptr; + return ParsePackageMessage(std::move(Payload), {}, Flags, Policy); } return {}; @@ -1273,7 +1288,12 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP return PackageHandlerRef->CreateTarget(Cid, Size); }; - CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer); + ParseFlags PkgFlags = (Request.IsLocalMachineRequest() && Service.AcceptsLocalFileReferences()) + ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + const ILocalRefPolicy* PkgPolicy = + EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? Service.GetLocalRefPolicy() : nullptr; + CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer, PkgFlags, PkgPolicy); PackageHandlerRef->OnRequestComplete(); } diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 5eaed6004..76f219f04 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -12,6 +12,7 @@ #include <zencore/string.h> #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <zenhttp/localrefpolicy.h> #include <zentelemetry/hyperloglog.h> #include <zentelemetry/stats.h> @@ -193,9 +194,16 @@ public: HttpService() = default; virtual ~HttpService() = default; - virtual const char* BaseUri() const = 0; - virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; - virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); + [[nodiscard]] virtual const char* BaseUri() const = 0; + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; + virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); + + /// Whether this service accepts local file references in inbound packages from local clients. + [[nodiscard]] virtual bool AcceptsLocalFileReferences() const; + + /// Returns the local ref policy for validating file paths in inbound local references. + /// Returns nullptr by default, which causes file-path local refs to be rejected (fail-closed). + [[nodiscard]] virtual const ILocalRefPolicy* GetLocalRefPolicy() const; // Internals diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h index bce771c75..b20bc3a36 100644 --- a/src/zenhttp/include/zenhttp/httpstats.h +++ b/src/zenhttp/include/zenhttp/httpstats.h @@ -43,7 +43,7 @@ public: virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; diff --git a/src/zenhttp/include/zenhttp/localrefpolicy.h b/src/zenhttp/include/zenhttp/localrefpolicy.h new file mode 100644 index 000000000..0b37f9dc7 --- /dev/null +++ b/src/zenhttp/include/zenhttp/localrefpolicy.h @@ -0,0 +1,21 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <filesystem> + +namespace zen { + +/// Policy interface for validating local file reference paths in inbound CbPackage messages. +/// Implementations should throw std::invalid_argument if the path is not allowed. +class ILocalRefPolicy +{ +public: + virtual ~ILocalRefPolicy() = default; + + /// Validate that a local file reference path is allowed. + /// Throws std::invalid_argument if the path escapes the allowed root. + virtual void ValidatePath(const std::filesystem::path& Path) const = 0; +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h index 1a5068580..66e3f6e55 100644 --- a/src/zenhttp/include/zenhttp/packageformat.h +++ b/src/zenhttp/include/zenhttp/packageformat.h @@ -5,6 +5,7 @@ #include <zencore/compactbinarypackage.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> +#include <zenhttp/localrefpolicy.h> #include <functional> #include <gsl/gsl-lite.hpp> @@ -97,11 +98,22 @@ gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions); std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr); CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr); -CbPackage ParsePackageMessage( - IoBuffer Payload, - std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { + +enum class ParseFlags +{ + kDefault = 0, + kAllowLocalReferences = (1u << 0), // Allow packages containing local file references (local clients only) +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(ParseFlags); + +CbPackage ParsePackageMessage( + IoBuffer Payload, + std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; - }); + }, + ParseFlags Flags = ParseFlags::kDefault, + const ILocalRefPolicy* Policy = nullptr); bool IsPackageMessage(IoBuffer Payload); bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage); @@ -122,10 +134,11 @@ CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, void* Targe class CbPackageReader { public: - CbPackageReader(); + CbPackageReader(ParseFlags Flags = ParseFlags::kDefault); ~CbPackageReader(); void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer); + void SetLocalRefPolicy(const ILocalRefPolicy* Policy); /** Process compact binary package data stream @@ -149,6 +162,8 @@ private: kReadingBuffers } m_CurrentState = State::kInitialState; + ParseFlags m_Flags; + const ILocalRefPolicy* m_LocalRefPolicy = nullptr; std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer; std::vector<IoBuffer> m_PayloadBuffers; std::vector<CbAttachmentEntry> m_AttachmentEntries; diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h index 710579faa..2d25515d3 100644 --- a/src/zenhttp/include/zenhttp/websocket.h +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -59,7 +59,7 @@ class IWebSocketHandler public: virtual ~IWebSocketHandler() = default; - virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0; + virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) = 0; virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0; virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0; }; diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp index 7e6207e56..5ad5ebcc7 100644 --- a/src/zenhttp/monitoring/httpstats.cpp +++ b/src/zenhttp/monitoring/httpstats.cpp @@ -196,8 +196,9 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request) // void -HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen"); ZEN_INFO("Stats WebSocket client connected"); diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp index 9c62c1f2d..267ce386c 100644 --- a/src/zenhttp/packageformat.cpp +++ b/src/zenhttp/packageformat.cpp @@ -36,6 +36,71 @@ const std::string_view HandlePrefix(":?#:"); typedef eastl::fixed_vector<IoBuffer, 16> IoBufferVec_t; +/// Enforce local-ref path policy. Handle-based refs bypass the policy since they use OS handle security. +/// If no policy is set, file-path local refs are rejected (fail-closed). +static void +ApplyLocalRefPolicy(const ILocalRefPolicy* Policy, const std::filesystem::path& Path) +{ + if (Policy) + { + Policy->ValidatePath(Path); + } + else + { + throw std::invalid_argument("local file reference rejected: no validation policy"); + } +} + +// Validates the CbPackageHeader magic and attachment count. Returns the total +// chunk count (AttachmentCount + 1, including the implicit root object). +static uint32_t +ValidatePackageHeader(const CbPackageHeader& Hdr) +{ + if (Hdr.HeaderMagic != kCbPkgMagic) + { + throw std::invalid_argument( + fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr.HeaderMagic)); + } + // ChunkCount is AttachmentCount + 1 (the root object is implicit). Guard against + // UINT32_MAX wrapping to 0, which would bypass subsequent size checks. + if (Hdr.AttachmentCount == UINT32_MAX) + { + throw std::invalid_argument("invalid CbPackage, attachment count overflow"); + } + return Hdr.AttachmentCount + 1; +} + +struct ValidatedLocalRef +{ + bool Valid = false; + const CbAttachmentReferenceHeader* Header = nullptr; + std::string_view Path; + std::string Error; +}; + +// Validates that the attachment buffer contains a well-formed local reference +// header and path. On failure, Valid is false and Error contains the reason. +static ValidatedLocalRef +ValidateLocalRef(const IoBuffer& AttachmentBuffer) +{ + if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader)) + { + return {.Error = fmt::format("local ref attachment too small for header (size {})", AttachmentBuffer.Size())}; + } + + const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); + + if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength) + { + return {.Error = fmt::format("local ref attachment too small for path (need {}, have {})", + sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength, + AttachmentBuffer.Size())}; + } + + const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); + return {.Valid = true, .Header = AttachRefHdr, .Path = std::string_view(PathPointer, AttachRefHdr->AbsolutePathLength)}; +} + IoBufferVec_t FormatPackageMessageInternal(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle); std::vector<IoBuffer> @@ -361,7 +426,10 @@ IsPackageMessage(IoBuffer Payload) } CbPackage -ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer) +ParsePackageMessage(IoBuffer Payload, + std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer, + ParseFlags Flags, + const ILocalRefPolicy* Policy) { ZEN_TRACE_CPU("ParsePackageMessage"); @@ -372,17 +440,13 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint BinaryReader Reader(Payload); - const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData()); - if (Hdr->HeaderMagic != kCbPkgMagic) - { - throw std::invalid_argument( - fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr->HeaderMagic)); - } + const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData()); + const uint32_t ChunkCount = ValidatePackageHeader(*Hdr); Reader.Skip(sizeof(CbPackageHeader)); - const uint32_t ChunkCount = Hdr->AttachmentCount + 1; - - if (Reader.Remaining() < sizeof(CbAttachmentEntry) * ChunkCount) + // Widen to uint64_t so the multiplication cannot wrap on 32-bit. + const uint64_t AttachmentTableSize = uint64_t(sizeof(CbAttachmentEntry)) * ChunkCount; + if (Reader.Remaining() < AttachmentTableSize) { throw std::invalid_argument(fmt::format("invalid CbPackage, missing attachment entry data (need {} bytes, have {} bytes)", sizeof(CbAttachmentEntry) * ChunkCount, @@ -417,15 +481,22 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) { - // Marshal local reference - a "pointer" to the chunk backing file - - ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); + if (!EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences)) + { + throw std::invalid_argument( + fmt::format("package contains local reference (attachment #{}) but local references are not allowed", i)); + } - const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); - const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); + // Marshal local reference - a "pointer" to the chunk backing file - ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); - std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength); + ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer); + if (!LocalRef.Valid) + { + MalformedAttachments.push_back(std::make_pair(i, fmt::format("{} for {}", LocalRef.Error, Entry.AttachmentHash))); + continue; + } + const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header; + std::string_view PathView = LocalRef.Path; IoBuffer FullFileBuffer; @@ -461,13 +532,29 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint } else { + ApplyLocalRefPolicy(Policy, Path); FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second; } } if (FullFileBuffer) { - IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize() + // Guard against offset+size overflow or exceeding the file bounds. + const uint64_t FileSize = FullFileBuffer.GetSize(); + if (AttachRefHdr->PayloadByteOffset > FileSize || + AttachRefHdr->PayloadByteSize > FileSize - AttachRefHdr->PayloadByteOffset) + { + MalformedAttachments.push_back( + std::make_pair(i, + fmt::format("Local ref offset/size out of bounds (offset {}, size {}, file size {}) for {}", + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize, + FileSize, + Entry.AttachmentHash))); + continue; + } + + IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FileSize ? FullFileBuffer : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); @@ -630,7 +717,9 @@ ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPa return OutPackage.TryLoad(Response); } -CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) +CbPackageReader::CbPackageReader(ParseFlags Flags) +: m_Flags(Flags) +, m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) { } @@ -644,6 +733,12 @@ CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Ci m_CreateBuffer = CreateBuffer; } +void +CbPackageReader::SetLocalRefPolicy(const ILocalRefPolicy* Policy) +{ + m_LocalRefPolicy = Policy; +} + uint64_t CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) { @@ -657,12 +752,14 @@ CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) return sizeof m_PackageHeader; case State::kReadingHeader: - ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); - memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); - ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic); - m_CurrentState = State::kReadingAttachmentEntries; - m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1); - return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry); + { + ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); + memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); + const uint32_t ChunkCount = ValidatePackageHeader(m_PackageHeader); + m_CurrentState = State::kReadingAttachmentEntries; + m_AttachmentEntries.resize(ChunkCount); + return uint64_t(ChunkCount) * sizeof(CbAttachmentEntry); + } case State::kReadingAttachmentEntries: ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry))); @@ -691,16 +788,19 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) { // Marshal local reference - a "pointer" to the chunk backing file - ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); - - const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); - const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1); - - ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); + ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer); + if (!LocalRef.Valid) + { + throw std::invalid_argument(LocalRef.Error); + } - std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength}; + const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header; + std::filesystem::path Path(Utf8ToWide(LocalRef.Path)); - std::filesystem::path Path{PathView}; + if (!LocalRef.Path.starts_with(HandlePrefix)) + { + ApplyLocalRefPolicy(m_LocalRefPolicy, Path); + } IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); @@ -714,6 +814,17 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) AttachRefHdr->PayloadByteSize)); } + // MakeFromFile silently clamps offset+size to the file size. Detect this + // to avoid returning a short buffer that could cause subtle downstream issues. + if (ChunkReference.GetSize() != AttachRefHdr->PayloadByteSize) + { + throw std::invalid_argument(fmt::format("local ref offset/size out of bounds for '{}' (requested offset {}, size {}, got size {})", + PathToUtf8(Path), + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize, + ChunkReference.GetSize())); + } + return ChunkReference; }; @@ -732,6 +843,13 @@ CbPackageReader::Finalize() { IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex]; + if ((Entry.Flags & CbAttachmentEntry::kIsLocalRef) && !EnumHasAllFlags(m_Flags, ParseFlags::kAllowLocalReferences)) + { + throw std::invalid_argument( + fmt::format("package contains local reference (attachment #{}) but local references are not allowed", + CurrentAttachmentIndex)); + } + if (CurrentAttachmentIndex == 0) { // Root object @@ -815,6 +933,13 @@ CbPackageReader::Finalize() TEST_SUITE_BEGIN("http.packageformat"); +/// Permissive policy that allows any path, for use in tests that exercise local ref +/// functionality but are not testing path validation. +struct PermissiveLocalRefPolicy : public ILocalRefPolicy +{ + void ValidatePath(const std::filesystem::path&) const override {} +}; + TEST_CASE("CbPackage.Serialization") { // Make a test package @@ -922,6 +1047,169 @@ TEST_CASE("CbPackage.LocalRef") RemainingBytes -= ByteCount; }; + PermissiveLocalRefPolicy AllowAllPolicy; + CbPackageReader Reader(ParseFlags::kAllowLocalReferences); + Reader.SetLocalRefPolicy(&AllowAllPolicy); + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); + NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); + auto Buffers = Reader.GetPayloadBuffers(); + + for (auto& PayloadBuffer : Buffers) + { + CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); + } + + Reader.Finalize(); +} + +TEST_CASE("CbPackage.Validation.TruncatedHeader") +{ + // Payload too small for a CbPackageHeader + uint8_t Bytes[] = {0xcc, 0xaa, 0x77, 0xaa}; + IoBuffer Payload(IoBuffer::Wrap, Bytes, sizeof(Bytes)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.BadMagic") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = 0xDEADBEEF; + Hdr.AttachmentCount = 0; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.AttachmentCountOverflow") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = UINT32_MAX; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.TruncatedAttachmentTable") +{ + // Valid header but not enough data for the attachment entries + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = 10; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.TruncatedAttachmentData") +{ + // Valid header + one attachment entry claiming more data than available + std::vector<uint8_t> Data(sizeof(CbPackageHeader) + sizeof(CbAttachmentEntry)); + + CbPackageHeader* Hdr = reinterpret_cast<CbPackageHeader*>(Data.data()); + Hdr->HeaderMagic = kCbPkgMagic; + Hdr->AttachmentCount = 0; // ChunkCount = 1 (root object) + + CbAttachmentEntry* Entry = reinterpret_cast<CbAttachmentEntry*>(Data.data() + sizeof(CbPackageHeader)); + Entry->PayloadSize = 9999; // way more than available + Entry->Flags = CbAttachmentEntry::kIsObject; + Entry->AttachmentHash = IoHash(); + + IoBuffer Payload(IoBuffer::Wrap, Data.data(), Data.size()); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.LocalRefRejectedByDefault") +{ + // Build a valid package with local refs backed by compressed-format files, + // then verify it's rejected with default ParseFlags and accepted when allowed. + ScopedTemporaryDirectory TempDir; + auto Path1 = TempDir.Path() / "abcd"; + auto Path2 = TempDir.Path() / "efgh"; + + // Compress data and write to disk, then create file-backed compressed attachments. + // The files must contain compressed-format data because ParsePackageMessage expects it + // when resolving local refs. + CompressedBuffer Comp1 = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("abcd")), OodleCompressor::NotSet, OodleCompressionLevel::None); + CompressedBuffer Comp2 = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("efgh")), OodleCompressor::NotSet, OodleCompressionLevel::None); + + IoHash Hash1 = Comp1.DecodeRawHash(); + IoHash Hash2 = Comp2.DecodeRawHash(); + + { + IoBuffer Buf1 = Comp1.GetCompressed().Flatten().AsIoBuffer(); + IoBuffer Buf2 = Comp2.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(Path1, Buf1); + WriteFile(Path2, Buf2); + } + + // Create attachments from file-backed buffers so FormatPackageMessage uses local refs + CbAttachment Attach1{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path1)), Hash1}; + CbAttachment Attach2{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path2)), Hash2}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("abcd", Attach1); + Cbo.AddAttachment("efgh", Attach2); + + CbPackage Pkg; + Pkg.AddAttachment(Attach1); + Pkg.AddAttachment(Attach2); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // Default flags should reject local refs + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); + + // With kAllowLocalReferences + a permissive policy, the local-ref gate is passed (the full round-trip + // for local refs through ParsePackageMessage is covered by CbPackage.LocalRef via CbPackageReader) + PermissiveLocalRefPolicy AllowAllPolicy; + CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &AllowAllPolicy); + CHECK(Result.GetObject()); + CHECK(Result.GetAttachments().size() == 2); +} + +TEST_CASE("CbPackage.Validation.LocalRefRejectedByReader") +{ + // Same test but via CbPackageReader + ScopedTemporaryDirectory TempDir; + auto FilePath = TempDir.Path() / "testdata"; + + { + IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("testdata")); + WriteFile(FilePath, Buf); + } + + IoBuffer FileBuffer = IoBufferBuilder::MakeFromFile(FilePath); + CbAttachment Attach{SharedBuffer(FileBuffer)}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten(); + const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); + uint64_t RemainingBytes = Buffer.GetSize(); + + auto ConsumeBytes = [&](uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + void* ReturnPtr = (void*)CursorPtr; + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + return ReturnPtr; + }; + + auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + memcpy(TargetBuffer, CursorPtr, ByteCount); + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + }; + + // Default flags should reject CbPackageReader Reader; uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); @@ -933,7 +1221,199 @@ TEST_CASE("CbPackage.LocalRef") CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); } - Reader.Finalize(); + CHECK_THROWS_AS(Reader.Finalize(), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.BadMagicViaReader") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = 0xBADCAFE; + Hdr.AttachmentCount = 0; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.AttachmentCountOverflowViaReader") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = UINT32_MAX; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathOutsideRoot") +{ + // A file outside the allowed root should be rejected by the policy + ScopedTemporaryDirectory AllowedRoot; + ScopedTemporaryDirectory OutsideDir; + + auto OutsidePath = OutsideDir.Path() / "secret.dat"; + { + IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("secret")); + WriteFile(OutsidePath, Buf); + } + + // Create file-backed compressed attachment from outside root + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("secret")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(OutsidePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // Policy rooted at AllowedRoot should reject the file in OutsideDir + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + Policy.Root = std::filesystem::weakly_canonical(AllowedRoot.Path()).string(); + + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathInsideRoot") +{ + // A file inside the allowed root should be accepted by the policy + ScopedTemporaryDirectory TempRoot; + + auto FilePath = TempRoot.Path() / "data.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("hello")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(FilePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string(); + + CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy); + CHECK(Result.GetObject()); + CHECK(Result.GetAttachments().size() == 1); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathTraversal") +{ + // A file path containing ".." that resolves outside root should be rejected + ScopedTemporaryDirectory TempRoot; + ScopedTemporaryDirectory OutsideDir; + + auto OutsidePath = OutsideDir.Path() / "evil.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("evil")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(OutsidePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + // Root is TempRoot, but the file lives in OutsideDir + Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string(); + + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.NoPolicyFailClosed") +{ + // When local refs are allowed but no policy is provided, file-path refs should be rejected + ScopedTemporaryDirectory TempDir; + + auto FilePath = TempDir.Path() / "data.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("data")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(FilePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // kAllowLocalReferences but nullptr policy => fail-closed + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, nullptr), std::invalid_argument); } TEST_SUITE_END(); diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 7972777b8..6cda84875 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -1275,7 +1275,9 @@ HttpServerConnectionT<SocketType>::HandleRequest() asio::buffer(ResponseStr->data(), ResponseStr->size()), asio::bind_executor( m_Strand, - [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr, PrefixLen = Service->UriPrefixLength()]( + const asio::error_code& Ec, + std::size_t) { if (Ec) { ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); @@ -1287,7 +1289,9 @@ HttpServerConnectionT<SocketType>::HandleRequest() Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + std::string_view FullUrl = Conn->m_RequestData.Url(); + std::string_view RelativeUri = FullUrl.substr(std::min(PrefixLen, static_cast<int>(FullUrl.size()))); + WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri); WsConn->Start(); })); diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 2cad97725..1b722940d 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -2595,7 +2595,14 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT &Transaction().Server())); Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + ExtendableStringBuilder<128> UrlUtf8; + WideToUtf8({(wchar_t*)HttpReq->CookedUrl.pAbsPath, + gsl::narrow<size_t>(HttpReq->CookedUrl.AbsPathLength / sizeof(wchar_t))}, + UrlUtf8); + int PrefixLen = Service->UriPrefixLength(); + std::string_view RelativeUri{UrlUtf8.ToView()}; + RelativeUri.remove_prefix(std::min(PrefixLen, static_cast<int>(RelativeUri.size()))); + WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri); WsConn->Start(); return nullptr; diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp index 59c46a418..363c478ae 100644 --- a/src/zenhttp/servers/wstest.cpp +++ b/src/zenhttp/servers/wstest.cpp @@ -335,8 +335,9 @@ namespace { } // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override { + ZEN_UNUSED(RelativeUri); m_OpenCount.fetch_add(1); m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); }); diff --git a/src/zennomad/nomadprocess.cpp b/src/zennomad/nomadprocess.cpp index 1ae968fb7..deecdef05 100644 --- a/src/zennomad/nomadprocess.cpp +++ b/src/zennomad/nomadprocess.cpp @@ -37,7 +37,7 @@ struct NomadProcess::Impl } CreateProcOptions Options; - Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; CreateProcResult Result = CreateProc("nomad" ZEN_EXE_SUFFIX_LITERAL, "nomad" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options); diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index a04063c4c..ca226444a 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -1149,8 +1149,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) WriteScavengedSequenceToCache(ScavengeRootPath, ScavengedContent, ScavengeOp); - WritePartsComplete++; - if (WritePartsComplete == TotalPartWriteCount) + if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount) { FilteredWrittenBytesPerSecond.Stop(); } @@ -1252,10 +1251,10 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) ScavengedLookups, ScavengedPaths, WriteCache); - WritePartsComplete++; + bool WritePartsDone = WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount; if (!m_AbortFlag) { - if (WritePartsComplete == TotalPartWriteCount) + if (WritePartsDone) { FilteredWrittenBytesPerSecond.Stop(); } @@ -1334,9 +1333,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) Ec.message()); } - WritePartsComplete++; - - if (WritePartsComplete == TotalPartWriteCount) + if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount) { FilteredWrittenBytesPerSecond.Stop(); } @@ -1389,25 +1386,20 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) BlockRangeStartIndex, RangeCount, ExistsResult, + TotalRequestCount, + FilteredDownloadedBytesPerSecond, [this, &RemoteChunkIndexNeedsCopyFromSourceFlags, &SequenceIndexChunksLeftToWriteCounters, &WritePartsComplete, &WriteCache, &Work, - TotalRequestCount, TotalPartWriteCount, - &FilteredDownloadedBytesPerSecond, &FilteredWrittenBytesPerSecond, &PartialBlocks](IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath, size_t BlockRangeStartIndex, std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths) { - if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } - if (!m_AbortFlag) { Work.ScheduleWork( @@ -1483,8 +1475,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) fmt::format("Partial block {} is malformed", BlockDescription.BlockHash)); } - WritePartsComplete++; - if (WritePartsComplete == TotalPartWriteCount) + if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount) { FilteredWrittenBytesPerSecond.Stop(); } @@ -1571,8 +1562,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) uint64_t BlockSize = BlockBuffer.GetSize(); m_DownloadStats.DownloadedBlockCount++; m_DownloadStats.DownloadedBlockByteCount += BlockSize; - m_DownloadStats.RequestsCompleteCount++; - if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) + if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount) { FilteredDownloadedBytesPerSecond.Stop(); } @@ -1683,9 +1673,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) } } - WritePartsComplete++; - - if (WritePartsComplete == TotalPartWriteCount) + if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount) { FilteredWrittenBytesPerSecond.Stop(); } @@ -2987,8 +2975,7 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd ExistingCompressedChunkPath = FindDownloadedChunk(ChunkHash); if (!ExistingCompressedChunkPath.empty()) { - m_DownloadStats.RequestsCompleteCount++; - if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) + if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount) { FilteredDownloadedBytesPerSecond.Stop(); } @@ -3027,11 +3014,11 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd bool NeedHashVerify = WriteCompressedChunkToCache(ChunkHash, ChunkTargetPtrs, WriteCache, std::move(CompressedPart)); - WritePartsComplete++; + bool WritePartsDone = WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount; if (!AbortFlag) { - if (WritePartsComplete == TotalPartWriteCount) + if (WritePartsDone) { FilteredWrittenBytesPerSecond.Stop(); } @@ -3085,6 +3072,8 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd DownloadBuildBlob(RemoteChunkIndex, ExistsResult, Work, + TotalRequestCount, + FilteredDownloadedBytesPerSecond, [this, &ExistsResult, SequenceIndexChunksLeftToWriteCounters, @@ -3092,15 +3081,9 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd &Work, &WritePartsComplete, TotalPartWriteCount, - TotalRequestCount, RemoteChunkIndex, - &FilteredDownloadedBytesPerSecond, &FilteredWrittenBytesPerSecond, ChunkTargetPtrs = std::move(ChunkTargetPtrs)](IoBuffer&& Payload) mutable { - if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } IoBufferFileReference FileRef; bool EnableBacklog = Payload.GetFileReference(FileRef); AsyncWriteDownloadedChunk(m_Options.ZenFolderPath, @@ -3125,6 +3108,8 @@ void BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkIndex, const BlobsExistsResult& ExistsResult, ParallelWork& Work, + uint64_t TotalRequestCount, + FilteredRate& FilteredDownloadedBytesPerSecond, std::function<void(IoBuffer&& Payload)>&& OnDownloaded) { const IoHash& ChunkHash = m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]; @@ -3140,37 +3125,48 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde uint64_t BlobSize = BuildBlob.GetSize(); m_DownloadStats.DownloadedChunkCount++; m_DownloadStats.DownloadedChunkByteCount += BlobSize; - m_DownloadStats.RequestsCompleteCount++; + if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } OnDownloaded(std::move(BuildBlob)); } else { if (m_RemoteContent.ChunkedContent.ChunkRawSizes[RemoteChunkIndex] >= m_Options.LargeAttachmentSize) { - DownloadLargeBlob( - *m_Storage.BuildStorage, - m_TempDownloadFolderPath, - m_BuildId, - ChunkHash, - m_Options.PreferredMultipartChunkSize, - Work, - m_NetworkPool, - m_DownloadStats.DownloadedChunkByteCount, - m_DownloadStats.MultipartAttachmentCount, - [this, &Work, ChunkHash, RemoteChunkIndex, OnDownloaded = std::move(OnDownloaded)](IoBuffer&& Payload) mutable { - m_DownloadStats.DownloadedChunkCount++; - m_DownloadStats.RequestsCompleteCount++; + DownloadLargeBlob(*m_Storage.BuildStorage, + m_TempDownloadFolderPath, + m_BuildId, + ChunkHash, + m_Options.PreferredMultipartChunkSize, + Work, + m_NetworkPool, + m_DownloadStats.DownloadedChunkByteCount, + m_DownloadStats.MultipartAttachmentCount, + [this, + &Work, + &FilteredDownloadedBytesPerSecond, + ChunkHash, + RemoteChunkIndex, + TotalRequestCount, + OnDownloaded = std::move(OnDownloaded)](IoBuffer&& Payload) mutable { + m_DownloadStats.DownloadedChunkCount++; + if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } - if (Payload && m_Storage.CacheStorage && m_Options.PopulateCache) - { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, - ChunkHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(Payload))); - } + if (Payload && m_Storage.CacheStorage && m_Options.PopulateCache) + { + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + ChunkHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(Payload))); + } - OnDownloaded(std::move(Payload)); - }); + OnDownloaded(std::move(Payload)); + }); } else { @@ -3193,7 +3189,10 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde uint64_t BlobSize = BuildBlob.GetSize(); m_DownloadStats.DownloadedChunkCount++; m_DownloadStats.DownloadedChunkByteCount += BlobSize; - m_DownloadStats.RequestsCompleteCount++; + if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } OnDownloaded(std::move(BuildBlob)); } @@ -3208,6 +3207,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( size_t BlockRangeStartIndex, size_t BlockRangeCount, const BlobsExistsResult& ExistsResult, + uint64_t TotalRequestCount, + FilteredRate& FilteredDownloadedBytesPerSecond, std::function<void(IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath, size_t BlockRangeStartIndex, @@ -3222,6 +3223,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( IoBuffer&& BlockRangeBuffer, size_t BlockRangeStartIndex, std::span<const std::pair<uint64_t, uint64_t>> BlockOffsetAndLengths, + uint64_t TotalRequestCount, + FilteredRate& FilteredDownloadedBytesPerSecond, const std::function<void(IoBuffer && InMemoryBuffer, const std::filesystem::path& OnDiskPath, size_t BlockRangeStartIndex, @@ -3229,7 +3232,11 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( uint64_t BlockRangeBufferSize = BlockRangeBuffer.GetSize(); m_DownloadStats.DownloadedBlockCount++; m_DownloadStats.DownloadedBlockByteCount += BlockRangeBufferSize; - m_DownloadStats.RequestsCompleteCount += BlockOffsetAndLengths.size(); + if (m_DownloadStats.RequestsCompleteCount.fetch_add(BlockOffsetAndLengths.size()) + BlockOffsetAndLengths.size() == + TotalRequestCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } std::filesystem::path BlockChunkPath; @@ -3337,6 +3344,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( std::move(PayloadBuffer), SubRangeStartIndex, std::vector<std::pair<uint64_t, uint64_t>>{std::make_pair(0u, SubRange.second)}, + TotalRequestCount, + FilteredDownloadedBytesPerSecond, OnDownloaded); SubRangeCountComplete += SubRangeCount; continue; @@ -3361,6 +3370,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( std::move(RangeBuffers.PayloadBuffer), SubRangeStartIndex, RangesSpan.subspan(SubRangeCountComplete, SubRangeCount), + TotalRequestCount, + FilteredDownloadedBytesPerSecond, OnDownloaded); SubRangeCountComplete += SubRangeCount; continue; @@ -3371,6 +3382,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( std::move(RangeBuffers.PayloadBuffer), SubRangeStartIndex, RangeBuffers.Ranges, + TotalRequestCount, + FilteredDownloadedBytesPerSecond, OnDownloaded); SubRangeCountComplete += SubRangeCount; continue; @@ -3413,6 +3426,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( std::move(RangeBuffers.PayloadBuffer), SubRangeStartIndex, RangesSpan.subspan(SubRangeCountComplete, SubRangeCount), + TotalRequestCount, + FilteredDownloadedBytesPerSecond, OnDownloaded); } else @@ -3428,6 +3443,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( std::move(RangeBuffers.PayloadBuffer), SubRangeStartIndex, RangeBuffers.Ranges, + TotalRequestCount, + FilteredDownloadedBytesPerSecond, OnDownloaded); } } @@ -4244,8 +4261,7 @@ BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::pa bool NeedHashVerify = WriteCompressedChunkToCache(ChunkHash, ChunkTargetPtrs, WriteCache, std::move(CompressedPart)); if (!m_AbortFlag) { - WritePartsComplete++; - if (WritePartsComplete == TotalPartWriteCount) + if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount) { FilteredWrittenBytesPerSecond.Stop(); } @@ -6111,8 +6127,7 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co TempUploadStats.BlockCount++; - UploadedBlockCount++; - if (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount) + if (UploadedBlockCount.fetch_add(1) + 1 == UploadBlockCount && UploadedChunkCount == UploadChunkCount) { FilteredUploadedBytesPerSecond.Stop(); } @@ -6192,8 +6207,8 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co if (IsComplete) { TempUploadStats.ChunkCount++; - UploadedChunkCount++; - if (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount) + if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount && + UploadedBlockCount == UploadBlockCount) { FilteredUploadedBytesPerSecond.Stop(); } @@ -6227,8 +6242,7 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co TempUploadStats.ChunkCount++; UploadedCompressedChunkSize += Payload.GetSize(); UploadedRawChunkSize += RawSize; - UploadedChunkCount++; - if (UploadedChunkCount == UploadChunkCount) + if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount && UploadedBlockCount == UploadBlockCount) { FilteredUploadedBytesPerSecond.Stop(); } @@ -6237,8 +6251,6 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co }); }; - std::vector<size_t> GenerateBlockIndexes; - std::atomic<uint64_t> GeneratedBlockCount = 0; std::atomic<uint64_t> GeneratedBlockByteCount = 0; @@ -6260,9 +6272,9 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co &Lookup, &NewBlocks, &NewBlockChunks, - &GenerateBlockIndexes, &GeneratedBlockCount, &GeneratedBlockByteCount, + GenerateBlockCount = BlockIndexes.size(), &AsyncUploadBlock, &QueuedPendingInMemoryBlocksForUpload](std::atomic<bool>&) { if (!m_AbortFlag) @@ -6293,8 +6305,7 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co } GeneratedBlockByteCount += NewBlocks.BlockSizes[BlockIndex]; - GeneratedBlockCount++; - if (GeneratedBlockCount == GenerateBlockIndexes.size()) + if (GeneratedBlockCount.fetch_add(1) + 1 == GenerateBlockCount) { FilteredGenerateBlockBytesPerSecond.Stop(); } @@ -7005,8 +7016,7 @@ BuildsOperationPrimeCache::Execute() CompositeBuffer(SharedBuffer(Payload))); } } - CompletedDownloadCount++; - if (CompletedDownloadCount == BlobCount) + if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount) { FilteredDownloadedBytesPerSecond.Stop(); } @@ -7029,8 +7039,7 @@ BuildsOperationPrimeCache::Execute() CompositeBuffer(SharedBuffer(std::move(Payload)))); } } - CompletedDownloadCount++; - if (CompletedDownloadCount == BlobCount) + if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount) { FilteredDownloadedBytesPerSecond.Stop(); } diff --git a/src/zenremotestore/builds/jupiterbuildstorage.cpp b/src/zenremotestore/builds/jupiterbuildstorage.cpp index ad4c4bc89..d837ce07f 100644 --- a/src/zenremotestore/builds/jupiterbuildstorage.cpp +++ b/src/zenremotestore/builds/jupiterbuildstorage.cpp @@ -263,7 +263,7 @@ public: std::vector<std::function<void()>> WorkList; for (auto& WorkItem : WorkItems) { - WorkList.emplace_back([this, WorkItem = std::move(WorkItem), OnSentBytes]() { + WorkList.emplace_back([this, WorkItem = std::move(WorkItem), OnSentBytes = std::move(OnSentBytes)]() { Stopwatch ExecutionTimer; auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); }); bool IsComplete = false; @@ -444,11 +444,13 @@ public: virtual bool GetExtendedStatistics(ExtendedStatistics& OutStats) override { - OutStats.ReceivedBytesPerSource.reserve(m_ReceivedBytesPerSource.size()); - for (auto& It : m_ReceivedBytesPerSource) - { - OutStats.ReceivedBytesPerSource.insert_or_assign(It.first, m_SourceBytes[It.second]); - } + m_SourceLock.WithSharedLock([this, &OutStats]() { + OutStats.ReceivedBytesPerSource.reserve(m_ReceivedBytesPerSource.size()); + for (auto& It : m_ReceivedBytesPerSource) + { + OutStats.ReceivedBytesPerSource.insert_or_assign(It.first, m_SourceBytes[It.second].load(std::memory_order_relaxed)); + } + }); return true; } @@ -521,15 +523,29 @@ private: } if (!Result.Source.empty()) { - if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source); - It != m_ReceivedBytesPerSource.end()) - { - m_SourceBytes[It->second] += Result.ReceivedBytes; - } - else + if (!m_SourceLock.WithSharedLock([&]() { + if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source); + It != m_ReceivedBytesPerSource.end()) + { + m_SourceBytes[It->second] += Result.ReceivedBytes; + return true; + } + return false; + })) { - m_ReceivedBytesPerSource.insert_or_assign(Result.Source, m_SourceBytes.size()); - m_SourceBytes.push_back(Result.ReceivedBytes); + m_SourceLock.WithExclusiveLock([&]() { + if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source); + It != m_ReceivedBytesPerSource.end()) + { + m_SourceBytes[It->second] += Result.ReceivedBytes; + } + else if (m_SourceCount < MaxSourceCount) + { + size_t Index = m_SourceCount++; + m_ReceivedBytesPerSource.insert_or_assign(Result.Source, Index); + m_SourceBytes[Index] += Result.ReceivedBytes; + } + }); } } } @@ -540,8 +556,11 @@ private: const std::string m_Bucket; const std::filesystem::path m_TempFolderPath; - tsl::robin_map<std::string, uint32_t> m_ReceivedBytesPerSource; - std::vector<uint64_t> m_SourceBytes; + RwLock m_SourceLock; + tsl::robin_map<std::string, uint32_t> m_ReceivedBytesPerSource; + static constexpr size_t MaxSourceCount = 8u; + std::array<std::atomic<uint64_t>, MaxSourceCount> m_SourceBytes; + size_t m_SourceCount = 0; }; std::unique_ptr<BuildStorageBase> diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h index 0d2eded58..27dc9de86 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h @@ -261,12 +261,16 @@ private: void DownloadBuildBlob(uint32_t RemoteChunkIndex, const BlobsExistsResult& ExistsResult, ParallelWork& Work, + uint64_t TotalRequestCount, + FilteredRate& FilteredDownloadedBytesPerSecond, std::function<void(IoBuffer&& Payload)>&& OnDownloaded); - void DownloadPartialBlock(std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges, - size_t BlockRangeIndex, - size_t BlockRangeCount, - const BlobsExistsResult& ExistsResult, + void DownloadPartialBlock(std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges, + size_t BlockRangeIndex, + size_t BlockRangeCount, + const BlobsExistsResult& ExistsResult, + uint64_t TotalRequestCount, + FilteredRate& FilteredDownloadedBytesPerSecond, std::function<void(IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath, size_t BlockRangeStartIndex, diff --git a/src/zenremotestore/jupiter/jupitersession.cpp b/src/zenremotestore/jupiter/jupitersession.cpp index a9788cb4e..d610d1fc8 100644 --- a/src/zenremotestore/jupiter/jupitersession.cpp +++ b/src/zenremotestore/jupiter/jupitersession.cpp @@ -673,7 +673,7 @@ JupiterSession::PutMultipartBuildBlob(std::string_view Namespace, size_t RetryPartIndex = PartNameToIndex.at(RetryPartId); const MultipartUploadResponse::Part& RetryPart = Workload->PartDescription.Parts[RetryPartIndex]; IoBuffer RetryPartPayload = - Workload->Transmitter(RetryPart.FirstByte, RetryPart.LastByte - RetryPart.FirstByte - 1); + Workload->Transmitter(RetryPart.FirstByte, RetryPart.LastByte - RetryPart.FirstByte); std::string RetryMultipartUploadResponseRequestString = fmt::format("/api/v2/builds/{}/{}/{}/blobs/{}/uploadMultipart{}&supportsRedirect={}", Namespace, diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp index 1a9dc10ef..2076adb70 100644 --- a/src/zenremotestore/projectstore/remoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp @@ -929,7 +929,6 @@ namespace remotestore_impl { { return; } - ZEN_ASSERT(UploadAttachment->Size != 0); if (!UploadAttachment->RawPath.empty()) { if (UploadAttachment->Size > (MaxChunkEmbedSize * 2)) @@ -7008,6 +7007,60 @@ TEST_CASE("buildcontainer.ignore_missing_file_attachment_warn") } } +TEST_CASE("buildcontainer.zero_byte_file_attachment") +{ + // A zero-byte file on disk is a valid attachment. BuildContainer must process + // it without hitting ZEN_ASSERT(UploadAttachment->Size != 0) in + // ResolveAttachments. The empty file flows through the compress-inline path + // and becomes a LooseUploadAttachment with raw size 0. + using namespace projectstore_testutils; + using namespace std::literals; + + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CidStore CidStore(Gc); + std::unique_ptr<ProjectStore> ProjectStoreDummy; + Ref<ProjectStore::Project> Project = MakeTestProject(CidStore, Gc, TempDir.Path(), ProjectStoreDummy); + + std::filesystem::path RootDir = TempDir.Path() / "root"; + auto FileAtts = CreateFileAttachments(RootDir, std::initializer_list<size_t>{512}); + + Ref<ProjectStore::Oplog> Oplog = Project->NewOplog("bc_zero_byte_file", {}); + REQUIRE(Oplog); + Oplog->AppendNewOplogEntry(CreateFilesOplogPackage(Oid::NewOid(), RootDir, FileAtts)); + + // Truncate the file to zero bytes after the oplog entry is created. + // The file still exists on disk so RewriteOplog's IsFile() check passes, + // but MakeFromFile returns a zero-size buffer. + std::filesystem::resize_file(FileAtts[0].second, 0); + + WorkerThreadPool WorkerPool(GetWorkerCount()); + + CbObject Container = BuildContainer( + CidStore, + *Project, + *Oplog, + WorkerPool, + 64u * 1024u, + 1000, + 32u * 1024u, + 64u * 1024u * 1024u, + /*BuildBlocks=*/true, + /*IgnoreMissingAttachments=*/false, + /*AllowChunking=*/true, + [](CompressedBuffer&&, ChunkBlockDescription&&) {}, + [](const IoHash&, TGetAttachmentBufferFunc&&) {}, + [](std::vector<std::pair<IoHash, FetchChunkFunc>>&&) {}, + /*EmbedLooseFiles=*/true); + + CHECK(Container.GetSize() > 0); + + // The zero-byte attachment is packed into a block via the compress-inline path. + CbArrayView Blocks = Container["blocks"sv].AsArrayView(); + CHECK(Blocks.Num() > 0); +} + TEST_CASE("buildcontainer.embed_loose_files_false_no_rewrite") { // EmbedLooseFiles=false: RewriteOp is skipped for file-op entries; they pass through diff --git a/src/zens3-testbed/main.cpp b/src/zens3-testbed/main.cpp index 4cd6b411f..1543c4d7c 100644 --- a/src/zens3-testbed/main.cpp +++ b/src/zens3-testbed/main.cpp @@ -110,7 +110,7 @@ CreateClient(const cxxopts::ParseResult& Args) if (Args.count("timeout")) { - Options.Timeout = std::chrono::milliseconds(Args["timeout"].as<int>() * 1000); + Options.HttpSettings.Timeout = std::chrono::milliseconds(Args["timeout"].as<int>() * 1000); } return S3Client(Options); diff --git a/src/zenserver-test/cache-tests.cpp b/src/zenserver-test/cache-tests.cpp index 14748e214..986dc67e0 100644 --- a/src/zenserver-test/cache-tests.cpp +++ b/src/zenserver-test/cache-tests.cpp @@ -9,6 +9,7 @@ # include <zencore/compactbinarypackage.h> # include <zencore/compress.h> # include <zencore/fmtutils.h> +# include <zenhttp/localrefpolicy.h> # include <zenhttp/packageformat.h> # include <zenstore/cache/cachepolicy.h> # include <zencore/filesystem.h> @@ -25,6 +26,13 @@ namespace zen::tests { TEST_SUITE_BEGIN("server.cache"); +/// Permissive policy that allows any path, for use in tests that exercise local ref +/// functionality but are not testing path validation. +struct PermissiveLocalRefPolicy : public ILocalRefPolicy +{ + void ValidatePath(const std::filesystem::path&) const override {} +}; + TEST_CASE("zcache.basic") { using namespace std::literals; @@ -743,7 +751,11 @@ TEST_CASE("zcache.rpc") if (Result.StatusCode == HttpResponseCode::OK) { - CbPackage Response = ParsePackageMessage(Result.ResponsePayload); + ParseFlags PFlags = EnumHasAllFlags(AcceptOptions, RpcAcceptOptions::kAllowLocalReferences) ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + PermissiveLocalRefPolicy AllowAllPolicy; + const ILocalRefPolicy* PPolicy = EnumHasAllFlags(PFlags, ParseFlags::kAllowLocalReferences) ? &AllowAllPolicy : nullptr; + CbPackage Response = ParsePackageMessage(Result.ResponsePayload, {}, PFlags, PPolicy); CHECK(!Response.IsNull()); OutResult.Response = std::move(Response); CHECK(OutResult.Result.Parse(OutResult.Response)); @@ -1745,8 +1757,13 @@ TEST_CASE("zcache.rpc.partialchunks") CHECK(Result.StatusCode == HttpResponseCode::OK); - CbPackage Response = ParsePackageMessage(Result.ResponsePayload); - bool Loaded = !Response.IsNull(); + ParseFlags PFlags = EnumHasAllFlags(Options.AcceptOptions, RpcAcceptOptions::kAllowLocalReferences) + ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + PermissiveLocalRefPolicy AllowAllPolicy; + const ILocalRefPolicy* PPolicy = EnumHasAllFlags(PFlags, ParseFlags::kAllowLocalReferences) ? &AllowAllPolicy : nullptr; + CbPackage Response = ParsePackageMessage(Result.ResponsePayload, {}, PFlags, PPolicy); + bool Loaded = !Response.IsNull(); CHECK_MESSAGE(Loaded, "GetCacheChunks response failed to load."); cacherequests::GetCacheChunksResult GetCacheChunksResult; CHECK(GetCacheChunksResult.Parse(Response)); diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp index 021052a3b..835d72713 100644 --- a/src/zenserver-test/compute-tests.cpp +++ b/src/zenserver-test/compute-tests.cpp @@ -21,6 +21,7 @@ # include <zenhttp/httpserver.h> # include <zenhttp/websocket.h> # include <zencompute/computeservice.h> +# include <zencore/fmtutils.h> # include <zenstore/zenstore.h> # include <zenutil/zenserverprocess.h> @@ -36,6 +37,8 @@ using namespace std::literals; static constexpr std::string_view kBuildSystemVersion = "17fe280d-ccd8-4be8-a9d1-89c944a70969"; static constexpr std::string_view kRot13Version = "13131313-1313-1313-1313-131313131313"; static constexpr std::string_view kSleepVersion = "88888888-8888-8888-8888-888888888888"; +static constexpr std::string_view kFailVersion = "fa11fa11-fa11-fa11-fa11-fa11fa11fa11"; +static constexpr std::string_view kCrashVersion = "c4a50000-c4a5-c4a5-c4a5-c4a5c4a5c4a5"; // In-memory implementation of ChunkResolver for test use. // Stores compressed data keyed by decompressed content hash. @@ -104,6 +107,16 @@ RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env) << "Sleep"sv; WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Fail"sv; + WorkerWriter << "version"sv << Guid::FromString(kFailVersion); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Crash"sv; + WorkerWriter << "version"sv << Guid::FromString(kCrashVersion); + WorkerWriter.EndObject(); WorkerWriter.EndArray(); CbPackage WorkerPackage; @@ -115,7 +128,7 @@ RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env) const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); HttpClient::Response RegisterResp = Client.Post(WorkerUrl, std::move(WorkerPackage)); REQUIRE_MESSAGE(RegisterResp, - fmt::format("Worker registration failed: status={}, body={}", int(RegisterResp.StatusCode), RegisterResp.ToText())); + fmt::format("Worker registration failed: status={}, body={}", RegisterResp.StatusCode, RegisterResp.ToText())); return WorkerId; } @@ -220,6 +233,83 @@ BuildSleepActionForSession(std::string_view Input, uint64_t SleepTimeMs, InMemor return ActionWriter.Save(); } +// Build a Fail action CbPackage. The worker exits with the given exit code. +static CbPackage +BuildFailActionPackage(int ExitCode) +{ + // The Fail function throws before reading inputs, but the action structure + // still requires a valid input attachment for the runner to manifest. + std::string_view Dummy = "x"sv; + + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Dummy.data(), Dummy.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Dummy.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Fail"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kFailVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "ExitCode"sv << static_cast<uint64_t>(ExitCode); + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + +// Build a Crash action CbPackage. The worker process crashes hard. +// Mode: "abort" (default) or "nullptr" (null pointer dereference). +static CbPackage +BuildCrashActionPackage(std::string_view Mode = "abort"sv) +{ + std::string_view Dummy = "x"sv; + + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Dummy.data(), Dummy.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Dummy.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Crash"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kCrashVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "Mode"sv << Mode; + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + static HttpClient::Response PollForResult(HttpClient& Client, const std::string& ResultUrl, uint64_t TimeoutMs = 30'000) { @@ -340,8 +430,9 @@ public: } // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override { + ZEN_UNUSED(RelativeUri); m_WsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); } @@ -469,6 +560,16 @@ BuildWorkerPackage(ZenServerEnvironment& Env, InMemoryChunkResolver& Resolver) << "Sleep"sv; WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Fail"sv; + WorkerWriter << "version"sv << Guid::FromString(kFailVersion); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Crash"sv; + WorkerWriter << "version"sv << Guid::FromString(kCrashVersion); + WorkerWriter.EndObject(); WorkerWriter.EndArray(); CbPackage WorkerPackage; @@ -526,7 +627,7 @@ TEST_CASE("function.rot13") // Submit action via legacy /jobs/{worker} endpoint const std::string JobUrl = fmt::format("/jobs/{}", WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission"); @@ -536,7 +637,7 @@ TEST_CASE("function.rot13") HttpClient::Response ResultResp = PollForResult(Client, ResultUrl); REQUIRE_MESSAGE( ResultResp.StatusCode == HttpResponseCode::OK, - fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", ResultResp.StatusCode, Instance.GetLogOutput())); // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" CbPackage ResultPackage = ResultResp.AsPackage(); @@ -581,7 +682,7 @@ TEST_CASE("function.workers") // GET /workers/{worker} — descriptor should match what was registered const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); HttpClient::Response DescResp = Client.Get(WorkerUrl); - REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", int(DescResp.StatusCode))); + REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", DescResp.StatusCode)); CbObject Desc = DescResp.AsObject(); CHECK_EQ(Desc["buildsystem_version"sv].AsUuid(), Guid::FromString(kBuildSystemVersion)); @@ -627,7 +728,7 @@ TEST_CASE("function.queues.lifecycle") // Create a queue HttpClient::Response CreateResp = Client.Post("/queues"sv); - REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText())); const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); @@ -651,8 +752,7 @@ TEST_CASE("function.queues.lifecycle") // Submit action via queue-scoped endpoint const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); - REQUIRE_MESSAGE(SubmitResp, - fmt::format("Queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Queue job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from queue job submission"); @@ -668,9 +768,8 @@ TEST_CASE("function.queues.lifecycle") // Retrieve result via queue-scoped /jobs/{lsn} endpoint const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); HttpClient::Response ResultResp = Client.Get(ResultUrl); - REQUIRE_MESSAGE( - ResultResp.StatusCode == HttpResponseCode::OK, - fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", ResultResp.StatusCode, Instance.GetLogOutput())); // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" CbPackage ResultPackage = ResultResp.AsPackage(); @@ -712,13 +811,13 @@ TEST_CASE("function.queues.cancel") // Submit a job const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); // Cancel the queue const std::string QueueUrl = fmt::format("/queues/{}", QueueId); HttpClient::Response CancelResp = Client.Delete(QueueUrl); REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, - fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + fmt::format("Queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText())); // Verify queue status shows cancelled HttpClient::Response StatusResp = Client.Get(QueueUrl); @@ -743,7 +842,7 @@ TEST_CASE("function.queues.remote") // Create a remote queue — response includes both an integer queue_id and an OID queue_token HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); REQUIRE_MESSAGE(CreateResp, - fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + fmt::format("Remote queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText())); CbObject CreateObj = CreateResp.AsObject(); const std::string QueueToken = std::string(CreateObj["queue_token"sv].AsString()); @@ -753,7 +852,7 @@ TEST_CASE("function.queues.remote") const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); REQUIRE_MESSAGE(SubmitResp, - fmt::format("Remote queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + fmt::format("Remote queue job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from remote queue job submission"); @@ -769,7 +868,7 @@ TEST_CASE("function.queues.remote") HttpClient::Response ResultResp = Client.Get(ResultUrl); REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK, fmt::format("Failed to retrieve result from remote queue: status={}\nServer log:\n{}", - int(ResultResp.StatusCode), + ResultResp.StatusCode, Instance.GetLogOutput())); // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" @@ -801,8 +900,7 @@ TEST_CASE("function.queues.cancel_running") // Submit a Sleep job long enough that it will still be running when we cancel const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(SubmitResp, - fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); @@ -814,7 +912,7 @@ TEST_CASE("function.queues.cancel_running") const std::string QueueUrl = fmt::format("/queues/{}", QueueId); HttpClient::Response CancelResp = Client.Delete(QueueUrl); REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, - fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + fmt::format("Queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText())); // The cancelled job should appear in the /completed endpoint once the process exits const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); @@ -849,7 +947,7 @@ TEST_CASE("function.queues.remote_cancel") // Create a remote queue to obtain an OID token for token-addressed cancellation HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); REQUIRE_MESSAGE(CreateResp, - fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + fmt::format("Remote queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText())); const std::string QueueToken = std::string(CreateResp.AsObject()["queue_token"sv].AsString()); REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation"); @@ -857,8 +955,7 @@ TEST_CASE("function.queues.remote_cancel") // Submit a long-running Sleep job via the token-addressed endpoint const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(SubmitResp, - fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); @@ -870,7 +967,7 @@ TEST_CASE("function.queues.remote_cancel") const std::string QueueUrl = fmt::format("/queues/{}", QueueToken); HttpClient::Response CancelResp = Client.Delete(QueueUrl); REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, - fmt::format("Remote queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + fmt::format("Remote queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText())); // The cancelled job should appear in the token-addressed /completed endpoint const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken); @@ -910,13 +1007,13 @@ TEST_CASE("function.queues.drain") // Submit a long-running job so we can verify it completes even after drain const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response Submit1 = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 2'000)); - REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", int(Submit1.StatusCode))); + REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", Submit1.StatusCode)); const int Lsn1 = Submit1.AsObject()["lsn"sv].AsInt32(); // Drain the queue const std::string DrainUrl = fmt::format("/queues/{}/drain", QueueId); HttpClient::Response DrainResp = Client.Post(DrainUrl); - REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", int(DrainResp.StatusCode), DrainResp.ToText())); + REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", DrainResp.StatusCode, DrainResp.ToText())); CHECK_EQ(std::string(DrainResp.AsObject()["state"sv].AsString()), "draining"); // Second submission should be rejected with 424 @@ -965,7 +1062,7 @@ TEST_CASE("function.priority") // jobs by priority when the slot becomes free. const std::string BlockerJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); HttpClient::Response BlockerResp = Client.Post(BlockerJobUrl, BuildSleepActionPackage("data"sv, 1'000)); - REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", int(BlockerResp.StatusCode))); + REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", BlockerResp.StatusCode)); // Submit 3 low-priority Rot13 jobs const std::string LowJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); @@ -1104,6 +1201,432 @@ TEST_CASE("function.priority") } ////////////////////////////////////////////////////////////////////////// +// Process exit code tests +// +// These tests exercise how the compute service handles worker processes +// that exit with non-zero exit codes, including retry behaviour and +// final failure reporting. + +TEST_CASE("function.exit_code.failed_action") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=0 so the action fails immediately + // without being rescheduled. + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 0; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Fail action with exit code 42 + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(42)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Fail job submission"); + + // Poll for the LSN to appear in the completed list + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify queue status reflects the failure + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + + // Verify action history records the failure + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + bool FoundInHistory = false; + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + FoundInHistory = true; + break; + } + } + CHECK_MESSAGE(FoundInHistory, fmt::format("LSN {} not found in action history", Lsn)); + + // GET /jobs/{lsn} for a failed action should return OK but with an empty result package + const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); + HttpClient::Response ResultResp = Client.Get(ResultUrl); + CHECK_EQ(ResultResp.StatusCode, HttpResponseCode::OK); +} + +TEST_CASE("function.exit_code.auto_retry") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=2 so the action is retried twice before + // being reported as failed (3 total attempts: initial + 2 retries). + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 2; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Fail action — the worker process will exit with code 1 on + // every attempt, eventually exhausting retries. + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(1)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}", SubmitResp.StatusCode)); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + + // Poll for the LSN to appear in the completed list — this only + // happens after all retries are exhausted. + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000), + fmt::format("LSN {} did not appear in queue {} completed list after retries\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify the action history records the retry count + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + CHECK_EQ(Item.AsObjectView()["retry_count"sv].AsInt32(), 2); + break; + } + } + + // Queue should show 1 failed, 0 completed + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.exit_code.reschedule_failed") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=1 so we have room for one manual reschedule + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 1; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Fail action — auto-retry will fire once, then it lands in results as Failed + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(7)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}", SubmitResp.StatusCode)); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + + // Wait for the action to exhaust its auto-retry and land in completed + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000), + fmt::format("LSN {} did not appear in queue completed list\nServer log:\n{}", Lsn, Instance.GetLogOutput())); + + // Try to manually reschedule — should fail because retry limit is reached + const std::string RescheduleUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); + HttpClient::Response RescheduleResp = Client.Post(RescheduleUrl); + CHECK_EQ(RescheduleResp.StatusCode, HttpResponseCode::Conflict); +} + +TEST_CASE("function.exit_code.mixed_success_and_failure") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=0 for fast failure + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 0; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit one Rot13 (success) and one Fail (failure) + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + + HttpClient::Response SuccessResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello"sv)); + REQUIRE_MESSAGE(SuccessResp, "Rot13 job submission failed"); + const int LsnSuccess = SuccessResp.AsObject()["lsn"sv].AsInt32(); + + HttpClient::Response FailResp = Client.Post(JobUrl, BuildFailActionPackage(1)); + REQUIRE_MESSAGE(FailResp, "Fail job submission failed"); + const int LsnFail = FailResp.AsObject()["lsn"sv].AsInt32(); + + // Wait for both to appear in the completed list + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, LsnSuccess), + fmt::format("Success LSN {} did not complete\nServer log:\n{}", LsnSuccess, Instance.GetLogOutput())); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, LsnFail), + fmt::format("Fail LSN {} did not complete\nServer log:\n{}", LsnFail, Instance.GetLogOutput())); + + // Verify queue counters + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.crash.abort") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=0 so we don't wait through retries + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 0; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Crash action that calls std::abort() + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("abort"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Crash job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Crash job submission"); + + // Poll for the LSN to appear in the completed list + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify queue status reflects the failure + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + + // Verify action history records the failure + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + bool FoundInHistory = false; + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + FoundInHistory = true; + break; + } + } + CHECK_MESSAGE(FoundInHistory, fmt::format("LSN {} not found in action history", Lsn)); +} + +TEST_CASE("function.crash.nullptr") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=0 + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 0; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Crash action that dereferences a null pointer + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("nullptr"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Crash job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Crash job submission"); + + // Poll for the LSN to appear in the completed list + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify queue status reflects the failure + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.crash.auto_retry") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue with max_retries=1 — the crash should be retried once + // before being reported as permanently failed. + CbObjectWriter ConfigWriter; + ConfigWriter << "max_retries"sv << 1; + + CbObjectWriter BodyWriter; + BodyWriter << "config"sv << ConfigWriter.Save(); + + HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save()); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode)); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a Crash action — will crash on every attempt + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("abort"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Crash job submission failed: status={}", SubmitResp.StatusCode)); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + + // Poll for the LSN to appear in the completed list after retries exhaust + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000), + fmt::format("LSN {} did not appear in queue {} completed list after retries\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify the action history records the retry count + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + for (auto& Item : HistoryResp.AsObject()["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false); + CHECK_EQ(Item.AsObjectView()["retry_count"sv].AsInt32(), 1); + break; + } + } + + // Queue should show 1 failed, 0 completed + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +////////////////////////////////////////////////////////////////////////// // Remote worker synchronization tests // // These tests exercise the orchestrator discovery path where new compute @@ -1162,9 +1685,8 @@ TEST_CASE("function.remote.worker_sync_on_discovery") Sleep(200); } - REQUIRE_MESSAGE( - ResultCode == HttpResponseCode::OK, - fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", ResultCode, Instance.GetLogOutput())); REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); @@ -1349,9 +1871,8 @@ TEST_CASE("function.remote.queue_association") Sleep(200); } - REQUIRE_MESSAGE( - ResultCode == HttpResponseCode::OK, - fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", ResultCode, Instance.GetLogOutput())); REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); @@ -1481,7 +2002,7 @@ TEST_CASE("function.abandon_running_http") const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", int(SubmitResp.StatusCode))); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", SubmitResp.StatusCode)); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN"); @@ -1498,7 +2019,7 @@ TEST_CASE("function.abandon_running_http") // Trigger abandon via the HTTP endpoint HttpClient::Response AbandonResp = Client.Post("/abandon"sv); REQUIRE_MESSAGE(AbandonResp.StatusCode == HttpResponseCode::OK, - fmt::format("Abandon request failed: status={}, body={}", int(AbandonResp.StatusCode), AbandonResp.ToText())); + fmt::format("Abandon request failed: status={}, body={}", AbandonResp.StatusCode, AbandonResp.ToText())); // Ready endpoint should now return 503 { @@ -1529,7 +2050,7 @@ TEST_CASE("function.abandon_running_http") CHECK_MESSAGE(RejectedResp.StatusCode != HttpResponseCode::OK, "Expected action submission to be rejected in Abandoned state"); } -TEST_CASE("function.session.abandon_pending") +TEST_CASE("function.session.abandon_pending" * doctest::skip()) { // Create a session with no runners so actions stay pending InMemoryChunkResolver Resolver; @@ -1577,7 +2098,7 @@ TEST_CASE("function.session.abandon_pending") } Sleep(100); } - CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code))); + CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, Code)); } // Queue should show 0 active, 3 abandoned @@ -1979,11 +2500,11 @@ TEST_CASE("function.retract_http") // Submit a long-running Sleep action to occupy the single execution slot const std::string BlockerUrl = fmt::format("/jobs/{}", WorkerId.ToHexString()); HttpClient::Response BlockerResp = Client.Post(BlockerUrl, BuildSleepActionPackage("data"sv, 30'000)); - REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", int(BlockerResp.StatusCode))); + REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", BlockerResp.StatusCode)); // Submit a second action — it will stay pending because the slot is occupied HttpClient::Response SubmitResp = Client.Post(BlockerUrl, BuildRot13ActionPackage("Retract HTTP Test"sv)); - REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", int(SubmitResp.StatusCode))); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", SubmitResp.StatusCode)); const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission"); @@ -1996,7 +2517,7 @@ TEST_CASE("function.retract_http") HttpClient::Response RetractResp = Client.Post(RetractUrl); CHECK_MESSAGE(RetractResp.StatusCode == HttpResponseCode::OK, fmt::format("Retract failed: status={}, body={}\nServer log:\n{}", - int(RetractResp.StatusCode), + RetractResp.StatusCode, RetractResp.ToText(), Instance.GetLogOutput())); @@ -2011,7 +2532,42 @@ TEST_CASE("function.retract_http") Sleep(500); HttpClient::Response RetractResp2 = Client.Post(RetractUrl); CHECK_MESSAGE(RetractResp2.StatusCode == HttpResponseCode::OK, - fmt::format("Second retract failed: status={}, body={}", int(RetractResp2.StatusCode), RetractResp2.ToText())); + fmt::format("Second retract failed: status={}, body={}", RetractResp2.StatusCode, RetractResp2.ToText())); +} + +TEST_CASE("function.session.immediate_query_after_enqueue") +{ + // Verify that actions are immediately visible to GetActionResult and + // FindActionResult right after enqueue, without waiting for the + // scheduler thread to process the update. + + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.Ready(); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + CbObject ActionObj = BuildRot13ActionForSession("immediate-query"sv, Resolver); + + auto EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Failed to enqueue action"); + + // Query by LSN immediately — must not return NotFound + CbPackage Result; + HttpResponseCode Code = Session.GetActionResult(EnqueueRes.Lsn, Result); + CHECK_MESSAGE(Code == HttpResponseCode::Accepted, + fmt::format("GetActionResult returned {} immediately after enqueue, expected Accepted", Code)); + + // Query by ActionId immediately — must not return NotFound + const IoHash ActionId = ActionObj.GetHash(); + CbPackage FindResult; + HttpResponseCode FindCode = Session.FindActionResult(ActionId, FindResult); + CHECK_MESSAGE(FindCode == HttpResponseCode::Accepted, + fmt::format("FindActionResult returned {} immediately after enqueue, expected Accepted", FindCode)); + + Session.Shutdown(); } TEST_SUITE_END(); diff --git a/src/zenserver-test/hub-tests.cpp b/src/zenserver-test/hub-tests.cpp index b2da552fc..82dfd7e91 100644 --- a/src/zenserver-test/hub-tests.cpp +++ b/src/zenserver-test/hub-tests.cpp @@ -377,7 +377,7 @@ TEST_CASE("hub.consul.kv") consul::ConsulProcess ConsulProc; ConsulProc.SpawnConsulAgent(); - consul::ConsulClient Client("http://localhost:8500/"); + consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"}); Client.SetKeyValue("zen/hub/testkey", "testvalue"); std::string RetrievedValue = Client.GetKeyValue("zen/hub/testkey"); @@ -399,7 +399,7 @@ TEST_CASE("hub.consul.hub.registration") "--consul-health-interval-seconds=5 --consul-deregister-after-seconds=60"); REQUIRE(PortNumber != 0); - consul::ConsulClient Client("http://localhost:8500/"); + consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"}); REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000)); // Verify custom intervals flowed through to the registered check @@ -480,7 +480,7 @@ TEST_CASE("hub.consul.hub.registration.token") // Use a plain client -- dev-mode Consul doesn't enforce ACLs, but the // server has exercised the ConsulTokenEnv -> GetEnvVariable -> ConsulClient path. - consul::ConsulClient Client("http://localhost:8500/"); + consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"}); REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000)); @@ -501,7 +501,7 @@ TEST_CASE("hub.consul.provision.registration") Instance.SpawnServerAndWaitUntilReady("--consul-endpoint=http://localhost:8500/ --instance-id=test-instance"); REQUIRE(PortNumber != 0); - consul::ConsulClient Client("http://localhost:8500/"); + consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"}); REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000)); diff --git a/src/zenserver-test/projectstore-tests.cpp b/src/zenserver-test/projectstore-tests.cpp index a37ecb6be..cec453511 100644 --- a/src/zenserver-test/projectstore-tests.cpp +++ b/src/zenserver-test/projectstore-tests.cpp @@ -22,6 +22,7 @@ ZEN_THIRD_PARTY_INCLUDES_START ZEN_THIRD_PARTY_INCLUDES_END # include <random> +# include <thread> namespace zen::tests { @@ -340,6 +341,102 @@ TEST_CASE("project.basic") ZEN_INFO("+++++++"); } + SUBCASE("snapshot zero byte file") + { + // A zero-byte file referenced in an oplog entry must survive a + // snapshot: the file is read, compressed, stored in CidStore, and + // the oplog is rewritten with a BinaryAttachment reference. After + // the snapshot the chunk must be retrievable and decompress to an + // empty payload. + + std::filesystem::path EmptyFileRelPath = std::filesystem::path("zerobyte_snapshot_test") / "empty.bin"; + std::filesystem::path EmptyFileAbsPath = RootPath / EmptyFileRelPath; + CreateDirectories(MakeSafeAbsolutePath(EmptyFileAbsPath.parent_path())); + // Create a zero-byte file on disk. + WriteFile(MakeSafeAbsolutePath(EmptyFileAbsPath), IoBuffer{}); + REQUIRE(IsFile(MakeSafeAbsolutePath(EmptyFileAbsPath))); + + const std::string_view EmptyChunkId{ + "00000000" + "00000000" + "00030000"}; + auto EmptyFileOid = zen::Oid::FromHexString(EmptyChunkId); + + zen::CbObjectWriter OpWriter; + OpWriter << "key" + << "zero_byte_test"; + OpWriter.BeginArray("files"); + OpWriter.BeginObject(); + OpWriter << "id" << EmptyFileOid; + OpWriter << "clientpath" + << "/{engine}/empty_file"; + OpWriter << "serverpath" << EmptyFileRelPath.c_str(); + OpWriter.EndObject(); + OpWriter.EndArray(); + + zen::CbObject Op = OpWriter.Save(); + zen::CbPackage OpPackage(Op); + + zen::BinaryWriter MemOut; + legacy::SaveCbPackage(OpPackage, MemOut); + + HttpClient Http{BaseUri}; + + { + auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView())); + REQUIRE(Response); + CHECK(Response.StatusCode == HttpResponseCode::Created); + } + + // Read file data before snapshot - raw and uncompressed, 0 bytes. + // http.sys converts a 200 OK with empty body to 204 No Content, so + // accept either status code. + { + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << "/" << EmptyChunkId; + auto Response = Http.Get(ChunkGetUri); + + REQUIRE(Response); + CHECK((Response.StatusCode == HttpResponseCode::OK || Response.StatusCode == HttpResponseCode::NoContent)); + CHECK(Response.ResponsePayload.GetSize() == 0); + } + + // Trigger snapshot. + { + IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer.AddString("method"sv, "snapshot"sv); }); + auto Response = Http.Post("/rpc"sv, Payload); + REQUIRE(Response); + CHECK(Response.StatusCode == HttpResponseCode::OK); + } + + // Read chunk after snapshot - compressed, decompresses to 0 bytes. + { + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << "/" << EmptyChunkId; + auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}}); + + REQUIRE(Response); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); + + IoBuffer Data = Response.ResponsePayload; + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize); + REQUIRE(Compressed); + CHECK(RawSize == 0); + IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer(); + CHECK(DataDecompressed.GetSize() == 0); + } + + // Cleanup + { + std::error_code Ec; + DeleteDirectories(MakeSafeAbsolutePath(RootPath / "zerobyte_snapshot_test"), Ec); + } + + ZEN_INFO("+++++++"); + } + SUBCASE("test chunk not found error") { HttpClient Http{BaseUri}; @@ -1154,6 +1251,412 @@ TEST_CASE("project.rpcappendop") } } +TEST_CASE("project.file.data.transitions") +{ + using namespace utils; + + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + const uint16_t PortNumber = Instance.SpawnServerAndWaitUntilReady(); + + zen::StringBuilder<64> ServerBaseUri; + ServerBaseUri << fmt::format("http://localhost:{}", PortNumber); + + // Set up a root directory with a test file on disk for path-referenced serving + std::filesystem::path RootDir = TestDir / "root"; + std::filesystem::path TestFilePath = RootDir / "content" / "testfile.bin"; + std::filesystem::path RelServerPath = std::filesystem::path("content") / "testfile.bin"; + CreateDirectories(TestFilePath.parent_path()); + IoBuffer FileBlob = CreateRandomBlob(4096); + WriteFile(TestFilePath, FileBlob); + + // Create a compressed blob to use as a CAS-referenced attachment (content differs from FileBlob) + CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(CreateRandomBlob(2048))); + + // Fixed chunk IDs for the file entry across sub-tests + const std::string_view FileChunkIdStr{ + "aa000000" + "bb000000" + "cc000001"}; + Oid FileOid = Oid::FromHexString(FileChunkIdStr); + + HttpClient Http{ServerBaseUri}; + + auto MakeProject = [&](std::string_view ProjectName) { + CbObjectWriter Project; + Project.AddString("id"sv, ProjectName); + Project.AddString("root"sv, PathToUtf8(RootDir.c_str())); + Project.AddString("engine"sv, ""sv); + Project.AddString("project"sv, ""sv); + Project.AddString("projectfile"sv, ""sv); + HttpClient::Response Response = Http.Post(fmt::format("/prj/{}", ProjectName), Project.Save()); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("MakeProject")); + }; + + auto MakeOplog = [&](std::string_view ProjectName, std::string_view OplogName) { + HttpClient::Response Response = + Http.Post(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName), IoBuffer{}, ZenContentType::kCbObject); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("MakeOplog")); + }; + + auto PostOplogEntry = [&](std::string_view ProjectName, std::string_view OplogName, const CbPackage& OpPackage) { + zen::BinaryWriter MemOut; + legacy::SaveCbPackage(OpPackage, MemOut); + IoBuffer Body{IoBuffer::Wrap, MemOut.GetData(), MemOut.GetSize()}; + Body.SetContentType(HttpContentType::kCbPackage); + HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/new", ProjectName, OplogName), Body); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("PostOplogEntry")); + }; + + auto GetChunk = [&](std::string_view ProjectName) -> HttpClient::Response { + return Http.Get(fmt::format("/prj/{}/oplog/oplog/{}", ProjectName, FileChunkIdStr)); + }; + + // Extract the raw decompressed bytes from a chunk response, handling both compressed and uncompressed payloads + auto GetDecompressedPayload = [](const HttpClient::Response& Response) -> IoBuffer { + if (Response.ResponsePayload.GetContentType() == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Response.ResponsePayload), RawHash, RawSize); + REQUIRE(Compressed); + return Compressed.Decompress().AsIoBuffer(); + } + return Response.ResponsePayload; + }; + + auto TriggerGcAndWait = [&]() { + HttpClient::Response TriggerResponse = Http.Post("/admin/gc?smallobjects=true"sv, IoBuffer{}); + REQUIRE_MESSAGE(TriggerResponse.IsSuccess(), TriggerResponse.ErrorMessage("TriggerGc")); + + for (int Attempt = 0; Attempt < 100; ++Attempt) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + HttpClient::Response StatusResponse = Http.Get("/admin/gc"sv); + REQUIRE_MESSAGE(StatusResponse.IsSuccess(), StatusResponse.ErrorMessage("GcStatus")); + CbObject StatusObj = StatusResponse.AsObject(); + if (StatusObj["Status"sv].AsString() == "Idle"sv) + { + return; + } + } + FAIL("GC did not complete within timeout"); + }; + + auto BuildPathReferencedFileOp = [&](const Oid& KeyId) -> CbPackage { + CbPackage Package; + CbObjectWriter Object; + Object << "key"sv << OidAsString(KeyId); + Object.BeginArray("files"sv); + Object.BeginObject(); + Object << "id"sv << FileOid; + Object << "serverpath"sv << RelServerPath.string(); + Object << "clientpath"sv + << "/{engine}/testfile.bin"sv; + Object.EndObject(); + Object.EndArray(); + Package.SetObject(Object.Save()); + return Package; + }; + + auto BuildHashReferencedFileOp = [&](const Oid& KeyId, const CompressedBuffer& Blob) -> CbPackage { + CbPackage Package; + CbObjectWriter Object; + Object << "key"sv << OidAsString(KeyId); + CbAttachment Attach(Blob, Blob.DecodeRawHash()); + Object.BeginArray("files"sv); + Object.BeginObject(); + Object << "id"sv << FileOid; + Object << "data"sv << Attach; + Object << "clientpath"sv + << "/{engine}/testfile.bin"sv; + Object.EndObject(); + Object.EndArray(); + Package.AddAttachment(Attach); + Package.SetObject(Object.Save()); + return Package; + }; + + SUBCASE("path-referenced file is retrievable") + { + MakeProject("proj_path"sv); + MakeOplog("proj_path"sv, "oplog"sv); + + CbPackage Op = BuildPathReferencedFileOp(Oid::NewOid()); + PostOplogEntry("proj_path"sv, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk("proj_path"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK_EQ(Payload.GetSize(), FileBlob.GetSize()); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + + SUBCASE("hash-referenced file is retrievable") + { + MakeProject("proj_hash"sv); + MakeOplog("proj_hash"sv, "oplog"sv); + + CbPackage Op = BuildHashReferencedFileOp(Oid::NewOid(), CompressedBlob); + PostOplogEntry("proj_hash"sv, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk("proj_hash"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer(); + CHECK_EQ(Payload.GetSize(), ExpectedDecompressed.GetSize()); + CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView())); + } + } + + SUBCASE("hash-referenced to path-referenced transition with different content") + { + MakeProject("proj_hash_to_path_diff"sv); + MakeOplog("proj_hash_to_path_diff"sv, "oplog"sv); + + Oid FirstOpKey = Oid::NewOid(); + Oid SecondOpKey; + bool RunGcAfterTransition = false; + + SUBCASE("new op key") { SecondOpKey = Oid::NewOid(); } + SUBCASE("same op key") { SecondOpKey = FirstOpKey; } + SUBCASE("new op key with gc") + { + SecondOpKey = Oid::NewOid(); + RunGcAfterTransition = true; + } + SUBCASE("same op key with gc") + { + SecondOpKey = FirstOpKey; + RunGcAfterTransition = true; + } + + // First op: file with CAS hash (content differs from the on-disk file) + { + CbPackage Op = BuildHashReferencedFileOp(FirstOpKey, CompressedBlob); + PostOplogEntry("proj_hash_to_path_diff"sv, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk("proj_hash_to_path_diff"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer(); + CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView())); + } + } + + // Second op: same FileId transitions to serverpath (different data) + { + CbPackage Op = BuildPathReferencedFileOp(SecondOpKey); + PostOplogEntry("proj_hash_to_path_diff"sv, "oplog"sv, Op); + } + + if (RunGcAfterTransition) + { + TriggerGcAndWait(); + } + + // Must serve the on-disk file content, not the old CAS blob + HttpClient::Response Response = GetChunk("proj_hash_to_path_diff"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK_EQ(Payload.GetSize(), FileBlob.GetSize()); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + + SUBCASE("hash-referenced to path-referenced transition with identical content") + { + // Compress the same on-disk file content as a CAS blob so both references yield identical data + CompressedBuffer MatchingBlob = CompressedBuffer::Compress(SharedBuffer::Clone(FileBlob.GetView())); + + MakeProject("proj_hash_to_path_same"sv); + MakeOplog("proj_hash_to_path_same"sv, "oplog"sv); + + Oid FirstOpKey = Oid::NewOid(); + Oid SecondOpKey; + bool RunGcAfterTransition = false; + + SUBCASE("new op key") { SecondOpKey = Oid::NewOid(); } + SUBCASE("same op key") { SecondOpKey = FirstOpKey; } + SUBCASE("new op key with gc") + { + SecondOpKey = Oid::NewOid(); + RunGcAfterTransition = true; + } + SUBCASE("same op key with gc") + { + SecondOpKey = FirstOpKey; + RunGcAfterTransition = true; + } + + // First op: file with CAS hash (content matches the on-disk file) + { + CbPackage Op = BuildHashReferencedFileOp(FirstOpKey, MatchingBlob); + PostOplogEntry("proj_hash_to_path_same"sv, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk("proj_hash_to_path_same"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + + // Second op: same FileId transitions to serverpath (same data) + { + CbPackage Op = BuildPathReferencedFileOp(SecondOpKey); + PostOplogEntry("proj_hash_to_path_same"sv, "oplog"sv, Op); + } + + if (RunGcAfterTransition) + { + TriggerGcAndWait(); + } + + // Must still resolve successfully after the transition + HttpClient::Response Response = GetChunk("proj_hash_to_path_same"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK_EQ(Payload.GetSize(), FileBlob.GetSize()); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + + SUBCASE("path-referenced to hash-referenced transition with different content") + { + MakeProject("proj_path_to_hash_diff"sv); + MakeOplog("proj_path_to_hash_diff"sv, "oplog"sv); + + Oid FirstOpKey = Oid::NewOid(); + Oid SecondOpKey; + bool RunGcAfterTransition = false; + + SUBCASE("new op key") { SecondOpKey = Oid::NewOid(); } + SUBCASE("same op key") { SecondOpKey = FirstOpKey; } + SUBCASE("new op key with gc") + { + SecondOpKey = Oid::NewOid(); + RunGcAfterTransition = true; + } + SUBCASE("same op key with gc") + { + SecondOpKey = FirstOpKey; + RunGcAfterTransition = true; + } + + // First op: file with serverpath + { + CbPackage Op = BuildPathReferencedFileOp(FirstOpKey); + PostOplogEntry("proj_path_to_hash_diff"sv, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk("proj_path_to_hash_diff"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + + // Second op: same FileId transitions to CAS hash (different data) + { + CbPackage Op = BuildHashReferencedFileOp(SecondOpKey, CompressedBlob); + PostOplogEntry("proj_path_to_hash_diff"sv, "oplog"sv, Op); + } + + if (RunGcAfterTransition) + { + TriggerGcAndWait(); + } + + // Must serve the CAS blob content, not the old on-disk file + HttpClient::Response Response = GetChunk("proj_path_to_hash_diff"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer(); + CHECK_EQ(Payload.GetSize(), ExpectedDecompressed.GetSize()); + CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView())); + } + } + + SUBCASE("path-referenced to hash-referenced transition with identical content") + { + // Compress the same on-disk file content as a CAS blob so both references yield identical data + CompressedBuffer MatchingBlob = CompressedBuffer::Compress(SharedBuffer::Clone(FileBlob.GetView())); + + MakeProject("proj_path_to_hash_same"sv); + MakeOplog("proj_path_to_hash_same"sv, "oplog"sv); + + Oid FirstOpKey = Oid::NewOid(); + Oid SecondOpKey; + bool RunGcAfterTransition = false; + + SUBCASE("new op key") { SecondOpKey = Oid::NewOid(); } + SUBCASE("same op key") { SecondOpKey = FirstOpKey; } + SUBCASE("new op key with gc") + { + SecondOpKey = Oid::NewOid(); + RunGcAfterTransition = true; + } + SUBCASE("same op key with gc") + { + SecondOpKey = FirstOpKey; + RunGcAfterTransition = true; + } + + // First op: file with serverpath + { + CbPackage Op = BuildPathReferencedFileOp(FirstOpKey); + PostOplogEntry("proj_path_to_hash_same"sv, "oplog"sv, Op); + + HttpClient::Response Response = GetChunk("proj_path_to_hash_same"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } + + // Second op: same FileId transitions to CAS hash (same data) + { + CbPackage Op = BuildHashReferencedFileOp(SecondOpKey, MatchingBlob); + PostOplogEntry("proj_path_to_hash_same"sv, "oplog"sv, Op); + } + + if (RunGcAfterTransition) + { + TriggerGcAndWait(); + } + + // Must still resolve successfully after the transition + HttpClient::Response Response = GetChunk("proj_path_to_hash_same"sv); + CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition")); + if (Response.IsSuccess()) + { + IoBuffer Payload = GetDecompressedPayload(Response); + CHECK_EQ(Payload.GetSize(), FileBlob.GetSize()); + CHECK(Payload.GetView().EqualBytes(FileBlob.GetView())); + } + } +} + TEST_SUITE_END(); } // namespace zen::tests diff --git a/src/zenserver/frontend/html/compute/hub.html b/src/zenserver/frontend/html/compute/hub.html index b15b34577..41c80d3a3 100644 --- a/src/zenserver/frontend/html/compute/hub.html +++ b/src/zenserver/frontend/html/compute/hub.html @@ -83,7 +83,7 @@ } async function fetchStats() { - var data = await fetchJSON('/hub/stats'); + var data = await fetchJSON('/stats/hub'); var current = data.currentInstanceCount || 0; var max = data.maxInstanceCount || 0; diff --git a/src/zenserver/frontend/html/pages/builds.js b/src/zenserver/frontend/html/pages/builds.js index 095f0bf29..6b3426378 100644 --- a/src/zenserver/frontend/html/pages/builds.js +++ b/src/zenserver/frontend/html/pages/builds.js @@ -16,7 +16,7 @@ export class Page extends ZenPage this.set_title("build store"); // Build Store Stats - const stats_section = this.add_section("Build Store Stats"); + const stats_section = this._collapsible_section("Build Store Service Stats"); stats_section.tag().classify("dropall").text("raw yaml \u2192").on_click(() => { window.open("/stats/builds.yaml", "_blank"); }); diff --git a/src/zenserver/frontend/html/pages/cache.js b/src/zenserver/frontend/html/pages/cache.js index 1fc8227c8..e0f6f73b6 100644 --- a/src/zenserver/frontend/html/pages/cache.js +++ b/src/zenserver/frontend/html/pages/cache.js @@ -95,39 +95,6 @@ export class Page extends ZenPage } } - _collapsible_section(name) - { - const section = this.add_section(name); - const container = section._parent.inner(); - const heading = container.firstElementChild; - - heading.style.cursor = "pointer"; - heading.style.userSelect = "none"; - - const indicator = document.createElement("span"); - indicator.textContent = " \u25BC"; - indicator.style.fontSize = "0.7em"; - heading.appendChild(indicator); - - let collapsed = false; - heading.addEventListener("click", (e) => { - if (e.target !== heading && e.target !== indicator) - { - return; - } - collapsed = !collapsed; - indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; - let sibling = heading.nextElementSibling; - while (sibling) - { - sibling.style.display = collapsed ? "none" : ""; - sibling = sibling.nextElementSibling; - } - }); - - return section; - } - _render_stats(stats) { const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj); diff --git a/src/zenserver/frontend/html/pages/compute.js b/src/zenserver/frontend/html/pages/compute.js index ab3d49c27..2eb4d4e9b 100644 --- a/src/zenserver/frontend/html/pages/compute.js +++ b/src/zenserver/frontend/html/pages/compute.js @@ -24,6 +24,12 @@ function formatTime(date) return date.toLocaleTimeString([], { hour: "2-digit", minute: "2-digit", second: "2-digit" }); } +function truncateHash(hash) +{ + if (!hash || hash.length <= 15) return hash; + return hash.slice(0, 6) + "\u2026" + hash.slice(-6); +} + function formatDuration(startDate, endDate) { if (!startDate || !endDate) return "-"; @@ -100,39 +106,6 @@ export class Page extends ZenPage }, 2000); } - _collapsible_section(name) - { - const section = this.add_section(name); - const container = section._parent.inner(); - const heading = container.firstElementChild; - - heading.style.cursor = "pointer"; - heading.style.userSelect = "none"; - - const indicator = document.createElement("span"); - indicator.textContent = " \u25BC"; - indicator.style.fontSize = "0.7em"; - heading.appendChild(indicator); - - let collapsed = false; - heading.addEventListener("click", (e) => { - if (e.target !== heading && e.target !== indicator) - { - return; - } - collapsed = !collapsed; - indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; - let sibling = heading.nextElementSibling; - while (sibling) - { - sibling.style.display = collapsed ? "none" : ""; - sibling = sibling.nextElementSibling; - } - }); - - return section; - } - async _load_chartjs() { if (window.Chart) @@ -338,11 +311,7 @@ export class Page extends ZenPage { const workerIds = data.workers || []; - if (this._workers_table) - { - this._workers_table.clear(); - } - else + if (!this._workers_table) { this._workers_table = this._workers_host.add_widget( Table, @@ -353,6 +322,7 @@ export class Page extends ZenPage if (workerIds.length === 0) { + this._workers_table.clear(); return; } @@ -382,6 +352,9 @@ export class Page extends ZenPage id, ); + // Worker ID column: monospace for hex readability + row.get_cell(5).style("fontFamily", "'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace"); + // Make name clickable to expand detail const cell = row.get_cell(0); cell.tag().text(name).on_click(() => this._toggle_worker_detail(id, desc)); @@ -579,6 +552,11 @@ export class Page extends ZenPage ["LSN", "queue", "status", "function", "started", "finished", "duration", "worker ID", "action ID"], Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1 ); + + // Right-align hash column headers to match data cells + const hdr = this._history_table.inner().firstElementChild; + hdr.children[7].style.textAlign = "right"; + hdr.children[8].style.textAlign = "right"; } // Entries arrive oldest-first; reverse to show newest at top @@ -593,7 +571,10 @@ export class Page extends ZenPage const startDate = filetimeToDate(entry.time_Running); const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed); - this._history_table.add_row( + const workerId = entry.workerId || "-"; + const actionId = entry.actionId || "-"; + + const row = this._history_table.add_row( lsn, queueId, status, @@ -601,9 +582,15 @@ export class Page extends ZenPage formatTime(startDate), formatTime(endDate), formatDuration(startDate, endDate), - entry.workerId || "-", - entry.actionId || "-", + truncateHash(workerId), + truncateHash(actionId), ); + + // Hash columns: force right-align (AlignNumeric misses hex strings starting with a-f), + // use monospace for readability, and show full value on hover + const mono = "'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace"; + row.get_cell(7).style("textAlign", "right").style("fontFamily", mono).attr("title", workerId); + row.get_cell(8).style("textAlign", "right").style("fontFamily", mono).attr("title", actionId); } } diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js index 1e4c82e3f..e381f4a71 100644 --- a/src/zenserver/frontend/html/pages/entry.js +++ b/src/zenserver/frontend/html/pages/entry.js @@ -168,7 +168,7 @@ export class Page extends ZenPage if (key === "cook.artifacts") { action_tb.left().add("view-raw").on_click(() => { - window.location = "/" + ["prj", project, "oplog", oplog, value+".json"].join("/"); + window.open("/" + ["prj", project, "oplog", oplog, value+".json"].join("/"), "_self"); }); } diff --git a/src/zenserver/frontend/html/pages/hub.js b/src/zenserver/frontend/html/pages/hub.js index 78e3a090c..c9652f31e 100644 --- a/src/zenserver/frontend/html/pages/hub.js +++ b/src/zenserver/frontend/html/pages/hub.js @@ -82,7 +82,7 @@ export class Page extends ZenPage this.set_title("hub"); // Capacity - const stats_section = this.add_section("Capacity"); + const stats_section = this._collapsible_section("Hub Service Stats"); this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles"); // Modules @@ -152,6 +152,8 @@ export class Page extends ZenPage this._btn_next.className = "module-pager-btn"; this._btn_next.textContent = "Next \u2192"; this._btn_next.addEventListener("click", () => this._go_page(this._page + 1)); + this._btn_provision = _make_bulk_btn("+", "Provision", () => this._show_provision_modal()); + pager.appendChild(this._btn_provision); pager.appendChild(this._btn_prev); pager.appendChild(this._pager_label); pager.appendChild(this._btn_next); @@ -203,27 +205,47 @@ export class Page extends ZenPage { const tile = grid.tag().classify("card").classify("stats-tile"); - tile.tag().classify("card-title").text("Active Modules"); + tile.tag().classify("card-title").text("Instances"); const body = tile.tag().classify("tile-metrics"); this._metric(body, Friendly.sep(current), "currently provisioned", true); + this._metric(body, Friendly.sep(max), "high watermark"); + this._metric(body, Friendly.sep(limit), "maximum allowed"); + if (limit > 0) + { + const pct = ((current / limit) * 100).toFixed(0) + "%"; + this._metric(body, pct, "utilization"); + } } + const machine = data.machine || {}; + const limits = data.resource_limits || {}; + if (machine.disk_total_bytes > 0 || machine.memory_total_mib > 0) { - const tile = grid.tag().classify("card").classify("stats-tile"); - tile.tag().classify("card-title").text("Peak Modules"); - const body = tile.tag().classify("tile-metrics"); - this._metric(body, Friendly.sep(max), "high watermark", true); - } + const disk_used = Math.max(0, (machine.disk_total_bytes || 0) - (machine.disk_free_bytes || 0)); + const mem_used = Math.max(0, (machine.memory_total_mib || 0) - (machine.memory_avail_mib || 0)) * 1024 * 1024; + const vmem_used = Math.max(0, (machine.virtual_memory_total_mib || 0) - (machine.virtual_memory_avail_mib || 0)) * 1024 * 1024; + const disk_limit = limits.disk_bytes || 0; + const mem_limit = limits.memory_bytes || 0; + const disk_over = disk_limit > 0 && disk_used > disk_limit; + const mem_over = mem_limit > 0 && mem_used > mem_limit; - { const tile = grid.tag().classify("card").classify("stats-tile"); - tile.tag().classify("card-title").text("Instance Limit"); - const body = tile.tag().classify("tile-metrics"); - this._metric(body, Friendly.sep(limit), "maximum allowed", true); - if (limit > 0) + if (disk_over || mem_over) { tile.inner().setAttribute("data-over", "true"); } + tile.tag().classify("card-title").text("Resources"); + const columns = tile.tag().classify("tile-columns"); + + const left = columns.tag().classify("tile-metrics"); + this._metric(left, Friendly.bytes(disk_used), "disk used", true); + this._metric(left, Friendly.bytes(machine.disk_total_bytes), "disk total"); + if (disk_limit > 0) { this._metric(left, Friendly.bytes(disk_limit), "disk limit"); } + + const right = columns.tag().classify("tile-metrics"); + this._metric(right, Friendly.bytes(mem_used), "memory used", true); + if (mem_limit > 0) { this._metric(right, Friendly.bytes(mem_limit), "memory limit"); } + if (machine.virtual_memory_total_mib > 0) { - const pct = ((current / limit) * 100).toFixed(0) + "%"; - this._metric(body, pct, "utilization"); + this._metric(right, Friendly.bytes(vmem_used), "vmem used", true); + this._metric(right, Friendly.bytes(machine.virtual_memory_total_mib * 1024 * 1024), "vmem total"); } } } @@ -284,6 +306,14 @@ export class Page extends ZenPage } row.state_text.nodeValue = state; row.port_text.nodeValue = m.port ? String(m.port) : ""; + if (m.state_change_time) + { + const state_label = state.charAt(0).toUpperCase() + state.slice(1); + row.state_since_label.textContent = state_label + " since"; + row.state_age_label.textContent = state_label + " for"; + row.state_since_node.nodeValue = m.state_change_time; + row.state_age_node.nodeValue = Friendly.timespan(Date.now() - new Date(m.state_change_time).getTime()); + } row.btn_open.disabled = state !== "provisioned"; row.btn_hibernate.disabled = !_btn_enabled(state, "hibernate"); row.btn_wake.disabled = !_btn_enabled(state, "wake"); @@ -373,7 +403,7 @@ export class Page extends ZenPage const td_action = document.createElement("td"); td_action.className = "module-action-cell"; const [wrap_o, btn_o] = _make_action_btn("\u2197", "Open dashboard", () => { - window.open(`${window.location.protocol}//${window.location.hostname}:${port}`, "_blank"); + window.open(`/hub/proxy/${port}/dashboard/`, "_blank"); }); btn_o.disabled = state !== "provisioned"; const [wrap_h, btn_h] = _make_action_btn("\u23F8", "Hibernate", () => this._post_module_action(id, "hibernate").then(() => this._update())); @@ -388,7 +418,7 @@ export class Page extends ZenPage td_action.appendChild(wrap_o); tr.appendChild(td_action); - // Build metrics grid from process_metrics keys. + // Build metrics grid: fixed state-time rows followed by process_metrics keys. // Keys are split into two halves and interleaved so the grid fills // top-to-bottom in the left column before continuing in the right column. const metric_nodes = new Map(); @@ -396,6 +426,28 @@ export class Page extends ZenPage metrics_td.colSpan = 6; const metrics_grid = document.createElement("div"); metrics_grid.className = "module-metrics-grid"; + + const _add_fixed_pair = (label, value_str) => { + const label_el = document.createElement("span"); + label_el.className = "module-metrics-label"; + label_el.textContent = label; + const value_node = document.createTextNode(value_str); + const value_el = document.createElement("span"); + value_el.className = "module-metrics-value"; + value_el.appendChild(value_node); + metrics_grid.appendChild(label_el); + metrics_grid.appendChild(value_el); + return { label_el, value_node }; + }; + + const state_label = m.state ? m.state.charAt(0).toUpperCase() + m.state.slice(1) : "State"; + const state_since_str = m.state_change_time || ""; + const state_age_str = m.state_change_time + ? Friendly.timespan(Date.now() - new Date(m.state_change_time).getTime()) + : ""; + const { label_el: state_since_label, value_node: state_since_node } = _add_fixed_pair(state_label + " since", state_since_str); + const { label_el: state_age_label, value_node: state_age_node } = _add_fixed_pair(state_label + " for", state_age_str); + const keys = Object.keys(m.process_metrics || {}); const half = Math.ceil(keys.length / 2); const add_metric_pair = (key) => { @@ -423,7 +475,7 @@ export class Page extends ZenPage metrics_td.appendChild(metrics_grid); metrics_tr.appendChild(metrics_td); - row = { tr, metrics_tr, idx: td_idx, cb, dot, state_text: state_node, port_text: port_node, btn_expand, btn_open: btn_o, btn_hibernate: btn_h, btn_wake: btn_w, btn_deprov: btn_d, metric_nodes }; + row = { tr, metrics_tr, idx: td_idx, cb, dot, state_text: state_node, port_text: port_node, btn_expand, btn_open: btn_o, btn_hibernate: btn_h, btn_wake: btn_w, btn_deprov: btn_d, metric_nodes, state_since_node, state_age_node, state_since_label, state_age_label }; this._row_cache.set(id, row); } @@ -614,4 +666,135 @@ export class Page extends ZenPage await fetch(`/hub/modules/${moduleId}/${action}`, { method: "POST" }); } + _show_provision_modal() + { + const MODULE_ID_RE = /^[A-Za-z0-9][A-Za-z0-9-]*$/; + + const overlay = document.createElement("div"); + overlay.className = "zen_modal"; + + const bg = document.createElement("div"); + bg.className = "zen_modal_bg"; + bg.addEventListener("click", () => overlay.remove()); + overlay.appendChild(bg); + + const dialog = document.createElement("div"); + overlay.appendChild(dialog); + + const title = document.createElement("div"); + title.className = "zen_modal_title"; + title.textContent = "Provision Module"; + dialog.appendChild(title); + + const content = document.createElement("div"); + content.className = "zen_modal_message"; + content.style.textAlign = "center"; + + const input = document.createElement("input"); + input.type = "text"; + input.placeholder = "module-name"; + input.style.cssText = "width:100%;font-size:14px;padding:8px 12px;"; + content.appendChild(input); + + const error_div = document.createElement("div"); + error_div.style.cssText = "color:var(--theme_fail);font-size:12px;margin-top:8px;min-height:1.2em;"; + content.appendChild(error_div); + + dialog.appendChild(content); + + const buttons = document.createElement("div"); + buttons.className = "zen_modal_buttons"; + + const btn_cancel = document.createElement("div"); + btn_cancel.textContent = "Cancel"; + btn_cancel.addEventListener("click", () => overlay.remove()); + + const btn_submit = document.createElement("div"); + btn_submit.textContent = "Provision"; + + buttons.appendChild(btn_cancel); + buttons.appendChild(btn_submit); + dialog.appendChild(buttons); + + let submitting = false; + + const set_submit_enabled = (enabled) => { + btn_submit.style.opacity = enabled ? "" : "0.4"; + btn_submit.style.pointerEvents = enabled ? "" : "none"; + }; + + set_submit_enabled(false); + + const validate = () => { + if (submitting) { return false; } + const val = input.value.trim(); + if (val.length === 0) + { + error_div.textContent = ""; + set_submit_enabled(false); + return false; + } + if (!MODULE_ID_RE.test(val)) + { + error_div.textContent = "Only letters, numbers, and hyphens allowed (must start with a letter or number)"; + set_submit_enabled(false); + return false; + } + error_div.textContent = ""; + set_submit_enabled(true); + return true; + }; + + input.addEventListener("input", validate); + + const submit = async () => { + if (submitting) { return; } + const moduleId = input.value.trim(); + if (!MODULE_ID_RE.test(moduleId)) { return; } + + submitting = true; + set_submit_enabled(false); + error_div.textContent = ""; + + try + { + const resp = await fetch(`/hub/modules/${encodeURIComponent(moduleId)}/provision`, { method: "POST" }); + if (resp.ok) + { + overlay.remove(); + await this._update(); + this._navigate_to_module(moduleId); + return; + } + const msg = await resp.text(); + error_div.textContent = msg || ("HTTP " + resp.status); + } + catch (e) + { + error_div.textContent = e.message || "Request failed"; + } + submitting = false; + set_submit_enabled(true); + }; + + btn_submit.addEventListener("click", submit); + input.addEventListener("keydown", (e) => { + if (e.key === "Enter" && validate()) { submit(); } + if (e.key === "Escape") { overlay.remove(); } + }); + + document.body.appendChild(overlay); + input.focus(); + } + + _navigate_to_module(moduleId) + { + const idx = this._modules_data.findIndex(m => m.moduleId === moduleId); + if (idx >= 0) + { + this._page = Math.floor(idx / this._page_size); + this._render_page(); + } + } + } diff --git a/src/zenserver/frontend/html/pages/orchestrator.js b/src/zenserver/frontend/html/pages/orchestrator.js index 4a9290a3c..a280fabdb 100644 --- a/src/zenserver/frontend/html/pages/orchestrator.js +++ b/src/zenserver/frontend/html/pages/orchestrator.js @@ -46,39 +46,6 @@ export class Page extends ZenPage this._connect_ws(); } - _collapsible_section(name) - { - const section = this.add_section(name); - const container = section._parent.inner(); - const heading = container.firstElementChild; - - heading.style.cursor = "pointer"; - heading.style.userSelect = "none"; - - const indicator = document.createElement("span"); - indicator.textContent = " \u25BC"; - indicator.style.fontSize = "0.7em"; - heading.appendChild(indicator); - - let collapsed = false; - heading.addEventListener("click", (e) => { - if (e.target !== heading && e.target !== indicator) - { - return; - } - collapsed = !collapsed; - indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; - let sibling = heading.nextElementSibling; - while (sibling) - { - sibling.style.display = collapsed ? "none" : ""; - sibling = sibling.nextElementSibling; - } - }); - - return section; - } - async _fetch_all() { try diff --git a/src/zenserver/frontend/html/pages/page.js b/src/zenserver/frontend/html/pages/page.js index cf8d3e3dd..ff530ff8e 100644 --- a/src/zenserver/frontend/html/pages/page.js +++ b/src/zenserver/frontend/html/pages/page.js @@ -337,4 +337,37 @@ export class ZenPage extends PageBase this._metric(right, Friendly.duration(reqData.t_max), "max"); } } + + _collapsible_section(name) + { + const section = this.add_section(name); + const container = section._parent.inner(); + const heading = container.firstElementChild; + + heading.style.cursor = "pointer"; + heading.style.userSelect = "none"; + + const indicator = document.createElement("span"); + indicator.textContent = " \u25BC"; + indicator.style.fontSize = "0.7em"; + heading.appendChild(indicator); + + let collapsed = false; + heading.addEventListener("click", (e) => { + if (e.target !== heading && e.target !== indicator) + { + return; + } + collapsed = !collapsed; + indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; + let sibling = heading.nextElementSibling; + while (sibling) + { + sibling.style.display = collapsed ? "none" : ""; + sibling = sibling.nextElementSibling; + } + }); + + return section; + } } diff --git a/src/zenserver/frontend/html/pages/projects.js b/src/zenserver/frontend/html/pages/projects.js index 2469bf70b..dfe4faeb8 100644 --- a/src/zenserver/frontend/html/pages/projects.js +++ b/src/zenserver/frontend/html/pages/projects.js @@ -110,39 +110,6 @@ export class Page extends ZenPage } } - _collapsible_section(name) - { - const section = this.add_section(name); - const container = section._parent.inner(); - const heading = container.firstElementChild; - - heading.style.cursor = "pointer"; - heading.style.userSelect = "none"; - - const indicator = document.createElement("span"); - indicator.textContent = " \u25BC"; - indicator.style.fontSize = "0.7em"; - heading.appendChild(indicator); - - let collapsed = false; - heading.addEventListener("click", (e) => { - if (e.target !== heading && e.target !== indicator) - { - return; - } - collapsed = !collapsed; - indicator.textContent = collapsed ? " \u25B6" : " \u25BC"; - let sibling = heading.nextElementSibling; - while (sibling) - { - sibling.style.display = collapsed ? "none" : ""; - sibling = sibling.nextElementSibling; - } - }); - - return section; - } - _clear_param(name) { this._params.delete(name); diff --git a/src/zenserver/frontend/html/pages/workspaces.js b/src/zenserver/frontend/html/pages/workspaces.js index d31fd7373..2442fb35b 100644 --- a/src/zenserver/frontend/html/pages/workspaces.js +++ b/src/zenserver/frontend/html/pages/workspaces.js @@ -13,7 +13,7 @@ export class Page extends ZenPage this.set_title("workspaces"); // Workspace Service Stats - const stats_section = this.add_section("Workspace Service Stats"); + const stats_section = this._collapsible_section("Workspace Service Stats"); this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles"); const stats = await new Fetcher().resource("stats", "ws").json().catch(() => null); diff --git a/src/zenserver/frontend/html/zen.css b/src/zenserver/frontend/html/zen.css index d9f7491ea..cb3d78cf2 100644 --- a/src/zenserver/frontend/html/zen.css +++ b/src/zenserver/frontend/html/zen.css @@ -816,6 +816,10 @@ zen-banner + zen-nav::part(nav-bar) { border-color: var(--theme_p0); } +.stats-tile[data-over="true"] { + border-color: var(--theme_fail); +} + .stats-tile-detailed { position: relative; } diff --git a/src/zenserver/hub/httphubservice.cpp b/src/zenserver/hub/httphubservice.cpp index ebefcf2e3..eba816793 100644 --- a/src/zenserver/hub/httphubservice.cpp +++ b/src/zenserver/hub/httphubservice.cpp @@ -2,6 +2,7 @@ #include "httphubservice.h" +#include "httpproxyhandler.h" #include "hub.h" #include "storageserverinstance.h" @@ -43,10 +44,11 @@ namespace { } } // namespace -HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpStatusService& StatusService) +HttpHubService::HttpHubService(Hub& Hub, HttpProxyHandler& Proxy, HttpStatsService& StatsService, HttpStatusService& StatusService) : m_Hub(Hub) , m_StatsService(StatsService) , m_StatusService(StatusService) +, m_Proxy(Proxy) { using namespace std::literals; @@ -67,6 +69,23 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta return true; }); + m_Router.AddMatcher("port", [](std::string_view Str) -> bool { + if (Str.empty()) + { + return false; + } + for (const auto C : Str) + { + if (!std::isdigit(C)) + { + return false; + } + } + return true; + }); + + m_Router.AddMatcher("proxypath", [](std::string_view Str) -> bool { return !Str.empty(); }); + m_Router.RegisterRoute( "status", [this](HttpRouterRequest& Req) { @@ -78,6 +97,10 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta Obj << "moduleId" << ModuleId; Obj << "state" << ToString(Info.State); Obj << "port" << Info.Port; + if (Info.StateChangeTime != std::chrono::system_clock::time_point::min()) + { + Obj << "state_change_time" << ToDateTime(Info.StateChangeTime); + } Obj.BeginObject("process_metrics"); { Obj << "MemoryBytes" << Info.Metrics.MemoryBytes; @@ -229,15 +252,23 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta HttpVerb::kPost); m_Router.RegisterRoute( - "stats", + "proxy/{port}/{proxypath}", [this](HttpRouterRequest& Req) { - CbObjectWriter Obj; - Obj << "currentInstanceCount" << m_Hub.GetInstanceCount(); - Obj << "maxInstanceCount" << m_Hub.GetMaxInstanceCount(); - Obj << "instanceLimit" << m_Hub.GetConfig().InstanceLimit; - Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + std::string_view PortStr = Req.GetCapture(1); + + // Use RelativeUriWithExtension to preserve the file extension that the + // router's URI parser strips (e.g. ".css", ".js") - the upstream server + // needs the full path including the extension. + std::string_view FullUri = Req.ServerRequest().RelativeUriWithExtension(); + std::string_view Prefix = "proxy/"; + + // FullUri is "proxy/{port}/{path...}" - skip past "proxy/{port}/" + size_t PathStart = Prefix.size() + PortStr.size() + 1; + std::string_view PathTail = (PathStart < FullUri.size()) ? FullUri.substr(PathStart) : std::string_view{}; + + m_Proxy.HandleProxyRequest(Req.ServerRequest(), PortStr, PathTail); }, - HttpVerb::kGet); + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); m_StatsService.RegisterHandler("hub", *this); m_StatusService.RegisterHandler("hub", *this); @@ -286,7 +317,37 @@ HttpHubService::HandleStatusRequest(HttpServerRequest& Request) void HttpHubService::HandleStatsRequest(HttpServerRequest& Request) { - Request.WriteResponse(HttpResponseCode::OK, CollectStats()); + CbObjectWriter Cbo; + + EmitSnapshot("requests", m_HttpRequests, Cbo); + + Cbo << "currentInstanceCount" << m_Hub.GetInstanceCount(); + Cbo << "maxInstanceCount" << m_Hub.GetMaxInstanceCount(); + Cbo << "instanceLimit" << m_Hub.GetConfig().InstanceLimit; + + SystemMetrics SysMetrics; + DiskSpace Disk; + m_Hub.GetMachineMetrics(SysMetrics, Disk); + Cbo.BeginObject("machine"); + { + Cbo << "disk_free_bytes" << Disk.Free; + Cbo << "disk_total_bytes" << Disk.Total; + Cbo << "memory_avail_mib" << SysMetrics.AvailSystemMemoryMiB; + Cbo << "memory_total_mib" << SysMetrics.SystemMemoryMiB; + Cbo << "virtual_memory_avail_mib" << SysMetrics.AvailVirtualMemoryMiB; + Cbo << "virtual_memory_total_mib" << SysMetrics.VirtualMemoryMiB; + } + Cbo.EndObject(); + + const ResourceMetrics& Limits = m_Hub.GetConfig().ResourceLimits; + Cbo.BeginObject("resource_limits"); + { + Cbo << "disk_bytes" << Limits.DiskUsageBytes; + Cbo << "memory_bytes" << Limits.MemoryUsageBytes; + } + Cbo.EndObject(); + + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); } CbObject @@ -369,4 +430,22 @@ HttpHubService::HandleModuleDelete(HttpServerRequest& Request, std::string_view Request.WriteResponse(HttpResponseCode::OK, Obj.Save()); } +void +HttpHubService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) +{ + m_Proxy.OnWebSocketOpen(std::move(Connection), RelativeUri); +} + +void +HttpHubService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + m_Proxy.OnWebSocketMessage(Conn, Msg); +} + +void +HttpHubService::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) +{ + m_Proxy.OnWebSocketClose(Conn, Code, Reason); +} + } // namespace zen diff --git a/src/zenserver/hub/httphubservice.h b/src/zenserver/hub/httphubservice.h index 1bb1c303e..ff2cb0029 100644 --- a/src/zenserver/hub/httphubservice.h +++ b/src/zenserver/hub/httphubservice.h @@ -2,11 +2,16 @@ #pragma once +#include <zencore/thread.h> #include <zenhttp/httpserver.h> #include <zenhttp/httpstatus.h> +#include <zenhttp/websocket.h> + +#include <memory> namespace zen { +class HttpProxyHandler; class HttpStatsService; class Hub; @@ -16,10 +21,10 @@ class Hub; * use in UEFN content worker style scenarios. * */ -class HttpHubService : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider +class HttpHubService : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider, public IWebSocketHandler { public: - HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpStatusService& StatusService); + HttpHubService(Hub& Hub, HttpProxyHandler& Proxy, HttpStatsService& StatsService, HttpStatusService& StatusService); ~HttpHubService(); HttpHubService(const HttpHubService&) = delete; @@ -32,6 +37,11 @@ public: virtual CbObject CollectStats() override; virtual uint64_t GetActivityCounter() override; + // IWebSocketHandler + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; + void SetNotificationEndpoint(std::string_view UpstreamNotificationEndpoint, std::string_view InstanceId); private: @@ -45,6 +55,8 @@ private: void HandleModuleGet(HttpServerRequest& Request, std::string_view ModuleId); void HandleModuleDelete(HttpServerRequest& Request, std::string_view ModuleId); + + HttpProxyHandler& m_Proxy; }; } // namespace zen diff --git a/src/zenserver/hub/httpproxyhandler.cpp b/src/zenserver/hub/httpproxyhandler.cpp new file mode 100644 index 000000000..25842623a --- /dev/null +++ b/src/zenserver/hub/httpproxyhandler.cpp @@ -0,0 +1,504 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpproxyhandler.h" + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include <zenhttp/httpclient.h> +#include <zenhttp/httpwsclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <charconv> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +#endif // ZEN_WITH_TESTS + +namespace zen { + +namespace { + + std::string InjectProxyScript(std::string_view Html, uint16_t Port) + { + ExtendableStringBuilder<2048> Script; + Script.Append("<script>\n(function(){\n var P = \"/hub/proxy/"); + Script.Append(fmt::format("{}", Port)); + Script.Append( + "\";\n" + " var OF = window.fetch;\n" + " window.fetch = function(u, o) {\n" + " if (typeof u === \"string\") {\n" + " try {\n" + " var p = new URL(u, location.origin);\n" + " if (p.origin === location.origin && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {\n" + " if (u.startsWith(\"/\") && !u.startsWith(P)) u = P + u;\n" + " }\n" + " }\n" + " return OF.call(this, u, o);\n" + " };\n" + " var OW = window.WebSocket;\n" + " window.WebSocket = function(u, pr) {\n" + " try {\n" + " var p = new URL(u);\n" + " if (p.hostname === location.hostname\n" + " && String(p.port || (p.protocol === \"wss:\" ? \"443\" : \"80\"))\n" + " === String(location.port || (location.protocol === \"https:\" ? \"443\" : \"80\"))\n" + " && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {}\n" + " return pr !== undefined ? new OW(u, pr) : new OW(u);\n" + " };\n" + " window.WebSocket.prototype = OW.prototype;\n" + " window.WebSocket.CONNECTING = OW.CONNECTING;\n" + " window.WebSocket.OPEN = OW.OPEN;\n" + " window.WebSocket.CLOSING = OW.CLOSING;\n" + " window.WebSocket.CLOSED = OW.CLOSED;\n" + " var OO = window.open;\n" + " window.open = function(u, t, f) {\n" + " if (typeof u === \"string\") {\n" + " try {\n" + " var p = new URL(u, location.origin);\n" + " if (p.origin === location.origin && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {}\n" + " }\n" + " return OO.call(this, u, t, f);\n" + " };\n" + " document.addEventListener(\"click\", function(e) {\n" + " var t = e.composedPath ? e.composedPath()[0] : e.target;\n" + " while (t && t.tagName !== \"A\") t = t.parentNode || t.host;\n" + " if (!t || !t.href) return;\n" + " try {\n" + " var h = new URL(t.href);\n" + " if (h.origin === location.origin && !h.pathname.startsWith(P))\n" + " { h.pathname = P + h.pathname; e.preventDefault(); window.location.href = h.toString(); }\n" + " } catch(x) {}\n" + " }, true);\n" + "})();\n</script>"); + + std::string ScriptStr = Script.ToString(); + + size_t HeadClose = Html.find("</head>"); + if (HeadClose != std::string_view::npos) + { + std::string Result; + Result.reserve(Html.size() + ScriptStr.size()); + Result.append(Html.substr(0, HeadClose)); + Result.append(ScriptStr); + Result.append(Html.substr(HeadClose)); + return Result; + } + + std::string Result; + Result.reserve(Html.size() + ScriptStr.size()); + Result.append(ScriptStr); + Result.append(Html); + return Result; + } + +} // namespace + +struct HttpProxyHandler::WsBridge : public RefCounted, public IWsClientHandler +{ + Ref<WebSocketConnection> ClientConn; + std::unique_ptr<HttpWsClient> UpstreamClient; + uint16_t Port = 0; + + void OnWsOpen() override {} + + void OnWsMessage(const WebSocketMessage& Msg) override + { + if (!ClientConn->IsOpen()) + { + return; + } + switch (Msg.Opcode) + { + case WebSocketOpcode::kText: + ClientConn->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kBinary: + ClientConn->SendBinary(std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + default: + break; + } + } + + void OnWsClose(uint16_t Code, std::string_view Reason) override + { + if (ClientConn->IsOpen()) + { + ClientConn->Close(Code, Reason); + } + } +}; + +HttpProxyHandler::HttpProxyHandler() +{ +} + +HttpProxyHandler::HttpProxyHandler(PortValidator ValidatePort) : m_ValidatePort(std::move(ValidatePort)) +{ +} + +void +HttpProxyHandler::SetPortValidator(PortValidator ValidatePort) +{ + m_ValidatePort = std::move(ValidatePort); +} + +HttpProxyHandler::~HttpProxyHandler() +{ + try + { + Shutdown(); + } + catch (...) + { + } +} + +HttpClient& +HttpProxyHandler::GetOrCreateProxyClient(uint16_t Port) +{ + HttpClient* Result = nullptr; + m_ProxyClientsLock.WithExclusiveLock([&] { + auto It = m_ProxyClients.find(Port); + if (It == m_ProxyClients.end()) + { + HttpClientSettings Settings; + Settings.LogCategory = "hub-proxy"; + Settings.ConnectTimeout = std::chrono::milliseconds(5000); + Settings.Timeout = std::chrono::milliseconds(30000); + auto Client = std::make_unique<HttpClient>(fmt::format("http://127.0.0.1:{}", Port), Settings); + Result = Client.get(); + m_ProxyClients.emplace(Port, std::move(Client)); + } + else + { + Result = It->second.get(); + } + }); + return *Result; +} + +void +HttpProxyHandler::HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail) +{ + uint16_t Port = 0; + auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port); + if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size()) + { + Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "invalid proxy URL"); + return; + } + + if (!m_ValidatePort(Port)) + { + Request.WriteResponse(HttpResponseCode::BadGateway, HttpContentType::kText, "target instance not available"); + return; + } + + HttpClient& Client = GetOrCreateProxyClient(Port); + + std::string RequestPath; + RequestPath.reserve(1 + PathTail.size()); + RequestPath.push_back('/'); + RequestPath.append(PathTail); + + std::string_view QueryString = Request.QueryString(); + if (!QueryString.empty()) + { + RequestPath.push_back('?'); + RequestPath.append(QueryString); + } + + HttpClient::KeyValueMap ForwardHeaders; + HttpContentType AcceptType = Request.AcceptContentType(); + if (AcceptType != HttpContentType::kUnknownContentType) + { + ForwardHeaders->emplace("Accept", std::string(MapContentTypeToString(AcceptType))); + } + + std::string_view Auth = Request.GetAuthorizationHeader(); + if (!Auth.empty()) + { + ForwardHeaders->emplace("Authorization", std::string(Auth)); + } + + HttpContentType ReqContentType = Request.RequestContentType(); + if (ReqContentType != HttpContentType::kUnknownContentType) + { + ForwardHeaders->emplace("Content-Type", std::string(MapContentTypeToString(ReqContentType))); + } + + HttpClient::Response Response; + + switch (Request.RequestVerb()) + { + case HttpVerb::kGet: + Response = Client.Get(RequestPath, ForwardHeaders); + break; + case HttpVerb::kPost: + { + IoBuffer Payload = Request.ReadPayload(); + Response = Client.Post(RequestPath, Payload, ForwardHeaders); + break; + } + case HttpVerb::kPut: + { + IoBuffer Payload = Request.ReadPayload(); + Response = Client.Put(RequestPath, Payload, ForwardHeaders); + break; + } + case HttpVerb::kDelete: + Response = Client.Delete(RequestPath, ForwardHeaders); + break; + case HttpVerb::kHead: + Response = Client.Head(RequestPath, ForwardHeaders); + break; + default: + Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "method not supported"); + return; + } + + if (Response.Error) + { + ZEN_WARN("proxy request to port {} failed: {}", Port, Response.Error->ErrorMessage); + Request.WriteResponse(HttpResponseCode::BadGateway, HttpContentType::kText, "upstream request failed"); + return; + } + + HttpContentType ContentType = Response.ResponsePayload.GetContentType(); + + if (ContentType == HttpContentType::kHTML) + { + std::string_view Html(static_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize()); + std::string Injected = InjectProxyScript(Html, Port); + Request.WriteResponse(Response.StatusCode, HttpContentType::kHTML, std::string_view(Injected)); + } + else + { + Request.WriteResponse(Response.StatusCode, ContentType, std::move(Response.ResponsePayload)); + } +} + +void +HttpProxyHandler::PrunePort(uint16_t Port) +{ + m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.erase(Port); }); + + std::vector<Ref<WsBridge>> Stale; + m_WsBridgesLock.WithExclusiveLock([&] { + for (auto It = m_WsBridges.begin(); It != m_WsBridges.end();) + { + if (It->second->Port == Port) + { + Stale.push_back(std::move(It->second)); + It = m_WsBridges.erase(It); + } + else + { + ++It; + } + } + }); + + for (auto& Bridge : Stale) + { + if (Bridge->UpstreamClient) + { + Bridge->UpstreamClient->Close(1001, "instance shutting down"); + } + if (Bridge->ClientConn->IsOpen()) + { + Bridge->ClientConn->Close(1001, "instance shutting down"); + } + } +} + +void +HttpProxyHandler::Shutdown() +{ + m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.clear(); }); + m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.clear(); }); +} + +////////////////////////////////////////////////////////////////////////// +// +// WebSocket proxy +// + +void +HttpProxyHandler::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) +{ + const std::string_view ProxyPrefix = "proxy/"; + if (!RelativeUri.starts_with(ProxyPrefix)) + { + Connection->Close(1008, "unsupported WebSocket endpoint"); + return; + } + + std::string_view ProxyTail = RelativeUri.substr(ProxyPrefix.size()); + + size_t SlashPos = ProxyTail.find('/'); + std::string_view PortStr = (SlashPos != std::string_view::npos) ? ProxyTail.substr(0, SlashPos) : ProxyTail; + std::string_view Path = (SlashPos != std::string_view::npos) ? ProxyTail.substr(SlashPos) : "/"; + + uint16_t Port = 0; + auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port); + if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size()) + { + Connection->Close(1008, "invalid proxy URL"); + return; + } + + if (!m_ValidatePort(Port)) + { + Connection->Close(1008, "target instance not available"); + return; + } + + std::string WsUrl = HttpToWsUrl(fmt::format("http://127.0.0.1:{}", Port), Path); + + Ref<WsBridge> Bridge(new WsBridge()); + Bridge->ClientConn = Connection; + Bridge->Port = Port; + + Bridge->UpstreamClient = std::make_unique<HttpWsClient>(WsUrl, *Bridge); + + try + { + Bridge->UpstreamClient->Connect(); + } + catch (const std::exception& Ex) + { + ZEN_WARN("proxy WebSocket connect to {} failed: {}", WsUrl, Ex.what()); + Connection->Close(1011, "upstream connect failed"); + return; + } + + WebSocketConnection* Key = Connection.Get(); + m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.emplace(Key, std::move(Bridge)); }); +} + +void +HttpProxyHandler::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + Ref<WsBridge> Bridge; + m_WsBridgesLock.WithSharedLock([&] { + auto It = m_WsBridges.find(&Conn); + if (It != m_WsBridges.end()) + { + Bridge = It->second; + } + }); + + if (!Bridge || !Bridge->UpstreamClient) + { + return; + } + + switch (Msg.Opcode) + { + case WebSocketOpcode::kText: + Bridge->UpstreamClient->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kBinary: + Bridge->UpstreamClient->SendBinary( + std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kClose: + Bridge->UpstreamClient->Close(Msg.CloseCode, {}); + break; + default: + break; + } +} + +void +HttpProxyHandler::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) +{ + Ref<WsBridge> Bridge = m_WsBridgesLock.WithExclusiveLock([this, &Conn]() -> Ref<WsBridge> { + auto It = m_WsBridges.find(&Conn); + if (It != m_WsBridges.end()) + { + Ref<WsBridge> Bridge = std::move(It->second); + m_WsBridges.erase(It); + return Bridge; + } + return {}; + }); + + if (Bridge && Bridge->UpstreamClient) + { + Bridge->UpstreamClient->Close(Code, Reason); + } +} + +#if ZEN_WITH_TESTS + +TEST_SUITE_BEGIN("server.httpproxyhandler"); + +TEST_CASE("server.httpproxyhandler.html_injection") +{ + SUBCASE("injects before </head>") + { + std::string Result = InjectProxyScript("<html><head></head><body></body></html>", 21005); + CHECK(Result.find("<script>") != std::string::npos); + CHECK(Result.find("/hub/proxy/21005") != std::string::npos); + size_t ScriptEnd = Result.find("</script>"); + size_t HeadClose = Result.find("</head>"); + REQUIRE(ScriptEnd != std::string::npos); + REQUIRE(HeadClose != std::string::npos); + CHECK(ScriptEnd < HeadClose); + } + + SUBCASE("prepends when no </head>") + { + std::string Result = InjectProxyScript("<body>content</body>", 21005); + CHECK(Result.find("<script>") == 0); + CHECK(Result.find("<body>content</body>") != std::string::npos); + } + + SUBCASE("empty html") + { + std::string Result = InjectProxyScript("", 21005); + CHECK(Result.find("<script>") != std::string::npos); + CHECK(Result.find("/hub/proxy/21005") != std::string::npos); + } + + SUBCASE("preserves original content") + { + std::string_view Html = "<html><head><title>Test</title></head><body><h1>Dashboard</h1></body></html>"; + std::string Result = InjectProxyScript(Html, 21005); + CHECK(Result.find("<title>Test</title>") != std::string::npos); + CHECK(Result.find("<h1>Dashboard</h1>") != std::string::npos); + } +} + +TEST_CASE("server.httpproxyhandler.port_embedding") +{ + std::string Result = InjectProxyScript("<head></head>", 80); + CHECK(Result.find("/hub/proxy/80") != std::string::npos); + + Result = InjectProxyScript("<head></head>", 65535); + CHECK(Result.find("/hub/proxy/65535") != std::string::npos); +} + +TEST_SUITE_END(); + +void +httpproxyhandler_forcelink() +{ +} +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zenserver/hub/httpproxyhandler.h b/src/zenserver/hub/httpproxyhandler.h new file mode 100644 index 000000000..8667c0ca1 --- /dev/null +++ b/src/zenserver/hub/httpproxyhandler.h @@ -0,0 +1,52 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> + +#include <functional> +#include <memory> +#include <unordered_map> + +namespace zen { + +class HttpClient; + +class HttpProxyHandler +{ +public: + using PortValidator = std::function<bool(uint16_t)>; + + HttpProxyHandler(); + explicit HttpProxyHandler(PortValidator ValidatePort); + ~HttpProxyHandler(); + + void SetPortValidator(PortValidator ValidatePort); + + HttpProxyHandler(const HttpProxyHandler&) = delete; + HttpProxyHandler& operator=(const HttpProxyHandler&) = delete; + + void HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail); + void PrunePort(uint16_t Port); + void Shutdown(); + + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri); + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg); + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason); + +private: + PortValidator m_ValidatePort; + + HttpClient& GetOrCreateProxyClient(uint16_t Port); + + RwLock m_ProxyClientsLock; + std::unordered_map<uint16_t, std::unique_ptr<HttpClient>> m_ProxyClients; + + struct WsBridge; + RwLock m_WsBridgesLock; + std::unordered_map<WebSocketConnection*, Ref<WsBridge>> m_WsBridges; +}; + +} // namespace zen diff --git a/src/zenserver/hub/hub.cpp b/src/zenserver/hub/hub.cpp index 6c44e2333..82f4a00ba 100644 --- a/src/zenserver/hub/hub.cpp +++ b/src/zenserver/hub/hub.cpp @@ -19,7 +19,6 @@ ZEN_THIRD_PARTY_INCLUDES_START ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_WITH_TESTS -# include <zencore/filesystem.h> # include <zencore/testing.h> # include <zencore/testutils.h> #endif @@ -122,6 +121,55 @@ private: ////////////////////////////////////////////////////////////////////////// +ProcessMetrics +Hub::AtomicProcessMetrics::Load() const +{ + return { + .MemoryBytes = MemoryBytes.load(), + .KernelTimeMs = KernelTimeMs.load(), + .UserTimeMs = UserTimeMs.load(), + .WorkingSetSize = WorkingSetSize.load(), + .PeakWorkingSetSize = PeakWorkingSetSize.load(), + .PagefileUsage = PagefileUsage.load(), + .PeakPagefileUsage = PeakPagefileUsage.load(), + }; +} + +void +Hub::AtomicProcessMetrics::Store(const ProcessMetrics& Metrics) +{ + MemoryBytes.store(Metrics.MemoryBytes); + KernelTimeMs.store(Metrics.KernelTimeMs); + UserTimeMs.store(Metrics.UserTimeMs); + WorkingSetSize.store(Metrics.WorkingSetSize); + PeakWorkingSetSize.store(Metrics.PeakWorkingSetSize); + PagefileUsage.store(Metrics.PagefileUsage); + PeakPagefileUsage.store(Metrics.PeakPagefileUsage); +} + +void +Hub::AtomicProcessMetrics::Reset() +{ + MemoryBytes.store(0); + KernelTimeMs.store(0); + UserTimeMs.store(0); + WorkingSetSize.store(0); + PeakWorkingSetSize.store(0); + PagefileUsage.store(0); + PeakPagefileUsage.store(0); +} + +void +Hub::GetMachineMetrics(SystemMetrics& OutSystemMetrict, DiskSpace& OutDiskSpace) const +{ + m_Lock.WithSharedLock([&]() { + OutSystemMetrict = m_SystemMetrics; + OutDiskSpace = m_DiskSpace; + }); +} + +////////////////////////////////////////////////////////////////////////// + Hub::Hub(const Configuration& Config, ZenServerEnvironment&& RunEnvironment, WorkerThreadPool* OptionalWorkerPool, @@ -134,11 +182,11 @@ Hub::Hub(const Configuration& Config, , m_ActiveInstances(Config.InstanceLimit) , m_FreeActiveInstanceIndexes(Config.InstanceLimit) { - m_HostMetrics = GetSystemMetrics(); - m_ResourceLimits.DiskUsageBytes = 1000ull * 1024 * 1024 * 1024; - m_ResourceLimits.MemoryUsageBytes = 16ull * 1024 * 1024 * 1024; - - if (m_Config.HydrationTargetSpecification.empty()) + if (!m_Config.HydrationTargetSpecification.empty()) + { + m_HydrationTargetSpecification = m_Config.HydrationTargetSpecification; + } + else if (!m_Config.HydrationOptions) { std::filesystem::path FileHydrationPath = m_RunEnvironment.CreateChildDir("hydration_storage"); ZEN_INFO("using file hydration path: '{}'", FileHydrationPath); @@ -146,7 +194,7 @@ Hub::Hub(const Configuration& Config, } else { - m_HydrationTargetSpecification = m_Config.HydrationTargetSpecification; + m_HydrationOptions = m_Config.HydrationOptions; } m_HydrationTempPath = m_RunEnvironment.CreateChildDir("hydration_temp"); @@ -171,6 +219,9 @@ Hub::Hub(const Configuration& Config, } } #endif + + UpdateMachineMetrics(); + m_WatchDog = std::thread([this]() { WatchDog(); }); } @@ -195,6 +246,9 @@ Hub::Shutdown() { ZEN_INFO("Hub service shutting down, deprovisioning any current instances"); + bool Expected = false; + bool WaitForBackgroundWork = m_ShutdownFlag.compare_exchange_strong(Expected, true); + m_WatchDogEvent.Set(); if (m_WatchDog.joinable()) { @@ -203,8 +257,6 @@ Hub::Shutdown() m_WatchDog = {}; - bool Expected = false; - bool WaitForBackgroundWork = m_ShutdownFlag.compare_exchange_strong(Expected, true); if (WaitForBackgroundWork && m_WorkerPool) { m_BackgroundWorkLatch.CountDown(); @@ -254,7 +306,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) if (auto It = m_InstanceLookup.find(std::string(ModuleId)); It == m_InstanceLookup.end()) { std::string Reason; - if (!CanProvisionInstance(ModuleId, /* out */ Reason)) + if (!CanProvisionInstanceLocked(ModuleId, /* out */ Reason)) { ZEN_WARN("Cannot provision new storage server instance for module '{}': {}", ModuleId, Reason); @@ -274,6 +326,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) StorageServerInstance::Configuration{.BasePort = GetInstanceIndexAssignedPort(ActiveInstanceIndex), .HydrationTempPath = m_HydrationTempPath, .HydrationTargetSpecification = m_HydrationTargetSpecification, + .HydrationOptions = m_HydrationOptions, .HttpThreadCount = m_Config.InstanceHttpThreadCount, .CoreLimit = m_Config.InstanceCoreLimit, .ConfigPath = m_Config.InstanceConfigPath}, @@ -289,6 +342,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo) Instance = NewInstance->LockExclusive(/*Wait*/ true); m_ActiveInstances[ActiveInstanceIndex].Instance = std::move(NewInstance); + m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Reset(); m_InstanceLookup.insert_or_assign(std::string(ModuleId), ActiveInstanceIndex); // Set Provisioning while both hub lock and instance lock are held so that any // concurrent Deprovision sees the in-flight state, not Unprovisioned. @@ -947,12 +1001,10 @@ Hub::Find(std::string_view ModuleId, InstanceInfo* OutInstanceInfo) ZEN_ASSERT(ActiveInstanceIndex < m_ActiveInstances.size()); const std::unique_ptr<StorageServerInstance>& Instance = m_ActiveInstances[ActiveInstanceIndex].Instance; ZEN_ASSERT(Instance); - InstanceInfo Info{ - m_ActiveInstances[ActiveInstanceIndex].State.load(), - std::chrono::system_clock::now() // TODO - }; - Instance->GetProcessMetrics(Info.Metrics); - Info.Port = Instance->GetBasePort(); + InstanceInfo Info{m_ActiveInstances[ActiveInstanceIndex].State.load(), + m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.load()}; + Info.Metrics = m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Load(); + Info.Port = Instance->GetBasePort(); *OutInstanceInfo = Info; } @@ -971,12 +1023,10 @@ Hub::EnumerateModules(std::function<void(std::string_view ModuleId, const Instan { const std::unique_ptr<StorageServerInstance>& Instance = m_ActiveInstances[ActiveInstanceIndex].Instance; ZEN_ASSERT(Instance); - InstanceInfo Info{ - m_ActiveInstances[ActiveInstanceIndex].State.load(), - std::chrono::system_clock::now() // TODO - }; - Instance->GetProcessMetrics(Info.Metrics); - Info.Port = Instance->GetBasePort(); + InstanceInfo Info{m_ActiveInstances[ActiveInstanceIndex].State.load(), + m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.load()}; + Info.Metrics = m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Load(); + Info.Port = Instance->GetBasePort(); Infos.push_back(std::make_pair(std::string(Instance->GetModuleId()), Info)); } @@ -994,28 +1044,8 @@ Hub::GetInstanceCount() return m_Lock.WithSharedLock([this]() { return gsl::narrow_cast<int>(m_InstanceLookup.size()); }); } -void -Hub::UpdateCapacityMetrics() -{ - m_HostMetrics = GetSystemMetrics(); - - // TODO: Should probably go into WatchDog and use atomic for update so it can be read without locks... - // Per-instance stats are already refreshed by WatchDog and are readable via the Find and EnumerateModules -} - -void -Hub::UpdateStats() -{ - int CurrentInstanceCount = m_Lock.WithSharedLock([this] { return gsl::narrow_cast<int>(m_InstanceLookup.size()); }); - int CurrentMaxCount = m_MaxInstanceCount.load(); - - int NewMax = Max(CurrentMaxCount, CurrentInstanceCount); - - m_MaxInstanceCount.compare_exchange_weak(CurrentMaxCount, NewMax); -} - bool -Hub::CanProvisionInstance(std::string_view ModuleId, std::string& OutReason) +Hub::CanProvisionInstanceLocked(std::string_view ModuleId, std::string& OutReason) { ZEN_UNUSED(ModuleId); if (m_FreeActiveInstanceIndexes.empty()) @@ -1025,7 +1055,24 @@ Hub::CanProvisionInstance(std::string_view ModuleId, std::string& OutReason) return false; } - // TODO: handle additional resource metrics + const uint64_t DiskUsedBytes = m_DiskSpace.Free <= m_DiskSpace.Total ? m_DiskSpace.Total - m_DiskSpace.Free : 0; + if (m_Config.ResourceLimits.DiskUsageBytes > 0 && DiskUsedBytes > m_Config.ResourceLimits.DiskUsageBytes) + { + OutReason = + fmt::format("disk usage ({}) exceeds ({})", NiceBytes(DiskUsedBytes), NiceBytes(m_Config.ResourceLimits.DiskUsageBytes)); + return false; + } + + const uint64_t RamUsedMiB = m_SystemMetrics.AvailSystemMemoryMiB <= m_SystemMetrics.SystemMemoryMiB + ? m_SystemMetrics.SystemMemoryMiB - m_SystemMetrics.AvailSystemMemoryMiB + : 0; + const uint64_t RamUsedBytes = RamUsedMiB * 1024 * 1024; + if (m_Config.ResourceLimits.MemoryUsageBytes > 0 && RamUsedBytes > m_Config.ResourceLimits.MemoryUsageBytes) + { + OutReason = + fmt::format("ram usage ({}) exceeds ({})", NiceBytes(RamUsedBytes), NiceBytes(m_Config.ResourceLimits.MemoryUsageBytes)); + return false; + } return true; } @@ -1036,6 +1083,21 @@ Hub::GetInstanceIndexAssignedPort(size_t ActiveInstanceIndex) const return gsl::narrow<uint16_t>(m_Config.BasePortNumber + ActiveInstanceIndex); } +bool +Hub::IsInstancePort(uint16_t Port) const +{ + if (Port < m_Config.BasePortNumber) + { + return false; + } + size_t Index = Port - m_Config.BasePortNumber; + if (Index >= m_ActiveInstances.size()) + { + return false; + } + return m_ActiveInstances[Index].State.load(std::memory_order_relaxed) != HubInstanceState::Unprovisioned; +} + HubInstanceState Hub::UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewState) { @@ -1065,8 +1127,10 @@ Hub::UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewS } return false; }(m_ActiveInstances[ActiveInstanceIndex].State.load(), NewState)); + const std::chrono::system_clock::time_point Now = std::chrono::system_clock::now(); m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.store(0); - m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(std::chrono::system_clock::now()); + m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(Now); + m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.store(Now); return m_ActiveInstances[ActiveInstanceIndex].State.exchange(NewState); } @@ -1173,14 +1237,14 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient, StorageServerInstance::SharedLockedPtr&& LockedInstance, size_t ActiveInstanceIndex) { + const std::string ModuleId(LockedInstance.GetModuleId()); + HubInstanceState InstanceState = m_ActiveInstances[ActiveInstanceIndex].State.load(); if (LockedInstance.IsRunning()) { - LockedInstance.UpdateMetrics(); + m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Store(LockedInstance.GetProcessMetrics()); if (InstanceState == HubInstanceState::Provisioned) { - const std::string ModuleId(LockedInstance.GetModuleId()); - const uint16_t Port = LockedInstance.GetBasePort(); const uint64_t PreviousActivitySum = m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.load(); const std::chrono::system_clock::time_point LastActivityTime = m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.load(); @@ -1260,8 +1324,7 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient, else if (InstanceState == HubInstanceState::Provisioned) { // Process is not running but state says it should be - instance died unexpectedly. - const std::string ModuleId(LockedInstance.GetModuleId()); - const uint16_t Port = LockedInstance.GetBasePort(); + const uint16_t Port = LockedInstance.GetBasePort(); UpdateInstanceState(LockedInstance, ActiveInstanceIndex, HubInstanceState::Crashed); NotifyStateUpdate(ModuleId, HubInstanceState::Provisioned, HubInstanceState::Crashed, Port, {}); LockedInstance = {}; @@ -1272,7 +1335,6 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient, { // Process is not running - no HTTP activity check is possible. // Use a pure time-based check; the margin window does not apply here. - const std::string ModuleId = std::string(LockedInstance.GetModuleId()); const std::chrono::system_clock::time_point LastActivityTime = m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.load(); const uint64_t PreviousActivitySum = m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.load(); const std::chrono::system_clock::time_point Now = std::chrono::system_clock::now(); @@ -1312,6 +1374,43 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient, } void +Hub::UpdateMachineMetrics() +{ + try + { + bool DiskSpaceOk = false; + DiskSpace Disk; + + std::filesystem::path ChildDir = m_RunEnvironment.GetChildBaseDir(); + if (!ChildDir.empty()) + { + if (DiskSpaceInfo(ChildDir, Disk)) + { + DiskSpaceOk = true; + } + else + { + ZEN_WARN("Failed to query disk space for '{}'; disk-based provisioning limits will not be enforced", ChildDir); + } + } + + SystemMetrics Metrics = GetSystemMetrics(); + + m_Lock.WithExclusiveLock([&]() { + if (DiskSpaceOk) + { + m_DiskSpace = Disk; + } + m_SystemMetrics = Metrics; + }); + } + catch (const std::exception& Ex) + { + ZEN_WARN("Failed to update machine metrics. Reason: {}", Ex.what()); + } +} + +void Hub::WatchDog() { const uint64_t CycleIntervalMs = std::chrono::duration_cast<std::chrono::milliseconds>(m_Config.WatchDog.CycleInterval).count(); @@ -1326,16 +1425,18 @@ Hub::WatchDog() [&]() -> bool { return m_WatchDogEvent.Wait(0); }); size_t CheckInstanceIndex = SIZE_MAX; // first increment wraps to 0 - while (!m_WatchDogEvent.Wait(gsl::narrow<int>(CycleIntervalMs))) + while (!m_ShutdownFlag.load() && !m_WatchDogEvent.Wait(gsl::narrow<int>(CycleIntervalMs))) { try { + UpdateMachineMetrics(); + // Snapshot slot count. We iterate all slots (including freed nulls) so // round-robin coverage is not skewed by deprovisioned entries. size_t SlotsRemaining = m_Lock.WithSharedLock([this]() { return m_ActiveInstances.size(); }); Stopwatch Timer; - bool ShuttingDown = false; + bool ShuttingDown = m_ShutdownFlag.load(); while (SlotsRemaining > 0 && Timer.GetElapsedTimeMs() < CycleProcessingBudgetMs && !ShuttingDown) { StorageServerInstance::SharedLockedPtr LockedInstance; @@ -1366,16 +1467,24 @@ Hub::WatchDog() std::string ModuleId(LockedInstance.GetModuleId()); - bool InstanceIsOk = CheckInstanceStatus(ActivityCheckClient, std::move(LockedInstance), CheckInstanceIndex); - if (InstanceIsOk) + try { - ShuttingDown = m_WatchDogEvent.Wait(gsl::narrow<int>(InstanceCheckThrottleMs)); + bool InstanceIsOk = CheckInstanceStatus(ActivityCheckClient, std::move(LockedInstance), CheckInstanceIndex); + if (InstanceIsOk) + { + ShuttingDown = m_WatchDogEvent.Wait(gsl::narrow<int>(InstanceCheckThrottleMs)); + } + else + { + ZEN_WARN("Instance for module '{}' is not running, attempting recovery", ModuleId); + AttemptRecoverInstance(ModuleId); + } } - else + catch (const std::exception& Ex) { - ZEN_WARN("Instance for module '{}' is not running, attempting recovery", ModuleId); - AttemptRecoverInstance(ModuleId); + ZEN_WARN("Failed to check status of module {}. Reason: {}", ModuleId, Ex.what()); } + ShuttingDown |= m_ShutdownFlag.load(); } } catch (const std::exception& Ex) @@ -1515,6 +1624,8 @@ TEST_CASE("hub.provision_basic") Hub::InstanceInfo InstanceInfo; REQUIRE(HubInstance->Find("module_a", &InstanceInfo)); CHECK_EQ(InstanceInfo.State, HubInstanceState::Provisioned); + CHECK_NE(InstanceInfo.StateChangeTime, std::chrono::system_clock::time_point::min()); + CHECK_LE(InstanceInfo.StateChangeTime, std::chrono::system_clock::now()); { HttpClient ModClient(fmt::format("http://localhost:{}", Info.Port), kFastTimeout); @@ -1934,6 +2045,9 @@ TEST_CASE("hub.hibernate_wake") } REQUIRE(HubInstance->Find("hib_a", &Info)); CHECK_EQ(Info.State, HubInstanceState::Provisioned); + const std::chrono::system_clock::time_point ProvisionedTime = Info.StateChangeTime; + CHECK_NE(ProvisionedTime, std::chrono::system_clock::time_point::min()); + CHECK_LE(ProvisionedTime, std::chrono::system_clock::now()); { HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout); CHECK(ModClient.Get("/health/")); @@ -1944,6 +2058,8 @@ TEST_CASE("hub.hibernate_wake") REQUIRE_MESSAGE(HibernateResult.ResponseCode == Hub::EResponseCode::Completed, HibernateResult.Message); REQUIRE(HubInstance->Find("hib_a", &Info)); CHECK_EQ(Info.State, HubInstanceState::Hibernated); + const std::chrono::system_clock::time_point HibernatedTime = Info.StateChangeTime; + CHECK_GE(HibernatedTime, ProvisionedTime); { HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout); CHECK(!ModClient.Get("/health/")); @@ -1954,6 +2070,7 @@ TEST_CASE("hub.hibernate_wake") REQUIRE_MESSAGE(WakeResult.ResponseCode == Hub::EResponseCode::Completed, WakeResult.Message); REQUIRE(HubInstance->Find("hib_a", &Info)); CHECK_EQ(Info.State, HubInstanceState::Provisioned); + CHECK_GE(Info.StateChangeTime, HibernatedTime); { HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout); CHECK(ModClient.Get("/health/")); @@ -2352,7 +2469,7 @@ TEST_CASE("hub.async_provision_shutdown_waits") TEST_CASE("hub.async_provision_rejected") { - // Rejection from CanProvisionInstance fires synchronously even when a WorkerPool is present. + // Rejection from CanProvisionInstanceLocked fires synchronously even when a WorkerPool is present. ScopedTemporaryDirectory TempDir; Hub::Configuration Config; @@ -2369,7 +2486,7 @@ TEST_CASE("hub.async_provision_rejected") REQUIRE_MESSAGE(FirstResult.ResponseCode == Hub::EResponseCode::Accepted, FirstResult.Message); REQUIRE_NE(Info.Port, 0); - // Second provision: CanProvisionInstance rejects synchronously (limit reached), returns Rejected + // Second provision: CanProvisionInstanceLocked rejects synchronously (limit reached), returns Rejected HubProvisionedInstanceInfo Info2; const Hub::Response SecondResult = HubInstance->Provision("async_r2", Info2); CHECK(SecondResult.ResponseCode == Hub::EResponseCode::Rejected); @@ -2485,6 +2602,55 @@ TEST_CASE("hub.instance.inactivity.deprovision") HubInstance->Shutdown(); } +TEST_CASE("hub.machine_metrics") +{ + ScopedTemporaryDirectory TempDir; + + std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), {}); + + // UpdateMachineMetrics() is called synchronously in the Hub constructor, so metrics + // are available immediately without waiting for a watchdog cycle. + SystemMetrics SysMetrics; + DiskSpace Disk; + HubInstance->GetMachineMetrics(SysMetrics, Disk); + + CHECK_GT(Disk.Total, 0u); + CHECK_LE(Disk.Free, Disk.Total); + + CHECK_GT(SysMetrics.SystemMemoryMiB, 0u); + CHECK_LE(SysMetrics.AvailSystemMemoryMiB, SysMetrics.SystemMemoryMiB); + + CHECK_GT(SysMetrics.VirtualMemoryMiB, 0u); + CHECK_LE(SysMetrics.AvailVirtualMemoryMiB, SysMetrics.VirtualMemoryMiB); +} + +TEST_CASE("hub.provision_rejected_resource_limits") +{ + // The Hub constructor calls UpdateMachineMetrics() synchronously, so CanProvisionInstanceLocked + // can enforce limits immediately without waiting for a watchdog cycle. + ScopedTemporaryDirectory TempDir; + + { + Hub::Configuration Config; + Config.ResourceLimits.DiskUsageBytes = 1; + std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config); + HubProvisionedInstanceInfo Info; + const Hub::Response Result = HubInstance->Provision("disk_limit", Info); + CHECK(Result.ResponseCode == Hub::EResponseCode::Rejected); + CHECK_NE(Result.Message.find("disk usage"), std::string::npos); + } + + { + Hub::Configuration Config; + Config.ResourceLimits.MemoryUsageBytes = 1; + std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config); + HubProvisionedInstanceInfo Info; + const Hub::Response Result = HubInstance->Provision("mem_limit", Info); + CHECK(Result.ResponseCode == Hub::EResponseCode::Rejected); + CHECK_NE(Result.Message.find("ram usage"), std::string::npos); + } +} + TEST_SUITE_END(); void diff --git a/src/zenserver/hub/hub.h b/src/zenserver/hub/hub.h index c343b19e2..8ee9130f6 100644 --- a/src/zenserver/hub/hub.h +++ b/src/zenserver/hub/hub.h @@ -6,6 +6,8 @@ #include "resourcemetrics.h" #include "storageserverinstance.h" +#include <zencore/compactbinary.h> +#include <zencore/filesystem.h> #include <zencore/system.h> #include <zenutil/zenserverprocess.h> @@ -66,8 +68,11 @@ public: int InstanceCoreLimit = 0; // Automatic std::filesystem::path InstanceConfigPath; std::string HydrationTargetSpecification; + CbObject HydrationOptions; WatchDogConfiguration WatchDog; + + ResourceMetrics ResourceLimits; }; typedef std::function< @@ -86,7 +91,7 @@ public: struct InstanceInfo { HubInstanceState State = HubInstanceState::Unprovisioned; - std::chrono::system_clock::time_point ProvisionTime; + std::chrono::system_clock::time_point StateChangeTime; ProcessMetrics Metrics; uint16_t Port = 0; }; @@ -160,6 +165,10 @@ public: int GetMaxInstanceCount() const { return m_MaxInstanceCount.load(); } + void GetMachineMetrics(SystemMetrics& OutSystemMetrict, DiskSpace& OutDiskSpace) const; + + bool IsInstancePort(uint16_t Port) const; + const Configuration& GetConfig() const { return m_Config; } #if ZEN_WITH_TESTS @@ -176,14 +185,31 @@ private: AsyncModuleStateChangeCallbackFunc m_ModuleStateChangeCallback; std::string m_HydrationTargetSpecification; + CbObject m_HydrationOptions; std::filesystem::path m_HydrationTempPath; #if ZEN_PLATFORM_WINDOWS JobObject m_JobObject; #endif - RwLock m_Lock; + mutable RwLock m_Lock; std::unordered_map<std::string, size_t> m_InstanceLookup; + // Mirrors ProcessMetrics with atomic fields, enabling lock-free reads alongside watchdog writes. + struct AtomicProcessMetrics + { + std::atomic<uint64_t> MemoryBytes = 0; + std::atomic<uint64_t> KernelTimeMs = 0; + std::atomic<uint64_t> UserTimeMs = 0; + std::atomic<uint64_t> WorkingSetSize = 0; + std::atomic<uint64_t> PeakWorkingSetSize = 0; + std::atomic<uint64_t> PagefileUsage = 0; + std::atomic<uint64_t> PeakPagefileUsage = 0; + + ProcessMetrics Load() const; + void Store(const ProcessMetrics& Metrics); + void Reset(); + }; + struct ActiveInstance { // Invariant: Instance == nullptr if and only if State == Unprovisioned. @@ -192,11 +218,16 @@ private: // without holding the hub lock. std::unique_ptr<StorageServerInstance> Instance; std::atomic<HubInstanceState> State = HubInstanceState::Unprovisioned; - // TODO: We should move current metrics here (from StorageServerInstance) - // Read and updated by WatchDog, updates to State triggers a reset of both + // Process metrics - written by WatchDog (inside instance shared lock), read lock-free. + AtomicProcessMetrics ProcessMetrics; + + // Activity tracking - written by WatchDog, reset on every state transition. std::atomic<uint64_t> LastKnownActivitySum = 0; std::atomic<std::chrono::system_clock::time_point> LastActivityTime = std::chrono::system_clock::time_point::min(); + + // Set in UpdateInstanceStateLocked on every state transition; read lock-free by Find/EnumerateModules. + std::atomic<std::chrono::system_clock::time_point> StateChangeTime = std::chrono::system_clock::time_point::min(); }; // UpdateInstanceState is overloaded to accept a locked instance pointer (exclusive or shared) or the hub exclusive @@ -226,21 +257,20 @@ private: std::vector<ActiveInstance> m_ActiveInstances; std::deque<size_t> m_FreeActiveInstanceIndexes; - ResourceMetrics m_ResourceLimits; - SystemMetrics m_HostMetrics; + SystemMetrics m_SystemMetrics; + DiskSpace m_DiskSpace; std::atomic<int> m_MaxInstanceCount = 0; std::thread m_WatchDog; Event m_WatchDogEvent; void WatchDog(); + void UpdateMachineMetrics(); bool CheckInstanceStatus(HttpClient& ActivityHttpClient, StorageServerInstance::SharedLockedPtr&& LockedInstance, size_t ActiveInstanceIndex); void AttemptRecoverInstance(std::string_view ModuleId); - void UpdateStats(); - void UpdateCapacityMetrics(); - bool CanProvisionInstance(std::string_view ModuleId, std::string& OutReason); + bool CanProvisionInstanceLocked(std::string_view ModuleId, std::string& OutReason); uint16_t GetInstanceIndexAssignedPort(size_t ActiveInstanceIndex) const; Response InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveInstance& Instance)>&& DeprovisionGate); diff --git a/src/zenserver/hub/hydration.cpp b/src/zenserver/hub/hydration.cpp index 541127590..ed16bfe56 100644 --- a/src/zenserver/hub/hydration.cpp +++ b/src/zenserver/hub/hydration.cpp @@ -10,6 +10,7 @@ #include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/system.h> +#include <zencore/timer.h> #include <zenutil/cloud/imdscredentials.h> #include <zenutil/cloud/s3client.h> @@ -60,6 +61,7 @@ namespace { /////////////////////////////////////////////////////////////////////////// constexpr std::string_view FileHydratorPrefix = "file://"; +constexpr std::string_view FileHydratorType = "file"; struct FileHydrator : public HydrationStrategyBase { @@ -77,7 +79,21 @@ FileHydrator::Configure(const HydrationConfig& Config) { m_Config = Config; - std::filesystem::path ConfigPath(Utf8ToWide(m_Config.TargetSpecification.substr(FileHydratorPrefix.length()))); + std::filesystem::path ConfigPath; + if (!m_Config.TargetSpecification.empty()) + { + ConfigPath = Utf8ToWide(m_Config.TargetSpecification.substr(FileHydratorPrefix.length())); + } + else + { + CbObjectView Settings = m_Config.Options["settings"].AsObjectView(); + std::string_view Path = Settings["path"].AsString(); + if (Path.empty()) + { + throw zen::runtime_error("Hydration config 'file' type requires 'settings.path'"); + } + ConfigPath = Utf8ToWide(std::string(Path)); + } MakeSafeAbsolutePathInPlace(ConfigPath); if (!std::filesystem::exists(ConfigPath)) @@ -95,6 +111,8 @@ FileHydrator::Hydrate() { ZEN_INFO("Hydrating state from '{}' to '{}'", m_StorageModuleRootDir, m_Config.ServerStateDir); + Stopwatch Timer; + // Ensure target is clean ZEN_DEBUG("Wiping server state at '{}'", m_Config.ServerStateDir); const bool ForceRemoveReadOnlyFiles = true; @@ -120,6 +138,10 @@ FileHydrator::Hydrate() ZEN_DEBUG("Cleaning server state '{}'", m_Config.ServerStateDir); CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles); } + else + { + ZEN_INFO("Hydration complete in {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + } } void @@ -127,6 +149,8 @@ FileHydrator::Dehydrate() { ZEN_INFO("Dehydrating state from '{}' to '{}'", m_Config.ServerStateDir, m_StorageModuleRootDir); + Stopwatch Timer; + const std::filesystem::path TargetDir = m_StorageModuleRootDir; // Ensure target is clean. This could be replaced with an atomic copy at a later date @@ -141,7 +165,23 @@ FileHydrator::Dehydrate() try { ZEN_DEBUG("Copying '{}' to '{}'", m_Config.ServerStateDir, TargetDir); - CopyTree(m_Config.ServerStateDir, TargetDir, {.EnableClone = true}); + for (const std::filesystem::directory_entry& Entry : std::filesystem::directory_iterator(m_Config.ServerStateDir)) + { + if (Entry.path().filename() == ".sentry-native") + { + continue; + } + std::filesystem::path Dest = TargetDir / Entry.path().filename(); + if (Entry.is_directory()) + { + CreateDirectories(Dest); + CopyTree(Entry.path(), Dest, {.EnableClone = true}); + } + else + { + CopyFile(Entry.path(), Dest, {.EnableClone = true}); + } + } } catch (std::exception& Ex) { @@ -159,11 +199,17 @@ FileHydrator::Dehydrate() ZEN_DEBUG("Wiping server state '{}'", m_Config.ServerStateDir); CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles); + + if (CopySuccess) + { + ZEN_INFO("Dehydration complete in {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + } } /////////////////////////////////////////////////////////////////////////// constexpr std::string_view S3HydratorPrefix = "s3://"; +constexpr std::string_view S3HydratorType = "s3"; struct S3Hydrator : public HydrationStrategyBase { @@ -182,6 +228,8 @@ private: std::string m_Region; SigV4Credentials m_Credentials; Ref<ImdsCredentialProvider> m_CredentialProvider; + + static constexpr uint64_t MultipartChunkSize = 8 * 1024 * 1024; }; void @@ -189,8 +237,23 @@ S3Hydrator::Configure(const HydrationConfig& Config) { m_Config = Config; - std::string_view Spec = m_Config.TargetSpecification; - Spec.remove_prefix(S3HydratorPrefix.size()); + CbObjectView Settings = m_Config.Options["settings"].AsObjectView(); + std::string_view Spec; + if (!m_Config.TargetSpecification.empty()) + { + Spec = m_Config.TargetSpecification; + Spec.remove_prefix(S3HydratorPrefix.size()); + } + else + { + std::string_view Uri = Settings["uri"].AsString(); + if (Uri.empty()) + { + throw zen::runtime_error("Hydration config 's3' type requires 'settings.uri'"); + } + Spec = Uri; + Spec.remove_prefix(S3HydratorPrefix.size()); + } size_t SlashPos = Spec.find('/'); std::string UserPrefix = SlashPos != std::string_view::npos ? std::string(Spec.substr(SlashPos + 1)) : std::string{}; @@ -199,7 +262,11 @@ S3Hydrator::Configure(const HydrationConfig& Config) ZEN_ASSERT(!m_Bucket.empty()); - std::string Region = GetEnvVariable("AWS_DEFAULT_REGION"); + std::string Region = std::string(Settings["region"].AsString()); + if (Region.empty()) + { + Region = GetEnvVariable("AWS_DEFAULT_REGION"); + } if (Region.empty()) { Region = GetEnvVariable("AWS_REGION"); @@ -230,10 +297,12 @@ S3Hydrator::CreateS3Client() const Options.BucketName = m_Bucket; Options.Region = m_Region; - if (!m_Config.S3Endpoint.empty()) + CbObjectView Settings = m_Config.Options["settings"].AsObjectView(); + std::string_view Endpoint = Settings["endpoint"].AsString(); + if (!Endpoint.empty()) { - Options.Endpoint = m_Config.S3Endpoint; - Options.PathStyle = m_Config.S3PathStyle; + Options.Endpoint = std::string(Endpoint); + Options.PathStyle = Settings["path-style"].AsBool(); } if (m_CredentialProvider) @@ -245,6 +314,8 @@ S3Hydrator::CreateS3Client() const Options.Credentials = m_Credentials; } + Options.HttpSettings.MaximumInMemoryDownloadSize = 16u * 1024u; + return S3Client(Options); } @@ -275,11 +346,11 @@ S3Hydrator::Dehydrate() try { - S3Client Client = CreateS3Client(); - std::string FolderName = BuildTimestampFolderName(); - uint64_t TotalBytes = 0; - uint32_t FileCount = 0; - std::chrono::steady_clock::time_point UploadStart = std::chrono::steady_clock::now(); + S3Client Client = CreateS3Client(); + std::string FolderName = BuildTimestampFolderName(); + uint64_t TotalBytes = 0; + uint32_t FileCount = 0; + Stopwatch Timer; DirectoryContent DirContent; GetDirectoryContent(m_Config.ServerStateDir, DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive, DirContent); @@ -295,13 +366,20 @@ S3Hydrator::Dehydrate() AbsPath.string(), m_Config.ServerStateDir.string()); } + if (*RelPath.begin() == ".sentry-native") + { + continue; + } std::string Key = MakeObjectKey(FolderName, RelPath); BasicFile File(AbsPath, BasicFile::Mode::kRead); uint64_t FileSize = File.FileSize(); - S3Result UploadResult = - Client.PutObjectMultipart(Key, FileSize, [&File](uint64_t Offset, uint64_t Size) { return File.ReadRange(Offset, Size); }); + S3Result UploadResult = Client.PutObjectMultipart( + Key, + FileSize, + [&File](uint64_t Offset, uint64_t Size) { return File.ReadRange(Offset, Size); }, + MultipartChunkSize); if (!UploadResult.IsSuccess()) { throw zen::runtime_error("Failed to upload '{}' to S3: {}", Key, UploadResult.Error); @@ -312,8 +390,7 @@ S3Hydrator::Dehydrate() } // Write current-state.json - int64_t UploadDurationMs = - std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - UploadStart).count(); + uint64_t UploadDurationMs = Timer.GetElapsedTimeMs(); UtcTime Now = UtcTime::Now(); std::string UploadTimeUtc = fmt::format("{:04d}-{:02d}-{:02d}T{:02d}:{:02d}:{:02d}.{:03d}Z", @@ -346,7 +423,7 @@ S3Hydrator::Dehydrate() throw zen::runtime_error("Failed to write current-state.json to '{}': {}", MetaKey, MetaUploadResult.Error); } - ZEN_INFO("Dehydration complete: {} files, {} bytes, {} ms", FileCount, TotalBytes, UploadDurationMs); + ZEN_INFO("Dehydration complete: {} files, {}, {}", FileCount, NiceBytes(TotalBytes), NiceTimeSpanMs(UploadDurationMs)); } catch (std::exception& Ex) { @@ -361,6 +438,7 @@ S3Hydrator::Hydrate() { ZEN_INFO("Hydrating state from s3://{}/{} to '{}'", m_Bucket, m_KeyPrefix, m_Config.ServerStateDir); + Stopwatch Timer; const bool ForceRemoveReadOnlyFiles = true; // Clean temp dir before starting in case of leftover state from a previous failed hydration @@ -374,19 +452,17 @@ S3Hydrator::Hydrate() S3Client Client = CreateS3Client(); std::string MetaKey = m_KeyPrefix + "/current-state.json"; - S3HeadObjectResult HeadResult = Client.HeadObject(MetaKey); - if (HeadResult.Status == HeadObjectResult::NotFound) - { - throw zen::runtime_error("No state found in S3 at '{}'", MetaKey); - } - if (!HeadResult.IsSuccess()) - { - throw zen::runtime_error("Failed to check for state in S3 at '{}': {}", MetaKey, HeadResult.Error); - } - S3GetObjectResult MetaResult = Client.GetObject(MetaKey); if (!MetaResult.IsSuccess()) { + if (MetaResult.Error == S3GetObjectResult::NotFoundErrorText) + { + ZEN_INFO("No state found in S3 at {}", MetaKey); + + ZEN_DEBUG("Wiping server state '{}'", m_Config.ServerStateDir); + CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles); + return; + } throw zen::runtime_error("Failed to read current-state.json from '{}': {}", MetaKey, MetaResult.Error); } @@ -426,17 +502,17 @@ S3Hydrator::Hydrate() std::filesystem::path DestPath = MakeSafeAbsolutePath(m_Config.TempDir / std::filesystem::path(RelKey)); CreateDirectories(DestPath.parent_path()); - BasicFile DestFile(DestPath, BasicFile::Mode::kTruncate); - DestFile.SetFileSize(Obj.Size); - - if (Obj.Size > 0) + if (Obj.Size > MultipartChunkSize) { + BasicFile DestFile(DestPath, BasicFile::Mode::kTruncate); + DestFile.SetFileSize(Obj.Size); + BasicFileWriter Writer(DestFile, 64 * 1024); uint64_t Offset = 0; while (Offset < Obj.Size) { - uint64_t ChunkSize = std::min<uint64_t>(8 * 1024 * 1024, Obj.Size - Offset); + uint64_t ChunkSize = std::min<uint64_t>(MultipartChunkSize, Obj.Size - Offset); S3GetObjectResult Chunk = Client.GetObjectRange(Obj.Key, Offset, ChunkSize); if (!Chunk.IsSuccess()) { @@ -453,6 +529,34 @@ S3Hydrator::Hydrate() Writer.Flush(); } + else + { + S3GetObjectResult Chunk = Client.GetObject(Obj.Key, m_Config.TempDir); + if (!Chunk.IsSuccess()) + { + throw zen::runtime_error("Failed to download '{}' from S3: {}", Obj.Key, Chunk.Error); + } + + if (IoBufferFileReference FileRef; Chunk.Content.GetFileReference(FileRef)) + { + std::error_code Ec; + std::filesystem::path ChunkPath = PathFromHandle(FileRef.FileHandle, Ec); + if (Ec) + { + WriteFile(DestPath, Chunk.Content); + } + else + { + Chunk.Content.SetDeleteOnClose(false); + Chunk.Content = {}; + RenameFile(ChunkPath, DestPath, Ec); + } + } + else + { + WriteFile(DestPath, Chunk.Content); + } + } } // Downloaded successfully - swap into ServerStateDir @@ -465,19 +569,20 @@ S3Hydrator::Hydrate() std::mismatch(m_Config.TempDir.begin(), m_Config.TempDir.end(), m_Config.ServerStateDir.begin(), m_Config.ServerStateDir.end()); if (ItTmp != m_Config.TempDir.begin()) { - // Fast path: atomic renames - no data copying needed - for (const std::filesystem::directory_entry& Entry : std::filesystem::directory_iterator(m_Config.TempDir)) + DirectoryContent DirContent; + GetDirectoryContent(m_Config.TempDir, DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::IncludeDirs, DirContent); + + for (const std::filesystem::path& AbsPath : DirContent.Directories) { - std::filesystem::path Dest = MakeSafeAbsolutePath(m_Config.ServerStateDir / Entry.path().filename()); - if (Entry.is_directory()) - { - RenameDirectory(Entry.path(), Dest); - } - else - { - RenameFile(Entry.path(), Dest); - } + std::filesystem::path Dest = MakeSafeAbsolutePath(m_Config.ServerStateDir / AbsPath.filename()); + RenameDirectory(AbsPath, Dest); + } + for (const std::filesystem::path& AbsPath : DirContent.Files) + { + std::filesystem::path Dest = MakeSafeAbsolutePath(m_Config.ServerStateDir / AbsPath.filename()); + RenameFile(AbsPath, Dest); } + ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir); CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles); } @@ -491,7 +596,7 @@ S3Hydrator::Hydrate() CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles); } - ZEN_INFO("Hydration complete from folder '{}'", FolderName); + ZEN_INFO("Hydration complete from folder '{}' in {}", FolderName, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } catch (std::exception& Ex) { @@ -513,19 +618,41 @@ S3Hydrator::Hydrate() std::unique_ptr<HydrationStrategyBase> CreateHydrator(const HydrationConfig& Config) { - if (StrCaseCompare(Config.TargetSpecification.substr(0, FileHydratorPrefix.length()), FileHydratorPrefix) == 0) + if (!Config.TargetSpecification.empty()) + { + if (StrCaseCompare(Config.TargetSpecification.substr(0, FileHydratorPrefix.length()), FileHydratorPrefix) == 0) + { + std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<FileHydrator>(); + Hydrator->Configure(Config); + return Hydrator; + } + if (StrCaseCompare(Config.TargetSpecification.substr(0, S3HydratorPrefix.length()), S3HydratorPrefix) == 0) + { + std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<S3Hydrator>(); + Hydrator->Configure(Config); + return Hydrator; + } + throw std::runtime_error(fmt::format("Unknown hydration strategy: {}", Config.TargetSpecification)); + } + + std::string_view Type = Config.Options["type"].AsString(); + if (Type == FileHydratorType) { std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<FileHydrator>(); Hydrator->Configure(Config); return Hydrator; } - if (StrCaseCompare(Config.TargetSpecification.substr(0, S3HydratorPrefix.length()), S3HydratorPrefix) == 0) + if (Type == S3HydratorType) { std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<S3Hydrator>(); Hydrator->Configure(Config); return Hydrator; } - throw std::runtime_error(fmt::format("Unknown hydration strategy: {}", Config.TargetSpecification)); + if (!Type.empty()) + { + throw zen::runtime_error("Unknown hydration target type '{}'", Type); + } + throw zen::runtime_error("No hydration target configured"); } #if ZEN_WITH_TESTS @@ -607,6 +734,12 @@ namespace { AddFile("file_a.bin", CreateSemiRandomBlob(1024)); AddFile("subdir/file_b.bin", CreateSemiRandomBlob(2048)); AddFile("subdir/nested/file_c.bin", CreateSemiRandomBlob(512)); + AddFile("subdir/nested/file_d.bin", CreateSemiRandomBlob(512)); + AddFile("subdir/nested/file_e.bin", CreateSemiRandomBlob(512)); + AddFile("subdir/nested/file_f.bin", CreateSemiRandomBlob(512)); + AddFile("subdir/nested/medium.bulk", CreateSemiRandomBlob(256u * 1024u)); + AddFile("subdir/nested/big.bulk", CreateSemiRandomBlob(512u * 1024u)); + AddFile("subdir/nested/huge.bulk", CreateSemiRandomBlob(9u * 1024u * 1024u)); return Files; } @@ -844,12 +977,16 @@ TEST_CASE("hydration.s3.dehydrate_hydrate") auto TestFiles = CreateTestTree(ServerStateDir); HydrationConfig Config; - Config.ServerStateDir = ServerStateDir; - Config.TempDir = HydrationTemp; - Config.ModuleId = ModuleId; - Config.TargetSpecification = "s3://zen-hydration-test"; - Config.S3Endpoint = Minio.Endpoint(); - Config.S3PathStyle = true; + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = ModuleId; + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); // Dehydrate: upload server state to MinIO { @@ -902,12 +1039,18 @@ TEST_CASE("hydration.s3.current_state_json_selects_latest_folder") const std::string ModuleId = "s3test_folder_select"; HydrationConfig Config; - Config.ServerStateDir = ServerStateDir; - Config.TempDir = HydrationTemp; - Config.ModuleId = ModuleId; - Config.TargetSpecification = "s3://zen-hydration-test"; - Config.S3Endpoint = Minio.Endpoint(); - Config.S3PathStyle = true; + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = ModuleId; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); + } // v1: dehydrate without a marker file CreateTestTree(ServerStateDir); @@ -972,13 +1115,19 @@ TEST_CASE("hydration.s3.module_isolation") CreateDirectories(TempPath); ModuleData Data; - Data.Config.ServerStateDir = StateDir; - Data.Config.TempDir = TempPath; - Data.Config.ModuleId = ModuleId; - Data.Config.TargetSpecification = "s3://zen-hydration-test"; - Data.Config.S3Endpoint = Minio.Endpoint(); - Data.Config.S3PathStyle = true; - Data.Files = CreateTestTree(StateDir); + Data.Config.ServerStateDir = StateDir; + Data.Config.TempDir = TempPath; + Data.Config.ModuleId = ModuleId; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Data.Config.Options = std::move(Root).AsObject(); + } + Data.Files = CreateTestTree(StateDir); std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Data.Config); Hydrator->Dehydrate(); @@ -1015,7 +1164,8 @@ TEST_CASE("hydration.s3.concurrent") ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); - constexpr int kModuleCount = 4; + constexpr int kModuleCount = 16; + constexpr int kThreadCount = 4; ScopedTemporaryDirectory TempDir; @@ -1034,18 +1184,24 @@ TEST_CASE("hydration.s3.concurrent") CreateDirectories(StateDir); CreateDirectories(TempPath); - Modules[I].Config.ServerStateDir = StateDir; - Modules[I].Config.TempDir = TempPath; - Modules[I].Config.ModuleId = ModuleId; - Modules[I].Config.TargetSpecification = "s3://zen-hydration-test"; - Modules[I].Config.S3Endpoint = Minio.Endpoint(); - Modules[I].Config.S3PathStyle = true; - Modules[I].Files = CreateTestTree(StateDir); + Modules[I].Config.ServerStateDir = StateDir; + Modules[I].Config.TempDir = TempPath; + Modules[I].Config.ModuleId = ModuleId; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Modules[I].Config.Options = std::move(Root).AsObject(); + } + Modules[I].Files = CreateTestTree(StateDir); } // Concurrent dehydrate { - WorkerThreadPool Pool(kModuleCount, "hydration_s3_dehy"); + WorkerThreadPool Pool(kThreadCount, "hydration_s3_dehy"); std::atomic<bool> AbortFlag{false}; std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); @@ -1063,7 +1219,7 @@ TEST_CASE("hydration.s3.concurrent") // Concurrent hydrate { - WorkerThreadPool Pool(kModuleCount, "hydration_s3_hy"); + WorkerThreadPool Pool(kThreadCount, "hydration_s3_hy"); std::atomic<bool> AbortFlag{false}; std::atomic<bool> PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); @@ -1116,12 +1272,18 @@ TEST_CASE("hydration.s3.no_prior_state") WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256)); HydrationConfig Config; - Config.ServerStateDir = ServerStateDir; - Config.TempDir = HydrationTemp; - Config.ModuleId = "s3test_no_prior"; - Config.TargetSpecification = "s3://zen-hydration-test"; - Config.S3Endpoint = Minio.Endpoint(); - Config.S3PathStyle = true; + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = "s3test_no_prior"; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); + } std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); Hydrator->Hydrate(); @@ -1159,12 +1321,71 @@ TEST_CASE("hydration.s3.path_prefix") std::vector<std::pair<std::filesystem::path, IoBuffer>> TestFiles = CreateTestTree(ServerStateDir); HydrationConfig Config; - Config.ServerStateDir = ServerStateDir; - Config.TempDir = HydrationTemp; - Config.ModuleId = "s3test_prefix"; - Config.TargetSpecification = "s3://zen-hydration-test/team/project"; - Config.S3Endpoint = Minio.Endpoint(); - Config.S3PathStyle = true; + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = "s3test_prefix"; + { + std::string ConfigJson = + fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test/team/project","endpoint":"{}","path-style":true}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); + } + + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Dehydrate(); + } + + CleanDirectory(ServerStateDir, true); + + { + std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); + Hydrator->Hydrate(); + } + + VerifyTree(ServerStateDir, TestFiles); +} + +TEST_CASE("hydration.s3.options_region_override") +{ + // Verify that 'region' in Options["settings"] takes precedence over AWS_DEFAULT_REGION env var. + // AWS_DEFAULT_REGION is set to a bogus value; hydration must succeed using the region from Options. + + MinioProcessOptions MinioOpts; + MinioOpts.Port = 19016; + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + Minio.CreateBucket("zen-hydration-test"); + + ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser()); + ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword()); + ScopedEnvVar EnvRegion("AWS_DEFAULT_REGION", "wrong-region"); + + ScopedTemporaryDirectory TempDir; + + std::filesystem::path ServerStateDir = TempDir.Path() / "server_state"; + std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp"; + CreateDirectories(ServerStateDir); + CreateDirectories(HydrationTemp); + + auto TestFiles = CreateTestTree(ServerStateDir); + + HydrationConfig Config; + Config.ServerStateDir = ServerStateDir; + Config.TempDir = HydrationTemp; + Config.ModuleId = "s3test_region_override"; + { + std::string ConfigJson = fmt::format( + R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true,"region":"us-east-1"}}}})", + Minio.Endpoint()); + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError); + ZEN_ASSERT(ParseError.empty() && Root.IsObject()); + Config.Options = std::move(Root).AsObject(); + } { std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); diff --git a/src/zenserver/hub/hydration.h b/src/zenserver/hub/hydration.h index d29ffe5c0..19a96c248 100644 --- a/src/zenserver/hub/hydration.h +++ b/src/zenserver/hub/hydration.h @@ -2,6 +2,8 @@ #pragma once +#include <zencore/compactbinary.h> + #include <filesystem> namespace zen { @@ -16,12 +18,8 @@ struct HydrationConfig std::string ModuleId; // Back-end specific target specification (e.g. S3 bucket, file path, etc) std::string TargetSpecification; - - // Optional S3 endpoint override (e.g. "http://localhost:9000" for MinIO). - std::string S3Endpoint; - // Use path-style S3 URLs (endpoint/bucket/key) instead of virtual-hosted-style - // (bucket.endpoint/key). Required for MinIO and other non-AWS endpoints. - bool S3PathStyle = false; + // Full config object when using --hub-hydration-target-config (mutually exclusive with TargetSpecification) + CbObject Options; }; /** diff --git a/src/zenserver/hub/storageserverinstance.cpp b/src/zenserver/hub/storageserverinstance.cpp index 6b139dbf1..0c9354990 100644 --- a/src/zenserver/hub/storageserverinstance.cpp +++ b/src/zenserver/hub/storageserverinstance.cpp @@ -57,16 +57,15 @@ StorageServerInstance::SpawnServerProcess() m_ServerInstance.EnableShutdownOnDestroy(); } -void -StorageServerInstance::GetProcessMetrics(ProcessMetrics& OutMetrics) const +ProcessMetrics +StorageServerInstance::GetProcessMetrics() const { - OutMetrics.MemoryBytes = m_MemoryBytes.load(); - OutMetrics.KernelTimeMs = m_KernelTimeMs.load(); - OutMetrics.UserTimeMs = m_UserTimeMs.load(); - OutMetrics.WorkingSetSize = m_WorkingSetSize.load(); - OutMetrics.PeakWorkingSetSize = m_PeakWorkingSetSize.load(); - OutMetrics.PagefileUsage = m_PagefileUsage.load(); - OutMetrics.PeakPagefileUsage = m_PeakPagefileUsage.load(); + ProcessMetrics Metrics; + if (m_ServerInstance.IsRunning()) + { + zen::GetProcessMetrics(m_ServerInstance.GetProcessHandle(), Metrics); + } + return Metrics; } void @@ -158,7 +157,8 @@ StorageServerInstance::Hydrate() HydrationConfig Config{.ServerStateDir = m_BaseDir, .TempDir = m_TempDir, .ModuleId = m_ModuleId, - .TargetSpecification = m_Config.HydrationTargetSpecification}; + .TargetSpecification = m_Config.HydrationTargetSpecification, + .Options = m_Config.HydrationOptions}; std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); @@ -171,7 +171,8 @@ StorageServerInstance::Dehydrate() HydrationConfig Config{.ServerStateDir = m_BaseDir, .TempDir = m_TempDir, .ModuleId = m_ModuleId, - .TargetSpecification = m_Config.HydrationTargetSpecification}; + .TargetSpecification = m_Config.HydrationTargetSpecification, + .Options = m_Config.HydrationOptions}; std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config); @@ -249,25 +250,6 @@ StorageServerInstance::SharedLockedPtr::IsRunning() const return m_Instance->m_ServerInstance.IsRunning(); } -void -StorageServerInstance::UpdateMetricsLocked() -{ - if (m_ServerInstance.IsRunning()) - { - ProcessMetrics Metrics; - zen::GetProcessMetrics(m_ServerInstance.GetProcessHandle(), Metrics); - - m_MemoryBytes.store(Metrics.MemoryBytes); - m_KernelTimeMs.store(Metrics.KernelTimeMs); - m_UserTimeMs.store(Metrics.UserTimeMs); - m_WorkingSetSize.store(Metrics.WorkingSetSize); - m_PeakWorkingSetSize.store(Metrics.PeakWorkingSetSize); - m_PagefileUsage.store(Metrics.PagefileUsage); - m_PeakPagefileUsage.store(Metrics.PeakPagefileUsage); - } - // TODO: Resource metrics... -} - #if ZEN_WITH_TESTS void StorageServerInstance::SharedLockedPtr::TerminateForTesting() const diff --git a/src/zenserver/hub/storageserverinstance.h b/src/zenserver/hub/storageserverinstance.h index 94c47630c..1b0078d87 100644 --- a/src/zenserver/hub/storageserverinstance.h +++ b/src/zenserver/hub/storageserverinstance.h @@ -2,8 +2,7 @@ #pragma once -#include "resourcemetrics.h" - +#include <zencore/compactbinary.h> #include <zenutil/zenserverprocess.h> #include <atomic> @@ -26,6 +25,7 @@ public: uint16_t BasePort; std::filesystem::path HydrationTempPath; std::string HydrationTargetSpecification; + CbObject HydrationOptions; uint32_t HttpThreadCount = 0; // Automatic int CoreLimit = 0; // Automatic std::filesystem::path ConfigPath; @@ -34,11 +34,9 @@ public: StorageServerInstance(ZenServerEnvironment& RunEnvironment, const Configuration& Config, std::string_view ModuleId); ~StorageServerInstance(); - const ResourceMetrics& GetResourceMetrics() const { return m_ResourceMetrics; } - inline std::string_view GetModuleId() const { return m_ModuleId; } inline uint16_t GetBasePort() const { return m_Config.BasePort; } - void GetProcessMetrics(ProcessMetrics& OutMetrics) const; + ProcessMetrics GetProcessMetrics() const; #if ZEN_PLATFORM_WINDOWS void SetJobObject(JobObject* InJobObject) { m_JobObject = InJobObject; } @@ -68,15 +66,10 @@ public: } bool IsRunning() const; - const ResourceMetrics& GetResourceMetrics() const - { - ZEN_ASSERT(m_Instance); - return m_Instance->m_ResourceMetrics; - } - void UpdateMetrics() + ProcessMetrics GetProcessMetrics() const { ZEN_ASSERT(m_Instance); - return m_Instance->UpdateMetricsLocked(); + return m_Instance->GetProcessMetrics(); } #if ZEN_WITH_TESTS @@ -114,12 +107,6 @@ public: } bool IsRunning() const; - const ResourceMetrics& GetResourceMetrics() const - { - ZEN_ASSERT(m_Instance); - return m_Instance->m_ResourceMetrics; - } - void Provision(); void Deprovision(); void Hibernate(); @@ -139,8 +126,6 @@ private: void HibernateLocked(); void WakeLocked(); - void UpdateMetricsLocked(); - mutable RwLock m_Lock; const Configuration m_Config; std::string m_ModuleId; @@ -149,15 +134,6 @@ private: std::filesystem::path m_BaseDir; std::filesystem::path m_TempDir; - ResourceMetrics m_ResourceMetrics; - - std::atomic<uint64_t> m_MemoryBytes = 0; - std::atomic<uint64_t> m_KernelTimeMs = 0; - std::atomic<uint64_t> m_UserTimeMs = 0; - std::atomic<uint64_t> m_WorkingSetSize = 0; - std::atomic<uint64_t> m_PeakWorkingSetSize = 0; - std::atomic<uint64_t> m_PagefileUsage = 0; - std::atomic<uint64_t> m_PeakPagefileUsage = 0; #if ZEN_PLATFORM_WINDOWS JobObject* m_JobObject = nullptr; diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp index 314031246..d01e5f3f2 100644 --- a/src/zenserver/hub/zenhubserver.cpp +++ b/src/zenserver/hub/zenhubserver.cpp @@ -2,17 +2,24 @@ #include "zenhubserver.h" +#include "config/luaconfig.h" #include "frontend/frontend.h" #include "httphubservice.h" +#include "httpproxyhandler.h" #include "hub.h" +#include <zencore/compactbinary.h> #include <zencore/config.h> +#include <zencore/except.h> +#include <zencore/except_fmt.h> +#include <zencore/filesystem.h> #include <zencore/fmtutils.h> #include <zencore/memory/llm.h> #include <zencore/memory/memorytrace.h> #include <zencore/memory/tagtrace.h> #include <zencore/scopeguard.h> #include <zencore/sentryintegration.h> +#include <zencore/system.h> #include <zencore/windows.h> #include <zenhttp/httpapiservice.h> #include <zenutil/service.h> @@ -53,12 +60,19 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("hub", "", "instance-id", - "Instance ID for use in notifications", + "Instance ID for use in notifications (deprecated, use --upstream-notification-instance-id)", cxxopts::value<std::string>(m_ServerOptions.InstanceId)->default_value(""), ""); Options.add_option("hub", "", + "upstream-notification-instance-id", + "Instance ID for use in notifications", + cxxopts::value<std::string>(m_ServerOptions.InstanceId), + ""); + + Options.add_option("hub", + "", "consul-endpoint", "Consul endpoint URL for service registration (empty = disabled)", cxxopts::value<std::string>(m_ServerOptions.ConsulEndpoint)->default_value(""), @@ -89,12 +103,19 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("hub", "", "hub-base-port-number", - "Base port number for provisioned instances", + "Base port number for provisioned instances (deprecated, use --hub-instance-base-port-number)", cxxopts::value<uint16_t>(m_ServerOptions.HubBasePortNumber)->default_value("21000"), ""); Options.add_option("hub", "", + "hub-instance-base-port-number", + "Base port number for provisioned instances", + cxxopts::value<uint16_t>(m_ServerOptions.HubBasePortNumber), + ""); + + Options.add_option("hub", + "", "hub-instance-limit", "Maximum number of provisioned instances for this hub", cxxopts::value<int>(m_ServerOptions.HubInstanceLimit)->default_value("1000"), @@ -139,6 +160,14 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.HydrationTargetSpecification), "<hydration-target-spec>"); + Options.add_option("hub", + "", + "hub-hydration-target-config", + "Path to JSON file specifying the hydration target (mutually exclusive with " + "--hub-hydration-target-spec). Supported types: 'file', 's3'.", + cxxopts::value(m_ServerOptions.HydrationTargetConfigPath), + "<path>"); + #if ZEN_PLATFORM_WINDOWS Options.add_option("hub", "", @@ -203,12 +232,103 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) "Request timeout in milliseconds for instance activity check requests", cxxopts::value<uint32_t>(m_ServerOptions.WatchdogConfig.ActivityCheckRequestTimeoutMs)->default_value("200"), "<ms>"); + + Options.add_option("hub", + "", + "hub-provision-disk-limit-bytes", + "Reject provisioning when used disk bytes exceed this value (0 = no limit).", + cxxopts::value<uint64_t>(m_ServerOptions.HubProvisionDiskLimitBytes), + "<bytes>"); + + Options.add_option("hub", + "", + "hub-provision-disk-limit-percent", + "Reject provisioning when used disk exceeds this percentage of total disk (0 = no limit).", + cxxopts::value<uint32_t>(m_ServerOptions.HubProvisionDiskLimitPercent), + "<percent>"); + + Options.add_option("hub", + "", + "hub-provision-memory-limit-bytes", + "Reject provisioning when used memory bytes exceed this value (0 = no limit).", + cxxopts::value<uint64_t>(m_ServerOptions.HubProvisionMemoryLimitBytes), + "<bytes>"); + + Options.add_option("hub", + "", + "hub-provision-memory-limit-percent", + "Reject provisioning when used memory exceeds this percentage of total RAM (0 = no limit).", + cxxopts::value<uint32_t>(m_ServerOptions.HubProvisionMemoryLimitPercent), + "<percent>"); } void ZenHubServerConfigurator::AddConfigOptions(LuaConfig::Options& Options) { - ZEN_UNUSED(Options); + using namespace std::literals; + + Options.AddOption("hub.upstreamnotification.endpoint"sv, + m_ServerOptions.UpstreamNotificationEndpoint, + "upstream-notification-endpoint"sv); + Options.AddOption("hub.upstreamnotification.instanceid"sv, m_ServerOptions.InstanceId, "upstream-notification-instance-id"sv); + + Options.AddOption("hub.consul.endpoint"sv, m_ServerOptions.ConsulEndpoint, "consul-endpoint"sv); + Options.AddOption("hub.consul.tokenenv"sv, m_ServerOptions.ConsulTokenEnv, "consul-token-env"sv); + Options.AddOption("hub.consul.healthintervalseconds"sv, + m_ServerOptions.ConsulHealthIntervalSeconds, + "consul-health-interval-seconds"sv); + Options.AddOption("hub.consul.deregisterafterseconds"sv, + m_ServerOptions.ConsulDeregisterAfterSeconds, + "consul-deregister-after-seconds"sv); + + Options.AddOption("hub.instance.baseportnumber"sv, m_ServerOptions.HubBasePortNumber, "hub-instance-base-port-number"sv); + Options.AddOption("hub.instance.http"sv, m_ServerOptions.HubInstanceHttpClass, "hub-instance-http"sv); + Options.AddOption("hub.instance.httpthreads"sv, m_ServerOptions.HubInstanceHttpThreadCount, "hub-instance-http-threads"sv); + Options.AddOption("hub.instance.corelimit"sv, m_ServerOptions.HubInstanceCoreLimit, "hub-instance-corelimit"sv); + Options.AddOption("hub.instance.config"sv, m_ServerOptions.HubInstanceConfigPath, "hub-instance-config"sv); + Options.AddOption("hub.instance.limits.count"sv, m_ServerOptions.HubInstanceLimit, "hub-instance-limit"sv); + Options.AddOption("hub.instance.limits.disklimitbytes"sv, + m_ServerOptions.HubProvisionDiskLimitBytes, + "hub-provision-disk-limit-bytes"sv); + Options.AddOption("hub.instance.limits.disklimitpercent"sv, + m_ServerOptions.HubProvisionDiskLimitPercent, + "hub-provision-disk-limit-percent"sv); + Options.AddOption("hub.instance.limits.memorylimitbytes"sv, + m_ServerOptions.HubProvisionMemoryLimitBytes, + "hub-provision-memory-limit-bytes"sv); + Options.AddOption("hub.instance.limits.memorylimitpercent"sv, + m_ServerOptions.HubProvisionMemoryLimitPercent, + "hub-provision-memory-limit-percent"sv); + + Options.AddOption("hub.hydration.targetspec"sv, m_ServerOptions.HydrationTargetSpecification, "hub-hydration-target-spec"sv); + Options.AddOption("hub.hydration.targetconfig"sv, m_ServerOptions.HydrationTargetConfigPath, "hub-hydration-target-config"sv); + + Options.AddOption("hub.watchdog.cycleintervalms"sv, m_ServerOptions.WatchdogConfig.CycleIntervalMs, "hub-watchdog-cycle-interval-ms"sv); + Options.AddOption("hub.watchdog.cycleprocessingbudgetms"sv, + m_ServerOptions.WatchdogConfig.CycleProcessingBudgetMs, + "hub-watchdog-cycle-processing-budget-ms"sv); + Options.AddOption("hub.watchdog.instancecheckthrottlems"sv, + m_ServerOptions.WatchdogConfig.InstanceCheckThrottleMs, + "hub-watchdog-instance-check-throttle-ms"sv); + Options.AddOption("hub.watchdog.provisionedinactivitytimeoutseconds"sv, + m_ServerOptions.WatchdogConfig.ProvisionedInactivityTimeoutSeconds, + "hub-watchdog-provisioned-inactivity-timeout-seconds"sv); + Options.AddOption("hub.watchdog.hibernatedinactivitytimeoutseconds"sv, + m_ServerOptions.WatchdogConfig.HibernatedInactivityTimeoutSeconds, + "hub-watchdog-hibernated-inactivity-timeout-seconds"sv); + Options.AddOption("hub.watchdog.inactivitycheckmarginseconds"sv, + m_ServerOptions.WatchdogConfig.InactivityCheckMarginSeconds, + "hub-watchdog-inactivity-check-margin-seconds"sv); + Options.AddOption("hub.watchdog.activitycheckconnecttimeoutms"sv, + m_ServerOptions.WatchdogConfig.ActivityCheckConnectTimeoutMs, + "hub-watchdog-activity-check-connect-timeout-ms"sv); + Options.AddOption("hub.watchdog.activitycheckrequesttimeoutms"sv, + m_ServerOptions.WatchdogConfig.ActivityCheckRequestTimeoutMs, + "hub-watchdog-activity-check-request-timeout-ms"sv); + +#if ZEN_PLATFORM_WINDOWS + Options.AddOption("hub.usejobobject"sv, m_ServerOptions.HubUseJobObject, "hub-use-job-object"sv); +#endif } void @@ -226,6 +346,28 @@ ZenHubServerConfigurator::OnConfigFileParsed(LuaConfig::Options& LuaOptions) void ZenHubServerConfigurator::ValidateOptions() { + if (m_ServerOptions.HubProvisionDiskLimitPercent > 100) + { + throw OptionParseException( + fmt::format("'--hub-provision-disk-limit-percent' ({}) must be in range 0..100", m_ServerOptions.HubProvisionDiskLimitPercent), + {}); + } + if (m_ServerOptions.HubProvisionMemoryLimitPercent > 100) + { + throw OptionParseException(fmt::format("'--hub-provision-memory-limit-percent' ({}) must be in range 0..100", + m_ServerOptions.HubProvisionMemoryLimitPercent), + {}); + } + if (!m_ServerOptions.HydrationTargetSpecification.empty() && !m_ServerOptions.HydrationTargetConfigPath.empty()) + { + throw OptionParseException("'--hub-hydration-target-spec' and '--hub-hydration-target-config' are mutually exclusive", {}); + } + if (!m_ServerOptions.HydrationTargetConfigPath.empty() && !std::filesystem::exists(m_ServerOptions.HydrationTargetConfigPath)) + { + throw OptionParseException( + fmt::format("'--hub-hydration-target-config': file not found: '{}'", m_ServerOptions.HydrationTargetConfigPath.string()), + {}); + } } /////////////////////////////////////////////////////////////////////////// @@ -247,6 +389,15 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId, HubInstanceState NewState) { ZEN_UNUSED(PreviousState); + + if (NewState == HubInstanceState::Deprovisioning || NewState == HubInstanceState::Hibernating) + { + if (Info.Port != 0) + { + m_Proxy->PrunePort(Info.Port); + } + } + if (!m_ConsulClient) { return; @@ -294,8 +445,8 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId, ZEN_INFO("Deregistered storage server instance for module '{}' at port {} from Consul", ModuleId, Info.Port); } } - // Transitional states (Deprovisioning, Hibernating, Waking, Recovering, Crashed) - // and Hibernated are intentionally ignored. + // Transitional states (Waking, Recovering, Crashed) and stable states + // not handled above (Hibernated) are intentionally ignored by Consul. } int @@ -348,6 +499,11 @@ ZenHubServer::Cleanup() m_Http->Close(); } + if (m_Proxy) + { + m_Proxy->Shutdown(); + } + if (m_Hub) { m_Hub->Shutdown(); @@ -357,6 +513,7 @@ ZenHubServer::Cleanup() m_HubService.reset(); m_ApiService.reset(); m_Hub.reset(); + m_Proxy.reset(); m_ConsulRegistration.reset(); m_ConsulClient.reset(); @@ -373,49 +530,116 @@ ZenHubServer::InitializeState(const ZenHubServerConfig& ServerConfig) ZEN_UNUSED(ServerConfig); } +ResourceMetrics +ZenHubServer::ResolveLimits(const ZenHubServerConfig& ServerConfig) +{ + uint64_t DiskTotal = 0; + uint64_t MemoryTotal = 0; + + if (ServerConfig.HubProvisionDiskLimitPercent > 0) + { + DiskSpace Disk; + if (DiskSpaceInfo(ServerConfig.DataDir, Disk)) + { + DiskTotal = Disk.Total; + } + else + { + ZEN_WARN("Failed to query disk space for '{}'; disk percent limit will not be applied", ServerConfig.DataDir); + } + } + if (ServerConfig.HubProvisionMemoryLimitPercent > 0) + { + MemoryTotal = GetSystemMetrics().SystemMemoryMiB * 1024 * 1024; + } + + auto Resolve = [](uint64_t Bytes, uint32_t Pct, uint64_t Total) -> uint64_t { + const uint64_t PctBytes = Pct > 0 ? (Total * Pct) / 100 : 0; + if (Bytes > 0 && PctBytes > 0) + { + return Min(Bytes, PctBytes); + } + return Bytes > 0 ? Bytes : PctBytes; + }; + + return { + .DiskUsageBytes = Resolve(ServerConfig.HubProvisionDiskLimitBytes, ServerConfig.HubProvisionDiskLimitPercent, DiskTotal), + .MemoryUsageBytes = Resolve(ServerConfig.HubProvisionMemoryLimitBytes, ServerConfig.HubProvisionMemoryLimitPercent, MemoryTotal), + }; +} + void ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) { ZEN_INFO("instantiating Hub"); + Hub::Configuration HubConfig{ + .UseJobObject = ServerConfig.HubUseJobObject, + .BasePortNumber = ServerConfig.HubBasePortNumber, + .InstanceLimit = ServerConfig.HubInstanceLimit, + .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount, + .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit, + .InstanceConfigPath = ServerConfig.HubInstanceConfigPath, + .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification, + .WatchDog = + { + .CycleInterval = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleIntervalMs), + .CycleProcessingBudget = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleProcessingBudgetMs), + .InstanceCheckThrottle = std::chrono::milliseconds(ServerConfig.WatchdogConfig.InstanceCheckThrottleMs), + .ProvisionedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.ProvisionedInactivityTimeoutSeconds), + .HibernatedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.HibernatedInactivityTimeoutSeconds), + .InactivityCheckMargin = std::chrono::seconds(ServerConfig.WatchdogConfig.InactivityCheckMarginSeconds), + .ActivityCheckConnectTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckConnectTimeoutMs), + .ActivityCheckRequestTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckRequestTimeoutMs), + }, + .ResourceLimits = ResolveLimits(ServerConfig)}; + + if (!ServerConfig.HydrationTargetConfigPath.empty()) + { + FileContents Contents = ReadFile(ServerConfig.HydrationTargetConfigPath); + if (!Contents) + { + throw zen::runtime_error("Failed to read hydration config '{}': {}", + ServerConfig.HydrationTargetConfigPath.string(), + Contents.ErrorCode.message()); + } + IoBuffer Buffer(Contents.Flatten()); + std::string_view JsonText(static_cast<const char*>(Buffer.GetData()), Buffer.GetSize()); + + std::string ParseError; + CbFieldIterator Root = LoadCompactBinaryFromJson(JsonText, ParseError); + if (!ParseError.empty() || !Root.IsObject()) + { + throw zen::runtime_error("Failed to parse hydration config '{}': {}", + ServerConfig.HydrationTargetConfigPath.string(), + ParseError.empty() ? "root must be a JSON object" : ParseError); + } + HubConfig.HydrationOptions = std::move(Root).AsObject(); + } + + m_Proxy = std::make_unique<HttpProxyHandler>(); + m_Hub = std::make_unique<Hub>( - Hub::Configuration{ - .UseJobObject = ServerConfig.HubUseJobObject, - .BasePortNumber = ServerConfig.HubBasePortNumber, - .InstanceLimit = ServerConfig.HubInstanceLimit, - .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount, - .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit, - .InstanceConfigPath = ServerConfig.HubInstanceConfigPath, - .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification, - .WatchDog = - { - .CycleInterval = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleIntervalMs), - .CycleProcessingBudget = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleProcessingBudgetMs), - .InstanceCheckThrottle = std::chrono::milliseconds(ServerConfig.WatchdogConfig.InstanceCheckThrottleMs), - .ProvisionedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.ProvisionedInactivityTimeoutSeconds), - .HibernatedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.HibernatedInactivityTimeoutSeconds), - .InactivityCheckMargin = std::chrono::seconds(ServerConfig.WatchdogConfig.InactivityCheckMarginSeconds), - .ActivityCheckConnectTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckConnectTimeoutMs), - .ActivityCheckRequestTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckRequestTimeoutMs), - }}, + std::move(HubConfig), ZenServerEnvironment(ZenServerEnvironment::Hub, ServerConfig.DataDir / "hub", ServerConfig.DataDir / "servers", ServerConfig.HubInstanceHttpClass), &GetMediumWorkerPool(EWorkloadType::Background), - m_ConsulClient ? Hub::AsyncModuleStateChangeCallbackFunc{[this, HubInstanceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId)]( - std::string_view ModuleId, - const HubProvisionedInstanceInfo& Info, - HubInstanceState PreviousState, - HubInstanceState NewState) { - OnModuleStateChanged(HubInstanceId, ModuleId, Info, PreviousState, NewState); - }} - : Hub::AsyncModuleStateChangeCallbackFunc{}); + Hub::AsyncModuleStateChangeCallbackFunc{ + [this, HubInstanceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId)](std::string_view ModuleId, + const HubProvisionedInstanceInfo& Info, + HubInstanceState PreviousState, + HubInstanceState NewState) { + OnModuleStateChanged(HubInstanceId, ModuleId, Info, PreviousState, NewState); + }}); + + m_Proxy->SetPortValidator([Hub = m_Hub.get()](uint16_t Port) { return Hub->IsInstancePort(Port); }); ZEN_INFO("instantiating API service"); m_ApiService = std::make_unique<zen::HttpApiService>(*m_Http); ZEN_INFO("instantiating hub service"); - m_HubService = std::make_unique<HttpHubService>(*m_Hub, m_StatsService, m_StatusService); + m_HubService = std::make_unique<HttpHubService>(*m_Hub, *m_Proxy, m_StatsService, m_StatusService); m_HubService->SetNotificationEndpoint(ServerConfig.UpstreamNotificationEndpoint, ServerConfig.InstanceId); m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatsService, m_StatusService); @@ -465,12 +689,15 @@ ZenHubServer::InitializeConsulRegistration(const ZenHubServerConfig& ServerConfi } else { - ZEN_INFO("Consul token read from environment variable '{}'", ConsulAccessTokenEnvName); + ZEN_INFO("Consul token will be read from environment variable '{}'", ConsulAccessTokenEnvName); } try { - m_ConsulClient = std::make_unique<consul::ConsulClient>(ServerConfig.ConsulEndpoint, ConsulAccessToken); + m_ConsulClient = std::make_unique<consul::ConsulClient>(consul::ConsulClient::Configuration{ + .BaseUri = ServerConfig.ConsulEndpoint, + .TokenEnvName = ConsulAccessTokenEnvName, + }); m_ConsulHealthIntervalSeconds = ServerConfig.ConsulHealthIntervalSeconds; m_ConsulDeregisterAfterSeconds = ServerConfig.ConsulDeregisterAfterSeconds; @@ -479,7 +706,7 @@ ZenHubServer::InitializeConsulRegistration(const ZenHubServerConfig& ServerConfi Info.ServiceName = "zen-hub"; // Info.Address = "localhost"; // Let the consul agent figure out out external address // TODO: Info.BaseUri? Info.Port = static_cast<uint16_t>(EffectivePort); - Info.HealthEndpoint = "hub/health"; + Info.HealthEndpoint = "health"; Info.Tags = std::vector<std::pair<std::string, std::string>>{ std::make_pair("zen-hub", Info.ServiceId), std::make_pair("version", std::string(ZEN_CFG_VERSION)), diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h index 77df3eaa3..d1add7690 100644 --- a/src/zenserver/hub/zenhubserver.h +++ b/src/zenserver/hub/zenhubserver.h @@ -3,6 +3,7 @@ #pragma once #include "hubinstancestate.h" +#include "resourcemetrics.h" #include "zenserver.h" #include <zenutil/consul.h> @@ -19,6 +20,7 @@ namespace zen { class HttpApiService; class HttpFrontendService; class HttpHubService; +class HttpProxyHandler; struct ZenHubWatchdogConfig { @@ -48,7 +50,12 @@ struct ZenHubServerConfig : public ZenServerConfig int HubInstanceCoreLimit = 0; // Automatic std::filesystem::path HubInstanceConfigPath; // Path to Lua config file std::string HydrationTargetSpecification; // hydration/dehydration target specification + std::filesystem::path HydrationTargetConfigPath; // path to JSON config file (mutually exclusive with HydrationTargetSpecification) ZenHubWatchdogConfig WatchdogConfig; + uint64_t HubProvisionDiskLimitBytes = 0; + uint32_t HubProvisionDiskLimitPercent = 0; + uint64_t HubProvisionMemoryLimitBytes = 0; + uint32_t HubProvisionMemoryLimitPercent = 0; }; class Hub; @@ -115,7 +122,8 @@ private: std::filesystem::path m_ContentRoot; bool m_DebugOptionForcedCrash = false; - std::unique_ptr<Hub> m_Hub; + std::unique_ptr<HttpProxyHandler> m_Proxy; + std::unique_ptr<Hub> m_Hub; std::unique_ptr<HttpHubService> m_HubService; std::unique_ptr<HttpApiService> m_ApiService; @@ -126,6 +134,8 @@ private: uint32_t m_ConsulHealthIntervalSeconds = 10; uint32_t m_ConsulDeregisterAfterSeconds = 30; + static ResourceMetrics ResolveLimits(const ZenHubServerConfig& ServerConfig); + void InitializeState(const ZenHubServerConfig& ServerConfig); void InitializeServices(const ZenHubServerConfig& ServerConfig); void RegisterServices(const ZenHubServerConfig& ServerConfig); diff --git a/src/zenserver/sessions/httpsessions.cpp b/src/zenserver/sessions/httpsessions.cpp index fdf2e1f21..c21ae6a5c 100644 --- a/src/zenserver/sessions/httpsessions.cpp +++ b/src/zenserver/sessions/httpsessions.cpp @@ -512,8 +512,9 @@ HttpSessionsService::SessionLogRequest(HttpRouterRequest& Req) // void -HttpSessionsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpSessionsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); ZEN_INFO("Sessions WebSocket client connected"); m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); } diff --git a/src/zenserver/sessions/httpsessions.h b/src/zenserver/sessions/httpsessions.h index 86a23f835..6ebe61c8d 100644 --- a/src/zenserver/sessions/httpsessions.h +++ b/src/zenserver/sessions/httpsessions.h @@ -37,7 +37,7 @@ public: void SetSelfSessionId(const Oid& Id) { m_SelfSessionId = Id; } // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp index c1727270c..8ad48225b 100644 --- a/src/zenserver/storage/cache/httpstructuredcache.cpp +++ b/src/zenserver/storage/cache/httpstructuredcache.cpp @@ -80,7 +80,8 @@ HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCach HttpStatusService& StatusService, UpstreamCache& UpstreamCache, const DiskWriteBlocker* InDiskWriteBlocker, - OpenProcessCache& InOpenProcessCache) + OpenProcessCache& InOpenProcessCache, + const ILocalRefPolicy* InLocalRefPolicy) : m_Log(logging::Get("cache")) , m_CacheStore(InCacheStore) , m_StatsService(StatsService) @@ -90,6 +91,7 @@ HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCach , m_DiskWriteBlocker(InDiskWriteBlocker) , m_OpenProcessCache(InOpenProcessCache) , m_RpcHandler(m_Log, m_CacheStats, UpstreamCache, InCacheStore, InCidStore, InDiskWriteBlocker) +, m_LocalRefPolicy(InLocalRefPolicy) { m_StatsService.RegisterHandler("z$", *this); m_StatusService.RegisterHandler("z$", *this); @@ -114,6 +116,18 @@ HttpStructuredCacheService::BaseUri() const return "/z$/"; } +bool +HttpStructuredCacheService::AcceptsLocalFileReferences() const +{ + return true; +} + +const ILocalRefPolicy* +HttpStructuredCacheService::GetLocalRefPolicy() const +{ + return m_LocalRefPolicy; +} + void HttpStructuredCacheService::Flush() { diff --git a/src/zenserver/storage/cache/httpstructuredcache.h b/src/zenserver/storage/cache/httpstructuredcache.h index fc80b449e..f606126d6 100644 --- a/src/zenserver/storage/cache/httpstructuredcache.h +++ b/src/zenserver/storage/cache/httpstructuredcache.h @@ -76,11 +76,14 @@ public: HttpStatusService& StatusService, UpstreamCache& UpstreamCache, const DiskWriteBlocker* InDiskWriteBlocker, - OpenProcessCache& InOpenProcessCache); + OpenProcessCache& InOpenProcessCache, + const ILocalRefPolicy* InLocalRefPolicy = nullptr); ~HttpStructuredCacheService(); - virtual const char* BaseUri() const override; - virtual void HandleRequest(HttpServerRequest& Request) override; + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + virtual bool AcceptsLocalFileReferences() const override; + virtual const ILocalRefPolicy* GetLocalRefPolicy() const override; void Flush(); @@ -125,6 +128,7 @@ private: const DiskWriteBlocker* m_DiskWriteBlocker = nullptr; OpenProcessCache& m_OpenProcessCache; CacheRpcHandler m_RpcHandler; + const ILocalRefPolicy* m_LocalRefPolicy = nullptr; void ReplayRequestRecorder(const CacheRequestContext& Context, cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount); diff --git a/src/zenserver/storage/localrefpolicy.cpp b/src/zenserver/storage/localrefpolicy.cpp new file mode 100644 index 000000000..47ef13b28 --- /dev/null +++ b/src/zenserver/storage/localrefpolicy.cpp @@ -0,0 +1,29 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "localrefpolicy.h" + +#include <zencore/except_fmt.h> +#include <zencore/fmtutils.h> + +#include <filesystem> + +namespace zen { + +DataRootLocalRefPolicy::DataRootLocalRefPolicy(const std::filesystem::path& DataRoot) +: m_CanonicalRoot(std::filesystem::weakly_canonical(DataRoot).string()) +{ +} + +void +DataRootLocalRefPolicy::ValidatePath(const std::filesystem::path& Path) const +{ + std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(Path); + std::string FileStr = CanonicalFile.string(); + + if (FileStr.size() < m_CanonicalRoot.size() || FileStr.compare(0, m_CanonicalRoot.size(), m_CanonicalRoot) != 0) + { + throw zen::invalid_argument("local file reference '{}' is outside allowed data root", CanonicalFile); + } +} + +} // namespace zen diff --git a/src/zenserver/storage/localrefpolicy.h b/src/zenserver/storage/localrefpolicy.h new file mode 100644 index 000000000..3686d1880 --- /dev/null +++ b/src/zenserver/storage/localrefpolicy.h @@ -0,0 +1,25 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/localrefpolicy.h> + +#include <filesystem> +#include <string> + +namespace zen { + +/// Local ref policy that restricts file paths to a canonical data root directory. +/// Uses weakly_canonical + string prefix comparison to detect path traversal. +class DataRootLocalRefPolicy : public ILocalRefPolicy +{ +public: + explicit DataRootLocalRefPolicy(const std::filesystem::path& DataRoot); + + void ValidatePath(const std::filesystem::path& Path) const override; + +private: + std::string m_CanonicalRoot; +}; + +} // namespace zen diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp index a7c8c66b6..afd0d8f82 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.cpp +++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp @@ -656,7 +656,8 @@ HttpProjectService::HttpProjectService(CidStore& Store, JobQueue& InJobQueue, bool InRestrictContentTypes, const std::filesystem::path& InOidcTokenExePath, - bool InAllowExternalOidcTokenExe) + bool InAllowExternalOidcTokenExe, + const ILocalRefPolicy* InLocalRefPolicy) : m_Log(logging::Get("project")) , m_CidStore(Store) , m_ProjectStore(Projects) @@ -668,6 +669,7 @@ HttpProjectService::HttpProjectService(CidStore& Store, , m_RestrictContentTypes(InRestrictContentTypes) , m_OidcTokenExePath(InOidcTokenExePath) , m_AllowExternalOidcTokenExe(InAllowExternalOidcTokenExe) +, m_LocalRefPolicy(InLocalRefPolicy) { ZEN_MEMSCOPE(GetProjectHttpTag()); @@ -820,6 +822,18 @@ HttpProjectService::BaseUri() const return "/prj/"; } +bool +HttpProjectService::AcceptsLocalFileReferences() const +{ + return true; +} + +const ILocalRefPolicy* +HttpProjectService::GetLocalRefPolicy() const +{ + return m_LocalRefPolicy; +} + void HttpProjectService::HandleRequest(HttpServerRequest& Request) { @@ -1668,7 +1682,8 @@ HttpProjectService::HandleOplogOpNewRequest(HttpRouterRequest& Req) CbPackage Package; - if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver)) + const bool ValidateHashes = false; + if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver, ValidateHashes)) { CbValidateError ValidateResult; if (CbObject Core = ValidateAndReadCompactBinaryObject(IoBuffer(Payload), ValidateResult); @@ -2763,7 +2778,11 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) case HttpContentType::kCbPackage: try { - Package = ParsePackageMessage(Payload); + ParseFlags PkgFlags = (HttpReq.IsLocalMachineRequest() && AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + const ILocalRefPolicy* PkgPolicy = + EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? GetLocalRefPolicy() : nullptr; + Package = ParsePackageMessage(Payload, {}, PkgFlags, PkgPolicy); Cb = Package.GetObject(); } catch (const std::invalid_argument& ex) diff --git a/src/zenserver/storage/projectstore/httpprojectstore.h b/src/zenserver/storage/projectstore/httpprojectstore.h index e3ed02f26..8aa345fa7 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.h +++ b/src/zenserver/storage/projectstore/httpprojectstore.h @@ -47,11 +47,14 @@ public: JobQueue& InJobQueue, bool InRestrictContentTypes, const std::filesystem::path& InOidcTokenExePath, - bool AllowExternalOidcTokenExe); + bool AllowExternalOidcTokenExe, + const ILocalRefPolicy* InLocalRefPolicy = nullptr); ~HttpProjectService(); - virtual const char* BaseUri() const override; - virtual void HandleRequest(HttpServerRequest& Request) override; + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + virtual bool AcceptsLocalFileReferences() const override; + virtual const ILocalRefPolicy* GetLocalRefPolicy() const override; virtual void HandleStatusRequest(HttpServerRequest& Request) override; virtual void HandleStatsRequest(HttpServerRequest& Request) override; @@ -117,6 +120,7 @@ private: bool m_RestrictContentTypes; std::filesystem::path m_OidcTokenExePath; bool m_AllowExternalOidcTokenExe; + const ILocalRefPolicy* m_LocalRefPolicy; Ref<TransferThreadWorkers> GetThreadWorkers(bool BoostWorkers, bool SingleThreaded); }; diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp index bc0a8f4ac..6b1da5f12 100644 --- a/src/zenserver/storage/zenstorageserver.cpp +++ b/src/zenserver/storage/zenstorageserver.cpp @@ -223,6 +223,7 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions ZEN_INFO("instantiating project service"); + m_LocalRefPolicy = std::make_unique<DataRootLocalRefPolicy>(m_DataRoot); m_JobQueue = MakeJobQueue(8, "bgjobs"); m_OpenProcessCache = std::make_unique<OpenProcessCache>(); @@ -236,7 +237,8 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions *m_JobQueue, ServerOptions.RestrictContentTypes, ServerOptions.OidcTokenExecutable, - ServerOptions.AllowExternalOidcTokenExe}); + ServerOptions.AllowExternalOidcTokenExe, + m_LocalRefPolicy.get()}); if (ServerOptions.WorksSpacesConfig.Enabled) { @@ -713,7 +715,8 @@ ZenStorageServer::InitializeStructuredCache(const ZenStorageServerConfig& Server m_StatusService, *m_UpstreamCache, m_GcManager.GetDiskWriteBlocker(), - *m_OpenProcessCache); + *m_OpenProcessCache, + m_LocalRefPolicy.get()); m_StatsReporter.AddProvider(m_CacheStore.Get()); m_StatsReporter.AddProvider(m_CidStore.get()); diff --git a/src/zenserver/storage/zenstorageserver.h b/src/zenserver/storage/zenstorageserver.h index fad22ad54..e3c6248e6 100644 --- a/src/zenserver/storage/zenstorageserver.h +++ b/src/zenserver/storage/zenstorageserver.h @@ -11,6 +11,7 @@ #include <zenstore/cache/structuredcachestore.h> #include <zenstore/gc.h> #include <zenstore/projectstore.h> +#include "localrefpolicy.h" #include "admin/admin.h" #include "buildstore/httpbuildstore.h" @@ -65,15 +66,16 @@ private: void InitializeServices(const ZenStorageServerConfig& ServerOptions); void RegisterServices(); - std::unique_ptr<JobQueue> m_JobQueue; - GcManager m_GcManager; - GcScheduler m_GcScheduler{m_GcManager}; - std::unique_ptr<CidStore> m_CidStore; - Ref<ZenCacheStore> m_CacheStore; - std::unique_ptr<OpenProcessCache> m_OpenProcessCache; - HttpTestService m_TestService; - std::unique_ptr<CidStore> m_BuildCidStore; - std::unique_ptr<BuildStore> m_BuildStore; + std::unique_ptr<DataRootLocalRefPolicy> m_LocalRefPolicy; + std::unique_ptr<JobQueue> m_JobQueue; + GcManager m_GcManager; + GcScheduler m_GcScheduler{m_GcManager}; + std::unique_ptr<CidStore> m_CidStore; + Ref<ZenCacheStore> m_CacheStore; + std::unique_ptr<OpenProcessCache> m_OpenProcessCache; + HttpTestService m_TestService; + std::unique_ptr<CidStore> m_BuildCidStore; + std::unique_ptr<BuildStore> m_BuildStore; #if ZEN_WITH_TESTS HttpTestingService m_TestingService; diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index 6aa02eb87..087b40d6a 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -272,6 +272,8 @@ ZenServerBase::GetBuildOptions(StringBuilderBase& OutOptions, char Separator) co OutOptions << Separator; OutOptions << "ZEN_WITH_MEMTRACK=" << (ZEN_WITH_MEMTRACK ? "1" : "0"); OutOptions << Separator; + OutOptions << "ZEN_WITH_COMPUTE_SERVICES=" << (ZEN_WITH_COMPUTE_SERVICES ? "1" : "0"); + OutOptions << Separator; OutOptions << "ZEN_WITH_TRACE=" << (ZEN_WITH_TRACE ? "1" : "0"); } diff --git a/src/zenstore/projectstore.cpp b/src/zenstore/projectstore.cpp index 13674da4d..7cd6b9e37 100644 --- a/src/zenstore/projectstore.cpp +++ b/src/zenstore/projectstore.cpp @@ -3180,6 +3180,7 @@ ProjectStore::Oplog::AddFileMapping(const RwLock::ExclusiveLockScope&, } else { + m_ChunkMap.erase(FileId); Entry.ServerPath = ServerPath; } @@ -5168,7 +5169,10 @@ ExtractRange(IoBuffer&& Chunk, uint64_t Offset, uint64_t Size, ZenContentType Ac const bool IsFullRange = (Offset == 0) && ((Size == ~(0ull)) || (Size == ChunkSize)); if (IsFullRange) { - Result.Chunk = CompositeBuffer(SharedBuffer(std::move(Chunk))); + if (ChunkSize > 0) + { + Result.Chunk = CompositeBuffer(SharedBuffer(std::move(Chunk))); + } Result.RawSize = 0; } else diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp index 54e54edde..73cb7ff2d 100644 --- a/src/zentest-appstub/zentest-appstub.cpp +++ b/src/zentest-appstub/zentest-appstub.cpp @@ -33,6 +33,13 @@ using namespace zen; // Some basic functions to implement some test "compute" functions +struct ForcedExitException : std::exception +{ + int Code; + explicit ForcedExitException(int InCode) : Code(InCode) {} + const char* what() const noexcept override { return "forced exit"; } +}; + std::string Rot13Function(std::string_view InputString) { @@ -111,6 +118,16 @@ DescribeFunctions() << "Sleep"sv; Versions << "Version"sv << Guid::FromString("88888888-8888-8888-8888-888888888888"sv); Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Fail"sv; + Versions << "Version"sv << Guid::FromString("fa11fa11-fa11-fa11-fa11-fa11fa11fa11"sv); + Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Crash"sv; + Versions << "Version"sv << Guid::FromString("c4a50000-c4a5-c4a5-c4a5-c4a5c4a5c4a5"sv); + Versions.EndObject(); Versions.EndArray(); return Versions.Save(); @@ -201,6 +218,38 @@ ExecuteFunction(CbObject Action, ContentResolver ChunkResolver) zen::Sleep(static_cast<int>(SleepTimeMs)); return Apply(IdentityFunction); } + else if (Function == "Fail"sv) + { + int FailExitCode = static_cast<int>(Action["Constants"sv].AsObjectView()["ExitCode"sv].AsUInt64()); + if (FailExitCode == 0) + { + FailExitCode = 1; + } + throw ForcedExitException(FailExitCode); + } + else if (Function == "Crash"sv) + { + // Crash modes: + // "abort" - calls std::abort() (SIGABRT / process termination) + // "nullptr" - dereferences a null pointer (SIGSEGV / access violation) + std::string_view Mode = Action["Constants"sv].AsObjectView()["Mode"sv].AsString(); + + printf("[zentest] crashing with mode: %.*s\n", int(Mode.size()), Mode.data()); + fflush(stdout); + + if (Mode == "nullptr"sv) + { + volatile int* Ptr = nullptr; + *Ptr = 42; + } + + // Default crash mode (also reached after nullptr write on platforms + // that don't immediately fault on null dereference) +#if defined(_MSC_VER) + _set_abort_behavior(0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); +#endif + std::abort(); + } else { return {}; @@ -421,6 +470,12 @@ main(int argc, char* argv[]) } } } + catch (ForcedExitException& Ex) + { + printf("[zentest] forced exit with code: %d\n", Ex.Code); + + ExitCode = Ex.Code; + } catch (std::exception& Ex) { printf("[zentest] exception caught in main: '%s'\n", Ex.what()); diff --git a/src/zenutil/cloud/imdscredentials.cpp b/src/zenutil/cloud/imdscredentials.cpp index dde1dc019..5a6cf45d2 100644 --- a/src/zenutil/cloud/imdscredentials.cpp +++ b/src/zenutil/cloud/imdscredentials.cpp @@ -115,7 +115,7 @@ ImdsCredentialProvider::FetchToken() HttpClient::KeyValueMap Headers; Headers->emplace("X-aws-ec2-metadata-token-ttl-seconds", "21600"); - HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", Headers); + HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", IoBuffer{}, Headers); if (!Response.IsSuccess()) { ZEN_WARN("IMDS token request failed: {}", Response.ErrorMessage("PUT /latest/api/token")); diff --git a/src/zenutil/cloud/minioprocess.cpp b/src/zenutil/cloud/minioprocess.cpp index 457453bd8..e146f6677 100644 --- a/src/zenutil/cloud/minioprocess.cpp +++ b/src/zenutil/cloud/minioprocess.cpp @@ -45,7 +45,7 @@ struct MinioProcess::Impl } CreateProcOptions Options; - Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; Options.Environment.emplace_back("MINIO_ROOT_USER", m_Options.RootUser); Options.Environment.emplace_back("MINIO_ROOT_PASSWORD", m_Options.RootPassword); diff --git a/src/zenutil/cloud/s3client.cpp b/src/zenutil/cloud/s3client.cpp index 26d1023f4..d9fde05d9 100644 --- a/src/zenutil/cloud/s3client.cpp +++ b/src/zenutil/cloud/s3client.cpp @@ -137,6 +137,8 @@ namespace { } // namespace +std::string_view S3GetObjectResult::NotFoundErrorText = "Not found"; + S3Client::S3Client(const S3ClientOptions& Options) : m_Log(logging::Get("s3")) , m_BucketName(Options.BucketName) @@ -145,13 +147,7 @@ S3Client::S3Client(const S3ClientOptions& Options) , m_PathStyle(Options.PathStyle) , m_Credentials(Options.Credentials) , m_CredentialProvider(Options.CredentialProvider) -, m_HttpClient(BuildEndpoint(), - HttpClientSettings{ - .LogCategory = "s3", - .ConnectTimeout = Options.ConnectTimeout, - .Timeout = Options.Timeout, - .RetryCount = Options.RetryCount, - }) +, m_HttpClient(BuildEndpoint(), Options.HttpSettings) { m_Host = BuildHostHeader(); ZEN_INFO("S3 client configured for bucket '{}' in region '{}' (endpoint: {}, {})", @@ -347,15 +343,20 @@ S3Client::PutObject(std::string_view Key, IoBuffer Content) } S3GetObjectResult -S3Client::GetObject(std::string_view Key) +S3Client::GetObject(std::string_view Key, const std::filesystem::path& TempFilePath) { std::string Path = KeyToPath(Key); HttpClient::KeyValueMap Headers = SignRequest("GET", Path, "", EmptyPayloadHash); - HttpClient::Response Response = m_HttpClient.Get(Path, Headers); + HttpClient::Response Response = m_HttpClient.Download(Path, TempFilePath, Headers); if (!Response.IsSuccess()) { + if (Response.StatusCode == HttpResponseCode::NotFound) + { + return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}}; + } + std::string Err = Response.ErrorMessage("S3 GET failed"); ZEN_WARN("S3 GET '{}' failed: {}", Key, Err); return S3GetObjectResult{S3Result{std::move(Err)}, {}}; @@ -377,6 +378,11 @@ S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t Ran HttpClient::Response Response = m_HttpClient.Get(Path, Headers); if (!Response.IsSuccess()) { + if (Response.StatusCode == HttpResponseCode::NotFound) + { + return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}}; + } + std::string Err = Response.ErrorMessage("S3 GET range failed"); ZEN_WARN("S3 GET range '{}' [{}-{}] failed: {}", Key, RangeStart, RangeStart + RangeSize - 1, Err); return S3GetObjectResult{S3Result{std::move(Err)}, {}}; @@ -749,7 +755,7 @@ S3Client::PutObjectMultipart(std::string_view Key, return PutObject(Key, TotalSize > 0 ? FetchRange(0, TotalSize) : IoBuffer{}); } - ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, TotalSize, (TotalSize + PartSize - 1) / PartSize); + ZEN_DEBUG("S3 multipart upload '{}': {} bytes in ~{} parts", Key, TotalSize, (TotalSize + PartSize - 1) / PartSize); S3CreateMultipartUploadResult InitResult = CreateMultipartUpload(Key); if (!InitResult) @@ -797,7 +803,7 @@ S3Client::PutObjectMultipart(std::string_view Key, throw; } - ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), TotalSize); + ZEN_DEBUG("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), TotalSize); return {}; } @@ -885,7 +891,10 @@ TEST_CASE("s3client.minio_integration") { using namespace std::literals; - // Spawn a local MinIO server + // Spawn a single MinIO server for the entire test case. Previously each SUBCASE re-entered + // the TEST_CASE from the top, spawning and killing MinIO per subcase — slow and flaky on + // macOS CI. Sequential sections avoid the re-entry while still sharing one MinIO instance + // that is torn down via RAII at scope exit. MinioProcessOptions MinioOpts; MinioOpts.Port = 19000; MinioOpts.RootUser = "testuser"; @@ -893,11 +902,8 @@ TEST_CASE("s3client.minio_integration") MinioProcess Minio(MinioOpts); Minio.SpawnMinioServer(); - - // Pre-create the test bucket (creates a subdirectory in MinIO's data dir) Minio.CreateBucket("integration-test"); - // Configure S3Client for the test bucket S3ClientOptions Opts; Opts.BucketName = "integration-test"; Opts.Region = "us-east-1"; @@ -908,7 +914,7 @@ TEST_CASE("s3client.minio_integration") S3Client Client(Opts); - SUBCASE("put_get_delete") + // -- put_get_delete ------------------------------------------------------- { // PUT std::string_view TestData = "hello, minio integration test!"sv; @@ -937,14 +943,14 @@ TEST_CASE("s3client.minio_integration") CHECK(HeadRes2.Status == HeadObjectResult::NotFound); } - SUBCASE("head_not_found") + // -- head_not_found ------------------------------------------------------- { S3HeadObjectResult Res = Client.HeadObject("nonexistent/key.dat"); CHECK(Res.IsSuccess()); CHECK(Res.Status == HeadObjectResult::NotFound); } - SUBCASE("list_objects") + // -- list_objects --------------------------------------------------------- { // Upload several objects with a common prefix for (int i = 0; i < 3; ++i) @@ -979,7 +985,7 @@ TEST_CASE("s3client.minio_integration") } } - SUBCASE("multipart_upload") + // -- multipart_upload ----------------------------------------------------- { // Create a payload large enough to exercise multipart (use minimum part size) constexpr uint64_t PartSize = 5 * 1024 * 1024; // 5 MB minimum @@ -1006,7 +1012,7 @@ TEST_CASE("s3client.minio_integration") Client.DeleteObject("multipart/large.bin"); } - SUBCASE("presigned_urls") + // -- presigned_urls ------------------------------------------------------- { // Upload an object std::string_view TestData = "presigned-url-test-data"sv; @@ -1032,8 +1038,6 @@ TEST_CASE("s3client.minio_integration") // Cleanup Client.DeleteObject("presigned/test.txt"); } - - Minio.StopMinioServer(); } TEST_SUITE_END(); diff --git a/src/zenutil/consul/consul.cpp b/src/zenutil/consul/consul.cpp index c9144e589..951beed65 100644 --- a/src/zenutil/consul/consul.cpp +++ b/src/zenutil/consul/consul.cpp @@ -9,9 +9,13 @@ #include <zencore/logging.h> #include <zencore/process.h> #include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> #include <zencore/thread.h> #include <zencore/timer.h> +#include <zenhttp/httpserver.h> + #include <fmt/format.h> namespace zen::consul { @@ -31,7 +35,7 @@ struct ConsulProcess::Impl } CreateProcOptions Options; - Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; const std::filesystem::path ConsulExe = GetRunningExecutablePath().parent_path() / ("consul" ZEN_EXE_SUFFIX_LITERAL); CreateProcResult Result = CreateProc(ConsulExe, "consul" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options); @@ -107,7 +111,7 @@ ConsulProcess::StopConsulAgent() ////////////////////////////////////////////////////////////////////////// -ConsulClient::ConsulClient(std::string_view BaseUri, std::string_view Token) : m_Token(Token), m_HttpClient(BaseUri) +ConsulClient::ConsulClient(const Configuration& Config) : m_Config(Config), m_HttpClient(m_Config.BaseUri) { } @@ -193,7 +197,9 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) // when no interval is configured (e.g. during Provisioning). Writer.BeginObject("Check"sv); { - Writer.AddString("HTTP"sv, fmt::format("http://{}:{}/{}", Info.Address, Info.Port, Info.HealthEndpoint)); + Writer.AddString( + "HTTP"sv, + fmt::format("http://{}:{}/{}", Info.Address.empty() ? "localhost" : Info.Address, Info.Port, Info.HealthEndpoint)); Writer.AddString("Interval"sv, fmt::format("{}s", Info.HealthIntervalSeconds)); if (Info.DeregisterAfterSeconds != 0) { @@ -223,27 +229,112 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info) bool ConsulClient::DeregisterService(std::string_view ServiceId) { + using namespace std::literals; + HttpClient::KeyValueMap AdditionalHeaders; ApplyCommonHeaders(AdditionalHeaders); AdditionalHeaders.Entries.emplace(HttpClient::Accept(HttpContentType::kJSON)); - HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), AdditionalHeaders); + HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), IoBuffer{}, AdditionalHeaders); + if (Result) + { + return true; + } + + // Agent deregister failed — fall back to catalog deregister. + // This handles cases where the service was registered via a different Consul agent + // (e.g. load-balanced endpoint routing to different agents). + std::string NodeName = GetNodeName(); + if (!NodeName.empty()) + { + CbObjectWriter Writer; + Writer.AddString("Node"sv, NodeName); + Writer.AddString("ServiceID"sv, ServiceId); + + ExtendableStringBuilder<256> SB; + CompactBinaryToJson(Writer.Save(), SB); + + IoBuffer PayloadBuffer(IoBuffer::Wrap, SB.Data(), SB.Size()); + PayloadBuffer.SetContentType(HttpContentType::kJSON); + + HttpClient::Response CatalogResult = m_HttpClient.Put("v1/catalog/deregister", PayloadBuffer, AdditionalHeaders); + if (CatalogResult) + { + ZEN_INFO("ConsulClient::DeregisterService() deregistered service '{}' via catalog fallback (agent error: {})", + ServiceId, + Result.ErrorMessage("")); + return true; + } + + ZEN_WARN("ConsulClient::DeregisterService() failed to deregister service '{}' (agent: {}, catalog: {})", + ServiceId, + Result.ErrorMessage(""), + CatalogResult.ErrorMessage("")); + } + else + { + ZEN_WARN( + "ConsulClient::DeregisterService() failed to deregister service '{}' (agent: {}, could not determine node name for catalog " + "fallback)", + ServiceId, + Result.ErrorMessage("")); + } + + return false; +} + +std::string +ConsulClient::GetNodeName() +{ + using namespace std::literals; + + HttpClient::KeyValueMap AdditionalHeaders; + ApplyCommonHeaders(AdditionalHeaders); + HttpClient::Response Result = m_HttpClient.Get("v1/agent/self", AdditionalHeaders); if (!Result) { - ZEN_WARN("ConsulClient::DeregisterService() failed to deregister service '{}' ({})", ServiceId, Result.ErrorMessage("")); - return false; + return {}; } - return true; + std::string JsonError; + CbFieldIterator Root = LoadCompactBinaryFromJson(Result.AsText(), JsonError); + if (!Root || !JsonError.empty()) + { + return {}; + } + + for (CbFieldView Field : Root) + { + if (Field.GetName() == "Config"sv) + { + CbObjectView Config = Field.AsObjectView(); + if (Config) + { + return std::string(Config["NodeName"sv].AsString()); + } + } + } + + return {}; } void ConsulClient::ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap) { - if (!m_Token.empty()) + std::string Token; + if (!m_Config.StaticToken.empty()) + { + Token = m_Config.StaticToken; + } + else if (!m_Config.TokenEnvName.empty()) + { + Token = GetEnvVariable(m_Config.TokenEnvName); + } + + if (!Token.empty()) { - InOutHeaderMap.Entries.emplace("X-Consul-Token", m_Token); + InOutHeaderMap.Entries.emplace("X-Consul-Token", Token); } } @@ -446,4 +537,191 @@ ServiceRegistration::RegistrationLoop() } } +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +consul_forcelink() +{ +} + +struct MockHealthService : public HttpService +{ + std::atomic<bool> FailHealth{false}; + std::atomic<int> HealthCheckCount{0}; + + const char* BaseUri() const override { return "/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + std::string_view Uri = Request.RelativeUri(); + if (Uri == "health/" || Uri == "health") + { + HealthCheckCount.fetch_add(1); + if (FailHealth.load()) + { + Request.WriteResponse(HttpResponseCode::ServiceUnavailable); + } + else + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + } +}; + +struct TestHealthServer +{ + MockHealthService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(0, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + int Port() const { return m_Port; } + + ~TestHealthServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional<ScopedTemporaryDirectory> m_TmpDir; + Ref<HttpServer> m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +static bool +WaitForCondition(std::function<bool()> Predicate, int TimeoutMs, int PollIntervalMs = 200) +{ + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < static_cast<uint64_t>(TimeoutMs)) + { + if (Predicate()) + { + return true; + } + Sleep(PollIntervalMs); + } + return Predicate(); +} + +static std::string +GetCheckStatus(ConsulClient& Client, std::string_view ServiceId) +{ + using namespace std::literals; + + std::string JsonError; + CbFieldIterator ChecksRoot = LoadCompactBinaryFromJson(Client.GetAgentChecksJson(), JsonError); + if (!ChecksRoot || !JsonError.empty()) + { + return {}; + } + + for (CbFieldView F : ChecksRoot) + { + if (!F.IsObject()) + { + continue; + } + for (CbFieldView C : F.AsObjectView()) + { + CbObjectView Check = C.AsObjectView(); + if (Check["ServiceID"sv].AsString() == ServiceId) + { + return std::string(Check["Status"sv].AsString()); + } + } + } + return {}; +} + +TEST_SUITE_BEGIN("util.consul"); + +TEST_CASE("util.consul.service_lifecycle") +{ + ConsulProcess ConsulProc; + ConsulProc.SpawnConsulAgent(); + + TestHealthServer HealthServer; + HealthServer.Start(); + + ConsulClient Client({.BaseUri = "http://localhost:8500/"}); + + const std::string ServiceId = "test-health-svc"; + + ServiceRegistrationInfo Info; + Info.ServiceId = ServiceId; + Info.ServiceName = "zen-test-health"; + Info.Address = "127.0.0.1"; + Info.Port = static_cast<uint16_t>(HealthServer.Port()); + Info.HealthEndpoint = "health/"; + Info.HealthIntervalSeconds = 1; + Info.DeregisterAfterSeconds = 60; + + // Phase 1: Register and verify Consul sends health checks to our service + REQUIRE(Client.RegisterService(Info)); + REQUIRE(Client.HasService(ServiceId)); + + REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000)); + CHECK(HealthServer.Mock.HealthCheckCount.load() >= 1); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing"); + + // Phase 2: Explicit deregister + REQUIRE(Client.DeregisterService(ServiceId)); + CHECK_FALSE(Client.HasService(ServiceId)); + + // Phase 3: Register again, verify passing, then fail health and verify check goes critical + HealthServer.Mock.HealthCheckCount.store(0); + HealthServer.Mock.FailHealth.store(false); + + REQUIRE(Client.RegisterService(Info)); + REQUIRE(Client.HasService(ServiceId)); + + REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000)); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing"); + + HealthServer.Mock.FailHealth.store(true); + + // Wait for Consul to observe the failing check + REQUIRE(WaitForCondition([&]() { return GetCheckStatus(Client, ServiceId) == "critical"; }, 10000)); + CHECK_EQ(GetCheckStatus(Client, ServiceId), "critical"); + + // Phase 4: Explicit deregister while critical + REQUIRE(Client.DeregisterService(ServiceId)); + CHECK_FALSE(Client.HasService(ServiceId)); + + // Phase 5: Deregister an already-deregistered service - should not crash + Client.DeregisterService(ServiceId); + CHECK_FALSE(Client.HasService(ServiceId)); + + ConsulProc.StopConsulAgent(); +} + +TEST_SUITE_END(); + +#endif + } // namespace zen::consul diff --git a/src/zenutil/include/zenutil/cloud/s3client.h b/src/zenutil/include/zenutil/cloud/s3client.h index bd30aa8a2..f1f0df0e4 100644 --- a/src/zenutil/include/zenutil/cloud/s3client.h +++ b/src/zenutil/include/zenutil/cloud/s3client.h @@ -35,9 +35,7 @@ struct S3ClientOptions /// Overrides the static Credentials field. Ref<ImdsCredentialProvider> CredentialProvider; - std::chrono::milliseconds ConnectTimeout{5000}; - std::chrono::milliseconds Timeout{}; - uint8_t RetryCount = 3; + HttpClientSettings HttpSettings = {.LogCategory = "s3", .ConnectTimeout = std::chrono::milliseconds(5000), .RetryCount = 3}; }; struct S3ObjectInfo @@ -70,6 +68,8 @@ struct S3GetObjectResult : S3Result IoBuffer Content; std::string_view AsText() const { return std::string_view(reinterpret_cast<const char*>(Content.GetData()), Content.GetSize()); } + + static std::string_view NotFoundErrorText; }; /// Result of HeadObject - carries object metadata and existence status. @@ -119,7 +119,7 @@ public: S3Result PutObject(std::string_view Key, IoBuffer Content); /// Download an object from S3 - S3GetObjectResult GetObject(std::string_view Key); + S3GetObjectResult GetObject(std::string_view Key, const std::filesystem::path& TempFilePath = {}); /// Download a byte range of an object from S3 /// @param RangeStart First byte offset (inclusive) diff --git a/src/zenutil/include/zenutil/consul.h b/src/zenutil/include/zenutil/consul.h index 4002d5d23..7517ddd1e 100644 --- a/src/zenutil/include/zenutil/consul.h +++ b/src/zenutil/include/zenutil/consul.h @@ -28,7 +28,14 @@ struct ServiceRegistrationInfo class ConsulClient { public: - ConsulClient(std::string_view BaseUri, std::string_view Token = ""); + struct Configuration + { + std::string BaseUri; + std::string StaticToken; + std::string TokenEnvName; + }; + + ConsulClient(const Configuration& Config); ~ConsulClient(); ConsulClient(const ConsulClient&) = delete; @@ -55,9 +62,10 @@ public: private: static bool FindServiceInJson(std::string_view Json, std::string_view ServiceId); void ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap); + std::string GetNodeName(); - std::string m_Token; - HttpClient m_HttpClient; + Configuration m_Config; + HttpClient m_HttpClient; }; class ConsulProcess @@ -109,4 +117,6 @@ private: void RegistrationLoop(); }; +void consul_forcelink(); + } // namespace zen::consul diff --git a/src/zenutil/include/zenutil/process/subprocessmanager.h b/src/zenutil/include/zenutil/process/subprocessmanager.h index 4a25170df..e16c0c446 100644 --- a/src/zenutil/include/zenutil/process/subprocessmanager.h +++ b/src/zenutil/include/zenutil/process/subprocessmanager.h @@ -95,14 +95,19 @@ public: /// Spawn a new child process and begin monitoring it. /// /// If Options.StdoutPipe is set, the pipe is consumed and async reading - /// begins automatically. Similarly for Options.StderrPipe. + /// begins automatically. Similarly for Options.StderrPipe. When providing + /// pipes, pass the corresponding data callback here so it is installed + /// before the first async read completes — setting it later via + /// SetStdoutCallback risks losing early output. /// /// Returns a non-owning pointer valid until Remove() or manager destruction. /// The exit callback fires on an io_context thread when the process terminates. ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout = {}, + ProcessDataCallback OnStderr = {}); /// Adopt an already-running process by handle. Takes ownership of handle internals. ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); @@ -182,12 +187,6 @@ public: /// yet computed. [[nodiscard]] float GetCpuUsagePercent() const; - /// Set per-process stdout callback (overrides manager default). - void SetStdoutCallback(ProcessDataCallback Callback); - - /// Set per-process stderr callback (overrides manager default). - void SetStderrCallback(ProcessDataCallback Callback); - /// Return all stdout captured so far. When a callback is set, output is /// delivered there instead of being accumulated. [[nodiscard]] std::string GetCapturedStdout() const; @@ -237,11 +236,14 @@ public: /// Group name (as passed to CreateGroup). [[nodiscard]] std::string_view GetName() const; - /// Spawn a process into this group. + /// Spawn a process into this group. See SubprocessManager::Spawn for + /// details on the stdout/stderr callback parameters. ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout = {}, + ProcessDataCallback OnStderr = {}); /// Adopt an already-running process into this group. /// On Windows the process is assigned to the group's JobObject. diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h index 03d507400..d6f66fbea 100644 --- a/src/zenutil/include/zenutil/zenserverprocess.h +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -66,6 +66,7 @@ public: std::filesystem::path CreateNewTestDir(); std::filesystem::path CreateChildDir(std::string_view ChildName); std::filesystem::path ProgramBaseDir() const { return m_ProgramBaseDir; } + std::filesystem::path GetChildBaseDir() const { return m_ChildProcessBaseDir; } std::filesystem::path GetTestRootDir(std::string_view Path); inline bool IsInitialized() const { return m_IsInitialized; } inline bool IsTestEnvironment() const { return m_IsTestInstance; } diff --git a/src/zenutil/logging/jsonformatter.cpp b/src/zenutil/logging/jsonformatter.cpp index 673a03c94..c63ad891e 100644 --- a/src/zenutil/logging/jsonformatter.cpp +++ b/src/zenutil/logging/jsonformatter.cpp @@ -19,8 +19,6 @@ static void WriteEscapedString(MemoryBuffer& Dest, std::string_view Text) { // Strip ANSI SGR sequences before escaping so they don't appear in JSON output - static const auto IsEscapeStart = [](char C) { return C == '\033'; }; - const char* RangeStart = Text.data(); const char* End = Text.data() + Text.size(); diff --git a/src/zenutil/process/subprocessmanager.cpp b/src/zenutil/process/subprocessmanager.cpp index 3a91b0a61..e908dd63a 100644 --- a/src/zenutil/process/subprocessmanager.cpp +++ b/src/zenutil/process/subprocessmanager.cpp @@ -196,18 +196,6 @@ ManagedProcess::GetCpuUsagePercent() const return m_Impl->m_CpuUsagePercent.load(); } -void -ManagedProcess::SetStdoutCallback(ProcessDataCallback Callback) -{ - m_Impl->m_StdoutCallback = std::move(Callback); -} - -void -ManagedProcess::SetStderrCallback(ProcessDataCallback Callback) -{ - m_Impl->m_StderrCallback = std::move(Callback); -} - std::string ManagedProcess::GetCapturedStdout() const { @@ -288,7 +276,9 @@ struct SubprocessManager::Impl ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr); ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); void Remove(int Pid); void RemoveAll(); @@ -462,7 +452,9 @@ ManagedProcess* SubprocessManager::Impl::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { bool HasStdout = Options.StdoutPipe != nullptr; bool HasStderr = Options.StderrPipe != nullptr; @@ -476,6 +468,16 @@ SubprocessManager::Impl::Spawn(const std::filesystem::path& Executable, ImplPtr->m_Handle.Initialize(static_cast<int>(Result)); #endif + // Install callbacks before starting async readers so no data is missed. + if (OnStdout) + { + ImplPtr->m_StdoutCallback = std::move(OnStdout); + } + if (OnStderr) + { + ImplPtr->m_StderrCallback = std::move(OnStderr); + } + auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr))); ManagedProcess* Ptr = AddProcess(std::move(Proc)); @@ -719,10 +721,12 @@ ManagedProcess* SubprocessManager::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { ZEN_TRACE_CPU("SubprocessManager::Spawn"); - return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit)); + return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr)); } ManagedProcess* @@ -835,7 +839,9 @@ struct ProcessGroup::Impl ManagedProcess* Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit); + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr); ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit); void Remove(int Pid); void KillAll(); @@ -884,7 +890,9 @@ ManagedProcess* ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { bool HasStdout = Options.StdoutPipe != nullptr; bool HasStderr = Options.StderrPipe != nullptr; @@ -895,7 +903,11 @@ ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable, Options.AssignToJob = &m_JobObject; } #else - if (m_Pgid > 0) + if (m_Pgid == 0) + { + Options.Flags |= CreateProcOptions::Flag_NewProcessGroup; + } + else { Options.ProcessGroupId = m_Pgid; } @@ -917,6 +929,16 @@ ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable, } #endif + // Install callbacks before starting async readers so no data is missed. + if (OnStdout) + { + ImplPtr->m_StdoutCallback = std::move(OnStdout); + } + if (OnStderr) + { + ImplPtr->m_StderrCallback = std::move(OnStderr); + } + auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr))); ManagedProcess* Ptr = AddProcess(std::move(Proc)); @@ -1077,10 +1099,12 @@ ManagedProcess* ProcessGroup::Spawn(const std::filesystem::path& Executable, std::string_view CommandLine, CreateProcOptions& Options, - ProcessExitCallback OnExit) + ProcessExitCallback OnExit, + ProcessDataCallback OnStdout, + ProcessDataCallback OnStderr) { ZEN_TRACE_CPU("ProcessGroup::Spawn"); - return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit)); + return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr)); } ManagedProcess* @@ -1289,9 +1313,12 @@ TEST_CASE("SubprocessManager.StdoutCallback") std::string ReceivedData; bool Exited = false; - ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; }); - - Proc->SetStdoutCallback([&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); }); + ManagedProcess* Proc = Manager.Spawn( + AppStub, + CmdLine, + Options, + [&](ManagedProcess&, int) { Exited = true; }, + [&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); }); IoContext.run_for(5s); diff --git a/src/zenutil/zenutil.cpp b/src/zenutil/zenutil.cpp index 2ca380c75..516eec3a9 100644 --- a/src/zenutil/zenutil.cpp +++ b/src/zenutil/zenutil.cpp @@ -5,6 +5,7 @@ #if ZEN_WITH_TESTS # include <zenutil/cloud/imdscredentials.h> +# include <zenutil/consul.h> # include <zenutil/cloud/s3client.h> # include <zenutil/cloud/sigv4.h> # include <zenutil/config/commandlineoptions.h> @@ -20,6 +21,7 @@ zenutil_forcelinktests() { cache::rpcrecord_forcelink(); commandlineoptions_forcelink(); + consul::consul_forcelink(); imdscredentials_forcelink(); logstreamlistener_forcelink(); subprocessmanager_forcelink(); |