aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/websocketasio.cpp
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-02-21 14:22:38 +0100
committerPer Larsson <[email protected]>2022-02-21 14:22:38 +0100
commitfd9f9086b3ddd0c38fa87d7e49f6341dacdcc125 (patch)
tree50be00e62f1e0d3a0521d08d5c7f00fe7787e014 /zenhttp/websocketasio.cpp
parentBasic websocket service and test. (diff)
downloadzen-fd9f9086b3ddd0c38fa87d7e49f6341dacdcc125.tar.xz
zen-fd9f9086b3ddd0c38fa87d7e49f6341dacdcc125.zip
Refactored websocket message.
Diffstat (limited to 'zenhttp/websocketasio.cpp')
-rw-r--r--zenhttp/websocketasio.cpp516
1 files changed, 341 insertions, 175 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp
index f6f58f38c..13d0177ee 100644
--- a/zenhttp/websocketasio.cpp
+++ b/zenhttp/websocketasio.cpp
@@ -3,7 +3,7 @@
#include <zenhttp/websocket.h>
#include <zencore/base64.h>
-#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinarybuilder.h>
#include <zencore/intmath.h>
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
@@ -75,7 +75,7 @@ protected:
virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0;
virtual void OnReset() = 0;
- SimpleBinaryWriter m_Stream;
+ BinaryWriter m_Stream;
};
ParseMessageResult
@@ -89,7 +89,7 @@ MessageParser::Reset()
{
OnReset();
- m_Stream.Clear();
+ m_Stream.Reset();
}
///////////////////////////////////////////////////////////////////////////////
@@ -374,13 +374,13 @@ class WebSocketMessageParser final : public MessageParser
public:
WebSocketMessageParser() : MessageParser() {}
- bool TryLoadMessage(CbPackage& OutMsg);
+ WebSocketMessage ConsumeMessage();
private:
virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
virtual void OnReset() override;
- WebSocketMessageHeader m_Header;
+ WebSocketMessage m_Message;
};
ParseMessageResult
@@ -388,65 +388,76 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg)
{
const uint64_t PrevOffset = m_Stream.CurrentOffset();
- if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader))
+ if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
{
- const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_Stream.CurrentOffset();
+ const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset();
m_Stream.Write(Msg.Left(RemaingHeaderSize));
- Msg.RightChopInline(RemaingHeaderSize);
-
- if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader))
+ if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
{
return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
- const bool IsValidHeader = WebSocketMessageHeader::Read(m_Stream.GetView(), m_Header);
+ const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView());
if (IsValidHeader == false)
{
- return {.Status = ParseMessageStatus::kError, .Reason = std::string("Invalid websocket message header")};
+ OnReset();
+
+ return {.Status = ParseMessageStatus::kError,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
+ .Reason = std::string("Invalid websocket message header")};
}
+
+ Msg += RemaingHeaderSize;
}
- ZEN_ASSERT(m_Stream.CurrentOffset() >= sizeof(WebSocketMessageHeader));
+ ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize);
if (Msg.IsEmpty())
{
return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
- const uint64_t RemaingContentSize =
- Min(m_Header.ContentLength - (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)), Msg.GetSize());
+ const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset();
+
+ m_Stream.Write(Msg.Left(RemaingMessageSize));
+
+ const bool IsComplete = WebSocketMessage::HeaderSize + m_Message.MessageSize() == m_Stream.CurrentOffset();
+
+ if (IsComplete)
+ {
+ BinaryReader Reader(m_Stream.GetView(WebSocketMessage::HeaderSize));
- m_Stream.Write(Msg.Left(RemaingContentSize));
+ CbPackage Pkg;
+ if (Pkg.TryLoad(Reader) == false)
+ {
+ return {.Status = ParseMessageStatus::kError,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
+ .Reason = std::string("Invalid websocket message")};
+ }
- const auto Status = (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)) == m_Header.ContentLength
- ? ParseMessageStatus::kDone
- : ParseMessageStatus::kContinue;
+ m_Message.SetBody(std::move(Pkg));
+ }
- return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ return {.Status = IsComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
void
WebSocketMessageParser::OnReset()
{
- m_Header = {};
+ m_Message = WebSocketMessage();
}
-bool
-WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg)
+WebSocketMessage
+WebSocketMessageParser::ConsumeMessage()
{
- const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength + sizeof(WebSocketMessageHeader);
+ WebSocketMessage Msg = std::move(m_Message);
+ m_Message = WebSocketMessage();
- if (IsParsed)
- {
- BinaryReader Reader(m_Stream.GetView().RightChop(sizeof(WebSocketMessageHeader)));
-
- return OutMsg.TryLoad(Reader);
- }
-
- return false;
+ return Msg;
}
///////////////////////////////////////////////////////////////////////////////
@@ -574,10 +585,15 @@ public:
WsServer() = default;
virtual ~WsServer() { Shutdown(); }
- virtual void RegisterService(WebSocketService& Service) override;
virtual bool Run(const WebSocketServerOptions& Options) override;
virtual void Shutdown() override;
- virtual void PublishMessage(WebSocketId Id, CbPackage&& Msg) override;
+
+ virtual void RegisterService(WebSocketService& Service) override;
+ virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override;
+ virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override;
+
+ virtual void SendNotification(WebSocketMessage&& Notification) override;
+ virtual void SendResponse(WebSocketMessage&& Response) override;
private:
friend class WsConnection;
@@ -586,14 +602,17 @@ private:
void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec);
void ReadMessage(std::shared_ptr<WsConnection> Connection);
- void RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg);
+ void RouteMessage(WebSocketMessage&& Msg);
+ void SendMessage(WebSocketMessage&& Msg);
struct IdHasher
{
size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); }
};
- using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>;
+ using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>;
+ using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>;
+ using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>;
asio::io_service m_IoSvc;
std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor;
@@ -601,6 +620,8 @@ private:
ConnectionMap m_Connections;
std::shared_mutex m_ConnMutex;
std::vector<WebSocketService*> m_Services;
+ RequestHandlerMap m_RequestHandlers;
+ NotificationHandlerMap m_NotificationHandlers;
std::atomic_bool m_Running{};
};
@@ -664,34 +685,32 @@ WsServer::Shutdown()
}
void
-WsServer::PublishMessage(WebSocketId Id, CbPackage&& Msg)
+WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service)
{
- std::shared_ptr<WsConnection> Connection;
-
- {
- std::unique_lock _(m_ConnMutex);
-
- if (auto It = m_Connections.find(Id); It != m_Connections.end())
- {
- Connection = It->second;
- }
- }
+ auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>());
+ Result.first->second.push_back(&Service);
+}
- BinaryWriter Writer;
- WebSocketMessageHeader::Write(Writer, Msg);
+void
+WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service)
+{
+ m_RequestHandlers[Key] = &Service;
+}
- IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+void
+WsServer::SendNotification(WebSocketMessage&& Notification)
+{
+ ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification);
- ZEN_LOG_WARN(LogWebSocket, "sending message {}B to '#{} {}' ", Buffer.Size(), Connection->Id().Value(), Connection->RemoteAddr());
+ SendMessage(std::move(Notification));
+}
+void
+WsServer::SendResponse(WebSocketMessage&& Response)
+{
+ ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse);
+ ZEN_ASSERT(Response.CorrelationId() != 0);
- async_write(Connection->Socket(),
- asio::buffer(Buffer.Data(), Buffer.Size()),
- [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) {
- if (Ec)
- {
- CloseConnection(Connection, Ec);
- }
- });
+ SendMessage(std::move(Response));
}
void
@@ -768,7 +787,7 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
Connection->Socket(),
Connection->ReadBuffer(),
asio::transfer_at_least(1),
- [this, Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable {
+ [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable {
if (ReadEc)
{
return CloseConnection(Connection, ReadEc);
@@ -911,20 +930,15 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
{
WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser());
- MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), ByteCount);
+ uint64_t RemainingBytes = Connection->ReadBuffer().size();
- while (MessageData.IsEmpty() == false)
+ while (RemainingBytes > 0)
{
- const ParseMessageResult Result = Parser.ParseMessage(MessageData);
-
- MessageData.RightChopInline(Result.ByteCount);
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- ZEN_ASSERT(MessageData.IsEmpty());
+ MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes);
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
- return ReadMessage(Connection);
- }
+ Connection->ReadBuffer().consume(Result.ByteCount);
+ RemainingBytes = Connection->ReadBuffer().size();
if (Result.Status == ParseMessageStatus::kError)
{
@@ -933,20 +947,19 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
return CloseConnection(Connection, std::error_code());
}
- CbPackage Message;
- if (Parser.TryLoadMessage(Message) == false)
+ if (Result.Status == ParseMessageStatus::kContinue)
{
- ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'");
-
- return CloseConnection(Connection, std::error_code());
+ ZEN_ASSERT(RemainingBytes == 0);
+ continue;
}
- RouteMessage(Connection, Message);
-
+ WebSocketMessage Message = Parser.ConsumeMessage();
Parser.Reset();
- }
- Connection->ReadBuffer().consume(ByteCount);
+ Message.SetSocketId(Connection->Id());
+
+ RouteMessage(std::move(Message));
+ }
ReadMessage(Connection);
}
@@ -959,19 +972,114 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
}
void
-WsServer::RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg)
+WsServer::RouteMessage(WebSocketMessage&& RoutedMessage)
{
- ZEN_LOG_DEBUG(LogWebSocket, "routing message");
+ switch (RoutedMessage.MessageType())
+ {
+ case WebSocketMessageType::kRequest:
+ {
+ CbObjectView Request = RoutedMessage.Body().GetObject();
+ std::string_view Method = Request["Method"].AsString();
+ bool Handled = false;
+ bool Error = false;
+ std::exception Exception;
+
+ if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end())
+ {
+ WebSocketService* Service = It->second;
+ ZEN_ASSERT(Service);
+
+ try
+ {
+ Handled = Service->HandleRequest(std::move(RoutedMessage));
+ }
+ catch (std::exception& Err)
+ {
+ Exception = std::move(Err);
+ Error = true;
+ }
+ }
+
+ if (Error || Handled == false)
+ {
+ std::string ErrorText = Error ? Exception.what() : std::string("Not Found");
+
+ ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText);
+
+ CbObjectWriter Response;
+ Response << "Error"sv << ErrorText;
+
+ WebSocketMessage ResponseMsg;
+ ResponseMsg.SetMessageType(WebSocketMessageType::kResponse);
+ ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId());
+ ResponseMsg.SetSocketId(RoutedMessage.SocketId());
+ ResponseMsg.SetBody(Response.Save());
+
+ SendResponse(std::move(ResponseMsg));
+ }
+ }
+ break;
+
+ case WebSocketMessageType::kNotification:
+ {
+ CbObjectView Notification = RoutedMessage.Body().GetObject();
+ std::string_view Message = Notification["Message"].AsString();
+
+ if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end())
+ {
+ std::vector<WebSocketService*>& Handlers = It->second;
+
+ for (WebSocketService* Handler : Handlers)
+ {
+ Handler->HandleNotification(RoutedMessage);
+ }
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message);
+ }
+ }
+ break;
+ };
+}
+
+void
+WsServer::SendMessage(WebSocketMessage&& Msg)
+{
+ std::shared_ptr<WsConnection> Connection;
- for (auto Server : m_Services)
{
- if (Server->HandleMessage(Connection->Id(), Msg))
+ std::unique_lock _(m_ConnMutex);
+
+ if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end())
{
- return;
+ Connection = It->second;
}
}
- ZEN_LOG_WARN(LogWebSocket, "unhandled message");
+ if (Connection.get() == nullptr)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value());
+ return;
+ }
+
+ if (Connection.get() != nullptr)
+ {
+ BinaryWriter Writer;
+ Msg.Save(Writer);
+ IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+
+ async_write(Connection->Socket(),
+ asio::buffer(Buffer.Data(), Buffer.Size()),
+ [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason '{}'", Ec.message());
+
+ CloseConnection(Connection, Ec);
+ }
+ });
+ }
}
///////////////////////////////////////////////////////////////////////////////
@@ -984,15 +1092,14 @@ public:
std::shared_ptr<WsClient> AsShared() { return shared_from_this(); }
- virtual bool Connect(const WebSocketConnectInfo& Info) override;
- virtual void Disconnect() override;
- virtual bool IsConnected() const { return false; }
- virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); }
- virtual void SendMsg(CbPackage&& Msg) override;
- virtual void SendMsg(CbObject&& Msg) override;
+ virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override;
+ virtual void Disconnect() override;
+ virtual bool IsConnected() const { return false; }
+ virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); }
- virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override;
- virtual void OnMessage(MessageCallback&& Cb) override;
+ virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override;
+ virtual void OnNotification(NotificationCallback&& Cb) override;
+ virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override;
private:
WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
@@ -1001,7 +1108,9 @@ private:
asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
void TriggerEvent(WebSocketEvent Evt);
void ReadMessage();
- void RouteMessage(CbPackage&& Msg);
+ void RouteMessage(WebSocketMessage&& RoutedMessage);
+
+ using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>;
asio::io_context& m_IoCtx;
WebSocketId m_Id;
@@ -1009,18 +1118,21 @@ private:
std::unique_ptr<MessageParser> m_MsgParser;
asio::streambuf m_ReadBuffer;
EventCallback m_EventCallbacks[3];
- MessageCallback m_MsgCallback;
+ NotificationCallback m_NotificationCallback;
+ PendingRequestMap m_PendingRequests;
+ std::mutex m_RequestMutex;
+ std::promise<bool> m_ConnectPromise;
std::atomic_uint32_t m_State;
std::string m_Host;
int16_t m_Port{};
};
-bool
+std::future<bool>
WsClient::Connect(const WebSocketConnectInfo& Info)
{
if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected)
{
- return true;
+ return m_ConnectPromise.get_future();
}
SetState(WebSocketState::kHandshaking);
@@ -1046,7 +1158,9 @@ WsClient::Connect(const WebSocketConnectInfo& Info)
TriggerEvent(WebSocketEvent::kDisconnected);
- return false;
+ m_ConnectPromise.set_value(false);
+
+ return m_ConnectPromise.get_future();
}
ExtendableStringBuilder<128> Sb;
@@ -1093,7 +1207,7 @@ WsClient::Connect(const WebSocketConnectInfo& Info)
}
});
- return true;
+ return m_ConnectPromise.get_future();
}
void
@@ -1110,48 +1224,64 @@ WsClient::Disconnect()
}
TriggerEvent(WebSocketEvent::kDisconnected);
+
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ for (auto& Kv : m_PendingRequests)
+ {
+ Kv.second.set_value(WebSocketMessage());
+ }
+
+ m_PendingRequests.clear();
+ }
}
}
-void
-WsClient::SendMsg(CbPackage&& Msg)
+std::future<WebSocketMessage>
+WsClient::SendRequest(WebSocketMessage&& Request)
{
+ ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest);
+
BinaryWriter Writer;
- WebSocketMessageHeader::Write(Writer, Msg);
+ Request.Save(Writer);
- IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+ std::future<WebSocketMessage> FutureResponse;
+
+ {
+ std::unique_lock _(m_RequestMutex);
- ZEN_LOG_DEBUG(LogWsClient, "sending message {}B", Buffer.Size());
+ auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>());
+ ZEN_ASSERT(Result.second);
- async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const asio::error_code& Ec, std::size_t) {
+ auto It = Result.first;
+ FutureResponse = It->second.get_future();
+ }
+
+ IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+
+ async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) {
if (Ec)
{
- ZEN_LOG_ERROR(LogWsClient, "send messge FAILED, reason '{}'", Ec.message());
+ ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message());
Self->Disconnect();
}
});
-}
-
-void
-WsClient::SendMsg(CbObject&& Msg)
-{
- CbPackage Pkg;
- Pkg.SetObject(std::move(Msg));
- WsClient::SendMsg(std::move(Pkg));
+ return FutureResponse;
}
void
-WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
+WsClient::OnNotification(NotificationCallback&& Cb)
{
- m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb);
+ m_NotificationCallback = std::move(Cb);
}
void
-WsClient::OnMessage(MessageCallback&& Cb)
+WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
{
- m_MsgCallback = std::move(Cb);
+ m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb);
}
void
@@ -1199,6 +1329,8 @@ WsClient::ReadMessage()
{
ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode());
+ Self->m_ConnectPromise.set_value(false);
+
return Self->Disconnect();
}
@@ -1216,6 +1348,8 @@ WsClient::ReadMessage()
Parser.StatusText(),
Parser.StatusCode());
+ Self->m_ConnectPromise.set_value(false);
+
return Self->Disconnect();
}
@@ -1225,6 +1359,8 @@ WsClient::ReadMessage()
Self->SetState(WebSocketState::kConnected);
Self->ReadMessage();
Self->TriggerEvent(WebSocketEvent::kConnected);
+
+ Self->m_ConnectPromise.set_value(true);
}
break;
@@ -1232,44 +1368,36 @@ WsClient::ReadMessage()
{
WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser());
- MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount);
+ uint64_t RemainingBytes = Self->ReadBuffer().size();
- while (MessageData.IsEmpty() == false)
+ while (RemainingBytes > 0)
{
- const ParseMessageResult Result = Parser.ParseMessage(MessageData);
-
- MessageData.RightChopInline(Result.ByteCount);
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- ZEN_ASSERT(MessageData.IsEmpty());
+ MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes);
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
- return Self->ReadMessage();
- }
+ Self->ReadBuffer().consume(Result.ByteCount);
+ RemainingBytes = Self->ReadBuffer().size();
if (Result.Status == ParseMessageStatus::kError)
{
ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
Parser.Reset();
-
continue;
}
- CbPackage Message;
- if (Parser.TryLoadMessage(Message))
- {
- Self->RouteMessage(std::move(Message));
- }
- else
+ if (Result.Status == ParseMessageStatus::kContinue)
{
- ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason 'invalid message'");
+ ZEN_ASSERT(RemainingBytes == 0);
+ continue;
}
+ WebSocketMessage Message = Parser.ConsumeMessage();
Parser.Reset();
+
+ Self->RouteMessage(std::move(Message));
}
- Self->ReadBuffer().consume(ByteCount);
Self->ReadMessage();
}
break;
@@ -1278,12 +1406,39 @@ WsClient::ReadMessage()
}
void
-WsClient::RouteMessage(CbPackage&& Msg)
+WsClient::RouteMessage(WebSocketMessage&& RoutedMessage)
{
- if (m_MsgCallback)
+ switch (RoutedMessage.MessageType())
{
- m_MsgCallback(Msg);
- }
+ case WebSocketMessageType::kResponse:
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end())
+ {
+ It->second.set_value(std::move(RoutedMessage));
+ m_PendingRequests.erase(It);
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWsClient,
+ "route request message FAILED, reason 'unknown correlation ID ({})'",
+ RoutedMessage.CorrelationId());
+ }
+ }
+ break;
+
+ case WebSocketMessageType::kNotification:
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ if (m_NotificationCallback)
+ {
+ m_NotificationCallback(std::move(RoutedMessage));
+ }
+ }
+ break;
+ };
}
} // namespace zen::websocket
@@ -1292,67 +1447,78 @@ namespace zen {
std::atomic_uint32_t WebSocketId::NextId{1};
-void
-WebSocketService::Configure(WebSocketServer& Server)
+bool
+WebSocketMessage::Header::IsValid() const
{
- ZEN_ASSERT(m_Server == nullptr);
-
- m_Server = &Server;
+ return Magic == HeaderMagic && MessageSize > 0 && Crc32 > 0 && uint8_t(MessageType) > 0 && uint8_t(MessageType) < 4;
}
+std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1};
+
void
-WebSocketService::PublishMessage(WebSocketId Id, CbPackage&& Msg)
+WebSocketMessage::SetMessageType(WebSocketMessageType MessageType)
{
- ZEN_ASSERT(m_Server != nullptr);
-
- m_Server->PublishMessage(Id, std::move(Msg));
+ m_Header.MessageType = MessageType;
}
void
-WebSocketService::PublishMessage(WebSocketId Id, CbObject&& Msg)
+WebSocketMessage::SetBody(CbPackage&& Body)
+{
+ m_Body = std::move(Body);
+}
+void
+WebSocketMessage::SetBody(CbObject&& Body)
{
- ZEN_ASSERT(m_Server != nullptr);
-
CbPackage Pkg;
- Pkg.SetObject(std::move(Msg));
+ Pkg.SetObject(Body);
- m_Server->PublishMessage(Id, std::move(Pkg));
+ SetBody(std::move(Pkg));
}
-bool
-WebSocketMessageHeader::IsValid() const
+void
+WebSocketMessage::Save(BinaryWriter& Writer)
{
- return Magic == ExpectedMagic && ContentLength != 0 && Crc32 != 0;
+ Writer.Write(&m_Header, HeaderSize);
+
+ if (m_Body.has_value())
+ {
+ m_Body.value().Save(Writer);
+ }
+
+ if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest)
+ {
+ m_Header.CorrelationId = NextCorrelationId.fetch_add(1);
+ }
+
+ m_Header.MessageSize = Writer.Size() - HeaderSize;
+ m_Header.Crc32 = 1; // TODO
+
+ Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize));
}
bool
-WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeader)
+WebSocketMessage::TryLoadHeader(MemoryView Memory)
{
- if (Memory.GetSize() < sizeof(WebSocketMessageHeader))
+ if (Memory.GetSize() < HeaderSize)
{
return false;
}
- void* Dst = &OutHeader;
- memcpy(Dst, Memory.GetData(), sizeof(WebSocketMessageHeader));
+ MutableMemoryView HeaderView(&m_Header, HeaderSize);
+
+ HeaderView.CopyFrom(Memory);
- return OutHeader.IsValid();
+ return m_Header.IsValid();
}
void
-WebSocketMessageHeader::Write(BinaryWriter& Writer, const CbPackage& Msg)
+WebSocketService::Configure(WebSocketServer& Server)
{
- WebSocketMessageHeader Header;
-
- Writer.Write(&Header, sizeof(WebSocketMessageHeader));
- Msg.Save(Writer);
+ ZEN_ASSERT(m_SocketServer == nullptr);
- Header.Magic = WebSocketMessageHeader::ExpectedMagic;
- Header.ContentLength = Writer.CurrentOffset() - sizeof(WebSocketMessageHeader);
- Header.Crc32 = WebSocketMessageHeader::ExpectedMagic; // TODO
+ m_SocketServer = &Server;
- MutableMemoryView HeaderView(const_cast<uint8_t*>(Writer.Data()), sizeof(WebSocketMessageHeader));
- HeaderView.CopyFrom(MemoryView(&Header, sizeof(WebSocketMessageHeader)));
+ RegisterHandlers(Server);
}
std::unique_ptr<WebSocketServer>