aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/websocketasio.cpp
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-02-18 14:48:41 +0100
committerPer Larsson <[email protected]>2022-02-18 14:48:41 +0100
commit08cc02bf1b5cad75ed7b97c0dc7cfc082da537c4 (patch)
tree8fd8e7189c9807280003eabb3040cb35db2e5cd4 /zenhttp/websocketasio.cpp
parentWeb socket client is shared between I/O thead and client. (diff)
downloadzen-08cc02bf1b5cad75ed7b97c0dc7cfc082da537c4.tar.xz
zen-08cc02bf1b5cad75ed7b97c0dc7cfc082da537c4.zip
Basic websocket service and test.
Diffstat (limited to 'zenhttp/websocketasio.cpp')
-rw-r--r--zenhttp/websocketasio.cpp212
1 files changed, 161 insertions, 51 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp
index 1952c97a2..f6f58f38c 100644
--- a/zenhttp/websocketasio.cpp
+++ b/zenhttp/websocketasio.cpp
@@ -13,7 +13,6 @@
#include <zencore/string.h>
#include <chrono>
-#include <compare>
#include <optional>
#include <shared_mutex>
#include <span>
@@ -381,7 +380,6 @@ private:
virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
virtual void OnReset() override;
- SimpleBinaryWriter m_HeaderStream;
WebSocketMessageHeader m_Header;
};
@@ -390,20 +388,20 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg)
{
const uint64_t PrevOffset = m_Stream.CurrentOffset();
- if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader))
+ if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader))
{
- const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_HeaderStream.CurrentOffset();
+ const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_Stream.CurrentOffset();
- m_HeaderStream.Write(Msg.Left(RemaingHeaderSize));
+ m_Stream.Write(Msg.Left(RemaingHeaderSize));
Msg.RightChopInline(RemaingHeaderSize);
- if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader))
+ if (m_Stream.CurrentOffset() < sizeof(WebSocketMessageHeader))
{
return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
- const bool IsValidHeader = WebSocketMessageHeader::Read(m_HeaderStream.GetView(), m_Header);
+ const bool IsValidHeader = WebSocketMessageHeader::Read(m_Stream.GetView(), m_Header);
if (IsValidHeader == false)
{
@@ -411,16 +409,21 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg)
}
}
- if (Msg.GetSize() == 0)
+ ZEN_ASSERT(m_Stream.CurrentOffset() >= sizeof(WebSocketMessageHeader));
+
+ if (Msg.IsEmpty())
{
return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
- const uint64_t RemaingContentSize = m_Header.ContentLength - m_HeaderStream.CurrentOffset();
+ const uint64_t RemaingContentSize =
+ Min(m_Header.ContentLength - (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)), Msg.GetSize());
m_Stream.Write(Msg.Left(RemaingContentSize));
- const auto Status = m_Stream.CurrentOffset() == m_Header.ContentLength ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue;
+ const auto Status = (m_Stream.CurrentOffset() - sizeof(WebSocketMessageHeader)) == m_Header.ContentLength
+ ? ParseMessageStatus::kDone
+ : ParseMessageStatus::kContinue;
return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
@@ -428,18 +431,17 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg)
void
WebSocketMessageParser::OnReset()
{
- m_HeaderStream.Clear();
m_Header = {};
}
bool
WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg)
{
- const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength;
+ const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength + sizeof(WebSocketMessageHeader);
if (IsParsed)
{
- BinaryReader Reader(m_Stream.Data(), m_Stream.Size());
+ BinaryReader Reader(m_Stream.GetView().RightChop(sizeof(WebSocketMessageHeader)));
return OutMsg.TryLoad(Reader);
}
@@ -448,32 +450,10 @@ WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg)
}
///////////////////////////////////////////////////////////////////////////////
-class WsConnectionId
-{
- static std::atomic_uint32_t WsConnectionCounter;
-
-public:
- WsConnectionId() = default;
-
- uint32_t Value() const { return m_Value; }
-
- auto operator<=>(const WsConnectionId& RHS) const = default;
-
- static WsConnectionId New() { return WsConnectionId(WsConnectionCounter.fetch_add(1)); }
-
-private:
- WsConnectionId(uint32_t Value) : m_Value(Value) {}
-
- uint32_t m_Value{};
-};
-
-std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1};
-
-///////////////////////////////////////////////////////////////////////////////
class WsConnection : public std::enable_shared_from_this<WsConnection>
{
public:
- WsConnection(WsConnectionId Id, std::unique_ptr<asio::ip::tcp::socket> Socket)
+ WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket)
: m_Id(Id)
, m_Socket(std::move(Socket))
, m_StartTime(Clock::now())
@@ -485,7 +465,7 @@ public:
std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); }
- WsConnectionId Id() const { return m_Id; }
+ WebSocketId Id() const { return m_Id; }
asio::ip::tcp::socket& Socket() { return *m_Socket; }
TimePoint StartTime() const { return m_StartTime; }
WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); }
@@ -497,7 +477,7 @@ public:
void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
private:
- WsConnectionId m_Id;
+ WebSocketId m_Id;
std::unique_ptr<asio::ip::tcp::socket> m_Socket;
TimePoint m_StartTime;
std::atomic_uint32_t m_State;
@@ -594,8 +574,10 @@ public:
WsServer() = default;
virtual ~WsServer() { Shutdown(); }
+ virtual void RegisterService(WebSocketService& Service) override;
virtual bool Run(const WebSocketServerOptions& Options) override;
virtual void Shutdown() override;
+ virtual void PublishMessage(WebSocketId Id, CbPackage&& Msg) override;
private:
friend class WsConnection;
@@ -608,19 +590,28 @@ private:
struct IdHasher
{
- size_t operator()(WsConnectionId Id) const { return size_t(Id.Value()); }
+ size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); }
};
- using ConnectionMap = std::unordered_map<WsConnectionId, std::shared_ptr<WsConnection>, IdHasher>;
+ using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>;
asio::io_service m_IoSvc;
std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor;
std::unique_ptr<WsThreadPool> m_ThreadPool;
ConnectionMap m_Connections;
std::shared_mutex m_ConnMutex;
+ std::vector<WebSocketService*> m_Services;
std::atomic_bool m_Running{};
};
+void
+WsServer::RegisterService(WebSocketService& Service)
+{
+ m_Services.push_back(&Service);
+
+ Service.Configure(*this);
+}
+
bool
WsServer::Run(const WebSocketServerOptions& Options)
{
@@ -673,6 +664,37 @@ WsServer::Shutdown()
}
void
+WsServer::PublishMessage(WebSocketId Id, CbPackage&& Msg)
+{
+ std::shared_ptr<WsConnection> Connection;
+
+ {
+ std::unique_lock _(m_ConnMutex);
+
+ if (auto It = m_Connections.find(Id); It != m_Connections.end())
+ {
+ Connection = It->second;
+ }
+ }
+
+ BinaryWriter Writer;
+ WebSocketMessageHeader::Write(Writer, Msg);
+
+ IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+
+ ZEN_LOG_WARN(LogWebSocket, "sending message {}B to '#{} {}' ", Buffer.Size(), Connection->Id().Value(), Connection->RemoteAddr());
+
+ async_write(Connection->Socket(),
+ asio::buffer(Buffer.Data(), Buffer.Size()),
+ [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ CloseConnection(Connection, Ec);
+ }
+ });
+}
+
+void
WsServer::AcceptConnection()
{
auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc);
@@ -685,7 +707,7 @@ WsServer::AcceptConnection()
}
else
{
- auto Connection = std::make_shared<WsConnection>(WsConnectionId::New(), std::move(ConnectedSocket));
+ auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket));
ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr());
@@ -726,7 +748,7 @@ WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::e
}
}
- const WsConnectionId Id = Connection->Id();
+ const WebSocketId Id = Connection->Id();
{
std::unique_lock _(m_ConnMutex);
@@ -763,6 +785,8 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size()));
+ Connection->ReadBuffer().consume(Result.ByteCount);
+
if (Result.Status == ParseMessageStatus::kContinue)
{
return ReadMessage(Connection);
@@ -770,10 +794,10 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
if (Result.Status == ParseMessageStatus::kError)
{
- ZEN_LOG_DEBUG(LogWebSocket,
- "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'",
- Connection->Id().Value(),
- Connection->RemoteAddr());
+ ZEN_LOG_WARN(LogWebSocket,
+ "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
return CloseConnection(Connection, std::error_code());
}
@@ -877,6 +901,8 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
Connection->SetParser(std::make_unique<WebSocketMessageParser>());
Connection->SetState(kConnected);
+
+ ReadMessage(Connection);
});
}
break;
@@ -935,15 +961,24 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
void
WsServer::RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg)
{
- ZEN_UNUSED(Connection, Msg);
ZEN_LOG_DEBUG(LogWebSocket, "routing message");
+
+ for (auto Server : m_Services)
+ {
+ if (Server->HandleMessage(Connection->Id(), Msg))
+ {
+ return;
+ }
+ }
+
+ ZEN_LOG_WARN(LogWebSocket, "unhandled message");
}
///////////////////////////////////////////////////////////////////////////////
class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient>
{
public:
- WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WsConnectionId::New()) {}
+ WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {}
virtual ~WsClient() { Disconnect(); }
@@ -953,6 +988,8 @@ public:
virtual void Disconnect() override;
virtual bool IsConnected() const { return false; }
virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); }
+ virtual void SendMsg(CbPackage&& Msg) override;
+ virtual void SendMsg(CbObject&& Msg) override;
virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override;
virtual void OnMessage(MessageCallback&& Cb) override;
@@ -967,7 +1004,7 @@ private:
void RouteMessage(CbPackage&& Msg);
asio::io_context& m_IoCtx;
- WsConnectionId m_Id;
+ WebSocketId m_Id;
std::unique_ptr<asio::ip::tcp::socket> m_Socket;
std::unique_ptr<MessageParser> m_MsgParser;
asio::streambuf m_ReadBuffer;
@@ -1077,6 +1114,35 @@ WsClient::Disconnect()
}
void
+WsClient::SendMsg(CbPackage&& Msg)
+{
+ BinaryWriter Writer;
+ WebSocketMessageHeader::Write(Writer, Msg);
+
+ IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+
+ ZEN_LOG_DEBUG(LogWsClient, "sending message {}B", Buffer.Size());
+
+ async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_LOG_ERROR(LogWsClient, "send messge FAILED, reason '{}'", Ec.message());
+
+ Self->Disconnect();
+ }
+ });
+}
+
+void
+WsClient::SendMsg(CbObject&& Msg)
+{
+ CbPackage Pkg;
+ Pkg.SetObject(std::move(Msg));
+
+ WsClient::SendMsg(std::move(Pkg));
+}
+
+void
WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
{
m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb);
@@ -1157,9 +1223,8 @@ WsClient::ReadMessage()
Self->SetParser(std::make_unique<WebSocketMessageParser>());
Self->SetState(WebSocketState::kConnected);
- Self->TriggerEvent(WebSocketEvent::kConnected);
-
Self->ReadMessage();
+ Self->TriggerEvent(WebSocketEvent::kConnected);
}
break;
@@ -1225,6 +1290,35 @@ WsClient::RouteMessage(CbPackage&& Msg)
namespace zen {
+std::atomic_uint32_t WebSocketId::NextId{1};
+
+void
+WebSocketService::Configure(WebSocketServer& Server)
+{
+ ZEN_ASSERT(m_Server == nullptr);
+
+ m_Server = &Server;
+}
+
+void
+WebSocketService::PublishMessage(WebSocketId Id, CbPackage&& Msg)
+{
+ ZEN_ASSERT(m_Server != nullptr);
+
+ m_Server->PublishMessage(Id, std::move(Msg));
+}
+
+void
+WebSocketService::PublishMessage(WebSocketId Id, CbObject&& Msg)
+{
+ ZEN_ASSERT(m_Server != nullptr);
+
+ CbPackage Pkg;
+ Pkg.SetObject(std::move(Msg));
+
+ m_Server->PublishMessage(Id, std::move(Pkg));
+}
+
bool
WebSocketMessageHeader::IsValid() const
{
@@ -1245,6 +1339,22 @@ WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeade
return OutHeader.IsValid();
}
+void
+WebSocketMessageHeader::Write(BinaryWriter& Writer, const CbPackage& Msg)
+{
+ WebSocketMessageHeader Header;
+
+ Writer.Write(&Header, sizeof(WebSocketMessageHeader));
+ Msg.Save(Writer);
+
+ Header.Magic = WebSocketMessageHeader::ExpectedMagic;
+ Header.ContentLength = Writer.CurrentOffset() - sizeof(WebSocketMessageHeader);
+ Header.Crc32 = WebSocketMessageHeader::ExpectedMagic; // TODO
+
+ MutableMemoryView HeaderView(const_cast<uint8_t*>(Writer.Data()), sizeof(WebSocketMessageHeader));
+ HeaderView.CopyFrom(MemoryView(&Header, sizeof(WebSocketMessageHeader)));
+}
+
std::unique_ptr<WebSocketServer>
WebSocketServer::Create()
{