diff options
Diffstat (limited to 'src/zencompute/httporchestrator.cpp')
| -rw-r--r-- | src/zencompute/httporchestrator.cpp | 135 |
1 files changed, 132 insertions, 3 deletions
diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp index d92af8716..1f51e560e 100644 --- a/src/zencompute/httporchestrator.cpp +++ b/src/zencompute/httporchestrator.cpp @@ -7,6 +7,7 @@ # include <zencompute/orchestratorservice.h> # include <zencore/compactbinarybuilder.h> # include <zencore/logging.h> +# include <zencore/session.h> # include <zencore/string.h> # include <zencore/system.h> @@ -77,10 +78,47 @@ ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnn return Ann.Id; } +static OrchestratorService::WorkerAnnotator +MakeWorkerAnnotator(IProvisionerStateProvider* Prov) +{ + if (!Prov) + { + return {}; + } + return [Prov](std::string_view WorkerId, CbObjectWriter& Cbo) { + AgentProvisioningStatus Status = Prov->GetAgentStatus(WorkerId); + if (Status != AgentProvisioningStatus::Unknown) + { + const char* StatusStr = (Status == AgentProvisioningStatus::Draining) ? "draining" : "active"; + Cbo << "provisioner_status" << std::string_view(StatusStr); + } + }; +} + +bool +HttpOrchestratorService::ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId) +{ + std::string_view SessionStr = Data["coordinator_session"].AsString(""); + if (SessionStr.empty()) + { + return true; // backwards compatibility: accept announcements without a session + } + Oid Session = Oid::TryFromHexString(SessionStr); + if (Session == m_SessionId) + { + return true; + } + ZEN_WARN("rejecting stale announcement from '{}' (session {} != {})", WorkerId, SessionStr, m_SessionId.ToString()); + return false; +} + HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) : m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket)) , m_Hostname(GetMachineName()) { + m_SessionId = zen::GetSessionId(); + ZEN_INFO("orchestrator session id: {}", m_SessionId.ToString()); + m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); @@ -95,13 +133,17 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, [this](HttpRouterRequest& Req) { CbObjectWriter Cbo; Cbo << "hostname" << std::string_view(m_Hostname); + Cbo << "session_id" << m_SessionId.ToString(); Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kGet); m_Router.RegisterRoute( "provision", - [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + [this](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, + m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)))); + }, HttpVerb::kPost); m_Router.RegisterRoute( @@ -122,6 +164,11 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, "characters and uri must start with http:// or https://"); } + if (!ValidateCoordinatorSession(Data, WorkerId)) + { + return HttpReq.WriteResponse(HttpResponseCode::Conflict, HttpContentType::kText, "Stale coordinator session"); + } + m_Service->AnnounceWorker(Ann); HttpReq.WriteResponse(HttpResponseCode::OK); @@ -135,7 +182,10 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, m_Router.RegisterRoute( "agents", - [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + [this](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, + m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)))); + }, HttpVerb::kGet); m_Router.RegisterRoute( @@ -241,6 +291,59 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, }, HttpVerb::kGet); + // Provisioner endpoints + + m_Router.RegisterRoute( + "provisioner/status", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + Cbo << "name" << Prov->GetName(); + Cbo << "target_cores" << Prov->GetTargetCoreCount(); + Cbo << "estimated_cores" << Prov->GetEstimatedCoreCount(); + Cbo << "active_cores" << Prov->GetActiveCoreCount(); + Cbo << "agents" << Prov->GetAgentCount(); + Cbo << "agents_draining" << Prov->GetDrainingAgentCount(); + } + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "provisioner/target", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObject Data = HttpReq.ReadPayloadObject(); + int32_t Cores = Data["target_cores"].AsInt32(-1); + + ZEN_INFO("provisioner/target: received request (target_cores={}, payload_valid={})", Cores, Data ? true : false); + + if (Cores < 0) + { + ZEN_WARN("provisioner/target: bad request (target_cores={})", Cores); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Missing or invalid target_cores field"); + } + + IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire); + if (!Prov) + { + ZEN_WARN("provisioner/target: no provisioner configured"); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "No provisioner configured"); + } + + ZEN_INFO("provisioner/target: setting target to {} cores", Cores); + Prov->SetTargetCoreCount(static_cast<uint32_t>(Cores)); + + CbObjectWriter Cbo; + Cbo << "target_cores" << Prov->GetTargetCoreCount(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + // Client tracking endpoints m_Router.RegisterRoute( @@ -411,6 +514,13 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) } } +void +HttpOrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider) +{ + m_Provisioner.store(Provider, std::memory_order_release); + m_Service->SetProvisionerStateProvider(Provider); +} + ////////////////////////////////////////////////////////////////////////// // // IWebSocketHandler @@ -488,6 +598,11 @@ HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Ms return {}; } + if (!ValidateCoordinatorSession(Data, WorkerId)) + { + return {}; + } + m_Service->AnnounceWorker(Ann); return std::string(WorkerId); } @@ -563,7 +678,7 @@ HttpOrchestratorService::PushThreadFunction() } // Build combined JSON with worker list, provisioning history, clients, and client history - CbObject WorkerList = m_Service->GetWorkerList(); + CbObject WorkerList = m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))); CbObject History = m_Service->GetProvisioningHistory(50); CbObject ClientList = m_Service->GetClientList(); CbObject ClientHistory = m_Service->GetClientHistory(50); @@ -615,6 +730,20 @@ HttpOrchestratorService::PushThreadFunction() JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2)); } + // Emit provisioner stats if available + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + JsonBuilder.Append( + fmt::format(",\"provisioner\":{{\"name\":\"{}\",\"target_cores\":{},\"estimated_cores\":{}" + ",\"active_cores\":{},\"agents\":{},\"agents_draining\":{}}}", + Prov->GetName(), + Prov->GetTargetCoreCount(), + Prov->GetEstimatedCoreCount(), + Prov->GetActiveCoreCount(), + Prov->GetAgentCount(), + Prov->GetDrainingAgentCount())); + } + JsonBuilder.Append("}"); std::string_view Json = JsonBuilder.ToView(); |