aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/websocketasio.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'zenhttp/websocketasio.cpp')
-rw-r--r--zenhttp/websocketasio.cpp237
1 files changed, 126 insertions, 111 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp
index eb01e010e..1952c97a2 100644
--- a/zenhttp/websocketasio.cpp
+++ b/zenhttp/websocketasio.cpp
@@ -32,6 +32,8 @@ using namespace std::literals;
ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv);
+ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv);
+
using Clock = std::chrono::steady_clock;
using TimePoint = Clock::time_point;
@@ -104,7 +106,8 @@ class HttpMessageParser final : public MessageParser
public:
using HttpHeaders = std::unordered_map<std::string_view, std::string_view>;
- HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) {}
+ HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); }
+
virtual ~HttpMessageParser() = default;
int32_t StatusCode() const { return m_Parser.status_code; }
@@ -120,6 +123,7 @@ public:
bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason);
private:
+ void Initialize();
virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
virtual void OnReset() override;
int OnMessageBegin();
@@ -183,6 +187,25 @@ http_parser_settings HttpMessageParser::ParserSettings = {
.on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }};
+void
+HttpMessageParser::Initialize()
+{
+ http_parser_init(&m_Parser,
+ m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST
+ : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE
+ : HTTP_BOTH);
+ m_Parser.data = this;
+
+ m_UrlEntry = {};
+ m_StatusEntry = {};
+ m_CurrentHeader = {};
+ m_BodyEntry = {};
+
+ m_IsMsgComplete = false;
+
+ m_HeaderEntries.clear();
+}
+
ParseMessageResult
HttpMessageParser::OnParseMessage(MemoryView Msg)
{
@@ -201,20 +224,7 @@ HttpMessageParser::OnParseMessage(MemoryView Msg)
void
HttpMessageParser::OnReset()
{
- http_parser_init(&m_Parser,
- m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST
- : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE
- : HTTP_BOTH);
- m_Parser.data = this;
-
- m_UrlEntry = {};
- m_StatusEntry = {};
- m_CurrentHeader = {};
- m_BodyEntry = {};
-
- m_IsMsgComplete = false;
-
- m_HeaderEntries.clear();
+ Initialize();
}
int
@@ -930,13 +940,15 @@ WsServer::RouteMessage(std::shared_ptr<WsConnection> Connection, const CbPackage
}
///////////////////////////////////////////////////////////////////////////////
-class WsClient final : public WebSocketClient
+class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient>
{
public:
- WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Logger(zen::logging::Get("websocket-client")) {}
+ WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WsConnectionId::New()) {}
virtual ~WsClient() { Disconnect(); }
+ std::shared_ptr<WsClient> AsShared() { return shared_from_this(); }
+
virtual bool Connect(const WebSocketConnectInfo& Info) override;
virtual void Disconnect() override;
virtual bool IsConnected() const { return false; }
@@ -946,14 +958,16 @@ public:
virtual void OnMessage(MessageCallback&& Cb) override;
private:
- WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
- void TriggerEvent(WebSocketEvent Evt);
- void ReadMessage();
- void RouteMessage(CbPackage&& Msg);
- spdlog::logger& Log() { return m_Logger; }
+ WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
+ MessageParser* Parser() { return m_MsgParser.get(); }
+ void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
+ asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
+ void TriggerEvent(WebSocketEvent Evt);
+ void ReadMessage();
+ void RouteMessage(CbPackage&& Msg);
asio::io_context& m_IoCtx;
- spdlog::logger& m_Logger;
+ WsConnectionId m_Id;
std::unique_ptr<asio::ip::tcp::socket> m_Socket;
std::unique_ptr<MessageParser> m_MsgParser;
asio::streambuf m_ReadBuffer;
@@ -984,11 +998,11 @@ WsClient::Connect(const WebSocketConnectInfo& Info)
m_Host = m_Socket->remote_endpoint().address().to_string();
m_Port = Info.Port;
- ZEN_INFO("connected to websocket server '{}:{}'", m_Host, m_Port);
+ ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port);
}
catch (std::exception& Err)
{
- ZEN_WARN("connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what());
+ ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what());
SetState(WebSocketState::kError);
m_Socket.reset();
@@ -1024,21 +1038,21 @@ WsClient::Connect(const WebSocketConnectInfo& Info)
std::string HandshakeRequest = Sb.ToString();
asio::const_buffer Buffer = asio::buffer(HandshakeRequest);
- ZEN_DEBUG("handshaking with '{}:{}'", m_Host, m_Port);
+ ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port);
m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse);
m_MsgParser->Reset();
- async_write(*m_Socket, Buffer, [this, _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) {
+ async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) {
if (Ec)
{
- ZEN_ERROR("write data FAILED, reason '{}'", Ec.message());
+ ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message());
- Disconnect();
+ Self->Disconnect();
}
else
{
- ReadMessage();
+ Self->ReadMessage();
}
});
@@ -1050,7 +1064,7 @@ WsClient::Disconnect()
{
if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected)
{
- ZEN_INFO("closing connection to '{}:{}'", m_Host, m_Port);
+ ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port);
if (m_Socket && m_Socket->is_open())
{
@@ -1090,111 +1104,112 @@ WsClient::ReadMessage()
{
m_ReadBuffer.prepare(64 << 10);
- async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t ByteCount) {
- if (Ec)
- {
- ZEN_DEBUG("read data from '{}:{}' FAILED, reason '{}'", m_Host, m_Port, Ec.message());
-
- return Disconnect();
- }
+ async_read(*m_Socket,
+ m_ReadBuffer,
+ asio::transfer_at_least(1),
+ [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable {
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message());
- switch (State())
- {
- case WebSocketState::kHandshaking:
- {
- ZEN_ASSERT(m_MsgParser.get() != nullptr);
+ return Self->Disconnect();
+ }
- HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(m_MsgParser.get());
+ const WebSocketState State = Self->State();
- asio::const_buffer Buffer = m_ReadBuffer.data();
- ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size()));
+ switch (State)
+ {
+ case WebSocketState::kHandshaking:
+ {
+ HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser());
- m_ReadBuffer.consume(size_t(Result.ByteCount));
+ MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount);
- if (Result.Status == ParseMessageStatus::kError)
- {
- ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode());
+ ParseMessageResult Result = Parser.ParseMessage(MessageData);
- return Disconnect();
- }
+ Self->ReadBuffer().consume(size_t(Result.ByteCount));
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- return ReadMessage();
- }
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode());
- ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
+ return Self->Disconnect();
+ }
- if (Parser.StatusCode() != 101)
- {
- ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'",
- m_Host,
- m_Port,
- Parser.StatusText(),
- Parser.StatusCode());
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ return Self->ReadMessage();
+ }
- return Disconnect();
- }
+ ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
- ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText());
+ if (Parser.StatusCode() != 101)
+ {
+ ZEN_LOG_WARN(LogWsClient,
+ "handshake FAILED, status '{}', status code '{}'",
+ Parser.StatusText(),
+ Parser.StatusCode());
- m_MsgParser = std::make_unique<WebSocketMessageParser>();
+ return Self->Disconnect();
+ }
- SetState(WebSocketState::kConnected);
- TriggerEvent(WebSocketEvent::kConnected);
+ ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText());
- ReadMessage();
- }
- break;
+ Self->SetParser(std::make_unique<WebSocketMessageParser>());
+ Self->SetState(WebSocketState::kConnected);
+ Self->TriggerEvent(WebSocketEvent::kConnected);
- case WebSocketState::kConnected:
- {
- WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(m_MsgParser.get());
+ Self->ReadMessage();
+ }
+ break;
- MemoryView MessageData = MemoryView(m_ReadBuffer.data().data(), ByteCount);
+ case WebSocketState::kConnected:
+ {
+ WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser());
- while (MessageData.IsEmpty() == false)
- {
- const ParseMessageResult Result = Parser.ParseMessage(MessageData);
+ MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount);
- MessageData.RightChopInline(Result.ByteCount);
+ while (MessageData.IsEmpty() == false)
+ {
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- ZEN_ASSERT(MessageData.IsEmpty());
+ MessageData.RightChopInline(Result.ByteCount);
- return ReadMessage();
- }
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ ZEN_ASSERT(MessageData.IsEmpty());
- if (Result.Status == ParseMessageStatus::kError)
- {
- ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
+ return Self->ReadMessage();
+ }
- Parser.Reset();
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
- continue;
- }
+ Parser.Reset();
- CbPackage Message;
- if (Parser.TryLoadMessage(Message))
- {
- RouteMessage(std::move(Message));
- }
- else
- {
- ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason 'invalid message'");
- }
+ continue;
+ }
- Parser.Reset();
- }
+ CbPackage Message;
+ if (Parser.TryLoadMessage(Message))
+ {
+ Self->RouteMessage(std::move(Message));
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason 'invalid message'");
+ }
- m_ReadBuffer.consume(ByteCount);
+ Parser.Reset();
+ }
- ReadMessage();
- }
- break;
- };
- });
+ Self->ReadBuffer().consume(ByteCount);
+ Self->ReadMessage();
+ }
+ break;
+ }
+ });
}
void
@@ -1236,10 +1251,10 @@ WebSocketServer::Create()
return std::make_unique<websocket::WsServer>();
}
-std::unique_ptr<WebSocketClient>
+std::shared_ptr<WebSocketClient>
WebSocketClient::Create(asio::io_context& IoCtx)
{
- return std::make_unique<websocket::WsClient>(IoCtx);
+ return std::make_shared<websocket::WsClient>(IoCtx);
}
} // namespace zen