diff options
| author | Per Larsson <[email protected]> | 2022-02-21 15:14:11 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-02-21 15:14:11 +0100 |
| commit | db1c9605e3afbaf86f4231ba4eb7976d896f286b (patch) | |
| tree | 54b451da4247c69575ff1a05ed006ecef3905c85 | |
| parent | If open(O_CREAT) is used then a file mode must be given (diff) | |
| parent | Removed optional offset for GetView. (diff) | |
| download | zen-db1c9605e3afbaf86f4231ba4eb7976d896f286b.tar.xz zen-db1c9605e3afbaf86f4231ba4eb7976d896f286b.zip | |
Initial support for websockets.
| -rw-r--r-- | zencore/include/zencore/logging.h | 61 | ||||
| -rw-r--r-- | zencore/include/zencore/stream.h | 6 | ||||
| -rw-r--r-- | zencore/stream.cpp | 9 | ||||
| -rw-r--r-- | zenhttp/include/zenhttp/websocket.h | 224 | ||||
| -rw-r--r-- | zenhttp/websocketasio.cpp | 1537 | ||||
| -rw-r--r-- | zenserver-test/zenserver-test.cpp | 96 | ||||
| -rw-r--r-- | zenserver/config.cpp | 14 | ||||
| -rw-r--r-- | zenserver/config.h | 24 | ||||
| -rw-r--r-- | zenserver/testing/httptest.cpp | 34 | ||||
| -rw-r--r-- | zenserver/testing/httptest.h | 6 | ||||
| -rw-r--r-- | zenserver/zenserver.cpp | 21 |
11 files changed, 2020 insertions, 12 deletions
diff --git a/zencore/include/zencore/logging.h b/zencore/include/zencore/logging.h index 468e5d6e2..74ab0f81f 100644 --- a/zencore/include/zencore/logging.h +++ b/zencore/include/zencore/logging.h @@ -38,6 +38,67 @@ using logging::ConsoleLog; using zen::ConsoleLog; using zen::Log; +struct LogCategory +{ + LogCategory(std::string_view InCategory) : Category(InCategory) {} + + spdlog::logger& Logger() + { + static spdlog::logger& Inst = zen::logging::Get(Category); + return Inst; + } + + std::string Category; +}; + +#define ZEN_DEFINE_LOG_CATEGORY_STATIC(Category, Name) \ + static struct LogCategory##Category : public LogCategory \ + { \ + LogCategory##Category() : LogCategory(Name) {} \ + } Category; + +#define ZEN_LOG_TRACE(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().trace(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + +#define ZEN_LOG_DEBUG(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().debug(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + +#define ZEN_LOG_INFO(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().info(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + +#define ZEN_LOG_WARN(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().warn(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + +#define ZEN_LOG_ERROR(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().error(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + +#define ZEN_LOG_CRITICAL(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().critical(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + // Helper macros for logging #define ZEN_TRACE(fmtstr, ...) \ diff --git a/zencore/include/zencore/stream.h b/zencore/include/zencore/stream.h index 9d1a7628c..efff2c541 100644 --- a/zencore/include/zencore/stream.h +++ b/zencore/include/zencore/stream.h @@ -27,12 +27,18 @@ public: m_Offset += ByteCount; } + inline void Write(MemoryView Memory) { Write(Memory.GetData(), Memory.GetSize()); } + inline uint64_t CurrentOffset() const { return m_Offset; } inline const uint8_t* Data() const { return m_Buffer.data(); } inline const uint8_t* GetData() const { return m_Buffer.data(); } inline uint64_t Size() const { return m_Buffer.size(); } inline uint64_t GetSize() const { return m_Buffer.size(); } + void Reset(); + + inline MemoryView GetView() const { return MemoryView(m_Buffer.data(), m_Offset); } + inline MutableMemoryView GetMutableView() { return MutableMemoryView(m_Buffer.data(), m_Offset); } private: RwLock m_Lock; diff --git a/zencore/stream.cpp b/zencore/stream.cpp index aa9705764..8faf90af2 100644 --- a/zencore/stream.cpp +++ b/zencore/stream.cpp @@ -25,6 +25,15 @@ BinaryWriter::Write(const void* data, size_t ByteCount, uint64_t Offset) memcpy(m_Buffer.data() + Offset, data, ByteCount); } +void +BinaryWriter::Reset() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + m_Buffer.clear(); + m_Offset = 0; +} + ////////////////////////////////////////////////////////////////////////// // // Testing related code follows... diff --git a/zenhttp/include/zenhttp/websocket.h b/zenhttp/include/zenhttp/websocket.h new file mode 100644 index 000000000..a514e6002 --- /dev/null +++ b/zenhttp/include/zenhttp/websocket.h @@ -0,0 +1,224 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/compactbinarypackage.h> +#include <zencore/memory.h> + +#include <compare> +#include <functional> +#include <future> +#include <memory> +#include <optional> + +#pragma once + +namespace asio { +class io_context; +} + +namespace zen { + +class BinaryWriter; + +/** + * A unique socket ID. + */ +class WebSocketId +{ + static std::atomic_uint32_t NextId; + +public: + WebSocketId() = default; + + uint32_t Value() const { return m_Value; } + + auto operator<=>(const WebSocketId&) const = default; + + static WebSocketId New() { return WebSocketId(NextId.fetch_add(1)); } + +private: + WebSocketId(uint32_t Value) : m_Value(Value) {} + + uint32_t m_Value{}; +}; + +/** + * Type of web socket message. + */ +enum class WebSocketMessageType : uint8_t +{ + kInvalid, + kNotification, + kRequest, + kResponse +}; + +/** + * Web socket message. + */ +class WebSocketMessage +{ + struct Header + { + static constexpr uint32_t HeaderMagic = 0x7a776d68; // zwmh + + uint64_t MessageSize{}; + uint32_t Magic{HeaderMagic}; + uint32_t CorrelationId{}; + uint32_t Crc32{}; + WebSocketMessageType MessageType{}; + uint8_t Reserved[3] = {0}; + + bool IsValid() const; + }; + + static_assert(sizeof Header == 24); + + static std::atomic_uint32_t NextCorrelationId; + +public: + static constexpr size_t HeaderSize = sizeof(Header); + + WebSocketMessage() = default; + + WebSocketId SocketId() const { return m_SocketId; } + void SetSocketId(WebSocketId Id) { m_SocketId = Id; } + void SetMessageType(WebSocketMessageType MessageType); + WebSocketMessageType MessageType() const { return m_Header.MessageType; } + uint64_t MessageSize() const { return m_Header.MessageSize; } + void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; } + uint32_t CorrelationId() const { return m_Header.CorrelationId; } + + const CbPackage& Body() const { return m_Body.value(); } + void SetBody(CbPackage&& Body); + void SetBody(CbObject&& Body); + bool HasBody() const { return m_Body.has_value(); } + + void Save(BinaryWriter& Writer); + bool TryLoadHeader(MemoryView Memory); + + bool IsValid() const { return m_Header.MessageType != WebSocketMessageType::kInvalid; } + +private: + Header m_Header{}; + WebSocketId m_SocketId{}; + std::optional<CbPackage> m_Body; +}; + +class WebSocketServer; + +/** + * Base class for handling web socket requests and notifications from connected client(s). + */ +class WebSocketService +{ +public: + virtual ~WebSocketService() = default; + + void Configure(WebSocketServer& Server); + + virtual bool HandleRequest(const WebSocketMessage&) { ZEN_ASSERT(false); } + virtual void HandleNotification(const WebSocketMessage&) { ZEN_ASSERT(false); } + +protected: + WebSocketService() = default; + + virtual void RegisterHandlers(WebSocketServer& Server) = 0; + + WebSocketServer& SocketServer() + { + ZEN_ASSERT(m_SocketServer); + return *m_SocketServer; + } + +private: + WebSocketServer* m_SocketServer{}; +}; + +/** + * Server options. + */ +struct WebSocketServerOptions +{ + uint16_t Port = 2337; + uint32_t ThreadCount = 1; +}; + +/** + * The web socket server manages client connections and routing of requests and notifications. + */ +class WebSocketServer +{ +public: + virtual ~WebSocketServer() = default; + + virtual bool Run() = 0; + virtual void Shutdown() = 0; + + virtual void RegisterService(WebSocketService& Service) = 0; + virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) = 0; + virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) = 0; + + virtual void SendNotification(WebSocketMessage&& Notification) = 0; + virtual void SendResponse(WebSocketMessage&& Response) = 0; + + static std::unique_ptr<WebSocketServer> Create(const WebSocketServerOptions& Options); +}; + +/** + * The state of the web socket. + */ +enum class WebSocketState : uint32_t +{ + kNone, + kHandshaking, + kConnected, + kDisconnected, + kError +}; + +/** + * Type of web socket client event. + */ +enum class WebSocketEvent : uint32_t +{ + kConnected, + kDisconnected, + kError +}; + +/** + * Web socket client connection info. + */ +struct WebSocketConnectInfo +{ + std::string Host; + int16_t Port{8848}; + std::string Endpoint; + std::vector<std::string> Protocols; + uint16_t Version{13}; +}; + +/** + * A connection to a web socket server for sending requests and listening for notifications. + */ +class WebSocketClient +{ +public: + using EventCallback = std::function<void()>; + using NotificationCallback = std::function<void(WebSocketMessage&&)>; + + virtual ~WebSocketClient() = default; + + virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) = 0; + virtual void Disconnect() = 0; + virtual bool IsConnected() const = 0; + virtual WebSocketState State() const = 0; + + virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) = 0; + virtual void OnNotification(NotificationCallback&& Cb) = 0; + virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) = 0; + + static std::shared_ptr<WebSocketClient> Create(asio::io_context& IoCtx); +}; + +} // namespace zen diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp new file mode 100644 index 000000000..1a95b12bc --- /dev/null +++ b/zenhttp/websocketasio.cpp @@ -0,0 +1,1537 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/websocket.h> + +#include <zencore/base64.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/intmath.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/sha1.h> +#include <zencore/stream.h> +#include <zencore/string.h> + +#include <chrono> +#include <optional> +#include <shared_mutex> +#include <span> +#include <system_error> +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <http_parser.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::websocket { + +using namespace std::literals; + +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); + +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv); + +using Clock = std::chrono::steady_clock; +using TimePoint = Clock::time_point; + +/////////////////////////////////////////////////////////////////////////////// +namespace http_header { + static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; + static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; + static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; + static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; + static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; + static constexpr std::string_view Upgrade = "Upgrade"sv; +} // namespace http_header + +/////////////////////////////////////////////////////////////////////////////// +enum class ParseMessageStatus : uint32_t +{ + kError, + kContinue, + kDone, +}; + +struct ParseMessageResult +{ + ParseMessageStatus Status{}; + size_t ByteCount{}; + std::optional<std::string> Reason; +}; + +class MessageParser +{ +public: + virtual ~MessageParser() = default; + + ParseMessageResult ParseMessage(MemoryView Msg); + void Reset(); + +protected: + MessageParser() = default; + + virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; + virtual void OnReset() = 0; + + BinaryWriter m_Stream; +}; + +ParseMessageResult +MessageParser::ParseMessage(MemoryView Msg) +{ + return OnParseMessage(Msg); +} + +void +MessageParser::Reset() +{ + OnReset(); + + m_Stream.Reset(); +} + +/////////////////////////////////////////////////////////////////////////////// +enum class HttpMessageParserType +{ + kRequest, + kResponse, + kBoth +}; + +class HttpMessageParser final : public MessageParser +{ +public: + using HttpHeaders = std::unordered_map<std::string_view, std::string_view>; + + HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); } + + virtual ~HttpMessageParser() = default; + + int32_t StatusCode() const { return m_Parser.status_code; } + bool IsUpgrade() const { return m_Parser.upgrade != 0; } + HttpHeaders& Headers() { return m_Headers; } + MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } + + std::string_view StatusText() const + { + return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); + } + + bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); + +private: + void Initialize(); + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; + int OnMessageBegin(); + int OnUrl(MemoryView Url); + int OnStatus(MemoryView Status); + int OnHeaderField(MemoryView HeaderField); + int OnHeaderValue(MemoryView HeaderValue); + int OnHeadersComplete(); + int OnBody(MemoryView Body); + int OnMessageComplete(); + + struct StreamEntry + { + uint64_t Offset{}; + uint64_t Size{}; + }; + + struct HeaderStreamEntry + { + StreamEntry Field{}; + StreamEntry Value{}; + }; + + HttpMessageParserType m_Type; + http_parser m_Parser; + StreamEntry m_UrlEntry; + StreamEntry m_StatusEntry; + StreamEntry m_BodyEntry; + HeaderStreamEntry m_CurrentHeader; + std::vector<HeaderStreamEntry> m_HeaderEntries; + HttpHeaders m_Headers; + bool m_IsMsgComplete{false}; + + static http_parser_settings ParserSettings; +}; + +http_parser_settings HttpMessageParser::ParserSettings = { + .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); }, + + .on_url = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); }, + + .on_status = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); }, + + .on_header_field = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); }, + + .on_header_value = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, + + .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); }, + + .on_body = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); }, + + .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }}; + +void +HttpMessageParser::Initialize() +{ + http_parser_init(&m_Parser, + m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST + : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE + : HTTP_BOTH); + m_Parser.data = this; + + m_UrlEntry = {}; + m_StatusEntry = {}; + m_CurrentHeader = {}; + m_BodyEntry = {}; + + m_IsMsgComplete = false; + + m_HeaderEntries.clear(); +} + +ParseMessageResult +HttpMessageParser::OnParseMessage(MemoryView Msg) +{ + const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize()); + + auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; + + if (m_Parser.http_errno != 0) + { + Status = ParseMessageStatus::kError; + } + + return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; +} + +void +HttpMessageParser::OnReset() +{ + Initialize(); +} + +int +HttpMessageParser::OnMessageBegin() +{ + ZEN_ASSERT(m_IsMsgComplete == false); + ZEN_ASSERT(m_HeaderEntries.empty()); + ZEN_ASSERT(m_Headers.empty()); + + return 0; +} + +int +HttpMessageParser::OnStatus(MemoryView Status) +{ + m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; + + m_Stream.Write(Status); + + return 0; +} + +int +HttpMessageParser::OnUrl(MemoryView Url) +{ + m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; + + m_Stream.Write(Url); + + return 0; +} + +int +HttpMessageParser::OnHeaderField(MemoryView HeaderField) +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } + + if (m_CurrentHeader.Field.Size == 0) + { + m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Field.Size += HeaderField.GetSize(); + + m_Stream.Write(HeaderField); + + return 0; +} + +int +HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) +{ + if (m_CurrentHeader.Value.Size == 0) + { + m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Value.Size += HeaderValue.GetSize(); + + m_Stream.Write(HeaderValue); + + return 0; +} + +int +HttpMessageParser::OnHeadersComplete() +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } + + m_Headers.clear(); + m_Headers.reserve(m_HeaderEntries.size()); + + const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data()); + + for (const auto& Entry : m_HeaderEntries) + { + auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); + auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); + + m_Headers.try_emplace(std::move(Field), std::move(Value)); + } + + return 0; +} + +int +HttpMessageParser::OnBody(MemoryView Body) +{ + m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; + + m_Stream.Write(Body); + + return 0; +} + +int +HttpMessageParser::OnMessageComplete() +{ + m_IsMsgComplete = true; + + return 0; +} + +bool +HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) +{ + static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + + OutAcceptHash = std::string(); + + if (m_Headers.contains(http_header::SecWebSocketKey) == false) + { + OutReason = "Missing header Sec-WebSocket-Key"; + return false; + } + + if (m_Headers.contains(http_header::Upgrade) == false) + { + OutReason = "Missing header Upgrade"; + return false; + } + + ExtendableStringBuilder<128> Sb; + Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; + + SHA1Stream HashStream; + HashStream.Append(Sb.Data(), Sb.Size()); + + SHA1 Hash = HashStream.GetHash(); + + OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); + Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); + + return true; +} + +/////////////////////////////////////////////////////////////////////////////// +class WebSocketMessageParser final : public MessageParser +{ +public: + WebSocketMessageParser() : MessageParser() {} + + WebSocketMessage ConsumeMessage(); + +private: + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; + + WebSocketMessage m_Message; +}; + +ParseMessageResult +WebSocketMessageParser::OnParseMessage(MemoryView Msg) +{ + const uint64_t PrevOffset = m_Stream.CurrentOffset(); + + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) + { + const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); + + m_Stream.Write(Msg.Left(RemaingHeaderSize)); + + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) + { + return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } + + const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView()); + + if (IsValidHeader == false) + { + OnReset(); + + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message header")}; + } + + Msg += RemaingHeaderSize; + } + + ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); + + if (Msg.IsEmpty()) + { + return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } + + const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); + + m_Stream.Write(Msg.Left(RemaingMessageSize)); + + const bool IsComplete = WebSocketMessage::HeaderSize + m_Message.MessageSize() == m_Stream.CurrentOffset(); + + if (IsComplete) + { + BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize)); + + CbPackage Pkg; + if (Pkg.TryLoad(Reader) == false) + { + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message")}; + } + + m_Message.SetBody(std::move(Pkg)); + } + + return {.Status = IsComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; +} + +void +WebSocketMessageParser::OnReset() +{ + m_Message = WebSocketMessage(); +} + +WebSocketMessage +WebSocketMessageParser::ConsumeMessage() +{ + WebSocketMessage Msg = std::move(m_Message); + m_Message = WebSocketMessage(); + + return Msg; +} + +/////////////////////////////////////////////////////////////////////////////// +class WsConnection : public std::enable_shared_from_this<WsConnection> +{ +public: + WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) + : m_Id(Id) + , m_Socket(std::move(Socket)) + , m_StartTime(Clock::now()) + , m_State() + { + } + + ~WsConnection() = default; + + std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } + + WebSocketId Id() const { return m_Id; } + asio::ip::tcp::socket& Socket() { return *m_Socket; } + TimePoint StartTime() const { return m_StartTime; } + WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); } + std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } + WebSocketState Close(); + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + +private: + WebSocketId m_Id; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + TimePoint m_StartTime; + std::atomic_uint32_t m_State; + std::unique_ptr<MessageParser> m_MsgParser; + asio::streambuf m_ReadBuffer; +}; + +WebSocketState +WsConnection::Close() +{ + using enum WebSocketState; + + const auto PrevState = SetState(kDisconnected); + + if (PrevState != kDisconnected && m_Socket->is_open()) + { + m_Socket->close(); + } + + return PrevState; +} + +/////////////////////////////////////////////////////////////////////////////// +class WsThreadPool +{ +public: + WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} + void Start(uint32_t ThreadCount); + void Stop(); + +private: + asio::io_service& m_IoSvc; + std::vector<std::thread> m_Threads; + std::atomic_bool m_Running{false}; +}; + +void +WsThreadPool::Start(uint32_t ThreadCount) +{ + ZEN_ASSERT(m_Threads.empty()); + + ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount); + + m_Running = true; + + for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) + { + m_Threads.emplace_back([this, ThreadId = Idx + 1] { + for (;;) + { + if (m_Running == false) + { + break; + } + + try + { + m_IoSvc.run(); + } + catch (std::exception& Err) + { + ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what()); + } + } + + ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId); + }); + } +} + +void +WsThreadPool::Stop() +{ + if (m_Running) + { + m_Running = false; + + for (std::thread& Thread : m_Threads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Threads.clear(); + } +} + +/////////////////////////////////////////////////////////////////////////////// +class WsServer final : public WebSocketServer +{ +public: + WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {} + virtual ~WsServer() { Shutdown(); } + + virtual bool Run() override; + virtual void Shutdown() override; + + virtual void RegisterService(WebSocketService& Service) override; + virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override; + virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override; + + virtual void SendNotification(WebSocketMessage&& Notification) override; + virtual void SendResponse(WebSocketMessage&& Response) override; + +private: + friend class WsConnection; + + void AcceptConnection(); + void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec); + + void ReadMessage(std::shared_ptr<WsConnection> Connection); + void RouteMessage(WebSocketMessage&& Msg); + void SendMessage(WebSocketMessage&& Msg); + + struct IdHasher + { + size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } + }; + + using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>; + using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>; + using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>; + + WebSocketServerOptions m_Options; + asio::io_service m_IoSvc; + std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; + std::unique_ptr<WsThreadPool> m_ThreadPool; + ConnectionMap m_Connections; + std::shared_mutex m_ConnMutex; + std::vector<WebSocketService*> m_Services; + RequestHandlerMap m_RequestHandlers; + NotificationHandlerMap m_NotificationHandlers; + std::atomic_bool m_Running{}; +}; + +void +WsServer::RegisterService(WebSocketService& Service) +{ + m_Services.push_back(&Service); + + Service.Configure(*this); +} + +bool +WsServer::Run() +{ + m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6()); + + m_Acceptor->set_option(asio::ip::v6_only(false)); + m_Acceptor->set_option(asio::socket_base::reuse_address(true)); + m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + asio::error_code Ec; + m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec); + + if (Ec) + { + ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value()); + + return false; + } + + m_Acceptor->listen(); + m_Running = true; + + ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port); + + AcceptConnection(); + + m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc); + m_ThreadPool->Start(m_Options.ThreadCount); + + return true; +} + +void +WsServer::Shutdown() +{ + if (m_Running) + { + ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down"); + + m_Running = false; + + m_Acceptor->close(); + m_Acceptor.reset(); + m_IoSvc.stop(); + + m_ThreadPool->Stop(); + } +} + +void +WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) +{ + auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>()); + Result.first->second.push_back(&Service); +} + +void +WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service) +{ + m_RequestHandlers[Key] = &Service; +} + +void +WsServer::SendNotification(WebSocketMessage&& Notification) +{ + ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification); + + SendMessage(std::move(Notification)); +} +void +WsServer::SendResponse(WebSocketMessage&& Response) +{ + ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse); + ZEN_ASSERT(Response.CorrelationId() != 0); + + SendMessage(std::move(Response)); +} + +void +WsServer::AcceptConnection() +{ + auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc); + asio::ip::tcp::socket& SocketRef = *Socket.get(); + + m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { + if (m_Running) + { + if (Ec) + { + ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message()); + } + else + { + auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket)); + + ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); + + { + std::unique_lock _(m_ConnMutex); + m_Connections[Connection->Id()] = Connection; + } + + Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest)); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + } + + AcceptConnection(); + } + }); +} + +void +WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec) +{ + if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected) + { + if (Ec) + { + ZEN_LOG_INFO(LogWebSocket, + "connection '{}' closed, ERROR '{}' error code '{}'", + Connection->Id().Value(), + Ec.message(), + Ec.value()); + } + else + { + ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value()); + } + } + + const WebSocketId Id = Connection->Id(); + + { + std::unique_lock _(m_ConnMutex); + if (m_Connections.contains(Id)) + { + m_Connections.erase(Id); + } + } +} + +void +WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) +{ + Connection->ReadBuffer().prepare(64 << 10); + + asio::async_read( + Connection->Socket(), + Connection->ReadBuffer(), + asio::transfer_at_least(1), + [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable { + if (ReadEc) + { + return CloseConnection(Connection, ReadEc); + } + + using enum WebSocketState; + + switch (Connection->State()) + { + case kHandshaking: + { + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser()); + asio::const_buffer Buffer = Connection->ReadBuffer().data(); + + ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + + Connection->ReadBuffer().consume(Result.ByteCount); + + if (Result.Status == ParseMessageStatus::kContinue) + { + return ReadMessage(Connection); + } + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + return CloseConnection(Connection, std::error_code()); + } + + if (Parser.IsUpgrade() == false) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; + + return async_write(Connection->Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + return CloseConnection(Connection, WriteEc); + } + + Connection->Parser()->Reset(); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + }); + } + + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + + std::string AcceptHash; + std::string Reason; + const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason); + + if (ValidHandshake == false) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + Reason); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; + + return async_write(Connection->Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + return CloseConnection(Connection, WriteEc); + } + + Connection->Parser()->Reset(); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + }); + } + + ExtendableStringBuilder<128> Sb; + + Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: Upgrade\r\n"sv; + + // TODO: Verify protocol + if (Parser.Headers().contains(http_header::SecWebSocketProtocol)) + { + Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol] + << "\r\n"; + } + + Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; + Sb << "\r\n"sv; + + ZEN_LOG_DEBUG(LogWebSocket, + "accepting handshake from connection '#{} {}'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + std::string Response = Sb.ToString(); + Buffer = asio::buffer(Response); + + async_write(Connection->Socket(), + Buffer, + [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) { + if (WriteEc) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + WriteEc.message()); + + return CloseConnection(Connection, WriteEc); + } + + ZEN_LOG_DEBUG(LogWebSocket, + "handshake ({}B) with connection '#{} {}' OK", + ByteCount, + Connection->Id().Value(), + Connection->RemoteAddr()); + + Connection->SetParser(std::make_unique<WebSocketMessageParser>()); + Connection->SetState(kConnected); + + ReadMessage(Connection); + }); + } + break; + + case kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser()); + + uint64_t RemainingBytes = Connection->ReadBuffer().size(); + + while (RemainingBytes > 0) + { + MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Connection->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Connection->ReadBuffer().size(); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + + return CloseConnection(Connection, std::error_code()); + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(RemainingBytes == 0); + continue; + } + + WebSocketMessage Message = Parser.ConsumeMessage(); + Parser.Reset(); + + Message.SetSocketId(Connection->Id()); + + RouteMessage(std::move(Message)); + } + + ReadMessage(Connection); + } + break; + + default: + break; + }; + }); +} + +void +WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) +{ + switch (RoutedMessage.MessageType()) + { + case WebSocketMessageType::kRequest: + { + CbObjectView Request = RoutedMessage.Body().GetObject(); + std::string_view Method = Request["Method"].AsString(); + bool Handled = false; + bool Error = false; + std::exception Exception; + + if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end()) + { + WebSocketService* Service = It->second; + ZEN_ASSERT(Service); + + try + { + Handled = Service->HandleRequest(std::move(RoutedMessage)); + } + catch (std::exception& Err) + { + Exception = std::move(Err); + Error = true; + } + } + + if (Error || Handled == false) + { + std::string ErrorText = Error ? Exception.what() : std::string("Not Found"); + + ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); + + CbObjectWriter Response; + Response << "Error"sv << ErrorText; + + WebSocketMessage ResponseMsg; + ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); + ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId()); + ResponseMsg.SetSocketId(RoutedMessage.SocketId()); + ResponseMsg.SetBody(Response.Save()); + + SendResponse(std::move(ResponseMsg)); + } + } + break; + + case WebSocketMessageType::kNotification: + { + CbObjectView Notification = RoutedMessage.Body().GetObject(); + std::string_view Message = Notification["Message"].AsString(); + + if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end()) + { + std::vector<WebSocketService*>& Handlers = It->second; + + for (WebSocketService* Handler : Handlers) + { + Handler->HandleNotification(RoutedMessage); + } + } + else + { + ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message); + } + } + break; + }; +} + +void +WsServer::SendMessage(WebSocketMessage&& Msg) +{ + std::shared_ptr<WsConnection> Connection; + + { + std::unique_lock _(m_ConnMutex); + + if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end()) + { + Connection = It->second; + } + } + + if (Connection.get() == nullptr) + { + ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value()); + return; + } + + if (Connection.get() != nullptr) + { + BinaryWriter Writer; + Msg.Save(Writer); + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + async_write(Connection->Socket(), + asio::buffer(Buffer.Data(), Buffer.Size()), + [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason '{}'", Ec.message()); + + CloseConnection(Connection, Ec); + } + }); + } +} + +/////////////////////////////////////////////////////////////////////////////// +class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient> +{ +public: + WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {} + + virtual ~WsClient() { Disconnect(); } + + std::shared_ptr<WsClient> AsShared() { return shared_from_this(); } + + virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override; + virtual void Disconnect() override; + virtual bool IsConnected() const { return false; } + virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); } + + virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override; + virtual void OnNotification(NotificationCallback&& Cb) override; + virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override; + +private: + WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + void TriggerEvent(WebSocketEvent Evt); + void ReadMessage(); + void RouteMessage(WebSocketMessage&& RoutedMessage); + + using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>; + + asio::io_context& m_IoCtx; + WebSocketId m_Id; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<MessageParser> m_MsgParser; + asio::streambuf m_ReadBuffer; + EventCallback m_EventCallbacks[3]; + NotificationCallback m_NotificationCallback; + PendingRequestMap m_PendingRequests; + std::mutex m_RequestMutex; + std::promise<bool> m_ConnectPromise; + std::atomic_uint32_t m_State; + std::string m_Host; + int16_t m_Port{}; +}; + +std::future<bool> +WsClient::Connect(const WebSocketConnectInfo& Info) +{ + if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) + { + return m_ConnectPromise.get_future(); + } + + SetState(WebSocketState::kHandshaking); + + try + { + asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port); + m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol()); + + m_Socket->connect(Endpoint); + + m_Host = m_Socket->remote_endpoint().address().to_string(); + m_Port = Info.Port; + + ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port); + } + catch (std::exception& Err) + { + ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); + + SetState(WebSocketState::kError); + m_Socket.reset(); + + TriggerEvent(WebSocketEvent::kDisconnected); + + m_ConnectPromise.set_value(false); + + return m_ConnectPromise.get_future(); + } + + ExtendableStringBuilder<128> Sb; + Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv; + Sb << "Host: " << Info.Host << "\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: upgrade\r\n"sv; + Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv; + + if (Info.Protocols.empty() == false) + { + Sb << "Sec-WebSocket-Protocol: "sv; + for (size_t Idx = 0; const auto& Protocol : Info.Protocols) + { + if (Idx++) + { + Sb << ", "; + } + Sb << Protocol; + } + } + + Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv; + Sb << "\r\n"; + + std::string HandshakeRequest = Sb.ToString(); + asio::const_buffer Buffer = asio::buffer(HandshakeRequest); + + ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port); + + m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse); + m_MsgParser->Reset(); + + async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message()); + + Self->Disconnect(); + } + else + { + Self->ReadMessage(); + } + }); + + return m_ConnectPromise.get_future(); +} + +void +WsClient::Disconnect() +{ + if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) + { + ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port); + + if (m_Socket && m_Socket->is_open()) + { + m_Socket->close(); + m_Socket.reset(); + } + + TriggerEvent(WebSocketEvent::kDisconnected); + + { + std::unique_lock _(m_RequestMutex); + + for (auto& Kv : m_PendingRequests) + { + Kv.second.set_value(WebSocketMessage()); + } + + m_PendingRequests.clear(); + } + } +} + +std::future<WebSocketMessage> +WsClient::SendRequest(WebSocketMessage&& Request) +{ + ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest); + + BinaryWriter Writer; + Request.Save(Writer); + + std::future<WebSocketMessage> FutureResponse; + + { + std::unique_lock _(m_RequestMutex); + + auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>()); + ZEN_ASSERT(Result.second); + + auto It = Result.first; + FutureResponse = It->second.get_future(); + } + + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) { + if (Ec) + { + ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message()); + + Self->Disconnect(); + } + }); + + return FutureResponse; +} + +void +WsClient::OnNotification(NotificationCallback&& Cb) +{ + m_NotificationCallback = std::move(Cb); +} + +void +WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) +{ + m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); +} + +void +WsClient::TriggerEvent(WebSocketEvent Evt) +{ + const uint32_t Index = static_cast<uint32_t>(Evt); + + if (m_EventCallbacks[Index]) + { + m_EventCallbacks[Index](); + } +} + +void +WsClient::ReadMessage() +{ + m_ReadBuffer.prepare(64 << 10); + + async_read(*m_Socket, + m_ReadBuffer, + asio::transfer_at_least(1), + [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable { + if (Ec) + { + ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message()); + + return Self->Disconnect(); + } + + const WebSocketState State = Self->State(); + + switch (State) + { + case WebSocketState::kHandshaking: + { + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser()); + + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); + + ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Self->ReadBuffer().consume(size_t(Result.ByteCount)); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); + + Self->m_ConnectPromise.set_value(false); + + return Self->Disconnect(); + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + return Self->ReadMessage(); + } + + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + + if (Parser.StatusCode() != 101) + { + ZEN_LOG_WARN(LogWsClient, + "handshake FAILED, status '{}', status code '{}'", + Parser.StatusText(), + Parser.StatusCode()); + + Self->m_ConnectPromise.set_value(false); + + return Self->Disconnect(); + } + + ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText()); + + Self->SetParser(std::make_unique<WebSocketMessageParser>()); + Self->SetState(WebSocketState::kConnected); + Self->ReadMessage(); + Self->TriggerEvent(WebSocketEvent::kConnected); + + Self->m_ConnectPromise.set_value(true); + } + break; + + case WebSocketState::kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser()); + + uint64_t RemainingBytes = Self->ReadBuffer().size(); + + while (RemainingBytes > 0) + { + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Self->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Self->ReadBuffer().size(); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + + Parser.Reset(); + continue; + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(RemainingBytes == 0); + continue; + } + + WebSocketMessage Message = Parser.ConsumeMessage(); + Parser.Reset(); + + Self->RouteMessage(std::move(Message)); + } + + Self->ReadMessage(); + } + break; + } + }); +} + +void +WsClient::RouteMessage(WebSocketMessage&& RoutedMessage) +{ + switch (RoutedMessage.MessageType()) + { + case WebSocketMessageType::kResponse: + { + std::unique_lock _(m_RequestMutex); + + if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end()) + { + It->second.set_value(std::move(RoutedMessage)); + m_PendingRequests.erase(It); + } + else + { + ZEN_LOG_WARN(LogWsClient, + "route request message FAILED, reason 'unknown correlation ID ({})'", + RoutedMessage.CorrelationId()); + } + } + break; + + case WebSocketMessageType::kNotification: + { + std::unique_lock _(m_RequestMutex); + + if (m_NotificationCallback) + { + m_NotificationCallback(std::move(RoutedMessage)); + } + } + break; + }; +} + +} // namespace zen::websocket + +namespace zen { + +std::atomic_uint32_t WebSocketId::NextId{1}; + +bool +WebSocketMessage::Header::IsValid() const +{ + return Magic == HeaderMagic && MessageSize > 0 && Crc32 > 0 && uint8_t(MessageType) > 0 && uint8_t(MessageType) < 4; +} + +std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; + +void +WebSocketMessage::SetMessageType(WebSocketMessageType MessageType) +{ + m_Header.MessageType = MessageType; +} + +void +WebSocketMessage::SetBody(CbPackage&& Body) +{ + m_Body = std::move(Body); +} +void +WebSocketMessage::SetBody(CbObject&& Body) +{ + CbPackage Pkg; + Pkg.SetObject(Body); + + SetBody(std::move(Pkg)); +} + +void +WebSocketMessage::Save(BinaryWriter& Writer) +{ + Writer.Write(&m_Header, HeaderSize); + + if (m_Body.has_value()) + { + m_Body.value().Save(Writer); + } + + if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest) + { + m_Header.CorrelationId = NextCorrelationId.fetch_add(1); + } + + m_Header.MessageSize = Writer.Size() - HeaderSize; + m_Header.Crc32 = 1; // TODO + + Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); +} + +bool +WebSocketMessage::TryLoadHeader(MemoryView Memory) +{ + if (Memory.GetSize() < HeaderSize) + { + return false; + } + + MutableMemoryView HeaderView(&m_Header, HeaderSize); + + HeaderView.CopyFrom(Memory); + + return m_Header.IsValid(); +} + +void +WebSocketService::Configure(WebSocketServer& Server) +{ + ZEN_ASSERT(m_SocketServer == nullptr); + + m_SocketServer = &Server; + + RegisterHandlers(Server); +} + +std::unique_ptr<WebSocketServer> +WebSocketServer::Create(const WebSocketServerOptions& Options) +{ + return std::make_unique<websocket::WsServer>(Options); +} + +std::shared_ptr<WebSocketClient> +WebSocketClient::Create(asio::io_context& IoCtx) +{ + return std::make_shared<websocket::WsClient>(IoCtx); +} + +} // namespace zen diff --git a/zenserver-test/zenserver-test.cpp b/zenserver-test/zenserver-test.cpp index 293af7816..6a1b54b79 100644 --- a/zenserver-test/zenserver-test.cpp +++ b/zenserver-test/zenserver-test.cpp @@ -19,6 +19,7 @@ #include <zencore/timer.h> #include <zenhttp/httpclient.h> #include <zenhttp/httpshared.h> +#include <zenhttp/websocket.h> #include <zenhttp/zenhttp.h> #include <zenutil/cache/cache.h> #include <zenutil/zenserverprocess.h> @@ -2588,6 +2589,56 @@ private: std::vector<std::unique_ptr<ZenServerInstance> > m_Instances; }; +class IoDispatcher +{ +public: + IoDispatcher(asio::io_context& IoCtx) : m_IoCtx(IoCtx) {} + ~IoDispatcher() { Stop(); } + + void Run() + { + Stop(); + + m_Running = true; + + m_IoThread = std::thread([this]() { + try + { + m_IoCtx.run(); + } + catch (std::exception& Error) + { + m_Error = Error; + } + + m_Running = false; + }); + } + + void Stop() + { + if (m_Running) + { + m_Running = false; + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + } + } + + bool IsRunning() const { return m_Running; } + + const std::exception& Error() { return m_Error; } + +private: + asio::io_context& m_IoCtx; + std::thread m_IoThread; + std::exception m_Error; + std::atomic_bool m_Running{false}; +}; + TEST_CASE("http.basics") { using namespace std::literals; @@ -2655,6 +2706,51 @@ TEST_CASE("http.package") CHECK_EQ(ResponsePackage, TestPackage); } +TEST_CASE("websocket.basic") +{ + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const uint16_t PortNumber = 13337; + const auto MaxWaitTime = std::chrono::seconds(5); + + ZenServerInstance Inst(TestEnv); + Inst.SetTestDir(TestDir); + Inst.SpawnServer(PortNumber, "--websocket-port=8848"sv); + Inst.WaitUntilReady(); + + asio::io_context IoCtx; + IoDispatcher IoDispatcher(IoCtx); + auto WebSocket = WebSocketClient::Create(IoCtx); + + auto ConnectFuture = WebSocket->Connect({.Host = "127.0.0.1", .Port = 8848, .Endpoint = "/zen"}); + IoDispatcher.Run(); + + ConnectFuture.wait_for(MaxWaitTime); + CHECK(ConnectFuture.get()); + + for (size_t Idx = 0; Idx < 10; Idx++) + { + CbObjectWriter Request; + Request << "Method"sv + << "SayHello"sv; + + WebSocketMessage RequestMsg; + RequestMsg.SetMessageType(WebSocketMessageType::kRequest); + RequestMsg.SetBody(Request.Save()); + + auto ResponseFuture = WebSocket->SendRequest(std::move(RequestMsg)); + ResponseFuture.wait_for(MaxWaitTime); + + CbObject Response = ResponseFuture.get().Body().GetObject(); + std::string_view Message = Response["Result"].AsString(); + + CHECK(Message == "Hello Friend!!"sv); + } + + WebSocket->Disconnect(); + + IoDispatcher.Stop(); +} + # if 0 TEST_CASE("lifetime.owner") { diff --git a/zenserver/config.cpp b/zenserver/config.cpp index cb6d5ea6d..bcacc16c0 100644 --- a/zenserver/config.cpp +++ b/zenserver/config.cpp @@ -193,6 +193,20 @@ ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions) cxxopts::value<int>(ServerOptions.BasePort)->default_value("1337"), "<port number>"); + options.add_option("network", + "", + "websocket-port", + "Websocket server port", + cxxopts::value<int>(ServerOptions.WebSocketPort)->default_value("0"), + "<port number>"); + + options.add_option("network", + "", + "websocket-threads", + "Number of websocket I/O thread(s) (0 == hardware concurrency)", + cxxopts::value<int>(ServerOptions.WebSocketThreads)->default_value("0"), + ""); + #if ZEN_ENABLE_MESH options.add_option("network", "m", diff --git a/zenserver/config.h b/zenserver/config.h index 69e65498c..fd569bdb1 100644 --- a/zenserver/config.h +++ b/zenserver/config.h @@ -87,17 +87,19 @@ struct ZenServerOptions { ZenUpstreamCacheConfig UpstreamCacheConfig; ZenGcConfig GcConfig; - std::filesystem::path DataDir; // Root directory for state (used for testing) - std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental) - std::filesystem::path AbsLogFile; // Absolute path to main log file - std::filesystem::path ConfigFile; // Path to Lua config file - std::string ChildId; // Id assigned by parent process (used for lifetime management) - std::string LogId; // Id for tagging log output - std::string HttpServerClass; // Choice of HTTP server implementation - std::string EncryptionKey; // 256 bit AES encryption key - std::string EncryptionIV; // 128 bit AES initialization vector - int BasePort = 1337; // Service listen port (used for both UDP and TCP) - int OwnerPid = 0; // Parent process id (zero for standalone) + std::filesystem::path DataDir; // Root directory for state (used for testing) + std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental) + std::filesystem::path AbsLogFile; // Absolute path to main log file + std::filesystem::path ConfigFile; // Path to Lua config file + std::string ChildId; // Id assigned by parent process (used for lifetime management) + std::string LogId; // Id for tagging log output + std::string HttpServerClass; // Choice of HTTP server implementation + std::string EncryptionKey; // 256 bit AES encryption key + std::string EncryptionIV; // 128 bit AES initialization vector + int BasePort = 1337; // Service listen port (used for both UDP and TCP) + int OwnerPid = 0; // Parent process id (zero for standalone) + int WebSocketPort = 0; // Web socket port (Zero = disabled) + int WebSocketThreads = 0; bool InstallService = false; // Flag used to initiate service install (temporary) bool UninstallService = false; // Flag used to initiate service uninstall (temporary) bool IsDebug = false; diff --git a/zenserver/testing/httptest.cpp b/zenserver/testing/httptest.cpp index 230d5d6c5..10b69c469 100644 --- a/zenserver/testing/httptest.cpp +++ b/zenserver/testing/httptest.cpp @@ -8,6 +8,8 @@ namespace zen { +using namespace std::literals; + HttpTestingService::HttpTestingService() { m_Router.RegisterRoute( @@ -136,6 +138,38 @@ HttpTestingService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) return (InsertResult.first->second = new PackageHandler(*this, RequestId)).Get(); } +void +HttpTestingService::RegisterHandlers(WebSocketServer& Server) +{ + Server.RegisterRequestHandler("SayHello"sv, *this); +} + +bool +HttpTestingService::HandleRequest(const WebSocketMessage& RequestMsg) +{ + CbObjectView Request = RequestMsg.Body().GetObject(); + + std::string_view Method = Request["Method"].AsString(); + + if (Method != "SayHello"sv) + { + return false; + } + + CbObjectWriter Response; + Response.AddString("Result"sv, "Hello Friend!!"); + + WebSocketMessage ResponseMsg; + ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); + ResponseMsg.SetCorrelationId(RequestMsg.CorrelationId()); + ResponseMsg.SetSocketId(RequestMsg.SocketId()); + ResponseMsg.SetBody(Response.Save()); + + SocketServer().SendResponse(std::move(ResponseMsg)); + + return true; +} + ////////////////////////////////////////////////////////////////////////// HttpTestingService::PackageHandler::PackageHandler(HttpTestingService& Svc, uint32_t RequestId) : m_Svc(Svc), m_RequestId(RequestId) diff --git a/zenserver/testing/httptest.h b/zenserver/testing/httptest.h index f7ea0c31c..57d2d63f3 100644 --- a/zenserver/testing/httptest.h +++ b/zenserver/testing/httptest.h @@ -5,6 +5,7 @@ #include <zencore/logging.h> #include <zencore/stats.h> #include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> #include <atomic> @@ -13,7 +14,7 @@ namespace zen { /** * Test service to facilitate testing the HTTP framework and client interactions */ -class HttpTestingService : public HttpService +class HttpTestingService : public HttpService, public WebSocketService { public: HttpTestingService(); @@ -40,6 +41,9 @@ public: }; private: + virtual void RegisterHandlers(WebSocketServer& Server) override; + virtual bool HandleRequest(const WebSocketMessage& Request) override; + HttpRequestRouter m_Router; std::atomic<uint32_t> m_Counter{0}; metrics::OperationTiming m_TimingStats; diff --git a/zenserver/zenserver.cpp b/zenserver/zenserver.cpp index ea0f52db2..78a62e202 100644 --- a/zenserver/zenserver.cpp +++ b/zenserver/zenserver.cpp @@ -15,6 +15,7 @@ #include <zencore/timer.h> #include <zencore/trace.h> #include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> #include <zenstore/basicfile.h> #include <zenstore/cas.h> #include <zenstore/cidstore.h> @@ -204,6 +205,15 @@ public: m_Http = zen::CreateHttpServer(ServerOptions.HttpServerClass); int EffectiveBasePort = m_Http->Initialize(ServerOptions.BasePort); + if (ServerOptions.WebSocketPort != 0) + { + const uint32 ThreadCount = + ServerOptions.WebSocketThreads > 0 ? uint32_t(ServerOptions.WebSocketThreads) : std::thread::hardware_concurrency(); + + m_WebSocket = zen::WebSocketServer::Create( + {.Port = gsl::narrow<uint16_t>(ServerOptions.WebSocketPort), .ThreadCount = Max(ThreadCount, uint32_t(16))}); + } + // Setup authentication manager { std::string EncryptionKey = ServerOptions.EncryptionKey; @@ -304,6 +314,11 @@ public: m_Http->RegisterService(m_TestingService); m_Http->RegisterService(m_AdminService); + if (m_WebSocket) + { + m_WebSocket->RegisterService(m_TestingService); + } + if (m_HttpProjectService) { m_Http->RegisterService(*m_HttpProjectService); @@ -396,6 +411,11 @@ public: OnReady(); + if (m_WebSocket) + { + m_WebSocket->Run(); + } + m_Http->Run(IsInteractiveMode); SetNewState(kShuttingDown); @@ -559,6 +579,7 @@ private: } zen::Ref<zen::HttpServer> m_Http; + std::unique_ptr<zen::WebSocketServer> m_WebSocket; std::unique_ptr<zen::AuthMgr> m_AuthMgr; std::unique_ptr<zen::HttpAuthService> m_AuthService; zen::HttpStatusService m_StatusService; |