// Copyright Epic Games, Inc. All Rights Reserved. #include "tcpproxy.h" #include #include namespace zen { ////////////////////////////////////////////////////////////////////////// // ProxyMapping std::string ProxyMapping::ListenDescription() const { if (IsUnixListen()) { return fmt::format("unix:{}", ListenUnixSocket); } std::string Addr = ListenAddress.empty() ? "0.0.0.0" : ListenAddress; return fmt::format("{}:{}", Addr, ListenPort); } std::string ProxyMapping::TargetDescription() const { if (IsUnixTarget()) { return fmt::format("unix:{}", TargetUnixSocket); } return fmt::format("{}:{}", TargetHost, TargetPort); } ////////////////////////////////////////////////////////////////////////// // TcpProxySession TcpProxySession::TcpProxySession(asio::ip::tcp::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner) : m_ClientTcpSocket(std::move(ClientSocket)) , m_UpstreamTcpSocket(m_ClientTcpSocket.get_executor()) #if defined(ASIO_HAS_LOCAL_SOCKETS) , m_ClientUnixSocket(m_ClientTcpSocket.get_executor()) , m_UpstreamUnixSocket(m_ClientTcpSocket.get_executor()) , m_IsUnixClient(false) , m_IsUnixTarget(Mapping.IsUnixTarget()) #endif , m_TargetHost(Mapping.TargetHost) , m_TargetPort(Mapping.TargetPort) , m_TargetUnixSocket(Mapping.TargetUnixSocket) , m_Owner(Owner) { } #if defined(ASIO_HAS_LOCAL_SOCKETS) TcpProxySession::TcpProxySession(asio::local::stream_protocol::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner) : m_ClientTcpSocket(ClientSocket.get_executor()) , m_UpstreamTcpSocket(ClientSocket.get_executor()) , m_ClientUnixSocket(std::move(ClientSocket)) , m_UpstreamUnixSocket(m_ClientUnixSocket.get_executor()) , m_IsUnixClient(true) , m_IsUnixTarget(Mapping.IsUnixTarget()) , m_TargetHost(Mapping.TargetHost) , m_TargetPort(Mapping.TargetPort) , m_TargetUnixSocket(Mapping.TargetUnixSocket) , m_Owner(Owner) { } #endif LoggerRef TcpProxySession::Log() { return m_Owner.Log(); } void TcpProxySession::Start() { #if defined(ASIO_HAS_LOCAL_SOCKETS) if (m_IsUnixTarget) { ConnectToUnixTarget(); return; } #endif ConnectToTcpTarget(); } void TcpProxySession::ConnectToTcpTarget() { auto Self = shared_from_this(); auto Resolver = std::make_shared(m_UpstreamTcpSocket.get_executor()); Resolver->async_resolve(m_TargetHost, std::to_string(m_TargetPort), [this, Self, Resolver](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) { if (Ec) { ZEN_WARN("failed to resolve {}:{} - {}", m_TargetHost, m_TargetPort, Ec.message()); Shutdown(); return; } asio::async_connect( m_UpstreamTcpSocket, Results, [this, Self](const asio::error_code& ConnectEc, const asio::ip::tcp::endpoint& /*Endpoint*/) { if (ConnectEc) { ZEN_WARN("failed to connect to {}:{} - {}", m_TargetHost, m_TargetPort, ConnectEc.message()); Shutdown(); return; } StartRelay(); }); }); } #if defined(ASIO_HAS_LOCAL_SOCKETS) void TcpProxySession::ConnectToUnixTarget() { auto Self = shared_from_this(); asio::local::stream_protocol::endpoint Endpoint(m_TargetUnixSocket); m_UpstreamUnixSocket.async_connect(Endpoint, [this, Self](const asio::error_code& Ec) { if (Ec) { ZEN_WARN("failed to connect to unix:{} - {}", m_TargetUnixSocket, Ec.message()); Shutdown(); return; } StartRelay(); }); } #endif void TcpProxySession::StartRelay() { asio::error_code Ec; // TCP no_delay only applies to TCP sockets. #if defined(ASIO_HAS_LOCAL_SOCKETS) if (!m_IsUnixClient) #endif { m_ClientTcpSocket.set_option(asio::ip::tcp::no_delay(true), Ec); } #if defined(ASIO_HAS_LOCAL_SOCKETS) if (!m_IsUnixTarget) #endif { m_UpstreamTcpSocket.set_option(asio::ip::tcp::no_delay(true), Ec); } std::string TargetLabel = m_Owner.GetMapping().TargetDescription(); std::string ClientLabel; #if defined(ASIO_HAS_LOCAL_SOCKETS) if (m_IsUnixClient) { ClientLabel = "unix"; } else #endif { asio::ip::tcp::endpoint ClientEndpoint = m_ClientTcpSocket.remote_endpoint(Ec); if (!Ec) { ClientLabel = fmt::format("{}:{}", ClientEndpoint.address().to_string(), ClientEndpoint.port()); } else { ClientLabel = "?"; } } m_ClientLabel = ClientLabel; m_StartTime = std::chrono::steady_clock::now(); std::string SessionLabel = fmt::format("{} -> {}", ClientLabel, TargetLabel); ZEN_DEBUG("session established {}", SessionLabel); m_RequestInspector.emplace(HttpTrafficInspector::Direction::Request, SessionLabel); m_ResponseInspector.emplace(HttpTrafficInspector::Direction::Response, SessionLabel); if (m_Owner.IsRecording()) { std::string RecordDir = m_Owner.GetRecordDir(); if (!RecordDir.empty()) { auto Now = std::chrono::system_clock::now(); uint64_t Ms = uint64_t(std::chrono::duration_cast(Now.time_since_epoch()).count()); uint64_t Seq = m_Owner.m_RecordSessionCounter.fetch_add(1, std::memory_order_relaxed); std::filesystem::path ConnDir = std::filesystem::path(RecordDir) / fmt::format("{}_{}", Ms, Seq); m_Recorder = std::make_unique(ConnDir, ClientLabel, TargetLabel); if (m_Recorder->IsValid()) { m_RequestInspector->SetObserver(m_Recorder.get()); m_ResponseInspector->SetObserver(m_Recorder.get()); } else { m_Recorder.reset(); } } } ReadFromClient(); ReadFromUpstream(); } template void TcpProxySession::DispatchClientSocket(Fn&& F) { #if defined(ASIO_HAS_LOCAL_SOCKETS) if (m_IsUnixClient) { F(m_ClientUnixSocket); return; } #endif F(m_ClientTcpSocket); } template void TcpProxySession::DispatchUpstreamSocket(Fn&& F) { #if defined(ASIO_HAS_LOCAL_SOCKETS) if (m_IsUnixTarget) { F(m_UpstreamUnixSocket); return; } #endif F(m_UpstreamTcpSocket); } void TcpProxySession::ReadFromClient() { DispatchClientSocket([this](auto& Client) { DoReadFromClient(Client); }); } void TcpProxySession::ReadFromUpstream() { DispatchUpstreamSocket([this](auto& Upstream) { DoReadFromUpstream(Upstream); }); } template void TcpProxySession::DoReadFromClient(SocketT& ClientSocket) { auto Self = shared_from_this(); ClientSocket.async_read_some(asio::buffer(m_ClientBuffer), [this, Self](const asio::error_code& Ec, size_t BytesRead) { if (Ec) { if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) { ZEN_DEBUG("client read error - {}", Ec.message()); } Shutdown(); return; } uint64_t RequestsBefore = m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0; if (m_RequestInspector) { m_RequestInspector->Inspect(m_ClientBuffer.data(), BytesRead); } if (m_Recorder) { m_Recorder->WriteRequest(m_ClientBuffer.data(), BytesRead); } uint64_t RequestsAfter = m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0; uint64_t NewRequests = RequestsAfter - RequestsBefore; DispatchUpstreamSocket( [this, Self, BytesRead, NewRequests](auto& Upstream) { DoForwardToUpstream(Upstream, BytesRead, NewRequests); }); }); } template void TcpProxySession::DoForwardToUpstream(SocketT& UpstreamSocket, size_t BytesToWrite, uint64_t NewRequests) { auto Self = shared_from_this(); asio::async_write(UpstreamSocket, asio::buffer(m_ClientBuffer.data(), BytesToWrite), [this, Self, BytesToWrite, NewRequests](const asio::error_code& WriteEc, size_t /*BytesWritten*/) { if (WriteEc) { if (WriteEc != asio::error::operation_aborted) { ZEN_DEBUG("upstream write error - {}", WriteEc.message()); } Shutdown(); return; } m_Owner.m_TotalBytesFromClient.fetch_add(BytesToWrite, std::memory_order_relaxed); m_BytesFromClient.fetch_add(BytesToWrite, std::memory_order_relaxed); m_Owner.m_BytesMeter.Mark(BytesToWrite); if (NewRequests > 0) { m_Owner.m_RequestMeter.Mark(NewRequests); } ReadFromClient(); }); } template void TcpProxySession::DoReadFromUpstream(SocketT& UpstreamSocket) { auto Self = shared_from_this(); UpstreamSocket.async_read_some(asio::buffer(m_UpstreamBuffer), [this, Self](const asio::error_code& Ec, size_t BytesRead) { if (Ec) { if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) { ZEN_DEBUG("upstream read error - {}", Ec.message()); } Shutdown(); return; } if (m_ResponseInspector) { m_ResponseInspector->Inspect(m_UpstreamBuffer.data(), BytesRead); } if (m_Recorder) { m_Recorder->WriteResponse(m_UpstreamBuffer.data(), BytesRead); } DispatchClientSocket([this, Self, BytesRead](auto& Client) { DoForwardToClient(Client, BytesRead); }); }); } template void TcpProxySession::DoForwardToClient(SocketT& ClientSocket, size_t BytesToWrite) { auto Self = shared_from_this(); asio::async_write(ClientSocket, asio::buffer(m_UpstreamBuffer.data(), BytesToWrite), [this, Self, BytesToWrite](const asio::error_code& WriteEc, size_t /*BytesWritten*/) { if (WriteEc) { if (WriteEc != asio::error::operation_aborted) { ZEN_DEBUG("client write error - {}", WriteEc.message()); } Shutdown(); return; } m_Owner.m_TotalBytesToClient.fetch_add(BytesToWrite, std::memory_order_relaxed); m_BytesToClient.fetch_add(BytesToWrite, std::memory_order_relaxed); m_Owner.m_BytesMeter.Mark(BytesToWrite); ReadFromUpstream(); }); } template void TcpProxySession::DoShutdownSocket(SocketT& Socket) { if (Socket.is_open()) { asio::error_code Ec; Socket.shutdown(asio::socket_base::shutdown_both, Ec); Socket.close(Ec); } } void TcpProxySession::Shutdown() { if (m_ShutdownCalled.exchange(true)) { return; } #if defined(ASIO_HAS_LOCAL_SOCKETS) if (m_IsUnixClient) { DoShutdownSocket(m_ClientUnixSocket); } else #endif { DoShutdownSocket(m_ClientTcpSocket); } #if defined(ASIO_HAS_LOCAL_SOCKETS) if (m_IsUnixTarget) { DoShutdownSocket(m_UpstreamUnixSocket); } else #endif { DoShutdownSocket(m_UpstreamTcpSocket); } if (m_Recorder) { bool WebSocket = m_RequestInspector && m_RequestInspector->IsUpgraded(); Oid SessionId = m_RequestInspector ? m_RequestInspector->GetSessionId() : Oid::Zero; m_Recorder->Finalize(WebSocket, SessionId); } m_Owner.m_ActiveConnections.fetch_sub(1, std::memory_order_relaxed); m_Owner.RemoveSession(this); } ////////////////////////////////////////////////////////////////////////// // TcpProxyService TcpProxyService::TcpProxyService(asio::io_context& IoContext, const ProxyMapping& Mapping) : m_Log(logging::Get("proxy")) , m_Mapping(Mapping) , m_IoContext(IoContext) , m_TcpAcceptor(IoContext) #if defined(ASIO_HAS_LOCAL_SOCKETS) , m_UnixAcceptor(IoContext) #endif { if (!Mapping.IsUnixListen()) { asio::ip::address ListenAddr = Mapping.ListenAddress.empty() ? asio::ip::address_v4::any() : asio::ip::make_address(Mapping.ListenAddress); m_ListenEndpoint = asio::ip::tcp::endpoint(ListenAddr, Mapping.ListenPort); } } void TcpProxyService::Start() { #if defined(ASIO_HAS_LOCAL_SOCKETS) if (m_Mapping.IsUnixListen()) { // Remove stale socket file if it exists. std::error_code RemoveEc; std::filesystem::remove(m_Mapping.ListenUnixSocket, RemoveEc); asio::local::stream_protocol::endpoint Endpoint(m_Mapping.ListenUnixSocket); m_UnixAcceptor.open(Endpoint.protocol()); m_UnixAcceptor.bind(Endpoint); m_UnixAcceptor.listen(); ZEN_INFO("listening on {} -> {}", m_Mapping.ListenDescription(), m_Mapping.TargetDescription()); DoAcceptUnix(); return; } #endif m_TcpAcceptor.open(m_ListenEndpoint.protocol()); m_TcpAcceptor.set_option(asio::ip::tcp::acceptor::reuse_address(true)); m_TcpAcceptor.bind(m_ListenEndpoint); m_TcpAcceptor.listen(); ZEN_INFO("listening on {} -> {}", m_Mapping.ListenDescription(), m_Mapping.TargetDescription()); DoAccept(); } void TcpProxyService::Stop() { m_Stopped = true; asio::error_code Ec; #if defined(ASIO_HAS_LOCAL_SOCKETS) if (m_Mapping.IsUnixListen()) { m_UnixAcceptor.close(Ec); // Clean up the socket file. std::error_code RemoveEc; std::filesystem::remove(m_Mapping.ListenUnixSocket, RemoveEc); return; } #endif m_TcpAcceptor.close(Ec); } void TcpProxyService::OnAcceptedSession(std::shared_ptr Session) { m_TotalConnections.fetch_add(1, std::memory_order_relaxed); uint64_t Active = m_ActiveConnections.fetch_add(1, std::memory_order_relaxed) + 1; uint64_t Peak = m_PeakActiveConnections.load(std::memory_order_relaxed); while (Active > Peak && !m_PeakActiveConnections.compare_exchange_weak(Peak, Active, std::memory_order_relaxed)) ; AddSession(Session); Session->Start(); } void TcpProxyService::AddSession(std::shared_ptr Session) { RwLock::ExclusiveLockScope Lock(m_SessionsLock); m_Sessions.push_back(std::move(Session)); } void TcpProxyService::RemoveSession(TcpProxySession* Session) { RwLock::ExclusiveLockScope Lock(m_SessionsLock); auto It = std::find_if(m_Sessions.begin(), m_Sessions.end(), [Session](const std::shared_ptr& S) { return S.get() == Session; }); if (It != m_Sessions.end()) { // Swap-and-pop for O(1) removal; order doesn't matter. std::swap(*It, m_Sessions.back()); m_Sessions.pop_back(); } } std::vector> TcpProxyService::GetActiveSessions() const { RwLock::SharedLockScope Lock(m_SessionsLock); return m_Sessions; } void TcpProxyService::SetRecording(bool Enabled, const std::string& Dir) { { RwLock::ExclusiveLockScope Lock(m_RecordDirLock); m_RecordDir = Dir; } m_RecordingEnabled.store(Enabled, std::memory_order_relaxed); ZEN_INFO("proxy recording {} (dir: {})", Enabled ? "enabled" : "disabled", Dir); } std::string TcpProxyService::GetRecordDir() const { RwLock::SharedLockScope Lock(m_RecordDirLock); return m_RecordDir; } void TcpProxyService::DoAccept() { m_TcpAcceptor.async_accept([this](const asio::error_code& Ec, asio::ip::tcp::socket Socket) { if (Ec) { if (!m_Stopped) { ZEN_WARN("accept error - {}", Ec.message()); } return; } ZEN_DEBUG("accepted connection from {}", Socket.remote_endpoint().address().to_string()); OnAcceptedSession(std::make_shared(std::move(Socket), m_Mapping, *this)); DoAccept(); }); } #if defined(ASIO_HAS_LOCAL_SOCKETS) void TcpProxyService::DoAcceptUnix() { m_UnixAcceptor.async_accept([this](const asio::error_code& Ec, asio::local::stream_protocol::socket Socket) { if (Ec) { if (!m_Stopped) { ZEN_WARN("accept error - {}", Ec.message()); } return; } ZEN_DEBUG("accepted unix connection"); OnAcceptedSession(std::make_shared(std::move(Socket), m_Mapping, *this)); DoAcceptUnix(); }); } #endif } // namespace zen