diff options
Diffstat (limited to 'src/zenserver')
| -rw-r--r-- | src/zenserver/frontend/html/pages/entry.js | 2 | ||||
| -rw-r--r-- | src/zenserver/frontend/html/pages/hub.js | 2 | ||||
| -rw-r--r-- | src/zenserver/hub/httphubservice.cpp | 58 | ||||
| -rw-r--r-- | src/zenserver/hub/httphubservice.h | 16 | ||||
| -rw-r--r-- | src/zenserver/hub/httpproxyhandler.cpp | 504 | ||||
| -rw-r--r-- | src/zenserver/hub/httpproxyhandler.h | 52 | ||||
| -rw-r--r-- | src/zenserver/hub/hub.cpp | 15 | ||||
| -rw-r--r-- | src/zenserver/hub/hub.h | 2 | ||||
| -rw-r--r-- | src/zenserver/hub/zenhubserver.cpp | 41 | ||||
| -rw-r--r-- | src/zenserver/hub/zenhubserver.h | 4 | ||||
| -rw-r--r-- | src/zenserver/sessions/httpsessions.cpp | 3 | ||||
| -rw-r--r-- | src/zenserver/sessions/httpsessions.h | 2 |
12 files changed, 682 insertions, 19 deletions
diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js index 1e4c82e3f..e381f4a71 100644 --- a/src/zenserver/frontend/html/pages/entry.js +++ b/src/zenserver/frontend/html/pages/entry.js @@ -168,7 +168,7 @@ export class Page extends ZenPage if (key === "cook.artifacts") { action_tb.left().add("view-raw").on_click(() => { - window.location = "/" + ["prj", project, "oplog", oplog, value+".json"].join("/"); + window.open("/" + ["prj", project, "oplog", oplog, value+".json"].join("/"), "_self"); }); } diff --git a/src/zenserver/frontend/html/pages/hub.js b/src/zenserver/frontend/html/pages/hub.js index c6f96d496..fcc792ddc 100644 --- a/src/zenserver/frontend/html/pages/hub.js +++ b/src/zenserver/frontend/html/pages/hub.js @@ -400,7 +400,7 @@ export class Page extends ZenPage const td_action = document.createElement("td"); td_action.className = "module-action-cell"; const [wrap_o, btn_o] = _make_action_btn("\u2197", "Open dashboard", () => { - window.open(`${window.location.protocol}//${window.location.hostname}:${port}`, "_blank"); + window.open(`/hub/proxy/${port}/dashboard/`, "_blank"); }); btn_o.disabled = state !== "provisioned"; const [wrap_h, btn_h] = _make_action_btn("\u23F8", "Hibernate", () => this._post_module_action(id, "hibernate").then(() => this._update())); diff --git a/src/zenserver/hub/httphubservice.cpp b/src/zenserver/hub/httphubservice.cpp index d52da5ae7..eba816793 100644 --- a/src/zenserver/hub/httphubservice.cpp +++ b/src/zenserver/hub/httphubservice.cpp @@ -2,6 +2,7 @@ #include "httphubservice.h" +#include "httpproxyhandler.h" #include "hub.h" #include "storageserverinstance.h" @@ -43,10 +44,11 @@ namespace { } } // namespace -HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpStatusService& StatusService) +HttpHubService::HttpHubService(Hub& Hub, HttpProxyHandler& Proxy, HttpStatsService& StatsService, HttpStatusService& StatusService) : m_Hub(Hub) , m_StatsService(StatsService) , m_StatusService(StatusService) +, m_Proxy(Proxy) { using namespace std::literals; @@ -67,6 +69,23 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta return true; }); + m_Router.AddMatcher("port", [](std::string_view Str) -> bool { + if (Str.empty()) + { + return false; + } + for (const auto C : Str) + { + if (!std::isdigit(C)) + { + return false; + } + } + return true; + }); + + m_Router.AddMatcher("proxypath", [](std::string_view Str) -> bool { return !Str.empty(); }); + m_Router.RegisterRoute( "status", [this](HttpRouterRequest& Req) { @@ -232,6 +251,25 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta }, HttpVerb::kPost); + m_Router.RegisterRoute( + "proxy/{port}/{proxypath}", + [this](HttpRouterRequest& Req) { + std::string_view PortStr = Req.GetCapture(1); + + // Use RelativeUriWithExtension to preserve the file extension that the + // router's URI parser strips (e.g. ".css", ".js") - the upstream server + // needs the full path including the extension. + std::string_view FullUri = Req.ServerRequest().RelativeUriWithExtension(); + std::string_view Prefix = "proxy/"; + + // FullUri is "proxy/{port}/{path...}" - skip past "proxy/{port}/" + size_t PathStart = Prefix.size() + PortStr.size() + 1; + std::string_view PathTail = (PathStart < FullUri.size()) ? FullUri.substr(PathStart) : std::string_view{}; + + m_Proxy.HandleProxyRequest(Req.ServerRequest(), PortStr, PathTail); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + m_StatsService.RegisterHandler("hub", *this); m_StatusService.RegisterHandler("hub", *this); } @@ -392,4 +430,22 @@ HttpHubService::HandleModuleDelete(HttpServerRequest& Request, std::string_view Request.WriteResponse(HttpResponseCode::OK, Obj.Save()); } +void +HttpHubService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) +{ + m_Proxy.OnWebSocketOpen(std::move(Connection), RelativeUri); +} + +void +HttpHubService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + m_Proxy.OnWebSocketMessage(Conn, Msg); +} + +void +HttpHubService::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) +{ + m_Proxy.OnWebSocketClose(Conn, Code, Reason); +} + } // namespace zen diff --git a/src/zenserver/hub/httphubservice.h b/src/zenserver/hub/httphubservice.h index 1bb1c303e..ff2cb0029 100644 --- a/src/zenserver/hub/httphubservice.h +++ b/src/zenserver/hub/httphubservice.h @@ -2,11 +2,16 @@ #pragma once +#include <zencore/thread.h> #include <zenhttp/httpserver.h> #include <zenhttp/httpstatus.h> +#include <zenhttp/websocket.h> + +#include <memory> namespace zen { +class HttpProxyHandler; class HttpStatsService; class Hub; @@ -16,10 +21,10 @@ class Hub; * use in UEFN content worker style scenarios. * */ -class HttpHubService : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider +class HttpHubService : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider, public IWebSocketHandler { public: - HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpStatusService& StatusService); + HttpHubService(Hub& Hub, HttpProxyHandler& Proxy, HttpStatsService& StatsService, HttpStatusService& StatusService); ~HttpHubService(); HttpHubService(const HttpHubService&) = delete; @@ -32,6 +37,11 @@ public: virtual CbObject CollectStats() override; virtual uint64_t GetActivityCounter() override; + // IWebSocketHandler + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; + void SetNotificationEndpoint(std::string_view UpstreamNotificationEndpoint, std::string_view InstanceId); private: @@ -45,6 +55,8 @@ private: void HandleModuleGet(HttpServerRequest& Request, std::string_view ModuleId); void HandleModuleDelete(HttpServerRequest& Request, std::string_view ModuleId); + + HttpProxyHandler& m_Proxy; }; } // namespace zen diff --git a/src/zenserver/hub/httpproxyhandler.cpp b/src/zenserver/hub/httpproxyhandler.cpp new file mode 100644 index 000000000..25842623a --- /dev/null +++ b/src/zenserver/hub/httpproxyhandler.cpp @@ -0,0 +1,504 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpproxyhandler.h" + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include <zenhttp/httpclient.h> +#include <zenhttp/httpwsclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <charconv> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +#endif // ZEN_WITH_TESTS + +namespace zen { + +namespace { + + std::string InjectProxyScript(std::string_view Html, uint16_t Port) + { + ExtendableStringBuilder<2048> Script; + Script.Append("<script>\n(function(){\n var P = \"/hub/proxy/"); + Script.Append(fmt::format("{}", Port)); + Script.Append( + "\";\n" + " var OF = window.fetch;\n" + " window.fetch = function(u, o) {\n" + " if (typeof u === \"string\") {\n" + " try {\n" + " var p = new URL(u, location.origin);\n" + " if (p.origin === location.origin && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {\n" + " if (u.startsWith(\"/\") && !u.startsWith(P)) u = P + u;\n" + " }\n" + " }\n" + " return OF.call(this, u, o);\n" + " };\n" + " var OW = window.WebSocket;\n" + " window.WebSocket = function(u, pr) {\n" + " try {\n" + " var p = new URL(u);\n" + " if (p.hostname === location.hostname\n" + " && String(p.port || (p.protocol === \"wss:\" ? \"443\" : \"80\"))\n" + " === String(location.port || (location.protocol === \"https:\" ? \"443\" : \"80\"))\n" + " && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {}\n" + " return pr !== undefined ? new OW(u, pr) : new OW(u);\n" + " };\n" + " window.WebSocket.prototype = OW.prototype;\n" + " window.WebSocket.CONNECTING = OW.CONNECTING;\n" + " window.WebSocket.OPEN = OW.OPEN;\n" + " window.WebSocket.CLOSING = OW.CLOSING;\n" + " window.WebSocket.CLOSED = OW.CLOSED;\n" + " var OO = window.open;\n" + " window.open = function(u, t, f) {\n" + " if (typeof u === \"string\") {\n" + " try {\n" + " var p = new URL(u, location.origin);\n" + " if (p.origin === location.origin && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {}\n" + " }\n" + " return OO.call(this, u, t, f);\n" + " };\n" + " document.addEventListener(\"click\", function(e) {\n" + " var t = e.composedPath ? e.composedPath()[0] : e.target;\n" + " while (t && t.tagName !== \"A\") t = t.parentNode || t.host;\n" + " if (!t || !t.href) return;\n" + " try {\n" + " var h = new URL(t.href);\n" + " if (h.origin === location.origin && !h.pathname.startsWith(P))\n" + " { h.pathname = P + h.pathname; e.preventDefault(); window.location.href = h.toString(); }\n" + " } catch(x) {}\n" + " }, true);\n" + "})();\n</script>"); + + std::string ScriptStr = Script.ToString(); + + size_t HeadClose = Html.find("</head>"); + if (HeadClose != std::string_view::npos) + { + std::string Result; + Result.reserve(Html.size() + ScriptStr.size()); + Result.append(Html.substr(0, HeadClose)); + Result.append(ScriptStr); + Result.append(Html.substr(HeadClose)); + return Result; + } + + std::string Result; + Result.reserve(Html.size() + ScriptStr.size()); + Result.append(ScriptStr); + Result.append(Html); + return Result; + } + +} // namespace + +struct HttpProxyHandler::WsBridge : public RefCounted, public IWsClientHandler +{ + Ref<WebSocketConnection> ClientConn; + std::unique_ptr<HttpWsClient> UpstreamClient; + uint16_t Port = 0; + + void OnWsOpen() override {} + + void OnWsMessage(const WebSocketMessage& Msg) override + { + if (!ClientConn->IsOpen()) + { + return; + } + switch (Msg.Opcode) + { + case WebSocketOpcode::kText: + ClientConn->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kBinary: + ClientConn->SendBinary(std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + default: + break; + } + } + + void OnWsClose(uint16_t Code, std::string_view Reason) override + { + if (ClientConn->IsOpen()) + { + ClientConn->Close(Code, Reason); + } + } +}; + +HttpProxyHandler::HttpProxyHandler() +{ +} + +HttpProxyHandler::HttpProxyHandler(PortValidator ValidatePort) : m_ValidatePort(std::move(ValidatePort)) +{ +} + +void +HttpProxyHandler::SetPortValidator(PortValidator ValidatePort) +{ + m_ValidatePort = std::move(ValidatePort); +} + +HttpProxyHandler::~HttpProxyHandler() +{ + try + { + Shutdown(); + } + catch (...) + { + } +} + +HttpClient& +HttpProxyHandler::GetOrCreateProxyClient(uint16_t Port) +{ + HttpClient* Result = nullptr; + m_ProxyClientsLock.WithExclusiveLock([&] { + auto It = m_ProxyClients.find(Port); + if (It == m_ProxyClients.end()) + { + HttpClientSettings Settings; + Settings.LogCategory = "hub-proxy"; + Settings.ConnectTimeout = std::chrono::milliseconds(5000); + Settings.Timeout = std::chrono::milliseconds(30000); + auto Client = std::make_unique<HttpClient>(fmt::format("http://127.0.0.1:{}", Port), Settings); + Result = Client.get(); + m_ProxyClients.emplace(Port, std::move(Client)); + } + else + { + Result = It->second.get(); + } + }); + return *Result; +} + +void +HttpProxyHandler::HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail) +{ + uint16_t Port = 0; + auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port); + if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size()) + { + Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "invalid proxy URL"); + return; + } + + if (!m_ValidatePort(Port)) + { + Request.WriteResponse(HttpResponseCode::BadGateway, HttpContentType::kText, "target instance not available"); + return; + } + + HttpClient& Client = GetOrCreateProxyClient(Port); + + std::string RequestPath; + RequestPath.reserve(1 + PathTail.size()); + RequestPath.push_back('/'); + RequestPath.append(PathTail); + + std::string_view QueryString = Request.QueryString(); + if (!QueryString.empty()) + { + RequestPath.push_back('?'); + RequestPath.append(QueryString); + } + + HttpClient::KeyValueMap ForwardHeaders; + HttpContentType AcceptType = Request.AcceptContentType(); + if (AcceptType != HttpContentType::kUnknownContentType) + { + ForwardHeaders->emplace("Accept", std::string(MapContentTypeToString(AcceptType))); + } + + std::string_view Auth = Request.GetAuthorizationHeader(); + if (!Auth.empty()) + { + ForwardHeaders->emplace("Authorization", std::string(Auth)); + } + + HttpContentType ReqContentType = Request.RequestContentType(); + if (ReqContentType != HttpContentType::kUnknownContentType) + { + ForwardHeaders->emplace("Content-Type", std::string(MapContentTypeToString(ReqContentType))); + } + + HttpClient::Response Response; + + switch (Request.RequestVerb()) + { + case HttpVerb::kGet: + Response = Client.Get(RequestPath, ForwardHeaders); + break; + case HttpVerb::kPost: + { + IoBuffer Payload = Request.ReadPayload(); + Response = Client.Post(RequestPath, Payload, ForwardHeaders); + break; + } + case HttpVerb::kPut: + { + IoBuffer Payload = Request.ReadPayload(); + Response = Client.Put(RequestPath, Payload, ForwardHeaders); + break; + } + case HttpVerb::kDelete: + Response = Client.Delete(RequestPath, ForwardHeaders); + break; + case HttpVerb::kHead: + Response = Client.Head(RequestPath, ForwardHeaders); + break; + default: + Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "method not supported"); + return; + } + + if (Response.Error) + { + ZEN_WARN("proxy request to port {} failed: {}", Port, Response.Error->ErrorMessage); + Request.WriteResponse(HttpResponseCode::BadGateway, HttpContentType::kText, "upstream request failed"); + return; + } + + HttpContentType ContentType = Response.ResponsePayload.GetContentType(); + + if (ContentType == HttpContentType::kHTML) + { + std::string_view Html(static_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize()); + std::string Injected = InjectProxyScript(Html, Port); + Request.WriteResponse(Response.StatusCode, HttpContentType::kHTML, std::string_view(Injected)); + } + else + { + Request.WriteResponse(Response.StatusCode, ContentType, std::move(Response.ResponsePayload)); + } +} + +void +HttpProxyHandler::PrunePort(uint16_t Port) +{ + m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.erase(Port); }); + + std::vector<Ref<WsBridge>> Stale; + m_WsBridgesLock.WithExclusiveLock([&] { + for (auto It = m_WsBridges.begin(); It != m_WsBridges.end();) + { + if (It->second->Port == Port) + { + Stale.push_back(std::move(It->second)); + It = m_WsBridges.erase(It); + } + else + { + ++It; + } + } + }); + + for (auto& Bridge : Stale) + { + if (Bridge->UpstreamClient) + { + Bridge->UpstreamClient->Close(1001, "instance shutting down"); + } + if (Bridge->ClientConn->IsOpen()) + { + Bridge->ClientConn->Close(1001, "instance shutting down"); + } + } +} + +void +HttpProxyHandler::Shutdown() +{ + m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.clear(); }); + m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.clear(); }); +} + +////////////////////////////////////////////////////////////////////////// +// +// WebSocket proxy +// + +void +HttpProxyHandler::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) +{ + const std::string_view ProxyPrefix = "proxy/"; + if (!RelativeUri.starts_with(ProxyPrefix)) + { + Connection->Close(1008, "unsupported WebSocket endpoint"); + return; + } + + std::string_view ProxyTail = RelativeUri.substr(ProxyPrefix.size()); + + size_t SlashPos = ProxyTail.find('/'); + std::string_view PortStr = (SlashPos != std::string_view::npos) ? ProxyTail.substr(0, SlashPos) : ProxyTail; + std::string_view Path = (SlashPos != std::string_view::npos) ? ProxyTail.substr(SlashPos) : "/"; + + uint16_t Port = 0; + auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port); + if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size()) + { + Connection->Close(1008, "invalid proxy URL"); + return; + } + + if (!m_ValidatePort(Port)) + { + Connection->Close(1008, "target instance not available"); + return; + } + + std::string WsUrl = HttpToWsUrl(fmt::format("http://127.0.0.1:{}", Port), Path); + + Ref<WsBridge> Bridge(new WsBridge()); + Bridge->ClientConn = Connection; + Bridge->Port = Port; + + Bridge->UpstreamClient = std::make_unique<HttpWsClient>(WsUrl, *Bridge); + + try + { + Bridge->UpstreamClient->Connect(); + } + catch (const std::exception& Ex) + { + ZEN_WARN("proxy WebSocket connect to {} failed: {}", WsUrl, Ex.what()); + Connection->Close(1011, "upstream connect failed"); + return; + } + + WebSocketConnection* Key = Connection.Get(); + m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.emplace(Key, std::move(Bridge)); }); +} + +void +HttpProxyHandler::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + Ref<WsBridge> Bridge; + m_WsBridgesLock.WithSharedLock([&] { + auto It = m_WsBridges.find(&Conn); + if (It != m_WsBridges.end()) + { + Bridge = It->second; + } + }); + + if (!Bridge || !Bridge->UpstreamClient) + { + return; + } + + switch (Msg.Opcode) + { + case WebSocketOpcode::kText: + Bridge->UpstreamClient->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kBinary: + Bridge->UpstreamClient->SendBinary( + std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kClose: + Bridge->UpstreamClient->Close(Msg.CloseCode, {}); + break; + default: + break; + } +} + +void +HttpProxyHandler::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) +{ + Ref<WsBridge> Bridge = m_WsBridgesLock.WithExclusiveLock([this, &Conn]() -> Ref<WsBridge> { + auto It = m_WsBridges.find(&Conn); + if (It != m_WsBridges.end()) + { + Ref<WsBridge> Bridge = std::move(It->second); + m_WsBridges.erase(It); + return Bridge; + } + return {}; + }); + + if (Bridge && Bridge->UpstreamClient) + { + Bridge->UpstreamClient->Close(Code, Reason); + } +} + +#if ZEN_WITH_TESTS + +TEST_SUITE_BEGIN("server.httpproxyhandler"); + +TEST_CASE("server.httpproxyhandler.html_injection") +{ + SUBCASE("injects before </head>") + { + std::string Result = InjectProxyScript("<html><head></head><body></body></html>", 21005); + CHECK(Result.find("<script>") != std::string::npos); + CHECK(Result.find("/hub/proxy/21005") != std::string::npos); + size_t ScriptEnd = Result.find("</script>"); + size_t HeadClose = Result.find("</head>"); + REQUIRE(ScriptEnd != std::string::npos); + REQUIRE(HeadClose != std::string::npos); + CHECK(ScriptEnd < HeadClose); + } + + SUBCASE("prepends when no </head>") + { + std::string Result = InjectProxyScript("<body>content</body>", 21005); + CHECK(Result.find("<script>") == 0); + CHECK(Result.find("<body>content</body>") != std::string::npos); + } + + SUBCASE("empty html") + { + std::string Result = InjectProxyScript("", 21005); + CHECK(Result.find("<script>") != std::string::npos); + CHECK(Result.find("/hub/proxy/21005") != std::string::npos); + } + + SUBCASE("preserves original content") + { + std::string_view Html = "<html><head><title>Test</title></head><body><h1>Dashboard</h1></body></html>"; + std::string Result = InjectProxyScript(Html, 21005); + CHECK(Result.find("<title>Test</title>") != std::string::npos); + CHECK(Result.find("<h1>Dashboard</h1>") != std::string::npos); + } +} + +TEST_CASE("server.httpproxyhandler.port_embedding") +{ + std::string Result = InjectProxyScript("<head></head>", 80); + CHECK(Result.find("/hub/proxy/80") != std::string::npos); + + Result = InjectProxyScript("<head></head>", 65535); + CHECK(Result.find("/hub/proxy/65535") != std::string::npos); +} + +TEST_SUITE_END(); + +void +httpproxyhandler_forcelink() +{ +} +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zenserver/hub/httpproxyhandler.h b/src/zenserver/hub/httpproxyhandler.h new file mode 100644 index 000000000..8667c0ca1 --- /dev/null +++ b/src/zenserver/hub/httpproxyhandler.h @@ -0,0 +1,52 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> + +#include <functional> +#include <memory> +#include <unordered_map> + +namespace zen { + +class HttpClient; + +class HttpProxyHandler +{ +public: + using PortValidator = std::function<bool(uint16_t)>; + + HttpProxyHandler(); + explicit HttpProxyHandler(PortValidator ValidatePort); + ~HttpProxyHandler(); + + void SetPortValidator(PortValidator ValidatePort); + + HttpProxyHandler(const HttpProxyHandler&) = delete; + HttpProxyHandler& operator=(const HttpProxyHandler&) = delete; + + void HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail); + void PrunePort(uint16_t Port); + void Shutdown(); + + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri); + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg); + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason); + +private: + PortValidator m_ValidatePort; + + HttpClient& GetOrCreateProxyClient(uint16_t Port); + + RwLock m_ProxyClientsLock; + std::unordered_map<uint16_t, std::unique_ptr<HttpClient>> m_ProxyClients; + + struct WsBridge; + RwLock m_WsBridgesLock; + std::unordered_map<WebSocketConnection*, Ref<WsBridge>> m_WsBridges; +}; + +} // namespace zen diff --git a/src/zenserver/hub/hub.cpp b/src/zenserver/hub/hub.cpp index 76c7a8f6d..82f4a00ba 100644 --- a/src/zenserver/hub/hub.cpp +++ b/src/zenserver/hub/hub.cpp @@ -1083,6 +1083,21 @@ Hub::GetInstanceIndexAssignedPort(size_t ActiveInstanceIndex) const return gsl::narrow<uint16_t>(m_Config.BasePortNumber + ActiveInstanceIndex); } +bool +Hub::IsInstancePort(uint16_t Port) const +{ + if (Port < m_Config.BasePortNumber) + { + return false; + } + size_t Index = Port - m_Config.BasePortNumber; + if (Index >= m_ActiveInstances.size()) + { + return false; + } + return m_ActiveInstances[Index].State.load(std::memory_order_relaxed) != HubInstanceState::Unprovisioned; +} + HubInstanceState Hub::UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewState) { diff --git a/src/zenserver/hub/hub.h b/src/zenserver/hub/hub.h index ac3e680ae..8ee9130f6 100644 --- a/src/zenserver/hub/hub.h +++ b/src/zenserver/hub/hub.h @@ -167,6 +167,8 @@ public: void GetMachineMetrics(SystemMetrics& OutSystemMetrict, DiskSpace& OutDiskSpace) const; + bool IsInstancePort(uint16_t Port) const; + const Configuration& GetConfig() const { return m_Config; } #if ZEN_WITH_TESTS diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp index b0e0023b1..5308a76f1 100644 --- a/src/zenserver/hub/zenhubserver.cpp +++ b/src/zenserver/hub/zenhubserver.cpp @@ -5,6 +5,7 @@ #include "config/luaconfig.h" #include "frontend/frontend.h" #include "httphubservice.h" +#include "httpproxyhandler.h" #include "hub.h" #include <zencore/compactbinary.h> @@ -388,6 +389,15 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId, HubInstanceState NewState) { ZEN_UNUSED(PreviousState); + + if (NewState == HubInstanceState::Deprovisioning || NewState == HubInstanceState::Hibernating) + { + if (Info.Port != 0) + { + m_Proxy->PrunePort(Info.Port); + } + } + if (!m_ConsulClient) { return; @@ -435,8 +445,8 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId, ZEN_INFO("Deregistered storage server instance for module '{}' at port {} from Consul", ModuleId, Info.Port); } } - // Transitional states (Deprovisioning, Hibernating, Waking, Recovering, Crashed) - // and Hibernated are intentionally ignored. + // Transitional states (Waking, Recovering, Crashed) and stable states + // not handled above (Hibernated) are intentionally ignored by Consul. } int @@ -489,6 +499,11 @@ ZenHubServer::Cleanup() m_Http->Close(); } + if (m_Proxy) + { + m_Proxy->Shutdown(); + } + if (m_Hub) { m_Hub->Shutdown(); @@ -498,6 +513,7 @@ ZenHubServer::Cleanup() m_HubService.reset(); m_ApiService.reset(); m_Hub.reset(); + m_Proxy.reset(); m_ConsulRegistration.reset(); m_ConsulClient.reset(); @@ -600,6 +616,8 @@ ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) HubConfig.HydrationOptions = std::move(Root).AsObject(); } + m_Proxy = std::make_unique<HttpProxyHandler>(); + m_Hub = std::make_unique<Hub>( std::move(HubConfig), ZenServerEnvironment(ZenServerEnvironment::Hub, @@ -607,20 +625,21 @@ ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) ServerConfig.DataDir / "servers", ServerConfig.HubInstanceHttpClass), &GetMediumWorkerPool(EWorkloadType::Background), - m_ConsulClient ? Hub::AsyncModuleStateChangeCallbackFunc{[this, HubInstanceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId)]( - std::string_view ModuleId, - const HubProvisionedInstanceInfo& Info, - HubInstanceState PreviousState, - HubInstanceState NewState) { - OnModuleStateChanged(HubInstanceId, ModuleId, Info, PreviousState, NewState); - }} - : Hub::AsyncModuleStateChangeCallbackFunc{}); + Hub::AsyncModuleStateChangeCallbackFunc{ + [this, HubInstanceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId)](std::string_view ModuleId, + const HubProvisionedInstanceInfo& Info, + HubInstanceState PreviousState, + HubInstanceState NewState) { + OnModuleStateChanged(HubInstanceId, ModuleId, Info, PreviousState, NewState); + }}); + + m_Proxy->SetPortValidator([Hub = m_Hub.get()](uint16_t Port) { return Hub->IsInstancePort(Port); }); ZEN_INFO("instantiating API service"); m_ApiService = std::make_unique<zen::HttpApiService>(*m_Http); ZEN_INFO("instantiating hub service"); - m_HubService = std::make_unique<HttpHubService>(*m_Hub, m_StatsService, m_StatusService); + m_HubService = std::make_unique<HttpHubService>(*m_Hub, *m_Proxy, m_StatsService, m_StatusService); m_HubService->SetNotificationEndpoint(ServerConfig.UpstreamNotificationEndpoint, ServerConfig.InstanceId); m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatsService, m_StatusService); diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h index b976c52b3..d1add7690 100644 --- a/src/zenserver/hub/zenhubserver.h +++ b/src/zenserver/hub/zenhubserver.h @@ -20,6 +20,7 @@ namespace zen { class HttpApiService; class HttpFrontendService; class HttpHubService; +class HttpProxyHandler; struct ZenHubWatchdogConfig { @@ -121,7 +122,8 @@ private: std::filesystem::path m_ContentRoot; bool m_DebugOptionForcedCrash = false; - std::unique_ptr<Hub> m_Hub; + std::unique_ptr<HttpProxyHandler> m_Proxy; + std::unique_ptr<Hub> m_Hub; std::unique_ptr<HttpHubService> m_HubService; std::unique_ptr<HttpApiService> m_ApiService; diff --git a/src/zenserver/sessions/httpsessions.cpp b/src/zenserver/sessions/httpsessions.cpp index fdf2e1f21..c21ae6a5c 100644 --- a/src/zenserver/sessions/httpsessions.cpp +++ b/src/zenserver/sessions/httpsessions.cpp @@ -512,8 +512,9 @@ HttpSessionsService::SessionLogRequest(HttpRouterRequest& Req) // void -HttpSessionsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpSessionsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); ZEN_INFO("Sessions WebSocket client connected"); m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); } diff --git a/src/zenserver/sessions/httpsessions.h b/src/zenserver/sessions/httpsessions.h index 86a23f835..6ebe61c8d 100644 --- a/src/zenserver/sessions/httpsessions.h +++ b/src/zenserver/sessions/httpsessions.h @@ -37,7 +37,7 @@ public: void SetSelfSessionId(const Oid& Id) { m_SelfSessionId = Id; } // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; |