diff options
| author | Per Larsson <[email protected]> | 2022-03-19 16:53:20 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-03-19 16:53:20 +0100 |
| commit | de1c792b182aeb15168ed483a803bc93725f2f46 (patch) | |
| tree | d5ea6948e6320ec43ddcba875a7c5f21a5d214cd /zenhttp/websocketasio.cpp | |
| parent | Suppress C4305 in third party includes (diff) | |
| download | zen-de1c792b182aeb15168ed483a803bc93725f2f46.tar.xz zen-de1c792b182aeb15168ed483a803bc93725f2f46.zip | |
Added websocket stream request/response handling.
Diffstat (limited to 'zenhttp/websocketasio.cpp')
| -rw-r--r-- | zenhttp/websocketasio.cpp | 123 |
1 files changed, 96 insertions, 27 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp index 966925d98..bbe7e1ad8 100644 --- a/zenhttp/websocketasio.cpp +++ b/zenhttp/websocketasio.cpp @@ -4,6 +4,7 @@ #include <zencore/base64.h> #include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> #include <zencore/intmath.h> #include <zencore/iobuffer.h> #include <zencore/logging.h> @@ -11,6 +12,7 @@ #include <zencore/sha1.h> #include <zencore/stream.h> #include <zencore/string.h> +#include <zencore/trace.h> #include <chrono> #include <optional> @@ -25,6 +27,10 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio.hpp> ZEN_THIRD_PARTY_INCLUDES_END +#if ZEN_PLATFORM_WINDOWS +# include <mstcpip.h> +#endif + namespace zen::websocket { using namespace std::literals; @@ -386,6 +392,8 @@ private: ParseMessageResult WebSocketMessageParser::OnParseMessage(MemoryView Msg) { + ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage"); + const uint64_t PrevOffset = m_Stream.CurrentOffset(); if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) @@ -393,6 +401,7 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); m_Stream.Write(Msg.Left(RemaingHeaderSize)); + Msg += RemaingHeaderSize; if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) { @@ -410,24 +419,26 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) .Reason = std::string("Invalid websocket message header")}; } - Msg += RemaingHeaderSize; + if (m_Message.MessageSize() == 0) + { + return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } } ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); - if (Msg.IsEmpty()) + if (Msg.IsEmpty() == false) { - return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); + m_Stream.Write(Msg.Left(RemaingMessageSize)); } - const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); - - m_Stream.Write(Msg.Left(RemaingMessageSize)); + auto Status = ParseMessageStatus::kContinue; - const bool IsComplete = WebSocketMessage::HeaderSize + m_Message.MessageSize() == m_Stream.CurrentOffset(); - - if (IsComplete) + if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize()) { + Status = ParseMessageStatus::kDone; + BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize)); CbPackage Pkg; @@ -441,8 +452,7 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg) m_Message.SetBody(std::move(Pkg)); } - return {.Status = IsComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue, - .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; } void @@ -486,6 +496,7 @@ public: WebSocketState Close(); MessageParser* Parser() { return m_MsgParser.get(); } void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + std::mutex& WriteMutex() { return m_WriteMutex; } private: WebSocketId m_Id; @@ -494,6 +505,7 @@ private: std::atomic_uint32_t m_State; std::unique_ptr<MessageParser> m_MsgParser; asio::streambuf m_ReadBuffer; + std::mutex m_WriteMutex; }; WebSocketState @@ -635,13 +647,35 @@ WsServer::RegisterService(WebSocketService& Service) bool WsServer::Run() { + static constexpr size_t ReceiveBufferSize = 256 << 10; + static constexpr size_t SendBufferSize = 256 << 10; + m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6()); m_Acceptor->set_option(asio::ip::v6_only(false)); m_Acceptor->set_option(asio::socket_base::reuse_address(true)); m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); - m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize)); + m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize)); + +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor->native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +#endif asio::error_code Ec; m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec); @@ -706,7 +740,10 @@ WsServer::SendNotification(WebSocketMessage&& Notification) void WsServer::SendResponse(WebSocketMessage&& Response) { - ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse); + ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse || + Response.MessageType() == WebSocketMessageType::kStreamResponse || + Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse); + ZEN_ASSERT(Response.CorrelationId() != 0); SendMessage(std::move(Response)); @@ -970,6 +1007,7 @@ WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) switch (RoutedMessage.MessageType()) { case WebSocketMessageType::kRequest: + case WebSocketMessageType::kStreamRequest: { CbObjectView Request = RoutedMessage.Body().GetObject(); std::string_view Method = Request["Method"].AsString(); @@ -995,7 +1033,7 @@ WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) if (Error || Handled == false) { - std::string ErrorText = Error ? Exception.what() : std::string("Not Found"); + std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method); ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); @@ -1063,18 +1101,21 @@ WsServer::SendMessage(WebSocketMessage&& Msg) { BinaryWriter Writer; Msg.Save(Writer); - IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); - async_write(Connection->Socket(), - asio::buffer(Buffer.Data(), Buffer.Size()), - [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason '{}'", Ec.message()); + ZEN_LOG_TRACE(LogWebSocket, + "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}", + ToString(Msg.MessageType()), + Connection->Id().Value(), + Msg.MessageSize(), + Msg.CorrelationId(), + NiceBytes(Writer.Size())); - CloseConnection(Connection, Ec); - } - }); + { + ZEN_TRACE_CPU("WS::SendMessage"); + std::unique_lock _(Connection->WriteMutex()); + ZEN_TRACE_CPU("WS::WriteSocketData"); + asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size())); + } } } @@ -1458,7 +1499,8 @@ std::atomic_uint32_t WebSocketId::NextId{1}; bool WebSocketMessage::Header::IsValid() const { - return Magic == HeaderMagic && MessageSize > 0 && Crc32 > 0 && uint8_t(MessageType) > 0 && uint8_t(MessageType) < 4; + return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) && + uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount); } std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; @@ -1490,6 +1532,12 @@ WebSocketMessage::Save(BinaryWriter& Writer) if (m_Body.has_value()) { + const CbObject& Obj = m_Body.value().GetObject(); + MemoryView View = Obj.GetBuffer().GetView(); + + const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All); + ZEN_ASSERT(ValidationResult == CbValidateError::None); + m_Body.value().Save(Writer); } @@ -1499,7 +1547,6 @@ WebSocketMessage::Save(BinaryWriter& Writer) } m_Header.MessageSize = Writer.Size() - HeaderSize; - m_Header.Crc32 = 1; // TODO Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); } @@ -1529,6 +1576,28 @@ WebSocketService::Configure(WebSocketServer& Server) RegisterHandlers(Server); } +void +WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete) +{ + WebSocketMessage Message; + + Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse); + Message.SetCorrelationId(CorrelationId); + Message.SetSocketId(SocketId); + Message.SetBody(std::move(StreamResponse)); + + SocketServer().SendResponse(std::move(Message)); +} + +void +WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete) +{ + CbPackage Response; + Response.SetObject(std::move(StreamResponse)); + + SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete); +} + std::unique_ptr<WebSocketServer> WebSocketServer::Create(const WebSocketServerOptions& Options) { |