aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2023-10-04 13:13:23 +0200
committerGitHub <[email protected]>2023-10-04 13:13:23 +0200
commit3e8db7cd243e8be3b2d5fea2490b9ad70f765590 (patch)
treec1533ae25b0b717dd2393960c5449aadd56807c4 /src
parentfactored out http parser from asio into separate files (#444) (diff)
downloadzen-3e8db7cd243e8be3b2d5fea2490b9ad70f765590.tar.xz
zen-3e8db7cd243e8be3b2d5fea2490b9ad70f765590.zip
removed websocket protocol support(#445)
removed websocket support since it is not used right now and is unlikely to be used in the future
Diffstat (limited to 'src')
-rw-r--r--src/zenhttp/include/zenhttp/httptest.h6
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h256
-rw-r--r--src/zenhttp/testing/httptest.cpp32
-rw-r--r--src/zenhttp/websocketasio.cpp1613
-rw-r--r--src/zenserver-test/zenserver-test.cpp52
-rw-r--r--src/zenserver/config.cpp16
-rw-r--r--src/zenserver/config.h22
-rw-r--r--src/zenserver/zenserver.cpp22
8 files changed, 11 insertions, 2008 deletions
diff --git a/src/zenhttp/include/zenhttp/httptest.h b/src/zenhttp/include/zenhttp/httptest.h
index 74db69785..a4008fb5e 100644
--- a/src/zenhttp/include/zenhttp/httptest.h
+++ b/src/zenhttp/include/zenhttp/httptest.h
@@ -5,7 +5,6 @@
#include <zencore/logging.h>
#include <zencore/stats.h>
#include <zenhttp/httpserver.h>
-#include <zenhttp/websocket.h>
#include <atomic>
@@ -16,7 +15,7 @@ namespace zen {
/**
* Test service to facilitate testing the HTTP framework and client interactions
*/
-class HttpTestingService : public HttpService, public WebSocketService
+class HttpTestingService : public HttpService
{
public:
HttpTestingService();
@@ -43,9 +42,6 @@ public:
};
private:
- virtual void RegisterHandlers(WebSocketServer& Server) override;
- virtual bool HandleRequest(const WebSocketMessage& Request) override;
-
HttpRequestRouter m_Router;
std::atomic<uint32_t> m_Counter{0};
metrics::OperationTiming m_TimingStats;
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
deleted file mode 100644
index adca7e988..000000000
--- a/src/zenhttp/include/zenhttp/websocket.h
+++ /dev/null
@@ -1,256 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include <zencore/compactbinarypackage.h>
-#include <zencore/memory.h>
-
-#include <compare>
-#include <functional>
-#include <future>
-#include <memory>
-#include <optional>
-
-#pragma once
-
-namespace asio {
-class io_context;
-}
-
-namespace zen {
-
-class BinaryWriter;
-
-/**
- * A unique socket ID.
- */
-class WebSocketId
-{
- static std::atomic_uint32_t NextId;
-
-public:
- WebSocketId() = default;
-
- uint32_t Value() const { return m_Value; }
-
- auto operator<=>(const WebSocketId&) const = default;
-
- static WebSocketId New() { return WebSocketId(NextId.fetch_add(1)); }
-
-private:
- WebSocketId(uint32_t Value) : m_Value(Value) {}
-
- uint32_t m_Value{};
-};
-
-/**
- * Type of web socket message.
- */
-enum class WebSocketMessageType : uint8_t
-{
- kInvalid,
- kNotification,
- kRequest,
- kStreamRequest,
- kResponse,
- kStreamResponse,
- kStreamCompleteResponse,
- kCount
-};
-
-inline std::string_view
-ToString(WebSocketMessageType Type)
-{
- switch (Type)
- {
- case WebSocketMessageType::kInvalid:
- return std::string_view("Invalid");
- case WebSocketMessageType::kNotification:
- return std::string_view("Notification");
- case WebSocketMessageType::kRequest:
- return std::string_view("Request");
- case WebSocketMessageType::kStreamRequest:
- return std::string_view("StreamRequest");
- case WebSocketMessageType::kResponse:
- return std::string_view("Response");
- case WebSocketMessageType::kStreamResponse:
- return std::string_view("StreamResponse");
- case WebSocketMessageType::kStreamCompleteResponse:
- return std::string_view("StreamCompleteResponse");
- default:
- return std::string_view("Unknown");
- };
-}
-
-/**
- * Web socket message.
- */
-class WebSocketMessage
-{
- struct Header
- {
- static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh
-
- uint64_t MessageSize{};
- uint32_t Magic{ExpectedMagic};
- uint32_t CorrelationId{};
- uint32_t StatusCode{200u};
- WebSocketMessageType MessageType{};
- uint8_t Reserved[3] = {0};
-
- bool IsValid() const;
- };
-
- static_assert(sizeof(Header) == 24);
-
- static std::atomic_uint32_t NextCorrelationId;
-
-public:
- static constexpr size_t HeaderSize = sizeof(Header);
-
- WebSocketMessage() = default;
-
- WebSocketId SocketId() const { return m_SocketId; }
- void SetSocketId(WebSocketId Id) { m_SocketId = Id; }
- uint64_t MessageSize() const { return m_Header.MessageSize; }
- void SetMessageType(WebSocketMessageType MessageType);
- void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; }
- uint32_t CorrelationId() const { return m_Header.CorrelationId; }
- uint32_t StatusCode() const { return m_Header.StatusCode; }
- void SetStatusCode(uint32_t StatusCode) { m_Header.StatusCode = StatusCode; }
- WebSocketMessageType MessageType() const { return m_Header.MessageType; }
-
- const CbPackage& Body() const { return m_Body.value(); }
- void SetBody(CbPackage&& Body);
- void SetBody(CbObject&& Body);
- bool HasBody() const { return m_Body.has_value(); }
-
- void Save(BinaryWriter& Writer);
- bool TryLoadHeader(MemoryView Memory);
-
- bool IsValid() const { return m_Header.MessageType != WebSocketMessageType::kInvalid; }
-
-private:
- Header m_Header{};
- WebSocketId m_SocketId{};
- std::optional<CbPackage> m_Body;
-};
-
-class WebSocketServer;
-
-/**
- * Base class for handling web socket requests and notifications from connected client(s).
- */
-class WebSocketService
-{
-public:
- virtual ~WebSocketService() = default;
-
- void Configure(WebSocketServer& Server);
-
- virtual bool HandleRequest(const WebSocketMessage&) { ZEN_ASSERT(false); }
- virtual void HandleNotification(const WebSocketMessage&) { ZEN_ASSERT(false); }
-
-protected:
- WebSocketService() = default;
-
- virtual void RegisterHandlers(WebSocketServer& Server) = 0;
- void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete);
- void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete);
-
- WebSocketServer& SocketServer()
- {
- ZEN_ASSERT(m_SocketServer);
- return *m_SocketServer;
- }
-
-private:
- WebSocketServer* m_SocketServer{};
-};
-
-/**
- * Server options.
- */
-struct WebSocketServerOptions
-{
- uint16_t Port = 2337;
- uint32_t ThreadCount = 1;
-};
-
-/**
- * The web socket server manages client connections and routing of requests and notifications.
- */
-class WebSocketServer
-{
-public:
- virtual ~WebSocketServer() = default;
-
- virtual bool Run() = 0;
- virtual void Shutdown() = 0;
-
- virtual void RegisterService(WebSocketService& Service) = 0;
- virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) = 0;
- virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) = 0;
-
- virtual void SendNotification(WebSocketMessage&& Notification) = 0;
- virtual void SendResponse(WebSocketMessage&& Response) = 0;
-
- static std::unique_ptr<WebSocketServer> Create(const WebSocketServerOptions& Options);
-};
-
-/**
- * The state of the web socket.
- */
-enum class WebSocketState : uint32_t
-{
- kNone,
- kHandshaking,
- kConnected,
- kDisconnected,
- kError
-};
-
-/**
- * Type of web socket client event.
- */
-enum class WebSocketEvent : uint32_t
-{
- kConnected,
- kDisconnected,
- kError
-};
-
-/**
- * Web socket client connection info.
- */
-struct WebSocketConnectInfo
-{
- std::string Host;
- int16_t Port{8848};
- std::string Endpoint;
- std::vector<std::string> Protocols;
- uint16_t Version{13};
-};
-
-/**
- * A connection to a web socket server for sending requests and listening for notifications.
- */
-class WebSocketClient
-{
-public:
- using EventCallback = std::function<void()>;
- using NotificationCallback = std::function<void(WebSocketMessage&&)>;
-
- virtual ~WebSocketClient() = default;
-
- virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) = 0;
- virtual void Disconnect() = 0;
- virtual bool IsConnected() const = 0;
- virtual WebSocketState State() const = 0;
-
- virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) = 0;
- virtual void OnNotification(NotificationCallback&& Cb) = 0;
- virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) = 0;
-
- static std::shared_ptr<WebSocketClient> Create(asio::io_context& IoCtx);
-};
-
-} // namespace zen
diff --git a/src/zenhttp/testing/httptest.cpp b/src/zenhttp/testing/httptest.cpp
index a02e36bcc..3a0ad72a9 100644
--- a/src/zenhttp/testing/httptest.cpp
+++ b/src/zenhttp/testing/httptest.cpp
@@ -140,38 +140,6 @@ HttpTestingService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
return (InsertResult.first->second = Ref<PackageHandler>(new PackageHandler(*this, RequestId)));
}
-void
-HttpTestingService::RegisterHandlers(WebSocketServer& Server)
-{
- Server.RegisterRequestHandler("SayHello"sv, *this);
-}
-
-bool
-HttpTestingService::HandleRequest(const WebSocketMessage& RequestMsg)
-{
- CbObjectView Request = RequestMsg.Body().GetObject();
-
- std::string_view Method = Request["Method"].AsString();
-
- if (Method != "SayHello"sv)
- {
- return false;
- }
-
- CbObjectWriter Response;
- Response.AddString("Result"sv, "Hello Friend!!");
-
- WebSocketMessage ResponseMsg;
- ResponseMsg.SetMessageType(WebSocketMessageType::kResponse);
- ResponseMsg.SetCorrelationId(RequestMsg.CorrelationId());
- ResponseMsg.SetSocketId(RequestMsg.SocketId());
- ResponseMsg.SetBody(Response.Save());
-
- SocketServer().SendResponse(std::move(ResponseMsg));
-
- return true;
-}
-
//////////////////////////////////////////////////////////////////////////
HttpTestingService::PackageHandler::PackageHandler(HttpTestingService& Svc, uint32_t RequestId) : m_Svc(Svc), m_RequestId(RequestId)
diff --git a/src/zenhttp/websocketasio.cpp b/src/zenhttp/websocketasio.cpp
deleted file mode 100644
index bbe7e1ad8..000000000
--- a/src/zenhttp/websocketasio.cpp
+++ /dev/null
@@ -1,1613 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include <zenhttp/websocket.h>
-
-#include <zencore/base64.h>
-#include <zencore/compactbinarybuilder.h>
-#include <zencore/compactbinaryvalidation.h>
-#include <zencore/intmath.h>
-#include <zencore/iobuffer.h>
-#include <zencore/logging.h>
-#include <zencore/memory.h>
-#include <zencore/sha1.h>
-#include <zencore/stream.h>
-#include <zencore/string.h>
-#include <zencore/trace.h>
-
-#include <chrono>
-#include <optional>
-#include <shared_mutex>
-#include <span>
-#include <system_error>
-#include <thread>
-
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <fmt/format.h>
-#include <http_parser.h>
-#include <asio.hpp>
-ZEN_THIRD_PARTY_INCLUDES_END
-
-#if ZEN_PLATFORM_WINDOWS
-# include <mstcpip.h>
-#endif
-
-namespace zen::websocket {
-
-using namespace std::literals;
-
-ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv);
-
-ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv);
-
-using Clock = std::chrono::steady_clock;
-using TimePoint = Clock::time_point;
-
-///////////////////////////////////////////////////////////////////////////////
-namespace http_header {
- static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv;
- static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv;
- static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv;
- static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv;
- static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv;
- static constexpr std::string_view Upgrade = "Upgrade"sv;
-} // namespace http_header
-
-///////////////////////////////////////////////////////////////////////////////
-enum class ParseMessageStatus : uint32_t
-{
- kError,
- kContinue,
- kDone,
-};
-
-struct ParseMessageResult
-{
- ParseMessageStatus Status{};
- size_t ByteCount{};
- std::optional<std::string> Reason;
-};
-
-class MessageParser
-{
-public:
- virtual ~MessageParser() = default;
-
- ParseMessageResult ParseMessage(MemoryView Msg);
- void Reset();
-
-protected:
- MessageParser() = default;
-
- virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0;
- virtual void OnReset() = 0;
-
- BinaryWriter m_Stream;
-};
-
-ParseMessageResult
-MessageParser::ParseMessage(MemoryView Msg)
-{
- return OnParseMessage(Msg);
-}
-
-void
-MessageParser::Reset()
-{
- OnReset();
-
- m_Stream.Reset();
-}
-
-///////////////////////////////////////////////////////////////////////////////
-enum class HttpMessageParserType
-{
- kRequest,
- kResponse,
- kBoth
-};
-
-class HttpMessageParser final : public MessageParser
-{
-public:
- using HttpHeaders = std::unordered_map<std::string_view, std::string_view>;
-
- HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); }
-
- virtual ~HttpMessageParser() = default;
-
- int32_t StatusCode() const { return m_Parser.status_code; }
- bool IsUpgrade() const { return m_Parser.upgrade != 0; }
- HttpHeaders& Headers() { return m_Headers; }
- MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); }
-
- std::string_view StatusText() const
- {
- return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size);
- }
-
- bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason);
-
-private:
- void Initialize();
- virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
- virtual void OnReset() override;
- int OnMessageBegin();
- int OnUrl(MemoryView Url);
- int OnStatus(MemoryView Status);
- int OnHeaderField(MemoryView HeaderField);
- int OnHeaderValue(MemoryView HeaderValue);
- int OnHeadersComplete();
- int OnBody(MemoryView Body);
- int OnMessageComplete();
-
- struct StreamEntry
- {
- uint64_t Offset{};
- uint64_t Size{};
- };
-
- struct HeaderStreamEntry
- {
- StreamEntry Field{};
- StreamEntry Value{};
- };
-
- HttpMessageParserType m_Type;
- http_parser m_Parser;
- StreamEntry m_UrlEntry;
- StreamEntry m_StatusEntry;
- StreamEntry m_BodyEntry;
- HeaderStreamEntry m_CurrentHeader;
- std::vector<HeaderStreamEntry> m_HeaderEntries;
- HttpHeaders m_Headers;
- bool m_IsMsgComplete{false};
-
- static http_parser_settings ParserSettings;
-};
-
-http_parser_settings HttpMessageParser::ParserSettings = {
- .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); },
-
- .on_url = [](http_parser* P,
- const char* Data,
- size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); },
-
- .on_status = [](http_parser* P,
- const char* Data,
- size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); },
-
- .on_header_field = [](http_parser* P,
- const char* Data,
- size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); },
-
- .on_header_value = [](http_parser* P,
- const char* Data,
- size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); },
-
- .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); },
-
- .on_body = [](http_parser* P,
- const char* Data,
- size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); },
-
- .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }};
-
-void
-HttpMessageParser::Initialize()
-{
- http_parser_init(&m_Parser,
- m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST
- : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE
- : HTTP_BOTH);
- m_Parser.data = this;
-
- m_UrlEntry = {};
- m_StatusEntry = {};
- m_CurrentHeader = {};
- m_BodyEntry = {};
-
- m_IsMsgComplete = false;
-
- m_HeaderEntries.clear();
-}
-
-ParseMessageResult
-HttpMessageParser::OnParseMessage(MemoryView Msg)
-{
- const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize());
-
- auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue;
-
- if (m_Parser.http_errno != 0)
- {
- Status = ParseMessageStatus::kError;
- }
-
- return {.Status = Status, .ByteCount = uint64_t(ByteCount)};
-}
-
-void
-HttpMessageParser::OnReset()
-{
- Initialize();
-}
-
-int
-HttpMessageParser::OnMessageBegin()
-{
- ZEN_ASSERT(m_IsMsgComplete == false);
- ZEN_ASSERT(m_HeaderEntries.empty());
- ZEN_ASSERT(m_Headers.empty());
-
- return 0;
-}
-
-int
-HttpMessageParser::OnStatus(MemoryView Status)
-{
- m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()};
-
- m_Stream.Write(Status);
-
- return 0;
-}
-
-int
-HttpMessageParser::OnUrl(MemoryView Url)
-{
- m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()};
-
- m_Stream.Write(Url);
-
- return 0;
-}
-
-int
-HttpMessageParser::OnHeaderField(MemoryView HeaderField)
-{
- if (m_CurrentHeader.Value.Size > 0)
- {
- m_HeaderEntries.push_back(m_CurrentHeader);
- m_CurrentHeader = {};
- }
-
- if (m_CurrentHeader.Field.Size == 0)
- {
- m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset();
- }
-
- m_CurrentHeader.Field.Size += HeaderField.GetSize();
-
- m_Stream.Write(HeaderField);
-
- return 0;
-}
-
-int
-HttpMessageParser::OnHeaderValue(MemoryView HeaderValue)
-{
- if (m_CurrentHeader.Value.Size == 0)
- {
- m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset();
- }
-
- m_CurrentHeader.Value.Size += HeaderValue.GetSize();
-
- m_Stream.Write(HeaderValue);
-
- return 0;
-}
-
-int
-HttpMessageParser::OnHeadersComplete()
-{
- if (m_CurrentHeader.Value.Size > 0)
- {
- m_HeaderEntries.push_back(m_CurrentHeader);
- m_CurrentHeader = {};
- }
-
- m_Headers.clear();
- m_Headers.reserve(m_HeaderEntries.size());
-
- const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data());
-
- for (const auto& Entry : m_HeaderEntries)
- {
- auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size);
- auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size);
-
- m_Headers.try_emplace(std::move(Field), std::move(Value));
- }
-
- return 0;
-}
-
-int
-HttpMessageParser::OnBody(MemoryView Body)
-{
- m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()};
-
- m_Stream.Write(Body);
-
- return 0;
-}
-
-int
-HttpMessageParser::OnMessageComplete()
-{
- m_IsMsgComplete = true;
-
- return 0;
-}
-
-bool
-HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason)
-{
- static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv;
-
- OutAcceptHash = std::string();
-
- if (m_Headers.contains(http_header::SecWebSocketKey) == false)
- {
- OutReason = "Missing header Sec-WebSocket-Key";
- return false;
- }
-
- if (m_Headers.contains(http_header::Upgrade) == false)
- {
- OutReason = "Missing header Upgrade";
- return false;
- }
-
- ExtendableStringBuilder<128> Sb;
- Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid;
-
- SHA1Stream HashStream;
- HashStream.Append(Sb.Data(), Sb.Size());
-
- SHA1 Hash = HashStream.GetHash();
-
- OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash)));
- Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data());
-
- return true;
-}
-
-///////////////////////////////////////////////////////////////////////////////
-class WebSocketMessageParser final : public MessageParser
-{
-public:
- WebSocketMessageParser() : MessageParser() {}
-
- WebSocketMessage ConsumeMessage();
-
-private:
- virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
- virtual void OnReset() override;
-
- WebSocketMessage m_Message;
-};
-
-ParseMessageResult
-WebSocketMessageParser::OnParseMessage(MemoryView Msg)
-{
- ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage");
-
- const uint64_t PrevOffset = m_Stream.CurrentOffset();
-
- if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
- {
- const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset();
-
- m_Stream.Write(Msg.Left(RemaingHeaderSize));
- Msg += RemaingHeaderSize;
-
- if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
- {
- return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
- }
-
- const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView());
-
- if (IsValidHeader == false)
- {
- OnReset();
-
- return {.Status = ParseMessageStatus::kError,
- .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
- .Reason = std::string("Invalid websocket message header")};
- }
-
- if (m_Message.MessageSize() == 0)
- {
- return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
- }
- }
-
- ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize);
-
- if (Msg.IsEmpty() == false)
- {
- const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset();
- m_Stream.Write(Msg.Left(RemaingMessageSize));
- }
-
- auto Status = ParseMessageStatus::kContinue;
-
- if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize())
- {
- Status = ParseMessageStatus::kDone;
-
- BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize));
-
- CbPackage Pkg;
- if (Pkg.TryLoad(Reader) == false)
- {
- return {.Status = ParseMessageStatus::kError,
- .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
- .Reason = std::string("Invalid websocket message")};
- }
-
- m_Message.SetBody(std::move(Pkg));
- }
-
- return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
-}
-
-void
-WebSocketMessageParser::OnReset()
-{
- m_Message = WebSocketMessage();
-}
-
-WebSocketMessage
-WebSocketMessageParser::ConsumeMessage()
-{
- WebSocketMessage Msg = std::move(m_Message);
- m_Message = WebSocketMessage();
-
- return Msg;
-}
-
-///////////////////////////////////////////////////////////////////////////////
-class WsConnection : public std::enable_shared_from_this<WsConnection>
-{
-public:
- WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket)
- : m_Id(Id)
- , m_Socket(std::move(Socket))
- , m_StartTime(Clock::now())
- , m_State()
- {
- }
-
- ~WsConnection() = default;
-
- std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); }
-
- WebSocketId Id() const { return m_Id; }
- asio::ip::tcp::socket& Socket() { return *m_Socket; }
- TimePoint StartTime() const { return m_StartTime; }
- WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); }
- std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); }
- asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
- WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
- WebSocketState Close();
- MessageParser* Parser() { return m_MsgParser.get(); }
- void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
- std::mutex& WriteMutex() { return m_WriteMutex; }
-
-private:
- WebSocketId m_Id;
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- TimePoint m_StartTime;
- std::atomic_uint32_t m_State;
- std::unique_ptr<MessageParser> m_MsgParser;
- asio::streambuf m_ReadBuffer;
- std::mutex m_WriteMutex;
-};
-
-WebSocketState
-WsConnection::Close()
-{
- const auto PrevState = SetState(WebSocketState::kDisconnected);
-
- if (PrevState != WebSocketState::kDisconnected && m_Socket->is_open())
- {
- m_Socket->close();
- }
-
- return PrevState;
-}
-
-///////////////////////////////////////////////////////////////////////////////
-class WsThreadPool
-{
-public:
- WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {}
- void Start(uint32_t ThreadCount);
- void Stop();
-
-private:
- asio::io_service& m_IoSvc;
- std::vector<std::thread> m_Threads;
- std::atomic_bool m_Running{false};
-};
-
-void
-WsThreadPool::Start(uint32_t ThreadCount)
-{
- ZEN_ASSERT(m_Threads.empty());
-
- ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount);
-
- m_Running = true;
-
- for (uint32_t Idx = 0; Idx < ThreadCount; Idx++)
- {
- m_Threads.emplace_back([this, ThreadId = Idx + 1] {
- for (;;)
- {
- if (m_Running == false)
- {
- break;
- }
-
- try
- {
- m_IoSvc.run();
- }
- catch (std::exception& Err)
- {
- ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what());
- }
- }
-
- ZEN_LOG_TRACE(LogWebSocket, "websocket I/O thread '{}' exiting", ThreadId);
- });
- }
-}
-
-void
-WsThreadPool::Stop()
-{
- if (m_Running)
- {
- m_Running = false;
-
- for (std::thread& Thread : m_Threads)
- {
- if (Thread.joinable())
- {
- Thread.join();
- }
- }
-
- m_Threads.clear();
- }
-}
-
-///////////////////////////////////////////////////////////////////////////////
-class WsServer final : public WebSocketServer
-{
-public:
- WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {}
- virtual ~WsServer() { Shutdown(); }
-
- virtual bool Run() override;
- virtual void Shutdown() override;
-
- virtual void RegisterService(WebSocketService& Service) override;
- virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override;
- virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override;
-
- virtual void SendNotification(WebSocketMessage&& Notification) override;
- virtual void SendResponse(WebSocketMessage&& Response) override;
-
-private:
- friend class WsConnection;
-
- void AcceptConnection();
- void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec);
-
- void ReadMessage(std::shared_ptr<WsConnection> Connection);
- void RouteMessage(WebSocketMessage&& Msg);
- void SendMessage(WebSocketMessage&& Msg);
-
- struct IdHasher
- {
- size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); }
- };
-
- using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>;
- using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>;
- using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>;
-
- WebSocketServerOptions m_Options;
- asio::io_service m_IoSvc;
- std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor;
- std::unique_ptr<WsThreadPool> m_ThreadPool;
- ConnectionMap m_Connections;
- std::shared_mutex m_ConnMutex;
- std::vector<WebSocketService*> m_Services;
- RequestHandlerMap m_RequestHandlers;
- NotificationHandlerMap m_NotificationHandlers;
- std::atomic_bool m_Running{};
-};
-
-void
-WsServer::RegisterService(WebSocketService& Service)
-{
- m_Services.push_back(&Service);
-
- Service.Configure(*this);
-}
-
-bool
-WsServer::Run()
-{
- static constexpr size_t ReceiveBufferSize = 256 << 10;
- static constexpr size_t SendBufferSize = 256 << 10;
-
- m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6());
-
- m_Acceptor->set_option(asio::ip::v6_only(false));
- m_Acceptor->set_option(asio::socket_base::reuse_address(true));
- m_Acceptor->set_option(asio::ip::tcp::no_delay(true));
- m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize));
- m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize));
-
-#if ZEN_PLATFORM_WINDOWS
- // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
- // This must be used by both the client and server side, and is only effective in the absence of
- // Windows Filtering Platform (WFP) callouts which can be installed by security software.
- // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
- SOCKET NativeSocket = m_Acceptor->native_handle();
- int LoopbackOptionValue = 1;
- DWORD OptionNumberOfBytesReturned = 0;
- WSAIoctl(NativeSocket,
- SIO_LOOPBACK_FAST_PATH,
- &LoopbackOptionValue,
- sizeof(LoopbackOptionValue),
- NULL,
- 0,
- &OptionNumberOfBytesReturned,
- 0,
- 0);
-#endif
-
- asio::error_code Ec;
- m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec);
-
- if (Ec)
- {
- ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value());
-
- return false;
- }
-
- m_Acceptor->listen();
- m_Running = true;
-
- ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port);
-
- AcceptConnection();
-
- m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc);
- m_ThreadPool->Start(m_Options.ThreadCount);
-
- return true;
-}
-
-void
-WsServer::Shutdown()
-{
- if (m_Running)
- {
- ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down");
-
- m_Running = false;
-
- m_Acceptor->close();
- m_Acceptor.reset();
- m_IoSvc.stop();
-
- m_ThreadPool->Stop();
- }
-}
-
-void
-WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service)
-{
- auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>());
- Result.first->second.push_back(&Service);
-}
-
-void
-WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service)
-{
- m_RequestHandlers[Key] = &Service;
-}
-
-void
-WsServer::SendNotification(WebSocketMessage&& Notification)
-{
- ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification);
-
- SendMessage(std::move(Notification));
-}
-void
-WsServer::SendResponse(WebSocketMessage&& Response)
-{
- ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse ||
- Response.MessageType() == WebSocketMessageType::kStreamResponse ||
- Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse);
-
- ZEN_ASSERT(Response.CorrelationId() != 0);
-
- SendMessage(std::move(Response));
-}
-
-void
-WsServer::AcceptConnection()
-{
- auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc);
- asio::ip::tcp::socket& SocketRef = *Socket.get();
-
- m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable {
- if (m_Running)
- {
- if (Ec)
- {
- ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message());
- }
- else
- {
- auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket));
-
- ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr());
-
- {
- std::unique_lock _(m_ConnMutex);
- m_Connections[Connection->Id()] = Connection;
- }
-
- Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest));
- Connection->SetState(WebSocketState::kHandshaking);
-
- ReadMessage(Connection);
- }
-
- AcceptConnection();
- }
- });
-}
-
-void
-WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec)
-{
- if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected)
- {
- if (Ec)
- {
- ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, reason '{} ({})'", Connection->Id().Value(), Ec.message(), Ec.value());
- }
- else
- {
- ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value());
- }
- }
-
- const WebSocketId Id = Connection->Id();
-
- {
- std::unique_lock _(m_ConnMutex);
- if (m_Connections.contains(Id))
- {
- m_Connections.erase(Id);
- }
- }
-}
-
-void
-WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
-{
- Connection->ReadBuffer().prepare(64 << 10);
-
- asio::async_read(
- Connection->Socket(),
- Connection->ReadBuffer(),
- asio::transfer_at_least(1),
- [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable {
- if (ReadEc)
- {
- return CloseConnection(Connection, ReadEc);
- }
-
- switch (Connection->State())
- {
- case WebSocketState::kHandshaking:
- {
- HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser());
- asio::const_buffer Buffer = Connection->ReadBuffer().data();
-
- ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size()));
-
- Connection->ReadBuffer().consume(Result.ByteCount);
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- return ReadMessage(Connection);
- }
-
- if (Result.Status == ParseMessageStatus::kError)
- {
- ZEN_LOG_WARN(LogWebSocket,
- "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'",
- Connection->Id().Value(),
- Connection->RemoteAddr());
-
- return CloseConnection(Connection, std::error_code());
- }
-
- if (Parser.IsUpgrade() == false)
- {
- ZEN_LOG_DEBUG(LogWebSocket,
- "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'",
- Connection->Id().Value(),
- Connection->RemoteAddr());
-
- constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv;
-
- return async_write(Connection->Socket(),
- asio::buffer(UpgradeRequiredResponse),
- [this, Connection](const asio::error_code& WriteEc, std::size_t) {
- if (WriteEc)
- {
- return CloseConnection(Connection, WriteEc);
- }
-
- Connection->Parser()->Reset();
- Connection->SetState(WebSocketState::kHandshaking);
-
- ReadMessage(Connection);
- });
- }
-
- ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
-
- std::string AcceptHash;
- std::string Reason;
- const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason);
-
- if (ValidHandshake == false)
- {
- ZEN_LOG_DEBUG(LogWebSocket,
- "handshake with connection '{}' FAILED, reason '{}'",
- Connection->Id().Value(),
- Reason);
-
- constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv;
-
- return async_write(Connection->Socket(),
- asio::buffer(UpgradeRequiredResponse),
- [this, &Connection](const asio::error_code& WriteEc, std::size_t) {
- if (WriteEc)
- {
- return CloseConnection(Connection, WriteEc);
- }
-
- Connection->Parser()->Reset();
- Connection->SetState(WebSocketState::kHandshaking);
-
- ReadMessage(Connection);
- });
- }
-
- ExtendableStringBuilder<128> Sb;
-
- Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv;
- Sb << "Upgrade: websocket\r\n"sv;
- Sb << "Connection: Upgrade\r\n"sv;
-
- // TODO: Verify protocol
- if (Parser.Headers().contains(http_header::SecWebSocketProtocol))
- {
- Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol]
- << "\r\n";
- }
-
- Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n";
- Sb << "\r\n"sv;
-
- ZEN_LOG_DEBUG(LogWebSocket,
- "accepting handshake from connection '#{} {}'",
- Connection->Id().Value(),
- Connection->RemoteAddr());
-
- std::string Response = Sb.ToString();
- Buffer = asio::buffer(Response);
-
- async_write(Connection->Socket(),
- Buffer,
- [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) {
- if (WriteEc)
- {
- ZEN_LOG_DEBUG(LogWebSocket,
- "handshake with connection '{}' FAILED, reason '{}'",
- Connection->Id().Value(),
- WriteEc.message());
-
- return CloseConnection(Connection, WriteEc);
- }
-
- ZEN_LOG_DEBUG(LogWebSocket,
- "handshake ({}B) with connection '#{} {}' OK",
- ByteCount,
- Connection->Id().Value(),
- Connection->RemoteAddr());
-
- Connection->SetParser(std::make_unique<WebSocketMessageParser>());
- Connection->SetState(WebSocketState::kConnected);
-
- ReadMessage(Connection);
- });
- }
- break;
-
- case WebSocketState::kConnected:
- {
- WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser());
-
- uint64_t RemainingBytes = Connection->ReadBuffer().size();
-
- while (RemainingBytes > 0)
- {
- MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes);
- const ParseMessageResult Result = Parser.ParseMessage(MessageData);
-
- Connection->ReadBuffer().consume(Result.ByteCount);
- RemainingBytes = Connection->ReadBuffer().size();
-
- if (Result.Status == ParseMessageStatus::kError)
- {
- ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
-
- return CloseConnection(Connection, std::error_code());
- }
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- ZEN_ASSERT(RemainingBytes == 0);
- continue;
- }
-
- WebSocketMessage Message = Parser.ConsumeMessage();
- Parser.Reset();
-
- Message.SetSocketId(Connection->Id());
-
- RouteMessage(std::move(Message));
- }
-
- ReadMessage(Connection);
- }
- break;
-
- default:
- break;
- };
- });
-}
-
-void
-WsServer::RouteMessage(WebSocketMessage&& RoutedMessage)
-{
- switch (RoutedMessage.MessageType())
- {
- case WebSocketMessageType::kRequest:
- case WebSocketMessageType::kStreamRequest:
- {
- CbObjectView Request = RoutedMessage.Body().GetObject();
- std::string_view Method = Request["Method"].AsString();
- bool Handled = false;
- bool Error = false;
- std::exception Exception;
-
- if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end())
- {
- WebSocketService* Service = It->second;
- ZEN_ASSERT(Service);
-
- try
- {
- Handled = Service->HandleRequest(std::move(RoutedMessage));
- }
- catch (std::exception& Err)
- {
- Exception = std::move(Err);
- Error = true;
- }
- }
-
- if (Error || Handled == false)
- {
- std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method);
-
- ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText);
-
- CbObjectWriter Response;
- Response << "Error"sv << ErrorText;
-
- WebSocketMessage ResponseMsg;
- ResponseMsg.SetMessageType(WebSocketMessageType::kResponse);
- ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId());
- ResponseMsg.SetSocketId(RoutedMessage.SocketId());
- ResponseMsg.SetBody(Response.Save());
-
- SendResponse(std::move(ResponseMsg));
- }
- }
- break;
-
- case WebSocketMessageType::kNotification:
- {
- CbObjectView Notification = RoutedMessage.Body().GetObject();
- std::string_view Message = Notification["Message"].AsString();
-
- if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end())
- {
- std::vector<WebSocketService*>& Handlers = It->second;
-
- for (WebSocketService* Handler : Handlers)
- {
- Handler->HandleNotification(RoutedMessage);
- }
- }
- else
- {
- ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message);
- }
- }
- break;
-
- default:
- break;
- };
-}
-
-void
-WsServer::SendMessage(WebSocketMessage&& Msg)
-{
- std::shared_ptr<WsConnection> Connection;
-
- {
- std::unique_lock _(m_ConnMutex);
-
- if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end())
- {
- Connection = It->second;
- }
- }
-
- if (Connection.get() == nullptr)
- {
- ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value());
- return;
- }
-
- if (Connection.get() != nullptr)
- {
- BinaryWriter Writer;
- Msg.Save(Writer);
-
- ZEN_LOG_TRACE(LogWebSocket,
- "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}",
- ToString(Msg.MessageType()),
- Connection->Id().Value(),
- Msg.MessageSize(),
- Msg.CorrelationId(),
- NiceBytes(Writer.Size()));
-
- {
- ZEN_TRACE_CPU("WS::SendMessage");
- std::unique_lock _(Connection->WriteMutex());
- ZEN_TRACE_CPU("WS::WriteSocketData");
- asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size()));
- }
- }
-}
-
-///////////////////////////////////////////////////////////////////////////////
-class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient>
-{
-public:
- WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {}
-
- virtual ~WsClient() { Disconnect(); }
-
- std::shared_ptr<WsClient> AsShared() { return shared_from_this(); }
-
- virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override;
- virtual void Disconnect() override;
- virtual bool IsConnected() const override { return false; }
- virtual WebSocketState State() const override { return static_cast<WebSocketState>(m_State.load()); }
-
- virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override;
- virtual void OnNotification(NotificationCallback&& Cb) override;
- virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override;
-
-private:
- WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
- MessageParser* Parser() { return m_MsgParser.get(); }
- void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
- asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
- void TriggerEvent(WebSocketEvent Evt);
- void ReadMessage();
- void RouteMessage(WebSocketMessage&& RoutedMessage);
-
- using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>;
-
- asio::io_context& m_IoCtx;
- WebSocketId m_Id;
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- std::unique_ptr<MessageParser> m_MsgParser;
- asio::streambuf m_ReadBuffer;
- EventCallback m_EventCallbacks[3];
- NotificationCallback m_NotificationCallback;
- PendingRequestMap m_PendingRequests;
- std::mutex m_RequestMutex;
- std::promise<bool> m_ConnectPromise;
- std::atomic_uint32_t m_State;
- std::string m_Host;
- int16_t m_Port{};
-};
-
-std::future<bool>
-WsClient::Connect(const WebSocketConnectInfo& Info)
-{
- if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected)
- {
- return m_ConnectPromise.get_future();
- }
-
- SetState(WebSocketState::kHandshaking);
-
- try
- {
- asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port);
- m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol());
-
- m_Socket->connect(Endpoint);
-
- m_Host = m_Socket->remote_endpoint().address().to_string();
- m_Port = Info.Port;
-
- ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port);
- }
- catch (std::exception& Err)
- {
- ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what());
-
- SetState(WebSocketState::kError);
- m_Socket.reset();
-
- TriggerEvent(WebSocketEvent::kDisconnected);
-
- m_ConnectPromise.set_value(false);
-
- return m_ConnectPromise.get_future();
- }
-
- ExtendableStringBuilder<128> Sb;
- Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv;
- Sb << "Host: " << Info.Host << "\r\n"sv;
- Sb << "Upgrade: websocket\r\n"sv;
- Sb << "Connection: upgrade\r\n"sv;
- Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv;
-
- if (Info.Protocols.empty() == false)
- {
- Sb << "Sec-WebSocket-Protocol: "sv;
- for (size_t Idx = 0; const auto& Protocol : Info.Protocols)
- {
- if (Idx++)
- {
- Sb << ", ";
- }
- Sb << Protocol;
- }
- }
-
- Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv;
- Sb << "\r\n";
-
- std::string HandshakeRequest = Sb.ToString();
- asio::const_buffer Buffer = asio::buffer(HandshakeRequest);
-
- ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port);
-
- m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse);
- m_MsgParser->Reset();
-
- async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) {
- if (Ec)
- {
- ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message());
-
- Self->Disconnect();
- }
- else
- {
- Self->ReadMessage();
- }
- });
-
- return m_ConnectPromise.get_future();
-}
-
-void
-WsClient::Disconnect()
-{
- if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected)
- {
- ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port);
-
- if (m_Socket && m_Socket->is_open())
- {
- m_Socket->close();
- m_Socket.reset();
- }
-
- TriggerEvent(WebSocketEvent::kDisconnected);
-
- {
- std::unique_lock _(m_RequestMutex);
-
- for (auto& Kv : m_PendingRequests)
- {
- Kv.second.set_value(WebSocketMessage());
- }
-
- m_PendingRequests.clear();
- }
- }
-}
-
-std::future<WebSocketMessage>
-WsClient::SendRequest(WebSocketMessage&& Request)
-{
- ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest);
-
- BinaryWriter Writer;
- Request.Save(Writer);
-
- std::future<WebSocketMessage> FutureResponse;
-
- {
- std::unique_lock _(m_RequestMutex);
-
- auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>());
- ZEN_ASSERT(Result.second);
-
- auto It = Result.first;
- FutureResponse = It->second.get_future();
- }
-
- IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
-
- async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) {
- if (Ec)
- {
- ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message());
-
- Self->Disconnect();
- }
- });
-
- return FutureResponse;
-}
-
-void
-WsClient::OnNotification(NotificationCallback&& Cb)
-{
- m_NotificationCallback = std::move(Cb);
-}
-
-void
-WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
-{
- m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb);
-}
-
-void
-WsClient::TriggerEvent(WebSocketEvent Evt)
-{
- const uint32_t Index = static_cast<uint32_t>(Evt);
-
- if (m_EventCallbacks[Index])
- {
- m_EventCallbacks[Index]();
- }
-}
-
-void
-WsClient::ReadMessage()
-{
- m_ReadBuffer.prepare(64 << 10);
-
- async_read(*m_Socket,
- m_ReadBuffer,
- asio::transfer_at_least(1),
- [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable {
- const WebSocketState State = Self->State();
-
- if (State == WebSocketState::kDisconnected)
- {
- return;
- }
-
- if (Ec)
- {
- ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message());
-
- return Self->Disconnect();
- }
-
- switch (State)
- {
- case WebSocketState::kHandshaking:
- {
- HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser());
-
- MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount);
-
- ParseMessageResult Result = Parser.ParseMessage(MessageData);
-
- Self->ReadBuffer().consume(size_t(Result.ByteCount));
-
- if (Result.Status == ParseMessageStatus::kError)
- {
- ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode());
-
- Self->m_ConnectPromise.set_value(false);
-
- return Self->Disconnect();
- }
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- return Self->ReadMessage();
- }
-
- ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
-
- if (Parser.StatusCode() != 101)
- {
- ZEN_LOG_WARN(LogWsClient,
- "handshake FAILED, status '{}', status code '{}'",
- Parser.StatusText(),
- Parser.StatusCode());
-
- Self->m_ConnectPromise.set_value(false);
-
- return Self->Disconnect();
- }
-
- ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText());
-
- Self->SetParser(std::make_unique<WebSocketMessageParser>());
- Self->SetState(WebSocketState::kConnected);
- Self->ReadMessage();
- Self->TriggerEvent(WebSocketEvent::kConnected);
-
- Self->m_ConnectPromise.set_value(true);
- }
- break;
-
- case WebSocketState::kConnected:
- {
- WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser());
-
- uint64_t RemainingBytes = Self->ReadBuffer().size();
-
- while (RemainingBytes > 0)
- {
- MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes);
- const ParseMessageResult Result = Parser.ParseMessage(MessageData);
-
- Self->ReadBuffer().consume(Result.ByteCount);
- RemainingBytes = Self->ReadBuffer().size();
-
- if (Result.Status == ParseMessageStatus::kError)
- {
- ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
-
- Parser.Reset();
- continue;
- }
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- ZEN_ASSERT(RemainingBytes == 0);
- continue;
- }
-
- WebSocketMessage Message = Parser.ConsumeMessage();
- Parser.Reset();
-
- Self->RouteMessage(std::move(Message));
- }
-
- Self->ReadMessage();
- }
- break;
-
- default:
- break;
- }
- });
-}
-
-void
-WsClient::RouteMessage(WebSocketMessage&& RoutedMessage)
-{
- switch (RoutedMessage.MessageType())
- {
- case WebSocketMessageType::kResponse:
- {
- std::unique_lock _(m_RequestMutex);
-
- if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end())
- {
- It->second.set_value(std::move(RoutedMessage));
- m_PendingRequests.erase(It);
- }
- else
- {
- ZEN_LOG_WARN(LogWsClient,
- "route request message FAILED, reason 'unknown correlation ID ({})'",
- RoutedMessage.CorrelationId());
- }
- }
- break;
-
- case WebSocketMessageType::kNotification:
- {
- std::unique_lock _(m_RequestMutex);
-
- if (m_NotificationCallback)
- {
- m_NotificationCallback(std::move(RoutedMessage));
- }
- }
- break;
-
- default:
- ZEN_LOG_WARN(LogWsClient, "route message FAILED, reason 'invalid message type ({})'", uint8_t(RoutedMessage.MessageType()));
- break;
- };
-}
-
-} // namespace zen::websocket
-
-namespace zen {
-
-std::atomic_uint32_t WebSocketId::NextId{1};
-
-bool
-WebSocketMessage::Header::IsValid() const
-{
- return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) &&
- uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount);
-}
-
-std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1};
-
-void
-WebSocketMessage::SetMessageType(WebSocketMessageType MessageType)
-{
- m_Header.MessageType = MessageType;
-}
-
-void
-WebSocketMessage::SetBody(CbPackage&& Body)
-{
- m_Body = std::move(Body);
-}
-void
-WebSocketMessage::SetBody(CbObject&& Body)
-{
- CbPackage Pkg;
- Pkg.SetObject(Body);
-
- SetBody(std::move(Pkg));
-}
-
-void
-WebSocketMessage::Save(BinaryWriter& Writer)
-{
- Writer.Write(&m_Header, HeaderSize);
-
- if (m_Body.has_value())
- {
- const CbObject& Obj = m_Body.value().GetObject();
- MemoryView View = Obj.GetBuffer().GetView();
-
- const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All);
- ZEN_ASSERT(ValidationResult == CbValidateError::None);
-
- m_Body.value().Save(Writer);
- }
-
- if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest)
- {
- m_Header.CorrelationId = NextCorrelationId.fetch_add(1);
- }
-
- m_Header.MessageSize = Writer.Size() - HeaderSize;
-
- Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize));
-}
-
-bool
-WebSocketMessage::TryLoadHeader(MemoryView Memory)
-{
- if (Memory.GetSize() < HeaderSize)
- {
- return false;
- }
-
- MutableMemoryView HeaderView(&m_Header, HeaderSize);
-
- HeaderView.CopyFrom(Memory);
-
- return m_Header.IsValid();
-}
-
-void
-WebSocketService::Configure(WebSocketServer& Server)
-{
- ZEN_ASSERT(m_SocketServer == nullptr);
-
- m_SocketServer = &Server;
-
- RegisterHandlers(Server);
-}
-
-void
-WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete)
-{
- WebSocketMessage Message;
-
- Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse);
- Message.SetCorrelationId(CorrelationId);
- Message.SetSocketId(SocketId);
- Message.SetBody(std::move(StreamResponse));
-
- SocketServer().SendResponse(std::move(Message));
-}
-
-void
-WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete)
-{
- CbPackage Response;
- Response.SetObject(std::move(StreamResponse));
-
- SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete);
-}
-
-std::unique_ptr<WebSocketServer>
-WebSocketServer::Create(const WebSocketServerOptions& Options)
-{
- return std::make_unique<websocket::WsServer>(Options);
-}
-
-std::shared_ptr<WebSocketClient>
-WebSocketClient::Create(asio::io_context& IoCtx)
-{
- return std::make_shared<websocket::WsClient>(IoCtx);
-}
-
-} // namespace zen
diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp
index 2f52e3225..3fae3235f 100644
--- a/src/zenserver-test/zenserver-test.cpp
+++ b/src/zenserver-test/zenserver-test.cpp
@@ -21,7 +21,6 @@
#include <zencore/xxhash.h>
#include <zenhttp/httpclient.h>
#include <zenhttp/httpshared.h>
-#include <zenhttp/websocket.h>
#include <zenhttp/zenhttp.h>
#include <zenutil/cache/cache.h>
#include <zenutil/cache/cacherequests.h>
@@ -2608,57 +2607,6 @@ TEST_CASE("http.package")
CHECK_EQ(ResponsePackage, TestPackage);
}
-TEST_CASE("websocket.basic")
-{
- if (true)
- {
- return;
- }
-
- std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
- const uint16_t PortNumber = 13337;
- const auto MaxWaitTime = std::chrono::seconds(5);
-
- ZenServerInstance Inst(TestEnv);
- Inst.SetTestDir(TestDir);
- Inst.SpawnServer(PortNumber, "--websocket-port=8848"sv);
- Inst.WaitUntilReady();
-
- asio::io_context IoCtx;
- IoDispatcher IoDispatcher(IoCtx);
- auto WebSocket = WebSocketClient::Create(IoCtx);
-
- auto ConnectFuture = WebSocket->Connect({.Host = "127.0.0.1", .Port = 8848, .Endpoint = "/zen"});
- IoDispatcher.Run();
-
- ConnectFuture.wait_for(MaxWaitTime);
- CHECK(ConnectFuture.get());
-
- for (size_t Idx = 0; Idx < 10; Idx++)
- {
- CbObjectWriter Request;
- Request << "Method"sv
- << "SayHello"sv;
-
- WebSocketMessage RequestMsg;
- RequestMsg.SetMessageType(WebSocketMessageType::kRequest);
- RequestMsg.SetBody(Request.Save());
-
- auto ResponseFuture = WebSocket->SendRequest(std::move(RequestMsg));
- ResponseFuture.wait_for(MaxWaitTime);
-
- CbObject Response = ResponseFuture.get().Body().GetObject();
- std::string_view Message = Response["Result"].AsString();
-
- CHECK(Message == "Hello Friend!!"sv);
- }
-
- WebSocket->Disconnect();
-
- IoCtx.stop();
- IoDispatcher.Stop();
-}
-
std::string
OidAsString(const Oid& Id)
{
diff --git a/src/zenserver/config.cpp b/src/zenserver/config.cpp
index 5e24d174b..c0a97ce5b 100644
--- a/src/zenserver/config.cpp
+++ b/src/zenserver/config.cpp
@@ -796,8 +796,6 @@ ParseConfigFile(const std::filesystem::path& Path,
LuaOptions.AddOption("network.httpserverclass"sv, ServerOptions.HttpServerConfig.ServerClass, "http"sv);
LuaOptions.AddOption("network.httpserverthreads"sv, ServerOptions.HttpServerConfig.ThreadCount, "http-threads"sv);
LuaOptions.AddOption("network.port"sv, ServerOptions.BasePort, "port"sv);
- LuaOptions.AddOption("network.websocket.port"sv, ServerOptions.WebSocketPort, "websocket-port"sv);
- LuaOptions.AddOption("network.websocket.threadcount"sv, ServerOptions.WebSocketThreads, "websocket-threads"sv);
LuaOptions.AddOption("network.httpsys.async.workthreads"sv,
ServerOptions.HttpServerConfig.HttpSys.AsyncWorkThreadCount,
@@ -1076,20 +1074,6 @@ ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions)
cxxopts::value<int>(ServerOptions.BasePort)->default_value("1337"),
"<port number>");
- options.add_option("network",
- "",
- "websocket-port",
- "Websocket server port",
- cxxopts::value<int>(ServerOptions.WebSocketPort)->default_value("0"),
- "<port number>");
-
- options.add_option("network",
- "",
- "websocket-threads",
- "Number of websocket I/O thread(s) (0 == hardware concurrency)",
- cxxopts::value<int>(ServerOptions.WebSocketThreads)->default_value("0"),
- "");
-
options.add_option("httpsys",
"",
"httpsys-async-work-threads",
diff --git a/src/zenserver/config.h b/src/zenserver/config.h
index 924375a19..df1ccb752 100644
--- a/src/zenserver/config.h
+++ b/src/zenserver/config.h
@@ -134,18 +134,16 @@ struct ZenServerOptions
ZenObjectStoreConfig ObjectStoreConfig;
zen::HttpServerConfig HttpServerConfig;
ZenStructuredCacheConfig StructuredCacheConfig;
- std::filesystem::path DataDir; // Root directory for state (used for testing)
- std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental)
- std::filesystem::path AbsLogFile; // Absolute path to main log file
- std::filesystem::path ConfigFile; // Path to Lua config file
- std::string ChildId; // Id assigned by parent process (used for lifetime management)
- std::string LogId; // Id for tagging log output
- std::string EncryptionKey; // 256 bit AES encryption key
- std::string EncryptionIV; // 128 bit AES initialization vector
- int BasePort = 1337; // Service listen port (used for both UDP and TCP)
- int OwnerPid = 0; // Parent process id (zero for standalone)
- int WebSocketPort = 0; // Web socket port (Zero = disabled)
- int WebSocketThreads = 0;
+ std::filesystem::path DataDir; // Root directory for state (used for testing)
+ std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental)
+ std::filesystem::path AbsLogFile; // Absolute path to main log file
+ std::filesystem::path ConfigFile; // Path to Lua config file
+ std::string ChildId; // Id assigned by parent process (used for lifetime management)
+ std::string LogId; // Id for tagging log output
+ std::string EncryptionKey; // 256 bit AES encryption key
+ std::string EncryptionIV; // 128 bit AES initialization vector
+ int BasePort = 1337; // Service listen port (used for both UDP and TCP)
+ int OwnerPid = 0; // Parent process id (zero for standalone)
bool InstallService = false; // Flag used to initiate service install (temporary)
bool UninstallService = false; // Flag used to initiate service uninstall (temporary)
bool IsDebug = false;
diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp
index cf9f03d89..83eb81e3b 100644
--- a/src/zenserver/zenserver.cpp
+++ b/src/zenserver/zenserver.cpp
@@ -17,7 +17,6 @@
#include <zencore/trace.h>
#include <zencore/workthreadpool.h>
#include <zenhttp/httpserver.h>
-#include <zenhttp/websocket.h>
#include <zenstore/cidstore.h>
#include <zenstore/scrubcontext.h>
#include <zenutil/basicfile.h>
@@ -309,15 +308,6 @@ public:
m_Http = zen::CreateHttpServer(ServerOptions.HttpServerConfig);
int EffectiveBasePort = m_Http->Initialize(ServerOptions.BasePort);
- if (ServerOptions.WebSocketPort != 0)
- {
- const uint32_t ThreadCount =
- ServerOptions.WebSocketThreads > 0 ? uint32_t(ServerOptions.WebSocketThreads) : std::thread::hardware_concurrency();
-
- m_WebSocket = zen::WebSocketServer::Create(
- {.Port = gsl::narrow<uint16_t>(ServerOptions.WebSocketPort), .ThreadCount = Max(ThreadCount, uint32_t(16))});
- }
-
// Setup authentication manager
{
std::string EncryptionKey = ServerOptions.EncryptionKey;
@@ -396,11 +386,6 @@ public:
#if ZEN_WITH_TESTS
m_Http->RegisterService(m_TestingService);
-
- if (m_WebSocket)
- {
- m_WebSocket->RegisterService(m_TestingService);
- }
#endif
if (m_HttpProjectService)
@@ -526,11 +511,6 @@ public:
OnReady();
- if (m_WebSocket)
- {
- m_WebSocket->Run();
- }
-
m_Http->Run(IsInteractiveMode);
SetNewState(kShuttingDown);
@@ -590,7 +570,6 @@ public:
m_CidStore.reset();
m_AuthService.reset();
m_AuthMgr.reset();
- m_WebSocket.reset();
m_Http = {};
m_JobQueue.reset();
}
@@ -788,7 +767,6 @@ private:
}
zen::Ref<zen::HttpServer> m_Http;
- std::unique_ptr<zen::WebSocketServer> m_WebSocket;
std::unique_ptr<zen::AuthMgr> m_AuthMgr;
std::unique_ptr<zen::HttpAuthService> m_AuthService;
zen::HttpStatusService m_StatusService;