diff options
| author | Stefan Boberg <[email protected]> | 2023-05-02 10:01:47 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-05-02 10:01:47 +0200 |
| commit | 075d17f8ada47e990fe94606c3d21df409223465 (patch) | |
| tree | e50549b766a2f3c354798a54ff73404217b4c9af /src/zenhttp/websocketasio.cpp | |
| parent | fix: bundle shouldn't append content zip to zen (diff) | |
| download | zen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz zen-075d17f8ada47e990fe94606c3d21df409223465.zip | |
moved source directories into `/src` (#264)
* moved source directories into `/src`
* updated bundle.lua for new `src` path
* moved some docs, icon
* removed old test trees
Diffstat (limited to 'src/zenhttp/websocketasio.cpp')
| -rw-r--r-- | src/zenhttp/websocketasio.cpp | 1613 |
1 files changed, 1613 insertions, 0 deletions
diff --git a/src/zenhttp/websocketasio.cpp b/src/zenhttp/websocketasio.cpp new file mode 100644 index 000000000..bbe7e1ad8 --- /dev/null +++ b/src/zenhttp/websocketasio.cpp @@ -0,0 +1,1613 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/websocket.h> + +#include <zencore/base64.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/intmath.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/sha1.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/trace.h> + +#include <chrono> +#include <optional> +#include <shared_mutex> +#include <span> +#include <system_error> +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <http_parser.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# include <mstcpip.h> +#endif + +namespace zen::websocket { + +using namespace std::literals; + +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); + +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv); + +using Clock = std::chrono::steady_clock; +using TimePoint = Clock::time_point; + +/////////////////////////////////////////////////////////////////////////////// +namespace http_header { + static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; + static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; + static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; + static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; + static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; + static constexpr std::string_view Upgrade = "Upgrade"sv; +} // namespace http_header + +/////////////////////////////////////////////////////////////////////////////// +enum class ParseMessageStatus : uint32_t +{ + kError, + kContinue, + kDone, +}; + +struct ParseMessageResult +{ + ParseMessageStatus Status{}; + size_t ByteCount{}; + std::optional<std::string> Reason; +}; + +class MessageParser +{ +public: + virtual ~MessageParser() = default; + + ParseMessageResult ParseMessage(MemoryView Msg); + void Reset(); + +protected: + MessageParser() = default; + + virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; + virtual void OnReset() = 0; + + BinaryWriter m_Stream; +}; + +ParseMessageResult +MessageParser::ParseMessage(MemoryView Msg) +{ + return OnParseMessage(Msg); +} + +void +MessageParser::Reset() +{ + OnReset(); + + m_Stream.Reset(); +} + +/////////////////////////////////////////////////////////////////////////////// +enum class HttpMessageParserType +{ + kRequest, + kResponse, + kBoth +}; + +class HttpMessageParser final : public MessageParser +{ +public: + using HttpHeaders = std::unordered_map<std::string_view, std::string_view>; + + HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); } + + virtual ~HttpMessageParser() = default; + + int32_t StatusCode() const { return m_Parser.status_code; } + bool IsUpgrade() const { return m_Parser.upgrade != 0; } + HttpHeaders& Headers() { return m_Headers; } + MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } + + std::string_view StatusText() const + { + return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); + } + + bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); + +private: + void Initialize(); + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; + int OnMessageBegin(); + int OnUrl(MemoryView Url); + int OnStatus(MemoryView Status); + int OnHeaderField(MemoryView HeaderField); + int OnHeaderValue(MemoryView HeaderValue); + int OnHeadersComplete(); + int OnBody(MemoryView Body); + int OnMessageComplete(); + + struct StreamEntry + { + uint64_t Offset{}; + uint64_t Size{}; + }; + + struct HeaderStreamEntry + { + StreamEntry Field{}; + StreamEntry Value{}; + }; + + HttpMessageParserType m_Type; + http_parser m_Parser; + StreamEntry m_UrlEntry; + StreamEntry m_StatusEntry; + StreamEntry m_BodyEntry; + HeaderStreamEntry m_CurrentHeader; + std::vector<HeaderStreamEntry> m_HeaderEntries; + HttpHeaders m_Headers; + bool m_IsMsgComplete{false}; + + static http_parser_settings ParserSettings; +}; + +http_parser_settings HttpMessageParser::ParserSettings = { + .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); }, + + .on_url = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); }, + + .on_status = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); }, + + .on_header_field = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); }, + + .on_header_value = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, + + .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); }, + + .on_body = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); }, + + .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }}; + +void +HttpMessageParser::Initialize() +{ + http_parser_init(&m_Parser, + m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST + : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE + : HTTP_BOTH); + m_Parser.data = this; + + m_UrlEntry = {}; + m_StatusEntry = {}; + m_CurrentHeader = {}; + m_BodyEntry = {}; + + m_IsMsgComplete = false; + + m_HeaderEntries.clear(); +} + +ParseMessageResult +HttpMessageParser::OnParseMessage(MemoryView Msg) +{ + const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize()); + + auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; + + if (m_Parser.http_errno != 0) + { + Status = ParseMessageStatus::kError; + } + + return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; +} + +void +HttpMessageParser::OnReset() +{ + Initialize(); +} + +int +HttpMessageParser::OnMessageBegin() +{ + ZEN_ASSERT(m_IsMsgComplete == false); + ZEN_ASSERT(m_HeaderEntries.empty()); + ZEN_ASSERT(m_Headers.empty()); + + return 0; +} + +int +HttpMessageParser::OnStatus(MemoryView Status) +{ + m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; + + m_Stream.Write(Status); + + return 0; +} + +int +HttpMessageParser::OnUrl(MemoryView Url) +{ + m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; + + m_Stream.Write(Url); + + return 0; +} + +int +HttpMessageParser::OnHeaderField(MemoryView HeaderField) +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } + + if (m_CurrentHeader.Field.Size == 0) + { + m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Field.Size += HeaderField.GetSize(); + + m_Stream.Write(HeaderField); + + return 0; +} + +int +HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) +{ + if (m_CurrentHeader.Value.Size == 0) + { + m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Value.Size += HeaderValue.GetSize(); + + m_Stream.Write(HeaderValue); + + return 0; +} + +int +HttpMessageParser::OnHeadersComplete() +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } + + m_Headers.clear(); + m_Headers.reserve(m_HeaderEntries.size()); + + const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data()); + + for (const auto& Entry : m_HeaderEntries) + { + auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); + auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); + + m_Headers.try_emplace(std::move(Field), std::move(Value)); + } + + return 0; +} + +int +HttpMessageParser::OnBody(MemoryView Body) +{ + m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; + + m_Stream.Write(Body); + + return 0; +} + +int +HttpMessageParser::OnMessageComplete() +{ + m_IsMsgComplete = true; + + return 0; +} + +bool +HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) +{ + static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + + OutAcceptHash = std::string(); + + if (m_Headers.contains(http_header::SecWebSocketKey) == false) + { + OutReason = "Missing header Sec-WebSocket-Key"; + return false; + } + + if (m_Headers.contains(http_header::Upgrade) == false) + { + OutReason = "Missing header Upgrade"; + return false; + } + + ExtendableStringBuilder<128> Sb; + Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; + + SHA1Stream HashStream; + HashStream.Append(Sb.Data(), Sb.Size()); + + SHA1 Hash = HashStream.GetHash(); + + OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); + Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); + + return true; +} + +/////////////////////////////////////////////////////////////////////////////// +class WebSocketMessageParser final : public MessageParser +{ +public: + WebSocketMessageParser() : MessageParser() {} + + WebSocketMessage ConsumeMessage(); + +private: + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; + + WebSocketMessage m_Message; +}; + +ParseMessageResult +WebSocketMessageParser::OnParseMessage(MemoryView Msg) +{ + ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage"); + + const uint64_t PrevOffset = m_Stream.CurrentOffset(); + + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) + { + const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); + + m_Stream.Write(Msg.Left(RemaingHeaderSize)); + Msg += RemaingHeaderSize; + + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) + { + return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } + + const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView()); + + if (IsValidHeader == false) + { + OnReset(); + + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message header")}; + } + + if (m_Message.MessageSize() == 0) + { + return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } + } + + ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); + + if (Msg.IsEmpty() == false) + { + const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); + m_Stream.Write(Msg.Left(RemaingMessageSize)); + } + + auto Status = ParseMessageStatus::kContinue; + + if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize()) + { + Status = ParseMessageStatus::kDone; + + BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize)); + + CbPackage Pkg; + if (Pkg.TryLoad(Reader) == false) + { + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message")}; + } + + m_Message.SetBody(std::move(Pkg)); + } + + return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; +} + +void +WebSocketMessageParser::OnReset() +{ + m_Message = WebSocketMessage(); +} + +WebSocketMessage +WebSocketMessageParser::ConsumeMessage() +{ + WebSocketMessage Msg = std::move(m_Message); + m_Message = WebSocketMessage(); + + return Msg; +} + +/////////////////////////////////////////////////////////////////////////////// +class WsConnection : public std::enable_shared_from_this<WsConnection> +{ +public: + WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) + : m_Id(Id) + , m_Socket(std::move(Socket)) + , m_StartTime(Clock::now()) + , m_State() + { + } + + ~WsConnection() = default; + + std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } + + WebSocketId Id() const { return m_Id; } + asio::ip::tcp::socket& Socket() { return *m_Socket; } + TimePoint StartTime() const { return m_StartTime; } + WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); } + std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } + WebSocketState Close(); + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + std::mutex& WriteMutex() { return m_WriteMutex; } + +private: + WebSocketId m_Id; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + TimePoint m_StartTime; + std::atomic_uint32_t m_State; + std::unique_ptr<MessageParser> m_MsgParser; + asio::streambuf m_ReadBuffer; + std::mutex m_WriteMutex; +}; + +WebSocketState +WsConnection::Close() +{ + const auto PrevState = SetState(WebSocketState::kDisconnected); + + if (PrevState != WebSocketState::kDisconnected && m_Socket->is_open()) + { + m_Socket->close(); + } + + return PrevState; +} + +/////////////////////////////////////////////////////////////////////////////// +class WsThreadPool +{ +public: + WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} + void Start(uint32_t ThreadCount); + void Stop(); + +private: + asio::io_service& m_IoSvc; + std::vector<std::thread> m_Threads; + std::atomic_bool m_Running{false}; +}; + +void +WsThreadPool::Start(uint32_t ThreadCount) +{ + ZEN_ASSERT(m_Threads.empty()); + + ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount); + + m_Running = true; + + for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) + { + m_Threads.emplace_back([this, ThreadId = Idx + 1] { + for (;;) + { + if (m_Running == false) + { + break; + } + + try + { + m_IoSvc.run(); + } + catch (std::exception& Err) + { + ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what()); + } + } + + ZEN_LOG_TRACE(LogWebSocket, "websocket I/O thread '{}' exiting", ThreadId); + }); + } +} + +void +WsThreadPool::Stop() +{ + if (m_Running) + { + m_Running = false; + + for (std::thread& Thread : m_Threads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Threads.clear(); + } +} + +/////////////////////////////////////////////////////////////////////////////// +class WsServer final : public WebSocketServer +{ +public: + WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {} + virtual ~WsServer() { Shutdown(); } + + virtual bool Run() override; + virtual void Shutdown() override; + + virtual void RegisterService(WebSocketService& Service) override; + virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override; + virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override; + + virtual void SendNotification(WebSocketMessage&& Notification) override; + virtual void SendResponse(WebSocketMessage&& Response) override; + +private: + friend class WsConnection; + + void AcceptConnection(); + void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec); + + void ReadMessage(std::shared_ptr<WsConnection> Connection); + void RouteMessage(WebSocketMessage&& Msg); + void SendMessage(WebSocketMessage&& Msg); + + struct IdHasher + { + size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } + }; + + using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>; + using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>; + using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>; + + WebSocketServerOptions m_Options; + asio::io_service m_IoSvc; + std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; + std::unique_ptr<WsThreadPool> m_ThreadPool; + ConnectionMap m_Connections; + std::shared_mutex m_ConnMutex; + std::vector<WebSocketService*> m_Services; + RequestHandlerMap m_RequestHandlers; + NotificationHandlerMap m_NotificationHandlers; + std::atomic_bool m_Running{}; +}; + +void +WsServer::RegisterService(WebSocketService& Service) +{ + m_Services.push_back(&Service); + + Service.Configure(*this); +} + +bool +WsServer::Run() +{ + static constexpr size_t ReceiveBufferSize = 256 << 10; + static constexpr size_t SendBufferSize = 256 << 10; + + m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6()); + + m_Acceptor->set_option(asio::ip::v6_only(false)); + m_Acceptor->set_option(asio::socket_base::reuse_address(true)); + m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize)); + m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize)); + +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor->native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +#endif + + asio::error_code Ec; + m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec); + + if (Ec) + { + ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value()); + + return false; + } + + m_Acceptor->listen(); + m_Running = true; + + ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port); + + AcceptConnection(); + + m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc); + m_ThreadPool->Start(m_Options.ThreadCount); + + return true; +} + +void +WsServer::Shutdown() +{ + if (m_Running) + { + ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down"); + + m_Running = false; + + m_Acceptor->close(); + m_Acceptor.reset(); + m_IoSvc.stop(); + + m_ThreadPool->Stop(); + } +} + +void +WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) +{ + auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>()); + Result.first->second.push_back(&Service); +} + +void +WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service) +{ + m_RequestHandlers[Key] = &Service; +} + +void +WsServer::SendNotification(WebSocketMessage&& Notification) +{ + ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification); + + SendMessage(std::move(Notification)); +} +void +WsServer::SendResponse(WebSocketMessage&& Response) +{ + ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse || + Response.MessageType() == WebSocketMessageType::kStreamResponse || + Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse); + + ZEN_ASSERT(Response.CorrelationId() != 0); + + SendMessage(std::move(Response)); +} + +void +WsServer::AcceptConnection() +{ + auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc); + asio::ip::tcp::socket& SocketRef = *Socket.get(); + + m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { + if (m_Running) + { + if (Ec) + { + ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message()); + } + else + { + auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket)); + + ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); + + { + std::unique_lock _(m_ConnMutex); + m_Connections[Connection->Id()] = Connection; + } + + Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest)); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + } + + AcceptConnection(); + } + }); +} + +void +WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec) +{ + if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected) + { + if (Ec) + { + ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, reason '{} ({})'", Connection->Id().Value(), Ec.message(), Ec.value()); + } + else + { + ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value()); + } + } + + const WebSocketId Id = Connection->Id(); + + { + std::unique_lock _(m_ConnMutex); + if (m_Connections.contains(Id)) + { + m_Connections.erase(Id); + } + } +} + +void +WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) +{ + Connection->ReadBuffer().prepare(64 << 10); + + asio::async_read( + Connection->Socket(), + Connection->ReadBuffer(), + asio::transfer_at_least(1), + [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable { + if (ReadEc) + { + return CloseConnection(Connection, ReadEc); + } + + switch (Connection->State()) + { + case WebSocketState::kHandshaking: + { + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser()); + asio::const_buffer Buffer = Connection->ReadBuffer().data(); + + ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + + Connection->ReadBuffer().consume(Result.ByteCount); + + if (Result.Status == ParseMessageStatus::kContinue) + { + return ReadMessage(Connection); + } + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + return CloseConnection(Connection, std::error_code()); + } + + if (Parser.IsUpgrade() == false) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; + + return async_write(Connection->Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + return CloseConnection(Connection, WriteEc); + } + + Connection->Parser()->Reset(); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + }); + } + + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + + std::string AcceptHash; + std::string Reason; + const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason); + + if (ValidHandshake == false) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + Reason); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; + + return async_write(Connection->Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + return CloseConnection(Connection, WriteEc); + } + + Connection->Parser()->Reset(); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + }); + } + + ExtendableStringBuilder<128> Sb; + + Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: Upgrade\r\n"sv; + + // TODO: Verify protocol + if (Parser.Headers().contains(http_header::SecWebSocketProtocol)) + { + Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol] + << "\r\n"; + } + + Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; + Sb << "\r\n"sv; + + ZEN_LOG_DEBUG(LogWebSocket, + "accepting handshake from connection '#{} {}'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + std::string Response = Sb.ToString(); + Buffer = asio::buffer(Response); + + async_write(Connection->Socket(), + Buffer, + [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) { + if (WriteEc) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + WriteEc.message()); + + return CloseConnection(Connection, WriteEc); + } + + ZEN_LOG_DEBUG(LogWebSocket, + "handshake ({}B) with connection '#{} {}' OK", + ByteCount, + Connection->Id().Value(), + Connection->RemoteAddr()); + + Connection->SetParser(std::make_unique<WebSocketMessageParser>()); + Connection->SetState(WebSocketState::kConnected); + + ReadMessage(Connection); + }); + } + break; + + case WebSocketState::kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser()); + + uint64_t RemainingBytes = Connection->ReadBuffer().size(); + + while (RemainingBytes > 0) + { + MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Connection->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Connection->ReadBuffer().size(); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + + return CloseConnection(Connection, std::error_code()); + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(RemainingBytes == 0); + continue; + } + + WebSocketMessage Message = Parser.ConsumeMessage(); + Parser.Reset(); + + Message.SetSocketId(Connection->Id()); + + RouteMessage(std::move(Message)); + } + + ReadMessage(Connection); + } + break; + + default: + break; + }; + }); +} + +void +WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) +{ + switch (RoutedMessage.MessageType()) + { + case WebSocketMessageType::kRequest: + case WebSocketMessageType::kStreamRequest: + { + CbObjectView Request = RoutedMessage.Body().GetObject(); + std::string_view Method = Request["Method"].AsString(); + bool Handled = false; + bool Error = false; + std::exception Exception; + + if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end()) + { + WebSocketService* Service = It->second; + ZEN_ASSERT(Service); + + try + { + Handled = Service->HandleRequest(std::move(RoutedMessage)); + } + catch (std::exception& Err) + { + Exception = std::move(Err); + Error = true; + } + } + + if (Error || Handled == false) + { + std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method); + + ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); + + CbObjectWriter Response; + Response << "Error"sv << ErrorText; + + WebSocketMessage ResponseMsg; + ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); + ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId()); + ResponseMsg.SetSocketId(RoutedMessage.SocketId()); + ResponseMsg.SetBody(Response.Save()); + + SendResponse(std::move(ResponseMsg)); + } + } + break; + + case WebSocketMessageType::kNotification: + { + CbObjectView Notification = RoutedMessage.Body().GetObject(); + std::string_view Message = Notification["Message"].AsString(); + + if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end()) + { + std::vector<WebSocketService*>& Handlers = It->second; + + for (WebSocketService* Handler : Handlers) + { + Handler->HandleNotification(RoutedMessage); + } + } + else + { + ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message); + } + } + break; + + default: + break; + }; +} + +void +WsServer::SendMessage(WebSocketMessage&& Msg) +{ + std::shared_ptr<WsConnection> Connection; + + { + std::unique_lock _(m_ConnMutex); + + if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end()) + { + Connection = It->second; + } + } + + if (Connection.get() == nullptr) + { + ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value()); + return; + } + + if (Connection.get() != nullptr) + { + BinaryWriter Writer; + Msg.Save(Writer); + + ZEN_LOG_TRACE(LogWebSocket, + "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}", + ToString(Msg.MessageType()), + Connection->Id().Value(), + Msg.MessageSize(), + Msg.CorrelationId(), + NiceBytes(Writer.Size())); + + { + ZEN_TRACE_CPU("WS::SendMessage"); + std::unique_lock _(Connection->WriteMutex()); + ZEN_TRACE_CPU("WS::WriteSocketData"); + asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size())); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient> +{ +public: + WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {} + + virtual ~WsClient() { Disconnect(); } + + std::shared_ptr<WsClient> AsShared() { return shared_from_this(); } + + virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override; + virtual void Disconnect() override; + virtual bool IsConnected() const override { return false; } + virtual WebSocketState State() const override { return static_cast<WebSocketState>(m_State.load()); } + + virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override; + virtual void OnNotification(NotificationCallback&& Cb) override; + virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override; + +private: + WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + void TriggerEvent(WebSocketEvent Evt); + void ReadMessage(); + void RouteMessage(WebSocketMessage&& RoutedMessage); + + using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>; + + asio::io_context& m_IoCtx; + WebSocketId m_Id; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<MessageParser> m_MsgParser; + asio::streambuf m_ReadBuffer; + EventCallback m_EventCallbacks[3]; + NotificationCallback m_NotificationCallback; + PendingRequestMap m_PendingRequests; + std::mutex m_RequestMutex; + std::promise<bool> m_ConnectPromise; + std::atomic_uint32_t m_State; + std::string m_Host; + int16_t m_Port{}; +}; + +std::future<bool> +WsClient::Connect(const WebSocketConnectInfo& Info) +{ + if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) + { + return m_ConnectPromise.get_future(); + } + + SetState(WebSocketState::kHandshaking); + + try + { + asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port); + m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol()); + + m_Socket->connect(Endpoint); + + m_Host = m_Socket->remote_endpoint().address().to_string(); + m_Port = Info.Port; + + ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port); + } + catch (std::exception& Err) + { + ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); + + SetState(WebSocketState::kError); + m_Socket.reset(); + + TriggerEvent(WebSocketEvent::kDisconnected); + + m_ConnectPromise.set_value(false); + + return m_ConnectPromise.get_future(); + } + + ExtendableStringBuilder<128> Sb; + Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv; + Sb << "Host: " << Info.Host << "\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: upgrade\r\n"sv; + Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv; + + if (Info.Protocols.empty() == false) + { + Sb << "Sec-WebSocket-Protocol: "sv; + for (size_t Idx = 0; const auto& Protocol : Info.Protocols) + { + if (Idx++) + { + Sb << ", "; + } + Sb << Protocol; + } + } + + Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv; + Sb << "\r\n"; + + std::string HandshakeRequest = Sb.ToString(); + asio::const_buffer Buffer = asio::buffer(HandshakeRequest); + + ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port); + + m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse); + m_MsgParser->Reset(); + + async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message()); + + Self->Disconnect(); + } + else + { + Self->ReadMessage(); + } + }); + + return m_ConnectPromise.get_future(); +} + +void +WsClient::Disconnect() +{ + if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) + { + ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port); + + if (m_Socket && m_Socket->is_open()) + { + m_Socket->close(); + m_Socket.reset(); + } + + TriggerEvent(WebSocketEvent::kDisconnected); + + { + std::unique_lock _(m_RequestMutex); + + for (auto& Kv : m_PendingRequests) + { + Kv.second.set_value(WebSocketMessage()); + } + + m_PendingRequests.clear(); + } + } +} + +std::future<WebSocketMessage> +WsClient::SendRequest(WebSocketMessage&& Request) +{ + ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest); + + BinaryWriter Writer; + Request.Save(Writer); + + std::future<WebSocketMessage> FutureResponse; + + { + std::unique_lock _(m_RequestMutex); + + auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>()); + ZEN_ASSERT(Result.second); + + auto It = Result.first; + FutureResponse = It->second.get_future(); + } + + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) { + if (Ec) + { + ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message()); + + Self->Disconnect(); + } + }); + + return FutureResponse; +} + +void +WsClient::OnNotification(NotificationCallback&& Cb) +{ + m_NotificationCallback = std::move(Cb); +} + +void +WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) +{ + m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); +} + +void +WsClient::TriggerEvent(WebSocketEvent Evt) +{ + const uint32_t Index = static_cast<uint32_t>(Evt); + + if (m_EventCallbacks[Index]) + { + m_EventCallbacks[Index](); + } +} + +void +WsClient::ReadMessage() +{ + m_ReadBuffer.prepare(64 << 10); + + async_read(*m_Socket, + m_ReadBuffer, + asio::transfer_at_least(1), + [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable { + const WebSocketState State = Self->State(); + + if (State == WebSocketState::kDisconnected) + { + return; + } + + if (Ec) + { + ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message()); + + return Self->Disconnect(); + } + + switch (State) + { + case WebSocketState::kHandshaking: + { + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser()); + + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); + + ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Self->ReadBuffer().consume(size_t(Result.ByteCount)); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); + + Self->m_ConnectPromise.set_value(false); + + return Self->Disconnect(); + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + return Self->ReadMessage(); + } + + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + + if (Parser.StatusCode() != 101) + { + ZEN_LOG_WARN(LogWsClient, + "handshake FAILED, status '{}', status code '{}'", + Parser.StatusText(), + Parser.StatusCode()); + + Self->m_ConnectPromise.set_value(false); + + return Self->Disconnect(); + } + + ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText()); + + Self->SetParser(std::make_unique<WebSocketMessageParser>()); + Self->SetState(WebSocketState::kConnected); + Self->ReadMessage(); + Self->TriggerEvent(WebSocketEvent::kConnected); + + Self->m_ConnectPromise.set_value(true); + } + break; + + case WebSocketState::kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser()); + + uint64_t RemainingBytes = Self->ReadBuffer().size(); + + while (RemainingBytes > 0) + { + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Self->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Self->ReadBuffer().size(); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + + Parser.Reset(); + continue; + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(RemainingBytes == 0); + continue; + } + + WebSocketMessage Message = Parser.ConsumeMessage(); + Parser.Reset(); + + Self->RouteMessage(std::move(Message)); + } + + Self->ReadMessage(); + } + break; + + default: + break; + } + }); +} + +void +WsClient::RouteMessage(WebSocketMessage&& RoutedMessage) +{ + switch (RoutedMessage.MessageType()) + { + case WebSocketMessageType::kResponse: + { + std::unique_lock _(m_RequestMutex); + + if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end()) + { + It->second.set_value(std::move(RoutedMessage)); + m_PendingRequests.erase(It); + } + else + { + ZEN_LOG_WARN(LogWsClient, + "route request message FAILED, reason 'unknown correlation ID ({})'", + RoutedMessage.CorrelationId()); + } + } + break; + + case WebSocketMessageType::kNotification: + { + std::unique_lock _(m_RequestMutex); + + if (m_NotificationCallback) + { + m_NotificationCallback(std::move(RoutedMessage)); + } + } + break; + + default: + ZEN_LOG_WARN(LogWsClient, "route message FAILED, reason 'invalid message type ({})'", uint8_t(RoutedMessage.MessageType())); + break; + }; +} + +} // namespace zen::websocket + +namespace zen { + +std::atomic_uint32_t WebSocketId::NextId{1}; + +bool +WebSocketMessage::Header::IsValid() const +{ + return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) && + uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount); +} + +std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; + +void +WebSocketMessage::SetMessageType(WebSocketMessageType MessageType) +{ + m_Header.MessageType = MessageType; +} + +void +WebSocketMessage::SetBody(CbPackage&& Body) +{ + m_Body = std::move(Body); +} +void +WebSocketMessage::SetBody(CbObject&& Body) +{ + CbPackage Pkg; + Pkg.SetObject(Body); + + SetBody(std::move(Pkg)); +} + +void +WebSocketMessage::Save(BinaryWriter& Writer) +{ + Writer.Write(&m_Header, HeaderSize); + + if (m_Body.has_value()) + { + const CbObject& Obj = m_Body.value().GetObject(); + MemoryView View = Obj.GetBuffer().GetView(); + + const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All); + ZEN_ASSERT(ValidationResult == CbValidateError::None); + + m_Body.value().Save(Writer); + } + + if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest) + { + m_Header.CorrelationId = NextCorrelationId.fetch_add(1); + } + + m_Header.MessageSize = Writer.Size() - HeaderSize; + + Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); +} + +bool +WebSocketMessage::TryLoadHeader(MemoryView Memory) +{ + if (Memory.GetSize() < HeaderSize) + { + return false; + } + + MutableMemoryView HeaderView(&m_Header, HeaderSize); + + HeaderView.CopyFrom(Memory); + + return m_Header.IsValid(); +} + +void +WebSocketService::Configure(WebSocketServer& Server) +{ + ZEN_ASSERT(m_SocketServer == nullptr); + + m_SocketServer = &Server; + + RegisterHandlers(Server); +} + +void +WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete) +{ + WebSocketMessage Message; + + Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse); + Message.SetCorrelationId(CorrelationId); + Message.SetSocketId(SocketId); + Message.SetBody(std::move(StreamResponse)); + + SocketServer().SendResponse(std::move(Message)); +} + +void +WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete) +{ + CbPackage Response; + Response.SetObject(std::move(StreamResponse)); + + SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete); +} + +std::unique_ptr<WebSocketServer> +WebSocketServer::Create(const WebSocketServerOptions& Options) +{ + return std::make_unique<websocket::WsServer>(Options); +} + +std::shared_ptr<WebSocketClient> +WebSocketClient::Create(asio::io_context& IoCtx) +{ + return std::make_shared<websocket::WsClient>(IoCtx); +} + +} // namespace zen |