diff options
| author | Stefan Boberg <[email protected]> | 2026-03-10 17:27:26 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-10 17:27:26 +0100 |
| commit | d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7 (patch) | |
| tree | 2dfe1e3e0b620043d358e0b7f8bdf8320d985491 /src/zenhttp/servers | |
| parent | changelog entry which was inadvertently omitted from PR merge (diff) | |
| download | zen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.tar.xz zen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.zip | |
HttpClient using libcurl, Unix Sockets for HTTP. HTTPS support (#770)
The main goal of this change is to eliminate the cpr back-end altogether and replace it with the curl implementation. I would expect to drop cpr as soon as we feel happy with the libcurl back-end. That would leave us with a direct dependency on libcurl only, and cpr can be eliminated as a dependency.
### HttpClient Backend Overhaul
- Implemented a new **libcurl-based HttpClient** backend (`httpclientcurl.cpp`, ~2000 lines)
as an alternative to the cpr-based one
- Made HttpClient backend **configurable at runtime** via constructor arguments
and `-httpclient=...` CLI option (for zen, zenserver, and tests)
- Extended HttpClient test suite to cover multipart/content-range scenarios
### Unix Domain Socket Support
- Added Unix domain socket support to **httpasio** (server side)
- Added Unix domain socket support to **HttpClient**
- Added Unix domain socket support to **HttpWsClient** (WebSocket client)
- Templatized `HttpServerConnectionT<SocketType>` and `WsAsioConnectionT<SocketType>`
to handle TCP, Unix, and SSL sockets uniformly via `if constexpr` dispatch
### HTTPS Support
- Added **preliminary HTTPS support to httpasio** (for Mac/Linux via OpenSSL)
- Added **basic HTTPS support for http.sys** (Windows)
- Implemented HTTPS test for httpasio
- Split `InitializeServer` into smaller sub-functions for http.sys
### Other Notable Changes
- Improved **zenhttp-test stability** with dynamic port allocation
- Enhanced port retry logic in http.sys (handles ERROR_ACCESS_DENIED)
- Fatal signal/exception handlers for backtrace generation in tests
- Added `zen bench http` subcommand to exercise network + HTTP client/server communication stack
Diffstat (limited to 'src/zenhttp/servers')
| -rw-r--r-- | src/zenhttp/servers/asio_socket_traits.h | 54 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpasio.cpp | 687 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpasio.h | 6 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpsys.cpp | 409 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpsys.h | 4 | ||||
| -rw-r--r-- | src/zenhttp/servers/wsasio.cpp | 64 | ||||
| -rw-r--r-- | src/zenhttp/servers/wsasio.h | 43 | ||||
| -rw-r--r-- | src/zenhttp/servers/wstest.cpp | 73 |
8 files changed, 1060 insertions, 280 deletions
diff --git a/src/zenhttp/servers/asio_socket_traits.h b/src/zenhttp/servers/asio_socket_traits.h new file mode 100644 index 000000000..25aeaa24e --- /dev/null +++ b/src/zenhttp/servers/asio_socket_traits.h @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +#if ZEN_USE_OPENSSL +# include <asio/ssl.hpp> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::asio_http { + +/** + * Traits for abstracting socket shutdown/close across plain TCP, Unix domain, and SSL sockets. + * SSL sockets need lowest_layer() access and have different shutdown semantics. + */ +template<typename SocketType> +struct SocketTraits +{ + /// SSL sockets cannot use zero-copy file send (TransmitFile/sendfile) because + /// those bypass the encryption layer. This flag lets templated code fall back + /// to reading-into-memory for SSL connections. + static constexpr bool IsSslSocket = false; + + static void ShutdownReceive(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_receive, Ec); } + + static void ShutdownBoth(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_both, Ec); } + + static void Close(SocketType& S, std::error_code& Ec) { S.close(Ec); } +}; + +#if ZEN_USE_OPENSSL +using SslSocket = asio::ssl::stream<asio::ip::tcp::socket>; + +template<> +struct SocketTraits<SslSocket> +{ + static constexpr bool IsSslSocket = true; + + static void ShutdownReceive(SslSocket& S, std::error_code& Ec) { S.lowest_layer().shutdown(asio::socket_base::shutdown_receive, Ec); } + + static void ShutdownBoth(SslSocket& S, std::error_code& Ec) + { + // Best-effort SSL close_notify, then TCP shutdown + S.shutdown(Ec); + S.lowest_layer().shutdown(asio::socket_base::shutdown_both, Ec); + } + + static void Close(SslSocket& S, std::error_code& Ec) { S.lowest_layer().close(Ec); } +}; +#endif + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index f5178ebe8..ee8e71256 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -1,6 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "httpasio.h" +#include "asio_socket_traits.h" #include "httptracer.h" #include <zencore/except.h> @@ -35,6 +36,12 @@ ZEN_THIRD_PARTY_INCLUDES_START #endif #include <asio.hpp> #include <asio/stream_file.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif +#if ZEN_USE_OPENSSL +# include <asio/ssl.hpp> +#endif ZEN_THIRD_PARTY_INCLUDES_END #define ASIO_VERBOSE_TRACE 0 @@ -144,7 +151,17 @@ using namespace std::literals; struct HttpAcceptor; struct HttpResponse; -struct HttpServerConnection; +template<typename SocketType> +struct HttpServerConnectionT; +using HttpServerConnection = HttpServerConnectionT<asio::ip::tcp::socket>; +#if defined(ASIO_HAS_LOCAL_SOCKETS) +struct UnixAcceptor; +using UnixServerConnection = HttpServerConnectionT<asio::local::stream_protocol::socket>; +#endif +#if ZEN_USE_OPENSSL +struct HttpsAcceptor; +using HttpsSslServerConnection = HttpServerConnectionT<SslSocket>; +#endif inline LoggerRef InitLogger() @@ -176,9 +193,9 @@ Log() #endif #if ZEN_USE_TRANSMITFILE -template<typename Handler> +template<typename Handler, typename SocketType> void -TransmitFileAsync(asio::ip::tcp::socket& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb) +TransmitFileAsync(SocketType& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb) { # if ZEN_BUILD_DEBUG const uint64_t FileSize = FileSizeFromHandle(FileHandle); @@ -511,11 +528,20 @@ public: bool IsLoopbackOnly() const; + int GetEffectiveHttpsPort() const; + asio::io_service m_IoService; asio::io_service::work m_Work{m_IoService}; std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; - std::vector<std::thread> m_ThreadPool; - std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + std::unique_ptr<asio_http::UnixAcceptor> m_UnixAcceptor; +#endif +#if ZEN_USE_OPENSSL + std::unique_ptr<asio::ssl::context> m_SslContext; + std::unique_ptr<asio_http::HttpsAcceptor> m_HttpsAcceptor; +#endif + std::vector<std::thread> m_ThreadPool; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; LoggerRef m_RequestLog; HttpServerTracer m_RequestTracer; @@ -573,6 +599,7 @@ public: uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers IoBuffer m_PayloadBuffer; bool m_IsLocalMachineRequest; + bool m_AllowZeroCopyFileSend = true; std::string m_RemoteAddress; std::unique_ptr<HttpResponse> m_Response; }; @@ -595,6 +622,8 @@ public: ~HttpResponse() = default; + void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; } + /** * Initialize the response for sending a payload made up of multiple blobs * @@ -636,7 +665,7 @@ public: bool ChunkHandled = false; #if ZEN_USE_TRANSMITFILE || ZEN_USE_ASYNC_SENDFILE - if (OwnedBuffer.IsWholeFile()) + if (m_AllowZeroCopyFileSend && OwnedBuffer.IsWholeFile()) { if (IoBufferFileReference FileRef; OwnedBuffer.GetFileReference(/* out */ FileRef)) { @@ -751,7 +780,8 @@ public: return m_Headers; } - void SendResponse(asio::ip::tcp::socket& TcpSocket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token) + template<typename SocketType> + void SendResponse(SocketType& Socket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token) { ZEN_ASSERT(m_State == State::kInitialized); @@ -761,10 +791,11 @@ public: m_SendCb = std::move(Token); m_State = State::kSending; - SendNextChunk(TcpSocket); + SendNextChunk(Socket); } - void SendNextChunk(asio::ip::tcp::socket& TcpSocket) + template<typename SocketType> + void SendNextChunk(SocketType& Socket) { ZEN_ASSERT(m_State == State::kSending); @@ -781,12 +812,12 @@ public: auto CompletionToken = [Self = this, Token = std::move(m_SendCb), TotalBytes = m_TotalBytesSent] { Token({}, TotalBytes); }; - asio::defer(TcpSocket.get_executor(), std::move(CompletionToken)); + asio::defer(Socket.get_executor(), std::move(CompletionToken)); return; } - auto OnCompletion = [this, &TcpSocket](const asio::error_code& Ec, std::size_t ByteCount) { + auto OnCompletion = [this, &Socket](const asio::error_code& Ec, std::size_t ByteCount) { ZEN_ASSERT(m_State == State::kSending); m_TotalBytesSent += ByteCount; @@ -797,7 +828,7 @@ public: } else { - SendNextChunk(TcpSocket); + SendNextChunk(Socket); } }; @@ -811,25 +842,21 @@ public: Io.Ref.FileRef.FileChunkSize); #if ZEN_USE_TRANSMITFILE - TransmitFileAsync(TcpSocket, + TransmitFileAsync(Socket, Io.Ref.FileRef.FileHandle, Io.Ref.FileRef.FileChunkOffset, gsl::narrow_cast<uint32_t>(Io.Ref.FileRef.FileChunkSize), OnCompletion); + return; #elif ZEN_USE_ASYNC_SENDFILE - SendFileAsync(TcpSocket, + SendFileAsync(Socket, Io.Ref.FileRef.FileHandle, Io.Ref.FileRef.FileChunkOffset, Io.Ref.FileRef.FileChunkSize, 64 * 1024, OnCompletion); -#else - // This should never occur unless we compile with one - // of the options above - ZEN_WARN("invalid file reference in response"); -#endif - return; +#endif } // Send as many consecutive non-file references as possible in one asio operation @@ -850,7 +877,7 @@ public: ++m_IoVecCursor; } - asio::async_write(TcpSocket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion); + asio::async_write(Socket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion); } private: @@ -863,12 +890,13 @@ private: kFailed }; - uint32_t m_RequestNumber = 0; - uint16_t m_ResponseCode = 0; - bool m_IsKeepAlive = true; - State m_State = State::kUninitialized; - HttpContentType m_ContentType = HttpContentType::kBinary; - uint64_t m_ContentLength = 0; + uint32_t m_RequestNumber = 0; + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + bool m_AllowZeroCopyFileSend = true; + State m_State = State::kUninitialized; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive ExtendableStringBuilder<160> m_Headers; @@ -895,12 +923,13 @@ private: ////////////////////////////////////////////////////////////////////////// -struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnection> +template<typename SocketType> +struct HttpServerConnectionT : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnectionT<SocketType>> { - HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket); - ~HttpServerConnection(); + HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket); + ~HttpServerConnectionT(); - std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); } + std::shared_ptr<HttpServerConnectionT> AsSharedPtr() { return this->shared_from_this(); } // HttpConnectionBase implementation @@ -962,12 +991,13 @@ private: RwLock m_ActiveResponsesLock; std::deque<std::unique_ptr<HttpResponse>> m_ActiveResponses; - std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<SocketType> m_Socket; }; std::atomic<uint32_t> g_ConnectionIdCounter{0}; -HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket) +template<typename SocketType> +HttpServerConnectionT<SocketType>::HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket) : m_Server(Server) , m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) , m_Socket(std::move(Socket)) @@ -975,21 +1005,24 @@ HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::uniq ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); } -HttpServerConnection::~HttpServerConnection() +template<typename SocketType> +HttpServerConnectionT<SocketType>::~HttpServerConnectionT() { RwLock::ExclusiveLockScope _(m_ActiveResponsesLock); ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); } +template<typename SocketType> void -HttpServerConnection::HandleNewRequest() +HttpServerConnectionT<SocketType>::HandleNewRequest() { EnqueueRead(); } +template<typename SocketType> void -HttpServerConnection::TerminateConnection() +HttpServerConnectionT<SocketType>::TerminateConnection() { if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) { @@ -1001,12 +1034,13 @@ HttpServerConnection::TerminateConnection() // Terminating, we don't care about any errors when closing socket std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_both, Ec); - m_Socket->close(Ec); + SocketTraits<SocketType>::ShutdownBoth(*m_Socket, Ec); + SocketTraits<SocketType>::Close(*m_Socket, Ec); } +template<typename SocketType> void -HttpServerConnection::EnqueueRead() +HttpServerConnectionT<SocketType>::EnqueueRead() { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1027,8 +1061,9 @@ HttpServerConnection::EnqueueRead() [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); } +template<typename SocketType> void -HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +HttpServerConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1086,11 +1121,12 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused] } } +template<typename SocketType> void -HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, - [[maybe_unused]] std::size_t ByteCount, - [[maybe_unused]] uint32_t RequestNumber, - HttpResponse* ResponseToPop) +HttpServerConnectionT<SocketType>::OnResponseDataSent(const asio::error_code& Ec, + [[maybe_unused]] std::size_t ByteCount, + [[maybe_unused]] uint32_t RequestNumber, + HttpResponse* ResponseToPop) { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1144,8 +1180,9 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, } } +template<typename SocketType> void -HttpServerConnection::CloseConnection() +HttpServerConnectionT<SocketType>::CloseConnection() { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1157,23 +1194,24 @@ HttpServerConnection::CloseConnection() m_RequestState = RequestState::kDone; std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec); if (Ec) { ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); } - m_Socket->close(Ec); + SocketTraits<SocketType>::Close(*m_Socket, Ec); if (Ec) { ZEN_WARN("socket close ERROR, reason '{}'", Ec.message()); } } +template<typename SocketType> void -HttpServerConnection::SendInlineResponse(uint32_t RequestNumber, - std::string_view StatusLine, - std::string_view Headers, - std::string_view Body) +HttpServerConnectionT<SocketType>::SendInlineResponse(uint32_t RequestNumber, + std::string_view StatusLine, + std::string_view Headers, + std::string_view Body) { ExtendableStringBuilder<256> ResponseBuilder; ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n"; @@ -1194,15 +1232,16 @@ HttpServerConnection::SendInlineResponse(uint32_t RequestNumber, IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size()); auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize()); asio::async_write( - *m_Socket.get(), + *m_Socket, Buffer, [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); }); } +template<typename SocketType> void -HttpServerConnection::HandleRequest() +HttpServerConnectionT<SocketType>::HandleRequest() { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1229,24 +1268,25 @@ HttpServerConnection::HandleRequest() ResponseStr->append("\r\n\r\n"); // Send the 101 response on the current socket, then hand the socket off - // to a WsAsioConnection for the WebSocket protocol. - asio::async_write(*m_Socket, - asio::buffer(ResponseStr->data(), ResponseStr->size()), - [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); - return; - } - - Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); - Ref<WsAsioConnection> WsConn( - new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); - Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); - WsConn->Start(); - }); + // to a WsAsioConnectionT for the WebSocket protocol. + asio::async_write( + *m_Socket, + asio::buffer(ResponseStr->data(), ResponseStr->size()), + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); + return; + } + + Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); + using WsConnType = WsAsioConnectionT<SocketType>; + Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + }); m_RequestState = RequestState::kDone; return; @@ -1260,7 +1300,7 @@ HttpServerConnection::HandleRequest() m_RequestState = RequestState::kWritingFinal; std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec); if (Ec) { @@ -1280,15 +1320,36 @@ HttpServerConnection::HandleRequest() m_Server.m_HttpServer->MarkRequest(); - auto RemoteEndpoint = m_Socket->remote_endpoint(); - bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address(); + bool IsLocalConnection = true; + std::string RemoteAddress; + + if constexpr (std::is_same_v<SocketType, asio::ip::tcp::socket>) + { + auto RemoteEndpoint = m_Socket->remote_endpoint(); + IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address(); + RemoteAddress = RemoteEndpoint.address().to_string(); + } +#if ZEN_USE_OPENSSL + else if constexpr (std::is_same_v<SocketType, SslSocket>) + { + auto RemoteEndpoint = m_Socket->lowest_layer().remote_endpoint(); + IsLocalConnection = m_Socket->lowest_layer().local_endpoint().address() == RemoteEndpoint.address(); + RemoteAddress = RemoteEndpoint.address().to_string(); + } +#endif + else + { + RemoteAddress = "unix"; + } HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber, IsLocalConnection, - RemoteEndpoint.address().to_string()); + std::move(RemoteAddress)); + + Request.m_AllowZeroCopyFileSend = !SocketTraits<SocketType>::IsSslSocket; ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber); @@ -1439,14 +1500,23 @@ HttpServerConnection::HandleRequest() } ////////////////////////////////////////////////////////////////////////// +// Base class for TCP acceptors that handles socket setup, port binding +// with probing/retry, and dual-stack (IPv6+IPv4 loopback) support. +// Subclasses only need to implement OnAccept() to handle new connections. -struct HttpAcceptor +struct TcpAcceptorBase { - HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) + TcpAcceptorBase(HttpAsioServerImpl& Server, + asio::io_service& IoService, + uint16_t BasePort, + bool ForceLoopback, + bool AllowPortProbing, + std::string_view Label) : m_Server(Server) , m_IoService(IoService) , m_Acceptor(m_IoService, asio::ip::tcp::v6()) , m_AlternateProtocolAcceptor(m_IoService, asio::ip::tcp::v4()) + , m_Label(Label) { const bool IsUsingIPv6 = IsIPv6Capable(); if (!IsUsingIPv6) @@ -1455,7 +1525,6 @@ struct HttpAcceptor } #if ZEN_PLATFORM_WINDOWS - // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> exclusive_address; m_Acceptor.set_option(exclusive_address(true)); m_AlternateProtocolAcceptor.set_option(exclusive_address(true)); @@ -1468,83 +1537,54 @@ struct HttpAcceptor #endif // ZEN_PLATFORM_WINDOWS m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); - m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); - m_AlternateProtocolAcceptor.set_option(asio::ip::tcp::no_delay(true)); - m_AlternateProtocolAcceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_AlternateProtocolAcceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - std::string BoundBaseUrl; if (IsUsingIPv6) { - BoundBaseUrl = BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing); + BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing); } else { - ZEN_INFO("NOTE: ipv6 support is disabled, binding to ipv4 only"); - - BoundBaseUrl = BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing); + ZEN_INFO("{}: ipv6 support is disabled, binding to ipv4 only", m_Label); + BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing); } + } - if (!IsValid()) - { - return; - } - -#if ZEN_PLATFORM_WINDOWS - // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. - // This must be used by both the client and server side, and is only effective in the absence of - // Windows Filtering Platform (WFP) callouts which can be installed by security software. - // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path - SOCKET NativeSocket = m_Acceptor.native_handle(); - int LoopbackOptionValue = 1; - DWORD OptionNumberOfBytesReturned = 0; - WSAIoctl(NativeSocket, - SIO_LOOPBACK_FAST_PATH, - &LoopbackOptionValue, - sizeof(LoopbackOptionValue), - NULL, - 0, - &OptionNumberOfBytesReturned, - 0, - 0); - - if (m_UseAlternateProtocolAcceptor) - { - NativeSocket = m_AlternateProtocolAcceptor.native_handle(); - WSAIoctl(NativeSocket, - SIO_LOOPBACK_FAST_PATH, - &LoopbackOptionValue, - sizeof(LoopbackOptionValue), - NULL, - 0, - &OptionNumberOfBytesReturned, - 0, - 0); - } -#endif - m_Acceptor.listen(); + virtual ~TcpAcceptorBase() + { + m_Acceptor.close(); if (m_UseAlternateProtocolAcceptor) { - m_AlternateProtocolAcceptor.listen(); + m_AlternateProtocolAcceptor.close(); } - - ZEN_INFO("Started asio server at '{}", BoundBaseUrl); } - ~HttpAcceptor() + void Start() { - m_Acceptor.close(); + ZEN_ASSERT(!m_IsStopped); + InitAcceptLoop(m_Acceptor); if (m_UseAlternateProtocolAcceptor) { - m_AlternateProtocolAcceptor.close(); + InitAcceptLoop(m_AlternateProtocolAcceptor); } } + void StopAccepting() { m_IsStopped = true; } + + uint16_t GetPort() const { return m_Acceptor.local_endpoint().port(); } + bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); } + bool IsValid() const { return m_IsValid; } + +protected: + /// Called for each accepted TCP socket. Subclasses create the appropriate connection type. + virtual void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) = 0; + + HttpAsioServerImpl& m_Server; + asio::io_service& m_IoService; + +private: template<typename AddressType> - std::string BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) + void BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) { uint16_t EffectivePort = BasePort; @@ -1571,7 +1611,7 @@ struct HttpAcceptor if (BindErrorCode == asio::error::access_denied && !BindAddress.is_loopback()) { - ZEN_INFO("Access denied for public port {}, falling back to loopback", BasePort); + ZEN_INFO("{}: Access denied for public port {}, falling back to loopback", m_Label, BasePort); BindAddress = AddressType::loopback(); @@ -1585,7 +1625,7 @@ struct HttpAcceptor if (BindErrorCode == asio::error::address_in_use) { - ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message()); + ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message()); Sleep(500); m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode); } @@ -1601,7 +1641,8 @@ struct HttpAcceptor if (BindErrorCode) { - ZEN_INFO("Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')", + ZEN_INFO("{}: Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')", + m_Label, BindErrorCode.message()); EffectivePort = 0; @@ -1617,7 +1658,7 @@ struct HttpAcceptor { for (uint32_t Retries = 0; (BindErrorCode == asio::error::address_in_use) && (Retries < 3); Retries++) { - ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message()); + ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message()); Sleep(500); m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode); } @@ -1625,14 +1666,13 @@ struct HttpAcceptor if (BindErrorCode) { - ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message()); - - return {}; + ZEN_WARN("{}: Unable to bind on port {} (bind returned '{}')", m_Label, BasePort, BindErrorCode.message()); + return; } if (EffectivePort != BasePort) { - ZEN_WARN("Desired port {} is in use, remapped to port {}", BasePort, EffectivePort); + ZEN_WARN("{}: Desired port {} is in use, remapped to port {}", m_Label, BasePort, EffectivePort); } if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>) @@ -1642,55 +1682,64 @@ struct HttpAcceptor // IPv6 loopback will only respond on the IPv6 loopback address. Not everyone does // IPv6 though so we also bind to IPv4 loopback (localhost/127.0.0.1) - m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), BindErrorCode); + asio::error_code AltEc; + m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), AltEc); - if (BindErrorCode) + if (AltEc) { - ZEN_WARN("Failed to register secondary IPv4 local-only handler 'http://{}:{}/'", "localhost", EffectivePort); + ZEN_WARN("{}: Failed to register secondary IPv4 local-only handler on port {}", m_Label, EffectivePort); } else { m_UseAlternateProtocolAcceptor = true; - ZEN_INFO("Registered local-only handler 'http://{}:{}/' - this is not accessible from remote hosts", - "localhost", - EffectivePort); } } } - m_IsValid = true; +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor.native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); - if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>) - { - return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "[::1]" : "*", EffectivePort); - } - else + if (m_UseAlternateProtocolAcceptor) { - return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "127.0.0.1" : "*", EffectivePort); + NativeSocket = m_AlternateProtocolAcceptor.native_handle(); + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); } - } - - void Start() - { - ZEN_MEMSCOPE(GetHttpasioTag()); +#endif - ZEN_ASSERT(!m_IsStopped); - InitAcceptInternal(m_Acceptor); + m_Acceptor.listen(); if (m_UseAlternateProtocolAcceptor) { - InitAcceptInternal(m_AlternateProtocolAcceptor); + m_AlternateProtocolAcceptor.listen(); } - } - void StopAccepting() { m_IsStopped = true; } - - int GetAcceptPort() const { return m_Acceptor.local_endpoint().port(); } - bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); } - - bool IsValid() const { return m_IsValid; } + m_IsValid = true; + ZEN_INFO("{}: Listening on port {}", m_Label, m_Acceptor.local_endpoint().port()); + } -private: - void InitAcceptInternal(asio::ip::tcp::acceptor& Acceptor) + void InitAcceptLoop(asio::ip::tcp::acceptor& Acceptor) { auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService); asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); @@ -1698,29 +1747,19 @@ private: Acceptor.async_accept(SocketRef, [this, &Acceptor, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { if (Ec) { - ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", - Acceptor.local_endpoint().address().to_string(), - Acceptor.local_endpoint().port(), - Ec.message()); + if (!m_IsStopped.load()) + { + ZEN_WARN("{}: async_accept failed: '{}'", m_Label, Ec.message()); + } } else { - // New connection established, pass socket ownership into connection object - // and initiate request handling loop. The connection lifetime is - // managed by the async read/write loop by passing the shared - // reference to the callbacks. - - Socket->set_option(asio::ip::tcp::no_delay(true)); - Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket)); - Conn->HandleNewRequest(); + OnAccept(std::move(Socket)); } if (!m_IsStopped.load()) { - InitAcceptInternal(Acceptor); + InitAcceptLoop(Acceptor); } else { @@ -1728,21 +1767,204 @@ private: Acceptor.close(CloseEc); if (CloseEc) { - ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message()); + ZEN_WARN("{}: acceptor close error: '{}'", m_Label, CloseEc.message()); } } }); } - HttpAsioServerImpl& m_Server; - asio::io_service& m_IoService; asio::ip::tcp::acceptor m_Acceptor; asio::ip::tcp::acceptor m_AlternateProtocolAcceptor; bool m_UseAlternateProtocolAcceptor{false}; bool m_IsValid{false}; std::atomic<bool> m_IsStopped{false}; + std::string_view m_Label; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpAcceptor final : TcpAcceptorBase +{ + HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) + : TcpAcceptorBase(Server, IoService, BasePort, ForceLoopback, AllowPortProbing, "HTTP") + { + } + + int GetAcceptPort() const { return GetPort(); } + +protected: + void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override + { + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + } }; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + +////////////////////////////////////////////////////////////////////////// + +struct UnixAcceptor +{ + UnixAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, const std::string& SocketPath) + : m_Server(Server) + , m_IoService(IoService) + , m_Acceptor(m_IoService) + , m_SocketPath(SocketPath) + { + // Remove any stale socket file from a previous run + std::filesystem::remove(m_SocketPath); + + asio::local::stream_protocol::endpoint Endpoint(m_SocketPath); + + asio::error_code Ec; + m_Acceptor.open(Endpoint.protocol(), Ec); + if (Ec) + { + ZEN_WARN("failed to open unix domain socket: {}", Ec.message()); + return; + } + + m_Acceptor.bind(Endpoint, Ec); + if (Ec) + { + ZEN_WARN("failed to bind unix domain socket at '{}': {}", m_SocketPath, Ec.message()); + return; + } + + m_Acceptor.listen(asio::socket_base::max_listen_connections, Ec); + if (Ec) + { + ZEN_WARN("failed to listen on unix domain socket at '{}': {}", m_SocketPath, Ec.message()); + return; + } + + m_IsValid = true; + ZEN_INFO("Started unix domain socket listener at '{}'", m_SocketPath); + } + + ~UnixAcceptor() + { + asio::error_code Ec; + m_Acceptor.close(Ec); + std::filesystem::remove(m_SocketPath); + } + + void Start() + { + ZEN_ASSERT(!m_IsStopped); + InitAccept(); + } + + void StopAccepting() { m_IsStopped = true; } + + bool IsValid() const { return m_IsValid; } + +private: + void InitAccept() + { + auto SocketPtr = std::make_unique<asio::local::stream_protocol::socket>(m_IoService); + asio::local::stream_protocol::socket& SocketRef = *SocketPtr.get(); + + m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + if (Ec) + { + if (!m_IsStopped.load()) + { + ZEN_WARN("unix domain socket async_accept failed: '{}'", Ec.message()); + } + } + else + { + auto Conn = std::make_shared<UnixServerConnection>(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + } + + if (!m_IsStopped.load()) + { + InitAccept(); + } + else + { + std::error_code CloseEc; + m_Acceptor.close(CloseEc); + } + }); + } + + HttpAsioServerImpl& m_Server; + asio::io_service& m_IoService; + asio::local::stream_protocol::acceptor m_Acceptor; + std::string m_SocketPath; + bool m_IsValid{false}; + std::atomic<bool> m_IsStopped{false}; +}; + +#endif // ASIO_HAS_LOCAL_SOCKETS + +#if ZEN_USE_OPENSSL + +////////////////////////////////////////////////////////////////////////// + +struct HttpsAcceptor final : TcpAcceptorBase +{ + HttpsAcceptor(HttpAsioServerImpl& Server, + asio::io_service& IoService, + asio::ssl::context& SslContext, + uint16_t Port, + bool ForceLoopback, + bool AllowPortProbing) + : TcpAcceptorBase(Server, IoService, Port, ForceLoopback, AllowPortProbing, "HTTPS") + , m_SslContext(SslContext) + { + } + +protected: + void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override + { + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + // Wrap accepted TCP socket in an SSL stream and perform the handshake + auto SslSocketPtr = std::make_unique<SslSocket>(std::move(*Socket), m_SslContext); + + SslSocket& SslRef = *SslSocketPtr; + SslRef.async_handshake(asio::ssl::stream_base::server, + [this, SslSocket = std::move(SslSocketPtr)](const asio::error_code& HandshakeEc) mutable { + if (HandshakeEc) + { + ZEN_WARN("SSL handshake failed: '{}'", HandshakeEc.message()); + std::error_code Ec; + SslSocket->lowest_layer().close(Ec); + return; + } + + auto Conn = std::make_shared<HttpsSslServerConnection>(m_Server, std::move(SslSocket)); + Conn->HandleNewRequest(); + }); + } + +private: + asio::ssl::context& m_SslContext; +}; + +#endif // ZEN_USE_OPENSSL + +int +HttpAsioServerImpl::GetEffectiveHttpsPort() const +{ +#if ZEN_USE_OPENSSL + return m_HttpsAcceptor ? m_HttpsAcceptor->GetPort() : 0; +#else + return 0; +#endif +} + ////////////////////////////////////////////////////////////////////////// HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, @@ -1860,6 +2082,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber)); + m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -1873,6 +2096,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); + m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } @@ -1883,6 +2107,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); + m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); @@ -1942,6 +2167,51 @@ HttpAsioServerImpl::Start(uint16_t Port, const AsioConfig& Config) m_Acceptor->Start(); +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (!Config.UnixSocketPath.empty()) + { + m_UnixAcceptor.reset(new asio_http::UnixAcceptor(*this, m_IoService, Config.UnixSocketPath)); + + if (m_UnixAcceptor->IsValid()) + { + m_UnixAcceptor->Start(); + } + else + { + m_UnixAcceptor.reset(); + } + } +#endif + +#if ZEN_USE_OPENSSL + if (!Config.CertFile.empty() && !Config.KeyFile.empty()) + { + m_SslContext = std::make_unique<asio::ssl::context>(asio::ssl::context::tlsv12_server); + m_SslContext->set_options(asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 | + asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1); + m_SslContext->use_certificate_chain_file(Config.CertFile); + m_SslContext->use_private_key_file(Config.KeyFile, asio::ssl::context::pem); + + ZEN_INFO("SSL context initialized (cert: '{}', key: '{}')", Config.CertFile, Config.KeyFile); + + m_HttpsAcceptor.reset(new asio_http::HttpsAcceptor(*this, + m_IoService, + *m_SslContext, + gsl::narrow<uint16_t>(Config.HttpsPort), + Config.ForceLoopback, + /*AllowPortProbing*/ !Config.IsDedicatedServer)); + + if (m_HttpsAcceptor->IsValid()) + { + m_HttpsAcceptor->Start(); + } + else + { + m_HttpsAcceptor.reset(); + } + } +#endif + // This should consist of a set of minimum threads and grow on demand to // meet concurrency needs? Right now we end up allocating a large number // of threads even if we never end up using all of them, which seems @@ -1990,6 +2260,18 @@ HttpAsioServerImpl::Stop() { m_Acceptor->StopAccepting(); } +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_UnixAcceptor) + { + m_UnixAcceptor->StopAccepting(); + } +#endif +#if ZEN_USE_OPENSSL + if (m_HttpsAcceptor) + { + m_HttpsAcceptor->StopAccepting(); + } +#endif m_IoService.stop(); for (auto& Thread : m_ThreadPool) { @@ -1999,7 +2281,23 @@ HttpAsioServerImpl::Stop() } } m_ThreadPool.clear(); + + // Drain remaining handlers (e.g. cancellation callbacks from active WebSocket + // connections) so that their captured Ref<> pointers are released while the + // io_service and its epoll reactor are still alive. Without this, sockets + // held by external code (e.g. IWebSocketHandler connection lists) can outlive + // the reactor and crash during deregistration. + m_IoService.restart(); + m_IoService.poll(); + m_Acceptor.reset(); +#if defined(ASIO_HAS_LOCAL_SOCKETS) + m_UnixAcceptor.reset(); +#endif +#if ZEN_USE_OPENSSL + m_HttpsAcceptor.reset(); + m_SslContext.reset(); +#endif } void @@ -2166,6 +2464,13 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Config); +#if ZEN_USE_OPENSSL + if (int EffectiveHttpsPort = m_Impl->GetEffectiveHttpsPort(); EffectiveHttpsPort > 0) + { + SetEffectiveHttpsPort(EffectiveHttpsPort); + } +#endif + return m_BasePort; } diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h index 3ec1141a7..5adf4d5e8 100644 --- a/src/zenhttp/servers/httpasio.h +++ b/src/zenhttp/servers/httpasio.h @@ -11,6 +11,12 @@ struct AsioConfig unsigned int ThreadCount = 0; bool ForceLoopback = false; bool IsDedicatedServer = false; + std::string UnixSocketPath; +#if ZEN_USE_OPENSSL + int HttpsPort = 0; // 0 = auto-assign; set CertFile/KeyFile to enable HTTPS + std::string CertFile; // PEM certificate chain file (empty = HTTPS disabled) + std::string KeyFile; // PEM private key file +#endif }; Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config); diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index dfe6bb6aa..83b98013e 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -116,6 +116,12 @@ public: private: int InitializeServer(int BasePort); + bool CreateSessionAndUrlGroup(); + bool RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris); + int RegisterHttpUrls(int BasePort); + bool RegisterHttpsUrls(); + bool CreateRequestQueue(int EffectivePort); + bool SetupIoCompletionPort(); void Cleanup(); void StartServer(); @@ -125,6 +131,9 @@ private: void RegisterService(const char* Endpoint, HttpService& Service); void UnregisterService(const char* Endpoint, HttpService& Service); + bool BindSslCertificate(int Port); + void UnbindSslCertificate(); + private: LoggerRef m_Log; LoggerRef m_RequestLog; @@ -140,7 +149,10 @@ private: RwLock m_AsyncWorkPoolInitLock; std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr; - std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + std::vector<std::wstring> m_HttpsBaseUris; // eg: https://*:nnnn/ + bool m_DidAutoBindCert = false; + int m_HttpsPort = 0; HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; HANDLE m_RequestQueueHandle = 0; @@ -1082,39 +1094,63 @@ HttpSysServer::OnClose() } } -int -HttpSysServer::InitializeServer(int BasePort) +bool +HttpSysServer::CreateSessionAndUrlGroup() { - ZEN_MEMSCOPE(GetHttpsysTag()); - - using namespace std::literals; - - WideStringBuilder<64> WildcardUrlPath; - WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; - - m_IsOk = false; - ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})", - WideToUtf8(WildcardUrlPath), - GetSystemErrorAsString(Result), - Result); + ZEN_ERROR("Failed to create server session: {} ({:#x})", GetSystemErrorAsString(Result), Result); - return 0; + return false; } Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result); + ZEN_ERROR("Failed to create URL group: {} ({:#x})", GetSystemErrorAsString(Result), Result); - return 0; + return false; } + return true; +} + +bool +HttpSysServer::RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris) +{ + using namespace std::literals; + + const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; + + for (const std::u8string_view Host : Hosts) + { + WideStringBuilder<64> LocalUrl; + LocalUrl << Scheme << u8"://"sv << Host << u8":"sv << int64_t(Port) << u8"/"sv; + + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrl.c_str(), HTTP_URL_CONTEXT(0), 0); + + if (Result == NO_ERROR) + { + ZEN_WARN("Registered local-only handler '{}' - this is not accessible from remote hosts", WideToUtf8(LocalUrl)); + OutUris.push_back(LocalUrl.c_str()); + } + else + { + break; + } + } + + return !OutUris.empty(); +} + +int +HttpSysServer::RegisterHttpUrls(int BasePort) +{ + using namespace std::literals; + m_BaseUris.clear(); const bool AllowPortProbing = !m_InitialConfig.IsDedicatedServer; @@ -1122,6 +1158,11 @@ HttpSysServer::InitializeServer(int BasePort) int EffectivePort = BasePort; + WideStringBuilder<64> WildcardUrlPath; + WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; + + ULONG Result; + if (m_InitialConfig.ForceLoopback) { // Force trigger of opening using local port @@ -1177,11 +1218,11 @@ HttpSysServer::InitializeServer(int BasePort) { if (AllowLocalOnly) { - // If we can't register the wildcard path, we fall back to local paths - // This local paths allow requests originating locally to function, but will not allow - // remote origin requests to function. This can be remedied by using netsh + // If we can't register the wildcard path, we fall back to local paths. + // Local paths allow requests originating locally to function, but will not allow + // remote origin requests to function. This can be remedied by using netsh // during an install process to grant permissions to route public access to the appropriate - // port for the current user. eg: + // port for the current user. eg: // netsh http add urlacl url=http://*:8558/ user=<some_user> if (!m_InitialConfig.ForceLoopback) @@ -1246,7 +1287,7 @@ HttpSysServer::InitializeServer(int BasePort) } } - if (m_BaseUris.empty()) + if (m_BaseUris.empty() && m_InitialConfig.HttpsPort == 0) { ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), @@ -1256,16 +1297,104 @@ HttpSysServer::InitializeServer(int BasePort) return 0; } + return EffectivePort; +} + +bool +HttpSysServer::RegisterHttpsUrls() +{ + using namespace std::literals; + + const bool AllowLocalOnly = !m_InitialConfig.IsDedicatedServer; + const int HttpsPort = m_InitialConfig.HttpsPort; + + // If HTTPS-only mode, remove HTTP URLs and clear base URIs + if (m_InitialConfig.HttpsOnly) + { + for (const std::wstring& Uri : m_BaseUris) + { + HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Uri.c_str(), 0); + } + m_BaseUris.clear(); + } + + // Auto-bind certificate if thumbprint is provided + if (!m_InitialConfig.CertThumbprint.empty()) + { + if (!BindSslCertificate(HttpsPort)) + { + return false; + } + } + else + { + ZEN_INFO("HTTPS port {} configured without thumbprint - assuming pre-registered SSL certificate", HttpsPort); + } + + // Register HTTPS URLs using same pattern as HTTP + + WideStringBuilder<64> HttpsWildcard; + HttpsWildcard << u8"https://*:"sv << int64_t(HttpsPort) << u8"/"sv; + + ULONG HttpsResult = NO_ERROR; + + if (m_InitialConfig.ForceLoopback) + { + HttpsResult = ERROR_ACCESS_DENIED; + } + else + { + HttpsResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, HttpsWildcard.c_str(), HTTP_URL_CONTEXT(0), 0); + } + + if (HttpsResult == NO_ERROR) + { + m_HttpsBaseUris.push_back(HttpsWildcard.c_str()); + } + else if (HttpsResult == ERROR_ACCESS_DENIED && AllowLocalOnly) + { + if (!m_InitialConfig.ForceLoopback) + { + ZEN_WARN( + "Unable to register HTTPS handler using '{}' - falling back to local-only. " + "Please ensure the appropriate netsh URL reservation and SSL certificate configuration is made.", + WideToUtf8(HttpsWildcard)); + } + + RegisterLocalUrls(u8"https", HttpsPort, m_HttpsBaseUris); + } + else if (HttpsResult != NO_ERROR) + { + ZEN_ERROR("Failed to register HTTPS URL '{}': {} ({:#x})", + WideToUtf8(HttpsWildcard), + GetSystemErrorAsString(HttpsResult), + HttpsResult); + return false; + } + + if (m_HttpsBaseUris.empty()) + { + ZEN_ERROR("Failed to register any HTTPS URL for port {}", HttpsPort); + return false; + } + + m_HttpsPort = HttpsPort; + return true; +} + +bool +HttpSysServer::CreateRequestQueue(int EffectivePort) +{ HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; WideStringBuilder<64> QueueName; QueueName << "zenserver_" << EffectivePort; - Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, - /* Name */ QueueName.c_str(), - /* SecurityAttributes */ nullptr, - /* Flags */ 0, - &m_RequestQueueHandle); + ULONG Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, + /* Name */ QueueName.c_str(), + /* SecurityAttributes */ nullptr, + /* Flags */ 0, + &m_RequestQueueHandle); if (Result != NO_ERROR) { @@ -1274,7 +1403,7 @@ HttpSysServer::InitializeServer(int BasePort) GetSystemErrorAsString(Result), Result); - return 0; + return false; } HttpBindingInfo.Flags.Present = 1; @@ -1289,7 +1418,7 @@ HttpSysServer::InitializeServer(int BasePort) GetSystemErrorAsString(Result), Result); - return 0; + return false; } // Configure rejection method. Default is to drop the connection, it's better if we @@ -1323,22 +1452,77 @@ HttpSysServer::InitializeServer(int BasePort) } } - // Create I/O completion port + return true; +} +bool +HttpSysServer::SetupIoCompletionPort() +{ std::error_code ErrorCode; m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); if (ErrorCode) { - ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); + ZEN_ERROR("Failed to create IOCP: {}", ErrorCode.message()); + return false; + } + m_IsOk = true; + + if (!m_BaseUris.empty()) + { + ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + } + if (!m_HttpsBaseUris.empty()) + { + ZEN_INFO("Started http.sys HTTPS server at '{}'", WideToUtf8(m_HttpsBaseUris.front())); + } + + return true; +} + +int +HttpSysServer::InitializeServer(int BasePort) +{ + ZEN_MEMSCOPE(GetHttpsysTag()); + + m_IsOk = false; + + if (!CreateSessionAndUrlGroup()) + { return 0; } - else + + int EffectivePort = RegisterHttpUrls(BasePort); + + if (m_InitialConfig.HttpsPort > 0) + { + if (!RegisterHttpsUrls()) + { + return 0; + } + } + + if (m_BaseUris.empty() && m_HttpsBaseUris.empty()) { - m_IsOk = true; + ZEN_ERROR("No HTTP or HTTPS listeners could be registered"); + return 0; + } - ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + if (!CreateRequestQueue(EffectivePort)) + { + return 0; + } + + if (!SetupIoCompletionPort()) + { + return 0; + } + + // When HTTPS-only, return the HTTPS port as the effective port + if (m_InitialConfig.HttpsOnly && m_HttpsPort > 0) + { + return m_HttpsPort; } return EffectivePort; @@ -1349,6 +1533,8 @@ HttpSysServer::Cleanup() { ++m_IsShuttingDown; + UnbindSslCertificate(); + if (m_RequestQueueHandle) { HttpCloseRequestQueue(m_RequestQueueHandle); @@ -1368,6 +1554,105 @@ HttpSysServer::Cleanup() } } +// {7E3F4B2A-1C8D-4A6E-B5F0-9D2E8C7A3B1F} - Fixed GUID for zenserver SSL bindings +static constexpr GUID ZenServerSslAppId = {0x7E3F4B2A, 0x1C8D, 0x4A6E, {0xB5, 0xF0, 0x9D, 0x2E, 0x8C, 0x7A, 0x3B, 0x1F}}; + +bool +HttpSysServer::BindSslCertificate(int Port) +{ + const std::string& Thumbprint = m_InitialConfig.CertThumbprint; + if (Thumbprint.size() != 40) + { + ZEN_ERROR("SSL certificate thumbprint must be exactly 40 hex characters, got {}", Thumbprint.size()); + return false; + } + + BYTE CertHash[20] = {}; + if (!ParseHexBytes(Thumbprint, CertHash)) + { + ZEN_ERROR("SSL certificate thumbprint contains invalid hex characters"); + return false; + } + + SOCKADDR_IN Address = {}; + Address.sin_family = AF_INET; + Address.sin_port = htons(static_cast<USHORT>(Port)); + Address.sin_addr.s_addr = INADDR_ANY; + + const std::wstring StoreNameW = UTF8_to_UTF16(m_InitialConfig.CertStoreName.c_str()); + + HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {}; + SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + SslConfig.ParamDesc.pSslHash = CertHash; + SslConfig.ParamDesc.SslHashLength = sizeof(CertHash); + SslConfig.ParamDesc.pSslCertStoreName = const_cast<PWSTR>(StoreNameW.c_str()); + SslConfig.ParamDesc.AppId = ZenServerSslAppId; + + ULONG Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + + if (Result == ERROR_ALREADY_EXISTS) + { + // Remove existing binding and retry + HTTP_SERVICE_CONFIG_SSL_SET DeleteConfig = {}; + DeleteConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + + HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &DeleteConfig, sizeof(DeleteConfig), nullptr); + + Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + } + + if (Result != NO_ERROR) + { + ZEN_ERROR( + "Failed to bind SSL certificate to port {}: {} ({:#x}). " + "This operation may require running as administrator.", + Port, + GetSystemErrorAsString(Result), + Result); + return false; + } + + m_DidAutoBindCert = true; + m_HttpsPort = Port; + + ZEN_INFO("SSL certificate auto-bound for 0.0.0.0:{} (thumbprint: {}..., store: {})", + Port, + Thumbprint.substr(0, 8), + m_InitialConfig.CertStoreName); + + return true; +} + +void +HttpSysServer::UnbindSslCertificate() +{ + if (!m_DidAutoBindCert) + { + return; + } + + SOCKADDR_IN Address = {}; + Address.sin_family = AF_INET; + Address.sin_port = htons(static_cast<USHORT>(m_HttpsPort)); + Address.sin_addr.s_addr = INADDR_ANY; + + HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {}; + SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + + ULONG Result = HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + + if (Result != NO_ERROR) + { + ZEN_WARN("Failed to remove SSL certificate binding from port {}: {} ({:#x})", m_HttpsPort, GetSystemErrorAsString(Result), Result); + } + else + { + ZEN_INFO("SSL certificate binding removed from port {}", m_HttpsPort); + } + + m_DidAutoBindCert = false; +} + WorkerThreadPool& HttpSysServer::WorkPool() { @@ -1495,19 +1780,23 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) // Convert to wide string - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; - - ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); - - if (Result != NO_ERROR) + auto RegisterWithBaseUris = [&](const std::vector<std::wstring>& BaseUris) { + for (const std::wstring& BaseUri : BaseUris) { - ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + std::wstring Url16 = BaseUri + PathUtf16; - return; + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + return; + } } - } + }; + + RegisterWithBaseUris(m_BaseUris); + RegisterWithBaseUris(m_HttpsBaseUris); } void @@ -1522,19 +1811,22 @@ HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); - // Convert to wide string - - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; + auto UnregisterFromBaseUris = [&](const std::vector<std::wstring>& BaseUris) { + for (const std::wstring& BaseUri : BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; - ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); + ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); - if (Result != NO_ERROR) - { - ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + } } - } + }; + + UnregisterFromBaseUris(m_BaseUris); + UnregisterFromBaseUris(m_HttpsBaseUris); } ////////////////////////////////////////////////////////////////////////// @@ -2422,6 +2714,11 @@ HttpSysServer::OnInitialize(int BasePort, std::filesystem::path DataDir) ZEN_UNUSED(DataDir); if (int EffectivePort = InitializeServer(BasePort)) { + if (m_HttpsPort > 0) + { + SetEffectiveHttpsPort(m_HttpsPort); + } + StartServer(); return EffectivePort; diff --git a/src/zenhttp/servers/httpsys.h b/src/zenhttp/servers/httpsys.h index b2fe7475b..ca465ad00 100644 --- a/src/zenhttp/servers/httpsys.h +++ b/src/zenhttp/servers/httpsys.h @@ -22,6 +22,10 @@ struct HttpSysConfig bool IsRequestLoggingEnabled = false; bool IsDedicatedServer = false; bool ForceLoopback = false; + int HttpsPort = 0; // 0 = HTTPS disabled + std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding + std::string CertStoreName = "MY"; // Windows certificate store name + bool HttpsOnly = false; // When true, disable HTTP listener }; Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config); diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp index b2543277a..5ae48f5b3 100644 --- a/src/zenhttp/servers/wsasio.cpp +++ b/src/zenhttp/servers/wsasio.cpp @@ -1,6 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "wsasio.h" +#include "asio_socket_traits.h" #include "wsframecodec.h" #include <zencore/logging.h> @@ -17,14 +18,16 @@ WsLog() ////////////////////////////////////////////////////////////////////////// -WsAsioConnection::WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server) +template<typename SocketType> +WsAsioConnectionT<SocketType>::WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server) : m_Socket(std::move(Socket)) , m_Handler(Handler) , m_HttpServer(Server) { } -WsAsioConnection::~WsAsioConnection() +template<typename SocketType> +WsAsioConnectionT<SocketType>::~WsAsioConnectionT() { m_IsOpen.store(false); if (m_HttpServer) @@ -33,14 +36,16 @@ WsAsioConnection::~WsAsioConnection() } } +template<typename SocketType> void -WsAsioConnection::Start() +WsAsioConnectionT<SocketType>::Start() { EnqueueRead(); } +template<typename SocketType> bool -WsAsioConnection::IsOpen() const +WsAsioConnectionT<SocketType>::IsOpen() const { return m_IsOpen.load(std::memory_order_relaxed); } @@ -50,23 +55,25 @@ WsAsioConnection::IsOpen() const // Read loop // +template<typename SocketType> void -WsAsioConnection::EnqueueRead() +WsAsioConnectionT<SocketType>::EnqueueRead() { if (!m_IsOpen.load(std::memory_order_relaxed)) { return; } - Ref<WsAsioConnection> Self(this); + Ref<WsAsioConnectionT> Self(this); asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnDataReceived(Ec, ByteCount); }); } +template<typename SocketType> void -WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +WsAsioConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) { if (Ec) { @@ -90,8 +97,9 @@ WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] st } } +template<typename SocketType> void -WsAsioConnection::ProcessReceivedData() +WsAsioConnectionT<SocketType>::ProcessReceivedData() { while (m_ReadBuffer.size() > 0) { @@ -162,8 +170,8 @@ WsAsioConnection::ProcessReceivedData() // Shut down the socket std::error_code ShutdownEc; - m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc); - m_Socket->close(ShutdownEc); + SocketTraits<SocketType>::ShutdownBoth(*m_Socket, ShutdownEc); + SocketTraits<SocketType>::Close(*m_Socket, ShutdownEc); return; } @@ -179,8 +187,9 @@ WsAsioConnection::ProcessReceivedData() // Write queue // +template<typename SocketType> void -WsAsioConnection::SendText(std::string_view Text) +WsAsioConnectionT<SocketType>::SendText(std::string_view Text) { if (!m_IsOpen.load(std::memory_order_relaxed)) { @@ -192,8 +201,9 @@ WsAsioConnection::SendText(std::string_view Text) EnqueueWrite(std::move(Frame)); } +template<typename SocketType> void -WsAsioConnection::SendBinary(std::span<const uint8_t> Data) +WsAsioConnectionT<SocketType>::SendBinary(std::span<const uint8_t> Data) { if (!m_IsOpen.load(std::memory_order_relaxed)) { @@ -204,14 +214,16 @@ WsAsioConnection::SendBinary(std::span<const uint8_t> Data) EnqueueWrite(std::move(Frame)); } +template<typename SocketType> void -WsAsioConnection::Close(uint16_t Code, std::string_view Reason) +WsAsioConnectionT<SocketType>::Close(uint16_t Code, std::string_view Reason) { DoClose(Code, Reason); } +template<typename SocketType> void -WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason) +WsAsioConnectionT<SocketType>::DoClose(uint16_t Code, std::string_view Reason) { if (!m_IsOpen.exchange(false)) { @@ -227,8 +239,9 @@ WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason) m_Handler.OnWebSocketClose(*this, Code, Reason); } +template<typename SocketType> void -WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame) +WsAsioConnectionT<SocketType>::EnqueueWrite(std::vector<uint8_t> Frame) { if (m_HttpServer) { @@ -252,8 +265,9 @@ WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame) } } +template<typename SocketType> void -WsAsioConnection::FlushWriteQueue() +WsAsioConnectionT<SocketType>::FlushWriteQueue() { std::vector<uint8_t> Frame; @@ -272,7 +286,7 @@ WsAsioConnection::FlushWriteQueue() return; } - Ref<WsAsioConnection> Self(this); + Ref<WsAsioConnectionT> Self(this); // Move Frame into a shared_ptr so we can create the buffer and capture ownership // in the same async_write call without evaluation order issues. @@ -283,8 +297,9 @@ WsAsioConnection::FlushWriteQueue() [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); }); } +template<typename SocketType> void -WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +WsAsioConnectionT<SocketType>::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) { if (Ec) { @@ -308,4 +323,17 @@ WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] s FlushWriteQueue(); } +////////////////////////////////////////////////////////////////////////// +// Explicit template instantiations + +template class WsAsioConnectionT<asio::ip::tcp::socket>; + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +template class WsAsioConnectionT<asio::local::stream_protocol::socket>; +#endif + +#if ZEN_USE_OPENSSL +template class WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>; +#endif + } // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h index e8bb3b1d2..64602ee46 100644 --- a/src/zenhttp/servers/wsasio.h +++ b/src/zenhttp/servers/wsasio.h @@ -8,6 +8,12 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif +#if ZEN_USE_OPENSSL +# include <asio/ssl.hpp> +#endif ZEN_THIRD_PARTY_INCLUDES_END #include <deque> @@ -21,22 +27,23 @@ class HttpServer; namespace zen::asio_http { /** - * WebSocket connection over an ASIO TCP socket + * WebSocket connection over an ASIO stream socket * - * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake) + * Templated on SocketType to support both TCP and Unix domain sockets. + * Owns the socket (moved from HttpServerConnection after the 101 handshake) * and runs an async read/write loop to exchange WebSocket frames. * * Lifetime is managed solely through intrusive reference counting (RefCounted). - * The async read/write callbacks capture Ref<WsAsioConnection> to keep the - * connection alive for the duration of the async operation. The service layer - * also holds a Ref<WebSocketConnection>. + * The async read/write callbacks capture Ref<> to keep the connection alive + * for the duration of the async operation. The service layer also holds a + * Ref<WebSocketConnection>. */ - -class WsAsioConnection : public WebSocketConnection +template<typename SocketType> +class WsAsioConnectionT : public WebSocketConnection { public: - WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server); - ~WsAsioConnection() override; + WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server); + ~WsAsioConnectionT() override; /** * Start the async read loop. Must be called once after construction @@ -61,10 +68,10 @@ private: void DoClose(uint16_t Code, std::string_view Reason); - std::unique_ptr<asio::ip::tcp::socket> m_Socket; - IWebSocketHandler& m_Handler; - zen::HttpServer* m_HttpServer; - asio::streambuf m_ReadBuffer; + std::unique_ptr<SocketType> m_Socket; + IWebSocketHandler& m_Handler; + zen::HttpServer* m_HttpServer; + asio::streambuf m_ReadBuffer; RwLock m_WriteLock; std::deque<std::vector<uint8_t>> m_WriteQueue; @@ -74,4 +81,14 @@ private: std::atomic<bool> m_CloseSent{false}; }; +using WsAsioConnection = WsAsioConnectionT<asio::ip::tcp::socket>; + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +using WsAsioUnixConnection = WsAsioConnectionT<asio::local::stream_protocol::socket>; +#endif + +#if ZEN_USE_OPENSSL +using WsAsioSslConnection = WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>; +#endif + } // namespace zen::asio_http diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp index 2134e4ff1..042afd8ff 100644 --- a/src/zenhttp/servers/wstest.cpp +++ b/src/zenhttp/servers/wstest.cpp @@ -485,7 +485,7 @@ TEST_CASE("websocket.integration") Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{}); - int Port = Server->Initialize(7575, TmpDir.Path()); + int Port = Server->Initialize(0, TmpDir.Path()); REQUIRE(Port != 0); Server->RegisterService(TestService); @@ -797,7 +797,7 @@ TEST_CASE("websocket.client") Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{}); - int Port = Server->Initialize(7576, TmpDir.Path()); + int Port = Server->Initialize(0, TmpDir.Path()); REQUIRE(Port != 0); Server->RegisterService(TestService); @@ -913,6 +913,75 @@ TEST_CASE("websocket.client") } } +TEST_CASE("websocket.client.unixsocket") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + std::string SocketPath = (TmpDir.Path() / "ws.sock").string(); + + Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath}); + + int Port = Server->Initialize(0, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + Sleep(100); + + SUBCASE("connect, echo, close over unix socket") + { + TestWsClientHandler Handler; + HttpWsClientSettings Settings; + Settings.UnixSocketPath = SocketPath; + + HttpWsClient Client("ws://localhost/wstest/ws", Handler, Settings); + Client.Connect(); + + // Wait for OnWsOpen + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + CHECK(Client.IsOpen()); + + // Send text, expect echo + Client.SendText("hello over unix socket"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + CHECK_EQ(Handler.m_MessageCount.load(), 1); + CHECK_EQ(Handler.m_LastMessage, "hello over unix socket"); + + // Close + Client.Close(1000, "done"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + Sleep(50); + CHECK_FALSE(Client.IsOpen()); + } +} + TEST_SUITE_END(); void |