aboutsummaryrefslogtreecommitdiff
path: root/src/zencompute/httporchestrator.cpp
diff options
context:
space:
mode:
authorLiam Mitchell <[email protected]>2026-03-09 19:06:36 -0700
committerLiam Mitchell <[email protected]>2026-03-09 19:06:36 -0700
commitd1abc50ee9d4fb72efc646e17decafea741caa34 (patch)
treee4288e00f2f7ca0391b83d986efcb69d3ba66a83 /src/zencompute/httporchestrator.cpp
parentAllow requests with invalid content-types unless specified in command line or... (diff)
parentupdated chunk–block analyser (#818) (diff)
downloadzen-d1abc50ee9d4fb72efc646e17decafea741caa34.tar.xz
zen-d1abc50ee9d4fb72efc646e17decafea741caa34.zip
Merge branch 'main' into lm/restrict-content-type
Diffstat (limited to 'src/zencompute/httporchestrator.cpp')
-rw-r--r--src/zencompute/httporchestrator.cpp650
1 files changed, 650 insertions, 0 deletions
diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp
new file mode 100644
index 000000000..6cbe01e04
--- /dev/null
+++ b/src/zencompute/httporchestrator.cpp
@@ -0,0 +1,650 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencompute/httporchestrator.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencompute/orchestratorservice.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/logging.h>
+# include <zencore/string.h>
+# include <zencore/system.h>
+
+namespace zen::compute {
+
+// Worker IDs must be 3-64 characters and can only contain letters, numbers, underscores, and dashes
+static bool
+IsValidWorkerId(std::string_view Id)
+{
+ if (Id.size() < 3 || Id.size() > 64)
+ {
+ return false;
+ }
+ for (char c : Id)
+ {
+ if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-')
+ {
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
+// Shared announce payload parser used by both the HTTP POST route and the
+// WebSocket message handler. Returns the worker ID on success (empty on
+// validation failure). The returned WorkerAnnouncement has string_view
+// fields that reference the supplied CbObjectView, so the CbObject must
+// outlive the returned announcement.
+static std::string_view
+ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnnouncement& Ann)
+{
+ Ann.Id = Data["id"].AsString("");
+ Ann.Uri = Data["uri"].AsString("");
+
+ if (!IsValidWorkerId(Ann.Id))
+ {
+ return {};
+ }
+
+ if (!Ann.Uri.starts_with("http://") && !Ann.Uri.starts_with("https://"))
+ {
+ return {};
+ }
+
+ Ann.Hostname = Data["hostname"].AsString("");
+ Ann.Platform = Data["platform"].AsString("");
+ Ann.CpuUsagePercent = Data["cpu_usage"].AsFloat(0.0f);
+ Ann.MemoryTotalBytes = Data["memory_total"].AsUInt64(0);
+ Ann.MemoryUsedBytes = Data["memory_used"].AsUInt64(0);
+ Ann.BytesReceived = Data["bytes_received"].AsUInt64(0);
+ Ann.BytesSent = Data["bytes_sent"].AsUInt64(0);
+ Ann.ActionsPending = Data["actions_pending"].AsInt32(0);
+ Ann.ActionsRunning = Data["actions_running"].AsInt32(0);
+ Ann.ActionsCompleted = Data["actions_completed"].AsInt32(0);
+ Ann.ActiveQueues = Data["active_queues"].AsInt32(0);
+ Ann.Provisioner = Data["provisioner"].AsString("");
+
+ if (auto Metrics = Data["metrics"].AsObjectView())
+ {
+ Ann.Cpus = Metrics["lp_count"].AsInt32(0);
+ if (Ann.Cpus <= 0)
+ {
+ Ann.Cpus = 1;
+ }
+ }
+
+ return Ann.Id;
+}
+
+HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket)
+: m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket))
+, m_Hostname(GetMachineName())
+{
+ m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
+ m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
+
+ // dummy endpoint for websocket clients
+ m_Router.RegisterRoute(
+ "ws",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "status",
+ [this](HttpRouterRequest& Req) {
+ CbObjectWriter Cbo;
+ Cbo << "hostname" << std::string_view(m_Hostname);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "provision",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "announce",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObject Data = HttpReq.ReadPayloadObject();
+
+ OrchestratorService::WorkerAnnouncement Ann;
+ std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann);
+
+ if (WorkerId.empty())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Invalid worker announcement: id must be 3-64 alphanumeric/underscore/dash "
+ "characters and uri must start with http:// or https://");
+ }
+
+ m_Service->AnnounceWorker(Ann);
+
+ HttpReq.WriteResponse(HttpResponseCode::OK);
+
+# if ZEN_WITH_WEBSOCKETS
+ // Notify push thread that state may have changed
+ m_PushEvent.Set();
+# endif
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "agents",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "history",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ auto Params = HttpReq.GetQueryParams();
+
+ int Limit = 100;
+ auto LimitStr = Params.GetValue("limit");
+ if (!LimitStr.empty())
+ {
+ Limit = std::atoi(std::string(LimitStr).c_str());
+ }
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetProvisioningHistory(Limit));
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "timeline/{workerid}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ std::string_view WorkerId = Req.GetCapture(1);
+ auto Params = HttpReq.GetQueryParams();
+
+ auto FromStr = Params.GetValue("from");
+ auto ToStr = Params.GetValue("to");
+ auto LimitStr = Params.GetValue("limit");
+
+ std::optional<DateTime> From;
+ std::optional<DateTime> To;
+
+ if (!FromStr.empty())
+ {
+ auto Val = zen::ParseInt<uint64_t>(FromStr);
+ if (!Val)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ From = DateTime(*Val);
+ }
+
+ if (!ToStr.empty())
+ {
+ auto Val = zen::ParseInt<uint64_t>(ToStr);
+ if (!Val)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ To = DateTime(*Val);
+ }
+
+ int Limit = !LimitStr.empty() ? zen::ParseInt<int>(LimitStr).value_or(0) : 0;
+
+ CbObject Result = m_Service->GetWorkerTimeline(WorkerId, From, To, Limit);
+
+ if (!Result)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result));
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "timeline",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ auto Params = HttpReq.GetQueryParams();
+
+ auto FromStr = Params.GetValue("from");
+ auto ToStr = Params.GetValue("to");
+
+ DateTime From = DateTime(0);
+ DateTime To = DateTime::Now();
+
+ if (!FromStr.empty())
+ {
+ auto Val = zen::ParseInt<uint64_t>(FromStr);
+ if (!Val)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ From = DateTime(*Val);
+ }
+
+ if (!ToStr.empty())
+ {
+ auto Val = zen::ParseInt<uint64_t>(ToStr);
+ if (!Val)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ To = DateTime(*Val);
+ }
+
+ CbObject Result = m_Service->GetAllTimelines(From, To);
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result));
+ },
+ HttpVerb::kGet);
+
+ // Client tracking endpoints
+
+ m_Router.RegisterRoute(
+ "clients",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObject Data = HttpReq.ReadPayloadObject();
+
+ OrchestratorService::ClientAnnouncement Ann;
+ Ann.SessionId = Data["session_id"].AsObjectId(Oid::Zero);
+ Ann.Hostname = Data["hostname"].AsString("");
+ Ann.Address = HttpReq.GetRemoteAddress();
+
+ auto MetadataView = Data["metadata"].AsObjectView();
+ if (MetadataView)
+ {
+ Ann.Metadata = CbObject::Clone(MetadataView);
+ }
+
+ std::string ClientId = m_Service->AnnounceClient(Ann);
+
+ CbObjectWriter ResponseObj;
+ ResponseObj << "id" << std::string_view(ClientId);
+ HttpReq.WriteResponse(HttpResponseCode::OK, ResponseObj.Save());
+
+# if ZEN_WITH_WEBSOCKETS
+ m_PushEvent.Set();
+# endif
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "clients/{clientid}/update",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view ClientId = Req.GetCapture(1);
+
+ CbObject MetadataObj;
+ CbObject Data = HttpReq.ReadPayloadObject();
+ if (Data)
+ {
+ auto MetadataView = Data["metadata"].AsObjectView();
+ if (MetadataView)
+ {
+ MetadataObj = CbObject::Clone(MetadataView);
+ }
+ }
+
+ if (m_Service->UpdateClient(ClientId, std::move(MetadataObj)))
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "clients/{clientid}/complete",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view ClientId = Req.GetCapture(1);
+
+ if (m_Service->CompleteClient(ClientId))
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+# if ZEN_WITH_WEBSOCKETS
+ m_PushEvent.Set();
+# endif
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "clients",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetClientList()); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "clients/history",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ auto Params = HttpReq.GetQueryParams();
+
+ int Limit = 100;
+ auto LimitStr = Params.GetValue("limit");
+ if (!LimitStr.empty())
+ {
+ Limit = std::atoi(std::string(LimitStr).c_str());
+ }
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetClientHistory(Limit));
+ },
+ HttpVerb::kGet);
+
+# if ZEN_WITH_WEBSOCKETS
+
+ // Start the WebSocket push thread
+ m_PushEnabled.store(true);
+ m_PushThread = std::thread([this] { PushThreadFunction(); });
+# endif
+}
+
+HttpOrchestratorService::~HttpOrchestratorService()
+{
+ Shutdown();
+}
+
+void
+HttpOrchestratorService::Shutdown()
+{
+# if ZEN_WITH_WEBSOCKETS
+ if (!m_PushEnabled.exchange(false))
+ {
+ return;
+ }
+
+ // Stop the push thread first, before touching connections. This ensures
+ // the push thread is no longer reading m_WsConnections or calling into
+ // m_Service when we start tearing things down.
+ m_PushEvent.Set();
+ if (m_PushThread.joinable())
+ {
+ m_PushThread.join();
+ }
+
+ // Clean up worker WebSocket connections — collect IDs under lock, then
+ // notify the service outside the lock to avoid lock-order inversions.
+ std::vector<std::string> WorkerIds;
+ m_WorkerWsLock.WithExclusiveLock([&] {
+ WorkerIds.reserve(m_WorkerWsMap.size());
+ for (const auto& [Conn, Id] : m_WorkerWsMap)
+ {
+ WorkerIds.push_back(Id);
+ }
+ m_WorkerWsMap.clear();
+ });
+ for (const auto& Id : WorkerIds)
+ {
+ m_Service->SetWorkerWebSocketConnected(Id, false);
+ }
+
+ // Now that the push thread is gone, release all dashboard connections.
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); });
+# endif
+}
+
+const char*
+HttpOrchestratorService::BaseUri() const
+{
+ return "/orch/";
+}
+
+void
+HttpOrchestratorService::HandleRequest(HttpServerRequest& Request)
+{
+ if (m_Router.HandleRequest(Request) == false)
+ {
+ ZEN_WARN("No route found for {0}", Request.RelativeUri());
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// IWebSocketHandler
+//
+
+# if ZEN_WITH_WEBSOCKETS
+void
+HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+{
+ if (!m_PushEnabled.load())
+ {
+ return;
+ }
+
+ ZEN_INFO("WebSocket client connected");
+
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
+
+ // Wake push thread to send initial state immediately
+ m_PushEvent.Set();
+}
+
+void
+HttpOrchestratorService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg)
+{
+ // Only handle binary messages from workers when the feature is enabled.
+ if (!m_Service->IsWorkerWebSocketEnabled() || Msg.Opcode != WebSocketOpcode::kBinary)
+ {
+ return;
+ }
+
+ std::string WorkerId = HandleWorkerWebSocketMessage(Msg);
+ if (WorkerId.empty())
+ {
+ return;
+ }
+
+ // Check if this is a new worker WebSocket connection
+ bool IsNewWorkerWs = false;
+ m_WorkerWsLock.WithExclusiveLock([&] {
+ auto It = m_WorkerWsMap.find(&Conn);
+ if (It == m_WorkerWsMap.end())
+ {
+ m_WorkerWsMap[&Conn] = WorkerId;
+ IsNewWorkerWs = true;
+ }
+ });
+
+ if (IsNewWorkerWs)
+ {
+ m_Service->SetWorkerWebSocketConnected(WorkerId, true);
+ }
+
+ m_PushEvent.Set();
+}
+
+std::string
+HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg)
+{
+ // Workers send CbObject in native binary format over the WebSocket to
+ // avoid the lossy CbObject↔JSON round-trip.
+ CbObject Data = CbObject::MakeView(Msg.Payload.GetData());
+ if (!Data)
+ {
+ ZEN_WARN("worker WebSocket message is not a valid CbObject");
+ return {};
+ }
+
+ OrchestratorService::WorkerAnnouncement Ann;
+ std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann);
+ if (WorkerId.empty())
+ {
+ ZEN_WARN("invalid worker announcement via WebSocket");
+ return {};
+ }
+
+ m_Service->AnnounceWorker(Ann);
+ return std::string(WorkerId);
+}
+
+void
+HttpOrchestratorService::OnWebSocketClose(WebSocketConnection& Conn,
+ [[maybe_unused]] uint16_t Code,
+ [[maybe_unused]] std::string_view Reason)
+{
+ ZEN_INFO("WebSocket client disconnected (code {})", Code);
+
+ // Check if this was a worker WebSocket connection; collect the ID under
+ // the worker lock, then notify the service outside the lock.
+ std::string DisconnectedWorkerId;
+ m_WorkerWsLock.WithExclusiveLock([&] {
+ auto It = m_WorkerWsMap.find(&Conn);
+ if (It != m_WorkerWsMap.end())
+ {
+ DisconnectedWorkerId = std::move(It->second);
+ m_WorkerWsMap.erase(It);
+ }
+ });
+
+ if (!DisconnectedWorkerId.empty())
+ {
+ m_Service->SetWorkerWebSocketConnected(DisconnectedWorkerId, false);
+ m_PushEvent.Set();
+ }
+
+ if (!m_PushEnabled.load())
+ {
+ return;
+ }
+
+ // Remove from dashboard connections
+ m_WsConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
+}
+# endif
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Push thread
+//
+
+# if ZEN_WITH_WEBSOCKETS
+void
+HttpOrchestratorService::PushThreadFunction()
+{
+ SetCurrentThreadName("orch_ws_push");
+
+ while (m_PushEnabled.load())
+ {
+ m_PushEvent.Wait(2000);
+ m_PushEvent.Reset();
+
+ if (!m_PushEnabled.load())
+ {
+ break;
+ }
+
+ // Snapshot current connections
+ std::vector<Ref<WebSocketConnection>> Connections;
+ m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; });
+
+ if (Connections.empty())
+ {
+ continue;
+ }
+
+ // Build combined JSON with worker list, provisioning history, clients, and client history
+ CbObject WorkerList = m_Service->GetWorkerList();
+ CbObject History = m_Service->GetProvisioningHistory(50);
+ CbObject ClientList = m_Service->GetClientList();
+ CbObject ClientHistory = m_Service->GetClientHistory(50);
+
+ ExtendableStringBuilder<4096> JsonBuilder;
+ JsonBuilder.Append("{");
+ JsonBuilder.Append(fmt::format("\"hostname\":\"{}\",", m_Hostname));
+
+ // Emit workers array from worker list
+ ExtendableStringBuilder<2048> WorkerJson;
+ WorkerList.ToJson(WorkerJson);
+ std::string_view WorkerJsonView = WorkerJson.ToView();
+ // Strip outer braces: {"workers":[...]} -> "workers":[...]
+ if (WorkerJsonView.size() >= 2)
+ {
+ JsonBuilder.Append(WorkerJsonView.substr(1, WorkerJsonView.size() - 2));
+ }
+
+ JsonBuilder.Append(",");
+
+ // Emit events array from history
+ ExtendableStringBuilder<2048> HistoryJson;
+ History.ToJson(HistoryJson);
+ std::string_view HistoryJsonView = HistoryJson.ToView();
+ if (HistoryJsonView.size() >= 2)
+ {
+ JsonBuilder.Append(HistoryJsonView.substr(1, HistoryJsonView.size() - 2));
+ }
+
+ JsonBuilder.Append(",");
+
+ // Emit clients array from client list
+ ExtendableStringBuilder<2048> ClientJson;
+ ClientList.ToJson(ClientJson);
+ std::string_view ClientJsonView = ClientJson.ToView();
+ if (ClientJsonView.size() >= 2)
+ {
+ JsonBuilder.Append(ClientJsonView.substr(1, ClientJsonView.size() - 2));
+ }
+
+ JsonBuilder.Append(",");
+
+ // Emit client_events array from client history
+ ExtendableStringBuilder<2048> ClientHistoryJson;
+ ClientHistory.ToJson(ClientHistoryJson);
+ std::string_view ClientHistoryJsonView = ClientHistoryJson.ToView();
+ if (ClientHistoryJsonView.size() >= 2)
+ {
+ JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2));
+ }
+
+ JsonBuilder.Append("}");
+ std::string_view Json = JsonBuilder.ToView();
+
+ // Broadcast to all connected clients, prune closed ones
+ bool HadClosedConnections = false;
+
+ for (auto& Conn : Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendText(Json);
+ }
+ else
+ {
+ HadClosedConnections = true;
+ }
+ }
+
+ if (HadClosedConnections)
+ {
+ m_WsConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref<WebSocketConnection>& C) {
+ return !C->IsOpen();
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
+ }
+ }
+}
+# endif
+
+} // namespace zen::compute
+
+#endif