aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-02-27 17:13:40 +0100
committerGitHub Enterprise <[email protected]>2026-02-27 17:13:40 +0100
commit0a41fd42aa43080fbc991e7d976dde70aeaec594 (patch)
tree765ce661d98b3659a58091afcaad587f03f4bea9 /src
parentadd sentry-sdk logger (#793) (diff)
downloadzen-0a41fd42aa43080fbc991e7d976dde70aeaec594.tar.xz
zen-0a41fd42aa43080fbc991e7d976dde70aeaec594.zip
add full WebSocket (RFC 6455) client/server support for zenhttp (#792)
* This branch adds full WebSocket (RFC 6455) support to the HTTP server layer, covering both transport backends, a client, and tests. - **`websocket.h`** -- Core interfaces: `WebSocketOpcode`, `WebSocketMessage`, `WebSocketConnection` (ref-counted), and `IWebSocketHandler`. Services opt in to WebSocket support by implementing `IWebSocketHandler` alongside their existing `HttpService`. - **`httpwsclient.h`** -- `HttpWsClient`: an ASIO-backed `ws://` client with both standalone (own thread) and shared `io_context` modes. Supports connect timeout and optional auth token injection via `IWsClientHandler` callbacks. - **`wsasio.cpp/h`** -- `WsAsioConnection`: WebSocket over ASIO TCP. Takes over the socket after the HTTP 101 handshake and runs an async read/write loop with a queued write path (guarded by `RwLock`). - **`wshttpsys.cpp/h`** -- `WsHttpSysConnection`: WebSocket over http.sys opaque-mode connections (Windows only). Uses `HttpReceiveRequestEntityBody` / `HttpSendResponseEntityBody` via IOCP, sharing the same threadpool as normal http.sys traffic. Self-ref lifetime management ensures graceful drain of outstanding async ops. - **`httpsys_iocontext.h`** -- Tagged `OVERLAPPED` wrapper (`HttpSysIoContext`) used to distinguish normal HTTP transactions from WebSocket read/write completions in the single IOCP callback. - **`wsframecodec.cpp/h`** -- `WsFrameCodec`: static helpers for parsing (unmasked and masked) and building (unmasked server frames and masked client frames) RFC 6455 frames across all three payload length encodings (7-bit, 16-bit, 64-bit). Also computes `Sec-WebSocket-Accept` keys. - **`clients/httpwsclient.cpp`** -- `HttpWsClient::Impl`: ASIO-based client that performs the HTTP upgrade handshake, then hands off to the frame codec for the read loop. Manages its own `io_context` thread or plugs into an external one. - **`httpasio.cpp`** -- ASIO server now detects `Upgrade: websocket` requests, checks the matching `HttpService` for `IWebSocketHandler` via `dynamic_cast`, performs the RFC 6455 handshake (101 response), and spins up a `WsAsioConnection`. - **`httpsys.cpp`** -- Same upgrade detection and handshake logic for the http.sys backend, using `WsHttpSysConnection` and `HTTP_SEND_RESPONSE_FLAG_OPAQUE`. - **`httpparser.cpp/h`** -- Extended to surface the `Upgrade` / `Connection` / `Sec-WebSocket-Key` headers needed by the handshake. - **`httpcommon.h`** -- Minor additions (probably new header constants or response codes for the WS upgrade). - **`httpserver.h`** -- Small interface changes to support WebSocket registration. - **`zenhttp.cpp` / `xmake.lua`** -- New source files wired in; build config updated. - **Unit tests** (`websocket.framecodec`): round-trip encode/decode for text, binary, close frames; all three payload sizes; masked and unmasked variants; RFC 6455 `Sec-WebSocket-Accept` test vector. - **Integration tests** (`websocket.integration`): full ASIO server tests covering handshake (101), normal HTTP coexistence, echo, server-push broadcast, client close handshake, ping/pong auto-response, sequential messages, and rejection of upgrades on non-WS services. - **Client tests** (`websocket.client`): `HttpWsClient` connect+echo+close, connection failure (bad port -> close code 1006), and server-initiated close. * changed HttpRequestParser::ParseCurrentHeader to use switch instead of if/else chain * remove spurious printf --------- Co-authored-by: Stefan Boberg <[email protected]>
Diffstat (limited to 'src')
-rw-r--r--src/zencore/filesystem.cpp1
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp568
-rw-r--r--src/zenhttp/include/zenhttp/httpcommon.h7
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h3
-rw-r--r--src/zenhttp/include/zenhttp/httpwsclient.h79
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h65
-rw-r--r--src/zenhttp/servers/httpasio.cpp49
-rw-r--r--src/zenhttp/servers/httpparser.cpp148
-rw-r--r--src/zenhttp/servers/httpparser.h7
-rw-r--r--src/zenhttp/servers/httpsys.cpp180
-rw-r--r--src/zenhttp/servers/httpsys_iocontext.h40
-rw-r--r--src/zenhttp/servers/wsasio.cpp297
-rw-r--r--src/zenhttp/servers/wsasio.h71
-rw-r--r--src/zenhttp/servers/wsframecodec.cpp229
-rw-r--r--src/zenhttp/servers/wsframecodec.h74
-rw-r--r--src/zenhttp/servers/wshttpsys.cpp466
-rw-r--r--src/zenhttp/servers/wshttpsys.h104
-rw-r--r--src/zenhttp/servers/wstest.cpp922
-rw-r--r--src/zenhttp/xmake.lua1
-rw-r--r--src/zenhttp/zenhttp.cpp1
20 files changed, 3203 insertions, 109 deletions
diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp
index 553897407..03398860b 100644
--- a/src/zencore/filesystem.cpp
+++ b/src/zencore/filesystem.cpp
@@ -3533,7 +3533,6 @@ TEST_CASE("PathBuilder")
Path.Reset();
Path.Append(fspath(L"/\u0119oo/"));
Path /= L"bar";
- printf("%ls\n", Path.ToPath().c_str());
CHECK(Path.ToView() == L"/\u0119oo/bar");
CHECK(Path.ToPath() == L"\\\u0119oo\\bar");
# endif
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
new file mode 100644
index 000000000..36a6f081b
--- /dev/null
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -0,0 +1,568 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpwsclient.h>
+
+#include "../servers/wsframecodec.h"
+
+#include <zencore/base64.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <deque>
+#include <random>
+#include <thread>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpWsClient::Impl
+{
+ Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_OwnedIoContext(std::make_unique<asio::io_context>())
+ , m_IoContext(*m_OwnedIoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_IoContext(IoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ ~Impl()
+ {
+ // Release work guard so io_context::run() can return
+ m_WorkGuard.reset();
+
+ // Close the socket to cancel pending async ops
+ if (m_Socket)
+ {
+ asio::error_code Ec;
+ m_Socket->close(Ec);
+ }
+
+ if (m_IoThread.joinable())
+ {
+ m_IoThread.join();
+ }
+ }
+
+ void ParseUrl(std::string_view Url)
+ {
+ // Expected format: ws://host:port/path
+ if (Url.substr(0, 5) == "ws://")
+ {
+ Url.remove_prefix(5);
+ }
+
+ auto SlashPos = Url.find('/');
+ std::string_view HostPort;
+ if (SlashPos != std::string_view::npos)
+ {
+ HostPort = Url.substr(0, SlashPos);
+ m_Path = std::string(Url.substr(SlashPos));
+ }
+ else
+ {
+ HostPort = Url;
+ m_Path = "/";
+ }
+
+ auto ColonPos = HostPort.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ m_Host = std::string(HostPort.substr(0, ColonPos));
+ m_Port = std::string(HostPort.substr(ColonPos + 1));
+ }
+ else
+ {
+ m_Host = std::string(HostPort);
+ m_Port = "80";
+ }
+ }
+
+ void Connect()
+ {
+ if (m_OwnedIoContext)
+ {
+ m_WorkGuard = std::make_unique<asio::io_context::work>(m_IoContext);
+ m_IoThread = std::thread([this] { m_IoContext.run(); });
+ }
+
+ asio::post(m_IoContext, [this] { DoResolve(); });
+ }
+
+ void DoResolve()
+ {
+ m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext);
+
+ m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) {
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "resolve failed");
+ return;
+ }
+
+ DoConnect(Results);
+ });
+ }
+
+ void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints)
+ {
+ m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoContext);
+
+ // Start connect timeout timer
+ m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
+ m_Timer->async_wait([this](const asio::error_code& Ec) {
+ if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port);
+ if (m_Socket)
+ {
+ asio::error_code CloseEc;
+ m_Socket->close(CloseEc);
+ }
+ }
+ });
+
+ asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "connect failed");
+ return;
+ }
+
+ DoHandshake();
+ });
+ }
+
+ void DoHandshake()
+ {
+ // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded)
+ uint8_t KeyBytes[16];
+ {
+ static thread_local std::mt19937 s_Rng(std::random_device{}());
+ for (int i = 0; i < 4; ++i)
+ {
+ uint32_t Val = s_Rng();
+ std::memcpy(KeyBytes + i * 4, &Val, 4);
+ }
+ }
+
+ char KeyBase64[Base64::GetEncodedDataSize(16) + 1];
+ uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64);
+ KeyBase64[KeyLen] = '\0';
+ m_WebSocketKey = std::string(KeyBase64, KeyLen);
+
+ // Build the HTTP upgrade request
+ ExtendableStringBuilder<512> Request;
+ Request << "GET " << m_Path << " HTTP/1.1\r\n"
+ << "Host: " << m_Host << ":" << m_Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n"
+ << "Sec-WebSocket-Version: 13\r\n";
+
+ // Add Authorization header if access token provider is set
+ if (m_Settings.AccessTokenProvider)
+ {
+ HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)();
+ if (Token.IsValid())
+ {
+ Request << "Authorization: Bearer " << Token.Value << "\r\n";
+ }
+ }
+
+ Request << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+
+ m_HandshakeBuffer = std::make_shared<std::string>(ReqStr);
+
+ asio::async_write(*m_Socket,
+ asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()),
+ [this](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake write failed");
+ return;
+ }
+
+ DoReadHandshakeResponse();
+ });
+ }
+
+ void DoReadHandshakeResponse()
+ {
+ asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) {
+ m_Timer->cancel();
+
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake read failed");
+ return;
+ }
+
+ // Parse the response
+ const auto& Data = m_ReadBuffer.data();
+ std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
+
+ // Consume the headers from the read buffer (any extra data stays for frame parsing)
+ auto HeaderEnd = Response.find("\r\n\r\n");
+ if (HeaderEnd != std::string::npos)
+ {
+ m_ReadBuffer.consume(HeaderEnd + 4);
+ }
+
+ // Validate 101 response
+ if (Response.find("101") == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
+ m_Handler.OnWsClose(1006, "handshake rejected");
+ return;
+ }
+
+ // Validate Sec-WebSocket-Accept
+ std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
+ if (Response.find(ExpectedAccept) == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
+ m_Handler.OnWsClose(1006, "invalid accept key");
+ return;
+ }
+
+ m_IsOpen.store(true);
+ m_Handler.OnWsOpen();
+ EnqueueRead();
+ });
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Read loop
+ //
+
+ void EnqueueRead()
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) {
+ OnDataReceived(Ec);
+ });
+ }
+
+ void OnDataReceived(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message());
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "connection lost");
+ }
+ return;
+ }
+
+ ProcessReceivedData();
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ EnqueueRead();
+ }
+ }
+
+ void ProcessReceivedData()
+ {
+ while (m_ReadBuffer.size() > 0)
+ {
+ const auto& InputBuffer = m_ReadBuffer.data();
+ const auto* RawData = static_cast<const uint8_t*>(InputBuffer.data());
+ const auto Size = InputBuffer.size();
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size);
+ if (!Frame.IsValid)
+ {
+ break;
+ }
+
+ m_ReadBuffer.consume(Frame.BytesConsumed);
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWsMessage(Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with masked pong
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason =
+ std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo masked close frame if we haven't sent one yet
+ if (!m_CloseSent)
+ {
+ m_CloseSent = true;
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWsClose(Code, Reason);
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Write queue
+ //
+
+ void EnqueueWrite(std::vector<uint8_t> Frame)
+ {
+ bool ShouldFlush = false;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_WriteQueue.push_back(std::move(Frame));
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ });
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+ }
+
+ void FlushWriteQueue()
+ {
+ std::vector<uint8_t> Frame;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+ Frame = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ });
+
+ if (Frame.empty())
+ {
+ return;
+ }
+
+ auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
+
+ asio::async_write(*m_Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); });
+ }
+
+ void OnWriteComplete(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message());
+ }
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_IsWriting = false;
+ m_WriteQueue.clear();
+ });
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "write error");
+ }
+ return;
+ }
+
+ FlushWriteQueue();
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Public operations
+ //
+
+ void SendText(std::string_view Text)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void SendBinary(std::span<const uint8_t> Data)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void DoClose(uint16_t Code, std::string_view Reason)
+ {
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ if (!m_CloseSent)
+ {
+ m_CloseSent = true;
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ IWsClientHandler& m_Handler;
+ HttpWsClientSettings m_Settings;
+ LoggerRef m_Log;
+
+ std::string m_Host;
+ std::string m_Port;
+ std::string m_Path;
+
+ // io_context: owned (standalone) or external (shared)
+ std::unique_ptr<asio::io_context> m_OwnedIoContext;
+ asio::io_context& m_IoContext;
+ std::unique_ptr<asio::io_context::work> m_WorkGuard;
+ std::thread m_IoThread;
+
+ // Connection state
+ std::unique_ptr<asio::ip::tcp::resolver> m_Resolver;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<asio::steady_timer> m_Timer;
+ asio::streambuf m_ReadBuffer;
+ std::string m_WebSocketKey;
+ std::shared_ptr<std::string> m_HandshakeBuffer;
+
+ // Write queue
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ bool m_IsWriting = false;
+
+ std::atomic<bool> m_IsOpen{false};
+ bool m_CloseSent = false;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, Settings))
+{
+}
+
+HttpWsClient::HttpWsClient(std::string_view Url,
+ IWsClientHandler& Handler,
+ asio::io_context& IoContext,
+ const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, IoContext, Settings))
+{
+}
+
+HttpWsClient::~HttpWsClient() = default;
+
+void
+HttpWsClient::Connect()
+{
+ m_Impl->Connect();
+}
+
+void
+HttpWsClient::SendText(std::string_view Text)
+{
+ m_Impl->SendText(Text);
+}
+
+void
+HttpWsClient::SendBinary(std::span<const uint8_t> Data)
+{
+ m_Impl->SendBinary(Data);
+}
+
+void
+HttpWsClient::Close(uint16_t Code, std::string_view Reason)
+{
+ m_Impl->DoClose(Code, Reason);
+}
+
+bool
+HttpWsClient::IsOpen() const
+{
+ return m_Impl->m_IsOpen.load(std::memory_order_relaxed);
+}
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h
index bc18549c9..8fca35ac5 100644
--- a/src/zenhttp/include/zenhttp/httpcommon.h
+++ b/src/zenhttp/include/zenhttp/httpcommon.h
@@ -184,6 +184,13 @@ IsHttpSuccessCode(HttpResponseCode HttpCode) noexcept
return IsHttpSuccessCode(int(HttpCode));
}
+[[nodiscard]] inline bool
+IsHttpOk(HttpResponseCode HttpCode) noexcept
+{
+ return HttpCode == HttpResponseCode::OK || HttpCode == HttpResponseCode::Created || HttpCode == HttpResponseCode::Accepted ||
+ HttpCode == HttpResponseCode::NoContent;
+}
+
std::string_view ToString(HttpResponseCode HttpCode);
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index 00cbc6c14..fee932daa 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -462,6 +462,7 @@ struct IHttpStatsService
virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
};
-void http_forcelink(); // internal
+void http_forcelink(); // internal
+void websocket_forcelink(); // internal
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h
new file mode 100644
index 000000000..926ec1e3d
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpwsclient.h
@@ -0,0 +1,79 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenhttp.h"
+
+#include <zenhttp/httpclient.h>
+#include <zenhttp/websocket.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio/io_context.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <chrono>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <span>
+#include <string>
+#include <string_view>
+
+namespace zen {
+
+/**
+ * Callback interface for WebSocket client events
+ *
+ * Separate from the server-side IWebSocketHandler because the caller
+ * already owns the HttpWsClient — no Ref<WebSocketConnection> needed.
+ */
+class IWsClientHandler
+{
+public:
+ virtual ~IWsClientHandler() = default;
+
+ virtual void OnWsOpen() = 0;
+ virtual void OnWsMessage(const WebSocketMessage& Msg) = 0;
+ virtual void OnWsClose(uint16_t Code, std::string_view Reason) = 0;
+};
+
+struct HttpWsClientSettings
+{
+ std::string LogCategory = "wsclient";
+ std::chrono::milliseconds ConnectTimeout{5000};
+ std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider;
+};
+
+/**
+ * WebSocket client over TCP (ws:// scheme)
+ *
+ * Uses ASIO for async I/O. Two construction modes:
+ * - Internal io_context + background thread (standalone use)
+ * - External io_context (shared event loop, no internal thread)
+ *
+ * Thread-safe for SendText/SendBinary/Close.
+ */
+class HttpWsClient
+{
+public:
+ HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings = {});
+ HttpWsClient(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings = {});
+
+ ~HttpWsClient();
+
+ HttpWsClient(const HttpWsClient&) = delete;
+ HttpWsClient& operator=(const HttpWsClient&) = delete;
+
+ void Connect();
+ void SendText(std::string_view Text);
+ void SendBinary(std::span<const uint8_t> Data);
+ void Close(uint16_t Code = 1000, std::string_view Reason = {});
+ bool IsOpen() const;
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
new file mode 100644
index 000000000..7a6fb33dd
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/websocket.h
@@ -0,0 +1,65 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenbase/refcount.h>
+#include <zencore/iobuffer.h>
+
+#include <cstdint>
+#include <span>
+#include <string_view>
+
+namespace zen {
+
+enum class WebSocketOpcode : uint8_t
+{
+ kText = 0x1,
+ kBinary = 0x2,
+ kClose = 0x8,
+ kPing = 0x9,
+ kPong = 0xA
+};
+
+struct WebSocketMessage
+{
+ WebSocketOpcode Opcode;
+ IoBuffer Payload;
+ uint16_t CloseCode = 0;
+};
+
+/**
+ * Represents an active WebSocket connection
+ *
+ * Derived classes implement the actual transport (e.g. ASIO sockets).
+ * Instances are reference-counted so that both the service layer and
+ * the async read/write loop can share ownership.
+ */
+class WebSocketConnection : public RefCounted
+{
+public:
+ virtual ~WebSocketConnection() = default;
+
+ virtual void SendText(std::string_view Text) = 0;
+ virtual void SendBinary(std::span<const uint8_t> Data) = 0;
+ virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0;
+ virtual bool IsOpen() const = 0;
+};
+
+/**
+ * Interface for services that accept WebSocket upgrades
+ *
+ * An HttpService may additionally implement this interface to indicate
+ * it supports WebSocket connections. The HTTP server checks for this
+ * via dynamic_cast when it sees an Upgrade: websocket request.
+ */
+class IWebSocketHandler
+{
+public:
+ virtual ~IWebSocketHandler() = default;
+
+ virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0;
+ virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0;
+ virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 0c0238886..8c2dcd116 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -14,6 +14,8 @@
#include <zenhttp/httpserver.h>
#include "httpparser.h"
+#include "wsasio.h"
+#include "wsframecodec.h"
#include <EASTL/fixed_vector.h>
@@ -1159,6 +1161,53 @@ HttpServerConnection::HandleRequest()
{
ZEN_MEMSCOPE(GetHttpasioTag());
+ // WebSocket upgrade detection must happen before the keep-alive check below,
+ // because Upgrade requests have "Connection: Upgrade" which the HTTP parser
+ // treats as non-keep-alive, causing a premature shutdown of the receive side.
+ if (m_RequestData.IsWebSocketUpgrade())
+ {
+ if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url()))
+ {
+ IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service);
+ if (WsHandler && !m_RequestData.SecWebSocketKey().empty())
+ {
+ std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(m_RequestData.SecWebSocketKey());
+
+ auto ResponseStr = std::make_shared<std::string>();
+ ResponseStr->reserve(256);
+ ResponseStr->append(
+ "HTTP/1.1 101 Switching Protocols\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: ");
+ ResponseStr->append(AcceptKey);
+ ResponseStr->append("\r\n\r\n");
+
+ // Send the 101 response on the current socket, then hand the socket off
+ // to a WsAsioConnection for the WebSocket protocol.
+ asio::async_write(*m_Socket,
+ asio::buffer(ResponseStr->data(), ResponseStr->size()),
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
+ return;
+ }
+
+ Ref<WsAsioConnection> WsConn(new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+ });
+
+ m_RequestState = RequestState::kDone;
+ return;
+ }
+ }
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
+
if (!m_RequestData.IsKeepAlive())
{
m_RequestState = RequestState::kWritingFinal;
diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp
index f0485aa25..3b1229375 100644
--- a/src/zenhttp/servers/httpparser.cpp
+++ b/src/zenhttp/servers/httpparser.cpp
@@ -12,14 +12,17 @@ namespace zen {
using namespace std::literals;
-static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
-static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
-static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
-static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
-static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
-static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
-static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
-static constinit uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv);
+static constexpr uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
+static constexpr uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
+static constexpr uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
+static constexpr uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
+static constexpr uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
+static constexpr uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
+static constexpr uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+static constexpr uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv);
+static constexpr uint32_t HashUpgrade = HashStringAsLowerDjb2("Upgrade"sv);
+static constexpr uint32_t HashSecWebSocketKey = HashStringAsLowerDjb2("Sec-WebSocket-Key"sv);
+static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-WebSocket-Version"sv);
//////////////////////////////////////////////////////////////////////////
//
@@ -143,45 +146,62 @@ HttpRequestParser::ParseCurrentHeader()
const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName);
const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1);
- if (HeaderHash == HashContentLength)
+ switch (HeaderHash)
{
- m_ContentLengthHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashAccept)
- {
- m_AcceptHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashContentType)
- {
- m_ContentTypeHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashAuthorization)
- {
- m_AuthorizationHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashSession)
- {
- m_SessionId = Oid::TryFromHexString(HeaderValue);
- }
- else if (HeaderHash == HashRequest)
- {
- std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
- }
- else if (HeaderHash == HashExpect)
- {
- if (HeaderValue == "100-continue"sv)
- {
- // We don't currently do anything with this
- m_Expect100Continue = true;
- }
- else
- {
- ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
- }
- }
- else if (HeaderHash == HashRange)
- {
- m_RangeHeaderIndex = CurrentHeaderIndex;
+ case HashContentLength:
+ m_ContentLengthHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashAccept:
+ m_AcceptHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashContentType:
+ m_ContentTypeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashAuthorization:
+ m_AuthorizationHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSession:
+ m_SessionId = Oid::TryFromHexString(HeaderValue);
+ break;
+
+ case HashRequest:
+ std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
+ break;
+
+ case HashExpect:
+ if (HeaderValue == "100-continue"sv)
+ {
+ // We don't currently do anything with this
+ m_Expect100Continue = true;
+ }
+ else
+ {
+ ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
+ }
+ break;
+
+ case HashRange:
+ m_RangeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashUpgrade:
+ m_UpgradeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSecWebSocketKey:
+ m_SecWebSocketKeyHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSecWebSocketVersion:
+ m_SecWebSocketVersionHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ default:
+ break;
}
}
@@ -361,14 +381,17 @@ HttpRequestParser::ResetState()
m_HeaderEntries.clear();
- m_ContentLengthHeaderIndex = -1;
- m_AcceptHeaderIndex = -1;
- m_ContentTypeHeaderIndex = -1;
- m_RangeHeaderIndex = -1;
- m_AuthorizationHeaderIndex = -1;
- m_Expect100Continue = false;
- m_BodyBuffer = {};
- m_BodyPosition = 0;
+ m_ContentLengthHeaderIndex = -1;
+ m_AcceptHeaderIndex = -1;
+ m_ContentTypeHeaderIndex = -1;
+ m_RangeHeaderIndex = -1;
+ m_AuthorizationHeaderIndex = -1;
+ m_UpgradeHeaderIndex = -1;
+ m_SecWebSocketKeyHeaderIndex = -1;
+ m_SecWebSocketVersionHeaderIndex = -1;
+ m_Expect100Continue = false;
+ m_BodyBuffer = {};
+ m_BodyPosition = 0;
m_HeaderData.clear();
m_NormalizedUrl.clear();
@@ -425,4 +448,21 @@ HttpRequestParser::OnMessageComplete()
}
}
+bool
+HttpRequestParser::IsWebSocketUpgrade() const
+{
+ std::string_view Upgrade = GetHeaderValue(m_UpgradeHeaderIndex);
+ if (Upgrade.empty())
+ {
+ return false;
+ }
+
+ // Case-insensitive check for "websocket"
+ if (Upgrade.size() != 9)
+ {
+ return false;
+ }
+ return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0;
+}
+
} // namespace zen
diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h
index ff56ca970..d40a5aeb0 100644
--- a/src/zenhttp/servers/httpparser.h
+++ b/src/zenhttp/servers/httpparser.h
@@ -48,6 +48,10 @@ struct HttpRequestParser
std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); }
+ std::string_view UpgradeHeader() const { return GetHeaderValue(m_UpgradeHeaderIndex); }
+ std::string_view SecWebSocketKey() const { return GetHeaderValue(m_SecWebSocketKeyHeaderIndex); }
+ bool IsWebSocketUpgrade() const;
+
private:
struct HeaderRange
{
@@ -86,6 +90,9 @@ private:
int8_t m_ContentTypeHeaderIndex;
int8_t m_RangeHeaderIndex;
int8_t m_AuthorizationHeaderIndex;
+ int8_t m_UpgradeHeaderIndex;
+ int8_t m_SecWebSocketKeyHeaderIndex;
+ int8_t m_SecWebSocketVersionHeaderIndex;
HttpVerb m_RequestVerb;
std::atomic_bool m_KeepAlive{false};
bool m_Expect100Continue = false;
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index e93ae4853..23d57af57 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -156,6 +156,10 @@ private:
#if ZEN_WITH_HTTPSYS
+# include "httpsys_iocontext.h"
+# include "wshttpsys.h"
+# include "wsframecodec.h"
+
# include <conio.h>
# include <mstcpip.h>
# pragma comment(lib, "httpapi.lib")
@@ -380,7 +384,7 @@ public:
PTP_IO Iocp();
HANDLE RequestQueueHandle();
- inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
+ inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; }
inline HttpSysServer& Server() { return m_HttpServer; }
inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
@@ -397,8 +401,8 @@ public:
};
private:
- OVERLAPPED m_HttpOverlapped{};
- HttpSysServer& m_HttpServer;
+ HttpSysIoContext m_IoContext{};
+ HttpSysServer& m_HttpServer;
// Tracks which handler is due to handle the next I/O completion event
HttpSysRequestHandler* m_CompletionHandler = nullptr;
@@ -1555,7 +1559,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
// than one thread at any given moment. This means we need to be careful about what
// happens in here
- HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped);
+ HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped);
+
+ switch (IoContext->ContextType)
+ {
+ case HttpSysIoContext::Type::kWebSocketRead:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kWebSocketWrite:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kTransaction:
+ break;
+ }
+
+ HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext);
if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone)
{
@@ -2111,64 +2131,118 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
{
HTTP_REQUEST* HttpReq = HttpRequest();
-# if 0
- for (int i = 0; i < HttpReq->RequestInfoCount; ++i)
+ if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
{
- auto& ReqInfo = HttpReq->pRequestInfo[i];
-
- switch (ReqInfo.InfoType)
+ // WebSocket upgrade detection
+ if (m_IsInitialRequest)
{
- case HttpRequestInfoTypeRequestTiming:
+ const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade];
+ if (UpgradeHeader.RawValueLength > 0 &&
+ StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0)
+ {
+ if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service))
{
- const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo);
+ // Extract Sec-WebSocket-Key from the unknown headers
+ // (http.sys has no known-header slot for it)
+ std::string_view SecWebSocketKey;
+ for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i)
+ {
+ const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i];
+ if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0)
+ {
+ SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength);
+ break;
+ }
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeAuth:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeChannelBind:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslProtocol:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBindingDraft:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBinding:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV0:
- {
- const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo);
+ if (SecWebSocketKey.empty())
+ {
+ ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header");
+ return nullptr;
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeRequestSizing:
- {
- const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo);
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeQuicStats:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV1:
- {
- const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo);
+ const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey);
+
+ HANDLE RequestQueueHandle = Transaction().RequestQueueHandle();
+ HTTP_REQUEST_ID RequestId = HttpReq->RequestId;
+
+ // Build the 101 Switching Protocols response
+ HTTP_RESPONSE Response = {};
+ Response.StatusCode = 101;
+ Response.pReason = "Switching Protocols";
+ Response.ReasonLength = (USHORT)strlen(Response.pReason);
+
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket";
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9;
+
+ eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders;
+
+ // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders
+ // despite there being an entry for it there (HttpHeaderConnection). If you try to do
+ // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below
+
+ UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"});
+
+ UnknownHeaders.push_back({.NameLength = 20,
+ .RawValueLength = (USHORT)AcceptKey.size(),
+ .pName = "Sec-WebSocket-Accept",
+ .pRawValue = AcceptKey.c_str()});
+
+ Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size();
+ Response.Headers.pUnknownHeaders = UnknownHeaders.data();
+
+ const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
+
+ // Use an OVERLAPPED with an event so we can wait synchronously.
+ // The request queue is IOCP-associated, so passing NULL for pOverlapped
+ // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent
+ // prevents IOCP delivery and lets us wait on the event directly.
+ OVERLAPPED SendOverlapped = {};
+ HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr);
+ SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1);
+
+ ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle,
+ RequestId,
+ Flags,
+ &Response,
+ nullptr, // CachePolicy
+ nullptr, // BytesSent
+ nullptr, // Reserved1
+ 0, // Reserved2
+ &SendOverlapped,
+ nullptr // LogData
+ );
+
+ if (SendResult == ERROR_IO_PENDING)
+ {
+ WaitForSingleObject(SendEvent, INFINITE);
+ SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE;
+ }
+
+ CloseHandle(SendEvent);
+
+ if (SendResult == NO_ERROR)
+ {
+ Ref<WsHttpSysConnection> WsConn(
+ new WsHttpSysConnection(RequestQueueHandle, RequestId, *WsHandler, Transaction().Iocp()));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+
+ return nullptr;
+ }
- ZEN_INFO("");
+ ZEN_WARN("WebSocket 101 send failed: {}", SendResult);
+
+ // WebSocket upgrade failed — return nullptr since ServerRequest()
+ // was never populated (no InvokeRequestHandler call)
+ return nullptr;
}
- break;
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
}
- }
-# endif
- if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
- {
if (m_IsInitialRequest)
{
m_ContentLength = GetContentLength(HttpReq);
diff --git a/src/zenhttp/servers/httpsys_iocontext.h b/src/zenhttp/servers/httpsys_iocontext.h
new file mode 100644
index 000000000..4f8a97012
--- /dev/null
+++ b/src/zenhttp/servers/httpsys_iocontext.h
@@ -0,0 +1,40 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+
+# include <cstdint>
+
+namespace zen {
+
+/**
+ * Tagged OVERLAPPED wrapper for http.sys IOCP dispatch
+ *
+ * Both HttpSysTransaction (for normal HTTP request I/O) and WsHttpSysConnection
+ * (for WebSocket read/write) embed this struct. The single IoCompletionCallback
+ * bound to the request queue uses the ContextType tag to dispatch to the correct
+ * handler.
+ *
+ * The Overlapped member must be first so that CONTAINING_RECORD works to recover
+ * the HttpSysIoContext from the OVERLAPPED pointer provided by the threadpool.
+ */
+struct HttpSysIoContext
+{
+ OVERLAPPED Overlapped{};
+
+ enum class Type : uint8_t
+ {
+ kTransaction,
+ kWebSocketRead,
+ kWebSocketWrite,
+ } ContextType = Type::kTransaction;
+
+ void* Owner = nullptr;
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
new file mode 100644
index 000000000..dfc1eac38
--- /dev/null
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -0,0 +1,297 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wsasio.h"
+#include "wsframecodec.h"
+
+#include <zencore/logging.h>
+
+namespace zen::asio_http {
+
+static LoggerRef
+WsLog()
+{
+ static LoggerRef g_Logger = logging::Get("ws");
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+WsAsioConnection::WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler)
+: m_Socket(std::move(Socket))
+, m_Handler(Handler)
+{
+}
+
+WsAsioConnection::~WsAsioConnection()
+{
+ m_IsOpen.store(false);
+}
+
+void
+WsAsioConnection::Start()
+{
+ EnqueueRead();
+}
+
+bool
+WsAsioConnection::IsOpen() const
+{
+ return m_IsOpen.load(std::memory_order_relaxed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Read loop
+//
+
+void
+WsAsioConnection::EnqueueRead()
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ Ref<WsAsioConnection> Self(this);
+
+ asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) {
+ Self->OnDataReceived(Ec, ByteCount);
+ });
+}
+
+void
+WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "WebSocket read error: {}", Ec.message());
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection lost");
+ }
+ return;
+ }
+
+ ProcessReceivedData();
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ EnqueueRead();
+ }
+}
+
+void
+WsAsioConnection::ProcessReceivedData()
+{
+ while (m_ReadBuffer.size() > 0)
+ {
+ const auto& InputBuffer = m_ReadBuffer.data();
+ const auto* Data = static_cast<const uint8_t*>(InputBuffer.data());
+ const auto Size = InputBuffer.size();
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size);
+ if (!Frame.IsValid)
+ {
+ break; // not enough data yet
+ }
+
+ m_ReadBuffer.consume(Frame.BytesConsumed);
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with pong carrying the same payload
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ // Unsolicited pong — ignore per RFC 6455
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo close frame back if we haven't sent one yet
+ if (!m_CloseSent)
+ {
+ m_CloseSent = true;
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+
+ // Shut down the socket
+ std::error_code ShutdownEc;
+ m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc);
+ m_Socket->close(ShutdownEc);
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(WsLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Write queue
+//
+
+void
+WsAsioConnection::SendText(std::string_view Text)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsAsioConnection::SendBinary(std::span<const uint8_t> Data)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsAsioConnection::Close(uint16_t Code, std::string_view Reason)
+{
+ DoClose(Code, Reason);
+}
+
+void
+WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason)
+{
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ if (!m_CloseSent)
+ {
+ m_CloseSent = true;
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+}
+
+void
+WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+{
+ bool ShouldFlush = false;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_WriteQueue.push_back(std::move(Frame));
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ });
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+}
+
+void
+WsAsioConnection::FlushWriteQueue()
+{
+ std::vector<uint8_t> Frame;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+ Frame = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ });
+
+ if (Frame.empty())
+ {
+ return;
+ }
+
+ Ref<WsAsioConnection> Self(this);
+
+ // Move Frame into a shared_ptr so we can create the buffer and capture ownership
+ // in the same async_write call without evaluation order issues.
+ auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
+
+ asio::async_write(*m_Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); });
+}
+
+void
+WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "WebSocket write error: {}", Ec.message());
+ }
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_IsWriting = false;
+ m_WriteQueue.clear();
+ });
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+ return;
+ }
+
+ FlushWriteQueue();
+}
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h
new file mode 100644
index 000000000..a638ea836
--- /dev/null
+++ b/src/zenhttp/servers/wsasio.h
@@ -0,0 +1,71 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include <zencore/thread.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <deque>
+#include <memory>
+#include <vector>
+
+namespace zen::asio_http {
+
+/**
+ * WebSocket connection over an ASIO TCP socket
+ *
+ * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake)
+ * and runs an async read/write loop to exchange WebSocket frames.
+ *
+ * Lifetime is managed solely through intrusive reference counting (RefCounted).
+ * The async read/write callbacks capture Ref<WsAsioConnection> to keep the
+ * connection alive for the duration of the async operation. The service layer
+ * also holds a Ref<WebSocketConnection>.
+ */
+class WsAsioConnection : public WebSocketConnection
+{
+public:
+ WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler);
+ ~WsAsioConnection() override;
+
+ /**
+ * Start the async read loop. Must be called once after construction
+ * and the 101 response has been sent.
+ */
+ void Start();
+
+ // WebSocketConnection interface
+ void SendText(std::string_view Text) override;
+ void SendBinary(std::span<const uint8_t> Data) override;
+ void Close(uint16_t Code, std::string_view Reason) override;
+ bool IsOpen() const override;
+
+private:
+ void EnqueueRead();
+ void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
+ void ProcessReceivedData();
+
+ void EnqueueWrite(std::vector<uint8_t> Frame);
+ void FlushWriteQueue();
+ void OnWriteComplete(const asio::error_code& Ec, std::size_t ByteCount);
+
+ void DoClose(uint16_t Code, std::string_view Reason);
+
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ IWebSocketHandler& m_Handler;
+ asio::streambuf m_ReadBuffer;
+
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ bool m_IsWriting = false;
+
+ std::atomic<bool> m_IsOpen{true};
+ bool m_CloseSent = false;
+};
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp
new file mode 100644
index 000000000..a4c5e0f16
--- /dev/null
+++ b/src/zenhttp/servers/wsframecodec.cpp
@@ -0,0 +1,229 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wsframecodec.h"
+
+#include <zencore/base64.h>
+#include <zencore/sha1.h>
+
+#include <cstring>
+#include <random>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame parsing
+//
+
+WsFrameParseResult
+WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size)
+{
+ // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames)
+ if (Size < 2)
+ {
+ return {};
+ }
+
+ const bool Fin = (Data[0] & 0x80) != 0;
+ const uint8_t OpcodeRaw = Data[0] & 0x0F;
+ const bool Masked = (Data[1] & 0x80) != 0;
+ uint64_t PayloadLen = Data[1] & 0x7F;
+
+ size_t HeaderSize = 2;
+
+ if (PayloadLen == 126)
+ {
+ if (Size < 4)
+ {
+ return {};
+ }
+ PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]);
+ HeaderSize = 4;
+ }
+ else if (PayloadLen == 127)
+ {
+ if (Size < 10)
+ {
+ return {};
+ }
+ PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) |
+ (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]);
+ HeaderSize = 10;
+ }
+
+ const size_t MaskSize = Masked ? 4 : 0;
+ const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen;
+
+ if (Size < TotalFrame)
+ {
+ return {};
+ }
+
+ const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr;
+ const uint8_t* PayloadData = Data + HeaderSize + MaskSize;
+
+ WsFrameParseResult Result;
+ Result.IsValid = true;
+ Result.BytesConsumed = TotalFrame;
+ Result.Opcode = static_cast<WebSocketOpcode>(OpcodeRaw);
+ Result.Fin = Fin;
+
+ Result.Payload.resize(static_cast<size_t>(PayloadLen));
+ if (PayloadLen > 0)
+ {
+ std::memcpy(Result.Payload.data(), PayloadData, static_cast<size_t>(PayloadLen));
+
+ if (Masked)
+ {
+ for (size_t i = 0; i < Result.Payload.size(); ++i)
+ {
+ Result.Payload[i] ^= MaskKey[i & 3];
+ }
+ }
+ }
+
+ return Result;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame building (server-to-client, no masking)
+//
+
+std::vector<uint8_t>
+WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+{
+ std::vector<uint8_t> Frame;
+
+ const size_t PayloadLen = Payload.size();
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length (no mask bit for server frames)
+ if (PayloadLen < 126)
+ {
+ Frame.push_back(static_cast<uint8_t>(PayloadLen));
+ }
+ else if (PayloadLen <= 0xFFFF)
+ {
+ Frame.push_back(126);
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF));
+ }
+ }
+
+ Frame.insert(Frame.end(), Payload.begin(), Payload.end());
+
+ return Frame;
+}
+
+std::vector<uint8_t>
+WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason)
+{
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ Payload.insert(Payload.end(), Reason.begin(), Reason.end());
+
+ return BuildFrame(WebSocketOpcode::kClose, Payload);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame building (client-to-server, with masking)
+//
+
+std::vector<uint8_t>
+WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+{
+ std::vector<uint8_t> Frame;
+
+ const size_t PayloadLen = Payload.size();
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length with mask bit set
+ if (PayloadLen < 126)
+ {
+ Frame.push_back(0x80 | static_cast<uint8_t>(PayloadLen));
+ }
+ else if (PayloadLen <= 0xFFFF)
+ {
+ Frame.push_back(0x80 | 126);
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(0x80 | 127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF));
+ }
+ }
+
+ // Generate random 4-byte mask key
+ static thread_local std::mt19937 s_Rng(std::random_device{}());
+ uint32_t MaskValue = s_Rng();
+ uint8_t MaskKey[4];
+ std::memcpy(MaskKey, &MaskValue, 4);
+
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+
+ // Masked payload
+ for (size_t i = 0; i < PayloadLen; ++i)
+ {
+ Frame.push_back(Payload[i] ^ MaskKey[i & 3]);
+ }
+
+ return Frame;
+}
+
+std::vector<uint8_t>
+WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason)
+{
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ Payload.insert(Payload.end(), Reason.begin(), Reason.end());
+
+ return BuildMaskedFrame(WebSocketOpcode::kClose, Payload);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2)
+//
+
+static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+
+std::string
+WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey)
+{
+ // Concatenate client key with the magic GUID
+ std::string Combined;
+ Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size());
+ Combined.append(ClientKey);
+ Combined.append(kWebSocketMagicGuid);
+
+ // SHA1 hash
+ SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size());
+
+ // Base64 encode the 20-byte hash
+ char Base64Buf[Base64::GetEncodedDataSize(20) + 1];
+ uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf);
+ Base64Buf[EncodedLen] = '\0';
+
+ return std::string(Base64Buf, EncodedLen);
+}
+
+} // namespace zen
diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h
new file mode 100644
index 000000000..2d90b6fa1
--- /dev/null
+++ b/src/zenhttp/servers/wsframecodec.h
@@ -0,0 +1,74 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include <cstddef>
+#include <cstdint>
+#include <optional>
+#include <span>
+#include <string>
+#include <string_view>
+#include <vector>
+
+namespace zen {
+
+/**
+ * Result of attempting to parse a single WebSocket frame from a byte buffer
+ */
+struct WsFrameParseResult
+{
+ bool IsValid = false; // true if a complete frame was successfully parsed
+ size_t BytesConsumed = 0; // number of bytes consumed from the input buffer
+ WebSocketOpcode Opcode = WebSocketOpcode::kText;
+ bool Fin = false;
+ std::vector<uint8_t> Payload;
+};
+
+/**
+ * RFC 6455 WebSocket frame codec
+ *
+ * Provides static helpers for parsing client-to-server frames (which are
+ * always masked) and building server-to-client frames (which are never masked).
+ */
+struct WsFrameCodec
+{
+ /**
+ * Try to parse one complete frame from the front of the buffer.
+ *
+ * Returns a result with IsValid == false and BytesConsumed == 0 when
+ * there is not enough data yet. The caller should accumulate more data
+ * and retry.
+ */
+ static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size);
+
+ /**
+ * Build a server-to-client frame (no masking)
+ */
+ static std::vector<uint8_t> BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload);
+
+ /**
+ * Build a close frame with a status code and optional reason string
+ */
+ static std::vector<uint8_t> BuildCloseFrame(uint16_t Code, std::string_view Reason = {});
+
+ /**
+ * Build a client-to-server frame (with masking per RFC 6455)
+ */
+ static std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload);
+
+ /**
+ * Build a masked close frame with status code and optional reason
+ */
+ static std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason = {});
+
+ /**
+ * Compute the Sec-WebSocket-Accept value per RFC 6455 section 4.2.2
+ *
+ * accept = Base64(SHA1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
+ */
+ static std::string ComputeAcceptKey(std::string_view ClientKey);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp
new file mode 100644
index 000000000..3f0f0b447
--- /dev/null
+++ b/src/zenhttp/servers/wshttpsys.cpp
@@ -0,0 +1,466 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wshttpsys.h"
+
+#if ZEN_WITH_HTTPSYS
+
+# include "wsframecodec.h"
+
+# include <zencore/logging.h>
+
+namespace zen {
+
+static LoggerRef
+WsHttpSysLog()
+{
+ static LoggerRef g_Logger = logging::Get("ws_httpsys");
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp)
+: m_RequestQueueHandle(RequestQueueHandle)
+, m_RequestId(RequestId)
+, m_Handler(Handler)
+, m_Iocp(Iocp)
+, m_ReadBuffer(8192)
+{
+ m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead;
+ m_ReadIoContext.Owner = this;
+ m_WriteIoContext.ContextType = HttpSysIoContext::Type::kWebSocketWrite;
+ m_WriteIoContext.Owner = this;
+}
+
+WsHttpSysConnection::~WsHttpSysConnection()
+{
+ ZEN_ASSERT(m_OutstandingOps.load() == 0);
+
+ if (m_IsOpen.exchange(false))
+ {
+ Disconnect();
+ }
+}
+
+void
+WsHttpSysConnection::Start()
+{
+ m_SelfRef = Ref<WsHttpSysConnection>(this);
+ IssueAsyncRead();
+}
+
+void
+WsHttpSysConnection::Shutdown()
+{
+ m_ShutdownRequested.store(true, std::memory_order_relaxed);
+
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED
+ HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
+}
+
+bool
+WsHttpSysConnection::IsOpen() const
+{
+ return m_IsOpen.load(std::memory_order_relaxed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Async read path
+//
+
+void
+WsHttpSysConnection::IssueAsyncRead()
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed) || m_ShutdownRequested.load(std::memory_order_relaxed))
+ {
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ m_OutstandingOps.fetch_add(1, std::memory_order_relaxed);
+
+ ZeroMemory(&m_ReadIoContext.Overlapped, sizeof(OVERLAPPED));
+
+ StartThreadpoolIo(m_Iocp);
+
+ ULONG Result = HttpReceiveRequestEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ 0, // Flags
+ m_ReadBuffer.data(),
+ (ULONG)m_ReadBuffer.size(),
+ nullptr, // BytesRead (ignored for async)
+ &m_ReadIoContext.Overlapped);
+
+ if (Result != NO_ERROR && Result != ERROR_IO_PENDING)
+ {
+ CancelThreadpoolIo(m_Iocp);
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "read issue failed");
+ }
+
+ MaybeReleaseSelfRef();
+ }
+}
+
+void
+WsHttpSysConnection::OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ // Hold a transient ref to prevent mid-callback destruction after MaybeReleaseSelfRef
+ Ref<WsHttpSysConnection> Guard(this);
+
+ if (IoResult != NO_ERROR)
+ {
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.exchange(false))
+ {
+ if (IoResult == ERROR_HANDLE_EOF)
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection closed");
+ }
+ else if (IoResult != ERROR_OPERATION_ABORTED)
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection lost");
+ }
+ }
+
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ if (NumberOfBytesTransferred > 0)
+ {
+ m_Accumulated.insert(m_Accumulated.end(), m_ReadBuffer.begin(), m_ReadBuffer.begin() + NumberOfBytesTransferred);
+ ProcessReceivedData();
+ }
+
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ IssueAsyncRead();
+ }
+ else
+ {
+ MaybeReleaseSelfRef();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame parsing
+//
+
+void
+WsHttpSysConnection::ProcessReceivedData()
+{
+ while (!m_Accumulated.empty())
+ {
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size());
+ if (!Frame.IsValid)
+ {
+ break; // not enough data yet
+ }
+
+ // Remove consumed bytes
+ m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed);
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with pong carrying the same payload
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ // Unsolicited pong — ignore per RFC 6455
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo close frame back if we haven't sent one yet
+ {
+ bool ShouldSendClose = false;
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ if (!m_CloseSent)
+ {
+ m_CloseSent = true;
+ ShouldSendClose = true;
+ }
+ }
+ if (ShouldSendClose)
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+ Disconnect();
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(WsHttpSysLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Async write path
+//
+
+void
+WsHttpSysConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+{
+ bool ShouldFlush = false;
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.push_back(std::move(Frame));
+
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ }
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+}
+
+void
+WsHttpSysConnection::FlushWriteQueue()
+{
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+
+ m_CurrentWriteBuffer = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ }
+
+ m_OutstandingOps.fetch_add(1, std::memory_order_relaxed);
+
+ ZeroMemory(&m_WriteChunk, sizeof(m_WriteChunk));
+ m_WriteChunk.DataChunkType = HttpDataChunkFromMemory;
+ m_WriteChunk.FromMemory.pBuffer = m_CurrentWriteBuffer.data();
+ m_WriteChunk.FromMemory.BufferLength = (ULONG)m_CurrentWriteBuffer.size();
+
+ ZeroMemory(&m_WriteIoContext.Overlapped, sizeof(OVERLAPPED));
+
+ StartThreadpoolIo(m_Iocp);
+
+ ULONG Result = HttpSendResponseEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ HTTP_SEND_RESPONSE_FLAG_MORE_DATA,
+ 1,
+ &m_WriteChunk,
+ nullptr,
+ nullptr,
+ 0,
+ &m_WriteIoContext.Overlapped,
+ nullptr);
+
+ if (Result != NO_ERROR && Result != ERROR_IO_PENDING)
+ {
+ CancelThreadpoolIo(m_Iocp);
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket async write failed: {}", Result);
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.clear();
+ m_IsWriting = false;
+ }
+ m_CurrentWriteBuffer.clear();
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+
+ MaybeReleaseSelfRef();
+ }
+}
+
+void
+WsHttpSysConnection::OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ ZEN_UNUSED(NumberOfBytesTransferred);
+
+ // Hold a transient ref to prevent mid-callback destruction
+ Ref<WsHttpSysConnection> Guard(this);
+
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+ m_CurrentWriteBuffer.clear();
+
+ if (IoResult != NO_ERROR)
+ {
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket write completion error: {}", IoResult);
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.clear();
+ m_IsWriting = false;
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ FlushWriteQueue();
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Send interface
+//
+
+void
+WsHttpSysConnection::SendText(std::string_view Text)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsHttpSysConnection::SendBinary(std::span<const uint8_t> Data)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsHttpSysConnection::Close(uint16_t Code, std::string_view Reason)
+{
+ DoClose(Code, Reason);
+}
+
+void
+WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason)
+{
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ {
+ bool ShouldSendClose = false;
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ if (!m_CloseSent)
+ {
+ m_CloseSent = true;
+ ShouldSendClose = true;
+ }
+ }
+ if (ShouldSendClose)
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+
+ // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED
+ HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Lifetime management
+//
+
+void
+WsHttpSysConnection::MaybeReleaseSelfRef()
+{
+ if (m_OutstandingOps.load(std::memory_order_relaxed) == 0 && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ m_SelfRef = nullptr;
+ }
+}
+
+void
+WsHttpSysConnection::Disconnect()
+{
+ // Send final empty body with DISCONNECT to tell http.sys the connection is done
+ HttpSendResponseEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ HTTP_SEND_RESPONSE_FLAG_DISCONNECT,
+ 0,
+ nullptr,
+ nullptr,
+ nullptr,
+ 0,
+ nullptr,
+ nullptr);
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h
new file mode 100644
index 000000000..ab0ca381a
--- /dev/null
+++ b/src/zenhttp/servers/wshttpsys.h
@@ -0,0 +1,104 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include "httpsys_iocontext.h"
+
+#include <zencore/thread.h>
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+# include <http.h>
+
+# include <atomic>
+# include <deque>
+# include <vector>
+
+namespace zen {
+
+/**
+ * WebSocket connection over an http.sys opaque-mode connection
+ *
+ * After the 101 Switching Protocols response is sent with
+ * HTTP_SEND_RESPONSE_FLAG_OPAQUE, http.sys stops parsing HTTP on the
+ * connection. Raw bytes are exchanged via HttpReceiveRequestEntityBody /
+ * HttpSendResponseEntityBody using the original RequestId.
+ *
+ * All I/O is performed asynchronously via the same IOCP threadpool used
+ * for normal http.sys traffic, eliminating per-connection threads.
+ *
+ * Lifetime is managed through intrusive reference counting (RefCounted).
+ * A self-reference (m_SelfRef) is held from Start() until all outstanding
+ * async operations have drained, preventing premature destruction.
+ */
+class WsHttpSysConnection : public WebSocketConnection
+{
+public:
+ WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp);
+ ~WsHttpSysConnection() override;
+
+ /**
+ * Start the async read loop. Must be called once after construction
+ * and after the 101 response has been sent.
+ */
+ void Start();
+
+ /**
+ * Shut down the connection. Cancels pending I/O; IOCP completions
+ * will fire with ERROR_OPERATION_ABORTED and drain naturally.
+ */
+ void Shutdown();
+
+ // WebSocketConnection interface
+ void SendText(std::string_view Text) override;
+ void SendBinary(std::span<const uint8_t> Data) override;
+ void Close(uint16_t Code, std::string_view Reason) override;
+ bool IsOpen() const override;
+
+ // Called from IoCompletionCallback via tagged dispatch
+ void OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+ void OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+
+private:
+ void IssueAsyncRead();
+ void ProcessReceivedData();
+ void EnqueueWrite(std::vector<uint8_t> Frame);
+ void FlushWriteQueue();
+ void DoClose(uint16_t Code, std::string_view Reason);
+ void Disconnect();
+ void MaybeReleaseSelfRef();
+
+ HANDLE m_RequestQueueHandle;
+ HTTP_REQUEST_ID m_RequestId;
+ IWebSocketHandler& m_Handler;
+ PTP_IO m_Iocp;
+
+ // Tagged OVERLAPPED contexts for concurrent read and write
+ HttpSysIoContext m_ReadIoContext{};
+ HttpSysIoContext m_WriteIoContext{};
+
+ // Read state
+ std::vector<uint8_t> m_ReadBuffer;
+ std::vector<uint8_t> m_Accumulated;
+
+ // Write state
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ std::vector<uint8_t> m_CurrentWriteBuffer;
+ HTTP_DATA_CHUNK m_WriteChunk{};
+ bool m_IsWriting = false;
+
+ // Lifetime management
+ std::atomic<int32_t> m_OutstandingOps{0};
+ Ref<WsHttpSysConnection> m_SelfRef;
+ std::atomic<bool> m_ShutdownRequested{false};
+ std::atomic<bool> m_IsOpen{true};
+ bool m_CloseSent = false;
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
new file mode 100644
index 000000000..95f8587df
--- /dev/null
+++ b/src/zenhttp/servers/wstest.cpp
@@ -0,0 +1,922 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/scopeguard.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+# include <zenhttp/httpserver.h>
+# include <zenhttp/httpwsclient.h>
+# include <zenhttp/websocket.h>
+
+# include "httpasio.h"
+# include "wsframecodec.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# if ZEN_PLATFORM_WINDOWS
+# include <winsock2.h>
+# else
+# include <poll.h>
+# include <sys/socket.h>
+# endif
+# include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+# include <atomic>
+# include <chrono>
+# include <cstring>
+# include <random>
+# include <string>
+# include <string_view>
+# include <thread>
+# include <vector>
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Unit tests: WsFrameCodec
+//
+
+TEST_CASE("websocket.framecodec")
+{
+ SUBCASE("ComputeAcceptKey RFC 6455 test vector")
+ {
+ // RFC 6455 section 4.2.2 example
+ std::string AcceptKey = WsFrameCodec::ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
+ CHECK_EQ(AcceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
+ }
+
+ SUBCASE("BuildFrame and TryParseFrame roundtrip - text")
+ {
+ std::string_view Text = "Hello, WebSocket!";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+
+ // Server frames are unmasked — TryParseFrame should handle them
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, Frame.size());
+ CHECK(Result.Fin);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), Text.size());
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text);
+ }
+
+ SUBCASE("BuildFrame and TryParseFrame roundtrip - binary")
+ {
+ std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD};
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, BinaryData);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary);
+ CHECK_EQ(Result.Payload, BinaryData);
+ }
+
+ SUBCASE("BuildFrame - medium payload (126-65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(300, 0x42);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 300u);
+ CHECK_EQ(Result.Payload, Payload);
+ }
+
+ SUBCASE("BuildFrame - large payload (>65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(70000, 0xAB);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 70000u);
+ }
+
+ SUBCASE("BuildCloseFrame roundtrip")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildCloseFrame(1000, "normal closure");
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Result.Payload.size() >= 2);
+
+ uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]);
+ CHECK_EQ(Code, 1000);
+
+ std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2);
+ CHECK_EQ(Reason, "normal closure");
+ }
+
+ SUBCASE("TryParseFrame - partial data returns invalid")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
+
+ // Pass only 1 byte — not enough for a frame header
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1);
+ CHECK_FALSE(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, 0u);
+ }
+
+ SUBCASE("TryParseFrame - empty payload")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK(Result.Payload.empty());
+ }
+
+ SUBCASE("TryParseFrame - masked client frame")
+ {
+ // Build a masked frame manually as a client would send
+ // Frame: FIN=1, opcode=text, MASK=1, payload_len=5, mask_key=0x37FA213D, payload="Hello"
+ uint8_t MaskKey[4] = {0x37, 0xFA, 0x21, 0x3D};
+ uint8_t MaskedPayload[5] = {};
+ const char* Original = "Hello";
+ for (int i = 0; i < 5; ++i)
+ {
+ MaskedPayload[i] = static_cast<uint8_t>(Original[i]) ^ MaskKey[i % 4];
+ }
+
+ std::vector<uint8_t> Frame;
+ Frame.push_back(0x81); // FIN + text
+ Frame.push_back(0x85); // MASK + len=5
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+ Frame.insert(Frame.end(), MaskedPayload, MaskedPayload + 5);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), 5u);
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), 5), "Hello"sv);
+ }
+
+ SUBCASE("BuildMaskedFrame roundtrip - text")
+ {
+ std::string_view Text = "Hello, masked WebSocket!";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+
+ // Verify mask bit is set
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, Frame.size());
+ CHECK(Result.Fin);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), Text.size());
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text);
+ }
+
+ SUBCASE("BuildMaskedFrame roundtrip - binary")
+ {
+ std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD};
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, BinaryData);
+
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary);
+ CHECK_EQ(Result.Payload, BinaryData);
+ }
+
+ SUBCASE("BuildMaskedFrame - medium payload (126-65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(300, 0x42);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload);
+
+ CHECK((Frame[1] & 0x80) != 0);
+ CHECK_EQ((Frame[1] & 0x7F), 126); // 16-bit extended length
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 300u);
+ CHECK_EQ(Result.Payload, Payload);
+ }
+
+ SUBCASE("BuildMaskedFrame - large payload (>65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(70000, 0xAB);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload);
+
+ CHECK((Frame[1] & 0x80) != 0);
+ CHECK_EQ((Frame[1] & 0x7F), 127); // 64-bit extended length
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 70000u);
+ }
+
+ SUBCASE("BuildMaskedCloseFrame roundtrip")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure");
+
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Result.Payload.size() >= 2);
+
+ uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]);
+ CHECK_EQ(Code, 1000);
+
+ std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2);
+ CHECK_EQ(Reason, "normal closure");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Integration tests: WebSocket over ASIO
+//
+
+namespace {
+
+ /**
+ * Helper: Build a masked client-to-server frame per RFC 6455
+ */
+ std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+ {
+ std::vector<uint8_t> Frame;
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length with mask bit set
+ if (Payload.size() < 126)
+ {
+ Frame.push_back(0x80 | static_cast<uint8_t>(Payload.size()));
+ }
+ else if (Payload.size() <= 0xFFFF)
+ {
+ Frame.push_back(0x80 | 126);
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(Payload.size() & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(0x80 | 127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> (i * 8)) & 0xFF));
+ }
+ }
+
+ // Mask key (use a fixed key for deterministic tests)
+ uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78};
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+
+ // Masked payload
+ for (size_t i = 0; i < Payload.size(); ++i)
+ {
+ Frame.push_back(Payload[i] ^ MaskKey[i & 3]);
+ }
+
+ return Frame;
+ }
+
+ std::vector<uint8_t> BuildMaskedTextFrame(std::string_view Text)
+ {
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ return BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ }
+
+ std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code)
+ {
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ return BuildMaskedFrame(WebSocketOpcode::kClose, Payload);
+ }
+
+ /**
+ * Test service that implements IWebSocketHandler
+ */
+ struct WsTestService : public HttpService, public IWebSocketHandler
+ {
+ const char* BaseUri() const override { return "/wstest/"; }
+
+ void HandleRequest(HttpServerRequest& Request) override
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello from wstest");
+ }
+
+ // IWebSocketHandler
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ {
+ m_OpenCount.fetch_add(1);
+
+ m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); });
+ }
+
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override
+ {
+ m_MessageCount.fetch_add(1);
+
+ if (Msg.Opcode == WebSocketOpcode::kText)
+ {
+ std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size());
+ m_LastMessage = std::string(Text);
+
+ // Echo the message back
+ Conn.SendText(Text);
+ }
+ }
+
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) override
+ {
+ m_CloseCount.fetch_add(1);
+ m_LastCloseCode = Code;
+
+ m_ConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_Connections.begin(), m_Connections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_Connections.erase(It, m_Connections.end());
+ });
+ }
+
+ void SendToAll(std::string_view Text)
+ {
+ RwLock::SharedLockScope _(m_ConnectionsLock);
+ for (auto& Conn : m_Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendText(Text);
+ }
+ }
+ }
+
+ std::atomic<int> m_OpenCount{0};
+ std::atomic<int> m_MessageCount{0};
+ std::atomic<int> m_CloseCount{0};
+ std::atomic<uint16_t> m_LastCloseCode{0};
+ std::string m_LastMessage;
+
+ RwLock m_ConnectionsLock;
+ std::vector<Ref<WebSocketConnection>> m_Connections;
+ };
+
+ /**
+ * Helper: Perform the WebSocket upgrade handshake on a raw TCP socket
+ *
+ * Returns true on success (101 response), false otherwise.
+ */
+ bool DoWebSocketHandshake(asio::ip::tcp::socket& Sock, std::string_view Path, int Port)
+ {
+ // Send HTTP upgrade request
+ ExtendableStringBuilder<512> Request;
+ Request << "GET " << Path << " HTTP/1.1\r\n"
+ << "Host: 127.0.0.1:" << Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ << "Sec-WebSocket-Version: 13\r\n"
+ << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+
+ asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size()));
+
+ // Read the response (look for "101")
+ asio::streambuf ResponseBuf;
+ asio::read_until(Sock, ResponseBuf, "\r\n\r\n");
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+
+ return Response.find("101") != std::string::npos;
+ }
+
+ /**
+ * Helper: Read a single server-to-client frame from a socket
+ *
+ * Uses a background thread with a synchronous ASIO read and a timeout.
+ */
+ WsFrameParseResult ReadOneFrame(asio::ip::tcp::socket& Sock, int TimeoutMs = 5000)
+ {
+ std::vector<uint8_t> Buffer;
+ WsFrameParseResult Result;
+ std::atomic<bool> Done{false};
+
+ std::thread Reader([&] {
+ while (!Done.load())
+ {
+ uint8_t Tmp[4096];
+ asio::error_code Ec;
+ size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec);
+ if (Ec || BytesRead == 0)
+ {
+ break;
+ }
+
+ Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead);
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size());
+ if (Frame.IsValid)
+ {
+ Result = std::move(Frame);
+ Done.store(true);
+ return;
+ }
+ }
+ });
+
+ auto Deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(TimeoutMs);
+ while (!Done.load() && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ if (!Done.load())
+ {
+ // Timeout — cancel the read
+ asio::error_code Ec;
+ Sock.cancel(Ec);
+ }
+
+ if (Reader.joinable())
+ {
+ Reader.join();
+ }
+
+ return Result;
+ }
+
+} // anonymous namespace
+
+TEST_CASE("websocket.integration")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = Server->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ // Give server a moment to start accepting
+ Sleep(100);
+
+ SUBCASE("handshake succeeds with 101")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ CHECK(Ok);
+
+ Sleep(50);
+ CHECK_EQ(TestService.m_OpenCount.load(), 1);
+
+ Sock.close();
+ }
+
+ SUBCASE("normal HTTP still works alongside WebSocket service")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ // Send a normal HTTP GET (not upgrade)
+ std::string HttpReq = fmt::format(
+ "GET /wstest/hello HTTP/1.1\r\n"
+ "Host: 127.0.0.1:{}\r\n"
+ "Connection: close\r\n"
+ "\r\n",
+ Port);
+
+ asio::write(Sock, asio::buffer(HttpReq));
+
+ asio::streambuf ResponseBuf;
+ asio::error_code Ec;
+ asio::read(Sock, ResponseBuf, asio::transfer_at_least(1), Ec);
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+ CHECK(Response.find("200") != std::string::npos);
+ }
+
+ SUBCASE("echo message roundtrip")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send a text message (masked, as client)
+ std::vector<uint8_t> Frame = BuildMaskedTextFrame("ping test");
+ asio::write(Sock, asio::buffer(Frame));
+
+ // Read the echo reply
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText);
+ std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(ReplyText, "ping test"sv);
+ CHECK_EQ(TestService.m_MessageCount.load(), 1);
+ CHECK_EQ(TestService.m_LastMessage, "ping test");
+
+ Sock.close();
+ }
+
+ SUBCASE("server push to client")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Server pushes a message
+ TestService.SendToAll("server says hello");
+
+ WsFrameParseResult Frame = ReadOneFrame(Sock);
+ REQUIRE(Frame.IsValid);
+ CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText);
+ std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size());
+ CHECK_EQ(Text, "server says hello"sv);
+
+ Sock.close();
+ }
+
+ SUBCASE("client close handshake")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send close frame
+ std::vector<uint8_t> CloseFrame = BuildMaskedCloseFrame(1000);
+ asio::write(Sock, asio::buffer(CloseFrame));
+
+ // Server should echo close back
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose);
+
+ Sleep(50);
+ CHECK_EQ(TestService.m_CloseCount.load(), 1);
+ CHECK_EQ(TestService.m_LastCloseCode.load(), 1000);
+
+ Sock.close();
+ }
+
+ SUBCASE("multiple concurrent connections")
+ {
+ constexpr int NumClients = 5;
+
+ asio::io_context IoCtx;
+ std::vector<asio::ip::tcp::socket> Sockets;
+
+ for (int i = 0; i < NumClients; ++i)
+ {
+ Sockets.emplace_back(IoCtx);
+ Sockets.back().connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sockets.back(), "/wstest/ws", Port);
+ REQUIRE(Ok);
+ }
+
+ Sleep(100);
+ CHECK_EQ(TestService.m_OpenCount.load(), NumClients);
+
+ // Broadcast from server
+ TestService.SendToAll("broadcast");
+
+ // Each client should receive the message
+ for (int i = 0; i < NumClients; ++i)
+ {
+ WsFrameParseResult Frame = ReadOneFrame(Sockets[i]);
+ REQUIRE(Frame.IsValid);
+ CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText);
+ std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size());
+ CHECK_EQ(Text, "broadcast"sv);
+ }
+
+ // Close all
+ for (auto& S : Sockets)
+ {
+ S.close();
+ }
+ }
+
+ SUBCASE("service without IWebSocketHandler rejects upgrade")
+ {
+ // Register a plain HTTP service (no WebSocket)
+ struct PlainService : public HttpService
+ {
+ const char* BaseUri() const override { return "/plain/"; }
+ void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK); }
+ };
+
+ PlainService Plain;
+ Server->RegisterService(Plain);
+
+ Sleep(50);
+
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ // Attempt WebSocket upgrade on the plain service
+ ExtendableStringBuilder<512> Request;
+ Request << "GET /plain/ws HTTP/1.1\r\n"
+ << "Host: 127.0.0.1:" << Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ << "Sec-WebSocket-Version: 13\r\n"
+ << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+ asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size()));
+
+ asio::streambuf ResponseBuf;
+ asio::read_until(Sock, ResponseBuf, "\r\n\r\n");
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+
+ // Should NOT get 101 — should fall through to normal request handling
+ CHECK(Response.find("101") == std::string::npos);
+
+ Sock.close();
+ }
+
+ SUBCASE("ping/pong auto-response")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send a ping frame with payload "test"
+ std::string_view PingPayload = "test";
+ std::span<const uint8_t> PingData(reinterpret_cast<const uint8_t*>(PingPayload.data()), PingPayload.size());
+ std::vector<uint8_t> PingFrame = BuildMaskedFrame(WebSocketOpcode::kPing, PingData);
+ asio::write(Sock, asio::buffer(PingFrame));
+
+ // Should receive a pong with the same payload
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kPong);
+ CHECK_EQ(Reply.Payload.size(), 4u);
+ std::string_view PongText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(PongText, "test"sv);
+
+ Sock.close();
+ }
+
+ SUBCASE("multiple messages in sequence")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ for (int i = 0; i < 10; ++i)
+ {
+ std::string Msg = fmt::format("message {}", i);
+ std::vector<uint8_t> Frame = BuildMaskedTextFrame(Msg);
+ asio::write(Sock, asio::buffer(Frame));
+
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText);
+ std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(ReplyText, Msg);
+ }
+
+ CHECK_EQ(TestService.m_MessageCount.load(), 10);
+
+ Sock.close();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Integration tests: HttpWsClient
+//
+
+namespace {
+
+ struct TestWsClientHandler : public IWsClientHandler
+ {
+ void OnWsOpen() override { m_OpenCount.fetch_add(1); }
+
+ void OnWsMessage(const WebSocketMessage& Msg) override
+ {
+ m_MessageCount.fetch_add(1);
+
+ if (Msg.Opcode == WebSocketOpcode::kText)
+ {
+ std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size());
+ m_LastMessage = std::string(Text);
+ }
+ }
+
+ void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override
+ {
+ m_CloseCount.fetch_add(1);
+ m_LastCloseCode = Code;
+ }
+
+ std::atomic<int> m_OpenCount{0};
+ std::atomic<int> m_MessageCount{0};
+ std::atomic<int> m_CloseCount{0};
+ std::atomic<uint16_t> m_LastCloseCode{0};
+ std::string m_LastMessage;
+ };
+
+} // anonymous namespace
+
+TEST_CASE("websocket.client")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = Server->Initialize(7576, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ Sleep(100);
+
+ SUBCASE("connect, echo, close")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port);
+
+ HttpWsClient Client(Url, Handler);
+ Client.Connect();
+
+ // Wait for OnWsOpen
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+ CHECK(Client.IsOpen());
+
+ // Send text, expect echo
+ Client.SendText("hello from client");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ CHECK_EQ(Handler.m_MessageCount.load(), 1);
+ CHECK_EQ(Handler.m_LastMessage, "hello from client");
+
+ // Close
+ Client.Close(1000, "done");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ // The server echoes the close frame, which triggers OnWsClose on the client side
+ // with the server's close code. Allow the connection to settle.
+ Sleep(50);
+ CHECK_FALSE(Client.IsOpen());
+ }
+
+ SUBCASE("connect to bad port")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = "ws://127.0.0.1:1/wstest/ws";
+
+ HttpWsClient Client(Url, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)});
+ Client.Connect();
+
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ CHECK_EQ(Handler.m_CloseCount.load(), 1);
+ CHECK_EQ(Handler.m_LastCloseCode.load(), 1006);
+ CHECK_EQ(Handler.m_OpenCount.load(), 0);
+ }
+
+ SUBCASE("server-initiated close")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port);
+
+ HttpWsClient Client(Url, Handler);
+ Client.Connect();
+
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+
+ // Copy connections then close them outside the lock to avoid deadlocking
+ // with OnWebSocketClose which acquires an exclusive lock
+ std::vector<Ref<WebSocketConnection>> Conns;
+ TestService.m_ConnectionsLock.WithSharedLock([&] { Conns = TestService.m_Connections; });
+ for (auto& Conn : Conns)
+ {
+ Conn->Close(1001, "going away");
+ }
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ CHECK_EQ(Handler.m_CloseCount.load(), 1);
+ CHECK_EQ(Handler.m_LastCloseCode.load(), 1001);
+ CHECK_FALSE(Client.IsOpen());
+ }
+}
+
+void
+websocket_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_TESTS
diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua
index 78876d21b..e8f87b668 100644
--- a/src/zenhttp/xmake.lua
+++ b/src/zenhttp/xmake.lua
@@ -6,6 +6,7 @@ target('zenhttp')
add_headerfiles("**.h")
add_files("**.cpp")
add_files("servers/httpsys.cpp", {unity_ignored=true})
+ add_files("servers/wshttpsys.cpp", {unity_ignored=true})
add_includedirs("include", {public=true})
add_deps("zencore", "zentelemetry", "transport-sdk", "asio", "cpr")
add_packages("http_parser", "json11")
diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp
index ad14ecb8d..3ac8eea8d 100644
--- a/src/zenhttp/zenhttp.cpp
+++ b/src/zenhttp/zenhttp.cpp
@@ -19,6 +19,7 @@ zenhttp_forcelinktests()
httpclient_test_forcelink();
forcelink_packageformat();
passwordsecurity_forcelink();
+ websocket_forcelink();
}
} // namespace zen