aboutsummaryrefslogtreecommitdiff
path: root/src/zencompute/httporchestrator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zencompute/httporchestrator.cpp')
-rw-r--r--src/zencompute/httporchestrator.cpp135
1 files changed, 132 insertions, 3 deletions
diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp
index d92af8716..1f51e560e 100644
--- a/src/zencompute/httporchestrator.cpp
+++ b/src/zencompute/httporchestrator.cpp
@@ -7,6 +7,7 @@
# include <zencompute/orchestratorservice.h>
# include <zencore/compactbinarybuilder.h>
# include <zencore/logging.h>
+# include <zencore/session.h>
# include <zencore/string.h>
# include <zencore/system.h>
@@ -77,10 +78,47 @@ ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnn
return Ann.Id;
}
+static OrchestratorService::WorkerAnnotator
+MakeWorkerAnnotator(IProvisionerStateProvider* Prov)
+{
+ if (!Prov)
+ {
+ return {};
+ }
+ return [Prov](std::string_view WorkerId, CbObjectWriter& Cbo) {
+ AgentProvisioningStatus Status = Prov->GetAgentStatus(WorkerId);
+ if (Status != AgentProvisioningStatus::Unknown)
+ {
+ const char* StatusStr = (Status == AgentProvisioningStatus::Draining) ? "draining" : "active";
+ Cbo << "provisioner_status" << std::string_view(StatusStr);
+ }
+ };
+}
+
+bool
+HttpOrchestratorService::ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId)
+{
+ std::string_view SessionStr = Data["coordinator_session"].AsString("");
+ if (SessionStr.empty())
+ {
+ return true; // backwards compatibility: accept announcements without a session
+ }
+ Oid Session = Oid::TryFromHexString(SessionStr);
+ if (Session == m_SessionId)
+ {
+ return true;
+ }
+ ZEN_WARN("rejecting stale announcement from '{}' (session {} != {})", WorkerId, SessionStr, m_SessionId.ToString());
+ return false;
+}
+
HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket)
: m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket))
, m_Hostname(GetMachineName())
{
+ m_SessionId = zen::GetSessionId();
+ ZEN_INFO("orchestrator session id: {}", m_SessionId.ToString());
+
m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
@@ -95,13 +133,17 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
[this](HttpRouterRequest& Req) {
CbObjectWriter Cbo;
Cbo << "hostname" << std::string_view(m_Hostname);
+ Cbo << "session_id" << m_SessionId.ToString();
Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save());
},
HttpVerb::kGet);
m_Router.RegisterRoute(
"provision",
- [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ [this](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK,
+ m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))));
+ },
HttpVerb::kPost);
m_Router.RegisterRoute(
@@ -122,6 +164,11 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
"characters and uri must start with http:// or https://");
}
+ if (!ValidateCoordinatorSession(Data, WorkerId))
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::Conflict, HttpContentType::kText, "Stale coordinator session");
+ }
+
m_Service->AnnounceWorker(Ann);
HttpReq.WriteResponse(HttpResponseCode::OK);
@@ -135,7 +182,10 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
m_Router.RegisterRoute(
"agents",
- [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ [this](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK,
+ m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))));
+ },
HttpVerb::kGet);
m_Router.RegisterRoute(
@@ -241,6 +291,59 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
},
HttpVerb::kGet);
+ // Provisioner endpoints
+
+ m_Router.RegisterRoute(
+ "provisioner/status",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObjectWriter Cbo;
+ if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire))
+ {
+ Cbo << "name" << Prov->GetName();
+ Cbo << "target_cores" << Prov->GetTargetCoreCount();
+ Cbo << "estimated_cores" << Prov->GetEstimatedCoreCount();
+ Cbo << "active_cores" << Prov->GetActiveCoreCount();
+ Cbo << "agents" << Prov->GetAgentCount();
+ Cbo << "agents_draining" << Prov->GetDrainingAgentCount();
+ }
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "provisioner/target",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObject Data = HttpReq.ReadPayloadObject();
+ int32_t Cores = Data["target_cores"].AsInt32(-1);
+
+ ZEN_INFO("provisioner/target: received request (target_cores={}, payload_valid={})", Cores, Data ? true : false);
+
+ if (Cores < 0)
+ {
+ ZEN_WARN("provisioner/target: bad request (target_cores={})", Cores);
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Missing or invalid target_cores field");
+ }
+
+ IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire);
+ if (!Prov)
+ {
+ ZEN_WARN("provisioner/target: no provisioner configured");
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "No provisioner configured");
+ }
+
+ ZEN_INFO("provisioner/target: setting target to {} cores", Cores);
+ Prov->SetTargetCoreCount(static_cast<uint32_t>(Cores));
+
+ CbObjectWriter Cbo;
+ Cbo << "target_cores" << Prov->GetTargetCoreCount();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
// Client tracking endpoints
m_Router.RegisterRoute(
@@ -411,6 +514,13 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request)
}
}
+void
+HttpOrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider)
+{
+ m_Provisioner.store(Provider, std::memory_order_release);
+ m_Service->SetProvisionerStateProvider(Provider);
+}
+
//////////////////////////////////////////////////////////////////////////
//
// IWebSocketHandler
@@ -488,6 +598,11 @@ HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Ms
return {};
}
+ if (!ValidateCoordinatorSession(Data, WorkerId))
+ {
+ return {};
+ }
+
m_Service->AnnounceWorker(Ann);
return std::string(WorkerId);
}
@@ -563,7 +678,7 @@ HttpOrchestratorService::PushThreadFunction()
}
// Build combined JSON with worker list, provisioning history, clients, and client history
- CbObject WorkerList = m_Service->GetWorkerList();
+ CbObject WorkerList = m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)));
CbObject History = m_Service->GetProvisioningHistory(50);
CbObject ClientList = m_Service->GetClientList();
CbObject ClientHistory = m_Service->GetClientHistory(50);
@@ -615,6 +730,20 @@ HttpOrchestratorService::PushThreadFunction()
JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2));
}
+ // Emit provisioner stats if available
+ if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire))
+ {
+ JsonBuilder.Append(
+ fmt::format(",\"provisioner\":{{\"name\":\"{}\",\"target_cores\":{},\"estimated_cores\":{}"
+ ",\"active_cores\":{},\"agents\":{},\"agents_draining\":{}}}",
+ Prov->GetName(),
+ Prov->GetTargetCoreCount(),
+ Prov->GetEstimatedCoreCount(),
+ Prov->GetActiveCoreCount(),
+ Prov->GetAgentCount(),
+ Prov->GetDrainingAgentCount()));
+ }
+
JsonBuilder.Append("}");
std::string_view Json = JsonBuilder.ToView();