diff options
| author | Per Larsson <[email protected]> | 2022-02-18 14:48:41 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-02-18 14:48:41 +0100 |
| commit | 08cc02bf1b5cad75ed7b97c0dc7cfc082da537c4 (patch) | |
| tree | 8fd8e7189c9807280003eabb3040cb35db2e5cd4 /zenhttp/websocketasio.cpp | |
| parent | Web socket client is shared between I/O thead and client. (diff) | |
| download | zen-08cc02bf1b5cad75ed7b97c0dc7cfc082da537c4.tar.xz zen-08cc02bf1b5cad75ed7b97c0dc7cfc082da537c4.zip | |
Basic websocket service and test.
Diffstat (limited to 'zenhttp/websocketasio.cpp')
| -rw-r--r-- | zenhttp/websocketasio.cpp | 212 |
1 files changed, 161 insertions, 51 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index 1952c97a2..f6f58f38c 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -13,7 +13,6 @@ #include <zencore/string.h> #include <chrono> -#include <compare> #include <optional> #include <shared_mutex> #include <span> @@ -381,7 +380,6 @@ private: virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; - SimpleBinaryWriter m_HeaderStream; WebSocketMessageHeader m_Header; }; @@ -390,20 +388,20 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) { const uint64_t PrevOffset = m_Stream.CurrentOffset(); - if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader)) { - const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_HeaderStream.CurrentOffset(); + const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_Stream.CurrentOffset(); - m_HeaderStream.Write(Msg.Left(RemaingHeaderSize)); + m_Stream.Write(Msg.Left(RemaingHeaderSize)); Msg.RightChopInline(RemaingHeaderSize); - if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader)) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - const bool IsValidHeader = WebSocketMessageHeader::Read(m_HeaderStream.GetView(), m_Header); + const bool IsValidHeader = WebSocketMessageHeader::Read(m_Stream.GetView(), m_Header); if (IsValidHeader == false) { @@ -411,16 +409,21 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) } } - if (Msg.GetSize() == 0) + ZEN_ASSERT(m_Stream.CurrentOffset() >= sizeof(WebSocketMessageHeader)); + + if (Msg.IsEmpty()) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - const uint64_t RemaingContentSize = m_Header.ContentLength - m_HeaderStream.CurrentOffset(); + const uint64_t RemaingContentSize = + Min(m_Header.ContentLength - (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)), Msg.GetSize()); m_Stream.Write(Msg.Left(RemaingContentSize)); - const auto Status = m_Stream.CurrentOffset() == m_Header.ContentLength ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; + const auto Status = (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)) == m_Header.ContentLength + ? ParseMessageStatus::kDone + : ParseMessageStatus::kContinue; return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } @@ -428,18 +431,17 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) void WebSocketMessageParser::OnReset() { - m_HeaderStream.Clear(); m_Header = {}; } bool WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) { - const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength; + const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength + sizeof(WebSocketMessageHeader); if (IsParsed) { - BinaryReader Reader(m_Stream.Data(), m_Stream.Size()); + BinaryReader Reader(m_Stream.GetView().RightChop(sizeof(WebSocketMessageHeader))); return OutMsg.TryLoad(Reader); } @@ -448,32 +450,10 @@ WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) } /////////////////////////////////////////////////////////////////////////////// -class WsConnectionId -{ - static std::atomic_uint32_t WsConnectionCounter; - -public: - WsConnectionId() = default; - - uint32_t Value() const { return m_Value; } - - auto operator<=>(const WsConnectionId& RHS) const = default; - - static WsConnectionId New() { return WsConnectionId(WsConnectionCounter.fetch_add(1)); } - -private: - WsConnectionId(uint32_t Value) : m_Value(Value) {} - - uint32_t m_Value{}; -}; - -std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1}; - -/////////////////////////////////////////////////////////////////////////////// class WsConnection : public std::enable_shared_from_this<WsConnection> { public: - WsConnection(WsConnectionId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) + WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) : m_Id(Id) , m_Socket(std::move(Socket)) , m_StartTime(Clock::now()) @@ -485,7 +465,7 @@ public: std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } - WsConnectionId Id() const { return m_Id; } + WebSocketId Id() const { return m_Id; } asio::ip::tcp::socket& Socket() { return *m_Socket; } TimePoint StartTime() const { return m_StartTime; } WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); } @@ -497,7 +477,7 @@ public: void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } private: - WsConnectionId m_Id; + WebSocketId m_Id; std::unique_ptr<asio::ip::tcp::socket> m_Socket; TimePoint m_StartTime; std::atomic_uint32_t m_State; @@ -594,8 +574,10 @@ public: WsServer() = default; virtual ~WsServer() { Shutdown(); } + virtual void RegisterService(WebSocketService& Service) override; virtual bool Run(const WebSocketServerOptions& Options) override; virtual void Shutdown() override; + virtual void PublishMessage(WebSocketId Id, CbPackage&& Msg) override; private: friend class WsConnection; @@ -608,19 +590,28 @@ private: struct IdHasher { - size_t operator()(WsConnectionId Id) const { return size_t(Id.Value()); } + size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } }; - using ConnectionMap = std::unordered_map<WsConnectionId, std::shared_ptr<WsConnection>, IdHasher>; + using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>; asio::io_service m_IoSvc; std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; std::unique_ptr<WsThreadPool> m_ThreadPool; ConnectionMap m_Connections; std::shared_mutex m_ConnMutex; + std::vector<WebSocketService*> m_Services; std::atomic_bool m_Running{}; }; +void +WsServer::RegisterService(WebSocketService& Service) +{ + m_Services.push_back(&Service); + + Service.Configure(*this); +} + bool WsServer::Run(const WebSocketServerOptions& Options) { @@ -673,6 +664,37 @@ WsServer::Shutdown() } void +WsServer::PublishMessage(WebSocketId Id, CbPackage&& Msg) +{ + std::shared_ptr<WsConnection> Connection; + + { + std::unique_lock _(m_ConnMutex); + + if (auto It = m_Connections.find(Id); It != m_Connections.end()) + { + Connection = It->second; + } + } + + BinaryWriter Writer; + WebSocketMessageHeader::Write(Writer, Msg); + + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + ZEN_LOG_WARN(LogWebSocket, "sending message {}B to '#{} {}' ", Buffer.Size(), Connection->Id().Value(), Connection->RemoteAddr()); + + async_write(Connection->Socket(), + asio::buffer(Buffer.Data(), Buffer.Size()), + [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + CloseConnection(Connection, Ec); + } + }); +} + +void WsServer::AcceptConnection() { auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc); @@ -685,7 +707,7 @@ WsServer::AcceptConnection() } else { - auto Connection = std::make_shared<WsConnection>(WsConnectionId::New(), std::move(ConnectedSocket)); + auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket)); ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); @@ -726,7 +748,7 @@ WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::e } } - const WsConnectionId Id = Connection->Id(); + const WebSocketId Id = Connection->Id(); { std::unique_lock _(m_ConnMutex); @@ -763,6 +785,8 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + Connection->ReadBuffer().consume(Result.ByteCount); + if (Result.Status == ParseMessageStatus::kContinue) { return ReadMessage(Connection); @@ -770,10 +794,10 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) if (Result.Status == ParseMessageStatus::kError) { - ZEN_LOG_DEBUG(LogWebSocket, - "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", - Connection->Id().Value(), - Connection->RemoteAddr()); + ZEN_LOG_WARN(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", + Connection->Id().Value(), + Connection->RemoteAddr()); return CloseConnection(Connection, std::error_code()); } @@ -877,6 +901,8 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) Connection->SetParser(std::make_unique<WebSocketMessageParser>()); Connection->SetState(kConnected); + + ReadMessage(Connection); }); } break; @@ -935,15 +961,24 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) void WsServer::RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg) { - ZEN_UNUSED(Connection, Msg); ZEN_LOG_DEBUG(LogWebSocket, "routing message"); + + for (auto Server : m_Services) + { + if (Server->HandleMessage(Connection->Id(), Msg)) + { + return; + } + } + + ZEN_LOG_WARN(LogWebSocket, "unhandled message"); } /////////////////////////////////////////////////////////////////////////////// class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient> { public: - WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WsConnectionId::New()) {} + WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {} virtual ~WsClient() { Disconnect(); } @@ -953,6 +988,8 @@ public: virtual void Disconnect() override; virtual bool IsConnected() const { return false; } virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); } + virtual void SendMsg(CbPackage&& Msg) override; + virtual void SendMsg(CbObject&& Msg) override; virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override; virtual void OnMessage(MessageCallback&& Cb) override; @@ -967,7 +1004,7 @@ private: void RouteMessage(CbPackage&& Msg); asio::io_context& m_IoCtx; - WsConnectionId m_Id; + WebSocketId m_Id; std::unique_ptr<asio::ip::tcp::socket> m_Socket; std::unique_ptr<MessageParser> m_MsgParser; asio::streambuf m_ReadBuffer; @@ -1077,6 +1114,35 @@ WsClient::Disconnect() } void +WsClient::SendMsg(CbPackage&& Msg) +{ + BinaryWriter Writer; + WebSocketMessageHeader::Write(Writer, Msg); + + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + ZEN_LOG_DEBUG(LogWsClient, "sending message {}B", Buffer.Size()); + + async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_LOG_ERROR(LogWsClient, "send messge FAILED, reason '{}'", Ec.message()); + + Self->Disconnect(); + } + }); +} + +void +WsClient::SendMsg(CbObject&& Msg) +{ + CbPackage Pkg; + Pkg.SetObject(std::move(Msg)); + + WsClient::SendMsg(std::move(Pkg)); +} + +void WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) { m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); @@ -1157,9 +1223,8 @@ WsClient::ReadMessage() Self->SetParser(std::make_unique<WebSocketMessageParser>()); Self->SetState(WebSocketState::kConnected); - Self->TriggerEvent(WebSocketEvent::kConnected); - Self->ReadMessage(); + Self->TriggerEvent(WebSocketEvent::kConnected); } break; @@ -1225,6 +1290,35 @@ WsClient::RouteMessage(CbPackage&& Msg) namespace zen { +std::atomic_uint32_t WebSocketId::NextId{1}; + +void +WebSocketService::Configure(WebSocketServer& Server) +{ + ZEN_ASSERT(m_Server == nullptr); + + m_Server = &Server; +} + +void +WebSocketService::PublishMessage(WebSocketId Id, CbPackage&& Msg) +{ + ZEN_ASSERT(m_Server != nullptr); + + m_Server->PublishMessage(Id, std::move(Msg)); +} + +void +WebSocketService::PublishMessage(WebSocketId Id, CbObject&& Msg) +{ + ZEN_ASSERT(m_Server != nullptr); + + CbPackage Pkg; + Pkg.SetObject(std::move(Msg)); + + m_Server->PublishMessage(Id, std::move(Pkg)); +} + bool WebSocketMessageHeader::IsValid() const { @@ -1245,6 +1339,22 @@ WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeade return OutHeader.IsValid(); } +void +WebSocketMessageHeader::Write(BinaryWriter& Writer, const CbPackage& Msg) +{ + WebSocketMessageHeader Header; + + Writer.Write(&Header, sizeof(WebSocketMessageHeader)); + Msg.Save(Writer); + + Header.Magic = WebSocketMessageHeader::ExpectedMagic; + Header.ContentLength = Writer.CurrentOffset() - sizeof(WebSocketMessageHeader); + Header.Crc32 = WebSocketMessageHeader::ExpectedMagic; // TODO + + MutableMemoryView HeaderView(const_cast<uint8_t*>(Writer.Data()), sizeof(WebSocketMessageHeader)); + HeaderView.CopyFrom(MemoryView(&Header, sizeof(WebSocketMessageHeader))); +} + std::unique_ptr<WebSocketServer> WebSocketServer::Create() { |