diff options
Diffstat (limited to 'zenhttp/websocketasio.cpp')
| -rw-r--r-- | zenhttp/websocketasio.cpp | 318 |
1 files changed, 170 insertions, 148 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index b800892d2..eb01e010e 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -438,14 +438,6 @@ WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) } /////////////////////////////////////////////////////////////////////////////// -enum class WsConnectionState : uint32_t -{ - kDisconnected, - kHandshaking, - kConnected -}; - -/////////////////////////////////////////////////////////////////////////////// class WsConnectionId { static std::atomic_uint32_t WsConnectionCounter; @@ -467,53 +459,46 @@ private: std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1}; -class WsServer; - /////////////////////////////////////////////////////////////////////////////// class WsConnection : public std::enable_shared_from_this<WsConnection> { public: - WsConnection(WsServer& Server, WsConnectionId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) - : m_Server(Server) - , m_Id(Id) + WsConnection(WsConnectionId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) + : m_Id(Id) , m_Socket(std::move(Socket)) , m_StartTime(Clock::now()) - , m_Status() + , m_State() { - m_RemoteAddr = m_Socket->remote_endpoint().address().to_string(); } - ~WsConnection(); + ~WsConnection() = default; std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } WsConnectionId Id() const { return m_Id; } - std::string_view RemoteAddr() const { return m_RemoteAddr; } asio::ip::tcp::socket& Socket() { return *m_Socket; } TimePoint StartTime() const { return m_StartTime; } + WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); } + std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); } asio::streambuf& ReadBuffer() { return m_ReadBuffer; } - WsConnectionState Close(); - WsConnectionState State() const { return static_cast<WsConnectionState>(m_Status.load(std::memory_order_relaxed)); } - WsConnectionState SetState(WsConnectionState NewState) { return static_cast<WsConnectionState>(m_Status.exchange(uint32_t(NewState))); } - - MessageParser* Parser() { return m_MsgParser.get(); } - void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } + WebSocketState Close(); + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } private: - WsServer& m_Server; WsConnectionId m_Id; std::unique_ptr<asio::ip::tcp::socket> m_Socket; - std::unique_ptr<MessageParser> m_MsgParser; TimePoint m_StartTime; + std::atomic_uint32_t m_State; + std::unique_ptr<MessageParser> m_MsgParser; asio::streambuf m_ReadBuffer; - std::string m_RemoteAddr; - std::atomic_uint32_t m_Status; }; -WsConnectionState +WebSocketState WsConnection::Close() { - using enum WsConnectionState; + using enum WebSocketState; const auto PrevState = SetState(kDisconnected); @@ -607,10 +592,9 @@ private: void AcceptConnection(); void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec); - void RemoveConnection(const WsConnectionId Id); void ReadMessage(std::shared_ptr<WsConnection> Connection); - void RouteMessage(const CbPackage& Msg); + void RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg); struct IdHasher { @@ -627,11 +611,6 @@ private: std::atomic_bool m_Running{}; }; -WsConnection::~WsConnection() -{ - m_Server.RemoveConnection(m_Id); -} - bool WsServer::Run(const WebSocketServerOptions& Options) { @@ -692,25 +671,21 @@ WsServer::AcceptConnection() m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { if (Ec) { - ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, error code '{}'", Ec.value()); + ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message()); } else { - auto ConnId = WsConnectionId::New(); - auto Connection = std::make_shared<WsConnection>(*this, ConnId, std::move(ConnectedSocket)); + auto Connection = std::make_shared<WsConnection>(WsConnectionId::New(), std::move(ConnectedSocket)); - ZEN_LOG_DEBUG(LogWebSocket, "accept connection OK, addr '{}', ID '{}'", Connection->RemoteAddr(), ConnId.Value()); + ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); { std::unique_lock _(m_ConnMutex); - m_Connections[ConnId] = Connection; + m_Connections[Connection->Id()] = Connection; } - auto Parser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest); - Parser->Reset(); - - Connection->SetParser(std::move(Parser)); - Connection->SetState(WsConnectionState::kHandshaking); + Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest)); + Connection->SetState(WebSocketState::kHandshaking); ReadMessage(Connection); } @@ -725,7 +700,7 @@ WsServer::AcceptConnection() void WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec) { - if (const auto State = Connection->Close(); State != WsConnectionState::kDisconnected) + if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected) { if (Ec) { @@ -753,12 +728,6 @@ WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::e } void -WsServer::RemoveConnection(const WsConnectionId Id) -{ - ZEN_LOG_INFO(LogWebSocket, "removing connection '{}'", Id.Value()); -} - -void WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) { Connection->ReadBuffer().prepare(64 << 10); @@ -773,13 +742,7 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) return CloseConnection(Connection, ReadEc); } - ZEN_LOG_DEBUG(LogWebSocket, - "reading {}B from connection '#{} {}'", - ByteCount, - Connection->Id().Value(), - Connection->RemoteAddr()); - - using enum WsConnectionState; + using enum WebSocketState; switch (Connection->State()) { @@ -823,7 +786,7 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) } Connection->Parser()->Reset(); - Connection->SetState(WsConnectionState::kHandshaking); + Connection->SetState(WebSocketState::kHandshaking); ReadMessage(Connection); }); @@ -853,7 +816,7 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) } Connection->Parser()->Reset(); - Connection->SetState(WsConnectionState::kHandshaking); + Connection->SetState(WebSocketState::kHandshaking); ReadMessage(Connection); }); @@ -910,42 +873,46 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) case kConnected: { - // for (;;) - //{ - // if (Connection->ReadBuffer().size() == 0) - // { - // break; - // } - - // WsMessageParser& MessageParser = Connection->MessageParser(); - - // size_t ConsumedBytes{}; - // const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes); - - // Connection->ReadBuffer().consume(ConsumedBytes); - - // if (Ok == false) - // { - // ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, connection '{}'", - // Connection->Id().Value()); MessageParser.Reset(); - // } - - // if (Ok == false || MessageParser.IsComplete() == false) - // { - // continue; - // } - - // CbPackage Message; - // if (MessageParser.TryLoadMessage(Message) == false) - // { - // ZEN_LOG_WARN(LogWebSocket, "invalid websocket message, connection '{}'", - // Connection->Id().Value()); continue; - // } - - // RouteMessage(Message); - //} - - // ReadMessage(Connection); + WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser()); + + MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), ByteCount); + + while (MessageData.IsEmpty() == false) + { + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + MessageData.RightChopInline(Result.ByteCount); + + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(MessageData.IsEmpty()); + + return ReadMessage(Connection); + } + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + + return CloseConnection(Connection, std::error_code()); + } + + CbPackage Message; + if (Parser.TryLoadMessage(Message) == false) + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'"); + + return CloseConnection(Connection, std::error_code()); + } + + RouteMessage(Connection, Message); + + Parser.Reset(); + } + + Connection->ReadBuffer().consume(ByteCount); + + ReadMessage(Connection); } break; @@ -956,9 +923,9 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) } void -WsServer::RouteMessage(const CbPackage& Msg) +WsServer::RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg) { - ZEN_UNUSED(Msg); + ZEN_UNUSED(Connection, Msg); ZEN_LOG_DEBUG(LogWebSocket, "routing message"); } @@ -976,11 +943,13 @@ public: virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); } virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override; + virtual void OnMessage(MessageCallback&& Cb) override; private: WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } void TriggerEvent(WebSocketEvent Evt); - void BeginRead(); + void ReadMessage(); + void RouteMessage(CbPackage&& Msg); spdlog::logger& Log() { return m_Logger; } asio::io_context& m_IoCtx; @@ -989,6 +958,7 @@ private: std::unique_ptr<MessageParser> m_MsgParser; asio::streambuf m_ReadBuffer; EventCallback m_EventCallbacks[3]; + MessageCallback m_MsgCallback; std::atomic_uint32_t m_State; std::string m_Host; int16_t m_Port{}; @@ -997,12 +967,12 @@ private: bool WsClient::Connect(const WebSocketConnectInfo& Info) { - if (State() == WebSocketState::kConnecting || State() == WebSocketState::kConnected) + if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) { return true; } - SetState(WebSocketState::kConnecting); + SetState(WebSocketState::kHandshaking); try { @@ -1020,7 +990,7 @@ WsClient::Connect(const WebSocketConnectInfo& Info) { ZEN_WARN("connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); - SetState(WebSocketState::kFailedToConnect); + SetState(WebSocketState::kError); m_Socket.reset(); TriggerEvent(WebSocketEvent::kDisconnected); @@ -1068,7 +1038,7 @@ WsClient::Connect(const WebSocketConnectInfo& Info) } else { - BeginRead(); + ReadMessage(); } }); @@ -1095,7 +1065,13 @@ WsClient::Disconnect() void WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) { - m_EventCallbacks[static_cast<uint32_t>(Evt)] = Cb; + m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); +} + +void +WsClient::OnMessage(MessageCallback&& Cb) +{ + m_MsgCallback = std::move(Cb); } void @@ -1110,7 +1086,7 @@ WsClient::TriggerEvent(WebSocketEvent Evt) } void -WsClient::BeginRead() +WsClient::ReadMessage() { m_ReadBuffer.prepare(64 << 10); @@ -1119,71 +1095,117 @@ WsClient::BeginRead() { ZEN_DEBUG("read data from '{}:{}' FAILED, reason '{}'", m_Host, m_Port, Ec.message()); - Disconnect(); + return Disconnect(); } - else + + switch (State()) { - ZEN_DEBUG("reading {}B from '{}:{}'", ByteCount, m_Host, m_Port); + case WebSocketState::kHandshaking: + { + ZEN_ASSERT(m_MsgParser.get() != nullptr); - switch (State()) - { - case WebSocketState::kConnecting: + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(m_MsgParser.get()); + + asio::const_buffer Buffer = m_ReadBuffer.data(); + ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + + m_ReadBuffer.consume(size_t(Result.ByteCount)); + + if (Result.Status == ParseMessageStatus::kError) { - ZEN_ASSERT(m_MsgParser.get() != nullptr); + ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode()); - HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(m_MsgParser.get()); + return Disconnect(); + } - asio::const_buffer Buffer = m_ReadBuffer.data(); - ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + if (Result.Status == ParseMessageStatus::kContinue) + { + return ReadMessage(); + } - m_ReadBuffer.consume(size_t(Result.ByteCount)); + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode()); + if (Parser.StatusCode() != 101) + { + ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'", + m_Host, + m_Port, + Parser.StatusText(), + Parser.StatusCode()); - return Disconnect(); - } + return Disconnect(); + } - if (Result.Status == ParseMessageStatus::kContinue) - { - return BeginRead(); - } + ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText()); - ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + m_MsgParser = std::make_unique<WebSocketMessageParser>(); + + SetState(WebSocketState::kConnected); + TriggerEvent(WebSocketEvent::kConnected); + + ReadMessage(); + } + break; + + case WebSocketState::kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(m_MsgParser.get()); - if (Parser.StatusCode() != 101) + MemoryView MessageData = MemoryView(m_ReadBuffer.data().data(), ByteCount); + + while (MessageData.IsEmpty() == false) + { + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + MessageData.RightChopInline(Result.ByteCount); + + if (Result.Status == ParseMessageStatus::kContinue) { - ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'", - m_Host, - m_Port, - Parser.StatusText(), - Parser.StatusCode()); + ZEN_ASSERT(MessageData.IsEmpty()); - return Disconnect(); + return ReadMessage(); } - ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText()); + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); - m_MsgParser = std::make_unique<WebSocketMessageParser>(); + Parser.Reset(); - SetState(WebSocketState::kConnected); - TriggerEvent(WebSocketEvent::kConnected); + continue; + } - BeginRead(); - } - break; + CbPackage Message; + if (Parser.TryLoadMessage(Message)) + { + RouteMessage(std::move(Message)); + } + else + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'"); + } - case WebSocketState::kConnected: - { - BeginRead(); + Parser.Reset(); } - break; - }; - } + + m_ReadBuffer.consume(ByteCount); + + ReadMessage(); + } + break; + }; }); } +void +WsClient::RouteMessage(CbPackage&& Msg) +{ + if (m_MsgCallback) + { + m_MsgCallback(Msg); + } +} + } // namespace zen::websocket namespace zen { |