// Copyright Epic Games, Inc. All Rights Reserved. #include "zencompute/httporchestrator.h" #if ZEN_WITH_COMPUTE_SERVICES # include # include # include namespace zen::compute { // Worker IDs must be 3-64 characters and can only contain letters, numbers, underscores, and dashes static bool IsValidWorkerId(std::string_view Id) { if (Id.size() < 3 || Id.size() > 64) { return false; } for (char c : Id) { if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-') { continue; } return false; } return true; } HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir) : m_Service(std::make_unique(std::move(DataDir))) { m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); // dummy endpoint for websocket clients m_Router.RegisterRoute( "ws", [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, HttpVerb::kGet); m_Router.RegisterRoute( "provision", [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, HttpVerb::kPost); m_Router.RegisterRoute( "announce", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); CbObject Data = HttpReq.ReadPayloadObject(); std::string_view WorkerId = Data["id"].AsString(""); std::string_view WorkerUri = Data["uri"].AsString(""); if (!IsValidWorkerId(WorkerId)) { return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid worker id: must be 3-64 alphanumeric, underscore, or dash characters"); } if (!WorkerUri.starts_with("http://") && !WorkerUri.starts_with("https://")) { return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid uri: must start with http:// or https://"); } OrchestratorService::WorkerAnnouncement Ann; Ann.Id = WorkerId; Ann.Uri = WorkerUri; Ann.Hostname = Data["hostname"].AsString(""); Ann.CpuUsagePercent = Data["cpu_usage"].AsFloat(0.0f); Ann.MemoryTotalBytes = Data["memory_total"].AsUInt64(0); Ann.MemoryUsedBytes = Data["memory_used"].AsUInt64(0); if (auto Metrics = Data["metrics"].AsObjectView()) { Ann.Cpus = Metrics["lp_count"].AsInt32(0); if (Ann.Cpus <= 0) { Ann.Cpus = 1; } } m_Service->AnnounceWorker(Ann); HttpReq.WriteResponse(HttpResponseCode::OK); # if ZEN_WITH_WEBSOCKETS // Notify push thread that state may have changed m_PushEvent.Set(); # endif }, HttpVerb::kPost); m_Router.RegisterRoute( "agents", [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, HttpVerb::kGet); m_Router.RegisterRoute( "timeline/{workerid}", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); std::string_view WorkerId = Req.GetCapture(1); auto Params = HttpReq.GetQueryParams(); CbObject Result = m_Service->GetWorkerTimeline(WorkerId, Params.GetValue("from"), Params.GetValue("to"), Params.GetValue("limit")); if (!Result) { return HttpReq.WriteResponse(HttpResponseCode::NotFound); } HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); }, HttpVerb::kGet); m_Router.RegisterRoute( "timeline", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); auto Params = HttpReq.GetQueryParams(); CbObject Result = m_Service->GetAllTimelines(Params.GetValue("from"), Params.GetValue("to")); HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); }, HttpVerb::kGet); # if ZEN_WITH_WEBSOCKETS // Start the WebSocket push thread m_PushEnabled.store(true); m_PushThread = std::thread([this] { PushThreadFunction(); }); # endif } HttpOrchestratorService::~HttpOrchestratorService() { Shutdown(); } void HttpOrchestratorService::Shutdown() { # if ZEN_WITH_WEBSOCKETS if (!m_PushEnabled.exchange(false)) { return; } // Stop the push thread first, before touching connections. This ensures // the push thread is no longer reading m_WsConnections or calling into // m_Service when we start tearing things down. m_PushEvent.Set(); if (m_PushThread.joinable()) { m_PushThread.join(); } // Now that the push thread is gone, release all connections. m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); }); # endif } const char* HttpOrchestratorService::BaseUri() const { return "/orch/"; } void HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) { if (m_Router.HandleRequest(Request) == false) { ZEN_WARN("No route found for {0}", Request.RelativeUri()); } } ////////////////////////////////////////////////////////////////////////// // // IWebSocketHandler // # if ZEN_WITH_WEBSOCKETS void HttpOrchestratorService::OnWebSocketOpen(Ref Connection) { if (!m_PushEnabled.load()) { return; } ZEN_INFO("WebSocket client connected"); m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); // Wake push thread to send initial state immediately m_PushEvent.Set(); } void HttpOrchestratorService::OnWebSocketMessage([[maybe_unused]] WebSocketConnection& Conn, [[maybe_unused]] const WebSocketMessage& Msg) { // Dashboard clients don't send meaningful messages; ignore } void HttpOrchestratorService::OnWebSocketClose(WebSocketConnection& Conn, [[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) { ZEN_INFO("WebSocket client disconnected (code {})", Code); if (!m_PushEnabled.load()) { return; } m_WsConnectionsLock.WithExclusiveLock([&] { auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref& C) { return C.Get() == &Conn; }); m_WsConnections.erase(It, m_WsConnections.end()); }); } # endif ////////////////////////////////////////////////////////////////////////// // // Push thread // # if ZEN_WITH_WEBSOCKETS void HttpOrchestratorService::PushThreadFunction() { SetCurrentThreadName("orch_ws_push"); while (m_PushEnabled.load()) { m_PushEvent.Wait(2000); m_PushEvent.Reset(); if (!m_PushEnabled.load()) { break; } // Snapshot current connections std::vector> Connections; m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; }); if (Connections.empty()) { continue; } // Build JSON from the worker list CbObject WorkerList = m_Service->GetWorkerList(); ExtendableStringBuilder<4096> JsonBuilder; WorkerList.ToJson(JsonBuilder); std::string_view Json = JsonBuilder.ToView(); // Broadcast to all connected clients, prune closed ones bool HadClosedConnections = false; for (auto& Conn : Connections) { if (Conn->IsOpen()) { Conn->SendText(Json); } else { HadClosedConnections = true; } } if (HadClosedConnections) { m_WsConnectionsLock.WithExclusiveLock([&] { auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref& C) { return !C->IsOpen(); }); m_WsConnections.erase(It, m_WsConnections.end()); }); } } } # endif } // namespace zen::compute #endif