aboutsummaryrefslogtreecommitdiff
path: root/src/zencompute/runners/functionrunner.h
blob: 56c3f3af059e7913250c20ac96550f5dccdf37bb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
// Copyright Epic Games, Inc. All Rights Reserved.

#pragma once

#include <zencompute/computeservice.h>

#if ZEN_WITH_COMPUTE_SERVICES

#	include <atomic>
#	include <filesystem>
#	include <vector>

namespace zen::compute {

struct SubmitResult
{
	bool		IsAccepted = false;
	std::string Reason;
};

/** Base interface for classes implementing a remote execution "runner"
 */
class FunctionRunner : public RefCounted
{
	FunctionRunner(FunctionRunner&&) = delete;
	FunctionRunner& operator=(FunctionRunner&&) = delete;

public:
	FunctionRunner(std::filesystem::path BasePath);
	virtual ~FunctionRunner() = 0;

	virtual void			   Shutdown()									  = 0;
	[[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) = 0;

	[[nodiscard]] virtual SubmitResult				SubmitAction(Ref<RunnerAction> Action) = 0;
	[[nodiscard]] virtual size_t					GetSubmittedActionCount()			   = 0;
	[[nodiscard]] virtual bool						IsHealthy()							   = 0;
	[[nodiscard]] virtual size_t					QueryCapacity();
	[[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);

	// Best-effort cancellation of a specific in-flight action. Returns true if the
	// cancellation signal was successfully sent. The action will transition to Cancelled
	// asynchronously once the platform-level termination completes.
	virtual bool CancelAction(int /*ActionLsn*/) { return false; }

	// Cancel the remote queue corresponding to the given local QueueId.
	// Only meaningful for remote runners; local runners ignore this.
	virtual void CancelRemoteQueue(int /*QueueId*/) {}

protected:
	std::filesystem::path m_ActionsPath;
	bool				  m_DumpActions = false;
	void				  MaybeDumpAction(int ActionLsn, const CbObject& ActionObject);
};

/** Base class for RunnerGroup that operates on generic FunctionRunner references.
 *  All scheduling, capacity, and lifecycle logic lives here.
 */
class BaseRunnerGroup
{
public:
	size_t					  QueryCapacity();
	SubmitResult			  SubmitAction(Ref<RunnerAction> Action);
	std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);
	size_t					  GetSubmittedActionCount();
	[[nodiscard]] bool		  RegisterWorker(CbPackage Worker);
	void					  Shutdown();
	bool					  CancelAction(int ActionLsn);
	void					  CancelRemoteQueue(int QueueId);

	size_t GetRunnerCount()
	{
		return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); });
	}

protected:
	void AddRunnerInternal(FunctionRunner* Runner);

	RwLock							 m_RunnersLock;
	std::vector<Ref<FunctionRunner>> m_Runners;
	std::atomic<int>				 m_NextSubmitIndex{0};
};

/** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal.
 */
template<typename RunnerType>
struct RunnerGroup : public BaseRunnerGroup
{
	void AddRunner(RunnerType* Runner) { AddRunnerInternal(Runner); }

	template<typename Predicate>
	size_t RemoveRunnerIf(Predicate&& Pred)
	{
		size_t RemovedCount = 0;
		m_RunnersLock.WithExclusiveLock([&] {
			auto It = m_Runners.begin();
			while (It != m_Runners.end())
			{
				if (Pred(static_cast<RunnerType&>(**It)))
				{
					(*It)->Shutdown();
					It = m_Runners.erase(It);
					++RemovedCount;
				}
				else
				{
					++It;
				}
			}
		});
		return RemovedCount;
	}
};

/**
 * This represents an action going through different stages of scheduling and execution.
 *
 * State machine
 * =============
 *
 * Normal forward flow (enforced by SetActionState rejecting backward transitions):
 *
 *   New -> Pending -> Submitting -> Running -> Completed
 *                                           -> Failed
 *                                           -> Abandoned
 *                                           -> Cancelled
 *
 * Rescheduling (via ResetActionStateToPending):
 *
 *   Failed    ---> Pending   (increments RetryCount, subject to retry limit)
 *   Abandoned ---> Pending   (increments RetryCount, subject to retry limit)
 *   Retracted ---> Pending   (does NOT increment RetryCount)
 *
 * Retraction (via RetractAction, idempotent):
 *
 *   Pending/Submitting/Running -> Retracted -> Pending (rescheduled)
 *
 * Retracted is placed after Cancelled in enum order so that once set,
 * no runner-side transition (Completed/Failed) can override it via
 * SetActionState's forward-only rule.
 */
struct RunnerAction : public RefCounted
{
	explicit RunnerAction(ComputeServiceSession* OwnerSession);
	~RunnerAction();

	int			ActionLsn = 0;
	int			QueueId	  = 0;
	WorkerDesc	Worker;
	IoHash		ActionId;
	CbObject	ActionObj;
	int			Priority = 0;
	std::string ExecutionLocation;	// "local" or remote hostname

	// CPU usage and total CPU time of the running process, sampled periodically by the local runner.
	// CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage.
	// CpuSeconds: total CPU time (user+system) consumed since process start, in seconds. 0.0 if not yet sampled.
	std::atomic<float> CpuUsagePercent{-1.0f};
	std::atomic<float> CpuSeconds{0.0f};
	std::atomic<int>   RetryCount{0};

	enum class State
	{
		New,		 // Initial state at construction, before entering the scheduler
		Pending,	 // Queued and waiting for a runner slot
		Submitting,	 // Being handed off to a runner (async submission in progress)
		Running,	 // Executing on a runner process
		Completed,	 // Finished successfully with results available
		Failed,		 // Execution failed (transient error, eligible for retry)
		Abandoned,	 // Infrastructure termination (e.g. spot eviction, session abandon)
		Cancelled,	 // Intentional user cancellation (never retried)
		Retracted,	 // Pulled back for rescheduling on a different runner (no retry cost)
		_Count
	};
	static_assert(State::Retracted > State::Completed && State::Retracted > State::Failed && State::Retracted > State::Abandoned &&
					  State::Retracted > State::Cancelled,
				  "Retracted must be the highest terminal ordinal so runner-side transitions cannot override it");

	static const char* ToString(State _)
	{
		switch (_)
		{
			case State::New:
				return "New";
			case State::Pending:
				return "Pending";
			case State::Submitting:
				return "Submitting";
			case State::Running:
				return "Running";
			case State::Completed:
				return "Completed";
			case State::Failed:
				return "Failed";
			case State::Abandoned:
				return "Abandoned";
			case State::Cancelled:
				return "Cancelled";
			case State::Retracted:
				return "Retracted";
			default:
				return "Unknown";
		}
	}

	static State FromString(std::string_view Name, State Default = State::Failed)
	{
		for (int i = 0; i < static_cast<int>(State::_Count); ++i)
		{
			if (Name == ToString(static_cast<State>(i)))
			{
				return static_cast<State>(i);
			}
		}
		return Default;
	}

	uint64_t Timestamps[static_cast<int>(State::_Count)] = {};

	State ActionState() const { return m_ActionState; }
	void  SetActionState(State NewState);

	bool IsSuccess() const { return ActionState() == State::Completed; }
	bool RetractAction();
	bool ResetActionStateToPending();
	bool IsCompleted() const
	{
		return ActionState() == State::Completed || ActionState() == State::Failed || ActionState() == State::Abandoned ||
			   ActionState() == State::Cancelled;
	}

	void	   SetResult(CbPackage&& Result);
	CbPackage& GetResult();

	ComputeServiceSession* GetOwnerSession() const { return m_OwnerSession; }

private:
	std::atomic<State>	   m_ActionState  = State::New;
	ComputeServiceSession* m_OwnerSession = nullptr;
	CbPackage			   m_Result;
};

}  // namespace zen::compute

#endif	// ZEN_WITH_COMPUTE_SERVICES