aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/websocketasio.cpp
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-03-19 16:53:20 +0100
committerPer Larsson <[email protected]>2022-03-19 16:53:20 +0100
commitde1c792b182aeb15168ed483a803bc93725f2f46 (patch)
treed5ea6948e6320ec43ddcba875a7c5f21a5d214cd /zenhttp/websocketasio.cpp
parentSuppress C4305 in third party includes (diff)
downloadzen-de1c792b182aeb15168ed483a803bc93725f2f46.tar.xz
zen-de1c792b182aeb15168ed483a803bc93725f2f46.zip
Added websocket stream request/response handling.
Diffstat (limited to 'zenhttp/websocketasio.cpp')
-rw-r--r--zenhttp/websocketasio.cpp123
1 files changed, 96 insertions, 27 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp
index 966925d98..bbe7e1ad8 100644
--- a/zenhttp/websocketasio.cpp
+++ b/zenhttp/websocketasio.cpp
@@ -4,6 +4,7 @@
#include <zencore/base64.h>
#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
#include <zencore/intmath.h>
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
@@ -11,6 +12,7 @@
#include <zencore/sha1.h>
#include <zencore/stream.h>
#include <zencore/string.h>
+#include <zencore/trace.h>
#include <chrono>
#include <optional>
@@ -25,6 +27,10 @@ ZEN_THIRD_PARTY_INCLUDES_START
#include <asio.hpp>
ZEN_THIRD_PARTY_INCLUDES_END
+#if ZEN_PLATFORM_WINDOWS
+# include <mstcpip.h>
+#endif
+
namespace zen::websocket {
using namespace std::literals;
@@ -386,6 +392,8 @@ private:
ParseMessageResult
WebSocketMessageParser::OnParseMessage(MemoryView Msg)
{
+ ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage");
+
const uint64_t PrevOffset = m_Stream.CurrentOffset();
if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
@@ -393,6 +401,7 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg)
const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset();
m_Stream.Write(Msg.Left(RemaingHeaderSize));
+ Msg += RemaingHeaderSize;
if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
{
@@ -410,24 +419,26 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg)
.Reason = std::string("Invalid websocket message header")};
}
- Msg += RemaingHeaderSize;
+ if (m_Message.MessageSize() == 0)
+ {
+ return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ }
}
ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize);
- if (Msg.IsEmpty())
+ if (Msg.IsEmpty() == false)
{
- return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset();
+ m_Stream.Write(Msg.Left(RemaingMessageSize));
}
- const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset();
-
- m_Stream.Write(Msg.Left(RemaingMessageSize));
+ auto Status = ParseMessageStatus::kContinue;
- const bool IsComplete = WebSocketMessage::HeaderSize + m_Message.MessageSize() == m_Stream.CurrentOffset();
-
- if (IsComplete)
+ if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize())
{
+ Status = ParseMessageStatus::kDone;
+
BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize));
CbPackage Pkg;
@@ -441,8 +452,7 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg)
m_Message.SetBody(std::move(Pkg));
}
- return {.Status = IsComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue,
- .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
void
@@ -486,6 +496,7 @@ public:
WebSocketState Close();
MessageParser* Parser() { return m_MsgParser.get(); }
void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
+ std::mutex& WriteMutex() { return m_WriteMutex; }
private:
WebSocketId m_Id;
@@ -494,6 +505,7 @@ private:
std::atomic_uint32_t m_State;
std::unique_ptr<MessageParser> m_MsgParser;
asio::streambuf m_ReadBuffer;
+ std::mutex m_WriteMutex;
};
WebSocketState
@@ -635,13 +647,35 @@ WsServer::RegisterService(WebSocketService& Service)
bool
WsServer::Run()
{
+ static constexpr size_t ReceiveBufferSize = 256 << 10;
+ static constexpr size_t SendBufferSize = 256 << 10;
+
m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6());
m_Acceptor->set_option(asio::ip::v6_only(false));
m_Acceptor->set_option(asio::socket_base::reuse_address(true));
m_Acceptor->set_option(asio::ip::tcp::no_delay(true));
- m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+ m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize));
+ m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize));
+
+#if ZEN_PLATFORM_WINDOWS
+ // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
+ // This must be used by both the client and server side, and is only effective in the absence of
+ // Windows Filtering Platform (WFP) callouts which can be installed by security software.
+ // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
+ SOCKET NativeSocket = m_Acceptor->native_handle();
+ int LoopbackOptionValue = 1;
+ DWORD OptionNumberOfBytesReturned = 0;
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
+#endif
asio::error_code Ec;
m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec);
@@ -706,7 +740,10 @@ WsServer::SendNotification(WebSocketMessage&& Notification)
void
WsServer::SendResponse(WebSocketMessage&& Response)
{
- ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse);
+ ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse ||
+ Response.MessageType() == WebSocketMessageType::kStreamResponse ||
+ Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse);
+
ZEN_ASSERT(Response.CorrelationId() != 0);
SendMessage(std::move(Response));
@@ -970,6 +1007,7 @@ WsServer::RouteMessage(WebSocketMessage&& RoutedMessage)
switch (RoutedMessage.MessageType())
{
case WebSocketMessageType::kRequest:
+ case WebSocketMessageType::kStreamRequest:
{
CbObjectView Request = RoutedMessage.Body().GetObject();
std::string_view Method = Request["Method"].AsString();
@@ -995,7 +1033,7 @@ WsServer::RouteMessage(WebSocketMessage&& RoutedMessage)
if (Error || Handled == false)
{
- std::string ErrorText = Error ? Exception.what() : std::string("Not Found");
+ std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method);
ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText);
@@ -1063,18 +1101,21 @@ WsServer::SendMessage(WebSocketMessage&& Msg)
{
BinaryWriter Writer;
Msg.Save(Writer);
- IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
- async_write(Connection->Socket(),
- asio::buffer(Buffer.Data(), Buffer.Size()),
- [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) {
- if (Ec)
- {
- ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason '{}'", Ec.message());
+ ZEN_LOG_TRACE(LogWebSocket,
+ "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}",
+ ToString(Msg.MessageType()),
+ Connection->Id().Value(),
+ Msg.MessageSize(),
+ Msg.CorrelationId(),
+ NiceBytes(Writer.Size()));
- CloseConnection(Connection, Ec);
- }
- });
+ {
+ ZEN_TRACE_CPU("WS::SendMessage");
+ std::unique_lock _(Connection->WriteMutex());
+ ZEN_TRACE_CPU("WS::WriteSocketData");
+ asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size()));
+ }
}
}
@@ -1458,7 +1499,8 @@ std::atomic_uint32_t WebSocketId::NextId{1};
bool
WebSocketMessage::Header::IsValid() const
{
- return Magic == HeaderMagic && MessageSize > 0 && Crc32 > 0 && uint8_t(MessageType) > 0 && uint8_t(MessageType) < 4;
+ return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) &&
+ uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount);
}
std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1};
@@ -1490,6 +1532,12 @@ WebSocketMessage::Save(BinaryWriter& Writer)
if (m_Body.has_value())
{
+ const CbObject& Obj = m_Body.value().GetObject();
+ MemoryView View = Obj.GetBuffer().GetView();
+
+ const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All);
+ ZEN_ASSERT(ValidationResult == CbValidateError::None);
+
m_Body.value().Save(Writer);
}
@@ -1499,7 +1547,6 @@ WebSocketMessage::Save(BinaryWriter& Writer)
}
m_Header.MessageSize = Writer.Size() - HeaderSize;
- m_Header.Crc32 = 1; // TODO
Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize));
}
@@ -1529,6 +1576,28 @@ WebSocketService::Configure(WebSocketServer& Server)
RegisterHandlers(Server);
}
+void
+WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete)
+{
+ WebSocketMessage Message;
+
+ Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse);
+ Message.SetCorrelationId(CorrelationId);
+ Message.SetSocketId(SocketId);
+ Message.SetBody(std::move(StreamResponse));
+
+ SocketServer().SendResponse(std::move(Message));
+}
+
+void
+WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete)
+{
+ CbPackage Response;
+ Response.SetObject(std::move(StreamResponse));
+
+ SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete);
+}
+
std::unique_ptr<WebSocketServer>
WebSocketServer::Create(const WebSocketServerOptions& Options)
{