diff options
Diffstat (limited to 'zenhttp/websocketasio.cpp')
| -rw-r--r-- | zenhttp/websocketasio.cpp | 237 |
1 files changed, 126 insertions, 111 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index eb01e010e..1952c97a2 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -32,6 +32,8 @@ using namespace std::literals; ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv); + using Clock = std::chrono::steady_clock; using TimePoint = Clock::time_point; @@ -104,7 +106,8 @@ class HttpMessageParser final : public MessageParser public: using HttpHeaders = std::unordered_map<std::string_view, std::string_view>; - HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) {} + HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); } + virtual ~HttpMessageParser() = default; int32_t StatusCode() const { return m_Parser.status_code; } @@ -120,6 +123,7 @@ public: bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); private: + void Initialize(); virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; int OnMessageBegin(); @@ -183,6 +187,25 @@ http_parser_settings HttpMessageParser::ParserSettings = { .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }}; +void +HttpMessageParser::Initialize() +{ + http_parser_init(&m_Parser, + m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST + : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE + : HTTP_BOTH); + m_Parser.data = this; + + m_UrlEntry = {}; + m_StatusEntry = {}; + m_CurrentHeader = {}; + m_BodyEntry = {}; + + m_IsMsgComplete = false; + + m_HeaderEntries.clear(); +} + ParseMessageResult HttpMessageParser::OnParseMessage(MemoryView Msg) { @@ -201,20 +224,7 @@ HttpMessageParser::OnParseMessage(MemoryView Msg) void HttpMessageParser::OnReset() { - http_parser_init(&m_Parser, - m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST - : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE - : HTTP_BOTH); - m_Parser.data = this; - - m_UrlEntry = {}; - m_StatusEntry = {}; - m_CurrentHeader = {}; - m_BodyEntry = {}; - - m_IsMsgComplete = false; - - m_HeaderEntries.clear(); + Initialize(); } int @@ -930,13 +940,15 @@ WsServer::RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage } /////////////////////////////////////////////////////////////////////////////// -class WsClient final : public WebSocketClient +class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient> { public: - WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Logger(zen::logging::Get("websocket-client")) {} + WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WsConnectionId::New()) {} virtual ~WsClient() { Disconnect(); } + std::shared_ptr<WsClient> AsShared() { return shared_from_this(); } + virtual bool Connect(const WebSocketConnectInfo& Info) override; virtual void Disconnect() override; virtual bool IsConnected() const { return false; } @@ -946,14 +958,16 @@ public: virtual void OnMessage(MessageCallback&& Cb) override; private: - WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } - void TriggerEvent(WebSocketEvent Evt); - void ReadMessage(); - void RouteMessage(CbPackage&& Msg); - spdlog::logger& Log() { return m_Logger; } + WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + void TriggerEvent(WebSocketEvent Evt); + void ReadMessage(); + void RouteMessage(CbPackage&& Msg); asio::io_context& m_IoCtx; - spdlog::logger& m_Logger; + WsConnectionId m_Id; std::unique_ptr<asio::ip::tcp::socket> m_Socket; std::unique_ptr<MessageParser> m_MsgParser; asio::streambuf m_ReadBuffer; @@ -984,11 +998,11 @@ WsClient::Connect(const WebSocketConnectInfo& Info) m_Host = m_Socket->remote_endpoint().address().to_string(); m_Port = Info.Port; - ZEN_INFO("connected to websocket server '{}:{}'", m_Host, m_Port); + ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port); } catch (std::exception& Err) { - ZEN_WARN("connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); + ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); SetState(WebSocketState::kError); m_Socket.reset(); @@ -1024,21 +1038,21 @@ WsClient::Connect(const WebSocketConnectInfo& Info) std::string HandshakeRequest = Sb.ToString(); asio::const_buffer Buffer = asio::buffer(HandshakeRequest); - ZEN_DEBUG("handshaking with '{}:{}'", m_Host, m_Port); + ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port); m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse); m_MsgParser->Reset(); - async_write(*m_Socket, Buffer, [this, _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { + async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { if (Ec) { - ZEN_ERROR("write data FAILED, reason '{}'", Ec.message()); + ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message()); - Disconnect(); + Self->Disconnect(); } else { - ReadMessage(); + Self->ReadMessage(); } }); @@ -1050,7 +1064,7 @@ WsClient::Disconnect() { if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) { - ZEN_INFO("closing connection to '{}:{}'", m_Host, m_Port); + ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port); if (m_Socket && m_Socket->is_open()) { @@ -1090,111 +1104,112 @@ WsClient::ReadMessage() { m_ReadBuffer.prepare(64 << 10); - async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t ByteCount) { - if (Ec) - { - ZEN_DEBUG("read data from '{}:{}' FAILED, reason '{}'", m_Host, m_Port, Ec.message()); - - return Disconnect(); - } + async_read(*m_Socket, + m_ReadBuffer, + asio::transfer_at_least(1), + [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable { + if (Ec) + { + ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message()); - switch (State()) - { - case WebSocketState::kHandshaking: - { - ZEN_ASSERT(m_MsgParser.get() != nullptr); + return Self->Disconnect(); + } - HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(m_MsgParser.get()); + const WebSocketState State = Self->State(); - asio::const_buffer Buffer = m_ReadBuffer.data(); - ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + switch (State) + { + case WebSocketState::kHandshaking: + { + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser()); - m_ReadBuffer.consume(size_t(Result.ByteCount)); + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode()); + ParseMessageResult Result = Parser.ParseMessage(MessageData); - return Disconnect(); - } + Self->ReadBuffer().consume(size_t(Result.ByteCount)); - if (Result.Status == ParseMessageStatus::kContinue) - { - return ReadMessage(); - } + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); - ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + return Self->Disconnect(); + } - if (Parser.StatusCode() != 101) - { - ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'", - m_Host, - m_Port, - Parser.StatusText(), - Parser.StatusCode()); + if (Result.Status == ParseMessageStatus::kContinue) + { + return Self->ReadMessage(); + } - return Disconnect(); - } + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); - ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText()); + if (Parser.StatusCode() != 101) + { + ZEN_LOG_WARN(LogWsClient, + "handshake FAILED, status '{}', status code '{}'", + Parser.StatusText(), + Parser.StatusCode()); - m_MsgParser = std::make_unique<WebSocketMessageParser>(); + return Self->Disconnect(); + } - SetState(WebSocketState::kConnected); - TriggerEvent(WebSocketEvent::kConnected); + ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText()); - ReadMessage(); - } - break; + Self->SetParser(std::make_unique<WebSocketMessageParser>()); + Self->SetState(WebSocketState::kConnected); + Self->TriggerEvent(WebSocketEvent::kConnected); - case WebSocketState::kConnected: - { - WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(m_MsgParser.get()); + Self->ReadMessage(); + } + break; - MemoryView MessageData = MemoryView(m_ReadBuffer.data().data(), ByteCount); + case WebSocketState::kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser()); - while (MessageData.IsEmpty() == false) - { - const ParseMessageResult Result = Parser.ParseMessage(MessageData); + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); - MessageData.RightChopInline(Result.ByteCount); + while (MessageData.IsEmpty() == false) + { + const ParseMessageResult Result = Parser.ParseMessage(MessageData); - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(MessageData.IsEmpty()); + MessageData.RightChopInline(Result.ByteCount); - return ReadMessage(); - } + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(MessageData.IsEmpty()); - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + return Self->ReadMessage(); + } - Parser.Reset(); + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); - continue; - } + Parser.Reset(); - CbPackage Message; - if (Parser.TryLoadMessage(Message)) - { - RouteMessage(std::move(Message)); - } - else - { - ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'"); - } + continue; + } - Parser.Reset(); - } + CbPackage Message; + if (Parser.TryLoadMessage(Message)) + { + Self->RouteMessage(std::move(Message)); + } + else + { + ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason 'invalid message'"); + } - m_ReadBuffer.consume(ByteCount); + Parser.Reset(); + } - ReadMessage(); - } - break; - }; - }); + Self->ReadBuffer().consume(ByteCount); + Self->ReadMessage(); + } + break; + } + }); } void @@ -1236,10 +1251,10 @@ WebSocketServer::Create() return std::make_unique<websocket::WsServer>(); } -std::unique_ptr<WebSocketClient> +std::shared_ptr<WebSocketClient> WebSocketClient::Create(asio::io_context& IoCtx) { - return std::make_unique<websocket::WsClient>(IoCtx); + return std::make_shared<websocket::WsClient>(IoCtx); } } // namespace zen |