diff options
| author | Stefan Boberg <[email protected]> | 2026-03-12 15:03:03 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-12 15:03:03 +0100 |
| commit | 81bc43aa96f0059cecb28d1bd88338b7d84667f9 (patch) | |
| tree | a3428cb7fddceae0b284d33562af5bf3e64a367e /src/zenserver/proxy | |
| parent | update fmt 12.0.0 -> 12.1.0 (#828) (diff) | |
| download | zen-81bc43aa96f0059cecb28d1bd88338b7d84667f9.tar.xz zen-81bc43aa96f0059cecb28d1bd88338b7d84667f9.zip | |
Transparent proxy mode (#823)
Adds a **transparent TCP proxy mode** to zenserver (activated via `zenserver proxy`), allowing it to sit between clients and upstream Zen servers to inspect and monitor HTTP/1.x traffic in real time. Primarily useful during development, to be able to observe multi-server/client interactions in one place.
- **Dedicated proxy port** -- Proxy mode defaults to port 8118 with its own data directory to avoid collisions with a normal zenserver instance.
- **TCP proxy core** (`src/zenserver/proxy/`) -- A new transparent TCP proxy that forwards connections to upstream targets, with support for both TCP/IP and Unix socket listeners. Multi-threaded I/O for connection handling. Supports Unix domain sockets for both upstream/downstream.
- **HTTP traffic inspection** -- Parses HTTP/1.x request/response streams inline to extract method, path, status, content length, and WebSocket upgrades without breaking the proxied data.
- **Proxy dashboard** -- A web UI showing live connection stats, per-target request counts, active connections, bytes transferred, and client IP/session ID rollups.
- **Server mode display** -- Dashboard banner now shows the running server mode (Zen Proxy, Zen Compute, etc.).
Supporting changes included in this branch:
- **Wildcard log level matching** -- Log levels can now be set per-category using wildcard patterns (e.g. `proxy.*=debug`).
- **`zen down --all`** -- New flag to shut down all running zenserver instances; also used by the new `xmake kill` task.
- Minor test stability fixes (flaky hash collisions, per-thread RNG seeds).
- Support ZEN_MALLOC environment variable for default allocator selection and switch default to rpmalloc
- Fixed sentry-native build to allow LTO on Windows
Diffstat (limited to 'src/zenserver/proxy')
| -rw-r--r-- | src/zenserver/proxy/httpproxystats.cpp | 234 | ||||
| -rw-r--r-- | src/zenserver/proxy/httpproxystats.h | 39 | ||||
| -rw-r--r-- | src/zenserver/proxy/httptrafficinspector.cpp | 197 | ||||
| -rw-r--r-- | src/zenserver/proxy/httptrafficinspector.h | 85 | ||||
| -rw-r--r-- | src/zenserver/proxy/httptrafficrecorder.cpp | 191 | ||||
| -rw-r--r-- | src/zenserver/proxy/httptrafficrecorder.h | 81 | ||||
| -rw-r--r-- | src/zenserver/proxy/tcpproxy.cpp | 610 | ||||
| -rw-r--r-- | src/zenserver/proxy/tcpproxy.h | 196 | ||||
| -rw-r--r-- | src/zenserver/proxy/zenproxyserver.cpp | 517 | ||||
| -rw-r--r-- | src/zenserver/proxy/zenproxyserver.h | 96 |
10 files changed, 2246 insertions, 0 deletions
diff --git a/src/zenserver/proxy/httpproxystats.cpp b/src/zenserver/proxy/httpproxystats.cpp new file mode 100644 index 000000000..6aa3e5c9b --- /dev/null +++ b/src/zenserver/proxy/httpproxystats.cpp @@ -0,0 +1,234 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpproxystats.h" + +#include "tcpproxy.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/fmtutils.h> + +#include <chrono> +#include <filesystem> + +namespace zen { + +HttpProxyStatsService::HttpProxyStatsService(const std::vector<std::unique_ptr<TcpProxyService>>& ProxyServices, + IHttpStatsService& StatsService, + std::string DefaultRecordDir) +: m_ProxyServices(ProxyServices) +, m_StatsService(StatsService) +, m_DefaultRecordDir(std::move(DefaultRecordDir)) +{ + m_StatsService.RegisterHandler("proxy", *this); +} + +HttpProxyStatsService::~HttpProxyStatsService() +{ + m_StatsService.UnregisterHandler("proxy", *this); +} + +const char* +HttpProxyStatsService::BaseUri() const +{ + return "/proxy/"; +} + +void +HttpProxyStatsService::HandleRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + if (Uri == "stats" || Uri == "stats/") + { + HandleStatsRequest(Request); + } + else if (Uri == "record/start" || Uri == "record/start/") + { + HandleRecordStart(Request); + } + else if (Uri == "record/stop" || Uri == "record/stop/") + { + HandleRecordStop(Request); + } + else if (Uri == "record" || Uri == "record/") + { + HandleRecordStatus(Request); + } + else + { + Request.WriteResponse(HttpResponseCode::NotFound); + } +} + +void +HttpProxyStatsService::HandleRecordStart(HttpServerRequest& Request) +{ + if (Request.RequestVerb() != HttpVerb::kPost) + { + Request.WriteResponse(HttpResponseCode::MethodNotAllowed); + return; + } + + auto Params = Request.GetQueryParams(); + std::string_view Dir = Params.GetValue("dir"); + + std::string RecordDir; + if (Dir.empty()) + { + RecordDir = m_DefaultRecordDir; + } + else + { + // Treat dir as a subdirectory name within the default record directory. + // Reject path separators and parent references to prevent path traversal. + if (Dir.find("..") != std::string_view::npos || Dir.find('/') != std::string_view::npos || Dir.find('\\') != std::string_view::npos) + { + Request.WriteResponse(HttpResponseCode::BadRequest); + return; + } + RecordDir = (std::filesystem::path(m_DefaultRecordDir) / std::string(Dir)).string(); + } + + for (const std::unique_ptr<TcpProxyService>& Service : m_ProxyServices) + { + Service->SetRecording(true, RecordDir); + } + + CbObjectWriter Cbo; + Cbo << "recording" << true; + Cbo << "dir" << std::string_view(RecordDir); + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpProxyStatsService::HandleRecordStop(HttpServerRequest& Request) +{ + if (Request.RequestVerb() != HttpVerb::kPost) + { + Request.WriteResponse(HttpResponseCode::MethodNotAllowed); + return; + } + + for (const std::unique_ptr<TcpProxyService>& Service : m_ProxyServices) + { + Service->SetRecording(false, Service->GetRecordDir()); + } + + CbObjectWriter Cbo; + Cbo << "recording" << false; + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpProxyStatsService::HandleRecordStatus(HttpServerRequest& Request) +{ + bool IsRecording = false; + std::string RecordDir; + for (const std::unique_ptr<TcpProxyService>& Service : m_ProxyServices) + { + if (Service->IsRecording()) + { + IsRecording = true; + RecordDir = Service->GetRecordDir(); + break; + } + } + + CbObjectWriter Cbo; + Cbo << "recording" << IsRecording; + Cbo << "dir" << std::string_view(RecordDir); + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +CbObject +HttpProxyStatsService::CollectStats() +{ + CbObjectWriter Cbo; + + // Include recording status in stats output. + { + bool IsRecording = false; + std::string RecordDir; + for (const std::unique_ptr<TcpProxyService>& Service : m_ProxyServices) + { + if (Service->IsRecording()) + { + IsRecording = true; + RecordDir = Service->GetRecordDir(); + break; + } + } + Cbo << "recording" << IsRecording; + Cbo << "recordDir" << std::string_view(RecordDir); + } + + Cbo.BeginArray("mappings"); + for (const std::unique_ptr<TcpProxyService>& Service : m_ProxyServices) + { + const ProxyMapping& Mapping = Service->GetMapping(); + + Cbo.BeginObject(); + { + std::string ListenAddr = Mapping.ListenDescription(); + Cbo << "listen" << std::string_view(ListenAddr); + + std::string TargetAddr = Mapping.TargetDescription(); + Cbo << "target" << std::string_view(TargetAddr); + + Cbo << "activeConnections" << Service->GetActiveConnections(); + Cbo << "peakActiveConnections" << Service->GetPeakActiveConnections(); + Cbo << "totalConnections" << Service->GetTotalConnections(); + Cbo << "bytesFromClient" << Service->GetTotalBytesFromClient(); + Cbo << "bytesToClient" << Service->GetTotalBytesToClient(); + + Cbo << "requestRate1" << Service->GetRequestMeter().Rate1(); + Cbo << "byteRate1" << Service->GetBytesMeter().Rate1(); + Cbo << "byteRate5" << Service->GetBytesMeter().Rate5(); + + auto Now = std::chrono::steady_clock::now(); + auto Sessions = Service->GetActiveSessions(); + + Cbo.BeginArray("connections"); + for (const auto& Session : Sessions) + { + Cbo.BeginObject(); + { + std::string ClientLabel = Session->GetClientLabel(); + Cbo << "client" << std::string_view(ClientLabel); + + std::string TargetLabel = Mapping.TargetDescription(); + Cbo << "target" << std::string_view(TargetLabel); + + Cbo << "bytesFromClient" << Session->GetBytesFromClient(); + Cbo << "bytesToClient" << Session->GetBytesToClient(); + + Cbo << "requests" << Session->GetRequestCount(); + Cbo << "websocket" << Session->IsWebSocket(); + + if (Session->HasSessionId()) + { + std::string SessionId = Session->GetSessionId().ToString(); + Cbo << "sessionId" << std::string_view(SessionId); + } + + double DurationMs = std::chrono::duration<double, std::milli>(Now - Session->GetStartTime()).count(); + Cbo << "durationMs" << DurationMs; + } + Cbo.EndObject(); + } + Cbo.EndArray(); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + return Cbo.Save(); +} + +void +HttpProxyStatsService::HandleStatsRequest(HttpServerRequest& Request) +{ + Request.WriteResponse(HttpResponseCode::OK, CollectStats()); +} + +} // namespace zen diff --git a/src/zenserver/proxy/httpproxystats.h b/src/zenserver/proxy/httpproxystats.h new file mode 100644 index 000000000..76ac7c875 --- /dev/null +++ b/src/zenserver/proxy/httpproxystats.h @@ -0,0 +1,39 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +#include <memory> +#include <vector> + +namespace zen { + +class TcpProxyService; + +class HttpProxyStatsService : public HttpService, public IHttpStatsProvider +{ +public: + HttpProxyStatsService(const std::vector<std::unique_ptr<TcpProxyService>>& ProxyServices, + IHttpStatsService& StatsService, + std::string DefaultRecordDir); + ~HttpProxyStatsService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + + // IHttpStatsProvider + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + virtual CbObject CollectStats() override; + +private: + void HandleRecordStart(HttpServerRequest& Request); + void HandleRecordStop(HttpServerRequest& Request); + void HandleRecordStatus(HttpServerRequest& Request); + + const std::vector<std::unique_ptr<TcpProxyService>>& m_ProxyServices; + IHttpStatsService& m_StatsService; + std::string m_DefaultRecordDir; +}; + +} // namespace zen diff --git a/src/zenserver/proxy/httptrafficinspector.cpp b/src/zenserver/proxy/httptrafficinspector.cpp new file mode 100644 index 000000000..74ecbfd48 --- /dev/null +++ b/src/zenserver/proxy/httptrafficinspector.cpp @@ -0,0 +1,197 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httptrafficinspector.h" + +#include <zencore/logging.h> +#include <zencore/string.h> + +#include <charconv> + +namespace zen { + +// clang-format off +http_parser_settings HttpTrafficInspector::s_RequestSettings{ + .on_message_begin = [](http_parser*) { return 0; }, + .on_url = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnUrl(Data, Len); }, + .on_status = [](http_parser*, const char*, size_t) { return 0; }, + .on_header_field = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderField(Data, Len); }, + .on_header_value = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderValue(Data, Len); }, + .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, + .on_body = [](http_parser*, const char*, size_t) { return 0; }, + .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, + .on_chunk_header{}, + .on_chunk_complete{}}; + +http_parser_settings HttpTrafficInspector::s_ResponseSettings{ + .on_message_begin = [](http_parser*) { return 0; }, + .on_url = [](http_parser*, const char*, size_t) { return 0; }, + .on_status = [](http_parser*, const char*, size_t) { return 0; }, + .on_header_field = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderField(Data, Len); }, + .on_header_value = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderValue(Data, Len); }, + .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, + .on_body = [](http_parser*, const char*, size_t) { return 0; }, + .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, + .on_chunk_header{}, + .on_chunk_complete{}}; +// clang-format on + +HttpTrafficInspector::HttpTrafficInspector(Direction Dir, std::string_view SessionLabel) +: m_Log(logging::Get("proxy.http")) +, m_Direction(Dir) +, m_SessionLabel(SessionLabel) +{ + http_parser_init(&m_Parser, Dir == Direction::Request ? HTTP_REQUEST : HTTP_RESPONSE); + m_Parser.data = this; +} + +void +HttpTrafficInspector::Inspect(const char* Data, size_t Length) +{ + if (m_Disabled) + { + return; + } + + http_parser_settings* Settings = (m_Direction == Direction::Request) ? &s_RequestSettings : &s_ResponseSettings; + + size_t Parsed = http_parser_execute(&m_Parser, Settings, Data, Length); + + if (m_Parser.upgrade) + { + if (m_Direction == Direction::Request) + { + ZEN_DEBUG("[{}] >> {} {} (upgrade to WebSocket)", m_SessionLabel, m_Method, m_Url); + } + else + { + ZEN_DEBUG("[{}] << {} (upgrade to WebSocket)", m_SessionLabel, m_StatusCode); + } + ResetMessageState(); + m_Upgraded.store(true, std::memory_order_relaxed); + m_Disabled = true; + return; + } + + http_errno Error = HTTP_PARSER_ERRNO(&m_Parser); + if (Error != HPE_OK) + { + ZEN_DEBUG("[{}] non-HTTP traffic detected ({}), disabling inspection", m_SessionLabel, http_errno_name(Error)); + m_Disabled = true; + } + else if (Parsed != Length) + { + ZEN_DEBUG("[{}] parser consumed {}/{} bytes, disabling inspection", m_SessionLabel, Parsed, Length); + m_Disabled = true; + } +} + +int +HttpTrafficInspector::OnUrl(const char* Data, size_t Length) +{ + m_Url.append(Data, Length); + return 0; +} + +int +HttpTrafficInspector::OnHeaderField(const char* Data, size_t Length) +{ + m_CurrentHeaderField.assign(Data, Length); + return 0; +} + +int +HttpTrafficInspector::OnHeaderValue(const char* Data, size_t Length) +{ + if (m_CurrentHeaderField.size() == 14 && StrCaseCompare(m_CurrentHeaderField.c_str(), "Content-Length", 14) == 0) + { + int64_t Value = 0; + std::from_chars(Data, Data + Length, Value); + m_ContentLength = Value; + } + else if (!m_SessionIdCaptured && m_CurrentHeaderField.size() == 10 && + StrCaseCompare(m_CurrentHeaderField.c_str(), "UE-Session", 10) == 0) + { + Oid Parsed; + if (Oid::TryParse(std::string_view(Data, Length), Parsed)) + { + m_SessionId = Parsed; + m_SessionIdCaptured = true; + } + } + m_CurrentHeaderField.clear(); + return 0; +} + +int +HttpTrafficInspector::OnHeadersComplete() +{ + if (m_Direction == Direction::Request) + { + m_Method = http_method_str(static_cast<http_method>(m_Parser.method)); + } + else + { + m_StatusCode = m_Parser.status_code; + } + return 0; +} + +int +HttpTrafficInspector::OnMessageComplete() +{ + if (m_Direction == Direction::Request) + { + if (m_ContentLength >= 0) + { + ZEN_DEBUG("[{}] >> {} {} (content-length: {})", m_SessionLabel, m_Method, m_Url, m_ContentLength); + } + else + { + ZEN_DEBUG("[{}] >> {} {}", m_SessionLabel, m_Method, m_Url); + } + } + else + { + if (m_ContentLength >= 0) + { + ZEN_DEBUG("[{}] << {} (content-length: {})", m_SessionLabel, m_StatusCode, m_ContentLength); + } + else + { + ZEN_DEBUG("[{}] << {}", m_SessionLabel, m_StatusCode); + } + } + + if (m_Observer) + { + m_Observer->OnMessageComplete(m_Direction, m_Method, m_Url, m_StatusCode, m_ContentLength); + } + + m_MessageCount.fetch_add(1, std::memory_order_relaxed); + ResetMessageState(); + return 0; +} + +Oid +HttpTrafficInspector::GetSessionId() const +{ + return m_SessionId; +} + +bool +HttpTrafficInspector::HasSessionId() const +{ + return m_SessionIdCaptured; +} + +void +HttpTrafficInspector::ResetMessageState() +{ + m_Url.clear(); + m_Method.clear(); + m_StatusCode = 0; + m_ContentLength = -1; + m_CurrentHeaderField.clear(); +} + +} // namespace zen diff --git a/src/zenserver/proxy/httptrafficinspector.h b/src/zenserver/proxy/httptrafficinspector.h new file mode 100644 index 000000000..f4af0e77e --- /dev/null +++ b/src/zenserver/proxy/httptrafficinspector.h @@ -0,0 +1,85 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging.h> +#include <zencore/uid.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <http_parser.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <atomic> +#include <cstdint> +#include <string> +#include <string_view> + +namespace zen { + +class HttpTrafficInspector +{ +public: + enum class Direction + { + Request, + Response + }; + + HttpTrafficInspector(Direction Dir, std::string_view SessionLabel); + + void Inspect(const char* Data, size_t Length); + + uint64_t GetMessageCount() const { return m_MessageCount.load(std::memory_order_relaxed); } + bool IsUpgraded() const { return m_Upgraded.load(std::memory_order_relaxed); } + Oid GetSessionId() const; + bool HasSessionId() const; + + void SetObserver(class IHttpTrafficObserver* Observer) { m_Observer = Observer; } + +private: + int OnUrl(const char* Data, size_t Length); + int OnHeaderField(const char* Data, size_t Length); + int OnHeaderValue(const char* Data, size_t Length); + int OnHeadersComplete(); + int OnMessageComplete(); + + void ResetMessageState(); + + static HttpTrafficInspector* GetThis(http_parser* Parser) { return static_cast<HttpTrafficInspector*>(Parser->data); } + + static http_parser_settings s_RequestSettings; + static http_parser_settings s_ResponseSettings; + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + http_parser m_Parser; + Direction m_Direction; + std::string m_SessionLabel; + bool m_Disabled = false; + + // Per-message state + std::string m_Url; + std::string m_Method; + uint16_t m_StatusCode = 0; + int64_t m_ContentLength = -1; + std::string m_CurrentHeaderField; + std::atomic<uint64_t> m_MessageCount{0}; + std::atomic<bool> m_Upgraded{false}; + Oid m_SessionId = Oid::Zero; + bool m_SessionIdCaptured = false; + IHttpTrafficObserver* m_Observer = nullptr; +}; + +class IHttpTrafficObserver +{ +public: + virtual ~IHttpTrafficObserver() = default; + virtual void OnMessageComplete(HttpTrafficInspector::Direction Dir, + std::string_view Method, + std::string_view Url, + uint16_t StatusCode, + int64_t ContentLength) = 0; +}; + +} // namespace zen diff --git a/src/zenserver/proxy/httptrafficrecorder.cpp b/src/zenserver/proxy/httptrafficrecorder.cpp new file mode 100644 index 000000000..0279555a0 --- /dev/null +++ b/src/zenserver/proxy/httptrafficrecorder.cpp @@ -0,0 +1,191 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httptrafficrecorder.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/logging.h> + +#include <chrono> +#include <filesystem> + +namespace zen { + +HttpTrafficRecorder::HttpTrafficRecorder(const std::filesystem::path& OutputDir, std::string_view ClientLabel, std::string_view TargetLabel) +: m_Log(logging::Get("proxy.record")) +, m_Dir(OutputDir) +, m_ClientLabel(ClientLabel) +, m_TargetLabel(TargetLabel) +{ + auto Now = std::chrono::system_clock::now(); + m_StartTimeMs = uint64_t(std::chrono::duration_cast<std::chrono::milliseconds>(Now.time_since_epoch()).count()); + + std::error_code Ec; + std::filesystem::create_directories(m_Dir, Ec); + if (Ec) + { + ZEN_WARN("failed to create recording directory {} - {}", m_Dir.string(), Ec.message()); + return; + } + + std::error_code ReqEc; + m_RequestFile.Open(m_Dir / "request.bin", BasicFile::Mode::kTruncate, ReqEc); + if (ReqEc) + { + ZEN_WARN("failed to open request.bin in {} - {}", m_Dir.string(), ReqEc.message()); + return; + } + + std::error_code RespEc; + m_ResponseFile.Open(m_Dir / "response.bin", BasicFile::Mode::kTruncate, RespEc); + if (RespEc) + { + ZEN_WARN("failed to open response.bin in {} - {}", m_Dir.string(), RespEc.message()); + m_RequestFile.Close(); + return; + } + + m_Valid = true; + ZEN_DEBUG("recording started in {}", m_Dir.string()); +} + +HttpTrafficRecorder::~HttpTrafficRecorder() +{ + if (m_Valid && !m_Finalized) + { + Finalize(false, Oid::Zero); + } +} + +void +HttpTrafficRecorder::WriteRequest(const char* Data, size_t Length) +{ + if (!m_Valid) + { + return; + } + m_RequestFile.Write(Data, Length, m_RequestOffset); + m_RequestOffset += Length; +} + +void +HttpTrafficRecorder::WriteResponse(const char* Data, size_t Length) +{ + if (!m_Valid) + { + return; + } + m_ResponseFile.Write(Data, Length, m_ResponseOffset); + m_ResponseOffset += Length; +} + +void +HttpTrafficRecorder::OnMessageComplete(HttpTrafficInspector::Direction Dir, + std::string_view Method, + std::string_view Url, + uint16_t StatusCode, + int64_t /*ContentLength*/) +{ + if (!m_Valid) + { + return; + } + + if (Dir == HttpTrafficInspector::Direction::Request) + { + // Record the request boundary. The request spans from m_CurrentRequestStart to m_RequestOffset. + m_PendingReqOffset = m_CurrentRequestStart; + m_PendingReqSize = m_RequestOffset - m_CurrentRequestStart; + m_PendingMethod = Method; + m_PendingUrl = Url; + m_HasPendingRequest = true; + + // Advance start to current offset for the next request. + m_CurrentRequestStart = m_RequestOffset; + } + else + { + // Response complete -- pair with pending request. + RecordedEntry Entry; + if (m_HasPendingRequest) + { + Entry.ReqOffset = m_PendingReqOffset; + Entry.ReqSize = m_PendingReqSize; + Entry.Method = std::move(m_PendingMethod); + Entry.Url = std::move(m_PendingUrl); + m_HasPendingRequest = false; + } + + Entry.RespOffset = m_CurrentResponseStart; + Entry.RespSize = m_ResponseOffset - m_CurrentResponseStart; + Entry.Status = StatusCode; + + m_Entries.push_back(std::move(Entry)); + + // Advance start to current offset for the next response. + m_CurrentResponseStart = m_ResponseOffset; + } +} + +void +HttpTrafficRecorder::Finalize(bool WebSocket, const Oid& SessionId) +{ + if (!m_Valid || m_Finalized) + { + return; + } + m_Finalized = true; + + m_RequestFile.Close(); + m_ResponseFile.Close(); + + // Write index.ucb as a CbObject. + CbObjectWriter Cbo; + + Cbo << "client" << std::string_view(m_ClientLabel); + Cbo << "target" << std::string_view(m_TargetLabel); + Cbo << "startTime" << m_StartTimeMs; + Cbo << "websocket" << WebSocket; + if (SessionId != Oid::Zero) + { + std::string SessionIdStr = SessionId.ToString(); + Cbo << "sessionId" << std::string_view(SessionIdStr); + } + + Cbo.BeginArray("entries"); + for (const RecordedEntry& Entry : m_Entries) + { + Cbo.BeginObject(); + Cbo << "reqOffset" << Entry.ReqOffset; + Cbo << "reqSize" << Entry.ReqSize; + Cbo << "respOffset" << Entry.RespOffset; + Cbo << "respSize" << Entry.RespSize; + Cbo << "method" << std::string_view(Entry.Method); + Cbo << "url" << std::string_view(Entry.Url); + Cbo << "status" << Entry.Status; + Cbo.EndObject(); + } + Cbo.EndArray(); + + CbObject IndexObj = Cbo.Save(); + + MemoryView View = IndexObj.GetView(); + std::error_code Ec; + BasicFile IndexFile(m_Dir / "index.ucb", BasicFile::Mode::kTruncate, Ec); + if (!Ec) + { + IndexFile.Write(View, 0, Ec); + if (Ec) + { + ZEN_WARN("failed to write index.ucb in {} - {}", m_Dir.string(), Ec.message()); + } + IndexFile.Close(); + } + else + { + ZEN_WARN("failed to create index.ucb in {} - {}", m_Dir.string(), Ec.message()); + } + + ZEN_DEBUG("recording finalized in {} ({} entries, websocket: {})", m_Dir.string(), m_Entries.size(), WebSocket); +} + +} // namespace zen diff --git a/src/zenserver/proxy/httptrafficrecorder.h b/src/zenserver/proxy/httptrafficrecorder.h new file mode 100644 index 000000000..bbf22a14e --- /dev/null +++ b/src/zenserver/proxy/httptrafficrecorder.h @@ -0,0 +1,81 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "httptrafficinspector.h" + +#include <zencore/basicfile.h> +#include <zencore/logging.h> + +#include <cstdint> +#include <filesystem> +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +struct RecordedEntry +{ + uint64_t ReqOffset = 0; + uint64_t ReqSize = 0; + uint64_t RespOffset = 0; + uint64_t RespSize = 0; + std::string Method; + std::string Url; + uint32_t Status = 0; +}; + +class IHttpTrafficObserver; + +class HttpTrafficRecorder : public IHttpTrafficObserver +{ +public: + HttpTrafficRecorder(const std::filesystem::path& OutputDir, std::string_view ClientLabel, std::string_view TargetLabel); + ~HttpTrafficRecorder(); + + bool IsValid() const { return m_Valid; } + + void WriteRequest(const char* Data, size_t Length); + void WriteResponse(const char* Data, size_t Length); + + // IHttpTrafficObserver + void OnMessageComplete(HttpTrafficInspector::Direction Dir, + std::string_view Method, + std::string_view Url, + uint16_t StatusCode, + int64_t ContentLength) override; + + void Finalize(bool WebSocket, const Oid& SessionId); + +private: + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + std::filesystem::path m_Dir; + std::string m_ClientLabel; + std::string m_TargetLabel; + uint64_t m_StartTimeMs = 0; + bool m_Valid = false; + bool m_Finalized = false; + + BasicFile m_RequestFile; + BasicFile m_ResponseFile; + + uint64_t m_RequestOffset = 0; + uint64_t m_ResponseOffset = 0; + + uint64_t m_CurrentRequestStart = 0; + uint64_t m_CurrentResponseStart = 0; + + // Pending request metadata waiting for its paired response. + std::string m_PendingMethod; + std::string m_PendingUrl; + uint64_t m_PendingReqOffset = 0; + uint64_t m_PendingReqSize = 0; + bool m_HasPendingRequest = false; + + std::vector<RecordedEntry> m_Entries; +}; + +} // namespace zen diff --git a/src/zenserver/proxy/tcpproxy.cpp b/src/zenserver/proxy/tcpproxy.cpp new file mode 100644 index 000000000..bdc0de164 --- /dev/null +++ b/src/zenserver/proxy/tcpproxy.cpp @@ -0,0 +1,610 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "tcpproxy.h" + +#include <zencore/logging.h> + +#include <filesystem> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// ProxyMapping + +std::string +ProxyMapping::ListenDescription() const +{ + if (IsUnixListen()) + { + return fmt::format("unix:{}", ListenUnixSocket); + } + std::string Addr = ListenAddress.empty() ? "0.0.0.0" : ListenAddress; + return fmt::format("{}:{}", Addr, ListenPort); +} + +std::string +ProxyMapping::TargetDescription() const +{ + if (IsUnixTarget()) + { + return fmt::format("unix:{}", TargetUnixSocket); + } + return fmt::format("{}:{}", TargetHost, TargetPort); +} + +////////////////////////////////////////////////////////////////////////// +// TcpProxySession + +TcpProxySession::TcpProxySession(asio::ip::tcp::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner) +: m_ClientTcpSocket(std::move(ClientSocket)) +, m_UpstreamTcpSocket(m_ClientTcpSocket.get_executor()) +#if defined(ASIO_HAS_LOCAL_SOCKETS) +, m_ClientUnixSocket(m_ClientTcpSocket.get_executor()) +, m_UpstreamUnixSocket(m_ClientTcpSocket.get_executor()) +, m_IsUnixClient(false) +, m_IsUnixTarget(Mapping.IsUnixTarget()) +#endif +, m_TargetHost(Mapping.TargetHost) +, m_TargetPort(Mapping.TargetPort) +, m_TargetUnixSocket(Mapping.TargetUnixSocket) +, m_Owner(Owner) +{ +} + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +TcpProxySession::TcpProxySession(asio::local::stream_protocol::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner) +: m_ClientTcpSocket(ClientSocket.get_executor()) +, m_UpstreamTcpSocket(ClientSocket.get_executor()) +, m_ClientUnixSocket(std::move(ClientSocket)) +, m_UpstreamUnixSocket(m_ClientUnixSocket.get_executor()) +, m_IsUnixClient(true) +, m_IsUnixTarget(Mapping.IsUnixTarget()) +, m_TargetHost(Mapping.TargetHost) +, m_TargetPort(Mapping.TargetPort) +, m_TargetUnixSocket(Mapping.TargetUnixSocket) +, m_Owner(Owner) +{ +} +#endif + +LoggerRef +TcpProxySession::Log() +{ + return m_Owner.Log(); +} + +void +TcpProxySession::Start() +{ +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixTarget) + { + ConnectToUnixTarget(); + return; + } +#endif + ConnectToTcpTarget(); +} + +void +TcpProxySession::ConnectToTcpTarget() +{ + auto Self = shared_from_this(); + auto Resolver = std::make_shared<asio::ip::tcp::resolver>(m_UpstreamTcpSocket.get_executor()); + + Resolver->async_resolve(m_TargetHost, + std::to_string(m_TargetPort), + [this, Self, Resolver](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) { + if (Ec) + { + ZEN_WARN("failed to resolve {}:{} - {}", m_TargetHost, m_TargetPort, Ec.message()); + Shutdown(); + return; + } + + asio::async_connect( + m_UpstreamTcpSocket, + Results, + [this, Self](const asio::error_code& ConnectEc, const asio::ip::tcp::endpoint& /*Endpoint*/) { + if (ConnectEc) + { + ZEN_WARN("failed to connect to {}:{} - {}", m_TargetHost, m_TargetPort, ConnectEc.message()); + Shutdown(); + return; + } + + StartRelay(); + }); + }); +} + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +void +TcpProxySession::ConnectToUnixTarget() +{ + auto Self = shared_from_this(); + + asio::local::stream_protocol::endpoint Endpoint(m_TargetUnixSocket); + + m_UpstreamUnixSocket.async_connect(Endpoint, [this, Self](const asio::error_code& Ec) { + if (Ec) + { + ZEN_WARN("failed to connect to unix:{} - {}", m_TargetUnixSocket, Ec.message()); + Shutdown(); + return; + } + + StartRelay(); + }); +} +#endif + +void +TcpProxySession::StartRelay() +{ + asio::error_code Ec; + + // TCP no_delay only applies to TCP sockets. +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (!m_IsUnixClient) +#endif + { + m_ClientTcpSocket.set_option(asio::ip::tcp::no_delay(true), Ec); + } + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (!m_IsUnixTarget) +#endif + { + m_UpstreamTcpSocket.set_option(asio::ip::tcp::no_delay(true), Ec); + } + + std::string TargetLabel = m_Owner.GetMapping().TargetDescription(); + std::string ClientLabel; + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixClient) + { + ClientLabel = "unix"; + } + else +#endif + { + asio::ip::tcp::endpoint ClientEndpoint = m_ClientTcpSocket.remote_endpoint(Ec); + if (!Ec) + { + ClientLabel = fmt::format("{}:{}", ClientEndpoint.address().to_string(), ClientEndpoint.port()); + } + else + { + ClientLabel = "?"; + } + } + + m_ClientLabel = ClientLabel; + m_StartTime = std::chrono::steady_clock::now(); + + std::string SessionLabel = fmt::format("{} -> {}", ClientLabel, TargetLabel); + + ZEN_DEBUG("session established {}", SessionLabel); + + m_RequestInspector.emplace(HttpTrafficInspector::Direction::Request, SessionLabel); + m_ResponseInspector.emplace(HttpTrafficInspector::Direction::Response, SessionLabel); + + if (m_Owner.IsRecording()) + { + std::string RecordDir = m_Owner.GetRecordDir(); + if (!RecordDir.empty()) + { + auto Now = std::chrono::system_clock::now(); + uint64_t Ms = uint64_t(std::chrono::duration_cast<std::chrono::milliseconds>(Now.time_since_epoch()).count()); + uint64_t Seq = m_Owner.m_RecordSessionCounter.fetch_add(1, std::memory_order_relaxed); + + std::filesystem::path ConnDir = std::filesystem::path(RecordDir) / fmt::format("{}_{}", Ms, Seq); + + m_Recorder = std::make_unique<HttpTrafficRecorder>(ConnDir, ClientLabel, TargetLabel); + if (m_Recorder->IsValid()) + { + m_RequestInspector->SetObserver(m_Recorder.get()); + m_ResponseInspector->SetObserver(m_Recorder.get()); + } + else + { + m_Recorder.reset(); + } + } + } + + ReadFromClient(); + ReadFromUpstream(); +} + +template<typename Fn> +void +TcpProxySession::DispatchClientSocket(Fn&& F) +{ +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixClient) + { + F(m_ClientUnixSocket); + return; + } +#endif + F(m_ClientTcpSocket); +} + +template<typename Fn> +void +TcpProxySession::DispatchUpstreamSocket(Fn&& F) +{ +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixTarget) + { + F(m_UpstreamUnixSocket); + return; + } +#endif + F(m_UpstreamTcpSocket); +} + +void +TcpProxySession::ReadFromClient() +{ + DispatchClientSocket([this](auto& Client) { DoReadFromClient(Client); }); +} + +void +TcpProxySession::ReadFromUpstream() +{ + DispatchUpstreamSocket([this](auto& Upstream) { DoReadFromUpstream(Upstream); }); +} + +template<typename SocketT> +void +TcpProxySession::DoReadFromClient(SocketT& ClientSocket) +{ + auto Self = shared_from_this(); + + ClientSocket.async_read_some(asio::buffer(m_ClientBuffer), [this, Self](const asio::error_code& Ec, size_t BytesRead) { + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_DEBUG("client read error - {}", Ec.message()); + } + Shutdown(); + return; + } + + uint64_t RequestsBefore = m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0; + if (m_RequestInspector) + { + m_RequestInspector->Inspect(m_ClientBuffer.data(), BytesRead); + } + if (m_Recorder) + { + m_Recorder->WriteRequest(m_ClientBuffer.data(), BytesRead); + } + uint64_t RequestsAfter = m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0; + uint64_t NewRequests = RequestsAfter - RequestsBefore; + + DispatchUpstreamSocket( + [this, Self, BytesRead, NewRequests](auto& Upstream) { DoForwardToUpstream(Upstream, BytesRead, NewRequests); }); + }); +} + +template<typename SocketT> +void +TcpProxySession::DoForwardToUpstream(SocketT& UpstreamSocket, size_t BytesToWrite, uint64_t NewRequests) +{ + auto Self = shared_from_this(); + + asio::async_write(UpstreamSocket, + asio::buffer(m_ClientBuffer.data(), BytesToWrite), + [this, Self, BytesToWrite, NewRequests](const asio::error_code& WriteEc, size_t /*BytesWritten*/) { + if (WriteEc) + { + if (WriteEc != asio::error::operation_aborted) + { + ZEN_DEBUG("upstream write error - {}", WriteEc.message()); + } + Shutdown(); + return; + } + + m_Owner.m_TotalBytesFromClient.fetch_add(BytesToWrite, std::memory_order_relaxed); + m_BytesFromClient.fetch_add(BytesToWrite, std::memory_order_relaxed); + m_Owner.m_BytesMeter.Mark(BytesToWrite); + if (NewRequests > 0) + { + m_Owner.m_RequestMeter.Mark(NewRequests); + } + ReadFromClient(); + }); +} + +template<typename SocketT> +void +TcpProxySession::DoReadFromUpstream(SocketT& UpstreamSocket) +{ + auto Self = shared_from_this(); + + UpstreamSocket.async_read_some(asio::buffer(m_UpstreamBuffer), [this, Self](const asio::error_code& Ec, size_t BytesRead) { + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_DEBUG("upstream read error - {}", Ec.message()); + } + Shutdown(); + return; + } + + if (m_ResponseInspector) + { + m_ResponseInspector->Inspect(m_UpstreamBuffer.data(), BytesRead); + } + if (m_Recorder) + { + m_Recorder->WriteResponse(m_UpstreamBuffer.data(), BytesRead); + } + + DispatchClientSocket([this, Self, BytesRead](auto& Client) { DoForwardToClient(Client, BytesRead); }); + }); +} + +template<typename SocketT> +void +TcpProxySession::DoForwardToClient(SocketT& ClientSocket, size_t BytesToWrite) +{ + auto Self = shared_from_this(); + + asio::async_write(ClientSocket, + asio::buffer(m_UpstreamBuffer.data(), BytesToWrite), + [this, Self, BytesToWrite](const asio::error_code& WriteEc, size_t /*BytesWritten*/) { + if (WriteEc) + { + if (WriteEc != asio::error::operation_aborted) + { + ZEN_DEBUG("client write error - {}", WriteEc.message()); + } + Shutdown(); + return; + } + + m_Owner.m_TotalBytesToClient.fetch_add(BytesToWrite, std::memory_order_relaxed); + m_BytesToClient.fetch_add(BytesToWrite, std::memory_order_relaxed); + m_Owner.m_BytesMeter.Mark(BytesToWrite); + ReadFromUpstream(); + }); +} + +template<typename SocketT> +void +TcpProxySession::DoShutdownSocket(SocketT& Socket) +{ + if (Socket.is_open()) + { + asio::error_code Ec; + Socket.shutdown(asio::socket_base::shutdown_both, Ec); + Socket.close(Ec); + } +} + +void +TcpProxySession::Shutdown() +{ + if (m_ShutdownCalled.exchange(true)) + { + return; + } + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixClient) + { + DoShutdownSocket(m_ClientUnixSocket); + } + else +#endif + { + DoShutdownSocket(m_ClientTcpSocket); + } + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_IsUnixTarget) + { + DoShutdownSocket(m_UpstreamUnixSocket); + } + else +#endif + { + DoShutdownSocket(m_UpstreamTcpSocket); + } + + if (m_Recorder) + { + bool WebSocket = m_RequestInspector && m_RequestInspector->IsUpgraded(); + Oid SessionId = m_RequestInspector ? m_RequestInspector->GetSessionId() : Oid::Zero; + m_Recorder->Finalize(WebSocket, SessionId); + } + + m_Owner.m_ActiveConnections.fetch_sub(1, std::memory_order_relaxed); + m_Owner.RemoveSession(this); +} + +////////////////////////////////////////////////////////////////////////// +// TcpProxyService + +TcpProxyService::TcpProxyService(asio::io_context& IoContext, const ProxyMapping& Mapping) +: m_Log(logging::Get("proxy")) +, m_Mapping(Mapping) +, m_IoContext(IoContext) +, m_TcpAcceptor(IoContext) +#if defined(ASIO_HAS_LOCAL_SOCKETS) +, m_UnixAcceptor(IoContext) +#endif +{ + if (!Mapping.IsUnixListen()) + { + asio::ip::address ListenAddr = + Mapping.ListenAddress.empty() ? asio::ip::address_v4::any() : asio::ip::make_address(Mapping.ListenAddress); + m_ListenEndpoint = asio::ip::tcp::endpoint(ListenAddr, Mapping.ListenPort); + } +} + +void +TcpProxyService::Start() +{ +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_Mapping.IsUnixListen()) + { + // Remove stale socket file if it exists. + std::error_code RemoveEc; + std::filesystem::remove(m_Mapping.ListenUnixSocket, RemoveEc); + + asio::local::stream_protocol::endpoint Endpoint(m_Mapping.ListenUnixSocket); + m_UnixAcceptor.open(Endpoint.protocol()); + m_UnixAcceptor.bind(Endpoint); + m_UnixAcceptor.listen(); + + ZEN_INFO("listening on {} -> {}", m_Mapping.ListenDescription(), m_Mapping.TargetDescription()); + + DoAcceptUnix(); + return; + } +#endif + + m_TcpAcceptor.open(m_ListenEndpoint.protocol()); + m_TcpAcceptor.set_option(asio::ip::tcp::acceptor::reuse_address(true)); + m_TcpAcceptor.bind(m_ListenEndpoint); + m_TcpAcceptor.listen(); + + ZEN_INFO("listening on {} -> {}", m_Mapping.ListenDescription(), m_Mapping.TargetDescription()); + + DoAccept(); +} + +void +TcpProxyService::Stop() +{ + m_Stopped = true; + + asio::error_code Ec; + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_Mapping.IsUnixListen()) + { + m_UnixAcceptor.close(Ec); + + // Clean up the socket file. + std::error_code RemoveEc; + std::filesystem::remove(m_Mapping.ListenUnixSocket, RemoveEc); + return; + } +#endif + + m_TcpAcceptor.close(Ec); +} + +void +TcpProxyService::OnAcceptedSession(std::shared_ptr<TcpProxySession> Session) +{ + m_TotalConnections.fetch_add(1, std::memory_order_relaxed); + uint64_t Active = m_ActiveConnections.fetch_add(1, std::memory_order_relaxed) + 1; + uint64_t Peak = m_PeakActiveConnections.load(std::memory_order_relaxed); + while (Active > Peak && !m_PeakActiveConnections.compare_exchange_weak(Peak, Active, std::memory_order_relaxed)) + ; + AddSession(Session); + Session->Start(); +} + +void +TcpProxyService::AddSession(std::shared_ptr<TcpProxySession> Session) +{ + RwLock::ExclusiveLockScope Lock(m_SessionsLock); + m_Sessions.push_back(std::move(Session)); +} + +void +TcpProxyService::RemoveSession(TcpProxySession* Session) +{ + RwLock::ExclusiveLockScope Lock(m_SessionsLock); + auto It = std::find_if(m_Sessions.begin(), m_Sessions.end(), [Session](const std::shared_ptr<TcpProxySession>& S) { + return S.get() == Session; + }); + if (It != m_Sessions.end()) + { + // Swap-and-pop for O(1) removal; order doesn't matter. + std::swap(*It, m_Sessions.back()); + m_Sessions.pop_back(); + } +} + +std::vector<std::shared_ptr<TcpProxySession>> +TcpProxyService::GetActiveSessions() const +{ + RwLock::SharedLockScope Lock(m_SessionsLock); + return m_Sessions; +} + +void +TcpProxyService::SetRecording(bool Enabled, const std::string& Dir) +{ + { + RwLock::ExclusiveLockScope Lock(m_RecordDirLock); + m_RecordDir = Dir; + } + m_RecordingEnabled.store(Enabled, std::memory_order_relaxed); + ZEN_INFO("proxy recording {} (dir: {})", Enabled ? "enabled" : "disabled", Dir); +} + +std::string +TcpProxyService::GetRecordDir() const +{ + RwLock::SharedLockScope Lock(m_RecordDirLock); + return m_RecordDir; +} + +void +TcpProxyService::DoAccept() +{ + m_TcpAcceptor.async_accept([this](const asio::error_code& Ec, asio::ip::tcp::socket Socket) { + if (Ec) + { + if (!m_Stopped) + { + ZEN_WARN("accept error - {}", Ec.message()); + } + return; + } + + ZEN_DEBUG("accepted connection from {}", Socket.remote_endpoint().address().to_string()); + + OnAcceptedSession(std::make_shared<TcpProxySession>(std::move(Socket), m_Mapping, *this)); + DoAccept(); + }); +} + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +void +TcpProxyService::DoAcceptUnix() +{ + m_UnixAcceptor.async_accept([this](const asio::error_code& Ec, asio::local::stream_protocol::socket Socket) { + if (Ec) + { + if (!m_Stopped) + { + ZEN_WARN("accept error - {}", Ec.message()); + } + return; + } + + ZEN_DEBUG("accepted unix connection"); + + OnAcceptedSession(std::make_shared<TcpProxySession>(std::move(Socket), m_Mapping, *this)); + DoAcceptUnix(); + }); +} +#endif + +} // namespace zen diff --git a/src/zenserver/proxy/tcpproxy.h b/src/zenserver/proxy/tcpproxy.h new file mode 100644 index 000000000..7eb5c8dff --- /dev/null +++ b/src/zenserver/proxy/tcpproxy.h @@ -0,0 +1,196 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "httptrafficinspector.h" +#include "httptrafficrecorder.h" + +#include <zencore/logging.h> +#include <zencore/thread.h> +#include <zentelemetry/stats.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +#include <atomic> +#include <chrono> +#include <cstdint> +#include <memory> +#include <optional> +#include <string> +#include <vector> + +namespace zen { + +struct ProxyMapping +{ + std::string ListenAddress; + uint16_t ListenPort = 0; + std::string ListenUnixSocket; + std::string TargetHost; + uint16_t TargetPort = 0; + std::string TargetUnixSocket; + + bool IsUnixListen() const { return !ListenUnixSocket.empty(); } + bool IsUnixTarget() const { return !TargetUnixSocket.empty(); } + std::string ListenDescription() const; + std::string TargetDescription() const; +}; + +class TcpProxyService; + +class TcpProxySession : public std::enable_shared_from_this<TcpProxySession> +{ +public: + TcpProxySession(asio::ip::tcp::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner); +#if defined(ASIO_HAS_LOCAL_SOCKETS) + TcpProxySession(asio::local::stream_protocol::socket ClientSocket, const ProxyMapping& Mapping, TcpProxyService& Owner); +#endif + + void Start(); + + const std::string& GetClientLabel() const { return m_ClientLabel; } + std::chrono::steady_clock::time_point GetStartTime() const { return m_StartTime; } + uint64_t GetBytesFromClient() const { return m_BytesFromClient.load(std::memory_order_relaxed); } + uint64_t GetBytesToClient() const { return m_BytesToClient.load(std::memory_order_relaxed); } + uint64_t GetRequestCount() const { return m_RequestInspector ? m_RequestInspector->GetMessageCount() : 0; } + bool IsWebSocket() const { return m_RequestInspector && m_RequestInspector->IsUpgraded(); } + bool HasSessionId() const { return m_RequestInspector && m_RequestInspector->HasSessionId(); } + Oid GetSessionId() const { return m_RequestInspector ? m_RequestInspector->GetSessionId() : Oid::Zero; } + +private: + LoggerRef Log(); + + void ConnectToTcpTarget(); +#if defined(ASIO_HAS_LOCAL_SOCKETS) + void ConnectToUnixTarget(); +#endif + void StartRelay(); + + void ReadFromClient(); + void ReadFromUpstream(); + + template<typename Fn> + void DispatchClientSocket(Fn&& F); + template<typename Fn> + void DispatchUpstreamSocket(Fn&& F); + + template<typename SocketT> + void DoReadFromClient(SocketT& ClientSocket); + template<typename SocketT> + void DoReadFromUpstream(SocketT& UpstreamSocket); + template<typename SocketT> + void DoForwardToUpstream(SocketT& UpstreamSocket, size_t BytesToWrite, uint64_t NewRequests); + template<typename SocketT> + void DoForwardToClient(SocketT& ClientSocket, size_t BytesToWrite); + template<typename SocketT> + void DoShutdownSocket(SocketT& Socket); + + void Shutdown(); + + asio::ip::tcp::socket m_ClientTcpSocket; + asio::ip::tcp::socket m_UpstreamTcpSocket; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + asio::local::stream_protocol::socket m_ClientUnixSocket; + asio::local::stream_protocol::socket m_UpstreamUnixSocket; + bool m_IsUnixClient = false; + bool m_IsUnixTarget = false; +#endif + + std::string m_TargetHost; + uint16_t m_TargetPort; + std::string m_TargetUnixSocket; + + TcpProxyService& m_Owner; + + static constexpr size_t kBufferSize = 16 * 1024; + + std::array<char, kBufferSize> m_ClientBuffer; + std::array<char, kBufferSize> m_UpstreamBuffer; + + std::atomic<bool> m_ShutdownCalled{false}; + + std::string m_ClientLabel; + std::chrono::steady_clock::time_point m_StartTime; + std::atomic<uint64_t> m_BytesFromClient{0}; + std::atomic<uint64_t> m_BytesToClient{0}; + + std::optional<HttpTrafficInspector> m_RequestInspector; + std::optional<HttpTrafficInspector> m_ResponseInspector; + std::unique_ptr<HttpTrafficRecorder> m_Recorder; +}; + +class TcpProxyService +{ +public: + TcpProxyService(asio::io_context& IoContext, const ProxyMapping& Mapping); + + void Start(); + void Stop(); + + const ProxyMapping& GetMapping() const { return m_Mapping; } + + uint64_t GetTotalConnections() const { return m_TotalConnections.load(std::memory_order_relaxed); } + uint64_t GetActiveConnections() const { return m_ActiveConnections.load(std::memory_order_relaxed); } + uint64_t GetPeakActiveConnections() const { return m_PeakActiveConnections.load(std::memory_order_relaxed); } + uint64_t GetTotalBytesFromClient() const { return m_TotalBytesFromClient.load(std::memory_order_relaxed); } + uint64_t GetTotalBytesToClient() const { return m_TotalBytesToClient.load(std::memory_order_relaxed); } + + metrics::Meter& GetRequestMeter() { return m_RequestMeter; } + metrics::Meter& GetBytesMeter() { return m_BytesMeter; } + + // Returns a snapshot of active sessions under a shared lock. + std::vector<std::shared_ptr<TcpProxySession>> GetActiveSessions() const; + + void SetRecording(bool Enabled, const std::string& Dir); + bool IsRecording() const { return m_RecordingEnabled.load(std::memory_order_relaxed); } + std::string GetRecordDir() const; + + LoggerRef Log() { return m_Log; } + +private: + friend class TcpProxySession; + + void DoAccept(); +#if defined(ASIO_HAS_LOCAL_SOCKETS) + void DoAcceptUnix(); +#endif + + void OnAcceptedSession(std::shared_ptr<TcpProxySession> Session); + + LoggerRef m_Log; + ProxyMapping m_Mapping; + asio::io_context& m_IoContext; + asio::ip::tcp::acceptor m_TcpAcceptor; + asio::ip::tcp::endpoint m_ListenEndpoint; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + asio::local::stream_protocol::acceptor m_UnixAcceptor; +#endif + bool m_Stopped = false; + + void AddSession(std::shared_ptr<TcpProxySession> Session); + void RemoveSession(TcpProxySession* Session); + + std::atomic<uint64_t> m_TotalConnections{0}; + std::atomic<uint64_t> m_ActiveConnections{0}; + std::atomic<uint64_t> m_PeakActiveConnections{0}; + std::atomic<uint64_t> m_TotalBytesFromClient{0}; + std::atomic<uint64_t> m_TotalBytesToClient{0}; + + metrics::Meter m_RequestMeter; + metrics::Meter m_BytesMeter; + + mutable RwLock m_SessionsLock; + std::vector<std::shared_ptr<TcpProxySession>> m_Sessions; + + std::atomic<bool> m_RecordingEnabled{false}; + mutable RwLock m_RecordDirLock; + std::string m_RecordDir; + std::atomic<uint64_t> m_RecordSessionCounter{0}; +}; + +} // namespace zen diff --git a/src/zenserver/proxy/zenproxyserver.cpp b/src/zenserver/proxy/zenproxyserver.cpp new file mode 100644 index 000000000..1fd9cd2c4 --- /dev/null +++ b/src/zenserver/proxy/zenproxyserver.cpp @@ -0,0 +1,517 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenproxyserver.h" + +#include "frontend/frontend.h" +#include "proxy/httpproxystats.h" + +#include <zenhttp/httpapiservice.h> + +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/memory/llm.h> +#include <zencore/scopeguard.h> +#include <zencore/sentryintegration.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/windows.h> +#include <zenutil/service.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cxxopts.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// Configurator + +void +ZenProxyServerConfigurator::AddCliOptions(cxxopts::Options& Options) +{ + Options.add_option("proxy", + "", + "proxy-map", + "Proxy mapping (see documentation for full format)", + cxxopts::value<std::vector<std::string>>(m_RawProxyMappings), + ""); + + Options.parse_positional({"proxy-map"}); + Options.show_positional_help(); +} + +void +ZenProxyServerConfigurator::AddConfigOptions(LuaConfig::Options& Options) +{ + ZEN_UNUSED(Options); +} + +void +ZenProxyServerConfigurator::ApplyOptions(cxxopts::Options& Options) +{ + ZEN_UNUSED(Options); +} + +void +ZenProxyServerConfigurator::OnConfigFileParsed(LuaConfig::Options& LuaOptions) +{ + ZEN_UNUSED(LuaOptions); +} + +static ProxyMapping +ParseProxyMapping(const std::string& Raw) +{ + // Preferred format using "=" as the listen/target separator: + // listen_spec=target_spec + // where listen_spec is [addr:]port or unix:path + // and target_spec is host:port or unix:path + // + // Examples: + // 9000=127.0.0.1:8558 (TCP -> TCP) + // 10.0.0.1:9000=10.0.0.2:8558 (TCP -> TCP) + // 9000=unix:/tmp/target.sock (TCP -> Unix) + // 9000=unix:C:\Users\foo\zen.sock (TCP -> Unix, Windows path) + // unix:/tmp/listen.sock=localhost:8558 (Unix -> TCP) + // unix:C:\foo\l.sock=unix:C:\foo\t.sock (Unix -> Unix, Windows paths) + // + // Legacy format using colon-only separators (no "=" present): + // [listen_addr:]listen_port:target_host:target_port (TCP -> TCP) + // [listen_addr:]listen_port:unix:target_socket_path (TCP -> Unix) + // unix:listen_socket_path:target_host:target_port (Unix -> TCP, path must not contain colons) + // unix:listen_socket_path:unix:target_socket_path (Unix -> Unix, listen path must not contain colons) + + auto ThrowBadMapping = [&](std::string_view Detail) { + throw OptionParseException(fmt::format("invalid proxy mapping '{}': {}", Raw, Detail), ""); + }; + + auto ParsePort = [&](std::string_view Field) -> uint16_t { + std::optional<uint16_t> Port = ParseInt<uint16_t>(Field); + if (!Port) + { + ThrowBadMapping(fmt::format("'{}' is not a valid port number", Field)); + } + return *Port; + }; + + auto RequireNonEmpty = [&](std::string_view Value, std::string_view Label) { + if (Value.empty()) + { + ThrowBadMapping(fmt::format("empty {}", Label)); + } + }; + + // Parse a listen spec: [addr:]port or unix:path + auto ParseListenSpec = [&](std::string_view Spec, ProxyMapping& Out) { + if (Spec.substr(0, 5) == "unix:") + { + Out.ListenUnixSocket = Spec.substr(5); + RequireNonEmpty(Out.ListenUnixSocket, "listen unix socket path"); + } + else + { + size_t ColonPos = Spec.find(':'); + if (ColonPos == std::string_view::npos) + { + Out.ListenPort = ParsePort(Spec); + } + else + { + Out.ListenAddress = Spec.substr(0, ColonPos); + Out.ListenPort = ParsePort(Spec.substr(ColonPos + 1)); + } + } + }; + + // Parse a target spec: host:port or unix:path + auto ParseTargetSpec = [&](std::string_view Spec, ProxyMapping& Out) { + if (Spec.substr(0, 5) == "unix:") + { + Out.TargetUnixSocket = Spec.substr(5); + RequireNonEmpty(Out.TargetUnixSocket, "target unix socket path"); + } + else + { + size_t ColonPos = Spec.rfind(':'); + if (ColonPos == std::string_view::npos) + { + ThrowBadMapping("target must be host:port or unix:path"); + } + Out.TargetHost = Spec.substr(0, ColonPos); + Out.TargetPort = ParsePort(Spec.substr(ColonPos + 1)); + } + }; + + ProxyMapping Mapping; + + // Check for the "=" separator first. + size_t EqPos = Raw.find('='); + if (EqPos != std::string::npos) + { + std::string_view ListenSpec = std::string_view(Raw).substr(0, EqPos); + std::string_view TargetSpec = std::string_view(Raw).substr(EqPos + 1); + + RequireNonEmpty(ListenSpec, "listen spec"); + RequireNonEmpty(TargetSpec, "target spec"); + + ParseListenSpec(ListenSpec, Mapping); + ParseTargetSpec(TargetSpec, Mapping); + return Mapping; + } + + // Legacy colon-only format. Extract fields left-to-right; when we encounter the + // "unix" keyword, everything after the next colon is the socket path taken verbatim. + // Listen-side unix socket paths must not contain colons in this format. + + auto RequireColon = [&](size_t From) -> size_t { + size_t Pos = Raw.find(':', From); + if (Pos == std::string::npos) + { + ThrowBadMapping("expected [listen_addr:]listen_port:target_host:target_port or use '=' separator"); + } + return Pos; + }; + + size_t Pos1 = RequireColon(0); + std::string Field1 = Raw.substr(0, Pos1); + + size_t Pos2 = RequireColon(Pos1 + 1); + std::string Field2 = Raw.substr(Pos1 + 1, Pos2 - Pos1 - 1); + + // unix:listen_path:... + if (Field1 == "unix") + { + Mapping.ListenUnixSocket = Field2; + RequireNonEmpty(Mapping.ListenUnixSocket, "listen unix socket path"); + + ParseTargetSpec(std::string_view(Raw).substr(Pos2 + 1), Mapping); + return Mapping; + } + + // listen_port:unix:target_socket_path + if (Field2 == "unix") + { + Mapping.ListenPort = ParsePort(Field1); + Mapping.TargetUnixSocket = Raw.substr(Pos2 + 1); + RequireNonEmpty(Mapping.TargetUnixSocket, "target unix socket path"); + return Mapping; + } + + size_t Pos3 = Raw.find(':', Pos2 + 1); + if (Pos3 == std::string::npos) + { + // listen_port:target_host:target_port + Mapping.ListenPort = ParsePort(Field1); + Mapping.TargetHost = Field2; + Mapping.TargetPort = ParsePort(std::string_view(Raw).substr(Pos2 + 1)); + return Mapping; + } + + std::string Field3 = Raw.substr(Pos2 + 1, Pos3 - Pos2 - 1); + + // listen_addr:listen_port:unix:target_socket_path + if (Field3 == "unix") + { + Mapping.ListenAddress = Field1; + Mapping.ListenPort = ParsePort(Field2); + Mapping.TargetUnixSocket = Raw.substr(Pos3 + 1); + RequireNonEmpty(Mapping.TargetUnixSocket, "target unix socket path"); + return Mapping; + } + + // listen_addr:listen_port:target_host:target_port + std::string Field4 = Raw.substr(Pos3 + 1); + if (Field4.find(':') != std::string::npos) + { + ThrowBadMapping("expected [listen_addr:]listen_port:target_host:target_port or use '=' separator"); + } + + Mapping.ListenAddress = Field1; + Mapping.ListenPort = ParsePort(Field2); + Mapping.TargetHost = Field3; + Mapping.TargetPort = ParsePort(Field4); + return Mapping; +} + +void +ZenProxyServerConfigurator::ValidateOptions() +{ + if (m_ServerOptions.BasePort == 0) + { + m_ServerOptions.BasePort = ZenProxyServerConfig::kDefaultProxyPort; + } + + if (m_ServerOptions.DataDir.empty()) + { + std::filesystem::path SystemRoot = m_ServerOptions.SystemRootDir; + if (SystemRoot.empty()) + { + SystemRoot = PickDefaultSystemRootDirectory(); + } + if (!SystemRoot.empty()) + { + m_ServerOptions.DataDir = SystemRoot / "Proxy"; + } + } + + for (const std::string& Raw : m_RawProxyMappings) + { + // The mode keyword "proxy" from argv[1] gets captured as a positional + // argument — skip it. + if (Raw == "proxy") + { + continue; + } + + m_ServerOptions.ProxyMappings.push_back(ParseProxyMapping(Raw)); + } +} + +////////////////////////////////////////////////////////////////////////// +// ZenProxyServer + +ZenProxyServer::ZenProxyServer() +{ +} + +ZenProxyServer::~ZenProxyServer() +{ + Cleanup(); +} + +int +ZenProxyServer::Initialize(const ZenProxyServerConfig& ServerConfig, ZenServerState::ZenServerEntry* ServerEntry) +{ + ZEN_TRACE_CPU("ZenProxyServer::Initialize"); + ZEN_MEMSCOPE(GetZenserverTag()); + + ZEN_INFO(ZEN_APP_NAME " initializing in PROXY server mode"); + + const int EffectiveBasePort = ZenServerBase::Initialize(ServerConfig, ServerEntry); + if (EffectiveBasePort < 0) + { + return EffectiveBasePort; + } + + for (const ProxyMapping& Mapping : ServerConfig.ProxyMappings) + { + auto Service = std::make_unique<TcpProxyService>(m_ProxyIoContext, Mapping); + Service->Start(); + m_ProxyServices.push_back(std::move(Service)); + } + + // Keep the io_context alive even when there is no pending work, so that + // worker threads don't exit prematurely between async operations. + m_ProxyIoWorkGuard = std::make_unique<asio::io_context::work>(m_ProxyIoContext); + + // Start proxy I/O worker threads. Use a modest thread count — proxy work is + // I/O-bound so we don't need a thread per core, but having more than one + // avoids head-of-line blocking when many connections are active. + unsigned int ThreadCount = std::max(GetHardwareConcurrency() / 4, 4u); + + for (unsigned int i = 0; i < ThreadCount; ++i) + { + m_ProxyIoThreads.emplace_back([this, i] { + ExtendableStringBuilder<32> ThreadName; + ThreadName << "proxy_io_" << i; + SetCurrentThreadName(ThreadName); + m_ProxyIoContext.run(); + }); + } + + ZEN_INFO("proxy I/O thread pool started with {} threads", ThreadCount); + + m_ApiService = std::make_unique<HttpApiService>(*m_Http); + m_Http->RegisterService(*m_ApiService); + + m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatusService); + m_Http->RegisterService(*m_FrontendService); + + std::string DefaultRecordDir = (m_DataRoot / "recordings").string(); + m_ProxyStatsService = std::make_unique<HttpProxyStatsService>(m_ProxyServices, m_StatsService, std::move(DefaultRecordDir)); + m_Http->RegisterService(*m_ProxyStatsService); + + EnsureIoRunner(); + + ZenServerBase::Finalize(); + + return EffectiveBasePort; +} + +void +ZenProxyServer::Run() +{ + if (m_ProcessMonitor.IsActive()) + { + CheckOwnerPid(); + } + + if (!m_TestMode) + { + // clang-format off + ZEN_INFO(R"(__________ __________ )" "\n" + R"(\____ /____ ____ \______ \_______ _______ ______.__. )" "\n" + R"( / // __ \ / \ | ___/\_ __ \/ _ \ \/ < | | )" "\n" + R"( / /\ ___/| | \ | | | | \( <_> > < \___ | )" "\n" + R"(/_______ \___ >___| / |____| |__| \____/__/\_ \/ ____| )" "\n" + R"( \/ \/ \/ \/\/ )"); + // clang-format on + + ExtendableStringBuilder<256> BuildOptions; + GetBuildOptions(BuildOptions, '\n'); + ZEN_INFO("Build options ({}/{}):\n{}", GetOperatingSystemName(), GetCpuName(), BuildOptions); + } + + ZEN_INFO(ZEN_APP_NAME " now running as PROXY (pid: {})", GetCurrentProcessId()); + +#if ZEN_PLATFORM_WINDOWS + if (zen::windows::IsRunningOnWine()) + { + ZEN_INFO("detected Wine session - " ZEN_APP_NAME " is not formally tested on Wine and may therefore not work or perform well"); + } +#endif + +#if ZEN_USE_SENTRY + ZEN_INFO("sentry crash handler {}", m_UseSentry ? "ENABLED" : "DISABLED"); + if (m_UseSentry) + { + SentryIntegration::ClearCaches(); + } +#endif + + const bool IsInteractiveMode = IsInteractiveSession(); + + SetNewState(kRunning); + + OnReady(); + + m_Http->Run(IsInteractiveMode); + + SetNewState(kShuttingDown); + + ZEN_INFO(ZEN_APP_NAME " exiting"); +} + +void +ZenProxyServer::Cleanup() +{ + ZEN_TRACE_CPU("ZenProxyServer::Cleanup"); + ZEN_INFO(ZEN_APP_NAME " cleaning up"); + try + { + for (auto& Service : m_ProxyServices) + { + Service->Stop(); + } + + m_ProxyIoWorkGuard.reset(); + m_ProxyIoContext.stop(); + for (auto& Thread : m_ProxyIoThreads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + m_ProxyIoThreads.clear(); + m_ProxyServices.clear(); + + m_IoContext.stop(); + if (m_IoRunner.joinable()) + { + m_IoRunner.join(); + } + + m_ProxyStatsService.reset(); + m_FrontendService.reset(); + m_ApiService.reset(); + + ShutdownServices(); + if (m_Http) + { + m_Http->Close(); + } + } + catch (const std::exception& Ex) + { + ZEN_ERROR("exception thrown during Cleanup() in {}: '{}'", ZEN_APP_NAME, Ex.what()); + } +} + +////////////////////////////////////////////////////////////////////////// +// ZenProxyServerMain + +ZenProxyServerMain::ZenProxyServerMain(ZenProxyServerConfig& ServerOptions) : ZenServerMain(ServerOptions), m_ServerOptions(ServerOptions) +{ +} + +void +ZenProxyServerMain::DoRun(ZenServerState::ZenServerEntry* Entry) +{ + ZenProxyServer Server; + Server.SetServerMode("Proxy"); + Server.SetDataRoot(m_ServerOptions.DataDir); + Server.SetContentRoot(m_ServerOptions.ContentDir); + Server.SetTestMode(m_ServerOptions.IsTest); + Server.SetDedicatedMode(m_ServerOptions.IsDedicated); + + const int EffectiveBasePort = Server.Initialize(m_ServerOptions, Entry); + if (EffectiveBasePort == -1) + { + std::exit(1); + } + + Entry->EffectiveListenPort = uint16_t(EffectiveBasePort); + if (EffectiveBasePort != m_ServerOptions.BasePort) + { + ZEN_INFO(ZEN_APP_NAME " - relocated to base port {}", EffectiveBasePort); + m_ServerOptions.BasePort = EffectiveBasePort; + } + + std::unique_ptr<std::thread> ShutdownThread; + std::unique_ptr<NamedEvent> ShutdownEvent; + + ExtendableStringBuilder<64> ShutdownEventName; + ShutdownEventName << "Zen_" << m_ServerOptions.BasePort << "_Shutdown"; + ShutdownEvent.reset(new NamedEvent{ShutdownEventName}); + + ShutdownThread.reset(new std::thread{[&] { + SetCurrentThreadName("shutdown_mon"); + + ZEN_INFO("shutdown monitor thread waiting for shutdown signal '{}' for process {}", ShutdownEventName, zen::GetCurrentProcessId()); + + if (ShutdownEvent->Wait()) + { + ZEN_INFO("shutdown signal for pid {} received", zen::GetCurrentProcessId()); + Server.RequestExit(0); + } + else + { + ZEN_INFO("shutdown signal wait() failed"); + } + }}); + + auto CleanupShutdown = MakeGuard([&ShutdownEvent, &ShutdownThread] { + ReportServiceStatus(ServiceStatus::Stopping); + + if (ShutdownEvent) + { + ShutdownEvent->Set(); + } + if (ShutdownThread && ShutdownThread->joinable()) + { + ShutdownThread->join(); + } + }); + + Server.SetIsReadyFunc([&] { + std::error_code Ec; + m_LockFile.Update(MakeLockData(true), Ec); + ReportServiceStatus(ServiceStatus::Running); + NotifyReady(); + }); + + Server.Run(); +} + +} // namespace zen diff --git a/src/zenserver/proxy/zenproxyserver.h b/src/zenserver/proxy/zenproxyserver.h new file mode 100644 index 000000000..7dad748cf --- /dev/null +++ b/src/zenserver/proxy/zenproxyserver.h @@ -0,0 +1,96 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenserver.h" + +#include "proxy/tcpproxy.h" + +#include <memory> +#include <thread> +#include <vector> + +namespace zen { +class HttpApiService; +class HttpFrontendService; +class HttpProxyStatsService; +} // namespace zen + +namespace cxxopts { +class Options; +} +namespace zen::LuaConfig { +struct Options; +} + +namespace zen { + +struct ZenProxyServerConfig : public ZenServerConfig +{ + static constexpr int kDefaultProxyPort = 8118; + + std::vector<ProxyMapping> ProxyMappings; +}; + +struct ZenProxyServerConfigurator : public ZenServerConfiguratorBase +{ + ZenProxyServerConfigurator(ZenProxyServerConfig& ServerOptions) + : ZenServerConfiguratorBase(ServerOptions) + , m_ServerOptions(ServerOptions) + { + } + + ~ZenProxyServerConfigurator() = default; + +private: + virtual void AddCliOptions(cxxopts::Options& Options) override; + virtual void AddConfigOptions(LuaConfig::Options& Options) override; + virtual void ApplyOptions(cxxopts::Options& Options) override; + virtual void OnConfigFileParsed(LuaConfig::Options& LuaOptions) override; + virtual void ValidateOptions() override; + + ZenProxyServerConfig& m_ServerOptions; + + std::vector<std::string> m_RawProxyMappings; +}; + +class ZenProxyServerMain : public ZenServerMain +{ +public: + ZenProxyServerMain(ZenProxyServerConfig& ServerOptions); + virtual void DoRun(ZenServerState::ZenServerEntry* Entry) override; + + ZenProxyServerMain(const ZenProxyServerMain&) = delete; + ZenProxyServerMain& operator=(const ZenProxyServerMain&) = delete; + + typedef ZenProxyServerConfig Config; + typedef ZenProxyServerConfigurator Configurator; + +private: + ZenProxyServerConfig& m_ServerOptions; +}; + +class ZenProxyServer : public ZenServerBase +{ + ZenProxyServer& operator=(ZenProxyServer&&) = delete; + ZenProxyServer(ZenProxyServer&&) = delete; + +public: + ZenProxyServer(); + ~ZenProxyServer(); + + int Initialize(const ZenProxyServerConfig& ServerConfig, ZenServerState::ZenServerEntry* ServerEntry); + void Run(); + void Cleanup(); + +private: + asio::io_context m_ProxyIoContext; + std::unique_ptr<asio::io_context::work> m_ProxyIoWorkGuard; + std::vector<std::thread> m_ProxyIoThreads; + std::vector<std::unique_ptr<TcpProxyService>> m_ProxyServices; + std::unique_ptr<HttpApiService> m_ApiService; + std::unique_ptr<HttpFrontendService> m_FrontendService; + std::unique_ptr<HttpProxyStatsService> m_ProxyStatsService; +}; + +} // namespace zen |