diff options
| author | Stefan Boberg <[email protected]> | 2023-10-04 13:13:23 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-10-04 13:13:23 +0200 |
| commit | 3e8db7cd243e8be3b2d5fea2490b9ad70f765590 (patch) | |
| tree | c1533ae25b0b717dd2393960c5449aadd56807c4 /src | |
| parent | factored out http parser from asio into separate files (#444) (diff) | |
| download | zen-3e8db7cd243e8be3b2d5fea2490b9ad70f765590.tar.xz zen-3e8db7cd243e8be3b2d5fea2490b9ad70f765590.zip | |
removed websocket protocol support(#445)
removed websocket support since it is not used right now and is unlikely to be used in the future
Diffstat (limited to 'src')
| -rw-r--r-- | src/zenhttp/include/zenhttp/httptest.h | 6 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/websocket.h | 256 | ||||
| -rw-r--r-- | src/zenhttp/testing/httptest.cpp | 32 | ||||
| -rw-r--r-- | src/zenhttp/websocketasio.cpp | 1613 | ||||
| -rw-r--r-- | src/zenserver-test/zenserver-test.cpp | 52 | ||||
| -rw-r--r-- | src/zenserver/config.cpp | 16 | ||||
| -rw-r--r-- | src/zenserver/config.h | 22 | ||||
| -rw-r--r-- | src/zenserver/zenserver.cpp | 22 |
8 files changed, 11 insertions, 2008 deletions
diff --git a/src/zenhttp/include/zenhttp/httptest.h b/src/zenhttp/include/zenhttp/httptest.h index 74db69785..a4008fb5e 100644 --- a/src/zenhttp/include/zenhttp/httptest.h +++ b/src/zenhttp/include/zenhttp/httptest.h @@ -5,7 +5,6 @@ #include <zencore/logging.h> #include <zencore/stats.h> #include <zenhttp/httpserver.h> -#include <zenhttp/websocket.h> #include <atomic> @@ -16,7 +15,7 @@ namespace zen { /** * Test service to facilitate testing the HTTP framework and client interactions */ -class HttpTestingService : public HttpService, public WebSocketService +class HttpTestingService : public HttpService { public: HttpTestingService(); @@ -43,9 +42,6 @@ public: }; private: - virtual void RegisterHandlers(WebSocketServer& Server) override; - virtual bool HandleRequest(const WebSocketMessage& Request) override; - HttpRequestRouter m_Router; std::atomic<uint32_t> m_Counter{0}; metrics::OperationTiming m_TimingStats; diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h deleted file mode 100644 index adca7e988..000000000 --- a/src/zenhttp/include/zenhttp/websocket.h +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zencore/compactbinarypackage.h> -#include <zencore/memory.h> - -#include <compare> -#include <functional> -#include <future> -#include <memory> -#include <optional> - -#pragma once - -namespace asio { -class io_context; -} - -namespace zen { - -class BinaryWriter; - -/** - * A unique socket ID. - */ -class WebSocketId -{ - static std::atomic_uint32_t NextId; - -public: - WebSocketId() = default; - - uint32_t Value() const { return m_Value; } - - auto operator<=>(const WebSocketId&) const = default; - - static WebSocketId New() { return WebSocketId(NextId.fetch_add(1)); } - -private: - WebSocketId(uint32_t Value) : m_Value(Value) {} - - uint32_t m_Value{}; -}; - -/** - * Type of web socket message. - */ -enum class WebSocketMessageType : uint8_t -{ - kInvalid, - kNotification, - kRequest, - kStreamRequest, - kResponse, - kStreamResponse, - kStreamCompleteResponse, - kCount -}; - -inline std::string_view -ToString(WebSocketMessageType Type) -{ - switch (Type) - { - case WebSocketMessageType::kInvalid: - return std::string_view("Invalid"); - case WebSocketMessageType::kNotification: - return std::string_view("Notification"); - case WebSocketMessageType::kRequest: - return std::string_view("Request"); - case WebSocketMessageType::kStreamRequest: - return std::string_view("StreamRequest"); - case WebSocketMessageType::kResponse: - return std::string_view("Response"); - case WebSocketMessageType::kStreamResponse: - return std::string_view("StreamResponse"); - case WebSocketMessageType::kStreamCompleteResponse: - return std::string_view("StreamCompleteResponse"); - default: - return std::string_view("Unknown"); - }; -} - -/** - * Web socket message. - */ -class WebSocketMessage -{ - struct Header - { - static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh - - uint64_t MessageSize{}; - uint32_t Magic{ExpectedMagic}; - uint32_t CorrelationId{}; - uint32_t StatusCode{200u}; - WebSocketMessageType MessageType{}; - uint8_t Reserved[3] = {0}; - - bool IsValid() const; - }; - - static_assert(sizeof(Header) == 24); - - static std::atomic_uint32_t NextCorrelationId; - -public: - static constexpr size_t HeaderSize = sizeof(Header); - - WebSocketMessage() = default; - - WebSocketId SocketId() const { return m_SocketId; } - void SetSocketId(WebSocketId Id) { m_SocketId = Id; } - uint64_t MessageSize() const { return m_Header.MessageSize; } - void SetMessageType(WebSocketMessageType MessageType); - void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; } - uint32_t CorrelationId() const { return m_Header.CorrelationId; } - uint32_t StatusCode() const { return m_Header.StatusCode; } - void SetStatusCode(uint32_t StatusCode) { m_Header.StatusCode = StatusCode; } - WebSocketMessageType MessageType() const { return m_Header.MessageType; } - - const CbPackage& Body() const { return m_Body.value(); } - void SetBody(CbPackage&& Body); - void SetBody(CbObject&& Body); - bool HasBody() const { return m_Body.has_value(); } - - void Save(BinaryWriter& Writer); - bool TryLoadHeader(MemoryView Memory); - - bool IsValid() const { return m_Header.MessageType != WebSocketMessageType::kInvalid; } - -private: - Header m_Header{}; - WebSocketId m_SocketId{}; - std::optional<CbPackage> m_Body; -}; - -class WebSocketServer; - -/** - * Base class for handling web socket requests and notifications from connected client(s). - */ -class WebSocketService -{ -public: - virtual ~WebSocketService() = default; - - void Configure(WebSocketServer& Server); - - virtual bool HandleRequest(const WebSocketMessage&) { ZEN_ASSERT(false); } - virtual void HandleNotification(const WebSocketMessage&) { ZEN_ASSERT(false); } - -protected: - WebSocketService() = default; - - virtual void RegisterHandlers(WebSocketServer& Server) = 0; - void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete); - void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete); - - WebSocketServer& SocketServer() - { - ZEN_ASSERT(m_SocketServer); - return *m_SocketServer; - } - -private: - WebSocketServer* m_SocketServer{}; -}; - -/** - * Server options. - */ -struct WebSocketServerOptions -{ - uint16_t Port = 2337; - uint32_t ThreadCount = 1; -}; - -/** - * The web socket server manages client connections and routing of requests and notifications. - */ -class WebSocketServer -{ -public: - virtual ~WebSocketServer() = default; - - virtual bool Run() = 0; - virtual void Shutdown() = 0; - - virtual void RegisterService(WebSocketService& Service) = 0; - virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) = 0; - virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) = 0; - - virtual void SendNotification(WebSocketMessage&& Notification) = 0; - virtual void SendResponse(WebSocketMessage&& Response) = 0; - - static std::unique_ptr<WebSocketServer> Create(const WebSocketServerOptions& Options); -}; - -/** - * The state of the web socket. - */ -enum class WebSocketState : uint32_t -{ - kNone, - kHandshaking, - kConnected, - kDisconnected, - kError -}; - -/** - * Type of web socket client event. - */ -enum class WebSocketEvent : uint32_t -{ - kConnected, - kDisconnected, - kError -}; - -/** - * Web socket client connection info. - */ -struct WebSocketConnectInfo -{ - std::string Host; - int16_t Port{8848}; - std::string Endpoint; - std::vector<std::string> Protocols; - uint16_t Version{13}; -}; - -/** - * A connection to a web socket server for sending requests and listening for notifications. - */ -class WebSocketClient -{ -public: - using EventCallback = std::function<void()>; - using NotificationCallback = std::function<void(WebSocketMessage&&)>; - - virtual ~WebSocketClient() = default; - - virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) = 0; - virtual void Disconnect() = 0; - virtual bool IsConnected() const = 0; - virtual WebSocketState State() const = 0; - - virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) = 0; - virtual void OnNotification(NotificationCallback&& Cb) = 0; - virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) = 0; - - static std::shared_ptr<WebSocketClient> Create(asio::io_context& IoCtx); -}; - -} // namespace zen diff --git a/src/zenhttp/testing/httptest.cpp b/src/zenhttp/testing/httptest.cpp index a02e36bcc..3a0ad72a9 100644 --- a/src/zenhttp/testing/httptest.cpp +++ b/src/zenhttp/testing/httptest.cpp @@ -140,38 +140,6 @@ HttpTestingService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) return (InsertResult.first->second = Ref<PackageHandler>(new PackageHandler(*this, RequestId))); } -void -HttpTestingService::RegisterHandlers(WebSocketServer& Server) -{ - Server.RegisterRequestHandler("SayHello"sv, *this); -} - -bool -HttpTestingService::HandleRequest(const WebSocketMessage& RequestMsg) -{ - CbObjectView Request = RequestMsg.Body().GetObject(); - - std::string_view Method = Request["Method"].AsString(); - - if (Method != "SayHello"sv) - { - return false; - } - - CbObjectWriter Response; - Response.AddString("Result"sv, "Hello Friend!!"); - - WebSocketMessage ResponseMsg; - ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); - ResponseMsg.SetCorrelationId(RequestMsg.CorrelationId()); - ResponseMsg.SetSocketId(RequestMsg.SocketId()); - ResponseMsg.SetBody(Response.Save()); - - SocketServer().SendResponse(std::move(ResponseMsg)); - - return true; -} - ////////////////////////////////////////////////////////////////////////// HttpTestingService::PackageHandler::PackageHandler(HttpTestingService& Svc, uint32_t RequestId) : m_Svc(Svc), m_RequestId(RequestId) diff --git a/src/zenhttp/websocketasio.cpp b/src/zenhttp/websocketasio.cpp deleted file mode 100644 index bbe7e1ad8..000000000 --- a/src/zenhttp/websocketasio.cpp +++ /dev/null @@ -1,1613 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zenhttp/websocket.h> - -#include <zencore/base64.h> -#include <zencore/compactbinarybuilder.h> -#include <zencore/compactbinaryvalidation.h> -#include <zencore/intmath.h> -#include <zencore/iobuffer.h> -#include <zencore/logging.h> -#include <zencore/memory.h> -#include <zencore/sha1.h> -#include <zencore/stream.h> -#include <zencore/string.h> -#include <zencore/trace.h> - -#include <chrono> -#include <optional> -#include <shared_mutex> -#include <span> -#include <system_error> -#include <thread> - -ZEN_THIRD_PARTY_INCLUDES_START -#include <fmt/format.h> -#include <http_parser.h> -#include <asio.hpp> -ZEN_THIRD_PARTY_INCLUDES_END - -#if ZEN_PLATFORM_WINDOWS -# include <mstcpip.h> -#endif - -namespace zen::websocket { - -using namespace std::literals; - -ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); - -ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv); - -using Clock = std::chrono::steady_clock; -using TimePoint = Clock::time_point; - -/////////////////////////////////////////////////////////////////////////////// -namespace http_header { - static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; - static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; - static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; - static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; - static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; - static constexpr std::string_view Upgrade = "Upgrade"sv; -} // namespace http_header - -/////////////////////////////////////////////////////////////////////////////// -enum class ParseMessageStatus : uint32_t -{ - kError, - kContinue, - kDone, -}; - -struct ParseMessageResult -{ - ParseMessageStatus Status{}; - size_t ByteCount{}; - std::optional<std::string> Reason; -}; - -class MessageParser -{ -public: - virtual ~MessageParser() = default; - - ParseMessageResult ParseMessage(MemoryView Msg); - void Reset(); - -protected: - MessageParser() = default; - - virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; - virtual void OnReset() = 0; - - BinaryWriter m_Stream; -}; - -ParseMessageResult -MessageParser::ParseMessage(MemoryView Msg) -{ - return OnParseMessage(Msg); -} - -void -MessageParser::Reset() -{ - OnReset(); - - m_Stream.Reset(); -} - -/////////////////////////////////////////////////////////////////////////////// -enum class HttpMessageParserType -{ - kRequest, - kResponse, - kBoth -}; - -class HttpMessageParser final : public MessageParser -{ -public: - using HttpHeaders = std::unordered_map<std::string_view, std::string_view>; - - HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); } - - virtual ~HttpMessageParser() = default; - - int32_t StatusCode() const { return m_Parser.status_code; } - bool IsUpgrade() const { return m_Parser.upgrade != 0; } - HttpHeaders& Headers() { return m_Headers; } - MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } - - std::string_view StatusText() const - { - return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); - } - - bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); - -private: - void Initialize(); - virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; - virtual void OnReset() override; - int OnMessageBegin(); - int OnUrl(MemoryView Url); - int OnStatus(MemoryView Status); - int OnHeaderField(MemoryView HeaderField); - int OnHeaderValue(MemoryView HeaderValue); - int OnHeadersComplete(); - int OnBody(MemoryView Body); - int OnMessageComplete(); - - struct StreamEntry - { - uint64_t Offset{}; - uint64_t Size{}; - }; - - struct HeaderStreamEntry - { - StreamEntry Field{}; - StreamEntry Value{}; - }; - - HttpMessageParserType m_Type; - http_parser m_Parser; - StreamEntry m_UrlEntry; - StreamEntry m_StatusEntry; - StreamEntry m_BodyEntry; - HeaderStreamEntry m_CurrentHeader; - std::vector<HeaderStreamEntry> m_HeaderEntries; - HttpHeaders m_Headers; - bool m_IsMsgComplete{false}; - - static http_parser_settings ParserSettings; -}; - -http_parser_settings HttpMessageParser::ParserSettings = { - .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); }, - - .on_url = [](http_parser* P, - const char* Data, - size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); }, - - .on_status = [](http_parser* P, - const char* Data, - size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); }, - - .on_header_field = [](http_parser* P, - const char* Data, - size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); }, - - .on_header_value = [](http_parser* P, - const char* Data, - size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, - - .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); }, - - .on_body = [](http_parser* P, - const char* Data, - size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); }, - - .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }}; - -void -HttpMessageParser::Initialize() -{ - http_parser_init(&m_Parser, - m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST - : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE - : HTTP_BOTH); - m_Parser.data = this; - - m_UrlEntry = {}; - m_StatusEntry = {}; - m_CurrentHeader = {}; - m_BodyEntry = {}; - - m_IsMsgComplete = false; - - m_HeaderEntries.clear(); -} - -ParseMessageResult -HttpMessageParser::OnParseMessage(MemoryView Msg) -{ - const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize()); - - auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; - - if (m_Parser.http_errno != 0) - { - Status = ParseMessageStatus::kError; - } - - return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; -} - -void -HttpMessageParser::OnReset() -{ - Initialize(); -} - -int -HttpMessageParser::OnMessageBegin() -{ - ZEN_ASSERT(m_IsMsgComplete == false); - ZEN_ASSERT(m_HeaderEntries.empty()); - ZEN_ASSERT(m_Headers.empty()); - - return 0; -} - -int -HttpMessageParser::OnStatus(MemoryView Status) -{ - m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; - - m_Stream.Write(Status); - - return 0; -} - -int -HttpMessageParser::OnUrl(MemoryView Url) -{ - m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; - - m_Stream.Write(Url); - - return 0; -} - -int -HttpMessageParser::OnHeaderField(MemoryView HeaderField) -{ - if (m_CurrentHeader.Value.Size > 0) - { - m_HeaderEntries.push_back(m_CurrentHeader); - m_CurrentHeader = {}; - } - - if (m_CurrentHeader.Field.Size == 0) - { - m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); - } - - m_CurrentHeader.Field.Size += HeaderField.GetSize(); - - m_Stream.Write(HeaderField); - - return 0; -} - -int -HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) -{ - if (m_CurrentHeader.Value.Size == 0) - { - m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); - } - - m_CurrentHeader.Value.Size += HeaderValue.GetSize(); - - m_Stream.Write(HeaderValue); - - return 0; -} - -int -HttpMessageParser::OnHeadersComplete() -{ - if (m_CurrentHeader.Value.Size > 0) - { - m_HeaderEntries.push_back(m_CurrentHeader); - m_CurrentHeader = {}; - } - - m_Headers.clear(); - m_Headers.reserve(m_HeaderEntries.size()); - - const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data()); - - for (const auto& Entry : m_HeaderEntries) - { - auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); - auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); - - m_Headers.try_emplace(std::move(Field), std::move(Value)); - } - - return 0; -} - -int -HttpMessageParser::OnBody(MemoryView Body) -{ - m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; - - m_Stream.Write(Body); - - return 0; -} - -int -HttpMessageParser::OnMessageComplete() -{ - m_IsMsgComplete = true; - - return 0; -} - -bool -HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) -{ - static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; - - OutAcceptHash = std::string(); - - if (m_Headers.contains(http_header::SecWebSocketKey) == false) - { - OutReason = "Missing header Sec-WebSocket-Key"; - return false; - } - - if (m_Headers.contains(http_header::Upgrade) == false) - { - OutReason = "Missing header Upgrade"; - return false; - } - - ExtendableStringBuilder<128> Sb; - Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; - - SHA1Stream HashStream; - HashStream.Append(Sb.Data(), Sb.Size()); - - SHA1 Hash = HashStream.GetHash(); - - OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); - Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); - - return true; -} - -/////////////////////////////////////////////////////////////////////////////// -class WebSocketMessageParser final : public MessageParser -{ -public: - WebSocketMessageParser() : MessageParser() {} - - WebSocketMessage ConsumeMessage(); - -private: - virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; - virtual void OnReset() override; - - WebSocketMessage m_Message; -}; - -ParseMessageResult -WebSocketMessageParser::OnParseMessage(MemoryView Msg) -{ - ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage"); - - const uint64_t PrevOffset = m_Stream.CurrentOffset(); - - if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) - { - const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); - - m_Stream.Write(Msg.Left(RemaingHeaderSize)); - Msg += RemaingHeaderSize; - - if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) - { - return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; - } - - const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView()); - - if (IsValidHeader == false) - { - OnReset(); - - return {.Status = ParseMessageStatus::kError, - .ByteCount = m_Stream.CurrentOffset() - PrevOffset, - .Reason = std::string("Invalid websocket message header")}; - } - - if (m_Message.MessageSize() == 0) - { - return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; - } - } - - ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); - - if (Msg.IsEmpty() == false) - { - const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); - m_Stream.Write(Msg.Left(RemaingMessageSize)); - } - - auto Status = ParseMessageStatus::kContinue; - - if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize()) - { - Status = ParseMessageStatus::kDone; - - BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize)); - - CbPackage Pkg; - if (Pkg.TryLoad(Reader) == false) - { - return {.Status = ParseMessageStatus::kError, - .ByteCount = m_Stream.CurrentOffset() - PrevOffset, - .Reason = std::string("Invalid websocket message")}; - } - - m_Message.SetBody(std::move(Pkg)); - } - - return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; -} - -void -WebSocketMessageParser::OnReset() -{ - m_Message = WebSocketMessage(); -} - -WebSocketMessage -WebSocketMessageParser::ConsumeMessage() -{ - WebSocketMessage Msg = std::move(m_Message); - m_Message = WebSocketMessage(); - - return Msg; -} - -/////////////////////////////////////////////////////////////////////////////// -class WsConnection : public std::enable_shared_from_this<WsConnection> -{ -public: - WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) - : m_Id(Id) - , m_Socket(std::move(Socket)) - , m_StartTime(Clock::now()) - , m_State() - { - } - - ~WsConnection() = default; - - std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } - - WebSocketId Id() const { return m_Id; } - asio::ip::tcp::socket& Socket() { return *m_Socket; } - TimePoint StartTime() const { return m_StartTime; } - WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); } - std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); } - asio::streambuf& ReadBuffer() { return m_ReadBuffer; } - WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } - WebSocketState Close(); - MessageParser* Parser() { return m_MsgParser.get(); } - void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } - std::mutex& WriteMutex() { return m_WriteMutex; } - -private: - WebSocketId m_Id; - std::unique_ptr<asio::ip::tcp::socket> m_Socket; - TimePoint m_StartTime; - std::atomic_uint32_t m_State; - std::unique_ptr<MessageParser> m_MsgParser; - asio::streambuf m_ReadBuffer; - std::mutex m_WriteMutex; -}; - -WebSocketState -WsConnection::Close() -{ - const auto PrevState = SetState(WebSocketState::kDisconnected); - - if (PrevState != WebSocketState::kDisconnected && m_Socket->is_open()) - { - m_Socket->close(); - } - - return PrevState; -} - -/////////////////////////////////////////////////////////////////////////////// -class WsThreadPool -{ -public: - WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} - void Start(uint32_t ThreadCount); - void Stop(); - -private: - asio::io_service& m_IoSvc; - std::vector<std::thread> m_Threads; - std::atomic_bool m_Running{false}; -}; - -void -WsThreadPool::Start(uint32_t ThreadCount) -{ - ZEN_ASSERT(m_Threads.empty()); - - ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount); - - m_Running = true; - - for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) - { - m_Threads.emplace_back([this, ThreadId = Idx + 1] { - for (;;) - { - if (m_Running == false) - { - break; - } - - try - { - m_IoSvc.run(); - } - catch (std::exception& Err) - { - ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what()); - } - } - - ZEN_LOG_TRACE(LogWebSocket, "websocket I/O thread '{}' exiting", ThreadId); - }); - } -} - -void -WsThreadPool::Stop() -{ - if (m_Running) - { - m_Running = false; - - for (std::thread& Thread : m_Threads) - { - if (Thread.joinable()) - { - Thread.join(); - } - } - - m_Threads.clear(); - } -} - -/////////////////////////////////////////////////////////////////////////////// -class WsServer final : public WebSocketServer -{ -public: - WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {} - virtual ~WsServer() { Shutdown(); } - - virtual bool Run() override; - virtual void Shutdown() override; - - virtual void RegisterService(WebSocketService& Service) override; - virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override; - virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override; - - virtual void SendNotification(WebSocketMessage&& Notification) override; - virtual void SendResponse(WebSocketMessage&& Response) override; - -private: - friend class WsConnection; - - void AcceptConnection(); - void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec); - - void ReadMessage(std::shared_ptr<WsConnection> Connection); - void RouteMessage(WebSocketMessage&& Msg); - void SendMessage(WebSocketMessage&& Msg); - - struct IdHasher - { - size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } - }; - - using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>; - using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>; - using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>; - - WebSocketServerOptions m_Options; - asio::io_service m_IoSvc; - std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; - std::unique_ptr<WsThreadPool> m_ThreadPool; - ConnectionMap m_Connections; - std::shared_mutex m_ConnMutex; - std::vector<WebSocketService*> m_Services; - RequestHandlerMap m_RequestHandlers; - NotificationHandlerMap m_NotificationHandlers; - std::atomic_bool m_Running{}; -}; - -void -WsServer::RegisterService(WebSocketService& Service) -{ - m_Services.push_back(&Service); - - Service.Configure(*this); -} - -bool -WsServer::Run() -{ - static constexpr size_t ReceiveBufferSize = 256 << 10; - static constexpr size_t SendBufferSize = 256 << 10; - - m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6()); - - m_Acceptor->set_option(asio::ip::v6_only(false)); - m_Acceptor->set_option(asio::socket_base::reuse_address(true)); - m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); - m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize)); - m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize)); - -#if ZEN_PLATFORM_WINDOWS - // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. - // This must be used by both the client and server side, and is only effective in the absence of - // Windows Filtering Platform (WFP) callouts which can be installed by security software. - // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path - SOCKET NativeSocket = m_Acceptor->native_handle(); - int LoopbackOptionValue = 1; - DWORD OptionNumberOfBytesReturned = 0; - WSAIoctl(NativeSocket, - SIO_LOOPBACK_FAST_PATH, - &LoopbackOptionValue, - sizeof(LoopbackOptionValue), - NULL, - 0, - &OptionNumberOfBytesReturned, - 0, - 0); -#endif - - asio::error_code Ec; - m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec); - - if (Ec) - { - ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value()); - - return false; - } - - m_Acceptor->listen(); - m_Running = true; - - ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port); - - AcceptConnection(); - - m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc); - m_ThreadPool->Start(m_Options.ThreadCount); - - return true; -} - -void -WsServer::Shutdown() -{ - if (m_Running) - { - ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down"); - - m_Running = false; - - m_Acceptor->close(); - m_Acceptor.reset(); - m_IoSvc.stop(); - - m_ThreadPool->Stop(); - } -} - -void -WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) -{ - auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>()); - Result.first->second.push_back(&Service); -} - -void -WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service) -{ - m_RequestHandlers[Key] = &Service; -} - -void -WsServer::SendNotification(WebSocketMessage&& Notification) -{ - ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification); - - SendMessage(std::move(Notification)); -} -void -WsServer::SendResponse(WebSocketMessage&& Response) -{ - ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse || - Response.MessageType() == WebSocketMessageType::kStreamResponse || - Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse); - - ZEN_ASSERT(Response.CorrelationId() != 0); - - SendMessage(std::move(Response)); -} - -void -WsServer::AcceptConnection() -{ - auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc); - asio::ip::tcp::socket& SocketRef = *Socket.get(); - - m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { - if (m_Running) - { - if (Ec) - { - ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message()); - } - else - { - auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket)); - - ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); - - { - std::unique_lock _(m_ConnMutex); - m_Connections[Connection->Id()] = Connection; - } - - Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest)); - Connection->SetState(WebSocketState::kHandshaking); - - ReadMessage(Connection); - } - - AcceptConnection(); - } - }); -} - -void -WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec) -{ - if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected) - { - if (Ec) - { - ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, reason '{} ({})'", Connection->Id().Value(), Ec.message(), Ec.value()); - } - else - { - ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value()); - } - } - - const WebSocketId Id = Connection->Id(); - - { - std::unique_lock _(m_ConnMutex); - if (m_Connections.contains(Id)) - { - m_Connections.erase(Id); - } - } -} - -void -WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) -{ - Connection->ReadBuffer().prepare(64 << 10); - - asio::async_read( - Connection->Socket(), - Connection->ReadBuffer(), - asio::transfer_at_least(1), - [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable { - if (ReadEc) - { - return CloseConnection(Connection, ReadEc); - } - - switch (Connection->State()) - { - case WebSocketState::kHandshaking: - { - HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser()); - asio::const_buffer Buffer = Connection->ReadBuffer().data(); - - ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); - - Connection->ReadBuffer().consume(Result.ByteCount); - - if (Result.Status == ParseMessageStatus::kContinue) - { - return ReadMessage(Connection); - } - - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_LOG_WARN(LogWebSocket, - "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", - Connection->Id().Value(), - Connection->RemoteAddr()); - - return CloseConnection(Connection, std::error_code()); - } - - if (Parser.IsUpgrade() == false) - { - ZEN_LOG_DEBUG(LogWebSocket, - "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'", - Connection->Id().Value(), - Connection->RemoteAddr()); - - constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; - - return async_write(Connection->Socket(), - asio::buffer(UpgradeRequiredResponse), - [this, Connection](const asio::error_code& WriteEc, std::size_t) { - if (WriteEc) - { - return CloseConnection(Connection, WriteEc); - } - - Connection->Parser()->Reset(); - Connection->SetState(WebSocketState::kHandshaking); - - ReadMessage(Connection); - }); - } - - ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); - - std::string AcceptHash; - std::string Reason; - const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason); - - if (ValidHandshake == false) - { - ZEN_LOG_DEBUG(LogWebSocket, - "handshake with connection '{}' FAILED, reason '{}'", - Connection->Id().Value(), - Reason); - - constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; - - return async_write(Connection->Socket(), - asio::buffer(UpgradeRequiredResponse), - [this, &Connection](const asio::error_code& WriteEc, std::size_t) { - if (WriteEc) - { - return CloseConnection(Connection, WriteEc); - } - - Connection->Parser()->Reset(); - Connection->SetState(WebSocketState::kHandshaking); - - ReadMessage(Connection); - }); - } - - ExtendableStringBuilder<128> Sb; - - Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; - Sb << "Upgrade: websocket\r\n"sv; - Sb << "Connection: Upgrade\r\n"sv; - - // TODO: Verify protocol - if (Parser.Headers().contains(http_header::SecWebSocketProtocol)) - { - Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol] - << "\r\n"; - } - - Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; - Sb << "\r\n"sv; - - ZEN_LOG_DEBUG(LogWebSocket, - "accepting handshake from connection '#{} {}'", - Connection->Id().Value(), - Connection->RemoteAddr()); - - std::string Response = Sb.ToString(); - Buffer = asio::buffer(Response); - - async_write(Connection->Socket(), - Buffer, - [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) { - if (WriteEc) - { - ZEN_LOG_DEBUG(LogWebSocket, - "handshake with connection '{}' FAILED, reason '{}'", - Connection->Id().Value(), - WriteEc.message()); - - return CloseConnection(Connection, WriteEc); - } - - ZEN_LOG_DEBUG(LogWebSocket, - "handshake ({}B) with connection '#{} {}' OK", - ByteCount, - Connection->Id().Value(), - Connection->RemoteAddr()); - - Connection->SetParser(std::make_unique<WebSocketMessageParser>()); - Connection->SetState(WebSocketState::kConnected); - - ReadMessage(Connection); - }); - } - break; - - case WebSocketState::kConnected: - { - WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser()); - - uint64_t RemainingBytes = Connection->ReadBuffer().size(); - - while (RemainingBytes > 0) - { - MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes); - const ParseMessageResult Result = Parser.ParseMessage(MessageData); - - Connection->ReadBuffer().consume(Result.ByteCount); - RemainingBytes = Connection->ReadBuffer().size(); - - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); - - return CloseConnection(Connection, std::error_code()); - } - - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(RemainingBytes == 0); - continue; - } - - WebSocketMessage Message = Parser.ConsumeMessage(); - Parser.Reset(); - - Message.SetSocketId(Connection->Id()); - - RouteMessage(std::move(Message)); - } - - ReadMessage(Connection); - } - break; - - default: - break; - }; - }); -} - -void -WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) -{ - switch (RoutedMessage.MessageType()) - { - case WebSocketMessageType::kRequest: - case WebSocketMessageType::kStreamRequest: - { - CbObjectView Request = RoutedMessage.Body().GetObject(); - std::string_view Method = Request["Method"].AsString(); - bool Handled = false; - bool Error = false; - std::exception Exception; - - if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end()) - { - WebSocketService* Service = It->second; - ZEN_ASSERT(Service); - - try - { - Handled = Service->HandleRequest(std::move(RoutedMessage)); - } - catch (std::exception& Err) - { - Exception = std::move(Err); - Error = true; - } - } - - if (Error || Handled == false) - { - std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method); - - ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); - - CbObjectWriter Response; - Response << "Error"sv << ErrorText; - - WebSocketMessage ResponseMsg; - ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); - ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId()); - ResponseMsg.SetSocketId(RoutedMessage.SocketId()); - ResponseMsg.SetBody(Response.Save()); - - SendResponse(std::move(ResponseMsg)); - } - } - break; - - case WebSocketMessageType::kNotification: - { - CbObjectView Notification = RoutedMessage.Body().GetObject(); - std::string_view Message = Notification["Message"].AsString(); - - if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end()) - { - std::vector<WebSocketService*>& Handlers = It->second; - - for (WebSocketService* Handler : Handlers) - { - Handler->HandleNotification(RoutedMessage); - } - } - else - { - ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message); - } - } - break; - - default: - break; - }; -} - -void -WsServer::SendMessage(WebSocketMessage&& Msg) -{ - std::shared_ptr<WsConnection> Connection; - - { - std::unique_lock _(m_ConnMutex); - - if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end()) - { - Connection = It->second; - } - } - - if (Connection.get() == nullptr) - { - ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value()); - return; - } - - if (Connection.get() != nullptr) - { - BinaryWriter Writer; - Msg.Save(Writer); - - ZEN_LOG_TRACE(LogWebSocket, - "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}", - ToString(Msg.MessageType()), - Connection->Id().Value(), - Msg.MessageSize(), - Msg.CorrelationId(), - NiceBytes(Writer.Size())); - - { - ZEN_TRACE_CPU("WS::SendMessage"); - std::unique_lock _(Connection->WriteMutex()); - ZEN_TRACE_CPU("WS::WriteSocketData"); - asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size())); - } - } -} - -/////////////////////////////////////////////////////////////////////////////// -class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient> -{ -public: - WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {} - - virtual ~WsClient() { Disconnect(); } - - std::shared_ptr<WsClient> AsShared() { return shared_from_this(); } - - virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override; - virtual void Disconnect() override; - virtual bool IsConnected() const override { return false; } - virtual WebSocketState State() const override { return static_cast<WebSocketState>(m_State.load()); } - - virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override; - virtual void OnNotification(NotificationCallback&& Cb) override; - virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override; - -private: - WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } - MessageParser* Parser() { return m_MsgParser.get(); } - void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } - asio::streambuf& ReadBuffer() { return m_ReadBuffer; } - void TriggerEvent(WebSocketEvent Evt); - void ReadMessage(); - void RouteMessage(WebSocketMessage&& RoutedMessage); - - using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>; - - asio::io_context& m_IoCtx; - WebSocketId m_Id; - std::unique_ptr<asio::ip::tcp::socket> m_Socket; - std::unique_ptr<MessageParser> m_MsgParser; - asio::streambuf m_ReadBuffer; - EventCallback m_EventCallbacks[3]; - NotificationCallback m_NotificationCallback; - PendingRequestMap m_PendingRequests; - std::mutex m_RequestMutex; - std::promise<bool> m_ConnectPromise; - std::atomic_uint32_t m_State; - std::string m_Host; - int16_t m_Port{}; -}; - -std::future<bool> -WsClient::Connect(const WebSocketConnectInfo& Info) -{ - if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) - { - return m_ConnectPromise.get_future(); - } - - SetState(WebSocketState::kHandshaking); - - try - { - asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port); - m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol()); - - m_Socket->connect(Endpoint); - - m_Host = m_Socket->remote_endpoint().address().to_string(); - m_Port = Info.Port; - - ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port); - } - catch (std::exception& Err) - { - ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); - - SetState(WebSocketState::kError); - m_Socket.reset(); - - TriggerEvent(WebSocketEvent::kDisconnected); - - m_ConnectPromise.set_value(false); - - return m_ConnectPromise.get_future(); - } - - ExtendableStringBuilder<128> Sb; - Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv; - Sb << "Host: " << Info.Host << "\r\n"sv; - Sb << "Upgrade: websocket\r\n"sv; - Sb << "Connection: upgrade\r\n"sv; - Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv; - - if (Info.Protocols.empty() == false) - { - Sb << "Sec-WebSocket-Protocol: "sv; - for (size_t Idx = 0; const auto& Protocol : Info.Protocols) - { - if (Idx++) - { - Sb << ", "; - } - Sb << Protocol; - } - } - - Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv; - Sb << "\r\n"; - - std::string HandshakeRequest = Sb.ToString(); - asio::const_buffer Buffer = asio::buffer(HandshakeRequest); - - ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port); - - m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse); - m_MsgParser->Reset(); - - async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message()); - - Self->Disconnect(); - } - else - { - Self->ReadMessage(); - } - }); - - return m_ConnectPromise.get_future(); -} - -void -WsClient::Disconnect() -{ - if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) - { - ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port); - - if (m_Socket && m_Socket->is_open()) - { - m_Socket->close(); - m_Socket.reset(); - } - - TriggerEvent(WebSocketEvent::kDisconnected); - - { - std::unique_lock _(m_RequestMutex); - - for (auto& Kv : m_PendingRequests) - { - Kv.second.set_value(WebSocketMessage()); - } - - m_PendingRequests.clear(); - } - } -} - -std::future<WebSocketMessage> -WsClient::SendRequest(WebSocketMessage&& Request) -{ - ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest); - - BinaryWriter Writer; - Request.Save(Writer); - - std::future<WebSocketMessage> FutureResponse; - - { - std::unique_lock _(m_RequestMutex); - - auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>()); - ZEN_ASSERT(Result.second); - - auto It = Result.first; - FutureResponse = It->second.get_future(); - } - - IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); - - async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) { - if (Ec) - { - ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message()); - - Self->Disconnect(); - } - }); - - return FutureResponse; -} - -void -WsClient::OnNotification(NotificationCallback&& Cb) -{ - m_NotificationCallback = std::move(Cb); -} - -void -WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) -{ - m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); -} - -void -WsClient::TriggerEvent(WebSocketEvent Evt) -{ - const uint32_t Index = static_cast<uint32_t>(Evt); - - if (m_EventCallbacks[Index]) - { - m_EventCallbacks[Index](); - } -} - -void -WsClient::ReadMessage() -{ - m_ReadBuffer.prepare(64 << 10); - - async_read(*m_Socket, - m_ReadBuffer, - asio::transfer_at_least(1), - [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable { - const WebSocketState State = Self->State(); - - if (State == WebSocketState::kDisconnected) - { - return; - } - - if (Ec) - { - ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message()); - - return Self->Disconnect(); - } - - switch (State) - { - case WebSocketState::kHandshaking: - { - HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser()); - - MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); - - ParseMessageResult Result = Parser.ParseMessage(MessageData); - - Self->ReadBuffer().consume(size_t(Result.ByteCount)); - - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); - - Self->m_ConnectPromise.set_value(false); - - return Self->Disconnect(); - } - - if (Result.Status == ParseMessageStatus::kContinue) - { - return Self->ReadMessage(); - } - - ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); - - if (Parser.StatusCode() != 101) - { - ZEN_LOG_WARN(LogWsClient, - "handshake FAILED, status '{}', status code '{}'", - Parser.StatusText(), - Parser.StatusCode()); - - Self->m_ConnectPromise.set_value(false); - - return Self->Disconnect(); - } - - ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText()); - - Self->SetParser(std::make_unique<WebSocketMessageParser>()); - Self->SetState(WebSocketState::kConnected); - Self->ReadMessage(); - Self->TriggerEvent(WebSocketEvent::kConnected); - - Self->m_ConnectPromise.set_value(true); - } - break; - - case WebSocketState::kConnected: - { - WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser()); - - uint64_t RemainingBytes = Self->ReadBuffer().size(); - - while (RemainingBytes > 0) - { - MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes); - const ParseMessageResult Result = Parser.ParseMessage(MessageData); - - Self->ReadBuffer().consume(Result.ByteCount); - RemainingBytes = Self->ReadBuffer().size(); - - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); - - Parser.Reset(); - continue; - } - - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(RemainingBytes == 0); - continue; - } - - WebSocketMessage Message = Parser.ConsumeMessage(); - Parser.Reset(); - - Self->RouteMessage(std::move(Message)); - } - - Self->ReadMessage(); - } - break; - - default: - break; - } - }); -} - -void -WsClient::RouteMessage(WebSocketMessage&& RoutedMessage) -{ - switch (RoutedMessage.MessageType()) - { - case WebSocketMessageType::kResponse: - { - std::unique_lock _(m_RequestMutex); - - if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end()) - { - It->second.set_value(std::move(RoutedMessage)); - m_PendingRequests.erase(It); - } - else - { - ZEN_LOG_WARN(LogWsClient, - "route request message FAILED, reason 'unknown correlation ID ({})'", - RoutedMessage.CorrelationId()); - } - } - break; - - case WebSocketMessageType::kNotification: - { - std::unique_lock _(m_RequestMutex); - - if (m_NotificationCallback) - { - m_NotificationCallback(std::move(RoutedMessage)); - } - } - break; - - default: - ZEN_LOG_WARN(LogWsClient, "route message FAILED, reason 'invalid message type ({})'", uint8_t(RoutedMessage.MessageType())); - break; - }; -} - -} // namespace zen::websocket - -namespace zen { - -std::atomic_uint32_t WebSocketId::NextId{1}; - -bool -WebSocketMessage::Header::IsValid() const -{ - return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) && - uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount); -} - -std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; - -void -WebSocketMessage::SetMessageType(WebSocketMessageType MessageType) -{ - m_Header.MessageType = MessageType; -} - -void -WebSocketMessage::SetBody(CbPackage&& Body) -{ - m_Body = std::move(Body); -} -void -WebSocketMessage::SetBody(CbObject&& Body) -{ - CbPackage Pkg; - Pkg.SetObject(Body); - - SetBody(std::move(Pkg)); -} - -void -WebSocketMessage::Save(BinaryWriter& Writer) -{ - Writer.Write(&m_Header, HeaderSize); - - if (m_Body.has_value()) - { - const CbObject& Obj = m_Body.value().GetObject(); - MemoryView View = Obj.GetBuffer().GetView(); - - const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All); - ZEN_ASSERT(ValidationResult == CbValidateError::None); - - m_Body.value().Save(Writer); - } - - if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest) - { - m_Header.CorrelationId = NextCorrelationId.fetch_add(1); - } - - m_Header.MessageSize = Writer.Size() - HeaderSize; - - Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); -} - -bool -WebSocketMessage::TryLoadHeader(MemoryView Memory) -{ - if (Memory.GetSize() < HeaderSize) - { - return false; - } - - MutableMemoryView HeaderView(&m_Header, HeaderSize); - - HeaderView.CopyFrom(Memory); - - return m_Header.IsValid(); -} - -void -WebSocketService::Configure(WebSocketServer& Server) -{ - ZEN_ASSERT(m_SocketServer == nullptr); - - m_SocketServer = &Server; - - RegisterHandlers(Server); -} - -void -WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete) -{ - WebSocketMessage Message; - - Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse); - Message.SetCorrelationId(CorrelationId); - Message.SetSocketId(SocketId); - Message.SetBody(std::move(StreamResponse)); - - SocketServer().SendResponse(std::move(Message)); -} - -void -WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete) -{ - CbPackage Response; - Response.SetObject(std::move(StreamResponse)); - - SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete); -} - -std::unique_ptr<WebSocketServer> -WebSocketServer::Create(const WebSocketServerOptions& Options) -{ - return std::make_unique<websocket::WsServer>(Options); -} - -std::shared_ptr<WebSocketClient> -WebSocketClient::Create(asio::io_context& IoCtx) -{ - return std::make_shared<websocket::WsClient>(IoCtx); -} - -} // namespace zen diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp index 2f52e3225..3fae3235f 100644 --- a/src/zenserver-test/zenserver-test.cpp +++ b/src/zenserver-test/zenserver-test.cpp @@ -21,7 +21,6 @@ #include <zencore/xxhash.h> #include <zenhttp/httpclient.h> #include <zenhttp/httpshared.h> -#include <zenhttp/websocket.h> #include <zenhttp/zenhttp.h> #include <zenutil/cache/cache.h> #include <zenutil/cache/cacherequests.h> @@ -2608,57 +2607,6 @@ TEST_CASE("http.package") CHECK_EQ(ResponsePackage, TestPackage); } -TEST_CASE("websocket.basic") -{ - if (true) - { - return; - } - - std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); - const uint16_t PortNumber = 13337; - const auto MaxWaitTime = std::chrono::seconds(5); - - ZenServerInstance Inst(TestEnv); - Inst.SetTestDir(TestDir); - Inst.SpawnServer(PortNumber, "--websocket-port=8848"sv); - Inst.WaitUntilReady(); - - asio::io_context IoCtx; - IoDispatcher IoDispatcher(IoCtx); - auto WebSocket = WebSocketClient::Create(IoCtx); - - auto ConnectFuture = WebSocket->Connect({.Host = "127.0.0.1", .Port = 8848, .Endpoint = "/zen"}); - IoDispatcher.Run(); - - ConnectFuture.wait_for(MaxWaitTime); - CHECK(ConnectFuture.get()); - - for (size_t Idx = 0; Idx < 10; Idx++) - { - CbObjectWriter Request; - Request << "Method"sv - << "SayHello"sv; - - WebSocketMessage RequestMsg; - RequestMsg.SetMessageType(WebSocketMessageType::kRequest); - RequestMsg.SetBody(Request.Save()); - - auto ResponseFuture = WebSocket->SendRequest(std::move(RequestMsg)); - ResponseFuture.wait_for(MaxWaitTime); - - CbObject Response = ResponseFuture.get().Body().GetObject(); - std::string_view Message = Response["Result"].AsString(); - - CHECK(Message == "Hello Friend!!"sv); - } - - WebSocket->Disconnect(); - - IoCtx.stop(); - IoDispatcher.Stop(); -} - std::string OidAsString(const Oid& Id) { diff --git a/src/zenserver/config.cpp b/src/zenserver/config.cpp index 5e24d174b..c0a97ce5b 100644 --- a/src/zenserver/config.cpp +++ b/src/zenserver/config.cpp @@ -796,8 +796,6 @@ ParseConfigFile(const std::filesystem::path& Path, LuaOptions.AddOption("network.httpserverclass"sv, ServerOptions.HttpServerConfig.ServerClass, "http"sv); LuaOptions.AddOption("network.httpserverthreads"sv, ServerOptions.HttpServerConfig.ThreadCount, "http-threads"sv); LuaOptions.AddOption("network.port"sv, ServerOptions.BasePort, "port"sv); - LuaOptions.AddOption("network.websocket.port"sv, ServerOptions.WebSocketPort, "websocket-port"sv); - LuaOptions.AddOption("network.websocket.threadcount"sv, ServerOptions.WebSocketThreads, "websocket-threads"sv); LuaOptions.AddOption("network.httpsys.async.workthreads"sv, ServerOptions.HttpServerConfig.HttpSys.AsyncWorkThreadCount, @@ -1076,20 +1074,6 @@ ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions) cxxopts::value<int>(ServerOptions.BasePort)->default_value("1337"), "<port number>"); - options.add_option("network", - "", - "websocket-port", - "Websocket server port", - cxxopts::value<int>(ServerOptions.WebSocketPort)->default_value("0"), - "<port number>"); - - options.add_option("network", - "", - "websocket-threads", - "Number of websocket I/O thread(s) (0 == hardware concurrency)", - cxxopts::value<int>(ServerOptions.WebSocketThreads)->default_value("0"), - ""); - options.add_option("httpsys", "", "httpsys-async-work-threads", diff --git a/src/zenserver/config.h b/src/zenserver/config.h index 924375a19..df1ccb752 100644 --- a/src/zenserver/config.h +++ b/src/zenserver/config.h @@ -134,18 +134,16 @@ struct ZenServerOptions ZenObjectStoreConfig ObjectStoreConfig; zen::HttpServerConfig HttpServerConfig; ZenStructuredCacheConfig StructuredCacheConfig; - std::filesystem::path DataDir; // Root directory for state (used for testing) - std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental) - std::filesystem::path AbsLogFile; // Absolute path to main log file - std::filesystem::path ConfigFile; // Path to Lua config file - std::string ChildId; // Id assigned by parent process (used for lifetime management) - std::string LogId; // Id for tagging log output - std::string EncryptionKey; // 256 bit AES encryption key - std::string EncryptionIV; // 128 bit AES initialization vector - int BasePort = 1337; // Service listen port (used for both UDP and TCP) - int OwnerPid = 0; // Parent process id (zero for standalone) - int WebSocketPort = 0; // Web socket port (Zero = disabled) - int WebSocketThreads = 0; + std::filesystem::path DataDir; // Root directory for state (used for testing) + std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental) + std::filesystem::path AbsLogFile; // Absolute path to main log file + std::filesystem::path ConfigFile; // Path to Lua config file + std::string ChildId; // Id assigned by parent process (used for lifetime management) + std::string LogId; // Id for tagging log output + std::string EncryptionKey; // 256 bit AES encryption key + std::string EncryptionIV; // 128 bit AES initialization vector + int BasePort = 1337; // Service listen port (used for both UDP and TCP) + int OwnerPid = 0; // Parent process id (zero for standalone) bool InstallService = false; // Flag used to initiate service install (temporary) bool UninstallService = false; // Flag used to initiate service uninstall (temporary) bool IsDebug = false; diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index cf9f03d89..83eb81e3b 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -17,7 +17,6 @@ #include <zencore/trace.h> #include <zencore/workthreadpool.h> #include <zenhttp/httpserver.h> -#include <zenhttp/websocket.h> #include <zenstore/cidstore.h> #include <zenstore/scrubcontext.h> #include <zenutil/basicfile.h> @@ -309,15 +308,6 @@ public: m_Http = zen::CreateHttpServer(ServerOptions.HttpServerConfig); int EffectiveBasePort = m_Http->Initialize(ServerOptions.BasePort); - if (ServerOptions.WebSocketPort != 0) - { - const uint32_t ThreadCount = - ServerOptions.WebSocketThreads > 0 ? uint32_t(ServerOptions.WebSocketThreads) : std::thread::hardware_concurrency(); - - m_WebSocket = zen::WebSocketServer::Create( - {.Port = gsl::narrow<uint16_t>(ServerOptions.WebSocketPort), .ThreadCount = Max(ThreadCount, uint32_t(16))}); - } - // Setup authentication manager { std::string EncryptionKey = ServerOptions.EncryptionKey; @@ -396,11 +386,6 @@ public: #if ZEN_WITH_TESTS m_Http->RegisterService(m_TestingService); - - if (m_WebSocket) - { - m_WebSocket->RegisterService(m_TestingService); - } #endif if (m_HttpProjectService) @@ -526,11 +511,6 @@ public: OnReady(); - if (m_WebSocket) - { - m_WebSocket->Run(); - } - m_Http->Run(IsInteractiveMode); SetNewState(kShuttingDown); @@ -590,7 +570,6 @@ public: m_CidStore.reset(); m_AuthService.reset(); m_AuthMgr.reset(); - m_WebSocket.reset(); m_Http = {}; m_JobQueue.reset(); } @@ -788,7 +767,6 @@ private: } zen::Ref<zen::HttpServer> m_Http; - std::unique_ptr<zen::WebSocketServer> m_WebSocket; std::unique_ptr<zen::AuthMgr> m_AuthMgr; std::unique_ptr<zen::HttpAuthService> m_AuthService; zen::HttpStatusService m_StatusService; |