diff options
| author | Stefan Boberg <[email protected]> | 2026-02-27 17:13:40 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-02-27 17:13:40 +0100 |
| commit | 0a41fd42aa43080fbc991e7d976dde70aeaec594 (patch) | |
| tree | 765ce661d98b3659a58091afcaad587f03f4bea9 /src/zenhttp/servers/wsframecodec.cpp | |
| parent | add sentry-sdk logger (#793) (diff) | |
| download | zen-0a41fd42aa43080fbc991e7d976dde70aeaec594.tar.xz zen-0a41fd42aa43080fbc991e7d976dde70aeaec594.zip | |
add full WebSocket (RFC 6455) client/server support for zenhttp (#792)
* This branch adds full WebSocket (RFC 6455) support to the HTTP server layer, covering both transport backends, a client, and tests.
- **`websocket.h`** -- Core interfaces: `WebSocketOpcode`, `WebSocketMessage`, `WebSocketConnection` (ref-counted), and `IWebSocketHandler`. Services opt in to WebSocket support by implementing `IWebSocketHandler` alongside their existing `HttpService`.
- **`httpwsclient.h`** -- `HttpWsClient`: an ASIO-backed `ws://` client with both standalone (own thread) and shared `io_context` modes. Supports connect timeout and optional auth token injection via `IWsClientHandler` callbacks.
- **`wsasio.cpp/h`** -- `WsAsioConnection`: WebSocket over ASIO TCP. Takes over the socket after the HTTP 101 handshake and runs an async read/write loop with a queued write path (guarded by `RwLock`).
- **`wshttpsys.cpp/h`** -- `WsHttpSysConnection`: WebSocket over http.sys opaque-mode connections (Windows only). Uses `HttpReceiveRequestEntityBody` / `HttpSendResponseEntityBody` via IOCP, sharing the same threadpool as normal http.sys traffic. Self-ref lifetime management ensures graceful drain of outstanding async ops.
- **`httpsys_iocontext.h`** -- Tagged `OVERLAPPED` wrapper (`HttpSysIoContext`) used to distinguish normal HTTP transactions from WebSocket read/write completions in the single IOCP callback.
- **`wsframecodec.cpp/h`** -- `WsFrameCodec`: static helpers for parsing (unmasked and masked) and building (unmasked server frames and masked client frames) RFC 6455 frames across all three payload length encodings (7-bit, 16-bit, 64-bit). Also computes `Sec-WebSocket-Accept` keys.
- **`clients/httpwsclient.cpp`** -- `HttpWsClient::Impl`: ASIO-based client that performs the HTTP upgrade handshake, then hands off to the frame codec for the read loop. Manages its own `io_context` thread or plugs into an external one.
- **`httpasio.cpp`** -- ASIO server now detects `Upgrade: websocket` requests, checks the matching `HttpService` for `IWebSocketHandler` via `dynamic_cast`, performs the RFC 6455 handshake (101 response), and spins up a `WsAsioConnection`.
- **`httpsys.cpp`** -- Same upgrade detection and handshake logic for the http.sys backend, using `WsHttpSysConnection` and `HTTP_SEND_RESPONSE_FLAG_OPAQUE`.
- **`httpparser.cpp/h`** -- Extended to surface the `Upgrade` / `Connection` / `Sec-WebSocket-Key` headers needed by the handshake.
- **`httpcommon.h`** -- Minor additions (probably new header constants or response codes for the WS upgrade).
- **`httpserver.h`** -- Small interface changes to support WebSocket registration.
- **`zenhttp.cpp` / `xmake.lua`** -- New source files wired in; build config updated.
- **Unit tests** (`websocket.framecodec`): round-trip encode/decode for text, binary, close frames; all three payload sizes; masked and unmasked variants; RFC 6455 `Sec-WebSocket-Accept` test vector.
- **Integration tests** (`websocket.integration`): full ASIO server tests covering handshake (101), normal HTTP coexistence, echo, server-push broadcast, client close handshake, ping/pong auto-response, sequential messages, and rejection of upgrades on non-WS services.
- **Client tests** (`websocket.client`): `HttpWsClient` connect+echo+close, connection failure (bad port -> close code 1006), and server-initiated close.
* changed HttpRequestParser::ParseCurrentHeader to use switch instead of if/else chain
* remove spurious printf
---------
Co-authored-by: Stefan Boberg <[email protected]>
Diffstat (limited to 'src/zenhttp/servers/wsframecodec.cpp')
| -rw-r--r-- | src/zenhttp/servers/wsframecodec.cpp | 229 |
1 files changed, 229 insertions, 0 deletions
diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp new file mode 100644 index 000000000..a4c5e0f16 --- /dev/null +++ b/src/zenhttp/servers/wsframecodec.cpp @@ -0,0 +1,229 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wsframecodec.h" + +#include <zencore/base64.h> +#include <zencore/sha1.h> + +#include <cstring> +#include <random> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Frame parsing +// + +WsFrameParseResult +WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) +{ + // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames) + if (Size < 2) + { + return {}; + } + + const bool Fin = (Data[0] & 0x80) != 0; + const uint8_t OpcodeRaw = Data[0] & 0x0F; + const bool Masked = (Data[1] & 0x80) != 0; + uint64_t PayloadLen = Data[1] & 0x7F; + + size_t HeaderSize = 2; + + if (PayloadLen == 126) + { + if (Size < 4) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]); + HeaderSize = 4; + } + else if (PayloadLen == 127) + { + if (Size < 10) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) | + (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]); + HeaderSize = 10; + } + + const size_t MaskSize = Masked ? 4 : 0; + const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen; + + if (Size < TotalFrame) + { + return {}; + } + + const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr; + const uint8_t* PayloadData = Data + HeaderSize + MaskSize; + + WsFrameParseResult Result; + Result.IsValid = true; + Result.BytesConsumed = TotalFrame; + Result.Opcode = static_cast<WebSocketOpcode>(OpcodeRaw); + Result.Fin = Fin; + + Result.Payload.resize(static_cast<size_t>(PayloadLen)); + if (PayloadLen > 0) + { + std::memcpy(Result.Payload.data(), PayloadData, static_cast<size_t>(PayloadLen)); + + if (Masked) + { + for (size_t i = 0; i < Result.Payload.size(); ++i) + { + Result.Payload[i] ^= MaskKey[i & 3]; + } + } + } + + return Result; +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (server-to-client, no masking) +// + +std::vector<uint8_t> +WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload) +{ + std::vector<uint8_t> Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast<uint8_t>(Opcode)); + + // Payload length (no mask bit for server frames) + if (PayloadLen < 126) + { + Frame.push_back(static_cast<uint8_t>(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(126); + Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + Frame.insert(Frame.end(), Payload.begin(), Payload.end()); + + return Frame; +} + +std::vector<uint8_t> +WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector<uint8_t> Payload; + Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF)); + Payload.push_back(static_cast<uint8_t>(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (client-to-server, with masking) +// + +std::vector<uint8_t> +WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload) +{ + std::vector<uint8_t> Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast<uint8_t>(Opcode)); + + // Payload length with mask bit set + if (PayloadLen < 126) + { + Frame.push_back(0x80 | static_cast<uint8_t>(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + // Generate random 4-byte mask key + static thread_local std::mt19937 s_Rng(std::random_device{}()); + uint32_t MaskValue = s_Rng(); + uint8_t MaskKey[4]; + std::memcpy(MaskKey, &MaskValue, 4); + + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + // Masked payload + for (size_t i = 0; i < PayloadLen; ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; +} + +std::vector<uint8_t> +WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector<uint8_t> Payload; + Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF)); + Payload.push_back(static_cast<uint8_t>(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2) +// + +static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +std::string +WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey) +{ + // Concatenate client key with the magic GUID + std::string Combined; + Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size()); + Combined.append(ClientKey); + Combined.append(kWebSocketMagicGuid); + + // SHA1 hash + SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size()); + + // Base64 encode the 20-byte hash + char Base64Buf[Base64::GetEncodedDataSize(20) + 1]; + uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf); + Base64Buf[EncodedLen] = '\0'; + + return std::string(Base64Buf, EncodedLen); +} + +} // namespace zen |