// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include #include #include ZEN_THIRD_PARTY_INCLUDES_END namespace zen::websocket { using namespace std::literals; ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); using Clock = std::chrono::steady_clock; using TimePoint = Clock::time_point; /////////////////////////////////////////////////////////////////////////////// namespace http_header { static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; static constexpr std::string_view Upgrade = "Upgrade"sv; } // namespace http_header /////////////////////////////////////////////////////////////////////////////// enum class ParseMessageStatus : uint32_t { kError, kContinue, kDone, }; struct ParseMessageResult { ParseMessageStatus Status{}; size_t ByteCount{}; std::optional Reason; }; class MessageParser { public: virtual ~MessageParser() = default; ParseMessageResult ParseMessage(MemoryView Msg); void Reset(); protected: MessageParser() = default; virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; virtual void OnReset() = 0; SimpleBinaryWriter m_Stream; }; ParseMessageResult MessageParser::ParseMessage(MemoryView Msg) { return OnParseMessage(Msg); } void MessageParser::Reset() { OnReset(); m_Stream.Clear(); } /////////////////////////////////////////////////////////////////////////////// enum class HttpMessageParserType { kRequest, kResponse, kBoth }; class HttpMessageParser final : public MessageParser { public: using HttpHeaders = std::unordered_map; HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) {} virtual ~HttpMessageParser() = default; int32_t StatusCode() const { return m_Parser.status_code; } bool IsUpgrade() const { return m_Parser.upgrade != 0; } HttpHeaders& Headers() { return m_Headers; } MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } std::string_view StatusText() const { return std::string_view(reinterpret_cast(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); } bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); private: virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; int OnMessageBegin(); int OnUrl(MemoryView Url); int OnStatus(MemoryView Status); int OnHeaderField(MemoryView HeaderField); int OnHeaderValue(MemoryView HeaderValue); int OnHeadersComplete(); int OnBody(MemoryView Body); int OnMessageComplete(); struct StreamEntry { uint64_t Offset{}; uint64_t Size{}; }; struct HeaderStreamEntry { StreamEntry Field{}; StreamEntry Value{}; }; HttpMessageParserType m_Type; http_parser m_Parser; StreamEntry m_UrlEntry; StreamEntry m_StatusEntry; StreamEntry m_BodyEntry; HeaderStreamEntry m_CurrentHeader; std::vector m_HeaderEntries; HttpHeaders m_Headers; bool m_IsMsgComplete{false}; static http_parser_settings ParserSettings; }; http_parser_settings HttpMessageParser::ParserSettings = { .on_message_begin = [](http_parser* P) { return reinterpret_cast(P->data)->OnMessageBegin(); }, .on_url = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnUrl(MemoryView(Data, Size)); }, .on_status = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnStatus(MemoryView(Data, Size)); }, .on_header_field = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnHeaderField(MemoryView(Data, Size)); }, .on_header_value = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, .on_headers_complete = [](http_parser* P) { return reinterpret_cast(P->data)->OnHeadersComplete(); }, .on_body = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnBody(MemoryView(Data, Size)); }, .on_message_complete = [](http_parser* P) { return reinterpret_cast(P->data)->OnMessageComplete(); }}; ParseMessageResult HttpMessageParser::OnParseMessage(MemoryView Msg) { const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast(Msg.GetData()), Msg.GetSize()); auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; if (m_Parser.http_errno != 0) { Status = ParseMessageStatus::kError; } return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; } void HttpMessageParser::OnReset() { http_parser_init(&m_Parser, m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE : HTTP_BOTH); m_Parser.data = this; m_UrlEntry = {}; m_StatusEntry = {}; m_CurrentHeader = {}; m_BodyEntry = {}; m_IsMsgComplete = false; m_HeaderEntries.clear(); } int HttpMessageParser::OnMessageBegin() { ZEN_ASSERT(m_IsMsgComplete == false); ZEN_ASSERT(m_HeaderEntries.empty()); ZEN_ASSERT(m_Headers.empty()); return 0; } int HttpMessageParser::OnStatus(MemoryView Status) { m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; m_Stream.Write(Status); return 0; } int HttpMessageParser::OnUrl(MemoryView Url) { m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; m_Stream.Write(Url); return 0; } int HttpMessageParser::OnHeaderField(MemoryView HeaderField) { if (m_CurrentHeader.Value.Size > 0) { m_HeaderEntries.push_back(m_CurrentHeader); m_CurrentHeader = {}; } if (m_CurrentHeader.Field.Size == 0) { m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); } m_CurrentHeader.Field.Size += HeaderField.GetSize(); m_Stream.Write(HeaderField); return 0; } int HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) { if (m_CurrentHeader.Value.Size == 0) { m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); } m_CurrentHeader.Value.Size += HeaderValue.GetSize(); m_Stream.Write(HeaderValue); return 0; } int HttpMessageParser::OnHeadersComplete() { if (m_CurrentHeader.Value.Size > 0) { m_HeaderEntries.push_back(m_CurrentHeader); m_CurrentHeader = {}; } m_Headers.clear(); m_Headers.reserve(m_HeaderEntries.size()); const char* StreamData = reinterpret_cast(m_Stream.Data()); for (const auto& Entry : m_HeaderEntries) { auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); m_Headers.try_emplace(std::move(Field), std::move(Value)); } return 0; } int HttpMessageParser::OnBody(MemoryView Body) { m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; m_Stream.Write(Body); return 0; } int HttpMessageParser::OnMessageComplete() { m_IsMsgComplete = true; return 0; } bool HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) { static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; OutAcceptHash = std::string(); if (m_Headers.contains(http_header::SecWebSocketKey) == false) { OutReason = "Missing header Sec-WebSocket-Key"; return false; } if (m_Headers.contains(http_header::Upgrade) == false) { OutReason = "Missing header Upgrade"; return false; } ExtendableStringBuilder<128> Sb; Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; SHA1Stream HashStream; HashStream.Append(Sb.Data(), Sb.Size()); SHA1 Hash = HashStream.GetHash(); OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); return true; } /////////////////////////////////////////////////////////////////////////////// class WebSocketMessageParser final : public MessageParser { public: WebSocketMessageParser() : MessageParser() {} bool TryLoadMessage(CbPackage& OutMsg); private: virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; SimpleBinaryWriter m_HeaderStream; WebSocketMessageHeader m_Header; }; ParseMessageResult WebSocketMessageParser::OnParseMessage(MemoryView Msg) { const uint64_t PrevOffset = m_Stream.CurrentOffset(); if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader)) { const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_HeaderStream.CurrentOffset(); m_HeaderStream.Write(Msg.Left(RemaingHeaderSize)); Msg.RightChopInline(RemaingHeaderSize); if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader)) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } const bool IsValidHeader = WebSocketMessageHeader::Read(m_HeaderStream.GetView(), m_Header); if (IsValidHeader == false) { return {.Status = ParseMessageStatus::kError, .Reason = std::string("Invalid websocket message header")}; } } if (Msg.GetSize() == 0) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } const uint64_t RemaingContentSize = m_Header.ContentLength - m_HeaderStream.CurrentOffset(); m_Stream.Write(Msg.Left(RemaingContentSize)); const auto Status = m_Stream.CurrentOffset() == m_Header.ContentLength ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } void WebSocketMessageParser::OnReset() { m_HeaderStream.Clear(); m_Header = {}; } bool WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) { const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength; if (IsParsed) { BinaryReader Reader(m_Stream.Data(), m_Stream.Size()); return OutMsg.TryLoad(Reader); } return false; } /////////////////////////////////////////////////////////////////////////////// enum class WsConnectionState : uint32_t { kDisconnected, kHandshaking, kConnected }; /////////////////////////////////////////////////////////////////////////////// class WsConnectionId { static std::atomic_uint32_t WsConnectionCounter; public: WsConnectionId() = default; uint32_t Value() const { return m_Value; } auto operator<=>(const WsConnectionId& RHS) const = default; static WsConnectionId New() { return WsConnectionId(WsConnectionCounter.fetch_add(1)); } private: WsConnectionId(uint32_t Value) : m_Value(Value) {} uint32_t m_Value{}; }; std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1}; class WsServer; /////////////////////////////////////////////////////////////////////////////// class WsConnection : public std::enable_shared_from_this { public: WsConnection(WsServer& Server, WsConnectionId Id, std::unique_ptr Socket) : m_Server(Server) , m_Id(Id) , m_Socket(std::move(Socket)) , m_StartTime(Clock::now()) , m_Status() { m_RemoteAddr = m_Socket->remote_endpoint().address().to_string(); } ~WsConnection(); std::shared_ptr AsShared() { return shared_from_this(); } WsConnectionId Id() const { return m_Id; } std::string_view RemoteAddr() const { return m_RemoteAddr; } asio::ip::tcp::socket& Socket() { return *m_Socket; } TimePoint StartTime() const { return m_StartTime; } asio::streambuf& ReadBuffer() { return m_ReadBuffer; } WsConnectionState Close(); WsConnectionState State() const { return static_cast(m_Status.load(std::memory_order_relaxed)); } WsConnectionState SetState(WsConnectionState NewState) { return static_cast(m_Status.exchange(uint32_t(NewState))); } MessageParser* Parser() { return m_MsgParser.get(); } void SetParser(std::unique_ptr&& Parser) { m_MsgParser = std::move(Parser); } private: WsServer& m_Server; WsConnectionId m_Id; std::unique_ptr m_Socket; std::unique_ptr m_MsgParser; TimePoint m_StartTime; asio::streambuf m_ReadBuffer; std::string m_RemoteAddr; std::atomic_uint32_t m_Status; }; WsConnectionState WsConnection::Close() { using enum WsConnectionState; const auto PrevState = SetState(kDisconnected); if (PrevState != kDisconnected && m_Socket->is_open()) { m_Socket->close(); } return PrevState; } /////////////////////////////////////////////////////////////////////////////// class WsThreadPool { public: WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} void Start(uint32_t ThreadCount); void Stop(); private: asio::io_service& m_IoSvc; std::vector m_Threads; std::atomic_bool m_Running{false}; }; void WsThreadPool::Start(uint32_t ThreadCount) { ZEN_ASSERT(m_Threads.empty()); ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount); m_Running = true; for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) { m_Threads.emplace_back([this, ThreadId = Idx + 1] { for (;;) { if (m_Running == false) { break; } try { m_IoSvc.run(); } catch (std::exception& Err) { ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what()); } } ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId); }); } } void WsThreadPool::Stop() { if (m_Running) { m_Running = false; for (std::thread& Thread : m_Threads) { if (Thread.joinable()) { Thread.join(); } } m_Threads.clear(); } } /////////////////////////////////////////////////////////////////////////////// class WsServer final : public WebSocketServer { public: WsServer() = default; virtual ~WsServer() { Shutdown(); } virtual bool Run(const WebSocketServerOptions& Options) override; virtual void Shutdown() override; private: friend class WsConnection; void AcceptConnection(); void CloseConnection(std::shared_ptr Connection, const std::error_code& Ec); void RemoveConnection(const WsConnectionId Id); void ReadMessage(std::shared_ptr Connection); void RouteMessage(const CbPackage& Msg); struct IdHasher { size_t operator()(WsConnectionId Id) const { return size_t(Id.Value()); } }; using ConnectionMap = std::unordered_map, IdHasher>; asio::io_service m_IoSvc; std::unique_ptr m_Acceptor; std::unique_ptr m_ThreadPool; ConnectionMap m_Connections; std::shared_mutex m_ConnMutex; std::atomic_bool m_Running{}; }; WsConnection::~WsConnection() { m_Server.RemoveConnection(m_Id); } bool WsServer::Run(const WebSocketServerOptions& Options) { m_Acceptor = std::make_unique(m_IoSvc, asio::ip::tcp::v6()); m_Acceptor->set_option(asio::ip::v6_only(false)); m_Acceptor->set_option(asio::socket_base::reuse_address(true)); m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); asio::error_code Ec; m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec); if (Ec) { ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value()); return false; } m_Acceptor->listen(); m_Running = true; ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", Options.Port); AcceptConnection(); m_ThreadPool = std::make_unique(m_IoSvc); m_ThreadPool->Start(Options.ThreadCount); return true; } void WsServer::Shutdown() { if (m_Running) { ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down"); m_Running = false; m_Acceptor->close(); m_Acceptor.reset(); m_IoSvc.stop(); m_ThreadPool->Stop(); } } void WsServer::AcceptConnection() { auto Socket = std::make_unique(m_IoSvc); asio::ip::tcp::socket& SocketRef = *Socket.get(); m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { if (Ec) { ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, error code '{}'", Ec.value()); } else { auto ConnId = WsConnectionId::New(); auto Connection = std::make_shared(*this, ConnId, std::move(ConnectedSocket)); ZEN_LOG_DEBUG(LogWebSocket, "accept connection OK, addr '{}', ID '{}'", Connection->RemoteAddr(), ConnId.Value()); { std::unique_lock _(m_ConnMutex); m_Connections[ConnId] = Connection; } auto Parser = std::make_unique(HttpMessageParserType::kRequest); Parser->Reset(); Connection->SetParser(std::move(Parser)); Connection->SetState(WsConnectionState::kHandshaking); ReadMessage(Connection); } if (m_Running) { AcceptConnection(); } }); } void WsServer::CloseConnection(std::shared_ptr Connection, const std::error_code& Ec) { if (const auto State = Connection->Close(); State != WsConnectionState::kDisconnected) { if (Ec) { ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, ERROR '{}' error code '{}'", Connection->Id().Value(), Ec.message(), Ec.value()); } else { ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value()); } } const WsConnectionId Id = Connection->Id(); { std::unique_lock _(m_ConnMutex); if (m_Connections.contains(Id)) { m_Connections.erase(Id); } } } void WsServer::RemoveConnection(const WsConnectionId Id) { ZEN_LOG_INFO(LogWebSocket, "removing connection '{}'", Id.Value()); } void WsServer::ReadMessage(std::shared_ptr Connection) { Connection->ReadBuffer().prepare(64 << 10); asio::async_read( Connection->Socket(), Connection->ReadBuffer(), asio::transfer_at_least(1), [this, Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { if (ReadEc) { return CloseConnection(Connection, ReadEc); } ZEN_LOG_DEBUG(LogWebSocket, "reading {}B from connection '#{} {}'", ByteCount, Connection->Id().Value(), Connection->RemoteAddr()); using enum WsConnectionState; switch (Connection->State()) { case kHandshaking: { HttpMessageParser& Parser = *reinterpret_cast(Connection->Parser()); asio::const_buffer Buffer = Connection->ReadBuffer().data(); ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); if (Result.Status == ParseMessageStatus::kContinue) { return ReadMessage(Connection); } if (Result.Status == ParseMessageStatus::kError) { ZEN_LOG_DEBUG(LogWebSocket, "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", Connection->Id().Value(), Connection->RemoteAddr()); return CloseConnection(Connection, std::error_code()); } if (Parser.IsUpgrade() == false) { ZEN_LOG_DEBUG(LogWebSocket, "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'", Connection->Id().Value(), Connection->RemoteAddr()); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; return async_write(Connection->Socket(), asio::buffer(UpgradeRequiredResponse), [this, Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { return CloseConnection(Connection, WriteEc); } Connection->Parser()->Reset(); Connection->SetState(WsConnectionState::kHandshaking); ReadMessage(Connection); }); } ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); std::string AcceptHash; std::string Reason; const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason); if (ValidHandshake == false) { ZEN_LOG_DEBUG(LogWebSocket, "handshake with connection '{}' FAILED, reason '{}'", Connection->Id().Value(), Reason); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; return async_write(Connection->Socket(), asio::buffer(UpgradeRequiredResponse), [this, &Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { return CloseConnection(Connection, WriteEc); } Connection->Parser()->Reset(); Connection->SetState(WsConnectionState::kHandshaking); ReadMessage(Connection); }); } ExtendableStringBuilder<128> Sb; Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; Sb << "Upgrade: websocket\r\n"sv; Sb << "Connection: Upgrade\r\n"sv; // TODO: Verify protocol if (Parser.Headers().contains(http_header::SecWebSocketProtocol)) { Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol] << "\r\n"; } Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; Sb << "\r\n"sv; ZEN_LOG_DEBUG(LogWebSocket, "accepting handshake from connection '#{} {}'", Connection->Id().Value(), Connection->RemoteAddr()); std::string Response = Sb.ToString(); Buffer = asio::buffer(Response); async_write(Connection->Socket(), Buffer, [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) { if (WriteEc) { ZEN_LOG_DEBUG(LogWebSocket, "handshake with connection '{}' FAILED, reason '{}'", Connection->Id().Value(), WriteEc.message()); return CloseConnection(Connection, WriteEc); } ZEN_LOG_DEBUG(LogWebSocket, "handshake ({}B) with connection '#{} {}' OK", ByteCount, Connection->Id().Value(), Connection->RemoteAddr()); Connection->SetParser(std::make_unique()); Connection->SetState(kConnected); }); } break; case kConnected: { // for (;;) //{ // if (Connection->ReadBuffer().size() == 0) // { // break; // } // WsMessageParser& MessageParser = Connection->MessageParser(); // size_t ConsumedBytes{}; // const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes); // Connection->ReadBuffer().consume(ConsumedBytes); // if (Ok == false) // { // ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, connection '{}'", // Connection->Id().Value()); MessageParser.Reset(); // } // if (Ok == false || MessageParser.IsComplete() == false) // { // continue; // } // CbPackage Message; // if (MessageParser.TryLoadMessage(Message) == false) // { // ZEN_LOG_WARN(LogWebSocket, "invalid websocket message, connection '{}'", // Connection->Id().Value()); continue; // } // RouteMessage(Message); //} // ReadMessage(Connection); } break; default: break; }; }); } void WsServer::RouteMessage(const CbPackage& Msg) { ZEN_UNUSED(Msg); ZEN_LOG_DEBUG(LogWebSocket, "routing message"); } /////////////////////////////////////////////////////////////////////////////// class WsClient final : public WebSocketClient { public: WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Logger(zen::logging::Get("websocket-client")) {} virtual ~WsClient() { Disconnect(); } virtual bool Connect(const WebSocketConnectInfo& Info) override; virtual void Disconnect() override; virtual bool IsConnected() const { return false; } virtual WebSocketState State() const { return static_cast(m_State.load()); } virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override; private: WebSocketState SetState(WebSocketState NewState) { return static_cast(m_State.exchange(uint32_t(NewState))); } void TriggerEvent(WebSocketEvent Evt); void BeginRead(); spdlog::logger& Log() { return m_Logger; } asio::io_context& m_IoCtx; spdlog::logger& m_Logger; std::unique_ptr m_Socket; std::unique_ptr m_MsgParser; asio::streambuf m_ReadBuffer; EventCallback m_EventCallbacks[3]; std::atomic_uint32_t m_State; std::string m_Host; int16_t m_Port{}; }; bool WsClient::Connect(const WebSocketConnectInfo& Info) { if (State() == WebSocketState::kConnecting || State() == WebSocketState::kConnected) { return true; } SetState(WebSocketState::kConnecting); try { asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port); m_Socket = std::make_unique(m_IoCtx, Endpoint.protocol()); m_Socket->connect(Endpoint); m_Host = m_Socket->remote_endpoint().address().to_string(); m_Port = Info.Port; ZEN_INFO("connected to websocket server '{}:{}'", m_Host, m_Port); } catch (std::exception& Err) { ZEN_WARN("connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); SetState(WebSocketState::kFailedToConnect); m_Socket.reset(); TriggerEvent(WebSocketEvent::kDisconnected); return false; } ExtendableStringBuilder<128> Sb; Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv; Sb << "Host: " << Info.Host << "\r\n"sv; Sb << "Upgrade: websocket\r\n"sv; Sb << "Connection: upgrade\r\n"sv; Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv; if (Info.Protocols.empty() == false) { Sb << "Sec-WebSocket-Protocol: "sv; for (size_t Idx = 0; const auto& Protocol : Info.Protocols) { if (Idx++) { Sb << ", "; } Sb << Protocol; } } Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv; Sb << "\r\n"; std::string HandshakeRequest = Sb.ToString(); asio::const_buffer Buffer = asio::buffer(HandshakeRequest); ZEN_DEBUG("handshaking with '{}:{}'", m_Host, m_Port); m_MsgParser = std::make_unique(HttpMessageParserType::kResponse); m_MsgParser->Reset(); async_write(*m_Socket, Buffer, [this, _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { if (Ec) { ZEN_ERROR("write data FAILED, reason '{}'", Ec.message()); Disconnect(); } else { BeginRead(); } }); return true; } void WsClient::Disconnect() { if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) { ZEN_INFO("closing connection to '{}:{}'", m_Host, m_Port); if (m_Socket && m_Socket->is_open()) { m_Socket->close(); m_Socket.reset(); } TriggerEvent(WebSocketEvent::kDisconnected); } } void WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) { m_EventCallbacks[static_cast(Evt)] = Cb; } void WsClient::TriggerEvent(WebSocketEvent Evt) { const uint32_t Index = static_cast(Evt); if (m_EventCallbacks[Index]) { m_EventCallbacks[Index](); } } void WsClient::BeginRead() { m_ReadBuffer.prepare(64 << 10); async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t ByteCount) { if (Ec) { ZEN_DEBUG("read data from '{}:{}' FAILED, reason '{}'", m_Host, m_Port, Ec.message()); Disconnect(); } else { ZEN_DEBUG("reading {}B from '{}:{}'", ByteCount, m_Host, m_Port); switch (State()) { case WebSocketState::kConnecting: { ZEN_ASSERT(m_MsgParser.get() != nullptr); HttpMessageParser& Parser = *reinterpret_cast(m_MsgParser.get()); asio::const_buffer Buffer = m_ReadBuffer.data(); ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); m_ReadBuffer.consume(size_t(Result.ByteCount)); if (Result.Status == ParseMessageStatus::kError) { ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode()); return Disconnect(); } if (Result.Status == ParseMessageStatus::kContinue) { return BeginRead(); } ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); if (Parser.StatusCode() != 101) { ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'", m_Host, m_Port, Parser.StatusText(), Parser.StatusCode()); return Disconnect(); } ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText()); m_MsgParser = std::make_unique(); SetState(WebSocketState::kConnected); TriggerEvent(WebSocketEvent::kConnected); BeginRead(); } break; case WebSocketState::kConnected: { BeginRead(); } break; }; } }); } } // namespace zen::websocket namespace zen { bool WebSocketMessageHeader::IsValid() const { return Magic == ExpectedMagic && ContentLength != 0 && Crc32 != 0; } bool WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeader) { if (Memory.GetSize() < sizeof(WebSocketMessageHeader)) { return false; } void* Dst = &OutHeader; memcpy(Dst, Memory.GetData(), sizeof(WebSocketMessageHeader)); return OutHeader.IsValid(); } std::unique_ptr WebSocketServer::Create() { return std::make_unique(); } std::unique_ptr WebSocketClient::Create(asio::io_context& IoCtx) { return std::make_unique(IoCtx); } } // namespace zen