// Copyright Epic Games, Inc. All Rights Reserved. #include "wsasio.h" #include "wsframecodec.h" #include namespace zen::asio_http { static LoggerRef WsLog() { static LoggerRef g_Logger = logging::Get("ws"); return g_Logger; } ////////////////////////////////////////////////////////////////////////// WsAsioConnection::WsAsioConnection(std::unique_ptr Socket, IWebSocketHandler& Handler) : m_Socket(std::move(Socket)) , m_Handler(Handler) { } WsAsioConnection::~WsAsioConnection() { m_IsOpen.store(false); } void WsAsioConnection::Start() { EnqueueRead(); } bool WsAsioConnection::IsOpen() const { return m_IsOpen.load(std::memory_order_relaxed); } ////////////////////////////////////////////////////////////////////////// // // Read loop // void WsAsioConnection::EnqueueRead() { if (!m_IsOpen.load(std::memory_order_relaxed)) { return; } Ref Self(this); asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnDataReceived(Ec, ByteCount); }); } void WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) { if (Ec) { if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) { ZEN_LOG_DEBUG(WsLog(), "WebSocket read error: {}", Ec.message()); } if (m_IsOpen.exchange(false)) { m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); } return; } ProcessReceivedData(); if (m_IsOpen.load(std::memory_order_relaxed)) { EnqueueRead(); } } void WsAsioConnection::ProcessReceivedData() { while (m_ReadBuffer.size() > 0) { const auto& InputBuffer = m_ReadBuffer.data(); const auto* Data = static_cast(InputBuffer.data()); const auto Size = InputBuffer.size(); WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size); if (!Frame.IsValid) { break; // not enough data yet } m_ReadBuffer.consume(Frame.BytesConsumed); switch (Frame.Opcode) { case WebSocketOpcode::kText: case WebSocketOpcode::kBinary: { WebSocketMessage Msg; Msg.Opcode = Frame.Opcode; Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); m_Handler.OnWebSocketMessage(*this, Msg); break; } case WebSocketOpcode::kPing: { // Auto-respond with pong carrying the same payload std::vector PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); EnqueueWrite(std::move(PongFrame)); break; } case WebSocketOpcode::kPong: // Unsolicited pong — ignore per RFC 6455 break; case WebSocketOpcode::kClose: { uint16_t Code = 1000; std::string_view Reason; if (Frame.Payload.size() >= 2) { Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); if (Frame.Payload.size() > 2) { Reason = std::string_view(reinterpret_cast(Frame.Payload.data() + 2), Frame.Payload.size() - 2); } } // Echo close frame back if we haven't sent one yet if (!m_CloseSent) { m_CloseSent = true; std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code); EnqueueWrite(std::move(CloseFrame)); } m_IsOpen.store(false); m_Handler.OnWebSocketClose(*this, Code, Reason); // Shut down the socket std::error_code ShutdownEc; m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc); m_Socket->close(ShutdownEc); return; } default: ZEN_LOG_WARN(WsLog(), "Unknown WebSocket opcode: {:#x}", static_cast(Frame.Opcode)); break; } } } ////////////////////////////////////////////////////////////////////////// // // Write queue // void WsAsioConnection::SendText(std::string_view Text) { if (!m_IsOpen.load(std::memory_order_relaxed)) { return; } std::span Payload(reinterpret_cast(Text.data()), Text.size()); std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); EnqueueWrite(std::move(Frame)); } void WsAsioConnection::SendBinary(std::span Data) { if (!m_IsOpen.load(std::memory_order_relaxed)) { return; } std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); EnqueueWrite(std::move(Frame)); } void WsAsioConnection::Close(uint16_t Code, std::string_view Reason) { DoClose(Code, Reason); } void WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason) { if (!m_IsOpen.exchange(false)) { return; } if (!m_CloseSent) { m_CloseSent = true; std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); EnqueueWrite(std::move(CloseFrame)); } m_Handler.OnWebSocketClose(*this, Code, Reason); } void WsAsioConnection::EnqueueWrite(std::vector Frame) { bool ShouldFlush = false; m_WriteLock.WithExclusiveLock([&] { m_WriteQueue.push_back(std::move(Frame)); if (!m_IsWriting) { m_IsWriting = true; ShouldFlush = true; } }); if (ShouldFlush) { FlushWriteQueue(); } } void WsAsioConnection::FlushWriteQueue() { std::vector Frame; m_WriteLock.WithExclusiveLock([&] { if (m_WriteQueue.empty()) { m_IsWriting = false; return; } Frame = std::move(m_WriteQueue.front()); m_WriteQueue.pop_front(); }); if (Frame.empty()) { return; } Ref Self(this); // Move Frame into a shared_ptr so we can create the buffer and capture ownership // in the same async_write call without evaluation order issues. auto OwnedFrame = std::make_shared>(std::move(Frame)); asio::async_write(*m_Socket, asio::buffer(OwnedFrame->data(), OwnedFrame->size()), [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); }); } void WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) { if (Ec) { if (Ec != asio::error::operation_aborted) { ZEN_LOG_DEBUG(WsLog(), "WebSocket write error: {}", Ec.message()); } m_WriteLock.WithExclusiveLock([&] { m_IsWriting = false; m_WriteQueue.clear(); }); if (m_IsOpen.exchange(false)) { m_Handler.OnWebSocketClose(*this, 1006, "write error"); } return; } FlushWriteQueue(); } } // namespace zen::asio_http