diff options
| author | Per Larsson <[email protected]> | 2022-03-19 16:53:20 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-03-19 16:53:20 +0100 |
| commit | de1c792b182aeb15168ed483a803bc93725f2f46 (patch) | |
| tree | d5ea6948e6320ec43ddcba875a7c5f21a5d214cd /zenhttp | |
| parent | Suppress C4305 in third party includes (diff) | |
| download | zen-de1c792b182aeb15168ed483a803bc93725f2f46.tar.xz zen-de1c792b182aeb15168ed483a803bc93725f2f46.zip | |
Added websocket stream request/response handling.
Diffstat (limited to 'zenhttp')
| -rw-r--r-- | zenhttp/include/zenhttp/websocket.h | 44 | ||||
| -rw-r--r-- | zenhttp/websocketasio.cpp | 123 |
2 files changed, 134 insertions, 33 deletions
diff --git a/zenhttp/include/zenhttp/websocket.h b/zenhttp/include/zenhttp/websocket.h index 132dd1679..1280868ec 100644 --- a/zenhttp/include/zenhttp/websocket.h +++ b/zenhttp/include/zenhttp/websocket.h @@ -49,9 +49,37 @@ enum class WebSocketMessageType : uint8_t kInvalid, kNotification, kRequest, - kResponse + kStreamRequest, + kResponse, + kStreamResponse, + kStreamCompleteResponse, + kCount }; +inline std::string_view +ToString(WebSocketMessageType Type) +{ + switch (Type) + { + case WebSocketMessageType::kInvalid: + return std::string_view("Invalid"); + case WebSocketMessageType::kNotification: + return std::string_view("Notification"); + case WebSocketMessageType::kRequest: + return std::string_view("Request"); + case WebSocketMessageType::kStreamRequest: + return std::string_view("StreamRequest"); + case WebSocketMessageType::kResponse: + return std::string_view("Response"); + case WebSocketMessageType::kStreamResponse: + return std::string_view("StreamResponse"); + case WebSocketMessageType::kStreamCompleteResponse: + return std::string_view("StreamCompleteResponse"); + default: + return std::string_view("Unknown"); + }; +} + /** * Web socket message. */ @@ -59,12 +87,12 @@ class WebSocketMessage { struct Header { - static constexpr uint32_t HeaderMagic = 0x7a776d68; // zwmh + static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh uint64_t MessageSize{}; - uint32_t Magic{HeaderMagic}; + uint32_t Magic{ExpectedMagic}; uint32_t CorrelationId{}; - uint32_t Crc32{}; + uint32_t StatusCode{200u}; WebSocketMessageType MessageType{}; uint8_t Reserved[3] = {0}; @@ -82,11 +110,13 @@ public: WebSocketId SocketId() const { return m_SocketId; } void SetSocketId(WebSocketId Id) { m_SocketId = Id; } - void SetMessageType(WebSocketMessageType MessageType); - WebSocketMessageType MessageType() const { return m_Header.MessageType; } uint64_t MessageSize() const { return m_Header.MessageSize; } + void SetMessageType(WebSocketMessageType MessageType); void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; } uint32_t CorrelationId() const { return m_Header.CorrelationId; } + uint32_t StatusCode() const { m_Header.StatusCode; } + void SetStatusCode(uint32_t StatusCode) { m_Header.StatusCode = StatusCode; } + WebSocketMessageType MessageType() const { return m_Header.MessageType; } const CbPackage& Body() const { return m_Body.value(); } void SetBody(CbPackage&& Body); @@ -123,6 +153,8 @@ protected: WebSocketService() = default; virtual void RegisterHandlers(WebSocketServer& Server) = 0; + void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete); + void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete); WebSocketServer& SocketServer() { diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index 966925d98..bbe7e1ad8 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -4,6 +4,7 @@ #include <zencore/base64.h> #include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> #include <zencore/intmath.h> #include <zencore/iobuffer.h> #include <zencore/logging.h> @@ -11,6 +12,7 @@ #include <zencore/sha1.h> #include <zencore/stream.h> #include <zencore/string.h> +#include <zencore/trace.h> #include <chrono> #include <optional> @@ -25,6 +27,10 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio.hpp> ZEN_THIRD_PARTY_INCLUDES_END +#if ZEN_PLATFORM_WINDOWS +# include <mstcpip.h> +#endif + namespace zen::websocket { using namespace std::literals; @@ -386,6 +392,8 @@ private: ParseMessageResult WebSocketMessageParser::OnParseMessage(MemoryView Msg) { + ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage"); + const uint64_t PrevOffset = m_Stream.CurrentOffset(); if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) @@ -393,6 +401,7 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); m_Stream.Write(Msg.Left(RemaingHeaderSize)); + Msg += RemaingHeaderSize; if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) { @@ -410,24 +419,26 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) .Reason = std::string("Invalid websocket message header")}; } - Msg += RemaingHeaderSize; + if (m_Message.MessageSize() == 0) + { + return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } } ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); - if (Msg.IsEmpty()) + if (Msg.IsEmpty() == false) { - return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); + m_Stream.Write(Msg.Left(RemaingMessageSize)); } - const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); - - m_Stream.Write(Msg.Left(RemaingMessageSize)); + auto Status = ParseMessageStatus::kContinue; - const bool IsComplete = WebSocketMessage::HeaderSize + m_Message.MessageSize() == m_Stream.CurrentOffset(); - - if (IsComplete) + if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize()) { + Status = ParseMessageStatus::kDone; + BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize)); CbPackage Pkg; @@ -441,8 +452,7 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) m_Message.SetBody(std::move(Pkg)); } - return {.Status = IsComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue, - .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } void @@ -486,6 +496,7 @@ public: WebSocketState Close(); MessageParser* Parser() { return m_MsgParser.get(); } void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + std::mutex& WriteMutex() { return m_WriteMutex; } private: WebSocketId m_Id; @@ -494,6 +505,7 @@ private: std::atomic_uint32_t m_State; std::unique_ptr<MessageParser> m_MsgParser; asio::streambuf m_ReadBuffer; + std::mutex m_WriteMutex; }; WebSocketState @@ -635,13 +647,35 @@ WsServer::RegisterService(WebSocketService& Service) bool WsServer::Run() { + static constexpr size_t ReceiveBufferSize = 256 << 10; + static constexpr size_t SendBufferSize = 256 << 10; + m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6()); m_Acceptor->set_option(asio::ip::v6_only(false)); m_Acceptor->set_option(asio::socket_base::reuse_address(true)); m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); - m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize)); + m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize)); + +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor->native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +#endif asio::error_code Ec; m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec); @@ -706,7 +740,10 @@ WsServer::SendNotification(WebSocketMessage&& Notification) void WsServer::SendResponse(WebSocketMessage&& Response) { - ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse); + ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse || + Response.MessageType() == WebSocketMessageType::kStreamResponse || + Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse); + ZEN_ASSERT(Response.CorrelationId() != 0); SendMessage(std::move(Response)); @@ -970,6 +1007,7 @@ WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) switch (RoutedMessage.MessageType()) { case WebSocketMessageType::kRequest: + case WebSocketMessageType::kStreamRequest: { CbObjectView Request = RoutedMessage.Body().GetObject(); std::string_view Method = Request["Method"].AsString(); @@ -995,7 +1033,7 @@ WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) if (Error || Handled == false) { - std::string ErrorText = Error ? Exception.what() : std::string("Not Found"); + std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method); ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); @@ -1063,18 +1101,21 @@ WsServer::SendMessage(WebSocketMessage&& Msg) { BinaryWriter Writer; Msg.Save(Writer); - IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); - async_write(Connection->Socket(), - asio::buffer(Buffer.Data(), Buffer.Size()), - [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason '{}'", Ec.message()); + ZEN_LOG_TRACE(LogWebSocket, + "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}", + ToString(Msg.MessageType()), + Connection->Id().Value(), + Msg.MessageSize(), + Msg.CorrelationId(), + NiceBytes(Writer.Size())); - CloseConnection(Connection, Ec); - } - }); + { + ZEN_TRACE_CPU("WS::SendMessage"); + std::unique_lock _(Connection->WriteMutex()); + ZEN_TRACE_CPU("WS::WriteSocketData"); + asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size())); + } } } @@ -1458,7 +1499,8 @@ std::atomic_uint32_t WebSocketId::NextId{1}; bool WebSocketMessage::Header::IsValid() const { - return Magic == HeaderMagic && MessageSize > 0 && Crc32 > 0 && uint8_t(MessageType) > 0 && uint8_t(MessageType) < 4; + return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) && + uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount); } std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; @@ -1490,6 +1532,12 @@ WebSocketMessage::Save(BinaryWriter& Writer) if (m_Body.has_value()) { + const CbObject& Obj = m_Body.value().GetObject(); + MemoryView View = Obj.GetBuffer().GetView(); + + const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All); + ZEN_ASSERT(ValidationResult == CbValidateError::None); + m_Body.value().Save(Writer); } @@ -1499,7 +1547,6 @@ WebSocketMessage::Save(BinaryWriter& Writer) } m_Header.MessageSize = Writer.Size() - HeaderSize; - m_Header.Crc32 = 1; // TODO Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); } @@ -1529,6 +1576,28 @@ WebSocketService::Configure(WebSocketServer& Server) RegisterHandlers(Server); } +void +WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete) +{ + WebSocketMessage Message; + + Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse); + Message.SetCorrelationId(CorrelationId); + Message.SetSocketId(SocketId); + Message.SetBody(std::move(StreamResponse)); + + SocketServer().SendResponse(std::move(Message)); +} + +void +WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete) +{ + CbPackage Response; + Response.SetObject(std::move(StreamResponse)); + + SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete); +} + std::unique_ptr<WebSocketServer> WebSocketServer::Create(const WebSocketServerOptions& Options) { |