diff options
Diffstat (limited to 'src/zenserver/proxy/tcpproxy.cpp')
| -rw-r--r-- | src/zenserver/proxy/tcpproxy.cpp | 610 |
1 files changed, 610 insertions, 0 deletions
diff --git a/src/zenserver/proxy/tcpproxy.cpp b/src/zenserver/proxy/tcpproxy.cpp new file mode 100644 index 000000000..bdc0de164 --- /dev/null +++ b/src/zenserver/proxy/tcpproxy.cpp @@ -0,0 +1,610 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "tcpproxy.h" + +#include <zencore/logging.h> + +#include <filesystem> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// ProxyMapping + +std::string +ProxyMapping::ListenDescription() const +{ + if (IsUnixListen()) + { + return fmt::format("unix:{}", ListenUnixSocket); + } + std::string Addr = ListenAddress.empty() ? "0.0.0.0" : ListenAddress; + return fmt::format("{}:{}", Addr, ListenPort); +} + +std::string +ProxyMapping::TargetDescription() const +{ + if (IsUnixTarget()) + { + return fmt::format("unix:{}", TargetUnixSocket); + } + return fmt::format("{}:{}", TargetHost, TargetPort); +} + +////////////////////////////////////////////////////////////////////////// +// TcpProxySession + +TcpProxySession::TcpProxySession(asio::ip::tcp::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner) +: m_ClientTcpSocket(std::move(ClientSocket)) +, m_UpstreamTcpSocket(m_ClientTcpSocket.get_executor()) +#if defined(ASIO_HAS_LOCAL_SOCKETS) +, m_ClientUnixSocket(m_ClientTcpSocket.get_executor()) +, m_UpstreamUnixSocket(m_ClientTcpSocket.get_executor()) +, m_IsUnixClient(false) +, m_IsUnixTarget(Mapping.IsUnixTarget()) +#endif +, m_TargetHost(Mapping.TargetHost) +, m_TargetPort(Mapping.TargetPort) +, m_TargetUnixSocket(Mapping.TargetUnixSocket) +, m_Owner(Owner) +{ +} + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +TcpProxySession::TcpProxySession(asio::local::stream_protocol::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner) +: m_ClientTcpSocket(ClientSocket.get_executor()) +, m_UpstreamTcpSocket(ClientSocket.get_executor()) +, m_ClientUnixSocket(std::move(ClientSocket)) +, m_UpstreamUnixSocket(m_ClientUnixSocket.get_executor()) +, m_IsUnixClient(true) +, m_IsUnixTarget(Mapping.IsUnixTarget()) +, m_TargetHost(Mapping.TargetHost) +, m_TargetPort(Mapping.TargetPort) +, m_TargetUnixSocket(Mapping.TargetUnixSocket) +, m_Owner(Owner) +{ +} +#endif + +LoggerRef +TcpProxySession::Log() +{ + return m_Owner.Log(); +} + +void +TcpProxySession::Start() +{ +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixTarget) + { + ConnectToUnixTarget(); + return; + } +#endif + ConnectToTcpTarget(); +} + +void +TcpProxySession::ConnectToTcpTarget() +{ + auto Self = shared_from_this(); + auto Resolver = std::make_shared<asio::ip::tcp::resolver>(m_UpstreamTcpSocket.get_executor()); + + Resolver->async_resolve(m_TargetHost, + std::to_string(m_TargetPort), + [this, Self, Resolver](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) { + if (Ec) + { + ZEN_WARN("failed to resolve {}:{} - {}", m_TargetHost, m_TargetPort, Ec.message()); + Shutdown(); + return; + } + + asio::async_connect( + m_UpstreamTcpSocket, + Results, + [this, Self](const asio::error_code& ConnectEc, const asio::ip::tcp::endpoint& /*Endpoint*/) { + if (ConnectEc) + { + ZEN_WARN("failed to connect to {}:{} - {}", m_TargetHost, m_TargetPort, ConnectEc.message()); + Shutdown(); + return; + } + + StartRelay(); + }); + }); +} + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +void +TcpProxySession::ConnectToUnixTarget() +{ + auto Self = shared_from_this(); + + asio::local::stream_protocol::endpoint Endpoint(m_TargetUnixSocket); + + m_UpstreamUnixSocket.async_connect(Endpoint, [this, Self](const asio::error_code& Ec) { + if (Ec) + { + ZEN_WARN("failed to connect to unix:{} - {}", m_TargetUnixSocket, Ec.message()); + Shutdown(); + return; + } + + StartRelay(); + }); +} +#endif + +void +TcpProxySession::StartRelay() +{ + asio::error_code Ec; + + // TCP no_delay only applies to TCP sockets. +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (!m_IsUnixClient) +#endif + { + m_ClientTcpSocket.set_option(asio::ip::tcp::no_delay(true), Ec); + } + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (!m_IsUnixTarget) +#endif + { + m_UpstreamTcpSocket.set_option(asio::ip::tcp::no_delay(true), Ec); + } + + std::string TargetLabel = m_Owner.GetMapping().TargetDescription(); + std::string ClientLabel; + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixClient) + { + ClientLabel = "unix"; + } + else +#endif + { + asio::ip::tcp::endpoint ClientEndpoint = m_ClientTcpSocket.remote_endpoint(Ec); + if (!Ec) + { + ClientLabel = fmt::format("{}:{}", ClientEndpoint.address().to_string(), ClientEndpoint.port()); + } + else + { + ClientLabel = "?"; + } + } + + m_ClientLabel = ClientLabel; + m_StartTime = std::chrono::steady_clock::now(); + + std::string SessionLabel = fmt::format("{} -> {}", ClientLabel, TargetLabel); + + ZEN_DEBUG("session established {}", SessionLabel); + + m_RequestInspector.emplace(HttpTrafficInspector::Direction::Request, SessionLabel); + m_ResponseInspector.emplace(HttpTrafficInspector::Direction::Response, SessionLabel); + + if (m_Owner.IsRecording()) + { + std::string RecordDir = m_Owner.GetRecordDir(); + if (!RecordDir.empty()) + { + auto Now = std::chrono::system_clock::now(); + uint64_t Ms = uint64_t(std::chrono::duration_cast<std::chrono::milliseconds>(Now.time_since_epoch()).count()); + uint64_t Seq = m_Owner.m_RecordSessionCounter.fetch_add(1, std::memory_order_relaxed); + + std::filesystem::path ConnDir = std::filesystem::path(RecordDir) / fmt::format("{}_{}", Ms, Seq); + + m_Recorder = std::make_unique<HttpTrafficRecorder>(ConnDir, ClientLabel, TargetLabel); + if (m_Recorder->IsValid()) + { + m_RequestInspector->SetObserver(m_Recorder.get()); + m_ResponseInspector->SetObserver(m_Recorder.get()); + } + else + { + m_Recorder.reset(); + } + } + } + + ReadFromClient(); + ReadFromUpstream(); +} + +template<typename Fn> +void +TcpProxySession::DispatchClientSocket(Fn&& F) +{ +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixClient) + { + F(m_ClientUnixSocket); + return; + } +#endif + F(m_ClientTcpSocket); +} + +template<typename Fn> +void +TcpProxySession::DispatchUpstreamSocket(Fn&& F) +{ +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixTarget) + { + F(m_UpstreamUnixSocket); + return; + } +#endif + F(m_UpstreamTcpSocket); +} + +void +TcpProxySession::ReadFromClient() +{ + DispatchClientSocket([this](auto& Client) { DoReadFromClient(Client); }); +} + +void +TcpProxySession::ReadFromUpstream() +{ + DispatchUpstreamSocket([this](auto& Upstream) { DoReadFromUpstream(Upstream); }); +} + +template<typename SocketT> +void +TcpProxySession::DoReadFromClient(SocketT& ClientSocket) +{ + auto Self = shared_from_this(); + + ClientSocket.async_read_some(asio::buffer(m_ClientBuffer), [this, Self](const asio::error_code& Ec, size_t BytesRead) { + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_DEBUG("client read error - {}", Ec.message()); + } + Shutdown(); + return; + } + + uint64_t RequestsBefore = m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0; + if (m_RequestInspector) + { + m_RequestInspector->Inspect(m_ClientBuffer.data(), BytesRead); + } + if (m_Recorder) + { + m_Recorder->WriteRequest(m_ClientBuffer.data(), BytesRead); + } + uint64_t RequestsAfter = m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0; + uint64_t NewRequests = RequestsAfter - RequestsBefore; + + DispatchUpstreamSocket( + [this, Self, BytesRead, NewRequests](auto& Upstream) { DoForwardToUpstream(Upstream, BytesRead, NewRequests); }); + }); +} + +template<typename SocketT> +void +TcpProxySession::DoForwardToUpstream(SocketT& UpstreamSocket, size_t BytesToWrite, uint64_t NewRequests) +{ + auto Self = shared_from_this(); + + asio::async_write(UpstreamSocket, + asio::buffer(m_ClientBuffer.data(), BytesToWrite), + [this, Self, BytesToWrite, NewRequests](const asio::error_code& WriteEc, size_t /*BytesWritten*/) { + if (WriteEc) + { + if (WriteEc != asio::error::operation_aborted) + { + ZEN_DEBUG("upstream write error - {}", WriteEc.message()); + } + Shutdown(); + return; + } + + m_Owner.m_TotalBytesFromClient.fetch_add(BytesToWrite, std::memory_order_relaxed); + m_BytesFromClient.fetch_add(BytesToWrite, std::memory_order_relaxed); + m_Owner.m_BytesMeter.Mark(BytesToWrite); + if (NewRequests > 0) + { + m_Owner.m_RequestMeter.Mark(NewRequests); + } + ReadFromClient(); + }); +} + +template<typename SocketT> +void +TcpProxySession::DoReadFromUpstream(SocketT& UpstreamSocket) +{ + auto Self = shared_from_this(); + + UpstreamSocket.async_read_some(asio::buffer(m_UpstreamBuffer), [this, Self](const asio::error_code& Ec, size_t BytesRead) { + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_DEBUG("upstream read error - {}", Ec.message()); + } + Shutdown(); + return; + } + + if (m_ResponseInspector) + { + m_ResponseInspector->Inspect(m_UpstreamBuffer.data(), BytesRead); + } + if (m_Recorder) + { + m_Recorder->WriteResponse(m_UpstreamBuffer.data(), BytesRead); + } + + DispatchClientSocket([this, Self, BytesRead](auto& Client) { DoForwardToClient(Client, BytesRead); }); + }); +} + +template<typename SocketT> +void +TcpProxySession::DoForwardToClient(SocketT& ClientSocket, size_t BytesToWrite) +{ + auto Self = shared_from_this(); + + asio::async_write(ClientSocket, + asio::buffer(m_UpstreamBuffer.data(), BytesToWrite), + [this, Self, BytesToWrite](const asio::error_code& WriteEc, size_t /*BytesWritten*/) { + if (WriteEc) + { + if (WriteEc != asio::error::operation_aborted) + { + ZEN_DEBUG("client write error - {}", WriteEc.message()); + } + Shutdown(); + return; + } + + m_Owner.m_TotalBytesToClient.fetch_add(BytesToWrite, std::memory_order_relaxed); + m_BytesToClient.fetch_add(BytesToWrite, std::memory_order_relaxed); + m_Owner.m_BytesMeter.Mark(BytesToWrite); + ReadFromUpstream(); + }); +} + +template<typename SocketT> +void +TcpProxySession::DoShutdownSocket(SocketT& Socket) +{ + if (Socket.is_open()) + { + asio::error_code Ec; + Socket.shutdown(asio::socket_base::shutdown_both, Ec); + Socket.close(Ec); + } +} + +void +TcpProxySession::Shutdown() +{ + if (m_ShutdownCalled.exchange(true)) + { + return; + } + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixClient) + { + DoShutdownSocket(m_ClientUnixSocket); + } + else +#endif + { + DoShutdownSocket(m_ClientTcpSocket); + } + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixTarget) + { + DoShutdownSocket(m_UpstreamUnixSocket); + } + else +#endif + { + DoShutdownSocket(m_UpstreamTcpSocket); + } + + if (m_Recorder) + { + bool WebSocket = m_RequestInspector && m_RequestInspector->IsUpgraded(); + Oid SessionId = m_RequestInspector ? m_RequestInspector->GetSessionId() : Oid::Zero; + m_Recorder->Finalize(WebSocket, SessionId); + } + + m_Owner.m_ActiveConnections.fetch_sub(1, std::memory_order_relaxed); + m_Owner.RemoveSession(this); +} + +////////////////////////////////////////////////////////////////////////// +// TcpProxyService + +TcpProxyService::TcpProxyService(asio::io_context& IoContext, const ProxyMapping& Mapping) +: m_Log(logging::Get("proxy")) +, m_Mapping(Mapping) +, m_IoContext(IoContext) +, m_TcpAcceptor(IoContext) +#if defined(ASIO_HAS_LOCAL_SOCKETS) +, m_UnixAcceptor(IoContext) +#endif +{ + if (!Mapping.IsUnixListen()) + { + asio::ip::address ListenAddr = + Mapping.ListenAddress.empty() ? asio::ip::address_v4::any() : asio::ip::make_address(Mapping.ListenAddress); + m_ListenEndpoint = asio::ip::tcp::endpoint(ListenAddr, Mapping.ListenPort); + } +} + +void +TcpProxyService::Start() +{ +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_Mapping.IsUnixListen()) + { + // Remove stale socket file if it exists. + std::error_code RemoveEc; + std::filesystem::remove(m_Mapping.ListenUnixSocket, RemoveEc); + + asio::local::stream_protocol::endpoint Endpoint(m_Mapping.ListenUnixSocket); + m_UnixAcceptor.open(Endpoint.protocol()); + m_UnixAcceptor.bind(Endpoint); + m_UnixAcceptor.listen(); + + ZEN_INFO("listening on {} -> {}", m_Mapping.ListenDescription(), m_Mapping.TargetDescription()); + + DoAcceptUnix(); + return; + } +#endif + + m_TcpAcceptor.open(m_ListenEndpoint.protocol()); + m_TcpAcceptor.set_option(asio::ip::tcp::acceptor::reuse_address(true)); + m_TcpAcceptor.bind(m_ListenEndpoint); + m_TcpAcceptor.listen(); + + ZEN_INFO("listening on {} -> {}", m_Mapping.ListenDescription(), m_Mapping.TargetDescription()); + + DoAccept(); +} + +void +TcpProxyService::Stop() +{ + m_Stopped = true; + + asio::error_code Ec; + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_Mapping.IsUnixListen()) + { + m_UnixAcceptor.close(Ec); + + // Clean up the socket file. + std::error_code RemoveEc; + std::filesystem::remove(m_Mapping.ListenUnixSocket, RemoveEc); + return; + } +#endif + + m_TcpAcceptor.close(Ec); +} + +void +TcpProxyService::OnAcceptedSession(std::shared_ptr<TcpProxySession> Session) +{ + m_TotalConnections.fetch_add(1, std::memory_order_relaxed); + uint64_t Active = m_ActiveConnections.fetch_add(1, std::memory_order_relaxed) + 1; + uint64_t Peak = m_PeakActiveConnections.load(std::memory_order_relaxed); + while (Active > Peak && !m_PeakActiveConnections.compare_exchange_weak(Peak, Active, std::memory_order_relaxed)) + ; + AddSession(Session); + Session->Start(); +} + +void +TcpProxyService::AddSession(std::shared_ptr<TcpProxySession> Session) +{ + RwLock::ExclusiveLockScope Lock(m_SessionsLock); + m_Sessions.push_back(std::move(Session)); +} + +void +TcpProxyService::RemoveSession(TcpProxySession* Session) +{ + RwLock::ExclusiveLockScope Lock(m_SessionsLock); + auto It = std::find_if(m_Sessions.begin(), m_Sessions.end(), [Session](const std::shared_ptr<TcpProxySession>& S) { + return S.get() == Session; + }); + if (It != m_Sessions.end()) + { + // Swap-and-pop for O(1) removal; order doesn't matter. + std::swap(*It, m_Sessions.back()); + m_Sessions.pop_back(); + } +} + +std::vector<std::shared_ptr<TcpProxySession>> +TcpProxyService::GetActiveSessions() const +{ + RwLock::SharedLockScope Lock(m_SessionsLock); + return m_Sessions; +} + +void +TcpProxyService::SetRecording(bool Enabled, const std::string& Dir) +{ + { + RwLock::ExclusiveLockScope Lock(m_RecordDirLock); + m_RecordDir = Dir; + } + m_RecordingEnabled.store(Enabled, std::memory_order_relaxed); + ZEN_INFO("proxy recording {} (dir: {})", Enabled ? "enabled" : "disabled", Dir); +} + +std::string +TcpProxyService::GetRecordDir() const +{ + RwLock::SharedLockScope Lock(m_RecordDirLock); + return m_RecordDir; +} + +void +TcpProxyService::DoAccept() +{ + m_TcpAcceptor.async_accept([this](const asio::error_code& Ec, asio::ip::tcp::socket Socket) { + if (Ec) + { + if (!m_Stopped) + { + ZEN_WARN("accept error - {}", Ec.message()); + } + return; + } + + ZEN_DEBUG("accepted connection from {}", Socket.remote_endpoint().address().to_string()); + + OnAcceptedSession(std::make_shared<TcpProxySession>(std::move(Socket), m_Mapping, *this)); + DoAccept(); + }); +} + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +void +TcpProxyService::DoAcceptUnix() +{ + m_UnixAcceptor.async_accept([this](const asio::error_code& Ec, asio::local::stream_protocol::socket Socket) { + if (Ec) + { + if (!m_Stopped) + { + ZEN_WARN("accept error - {}", Ec.message()); + } + return; + } + + ZEN_DEBUG("accepted unix connection"); + + OnAcceptedSession(std::make_shared<TcpProxySession>(std::move(Socket), m_Mapping, *this)); + DoAcceptUnix(); + }); +} +#endif + +} // namespace zen |