// Copyright Epic Games, Inc. All Rights Reserved. #include #include "../servers/wsframecodec.h" #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include ZEN_THIRD_PARTY_INCLUDES_END #include #include #include namespace zen { ////////////////////////////////////////////////////////////////////////// struct HttpWsClient::Impl { Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) : m_Handler(Handler) , m_Settings(Settings) , m_Log(logging::Get(Settings.LogCategory)) , m_OwnedIoContext(std::make_unique()) , m_IoContext(*m_OwnedIoContext) { ParseUrl(Url); } Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings) : m_Handler(Handler) , m_Settings(Settings) , m_Log(logging::Get(Settings.LogCategory)) , m_IoContext(IoContext) { ParseUrl(Url); } ~Impl() { // Release work guard so io_context::run() can return m_WorkGuard.reset(); // Close the socket to cancel pending async ops if (m_Socket) { asio::error_code Ec; m_Socket->close(Ec); } if (m_IoThread.joinable()) { m_IoThread.join(); } } void ParseUrl(std::string_view Url) { // Expected format: ws://host:port/path if (Url.substr(0, 5) == "ws://") { Url.remove_prefix(5); } auto SlashPos = Url.find('/'); std::string_view HostPort; if (SlashPos != std::string_view::npos) { HostPort = Url.substr(0, SlashPos); m_Path = std::string(Url.substr(SlashPos)); } else { HostPort = Url; m_Path = "/"; } auto ColonPos = HostPort.find(':'); if (ColonPos != std::string_view::npos) { m_Host = std::string(HostPort.substr(0, ColonPos)); m_Port = std::string(HostPort.substr(ColonPos + 1)); } else { m_Host = std::string(HostPort); m_Port = "80"; } } void Connect() { if (m_OwnedIoContext) { m_WorkGuard = std::make_unique(m_IoContext); m_IoThread = std::thread([this] { m_IoContext.run(); }); } asio::post(m_IoContext, [this] { DoResolve(); }); } void DoResolve() { m_Resolver = std::make_unique(m_IoContext); m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) { if (Ec) { ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message()); m_Handler.OnWsClose(1006, "resolve failed"); return; } DoConnect(Results); }); } void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints) { m_Socket = std::make_unique(m_IoContext); // Start connect timeout timer m_Timer = std::make_unique(m_IoContext, m_Settings.ConnectTimeout); m_Timer->async_wait([this](const asio::error_code& Ec) { if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) { ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port); if (m_Socket) { asio::error_code CloseEc; m_Socket->close(CloseEc); } } }); asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) { if (Ec) { m_Timer->cancel(); ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message()); m_Handler.OnWsClose(1006, "connect failed"); return; } DoHandshake(); }); } void DoHandshake() { // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded) uint8_t KeyBytes[16]; { static thread_local std::mt19937 s_Rng(std::random_device{}()); for (int i = 0; i < 4; ++i) { uint32_t Val = s_Rng(); std::memcpy(KeyBytes + i * 4, &Val, 4); } } char KeyBase64[Base64::GetEncodedDataSize(16) + 1]; uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64); KeyBase64[KeyLen] = '\0'; m_WebSocketKey = std::string(KeyBase64, KeyLen); // Build the HTTP upgrade request ExtendableStringBuilder<512> Request; Request << "GET " << m_Path << " HTTP/1.1\r\n" << "Host: " << m_Host << ":" << m_Port << "\r\n" << "Upgrade: websocket\r\n" << "Connection: Upgrade\r\n" << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n" << "Sec-WebSocket-Version: 13\r\n"; // Add Authorization header if access token provider is set if (m_Settings.AccessTokenProvider) { HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); if (Token.IsValid()) { Request << "Authorization: Bearer " << Token.Value << "\r\n"; } } Request << "\r\n"; std::string_view ReqStr = Request.ToView(); m_HandshakeBuffer = std::make_shared(ReqStr); asio::async_write(*m_Socket, asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()), [this](const asio::error_code& Ec, std::size_t) { if (Ec) { m_Timer->cancel(); ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message()); m_Handler.OnWsClose(1006, "handshake write failed"); return; } DoReadHandshakeResponse(); }); } void DoReadHandshakeResponse() { asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) { m_Timer->cancel(); if (Ec) { ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message()); m_Handler.OnWsClose(1006, "handshake read failed"); return; } // Parse the response const auto& Data = m_ReadBuffer.data(); std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); // Consume the headers from the read buffer (any extra data stays for frame parsing) auto HeaderEnd = Response.find("\r\n\r\n"); if (HeaderEnd != std::string::npos) { m_ReadBuffer.consume(HeaderEnd + 4); } // Validate 101 response if (Response.find("101") == std::string::npos) { ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); m_Handler.OnWsClose(1006, "handshake rejected"); return; } // Validate Sec-WebSocket-Accept std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); if (Response.find(ExpectedAccept) == std::string::npos) { ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); m_Handler.OnWsClose(1006, "invalid accept key"); return; } m_IsOpen.store(true); m_Handler.OnWsOpen(); EnqueueRead(); }); } ////////////////////////////////////////////////////////////////////////// // // Read loop // void EnqueueRead() { if (!m_IsOpen.load(std::memory_order_relaxed)) { return; } asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) { OnDataReceived(Ec); }); } void OnDataReceived(const asio::error_code& Ec) { if (Ec) { if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) { ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message()); } if (m_IsOpen.exchange(false)) { m_Handler.OnWsClose(1006, "connection lost"); } return; } ProcessReceivedData(); if (m_IsOpen.load(std::memory_order_relaxed)) { EnqueueRead(); } } void ProcessReceivedData() { while (m_ReadBuffer.size() > 0) { const auto& InputBuffer = m_ReadBuffer.data(); const auto* RawData = static_cast(InputBuffer.data()); const auto Size = InputBuffer.size(); WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size); if (!Frame.IsValid) { break; } m_ReadBuffer.consume(Frame.BytesConsumed); switch (Frame.Opcode) { case WebSocketOpcode::kText: case WebSocketOpcode::kBinary: { WebSocketMessage Msg; Msg.Opcode = Frame.Opcode; Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); m_Handler.OnWsMessage(Msg); break; } case WebSocketOpcode::kPing: { // Auto-respond with masked pong std::vector PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload); EnqueueWrite(std::move(PongFrame)); break; } case WebSocketOpcode::kPong: break; case WebSocketOpcode::kClose: { uint16_t Code = 1000; std::string_view Reason; if (Frame.Payload.size() >= 2) { Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); if (Frame.Payload.size() > 2) { Reason = std::string_view(reinterpret_cast(Frame.Payload.data() + 2), Frame.Payload.size() - 2); } } // Echo masked close frame if we haven't sent one yet if (!m_CloseSent) { m_CloseSent = true; std::vector CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code); EnqueueWrite(std::move(CloseFrame)); } m_IsOpen.store(false); m_Handler.OnWsClose(Code, Reason); return; } default: ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast(Frame.Opcode)); break; } } } ////////////////////////////////////////////////////////////////////////// // // Write queue // void EnqueueWrite(std::vector Frame) { bool ShouldFlush = false; m_WriteLock.WithExclusiveLock([&] { m_WriteQueue.push_back(std::move(Frame)); if (!m_IsWriting) { m_IsWriting = true; ShouldFlush = true; } }); if (ShouldFlush) { FlushWriteQueue(); } } void FlushWriteQueue() { std::vector Frame; m_WriteLock.WithExclusiveLock([&] { if (m_WriteQueue.empty()) { m_IsWriting = false; return; } Frame = std::move(m_WriteQueue.front()); m_WriteQueue.pop_front(); }); if (Frame.empty()) { return; } auto OwnedFrame = std::make_shared>(std::move(Frame)); asio::async_write(*m_Socket, asio::buffer(OwnedFrame->data(), OwnedFrame->size()), [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); }); } void OnWriteComplete(const asio::error_code& Ec) { if (Ec) { if (Ec != asio::error::operation_aborted) { ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message()); } m_WriteLock.WithExclusiveLock([&] { m_IsWriting = false; m_WriteQueue.clear(); }); if (m_IsOpen.exchange(false)) { m_Handler.OnWsClose(1006, "write error"); } return; } FlushWriteQueue(); } ////////////////////////////////////////////////////////////////////////// // // Public operations // void SendText(std::string_view Text) { if (!m_IsOpen.load(std::memory_order_relaxed)) { return; } std::span Payload(reinterpret_cast(Text.data()), Text.size()); std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); EnqueueWrite(std::move(Frame)); } void SendBinary(std::span Data) { if (!m_IsOpen.load(std::memory_order_relaxed)) { return; } std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data); EnqueueWrite(std::move(Frame)); } void DoClose(uint16_t Code, std::string_view Reason) { if (!m_IsOpen.exchange(false)) { return; } if (!m_CloseSent) { m_CloseSent = true; std::vector CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason); EnqueueWrite(std::move(CloseFrame)); } } IWsClientHandler& m_Handler; HttpWsClientSettings m_Settings; LoggerRef m_Log; std::string m_Host; std::string m_Port; std::string m_Path; // io_context: owned (standalone) or external (shared) std::unique_ptr m_OwnedIoContext; asio::io_context& m_IoContext; std::unique_ptr m_WorkGuard; std::thread m_IoThread; // Connection state std::unique_ptr m_Resolver; std::unique_ptr m_Socket; std::unique_ptr m_Timer; asio::streambuf m_ReadBuffer; std::string m_WebSocketKey; std::shared_ptr m_HandshakeBuffer; // Write queue RwLock m_WriteLock; std::deque> m_WriteQueue; bool m_IsWriting = false; std::atomic m_IsOpen{false}; bool m_CloseSent = false; }; ////////////////////////////////////////////////////////////////////////// HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) : m_Impl(std::make_unique(Url, Handler, Settings)) { } HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings) : m_Impl(std::make_unique(Url, Handler, IoContext, Settings)) { } HttpWsClient::~HttpWsClient() = default; void HttpWsClient::Connect() { m_Impl->Connect(); } void HttpWsClient::SendText(std::string_view Text) { m_Impl->SendText(Text); } void HttpWsClient::SendBinary(std::span Data) { m_Impl->SendBinary(Data); } void HttpWsClient::Close(uint16_t Code, std::string_view Reason) { m_Impl->DoClose(Code, Reason); } bool HttpWsClient::IsOpen() const { return m_Impl->m_IsOpen.load(std::memory_order_relaxed); } } // namespace zen