diff options
| author | Per Larsson <[email protected]> | 2022-02-18 06:56:20 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-02-18 06:56:20 +0100 |
| commit | 4b9bac3c5baf7633cd51cffcf8e63cb5527ddb36 (patch) | |
| tree | 9d2f5e83679c0eea5de63b129eb1a2779501b28b /zenhttp/websocketasio.cpp | |
| parent | Renamed file. (diff) | |
| download | zen-4b9bac3c5baf7633cd51cffcf8e63cb5527ddb36.tar.xz zen-4b9bac3c5baf7633cd51cffcf8e63cb5527ddb36.zip | |
Simple websocket client/server test.
Diffstat (limited to 'zenhttp/websocketasio.cpp')
| -rw-r--r-- | zenhttp/websocketasio.cpp | 1047 |
1 files changed, 727 insertions, 320 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index ad8434a5a..b800892d2 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -1,12 +1,13 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zenhttp/websocketserver.h> +#include <zenhttp/websocket.h> #include <zencore/base64.h> #include <zencore/compactbinarypackage.h> #include <zencore/intmath.h> #include <zencore/iobuffer.h> #include <zencore/logging.h> +#include <zencore/memory.h> #include <zencore/sha1.h> #include <zencore/stream.h> #include <zencore/string.h> @@ -25,11 +26,11 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio.hpp> ZEN_THIRD_PARTY_INCLUDES_END -namespace zen::asio_ws { +namespace zen::websocket { using namespace std::literals; -ZEN_DEFINE_LOG_CATEGORY_STATIC(WsLog, "websocket"sv); +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); using Clock = std::chrono::steady_clock; using TimePoint = Clock::time_point; @@ -45,263 +46,396 @@ namespace http_header { } // namespace http_header /////////////////////////////////////////////////////////////////////////////// -struct HttpParser +enum class ParseMessageStatus : uint32_t { - HttpParser() + kError, + kContinue, + kDone, +}; + +struct ParseMessageResult +{ + ParseMessageStatus Status{}; + size_t ByteCount{}; + std::optional<std::string> Reason; +}; + +class MessageParser +{ +public: + virtual ~MessageParser() = default; + + ParseMessageResult ParseMessage(MemoryView Msg); + void Reset(); + +protected: + MessageParser() = default; + + virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; + virtual void OnReset() = 0; + + SimpleBinaryWriter m_Stream; +}; + +ParseMessageResult +MessageParser::ParseMessage(MemoryView Msg) +{ + return OnParseMessage(Msg); +} + +void +MessageParser::Reset() +{ + OnReset(); + + m_Stream.Clear(); +} + +/////////////////////////////////////////////////////////////////////////////// +enum class HttpMessageParserType +{ + kRequest, + kResponse, + kBoth +}; + +class HttpMessageParser final : public MessageParser +{ +public: + using HttpHeaders = std::unordered_map<std::string_view, std::string_view>; + + HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) {} + virtual ~HttpMessageParser() = default; + + int32_t StatusCode() const { return m_Parser.status_code; } + bool IsUpgrade() const { return m_Parser.upgrade != 0; } + HttpHeaders& Headers() { return m_Headers; } + MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } + + std::string_view StatusText() const { - http_parser_init(&Parser, HTTP_REQUEST); - Parser.data = this; + return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); } - size_t Parse(asio::const_buffer Buffer) + bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); + +private: + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; + int OnMessageBegin(); + int OnUrl(MemoryView Url); + int OnStatus(MemoryView Status); + int OnHeaderField(MemoryView HeaderField); + int OnHeaderValue(MemoryView HeaderValue); + int OnHeadersComplete(); + int OnBody(MemoryView Body); + int OnMessageComplete(); + + struct StreamEntry { - return http_parser_execute(&Parser, &ParserSettings, reinterpret_cast<const char*>(Buffer.data()), Buffer.size()); - } + uint64_t Offset{}; + uint64_t Size{}; + }; - void GetHeaders(std::unordered_map<std::string_view, std::string_view>& OutHeaders) + struct HeaderStreamEntry { - OutHeaders.reserve(HeaderEntries.size()); + StreamEntry Field{}; + StreamEntry Value{}; + }; - for (const auto& E : HeaderEntries) - { - auto Name = std::string_view((const char*)HeaderStream.Data() + E.Name.Offset, E.Name.Size); - auto Value = std::string_view((const char*)HeaderStream.Data() + E.Value.Offset, E.Value.Size); + HttpMessageParserType m_Type; + http_parser m_Parser; + StreamEntry m_UrlEntry; + StreamEntry m_StatusEntry; + StreamEntry m_BodyEntry; + HeaderStreamEntry m_CurrentHeader; + std::vector<HeaderStreamEntry> m_HeaderEntries; + HttpHeaders m_Headers; + bool m_IsMsgComplete{false}; - OutHeaders[Name] = Value; - } - } + static http_parser_settings ParserSettings; +}; - std::string ValidateWebSocketHandshake(std::unordered_map<std::string_view, std::string_view>& Headers, std::string& OutReason) +http_parser_settings HttpMessageParser::ParserSettings = { + .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); }, + + .on_url = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); }, + + .on_status = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); }, + + .on_header_field = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); }, + + .on_header_value = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, + + .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); }, + + .on_body = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); }, + + .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }}; + +ParseMessageResult +HttpMessageParser::OnParseMessage(MemoryView Msg) +{ + const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize()); + + auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; + + if (m_Parser.http_errno != 0) { - static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + Status = ParseMessageStatus::kError; + } - std::string AcceptHash; + return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; +} - if (Headers.contains(http_header::SecWebSocketKey) == false) - { - OutReason = "Missing header Sec-WebSocket-Key"; - return AcceptHash; - } +void +HttpMessageParser::OnReset() +{ + http_parser_init(&m_Parser, + m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST + : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE + : HTTP_BOTH); + m_Parser.data = this; - if (Headers.contains(http_header::Upgrade) == false) - { - OutReason = "Missing header Upgrade"; - return AcceptHash; - } + m_UrlEntry = {}; + m_StatusEntry = {}; + m_CurrentHeader = {}; + m_BodyEntry = {}; + + m_IsMsgComplete = false; - ExtendableStringBuilder<128> Sb; - Sb << Headers[http_header::SecWebSocketKey] << WebSocketGuid; + m_HeaderEntries.clear(); +} + +int +HttpMessageParser::OnMessageBegin() +{ + ZEN_ASSERT(m_IsMsgComplete == false); + ZEN_ASSERT(m_HeaderEntries.empty()); + ZEN_ASSERT(m_Headers.empty()); + + return 0; +} + +int +HttpMessageParser::OnStatus(MemoryView Status) +{ + m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; + + m_Stream.Write(Status); + + return 0; +} - SHA1Stream HashStream; - HashStream.Append(Sb.Data(), Sb.Size()); +int +HttpMessageParser::OnUrl(MemoryView Url) +{ + m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; - SHA1 Hash = HashStream.GetHash(); + m_Stream.Write(Url); - AcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); - Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), AcceptHash.data()); + return 0; +} - return AcceptHash; +int +HttpMessageParser::OnHeaderField(MemoryView HeaderField) +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; } - static void Initialize() + if (m_CurrentHeader.Field.Size == 0) { - ParserSettings = {.on_message_begin = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - Parser.Url = UrlEntry{}; - Parser.CurrentHeader = HeaderEntry{}; - Parser.IsUpgrade = false; - Parser.IsComplete = false; - - Parser.HeaderStream.Clear(); - Parser.HeaderEntries.clear(); - - return 0; - }, - .on_url = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - Parser.Url.Offset = Parser.HeaderStream.CurrentOffset(); - Parser.Url.Size = Size; - - Parser.HeaderStream.Write(Data, uint32_t(Size)); - - return 0; - }, - .on_header_field = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - if (Parser.CurrentHeader.Value.Size > 0) - { - Parser.HeaderEntries.push_back(Parser.CurrentHeader); - Parser.CurrentHeader = HeaderEntry{}; - } - - if (Parser.CurrentHeader.Name.Size == 0) - { - Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.CurrentOffset(); - } - - Parser.CurrentHeader.Name.Size += Size; - - Parser.HeaderStream.Write(Data, Size); - - return 0; - }, - .on_header_value = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - if (Parser.CurrentHeader.Value.Size == 0) - { - Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.CurrentOffset(); - } - - Parser.CurrentHeader.Value.Size += Size; - - Parser.HeaderStream.Write(Data, Size); - - return 0; - }, - .on_headers_complete = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - if (Parser.CurrentHeader.Value.Size > 0) - { - Parser.HeaderEntries.push_back(Parser.CurrentHeader); - Parser.CurrentHeader = HeaderEntry{}; - } - - Parser.IsUpgrade = P->upgrade > 0; - - return 0; - }, - .on_message_complete = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - Parser.IsComplete = true; - Parser.IsUpgrade = P->upgrade > 0; - return 0; - }}; + m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); } - struct MemStreamEntry + m_CurrentHeader.Field.Size += HeaderField.GetSize(); + + m_Stream.Write(HeaderField); + + return 0; +} + +int +HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) +{ + if (m_CurrentHeader.Value.Size == 0) { - size_t Offset{}; - size_t Size{}; - }; + m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Value.Size += HeaderValue.GetSize(); - using UrlEntry = MemStreamEntry; + m_Stream.Write(HeaderValue); - struct HeaderEntry + return 0; +} + +int +HttpMessageParser::OnHeadersComplete() +{ + if (m_CurrentHeader.Value.Size > 0) { - MemStreamEntry Name; - MemStreamEntry Value; - }; + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } - static http_parser_settings ParserSettings; + m_Headers.clear(); + m_Headers.reserve(m_HeaderEntries.size()); - http_parser Parser; - SimpleBinaryWriter HeaderStream; - std::vector<HeaderEntry> HeaderEntries; - HeaderEntry CurrentHeader{}; - UrlEntry Url{}; - bool IsUpgrade = false; - bool IsComplete = false; -}; + const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data()); + + for (const auto& Entry : m_HeaderEntries) + { + auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); + auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); -http_parser_settings HttpParser::ParserSettings; + m_Headers.try_emplace(std::move(Field), std::move(Value)); + } -/////////////////////////////////////////////////////////////////////////////// -class WsMessageParser + return 0; +} + +int +HttpMessageParser::OnBody(MemoryView Body) { -public: - WsMessageParser() {} + m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; + + m_Stream.Write(Body); + + return 0; +} + +int +HttpMessageParser::OnMessageComplete() +{ + m_IsMsgComplete = true; + + return 0; +} + +bool +HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) +{ + static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + + OutAcceptHash = std::string(); - void Reset() + if (m_Headers.contains(http_header::SecWebSocketKey) == false) { - m_Header.reset(); - m_Stream.Clear(); + OutReason = "Missing header Sec-WebSocket-Key"; + return false; } - bool Parse(asio::const_buffer Buffer, size_t& OutConsumedBytes) + if (m_Headers.contains(http_header::Upgrade) == false) { - if (m_Header.has_value()) - { - OutConsumedBytes = Min(m_Header.value().ContentLength, Buffer.size()); + OutReason = "Missing header Upgrade"; + return false; + } - m_Stream.Write(Buffer.data(), OutConsumedBytes); + ExtendableStringBuilder<128> Sb; + Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; - return true; - } + SHA1Stream HashStream; + HashStream.Append(Sb.Data(), Sb.Size()); - const size_t PrevOffset = m_Stream.CurrentOffset(); - const size_t BytesToWrite = Min(sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(), Buffer.size()); - const size_t RemainingBytes = Buffer.size() - BytesToWrite; + SHA1 Hash = HashStream.GetHash(); - m_Stream.Write(Buffer.data(), BytesToWrite); + OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); + Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); - if (m_Stream.CurrentOffset() < sizeof(zen::WebSocketMessageHeader)) - { - OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + return true; +} - return true; - } +/////////////////////////////////////////////////////////////////////////////// +class WebSocketMessageParser final : public MessageParser +{ +public: + WebSocketMessageParser() : MessageParser() {} - zen::WebSocketMessageHeader Header; - if (zen::WebSocketMessageHeader::Read(m_Stream.GetView(), Header) == false) - { - OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + bool TryLoadMessage(CbPackage& OutMsg); - return false; - } +private: + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; - m_Header = Header; + SimpleBinaryWriter m_HeaderStream; + WebSocketMessageHeader m_Header; +}; - if (RemainingBytes > 0) - { - const size_t RemainingBytesToWrite = Min(m_Header.value().ContentLength, RemainingBytes); +ParseMessageResult +WebSocketMessageParser::OnParseMessage(MemoryView Msg) +{ + const uint64_t PrevOffset = m_Stream.CurrentOffset(); - m_Stream.Write(reinterpret_cast<const char*>(Buffer.data()) + BytesToWrite, RemainingBytesToWrite); - } + if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader)) + { + const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_HeaderStream.CurrentOffset(); - OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + m_HeaderStream.Write(Msg.Left(RemaingHeaderSize)); - return true; - } + Msg.RightChopInline(RemaingHeaderSize); - bool IsComplete() - { - if (m_Header.has_value()) + if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader)) { - const size_t RemainingBytes = m_Header.value().ContentLength + sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(); - - return RemainingBytes == 0; + return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } - return false; + const bool IsValidHeader = WebSocketMessageHeader::Read(m_HeaderStream.GetView(), m_Header); + + if (IsValidHeader == false) + { + return {.Status = ParseMessageStatus::kError, .Reason = std::string("Invalid websocket message header")}; + } } - bool TryLoadMessage(CbPackage& OutPackage) + if (Msg.GetSize() == 0) { - if (IsComplete()) - { - BinaryReader Reader(m_Stream.Data(), m_Stream.Size()); + return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } - return OutPackage.TryLoad(Reader); - } + const uint64_t RemaingContentSize = m_Header.ContentLength - m_HeaderStream.CurrentOffset(); - return false; + m_Stream.Write(Msg.Left(RemaingContentSize)); + + const auto Status = m_Stream.CurrentOffset() == m_Header.ContentLength ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; + + return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; +} + +void +WebSocketMessageParser::OnReset() +{ + m_HeaderStream.Clear(); + m_Header = {}; +} + +bool +WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg) +{ + const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength; + + if (IsParsed) + { + BinaryReader Reader(m_Stream.Data(), m_Stream.Size()); + + return OutMsg.TryLoad(Reader); } -private: - SimpleBinaryWriter m_Stream{64 << 10}; - std::optional<zen::WebSocketMessageHeader> m_Header; -}; + return false; +} /////////////////////////////////////////////////////////////////////////////// enum class WsConnectionState : uint32_t @@ -346,34 +480,34 @@ public: , m_StartTime(Clock::now()) , m_Status() { + m_RemoteAddr = m_Socket->remote_endpoint().address().to_string(); } ~WsConnection(); - WsConnectionId Id() const { return m_Id; } - asio::ip::tcp::socket& Socket() { return *m_Socket; } - TimePoint StartTime() const { return m_StartTime; } std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } - asio::streambuf& ReadBuffer() { return m_ReadBuffer; } - HttpParser& ParserHttp() { return *m_HttpParser; } - WsMessageParser& MessageParser() { return m_MsgParser; } - WsConnectionState Close(); - WsConnectionState State() const { return static_cast<WsConnectionState>(m_Status.load(std::memory_order_relaxed)); } + WsConnectionId Id() const { return m_Id; } + std::string_view RemoteAddr() const { return m_RemoteAddr; } + asio::ip::tcp::socket& Socket() { return *m_Socket; } + TimePoint StartTime() const { return m_StartTime; } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + WsConnectionState Close(); + WsConnectionState State() const { return static_cast<WsConnectionState>(m_Status.load(std::memory_order_relaxed)); } WsConnectionState SetState(WsConnectionState NewState) { return static_cast<WsConnectionState>(m_Status.exchange(uint32_t(NewState))); } - void InitializeHttpParser() { m_HttpParser = std::make_unique<HttpParser>(); } - void ReleaseHttpParser() { m_HttpParser.reset(); } + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } private: WsServer& m_Server; WsConnectionId m_Id; std::unique_ptr<asio::ip::tcp::socket> m_Socket; - std::unique_ptr<HttpParser> m_HttpParser; - WsMessageParser m_MsgParser; + std::unique_ptr<MessageParser> m_MsgParser; TimePoint m_StartTime; - std::atomic_uint32_t m_Status; asio::streambuf m_ReadBuffer; + std::string m_RemoteAddr; + std::atomic_uint32_t m_Status; }; WsConnectionState @@ -402,6 +536,7 @@ public: private: asio::io_service& m_IoSvc; std::vector<std::thread> m_Threads; + std::atomic_bool m_Running{false}; }; void @@ -409,18 +544,28 @@ WsThreadPool::Start(uint32_t ThreadCount) { ZEN_ASSERT(m_Threads.empty()); - ZEN_LOG_DEBUG(WsLog, "starting '{}' websocket I/O thread(s)", ThreadCount); + ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount); + + m_Running = true; for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) { m_Threads.emplace_back([this, ThreadId = Idx + 1] { - try - { - m_IoSvc.run(); - } - catch (std::exception& Err) + for (;;) { - ZEN_LOG_ERROR(WsLog, "process websocket I/O FAILED, reason '{}'", Err.what()); + if (m_Running == false) + { + break; + } + + try + { + m_IoSvc.run(); + } + catch (std::exception& Err) + { + ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what()); + } } ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId); @@ -431,15 +576,20 @@ WsThreadPool::Start(uint32_t ThreadCount) void WsThreadPool::Stop() { - for (std::thread& Thread : m_Threads) + if (m_Running) { - if (Thread.joinable()) + m_Running = false; + + for (std::thread& Thread : m_Threads) { - Thread.join(); + if (Thread.joinable()) + { + Thread.join(); + } } - } - m_Threads.clear(); + m_Threads.clear(); + } } /////////////////////////////////////////////////////////////////////////////// @@ -485,8 +635,6 @@ WsConnection::~WsConnection() bool WsServer::Run(const WebSocketServerOptions& Options) { - HttpParser::Initialize(); - m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6()); m_Acceptor->set_option(asio::ip::v6_only(false)); @@ -500,7 +648,7 @@ WsServer::Run(const WebSocketServerOptions& Options) if (Ec) { - ZEN_LOG_ERROR(WsLog, "failed to bind websocket endpoint, error code '{}'", Ec.value()); + ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value()); return false; } @@ -508,7 +656,7 @@ WsServer::Run(const WebSocketServerOptions& Options) m_Acceptor->listen(); m_Running = true; - ZEN_LOG_INFO(WsLog, "web socket server running on port '{}'", Options.Port); + ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", Options.Port); AcceptConnection(); @@ -523,7 +671,7 @@ WsServer::Shutdown() { if (m_Running) { - ZEN_LOG_INFO(WsLog, "websocket server shutting down"); + ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down"); m_Running = false; @@ -544,22 +692,24 @@ WsServer::AcceptConnection() m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { if (Ec) { - ZEN_LOG_WARN(WsLog, "accept connection FAILED, error code '{}'", Ec.value()); + ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, error code '{}'", Ec.value()); } else { - auto ConnId = WsConnectionId::New(); - - ZEN_LOG_DEBUG(WsLog, "accept connection OK, ID '{}'", ConnId.Value()); - + auto ConnId = WsConnectionId::New(); auto Connection = std::make_shared<WsConnection>(*this, ConnId, std::move(ConnectedSocket)); + ZEN_LOG_DEBUG(LogWebSocket, "accept connection OK, addr '{}', ID '{}'", Connection->RemoteAddr(), ConnId.Value()); + { std::unique_lock _(m_ConnMutex); m_Connections[ConnId] = Connection; } - Connection->InitializeHttpParser(); + auto Parser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest); + Parser->Reset(); + + Connection->SetParser(std::move(Parser)); Connection->SetState(WsConnectionState::kHandshaking); ReadMessage(Connection); @@ -579,11 +729,15 @@ WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::e { if (Ec) { - ZEN_LOG_INFO(WsLog, "connection '{}' closed, ERROR '{}' error code '{}'", Connection->Id().Value(), Ec.message(), Ec.value()); + ZEN_LOG_INFO(LogWebSocket, + "connection '{}' closed, ERROR '{}' error code '{}'", + Connection->Id().Value(), + Ec.message(), + Ec.value()); } else { - ZEN_LOG_INFO(WsLog, "connection '{}' closed", Connection->Id().Value()); + ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value()); } } @@ -591,14 +745,17 @@ WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::e { std::unique_lock _(m_ConnMutex); - m_Connections.erase(Id); + if (m_Connections.contains(Id)) + { + m_Connections.erase(Id); + } } } void WsServer::RemoveConnection(const WsConnectionId Id) { - ZEN_LOG_INFO(WsLog, "removing connection '{}'", Id.Value()); + ZEN_LOG_INFO(LogWebSocket, "removing connection '{}'", Id.Value()); } void @@ -616,7 +773,11 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) return CloseConnection(Connection, ReadEc); } - ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection->Id().Value()); + ZEN_LOG_DEBUG(LogWebSocket, + "reading {}B from connection '#{} {}'", + ByteCount, + Connection->Id().Value(), + Connection->RemoteAddr()); using enum WsConnectionState; @@ -624,20 +785,32 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) { case kHandshaking: { - HttpParser& Parser = Connection->ParserHttp(); - const size_t Consumed = Parser.Parse(Connection->ReadBuffer().data()); - Connection->ReadBuffer().consume(Consumed); + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser()); + asio::const_buffer Buffer = Connection->ReadBuffer().data(); + + ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); - if (Parser.IsComplete == false) + if (Result.Status == ParseMessageStatus::kContinue) { return ReadMessage(Connection); } - if (Parser.IsUpgrade == false) + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + return CloseConnection(Connection, std::error_code()); + } + + if (Parser.IsUpgrade() == false) { - ZEN_LOG_DEBUG(WsLog, - "handshake with connection '{}' FAILED, reason 'not an upgrade request'", - Connection->Id().Value()); + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'", + Connection->Id().Value(), + Connection->RemoteAddr()); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; @@ -646,27 +819,28 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) [this, Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { - CloseConnection(Connection, WriteEc); + return CloseConnection(Connection, WriteEc); } - else - { - Connection->InitializeHttpParser(); - Connection->SetState(WsConnectionState::kHandshaking); - ReadMessage(Connection); - } + Connection->Parser()->Reset(); + Connection->SetState(WsConnectionState::kHandshaking); + + ReadMessage(Connection); }); } - std::unordered_map<std::string_view, std::string_view> Headers; - Parser.GetHeaders(Headers); + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + std::string AcceptHash; std::string Reason; - std::string AcceptHash = Parser.ValidateWebSocketHandshake(Headers, Reason); + const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason); - if (AcceptHash.empty()) + if (ValidHandshake == false) { - ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection->Id().Value(), Reason); + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + Reason); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; @@ -675,16 +849,13 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) [this, &Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { - CloseConnection(Connection, WriteEc); + return CloseConnection(Connection, WriteEc); } - else - { - // TODO: Always close connection? - Connection->InitializeHttpParser(); - Connection->SetState(WsConnectionState::kHandshaking); - ReadMessage(Connection); - } + Connection->Parser()->Reset(); + Connection->SetState(WsConnectionState::kHandshaking); + + ReadMessage(Connection); }); } @@ -695,95 +866,325 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) Sb << "Connection: Upgrade\r\n"sv; // TODO: Verify protocol - if (Headers.contains(http_header::SecWebSocketProtocol)) + if (Parser.Headers().contains(http_header::SecWebSocketProtocol)) { - Sb << http_header::SecWebSocketProtocol << ": " << Headers[http_header::SecWebSocketProtocol] << "\r\n"; + Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol] + << "\r\n"; } Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; Sb << "\r\n"sv; - std::string Response = Sb.ToString(); - asio::const_buffer Buffer = asio::buffer(Response); + ZEN_LOG_DEBUG(LogWebSocket, + "accepting handshake from connection '#{} {}'", + Connection->Id().Value(), + Connection->RemoteAddr()); - ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection->Id().Value()); + std::string Response = Sb.ToString(); + Buffer = asio::buffer(Response); async_write(Connection->Socket(), Buffer, - [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) { + [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) { if (WriteEc) { - CloseConnection(Connection, WriteEc); + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + WriteEc.message()); + + return CloseConnection(Connection, WriteEc); } - else - { - ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection->Id().Value()); - Connection->ReleaseHttpParser(); - Connection->SetState(kConnected); - Connection->MessageParser().Reset(); + ZEN_LOG_DEBUG(LogWebSocket, + "handshake ({}B) with connection '#{} {}' OK", + ByteCount, + Connection->Id().Value(), + Connection->RemoteAddr()); - ReadMessage(Connection); - } + Connection->SetParser(std::make_unique<WebSocketMessageParser>()); + Connection->SetState(kConnected); }); } break; case kConnected: { - for (;;) - { - if (Connection->ReadBuffer().size() == 0) - { - break; - } + // for (;;) + //{ + // if (Connection->ReadBuffer().size() == 0) + // { + // break; + // } + + // WsMessageParser& MessageParser = Connection->MessageParser(); + + // size_t ConsumedBytes{}; + // const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes); + + // Connection->ReadBuffer().consume(ConsumedBytes); + + // if (Ok == false) + // { + // ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, connection '{}'", + // Connection->Id().Value()); MessageParser.Reset(); + // } + + // if (Ok == false || MessageParser.IsComplete() == false) + // { + // continue; + // } + + // CbPackage Message; + // if (MessageParser.TryLoadMessage(Message) == false) + // { + // ZEN_LOG_WARN(LogWebSocket, "invalid websocket message, connection '{}'", + // Connection->Id().Value()); continue; + // } + + // RouteMessage(Message); + //} + + // ReadMessage(Connection); + } + break; + + default: + break; + }; + }); +} + +void +WsServer::RouteMessage(const CbPackage& Msg) +{ + ZEN_UNUSED(Msg); + ZEN_LOG_DEBUG(LogWebSocket, "routing message"); +} + +/////////////////////////////////////////////////////////////////////////////// +class WsClient final : public WebSocketClient +{ +public: + WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Logger(zen::logging::Get("websocket-client")) {} + + virtual ~WsClient() { Disconnect(); } + + virtual bool Connect(const WebSocketConnectInfo& Info) override; + virtual void Disconnect() override; + virtual bool IsConnected() const { return false; } + virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); } + + virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override; + +private: + WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } + void TriggerEvent(WebSocketEvent Evt); + void BeginRead(); + spdlog::logger& Log() { return m_Logger; } + + asio::io_context& m_IoCtx; + spdlog::logger& m_Logger; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<MessageParser> m_MsgParser; + asio::streambuf m_ReadBuffer; + EventCallback m_EventCallbacks[3]; + std::atomic_uint32_t m_State; + std::string m_Host; + int16_t m_Port{}; +}; + +bool +WsClient::Connect(const WebSocketConnectInfo& Info) +{ + if (State() == WebSocketState::kConnecting || State() == WebSocketState::kConnected) + { + return true; + } + + SetState(WebSocketState::kConnecting); + + try + { + asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port); + m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol()); + + m_Socket->connect(Endpoint); + + m_Host = m_Socket->remote_endpoint().address().to_string(); + m_Port = Info.Port; + + ZEN_INFO("connected to websocket server '{}:{}'", m_Host, m_Port); + } + catch (std::exception& Err) + { + ZEN_WARN("connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); + + SetState(WebSocketState::kFailedToConnect); + m_Socket.reset(); + + TriggerEvent(WebSocketEvent::kDisconnected); + + return false; + } + + ExtendableStringBuilder<128> Sb; + Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv; + Sb << "Host: " << Info.Host << "\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: upgrade\r\n"sv; + Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv; + + if (Info.Protocols.empty() == false) + { + Sb << "Sec-WebSocket-Protocol: "sv; + for (size_t Idx = 0; const auto& Protocol : Info.Protocols) + { + if (Idx++) + { + Sb << ", "; + } + Sb << Protocol; + } + } + + Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv; + Sb << "\r\n"; + + std::string HandshakeRequest = Sb.ToString(); + asio::const_buffer Buffer = asio::buffer(HandshakeRequest); + + ZEN_DEBUG("handshaking with '{}:{}'", m_Host, m_Port); + + m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse); + m_MsgParser->Reset(); + + async_write(*m_Socket, Buffer, [this, _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_ERROR("write data FAILED, reason '{}'", Ec.message()); + + Disconnect(); + } + else + { + BeginRead(); + } + }); + + return true; +} + +void +WsClient::Disconnect() +{ + if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) + { + ZEN_INFO("closing connection to '{}:{}'", m_Host, m_Port); + + if (m_Socket && m_Socket->is_open()) + { + m_Socket->close(); + m_Socket.reset(); + } + + TriggerEvent(WebSocketEvent::kDisconnected); + } +} + +void +WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) +{ + m_EventCallbacks[static_cast<uint32_t>(Evt)] = Cb; +} + +void +WsClient::TriggerEvent(WebSocketEvent Evt) +{ + const uint32_t Index = static_cast<uint32_t>(Evt); + + if (m_EventCallbacks[Index]) + { + m_EventCallbacks[Index](); + } +} + +void +WsClient::BeginRead() +{ + m_ReadBuffer.prepare(64 << 10); + + async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t ByteCount) { + if (Ec) + { + ZEN_DEBUG("read data from '{}:{}' FAILED, reason '{}'", m_Host, m_Port, Ec.message()); + + Disconnect(); + } + else + { + ZEN_DEBUG("reading {}B from '{}:{}'", ByteCount, m_Host, m_Port); + + switch (State()) + { + case WebSocketState::kConnecting: + { + ZEN_ASSERT(m_MsgParser.get() != nullptr); - WsMessageParser& MessageParser = Connection->MessageParser(); + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(m_MsgParser.get()); - size_t ConsumedBytes{}; - const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes); + asio::const_buffer Buffer = m_ReadBuffer.data(); + ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); - Connection->ReadBuffer().consume(ConsumedBytes); + m_ReadBuffer.consume(size_t(Result.ByteCount)); - if (Ok == false) - { - ZEN_LOG_WARN(WsLog, "parse websocket message FAILED, connection '{}'", Connection->Id().Value()); - MessageParser.Reset(); - } + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode()); + + return Disconnect(); + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + return BeginRead(); + } - if (Ok == false || MessageParser.IsComplete() == false) - { - continue; - } + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); - CbPackage Message; - if (MessageParser.TryLoadMessage(Message) == false) - { - ZEN_LOG_WARN(WsLog, "invalid websocket message, connection '{}'", Connection->Id().Value()); - continue; - } + if (Parser.StatusCode() != 101) + { + ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'", + m_Host, + m_Port, + Parser.StatusText(), + Parser.StatusCode()); - RouteMessage(Message); + return Disconnect(); } - ReadMessage(Connection); + ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText()); + + m_MsgParser = std::make_unique<WebSocketMessageParser>(); + + SetState(WebSocketState::kConnected); + TriggerEvent(WebSocketEvent::kConnected); + + BeginRead(); } break; - default: + case WebSocketState::kConnected: + { + BeginRead(); + } break; }; - }); -} - -void -WsServer::RouteMessage(const CbPackage& Msg) -{ - ZEN_UNUSED(Msg); - ZEN_LOG_DEBUG(WsLog, "routing message"); + } + }); } -} // namespace zen::asio_ws +} // namespace zen::websocket namespace zen { @@ -810,7 +1211,13 @@ WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeade std::unique_ptr<WebSocketServer> WebSocketServer::Create() { - return std::make_unique<asio_ws::WsServer>(); + return std::make_unique<websocket::WsServer>(); +} + +std::unique_ptr<WebSocketClient> +WebSocketClient::Create(asio::io_context& IoCtx) +{ + return std::make_unique<websocket::WsClient>(IoCtx); } } // namespace zen |