aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/websocketasio.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'zenhttp/websocketasio.cpp')
-rw-r--r--zenhttp/websocketasio.cpp1047
1 files changed, 727 insertions, 320 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp
index ad8434a5a..b800892d2 100644
--- a/zenhttp/websocketasio.cpp
+++ b/zenhttp/websocketasio.cpp
@@ -1,12 +1,13 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zenhttp/websocketserver.h>
+#include <zenhttp/websocket.h>
#include <zencore/base64.h>
#include <zencore/compactbinarypackage.h>
#include <zencore/intmath.h>
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
+#include <zencore/memory.h>
#include <zencore/sha1.h>
#include <zencore/stream.h>
#include <zencore/string.h>
@@ -25,11 +26,11 @@ ZEN_THIRD_PARTY_INCLUDES_START
#include <asio.hpp>
ZEN_THIRD_PARTY_INCLUDES_END
-namespace zen::asio_ws {
+namespace zen::websocket {
using namespace std::literals;
-ZEN_DEFINE_LOG_CATEGORY_STATIC(WsLog, "websocket"sv);
+ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv);
using Clock = std::chrono::steady_clock;
using TimePoint = Clock::time_point;
@@ -45,263 +46,396 @@ namespace http_header {
} // namespace http_header
///////////////////////////////////////////////////////////////////////////////
-struct HttpParser
+enum class ParseMessageStatus : uint32_t
{
- HttpParser()
+ kError,
+ kContinue,
+ kDone,
+};
+
+struct ParseMessageResult
+{
+ ParseMessageStatus Status{};
+ size_t ByteCount{};
+ std::optional<std::string> Reason;
+};
+
+class MessageParser
+{
+public:
+ virtual ~MessageParser() = default;
+
+ ParseMessageResult ParseMessage(MemoryView Msg);
+ void Reset();
+
+protected:
+ MessageParser() = default;
+
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0;
+ virtual void OnReset() = 0;
+
+ SimpleBinaryWriter m_Stream;
+};
+
+ParseMessageResult
+MessageParser::ParseMessage(MemoryView Msg)
+{
+ return OnParseMessage(Msg);
+}
+
+void
+MessageParser::Reset()
+{
+ OnReset();
+
+ m_Stream.Clear();
+}
+
+///////////////////////////////////////////////////////////////////////////////
+enum class HttpMessageParserType
+{
+ kRequest,
+ kResponse,
+ kBoth
+};
+
+class HttpMessageParser final : public MessageParser
+{
+public:
+ using HttpHeaders = std::unordered_map<std::string_view, std::string_view>;
+
+ HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) {}
+ virtual ~HttpMessageParser() = default;
+
+ int32_t StatusCode() const { return m_Parser.status_code; }
+ bool IsUpgrade() const { return m_Parser.upgrade != 0; }
+ HttpHeaders& Headers() { return m_Headers; }
+ MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); }
+
+ std::string_view StatusText() const
{
- http_parser_init(&Parser, HTTP_REQUEST);
- Parser.data = this;
+ return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size);
}
- size_t Parse(asio::const_buffer Buffer)
+ bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason);
+
+private:
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
+ virtual void OnReset() override;
+ int OnMessageBegin();
+ int OnUrl(MemoryView Url);
+ int OnStatus(MemoryView Status);
+ int OnHeaderField(MemoryView HeaderField);
+ int OnHeaderValue(MemoryView HeaderValue);
+ int OnHeadersComplete();
+ int OnBody(MemoryView Body);
+ int OnMessageComplete();
+
+ struct StreamEntry
{
- return http_parser_execute(&Parser, &ParserSettings, reinterpret_cast<const char*>(Buffer.data()), Buffer.size());
- }
+ uint64_t Offset{};
+ uint64_t Size{};
+ };
- void GetHeaders(std::unordered_map<std::string_view, std::string_view>& OutHeaders)
+ struct HeaderStreamEntry
{
- OutHeaders.reserve(HeaderEntries.size());
+ StreamEntry Field{};
+ StreamEntry Value{};
+ };
- for (const auto& E : HeaderEntries)
- {
- auto Name = std::string_view((const char*)HeaderStream.Data() + E.Name.Offset, E.Name.Size);
- auto Value = std::string_view((const char*)HeaderStream.Data() + E.Value.Offset, E.Value.Size);
+ HttpMessageParserType m_Type;
+ http_parser m_Parser;
+ StreamEntry m_UrlEntry;
+ StreamEntry m_StatusEntry;
+ StreamEntry m_BodyEntry;
+ HeaderStreamEntry m_CurrentHeader;
+ std::vector<HeaderStreamEntry> m_HeaderEntries;
+ HttpHeaders m_Headers;
+ bool m_IsMsgComplete{false};
- OutHeaders[Name] = Value;
- }
- }
+ static http_parser_settings ParserSettings;
+};
- std::string ValidateWebSocketHandshake(std::unordered_map<std::string_view, std::string_view>& Headers, std::string& OutReason)
+http_parser_settings HttpMessageParser::ParserSettings = {
+ .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); },
+
+ .on_url = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); },
+
+ .on_status = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); },
+
+ .on_header_field = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); },
+
+ .on_header_value = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); },
+
+ .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); },
+
+ .on_body = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); },
+
+ .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }};
+
+ParseMessageResult
+HttpMessageParser::OnParseMessage(MemoryView Msg)
+{
+ const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize());
+
+ auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue;
+
+ if (m_Parser.http_errno != 0)
{
- static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv;
+ Status = ParseMessageStatus::kError;
+ }
- std::string AcceptHash;
+ return {.Status = Status, .ByteCount = uint64_t(ByteCount)};
+}
- if (Headers.contains(http_header::SecWebSocketKey) == false)
- {
- OutReason = "Missing header Sec-WebSocket-Key";
- return AcceptHash;
- }
+void
+HttpMessageParser::OnReset()
+{
+ http_parser_init(&m_Parser,
+ m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST
+ : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE
+ : HTTP_BOTH);
+ m_Parser.data = this;
- if (Headers.contains(http_header::Upgrade) == false)
- {
- OutReason = "Missing header Upgrade";
- return AcceptHash;
- }
+ m_UrlEntry = {};
+ m_StatusEntry = {};
+ m_CurrentHeader = {};
+ m_BodyEntry = {};
+
+ m_IsMsgComplete = false;
- ExtendableStringBuilder<128> Sb;
- Sb << Headers[http_header::SecWebSocketKey] << WebSocketGuid;
+ m_HeaderEntries.clear();
+}
+
+int
+HttpMessageParser::OnMessageBegin()
+{
+ ZEN_ASSERT(m_IsMsgComplete == false);
+ ZEN_ASSERT(m_HeaderEntries.empty());
+ ZEN_ASSERT(m_Headers.empty());
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnStatus(MemoryView Status)
+{
+ m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()};
+
+ m_Stream.Write(Status);
+
+ return 0;
+}
- SHA1Stream HashStream;
- HashStream.Append(Sb.Data(), Sb.Size());
+int
+HttpMessageParser::OnUrl(MemoryView Url)
+{
+ m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()};
- SHA1 Hash = HashStream.GetHash();
+ m_Stream.Write(Url);
- AcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash)));
- Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), AcceptHash.data());
+ return 0;
+}
- return AcceptHash;
+int
+HttpMessageParser::OnHeaderField(MemoryView HeaderField)
+{
+ if (m_CurrentHeader.Value.Size > 0)
+ {
+ m_HeaderEntries.push_back(m_CurrentHeader);
+ m_CurrentHeader = {};
}
- static void Initialize()
+ if (m_CurrentHeader.Field.Size == 0)
{
- ParserSettings = {.on_message_begin =
- [](http_parser* P) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
-
- Parser.Url = UrlEntry{};
- Parser.CurrentHeader = HeaderEntry{};
- Parser.IsUpgrade = false;
- Parser.IsComplete = false;
-
- Parser.HeaderStream.Clear();
- Parser.HeaderEntries.clear();
-
- return 0;
- },
- .on_url =
- [](http_parser* P, const char* Data, size_t Size) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
-
- Parser.Url.Offset = Parser.HeaderStream.CurrentOffset();
- Parser.Url.Size = Size;
-
- Parser.HeaderStream.Write(Data, uint32_t(Size));
-
- return 0;
- },
- .on_header_field =
- [](http_parser* P, const char* Data, size_t Size) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
-
- if (Parser.CurrentHeader.Value.Size > 0)
- {
- Parser.HeaderEntries.push_back(Parser.CurrentHeader);
- Parser.CurrentHeader = HeaderEntry{};
- }
-
- if (Parser.CurrentHeader.Name.Size == 0)
- {
- Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.CurrentOffset();
- }
-
- Parser.CurrentHeader.Name.Size += Size;
-
- Parser.HeaderStream.Write(Data, Size);
-
- return 0;
- },
- .on_header_value =
- [](http_parser* P, const char* Data, size_t Size) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
-
- if (Parser.CurrentHeader.Value.Size == 0)
- {
- Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.CurrentOffset();
- }
-
- Parser.CurrentHeader.Value.Size += Size;
-
- Parser.HeaderStream.Write(Data, Size);
-
- return 0;
- },
- .on_headers_complete =
- [](http_parser* P) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
-
- if (Parser.CurrentHeader.Value.Size > 0)
- {
- Parser.HeaderEntries.push_back(Parser.CurrentHeader);
- Parser.CurrentHeader = HeaderEntry{};
- }
-
- Parser.IsUpgrade = P->upgrade > 0;
-
- return 0;
- },
- .on_message_complete =
- [](http_parser* P) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
- Parser.IsComplete = true;
- Parser.IsUpgrade = P->upgrade > 0;
- return 0;
- }};
+ m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset();
}
- struct MemStreamEntry
+ m_CurrentHeader.Field.Size += HeaderField.GetSize();
+
+ m_Stream.Write(HeaderField);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnHeaderValue(MemoryView HeaderValue)
+{
+ if (m_CurrentHeader.Value.Size == 0)
{
- size_t Offset{};
- size_t Size{};
- };
+ m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset();
+ }
+
+ m_CurrentHeader.Value.Size += HeaderValue.GetSize();
- using UrlEntry = MemStreamEntry;
+ m_Stream.Write(HeaderValue);
- struct HeaderEntry
+ return 0;
+}
+
+int
+HttpMessageParser::OnHeadersComplete()
+{
+ if (m_CurrentHeader.Value.Size > 0)
{
- MemStreamEntry Name;
- MemStreamEntry Value;
- };
+ m_HeaderEntries.push_back(m_CurrentHeader);
+ m_CurrentHeader = {};
+ }
- static http_parser_settings ParserSettings;
+ m_Headers.clear();
+ m_Headers.reserve(m_HeaderEntries.size());
- http_parser Parser;
- SimpleBinaryWriter HeaderStream;
- std::vector<HeaderEntry> HeaderEntries;
- HeaderEntry CurrentHeader{};
- UrlEntry Url{};
- bool IsUpgrade = false;
- bool IsComplete = false;
-};
+ const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data());
+
+ for (const auto& Entry : m_HeaderEntries)
+ {
+ auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size);
+ auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size);
-http_parser_settings HttpParser::ParserSettings;
+ m_Headers.try_emplace(std::move(Field), std::move(Value));
+ }
-///////////////////////////////////////////////////////////////////////////////
-class WsMessageParser
+ return 0;
+}
+
+int
+HttpMessageParser::OnBody(MemoryView Body)
{
-public:
- WsMessageParser() {}
+ m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()};
+
+ m_Stream.Write(Body);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnMessageComplete()
+{
+ m_IsMsgComplete = true;
+
+ return 0;
+}
+
+bool
+HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason)
+{
+ static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv;
+
+ OutAcceptHash = std::string();
- void Reset()
+ if (m_Headers.contains(http_header::SecWebSocketKey) == false)
{
- m_Header.reset();
- m_Stream.Clear();
+ OutReason = "Missing header Sec-WebSocket-Key";
+ return false;
}
- bool Parse(asio::const_buffer Buffer, size_t& OutConsumedBytes)
+ if (m_Headers.contains(http_header::Upgrade) == false)
{
- if (m_Header.has_value())
- {
- OutConsumedBytes = Min(m_Header.value().ContentLength, Buffer.size());
+ OutReason = "Missing header Upgrade";
+ return false;
+ }
- m_Stream.Write(Buffer.data(), OutConsumedBytes);
+ ExtendableStringBuilder<128> Sb;
+ Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid;
- return true;
- }
+ SHA1Stream HashStream;
+ HashStream.Append(Sb.Data(), Sb.Size());
- const size_t PrevOffset = m_Stream.CurrentOffset();
- const size_t BytesToWrite = Min(sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(), Buffer.size());
- const size_t RemainingBytes = Buffer.size() - BytesToWrite;
+ SHA1 Hash = HashStream.GetHash();
- m_Stream.Write(Buffer.data(), BytesToWrite);
+ OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash)));
+ Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data());
- if (m_Stream.CurrentOffset() < sizeof(zen::WebSocketMessageHeader))
- {
- OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset;
+ return true;
+}
- return true;
- }
+///////////////////////////////////////////////////////////////////////////////
+class WebSocketMessageParser final : public MessageParser
+{
+public:
+ WebSocketMessageParser() : MessageParser() {}
- zen::WebSocketMessageHeader Header;
- if (zen::WebSocketMessageHeader::Read(m_Stream.GetView(), Header) == false)
- {
- OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset;
+ bool TryLoadMessage(CbPackage& OutMsg);
- return false;
- }
+private:
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
+ virtual void OnReset() override;
- m_Header = Header;
+ SimpleBinaryWriter m_HeaderStream;
+ WebSocketMessageHeader m_Header;
+};
- if (RemainingBytes > 0)
- {
- const size_t RemainingBytesToWrite = Min(m_Header.value().ContentLength, RemainingBytes);
+ParseMessageResult
+WebSocketMessageParser::OnParseMessage(MemoryView Msg)
+{
+ const uint64_t PrevOffset = m_Stream.CurrentOffset();
- m_Stream.Write(reinterpret_cast<const char*>(Buffer.data()) + BytesToWrite, RemainingBytesToWrite);
- }
+ if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader))
+ {
+ const uint64_t RemaingHeaderSize = sizeof(WebSocketMessageHeader) - m_HeaderStream.CurrentOffset();
- OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset;
+ m_HeaderStream.Write(Msg.Left(RemaingHeaderSize));
- return true;
- }
+ Msg.RightChopInline(RemaingHeaderSize);
- bool IsComplete()
- {
- if (m_Header.has_value())
+ if (m_HeaderStream.CurrentOffset() < sizeof(WebSocketMessageHeader))
{
- const size_t RemainingBytes = m_Header.value().ContentLength + sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset();
-
- return RemainingBytes == 0;
+ return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
- return false;
+ const bool IsValidHeader = WebSocketMessageHeader::Read(m_HeaderStream.GetView(), m_Header);
+
+ if (IsValidHeader == false)
+ {
+ return {.Status = ParseMessageStatus::kError, .Reason = std::string("Invalid websocket message header")};
+ }
}
- bool TryLoadMessage(CbPackage& OutPackage)
+ if (Msg.GetSize() == 0)
{
- if (IsComplete())
- {
- BinaryReader Reader(m_Stream.Data(), m_Stream.Size());
+ return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ }
- return OutPackage.TryLoad(Reader);
- }
+ const uint64_t RemaingContentSize = m_Header.ContentLength - m_HeaderStream.CurrentOffset();
- return false;
+ m_Stream.Write(Msg.Left(RemaingContentSize));
+
+ const auto Status = m_Stream.CurrentOffset() == m_Header.ContentLength ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue;
+
+ return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+}
+
+void
+WebSocketMessageParser::OnReset()
+{
+ m_HeaderStream.Clear();
+ m_Header = {};
+}
+
+bool
+WebSocketMessageParser::TryLoadMessage(CbPackage& OutMsg)
+{
+ const bool IsParsed = m_Header.IsValid() && m_Stream.CurrentOffset() == m_Header.ContentLength;
+
+ if (IsParsed)
+ {
+ BinaryReader Reader(m_Stream.Data(), m_Stream.Size());
+
+ return OutMsg.TryLoad(Reader);
}
-private:
- SimpleBinaryWriter m_Stream{64 << 10};
- std::optional<zen::WebSocketMessageHeader> m_Header;
-};
+ return false;
+}
///////////////////////////////////////////////////////////////////////////////
enum class WsConnectionState : uint32_t
@@ -346,34 +480,34 @@ public:
, m_StartTime(Clock::now())
, m_Status()
{
+ m_RemoteAddr = m_Socket->remote_endpoint().address().to_string();
}
~WsConnection();
- WsConnectionId Id() const { return m_Id; }
- asio::ip::tcp::socket& Socket() { return *m_Socket; }
- TimePoint StartTime() const { return m_StartTime; }
std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); }
- asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
- HttpParser& ParserHttp() { return *m_HttpParser; }
- WsMessageParser& MessageParser() { return m_MsgParser; }
- WsConnectionState Close();
- WsConnectionState State() const { return static_cast<WsConnectionState>(m_Status.load(std::memory_order_relaxed)); }
+ WsConnectionId Id() const { return m_Id; }
+ std::string_view RemoteAddr() const { return m_RemoteAddr; }
+ asio::ip::tcp::socket& Socket() { return *m_Socket; }
+ TimePoint StartTime() const { return m_StartTime; }
+ asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
+ WsConnectionState Close();
+ WsConnectionState State() const { return static_cast<WsConnectionState>(m_Status.load(std::memory_order_relaxed)); }
WsConnectionState SetState(WsConnectionState NewState) { return static_cast<WsConnectionState>(m_Status.exchange(uint32_t(NewState))); }
- void InitializeHttpParser() { m_HttpParser = std::make_unique<HttpParser>(); }
- void ReleaseHttpParser() { m_HttpParser.reset(); }
+ MessageParser* Parser() { return m_MsgParser.get(); }
+ void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
private:
WsServer& m_Server;
WsConnectionId m_Id;
std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- std::unique_ptr<HttpParser> m_HttpParser;
- WsMessageParser m_MsgParser;
+ std::unique_ptr<MessageParser> m_MsgParser;
TimePoint m_StartTime;
- std::atomic_uint32_t m_Status;
asio::streambuf m_ReadBuffer;
+ std::string m_RemoteAddr;
+ std::atomic_uint32_t m_Status;
};
WsConnectionState
@@ -402,6 +536,7 @@ public:
private:
asio::io_service& m_IoSvc;
std::vector<std::thread> m_Threads;
+ std::atomic_bool m_Running{false};
};
void
@@ -409,18 +544,28 @@ WsThreadPool::Start(uint32_t ThreadCount)
{
ZEN_ASSERT(m_Threads.empty());
- ZEN_LOG_DEBUG(WsLog, "starting '{}' websocket I/O thread(s)", ThreadCount);
+ ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount);
+
+ m_Running = true;
for (uint32_t Idx = 0; Idx < ThreadCount; Idx++)
{
m_Threads.emplace_back([this, ThreadId = Idx + 1] {
- try
- {
- m_IoSvc.run();
- }
- catch (std::exception& Err)
+ for (;;)
{
- ZEN_LOG_ERROR(WsLog, "process websocket I/O FAILED, reason '{}'", Err.what());
+ if (m_Running == false)
+ {
+ break;
+ }
+
+ try
+ {
+ m_IoSvc.run();
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what());
+ }
}
ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId);
@@ -431,15 +576,20 @@ WsThreadPool::Start(uint32_t ThreadCount)
void
WsThreadPool::Stop()
{
- for (std::thread& Thread : m_Threads)
+ if (m_Running)
{
- if (Thread.joinable())
+ m_Running = false;
+
+ for (std::thread& Thread : m_Threads)
{
- Thread.join();
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
}
- }
- m_Threads.clear();
+ m_Threads.clear();
+ }
}
///////////////////////////////////////////////////////////////////////////////
@@ -485,8 +635,6 @@ WsConnection::~WsConnection()
bool
WsServer::Run(const WebSocketServerOptions& Options)
{
- HttpParser::Initialize();
-
m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6());
m_Acceptor->set_option(asio::ip::v6_only(false));
@@ -500,7 +648,7 @@ WsServer::Run(const WebSocketServerOptions& Options)
if (Ec)
{
- ZEN_LOG_ERROR(WsLog, "failed to bind websocket endpoint, error code '{}'", Ec.value());
+ ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value());
return false;
}
@@ -508,7 +656,7 @@ WsServer::Run(const WebSocketServerOptions& Options)
m_Acceptor->listen();
m_Running = true;
- ZEN_LOG_INFO(WsLog, "web socket server running on port '{}'", Options.Port);
+ ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", Options.Port);
AcceptConnection();
@@ -523,7 +671,7 @@ WsServer::Shutdown()
{
if (m_Running)
{
- ZEN_LOG_INFO(WsLog, "websocket server shutting down");
+ ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down");
m_Running = false;
@@ -544,22 +692,24 @@ WsServer::AcceptConnection()
m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable {
if (Ec)
{
- ZEN_LOG_WARN(WsLog, "accept connection FAILED, error code '{}'", Ec.value());
+ ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, error code '{}'", Ec.value());
}
else
{
- auto ConnId = WsConnectionId::New();
-
- ZEN_LOG_DEBUG(WsLog, "accept connection OK, ID '{}'", ConnId.Value());
-
+ auto ConnId = WsConnectionId::New();
auto Connection = std::make_shared<WsConnection>(*this, ConnId, std::move(ConnectedSocket));
+ ZEN_LOG_DEBUG(LogWebSocket, "accept connection OK, addr '{}', ID '{}'", Connection->RemoteAddr(), ConnId.Value());
+
{
std::unique_lock _(m_ConnMutex);
m_Connections[ConnId] = Connection;
}
- Connection->InitializeHttpParser();
+ auto Parser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest);
+ Parser->Reset();
+
+ Connection->SetParser(std::move(Parser));
Connection->SetState(WsConnectionState::kHandshaking);
ReadMessage(Connection);
@@ -579,11 +729,15 @@ WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::e
{
if (Ec)
{
- ZEN_LOG_INFO(WsLog, "connection '{}' closed, ERROR '{}' error code '{}'", Connection->Id().Value(), Ec.message(), Ec.value());
+ ZEN_LOG_INFO(LogWebSocket,
+ "connection '{}' closed, ERROR '{}' error code '{}'",
+ Connection->Id().Value(),
+ Ec.message(),
+ Ec.value());
}
else
{
- ZEN_LOG_INFO(WsLog, "connection '{}' closed", Connection->Id().Value());
+ ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value());
}
}
@@ -591,14 +745,17 @@ WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::e
{
std::unique_lock _(m_ConnMutex);
- m_Connections.erase(Id);
+ if (m_Connections.contains(Id))
+ {
+ m_Connections.erase(Id);
+ }
}
}
void
WsServer::RemoveConnection(const WsConnectionId Id)
{
- ZEN_LOG_INFO(WsLog, "removing connection '{}'", Id.Value());
+ ZEN_LOG_INFO(LogWebSocket, "removing connection '{}'", Id.Value());
}
void
@@ -616,7 +773,11 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
return CloseConnection(Connection, ReadEc);
}
- ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection->Id().Value());
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "reading {}B from connection '#{} {}'",
+ ByteCount,
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
using enum WsConnectionState;
@@ -624,20 +785,32 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
{
case kHandshaking:
{
- HttpParser& Parser = Connection->ParserHttp();
- const size_t Consumed = Parser.Parse(Connection->ReadBuffer().data());
- Connection->ReadBuffer().consume(Consumed);
+ HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser());
+ asio::const_buffer Buffer = Connection->ReadBuffer().data();
+
+ ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size()));
- if (Parser.IsComplete == false)
+ if (Result.Status == ParseMessageStatus::kContinue)
{
return ReadMessage(Connection);
}
- if (Parser.IsUpgrade == false)
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ return CloseConnection(Connection, std::error_code());
+ }
+
+ if (Parser.IsUpgrade() == false)
{
- ZEN_LOG_DEBUG(WsLog,
- "handshake with connection '{}' FAILED, reason 'not an upgrade request'",
- Connection->Id().Value());
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv;
@@ -646,27 +819,28 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
[this, Connection](const asio::error_code& WriteEc, std::size_t) {
if (WriteEc)
{
- CloseConnection(Connection, WriteEc);
+ return CloseConnection(Connection, WriteEc);
}
- else
- {
- Connection->InitializeHttpParser();
- Connection->SetState(WsConnectionState::kHandshaking);
- ReadMessage(Connection);
- }
+ Connection->Parser()->Reset();
+ Connection->SetState(WsConnectionState::kHandshaking);
+
+ ReadMessage(Connection);
});
}
- std::unordered_map<std::string_view, std::string_view> Headers;
- Parser.GetHeaders(Headers);
+ ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
+ std::string AcceptHash;
std::string Reason;
- std::string AcceptHash = Parser.ValidateWebSocketHandshake(Headers, Reason);
+ const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason);
- if (AcceptHash.empty())
+ if (ValidHandshake == false)
{
- ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection->Id().Value(), Reason);
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '{}' FAILED, reason '{}'",
+ Connection->Id().Value(),
+ Reason);
constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv;
@@ -675,16 +849,13 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
[this, &Connection](const asio::error_code& WriteEc, std::size_t) {
if (WriteEc)
{
- CloseConnection(Connection, WriteEc);
+ return CloseConnection(Connection, WriteEc);
}
- else
- {
- // TODO: Always close connection?
- Connection->InitializeHttpParser();
- Connection->SetState(WsConnectionState::kHandshaking);
- ReadMessage(Connection);
- }
+ Connection->Parser()->Reset();
+ Connection->SetState(WsConnectionState::kHandshaking);
+
+ ReadMessage(Connection);
});
}
@@ -695,95 +866,325 @@ WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
Sb << "Connection: Upgrade\r\n"sv;
// TODO: Verify protocol
- if (Headers.contains(http_header::SecWebSocketProtocol))
+ if (Parser.Headers().contains(http_header::SecWebSocketProtocol))
{
- Sb << http_header::SecWebSocketProtocol << ": " << Headers[http_header::SecWebSocketProtocol] << "\r\n";
+ Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol]
+ << "\r\n";
}
Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n";
Sb << "\r\n"sv;
- std::string Response = Sb.ToString();
- asio::const_buffer Buffer = asio::buffer(Response);
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "accepting handshake from connection '#{} {}'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
- ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection->Id().Value());
+ std::string Response = Sb.ToString();
+ Buffer = asio::buffer(Response);
async_write(Connection->Socket(),
Buffer,
- [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) {
+ [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) {
if (WriteEc)
{
- CloseConnection(Connection, WriteEc);
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '{}' FAILED, reason '{}'",
+ Connection->Id().Value(),
+ WriteEc.message());
+
+ return CloseConnection(Connection, WriteEc);
}
- else
- {
- ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection->Id().Value());
- Connection->ReleaseHttpParser();
- Connection->SetState(kConnected);
- Connection->MessageParser().Reset();
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake ({}B) with connection '#{} {}' OK",
+ ByteCount,
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
- ReadMessage(Connection);
- }
+ Connection->SetParser(std::make_unique<WebSocketMessageParser>());
+ Connection->SetState(kConnected);
});
}
break;
case kConnected:
{
- for (;;)
- {
- if (Connection->ReadBuffer().size() == 0)
- {
- break;
- }
+ // for (;;)
+ //{
+ // if (Connection->ReadBuffer().size() == 0)
+ // {
+ // break;
+ // }
+
+ // WsMessageParser& MessageParser = Connection->MessageParser();
+
+ // size_t ConsumedBytes{};
+ // const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes);
+
+ // Connection->ReadBuffer().consume(ConsumedBytes);
+
+ // if (Ok == false)
+ // {
+ // ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, connection '{}'",
+ // Connection->Id().Value()); MessageParser.Reset();
+ // }
+
+ // if (Ok == false || MessageParser.IsComplete() == false)
+ // {
+ // continue;
+ // }
+
+ // CbPackage Message;
+ // if (MessageParser.TryLoadMessage(Message) == false)
+ // {
+ // ZEN_LOG_WARN(LogWebSocket, "invalid websocket message, connection '{}'",
+ // Connection->Id().Value()); continue;
+ // }
+
+ // RouteMessage(Message);
+ //}
+
+ // ReadMessage(Connection);
+ }
+ break;
+
+ default:
+ break;
+ };
+ });
+}
+
+void
+WsServer::RouteMessage(const CbPackage& Msg)
+{
+ ZEN_UNUSED(Msg);
+ ZEN_LOG_DEBUG(LogWebSocket, "routing message");
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsClient final : public WebSocketClient
+{
+public:
+ WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Logger(zen::logging::Get("websocket-client")) {}
+
+ virtual ~WsClient() { Disconnect(); }
+
+ virtual bool Connect(const WebSocketConnectInfo& Info) override;
+ virtual void Disconnect() override;
+ virtual bool IsConnected() const { return false; }
+ virtual WebSocketState State() const { return static_cast<WebSocketState>(m_State.load()); }
+
+ virtual void On(WebSocketEvent Evt, EventCallback&& Cb) override;
+
+private:
+ WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
+ void TriggerEvent(WebSocketEvent Evt);
+ void BeginRead();
+ spdlog::logger& Log() { return m_Logger; }
+
+ asio::io_context& m_IoCtx;
+ spdlog::logger& m_Logger;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<MessageParser> m_MsgParser;
+ asio::streambuf m_ReadBuffer;
+ EventCallback m_EventCallbacks[3];
+ std::atomic_uint32_t m_State;
+ std::string m_Host;
+ int16_t m_Port{};
+};
+
+bool
+WsClient::Connect(const WebSocketConnectInfo& Info)
+{
+ if (State() == WebSocketState::kConnecting || State() == WebSocketState::kConnected)
+ {
+ return true;
+ }
+
+ SetState(WebSocketState::kConnecting);
+
+ try
+ {
+ asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port);
+ m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol());
+
+ m_Socket->connect(Endpoint);
+
+ m_Host = m_Socket->remote_endpoint().address().to_string();
+ m_Port = Info.Port;
+
+ ZEN_INFO("connected to websocket server '{}:{}'", m_Host, m_Port);
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_WARN("connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what());
+
+ SetState(WebSocketState::kFailedToConnect);
+ m_Socket.reset();
+
+ TriggerEvent(WebSocketEvent::kDisconnected);
+
+ return false;
+ }
+
+ ExtendableStringBuilder<128> Sb;
+ Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv;
+ Sb << "Host: " << Info.Host << "\r\n"sv;
+ Sb << "Upgrade: websocket\r\n"sv;
+ Sb << "Connection: upgrade\r\n"sv;
+ Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv;
+
+ if (Info.Protocols.empty() == false)
+ {
+ Sb << "Sec-WebSocket-Protocol: "sv;
+ for (size_t Idx = 0; const auto& Protocol : Info.Protocols)
+ {
+ if (Idx++)
+ {
+ Sb << ", ";
+ }
+ Sb << Protocol;
+ }
+ }
+
+ Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv;
+ Sb << "\r\n";
+
+ std::string HandshakeRequest = Sb.ToString();
+ asio::const_buffer Buffer = asio::buffer(HandshakeRequest);
+
+ ZEN_DEBUG("handshaking with '{}:{}'", m_Host, m_Port);
+
+ m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse);
+ m_MsgParser->Reset();
+
+ async_write(*m_Socket, Buffer, [this, _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_ERROR("write data FAILED, reason '{}'", Ec.message());
+
+ Disconnect();
+ }
+ else
+ {
+ BeginRead();
+ }
+ });
+
+ return true;
+}
+
+void
+WsClient::Disconnect()
+{
+ if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected)
+ {
+ ZEN_INFO("closing connection to '{}:{}'", m_Host, m_Port);
+
+ if (m_Socket && m_Socket->is_open())
+ {
+ m_Socket->close();
+ m_Socket.reset();
+ }
+
+ TriggerEvent(WebSocketEvent::kDisconnected);
+ }
+}
+
+void
+WsClient::On(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
+{
+ m_EventCallbacks[static_cast<uint32_t>(Evt)] = Cb;
+}
+
+void
+WsClient::TriggerEvent(WebSocketEvent Evt)
+{
+ const uint32_t Index = static_cast<uint32_t>(Evt);
+
+ if (m_EventCallbacks[Index])
+ {
+ m_EventCallbacks[Index]();
+ }
+}
+
+void
+WsClient::BeginRead()
+{
+ m_ReadBuffer.prepare(64 << 10);
+
+ async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t ByteCount) {
+ if (Ec)
+ {
+ ZEN_DEBUG("read data from '{}:{}' FAILED, reason '{}'", m_Host, m_Port, Ec.message());
+
+ Disconnect();
+ }
+ else
+ {
+ ZEN_DEBUG("reading {}B from '{}:{}'", ByteCount, m_Host, m_Port);
+
+ switch (State())
+ {
+ case WebSocketState::kConnecting:
+ {
+ ZEN_ASSERT(m_MsgParser.get() != nullptr);
- WsMessageParser& MessageParser = Connection->MessageParser();
+ HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(m_MsgParser.get());
- size_t ConsumedBytes{};
- const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes);
+ asio::const_buffer Buffer = m_ReadBuffer.data();
+ ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size()));
- Connection->ReadBuffer().consume(ConsumedBytes);
+ m_ReadBuffer.consume(size_t(Result.ByteCount));
- if (Ok == false)
- {
- ZEN_LOG_WARN(WsLog, "parse websocket message FAILED, connection '{}'", Connection->Id().Value());
- MessageParser.Reset();
- }
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_WARN("handshake with '{}:{}' FAILED, status code '{}'", m_Host, m_Port, Parser.StatusCode());
+
+ return Disconnect();
+ }
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ return BeginRead();
+ }
- if (Ok == false || MessageParser.IsComplete() == false)
- {
- continue;
- }
+ ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
- CbPackage Message;
- if (MessageParser.TryLoadMessage(Message) == false)
- {
- ZEN_LOG_WARN(WsLog, "invalid websocket message, connection '{}'", Connection->Id().Value());
- continue;
- }
+ if (Parser.StatusCode() != 101)
+ {
+ ZEN_WARN("handshake with '{}:{}' FAILED, status '{}', status code '{}'",
+ m_Host,
+ m_Port,
+ Parser.StatusText(),
+ Parser.StatusCode());
- RouteMessage(Message);
+ return Disconnect();
}
- ReadMessage(Connection);
+ ZEN_INFO("handshake with '{}:{}' OK, status '{}'", m_Host, m_Port, Parser.StatusText());
+
+ m_MsgParser = std::make_unique<WebSocketMessageParser>();
+
+ SetState(WebSocketState::kConnected);
+ TriggerEvent(WebSocketEvent::kConnected);
+
+ BeginRead();
}
break;
- default:
+ case WebSocketState::kConnected:
+ {
+ BeginRead();
+ }
break;
};
- });
-}
-
-void
-WsServer::RouteMessage(const CbPackage& Msg)
-{
- ZEN_UNUSED(Msg);
- ZEN_LOG_DEBUG(WsLog, "routing message");
+ }
+ });
}
-} // namespace zen::asio_ws
+} // namespace zen::websocket
namespace zen {
@@ -810,7 +1211,13 @@ WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeade
std::unique_ptr<WebSocketServer>
WebSocketServer::Create()
{
- return std::make_unique<asio_ws::WsServer>();
+ return std::make_unique<websocket::WsServer>();
+}
+
+std::unique_ptr<WebSocketClient>
+WebSocketClient::Create(asio::io_context& IoCtx)
+{
+ return std::make_unique<websocket::WsClient>(IoCtx);
}
} // namespace zen