From 386f56cd3b7f06fc30318adcbdc0753ddc02c127 Mon Sep 17 00:00:00 2001 From: Per Larsson Date: Tue, 15 Feb 2022 14:07:40 +0100 Subject: Renamed asio web socket impl. --- zenhttp/websocketasio.cpp | 683 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 683 insertions(+) create mode 100644 zenhttp/websocketasio.cpp (limited to 'zenhttp/websocketasio.cpp') diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp new file mode 100644 index 000000000..bb3999780 --- /dev/null +++ b/zenhttp/websocketasio.cpp @@ -0,0 +1,683 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +#include +#include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::asio_ws { + +using namespace std::literals; + +ZEN_DEFINE_LOG_CATEGORY_STATIC(WsLog, "websocket"sv); + +using Clock = std::chrono::steady_clock; +using TimePoint = Clock::time_point; + +/////////////////////////////////////////////////////////////////////////////// +namespace http_header { + static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; + static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; + static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; + static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; + static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; + static constexpr std::string_view Upgrade = "Upgrade"sv; +} // namespace http_header + +/////////////////////////////////////////////////////////////////////////////// +struct HttpParser +{ + HttpParser() + { + http_parser_init(&Parser, HTTP_REQUEST); + Parser.data = this; + } + + size_t Parse(asio::const_buffer Buffer) + { + return http_parser_execute(&Parser, &ParserSettings, reinterpret_cast(Buffer.data()), Buffer.size()); + } + + void GetHeaders(std::unordered_map& OutHeaders) + { + OutHeaders.reserve(HeaderEntries.size()); + + for (const auto& E : HeaderEntries) + { + auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size); + auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size); + + OutHeaders[Name] = Value; + } + } + + std::string ValidateWebSocketHandshake(std::unordered_map& Headers, std::string& OutReason) + { + static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + + std::string AcceptHash; + + if (Headers.contains(http_header::SecWebSocketKey) == false) + { + OutReason = "Missing header Sec-WebSocket-Key"; + return AcceptHash; + } + + if (Headers.contains(http_header::Upgrade) == false) + { + OutReason = "Missing header Upgrade"; + return AcceptHash; + } + + ExtendableStringBuilder<128> Sb; + Sb << Headers[http_header::SecWebSocketKey] << WebSocketGuid; + + SHA1Stream HashStream; + HashStream.Append(Sb.Data(), Sb.Size()); + + SHA1 Hash = HashStream.GetHash(); + + AcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); + Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), AcceptHash.data()); + + return AcceptHash; + } + + static void Initialize() + { + ParserSettings = {.on_message_begin = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast(P->data); + + Parser.Url = UrlEntry{}; + Parser.CurrentHeader = HeaderEntry{}; + Parser.IsUpgrade = false; + Parser.IsComplete = false; + + Parser.HeaderStream.Clear(); + Parser.HeaderEntries.clear(); + + return 0; + }, + .on_url = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast(P->data); + + Parser.Url.Offset = Parser.HeaderStream.Pos(); + Parser.Url.Size = Size; + + Parser.HeaderStream.Append(Data, uint32_t(Size)); + + return 0; + }, + .on_header_field = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast(P->data); + + if (Parser.CurrentHeader.Value.Size > 0) + { + Parser.HeaderEntries.push_back(Parser.CurrentHeader); + Parser.CurrentHeader = HeaderEntry{}; + } + + if (Parser.CurrentHeader.Name.Size == 0) + { + Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos(); + } + + Parser.CurrentHeader.Name.Size += Size; + + Parser.HeaderStream.Append(Data, Size); + + return 0; + }, + .on_header_value = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast(P->data); + + if (Parser.CurrentHeader.Value.Size == 0) + { + Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos(); + } + + Parser.CurrentHeader.Value.Size += Size; + + Parser.HeaderStream.Append(Data, Size); + + return 0; + }, + .on_headers_complete = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast(P->data); + + if (Parser.CurrentHeader.Value.Size > 0) + { + Parser.HeaderEntries.push_back(Parser.CurrentHeader); + Parser.CurrentHeader = HeaderEntry{}; + } + + Parser.IsUpgrade = P->upgrade > 0; + + return 0; + }, + .on_message_complete = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast(P->data); + Parser.IsComplete = true; + Parser.IsUpgrade = P->upgrade > 0; + return 0; + }}; + } + + struct MemStreamEntry + { + size_t Offset{}; + size_t Size{}; + }; + + class MemStream + { + public: + MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {} + + void Append(const char* Data, size_t Size) + { + const size_t NewSize = m_Size + Size; + + if (NewSize > m_Buf.size()) + { + m_Buf.resize(m_Buf.size() + m_BlockSize); + } + + memcpy(m_Buf.data() + m_Size, Data, Size); + m_Size += Size; + } + + const char* Data() const { return m_Buf.data(); } + size_t Pos() const { return m_Size; } + void Clear() { m_Size = 0; } + + private: + std::vector m_Buf; + size_t m_Size{}; + size_t m_BlockSize{}; + }; + + using UrlEntry = MemStreamEntry; + + struct HeaderEntry + { + MemStreamEntry Name; + MemStreamEntry Value; + }; + + static http_parser_settings ParserSettings; + + http_parser Parser; + MemStream HeaderStream; + std::vector HeaderEntries; + HeaderEntry CurrentHeader{}; + UrlEntry Url{}; + bool IsUpgrade = false; + bool IsComplete = false; +}; + +http_parser_settings HttpParser::ParserSettings; + +/////////////////////////////////////////////////////////////////////////////// +enum class WsConnectionState : uint32_t +{ + kDisconnected, + kHandshaking, + kConnected +}; + +/////////////////////////////////////////////////////////////////////////////// +class WsConnectionId +{ + static std::atomic_uint32_t WsConnectionCounter; + +public: + WsConnectionId() = default; + + uint32_t Value() const { return m_Value; } + + auto operator<=>(const WsConnectionId& RHS) const = default; + + static WsConnectionId New() { return WsConnectionId(WsConnectionCounter.fetch_add(1)); } + +private: + WsConnectionId(uint32_t Value) : m_Value(Value) {} + + uint32_t m_Value{}; +}; + +std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1}; + +class WsServer; + +/////////////////////////////////////////////////////////////////////////////// +class WsConnection : public std::enable_shared_from_this +{ +public: + WsConnection(WsServer& Server, WsConnectionId Id, std::unique_ptr Socket) + : m_Server(Server) + , m_Id(Id) + , m_Socket(std::move(Socket)) + , m_StartTime(Clock::now()) + , m_Status() + { + } + + ~WsConnection(); + + WsConnectionId Id() const { return m_Id; } + asio::ip::tcp::socket& Socket() { return *m_Socket; } + TimePoint StartTime() const { return m_StartTime; } + std::shared_ptr AsShared() { return shared_from_this(); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + HttpParser& ParserHttp() { return *m_HttpParser; } + WsConnectionState Close(); + WsConnectionState State() const { return static_cast(m_Status.load(std::memory_order_relaxed)); } + + WsConnectionState SetState(WsConnectionState NewState) { return static_cast(m_Status.exchange(uint32_t(NewState))); } + + void InitializeHttpParser() { m_HttpParser = std::make_unique(); } + void ReleaseHttpParser() { m_HttpParser.reset(); } + +private: + WsServer& m_Server; + WsConnectionId m_Id; + std::unique_ptr m_Socket; + std::unique_ptr m_HttpParser; + TimePoint m_StartTime; + std::atomic_uint32_t m_Status; + asio::streambuf m_ReadBuffer; +}; + +WsConnectionState +WsConnection::Close() +{ + using enum WsConnectionState; + + const auto PrevState = SetState(kDisconnected); + + if (PrevState != kDisconnected && m_Socket->is_open()) + { + m_Socket->close(); + } + + return PrevState; +} + +/////////////////////////////////////////////////////////////////////////////// +class WsThreadPool +{ +public: + WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} + void Start(uint32_t ThreadCount); + void Stop(); + +private: + asio::io_service& m_IoSvc; + std::vector m_Threads; +}; + +void +WsThreadPool::Start(uint32_t ThreadCount) +{ + ZEN_ASSERT(m_Threads.empty()); + + ZEN_LOG_DEBUG(WsLog, "starting '{}' websocket I/O thread(s)", ThreadCount); + + for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) + { + m_Threads.emplace_back([this, ThreadId = Idx + 1] { + try + { + m_IoSvc.run(); + } + catch (std::exception& Err) + { + ZEN_LOG_ERROR(WsLog, "process websocket I/O FAILED, reason '{}'", Err.what()); + } + + ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId); + }); + } +} + +void +WsThreadPool::Stop() +{ + for (std::thread& Thread : m_Threads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Threads.clear(); +} + +/////////////////////////////////////////////////////////////////////////////// +class WsServer final : public WebSocketServer +{ +public: + WsServer() = default; + virtual ~WsServer() { Shutdown(); } + + virtual bool Run(const WebSocketServerOptions& Options) override; + virtual void Shutdown() override; + +private: + friend class WsConnection; + + void AcceptConnection(); + void CloseConnection(WsConnection& Connection, const std::error_code& Ec); + void RemoveConnection(WsConnection& Connection); + void ReadConnection(WsConnection& Connection); + + struct IdHasher + { + size_t operator()(WsConnectionId Id) const { return size_t(Id.Value()); } + }; + + using ConnectionMap = std::unordered_map, IdHasher>; + + asio::io_service m_IoSvc; + std::unique_ptr m_Acceptor; + std::unique_ptr m_ThreadPool; + ConnectionMap m_Connections; + std::shared_mutex m_ConnMutex; + std::atomic_bool m_Running{}; +}; + +WsConnection::~WsConnection() +{ + m_Server.RemoveConnection(*this); +} + +bool +WsServer::Run(const WebSocketServerOptions& Options) +{ + HttpParser::Initialize(); + + m_Acceptor = std::make_unique(m_IoSvc, asio::ip::tcp::v6()); + + m_Acceptor->set_option(asio::ip::v6_only(false)); + m_Acceptor->set_option(asio::socket_base::reuse_address(true)); + m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + asio::error_code Ec; + m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec); + + if (Ec) + { + ZEN_LOG_ERROR(WsLog, "failed to bind websocket endpoint, error code '{}'", Ec.value()); + + return false; + } + + m_Acceptor->listen(); + m_Running = true; + + ZEN_LOG_INFO(WsLog, "web socket server running on port '{}'", Options.Port); + + AcceptConnection(); + + m_ThreadPool = std::make_unique(m_IoSvc); + m_ThreadPool->Start(Options.ThreadCount); + + return true; +} + +void +WsServer::Shutdown() +{ + if (m_Running) + { + ZEN_LOG_INFO(WsLog, "websocket server shutting down"); + + m_Running = false; + + m_Acceptor->close(); + m_Acceptor.reset(); + m_IoSvc.stop(); + + m_ThreadPool->Stop(); + } +} + +void +WsServer::AcceptConnection() +{ + auto Socket = std::make_unique(m_IoSvc); + asio::ip::tcp::socket& SocketRef = *Socket.get(); + + m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_LOG_WARN(WsLog, "accept connection FAILED, error code '{}'", Ec.value()); + } + else + { + auto ConnId = WsConnectionId::New(); + + ZEN_LOG_DEBUG(WsLog, "accept connection OK, ID '{}'", ConnId.Value()); + + auto Connection = std::make_shared(*this, ConnId, std::move(ConnectedSocket)); + + { + std::unique_lock _(m_ConnMutex); + m_Connections[ConnId] = Connection; + } + + Connection->InitializeHttpParser(); + Connection->SetState(WsConnectionState::kHandshaking); + + ReadConnection(*Connection); + } + + if (m_Running) + { + AcceptConnection(); + } + }); +} + +void +WsServer::CloseConnection(WsConnection& Connection, const std::error_code& Ec) +{ + if (const auto State = Connection.Close(); State != WsConnectionState::kDisconnected) + { + if (Ec) + { + ZEN_LOG_INFO(WsLog, + "closing connection '{}' ERROR, reason '{}' error code '{}'", + Connection.Id().Value(), + Ec.message(), + Ec.value()); + } + else + { + ZEN_LOG_INFO(WsLog, "closing connection '{}'", Connection.Id().Value()); + } + } +} + +void +WsServer::RemoveConnection(WsConnection& Connection) +{ + ZEN_LOG_INFO(WsLog, "removing connection '{}'", Connection.Id().Value()); +} + +void +WsServer::ReadConnection(WsConnection& Connection) +{ + Connection.ReadBuffer().prepare(64 << 10); + + asio::async_read( + Connection.Socket(), + Connection.ReadBuffer(), + asio::transfer_at_least(1), + [this, &Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { + if (ReadEc) + { + return CloseConnection(Connection, ReadEc); + } + + ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection.Id().Value()); + + using enum WsConnectionState; + + switch (Connection.State()) + { + case kHandshaking: + { + HttpParser& Parser = Connection.ParserHttp(); + const size_t Consumed = Parser.Parse(Connection.ReadBuffer().data()); + Connection.ReadBuffer().consume(Consumed); + + if (Parser.IsComplete == false) + { + return ReadConnection(Connection); + } + + if (Parser.IsUpgrade == false) + { + ZEN_LOG_DEBUG(WsLog, + "handshake with connection '{}' FAILED, reason 'not an upgrade request'", + Connection.Id().Value()); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; + + return async_write(Connection.Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + CloseConnection(Connection, WriteEc); + } + else + { + Connection.InitializeHttpParser(); + Connection.SetState(WsConnectionState::kHandshaking); + + ReadConnection(Connection); + } + }); + } + + std::unordered_map Headers; + Parser.GetHeaders(Headers); + + std::string Reason; + std::string AcceptHash = Parser.ValidateWebSocketHandshake(Headers, Reason); + + if (AcceptHash.empty()) + { + ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection.Id().Value(), Reason); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; + + return async_write(Connection.Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + CloseConnection(Connection, WriteEc); + } + else + { + Connection.InitializeHttpParser(); + Connection.SetState(WsConnectionState::kHandshaking); + + ReadConnection(Connection); + } + }); + } + + ExtendableStringBuilder<128> Sb; + + Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: Upgrade\r\n"sv; + + // TODO: Verify protocol + if (Headers.contains(http_header::SecWebSocketProtocol)) + { + Sb << http_header::SecWebSocketProtocol << ": " << Headers[http_header::SecWebSocketProtocol] << "\r\n"; + } + + Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; + Sb << "\r\n"sv; + + std::string Response = Sb.ToString(); + asio::const_buffer Buffer = asio::buffer(Response); + + ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection.Id().Value()); + + async_write(Connection.Socket(), + Buffer, + [this, &Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + CloseConnection(Connection, WriteEc); + } + else + { + ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection.Id().Value()); + + Connection.ReleaseHttpParser(); + Connection.SetState(kConnected); + + ReadConnection(Connection); + } + }); + } + break; + + case kConnected: + { + // TODO: Implement RPC API + } + break; + + default: + break; + }; + }); +} + +} // namespace zen::asio_ws + +namespace zen { + +std::unique_ptr +WebSocketServer::Create() +{ + return std::make_unique(); +} + +} // namespace zen -- cgit v1.2.3 From 87bb9700722e8319aa58484bba03e398dedede87 Mon Sep 17 00:00:00 2001 From: Per Larsson Date: Wed, 16 Feb 2022 12:32:27 +0100 Subject: Added websocket message parser. --- zenhttp/websocketasio.cpp | 297 +++++++++++++++++++++++++++++++++------------- 1 file changed, 215 insertions(+), 82 deletions(-) (limited to 'zenhttp/websocketasio.cpp') diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index bb3999780..ad8434a5a 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -3,9 +3,12 @@ #include #include +#include +#include #include #include #include +#include #include #include @@ -61,8 +64,8 @@ struct HttpParser for (const auto& E : HeaderEntries) { - auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size); - auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size); + auto Name = std::string_view((const char*)HeaderStream.Data() + E.Name.Offset, E.Name.Size); + auto Value = std::string_view((const char*)HeaderStream.Data() + E.Value.Offset, E.Value.Size); OutHeaders[Name] = Value; } @@ -120,10 +123,10 @@ struct HttpParser [](http_parser* P, const char* Data, size_t Size) { HttpParser& Parser = *reinterpret_cast(P->data); - Parser.Url.Offset = Parser.HeaderStream.Pos(); + Parser.Url.Offset = Parser.HeaderStream.CurrentOffset(); Parser.Url.Size = Size; - Parser.HeaderStream.Append(Data, uint32_t(Size)); + Parser.HeaderStream.Write(Data, uint32_t(Size)); return 0; }, @@ -139,12 +142,12 @@ struct HttpParser if (Parser.CurrentHeader.Name.Size == 0) { - Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos(); + Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.CurrentOffset(); } Parser.CurrentHeader.Name.Size += Size; - Parser.HeaderStream.Append(Data, Size); + Parser.HeaderStream.Write(Data, Size); return 0; }, @@ -154,12 +157,12 @@ struct HttpParser if (Parser.CurrentHeader.Value.Size == 0) { - Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos(); + Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.CurrentOffset(); } Parser.CurrentHeader.Value.Size += Size; - Parser.HeaderStream.Append(Data, Size); + Parser.HeaderStream.Write(Data, Size); return 0; }, @@ -192,34 +195,6 @@ struct HttpParser size_t Size{}; }; - class MemStream - { - public: - MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {} - - void Append(const char* Data, size_t Size) - { - const size_t NewSize = m_Size + Size; - - if (NewSize > m_Buf.size()) - { - m_Buf.resize(m_Buf.size() + m_BlockSize); - } - - memcpy(m_Buf.data() + m_Size, Data, Size); - m_Size += Size; - } - - const char* Data() const { return m_Buf.data(); } - size_t Pos() const { return m_Size; } - void Clear() { m_Size = 0; } - - private: - std::vector m_Buf; - size_t m_Size{}; - size_t m_BlockSize{}; - }; - using UrlEntry = MemStreamEntry; struct HeaderEntry @@ -231,7 +206,7 @@ struct HttpParser static http_parser_settings ParserSettings; http_parser Parser; - MemStream HeaderStream; + SimpleBinaryWriter HeaderStream; std::vector HeaderEntries; HeaderEntry CurrentHeader{}; UrlEntry Url{}; @@ -241,6 +216,93 @@ struct HttpParser http_parser_settings HttpParser::ParserSettings; +/////////////////////////////////////////////////////////////////////////////// +class WsMessageParser +{ +public: + WsMessageParser() {} + + void Reset() + { + m_Header.reset(); + m_Stream.Clear(); + } + + bool Parse(asio::const_buffer Buffer, size_t& OutConsumedBytes) + { + if (m_Header.has_value()) + { + OutConsumedBytes = Min(m_Header.value().ContentLength, Buffer.size()); + + m_Stream.Write(Buffer.data(), OutConsumedBytes); + + return true; + } + + const size_t PrevOffset = m_Stream.CurrentOffset(); + const size_t BytesToWrite = Min(sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(), Buffer.size()); + const size_t RemainingBytes = Buffer.size() - BytesToWrite; + + m_Stream.Write(Buffer.data(), BytesToWrite); + + if (m_Stream.CurrentOffset() < sizeof(zen::WebSocketMessageHeader)) + { + OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + + return true; + } + + zen::WebSocketMessageHeader Header; + if (zen::WebSocketMessageHeader::Read(m_Stream.GetView(), Header) == false) + { + OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + + return false; + } + + m_Header = Header; + + if (RemainingBytes > 0) + { + const size_t RemainingBytesToWrite = Min(m_Header.value().ContentLength, RemainingBytes); + + m_Stream.Write(reinterpret_cast(Buffer.data()) + BytesToWrite, RemainingBytesToWrite); + } + + OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + + return true; + } + + bool IsComplete() + { + if (m_Header.has_value()) + { + const size_t RemainingBytes = m_Header.value().ContentLength + sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(); + + return RemainingBytes == 0; + } + + return false; + } + + bool TryLoadMessage(CbPackage& OutPackage) + { + if (IsComplete()) + { + BinaryReader Reader(m_Stream.Data(), m_Stream.Size()); + + return OutPackage.TryLoad(Reader); + } + + return false; + } + +private: + SimpleBinaryWriter m_Stream{64 << 10}; + std::optional m_Header; +}; + /////////////////////////////////////////////////////////////////////////////// enum class WsConnectionState : uint32_t { @@ -294,6 +356,7 @@ public: std::shared_ptr AsShared() { return shared_from_this(); } asio::streambuf& ReadBuffer() { return m_ReadBuffer; } HttpParser& ParserHttp() { return *m_HttpParser; } + WsMessageParser& MessageParser() { return m_MsgParser; } WsConnectionState Close(); WsConnectionState State() const { return static_cast(m_Status.load(std::memory_order_relaxed)); } @@ -307,6 +370,7 @@ private: WsConnectionId m_Id; std::unique_ptr m_Socket; std::unique_ptr m_HttpParser; + WsMessageParser m_MsgParser; TimePoint m_StartTime; std::atomic_uint32_t m_Status; asio::streambuf m_ReadBuffer; @@ -392,9 +456,11 @@ private: friend class WsConnection; void AcceptConnection(); - void CloseConnection(WsConnection& Connection, const std::error_code& Ec); - void RemoveConnection(WsConnection& Connection); - void ReadConnection(WsConnection& Connection); + void CloseConnection(std::shared_ptr Connection, const std::error_code& Ec); + void RemoveConnection(const WsConnectionId Id); + + void ReadMessage(std::shared_ptr Connection); + void RouteMessage(const CbPackage& Msg); struct IdHasher { @@ -413,7 +479,7 @@ private: WsConnection::~WsConnection() { - m_Server.RemoveConnection(*this); + m_Server.RemoveConnection(m_Id); } bool @@ -496,7 +562,7 @@ WsServer::AcceptConnection() Connection->InitializeHttpParser(); Connection->SetState(WsConnectionState::kHandshaking); - ReadConnection(*Connection); + ReadMessage(Connection); } if (m_Running) @@ -507,84 +573,87 @@ WsServer::AcceptConnection() } void -WsServer::CloseConnection(WsConnection& Connection, const std::error_code& Ec) +WsServer::CloseConnection(std::shared_ptr Connection, const std::error_code& Ec) { - if (const auto State = Connection.Close(); State != WsConnectionState::kDisconnected) + if (const auto State = Connection->Close(); State != WsConnectionState::kDisconnected) { if (Ec) { - ZEN_LOG_INFO(WsLog, - "closing connection '{}' ERROR, reason '{}' error code '{}'", - Connection.Id().Value(), - Ec.message(), - Ec.value()); + ZEN_LOG_INFO(WsLog, "connection '{}' closed, ERROR '{}' error code '{}'", Connection->Id().Value(), Ec.message(), Ec.value()); } else { - ZEN_LOG_INFO(WsLog, "closing connection '{}'", Connection.Id().Value()); + ZEN_LOG_INFO(WsLog, "connection '{}' closed", Connection->Id().Value()); } } + + const WsConnectionId Id = Connection->Id(); + + { + std::unique_lock _(m_ConnMutex); + m_Connections.erase(Id); + } } void -WsServer::RemoveConnection(WsConnection& Connection) +WsServer::RemoveConnection(const WsConnectionId Id) { - ZEN_LOG_INFO(WsLog, "removing connection '{}'", Connection.Id().Value()); + ZEN_LOG_INFO(WsLog, "removing connection '{}'", Id.Value()); } void -WsServer::ReadConnection(WsConnection& Connection) +WsServer::ReadMessage(std::shared_ptr Connection) { - Connection.ReadBuffer().prepare(64 << 10); + Connection->ReadBuffer().prepare(64 << 10); asio::async_read( - Connection.Socket(), - Connection.ReadBuffer(), + Connection->Socket(), + Connection->ReadBuffer(), asio::transfer_at_least(1), - [this, &Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { + [this, Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { if (ReadEc) { return CloseConnection(Connection, ReadEc); } - ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection.Id().Value()); + ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection->Id().Value()); using enum WsConnectionState; - switch (Connection.State()) + switch (Connection->State()) { case kHandshaking: { - HttpParser& Parser = Connection.ParserHttp(); - const size_t Consumed = Parser.Parse(Connection.ReadBuffer().data()); - Connection.ReadBuffer().consume(Consumed); + HttpParser& Parser = Connection->ParserHttp(); + const size_t Consumed = Parser.Parse(Connection->ReadBuffer().data()); + Connection->ReadBuffer().consume(Consumed); if (Parser.IsComplete == false) { - return ReadConnection(Connection); + return ReadMessage(Connection); } if (Parser.IsUpgrade == false) { ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason 'not an upgrade request'", - Connection.Id().Value()); + Connection->Id().Value()); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; - return async_write(Connection.Socket(), + return async_write(Connection->Socket(), asio::buffer(UpgradeRequiredResponse), - [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + [this, Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { CloseConnection(Connection, WriteEc); } else { - Connection.InitializeHttpParser(); - Connection.SetState(WsConnectionState::kHandshaking); + Connection->InitializeHttpParser(); + Connection->SetState(WsConnectionState::kHandshaking); - ReadConnection(Connection); + ReadMessage(Connection); } }); } @@ -597,11 +666,11 @@ WsServer::ReadConnection(WsConnection& Connection) if (AcceptHash.empty()) { - ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection.Id().Value(), Reason); + ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection->Id().Value(), Reason); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; - return async_write(Connection.Socket(), + return async_write(Connection->Socket(), asio::buffer(UpgradeRequiredResponse), [this, &Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) @@ -610,10 +679,11 @@ WsServer::ReadConnection(WsConnection& Connection) } else { - Connection.InitializeHttpParser(); - Connection.SetState(WsConnectionState::kHandshaking); + // TODO: Always close connection? + Connection->InitializeHttpParser(); + Connection->SetState(WsConnectionState::kHandshaking); - ReadConnection(Connection); + ReadMessage(Connection); } }); } @@ -636,23 +706,24 @@ WsServer::ReadConnection(WsConnection& Connection) std::string Response = Sb.ToString(); asio::const_buffer Buffer = asio::buffer(Response); - ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection.Id().Value()); + ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection->Id().Value()); - async_write(Connection.Socket(), + async_write(Connection->Socket(), Buffer, - [this, &Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) { + [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { CloseConnection(Connection, WriteEc); } else { - ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection.Id().Value()); + ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection->Id().Value()); - Connection.ReleaseHttpParser(); - Connection.SetState(kConnected); + Connection->ReleaseHttpParser(); + Connection->SetState(kConnected); + Connection->MessageParser().Reset(); - ReadConnection(Connection); + ReadMessage(Connection); } }); } @@ -660,7 +731,42 @@ WsServer::ReadConnection(WsConnection& Connection) case kConnected: { - // TODO: Implement RPC API + for (;;) + { + if (Connection->ReadBuffer().size() == 0) + { + break; + } + + WsMessageParser& MessageParser = Connection->MessageParser(); + + size_t ConsumedBytes{}; + const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes); + + Connection->ReadBuffer().consume(ConsumedBytes); + + if (Ok == false) + { + ZEN_LOG_WARN(WsLog, "parse websocket message FAILED, connection '{}'", Connection->Id().Value()); + MessageParser.Reset(); + } + + if (Ok == false || MessageParser.IsComplete() == false) + { + continue; + } + + CbPackage Message; + if (MessageParser.TryLoadMessage(Message) == false) + { + ZEN_LOG_WARN(WsLog, "invalid websocket message, connection '{}'", Connection->Id().Value()); + continue; + } + + RouteMessage(Message); + } + + ReadMessage(Connection); } break; @@ -670,10 +776,37 @@ WsServer::ReadConnection(WsConnection& Connection) }); } +void +WsServer::RouteMessage(const CbPackage& Msg) +{ + ZEN_UNUSED(Msg); + ZEN_LOG_DEBUG(WsLog, "routing message"); +} + } // namespace zen::asio_ws namespace zen { +bool +WebSocketMessageHeader::IsValid() const +{ + return Magic == ExpectedMagic && ContentLength != 0 && Crc32 != 0; +} + +bool +WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeader) +{ + if (Memory.GetSize() < sizeof(WebSocketMessageHeader)) + { + return false; + } + + void* Dst = &OutHeader; + memcpy(Dst, Memory.GetData(), sizeof(WebSocketMessageHeader)); + + return OutHeader.IsValid(); +} + std::unique_ptr WebSocketServer::Create() { -- cgit v1.2.3 From 4b9bac3c5baf7633cd51cffcf8e63cb5527ddb36 Mon Sep 17 00:00:00 2001 From: Per Larsson Date: Fri, 18 Feb 2022 06:56:20 +0100 Subject: Simple websocket client/server test. --- zenhttp/websocketasio.cpp | 1047 +++++++++++++++++++++++++++++++-------------- 1 file changed, 727 insertions(+), 320 deletions(-) (limited to 'zenhttp/websocketasio.cpp') diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index ad8434a5a..b800892d2 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -1,12 +1,13 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include +#include #include #include #include #include #include +#include #include #include #include @@ -25,11 +26,11 @@ ZEN_THIRD_PARTY_INCLUDES_START #include ZEN_THIRD_PARTY_INCLUDES_END -namespace zen::asio_ws { +namespace zen::websocket { using namespace std::literals; -ZEN_DEFINE_LOG_CATEGORY_STATIC(WsLog, "websocket"sv); +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); using Clock = std::chrono::steady_clock; using TimePoint = Clock::time_point; @@ -45,263 +46,396 @@ namespace http_header { } // namespace http_header /////////////////////////////////////////////////////////////////////////////// -struct HttpParser +enum class ParseMessageStatus : uint32_t { - HttpParser() + kError, + kContinue, + kDone, +}; + +struct ParseMessageResult +{ + ParseMessageStatus Status{}; + size_t ByteCount{}; + std::optional Reason; +}; + +class MessageParser +{ +public: + virtual ~MessageParser() = default; + + ParseMessageResult ParseMessage(MemoryView Msg); + void Reset(); + +protected: + MessageParser() = default; + + virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; + virtual void OnReset() = 0; + + SimpleBinaryWriter m_Stream; +}; + +ParseMessageResult +MessageParser::ParseMessage(MemoryView Msg) +{ + return OnParseMessage(Msg); +} + +void +MessageParser::Reset() +{ + OnReset(); + + m_Stream.Clear(); +} + +/////////////////////////////////////////////////////////////////////////////// +enum class HttpMessageParserType +{ + kRequest, + kResponse, + kBoth +}; + +class HttpMessageParser final : public MessageParser +{ +public: + using HttpHeaders = std::unordered_map; + + HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) {} + virtual ~HttpMessageParser() = default; + + int32_t StatusCode() const { return m_Parser.status_code; } + bool IsUpgrade() const { return m_Parser.upgrade != 0; } + HttpHeaders& Headers() { return m_Headers; } + MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } + + std::string_view StatusText() const { - http_parser_init(&Parser, HTTP_REQUEST); - Parser.data = this; + return std::string_view(reinterpret_cast(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); } - size_t Parse(asio::const_buffer Buffer) + bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); + +private: + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; + int OnMessageBegin(); + int OnUrl(MemoryView Url); + int OnStatus(MemoryView Status); + int OnHeaderField(MemoryView HeaderField); + int OnHeaderValue(MemoryView HeaderValue); + int OnHeadersComplete(); + int OnBody(MemoryView Body); + int OnMessageComplete(); + + struct StreamEntry { - return http_parser_execute(&Parser, &ParserSettings, reinterpret_cast(Buffer.data()), Buffer.size()); - } + uint64_t Offset{}; + uint64_t Size{}; + }; - void GetHeaders(std::unordered_map& OutHeaders) + struct HeaderStreamEntry { - OutHeaders.reserve(HeaderEntries.size()); + StreamEntry Field{}; + StreamEntry Value{}; + }; - for (const auto& E : HeaderEntries) - { - auto Name = std::string_view((const char*)HeaderStream.Data() + E.Name.Offset, E.Name.Size); - auto Value = std::string_view((const char*)HeaderStream.Data() + E.Value.Offset, E.Value.Size); + HttpMessageParserType m_Type; + http_parser m_Parser; + StreamEntry m_UrlEntry; + StreamEntry m_StatusEntry; + StreamEntry m_BodyEntry; + HeaderStreamEntry m_CurrentHeader; + std::vector m_HeaderEntries; + HttpHeaders m_Headers; + bool m_IsMsgComplete{false}; - OutHeaders[Name] = Value; - } - } + static http_parser_settings ParserSettings; +}; - std::string ValidateWebSocketHandshake(std::unordered_map& Headers, std::string& OutReason) +http_parser_settings HttpMessageParser::ParserSettings = { + .on_message_begin = [](http_parser* P) { return reinterpret_cast(P->data)->OnMessageBegin(); }, + + .on_url = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast(P->data)->OnUrl(MemoryView(Data, Size)); }, + + .on_status = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast(P->data)->OnStatus(MemoryView(Data, Size)); }, + + .on_header_field = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast(P->data)->OnHeaderField(MemoryView(Data, Size)); }, + + .on_header_value = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, + + .on_headers_complete = [](http_parser* P) { return reinterpret_cast(P->data)->OnHeadersComplete(); }, + + .on_body = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast(P->data)->OnBody(MemoryView(Data, Size)); }, + + .on_message_complete = [](http_parser* P) { return reinterpret_cast(P->data)->OnMessageComplete(); }}; + +ParseMessageResult +HttpMessageParser::OnParseMessage(MemoryView Msg) +{ + const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast(Msg.GetData()), Msg.GetSize()); + + auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; + + if (m_Parser.http_errno != 0) { - static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + Status = ParseMessageStatus::kError; + } - std::string AcceptHash; + return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; +} - if (Headers.contains(http_header::SecWebSocketKey) == false) - { - OutReason = "Missing header Sec-WebSocket-Key"; - return AcceptHash; - } +void +HttpMessageParser::OnReset() +{ + http_parser_init(&m_Parser, + m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST + : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE + : HTTP_BOTH); + m_Parser.data = this; - if (Headers.contains(http_header::Upgrade) == false) - { - OutReason = "Missing header Upgrade"; - return AcceptHash; - } + m_UrlEntry = {}; + m_StatusEntry = {}; + m_CurrentHeader = {}; + m_BodyEntry = {}; + + m_IsMsgComplete = false; - ExtendableStringBuilder<128> Sb; - Sb << Headers[http_header::SecWebSocketKey] << WebSocketGuid; + m_HeaderEntries.clear(); +} + +int +HttpMessageParser::OnMessageBegin() +{ + ZEN_ASSERT(m_IsMsgComplete == false); + ZEN_ASSERT(m_HeaderEntries.empty()); + ZEN_ASSERT(m_Headers.empty()); + + return 0; +} + +int +HttpMessageParser::OnStatus(MemoryView Status) +{ + m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; + + m_Stream.Write(Status); + + return 0; +} - SHA1Stream HashStream; - HashStream.Append(Sb.Data(), Sb.Size()); +int +HttpMessageParser::OnUrl(MemoryView Url) +{ + m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; - SHA1 Hash = HashStream.GetHash(); + m_Stream.Write(Url); - AcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); - Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), AcceptHash.data()); + return 0; +} - return AcceptHash; +int +HttpMessageParser::OnHeaderField(MemoryView HeaderField) +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; } - static void Initialize() + if (m_CurrentHeader.Field.Size == 0) { - ParserSettings = {.on_message_begin = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast(P->data); - - Parser.Url = UrlEntry{}; - Parser.CurrentHeader = HeaderEntry{}; - Parser.IsUpgrade = false; - Parser.IsComplete = false; - - Parser.HeaderStream.Clear(); - Parser.HeaderEntries.clear(); - - return 0; - }, - .on_url = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast(P->data); - - Parser.Url.Offset = Parser.HeaderStream.CurrentOffset(); - Parser.Url.Size = Size; - - Parser.HeaderStream.Write(Data, uint32_t(Size)); - - return 0; - }, - .on_header_field = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast(P->data); - - if (Parser.CurrentHeader.Value.Size > 0) - { - Parser.HeaderEntries.push_back(Parser.CurrentHeader); - Parser.CurrentHeader = HeaderEntry{}; - } - - if (Parser.CurrentHeader.Name.Size == 0) - { - Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.CurrentOffset(); - } - - Parser.CurrentHeader.Name.Size += Size; - - Parser.HeaderStream.Write(Data, Size); - - return 0; - }, - .on_header_value = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast(P->data); - - if (Parser.CurrentHeader.Value.Size == 0) - { - Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.CurrentOffset(); - } - - Parser.CurrentHeader.Value.Size += Size; - - Parser.HeaderStream.Write(Data, Size); - - return 0; - }, - .on_headers_complete = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast(P->data); - - if (Parser.CurrentHeader.Value.Size > 0) - { - Parser.HeaderEntries.push_back(Parser.CurrentHeader); - Parser.CurrentHeader = HeaderEntry{}; - } - - Parser.IsUpgrade = P->upgrade > 0; - - return 0; - }, - .on_message_complete = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast(P->data); - Parser.IsComplete = true; - Parser.IsUpgrade = P->upgrade > 0; - return 0; - }}; + m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); } - struct MemStreamEntry + m_CurrentHeader.Field.Size += HeaderField.GetSize(); + + m_Stream.Write(HeaderField); + + return 0; +} + +int +HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) +{ + if (m_CurrentHeader.Value.Size == 0) { - size_t Offset{}; - size_t Size{}; - }; + m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Value.Size += HeaderValue.GetSize(); - using UrlEntry = MemStreamEntry; + m_Stream.Write(HeaderValue); - struct HeaderEntry + return 0; +} + +int +HttpMessageParser::OnHeadersComplete() +{ + if (m_CurrentHeader.Value.Size > 0) { - MemStreamEntry Name; - MemStreamEntry Value; - }; + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } - static http_parser_settings ParserSettings; + m_Headers.clear(); + m_Headers.reserve(m_HeaderEntries.size()); - http_parser Parser; - SimpleBinaryWriter HeaderStream; - std::vector HeaderEntries; - HeaderEntry CurrentHeader{}; - UrlEntry Url{}; - bool IsUpgrade = false; - bool IsComplete = false; -}; + const char* StreamData = reinterpret_cast(m_Stream.Data()); + + for (const auto& Entry : m_HeaderEntries) + { + auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); + auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); -http_parser_settings HttpParser::ParserSettings; + m_Headers.try_emplace(std::move(Field), std::move(Value)); + } -/////////////////////////////////////////////////////////////////////////////// -class WsMessageParser + return 0; +} + +int +HttpMessageParser::OnBody(MemoryView Body) { -public: - WsMessageParser() {} + m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; + + m_Stream.Write(Body); + + return 0; +} + +int +HttpMessageParser::OnMessageComplete() +{ + m_IsMsgComplete = true; + + return 0; +} + +bool +HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) +{ + static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + + OutAcceptHash = std::string(); - void Reset() + if (m_Headers.contains(http_header::SecWebSocketKey) == false) { - m_Header.reset(); - m_Stream.Clear(); + OutReason = "Missing header Sec-WebSocket-Key"; + return false; } - bool Parse(asio::const_buffer Buffer, size_t& OutConsumedBytes) + if (m_Headers.contains(http_header::Upgrade) == false) { - if (m_Header.has_value()) - { - OutConsumedBytes = Min(m_Header.value().ContentLength, Buffer.size()); + OutReason = "Missing header Upgrade"; + return false; + } - m_Stream.Write(Buffer.data(), OutConsumedBytes); + ExtendableStringBuilder<128> Sb; + Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; - return true; - } + SHA1Stream HashStream; + HashStream.Append(Sb.Data(), Sb.Size()); - const size_t PrevOffset = m_Stream.CurrentOffset(); - const size_t BytesToWrite = Min(sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(), Buffer.size()); - const size_t RemainingBytes = Buffer.size() - BytesToWrite; + SHA1 Hash = HashStream.GetHash(); - m_Stream.Write(Buffer.data(), BytesToWrite); + OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); + Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); - if (m_Stream.CurrentOffset() < sizeof(zen::WebSocketMessageHeader)) - { - OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + return true; +} - return true; - } +/////////////////////////////////////////////////////////////////////////////// +class WebSocketMessageParser final : public MessageParser +{ +public: + WebSocketMessageParser() : MessageParser() {} - zen::WebSocketMessageHeader Header; - if (zen::WebSocketMessageHeader::Read(m_Stream.GetView(), Header) == false) - { - OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + bool TryLoadMessage(CbPackage& OutMsg); - return false; - } +private: + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; - m_Header = Header; + SimpleBinaryWriter m_HeaderStream; + WebSocketMessageHeader m_Header; +}; - if (RemainingBytes > 0) - { - const size_t RemainingBytesToWrite = Min(m_Header.value().ContentLength, RemainingBytes); +ParseMessageResult +WebSocketMessageParser::OnParseMessage(MemoryView Msg) +{ + const uint64_t PrevOffset = m_Stream.CurrentOffset(); - m_Stream.Write(reinterpret_cast(Buffer.data()) + BytesToWrite, RemainingBytesToWrite); - } + if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + { + const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_HeaderStream.CurrentOffset(); - OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + m_HeaderStream.Write(Msg.Left(RemaingHeaderSize)); - return true; - } + Msg.RightChopInline(RemaingHeaderSize); - bool IsComplete() - { - if (m_Header.has_value()) + if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader)) { - const size_t RemainingBytes = m_Header.value().ContentLength + sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(); - - return RemainingBytes == 0; + return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - return false; + const bool IsValidHeader = WebSocketMessageHeader::Read(m_HeaderStream.GetView(), m_Header); + + if (IsValidHeader == false) + { + return {.Status = ParseMessageStatus::kError, .Reason = std::string("Invalid websocket message header")}; + } } - bool TryLoadMessage(CbPackage& OutPackage) + if (Msg.GetSize() == 0) { - if (IsComplete()) - { - BinaryReader Reader(m_Stream.Data(), m_Stream.Size()); + return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } - return OutPackage.TryLoad(Reader); - } + const uint64_t RemaingContentSize = m_Header.ContentLength - m_HeaderStream.CurrentOffset(); - return false; + m_Stream.Write(Msg.Left(RemaingContentSize)); + + const auto Status = m_Stream.CurrentOffset() == m_Header.ContentLength ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; + + return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; +} + +void +WebSocketMessageParser::OnReset() +{ + m_HeaderStream.Clear(); + m_Header = {}; +} + +bool +WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) +{ + const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength; + + if (IsParsed) + { + BinaryReader Reader(m_Stream.Data(), m_Stream.Size()); + + return OutMsg.TryLoad(Reader); } -private: - SimpleBinaryWriter m_Stream{64 << 10}; - std::optional m_Header; -}; + return false; +} /////////////////////////////////////////////////////////////////////////////// enum class WsConnectionState : uint32_t @@ -346,34 +480,34 @@ public: , m_StartTime(Clock::now()) , m_Status() { + m_RemoteAddr = m_Socket->remote_endpoint().address().to_string(); } ~WsConnection(); - WsConnectionId Id() const { return m_Id; } - asio::ip::tcp::socket& Socket() { return *m_Socket; } - TimePoint StartTime() const { return m_StartTime; } std::shared_ptr AsShared() { return shared_from_this(); } - asio::streambuf& ReadBuffer() { return m_ReadBuffer; } - HttpParser& ParserHttp() { return *m_HttpParser; } - WsMessageParser& MessageParser() { return m_MsgParser; } - WsConnectionState Close(); - WsConnectionState State() const { return static_cast(m_Status.load(std::memory_order_relaxed)); } + WsConnectionId Id() const { return m_Id; } + std::string_view RemoteAddr() const { return m_RemoteAddr; } + asio::ip::tcp::socket& Socket() { return *m_Socket; } + TimePoint StartTime() const { return m_StartTime; } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + WsConnectionState Close(); + WsConnectionState State() const { return static_cast(m_Status.load(std::memory_order_relaxed)); } WsConnectionState SetState(WsConnectionState NewState) { return static_cast(m_Status.exchange(uint32_t(NewState))); } - void InitializeHttpParser() { m_HttpParser = std::make_unique(); } - void ReleaseHttpParser() { m_HttpParser.reset(); } + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr&& Parser) { m_MsgParser = std::move(Parser); } private: WsServer& m_Server; WsConnectionId m_Id; std::unique_ptr m_Socket; - std::unique_ptr m_HttpParser; - WsMessageParser m_MsgParser; + std::unique_ptr m_MsgParser; TimePoint m_StartTime; - std::atomic_uint32_t m_Status; asio::streambuf m_ReadBuffer; + std::string m_RemoteAddr; + std::atomic_uint32_t m_Status; }; WsConnectionState @@ -402,6 +536,7 @@ public: private: asio::io_service& m_IoSvc; std::vector m_Threads; + std::atomic_bool m_Running{false}; }; void @@ -409,18 +544,28 @@ WsThreadPool::Start(uint32_t ThreadCount) { ZEN_ASSERT(m_Threads.empty()); - ZEN_LOG_DEBUG(WsLog, "starting '{}' websocket I/O thread(s)", ThreadCount); + ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount); + + m_Running = true; for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) { m_Threads.emplace_back([this, ThreadId = Idx + 1] { - try - { - m_IoSvc.run(); - } - catch (std::exception& Err) + for (;;) { - ZEN_LOG_ERROR(WsLog, "process websocket I/O FAILED, reason '{}'", Err.what()); + if (m_Running == false) + { + break; + } + + try + { + m_IoSvc.run(); + } + catch (std::exception& Err) + { + ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what()); + } } ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId); @@ -431,15 +576,20 @@ WsThreadPool::Start(uint32_t ThreadCount) void WsThreadPool::Stop() { - for (std::thread& Thread : m_Threads) + if (m_Running) { - if (Thread.joinable()) + m_Running = false; + + for (std::thread& Thread : m_Threads) { - Thread.join(); + if (Thread.joinable()) + { + Thread.join(); + } } - } - m_Threads.clear(); + m_Threads.clear(); + } } /////////////////////////////////////////////////////////////////////////////// @@ -485,8 +635,6 @@ WsConnection::~WsConnection() bool WsServer::Run(const WebSocketServerOptions& Options) { - HttpParser::Initialize(); - m_Acceptor = std::make_unique(m_IoSvc, asio::ip::tcp::v6()); m_Acceptor->set_option(asio::ip::v6_only(false)); @@ -500,7 +648,7 @@ WsServer::Run(const WebSocketServerOptions& Options) if (Ec) { - ZEN_LOG_ERROR(WsLog, "failed to bind websocket endpoint, error code '{}'", Ec.value()); + ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value()); return false; } @@ -508,7 +656,7 @@ WsServer::Run(const WebSocketServerOptions& Options) m_Acceptor->listen(); m_Running = true; - ZEN_LOG_INFO(WsLog, "web socket server running on port '{}'", Options.Port); + ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", Options.Port); AcceptConnection(); @@ -523,7 +671,7 @@ WsServer::Shutdown() { if (m_Running) { - ZEN_LOG_INFO(WsLog, "websocket server shutting down"); + ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down"); m_Running = false; @@ -544,22 +692,24 @@ WsServer::AcceptConnection() m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { if (Ec) { - ZEN_LOG_WARN(WsLog, "accept connection FAILED, error code '{}'", Ec.value()); + ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, error code '{}'", Ec.value()); } else { - auto ConnId = WsConnectionId::New(); - - ZEN_LOG_DEBUG(WsLog, "accept connection OK, ID '{}'", ConnId.Value()); - + auto ConnId = WsConnectionId::New(); auto Connection = std::make_shared(*this, ConnId, std::move(ConnectedSocket)); + ZEN_LOG_DEBUG(LogWebSocket, "accept connection OK, addr '{}', ID '{}'", Connection->RemoteAddr(), ConnId.Value()); + { std::unique_lock _(m_ConnMutex); m_Connections[ConnId] = Connection; } - Connection->InitializeHttpParser(); + auto Parser = std::make_unique(HttpMessageParserType::kRequest); + Parser->Reset(); + + Connection->SetParser(std::move(Parser)); Connection->SetState(WsConnectionState::kHandshaking); ReadMessage(Connection); @@ -579,11 +729,15 @@ WsServer::CloseConnection(std::shared_ptr Connection, const std::e { if (Ec) { - ZEN_LOG_INFO(WsLog, "connection '{}' closed, ERROR '{}' error code '{}'", Connection->Id().Value(), Ec.message(), Ec.value()); + ZEN_LOG_INFO(LogWebSocket, + "connection '{}' closed, ERROR '{}' error code '{}'", + Connection->Id().Value(), + Ec.message(), + Ec.value()); } else { - ZEN_LOG_INFO(WsLog, "connection '{}' closed", Connection->Id().Value()); + ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value()); } } @@ -591,14 +745,17 @@ WsServer::CloseConnection(std::shared_ptr Connection, const std::e { std::unique_lock _(m_ConnMutex); - m_Connections.erase(Id); + if (m_Connections.contains(Id)) + { + m_Connections.erase(Id); + } } } void WsServer::RemoveConnection(const WsConnectionId Id) { - ZEN_LOG_INFO(WsLog, "removing connection '{}'", Id.Value()); + ZEN_LOG_INFO(LogWebSocket, "removing connection '{}'", Id.Value()); } void @@ -616,7 +773,11 @@ WsServer::ReadMessage(std::shared_ptr Connection) return CloseConnection(Connection, ReadEc); } - ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection->Id().Value()); + ZEN_LOG_DEBUG(LogWebSocket, + "reading {}B from connection '#{} {}'", + ByteCount, + Connection->Id().Value(), + Connection->RemoteAddr()); using enum WsConnectionState; @@ -624,20 +785,32 @@ WsServer::ReadMessage(std::shared_ptr Connection) { case kHandshaking: { - HttpParser& Parser = Connection->ParserHttp(); - const size_t Consumed = Parser.Parse(Connection->ReadBuffer().data()); - Connection->ReadBuffer().consume(Consumed); + HttpMessageParser& Parser = *reinterpret_cast(Connection->Parser()); + asio::const_buffer Buffer = Connection->ReadBuffer().data(); + + ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); - if (Parser.IsComplete == false) + if (Result.Status == ParseMessageStatus::kContinue) { return ReadMessage(Connection); } - if (Parser.IsUpgrade == false) + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + return CloseConnection(Connection, std::error_code()); + } + + if (Parser.IsUpgrade() == false) { - ZEN_LOG_DEBUG(WsLog, - "handshake with connection '{}' FAILED, reason 'not an upgrade request'", - Connection->Id().Value()); + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'", + Connection->Id().Value(), + Connection->RemoteAddr()); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; @@ -646,27 +819,28 @@ WsServer::ReadMessage(std::shared_ptr Connection) [this, Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { - CloseConnection(Connection, WriteEc); + return CloseConnection(Connection, WriteEc); } - else - { - Connection->InitializeHttpParser(); - Connection->SetState(WsConnectionState::kHandshaking); - ReadMessage(Connection); - } + Connection->Parser()->Reset(); + Connection->SetState(WsConnectionState::kHandshaking); + + ReadMessage(Connection); }); } - std::unordered_map Headers; - Parser.GetHeaders(Headers); + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + std::string AcceptHash; std::string Reason; - std::string AcceptHash = Parser.ValidateWebSocketHandshake(Headers, Reason); + const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason); - if (AcceptHash.empty()) + if (ValidHandshake == false) { - ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection->Id().Value(), Reason); + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + Reason); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; @@ -675,16 +849,13 @@ WsServer::ReadMessage(std::shared_ptr Connection) [this, &Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { - CloseConnection(Connection, WriteEc); + return CloseConnection(Connection, WriteEc); } - else - { - // TODO: Always close connection? - Connection->InitializeHttpParser(); - Connection->SetState(WsConnectionState::kHandshaking); - ReadMessage(Connection); - } + Connection->Parser()->Reset(); + Connection->SetState(WsConnectionState::kHandshaking); + + ReadMessage(Connection); }); } @@ -695,95 +866,325 @@ WsServer::ReadMessage(std::shared_ptr Connection) Sb << "Connection: Upgrade\r\n"sv; // TODO: Verify protocol - if (Headers.contains(http_header::SecWebSocketProtocol)) + if (Parser.Headers().contains(http_header::SecWebSocketProtocol)) { - Sb << http_header::SecWebSocketProtocol << ": " << Headers[http_header::SecWebSocketProtocol] << "\r\n"; + Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol] + << "\r\n"; } Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; Sb << "\r\n"sv; - std::string Response = Sb.ToString(); - asio::const_buffer Buffer = asio::buffer(Response); + ZEN_LOG_DEBUG(LogWebSocket, + "accepting handshake from connection '#{} {}'", + Connection->Id().Value(), + Connection->RemoteAddr()); - ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection->Id().Value()); + std::string Response = Sb.ToString(); + Buffer = asio::buffer(Response); async_write(Connection->Socket(), Buffer, - [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) { + [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) { if (WriteEc) { - CloseConnection(Connection, WriteEc); + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + WriteEc.message()); + + return CloseConnection(Connection, WriteEc); } - else - { - ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection->Id().Value()); - Connection->ReleaseHttpParser(); - Connection->SetState(kConnected); - Connection->MessageParser().Reset(); + ZEN_LOG_DEBUG(LogWebSocket, + "handshake ({}B) with connection '#{} {}' OK", + ByteCount, + Connection->Id().Value(), + Connection->RemoteAddr()); - ReadMessage(Connection); - } + Connection->SetParser(std::make_unique()); + Connection->SetState(kConnected); }); } break; case kConnected: { - for (;;) - { - if (Connection->ReadBuffer().size() == 0) - { - break; - } + // for (;;) + //{ + // if (Connection->ReadBuffer().size() == 0) + // { + // break; + // } + + // WsMessageParser& MessageParser = Connection->MessageParser(); + + // size_t ConsumedBytes{}; + // const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes); + + // Connection->ReadBuffer().consume(ConsumedBytes); + + // if (Ok == false) + // { + // ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, connection '{}'", + // Connection->Id().Value()); MessageParser.Reset(); + // } + + // if (Ok == false || MessageParser.IsComplete() == false) + // { + // continue; + // } + + // CbPackage Message; + // if (MessageParser.TryLoadMessage(Message) == false) + // { + // ZEN_LOG_WARN(LogWebSocket, "invalid websocket message, connection '{}'", + // Connection->Id().Value()); continue; + // } + + // RouteMessage(Message); + //} + + // ReadMessage(Connection); + } + break; + + default: + break; + }; + }); +} + +void +WsServer::RouteMessage(const CbPackage& Msg) +{ + ZEN_UNUSED(Msg); + ZEN_LOG_DEBUG(LogWebSocket, "routing message"); +} + +/////////////////////////////////////////////////////////////////////////////// +class WsClient final : public WebSocketClient +{ +public: + WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Logger(zen::logging::Get("websocket-client")) {} + + virtual ~WsClient() { Disconnect(); } + + virtual bool Connect(const WebSocketConnectInfo& Info) override; + virtual void Disconnect() override; + virtual bool IsConnected() const { return false; } + virtual WebSocketState State() const { return static_cast(m_State.load()); } + + virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override; + +private: + WebSocketState SetState(WebSocketState NewState) { return static_cast(m_State.exchange(uint32_t(NewState))); } + void TriggerEvent(WebSocketEvent Evt); + void BeginRead(); + spdlog::logger& Log() { return m_Logger; } + + asio::io_context& m_IoCtx; + spdlog::logger& m_Logger; + std::unique_ptr m_Socket; + std::unique_ptr m_MsgParser; + asio::streambuf m_ReadBuffer; + EventCallback m_EventCallbacks[3]; + std::atomic_uint32_t m_State; + std::string m_Host; + int16_t m_Port{}; +}; + +bool +WsClient::Connect(const WebSocketConnectInfo& Info) +{ + if (State() == WebSocketState::kConnecting || State() == WebSocketState::kConnected) + { + return true; + } + + SetState(WebSocketState::kConnecting); + + try + { + asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port); + m_Socket = std::make_unique(m_IoCtx, Endpoint.protocol()); + + m_Socket->connect(Endpoint); + + m_Host = m_Socket->remote_endpoint().address().to_string(); + m_Port = Info.Port; + + ZEN_INFO("connected to websocket server '{}:{}'", m_Host, m_Port); + } + catch (std::exception& Err) + { + ZEN_WARN("connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); + + SetState(WebSocketState::kFailedToConnect); + m_Socket.reset(); + + TriggerEvent(WebSocketEvent::kDisconnected); + + return false; + } + + ExtendableStringBuilder<128> Sb; + Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv; + Sb << "Host: " << Info.Host << "\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: upgrade\r\n"sv; + Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv; + + if (Info.Protocols.empty() == false) + { + Sb << "Sec-WebSocket-Protocol: "sv; + for (size_t Idx = 0; const auto& Protocol : Info.Protocols) + { + if (Idx++) + { + Sb << ", "; + } + Sb << Protocol; + } + } + + Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv; + Sb << "\r\n"; + + std::string HandshakeRequest = Sb.ToString(); + asio::const_buffer Buffer = asio::buffer(HandshakeRequest); + + ZEN_DEBUG("handshaking with '{}:{}'", m_Host, m_Port); + + m_MsgParser = std::make_unique(HttpMessageParserType::kResponse); + m_MsgParser->Reset(); + + async_write(*m_Socket, Buffer, [this, _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_ERROR("write data FAILED, reason '{}'", Ec.message()); + + Disconnect(); + } + else + { + BeginRead(); + } + }); + + return true; +} + +void +WsClient::Disconnect() +{ + if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) + { + ZEN_INFO("closing connection to '{}:{}'", m_Host, m_Port); + + if (m_Socket && m_Socket->is_open()) + { + m_Socket->close(); + m_Socket.reset(); + } + + TriggerEvent(WebSocketEvent::kDisconnected); + } +} + +void +WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) +{ + m_EventCallbacks[static_cast(Evt)] = Cb; +} + +void +WsClient::TriggerEvent(WebSocketEvent Evt) +{ + const uint32_t Index = static_cast(Evt); + + if (m_EventCallbacks[Index]) + { + m_EventCallbacks[Index](); + } +} + +void +WsClient::BeginRead() +{ + m_ReadBuffer.prepare(64 << 10); + + async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t ByteCount) { + if (Ec) + { + ZEN_DEBUG("read data from '{}:{}' FAILED, reason '{}'", m_Host, m_Port, Ec.message()); + + Disconnect(); + } + else + { + ZEN_DEBUG("reading {}B from '{}:{}'", ByteCount, m_Host, m_Port); + + switch (State()) + { + case WebSocketState::kConnecting: + { + ZEN_ASSERT(m_MsgParser.get() != nullptr); - WsMessageParser& MessageParser = Connection->MessageParser(); + HttpMessageParser& Parser = *reinterpret_cast(m_MsgParser.get()); - size_t ConsumedBytes{}; - const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes); + asio::const_buffer Buffer = m_ReadBuffer.data(); + ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); - Connection->ReadBuffer().consume(ConsumedBytes); + m_ReadBuffer.consume(size_t(Result.ByteCount)); - if (Ok == false) - { - ZEN_LOG_WARN(WsLog, "parse websocket message FAILED, connection '{}'", Connection->Id().Value()); - MessageParser.Reset(); - } + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode()); + + return Disconnect(); + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + return BeginRead(); + } - if (Ok == false || MessageParser.IsComplete() == false) - { - continue; - } + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); - CbPackage Message; - if (MessageParser.TryLoadMessage(Message) == false) - { - ZEN_LOG_WARN(WsLog, "invalid websocket message, connection '{}'", Connection->Id().Value()); - continue; - } + if (Parser.StatusCode() != 101) + { + ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'", + m_Host, + m_Port, + Parser.StatusText(), + Parser.StatusCode()); - RouteMessage(Message); + return Disconnect(); } - ReadMessage(Connection); + ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText()); + + m_MsgParser = std::make_unique(); + + SetState(WebSocketState::kConnected); + TriggerEvent(WebSocketEvent::kConnected); + + BeginRead(); } break; - default: + case WebSocketState::kConnected: + { + BeginRead(); + } break; }; - }); -} - -void -WsServer::RouteMessage(const CbPackage& Msg) -{ - ZEN_UNUSED(Msg); - ZEN_LOG_DEBUG(WsLog, "routing message"); + } + }); } -} // namespace zen::asio_ws +} // namespace zen::websocket namespace zen { @@ -810,7 +1211,13 @@ WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeade std::unique_ptr WebSocketServer::Create() { - return std::make_unique(); + return std::make_unique(); +} + +std::unique_ptr +WebSocketClient::Create(asio::io_context& IoCtx) +{ + return std::make_unique(IoCtx); } } // namespace zen -- cgit v1.2.3 From b063a8c2fccbdbf73e1a7730cf26a97cec475636 Mon Sep 17 00:00:00 2001 From: Per Larsson Date: Fri, 18 Feb 2022 09:31:22 +0100 Subject: Route websocket message. --- zenhttp/websocketasio.cpp | 318 +++++++++++++++++++++++++--------------------- 1 file changed, 170 insertions(+), 148 deletions(-) (limited to 'zenhttp/websocketasio.cpp') diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index b800892d2..eb01e010e 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -437,14 +437,6 @@ WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) return false; } -/////////////////////////////////////////////////////////////////////////////// -enum class WsConnectionState : uint32_t -{ - kDisconnected, - kHandshaking, - kConnected -}; - /////////////////////////////////////////////////////////////////////////////// class WsConnectionId { @@ -467,53 +459,46 @@ private: std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1}; -class WsServer; - /////////////////////////////////////////////////////////////////////////////// class WsConnection : public std::enable_shared_from_this { public: - WsConnection(WsServer& Server, WsConnectionId Id, std::unique_ptr Socket) - : m_Server(Server) - , m_Id(Id) + WsConnection(WsConnectionId Id, std::unique_ptr Socket) + : m_Id(Id) , m_Socket(std::move(Socket)) , m_StartTime(Clock::now()) - , m_Status() + , m_State() { - m_RemoteAddr = m_Socket->remote_endpoint().address().to_string(); } - ~WsConnection(); + ~WsConnection() = default; std::shared_ptr AsShared() { return shared_from_this(); } WsConnectionId Id() const { return m_Id; } - std::string_view RemoteAddr() const { return m_RemoteAddr; } asio::ip::tcp::socket& Socket() { return *m_Socket; } TimePoint StartTime() const { return m_StartTime; } + WebSocketState State() const { return static_cast(m_State.load(std::memory_order_relaxed)); } + std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); } asio::streambuf& ReadBuffer() { return m_ReadBuffer; } - WsConnectionState Close(); - WsConnectionState State() const { return static_cast(m_Status.load(std::memory_order_relaxed)); } - WsConnectionState SetState(WsConnectionState NewState) { return static_cast(m_Status.exchange(uint32_t(NewState))); } - - MessageParser* Parser() { return m_MsgParser.get(); } - void SetParser(std::unique_ptr&& Parser) { m_MsgParser = std::move(Parser); } + WebSocketState SetState(WebSocketState NewState) { return static_cast(m_State.exchange(uint32_t(NewState))); } + WebSocketState Close(); + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr&& Parser) { m_MsgParser = std::move(Parser); } private: - WsServer& m_Server; WsConnectionId m_Id; std::unique_ptr m_Socket; - std::unique_ptr m_MsgParser; TimePoint m_StartTime; + std::atomic_uint32_t m_State; + std::unique_ptr m_MsgParser; asio::streambuf m_ReadBuffer; - std::string m_RemoteAddr; - std::atomic_uint32_t m_Status; }; -WsConnectionState +WebSocketState WsConnection::Close() { - using enum WsConnectionState; + using enum WebSocketState; const auto PrevState = SetState(kDisconnected); @@ -607,10 +592,9 @@ private: void AcceptConnection(); void CloseConnection(std::shared_ptr Connection, const std::error_code& Ec); - void RemoveConnection(const WsConnectionId Id); void ReadMessage(std::shared_ptr Connection); - void RouteMessage(const CbPackage& Msg); + void RouteMessage(std::shared_ptr Connection, const CbPackage& Msg); struct IdHasher { @@ -627,11 +611,6 @@ private: std::atomic_bool m_Running{}; }; -WsConnection::~WsConnection() -{ - m_Server.RemoveConnection(m_Id); -} - bool WsServer::Run(const WebSocketServerOptions& Options) { @@ -692,25 +671,21 @@ WsServer::AcceptConnection() m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { if (Ec) { - ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, error code '{}'", Ec.value()); + ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message()); } else { - auto ConnId = WsConnectionId::New(); - auto Connection = std::make_shared(*this, ConnId, std::move(ConnectedSocket)); + auto Connection = std::make_shared(WsConnectionId::New(), std::move(ConnectedSocket)); - ZEN_LOG_DEBUG(LogWebSocket, "accept connection OK, addr '{}', ID '{}'", Connection->RemoteAddr(), ConnId.Value()); + ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); { std::unique_lock _(m_ConnMutex); - m_Connections[ConnId] = Connection; + m_Connections[Connection->Id()] = Connection; } - auto Parser = std::make_unique(HttpMessageParserType::kRequest); - Parser->Reset(); - - Connection->SetParser(std::move(Parser)); - Connection->SetState(WsConnectionState::kHandshaking); + Connection->SetParser(std::make_unique(HttpMessageParserType::kRequest)); + Connection->SetState(WebSocketState::kHandshaking); ReadMessage(Connection); } @@ -725,7 +700,7 @@ WsServer::AcceptConnection() void WsServer::CloseConnection(std::shared_ptr Connection, const std::error_code& Ec) { - if (const auto State = Connection->Close(); State != WsConnectionState::kDisconnected) + if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected) { if (Ec) { @@ -752,12 +727,6 @@ WsServer::CloseConnection(std::shared_ptr Connection, const std::e } } -void -WsServer::RemoveConnection(const WsConnectionId Id) -{ - ZEN_LOG_INFO(LogWebSocket, "removing connection '{}'", Id.Value()); -} - void WsServer::ReadMessage(std::shared_ptr Connection) { @@ -773,13 +742,7 @@ WsServer::ReadMessage(std::shared_ptr Connection) return CloseConnection(Connection, ReadEc); } - ZEN_LOG_DEBUG(LogWebSocket, - "reading {}B from connection '#{} {}'", - ByteCount, - Connection->Id().Value(), - Connection->RemoteAddr()); - - using enum WsConnectionState; + using enum WebSocketState; switch (Connection->State()) { @@ -823,7 +786,7 @@ WsServer::ReadMessage(std::shared_ptr Connection) } Connection->Parser()->Reset(); - Connection->SetState(WsConnectionState::kHandshaking); + Connection->SetState(WebSocketState::kHandshaking); ReadMessage(Connection); }); @@ -853,7 +816,7 @@ WsServer::ReadMessage(std::shared_ptr Connection) } Connection->Parser()->Reset(); - Connection->SetState(WsConnectionState::kHandshaking); + Connection->SetState(WebSocketState::kHandshaking); ReadMessage(Connection); }); @@ -910,42 +873,46 @@ WsServer::ReadMessage(std::shared_ptr Connection) case kConnected: { - // for (;;) - //{ - // if (Connection->ReadBuffer().size() == 0) - // { - // break; - // } - - // WsMessageParser& MessageParser = Connection->MessageParser(); - - // size_t ConsumedBytes{}; - // const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes); - - // Connection->ReadBuffer().consume(ConsumedBytes); - - // if (Ok == false) - // { - // ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, connection '{}'", - // Connection->Id().Value()); MessageParser.Reset(); - // } - - // if (Ok == false || MessageParser.IsComplete() == false) - // { - // continue; - // } - - // CbPackage Message; - // if (MessageParser.TryLoadMessage(Message) == false) - // { - // ZEN_LOG_WARN(LogWebSocket, "invalid websocket message, connection '{}'", - // Connection->Id().Value()); continue; - // } - - // RouteMessage(Message); - //} - - // ReadMessage(Connection); + WebSocketMessageParser& Parser = *reinterpret_cast(Connection->Parser()); + + MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), ByteCount); + + while (MessageData.IsEmpty() == false) + { + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + MessageData.RightChopInline(Result.ByteCount); + + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(MessageData.IsEmpty()); + + return ReadMessage(Connection); + } + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + + return CloseConnection(Connection, std::error_code()); + } + + CbPackage Message; + if (Parser.TryLoadMessage(Message) == false) + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'"); + + return CloseConnection(Connection, std::error_code()); + } + + RouteMessage(Connection, Message); + + Parser.Reset(); + } + + Connection->ReadBuffer().consume(ByteCount); + + ReadMessage(Connection); } break; @@ -956,9 +923,9 @@ WsServer::ReadMessage(std::shared_ptr Connection) } void -WsServer::RouteMessage(const CbPackage& Msg) +WsServer::RouteMessage(std::shared_ptr Connection, const CbPackage& Msg) { - ZEN_UNUSED(Msg); + ZEN_UNUSED(Connection, Msg); ZEN_LOG_DEBUG(LogWebSocket, "routing message"); } @@ -976,11 +943,13 @@ public: virtual WebSocketState State() const { return static_cast(m_State.load()); } virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override; + virtual void OnMessage(MessageCallback&& Cb) override; private: WebSocketState SetState(WebSocketState NewState) { return static_cast(m_State.exchange(uint32_t(NewState))); } void TriggerEvent(WebSocketEvent Evt); - void BeginRead(); + void ReadMessage(); + void RouteMessage(CbPackage&& Msg); spdlog::logger& Log() { return m_Logger; } asio::io_context& m_IoCtx; @@ -989,6 +958,7 @@ private: std::unique_ptr m_MsgParser; asio::streambuf m_ReadBuffer; EventCallback m_EventCallbacks[3]; + MessageCallback m_MsgCallback; std::atomic_uint32_t m_State; std::string m_Host; int16_t m_Port{}; @@ -997,12 +967,12 @@ private: bool WsClient::Connect(const WebSocketConnectInfo& Info) { - if (State() == WebSocketState::kConnecting || State() == WebSocketState::kConnected) + if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) { return true; } - SetState(WebSocketState::kConnecting); + SetState(WebSocketState::kHandshaking); try { @@ -1020,7 +990,7 @@ WsClient::Connect(const WebSocketConnectInfo& Info) { ZEN_WARN("connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); - SetState(WebSocketState::kFailedToConnect); + SetState(WebSocketState::kError); m_Socket.reset(); TriggerEvent(WebSocketEvent::kDisconnected); @@ -1068,7 +1038,7 @@ WsClient::Connect(const WebSocketConnectInfo& Info) } else { - BeginRead(); + ReadMessage(); } }); @@ -1095,7 +1065,13 @@ WsClient::Disconnect() void WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) { - m_EventCallbacks[static_cast(Evt)] = Cb; + m_EventCallbacks[static_cast(Evt)] = std::move(Cb); +} + +void +WsClient::OnMessage(MessageCallback&& Cb) +{ + m_MsgCallback = std::move(Cb); } void @@ -1110,7 +1086,7 @@ WsClient::TriggerEvent(WebSocketEvent Evt) } void -WsClient::BeginRead() +WsClient::ReadMessage() { m_ReadBuffer.prepare(64 << 10); @@ -1119,71 +1095,117 @@ WsClient::BeginRead() { ZEN_DEBUG("read data from '{}:{}' FAILED, reason '{}'", m_Host, m_Port, Ec.message()); - Disconnect(); + return Disconnect(); } - else + + switch (State()) { - ZEN_DEBUG("reading {}B from '{}:{}'", ByteCount, m_Host, m_Port); + case WebSocketState::kHandshaking: + { + ZEN_ASSERT(m_MsgParser.get() != nullptr); - switch (State()) - { - case WebSocketState::kConnecting: + HttpMessageParser& Parser = *reinterpret_cast(m_MsgParser.get()); + + asio::const_buffer Buffer = m_ReadBuffer.data(); + ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + + m_ReadBuffer.consume(size_t(Result.ByteCount)); + + if (Result.Status == ParseMessageStatus::kError) { - ZEN_ASSERT(m_MsgParser.get() != nullptr); + ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode()); - HttpMessageParser& Parser = *reinterpret_cast(m_MsgParser.get()); + return Disconnect(); + } - asio::const_buffer Buffer = m_ReadBuffer.data(); - ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + if (Result.Status == ParseMessageStatus::kContinue) + { + return ReadMessage(); + } - m_ReadBuffer.consume(size_t(Result.ByteCount)); + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode()); + if (Parser.StatusCode() != 101) + { + ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'", + m_Host, + m_Port, + Parser.StatusText(), + Parser.StatusCode()); - return Disconnect(); - } + return Disconnect(); + } - if (Result.Status == ParseMessageStatus::kContinue) - { - return BeginRead(); - } + ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText()); - ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + m_MsgParser = std::make_unique(); + + SetState(WebSocketState::kConnected); + TriggerEvent(WebSocketEvent::kConnected); + + ReadMessage(); + } + break; + + case WebSocketState::kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast(m_MsgParser.get()); - if (Parser.StatusCode() != 101) + MemoryView MessageData = MemoryView(m_ReadBuffer.data().data(), ByteCount); + + while (MessageData.IsEmpty() == false) + { + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + MessageData.RightChopInline(Result.ByteCount); + + if (Result.Status == ParseMessageStatus::kContinue) { - ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'", - m_Host, - m_Port, - Parser.StatusText(), - Parser.StatusCode()); + ZEN_ASSERT(MessageData.IsEmpty()); - return Disconnect(); + return ReadMessage(); } - ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText()); + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); - m_MsgParser = std::make_unique(); + Parser.Reset(); - SetState(WebSocketState::kConnected); - TriggerEvent(WebSocketEvent::kConnected); + continue; + } - BeginRead(); - } - break; + CbPackage Message; + if (Parser.TryLoadMessage(Message)) + { + RouteMessage(std::move(Message)); + } + else + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'"); + } - case WebSocketState::kConnected: - { - BeginRead(); + Parser.Reset(); } - break; - }; - } + + m_ReadBuffer.consume(ByteCount); + + ReadMessage(); + } + break; + }; }); } +void +WsClient::RouteMessage(CbPackage&& Msg) +{ + if (m_MsgCallback) + { + m_MsgCallback(Msg); + } +} + } // namespace zen::websocket namespace zen { -- cgit v1.2.3 From 3a64ffe3595c20f82de876a90f026dcc36c75258 Mon Sep 17 00:00:00 2001 From: Per Larsson Date: Fri, 18 Feb 2022 10:12:43 +0100 Subject: Web socket client is shared between I/O thead and client. --- zenhttp/websocketasio.cpp | 237 ++++++++++++++++++++++++---------------------- 1 file changed, 126 insertions(+), 111 deletions(-) (limited to 'zenhttp/websocketasio.cpp') diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index eb01e010e..1952c97a2 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -32,6 +32,8 @@ using namespace std::literals; ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv); + using Clock = std::chrono::steady_clock; using TimePoint = Clock::time_point; @@ -104,7 +106,8 @@ class HttpMessageParser final : public MessageParser public: using HttpHeaders = std::unordered_map; - HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) {} + HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); } + virtual ~HttpMessageParser() = default; int32_t StatusCode() const { return m_Parser.status_code; } @@ -120,6 +123,7 @@ public: bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); private: + void Initialize(); virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; int OnMessageBegin(); @@ -183,6 +187,25 @@ http_parser_settings HttpMessageParser::ParserSettings = { .on_message_complete = [](http_parser* P) { return reinterpret_cast(P->data)->OnMessageComplete(); }}; +void +HttpMessageParser::Initialize() +{ + http_parser_init(&m_Parser, + m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST + : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE + : HTTP_BOTH); + m_Parser.data = this; + + m_UrlEntry = {}; + m_StatusEntry = {}; + m_CurrentHeader = {}; + m_BodyEntry = {}; + + m_IsMsgComplete = false; + + m_HeaderEntries.clear(); +} + ParseMessageResult HttpMessageParser::OnParseMessage(MemoryView Msg) { @@ -201,20 +224,7 @@ HttpMessageParser::OnParseMessage(MemoryView Msg) void HttpMessageParser::OnReset() { - http_parser_init(&m_Parser, - m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST - : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE - : HTTP_BOTH); - m_Parser.data = this; - - m_UrlEntry = {}; - m_StatusEntry = {}; - m_CurrentHeader = {}; - m_BodyEntry = {}; - - m_IsMsgComplete = false; - - m_HeaderEntries.clear(); + Initialize(); } int @@ -930,13 +940,15 @@ WsServer::RouteMessage(std::shared_ptr Connection, const CbPackage } /////////////////////////////////////////////////////////////////////////////// -class WsClient final : public WebSocketClient +class WsClient final : public WebSocketClient, public std::enable_shared_from_this { public: - WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Logger(zen::logging::Get("websocket-client")) {} + WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WsConnectionId::New()) {} virtual ~WsClient() { Disconnect(); } + std::shared_ptr AsShared() { return shared_from_this(); } + virtual bool Connect(const WebSocketConnectInfo& Info) override; virtual void Disconnect() override; virtual bool IsConnected() const { return false; } @@ -946,14 +958,16 @@ public: virtual void OnMessage(MessageCallback&& Cb) override; private: - WebSocketState SetState(WebSocketState NewState) { return static_cast(m_State.exchange(uint32_t(NewState))); } - void TriggerEvent(WebSocketEvent Evt); - void ReadMessage(); - void RouteMessage(CbPackage&& Msg); - spdlog::logger& Log() { return m_Logger; } + WebSocketState SetState(WebSocketState NewState) { return static_cast(m_State.exchange(uint32_t(NewState))); } + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr&& Parser) { m_MsgParser = std::move(Parser); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + void TriggerEvent(WebSocketEvent Evt); + void ReadMessage(); + void RouteMessage(CbPackage&& Msg); asio::io_context& m_IoCtx; - spdlog::logger& m_Logger; + WsConnectionId m_Id; std::unique_ptr m_Socket; std::unique_ptr m_MsgParser; asio::streambuf m_ReadBuffer; @@ -984,11 +998,11 @@ WsClient::Connect(const WebSocketConnectInfo& Info) m_Host = m_Socket->remote_endpoint().address().to_string(); m_Port = Info.Port; - ZEN_INFO("connected to websocket server '{}:{}'", m_Host, m_Port); + ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port); } catch (std::exception& Err) { - ZEN_WARN("connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); + ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); SetState(WebSocketState::kError); m_Socket.reset(); @@ -1024,21 +1038,21 @@ WsClient::Connect(const WebSocketConnectInfo& Info) std::string HandshakeRequest = Sb.ToString(); asio::const_buffer Buffer = asio::buffer(HandshakeRequest); - ZEN_DEBUG("handshaking with '{}:{}'", m_Host, m_Port); + ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port); m_MsgParser = std::make_unique(HttpMessageParserType::kResponse); m_MsgParser->Reset(); - async_write(*m_Socket, Buffer, [this, _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { + async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { if (Ec) { - ZEN_ERROR("write data FAILED, reason '{}'", Ec.message()); + ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message()); - Disconnect(); + Self->Disconnect(); } else { - ReadMessage(); + Self->ReadMessage(); } }); @@ -1050,7 +1064,7 @@ WsClient::Disconnect() { if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) { - ZEN_INFO("closing connection to '{}:{}'", m_Host, m_Port); + ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port); if (m_Socket && m_Socket->is_open()) { @@ -1090,111 +1104,112 @@ WsClient::ReadMessage() { m_ReadBuffer.prepare(64 << 10); - async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t ByteCount) { - if (Ec) - { - ZEN_DEBUG("read data from '{}:{}' FAILED, reason '{}'", m_Host, m_Port, Ec.message()); - - return Disconnect(); - } + async_read(*m_Socket, + m_ReadBuffer, + asio::transfer_at_least(1), + [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable { + if (Ec) + { + ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message()); - switch (State()) - { - case WebSocketState::kHandshaking: - { - ZEN_ASSERT(m_MsgParser.get() != nullptr); + return Self->Disconnect(); + } - HttpMessageParser& Parser = *reinterpret_cast(m_MsgParser.get()); + const WebSocketState State = Self->State(); - asio::const_buffer Buffer = m_ReadBuffer.data(); - ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + switch (State) + { + case WebSocketState::kHandshaking: + { + HttpMessageParser& Parser = *reinterpret_cast(Self->Parser()); - m_ReadBuffer.consume(size_t(Result.ByteCount)); + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode()); + ParseMessageResult Result = Parser.ParseMessage(MessageData); - return Disconnect(); - } + Self->ReadBuffer().consume(size_t(Result.ByteCount)); - if (Result.Status == ParseMessageStatus::kContinue) - { - return ReadMessage(); - } + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); - ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + return Self->Disconnect(); + } - if (Parser.StatusCode() != 101) - { - ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'", - m_Host, - m_Port, - Parser.StatusText(), - Parser.StatusCode()); + if (Result.Status == ParseMessageStatus::kContinue) + { + return Self->ReadMessage(); + } - return Disconnect(); - } + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); - ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText()); + if (Parser.StatusCode() != 101) + { + ZEN_LOG_WARN(LogWsClient, + "handshake FAILED, status '{}', status code '{}'", + Parser.StatusText(), + Parser.StatusCode()); - m_MsgParser = std::make_unique(); + return Self->Disconnect(); + } - SetState(WebSocketState::kConnected); - TriggerEvent(WebSocketEvent::kConnected); + ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText()); - ReadMessage(); - } - break; + Self->SetParser(std::make_unique()); + Self->SetState(WebSocketState::kConnected); + Self->TriggerEvent(WebSocketEvent::kConnected); - case WebSocketState::kConnected: - { - WebSocketMessageParser& Parser = *reinterpret_cast(m_MsgParser.get()); + Self->ReadMessage(); + } + break; - MemoryView MessageData = MemoryView(m_ReadBuffer.data().data(), ByteCount); + case WebSocketState::kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast(Self->Parser()); - while (MessageData.IsEmpty() == false) - { - const ParseMessageResult Result = Parser.ParseMessage(MessageData); + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); - MessageData.RightChopInline(Result.ByteCount); + while (MessageData.IsEmpty() == false) + { + const ParseMessageResult Result = Parser.ParseMessage(MessageData); - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(MessageData.IsEmpty()); + MessageData.RightChopInline(Result.ByteCount); - return ReadMessage(); - } + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(MessageData.IsEmpty()); - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + return Self->ReadMessage(); + } - Parser.Reset(); + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); - continue; - } + Parser.Reset(); - CbPackage Message; - if (Parser.TryLoadMessage(Message)) - { - RouteMessage(std::move(Message)); - } - else - { - ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'"); - } + continue; + } - Parser.Reset(); - } + CbPackage Message; + if (Parser.TryLoadMessage(Message)) + { + Self->RouteMessage(std::move(Message)); + } + else + { + ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason 'invalid message'"); + } - m_ReadBuffer.consume(ByteCount); + Parser.Reset(); + } - ReadMessage(); - } - break; - }; - }); + Self->ReadBuffer().consume(ByteCount); + Self->ReadMessage(); + } + break; + } + }); } void @@ -1236,10 +1251,10 @@ WebSocketServer::Create() return std::make_unique(); } -std::unique_ptr +std::shared_ptr WebSocketClient::Create(asio::io_context& IoCtx) { - return std::make_unique(IoCtx); + return std::make_shared(IoCtx); } } // namespace zen -- cgit v1.2.3 From 08cc02bf1b5cad75ed7b97c0dc7cfc082da537c4 Mon Sep 17 00:00:00 2001 From: Per Larsson Date: Fri, 18 Feb 2022 14:48:41 +0100 Subject: Basic websocket service and test. --- zenhttp/websocketasio.cpp | 212 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 161 insertions(+), 51 deletions(-) (limited to 'zenhttp/websocketasio.cpp') diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index 1952c97a2..f6f58f38c 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -13,7 +13,6 @@ #include #include -#include #include #include #include @@ -381,7 +380,6 @@ private: virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; - SimpleBinaryWriter m_HeaderStream; WebSocketMessageHeader m_Header; }; @@ -390,20 +388,20 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) { const uint64_t PrevOffset = m_Stream.CurrentOffset(); - if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader)) { - const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_HeaderStream.CurrentOffset(); + const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_Stream.CurrentOffset(); - m_HeaderStream.Write(Msg.Left(RemaingHeaderSize)); + m_Stream.Write(Msg.Left(RemaingHeaderSize)); Msg.RightChopInline(RemaingHeaderSize); - if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader)) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - const bool IsValidHeader = WebSocketMessageHeader::Read(m_HeaderStream.GetView(), m_Header); + const bool IsValidHeader = WebSocketMessageHeader::Read(m_Stream.GetView(), m_Header); if (IsValidHeader == false) { @@ -411,16 +409,21 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) } } - if (Msg.GetSize() == 0) + ZEN_ASSERT(m_Stream.CurrentOffset() >= sizeof(WebSocketMessageHeader)); + + if (Msg.IsEmpty()) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - const uint64_t RemaingContentSize = m_Header.ContentLength - m_HeaderStream.CurrentOffset(); + const uint64_t RemaingContentSize = + Min(m_Header.ContentLength - (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)), Msg.GetSize()); m_Stream.Write(Msg.Left(RemaingContentSize)); - const auto Status = m_Stream.CurrentOffset() == m_Header.ContentLength ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; + const auto Status = (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)) == m_Header.ContentLength + ? ParseMessageStatus::kDone + : ParseMessageStatus::kContinue; return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } @@ -428,18 +431,17 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) void WebSocketMessageParser::OnReset() { - m_HeaderStream.Clear(); m_Header = {}; } bool WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) { - const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength; + const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength + sizeof(WebSocketMessageHeader); if (IsParsed) { - BinaryReader Reader(m_Stream.Data(), m_Stream.Size()); + BinaryReader Reader(m_Stream.GetView().RightChop(sizeof(WebSocketMessageHeader))); return OutMsg.TryLoad(Reader); } @@ -447,33 +449,11 @@ WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) return false; } -/////////////////////////////////////////////////////////////////////////////// -class WsConnectionId -{ - static std::atomic_uint32_t WsConnectionCounter; - -public: - WsConnectionId() = default; - - uint32_t Value() const { return m_Value; } - - auto operator<=>(const WsConnectionId& RHS) const = default; - - static WsConnectionId New() { return WsConnectionId(WsConnectionCounter.fetch_add(1)); } - -private: - WsConnectionId(uint32_t Value) : m_Value(Value) {} - - uint32_t m_Value{}; -}; - -std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1}; - /////////////////////////////////////////////////////////////////////////////// class WsConnection : public std::enable_shared_from_this { public: - WsConnection(WsConnectionId Id, std::unique_ptr Socket) + WsConnection(WebSocketId Id, std::unique_ptr Socket) : m_Id(Id) , m_Socket(std::move(Socket)) , m_StartTime(Clock::now()) @@ -485,7 +465,7 @@ public: std::shared_ptr AsShared() { return shared_from_this(); } - WsConnectionId Id() const { return m_Id; } + WebSocketId Id() const { return m_Id; } asio::ip::tcp::socket& Socket() { return *m_Socket; } TimePoint StartTime() const { return m_StartTime; } WebSocketState State() const { return static_cast(m_State.load(std::memory_order_relaxed)); } @@ -497,7 +477,7 @@ public: void SetParser(std::unique_ptr&& Parser) { m_MsgParser = std::move(Parser); } private: - WsConnectionId m_Id; + WebSocketId m_Id; std::unique_ptr m_Socket; TimePoint m_StartTime; std::atomic_uint32_t m_State; @@ -594,8 +574,10 @@ public: WsServer() = default; virtual ~WsServer() { Shutdown(); } + virtual void RegisterService(WebSocketService& Service) override; virtual bool Run(const WebSocketServerOptions& Options) override; virtual void Shutdown() override; + virtual void PublishMessage(WebSocketId Id, CbPackage&& Msg) override; private: friend class WsConnection; @@ -608,19 +590,28 @@ private: struct IdHasher { - size_t operator()(WsConnectionId Id) const { return size_t(Id.Value()); } + size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } }; - using ConnectionMap = std::unordered_map, IdHasher>; + using ConnectionMap = std::unordered_map, IdHasher>; asio::io_service m_IoSvc; std::unique_ptr m_Acceptor; std::unique_ptr m_ThreadPool; ConnectionMap m_Connections; std::shared_mutex m_ConnMutex; + std::vector m_Services; std::atomic_bool m_Running{}; }; +void +WsServer::RegisterService(WebSocketService& Service) +{ + m_Services.push_back(&Service); + + Service.Configure(*this); +} + bool WsServer::Run(const WebSocketServerOptions& Options) { @@ -672,6 +663,37 @@ WsServer::Shutdown() } } +void +WsServer::PublishMessage(WebSocketId Id, CbPackage&& Msg) +{ + std::shared_ptr Connection; + + { + std::unique_lock _(m_ConnMutex); + + if (auto It = m_Connections.find(Id); It != m_Connections.end()) + { + Connection = It->second; + } + } + + BinaryWriter Writer; + WebSocketMessageHeader::Write(Writer, Msg); + + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + ZEN_LOG_WARN(LogWebSocket, "sending message {}B to '#{} {}' ", Buffer.Size(), Connection->Id().Value(), Connection->RemoteAddr()); + + async_write(Connection->Socket(), + asio::buffer(Buffer.Data(), Buffer.Size()), + [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + CloseConnection(Connection, Ec); + } + }); +} + void WsServer::AcceptConnection() { @@ -685,7 +707,7 @@ WsServer::AcceptConnection() } else { - auto Connection = std::make_shared(WsConnectionId::New(), std::move(ConnectedSocket)); + auto Connection = std::make_shared(WebSocketId::New(), std::move(ConnectedSocket)); ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); @@ -726,7 +748,7 @@ WsServer::CloseConnection(std::shared_ptr Connection, const std::e } } - const WsConnectionId Id = Connection->Id(); + const WebSocketId Id = Connection->Id(); { std::unique_lock _(m_ConnMutex); @@ -763,6 +785,8 @@ WsServer::ReadMessage(std::shared_ptr Connection) ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + Connection->ReadBuffer().consume(Result.ByteCount); + if (Result.Status == ParseMessageStatus::kContinue) { return ReadMessage(Connection); @@ -770,10 +794,10 @@ WsServer::ReadMessage(std::shared_ptr Connection) if (Result.Status == ParseMessageStatus::kError) { - ZEN_LOG_DEBUG(LogWebSocket, - "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", - Connection->Id().Value(), - Connection->RemoteAddr()); + ZEN_LOG_WARN(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", + Connection->Id().Value(), + Connection->RemoteAddr()); return CloseConnection(Connection, std::error_code()); } @@ -877,6 +901,8 @@ WsServer::ReadMessage(std::shared_ptr Connection) Connection->SetParser(std::make_unique()); Connection->SetState(kConnected); + + ReadMessage(Connection); }); } break; @@ -935,15 +961,24 @@ WsServer::ReadMessage(std::shared_ptr Connection) void WsServer::RouteMessage(std::shared_ptr Connection, const CbPackage& Msg) { - ZEN_UNUSED(Connection, Msg); ZEN_LOG_DEBUG(LogWebSocket, "routing message"); + + for (auto Server : m_Services) + { + if (Server->HandleMessage(Connection->Id(), Msg)) + { + return; + } + } + + ZEN_LOG_WARN(LogWebSocket, "unhandled message"); } /////////////////////////////////////////////////////////////////////////////// class WsClient final : public WebSocketClient, public std::enable_shared_from_this { public: - WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WsConnectionId::New()) {} + WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {} virtual ~WsClient() { Disconnect(); } @@ -953,6 +988,8 @@ public: virtual void Disconnect() override; virtual bool IsConnected() const { return false; } virtual WebSocketState State() const { return static_cast(m_State.load()); } + virtual void SendMsg(CbPackage&& Msg) override; + virtual void SendMsg(CbObject&& Msg) override; virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override; virtual void OnMessage(MessageCallback&& Cb) override; @@ -967,7 +1004,7 @@ private: void RouteMessage(CbPackage&& Msg); asio::io_context& m_IoCtx; - WsConnectionId m_Id; + WebSocketId m_Id; std::unique_ptr m_Socket; std::unique_ptr m_MsgParser; asio::streambuf m_ReadBuffer; @@ -1076,6 +1113,35 @@ WsClient::Disconnect() } } +void +WsClient::SendMsg(CbPackage&& Msg) +{ + BinaryWriter Writer; + WebSocketMessageHeader::Write(Writer, Msg); + + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + ZEN_LOG_DEBUG(LogWsClient, "sending message {}B", Buffer.Size()); + + async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_LOG_ERROR(LogWsClient, "send messge FAILED, reason '{}'", Ec.message()); + + Self->Disconnect(); + } + }); +} + +void +WsClient::SendMsg(CbObject&& Msg) +{ + CbPackage Pkg; + Pkg.SetObject(std::move(Msg)); + + WsClient::SendMsg(std::move(Pkg)); +} + void WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) { @@ -1157,9 +1223,8 @@ WsClient::ReadMessage() Self->SetParser(std::make_unique()); Self->SetState(WebSocketState::kConnected); - Self->TriggerEvent(WebSocketEvent::kConnected); - Self->ReadMessage(); + Self->TriggerEvent(WebSocketEvent::kConnected); } break; @@ -1225,6 +1290,35 @@ WsClient::RouteMessage(CbPackage&& Msg) namespace zen { +std::atomic_uint32_t WebSocketId::NextId{1}; + +void +WebSocketService::Configure(WebSocketServer& Server) +{ + ZEN_ASSERT(m_Server == nullptr); + + m_Server = &Server; +} + +void +WebSocketService::PublishMessage(WebSocketId Id, CbPackage&& Msg) +{ + ZEN_ASSERT(m_Server != nullptr); + + m_Server->PublishMessage(Id, std::move(Msg)); +} + +void +WebSocketService::PublishMessage(WebSocketId Id, CbObject&& Msg) +{ + ZEN_ASSERT(m_Server != nullptr); + + CbPackage Pkg; + Pkg.SetObject(std::move(Msg)); + + m_Server->PublishMessage(Id, std::move(Pkg)); +} + bool WebSocketMessageHeader::IsValid() const { @@ -1245,6 +1339,22 @@ WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeade return OutHeader.IsValid(); } +void +WebSocketMessageHeader::Write(BinaryWriter& Writer, const CbPackage& Msg) +{ + WebSocketMessageHeader Header; + + Writer.Write(&Header, sizeof(WebSocketMessageHeader)); + Msg.Save(Writer); + + Header.Magic = WebSocketMessageHeader::ExpectedMagic; + Header.ContentLength = Writer.CurrentOffset() - sizeof(WebSocketMessageHeader); + Header.Crc32 = WebSocketMessageHeader::ExpectedMagic; // TODO + + MutableMemoryView HeaderView(const_cast(Writer.Data()), sizeof(WebSocketMessageHeader)); + HeaderView.CopyFrom(MemoryView(&Header, sizeof(WebSocketMessageHeader))); +} + std::unique_ptr WebSocketServer::Create() { -- cgit v1.2.3 From fd9f9086b3ddd0c38fa87d7e49f6341dacdcc125 Mon Sep 17 00:00:00 2001 From: Per Larsson Date: Mon, 21 Feb 2022 14:22:38 +0100 Subject: Refactored websocket message. --- zenhttp/websocketasio.cpp | 516 ++++++++++++++++++++++++++++++---------------- 1 file changed, 341 insertions(+), 175 deletions(-) (limited to 'zenhttp/websocketasio.cpp') diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index f6f58f38c..13d0177ee 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include #include @@ -75,7 +75,7 @@ protected: virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; virtual void OnReset() = 0; - SimpleBinaryWriter m_Stream; + BinaryWriter m_Stream; }; ParseMessageResult @@ -89,7 +89,7 @@ MessageParser::Reset() { OnReset(); - m_Stream.Clear(); + m_Stream.Reset(); } /////////////////////////////////////////////////////////////////////////////// @@ -374,13 +374,13 @@ class WebSocketMessageParser final : public MessageParser public: WebSocketMessageParser() : MessageParser() {} - bool TryLoadMessage(CbPackage& OutMsg); + WebSocketMessage ConsumeMessage(); private: virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; - WebSocketMessageHeader m_Header; + WebSocketMessage m_Message; }; ParseMessageResult @@ -388,65 +388,76 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) { const uint64_t PrevOffset = m_Stream.CurrentOffset(); - if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) { - const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_Stream.CurrentOffset(); + const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); m_Stream.Write(Msg.Left(RemaingHeaderSize)); - Msg.RightChopInline(RemaingHeaderSize); - - if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - const bool IsValidHeader = WebSocketMessageHeader::Read(m_Stream.GetView(), m_Header); + const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView()); if (IsValidHeader == false) { - return {.Status = ParseMessageStatus::kError, .Reason = std::string("Invalid websocket message header")}; + OnReset(); + + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message header")}; } + + Msg += RemaingHeaderSize; } - ZEN_ASSERT(m_Stream.CurrentOffset() >= sizeof(WebSocketMessageHeader)); + ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); if (Msg.IsEmpty()) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - const uint64_t RemaingContentSize = - Min(m_Header.ContentLength - (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)), Msg.GetSize()); + const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); + + m_Stream.Write(Msg.Left(RemaingMessageSize)); + + const bool IsComplete = WebSocketMessage::HeaderSize + m_Message.MessageSize() == m_Stream.CurrentOffset(); + + if (IsComplete) + { + BinaryReader Reader(m_Stream.GetView(WebSocketMessage::HeaderSize)); - m_Stream.Write(Msg.Left(RemaingContentSize)); + CbPackage Pkg; + if (Pkg.TryLoad(Reader) == false) + { + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message")}; + } - const auto Status = (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)) == m_Header.ContentLength - ? ParseMessageStatus::kDone - : ParseMessageStatus::kContinue; + m_Message.SetBody(std::move(Pkg)); + } - return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + return {.Status = IsComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } void WebSocketMessageParser::OnReset() { - m_Header = {}; + m_Message = WebSocketMessage(); } -bool -WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) +WebSocketMessage +WebSocketMessageParser::ConsumeMessage() { - const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength + sizeof(WebSocketMessageHeader); + WebSocketMessage Msg = std::move(m_Message); + m_Message = WebSocketMessage(); - if (IsParsed) - { - BinaryReader Reader(m_Stream.GetView().RightChop(sizeof(WebSocketMessageHeader))); - - return OutMsg.TryLoad(Reader); - } - - return false; + return Msg; } /////////////////////////////////////////////////////////////////////////////// @@ -574,10 +585,15 @@ public: WsServer() = default; virtual ~WsServer() { Shutdown(); } - virtual void RegisterService(WebSocketService& Service) override; virtual bool Run(const WebSocketServerOptions& Options) override; virtual void Shutdown() override; - virtual void PublishMessage(WebSocketId Id, CbPackage&& Msg) override; + + virtual void RegisterService(WebSocketService& Service) override; + virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override; + virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override; + + virtual void SendNotification(WebSocketMessage&& Notification) override; + virtual void SendResponse(WebSocketMessage&& Response) override; private: friend class WsConnection; @@ -586,14 +602,17 @@ private: void CloseConnection(std::shared_ptr Connection, const std::error_code& Ec); void ReadMessage(std::shared_ptr Connection); - void RouteMessage(std::shared_ptr Connection, const CbPackage& Msg); + void RouteMessage(WebSocketMessage&& Msg); + void SendMessage(WebSocketMessage&& Msg); struct IdHasher { size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } }; - using ConnectionMap = std::unordered_map, IdHasher>; + using ConnectionMap = std::unordered_map, IdHasher>; + using RequestHandlerMap = std::unordered_map; + using NotificationHandlerMap = std::unordered_map>; asio::io_service m_IoSvc; std::unique_ptr m_Acceptor; @@ -601,6 +620,8 @@ private: ConnectionMap m_Connections; std::shared_mutex m_ConnMutex; std::vector m_Services; + RequestHandlerMap m_RequestHandlers; + NotificationHandlerMap m_NotificationHandlers; std::atomic_bool m_Running{}; }; @@ -664,34 +685,32 @@ WsServer::Shutdown() } void -WsServer::PublishMessage(WebSocketId Id, CbPackage&& Msg) +WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) { - std::shared_ptr Connection; - - { - std::unique_lock _(m_ConnMutex); - - if (auto It = m_Connections.find(Id); It != m_Connections.end()) - { - Connection = It->second; - } - } + auto Result = m_NotificationHandlers.try_emplace(Key, std::vector()); + Result.first->second.push_back(&Service); +} - BinaryWriter Writer; - WebSocketMessageHeader::Write(Writer, Msg); +void +WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service) +{ + m_RequestHandlers[Key] = &Service; +} - IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); +void +WsServer::SendNotification(WebSocketMessage&& Notification) +{ + ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification); - ZEN_LOG_WARN(LogWebSocket, "sending message {}B to '#{} {}' ", Buffer.Size(), Connection->Id().Value(), Connection->RemoteAddr()); + SendMessage(std::move(Notification)); +} +void +WsServer::SendResponse(WebSocketMessage&& Response) +{ + ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse); + ZEN_ASSERT(Response.CorrelationId() != 0); - async_write(Connection->Socket(), - asio::buffer(Buffer.Data(), Buffer.Size()), - [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - CloseConnection(Connection, Ec); - } - }); + SendMessage(std::move(Response)); } void @@ -768,7 +787,7 @@ WsServer::ReadMessage(std::shared_ptr Connection) Connection->Socket(), Connection->ReadBuffer(), asio::transfer_at_least(1), - [this, Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { + [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable { if (ReadEc) { return CloseConnection(Connection, ReadEc); @@ -911,20 +930,15 @@ WsServer::ReadMessage(std::shared_ptr Connection) { WebSocketMessageParser& Parser = *reinterpret_cast(Connection->Parser()); - MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), ByteCount); + uint64_t RemainingBytes = Connection->ReadBuffer().size(); - while (MessageData.IsEmpty() == false) + while (RemainingBytes > 0) { - const ParseMessageResult Result = Parser.ParseMessage(MessageData); - - MessageData.RightChopInline(Result.ByteCount); - - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(MessageData.IsEmpty()); + MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); - return ReadMessage(Connection); - } + Connection->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Connection->ReadBuffer().size(); if (Result.Status == ParseMessageStatus::kError) { @@ -933,20 +947,19 @@ WsServer::ReadMessage(std::shared_ptr Connection) return CloseConnection(Connection, std::error_code()); } - CbPackage Message; - if (Parser.TryLoadMessage(Message) == false) + if (Result.Status == ParseMessageStatus::kContinue) { - ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'"); - - return CloseConnection(Connection, std::error_code()); + ZEN_ASSERT(RemainingBytes == 0); + continue; } - RouteMessage(Connection, Message); - + WebSocketMessage Message = Parser.ConsumeMessage(); Parser.Reset(); - } - Connection->ReadBuffer().consume(ByteCount); + Message.SetSocketId(Connection->Id()); + + RouteMessage(std::move(Message)); + } ReadMessage(Connection); } @@ -959,19 +972,114 @@ WsServer::ReadMessage(std::shared_ptr Connection) } void -WsServer::RouteMessage(std::shared_ptr Connection, const CbPackage& Msg) +WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) { - ZEN_LOG_DEBUG(LogWebSocket, "routing message"); + switch (RoutedMessage.MessageType()) + { + case WebSocketMessageType::kRequest: + { + CbObjectView Request = RoutedMessage.Body().GetObject(); + std::string_view Method = Request["Method"].AsString(); + bool Handled = false; + bool Error = false; + std::exception Exception; + + if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end()) + { + WebSocketService* Service = It->second; + ZEN_ASSERT(Service); + + try + { + Handled = Service->HandleRequest(std::move(RoutedMessage)); + } + catch (std::exception& Err) + { + Exception = std::move(Err); + Error = true; + } + } + + if (Error || Handled == false) + { + std::string ErrorText = Error ? Exception.what() : std::string("Not Found"); + + ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); + + CbObjectWriter Response; + Response << "Error"sv << ErrorText; + + WebSocketMessage ResponseMsg; + ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); + ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId()); + ResponseMsg.SetSocketId(RoutedMessage.SocketId()); + ResponseMsg.SetBody(Response.Save()); + + SendResponse(std::move(ResponseMsg)); + } + } + break; + + case WebSocketMessageType::kNotification: + { + CbObjectView Notification = RoutedMessage.Body().GetObject(); + std::string_view Message = Notification["Message"].AsString(); + + if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end()) + { + std::vector& Handlers = It->second; + + for (WebSocketService* Handler : Handlers) + { + Handler->HandleNotification(RoutedMessage); + } + } + else + { + ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message); + } + } + break; + }; +} + +void +WsServer::SendMessage(WebSocketMessage&& Msg) +{ + std::shared_ptr Connection; - for (auto Server : m_Services) { - if (Server->HandleMessage(Connection->Id(), Msg)) + std::unique_lock _(m_ConnMutex); + + if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end()) { - return; + Connection = It->second; } } - ZEN_LOG_WARN(LogWebSocket, "unhandled message"); + if (Connection.get() == nullptr) + { + ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value()); + return; + } + + if (Connection.get() != nullptr) + { + BinaryWriter Writer; + Msg.Save(Writer); + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + async_write(Connection->Socket(), + asio::buffer(Buffer.Data(), Buffer.Size()), + [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason '{}'", Ec.message()); + + CloseConnection(Connection, Ec); + } + }); + } } /////////////////////////////////////////////////////////////////////////////// @@ -984,15 +1092,14 @@ public: std::shared_ptr AsShared() { return shared_from_this(); } - virtual bool Connect(const WebSocketConnectInfo& Info) override; - virtual void Disconnect() override; - virtual bool IsConnected() const { return false; } - virtual WebSocketState State() const { return static_cast(m_State.load()); } - virtual void SendMsg(CbPackage&& Msg) override; - virtual void SendMsg(CbObject&& Msg) override; + virtual std::future Connect(const WebSocketConnectInfo& Info) override; + virtual void Disconnect() override; + virtual bool IsConnected() const { return false; } + virtual WebSocketState State() const { return static_cast(m_State.load()); } - virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override; - virtual void OnMessage(MessageCallback&& Cb) override; + virtual std::future SendRequest(WebSocketMessage&& Request) override; + virtual void OnNotification(NotificationCallback&& Cb) override; + virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override; private: WebSocketState SetState(WebSocketState NewState) { return static_cast(m_State.exchange(uint32_t(NewState))); } @@ -1001,7 +1108,9 @@ private: asio::streambuf& ReadBuffer() { return m_ReadBuffer; } void TriggerEvent(WebSocketEvent Evt); void ReadMessage(); - void RouteMessage(CbPackage&& Msg); + void RouteMessage(WebSocketMessage&& RoutedMessage); + + using PendingRequestMap = std::unordered_map>; asio::io_context& m_IoCtx; WebSocketId m_Id; @@ -1009,18 +1118,21 @@ private: std::unique_ptr m_MsgParser; asio::streambuf m_ReadBuffer; EventCallback m_EventCallbacks[3]; - MessageCallback m_MsgCallback; + NotificationCallback m_NotificationCallback; + PendingRequestMap m_PendingRequests; + std::mutex m_RequestMutex; + std::promise m_ConnectPromise; std::atomic_uint32_t m_State; std::string m_Host; int16_t m_Port{}; }; -bool +std::future WsClient::Connect(const WebSocketConnectInfo& Info) { if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) { - return true; + return m_ConnectPromise.get_future(); } SetState(WebSocketState::kHandshaking); @@ -1046,7 +1158,9 @@ WsClient::Connect(const WebSocketConnectInfo& Info) TriggerEvent(WebSocketEvent::kDisconnected); - return false; + m_ConnectPromise.set_value(false); + + return m_ConnectPromise.get_future(); } ExtendableStringBuilder<128> Sb; @@ -1093,7 +1207,7 @@ WsClient::Connect(const WebSocketConnectInfo& Info) } }); - return true; + return m_ConnectPromise.get_future(); } void @@ -1110,48 +1224,64 @@ WsClient::Disconnect() } TriggerEvent(WebSocketEvent::kDisconnected); + + { + std::unique_lock _(m_RequestMutex); + + for (auto& Kv : m_PendingRequests) + { + Kv.second.set_value(WebSocketMessage()); + } + + m_PendingRequests.clear(); + } } } -void -WsClient::SendMsg(CbPackage&& Msg) +std::future +WsClient::SendRequest(WebSocketMessage&& Request) { + ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest); + BinaryWriter Writer; - WebSocketMessageHeader::Write(Writer, Msg); + Request.Save(Writer); - IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + std::future FutureResponse; + + { + std::unique_lock _(m_RequestMutex); - ZEN_LOG_DEBUG(LogWsClient, "sending message {}B", Buffer.Size()); + auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise()); + ZEN_ASSERT(Result.second); - async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const asio::error_code& Ec, std::size_t) { + auto It = Result.first; + FutureResponse = It->second.get_future(); + } + + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) { if (Ec) { - ZEN_LOG_ERROR(LogWsClient, "send messge FAILED, reason '{}'", Ec.message()); + ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message()); Self->Disconnect(); } }); -} - -void -WsClient::SendMsg(CbObject&& Msg) -{ - CbPackage Pkg; - Pkg.SetObject(std::move(Msg)); - WsClient::SendMsg(std::move(Pkg)); + return FutureResponse; } void -WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) +WsClient::OnNotification(NotificationCallback&& Cb) { - m_EventCallbacks[static_cast(Evt)] = std::move(Cb); + m_NotificationCallback = std::move(Cb); } void -WsClient::OnMessage(MessageCallback&& Cb) +WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) { - m_MsgCallback = std::move(Cb); + m_EventCallbacks[static_cast(Evt)] = std::move(Cb); } void @@ -1199,6 +1329,8 @@ WsClient::ReadMessage() { ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); + Self->m_ConnectPromise.set_value(false); + return Self->Disconnect(); } @@ -1216,6 +1348,8 @@ WsClient::ReadMessage() Parser.StatusText(), Parser.StatusCode()); + Self->m_ConnectPromise.set_value(false); + return Self->Disconnect(); } @@ -1225,6 +1359,8 @@ WsClient::ReadMessage() Self->SetState(WebSocketState::kConnected); Self->ReadMessage(); Self->TriggerEvent(WebSocketEvent::kConnected); + + Self->m_ConnectPromise.set_value(true); } break; @@ -1232,44 +1368,36 @@ WsClient::ReadMessage() { WebSocketMessageParser& Parser = *reinterpret_cast(Self->Parser()); - MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); + uint64_t RemainingBytes = Self->ReadBuffer().size(); - while (MessageData.IsEmpty() == false) + while (RemainingBytes > 0) { - const ParseMessageResult Result = Parser.ParseMessage(MessageData); - - MessageData.RightChopInline(Result.ByteCount); - - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(MessageData.IsEmpty()); + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); - return Self->ReadMessage(); - } + Self->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Self->ReadBuffer().size(); if (Result.Status == ParseMessageStatus::kError) { ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); Parser.Reset(); - continue; } - CbPackage Message; - if (Parser.TryLoadMessage(Message)) - { - Self->RouteMessage(std::move(Message)); - } - else + if (Result.Status == ParseMessageStatus::kContinue) { - ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason 'invalid message'"); + ZEN_ASSERT(RemainingBytes == 0); + continue; } + WebSocketMessage Message = Parser.ConsumeMessage(); Parser.Reset(); + + Self->RouteMessage(std::move(Message)); } - Self->ReadBuffer().consume(ByteCount); Self->ReadMessage(); } break; @@ -1278,12 +1406,39 @@ WsClient::ReadMessage() } void -WsClient::RouteMessage(CbPackage&& Msg) +WsClient::RouteMessage(WebSocketMessage&& RoutedMessage) { - if (m_MsgCallback) + switch (RoutedMessage.MessageType()) { - m_MsgCallback(Msg); - } + case WebSocketMessageType::kResponse: + { + std::unique_lock _(m_RequestMutex); + + if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end()) + { + It->second.set_value(std::move(RoutedMessage)); + m_PendingRequests.erase(It); + } + else + { + ZEN_LOG_WARN(LogWsClient, + "route request message FAILED, reason 'unknown correlation ID ({})'", + RoutedMessage.CorrelationId()); + } + } + break; + + case WebSocketMessageType::kNotification: + { + std::unique_lock _(m_RequestMutex); + + if (m_NotificationCallback) + { + m_NotificationCallback(std::move(RoutedMessage)); + } + } + break; + }; } } // namespace zen::websocket @@ -1292,67 +1447,78 @@ namespace zen { std::atomic_uint32_t WebSocketId::NextId{1}; -void -WebSocketService::Configure(WebSocketServer& Server) +bool +WebSocketMessage::Header::IsValid() const { - ZEN_ASSERT(m_Server == nullptr); - - m_Server = &Server; + return Magic == HeaderMagic && MessageSize > 0 && Crc32 > 0 && uint8_t(MessageType) > 0 && uint8_t(MessageType) < 4; } +std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; + void -WebSocketService::PublishMessage(WebSocketId Id, CbPackage&& Msg) +WebSocketMessage::SetMessageType(WebSocketMessageType MessageType) { - ZEN_ASSERT(m_Server != nullptr); - - m_Server->PublishMessage(Id, std::move(Msg)); + m_Header.MessageType = MessageType; } void -WebSocketService::PublishMessage(WebSocketId Id, CbObject&& Msg) +WebSocketMessage::SetBody(CbPackage&& Body) +{ + m_Body = std::move(Body); +} +void +WebSocketMessage::SetBody(CbObject&& Body) { - ZEN_ASSERT(m_Server != nullptr); - CbPackage Pkg; - Pkg.SetObject(std::move(Msg)); + Pkg.SetObject(Body); - m_Server->PublishMessage(Id, std::move(Pkg)); + SetBody(std::move(Pkg)); } -bool -WebSocketMessageHeader::IsValid() const +void +WebSocketMessage::Save(BinaryWriter& Writer) { - return Magic == ExpectedMagic && ContentLength != 0 && Crc32 != 0; + Writer.Write(&m_Header, HeaderSize); + + if (m_Body.has_value()) + { + m_Body.value().Save(Writer); + } + + if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest) + { + m_Header.CorrelationId = NextCorrelationId.fetch_add(1); + } + + m_Header.MessageSize = Writer.Size() - HeaderSize; + m_Header.Crc32 = 1; // TODO + + Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); } bool -WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeader) +WebSocketMessage::TryLoadHeader(MemoryView Memory) { - if (Memory.GetSize() < sizeof(WebSocketMessageHeader)) + if (Memory.GetSize() < HeaderSize) { return false; } - void* Dst = &OutHeader; - memcpy(Dst, Memory.GetData(), sizeof(WebSocketMessageHeader)); + MutableMemoryView HeaderView(&m_Header, HeaderSize); + + HeaderView.CopyFrom(Memory); - return OutHeader.IsValid(); + return m_Header.IsValid(); } void -WebSocketMessageHeader::Write(BinaryWriter& Writer, const CbPackage& Msg) +WebSocketService::Configure(WebSocketServer& Server) { - WebSocketMessageHeader Header; - - Writer.Write(&Header, sizeof(WebSocketMessageHeader)); - Msg.Save(Writer); + ZEN_ASSERT(m_SocketServer == nullptr); - Header.Magic = WebSocketMessageHeader::ExpectedMagic; - Header.ContentLength = Writer.CurrentOffset() - sizeof(WebSocketMessageHeader); - Header.Crc32 = WebSocketMessageHeader::ExpectedMagic; // TODO + m_SocketServer = &Server; - MutableMemoryView HeaderView(const_cast(Writer.Data()), sizeof(WebSocketMessageHeader)); - HeaderView.CopyFrom(MemoryView(&Header, sizeof(WebSocketMessageHeader))); + RegisterHandlers(Server); } std::unique_ptr -- cgit v1.2.3 From 41782efc63d7f88525596d6724a1bb86d6fdcfa4 Mon Sep 17 00:00:00 2001 From: Per Larsson Date: Mon, 21 Feb 2022 15:00:02 +0100 Subject: Added option to enable websockets. --- zenhttp/websocketasio.cpp | 51 ++++++++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 25 deletions(-) (limited to 'zenhttp/websocketasio.cpp') diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index 13d0177ee..c2ce7ca64 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -582,10 +582,10 @@ WsThreadPool::Stop() class WsServer final : public WebSocketServer { public: - WsServer() = default; + WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {} virtual ~WsServer() { Shutdown(); } - virtual bool Run(const WebSocketServerOptions& Options) override; + virtual bool Run() override; virtual void Shutdown() override; virtual void RegisterService(WebSocketService& Service) override; @@ -614,6 +614,7 @@ private: using RequestHandlerMap = std::unordered_map; using NotificationHandlerMap = std::unordered_map>; + WebSocketServerOptions m_Options; asio::io_service m_IoSvc; std::unique_ptr m_Acceptor; std::unique_ptr m_ThreadPool; @@ -634,7 +635,7 @@ WsServer::RegisterService(WebSocketService& Service) } bool -WsServer::Run(const WebSocketServerOptions& Options) +WsServer::Run() { m_Acceptor = std::make_unique(m_IoSvc, asio::ip::tcp::v6()); @@ -645,7 +646,7 @@ WsServer::Run(const WebSocketServerOptions& Options) m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); asio::error_code Ec; - m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec); + m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec); if (Ec) { @@ -657,12 +658,12 @@ WsServer::Run(const WebSocketServerOptions& Options) m_Acceptor->listen(); m_Running = true; - ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", Options.Port); + ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port); AcceptConnection(); m_ThreadPool = std::make_unique(m_IoSvc); - m_ThreadPool->Start(Options.ThreadCount); + m_ThreadPool->Start(m_Options.ThreadCount); return true; } @@ -720,29 +721,29 @@ WsServer::AcceptConnection() asio::ip::tcp::socket& SocketRef = *Socket.get(); m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { - if (Ec) - { - ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message()); - } - else + if (m_Running) { - auto Connection = std::make_shared(WebSocketId::New(), std::move(ConnectedSocket)); - - ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); - + if (Ec) { - std::unique_lock _(m_ConnMutex); - m_Connections[Connection->Id()] = Connection; + ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message()); } + else + { + auto Connection = std::make_shared(WebSocketId::New(), std::move(ConnectedSocket)); - Connection->SetParser(std::make_unique(HttpMessageParserType::kRequest)); - Connection->SetState(WebSocketState::kHandshaking); + ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); - ReadMessage(Connection); - } + { + std::unique_lock _(m_ConnMutex); + m_Connections[Connection->Id()] = Connection; + } + + Connection->SetParser(std::make_unique(HttpMessageParserType::kRequest)); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + } - if (m_Running) - { AcceptConnection(); } }); @@ -1522,9 +1523,9 @@ WebSocketService::Configure(WebSocketServer& Server) } std::unique_ptr -WebSocketServer::Create() +WebSocketServer::Create(const WebSocketServerOptions& Options) { - return std::make_unique(); + return std::make_unique(Options); } std::shared_ptr -- cgit v1.2.3 From 75fbf7811d2059d0a9677dd868d3e3f2147b64ae Mon Sep 17 00:00:00 2001 From: Per Larsson Date: Mon, 21 Feb 2022 15:10:06 +0100 Subject: Removed optional offset for GetView. --- zenhttp/websocketasio.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'zenhttp/websocketasio.cpp') diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index c2ce7ca64..1a95b12bc 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -428,7 +428,7 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) if (IsComplete) { - BinaryReader Reader(m_Stream.GetView(WebSocketMessage::HeaderSize)); + BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize)); CbPackage Pkg; if (Pkg.TryLoad(Reader) == false) -- cgit v1.2.3