aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/websocketasio.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'zenhttp/websocketasio.cpp')
-rw-r--r--zenhttp/websocketasio.cpp318
1 files changed, 170 insertions, 148 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp
index b800892d2..eb01e010e 100644
--- a/zenhttp/websocketasio.cpp
+++ b/zenhttp/websocketasio.cpp
@@ -438,14 +438,6 @@ WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg)
}
///////////////////////////////////////////////////////////////////////////////
-enum class WsConnectionState : uint32_t
-{
- kDisconnected,
- kHandshaking,
- kConnected
-};
-
-///////////////////////////////////////////////////////////////////////////////
class WsConnectionId
{
static std::atomic_uint32_t WsConnectionCounter;
@@ -467,53 +459,46 @@ private:
std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1};
-class WsServer;
-
///////////////////////////////////////////////////////////////////////////////
class WsConnection : public std::enable_shared_from_this<WsConnection>
{
public:
- WsConnection(WsServer& Server, WsConnectionId Id, std::unique_ptr<asio::ip::tcp::socket> Socket)
- : m_Server(Server)
- , m_Id(Id)
+ WsConnection(WsConnectionId Id, std::unique_ptr<asio::ip::tcp::socket> Socket)
+ : m_Id(Id)
, m_Socket(std::move(Socket))
, m_StartTime(Clock::now())
- , m_Status()
+ , m_State()
{
- m_RemoteAddr = m_Socket->remote_endpoint().address().to_string();
}
- ~WsConnection();
+ ~WsConnection() = default;
std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); }
WsConnectionId Id() const { return m_Id; }
- std::string_view RemoteAddr() const { return m_RemoteAddr; }
asio::ip::tcp::socket& Socket() { return *m_Socket; }
TimePoint StartTime() const { return m_StartTime; }
+ WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); }
+ std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); }
asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
- WsConnectionState Close();
- WsConnectionState State() const { return static_cast<WsConnectionState>(m_Status.load(std::memory_order_relaxed)); }
- WsConnectionState SetState(WsConnectionState NewState) { return static_cast<WsConnectionState>(m_Status.exchange(uint32_t(NewState))); }
-
- MessageParser* Parser() { return m_MsgParser.get(); }
- void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
+ WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
+ WebSocketState Close();
+ MessageParser* Parser() { return m_MsgParser.get(); }
+ void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
private:
- WsServer& m_Server;
WsConnectionId m_Id;
std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- std::unique_ptr<MessageParser> m_MsgParser;
TimePoint m_StartTime;
+ std::atomic_uint32_t m_State;
+ std::unique_ptr<MessageParser> m_MsgParser;
asio::streambuf m_ReadBuffer;
- std::string m_RemoteAddr;
- std::atomic_uint32_t m_Status;
};
-WsConnectionState
+WebSocketState
WsConnection::Close()
{
- using enum WsConnectionState;
+ using enum WebSocketState;
const auto PrevState = SetState(kDisconnected);
@@ -607,10 +592,9 @@ private:
void AcceptConnection();
void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec);
- void RemoveConnection(const WsConnectionId Id);
void ReadMessage(std::shared_ptr<WsConnection> Connection);
- void RouteMessage(const CbPackage& Msg);
+ void RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg);
struct IdHasher
{
@@ -627,11 +611,6 @@ private:
std::atomic_bool m_Running{};
};
-WsConnection::~WsConnection()
-{
- m_Server.RemoveConnection(m_Id);
-}
-
bool
WsServer::Run(const WebSocketServerOptions& Options)
{
@@ -692,25 +671,21 @@ WsServer::AcceptConnection()
m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable {
if (Ec)
{
- ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, error code '{}'", Ec.value());
+ ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message());
}
else
{
- auto ConnId = WsConnectionId::New();
- auto Connection = std::make_shared<WsConnection>(*this, ConnId, std::move(ConnectedSocket));
+ auto Connection = std::make_shared<WsConnection>(WsConnectionId::New(), std::move(ConnectedSocket));
- ZEN_LOG_DEBUG(LogWebSocket, "accept connection OK, addr '{}', ID '{}'", Connection->RemoteAddr(), ConnId.Value());
+ ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr());
{
std::unique_lock _(m_ConnMutex);
- m_Connections[ConnId] = Connection;
+ m_Connections[Connection->Id()] = Connection;
}
- auto Parser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest);
- Parser->Reset();
-
- Connection->SetParser(std::move(Parser));
- Connection->SetState(WsConnectionState::kHandshaking);
+ Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest));
+ Connection->SetState(WebSocketState::kHandshaking);
ReadMessage(Connection);
}
@@ -725,7 +700,7 @@ WsServer::AcceptConnection()
void
WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec)
{
- if (const auto State = Connection->Close(); State != WsConnectionState::kDisconnected)
+ if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected)
{
if (Ec)
{
@@ -753,12 +728,6 @@ WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::e
}
void
-WsServer::RemoveConnection(const WsConnectionId Id)
-{
- ZEN_LOG_INFO(LogWebSocket, "removing connection '{}'", Id.Value());
-}
-
-void
WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
{
Connection->ReadBuffer().prepare(64 << 10);
@@ -773,13 +742,7 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
return CloseConnection(Connection, ReadEc);
}
- ZEN_LOG_DEBUG(LogWebSocket,
- "reading {}B from connection '#{} {}'",
- ByteCount,
- Connection->Id().Value(),
- Connection->RemoteAddr());
-
- using enum WsConnectionState;
+ using enum WebSocketState;
switch (Connection->State())
{
@@ -823,7 +786,7 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
}
Connection->Parser()->Reset();
- Connection->SetState(WsConnectionState::kHandshaking);
+ Connection->SetState(WebSocketState::kHandshaking);
ReadMessage(Connection);
});
@@ -853,7 +816,7 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
}
Connection->Parser()->Reset();
- Connection->SetState(WsConnectionState::kHandshaking);
+ Connection->SetState(WebSocketState::kHandshaking);
ReadMessage(Connection);
});
@@ -910,42 +873,46 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
case kConnected:
{
- // for (;;)
- //{
- // if (Connection->ReadBuffer().size() == 0)
- // {
- // break;
- // }
-
- // WsMessageParser& MessageParser = Connection->MessageParser();
-
- // size_t ConsumedBytes{};
- // const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes);
-
- // Connection->ReadBuffer().consume(ConsumedBytes);
-
- // if (Ok == false)
- // {
- // ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, connection '{}'",
- // Connection->Id().Value()); MessageParser.Reset();
- // }
-
- // if (Ok == false || MessageParser.IsComplete() == false)
- // {
- // continue;
- // }
-
- // CbPackage Message;
- // if (MessageParser.TryLoadMessage(Message) == false)
- // {
- // ZEN_LOG_WARN(LogWebSocket, "invalid websocket message, connection '{}'",
- // Connection->Id().Value()); continue;
- // }
-
- // RouteMessage(Message);
- //}
-
- // ReadMessage(Connection);
+ WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser());
+
+ MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), ByteCount);
+
+ while (MessageData.IsEmpty() == false)
+ {
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
+
+ MessageData.RightChopInline(Result.ByteCount);
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ ZEN_ASSERT(MessageData.IsEmpty());
+
+ return ReadMessage(Connection);
+ }
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
+
+ return CloseConnection(Connection, std::error_code());
+ }
+
+ CbPackage Message;
+ if (Parser.TryLoadMessage(Message) == false)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'");
+
+ return CloseConnection(Connection, std::error_code());
+ }
+
+ RouteMessage(Connection, Message);
+
+ Parser.Reset();
+ }
+
+ Connection->ReadBuffer().consume(ByteCount);
+
+ ReadMessage(Connection);
}
break;
@@ -956,9 +923,9 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
}
void
-WsServer::RouteMessage(const CbPackage& Msg)
+WsServer::RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage& Msg)
{
- ZEN_UNUSED(Msg);
+ ZEN_UNUSED(Connection, Msg);
ZEN_LOG_DEBUG(LogWebSocket, "routing message");
}
@@ -976,11 +943,13 @@ public:
virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); }
virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override;
+ virtual void OnMessage(MessageCallback&& Cb) override;
private:
WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
void TriggerEvent(WebSocketEvent Evt);
- void BeginRead();
+ void ReadMessage();
+ void RouteMessage(CbPackage&& Msg);
spdlog::logger& Log() { return m_Logger; }
asio::io_context& m_IoCtx;
@@ -989,6 +958,7 @@ private:
std::unique_ptr<MessageParser> m_MsgParser;
asio::streambuf m_ReadBuffer;
EventCallback m_EventCallbacks[3];
+ MessageCallback m_MsgCallback;
std::atomic_uint32_t m_State;
std::string m_Host;
int16_t m_Port{};
@@ -997,12 +967,12 @@ private:
bool
WsClient::Connect(const WebSocketConnectInfo& Info)
{
- if (State() == WebSocketState::kConnecting || State() == WebSocketState::kConnected)
+ if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected)
{
return true;
}
- SetState(WebSocketState::kConnecting);
+ SetState(WebSocketState::kHandshaking);
try
{
@@ -1020,7 +990,7 @@ WsClient::Connect(const WebSocketConnectInfo& Info)
{
ZEN_WARN("connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what());
- SetState(WebSocketState::kFailedToConnect);
+ SetState(WebSocketState::kError);
m_Socket.reset();
TriggerEvent(WebSocketEvent::kDisconnected);
@@ -1068,7 +1038,7 @@ WsClient::Connect(const WebSocketConnectInfo& Info)
}
else
{
- BeginRead();
+ ReadMessage();
}
});
@@ -1095,7 +1065,13 @@ WsClient::Disconnect()
void
WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
{
- m_EventCallbacks[static_cast<uint32_t>(Evt)] = Cb;
+ m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb);
+}
+
+void
+WsClient::OnMessage(MessageCallback&& Cb)
+{
+ m_MsgCallback = std::move(Cb);
}
void
@@ -1110,7 +1086,7 @@ WsClient::TriggerEvent(WebSocketEvent Evt)
}
void
-WsClient::BeginRead()
+WsClient::ReadMessage()
{
m_ReadBuffer.prepare(64 << 10);
@@ -1119,71 +1095,117 @@ WsClient::BeginRead()
{
ZEN_DEBUG("read data from '{}:{}' FAILED, reason '{}'", m_Host, m_Port, Ec.message());
- Disconnect();
+ return Disconnect();
}
- else
+
+ switch (State())
{
- ZEN_DEBUG("reading {}B from '{}:{}'", ByteCount, m_Host, m_Port);
+ case WebSocketState::kHandshaking:
+ {
+ ZEN_ASSERT(m_MsgParser.get() != nullptr);
- switch (State())
- {
- case WebSocketState::kConnecting:
+ HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(m_MsgParser.get());
+
+ asio::const_buffer Buffer = m_ReadBuffer.data();
+ ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size()));
+
+ m_ReadBuffer.consume(size_t(Result.ByteCount));
+
+ if (Result.Status == ParseMessageStatus::kError)
{
- ZEN_ASSERT(m_MsgParser.get() != nullptr);
+ ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode());
- HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(m_MsgParser.get());
+ return Disconnect();
+ }
- asio::const_buffer Buffer = m_ReadBuffer.data();
- ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size()));
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ return ReadMessage();
+ }
- m_ReadBuffer.consume(size_t(Result.ByteCount));
+ ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
- if (Result.Status == ParseMessageStatus::kError)
- {
- ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode());
+ if (Parser.StatusCode() != 101)
+ {
+ ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'",
+ m_Host,
+ m_Port,
+ Parser.StatusText(),
+ Parser.StatusCode());
- return Disconnect();
- }
+ return Disconnect();
+ }
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- return BeginRead();
- }
+ ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText());
- ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
+ m_MsgParser = std::make_unique<WebSocketMessageParser>();
+
+ SetState(WebSocketState::kConnected);
+ TriggerEvent(WebSocketEvent::kConnected);
+
+ ReadMessage();
+ }
+ break;
+
+ case WebSocketState::kConnected:
+ {
+ WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(m_MsgParser.get());
- if (Parser.StatusCode() != 101)
+ MemoryView MessageData = MemoryView(m_ReadBuffer.data().data(), ByteCount);
+
+ while (MessageData.IsEmpty() == false)
+ {
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
+
+ MessageData.RightChopInline(Result.ByteCount);
+
+ if (Result.Status == ParseMessageStatus::kContinue)
{
- ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'",
- m_Host,
- m_Port,
- Parser.StatusText(),
- Parser.StatusCode());
+ ZEN_ASSERT(MessageData.IsEmpty());
- return Disconnect();
+ return ReadMessage();
}
- ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText());
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
- m_MsgParser = std::make_unique<WebSocketMessageParser>();
+ Parser.Reset();
- SetState(WebSocketState::kConnected);
- TriggerEvent(WebSocketEvent::kConnected);
+ continue;
+ }
- BeginRead();
- }
- break;
+ CbPackage Message;
+ if (Parser.TryLoadMessage(Message))
+ {
+ RouteMessage(std::move(Message));
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'");
+ }
- case WebSocketState::kConnected:
- {
- BeginRead();
+ Parser.Reset();
}
- break;
- };
- }
+
+ m_ReadBuffer.consume(ByteCount);
+
+ ReadMessage();
+ }
+ break;
+ };
});
}
+void
+WsClient::RouteMessage(CbPackage&& Msg)
+{
+ if (m_MsgCallback)
+ {
+ m_MsgCallback(Msg);
+ }
+}
+
} // namespace zen::websocket
namespace zen {