// Copyright Epic Games, Inc. All Rights Reserved. #include "wsframecodec.h" #include #include #include #include namespace zen { ////////////////////////////////////////////////////////////////////////// // // Frame parsing // WsFrameParseResult WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) { // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames) if (Size < 2) { return {}; } const bool Fin = (Data[0] & 0x80) != 0; const uint8_t OpcodeRaw = Data[0] & 0x0F; const bool Masked = (Data[1] & 0x80) != 0; uint64_t PayloadLen = Data[1] & 0x7F; size_t HeaderSize = 2; if (PayloadLen == 126) { if (Size < 4) { return {}; } PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]); HeaderSize = 4; } else if (PayloadLen == 127) { if (Size < 10) { return {}; } PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) | (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]); HeaderSize = 10; } const size_t MaskSize = Masked ? 4 : 0; const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen; if (Size < TotalFrame) { return {}; } const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr; const uint8_t* PayloadData = Data + HeaderSize + MaskSize; WsFrameParseResult Result; Result.IsValid = true; Result.BytesConsumed = TotalFrame; Result.Opcode = static_cast(OpcodeRaw); Result.Fin = Fin; Result.Payload.resize(static_cast(PayloadLen)); if (PayloadLen > 0) { std::memcpy(Result.Payload.data(), PayloadData, static_cast(PayloadLen)); if (Masked) { for (size_t i = 0; i < Result.Payload.size(); ++i) { Result.Payload[i] ^= MaskKey[i & 3]; } } } return Result; } ////////////////////////////////////////////////////////////////////////// // // Frame building (server-to-client, no masking) // std::vector WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span Payload) { std::vector Frame; const size_t PayloadLen = Payload.size(); // FIN + opcode Frame.push_back(0x80 | static_cast(Opcode)); // Payload length (no mask bit for server frames) if (PayloadLen < 126) { Frame.push_back(static_cast(PayloadLen)); } else if (PayloadLen <= 0xFFFF) { Frame.push_back(126); Frame.push_back(static_cast((PayloadLen >> 8) & 0xFF)); Frame.push_back(static_cast(PayloadLen & 0xFF)); } else { Frame.push_back(127); for (int i = 7; i >= 0; --i) { Frame.push_back(static_cast((PayloadLen >> (i * 8)) & 0xFF)); } } Frame.insert(Frame.end(), Payload.begin(), Payload.end()); return Frame; } std::vector WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason) { std::vector Payload; Payload.push_back(static_cast((Code >> 8) & 0xFF)); Payload.push_back(static_cast(Code & 0xFF)); Payload.insert(Payload.end(), Reason.begin(), Reason.end()); return BuildFrame(WebSocketOpcode::kClose, Payload); } ////////////////////////////////////////////////////////////////////////// // // Frame building (client-to-server, with masking) // std::vector WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span Payload) { std::vector Frame; const size_t PayloadLen = Payload.size(); // FIN + opcode Frame.push_back(0x80 | static_cast(Opcode)); // Payload length with mask bit set if (PayloadLen < 126) { Frame.push_back(0x80 | static_cast(PayloadLen)); } else if (PayloadLen <= 0xFFFF) { Frame.push_back(0x80 | 126); Frame.push_back(static_cast((PayloadLen >> 8) & 0xFF)); Frame.push_back(static_cast(PayloadLen & 0xFF)); } else { Frame.push_back(0x80 | 127); for (int i = 7; i >= 0; --i) { Frame.push_back(static_cast((PayloadLen >> (i * 8)) & 0xFF)); } } // Generate random 4-byte mask key static thread_local std::mt19937 s_Rng(std::random_device{}()); uint32_t MaskValue = s_Rng(); uint8_t MaskKey[4]; std::memcpy(MaskKey, &MaskValue, 4); Frame.insert(Frame.end(), MaskKey, MaskKey + 4); // Masked payload for (size_t i = 0; i < PayloadLen; ++i) { Frame.push_back(Payload[i] ^ MaskKey[i & 3]); } return Frame; } std::vector WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason) { std::vector Payload; Payload.push_back(static_cast((Code >> 8) & 0xFF)); Payload.push_back(static_cast(Code & 0xFF)); Payload.insert(Payload.end(), Reason.begin(), Reason.end()); return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); } ////////////////////////////////////////////////////////////////////////// // // Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2) // static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; std::string WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey) { // Concatenate client key with the magic GUID std::string Combined; Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size()); Combined.append(ClientKey); Combined.append(kWebSocketMagicGuid); // SHA1 hash SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size()); // Base64 encode the 20-byte hash char Base64Buf[Base64::GetEncodedDataSize(20) + 1]; uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf); Base64Buf[EncodedLen] = '\0'; return std::string(Base64Buf, EncodedLen); } } // namespace zen