diff options
| author | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
|---|---|---|
| committer | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
| commit | d1abc50ee9d4fb72efc646e17decafea741caa34 (patch) | |
| tree | e4288e00f2f7ca0391b83d986efcb69d3ba66a83 /src/zencompute/runners/localrunner.cpp | |
| parent | Allow requests with invalid content-types unless specified in command line or... (diff) | |
| parent | updated chunk–block analyser (#818) (diff) | |
| download | zen-d1abc50ee9d4fb72efc646e17decafea741caa34.tar.xz zen-d1abc50ee9d4fb72efc646e17decafea741caa34.zip | |
Merge branch 'main' into lm/restrict-content-type
Diffstat (limited to 'src/zencompute/runners/localrunner.cpp')
| -rw-r--r-- | src/zencompute/runners/localrunner.cpp | 674 |
1 files changed, 674 insertions, 0 deletions
diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp new file mode 100644 index 000000000..7aaefb06e --- /dev/null +++ b/src/zencompute/runners/localrunner.cpp @@ -0,0 +1,674 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/system.h> +# include <zencore/scopeguard.h> +# include <zencore/timer.h> +# include <zencore/trace.h> +# include <zenstore/cidstore.h> + +# include <span> + +namespace zen::compute { + +using namespace std::literals; + +LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions) +: FunctionRunner(BaseDir) +, m_Log(logging::Get("local_exec")) +, m_ChunkResolver(Resolver) +, m_WorkerPath(std::filesystem::weakly_canonical(BaseDir / "workers")) +, m_SandboxPath(std::filesystem::weakly_canonical(BaseDir / "scratch")) +, m_DeferredDeleter(Deleter) +, m_WorkerPool(WorkerPool) +{ + SystemMetrics Sm = GetSystemMetricsForReporting(); + + m_MaxRunningActions = Sm.LogicalProcessorCount * 2; + + if (MaxConcurrentActions > 0) + { + m_MaxRunningActions = MaxConcurrentActions; + } + + ZEN_INFO("Max concurrent action count: {}", m_MaxRunningActions); + + bool DidCleanup = false; + + if (std::filesystem::is_directory(m_ActionsPath)) + { + ZEN_INFO("Cleaning '{}'", m_ActionsPath); + + std::error_code Ec; + CleanDirectory(m_ActionsPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_ActionsPath, Ec.message()); + } + + DidCleanup = true; + } + + if (std::filesystem::is_directory(m_SandboxPath)) + { + ZEN_INFO("Cleaning '{}'", m_SandboxPath); + std::error_code Ec; + CleanDirectory(m_SandboxPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_SandboxPath, Ec.message()); + } + + DidCleanup = true; + } + + // We clean out all workers on startup since we can't know they are good. They could be bad + // due to tampering, malware (which I also mean to include AV and antimalware software) or + // other processes we have no control over + if (std::filesystem::is_directory(m_WorkerPath)) + { + ZEN_INFO("Cleaning '{}'", m_WorkerPath); + std::error_code Ec; + CleanDirectory(m_WorkerPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_WorkerPath, Ec.message()); + } + + DidCleanup = true; + } + + if (DidCleanup) + { + ZEN_INFO("Cleanup complete"); + } + + m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; + +# if ZEN_PLATFORM_WINDOWS + // Suppress any error dialogs caused by missing dependencies + UINT OldMode = ::SetErrorMode(0); + ::SetErrorMode(OldMode | SEM_FAILCRITICALERRORS); +# endif + + m_AcceptNewActions = true; +} + +LocalProcessRunner::~LocalProcessRunner() +{ + try + { + Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception during local process runner shutdown: {}", Ex.what()); + } +} + +void +LocalProcessRunner::Shutdown() +{ + ZEN_TRACE_CPU("LocalProcessRunner::Shutdown"); + m_AcceptNewActions = false; + + m_MonitorThreadEnabled = false; + m_MonitorThreadEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } + + CancelRunningActions(); +} + +std::filesystem::path +LocalProcessRunner::CreateNewSandbox() +{ + ZEN_TRACE_CPU("LocalProcessRunner::CreateNewSandbox"); + std::string UniqueId = std::to_string(++m_SandboxCounter); + std::filesystem::path Path = m_SandboxPath / UniqueId; + zen::CreateDirectories(Path); + + return Path; +} + +void +LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) +{ + ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker"); + if (m_DumpActions) + { + CbObject WorkerDescriptor = WorkerPackage.GetObject(); + const IoHash& WorkerId = WorkerPackage.GetObjectHash(); + + std::string UniqueId = fmt::format("worker_{}"sv, WorkerId); + std::filesystem::path Path = m_ActionsPath / UniqueId; + + zen::WriteFile(Path / "worker.ucb", WorkerDescriptor.GetBuffer().AsIoBuffer()); + + ManifestWorker(WorkerPackage, Path / "tree", [&](const IoHash& Cid, CompressedBuffer& ChunkBuffer) { + std::filesystem::path ChunkPath = Path / "chunks" / Cid.ToHexString(); + zen::WriteFile(ChunkPath, ChunkBuffer.GetCompressed()); + }); + + ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path); + } +} + +size_t +LocalProcessRunner::QueryCapacity() +{ + // Estimate how much more work we're ready to accept + + RwLock::SharedLockScope _{m_RunningLock}; + + if (!m_AcceptNewActions) + { + return 0; + } + + const size_t InFlightCount = m_RunningMap.size() + m_SubmittingCount.load(std::memory_order_relaxed); + + if (const size_t MaxRunningActions = m_MaxRunningActions; InFlightCount >= MaxRunningActions) + { + return 0; + } + else + { + return MaxRunningActions - InFlightCount; + } +} + +std::vector<SubmitResult> +LocalProcessRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + if (Actions.size() <= 1) + { + std::vector<SubmitResult> Results; + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; + } + + // For nontrivial batches, check capacity upfront and accept what fits. + // Accepted actions are transitioned to Submitting and dispatched to the + // worker pool as fire-and-forget, so SubmitActions returns immediately + // and the scheduler thread is free to handle completions and updates. + + size_t Available = QueryCapacity(); + + std::vector<SubmitResult> Results(Actions.size()); + + size_t AcceptCount = std::min(Available, Actions.size()); + + for (size_t i = 0; i < AcceptCount; ++i) + { + const Ref<RunnerAction>& Action = Actions[i]; + + Action->SetActionState(RunnerAction::State::Submitting); + m_SubmittingCount.fetch_add(1, std::memory_order_relaxed); + + Results[i] = SubmitResult{.IsAccepted = true}; + + m_WorkerPool.ScheduleWork( + [this, Action]() { + auto CountGuard = MakeGuard([this] { m_SubmittingCount.fetch_sub(1, std::memory_order_relaxed); }); + + SubmitResult Result = SubmitAction(Action); + + if (!Result.IsAccepted) + { + // This might require another state? We should + // distinguish between outright rejections (e.g. invalid action) + // and transient failures (e.g. failed to launch process) which might + // be retried by the scheduler, but for now just fail the action + Action->SetActionState(RunnerAction::State::Failed); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + + for (size_t i = AcceptCount; i < Actions.size(); ++i) + { + Results[i] = SubmitResult{.IsAccepted = false}; + } + + return Results; +} + +std::optional<LocalProcessRunner::PreparedAction> +LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("LocalProcessRunner::PrepareActionSubmission"); + + // Verify whether we can accept more work + + { + RwLock::SharedLockScope _{m_RunningLock}; + + if (!m_AcceptNewActions) + { + return std::nullopt; + } + + if (m_RunningMap.size() >= size_t(m_MaxRunningActions)) + { + return std::nullopt; + } + } + + // Each enqueued action is assigned an integer index (logical sequence number), + // which we use as a key for tracking data structures and as an opaque id which + // may be used by clients to reference the scheduled action + + const int32_t ActionLsn = Action->ActionLsn; + const CbObject& ActionObj = Action->ActionObj; + + MaybeDumpAction(ActionLsn, ActionObj); + + std::filesystem::path SandboxPath = CreateNewSandbox(); + + // Ensure the sandbox directory is cleaned up if any subsequent step throws + auto SandboxGuard = MakeGuard([&] { m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(SandboxPath)); }); + + CbPackage WorkerPackage = Action->Worker.Descriptor; + + std::filesystem::path WorkerPath = ManifestWorker(Action->Worker); + + // Write out action + + zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer()); + + // Manifest inputs in sandbox + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash Cid = Field.AsHash(); + std::filesystem::path FilePath{SandboxPath / "Inputs"sv / Cid.ToHexString()}; + IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(Cid); + + if (!DataBuffer) + { + throw std::runtime_error(fmt::format("input CID chunk '{}' missing", Cid)); + } + + zen::WriteFile(FilePath, DataBuffer); + }); + + Action->ExecutionLocation = "local"; + + SandboxGuard.Dismiss(); + + return PreparedAction{ + .ActionLsn = ActionLsn, + .SandboxPath = std::move(SandboxPath), + .WorkerPath = std::move(WorkerPath), + .WorkerPackage = std::move(WorkerPackage), + }; +} + +SubmitResult +LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + // Base class is not directly usable — platform subclasses override this + ZEN_UNUSED(Action); + return SubmitResult{.IsAccepted = false}; +} + +size_t +LocalProcessRunner::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunningLock); + return m_RunningMap.size(); +} + +std::filesystem::path +LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) +{ + ZEN_TRACE_CPU("LocalProcessRunner::ManifestWorker"); + RwLock::SharedLockScope _(m_WorkerLock); + + std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); + + if (!std::filesystem::exists(WorkerDir)) + { + _.ReleaseNow(); + + RwLock::ExclusiveLockScope $(m_WorkerLock); + + if (!std::filesystem::exists(WorkerDir)) + { + ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {}); + } + } + + return WorkerDir; +} + +void +LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromPackage, + CbObjectView FileEntry, + const std::filesystem::path& SandboxRootPath, + std::function<void(const IoHash&, CompressedBuffer&)>& ChunkReferenceCallback) +{ + std::string_view Name = FileEntry["name"sv].AsString(); + const IoHash ChunkHash = FileEntry["hash"sv].AsHash(); + const uint64_t Size = FileEntry["size"sv].AsUInt64(); + + CompressedBuffer Compressed; + + if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash)) + { + Compressed = Attachment->AsCompressedBinary(); + } + else + { + IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(ChunkHash); + + if (!DataBuffer) + { + throw std::runtime_error(fmt::format("worker chunk '{}' missing", ChunkHash)); + } + + uint64_t DataRawSize = 0; + IoHash DataRawHash; + Compressed = CompressedBuffer::FromCompressed(SharedBuffer{DataBuffer}, DataRawHash, DataRawSize); + + if (DataRawSize != Size) + { + throw std::runtime_error( + fmt::format("worker chunk '{}' size: {}, action spec expected {}", ChunkHash, DataBuffer.Size(), Size)); + } + } + + ChunkReferenceCallback(ChunkHash, Compressed); + + std::filesystem::path FilePath{SandboxRootPath / std::filesystem::path(Name).make_preferred()}; + + // Validate the resolved path stays within the sandbox to prevent directory traversal + // via malicious names like "../../etc/evil" + // + // This might be worth revisiting to frontload the validation and eliminate some memory + // allocations in the future. + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxRootPath); + std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(FilePath); + std::string RootStr = CanonicalRoot.string(); + std::string FileStr = CanonicalFile.string(); + + if (FileStr.size() < RootStr.size() || FileStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: '{}' escapes sandbox root '{}'", Name, SandboxRootPath); + } + } + + SharedBuffer Decompressed = Compressed.Decompress(); + zen::WriteFile(FilePath, Decompressed.AsIoBuffer()); +} + +void +LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, + const std::filesystem::path& SandboxPath, + std::function<void(const IoHash&, CompressedBuffer&)>&& ChunkReferenceCallback) +{ + CbObject WorkerDescription = WorkerPackage.GetObject(); + + // Manifest worker in Sandbox + + for (auto& It : WorkerDescription["executables"sv]) + { + DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); +# if !ZEN_PLATFORM_WINDOWS + std::string_view ExeName = It.AsObjectView()["name"sv].AsString(); + std::filesystem::path ExePath{SandboxPath / std::filesystem::path(ExeName).make_preferred()}; + std::filesystem::permissions( + ExePath, + std::filesystem::perms::owner_exec | std::filesystem::perms::group_exec | std::filesystem::perms::others_exec, + std::filesystem::perm_options::add); +# endif + } + + for (auto& It : WorkerDescription["dirs"sv]) + { + std::string_view Name = It.AsString(); + std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; + + // Validate dir path stays within sandbox + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxPath); + std::filesystem::path CanonicalDir = std::filesystem::weakly_canonical(DirPath); + std::string RootStr = CanonicalRoot.string(); + std::string DirStr = CanonicalDir.string(); + + if (DirStr.size() < RootStr.size() || DirStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: dir '{}' escapes sandbox root '{}'", Name, SandboxPath); + } + } + + zen::CreateDirectories(DirPath); + } + + for (auto& It : WorkerDescription["files"sv]) + { + DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); + } + + WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer()); +} + +CbPackage +LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) +{ + ZEN_TRACE_CPU("LocalProcessRunner::GatherActionOutputs"); + std::filesystem::path OutputFile = SandboxPath / "build.output"; + FileContents OutputData = zen::ReadFile(OutputFile); + + if (OutputData.ErrorCode) + { + throw std::system_error(OutputData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputFile)); + } + + CbPackage OutputPackage; + CbObject Output = zen::LoadCompactBinaryObject(OutputData.Flatten()); + + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalRawAttachmentBytes = 0; + + Output.IterateAttachments([&](CbFieldView Field) { + IoHash Hash = Field.AsHash(); + std::filesystem::path OutputPath{SandboxPath / "Outputs" / Hash.ToHexString()}; + FileContents ChunkData = zen::ReadFile(OutputPath); + + if (ChunkData.ErrorCode) + { + throw std::system_error(ChunkData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputPath)); + } + + uint64_t ChunkDataRawSize = 0; + IoHash ChunkDataHash; + CompressedBuffer AttachmentBuffer = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData.Flatten()), ChunkDataHash, ChunkDataRawSize); + + if (!AttachmentBuffer) + { + throw std::runtime_error("Invalid output encountered (not valid CompressedBuffer format)"); + } + + TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize(); + TotalRawAttachmentBytes += ChunkDataRawSize; + + CbAttachment Attachment(std::move(AttachmentBuffer), ChunkDataHash); + OutputPackage.AddAttachment(Attachment); + }); + + OutputPackage.SetObject(Output); + + ZEN_DEBUG("Action completed with {} attachments ({} compressed, {} uncompressed)", + OutputPackage.GetAttachments().size(), + NiceBytes(TotalAttachmentBytes), + NiceBytes(TotalRawAttachmentBytes)); + + return OutputPackage; +} + +void +LocalProcessRunner::MonitorThreadFunction() +{ + SetCurrentThreadName("LocalProcessRunner_Monitor"); + + auto _ = MakeGuard([&] { ZEN_INFO("monitor thread exiting"); }); + + do + { + // On Windows it's possible to wait on process handles, so we wait for either a process to exit + // or for the monitor event to be signaled (which indicates we should check for cancellation + // or shutdown). This could be further improved by using a completion port and registering process + // handles with it, but this is a reasonable first implementation given that we shouldn't be dealing + // with an enormous number of concurrent processes. + // + // On other platforms we just wait on the monitor event and poll for process exits at intervals. +# if ZEN_PLATFORM_WINDOWS + auto WaitOnce = [&] { + HANDLE WaitHandles[MAXIMUM_WAIT_OBJECTS]; + + uint32_t NumHandles = 0; + + WaitHandles[NumHandles++] = m_MonitorThreadEvent.GetWindowsHandle(); + + m_RunningLock.WithSharedLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd && NumHandles < MAXIMUM_WAIT_OBJECTS; ++It) + { + Ref<RunningAction> Action = It->second; + + WaitHandles[NumHandles++] = Action->ProcessHandle; + } + }); + + DWORD WaitResult = WaitForMultipleObjects(NumHandles, WaitHandles, FALSE, 1000); + + // return true if a handle was signaled + return (WaitResult <= NumHandles); + }; +# else + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(1000); }; +# endif + + while (!WaitOnce()) + { + if (m_MonitorThreadEnabled == false) + { + return; + } + + SweepRunningActions(); + SampleRunningProcessCpu(); + } + + // Signal received + + SweepRunningActions(); + SampleRunningProcessCpu(); + } while (m_MonitorThreadEnabled); +} + +void +LocalProcessRunner::CancelRunningActions() +{ + // Base class is not directly usable — platform subclasses override this +} + +void +LocalProcessRunner::SampleRunningProcessCpu() +{ + static constexpr uint64_t kSampleIntervalMs = 5'000; + + m_RunningLock.WithSharedLock([&] { + const uint64_t Now = GetHifreqTimerValue(); + for (auto& [Lsn, Running] : m_RunningMap) + { + const bool NeverSampled = Running->LastCpuSampleTicks == 0; + const bool IntervalElapsed = Stopwatch::GetElapsedTimeMs(Now - Running->LastCpuSampleTicks) >= kSampleIntervalMs; + if (NeverSampled || IntervalElapsed) + { + SampleProcessCpu(*Running); + } + } + }); +} + +void +LocalProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("LocalProcessRunner::SweepRunningActions"); +} + +void +LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& CompletedActions) +{ + ZEN_TRACE_CPU("LocalProcessRunner::ProcessCompletedActions"); + // Shared post-processing: gather outputs, set state, clean sandbox. + // Note that this must be called without holding any local locks + // otherwise we may end up with deadlocks. + + for (Ref<RunningAction> Running : CompletedActions) + { + const int ActionLsn = Running->Action->ActionLsn; + + if (Running->ExitCode == 0) + { + try + { + // Gather outputs + + CbPackage OutputPackage = GatherActionOutputs(Running->SandboxPath); + + Running->Action->SetResult(std::move(OutputPackage)); + Running->Action->SetActionState(RunnerAction::State::Completed); + + // Enqueue sandbox for deferred background deletion, giving + // file handles time to close before we attempt removal. + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); + + // Success -- continue with next iteration of the loop + continue; + } + catch (std::exception& Ex) + { + ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what()); + } + } + + // Failed - clean up the sandbox in the background. + + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } +} + +} // namespace zen::compute + +#endif |