aboutsummaryrefslogtreecommitdiff
path: root/src/zenserver/proxy/tcpproxy.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenserver/proxy/tcpproxy.cpp')
-rw-r--r--src/zenserver/proxy/tcpproxy.cpp610
1 files changed, 610 insertions, 0 deletions
diff --git a/src/zenserver/proxy/tcpproxy.cpp b/src/zenserver/proxy/tcpproxy.cpp
new file mode 100644
index 000000000..bdc0de164
--- /dev/null
+++ b/src/zenserver/proxy/tcpproxy.cpp
@@ -0,0 +1,610 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "tcpproxy.h"
+
+#include <zencore/logging.h>
+
+#include <filesystem>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+// ProxyMapping
+
+std::string
+ProxyMapping::ListenDescription() const
+{
+ if (IsUnixListen())
+ {
+ return fmt::format("unix:{}", ListenUnixSocket);
+ }
+ std::string Addr = ListenAddress.empty() ? "0.0.0.0" : ListenAddress;
+ return fmt::format("{}:{}", Addr, ListenPort);
+}
+
+std::string
+ProxyMapping::TargetDescription() const
+{
+ if (IsUnixTarget())
+ {
+ return fmt::format("unix:{}", TargetUnixSocket);
+ }
+ return fmt::format("{}:{}", TargetHost, TargetPort);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// TcpProxySession
+
+TcpProxySession::TcpProxySession(asio::ip::tcp::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner)
+: m_ClientTcpSocket(std::move(ClientSocket))
+, m_UpstreamTcpSocket(m_ClientTcpSocket.get_executor())
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+, m_ClientUnixSocket(m_ClientTcpSocket.get_executor())
+, m_UpstreamUnixSocket(m_ClientTcpSocket.get_executor())
+, m_IsUnixClient(false)
+, m_IsUnixTarget(Mapping.IsUnixTarget())
+#endif
+, m_TargetHost(Mapping.TargetHost)
+, m_TargetPort(Mapping.TargetPort)
+, m_TargetUnixSocket(Mapping.TargetUnixSocket)
+, m_Owner(Owner)
+{
+}
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+TcpProxySession::TcpProxySession(asio::local::stream_protocol::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner)
+: m_ClientTcpSocket(ClientSocket.get_executor())
+, m_UpstreamTcpSocket(ClientSocket.get_executor())
+, m_ClientUnixSocket(std::move(ClientSocket))
+, m_UpstreamUnixSocket(m_ClientUnixSocket.get_executor())
+, m_IsUnixClient(true)
+, m_IsUnixTarget(Mapping.IsUnixTarget())
+, m_TargetHost(Mapping.TargetHost)
+, m_TargetPort(Mapping.TargetPort)
+, m_TargetUnixSocket(Mapping.TargetUnixSocket)
+, m_Owner(Owner)
+{
+}
+#endif
+
+LoggerRef
+TcpProxySession::Log()
+{
+ return m_Owner.Log();
+}
+
+void
+TcpProxySession::Start()
+{
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_IsUnixTarget)
+ {
+ ConnectToUnixTarget();
+ return;
+ }
+#endif
+ ConnectToTcpTarget();
+}
+
+void
+TcpProxySession::ConnectToTcpTarget()
+{
+ auto Self = shared_from_this();
+ auto Resolver = std::make_shared<asio::ip::tcp::resolver>(m_UpstreamTcpSocket.get_executor());
+
+ Resolver->async_resolve(m_TargetHost,
+ std::to_string(m_TargetPort),
+ [this, Self, Resolver](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) {
+ if (Ec)
+ {
+ ZEN_WARN("failed to resolve {}:{} - {}", m_TargetHost, m_TargetPort, Ec.message());
+ Shutdown();
+ return;
+ }
+
+ asio::async_connect(
+ m_UpstreamTcpSocket,
+ Results,
+ [this, Self](const asio::error_code& ConnectEc, const asio::ip::tcp::endpoint& /*Endpoint*/) {
+ if (ConnectEc)
+ {
+ ZEN_WARN("failed to connect to {}:{} - {}", m_TargetHost, m_TargetPort, ConnectEc.message());
+ Shutdown();
+ return;
+ }
+
+ StartRelay();
+ });
+ });
+}
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+void
+TcpProxySession::ConnectToUnixTarget()
+{
+ auto Self = shared_from_this();
+
+ asio::local::stream_protocol::endpoint Endpoint(m_TargetUnixSocket);
+
+ m_UpstreamUnixSocket.async_connect(Endpoint, [this, Self](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ ZEN_WARN("failed to connect to unix:{} - {}", m_TargetUnixSocket, Ec.message());
+ Shutdown();
+ return;
+ }
+
+ StartRelay();
+ });
+}
+#endif
+
+void
+TcpProxySession::StartRelay()
+{
+ asio::error_code Ec;
+
+ // TCP no_delay only applies to TCP sockets.
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (!m_IsUnixClient)
+#endif
+ {
+ m_ClientTcpSocket.set_option(asio::ip::tcp::no_delay(true), Ec);
+ }
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (!m_IsUnixTarget)
+#endif
+ {
+ m_UpstreamTcpSocket.set_option(asio::ip::tcp::no_delay(true), Ec);
+ }
+
+ std::string TargetLabel = m_Owner.GetMapping().TargetDescription();
+ std::string ClientLabel;
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_IsUnixClient)
+ {
+ ClientLabel = "unix";
+ }
+ else
+#endif
+ {
+ asio::ip::tcp::endpoint ClientEndpoint = m_ClientTcpSocket.remote_endpoint(Ec);
+ if (!Ec)
+ {
+ ClientLabel = fmt::format("{}:{}", ClientEndpoint.address().to_string(), ClientEndpoint.port());
+ }
+ else
+ {
+ ClientLabel = "?";
+ }
+ }
+
+ m_ClientLabel = ClientLabel;
+ m_StartTime = std::chrono::steady_clock::now();
+
+ std::string SessionLabel = fmt::format("{} -> {}", ClientLabel, TargetLabel);
+
+ ZEN_DEBUG("session established {}", SessionLabel);
+
+ m_RequestInspector.emplace(HttpTrafficInspector::Direction::Request, SessionLabel);
+ m_ResponseInspector.emplace(HttpTrafficInspector::Direction::Response, SessionLabel);
+
+ if (m_Owner.IsRecording())
+ {
+ std::string RecordDir = m_Owner.GetRecordDir();
+ if (!RecordDir.empty())
+ {
+ auto Now = std::chrono::system_clock::now();
+ uint64_t Ms = uint64_t(std::chrono::duration_cast<std::chrono::milliseconds>(Now.time_since_epoch()).count());
+ uint64_t Seq = m_Owner.m_RecordSessionCounter.fetch_add(1, std::memory_order_relaxed);
+
+ std::filesystem::path ConnDir = std::filesystem::path(RecordDir) / fmt::format("{}_{}", Ms, Seq);
+
+ m_Recorder = std::make_unique<HttpTrafficRecorder>(ConnDir, ClientLabel, TargetLabel);
+ if (m_Recorder->IsValid())
+ {
+ m_RequestInspector->SetObserver(m_Recorder.get());
+ m_ResponseInspector->SetObserver(m_Recorder.get());
+ }
+ else
+ {
+ m_Recorder.reset();
+ }
+ }
+ }
+
+ ReadFromClient();
+ ReadFromUpstream();
+}
+
+template<typename Fn>
+void
+TcpProxySession::DispatchClientSocket(Fn&& F)
+{
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_IsUnixClient)
+ {
+ F(m_ClientUnixSocket);
+ return;
+ }
+#endif
+ F(m_ClientTcpSocket);
+}
+
+template<typename Fn>
+void
+TcpProxySession::DispatchUpstreamSocket(Fn&& F)
+{
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_IsUnixTarget)
+ {
+ F(m_UpstreamUnixSocket);
+ return;
+ }
+#endif
+ F(m_UpstreamTcpSocket);
+}
+
+void
+TcpProxySession::ReadFromClient()
+{
+ DispatchClientSocket([this](auto& Client) { DoReadFromClient(Client); });
+}
+
+void
+TcpProxySession::ReadFromUpstream()
+{
+ DispatchUpstreamSocket([this](auto& Upstream) { DoReadFromUpstream(Upstream); });
+}
+
+template<typename SocketT>
+void
+TcpProxySession::DoReadFromClient(SocketT& ClientSocket)
+{
+ auto Self = shared_from_this();
+
+ ClientSocket.async_read_some(asio::buffer(m_ClientBuffer), [this, Self](const asio::error_code& Ec, size_t BytesRead) {
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_DEBUG("client read error - {}", Ec.message());
+ }
+ Shutdown();
+ return;
+ }
+
+ uint64_t RequestsBefore = m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0;
+ if (m_RequestInspector)
+ {
+ m_RequestInspector->Inspect(m_ClientBuffer.data(), BytesRead);
+ }
+ if (m_Recorder)
+ {
+ m_Recorder->WriteRequest(m_ClientBuffer.data(), BytesRead);
+ }
+ uint64_t RequestsAfter = m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0;
+ uint64_t NewRequests = RequestsAfter - RequestsBefore;
+
+ DispatchUpstreamSocket(
+ [this, Self, BytesRead, NewRequests](auto& Upstream) { DoForwardToUpstream(Upstream, BytesRead, NewRequests); });
+ });
+}
+
+template<typename SocketT>
+void
+TcpProxySession::DoForwardToUpstream(SocketT& UpstreamSocket, size_t BytesToWrite, uint64_t NewRequests)
+{
+ auto Self = shared_from_this();
+
+ asio::async_write(UpstreamSocket,
+ asio::buffer(m_ClientBuffer.data(), BytesToWrite),
+ [this, Self, BytesToWrite, NewRequests](const asio::error_code& WriteEc, size_t /*BytesWritten*/) {
+ if (WriteEc)
+ {
+ if (WriteEc != asio::error::operation_aborted)
+ {
+ ZEN_DEBUG("upstream write error - {}", WriteEc.message());
+ }
+ Shutdown();
+ return;
+ }
+
+ m_Owner.m_TotalBytesFromClient.fetch_add(BytesToWrite, std::memory_order_relaxed);
+ m_BytesFromClient.fetch_add(BytesToWrite, std::memory_order_relaxed);
+ m_Owner.m_BytesMeter.Mark(BytesToWrite);
+ if (NewRequests > 0)
+ {
+ m_Owner.m_RequestMeter.Mark(NewRequests);
+ }
+ ReadFromClient();
+ });
+}
+
+template<typename SocketT>
+void
+TcpProxySession::DoReadFromUpstream(SocketT& UpstreamSocket)
+{
+ auto Self = shared_from_this();
+
+ UpstreamSocket.async_read_some(asio::buffer(m_UpstreamBuffer), [this, Self](const asio::error_code& Ec, size_t BytesRead) {
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_DEBUG("upstream read error - {}", Ec.message());
+ }
+ Shutdown();
+ return;
+ }
+
+ if (m_ResponseInspector)
+ {
+ m_ResponseInspector->Inspect(m_UpstreamBuffer.data(), BytesRead);
+ }
+ if (m_Recorder)
+ {
+ m_Recorder->WriteResponse(m_UpstreamBuffer.data(), BytesRead);
+ }
+
+ DispatchClientSocket([this, Self, BytesRead](auto& Client) { DoForwardToClient(Client, BytesRead); });
+ });
+}
+
+template<typename SocketT>
+void
+TcpProxySession::DoForwardToClient(SocketT& ClientSocket, size_t BytesToWrite)
+{
+ auto Self = shared_from_this();
+
+ asio::async_write(ClientSocket,
+ asio::buffer(m_UpstreamBuffer.data(), BytesToWrite),
+ [this, Self, BytesToWrite](const asio::error_code& WriteEc, size_t /*BytesWritten*/) {
+ if (WriteEc)
+ {
+ if (WriteEc != asio::error::operation_aborted)
+ {
+ ZEN_DEBUG("client write error - {}", WriteEc.message());
+ }
+ Shutdown();
+ return;
+ }
+
+ m_Owner.m_TotalBytesToClient.fetch_add(BytesToWrite, std::memory_order_relaxed);
+ m_BytesToClient.fetch_add(BytesToWrite, std::memory_order_relaxed);
+ m_Owner.m_BytesMeter.Mark(BytesToWrite);
+ ReadFromUpstream();
+ });
+}
+
+template<typename SocketT>
+void
+TcpProxySession::DoShutdownSocket(SocketT& Socket)
+{
+ if (Socket.is_open())
+ {
+ asio::error_code Ec;
+ Socket.shutdown(asio::socket_base::shutdown_both, Ec);
+ Socket.close(Ec);
+ }
+}
+
+void
+TcpProxySession::Shutdown()
+{
+ if (m_ShutdownCalled.exchange(true))
+ {
+ return;
+ }
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_IsUnixClient)
+ {
+ DoShutdownSocket(m_ClientUnixSocket);
+ }
+ else
+#endif
+ {
+ DoShutdownSocket(m_ClientTcpSocket);
+ }
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_IsUnixTarget)
+ {
+ DoShutdownSocket(m_UpstreamUnixSocket);
+ }
+ else
+#endif
+ {
+ DoShutdownSocket(m_UpstreamTcpSocket);
+ }
+
+ if (m_Recorder)
+ {
+ bool WebSocket = m_RequestInspector && m_RequestInspector->IsUpgraded();
+ Oid SessionId = m_RequestInspector ? m_RequestInspector->GetSessionId() : Oid::Zero;
+ m_Recorder->Finalize(WebSocket, SessionId);
+ }
+
+ m_Owner.m_ActiveConnections.fetch_sub(1, std::memory_order_relaxed);
+ m_Owner.RemoveSession(this);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// TcpProxyService
+
+TcpProxyService::TcpProxyService(asio::io_context& IoContext, const ProxyMapping& Mapping)
+: m_Log(logging::Get("proxy"))
+, m_Mapping(Mapping)
+, m_IoContext(IoContext)
+, m_TcpAcceptor(IoContext)
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+, m_UnixAcceptor(IoContext)
+#endif
+{
+ if (!Mapping.IsUnixListen())
+ {
+ asio::ip::address ListenAddr =
+ Mapping.ListenAddress.empty() ? asio::ip::address_v4::any() : asio::ip::make_address(Mapping.ListenAddress);
+ m_ListenEndpoint = asio::ip::tcp::endpoint(ListenAddr, Mapping.ListenPort);
+ }
+}
+
+void
+TcpProxyService::Start()
+{
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_Mapping.IsUnixListen())
+ {
+ // Remove stale socket file if it exists.
+ std::error_code RemoveEc;
+ std::filesystem::remove(m_Mapping.ListenUnixSocket, RemoveEc);
+
+ asio::local::stream_protocol::endpoint Endpoint(m_Mapping.ListenUnixSocket);
+ m_UnixAcceptor.open(Endpoint.protocol());
+ m_UnixAcceptor.bind(Endpoint);
+ m_UnixAcceptor.listen();
+
+ ZEN_INFO("listening on {} -> {}", m_Mapping.ListenDescription(), m_Mapping.TargetDescription());
+
+ DoAcceptUnix();
+ return;
+ }
+#endif
+
+ m_TcpAcceptor.open(m_ListenEndpoint.protocol());
+ m_TcpAcceptor.set_option(asio::ip::tcp::acceptor::reuse_address(true));
+ m_TcpAcceptor.bind(m_ListenEndpoint);
+ m_TcpAcceptor.listen();
+
+ ZEN_INFO("listening on {} -> {}", m_Mapping.ListenDescription(), m_Mapping.TargetDescription());
+
+ DoAccept();
+}
+
+void
+TcpProxyService::Stop()
+{
+ m_Stopped = true;
+
+ asio::error_code Ec;
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_Mapping.IsUnixListen())
+ {
+ m_UnixAcceptor.close(Ec);
+
+ // Clean up the socket file.
+ std::error_code RemoveEc;
+ std::filesystem::remove(m_Mapping.ListenUnixSocket, RemoveEc);
+ return;
+ }
+#endif
+
+ m_TcpAcceptor.close(Ec);
+}
+
+void
+TcpProxyService::OnAcceptedSession(std::shared_ptr<TcpProxySession> Session)
+{
+ m_TotalConnections.fetch_add(1, std::memory_order_relaxed);
+ uint64_t Active = m_ActiveConnections.fetch_add(1, std::memory_order_relaxed) + 1;
+ uint64_t Peak = m_PeakActiveConnections.load(std::memory_order_relaxed);
+ while (Active > Peak && !m_PeakActiveConnections.compare_exchange_weak(Peak, Active, std::memory_order_relaxed))
+ ;
+ AddSession(Session);
+ Session->Start();
+}
+
+void
+TcpProxyService::AddSession(std::shared_ptr<TcpProxySession> Session)
+{
+ RwLock::ExclusiveLockScope Lock(m_SessionsLock);
+ m_Sessions.push_back(std::move(Session));
+}
+
+void
+TcpProxyService::RemoveSession(TcpProxySession* Session)
+{
+ RwLock::ExclusiveLockScope Lock(m_SessionsLock);
+ auto It = std::find_if(m_Sessions.begin(), m_Sessions.end(), [Session](const std::shared_ptr<TcpProxySession>& S) {
+ return S.get() == Session;
+ });
+ if (It != m_Sessions.end())
+ {
+ // Swap-and-pop for O(1) removal; order doesn't matter.
+ std::swap(*It, m_Sessions.back());
+ m_Sessions.pop_back();
+ }
+}
+
+std::vector<std::shared_ptr<TcpProxySession>>
+TcpProxyService::GetActiveSessions() const
+{
+ RwLock::SharedLockScope Lock(m_SessionsLock);
+ return m_Sessions;
+}
+
+void
+TcpProxyService::SetRecording(bool Enabled, const std::string& Dir)
+{
+ {
+ RwLock::ExclusiveLockScope Lock(m_RecordDirLock);
+ m_RecordDir = Dir;
+ }
+ m_RecordingEnabled.store(Enabled, std::memory_order_relaxed);
+ ZEN_INFO("proxy recording {} (dir: {})", Enabled ? "enabled" : "disabled", Dir);
+}
+
+std::string
+TcpProxyService::GetRecordDir() const
+{
+ RwLock::SharedLockScope Lock(m_RecordDirLock);
+ return m_RecordDir;
+}
+
+void
+TcpProxyService::DoAccept()
+{
+ m_TcpAcceptor.async_accept([this](const asio::error_code& Ec, asio::ip::tcp::socket Socket) {
+ if (Ec)
+ {
+ if (!m_Stopped)
+ {
+ ZEN_WARN("accept error - {}", Ec.message());
+ }
+ return;
+ }
+
+ ZEN_DEBUG("accepted connection from {}", Socket.remote_endpoint().address().to_string());
+
+ OnAcceptedSession(std::make_shared<TcpProxySession>(std::move(Socket), m_Mapping, *this));
+ DoAccept();
+ });
+}
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+void
+TcpProxyService::DoAcceptUnix()
+{
+ m_UnixAcceptor.async_accept([this](const asio::error_code& Ec, asio::local::stream_protocol::socket Socket) {
+ if (Ec)
+ {
+ if (!m_Stopped)
+ {
+ ZEN_WARN("accept error - {}", Ec.message());
+ }
+ return;
+ }
+
+ ZEN_DEBUG("accepted unix connection");
+
+ OnAcceptedSession(std::make_shared<TcpProxySession>(std::move(Socket), m_Mapping, *this));
+ DoAcceptUnix();
+ });
+}
+#endif
+
+} // namespace zen