// Copyright Epic Games, Inc. All Rights Reserved. #include "wsframecodec.h" #include #include #include #include namespace zen { ////////////////////////////////////////////////////////////////////////// // // Frame parsing // WsFrameParseResult WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size, bool RequireMask) { // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames) if (Size < 2) { return {}; } const bool Fin = (Data[0] & 0x80) != 0; const uint8_t RsvBits = Data[0] & 0x70; const uint8_t OpcodeRaw = Data[0] & 0x0F; const bool Masked = (Data[1] & 0x80) != 0; const uint8_t ShortLength = Data[1] & 0x7F; uint64_t PayloadLen = ShortLength; const bool IsControlFrame = (OpcodeRaw & 0x08) != 0; // RFC 6455 section 5.2: RSV1/2/3 must be zero unless a negotiated extension // defines them. We do not negotiate any extensions, so any non-zero RSV bit // is a protocol violation. if (RsvBits != 0) { WsFrameParseResult Error; Error.Status = WsFrameParseStatus::kProtocolError; return Error; } // RFC 6455 section 5.5: control frames (Close / Ping / Pong and any opcode // in 0x8..0xF) MUST NOT be fragmented and MUST have a payload of 125 bytes // or less. Rejecting fragmented or oversized control frames prevents a // peer from tying up unbounded memory inside an auto-pong, and closes off // a class of smuggling tricks where handlers might observe partial control // payloads. if (IsControlFrame && (!Fin || ShortLength > 125)) { WsFrameParseResult Error; Error.Status = WsFrameParseStatus::kProtocolError; return Error; } // RFC 6455 section 5.1: a server MUST close the connection upon receiving an // unmasked client frame. Signal this distinctly from "need more data" so the // server close path can trigger a 1002 close rather than stalling for bytes // that will never satisfy the parse. if (RequireMask && !Masked) { WsFrameParseResult Error; Error.Status = WsFrameParseStatus::kProtocolError; return Error; } size_t HeaderSize = 2; if (PayloadLen == 126) { if (Size < 4) { return {}; } PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]); HeaderSize = 4; } else if (PayloadLen == 127) { if (Size < 10) { return {}; } PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) | (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]); HeaderSize = 10; } // Reject frames with unreasonable payload sizes to bound per-connection // memory. Parsers accumulate the whole frame before dispatch (see the // read loops in wsasio.cpp / wshttpsys.cpp), so this cap also bounds the // accumulator: a peer that advertises a large frame and streams bytes // slowly cannot grow buffers past this limit. 4 MB is well above anything // the monitoring / stats endpoints produce; raise it if a legitimate use // case emerges. static constexpr uint64_t kMaxPayloadSize = 4 * 1024 * 1024; // 4 MB if (PayloadLen > kMaxPayloadSize) { WsFrameParseResult Error; Error.Status = WsFrameParseStatus::kProtocolError; return Error; } const size_t MaskSize = Masked ? 4 : 0; const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen; if (Size < TotalFrame) { return {}; } const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr; const uint8_t* PayloadData = Data + HeaderSize + MaskSize; WsFrameParseResult Result; Result.Status = WsFrameParseStatus::kValid; Result.IsValid = true; Result.BytesConsumed = TotalFrame; Result.Opcode = static_cast(OpcodeRaw); Result.Fin = Fin; Result.Payload.resize(static_cast(PayloadLen)); if (PayloadLen > 0) { std::memcpy(Result.Payload.data(), PayloadData, static_cast(PayloadLen)); if (Masked) { for (size_t i = 0; i < Result.Payload.size(); ++i) { Result.Payload[i] ^= MaskKey[i & 3]; } } } return Result; } ////////////////////////////////////////////////////////////////////////// // // Frame building (server-to-client, no masking) // std::vector WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span Payload) { std::vector Frame; const size_t PayloadLen = Payload.size(); // FIN + opcode Frame.push_back(0x80 | static_cast(Opcode)); // Payload length (no mask bit for server frames) if (PayloadLen < 126) { Frame.push_back(static_cast(PayloadLen)); } else if (PayloadLen <= 0xFFFF) { Frame.push_back(126); Frame.push_back(static_cast((PayloadLen >> 8) & 0xFF)); Frame.push_back(static_cast(PayloadLen & 0xFF)); } else { Frame.push_back(127); for (int i = 7; i >= 0; --i) { Frame.push_back(static_cast((PayloadLen >> (i * 8)) & 0xFF)); } } Frame.insert(Frame.end(), Payload.begin(), Payload.end()); return Frame; } std::vector WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason) { std::vector Payload; Payload.push_back(static_cast((Code >> 8) & 0xFF)); Payload.push_back(static_cast(Code & 0xFF)); Payload.insert(Payload.end(), Reason.begin(), Reason.end()); return BuildFrame(WebSocketOpcode::kClose, Payload); } ////////////////////////////////////////////////////////////////////////// // // Frame building (client-to-server, with masking) // std::vector WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span Payload) { std::vector Frame; const size_t PayloadLen = Payload.size(); // FIN + opcode Frame.push_back(0x80 | static_cast(Opcode)); // Payload length with mask bit set if (PayloadLen < 126) { Frame.push_back(0x80 | static_cast(PayloadLen)); } else if (PayloadLen <= 0xFFFF) { Frame.push_back(0x80 | 126); Frame.push_back(static_cast((PayloadLen >> 8) & 0xFF)); Frame.push_back(static_cast(PayloadLen & 0xFF)); } else { Frame.push_back(0x80 | 127); for (int i = 7; i >= 0; --i) { Frame.push_back(static_cast((PayloadLen >> (i * 8)) & 0xFF)); } } // Generate random 4-byte mask key static thread_local std::mt19937 s_Rng(std::random_device{}()); uint32_t MaskValue = s_Rng(); uint8_t MaskKey[4]; std::memcpy(MaskKey, &MaskValue, 4); Frame.insert(Frame.end(), MaskKey, MaskKey + 4); // Masked payload for (size_t i = 0; i < PayloadLen; ++i) { Frame.push_back(Payload[i] ^ MaskKey[i & 3]); } return Frame; } std::vector WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason) { std::vector Payload; Payload.push_back(static_cast((Code >> 8) & 0xFF)); Payload.push_back(static_cast(Code & 0xFF)); Payload.insert(Payload.end(), Reason.begin(), Reason.end()); return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); } ////////////////////////////////////////////////////////////////////////// // // Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2) // static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; std::string WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey) { // Concatenate client key with the magic GUID std::string Combined; Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size()); Combined.append(ClientKey); Combined.append(kWebSocketMagicGuid); // SHA1 hash SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size()); // Base64 encode the 20-byte hash char Base64Buf[Base64::GetEncodedDataSize(20) + 1]; uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf); Base64Buf[EncodedLen] = '\0'; return std::string(Base64Buf, EncodedLen); } } // namespace zen