aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-02-21 14:22:38 +0100
committerPer Larsson <[email protected]>2022-02-21 14:22:38 +0100
commitfd9f9086b3ddd0c38fa87d7e49f6341dacdcc125 (patch)
tree50be00e62f1e0d3a0521d08d5c7f00fe7787e014
parentBasic websocket service and test. (diff)
downloadzen-fd9f9086b3ddd0c38fa87d7e49f6341dacdcc125.tar.xz
zen-fd9f9086b3ddd0c38fa87d7e49f6341dacdcc125.zip
Refactored websocket message.
-rw-r--r--zencore/include/zencore/stream.h48
-rw-r--r--zencore/stream.cpp19
-rw-r--r--zenhttp/include/zenhttp/websocket.h156
-rw-r--r--zenhttp/websocketasio.cpp516
-rw-r--r--zenserver-test/zenserver-test.cpp54
-rw-r--r--zenserver/testing/httptest.cpp31
-rw-r--r--zenserver/testing/httptest.h4
7 files changed, 523 insertions, 305 deletions
diff --git a/zencore/include/zencore/stream.h b/zencore/include/zencore/stream.h
index 6d7e7d19f..54d7e1014 100644
--- a/zencore/include/zencore/stream.h
+++ b/zencore/include/zencore/stream.h
@@ -27,12 +27,29 @@ public:
m_Offset += ByteCount;
}
+ inline void Write(MemoryView Memory) { Write(Memory.GetData(), Memory.GetSize()); }
+
inline uint64_t CurrentOffset() const { return m_Offset; }
inline const uint8_t* Data() const { return m_Buffer.data(); }
inline const uint8_t* GetData() const { return m_Buffer.data(); }
inline uint64_t Size() const { return m_Buffer.size(); }
inline uint64_t GetSize() const { return m_Buffer.size(); }
+ void Reset();
+
+ inline MemoryView GetView(uint64_t Offset = 0) const
+ {
+ MemoryView View(m_Buffer.data(), m_Offset);
+ View.RightChopInline(Offset);
+ return View;
+ }
+
+ inline MutableMemoryView GetMutableView(uint64_t Offset = 0)
+ {
+ MutableMemoryView View(m_Buffer.data(), m_Offset);
+ View.RightChopInline(Offset);
+ return View;
+ }
private:
RwLock m_Lock;
@@ -49,37 +66,6 @@ MakeMemoryView(const BinaryWriter& Stream)
}
/**
- * Non thread-safe stream writer
- */
-
-class SimpleBinaryWriter
-{
- static constexpr uint32_t DefaultBlockSize = 64;
-
-public:
- SimpleBinaryWriter(uint32_t BlockSize = DefaultBlockSize) : m_BlockSize(BlockSize), m_Offset{0} {}
- ~SimpleBinaryWriter() = default;
-
- void Write(MemoryView Memory);
- void Write(const void* Data, size_t Size) { Write(MemoryView(Data, Size)); }
- void Clear();
-
- inline uint64_t CurrentOffset() const { return m_Offset; }
-
- inline const uint8_t* Data() const { return m_Buffer.data(); }
- inline const uint8_t* GetData() const { return m_Buffer.data(); }
- inline uint64_t Size() const { return m_Buffer.size(); }
- inline uint64_t GetSize() const { return m_Buffer.size(); }
-
- MemoryView GetView() const { return MemoryView(m_Buffer.data(), m_Offset); }
-
-private:
- std::vector<uint8_t> m_Buffer;
- uint64_t m_Offset;
- uint32_t m_BlockSize;
-};
-
-/**
* Binary stream reader
*/
diff --git a/zencore/stream.cpp b/zencore/stream.cpp
index 36953363f..8faf90af2 100644
--- a/zencore/stream.cpp
+++ b/zencore/stream.cpp
@@ -26,25 +26,10 @@ BinaryWriter::Write(const void* data, size_t ByteCount, uint64_t Offset)
}
void
-SimpleBinaryWriter::Write(MemoryView Memory)
+BinaryWriter::Reset()
{
- const uint64_t NeededSize = m_Offset + Memory.GetSize();
-
- if (NeededSize > m_Buffer.size())
- {
- const size_t NewCapacity = RoundUp(NeededSize, m_BlockSize);
-
- m_Buffer.resize(NewCapacity);
- }
-
- memcpy(m_Buffer.data() + m_Offset, Memory.GetData(), Memory.GetSize());
-
- m_Offset += Memory.GetSize();
-}
+ RwLock::ExclusiveLockScope _(m_Lock);
-void
-SimpleBinaryWriter::Clear()
-{
m_Buffer.clear();
m_Offset = 0;
}
diff --git a/zenhttp/include/zenhttp/websocket.h b/zenhttp/include/zenhttp/websocket.h
index 1ab9b4804..336d98b42 100644
--- a/zenhttp/include/zenhttp/websocket.h
+++ b/zenhttp/include/zenhttp/websocket.h
@@ -1,10 +1,13 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#include <zencore/compactbinarypackage.h>
#include <zencore/memory.h>
#include <compare>
#include <functional>
+#include <future>
#include <memory>
+#include <optional>
#pragma once
@@ -14,10 +17,11 @@ class io_context;
namespace zen {
-class CbPackage;
-class CbObject;
class BinaryWriter;
+/**
+ * A unique socket ID.
+ */
class WebSocketId
{
static std::atomic_uint32_t NextId;
@@ -37,45 +41,132 @@ private:
uint32_t m_Value{};
};
+/**
+ * Type of web socket message.
+ */
+enum class WebSocketMessageType : uint8_t
+{
+ kInvalid,
+ kNotification,
+ kRequest,
+ kResponse
+};
+
+/**
+ * Web socket message.
+ */
+class WebSocketMessage
+{
+ struct Header
+ {
+ static constexpr uint32_t HeaderMagic = 0x7a776d68; // zwmh
+
+ uint64_t MessageSize{};
+ uint32_t Magic{HeaderMagic};
+ uint32_t CorrelationId{};
+ uint32_t Crc32{};
+ WebSocketMessageType MessageType{};
+ uint8_t Reserved[3] = {0};
+
+ bool IsValid() const;
+ };
+
+ static_assert(sizeof Header == 24);
+
+ static std::atomic_uint32_t NextCorrelationId;
+
+public:
+ static constexpr size_t HeaderSize = sizeof(Header);
+
+ WebSocketMessage() = default;
+
+ WebSocketId SocketId() const { return m_SocketId; }
+ void SetSocketId(WebSocketId Id) { m_SocketId = Id; }
+ void SetMessageType(WebSocketMessageType MessageType);
+ WebSocketMessageType MessageType() const { return m_Header.MessageType; }
+ uint64_t MessageSize() const { return m_Header.MessageSize; }
+ void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; }
+ uint32_t CorrelationId() const { return m_Header.CorrelationId; }
+
+ const CbPackage& Body() const { return m_Body.value(); }
+ void SetBody(CbPackage&& Body);
+ void SetBody(CbObject&& Body);
+ bool HasBody() const { return m_Body.has_value(); }
+
+ void Save(BinaryWriter& Writer);
+ bool TryLoadHeader(MemoryView Memory);
+
+ bool IsValid() const { return m_Header.MessageType != WebSocketMessageType::kInvalid; }
+
+private:
+ Header m_Header{};
+ WebSocketId m_SocketId{};
+ std::optional<CbPackage> m_Body;
+};
+
class WebSocketServer;
+/**
+ * Base class for handling web socket requests and notifications from connected client(s).
+ */
class WebSocketService
{
public:
virtual ~WebSocketService() = default;
- virtual bool HandleMessage(WebSocketId Id, const CbPackage& Msg) = 0;
- void Configure(WebSocketServer& Server);
+ void Configure(WebSocketServer& Server);
+
+ virtual bool HandleRequest(const WebSocketMessage&) { ZEN_ASSERT(false); }
+ virtual void HandleNotification(const WebSocketMessage&) { ZEN_ASSERT(false); }
protected:
WebSocketService() = default;
- void PublishMessage(WebSocketId Id, CbPackage&& Msg);
- void PublishMessage(WebSocketId Id, CbObject&& Msg);
+ virtual void RegisterHandlers(WebSocketServer& Server) = 0;
+
+ WebSocketServer& SocketServer()
+ {
+ ZEN_ASSERT(m_SocketServer);
+ return *m_SocketServer;
+ }
private:
- WebSocketServer* m_Server = nullptr;
+ WebSocketServer* m_SocketServer{};
};
+/**
+ * Server options.
+ */
struct WebSocketServerOptions
{
uint16_t Port = 8848;
uint32_t ThreadCount = 1;
};
+/**
+ * The web socket server manages client connections and routing of requests and notifications.
+ */
class WebSocketServer
{
public:
virtual ~WebSocketServer() = default;
- virtual void RegisterService(WebSocketService& Service) = 0;
- virtual bool Run(const WebSocketServerOptions& Options) = 0;
- virtual void Shutdown() = 0;
- virtual void PublishMessage(WebSocketId Id, CbPackage&& Msg) = 0;
+ virtual bool Run(const WebSocketServerOptions& Options) = 0;
+ virtual void Shutdown() = 0;
+
+ virtual void RegisterService(WebSocketService& Service) = 0;
+ virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) = 0;
+ virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) = 0;
+
+ virtual void SendNotification(WebSocketMessage&& Notification) = 0;
+ virtual void SendResponse(WebSocketMessage&& Response) = 0;
static std::unique_ptr<WebSocketServer> Create();
};
+/**
+ * The state of the web socket.
+ */
enum class WebSocketState : uint32_t
{
kNone,
@@ -85,6 +176,9 @@ enum class WebSocketState : uint32_t
kError
};
+/**
+ * Type of web socket client event.
+ */
enum class WebSocketEvent : uint32_t
{
kConnected,
@@ -92,6 +186,9 @@ enum class WebSocketEvent : uint32_t
kError
};
+/**
+ * Web socket client connection info.
+ */
struct WebSocketConnectInfo
{
std::string Host;
@@ -101,40 +198,27 @@ struct WebSocketConnectInfo
uint16_t Version{13};
};
+/**
+ * A connection to a web socket server for sending requests and listening for notifications.
+ */
class WebSocketClient
{
public:
- using EventCallback = std::function<void()>;
- using MessageCallback = std::function<void(CbPackage&)>;
+ using EventCallback = std::function<void()>;
+ using NotificationCallback = std::function<void(WebSocketMessage&&)>;
virtual ~WebSocketClient() = default;
- virtual bool Connect(const WebSocketConnectInfo& Info) = 0;
- virtual void Disconnect() = 0;
- virtual bool IsConnected() const = 0;
- virtual WebSocketState State() const = 0;
- virtual void SendMsg(CbObject&& Msg) = 0;
- virtual void SendMsg(CbPackage&& Msg) = 0;
+ virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) = 0;
+ virtual void Disconnect() = 0;
+ virtual bool IsConnected() const = 0;
+ virtual WebSocketState State() const = 0;
- virtual void On(WebSocketEvent Evt, EventCallback&& Cb) = 0;
- virtual void OnMessage(MessageCallback&& Cb) = 0;
+ virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) = 0;
+ virtual void OnNotification(NotificationCallback&& Cb) = 0;
+ virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) = 0;
static std::shared_ptr<WebSocketClient> Create(asio::io_context& IoCtx);
};
-struct WebSocketMessageHeader
-{
- static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh
-
- uint32_t Magic{};
- uint64_t ContentLength{};
- uint32_t Crc32{};
-
- bool IsValid() const;
-
- static bool Read(MemoryView Memory, WebSocketMessageHeader& OutHeader);
-
- static void Write(BinaryWriter& Writer, const CbPackage& Msg);
-};
-
} // namespace zen
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp
index f6f58f38c..13d0177ee 100644
--- a/zenhttp/websocketasio.cpp
+++ b/zenhttp/websocketasio.cpp
@@ -3,7 +3,7 @@
#include <zenhttp/websocket.h>
#include <zencore/base64.h>
-#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinarybuilder.h>
#include <zencore/intmath.h>
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
@@ -75,7 +75,7 @@ protected:
virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0;
virtual void OnReset() = 0;
- SimpleBinaryWriter m_Stream;
+ BinaryWriter m_Stream;
};
ParseMessageResult
@@ -89,7 +89,7 @@ MessageParser::Reset()
{
OnReset();
- m_Stream.Clear();
+ m_Stream.Reset();
}
///////////////////////////////////////////////////////////////////////////////
@@ -374,13 +374,13 @@ class WebSocketMessageParser final : public MessageParser
public:
WebSocketMessageParser() : MessageParser() {}
- bool TryLoadMessage(CbPackage& OutMsg);
+ WebSocketMessage ConsumeMessage();
private:
virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
virtual void OnReset() override;
- WebSocketMessageHeader m_Header;
+ WebSocketMessage m_Message;
};
ParseMessageResult
@@ -388,65 +388,76 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg)
{
const uint64_t PrevOffset = m_Stream.CurrentOffset();
- if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader))
+ if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
{
- const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_Stream.CurrentOffset();
+ const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset();
m_Stream.Write(Msg.Left(RemaingHeaderSize));
- Msg.RightChopInline(RemaingHeaderSize);
-
- if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader))
+ if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
{
return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
- const bool IsValidHeader = WebSocketMessageHeader::Read(m_Stream.GetView(), m_Header);
+ const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView());
if (IsValidHeader == false)
{
- return {.Status = ParseMessageStatus::kError, .Reason = std::string("Invalid websocket message header")};
+ OnReset();
+
+ return {.Status = ParseMessageStatus::kError,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
+ .Reason = std::string("Invalid websocket message header")};
}
+
+ Msg += RemaingHeaderSize;
}
- ZEN_ASSERT(m_Stream.CurrentOffset() >= sizeof(WebSocketMessageHeader));
+ ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize);
if (Msg.IsEmpty())
{
return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
- const uint64_t RemaingContentSize =
- Min(m_Header.ContentLength - (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)), Msg.GetSize());
+ const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset();
+
+ m_Stream.Write(Msg.Left(RemaingMessageSize));
+
+ const bool IsComplete = WebSocketMessage::HeaderSize + m_Message.MessageSize() == m_Stream.CurrentOffset();
+
+ if (IsComplete)
+ {
+ BinaryReader Reader(m_Stream.GetView(WebSocketMessage::HeaderSize));
- m_Stream.Write(Msg.Left(RemaingContentSize));
+ CbPackage Pkg;
+ if (Pkg.TryLoad(Reader) == false)
+ {
+ return {.Status = ParseMessageStatus::kError,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
+ .Reason = std::string("Invalid websocket message")};
+ }
- const auto Status = (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)) == m_Header.ContentLength
- ? ParseMessageStatus::kDone
- : ParseMessageStatus::kContinue;
+ m_Message.SetBody(std::move(Pkg));
+ }
- return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ return {.Status = IsComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
void
WebSocketMessageParser::OnReset()
{
- m_Header = {};
+ m_Message = WebSocketMessage();
}
-bool
-WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg)
+WebSocketMessage
+WebSocketMessageParser::ConsumeMessage()
{
- const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength + sizeof(WebSocketMessageHeader);
+ WebSocketMessage Msg = std::move(m_Message);
+ m_Message = WebSocketMessage();
- if (IsParsed)
- {
- BinaryReader Reader(m_Stream.GetView().RightChop(sizeof(WebSocketMessageHeader)));
-
- return OutMsg.TryLoad(Reader);
- }
-
- return false;
+ return Msg;
}
///////////////////////////////////////////////////////////////////////////////
@@ -574,10 +585,15 @@ public:
WsServer() = default;
virtual ~WsServer() { Shutdown(); }
- virtual void RegisterService(WebSocketService& Service) override;
virtual bool Run(const WebSocketServerOptions& Options) override;
virtual void Shutdown() override;
- virtual void PublishMessage(WebSocketId Id, CbPackage&& Msg) override;
+
+ virtual void RegisterService(WebSocketService& Service) override;
+ virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override;
+ virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override;
+
+ virtual void SendNotification(WebSocketMessage&& Notification) override;
+ virtual void SendResponse(WebSocketMessage&& Response) override;
private:
friend class WsConnection;
@@ -586,14 +602,17 @@ private:
void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec);
void ReadMessage(std::shared_ptr<WsConnection> Connection);
- void RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg);
+ void RouteMessage(WebSocketMessage&& Msg);
+ void SendMessage(WebSocketMessage&& Msg);
struct IdHasher
{
size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); }
};
- using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>;
+ using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>;
+ using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>;
+ using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>;
asio::io_service m_IoSvc;
std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor;
@@ -601,6 +620,8 @@ private:
ConnectionMap m_Connections;
std::shared_mutex m_ConnMutex;
std::vector<WebSocketService*> m_Services;
+ RequestHandlerMap m_RequestHandlers;
+ NotificationHandlerMap m_NotificationHandlers;
std::atomic_bool m_Running{};
};
@@ -664,34 +685,32 @@ WsServer::Shutdown()
}
void
-WsServer::PublishMessage(WebSocketId Id, CbPackage&& Msg)
+WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service)
{
- std::shared_ptr<WsConnection> Connection;
-
- {
- std::unique_lock _(m_ConnMutex);
-
- if (auto It = m_Connections.find(Id); It != m_Connections.end())
- {
- Connection = It->second;
- }
- }
+ auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>());
+ Result.first->second.push_back(&Service);
+}
- BinaryWriter Writer;
- WebSocketMessageHeader::Write(Writer, Msg);
+void
+WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service)
+{
+ m_RequestHandlers[Key] = &Service;
+}
- IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+void
+WsServer::SendNotification(WebSocketMessage&& Notification)
+{
+ ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification);
- ZEN_LOG_WARN(LogWebSocket, "sending message {}B to '#{} {}' ", Buffer.Size(), Connection->Id().Value(), Connection->RemoteAddr());
+ SendMessage(std::move(Notification));
+}
+void
+WsServer::SendResponse(WebSocketMessage&& Response)
+{
+ ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse);
+ ZEN_ASSERT(Response.CorrelationId() != 0);
- async_write(Connection->Socket(),
- asio::buffer(Buffer.Data(), Buffer.Size()),
- [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) {
- if (Ec)
- {
- CloseConnection(Connection, Ec);
- }
- });
+ SendMessage(std::move(Response));
}
void
@@ -768,7 +787,7 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
Connection->Socket(),
Connection->ReadBuffer(),
asio::transfer_at_least(1),
- [this, Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable {
+ [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable {
if (ReadEc)
{
return CloseConnection(Connection, ReadEc);
@@ -911,20 +930,15 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
{
WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser());
- MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), ByteCount);
+ uint64_t RemainingBytes = Connection->ReadBuffer().size();
- while (MessageData.IsEmpty() == false)
+ while (RemainingBytes > 0)
{
- const ParseMessageResult Result = Parser.ParseMessage(MessageData);
-
- MessageData.RightChopInline(Result.ByteCount);
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- ZEN_ASSERT(MessageData.IsEmpty());
+ MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes);
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
- return ReadMessage(Connection);
- }
+ Connection->ReadBuffer().consume(Result.ByteCount);
+ RemainingBytes = Connection->ReadBuffer().size();
if (Result.Status == ParseMessageStatus::kError)
{
@@ -933,20 +947,19 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
return CloseConnection(Connection, std::error_code());
}
- CbPackage Message;
- if (Parser.TryLoadMessage(Message) == false)
+ if (Result.Status == ParseMessageStatus::kContinue)
{
- ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'");
-
- return CloseConnection(Connection, std::error_code());
+ ZEN_ASSERT(RemainingBytes == 0);
+ continue;
}
- RouteMessage(Connection, Message);
-
+ WebSocketMessage Message = Parser.ConsumeMessage();
Parser.Reset();
- }
- Connection->ReadBuffer().consume(ByteCount);
+ Message.SetSocketId(Connection->Id());
+
+ RouteMessage(std::move(Message));
+ }
ReadMessage(Connection);
}
@@ -959,19 +972,114 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
}
void
-WsServer::RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg)
+WsServer::RouteMessage(WebSocketMessage&& RoutedMessage)
{
- ZEN_LOG_DEBUG(LogWebSocket, "routing message");
+ switch (RoutedMessage.MessageType())
+ {
+ case WebSocketMessageType::kRequest:
+ {
+ CbObjectView Request = RoutedMessage.Body().GetObject();
+ std::string_view Method = Request["Method"].AsString();
+ bool Handled = false;
+ bool Error = false;
+ std::exception Exception;
+
+ if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end())
+ {
+ WebSocketService* Service = It->second;
+ ZEN_ASSERT(Service);
+
+ try
+ {
+ Handled = Service->HandleRequest(std::move(RoutedMessage));
+ }
+ catch (std::exception& Err)
+ {
+ Exception = std::move(Err);
+ Error = true;
+ }
+ }
+
+ if (Error || Handled == false)
+ {
+ std::string ErrorText = Error ? Exception.what() : std::string("Not Found");
+
+ ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText);
+
+ CbObjectWriter Response;
+ Response << "Error"sv << ErrorText;
+
+ WebSocketMessage ResponseMsg;
+ ResponseMsg.SetMessageType(WebSocketMessageType::kResponse);
+ ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId());
+ ResponseMsg.SetSocketId(RoutedMessage.SocketId());
+ ResponseMsg.SetBody(Response.Save());
+
+ SendResponse(std::move(ResponseMsg));
+ }
+ }
+ break;
+
+ case WebSocketMessageType::kNotification:
+ {
+ CbObjectView Notification = RoutedMessage.Body().GetObject();
+ std::string_view Message = Notification["Message"].AsString();
+
+ if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end())
+ {
+ std::vector<WebSocketService*>& Handlers = It->second;
+
+ for (WebSocketService* Handler : Handlers)
+ {
+ Handler->HandleNotification(RoutedMessage);
+ }
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message);
+ }
+ }
+ break;
+ };
+}
+
+void
+WsServer::SendMessage(WebSocketMessage&& Msg)
+{
+ std::shared_ptr<WsConnection> Connection;
- for (auto Server : m_Services)
{
- if (Server->HandleMessage(Connection->Id(), Msg))
+ std::unique_lock _(m_ConnMutex);
+
+ if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end())
{
- return;
+ Connection = It->second;
}
}
- ZEN_LOG_WARN(LogWebSocket, "unhandled message");
+ if (Connection.get() == nullptr)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value());
+ return;
+ }
+
+ if (Connection.get() != nullptr)
+ {
+ BinaryWriter Writer;
+ Msg.Save(Writer);
+ IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+
+ async_write(Connection->Socket(),
+ asio::buffer(Buffer.Data(), Buffer.Size()),
+ [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason '{}'", Ec.message());
+
+ CloseConnection(Connection, Ec);
+ }
+ });
+ }
}
///////////////////////////////////////////////////////////////////////////////
@@ -984,15 +1092,14 @@ public:
std::shared_ptr<WsClient> AsShared() { return shared_from_this(); }
- virtual bool Connect(const WebSocketConnectInfo& Info) override;
- virtual void Disconnect() override;
- virtual bool IsConnected() const { return false; }
- virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); }
- virtual void SendMsg(CbPackage&& Msg) override;
- virtual void SendMsg(CbObject&& Msg) override;
+ virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override;
+ virtual void Disconnect() override;
+ virtual bool IsConnected() const { return false; }
+ virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); }
- virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override;
- virtual void OnMessage(MessageCallback&& Cb) override;
+ virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override;
+ virtual void OnNotification(NotificationCallback&& Cb) override;
+ virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override;
private:
WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
@@ -1001,7 +1108,9 @@ private:
asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
void TriggerEvent(WebSocketEvent Evt);
void ReadMessage();
- void RouteMessage(CbPackage&& Msg);
+ void RouteMessage(WebSocketMessage&& RoutedMessage);
+
+ using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>;
asio::io_context& m_IoCtx;
WebSocketId m_Id;
@@ -1009,18 +1118,21 @@ private:
std::unique_ptr<MessageParser> m_MsgParser;
asio::streambuf m_ReadBuffer;
EventCallback m_EventCallbacks[3];
- MessageCallback m_MsgCallback;
+ NotificationCallback m_NotificationCallback;
+ PendingRequestMap m_PendingRequests;
+ std::mutex m_RequestMutex;
+ std::promise<bool> m_ConnectPromise;
std::atomic_uint32_t m_State;
std::string m_Host;
int16_t m_Port{};
};
-bool
+std::future<bool>
WsClient::Connect(const WebSocketConnectInfo& Info)
{
if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected)
{
- return true;
+ return m_ConnectPromise.get_future();
}
SetState(WebSocketState::kHandshaking);
@@ -1046,7 +1158,9 @@ WsClient::Connect(const WebSocketConnectInfo& Info)
TriggerEvent(WebSocketEvent::kDisconnected);
- return false;
+ m_ConnectPromise.set_value(false);
+
+ return m_ConnectPromise.get_future();
}
ExtendableStringBuilder<128> Sb;
@@ -1093,7 +1207,7 @@ WsClient::Connect(const WebSocketConnectInfo& Info)
}
});
- return true;
+ return m_ConnectPromise.get_future();
}
void
@@ -1110,48 +1224,64 @@ WsClient::Disconnect()
}
TriggerEvent(WebSocketEvent::kDisconnected);
+
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ for (auto& Kv : m_PendingRequests)
+ {
+ Kv.second.set_value(WebSocketMessage());
+ }
+
+ m_PendingRequests.clear();
+ }
}
}
-void
-WsClient::SendMsg(CbPackage&& Msg)
+std::future<WebSocketMessage>
+WsClient::SendRequest(WebSocketMessage&& Request)
{
+ ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest);
+
BinaryWriter Writer;
- WebSocketMessageHeader::Write(Writer, Msg);
+ Request.Save(Writer);
- IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+ std::future<WebSocketMessage> FutureResponse;
+
+ {
+ std::unique_lock _(m_RequestMutex);
- ZEN_LOG_DEBUG(LogWsClient, "sending message {}B", Buffer.Size());
+ auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>());
+ ZEN_ASSERT(Result.second);
- async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const asio::error_code& Ec, std::size_t) {
+ auto It = Result.first;
+ FutureResponse = It->second.get_future();
+ }
+
+ IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+
+ async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) {
if (Ec)
{
- ZEN_LOG_ERROR(LogWsClient, "send messge FAILED, reason '{}'", Ec.message());
+ ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message());
Self->Disconnect();
}
});
-}
-
-void
-WsClient::SendMsg(CbObject&& Msg)
-{
- CbPackage Pkg;
- Pkg.SetObject(std::move(Msg));
- WsClient::SendMsg(std::move(Pkg));
+ return FutureResponse;
}
void
-WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
+WsClient::OnNotification(NotificationCallback&& Cb)
{
- m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb);
+ m_NotificationCallback = std::move(Cb);
}
void
-WsClient::OnMessage(MessageCallback&& Cb)
+WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
{
- m_MsgCallback = std::move(Cb);
+ m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb);
}
void
@@ -1199,6 +1329,8 @@ WsClient::ReadMessage()
{
ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode());
+ Self->m_ConnectPromise.set_value(false);
+
return Self->Disconnect();
}
@@ -1216,6 +1348,8 @@ WsClient::ReadMessage()
Parser.StatusText(),
Parser.StatusCode());
+ Self->m_ConnectPromise.set_value(false);
+
return Self->Disconnect();
}
@@ -1225,6 +1359,8 @@ WsClient::ReadMessage()
Self->SetState(WebSocketState::kConnected);
Self->ReadMessage();
Self->TriggerEvent(WebSocketEvent::kConnected);
+
+ Self->m_ConnectPromise.set_value(true);
}
break;
@@ -1232,44 +1368,36 @@ WsClient::ReadMessage()
{
WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser());
- MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount);
+ uint64_t RemainingBytes = Self->ReadBuffer().size();
- while (MessageData.IsEmpty() == false)
+ while (RemainingBytes > 0)
{
- const ParseMessageResult Result = Parser.ParseMessage(MessageData);
-
- MessageData.RightChopInline(Result.ByteCount);
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- ZEN_ASSERT(MessageData.IsEmpty());
+ MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes);
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
- return Self->ReadMessage();
- }
+ Self->ReadBuffer().consume(Result.ByteCount);
+ RemainingBytes = Self->ReadBuffer().size();
if (Result.Status == ParseMessageStatus::kError)
{
ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
Parser.Reset();
-
continue;
}
- CbPackage Message;
- if (Parser.TryLoadMessage(Message))
- {
- Self->RouteMessage(std::move(Message));
- }
- else
+ if (Result.Status == ParseMessageStatus::kContinue)
{
- ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason 'invalid message'");
+ ZEN_ASSERT(RemainingBytes == 0);
+ continue;
}
+ WebSocketMessage Message = Parser.ConsumeMessage();
Parser.Reset();
+
+ Self->RouteMessage(std::move(Message));
}
- Self->ReadBuffer().consume(ByteCount);
Self->ReadMessage();
}
break;
@@ -1278,12 +1406,39 @@ WsClient::ReadMessage()
}
void
-WsClient::RouteMessage(CbPackage&& Msg)
+WsClient::RouteMessage(WebSocketMessage&& RoutedMessage)
{
- if (m_MsgCallback)
+ switch (RoutedMessage.MessageType())
{
- m_MsgCallback(Msg);
- }
+ case WebSocketMessageType::kResponse:
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end())
+ {
+ It->second.set_value(std::move(RoutedMessage));
+ m_PendingRequests.erase(It);
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWsClient,
+ "route request message FAILED, reason 'unknown correlation ID ({})'",
+ RoutedMessage.CorrelationId());
+ }
+ }
+ break;
+
+ case WebSocketMessageType::kNotification:
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ if (m_NotificationCallback)
+ {
+ m_NotificationCallback(std::move(RoutedMessage));
+ }
+ }
+ break;
+ };
}
} // namespace zen::websocket
@@ -1292,67 +1447,78 @@ namespace zen {
std::atomic_uint32_t WebSocketId::NextId{1};
-void
-WebSocketService::Configure(WebSocketServer& Server)
+bool
+WebSocketMessage::Header::IsValid() const
{
- ZEN_ASSERT(m_Server == nullptr);
-
- m_Server = &Server;
+ return Magic == HeaderMagic && MessageSize > 0 && Crc32 > 0 && uint8_t(MessageType) > 0 && uint8_t(MessageType) < 4;
}
+std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1};
+
void
-WebSocketService::PublishMessage(WebSocketId Id, CbPackage&& Msg)
+WebSocketMessage::SetMessageType(WebSocketMessageType MessageType)
{
- ZEN_ASSERT(m_Server != nullptr);
-
- m_Server->PublishMessage(Id, std::move(Msg));
+ m_Header.MessageType = MessageType;
}
void
-WebSocketService::PublishMessage(WebSocketId Id, CbObject&& Msg)
+WebSocketMessage::SetBody(CbPackage&& Body)
+{
+ m_Body = std::move(Body);
+}
+void
+WebSocketMessage::SetBody(CbObject&& Body)
{
- ZEN_ASSERT(m_Server != nullptr);
-
CbPackage Pkg;
- Pkg.SetObject(std::move(Msg));
+ Pkg.SetObject(Body);
- m_Server->PublishMessage(Id, std::move(Pkg));
+ SetBody(std::move(Pkg));
}
-bool
-WebSocketMessageHeader::IsValid() const
+void
+WebSocketMessage::Save(BinaryWriter& Writer)
{
- return Magic == ExpectedMagic && ContentLength != 0 && Crc32 != 0;
+ Writer.Write(&m_Header, HeaderSize);
+
+ if (m_Body.has_value())
+ {
+ m_Body.value().Save(Writer);
+ }
+
+ if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest)
+ {
+ m_Header.CorrelationId = NextCorrelationId.fetch_add(1);
+ }
+
+ m_Header.MessageSize = Writer.Size() - HeaderSize;
+ m_Header.Crc32 = 1; // TODO
+
+ Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize));
}
bool
-WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeader)
+WebSocketMessage::TryLoadHeader(MemoryView Memory)
{
- if (Memory.GetSize() < sizeof(WebSocketMessageHeader))
+ if (Memory.GetSize() < HeaderSize)
{
return false;
}
- void* Dst = &OutHeader;
- memcpy(Dst, Memory.GetData(), sizeof(WebSocketMessageHeader));
+ MutableMemoryView HeaderView(&m_Header, HeaderSize);
+
+ HeaderView.CopyFrom(Memory);
- return OutHeader.IsValid();
+ return m_Header.IsValid();
}
void
-WebSocketMessageHeader::Write(BinaryWriter& Writer, const CbPackage& Msg)
+WebSocketService::Configure(WebSocketServer& Server)
{
- WebSocketMessageHeader Header;
-
- Writer.Write(&Header, sizeof(WebSocketMessageHeader));
- Msg.Save(Writer);
+ ZEN_ASSERT(m_SocketServer == nullptr);
- Header.Magic = WebSocketMessageHeader::ExpectedMagic;
- Header.ContentLength = Writer.CurrentOffset() - sizeof(WebSocketMessageHeader);
- Header.Crc32 = WebSocketMessageHeader::ExpectedMagic; // TODO
+ m_SocketServer = &Server;
- MutableMemoryView HeaderView(const_cast<uint8_t*>(Writer.Data()), sizeof(WebSocketMessageHeader));
- HeaderView.CopyFrom(MemoryView(&Header, sizeof(WebSocketMessageHeader)));
+ RegisterHandlers(Server);
}
std::unique_ptr<WebSocketServer>
diff --git a/zenserver-test/zenserver-test.cpp b/zenserver-test/zenserver-test.cpp
index 73048b504..78829a2d1 100644
--- a/zenserver-test/zenserver-test.cpp
+++ b/zenserver-test/zenserver-test.cpp
@@ -2083,8 +2083,9 @@ TEST_CASE("http.package")
TEST_CASE("websocket.basic")
{
- std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
- const uint16_t PortNumber = 13337;
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const uint16_t PortNumber = 13337;
+ const auto MaxWaitTime = std::chrono::seconds(5);
ZenServerInstance Inst(TestEnv);
Inst.SetTestDir(TestDir);
@@ -2095,43 +2096,30 @@ TEST_CASE("websocket.basic")
IoDispatcher IoDispatcher(IoCtx);
auto WebSocket = WebSocketClient::Create(IoCtx);
- std::atomic_bool Done{false};
- WebSocketEvent Event;
-
- WebSocket->On(WebSocketEvent::kConnected, [&]() {
- CbObjectWriter Req;
- Req.BeginObject("Header");
- Req << "Method"sv
- << "TestHelloZen"sv;
- Req << "CorrelationId" << uint64_t(1);
- Req.EndObject();
-
- WebSocket->SendMsg(Req.Save());
- });
+ auto ConnectFuture = WebSocket->Connect({.Host = "127.0.0.1", .Port = 8848, .Endpoint = "/zen"});
+ IoDispatcher.Run();
- WebSocket->On(WebSocketEvent::kDisconnected, [&]() {
- Event = WebSocketEvent::kDisconnected;
- Done = true;
- });
+ ConnectFuture.wait_for(MaxWaitTime);
+ CHECK(ConnectFuture.get());
- CbPackage Response;
- WebSocket->OnMessage([&](const CbPackage& Msg) {
- Response = Msg;
- Done = true;
- });
+ for (size_t Idx = 0; Idx < 10; Idx++)
+ {
+ CbObjectWriter Request;
+ Request << "Method"sv
+ << "SayHello"sv;
- WebSocket->Connect({.Host = "127.0.0.1", .Port = 8848, .Endpoint = "/zen"});
+ WebSocketMessage RequestMsg;
+ RequestMsg.SetMessageType(WebSocketMessageType::kRequest);
+ RequestMsg.SetBody(Request.Save());
- IoDispatcher.Run();
+ auto ResponseFuture = WebSocket->SendRequest(std::move(RequestMsg));
+ ResponseFuture.wait_for(MaxWaitTime);
- while (Done == false && IoDispatcher.IsRunning())
- {
- std::this_thread::sleep_for(std::chrono::seconds(2));
- };
+ CbObject Response = ResponseFuture.get().Body().GetObject();
+ std::string_view Message = Response["Result"].AsString();
- CbObject ResponseObject = Response.GetObject();
- std::string_view Message = ResponseObject["Result"].AsString();
- CHECK(Message == "Hello Friend!!"sv);
+ CHECK(Message == "Hello Friend!!"sv);
+ }
WebSocket->Disconnect();
diff --git a/zenserver/testing/httptest.cpp b/zenserver/testing/httptest.cpp
index 41a4f064b..10b69c469 100644
--- a/zenserver/testing/httptest.cpp
+++ b/zenserver/testing/httptest.cpp
@@ -8,6 +8,8 @@
namespace zen {
+using namespace std::literals;
+
HttpTestingService::HttpTestingService()
{
m_Router.RegisterRoute(
@@ -136,29 +138,34 @@ HttpTestingService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
return (InsertResult.first->second = new PackageHandler(*this, RequestId)).Get();
}
-bool
-HttpTestingService::HandleMessage(WebSocketId SocketId, const CbPackage& Msg)
+void
+HttpTestingService::RegisterHandlers(WebSocketServer& Server)
{
- using namespace std::literals;
+ Server.RegisterRequestHandler("SayHello"sv, *this);
+}
- CbObject Request = Msg.GetObject();
+bool
+HttpTestingService::HandleRequest(const WebSocketMessage& RequestMsg)
+{
+ CbObjectView Request = RequestMsg.Body().GetObject();
- CbObjectView Header = Request["Header"sv].AsObjectView();
- std::string_view Method = Header["Method"].AsString();
- const uint64_t CorrelationId = Header["CorrelationId"].AsUInt64();
+ std::string_view Method = Request["Method"].AsString();
- if (Method != "TestHelloZen"sv)
+ if (Method != "SayHello"sv)
{
return false;
}
CbObjectWriter Response;
- Response.BeginObject("Header");
- Response << "CorrelationId"sv << CorrelationId;
- Response.EndObject();
Response.AddString("Result"sv, "Hello Friend!!");
- PublishMessage(SocketId, Response.Save());
+ WebSocketMessage ResponseMsg;
+ ResponseMsg.SetMessageType(WebSocketMessageType::kResponse);
+ ResponseMsg.SetCorrelationId(RequestMsg.CorrelationId());
+ ResponseMsg.SetSocketId(RequestMsg.SocketId());
+ ResponseMsg.SetBody(Response.Save());
+
+ SocketServer().SendResponse(std::move(ResponseMsg));
return true;
}
diff --git a/zenserver/testing/httptest.h b/zenserver/testing/httptest.h
index 267d59b36..57d2d63f3 100644
--- a/zenserver/testing/httptest.h
+++ b/zenserver/testing/httptest.h
@@ -23,7 +23,6 @@ public:
virtual const char* BaseUri() const override;
virtual void HandleRequest(HttpServerRequest& Request) override;
virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest) override;
- virtual bool HandleMessage(WebSocketId SocketId, const CbPackage& Msg) override;
class PackageHandler : public IHttpPackageHandler
{
@@ -42,6 +41,9 @@ public:
};
private:
+ virtual void RegisterHandlers(WebSocketServer& Server) override;
+ virtual bool HandleRequest(const WebSocketMessage& Request) override;
+
HttpRequestRouter m_Router;
std::atomic<uint32_t> m_Counter{0};
metrics::OperationTiming m_TimingStats;