From 0a41fd42aa43080fbc991e7d976dde70aeaec594 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Fri, 27 Feb 2026 17:13:40 +0100 Subject: add full WebSocket (RFC 6455) client/server support for zenhttp (#792) * This branch adds full WebSocket (RFC 6455) support to the HTTP server layer, covering both transport backends, a client, and tests. - **`websocket.h`** -- Core interfaces: `WebSocketOpcode`, `WebSocketMessage`, `WebSocketConnection` (ref-counted), and `IWebSocketHandler`. Services opt in to WebSocket support by implementing `IWebSocketHandler` alongside their existing `HttpService`. - **`httpwsclient.h`** -- `HttpWsClient`: an ASIO-backed `ws://` client with both standalone (own thread) and shared `io_context` modes. Supports connect timeout and optional auth token injection via `IWsClientHandler` callbacks. - **`wsasio.cpp/h`** -- `WsAsioConnection`: WebSocket over ASIO TCP. Takes over the socket after the HTTP 101 handshake and runs an async read/write loop with a queued write path (guarded by `RwLock`). - **`wshttpsys.cpp/h`** -- `WsHttpSysConnection`: WebSocket over http.sys opaque-mode connections (Windows only). Uses `HttpReceiveRequestEntityBody` / `HttpSendResponseEntityBody` via IOCP, sharing the same threadpool as normal http.sys traffic. Self-ref lifetime management ensures graceful drain of outstanding async ops. - **`httpsys_iocontext.h`** -- Tagged `OVERLAPPED` wrapper (`HttpSysIoContext`) used to distinguish normal HTTP transactions from WebSocket read/write completions in the single IOCP callback. - **`wsframecodec.cpp/h`** -- `WsFrameCodec`: static helpers for parsing (unmasked and masked) and building (unmasked server frames and masked client frames) RFC 6455 frames across all three payload length encodings (7-bit, 16-bit, 64-bit). Also computes `Sec-WebSocket-Accept` keys. - **`clients/httpwsclient.cpp`** -- `HttpWsClient::Impl`: ASIO-based client that performs the HTTP upgrade handshake, then hands off to the frame codec for the read loop. Manages its own `io_context` thread or plugs into an external one. - **`httpasio.cpp`** -- ASIO server now detects `Upgrade: websocket` requests, checks the matching `HttpService` for `IWebSocketHandler` via `dynamic_cast`, performs the RFC 6455 handshake (101 response), and spins up a `WsAsioConnection`. - **`httpsys.cpp`** -- Same upgrade detection and handshake logic for the http.sys backend, using `WsHttpSysConnection` and `HTTP_SEND_RESPONSE_FLAG_OPAQUE`. - **`httpparser.cpp/h`** -- Extended to surface the `Upgrade` / `Connection` / `Sec-WebSocket-Key` headers needed by the handshake. - **`httpcommon.h`** -- Minor additions (probably new header constants or response codes for the WS upgrade). - **`httpserver.h`** -- Small interface changes to support WebSocket registration. - **`zenhttp.cpp` / `xmake.lua`** -- New source files wired in; build config updated. - **Unit tests** (`websocket.framecodec`): round-trip encode/decode for text, binary, close frames; all three payload sizes; masked and unmasked variants; RFC 6455 `Sec-WebSocket-Accept` test vector. - **Integration tests** (`websocket.integration`): full ASIO server tests covering handshake (101), normal HTTP coexistence, echo, server-push broadcast, client close handshake, ping/pong auto-response, sequential messages, and rejection of upgrades on non-WS services. - **Client tests** (`websocket.client`): `HttpWsClient` connect+echo+close, connection failure (bad port -> close code 1006), and server-initiated close. * changed HttpRequestParser::ParseCurrentHeader to use switch instead of if/else chain * remove spurious printf --------- Co-authored-by: Stefan Boberg --- src/zencore/filesystem.cpp | 1 - src/zenhttp/clients/httpwsclient.cpp | 568 ++++++++++++++++++ src/zenhttp/include/zenhttp/httpcommon.h | 7 + src/zenhttp/include/zenhttp/httpserver.h | 3 +- src/zenhttp/include/zenhttp/httpwsclient.h | 79 +++ src/zenhttp/include/zenhttp/websocket.h | 65 ++ src/zenhttp/servers/httpasio.cpp | 49 ++ src/zenhttp/servers/httpparser.cpp | 148 +++-- src/zenhttp/servers/httpparser.h | 7 + src/zenhttp/servers/httpsys.cpp | 180 ++++-- src/zenhttp/servers/httpsys_iocontext.h | 40 ++ src/zenhttp/servers/wsasio.cpp | 297 ++++++++++ src/zenhttp/servers/wsasio.h | 71 +++ src/zenhttp/servers/wsframecodec.cpp | 229 +++++++ src/zenhttp/servers/wsframecodec.h | 74 +++ src/zenhttp/servers/wshttpsys.cpp | 466 +++++++++++++++ src/zenhttp/servers/wshttpsys.h | 104 ++++ src/zenhttp/servers/wstest.cpp | 922 +++++++++++++++++++++++++++++ src/zenhttp/xmake.lua | 1 + src/zenhttp/zenhttp.cpp | 1 + 20 files changed, 3203 insertions(+), 109 deletions(-) create mode 100644 src/zenhttp/clients/httpwsclient.cpp create mode 100644 src/zenhttp/include/zenhttp/httpwsclient.h create mode 100644 src/zenhttp/include/zenhttp/websocket.h create mode 100644 src/zenhttp/servers/httpsys_iocontext.h create mode 100644 src/zenhttp/servers/wsasio.cpp create mode 100644 src/zenhttp/servers/wsasio.h create mode 100644 src/zenhttp/servers/wsframecodec.cpp create mode 100644 src/zenhttp/servers/wsframecodec.h create mode 100644 src/zenhttp/servers/wshttpsys.cpp create mode 100644 src/zenhttp/servers/wshttpsys.h create mode 100644 src/zenhttp/servers/wstest.cpp diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 553897407..03398860b 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -3533,7 +3533,6 @@ TEST_CASE("PathBuilder") Path.Reset(); Path.Append(fspath(L"/\u0119oo/")); Path /= L"bar"; - printf("%ls\n", Path.ToPath().c_str()); CHECK(Path.ToView() == L"/\u0119oo/bar"); CHECK(Path.ToPath() == L"\\\u0119oo\\bar"); # endif diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp new file mode 100644 index 000000000..36a6f081b --- /dev/null +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -0,0 +1,568 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include "../servers/wsframecodec.h" + +#include +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#include +#include +#include + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +struct HttpWsClient::Impl +{ + Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_OwnedIoContext(std::make_unique()) + , m_IoContext(*m_OwnedIoContext) + { + ParseUrl(Url); + } + + Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_IoContext(IoContext) + { + ParseUrl(Url); + } + + ~Impl() + { + // Release work guard so io_context::run() can return + m_WorkGuard.reset(); + + // Close the socket to cancel pending async ops + if (m_Socket) + { + asio::error_code Ec; + m_Socket->close(Ec); + } + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + } + + void ParseUrl(std::string_view Url) + { + // Expected format: ws://host:port/path + if (Url.substr(0, 5) == "ws://") + { + Url.remove_prefix(5); + } + + auto SlashPos = Url.find('/'); + std::string_view HostPort; + if (SlashPos != std::string_view::npos) + { + HostPort = Url.substr(0, SlashPos); + m_Path = std::string(Url.substr(SlashPos)); + } + else + { + HostPort = Url; + m_Path = "/"; + } + + auto ColonPos = HostPort.find(':'); + if (ColonPos != std::string_view::npos) + { + m_Host = std::string(HostPort.substr(0, ColonPos)); + m_Port = std::string(HostPort.substr(ColonPos + 1)); + } + else + { + m_Host = std::string(HostPort); + m_Port = "80"; + } + } + + void Connect() + { + if (m_OwnedIoContext) + { + m_WorkGuard = std::make_unique(m_IoContext); + m_IoThread = std::thread([this] { m_IoContext.run(); }); + } + + asio::post(m_IoContext, [this] { DoResolve(); }); + } + + void DoResolve() + { + m_Resolver = std::make_unique(m_IoContext); + + m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) { + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "resolve failed"); + return; + } + + DoConnect(Results); + }); + } + + void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints) + { + m_Socket = std::make_unique(m_IoContext); + + // Start connect timeout timer + m_Timer = std::make_unique(m_IoContext, m_Settings.ConnectTimeout); + m_Timer->async_wait([this](const asio::error_code& Ec) { + if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port); + if (m_Socket) + { + asio::error_code CloseEc; + m_Socket->close(CloseEc); + } + } + }); + + asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "connect failed"); + return; + } + + DoHandshake(); + }); + } + + void DoHandshake() + { + // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded) + uint8_t KeyBytes[16]; + { + static thread_local std::mt19937 s_Rng(std::random_device{}()); + for (int i = 0; i < 4; ++i) + { + uint32_t Val = s_Rng(); + std::memcpy(KeyBytes + i * 4, &Val, 4); + } + } + + char KeyBase64[Base64::GetEncodedDataSize(16) + 1]; + uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64); + KeyBase64[KeyLen] = '\0'; + m_WebSocketKey = std::string(KeyBase64, KeyLen); + + // Build the HTTP upgrade request + ExtendableStringBuilder<512> Request; + Request << "GET " << m_Path << " HTTP/1.1\r\n" + << "Host: " << m_Host << ":" << m_Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n" + << "Sec-WebSocket-Version: 13\r\n"; + + // Add Authorization header if access token provider is set + if (m_Settings.AccessTokenProvider) + { + HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); + if (Token.IsValid()) + { + Request << "Authorization: Bearer " << Token.Value << "\r\n"; + } + } + + Request << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + + m_HandshakeBuffer = std::make_shared(ReqStr); + + asio::async_write(*m_Socket, + asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()), + [this](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake write failed"); + return; + } + + DoReadHandshakeResponse(); + }); + } + + void DoReadHandshakeResponse() + { + asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) { + m_Timer->cancel(); + + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake read failed"); + return; + } + + // Parse the response + const auto& Data = m_ReadBuffer.data(); + std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); + + // Consume the headers from the read buffer (any extra data stays for frame parsing) + auto HeaderEnd = Response.find("\r\n\r\n"); + if (HeaderEnd != std::string::npos) + { + m_ReadBuffer.consume(HeaderEnd + 4); + } + + // Validate 101 response + if (Response.find("101") == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); + m_Handler.OnWsClose(1006, "handshake rejected"); + return; + } + + // Validate Sec-WebSocket-Accept + std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); + if (Response.find(ExpectedAccept) == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); + m_Handler.OnWsClose(1006, "invalid accept key"); + return; + } + + m_IsOpen.store(true); + m_Handler.OnWsOpen(); + EnqueueRead(); + }); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Read loop + // + + void EnqueueRead() + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) { + OnDataReceived(Ec); + }); + } + + void OnDataReceived(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message()); + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "connection lost"); + } + return; + } + + ProcessReceivedData(); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + EnqueueRead(); + } + } + + void ProcessReceivedData() + { + while (m_ReadBuffer.size() > 0) + { + const auto& InputBuffer = m_ReadBuffer.data(); + const auto* RawData = static_cast(InputBuffer.data()); + const auto Size = InputBuffer.size(); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size); + if (!Frame.IsValid) + { + break; + } + + m_ReadBuffer.consume(Frame.BytesConsumed); + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWsMessage(Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with masked pong + std::vector PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = + std::string_view(reinterpret_cast(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo masked close frame if we haven't sent one yet + if (!m_CloseSent) + { + m_CloseSent = true; + std::vector CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + + m_IsOpen.store(false); + m_Handler.OnWsClose(Code, Reason); + return; + } + + default: + ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast(Frame.Opcode)); + break; + } + } + } + + ////////////////////////////////////////////////////////////////////////// + // + // Write queue + // + + void EnqueueWrite(std::vector Frame) + { + bool ShouldFlush = false; + + m_WriteLock.WithExclusiveLock([&] { + m_WriteQueue.push_back(std::move(Frame)); + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + }); + + if (ShouldFlush) + { + FlushWriteQueue(); + } + } + + void FlushWriteQueue() + { + std::vector Frame; + + m_WriteLock.WithExclusiveLock([&] { + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + Frame = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + }); + + if (Frame.empty()) + { + return; + } + + auto OwnedFrame = std::make_shared>(std::move(Frame)); + + asio::async_write(*m_Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); }); + } + + void OnWriteComplete(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message()); + } + + m_WriteLock.WithExclusiveLock([&] { + m_IsWriting = false; + m_WriteQueue.clear(); + }); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "write error"); + } + return; + } + + FlushWriteQueue(); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Public operations + // + + void SendText(std::string_view Text) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); + } + + void SendBinary(std::span Data) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); + } + + void DoClose(uint16_t Code, std::string_view Reason) + { + if (!m_IsOpen.exchange(false)) + { + return; + } + + if (!m_CloseSent) + { + m_CloseSent = true; + std::vector CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + } + + IWsClientHandler& m_Handler; + HttpWsClientSettings m_Settings; + LoggerRef m_Log; + + std::string m_Host; + std::string m_Port; + std::string m_Path; + + // io_context: owned (standalone) or external (shared) + std::unique_ptr m_OwnedIoContext; + asio::io_context& m_IoContext; + std::unique_ptr m_WorkGuard; + std::thread m_IoThread; + + // Connection state + std::unique_ptr m_Resolver; + std::unique_ptr m_Socket; + std::unique_ptr m_Timer; + asio::streambuf m_ReadBuffer; + std::string m_WebSocketKey; + std::shared_ptr m_HandshakeBuffer; + + // Write queue + RwLock m_WriteLock; + std::deque> m_WriteQueue; + bool m_IsWriting = false; + + std::atomic m_IsOpen{false}; + bool m_CloseSent = false; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique(Url, Handler, Settings)) +{ +} + +HttpWsClient::HttpWsClient(std::string_view Url, + IWsClientHandler& Handler, + asio::io_context& IoContext, + const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique(Url, Handler, IoContext, Settings)) +{ +} + +HttpWsClient::~HttpWsClient() = default; + +void +HttpWsClient::Connect() +{ + m_Impl->Connect(); +} + +void +HttpWsClient::SendText(std::string_view Text) +{ + m_Impl->SendText(Text); +} + +void +HttpWsClient::SendBinary(std::span Data) +{ + m_Impl->SendBinary(Data); +} + +void +HttpWsClient::Close(uint16_t Code, std::string_view Reason) +{ + m_Impl->DoClose(Code, Reason); +} + +bool +HttpWsClient::IsOpen() const +{ + return m_Impl->m_IsOpen.load(std::memory_order_relaxed); +} + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h index bc18549c9..8fca35ac5 100644 --- a/src/zenhttp/include/zenhttp/httpcommon.h +++ b/src/zenhttp/include/zenhttp/httpcommon.h @@ -184,6 +184,13 @@ IsHttpSuccessCode(HttpResponseCode HttpCode) noexcept return IsHttpSuccessCode(int(HttpCode)); } +[[nodiscard]] inline bool +IsHttpOk(HttpResponseCode HttpCode) noexcept +{ + return HttpCode == HttpResponseCode::OK || HttpCode == HttpResponseCode::Created || HttpCode == HttpResponseCode::Accepted || + HttpCode == HttpResponseCode::NoContent; +} + std::string_view ToString(HttpResponseCode HttpCode); } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 00cbc6c14..fee932daa 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -462,6 +462,7 @@ struct IHttpStatsService virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; }; -void http_forcelink(); // internal +void http_forcelink(); // internal +void websocket_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h new file mode 100644 index 000000000..926ec1e3d --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -0,0 +1,79 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace zen { + +/** + * Callback interface for WebSocket client events + * + * Separate from the server-side IWebSocketHandler because the caller + * already owns the HttpWsClient — no Ref needed. + */ +class IWsClientHandler +{ +public: + virtual ~IWsClientHandler() = default; + + virtual void OnWsOpen() = 0; + virtual void OnWsMessage(const WebSocketMessage& Msg) = 0; + virtual void OnWsClose(uint16_t Code, std::string_view Reason) = 0; +}; + +struct HttpWsClientSettings +{ + std::string LogCategory = "wsclient"; + std::chrono::milliseconds ConnectTimeout{5000}; + std::optional> AccessTokenProvider; +}; + +/** + * WebSocket client over TCP (ws:// scheme) + * + * Uses ASIO for async I/O. Two construction modes: + * - Internal io_context + background thread (standalone use) + * - External io_context (shared event loop, no internal thread) + * + * Thread-safe for SendText/SendBinary/Close. + */ +class HttpWsClient +{ +public: + HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings = {}); + HttpWsClient(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings = {}); + + ~HttpWsClient(); + + HttpWsClient(const HttpWsClient&) = delete; + HttpWsClient& operator=(const HttpWsClient&) = delete; + + void Connect(); + void SendText(std::string_view Text); + void SendBinary(std::span Data); + void Close(uint16_t Code = 1000, std::string_view Reason = {}); + bool IsOpen() const; + +private: + struct Impl; + std::unique_ptr m_Impl; +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h new file mode 100644 index 000000000..7a6fb33dd --- /dev/null +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +#include +#include +#include + +namespace zen { + +enum class WebSocketOpcode : uint8_t +{ + kText = 0x1, + kBinary = 0x2, + kClose = 0x8, + kPing = 0x9, + kPong = 0xA +}; + +struct WebSocketMessage +{ + WebSocketOpcode Opcode; + IoBuffer Payload; + uint16_t CloseCode = 0; +}; + +/** + * Represents an active WebSocket connection + * + * Derived classes implement the actual transport (e.g. ASIO sockets). + * Instances are reference-counted so that both the service layer and + * the async read/write loop can share ownership. + */ +class WebSocketConnection : public RefCounted +{ +public: + virtual ~WebSocketConnection() = default; + + virtual void SendText(std::string_view Text) = 0; + virtual void SendBinary(std::span Data) = 0; + virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0; + virtual bool IsOpen() const = 0; +}; + +/** + * Interface for services that accept WebSocket upgrades + * + * An HttpService may additionally implement this interface to indicate + * it supports WebSocket connections. The HTTP server checks for this + * via dynamic_cast when it sees an Upgrade: websocket request. + */ +class IWebSocketHandler +{ +public: + virtual ~IWebSocketHandler() = default; + + virtual void OnWebSocketOpen(Ref Connection) = 0; + virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0; + virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0; +}; + +} // namespace zen diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 0c0238886..8c2dcd116 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -14,6 +14,8 @@ #include #include "httpparser.h" +#include "wsasio.h" +#include "wsframecodec.h" #include @@ -1159,6 +1161,53 @@ HttpServerConnection::HandleRequest() { ZEN_MEMSCOPE(GetHttpasioTag()); + // WebSocket upgrade detection must happen before the keep-alive check below, + // because Upgrade requests have "Connection: Upgrade" which the HTTP parser + // treats as non-keep-alive, causing a premature shutdown of the receive side. + if (m_RequestData.IsWebSocketUpgrade()) + { + if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) + { + IWebSocketHandler* WsHandler = dynamic_cast(Service); + if (WsHandler && !m_RequestData.SecWebSocketKey().empty()) + { + std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(m_RequestData.SecWebSocketKey()); + + auto ResponseStr = std::make_shared(); + ResponseStr->reserve(256); + ResponseStr->append( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: "); + ResponseStr->append(AcceptKey); + ResponseStr->append("\r\n\r\n"); + + // Send the 101 response on the current socket, then hand the socket off + // to a WsAsioConnection for the WebSocket protocol. + asio::async_write(*m_Socket, + asio::buffer(ResponseStr->data(), ResponseStr->size()), + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); + return; + } + + Ref WsConn(new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler)); + Ref WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + }); + + m_RequestState = RequestState::kDone; + return; + } + } + // Service doesn't support WebSocket or missing key — fall through to normal handling + } + if (!m_RequestData.IsKeepAlive()) { m_RequestState = RequestState::kWritingFinal; diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index f0485aa25..3b1229375 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -12,14 +12,17 @@ namespace zen { using namespace std::literals; -static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); -static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); -static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); -static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); -static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); -static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); -static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); -static constinit uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv); +static constexpr uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constexpr uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constexpr uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constexpr uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constexpr uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constexpr uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constexpr uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); +static constexpr uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv); +static constexpr uint32_t HashUpgrade = HashStringAsLowerDjb2("Upgrade"sv); +static constexpr uint32_t HashSecWebSocketKey = HashStringAsLowerDjb2("Sec-WebSocket-Key"sv); +static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-WebSocket-Version"sv); ////////////////////////////////////////////////////////////////////////// // @@ -143,45 +146,62 @@ HttpRequestParser::ParseCurrentHeader() const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1); - if (HeaderHash == HashContentLength) + switch (HeaderHash) { - m_ContentLengthHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashAccept) - { - m_AcceptHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashContentType) - { - m_ContentTypeHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashAuthorization) - { - m_AuthorizationHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashSession) - { - m_SessionId = Oid::TryFromHexString(HeaderValue); - } - else if (HeaderHash == HashRequest) - { - std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); - } - else if (HeaderHash == HashExpect) - { - if (HeaderValue == "100-continue"sv) - { - // We don't currently do anything with this - m_Expect100Continue = true; - } - else - { - ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); - } - } - else if (HeaderHash == HashRange) - { - m_RangeHeaderIndex = CurrentHeaderIndex; + case HashContentLength: + m_ContentLengthHeaderIndex = CurrentHeaderIndex; + break; + + case HashAccept: + m_AcceptHeaderIndex = CurrentHeaderIndex; + break; + + case HashContentType: + m_ContentTypeHeaderIndex = CurrentHeaderIndex; + break; + + case HashAuthorization: + m_AuthorizationHeaderIndex = CurrentHeaderIndex; + break; + + case HashSession: + m_SessionId = Oid::TryFromHexString(HeaderValue); + break; + + case HashRequest: + std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); + break; + + case HashExpect: + if (HeaderValue == "100-continue"sv) + { + // We don't currently do anything with this + m_Expect100Continue = true; + } + else + { + ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); + } + break; + + case HashRange: + m_RangeHeaderIndex = CurrentHeaderIndex; + break; + + case HashUpgrade: + m_UpgradeHeaderIndex = CurrentHeaderIndex; + break; + + case HashSecWebSocketKey: + m_SecWebSocketKeyHeaderIndex = CurrentHeaderIndex; + break; + + case HashSecWebSocketVersion: + m_SecWebSocketVersionHeaderIndex = CurrentHeaderIndex; + break; + + default: + break; } } @@ -361,14 +381,17 @@ HttpRequestParser::ResetState() m_HeaderEntries.clear(); - m_ContentLengthHeaderIndex = -1; - m_AcceptHeaderIndex = -1; - m_ContentTypeHeaderIndex = -1; - m_RangeHeaderIndex = -1; - m_AuthorizationHeaderIndex = -1; - m_Expect100Continue = false; - m_BodyBuffer = {}; - m_BodyPosition = 0; + m_ContentLengthHeaderIndex = -1; + m_AcceptHeaderIndex = -1; + m_ContentTypeHeaderIndex = -1; + m_RangeHeaderIndex = -1; + m_AuthorizationHeaderIndex = -1; + m_UpgradeHeaderIndex = -1; + m_SecWebSocketKeyHeaderIndex = -1; + m_SecWebSocketVersionHeaderIndex = -1; + m_Expect100Continue = false; + m_BodyBuffer = {}; + m_BodyPosition = 0; m_HeaderData.clear(); m_NormalizedUrl.clear(); @@ -425,4 +448,21 @@ HttpRequestParser::OnMessageComplete() } } +bool +HttpRequestParser::IsWebSocketUpgrade() const +{ + std::string_view Upgrade = GetHeaderValue(m_UpgradeHeaderIndex); + if (Upgrade.empty()) + { + return false; + } + + // Case-insensitive check for "websocket" + if (Upgrade.size() != 9) + { + return false; + } + return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0; +} + } // namespace zen diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index ff56ca970..d40a5aeb0 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -48,6 +48,10 @@ struct HttpRequestParser std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); } + std::string_view UpgradeHeader() const { return GetHeaderValue(m_UpgradeHeaderIndex); } + std::string_view SecWebSocketKey() const { return GetHeaderValue(m_SecWebSocketKeyHeaderIndex); } + bool IsWebSocketUpgrade() const; + private: struct HeaderRange { @@ -86,6 +90,9 @@ private: int8_t m_ContentTypeHeaderIndex; int8_t m_RangeHeaderIndex; int8_t m_AuthorizationHeaderIndex; + int8_t m_UpgradeHeaderIndex; + int8_t m_SecWebSocketKeyHeaderIndex; + int8_t m_SecWebSocketVersionHeaderIndex; HttpVerb m_RequestVerb; std::atomic_bool m_KeepAlive{false}; bool m_Expect100Continue = false; diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index e93ae4853..23d57af57 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -156,6 +156,10 @@ private: #if ZEN_WITH_HTTPSYS +# include "httpsys_iocontext.h" +# include "wshttpsys.h" +# include "wsframecodec.h" + # include # include # pragma comment(lib, "httpapi.lib") @@ -380,7 +384,7 @@ public: PTP_IO Iocp(); HANDLE RequestQueueHandle(); - inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; } inline HttpSysServer& Server() { return m_HttpServer; } inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } @@ -397,8 +401,8 @@ public: }; private: - OVERLAPPED m_HttpOverlapped{}; - HttpSysServer& m_HttpServer; + HttpSysIoContext m_IoContext{}; + HttpSysServer& m_HttpServer; // Tracks which handler is due to handle the next I/O completion event HttpSysRequestHandler* m_CompletionHandler = nullptr; @@ -1555,7 +1559,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, // than one thread at any given moment. This means we need to be careful about what // happens in here - HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped); + + switch (IoContext->ContextType) + { + case HttpSysIoContext::Type::kWebSocketRead: + static_cast(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kWebSocketWrite: + static_cast(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kTransaction: + break; + } + + HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext); if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) { @@ -2111,64 +2131,118 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT { HTTP_REQUEST* HttpReq = HttpRequest(); -# if 0 - for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + if (HttpService* Service = reinterpret_cast(HttpReq->UrlContext)) { - auto& ReqInfo = HttpReq->pRequestInfo[i]; - - switch (ReqInfo.InfoType) + // WebSocket upgrade detection + if (m_IsInitialRequest) { - case HttpRequestInfoTypeRequestTiming: + const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade]; + if (UpgradeHeader.RawValueLength > 0 && + StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0) + { + if (IWebSocketHandler* WsHandler = dynamic_cast(Service)) { - const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast(ReqInfo.pInfo); + // Extract Sec-WebSocket-Key from the unknown headers + // (http.sys has no known-header slot for it) + std::string_view SecWebSocketKey; + for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i) + { + const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i]; + if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0) + { + SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength); + break; + } + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeAuth: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeChannelBind: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslProtocol: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBindingDraft: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBinding: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV0: - { - const TCP_INFO_v0* TcpInfo = reinterpret_cast(ReqInfo.pInfo); + if (SecWebSocketKey.empty()) + { + ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header"); + return nullptr; + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeRequestSizing: - { - const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast(ReqInfo.pInfo); - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeQuicStats: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV1: - { - const TCP_INFO_v1* TcpInfo = reinterpret_cast(ReqInfo.pInfo); + const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey); + + HANDLE RequestQueueHandle = Transaction().RequestQueueHandle(); + HTTP_REQUEST_ID RequestId = HttpReq->RequestId; + + // Build the 101 Switching Protocols response + HTTP_RESPONSE Response = {}; + Response.StatusCode = 101; + Response.pReason = "Switching Protocols"; + Response.ReasonLength = (USHORT)strlen(Response.pReason); + + Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket"; + Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9; + + eastl::fixed_vector UnknownHeaders; + + // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders + // despite there being an entry for it there (HttpHeaderConnection). If you try to do + // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below + + UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"}); + + UnknownHeaders.push_back({.NameLength = 20, + .RawValueLength = (USHORT)AcceptKey.size(), + .pName = "Sec-WebSocket-Accept", + .pRawValue = AcceptKey.c_str()}); + + Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size(); + Response.Headers.pUnknownHeaders = UnknownHeaders.data(); + + const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + + // Use an OVERLAPPED with an event so we can wait synchronously. + // The request queue is IOCP-associated, so passing NULL for pOverlapped + // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent + // prevents IOCP delivery and lets us wait on the event directly. + OVERLAPPED SendOverlapped = {}; + HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr); + SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1); + + ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle, + RequestId, + Flags, + &Response, + nullptr, // CachePolicy + nullptr, // BytesSent + nullptr, // Reserved1 + 0, // Reserved2 + &SendOverlapped, + nullptr // LogData + ); + + if (SendResult == ERROR_IO_PENDING) + { + WaitForSingleObject(SendEvent, INFINITE); + SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE; + } + + CloseHandle(SendEvent); + + if (SendResult == NO_ERROR) + { + Ref WsConn( + new WsHttpSysConnection(RequestQueueHandle, RequestId, *WsHandler, Transaction().Iocp())); + Ref WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + + return nullptr; + } - ZEN_INFO(""); + ZEN_WARN("WebSocket 101 send failed: {}", SendResult); + + // WebSocket upgrade failed — return nullptr since ServerRequest() + // was never populated (no InvokeRequestHandler call) + return nullptr; } - break; + // Service doesn't support WebSocket or missing key — fall through to normal handling + } } - } -# endif - if (HttpService* Service = reinterpret_cast(HttpReq->UrlContext)) - { if (m_IsInitialRequest) { m_ContentLength = GetContentLength(HttpReq); diff --git a/src/zenhttp/servers/httpsys_iocontext.h b/src/zenhttp/servers/httpsys_iocontext.h new file mode 100644 index 000000000..4f8a97012 --- /dev/null +++ b/src/zenhttp/servers/httpsys_iocontext.h @@ -0,0 +1,40 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include + +# include + +namespace zen { + +/** + * Tagged OVERLAPPED wrapper for http.sys IOCP dispatch + * + * Both HttpSysTransaction (for normal HTTP request I/O) and WsHttpSysConnection + * (for WebSocket read/write) embed this struct. The single IoCompletionCallback + * bound to the request queue uses the ContextType tag to dispatch to the correct + * handler. + * + * The Overlapped member must be first so that CONTAINING_RECORD works to recover + * the HttpSysIoContext from the OVERLAPPED pointer provided by the threadpool. + */ +struct HttpSysIoContext +{ + OVERLAPPED Overlapped{}; + + enum class Type : uint8_t + { + kTransaction, + kWebSocketRead, + kWebSocketWrite, + } ContextType = Type::kTransaction; + + void* Owner = nullptr; +}; + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp new file mode 100644 index 000000000..dfc1eac38 --- /dev/null +++ b/src/zenhttp/servers/wsasio.cpp @@ -0,0 +1,297 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wsasio.h" +#include "wsframecodec.h" + +#include + +namespace zen::asio_http { + +static LoggerRef +WsLog() +{ + static LoggerRef g_Logger = logging::Get("ws"); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +WsAsioConnection::WsAsioConnection(std::unique_ptr Socket, IWebSocketHandler& Handler) +: m_Socket(std::move(Socket)) +, m_Handler(Handler) +{ +} + +WsAsioConnection::~WsAsioConnection() +{ + m_IsOpen.store(false); +} + +void +WsAsioConnection::Start() +{ + EnqueueRead(); +} + +bool +WsAsioConnection::IsOpen() const +{ + return m_IsOpen.load(std::memory_order_relaxed); +} + +////////////////////////////////////////////////////////////////////////// +// +// Read loop +// + +void +WsAsioConnection::EnqueueRead() +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + Ref Self(this); + + asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) { + Self->OnDataReceived(Ec, ByteCount); + }); +} + +void +WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(WsLog(), "WebSocket read error: {}", Ec.message()); + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); + } + return; + } + + ProcessReceivedData(); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + EnqueueRead(); + } +} + +void +WsAsioConnection::ProcessReceivedData() +{ + while (m_ReadBuffer.size() > 0) + { + const auto& InputBuffer = m_ReadBuffer.data(); + const auto* Data = static_cast(InputBuffer.data()); + const auto Size = InputBuffer.size(); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size); + if (!Frame.IsValid) + { + break; // not enough data yet + } + + m_ReadBuffer.consume(Frame.BytesConsumed); + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with pong carrying the same payload + std::vector PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + // Unsolicited pong — ignore per RFC 6455 + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = std::string_view(reinterpret_cast(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo close frame back if we haven't sent one yet + if (!m_CloseSent) + { + m_CloseSent = true; + std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + + m_IsOpen.store(false); + m_Handler.OnWebSocketClose(*this, Code, Reason); + + // Shut down the socket + std::error_code ShutdownEc; + m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc); + m_Socket->close(ShutdownEc); + return; + } + + default: + ZEN_LOG_WARN(WsLog(), "Unknown WebSocket opcode: {:#x}", static_cast(Frame.Opcode)); + break; + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Write queue +// + +void +WsAsioConnection::SendText(std::string_view Text) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); +} + +void +WsAsioConnection::SendBinary(std::span Data) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); +} + +void +WsAsioConnection::Close(uint16_t Code, std::string_view Reason) +{ + DoClose(Code, Reason); +} + +void +WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason) +{ + if (!m_IsOpen.exchange(false)) + { + return; + } + + if (!m_CloseSent) + { + m_CloseSent = true; + std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + + m_Handler.OnWebSocketClose(*this, Code, Reason); +} + +void +WsAsioConnection::EnqueueWrite(std::vector Frame) +{ + bool ShouldFlush = false; + + m_WriteLock.WithExclusiveLock([&] { + m_WriteQueue.push_back(std::move(Frame)); + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + }); + + if (ShouldFlush) + { + FlushWriteQueue(); + } +} + +void +WsAsioConnection::FlushWriteQueue() +{ + std::vector Frame; + + m_WriteLock.WithExclusiveLock([&] { + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + Frame = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + }); + + if (Frame.empty()) + { + return; + } + + Ref Self(this); + + // Move Frame into a shared_ptr so we can create the buffer and capture ownership + // in the same async_write call without evaluation order issues. + auto OwnedFrame = std::make_shared>(std::move(Frame)); + + asio::async_write(*m_Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); }); +} + +void +WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(WsLog(), "WebSocket write error: {}", Ec.message()); + } + + m_WriteLock.WithExclusiveLock([&] { + m_IsWriting = false; + m_WriteQueue.clear(); + }); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + return; + } + + FlushWriteQueue(); +} + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h new file mode 100644 index 000000000..a638ea836 --- /dev/null +++ b/src/zenhttp/servers/wsasio.h @@ -0,0 +1,71 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#include +#include +#include + +namespace zen::asio_http { + +/** + * WebSocket connection over an ASIO TCP socket + * + * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake) + * and runs an async read/write loop to exchange WebSocket frames. + * + * Lifetime is managed solely through intrusive reference counting (RefCounted). + * The async read/write callbacks capture Ref to keep the + * connection alive for the duration of the async operation. The service layer + * also holds a Ref. + */ +class WsAsioConnection : public WebSocketConnection +{ +public: + WsAsioConnection(std::unique_ptr Socket, IWebSocketHandler& Handler); + ~WsAsioConnection() override; + + /** + * Start the async read loop. Must be called once after construction + * and the 101 response has been sent. + */ + void Start(); + + // WebSocketConnection interface + void SendText(std::string_view Text) override; + void SendBinary(std::span Data) override; + void Close(uint16_t Code, std::string_view Reason) override; + bool IsOpen() const override; + +private: + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void ProcessReceivedData(); + + void EnqueueWrite(std::vector Frame); + void FlushWriteQueue(); + void OnWriteComplete(const asio::error_code& Ec, std::size_t ByteCount); + + void DoClose(uint16_t Code, std::string_view Reason); + + std::unique_ptr m_Socket; + IWebSocketHandler& m_Handler; + asio::streambuf m_ReadBuffer; + + RwLock m_WriteLock; + std::deque> m_WriteQueue; + bool m_IsWriting = false; + + std::atomic m_IsOpen{true}; + bool m_CloseSent = false; +}; + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp new file mode 100644 index 000000000..a4c5e0f16 --- /dev/null +++ b/src/zenhttp/servers/wsframecodec.cpp @@ -0,0 +1,229 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wsframecodec.h" + +#include +#include + +#include +#include + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Frame parsing +// + +WsFrameParseResult +WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) +{ + // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames) + if (Size < 2) + { + return {}; + } + + const bool Fin = (Data[0] & 0x80) != 0; + const uint8_t OpcodeRaw = Data[0] & 0x0F; + const bool Masked = (Data[1] & 0x80) != 0; + uint64_t PayloadLen = Data[1] & 0x7F; + + size_t HeaderSize = 2; + + if (PayloadLen == 126) + { + if (Size < 4) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]); + HeaderSize = 4; + } + else if (PayloadLen == 127) + { + if (Size < 10) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) | + (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]); + HeaderSize = 10; + } + + const size_t MaskSize = Masked ? 4 : 0; + const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen; + + if (Size < TotalFrame) + { + return {}; + } + + const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr; + const uint8_t* PayloadData = Data + HeaderSize + MaskSize; + + WsFrameParseResult Result; + Result.IsValid = true; + Result.BytesConsumed = TotalFrame; + Result.Opcode = static_cast(OpcodeRaw); + Result.Fin = Fin; + + Result.Payload.resize(static_cast(PayloadLen)); + if (PayloadLen > 0) + { + std::memcpy(Result.Payload.data(), PayloadData, static_cast(PayloadLen)); + + if (Masked) + { + for (size_t i = 0; i < Result.Payload.size(); ++i) + { + Result.Payload[i] ^= MaskKey[i & 3]; + } + } + } + + return Result; +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (server-to-client, no masking) +// + +std::vector +WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span Payload) +{ + std::vector Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast(Opcode)); + + // Payload length (no mask bit for server frames) + if (PayloadLen < 126) + { + Frame.push_back(static_cast(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(126); + Frame.push_back(static_cast((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + Frame.insert(Frame.end(), Payload.begin(), Payload.end()); + + return Frame; +} + +std::vector +WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector Payload; + Payload.push_back(static_cast((Code >> 8) & 0xFF)); + Payload.push_back(static_cast(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (client-to-server, with masking) +// + +std::vector +WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span Payload) +{ + std::vector Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast(Opcode)); + + // Payload length with mask bit set + if (PayloadLen < 126) + { + Frame.push_back(0x80 | static_cast(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + // Generate random 4-byte mask key + static thread_local std::mt19937 s_Rng(std::random_device{}()); + uint32_t MaskValue = s_Rng(); + uint8_t MaskKey[4]; + std::memcpy(MaskKey, &MaskValue, 4); + + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + // Masked payload + for (size_t i = 0; i < PayloadLen; ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; +} + +std::vector +WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector Payload; + Payload.push_back(static_cast((Code >> 8) & 0xFF)); + Payload.push_back(static_cast(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2) +// + +static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +std::string +WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey) +{ + // Concatenate client key with the magic GUID + std::string Combined; + Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size()); + Combined.append(ClientKey); + Combined.append(kWebSocketMagicGuid); + + // SHA1 hash + SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size()); + + // Base64 encode the 20-byte hash + char Base64Buf[Base64::GetEncodedDataSize(20) + 1]; + uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf); + Base64Buf[EncodedLen] = '\0'; + + return std::string(Base64Buf, EncodedLen); +} + +} // namespace zen diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h new file mode 100644 index 000000000..2d90b6fa1 --- /dev/null +++ b/src/zenhttp/servers/wsframecodec.h @@ -0,0 +1,74 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace zen { + +/** + * Result of attempting to parse a single WebSocket frame from a byte buffer + */ +struct WsFrameParseResult +{ + bool IsValid = false; // true if a complete frame was successfully parsed + size_t BytesConsumed = 0; // number of bytes consumed from the input buffer + WebSocketOpcode Opcode = WebSocketOpcode::kText; + bool Fin = false; + std::vector Payload; +}; + +/** + * RFC 6455 WebSocket frame codec + * + * Provides static helpers for parsing client-to-server frames (which are + * always masked) and building server-to-client frames (which are never masked). + */ +struct WsFrameCodec +{ + /** + * Try to parse one complete frame from the front of the buffer. + * + * Returns a result with IsValid == false and BytesConsumed == 0 when + * there is not enough data yet. The caller should accumulate more data + * and retry. + */ + static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size); + + /** + * Build a server-to-client frame (no masking) + */ + static std::vector BuildFrame(WebSocketOpcode Opcode, std::span Payload); + + /** + * Build a close frame with a status code and optional reason string + */ + static std::vector BuildCloseFrame(uint16_t Code, std::string_view Reason = {}); + + /** + * Build a client-to-server frame (with masking per RFC 6455) + */ + static std::vector BuildMaskedFrame(WebSocketOpcode Opcode, std::span Payload); + + /** + * Build a masked close frame with status code and optional reason + */ + static std::vector BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason = {}); + + /** + * Compute the Sec-WebSocket-Accept value per RFC 6455 section 4.2.2 + * + * accept = Base64(SHA1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + */ + static std::string ComputeAcceptKey(std::string_view ClientKey); +}; + +} // namespace zen diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp new file mode 100644 index 000000000..3f0f0b447 --- /dev/null +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -0,0 +1,466 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wshttpsys.h" + +#if ZEN_WITH_HTTPSYS + +# include "wsframecodec.h" + +# include + +namespace zen { + +static LoggerRef +WsHttpSysLog() +{ + static LoggerRef g_Logger = logging::Get("ws_httpsys"); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp) +: m_RequestQueueHandle(RequestQueueHandle) +, m_RequestId(RequestId) +, m_Handler(Handler) +, m_Iocp(Iocp) +, m_ReadBuffer(8192) +{ + m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead; + m_ReadIoContext.Owner = this; + m_WriteIoContext.ContextType = HttpSysIoContext::Type::kWebSocketWrite; + m_WriteIoContext.Owner = this; +} + +WsHttpSysConnection::~WsHttpSysConnection() +{ + ZEN_ASSERT(m_OutstandingOps.load() == 0); + + if (m_IsOpen.exchange(false)) + { + Disconnect(); + } +} + +void +WsHttpSysConnection::Start() +{ + m_SelfRef = Ref(this); + IssueAsyncRead(); +} + +void +WsHttpSysConnection::Shutdown() +{ + m_ShutdownRequested.store(true, std::memory_order_relaxed); + + if (!m_IsOpen.exchange(false)) + { + return; + } + + // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED + HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); +} + +bool +WsHttpSysConnection::IsOpen() const +{ + return m_IsOpen.load(std::memory_order_relaxed); +} + +////////////////////////////////////////////////////////////////////////// +// +// Async read path +// + +void +WsHttpSysConnection::IssueAsyncRead() +{ + if (!m_IsOpen.load(std::memory_order_relaxed) || m_ShutdownRequested.load(std::memory_order_relaxed)) + { + MaybeReleaseSelfRef(); + return; + } + + m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); + + ZeroMemory(&m_ReadIoContext.Overlapped, sizeof(OVERLAPPED)); + + StartThreadpoolIo(m_Iocp); + + ULONG Result = HttpReceiveRequestEntityBody(m_RequestQueueHandle, + m_RequestId, + 0, // Flags + m_ReadBuffer.data(), + (ULONG)m_ReadBuffer.size(), + nullptr, // BytesRead (ignored for async) + &m_ReadIoContext.Overlapped); + + if (Result != NO_ERROR && Result != ERROR_IO_PENDING) + { + CancelThreadpoolIo(m_Iocp); + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "read issue failed"); + } + + MaybeReleaseSelfRef(); + } +} + +void +WsHttpSysConnection::OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // Hold a transient ref to prevent mid-callback destruction after MaybeReleaseSelfRef + Ref Guard(this); + + if (IoResult != NO_ERROR) + { + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.exchange(false)) + { + if (IoResult == ERROR_HANDLE_EOF) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection closed"); + } + else if (IoResult != ERROR_OPERATION_ABORTED) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); + } + } + + MaybeReleaseSelfRef(); + return; + } + + if (NumberOfBytesTransferred > 0) + { + m_Accumulated.insert(m_Accumulated.end(), m_ReadBuffer.begin(), m_ReadBuffer.begin() + NumberOfBytesTransferred); + ProcessReceivedData(); + } + + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + IssueAsyncRead(); + } + else + { + MaybeReleaseSelfRef(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame parsing +// + +void +WsHttpSysConnection::ProcessReceivedData() +{ + while (!m_Accumulated.empty()) + { + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size()); + if (!Frame.IsValid) + { + break; // not enough data yet + } + + // Remove consumed bytes + m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed); + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with pong carrying the same payload + std::vector PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + // Unsolicited pong — ignore per RFC 6455 + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = std::string_view(reinterpret_cast(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo close frame back if we haven't sent one yet + { + bool ShouldSendClose = false; + { + RwLock::ExclusiveLockScope _(m_WriteLock); + if (!m_CloseSent) + { + m_CloseSent = true; + ShouldSendClose = true; + } + } + if (ShouldSendClose) + { + std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + } + + m_IsOpen.store(false); + m_Handler.OnWebSocketClose(*this, Code, Reason); + Disconnect(); + return; + } + + default: + ZEN_LOG_WARN(WsHttpSysLog(), "Unknown WebSocket opcode: {:#x}", static_cast(Frame.Opcode)); + break; + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Async write path +// + +void +WsHttpSysConnection::EnqueueWrite(std::vector Frame) +{ + bool ShouldFlush = false; + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.push_back(std::move(Frame)); + + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + } + + if (ShouldFlush) + { + FlushWriteQueue(); + } +} + +void +WsHttpSysConnection::FlushWriteQueue() +{ + { + RwLock::ExclusiveLockScope _(m_WriteLock); + + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + + m_CurrentWriteBuffer = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + } + + m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); + + ZeroMemory(&m_WriteChunk, sizeof(m_WriteChunk)); + m_WriteChunk.DataChunkType = HttpDataChunkFromMemory; + m_WriteChunk.FromMemory.pBuffer = m_CurrentWriteBuffer.data(); + m_WriteChunk.FromMemory.BufferLength = (ULONG)m_CurrentWriteBuffer.size(); + + ZeroMemory(&m_WriteIoContext.Overlapped, sizeof(OVERLAPPED)); + + StartThreadpoolIo(m_Iocp); + + ULONG Result = HttpSendResponseEntityBody(m_RequestQueueHandle, + m_RequestId, + HTTP_SEND_RESPONSE_FLAG_MORE_DATA, + 1, + &m_WriteChunk, + nullptr, + nullptr, + 0, + &m_WriteIoContext.Overlapped, + nullptr); + + if (Result != NO_ERROR && Result != ERROR_IO_PENDING) + { + CancelThreadpoolIo(m_Iocp); + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket async write failed: {}", Result); + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.clear(); + m_IsWriting = false; + } + m_CurrentWriteBuffer.clear(); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + + MaybeReleaseSelfRef(); + } +} + +void +WsHttpSysConnection::OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(NumberOfBytesTransferred); + + // Hold a transient ref to prevent mid-callback destruction + Ref Guard(this); + + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + m_CurrentWriteBuffer.clear(); + + if (IoResult != NO_ERROR) + { + ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket write completion error: {}", IoResult); + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.clear(); + m_IsWriting = false; + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + + MaybeReleaseSelfRef(); + return; + } + + FlushWriteQueue(); +} + +////////////////////////////////////////////////////////////////////////// +// +// Send interface +// + +void +WsHttpSysConnection::SendText(std::string_view Text) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); +} + +void +WsHttpSysConnection::SendBinary(std::span Data) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); +} + +void +WsHttpSysConnection::Close(uint16_t Code, std::string_view Reason) +{ + DoClose(Code, Reason); +} + +void +WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason) +{ + if (!m_IsOpen.exchange(false)) + { + return; + } + + { + bool ShouldSendClose = false; + { + RwLock::ExclusiveLockScope _(m_WriteLock); + if (!m_CloseSent) + { + m_CloseSent = true; + ShouldSendClose = true; + } + } + if (ShouldSendClose) + { + std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + } + + m_Handler.OnWebSocketClose(*this, Code, Reason); + + // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED + HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); +} + +////////////////////////////////////////////////////////////////////////// +// +// Lifetime management +// + +void +WsHttpSysConnection::MaybeReleaseSelfRef() +{ + if (m_OutstandingOps.load(std::memory_order_relaxed) == 0 && !m_IsOpen.load(std::memory_order_relaxed)) + { + m_SelfRef = nullptr; + } +} + +void +WsHttpSysConnection::Disconnect() +{ + // Send final empty body with DISCONNECT to tell http.sys the connection is done + HttpSendResponseEntityBody(m_RequestQueueHandle, + m_RequestId, + HTTP_SEND_RESPONSE_FLAG_DISCONNECT, + 0, + nullptr, + nullptr, + nullptr, + 0, + nullptr, + nullptr); +} + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h new file mode 100644 index 000000000..ab0ca381a --- /dev/null +++ b/src/zenhttp/servers/wshttpsys.h @@ -0,0 +1,104 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include "httpsys_iocontext.h" + +#include + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include +# include + +# include +# include +# include + +namespace zen { + +/** + * WebSocket connection over an http.sys opaque-mode connection + * + * After the 101 Switching Protocols response is sent with + * HTTP_SEND_RESPONSE_FLAG_OPAQUE, http.sys stops parsing HTTP on the + * connection. Raw bytes are exchanged via HttpReceiveRequestEntityBody / + * HttpSendResponseEntityBody using the original RequestId. + * + * All I/O is performed asynchronously via the same IOCP threadpool used + * for normal http.sys traffic, eliminating per-connection threads. + * + * Lifetime is managed through intrusive reference counting (RefCounted). + * A self-reference (m_SelfRef) is held from Start() until all outstanding + * async operations have drained, preventing premature destruction. + */ +class WsHttpSysConnection : public WebSocketConnection +{ +public: + WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp); + ~WsHttpSysConnection() override; + + /** + * Start the async read loop. Must be called once after construction + * and after the 101 response has been sent. + */ + void Start(); + + /** + * Shut down the connection. Cancels pending I/O; IOCP completions + * will fire with ERROR_OPERATION_ABORTED and drain naturally. + */ + void Shutdown(); + + // WebSocketConnection interface + void SendText(std::string_view Text) override; + void SendBinary(std::span Data) override; + void Close(uint16_t Code, std::string_view Reason) override; + bool IsOpen() const override; + + // Called from IoCompletionCallback via tagged dispatch + void OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + void OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + +private: + void IssueAsyncRead(); + void ProcessReceivedData(); + void EnqueueWrite(std::vector Frame); + void FlushWriteQueue(); + void DoClose(uint16_t Code, std::string_view Reason); + void Disconnect(); + void MaybeReleaseSelfRef(); + + HANDLE m_RequestQueueHandle; + HTTP_REQUEST_ID m_RequestId; + IWebSocketHandler& m_Handler; + PTP_IO m_Iocp; + + // Tagged OVERLAPPED contexts for concurrent read and write + HttpSysIoContext m_ReadIoContext{}; + HttpSysIoContext m_WriteIoContext{}; + + // Read state + std::vector m_ReadBuffer; + std::vector m_Accumulated; + + // Write state + RwLock m_WriteLock; + std::deque> m_WriteQueue; + std::vector m_CurrentWriteBuffer; + HTTP_DATA_CHUNK m_WriteChunk{}; + bool m_IsWriting = false; + + // Lifetime management + std::atomic m_OutstandingOps{0}; + Ref m_SelfRef; + std::atomic m_ShutdownRequested{false}; + std::atomic m_IsOpen{true}; + bool m_CloseSent = false; +}; + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp new file mode 100644 index 000000000..95f8587df --- /dev/null +++ b/src/zenhttp/servers/wstest.cpp @@ -0,0 +1,922 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#if ZEN_WITH_TESTS + +# include +# include +# include + +# include +# include +# include + +# include "httpasio.h" +# include "wsframecodec.h" + +ZEN_THIRD_PARTY_INCLUDES_START +# if ZEN_PLATFORM_WINDOWS +# include +# else +# include +# include +# endif +# include +ZEN_THIRD_PARTY_INCLUDES_END + +# include +# include +# include +# include +# include +# include +# include +# include + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// +// Unit tests: WsFrameCodec +// + +TEST_CASE("websocket.framecodec") +{ + SUBCASE("ComputeAcceptKey RFC 6455 test vector") + { + // RFC 6455 section 4.2.2 example + std::string AcceptKey = WsFrameCodec::ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); + CHECK_EQ(AcceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + } + + SUBCASE("BuildFrame and TryParseFrame roundtrip - text") + { + std::string_view Text = "Hello, WebSocket!"; + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + + // Server frames are unmasked — TryParseFrame should handle them + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, Frame.size()); + CHECK(Result.Fin); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), Text.size()); + CHECK_EQ(std::string_view(reinterpret_cast(Result.Payload.data()), Result.Payload.size()), Text); + } + + SUBCASE("BuildFrame and TryParseFrame roundtrip - binary") + { + std::vector BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, BinaryData); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); + CHECK_EQ(Result.Payload, BinaryData); + } + + SUBCASE("BuildFrame - medium payload (126-65535 bytes)") + { + std::vector Payload(300, 0x42); + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 300u); + CHECK_EQ(Result.Payload, Payload); + } + + SUBCASE("BuildFrame - large payload (>65535 bytes)") + { + std::vector Payload(70000, 0xAB); + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 70000u); + } + + SUBCASE("BuildCloseFrame roundtrip") + { + std::vector Frame = WsFrameCodec::BuildCloseFrame(1000, "normal closure"); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); + REQUIRE(Result.Payload.size() >= 2); + + uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); + CHECK_EQ(Code, 1000); + + std::string_view Reason(reinterpret_cast(Result.Payload.data() + 2), Result.Payload.size() - 2); + CHECK_EQ(Reason, "normal closure"); + } + + SUBCASE("TryParseFrame - partial data returns invalid") + { + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span{}); + + // Pass only 1 byte — not enough for a frame header + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1); + CHECK_FALSE(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, 0u); + } + + SUBCASE("TryParseFrame - empty payload") + { + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span{}); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK(Result.Payload.empty()); + } + + SUBCASE("TryParseFrame - masked client frame") + { + // Build a masked frame manually as a client would send + // Frame: FIN=1, opcode=text, MASK=1, payload_len=5, mask_key=0x37FA213D, payload="Hello" + uint8_t MaskKey[4] = {0x37, 0xFA, 0x21, 0x3D}; + uint8_t MaskedPayload[5] = {}; + const char* Original = "Hello"; + for (int i = 0; i < 5; ++i) + { + MaskedPayload[i] = static_cast(Original[i]) ^ MaskKey[i % 4]; + } + + std::vector Frame; + Frame.push_back(0x81); // FIN + text + Frame.push_back(0x85); // MASK + len=5 + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + Frame.insert(Frame.end(), MaskedPayload, MaskedPayload + 5); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), 5u); + CHECK_EQ(std::string_view(reinterpret_cast(Result.Payload.data()), 5), "Hello"sv); + } + + SUBCASE("BuildMaskedFrame roundtrip - text") + { + std::string_view Text = "Hello, masked WebSocket!"; + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + + // Verify mask bit is set + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, Frame.size()); + CHECK(Result.Fin); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), Text.size()); + CHECK_EQ(std::string_view(reinterpret_cast(Result.Payload.data()), Result.Payload.size()), Text); + } + + SUBCASE("BuildMaskedFrame roundtrip - binary") + { + std::vector BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; + + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, BinaryData); + + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); + CHECK_EQ(Result.Payload, BinaryData); + } + + SUBCASE("BuildMaskedFrame - medium payload (126-65535 bytes)") + { + std::vector Payload(300, 0x42); + + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); + + CHECK((Frame[1] & 0x80) != 0); + CHECK_EQ((Frame[1] & 0x7F), 126); // 16-bit extended length + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 300u); + CHECK_EQ(Result.Payload, Payload); + } + + SUBCASE("BuildMaskedFrame - large payload (>65535 bytes)") + { + std::vector Payload(70000, 0xAB); + + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); + + CHECK((Frame[1] & 0x80) != 0); + CHECK_EQ((Frame[1] & 0x7F), 127); // 64-bit extended length + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 70000u); + } + + SUBCASE("BuildMaskedCloseFrame roundtrip") + { + std::vector Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure"); + + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); + REQUIRE(Result.Payload.size() >= 2); + + uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); + CHECK_EQ(Code, 1000); + + std::string_view Reason(reinterpret_cast(Result.Payload.data() + 2), Result.Payload.size() - 2); + CHECK_EQ(Reason, "normal closure"); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Integration tests: WebSocket over ASIO +// + +namespace { + + /** + * Helper: Build a masked client-to-server frame per RFC 6455 + */ + std::vector BuildMaskedFrame(WebSocketOpcode Opcode, std::span Payload) + { + std::vector Frame; + + // FIN + opcode + Frame.push_back(0x80 | static_cast(Opcode)); + + // Payload length with mask bit set + if (Payload.size() < 126) + { + Frame.push_back(0x80 | static_cast(Payload.size())); + } + else if (Payload.size() <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast((Payload.size() >> 8) & 0xFF)); + Frame.push_back(static_cast(Payload.size() & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast((Payload.size() >> (i * 8)) & 0xFF)); + } + } + + // Mask key (use a fixed key for deterministic tests) + uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78}; + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + // Masked payload + for (size_t i = 0; i < Payload.size(); ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; + } + + std::vector BuildMaskedTextFrame(std::string_view Text) + { + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + return BuildMaskedFrame(WebSocketOpcode::kText, Payload); + } + + std::vector BuildMaskedCloseFrame(uint16_t Code) + { + std::vector Payload; + Payload.push_back(static_cast((Code >> 8) & 0xFF)); + Payload.push_back(static_cast(Code & 0xFF)); + return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); + } + + /** + * Test service that implements IWebSocketHandler + */ + struct WsTestService : public HttpService, public IWebSocketHandler + { + const char* BaseUri() const override { return "/wstest/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello from wstest"); + } + + // IWebSocketHandler + void OnWebSocketOpen(Ref Connection) override + { + m_OpenCount.fetch_add(1); + + m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); }); + } + + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override + { + m_MessageCount.fetch_add(1); + + if (Msg.Opcode == WebSocketOpcode::kText) + { + std::string_view Text(static_cast(Msg.Payload.Data()), Msg.Payload.Size()); + m_LastMessage = std::string(Text); + + // Echo the message back + Conn.SendText(Text); + } + } + + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) override + { + m_CloseCount.fetch_add(1); + m_LastCloseCode = Code; + + m_ConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_Connections.begin(), m_Connections.end(), [&Conn](const Ref& C) { + return C.Get() == &Conn; + }); + m_Connections.erase(It, m_Connections.end()); + }); + } + + void SendToAll(std::string_view Text) + { + RwLock::SharedLockScope _(m_ConnectionsLock); + for (auto& Conn : m_Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Text); + } + } + } + + std::atomic m_OpenCount{0}; + std::atomic m_MessageCount{0}; + std::atomic m_CloseCount{0}; + std::atomic m_LastCloseCode{0}; + std::string m_LastMessage; + + RwLock m_ConnectionsLock; + std::vector> m_Connections; + }; + + /** + * Helper: Perform the WebSocket upgrade handshake on a raw TCP socket + * + * Returns true on success (101 response), false otherwise. + */ + bool DoWebSocketHandshake(asio::ip::tcp::socket& Sock, std::string_view Path, int Port) + { + // Send HTTP upgrade request + ExtendableStringBuilder<512> Request; + Request << "GET " << Path << " HTTP/1.1\r\n" + << "Host: 127.0.0.1:" << Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + << "Sec-WebSocket-Version: 13\r\n" + << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + + asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); + + // Read the response (look for "101") + asio::streambuf ResponseBuf; + asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + + return Response.find("101") != std::string::npos; + } + + /** + * Helper: Read a single server-to-client frame from a socket + * + * Uses a background thread with a synchronous ASIO read and a timeout. + */ + WsFrameParseResult ReadOneFrame(asio::ip::tcp::socket& Sock, int TimeoutMs = 5000) + { + std::vector Buffer; + WsFrameParseResult Result; + std::atomic Done{false}; + + std::thread Reader([&] { + while (!Done.load()) + { + uint8_t Tmp[4096]; + asio::error_code Ec; + size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec); + if (Ec || BytesRead == 0) + { + break; + } + + Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size()); + if (Frame.IsValid) + { + Result = std::move(Frame); + Done.store(true); + return; + } + } + }); + + auto Deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(TimeoutMs); + while (!Done.load() && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + if (!Done.load()) + { + // Timeout — cancel the read + asio::error_code Ec; + Sock.cancel(Ec); + } + + if (Reader.joinable()) + { + Reader.join(); + } + + return Result; + } + +} // anonymous namespace + +TEST_CASE("websocket.integration") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref Server = CreateHttpAsioServer(AsioConfig{}); + + int Port = Server->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + // Give server a moment to start accepting + Sleep(100); + + SUBCASE("handshake succeeds with 101") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + CHECK(Ok); + + Sleep(50); + CHECK_EQ(TestService.m_OpenCount.load(), 1); + + Sock.close(); + } + + SUBCASE("normal HTTP still works alongside WebSocket service") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + // Send a normal HTTP GET (not upgrade) + std::string HttpReq = fmt::format( + "GET /wstest/hello HTTP/1.1\r\n" + "Host: 127.0.0.1:{}\r\n" + "Connection: close\r\n" + "\r\n", + Port); + + asio::write(Sock, asio::buffer(HttpReq)); + + asio::streambuf ResponseBuf; + asio::error_code Ec; + asio::read(Sock, ResponseBuf, asio::transfer_at_least(1), Ec); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + CHECK(Response.find("200") != std::string::npos); + } + + SUBCASE("echo message roundtrip") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send a text message (masked, as client) + std::vector Frame = BuildMaskedTextFrame("ping test"); + asio::write(Sock, asio::buffer(Frame)); + + // Read the echo reply + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); + std::string_view ReplyText(reinterpret_cast(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(ReplyText, "ping test"sv); + CHECK_EQ(TestService.m_MessageCount.load(), 1); + CHECK_EQ(TestService.m_LastMessage, "ping test"); + + Sock.close(); + } + + SUBCASE("server push to client") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Server pushes a message + TestService.SendToAll("server says hello"); + + WsFrameParseResult Frame = ReadOneFrame(Sock); + REQUIRE(Frame.IsValid); + CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); + std::string_view Text(reinterpret_cast(Frame.Payload.data()), Frame.Payload.size()); + CHECK_EQ(Text, "server says hello"sv); + + Sock.close(); + } + + SUBCASE("client close handshake") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send close frame + std::vector CloseFrame = BuildMaskedCloseFrame(1000); + asio::write(Sock, asio::buffer(CloseFrame)); + + // Server should echo close back + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose); + + Sleep(50); + CHECK_EQ(TestService.m_CloseCount.load(), 1); + CHECK_EQ(TestService.m_LastCloseCode.load(), 1000); + + Sock.close(); + } + + SUBCASE("multiple concurrent connections") + { + constexpr int NumClients = 5; + + asio::io_context IoCtx; + std::vector Sockets; + + for (int i = 0; i < NumClients; ++i) + { + Sockets.emplace_back(IoCtx); + Sockets.back().connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sockets.back(), "/wstest/ws", Port); + REQUIRE(Ok); + } + + Sleep(100); + CHECK_EQ(TestService.m_OpenCount.load(), NumClients); + + // Broadcast from server + TestService.SendToAll("broadcast"); + + // Each client should receive the message + for (int i = 0; i < NumClients; ++i) + { + WsFrameParseResult Frame = ReadOneFrame(Sockets[i]); + REQUIRE(Frame.IsValid); + CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); + std::string_view Text(reinterpret_cast(Frame.Payload.data()), Frame.Payload.size()); + CHECK_EQ(Text, "broadcast"sv); + } + + // Close all + for (auto& S : Sockets) + { + S.close(); + } + } + + SUBCASE("service without IWebSocketHandler rejects upgrade") + { + // Register a plain HTTP service (no WebSocket) + struct PlainService : public HttpService + { + const char* BaseUri() const override { return "/plain/"; } + void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK); } + }; + + PlainService Plain; + Server->RegisterService(Plain); + + Sleep(50); + + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + // Attempt WebSocket upgrade on the plain service + ExtendableStringBuilder<512> Request; + Request << "GET /plain/ws HTTP/1.1\r\n" + << "Host: 127.0.0.1:" << Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + << "Sec-WebSocket-Version: 13\r\n" + << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); + + asio::streambuf ResponseBuf; + asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + + // Should NOT get 101 — should fall through to normal request handling + CHECK(Response.find("101") == std::string::npos); + + Sock.close(); + } + + SUBCASE("ping/pong auto-response") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send a ping frame with payload "test" + std::string_view PingPayload = "test"; + std::span PingData(reinterpret_cast(PingPayload.data()), PingPayload.size()); + std::vector PingFrame = BuildMaskedFrame(WebSocketOpcode::kPing, PingData); + asio::write(Sock, asio::buffer(PingFrame)); + + // Should receive a pong with the same payload + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kPong); + CHECK_EQ(Reply.Payload.size(), 4u); + std::string_view PongText(reinterpret_cast(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(PongText, "test"sv); + + Sock.close(); + } + + SUBCASE("multiple messages in sequence") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + for (int i = 0; i < 10; ++i) + { + std::string Msg = fmt::format("message {}", i); + std::vector Frame = BuildMaskedTextFrame(Msg); + asio::write(Sock, asio::buffer(Frame)); + + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); + std::string_view ReplyText(reinterpret_cast(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(ReplyText, Msg); + } + + CHECK_EQ(TestService.m_MessageCount.load(), 10); + + Sock.close(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Integration tests: HttpWsClient +// + +namespace { + + struct TestWsClientHandler : public IWsClientHandler + { + void OnWsOpen() override { m_OpenCount.fetch_add(1); } + + void OnWsMessage(const WebSocketMessage& Msg) override + { + m_MessageCount.fetch_add(1); + + if (Msg.Opcode == WebSocketOpcode::kText) + { + std::string_view Text(static_cast(Msg.Payload.Data()), Msg.Payload.Size()); + m_LastMessage = std::string(Text); + } + } + + void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override + { + m_CloseCount.fetch_add(1); + m_LastCloseCode = Code; + } + + std::atomic m_OpenCount{0}; + std::atomic m_MessageCount{0}; + std::atomic m_CloseCount{0}; + std::atomic m_LastCloseCode{0}; + std::string m_LastMessage; + }; + +} // anonymous namespace + +TEST_CASE("websocket.client") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref Server = CreateHttpAsioServer(AsioConfig{}); + + int Port = Server->Initialize(7576, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + Sleep(100); + + SUBCASE("connect, echo, close") + { + TestWsClientHandler Handler; + std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); + + HttpWsClient Client(Url, Handler); + Client.Connect(); + + // Wait for OnWsOpen + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + CHECK(Client.IsOpen()); + + // Send text, expect echo + Client.SendText("hello from client"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + CHECK_EQ(Handler.m_MessageCount.load(), 1); + CHECK_EQ(Handler.m_LastMessage, "hello from client"); + + // Close + Client.Close(1000, "done"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + // The server echoes the close frame, which triggers OnWsClose on the client side + // with the server's close code. Allow the connection to settle. + Sleep(50); + CHECK_FALSE(Client.IsOpen()); + } + + SUBCASE("connect to bad port") + { + TestWsClientHandler Handler; + std::string Url = "ws://127.0.0.1:1/wstest/ws"; + + HttpWsClient Client(Url, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)}); + Client.Connect(); + + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + CHECK_EQ(Handler.m_CloseCount.load(), 1); + CHECK_EQ(Handler.m_LastCloseCode.load(), 1006); + CHECK_EQ(Handler.m_OpenCount.load(), 0); + } + + SUBCASE("server-initiated close") + { + TestWsClientHandler Handler; + std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); + + HttpWsClient Client(Url, Handler); + Client.Connect(); + + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + + // Copy connections then close them outside the lock to avoid deadlocking + // with OnWebSocketClose which acquires an exclusive lock + std::vector> Conns; + TestService.m_ConnectionsLock.WithSharedLock([&] { Conns = TestService.m_Connections; }); + for (auto& Conn : Conns) + { + Conn->Close(1001, "going away"); + } + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + CHECK_EQ(Handler.m_CloseCount.load(), 1); + CHECK_EQ(Handler.m_LastCloseCode.load(), 1001); + CHECK_FALSE(Client.IsOpen()); + } +} + +void +websocket_forcelink() +{ +} + +} // namespace zen + +#endif // ZEN_WITH_TESTS diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 78876d21b..e8f87b668 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -6,6 +6,7 @@ target('zenhttp') add_headerfiles("**.h") add_files("**.cpp") add_files("servers/httpsys.cpp", {unity_ignored=true}) + add_files("servers/wshttpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) add_deps("zencore", "zentelemetry", "transport-sdk", "asio", "cpr") add_packages("http_parser", "json11") diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index ad14ecb8d..3ac8eea8d 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -19,6 +19,7 @@ zenhttp_forcelinktests() httpclient_test_forcelink(); forcelink_packageformat(); passwordsecurity_forcelink(); + websocket_forcelink(); } } // namespace zen -- cgit v1.2.3