aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/websocketasio.cpp
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-02-16 12:32:27 +0100
committerPer Larsson <[email protected]>2022-02-16 12:32:27 +0100
commit87bb9700722e8319aa58484bba03e398dedede87 (patch)
tree1a82ce932f0a729f48bf8472d5f88fa897679a8f /zenhttp/websocketasio.cpp
parentRenamed asio web socket impl. (diff)
downloadzen-87bb9700722e8319aa58484bba03e398dedede87.tar.xz
zen-87bb9700722e8319aa58484bba03e398dedede87.zip
Added websocket message parser.
Diffstat (limited to 'zenhttp/websocketasio.cpp')
-rw-r--r--zenhttp/websocketasio.cpp297
1 files changed, 215 insertions, 82 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp
index bb3999780..ad8434a5a 100644
--- a/zenhttp/websocketasio.cpp
+++ b/zenhttp/websocketasio.cpp
@@ -3,9 +3,12 @@
#include <zenhttp/websocketserver.h>
#include <zencore/base64.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/intmath.h>
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
#include <zencore/sha1.h>
+#include <zencore/stream.h>
#include <zencore/string.h>
#include <chrono>
@@ -61,8 +64,8 @@ struct HttpParser
for (const auto& E : HeaderEntries)
{
- auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size);
- auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size);
+ auto Name = std::string_view((const char*)HeaderStream.Data() + E.Name.Offset, E.Name.Size);
+ auto Value = std::string_view((const char*)HeaderStream.Data() + E.Value.Offset, E.Value.Size);
OutHeaders[Name] = Value;
}
@@ -120,10 +123,10 @@ struct HttpParser
[](http_parser* P, const char* Data, size_t Size) {
HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
- Parser.Url.Offset = Parser.HeaderStream.Pos();
+ Parser.Url.Offset = Parser.HeaderStream.CurrentOffset();
Parser.Url.Size = Size;
- Parser.HeaderStream.Append(Data, uint32_t(Size));
+ Parser.HeaderStream.Write(Data, uint32_t(Size));
return 0;
},
@@ -139,12 +142,12 @@ struct HttpParser
if (Parser.CurrentHeader.Name.Size == 0)
{
- Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos();
+ Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.CurrentOffset();
}
Parser.CurrentHeader.Name.Size += Size;
- Parser.HeaderStream.Append(Data, Size);
+ Parser.HeaderStream.Write(Data, Size);
return 0;
},
@@ -154,12 +157,12 @@ struct HttpParser
if (Parser.CurrentHeader.Value.Size == 0)
{
- Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos();
+ Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.CurrentOffset();
}
Parser.CurrentHeader.Value.Size += Size;
- Parser.HeaderStream.Append(Data, Size);
+ Parser.HeaderStream.Write(Data, Size);
return 0;
},
@@ -192,34 +195,6 @@ struct HttpParser
size_t Size{};
};
- class MemStream
- {
- public:
- MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {}
-
- void Append(const char* Data, size_t Size)
- {
- const size_t NewSize = m_Size + Size;
-
- if (NewSize > m_Buf.size())
- {
- m_Buf.resize(m_Buf.size() + m_BlockSize);
- }
-
- memcpy(m_Buf.data() + m_Size, Data, Size);
- m_Size += Size;
- }
-
- const char* Data() const { return m_Buf.data(); }
- size_t Pos() const { return m_Size; }
- void Clear() { m_Size = 0; }
-
- private:
- std::vector<char> m_Buf;
- size_t m_Size{};
- size_t m_BlockSize{};
- };
-
using UrlEntry = MemStreamEntry;
struct HeaderEntry
@@ -231,7 +206,7 @@ struct HttpParser
static http_parser_settings ParserSettings;
http_parser Parser;
- MemStream HeaderStream;
+ SimpleBinaryWriter HeaderStream;
std::vector<HeaderEntry> HeaderEntries;
HeaderEntry CurrentHeader{};
UrlEntry Url{};
@@ -242,6 +217,93 @@ struct HttpParser
http_parser_settings HttpParser::ParserSettings;
///////////////////////////////////////////////////////////////////////////////
+class WsMessageParser
+{
+public:
+ WsMessageParser() {}
+
+ void Reset()
+ {
+ m_Header.reset();
+ m_Stream.Clear();
+ }
+
+ bool Parse(asio::const_buffer Buffer, size_t& OutConsumedBytes)
+ {
+ if (m_Header.has_value())
+ {
+ OutConsumedBytes = Min(m_Header.value().ContentLength, Buffer.size());
+
+ m_Stream.Write(Buffer.data(), OutConsumedBytes);
+
+ return true;
+ }
+
+ const size_t PrevOffset = m_Stream.CurrentOffset();
+ const size_t BytesToWrite = Min(sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(), Buffer.size());
+ const size_t RemainingBytes = Buffer.size() - BytesToWrite;
+
+ m_Stream.Write(Buffer.data(), BytesToWrite);
+
+ if (m_Stream.CurrentOffset() < sizeof(zen::WebSocketMessageHeader))
+ {
+ OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset;
+
+ return true;
+ }
+
+ zen::WebSocketMessageHeader Header;
+ if (zen::WebSocketMessageHeader::Read(m_Stream.GetView(), Header) == false)
+ {
+ OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset;
+
+ return false;
+ }
+
+ m_Header = Header;
+
+ if (RemainingBytes > 0)
+ {
+ const size_t RemainingBytesToWrite = Min(m_Header.value().ContentLength, RemainingBytes);
+
+ m_Stream.Write(reinterpret_cast<const char*>(Buffer.data()) + BytesToWrite, RemainingBytesToWrite);
+ }
+
+ OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset;
+
+ return true;
+ }
+
+ bool IsComplete()
+ {
+ if (m_Header.has_value())
+ {
+ const size_t RemainingBytes = m_Header.value().ContentLength + sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset();
+
+ return RemainingBytes == 0;
+ }
+
+ return false;
+ }
+
+ bool TryLoadMessage(CbPackage& OutPackage)
+ {
+ if (IsComplete())
+ {
+ BinaryReader Reader(m_Stream.Data(), m_Stream.Size());
+
+ return OutPackage.TryLoad(Reader);
+ }
+
+ return false;
+ }
+
+private:
+ SimpleBinaryWriter m_Stream{64 << 10};
+ std::optional<zen::WebSocketMessageHeader> m_Header;
+};
+
+///////////////////////////////////////////////////////////////////////////////
enum class WsConnectionState : uint32_t
{
kDisconnected,
@@ -294,6 +356,7 @@ public:
std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); }
asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
HttpParser& ParserHttp() { return *m_HttpParser; }
+ WsMessageParser& MessageParser() { return m_MsgParser; }
WsConnectionState Close();
WsConnectionState State() const { return static_cast<WsConnectionState>(m_Status.load(std::memory_order_relaxed)); }
@@ -307,6 +370,7 @@ private:
WsConnectionId m_Id;
std::unique_ptr<asio::ip::tcp::socket> m_Socket;
std::unique_ptr<HttpParser> m_HttpParser;
+ WsMessageParser m_MsgParser;
TimePoint m_StartTime;
std::atomic_uint32_t m_Status;
asio::streambuf m_ReadBuffer;
@@ -392,9 +456,11 @@ private:
friend class WsConnection;
void AcceptConnection();
- void CloseConnection(WsConnection& Connection, const std::error_code& Ec);
- void RemoveConnection(WsConnection& Connection);
- void ReadConnection(WsConnection& Connection);
+ void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec);
+ void RemoveConnection(const WsConnectionId Id);
+
+ void ReadMessage(std::shared_ptr<WsConnection> Connection);
+ void RouteMessage(const CbPackage& Msg);
struct IdHasher
{
@@ -413,7 +479,7 @@ private:
WsConnection::~WsConnection()
{
- m_Server.RemoveConnection(*this);
+ m_Server.RemoveConnection(m_Id);
}
bool
@@ -496,7 +562,7 @@ WsServer::AcceptConnection()
Connection->InitializeHttpParser();
Connection->SetState(WsConnectionState::kHandshaking);
- ReadConnection(*Connection);
+ ReadMessage(Connection);
}
if (m_Running)
@@ -507,84 +573,87 @@ WsServer::AcceptConnection()
}
void
-WsServer::CloseConnection(WsConnection& Connection, const std::error_code& Ec)
+WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec)
{
- if (const auto State = Connection.Close(); State != WsConnectionState::kDisconnected)
+ if (const auto State = Connection->Close(); State != WsConnectionState::kDisconnected)
{
if (Ec)
{
- ZEN_LOG_INFO(WsLog,
- "closing connection '{}' ERROR, reason '{}' error code '{}'",
- Connection.Id().Value(),
- Ec.message(),
- Ec.value());
+ ZEN_LOG_INFO(WsLog, "connection '{}' closed, ERROR '{}' error code '{}'", Connection->Id().Value(), Ec.message(), Ec.value());
}
else
{
- ZEN_LOG_INFO(WsLog, "closing connection '{}'", Connection.Id().Value());
+ ZEN_LOG_INFO(WsLog, "connection '{}' closed", Connection->Id().Value());
}
}
+
+ const WsConnectionId Id = Connection->Id();
+
+ {
+ std::unique_lock _(m_ConnMutex);
+ m_Connections.erase(Id);
+ }
}
void
-WsServer::RemoveConnection(WsConnection& Connection)
+WsServer::RemoveConnection(const WsConnectionId Id)
{
- ZEN_LOG_INFO(WsLog, "removing connection '{}'", Connection.Id().Value());
+ ZEN_LOG_INFO(WsLog, "removing connection '{}'", Id.Value());
}
void
-WsServer::ReadConnection(WsConnection& Connection)
+WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
{
- Connection.ReadBuffer().prepare(64 << 10);
+ Connection->ReadBuffer().prepare(64 << 10);
asio::async_read(
- Connection.Socket(),
- Connection.ReadBuffer(),
+ Connection->Socket(),
+ Connection->ReadBuffer(),
asio::transfer_at_least(1),
- [this, &Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable {
+ [this, Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable {
if (ReadEc)
{
return CloseConnection(Connection, ReadEc);
}
- ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection.Id().Value());
+ ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection->Id().Value());
using enum WsConnectionState;
- switch (Connection.State())
+ switch (Connection->State())
{
case kHandshaking:
{
- HttpParser& Parser = Connection.ParserHttp();
- const size_t Consumed = Parser.Parse(Connection.ReadBuffer().data());
- Connection.ReadBuffer().consume(Consumed);
+ HttpParser& Parser = Connection->ParserHttp();
+ const size_t Consumed = Parser.Parse(Connection->ReadBuffer().data());
+ Connection->ReadBuffer().consume(Consumed);
if (Parser.IsComplete == false)
{
- return ReadConnection(Connection);
+ return ReadMessage(Connection);
}
if (Parser.IsUpgrade == false)
{
ZEN_LOG_DEBUG(WsLog,
"handshake with connection '{}' FAILED, reason 'not an upgrade request'",
- Connection.Id().Value());
+ Connection->Id().Value());
constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv;
- return async_write(Connection.Socket(),
+ return async_write(Connection->Socket(),
asio::buffer(UpgradeRequiredResponse),
- [this, &Connection](const asio::error_code& WriteEc, std::size_t) {
+ [this, Connection](const asio::error_code& WriteEc, std::size_t) {
if (WriteEc)
{
CloseConnection(Connection, WriteEc);
}
else
{
- Connection.InitializeHttpParser();
- Connection.SetState(WsConnectionState::kHandshaking);
+ Connection->InitializeHttpParser();
+ Connection->SetState(WsConnectionState::kHandshaking);
- ReadConnection(Connection);
+ ReadMessage(Connection);
}
});
}
@@ -597,11 +666,11 @@ WsServer::ReadConnection(WsConnection& Connection)
if (AcceptHash.empty())
{
- ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection.Id().Value(), Reason);
+ ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection->Id().Value(), Reason);
constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv;
- return async_write(Connection.Socket(),
+ return async_write(Connection->Socket(),
asio::buffer(UpgradeRequiredResponse),
[this, &Connection](const asio::error_code& WriteEc, std::size_t) {
if (WriteEc)
@@ -610,10 +679,11 @@ WsServer::ReadConnection(WsConnection& Connection)
}
else
{
- Connection.InitializeHttpParser();
- Connection.SetState(WsConnectionState::kHandshaking);
+ // TODO: Always close connection?
+ Connection->InitializeHttpParser();
+ Connection->SetState(WsConnectionState::kHandshaking);
- ReadConnection(Connection);
+ ReadMessage(Connection);
}
});
}
@@ -636,23 +706,24 @@ WsServer::ReadConnection(WsConnection& Connection)
std::string Response = Sb.ToString();
asio::const_buffer Buffer = asio::buffer(Response);
- ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection.Id().Value());
+ ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection->Id().Value());
- async_write(Connection.Socket(),
+ async_write(Connection->Socket(),
Buffer,
- [this, &Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) {
+ [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) {
if (WriteEc)
{
CloseConnection(Connection, WriteEc);
}
else
{
- ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection.Id().Value());
+ ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection->Id().Value());
- Connection.ReleaseHttpParser();
- Connection.SetState(kConnected);
+ Connection->ReleaseHttpParser();
+ Connection->SetState(kConnected);
+ Connection->MessageParser().Reset();
- ReadConnection(Connection);
+ ReadMessage(Connection);
}
});
}
@@ -660,7 +731,42 @@ WsServer::ReadConnection(WsConnection& Connection)
case kConnected:
{
- // TODO: Implement RPC API
+ for (;;)
+ {
+ if (Connection->ReadBuffer().size() == 0)
+ {
+ break;
+ }
+
+ WsMessageParser& MessageParser = Connection->MessageParser();
+
+ size_t ConsumedBytes{};
+ const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes);
+
+ Connection->ReadBuffer().consume(ConsumedBytes);
+
+ if (Ok == false)
+ {
+ ZEN_LOG_WARN(WsLog, "parse websocket message FAILED, connection '{}'", Connection->Id().Value());
+ MessageParser.Reset();
+ }
+
+ if (Ok == false || MessageParser.IsComplete() == false)
+ {
+ continue;
+ }
+
+ CbPackage Message;
+ if (MessageParser.TryLoadMessage(Message) == false)
+ {
+ ZEN_LOG_WARN(WsLog, "invalid websocket message, connection '{}'", Connection->Id().Value());
+ continue;
+ }
+
+ RouteMessage(Message);
+ }
+
+ ReadMessage(Connection);
}
break;
@@ -670,10 +776,37 @@ WsServer::ReadConnection(WsConnection& Connection)
});
}
+void
+WsServer::RouteMessage(const CbPackage& Msg)
+{
+ ZEN_UNUSED(Msg);
+ ZEN_LOG_DEBUG(WsLog, "routing message");
+}
+
} // namespace zen::asio_ws
namespace zen {
+bool
+WebSocketMessageHeader::IsValid() const
+{
+ return Magic == ExpectedMagic && ContentLength != 0 && Crc32 != 0;
+}
+
+bool
+WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeader)
+{
+ if (Memory.GetSize() < sizeof(WebSocketMessageHeader))
+ {
+ return false;
+ }
+
+ void* Dst = &OutHeader;
+ memcpy(Dst, Memory.GetData(), sizeof(WebSocketMessageHeader));
+
+ return OutHeader.IsValid();
+}
+
std::unique_ptr<WebSocketServer>
WebSocketServer::Create()
{