// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include #include #include ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_PLATFORM_WINDOWS # include #endif namespace zen::websocket { using namespace std::literals; ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv); using Clock = std::chrono::steady_clock; using TimePoint = Clock::time_point; /////////////////////////////////////////////////////////////////////////////// namespace http_header { static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; static constexpr std::string_view Upgrade = "Upgrade"sv; } // namespace http_header /////////////////////////////////////////////////////////////////////////////// enum class ParseMessageStatus : uint32_t { kError, kContinue, kDone, }; struct ParseMessageResult { ParseMessageStatus Status{}; size_t ByteCount{}; std::optional Reason; }; class MessageParser { public: virtual ~MessageParser() = default; ParseMessageResult ParseMessage(MemoryView Msg); void Reset(); protected: MessageParser() = default; virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; virtual void OnReset() = 0; BinaryWriter m_Stream; }; ParseMessageResult MessageParser::ParseMessage(MemoryView Msg) { return OnParseMessage(Msg); } void MessageParser::Reset() { OnReset(); m_Stream.Reset(); } /////////////////////////////////////////////////////////////////////////////// enum class HttpMessageParserType { kRequest, kResponse, kBoth }; class HttpMessageParser final : public MessageParser { public: using HttpHeaders = std::unordered_map; HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); } virtual ~HttpMessageParser() = default; int32_t StatusCode() const { return m_Parser.status_code; } bool IsUpgrade() const { return m_Parser.upgrade != 0; } HttpHeaders& Headers() { return m_Headers; } MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } std::string_view StatusText() const { return std::string_view(reinterpret_cast(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); } bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); private: void Initialize(); virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; int OnMessageBegin(); int OnUrl(MemoryView Url); int OnStatus(MemoryView Status); int OnHeaderField(MemoryView HeaderField); int OnHeaderValue(MemoryView HeaderValue); int OnHeadersComplete(); int OnBody(MemoryView Body); int OnMessageComplete(); struct StreamEntry { uint64_t Offset{}; uint64_t Size{}; }; struct HeaderStreamEntry { StreamEntry Field{}; StreamEntry Value{}; }; HttpMessageParserType m_Type; http_parser m_Parser; StreamEntry m_UrlEntry; StreamEntry m_StatusEntry; StreamEntry m_BodyEntry; HeaderStreamEntry m_CurrentHeader; std::vector m_HeaderEntries; HttpHeaders m_Headers; bool m_IsMsgComplete{false}; static http_parser_settings ParserSettings; }; http_parser_settings HttpMessageParser::ParserSettings = { .on_message_begin = [](http_parser* P) { return reinterpret_cast(P->data)->OnMessageBegin(); }, .on_url = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnUrl(MemoryView(Data, Size)); }, .on_status = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnStatus(MemoryView(Data, Size)); }, .on_header_field = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnHeaderField(MemoryView(Data, Size)); }, .on_header_value = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, .on_headers_complete = [](http_parser* P) { return reinterpret_cast(P->data)->OnHeadersComplete(); }, .on_body = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnBody(MemoryView(Data, Size)); }, .on_message_complete = [](http_parser* P) { return reinterpret_cast(P->data)->OnMessageComplete(); }}; void HttpMessageParser::Initialize() { http_parser_init(&m_Parser, m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE : HTTP_BOTH); m_Parser.data = this; m_UrlEntry = {}; m_StatusEntry = {}; m_CurrentHeader = {}; m_BodyEntry = {}; m_IsMsgComplete = false; m_HeaderEntries.clear(); } ParseMessageResult HttpMessageParser::OnParseMessage(MemoryView Msg) { const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast(Msg.GetData()), Msg.GetSize()); auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; if (m_Parser.http_errno != 0) { Status = ParseMessageStatus::kError; } return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; } void HttpMessageParser::OnReset() { Initialize(); } int HttpMessageParser::OnMessageBegin() { ZEN_ASSERT(m_IsMsgComplete == false); ZEN_ASSERT(m_HeaderEntries.empty()); ZEN_ASSERT(m_Headers.empty()); return 0; } int HttpMessageParser::OnStatus(MemoryView Status) { m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; m_Stream.Write(Status); return 0; } int HttpMessageParser::OnUrl(MemoryView Url) { m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; m_Stream.Write(Url); return 0; } int HttpMessageParser::OnHeaderField(MemoryView HeaderField) { if (m_CurrentHeader.Value.Size > 0) { m_HeaderEntries.push_back(m_CurrentHeader); m_CurrentHeader = {}; } if (m_CurrentHeader.Field.Size == 0) { m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); } m_CurrentHeader.Field.Size += HeaderField.GetSize(); m_Stream.Write(HeaderField); return 0; } int HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) { if (m_CurrentHeader.Value.Size == 0) { m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); } m_CurrentHeader.Value.Size += HeaderValue.GetSize(); m_Stream.Write(HeaderValue); return 0; } int HttpMessageParser::OnHeadersComplete() { if (m_CurrentHeader.Value.Size > 0) { m_HeaderEntries.push_back(m_CurrentHeader); m_CurrentHeader = {}; } m_Headers.clear(); m_Headers.reserve(m_HeaderEntries.size()); const char* StreamData = reinterpret_cast(m_Stream.Data()); for (const auto& Entry : m_HeaderEntries) { auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); m_Headers.try_emplace(std::move(Field), std::move(Value)); } return 0; } int HttpMessageParser::OnBody(MemoryView Body) { m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; m_Stream.Write(Body); return 0; } int HttpMessageParser::OnMessageComplete() { m_IsMsgComplete = true; return 0; } bool HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) { static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; OutAcceptHash = std::string(); if (m_Headers.contains(http_header::SecWebSocketKey) == false) { OutReason = "Missing header Sec-WebSocket-Key"; return false; } if (m_Headers.contains(http_header::Upgrade) == false) { OutReason = "Missing header Upgrade"; return false; } ExtendableStringBuilder<128> Sb; Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; SHA1Stream HashStream; HashStream.Append(Sb.Data(), Sb.Size()); SHA1 Hash = HashStream.GetHash(); OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); return true; } /////////////////////////////////////////////////////////////////////////////// class WebSocketMessageParser final : public MessageParser { public: WebSocketMessageParser() : MessageParser() {} WebSocketMessage ConsumeMessage(); private: virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; WebSocketMessage m_Message; }; ParseMessageResult WebSocketMessageParser::OnParseMessage(MemoryView Msg) { ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage"); const uint64_t PrevOffset = m_Stream.CurrentOffset(); if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) { const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); m_Stream.Write(Msg.Left(RemaingHeaderSize)); Msg += RemaingHeaderSize; if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView()); if (IsValidHeader == false) { OnReset(); return {.Status = ParseMessageStatus::kError, .ByteCount = m_Stream.CurrentOffset() - PrevOffset, .Reason = std::string("Invalid websocket message header")}; } if (m_Message.MessageSize() == 0) { return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } } ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); if (Msg.IsEmpty() == false) { const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); m_Stream.Write(Msg.Left(RemaingMessageSize)); } auto Status = ParseMessageStatus::kContinue; if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize()) { Status = ParseMessageStatus::kDone; BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize)); CbPackage Pkg; if (Pkg.TryLoad(Reader) == false) { return {.Status = ParseMessageStatus::kError, .ByteCount = m_Stream.CurrentOffset() - PrevOffset, .Reason = std::string("Invalid websocket message")}; } m_Message.SetBody(std::move(Pkg)); } return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } void WebSocketMessageParser::OnReset() { m_Message = WebSocketMessage(); } WebSocketMessage WebSocketMessageParser::ConsumeMessage() { WebSocketMessage Msg = std::move(m_Message); m_Message = WebSocketMessage(); return Msg; } /////////////////////////////////////////////////////////////////////////////// class WsConnection : public std::enable_shared_from_this { public: WsConnection(WebSocketId Id, std::unique_ptr Socket) : m_Id(Id) , m_Socket(std::move(Socket)) , m_StartTime(Clock::now()) , m_State() { } ~WsConnection() = default; std::shared_ptr AsShared() { return shared_from_this(); } WebSocketId Id() const { return m_Id; } asio::ip::tcp::socket& Socket() { return *m_Socket; } TimePoint StartTime() const { return m_StartTime; } WebSocketState State() const { return static_cast(m_State.load(std::memory_order_relaxed)); } std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); } asio::streambuf& ReadBuffer() { return m_ReadBuffer; } WebSocketState SetState(WebSocketState NewState) { return static_cast(m_State.exchange(uint32_t(NewState))); } WebSocketState Close(); MessageParser* Parser() { return m_MsgParser.get(); } void SetParser(std::unique_ptr&& Parser) { m_MsgParser = std::move(Parser); } std::mutex& WriteMutex() { return m_WriteMutex; } private: WebSocketId m_Id; std::unique_ptr m_Socket; TimePoint m_StartTime; std::atomic_uint32_t m_State; std::unique_ptr m_MsgParser; asio::streambuf m_ReadBuffer; std::mutex m_WriteMutex; }; WebSocketState WsConnection::Close() { const auto PrevState = SetState(WebSocketState::kDisconnected); if (PrevState != WebSocketState::kDisconnected && m_Socket->is_open()) { m_Socket->close(); } return PrevState; } /////////////////////////////////////////////////////////////////////////////// class WsThreadPool { public: WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} void Start(uint32_t ThreadCount); void Stop(); private: asio::io_service& m_IoSvc; std::vector m_Threads; std::atomic_bool m_Running{false}; }; void WsThreadPool::Start(uint32_t ThreadCount) { ZEN_ASSERT(m_Threads.empty()); ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount); m_Running = true; for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) { m_Threads.emplace_back([this, ThreadId = Idx + 1] { for (;;) { if (m_Running == false) { break; } try { m_IoSvc.run(); } catch (std::exception& Err) { ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what()); } } ZEN_LOG_TRACE(LogWebSocket, "websocket I/O thread '{}' exiting", ThreadId); }); } } void WsThreadPool::Stop() { if (m_Running) { m_Running = false; for (std::thread& Thread : m_Threads) { if (Thread.joinable()) { Thread.join(); } } m_Threads.clear(); } } /////////////////////////////////////////////////////////////////////////////// class WsServer final : public WebSocketServer { public: WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {} virtual ~WsServer() { Shutdown(); } virtual bool Run() override; virtual void Shutdown() override; virtual void RegisterService(WebSocketService& Service) override; virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override; virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override; virtual void SendNotification(WebSocketMessage&& Notification) override; virtual void SendResponse(WebSocketMessage&& Response) override; private: friend class WsConnection; void AcceptConnection(); void CloseConnection(std::shared_ptr Connection, const std::error_code& Ec); void ReadMessage(std::shared_ptr Connection); void RouteMessage(WebSocketMessage&& Msg); void SendMessage(WebSocketMessage&& Msg); struct IdHasher { size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } }; using ConnectionMap = std::unordered_map, IdHasher>; using RequestHandlerMap = std::unordered_map; using NotificationHandlerMap = std::unordered_map>; WebSocketServerOptions m_Options; asio::io_service m_IoSvc; std::unique_ptr m_Acceptor; std::unique_ptr m_ThreadPool; ConnectionMap m_Connections; std::shared_mutex m_ConnMutex; std::vector m_Services; RequestHandlerMap m_RequestHandlers; NotificationHandlerMap m_NotificationHandlers; std::atomic_bool m_Running{}; }; void WsServer::RegisterService(WebSocketService& Service) { m_Services.push_back(&Service); Service.Configure(*this); } bool WsServer::Run() { static constexpr size_t ReceiveBufferSize = 256 << 10; static constexpr size_t SendBufferSize = 256 << 10; m_Acceptor = std::make_unique(m_IoSvc, asio::ip::tcp::v6()); m_Acceptor->set_option(asio::ip::v6_only(false)); m_Acceptor->set_option(asio::socket_base::reuse_address(true)); m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize)); m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize)); #if ZEN_PLATFORM_WINDOWS // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. // This must be used by both the client and server side, and is only effective in the absence of // Windows Filtering Platform (WFP) callouts which can be installed by security software. // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path SOCKET NativeSocket = m_Acceptor->native_handle(); int LoopbackOptionValue = 1; DWORD OptionNumberOfBytesReturned = 0; WSAIoctl(NativeSocket, SIO_LOOPBACK_FAST_PATH, &LoopbackOptionValue, sizeof(LoopbackOptionValue), NULL, 0, &OptionNumberOfBytesReturned, 0, 0); #endif asio::error_code Ec; m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec); if (Ec) { ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value()); return false; } m_Acceptor->listen(); m_Running = true; ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port); AcceptConnection(); m_ThreadPool = std::make_unique(m_IoSvc); m_ThreadPool->Start(m_Options.ThreadCount); return true; } void WsServer::Shutdown() { if (m_Running) { ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down"); m_Running = false; m_Acceptor->close(); m_Acceptor.reset(); m_IoSvc.stop(); m_ThreadPool->Stop(); } } void WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) { auto Result = m_NotificationHandlers.try_emplace(Key, std::vector()); Result.first->second.push_back(&Service); } void WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service) { m_RequestHandlers[Key] = &Service; } void WsServer::SendNotification(WebSocketMessage&& Notification) { ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification); SendMessage(std::move(Notification)); } void WsServer::SendResponse(WebSocketMessage&& Response) { ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse || Response.MessageType() == WebSocketMessageType::kStreamResponse || Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse); ZEN_ASSERT(Response.CorrelationId() != 0); SendMessage(std::move(Response)); } void WsServer::AcceptConnection() { auto Socket = std::make_unique(m_IoSvc); asio::ip::tcp::socket& SocketRef = *Socket.get(); m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { if (m_Running) { if (Ec) { ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message()); } else { auto Connection = std::make_shared(WebSocketId::New(), std::move(ConnectedSocket)); ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); { std::unique_lock _(m_ConnMutex); m_Connections[Connection->Id()] = Connection; } Connection->SetParser(std::make_unique(HttpMessageParserType::kRequest)); Connection->SetState(WebSocketState::kHandshaking); ReadMessage(Connection); } AcceptConnection(); } }); } void WsServer::CloseConnection(std::shared_ptr Connection, const std::error_code& Ec) { if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected) { if (Ec) { ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, reason '{} ({})'", Connection->Id().Value(), Ec.message(), Ec.value()); } else { ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value()); } } const WebSocketId Id = Connection->Id(); { std::unique_lock _(m_ConnMutex); if (m_Connections.contains(Id)) { m_Connections.erase(Id); } } } void WsServer::ReadMessage(std::shared_ptr Connection) { Connection->ReadBuffer().prepare(64 << 10); asio::async_read( Connection->Socket(), Connection->ReadBuffer(), asio::transfer_at_least(1), [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable { if (ReadEc) { return CloseConnection(Connection, ReadEc); } switch (Connection->State()) { case WebSocketState::kHandshaking: { HttpMessageParser& Parser = *reinterpret_cast(Connection->Parser()); asio::const_buffer Buffer = Connection->ReadBuffer().data(); ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); Connection->ReadBuffer().consume(Result.ByteCount); if (Result.Status == ParseMessageStatus::kContinue) { return ReadMessage(Connection); } if (Result.Status == ParseMessageStatus::kError) { ZEN_LOG_WARN(LogWebSocket, "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", Connection->Id().Value(), Connection->RemoteAddr()); return CloseConnection(Connection, std::error_code()); } if (Parser.IsUpgrade() == false) { ZEN_LOG_DEBUG(LogWebSocket, "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'", Connection->Id().Value(), Connection->RemoteAddr()); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; return async_write(Connection->Socket(), asio::buffer(UpgradeRequiredResponse), [this, Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { return CloseConnection(Connection, WriteEc); } Connection->Parser()->Reset(); Connection->SetState(WebSocketState::kHandshaking); ReadMessage(Connection); }); } ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); std::string AcceptHash; std::string Reason; const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason); if (ValidHandshake == false) { ZEN_LOG_DEBUG(LogWebSocket, "handshake with connection '{}' FAILED, reason '{}'", Connection->Id().Value(), Reason); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; return async_write(Connection->Socket(), asio::buffer(UpgradeRequiredResponse), [this, &Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { return CloseConnection(Connection, WriteEc); } Connection->Parser()->Reset(); Connection->SetState(WebSocketState::kHandshaking); ReadMessage(Connection); }); } ExtendableStringBuilder<128> Sb; Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; Sb << "Upgrade: websocket\r\n"sv; Sb << "Connection: Upgrade\r\n"sv; // TODO: Verify protocol if (Parser.Headers().contains(http_header::SecWebSocketProtocol)) { Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol] << "\r\n"; } Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; Sb << "\r\n"sv; ZEN_LOG_DEBUG(LogWebSocket, "accepting handshake from connection '#{} {}'", Connection->Id().Value(), Connection->RemoteAddr()); std::string Response = Sb.ToString(); Buffer = asio::buffer(Response); async_write(Connection->Socket(), Buffer, [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) { if (WriteEc) { ZEN_LOG_DEBUG(LogWebSocket, "handshake with connection '{}' FAILED, reason '{}'", Connection->Id().Value(), WriteEc.message()); return CloseConnection(Connection, WriteEc); } ZEN_LOG_DEBUG(LogWebSocket, "handshake ({}B) with connection '#{} {}' OK", ByteCount, Connection->Id().Value(), Connection->RemoteAddr()); Connection->SetParser(std::make_unique()); Connection->SetState(WebSocketState::kConnected); ReadMessage(Connection); }); } break; case WebSocketState::kConnected: { WebSocketMessageParser& Parser = *reinterpret_cast(Connection->Parser()); uint64_t RemainingBytes = Connection->ReadBuffer().size(); while (RemainingBytes > 0) { MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes); const ParseMessageResult Result = Parser.ParseMessage(MessageData); Connection->ReadBuffer().consume(Result.ByteCount); RemainingBytes = Connection->ReadBuffer().size(); if (Result.Status == ParseMessageStatus::kError) { ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); return CloseConnection(Connection, std::error_code()); } if (Result.Status == ParseMessageStatus::kContinue) { ZEN_ASSERT(RemainingBytes == 0); continue; } WebSocketMessage Message = Parser.ConsumeMessage(); Parser.Reset(); Message.SetSocketId(Connection->Id()); RouteMessage(std::move(Message)); } ReadMessage(Connection); } break; default: break; }; }); } void WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) { switch (RoutedMessage.MessageType()) { case WebSocketMessageType::kRequest: case WebSocketMessageType::kStreamRequest: { CbObjectView Request = RoutedMessage.Body().GetObject(); std::string_view Method = Request["Method"].AsString(); bool Handled = false; bool Error = false; std::exception Exception; if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end()) { WebSocketService* Service = It->second; ZEN_ASSERT(Service); try { Handled = Service->HandleRequest(std::move(RoutedMessage)); } catch (std::exception& Err) { Exception = std::move(Err); Error = true; } } if (Error || Handled == false) { std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method); ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); CbObjectWriter Response; Response << "Error"sv << ErrorText; WebSocketMessage ResponseMsg; ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId()); ResponseMsg.SetSocketId(RoutedMessage.SocketId()); ResponseMsg.SetBody(Response.Save()); SendResponse(std::move(ResponseMsg)); } } break; case WebSocketMessageType::kNotification: { CbObjectView Notification = RoutedMessage.Body().GetObject(); std::string_view Message = Notification["Message"].AsString(); if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end()) { std::vector& Handlers = It->second; for (WebSocketService* Handler : Handlers) { Handler->HandleNotification(RoutedMessage); } } else { ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message); } } break; default: break; }; } void WsServer::SendMessage(WebSocketMessage&& Msg) { std::shared_ptr Connection; { std::unique_lock _(m_ConnMutex); if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end()) { Connection = It->second; } } if (Connection.get() == nullptr) { ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value()); return; } if (Connection.get() != nullptr) { BinaryWriter Writer; Msg.Save(Writer); ZEN_LOG_TRACE(LogWebSocket, "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}", ToString(Msg.MessageType()), Connection->Id().Value(), Msg.MessageSize(), Msg.CorrelationId(), NiceBytes(Writer.Size())); { ZEN_TRACE_CPU("WS::SendMessage"); std::unique_lock _(Connection->WriteMutex()); ZEN_TRACE_CPU("WS::WriteSocketData"); asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size())); } } } /////////////////////////////////////////////////////////////////////////////// class WsClient final : public WebSocketClient, public std::enable_shared_from_this { public: WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {} virtual ~WsClient() { Disconnect(); } std::shared_ptr AsShared() { return shared_from_this(); } virtual std::future Connect(const WebSocketConnectInfo& Info) override; virtual void Disconnect() override; virtual bool IsConnected() const override { return false; } virtual WebSocketState State() const override { return static_cast(m_State.load()); } virtual std::future SendRequest(WebSocketMessage&& Request) override; virtual void OnNotification(NotificationCallback&& Cb) override; virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override; private: WebSocketState SetState(WebSocketState NewState) { return static_cast(m_State.exchange(uint32_t(NewState))); } MessageParser* Parser() { return m_MsgParser.get(); } void SetParser(std::unique_ptr&& Parser) { m_MsgParser = std::move(Parser); } asio::streambuf& ReadBuffer() { return m_ReadBuffer; } void TriggerEvent(WebSocketEvent Evt); void ReadMessage(); void RouteMessage(WebSocketMessage&& RoutedMessage); using PendingRequestMap = std::unordered_map>; asio::io_context& m_IoCtx; WebSocketId m_Id; std::unique_ptr m_Socket; std::unique_ptr m_MsgParser; asio::streambuf m_ReadBuffer; EventCallback m_EventCallbacks[3]; NotificationCallback m_NotificationCallback; PendingRequestMap m_PendingRequests; std::mutex m_RequestMutex; std::promise m_ConnectPromise; std::atomic_uint32_t m_State; std::string m_Host; int16_t m_Port{}; }; std::future WsClient::Connect(const WebSocketConnectInfo& Info) { if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) { return m_ConnectPromise.get_future(); } SetState(WebSocketState::kHandshaking); try { asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port); m_Socket = std::make_unique(m_IoCtx, Endpoint.protocol()); m_Socket->connect(Endpoint); m_Host = m_Socket->remote_endpoint().address().to_string(); m_Port = Info.Port; ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port); } catch (std::exception& Err) { ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); SetState(WebSocketState::kError); m_Socket.reset(); TriggerEvent(WebSocketEvent::kDisconnected); m_ConnectPromise.set_value(false); return m_ConnectPromise.get_future(); } ExtendableStringBuilder<128> Sb; Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv; Sb << "Host: " << Info.Host << "\r\n"sv; Sb << "Upgrade: websocket\r\n"sv; Sb << "Connection: upgrade\r\n"sv; Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv; if (Info.Protocols.empty() == false) { Sb << "Sec-WebSocket-Protocol: "sv; for (size_t Idx = 0; const auto& Protocol : Info.Protocols) { if (Idx++) { Sb << ", "; } Sb << Protocol; } } Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv; Sb << "\r\n"; std::string HandshakeRequest = Sb.ToString(); asio::const_buffer Buffer = asio::buffer(HandshakeRequest); ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port); m_MsgParser = std::make_unique(HttpMessageParserType::kResponse); m_MsgParser->Reset(); async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { if (Ec) { ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message()); Self->Disconnect(); } else { Self->ReadMessage(); } }); return m_ConnectPromise.get_future(); } void WsClient::Disconnect() { if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) { ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port); if (m_Socket && m_Socket->is_open()) { m_Socket->close(); m_Socket.reset(); } TriggerEvent(WebSocketEvent::kDisconnected); { std::unique_lock _(m_RequestMutex); for (auto& Kv : m_PendingRequests) { Kv.second.set_value(WebSocketMessage()); } m_PendingRequests.clear(); } } } std::future WsClient::SendRequest(WebSocketMessage&& Request) { ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest); BinaryWriter Writer; Request.Save(Writer); std::future FutureResponse; { std::unique_lock _(m_RequestMutex); auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise()); ZEN_ASSERT(Result.second); auto It = Result.first; FutureResponse = It->second.get_future(); } IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) { if (Ec) { ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message()); Self->Disconnect(); } }); return FutureResponse; } void WsClient::OnNotification(NotificationCallback&& Cb) { m_NotificationCallback = std::move(Cb); } void WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) { m_EventCallbacks[static_cast(Evt)] = std::move(Cb); } void WsClient::TriggerEvent(WebSocketEvent Evt) { const uint32_t Index = static_cast(Evt); if (m_EventCallbacks[Index]) { m_EventCallbacks[Index](); } } void WsClient::ReadMessage() { m_ReadBuffer.prepare(64 << 10); async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable { const WebSocketState State = Self->State(); if (State == WebSocketState::kDisconnected) { return; } if (Ec) { ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message()); return Self->Disconnect(); } switch (State) { case WebSocketState::kHandshaking: { HttpMessageParser& Parser = *reinterpret_cast(Self->Parser()); MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); ParseMessageResult Result = Parser.ParseMessage(MessageData); Self->ReadBuffer().consume(size_t(Result.ByteCount)); if (Result.Status == ParseMessageStatus::kError) { ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); Self->m_ConnectPromise.set_value(false); return Self->Disconnect(); } if (Result.Status == ParseMessageStatus::kContinue) { return Self->ReadMessage(); } ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); if (Parser.StatusCode() != 101) { ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status '{}', status code '{}'", Parser.StatusText(), Parser.StatusCode()); Self->m_ConnectPromise.set_value(false); return Self->Disconnect(); } ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText()); Self->SetParser(std::make_unique()); Self->SetState(WebSocketState::kConnected); Self->ReadMessage(); Self->TriggerEvent(WebSocketEvent::kConnected); Self->m_ConnectPromise.set_value(true); } break; case WebSocketState::kConnected: { WebSocketMessageParser& Parser = *reinterpret_cast(Self->Parser()); uint64_t RemainingBytes = Self->ReadBuffer().size(); while (RemainingBytes > 0) { MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes); const ParseMessageResult Result = Parser.ParseMessage(MessageData); Self->ReadBuffer().consume(Result.ByteCount); RemainingBytes = Self->ReadBuffer().size(); if (Result.Status == ParseMessageStatus::kError) { ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); Parser.Reset(); continue; } if (Result.Status == ParseMessageStatus::kContinue) { ZEN_ASSERT(RemainingBytes == 0); continue; } WebSocketMessage Message = Parser.ConsumeMessage(); Parser.Reset(); Self->RouteMessage(std::move(Message)); } Self->ReadMessage(); } break; default: break; } }); } void WsClient::RouteMessage(WebSocketMessage&& RoutedMessage) { switch (RoutedMessage.MessageType()) { case WebSocketMessageType::kResponse: { std::unique_lock _(m_RequestMutex); if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end()) { It->second.set_value(std::move(RoutedMessage)); m_PendingRequests.erase(It); } else { ZEN_LOG_WARN(LogWsClient, "route request message FAILED, reason 'unknown correlation ID ({})'", RoutedMessage.CorrelationId()); } } break; case WebSocketMessageType::kNotification: { std::unique_lock _(m_RequestMutex); if (m_NotificationCallback) { m_NotificationCallback(std::move(RoutedMessage)); } } break; default: ZEN_LOG_WARN(LogWsClient, "route message FAILED, reason 'invalid message type ({})'", uint8_t(RoutedMessage.MessageType())); break; }; } } // namespace zen::websocket namespace zen { std::atomic_uint32_t WebSocketId::NextId{1}; bool WebSocketMessage::Header::IsValid() const { return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) && uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount); } std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; void WebSocketMessage::SetMessageType(WebSocketMessageType MessageType) { m_Header.MessageType = MessageType; } void WebSocketMessage::SetBody(CbPackage&& Body) { m_Body = std::move(Body); } void WebSocketMessage::SetBody(CbObject&& Body) { CbPackage Pkg; Pkg.SetObject(Body); SetBody(std::move(Pkg)); } void WebSocketMessage::Save(BinaryWriter& Writer) { Writer.Write(&m_Header, HeaderSize); if (m_Body.has_value()) { const CbObject& Obj = m_Body.value().GetObject(); MemoryView View = Obj.GetBuffer().GetView(); const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All); ZEN_ASSERT(ValidationResult == CbValidateError::None); m_Body.value().Save(Writer); } if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest) { m_Header.CorrelationId = NextCorrelationId.fetch_add(1); } m_Header.MessageSize = Writer.Size() - HeaderSize; Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); } bool WebSocketMessage::TryLoadHeader(MemoryView Memory) { if (Memory.GetSize() < HeaderSize) { return false; } MutableMemoryView HeaderView(&m_Header, HeaderSize); HeaderView.CopyFrom(Memory); return m_Header.IsValid(); } void WebSocketService::Configure(WebSocketServer& Server) { ZEN_ASSERT(m_SocketServer == nullptr); m_SocketServer = &Server; RegisterHandlers(Server); } void WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete) { WebSocketMessage Message; Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse); Message.SetCorrelationId(CorrelationId); Message.SetSocketId(SocketId); Message.SetBody(std::move(StreamResponse)); SocketServer().SendResponse(std::move(Message)); } void WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete) { CbPackage Response; Response.SetObject(std::move(StreamResponse)); SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete); } std::unique_ptr WebSocketServer::Create(const WebSocketServerOptions& Options) { return std::make_unique(Options); } std::shared_ptr WebSocketClient::Create(asio::io_context& IoCtx) { return std::make_shared(IoCtx); } } // namespace zen