aboutsummaryrefslogtreecommitdiff
path: root/src/zencompute/runners/functionrunner.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/zencompute/runners/functionrunner.h')
-rw-r--r--src/zencompute/runners/functionrunner.h214
1 files changed, 214 insertions, 0 deletions
diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h
new file mode 100644
index 000000000..f67414dbb
--- /dev/null
+++ b/src/zencompute/runners/functionrunner.h
@@ -0,0 +1,214 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencompute/computeservice.h>
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <atomic>
+# include <filesystem>
+# include <vector>
+
+namespace zen::compute {
+
+struct SubmitResult
+{
+ bool IsAccepted = false;
+ std::string Reason;
+};
+
+/** Base interface for classes implementing a remote execution "runner"
+ */
+class FunctionRunner : public RefCounted
+{
+ FunctionRunner(FunctionRunner&&) = delete;
+ FunctionRunner& operator=(FunctionRunner&&) = delete;
+
+public:
+ FunctionRunner(std::filesystem::path BasePath);
+ virtual ~FunctionRunner() = 0;
+
+ virtual void Shutdown() = 0;
+ virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0;
+
+ [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0;
+ [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0;
+ [[nodiscard]] virtual bool IsHealthy() = 0;
+ [[nodiscard]] virtual size_t QueryCapacity();
+ [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);
+
+ // Best-effort cancellation of a specific in-flight action. Returns true if the
+ // cancellation signal was successfully sent. The action will transition to Cancelled
+ // asynchronously once the platform-level termination completes.
+ virtual bool CancelAction(int /*ActionLsn*/) { return false; }
+
+ // Cancel the remote queue corresponding to the given local QueueId.
+ // Only meaningful for remote runners; local runners ignore this.
+ virtual void CancelRemoteQueue(int /*QueueId*/) {}
+
+protected:
+ std::filesystem::path m_ActionsPath;
+ bool m_DumpActions = false;
+ void MaybeDumpAction(int ActionLsn, const CbObject& ActionObject);
+};
+
+/** Base class for RunnerGroup that operates on generic FunctionRunner references.
+ * All scheduling, capacity, and lifecycle logic lives here.
+ */
+class BaseRunnerGroup
+{
+public:
+ size_t QueryCapacity();
+ SubmitResult SubmitAction(Ref<RunnerAction> Action);
+ std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);
+ size_t GetSubmittedActionCount();
+ void RegisterWorker(CbPackage Worker);
+ void Shutdown();
+ bool CancelAction(int ActionLsn);
+ void CancelRemoteQueue(int QueueId);
+
+ size_t GetRunnerCount()
+ {
+ return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); });
+ }
+
+protected:
+ void AddRunnerInternal(FunctionRunner* Runner);
+
+ RwLock m_RunnersLock;
+ std::vector<Ref<FunctionRunner>> m_Runners;
+ std::atomic<int> m_NextSubmitIndex{0};
+};
+
+/** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal.
+ */
+template<typename RunnerType>
+struct RunnerGroup : public BaseRunnerGroup
+{
+ void AddRunner(RunnerType* Runner) { AddRunnerInternal(Runner); }
+
+ template<typename Predicate>
+ size_t RemoveRunnerIf(Predicate&& Pred)
+ {
+ size_t RemovedCount = 0;
+ m_RunnersLock.WithExclusiveLock([&] {
+ auto It = m_Runners.begin();
+ while (It != m_Runners.end())
+ {
+ if (Pred(static_cast<RunnerType&>(**It)))
+ {
+ (*It)->Shutdown();
+ It = m_Runners.erase(It);
+ ++RemovedCount;
+ }
+ else
+ {
+ ++It;
+ }
+ }
+ });
+ return RemovedCount;
+ }
+};
+
+/**
+ * This represents an action going through different stages of scheduling and execution.
+ */
+struct RunnerAction : public RefCounted
+{
+ explicit RunnerAction(ComputeServiceSession* OwnerSession);
+ ~RunnerAction();
+
+ int ActionLsn = 0;
+ int QueueId = 0;
+ WorkerDesc Worker;
+ IoHash ActionId;
+ CbObject ActionObj;
+ int Priority = 0;
+ std::string ExecutionLocation; // "local" or remote hostname
+
+ // CPU usage and total CPU time of the running process, sampled periodically by the local runner.
+ // CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage.
+ // CpuSeconds: total CPU time (user+system) consumed since process start, in seconds. 0.0 if not yet sampled.
+ std::atomic<float> CpuUsagePercent{-1.0f};
+ std::atomic<float> CpuSeconds{0.0f};
+ std::atomic<int> RetryCount{0};
+
+ enum class State
+ {
+ New,
+ Pending,
+ Submitting,
+ Running,
+ Completed,
+ Failed,
+ Abandoned,
+ Cancelled,
+ _Count
+ };
+
+ static const char* ToString(State _)
+ {
+ switch (_)
+ {
+ case State::New:
+ return "New";
+ case State::Pending:
+ return "Pending";
+ case State::Submitting:
+ return "Submitting";
+ case State::Running:
+ return "Running";
+ case State::Completed:
+ return "Completed";
+ case State::Failed:
+ return "Failed";
+ case State::Abandoned:
+ return "Abandoned";
+ case State::Cancelled:
+ return "Cancelled";
+ default:
+ return "Unknown";
+ }
+ }
+
+ static State FromString(std::string_view Name, State Default = State::Failed)
+ {
+ for (int i = 0; i < static_cast<int>(State::_Count); ++i)
+ {
+ if (Name == ToString(static_cast<State>(i)))
+ {
+ return static_cast<State>(i);
+ }
+ }
+ return Default;
+ }
+
+ uint64_t Timestamps[static_cast<int>(State::_Count)] = {};
+
+ State ActionState() const { return m_ActionState; }
+ void SetActionState(State NewState);
+
+ bool IsSuccess() const { return ActionState() == State::Completed; }
+ bool ResetActionStateToPending();
+ bool IsCompleted() const
+ {
+ return ActionState() == State::Completed || ActionState() == State::Failed || ActionState() == State::Abandoned ||
+ ActionState() == State::Cancelled;
+ }
+
+ void SetResult(CbPackage&& Result);
+ CbPackage& GetResult();
+
+ ComputeServiceSession* GetOwnerSession() const { return m_OwnerSession; }
+
+private:
+ std::atomic<State> m_ActionState = State::New;
+ ComputeServiceSession* m_OwnerSession = nullptr;
+ CbPackage m_Result;
+};
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES \ No newline at end of file