diff options
| author | Per Larsson <[email protected]> | 2022-02-21 14:22:38 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-02-21 14:22:38 +0100 |
| commit | fd9f9086b3ddd0c38fa87d7e49f6341dacdcc125 (patch) | |
| tree | 50be00e62f1e0d3a0521d08d5c7f00fe7787e014 /zenhttp/websocketasio.cpp | |
| parent | Basic websocket service and test. (diff) | |
| download | zen-fd9f9086b3ddd0c38fa87d7e49f6341dacdcc125.tar.xz zen-fd9f9086b3ddd0c38fa87d7e49f6341dacdcc125.zip | |
Refactored websocket message.
Diffstat (limited to 'zenhttp/websocketasio.cpp')
| -rw-r--r-- | zenhttp/websocketasio.cpp | 516 |
1 files changed, 341 insertions, 175 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index f6f58f38c..13d0177ee 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -3,7 +3,7 @@ #include <zenhttp/websocket.h> #include <zencore/base64.h> -#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinarybuilder.h> #include <zencore/intmath.h> #include <zencore/iobuffer.h> #include <zencore/logging.h> @@ -75,7 +75,7 @@ protected: virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; virtual void OnReset() = 0; - SimpleBinaryWriter m_Stream; + BinaryWriter m_Stream; }; ParseMessageResult @@ -89,7 +89,7 @@ MessageParser::Reset() { OnReset(); - m_Stream.Clear(); + m_Stream.Reset(); } /////////////////////////////////////////////////////////////////////////////// @@ -374,13 +374,13 @@ class WebSocketMessageParser final : public MessageParser public: WebSocketMessageParser() : MessageParser() {} - bool TryLoadMessage(CbPackage& OutMsg); + WebSocketMessage ConsumeMessage(); private: virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; - WebSocketMessageHeader m_Header; + WebSocketMessage m_Message; }; ParseMessageResult @@ -388,65 +388,76 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) { const uint64_t PrevOffset = m_Stream.CurrentOffset(); - if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) { - const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_Stream.CurrentOffset(); + const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); m_Stream.Write(Msg.Left(RemaingHeaderSize)); - Msg.RightChopInline(RemaingHeaderSize); - - if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - const bool IsValidHeader = WebSocketMessageHeader::Read(m_Stream.GetView(), m_Header); + const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView()); if (IsValidHeader == false) { - return {.Status = ParseMessageStatus::kError, .Reason = std::string("Invalid websocket message header")}; + OnReset(); + + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message header")}; } + + Msg += RemaingHeaderSize; } - ZEN_ASSERT(m_Stream.CurrentOffset() >= sizeof(WebSocketMessageHeader)); + ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); if (Msg.IsEmpty()) { return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - const uint64_t RemaingContentSize = - Min(m_Header.ContentLength - (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)), Msg.GetSize()); + const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); + + m_Stream.Write(Msg.Left(RemaingMessageSize)); + + const bool IsComplete = WebSocketMessage::HeaderSize + m_Message.MessageSize() == m_Stream.CurrentOffset(); + + if (IsComplete) + { + BinaryReader Reader(m_Stream.GetView(WebSocketMessage::HeaderSize)); - m_Stream.Write(Msg.Left(RemaingContentSize)); + CbPackage Pkg; + if (Pkg.TryLoad(Reader) == false) + { + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message")}; + } - const auto Status = (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)) == m_Header.ContentLength - ? ParseMessageStatus::kDone - : ParseMessageStatus::kContinue; + m_Message.SetBody(std::move(Pkg)); + } - return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + return {.Status = IsComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } void WebSocketMessageParser::OnReset() { - m_Header = {}; + m_Message = WebSocketMessage(); } -bool -WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) +WebSocketMessage +WebSocketMessageParser::ConsumeMessage() { - const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength + sizeof(WebSocketMessageHeader); + WebSocketMessage Msg = std::move(m_Message); + m_Message = WebSocketMessage(); - if (IsParsed) - { - BinaryReader Reader(m_Stream.GetView().RightChop(sizeof(WebSocketMessageHeader))); - - return OutMsg.TryLoad(Reader); - } - - return false; + return Msg; } /////////////////////////////////////////////////////////////////////////////// @@ -574,10 +585,15 @@ public: WsServer() = default; virtual ~WsServer() { Shutdown(); } - virtual void RegisterService(WebSocketService& Service) override; virtual bool Run(const WebSocketServerOptions& Options) override; virtual void Shutdown() override; - virtual void PublishMessage(WebSocketId Id, CbPackage&& Msg) override; + + virtual void RegisterService(WebSocketService& Service) override; + virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override; + virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override; + + virtual void SendNotification(WebSocketMessage&& Notification) override; + virtual void SendResponse(WebSocketMessage&& Response) override; private: friend class WsConnection; @@ -586,14 +602,17 @@ private: void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec); void ReadMessage(std::shared_ptr<WsConnection> Connection); - void RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg); + void RouteMessage(WebSocketMessage&& Msg); + void SendMessage(WebSocketMessage&& Msg); struct IdHasher { size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } }; - using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>; + using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>; + using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>; + using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>; asio::io_service m_IoSvc; std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; @@ -601,6 +620,8 @@ private: ConnectionMap m_Connections; std::shared_mutex m_ConnMutex; std::vector<WebSocketService*> m_Services; + RequestHandlerMap m_RequestHandlers; + NotificationHandlerMap m_NotificationHandlers; std::atomic_bool m_Running{}; }; @@ -664,34 +685,32 @@ WsServer::Shutdown() } void -WsServer::PublishMessage(WebSocketId Id, CbPackage&& Msg) +WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) { - std::shared_ptr<WsConnection> Connection; - - { - std::unique_lock _(m_ConnMutex); - - if (auto It = m_Connections.find(Id); It != m_Connections.end()) - { - Connection = It->second; - } - } + auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>()); + Result.first->second.push_back(&Service); +} - BinaryWriter Writer; - WebSocketMessageHeader::Write(Writer, Msg); +void +WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service) +{ + m_RequestHandlers[Key] = &Service; +} - IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); +void +WsServer::SendNotification(WebSocketMessage&& Notification) +{ + ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification); - ZEN_LOG_WARN(LogWebSocket, "sending message {}B to '#{} {}' ", Buffer.Size(), Connection->Id().Value(), Connection->RemoteAddr()); + SendMessage(std::move(Notification)); +} +void +WsServer::SendResponse(WebSocketMessage&& Response) +{ + ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse); + ZEN_ASSERT(Response.CorrelationId() != 0); - async_write(Connection->Socket(), - asio::buffer(Buffer.Data(), Buffer.Size()), - [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - CloseConnection(Connection, Ec); - } - }); + SendMessage(std::move(Response)); } void @@ -768,7 +787,7 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) Connection->Socket(), Connection->ReadBuffer(), asio::transfer_at_least(1), - [this, Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { + [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable { if (ReadEc) { return CloseConnection(Connection, ReadEc); @@ -911,20 +930,15 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) { WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser()); - MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), ByteCount); + uint64_t RemainingBytes = Connection->ReadBuffer().size(); - while (MessageData.IsEmpty() == false) + while (RemainingBytes > 0) { - const ParseMessageResult Result = Parser.ParseMessage(MessageData); - - MessageData.RightChopInline(Result.ByteCount); - - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(MessageData.IsEmpty()); + MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); - return ReadMessage(Connection); - } + Connection->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Connection->ReadBuffer().size(); if (Result.Status == ParseMessageStatus::kError) { @@ -933,20 +947,19 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) return CloseConnection(Connection, std::error_code()); } - CbPackage Message; - if (Parser.TryLoadMessage(Message) == false) + if (Result.Status == ParseMessageStatus::kContinue) { - ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'"); - - return CloseConnection(Connection, std::error_code()); + ZEN_ASSERT(RemainingBytes == 0); + continue; } - RouteMessage(Connection, Message); - + WebSocketMessage Message = Parser.ConsumeMessage(); Parser.Reset(); - } - Connection->ReadBuffer().consume(ByteCount); + Message.SetSocketId(Connection->Id()); + + RouteMessage(std::move(Message)); + } ReadMessage(Connection); } @@ -959,19 +972,114 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) } void -WsServer::RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg) +WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) { - ZEN_LOG_DEBUG(LogWebSocket, "routing message"); + switch (RoutedMessage.MessageType()) + { + case WebSocketMessageType::kRequest: + { + CbObjectView Request = RoutedMessage.Body().GetObject(); + std::string_view Method = Request["Method"].AsString(); + bool Handled = false; + bool Error = false; + std::exception Exception; + + if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end()) + { + WebSocketService* Service = It->second; + ZEN_ASSERT(Service); + + try + { + Handled = Service->HandleRequest(std::move(RoutedMessage)); + } + catch (std::exception& Err) + { + Exception = std::move(Err); + Error = true; + } + } + + if (Error || Handled == false) + { + std::string ErrorText = Error ? Exception.what() : std::string("Not Found"); + + ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); + + CbObjectWriter Response; + Response << "Error"sv << ErrorText; + + WebSocketMessage ResponseMsg; + ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); + ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId()); + ResponseMsg.SetSocketId(RoutedMessage.SocketId()); + ResponseMsg.SetBody(Response.Save()); + + SendResponse(std::move(ResponseMsg)); + } + } + break; + + case WebSocketMessageType::kNotification: + { + CbObjectView Notification = RoutedMessage.Body().GetObject(); + std::string_view Message = Notification["Message"].AsString(); + + if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end()) + { + std::vector<WebSocketService*>& Handlers = It->second; + + for (WebSocketService* Handler : Handlers) + { + Handler->HandleNotification(RoutedMessage); + } + } + else + { + ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message); + } + } + break; + }; +} + +void +WsServer::SendMessage(WebSocketMessage&& Msg) +{ + std::shared_ptr<WsConnection> Connection; - for (auto Server : m_Services) { - if (Server->HandleMessage(Connection->Id(), Msg)) + std::unique_lock _(m_ConnMutex); + + if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end()) { - return; + Connection = It->second; } } - ZEN_LOG_WARN(LogWebSocket, "unhandled message"); + if (Connection.get() == nullptr) + { + ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value()); + return; + } + + if (Connection.get() != nullptr) + { + BinaryWriter Writer; + Msg.Save(Writer); + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + async_write(Connection->Socket(), + asio::buffer(Buffer.Data(), Buffer.Size()), + [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason '{}'", Ec.message()); + + CloseConnection(Connection, Ec); + } + }); + } } /////////////////////////////////////////////////////////////////////////////// @@ -984,15 +1092,14 @@ public: std::shared_ptr<WsClient> AsShared() { return shared_from_this(); } - virtual bool Connect(const WebSocketConnectInfo& Info) override; - virtual void Disconnect() override; - virtual bool IsConnected() const { return false; } - virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); } - virtual void SendMsg(CbPackage&& Msg) override; - virtual void SendMsg(CbObject&& Msg) override; + virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override; + virtual void Disconnect() override; + virtual bool IsConnected() const { return false; } + virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); } - virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override; - virtual void OnMessage(MessageCallback&& Cb) override; + virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override; + virtual void OnNotification(NotificationCallback&& Cb) override; + virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override; private: WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } @@ -1001,7 +1108,9 @@ private: asio::streambuf& ReadBuffer() { return m_ReadBuffer; } void TriggerEvent(WebSocketEvent Evt); void ReadMessage(); - void RouteMessage(CbPackage&& Msg); + void RouteMessage(WebSocketMessage&& RoutedMessage); + + using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>; asio::io_context& m_IoCtx; WebSocketId m_Id; @@ -1009,18 +1118,21 @@ private: std::unique_ptr<MessageParser> m_MsgParser; asio::streambuf m_ReadBuffer; EventCallback m_EventCallbacks[3]; - MessageCallback m_MsgCallback; + NotificationCallback m_NotificationCallback; + PendingRequestMap m_PendingRequests; + std::mutex m_RequestMutex; + std::promise<bool> m_ConnectPromise; std::atomic_uint32_t m_State; std::string m_Host; int16_t m_Port{}; }; -bool +std::future<bool> WsClient::Connect(const WebSocketConnectInfo& Info) { if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) { - return true; + return m_ConnectPromise.get_future(); } SetState(WebSocketState::kHandshaking); @@ -1046,7 +1158,9 @@ WsClient::Connect(const WebSocketConnectInfo& Info) TriggerEvent(WebSocketEvent::kDisconnected); - return false; + m_ConnectPromise.set_value(false); + + return m_ConnectPromise.get_future(); } ExtendableStringBuilder<128> Sb; @@ -1093,7 +1207,7 @@ WsClient::Connect(const WebSocketConnectInfo& Info) } }); - return true; + return m_ConnectPromise.get_future(); } void @@ -1110,48 +1224,64 @@ WsClient::Disconnect() } TriggerEvent(WebSocketEvent::kDisconnected); + + { + std::unique_lock _(m_RequestMutex); + + for (auto& Kv : m_PendingRequests) + { + Kv.second.set_value(WebSocketMessage()); + } + + m_PendingRequests.clear(); + } } } -void -WsClient::SendMsg(CbPackage&& Msg) +std::future<WebSocketMessage> +WsClient::SendRequest(WebSocketMessage&& Request) { + ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest); + BinaryWriter Writer; - WebSocketMessageHeader::Write(Writer, Msg); + Request.Save(Writer); - IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + std::future<WebSocketMessage> FutureResponse; + + { + std::unique_lock _(m_RequestMutex); - ZEN_LOG_DEBUG(LogWsClient, "sending message {}B", Buffer.Size()); + auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>()); + ZEN_ASSERT(Result.second); - async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const asio::error_code& Ec, std::size_t) { + auto It = Result.first; + FutureResponse = It->second.get_future(); + } + + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) { if (Ec) { - ZEN_LOG_ERROR(LogWsClient, "send messge FAILED, reason '{}'", Ec.message()); + ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message()); Self->Disconnect(); } }); -} - -void -WsClient::SendMsg(CbObject&& Msg) -{ - CbPackage Pkg; - Pkg.SetObject(std::move(Msg)); - WsClient::SendMsg(std::move(Pkg)); + return FutureResponse; } void -WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) +WsClient::OnNotification(NotificationCallback&& Cb) { - m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); + m_NotificationCallback = std::move(Cb); } void -WsClient::OnMessage(MessageCallback&& Cb) +WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) { - m_MsgCallback = std::move(Cb); + m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); } void @@ -1199,6 +1329,8 @@ WsClient::ReadMessage() { ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); + Self->m_ConnectPromise.set_value(false); + return Self->Disconnect(); } @@ -1216,6 +1348,8 @@ WsClient::ReadMessage() Parser.StatusText(), Parser.StatusCode()); + Self->m_ConnectPromise.set_value(false); + return Self->Disconnect(); } @@ -1225,6 +1359,8 @@ WsClient::ReadMessage() Self->SetState(WebSocketState::kConnected); Self->ReadMessage(); Self->TriggerEvent(WebSocketEvent::kConnected); + + Self->m_ConnectPromise.set_value(true); } break; @@ -1232,44 +1368,36 @@ WsClient::ReadMessage() { WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser()); - MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); + uint64_t RemainingBytes = Self->ReadBuffer().size(); - while (MessageData.IsEmpty() == false) + while (RemainingBytes > 0) { - const ParseMessageResult Result = Parser.ParseMessage(MessageData); - - MessageData.RightChopInline(Result.ByteCount); - - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(MessageData.IsEmpty()); + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); - return Self->ReadMessage(); - } + Self->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Self->ReadBuffer().size(); if (Result.Status == ParseMessageStatus::kError) { ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); Parser.Reset(); - continue; } - CbPackage Message; - if (Parser.TryLoadMessage(Message)) - { - Self->RouteMessage(std::move(Message)); - } - else + if (Result.Status == ParseMessageStatus::kContinue) { - ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason 'invalid message'"); + ZEN_ASSERT(RemainingBytes == 0); + continue; } + WebSocketMessage Message = Parser.ConsumeMessage(); Parser.Reset(); + + Self->RouteMessage(std::move(Message)); } - Self->ReadBuffer().consume(ByteCount); Self->ReadMessage(); } break; @@ -1278,12 +1406,39 @@ WsClient::ReadMessage() } void -WsClient::RouteMessage(CbPackage&& Msg) +WsClient::RouteMessage(WebSocketMessage&& RoutedMessage) { - if (m_MsgCallback) + switch (RoutedMessage.MessageType()) { - m_MsgCallback(Msg); - } + case WebSocketMessageType::kResponse: + { + std::unique_lock _(m_RequestMutex); + + if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end()) + { + It->second.set_value(std::move(RoutedMessage)); + m_PendingRequests.erase(It); + } + else + { + ZEN_LOG_WARN(LogWsClient, + "route request message FAILED, reason 'unknown correlation ID ({})'", + RoutedMessage.CorrelationId()); + } + } + break; + + case WebSocketMessageType::kNotification: + { + std::unique_lock _(m_RequestMutex); + + if (m_NotificationCallback) + { + m_NotificationCallback(std::move(RoutedMessage)); + } + } + break; + }; } } // namespace zen::websocket @@ -1292,67 +1447,78 @@ namespace zen { std::atomic_uint32_t WebSocketId::NextId{1}; -void -WebSocketService::Configure(WebSocketServer& Server) +bool +WebSocketMessage::Header::IsValid() const { - ZEN_ASSERT(m_Server == nullptr); - - m_Server = &Server; + return Magic == HeaderMagic && MessageSize > 0 && Crc32 > 0 && uint8_t(MessageType) > 0 && uint8_t(MessageType) < 4; } +std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; + void -WebSocketService::PublishMessage(WebSocketId Id, CbPackage&& Msg) +WebSocketMessage::SetMessageType(WebSocketMessageType MessageType) { - ZEN_ASSERT(m_Server != nullptr); - - m_Server->PublishMessage(Id, std::move(Msg)); + m_Header.MessageType = MessageType; } void -WebSocketService::PublishMessage(WebSocketId Id, CbObject&& Msg) +WebSocketMessage::SetBody(CbPackage&& Body) +{ + m_Body = std::move(Body); +} +void +WebSocketMessage::SetBody(CbObject&& Body) { - ZEN_ASSERT(m_Server != nullptr); - CbPackage Pkg; - Pkg.SetObject(std::move(Msg)); + Pkg.SetObject(Body); - m_Server->PublishMessage(Id, std::move(Pkg)); + SetBody(std::move(Pkg)); } -bool -WebSocketMessageHeader::IsValid() const +void +WebSocketMessage::Save(BinaryWriter& Writer) { - return Magic == ExpectedMagic && ContentLength != 0 && Crc32 != 0; + Writer.Write(&m_Header, HeaderSize); + + if (m_Body.has_value()) + { + m_Body.value().Save(Writer); + } + + if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest) + { + m_Header.CorrelationId = NextCorrelationId.fetch_add(1); + } + + m_Header.MessageSize = Writer.Size() - HeaderSize; + m_Header.Crc32 = 1; // TODO + + Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); } bool -WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeader) +WebSocketMessage::TryLoadHeader(MemoryView Memory) { - if (Memory.GetSize() < sizeof(WebSocketMessageHeader)) + if (Memory.GetSize() < HeaderSize) { return false; } - void* Dst = &OutHeader; - memcpy(Dst, Memory.GetData(), sizeof(WebSocketMessageHeader)); + MutableMemoryView HeaderView(&m_Header, HeaderSize); + + HeaderView.CopyFrom(Memory); - return OutHeader.IsValid(); + return m_Header.IsValid(); } void -WebSocketMessageHeader::Write(BinaryWriter& Writer, const CbPackage& Msg) +WebSocketService::Configure(WebSocketServer& Server) { - WebSocketMessageHeader Header; - - Writer.Write(&Header, sizeof(WebSocketMessageHeader)); - Msg.Save(Writer); + ZEN_ASSERT(m_SocketServer == nullptr); - Header.Magic = WebSocketMessageHeader::ExpectedMagic; - Header.ContentLength = Writer.CurrentOffset() - sizeof(WebSocketMessageHeader); - Header.Crc32 = WebSocketMessageHeader::ExpectedMagic; // TODO + m_SocketServer = &Server; - MutableMemoryView HeaderView(const_cast<uint8_t*>(Writer.Data()), sizeof(WebSocketMessageHeader)); - HeaderView.CopyFrom(MemoryView(&Header, sizeof(WebSocketMessageHeader))); + RegisterHandlers(Server); } std::unique_ptr<WebSocketServer> |