aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/websocketasio.cpp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2023-05-02 10:01:47 +0200
committerGitHub <[email protected]>2023-05-02 10:01:47 +0200
commit075d17f8ada47e990fe94606c3d21df409223465 (patch)
treee50549b766a2f3c354798a54ff73404217b4c9af /src/zenhttp/websocketasio.cpp
parentfix: bundle shouldn't append content zip to zen (diff)
downloadzen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz
zen-075d17f8ada47e990fe94606c3d21df409223465.zip
moved source directories into `/src` (#264)
* moved source directories into `/src` * updated bundle.lua for new `src` path * moved some docs, icon * removed old test trees
Diffstat (limited to 'src/zenhttp/websocketasio.cpp')
-rw-r--r--src/zenhttp/websocketasio.cpp1613
1 files changed, 1613 insertions, 0 deletions
diff --git a/src/zenhttp/websocketasio.cpp b/src/zenhttp/websocketasio.cpp
new file mode 100644
index 000000000..bbe7e1ad8
--- /dev/null
+++ b/src/zenhttp/websocketasio.cpp
@@ -0,0 +1,1613 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/websocket.h>
+
+#include <zencore/base64.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/intmath.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+#include <zencore/sha1.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zencore/trace.h>
+
+#include <chrono>
+#include <optional>
+#include <shared_mutex>
+#include <span>
+#include <system_error>
+#include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <http_parser.h>
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# include <mstcpip.h>
+#endif
+
+namespace zen::websocket {
+
+using namespace std::literals;
+
+ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv);
+
+ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv);
+
+using Clock = std::chrono::steady_clock;
+using TimePoint = Clock::time_point;
+
+///////////////////////////////////////////////////////////////////////////////
+namespace http_header {
+ static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv;
+ static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv;
+ static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv;
+ static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv;
+ static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv;
+ static constexpr std::string_view Upgrade = "Upgrade"sv;
+} // namespace http_header
+
+///////////////////////////////////////////////////////////////////////////////
+enum class ParseMessageStatus : uint32_t
+{
+ kError,
+ kContinue,
+ kDone,
+};
+
+struct ParseMessageResult
+{
+ ParseMessageStatus Status{};
+ size_t ByteCount{};
+ std::optional<std::string> Reason;
+};
+
+class MessageParser
+{
+public:
+ virtual ~MessageParser() = default;
+
+ ParseMessageResult ParseMessage(MemoryView Msg);
+ void Reset();
+
+protected:
+ MessageParser() = default;
+
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0;
+ virtual void OnReset() = 0;
+
+ BinaryWriter m_Stream;
+};
+
+ParseMessageResult
+MessageParser::ParseMessage(MemoryView Msg)
+{
+ return OnParseMessage(Msg);
+}
+
+void
+MessageParser::Reset()
+{
+ OnReset();
+
+ m_Stream.Reset();
+}
+
+///////////////////////////////////////////////////////////////////////////////
+enum class HttpMessageParserType
+{
+ kRequest,
+ kResponse,
+ kBoth
+};
+
+class HttpMessageParser final : public MessageParser
+{
+public:
+ using HttpHeaders = std::unordered_map<std::string_view, std::string_view>;
+
+ HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); }
+
+ virtual ~HttpMessageParser() = default;
+
+ int32_t StatusCode() const { return m_Parser.status_code; }
+ bool IsUpgrade() const { return m_Parser.upgrade != 0; }
+ HttpHeaders& Headers() { return m_Headers; }
+ MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); }
+
+ std::string_view StatusText() const
+ {
+ return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size);
+ }
+
+ bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason);
+
+private:
+ void Initialize();
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
+ virtual void OnReset() override;
+ int OnMessageBegin();
+ int OnUrl(MemoryView Url);
+ int OnStatus(MemoryView Status);
+ int OnHeaderField(MemoryView HeaderField);
+ int OnHeaderValue(MemoryView HeaderValue);
+ int OnHeadersComplete();
+ int OnBody(MemoryView Body);
+ int OnMessageComplete();
+
+ struct StreamEntry
+ {
+ uint64_t Offset{};
+ uint64_t Size{};
+ };
+
+ struct HeaderStreamEntry
+ {
+ StreamEntry Field{};
+ StreamEntry Value{};
+ };
+
+ HttpMessageParserType m_Type;
+ http_parser m_Parser;
+ StreamEntry m_UrlEntry;
+ StreamEntry m_StatusEntry;
+ StreamEntry m_BodyEntry;
+ HeaderStreamEntry m_CurrentHeader;
+ std::vector<HeaderStreamEntry> m_HeaderEntries;
+ HttpHeaders m_Headers;
+ bool m_IsMsgComplete{false};
+
+ static http_parser_settings ParserSettings;
+};
+
+http_parser_settings HttpMessageParser::ParserSettings = {
+ .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); },
+
+ .on_url = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); },
+
+ .on_status = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); },
+
+ .on_header_field = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); },
+
+ .on_header_value = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); },
+
+ .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); },
+
+ .on_body = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); },
+
+ .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }};
+
+void
+HttpMessageParser::Initialize()
+{
+ http_parser_init(&m_Parser,
+ m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST
+ : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE
+ : HTTP_BOTH);
+ m_Parser.data = this;
+
+ m_UrlEntry = {};
+ m_StatusEntry = {};
+ m_CurrentHeader = {};
+ m_BodyEntry = {};
+
+ m_IsMsgComplete = false;
+
+ m_HeaderEntries.clear();
+}
+
+ParseMessageResult
+HttpMessageParser::OnParseMessage(MemoryView Msg)
+{
+ const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize());
+
+ auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue;
+
+ if (m_Parser.http_errno != 0)
+ {
+ Status = ParseMessageStatus::kError;
+ }
+
+ return {.Status = Status, .ByteCount = uint64_t(ByteCount)};
+}
+
+void
+HttpMessageParser::OnReset()
+{
+ Initialize();
+}
+
+int
+HttpMessageParser::OnMessageBegin()
+{
+ ZEN_ASSERT(m_IsMsgComplete == false);
+ ZEN_ASSERT(m_HeaderEntries.empty());
+ ZEN_ASSERT(m_Headers.empty());
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnStatus(MemoryView Status)
+{
+ m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()};
+
+ m_Stream.Write(Status);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnUrl(MemoryView Url)
+{
+ m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()};
+
+ m_Stream.Write(Url);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnHeaderField(MemoryView HeaderField)
+{
+ if (m_CurrentHeader.Value.Size > 0)
+ {
+ m_HeaderEntries.push_back(m_CurrentHeader);
+ m_CurrentHeader = {};
+ }
+
+ if (m_CurrentHeader.Field.Size == 0)
+ {
+ m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset();
+ }
+
+ m_CurrentHeader.Field.Size += HeaderField.GetSize();
+
+ m_Stream.Write(HeaderField);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnHeaderValue(MemoryView HeaderValue)
+{
+ if (m_CurrentHeader.Value.Size == 0)
+ {
+ m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset();
+ }
+
+ m_CurrentHeader.Value.Size += HeaderValue.GetSize();
+
+ m_Stream.Write(HeaderValue);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnHeadersComplete()
+{
+ if (m_CurrentHeader.Value.Size > 0)
+ {
+ m_HeaderEntries.push_back(m_CurrentHeader);
+ m_CurrentHeader = {};
+ }
+
+ m_Headers.clear();
+ m_Headers.reserve(m_HeaderEntries.size());
+
+ const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data());
+
+ for (const auto& Entry : m_HeaderEntries)
+ {
+ auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size);
+ auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size);
+
+ m_Headers.try_emplace(std::move(Field), std::move(Value));
+ }
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnBody(MemoryView Body)
+{
+ m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()};
+
+ m_Stream.Write(Body);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnMessageComplete()
+{
+ m_IsMsgComplete = true;
+
+ return 0;
+}
+
+bool
+HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason)
+{
+ static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv;
+
+ OutAcceptHash = std::string();
+
+ if (m_Headers.contains(http_header::SecWebSocketKey) == false)
+ {
+ OutReason = "Missing header Sec-WebSocket-Key";
+ return false;
+ }
+
+ if (m_Headers.contains(http_header::Upgrade) == false)
+ {
+ OutReason = "Missing header Upgrade";
+ return false;
+ }
+
+ ExtendableStringBuilder<128> Sb;
+ Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid;
+
+ SHA1Stream HashStream;
+ HashStream.Append(Sb.Data(), Sb.Size());
+
+ SHA1 Hash = HashStream.GetHash();
+
+ OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash)));
+ Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data());
+
+ return true;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WebSocketMessageParser final : public MessageParser
+{
+public:
+ WebSocketMessageParser() : MessageParser() {}
+
+ WebSocketMessage ConsumeMessage();
+
+private:
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
+ virtual void OnReset() override;
+
+ WebSocketMessage m_Message;
+};
+
+ParseMessageResult
+WebSocketMessageParser::OnParseMessage(MemoryView Msg)
+{
+ ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage");
+
+ const uint64_t PrevOffset = m_Stream.CurrentOffset();
+
+ if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
+ {
+ const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset();
+
+ m_Stream.Write(Msg.Left(RemaingHeaderSize));
+ Msg += RemaingHeaderSize;
+
+ if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
+ {
+ return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ }
+
+ const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView());
+
+ if (IsValidHeader == false)
+ {
+ OnReset();
+
+ return {.Status = ParseMessageStatus::kError,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
+ .Reason = std::string("Invalid websocket message header")};
+ }
+
+ if (m_Message.MessageSize() == 0)
+ {
+ return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ }
+ }
+
+ ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize);
+
+ if (Msg.IsEmpty() == false)
+ {
+ const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset();
+ m_Stream.Write(Msg.Left(RemaingMessageSize));
+ }
+
+ auto Status = ParseMessageStatus::kContinue;
+
+ if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize())
+ {
+ Status = ParseMessageStatus::kDone;
+
+ BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize));
+
+ CbPackage Pkg;
+ if (Pkg.TryLoad(Reader) == false)
+ {
+ return {.Status = ParseMessageStatus::kError,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
+ .Reason = std::string("Invalid websocket message")};
+ }
+
+ m_Message.SetBody(std::move(Pkg));
+ }
+
+ return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+}
+
+void
+WebSocketMessageParser::OnReset()
+{
+ m_Message = WebSocketMessage();
+}
+
+WebSocketMessage
+WebSocketMessageParser::ConsumeMessage()
+{
+ WebSocketMessage Msg = std::move(m_Message);
+ m_Message = WebSocketMessage();
+
+ return Msg;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsConnection : public std::enable_shared_from_this<WsConnection>
+{
+public:
+ WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket)
+ : m_Id(Id)
+ , m_Socket(std::move(Socket))
+ , m_StartTime(Clock::now())
+ , m_State()
+ {
+ }
+
+ ~WsConnection() = default;
+
+ std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); }
+
+ WebSocketId Id() const { return m_Id; }
+ asio::ip::tcp::socket& Socket() { return *m_Socket; }
+ TimePoint StartTime() const { return m_StartTime; }
+ WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); }
+ std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); }
+ asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
+ WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
+ WebSocketState Close();
+ MessageParser* Parser() { return m_MsgParser.get(); }
+ void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
+ std::mutex& WriteMutex() { return m_WriteMutex; }
+
+private:
+ WebSocketId m_Id;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ TimePoint m_StartTime;
+ std::atomic_uint32_t m_State;
+ std::unique_ptr<MessageParser> m_MsgParser;
+ asio::streambuf m_ReadBuffer;
+ std::mutex m_WriteMutex;
+};
+
+WebSocketState
+WsConnection::Close()
+{
+ const auto PrevState = SetState(WebSocketState::kDisconnected);
+
+ if (PrevState != WebSocketState::kDisconnected && m_Socket->is_open())
+ {
+ m_Socket->close();
+ }
+
+ return PrevState;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsThreadPool
+{
+public:
+ WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {}
+ void Start(uint32_t ThreadCount);
+ void Stop();
+
+private:
+ asio::io_service& m_IoSvc;
+ std::vector<std::thread> m_Threads;
+ std::atomic_bool m_Running{false};
+};
+
+void
+WsThreadPool::Start(uint32_t ThreadCount)
+{
+ ZEN_ASSERT(m_Threads.empty());
+
+ ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount);
+
+ m_Running = true;
+
+ for (uint32_t Idx = 0; Idx < ThreadCount; Idx++)
+ {
+ m_Threads.emplace_back([this, ThreadId = Idx + 1] {
+ for (;;)
+ {
+ if (m_Running == false)
+ {
+ break;
+ }
+
+ try
+ {
+ m_IoSvc.run();
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what());
+ }
+ }
+
+ ZEN_LOG_TRACE(LogWebSocket, "websocket I/O thread '{}' exiting", ThreadId);
+ });
+ }
+}
+
+void
+WsThreadPool::Stop()
+{
+ if (m_Running)
+ {
+ m_Running = false;
+
+ for (std::thread& Thread : m_Threads)
+ {
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
+ }
+
+ m_Threads.clear();
+ }
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsServer final : public WebSocketServer
+{
+public:
+ WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {}
+ virtual ~WsServer() { Shutdown(); }
+
+ virtual bool Run() override;
+ virtual void Shutdown() override;
+
+ virtual void RegisterService(WebSocketService& Service) override;
+ virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override;
+ virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override;
+
+ virtual void SendNotification(WebSocketMessage&& Notification) override;
+ virtual void SendResponse(WebSocketMessage&& Response) override;
+
+private:
+ friend class WsConnection;
+
+ void AcceptConnection();
+ void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec);
+
+ void ReadMessage(std::shared_ptr<WsConnection> Connection);
+ void RouteMessage(WebSocketMessage&& Msg);
+ void SendMessage(WebSocketMessage&& Msg);
+
+ struct IdHasher
+ {
+ size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); }
+ };
+
+ using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>;
+ using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>;
+ using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>;
+
+ WebSocketServerOptions m_Options;
+ asio::io_service m_IoSvc;
+ std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor;
+ std::unique_ptr<WsThreadPool> m_ThreadPool;
+ ConnectionMap m_Connections;
+ std::shared_mutex m_ConnMutex;
+ std::vector<WebSocketService*> m_Services;
+ RequestHandlerMap m_RequestHandlers;
+ NotificationHandlerMap m_NotificationHandlers;
+ std::atomic_bool m_Running{};
+};
+
+void
+WsServer::RegisterService(WebSocketService& Service)
+{
+ m_Services.push_back(&Service);
+
+ Service.Configure(*this);
+}
+
+bool
+WsServer::Run()
+{
+ static constexpr size_t ReceiveBufferSize = 256 << 10;
+ static constexpr size_t SendBufferSize = 256 << 10;
+
+ m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6());
+
+ m_Acceptor->set_option(asio::ip::v6_only(false));
+ m_Acceptor->set_option(asio::socket_base::reuse_address(true));
+ m_Acceptor->set_option(asio::ip::tcp::no_delay(true));
+ m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize));
+ m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize));
+
+#if ZEN_PLATFORM_WINDOWS
+ // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
+ // This must be used by both the client and server side, and is only effective in the absence of
+ // Windows Filtering Platform (WFP) callouts which can be installed by security software.
+ // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
+ SOCKET NativeSocket = m_Acceptor->native_handle();
+ int LoopbackOptionValue = 1;
+ DWORD OptionNumberOfBytesReturned = 0;
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
+#endif
+
+ asio::error_code Ec;
+ m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec);
+
+ if (Ec)
+ {
+ ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value());
+
+ return false;
+ }
+
+ m_Acceptor->listen();
+ m_Running = true;
+
+ ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port);
+
+ AcceptConnection();
+
+ m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc);
+ m_ThreadPool->Start(m_Options.ThreadCount);
+
+ return true;
+}
+
+void
+WsServer::Shutdown()
+{
+ if (m_Running)
+ {
+ ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down");
+
+ m_Running = false;
+
+ m_Acceptor->close();
+ m_Acceptor.reset();
+ m_IoSvc.stop();
+
+ m_ThreadPool->Stop();
+ }
+}
+
+void
+WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service)
+{
+ auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>());
+ Result.first->second.push_back(&Service);
+}
+
+void
+WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service)
+{
+ m_RequestHandlers[Key] = &Service;
+}
+
+void
+WsServer::SendNotification(WebSocketMessage&& Notification)
+{
+ ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification);
+
+ SendMessage(std::move(Notification));
+}
+void
+WsServer::SendResponse(WebSocketMessage&& Response)
+{
+ ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse ||
+ Response.MessageType() == WebSocketMessageType::kStreamResponse ||
+ Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse);
+
+ ZEN_ASSERT(Response.CorrelationId() != 0);
+
+ SendMessage(std::move(Response));
+}
+
+void
+WsServer::AcceptConnection()
+{
+ auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc);
+ asio::ip::tcp::socket& SocketRef = *Socket.get();
+
+ m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable {
+ if (m_Running)
+ {
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message());
+ }
+ else
+ {
+ auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket));
+
+ ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr());
+
+ {
+ std::unique_lock _(m_ConnMutex);
+ m_Connections[Connection->Id()] = Connection;
+ }
+
+ Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest));
+ Connection->SetState(WebSocketState::kHandshaking);
+
+ ReadMessage(Connection);
+ }
+
+ AcceptConnection();
+ }
+ });
+}
+
+void
+WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec)
+{
+ if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected)
+ {
+ if (Ec)
+ {
+ ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, reason '{} ({})'", Connection->Id().Value(), Ec.message(), Ec.value());
+ }
+ else
+ {
+ ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value());
+ }
+ }
+
+ const WebSocketId Id = Connection->Id();
+
+ {
+ std::unique_lock _(m_ConnMutex);
+ if (m_Connections.contains(Id))
+ {
+ m_Connections.erase(Id);
+ }
+ }
+}
+
+void
+WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
+{
+ Connection->ReadBuffer().prepare(64 << 10);
+
+ asio::async_read(
+ Connection->Socket(),
+ Connection->ReadBuffer(),
+ asio::transfer_at_least(1),
+ [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable {
+ if (ReadEc)
+ {
+ return CloseConnection(Connection, ReadEc);
+ }
+
+ switch (Connection->State())
+ {
+ case WebSocketState::kHandshaking:
+ {
+ HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser());
+ asio::const_buffer Buffer = Connection->ReadBuffer().data();
+
+ ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size()));
+
+ Connection->ReadBuffer().consume(Result.ByteCount);
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ return ReadMessage(Connection);
+ }
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWebSocket,
+ "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ return CloseConnection(Connection, std::error_code());
+ }
+
+ if (Parser.IsUpgrade() == false)
+ {
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv;
+
+ return async_write(Connection->Socket(),
+ asio::buffer(UpgradeRequiredResponse),
+ [this, Connection](const asio::error_code& WriteEc, std::size_t) {
+ if (WriteEc)
+ {
+ return CloseConnection(Connection, WriteEc);
+ }
+
+ Connection->Parser()->Reset();
+ Connection->SetState(WebSocketState::kHandshaking);
+
+ ReadMessage(Connection);
+ });
+ }
+
+ ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
+
+ std::string AcceptHash;
+ std::string Reason;
+ const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason);
+
+ if (ValidHandshake == false)
+ {
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '{}' FAILED, reason '{}'",
+ Connection->Id().Value(),
+ Reason);
+
+ constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv;
+
+ return async_write(Connection->Socket(),
+ asio::buffer(UpgradeRequiredResponse),
+ [this, &Connection](const asio::error_code& WriteEc, std::size_t) {
+ if (WriteEc)
+ {
+ return CloseConnection(Connection, WriteEc);
+ }
+
+ Connection->Parser()->Reset();
+ Connection->SetState(WebSocketState::kHandshaking);
+
+ ReadMessage(Connection);
+ });
+ }
+
+ ExtendableStringBuilder<128> Sb;
+
+ Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv;
+ Sb << "Upgrade: websocket\r\n"sv;
+ Sb << "Connection: Upgrade\r\n"sv;
+
+ // TODO: Verify protocol
+ if (Parser.Headers().contains(http_header::SecWebSocketProtocol))
+ {
+ Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol]
+ << "\r\n";
+ }
+
+ Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n";
+ Sb << "\r\n"sv;
+
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "accepting handshake from connection '#{} {}'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ std::string Response = Sb.ToString();
+ Buffer = asio::buffer(Response);
+
+ async_write(Connection->Socket(),
+ Buffer,
+ [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) {
+ if (WriteEc)
+ {
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '{}' FAILED, reason '{}'",
+ Connection->Id().Value(),
+ WriteEc.message());
+
+ return CloseConnection(Connection, WriteEc);
+ }
+
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake ({}B) with connection '#{} {}' OK",
+ ByteCount,
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ Connection->SetParser(std::make_unique<WebSocketMessageParser>());
+ Connection->SetState(WebSocketState::kConnected);
+
+ ReadMessage(Connection);
+ });
+ }
+ break;
+
+ case WebSocketState::kConnected:
+ {
+ WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser());
+
+ uint64_t RemainingBytes = Connection->ReadBuffer().size();
+
+ while (RemainingBytes > 0)
+ {
+ MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes);
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
+
+ Connection->ReadBuffer().consume(Result.ByteCount);
+ RemainingBytes = Connection->ReadBuffer().size();
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
+
+ return CloseConnection(Connection, std::error_code());
+ }
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ ZEN_ASSERT(RemainingBytes == 0);
+ continue;
+ }
+
+ WebSocketMessage Message = Parser.ConsumeMessage();
+ Parser.Reset();
+
+ Message.SetSocketId(Connection->Id());
+
+ RouteMessage(std::move(Message));
+ }
+
+ ReadMessage(Connection);
+ }
+ break;
+
+ default:
+ break;
+ };
+ });
+}
+
+void
+WsServer::RouteMessage(WebSocketMessage&& RoutedMessage)
+{
+ switch (RoutedMessage.MessageType())
+ {
+ case WebSocketMessageType::kRequest:
+ case WebSocketMessageType::kStreamRequest:
+ {
+ CbObjectView Request = RoutedMessage.Body().GetObject();
+ std::string_view Method = Request["Method"].AsString();
+ bool Handled = false;
+ bool Error = false;
+ std::exception Exception;
+
+ if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end())
+ {
+ WebSocketService* Service = It->second;
+ ZEN_ASSERT(Service);
+
+ try
+ {
+ Handled = Service->HandleRequest(std::move(RoutedMessage));
+ }
+ catch (std::exception& Err)
+ {
+ Exception = std::move(Err);
+ Error = true;
+ }
+ }
+
+ if (Error || Handled == false)
+ {
+ std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method);
+
+ ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText);
+
+ CbObjectWriter Response;
+ Response << "Error"sv << ErrorText;
+
+ WebSocketMessage ResponseMsg;
+ ResponseMsg.SetMessageType(WebSocketMessageType::kResponse);
+ ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId());
+ ResponseMsg.SetSocketId(RoutedMessage.SocketId());
+ ResponseMsg.SetBody(Response.Save());
+
+ SendResponse(std::move(ResponseMsg));
+ }
+ }
+ break;
+
+ case WebSocketMessageType::kNotification:
+ {
+ CbObjectView Notification = RoutedMessage.Body().GetObject();
+ std::string_view Message = Notification["Message"].AsString();
+
+ if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end())
+ {
+ std::vector<WebSocketService*>& Handlers = It->second;
+
+ for (WebSocketService* Handler : Handlers)
+ {
+ Handler->HandleNotification(RoutedMessage);
+ }
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message);
+ }
+ }
+ break;
+
+ default:
+ break;
+ };
+}
+
+void
+WsServer::SendMessage(WebSocketMessage&& Msg)
+{
+ std::shared_ptr<WsConnection> Connection;
+
+ {
+ std::unique_lock _(m_ConnMutex);
+
+ if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end())
+ {
+ Connection = It->second;
+ }
+ }
+
+ if (Connection.get() == nullptr)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value());
+ return;
+ }
+
+ if (Connection.get() != nullptr)
+ {
+ BinaryWriter Writer;
+ Msg.Save(Writer);
+
+ ZEN_LOG_TRACE(LogWebSocket,
+ "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}",
+ ToString(Msg.MessageType()),
+ Connection->Id().Value(),
+ Msg.MessageSize(),
+ Msg.CorrelationId(),
+ NiceBytes(Writer.Size()));
+
+ {
+ ZEN_TRACE_CPU("WS::SendMessage");
+ std::unique_lock _(Connection->WriteMutex());
+ ZEN_TRACE_CPU("WS::WriteSocketData");
+ asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size()));
+ }
+ }
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient>
+{
+public:
+ WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {}
+
+ virtual ~WsClient() { Disconnect(); }
+
+ std::shared_ptr<WsClient> AsShared() { return shared_from_this(); }
+
+ virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override;
+ virtual void Disconnect() override;
+ virtual bool IsConnected() const override { return false; }
+ virtual WebSocketState State() const override { return static_cast<WebSocketState>(m_State.load()); }
+
+ virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override;
+ virtual void OnNotification(NotificationCallback&& Cb) override;
+ virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override;
+
+private:
+ WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
+ MessageParser* Parser() { return m_MsgParser.get(); }
+ void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
+ asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
+ void TriggerEvent(WebSocketEvent Evt);
+ void ReadMessage();
+ void RouteMessage(WebSocketMessage&& RoutedMessage);
+
+ using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>;
+
+ asio::io_context& m_IoCtx;
+ WebSocketId m_Id;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<MessageParser> m_MsgParser;
+ asio::streambuf m_ReadBuffer;
+ EventCallback m_EventCallbacks[3];
+ NotificationCallback m_NotificationCallback;
+ PendingRequestMap m_PendingRequests;
+ std::mutex m_RequestMutex;
+ std::promise<bool> m_ConnectPromise;
+ std::atomic_uint32_t m_State;
+ std::string m_Host;
+ int16_t m_Port{};
+};
+
+std::future<bool>
+WsClient::Connect(const WebSocketConnectInfo& Info)
+{
+ if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected)
+ {
+ return m_ConnectPromise.get_future();
+ }
+
+ SetState(WebSocketState::kHandshaking);
+
+ try
+ {
+ asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port);
+ m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol());
+
+ m_Socket->connect(Endpoint);
+
+ m_Host = m_Socket->remote_endpoint().address().to_string();
+ m_Port = Info.Port;
+
+ ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port);
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what());
+
+ SetState(WebSocketState::kError);
+ m_Socket.reset();
+
+ TriggerEvent(WebSocketEvent::kDisconnected);
+
+ m_ConnectPromise.set_value(false);
+
+ return m_ConnectPromise.get_future();
+ }
+
+ ExtendableStringBuilder<128> Sb;
+ Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv;
+ Sb << "Host: " << Info.Host << "\r\n"sv;
+ Sb << "Upgrade: websocket\r\n"sv;
+ Sb << "Connection: upgrade\r\n"sv;
+ Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv;
+
+ if (Info.Protocols.empty() == false)
+ {
+ Sb << "Sec-WebSocket-Protocol: "sv;
+ for (size_t Idx = 0; const auto& Protocol : Info.Protocols)
+ {
+ if (Idx++)
+ {
+ Sb << ", ";
+ }
+ Sb << Protocol;
+ }
+ }
+
+ Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv;
+ Sb << "\r\n";
+
+ std::string HandshakeRequest = Sb.ToString();
+ asio::const_buffer Buffer = asio::buffer(HandshakeRequest);
+
+ ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port);
+
+ m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse);
+ m_MsgParser->Reset();
+
+ async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message());
+
+ Self->Disconnect();
+ }
+ else
+ {
+ Self->ReadMessage();
+ }
+ });
+
+ return m_ConnectPromise.get_future();
+}
+
+void
+WsClient::Disconnect()
+{
+ if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected)
+ {
+ ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port);
+
+ if (m_Socket && m_Socket->is_open())
+ {
+ m_Socket->close();
+ m_Socket.reset();
+ }
+
+ TriggerEvent(WebSocketEvent::kDisconnected);
+
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ for (auto& Kv : m_PendingRequests)
+ {
+ Kv.second.set_value(WebSocketMessage());
+ }
+
+ m_PendingRequests.clear();
+ }
+ }
+}
+
+std::future<WebSocketMessage>
+WsClient::SendRequest(WebSocketMessage&& Request)
+{
+ ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest);
+
+ BinaryWriter Writer;
+ Request.Save(Writer);
+
+ std::future<WebSocketMessage> FutureResponse;
+
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>());
+ ZEN_ASSERT(Result.second);
+
+ auto It = Result.first;
+ FutureResponse = It->second.get_future();
+ }
+
+ IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+
+ async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) {
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message());
+
+ Self->Disconnect();
+ }
+ });
+
+ return FutureResponse;
+}
+
+void
+WsClient::OnNotification(NotificationCallback&& Cb)
+{
+ m_NotificationCallback = std::move(Cb);
+}
+
+void
+WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
+{
+ m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb);
+}
+
+void
+WsClient::TriggerEvent(WebSocketEvent Evt)
+{
+ const uint32_t Index = static_cast<uint32_t>(Evt);
+
+ if (m_EventCallbacks[Index])
+ {
+ m_EventCallbacks[Index]();
+ }
+}
+
+void
+WsClient::ReadMessage()
+{
+ m_ReadBuffer.prepare(64 << 10);
+
+ async_read(*m_Socket,
+ m_ReadBuffer,
+ asio::transfer_at_least(1),
+ [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable {
+ const WebSocketState State = Self->State();
+
+ if (State == WebSocketState::kDisconnected)
+ {
+ return;
+ }
+
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message());
+
+ return Self->Disconnect();
+ }
+
+ switch (State)
+ {
+ case WebSocketState::kHandshaking:
+ {
+ HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser());
+
+ MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount);
+
+ ParseMessageResult Result = Parser.ParseMessage(MessageData);
+
+ Self->ReadBuffer().consume(size_t(Result.ByteCount));
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode());
+
+ Self->m_ConnectPromise.set_value(false);
+
+ return Self->Disconnect();
+ }
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ return Self->ReadMessage();
+ }
+
+ ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
+
+ if (Parser.StatusCode() != 101)
+ {
+ ZEN_LOG_WARN(LogWsClient,
+ "handshake FAILED, status '{}', status code '{}'",
+ Parser.StatusText(),
+ Parser.StatusCode());
+
+ Self->m_ConnectPromise.set_value(false);
+
+ return Self->Disconnect();
+ }
+
+ ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText());
+
+ Self->SetParser(std::make_unique<WebSocketMessageParser>());
+ Self->SetState(WebSocketState::kConnected);
+ Self->ReadMessage();
+ Self->TriggerEvent(WebSocketEvent::kConnected);
+
+ Self->m_ConnectPromise.set_value(true);
+ }
+ break;
+
+ case WebSocketState::kConnected:
+ {
+ WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser());
+
+ uint64_t RemainingBytes = Self->ReadBuffer().size();
+
+ while (RemainingBytes > 0)
+ {
+ MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes);
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
+
+ Self->ReadBuffer().consume(Result.ByteCount);
+ RemainingBytes = Self->ReadBuffer().size();
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
+
+ Parser.Reset();
+ continue;
+ }
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ ZEN_ASSERT(RemainingBytes == 0);
+ continue;
+ }
+
+ WebSocketMessage Message = Parser.ConsumeMessage();
+ Parser.Reset();
+
+ Self->RouteMessage(std::move(Message));
+ }
+
+ Self->ReadMessage();
+ }
+ break;
+
+ default:
+ break;
+ }
+ });
+}
+
+void
+WsClient::RouteMessage(WebSocketMessage&& RoutedMessage)
+{
+ switch (RoutedMessage.MessageType())
+ {
+ case WebSocketMessageType::kResponse:
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end())
+ {
+ It->second.set_value(std::move(RoutedMessage));
+ m_PendingRequests.erase(It);
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWsClient,
+ "route request message FAILED, reason 'unknown correlation ID ({})'",
+ RoutedMessage.CorrelationId());
+ }
+ }
+ break;
+
+ case WebSocketMessageType::kNotification:
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ if (m_NotificationCallback)
+ {
+ m_NotificationCallback(std::move(RoutedMessage));
+ }
+ }
+ break;
+
+ default:
+ ZEN_LOG_WARN(LogWsClient, "route message FAILED, reason 'invalid message type ({})'", uint8_t(RoutedMessage.MessageType()));
+ break;
+ };
+}
+
+} // namespace zen::websocket
+
+namespace zen {
+
+std::atomic_uint32_t WebSocketId::NextId{1};
+
+bool
+WebSocketMessage::Header::IsValid() const
+{
+ return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) &&
+ uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount);
+}
+
+std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1};
+
+void
+WebSocketMessage::SetMessageType(WebSocketMessageType MessageType)
+{
+ m_Header.MessageType = MessageType;
+}
+
+void
+WebSocketMessage::SetBody(CbPackage&& Body)
+{
+ m_Body = std::move(Body);
+}
+void
+WebSocketMessage::SetBody(CbObject&& Body)
+{
+ CbPackage Pkg;
+ Pkg.SetObject(Body);
+
+ SetBody(std::move(Pkg));
+}
+
+void
+WebSocketMessage::Save(BinaryWriter& Writer)
+{
+ Writer.Write(&m_Header, HeaderSize);
+
+ if (m_Body.has_value())
+ {
+ const CbObject& Obj = m_Body.value().GetObject();
+ MemoryView View = Obj.GetBuffer().GetView();
+
+ const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All);
+ ZEN_ASSERT(ValidationResult == CbValidateError::None);
+
+ m_Body.value().Save(Writer);
+ }
+
+ if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest)
+ {
+ m_Header.CorrelationId = NextCorrelationId.fetch_add(1);
+ }
+
+ m_Header.MessageSize = Writer.Size() - HeaderSize;
+
+ Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize));
+}
+
+bool
+WebSocketMessage::TryLoadHeader(MemoryView Memory)
+{
+ if (Memory.GetSize() < HeaderSize)
+ {
+ return false;
+ }
+
+ MutableMemoryView HeaderView(&m_Header, HeaderSize);
+
+ HeaderView.CopyFrom(Memory);
+
+ return m_Header.IsValid();
+}
+
+void
+WebSocketService::Configure(WebSocketServer& Server)
+{
+ ZEN_ASSERT(m_SocketServer == nullptr);
+
+ m_SocketServer = &Server;
+
+ RegisterHandlers(Server);
+}
+
+void
+WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete)
+{
+ WebSocketMessage Message;
+
+ Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse);
+ Message.SetCorrelationId(CorrelationId);
+ Message.SetSocketId(SocketId);
+ Message.SetBody(std::move(StreamResponse));
+
+ SocketServer().SendResponse(std::move(Message));
+}
+
+void
+WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete)
+{
+ CbPackage Response;
+ Response.SetObject(std::move(StreamResponse));
+
+ SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete);
+}
+
+std::unique_ptr<WebSocketServer>
+WebSocketServer::Create(const WebSocketServerOptions& Options)
+{
+ return std::make_unique<websocket::WsServer>(Options);
+}
+
+std::shared_ptr<WebSocketClient>
+WebSocketClient::Create(asio::io_context& IoCtx)
+{
+ return std::make_shared<websocket::WsClient>(IoCtx);
+}
+
+} // namespace zen