// Copyright Epic Games, Inc. All Rights Reserved. #include "wshttpsys.h" #if ZEN_WITH_HTTPSYS # include "wsframecodec.h" # include namespace zen { static LoggerRef WsHttpSysLog() { static LoggerRef g_Logger = logging::Get("ws_httpsys"); return g_Logger; } ////////////////////////////////////////////////////////////////////////// WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp) : m_RequestQueueHandle(RequestQueueHandle) , m_RequestId(RequestId) , m_Handler(Handler) , m_Iocp(Iocp) , m_ReadBuffer(8192) { m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead; m_ReadIoContext.Owner = this; m_WriteIoContext.ContextType = HttpSysIoContext::Type::kWebSocketWrite; m_WriteIoContext.Owner = this; } WsHttpSysConnection::~WsHttpSysConnection() { ZEN_ASSERT(m_OutstandingOps.load() == 0); if (m_IsOpen.exchange(false)) { Disconnect(); } } void WsHttpSysConnection::Start() { m_SelfRef = Ref(this); IssueAsyncRead(); } void WsHttpSysConnection::Shutdown() { m_ShutdownRequested.store(true, std::memory_order_relaxed); if (!m_IsOpen.exchange(false)) { return; } // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); } bool WsHttpSysConnection::IsOpen() const { return m_IsOpen.load(std::memory_order_relaxed); } ////////////////////////////////////////////////////////////////////////// // // Async read path // void WsHttpSysConnection::IssueAsyncRead() { if (!m_IsOpen.load(std::memory_order_relaxed) || m_ShutdownRequested.load(std::memory_order_relaxed)) { MaybeReleaseSelfRef(); return; } m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); ZeroMemory(&m_ReadIoContext.Overlapped, sizeof(OVERLAPPED)); StartThreadpoolIo(m_Iocp); ULONG Result = HttpReceiveRequestEntityBody(m_RequestQueueHandle, m_RequestId, 0, // Flags m_ReadBuffer.data(), (ULONG)m_ReadBuffer.size(), nullptr, // BytesRead (ignored for async) &m_ReadIoContext.Overlapped); if (Result != NO_ERROR && Result != ERROR_IO_PENDING) { CancelThreadpoolIo(m_Iocp); m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); if (m_IsOpen.exchange(false)) { m_Handler.OnWebSocketClose(*this, 1006, "read issue failed"); } MaybeReleaseSelfRef(); } } void WsHttpSysConnection::OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { // Hold a transient ref to prevent mid-callback destruction after MaybeReleaseSelfRef Ref Guard(this); if (IoResult != NO_ERROR) { m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); if (m_IsOpen.exchange(false)) { if (IoResult == ERROR_HANDLE_EOF) { m_Handler.OnWebSocketClose(*this, 1006, "connection closed"); } else if (IoResult != ERROR_OPERATION_ABORTED) { m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); } } MaybeReleaseSelfRef(); return; } if (NumberOfBytesTransferred > 0) { m_Accumulated.insert(m_Accumulated.end(), m_ReadBuffer.begin(), m_ReadBuffer.begin() + NumberOfBytesTransferred); ProcessReceivedData(); } m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); if (m_IsOpen.load(std::memory_order_relaxed)) { IssueAsyncRead(); } else { MaybeReleaseSelfRef(); } } ////////////////////////////////////////////////////////////////////////// // // Frame parsing // void WsHttpSysConnection::ProcessReceivedData() { while (!m_Accumulated.empty()) { WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size()); if (!Frame.IsValid) { break; // not enough data yet } // Remove consumed bytes m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed); switch (Frame.Opcode) { case WebSocketOpcode::kText: case WebSocketOpcode::kBinary: { WebSocketMessage Msg; Msg.Opcode = Frame.Opcode; Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); m_Handler.OnWebSocketMessage(*this, Msg); break; } case WebSocketOpcode::kPing: { // Auto-respond with pong carrying the same payload std::vector PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); EnqueueWrite(std::move(PongFrame)); break; } case WebSocketOpcode::kPong: // Unsolicited pong — ignore per RFC 6455 break; case WebSocketOpcode::kClose: { uint16_t Code = 1000; std::string_view Reason; if (Frame.Payload.size() >= 2) { Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); if (Frame.Payload.size() > 2) { Reason = std::string_view(reinterpret_cast(Frame.Payload.data() + 2), Frame.Payload.size() - 2); } } // Echo close frame back if we haven't sent one yet { bool ShouldSendClose = false; { RwLock::ExclusiveLockScope _(m_WriteLock); if (!m_CloseSent) { m_CloseSent = true; ShouldSendClose = true; } } if (ShouldSendClose) { std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code); EnqueueWrite(std::move(CloseFrame)); } } m_IsOpen.store(false); m_Handler.OnWebSocketClose(*this, Code, Reason); Disconnect(); return; } default: ZEN_LOG_WARN(WsHttpSysLog(), "Unknown WebSocket opcode: {:#x}", static_cast(Frame.Opcode)); break; } } } ////////////////////////////////////////////////////////////////////////// // // Async write path // void WsHttpSysConnection::EnqueueWrite(std::vector Frame) { bool ShouldFlush = false; { RwLock::ExclusiveLockScope _(m_WriteLock); m_WriteQueue.push_back(std::move(Frame)); if (!m_IsWriting) { m_IsWriting = true; ShouldFlush = true; } } if (ShouldFlush) { FlushWriteQueue(); } } void WsHttpSysConnection::FlushWriteQueue() { { RwLock::ExclusiveLockScope _(m_WriteLock); if (m_WriteQueue.empty()) { m_IsWriting = false; return; } m_CurrentWriteBuffer = std::move(m_WriteQueue.front()); m_WriteQueue.pop_front(); } m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); ZeroMemory(&m_WriteChunk, sizeof(m_WriteChunk)); m_WriteChunk.DataChunkType = HttpDataChunkFromMemory; m_WriteChunk.FromMemory.pBuffer = m_CurrentWriteBuffer.data(); m_WriteChunk.FromMemory.BufferLength = (ULONG)m_CurrentWriteBuffer.size(); ZeroMemory(&m_WriteIoContext.Overlapped, sizeof(OVERLAPPED)); StartThreadpoolIo(m_Iocp); ULONG Result = HttpSendResponseEntityBody(m_RequestQueueHandle, m_RequestId, HTTP_SEND_RESPONSE_FLAG_MORE_DATA, 1, &m_WriteChunk, nullptr, nullptr, 0, &m_WriteIoContext.Overlapped, nullptr); if (Result != NO_ERROR && Result != ERROR_IO_PENDING) { CancelThreadpoolIo(m_Iocp); m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket async write failed: {}", Result); { RwLock::ExclusiveLockScope _(m_WriteLock); m_WriteQueue.clear(); m_IsWriting = false; } m_CurrentWriteBuffer.clear(); if (m_IsOpen.exchange(false)) { m_Handler.OnWebSocketClose(*this, 1006, "write error"); } MaybeReleaseSelfRef(); } } void WsHttpSysConnection::OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { ZEN_UNUSED(NumberOfBytesTransferred); // Hold a transient ref to prevent mid-callback destruction Ref Guard(this); m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); m_CurrentWriteBuffer.clear(); if (IoResult != NO_ERROR) { ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket write completion error: {}", IoResult); { RwLock::ExclusiveLockScope _(m_WriteLock); m_WriteQueue.clear(); m_IsWriting = false; } if (m_IsOpen.exchange(false)) { m_Handler.OnWebSocketClose(*this, 1006, "write error"); } MaybeReleaseSelfRef(); return; } FlushWriteQueue(); } ////////////////////////////////////////////////////////////////////////// // // Send interface // void WsHttpSysConnection::SendText(std::string_view Text) { if (!m_IsOpen.load(std::memory_order_relaxed)) { return; } std::span Payload(reinterpret_cast(Text.data()), Text.size()); std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); EnqueueWrite(std::move(Frame)); } void WsHttpSysConnection::SendBinary(std::span Data) { if (!m_IsOpen.load(std::memory_order_relaxed)) { return; } std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); EnqueueWrite(std::move(Frame)); } void WsHttpSysConnection::Close(uint16_t Code, std::string_view Reason) { DoClose(Code, Reason); } void WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason) { if (!m_IsOpen.exchange(false)) { return; } { bool ShouldSendClose = false; { RwLock::ExclusiveLockScope _(m_WriteLock); if (!m_CloseSent) { m_CloseSent = true; ShouldSendClose = true; } } if (ShouldSendClose) { std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); EnqueueWrite(std::move(CloseFrame)); } } m_Handler.OnWebSocketClose(*this, Code, Reason); // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); } ////////////////////////////////////////////////////////////////////////// // // Lifetime management // void WsHttpSysConnection::MaybeReleaseSelfRef() { if (m_OutstandingOps.load(std::memory_order_relaxed) == 0 && !m_IsOpen.load(std::memory_order_relaxed)) { m_SelfRef = nullptr; } } void WsHttpSysConnection::Disconnect() { // Send final empty body with DISCONNECT to tell http.sys the connection is done HttpSendResponseEntityBody(m_RequestQueueHandle, m_RequestId, HTTP_SEND_RESPONSE_FLAG_DISCONNECT, 0, nullptr, nullptr, nullptr, 0, nullptr, nullptr); } } // namespace zen #endif // ZEN_WITH_HTTPSYS