diff options
| author | Per Larsson <[email protected]> | 2022-02-16 12:32:27 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-02-16 12:32:27 +0100 |
| commit | 87bb9700722e8319aa58484bba03e398dedede87 (patch) | |
| tree | 1a82ce932f0a729f48bf8472d5f88fa897679a8f /zenhttp/websocketasio.cpp | |
| parent | Renamed asio web socket impl. (diff) | |
| download | zen-87bb9700722e8319aa58484bba03e398dedede87.tar.xz zen-87bb9700722e8319aa58484bba03e398dedede87.zip | |
Added websocket message parser.
Diffstat (limited to 'zenhttp/websocketasio.cpp')
| -rw-r--r-- | zenhttp/websocketasio.cpp | 297 |
1 files changed, 215 insertions, 82 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index bb3999780..ad8434a5a 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -3,9 +3,12 @@ #include <zenhttp/websocketserver.h> #include <zencore/base64.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/intmath.h> #include <zencore/iobuffer.h> #include <zencore/logging.h> #include <zencore/sha1.h> +#include <zencore/stream.h> #include <zencore/string.h> #include <chrono> @@ -61,8 +64,8 @@ struct HttpParser for (const auto& E : HeaderEntries) { - auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size); - auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size); + auto Name = std::string_view((const char*)HeaderStream.Data() + E.Name.Offset, E.Name.Size); + auto Value = std::string_view((const char*)HeaderStream.Data() + E.Value.Offset, E.Value.Size); OutHeaders[Name] = Value; } @@ -120,10 +123,10 @@ struct HttpParser [](http_parser* P, const char* Data, size_t Size) { HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - Parser.Url.Offset = Parser.HeaderStream.Pos(); + Parser.Url.Offset = Parser.HeaderStream.CurrentOffset(); Parser.Url.Size = Size; - Parser.HeaderStream.Append(Data, uint32_t(Size)); + Parser.HeaderStream.Write(Data, uint32_t(Size)); return 0; }, @@ -139,12 +142,12 @@ struct HttpParser if (Parser.CurrentHeader.Name.Size == 0) { - Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos(); + Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.CurrentOffset(); } Parser.CurrentHeader.Name.Size += Size; - Parser.HeaderStream.Append(Data, Size); + Parser.HeaderStream.Write(Data, Size); return 0; }, @@ -154,12 +157,12 @@ struct HttpParser if (Parser.CurrentHeader.Value.Size == 0) { - Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos(); + Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.CurrentOffset(); } Parser.CurrentHeader.Value.Size += Size; - Parser.HeaderStream.Append(Data, Size); + Parser.HeaderStream.Write(Data, Size); return 0; }, @@ -192,34 +195,6 @@ struct HttpParser size_t Size{}; }; - class MemStream - { - public: - MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {} - - void Append(const char* Data, size_t Size) - { - const size_t NewSize = m_Size + Size; - - if (NewSize > m_Buf.size()) - { - m_Buf.resize(m_Buf.size() + m_BlockSize); - } - - memcpy(m_Buf.data() + m_Size, Data, Size); - m_Size += Size; - } - - const char* Data() const { return m_Buf.data(); } - size_t Pos() const { return m_Size; } - void Clear() { m_Size = 0; } - - private: - std::vector<char> m_Buf; - size_t m_Size{}; - size_t m_BlockSize{}; - }; - using UrlEntry = MemStreamEntry; struct HeaderEntry @@ -231,7 +206,7 @@ struct HttpParser static http_parser_settings ParserSettings; http_parser Parser; - MemStream HeaderStream; + SimpleBinaryWriter HeaderStream; std::vector<HeaderEntry> HeaderEntries; HeaderEntry CurrentHeader{}; UrlEntry Url{}; @@ -242,6 +217,93 @@ struct HttpParser http_parser_settings HttpParser::ParserSettings; /////////////////////////////////////////////////////////////////////////////// +class WsMessageParser +{ +public: + WsMessageParser() {} + + void Reset() + { + m_Header.reset(); + m_Stream.Clear(); + } + + bool Parse(asio::const_buffer Buffer, size_t& OutConsumedBytes) + { + if (m_Header.has_value()) + { + OutConsumedBytes = Min(m_Header.value().ContentLength, Buffer.size()); + + m_Stream.Write(Buffer.data(), OutConsumedBytes); + + return true; + } + + const size_t PrevOffset = m_Stream.CurrentOffset(); + const size_t BytesToWrite = Min(sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(), Buffer.size()); + const size_t RemainingBytes = Buffer.size() - BytesToWrite; + + m_Stream.Write(Buffer.data(), BytesToWrite); + + if (m_Stream.CurrentOffset() < sizeof(zen::WebSocketMessageHeader)) + { + OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + + return true; + } + + zen::WebSocketMessageHeader Header; + if (zen::WebSocketMessageHeader::Read(m_Stream.GetView(), Header) == false) + { + OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + + return false; + } + + m_Header = Header; + + if (RemainingBytes > 0) + { + const size_t RemainingBytesToWrite = Min(m_Header.value().ContentLength, RemainingBytes); + + m_Stream.Write(reinterpret_cast<const char*>(Buffer.data()) + BytesToWrite, RemainingBytesToWrite); + } + + OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; + + return true; + } + + bool IsComplete() + { + if (m_Header.has_value()) + { + const size_t RemainingBytes = m_Header.value().ContentLength + sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(); + + return RemainingBytes == 0; + } + + return false; + } + + bool TryLoadMessage(CbPackage& OutPackage) + { + if (IsComplete()) + { + BinaryReader Reader(m_Stream.Data(), m_Stream.Size()); + + return OutPackage.TryLoad(Reader); + } + + return false; + } + +private: + SimpleBinaryWriter m_Stream{64 << 10}; + std::optional<zen::WebSocketMessageHeader> m_Header; +}; + +/////////////////////////////////////////////////////////////////////////////// enum class WsConnectionState : uint32_t { kDisconnected, @@ -294,6 +356,7 @@ public: std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } asio::streambuf& ReadBuffer() { return m_ReadBuffer; } HttpParser& ParserHttp() { return *m_HttpParser; } + WsMessageParser& MessageParser() { return m_MsgParser; } WsConnectionState Close(); WsConnectionState State() const { return static_cast<WsConnectionState>(m_Status.load(std::memory_order_relaxed)); } @@ -307,6 +370,7 @@ private: WsConnectionId m_Id; std::unique_ptr<asio::ip::tcp::socket> m_Socket; std::unique_ptr<HttpParser> m_HttpParser; + WsMessageParser m_MsgParser; TimePoint m_StartTime; std::atomic_uint32_t m_Status; asio::streambuf m_ReadBuffer; @@ -392,9 +456,11 @@ private: friend class WsConnection; void AcceptConnection(); - void CloseConnection(WsConnection& Connection, const std::error_code& Ec); - void RemoveConnection(WsConnection& Connection); - void ReadConnection(WsConnection& Connection); + void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec); + void RemoveConnection(const WsConnectionId Id); + + void ReadMessage(std::shared_ptr<WsConnection> Connection); + void RouteMessage(const CbPackage& Msg); struct IdHasher { @@ -413,7 +479,7 @@ private: WsConnection::~WsConnection() { - m_Server.RemoveConnection(*this); + m_Server.RemoveConnection(m_Id); } bool @@ -496,7 +562,7 @@ WsServer::AcceptConnection() Connection->InitializeHttpParser(); Connection->SetState(WsConnectionState::kHandshaking); - ReadConnection(*Connection); + ReadMessage(Connection); } if (m_Running) @@ -507,84 +573,87 @@ WsServer::AcceptConnection() } void -WsServer::CloseConnection(WsConnection& Connection, const std::error_code& Ec) +WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec) { - if (const auto State = Connection.Close(); State != WsConnectionState::kDisconnected) + if (const auto State = Connection->Close(); State != WsConnectionState::kDisconnected) { if (Ec) { - ZEN_LOG_INFO(WsLog, - "closing connection '{}' ERROR, reason '{}' error code '{}'", - Connection.Id().Value(), - Ec.message(), - Ec.value()); + ZEN_LOG_INFO(WsLog, "connection '{}' closed, ERROR '{}' error code '{}'", Connection->Id().Value(), Ec.message(), Ec.value()); } else { - ZEN_LOG_INFO(WsLog, "closing connection '{}'", Connection.Id().Value()); + ZEN_LOG_INFO(WsLog, "connection '{}' closed", Connection->Id().Value()); } } + + const WsConnectionId Id = Connection->Id(); + + { + std::unique_lock _(m_ConnMutex); + m_Connections.erase(Id); + } } void -WsServer::RemoveConnection(WsConnection& Connection) +WsServer::RemoveConnection(const WsConnectionId Id) { - ZEN_LOG_INFO(WsLog, "removing connection '{}'", Connection.Id().Value()); + ZEN_LOG_INFO(WsLog, "removing connection '{}'", Id.Value()); } void -WsServer::ReadConnection(WsConnection& Connection) +WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) { - Connection.ReadBuffer().prepare(64 << 10); + Connection->ReadBuffer().prepare(64 << 10); asio::async_read( - Connection.Socket(), - Connection.ReadBuffer(), + Connection->Socket(), + Connection->ReadBuffer(), asio::transfer_at_least(1), - [this, &Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { + [this, Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { if (ReadEc) { return CloseConnection(Connection, ReadEc); } - ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection.Id().Value()); + ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection->Id().Value()); using enum WsConnectionState; - switch (Connection.State()) + switch (Connection->State()) { case kHandshaking: { - HttpParser& Parser = Connection.ParserHttp(); - const size_t Consumed = Parser.Parse(Connection.ReadBuffer().data()); - Connection.ReadBuffer().consume(Consumed); + HttpParser& Parser = Connection->ParserHttp(); + const size_t Consumed = Parser.Parse(Connection->ReadBuffer().data()); + Connection->ReadBuffer().consume(Consumed); if (Parser.IsComplete == false) { - return ReadConnection(Connection); + return ReadMessage(Connection); } if (Parser.IsUpgrade == false) { ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason 'not an upgrade request'", - Connection.Id().Value()); + Connection->Id().Value()); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; - return async_write(Connection.Socket(), + return async_write(Connection->Socket(), asio::buffer(UpgradeRequiredResponse), - [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + [this, Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { CloseConnection(Connection, WriteEc); } else { - Connection.InitializeHttpParser(); - Connection.SetState(WsConnectionState::kHandshaking); + Connection->InitializeHttpParser(); + Connection->SetState(WsConnectionState::kHandshaking); - ReadConnection(Connection); + ReadMessage(Connection); } }); } @@ -597,11 +666,11 @@ WsServer::ReadConnection(WsConnection& Connection) if (AcceptHash.empty()) { - ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection.Id().Value(), Reason); + ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection->Id().Value(), Reason); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; - return async_write(Connection.Socket(), + return async_write(Connection->Socket(), asio::buffer(UpgradeRequiredResponse), [this, &Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) @@ -610,10 +679,11 @@ WsServer::ReadConnection(WsConnection& Connection) } else { - Connection.InitializeHttpParser(); - Connection.SetState(WsConnectionState::kHandshaking); + // TODO: Always close connection? + Connection->InitializeHttpParser(); + Connection->SetState(WsConnectionState::kHandshaking); - ReadConnection(Connection); + ReadMessage(Connection); } }); } @@ -636,23 +706,24 @@ WsServer::ReadConnection(WsConnection& Connection) std::string Response = Sb.ToString(); asio::const_buffer Buffer = asio::buffer(Response); - ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection.Id().Value()); + ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection->Id().Value()); - async_write(Connection.Socket(), + async_write(Connection->Socket(), Buffer, - [this, &Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) { + [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { CloseConnection(Connection, WriteEc); } else { - ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection.Id().Value()); + ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection->Id().Value()); - Connection.ReleaseHttpParser(); - Connection.SetState(kConnected); + Connection->ReleaseHttpParser(); + Connection->SetState(kConnected); + Connection->MessageParser().Reset(); - ReadConnection(Connection); + ReadMessage(Connection); } }); } @@ -660,7 +731,42 @@ WsServer::ReadConnection(WsConnection& Connection) case kConnected: { - // TODO: Implement RPC API + for (;;) + { + if (Connection->ReadBuffer().size() == 0) + { + break; + } + + WsMessageParser& MessageParser = Connection->MessageParser(); + + size_t ConsumedBytes{}; + const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes); + + Connection->ReadBuffer().consume(ConsumedBytes); + + if (Ok == false) + { + ZEN_LOG_WARN(WsLog, "parse websocket message FAILED, connection '{}'", Connection->Id().Value()); + MessageParser.Reset(); + } + + if (Ok == false || MessageParser.IsComplete() == false) + { + continue; + } + + CbPackage Message; + if (MessageParser.TryLoadMessage(Message) == false) + { + ZEN_LOG_WARN(WsLog, "invalid websocket message, connection '{}'", Connection->Id().Value()); + continue; + } + + RouteMessage(Message); + } + + ReadMessage(Connection); } break; @@ -670,10 +776,37 @@ WsServer::ReadConnection(WsConnection& Connection) }); } +void +WsServer::RouteMessage(const CbPackage& Msg) +{ + ZEN_UNUSED(Msg); + ZEN_LOG_DEBUG(WsLog, "routing message"); +} + } // namespace zen::asio_ws namespace zen { +bool +WebSocketMessageHeader::IsValid() const +{ + return Magic == ExpectedMagic && ContentLength != 0 && Crc32 != 0; +} + +bool +WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeader) +{ + if (Memory.GetSize() < sizeof(WebSocketMessageHeader)) + { + return false; + } + + void* Dst = &OutHeader; + memcpy(Dst, Memory.GetData(), sizeof(WebSocketMessageHeader)); + + return OutHeader.IsValid(); +} + std::unique_ptr<WebSocketServer> WebSocketServer::Create() { |