diff options
Diffstat (limited to 'src/zenserver/proxy/tcpproxy.h')
| -rw-r--r-- | src/zenserver/proxy/tcpproxy.h | 196 |
1 files changed, 196 insertions, 0 deletions
diff --git a/src/zenserver/proxy/tcpproxy.h b/src/zenserver/proxy/tcpproxy.h new file mode 100644 index 000000000..7eb5c8dff --- /dev/null +++ b/src/zenserver/proxy/tcpproxy.h @@ -0,0 +1,196 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "httptrafficinspector.h" +#include "httptrafficrecorder.h" + +#include <zencore/logging.h> +#include <zencore/thread.h> +#include <zentelemetry/stats.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +#include <atomic> +#include <chrono> +#include <cstdint> +#include <memory> +#include <optional> +#include <string> +#include <vector> + +namespace zen { + +struct ProxyMapping +{ + std::string ListenAddress; + uint16_t ListenPort = 0; + std::string ListenUnixSocket; + std::string TargetHost; + uint16_t TargetPort = 0; + std::string TargetUnixSocket; + + bool IsUnixListen() const { return !ListenUnixSocket.empty(); } + bool IsUnixTarget() const { return !TargetUnixSocket.empty(); } + std::string ListenDescription() const; + std::string TargetDescription() const; +}; + +class TcpProxyService; + +class TcpProxySession : public std::enable_shared_from_this<TcpProxySession> +{ +public: + TcpProxySession(asio::ip::tcp::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner); +#if defined(ASIO_HAS_LOCAL_SOCKETS) + TcpProxySession(asio::local::stream_protocol::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner); +#endif + + void Start(); + + const std::string& GetClientLabel() const { return m_ClientLabel; } + std::chrono::steady_clock::time_point GetStartTime() const { return m_StartTime; } + uint64_t GetBytesFromClient() const { return m_BytesFromClient.load(std::memory_order_relaxed); } + uint64_t GetBytesToClient() const { return m_BytesToClient.load(std::memory_order_relaxed); } + uint64_t GetRequestCount() const { return m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0; } + bool IsWebSocket() const { return m_RequestInspector && m_RequestInspector->IsUpgraded(); } + bool HasSessionId() const { return m_RequestInspector && m_RequestInspector->HasSessionId(); } + Oid GetSessionId() const { return m_RequestInspector ? m_RequestInspector->GetSessionId() : Oid::Zero; } + +private: + LoggerRef Log(); + + void ConnectToTcpTarget(); +#if defined(ASIO_HAS_LOCAL_SOCKETS) + void ConnectToUnixTarget(); +#endif + void StartRelay(); + + void ReadFromClient(); + void ReadFromUpstream(); + + template<typename Fn> + void DispatchClientSocket(Fn&& F); + template<typename Fn> + void DispatchUpstreamSocket(Fn&& F); + + template<typename SocketT> + void DoReadFromClient(SocketT& ClientSocket); + template<typename SocketT> + void DoReadFromUpstream(SocketT& UpstreamSocket); + template<typename SocketT> + void DoForwardToUpstream(SocketT& UpstreamSocket, size_t BytesToWrite, uint64_t NewRequests); + template<typename SocketT> + void DoForwardToClient(SocketT& ClientSocket, size_t BytesToWrite); + template<typename SocketT> + void DoShutdownSocket(SocketT& Socket); + + void Shutdown(); + + asio::ip::tcp::socket m_ClientTcpSocket; + asio::ip::tcp::socket m_UpstreamTcpSocket; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + asio::local::stream_protocol::socket m_ClientUnixSocket; + asio::local::stream_protocol::socket m_UpstreamUnixSocket; + bool m_IsUnixClient = false; + bool m_IsUnixTarget = false; +#endif + + std::string m_TargetHost; + uint16_t m_TargetPort; + std::string m_TargetUnixSocket; + + TcpProxyService& m_Owner; + + static constexpr size_t kBufferSize = 16 * 1024; + + std::array<char, kBufferSize> m_ClientBuffer; + std::array<char, kBufferSize> m_UpstreamBuffer; + + std::atomic<bool> m_ShutdownCalled{false}; + + std::string m_ClientLabel; + std::chrono::steady_clock::time_point m_StartTime; + std::atomic<uint64_t> m_BytesFromClient{0}; + std::atomic<uint64_t> m_BytesToClient{0}; + + std::optional<HttpTrafficInspector> m_RequestInspector; + std::optional<HttpTrafficInspector> m_ResponseInspector; + std::unique_ptr<HttpTrafficRecorder> m_Recorder; +}; + +class TcpProxyService +{ +public: + TcpProxyService(asio::io_context& IoContext, const ProxyMapping& Mapping); + + void Start(); + void Stop(); + + const ProxyMapping& GetMapping() const { return m_Mapping; } + + uint64_t GetTotalConnections() const { return m_TotalConnections.load(std::memory_order_relaxed); } + uint64_t GetActiveConnections() const { return m_ActiveConnections.load(std::memory_order_relaxed); } + uint64_t GetPeakActiveConnections() const { return m_PeakActiveConnections.load(std::memory_order_relaxed); } + uint64_t GetTotalBytesFromClient() const { return m_TotalBytesFromClient.load(std::memory_order_relaxed); } + uint64_t GetTotalBytesToClient() const { return m_TotalBytesToClient.load(std::memory_order_relaxed); } + + metrics::Meter& GetRequestMeter() { return m_RequestMeter; } + metrics::Meter& GetBytesMeter() { return m_BytesMeter; } + + // Returns a snapshot of active sessions under a shared lock. + std::vector<std::shared_ptr<TcpProxySession>> GetActiveSessions() const; + + void SetRecording(bool Enabled, const std::string& Dir); + bool IsRecording() const { return m_RecordingEnabled.load(std::memory_order_relaxed); } + std::string GetRecordDir() const; + + LoggerRef Log() { return m_Log; } + +private: + friend class TcpProxySession; + + void DoAccept(); +#if defined(ASIO_HAS_LOCAL_SOCKETS) + void DoAcceptUnix(); +#endif + + void OnAcceptedSession(std::shared_ptr<TcpProxySession> Session); + + LoggerRef m_Log; + ProxyMapping m_Mapping; + asio::io_context& m_IoContext; + asio::ip::tcp::acceptor m_TcpAcceptor; + asio::ip::tcp::endpoint m_ListenEndpoint; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + asio::local::stream_protocol::acceptor m_UnixAcceptor; +#endif + bool m_Stopped = false; + + void AddSession(std::shared_ptr<TcpProxySession> Session); + void RemoveSession(TcpProxySession* Session); + + std::atomic<uint64_t> m_TotalConnections{0}; + std::atomic<uint64_t> m_ActiveConnections{0}; + std::atomic<uint64_t> m_PeakActiveConnections{0}; + std::atomic<uint64_t> m_TotalBytesFromClient{0}; + std::atomic<uint64_t> m_TotalBytesToClient{0}; + + metrics::Meter m_RequestMeter; + metrics::Meter m_BytesMeter; + + mutable RwLock m_SessionsLock; + std::vector<std::shared_ptr<TcpProxySession>> m_Sessions; + + std::atomic<bool> m_RecordingEnabled{false}; + mutable RwLock m_RecordDirLock; + std::string m_RecordDir; + std::atomic<uint64_t> m_RecordSessionCounter{0}; +}; + +} // namespace zen |