diff options
| author | Stefan Boberg <[email protected]> | 2026-03-04 14:13:46 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-04 14:13:46 +0100 |
| commit | 0763d09a81e5a1d3df11763a7ec75e7860c9510a (patch) | |
| tree | 074575ba6ea259044a179eab0bb396d37268fb09 /src/zencompute/include | |
| parent | native xmake toolchain definition for UE-clang (#805) (diff) | |
| download | zen-0763d09a81e5a1d3df11763a7ec75e7860c9510a.tar.xz zen-0763d09a81e5a1d3df11763a7ec75e7860c9510a.zip | |
compute orchestration (#763)
- Added local process runners for Linux/Wine, Mac with some sandboxing support
- Horde & Nomad provisioning for development and testing
- Client session queues with lifecycle management (active/draining/cancelled), automatic retry with configurable limits, and manual reschedule API
- Improved web UI for orchestrator, compute, and hub dashboards with WebSocket push updates
- Some security hardening
- Improved scalability and `zen exec` command
Still experimental - compute support is disabled by default
Diffstat (limited to 'src/zencompute/include')
| -rw-r--r-- | src/zencompute/include/zencompute/cloudmetadata.h | 151 | ||||
| -rw-r--r-- | src/zencompute/include/zencompute/computeservice.h | 262 | ||||
| -rw-r--r-- | src/zencompute/include/zencompute/functionservice.h | 132 | ||||
| -rw-r--r-- | src/zencompute/include/zencompute/httpcomputeservice.h | 54 | ||||
| -rw-r--r-- | src/zencompute/include/zencompute/httpfunctionservice.h | 73 | ||||
| -rw-r--r-- | src/zencompute/include/zencompute/httporchestrator.h | 81 | ||||
| -rw-r--r-- | src/zencompute/include/zencompute/mockimds.h | 102 | ||||
| -rw-r--r-- | src/zencompute/include/zencompute/orchestratorservice.h | 177 | ||||
| -rw-r--r-- | src/zencompute/include/zencompute/recordingreader.h | 4 | ||||
| -rw-r--r-- | src/zencompute/include/zencompute/zencompute.h | 4 |
10 files changed, 822 insertions, 218 deletions
diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h new file mode 100644 index 000000000..a5bc5a34d --- /dev/null +++ b/src/zencompute/include/zencompute/cloudmetadata.h @@ -0,0 +1,151 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinarybuilder.h> +#include <zencore/logging.h> +#include <zencore/thread.h> + +#include <atomic> +#include <filesystem> +#include <string> +#include <thread> + +namespace zen::compute { + +enum class CloudProvider +{ + None, + AWS, + Azure, + GCP +}; + +std::string_view ToString(CloudProvider Provider); + +/** Snapshot of detected cloud instance properties. */ +struct CloudInstanceInfo +{ + CloudProvider Provider = CloudProvider::None; + std::string InstanceId; + std::string AvailabilityZone; + bool IsSpot = false; + bool IsAutoscaling = false; +}; + +/** + * Detects whether the process is running on a cloud VM (AWS, Azure, or GCP) + * and monitors for impending termination signals. + * + * Detection works by querying the Instance Metadata Service (IMDS) at the + * well-known link-local address 169.254.169.254, which is only routable from + * within a cloud VM. Each provider is probed in sequence (AWS -> Azure -> GCP); + * the first successful response wins. + * + * To avoid a ~200ms connect timeout penalty on every startup when running on + * bare-metal or non-cloud machines, failed probes write sentinel files + * (e.g. ".isNotAWS") to DataDir. Subsequent startups skip providers that have + * a sentinel present. Delete the sentinel files to force re-detection. + * + * When a provider is detected, a background thread polls for termination + * signals every 5 seconds (spot interruption, autoscaling lifecycle changes, + * scheduled maintenance). The termination state is exposed as an atomic bool + * so the compute server can include it in coordinator announcements and react + * to imminent shutdown. + * + * Thread safety: GetInstanceInfo() and GetTerminationReason() acquire a + * shared RwLock; the background monitor thread acquires the exclusive lock + * only when writing the termination reason (a one-time transition). The + * termination-pending flag itself is a lock-free atomic. + * + * Usage: + * auto Cloud = std::make_unique<CloudMetadata>(DataDir / "cloud"); + * if (Cloud->IsTerminationPending()) { ... } + * Cloud->Describe(AnnounceBody); // writes "cloud" sub-object into CB + */ +class CloudMetadata +{ +public: + /** Synchronously probes cloud providers and starts the termination monitor + * if a provider is detected. Creates DataDir if it does not exist. + */ + explicit CloudMetadata(std::filesystem::path DataDir); + + /** Synchronously probes cloud providers at the given IMDS endpoint. + * Intended for testing — allows redirecting all IMDS queries to a local + * mock HTTP server instead of the real 169.254.169.254 endpoint. + */ + CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint); + + /** Stops the termination monitor thread and joins it. */ + ~CloudMetadata(); + + CloudMetadata(const CloudMetadata&) = delete; + CloudMetadata& operator=(const CloudMetadata&) = delete; + + CloudProvider GetProvider() const; + CloudInstanceInfo GetInstanceInfo() const; + bool IsTerminationPending() const; + std::string GetTerminationReason() const; + + /** Writes a "cloud" sub-object into the compact binary writer if a provider + * was detected. No-op when running on bare metal. + */ + void Describe(CbWriter& Writer) const; + + /** Executes a single termination-poll cycle for the detected provider. + * Public so tests can drive poll cycles synchronously without relying on + * the background thread's 5-second timer. + */ + void PollTermination(); + + /** Removes the negative-cache sentinel files (.isNotAWS, .isNotAzure, + * .isNotGCP) from DataDir so subsequent detection probes are not skipped. + * Primarily intended for tests that need to reset state between sub-cases. + */ + void ClearSentinelFiles(); + +private: + /** Tries each provider in order, stops on first successful detection. */ + void DetectProvider(); + bool TryDetectAWS(); + bool TryDetectAzure(); + bool TryDetectGCP(); + + void WriteSentinelFile(const std::filesystem::path& Path); + bool HasSentinelFile(const std::filesystem::path& Path) const; + + void StartTerminationMonitor(); + void TerminationMonitorThread(); + void PollAWSTermination(); + void PollAzureTermination(); + void PollGCPTermination(); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + std::filesystem::path m_DataDir; + std::string m_ImdsEndpoint; + + mutable RwLock m_InfoLock; + CloudInstanceInfo m_Info; + + std::atomic<bool> m_TerminationPending{false}; + + mutable RwLock m_ReasonLock; + std::string m_TerminationReason; + + // IMDSv2 session token, acquired during AWS detection and reused for + // subsequent termination polling. Has a 300s TTL on the AWS side; if it + // expires mid-run the poll requests will get 401s which we treat as + // non-terminal (the monitor simply retries next cycle). + std::string m_AwsToken; + + std::thread m_MonitorThread; + std::atomic<bool> m_MonitorEnabled{true}; + Event m_MonitorEvent; +}; + +void cloudmetadata_forcelink(); // internal + +} // namespace zen::compute diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h new file mode 100644 index 000000000..65ec5f9ee --- /dev/null +++ b/src/zencompute/include/zencompute/computeservice.h @@ -0,0 +1,262 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/iohash.h> +# include <zenstore/zenstore.h> +# include <zenhttp/httpcommon.h> + +# include <filesystem> + +namespace zen { +class ChunkResolver; +class CbObjectWriter; +} // namespace zen + +namespace zen::compute { + +class ActionRecorder; +class ComputeServiceSession; +class IActionResultHandler; +class LocalProcessRunner; +class RemoteHttpRunner; +struct RunnerAction; +struct SubmitResult; + +struct WorkerDesc +{ + CbPackage Descriptor; + IoHash WorkerId{IoHash::Zero}; + + inline operator bool() const { return WorkerId != IoHash::Zero; } +}; + +/** + * Lambda style compute function service + * + * The responsibility of this class is to accept function execution requests, and + * schedule them using one or more FunctionRunner instances. It will basically always + * accept requests, queueing them if necessary, and then hand them off to runners + * as they become available. + * + * This is typically fronted by an API service that handles communication with clients. + */ +class ComputeServiceSession final +{ +public: + /** + * Session lifecycle state machine. + * + * Forward transitions: Created -> Ready -> Draining -> Paused -> Sunset + * Backward transitions: Draining -> Ready, Paused -> Ready + * Automatic transition: Draining -> Paused (when pending + running reaches 0) + * Jump transitions: any non-terminal -> Abandoned, any non-terminal -> Sunset + * Terminal states: Abandoned (only Sunset out), Sunset (no transitions out) + * + * | State | Accept new actions | Schedule pending | Finish running | + * |-----------|-------------------|-----------------|----------------| + * | Created | No | No | N/A | + * | Ready | Yes | Yes | Yes | + * | Draining | No | Yes | Yes | + * | Paused | No | No | No | + * | Abandoned | No | No | No (all abandoned) | + * | Sunset | No | No | No | + */ + enum class SessionState + { + Created, // Initial state before WaitUntilReady completes + Ready, // Normal operating state; accepts and schedules work + Draining, // Stops accepting new work; finishes existing; auto-transitions to Paused when empty + Paused, // Idle; no work accepted or scheduled; can resume to Ready + Abandoned, // Spot termination grace period; all actions abandoned; only Sunset out + Sunset // Terminal; triggers full shutdown + }; + + ComputeServiceSession(ChunkResolver& InChunkResolver); + ~ComputeServiceSession(); + + void WaitUntilReady(); + void Shutdown(); + bool IsHealthy(); + + SessionState GetSessionState() const; + + // Request a state transition. Returns false if the transition is invalid. + // Sunset can be reached from any non-Sunset state. + bool RequestStateTransition(SessionState NewState); + + // Orchestration + + void SetOrchestratorEndpoint(std::string_view Endpoint); + void SetOrchestratorBasePath(std::filesystem::path BasePath); + + // Worker registration and discovery + + void RegisterWorker(CbPackage Worker); + [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + [[nodiscard]] std::vector<IoHash> GetKnownWorkerIds(); + + // Action runners + + void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); + void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); + + // Action submission + + struct EnqueueResult + { + int Lsn; + CbObject ResponseMessage; + + inline operator bool() const { return Lsn != 0; } + }; + + [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority); + [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); + + // Queue management + // + // Queues group actions submitted by a single client session. They allow + // cancelling or polling completion of all actions in the group. + + struct CreateQueueResult + { + int QueueId = 0; // 0 if creation failed + }; + + enum class QueueState + { + Active, + Draining, + Cancelled, + }; + + struct QueueStatus + { + bool IsValid = false; + int QueueId = 0; + int ActiveCount = 0; // pending + running (not yet completed) + int CompletedCount = 0; // successfully completed + int FailedCount = 0; // failed + int AbandonedCount = 0; // abandoned + int CancelledCount = 0; // cancelled + QueueState State = QueueState::Active; + bool IsComplete = false; // ActiveCount == 0 + }; + + [[nodiscard]] CreateQueueResult CreateQueue(std::string_view Tag = {}, CbObject Metadata = {}, CbObject Config = {}); + [[nodiscard]] std::vector<int> GetQueueIds(); + [[nodiscard]] QueueStatus GetQueueStatus(int QueueId); + [[nodiscard]] CbObject GetQueueMetadata(int QueueId); + [[nodiscard]] CbObject GetQueueConfig(int QueueId); + void CancelQueue(int QueueId); + void DrainQueue(int QueueId); + void DeleteQueue(int QueueId); + void GetQueueCompleted(int QueueId, CbWriter& Cbo); + + // Queue-scoped action submission. Actions submitted via these methods are + // tracked under the given queue in addition to the global LSN-based tracking. + + [[nodiscard]] EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); + [[nodiscard]] EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + + // Completed action tracking + + [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + void RetireActionResult(int ActionLsn); + + // Action rescheduling + + struct RescheduleResult + { + bool Success = false; + std::string Error; + int RetryCount = 0; + }; + + [[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn); + + void GetCompleted(CbWriter&); + + // Running action tracking + + struct RunningActionInfo + { + int Lsn; + int QueueId; + IoHash ActionId; + float CpuUsagePercent; // -1.0 if not yet sampled + float CpuSeconds; // 0.0 if not yet sampled + }; + + [[nodiscard]] std::vector<RunningActionInfo> GetRunningActions(); + + // Action history tracking (note that this is separate from completed action tracking, and + // will include actions which have been retired and no longer have their results available) + + struct ActionHistoryEntry + { + int Lsn; + int QueueId = 0; + IoHash ActionId; + IoHash WorkerId; + CbObject ActionDescriptor; + std::string ExecutionLocation; + bool Succeeded; + float CpuSeconds = 0.0f; // total CPU time at completion; 0.0 if not sampled + int RetryCount = 0; // number of times this action was rescheduled + // sized to match RunnerAction::State::_Count but we can't use the enum here + // for dependency reasons, so just use a fixed size array and static assert in + // the implementation file + uint64_t Timestamps[8] = {}; + }; + + [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); + [[nodiscard]] std::vector<ActionHistoryEntry> GetQueueHistory(int QueueId, int Limit = 100); + + // Stats reporting + + struct ActionCounts + { + int Pending = 0; + int Running = 0; + int Completed = 0; + int ActiveQueues = 0; + }; + + [[nodiscard]] ActionCounts GetActionCounts(); + + void EmitStats(CbObjectWriter& Cbo); + + // Recording + + void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); + void StopRecording(); + +private: + void PostUpdate(RunnerAction* Action); + + friend class FunctionRunner; + friend struct RunnerAction; + + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void computeservice_forcelink(); + +} // namespace zen::compute + +namespace zen { +const char* ToString(compute::ComputeServiceSession::SessionState State); +const char* ToString(compute::ComputeServiceSession::QueueState State); +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/functionservice.h b/src/zencompute/include/zencompute/functionservice.h deleted file mode 100644 index 1deb99fd5..000000000 --- a/src/zencompute/include/zencompute/functionservice.h +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/zencore.h> - -#if !defined(ZEN_WITH_COMPUTE_SERVICES) -# define ZEN_WITH_COMPUTE_SERVICES 1 -#endif - -#if ZEN_WITH_COMPUTE_SERVICES - -# include <zencore/compactbinary.h> -# include <zencore/compactbinarypackage.h> -# include <zencore/iohash.h> -# include <zenstore/zenstore.h> -# include <zenhttp/httpcommon.h> - -# include <filesystem> - -namespace zen { -class ChunkResolver; -class CbObjectWriter; -} // namespace zen - -namespace zen::compute { - -class ActionRecorder; -class FunctionServiceSession; -class IActionResultHandler; -class LocalProcessRunner; -class RemoteHttpRunner; -struct RunnerAction; -struct SubmitResult; - -struct WorkerDesc -{ - CbPackage Descriptor; - IoHash WorkerId{IoHash::Zero}; - - inline operator bool() const { return WorkerId != IoHash::Zero; } -}; - -/** - * Lambda style compute function service - * - * The responsibility of this class is to accept function execution requests, and - * schedule them using one or more FunctionRunner instances. It will basically always - * accept requests, queueing them if necessary, and then hand them off to runners - * as they become available. - * - * This is typically fronted by an API service that handles communication with clients. - */ -class FunctionServiceSession final -{ -public: - FunctionServiceSession(ChunkResolver& InChunkResolver); - ~FunctionServiceSession(); - - void Shutdown(); - bool IsHealthy(); - - // Worker registration and discovery - - void RegisterWorker(CbPackage Worker); - [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); - [[nodiscard]] std::vector<IoHash> GetKnownWorkerIds(); - - // Action runners - - void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath); - void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); - - // Action submission - - struct EnqueueResult - { - int Lsn; - CbObject ResponseMessage; - - inline operator bool() const { return Lsn != 0; } - }; - - [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority); - [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); - - // Completed action tracking - - [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); - [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); - - void GetCompleted(CbWriter&); - - // Action history tracking (note that this is separate from completed action tracking, and - // will include actions which have been retired and no longer have their results available) - - struct ActionHistoryEntry - { - int Lsn; - IoHash ActionId; - IoHash WorkerId; - CbObject ActionDescriptor; - bool Succeeded; - uint64_t Timestamps[5] = {}; - }; - - [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); - - // Stats reporting - - void EmitStats(CbObjectWriter& Cbo); - - // Recording - - void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); - void StopRecording(); - -private: - void PostUpdate(RunnerAction* Action); - - friend class FunctionRunner; - friend struct RunnerAction; - - struct Impl; - std::unique_ptr<Impl> m_Impl; -}; - -void function_forcelink(); - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h new file mode 100644 index 000000000..ee1cd2614 --- /dev/null +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "zencompute/computeservice.h" + +# include <zenhttp/httpserver.h> + +# include <filesystem> +# include <memory> + +namespace zen { +class CidStore; +} + +namespace zen::compute { + +/** + * HTTP interface for compute service + */ +class HttpComputeService : public HttpService, public IHttpStatsProvider +{ +public: + HttpComputeService(CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions = 0); + ~HttpComputeService(); + + void Shutdown(); + + [[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts(); + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + + // IHttpStatsProvider + + void HandleStatsRequest(HttpServerRequest& Request) override; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void httpcomputeservice_forcelink(); + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httpfunctionservice.h b/src/zencompute/include/zencompute/httpfunctionservice.h deleted file mode 100644 index 6e2344ae6..000000000 --- a/src/zencompute/include/zencompute/httpfunctionservice.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/zencore.h> - -#if !defined(ZEN_WITH_COMPUTE_SERVICES) -# define ZEN_WITH_COMPUTE_SERVICES 1 -#endif - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "zencompute/functionservice.h" - -# include <zencore/compactbinary.h> -# include <zencore/compactbinarypackage.h> -# include <zencore/iohash.h> -# include <zencore/logging.h> -# include <zentelemetry/stats.h> -# include <zenhttp/httpserver.h> - -# include <deque> -# include <filesystem> -# include <unordered_map> - -namespace zen { -class CidStore; -} - -namespace zen::compute { - -class HttpFunctionService; -class FunctionService; - -/** - * HTTP interface for compute function service - */ -class HttpFunctionService : public HttpService, public IHttpStatsProvider -{ -public: - HttpFunctionService(CidStore& InCidStore, IHttpStatsService& StatsService, const std::filesystem::path& BaseDir); - ~HttpFunctionService(); - - void Shutdown(); - - virtual const char* BaseUri() const override; - virtual void HandleRequest(HttpServerRequest& Request) override; - - // IHttpStatsProvider - - virtual void HandleStatsRequest(HttpServerRequest& Request) override; - -protected: - CidStore& m_CidStore; - IHttpStatsService& m_StatsService; - LoggerRef Log() { return m_Log; } - -private: - LoggerRef m_Log; - std::filesystem ::path m_BaseDir; - HttpRequestRouter m_Router; - FunctionServiceSession m_FunctionService; - - // Metrics - - metrics::OperationTiming m_HttpRequests; -}; - -void httpfunction_forcelink(); - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h index 168c6d7fe..da5c5dfc3 100644 --- a/src/zencompute/include/zencompute/httporchestrator.h +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -2,43 +2,100 @@ #pragma once +#include <zencompute/zencompute.h> + #include <zencore/logging.h> #include <zencore/thread.h> -#include <zencore/timer.h> #include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> +#include <atomic> +#include <filesystem> +#include <memory> +#include <string> +#include <thread> #include <unordered_map> +#include <vector> + +#define ZEN_WITH_WEBSOCKETS 1 namespace zen::compute { +class OrchestratorService; + +// Experimental helper, to see if we can get rid of some error-prone +// boilerplate when declaring loggers as class members. + +class LoggerHelper +{ +public: + LoggerHelper(std::string_view Logger) : m_Log(logging::Get(Logger)) {} + + LoggerRef operator()() { return m_Log; } + +private: + LoggerRef m_Log; +}; + /** - * Mock orchestrator service, for testing dynamic provisioning + * Orchestrator HTTP service with WebSocket push support + * + * Normal HTTP requests are routed through the HttpRequestRouter as before. + * WebSocket clients connecting to /orch/ws receive periodic state broadcasts + * from a dedicated push thread, eliminating the need for polling. */ class HttpOrchestratorService : public HttpService +#if ZEN_WITH_WEBSOCKETS +, + public IWebSocketHandler +#endif { public: - HttpOrchestratorService(); + explicit HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); ~HttpOrchestratorService(); HttpOrchestratorService(const HttpOrchestratorService&) = delete; HttpOrchestratorService& operator=(const HttpOrchestratorService&) = delete; + /** + * Gracefully shut down the WebSocket push thread and release connections. + * Must be called while the ASIO io_context is still alive. The destructor + * also calls this, so it is safe (but not ideal) to omit the explicit call. + */ + void Shutdown(); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; + // IWebSocketHandler +#if ZEN_WITH_WEBSOCKETS + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; +#endif + private: - HttpRequestRouter m_Router; - LoggerRef m_Log; + HttpRequestRouter m_Router; + LoggerHelper Log{"orch"}; + std::unique_ptr<OrchestratorService> m_Service; + std::string m_Hostname; + + // WebSocket push - struct KnownWorker - { - std::string_view BaseUri; - Stopwatch LastSeen; - }; +#if ZEN_WITH_WEBSOCKETS + RwLock m_WsConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; + std::thread m_PushThread; + std::atomic<bool> m_PushEnabled{false}; + Event m_PushEvent; + void PushThreadFunction(); - RwLock m_KnownWorkersLock; - std::unordered_map<std::string, KnownWorker> m_KnownWorkers; + // Worker WebSocket connections (worker→orchestrator persistent links) + RwLock m_WorkerWsLock; + std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr → worker ID + std::string HandleWorkerWebSocketMessage(const WebSocketMessage& Msg); +#endif }; } // namespace zen::compute diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h new file mode 100644 index 000000000..521722e63 --- /dev/null +++ b/src/zencompute/include/zencompute/mockimds.h @@ -0,0 +1,102 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/cloudmetadata.h> +#include <zenhttp/httpserver.h> + +#include <string> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +/** + * Mock IMDS (Instance Metadata Service) for testing CloudMetadata. + * + * Implements an HttpService that responds to the same URL paths as the real + * cloud provider metadata endpoints (AWS IMDSv2, Azure IMDS, GCP metadata). + * Tests configure which provider is "active" and set the desired response + * values, then pass the mock server's address as the ImdsEndpoint to the + * CloudMetadata constructor. + * + * When a request arrives for a provider that is not the ActiveProvider, the + * mock returns 404, causing CloudMetadata to write a sentinel file and move + * on to the next provider — exactly like a failed probe on bare metal. + * + * All config fields are public and can be mutated between poll cycles to + * simulate state changes (e.g. a spot interruption appearing mid-run). + * + * Usage: + * MockImdsService Mock; + * Mock.ActiveProvider = CloudProvider::AWS; + * Mock.Aws.InstanceId = "i-test"; + * // ... stand up ASIO server, register Mock, create CloudMetadata with endpoint + */ +class MockImdsService : public HttpService +{ +public: + /** AWS IMDSv2 response configuration. */ + struct AwsConfig + { + std::string Token = "mock-aws-token-v2"; + std::string InstanceId = "i-0123456789abcdef0"; + std::string AvailabilityZone = "us-east-1a"; + std::string LifeCycle = "on-demand"; // "spot" or "on-demand" + + // Empty string → endpoint returns 404 (instance not in an ASG). + // Non-empty → returned as the response body. "InService" means healthy; + // anything else (e.g. "Terminated:Wait") triggers termination detection. + std::string AutoscalingState; + + // Empty string → endpoint returns 404 (no spot interruption). + // Non-empty → returned as the response body, signalling a spot reclaim. + std::string SpotAction; + }; + + /** Azure IMDS response configuration. */ + struct AzureConfig + { + std::string VmId = "vm-12345678-1234-1234-1234-123456789abc"; + std::string Location = "eastus"; + std::string Priority = "Regular"; // "Spot" or "Regular" + + // Empty → instance is not in a VM Scale Set (no autoscaling). + std::string VmScaleSetName; + + // Empty → no scheduled events. Set to "Preempt", "Terminate", or + // "Reboot" to simulate a termination-class event. + std::string ScheduledEventType; + std::string ScheduledEventStatus = "Scheduled"; + }; + + /** GCP metadata response configuration. */ + struct GcpConfig + { + std::string InstanceId = "1234567890123456789"; + std::string Zone = "projects/123456/zones/us-central1-a"; + std::string Preemptible = "FALSE"; // "TRUE" or "FALSE" + std::string MaintenanceEvent = "NONE"; // "NONE" or event description + }; + + /** Which provider's endpoints respond successfully. + * Requests targeting other providers receive 404. + */ + CloudProvider ActiveProvider = CloudProvider::None; + + AwsConfig Aws; + AzureConfig Azure; + GcpConfig Gcp; + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + +private: + void HandleAwsRequest(HttpServerRequest& Request); + void HandleAzureRequest(HttpServerRequest& Request); + void HandleGcpRequest(HttpServerRequest& Request); +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h new file mode 100644 index 000000000..071e902b3 --- /dev/null +++ b/src/zencompute/include/zencompute/orchestratorservice.h @@ -0,0 +1,177 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/thread.h> +# include <zencore/timer.h> +# include <zencore/uid.h> + +# include <deque> +# include <optional> +# include <filesystem> +# include <memory> +# include <string> +# include <string_view> +# include <thread> +# include <unordered_map> + +namespace zen::compute { + +class WorkerTimelineStore; + +class OrchestratorService +{ +public: + explicit OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); + ~OrchestratorService(); + + OrchestratorService(const OrchestratorService&) = delete; + OrchestratorService& operator=(const OrchestratorService&) = delete; + + struct WorkerAnnouncement + { + std::string_view Id; + std::string_view Uri; + std::string_view Hostname; + std::string_view Platform; // e.g. "windows", "wine", "linux", "macos" + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string_view Provisioner; // e.g. "horde", "nomad", or empty + }; + + struct ProvisioningEvent + { + enum class Type + { + Joined, + Left, + Returned + }; + Type EventType; + DateTime Timestamp; + std::string WorkerId; + std::string Hostname; + }; + + struct ClientAnnouncement + { + Oid SessionId; + std::string_view Hostname; + std::string_view Address; + CbObject Metadata; + }; + + struct ClientEvent + { + enum class Type + { + Connected, + Disconnected, + Updated + }; + Type EventType; + DateTime Timestamp; + std::string ClientId; + std::string Hostname; + }; + + CbObject GetWorkerList(); + void AnnounceWorker(const WorkerAnnouncement& Announcement); + + bool IsWorkerWebSocketEnabled() const; + void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected); + + CbObject GetProvisioningHistory(int Limit = 100); + + CbObject GetWorkerTimeline(std::string_view WorkerId, std::optional<DateTime> From, std::optional<DateTime> To, int Limit); + + CbObject GetAllTimelines(DateTime From, DateTime To); + + std::string AnnounceClient(const ClientAnnouncement& Announcement); + bool UpdateClient(std::string_view ClientId, CbObject Metadata = {}); + bool CompleteClient(std::string_view ClientId); + CbObject GetClientList(); + CbObject GetClientHistory(int Limit = 100); + +private: + enum class ReachableState + { + Unknown, + Reachable, + Unreachable, + }; + + struct KnownWorker + { + std::string BaseUri; + Stopwatch LastSeen; + std::string Hostname; + std::string Platform; + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string Provisioner; + ReachableState Reachable = ReachableState::Unknown; + bool WsConnected = false; + Stopwatch LastProbed; + }; + + RwLock m_KnownWorkersLock; + std::unordered_map<std::string, KnownWorker> m_KnownWorkers; + std::unique_ptr<WorkerTimelineStore> m_TimelineStore; + + RwLock m_ProvisioningLogLock; + std::deque<ProvisioningEvent> m_ProvisioningLog; + static constexpr size_t kMaxProvisioningEvents = 1000; + + void RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname); + + struct KnownClient + { + Oid SessionId; + std::string Hostname; + std::string Address; + Stopwatch LastSeen; + CbObject Metadata; + }; + + RwLock m_KnownClientsLock; + std::unordered_map<std::string, KnownClient> m_KnownClients; + + RwLock m_ClientLogLock; + std::deque<ClientEvent> m_ClientLog; + static constexpr size_t kMaxClientEvents = 1000; + + void RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname); + + bool m_EnableWorkerWebSocket = false; + + std::thread m_ProbeThread; + std::atomic<bool> m_ProbeThreadEnabled{true}; + Event m_ProbeThreadEvent; + void ProbeThreadFunction(); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/include/zencompute/recordingreader.h b/src/zencompute/include/zencompute/recordingreader.h index bf1aff125..3f233fae0 100644 --- a/src/zencompute/include/zencompute/recordingreader.h +++ b/src/zencompute/include/zencompute/recordingreader.h @@ -2,7 +2,9 @@ #pragma once -#include <zencompute/functionservice.h> +#include <zencompute/zencompute.h> + +#include <zencompute/computeservice.h> #include <zencompute/zencompute.h> #include <zencore/basicfile.h> #include <zencore/compactbinarybuilder.h> diff --git a/src/zencompute/include/zencompute/zencompute.h b/src/zencompute/include/zencompute/zencompute.h index 6dc32eeea..00be4d4a0 100644 --- a/src/zencompute/include/zencompute/zencompute.h +++ b/src/zencompute/include/zencompute/zencompute.h @@ -4,6 +4,10 @@ #include <zencore/zencore.h> +#if !defined(ZEN_WITH_COMPUTE_SERVICES) +# define ZEN_WITH_COMPUTE_SERVICES 1 +#endif + namespace zen { void zencompute_forcelinktests(); |