// Copyright Epic Games, Inc. All Rights Reserved. #pragma once #include "httptrafficinspector.h" #include "httptrafficrecorder.h" #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include #if defined(ASIO_HAS_LOCAL_SOCKETS) # include #endif ZEN_THIRD_PARTY_INCLUDES_END #include #include #include #include #include #include #include namespace zen { struct ProxyMapping { std::string ListenAddress; uint16_t ListenPort = 0; std::string ListenUnixSocket; std::string TargetHost; uint16_t TargetPort = 0; std::string TargetUnixSocket; bool IsUnixListen() const { return !ListenUnixSocket.empty(); } bool IsUnixTarget() const { return !TargetUnixSocket.empty(); } std::string ListenDescription() const; std::string TargetDescription() const; }; class TcpProxyService; class TcpProxySession : public std::enable_shared_from_this { public: TcpProxySession(asio::ip::tcp::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner); #if defined(ASIO_HAS_LOCAL_SOCKETS) TcpProxySession(asio::local::stream_protocol::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner); #endif void Start(); const std::string& GetClientLabel() const { return m_ClientLabel; } std::chrono::steady_clock::time_point GetStartTime() const { return m_StartTime; } uint64_t GetBytesFromClient() const { return m_BytesFromClient.load(std::memory_order_relaxed); } uint64_t GetBytesToClient() const { return m_BytesToClient.load(std::memory_order_relaxed); } uint64_t GetRequestCount() const { return m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0; } bool IsWebSocket() const { return m_RequestInspector && m_RequestInspector->IsUpgraded(); } bool HasSessionId() const { return m_RequestInspector && m_RequestInspector->HasSessionId(); } Oid GetSessionId() const { return m_RequestInspector ? m_RequestInspector->GetSessionId() : Oid::Zero; } private: LoggerRef Log(); void ConnectToTcpTarget(); #if defined(ASIO_HAS_LOCAL_SOCKETS) void ConnectToUnixTarget(); #endif void StartRelay(); void ReadFromClient(); void ReadFromUpstream(); template void DispatchClientSocket(Fn&& F); template void DispatchUpstreamSocket(Fn&& F); template void DoReadFromClient(SocketT& ClientSocket); template void DoReadFromUpstream(SocketT& UpstreamSocket); template void DoForwardToUpstream(SocketT& UpstreamSocket, size_t BytesToWrite, uint64_t NewRequests); template void DoForwardToClient(SocketT& ClientSocket, size_t BytesToWrite); template void DoShutdownSocket(SocketT& Socket); void Shutdown(); asio::ip::tcp::socket m_ClientTcpSocket; asio::ip::tcp::socket m_UpstreamTcpSocket; #if defined(ASIO_HAS_LOCAL_SOCKETS) asio::local::stream_protocol::socket m_ClientUnixSocket; asio::local::stream_protocol::socket m_UpstreamUnixSocket; bool m_IsUnixClient = false; bool m_IsUnixTarget = false; #endif std::string m_TargetHost; uint16_t m_TargetPort; std::string m_TargetUnixSocket; TcpProxyService& m_Owner; static constexpr size_t kBufferSize = 16 * 1024; std::array m_ClientBuffer; std::array m_UpstreamBuffer; std::atomic m_ShutdownCalled{false}; std::string m_ClientLabel; std::chrono::steady_clock::time_point m_StartTime; std::atomic m_BytesFromClient{0}; std::atomic m_BytesToClient{0}; std::optional m_RequestInspector; std::optional m_ResponseInspector; std::unique_ptr m_Recorder; }; class TcpProxyService { public: TcpProxyService(asio::io_context& IoContext, const ProxyMapping& Mapping); void Start(); void Stop(); const ProxyMapping& GetMapping() const { return m_Mapping; } uint64_t GetTotalConnections() const { return m_TotalConnections.load(std::memory_order_relaxed); } uint64_t GetActiveConnections() const { return m_ActiveConnections.load(std::memory_order_relaxed); } uint64_t GetPeakActiveConnections() const { return m_PeakActiveConnections.load(std::memory_order_relaxed); } uint64_t GetTotalBytesFromClient() const { return m_TotalBytesFromClient.load(std::memory_order_relaxed); } uint64_t GetTotalBytesToClient() const { return m_TotalBytesToClient.load(std::memory_order_relaxed); } metrics::Meter& GetRequestMeter() { return m_RequestMeter; } metrics::Meter& GetBytesMeter() { return m_BytesMeter; } // Returns a snapshot of active sessions under a shared lock. std::vector> GetActiveSessions() const; void SetRecording(bool Enabled, const std::string& Dir); bool IsRecording() const { return m_RecordingEnabled.load(std::memory_order_relaxed); } std::string GetRecordDir() const; LoggerRef Log() { return m_Log; } private: friend class TcpProxySession; void DoAccept(); #if defined(ASIO_HAS_LOCAL_SOCKETS) void DoAcceptUnix(); #endif void OnAcceptedSession(std::shared_ptr Session); LoggerRef m_Log; ProxyMapping m_Mapping; asio::io_context& m_IoContext; asio::ip::tcp::acceptor m_TcpAcceptor; asio::ip::tcp::endpoint m_ListenEndpoint; #if defined(ASIO_HAS_LOCAL_SOCKETS) asio::local::stream_protocol::acceptor m_UnixAcceptor; #endif bool m_Stopped = false; void AddSession(std::shared_ptr Session); void RemoveSession(TcpProxySession* Session); std::atomic m_TotalConnections{0}; std::atomic m_ActiveConnections{0}; std::atomic m_PeakActiveConnections{0}; std::atomic m_TotalBytesFromClient{0}; std::atomic m_TotalBytesToClient{0}; metrics::Meter m_RequestMeter; metrics::Meter m_BytesMeter; mutable RwLock m_SessionsLock; std::vector> m_Sessions; std::atomic m_RecordingEnabled{false}; mutable RwLock m_RecordDirLock; std::string m_RecordDir; std::atomic m_RecordSessionCounter{0}; }; } // namespace zen