aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/zen/cmds/builds_cmd.cpp1
-rw-r--r--src/zen/cmds/exec_cmd.cpp12
-rw-r--r--src/zen/cmds/exec_cmd.h1
-rw-r--r--src/zencompute/CLAUDE.md10
-rw-r--r--src/zencompute/computeservice.cpp269
-rw-r--r--src/zencompute/httpcomputeservice.cpp113
-rw-r--r--src/zencompute/httporchestrator.cpp3
-rw-r--r--src/zencompute/include/zencompute/computeservice.h1
-rw-r--r--src/zencompute/include/zencompute/httpcomputeservice.h2
-rw-r--r--src/zencompute/include/zencompute/httporchestrator.h2
-rw-r--r--src/zencompute/pathvalidation.h118
-rw-r--r--src/zencompute/runners/functionrunner.cpp12
-rw-r--r--src/zencompute/runners/linuxrunner.cpp2
-rw-r--r--src/zencompute/runners/localrunner.cpp17
-rw-r--r--src/zencompute/runners/localrunner.h4
-rw-r--r--src/zencompute/runners/macrunner.cpp2
-rw-r--r--src/zencompute/runners/managedrunner.cpp279
-rw-r--r--src/zencompute/runners/managedrunner.h64
-rw-r--r--src/zencompute/runners/windowsrunner.cpp99
-rw-r--r--src/zencompute/runners/windowsrunner.h1
-rw-r--r--src/zencompute/runners/winerunner.cpp2
-rw-r--r--src/zencore/compactbinarypackage.cpp122
-rw-r--r--src/zencore/filesystem.cpp71
-rw-r--r--src/zencore/include/zencore/compactbinarypackage.h19
-rw-r--r--src/zencore/include/zencore/logging.h28
-rw-r--r--src/zencore/include/zencore/process.h14
-rw-r--r--src/zencore/process.cpp186
-rw-r--r--src/zencore/zencore.cpp2
-rw-r--r--src/zenhttp/httpserver.cpp24
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h14
-rw-r--r--src/zenhttp/include/zenhttp/httpstats.h2
-rw-r--r--src/zenhttp/include/zenhttp/localrefpolicy.h21
-rw-r--r--src/zenhttp/include/zenhttp/packageformat.h25
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h2
-rw-r--r--src/zenhttp/monitoring/httpstats.cpp3
-rw-r--r--src/zenhttp/packageformat.cpp548
-rw-r--r--src/zenhttp/servers/httpasio.cpp8
-rw-r--r--src/zenhttp/servers/httpsys.cpp9
-rw-r--r--src/zenhttp/servers/wstest.cpp3
-rw-r--r--src/zennomad/nomadprocess.cpp2
-rw-r--r--src/zenremotestore/builds/buildstorageoperations.cpp155
-rw-r--r--src/zenremotestore/builds/jupiterbuildstorage.cpp51
-rw-r--r--src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h12
-rw-r--r--src/zenremotestore/jupiter/jupitersession.cpp2
-rw-r--r--src/zenremotestore/projectstore/remoteprojectstore.cpp55
-rw-r--r--src/zens3-testbed/main.cpp2
-rw-r--r--src/zenserver-test/cache-tests.cpp23
-rw-r--r--src/zenserver-test/compute-tests.cpp636
-rw-r--r--src/zenserver-test/hub-tests.cpp8
-rw-r--r--src/zenserver-test/projectstore-tests.cpp503
-rw-r--r--src/zenserver/frontend/html/compute/hub.html2
-rw-r--r--src/zenserver/frontend/html/pages/builds.js2
-rw-r--r--src/zenserver/frontend/html/pages/cache.js33
-rw-r--r--src/zenserver/frontend/html/pages/compute.js69
-rw-r--r--src/zenserver/frontend/html/pages/entry.js2
-rw-r--r--src/zenserver/frontend/html/pages/hub.js217
-rw-r--r--src/zenserver/frontend/html/pages/orchestrator.js33
-rw-r--r--src/zenserver/frontend/html/pages/page.js33
-rw-r--r--src/zenserver/frontend/html/pages/projects.js33
-rw-r--r--src/zenserver/frontend/html/pages/workspaces.js2
-rw-r--r--src/zenserver/frontend/html/zen.css4
-rw-r--r--src/zenserver/hub/httphubservice.cpp97
-rw-r--r--src/zenserver/hub/httphubservice.h16
-rw-r--r--src/zenserver/hub/httpproxyhandler.cpp504
-rw-r--r--src/zenserver/hub/httpproxyhandler.h52
-rw-r--r--src/zenserver/hub/hub.cpp288
-rw-r--r--src/zenserver/hub/hub.h48
-rw-r--r--src/zenserver/hub/hydration.cpp399
-rw-r--r--src/zenserver/hub/hydration.h10
-rw-r--r--src/zenserver/hub/storageserverinstance.cpp42
-rw-r--r--src/zenserver/hub/storageserverinstance.h34
-rw-r--r--src/zenserver/hub/zenhubserver.cpp299
-rw-r--r--src/zenserver/hub/zenhubserver.h12
-rw-r--r--src/zenserver/sessions/httpsessions.cpp3
-rw-r--r--src/zenserver/sessions/httpsessions.h2
-rw-r--r--src/zenserver/storage/cache/httpstructuredcache.cpp16
-rw-r--r--src/zenserver/storage/cache/httpstructuredcache.h10
-rw-r--r--src/zenserver/storage/localrefpolicy.cpp29
-rw-r--r--src/zenserver/storage/localrefpolicy.h25
-rw-r--r--src/zenserver/storage/projectstore/httpprojectstore.cpp25
-rw-r--r--src/zenserver/storage/projectstore/httpprojectstore.h10
-rw-r--r--src/zenserver/storage/zenstorageserver.cpp7
-rw-r--r--src/zenserver/storage/zenstorageserver.h20
-rw-r--r--src/zenserver/zenserver.cpp2
-rw-r--r--src/zenstore/projectstore.cpp6
-rw-r--r--src/zentest-appstub/zentest-appstub.cpp55
-rw-r--r--src/zenutil/cloud/imdscredentials.cpp2
-rw-r--r--src/zenutil/cloud/minioprocess.cpp2
-rw-r--r--src/zenutil/cloud/s3client.cpp48
-rw-r--r--src/zenutil/consul/consul.cpp296
-rw-r--r--src/zenutil/include/zenutil/cloud/s3client.h8
-rw-r--r--src/zenutil/include/zenutil/consul.h16
-rw-r--r--src/zenutil/include/zenutil/process/subprocessmanager.h22
-rw-r--r--src/zenutil/include/zenutil/zenserverprocess.h1
-rw-r--r--src/zenutil/logging/jsonformatter.cpp2
-rw-r--r--src/zenutil/process/subprocessmanager.cpp75
-rw-r--r--src/zenutil/zenutil.cpp2
97 files changed, 5494 insertions, 985 deletions
diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp
index 93108dd47..cc8315e0b 100644
--- a/src/zen/cmds/builds_cmd.cpp
+++ b/src/zen/cmds/builds_cmd.cpp
@@ -1784,6 +1784,7 @@ namespace builds_impl {
{
OptionalStructuredOutput->AddString("path"sv, fmt::format("{}", Path));
OptionalStructuredOutput->AddInteger("rawSize"sv, RawSize);
+ OptionalStructuredOutput->AddHash("rawHash"sv, RawHash);
switch (Platform)
{
case SourcePlatform::Windows:
diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp
index 30e860a3f..9719fce77 100644
--- a/src/zen/cmds/exec_cmd.cpp
+++ b/src/zen/cmds/exec_cmd.cpp
@@ -1119,6 +1119,7 @@ ExecHttpSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
ExecInprocSubCmd::ExecInprocSubCmd(ExecCommand& Parent) : ZenSubCmdBase("inproc", "Handle execution in-process"), m_Parent(Parent)
{
+ m_SubOptions.add_option("managed", "", "managed", "Use managed local runner (if supported)", cxxopts::value(m_Managed), "<bool>");
}
void
@@ -1130,7 +1131,16 @@ ExecInprocSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
zen::compute::ComputeServiceSession ComputeSession(Resolver);
std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
- ComputeSession.AddLocalRunner(Resolver, TempPath);
+ if (m_Managed)
+ {
+ ZEN_CONSOLE_INFO("using managed local runner");
+ ComputeSession.AddManagedLocalRunner(Resolver, TempPath);
+ }
+ else
+ {
+ ZEN_CONSOLE_INFO("using local runner");
+ ComputeSession.AddLocalRunner(Resolver, TempPath);
+ }
Stopwatch ExecTimer;
int ReturnValue = m_Parent.RunSession(ComputeSession);
diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h
index c55412780..a0bf201a1 100644
--- a/src/zen/cmds/exec_cmd.h
+++ b/src/zen/cmds/exec_cmd.h
@@ -61,6 +61,7 @@ public:
private:
ExecCommand& m_Parent;
+ bool m_Managed = false;
};
class ExecBeaconSubCmd : public ZenSubCmdBase
diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md
index a1a39fc3c..750879d5a 100644
--- a/src/zencompute/CLAUDE.md
+++ b/src/zencompute/CLAUDE.md
@@ -141,7 +141,7 @@ Actions that fail or are abandoned can be automatically retried or manually resc
**Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure.
-**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit.
+**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Setting `max_retries=0` disables automatic retry entirely; omitting the field (or setting it to a negative value) uses the default of 3. Both automatic and manual paths respect this limit.
**Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent.
@@ -156,7 +156,7 @@ Queues group actions from a single client session. A `QueueEntry` (internal) tra
- `ActiveLsns` — for cancellation lookup (under `m_Lock`)
- `FinishedLsns` — moved here when actions complete
- `IdleSince` — used for 15-minute automatic expiry
-- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit
+- `Config` — CbObject set at creation; supports `max_retries` (int, default 3) to override the default retry limit. `0` = no retries, negative or absent = use default
**Queue state machine (`QueueState` enum):**
```
@@ -216,11 +216,7 @@ Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `Han
## Concurrency Model
-**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks:
-1. `m_ResultsLock`
-2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock")
-3. `m_PendingLock`
-4. `m_QueueLock`
+**Locking discipline:** The three action maps (`m_PendingActions`, `m_RunningMap`, `m_ResultsMap`) are guarded by a single `m_ActionMapLock`. This eliminates lock-ordering concerns between maps and prevents actions from being temporarily absent from all maps during state transitions. Runner-level `m_RunningLock` in `LocalProcessRunner` / `RemoteHttpRunner` is a separate lock on a different class — unrelated to the session-level action map lock.
**Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`.
diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp
index 92901de64..aaf34cbe2 100644
--- a/src/zencompute/computeservice.cpp
+++ b/src/zencompute/computeservice.cpp
@@ -8,6 +8,8 @@
# include "recording/actionrecorder.h"
# include "runners/localrunner.h"
# include "runners/remotehttprunner.h"
+# include "runners/managedrunner.h"
+# include "pathvalidation.h"
# if ZEN_PLATFORM_LINUX
# include "runners/linuxrunner.h"
# elif ZEN_PLATFORM_WINDOWS
@@ -195,13 +197,9 @@ struct ComputeServiceSession::Impl
std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr};
- RwLock m_PendingLock;
- std::map<int, Ref<RunnerAction>> m_PendingActions;
-
- RwLock m_RunningLock;
+ RwLock m_ActionMapLock; // Guards m_PendingActions, m_RunningMap, m_ResultsMap
+ std::map<int, Ref<RunnerAction>> m_PendingActions;
std::unordered_map<int, Ref<RunnerAction>> m_RunningMap;
-
- RwLock m_ResultsLock;
std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap;
metrics::Meter m_ResultRate;
std::atomic<uint64_t> m_RetiredCount{0};
@@ -343,9 +341,12 @@ struct ComputeServiceSession::Impl
ActionCounts GetActionCounts()
{
ActionCounts Counts;
- Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
- Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
- Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load();
+ m_ActionMapLock.WithSharedLock([&] {
+ Counts.Pending = (int)m_PendingActions.size();
+ Counts.Running = (int)m_RunningMap.size();
+ Counts.Completed = (int)m_ResultsMap.size();
+ });
+ Counts.Completed += (int)m_RetiredCount.load();
Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] {
size_t Count = 0;
for (const auto& [Id, Queue] : m_Queues)
@@ -364,8 +365,10 @@ struct ComputeServiceSession::Impl
{
Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed));
m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); });
- m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); });
- m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); });
+ m_ActionMapLock.WithSharedLock([&] {
+ Cbo << "actions_complete"sv << m_ResultsMap.size();
+ Cbo << "actions_pending"sv << m_PendingActions.size();
+ });
Cbo << "actions_submitted"sv << GetSubmittedActionCount();
EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo);
EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo);
@@ -450,27 +453,17 @@ ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState)
void
ComputeServiceSession::Impl::AbandonAllActions()
{
- // Collect all pending actions and mark them as Abandoned
+ // Collect all pending and running actions under a single lock scope
std::vector<Ref<RunnerAction>> PendingToAbandon;
+ std::vector<Ref<RunnerAction>> RunningToAbandon;
- m_PendingLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
PendingToAbandon.reserve(m_PendingActions.size());
for (auto& [Lsn, Action] : m_PendingActions)
{
PendingToAbandon.push_back(Action);
}
- });
-
- for (auto& Action : PendingToAbandon)
- {
- Action->SetActionState(RunnerAction::State::Abandoned);
- }
- // Collect all running actions and mark them as Abandoned, then
- // best-effort cancel via the local runner group
- std::vector<Ref<RunnerAction>> RunningToAbandon;
-
- m_RunningLock.WithSharedLock([&] {
RunningToAbandon.reserve(m_RunningMap.size());
for (auto& [Lsn, Action] : m_RunningMap)
{
@@ -478,6 +471,11 @@ ComputeServiceSession::Impl::AbandonAllActions()
}
});
+ for (auto& Action : PendingToAbandon)
+ {
+ Action->SetActionState(RunnerAction::State::Abandoned);
+ }
+
for (auto& Action : RunningToAbandon)
{
Action->SetActionState(RunnerAction::State::Abandoned);
@@ -742,7 +740,7 @@ std::vector<ComputeServiceSession::RunningActionInfo>
ComputeServiceSession::Impl::GetRunningActions()
{
std::vector<ComputeServiceSession::RunningActionInfo> Result;
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
Result.reserve(m_RunningMap.size());
for (const auto& [Lsn, Action] : m_RunningMap)
{
@@ -810,6 +808,11 @@ void
ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker)
{
ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker");
+
+ // Validate all paths in the worker description upfront, before the worker is
+ // distributed to runners. This rejects malicious packages early at ingestion time.
+ ValidateWorkerDescriptionPaths(Worker.GetObject());
+
RwLock::ExclusiveLockScope _(m_WorkerLock);
const IoHash& WorkerId = Worker.GetObject().GetHash();
@@ -994,10 +997,15 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke
Pending->ActionObj = ActionObj;
Pending->Priority = RequestPriority;
- // For now simply put action into pending state, so we can do batch scheduling
+ // Insert into the pending map immediately so the action is visible to
+ // FindActionResult/GetActionResult right away. SetActionState will call
+ // PostUpdate which adds the action to m_UpdatedActions and signals the
+ // scheduler, but the scheduler's HandleActionUpdates inserts with
+ // std::map::insert which is a no-op for existing keys.
ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Pending}); });
Pending->SetActionState(RunnerAction::State::Pending);
if (m_Recorder)
@@ -1043,11 +1051,7 @@ ComputeServiceSession::Impl::GetSubmittedActionCount()
HttpResponseCode
ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage)
{
- // This lock is held for the duration of the function since we need to
- // be sure that the action doesn't change state while we are checking the
- // different data structures
-
- RwLock::ExclusiveLockScope _(m_ResultsLock);
+ RwLock::ExclusiveLockScope _(m_ActionMapLock);
if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end())
{
@@ -1058,25 +1062,14 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult
return HttpResponseCode::OK;
}
+ if (m_PendingActions.find(ActionLsn) != m_PendingActions.end())
{
- RwLock::SharedLockScope __(m_PendingLock);
-
- if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end())
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
- // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must
- // always be taken after m_ResultsLock if both are needed
-
+ if (m_RunningMap.find(ActionLsn) != m_RunningMap.end())
{
- RwLock::SharedLockScope __(m_RunningLock);
-
- if (m_RunningMap.find(ActionLsn) != m_RunningMap.end())
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
return HttpResponseCode::NotFound;
@@ -1085,11 +1078,7 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult
HttpResponseCode
ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage)
{
- // This lock is held for the duration of the function since we need to
- // be sure that the action doesn't change state while we are checking the
- // different data structures
-
- RwLock::ExclusiveLockScope _(m_ResultsLock);
+ RwLock::ExclusiveLockScope _(m_ActionMapLock);
for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It)
{
@@ -1103,30 +1092,19 @@ ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage&
}
}
+ for (const auto& [K, Pending] : m_PendingActions)
{
- RwLock::SharedLockScope __(m_PendingLock);
-
- for (const auto& [K, Pending] : m_PendingActions)
+ if (Pending->ActionId == ActionId)
{
- if (Pending->ActionId == ActionId)
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
}
- // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must
- // always be taken after m_ResultsLock if both are needed
-
+ for (const auto& [K, v] : m_RunningMap)
{
- RwLock::SharedLockScope __(m_RunningLock);
-
- for (const auto& [K, v] : m_RunningMap)
+ if (v->ActionId == ActionId)
{
- if (v->ActionId == ActionId)
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
}
@@ -1144,7 +1122,7 @@ ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo)
{
Cbo.BeginArray("completed");
- m_ResultsLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
for (auto& [Lsn, Action] : m_ResultsMap)
{
Cbo.BeginObject();
@@ -1275,20 +1253,14 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId)
std::vector<Ref<RunnerAction>> PendingActionsToCancel;
std::vector<int> RunningLsnsToCancel;
- m_PendingLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
for (int Lsn : LsnsToCancel)
{
if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end())
{
PendingActionsToCancel.push_back(It->second);
}
- }
- });
-
- m_RunningLock.WithSharedLock([&] {
- for (int Lsn : LsnsToCancel)
- {
- if (m_RunningMap.find(Lsn) != m_RunningMap.end())
+ else if (m_RunningMap.find(Lsn) != m_RunningMap.end())
{
RunningLsnsToCancel.push_back(Lsn);
}
@@ -1307,7 +1279,7 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId)
// transition from the runner is blocked (Cancelled > Failed in the enum).
for (int Lsn : RunningLsnsToCancel)
{
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end())
{
It->second->SetActionState(RunnerAction::State::Cancelled);
@@ -1445,7 +1417,7 @@ ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo)
if (Queue)
{
Queue->m_Lock.WithSharedLock([&] {
- m_ResultsLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
for (int Lsn : Queue->FinishedLsns)
{
if (m_ResultsMap.contains(Lsn))
@@ -1475,15 +1447,19 @@ ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, Run
return;
}
+ bool WasActive = false;
Queue->m_Lock.WithExclusiveLock([&] {
- Queue->ActiveLsns.erase(Lsn);
+ WasActive = Queue->ActiveLsns.erase(Lsn) > 0;
Queue->FinishedLsns.insert(Lsn);
});
- const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed);
- if (PreviousActive == 1)
+ if (WasActive)
{
- Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed);
+ const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed);
+ if (PreviousActive == 1)
+ {
+ Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed);
+ }
}
switch (ActionState)
@@ -1541,9 +1517,15 @@ ComputeServiceSession::Impl::SchedulePendingActions()
{
ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions");
int ScheduledCount = 0;
- size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
- size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
- size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); });
+ size_t RunningCount = 0;
+ size_t PendingCount = 0;
+ size_t ResultCount = 0;
+
+ m_ActionMapLock.WithSharedLock([&] {
+ RunningCount = m_RunningMap.size();
+ PendingCount = m_PendingActions.size();
+ ResultCount = m_ResultsMap.size();
+ });
static Stopwatch DumpRunningTimer;
@@ -1560,7 +1542,7 @@ ComputeServiceSession::Impl::SchedulePendingActions()
DumpRunningTimer.Reset();
std::set<int> RunningList;
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
for (auto& [K, V] : m_RunningMap)
{
RunningList.insert(K);
@@ -1602,7 +1584,7 @@ ComputeServiceSession::Impl::SchedulePendingActions()
// Also note that the m_PendingActions list is not maintained
// here, that's done periodically in SchedulePendingActions()
- m_PendingLock.WithExclusiveLock([&] {
+ m_ActionMapLock.WithExclusiveLock([&] {
if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused)
{
return;
@@ -1701,7 +1683,7 @@ ComputeServiceSession::Impl::SchedulerThreadFunction()
{
int TimeoutMs = 500;
- auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
+ auto PendingCount = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.size(); });
if (PendingCount)
{
@@ -1720,22 +1702,22 @@ ComputeServiceSession::Impl::SchedulerThreadFunction()
m_SchedulingThreadEvent.Reset();
}
- ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}",
- m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }),
- PendingCount,
- m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }),
- m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }),
- TimeoutMs);
+ m_ActionMapLock.WithSharedLock([&] {
+ ZEN_DEBUG("compute scheduler TICK (Pending: {}, Running: {}, Results: {}) timeout: {}",
+ m_PendingActions.size(),
+ m_RunningMap.size(),
+ m_ResultsMap.size(),
+ TimeoutMs);
+ });
HandleActionUpdates();
// Auto-transition Draining → Paused when all work is done
if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining)
{
- size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
- size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
+ bool AllDrained = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.empty() && m_RunningMap.empty(); });
- if (Pending == 0 && Running == 0)
+ if (AllDrained)
{
SessionState Expected = SessionState::Draining;
if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel))
@@ -1776,9 +1758,9 @@ ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId)
if (Config)
{
- int Value = Config["max_retries"].AsInt32(0);
+ int Value = Config["max_retries"].AsInt32(-1);
- if (Value > 0)
+ if (Value >= 0)
{
return Value;
}
@@ -1797,7 +1779,7 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn)
// Find, validate, and remove atomically under a single lock scope to prevent
// concurrent RescheduleAction calls from double-removing the same action.
- m_ResultsLock.WithExclusiveLock([&] {
+ m_ActionMapLock.WithExclusiveLock([&] {
auto It = m_ResultsMap.find(ActionLsn);
if (It == m_ResultsMap.end())
{
@@ -1871,26 +1853,20 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn)
bool WasRunning = false;
// Look for the action in pending or running maps
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end())
{
Action = It->second;
WasRunning = true;
}
+ else if (auto PIt = m_PendingActions.find(ActionLsn); PIt != m_PendingActions.end())
+ {
+ Action = PIt->second;
+ }
});
if (!Action)
{
- m_PendingLock.WithSharedLock([&] {
- if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end())
- {
- Action = It->second;
- }
- });
- }
-
- if (!Action)
- {
return {.Success = false, .Error = "Action not found in pending or running maps"};
}
@@ -1912,18 +1888,15 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn)
void
ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn)
{
- m_RunningLock.WithExclusiveLock([&] {
- m_PendingLock.WithExclusiveLock([&] {
- if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
- {
- m_PendingActions.erase(ActionLsn);
- }
- else
- {
- m_RunningMap.erase(FindIt);
- }
- });
- });
+ // Caller must hold m_ActionMapLock exclusively.
+ if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
+ {
+ m_PendingActions.erase(ActionLsn);
+ }
+ else
+ {
+ m_RunningMap.erase(FindIt);
+ }
}
void
@@ -1973,7 +1946,7 @@ ComputeServiceSession::Impl::HandleActionUpdates()
}
else
{
- m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); });
+ m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); });
}
break;
@@ -1983,11 +1956,9 @@ ComputeServiceSession::Impl::HandleActionUpdates()
// Dispatched to a runner — move from pending to running
case RunnerAction::State::Running:
- m_RunningLock.WithExclusiveLock([&] {
- m_PendingLock.WithExclusiveLock([&] {
- m_RunningMap.insert({ActionLsn, Action});
- m_PendingActions.erase(ActionLsn);
- });
+ m_ActionMapLock.WithExclusiveLock([&] {
+ m_RunningMap.insert({ActionLsn, Action});
+ m_PendingActions.erase(ActionLsn);
});
ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn);
break;
@@ -1995,7 +1966,10 @@ ComputeServiceSession::Impl::HandleActionUpdates()
// Retracted — pull back for rescheduling without counting against retry limit
case RunnerAction::State::Retracted:
{
- RemoveActionFromActiveMaps(ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] {
+ m_RunningMap.erase(ActionLsn);
+ m_PendingActions[ActionLsn] = Action;
+ });
Action->ResetActionStateToPending();
ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn);
break;
@@ -2019,7 +1993,10 @@ ComputeServiceSession::Impl::HandleActionUpdates()
if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries)
{
- RemoveActionFromActiveMaps(ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] {
+ m_RunningMap.erase(ActionLsn);
+ m_PendingActions[ActionLsn] = Action;
+ });
// Reset triggers PostUpdate() which re-enters the action as Pending
Action->ResetActionStateToPending();
@@ -2034,16 +2011,16 @@ ComputeServiceSession::Impl::HandleActionUpdates()
}
}
- RemoveActionFromActiveMaps(ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] {
+ RemoveActionFromActiveMaps(ActionLsn);
- // Update queue counters BEFORE publishing the result into
- // m_ResultsMap. GetActionResult erases from m_ResultsMap
- // under m_ResultsLock, so if we updated counters after
- // releasing that lock, a caller could observe ActiveCount
- // still at 1 immediately after GetActionResult returned OK.
- NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState);
+ // Update queue counters BEFORE publishing the result into
+ // m_ResultsMap. GetActionResult erases from m_ResultsMap
+ // under m_ActionMapLock, so if we updated counters after
+ // releasing that lock, a caller could observe ActiveCount
+ // still at 1 immediately after GetActionResult returned OK.
+ NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState);
- m_ResultsLock.WithExclusiveLock([&] {
m_ResultsMap[ActionLsn] = Action;
// Append to bounded action history ring
@@ -2282,6 +2259,18 @@ ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::files
}
void
+ComputeServiceSession::AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions)
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::AddManagedLocalRunner");
+
+ auto* NewRunner =
+ new ManagedProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, MaxConcurrentActions);
+
+ m_Impl->SyncWorkersToRunner(*NewRunner);
+ m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner);
+}
+
+void
ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName)
{
ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner");
diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp
index bdfd9d197..1d28e7137 100644
--- a/src/zencompute/httpcomputeservice.cpp
+++ b/src/zencompute/httpcomputeservice.cpp
@@ -93,16 +93,17 @@ struct HttpComputeService::Impl
uint64_t NewBytes = 0;
};
- IngestStats IngestPackageAttachments(const CbPackage& Package);
- bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList);
- void HandleWorkersGet(HttpServerRequest& HttpReq);
- void HandleWorkersAllGet(HttpServerRequest& HttpReq);
- void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status);
- void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId);
- void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker);
+ bool IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats);
+ bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList);
+ bool ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment);
+ void HandleWorkersGet(HttpServerRequest& HttpReq);
+ void HandleWorkersAllGet(HttpServerRequest& HttpReq);
+ void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status);
+ void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId);
+ void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker);
// WebSocket / observer
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection);
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri);
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code);
void OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions);
@@ -373,7 +374,7 @@ HttpComputeService::Impl::RegisterRoutes()
if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output);
ResponseCode != HttpResponseCode::OK)
{
- ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode))
+ ZEN_DEBUG("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode))
if (ResponseCode == HttpResponseCode::NotFound)
{
@@ -1167,35 +1168,81 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin
return ParseInt<int>(Capture).value_or(0);
}
-HttpComputeService::Impl::IngestStats
-HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package)
+bool
+HttpComputeService::Impl::ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment)
{
- IngestStats Stats;
+ const IoHash ClaimedHash = Attachment.GetHash();
+ CompressedBuffer Buffer = Attachment.AsCompressedBinary();
+ const IoHash HeaderHash = Buffer.DecodeRawHash();
+
+ if (HeaderHash != ClaimedHash)
+ {
+ ZEN_WARN("attachment header hash mismatch: claimed {} but header contains {}", ClaimedHash, HeaderHash);
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ return false;
+ }
+
+ IoHashStream Hasher;
+
+ bool DecompressOk = Buffer.DecompressToStream(
+ 0,
+ Buffer.DecodeRawSize(),
+ [&](uint64_t /*SourceOffset*/, uint64_t /*SourceSize*/, uint64_t /*Offset*/, const CompositeBuffer& Range) -> bool {
+ for (const SharedBuffer& Segment : Range.GetSegments())
+ {
+ Hasher.Append(Segment.GetView());
+ }
+ return true;
+ });
+
+ if (!DecompressOk)
+ {
+ ZEN_WARN("attachment {}: failed to decompress", ClaimedHash);
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ return false;
+ }
+
+ const IoHash ActualHash = Hasher.GetHash();
+
+ if (ActualHash != ClaimedHash)
+ {
+ ZEN_WARN("attachment hash mismatch: claimed {} but decompressed data hashes to {}", ClaimedHash, ActualHash);
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ return false;
+ }
+
+ return true;
+}
+bool
+HttpComputeService::Impl::IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats)
+{
for (const CbAttachment& Attachment : Package.GetAttachments())
{
ZEN_ASSERT(Attachment.IsCompressedBinary());
- const IoHash DataHash = Attachment.GetHash();
- CompressedBuffer DataView = Attachment.AsCompressedBinary();
-
- ZEN_UNUSED(DataHash);
+ if (!ValidateAttachmentHash(HttpReq, Attachment))
+ {
+ return false;
+ }
- const uint64_t CompressedSize = DataView.GetCompressedSize();
+ const IoHash DataHash = Attachment.GetHash();
+ CompressedBuffer DataView = Attachment.AsCompressedBinary();
+ const uint64_t CompressedSize = DataView.GetCompressedSize();
- Stats.Bytes += CompressedSize;
- ++Stats.Count;
+ OutStats.Bytes += CompressedSize;
+ ++OutStats.Count;
const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
if (InsertResult.New)
{
- Stats.NewBytes += CompressedSize;
- ++Stats.NewCount;
+ OutStats.NewBytes += CompressedSize;
+ ++OutStats.NewCount;
}
}
- return Stats;
+ return true;
}
bool
@@ -1253,7 +1300,10 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que
{
CbPackage Package = HttpReq.ReadPayloadPackage();
Body = Package.GetObject();
- Stats = IngestPackageAttachments(Package);
+ if (!IngestPackageAttachments(HttpReq, Package, Stats))
+ {
+ return; // validation failed, response already written
+ }
break;
}
@@ -1268,8 +1318,7 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que
{
// --- Batch path ---
- // For CbObject payloads, check all attachments upfront before enqueuing anything
- if (HttpReq.RequestContentType() == HttpContentType::kCbObject)
+ // Verify all action attachment references exist in the store
{
std::vector<IoHash> NeedList;
@@ -1345,7 +1394,6 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que
// --- Single-action path: Body is the action itself ---
- if (HttpReq.RequestContentType() == HttpContentType::kCbObject)
{
std::vector<IoHash> NeedList;
@@ -1491,10 +1539,14 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const
{
ZEN_ASSERT(Attachment.IsCompressedBinary());
+ if (!ValidateAttachmentHash(HttpReq, Attachment))
+ {
+ return;
+ }
+
const IoHash DataHash = Attachment.GetHash();
CompressedBuffer Buffer = Attachment.AsCompressedBinary();
- ZEN_UNUSED(DataHash);
TotalAttachmentBytes += Buffer.GetCompressedSize();
++AttachmentCount;
@@ -1537,9 +1589,9 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const
//
void
-HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
- m_Impl->OnWebSocketOpen(std::move(Connection));
+ m_Impl->OnWebSocketOpen(std::move(Connection), RelativeUri);
}
void
@@ -1566,8 +1618,9 @@ HttpComputeService::OnActionsCompleted(std::span<const CompletedActionNotificati
//
void
-HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
+ ZEN_UNUSED(RelativeUri);
ZEN_INFO("compute WebSocket client connected");
m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
}
diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp
index 6cbe01e04..d92af8716 100644
--- a/src/zencompute/httporchestrator.cpp
+++ b/src/zencompute/httporchestrator.cpp
@@ -418,8 +418,9 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request)
# if ZEN_WITH_WEBSOCKETS
void
-HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
+ ZEN_UNUSED(RelativeUri);
if (!m_PushEnabled.load())
{
return;
diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h
index 1ca78738a..ad556f546 100644
--- a/src/zencompute/include/zencompute/computeservice.h
+++ b/src/zencompute/include/zencompute/computeservice.h
@@ -167,6 +167,7 @@ public:
// Action runners
void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0);
+ void AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0);
void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName);
// Action submission
diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h
index b58e73a0d..de85a295f 100644
--- a/src/zencompute/include/zencompute/httpcomputeservice.h
+++ b/src/zencompute/include/zencompute/httpcomputeservice.h
@@ -45,7 +45,7 @@ public:
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h
index da5c5dfc3..58b2c9152 100644
--- a/src/zencompute/include/zencompute/httporchestrator.h
+++ b/src/zencompute/include/zencompute/httporchestrator.h
@@ -70,7 +70,7 @@ public:
// IWebSocketHandler
#if ZEN_WITH_WEBSOCKETS
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
#endif
diff --git a/src/zencompute/pathvalidation.h b/src/zencompute/pathvalidation.h
new file mode 100644
index 000000000..c2e30183a
--- /dev/null
+++ b/src/zencompute/pathvalidation.h
@@ -0,0 +1,118 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/except_fmt.h>
+#include <zencore/string.h>
+
+#include <filesystem>
+#include <string_view>
+
+namespace zen::compute {
+
+// Validate that a single path component contains only characters that are valid
+// file/directory names on all supported platforms. Uses Windows rules as the most
+// restrictive superset, since packages may be built on one platform and consumed
+// on another.
+inline void
+ValidatePathComponent(std::string_view Component, std::string_view FullPath)
+{
+ // Reject control characters (0x00-0x1F) and characters forbidden on Windows
+ for (char Ch : Component)
+ {
+ if (static_cast<unsigned char>(Ch) < 0x20 || Ch == '<' || Ch == '>' || Ch == ':' || Ch == '"' || Ch == '|' || Ch == '?' ||
+ Ch == '*')
+ {
+ throw zen::invalid_argument("invalid character in path component '{}' of '{}'", Component, FullPath);
+ }
+ }
+
+ // Reject empty components and trailing dots or spaces (silently stripped on Windows, leading to confusion)
+ if (Component.empty() || Component.back() == '.' || Component.back() == ' ')
+ {
+ throw zen::invalid_argument("path component '{}' of '{}' has trailing dot or space", Component, FullPath);
+ }
+
+ // Reject Windows reserved device names (CON, PRN, AUX, NUL, COM1-9, LPT1-9)
+ // These are reserved with or without an extension (e.g. "CON.txt" is still reserved).
+ std::string_view Stem = Component.substr(0, Component.find('.'));
+
+ static constexpr std::string_view ReservedNames[] = {
+ "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7",
+ "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
+ };
+
+ for (std::string_view Reserved : ReservedNames)
+ {
+ if (zen::StrCaseCompare(Stem, Reserved) == 0)
+ {
+ throw zen::invalid_argument("path component '{}' of '{}' uses reserved device name '{}'", Component, FullPath, Reserved);
+ }
+ }
+}
+
+// Validate that a path extracted from a package is a safe relative path.
+// Rejects absolute paths, ".." components, and invalid platform filenames.
+inline void
+ValidateSandboxRelativePath(std::string_view Name)
+{
+ if (Name.empty())
+ {
+ throw zen::invalid_argument("path traversal detected: empty path name");
+ }
+
+ std::filesystem::path Parsed(Name);
+
+ if (Parsed.is_absolute())
+ {
+ throw zen::invalid_argument("path traversal detected: '{}' is an absolute path", Name);
+ }
+
+ for (const auto& Component : Parsed)
+ {
+ std::string ComponentStr = Component.string();
+
+ if (ComponentStr == "..")
+ {
+ throw zen::invalid_argument("path traversal detected: '{}' contains '..' component", Name);
+ }
+
+ // Skip "." (current directory) — harmless in relative paths
+ if (ComponentStr != ".")
+ {
+ ValidatePathComponent(ComponentStr, Name);
+ }
+ }
+}
+
+// Validate all path entries in a worker description CbObject.
+// Checks path, executables[].name, dirs[], and files[].name fields.
+// Throws an exception if any invalid paths are found.
+inline void
+ValidateWorkerDescriptionPaths(const CbObject& WorkerDescription)
+{
+ using namespace std::literals;
+
+ if (auto PathField = WorkerDescription["path"sv]; PathField.HasValue())
+ {
+ ValidateSandboxRelativePath(PathField.AsString());
+ }
+
+ for (auto& It : WorkerDescription["executables"sv])
+ {
+ ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString());
+ }
+
+ for (auto& It : WorkerDescription["dirs"sv])
+ {
+ ValidateSandboxRelativePath(It.AsString());
+ }
+
+ for (auto& It : WorkerDescription["files"sv])
+ {
+ ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString());
+ }
+}
+
+} // namespace zen::compute
diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp
index 4f116e7d8..67e12b84e 100644
--- a/src/zencompute/runners/functionrunner.cpp
+++ b/src/zencompute/runners/functionrunner.cpp
@@ -164,8 +164,9 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
}
}
- // Assign any remaining actions to runners with capacity (round-robin)
- for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount)
+ // Assign any remaining actions to runners with capacity (round-robin).
+ // Cap at TotalCapacity to avoid spinning when there are more actions than runners can accept.
+ for (int i = 0; ActionIdx < Actions.size() && ActionIdx < TotalCapacity; i = (i + 1) % RunnerCount)
{
if (Capacities[i] > PerRunnerActions[i].size())
{
@@ -186,11 +187,12 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
}
}
- // Reassemble results in original action order
- std::vector<SubmitResult> Results(Actions.size());
+ // Reassemble results in original action order.
+ // Actions beyond ActionIdx were not assigned to any runner (insufficient capacity).
+ std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"});
std::vector<size_t> PerRunnerIdx(RunnerCount, 0);
- for (size_t i = 0; i < Actions.size(); ++i)
+ for (size_t i = 0; i < ActionIdx; ++i)
{
size_t RunnerIdx = ActionRunnerIndex[i];
size_t Idx = PerRunnerIdx[RunnerIdx]++;
diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp
index e79a6c90f..9055005d9 100644
--- a/src/zencompute/runners/linuxrunner.cpp
+++ b/src/zencompute/runners/linuxrunner.cpp
@@ -331,6 +331,8 @@ LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver,
{
ZEN_INFO("namespace sandboxing enabled for child processes");
}
+
+ StartMonitorThread();
}
SubmitResult
diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp
index b61e0a46f..1b748c0e5 100644
--- a/src/zencompute/runners/localrunner.cpp
+++ b/src/zencompute/runners/localrunner.cpp
@@ -4,6 +4,8 @@
#if ZEN_WITH_COMPUTE_SERVICES
+# include "pathvalidation.h"
+
# include <zencore/compactbinary.h>
# include <zencore/compactbinarybuilder.h>
# include <zencore/compactbinarypackage.h>
@@ -104,8 +106,6 @@ LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver,
ZEN_INFO("Cleanup complete");
}
- m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this};
-
# if ZEN_PLATFORM_WINDOWS
// Suppress any error dialogs caused by missing dependencies
UINT OldMode = ::SetErrorMode(0);
@@ -382,6 +382,8 @@ LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromP
const IoHash ChunkHash = FileEntry["hash"sv].AsHash();
const uint64_t Size = FileEntry["size"sv].AsUInt64();
+ ValidateSandboxRelativePath(Name);
+
CompressedBuffer Compressed;
if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash))
@@ -457,7 +459,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage,
for (auto& It : WorkerDescription["dirs"sv])
{
- std::string_view Name = It.AsString();
+ std::string_view Name = It.AsString();
+ ValidateSandboxRelativePath(Name);
std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()};
// Validate dir path stays within sandbox
@@ -482,6 +485,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage,
}
WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer());
+
+ ZEN_INFO("manifested worker '{}' in '{}'", WorkerPackage.GetObjectHash(), SandboxPath);
}
CbPackage
@@ -540,6 +545,12 @@ LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath)
}
void
+LocalProcessRunner::StartMonitorThread()
+{
+ m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this};
+}
+
+void
LocalProcessRunner::MonitorThreadFunction()
{
SetCurrentThreadName("LocalProcessRunner_Monitor");
diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h
index b8cff6826..d6589db43 100644
--- a/src/zencompute/runners/localrunner.h
+++ b/src/zencompute/runners/localrunner.h
@@ -67,6 +67,7 @@ protected:
{
Ref<RunnerAction> Action;
void* ProcessHandle = nullptr;
+ int Pid = 0;
int ExitCode = 0;
std::filesystem::path SandboxPath;
@@ -83,8 +84,6 @@ protected:
std::filesystem::path m_SandboxPath;
int32_t m_MaxRunningActions = 64; // arbitrary limit for testing
- // if used in conjuction with m_ResultsLock, this lock must be taken *after*
- // m_ResultsLock to avoid deadlocks
RwLock m_RunningLock;
std::unordered_map<int, Ref<RunningAction>> m_RunningMap;
@@ -95,6 +94,7 @@ protected:
std::thread m_MonitorThread;
std::atomic<bool> m_MonitorThreadEnabled{true};
Event m_MonitorThreadEvent;
+ void StartMonitorThread();
void MonitorThreadFunction();
virtual void SweepRunningActions();
virtual void CancelRunningActions();
diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp
index 5cec90699..c2ccca9a6 100644
--- a/src/zencompute/runners/macrunner.cpp
+++ b/src/zencompute/runners/macrunner.cpp
@@ -130,6 +130,8 @@ MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver,
{
ZEN_INFO("Seatbelt sandboxing enabled for child processes");
}
+
+ StartMonitorThread();
}
SubmitResult
diff --git a/src/zencompute/runners/managedrunner.cpp b/src/zencompute/runners/managedrunner.cpp
new file mode 100644
index 000000000..e4a7ba388
--- /dev/null
+++ b/src/zencompute/runners/managedrunner.cpp
@@ -0,0 +1,279 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "managedrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/except_fmt.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/scopeguard.h>
+# include <zencore/timer.h>
+# include <zencore/trace.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <asio/io_context.hpp>
+# include <asio/executor_work_guard.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+ManagedProcessRunner::ManagedProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ int32_t MaxConcurrentActions)
+: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions)
+, m_IoContext(std::make_unique<asio::io_context>())
+, m_SubprocessManager(std::make_unique<SubprocessManager>(*m_IoContext))
+{
+ m_ProcessGroup = m_SubprocessManager->CreateGroup("compute-workers");
+
+ // Run the io_context on a small thread pool so that exit callbacks and
+ // metrics sampling are dispatched without blocking each other.
+ for (int i = 0; i < kIoThreadCount; ++i)
+ {
+ m_IoThreads.emplace_back([this, i] {
+ SetCurrentThreadName(fmt::format("mrunner_{}", i));
+
+ // work_guard keeps run() alive even when there is no pending work yet
+ auto WorkGuard = asio::make_work_guard(*m_IoContext);
+
+ m_IoContext->run();
+ });
+ }
+}
+
+ManagedProcessRunner::~ManagedProcessRunner()
+{
+ try
+ {
+ Shutdown();
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_WARN("exception during managed process runner shutdown: {}", Ex.what());
+ }
+}
+
+void
+ManagedProcessRunner::Shutdown()
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::Shutdown");
+ m_AcceptNewActions = false;
+
+ CancelRunningActions();
+
+ // Tear down the SubprocessManager before stopping the io_context so that
+ // any in-flight callbacks are drained cleanly.
+ if (m_SubprocessManager)
+ {
+ m_SubprocessManager->DestroyGroup("compute-workers");
+ m_ProcessGroup = nullptr;
+ m_SubprocessManager.reset();
+ }
+
+ if (m_IoContext)
+ {
+ m_IoContext->stop();
+ }
+
+ for (std::thread& Thread : m_IoThreads)
+ {
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
+ }
+ m_IoThreads.clear();
+}
+
+SubmitResult
+ManagedProcessRunner::SubmitAction(Ref<RunnerAction> Action)
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::SubmitAction");
+ std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action);
+
+ if (!Prepared)
+ {
+ return SubmitResult{.IsAccepted = false};
+ }
+
+ CbObject WorkerDescription = Prepared->WorkerPackage.GetObject();
+
+ // Parse environment variables from worker descriptor ("KEY=VALUE" strings)
+ // into the key-value pairs expected by CreateProcOptions.
+ std::vector<std::pair<std::string, std::string>> EnvPairs;
+ for (auto& It : WorkerDescription["environment"sv])
+ {
+ std::string_view Str = It.AsString();
+ size_t Eq = Str.find('=');
+ if (Eq != std::string_view::npos)
+ {
+ EnvPairs.emplace_back(std::string(Str.substr(0, Eq)), std::string(Str.substr(Eq + 1)));
+ }
+ }
+
+ // Build command line
+ std::string_view ExecPath = WorkerDescription["path"sv].AsString();
+ std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred();
+
+ std::string CommandLine = fmt::format("\"{}\" -Build=build.action"sv, ExePath.string());
+
+ ZEN_DEBUG("Executing (managed): '{}' (sandbox='{}')", CommandLine, Prepared->SandboxPath);
+
+ CreateProcOptions Options;
+ Options.WorkingDirectory = &Prepared->SandboxPath;
+ Options.Flags = CreateProcOptions::Flag_NoConsole;
+ Options.Environment = std::move(EnvPairs);
+
+ const int32_t ActionLsn = Prepared->ActionLsn;
+
+ ManagedProcess* Proc = nullptr;
+
+ try
+ {
+ Proc = m_ProcessGroup->Spawn(ExePath, CommandLine, Options, [this, ActionLsn](ManagedProcess& /*Process*/, int ExitCode) {
+ OnProcessExit(ActionLsn, ExitCode);
+ });
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_ERROR("Failed to spawn process for action LSN {}: {}", ActionLsn, Ex.what());
+ m_DeferredDeleter.Enqueue(ActionLsn, std::move(Prepared->SandboxPath));
+ return SubmitResult{.IsAccepted = false};
+ }
+
+ {
+ Ref<RunningAction> NewAction{new RunningAction()};
+ NewAction->Action = Action;
+ NewAction->ProcessHandle = static_cast<void*>(Proc);
+ NewAction->Pid = Proc->Pid();
+ NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+
+ RwLock::ExclusiveLockScope _(m_RunningLock);
+ m_RunningMap[ActionLsn] = std::move(NewAction);
+ }
+
+ Action->SetActionState(RunnerAction::State::Running);
+
+ ZEN_DEBUG("Managed runner: action LSN {} -> PID {}", ActionLsn, Proc->Pid());
+
+ return SubmitResult{.IsAccepted = true};
+}
+
+void
+ManagedProcessRunner::OnProcessExit(int ActionLsn, int ExitCode)
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::OnProcessExit");
+
+ Ref<RunningAction> Running;
+
+ m_RunningLock.WithExclusiveLock([&] {
+ auto It = m_RunningMap.find(ActionLsn);
+ if (It != m_RunningMap.end())
+ {
+ Running = std::move(It->second);
+ m_RunningMap.erase(It);
+ }
+ });
+
+ if (!Running)
+ {
+ return;
+ }
+
+ ZEN_DEBUG("Managed runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"), ActionLsn, Running->Pid, ExitCode);
+
+ Running->ExitCode = ExitCode;
+
+ // Capture final CPU metrics from the managed process before it is removed.
+ auto* Proc = static_cast<ManagedProcess*>(Running->ProcessHandle);
+ if (Proc)
+ {
+ ProcessMetrics Metrics = Proc->GetLatestMetrics();
+ float CpuMs = static_cast<float>(Metrics.UserTimeMs + Metrics.KernelTimeMs);
+ Running->Action->CpuSeconds.store(CpuMs / 1000.0f, std::memory_order_relaxed);
+
+ float CpuPct = Proc->GetCpuUsagePercent();
+ if (CpuPct >= 0.0f)
+ {
+ Running->Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed);
+ }
+ }
+
+ Running->ProcessHandle = nullptr;
+
+ std::vector<Ref<RunningAction>> CompletedActions;
+ CompletedActions.push_back(std::move(Running));
+ ProcessCompletedActions(CompletedActions);
+}
+
+void
+ManagedProcessRunner::CancelRunningActions()
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::CancelRunningActions");
+
+ std::unordered_map<int, Ref<RunningAction>> RunningMap;
+ m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); });
+
+ if (RunningMap.empty())
+ {
+ return;
+ }
+
+ ZEN_INFO("cancelling {} running actions via process group", RunningMap.size());
+
+ Stopwatch Timer;
+
+ // Kill all processes in the group atomically (TerminateJobObject on Windows,
+ // SIGTERM+SIGKILL on POSIX).
+ if (m_ProcessGroup)
+ {
+ m_ProcessGroup->KillAll();
+ }
+
+ for (auto& [Lsn, Running] : RunningMap)
+ {
+ m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath));
+ Running->Action->SetActionState(RunnerAction::State::Failed);
+ }
+
+ ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+}
+
+bool
+ManagedProcessRunner::CancelAction(int ActionLsn)
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::CancelAction");
+
+ ManagedProcess* Proc = nullptr;
+
+ m_RunningLock.WithSharedLock([&] {
+ auto It = m_RunningMap.find(ActionLsn);
+ if (It != m_RunningMap.end() && It->second->ProcessHandle != nullptr)
+ {
+ Proc = static_cast<ManagedProcess*>(It->second->ProcessHandle);
+ }
+ });
+
+ if (!Proc)
+ {
+ return false;
+ }
+
+ // Terminate the process. The exit callback will handle the rest
+ // (remove from running map, gather outputs or mark failed).
+ Proc->Terminate(222);
+
+ ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn);
+ return true;
+}
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/managedrunner.h b/src/zencompute/runners/managedrunner.h
new file mode 100644
index 000000000..21a44d43c
--- /dev/null
+++ b/src/zencompute/runners/managedrunner.h
@@ -0,0 +1,64 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "localrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zenutil/process/subprocessmanager.h>
+
+# include <memory>
+
+namespace asio {
+class io_context;
+}
+
+namespace zen::compute {
+
+/** Cross-platform process runner backed by SubprocessManager.
+
+ Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting,
+ input/output handling, and shared action preparation. Replaces the polling-based
+ monitor thread with async exit callbacks driven by SubprocessManager, and
+ delegates CPU/memory metrics sampling to the manager's built-in round-robin
+ sampler.
+
+ A ProcessGroup (backed by a JobObject on Windows, process group on POSIX) is
+ used for bulk cancellation on shutdown.
+
+ This runner does not perform any platform-specific sandboxing (AppContainer,
+ namespaces, Seatbelt). It is intended as a simpler, cross-platform alternative
+ to the platform-specific runners for non-sandboxed workloads.
+ */
+class ManagedProcessRunner : public LocalProcessRunner
+{
+public:
+ ManagedProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ int32_t MaxConcurrentActions = 0);
+ ~ManagedProcessRunner();
+
+ void Shutdown() override;
+ [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
+ void CancelRunningActions() override;
+ bool CancelAction(int ActionLsn) override;
+ [[nodiscard]] bool IsHealthy() override { return true; }
+
+private:
+ static constexpr int kIoThreadCount = 4;
+
+ // Exit callback posted on an io_context thread.
+ void OnProcessExit(int ActionLsn, int ExitCode);
+
+ std::unique_ptr<asio::io_context> m_IoContext;
+ std::unique_ptr<SubprocessManager> m_SubprocessManager;
+ ProcessGroup* m_ProcessGroup = nullptr;
+ std::vector<std::thread> m_IoThreads;
+};
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp
index cd4b646e9..92ee65c2d 100644
--- a/src/zencompute/runners/windowsrunner.cpp
+++ b/src/zencompute/runners/windowsrunner.cpp
@@ -21,6 +21,12 @@ ZEN_THIRD_PARTY_INCLUDES_START
# include <sddl.h>
ZEN_THIRD_PARTY_INCLUDES_END
+// JOB_OBJECT_UILIMIT_ERRORMODE is defined in winuser.h which may be
+// excluded by WIN32_LEAN_AND_MEAN.
+# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE)
+# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400
+# endif
+
namespace zen::compute {
using namespace std::literals;
@@ -34,38 +40,65 @@ WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver,
: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions)
, m_Sandboxed(Sandboxed)
{
- if (!m_Sandboxed)
+ // Create a job object shared by all child processes. Restricting the
+ // error-mode UI prevents crash dialogs (WER / Dr. Watson) from
+ // blocking the monitor thread when a worker process terminates
+ // abnormally.
+ m_JobObject = CreateJobObjectW(nullptr, nullptr);
+ if (m_JobObject)
{
- return;
+ JOBOBJECT_EXTENDED_LIMIT_INFORMATION ExtLimits{};
+ ExtLimits.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION;
+ SetInformationJobObject(m_JobObject, JobObjectExtendedLimitInformation, &ExtLimits, sizeof(ExtLimits));
+
+ JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{};
+ UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE;
+ SetInformationJobObject(m_JobObject, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions));
+
+ // Set error mode on this process so children inherit it. The
+ // UILIMIT_ERRORMODE restriction above prevents them from clearing
+ // SEM_NOGPFAULTERRORBOX.
+ SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX);
}
- // Build a unique profile name per process to avoid collisions
- m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId());
+ if (m_Sandboxed)
+ {
+ // Build a unique profile name per process to avoid collisions
+ m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId());
- // Clean up any stale profile from a previous crash
- DeleteAppContainerProfile(m_AppContainerName.c_str());
+ // Clean up any stale profile from a previous crash
+ DeleteAppContainerProfile(m_AppContainerName.c_str());
- PSID Sid = nullptr;
+ PSID Sid = nullptr;
- HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(),
- m_AppContainerName.c_str(), // display name
- m_AppContainerName.c_str(), // description
- nullptr, // no capabilities
- 0, // capability count
- &Sid);
+ HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(),
+ m_AppContainerName.c_str(), // display name
+ m_AppContainerName.c_str(), // description
+ nullptr, // no capabilities
+ 0, // capability count
+ &Sid);
- if (FAILED(Hr))
- {
- throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr));
- }
+ if (FAILED(Hr))
+ {
+ throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr));
+ }
- m_AppContainerSid = Sid;
+ m_AppContainerSid = Sid;
+
+ ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName));
+ }
- ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName));
+ StartMonitorThread();
}
WindowsProcessRunner::~WindowsProcessRunner()
{
+ if (m_JobObject)
+ {
+ CloseHandle(m_JobObject);
+ m_JobObject = nullptr;
+ }
+
if (m_AppContainerSid)
{
FreeSid(m_AppContainerSid);
@@ -172,9 +205,9 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action)
LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr;
LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr;
BOOL bInheritHandles = FALSE;
- DWORD dwCreationFlags = DETACHED_PROCESS;
+ DWORD dwCreationFlags = CREATE_SUSPENDED | DETACHED_PROCESS;
- ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed);
+ ZEN_DEBUG("{}: '{}' (sandbox='{}')", m_Sandboxed ? "Sandboxing" : "Executing", WideToUtf8(CommandLine.c_str()), Prepared->SandboxPath);
CommandLine.EnsureNulTerminated();
@@ -260,14 +293,21 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action)
}
}
- CloseHandle(ProcessInformation.hThread);
+ if (m_JobObject)
+ {
+ AssignProcessToJobObject(m_JobObject, ProcessInformation.hProcess);
+ }
- Ref<RunningAction> NewAction{new RunningAction()};
- NewAction->Action = Action;
- NewAction->ProcessHandle = ProcessInformation.hProcess;
- NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+ ResumeThread(ProcessInformation.hThread);
+ CloseHandle(ProcessInformation.hThread);
{
+ Ref<RunningAction> NewAction{new RunningAction()};
+ NewAction->Action = Action;
+ NewAction->ProcessHandle = ProcessInformation.hProcess;
+ NewAction->Pid = ProcessInformation.dwProcessId;
+ NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+
RwLock::ExclusiveLockScope _(m_RunningLock);
m_RunningMap[Prepared->ActionLsn] = std::move(NewAction);
@@ -275,6 +315,8 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action)
Action->SetActionState(RunnerAction::State::Running);
+ ZEN_DEBUG("Local runner: action LSN {} -> PID {}", Action->ActionLsn, ProcessInformation.dwProcessId);
+
return SubmitResult{.IsAccepted = true};
}
@@ -294,6 +336,11 @@ WindowsProcessRunner::SweepRunningActions()
if (IsSuccess && ExitCode != STILL_ACTIVE)
{
+ ZEN_DEBUG("Local runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"),
+ Running->Action->ActionLsn,
+ Running->Pid,
+ ExitCode);
+
CloseHandle(Running->ProcessHandle);
Running->ProcessHandle = INVALID_HANDLE_VALUE;
Running->ExitCode = ExitCode;
diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h
index 9f2385cc4..adeaf02fc 100644
--- a/src/zencompute/runners/windowsrunner.h
+++ b/src/zencompute/runners/windowsrunner.h
@@ -46,6 +46,7 @@ private:
bool m_Sandboxed = false;
PSID m_AppContainerSid = nullptr;
std::wstring m_AppContainerName;
+ HANDLE m_JobObject = nullptr;
};
} // namespace zen::compute
diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp
index 506bec73b..b4fafb467 100644
--- a/src/zencompute/runners/winerunner.cpp
+++ b/src/zencompute/runners/winerunner.cpp
@@ -36,6 +36,8 @@ WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver,
sigemptyset(&Action.sa_mask);
Action.sa_handler = SIG_DFL;
sigaction(SIGCHLD, &Action, nullptr);
+
+ StartMonitorThread();
}
SubmitResult
diff --git a/src/zencore/compactbinarypackage.cpp b/src/zencore/compactbinarypackage.cpp
index 56a292ca6..cd268745c 100644
--- a/src/zencore/compactbinarypackage.cpp
+++ b/src/zencore/compactbinarypackage.cpp
@@ -684,14 +684,22 @@ namespace legacy {
Writer.Save(Ar);
}
- bool TryLoadCbPackage(CbPackage& Package, IoBuffer InBuffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper)
+ bool TryLoadCbPackage(CbPackage& Package,
+ IoBuffer InBuffer,
+ BufferAllocator Allocator,
+ CbPackage::AttachmentResolver* Mapper,
+ bool ValidateHashes)
{
BinaryReader Reader(InBuffer.Data(), InBuffer.Size());
- return TryLoadCbPackage(Package, Reader, Allocator, Mapper);
+ return TryLoadCbPackage(Package, Reader, Allocator, Mapper, ValidateHashes);
}
- bool TryLoadCbPackage(CbPackage& Package, BinaryReader& Reader, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper)
+ bool TryLoadCbPackage(CbPackage& Package,
+ BinaryReader& Reader,
+ BufferAllocator Allocator,
+ CbPackage::AttachmentResolver* Mapper,
+ bool ValidateHashes)
{
Package = CbPackage();
for (;;)
@@ -708,7 +716,11 @@ namespace legacy {
if (ValueField.IsBinary())
{
const MemoryView View = ValueField.AsBinaryView();
- if (View.GetSize() > 0)
+ if (View.GetSize() == 0)
+ {
+ return false;
+ }
+ else
{
SharedBuffer Buffer = SharedBuffer::MakeView(View, ValueField.GetOuterBuffer()).MakeOwned();
CbField HashField = LoadCompactBinary(Reader, Allocator);
@@ -748,7 +760,11 @@ namespace legacy {
{
const IoHash Hash = ValueField.AsHash();
- ZEN_ASSERT(Mapper);
+ if (!Mapper)
+ {
+ return false;
+ }
+
if (SharedBuffer AttachmentData = (*Mapper)(Hash))
{
IoHash RawHash;
@@ -763,6 +779,10 @@ namespace legacy {
}
else
{
+ if (ValidateHashes && IoHash::HashBuffer(AttachmentData) != Hash)
+ {
+ return false;
+ }
const CbValidateError ValidationResult = ValidateCompactBinary(AttachmentData.GetView(), CbValidateMode::All);
if (ValidationResult != CbValidateError::None)
{
@@ -801,13 +821,13 @@ namespace legacy {
#if ZEN_WITH_TESTS
void
-usonpackage_forcelink()
+cbpackage_forcelink()
{
}
TEST_SUITE_BEGIN("core.compactbinarypackage");
-TEST_CASE("usonpackage")
+TEST_CASE("cbpackage")
{
using namespace std::literals;
@@ -997,7 +1017,7 @@ TEST_CASE("usonpackage")
}
}
-TEST_CASE("usonpackage.serialization")
+TEST_CASE("cbpackage.serialization")
{
using namespace std::literals;
@@ -1303,7 +1323,7 @@ TEST_CASE("usonpackage.serialization")
}
}
-TEST_CASE("usonpackage.invalidpackage")
+TEST_CASE("cbpackage.invalidpackage")
{
const auto TestLoad = [](std::initializer_list<uint8_t> RawData, BufferAllocator Allocator = UniqueBuffer::Alloc) {
const MemoryView RawView = MakeMemoryView(RawData);
@@ -1345,6 +1365,90 @@ TEST_CASE("usonpackage.invalidpackage")
}
}
+TEST_CASE("cbpackage.legacy.invalidpackage")
+{
+ const auto TestLegacyLoad = [](std::initializer_list<uint8_t> RawData) {
+ const MemoryView RawView = MakeMemoryView(RawData);
+ IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(RawView.GetData()), RawView.GetSize());
+ CbPackage Package;
+ CHECK_FALSE(legacy::TryLoadCbPackage(Package, Buffer, &UniqueBuffer::Alloc));
+ };
+
+ SUBCASE("Empty") { TestLegacyLoad({}); }
+
+ SUBCASE("Zero size binary rejects")
+ {
+ // A binary field with zero payload size should be rejected (would desync the reader)
+ BinaryWriter Writer;
+ CbWriter Cb;
+ Cb.AddBinary(MemoryView()); // zero-size binary
+ Cb.Save(Writer);
+
+ IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(MakeMemoryView(Writer).GetData()), MakeMemoryView(Writer).GetSize());
+ CbPackage Package;
+ CHECK_FALSE(legacy::TryLoadCbPackage(Package, Buffer, &UniqueBuffer::Alloc));
+ }
+}
+
+TEST_CASE("cbpackage.legacy.hashresolution")
+{
+ // Build a valid legacy package with an object, then round-trip it
+ CbObjectWriter RootWriter;
+ RootWriter.AddString("name", "test");
+ CbObject RootObject = RootWriter.Save();
+
+ CbAttachment ObjectAttach(RootObject);
+
+ CbPackage OriginalPkg;
+ OriginalPkg.SetObject(RootObject);
+ OriginalPkg.AddAttachment(ObjectAttach);
+
+ BinaryWriter Writer;
+ legacy::SaveCbPackage(OriginalPkg, Writer);
+
+ IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(MakeMemoryView(Writer).GetData()), MakeMemoryView(Writer).GetSize());
+ CbPackage LoadedPkg;
+ CHECK(legacy::TryLoadCbPackage(LoadedPkg, Buffer, &UniqueBuffer::Alloc));
+
+ // The hash-only path requires a mapper — without one it should fail
+ CbWriter HashOnlyCb;
+ HashOnlyCb.AddHash(ObjectAttach.GetHash());
+ HashOnlyCb.AddNull();
+ BinaryWriter HashOnlyWriter;
+ HashOnlyCb.Save(HashOnlyWriter);
+
+ IoBuffer HashOnlyBuffer(IoBuffer::Wrap,
+ const_cast<void*>(MakeMemoryView(HashOnlyWriter).GetData()),
+ MakeMemoryView(HashOnlyWriter).GetSize());
+ CbPackage HashOnlyPkg;
+ CHECK_FALSE(legacy::TryLoadCbPackage(HashOnlyPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, nullptr));
+
+ // With a mapper that returns valid data, it should succeed
+ CbPackage::AttachmentResolver Resolver = [&](const IoHash& Hash) -> SharedBuffer {
+ if (Hash == ObjectAttach.GetHash())
+ {
+ return RootObject.GetBuffer();
+ }
+ return {};
+ };
+ CHECK(legacy::TryLoadCbPackage(HashOnlyPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &Resolver));
+
+ // Build a different but structurally valid CbObject to use as mismatched data
+ CbObjectWriter DifferentWriter;
+ DifferentWriter.AddString("name", "different");
+ CbObject DifferentObject = DifferentWriter.Save();
+
+ CbPackage::AttachmentResolver BadResolver = [&](const IoHash&) -> SharedBuffer { return DifferentObject.GetBuffer(); };
+ CbPackage BadPkg;
+
+ // With ValidateHashes enabled and a mapper that returns mismatched data, it should fail
+ CHECK_FALSE(legacy::TryLoadCbPackage(BadPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &BadResolver, /*ValidateHashes*/ true));
+
+ // Without ValidateHashes, the mismatched data is accepted (structure is still valid CB)
+ CbPackage UncheckedPkg;
+ CHECK(legacy::TryLoadCbPackage(UncheckedPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &BadResolver, /*ValidateHashes*/ false));
+}
+
TEST_SUITE_END();
#endif
diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp
index 0d361801f..a63594be9 100644
--- a/src/zencore/filesystem.cpp
+++ b/src/zencore/filesystem.cpp
@@ -3275,14 +3275,25 @@ MakeSafeAbsolutePathInPlace(std::filesystem::path& Path)
{
if (!Path.empty())
{
- std::filesystem::path AbsolutePath = std::filesystem::absolute(Path).make_preferred();
+ Path = std::filesystem::absolute(Path).make_preferred();
#if ZEN_PLATFORM_WINDOWS
- const std::string_view Prefix = "\\\\?\\";
- const std::u8string PrefixU8(Prefix.begin(), Prefix.end());
- std::u8string PathString = AbsolutePath.u8string();
- if (!PathString.empty() && !PathString.starts_with(PrefixU8))
+ const std::u8string_view LongPathPrefix = u8"\\\\?\\";
+ const std::u8string_view UncPrefix = u8"\\\\";
+ const std::u8string_view LongPathUncPrefix = u8"\\\\?\\UNC\\";
+
+ std::u8string PathString = Path.u8string();
+ if (!PathString.empty() && !PathString.starts_with(LongPathPrefix))
{
- PathString.insert(0, PrefixU8);
+ if (PathString.starts_with(UncPrefix))
+ {
+ // UNC path: \\server\share → \\?\UNC\server\share
+ PathString.replace(0, UncPrefix.size(), LongPathUncPrefix);
+ }
+ else
+ {
+ // Local path: C:\foo → \\?\C:\foo
+ PathString.insert(0, LongPathPrefix);
+ }
Path = PathString;
}
#endif // ZEN_PLATFORM_WINDOWS
@@ -4049,6 +4060,54 @@ TEST_CASE("SharedMemory")
CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, false));
}
+TEST_CASE("filesystem.MakeSafeAbsolutePath")
+{
+# if ZEN_PLATFORM_WINDOWS
+ // Local path gets \\?\ prefix
+ {
+ std::filesystem::path Local = MakeSafeAbsolutePath("C:\\Users\\test");
+ CHECK(Local.u8string().starts_with(u8"\\\\?\\"));
+ CHECK(Local.u8string().find(u8"C:\\Users\\test") != std::u8string::npos);
+ }
+
+ // UNC path gets \\?\UNC\ prefix
+ {
+ std::filesystem::path Unc = MakeSafeAbsolutePath("\\\\server\\share\\path");
+ std::u8string UncStr = Unc.u8string();
+ CHECK_MESSAGE(UncStr.starts_with(u8"\\\\?\\UNC\\"), fmt::format("Expected \\\\?\\UNC\\ prefix, got '{}'", Unc));
+ CHECK_MESSAGE(UncStr.find(u8"server\\share\\path") != std::u8string::npos,
+ fmt::format("Expected server\\share\\path in '{}'", Unc));
+ // Must NOT produce \\?\\\server (double backslash after \\?\)
+ CHECK_MESSAGE(UncStr.find(u8"\\\\?\\\\\\") == std::u8string::npos,
+ fmt::format("Path contains invalid double-backslash after prefix: '{}'", Unc));
+ }
+
+ // Already-prefixed path is not double-prefixed
+ {
+ std::filesystem::path Already = MakeSafeAbsolutePath("\\\\?\\C:\\already\\prefixed");
+ size_t Count = 0;
+ std::u8string Str = Already.u8string();
+ for (size_t Pos = Str.find(u8"\\\\?\\"); Pos != std::u8string::npos; Pos = Str.find(u8"\\\\?\\", Pos + 1))
+ {
+ ++Count;
+ }
+ CHECK_EQ(Count, 1);
+ }
+
+ // Already-prefixed UNC path is not double-prefixed
+ {
+ std::filesystem::path AlreadyUnc = MakeSafeAbsolutePath("\\\\?\\UNC\\server\\share");
+ size_t Count = 0;
+ std::u8string Str = AlreadyUnc.u8string();
+ for (size_t Pos = Str.find(u8"\\\\?\\"); Pos != std::u8string::npos; Pos = Str.find(u8"\\\\?\\", Pos + 1))
+ {
+ ++Count;
+ }
+ CHECK_EQ(Count, 1);
+ }
+# endif // ZEN_PLATFORM_WINDOWS
+}
+
TEST_SUITE_END();
#endif
diff --git a/src/zencore/include/zencore/compactbinarypackage.h b/src/zencore/include/zencore/compactbinarypackage.h
index 64b62e2c0..148c0d3fd 100644
--- a/src/zencore/include/zencore/compactbinarypackage.h
+++ b/src/zencore/include/zencore/compactbinarypackage.h
@@ -278,10 +278,10 @@ public:
* @return The attachment, or null if the attachment is not found.
* @note The returned pointer is only valid until the attachments on this package are modified.
*/
- const CbAttachment* FindAttachment(const IoHash& Hash) const;
+ [[nodiscard]] const CbAttachment* FindAttachment(const IoHash& Hash) const;
/** Find an attachment if it exists in the package. */
- inline const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); }
+ [[nodiscard]] const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); }
/** Add the attachment to this package. */
inline void AddAttachment(const CbAttachment& Attachment) { AddAttachment(Attachment, nullptr); }
@@ -336,17 +336,26 @@ private:
IoHash ObjectHash;
};
+/** In addition to the above, we also support a legacy format which is used by
+ * the HTTP project store for historical reasons. Don't use the below functions
+ * for new code.
+ */
namespace legacy {
void SaveCbAttachment(const CbAttachment& Attachment, CbWriter& Writer);
void SaveCbPackage(const CbPackage& Package, CbWriter& Writer);
void SaveCbPackage(const CbPackage& Package, BinaryWriter& Ar);
- bool TryLoadCbPackage(CbPackage& Package, IoBuffer Buffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper = nullptr);
+ bool TryLoadCbPackage(CbPackage& Package,
+ IoBuffer Buffer,
+ BufferAllocator Allocator,
+ CbPackage::AttachmentResolver* Mapper = nullptr,
+ bool ValidateHashes = false);
bool TryLoadCbPackage(CbPackage& Package,
BinaryReader& Reader,
BufferAllocator Allocator,
- CbPackage::AttachmentResolver* Mapper = nullptr);
+ CbPackage::AttachmentResolver* Mapper = nullptr,
+ bool ValidateHashes = false);
} // namespace legacy
-void usonpackage_forcelink(); // internal
+void cbpackage_forcelink(); // internal
} // namespace zen
diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h
index 3427991d2..1608ad523 100644
--- a/src/zencore/include/zencore/logging.h
+++ b/src/zencore/include/zencore/logging.h
@@ -90,6 +90,34 @@ using zen::ConsoleLog;
using zen::ErrorLog;
using zen::Log;
+////////////////////////////////////////////////////////////////////////
+// Color helpers
+
+#define ZEN_RED(str) "\033[31m" str "\033[0m"
+#define ZEN_GREEN(str) "\033[32m" str "\033[0m"
+#define ZEN_YELLOW(str) "\033[33m" str "\033[0m"
+#define ZEN_BLUE(str) "\033[34m" str "\033[0m"
+#define ZEN_MAGENTA(str) "\033[35m" str "\033[0m"
+#define ZEN_CYAN(str) "\033[36m" str "\033[0m"
+#define ZEN_WHITE(str) "\033[37m" str "\033[0m"
+
+#define ZEN_BRIGHT_RED(str) "\033[91m" str "\033[0m"
+#define ZEN_BRIGHT_GREEN(str) "\033[92m" str "\033[0m"
+#define ZEN_BRIGHT_YELLOW(str) "\033[93m" str "\033[0m"
+#define ZEN_BRIGHT_BLUE(str) "\033[94m" str "\033[0m"
+#define ZEN_BRIGHT_MAGENTA(str) "\033[95m" str "\033[0m"
+#define ZEN_BRIGHT_CYAN(str) "\033[96m" str "\033[0m"
+#define ZEN_BRIGHT_WHITE(str) "\033[97m" str "\033[0m"
+
+#define ZEN_BOLD(str) "\033[1m" str "\033[0m"
+#define ZEN_UNDERLINE(str) "\033[4m" str "\033[0m"
+#define ZEN_DIM(str) "\033[2m" str "\033[0m"
+#define ZEN_ITALIC(str) "\033[3m" str "\033[0m"
+#define ZEN_STRIKETHROUGH(str) "\033[9m" str "\033[0m"
+#define ZEN_INVERSE(str) "\033[7m" str "\033[0m"
+
+////////////////////////////////////////////////////////////////////////
+
#if ZEN_BUILD_DEBUG
# define ZEN_CHECK_FORMAT_STRING(fmtstr, ...) \
while (false) \
diff --git a/src/zencore/include/zencore/process.h b/src/zencore/include/zencore/process.h
index 5ae7fad68..8cbed781d 100644
--- a/src/zencore/include/zencore/process.h
+++ b/src/zencore/include/zencore/process.h
@@ -174,9 +174,11 @@ struct CreateProcOptions
// allocated and no conhost.exe is spawned. Stdout/stderr still work when redirected
// via pipes. Prefer this for headless worker processes.
Flag_NoConsole = 1 << 3,
- // Create the child in a new process group (CREATE_NEW_PROCESS_GROUP on Windows).
- // Allows sending CTRL_BREAK_EVENT to the child group without affecting the parent.
- Flag_Windows_NewProcessGroup = 1 << 4,
+ // Spawn the child as a new process group leader (its pgid = its own pid).
+ // On Windows: CREATE_NEW_PROCESS_GROUP, enables CTRL_BREAK_EVENT targeting.
+ // On POSIX: child calls setpgid(0,0) / posix_spawn with POSIX_SPAWN_SETPGROUP+pgid=0.
+ // Mutually exclusive with ProcessGroupId > 0.
+ Flag_NewProcessGroup = 1 << 4,
// Allocate a hidden console for the child (CREATE_NO_WINDOW on Windows). Unlike
// Flag_NoConsole the child still gets a console (and a conhost.exe) but no visible
// window. Use this when the child needs a console for stdio but should not show a window.
@@ -197,9 +199,9 @@ struct CreateProcOptions
#if ZEN_PLATFORM_WINDOWS
JobObject* AssignToJob = nullptr; // When set, the process is created suspended, assigned to the job, then resumed
#else
- /// POSIX process group id. When > 0, the child is placed into this process
- /// group via setpgid() before exec. Use the pid of the first child as the
- /// pgid to create a group, then pass the same pgid for subsequent children.
+ /// When > 0, child joins this existing process group. Mutually exclusive with
+ /// Flag_NewProcessGroup; use that flag on the first spawn to create the group,
+ /// then pass the resulting pid here for subsequent spawns to join it.
int ProcessGroupId = 0;
#endif
};
diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp
index e7baa3f8e..ee821944a 100644
--- a/src/zencore/process.cpp
+++ b/src/zencore/process.cpp
@@ -37,7 +37,9 @@ ZEN_THIRD_PARTY_INCLUDES_START
#endif
#if ZEN_PLATFORM_MAC
+# include <crt_externs.h>
# include <libproc.h>
+# include <spawn.h>
# include <sys/types.h>
# include <sys/sysctl.h>
#endif
@@ -135,8 +137,68 @@ IsZombieProcess(int pid, std::error_code& OutEc)
}
return false;
}
+
+static char**
+GetEnviron()
+{
+ return *_NSGetEnviron();
+}
#endif // ZEN_PLATFORM_MAC
+#if ZEN_PLATFORM_LINUX
+static char**
+GetEnviron()
+{
+ return environ;
+}
+#endif // ZEN_PLATFORM_LINUX
+
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+// Holds a null-terminated envp array built by merging the current process environment with
+// a set of overrides. When Overrides is empty, Data points directly to environ (no allocation).
+// Must outlive any posix_spawn / execve call that receives Data.
+struct EnvpHolder
+{
+ char** Data = GetEnviron();
+
+ explicit EnvpHolder(const std::vector<std::pair<std::string, std::string>>& Overrides)
+ {
+ if (Overrides.empty())
+ {
+ return;
+ }
+ std::map<std::string, std::string> EnvMap;
+ for (char** E = GetEnviron(); *E; ++E)
+ {
+ std::string_view Entry(*E);
+ const size_t EqPos = Entry.find('=');
+ if (EqPos != std::string_view::npos)
+ {
+ EnvMap[std::string(Entry.substr(0, EqPos))] = std::string(Entry.substr(EqPos + 1));
+ }
+ }
+ for (const auto& [Key, Value] : Overrides)
+ {
+ EnvMap[Key] = Value;
+ }
+ for (const auto& [Key, Value] : EnvMap)
+ {
+ m_Strings.push_back(Key + "=" + Value);
+ }
+ for (std::string& S : m_Strings)
+ {
+ m_Ptrs.push_back(S.data());
+ }
+ m_Ptrs.push_back(nullptr);
+ Data = m_Ptrs.data();
+ }
+
+private:
+ std::vector<std::string> m_Strings;
+ std::vector<char*> m_Ptrs;
+};
+#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+
//////////////////////////////////////////////////////////////////////////
// Pipe creation for child process stdout capture
@@ -691,6 +753,7 @@ BuildArgV(std::vector<char*>& Out, char* CommandLine)
++Cursor;
}
}
+
#endif // !WINDOWS || TESTS
#if ZEN_PLATFORM_WINDOWS
@@ -766,7 +829,7 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma
{
CreationFlags |= CREATE_NO_WINDOW;
}
- if (Options.Flags & CreateProcOptions::Flag_Windows_NewProcessGroup)
+ if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup)
{
CreationFlags |= CREATE_NEW_PROCESS_GROUP;
}
@@ -1070,23 +1133,30 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine
}
return CreateProcNormal(Executable, CommandLine, Options);
-#else
+#elif ZEN_PLATFORM_LINUX
+ // vfork uses CLONE_VM|CLONE_VFORK: the child shares the parent's address space and the
+ // parent is suspended until the child calls exec or _exit. This avoids page-table duplication
+ // and the ENOMEM that fork() produces on systems with strict overcommit (vm.overcommit_memory=2).
+ // All child-side setup uses only syscalls that do not modify user-space memory.
+ // Environment overrides are merged into envp before vfork so that setenv() is never called
+ // from the child (which would corrupt the shared address space).
std::vector<char*> ArgV;
std::string CommandLineZ(CommandLine);
BuildArgV(ArgV, CommandLineZ.data());
ArgV.push_back(nullptr);
- int ChildPid = fork();
+ EnvpHolder Envp(Options.Environment);
+
+ int ChildPid = vfork();
if (ChildPid < 0)
{
- ThrowLastError("Failed to fork a new child process");
+ ThrowLastError("Failed to vfork a new child process");
}
else if (ChildPid == 0)
{
if (Options.WorkingDirectory != nullptr)
{
- int Result = chdir(Options.WorkingDirectory->c_str());
- ZEN_UNUSED(Result);
+ chdir(Options.WorkingDirectory->c_str());
}
if (Options.StdoutPipe != nullptr && Options.StdoutPipe->WriteFd >= 0)
@@ -1118,23 +1188,99 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine
}
}
- if (Options.ProcessGroupId > 0)
+ if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup)
+ {
+ setpgid(0, 0);
+ }
+ else if (Options.ProcessGroupId > 0)
{
setpgid(0, Options.ProcessGroupId);
}
- for (const auto& [Key, Value] : Options.Environment)
+ execve(Executable.c_str(), ArgV.data(), Envp.Data);
+ _exit(127);
+ }
+
+ return ChildPid;
+#else // macOS
+ std::vector<char*> ArgV;
+ std::string CommandLineZ(CommandLine);
+ BuildArgV(ArgV, CommandLineZ.data());
+ ArgV.push_back(nullptr);
+
+ posix_spawn_file_actions_t FileActions;
+ posix_spawnattr_t Attr;
+
+ int Err = posix_spawn_file_actions_init(&FileActions);
+ if (Err != 0)
+ {
+ ThrowSystemError(Err, "posix_spawn_file_actions_init failed");
+ }
+ auto FileActionsGuard = MakeGuard([&] { posix_spawn_file_actions_destroy(&FileActions); });
+
+ Err = posix_spawnattr_init(&Attr);
+ if (Err != 0)
+ {
+ ThrowSystemError(Err, "posix_spawnattr_init failed");
+ }
+ auto AttrGuard = MakeGuard([&] { posix_spawnattr_destroy(&Attr); });
+
+ if (Options.WorkingDirectory != nullptr)
+ {
+ Err = posix_spawn_file_actions_addchdir_np(&FileActions, Options.WorkingDirectory->c_str());
+ if (Err != 0)
{
- setenv(Key.c_str(), Value.c_str(), 1);
+ ThrowSystemError(Err, "posix_spawn_file_actions_addchdir_np failed");
}
+ }
+
+ if (Options.StdoutPipe != nullptr && Options.StdoutPipe->WriteFd >= 0)
+ {
+ const int StdoutWriteFd = Options.StdoutPipe->WriteFd;
+ ZEN_ASSERT(StdoutWriteFd > STDERR_FILENO);
+ posix_spawn_file_actions_adddup2(&FileActions, StdoutWriteFd, STDOUT_FILENO);
- if (execv(Executable.c_str(), ArgV.data()) < 0)
+ if (Options.StderrPipe != nullptr && Options.StderrPipe->WriteFd >= 0)
+ {
+ const int StderrWriteFd = Options.StderrPipe->WriteFd;
+ ZEN_ASSERT(StderrWriteFd > STDERR_FILENO && StderrWriteFd != StdoutWriteFd);
+ posix_spawn_file_actions_adddup2(&FileActions, StderrWriteFd, STDERR_FILENO);
+ posix_spawn_file_actions_addclose(&FileActions, StderrWriteFd);
+ }
+ else
{
- ThrowLastError("Failed to exec() a new process image");
+ posix_spawn_file_actions_adddup2(&FileActions, StdoutWriteFd, STDERR_FILENO);
}
+
+ posix_spawn_file_actions_addclose(&FileActions, StdoutWriteFd);
+ }
+ else if (!Options.StdoutFile.empty())
+ {
+ posix_spawn_file_actions_addopen(&FileActions, STDOUT_FILENO, Options.StdoutFile.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
+ posix_spawn_file_actions_adddup2(&FileActions, STDOUT_FILENO, STDERR_FILENO);
}
- return ChildPid;
+ if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup)
+ {
+ posix_spawnattr_setflags(&Attr, POSIX_SPAWN_SETPGROUP);
+ posix_spawnattr_setpgroup(&Attr, 0);
+ }
+ else if (Options.ProcessGroupId > 0)
+ {
+ posix_spawnattr_setflags(&Attr, POSIX_SPAWN_SETPGROUP);
+ posix_spawnattr_setpgroup(&Attr, Options.ProcessGroupId);
+ }
+
+ EnvpHolder Envp(Options.Environment);
+
+ pid_t ChildPid = 0;
+ Err = posix_spawn(&ChildPid, Executable.c_str(), &FileActions, &Attr, ArgV.data(), Envp.Data);
+ if (Err != 0)
+ {
+ ThrowSystemError(Err, "Failed to posix_spawn a new child process");
+ }
+
+ return int(ChildPid);
#endif
}
@@ -1252,14 +1398,28 @@ JobObject::Initialize()
}
JOBOBJECT_EXTENDED_LIMIT_INFORMATION LimitInfo = {};
- LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
+ LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION;
if (!SetInformationJobObject(m_JobHandle, JobObjectExtendedLimitInformation, &LimitInfo, sizeof(LimitInfo)))
{
ZEN_WARN("Failed to set job object limits: {}", zen::GetLastError());
CloseHandle(m_JobHandle);
m_JobHandle = nullptr;
+ return;
}
+
+ // Prevent child processes from clearing SEM_NOGPFAULTERRORBOX, which
+ // suppresses WER/Dr. Watson crash dialogs. Without this, a crashing
+ // child can pop a modal dialog and block the monitor thread.
+# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE)
+# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400
+# endif
+ JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{};
+ UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE;
+ SetInformationJobObject(m_JobHandle, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions));
+
+ // Set error mode on the current process so children inherit it.
+ SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX);
}
bool
diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp
index 8c29a8962..c1ac63621 100644
--- a/src/zencore/zencore.cpp
+++ b/src/zencore/zencore.cpp
@@ -273,7 +273,7 @@ zencore_forcelinktests()
zen::uid_forcelink();
zen::uson_forcelink();
zen::usonbuilder_forcelink();
- zen::usonpackage_forcelink();
+ zen::cbpackage_forcelink();
zen::cbjson_forcelink();
zen::cbyaml_forcelink();
zen::workthreadpool_forcelink();
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index e05c9815f..38021be16 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -479,6 +479,18 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
return Ref<IHttpPackageHandler>();
}
+bool
+HttpService::AcceptsLocalFileReferences() const
+{
+ return false;
+}
+
+const ILocalRefPolicy*
+HttpService::GetLocalRefPolicy() const
+{
+ return nullptr;
+}
+
//////////////////////////////////////////////////////////////////////////
HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service)
@@ -705,7 +717,10 @@ HttpServerRequest::ReadPayloadPackage()
{
if (IoBuffer Payload = ReadPayload())
{
- return ParsePackageMessage(std::move(Payload));
+ ParseFlags Flags =
+ (IsLocalMachineRequest() && m_Service.AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences : ParseFlags::kDefault;
+ const ILocalRefPolicy* Policy = EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences) ? m_Service.GetLocalRefPolicy() : nullptr;
+ return ParsePackageMessage(std::move(Payload), {}, Flags, Policy);
}
return {};
@@ -1273,7 +1288,12 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP
return PackageHandlerRef->CreateTarget(Cid, Size);
};
- CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer);
+ ParseFlags PkgFlags = (Request.IsLocalMachineRequest() && Service.AcceptsLocalFileReferences())
+ ? ParseFlags::kAllowLocalReferences
+ : ParseFlags::kDefault;
+ const ILocalRefPolicy* PkgPolicy =
+ EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? Service.GetLocalRefPolicy() : nullptr;
+ CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer, PkgFlags, PkgPolicy);
PackageHandlerRef->OnRequestComplete();
}
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index 5eaed6004..76f219f04 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -12,6 +12,7 @@
#include <zencore/string.h>
#include <zencore/uid.h>
#include <zenhttp/httpcommon.h>
+#include <zenhttp/localrefpolicy.h>
#include <zentelemetry/hyperloglog.h>
#include <zentelemetry/stats.h>
@@ -193,9 +194,16 @@ public:
HttpService() = default;
virtual ~HttpService() = default;
- virtual const char* BaseUri() const = 0;
- virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0;
- virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest);
+ [[nodiscard]] virtual const char* BaseUri() const = 0;
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0;
+ virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest);
+
+ /// Whether this service accepts local file references in inbound packages from local clients.
+ [[nodiscard]] virtual bool AcceptsLocalFileReferences() const;
+
+ /// Returns the local ref policy for validating file paths in inbound local references.
+ /// Returns nullptr by default, which causes file-path local refs to be rejected (fail-closed).
+ [[nodiscard]] virtual const ILocalRefPolicy* GetLocalRefPolicy() const;
// Internals
diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h
index bce771c75..b20bc3a36 100644
--- a/src/zenhttp/include/zenhttp/httpstats.h
+++ b/src/zenhttp/include/zenhttp/httpstats.h
@@ -43,7 +43,7 @@ public:
virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override;
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
diff --git a/src/zenhttp/include/zenhttp/localrefpolicy.h b/src/zenhttp/include/zenhttp/localrefpolicy.h
new file mode 100644
index 000000000..0b37f9dc7
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/localrefpolicy.h
@@ -0,0 +1,21 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <filesystem>
+
+namespace zen {
+
+/// Policy interface for validating local file reference paths in inbound CbPackage messages.
+/// Implementations should throw std::invalid_argument if the path is not allowed.
+class ILocalRefPolicy
+{
+public:
+ virtual ~ILocalRefPolicy() = default;
+
+ /// Validate that a local file reference path is allowed.
+ /// Throws std::invalid_argument if the path escapes the allowed root.
+ virtual void ValidatePath(const std::filesystem::path& Path) const = 0;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h
index 1a5068580..66e3f6e55 100644
--- a/src/zenhttp/include/zenhttp/packageformat.h
+++ b/src/zenhttp/include/zenhttp/packageformat.h
@@ -5,6 +5,7 @@
#include <zencore/compactbinarypackage.h>
#include <zencore/iobuffer.h>
#include <zencore/iohash.h>
+#include <zenhttp/localrefpolicy.h>
#include <functional>
#include <gsl/gsl-lite.hpp>
@@ -97,11 +98,22 @@ gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions);
std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr);
CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr);
-CbPackage ParsePackageMessage(
- IoBuffer Payload,
- std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer {
+
+enum class ParseFlags
+{
+ kDefault = 0,
+ kAllowLocalReferences = (1u << 0), // Allow packages containing local file references (local clients only)
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(ParseFlags);
+
+CbPackage ParsePackageMessage(
+ IoBuffer Payload,
+ std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer {
return IoBuffer{Size};
- });
+ },
+ ParseFlags Flags = ParseFlags::kDefault,
+ const ILocalRefPolicy* Policy = nullptr);
bool IsPackageMessage(IoBuffer Payload);
bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage);
@@ -122,10 +134,11 @@ CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, void* Targe
class CbPackageReader
{
public:
- CbPackageReader();
+ CbPackageReader(ParseFlags Flags = ParseFlags::kDefault);
~CbPackageReader();
void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer);
+ void SetLocalRefPolicy(const ILocalRefPolicy* Policy);
/** Process compact binary package data stream
@@ -149,6 +162,8 @@ private:
kReadingBuffers
} m_CurrentState = State::kInitialState;
+ ParseFlags m_Flags;
+ const ILocalRefPolicy* m_LocalRefPolicy = nullptr;
std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer;
std::vector<IoBuffer> m_PayloadBuffers;
std::vector<CbAttachmentEntry> m_AttachmentEntries;
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
index 710579faa..2d25515d3 100644
--- a/src/zenhttp/include/zenhttp/websocket.h
+++ b/src/zenhttp/include/zenhttp/websocket.h
@@ -59,7 +59,7 @@ class IWebSocketHandler
public:
virtual ~IWebSocketHandler() = default;
- virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0;
+ virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) = 0;
virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0;
virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0;
};
diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp
index 7e6207e56..5ad5ebcc7 100644
--- a/src/zenhttp/monitoring/httpstats.cpp
+++ b/src/zenhttp/monitoring/httpstats.cpp
@@ -196,8 +196,9 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request)
//
void
-HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
+ ZEN_UNUSED(RelativeUri);
ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen");
ZEN_INFO("Stats WebSocket client connected");
diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp
index 9c62c1f2d..267ce386c 100644
--- a/src/zenhttp/packageformat.cpp
+++ b/src/zenhttp/packageformat.cpp
@@ -36,6 +36,71 @@ const std::string_view HandlePrefix(":?#:");
typedef eastl::fixed_vector<IoBuffer, 16> IoBufferVec_t;
+/// Enforce local-ref path policy. Handle-based refs bypass the policy since they use OS handle security.
+/// If no policy is set, file-path local refs are rejected (fail-closed).
+static void
+ApplyLocalRefPolicy(const ILocalRefPolicy* Policy, const std::filesystem::path& Path)
+{
+ if (Policy)
+ {
+ Policy->ValidatePath(Path);
+ }
+ else
+ {
+ throw std::invalid_argument("local file reference rejected: no validation policy");
+ }
+}
+
+// Validates the CbPackageHeader magic and attachment count. Returns the total
+// chunk count (AttachmentCount + 1, including the implicit root object).
+static uint32_t
+ValidatePackageHeader(const CbPackageHeader& Hdr)
+{
+ if (Hdr.HeaderMagic != kCbPkgMagic)
+ {
+ throw std::invalid_argument(
+ fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr.HeaderMagic));
+ }
+ // ChunkCount is AttachmentCount + 1 (the root object is implicit). Guard against
+ // UINT32_MAX wrapping to 0, which would bypass subsequent size checks.
+ if (Hdr.AttachmentCount == UINT32_MAX)
+ {
+ throw std::invalid_argument("invalid CbPackage, attachment count overflow");
+ }
+ return Hdr.AttachmentCount + 1;
+}
+
+struct ValidatedLocalRef
+{
+ bool Valid = false;
+ const CbAttachmentReferenceHeader* Header = nullptr;
+ std::string_view Path;
+ std::string Error;
+};
+
+// Validates that the attachment buffer contains a well-formed local reference
+// header and path. On failure, Valid is false and Error contains the reason.
+static ValidatedLocalRef
+ValidateLocalRef(const IoBuffer& AttachmentBuffer)
+{
+ if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader))
+ {
+ return {.Error = fmt::format("local ref attachment too small for header (size {})", AttachmentBuffer.Size())};
+ }
+
+ const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
+
+ if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)
+ {
+ return {.Error = fmt::format("local ref attachment too small for path (need {}, have {})",
+ sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength,
+ AttachmentBuffer.Size())};
+ }
+
+ const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1);
+ return {.Valid = true, .Header = AttachRefHdr, .Path = std::string_view(PathPointer, AttachRefHdr->AbsolutePathLength)};
+}
+
IoBufferVec_t FormatPackageMessageInternal(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle);
std::vector<IoBuffer>
@@ -361,7 +426,10 @@ IsPackageMessage(IoBuffer Payload)
}
CbPackage
-ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer)
+ParsePackageMessage(IoBuffer Payload,
+ std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer,
+ ParseFlags Flags,
+ const ILocalRefPolicy* Policy)
{
ZEN_TRACE_CPU("ParsePackageMessage");
@@ -372,17 +440,13 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
BinaryReader Reader(Payload);
- const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData());
- if (Hdr->HeaderMagic != kCbPkgMagic)
- {
- throw std::invalid_argument(
- fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr->HeaderMagic));
- }
+ const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData());
+ const uint32_t ChunkCount = ValidatePackageHeader(*Hdr);
Reader.Skip(sizeof(CbPackageHeader));
- const uint32_t ChunkCount = Hdr->AttachmentCount + 1;
-
- if (Reader.Remaining() < sizeof(CbAttachmentEntry) * ChunkCount)
+ // Widen to uint64_t so the multiplication cannot wrap on 32-bit.
+ const uint64_t AttachmentTableSize = uint64_t(sizeof(CbAttachmentEntry)) * ChunkCount;
+ if (Reader.Remaining() < AttachmentTableSize)
{
throw std::invalid_argument(fmt::format("invalid CbPackage, missing attachment entry data (need {} bytes, have {} bytes)",
sizeof(CbAttachmentEntry) * ChunkCount,
@@ -417,15 +481,22 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
{
- // Marshal local reference - a "pointer" to the chunk backing file
-
- ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
+ if (!EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences))
+ {
+ throw std::invalid_argument(
+ fmt::format("package contains local reference (attachment #{}) but local references are not allowed", i));
+ }
- const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
- const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1);
+ // Marshal local reference - a "pointer" to the chunk backing file
- ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
- std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength);
+ ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer);
+ if (!LocalRef.Valid)
+ {
+ MalformedAttachments.push_back(std::make_pair(i, fmt::format("{} for {}", LocalRef.Error, Entry.AttachmentHash)));
+ continue;
+ }
+ const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header;
+ std::string_view PathView = LocalRef.Path;
IoBuffer FullFileBuffer;
@@ -461,13 +532,29 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
}
else
{
+ ApplyLocalRefPolicy(Policy, Path);
FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second;
}
}
if (FullFileBuffer)
{
- IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize()
+ // Guard against offset+size overflow or exceeding the file bounds.
+ const uint64_t FileSize = FullFileBuffer.GetSize();
+ if (AttachRefHdr->PayloadByteOffset > FileSize ||
+ AttachRefHdr->PayloadByteSize > FileSize - AttachRefHdr->PayloadByteOffset)
+ {
+ MalformedAttachments.push_back(
+ std::make_pair(i,
+ fmt::format("Local ref offset/size out of bounds (offset {}, size {}, file size {}) for {}",
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize,
+ FileSize,
+ Entry.AttachmentHash)));
+ continue;
+ }
+
+ IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FileSize
? FullFileBuffer
: IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
@@ -630,7 +717,9 @@ ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPa
return OutPackage.TryLoad(Response);
}
-CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; })
+CbPackageReader::CbPackageReader(ParseFlags Flags)
+: m_Flags(Flags)
+, m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; })
{
}
@@ -644,6 +733,12 @@ CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Ci
m_CreateBuffer = CreateBuffer;
}
+void
+CbPackageReader::SetLocalRefPolicy(const ILocalRefPolicy* Policy)
+{
+ m_LocalRefPolicy = Policy;
+}
+
uint64_t
CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes)
{
@@ -657,12 +752,14 @@ CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes)
return sizeof m_PackageHeader;
case State::kReadingHeader:
- ZEN_ASSERT(DataBytes == sizeof m_PackageHeader);
- memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader);
- ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic);
- m_CurrentState = State::kReadingAttachmentEntries;
- m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1);
- return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry);
+ {
+ ZEN_ASSERT(DataBytes == sizeof m_PackageHeader);
+ memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader);
+ const uint32_t ChunkCount = ValidatePackageHeader(m_PackageHeader);
+ m_CurrentState = State::kReadingAttachmentEntries;
+ m_AttachmentEntries.resize(ChunkCount);
+ return uint64_t(ChunkCount) * sizeof(CbAttachmentEntry);
+ }
case State::kReadingAttachmentEntries:
ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry)));
@@ -691,16 +788,19 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer)
{
// Marshal local reference - a "pointer" to the chunk backing file
- ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
-
- const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
- const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1);
-
- ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
+ ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer);
+ if (!LocalRef.Valid)
+ {
+ throw std::invalid_argument(LocalRef.Error);
+ }
- std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength};
+ const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header;
+ std::filesystem::path Path(Utf8ToWide(LocalRef.Path));
- std::filesystem::path Path{PathView};
+ if (!LocalRef.Path.starts_with(HandlePrefix))
+ {
+ ApplyLocalRefPolicy(m_LocalRefPolicy, Path);
+ }
IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
@@ -714,6 +814,17 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer)
AttachRefHdr->PayloadByteSize));
}
+ // MakeFromFile silently clamps offset+size to the file size. Detect this
+ // to avoid returning a short buffer that could cause subtle downstream issues.
+ if (ChunkReference.GetSize() != AttachRefHdr->PayloadByteSize)
+ {
+ throw std::invalid_argument(fmt::format("local ref offset/size out of bounds for '{}' (requested offset {}, size {}, got size {})",
+ PathToUtf8(Path),
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize,
+ ChunkReference.GetSize()));
+ }
+
return ChunkReference;
};
@@ -732,6 +843,13 @@ CbPackageReader::Finalize()
{
IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex];
+ if ((Entry.Flags & CbAttachmentEntry::kIsLocalRef) && !EnumHasAllFlags(m_Flags, ParseFlags::kAllowLocalReferences))
+ {
+ throw std::invalid_argument(
+ fmt::format("package contains local reference (attachment #{}) but local references are not allowed",
+ CurrentAttachmentIndex));
+ }
+
if (CurrentAttachmentIndex == 0)
{
// Root object
@@ -815,6 +933,13 @@ CbPackageReader::Finalize()
TEST_SUITE_BEGIN("http.packageformat");
+/// Permissive policy that allows any path, for use in tests that exercise local ref
+/// functionality but are not testing path validation.
+struct PermissiveLocalRefPolicy : public ILocalRefPolicy
+{
+ void ValidatePath(const std::filesystem::path&) const override {}
+};
+
TEST_CASE("CbPackage.Serialization")
{
// Make a test package
@@ -922,6 +1047,169 @@ TEST_CASE("CbPackage.LocalRef")
RemainingBytes -= ByteCount;
};
+ PermissiveLocalRefPolicy AllowAllPolicy;
+ CbPackageReader Reader(ParseFlags::kAllowLocalReferences);
+ Reader.SetLocalRefPolicy(&AllowAllPolicy);
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
+ NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes);
+ auto Buffers = Reader.GetPayloadBuffers();
+
+ for (auto& PayloadBuffer : Buffers)
+ {
+ CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
+ }
+
+ Reader.Finalize();
+}
+
+TEST_CASE("CbPackage.Validation.TruncatedHeader")
+{
+ // Payload too small for a CbPackageHeader
+ uint8_t Bytes[] = {0xcc, 0xaa, 0x77, 0xaa};
+ IoBuffer Payload(IoBuffer::Wrap, Bytes, sizeof(Bytes));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.BadMagic")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = 0xDEADBEEF;
+ Hdr.AttachmentCount = 0;
+ IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.AttachmentCountOverflow")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = kCbPkgMagic;
+ Hdr.AttachmentCount = UINT32_MAX;
+ IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.TruncatedAttachmentTable")
+{
+ // Valid header but not enough data for the attachment entries
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = kCbPkgMagic;
+ Hdr.AttachmentCount = 10;
+ IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.TruncatedAttachmentData")
+{
+ // Valid header + one attachment entry claiming more data than available
+ std::vector<uint8_t> Data(sizeof(CbPackageHeader) + sizeof(CbAttachmentEntry));
+
+ CbPackageHeader* Hdr = reinterpret_cast<CbPackageHeader*>(Data.data());
+ Hdr->HeaderMagic = kCbPkgMagic;
+ Hdr->AttachmentCount = 0; // ChunkCount = 1 (root object)
+
+ CbAttachmentEntry* Entry = reinterpret_cast<CbAttachmentEntry*>(Data.data() + sizeof(CbPackageHeader));
+ Entry->PayloadSize = 9999; // way more than available
+ Entry->Flags = CbAttachmentEntry::kIsObject;
+ Entry->AttachmentHash = IoHash();
+
+ IoBuffer Payload(IoBuffer::Wrap, Data.data(), Data.size());
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.LocalRefRejectedByDefault")
+{
+ // Build a valid package with local refs backed by compressed-format files,
+ // then verify it's rejected with default ParseFlags and accepted when allowed.
+ ScopedTemporaryDirectory TempDir;
+ auto Path1 = TempDir.Path() / "abcd";
+ auto Path2 = TempDir.Path() / "efgh";
+
+ // Compress data and write to disk, then create file-backed compressed attachments.
+ // The files must contain compressed-format data because ParsePackageMessage expects it
+ // when resolving local refs.
+ CompressedBuffer Comp1 =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("abcd")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ CompressedBuffer Comp2 =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("efgh")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+
+ IoHash Hash1 = Comp1.DecodeRawHash();
+ IoHash Hash2 = Comp2.DecodeRawHash();
+
+ {
+ IoBuffer Buf1 = Comp1.GetCompressed().Flatten().AsIoBuffer();
+ IoBuffer Buf2 = Comp2.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(Path1, Buf1);
+ WriteFile(Path2, Buf2);
+ }
+
+ // Create attachments from file-backed buffers so FormatPackageMessage uses local refs
+ CbAttachment Attach1{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path1)), Hash1};
+ CbAttachment Attach2{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path2)), Hash2};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("abcd", Attach1);
+ Cbo.AddAttachment("efgh", Attach2);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach1);
+ Pkg.AddAttachment(Attach2);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ // Default flags should reject local refs
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+
+ // With kAllowLocalReferences + a permissive policy, the local-ref gate is passed (the full round-trip
+ // for local refs through ParsePackageMessage is covered by CbPackage.LocalRef via CbPackageReader)
+ PermissiveLocalRefPolicy AllowAllPolicy;
+ CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &AllowAllPolicy);
+ CHECK(Result.GetObject());
+ CHECK(Result.GetAttachments().size() == 2);
+}
+
+TEST_CASE("CbPackage.Validation.LocalRefRejectedByReader")
+{
+ // Same test but via CbPackageReader
+ ScopedTemporaryDirectory TempDir;
+ auto FilePath = TempDir.Path() / "testdata";
+
+ {
+ IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("testdata"));
+ WriteFile(FilePath, Buf);
+ }
+
+ IoBuffer FileBuffer = IoBufferBuilder::MakeFromFile(FilePath);
+ CbAttachment Attach{SharedBuffer(FileBuffer)};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten();
+ const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData());
+ uint64_t RemainingBytes = Buffer.GetSize();
+
+ auto ConsumeBytes = [&](uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ void* ReturnPtr = (void*)CursorPtr;
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ return ReturnPtr;
+ };
+
+ auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ memcpy(TargetBuffer, CursorPtr, ByteCount);
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ };
+
+ // Default flags should reject
CbPackageReader Reader;
uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
@@ -933,7 +1221,199 @@ TEST_CASE("CbPackage.LocalRef")
CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
}
- Reader.Finalize();
+ CHECK_THROWS_AS(Reader.Finalize(), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.BadMagicViaReader")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = 0xBADCAFE;
+ Hdr.AttachmentCount = 0;
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.AttachmentCountOverflowViaReader")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = kCbPkgMagic;
+ Hdr.AttachmentCount = UINT32_MAX;
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.PathOutsideRoot")
+{
+ // A file outside the allowed root should be rejected by the policy
+ ScopedTemporaryDirectory AllowedRoot;
+ ScopedTemporaryDirectory OutsideDir;
+
+ auto OutsidePath = OutsideDir.Path() / "secret.dat";
+ {
+ IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("secret"));
+ WriteFile(OutsidePath, Buf);
+ }
+
+ // Create file-backed compressed attachment from outside root
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("secret")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(OutsidePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ // Policy rooted at AllowedRoot should reject the file in OutsideDir
+ struct TestPolicy : public ILocalRefPolicy
+ {
+ std::string Root;
+ void ValidatePath(const std::filesystem::path& Path) const override
+ {
+ std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string();
+ if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0)
+ {
+ throw std::invalid_argument("path outside root");
+ }
+ }
+ } Policy;
+ Policy.Root = std::filesystem::weakly_canonical(AllowedRoot.Path()).string();
+
+ CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.PathInsideRoot")
+{
+ // A file inside the allowed root should be accepted by the policy
+ ScopedTemporaryDirectory TempRoot;
+
+ auto FilePath = TempRoot.Path() / "data.dat";
+
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("hello")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(FilePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ struct TestPolicy : public ILocalRefPolicy
+ {
+ std::string Root;
+ void ValidatePath(const std::filesystem::path& Path) const override
+ {
+ std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string();
+ if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0)
+ {
+ throw std::invalid_argument("path outside root");
+ }
+ }
+ } Policy;
+ Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string();
+
+ CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy);
+ CHECK(Result.GetObject());
+ CHECK(Result.GetAttachments().size() == 1);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.PathTraversal")
+{
+ // A file path containing ".." that resolves outside root should be rejected
+ ScopedTemporaryDirectory TempRoot;
+ ScopedTemporaryDirectory OutsideDir;
+
+ auto OutsidePath = OutsideDir.Path() / "evil.dat";
+
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("evil")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(OutsidePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ struct TestPolicy : public ILocalRefPolicy
+ {
+ std::string Root;
+ void ValidatePath(const std::filesystem::path& Path) const override
+ {
+ std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string();
+ if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0)
+ {
+ throw std::invalid_argument("path outside root");
+ }
+ }
+ } Policy;
+ // Root is TempRoot, but the file lives in OutsideDir
+ Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string();
+
+ CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.NoPolicyFailClosed")
+{
+ // When local refs are allowed but no policy is provided, file-path refs should be rejected
+ ScopedTemporaryDirectory TempDir;
+
+ auto FilePath = TempDir.Path() / "data.dat";
+
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("data")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(FilePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ // kAllowLocalReferences but nullptr policy => fail-closed
+ CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, nullptr), std::invalid_argument);
}
TEST_SUITE_END();
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 7972777b8..6cda84875 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -1275,7 +1275,9 @@ HttpServerConnectionT<SocketType>::HandleRequest()
asio::buffer(ResponseStr->data(), ResponseStr->size()),
asio::bind_executor(
m_Strand,
- [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr, PrefixLen = Service->UriPrefixLength()](
+ const asio::error_code& Ec,
+ std::size_t) {
if (Ec)
{
ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
@@ -1287,7 +1289,9 @@ HttpServerConnectionT<SocketType>::HandleRequest()
Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
Ref<WebSocketConnection> WsConnRef(WsConn.Get());
- WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ std::string_view FullUrl = Conn->m_RequestData.Url();
+ std::string_view RelativeUri = FullUrl.substr(std::min(PrefixLen, static_cast<int>(FullUrl.size())));
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri);
WsConn->Start();
}));
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 2cad97725..1b722940d 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -2595,7 +2595,14 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
&Transaction().Server()));
Ref<WebSocketConnection> WsConnRef(WsConn.Get());
- WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ ExtendableStringBuilder<128> UrlUtf8;
+ WideToUtf8({(wchar_t*)HttpReq->CookedUrl.pAbsPath,
+ gsl::narrow<size_t>(HttpReq->CookedUrl.AbsPathLength / sizeof(wchar_t))},
+ UrlUtf8);
+ int PrefixLen = Service->UriPrefixLength();
+ std::string_view RelativeUri{UrlUtf8.ToView()};
+ RelativeUri.remove_prefix(std::min(PrefixLen, static_cast<int>(RelativeUri.size())));
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri);
WsConn->Start();
return nullptr;
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
index 59c46a418..363c478ae 100644
--- a/src/zenhttp/servers/wstest.cpp
+++ b/src/zenhttp/servers/wstest.cpp
@@ -335,8 +335,9 @@ namespace {
}
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override
{
+ ZEN_UNUSED(RelativeUri);
m_OpenCount.fetch_add(1);
m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); });
diff --git a/src/zennomad/nomadprocess.cpp b/src/zennomad/nomadprocess.cpp
index 1ae968fb7..deecdef05 100644
--- a/src/zennomad/nomadprocess.cpp
+++ b/src/zennomad/nomadprocess.cpp
@@ -37,7 +37,7 @@ struct NomadProcess::Impl
}
CreateProcOptions Options;
- Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup;
+ Options.Flags |= CreateProcOptions::Flag_NewProcessGroup;
CreateProcResult Result = CreateProc("nomad" ZEN_EXE_SUFFIX_LITERAL, "nomad" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options);
diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp
index a04063c4c..ca226444a 100644
--- a/src/zenremotestore/builds/buildstorageoperations.cpp
+++ b/src/zenremotestore/builds/buildstorageoperations.cpp
@@ -1149,8 +1149,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
WriteScavengedSequenceToCache(ScavengeRootPath, ScavengedContent, ScavengeOp);
- WritePartsComplete++;
- if (WritePartsComplete == TotalPartWriteCount)
+ if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount)
{
FilteredWrittenBytesPerSecond.Stop();
}
@@ -1252,10 +1251,10 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
ScavengedLookups,
ScavengedPaths,
WriteCache);
- WritePartsComplete++;
+ bool WritePartsDone = WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount;
if (!m_AbortFlag)
{
- if (WritePartsComplete == TotalPartWriteCount)
+ if (WritePartsDone)
{
FilteredWrittenBytesPerSecond.Stop();
}
@@ -1334,9 +1333,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
Ec.message());
}
- WritePartsComplete++;
-
- if (WritePartsComplete == TotalPartWriteCount)
+ if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount)
{
FilteredWrittenBytesPerSecond.Stop();
}
@@ -1389,25 +1386,20 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
BlockRangeStartIndex,
RangeCount,
ExistsResult,
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
[this,
&RemoteChunkIndexNeedsCopyFromSourceFlags,
&SequenceIndexChunksLeftToWriteCounters,
&WritePartsComplete,
&WriteCache,
&Work,
- TotalRequestCount,
TotalPartWriteCount,
- &FilteredDownloadedBytesPerSecond,
&FilteredWrittenBytesPerSecond,
&PartialBlocks](IoBuffer&& InMemoryBuffer,
const std::filesystem::path& OnDiskPath,
size_t BlockRangeStartIndex,
std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths) {
- if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
- }
-
if (!m_AbortFlag)
{
Work.ScheduleWork(
@@ -1483,8 +1475,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
fmt::format("Partial block {} is malformed", BlockDescription.BlockHash));
}
- WritePartsComplete++;
- if (WritePartsComplete == TotalPartWriteCount)
+ if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount)
{
FilteredWrittenBytesPerSecond.Stop();
}
@@ -1571,8 +1562,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
uint64_t BlockSize = BlockBuffer.GetSize();
m_DownloadStats.DownloadedBlockCount++;
m_DownloadStats.DownloadedBlockByteCount += BlockSize;
- m_DownloadStats.RequestsCompleteCount++;
- if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount)
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount)
{
FilteredDownloadedBytesPerSecond.Stop();
}
@@ -1683,9 +1673,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
}
}
- WritePartsComplete++;
-
- if (WritePartsComplete == TotalPartWriteCount)
+ if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount)
{
FilteredWrittenBytesPerSecond.Stop();
}
@@ -2987,8 +2975,7 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd
ExistingCompressedChunkPath = FindDownloadedChunk(ChunkHash);
if (!ExistingCompressedChunkPath.empty())
{
- m_DownloadStats.RequestsCompleteCount++;
- if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount)
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount)
{
FilteredDownloadedBytesPerSecond.Stop();
}
@@ -3027,11 +3014,11 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd
bool NeedHashVerify =
WriteCompressedChunkToCache(ChunkHash, ChunkTargetPtrs, WriteCache, std::move(CompressedPart));
- WritePartsComplete++;
+ bool WritePartsDone = WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount;
if (!AbortFlag)
{
- if (WritePartsComplete == TotalPartWriteCount)
+ if (WritePartsDone)
{
FilteredWrittenBytesPerSecond.Stop();
}
@@ -3085,6 +3072,8 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd
DownloadBuildBlob(RemoteChunkIndex,
ExistsResult,
Work,
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
[this,
&ExistsResult,
SequenceIndexChunksLeftToWriteCounters,
@@ -3092,15 +3081,9 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd
&Work,
&WritePartsComplete,
TotalPartWriteCount,
- TotalRequestCount,
RemoteChunkIndex,
- &FilteredDownloadedBytesPerSecond,
&FilteredWrittenBytesPerSecond,
ChunkTargetPtrs = std::move(ChunkTargetPtrs)](IoBuffer&& Payload) mutable {
- if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
- }
IoBufferFileReference FileRef;
bool EnableBacklog = Payload.GetFileReference(FileRef);
AsyncWriteDownloadedChunk(m_Options.ZenFolderPath,
@@ -3125,6 +3108,8 @@ void
BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkIndex,
const BlobsExistsResult& ExistsResult,
ParallelWork& Work,
+ uint64_t TotalRequestCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond,
std::function<void(IoBuffer&& Payload)>&& OnDownloaded)
{
const IoHash& ChunkHash = m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex];
@@ -3140,37 +3125,48 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde
uint64_t BlobSize = BuildBlob.GetSize();
m_DownloadStats.DownloadedChunkCount++;
m_DownloadStats.DownloadedChunkByteCount += BlobSize;
- m_DownloadStats.RequestsCompleteCount++;
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount)
+ {
+ FilteredDownloadedBytesPerSecond.Stop();
+ }
OnDownloaded(std::move(BuildBlob));
}
else
{
if (m_RemoteContent.ChunkedContent.ChunkRawSizes[RemoteChunkIndex] >= m_Options.LargeAttachmentSize)
{
- DownloadLargeBlob(
- *m_Storage.BuildStorage,
- m_TempDownloadFolderPath,
- m_BuildId,
- ChunkHash,
- m_Options.PreferredMultipartChunkSize,
- Work,
- m_NetworkPool,
- m_DownloadStats.DownloadedChunkByteCount,
- m_DownloadStats.MultipartAttachmentCount,
- [this, &Work, ChunkHash, RemoteChunkIndex, OnDownloaded = std::move(OnDownloaded)](IoBuffer&& Payload) mutable {
- m_DownloadStats.DownloadedChunkCount++;
- m_DownloadStats.RequestsCompleteCount++;
+ DownloadLargeBlob(*m_Storage.BuildStorage,
+ m_TempDownloadFolderPath,
+ m_BuildId,
+ ChunkHash,
+ m_Options.PreferredMultipartChunkSize,
+ Work,
+ m_NetworkPool,
+ m_DownloadStats.DownloadedChunkByteCount,
+ m_DownloadStats.MultipartAttachmentCount,
+ [this,
+ &Work,
+ &FilteredDownloadedBytesPerSecond,
+ ChunkHash,
+ RemoteChunkIndex,
+ TotalRequestCount,
+ OnDownloaded = std::move(OnDownloaded)](IoBuffer&& Payload) mutable {
+ m_DownloadStats.DownloadedChunkCount++;
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount)
+ {
+ FilteredDownloadedBytesPerSecond.Stop();
+ }
- if (Payload && m_Storage.CacheStorage && m_Options.PopulateCache)
- {
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
- ChunkHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(Payload)));
- }
+ if (Payload && m_Storage.CacheStorage && m_Options.PopulateCache)
+ {
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ ChunkHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(Payload)));
+ }
- OnDownloaded(std::move(Payload));
- });
+ OnDownloaded(std::move(Payload));
+ });
}
else
{
@@ -3193,7 +3189,10 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde
uint64_t BlobSize = BuildBlob.GetSize();
m_DownloadStats.DownloadedChunkCount++;
m_DownloadStats.DownloadedChunkByteCount += BlobSize;
- m_DownloadStats.RequestsCompleteCount++;
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount)
+ {
+ FilteredDownloadedBytesPerSecond.Stop();
+ }
OnDownloaded(std::move(BuildBlob));
}
@@ -3208,6 +3207,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
size_t BlockRangeStartIndex,
size_t BlockRangeCount,
const BlobsExistsResult& ExistsResult,
+ uint64_t TotalRequestCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond,
std::function<void(IoBuffer&& InMemoryBuffer,
const std::filesystem::path& OnDiskPath,
size_t BlockRangeStartIndex,
@@ -3222,6 +3223,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
IoBuffer&& BlockRangeBuffer,
size_t BlockRangeStartIndex,
std::span<const std::pair<uint64_t, uint64_t>> BlockOffsetAndLengths,
+ uint64_t TotalRequestCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond,
const std::function<void(IoBuffer && InMemoryBuffer,
const std::filesystem::path& OnDiskPath,
size_t BlockRangeStartIndex,
@@ -3229,7 +3232,11 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
uint64_t BlockRangeBufferSize = BlockRangeBuffer.GetSize();
m_DownloadStats.DownloadedBlockCount++;
m_DownloadStats.DownloadedBlockByteCount += BlockRangeBufferSize;
- m_DownloadStats.RequestsCompleteCount += BlockOffsetAndLengths.size();
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(BlockOffsetAndLengths.size()) + BlockOffsetAndLengths.size() ==
+ TotalRequestCount)
+ {
+ FilteredDownloadedBytesPerSecond.Stop();
+ }
std::filesystem::path BlockChunkPath;
@@ -3337,6 +3344,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
std::move(PayloadBuffer),
SubRangeStartIndex,
std::vector<std::pair<uint64_t, uint64_t>>{std::make_pair(0u, SubRange.second)},
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
OnDownloaded);
SubRangeCountComplete += SubRangeCount;
continue;
@@ -3361,6 +3370,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
std::move(RangeBuffers.PayloadBuffer),
SubRangeStartIndex,
RangesSpan.subspan(SubRangeCountComplete, SubRangeCount),
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
OnDownloaded);
SubRangeCountComplete += SubRangeCount;
continue;
@@ -3371,6 +3382,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
std::move(RangeBuffers.PayloadBuffer),
SubRangeStartIndex,
RangeBuffers.Ranges,
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
OnDownloaded);
SubRangeCountComplete += SubRangeCount;
continue;
@@ -3413,6 +3426,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
std::move(RangeBuffers.PayloadBuffer),
SubRangeStartIndex,
RangesSpan.subspan(SubRangeCountComplete, SubRangeCount),
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
OnDownloaded);
}
else
@@ -3428,6 +3443,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
std::move(RangeBuffers.PayloadBuffer),
SubRangeStartIndex,
RangeBuffers.Ranges,
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
OnDownloaded);
}
}
@@ -4244,8 +4261,7 @@ BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::pa
bool NeedHashVerify = WriteCompressedChunkToCache(ChunkHash, ChunkTargetPtrs, WriteCache, std::move(CompressedPart));
if (!m_AbortFlag)
{
- WritePartsComplete++;
- if (WritePartsComplete == TotalPartWriteCount)
+ if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount)
{
FilteredWrittenBytesPerSecond.Stop();
}
@@ -6111,8 +6127,7 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
TempUploadStats.BlockCount++;
- UploadedBlockCount++;
- if (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount)
+ if (UploadedBlockCount.fetch_add(1) + 1 == UploadBlockCount && UploadedChunkCount == UploadChunkCount)
{
FilteredUploadedBytesPerSecond.Stop();
}
@@ -6192,8 +6207,8 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
if (IsComplete)
{
TempUploadStats.ChunkCount++;
- UploadedChunkCount++;
- if (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount)
+ if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount &&
+ UploadedBlockCount == UploadBlockCount)
{
FilteredUploadedBytesPerSecond.Stop();
}
@@ -6227,8 +6242,7 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
TempUploadStats.ChunkCount++;
UploadedCompressedChunkSize += Payload.GetSize();
UploadedRawChunkSize += RawSize;
- UploadedChunkCount++;
- if (UploadedChunkCount == UploadChunkCount)
+ if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount && UploadedBlockCount == UploadBlockCount)
{
FilteredUploadedBytesPerSecond.Stop();
}
@@ -6237,8 +6251,6 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
});
};
- std::vector<size_t> GenerateBlockIndexes;
-
std::atomic<uint64_t> GeneratedBlockCount = 0;
std::atomic<uint64_t> GeneratedBlockByteCount = 0;
@@ -6260,9 +6272,9 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
&Lookup,
&NewBlocks,
&NewBlockChunks,
- &GenerateBlockIndexes,
&GeneratedBlockCount,
&GeneratedBlockByteCount,
+ GenerateBlockCount = BlockIndexes.size(),
&AsyncUploadBlock,
&QueuedPendingInMemoryBlocksForUpload](std::atomic<bool>&) {
if (!m_AbortFlag)
@@ -6293,8 +6305,7 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
}
GeneratedBlockByteCount += NewBlocks.BlockSizes[BlockIndex];
- GeneratedBlockCount++;
- if (GeneratedBlockCount == GenerateBlockIndexes.size())
+ if (GeneratedBlockCount.fetch_add(1) + 1 == GenerateBlockCount)
{
FilteredGenerateBlockBytesPerSecond.Stop();
}
@@ -7005,8 +7016,7 @@ BuildsOperationPrimeCache::Execute()
CompositeBuffer(SharedBuffer(Payload)));
}
}
- CompletedDownloadCount++;
- if (CompletedDownloadCount == BlobCount)
+ if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount)
{
FilteredDownloadedBytesPerSecond.Stop();
}
@@ -7029,8 +7039,7 @@ BuildsOperationPrimeCache::Execute()
CompositeBuffer(SharedBuffer(std::move(Payload))));
}
}
- CompletedDownloadCount++;
- if (CompletedDownloadCount == BlobCount)
+ if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount)
{
FilteredDownloadedBytesPerSecond.Stop();
}
diff --git a/src/zenremotestore/builds/jupiterbuildstorage.cpp b/src/zenremotestore/builds/jupiterbuildstorage.cpp
index ad4c4bc89..d837ce07f 100644
--- a/src/zenremotestore/builds/jupiterbuildstorage.cpp
+++ b/src/zenremotestore/builds/jupiterbuildstorage.cpp
@@ -263,7 +263,7 @@ public:
std::vector<std::function<void()>> WorkList;
for (auto& WorkItem : WorkItems)
{
- WorkList.emplace_back([this, WorkItem = std::move(WorkItem), OnSentBytes]() {
+ WorkList.emplace_back([this, WorkItem = std::move(WorkItem), OnSentBytes = std::move(OnSentBytes)]() {
Stopwatch ExecutionTimer;
auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); });
bool IsComplete = false;
@@ -444,11 +444,13 @@ public:
virtual bool GetExtendedStatistics(ExtendedStatistics& OutStats) override
{
- OutStats.ReceivedBytesPerSource.reserve(m_ReceivedBytesPerSource.size());
- for (auto& It : m_ReceivedBytesPerSource)
- {
- OutStats.ReceivedBytesPerSource.insert_or_assign(It.first, m_SourceBytes[It.second]);
- }
+ m_SourceLock.WithSharedLock([this, &OutStats]() {
+ OutStats.ReceivedBytesPerSource.reserve(m_ReceivedBytesPerSource.size());
+ for (auto& It : m_ReceivedBytesPerSource)
+ {
+ OutStats.ReceivedBytesPerSource.insert_or_assign(It.first, m_SourceBytes[It.second].load(std::memory_order_relaxed));
+ }
+ });
return true;
}
@@ -521,15 +523,29 @@ private:
}
if (!Result.Source.empty())
{
- if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source);
- It != m_ReceivedBytesPerSource.end())
- {
- m_SourceBytes[It->second] += Result.ReceivedBytes;
- }
- else
+ if (!m_SourceLock.WithSharedLock([&]() {
+ if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source);
+ It != m_ReceivedBytesPerSource.end())
+ {
+ m_SourceBytes[It->second] += Result.ReceivedBytes;
+ return true;
+ }
+ return false;
+ }))
{
- m_ReceivedBytesPerSource.insert_or_assign(Result.Source, m_SourceBytes.size());
- m_SourceBytes.push_back(Result.ReceivedBytes);
+ m_SourceLock.WithExclusiveLock([&]() {
+ if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source);
+ It != m_ReceivedBytesPerSource.end())
+ {
+ m_SourceBytes[It->second] += Result.ReceivedBytes;
+ }
+ else if (m_SourceCount < MaxSourceCount)
+ {
+ size_t Index = m_SourceCount++;
+ m_ReceivedBytesPerSource.insert_or_assign(Result.Source, Index);
+ m_SourceBytes[Index] += Result.ReceivedBytes;
+ }
+ });
}
}
}
@@ -540,8 +556,11 @@ private:
const std::string m_Bucket;
const std::filesystem::path m_TempFolderPath;
- tsl::robin_map<std::string, uint32_t> m_ReceivedBytesPerSource;
- std::vector<uint64_t> m_SourceBytes;
+ RwLock m_SourceLock;
+ tsl::robin_map<std::string, uint32_t> m_ReceivedBytesPerSource;
+ static constexpr size_t MaxSourceCount = 8u;
+ std::array<std::atomic<uint64_t>, MaxSourceCount> m_SourceBytes;
+ size_t m_SourceCount = 0;
};
std::unique_ptr<BuildStorageBase>
diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h
index 0d2eded58..27dc9de86 100644
--- a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h
+++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h
@@ -261,12 +261,16 @@ private:
void DownloadBuildBlob(uint32_t RemoteChunkIndex,
const BlobsExistsResult& ExistsResult,
ParallelWork& Work,
+ uint64_t TotalRequestCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond,
std::function<void(IoBuffer&& Payload)>&& OnDownloaded);
- void DownloadPartialBlock(std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges,
- size_t BlockRangeIndex,
- size_t BlockRangeCount,
- const BlobsExistsResult& ExistsResult,
+ void DownloadPartialBlock(std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges,
+ size_t BlockRangeIndex,
+ size_t BlockRangeCount,
+ const BlobsExistsResult& ExistsResult,
+ uint64_t TotalRequestCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond,
std::function<void(IoBuffer&& InMemoryBuffer,
const std::filesystem::path& OnDiskPath,
size_t BlockRangeStartIndex,
diff --git a/src/zenremotestore/jupiter/jupitersession.cpp b/src/zenremotestore/jupiter/jupitersession.cpp
index a9788cb4e..d610d1fc8 100644
--- a/src/zenremotestore/jupiter/jupitersession.cpp
+++ b/src/zenremotestore/jupiter/jupitersession.cpp
@@ -673,7 +673,7 @@ JupiterSession::PutMultipartBuildBlob(std::string_view Namespace,
size_t RetryPartIndex = PartNameToIndex.at(RetryPartId);
const MultipartUploadResponse::Part& RetryPart = Workload->PartDescription.Parts[RetryPartIndex];
IoBuffer RetryPartPayload =
- Workload->Transmitter(RetryPart.FirstByte, RetryPart.LastByte - RetryPart.FirstByte - 1);
+ Workload->Transmitter(RetryPart.FirstByte, RetryPart.LastByte - RetryPart.FirstByte);
std::string RetryMultipartUploadResponseRequestString =
fmt::format("/api/v2/builds/{}/{}/{}/blobs/{}/uploadMultipart{}&supportsRedirect={}",
Namespace,
diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp
index 1a9dc10ef..2076adb70 100644
--- a/src/zenremotestore/projectstore/remoteprojectstore.cpp
+++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp
@@ -929,7 +929,6 @@ namespace remotestore_impl {
{
return;
}
- ZEN_ASSERT(UploadAttachment->Size != 0);
if (!UploadAttachment->RawPath.empty())
{
if (UploadAttachment->Size > (MaxChunkEmbedSize * 2))
@@ -7008,6 +7007,60 @@ TEST_CASE("buildcontainer.ignore_missing_file_attachment_warn")
}
}
+TEST_CASE("buildcontainer.zero_byte_file_attachment")
+{
+ // A zero-byte file on disk is a valid attachment. BuildContainer must process
+ // it without hitting ZEN_ASSERT(UploadAttachment->Size != 0) in
+ // ResolveAttachments. The empty file flows through the compress-inline path
+ // and becomes a LooseUploadAttachment with raw size 0.
+ using namespace projectstore_testutils;
+ using namespace std::literals;
+
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ std::unique_ptr<ProjectStore> ProjectStoreDummy;
+ Ref<ProjectStore::Project> Project = MakeTestProject(CidStore, Gc, TempDir.Path(), ProjectStoreDummy);
+
+ std::filesystem::path RootDir = TempDir.Path() / "root";
+ auto FileAtts = CreateFileAttachments(RootDir, std::initializer_list<size_t>{512});
+
+ Ref<ProjectStore::Oplog> Oplog = Project->NewOplog("bc_zero_byte_file", {});
+ REQUIRE(Oplog);
+ Oplog->AppendNewOplogEntry(CreateFilesOplogPackage(Oid::NewOid(), RootDir, FileAtts));
+
+ // Truncate the file to zero bytes after the oplog entry is created.
+ // The file still exists on disk so RewriteOplog's IsFile() check passes,
+ // but MakeFromFile returns a zero-size buffer.
+ std::filesystem::resize_file(FileAtts[0].second, 0);
+
+ WorkerThreadPool WorkerPool(GetWorkerCount());
+
+ CbObject Container = BuildContainer(
+ CidStore,
+ *Project,
+ *Oplog,
+ WorkerPool,
+ 64u * 1024u,
+ 1000,
+ 32u * 1024u,
+ 64u * 1024u * 1024u,
+ /*BuildBlocks=*/true,
+ /*IgnoreMissingAttachments=*/false,
+ /*AllowChunking=*/true,
+ [](CompressedBuffer&&, ChunkBlockDescription&&) {},
+ [](const IoHash&, TGetAttachmentBufferFunc&&) {},
+ [](std::vector<std::pair<IoHash, FetchChunkFunc>>&&) {},
+ /*EmbedLooseFiles=*/true);
+
+ CHECK(Container.GetSize() > 0);
+
+ // The zero-byte attachment is packed into a block via the compress-inline path.
+ CbArrayView Blocks = Container["blocks"sv].AsArrayView();
+ CHECK(Blocks.Num() > 0);
+}
+
TEST_CASE("buildcontainer.embed_loose_files_false_no_rewrite")
{
// EmbedLooseFiles=false: RewriteOp is skipped for file-op entries; they pass through
diff --git a/src/zens3-testbed/main.cpp b/src/zens3-testbed/main.cpp
index 4cd6b411f..1543c4d7c 100644
--- a/src/zens3-testbed/main.cpp
+++ b/src/zens3-testbed/main.cpp
@@ -110,7 +110,7 @@ CreateClient(const cxxopts::ParseResult& Args)
if (Args.count("timeout"))
{
- Options.Timeout = std::chrono::milliseconds(Args["timeout"].as<int>() * 1000);
+ Options.HttpSettings.Timeout = std::chrono::milliseconds(Args["timeout"].as<int>() * 1000);
}
return S3Client(Options);
diff --git a/src/zenserver-test/cache-tests.cpp b/src/zenserver-test/cache-tests.cpp
index 14748e214..986dc67e0 100644
--- a/src/zenserver-test/cache-tests.cpp
+++ b/src/zenserver-test/cache-tests.cpp
@@ -9,6 +9,7 @@
# include <zencore/compactbinarypackage.h>
# include <zencore/compress.h>
# include <zencore/fmtutils.h>
+# include <zenhttp/localrefpolicy.h>
# include <zenhttp/packageformat.h>
# include <zenstore/cache/cachepolicy.h>
# include <zencore/filesystem.h>
@@ -25,6 +26,13 @@ namespace zen::tests {
TEST_SUITE_BEGIN("server.cache");
+/// Permissive policy that allows any path, for use in tests that exercise local ref
+/// functionality but are not testing path validation.
+struct PermissiveLocalRefPolicy : public ILocalRefPolicy
+{
+ void ValidatePath(const std::filesystem::path&) const override {}
+};
+
TEST_CASE("zcache.basic")
{
using namespace std::literals;
@@ -743,7 +751,11 @@ TEST_CASE("zcache.rpc")
if (Result.StatusCode == HttpResponseCode::OK)
{
- CbPackage Response = ParsePackageMessage(Result.ResponsePayload);
+ ParseFlags PFlags = EnumHasAllFlags(AcceptOptions, RpcAcceptOptions::kAllowLocalReferences) ? ParseFlags::kAllowLocalReferences
+ : ParseFlags::kDefault;
+ PermissiveLocalRefPolicy AllowAllPolicy;
+ const ILocalRefPolicy* PPolicy = EnumHasAllFlags(PFlags, ParseFlags::kAllowLocalReferences) ? &AllowAllPolicy : nullptr;
+ CbPackage Response = ParsePackageMessage(Result.ResponsePayload, {}, PFlags, PPolicy);
CHECK(!Response.IsNull());
OutResult.Response = std::move(Response);
CHECK(OutResult.Result.Parse(OutResult.Response));
@@ -1745,8 +1757,13 @@ TEST_CASE("zcache.rpc.partialchunks")
CHECK(Result.StatusCode == HttpResponseCode::OK);
- CbPackage Response = ParsePackageMessage(Result.ResponsePayload);
- bool Loaded = !Response.IsNull();
+ ParseFlags PFlags = EnumHasAllFlags(Options.AcceptOptions, RpcAcceptOptions::kAllowLocalReferences)
+ ? ParseFlags::kAllowLocalReferences
+ : ParseFlags::kDefault;
+ PermissiveLocalRefPolicy AllowAllPolicy;
+ const ILocalRefPolicy* PPolicy = EnumHasAllFlags(PFlags, ParseFlags::kAllowLocalReferences) ? &AllowAllPolicy : nullptr;
+ CbPackage Response = ParsePackageMessage(Result.ResponsePayload, {}, PFlags, PPolicy);
+ bool Loaded = !Response.IsNull();
CHECK_MESSAGE(Loaded, "GetCacheChunks response failed to load.");
cacherequests::GetCacheChunksResult GetCacheChunksResult;
CHECK(GetCacheChunksResult.Parse(Response));
diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp
index 021052a3b..835d72713 100644
--- a/src/zenserver-test/compute-tests.cpp
+++ b/src/zenserver-test/compute-tests.cpp
@@ -21,6 +21,7 @@
# include <zenhttp/httpserver.h>
# include <zenhttp/websocket.h>
# include <zencompute/computeservice.h>
+# include <zencore/fmtutils.h>
# include <zenstore/zenstore.h>
# include <zenutil/zenserverprocess.h>
@@ -36,6 +37,8 @@ using namespace std::literals;
static constexpr std::string_view kBuildSystemVersion = "17fe280d-ccd8-4be8-a9d1-89c944a70969";
static constexpr std::string_view kRot13Version = "13131313-1313-1313-1313-131313131313";
static constexpr std::string_view kSleepVersion = "88888888-8888-8888-8888-888888888888";
+static constexpr std::string_view kFailVersion = "fa11fa11-fa11-fa11-fa11-fa11fa11fa11";
+static constexpr std::string_view kCrashVersion = "c4a50000-c4a5-c4a5-c4a5-c4a5c4a5c4a5";
// In-memory implementation of ChunkResolver for test use.
// Stores compressed data keyed by decompressed content hash.
@@ -104,6 +107,16 @@ RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env)
<< "Sleep"sv;
WorkerWriter << "version"sv << Guid::FromString(kSleepVersion);
WorkerWriter.EndObject();
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Fail"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kFailVersion);
+ WorkerWriter.EndObject();
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Crash"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kCrashVersion);
+ WorkerWriter.EndObject();
WorkerWriter.EndArray();
CbPackage WorkerPackage;
@@ -115,7 +128,7 @@ RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env)
const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString());
HttpClient::Response RegisterResp = Client.Post(WorkerUrl, std::move(WorkerPackage));
REQUIRE_MESSAGE(RegisterResp,
- fmt::format("Worker registration failed: status={}, body={}", int(RegisterResp.StatusCode), RegisterResp.ToText()));
+ fmt::format("Worker registration failed: status={}, body={}", RegisterResp.StatusCode, RegisterResp.ToText()));
return WorkerId;
}
@@ -220,6 +233,83 @@ BuildSleepActionForSession(std::string_view Input, uint64_t SleepTimeMs, InMemor
return ActionWriter.Save();
}
+// Build a Fail action CbPackage. The worker exits with the given exit code.
+static CbPackage
+BuildFailActionPackage(int ExitCode)
+{
+ // The Fail function throws before reading inputs, but the action structure
+ // still requires a valid input attachment for the runner to manifest.
+ std::string_view Dummy = "x"sv;
+
+ CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Dummy.data(), Dummy.size()),
+ OodleCompressor::Selkie,
+ OodleCompressionLevel::HyperFast4);
+
+ const IoHash InputRawHash = InputCompressed.DecodeRawHash();
+ const uint64_t InputRawSize = Dummy.size();
+
+ CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash);
+
+ CbObjectWriter ActionWriter;
+ ActionWriter << "Function"sv
+ << "Fail"sv;
+ ActionWriter << "FunctionVersion"sv << Guid::FromString(kFailVersion);
+ ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion);
+ ActionWriter.BeginObject("Inputs"sv);
+ ActionWriter.BeginObject("Source"sv);
+ ActionWriter.AddAttachment("RawHash"sv, InputAttachment);
+ ActionWriter << "RawSize"sv << InputRawSize;
+ ActionWriter.EndObject();
+ ActionWriter.EndObject();
+ ActionWriter.BeginObject("Constants"sv);
+ ActionWriter << "ExitCode"sv << static_cast<uint64_t>(ExitCode);
+ ActionWriter.EndObject();
+
+ CbPackage ActionPackage;
+ ActionPackage.SetObject(ActionWriter.Save());
+ ActionPackage.AddAttachment(InputAttachment);
+
+ return ActionPackage;
+}
+
+// Build a Crash action CbPackage. The worker process crashes hard.
+// Mode: "abort" (default) or "nullptr" (null pointer dereference).
+static CbPackage
+BuildCrashActionPackage(std::string_view Mode = "abort"sv)
+{
+ std::string_view Dummy = "x"sv;
+
+ CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Dummy.data(), Dummy.size()),
+ OodleCompressor::Selkie,
+ OodleCompressionLevel::HyperFast4);
+
+ const IoHash InputRawHash = InputCompressed.DecodeRawHash();
+ const uint64_t InputRawSize = Dummy.size();
+
+ CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash);
+
+ CbObjectWriter ActionWriter;
+ ActionWriter << "Function"sv
+ << "Crash"sv;
+ ActionWriter << "FunctionVersion"sv << Guid::FromString(kCrashVersion);
+ ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion);
+ ActionWriter.BeginObject("Inputs"sv);
+ ActionWriter.BeginObject("Source"sv);
+ ActionWriter.AddAttachment("RawHash"sv, InputAttachment);
+ ActionWriter << "RawSize"sv << InputRawSize;
+ ActionWriter.EndObject();
+ ActionWriter.EndObject();
+ ActionWriter.BeginObject("Constants"sv);
+ ActionWriter << "Mode"sv << Mode;
+ ActionWriter.EndObject();
+
+ CbPackage ActionPackage;
+ ActionPackage.SetObject(ActionWriter.Save());
+ ActionPackage.AddAttachment(InputAttachment);
+
+ return ActionPackage;
+}
+
static HttpClient::Response
PollForResult(HttpClient& Client, const std::string& ResultUrl, uint64_t TimeoutMs = 30'000)
{
@@ -340,8 +430,9 @@ public:
}
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override
{
+ ZEN_UNUSED(RelativeUri);
m_WsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
}
@@ -469,6 +560,16 @@ BuildWorkerPackage(ZenServerEnvironment& Env, InMemoryChunkResolver& Resolver)
<< "Sleep"sv;
WorkerWriter << "version"sv << Guid::FromString(kSleepVersion);
WorkerWriter.EndObject();
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Fail"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kFailVersion);
+ WorkerWriter.EndObject();
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Crash"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kCrashVersion);
+ WorkerWriter.EndObject();
WorkerWriter.EndArray();
CbPackage WorkerPackage;
@@ -526,7 +627,7 @@ TEST_CASE("function.rot13")
// Submit action via legacy /jobs/{worker} endpoint
const std::string JobUrl = fmt::format("/jobs/{}", WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
- REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission");
@@ -536,7 +637,7 @@ TEST_CASE("function.rot13")
HttpClient::Response ResultResp = PollForResult(Client, ResultUrl);
REQUIRE_MESSAGE(
ResultResp.StatusCode == HttpResponseCode::OK,
- fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput()));
+ fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", ResultResp.StatusCode, Instance.GetLogOutput()));
// Verify result: Rot13("Hello World") == "Uryyb Jbeyq"
CbPackage ResultPackage = ResultResp.AsPackage();
@@ -581,7 +682,7 @@ TEST_CASE("function.workers")
// GET /workers/{worker} — descriptor should match what was registered
const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString());
HttpClient::Response DescResp = Client.Get(WorkerUrl);
- REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", int(DescResp.StatusCode)));
+ REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", DescResp.StatusCode));
CbObject Desc = DescResp.AsObject();
CHECK_EQ(Desc["buildsystem_version"sv].AsUuid(), Guid::FromString(kBuildSystemVersion));
@@ -627,7 +728,7 @@ TEST_CASE("function.queues.lifecycle")
// Create a queue
HttpClient::Response CreateResp = Client.Post("/queues"sv);
- REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText()));
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText()));
const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation");
@@ -651,8 +752,7 @@ TEST_CASE("function.queues.lifecycle")
// Submit action via queue-scoped endpoint
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
- REQUIRE_MESSAGE(SubmitResp,
- fmt::format("Queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Queue job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from queue job submission");
@@ -668,9 +768,8 @@ TEST_CASE("function.queues.lifecycle")
// Retrieve result via queue-scoped /jobs/{lsn} endpoint
const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn);
HttpClient::Response ResultResp = Client.Get(ResultUrl);
- REQUIRE_MESSAGE(
- ResultResp.StatusCode == HttpResponseCode::OK,
- fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput()));
+ REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK,
+ fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", ResultResp.StatusCode, Instance.GetLogOutput()));
// Verify result: Rot13("Hello World") == "Uryyb Jbeyq"
CbPackage ResultPackage = ResultResp.AsPackage();
@@ -712,13 +811,13 @@ TEST_CASE("function.queues.cancel")
// Submit a job
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
- REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
// Cancel the queue
const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
HttpClient::Response CancelResp = Client.Delete(QueueUrl);
REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent,
- fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText()));
+ fmt::format("Queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText()));
// Verify queue status shows cancelled
HttpClient::Response StatusResp = Client.Get(QueueUrl);
@@ -743,7 +842,7 @@ TEST_CASE("function.queues.remote")
// Create a remote queue — response includes both an integer queue_id and an OID queue_token
HttpClient::Response CreateResp = Client.Post("/queues/remote"sv);
REQUIRE_MESSAGE(CreateResp,
- fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText()));
+ fmt::format("Remote queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText()));
CbObject CreateObj = CreateResp.AsObject();
const std::string QueueToken = std::string(CreateObj["queue_token"sv].AsString());
@@ -753,7 +852,7 @@ TEST_CASE("function.queues.remote")
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
REQUIRE_MESSAGE(SubmitResp,
- fmt::format("Remote queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ fmt::format("Remote queue job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from remote queue job submission");
@@ -769,7 +868,7 @@ TEST_CASE("function.queues.remote")
HttpClient::Response ResultResp = Client.Get(ResultUrl);
REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK,
fmt::format("Failed to retrieve result from remote queue: status={}\nServer log:\n{}",
- int(ResultResp.StatusCode),
+ ResultResp.StatusCode,
Instance.GetLogOutput()));
// Verify result: Rot13("Hello World") == "Uryyb Jbeyq"
@@ -801,8 +900,7 @@ TEST_CASE("function.queues.cancel_running")
// Submit a Sleep job long enough that it will still be running when we cancel
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000));
- REQUIRE_MESSAGE(SubmitResp,
- fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission");
@@ -814,7 +912,7 @@ TEST_CASE("function.queues.cancel_running")
const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
HttpClient::Response CancelResp = Client.Delete(QueueUrl);
REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent,
- fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText()));
+ fmt::format("Queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText()));
// The cancelled job should appear in the /completed endpoint once the process exits
const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
@@ -849,7 +947,7 @@ TEST_CASE("function.queues.remote_cancel")
// Create a remote queue to obtain an OID token for token-addressed cancellation
HttpClient::Response CreateResp = Client.Post("/queues/remote"sv);
REQUIRE_MESSAGE(CreateResp,
- fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText()));
+ fmt::format("Remote queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText()));
const std::string QueueToken = std::string(CreateResp.AsObject()["queue_token"sv].AsString());
REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation");
@@ -857,8 +955,7 @@ TEST_CASE("function.queues.remote_cancel")
// Submit a long-running Sleep job via the token-addressed endpoint
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000));
- REQUIRE_MESSAGE(SubmitResp,
- fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission");
@@ -870,7 +967,7 @@ TEST_CASE("function.queues.remote_cancel")
const std::string QueueUrl = fmt::format("/queues/{}", QueueToken);
HttpClient::Response CancelResp = Client.Delete(QueueUrl);
REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent,
- fmt::format("Remote queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText()));
+ fmt::format("Remote queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText()));
// The cancelled job should appear in the token-addressed /completed endpoint
const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken);
@@ -910,13 +1007,13 @@ TEST_CASE("function.queues.drain")
// Submit a long-running job so we can verify it completes even after drain
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
HttpClient::Response Submit1 = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 2'000));
- REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", int(Submit1.StatusCode)));
+ REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", Submit1.StatusCode));
const int Lsn1 = Submit1.AsObject()["lsn"sv].AsInt32();
// Drain the queue
const std::string DrainUrl = fmt::format("/queues/{}/drain", QueueId);
HttpClient::Response DrainResp = Client.Post(DrainUrl);
- REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", int(DrainResp.StatusCode), DrainResp.ToText()));
+ REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", DrainResp.StatusCode, DrainResp.ToText()));
CHECK_EQ(std::string(DrainResp.AsObject()["state"sv].AsString()), "draining");
// Second submission should be rejected with 424
@@ -965,7 +1062,7 @@ TEST_CASE("function.priority")
// jobs by priority when the slot becomes free.
const std::string BlockerJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString());
HttpClient::Response BlockerResp = Client.Post(BlockerJobUrl, BuildSleepActionPackage("data"sv, 1'000));
- REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", int(BlockerResp.StatusCode)));
+ REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", BlockerResp.StatusCode));
// Submit 3 low-priority Rot13 jobs
const std::string LowJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString());
@@ -1104,6 +1201,432 @@ TEST_CASE("function.priority")
}
//////////////////////////////////////////////////////////////////////////
+// Process exit code tests
+//
+// These tests exercise how the compute service handles worker processes
+// that exit with non-zero exit codes, including retry behaviour and
+// final failure reporting.
+
+TEST_CASE("function.exit_code.failed_action")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue with max_retries=0 so the action fails immediately
+ // without being rescheduled.
+ CbObjectWriter ConfigWriter;
+ ConfigWriter << "max_retries"sv << 0;
+
+ CbObjectWriter BodyWriter;
+ BodyWriter << "config"sv << ConfigWriter.Save();
+
+ HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save());
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode));
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+
+ // Submit a Fail action with exit code 42
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(42));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Fail job submission");
+
+ // Poll for the LSN to appear in the completed list
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn),
+ fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ // Verify queue status reflects the failure
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+
+ // Verify action history records the failure
+ const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId);
+ HttpClient::Response HistoryResp = Client.Get(HistoryUrl);
+ REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history");
+
+ bool FoundInHistory = false;
+ for (auto& Item : HistoryResp.AsObject()["history"sv])
+ {
+ if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn)
+ {
+ CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false);
+ FoundInHistory = true;
+ break;
+ }
+ }
+ CHECK_MESSAGE(FoundInHistory, fmt::format("LSN {} not found in action history", Lsn));
+
+ // GET /jobs/{lsn} for a failed action should return OK but with an empty result package
+ const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn);
+ HttpClient::Response ResultResp = Client.Get(ResultUrl);
+ CHECK_EQ(ResultResp.StatusCode, HttpResponseCode::OK);
+}
+
+TEST_CASE("function.exit_code.auto_retry")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue with max_retries=2 so the action is retried twice before
+ // being reported as failed (3 total attempts: initial + 2 retries).
+ CbObjectWriter ConfigWriter;
+ ConfigWriter << "max_retries"sv << 2;
+
+ CbObjectWriter BodyWriter;
+ BodyWriter << "config"sv << ConfigWriter.Save();
+
+ HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save());
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode));
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+
+ // Submit a Fail action — the worker process will exit with code 1 on
+ // every attempt, eventually exhausting retries.
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(1));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}", SubmitResp.StatusCode));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+
+ // Poll for the LSN to appear in the completed list — this only
+ // happens after all retries are exhausted.
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000),
+ fmt::format("LSN {} did not appear in queue {} completed list after retries\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ // Verify the action history records the retry count
+ const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId);
+ HttpClient::Response HistoryResp = Client.Get(HistoryUrl);
+ REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history");
+
+ for (auto& Item : HistoryResp.AsObject()["history"sv])
+ {
+ if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn)
+ {
+ CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false);
+ CHECK_EQ(Item.AsObjectView()["retry_count"sv].AsInt32(), 2);
+ break;
+ }
+ }
+
+ // Queue should show 1 failed, 0 completed
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+}
+
+TEST_CASE("function.exit_code.reschedule_failed")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue with max_retries=1 so we have room for one manual reschedule
+ CbObjectWriter ConfigWriter;
+ ConfigWriter << "max_retries"sv << 1;
+
+ CbObjectWriter BodyWriter;
+ BodyWriter << "config"sv << ConfigWriter.Save();
+
+ HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save());
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode));
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+
+ // Submit a Fail action — auto-retry will fire once, then it lands in results as Failed
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(7));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}", SubmitResp.StatusCode));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+
+ // Wait for the action to exhaust its auto-retry and land in completed
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000),
+ fmt::format("LSN {} did not appear in queue completed list\nServer log:\n{}", Lsn, Instance.GetLogOutput()));
+
+ // Try to manually reschedule — should fail because retry limit is reached
+ const std::string RescheduleUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn);
+ HttpClient::Response RescheduleResp = Client.Post(RescheduleUrl);
+ CHECK_EQ(RescheduleResp.StatusCode, HttpResponseCode::Conflict);
+}
+
+TEST_CASE("function.exit_code.mixed_success_and_failure")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue with max_retries=0 for fast failure
+ CbObjectWriter ConfigWriter;
+ ConfigWriter << "max_retries"sv << 0;
+
+ CbObjectWriter BodyWriter;
+ BodyWriter << "config"sv << ConfigWriter.Save();
+
+ HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save());
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode));
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+
+ // Submit one Rot13 (success) and one Fail (failure)
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+
+ HttpClient::Response SuccessResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello"sv));
+ REQUIRE_MESSAGE(SuccessResp, "Rot13 job submission failed");
+ const int LsnSuccess = SuccessResp.AsObject()["lsn"sv].AsInt32();
+
+ HttpClient::Response FailResp = Client.Post(JobUrl, BuildFailActionPackage(1));
+ REQUIRE_MESSAGE(FailResp, "Fail job submission failed");
+ const int LsnFail = FailResp.AsObject()["lsn"sv].AsInt32();
+
+ // Wait for both to appear in the completed list
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, LsnSuccess),
+ fmt::format("Success LSN {} did not complete\nServer log:\n{}", LsnSuccess, Instance.GetLogOutput()));
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, LsnFail),
+ fmt::format("Fail LSN {} did not complete\nServer log:\n{}", LsnFail, Instance.GetLogOutput()));
+
+ // Verify queue counters
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0);
+}
+
+TEST_CASE("function.crash.abort")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue with max_retries=0 so we don't wait through retries
+ CbObjectWriter ConfigWriter;
+ ConfigWriter << "max_retries"sv << 0;
+
+ CbObjectWriter BodyWriter;
+ BodyWriter << "config"sv << ConfigWriter.Save();
+
+ HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save());
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode));
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+
+ // Submit a Crash action that calls std::abort()
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("abort"sv));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Crash job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Crash job submission");
+
+ // Poll for the LSN to appear in the completed list
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn),
+ fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ // Verify queue status reflects the failure
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+
+ // Verify action history records the failure
+ const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId);
+ HttpClient::Response HistoryResp = Client.Get(HistoryUrl);
+ REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history");
+
+ bool FoundInHistory = false;
+ for (auto& Item : HistoryResp.AsObject()["history"sv])
+ {
+ if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn)
+ {
+ CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false);
+ FoundInHistory = true;
+ break;
+ }
+ }
+ CHECK_MESSAGE(FoundInHistory, fmt::format("LSN {} not found in action history", Lsn));
+}
+
+TEST_CASE("function.crash.nullptr")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue with max_retries=0
+ CbObjectWriter ConfigWriter;
+ ConfigWriter << "max_retries"sv << 0;
+
+ CbObjectWriter BodyWriter;
+ BodyWriter << "config"sv << ConfigWriter.Save();
+
+ HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save());
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode));
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+
+ // Submit a Crash action that dereferences a null pointer
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("nullptr"sv));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Crash job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Crash job submission");
+
+ // Poll for the LSN to appear in the completed list
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn),
+ fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ // Verify queue status reflects the failure
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+}
+
+TEST_CASE("function.crash.auto_retry")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue with max_retries=1 — the crash should be retried once
+ // before being reported as permanently failed.
+ CbObjectWriter ConfigWriter;
+ ConfigWriter << "max_retries"sv << 1;
+
+ CbObjectWriter BodyWriter;
+ BodyWriter << "config"sv << ConfigWriter.Save();
+
+ HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save());
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode));
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+
+ // Submit a Crash action — will crash on every attempt
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("abort"sv));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Crash job submission failed: status={}", SubmitResp.StatusCode));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+
+ // Poll for the LSN to appear in the completed list after retries exhaust
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000),
+ fmt::format("LSN {} did not appear in queue {} completed list after retries\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ // Verify the action history records the retry count
+ const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId);
+ HttpClient::Response HistoryResp = Client.Get(HistoryUrl);
+ REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history");
+
+ for (auto& Item : HistoryResp.AsObject()["history"sv])
+ {
+ if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn)
+ {
+ CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false);
+ CHECK_EQ(Item.AsObjectView()["retry_count"sv].AsInt32(), 1);
+ break;
+ }
+ }
+
+ // Queue should show 1 failed, 0 completed
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+}
+
+//////////////////////////////////////////////////////////////////////////
// Remote worker synchronization tests
//
// These tests exercise the orchestrator discovery path where new compute
@@ -1162,9 +1685,8 @@ TEST_CASE("function.remote.worker_sync_on_discovery")
Sleep(200);
}
- REQUIRE_MESSAGE(
- ResultCode == HttpResponseCode::OK,
- fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput()));
+ REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK,
+ fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", ResultCode, Instance.GetLogOutput()));
REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput()));
@@ -1349,9 +1871,8 @@ TEST_CASE("function.remote.queue_association")
Sleep(200);
}
- REQUIRE_MESSAGE(
- ResultCode == HttpResponseCode::OK,
- fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput()));
+ REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK,
+ fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", ResultCode, Instance.GetLogOutput()));
REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput()));
CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv);
@@ -1481,7 +2002,7 @@ TEST_CASE("function.abandon_running_http")
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000));
- REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", int(SubmitResp.StatusCode)));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", SubmitResp.StatusCode));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN");
@@ -1498,7 +2019,7 @@ TEST_CASE("function.abandon_running_http")
// Trigger abandon via the HTTP endpoint
HttpClient::Response AbandonResp = Client.Post("/abandon"sv);
REQUIRE_MESSAGE(AbandonResp.StatusCode == HttpResponseCode::OK,
- fmt::format("Abandon request failed: status={}, body={}", int(AbandonResp.StatusCode), AbandonResp.ToText()));
+ fmt::format("Abandon request failed: status={}, body={}", AbandonResp.StatusCode, AbandonResp.ToText()));
// Ready endpoint should now return 503
{
@@ -1529,7 +2050,7 @@ TEST_CASE("function.abandon_running_http")
CHECK_MESSAGE(RejectedResp.StatusCode != HttpResponseCode::OK, "Expected action submission to be rejected in Abandoned state");
}
-TEST_CASE("function.session.abandon_pending")
+TEST_CASE("function.session.abandon_pending" * doctest::skip())
{
// Create a session with no runners so actions stay pending
InMemoryChunkResolver Resolver;
@@ -1577,7 +2098,7 @@ TEST_CASE("function.session.abandon_pending")
}
Sleep(100);
}
- CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code)));
+ CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, Code));
}
// Queue should show 0 active, 3 abandoned
@@ -1979,11 +2500,11 @@ TEST_CASE("function.retract_http")
// Submit a long-running Sleep action to occupy the single execution slot
const std::string BlockerUrl = fmt::format("/jobs/{}", WorkerId.ToHexString());
HttpClient::Response BlockerResp = Client.Post(BlockerUrl, BuildSleepActionPackage("data"sv, 30'000));
- REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", int(BlockerResp.StatusCode)));
+ REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", BlockerResp.StatusCode));
// Submit a second action — it will stay pending because the slot is occupied
HttpClient::Response SubmitResp = Client.Post(BlockerUrl, BuildRot13ActionPackage("Retract HTTP Test"sv));
- REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", int(SubmitResp.StatusCode)));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", SubmitResp.StatusCode));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission");
@@ -1996,7 +2517,7 @@ TEST_CASE("function.retract_http")
HttpClient::Response RetractResp = Client.Post(RetractUrl);
CHECK_MESSAGE(RetractResp.StatusCode == HttpResponseCode::OK,
fmt::format("Retract failed: status={}, body={}\nServer log:\n{}",
- int(RetractResp.StatusCode),
+ RetractResp.StatusCode,
RetractResp.ToText(),
Instance.GetLogOutput()));
@@ -2011,7 +2532,42 @@ TEST_CASE("function.retract_http")
Sleep(500);
HttpClient::Response RetractResp2 = Client.Post(RetractUrl);
CHECK_MESSAGE(RetractResp2.StatusCode == HttpResponseCode::OK,
- fmt::format("Second retract failed: status={}, body={}", int(RetractResp2.StatusCode), RetractResp2.ToText()));
+ fmt::format("Second retract failed: status={}, body={}", RetractResp2.StatusCode, RetractResp2.ToText()));
+}
+
+TEST_CASE("function.session.immediate_query_after_enqueue")
+{
+ // Verify that actions are immediately visible to GetActionResult and
+ // FindActionResult right after enqueue, without waiting for the
+ // scheduler thread to process the update.
+
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.Ready();
+
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ CbObject ActionObj = BuildRot13ActionForSession("immediate-query"sv, Resolver);
+
+ auto EnqueueRes = Session.EnqueueAction(ActionObj, 0);
+ REQUIRE_MESSAGE(EnqueueRes, "Failed to enqueue action");
+
+ // Query by LSN immediately — must not return NotFound
+ CbPackage Result;
+ HttpResponseCode Code = Session.GetActionResult(EnqueueRes.Lsn, Result);
+ CHECK_MESSAGE(Code == HttpResponseCode::Accepted,
+ fmt::format("GetActionResult returned {} immediately after enqueue, expected Accepted", Code));
+
+ // Query by ActionId immediately — must not return NotFound
+ const IoHash ActionId = ActionObj.GetHash();
+ CbPackage FindResult;
+ HttpResponseCode FindCode = Session.FindActionResult(ActionId, FindResult);
+ CHECK_MESSAGE(FindCode == HttpResponseCode::Accepted,
+ fmt::format("FindActionResult returned {} immediately after enqueue, expected Accepted", FindCode));
+
+ Session.Shutdown();
}
TEST_SUITE_END();
diff --git a/src/zenserver-test/hub-tests.cpp b/src/zenserver-test/hub-tests.cpp
index b2da552fc..82dfd7e91 100644
--- a/src/zenserver-test/hub-tests.cpp
+++ b/src/zenserver-test/hub-tests.cpp
@@ -377,7 +377,7 @@ TEST_CASE("hub.consul.kv")
consul::ConsulProcess ConsulProc;
ConsulProc.SpawnConsulAgent();
- consul::ConsulClient Client("http://localhost:8500/");
+ consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"});
Client.SetKeyValue("zen/hub/testkey", "testvalue");
std::string RetrievedValue = Client.GetKeyValue("zen/hub/testkey");
@@ -399,7 +399,7 @@ TEST_CASE("hub.consul.hub.registration")
"--consul-health-interval-seconds=5 --consul-deregister-after-seconds=60");
REQUIRE(PortNumber != 0);
- consul::ConsulClient Client("http://localhost:8500/");
+ consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"});
REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000));
// Verify custom intervals flowed through to the registered check
@@ -480,7 +480,7 @@ TEST_CASE("hub.consul.hub.registration.token")
// Use a plain client -- dev-mode Consul doesn't enforce ACLs, but the
// server has exercised the ConsulTokenEnv -> GetEnvVariable -> ConsulClient path.
- consul::ConsulClient Client("http://localhost:8500/");
+ consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"});
REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000));
@@ -501,7 +501,7 @@ TEST_CASE("hub.consul.provision.registration")
Instance.SpawnServerAndWaitUntilReady("--consul-endpoint=http://localhost:8500/ --instance-id=test-instance");
REQUIRE(PortNumber != 0);
- consul::ConsulClient Client("http://localhost:8500/");
+ consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"});
REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000));
diff --git a/src/zenserver-test/projectstore-tests.cpp b/src/zenserver-test/projectstore-tests.cpp
index a37ecb6be..cec453511 100644
--- a/src/zenserver-test/projectstore-tests.cpp
+++ b/src/zenserver-test/projectstore-tests.cpp
@@ -22,6 +22,7 @@ ZEN_THIRD_PARTY_INCLUDES_START
ZEN_THIRD_PARTY_INCLUDES_END
# include <random>
+# include <thread>
namespace zen::tests {
@@ -340,6 +341,102 @@ TEST_CASE("project.basic")
ZEN_INFO("+++++++");
}
+ SUBCASE("snapshot zero byte file")
+ {
+ // A zero-byte file referenced in an oplog entry must survive a
+ // snapshot: the file is read, compressed, stored in CidStore, and
+ // the oplog is rewritten with a BinaryAttachment reference. After
+ // the snapshot the chunk must be retrievable and decompress to an
+ // empty payload.
+
+ std::filesystem::path EmptyFileRelPath = std::filesystem::path("zerobyte_snapshot_test") / "empty.bin";
+ std::filesystem::path EmptyFileAbsPath = RootPath / EmptyFileRelPath;
+ CreateDirectories(MakeSafeAbsolutePath(EmptyFileAbsPath.parent_path()));
+ // Create a zero-byte file on disk.
+ WriteFile(MakeSafeAbsolutePath(EmptyFileAbsPath), IoBuffer{});
+ REQUIRE(IsFile(MakeSafeAbsolutePath(EmptyFileAbsPath)));
+
+ const std::string_view EmptyChunkId{
+ "00000000"
+ "00000000"
+ "00030000"};
+ auto EmptyFileOid = zen::Oid::FromHexString(EmptyChunkId);
+
+ zen::CbObjectWriter OpWriter;
+ OpWriter << "key"
+ << "zero_byte_test";
+ OpWriter.BeginArray("files");
+ OpWriter.BeginObject();
+ OpWriter << "id" << EmptyFileOid;
+ OpWriter << "clientpath"
+ << "/{engine}/empty_file";
+ OpWriter << "serverpath" << EmptyFileRelPath.c_str();
+ OpWriter.EndObject();
+ OpWriter.EndArray();
+
+ zen::CbObject Op = OpWriter.Save();
+ zen::CbPackage OpPackage(Op);
+
+ zen::BinaryWriter MemOut;
+ legacy::SaveCbPackage(OpPackage, MemOut);
+
+ HttpClient Http{BaseUri};
+
+ {
+ auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView()));
+ REQUIRE(Response);
+ CHECK(Response.StatusCode == HttpResponseCode::Created);
+ }
+
+ // Read file data before snapshot - raw and uncompressed, 0 bytes.
+ // http.sys converts a 200 OK with empty body to 204 No Content, so
+ // accept either status code.
+ {
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << "/" << EmptyChunkId;
+ auto Response = Http.Get(ChunkGetUri);
+
+ REQUIRE(Response);
+ CHECK((Response.StatusCode == HttpResponseCode::OK || Response.StatusCode == HttpResponseCode::NoContent));
+ CHECK(Response.ResponsePayload.GetSize() == 0);
+ }
+
+ // Trigger snapshot.
+ {
+ IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer.AddString("method"sv, "snapshot"sv); });
+ auto Response = Http.Post("/rpc"sv, Payload);
+ REQUIRE(Response);
+ CHECK(Response.StatusCode == HttpResponseCode::OK);
+ }
+
+ // Read chunk after snapshot - compressed, decompresses to 0 bytes.
+ {
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << "/" << EmptyChunkId;
+ auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}});
+
+ REQUIRE(Response);
+ REQUIRE(Response.StatusCode == HttpResponseCode::OK);
+
+ IoBuffer Data = Response.ResponsePayload;
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize);
+ REQUIRE(Compressed);
+ CHECK(RawSize == 0);
+ IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer();
+ CHECK(DataDecompressed.GetSize() == 0);
+ }
+
+ // Cleanup
+ {
+ std::error_code Ec;
+ DeleteDirectories(MakeSafeAbsolutePath(RootPath / "zerobyte_snapshot_test"), Ec);
+ }
+
+ ZEN_INFO("+++++++");
+ }
+
SUBCASE("test chunk not found error")
{
HttpClient Http{BaseUri};
@@ -1154,6 +1251,412 @@ TEST_CASE("project.rpcappendop")
}
}
+TEST_CASE("project.file.data.transitions")
+{
+ using namespace utils;
+
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestDir);
+ const uint16_t PortNumber = Instance.SpawnServerAndWaitUntilReady();
+
+ zen::StringBuilder<64> ServerBaseUri;
+ ServerBaseUri << fmt::format("http://localhost:{}", PortNumber);
+
+ // Set up a root directory with a test file on disk for path-referenced serving
+ std::filesystem::path RootDir = TestDir / "root";
+ std::filesystem::path TestFilePath = RootDir / "content" / "testfile.bin";
+ std::filesystem::path RelServerPath = std::filesystem::path("content") / "testfile.bin";
+ CreateDirectories(TestFilePath.parent_path());
+ IoBuffer FileBlob = CreateRandomBlob(4096);
+ WriteFile(TestFilePath, FileBlob);
+
+ // Create a compressed blob to use as a CAS-referenced attachment (content differs from FileBlob)
+ CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(CreateRandomBlob(2048)));
+
+ // Fixed chunk IDs for the file entry across sub-tests
+ const std::string_view FileChunkIdStr{
+ "aa000000"
+ "bb000000"
+ "cc000001"};
+ Oid FileOid = Oid::FromHexString(FileChunkIdStr);
+
+ HttpClient Http{ServerBaseUri};
+
+ auto MakeProject = [&](std::string_view ProjectName) {
+ CbObjectWriter Project;
+ Project.AddString("id"sv, ProjectName);
+ Project.AddString("root"sv, PathToUtf8(RootDir.c_str()));
+ Project.AddString("engine"sv, ""sv);
+ Project.AddString("project"sv, ""sv);
+ Project.AddString("projectfile"sv, ""sv);
+ HttpClient::Response Response = Http.Post(fmt::format("/prj/{}", ProjectName), Project.Save());
+ REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("MakeProject"));
+ };
+
+ auto MakeOplog = [&](std::string_view ProjectName, std::string_view OplogName) {
+ HttpClient::Response Response =
+ Http.Post(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName), IoBuffer{}, ZenContentType::kCbObject);
+ REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("MakeOplog"));
+ };
+
+ auto PostOplogEntry = [&](std::string_view ProjectName, std::string_view OplogName, const CbPackage& OpPackage) {
+ zen::BinaryWriter MemOut;
+ legacy::SaveCbPackage(OpPackage, MemOut);
+ IoBuffer Body{IoBuffer::Wrap, MemOut.GetData(), MemOut.GetSize()};
+ Body.SetContentType(HttpContentType::kCbPackage);
+ HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/new", ProjectName, OplogName), Body);
+ REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("PostOplogEntry"));
+ };
+
+ auto GetChunk = [&](std::string_view ProjectName) -> HttpClient::Response {
+ return Http.Get(fmt::format("/prj/{}/oplog/oplog/{}", ProjectName, FileChunkIdStr));
+ };
+
+ // Extract the raw decompressed bytes from a chunk response, handling both compressed and uncompressed payloads
+ auto GetDecompressedPayload = [](const HttpClient::Response& Response) -> IoBuffer {
+ if (Response.ResponsePayload.GetContentType() == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Response.ResponsePayload), RawHash, RawSize);
+ REQUIRE(Compressed);
+ return Compressed.Decompress().AsIoBuffer();
+ }
+ return Response.ResponsePayload;
+ };
+
+ auto TriggerGcAndWait = [&]() {
+ HttpClient::Response TriggerResponse = Http.Post("/admin/gc?smallobjects=true"sv, IoBuffer{});
+ REQUIRE_MESSAGE(TriggerResponse.IsSuccess(), TriggerResponse.ErrorMessage("TriggerGc"));
+
+ for (int Attempt = 0; Attempt < 100; ++Attempt)
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ HttpClient::Response StatusResponse = Http.Get("/admin/gc"sv);
+ REQUIRE_MESSAGE(StatusResponse.IsSuccess(), StatusResponse.ErrorMessage("GcStatus"));
+ CbObject StatusObj = StatusResponse.AsObject();
+ if (StatusObj["Status"sv].AsString() == "Idle"sv)
+ {
+ return;
+ }
+ }
+ FAIL("GC did not complete within timeout");
+ };
+
+ auto BuildPathReferencedFileOp = [&](const Oid& KeyId) -> CbPackage {
+ CbPackage Package;
+ CbObjectWriter Object;
+ Object << "key"sv << OidAsString(KeyId);
+ Object.BeginArray("files"sv);
+ Object.BeginObject();
+ Object << "id"sv << FileOid;
+ Object << "serverpath"sv << RelServerPath.string();
+ Object << "clientpath"sv
+ << "/{engine}/testfile.bin"sv;
+ Object.EndObject();
+ Object.EndArray();
+ Package.SetObject(Object.Save());
+ return Package;
+ };
+
+ auto BuildHashReferencedFileOp = [&](const Oid& KeyId, const CompressedBuffer& Blob) -> CbPackage {
+ CbPackage Package;
+ CbObjectWriter Object;
+ Object << "key"sv << OidAsString(KeyId);
+ CbAttachment Attach(Blob, Blob.DecodeRawHash());
+ Object.BeginArray("files"sv);
+ Object.BeginObject();
+ Object << "id"sv << FileOid;
+ Object << "data"sv << Attach;
+ Object << "clientpath"sv
+ << "/{engine}/testfile.bin"sv;
+ Object.EndObject();
+ Object.EndArray();
+ Package.AddAttachment(Attach);
+ Package.SetObject(Object.Save());
+ return Package;
+ };
+
+ SUBCASE("path-referenced file is retrievable")
+ {
+ MakeProject("proj_path"sv);
+ MakeOplog("proj_path"sv, "oplog"sv);
+
+ CbPackage Op = BuildPathReferencedFileOp(Oid::NewOid());
+ PostOplogEntry("proj_path"sv, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk("proj_path"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK_EQ(Payload.GetSize(), FileBlob.GetSize());
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+
+ SUBCASE("hash-referenced file is retrievable")
+ {
+ MakeProject("proj_hash"sv);
+ MakeOplog("proj_hash"sv, "oplog"sv);
+
+ CbPackage Op = BuildHashReferencedFileOp(Oid::NewOid(), CompressedBlob);
+ PostOplogEntry("proj_hash"sv, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk("proj_hash"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer();
+ CHECK_EQ(Payload.GetSize(), ExpectedDecompressed.GetSize());
+ CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView()));
+ }
+ }
+
+ SUBCASE("hash-referenced to path-referenced transition with different content")
+ {
+ MakeProject("proj_hash_to_path_diff"sv);
+ MakeOplog("proj_hash_to_path_diff"sv, "oplog"sv);
+
+ Oid FirstOpKey = Oid::NewOid();
+ Oid SecondOpKey;
+ bool RunGcAfterTransition = false;
+
+ SUBCASE("new op key") { SecondOpKey = Oid::NewOid(); }
+ SUBCASE("same op key") { SecondOpKey = FirstOpKey; }
+ SUBCASE("new op key with gc")
+ {
+ SecondOpKey = Oid::NewOid();
+ RunGcAfterTransition = true;
+ }
+ SUBCASE("same op key with gc")
+ {
+ SecondOpKey = FirstOpKey;
+ RunGcAfterTransition = true;
+ }
+
+ // First op: file with CAS hash (content differs from the on-disk file)
+ {
+ CbPackage Op = BuildHashReferencedFileOp(FirstOpKey, CompressedBlob);
+ PostOplogEntry("proj_hash_to_path_diff"sv, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk("proj_hash_to_path_diff"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer();
+ CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView()));
+ }
+ }
+
+ // Second op: same FileId transitions to serverpath (different data)
+ {
+ CbPackage Op = BuildPathReferencedFileOp(SecondOpKey);
+ PostOplogEntry("proj_hash_to_path_diff"sv, "oplog"sv, Op);
+ }
+
+ if (RunGcAfterTransition)
+ {
+ TriggerGcAndWait();
+ }
+
+ // Must serve the on-disk file content, not the old CAS blob
+ HttpClient::Response Response = GetChunk("proj_hash_to_path_diff"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK_EQ(Payload.GetSize(), FileBlob.GetSize());
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+
+ SUBCASE("hash-referenced to path-referenced transition with identical content")
+ {
+ // Compress the same on-disk file content as a CAS blob so both references yield identical data
+ CompressedBuffer MatchingBlob = CompressedBuffer::Compress(SharedBuffer::Clone(FileBlob.GetView()));
+
+ MakeProject("proj_hash_to_path_same"sv);
+ MakeOplog("proj_hash_to_path_same"sv, "oplog"sv);
+
+ Oid FirstOpKey = Oid::NewOid();
+ Oid SecondOpKey;
+ bool RunGcAfterTransition = false;
+
+ SUBCASE("new op key") { SecondOpKey = Oid::NewOid(); }
+ SUBCASE("same op key") { SecondOpKey = FirstOpKey; }
+ SUBCASE("new op key with gc")
+ {
+ SecondOpKey = Oid::NewOid();
+ RunGcAfterTransition = true;
+ }
+ SUBCASE("same op key with gc")
+ {
+ SecondOpKey = FirstOpKey;
+ RunGcAfterTransition = true;
+ }
+
+ // First op: file with CAS hash (content matches the on-disk file)
+ {
+ CbPackage Op = BuildHashReferencedFileOp(FirstOpKey, MatchingBlob);
+ PostOplogEntry("proj_hash_to_path_same"sv, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk("proj_hash_to_path_same"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+
+ // Second op: same FileId transitions to serverpath (same data)
+ {
+ CbPackage Op = BuildPathReferencedFileOp(SecondOpKey);
+ PostOplogEntry("proj_hash_to_path_same"sv, "oplog"sv, Op);
+ }
+
+ if (RunGcAfterTransition)
+ {
+ TriggerGcAndWait();
+ }
+
+ // Must still resolve successfully after the transition
+ HttpClient::Response Response = GetChunk("proj_hash_to_path_same"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK_EQ(Payload.GetSize(), FileBlob.GetSize());
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+
+ SUBCASE("path-referenced to hash-referenced transition with different content")
+ {
+ MakeProject("proj_path_to_hash_diff"sv);
+ MakeOplog("proj_path_to_hash_diff"sv, "oplog"sv);
+
+ Oid FirstOpKey = Oid::NewOid();
+ Oid SecondOpKey;
+ bool RunGcAfterTransition = false;
+
+ SUBCASE("new op key") { SecondOpKey = Oid::NewOid(); }
+ SUBCASE("same op key") { SecondOpKey = FirstOpKey; }
+ SUBCASE("new op key with gc")
+ {
+ SecondOpKey = Oid::NewOid();
+ RunGcAfterTransition = true;
+ }
+ SUBCASE("same op key with gc")
+ {
+ SecondOpKey = FirstOpKey;
+ RunGcAfterTransition = true;
+ }
+
+ // First op: file with serverpath
+ {
+ CbPackage Op = BuildPathReferencedFileOp(FirstOpKey);
+ PostOplogEntry("proj_path_to_hash_diff"sv, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk("proj_path_to_hash_diff"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+
+ // Second op: same FileId transitions to CAS hash (different data)
+ {
+ CbPackage Op = BuildHashReferencedFileOp(SecondOpKey, CompressedBlob);
+ PostOplogEntry("proj_path_to_hash_diff"sv, "oplog"sv, Op);
+ }
+
+ if (RunGcAfterTransition)
+ {
+ TriggerGcAndWait();
+ }
+
+ // Must serve the CAS blob content, not the old on-disk file
+ HttpClient::Response Response = GetChunk("proj_path_to_hash_diff"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer();
+ CHECK_EQ(Payload.GetSize(), ExpectedDecompressed.GetSize());
+ CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView()));
+ }
+ }
+
+ SUBCASE("path-referenced to hash-referenced transition with identical content")
+ {
+ // Compress the same on-disk file content as a CAS blob so both references yield identical data
+ CompressedBuffer MatchingBlob = CompressedBuffer::Compress(SharedBuffer::Clone(FileBlob.GetView()));
+
+ MakeProject("proj_path_to_hash_same"sv);
+ MakeOplog("proj_path_to_hash_same"sv, "oplog"sv);
+
+ Oid FirstOpKey = Oid::NewOid();
+ Oid SecondOpKey;
+ bool RunGcAfterTransition = false;
+
+ SUBCASE("new op key") { SecondOpKey = Oid::NewOid(); }
+ SUBCASE("same op key") { SecondOpKey = FirstOpKey; }
+ SUBCASE("new op key with gc")
+ {
+ SecondOpKey = Oid::NewOid();
+ RunGcAfterTransition = true;
+ }
+ SUBCASE("same op key with gc")
+ {
+ SecondOpKey = FirstOpKey;
+ RunGcAfterTransition = true;
+ }
+
+ // First op: file with serverpath
+ {
+ CbPackage Op = BuildPathReferencedFileOp(FirstOpKey);
+ PostOplogEntry("proj_path_to_hash_same"sv, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk("proj_path_to_hash_same"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+
+ // Second op: same FileId transitions to CAS hash (same data)
+ {
+ CbPackage Op = BuildHashReferencedFileOp(SecondOpKey, MatchingBlob);
+ PostOplogEntry("proj_path_to_hash_same"sv, "oplog"sv, Op);
+ }
+
+ if (RunGcAfterTransition)
+ {
+ TriggerGcAndWait();
+ }
+
+ // Must still resolve successfully after the transition
+ HttpClient::Response Response = GetChunk("proj_path_to_hash_same"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK_EQ(Payload.GetSize(), FileBlob.GetSize());
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+}
+
TEST_SUITE_END();
} // namespace zen::tests
diff --git a/src/zenserver/frontend/html/compute/hub.html b/src/zenserver/frontend/html/compute/hub.html
index b15b34577..41c80d3a3 100644
--- a/src/zenserver/frontend/html/compute/hub.html
+++ b/src/zenserver/frontend/html/compute/hub.html
@@ -83,7 +83,7 @@
}
async function fetchStats() {
- var data = await fetchJSON('/hub/stats');
+ var data = await fetchJSON('/stats/hub');
var current = data.currentInstanceCount || 0;
var max = data.maxInstanceCount || 0;
diff --git a/src/zenserver/frontend/html/pages/builds.js b/src/zenserver/frontend/html/pages/builds.js
index 095f0bf29..6b3426378 100644
--- a/src/zenserver/frontend/html/pages/builds.js
+++ b/src/zenserver/frontend/html/pages/builds.js
@@ -16,7 +16,7 @@ export class Page extends ZenPage
this.set_title("build store");
// Build Store Stats
- const stats_section = this.add_section("Build Store Stats");
+ const stats_section = this._collapsible_section("Build Store Service Stats");
stats_section.tag().classify("dropall").text("raw yaml \u2192").on_click(() => {
window.open("/stats/builds.yaml", "_blank");
});
diff --git a/src/zenserver/frontend/html/pages/cache.js b/src/zenserver/frontend/html/pages/cache.js
index 1fc8227c8..e0f6f73b6 100644
--- a/src/zenserver/frontend/html/pages/cache.js
+++ b/src/zenserver/frontend/html/pages/cache.js
@@ -95,39 +95,6 @@ export class Page extends ZenPage
}
}
- _collapsible_section(name)
- {
- const section = this.add_section(name);
- const container = section._parent.inner();
- const heading = container.firstElementChild;
-
- heading.style.cursor = "pointer";
- heading.style.userSelect = "none";
-
- const indicator = document.createElement("span");
- indicator.textContent = " \u25BC";
- indicator.style.fontSize = "0.7em";
- heading.appendChild(indicator);
-
- let collapsed = false;
- heading.addEventListener("click", (e) => {
- if (e.target !== heading && e.target !== indicator)
- {
- return;
- }
- collapsed = !collapsed;
- indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
- let sibling = heading.nextElementSibling;
- while (sibling)
- {
- sibling.style.display = collapsed ? "none" : "";
- sibling = sibling.nextElementSibling;
- }
- });
-
- return section;
- }
-
_render_stats(stats)
{
const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj);
diff --git a/src/zenserver/frontend/html/pages/compute.js b/src/zenserver/frontend/html/pages/compute.js
index ab3d49c27..2eb4d4e9b 100644
--- a/src/zenserver/frontend/html/pages/compute.js
+++ b/src/zenserver/frontend/html/pages/compute.js
@@ -24,6 +24,12 @@ function formatTime(date)
return date.toLocaleTimeString([], { hour: "2-digit", minute: "2-digit", second: "2-digit" });
}
+function truncateHash(hash)
+{
+ if (!hash || hash.length <= 15) return hash;
+ return hash.slice(0, 6) + "\u2026" + hash.slice(-6);
+}
+
function formatDuration(startDate, endDate)
{
if (!startDate || !endDate) return "-";
@@ -100,39 +106,6 @@ export class Page extends ZenPage
}, 2000);
}
- _collapsible_section(name)
- {
- const section = this.add_section(name);
- const container = section._parent.inner();
- const heading = container.firstElementChild;
-
- heading.style.cursor = "pointer";
- heading.style.userSelect = "none";
-
- const indicator = document.createElement("span");
- indicator.textContent = " \u25BC";
- indicator.style.fontSize = "0.7em";
- heading.appendChild(indicator);
-
- let collapsed = false;
- heading.addEventListener("click", (e) => {
- if (e.target !== heading && e.target !== indicator)
- {
- return;
- }
- collapsed = !collapsed;
- indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
- let sibling = heading.nextElementSibling;
- while (sibling)
- {
- sibling.style.display = collapsed ? "none" : "";
- sibling = sibling.nextElementSibling;
- }
- });
-
- return section;
- }
-
async _load_chartjs()
{
if (window.Chart)
@@ -338,11 +311,7 @@ export class Page extends ZenPage
{
const workerIds = data.workers || [];
- if (this._workers_table)
- {
- this._workers_table.clear();
- }
- else
+ if (!this._workers_table)
{
this._workers_table = this._workers_host.add_widget(
Table,
@@ -353,6 +322,7 @@ export class Page extends ZenPage
if (workerIds.length === 0)
{
+ this._workers_table.clear();
return;
}
@@ -382,6 +352,9 @@ export class Page extends ZenPage
id,
);
+ // Worker ID column: monospace for hex readability
+ row.get_cell(5).style("fontFamily", "'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace");
+
// Make name clickable to expand detail
const cell = row.get_cell(0);
cell.tag().text(name).on_click(() => this._toggle_worker_detail(id, desc));
@@ -579,6 +552,11 @@ export class Page extends ZenPage
["LSN", "queue", "status", "function", "started", "finished", "duration", "worker ID", "action ID"],
Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1
);
+
+ // Right-align hash column headers to match data cells
+ const hdr = this._history_table.inner().firstElementChild;
+ hdr.children[7].style.textAlign = "right";
+ hdr.children[8].style.textAlign = "right";
}
// Entries arrive oldest-first; reverse to show newest at top
@@ -593,7 +571,10 @@ export class Page extends ZenPage
const startDate = filetimeToDate(entry.time_Running);
const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed);
- this._history_table.add_row(
+ const workerId = entry.workerId || "-";
+ const actionId = entry.actionId || "-";
+
+ const row = this._history_table.add_row(
lsn,
queueId,
status,
@@ -601,9 +582,15 @@ export class Page extends ZenPage
formatTime(startDate),
formatTime(endDate),
formatDuration(startDate, endDate),
- entry.workerId || "-",
- entry.actionId || "-",
+ truncateHash(workerId),
+ truncateHash(actionId),
);
+
+ // Hash columns: force right-align (AlignNumeric misses hex strings starting with a-f),
+ // use monospace for readability, and show full value on hover
+ const mono = "'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace";
+ row.get_cell(7).style("textAlign", "right").style("fontFamily", mono).attr("title", workerId);
+ row.get_cell(8).style("textAlign", "right").style("fontFamily", mono).attr("title", actionId);
}
}
diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js
index 1e4c82e3f..e381f4a71 100644
--- a/src/zenserver/frontend/html/pages/entry.js
+++ b/src/zenserver/frontend/html/pages/entry.js
@@ -168,7 +168,7 @@ export class Page extends ZenPage
if (key === "cook.artifacts")
{
action_tb.left().add("view-raw").on_click(() => {
- window.location = "/" + ["prj", project, "oplog", oplog, value+".json"].join("/");
+ window.open("/" + ["prj", project, "oplog", oplog, value+".json"].join("/"), "_self");
});
}
diff --git a/src/zenserver/frontend/html/pages/hub.js b/src/zenserver/frontend/html/pages/hub.js
index 78e3a090c..c9652f31e 100644
--- a/src/zenserver/frontend/html/pages/hub.js
+++ b/src/zenserver/frontend/html/pages/hub.js
@@ -82,7 +82,7 @@ export class Page extends ZenPage
this.set_title("hub");
// Capacity
- const stats_section = this.add_section("Capacity");
+ const stats_section = this._collapsible_section("Hub Service Stats");
this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles");
// Modules
@@ -152,6 +152,8 @@ export class Page extends ZenPage
this._btn_next.className = "module-pager-btn";
this._btn_next.textContent = "Next \u2192";
this._btn_next.addEventListener("click", () => this._go_page(this._page + 1));
+ this._btn_provision = _make_bulk_btn("+", "Provision", () => this._show_provision_modal());
+ pager.appendChild(this._btn_provision);
pager.appendChild(this._btn_prev);
pager.appendChild(this._pager_label);
pager.appendChild(this._btn_next);
@@ -203,27 +205,47 @@ export class Page extends ZenPage
{
const tile = grid.tag().classify("card").classify("stats-tile");
- tile.tag().classify("card-title").text("Active Modules");
+ tile.tag().classify("card-title").text("Instances");
const body = tile.tag().classify("tile-metrics");
this._metric(body, Friendly.sep(current), "currently provisioned", true);
+ this._metric(body, Friendly.sep(max), "high watermark");
+ this._metric(body, Friendly.sep(limit), "maximum allowed");
+ if (limit > 0)
+ {
+ const pct = ((current / limit) * 100).toFixed(0) + "%";
+ this._metric(body, pct, "utilization");
+ }
}
+ const machine = data.machine || {};
+ const limits = data.resource_limits || {};
+ if (machine.disk_total_bytes > 0 || machine.memory_total_mib > 0)
{
- const tile = grid.tag().classify("card").classify("stats-tile");
- tile.tag().classify("card-title").text("Peak Modules");
- const body = tile.tag().classify("tile-metrics");
- this._metric(body, Friendly.sep(max), "high watermark", true);
- }
+ const disk_used = Math.max(0, (machine.disk_total_bytes || 0) - (machine.disk_free_bytes || 0));
+ const mem_used = Math.max(0, (machine.memory_total_mib || 0) - (machine.memory_avail_mib || 0)) * 1024 * 1024;
+ const vmem_used = Math.max(0, (machine.virtual_memory_total_mib || 0) - (machine.virtual_memory_avail_mib || 0)) * 1024 * 1024;
+ const disk_limit = limits.disk_bytes || 0;
+ const mem_limit = limits.memory_bytes || 0;
+ const disk_over = disk_limit > 0 && disk_used > disk_limit;
+ const mem_over = mem_limit > 0 && mem_used > mem_limit;
- {
const tile = grid.tag().classify("card").classify("stats-tile");
- tile.tag().classify("card-title").text("Instance Limit");
- const body = tile.tag().classify("tile-metrics");
- this._metric(body, Friendly.sep(limit), "maximum allowed", true);
- if (limit > 0)
+ if (disk_over || mem_over) { tile.inner().setAttribute("data-over", "true"); }
+ tile.tag().classify("card-title").text("Resources");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ this._metric(left, Friendly.bytes(disk_used), "disk used", true);
+ this._metric(left, Friendly.bytes(machine.disk_total_bytes), "disk total");
+ if (disk_limit > 0) { this._metric(left, Friendly.bytes(disk_limit), "disk limit"); }
+
+ const right = columns.tag().classify("tile-metrics");
+ this._metric(right, Friendly.bytes(mem_used), "memory used", true);
+ if (mem_limit > 0) { this._metric(right, Friendly.bytes(mem_limit), "memory limit"); }
+ if (machine.virtual_memory_total_mib > 0)
{
- const pct = ((current / limit) * 100).toFixed(0) + "%";
- this._metric(body, pct, "utilization");
+ this._metric(right, Friendly.bytes(vmem_used), "vmem used", true);
+ this._metric(right, Friendly.bytes(machine.virtual_memory_total_mib * 1024 * 1024), "vmem total");
}
}
}
@@ -284,6 +306,14 @@ export class Page extends ZenPage
}
row.state_text.nodeValue = state;
row.port_text.nodeValue = m.port ? String(m.port) : "";
+ if (m.state_change_time)
+ {
+ const state_label = state.charAt(0).toUpperCase() + state.slice(1);
+ row.state_since_label.textContent = state_label + " since";
+ row.state_age_label.textContent = state_label + " for";
+ row.state_since_node.nodeValue = m.state_change_time;
+ row.state_age_node.nodeValue = Friendly.timespan(Date.now() - new Date(m.state_change_time).getTime());
+ }
row.btn_open.disabled = state !== "provisioned";
row.btn_hibernate.disabled = !_btn_enabled(state, "hibernate");
row.btn_wake.disabled = !_btn_enabled(state, "wake");
@@ -373,7 +403,7 @@ export class Page extends ZenPage
const td_action = document.createElement("td");
td_action.className = "module-action-cell";
const [wrap_o, btn_o] = _make_action_btn("\u2197", "Open dashboard", () => {
- window.open(`${window.location.protocol}//${window.location.hostname}:${port}`, "_blank");
+ window.open(`/hub/proxy/${port}/dashboard/`, "_blank");
});
btn_o.disabled = state !== "provisioned";
const [wrap_h, btn_h] = _make_action_btn("\u23F8", "Hibernate", () => this._post_module_action(id, "hibernate").then(() => this._update()));
@@ -388,7 +418,7 @@ export class Page extends ZenPage
td_action.appendChild(wrap_o);
tr.appendChild(td_action);
- // Build metrics grid from process_metrics keys.
+ // Build metrics grid: fixed state-time rows followed by process_metrics keys.
// Keys are split into two halves and interleaved so the grid fills
// top-to-bottom in the left column before continuing in the right column.
const metric_nodes = new Map();
@@ -396,6 +426,28 @@ export class Page extends ZenPage
metrics_td.colSpan = 6;
const metrics_grid = document.createElement("div");
metrics_grid.className = "module-metrics-grid";
+
+ const _add_fixed_pair = (label, value_str) => {
+ const label_el = document.createElement("span");
+ label_el.className = "module-metrics-label";
+ label_el.textContent = label;
+ const value_node = document.createTextNode(value_str);
+ const value_el = document.createElement("span");
+ value_el.className = "module-metrics-value";
+ value_el.appendChild(value_node);
+ metrics_grid.appendChild(label_el);
+ metrics_grid.appendChild(value_el);
+ return { label_el, value_node };
+ };
+
+ const state_label = m.state ? m.state.charAt(0).toUpperCase() + m.state.slice(1) : "State";
+ const state_since_str = m.state_change_time || "";
+ const state_age_str = m.state_change_time
+ ? Friendly.timespan(Date.now() - new Date(m.state_change_time).getTime())
+ : "";
+ const { label_el: state_since_label, value_node: state_since_node } = _add_fixed_pair(state_label + " since", state_since_str);
+ const { label_el: state_age_label, value_node: state_age_node } = _add_fixed_pair(state_label + " for", state_age_str);
+
const keys = Object.keys(m.process_metrics || {});
const half = Math.ceil(keys.length / 2);
const add_metric_pair = (key) => {
@@ -423,7 +475,7 @@ export class Page extends ZenPage
metrics_td.appendChild(metrics_grid);
metrics_tr.appendChild(metrics_td);
- row = { tr, metrics_tr, idx: td_idx, cb, dot, state_text: state_node, port_text: port_node, btn_expand, btn_open: btn_o, btn_hibernate: btn_h, btn_wake: btn_w, btn_deprov: btn_d, metric_nodes };
+ row = { tr, metrics_tr, idx: td_idx, cb, dot, state_text: state_node, port_text: port_node, btn_expand, btn_open: btn_o, btn_hibernate: btn_h, btn_wake: btn_w, btn_deprov: btn_d, metric_nodes, state_since_node, state_age_node, state_since_label, state_age_label };
this._row_cache.set(id, row);
}
@@ -614,4 +666,135 @@ export class Page extends ZenPage
await fetch(`/hub/modules/${moduleId}/${action}`, { method: "POST" });
}
+ _show_provision_modal()
+ {
+ const MODULE_ID_RE = /^[A-Za-z0-9][A-Za-z0-9-]*$/;
+
+ const overlay = document.createElement("div");
+ overlay.className = "zen_modal";
+
+ const bg = document.createElement("div");
+ bg.className = "zen_modal_bg";
+ bg.addEventListener("click", () => overlay.remove());
+ overlay.appendChild(bg);
+
+ const dialog = document.createElement("div");
+ overlay.appendChild(dialog);
+
+ const title = document.createElement("div");
+ title.className = "zen_modal_title";
+ title.textContent = "Provision Module";
+ dialog.appendChild(title);
+
+ const content = document.createElement("div");
+ content.className = "zen_modal_message";
+ content.style.textAlign = "center";
+
+ const input = document.createElement("input");
+ input.type = "text";
+ input.placeholder = "module-name";
+ input.style.cssText = "width:100%;font-size:14px;padding:8px 12px;";
+ content.appendChild(input);
+
+ const error_div = document.createElement("div");
+ error_div.style.cssText = "color:var(--theme_fail);font-size:12px;margin-top:8px;min-height:1.2em;";
+ content.appendChild(error_div);
+
+ dialog.appendChild(content);
+
+ const buttons = document.createElement("div");
+ buttons.className = "zen_modal_buttons";
+
+ const btn_cancel = document.createElement("div");
+ btn_cancel.textContent = "Cancel";
+ btn_cancel.addEventListener("click", () => overlay.remove());
+
+ const btn_submit = document.createElement("div");
+ btn_submit.textContent = "Provision";
+
+ buttons.appendChild(btn_cancel);
+ buttons.appendChild(btn_submit);
+ dialog.appendChild(buttons);
+
+ let submitting = false;
+
+ const set_submit_enabled = (enabled) => {
+ btn_submit.style.opacity = enabled ? "" : "0.4";
+ btn_submit.style.pointerEvents = enabled ? "" : "none";
+ };
+
+ set_submit_enabled(false);
+
+ const validate = () => {
+ if (submitting) { return false; }
+ const val = input.value.trim();
+ if (val.length === 0)
+ {
+ error_div.textContent = "";
+ set_submit_enabled(false);
+ return false;
+ }
+ if (!MODULE_ID_RE.test(val))
+ {
+ error_div.textContent = "Only letters, numbers, and hyphens allowed (must start with a letter or number)";
+ set_submit_enabled(false);
+ return false;
+ }
+ error_div.textContent = "";
+ set_submit_enabled(true);
+ return true;
+ };
+
+ input.addEventListener("input", validate);
+
+ const submit = async () => {
+ if (submitting) { return; }
+ const moduleId = input.value.trim();
+ if (!MODULE_ID_RE.test(moduleId)) { return; }
+
+ submitting = true;
+ set_submit_enabled(false);
+ error_div.textContent = "";
+
+ try
+ {
+ const resp = await fetch(`/hub/modules/${encodeURIComponent(moduleId)}/provision`, { method: "POST" });
+ if (resp.ok)
+ {
+ overlay.remove();
+ await this._update();
+ this._navigate_to_module(moduleId);
+ return;
+ }
+ const msg = await resp.text();
+ error_div.textContent = msg || ("HTTP " + resp.status);
+ }
+ catch (e)
+ {
+ error_div.textContent = e.message || "Request failed";
+ }
+ submitting = false;
+ set_submit_enabled(true);
+ };
+
+ btn_submit.addEventListener("click", submit);
+ input.addEventListener("keydown", (e) => {
+ if (e.key === "Enter" && validate()) { submit(); }
+ if (e.key === "Escape") { overlay.remove(); }
+ });
+
+ document.body.appendChild(overlay);
+ input.focus();
+ }
+
+ _navigate_to_module(moduleId)
+ {
+ const idx = this._modules_data.findIndex(m => m.moduleId === moduleId);
+ if (idx >= 0)
+ {
+ this._page = Math.floor(idx / this._page_size);
+ this._render_page();
+ }
+ }
+
}
diff --git a/src/zenserver/frontend/html/pages/orchestrator.js b/src/zenserver/frontend/html/pages/orchestrator.js
index 4a9290a3c..a280fabdb 100644
--- a/src/zenserver/frontend/html/pages/orchestrator.js
+++ b/src/zenserver/frontend/html/pages/orchestrator.js
@@ -46,39 +46,6 @@ export class Page extends ZenPage
this._connect_ws();
}
- _collapsible_section(name)
- {
- const section = this.add_section(name);
- const container = section._parent.inner();
- const heading = container.firstElementChild;
-
- heading.style.cursor = "pointer";
- heading.style.userSelect = "none";
-
- const indicator = document.createElement("span");
- indicator.textContent = " \u25BC";
- indicator.style.fontSize = "0.7em";
- heading.appendChild(indicator);
-
- let collapsed = false;
- heading.addEventListener("click", (e) => {
- if (e.target !== heading && e.target !== indicator)
- {
- return;
- }
- collapsed = !collapsed;
- indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
- let sibling = heading.nextElementSibling;
- while (sibling)
- {
- sibling.style.display = collapsed ? "none" : "";
- sibling = sibling.nextElementSibling;
- }
- });
-
- return section;
- }
-
async _fetch_all()
{
try
diff --git a/src/zenserver/frontend/html/pages/page.js b/src/zenserver/frontend/html/pages/page.js
index cf8d3e3dd..ff530ff8e 100644
--- a/src/zenserver/frontend/html/pages/page.js
+++ b/src/zenserver/frontend/html/pages/page.js
@@ -337,4 +337,37 @@ export class ZenPage extends PageBase
this._metric(right, Friendly.duration(reqData.t_max), "max");
}
}
+
+ _collapsible_section(name)
+ {
+ const section = this.add_section(name);
+ const container = section._parent.inner();
+ const heading = container.firstElementChild;
+
+ heading.style.cursor = "pointer";
+ heading.style.userSelect = "none";
+
+ const indicator = document.createElement("span");
+ indicator.textContent = " \u25BC";
+ indicator.style.fontSize = "0.7em";
+ heading.appendChild(indicator);
+
+ let collapsed = false;
+ heading.addEventListener("click", (e) => {
+ if (e.target !== heading && e.target !== indicator)
+ {
+ return;
+ }
+ collapsed = !collapsed;
+ indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
+ let sibling = heading.nextElementSibling;
+ while (sibling)
+ {
+ sibling.style.display = collapsed ? "none" : "";
+ sibling = sibling.nextElementSibling;
+ }
+ });
+
+ return section;
+ }
}
diff --git a/src/zenserver/frontend/html/pages/projects.js b/src/zenserver/frontend/html/pages/projects.js
index 2469bf70b..dfe4faeb8 100644
--- a/src/zenserver/frontend/html/pages/projects.js
+++ b/src/zenserver/frontend/html/pages/projects.js
@@ -110,39 +110,6 @@ export class Page extends ZenPage
}
}
- _collapsible_section(name)
- {
- const section = this.add_section(name);
- const container = section._parent.inner();
- const heading = container.firstElementChild;
-
- heading.style.cursor = "pointer";
- heading.style.userSelect = "none";
-
- const indicator = document.createElement("span");
- indicator.textContent = " \u25BC";
- indicator.style.fontSize = "0.7em";
- heading.appendChild(indicator);
-
- let collapsed = false;
- heading.addEventListener("click", (e) => {
- if (e.target !== heading && e.target !== indicator)
- {
- return;
- }
- collapsed = !collapsed;
- indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
- let sibling = heading.nextElementSibling;
- while (sibling)
- {
- sibling.style.display = collapsed ? "none" : "";
- sibling = sibling.nextElementSibling;
- }
- });
-
- return section;
- }
-
_clear_param(name)
{
this._params.delete(name);
diff --git a/src/zenserver/frontend/html/pages/workspaces.js b/src/zenserver/frontend/html/pages/workspaces.js
index d31fd7373..2442fb35b 100644
--- a/src/zenserver/frontend/html/pages/workspaces.js
+++ b/src/zenserver/frontend/html/pages/workspaces.js
@@ -13,7 +13,7 @@ export class Page extends ZenPage
this.set_title("workspaces");
// Workspace Service Stats
- const stats_section = this.add_section("Workspace Service Stats");
+ const stats_section = this._collapsible_section("Workspace Service Stats");
this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles");
const stats = await new Fetcher().resource("stats", "ws").json().catch(() => null);
diff --git a/src/zenserver/frontend/html/zen.css b/src/zenserver/frontend/html/zen.css
index d9f7491ea..cb3d78cf2 100644
--- a/src/zenserver/frontend/html/zen.css
+++ b/src/zenserver/frontend/html/zen.css
@@ -816,6 +816,10 @@ zen-banner + zen-nav::part(nav-bar) {
border-color: var(--theme_p0);
}
+.stats-tile[data-over="true"] {
+ border-color: var(--theme_fail);
+}
+
.stats-tile-detailed {
position: relative;
}
diff --git a/src/zenserver/hub/httphubservice.cpp b/src/zenserver/hub/httphubservice.cpp
index ebefcf2e3..eba816793 100644
--- a/src/zenserver/hub/httphubservice.cpp
+++ b/src/zenserver/hub/httphubservice.cpp
@@ -2,6 +2,7 @@
#include "httphubservice.h"
+#include "httpproxyhandler.h"
#include "hub.h"
#include "storageserverinstance.h"
@@ -43,10 +44,11 @@ namespace {
}
} // namespace
-HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpStatusService& StatusService)
+HttpHubService::HttpHubService(Hub& Hub, HttpProxyHandler& Proxy, HttpStatsService& StatsService, HttpStatusService& StatusService)
: m_Hub(Hub)
, m_StatsService(StatsService)
, m_StatusService(StatusService)
+, m_Proxy(Proxy)
{
using namespace std::literals;
@@ -67,6 +69,23 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta
return true;
});
+ m_Router.AddMatcher("port", [](std::string_view Str) -> bool {
+ if (Str.empty())
+ {
+ return false;
+ }
+ for (const auto C : Str)
+ {
+ if (!std::isdigit(C))
+ {
+ return false;
+ }
+ }
+ return true;
+ });
+
+ m_Router.AddMatcher("proxypath", [](std::string_view Str) -> bool { return !Str.empty(); });
+
m_Router.RegisterRoute(
"status",
[this](HttpRouterRequest& Req) {
@@ -78,6 +97,10 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta
Obj << "moduleId" << ModuleId;
Obj << "state" << ToString(Info.State);
Obj << "port" << Info.Port;
+ if (Info.StateChangeTime != std::chrono::system_clock::time_point::min())
+ {
+ Obj << "state_change_time" << ToDateTime(Info.StateChangeTime);
+ }
Obj.BeginObject("process_metrics");
{
Obj << "MemoryBytes" << Info.Metrics.MemoryBytes;
@@ -229,15 +252,23 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta
HttpVerb::kPost);
m_Router.RegisterRoute(
- "stats",
+ "proxy/{port}/{proxypath}",
[this](HttpRouterRequest& Req) {
- CbObjectWriter Obj;
- Obj << "currentInstanceCount" << m_Hub.GetInstanceCount();
- Obj << "maxInstanceCount" << m_Hub.GetMaxInstanceCount();
- Obj << "instanceLimit" << m_Hub.GetConfig().InstanceLimit;
- Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ std::string_view PortStr = Req.GetCapture(1);
+
+ // Use RelativeUriWithExtension to preserve the file extension that the
+ // router's URI parser strips (e.g. ".css", ".js") - the upstream server
+ // needs the full path including the extension.
+ std::string_view FullUri = Req.ServerRequest().RelativeUriWithExtension();
+ std::string_view Prefix = "proxy/";
+
+ // FullUri is "proxy/{port}/{path...}" - skip past "proxy/{port}/"
+ size_t PathStart = Prefix.size() + PortStr.size() + 1;
+ std::string_view PathTail = (PathStart < FullUri.size()) ? FullUri.substr(PathStart) : std::string_view{};
+
+ m_Proxy.HandleProxyRequest(Req.ServerRequest(), PortStr, PathTail);
},
- HttpVerb::kGet);
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead);
m_StatsService.RegisterHandler("hub", *this);
m_StatusService.RegisterHandler("hub", *this);
@@ -286,7 +317,37 @@ HttpHubService::HandleStatusRequest(HttpServerRequest& Request)
void
HttpHubService::HandleStatsRequest(HttpServerRequest& Request)
{
- Request.WriteResponse(HttpResponseCode::OK, CollectStats());
+ CbObjectWriter Cbo;
+
+ EmitSnapshot("requests", m_HttpRequests, Cbo);
+
+ Cbo << "currentInstanceCount" << m_Hub.GetInstanceCount();
+ Cbo << "maxInstanceCount" << m_Hub.GetMaxInstanceCount();
+ Cbo << "instanceLimit" << m_Hub.GetConfig().InstanceLimit;
+
+ SystemMetrics SysMetrics;
+ DiskSpace Disk;
+ m_Hub.GetMachineMetrics(SysMetrics, Disk);
+ Cbo.BeginObject("machine");
+ {
+ Cbo << "disk_free_bytes" << Disk.Free;
+ Cbo << "disk_total_bytes" << Disk.Total;
+ Cbo << "memory_avail_mib" << SysMetrics.AvailSystemMemoryMiB;
+ Cbo << "memory_total_mib" << SysMetrics.SystemMemoryMiB;
+ Cbo << "virtual_memory_avail_mib" << SysMetrics.AvailVirtualMemoryMiB;
+ Cbo << "virtual_memory_total_mib" << SysMetrics.VirtualMemoryMiB;
+ }
+ Cbo.EndObject();
+
+ const ResourceMetrics& Limits = m_Hub.GetConfig().ResourceLimits;
+ Cbo.BeginObject("resource_limits");
+ {
+ Cbo << "disk_bytes" << Limits.DiskUsageBytes;
+ Cbo << "memory_bytes" << Limits.MemoryUsageBytes;
+ }
+ Cbo.EndObject();
+
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
}
CbObject
@@ -369,4 +430,22 @@ HttpHubService::HandleModuleDelete(HttpServerRequest& Request, std::string_view
Request.WriteResponse(HttpResponseCode::OK, Obj.Save());
}
+void
+HttpHubService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
+{
+ m_Proxy.OnWebSocketOpen(std::move(Connection), RelativeUri);
+}
+
+void
+HttpHubService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg)
+{
+ m_Proxy.OnWebSocketMessage(Conn, Msg);
+}
+
+void
+HttpHubService::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason)
+{
+ m_Proxy.OnWebSocketClose(Conn, Code, Reason);
+}
+
} // namespace zen
diff --git a/src/zenserver/hub/httphubservice.h b/src/zenserver/hub/httphubservice.h
index 1bb1c303e..ff2cb0029 100644
--- a/src/zenserver/hub/httphubservice.h
+++ b/src/zenserver/hub/httphubservice.h
@@ -2,11 +2,16 @@
#pragma once
+#include <zencore/thread.h>
#include <zenhttp/httpserver.h>
#include <zenhttp/httpstatus.h>
+#include <zenhttp/websocket.h>
+
+#include <memory>
namespace zen {
+class HttpProxyHandler;
class HttpStatsService;
class Hub;
@@ -16,10 +21,10 @@ class Hub;
* use in UEFN content worker style scenarios.
*
*/
-class HttpHubService : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider
+class HttpHubService : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider, public IWebSocketHandler
{
public:
- HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpStatusService& StatusService);
+ HttpHubService(Hub& Hub, HttpProxyHandler& Proxy, HttpStatsService& StatsService, HttpStatusService& StatusService);
~HttpHubService();
HttpHubService(const HttpHubService&) = delete;
@@ -32,6 +37,11 @@ public:
virtual CbObject CollectStats() override;
virtual uint64_t GetActivityCounter() override;
+ // IWebSocketHandler
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
+
void SetNotificationEndpoint(std::string_view UpstreamNotificationEndpoint, std::string_view InstanceId);
private:
@@ -45,6 +55,8 @@ private:
void HandleModuleGet(HttpServerRequest& Request, std::string_view ModuleId);
void HandleModuleDelete(HttpServerRequest& Request, std::string_view ModuleId);
+
+ HttpProxyHandler& m_Proxy;
};
} // namespace zen
diff --git a/src/zenserver/hub/httpproxyhandler.cpp b/src/zenserver/hub/httpproxyhandler.cpp
new file mode 100644
index 000000000..25842623a
--- /dev/null
+++ b/src/zenserver/hub/httpproxyhandler.cpp
@@ -0,0 +1,504 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpproxyhandler.h"
+
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include <zenhttp/httpclient.h>
+#include <zenhttp/httpwsclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <charconv>
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+#endif // ZEN_WITH_TESTS
+
+namespace zen {
+
+namespace {
+
+ std::string InjectProxyScript(std::string_view Html, uint16_t Port)
+ {
+ ExtendableStringBuilder<2048> Script;
+ Script.Append("<script>\n(function(){\n var P = \"/hub/proxy/");
+ Script.Append(fmt::format("{}", Port));
+ Script.Append(
+ "\";\n"
+ " var OF = window.fetch;\n"
+ " window.fetch = function(u, o) {\n"
+ " if (typeof u === \"string\") {\n"
+ " try {\n"
+ " var p = new URL(u, location.origin);\n"
+ " if (p.origin === location.origin && !p.pathname.startsWith(P))\n"
+ " { p.pathname = P + p.pathname; u = p.toString(); }\n"
+ " } catch(e) {\n"
+ " if (u.startsWith(\"/\") && !u.startsWith(P)) u = P + u;\n"
+ " }\n"
+ " }\n"
+ " return OF.call(this, u, o);\n"
+ " };\n"
+ " var OW = window.WebSocket;\n"
+ " window.WebSocket = function(u, pr) {\n"
+ " try {\n"
+ " var p = new URL(u);\n"
+ " if (p.hostname === location.hostname\n"
+ " && String(p.port || (p.protocol === \"wss:\" ? \"443\" : \"80\"))\n"
+ " === String(location.port || (location.protocol === \"https:\" ? \"443\" : \"80\"))\n"
+ " && !p.pathname.startsWith(P))\n"
+ " { p.pathname = P + p.pathname; u = p.toString(); }\n"
+ " } catch(e) {}\n"
+ " return pr !== undefined ? new OW(u, pr) : new OW(u);\n"
+ " };\n"
+ " window.WebSocket.prototype = OW.prototype;\n"
+ " window.WebSocket.CONNECTING = OW.CONNECTING;\n"
+ " window.WebSocket.OPEN = OW.OPEN;\n"
+ " window.WebSocket.CLOSING = OW.CLOSING;\n"
+ " window.WebSocket.CLOSED = OW.CLOSED;\n"
+ " var OO = window.open;\n"
+ " window.open = function(u, t, f) {\n"
+ " if (typeof u === \"string\") {\n"
+ " try {\n"
+ " var p = new URL(u, location.origin);\n"
+ " if (p.origin === location.origin && !p.pathname.startsWith(P))\n"
+ " { p.pathname = P + p.pathname; u = p.toString(); }\n"
+ " } catch(e) {}\n"
+ " }\n"
+ " return OO.call(this, u, t, f);\n"
+ " };\n"
+ " document.addEventListener(\"click\", function(e) {\n"
+ " var t = e.composedPath ? e.composedPath()[0] : e.target;\n"
+ " while (t && t.tagName !== \"A\") t = t.parentNode || t.host;\n"
+ " if (!t || !t.href) return;\n"
+ " try {\n"
+ " var h = new URL(t.href);\n"
+ " if (h.origin === location.origin && !h.pathname.startsWith(P))\n"
+ " { h.pathname = P + h.pathname; e.preventDefault(); window.location.href = h.toString(); }\n"
+ " } catch(x) {}\n"
+ " }, true);\n"
+ "})();\n</script>");
+
+ std::string ScriptStr = Script.ToString();
+
+ size_t HeadClose = Html.find("</head>");
+ if (HeadClose != std::string_view::npos)
+ {
+ std::string Result;
+ Result.reserve(Html.size() + ScriptStr.size());
+ Result.append(Html.substr(0, HeadClose));
+ Result.append(ScriptStr);
+ Result.append(Html.substr(HeadClose));
+ return Result;
+ }
+
+ std::string Result;
+ Result.reserve(Html.size() + ScriptStr.size());
+ Result.append(ScriptStr);
+ Result.append(Html);
+ return Result;
+ }
+
+} // namespace
+
+struct HttpProxyHandler::WsBridge : public RefCounted, public IWsClientHandler
+{
+ Ref<WebSocketConnection> ClientConn;
+ std::unique_ptr<HttpWsClient> UpstreamClient;
+ uint16_t Port = 0;
+
+ void OnWsOpen() override {}
+
+ void OnWsMessage(const WebSocketMessage& Msg) override
+ {
+ if (!ClientConn->IsOpen())
+ {
+ return;
+ }
+ switch (Msg.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ ClientConn->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize()));
+ break;
+ case WebSocketOpcode::kBinary:
+ ClientConn->SendBinary(std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize()));
+ break;
+ default:
+ break;
+ }
+ }
+
+ void OnWsClose(uint16_t Code, std::string_view Reason) override
+ {
+ if (ClientConn->IsOpen())
+ {
+ ClientConn->Close(Code, Reason);
+ }
+ }
+};
+
+HttpProxyHandler::HttpProxyHandler()
+{
+}
+
+HttpProxyHandler::HttpProxyHandler(PortValidator ValidatePort) : m_ValidatePort(std::move(ValidatePort))
+{
+}
+
+void
+HttpProxyHandler::SetPortValidator(PortValidator ValidatePort)
+{
+ m_ValidatePort = std::move(ValidatePort);
+}
+
+HttpProxyHandler::~HttpProxyHandler()
+{
+ try
+ {
+ Shutdown();
+ }
+ catch (...)
+ {
+ }
+}
+
+HttpClient&
+HttpProxyHandler::GetOrCreateProxyClient(uint16_t Port)
+{
+ HttpClient* Result = nullptr;
+ m_ProxyClientsLock.WithExclusiveLock([&] {
+ auto It = m_ProxyClients.find(Port);
+ if (It == m_ProxyClients.end())
+ {
+ HttpClientSettings Settings;
+ Settings.LogCategory = "hub-proxy";
+ Settings.ConnectTimeout = std::chrono::milliseconds(5000);
+ Settings.Timeout = std::chrono::milliseconds(30000);
+ auto Client = std::make_unique<HttpClient>(fmt::format("http://127.0.0.1:{}", Port), Settings);
+ Result = Client.get();
+ m_ProxyClients.emplace(Port, std::move(Client));
+ }
+ else
+ {
+ Result = It->second.get();
+ }
+ });
+ return *Result;
+}
+
+void
+HttpProxyHandler::HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail)
+{
+ uint16_t Port = 0;
+ auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port);
+ if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size())
+ {
+ Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "invalid proxy URL");
+ return;
+ }
+
+ if (!m_ValidatePort(Port))
+ {
+ Request.WriteResponse(HttpResponseCode::BadGateway, HttpContentType::kText, "target instance not available");
+ return;
+ }
+
+ HttpClient& Client = GetOrCreateProxyClient(Port);
+
+ std::string RequestPath;
+ RequestPath.reserve(1 + PathTail.size());
+ RequestPath.push_back('/');
+ RequestPath.append(PathTail);
+
+ std::string_view QueryString = Request.QueryString();
+ if (!QueryString.empty())
+ {
+ RequestPath.push_back('?');
+ RequestPath.append(QueryString);
+ }
+
+ HttpClient::KeyValueMap ForwardHeaders;
+ HttpContentType AcceptType = Request.AcceptContentType();
+ if (AcceptType != HttpContentType::kUnknownContentType)
+ {
+ ForwardHeaders->emplace("Accept", std::string(MapContentTypeToString(AcceptType)));
+ }
+
+ std::string_view Auth = Request.GetAuthorizationHeader();
+ if (!Auth.empty())
+ {
+ ForwardHeaders->emplace("Authorization", std::string(Auth));
+ }
+
+ HttpContentType ReqContentType = Request.RequestContentType();
+ if (ReqContentType != HttpContentType::kUnknownContentType)
+ {
+ ForwardHeaders->emplace("Content-Type", std::string(MapContentTypeToString(ReqContentType)));
+ }
+
+ HttpClient::Response Response;
+
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ Response = Client.Get(RequestPath, ForwardHeaders);
+ break;
+ case HttpVerb::kPost:
+ {
+ IoBuffer Payload = Request.ReadPayload();
+ Response = Client.Post(RequestPath, Payload, ForwardHeaders);
+ break;
+ }
+ case HttpVerb::kPut:
+ {
+ IoBuffer Payload = Request.ReadPayload();
+ Response = Client.Put(RequestPath, Payload, ForwardHeaders);
+ break;
+ }
+ case HttpVerb::kDelete:
+ Response = Client.Delete(RequestPath, ForwardHeaders);
+ break;
+ case HttpVerb::kHead:
+ Response = Client.Head(RequestPath, ForwardHeaders);
+ break;
+ default:
+ Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "method not supported");
+ return;
+ }
+
+ if (Response.Error)
+ {
+ ZEN_WARN("proxy request to port {} failed: {}", Port, Response.Error->ErrorMessage);
+ Request.WriteResponse(HttpResponseCode::BadGateway, HttpContentType::kText, "upstream request failed");
+ return;
+ }
+
+ HttpContentType ContentType = Response.ResponsePayload.GetContentType();
+
+ if (ContentType == HttpContentType::kHTML)
+ {
+ std::string_view Html(static_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize());
+ std::string Injected = InjectProxyScript(Html, Port);
+ Request.WriteResponse(Response.StatusCode, HttpContentType::kHTML, std::string_view(Injected));
+ }
+ else
+ {
+ Request.WriteResponse(Response.StatusCode, ContentType, std::move(Response.ResponsePayload));
+ }
+}
+
+void
+HttpProxyHandler::PrunePort(uint16_t Port)
+{
+ m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.erase(Port); });
+
+ std::vector<Ref<WsBridge>> Stale;
+ m_WsBridgesLock.WithExclusiveLock([&] {
+ for (auto It = m_WsBridges.begin(); It != m_WsBridges.end();)
+ {
+ if (It->second->Port == Port)
+ {
+ Stale.push_back(std::move(It->second));
+ It = m_WsBridges.erase(It);
+ }
+ else
+ {
+ ++It;
+ }
+ }
+ });
+
+ for (auto& Bridge : Stale)
+ {
+ if (Bridge->UpstreamClient)
+ {
+ Bridge->UpstreamClient->Close(1001, "instance shutting down");
+ }
+ if (Bridge->ClientConn->IsOpen())
+ {
+ Bridge->ClientConn->Close(1001, "instance shutting down");
+ }
+ }
+}
+
+void
+HttpProxyHandler::Shutdown()
+{
+ m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.clear(); });
+ m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.clear(); });
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// WebSocket proxy
+//
+
+void
+HttpProxyHandler::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
+{
+ const std::string_view ProxyPrefix = "proxy/";
+ if (!RelativeUri.starts_with(ProxyPrefix))
+ {
+ Connection->Close(1008, "unsupported WebSocket endpoint");
+ return;
+ }
+
+ std::string_view ProxyTail = RelativeUri.substr(ProxyPrefix.size());
+
+ size_t SlashPos = ProxyTail.find('/');
+ std::string_view PortStr = (SlashPos != std::string_view::npos) ? ProxyTail.substr(0, SlashPos) : ProxyTail;
+ std::string_view Path = (SlashPos != std::string_view::npos) ? ProxyTail.substr(SlashPos) : "/";
+
+ uint16_t Port = 0;
+ auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port);
+ if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size())
+ {
+ Connection->Close(1008, "invalid proxy URL");
+ return;
+ }
+
+ if (!m_ValidatePort(Port))
+ {
+ Connection->Close(1008, "target instance not available");
+ return;
+ }
+
+ std::string WsUrl = HttpToWsUrl(fmt::format("http://127.0.0.1:{}", Port), Path);
+
+ Ref<WsBridge> Bridge(new WsBridge());
+ Bridge->ClientConn = Connection;
+ Bridge->Port = Port;
+
+ Bridge->UpstreamClient = std::make_unique<HttpWsClient>(WsUrl, *Bridge);
+
+ try
+ {
+ Bridge->UpstreamClient->Connect();
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("proxy WebSocket connect to {} failed: {}", WsUrl, Ex.what());
+ Connection->Close(1011, "upstream connect failed");
+ return;
+ }
+
+ WebSocketConnection* Key = Connection.Get();
+ m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.emplace(Key, std::move(Bridge)); });
+}
+
+void
+HttpProxyHandler::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg)
+{
+ Ref<WsBridge> Bridge;
+ m_WsBridgesLock.WithSharedLock([&] {
+ auto It = m_WsBridges.find(&Conn);
+ if (It != m_WsBridges.end())
+ {
+ Bridge = It->second;
+ }
+ });
+
+ if (!Bridge || !Bridge->UpstreamClient)
+ {
+ return;
+ }
+
+ switch (Msg.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ Bridge->UpstreamClient->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize()));
+ break;
+ case WebSocketOpcode::kBinary:
+ Bridge->UpstreamClient->SendBinary(
+ std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize()));
+ break;
+ case WebSocketOpcode::kClose:
+ Bridge->UpstreamClient->Close(Msg.CloseCode, {});
+ break;
+ default:
+ break;
+ }
+}
+
+void
+HttpProxyHandler::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason)
+{
+ Ref<WsBridge> Bridge = m_WsBridgesLock.WithExclusiveLock([this, &Conn]() -> Ref<WsBridge> {
+ auto It = m_WsBridges.find(&Conn);
+ if (It != m_WsBridges.end())
+ {
+ Ref<WsBridge> Bridge = std::move(It->second);
+ m_WsBridges.erase(It);
+ return Bridge;
+ }
+ return {};
+ });
+
+ if (Bridge && Bridge->UpstreamClient)
+ {
+ Bridge->UpstreamClient->Close(Code, Reason);
+ }
+}
+
+#if ZEN_WITH_TESTS
+
+TEST_SUITE_BEGIN("server.httpproxyhandler");
+
+TEST_CASE("server.httpproxyhandler.html_injection")
+{
+ SUBCASE("injects before </head>")
+ {
+ std::string Result = InjectProxyScript("<html><head></head><body></body></html>", 21005);
+ CHECK(Result.find("<script>") != std::string::npos);
+ CHECK(Result.find("/hub/proxy/21005") != std::string::npos);
+ size_t ScriptEnd = Result.find("</script>");
+ size_t HeadClose = Result.find("</head>");
+ REQUIRE(ScriptEnd != std::string::npos);
+ REQUIRE(HeadClose != std::string::npos);
+ CHECK(ScriptEnd < HeadClose);
+ }
+
+ SUBCASE("prepends when no </head>")
+ {
+ std::string Result = InjectProxyScript("<body>content</body>", 21005);
+ CHECK(Result.find("<script>") == 0);
+ CHECK(Result.find("<body>content</body>") != std::string::npos);
+ }
+
+ SUBCASE("empty html")
+ {
+ std::string Result = InjectProxyScript("", 21005);
+ CHECK(Result.find("<script>") != std::string::npos);
+ CHECK(Result.find("/hub/proxy/21005") != std::string::npos);
+ }
+
+ SUBCASE("preserves original content")
+ {
+ std::string_view Html = "<html><head><title>Test</title></head><body><h1>Dashboard</h1></body></html>";
+ std::string Result = InjectProxyScript(Html, 21005);
+ CHECK(Result.find("<title>Test</title>") != std::string::npos);
+ CHECK(Result.find("<h1>Dashboard</h1>") != std::string::npos);
+ }
+}
+
+TEST_CASE("server.httpproxyhandler.port_embedding")
+{
+ std::string Result = InjectProxyScript("<head></head>", 80);
+ CHECK(Result.find("/hub/proxy/80") != std::string::npos);
+
+ Result = InjectProxyScript("<head></head>", 65535);
+ CHECK(Result.find("/hub/proxy/65535") != std::string::npos);
+}
+
+TEST_SUITE_END();
+
+void
+httpproxyhandler_forcelink()
+{
+}
+#endif // ZEN_WITH_TESTS
+
+} // namespace zen
diff --git a/src/zenserver/hub/httpproxyhandler.h b/src/zenserver/hub/httpproxyhandler.h
new file mode 100644
index 000000000..8667c0ca1
--- /dev/null
+++ b/src/zenserver/hub/httpproxyhandler.h
@@ -0,0 +1,52 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/thread.h>
+#include <zenhttp/httpserver.h>
+#include <zenhttp/websocket.h>
+
+#include <functional>
+#include <memory>
+#include <unordered_map>
+
+namespace zen {
+
+class HttpClient;
+
+class HttpProxyHandler
+{
+public:
+ using PortValidator = std::function<bool(uint16_t)>;
+
+ HttpProxyHandler();
+ explicit HttpProxyHandler(PortValidator ValidatePort);
+ ~HttpProxyHandler();
+
+ void SetPortValidator(PortValidator ValidatePort);
+
+ HttpProxyHandler(const HttpProxyHandler&) = delete;
+ HttpProxyHandler& operator=(const HttpProxyHandler&) = delete;
+
+ void HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail);
+ void PrunePort(uint16_t Port);
+ void Shutdown();
+
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri);
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg);
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason);
+
+private:
+ PortValidator m_ValidatePort;
+
+ HttpClient& GetOrCreateProxyClient(uint16_t Port);
+
+ RwLock m_ProxyClientsLock;
+ std::unordered_map<uint16_t, std::unique_ptr<HttpClient>> m_ProxyClients;
+
+ struct WsBridge;
+ RwLock m_WsBridgesLock;
+ std::unordered_map<WebSocketConnection*, Ref<WsBridge>> m_WsBridges;
+};
+
+} // namespace zen
diff --git a/src/zenserver/hub/hub.cpp b/src/zenserver/hub/hub.cpp
index 6c44e2333..82f4a00ba 100644
--- a/src/zenserver/hub/hub.cpp
+++ b/src/zenserver/hub/hub.cpp
@@ -19,7 +19,6 @@ ZEN_THIRD_PARTY_INCLUDES_START
ZEN_THIRD_PARTY_INCLUDES_END
#if ZEN_WITH_TESTS
-# include <zencore/filesystem.h>
# include <zencore/testing.h>
# include <zencore/testutils.h>
#endif
@@ -122,6 +121,55 @@ private:
//////////////////////////////////////////////////////////////////////////
+ProcessMetrics
+Hub::AtomicProcessMetrics::Load() const
+{
+ return {
+ .MemoryBytes = MemoryBytes.load(),
+ .KernelTimeMs = KernelTimeMs.load(),
+ .UserTimeMs = UserTimeMs.load(),
+ .WorkingSetSize = WorkingSetSize.load(),
+ .PeakWorkingSetSize = PeakWorkingSetSize.load(),
+ .PagefileUsage = PagefileUsage.load(),
+ .PeakPagefileUsage = PeakPagefileUsage.load(),
+ };
+}
+
+void
+Hub::AtomicProcessMetrics::Store(const ProcessMetrics& Metrics)
+{
+ MemoryBytes.store(Metrics.MemoryBytes);
+ KernelTimeMs.store(Metrics.KernelTimeMs);
+ UserTimeMs.store(Metrics.UserTimeMs);
+ WorkingSetSize.store(Metrics.WorkingSetSize);
+ PeakWorkingSetSize.store(Metrics.PeakWorkingSetSize);
+ PagefileUsage.store(Metrics.PagefileUsage);
+ PeakPagefileUsage.store(Metrics.PeakPagefileUsage);
+}
+
+void
+Hub::AtomicProcessMetrics::Reset()
+{
+ MemoryBytes.store(0);
+ KernelTimeMs.store(0);
+ UserTimeMs.store(0);
+ WorkingSetSize.store(0);
+ PeakWorkingSetSize.store(0);
+ PagefileUsage.store(0);
+ PeakPagefileUsage.store(0);
+}
+
+void
+Hub::GetMachineMetrics(SystemMetrics& OutSystemMetrict, DiskSpace& OutDiskSpace) const
+{
+ m_Lock.WithSharedLock([&]() {
+ OutSystemMetrict = m_SystemMetrics;
+ OutDiskSpace = m_DiskSpace;
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
Hub::Hub(const Configuration& Config,
ZenServerEnvironment&& RunEnvironment,
WorkerThreadPool* OptionalWorkerPool,
@@ -134,11 +182,11 @@ Hub::Hub(const Configuration& Config,
, m_ActiveInstances(Config.InstanceLimit)
, m_FreeActiveInstanceIndexes(Config.InstanceLimit)
{
- m_HostMetrics = GetSystemMetrics();
- m_ResourceLimits.DiskUsageBytes = 1000ull * 1024 * 1024 * 1024;
- m_ResourceLimits.MemoryUsageBytes = 16ull * 1024 * 1024 * 1024;
-
- if (m_Config.HydrationTargetSpecification.empty())
+ if (!m_Config.HydrationTargetSpecification.empty())
+ {
+ m_HydrationTargetSpecification = m_Config.HydrationTargetSpecification;
+ }
+ else if (!m_Config.HydrationOptions)
{
std::filesystem::path FileHydrationPath = m_RunEnvironment.CreateChildDir("hydration_storage");
ZEN_INFO("using file hydration path: '{}'", FileHydrationPath);
@@ -146,7 +194,7 @@ Hub::Hub(const Configuration& Config,
}
else
{
- m_HydrationTargetSpecification = m_Config.HydrationTargetSpecification;
+ m_HydrationOptions = m_Config.HydrationOptions;
}
m_HydrationTempPath = m_RunEnvironment.CreateChildDir("hydration_temp");
@@ -171,6 +219,9 @@ Hub::Hub(const Configuration& Config,
}
}
#endif
+
+ UpdateMachineMetrics();
+
m_WatchDog = std::thread([this]() { WatchDog(); });
}
@@ -195,6 +246,9 @@ Hub::Shutdown()
{
ZEN_INFO("Hub service shutting down, deprovisioning any current instances");
+ bool Expected = false;
+ bool WaitForBackgroundWork = m_ShutdownFlag.compare_exchange_strong(Expected, true);
+
m_WatchDogEvent.Set();
if (m_WatchDog.joinable())
{
@@ -203,8 +257,6 @@ Hub::Shutdown()
m_WatchDog = {};
- bool Expected = false;
- bool WaitForBackgroundWork = m_ShutdownFlag.compare_exchange_strong(Expected, true);
if (WaitForBackgroundWork && m_WorkerPool)
{
m_BackgroundWorkLatch.CountDown();
@@ -254,7 +306,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo)
if (auto It = m_InstanceLookup.find(std::string(ModuleId)); It == m_InstanceLookup.end())
{
std::string Reason;
- if (!CanProvisionInstance(ModuleId, /* out */ Reason))
+ if (!CanProvisionInstanceLocked(ModuleId, /* out */ Reason))
{
ZEN_WARN("Cannot provision new storage server instance for module '{}': {}", ModuleId, Reason);
@@ -274,6 +326,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo)
StorageServerInstance::Configuration{.BasePort = GetInstanceIndexAssignedPort(ActiveInstanceIndex),
.HydrationTempPath = m_HydrationTempPath,
.HydrationTargetSpecification = m_HydrationTargetSpecification,
+ .HydrationOptions = m_HydrationOptions,
.HttpThreadCount = m_Config.InstanceHttpThreadCount,
.CoreLimit = m_Config.InstanceCoreLimit,
.ConfigPath = m_Config.InstanceConfigPath},
@@ -289,6 +342,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo)
Instance = NewInstance->LockExclusive(/*Wait*/ true);
m_ActiveInstances[ActiveInstanceIndex].Instance = std::move(NewInstance);
+ m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Reset();
m_InstanceLookup.insert_or_assign(std::string(ModuleId), ActiveInstanceIndex);
// Set Provisioning while both hub lock and instance lock are held so that any
// concurrent Deprovision sees the in-flight state, not Unprovisioned.
@@ -947,12 +1001,10 @@ Hub::Find(std::string_view ModuleId, InstanceInfo* OutInstanceInfo)
ZEN_ASSERT(ActiveInstanceIndex < m_ActiveInstances.size());
const std::unique_ptr<StorageServerInstance>& Instance = m_ActiveInstances[ActiveInstanceIndex].Instance;
ZEN_ASSERT(Instance);
- InstanceInfo Info{
- m_ActiveInstances[ActiveInstanceIndex].State.load(),
- std::chrono::system_clock::now() // TODO
- };
- Instance->GetProcessMetrics(Info.Metrics);
- Info.Port = Instance->GetBasePort();
+ InstanceInfo Info{m_ActiveInstances[ActiveInstanceIndex].State.load(),
+ m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.load()};
+ Info.Metrics = m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Load();
+ Info.Port = Instance->GetBasePort();
*OutInstanceInfo = Info;
}
@@ -971,12 +1023,10 @@ Hub::EnumerateModules(std::function<void(std::string_view ModuleId, const Instan
{
const std::unique_ptr<StorageServerInstance>& Instance = m_ActiveInstances[ActiveInstanceIndex].Instance;
ZEN_ASSERT(Instance);
- InstanceInfo Info{
- m_ActiveInstances[ActiveInstanceIndex].State.load(),
- std::chrono::system_clock::now() // TODO
- };
- Instance->GetProcessMetrics(Info.Metrics);
- Info.Port = Instance->GetBasePort();
+ InstanceInfo Info{m_ActiveInstances[ActiveInstanceIndex].State.load(),
+ m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.load()};
+ Info.Metrics = m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Load();
+ Info.Port = Instance->GetBasePort();
Infos.push_back(std::make_pair(std::string(Instance->GetModuleId()), Info));
}
@@ -994,28 +1044,8 @@ Hub::GetInstanceCount()
return m_Lock.WithSharedLock([this]() { return gsl::narrow_cast<int>(m_InstanceLookup.size()); });
}
-void
-Hub::UpdateCapacityMetrics()
-{
- m_HostMetrics = GetSystemMetrics();
-
- // TODO: Should probably go into WatchDog and use atomic for update so it can be read without locks...
- // Per-instance stats are already refreshed by WatchDog and are readable via the Find and EnumerateModules
-}
-
-void
-Hub::UpdateStats()
-{
- int CurrentInstanceCount = m_Lock.WithSharedLock([this] { return gsl::narrow_cast<int>(m_InstanceLookup.size()); });
- int CurrentMaxCount = m_MaxInstanceCount.load();
-
- int NewMax = Max(CurrentMaxCount, CurrentInstanceCount);
-
- m_MaxInstanceCount.compare_exchange_weak(CurrentMaxCount, NewMax);
-}
-
bool
-Hub::CanProvisionInstance(std::string_view ModuleId, std::string& OutReason)
+Hub::CanProvisionInstanceLocked(std::string_view ModuleId, std::string& OutReason)
{
ZEN_UNUSED(ModuleId);
if (m_FreeActiveInstanceIndexes.empty())
@@ -1025,7 +1055,24 @@ Hub::CanProvisionInstance(std::string_view ModuleId, std::string& OutReason)
return false;
}
- // TODO: handle additional resource metrics
+ const uint64_t DiskUsedBytes = m_DiskSpace.Free <= m_DiskSpace.Total ? m_DiskSpace.Total - m_DiskSpace.Free : 0;
+ if (m_Config.ResourceLimits.DiskUsageBytes > 0 && DiskUsedBytes > m_Config.ResourceLimits.DiskUsageBytes)
+ {
+ OutReason =
+ fmt::format("disk usage ({}) exceeds ({})", NiceBytes(DiskUsedBytes), NiceBytes(m_Config.ResourceLimits.DiskUsageBytes));
+ return false;
+ }
+
+ const uint64_t RamUsedMiB = m_SystemMetrics.AvailSystemMemoryMiB <= m_SystemMetrics.SystemMemoryMiB
+ ? m_SystemMetrics.SystemMemoryMiB - m_SystemMetrics.AvailSystemMemoryMiB
+ : 0;
+ const uint64_t RamUsedBytes = RamUsedMiB * 1024 * 1024;
+ if (m_Config.ResourceLimits.MemoryUsageBytes > 0 && RamUsedBytes > m_Config.ResourceLimits.MemoryUsageBytes)
+ {
+ OutReason =
+ fmt::format("ram usage ({}) exceeds ({})", NiceBytes(RamUsedBytes), NiceBytes(m_Config.ResourceLimits.MemoryUsageBytes));
+ return false;
+ }
return true;
}
@@ -1036,6 +1083,21 @@ Hub::GetInstanceIndexAssignedPort(size_t ActiveInstanceIndex) const
return gsl::narrow<uint16_t>(m_Config.BasePortNumber + ActiveInstanceIndex);
}
+bool
+Hub::IsInstancePort(uint16_t Port) const
+{
+ if (Port < m_Config.BasePortNumber)
+ {
+ return false;
+ }
+ size_t Index = Port - m_Config.BasePortNumber;
+ if (Index >= m_ActiveInstances.size())
+ {
+ return false;
+ }
+ return m_ActiveInstances[Index].State.load(std::memory_order_relaxed) != HubInstanceState::Unprovisioned;
+}
+
HubInstanceState
Hub::UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewState)
{
@@ -1065,8 +1127,10 @@ Hub::UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewS
}
return false;
}(m_ActiveInstances[ActiveInstanceIndex].State.load(), NewState));
+ const std::chrono::system_clock::time_point Now = std::chrono::system_clock::now();
m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.store(0);
- m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(std::chrono::system_clock::now());
+ m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(Now);
+ m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.store(Now);
return m_ActiveInstances[ActiveInstanceIndex].State.exchange(NewState);
}
@@ -1173,14 +1237,14 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient,
StorageServerInstance::SharedLockedPtr&& LockedInstance,
size_t ActiveInstanceIndex)
{
+ const std::string ModuleId(LockedInstance.GetModuleId());
+
HubInstanceState InstanceState = m_ActiveInstances[ActiveInstanceIndex].State.load();
if (LockedInstance.IsRunning())
{
- LockedInstance.UpdateMetrics();
+ m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Store(LockedInstance.GetProcessMetrics());
if (InstanceState == HubInstanceState::Provisioned)
{
- const std::string ModuleId(LockedInstance.GetModuleId());
-
const uint16_t Port = LockedInstance.GetBasePort();
const uint64_t PreviousActivitySum = m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.load();
const std::chrono::system_clock::time_point LastActivityTime = m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.load();
@@ -1260,8 +1324,7 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient,
else if (InstanceState == HubInstanceState::Provisioned)
{
// Process is not running but state says it should be - instance died unexpectedly.
- const std::string ModuleId(LockedInstance.GetModuleId());
- const uint16_t Port = LockedInstance.GetBasePort();
+ const uint16_t Port = LockedInstance.GetBasePort();
UpdateInstanceState(LockedInstance, ActiveInstanceIndex, HubInstanceState::Crashed);
NotifyStateUpdate(ModuleId, HubInstanceState::Provisioned, HubInstanceState::Crashed, Port, {});
LockedInstance = {};
@@ -1272,7 +1335,6 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient,
{
// Process is not running - no HTTP activity check is possible.
// Use a pure time-based check; the margin window does not apply here.
- const std::string ModuleId = std::string(LockedInstance.GetModuleId());
const std::chrono::system_clock::time_point LastActivityTime = m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.load();
const uint64_t PreviousActivitySum = m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.load();
const std::chrono::system_clock::time_point Now = std::chrono::system_clock::now();
@@ -1312,6 +1374,43 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient,
}
void
+Hub::UpdateMachineMetrics()
+{
+ try
+ {
+ bool DiskSpaceOk = false;
+ DiskSpace Disk;
+
+ std::filesystem::path ChildDir = m_RunEnvironment.GetChildBaseDir();
+ if (!ChildDir.empty())
+ {
+ if (DiskSpaceInfo(ChildDir, Disk))
+ {
+ DiskSpaceOk = true;
+ }
+ else
+ {
+ ZEN_WARN("Failed to query disk space for '{}'; disk-based provisioning limits will not be enforced", ChildDir);
+ }
+ }
+
+ SystemMetrics Metrics = GetSystemMetrics();
+
+ m_Lock.WithExclusiveLock([&]() {
+ if (DiskSpaceOk)
+ {
+ m_DiskSpace = Disk;
+ }
+ m_SystemMetrics = Metrics;
+ });
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("Failed to update machine metrics. Reason: {}", Ex.what());
+ }
+}
+
+void
Hub::WatchDog()
{
const uint64_t CycleIntervalMs = std::chrono::duration_cast<std::chrono::milliseconds>(m_Config.WatchDog.CycleInterval).count();
@@ -1326,16 +1425,18 @@ Hub::WatchDog()
[&]() -> bool { return m_WatchDogEvent.Wait(0); });
size_t CheckInstanceIndex = SIZE_MAX; // first increment wraps to 0
- while (!m_WatchDogEvent.Wait(gsl::narrow<int>(CycleIntervalMs)))
+ while (!m_ShutdownFlag.load() && !m_WatchDogEvent.Wait(gsl::narrow<int>(CycleIntervalMs)))
{
try
{
+ UpdateMachineMetrics();
+
// Snapshot slot count. We iterate all slots (including freed nulls) so
// round-robin coverage is not skewed by deprovisioned entries.
size_t SlotsRemaining = m_Lock.WithSharedLock([this]() { return m_ActiveInstances.size(); });
Stopwatch Timer;
- bool ShuttingDown = false;
+ bool ShuttingDown = m_ShutdownFlag.load();
while (SlotsRemaining > 0 && Timer.GetElapsedTimeMs() < CycleProcessingBudgetMs && !ShuttingDown)
{
StorageServerInstance::SharedLockedPtr LockedInstance;
@@ -1366,16 +1467,24 @@ Hub::WatchDog()
std::string ModuleId(LockedInstance.GetModuleId());
- bool InstanceIsOk = CheckInstanceStatus(ActivityCheckClient, std::move(LockedInstance), CheckInstanceIndex);
- if (InstanceIsOk)
+ try
{
- ShuttingDown = m_WatchDogEvent.Wait(gsl::narrow<int>(InstanceCheckThrottleMs));
+ bool InstanceIsOk = CheckInstanceStatus(ActivityCheckClient, std::move(LockedInstance), CheckInstanceIndex);
+ if (InstanceIsOk)
+ {
+ ShuttingDown = m_WatchDogEvent.Wait(gsl::narrow<int>(InstanceCheckThrottleMs));
+ }
+ else
+ {
+ ZEN_WARN("Instance for module '{}' is not running, attempting recovery", ModuleId);
+ AttemptRecoverInstance(ModuleId);
+ }
}
- else
+ catch (const std::exception& Ex)
{
- ZEN_WARN("Instance for module '{}' is not running, attempting recovery", ModuleId);
- AttemptRecoverInstance(ModuleId);
+ ZEN_WARN("Failed to check status of module {}. Reason: {}", ModuleId, Ex.what());
}
+ ShuttingDown |= m_ShutdownFlag.load();
}
}
catch (const std::exception& Ex)
@@ -1515,6 +1624,8 @@ TEST_CASE("hub.provision_basic")
Hub::InstanceInfo InstanceInfo;
REQUIRE(HubInstance->Find("module_a", &InstanceInfo));
CHECK_EQ(InstanceInfo.State, HubInstanceState::Provisioned);
+ CHECK_NE(InstanceInfo.StateChangeTime, std::chrono::system_clock::time_point::min());
+ CHECK_LE(InstanceInfo.StateChangeTime, std::chrono::system_clock::now());
{
HttpClient ModClient(fmt::format("http://localhost:{}", Info.Port), kFastTimeout);
@@ -1934,6 +2045,9 @@ TEST_CASE("hub.hibernate_wake")
}
REQUIRE(HubInstance->Find("hib_a", &Info));
CHECK_EQ(Info.State, HubInstanceState::Provisioned);
+ const std::chrono::system_clock::time_point ProvisionedTime = Info.StateChangeTime;
+ CHECK_NE(ProvisionedTime, std::chrono::system_clock::time_point::min());
+ CHECK_LE(ProvisionedTime, std::chrono::system_clock::now());
{
HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout);
CHECK(ModClient.Get("/health/"));
@@ -1944,6 +2058,8 @@ TEST_CASE("hub.hibernate_wake")
REQUIRE_MESSAGE(HibernateResult.ResponseCode == Hub::EResponseCode::Completed, HibernateResult.Message);
REQUIRE(HubInstance->Find("hib_a", &Info));
CHECK_EQ(Info.State, HubInstanceState::Hibernated);
+ const std::chrono::system_clock::time_point HibernatedTime = Info.StateChangeTime;
+ CHECK_GE(HibernatedTime, ProvisionedTime);
{
HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout);
CHECK(!ModClient.Get("/health/"));
@@ -1954,6 +2070,7 @@ TEST_CASE("hub.hibernate_wake")
REQUIRE_MESSAGE(WakeResult.ResponseCode == Hub::EResponseCode::Completed, WakeResult.Message);
REQUIRE(HubInstance->Find("hib_a", &Info));
CHECK_EQ(Info.State, HubInstanceState::Provisioned);
+ CHECK_GE(Info.StateChangeTime, HibernatedTime);
{
HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout);
CHECK(ModClient.Get("/health/"));
@@ -2352,7 +2469,7 @@ TEST_CASE("hub.async_provision_shutdown_waits")
TEST_CASE("hub.async_provision_rejected")
{
- // Rejection from CanProvisionInstance fires synchronously even when a WorkerPool is present.
+ // Rejection from CanProvisionInstanceLocked fires synchronously even when a WorkerPool is present.
ScopedTemporaryDirectory TempDir;
Hub::Configuration Config;
@@ -2369,7 +2486,7 @@ TEST_CASE("hub.async_provision_rejected")
REQUIRE_MESSAGE(FirstResult.ResponseCode == Hub::EResponseCode::Accepted, FirstResult.Message);
REQUIRE_NE(Info.Port, 0);
- // Second provision: CanProvisionInstance rejects synchronously (limit reached), returns Rejected
+ // Second provision: CanProvisionInstanceLocked rejects synchronously (limit reached), returns Rejected
HubProvisionedInstanceInfo Info2;
const Hub::Response SecondResult = HubInstance->Provision("async_r2", Info2);
CHECK(SecondResult.ResponseCode == Hub::EResponseCode::Rejected);
@@ -2485,6 +2602,55 @@ TEST_CASE("hub.instance.inactivity.deprovision")
HubInstance->Shutdown();
}
+TEST_CASE("hub.machine_metrics")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), {});
+
+ // UpdateMachineMetrics() is called synchronously in the Hub constructor, so metrics
+ // are available immediately without waiting for a watchdog cycle.
+ SystemMetrics SysMetrics;
+ DiskSpace Disk;
+ HubInstance->GetMachineMetrics(SysMetrics, Disk);
+
+ CHECK_GT(Disk.Total, 0u);
+ CHECK_LE(Disk.Free, Disk.Total);
+
+ CHECK_GT(SysMetrics.SystemMemoryMiB, 0u);
+ CHECK_LE(SysMetrics.AvailSystemMemoryMiB, SysMetrics.SystemMemoryMiB);
+
+ CHECK_GT(SysMetrics.VirtualMemoryMiB, 0u);
+ CHECK_LE(SysMetrics.AvailVirtualMemoryMiB, SysMetrics.VirtualMemoryMiB);
+}
+
+TEST_CASE("hub.provision_rejected_resource_limits")
+{
+ // The Hub constructor calls UpdateMachineMetrics() synchronously, so CanProvisionInstanceLocked
+ // can enforce limits immediately without waiting for a watchdog cycle.
+ ScopedTemporaryDirectory TempDir;
+
+ {
+ Hub::Configuration Config;
+ Config.ResourceLimits.DiskUsageBytes = 1;
+ std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config);
+ HubProvisionedInstanceInfo Info;
+ const Hub::Response Result = HubInstance->Provision("disk_limit", Info);
+ CHECK(Result.ResponseCode == Hub::EResponseCode::Rejected);
+ CHECK_NE(Result.Message.find("disk usage"), std::string::npos);
+ }
+
+ {
+ Hub::Configuration Config;
+ Config.ResourceLimits.MemoryUsageBytes = 1;
+ std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config);
+ HubProvisionedInstanceInfo Info;
+ const Hub::Response Result = HubInstance->Provision("mem_limit", Info);
+ CHECK(Result.ResponseCode == Hub::EResponseCode::Rejected);
+ CHECK_NE(Result.Message.find("ram usage"), std::string::npos);
+ }
+}
+
TEST_SUITE_END();
void
diff --git a/src/zenserver/hub/hub.h b/src/zenserver/hub/hub.h
index c343b19e2..8ee9130f6 100644
--- a/src/zenserver/hub/hub.h
+++ b/src/zenserver/hub/hub.h
@@ -6,6 +6,8 @@
#include "resourcemetrics.h"
#include "storageserverinstance.h"
+#include <zencore/compactbinary.h>
+#include <zencore/filesystem.h>
#include <zencore/system.h>
#include <zenutil/zenserverprocess.h>
@@ -66,8 +68,11 @@ public:
int InstanceCoreLimit = 0; // Automatic
std::filesystem::path InstanceConfigPath;
std::string HydrationTargetSpecification;
+ CbObject HydrationOptions;
WatchDogConfiguration WatchDog;
+
+ ResourceMetrics ResourceLimits;
};
typedef std::function<
@@ -86,7 +91,7 @@ public:
struct InstanceInfo
{
HubInstanceState State = HubInstanceState::Unprovisioned;
- std::chrono::system_clock::time_point ProvisionTime;
+ std::chrono::system_clock::time_point StateChangeTime;
ProcessMetrics Metrics;
uint16_t Port = 0;
};
@@ -160,6 +165,10 @@ public:
int GetMaxInstanceCount() const { return m_MaxInstanceCount.load(); }
+ void GetMachineMetrics(SystemMetrics& OutSystemMetrict, DiskSpace& OutDiskSpace) const;
+
+ bool IsInstancePort(uint16_t Port) const;
+
const Configuration& GetConfig() const { return m_Config; }
#if ZEN_WITH_TESTS
@@ -176,14 +185,31 @@ private:
AsyncModuleStateChangeCallbackFunc m_ModuleStateChangeCallback;
std::string m_HydrationTargetSpecification;
+ CbObject m_HydrationOptions;
std::filesystem::path m_HydrationTempPath;
#if ZEN_PLATFORM_WINDOWS
JobObject m_JobObject;
#endif
- RwLock m_Lock;
+ mutable RwLock m_Lock;
std::unordered_map<std::string, size_t> m_InstanceLookup;
+ // Mirrors ProcessMetrics with atomic fields, enabling lock-free reads alongside watchdog writes.
+ struct AtomicProcessMetrics
+ {
+ std::atomic<uint64_t> MemoryBytes = 0;
+ std::atomic<uint64_t> KernelTimeMs = 0;
+ std::atomic<uint64_t> UserTimeMs = 0;
+ std::atomic<uint64_t> WorkingSetSize = 0;
+ std::atomic<uint64_t> PeakWorkingSetSize = 0;
+ std::atomic<uint64_t> PagefileUsage = 0;
+ std::atomic<uint64_t> PeakPagefileUsage = 0;
+
+ ProcessMetrics Load() const;
+ void Store(const ProcessMetrics& Metrics);
+ void Reset();
+ };
+
struct ActiveInstance
{
// Invariant: Instance == nullptr if and only if State == Unprovisioned.
@@ -192,11 +218,16 @@ private:
// without holding the hub lock.
std::unique_ptr<StorageServerInstance> Instance;
std::atomic<HubInstanceState> State = HubInstanceState::Unprovisioned;
- // TODO: We should move current metrics here (from StorageServerInstance)
- // Read and updated by WatchDog, updates to State triggers a reset of both
+ // Process metrics - written by WatchDog (inside instance shared lock), read lock-free.
+ AtomicProcessMetrics ProcessMetrics;
+
+ // Activity tracking - written by WatchDog, reset on every state transition.
std::atomic<uint64_t> LastKnownActivitySum = 0;
std::atomic<std::chrono::system_clock::time_point> LastActivityTime = std::chrono::system_clock::time_point::min();
+
+ // Set in UpdateInstanceStateLocked on every state transition; read lock-free by Find/EnumerateModules.
+ std::atomic<std::chrono::system_clock::time_point> StateChangeTime = std::chrono::system_clock::time_point::min();
};
// UpdateInstanceState is overloaded to accept a locked instance pointer (exclusive or shared) or the hub exclusive
@@ -226,21 +257,20 @@ private:
std::vector<ActiveInstance> m_ActiveInstances;
std::deque<size_t> m_FreeActiveInstanceIndexes;
- ResourceMetrics m_ResourceLimits;
- SystemMetrics m_HostMetrics;
+ SystemMetrics m_SystemMetrics;
+ DiskSpace m_DiskSpace;
std::atomic<int> m_MaxInstanceCount = 0;
std::thread m_WatchDog;
Event m_WatchDogEvent;
void WatchDog();
+ void UpdateMachineMetrics();
bool CheckInstanceStatus(HttpClient& ActivityHttpClient,
StorageServerInstance::SharedLockedPtr&& LockedInstance,
size_t ActiveInstanceIndex);
void AttemptRecoverInstance(std::string_view ModuleId);
- void UpdateStats();
- void UpdateCapacityMetrics();
- bool CanProvisionInstance(std::string_view ModuleId, std::string& OutReason);
+ bool CanProvisionInstanceLocked(std::string_view ModuleId, std::string& OutReason);
uint16_t GetInstanceIndexAssignedPort(size_t ActiveInstanceIndex) const;
Response InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveInstance& Instance)>&& DeprovisionGate);
diff --git a/src/zenserver/hub/hydration.cpp b/src/zenserver/hub/hydration.cpp
index 541127590..ed16bfe56 100644
--- a/src/zenserver/hub/hydration.cpp
+++ b/src/zenserver/hub/hydration.cpp
@@ -10,6 +10,7 @@
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/system.h>
+#include <zencore/timer.h>
#include <zenutil/cloud/imdscredentials.h>
#include <zenutil/cloud/s3client.h>
@@ -60,6 +61,7 @@ namespace {
///////////////////////////////////////////////////////////////////////////
constexpr std::string_view FileHydratorPrefix = "file://";
+constexpr std::string_view FileHydratorType = "file";
struct FileHydrator : public HydrationStrategyBase
{
@@ -77,7 +79,21 @@ FileHydrator::Configure(const HydrationConfig& Config)
{
m_Config = Config;
- std::filesystem::path ConfigPath(Utf8ToWide(m_Config.TargetSpecification.substr(FileHydratorPrefix.length())));
+ std::filesystem::path ConfigPath;
+ if (!m_Config.TargetSpecification.empty())
+ {
+ ConfigPath = Utf8ToWide(m_Config.TargetSpecification.substr(FileHydratorPrefix.length()));
+ }
+ else
+ {
+ CbObjectView Settings = m_Config.Options["settings"].AsObjectView();
+ std::string_view Path = Settings["path"].AsString();
+ if (Path.empty())
+ {
+ throw zen::runtime_error("Hydration config 'file' type requires 'settings.path'");
+ }
+ ConfigPath = Utf8ToWide(std::string(Path));
+ }
MakeSafeAbsolutePathInPlace(ConfigPath);
if (!std::filesystem::exists(ConfigPath))
@@ -95,6 +111,8 @@ FileHydrator::Hydrate()
{
ZEN_INFO("Hydrating state from '{}' to '{}'", m_StorageModuleRootDir, m_Config.ServerStateDir);
+ Stopwatch Timer;
+
// Ensure target is clean
ZEN_DEBUG("Wiping server state at '{}'", m_Config.ServerStateDir);
const bool ForceRemoveReadOnlyFiles = true;
@@ -120,6 +138,10 @@ FileHydrator::Hydrate()
ZEN_DEBUG("Cleaning server state '{}'", m_Config.ServerStateDir);
CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles);
}
+ else
+ {
+ ZEN_INFO("Hydration complete in {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ }
}
void
@@ -127,6 +149,8 @@ FileHydrator::Dehydrate()
{
ZEN_INFO("Dehydrating state from '{}' to '{}'", m_Config.ServerStateDir, m_StorageModuleRootDir);
+ Stopwatch Timer;
+
const std::filesystem::path TargetDir = m_StorageModuleRootDir;
// Ensure target is clean. This could be replaced with an atomic copy at a later date
@@ -141,7 +165,23 @@ FileHydrator::Dehydrate()
try
{
ZEN_DEBUG("Copying '{}' to '{}'", m_Config.ServerStateDir, TargetDir);
- CopyTree(m_Config.ServerStateDir, TargetDir, {.EnableClone = true});
+ for (const std::filesystem::directory_entry& Entry : std::filesystem::directory_iterator(m_Config.ServerStateDir))
+ {
+ if (Entry.path().filename() == ".sentry-native")
+ {
+ continue;
+ }
+ std::filesystem::path Dest = TargetDir / Entry.path().filename();
+ if (Entry.is_directory())
+ {
+ CreateDirectories(Dest);
+ CopyTree(Entry.path(), Dest, {.EnableClone = true});
+ }
+ else
+ {
+ CopyFile(Entry.path(), Dest, {.EnableClone = true});
+ }
+ }
}
catch (std::exception& Ex)
{
@@ -159,11 +199,17 @@ FileHydrator::Dehydrate()
ZEN_DEBUG("Wiping server state '{}'", m_Config.ServerStateDir);
CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles);
+
+ if (CopySuccess)
+ {
+ ZEN_INFO("Dehydration complete in {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ }
}
///////////////////////////////////////////////////////////////////////////
constexpr std::string_view S3HydratorPrefix = "s3://";
+constexpr std::string_view S3HydratorType = "s3";
struct S3Hydrator : public HydrationStrategyBase
{
@@ -182,6 +228,8 @@ private:
std::string m_Region;
SigV4Credentials m_Credentials;
Ref<ImdsCredentialProvider> m_CredentialProvider;
+
+ static constexpr uint64_t MultipartChunkSize = 8 * 1024 * 1024;
};
void
@@ -189,8 +237,23 @@ S3Hydrator::Configure(const HydrationConfig& Config)
{
m_Config = Config;
- std::string_view Spec = m_Config.TargetSpecification;
- Spec.remove_prefix(S3HydratorPrefix.size());
+ CbObjectView Settings = m_Config.Options["settings"].AsObjectView();
+ std::string_view Spec;
+ if (!m_Config.TargetSpecification.empty())
+ {
+ Spec = m_Config.TargetSpecification;
+ Spec.remove_prefix(S3HydratorPrefix.size());
+ }
+ else
+ {
+ std::string_view Uri = Settings["uri"].AsString();
+ if (Uri.empty())
+ {
+ throw zen::runtime_error("Hydration config 's3' type requires 'settings.uri'");
+ }
+ Spec = Uri;
+ Spec.remove_prefix(S3HydratorPrefix.size());
+ }
size_t SlashPos = Spec.find('/');
std::string UserPrefix = SlashPos != std::string_view::npos ? std::string(Spec.substr(SlashPos + 1)) : std::string{};
@@ -199,7 +262,11 @@ S3Hydrator::Configure(const HydrationConfig& Config)
ZEN_ASSERT(!m_Bucket.empty());
- std::string Region = GetEnvVariable("AWS_DEFAULT_REGION");
+ std::string Region = std::string(Settings["region"].AsString());
+ if (Region.empty())
+ {
+ Region = GetEnvVariable("AWS_DEFAULT_REGION");
+ }
if (Region.empty())
{
Region = GetEnvVariable("AWS_REGION");
@@ -230,10 +297,12 @@ S3Hydrator::CreateS3Client() const
Options.BucketName = m_Bucket;
Options.Region = m_Region;
- if (!m_Config.S3Endpoint.empty())
+ CbObjectView Settings = m_Config.Options["settings"].AsObjectView();
+ std::string_view Endpoint = Settings["endpoint"].AsString();
+ if (!Endpoint.empty())
{
- Options.Endpoint = m_Config.S3Endpoint;
- Options.PathStyle = m_Config.S3PathStyle;
+ Options.Endpoint = std::string(Endpoint);
+ Options.PathStyle = Settings["path-style"].AsBool();
}
if (m_CredentialProvider)
@@ -245,6 +314,8 @@ S3Hydrator::CreateS3Client() const
Options.Credentials = m_Credentials;
}
+ Options.HttpSettings.MaximumInMemoryDownloadSize = 16u * 1024u;
+
return S3Client(Options);
}
@@ -275,11 +346,11 @@ S3Hydrator::Dehydrate()
try
{
- S3Client Client = CreateS3Client();
- std::string FolderName = BuildTimestampFolderName();
- uint64_t TotalBytes = 0;
- uint32_t FileCount = 0;
- std::chrono::steady_clock::time_point UploadStart = std::chrono::steady_clock::now();
+ S3Client Client = CreateS3Client();
+ std::string FolderName = BuildTimestampFolderName();
+ uint64_t TotalBytes = 0;
+ uint32_t FileCount = 0;
+ Stopwatch Timer;
DirectoryContent DirContent;
GetDirectoryContent(m_Config.ServerStateDir, DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive, DirContent);
@@ -295,13 +366,20 @@ S3Hydrator::Dehydrate()
AbsPath.string(),
m_Config.ServerStateDir.string());
}
+ if (*RelPath.begin() == ".sentry-native")
+ {
+ continue;
+ }
std::string Key = MakeObjectKey(FolderName, RelPath);
BasicFile File(AbsPath, BasicFile::Mode::kRead);
uint64_t FileSize = File.FileSize();
- S3Result UploadResult =
- Client.PutObjectMultipart(Key, FileSize, [&File](uint64_t Offset, uint64_t Size) { return File.ReadRange(Offset, Size); });
+ S3Result UploadResult = Client.PutObjectMultipart(
+ Key,
+ FileSize,
+ [&File](uint64_t Offset, uint64_t Size) { return File.ReadRange(Offset, Size); },
+ MultipartChunkSize);
if (!UploadResult.IsSuccess())
{
throw zen::runtime_error("Failed to upload '{}' to S3: {}", Key, UploadResult.Error);
@@ -312,8 +390,7 @@ S3Hydrator::Dehydrate()
}
// Write current-state.json
- int64_t UploadDurationMs =
- std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - UploadStart).count();
+ uint64_t UploadDurationMs = Timer.GetElapsedTimeMs();
UtcTime Now = UtcTime::Now();
std::string UploadTimeUtc = fmt::format("{:04d}-{:02d}-{:02d}T{:02d}:{:02d}:{:02d}.{:03d}Z",
@@ -346,7 +423,7 @@ S3Hydrator::Dehydrate()
throw zen::runtime_error("Failed to write current-state.json to '{}': {}", MetaKey, MetaUploadResult.Error);
}
- ZEN_INFO("Dehydration complete: {} files, {} bytes, {} ms", FileCount, TotalBytes, UploadDurationMs);
+ ZEN_INFO("Dehydration complete: {} files, {}, {}", FileCount, NiceBytes(TotalBytes), NiceTimeSpanMs(UploadDurationMs));
}
catch (std::exception& Ex)
{
@@ -361,6 +438,7 @@ S3Hydrator::Hydrate()
{
ZEN_INFO("Hydrating state from s3://{}/{} to '{}'", m_Bucket, m_KeyPrefix, m_Config.ServerStateDir);
+ Stopwatch Timer;
const bool ForceRemoveReadOnlyFiles = true;
// Clean temp dir before starting in case of leftover state from a previous failed hydration
@@ -374,19 +452,17 @@ S3Hydrator::Hydrate()
S3Client Client = CreateS3Client();
std::string MetaKey = m_KeyPrefix + "/current-state.json";
- S3HeadObjectResult HeadResult = Client.HeadObject(MetaKey);
- if (HeadResult.Status == HeadObjectResult::NotFound)
- {
- throw zen::runtime_error("No state found in S3 at '{}'", MetaKey);
- }
- if (!HeadResult.IsSuccess())
- {
- throw zen::runtime_error("Failed to check for state in S3 at '{}': {}", MetaKey, HeadResult.Error);
- }
-
S3GetObjectResult MetaResult = Client.GetObject(MetaKey);
if (!MetaResult.IsSuccess())
{
+ if (MetaResult.Error == S3GetObjectResult::NotFoundErrorText)
+ {
+ ZEN_INFO("No state found in S3 at {}", MetaKey);
+
+ ZEN_DEBUG("Wiping server state '{}'", m_Config.ServerStateDir);
+ CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles);
+ return;
+ }
throw zen::runtime_error("Failed to read current-state.json from '{}': {}", MetaKey, MetaResult.Error);
}
@@ -426,17 +502,17 @@ S3Hydrator::Hydrate()
std::filesystem::path DestPath = MakeSafeAbsolutePath(m_Config.TempDir / std::filesystem::path(RelKey));
CreateDirectories(DestPath.parent_path());
- BasicFile DestFile(DestPath, BasicFile::Mode::kTruncate);
- DestFile.SetFileSize(Obj.Size);
-
- if (Obj.Size > 0)
+ if (Obj.Size > MultipartChunkSize)
{
+ BasicFile DestFile(DestPath, BasicFile::Mode::kTruncate);
+ DestFile.SetFileSize(Obj.Size);
+
BasicFileWriter Writer(DestFile, 64 * 1024);
uint64_t Offset = 0;
while (Offset < Obj.Size)
{
- uint64_t ChunkSize = std::min<uint64_t>(8 * 1024 * 1024, Obj.Size - Offset);
+ uint64_t ChunkSize = std::min<uint64_t>(MultipartChunkSize, Obj.Size - Offset);
S3GetObjectResult Chunk = Client.GetObjectRange(Obj.Key, Offset, ChunkSize);
if (!Chunk.IsSuccess())
{
@@ -453,6 +529,34 @@ S3Hydrator::Hydrate()
Writer.Flush();
}
+ else
+ {
+ S3GetObjectResult Chunk = Client.GetObject(Obj.Key, m_Config.TempDir);
+ if (!Chunk.IsSuccess())
+ {
+ throw zen::runtime_error("Failed to download '{}' from S3: {}", Obj.Key, Chunk.Error);
+ }
+
+ if (IoBufferFileReference FileRef; Chunk.Content.GetFileReference(FileRef))
+ {
+ std::error_code Ec;
+ std::filesystem::path ChunkPath = PathFromHandle(FileRef.FileHandle, Ec);
+ if (Ec)
+ {
+ WriteFile(DestPath, Chunk.Content);
+ }
+ else
+ {
+ Chunk.Content.SetDeleteOnClose(false);
+ Chunk.Content = {};
+ RenameFile(ChunkPath, DestPath, Ec);
+ }
+ }
+ else
+ {
+ WriteFile(DestPath, Chunk.Content);
+ }
+ }
}
// Downloaded successfully - swap into ServerStateDir
@@ -465,19 +569,20 @@ S3Hydrator::Hydrate()
std::mismatch(m_Config.TempDir.begin(), m_Config.TempDir.end(), m_Config.ServerStateDir.begin(), m_Config.ServerStateDir.end());
if (ItTmp != m_Config.TempDir.begin())
{
- // Fast path: atomic renames - no data copying needed
- for (const std::filesystem::directory_entry& Entry : std::filesystem::directory_iterator(m_Config.TempDir))
+ DirectoryContent DirContent;
+ GetDirectoryContent(m_Config.TempDir, DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::IncludeDirs, DirContent);
+
+ for (const std::filesystem::path& AbsPath : DirContent.Directories)
{
- std::filesystem::path Dest = MakeSafeAbsolutePath(m_Config.ServerStateDir / Entry.path().filename());
- if (Entry.is_directory())
- {
- RenameDirectory(Entry.path(), Dest);
- }
- else
- {
- RenameFile(Entry.path(), Dest);
- }
+ std::filesystem::path Dest = MakeSafeAbsolutePath(m_Config.ServerStateDir / AbsPath.filename());
+ RenameDirectory(AbsPath, Dest);
+ }
+ for (const std::filesystem::path& AbsPath : DirContent.Files)
+ {
+ std::filesystem::path Dest = MakeSafeAbsolutePath(m_Config.ServerStateDir / AbsPath.filename());
+ RenameFile(AbsPath, Dest);
}
+
ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir);
CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles);
}
@@ -491,7 +596,7 @@ S3Hydrator::Hydrate()
CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles);
}
- ZEN_INFO("Hydration complete from folder '{}'", FolderName);
+ ZEN_INFO("Hydration complete from folder '{}' in {}", FolderName, NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
}
catch (std::exception& Ex)
{
@@ -513,19 +618,41 @@ S3Hydrator::Hydrate()
std::unique_ptr<HydrationStrategyBase>
CreateHydrator(const HydrationConfig& Config)
{
- if (StrCaseCompare(Config.TargetSpecification.substr(0, FileHydratorPrefix.length()), FileHydratorPrefix) == 0)
+ if (!Config.TargetSpecification.empty())
+ {
+ if (StrCaseCompare(Config.TargetSpecification.substr(0, FileHydratorPrefix.length()), FileHydratorPrefix) == 0)
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<FileHydrator>();
+ Hydrator->Configure(Config);
+ return Hydrator;
+ }
+ if (StrCaseCompare(Config.TargetSpecification.substr(0, S3HydratorPrefix.length()), S3HydratorPrefix) == 0)
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<S3Hydrator>();
+ Hydrator->Configure(Config);
+ return Hydrator;
+ }
+ throw std::runtime_error(fmt::format("Unknown hydration strategy: {}", Config.TargetSpecification));
+ }
+
+ std::string_view Type = Config.Options["type"].AsString();
+ if (Type == FileHydratorType)
{
std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<FileHydrator>();
Hydrator->Configure(Config);
return Hydrator;
}
- if (StrCaseCompare(Config.TargetSpecification.substr(0, S3HydratorPrefix.length()), S3HydratorPrefix) == 0)
+ if (Type == S3HydratorType)
{
std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<S3Hydrator>();
Hydrator->Configure(Config);
return Hydrator;
}
- throw std::runtime_error(fmt::format("Unknown hydration strategy: {}", Config.TargetSpecification));
+ if (!Type.empty())
+ {
+ throw zen::runtime_error("Unknown hydration target type '{}'", Type);
+ }
+ throw zen::runtime_error("No hydration target configured");
}
#if ZEN_WITH_TESTS
@@ -607,6 +734,12 @@ namespace {
AddFile("file_a.bin", CreateSemiRandomBlob(1024));
AddFile("subdir/file_b.bin", CreateSemiRandomBlob(2048));
AddFile("subdir/nested/file_c.bin", CreateSemiRandomBlob(512));
+ AddFile("subdir/nested/file_d.bin", CreateSemiRandomBlob(512));
+ AddFile("subdir/nested/file_e.bin", CreateSemiRandomBlob(512));
+ AddFile("subdir/nested/file_f.bin", CreateSemiRandomBlob(512));
+ AddFile("subdir/nested/medium.bulk", CreateSemiRandomBlob(256u * 1024u));
+ AddFile("subdir/nested/big.bulk", CreateSemiRandomBlob(512u * 1024u));
+ AddFile("subdir/nested/huge.bulk", CreateSemiRandomBlob(9u * 1024u * 1024u));
return Files;
}
@@ -844,12 +977,16 @@ TEST_CASE("hydration.s3.dehydrate_hydrate")
auto TestFiles = CreateTestTree(ServerStateDir);
HydrationConfig Config;
- Config.ServerStateDir = ServerStateDir;
- Config.TempDir = HydrationTemp;
- Config.ModuleId = ModuleId;
- Config.TargetSpecification = "s3://zen-hydration-test";
- Config.S3Endpoint = Minio.Endpoint();
- Config.S3PathStyle = true;
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = ModuleId;
+ std::string ConfigJson =
+ fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
// Dehydrate: upload server state to MinIO
{
@@ -902,12 +1039,18 @@ TEST_CASE("hydration.s3.current_state_json_selects_latest_folder")
const std::string ModuleId = "s3test_folder_select";
HydrationConfig Config;
- Config.ServerStateDir = ServerStateDir;
- Config.TempDir = HydrationTemp;
- Config.ModuleId = ModuleId;
- Config.TargetSpecification = "s3://zen-hydration-test";
- Config.S3Endpoint = Minio.Endpoint();
- Config.S3PathStyle = true;
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = ModuleId;
+ {
+ std::string ConfigJson =
+ fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
+ }
// v1: dehydrate without a marker file
CreateTestTree(ServerStateDir);
@@ -972,13 +1115,19 @@ TEST_CASE("hydration.s3.module_isolation")
CreateDirectories(TempPath);
ModuleData Data;
- Data.Config.ServerStateDir = StateDir;
- Data.Config.TempDir = TempPath;
- Data.Config.ModuleId = ModuleId;
- Data.Config.TargetSpecification = "s3://zen-hydration-test";
- Data.Config.S3Endpoint = Minio.Endpoint();
- Data.Config.S3PathStyle = true;
- Data.Files = CreateTestTree(StateDir);
+ Data.Config.ServerStateDir = StateDir;
+ Data.Config.TempDir = TempPath;
+ Data.Config.ModuleId = ModuleId;
+ {
+ std::string ConfigJson =
+ fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Data.Config.Options = std::move(Root).AsObject();
+ }
+ Data.Files = CreateTestTree(StateDir);
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Data.Config);
Hydrator->Dehydrate();
@@ -1015,7 +1164,8 @@ TEST_CASE("hydration.s3.concurrent")
ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser());
ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword());
- constexpr int kModuleCount = 4;
+ constexpr int kModuleCount = 16;
+ constexpr int kThreadCount = 4;
ScopedTemporaryDirectory TempDir;
@@ -1034,18 +1184,24 @@ TEST_CASE("hydration.s3.concurrent")
CreateDirectories(StateDir);
CreateDirectories(TempPath);
- Modules[I].Config.ServerStateDir = StateDir;
- Modules[I].Config.TempDir = TempPath;
- Modules[I].Config.ModuleId = ModuleId;
- Modules[I].Config.TargetSpecification = "s3://zen-hydration-test";
- Modules[I].Config.S3Endpoint = Minio.Endpoint();
- Modules[I].Config.S3PathStyle = true;
- Modules[I].Files = CreateTestTree(StateDir);
+ Modules[I].Config.ServerStateDir = StateDir;
+ Modules[I].Config.TempDir = TempPath;
+ Modules[I].Config.ModuleId = ModuleId;
+ {
+ std::string ConfigJson =
+ fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Modules[I].Config.Options = std::move(Root).AsObject();
+ }
+ Modules[I].Files = CreateTestTree(StateDir);
}
// Concurrent dehydrate
{
- WorkerThreadPool Pool(kModuleCount, "hydration_s3_dehy");
+ WorkerThreadPool Pool(kThreadCount, "hydration_s3_dehy");
std::atomic<bool> AbortFlag{false};
std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
@@ -1063,7 +1219,7 @@ TEST_CASE("hydration.s3.concurrent")
// Concurrent hydrate
{
- WorkerThreadPool Pool(kModuleCount, "hydration_s3_hy");
+ WorkerThreadPool Pool(kThreadCount, "hydration_s3_hy");
std::atomic<bool> AbortFlag{false};
std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
@@ -1116,12 +1272,18 @@ TEST_CASE("hydration.s3.no_prior_state")
WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256));
HydrationConfig Config;
- Config.ServerStateDir = ServerStateDir;
- Config.TempDir = HydrationTemp;
- Config.ModuleId = "s3test_no_prior";
- Config.TargetSpecification = "s3://zen-hydration-test";
- Config.S3Endpoint = Minio.Endpoint();
- Config.S3PathStyle = true;
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = "s3test_no_prior";
+ {
+ std::string ConfigJson =
+ fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
+ }
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
Hydrator->Hydrate();
@@ -1159,12 +1321,71 @@ TEST_CASE("hydration.s3.path_prefix")
std::vector<std::pair<std::filesystem::path, IoBuffer>> TestFiles = CreateTestTree(ServerStateDir);
HydrationConfig Config;
- Config.ServerStateDir = ServerStateDir;
- Config.TempDir = HydrationTemp;
- Config.ModuleId = "s3test_prefix";
- Config.TargetSpecification = "s3://zen-hydration-test/team/project";
- Config.S3Endpoint = Minio.Endpoint();
- Config.S3PathStyle = true;
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = "s3test_prefix";
+ {
+ std::string ConfigJson =
+ fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test/team/project","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
+ }
+
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Dehydrate();
+ }
+
+ CleanDirectory(ServerStateDir, true);
+
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Hydrate();
+ }
+
+ VerifyTree(ServerStateDir, TestFiles);
+}
+
+TEST_CASE("hydration.s3.options_region_override")
+{
+ // Verify that 'region' in Options["settings"] takes precedence over AWS_DEFAULT_REGION env var.
+ // AWS_DEFAULT_REGION is set to a bogus value; hydration must succeed using the region from Options.
+
+ MinioProcessOptions MinioOpts;
+ MinioOpts.Port = 19016;
+ MinioProcess Minio(MinioOpts);
+ Minio.SpawnMinioServer();
+ Minio.CreateBucket("zen-hydration-test");
+
+ ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser());
+ ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword());
+ ScopedEnvVar EnvRegion("AWS_DEFAULT_REGION", "wrong-region");
+
+ ScopedTemporaryDirectory TempDir;
+
+ std::filesystem::path ServerStateDir = TempDir.Path() / "server_state";
+ std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp";
+ CreateDirectories(ServerStateDir);
+ CreateDirectories(HydrationTemp);
+
+ auto TestFiles = CreateTestTree(ServerStateDir);
+
+ HydrationConfig Config;
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = "s3test_region_override";
+ {
+ std::string ConfigJson = fmt::format(
+ R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true,"region":"us-east-1"}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
+ }
{
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
diff --git a/src/zenserver/hub/hydration.h b/src/zenserver/hub/hydration.h
index d29ffe5c0..19a96c248 100644
--- a/src/zenserver/hub/hydration.h
+++ b/src/zenserver/hub/hydration.h
@@ -2,6 +2,8 @@
#pragma once
+#include <zencore/compactbinary.h>
+
#include <filesystem>
namespace zen {
@@ -16,12 +18,8 @@ struct HydrationConfig
std::string ModuleId;
// Back-end specific target specification (e.g. S3 bucket, file path, etc)
std::string TargetSpecification;
-
- // Optional S3 endpoint override (e.g. "http://localhost:9000" for MinIO).
- std::string S3Endpoint;
- // Use path-style S3 URLs (endpoint/bucket/key) instead of virtual-hosted-style
- // (bucket.endpoint/key). Required for MinIO and other non-AWS endpoints.
- bool S3PathStyle = false;
+ // Full config object when using --hub-hydration-target-config (mutually exclusive with TargetSpecification)
+ CbObject Options;
};
/**
diff --git a/src/zenserver/hub/storageserverinstance.cpp b/src/zenserver/hub/storageserverinstance.cpp
index 6b139dbf1..0c9354990 100644
--- a/src/zenserver/hub/storageserverinstance.cpp
+++ b/src/zenserver/hub/storageserverinstance.cpp
@@ -57,16 +57,15 @@ StorageServerInstance::SpawnServerProcess()
m_ServerInstance.EnableShutdownOnDestroy();
}
-void
-StorageServerInstance::GetProcessMetrics(ProcessMetrics& OutMetrics) const
+ProcessMetrics
+StorageServerInstance::GetProcessMetrics() const
{
- OutMetrics.MemoryBytes = m_MemoryBytes.load();
- OutMetrics.KernelTimeMs = m_KernelTimeMs.load();
- OutMetrics.UserTimeMs = m_UserTimeMs.load();
- OutMetrics.WorkingSetSize = m_WorkingSetSize.load();
- OutMetrics.PeakWorkingSetSize = m_PeakWorkingSetSize.load();
- OutMetrics.PagefileUsage = m_PagefileUsage.load();
- OutMetrics.PeakPagefileUsage = m_PeakPagefileUsage.load();
+ ProcessMetrics Metrics;
+ if (m_ServerInstance.IsRunning())
+ {
+ zen::GetProcessMetrics(m_ServerInstance.GetProcessHandle(), Metrics);
+ }
+ return Metrics;
}
void
@@ -158,7 +157,8 @@ StorageServerInstance::Hydrate()
HydrationConfig Config{.ServerStateDir = m_BaseDir,
.TempDir = m_TempDir,
.ModuleId = m_ModuleId,
- .TargetSpecification = m_Config.HydrationTargetSpecification};
+ .TargetSpecification = m_Config.HydrationTargetSpecification,
+ .Options = m_Config.HydrationOptions};
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
@@ -171,7 +171,8 @@ StorageServerInstance::Dehydrate()
HydrationConfig Config{.ServerStateDir = m_BaseDir,
.TempDir = m_TempDir,
.ModuleId = m_ModuleId,
- .TargetSpecification = m_Config.HydrationTargetSpecification};
+ .TargetSpecification = m_Config.HydrationTargetSpecification,
+ .Options = m_Config.HydrationOptions};
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
@@ -249,25 +250,6 @@ StorageServerInstance::SharedLockedPtr::IsRunning() const
return m_Instance->m_ServerInstance.IsRunning();
}
-void
-StorageServerInstance::UpdateMetricsLocked()
-{
- if (m_ServerInstance.IsRunning())
- {
- ProcessMetrics Metrics;
- zen::GetProcessMetrics(m_ServerInstance.GetProcessHandle(), Metrics);
-
- m_MemoryBytes.store(Metrics.MemoryBytes);
- m_KernelTimeMs.store(Metrics.KernelTimeMs);
- m_UserTimeMs.store(Metrics.UserTimeMs);
- m_WorkingSetSize.store(Metrics.WorkingSetSize);
- m_PeakWorkingSetSize.store(Metrics.PeakWorkingSetSize);
- m_PagefileUsage.store(Metrics.PagefileUsage);
- m_PeakPagefileUsage.store(Metrics.PeakPagefileUsage);
- }
- // TODO: Resource metrics...
-}
-
#if ZEN_WITH_TESTS
void
StorageServerInstance::SharedLockedPtr::TerminateForTesting() const
diff --git a/src/zenserver/hub/storageserverinstance.h b/src/zenserver/hub/storageserverinstance.h
index 94c47630c..1b0078d87 100644
--- a/src/zenserver/hub/storageserverinstance.h
+++ b/src/zenserver/hub/storageserverinstance.h
@@ -2,8 +2,7 @@
#pragma once
-#include "resourcemetrics.h"
-
+#include <zencore/compactbinary.h>
#include <zenutil/zenserverprocess.h>
#include <atomic>
@@ -26,6 +25,7 @@ public:
uint16_t BasePort;
std::filesystem::path HydrationTempPath;
std::string HydrationTargetSpecification;
+ CbObject HydrationOptions;
uint32_t HttpThreadCount = 0; // Automatic
int CoreLimit = 0; // Automatic
std::filesystem::path ConfigPath;
@@ -34,11 +34,9 @@ public:
StorageServerInstance(ZenServerEnvironment& RunEnvironment, const Configuration& Config, std::string_view ModuleId);
~StorageServerInstance();
- const ResourceMetrics& GetResourceMetrics() const { return m_ResourceMetrics; }
-
inline std::string_view GetModuleId() const { return m_ModuleId; }
inline uint16_t GetBasePort() const { return m_Config.BasePort; }
- void GetProcessMetrics(ProcessMetrics& OutMetrics) const;
+ ProcessMetrics GetProcessMetrics() const;
#if ZEN_PLATFORM_WINDOWS
void SetJobObject(JobObject* InJobObject) { m_JobObject = InJobObject; }
@@ -68,15 +66,10 @@ public:
}
bool IsRunning() const;
- const ResourceMetrics& GetResourceMetrics() const
- {
- ZEN_ASSERT(m_Instance);
- return m_Instance->m_ResourceMetrics;
- }
- void UpdateMetrics()
+ ProcessMetrics GetProcessMetrics() const
{
ZEN_ASSERT(m_Instance);
- return m_Instance->UpdateMetricsLocked();
+ return m_Instance->GetProcessMetrics();
}
#if ZEN_WITH_TESTS
@@ -114,12 +107,6 @@ public:
}
bool IsRunning() const;
- const ResourceMetrics& GetResourceMetrics() const
- {
- ZEN_ASSERT(m_Instance);
- return m_Instance->m_ResourceMetrics;
- }
-
void Provision();
void Deprovision();
void Hibernate();
@@ -139,8 +126,6 @@ private:
void HibernateLocked();
void WakeLocked();
- void UpdateMetricsLocked();
-
mutable RwLock m_Lock;
const Configuration m_Config;
std::string m_ModuleId;
@@ -149,15 +134,6 @@ private:
std::filesystem::path m_BaseDir;
std::filesystem::path m_TempDir;
- ResourceMetrics m_ResourceMetrics;
-
- std::atomic<uint64_t> m_MemoryBytes = 0;
- std::atomic<uint64_t> m_KernelTimeMs = 0;
- std::atomic<uint64_t> m_UserTimeMs = 0;
- std::atomic<uint64_t> m_WorkingSetSize = 0;
- std::atomic<uint64_t> m_PeakWorkingSetSize = 0;
- std::atomic<uint64_t> m_PagefileUsage = 0;
- std::atomic<uint64_t> m_PeakPagefileUsage = 0;
#if ZEN_PLATFORM_WINDOWS
JobObject* m_JobObject = nullptr;
diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp
index 314031246..d01e5f3f2 100644
--- a/src/zenserver/hub/zenhubserver.cpp
+++ b/src/zenserver/hub/zenhubserver.cpp
@@ -2,17 +2,24 @@
#include "zenhubserver.h"
+#include "config/luaconfig.h"
#include "frontend/frontend.h"
#include "httphubservice.h"
+#include "httpproxyhandler.h"
#include "hub.h"
+#include <zencore/compactbinary.h>
#include <zencore/config.h>
+#include <zencore/except.h>
+#include <zencore/except_fmt.h>
+#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
#include <zencore/memory/llm.h>
#include <zencore/memory/memorytrace.h>
#include <zencore/memory/tagtrace.h>
#include <zencore/scopeguard.h>
#include <zencore/sentryintegration.h>
+#include <zencore/system.h>
#include <zencore/windows.h>
#include <zenhttp/httpapiservice.h>
#include <zenutil/service.h>
@@ -53,12 +60,19 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options)
Options.add_option("hub",
"",
"instance-id",
- "Instance ID for use in notifications",
+ "Instance ID for use in notifications (deprecated, use --upstream-notification-instance-id)",
cxxopts::value<std::string>(m_ServerOptions.InstanceId)->default_value(""),
"");
Options.add_option("hub",
"",
+ "upstream-notification-instance-id",
+ "Instance ID for use in notifications",
+ cxxopts::value<std::string>(m_ServerOptions.InstanceId),
+ "");
+
+ Options.add_option("hub",
+ "",
"consul-endpoint",
"Consul endpoint URL for service registration (empty = disabled)",
cxxopts::value<std::string>(m_ServerOptions.ConsulEndpoint)->default_value(""),
@@ -89,12 +103,19 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options)
Options.add_option("hub",
"",
"hub-base-port-number",
- "Base port number for provisioned instances",
+ "Base port number for provisioned instances (deprecated, use --hub-instance-base-port-number)",
cxxopts::value<uint16_t>(m_ServerOptions.HubBasePortNumber)->default_value("21000"),
"");
Options.add_option("hub",
"",
+ "hub-instance-base-port-number",
+ "Base port number for provisioned instances",
+ cxxopts::value<uint16_t>(m_ServerOptions.HubBasePortNumber),
+ "");
+
+ Options.add_option("hub",
+ "",
"hub-instance-limit",
"Maximum number of provisioned instances for this hub",
cxxopts::value<int>(m_ServerOptions.HubInstanceLimit)->default_value("1000"),
@@ -139,6 +160,14 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options)
cxxopts::value(m_ServerOptions.HydrationTargetSpecification),
"<hydration-target-spec>");
+ Options.add_option("hub",
+ "",
+ "hub-hydration-target-config",
+ "Path to JSON file specifying the hydration target (mutually exclusive with "
+ "--hub-hydration-target-spec). Supported types: 'file', 's3'.",
+ cxxopts::value(m_ServerOptions.HydrationTargetConfigPath),
+ "<path>");
+
#if ZEN_PLATFORM_WINDOWS
Options.add_option("hub",
"",
@@ -203,12 +232,103 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options)
"Request timeout in milliseconds for instance activity check requests",
cxxopts::value<uint32_t>(m_ServerOptions.WatchdogConfig.ActivityCheckRequestTimeoutMs)->default_value("200"),
"<ms>");
+
+ Options.add_option("hub",
+ "",
+ "hub-provision-disk-limit-bytes",
+ "Reject provisioning when used disk bytes exceed this value (0 = no limit).",
+ cxxopts::value<uint64_t>(m_ServerOptions.HubProvisionDiskLimitBytes),
+ "<bytes>");
+
+ Options.add_option("hub",
+ "",
+ "hub-provision-disk-limit-percent",
+ "Reject provisioning when used disk exceeds this percentage of total disk (0 = no limit).",
+ cxxopts::value<uint32_t>(m_ServerOptions.HubProvisionDiskLimitPercent),
+ "<percent>");
+
+ Options.add_option("hub",
+ "",
+ "hub-provision-memory-limit-bytes",
+ "Reject provisioning when used memory bytes exceed this value (0 = no limit).",
+ cxxopts::value<uint64_t>(m_ServerOptions.HubProvisionMemoryLimitBytes),
+ "<bytes>");
+
+ Options.add_option("hub",
+ "",
+ "hub-provision-memory-limit-percent",
+ "Reject provisioning when used memory exceeds this percentage of total RAM (0 = no limit).",
+ cxxopts::value<uint32_t>(m_ServerOptions.HubProvisionMemoryLimitPercent),
+ "<percent>");
}
void
ZenHubServerConfigurator::AddConfigOptions(LuaConfig::Options& Options)
{
- ZEN_UNUSED(Options);
+ using namespace std::literals;
+
+ Options.AddOption("hub.upstreamnotification.endpoint"sv,
+ m_ServerOptions.UpstreamNotificationEndpoint,
+ "upstream-notification-endpoint"sv);
+ Options.AddOption("hub.upstreamnotification.instanceid"sv, m_ServerOptions.InstanceId, "upstream-notification-instance-id"sv);
+
+ Options.AddOption("hub.consul.endpoint"sv, m_ServerOptions.ConsulEndpoint, "consul-endpoint"sv);
+ Options.AddOption("hub.consul.tokenenv"sv, m_ServerOptions.ConsulTokenEnv, "consul-token-env"sv);
+ Options.AddOption("hub.consul.healthintervalseconds"sv,
+ m_ServerOptions.ConsulHealthIntervalSeconds,
+ "consul-health-interval-seconds"sv);
+ Options.AddOption("hub.consul.deregisterafterseconds"sv,
+ m_ServerOptions.ConsulDeregisterAfterSeconds,
+ "consul-deregister-after-seconds"sv);
+
+ Options.AddOption("hub.instance.baseportnumber"sv, m_ServerOptions.HubBasePortNumber, "hub-instance-base-port-number"sv);
+ Options.AddOption("hub.instance.http"sv, m_ServerOptions.HubInstanceHttpClass, "hub-instance-http"sv);
+ Options.AddOption("hub.instance.httpthreads"sv, m_ServerOptions.HubInstanceHttpThreadCount, "hub-instance-http-threads"sv);
+ Options.AddOption("hub.instance.corelimit"sv, m_ServerOptions.HubInstanceCoreLimit, "hub-instance-corelimit"sv);
+ Options.AddOption("hub.instance.config"sv, m_ServerOptions.HubInstanceConfigPath, "hub-instance-config"sv);
+ Options.AddOption("hub.instance.limits.count"sv, m_ServerOptions.HubInstanceLimit, "hub-instance-limit"sv);
+ Options.AddOption("hub.instance.limits.disklimitbytes"sv,
+ m_ServerOptions.HubProvisionDiskLimitBytes,
+ "hub-provision-disk-limit-bytes"sv);
+ Options.AddOption("hub.instance.limits.disklimitpercent"sv,
+ m_ServerOptions.HubProvisionDiskLimitPercent,
+ "hub-provision-disk-limit-percent"sv);
+ Options.AddOption("hub.instance.limits.memorylimitbytes"sv,
+ m_ServerOptions.HubProvisionMemoryLimitBytes,
+ "hub-provision-memory-limit-bytes"sv);
+ Options.AddOption("hub.instance.limits.memorylimitpercent"sv,
+ m_ServerOptions.HubProvisionMemoryLimitPercent,
+ "hub-provision-memory-limit-percent"sv);
+
+ Options.AddOption("hub.hydration.targetspec"sv, m_ServerOptions.HydrationTargetSpecification, "hub-hydration-target-spec"sv);
+ Options.AddOption("hub.hydration.targetconfig"sv, m_ServerOptions.HydrationTargetConfigPath, "hub-hydration-target-config"sv);
+
+ Options.AddOption("hub.watchdog.cycleintervalms"sv, m_ServerOptions.WatchdogConfig.CycleIntervalMs, "hub-watchdog-cycle-interval-ms"sv);
+ Options.AddOption("hub.watchdog.cycleprocessingbudgetms"sv,
+ m_ServerOptions.WatchdogConfig.CycleProcessingBudgetMs,
+ "hub-watchdog-cycle-processing-budget-ms"sv);
+ Options.AddOption("hub.watchdog.instancecheckthrottlems"sv,
+ m_ServerOptions.WatchdogConfig.InstanceCheckThrottleMs,
+ "hub-watchdog-instance-check-throttle-ms"sv);
+ Options.AddOption("hub.watchdog.provisionedinactivitytimeoutseconds"sv,
+ m_ServerOptions.WatchdogConfig.ProvisionedInactivityTimeoutSeconds,
+ "hub-watchdog-provisioned-inactivity-timeout-seconds"sv);
+ Options.AddOption("hub.watchdog.hibernatedinactivitytimeoutseconds"sv,
+ m_ServerOptions.WatchdogConfig.HibernatedInactivityTimeoutSeconds,
+ "hub-watchdog-hibernated-inactivity-timeout-seconds"sv);
+ Options.AddOption("hub.watchdog.inactivitycheckmarginseconds"sv,
+ m_ServerOptions.WatchdogConfig.InactivityCheckMarginSeconds,
+ "hub-watchdog-inactivity-check-margin-seconds"sv);
+ Options.AddOption("hub.watchdog.activitycheckconnecttimeoutms"sv,
+ m_ServerOptions.WatchdogConfig.ActivityCheckConnectTimeoutMs,
+ "hub-watchdog-activity-check-connect-timeout-ms"sv);
+ Options.AddOption("hub.watchdog.activitycheckrequesttimeoutms"sv,
+ m_ServerOptions.WatchdogConfig.ActivityCheckRequestTimeoutMs,
+ "hub-watchdog-activity-check-request-timeout-ms"sv);
+
+#if ZEN_PLATFORM_WINDOWS
+ Options.AddOption("hub.usejobobject"sv, m_ServerOptions.HubUseJobObject, "hub-use-job-object"sv);
+#endif
}
void
@@ -226,6 +346,28 @@ ZenHubServerConfigurator::OnConfigFileParsed(LuaConfig::Options& LuaOptions)
void
ZenHubServerConfigurator::ValidateOptions()
{
+ if (m_ServerOptions.HubProvisionDiskLimitPercent > 100)
+ {
+ throw OptionParseException(
+ fmt::format("'--hub-provision-disk-limit-percent' ({}) must be in range 0..100", m_ServerOptions.HubProvisionDiskLimitPercent),
+ {});
+ }
+ if (m_ServerOptions.HubProvisionMemoryLimitPercent > 100)
+ {
+ throw OptionParseException(fmt::format("'--hub-provision-memory-limit-percent' ({}) must be in range 0..100",
+ m_ServerOptions.HubProvisionMemoryLimitPercent),
+ {});
+ }
+ if (!m_ServerOptions.HydrationTargetSpecification.empty() && !m_ServerOptions.HydrationTargetConfigPath.empty())
+ {
+ throw OptionParseException("'--hub-hydration-target-spec' and '--hub-hydration-target-config' are mutually exclusive", {});
+ }
+ if (!m_ServerOptions.HydrationTargetConfigPath.empty() && !std::filesystem::exists(m_ServerOptions.HydrationTargetConfigPath))
+ {
+ throw OptionParseException(
+ fmt::format("'--hub-hydration-target-config': file not found: '{}'", m_ServerOptions.HydrationTargetConfigPath.string()),
+ {});
+ }
}
///////////////////////////////////////////////////////////////////////////
@@ -247,6 +389,15 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId,
HubInstanceState NewState)
{
ZEN_UNUSED(PreviousState);
+
+ if (NewState == HubInstanceState::Deprovisioning || NewState == HubInstanceState::Hibernating)
+ {
+ if (Info.Port != 0)
+ {
+ m_Proxy->PrunePort(Info.Port);
+ }
+ }
+
if (!m_ConsulClient)
{
return;
@@ -294,8 +445,8 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId,
ZEN_INFO("Deregistered storage server instance for module '{}' at port {} from Consul", ModuleId, Info.Port);
}
}
- // Transitional states (Deprovisioning, Hibernating, Waking, Recovering, Crashed)
- // and Hibernated are intentionally ignored.
+ // Transitional states (Waking, Recovering, Crashed) and stable states
+ // not handled above (Hibernated) are intentionally ignored by Consul.
}
int
@@ -348,6 +499,11 @@ ZenHubServer::Cleanup()
m_Http->Close();
}
+ if (m_Proxy)
+ {
+ m_Proxy->Shutdown();
+ }
+
if (m_Hub)
{
m_Hub->Shutdown();
@@ -357,6 +513,7 @@ ZenHubServer::Cleanup()
m_HubService.reset();
m_ApiService.reset();
m_Hub.reset();
+ m_Proxy.reset();
m_ConsulRegistration.reset();
m_ConsulClient.reset();
@@ -373,49 +530,116 @@ ZenHubServer::InitializeState(const ZenHubServerConfig& ServerConfig)
ZEN_UNUSED(ServerConfig);
}
+ResourceMetrics
+ZenHubServer::ResolveLimits(const ZenHubServerConfig& ServerConfig)
+{
+ uint64_t DiskTotal = 0;
+ uint64_t MemoryTotal = 0;
+
+ if (ServerConfig.HubProvisionDiskLimitPercent > 0)
+ {
+ DiskSpace Disk;
+ if (DiskSpaceInfo(ServerConfig.DataDir, Disk))
+ {
+ DiskTotal = Disk.Total;
+ }
+ else
+ {
+ ZEN_WARN("Failed to query disk space for '{}'; disk percent limit will not be applied", ServerConfig.DataDir);
+ }
+ }
+ if (ServerConfig.HubProvisionMemoryLimitPercent > 0)
+ {
+ MemoryTotal = GetSystemMetrics().SystemMemoryMiB * 1024 * 1024;
+ }
+
+ auto Resolve = [](uint64_t Bytes, uint32_t Pct, uint64_t Total) -> uint64_t {
+ const uint64_t PctBytes = Pct > 0 ? (Total * Pct) / 100 : 0;
+ if (Bytes > 0 && PctBytes > 0)
+ {
+ return Min(Bytes, PctBytes);
+ }
+ return Bytes > 0 ? Bytes : PctBytes;
+ };
+
+ return {
+ .DiskUsageBytes = Resolve(ServerConfig.HubProvisionDiskLimitBytes, ServerConfig.HubProvisionDiskLimitPercent, DiskTotal),
+ .MemoryUsageBytes = Resolve(ServerConfig.HubProvisionMemoryLimitBytes, ServerConfig.HubProvisionMemoryLimitPercent, MemoryTotal),
+ };
+}
+
void
ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig)
{
ZEN_INFO("instantiating Hub");
+ Hub::Configuration HubConfig{
+ .UseJobObject = ServerConfig.HubUseJobObject,
+ .BasePortNumber = ServerConfig.HubBasePortNumber,
+ .InstanceLimit = ServerConfig.HubInstanceLimit,
+ .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount,
+ .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit,
+ .InstanceConfigPath = ServerConfig.HubInstanceConfigPath,
+ .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification,
+ .WatchDog =
+ {
+ .CycleInterval = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleIntervalMs),
+ .CycleProcessingBudget = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleProcessingBudgetMs),
+ .InstanceCheckThrottle = std::chrono::milliseconds(ServerConfig.WatchdogConfig.InstanceCheckThrottleMs),
+ .ProvisionedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.ProvisionedInactivityTimeoutSeconds),
+ .HibernatedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.HibernatedInactivityTimeoutSeconds),
+ .InactivityCheckMargin = std::chrono::seconds(ServerConfig.WatchdogConfig.InactivityCheckMarginSeconds),
+ .ActivityCheckConnectTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckConnectTimeoutMs),
+ .ActivityCheckRequestTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckRequestTimeoutMs),
+ },
+ .ResourceLimits = ResolveLimits(ServerConfig)};
+
+ if (!ServerConfig.HydrationTargetConfigPath.empty())
+ {
+ FileContents Contents = ReadFile(ServerConfig.HydrationTargetConfigPath);
+ if (!Contents)
+ {
+ throw zen::runtime_error("Failed to read hydration config '{}': {}",
+ ServerConfig.HydrationTargetConfigPath.string(),
+ Contents.ErrorCode.message());
+ }
+ IoBuffer Buffer(Contents.Flatten());
+ std::string_view JsonText(static_cast<const char*>(Buffer.GetData()), Buffer.GetSize());
+
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(JsonText, ParseError);
+ if (!ParseError.empty() || !Root.IsObject())
+ {
+ throw zen::runtime_error("Failed to parse hydration config '{}': {}",
+ ServerConfig.HydrationTargetConfigPath.string(),
+ ParseError.empty() ? "root must be a JSON object" : ParseError);
+ }
+ HubConfig.HydrationOptions = std::move(Root).AsObject();
+ }
+
+ m_Proxy = std::make_unique<HttpProxyHandler>();
+
m_Hub = std::make_unique<Hub>(
- Hub::Configuration{
- .UseJobObject = ServerConfig.HubUseJobObject,
- .BasePortNumber = ServerConfig.HubBasePortNumber,
- .InstanceLimit = ServerConfig.HubInstanceLimit,
- .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount,
- .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit,
- .InstanceConfigPath = ServerConfig.HubInstanceConfigPath,
- .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification,
- .WatchDog =
- {
- .CycleInterval = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleIntervalMs),
- .CycleProcessingBudget = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleProcessingBudgetMs),
- .InstanceCheckThrottle = std::chrono::milliseconds(ServerConfig.WatchdogConfig.InstanceCheckThrottleMs),
- .ProvisionedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.ProvisionedInactivityTimeoutSeconds),
- .HibernatedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.HibernatedInactivityTimeoutSeconds),
- .InactivityCheckMargin = std::chrono::seconds(ServerConfig.WatchdogConfig.InactivityCheckMarginSeconds),
- .ActivityCheckConnectTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckConnectTimeoutMs),
- .ActivityCheckRequestTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckRequestTimeoutMs),
- }},
+ std::move(HubConfig),
ZenServerEnvironment(ZenServerEnvironment::Hub,
ServerConfig.DataDir / "hub",
ServerConfig.DataDir / "servers",
ServerConfig.HubInstanceHttpClass),
&GetMediumWorkerPool(EWorkloadType::Background),
- m_ConsulClient ? Hub::AsyncModuleStateChangeCallbackFunc{[this, HubInstanceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId)](
- std::string_view ModuleId,
- const HubProvisionedInstanceInfo& Info,
- HubInstanceState PreviousState,
- HubInstanceState NewState) {
- OnModuleStateChanged(HubInstanceId, ModuleId, Info, PreviousState, NewState);
- }}
- : Hub::AsyncModuleStateChangeCallbackFunc{});
+ Hub::AsyncModuleStateChangeCallbackFunc{
+ [this, HubInstanceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId)](std::string_view ModuleId,
+ const HubProvisionedInstanceInfo& Info,
+ HubInstanceState PreviousState,
+ HubInstanceState NewState) {
+ OnModuleStateChanged(HubInstanceId, ModuleId, Info, PreviousState, NewState);
+ }});
+
+ m_Proxy->SetPortValidator([Hub = m_Hub.get()](uint16_t Port) { return Hub->IsInstancePort(Port); });
ZEN_INFO("instantiating API service");
m_ApiService = std::make_unique<zen::HttpApiService>(*m_Http);
ZEN_INFO("instantiating hub service");
- m_HubService = std::make_unique<HttpHubService>(*m_Hub, m_StatsService, m_StatusService);
+ m_HubService = std::make_unique<HttpHubService>(*m_Hub, *m_Proxy, m_StatsService, m_StatusService);
m_HubService->SetNotificationEndpoint(ServerConfig.UpstreamNotificationEndpoint, ServerConfig.InstanceId);
m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatsService, m_StatusService);
@@ -465,12 +689,15 @@ ZenHubServer::InitializeConsulRegistration(const ZenHubServerConfig& ServerConfi
}
else
{
- ZEN_INFO("Consul token read from environment variable '{}'", ConsulAccessTokenEnvName);
+ ZEN_INFO("Consul token will be read from environment variable '{}'", ConsulAccessTokenEnvName);
}
try
{
- m_ConsulClient = std::make_unique<consul::ConsulClient>(ServerConfig.ConsulEndpoint, ConsulAccessToken);
+ m_ConsulClient = std::make_unique<consul::ConsulClient>(consul::ConsulClient::Configuration{
+ .BaseUri = ServerConfig.ConsulEndpoint,
+ .TokenEnvName = ConsulAccessTokenEnvName,
+ });
m_ConsulHealthIntervalSeconds = ServerConfig.ConsulHealthIntervalSeconds;
m_ConsulDeregisterAfterSeconds = ServerConfig.ConsulDeregisterAfterSeconds;
@@ -479,7 +706,7 @@ ZenHubServer::InitializeConsulRegistration(const ZenHubServerConfig& ServerConfi
Info.ServiceName = "zen-hub";
// Info.Address = "localhost"; // Let the consul agent figure out out external address // TODO: Info.BaseUri?
Info.Port = static_cast<uint16_t>(EffectivePort);
- Info.HealthEndpoint = "hub/health";
+ Info.HealthEndpoint = "health";
Info.Tags = std::vector<std::pair<std::string, std::string>>{
std::make_pair("zen-hub", Info.ServiceId),
std::make_pair("version", std::string(ZEN_CFG_VERSION)),
diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h
index 77df3eaa3..d1add7690 100644
--- a/src/zenserver/hub/zenhubserver.h
+++ b/src/zenserver/hub/zenhubserver.h
@@ -3,6 +3,7 @@
#pragma once
#include "hubinstancestate.h"
+#include "resourcemetrics.h"
#include "zenserver.h"
#include <zenutil/consul.h>
@@ -19,6 +20,7 @@ namespace zen {
class HttpApiService;
class HttpFrontendService;
class HttpHubService;
+class HttpProxyHandler;
struct ZenHubWatchdogConfig
{
@@ -48,7 +50,12 @@ struct ZenHubServerConfig : public ZenServerConfig
int HubInstanceCoreLimit = 0; // Automatic
std::filesystem::path HubInstanceConfigPath; // Path to Lua config file
std::string HydrationTargetSpecification; // hydration/dehydration target specification
+ std::filesystem::path HydrationTargetConfigPath; // path to JSON config file (mutually exclusive with HydrationTargetSpecification)
ZenHubWatchdogConfig WatchdogConfig;
+ uint64_t HubProvisionDiskLimitBytes = 0;
+ uint32_t HubProvisionDiskLimitPercent = 0;
+ uint64_t HubProvisionMemoryLimitBytes = 0;
+ uint32_t HubProvisionMemoryLimitPercent = 0;
};
class Hub;
@@ -115,7 +122,8 @@ private:
std::filesystem::path m_ContentRoot;
bool m_DebugOptionForcedCrash = false;
- std::unique_ptr<Hub> m_Hub;
+ std::unique_ptr<HttpProxyHandler> m_Proxy;
+ std::unique_ptr<Hub> m_Hub;
std::unique_ptr<HttpHubService> m_HubService;
std::unique_ptr<HttpApiService> m_ApiService;
@@ -126,6 +134,8 @@ private:
uint32_t m_ConsulHealthIntervalSeconds = 10;
uint32_t m_ConsulDeregisterAfterSeconds = 30;
+ static ResourceMetrics ResolveLimits(const ZenHubServerConfig& ServerConfig);
+
void InitializeState(const ZenHubServerConfig& ServerConfig);
void InitializeServices(const ZenHubServerConfig& ServerConfig);
void RegisterServices(const ZenHubServerConfig& ServerConfig);
diff --git a/src/zenserver/sessions/httpsessions.cpp b/src/zenserver/sessions/httpsessions.cpp
index fdf2e1f21..c21ae6a5c 100644
--- a/src/zenserver/sessions/httpsessions.cpp
+++ b/src/zenserver/sessions/httpsessions.cpp
@@ -512,8 +512,9 @@ HttpSessionsService::SessionLogRequest(HttpRouterRequest& Req)
//
void
-HttpSessionsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpSessionsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
+ ZEN_UNUSED(RelativeUri);
ZEN_INFO("Sessions WebSocket client connected");
m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
}
diff --git a/src/zenserver/sessions/httpsessions.h b/src/zenserver/sessions/httpsessions.h
index 86a23f835..6ebe61c8d 100644
--- a/src/zenserver/sessions/httpsessions.h
+++ b/src/zenserver/sessions/httpsessions.h
@@ -37,7 +37,7 @@ public:
void SetSelfSessionId(const Oid& Id) { m_SelfSessionId = Id; }
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp
index c1727270c..8ad48225b 100644
--- a/src/zenserver/storage/cache/httpstructuredcache.cpp
+++ b/src/zenserver/storage/cache/httpstructuredcache.cpp
@@ -80,7 +80,8 @@ HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCach
HttpStatusService& StatusService,
UpstreamCache& UpstreamCache,
const DiskWriteBlocker* InDiskWriteBlocker,
- OpenProcessCache& InOpenProcessCache)
+ OpenProcessCache& InOpenProcessCache,
+ const ILocalRefPolicy* InLocalRefPolicy)
: m_Log(logging::Get("cache"))
, m_CacheStore(InCacheStore)
, m_StatsService(StatsService)
@@ -90,6 +91,7 @@ HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCach
, m_DiskWriteBlocker(InDiskWriteBlocker)
, m_OpenProcessCache(InOpenProcessCache)
, m_RpcHandler(m_Log, m_CacheStats, UpstreamCache, InCacheStore, InCidStore, InDiskWriteBlocker)
+, m_LocalRefPolicy(InLocalRefPolicy)
{
m_StatsService.RegisterHandler("z$", *this);
m_StatusService.RegisterHandler("z$", *this);
@@ -114,6 +116,18 @@ HttpStructuredCacheService::BaseUri() const
return "/z$/";
}
+bool
+HttpStructuredCacheService::AcceptsLocalFileReferences() const
+{
+ return true;
+}
+
+const ILocalRefPolicy*
+HttpStructuredCacheService::GetLocalRefPolicy() const
+{
+ return m_LocalRefPolicy;
+}
+
void
HttpStructuredCacheService::Flush()
{
diff --git a/src/zenserver/storage/cache/httpstructuredcache.h b/src/zenserver/storage/cache/httpstructuredcache.h
index fc80b449e..f606126d6 100644
--- a/src/zenserver/storage/cache/httpstructuredcache.h
+++ b/src/zenserver/storage/cache/httpstructuredcache.h
@@ -76,11 +76,14 @@ public:
HttpStatusService& StatusService,
UpstreamCache& UpstreamCache,
const DiskWriteBlocker* InDiskWriteBlocker,
- OpenProcessCache& InOpenProcessCache);
+ OpenProcessCache& InOpenProcessCache,
+ const ILocalRefPolicy* InLocalRefPolicy = nullptr);
~HttpStructuredCacheService();
- virtual const char* BaseUri() const override;
- virtual void HandleRequest(HttpServerRequest& Request) override;
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+ virtual bool AcceptsLocalFileReferences() const override;
+ virtual const ILocalRefPolicy* GetLocalRefPolicy() const override;
void Flush();
@@ -125,6 +128,7 @@ private:
const DiskWriteBlocker* m_DiskWriteBlocker = nullptr;
OpenProcessCache& m_OpenProcessCache;
CacheRpcHandler m_RpcHandler;
+ const ILocalRefPolicy* m_LocalRefPolicy = nullptr;
void ReplayRequestRecorder(const CacheRequestContext& Context, cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount);
diff --git a/src/zenserver/storage/localrefpolicy.cpp b/src/zenserver/storage/localrefpolicy.cpp
new file mode 100644
index 000000000..47ef13b28
--- /dev/null
+++ b/src/zenserver/storage/localrefpolicy.cpp
@@ -0,0 +1,29 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "localrefpolicy.h"
+
+#include <zencore/except_fmt.h>
+#include <zencore/fmtutils.h>
+
+#include <filesystem>
+
+namespace zen {
+
+DataRootLocalRefPolicy::DataRootLocalRefPolicy(const std::filesystem::path& DataRoot)
+: m_CanonicalRoot(std::filesystem::weakly_canonical(DataRoot).string())
+{
+}
+
+void
+DataRootLocalRefPolicy::ValidatePath(const std::filesystem::path& Path) const
+{
+ std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(Path);
+ std::string FileStr = CanonicalFile.string();
+
+ if (FileStr.size() < m_CanonicalRoot.size() || FileStr.compare(0, m_CanonicalRoot.size(), m_CanonicalRoot) != 0)
+ {
+ throw zen::invalid_argument("local file reference '{}' is outside allowed data root", CanonicalFile);
+ }
+}
+
+} // namespace zen
diff --git a/src/zenserver/storage/localrefpolicy.h b/src/zenserver/storage/localrefpolicy.h
new file mode 100644
index 000000000..3686d1880
--- /dev/null
+++ b/src/zenserver/storage/localrefpolicy.h
@@ -0,0 +1,25 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/localrefpolicy.h>
+
+#include <filesystem>
+#include <string>
+
+namespace zen {
+
+/// Local ref policy that restricts file paths to a canonical data root directory.
+/// Uses weakly_canonical + string prefix comparison to detect path traversal.
+class DataRootLocalRefPolicy : public ILocalRefPolicy
+{
+public:
+ explicit DataRootLocalRefPolicy(const std::filesystem::path& DataRoot);
+
+ void ValidatePath(const std::filesystem::path& Path) const override;
+
+private:
+ std::string m_CanonicalRoot;
+};
+
+} // namespace zen
diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp
index a7c8c66b6..afd0d8f82 100644
--- a/src/zenserver/storage/projectstore/httpprojectstore.cpp
+++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp
@@ -656,7 +656,8 @@ HttpProjectService::HttpProjectService(CidStore& Store,
JobQueue& InJobQueue,
bool InRestrictContentTypes,
const std::filesystem::path& InOidcTokenExePath,
- bool InAllowExternalOidcTokenExe)
+ bool InAllowExternalOidcTokenExe,
+ const ILocalRefPolicy* InLocalRefPolicy)
: m_Log(logging::Get("project"))
, m_CidStore(Store)
, m_ProjectStore(Projects)
@@ -668,6 +669,7 @@ HttpProjectService::HttpProjectService(CidStore& Store,
, m_RestrictContentTypes(InRestrictContentTypes)
, m_OidcTokenExePath(InOidcTokenExePath)
, m_AllowExternalOidcTokenExe(InAllowExternalOidcTokenExe)
+, m_LocalRefPolicy(InLocalRefPolicy)
{
ZEN_MEMSCOPE(GetProjectHttpTag());
@@ -820,6 +822,18 @@ HttpProjectService::BaseUri() const
return "/prj/";
}
+bool
+HttpProjectService::AcceptsLocalFileReferences() const
+{
+ return true;
+}
+
+const ILocalRefPolicy*
+HttpProjectService::GetLocalRefPolicy() const
+{
+ return m_LocalRefPolicy;
+}
+
void
HttpProjectService::HandleRequest(HttpServerRequest& Request)
{
@@ -1668,7 +1682,8 @@ HttpProjectService::HandleOplogOpNewRequest(HttpRouterRequest& Req)
CbPackage Package;
- if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver))
+ const bool ValidateHashes = false;
+ if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver, ValidateHashes))
{
CbValidateError ValidateResult;
if (CbObject Core = ValidateAndReadCompactBinaryObject(IoBuffer(Payload), ValidateResult);
@@ -2763,7 +2778,11 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req)
case HttpContentType::kCbPackage:
try
{
- Package = ParsePackageMessage(Payload);
+ ParseFlags PkgFlags = (HttpReq.IsLocalMachineRequest() && AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences
+ : ParseFlags::kDefault;
+ const ILocalRefPolicy* PkgPolicy =
+ EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? GetLocalRefPolicy() : nullptr;
+ Package = ParsePackageMessage(Payload, {}, PkgFlags, PkgPolicy);
Cb = Package.GetObject();
}
catch (const std::invalid_argument& ex)
diff --git a/src/zenserver/storage/projectstore/httpprojectstore.h b/src/zenserver/storage/projectstore/httpprojectstore.h
index e3ed02f26..8aa345fa7 100644
--- a/src/zenserver/storage/projectstore/httpprojectstore.h
+++ b/src/zenserver/storage/projectstore/httpprojectstore.h
@@ -47,11 +47,14 @@ public:
JobQueue& InJobQueue,
bool InRestrictContentTypes,
const std::filesystem::path& InOidcTokenExePath,
- bool AllowExternalOidcTokenExe);
+ bool AllowExternalOidcTokenExe,
+ const ILocalRefPolicy* InLocalRefPolicy = nullptr);
~HttpProjectService();
- virtual const char* BaseUri() const override;
- virtual void HandleRequest(HttpServerRequest& Request) override;
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+ virtual bool AcceptsLocalFileReferences() const override;
+ virtual const ILocalRefPolicy* GetLocalRefPolicy() const override;
virtual void HandleStatusRequest(HttpServerRequest& Request) override;
virtual void HandleStatsRequest(HttpServerRequest& Request) override;
@@ -117,6 +120,7 @@ private:
bool m_RestrictContentTypes;
std::filesystem::path m_OidcTokenExePath;
bool m_AllowExternalOidcTokenExe;
+ const ILocalRefPolicy* m_LocalRefPolicy;
Ref<TransferThreadWorkers> GetThreadWorkers(bool BoostWorkers, bool SingleThreaded);
};
diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp
index bc0a8f4ac..6b1da5f12 100644
--- a/src/zenserver/storage/zenstorageserver.cpp
+++ b/src/zenserver/storage/zenstorageserver.cpp
@@ -223,6 +223,7 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions
ZEN_INFO("instantiating project service");
+ m_LocalRefPolicy = std::make_unique<DataRootLocalRefPolicy>(m_DataRoot);
m_JobQueue = MakeJobQueue(8, "bgjobs");
m_OpenProcessCache = std::make_unique<OpenProcessCache>();
@@ -236,7 +237,8 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions
*m_JobQueue,
ServerOptions.RestrictContentTypes,
ServerOptions.OidcTokenExecutable,
- ServerOptions.AllowExternalOidcTokenExe});
+ ServerOptions.AllowExternalOidcTokenExe,
+ m_LocalRefPolicy.get()});
if (ServerOptions.WorksSpacesConfig.Enabled)
{
@@ -713,7 +715,8 @@ ZenStorageServer::InitializeStructuredCache(const ZenStorageServerConfig& Server
m_StatusService,
*m_UpstreamCache,
m_GcManager.GetDiskWriteBlocker(),
- *m_OpenProcessCache);
+ *m_OpenProcessCache,
+ m_LocalRefPolicy.get());
m_StatsReporter.AddProvider(m_CacheStore.Get());
m_StatsReporter.AddProvider(m_CidStore.get());
diff --git a/src/zenserver/storage/zenstorageserver.h b/src/zenserver/storage/zenstorageserver.h
index fad22ad54..e3c6248e6 100644
--- a/src/zenserver/storage/zenstorageserver.h
+++ b/src/zenserver/storage/zenstorageserver.h
@@ -11,6 +11,7 @@
#include <zenstore/cache/structuredcachestore.h>
#include <zenstore/gc.h>
#include <zenstore/projectstore.h>
+#include "localrefpolicy.h"
#include "admin/admin.h"
#include "buildstore/httpbuildstore.h"
@@ -65,15 +66,16 @@ private:
void InitializeServices(const ZenStorageServerConfig& ServerOptions);
void RegisterServices();
- std::unique_ptr<JobQueue> m_JobQueue;
- GcManager m_GcManager;
- GcScheduler m_GcScheduler{m_GcManager};
- std::unique_ptr<CidStore> m_CidStore;
- Ref<ZenCacheStore> m_CacheStore;
- std::unique_ptr<OpenProcessCache> m_OpenProcessCache;
- HttpTestService m_TestService;
- std::unique_ptr<CidStore> m_BuildCidStore;
- std::unique_ptr<BuildStore> m_BuildStore;
+ std::unique_ptr<DataRootLocalRefPolicy> m_LocalRefPolicy;
+ std::unique_ptr<JobQueue> m_JobQueue;
+ GcManager m_GcManager;
+ GcScheduler m_GcScheduler{m_GcManager};
+ std::unique_ptr<CidStore> m_CidStore;
+ Ref<ZenCacheStore> m_CacheStore;
+ std::unique_ptr<OpenProcessCache> m_OpenProcessCache;
+ HttpTestService m_TestService;
+ std::unique_ptr<CidStore> m_BuildCidStore;
+ std::unique_ptr<BuildStore> m_BuildStore;
#if ZEN_WITH_TESTS
HttpTestingService m_TestingService;
diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp
index 6aa02eb87..087b40d6a 100644
--- a/src/zenserver/zenserver.cpp
+++ b/src/zenserver/zenserver.cpp
@@ -272,6 +272,8 @@ ZenServerBase::GetBuildOptions(StringBuilderBase& OutOptions, char Separator) co
OutOptions << Separator;
OutOptions << "ZEN_WITH_MEMTRACK=" << (ZEN_WITH_MEMTRACK ? "1" : "0");
OutOptions << Separator;
+ OutOptions << "ZEN_WITH_COMPUTE_SERVICES=" << (ZEN_WITH_COMPUTE_SERVICES ? "1" : "0");
+ OutOptions << Separator;
OutOptions << "ZEN_WITH_TRACE=" << (ZEN_WITH_TRACE ? "1" : "0");
}
diff --git a/src/zenstore/projectstore.cpp b/src/zenstore/projectstore.cpp
index 13674da4d..7cd6b9e37 100644
--- a/src/zenstore/projectstore.cpp
+++ b/src/zenstore/projectstore.cpp
@@ -3180,6 +3180,7 @@ ProjectStore::Oplog::AddFileMapping(const RwLock::ExclusiveLockScope&,
}
else
{
+ m_ChunkMap.erase(FileId);
Entry.ServerPath = ServerPath;
}
@@ -5168,7 +5169,10 @@ ExtractRange(IoBuffer&& Chunk, uint64_t Offset, uint64_t Size, ZenContentType Ac
const bool IsFullRange = (Offset == 0) && ((Size == ~(0ull)) || (Size == ChunkSize));
if (IsFullRange)
{
- Result.Chunk = CompositeBuffer(SharedBuffer(std::move(Chunk)));
+ if (ChunkSize > 0)
+ {
+ Result.Chunk = CompositeBuffer(SharedBuffer(std::move(Chunk)));
+ }
Result.RawSize = 0;
}
else
diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp
index 54e54edde..73cb7ff2d 100644
--- a/src/zentest-appstub/zentest-appstub.cpp
+++ b/src/zentest-appstub/zentest-appstub.cpp
@@ -33,6 +33,13 @@ using namespace zen;
// Some basic functions to implement some test "compute" functions
+struct ForcedExitException : std::exception
+{
+ int Code;
+ explicit ForcedExitException(int InCode) : Code(InCode) {}
+ const char* what() const noexcept override { return "forced exit"; }
+};
+
std::string
Rot13Function(std::string_view InputString)
{
@@ -111,6 +118,16 @@ DescribeFunctions()
<< "Sleep"sv;
Versions << "Version"sv << Guid::FromString("88888888-8888-8888-8888-888888888888"sv);
Versions.EndObject();
+ Versions.BeginObject();
+ Versions << "Name"sv
+ << "Fail"sv;
+ Versions << "Version"sv << Guid::FromString("fa11fa11-fa11-fa11-fa11-fa11fa11fa11"sv);
+ Versions.EndObject();
+ Versions.BeginObject();
+ Versions << "Name"sv
+ << "Crash"sv;
+ Versions << "Version"sv << Guid::FromString("c4a50000-c4a5-c4a5-c4a5-c4a5c4a5c4a5"sv);
+ Versions.EndObject();
Versions.EndArray();
return Versions.Save();
@@ -201,6 +218,38 @@ ExecuteFunction(CbObject Action, ContentResolver ChunkResolver)
zen::Sleep(static_cast<int>(SleepTimeMs));
return Apply(IdentityFunction);
}
+ else if (Function == "Fail"sv)
+ {
+ int FailExitCode = static_cast<int>(Action["Constants"sv].AsObjectView()["ExitCode"sv].AsUInt64());
+ if (FailExitCode == 0)
+ {
+ FailExitCode = 1;
+ }
+ throw ForcedExitException(FailExitCode);
+ }
+ else if (Function == "Crash"sv)
+ {
+ // Crash modes:
+ // "abort" - calls std::abort() (SIGABRT / process termination)
+ // "nullptr" - dereferences a null pointer (SIGSEGV / access violation)
+ std::string_view Mode = Action["Constants"sv].AsObjectView()["Mode"sv].AsString();
+
+ printf("[zentest] crashing with mode: %.*s\n", int(Mode.size()), Mode.data());
+ fflush(stdout);
+
+ if (Mode == "nullptr"sv)
+ {
+ volatile int* Ptr = nullptr;
+ *Ptr = 42;
+ }
+
+ // Default crash mode (also reached after nullptr write on platforms
+ // that don't immediately fault on null dereference)
+#if defined(_MSC_VER)
+ _set_abort_behavior(0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT);
+#endif
+ std::abort();
+ }
else
{
return {};
@@ -421,6 +470,12 @@ main(int argc, char* argv[])
}
}
}
+ catch (ForcedExitException& Ex)
+ {
+ printf("[zentest] forced exit with code: %d\n", Ex.Code);
+
+ ExitCode = Ex.Code;
+ }
catch (std::exception& Ex)
{
printf("[zentest] exception caught in main: '%s'\n", Ex.what());
diff --git a/src/zenutil/cloud/imdscredentials.cpp b/src/zenutil/cloud/imdscredentials.cpp
index dde1dc019..5a6cf45d2 100644
--- a/src/zenutil/cloud/imdscredentials.cpp
+++ b/src/zenutil/cloud/imdscredentials.cpp
@@ -115,7 +115,7 @@ ImdsCredentialProvider::FetchToken()
HttpClient::KeyValueMap Headers;
Headers->emplace("X-aws-ec2-metadata-token-ttl-seconds", "21600");
- HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", Headers);
+ HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", IoBuffer{}, Headers);
if (!Response.IsSuccess())
{
ZEN_WARN("IMDS token request failed: {}", Response.ErrorMessage("PUT /latest/api/token"));
diff --git a/src/zenutil/cloud/minioprocess.cpp b/src/zenutil/cloud/minioprocess.cpp
index 457453bd8..e146f6677 100644
--- a/src/zenutil/cloud/minioprocess.cpp
+++ b/src/zenutil/cloud/minioprocess.cpp
@@ -45,7 +45,7 @@ struct MinioProcess::Impl
}
CreateProcOptions Options;
- Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup;
+ Options.Flags |= CreateProcOptions::Flag_NewProcessGroup;
Options.Environment.emplace_back("MINIO_ROOT_USER", m_Options.RootUser);
Options.Environment.emplace_back("MINIO_ROOT_PASSWORD", m_Options.RootPassword);
diff --git a/src/zenutil/cloud/s3client.cpp b/src/zenutil/cloud/s3client.cpp
index 26d1023f4..d9fde05d9 100644
--- a/src/zenutil/cloud/s3client.cpp
+++ b/src/zenutil/cloud/s3client.cpp
@@ -137,6 +137,8 @@ namespace {
} // namespace
+std::string_view S3GetObjectResult::NotFoundErrorText = "Not found";
+
S3Client::S3Client(const S3ClientOptions& Options)
: m_Log(logging::Get("s3"))
, m_BucketName(Options.BucketName)
@@ -145,13 +147,7 @@ S3Client::S3Client(const S3ClientOptions& Options)
, m_PathStyle(Options.PathStyle)
, m_Credentials(Options.Credentials)
, m_CredentialProvider(Options.CredentialProvider)
-, m_HttpClient(BuildEndpoint(),
- HttpClientSettings{
- .LogCategory = "s3",
- .ConnectTimeout = Options.ConnectTimeout,
- .Timeout = Options.Timeout,
- .RetryCount = Options.RetryCount,
- })
+, m_HttpClient(BuildEndpoint(), Options.HttpSettings)
{
m_Host = BuildHostHeader();
ZEN_INFO("S3 client configured for bucket '{}' in region '{}' (endpoint: {}, {})",
@@ -347,15 +343,20 @@ S3Client::PutObject(std::string_view Key, IoBuffer Content)
}
S3GetObjectResult
-S3Client::GetObject(std::string_view Key)
+S3Client::GetObject(std::string_view Key, const std::filesystem::path& TempFilePath)
{
std::string Path = KeyToPath(Key);
HttpClient::KeyValueMap Headers = SignRequest("GET", Path, "", EmptyPayloadHash);
- HttpClient::Response Response = m_HttpClient.Get(Path, Headers);
+ HttpClient::Response Response = m_HttpClient.Download(Path, TempFilePath, Headers);
if (!Response.IsSuccess())
{
+ if (Response.StatusCode == HttpResponseCode::NotFound)
+ {
+ return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}};
+ }
+
std::string Err = Response.ErrorMessage("S3 GET failed");
ZEN_WARN("S3 GET '{}' failed: {}", Key, Err);
return S3GetObjectResult{S3Result{std::move(Err)}, {}};
@@ -377,6 +378,11 @@ S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t Ran
HttpClient::Response Response = m_HttpClient.Get(Path, Headers);
if (!Response.IsSuccess())
{
+ if (Response.StatusCode == HttpResponseCode::NotFound)
+ {
+ return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}};
+ }
+
std::string Err = Response.ErrorMessage("S3 GET range failed");
ZEN_WARN("S3 GET range '{}' [{}-{}] failed: {}", Key, RangeStart, RangeStart + RangeSize - 1, Err);
return S3GetObjectResult{S3Result{std::move(Err)}, {}};
@@ -749,7 +755,7 @@ S3Client::PutObjectMultipart(std::string_view Key,
return PutObject(Key, TotalSize > 0 ? FetchRange(0, TotalSize) : IoBuffer{});
}
- ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, TotalSize, (TotalSize + PartSize - 1) / PartSize);
+ ZEN_DEBUG("S3 multipart upload '{}': {} bytes in ~{} parts", Key, TotalSize, (TotalSize + PartSize - 1) / PartSize);
S3CreateMultipartUploadResult InitResult = CreateMultipartUpload(Key);
if (!InitResult)
@@ -797,7 +803,7 @@ S3Client::PutObjectMultipart(std::string_view Key,
throw;
}
- ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), TotalSize);
+ ZEN_DEBUG("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), TotalSize);
return {};
}
@@ -885,7 +891,10 @@ TEST_CASE("s3client.minio_integration")
{
using namespace std::literals;
- // Spawn a local MinIO server
+ // Spawn a single MinIO server for the entire test case. Previously each SUBCASE re-entered
+ // the TEST_CASE from the top, spawning and killing MinIO per subcase — slow and flaky on
+ // macOS CI. Sequential sections avoid the re-entry while still sharing one MinIO instance
+ // that is torn down via RAII at scope exit.
MinioProcessOptions MinioOpts;
MinioOpts.Port = 19000;
MinioOpts.RootUser = "testuser";
@@ -893,11 +902,8 @@ TEST_CASE("s3client.minio_integration")
MinioProcess Minio(MinioOpts);
Minio.SpawnMinioServer();
-
- // Pre-create the test bucket (creates a subdirectory in MinIO's data dir)
Minio.CreateBucket("integration-test");
- // Configure S3Client for the test bucket
S3ClientOptions Opts;
Opts.BucketName = "integration-test";
Opts.Region = "us-east-1";
@@ -908,7 +914,7 @@ TEST_CASE("s3client.minio_integration")
S3Client Client(Opts);
- SUBCASE("put_get_delete")
+ // -- put_get_delete -------------------------------------------------------
{
// PUT
std::string_view TestData = "hello, minio integration test!"sv;
@@ -937,14 +943,14 @@ TEST_CASE("s3client.minio_integration")
CHECK(HeadRes2.Status == HeadObjectResult::NotFound);
}
- SUBCASE("head_not_found")
+ // -- head_not_found -------------------------------------------------------
{
S3HeadObjectResult Res = Client.HeadObject("nonexistent/key.dat");
CHECK(Res.IsSuccess());
CHECK(Res.Status == HeadObjectResult::NotFound);
}
- SUBCASE("list_objects")
+ // -- list_objects ---------------------------------------------------------
{
// Upload several objects with a common prefix
for (int i = 0; i < 3; ++i)
@@ -979,7 +985,7 @@ TEST_CASE("s3client.minio_integration")
}
}
- SUBCASE("multipart_upload")
+ // -- multipart_upload -----------------------------------------------------
{
// Create a payload large enough to exercise multipart (use minimum part size)
constexpr uint64_t PartSize = 5 * 1024 * 1024; // 5 MB minimum
@@ -1006,7 +1012,7 @@ TEST_CASE("s3client.minio_integration")
Client.DeleteObject("multipart/large.bin");
}
- SUBCASE("presigned_urls")
+ // -- presigned_urls -------------------------------------------------------
{
// Upload an object
std::string_view TestData = "presigned-url-test-data"sv;
@@ -1032,8 +1038,6 @@ TEST_CASE("s3client.minio_integration")
// Cleanup
Client.DeleteObject("presigned/test.txt");
}
-
- Minio.StopMinioServer();
}
TEST_SUITE_END();
diff --git a/src/zenutil/consul/consul.cpp b/src/zenutil/consul/consul.cpp
index c9144e589..951beed65 100644
--- a/src/zenutil/consul/consul.cpp
+++ b/src/zenutil/consul/consul.cpp
@@ -9,9 +9,13 @@
#include <zencore/logging.h>
#include <zencore/process.h>
#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/testutils.h>
#include <zencore/thread.h>
#include <zencore/timer.h>
+#include <zenhttp/httpserver.h>
+
#include <fmt/format.h>
namespace zen::consul {
@@ -31,7 +35,7 @@ struct ConsulProcess::Impl
}
CreateProcOptions Options;
- Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup;
+ Options.Flags |= CreateProcOptions::Flag_NewProcessGroup;
const std::filesystem::path ConsulExe = GetRunningExecutablePath().parent_path() / ("consul" ZEN_EXE_SUFFIX_LITERAL);
CreateProcResult Result = CreateProc(ConsulExe, "consul" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options);
@@ -107,7 +111,7 @@ ConsulProcess::StopConsulAgent()
//////////////////////////////////////////////////////////////////////////
-ConsulClient::ConsulClient(std::string_view BaseUri, std::string_view Token) : m_Token(Token), m_HttpClient(BaseUri)
+ConsulClient::ConsulClient(const Configuration& Config) : m_Config(Config), m_HttpClient(m_Config.BaseUri)
{
}
@@ -193,7 +197,9 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info)
// when no interval is configured (e.g. during Provisioning).
Writer.BeginObject("Check"sv);
{
- Writer.AddString("HTTP"sv, fmt::format("http://{}:{}/{}", Info.Address, Info.Port, Info.HealthEndpoint));
+ Writer.AddString(
+ "HTTP"sv,
+ fmt::format("http://{}:{}/{}", Info.Address.empty() ? "localhost" : Info.Address, Info.Port, Info.HealthEndpoint));
Writer.AddString("Interval"sv, fmt::format("{}s", Info.HealthIntervalSeconds));
if (Info.DeregisterAfterSeconds != 0)
{
@@ -223,27 +229,112 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info)
bool
ConsulClient::DeregisterService(std::string_view ServiceId)
{
+ using namespace std::literals;
+
HttpClient::KeyValueMap AdditionalHeaders;
ApplyCommonHeaders(AdditionalHeaders);
AdditionalHeaders.Entries.emplace(HttpClient::Accept(HttpContentType::kJSON));
- HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), AdditionalHeaders);
+ HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), IoBuffer{}, AdditionalHeaders);
+ if (Result)
+ {
+ return true;
+ }
+
+ // Agent deregister failed — fall back to catalog deregister.
+ // This handles cases where the service was registered via a different Consul agent
+ // (e.g. load-balanced endpoint routing to different agents).
+ std::string NodeName = GetNodeName();
+ if (!NodeName.empty())
+ {
+ CbObjectWriter Writer;
+ Writer.AddString("Node"sv, NodeName);
+ Writer.AddString("ServiceID"sv, ServiceId);
+
+ ExtendableStringBuilder<256> SB;
+ CompactBinaryToJson(Writer.Save(), SB);
+
+ IoBuffer PayloadBuffer(IoBuffer::Wrap, SB.Data(), SB.Size());
+ PayloadBuffer.SetContentType(HttpContentType::kJSON);
+
+ HttpClient::Response CatalogResult = m_HttpClient.Put("v1/catalog/deregister", PayloadBuffer, AdditionalHeaders);
+ if (CatalogResult)
+ {
+ ZEN_INFO("ConsulClient::DeregisterService() deregistered service '{}' via catalog fallback (agent error: {})",
+ ServiceId,
+ Result.ErrorMessage(""));
+ return true;
+ }
+
+ ZEN_WARN("ConsulClient::DeregisterService() failed to deregister service '{}' (agent: {}, catalog: {})",
+ ServiceId,
+ Result.ErrorMessage(""),
+ CatalogResult.ErrorMessage(""));
+ }
+ else
+ {
+ ZEN_WARN(
+ "ConsulClient::DeregisterService() failed to deregister service '{}' (agent: {}, could not determine node name for catalog "
+ "fallback)",
+ ServiceId,
+ Result.ErrorMessage(""));
+ }
+
+ return false;
+}
+
+std::string
+ConsulClient::GetNodeName()
+{
+ using namespace std::literals;
+
+ HttpClient::KeyValueMap AdditionalHeaders;
+ ApplyCommonHeaders(AdditionalHeaders);
+ HttpClient::Response Result = m_HttpClient.Get("v1/agent/self", AdditionalHeaders);
if (!Result)
{
- ZEN_WARN("ConsulClient::DeregisterService() failed to deregister service '{}' ({})", ServiceId, Result.ErrorMessage(""));
- return false;
+ return {};
}
- return true;
+ std::string JsonError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(Result.AsText(), JsonError);
+ if (!Root || !JsonError.empty())
+ {
+ return {};
+ }
+
+ for (CbFieldView Field : Root)
+ {
+ if (Field.GetName() == "Config"sv)
+ {
+ CbObjectView Config = Field.AsObjectView();
+ if (Config)
+ {
+ return std::string(Config["NodeName"sv].AsString());
+ }
+ }
+ }
+
+ return {};
}
void
ConsulClient::ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap)
{
- if (!m_Token.empty())
+ std::string Token;
+ if (!m_Config.StaticToken.empty())
+ {
+ Token = m_Config.StaticToken;
+ }
+ else if (!m_Config.TokenEnvName.empty())
+ {
+ Token = GetEnvVariable(m_Config.TokenEnvName);
+ }
+
+ if (!Token.empty())
{
- InOutHeaderMap.Entries.emplace("X-Consul-Token", m_Token);
+ InOutHeaderMap.Entries.emplace("X-Consul-Token", Token);
}
}
@@ -446,4 +537,191 @@ ServiceRegistration::RegistrationLoop()
}
}
+//////////////////////////////////////////////////////////////////////////
+// Tests
+
+#if ZEN_WITH_TESTS
+
+void
+consul_forcelink()
+{
+}
+
+struct MockHealthService : public HttpService
+{
+ std::atomic<bool> FailHealth{false};
+ std::atomic<int> HealthCheckCount{0};
+
+ const char* BaseUri() const override { return "/"; }
+
+ void HandleRequest(HttpServerRequest& Request) override
+ {
+ std::string_view Uri = Request.RelativeUri();
+ if (Uri == "health/" || Uri == "health")
+ {
+ HealthCheckCount.fetch_add(1);
+ if (FailHealth.load())
+ {
+ Request.WriteResponse(HttpResponseCode::ServiceUnavailable);
+ }
+ else
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok");
+ }
+ return;
+ }
+ Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+};
+
+struct TestHealthServer
+{
+ MockHealthService Mock;
+
+ void Start()
+ {
+ m_TmpDir.emplace();
+ m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"});
+ m_Port = m_Server->Initialize(0, m_TmpDir->Path() / "http");
+ REQUIRE(m_Port != -1);
+ m_Server->RegisterService(Mock);
+ m_ServerThread = std::thread([this]() { m_Server->Run(false); });
+ }
+
+ int Port() const { return m_Port; }
+
+ ~TestHealthServer()
+ {
+ if (m_Server)
+ {
+ m_Server->RequestExit();
+ }
+ if (m_ServerThread.joinable())
+ {
+ m_ServerThread.join();
+ }
+ if (m_Server)
+ {
+ m_Server->Close();
+ }
+ }
+
+private:
+ std::optional<ScopedTemporaryDirectory> m_TmpDir;
+ Ref<HttpServer> m_Server;
+ std::thread m_ServerThread;
+ int m_Port = -1;
+};
+
+static bool
+WaitForCondition(std::function<bool()> Predicate, int TimeoutMs, int PollIntervalMs = 200)
+{
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < static_cast<uint64_t>(TimeoutMs))
+ {
+ if (Predicate())
+ {
+ return true;
+ }
+ Sleep(PollIntervalMs);
+ }
+ return Predicate();
+}
+
+static std::string
+GetCheckStatus(ConsulClient& Client, std::string_view ServiceId)
+{
+ using namespace std::literals;
+
+ std::string JsonError;
+ CbFieldIterator ChecksRoot = LoadCompactBinaryFromJson(Client.GetAgentChecksJson(), JsonError);
+ if (!ChecksRoot || !JsonError.empty())
+ {
+ return {};
+ }
+
+ for (CbFieldView F : ChecksRoot)
+ {
+ if (!F.IsObject())
+ {
+ continue;
+ }
+ for (CbFieldView C : F.AsObjectView())
+ {
+ CbObjectView Check = C.AsObjectView();
+ if (Check["ServiceID"sv].AsString() == ServiceId)
+ {
+ return std::string(Check["Status"sv].AsString());
+ }
+ }
+ }
+ return {};
+}
+
+TEST_SUITE_BEGIN("util.consul");
+
+TEST_CASE("util.consul.service_lifecycle")
+{
+ ConsulProcess ConsulProc;
+ ConsulProc.SpawnConsulAgent();
+
+ TestHealthServer HealthServer;
+ HealthServer.Start();
+
+ ConsulClient Client({.BaseUri = "http://localhost:8500/"});
+
+ const std::string ServiceId = "test-health-svc";
+
+ ServiceRegistrationInfo Info;
+ Info.ServiceId = ServiceId;
+ Info.ServiceName = "zen-test-health";
+ Info.Address = "127.0.0.1";
+ Info.Port = static_cast<uint16_t>(HealthServer.Port());
+ Info.HealthEndpoint = "health/";
+ Info.HealthIntervalSeconds = 1;
+ Info.DeregisterAfterSeconds = 60;
+
+ // Phase 1: Register and verify Consul sends health checks to our service
+ REQUIRE(Client.RegisterService(Info));
+ REQUIRE(Client.HasService(ServiceId));
+
+ REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000));
+ CHECK(HealthServer.Mock.HealthCheckCount.load() >= 1);
+ CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing");
+
+ // Phase 2: Explicit deregister
+ REQUIRE(Client.DeregisterService(ServiceId));
+ CHECK_FALSE(Client.HasService(ServiceId));
+
+ // Phase 3: Register again, verify passing, then fail health and verify check goes critical
+ HealthServer.Mock.HealthCheckCount.store(0);
+ HealthServer.Mock.FailHealth.store(false);
+
+ REQUIRE(Client.RegisterService(Info));
+ REQUIRE(Client.HasService(ServiceId));
+
+ REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000));
+ CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing");
+
+ HealthServer.Mock.FailHealth.store(true);
+
+ // Wait for Consul to observe the failing check
+ REQUIRE(WaitForCondition([&]() { return GetCheckStatus(Client, ServiceId) == "critical"; }, 10000));
+ CHECK_EQ(GetCheckStatus(Client, ServiceId), "critical");
+
+ // Phase 4: Explicit deregister while critical
+ REQUIRE(Client.DeregisterService(ServiceId));
+ CHECK_FALSE(Client.HasService(ServiceId));
+
+ // Phase 5: Deregister an already-deregistered service - should not crash
+ Client.DeregisterService(ServiceId);
+ CHECK_FALSE(Client.HasService(ServiceId));
+
+ ConsulProc.StopConsulAgent();
+}
+
+TEST_SUITE_END();
+
+#endif
+
} // namespace zen::consul
diff --git a/src/zenutil/include/zenutil/cloud/s3client.h b/src/zenutil/include/zenutil/cloud/s3client.h
index bd30aa8a2..f1f0df0e4 100644
--- a/src/zenutil/include/zenutil/cloud/s3client.h
+++ b/src/zenutil/include/zenutil/cloud/s3client.h
@@ -35,9 +35,7 @@ struct S3ClientOptions
/// Overrides the static Credentials field.
Ref<ImdsCredentialProvider> CredentialProvider;
- std::chrono::milliseconds ConnectTimeout{5000};
- std::chrono::milliseconds Timeout{};
- uint8_t RetryCount = 3;
+ HttpClientSettings HttpSettings = {.LogCategory = "s3", .ConnectTimeout = std::chrono::milliseconds(5000), .RetryCount = 3};
};
struct S3ObjectInfo
@@ -70,6 +68,8 @@ struct S3GetObjectResult : S3Result
IoBuffer Content;
std::string_view AsText() const { return std::string_view(reinterpret_cast<const char*>(Content.GetData()), Content.GetSize()); }
+
+ static std::string_view NotFoundErrorText;
};
/// Result of HeadObject - carries object metadata and existence status.
@@ -119,7 +119,7 @@ public:
S3Result PutObject(std::string_view Key, IoBuffer Content);
/// Download an object from S3
- S3GetObjectResult GetObject(std::string_view Key);
+ S3GetObjectResult GetObject(std::string_view Key, const std::filesystem::path& TempFilePath = {});
/// Download a byte range of an object from S3
/// @param RangeStart First byte offset (inclusive)
diff --git a/src/zenutil/include/zenutil/consul.h b/src/zenutil/include/zenutil/consul.h
index 4002d5d23..7517ddd1e 100644
--- a/src/zenutil/include/zenutil/consul.h
+++ b/src/zenutil/include/zenutil/consul.h
@@ -28,7 +28,14 @@ struct ServiceRegistrationInfo
class ConsulClient
{
public:
- ConsulClient(std::string_view BaseUri, std::string_view Token = "");
+ struct Configuration
+ {
+ std::string BaseUri;
+ std::string StaticToken;
+ std::string TokenEnvName;
+ };
+
+ ConsulClient(const Configuration& Config);
~ConsulClient();
ConsulClient(const ConsulClient&) = delete;
@@ -55,9 +62,10 @@ public:
private:
static bool FindServiceInJson(std::string_view Json, std::string_view ServiceId);
void ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap);
+ std::string GetNodeName();
- std::string m_Token;
- HttpClient m_HttpClient;
+ Configuration m_Config;
+ HttpClient m_HttpClient;
};
class ConsulProcess
@@ -109,4 +117,6 @@ private:
void RegistrationLoop();
};
+void consul_forcelink();
+
} // namespace zen::consul
diff --git a/src/zenutil/include/zenutil/process/subprocessmanager.h b/src/zenutil/include/zenutil/process/subprocessmanager.h
index 4a25170df..e16c0c446 100644
--- a/src/zenutil/include/zenutil/process/subprocessmanager.h
+++ b/src/zenutil/include/zenutil/process/subprocessmanager.h
@@ -95,14 +95,19 @@ public:
/// Spawn a new child process and begin monitoring it.
///
/// If Options.StdoutPipe is set, the pipe is consumed and async reading
- /// begins automatically. Similarly for Options.StderrPipe.
+ /// begins automatically. Similarly for Options.StderrPipe. When providing
+ /// pipes, pass the corresponding data callback here so it is installed
+ /// before the first async read completes — setting it later via
+ /// SetStdoutCallback risks losing early output.
///
/// Returns a non-owning pointer valid until Remove() or manager destruction.
/// The exit callback fires on an io_context thread when the process terminates.
ManagedProcess* Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit);
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout = {},
+ ProcessDataCallback OnStderr = {});
/// Adopt an already-running process by handle. Takes ownership of handle internals.
ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit);
@@ -182,12 +187,6 @@ public:
/// yet computed.
[[nodiscard]] float GetCpuUsagePercent() const;
- /// Set per-process stdout callback (overrides manager default).
- void SetStdoutCallback(ProcessDataCallback Callback);
-
- /// Set per-process stderr callback (overrides manager default).
- void SetStderrCallback(ProcessDataCallback Callback);
-
/// Return all stdout captured so far. When a callback is set, output is
/// delivered there instead of being accumulated.
[[nodiscard]] std::string GetCapturedStdout() const;
@@ -237,11 +236,14 @@ public:
/// Group name (as passed to CreateGroup).
[[nodiscard]] std::string_view GetName() const;
- /// Spawn a process into this group.
+ /// Spawn a process into this group. See SubprocessManager::Spawn for
+ /// details on the stdout/stderr callback parameters.
ManagedProcess* Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit);
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout = {},
+ ProcessDataCallback OnStderr = {});
/// Adopt an already-running process into this group.
/// On Windows the process is assigned to the group's JobObject.
diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h
index 03d507400..d6f66fbea 100644
--- a/src/zenutil/include/zenutil/zenserverprocess.h
+++ b/src/zenutil/include/zenutil/zenserverprocess.h
@@ -66,6 +66,7 @@ public:
std::filesystem::path CreateNewTestDir();
std::filesystem::path CreateChildDir(std::string_view ChildName);
std::filesystem::path ProgramBaseDir() const { return m_ProgramBaseDir; }
+ std::filesystem::path GetChildBaseDir() const { return m_ChildProcessBaseDir; }
std::filesystem::path GetTestRootDir(std::string_view Path);
inline bool IsInitialized() const { return m_IsInitialized; }
inline bool IsTestEnvironment() const { return m_IsTestInstance; }
diff --git a/src/zenutil/logging/jsonformatter.cpp b/src/zenutil/logging/jsonformatter.cpp
index 673a03c94..c63ad891e 100644
--- a/src/zenutil/logging/jsonformatter.cpp
+++ b/src/zenutil/logging/jsonformatter.cpp
@@ -19,8 +19,6 @@ static void
WriteEscapedString(MemoryBuffer& Dest, std::string_view Text)
{
// Strip ANSI SGR sequences before escaping so they don't appear in JSON output
- static const auto IsEscapeStart = [](char C) { return C == '\033'; };
-
const char* RangeStart = Text.data();
const char* End = Text.data() + Text.size();
diff --git a/src/zenutil/process/subprocessmanager.cpp b/src/zenutil/process/subprocessmanager.cpp
index 3a91b0a61..e908dd63a 100644
--- a/src/zenutil/process/subprocessmanager.cpp
+++ b/src/zenutil/process/subprocessmanager.cpp
@@ -196,18 +196,6 @@ ManagedProcess::GetCpuUsagePercent() const
return m_Impl->m_CpuUsagePercent.load();
}
-void
-ManagedProcess::SetStdoutCallback(ProcessDataCallback Callback)
-{
- m_Impl->m_StdoutCallback = std::move(Callback);
-}
-
-void
-ManagedProcess::SetStderrCallback(ProcessDataCallback Callback)
-{
- m_Impl->m_StderrCallback = std::move(Callback);
-}
-
std::string
ManagedProcess::GetCapturedStdout() const
{
@@ -288,7 +276,9 @@ struct SubprocessManager::Impl
ManagedProcess* Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit);
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr);
ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit);
void Remove(int Pid);
void RemoveAll();
@@ -462,7 +452,9 @@ ManagedProcess*
SubprocessManager::Impl::Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit)
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr)
{
bool HasStdout = Options.StdoutPipe != nullptr;
bool HasStderr = Options.StderrPipe != nullptr;
@@ -476,6 +468,16 @@ SubprocessManager::Impl::Spawn(const std::filesystem::path& Executable,
ImplPtr->m_Handle.Initialize(static_cast<int>(Result));
#endif
+ // Install callbacks before starting async readers so no data is missed.
+ if (OnStdout)
+ {
+ ImplPtr->m_StdoutCallback = std::move(OnStdout);
+ }
+ if (OnStderr)
+ {
+ ImplPtr->m_StderrCallback = std::move(OnStderr);
+ }
+
auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr)));
ManagedProcess* Ptr = AddProcess(std::move(Proc));
@@ -719,10 +721,12 @@ ManagedProcess*
SubprocessManager::Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit)
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr)
{
ZEN_TRACE_CPU("SubprocessManager::Spawn");
- return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit));
+ return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr));
}
ManagedProcess*
@@ -835,7 +839,9 @@ struct ProcessGroup::Impl
ManagedProcess* Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit);
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr);
ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit);
void Remove(int Pid);
void KillAll();
@@ -884,7 +890,9 @@ ManagedProcess*
ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit)
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr)
{
bool HasStdout = Options.StdoutPipe != nullptr;
bool HasStderr = Options.StderrPipe != nullptr;
@@ -895,7 +903,11 @@ ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable,
Options.AssignToJob = &m_JobObject;
}
#else
- if (m_Pgid > 0)
+ if (m_Pgid == 0)
+ {
+ Options.Flags |= CreateProcOptions::Flag_NewProcessGroup;
+ }
+ else
{
Options.ProcessGroupId = m_Pgid;
}
@@ -917,6 +929,16 @@ ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable,
}
#endif
+ // Install callbacks before starting async readers so no data is missed.
+ if (OnStdout)
+ {
+ ImplPtr->m_StdoutCallback = std::move(OnStdout);
+ }
+ if (OnStderr)
+ {
+ ImplPtr->m_StderrCallback = std::move(OnStderr);
+ }
+
auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr)));
ManagedProcess* Ptr = AddProcess(std::move(Proc));
@@ -1077,10 +1099,12 @@ ManagedProcess*
ProcessGroup::Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit)
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr)
{
ZEN_TRACE_CPU("ProcessGroup::Spawn");
- return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit));
+ return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr));
}
ManagedProcess*
@@ -1289,9 +1313,12 @@ TEST_CASE("SubprocessManager.StdoutCallback")
std::string ReceivedData;
bool Exited = false;
- ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; });
-
- Proc->SetStdoutCallback([&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); });
+ ManagedProcess* Proc = Manager.Spawn(
+ AppStub,
+ CmdLine,
+ Options,
+ [&](ManagedProcess&, int) { Exited = true; },
+ [&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); });
IoContext.run_for(5s);
diff --git a/src/zenutil/zenutil.cpp b/src/zenutil/zenutil.cpp
index 2ca380c75..516eec3a9 100644
--- a/src/zenutil/zenutil.cpp
+++ b/src/zenutil/zenutil.cpp
@@ -5,6 +5,7 @@
#if ZEN_WITH_TESTS
# include <zenutil/cloud/imdscredentials.h>
+# include <zenutil/consul.h>
# include <zenutil/cloud/s3client.h>
# include <zenutil/cloud/sigv4.h>
# include <zenutil/config/commandlineoptions.h>
@@ -20,6 +21,7 @@ zenutil_forcelinktests()
{
cache::rpcrecord_forcelink();
commandlineoptions_forcelink();
+ consul::consul_forcelink();
imdscredentials_forcelink();
logstreamlistener_forcelink();
subprocessmanager_forcelink();