aboutsummaryrefslogtreecommitdiff
path: root/src/zenserver/proxy/tcpproxy.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenserver/proxy/tcpproxy.h')
-rw-r--r--src/zenserver/proxy/tcpproxy.h196
1 files changed, 196 insertions, 0 deletions
diff --git a/src/zenserver/proxy/tcpproxy.h b/src/zenserver/proxy/tcpproxy.h
new file mode 100644
index 000000000..7eb5c8dff
--- /dev/null
+++ b/src/zenserver/proxy/tcpproxy.h
@@ -0,0 +1,196 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "httptrafficinspector.h"
+#include "httptrafficrecorder.h"
+
+#include <zencore/logging.h>
+#include <zencore/thread.h>
+#include <zentelemetry/stats.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <atomic>
+#include <chrono>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <vector>
+
+namespace zen {
+
+struct ProxyMapping
+{
+ std::string ListenAddress;
+ uint16_t ListenPort = 0;
+ std::string ListenUnixSocket;
+ std::string TargetHost;
+ uint16_t TargetPort = 0;
+ std::string TargetUnixSocket;
+
+ bool IsUnixListen() const { return !ListenUnixSocket.empty(); }
+ bool IsUnixTarget() const { return !TargetUnixSocket.empty(); }
+ std::string ListenDescription() const;
+ std::string TargetDescription() const;
+};
+
+class TcpProxyService;
+
+class TcpProxySession : public std::enable_shared_from_this<TcpProxySession>
+{
+public:
+ TcpProxySession(asio::ip::tcp::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner);
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ TcpProxySession(asio::local::stream_protocol::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner);
+#endif
+
+ void Start();
+
+ const std::string& GetClientLabel() const { return m_ClientLabel; }
+ std::chrono::steady_clock::time_point GetStartTime() const { return m_StartTime; }
+ uint64_t GetBytesFromClient() const { return m_BytesFromClient.load(std::memory_order_relaxed); }
+ uint64_t GetBytesToClient() const { return m_BytesToClient.load(std::memory_order_relaxed); }
+ uint64_t GetRequestCount() const { return m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0; }
+ bool IsWebSocket() const { return m_RequestInspector && m_RequestInspector->IsUpgraded(); }
+ bool HasSessionId() const { return m_RequestInspector && m_RequestInspector->HasSessionId(); }
+ Oid GetSessionId() const { return m_RequestInspector ? m_RequestInspector->GetSessionId() : Oid::Zero; }
+
+private:
+ LoggerRef Log();
+
+ void ConnectToTcpTarget();
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ void ConnectToUnixTarget();
+#endif
+ void StartRelay();
+
+ void ReadFromClient();
+ void ReadFromUpstream();
+
+ template<typename Fn>
+ void DispatchClientSocket(Fn&& F);
+ template<typename Fn>
+ void DispatchUpstreamSocket(Fn&& F);
+
+ template<typename SocketT>
+ void DoReadFromClient(SocketT& ClientSocket);
+ template<typename SocketT>
+ void DoReadFromUpstream(SocketT& UpstreamSocket);
+ template<typename SocketT>
+ void DoForwardToUpstream(SocketT& UpstreamSocket, size_t BytesToWrite, uint64_t NewRequests);
+ template<typename SocketT>
+ void DoForwardToClient(SocketT& ClientSocket, size_t BytesToWrite);
+ template<typename SocketT>
+ void DoShutdownSocket(SocketT& Socket);
+
+ void Shutdown();
+
+ asio::ip::tcp::socket m_ClientTcpSocket;
+ asio::ip::tcp::socket m_UpstreamTcpSocket;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ asio::local::stream_protocol::socket m_ClientUnixSocket;
+ asio::local::stream_protocol::socket m_UpstreamUnixSocket;
+ bool m_IsUnixClient = false;
+ bool m_IsUnixTarget = false;
+#endif
+
+ std::string m_TargetHost;
+ uint16_t m_TargetPort;
+ std::string m_TargetUnixSocket;
+
+ TcpProxyService& m_Owner;
+
+ static constexpr size_t kBufferSize = 16 * 1024;
+
+ std::array<char, kBufferSize> m_ClientBuffer;
+ std::array<char, kBufferSize> m_UpstreamBuffer;
+
+ std::atomic<bool> m_ShutdownCalled{false};
+
+ std::string m_ClientLabel;
+ std::chrono::steady_clock::time_point m_StartTime;
+ std::atomic<uint64_t> m_BytesFromClient{0};
+ std::atomic<uint64_t> m_BytesToClient{0};
+
+ std::optional<HttpTrafficInspector> m_RequestInspector;
+ std::optional<HttpTrafficInspector> m_ResponseInspector;
+ std::unique_ptr<HttpTrafficRecorder> m_Recorder;
+};
+
+class TcpProxyService
+{
+public:
+ TcpProxyService(asio::io_context& IoContext, const ProxyMapping& Mapping);
+
+ void Start();
+ void Stop();
+
+ const ProxyMapping& GetMapping() const { return m_Mapping; }
+
+ uint64_t GetTotalConnections() const { return m_TotalConnections.load(std::memory_order_relaxed); }
+ uint64_t GetActiveConnections() const { return m_ActiveConnections.load(std::memory_order_relaxed); }
+ uint64_t GetPeakActiveConnections() const { return m_PeakActiveConnections.load(std::memory_order_relaxed); }
+ uint64_t GetTotalBytesFromClient() const { return m_TotalBytesFromClient.load(std::memory_order_relaxed); }
+ uint64_t GetTotalBytesToClient() const { return m_TotalBytesToClient.load(std::memory_order_relaxed); }
+
+ metrics::Meter& GetRequestMeter() { return m_RequestMeter; }
+ metrics::Meter& GetBytesMeter() { return m_BytesMeter; }
+
+ // Returns a snapshot of active sessions under a shared lock.
+ std::vector<std::shared_ptr<TcpProxySession>> GetActiveSessions() const;
+
+ void SetRecording(bool Enabled, const std::string& Dir);
+ bool IsRecording() const { return m_RecordingEnabled.load(std::memory_order_relaxed); }
+ std::string GetRecordDir() const;
+
+ LoggerRef Log() { return m_Log; }
+
+private:
+ friend class TcpProxySession;
+
+ void DoAccept();
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ void DoAcceptUnix();
+#endif
+
+ void OnAcceptedSession(std::shared_ptr<TcpProxySession> Session);
+
+ LoggerRef m_Log;
+ ProxyMapping m_Mapping;
+ asio::io_context& m_IoContext;
+ asio::ip::tcp::acceptor m_TcpAcceptor;
+ asio::ip::tcp::endpoint m_ListenEndpoint;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ asio::local::stream_protocol::acceptor m_UnixAcceptor;
+#endif
+ bool m_Stopped = false;
+
+ void AddSession(std::shared_ptr<TcpProxySession> Session);
+ void RemoveSession(TcpProxySession* Session);
+
+ std::atomic<uint64_t> m_TotalConnections{0};
+ std::atomic<uint64_t> m_ActiveConnections{0};
+ std::atomic<uint64_t> m_PeakActiveConnections{0};
+ std::atomic<uint64_t> m_TotalBytesFromClient{0};
+ std::atomic<uint64_t> m_TotalBytesToClient{0};
+
+ metrics::Meter m_RequestMeter;
+ metrics::Meter m_BytesMeter;
+
+ mutable RwLock m_SessionsLock;
+ std::vector<std::shared_ptr<TcpProxySession>> m_Sessions;
+
+ std::atomic<bool> m_RecordingEnabled{false};
+ mutable RwLock m_RecordDirLock;
+ std::string m_RecordDir;
+ std::atomic<uint64_t> m_RecordSessionCounter{0};
+};
+
+} // namespace zen