diff options
Diffstat (limited to 'src/zenserver/hub/httpproxyhandler.cpp')
| -rw-r--r-- | src/zenserver/hub/httpproxyhandler.cpp | 504 |
1 files changed, 504 insertions, 0 deletions
diff --git a/src/zenserver/hub/httpproxyhandler.cpp b/src/zenserver/hub/httpproxyhandler.cpp new file mode 100644 index 000000000..25842623a --- /dev/null +++ b/src/zenserver/hub/httpproxyhandler.cpp @@ -0,0 +1,504 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpproxyhandler.h" + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include <zenhttp/httpclient.h> +#include <zenhttp/httpwsclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <charconv> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +#endif // ZEN_WITH_TESTS + +namespace zen { + +namespace { + + std::string InjectProxyScript(std::string_view Html, uint16_t Port) + { + ExtendableStringBuilder<2048> Script; + Script.Append("<script>\n(function(){\n var P = \"/hub/proxy/"); + Script.Append(fmt::format("{}", Port)); + Script.Append( + "\";\n" + " var OF = window.fetch;\n" + " window.fetch = function(u, o) {\n" + " if (typeof u === \"string\") {\n" + " try {\n" + " var p = new URL(u, location.origin);\n" + " if (p.origin === location.origin && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {\n" + " if (u.startsWith(\"/\") && !u.startsWith(P)) u = P + u;\n" + " }\n" + " }\n" + " return OF.call(this, u, o);\n" + " };\n" + " var OW = window.WebSocket;\n" + " window.WebSocket = function(u, pr) {\n" + " try {\n" + " var p = new URL(u);\n" + " if (p.hostname === location.hostname\n" + " && String(p.port || (p.protocol === \"wss:\" ? \"443\" : \"80\"))\n" + " === String(location.port || (location.protocol === \"https:\" ? \"443\" : \"80\"))\n" + " && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {}\n" + " return pr !== undefined ? new OW(u, pr) : new OW(u);\n" + " };\n" + " window.WebSocket.prototype = OW.prototype;\n" + " window.WebSocket.CONNECTING = OW.CONNECTING;\n" + " window.WebSocket.OPEN = OW.OPEN;\n" + " window.WebSocket.CLOSING = OW.CLOSING;\n" + " window.WebSocket.CLOSED = OW.CLOSED;\n" + " var OO = window.open;\n" + " window.open = function(u, t, f) {\n" + " if (typeof u === \"string\") {\n" + " try {\n" + " var p = new URL(u, location.origin);\n" + " if (p.origin === location.origin && !p.pathname.startsWith(P))\n" + " { p.pathname = P + p.pathname; u = p.toString(); }\n" + " } catch(e) {}\n" + " }\n" + " return OO.call(this, u, t, f);\n" + " };\n" + " document.addEventListener(\"click\", function(e) {\n" + " var t = e.composedPath ? e.composedPath()[0] : e.target;\n" + " while (t && t.tagName !== \"A\") t = t.parentNode || t.host;\n" + " if (!t || !t.href) return;\n" + " try {\n" + " var h = new URL(t.href);\n" + " if (h.origin === location.origin && !h.pathname.startsWith(P))\n" + " { h.pathname = P + h.pathname; e.preventDefault(); window.location.href = h.toString(); }\n" + " } catch(x) {}\n" + " }, true);\n" + "})();\n</script>"); + + std::string ScriptStr = Script.ToString(); + + size_t HeadClose = Html.find("</head>"); + if (HeadClose != std::string_view::npos) + { + std::string Result; + Result.reserve(Html.size() + ScriptStr.size()); + Result.append(Html.substr(0, HeadClose)); + Result.append(ScriptStr); + Result.append(Html.substr(HeadClose)); + return Result; + } + + std::string Result; + Result.reserve(Html.size() + ScriptStr.size()); + Result.append(ScriptStr); + Result.append(Html); + return Result; + } + +} // namespace + +struct HttpProxyHandler::WsBridge : public RefCounted, public IWsClientHandler +{ + Ref<WebSocketConnection> ClientConn; + std::unique_ptr<HttpWsClient> UpstreamClient; + uint16_t Port = 0; + + void OnWsOpen() override {} + + void OnWsMessage(const WebSocketMessage& Msg) override + { + if (!ClientConn->IsOpen()) + { + return; + } + switch (Msg.Opcode) + { + case WebSocketOpcode::kText: + ClientConn->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kBinary: + ClientConn->SendBinary(std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + default: + break; + } + } + + void OnWsClose(uint16_t Code, std::string_view Reason) override + { + if (ClientConn->IsOpen()) + { + ClientConn->Close(Code, Reason); + } + } +}; + +HttpProxyHandler::HttpProxyHandler() +{ +} + +HttpProxyHandler::HttpProxyHandler(PortValidator ValidatePort) : m_ValidatePort(std::move(ValidatePort)) +{ +} + +void +HttpProxyHandler::SetPortValidator(PortValidator ValidatePort) +{ + m_ValidatePort = std::move(ValidatePort); +} + +HttpProxyHandler::~HttpProxyHandler() +{ + try + { + Shutdown(); + } + catch (...) + { + } +} + +HttpClient& +HttpProxyHandler::GetOrCreateProxyClient(uint16_t Port) +{ + HttpClient* Result = nullptr; + m_ProxyClientsLock.WithExclusiveLock([&] { + auto It = m_ProxyClients.find(Port); + if (It == m_ProxyClients.end()) + { + HttpClientSettings Settings; + Settings.LogCategory = "hub-proxy"; + Settings.ConnectTimeout = std::chrono::milliseconds(5000); + Settings.Timeout = std::chrono::milliseconds(30000); + auto Client = std::make_unique<HttpClient>(fmt::format("http://127.0.0.1:{}", Port), Settings); + Result = Client.get(); + m_ProxyClients.emplace(Port, std::move(Client)); + } + else + { + Result = It->second.get(); + } + }); + return *Result; +} + +void +HttpProxyHandler::HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail) +{ + uint16_t Port = 0; + auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port); + if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size()) + { + Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "invalid proxy URL"); + return; + } + + if (!m_ValidatePort(Port)) + { + Request.WriteResponse(HttpResponseCode::BadGateway, HttpContentType::kText, "target instance not available"); + return; + } + + HttpClient& Client = GetOrCreateProxyClient(Port); + + std::string RequestPath; + RequestPath.reserve(1 + PathTail.size()); + RequestPath.push_back('/'); + RequestPath.append(PathTail); + + std::string_view QueryString = Request.QueryString(); + if (!QueryString.empty()) + { + RequestPath.push_back('?'); + RequestPath.append(QueryString); + } + + HttpClient::KeyValueMap ForwardHeaders; + HttpContentType AcceptType = Request.AcceptContentType(); + if (AcceptType != HttpContentType::kUnknownContentType) + { + ForwardHeaders->emplace("Accept", std::string(MapContentTypeToString(AcceptType))); + } + + std::string_view Auth = Request.GetAuthorizationHeader(); + if (!Auth.empty()) + { + ForwardHeaders->emplace("Authorization", std::string(Auth)); + } + + HttpContentType ReqContentType = Request.RequestContentType(); + if (ReqContentType != HttpContentType::kUnknownContentType) + { + ForwardHeaders->emplace("Content-Type", std::string(MapContentTypeToString(ReqContentType))); + } + + HttpClient::Response Response; + + switch (Request.RequestVerb()) + { + case HttpVerb::kGet: + Response = Client.Get(RequestPath, ForwardHeaders); + break; + case HttpVerb::kPost: + { + IoBuffer Payload = Request.ReadPayload(); + Response = Client.Post(RequestPath, Payload, ForwardHeaders); + break; + } + case HttpVerb::kPut: + { + IoBuffer Payload = Request.ReadPayload(); + Response = Client.Put(RequestPath, Payload, ForwardHeaders); + break; + } + case HttpVerb::kDelete: + Response = Client.Delete(RequestPath, ForwardHeaders); + break; + case HttpVerb::kHead: + Response = Client.Head(RequestPath, ForwardHeaders); + break; + default: + Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "method not supported"); + return; + } + + if (Response.Error) + { + ZEN_WARN("proxy request to port {} failed: {}", Port, Response.Error->ErrorMessage); + Request.WriteResponse(HttpResponseCode::BadGateway, HttpContentType::kText, "upstream request failed"); + return; + } + + HttpContentType ContentType = Response.ResponsePayload.GetContentType(); + + if (ContentType == HttpContentType::kHTML) + { + std::string_view Html(static_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize()); + std::string Injected = InjectProxyScript(Html, Port); + Request.WriteResponse(Response.StatusCode, HttpContentType::kHTML, std::string_view(Injected)); + } + else + { + Request.WriteResponse(Response.StatusCode, ContentType, std::move(Response.ResponsePayload)); + } +} + +void +HttpProxyHandler::PrunePort(uint16_t Port) +{ + m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.erase(Port); }); + + std::vector<Ref<WsBridge>> Stale; + m_WsBridgesLock.WithExclusiveLock([&] { + for (auto It = m_WsBridges.begin(); It != m_WsBridges.end();) + { + if (It->second->Port == Port) + { + Stale.push_back(std::move(It->second)); + It = m_WsBridges.erase(It); + } + else + { + ++It; + } + } + }); + + for (auto& Bridge : Stale) + { + if (Bridge->UpstreamClient) + { + Bridge->UpstreamClient->Close(1001, "instance shutting down"); + } + if (Bridge->ClientConn->IsOpen()) + { + Bridge->ClientConn->Close(1001, "instance shutting down"); + } + } +} + +void +HttpProxyHandler::Shutdown() +{ + m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.clear(); }); + m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.clear(); }); +} + +////////////////////////////////////////////////////////////////////////// +// +// WebSocket proxy +// + +void +HttpProxyHandler::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) +{ + const std::string_view ProxyPrefix = "proxy/"; + if (!RelativeUri.starts_with(ProxyPrefix)) + { + Connection->Close(1008, "unsupported WebSocket endpoint"); + return; + } + + std::string_view ProxyTail = RelativeUri.substr(ProxyPrefix.size()); + + size_t SlashPos = ProxyTail.find('/'); + std::string_view PortStr = (SlashPos != std::string_view::npos) ? ProxyTail.substr(0, SlashPos) : ProxyTail; + std::string_view Path = (SlashPos != std::string_view::npos) ? ProxyTail.substr(SlashPos) : "/"; + + uint16_t Port = 0; + auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port); + if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size()) + { + Connection->Close(1008, "invalid proxy URL"); + return; + } + + if (!m_ValidatePort(Port)) + { + Connection->Close(1008, "target instance not available"); + return; + } + + std::string WsUrl = HttpToWsUrl(fmt::format("http://127.0.0.1:{}", Port), Path); + + Ref<WsBridge> Bridge(new WsBridge()); + Bridge->ClientConn = Connection; + Bridge->Port = Port; + + Bridge->UpstreamClient = std::make_unique<HttpWsClient>(WsUrl, *Bridge); + + try + { + Bridge->UpstreamClient->Connect(); + } + catch (const std::exception& Ex) + { + ZEN_WARN("proxy WebSocket connect to {} failed: {}", WsUrl, Ex.what()); + Connection->Close(1011, "upstream connect failed"); + return; + } + + WebSocketConnection* Key = Connection.Get(); + m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.emplace(Key, std::move(Bridge)); }); +} + +void +HttpProxyHandler::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + Ref<WsBridge> Bridge; + m_WsBridgesLock.WithSharedLock([&] { + auto It = m_WsBridges.find(&Conn); + if (It != m_WsBridges.end()) + { + Bridge = It->second; + } + }); + + if (!Bridge || !Bridge->UpstreamClient) + { + return; + } + + switch (Msg.Opcode) + { + case WebSocketOpcode::kText: + Bridge->UpstreamClient->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kBinary: + Bridge->UpstreamClient->SendBinary( + std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize())); + break; + case WebSocketOpcode::kClose: + Bridge->UpstreamClient->Close(Msg.CloseCode, {}); + break; + default: + break; + } +} + +void +HttpProxyHandler::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) +{ + Ref<WsBridge> Bridge = m_WsBridgesLock.WithExclusiveLock([this, &Conn]() -> Ref<WsBridge> { + auto It = m_WsBridges.find(&Conn); + if (It != m_WsBridges.end()) + { + Ref<WsBridge> Bridge = std::move(It->second); + m_WsBridges.erase(It); + return Bridge; + } + return {}; + }); + + if (Bridge && Bridge->UpstreamClient) + { + Bridge->UpstreamClient->Close(Code, Reason); + } +} + +#if ZEN_WITH_TESTS + +TEST_SUITE_BEGIN("server.httpproxyhandler"); + +TEST_CASE("server.httpproxyhandler.html_injection") +{ + SUBCASE("injects before </head>") + { + std::string Result = InjectProxyScript("<html><head></head><body></body></html>", 21005); + CHECK(Result.find("<script>") != std::string::npos); + CHECK(Result.find("/hub/proxy/21005") != std::string::npos); + size_t ScriptEnd = Result.find("</script>"); + size_t HeadClose = Result.find("</head>"); + REQUIRE(ScriptEnd != std::string::npos); + REQUIRE(HeadClose != std::string::npos); + CHECK(ScriptEnd < HeadClose); + } + + SUBCASE("prepends when no </head>") + { + std::string Result = InjectProxyScript("<body>content</body>", 21005); + CHECK(Result.find("<script>") == 0); + CHECK(Result.find("<body>content</body>") != std::string::npos); + } + + SUBCASE("empty html") + { + std::string Result = InjectProxyScript("", 21005); + CHECK(Result.find("<script>") != std::string::npos); + CHECK(Result.find("/hub/proxy/21005") != std::string::npos); + } + + SUBCASE("preserves original content") + { + std::string_view Html = "<html><head><title>Test</title></head><body><h1>Dashboard</h1></body></html>"; + std::string Result = InjectProxyScript(Html, 21005); + CHECK(Result.find("<title>Test</title>") != std::string::npos); + CHECK(Result.find("<h1>Dashboard</h1>") != std::string::npos); + } +} + +TEST_CASE("server.httpproxyhandler.port_embedding") +{ + std::string Result = InjectProxyScript("<head></head>", 80); + CHECK(Result.find("/hub/proxy/80") != std::string::npos); + + Result = InjectProxyScript("<head></head>", 65535); + CHECK(Result.find("/hub/proxy/65535") != std::string::npos); +} + +TEST_SUITE_END(); + +void +httpproxyhandler_forcelink() +{ +} +#endif // ZEN_WITH_TESTS + +} // namespace zen |