diff options
| author | Per Larsson <[email protected]> | 2022-02-21 14:22:38 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-02-21 14:22:38 +0100 |
| commit | fd9f9086b3ddd0c38fa87d7e49f6341dacdcc125 (patch) | |
| tree | 50be00e62f1e0d3a0521d08d5c7f00fe7787e014 | |
| parent | Basic websocket service and test. (diff) | |
| download | zen-fd9f9086b3ddd0c38fa87d7e49f6341dacdcc125.tar.xz zen-fd9f9086b3ddd0c38fa87d7e49f6341dacdcc125.zip | |
Refactored websocket message.
| -rw-r--r-- | zencore/include/zencore/stream.h | 48 | ||||
| -rw-r--r-- | zencore/stream.cpp | 19 | ||||
| -rw-r--r-- | zenhttp/include/zenhttp/websocket.h | 156 | ||||
| -rw-r--r-- | zenhttp/websocketasio.cpp | 516 | ||||
| -rw-r--r-- | zenserver-test/zenserver-test.cpp | 54 | ||||
| -rw-r--r-- | zenserver/testing/httptest.cpp | 31 | ||||
| -rw-r--r-- | zenserver/testing/httptest.h | 4 |
7 files changed, 523 insertions, 305 deletions
diff --git a/zencore/include/zencore/stream.h b/zencore/include/zencore/stream.h index 6d7e7d19f..54d7e1014 100644 --- a/zencore/include/zencore/stream.h +++ b/zencore/include/zencore/stream.h @@ -27,12 +27,29 @@ public: m_Offset += ByteCount; } + inline void Write(MemoryView Memory) { Write(Memory.GetData(), Memory.GetSize()); } + inline uint64_t CurrentOffset() const { return m_Offset; } inline const uint8_t* Data() const { return m_Buffer.data(); } inline const uint8_t* GetData() const { return m_Buffer.data(); } inline uint64_t Size() const { return m_Buffer.size(); } inline uint64_t GetSize() const { return m_Buffer.size(); } + void Reset(); + + inline MemoryView GetView(uint64_t Offset = 0) const + { + MemoryView View(m_Buffer.data(), m_Offset); + View.RightChopInline(Offset); + return View; + } + + inline MutableMemoryView GetMutableView(uint64_t Offset = 0) + { + MutableMemoryView View(m_Buffer.data(), m_Offset); + View.RightChopInline(Offset); + return View; + } private: RwLock m_Lock; @@ -49,37 +66,6 @@ MakeMemoryView(const BinaryWriter& Stream) } /** - * Non thread-safe stream writer - */ - -class SimpleBinaryWriter -{ - static constexpr uint32_t DefaultBlockSize = 64; - -public: - SimpleBinaryWriter(uint32_t BlockSize = DefaultBlockSize) : m_BlockSize(BlockSize), m_Offset{0} {} - ~SimpleBinaryWriter() = default; - - void Write(MemoryView Memory); - void Write(const void* Data, size_t Size) { Write(MemoryView(Data, Size)); } - void Clear(); - - inline uint64_t CurrentOffset() const { return m_Offset; } - - inline const uint8_t* Data() const { return m_Buffer.data(); } - inline const uint8_t* GetData() const { return m_Buffer.data(); } - inline uint64_t Size() const { return m_Buffer.size(); } - inline uint64_t GetSize() const { return m_Buffer.size(); } - - MemoryView GetView() const { return MemoryView(m_Buffer.data(), m_Offset); } - -private: - std::vector<uint8_t> m_Buffer; - uint64_t m_Offset; - uint32_t m_BlockSize; -}; - -/** * Binary stream reader */ diff --git a/zencore/stream.cpp b/zencore/stream.cpp index 36953363f..8faf90af2 100644 --- a/zencore/stream.cpp +++ b/zencore/stream.cpp @@ -26,25 +26,10 @@ BinaryWriter::Write(const void* data, size_t ByteCount, uint64_t Offset) } void -SimpleBinaryWriter::Write(MemoryView Memory) +BinaryWriter::Reset() { - const uint64_t NeededSize = m_Offset + Memory.GetSize(); - - if (NeededSize > m_Buffer.size()) - { - const size_t NewCapacity = RoundUp(NeededSize, m_BlockSize); - - m_Buffer.resize(NewCapacity); - } - - memcpy(m_Buffer.data() + m_Offset, Memory.GetData(), Memory.GetSize()); - - m_Offset += Memory.GetSize(); -} + RwLock::ExclusiveLockScope _(m_Lock); -void -SimpleBinaryWriter::Clear() -{ m_Buffer.clear(); m_Offset = 0; } diff --git a/zenhttp/include/zenhttp/websocket.h b/zenhttp/include/zenhttp/websocket.h index 1ab9b4804..336d98b42 100644 --- a/zenhttp/include/zenhttp/websocket.h +++ b/zenhttp/include/zenhttp/websocket.h @@ -1,10 +1,13 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#include <zencore/compactbinarypackage.h> #include <zencore/memory.h> #include <compare> #include <functional> +#include <future> #include <memory> +#include <optional> #pragma once @@ -14,10 +17,11 @@ class io_context; namespace zen { -class CbPackage; -class CbObject; class BinaryWriter; +/** + * A unique socket ID. + */ class WebSocketId { static std::atomic_uint32_t NextId; @@ -37,45 +41,132 @@ private: uint32_t m_Value{}; }; +/** + * Type of web socket message. + */ +enum class WebSocketMessageType : uint8_t +{ + kInvalid, + kNotification, + kRequest, + kResponse +}; + +/** + * Web socket message. + */ +class WebSocketMessage +{ + struct Header + { + static constexpr uint32_t HeaderMagic = 0x7a776d68; // zwmh + + uint64_t MessageSize{}; + uint32_t Magic{HeaderMagic}; + uint32_t CorrelationId{}; + uint32_t Crc32{}; + WebSocketMessageType MessageType{}; + uint8_t Reserved[3] = {0}; + + bool IsValid() const; + }; + + static_assert(sizeof Header == 24); + + static std::atomic_uint32_t NextCorrelationId; + +public: + static constexpr size_t HeaderSize = sizeof(Header); + + WebSocketMessage() = default; + + WebSocketId SocketId() const { return m_SocketId; } + void SetSocketId(WebSocketId Id) { m_SocketId = Id; } + void SetMessageType(WebSocketMessageType MessageType); + WebSocketMessageType MessageType() const { return m_Header.MessageType; } + uint64_t MessageSize() const { return m_Header.MessageSize; } + void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; } + uint32_t CorrelationId() const { return m_Header.CorrelationId; } + + const CbPackage& Body() const { return m_Body.value(); } + void SetBody(CbPackage&& Body); + void SetBody(CbObject&& Body); + bool HasBody() const { return m_Body.has_value(); } + + void Save(BinaryWriter& Writer); + bool TryLoadHeader(MemoryView Memory); + + bool IsValid() const { return m_Header.MessageType != WebSocketMessageType::kInvalid; } + +private: + Header m_Header{}; + WebSocketId m_SocketId{}; + std::optional<CbPackage> m_Body; +}; + class WebSocketServer; +/** + * Base class for handling web socket requests and notifications from connected client(s). + */ class WebSocketService { public: virtual ~WebSocketService() = default; - virtual bool HandleMessage(WebSocketId Id, const CbPackage& Msg) = 0; - void Configure(WebSocketServer& Server); + void Configure(WebSocketServer& Server); + + virtual bool HandleRequest(const WebSocketMessage&) { ZEN_ASSERT(false); } + virtual void HandleNotification(const WebSocketMessage&) { ZEN_ASSERT(false); } protected: WebSocketService() = default; - void PublishMessage(WebSocketId Id, CbPackage&& Msg); - void PublishMessage(WebSocketId Id, CbObject&& Msg); + virtual void RegisterHandlers(WebSocketServer& Server) = 0; + + WebSocketServer& SocketServer() + { + ZEN_ASSERT(m_SocketServer); + return *m_SocketServer; + } private: - WebSocketServer* m_Server = nullptr; + WebSocketServer* m_SocketServer{}; }; +/** + * Server options. + */ struct WebSocketServerOptions { uint16_t Port = 8848; uint32_t ThreadCount = 1; }; +/** + * The web socket server manages client connections and routing of requests and notifications. + */ class WebSocketServer { public: virtual ~WebSocketServer() = default; - virtual void RegisterService(WebSocketService& Service) = 0; - virtual bool Run(const WebSocketServerOptions& Options) = 0; - virtual void Shutdown() = 0; - virtual void PublishMessage(WebSocketId Id, CbPackage&& Msg) = 0; + virtual bool Run(const WebSocketServerOptions& Options) = 0; + virtual void Shutdown() = 0; + + virtual void RegisterService(WebSocketService& Service) = 0; + virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) = 0; + virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) = 0; + + virtual void SendNotification(WebSocketMessage&& Notification) = 0; + virtual void SendResponse(WebSocketMessage&& Response) = 0; static std::unique_ptr<WebSocketServer> Create(); }; +/** + * The state of the web socket. + */ enum class WebSocketState : uint32_t { kNone, @@ -85,6 +176,9 @@ enum class WebSocketState : uint32_t kError }; +/** + * Type of web socket client event. + */ enum class WebSocketEvent : uint32_t { kConnected, @@ -92,6 +186,9 @@ enum class WebSocketEvent : uint32_t kError }; +/** + * Web socket client connection info. + */ struct WebSocketConnectInfo { std::string Host; @@ -101,40 +198,27 @@ struct WebSocketConnectInfo uint16_t Version{13}; }; +/** + * A connection to a web socket server for sending requests and listening for notifications. + */ class WebSocketClient { public: - using EventCallback = std::function<void()>; - using MessageCallback = std::function<void(CbPackage&)>; + using EventCallback = std::function<void()>; + using NotificationCallback = std::function<void(WebSocketMessage&&)>; virtual ~WebSocketClient() = default; - virtual bool Connect(const WebSocketConnectInfo& Info) = 0; - virtual void Disconnect() = 0; - virtual bool IsConnected() const = 0; - virtual WebSocketState State() const = 0; - virtual void SendMsg(CbObject&& Msg) = 0; - virtual void SendMsg(CbPackage&& Msg) = 0; + virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) = 0; + virtual void Disconnect() = 0; + virtual bool IsConnected() const = 0; + virtual WebSocketState State() const = 0; - virtual void On(WebSocketEvent Evt, EventCallback&& Cb) = 0; - virtual void OnMessage(MessageCallback&& Cb) = 0; + virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) = 0; + virtual void OnNotification(NotificationCallback&& Cb) = 0; + virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) = 0; static std::shared_ptr<WebSocketClient> Create(asio::io_context& IoCtx); }; -struct WebSocketMessageHeader -{ - static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh - - uint32_t Magic{}; - uint64_t ContentLength{}; - uint32_t Crc32{}; - - bool IsValid() const; - - static bool Read(MemoryView Memory, WebSocketMessageHeader& OutHeader); - - static void Write(BinaryWriter& Writer, const CbPackage& Msg); -}; - } // namespace zen diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index f6f58f38c..13d0177ee 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -3,7 +3,7 @@ #include <zenhttp/websocket.h> #include <zencore/base64.h> -#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinarybuilder.h> #include <zencore/intmath.h> #include <zencore/iobuffer.h> #include <zencore/logging.h> @@ -75,7 +75,7 @@ protected: virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; virtual void OnReset() = 0; - SimpleBinaryWriter m_Stream; + BinaryWriter m_Stream; }; ParseMessageResult @@ -89,7 +89,7 @@ MessageParser::Reset() { OnReset(); - m_Stream.Clear(); + m_Stream.Reset(); } /////////////////////////////////////////////////////////////////////////////// @@ -374,13 +374,13 @@ class WebSocketMessageParser final : public MessageParser public: WebSocketMessageParser() : MessageParser() {} - bool TryLoadMessage(CbPackage& OutMsg); + WebSocketMessage ConsumeMessage(); private: virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; - WebSocketMessageHeader m_Header; + WebSocketMessage m_Message; }; ParseMessageResult @@ -388,65 +388,76 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) { const uint64_t PrevOffset = m_Stream.CurrentOffset(); - if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) { - const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_Stream.CurrentOffset(); + const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); m_Stream.Write(Msg.Left(RemaingHeaderSize)); - Msg.RightChopInline(RemaingHeaderSize); - - if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - const bool IsValidHeader = WebSocketMessageHeader::Read(m_Stream.GetView(), m_Header); + const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView()); if (IsValidHeader == false) { - return {.Status = ParseMessageStatus::kError, .Reason = std::string("Invalid websocket message header")}; + OnReset(); + + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message header")}; } + + Msg += RemaingHeaderSize; } - ZEN_ASSERT(m_Stream.CurrentOffset() >= sizeof(WebSocketMessageHeader)); + ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); if (Msg.IsEmpty()) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - const uint64_t RemaingContentSize = - Min(m_Header.ContentLength - (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)), Msg.GetSize()); + const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); + + m_Stream.Write(Msg.Left(RemaingMessageSize)); + + const bool IsComplete = WebSocketMessage::HeaderSize + m_Message.MessageSize() == m_Stream.CurrentOffset(); + + if (IsComplete) + { + BinaryReader Reader(m_Stream.GetView(WebSocketMessage::HeaderSize)); - m_Stream.Write(Msg.Left(RemaingContentSize)); + CbPackage Pkg; + if (Pkg.TryLoad(Reader) == false) + { + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message")}; + } - const auto Status = (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)) == m_Header.ContentLength - ? ParseMessageStatus::kDone - : ParseMessageStatus::kContinue; + m_Message.SetBody(std::move(Pkg)); + } - return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + return {.Status = IsComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } void WebSocketMessageParser::OnReset() { - m_Header = {}; + m_Message = WebSocketMessage(); } -bool -WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) +WebSocketMessage +WebSocketMessageParser::ConsumeMessage() { - const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength + sizeof(WebSocketMessageHeader); + WebSocketMessage Msg = std::move(m_Message); + m_Message = WebSocketMessage(); - if (IsParsed) - { - BinaryReader Reader(m_Stream.GetView().RightChop(sizeof(WebSocketMessageHeader))); - - return OutMsg.TryLoad(Reader); - } - - return false; + return Msg; } /////////////////////////////////////////////////////////////////////////////// @@ -574,10 +585,15 @@ public: WsServer() = default; virtual ~WsServer() { Shutdown(); } - virtual void RegisterService(WebSocketService& Service) override; virtual bool Run(const WebSocketServerOptions& Options) override; virtual void Shutdown() override; - virtual void PublishMessage(WebSocketId Id, CbPackage&& Msg) override; + + virtual void RegisterService(WebSocketService& Service) override; + virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override; + virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override; + + virtual void SendNotification(WebSocketMessage&& Notification) override; + virtual void SendResponse(WebSocketMessage&& Response) override; private: friend class WsConnection; @@ -586,14 +602,17 @@ private: void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec); void ReadMessage(std::shared_ptr<WsConnection> Connection); - void RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg); + void RouteMessage(WebSocketMessage&& Msg); + void SendMessage(WebSocketMessage&& Msg); struct IdHasher { size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } }; - using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>; + using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>; + using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>; + using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>; asio::io_service m_IoSvc; std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; @@ -601,6 +620,8 @@ private: ConnectionMap m_Connections; std::shared_mutex m_ConnMutex; std::vector<WebSocketService*> m_Services; + RequestHandlerMap m_RequestHandlers; + NotificationHandlerMap m_NotificationHandlers; std::atomic_bool m_Running{}; }; @@ -664,34 +685,32 @@ WsServer::Shutdown() } void -WsServer::PublishMessage(WebSocketId Id, CbPackage&& Msg) +WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) { - std::shared_ptr<WsConnection> Connection; - - { - std::unique_lock _(m_ConnMutex); - - if (auto It = m_Connections.find(Id); It != m_Connections.end()) - { - Connection = It->second; - } - } + auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>()); + Result.first->second.push_back(&Service); +} - BinaryWriter Writer; - WebSocketMessageHeader::Write(Writer, Msg); +void +WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service) +{ + m_RequestHandlers[Key] = &Service; +} - IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); +void +WsServer::SendNotification(WebSocketMessage&& Notification) +{ + ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification); - ZEN_LOG_WARN(LogWebSocket, "sending message {}B to '#{} {}' ", Buffer.Size(), Connection->Id().Value(), Connection->RemoteAddr()); + SendMessage(std::move(Notification)); +} +void +WsServer::SendResponse(WebSocketMessage&& Response) +{ + ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse); + ZEN_ASSERT(Response.CorrelationId() != 0); - async_write(Connection->Socket(), - asio::buffer(Buffer.Data(), Buffer.Size()), - [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - CloseConnection(Connection, Ec); - } - }); + SendMessage(std::move(Response)); } void @@ -768,7 +787,7 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) Connection->Socket(), Connection->ReadBuffer(), asio::transfer_at_least(1), - [this, Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { + [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable { if (ReadEc) { return CloseConnection(Connection, ReadEc); @@ -911,20 +930,15 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) { WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser()); - MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), ByteCount); + uint64_t RemainingBytes = Connection->ReadBuffer().size(); - while (MessageData.IsEmpty() == false) + while (RemainingBytes > 0) { - const ParseMessageResult Result = Parser.ParseMessage(MessageData); - - MessageData.RightChopInline(Result.ByteCount); - - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(MessageData.IsEmpty()); + MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); - return ReadMessage(Connection); - } + Connection->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Connection->ReadBuffer().size(); if (Result.Status == ParseMessageStatus::kError) { @@ -933,20 +947,19 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) return CloseConnection(Connection, std::error_code()); } - CbPackage Message; - if (Parser.TryLoadMessage(Message) == false) + if (Result.Status == ParseMessageStatus::kContinue) { - ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'"); - - return CloseConnection(Connection, std::error_code()); + ZEN_ASSERT(RemainingBytes == 0); + continue; } - RouteMessage(Connection, Message); - + WebSocketMessage Message = Parser.ConsumeMessage(); Parser.Reset(); - } - Connection->ReadBuffer().consume(ByteCount); + Message.SetSocketId(Connection->Id()); + + RouteMessage(std::move(Message)); + } ReadMessage(Connection); } @@ -959,19 +972,114 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) } void -WsServer::RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg) +WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) { - ZEN_LOG_DEBUG(LogWebSocket, "routing message"); + switch (RoutedMessage.MessageType()) + { + case WebSocketMessageType::kRequest: + { + CbObjectView Request = RoutedMessage.Body().GetObject(); + std::string_view Method = Request["Method"].AsString(); + bool Handled = false; + bool Error = false; + std::exception Exception; + + if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end()) + { + WebSocketService* Service = It->second; + ZEN_ASSERT(Service); + + try + { + Handled = Service->HandleRequest(std::move(RoutedMessage)); + } + catch (std::exception& Err) + { + Exception = std::move(Err); + Error = true; + } + } + + if (Error || Handled == false) + { + std::string ErrorText = Error ? Exception.what() : std::string("Not Found"); + + ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); + + CbObjectWriter Response; + Response << "Error"sv << ErrorText; + + WebSocketMessage ResponseMsg; + ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); + ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId()); + ResponseMsg.SetSocketId(RoutedMessage.SocketId()); + ResponseMsg.SetBody(Response.Save()); + + SendResponse(std::move(ResponseMsg)); + } + } + break; + + case WebSocketMessageType::kNotification: + { + CbObjectView Notification = RoutedMessage.Body().GetObject(); + std::string_view Message = Notification["Message"].AsString(); + + if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end()) + { + std::vector<WebSocketService*>& Handlers = It->second; + + for (WebSocketService* Handler : Handlers) + { + Handler->HandleNotification(RoutedMessage); + } + } + else + { + ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message); + } + } + break; + }; +} + +void +WsServer::SendMessage(WebSocketMessage&& Msg) +{ + std::shared_ptr<WsConnection> Connection; - for (auto Server : m_Services) { - if (Server->HandleMessage(Connection->Id(), Msg)) + std::unique_lock _(m_ConnMutex); + + if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end()) { - return; + Connection = It->second; } } - ZEN_LOG_WARN(LogWebSocket, "unhandled message"); + if (Connection.get() == nullptr) + { + ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value()); + return; + } + + if (Connection.get() != nullptr) + { + BinaryWriter Writer; + Msg.Save(Writer); + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + async_write(Connection->Socket(), + asio::buffer(Buffer.Data(), Buffer.Size()), + [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason '{}'", Ec.message()); + + CloseConnection(Connection, Ec); + } + }); + } } /////////////////////////////////////////////////////////////////////////////// @@ -984,15 +1092,14 @@ public: std::shared_ptr<WsClient> AsShared() { return shared_from_this(); } - virtual bool Connect(const WebSocketConnectInfo& Info) override; - virtual void Disconnect() override; - virtual bool IsConnected() const { return false; } - virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); } - virtual void SendMsg(CbPackage&& Msg) override; - virtual void SendMsg(CbObject&& Msg) override; + virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override; + virtual void Disconnect() override; + virtual bool IsConnected() const { return false; } + virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); } - virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override; - virtual void OnMessage(MessageCallback&& Cb) override; + virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override; + virtual void OnNotification(NotificationCallback&& Cb) override; + virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override; private: WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } @@ -1001,7 +1108,9 @@ private: asio::streambuf& ReadBuffer() { return m_ReadBuffer; } void TriggerEvent(WebSocketEvent Evt); void ReadMessage(); - void RouteMessage(CbPackage&& Msg); + void RouteMessage(WebSocketMessage&& RoutedMessage); + + using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>; asio::io_context& m_IoCtx; WebSocketId m_Id; @@ -1009,18 +1118,21 @@ private: std::unique_ptr<MessageParser> m_MsgParser; asio::streambuf m_ReadBuffer; EventCallback m_EventCallbacks[3]; - MessageCallback m_MsgCallback; + NotificationCallback m_NotificationCallback; + PendingRequestMap m_PendingRequests; + std::mutex m_RequestMutex; + std::promise<bool> m_ConnectPromise; std::atomic_uint32_t m_State; std::string m_Host; int16_t m_Port{}; }; -bool +std::future<bool> WsClient::Connect(const WebSocketConnectInfo& Info) { if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) { - return true; + return m_ConnectPromise.get_future(); } SetState(WebSocketState::kHandshaking); @@ -1046,7 +1158,9 @@ WsClient::Connect(const WebSocketConnectInfo& Info) TriggerEvent(WebSocketEvent::kDisconnected); - return false; + m_ConnectPromise.set_value(false); + + return m_ConnectPromise.get_future(); } ExtendableStringBuilder<128> Sb; @@ -1093,7 +1207,7 @@ WsClient::Connect(const WebSocketConnectInfo& Info) } }); - return true; + return m_ConnectPromise.get_future(); } void @@ -1110,48 +1224,64 @@ WsClient::Disconnect() } TriggerEvent(WebSocketEvent::kDisconnected); + + { + std::unique_lock _(m_RequestMutex); + + for (auto& Kv : m_PendingRequests) + { + Kv.second.set_value(WebSocketMessage()); + } + + m_PendingRequests.clear(); + } } } -void -WsClient::SendMsg(CbPackage&& Msg) +std::future<WebSocketMessage> +WsClient::SendRequest(WebSocketMessage&& Request) { + ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest); + BinaryWriter Writer; - WebSocketMessageHeader::Write(Writer, Msg); + Request.Save(Writer); - IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + std::future<WebSocketMessage> FutureResponse; + + { + std::unique_lock _(m_RequestMutex); - ZEN_LOG_DEBUG(LogWsClient, "sending message {}B", Buffer.Size()); + auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>()); + ZEN_ASSERT(Result.second); - async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const asio::error_code& Ec, std::size_t) { + auto It = Result.first; + FutureResponse = It->second.get_future(); + } + + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) { if (Ec) { - ZEN_LOG_ERROR(LogWsClient, "send messge FAILED, reason '{}'", Ec.message()); + ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message()); Self->Disconnect(); } }); -} - -void -WsClient::SendMsg(CbObject&& Msg) -{ - CbPackage Pkg; - Pkg.SetObject(std::move(Msg)); - WsClient::SendMsg(std::move(Pkg)); + return FutureResponse; } void -WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) +WsClient::OnNotification(NotificationCallback&& Cb) { - m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); + m_NotificationCallback = std::move(Cb); } void -WsClient::OnMessage(MessageCallback&& Cb) +WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) { - m_MsgCallback = std::move(Cb); + m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); } void @@ -1199,6 +1329,8 @@ WsClient::ReadMessage() { ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); + Self->m_ConnectPromise.set_value(false); + return Self->Disconnect(); } @@ -1216,6 +1348,8 @@ WsClient::ReadMessage() Parser.StatusText(), Parser.StatusCode()); + Self->m_ConnectPromise.set_value(false); + return Self->Disconnect(); } @@ -1225,6 +1359,8 @@ WsClient::ReadMessage() Self->SetState(WebSocketState::kConnected); Self->ReadMessage(); Self->TriggerEvent(WebSocketEvent::kConnected); + + Self->m_ConnectPromise.set_value(true); } break; @@ -1232,44 +1368,36 @@ WsClient::ReadMessage() { WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser()); - MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); + uint64_t RemainingBytes = Self->ReadBuffer().size(); - while (MessageData.IsEmpty() == false) + while (RemainingBytes > 0) { - const ParseMessageResult Result = Parser.ParseMessage(MessageData); - - MessageData.RightChopInline(Result.ByteCount); - - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(MessageData.IsEmpty()); + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); - return Self->ReadMessage(); - } + Self->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Self->ReadBuffer().size(); if (Result.Status == ParseMessageStatus::kError) { ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); Parser.Reset(); - continue; } - CbPackage Message; - if (Parser.TryLoadMessage(Message)) - { - Self->RouteMessage(std::move(Message)); - } - else + if (Result.Status == ParseMessageStatus::kContinue) { - ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason 'invalid message'"); + ZEN_ASSERT(RemainingBytes == 0); + continue; } + WebSocketMessage Message = Parser.ConsumeMessage(); Parser.Reset(); + + Self->RouteMessage(std::move(Message)); } - Self->ReadBuffer().consume(ByteCount); Self->ReadMessage(); } break; @@ -1278,12 +1406,39 @@ WsClient::ReadMessage() } void -WsClient::RouteMessage(CbPackage&& Msg) +WsClient::RouteMessage(WebSocketMessage&& RoutedMessage) { - if (m_MsgCallback) + switch (RoutedMessage.MessageType()) { - m_MsgCallback(Msg); - } + case WebSocketMessageType::kResponse: + { + std::unique_lock _(m_RequestMutex); + + if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end()) + { + It->second.set_value(std::move(RoutedMessage)); + m_PendingRequests.erase(It); + } + else + { + ZEN_LOG_WARN(LogWsClient, + "route request message FAILED, reason 'unknown correlation ID ({})'", + RoutedMessage.CorrelationId()); + } + } + break; + + case WebSocketMessageType::kNotification: + { + std::unique_lock _(m_RequestMutex); + + if (m_NotificationCallback) + { + m_NotificationCallback(std::move(RoutedMessage)); + } + } + break; + }; } } // namespace zen::websocket @@ -1292,67 +1447,78 @@ namespace zen { std::atomic_uint32_t WebSocketId::NextId{1}; -void -WebSocketService::Configure(WebSocketServer& Server) +bool +WebSocketMessage::Header::IsValid() const { - ZEN_ASSERT(m_Server == nullptr); - - m_Server = &Server; + return Magic == HeaderMagic && MessageSize > 0 && Crc32 > 0 && uint8_t(MessageType) > 0 && uint8_t(MessageType) < 4; } +std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; + void -WebSocketService::PublishMessage(WebSocketId Id, CbPackage&& Msg) +WebSocketMessage::SetMessageType(WebSocketMessageType MessageType) { - ZEN_ASSERT(m_Server != nullptr); - - m_Server->PublishMessage(Id, std::move(Msg)); + m_Header.MessageType = MessageType; } void -WebSocketService::PublishMessage(WebSocketId Id, CbObject&& Msg) +WebSocketMessage::SetBody(CbPackage&& Body) +{ + m_Body = std::move(Body); +} +void +WebSocketMessage::SetBody(CbObject&& Body) { - ZEN_ASSERT(m_Server != nullptr); - CbPackage Pkg; - Pkg.SetObject(std::move(Msg)); + Pkg.SetObject(Body); - m_Server->PublishMessage(Id, std::move(Pkg)); + SetBody(std::move(Pkg)); } -bool -WebSocketMessageHeader::IsValid() const +void +WebSocketMessage::Save(BinaryWriter& Writer) { - return Magic == ExpectedMagic && ContentLength != 0 && Crc32 != 0; + Writer.Write(&m_Header, HeaderSize); + + if (m_Body.has_value()) + { + m_Body.value().Save(Writer); + } + + if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest) + { + m_Header.CorrelationId = NextCorrelationId.fetch_add(1); + } + + m_Header.MessageSize = Writer.Size() - HeaderSize; + m_Header.Crc32 = 1; // TODO + + Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); } bool -WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeader) +WebSocketMessage::TryLoadHeader(MemoryView Memory) { - if (Memory.GetSize() < sizeof(WebSocketMessageHeader)) + if (Memory.GetSize() < HeaderSize) { return false; } - void* Dst = &OutHeader; - memcpy(Dst, Memory.GetData(), sizeof(WebSocketMessageHeader)); + MutableMemoryView HeaderView(&m_Header, HeaderSize); + + HeaderView.CopyFrom(Memory); - return OutHeader.IsValid(); + return m_Header.IsValid(); } void -WebSocketMessageHeader::Write(BinaryWriter& Writer, const CbPackage& Msg) +WebSocketService::Configure(WebSocketServer& Server) { - WebSocketMessageHeader Header; - - Writer.Write(&Header, sizeof(WebSocketMessageHeader)); - Msg.Save(Writer); + ZEN_ASSERT(m_SocketServer == nullptr); - Header.Magic = WebSocketMessageHeader::ExpectedMagic; - Header.ContentLength = Writer.CurrentOffset() - sizeof(WebSocketMessageHeader); - Header.Crc32 = WebSocketMessageHeader::ExpectedMagic; // TODO + m_SocketServer = &Server; - MutableMemoryView HeaderView(const_cast<uint8_t*>(Writer.Data()), sizeof(WebSocketMessageHeader)); - HeaderView.CopyFrom(MemoryView(&Header, sizeof(WebSocketMessageHeader))); + RegisterHandlers(Server); } std::unique_ptr<WebSocketServer> diff --git a/zenserver-test/zenserver-test.cpp b/zenserver-test/zenserver-test.cpp index 73048b504..78829a2d1 100644 --- a/zenserver-test/zenserver-test.cpp +++ b/zenserver-test/zenserver-test.cpp @@ -2083,8 +2083,9 @@ TEST_CASE("http.package") TEST_CASE("websocket.basic") { - std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); - const uint16_t PortNumber = 13337; + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const uint16_t PortNumber = 13337; + const auto MaxWaitTime = std::chrono::seconds(5); ZenServerInstance Inst(TestEnv); Inst.SetTestDir(TestDir); @@ -2095,43 +2096,30 @@ TEST_CASE("websocket.basic") IoDispatcher IoDispatcher(IoCtx); auto WebSocket = WebSocketClient::Create(IoCtx); - std::atomic_bool Done{false}; - WebSocketEvent Event; - - WebSocket->On(WebSocketEvent::kConnected, [&]() { - CbObjectWriter Req; - Req.BeginObject("Header"); - Req << "Method"sv - << "TestHelloZen"sv; - Req << "CorrelationId" << uint64_t(1); - Req.EndObject(); - - WebSocket->SendMsg(Req.Save()); - }); + auto ConnectFuture = WebSocket->Connect({.Host = "127.0.0.1", .Port = 8848, .Endpoint = "/zen"}); + IoDispatcher.Run(); - WebSocket->On(WebSocketEvent::kDisconnected, [&]() { - Event = WebSocketEvent::kDisconnected; - Done = true; - }); + ConnectFuture.wait_for(MaxWaitTime); + CHECK(ConnectFuture.get()); - CbPackage Response; - WebSocket->OnMessage([&](const CbPackage& Msg) { - Response = Msg; - Done = true; - }); + for (size_t Idx = 0; Idx < 10; Idx++) + { + CbObjectWriter Request; + Request << "Method"sv + << "SayHello"sv; - WebSocket->Connect({.Host = "127.0.0.1", .Port = 8848, .Endpoint = "/zen"}); + WebSocketMessage RequestMsg; + RequestMsg.SetMessageType(WebSocketMessageType::kRequest); + RequestMsg.SetBody(Request.Save()); - IoDispatcher.Run(); + auto ResponseFuture = WebSocket->SendRequest(std::move(RequestMsg)); + ResponseFuture.wait_for(MaxWaitTime); - while (Done == false && IoDispatcher.IsRunning()) - { - std::this_thread::sleep_for(std::chrono::seconds(2)); - }; + CbObject Response = ResponseFuture.get().Body().GetObject(); + std::string_view Message = Response["Result"].AsString(); - CbObject ResponseObject = Response.GetObject(); - std::string_view Message = ResponseObject["Result"].AsString(); - CHECK(Message == "Hello Friend!!"sv); + CHECK(Message == "Hello Friend!!"sv); + } WebSocket->Disconnect(); diff --git a/zenserver/testing/httptest.cpp b/zenserver/testing/httptest.cpp index 41a4f064b..10b69c469 100644 --- a/zenserver/testing/httptest.cpp +++ b/zenserver/testing/httptest.cpp @@ -8,6 +8,8 @@ namespace zen { +using namespace std::literals; + HttpTestingService::HttpTestingService() { m_Router.RegisterRoute( @@ -136,29 +138,34 @@ HttpTestingService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) return (InsertResult.first->second = new PackageHandler(*this, RequestId)).Get(); } -bool -HttpTestingService::HandleMessage(WebSocketId SocketId, const CbPackage& Msg) +void +HttpTestingService::RegisterHandlers(WebSocketServer& Server) { - using namespace std::literals; + Server.RegisterRequestHandler("SayHello"sv, *this); +} - CbObject Request = Msg.GetObject(); +bool +HttpTestingService::HandleRequest(const WebSocketMessage& RequestMsg) +{ + CbObjectView Request = RequestMsg.Body().GetObject(); - CbObjectView Header = Request["Header"sv].AsObjectView(); - std::string_view Method = Header["Method"].AsString(); - const uint64_t CorrelationId = Header["CorrelationId"].AsUInt64(); + std::string_view Method = Request["Method"].AsString(); - if (Method != "TestHelloZen"sv) + if (Method != "SayHello"sv) { return false; } CbObjectWriter Response; - Response.BeginObject("Header"); - Response << "CorrelationId"sv << CorrelationId; - Response.EndObject(); Response.AddString("Result"sv, "Hello Friend!!"); - PublishMessage(SocketId, Response.Save()); + WebSocketMessage ResponseMsg; + ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); + ResponseMsg.SetCorrelationId(RequestMsg.CorrelationId()); + ResponseMsg.SetSocketId(RequestMsg.SocketId()); + ResponseMsg.SetBody(Response.Save()); + + SocketServer().SendResponse(std::move(ResponseMsg)); return true; } diff --git a/zenserver/testing/httptest.h b/zenserver/testing/httptest.h index 267d59b36..57d2d63f3 100644 --- a/zenserver/testing/httptest.h +++ b/zenserver/testing/httptest.h @@ -23,7 +23,6 @@ public: virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest) override; - virtual bool HandleMessage(WebSocketId SocketId, const CbPackage& Msg) override; class PackageHandler : public IHttpPackageHandler { @@ -42,6 +41,9 @@ public: }; private: + virtual void RegisterHandlers(WebSocketServer& Server) override; + virtual bool HandleRequest(const WebSocketMessage& Request) override; + HttpRequestRouter m_Router; std::atomic<uint32_t> m_Counter{0}; metrics::OperationTiming m_TimingStats; |