aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/servers
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-03-10 17:27:26 +0100
committerGitHub Enterprise <[email protected]>2026-03-10 17:27:26 +0100
commitd0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7 (patch)
tree2dfe1e3e0b620043d358e0b7f8bdf8320d985491 /src/zenhttp/servers
parentchangelog entry which was inadvertently omitted from PR merge (diff)
downloadzen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.tar.xz
zen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.zip
HttpClient using libcurl, Unix Sockets for HTTP. HTTPS support (#770)
The main goal of this change is to eliminate the cpr back-end altogether and replace it with the curl implementation. I would expect to drop cpr as soon as we feel happy with the libcurl back-end. That would leave us with a direct dependency on libcurl only, and cpr can be eliminated as a dependency. ### HttpClient Backend Overhaul - Implemented a new **libcurl-based HttpClient** backend (`httpclientcurl.cpp`, ~2000 lines) as an alternative to the cpr-based one - Made HttpClient backend **configurable at runtime** via constructor arguments and `-httpclient=...` CLI option (for zen, zenserver, and tests) - Extended HttpClient test suite to cover multipart/content-range scenarios ### Unix Domain Socket Support - Added Unix domain socket support to **httpasio** (server side) - Added Unix domain socket support to **HttpClient** - Added Unix domain socket support to **HttpWsClient** (WebSocket client) - Templatized `HttpServerConnectionT<SocketType>` and `WsAsioConnectionT<SocketType>` to handle TCP, Unix, and SSL sockets uniformly via `if constexpr` dispatch ### HTTPS Support - Added **preliminary HTTPS support to httpasio** (for Mac/Linux via OpenSSL) - Added **basic HTTPS support for http.sys** (Windows) - Implemented HTTPS test for httpasio - Split `InitializeServer` into smaller sub-functions for http.sys ### Other Notable Changes - Improved **zenhttp-test stability** with dynamic port allocation - Enhanced port retry logic in http.sys (handles ERROR_ACCESS_DENIED) - Fatal signal/exception handlers for backtrace generation in tests - Added `zen bench http` subcommand to exercise network + HTTP client/server communication stack
Diffstat (limited to 'src/zenhttp/servers')
-rw-r--r--src/zenhttp/servers/asio_socket_traits.h54
-rw-r--r--src/zenhttp/servers/httpasio.cpp687
-rw-r--r--src/zenhttp/servers/httpasio.h6
-rw-r--r--src/zenhttp/servers/httpsys.cpp409
-rw-r--r--src/zenhttp/servers/httpsys.h4
-rw-r--r--src/zenhttp/servers/wsasio.cpp64
-rw-r--r--src/zenhttp/servers/wsasio.h43
-rw-r--r--src/zenhttp/servers/wstest.cpp73
8 files changed, 1060 insertions, 280 deletions
diff --git a/src/zenhttp/servers/asio_socket_traits.h b/src/zenhttp/servers/asio_socket_traits.h
new file mode 100644
index 000000000..25aeaa24e
--- /dev/null
+++ b/src/zenhttp/servers/asio_socket_traits.h
@@ -0,0 +1,54 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::asio_http {
+
+/**
+ * Traits for abstracting socket shutdown/close across plain TCP, Unix domain, and SSL sockets.
+ * SSL sockets need lowest_layer() access and have different shutdown semantics.
+ */
+template<typename SocketType>
+struct SocketTraits
+{
+ /// SSL sockets cannot use zero-copy file send (TransmitFile/sendfile) because
+ /// those bypass the encryption layer. This flag lets templated code fall back
+ /// to reading-into-memory for SSL connections.
+ static constexpr bool IsSslSocket = false;
+
+ static void ShutdownReceive(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_receive, Ec); }
+
+ static void ShutdownBoth(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_both, Ec); }
+
+ static void Close(SocketType& S, std::error_code& Ec) { S.close(Ec); }
+};
+
+#if ZEN_USE_OPENSSL
+using SslSocket = asio::ssl::stream<asio::ip::tcp::socket>;
+
+template<>
+struct SocketTraits<SslSocket>
+{
+ static constexpr bool IsSslSocket = true;
+
+ static void ShutdownReceive(SslSocket& S, std::error_code& Ec) { S.lowest_layer().shutdown(asio::socket_base::shutdown_receive, Ec); }
+
+ static void ShutdownBoth(SslSocket& S, std::error_code& Ec)
+ {
+ // Best-effort SSL close_notify, then TCP shutdown
+ S.shutdown(Ec);
+ S.lowest_layer().shutdown(asio::socket_base::shutdown_both, Ec);
+ }
+
+ static void Close(SslSocket& S, std::error_code& Ec) { S.lowest_layer().close(Ec); }
+};
+#endif
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index f5178ebe8..ee8e71256 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -1,6 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include "httpasio.h"
+#include "asio_socket_traits.h"
#include "httptracer.h"
#include <zencore/except.h>
@@ -35,6 +36,12 @@ ZEN_THIRD_PARTY_INCLUDES_START
#endif
#include <asio.hpp>
#include <asio/stream_file.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
ZEN_THIRD_PARTY_INCLUDES_END
#define ASIO_VERBOSE_TRACE 0
@@ -144,7 +151,17 @@ using namespace std::literals;
struct HttpAcceptor;
struct HttpResponse;
-struct HttpServerConnection;
+template<typename SocketType>
+struct HttpServerConnectionT;
+using HttpServerConnection = HttpServerConnectionT<asio::ip::tcp::socket>;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+struct UnixAcceptor;
+using UnixServerConnection = HttpServerConnectionT<asio::local::stream_protocol::socket>;
+#endif
+#if ZEN_USE_OPENSSL
+struct HttpsAcceptor;
+using HttpsSslServerConnection = HttpServerConnectionT<SslSocket>;
+#endif
inline LoggerRef
InitLogger()
@@ -176,9 +193,9 @@ Log()
#endif
#if ZEN_USE_TRANSMITFILE
-template<typename Handler>
+template<typename Handler, typename SocketType>
void
-TransmitFileAsync(asio::ip::tcp::socket& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb)
+TransmitFileAsync(SocketType& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb)
{
# if ZEN_BUILD_DEBUG
const uint64_t FileSize = FileSizeFromHandle(FileHandle);
@@ -511,11 +528,20 @@ public:
bool IsLoopbackOnly() const;
+ int GetEffectiveHttpsPort() const;
+
asio::io_service m_IoService;
asio::io_service::work m_Work{m_IoService};
std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
- std::vector<std::thread> m_ThreadPool;
- std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ std::unique_ptr<asio_http::UnixAcceptor> m_UnixAcceptor;
+#endif
+#if ZEN_USE_OPENSSL
+ std::unique_ptr<asio::ssl::context> m_SslContext;
+ std::unique_ptr<asio_http::HttpsAcceptor> m_HttpsAcceptor;
+#endif
+ std::vector<std::thread> m_ThreadPool;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
LoggerRef m_RequestLog;
HttpServerTracer m_RequestTracer;
@@ -573,6 +599,7 @@ public:
uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers
IoBuffer m_PayloadBuffer;
bool m_IsLocalMachineRequest;
+ bool m_AllowZeroCopyFileSend = true;
std::string m_RemoteAddress;
std::unique_ptr<HttpResponse> m_Response;
};
@@ -595,6 +622,8 @@ public:
~HttpResponse() = default;
+ void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; }
+
/**
* Initialize the response for sending a payload made up of multiple blobs
*
@@ -636,7 +665,7 @@ public:
bool ChunkHandled = false;
#if ZEN_USE_TRANSMITFILE || ZEN_USE_ASYNC_SENDFILE
- if (OwnedBuffer.IsWholeFile())
+ if (m_AllowZeroCopyFileSend && OwnedBuffer.IsWholeFile())
{
if (IoBufferFileReference FileRef; OwnedBuffer.GetFileReference(/* out */ FileRef))
{
@@ -751,7 +780,8 @@ public:
return m_Headers;
}
- void SendResponse(asio::ip::tcp::socket& TcpSocket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token)
+ template<typename SocketType>
+ void SendResponse(SocketType& Socket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token)
{
ZEN_ASSERT(m_State == State::kInitialized);
@@ -761,10 +791,11 @@ public:
m_SendCb = std::move(Token);
m_State = State::kSending;
- SendNextChunk(TcpSocket);
+ SendNextChunk(Socket);
}
- void SendNextChunk(asio::ip::tcp::socket& TcpSocket)
+ template<typename SocketType>
+ void SendNextChunk(SocketType& Socket)
{
ZEN_ASSERT(m_State == State::kSending);
@@ -781,12 +812,12 @@ public:
auto CompletionToken = [Self = this, Token = std::move(m_SendCb), TotalBytes = m_TotalBytesSent] { Token({}, TotalBytes); };
- asio::defer(TcpSocket.get_executor(), std::move(CompletionToken));
+ asio::defer(Socket.get_executor(), std::move(CompletionToken));
return;
}
- auto OnCompletion = [this, &TcpSocket](const asio::error_code& Ec, std::size_t ByteCount) {
+ auto OnCompletion = [this, &Socket](const asio::error_code& Ec, std::size_t ByteCount) {
ZEN_ASSERT(m_State == State::kSending);
m_TotalBytesSent += ByteCount;
@@ -797,7 +828,7 @@ public:
}
else
{
- SendNextChunk(TcpSocket);
+ SendNextChunk(Socket);
}
};
@@ -811,25 +842,21 @@ public:
Io.Ref.FileRef.FileChunkSize);
#if ZEN_USE_TRANSMITFILE
- TransmitFileAsync(TcpSocket,
+ TransmitFileAsync(Socket,
Io.Ref.FileRef.FileHandle,
Io.Ref.FileRef.FileChunkOffset,
gsl::narrow_cast<uint32_t>(Io.Ref.FileRef.FileChunkSize),
OnCompletion);
+ return;
#elif ZEN_USE_ASYNC_SENDFILE
- SendFileAsync(TcpSocket,
+ SendFileAsync(Socket,
Io.Ref.FileRef.FileHandle,
Io.Ref.FileRef.FileChunkOffset,
Io.Ref.FileRef.FileChunkSize,
64 * 1024,
OnCompletion);
-#else
- // This should never occur unless we compile with one
- // of the options above
- ZEN_WARN("invalid file reference in response");
-#endif
-
return;
+#endif
}
// Send as many consecutive non-file references as possible in one asio operation
@@ -850,7 +877,7 @@ public:
++m_IoVecCursor;
}
- asio::async_write(TcpSocket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion);
+ asio::async_write(Socket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion);
}
private:
@@ -863,12 +890,13 @@ private:
kFailed
};
- uint32_t m_RequestNumber = 0;
- uint16_t m_ResponseCode = 0;
- bool m_IsKeepAlive = true;
- State m_State = State::kUninitialized;
- HttpContentType m_ContentType = HttpContentType::kBinary;
- uint64_t m_ContentLength = 0;
+ uint32_t m_RequestNumber = 0;
+ uint16_t m_ResponseCode = 0;
+ bool m_IsKeepAlive = true;
+ bool m_AllowZeroCopyFileSend = true;
+ State m_State = State::kUninitialized;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ uint64_t m_ContentLength = 0;
eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive
ExtendableStringBuilder<160> m_Headers;
@@ -895,12 +923,13 @@ private:
//////////////////////////////////////////////////////////////////////////
-struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnection>
+template<typename SocketType>
+struct HttpServerConnectionT : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnectionT<SocketType>>
{
- HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket);
- ~HttpServerConnection();
+ HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket);
+ ~HttpServerConnectionT();
- std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); }
+ std::shared_ptr<HttpServerConnectionT> AsSharedPtr() { return this->shared_from_this(); }
// HttpConnectionBase implementation
@@ -962,12 +991,13 @@ private:
RwLock m_ActiveResponsesLock;
std::deque<std::unique_ptr<HttpResponse>> m_ActiveResponses;
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<SocketType> m_Socket;
};
std::atomic<uint32_t> g_ConnectionIdCounter{0};
-HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket)
+template<typename SocketType>
+HttpServerConnectionT<SocketType>::HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket)
: m_Server(Server)
, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1))
, m_Socket(std::move(Socket))
@@ -975,21 +1005,24 @@ HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::uniq
ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId);
}
-HttpServerConnection::~HttpServerConnection()
+template<typename SocketType>
+HttpServerConnectionT<SocketType>::~HttpServerConnectionT()
{
RwLock::ExclusiveLockScope _(m_ActiveResponsesLock);
ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId);
}
+template<typename SocketType>
void
-HttpServerConnection::HandleNewRequest()
+HttpServerConnectionT<SocketType>::HandleNewRequest()
{
EnqueueRead();
}
+template<typename SocketType>
void
-HttpServerConnection::TerminateConnection()
+HttpServerConnectionT<SocketType>::TerminateConnection()
{
if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated)
{
@@ -1001,12 +1034,13 @@ HttpServerConnection::TerminateConnection()
// Terminating, we don't care about any errors when closing socket
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_both, Ec);
- m_Socket->close(Ec);
+ SocketTraits<SocketType>::ShutdownBoth(*m_Socket, Ec);
+ SocketTraits<SocketType>::Close(*m_Socket, Ec);
}
+template<typename SocketType>
void
-HttpServerConnection::EnqueueRead()
+HttpServerConnectionT<SocketType>::EnqueueRead()
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1027,8 +1061,9 @@ HttpServerConnection::EnqueueRead()
[Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); });
}
+template<typename SocketType>
void
-HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+HttpServerConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1086,11 +1121,12 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]
}
}
+template<typename SocketType>
void
-HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
- [[maybe_unused]] std::size_t ByteCount,
- [[maybe_unused]] uint32_t RequestNumber,
- HttpResponse* ResponseToPop)
+HttpServerConnectionT<SocketType>::OnResponseDataSent(const asio::error_code& Ec,
+ [[maybe_unused]] std::size_t ByteCount,
+ [[maybe_unused]] uint32_t RequestNumber,
+ HttpResponse* ResponseToPop)
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1144,8 +1180,9 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
}
}
+template<typename SocketType>
void
-HttpServerConnection::CloseConnection()
+HttpServerConnectionT<SocketType>::CloseConnection()
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1157,23 +1194,24 @@ HttpServerConnection::CloseConnection()
m_RequestState = RequestState::kDone;
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
+ SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec);
if (Ec)
{
ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message());
}
- m_Socket->close(Ec);
+ SocketTraits<SocketType>::Close(*m_Socket, Ec);
if (Ec)
{
ZEN_WARN("socket close ERROR, reason '{}'", Ec.message());
}
}
+template<typename SocketType>
void
-HttpServerConnection::SendInlineResponse(uint32_t RequestNumber,
- std::string_view StatusLine,
- std::string_view Headers,
- std::string_view Body)
+HttpServerConnectionT<SocketType>::SendInlineResponse(uint32_t RequestNumber,
+ std::string_view StatusLine,
+ std::string_view Headers,
+ std::string_view Body)
{
ExtendableStringBuilder<256> ResponseBuilder;
ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n";
@@ -1194,15 +1232,16 @@ HttpServerConnection::SendInlineResponse(uint32_t RequestNumber,
IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size());
auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize());
asio::async_write(
- *m_Socket.get(),
+ *m_Socket,
Buffer,
[Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) {
Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
});
}
+template<typename SocketType>
void
-HttpServerConnection::HandleRequest()
+HttpServerConnectionT<SocketType>::HandleRequest()
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1229,24 +1268,25 @@ HttpServerConnection::HandleRequest()
ResponseStr->append("\r\n\r\n");
// Send the 101 response on the current socket, then hand the socket off
- // to a WsAsioConnection for the WebSocket protocol.
- asio::async_write(*m_Socket,
- asio::buffer(ResponseStr->data(), ResponseStr->size()),
- [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
- if (Ec)
- {
- ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
- return;
- }
-
- Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened();
- Ref<WsAsioConnection> WsConn(
- new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
- Ref<WebSocketConnection> WsConnRef(WsConn.Get());
-
- WsHandler->OnWebSocketOpen(std::move(WsConnRef));
- WsConn->Start();
- });
+ // to a WsAsioConnectionT for the WebSocket protocol.
+ asio::async_write(
+ *m_Socket,
+ asio::buffer(ResponseStr->data(), ResponseStr->size()),
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
+ return;
+ }
+
+ Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened();
+ using WsConnType = WsAsioConnectionT<SocketType>;
+ Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+ });
m_RequestState = RequestState::kDone;
return;
@@ -1260,7 +1300,7 @@ HttpServerConnection::HandleRequest()
m_RequestState = RequestState::kWritingFinal;
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
+ SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec);
if (Ec)
{
@@ -1280,15 +1320,36 @@ HttpServerConnection::HandleRequest()
m_Server.m_HttpServer->MarkRequest();
- auto RemoteEndpoint = m_Socket->remote_endpoint();
- bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address();
+ bool IsLocalConnection = true;
+ std::string RemoteAddress;
+
+ if constexpr (std::is_same_v<SocketType, asio::ip::tcp::socket>)
+ {
+ auto RemoteEndpoint = m_Socket->remote_endpoint();
+ IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address();
+ RemoteAddress = RemoteEndpoint.address().to_string();
+ }
+#if ZEN_USE_OPENSSL
+ else if constexpr (std::is_same_v<SocketType, SslSocket>)
+ {
+ auto RemoteEndpoint = m_Socket->lowest_layer().remote_endpoint();
+ IsLocalConnection = m_Socket->lowest_layer().local_endpoint().address() == RemoteEndpoint.address();
+ RemoteAddress = RemoteEndpoint.address().to_string();
+ }
+#endif
+ else
+ {
+ RemoteAddress = "unix";
+ }
HttpAsioServerRequest Request(m_RequestData,
*Service,
m_RequestData.Body(),
RequestNumber,
IsLocalConnection,
- RemoteEndpoint.address().to_string());
+ std::move(RemoteAddress));
+
+ Request.m_AllowZeroCopyFileSend = !SocketTraits<SocketType>::IsSslSocket;
ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber);
@@ -1439,14 +1500,23 @@ HttpServerConnection::HandleRequest()
}
//////////////////////////////////////////////////////////////////////////
+// Base class for TCP acceptors that handles socket setup, port binding
+// with probing/retry, and dual-stack (IPv6+IPv4 loopback) support.
+// Subclasses only need to implement OnAccept() to handle new connections.
-struct HttpAcceptor
+struct TcpAcceptorBase
{
- HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ TcpAcceptorBase(HttpAsioServerImpl& Server,
+ asio::io_service& IoService,
+ uint16_t BasePort,
+ bool ForceLoopback,
+ bool AllowPortProbing,
+ std::string_view Label)
: m_Server(Server)
, m_IoService(IoService)
, m_Acceptor(m_IoService, asio::ip::tcp::v6())
, m_AlternateProtocolAcceptor(m_IoService, asio::ip::tcp::v4())
+ , m_Label(Label)
{
const bool IsUsingIPv6 = IsIPv6Capable();
if (!IsUsingIPv6)
@@ -1455,7 +1525,6 @@ struct HttpAcceptor
}
#if ZEN_PLATFORM_WINDOWS
- // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms
typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> exclusive_address;
m_Acceptor.set_option(exclusive_address(true));
m_AlternateProtocolAcceptor.set_option(exclusive_address(true));
@@ -1468,83 +1537,54 @@ struct HttpAcceptor
#endif // ZEN_PLATFORM_WINDOWS
m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
- m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
m_AlternateProtocolAcceptor.set_option(asio::ip::tcp::no_delay(true));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
- std::string BoundBaseUrl;
if (IsUsingIPv6)
{
- BoundBaseUrl = BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing);
+ BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing);
}
else
{
- ZEN_INFO("NOTE: ipv6 support is disabled, binding to ipv4 only");
-
- BoundBaseUrl = BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing);
+ ZEN_INFO("{}: ipv6 support is disabled, binding to ipv4 only", m_Label);
+ BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing);
}
+ }
- if (!IsValid())
- {
- return;
- }
-
-#if ZEN_PLATFORM_WINDOWS
- // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
- // This must be used by both the client and server side, and is only effective in the absence of
- // Windows Filtering Platform (WFP) callouts which can be installed by security software.
- // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
- SOCKET NativeSocket = m_Acceptor.native_handle();
- int LoopbackOptionValue = 1;
- DWORD OptionNumberOfBytesReturned = 0;
- WSAIoctl(NativeSocket,
- SIO_LOOPBACK_FAST_PATH,
- &LoopbackOptionValue,
- sizeof(LoopbackOptionValue),
- NULL,
- 0,
- &OptionNumberOfBytesReturned,
- 0,
- 0);
-
- if (m_UseAlternateProtocolAcceptor)
- {
- NativeSocket = m_AlternateProtocolAcceptor.native_handle();
- WSAIoctl(NativeSocket,
- SIO_LOOPBACK_FAST_PATH,
- &LoopbackOptionValue,
- sizeof(LoopbackOptionValue),
- NULL,
- 0,
- &OptionNumberOfBytesReturned,
- 0,
- 0);
- }
-#endif
- m_Acceptor.listen();
+ virtual ~TcpAcceptorBase()
+ {
+ m_Acceptor.close();
if (m_UseAlternateProtocolAcceptor)
{
- m_AlternateProtocolAcceptor.listen();
+ m_AlternateProtocolAcceptor.close();
}
-
- ZEN_INFO("Started asio server at '{}", BoundBaseUrl);
}
- ~HttpAcceptor()
+ void Start()
{
- m_Acceptor.close();
+ ZEN_ASSERT(!m_IsStopped);
+ InitAcceptLoop(m_Acceptor);
if (m_UseAlternateProtocolAcceptor)
{
- m_AlternateProtocolAcceptor.close();
+ InitAcceptLoop(m_AlternateProtocolAcceptor);
}
}
+ void StopAccepting() { m_IsStopped = true; }
+
+ uint16_t GetPort() const { return m_Acceptor.local_endpoint().port(); }
+ bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); }
+ bool IsValid() const { return m_IsValid; }
+
+protected:
+ /// Called for each accepted TCP socket. Subclasses create the appropriate connection type.
+ virtual void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) = 0;
+
+ HttpAsioServerImpl& m_Server;
+ asio::io_service& m_IoService;
+
+private:
template<typename AddressType>
- std::string BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ void BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
{
uint16_t EffectivePort = BasePort;
@@ -1571,7 +1611,7 @@ struct HttpAcceptor
if (BindErrorCode == asio::error::access_denied && !BindAddress.is_loopback())
{
- ZEN_INFO("Access denied for public port {}, falling back to loopback", BasePort);
+ ZEN_INFO("{}: Access denied for public port {}, falling back to loopback", m_Label, BasePort);
BindAddress = AddressType::loopback();
@@ -1585,7 +1625,7 @@ struct HttpAcceptor
if (BindErrorCode == asio::error::address_in_use)
{
- ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message());
+ ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message());
Sleep(500);
m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode);
}
@@ -1601,7 +1641,8 @@ struct HttpAcceptor
if (BindErrorCode)
{
- ZEN_INFO("Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')",
+ ZEN_INFO("{}: Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')",
+ m_Label,
BindErrorCode.message());
EffectivePort = 0;
@@ -1617,7 +1658,7 @@ struct HttpAcceptor
{
for (uint32_t Retries = 0; (BindErrorCode == asio::error::address_in_use) && (Retries < 3); Retries++)
{
- ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message());
+ ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message());
Sleep(500);
m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode);
}
@@ -1625,14 +1666,13 @@ struct HttpAcceptor
if (BindErrorCode)
{
- ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message());
-
- return {};
+ ZEN_WARN("{}: Unable to bind on port {} (bind returned '{}')", m_Label, BasePort, BindErrorCode.message());
+ return;
}
if (EffectivePort != BasePort)
{
- ZEN_WARN("Desired port {} is in use, remapped to port {}", BasePort, EffectivePort);
+ ZEN_WARN("{}: Desired port {} is in use, remapped to port {}", m_Label, BasePort, EffectivePort);
}
if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>)
@@ -1642,55 +1682,64 @@ struct HttpAcceptor
// IPv6 loopback will only respond on the IPv6 loopback address. Not everyone does
// IPv6 though so we also bind to IPv4 loopback (localhost/127.0.0.1)
- m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), BindErrorCode);
+ asio::error_code AltEc;
+ m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), AltEc);
- if (BindErrorCode)
+ if (AltEc)
{
- ZEN_WARN("Failed to register secondary IPv4 local-only handler 'http://{}:{}/'", "localhost", EffectivePort);
+ ZEN_WARN("{}: Failed to register secondary IPv4 local-only handler on port {}", m_Label, EffectivePort);
}
else
{
m_UseAlternateProtocolAcceptor = true;
- ZEN_INFO("Registered local-only handler 'http://{}:{}/' - this is not accessible from remote hosts",
- "localhost",
- EffectivePort);
}
}
}
- m_IsValid = true;
+#if ZEN_PLATFORM_WINDOWS
+ // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
+ // This must be used by both the client and server side, and is only effective in the absence of
+ // Windows Filtering Platform (WFP) callouts which can be installed by security software.
+ // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
+ SOCKET NativeSocket = m_Acceptor.native_handle();
+ int LoopbackOptionValue = 1;
+ DWORD OptionNumberOfBytesReturned = 0;
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
- if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>)
- {
- return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "[::1]" : "*", EffectivePort);
- }
- else
+ if (m_UseAlternateProtocolAcceptor)
{
- return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "127.0.0.1" : "*", EffectivePort);
+ NativeSocket = m_AlternateProtocolAcceptor.native_handle();
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
}
- }
-
- void Start()
- {
- ZEN_MEMSCOPE(GetHttpasioTag());
+#endif
- ZEN_ASSERT(!m_IsStopped);
- InitAcceptInternal(m_Acceptor);
+ m_Acceptor.listen();
if (m_UseAlternateProtocolAcceptor)
{
- InitAcceptInternal(m_AlternateProtocolAcceptor);
+ m_AlternateProtocolAcceptor.listen();
}
- }
- void StopAccepting() { m_IsStopped = true; }
-
- int GetAcceptPort() const { return m_Acceptor.local_endpoint().port(); }
- bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); }
-
- bool IsValid() const { return m_IsValid; }
+ m_IsValid = true;
+ ZEN_INFO("{}: Listening on port {}", m_Label, m_Acceptor.local_endpoint().port());
+ }
-private:
- void InitAcceptInternal(asio::ip::tcp::acceptor& Acceptor)
+ void InitAcceptLoop(asio::ip::tcp::acceptor& Acceptor)
{
auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService);
asio::ip::tcp::socket& SocketRef = *SocketPtr.get();
@@ -1698,29 +1747,19 @@ private:
Acceptor.async_accept(SocketRef, [this, &Acceptor, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
if (Ec)
{
- ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'",
- Acceptor.local_endpoint().address().to_string(),
- Acceptor.local_endpoint().port(),
- Ec.message());
+ if (!m_IsStopped.load())
+ {
+ ZEN_WARN("{}: async_accept failed: '{}'", m_Label, Ec.message());
+ }
}
else
{
- // New connection established, pass socket ownership into connection object
- // and initiate request handling loop. The connection lifetime is
- // managed by the async read/write loop by passing the shared
- // reference to the callbacks.
-
- Socket->set_option(asio::ip::tcp::no_delay(true));
- Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
- auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
- Conn->HandleNewRequest();
+ OnAccept(std::move(Socket));
}
if (!m_IsStopped.load())
{
- InitAcceptInternal(Acceptor);
+ InitAcceptLoop(Acceptor);
}
else
{
@@ -1728,21 +1767,204 @@ private:
Acceptor.close(CloseEc);
if (CloseEc)
{
- ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message());
+ ZEN_WARN("{}: acceptor close error: '{}'", m_Label, CloseEc.message());
}
}
});
}
- HttpAsioServerImpl& m_Server;
- asio::io_service& m_IoService;
asio::ip::tcp::acceptor m_Acceptor;
asio::ip::tcp::acceptor m_AlternateProtocolAcceptor;
bool m_UseAlternateProtocolAcceptor{false};
bool m_IsValid{false};
std::atomic<bool> m_IsStopped{false};
+ std::string_view m_Label;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpAcceptor final : TcpAcceptorBase
+{
+ HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ : TcpAcceptorBase(Server, IoService, BasePort, ForceLoopback, AllowPortProbing, "HTTP")
+ {
+ }
+
+ int GetAcceptPort() const { return GetPort(); }
+
+protected:
+ void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override
+ {
+ Socket->set_option(asio::ip::tcp::no_delay(true));
+ Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
+ Conn->HandleNewRequest();
+ }
};
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+
+//////////////////////////////////////////////////////////////////////////
+
+struct UnixAcceptor
+{
+ UnixAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, const std::string& SocketPath)
+ : m_Server(Server)
+ , m_IoService(IoService)
+ , m_Acceptor(m_IoService)
+ , m_SocketPath(SocketPath)
+ {
+ // Remove any stale socket file from a previous run
+ std::filesystem::remove(m_SocketPath);
+
+ asio::local::stream_protocol::endpoint Endpoint(m_SocketPath);
+
+ asio::error_code Ec;
+ m_Acceptor.open(Endpoint.protocol(), Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to open unix domain socket: {}", Ec.message());
+ return;
+ }
+
+ m_Acceptor.bind(Endpoint, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to bind unix domain socket at '{}': {}", m_SocketPath, Ec.message());
+ return;
+ }
+
+ m_Acceptor.listen(asio::socket_base::max_listen_connections, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to listen on unix domain socket at '{}': {}", m_SocketPath, Ec.message());
+ return;
+ }
+
+ m_IsValid = true;
+ ZEN_INFO("Started unix domain socket listener at '{}'", m_SocketPath);
+ }
+
+ ~UnixAcceptor()
+ {
+ asio::error_code Ec;
+ m_Acceptor.close(Ec);
+ std::filesystem::remove(m_SocketPath);
+ }
+
+ void Start()
+ {
+ ZEN_ASSERT(!m_IsStopped);
+ InitAccept();
+ }
+
+ void StopAccepting() { m_IsStopped = true; }
+
+ bool IsValid() const { return m_IsValid; }
+
+private:
+ void InitAccept()
+ {
+ auto SocketPtr = std::make_unique<asio::local::stream_protocol::socket>(m_IoService);
+ asio::local::stream_protocol::socket& SocketRef = *SocketPtr.get();
+
+ m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ if (!m_IsStopped.load())
+ {
+ ZEN_WARN("unix domain socket async_accept failed: '{}'", Ec.message());
+ }
+ }
+ else
+ {
+ auto Conn = std::make_shared<UnixServerConnection>(m_Server, std::move(Socket));
+ Conn->HandleNewRequest();
+ }
+
+ if (!m_IsStopped.load())
+ {
+ InitAccept();
+ }
+ else
+ {
+ std::error_code CloseEc;
+ m_Acceptor.close(CloseEc);
+ }
+ });
+ }
+
+ HttpAsioServerImpl& m_Server;
+ asio::io_service& m_IoService;
+ asio::local::stream_protocol::acceptor m_Acceptor;
+ std::string m_SocketPath;
+ bool m_IsValid{false};
+ std::atomic<bool> m_IsStopped{false};
+};
+
+#endif // ASIO_HAS_LOCAL_SOCKETS
+
+#if ZEN_USE_OPENSSL
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpsAcceptor final : TcpAcceptorBase
+{
+ HttpsAcceptor(HttpAsioServerImpl& Server,
+ asio::io_service& IoService,
+ asio::ssl::context& SslContext,
+ uint16_t Port,
+ bool ForceLoopback,
+ bool AllowPortProbing)
+ : TcpAcceptorBase(Server, IoService, Port, ForceLoopback, AllowPortProbing, "HTTPS")
+ , m_SslContext(SslContext)
+ {
+ }
+
+protected:
+ void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override
+ {
+ Socket->set_option(asio::ip::tcp::no_delay(true));
+ Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ // Wrap accepted TCP socket in an SSL stream and perform the handshake
+ auto SslSocketPtr = std::make_unique<SslSocket>(std::move(*Socket), m_SslContext);
+
+ SslSocket& SslRef = *SslSocketPtr;
+ SslRef.async_handshake(asio::ssl::stream_base::server,
+ [this, SslSocket = std::move(SslSocketPtr)](const asio::error_code& HandshakeEc) mutable {
+ if (HandshakeEc)
+ {
+ ZEN_WARN("SSL handshake failed: '{}'", HandshakeEc.message());
+ std::error_code Ec;
+ SslSocket->lowest_layer().close(Ec);
+ return;
+ }
+
+ auto Conn = std::make_shared<HttpsSslServerConnection>(m_Server, std::move(SslSocket));
+ Conn->HandleNewRequest();
+ });
+ }
+
+private:
+ asio::ssl::context& m_SslContext;
+};
+
+#endif // ZEN_USE_OPENSSL
+
+int
+HttpAsioServerImpl::GetEffectiveHttpsPort() const
+{
+#if ZEN_USE_OPENSSL
+ return m_HttpsAcceptor ? m_HttpsAcceptor->GetPort() : 0;
+#else
+ return 0;
+#endif
+}
+
//////////////////////////////////////////////////////////////////////////
HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request,
@@ -1860,6 +2082,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
std::array<IoBuffer, 0> Empty;
m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
@@ -1873,6 +2096,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(ContentType, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
}
@@ -1883,6 +2107,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(ContentType, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size());
std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
@@ -1942,6 +2167,51 @@ HttpAsioServerImpl::Start(uint16_t Port, const AsioConfig& Config)
m_Acceptor->Start();
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (!Config.UnixSocketPath.empty())
+ {
+ m_UnixAcceptor.reset(new asio_http::UnixAcceptor(*this, m_IoService, Config.UnixSocketPath));
+
+ if (m_UnixAcceptor->IsValid())
+ {
+ m_UnixAcceptor->Start();
+ }
+ else
+ {
+ m_UnixAcceptor.reset();
+ }
+ }
+#endif
+
+#if ZEN_USE_OPENSSL
+ if (!Config.CertFile.empty() && !Config.KeyFile.empty())
+ {
+ m_SslContext = std::make_unique<asio::ssl::context>(asio::ssl::context::tlsv12_server);
+ m_SslContext->set_options(asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
+ asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1);
+ m_SslContext->use_certificate_chain_file(Config.CertFile);
+ m_SslContext->use_private_key_file(Config.KeyFile, asio::ssl::context::pem);
+
+ ZEN_INFO("SSL context initialized (cert: '{}', key: '{}')", Config.CertFile, Config.KeyFile);
+
+ m_HttpsAcceptor.reset(new asio_http::HttpsAcceptor(*this,
+ m_IoService,
+ *m_SslContext,
+ gsl::narrow<uint16_t>(Config.HttpsPort),
+ Config.ForceLoopback,
+ /*AllowPortProbing*/ !Config.IsDedicatedServer));
+
+ if (m_HttpsAcceptor->IsValid())
+ {
+ m_HttpsAcceptor->Start();
+ }
+ else
+ {
+ m_HttpsAcceptor.reset();
+ }
+ }
+#endif
+
// This should consist of a set of minimum threads and grow on demand to
// meet concurrency needs? Right now we end up allocating a large number
// of threads even if we never end up using all of them, which seems
@@ -1990,6 +2260,18 @@ HttpAsioServerImpl::Stop()
{
m_Acceptor->StopAccepting();
}
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixAcceptor)
+ {
+ m_UnixAcceptor->StopAccepting();
+ }
+#endif
+#if ZEN_USE_OPENSSL
+ if (m_HttpsAcceptor)
+ {
+ m_HttpsAcceptor->StopAccepting();
+ }
+#endif
m_IoService.stop();
for (auto& Thread : m_ThreadPool)
{
@@ -1999,7 +2281,23 @@ HttpAsioServerImpl::Stop()
}
}
m_ThreadPool.clear();
+
+ // Drain remaining handlers (e.g. cancellation callbacks from active WebSocket
+ // connections) so that their captured Ref<> pointers are released while the
+ // io_service and its epoll reactor are still alive. Without this, sockets
+ // held by external code (e.g. IWebSocketHandler connection lists) can outlive
+ // the reactor and crash during deregistration.
+ m_IoService.restart();
+ m_IoService.poll();
+
m_Acceptor.reset();
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ m_UnixAcceptor.reset();
+#endif
+#if ZEN_USE_OPENSSL
+ m_HttpsAcceptor.reset();
+ m_SslContext.reset();
+#endif
}
void
@@ -2166,6 +2464,13 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Config);
+#if ZEN_USE_OPENSSL
+ if (int EffectiveHttpsPort = m_Impl->GetEffectiveHttpsPort(); EffectiveHttpsPort > 0)
+ {
+ SetEffectiveHttpsPort(EffectiveHttpsPort);
+ }
+#endif
+
return m_BasePort;
}
diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h
index 3ec1141a7..5adf4d5e8 100644
--- a/src/zenhttp/servers/httpasio.h
+++ b/src/zenhttp/servers/httpasio.h
@@ -11,6 +11,12 @@ struct AsioConfig
unsigned int ThreadCount = 0;
bool ForceLoopback = false;
bool IsDedicatedServer = false;
+ std::string UnixSocketPath;
+#if ZEN_USE_OPENSSL
+ int HttpsPort = 0; // 0 = auto-assign; set CertFile/KeyFile to enable HTTPS
+ std::string CertFile; // PEM certificate chain file (empty = HTTPS disabled)
+ std::string KeyFile; // PEM private key file
+#endif
};
Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config);
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index dfe6bb6aa..83b98013e 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -116,6 +116,12 @@ public:
private:
int InitializeServer(int BasePort);
+ bool CreateSessionAndUrlGroup();
+ bool RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris);
+ int RegisterHttpUrls(int BasePort);
+ bool RegisterHttpsUrls();
+ bool CreateRequestQueue(int EffectivePort);
+ bool SetupIoCompletionPort();
void Cleanup();
void StartServer();
@@ -125,6 +131,9 @@ private:
void RegisterService(const char* Endpoint, HttpService& Service);
void UnregisterService(const char* Endpoint, HttpService& Service);
+ bool BindSslCertificate(int Port);
+ void UnbindSslCertificate();
+
private:
LoggerRef m_Log;
LoggerRef m_RequestLog;
@@ -140,7 +149,10 @@ private:
RwLock m_AsyncWorkPoolInitLock;
std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr;
- std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ std::vector<std::wstring> m_HttpsBaseUris; // eg: https://*:nnnn/
+ bool m_DidAutoBindCert = false;
+ int m_HttpsPort = 0;
HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0;
HANDLE m_RequestQueueHandle = 0;
@@ -1082,39 +1094,63 @@ HttpSysServer::OnClose()
}
}
-int
-HttpSysServer::InitializeServer(int BasePort)
+bool
+HttpSysServer::CreateSessionAndUrlGroup()
{
- ZEN_MEMSCOPE(GetHttpsysTag());
-
- using namespace std::literals;
-
- WideStringBuilder<64> WildcardUrlPath;
- WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
-
- m_IsOk = false;
-
ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})",
- WideToUtf8(WildcardUrlPath),
- GetSystemErrorAsString(Result),
- Result);
+ ZEN_ERROR("Failed to create server session: {} ({:#x})", GetSystemErrorAsString(Result), Result);
- return 0;
+ return false;
}
Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result);
+ ZEN_ERROR("Failed to create URL group: {} ({:#x})", GetSystemErrorAsString(Result), Result);
- return 0;
+ return false;
}
+ return true;
+}
+
+bool
+HttpSysServer::RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris)
+{
+ using namespace std::literals;
+
+ const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
+
+ for (const std::u8string_view Host : Hosts)
+ {
+ WideStringBuilder<64> LocalUrl;
+ LocalUrl << Scheme << u8"://"sv << Host << u8":"sv << int64_t(Port) << u8"/"sv;
+
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrl.c_str(), HTTP_URL_CONTEXT(0), 0);
+
+ if (Result == NO_ERROR)
+ {
+ ZEN_WARN("Registered local-only handler '{}' - this is not accessible from remote hosts", WideToUtf8(LocalUrl));
+ OutUris.push_back(LocalUrl.c_str());
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ return !OutUris.empty();
+}
+
+int
+HttpSysServer::RegisterHttpUrls(int BasePort)
+{
+ using namespace std::literals;
+
m_BaseUris.clear();
const bool AllowPortProbing = !m_InitialConfig.IsDedicatedServer;
@@ -1122,6 +1158,11 @@ HttpSysServer::InitializeServer(int BasePort)
int EffectivePort = BasePort;
+ WideStringBuilder<64> WildcardUrlPath;
+ WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
+
+ ULONG Result;
+
if (m_InitialConfig.ForceLoopback)
{
// Force trigger of opening using local port
@@ -1177,11 +1218,11 @@ HttpSysServer::InitializeServer(int BasePort)
{
if (AllowLocalOnly)
{
- // If we can't register the wildcard path, we fall back to local paths
- // This local paths allow requests originating locally to function, but will not allow
- // remote origin requests to function. This can be remedied by using netsh
+ // If we can't register the wildcard path, we fall back to local paths.
+ // Local paths allow requests originating locally to function, but will not allow
+ // remote origin requests to function. This can be remedied by using netsh
// during an install process to grant permissions to route public access to the appropriate
- // port for the current user. eg:
+ // port for the current user. eg:
// netsh http add urlacl url=http://*:8558/ user=<some_user>
if (!m_InitialConfig.ForceLoopback)
@@ -1246,7 +1287,7 @@ HttpSysServer::InitializeServer(int BasePort)
}
}
- if (m_BaseUris.empty())
+ if (m_BaseUris.empty() && m_InitialConfig.HttpsPort == 0)
{
ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})",
WideToUtf8(WildcardUrlPath),
@@ -1256,16 +1297,104 @@ HttpSysServer::InitializeServer(int BasePort)
return 0;
}
+ return EffectivePort;
+}
+
+bool
+HttpSysServer::RegisterHttpsUrls()
+{
+ using namespace std::literals;
+
+ const bool AllowLocalOnly = !m_InitialConfig.IsDedicatedServer;
+ const int HttpsPort = m_InitialConfig.HttpsPort;
+
+ // If HTTPS-only mode, remove HTTP URLs and clear base URIs
+ if (m_InitialConfig.HttpsOnly)
+ {
+ for (const std::wstring& Uri : m_BaseUris)
+ {
+ HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Uri.c_str(), 0);
+ }
+ m_BaseUris.clear();
+ }
+
+ // Auto-bind certificate if thumbprint is provided
+ if (!m_InitialConfig.CertThumbprint.empty())
+ {
+ if (!BindSslCertificate(HttpsPort))
+ {
+ return false;
+ }
+ }
+ else
+ {
+ ZEN_INFO("HTTPS port {} configured without thumbprint - assuming pre-registered SSL certificate", HttpsPort);
+ }
+
+ // Register HTTPS URLs using same pattern as HTTP
+
+ WideStringBuilder<64> HttpsWildcard;
+ HttpsWildcard << u8"https://*:"sv << int64_t(HttpsPort) << u8"/"sv;
+
+ ULONG HttpsResult = NO_ERROR;
+
+ if (m_InitialConfig.ForceLoopback)
+ {
+ HttpsResult = ERROR_ACCESS_DENIED;
+ }
+ else
+ {
+ HttpsResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, HttpsWildcard.c_str(), HTTP_URL_CONTEXT(0), 0);
+ }
+
+ if (HttpsResult == NO_ERROR)
+ {
+ m_HttpsBaseUris.push_back(HttpsWildcard.c_str());
+ }
+ else if (HttpsResult == ERROR_ACCESS_DENIED && AllowLocalOnly)
+ {
+ if (!m_InitialConfig.ForceLoopback)
+ {
+ ZEN_WARN(
+ "Unable to register HTTPS handler using '{}' - falling back to local-only. "
+ "Please ensure the appropriate netsh URL reservation and SSL certificate configuration is made.",
+ WideToUtf8(HttpsWildcard));
+ }
+
+ RegisterLocalUrls(u8"https", HttpsPort, m_HttpsBaseUris);
+ }
+ else if (HttpsResult != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to register HTTPS URL '{}': {} ({:#x})",
+ WideToUtf8(HttpsWildcard),
+ GetSystemErrorAsString(HttpsResult),
+ HttpsResult);
+ return false;
+ }
+
+ if (m_HttpsBaseUris.empty())
+ {
+ ZEN_ERROR("Failed to register any HTTPS URL for port {}", HttpsPort);
+ return false;
+ }
+
+ m_HttpsPort = HttpsPort;
+ return true;
+}
+
+bool
+HttpSysServer::CreateRequestQueue(int EffectivePort)
+{
HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0};
WideStringBuilder<64> QueueName;
QueueName << "zenserver_" << EffectivePort;
- Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
- /* Name */ QueueName.c_str(),
- /* SecurityAttributes */ nullptr,
- /* Flags */ 0,
- &m_RequestQueueHandle);
+ ULONG Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
+ /* Name */ QueueName.c_str(),
+ /* SecurityAttributes */ nullptr,
+ /* Flags */ 0,
+ &m_RequestQueueHandle);
if (Result != NO_ERROR)
{
@@ -1274,7 +1403,7 @@ HttpSysServer::InitializeServer(int BasePort)
GetSystemErrorAsString(Result),
Result);
- return 0;
+ return false;
}
HttpBindingInfo.Flags.Present = 1;
@@ -1289,7 +1418,7 @@ HttpSysServer::InitializeServer(int BasePort)
GetSystemErrorAsString(Result),
Result);
- return 0;
+ return false;
}
// Configure rejection method. Default is to drop the connection, it's better if we
@@ -1323,22 +1452,77 @@ HttpSysServer::InitializeServer(int BasePort)
}
}
- // Create I/O completion port
+ return true;
+}
+bool
+HttpSysServer::SetupIoCompletionPort()
+{
std::error_code ErrorCode;
m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode);
if (ErrorCode)
{
- ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message());
+ ZEN_ERROR("Failed to create IOCP: {}", ErrorCode.message());
+ return false;
+ }
+ m_IsOk = true;
+
+ if (!m_BaseUris.empty())
+ {
+ ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ }
+ if (!m_HttpsBaseUris.empty())
+ {
+ ZEN_INFO("Started http.sys HTTPS server at '{}'", WideToUtf8(m_HttpsBaseUris.front()));
+ }
+
+ return true;
+}
+
+int
+HttpSysServer::InitializeServer(int BasePort)
+{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
+ m_IsOk = false;
+
+ if (!CreateSessionAndUrlGroup())
+ {
return 0;
}
- else
+
+ int EffectivePort = RegisterHttpUrls(BasePort);
+
+ if (m_InitialConfig.HttpsPort > 0)
+ {
+ if (!RegisterHttpsUrls())
+ {
+ return 0;
+ }
+ }
+
+ if (m_BaseUris.empty() && m_HttpsBaseUris.empty())
{
- m_IsOk = true;
+ ZEN_ERROR("No HTTP or HTTPS listeners could be registered");
+ return 0;
+ }
- ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ if (!CreateRequestQueue(EffectivePort))
+ {
+ return 0;
+ }
+
+ if (!SetupIoCompletionPort())
+ {
+ return 0;
+ }
+
+ // When HTTPS-only, return the HTTPS port as the effective port
+ if (m_InitialConfig.HttpsOnly && m_HttpsPort > 0)
+ {
+ return m_HttpsPort;
}
return EffectivePort;
@@ -1349,6 +1533,8 @@ HttpSysServer::Cleanup()
{
++m_IsShuttingDown;
+ UnbindSslCertificate();
+
if (m_RequestQueueHandle)
{
HttpCloseRequestQueue(m_RequestQueueHandle);
@@ -1368,6 +1554,105 @@ HttpSysServer::Cleanup()
}
}
+// {7E3F4B2A-1C8D-4A6E-B5F0-9D2E8C7A3B1F} - Fixed GUID for zenserver SSL bindings
+static constexpr GUID ZenServerSslAppId = {0x7E3F4B2A, 0x1C8D, 0x4A6E, {0xB5, 0xF0, 0x9D, 0x2E, 0x8C, 0x7A, 0x3B, 0x1F}};
+
+bool
+HttpSysServer::BindSslCertificate(int Port)
+{
+ const std::string& Thumbprint = m_InitialConfig.CertThumbprint;
+ if (Thumbprint.size() != 40)
+ {
+ ZEN_ERROR("SSL certificate thumbprint must be exactly 40 hex characters, got {}", Thumbprint.size());
+ return false;
+ }
+
+ BYTE CertHash[20] = {};
+ if (!ParseHexBytes(Thumbprint, CertHash))
+ {
+ ZEN_ERROR("SSL certificate thumbprint contains invalid hex characters");
+ return false;
+ }
+
+ SOCKADDR_IN Address = {};
+ Address.sin_family = AF_INET;
+ Address.sin_port = htons(static_cast<USHORT>(Port));
+ Address.sin_addr.s_addr = INADDR_ANY;
+
+ const std::wstring StoreNameW = UTF8_to_UTF16(m_InitialConfig.CertStoreName.c_str());
+
+ HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {};
+ SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+ SslConfig.ParamDesc.pSslHash = CertHash;
+ SslConfig.ParamDesc.SslHashLength = sizeof(CertHash);
+ SslConfig.ParamDesc.pSslCertStoreName = const_cast<PWSTR>(StoreNameW.c_str());
+ SslConfig.ParamDesc.AppId = ZenServerSslAppId;
+
+ ULONG Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+
+ if (Result == ERROR_ALREADY_EXISTS)
+ {
+ // Remove existing binding and retry
+ HTTP_SERVICE_CONFIG_SSL_SET DeleteConfig = {};
+ DeleteConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+
+ HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &DeleteConfig, sizeof(DeleteConfig), nullptr);
+
+ Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+ }
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR(
+ "Failed to bind SSL certificate to port {}: {} ({:#x}). "
+ "This operation may require running as administrator.",
+ Port,
+ GetSystemErrorAsString(Result),
+ Result);
+ return false;
+ }
+
+ m_DidAutoBindCert = true;
+ m_HttpsPort = Port;
+
+ ZEN_INFO("SSL certificate auto-bound for 0.0.0.0:{} (thumbprint: {}..., store: {})",
+ Port,
+ Thumbprint.substr(0, 8),
+ m_InitialConfig.CertStoreName);
+
+ return true;
+}
+
+void
+HttpSysServer::UnbindSslCertificate()
+{
+ if (!m_DidAutoBindCert)
+ {
+ return;
+ }
+
+ SOCKADDR_IN Address = {};
+ Address.sin_family = AF_INET;
+ Address.sin_port = htons(static_cast<USHORT>(m_HttpsPort));
+ Address.sin_addr.s_addr = INADDR_ANY;
+
+ HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {};
+ SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+
+ ULONG Result = HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_WARN("Failed to remove SSL certificate binding from port {}: {} ({:#x})", m_HttpsPort, GetSystemErrorAsString(Result), Result);
+ }
+ else
+ {
+ ZEN_INFO("SSL certificate binding removed from port {}", m_HttpsPort);
+ }
+
+ m_DidAutoBindCert = false;
+}
+
WorkerThreadPool&
HttpSysServer::WorkPool()
{
@@ -1495,19 +1780,23 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service)
// Convert to wide string
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
-
- ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
-
- if (Result != NO_ERROR)
+ auto RegisterWithBaseUris = [&](const std::vector<std::wstring>& BaseUris) {
+ for (const std::wstring& BaseUri : BaseUris)
{
- ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ std::wstring Url16 = BaseUri + PathUtf16;
- return;
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ return;
+ }
}
- }
+ };
+
+ RegisterWithBaseUris(m_BaseUris);
+ RegisterWithBaseUris(m_HttpsBaseUris);
}
void
@@ -1522,19 +1811,22 @@ HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service)
const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath);
- // Convert to wide string
-
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
+ auto UnregisterFromBaseUris = [&](const std::vector<std::wstring>& BaseUris) {
+ for (const std::wstring& BaseUri : BaseUris)
+ {
+ std::wstring Url16 = BaseUri + PathUtf16;
- ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
+ ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
- if (Result != NO_ERROR)
- {
- ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ }
}
- }
+ };
+
+ UnregisterFromBaseUris(m_BaseUris);
+ UnregisterFromBaseUris(m_HttpsBaseUris);
}
//////////////////////////////////////////////////////////////////////////
@@ -2422,6 +2714,11 @@ HttpSysServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
ZEN_UNUSED(DataDir);
if (int EffectivePort = InitializeServer(BasePort))
{
+ if (m_HttpsPort > 0)
+ {
+ SetEffectiveHttpsPort(m_HttpsPort);
+ }
+
StartServer();
return EffectivePort;
diff --git a/src/zenhttp/servers/httpsys.h b/src/zenhttp/servers/httpsys.h
index b2fe7475b..ca465ad00 100644
--- a/src/zenhttp/servers/httpsys.h
+++ b/src/zenhttp/servers/httpsys.h
@@ -22,6 +22,10 @@ struct HttpSysConfig
bool IsRequestLoggingEnabled = false;
bool IsDedicatedServer = false;
bool ForceLoopback = false;
+ int HttpsPort = 0; // 0 = HTTPS disabled
+ std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding
+ std::string CertStoreName = "MY"; // Windows certificate store name
+ bool HttpsOnly = false; // When true, disable HTTP listener
};
Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config);
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
index b2543277a..5ae48f5b3 100644
--- a/src/zenhttp/servers/wsasio.cpp
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -1,6 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include "wsasio.h"
+#include "asio_socket_traits.h"
#include "wsframecodec.h"
#include <zencore/logging.h>
@@ -17,14 +18,16 @@ WsLog()
//////////////////////////////////////////////////////////////////////////
-WsAsioConnection::WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server)
+template<typename SocketType>
+WsAsioConnectionT<SocketType>::WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server)
: m_Socket(std::move(Socket))
, m_Handler(Handler)
, m_HttpServer(Server)
{
}
-WsAsioConnection::~WsAsioConnection()
+template<typename SocketType>
+WsAsioConnectionT<SocketType>::~WsAsioConnectionT()
{
m_IsOpen.store(false);
if (m_HttpServer)
@@ -33,14 +36,16 @@ WsAsioConnection::~WsAsioConnection()
}
}
+template<typename SocketType>
void
-WsAsioConnection::Start()
+WsAsioConnectionT<SocketType>::Start()
{
EnqueueRead();
}
+template<typename SocketType>
bool
-WsAsioConnection::IsOpen() const
+WsAsioConnectionT<SocketType>::IsOpen() const
{
return m_IsOpen.load(std::memory_order_relaxed);
}
@@ -50,23 +55,25 @@ WsAsioConnection::IsOpen() const
// Read loop
//
+template<typename SocketType>
void
-WsAsioConnection::EnqueueRead()
+WsAsioConnectionT<SocketType>::EnqueueRead()
{
if (!m_IsOpen.load(std::memory_order_relaxed))
{
return;
}
- Ref<WsAsioConnection> Self(this);
+ Ref<WsAsioConnectionT> Self(this);
asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) {
Self->OnDataReceived(Ec, ByteCount);
});
}
+template<typename SocketType>
void
-WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+WsAsioConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
{
if (Ec)
{
@@ -90,8 +97,9 @@ WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] st
}
}
+template<typename SocketType>
void
-WsAsioConnection::ProcessReceivedData()
+WsAsioConnectionT<SocketType>::ProcessReceivedData()
{
while (m_ReadBuffer.size() > 0)
{
@@ -162,8 +170,8 @@ WsAsioConnection::ProcessReceivedData()
// Shut down the socket
std::error_code ShutdownEc;
- m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc);
- m_Socket->close(ShutdownEc);
+ SocketTraits<SocketType>::ShutdownBoth(*m_Socket, ShutdownEc);
+ SocketTraits<SocketType>::Close(*m_Socket, ShutdownEc);
return;
}
@@ -179,8 +187,9 @@ WsAsioConnection::ProcessReceivedData()
// Write queue
//
+template<typename SocketType>
void
-WsAsioConnection::SendText(std::string_view Text)
+WsAsioConnectionT<SocketType>::SendText(std::string_view Text)
{
if (!m_IsOpen.load(std::memory_order_relaxed))
{
@@ -192,8 +201,9 @@ WsAsioConnection::SendText(std::string_view Text)
EnqueueWrite(std::move(Frame));
}
+template<typename SocketType>
void
-WsAsioConnection::SendBinary(std::span<const uint8_t> Data)
+WsAsioConnectionT<SocketType>::SendBinary(std::span<const uint8_t> Data)
{
if (!m_IsOpen.load(std::memory_order_relaxed))
{
@@ -204,14 +214,16 @@ WsAsioConnection::SendBinary(std::span<const uint8_t> Data)
EnqueueWrite(std::move(Frame));
}
+template<typename SocketType>
void
-WsAsioConnection::Close(uint16_t Code, std::string_view Reason)
+WsAsioConnectionT<SocketType>::Close(uint16_t Code, std::string_view Reason)
{
DoClose(Code, Reason);
}
+template<typename SocketType>
void
-WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason)
+WsAsioConnectionT<SocketType>::DoClose(uint16_t Code, std::string_view Reason)
{
if (!m_IsOpen.exchange(false))
{
@@ -227,8 +239,9 @@ WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason)
m_Handler.OnWebSocketClose(*this, Code, Reason);
}
+template<typename SocketType>
void
-WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+WsAsioConnectionT<SocketType>::EnqueueWrite(std::vector<uint8_t> Frame)
{
if (m_HttpServer)
{
@@ -252,8 +265,9 @@ WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame)
}
}
+template<typename SocketType>
void
-WsAsioConnection::FlushWriteQueue()
+WsAsioConnectionT<SocketType>::FlushWriteQueue()
{
std::vector<uint8_t> Frame;
@@ -272,7 +286,7 @@ WsAsioConnection::FlushWriteQueue()
return;
}
- Ref<WsAsioConnection> Self(this);
+ Ref<WsAsioConnectionT> Self(this);
// Move Frame into a shared_ptr so we can create the buffer and capture ownership
// in the same async_write call without evaluation order issues.
@@ -283,8 +297,9 @@ WsAsioConnection::FlushWriteQueue()
[Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); });
}
+template<typename SocketType>
void
-WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+WsAsioConnectionT<SocketType>::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
{
if (Ec)
{
@@ -308,4 +323,17 @@ WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] s
FlushWriteQueue();
}
+//////////////////////////////////////////////////////////////////////////
+// Explicit template instantiations
+
+template class WsAsioConnectionT<asio::ip::tcp::socket>;
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+template class WsAsioConnectionT<asio::local::stream_protocol::socket>;
+#endif
+
+#if ZEN_USE_OPENSSL
+template class WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>;
+#endif
+
} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h
index e8bb3b1d2..64602ee46 100644
--- a/src/zenhttp/servers/wsasio.h
+++ b/src/zenhttp/servers/wsasio.h
@@ -8,6 +8,12 @@
ZEN_THIRD_PARTY_INCLUDES_START
#include <asio.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
ZEN_THIRD_PARTY_INCLUDES_END
#include <deque>
@@ -21,22 +27,23 @@ class HttpServer;
namespace zen::asio_http {
/**
- * WebSocket connection over an ASIO TCP socket
+ * WebSocket connection over an ASIO stream socket
*
- * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake)
+ * Templated on SocketType to support both TCP and Unix domain sockets.
+ * Owns the socket (moved from HttpServerConnection after the 101 handshake)
* and runs an async read/write loop to exchange WebSocket frames.
*
* Lifetime is managed solely through intrusive reference counting (RefCounted).
- * The async read/write callbacks capture Ref<WsAsioConnection> to keep the
- * connection alive for the duration of the async operation. The service layer
- * also holds a Ref<WebSocketConnection>.
+ * The async read/write callbacks capture Ref<> to keep the connection alive
+ * for the duration of the async operation. The service layer also holds a
+ * Ref<WebSocketConnection>.
*/
-
-class WsAsioConnection : public WebSocketConnection
+template<typename SocketType>
+class WsAsioConnectionT : public WebSocketConnection
{
public:
- WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server);
- ~WsAsioConnection() override;
+ WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server);
+ ~WsAsioConnectionT() override;
/**
* Start the async read loop. Must be called once after construction
@@ -61,10 +68,10 @@ private:
void DoClose(uint16_t Code, std::string_view Reason);
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- IWebSocketHandler& m_Handler;
- zen::HttpServer* m_HttpServer;
- asio::streambuf m_ReadBuffer;
+ std::unique_ptr<SocketType> m_Socket;
+ IWebSocketHandler& m_Handler;
+ zen::HttpServer* m_HttpServer;
+ asio::streambuf m_ReadBuffer;
RwLock m_WriteLock;
std::deque<std::vector<uint8_t>> m_WriteQueue;
@@ -74,4 +81,14 @@ private:
std::atomic<bool> m_CloseSent{false};
};
+using WsAsioConnection = WsAsioConnectionT<asio::ip::tcp::socket>;
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+using WsAsioUnixConnection = WsAsioConnectionT<asio::local::stream_protocol::socket>;
+#endif
+
+#if ZEN_USE_OPENSSL
+using WsAsioSslConnection = WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>;
+#endif
+
} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
index 2134e4ff1..042afd8ff 100644
--- a/src/zenhttp/servers/wstest.cpp
+++ b/src/zenhttp/servers/wstest.cpp
@@ -485,7 +485,7 @@ TEST_CASE("websocket.integration")
Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
- int Port = Server->Initialize(7575, TmpDir.Path());
+ int Port = Server->Initialize(0, TmpDir.Path());
REQUIRE(Port != 0);
Server->RegisterService(TestService);
@@ -797,7 +797,7 @@ TEST_CASE("websocket.client")
Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
- int Port = Server->Initialize(7576, TmpDir.Path());
+ int Port = Server->Initialize(0, TmpDir.Path());
REQUIRE(Port != 0);
Server->RegisterService(TestService);
@@ -913,6 +913,75 @@ TEST_CASE("websocket.client")
}
}
+TEST_CASE("websocket.client.unixsocket")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+ std::string SocketPath = (TmpDir.Path() / "ws.sock").string();
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath});
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ Sleep(100);
+
+ SUBCASE("connect, echo, close over unix socket")
+ {
+ TestWsClientHandler Handler;
+ HttpWsClientSettings Settings;
+ Settings.UnixSocketPath = SocketPath;
+
+ HttpWsClient Client("ws://localhost/wstest/ws", Handler, Settings);
+ Client.Connect();
+
+ // Wait for OnWsOpen
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+ CHECK(Client.IsOpen());
+
+ // Send text, expect echo
+ Client.SendText("hello over unix socket");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ CHECK_EQ(Handler.m_MessageCount.load(), 1);
+ CHECK_EQ(Handler.m_LastMessage, "hello over unix socket");
+
+ // Close
+ Client.Close(1000, "done");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ Sleep(50);
+ CHECK_FALSE(Client.IsOpen());
+ }
+}
+
TEST_SUITE_END();
void