// Copyright Epic Games, Inc. All Rights Reserved. #include "zencompute/httporchestrator.h" #if ZEN_WITH_COMPUTE_SERVICES # include # include # include # include # include namespace zen::compute { // Worker IDs must be 3-64 characters and can only contain letters, numbers, underscores, and dashes static bool IsValidWorkerId(std::string_view Id) { if (Id.size() < 3 || Id.size() > 64) { return false; } for (char c : Id) { if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-') { continue; } return false; } return true; } // Shared announce payload parser used by both the HTTP POST route and the // WebSocket message handler. Returns the worker ID on success (empty on // validation failure). The returned WorkerAnnouncement has string_view // fields that reference the supplied CbObjectView, so the CbObject must // outlive the returned announcement. static std::string_view ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnnouncement& Ann) { Ann.Id = Data["id"].AsString(""); Ann.Uri = Data["uri"].AsString(""); if (!IsValidWorkerId(Ann.Id)) { return {}; } if (!Ann.Uri.starts_with("http://") && !Ann.Uri.starts_with("https://")) { return {}; } Ann.Hostname = Data["hostname"].AsString(""); Ann.Platform = Data["platform"].AsString(""); Ann.CpuUsagePercent = Data["cpu_usage"].AsFloat(0.0f); Ann.MemoryTotalBytes = Data["memory_total"].AsUInt64(0); Ann.MemoryUsedBytes = Data["memory_used"].AsUInt64(0); Ann.BytesReceived = Data["bytes_received"].AsUInt64(0); Ann.BytesSent = Data["bytes_sent"].AsUInt64(0); Ann.ActionsPending = Data["actions_pending"].AsInt32(0); Ann.ActionsRunning = Data["actions_running"].AsInt32(0); Ann.ActionsCompleted = Data["actions_completed"].AsInt32(0); Ann.ActiveQueues = Data["active_queues"].AsInt32(0); Ann.Provisioner = Data["provisioner"].AsString(""); if (auto Metrics = Data["metrics"].AsObjectView()) { Ann.Cpus = Metrics["lp_count"].AsInt32(0); if (Ann.Cpus <= 0) { Ann.Cpus = 1; } } return Ann.Id; } HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) : m_Service(std::make_unique(std::move(DataDir), EnableWorkerWebSocket)) , m_Hostname(GetMachineName()) { m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); // dummy endpoint for websocket clients m_Router.RegisterRoute( "ws", [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, HttpVerb::kGet); m_Router.RegisterRoute( "status", [this](HttpRouterRequest& Req) { CbObjectWriter Cbo; Cbo << "hostname" << std::string_view(m_Hostname); Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kGet); m_Router.RegisterRoute( "provision", [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, HttpVerb::kPost); m_Router.RegisterRoute( "announce", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); CbObject Data = HttpReq.ReadPayloadObject(); OrchestratorService::WorkerAnnouncement Ann; std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); if (WorkerId.empty()) { return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid worker announcement: id must be 3-64 alphanumeric/underscore/dash " "characters and uri must start with http:// or https://"); } m_Service->AnnounceWorker(Ann); HttpReq.WriteResponse(HttpResponseCode::OK); # if ZEN_WITH_WEBSOCKETS // Notify push thread that state may have changed m_PushEvent.Set(); # endif }, HttpVerb::kPost); m_Router.RegisterRoute( "agents", [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, HttpVerb::kGet); m_Router.RegisterRoute( "history", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); auto Params = HttpReq.GetQueryParams(); int Limit = 100; auto LimitStr = Params.GetValue("limit"); if (!LimitStr.empty()) { Limit = std::atoi(std::string(LimitStr).c_str()); } HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetProvisioningHistory(Limit)); }, HttpVerb::kGet); m_Router.RegisterRoute( "timeline/{workerid}", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); std::string_view WorkerId = Req.GetCapture(1); auto Params = HttpReq.GetQueryParams(); auto FromStr = Params.GetValue("from"); auto ToStr = Params.GetValue("to"); auto LimitStr = Params.GetValue("limit"); std::optional From; std::optional To; if (!FromStr.empty()) { auto Val = zen::ParseInt(FromStr); if (!Val) { return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } From = DateTime(*Val); } if (!ToStr.empty()) { auto Val = zen::ParseInt(ToStr); if (!Val) { return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } To = DateTime(*Val); } int Limit = !LimitStr.empty() ? zen::ParseInt(LimitStr).value_or(0) : 0; CbObject Result = m_Service->GetWorkerTimeline(WorkerId, From, To, Limit); if (!Result) { return HttpReq.WriteResponse(HttpResponseCode::NotFound); } HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); }, HttpVerb::kGet); m_Router.RegisterRoute( "timeline", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); auto Params = HttpReq.GetQueryParams(); auto FromStr = Params.GetValue("from"); auto ToStr = Params.GetValue("to"); DateTime From = DateTime(0); DateTime To = DateTime::Now(); if (!FromStr.empty()) { auto Val = zen::ParseInt(FromStr); if (!Val) { return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } From = DateTime(*Val); } if (!ToStr.empty()) { auto Val = zen::ParseInt(ToStr); if (!Val) { return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } To = DateTime(*Val); } CbObject Result = m_Service->GetAllTimelines(From, To); HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); }, HttpVerb::kGet); // Client tracking endpoints m_Router.RegisterRoute( "clients", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); CbObject Data = HttpReq.ReadPayloadObject(); OrchestratorService::ClientAnnouncement Ann; Ann.SessionId = Data["session_id"].AsObjectId(Oid::Zero); Ann.Hostname = Data["hostname"].AsString(""); Ann.Address = HttpReq.GetRemoteAddress(); auto MetadataView = Data["metadata"].AsObjectView(); if (MetadataView) { Ann.Metadata = CbObject::Clone(MetadataView); } std::string ClientId = m_Service->AnnounceClient(Ann); CbObjectWriter ResponseObj; ResponseObj << "id" << std::string_view(ClientId); HttpReq.WriteResponse(HttpResponseCode::OK, ResponseObj.Save()); # if ZEN_WITH_WEBSOCKETS m_PushEvent.Set(); # endif }, HttpVerb::kPost); m_Router.RegisterRoute( "clients/{clientid}/update", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); std::string_view ClientId = Req.GetCapture(1); CbObject MetadataObj; CbObject Data = HttpReq.ReadPayloadObject(); if (Data) { auto MetadataView = Data["metadata"].AsObjectView(); if (MetadataView) { MetadataObj = CbObject::Clone(MetadataView); } } if (m_Service->UpdateClient(ClientId, std::move(MetadataObj))) { HttpReq.WriteResponse(HttpResponseCode::OK); } else { HttpReq.WriteResponse(HttpResponseCode::NotFound); } }, HttpVerb::kPost); m_Router.RegisterRoute( "clients/{clientid}/complete", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); std::string_view ClientId = Req.GetCapture(1); if (m_Service->CompleteClient(ClientId)) { HttpReq.WriteResponse(HttpResponseCode::OK); } else { HttpReq.WriteResponse(HttpResponseCode::NotFound); } # if ZEN_WITH_WEBSOCKETS m_PushEvent.Set(); # endif }, HttpVerb::kPost); m_Router.RegisterRoute( "clients", [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetClientList()); }, HttpVerb::kGet); m_Router.RegisterRoute( "clients/history", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); auto Params = HttpReq.GetQueryParams(); int Limit = 100; auto LimitStr = Params.GetValue("limit"); if (!LimitStr.empty()) { Limit = std::atoi(std::string(LimitStr).c_str()); } HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetClientHistory(Limit)); }, HttpVerb::kGet); # if ZEN_WITH_WEBSOCKETS // Start the WebSocket push thread m_PushEnabled.store(true); m_PushThread = std::thread([this] { PushThreadFunction(); }); # endif } HttpOrchestratorService::~HttpOrchestratorService() { Shutdown(); } void HttpOrchestratorService::Shutdown() { # if ZEN_WITH_WEBSOCKETS if (!m_PushEnabled.exchange(false)) { return; } // Stop the push thread first, before touching connections. This ensures // the push thread is no longer reading m_WsConnections or calling into // m_Service when we start tearing things down. m_PushEvent.Set(); if (m_PushThread.joinable()) { m_PushThread.join(); } // Clean up worker WebSocket connections — collect IDs under lock, then // notify the service outside the lock to avoid lock-order inversions. std::vector WorkerIds; m_WorkerWsLock.WithExclusiveLock([&] { WorkerIds.reserve(m_WorkerWsMap.size()); for (const auto& [Conn, Id] : m_WorkerWsMap) { WorkerIds.push_back(Id); } m_WorkerWsMap.clear(); }); for (const auto& Id : WorkerIds) { m_Service->SetWorkerWebSocketConnected(Id, false); } // Now that the push thread is gone, release all dashboard connections. m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); }); # endif } const char* HttpOrchestratorService::BaseUri() const { return "/orch/"; } void HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) { if (m_Router.HandleRequest(Request) == false) { ZEN_WARN("No route found for {0}", Request.RelativeUri()); } } ////////////////////////////////////////////////////////////////////////// // // IWebSocketHandler // # if ZEN_WITH_WEBSOCKETS void HttpOrchestratorService::OnWebSocketOpen(Ref Connection) { if (!m_PushEnabled.load()) { return; } ZEN_INFO("WebSocket client connected"); m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); // Wake push thread to send initial state immediately m_PushEvent.Set(); } void HttpOrchestratorService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) { // Only handle binary messages from workers when the feature is enabled. if (!m_Service->IsWorkerWebSocketEnabled() || Msg.Opcode != WebSocketOpcode::kBinary) { return; } std::string WorkerId = HandleWorkerWebSocketMessage(Msg); if (WorkerId.empty()) { return; } // Check if this is a new worker WebSocket connection bool IsNewWorkerWs = false; m_WorkerWsLock.WithExclusiveLock([&] { auto It = m_WorkerWsMap.find(&Conn); if (It == m_WorkerWsMap.end()) { m_WorkerWsMap[&Conn] = WorkerId; IsNewWorkerWs = true; } }); if (IsNewWorkerWs) { m_Service->SetWorkerWebSocketConnected(WorkerId, true); } m_PushEvent.Set(); } std::string HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg) { // Workers send CbObject in native binary format over the WebSocket to // avoid the lossy CbObject↔JSON round-trip. CbObject Data = CbObject::MakeView(Msg.Payload.GetData()); if (!Data) { ZEN_WARN("worker WebSocket message is not a valid CbObject"); return {}; } OrchestratorService::WorkerAnnouncement Ann; std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); if (WorkerId.empty()) { ZEN_WARN("invalid worker announcement via WebSocket"); return {}; } m_Service->AnnounceWorker(Ann); return std::string(WorkerId); } void HttpOrchestratorService::OnWebSocketClose(WebSocketConnection& Conn, [[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) { ZEN_INFO("WebSocket client disconnected (code {})", Code); // Check if this was a worker WebSocket connection; collect the ID under // the worker lock, then notify the service outside the lock. std::string DisconnectedWorkerId; m_WorkerWsLock.WithExclusiveLock([&] { auto It = m_WorkerWsMap.find(&Conn); if (It != m_WorkerWsMap.end()) { DisconnectedWorkerId = std::move(It->second); m_WorkerWsMap.erase(It); } }); if (!DisconnectedWorkerId.empty()) { m_Service->SetWorkerWebSocketConnected(DisconnectedWorkerId, false); m_PushEvent.Set(); } if (!m_PushEnabled.load()) { return; } // Remove from dashboard connections m_WsConnectionsLock.WithExclusiveLock([&] { auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref& C) { return C.Get() == &Conn; }); m_WsConnections.erase(It, m_WsConnections.end()); }); } # endif ////////////////////////////////////////////////////////////////////////// // // Push thread // # if ZEN_WITH_WEBSOCKETS void HttpOrchestratorService::PushThreadFunction() { SetCurrentThreadName("orch_ws_push"); while (m_PushEnabled.load()) { m_PushEvent.Wait(2000); m_PushEvent.Reset(); if (!m_PushEnabled.load()) { break; } // Snapshot current connections std::vector> Connections; m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; }); if (Connections.empty()) { continue; } // Build combined JSON with worker list, provisioning history, clients, and client history CbObject WorkerList = m_Service->GetWorkerList(); CbObject History = m_Service->GetProvisioningHistory(50); CbObject ClientList = m_Service->GetClientList(); CbObject ClientHistory = m_Service->GetClientHistory(50); ExtendableStringBuilder<4096> JsonBuilder; JsonBuilder.Append("{"); JsonBuilder.Append(fmt::format("\"hostname\":\"{}\",", m_Hostname)); // Emit workers array from worker list ExtendableStringBuilder<2048> WorkerJson; WorkerList.ToJson(WorkerJson); std::string_view WorkerJsonView = WorkerJson.ToView(); // Strip outer braces: {"workers":[...]} -> "workers":[...] if (WorkerJsonView.size() >= 2) { JsonBuilder.Append(WorkerJsonView.substr(1, WorkerJsonView.size() - 2)); } JsonBuilder.Append(","); // Emit events array from history ExtendableStringBuilder<2048> HistoryJson; History.ToJson(HistoryJson); std::string_view HistoryJsonView = HistoryJson.ToView(); if (HistoryJsonView.size() >= 2) { JsonBuilder.Append(HistoryJsonView.substr(1, HistoryJsonView.size() - 2)); } JsonBuilder.Append(","); // Emit clients array from client list ExtendableStringBuilder<2048> ClientJson; ClientList.ToJson(ClientJson); std::string_view ClientJsonView = ClientJson.ToView(); if (ClientJsonView.size() >= 2) { JsonBuilder.Append(ClientJsonView.substr(1, ClientJsonView.size() - 2)); } JsonBuilder.Append(","); // Emit client_events array from client history ExtendableStringBuilder<2048> ClientHistoryJson; ClientHistory.ToJson(ClientHistoryJson); std::string_view ClientHistoryJsonView = ClientHistoryJson.ToView(); if (ClientHistoryJsonView.size() >= 2) { JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2)); } JsonBuilder.Append("}"); std::string_view Json = JsonBuilder.ToView(); // Broadcast to all connected clients, prune closed ones bool HadClosedConnections = false; for (auto& Conn : Connections) { if (Conn->IsOpen()) { Conn->SendText(Json); } else { HadClosedConnections = true; } } if (HadClosedConnections) { m_WsConnectionsLock.WithExclusiveLock([&] { auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref& C) { return !C->IsOpen(); }); m_WsConnections.erase(It, m_WsConnections.end()); }); } } } # endif } // namespace zen::compute #endif