// Copyright Epic Games, Inc. All Rights Reserved. #include "httpproxyhandler.h" #include #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include ZEN_THIRD_PARTY_INCLUDES_END #include #if ZEN_WITH_TESTS # include #endif // ZEN_WITH_TESTS namespace zen { namespace { std::string InjectProxyScript(std::string_view Html, uint16_t Port) { ExtendableStringBuilder<2048> Script; Script.Append(""); std::string ScriptStr = Script.ToString(); size_t HeadClose = Html.find(""); if (HeadClose != std::string_view::npos) { std::string Result; Result.reserve(Html.size() + ScriptStr.size()); Result.append(Html.substr(0, HeadClose)); Result.append(ScriptStr); Result.append(Html.substr(HeadClose)); return Result; } std::string Result; Result.reserve(Html.size() + ScriptStr.size()); Result.append(ScriptStr); Result.append(Html); return Result; } } // namespace struct HttpProxyHandler::WsBridge : public RefCounted, public IWsClientHandler { Ref ClientConn; std::unique_ptr UpstreamClient; uint16_t Port = 0; void OnWsOpen() override {} void OnWsMessage(const WebSocketMessage& Msg) override { if (!ClientConn->IsOpen()) { return; } switch (Msg.Opcode) { case WebSocketOpcode::kText: ClientConn->SendText(std::string_view(static_cast(Msg.Payload.GetData()), Msg.Payload.GetSize())); break; case WebSocketOpcode::kBinary: ClientConn->SendBinary(std::span(static_cast(Msg.Payload.GetData()), Msg.Payload.GetSize())); break; default: break; } } void OnWsClose(uint16_t Code, std::string_view Reason) override { if (ClientConn->IsOpen()) { ClientConn->Close(Code, Reason); } } }; HttpProxyHandler::HttpProxyHandler() { } HttpProxyHandler::HttpProxyHandler(PortValidator ValidatePort) : m_ValidatePort(std::move(ValidatePort)) { } void HttpProxyHandler::SetPortValidator(PortValidator ValidatePort) { m_ValidatePort = std::move(ValidatePort); } HttpProxyHandler::~HttpProxyHandler() { try { Shutdown(); } catch (...) { } } HttpClient& HttpProxyHandler::GetOrCreateProxyClient(uint16_t Port) { HttpClient* Result = nullptr; m_ProxyClientsLock.WithExclusiveLock([&] { auto It = m_ProxyClients.find(Port); if (It == m_ProxyClients.end()) { HttpClientSettings Settings; Settings.LogCategory = "hub-proxy"; Settings.ConnectTimeout = std::chrono::milliseconds(5000); Settings.Timeout = std::chrono::milliseconds(30000); auto Client = std::make_unique(fmt::format("http://127.0.0.1:{}", Port), Settings); Result = Client.get(); m_ProxyClients.emplace(Port, std::move(Client)); } else { Result = It->second.get(); } }); return *Result; } void HttpProxyHandler::HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail) { uint16_t Port = 0; auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port); if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size()) { Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "invalid proxy URL"); return; } if (!m_ValidatePort(Port)) { Request.WriteResponse(HttpResponseCode::BadGateway, HttpContentType::kText, "target instance not available"); return; } HttpClient& Client = GetOrCreateProxyClient(Port); std::string RequestPath; RequestPath.reserve(1 + PathTail.size()); RequestPath.push_back('/'); RequestPath.append(PathTail); std::string_view QueryString = Request.QueryString(); if (!QueryString.empty()) { RequestPath.push_back('?'); RequestPath.append(QueryString); } HttpClient::KeyValueMap ForwardHeaders; HttpContentType AcceptType = Request.AcceptContentType(); if (AcceptType != HttpContentType::kUnknownContentType) { ForwardHeaders->emplace("Accept", std::string(MapContentTypeToString(AcceptType))); } std::string_view Auth = Request.GetAuthorizationHeader(); if (!Auth.empty()) { ForwardHeaders->emplace("Authorization", std::string(Auth)); } HttpContentType ReqContentType = Request.RequestContentType(); if (ReqContentType != HttpContentType::kUnknownContentType) { ForwardHeaders->emplace("Content-Type", std::string(MapContentTypeToString(ReqContentType))); } HttpClient::Response Response; switch (Request.RequestVerb()) { case HttpVerb::kGet: Response = Client.Get(RequestPath, ForwardHeaders); break; case HttpVerb::kPost: { IoBuffer Payload = Request.ReadPayload(); Response = Client.Post(RequestPath, Payload, ForwardHeaders); break; } case HttpVerb::kPut: { IoBuffer Payload = Request.ReadPayload(); Response = Client.Put(RequestPath, Payload, ForwardHeaders); break; } case HttpVerb::kDelete: Response = Client.Delete(RequestPath, ForwardHeaders); break; case HttpVerb::kHead: Response = Client.Head(RequestPath, ForwardHeaders); break; default: Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "method not supported"); return; } if (Response.Error) { ZEN_WARN("proxy request to port {} failed: {}", Port, Response.Error->ErrorMessage); Request.WriteResponse(HttpResponseCode::BadGateway, HttpContentType::kText, "upstream request failed"); return; } HttpContentType ContentType = Response.ResponsePayload.GetContentType(); if (ContentType == HttpContentType::kHTML) { std::string_view Html(static_cast(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize()); std::string Injected = InjectProxyScript(Html, Port); Request.WriteResponse(Response.StatusCode, HttpContentType::kHTML, std::string_view(Injected)); } else { Request.WriteResponse(Response.StatusCode, ContentType, std::move(Response.ResponsePayload)); } } void HttpProxyHandler::PrunePort(uint16_t Port) { m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.erase(Port); }); std::vector> Stale; m_WsBridgesLock.WithExclusiveLock([&] { for (auto It = m_WsBridges.begin(); It != m_WsBridges.end();) { if (It->second->Port == Port) { Stale.push_back(std::move(It->second)); It = m_WsBridges.erase(It); } else { ++It; } } }); for (auto& Bridge : Stale) { if (Bridge->UpstreamClient) { Bridge->UpstreamClient->Close(1001, "instance shutting down"); } if (Bridge->ClientConn->IsOpen()) { Bridge->ClientConn->Close(1001, "instance shutting down"); } } } void HttpProxyHandler::Shutdown() { m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.clear(); }); m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.clear(); }); } ////////////////////////////////////////////////////////////////////////// // // WebSocket proxy // void HttpProxyHandler::OnWebSocketOpen(Ref Connection, std::string_view RelativeUri) { const std::string_view ProxyPrefix = "proxy/"; if (!RelativeUri.starts_with(ProxyPrefix)) { Connection->Close(1008, "unsupported WebSocket endpoint"); return; } std::string_view ProxyTail = RelativeUri.substr(ProxyPrefix.size()); size_t SlashPos = ProxyTail.find('/'); std::string_view PortStr = (SlashPos != std::string_view::npos) ? ProxyTail.substr(0, SlashPos) : ProxyTail; std::string_view Path = (SlashPos != std::string_view::npos) ? ProxyTail.substr(SlashPos) : "/"; uint16_t Port = 0; auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port); if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size()) { Connection->Close(1008, "invalid proxy URL"); return; } if (!m_ValidatePort(Port)) { Connection->Close(1008, "target instance not available"); return; } std::string WsUrl = HttpToWsUrl(fmt::format("http://127.0.0.1:{}", Port), Path); Ref Bridge(new WsBridge()); Bridge->ClientConn = Connection; Bridge->Port = Port; Bridge->UpstreamClient = std::make_unique(WsUrl, *Bridge); try { Bridge->UpstreamClient->Connect(); } catch (const std::exception& Ex) { ZEN_WARN("proxy WebSocket connect to {} failed: {}", WsUrl, Ex.what()); Connection->Close(1011, "upstream connect failed"); return; } WebSocketConnection* Key = Connection.Get(); m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.emplace(Key, std::move(Bridge)); }); } void HttpProxyHandler::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) { Ref Bridge; m_WsBridgesLock.WithSharedLock([&] { auto It = m_WsBridges.find(&Conn); if (It != m_WsBridges.end()) { Bridge = It->second; } }); if (!Bridge || !Bridge->UpstreamClient) { return; } switch (Msg.Opcode) { case WebSocketOpcode::kText: Bridge->UpstreamClient->SendText(std::string_view(static_cast(Msg.Payload.GetData()), Msg.Payload.GetSize())); break; case WebSocketOpcode::kBinary: Bridge->UpstreamClient->SendBinary( std::span(static_cast(Msg.Payload.GetData()), Msg.Payload.GetSize())); break; case WebSocketOpcode::kClose: Bridge->UpstreamClient->Close(Msg.CloseCode, {}); break; default: break; } } void HttpProxyHandler::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) { Ref Bridge = m_WsBridgesLock.WithExclusiveLock([this, &Conn]() -> Ref { auto It = m_WsBridges.find(&Conn); if (It != m_WsBridges.end()) { Ref Bridge = std::move(It->second); m_WsBridges.erase(It); return Bridge; } return {}; }); if (Bridge && Bridge->UpstreamClient) { Bridge->UpstreamClient->Close(Code, Reason); } } #if ZEN_WITH_TESTS TEST_SUITE_BEGIN("server.httpproxyhandler"); TEST_CASE("server.httpproxyhandler.html_injection") { SUBCASE("injects before ") { std::string Result = InjectProxyScript("", 21005); CHECK(Result.find(""); size_t HeadClose = Result.find(""); REQUIRE(ScriptEnd != std::string::npos); REQUIRE(HeadClose != std::string::npos); CHECK(ScriptEnd < HeadClose); } SUBCASE("prepends when no ") { std::string Result = InjectProxyScript("content", 21005); CHECK(Result.find("