diff options
| author | Per Larsson <[email protected]> | 2022-02-15 14:01:27 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-02-15 14:01:27 +0100 |
| commit | 921ce02cca7c15113452fde59bafc8fb58663b98 (patch) | |
| tree | aa156785a389f0713a9999a68aaad25b851c634b | |
| parent | Initial websocket support. (diff) | |
| download | zen-921ce02cca7c15113452fde59bafc8fb58663b98.tar.xz zen-921ce02cca7c15113452fde59bafc8fb58663b98.zip | |
Refactored websocket server and added static logger support.
| -rw-r--r-- | zencore/include/zencore/logging.h | 61 | ||||
| -rw-r--r-- | zenhttp/asiowebsocketserver.cpp | 681 | ||||
| -rw-r--r-- | zenhttp/websocketserver.cpp | 471 |
3 files changed, 742 insertions, 471 deletions
diff --git a/zencore/include/zencore/logging.h b/zencore/include/zencore/logging.h index 468e5d6e2..74ab0f81f 100644 --- a/zencore/include/zencore/logging.h +++ b/zencore/include/zencore/logging.h @@ -38,6 +38,67 @@ using logging::ConsoleLog; using zen::ConsoleLog; using zen::Log; +struct LogCategory +{ + LogCategory(std::string_view InCategory) : Category(InCategory) {} + + spdlog::logger& Logger() + { + static spdlog::logger& Inst = zen::logging::Get(Category); + return Inst; + } + + std::string Category; +}; + +#define ZEN_DEFINE_LOG_CATEGORY_STATIC(Category, Name) \ + static struct LogCategory##Category : public LogCategory \ + { \ + LogCategory##Category() : LogCategory(Name) {} \ + } Category; + +#define ZEN_LOG_TRACE(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().trace(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + +#define ZEN_LOG_DEBUG(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().debug(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + +#define ZEN_LOG_INFO(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().info(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + +#define ZEN_LOG_WARN(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().warn(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + +#define ZEN_LOG_ERROR(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().error(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + +#define ZEN_LOG_CRITICAL(Category, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + Category.Logger().critical(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) + // Helper macros for logging #define ZEN_TRACE(fmtstr, ...) \ diff --git a/zenhttp/asiowebsocketserver.cpp b/zenhttp/asiowebsocketserver.cpp new file mode 100644 index 000000000..00e1c60ba --- /dev/null +++ b/zenhttp/asiowebsocketserver.cpp @@ -0,0 +1,681 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/websocketserver.h> + +#include <zencore/base64.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/sha1.h> +#include <zencore/string.h> + +#include <chrono> +#include <compare> +#include <optional> +#include <shared_mutex> +#include <span> +#include <system_error> +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <http_parser.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::asio_ws { + +using namespace std::literals; + +ZEN_DEFINE_LOG_CATEGORY_STATIC(WsLog, "websocket"sv); + +using Clock = std::chrono::steady_clock; +using TimePoint = Clock::time_point; + +/////////////////////////////////////////////////////////////////////////////// +namespace http_header { + static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; + static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; + static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; + static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; + static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; + static constexpr std::string_view Upgrade = "Upgrade"sv; +} // namespace http_header + +/////////////////////////////////////////////////////////////////////////////// +struct HttpParser +{ + HttpParser() + { + http_parser_init(&Parser, HTTP_REQUEST); + Parser.data = this; + } + + size_t Parse(asio::const_buffer Buffer) + { + return http_parser_execute(&Parser, &ParserSettings, reinterpret_cast<const char*>(Buffer.data()), Buffer.size()); + } + + void GetHeaders(std::unordered_map<std::string_view, std::string_view>& OutHeaders) + { + OutHeaders.reserve(HeaderEntries.size()); + + for (const auto& E : HeaderEntries) + { + auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size); + auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size); + + OutHeaders[Name] = Value; + } + } + + std::string ValidateWebSocketHandshake(std::unordered_map<std::string_view, std::string_view>& Headers, std::string& OutReason) + { + static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + + std::string AcceptHash; + + if (Headers.contains(http_header::SecWebSocketKey) == false) + { + OutReason = "Missing header Sec-WebSocket-Key"; + return AcceptHash; + } + + if (Headers.contains(http_header::Upgrade) == false) + { + OutReason = "Missing header Upgrade"; + return AcceptHash; + } + + ExtendableStringBuilder<128> Sb; + Sb << Headers[http_header::SecWebSocketKey] << WebSocketGuid; + + SHA1Stream HashStream; + HashStream.Append(Sb.Data(), Sb.Size()); + + SHA1 Hash = HashStream.GetHash(); + + AcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); + Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), AcceptHash.data()); + + return AcceptHash; + } + + static void Initialize() + { + ParserSettings = {.on_message_begin = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + Parser.Url = UrlEntry{}; + Parser.CurrentHeader = HeaderEntry{}; + Parser.IsUpgrade = false; + Parser.IsComplete = false; + + Parser.HeaderStream.Clear(); + Parser.HeaderEntries.clear(); + + return 0; + }, + .on_url = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + Parser.Url.Offset = Parser.HeaderStream.Pos(); + Parser.Url.Size = Size; + + Parser.HeaderStream.Append(Data, uint32_t(Size)); + + return 0; + }, + .on_header_field = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + if (Parser.CurrentHeader.Value.Size > 0) + { + Parser.HeaderEntries.push_back(Parser.CurrentHeader); + Parser.CurrentHeader = HeaderEntry{}; + } + + if (Parser.CurrentHeader.Name.Size == 0) + { + Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos(); + } + + Parser.CurrentHeader.Name.Size += Size; + + Parser.HeaderStream.Append(Data, Size); + + return 0; + }, + .on_header_value = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + if (Parser.CurrentHeader.Value.Size == 0) + { + Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos(); + } + + Parser.CurrentHeader.Value.Size += Size; + + Parser.HeaderStream.Append(Data, Size); + + return 0; + }, + .on_headers_complete = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + if (Parser.CurrentHeader.Value.Size > 0) + { + Parser.HeaderEntries.push_back(Parser.CurrentHeader); + Parser.CurrentHeader = HeaderEntry{}; + } + + Parser.IsUpgrade = P->upgrade > 0; + + return 0; + }, + .on_message_complete = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + Parser.IsComplete = true; + Parser.IsUpgrade = P->upgrade > 0; + return 0; + }}; + } + + struct MemStreamEntry + { + size_t Offset{}; + size_t Size{}; + }; + + class MemStream + { + public: + MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {} + + void Append(const char* Data, size_t Size) + { + const size_t NewSize = m_Size + Size; + + if (NewSize > m_Buf.size()) + { + m_Buf.resize(m_Buf.size() + m_BlockSize); + } + + memcpy(m_Buf.data() + m_Size, Data, Size); + m_Size += Size; + } + + const char* Data() const { return m_Buf.data(); } + size_t Pos() const { return m_Size; } + void Clear() { m_Size = 0; } + + private: + std::vector<char> m_Buf; + size_t m_Size{}; + size_t m_BlockSize{}; + }; + + using UrlEntry = MemStreamEntry; + + struct HeaderEntry + { + MemStreamEntry Name; + MemStreamEntry Value; + }; + + static http_parser_settings ParserSettings; + + http_parser Parser; + MemStream HeaderStream; + std::vector<HeaderEntry> HeaderEntries; + HeaderEntry CurrentHeader{}; + UrlEntry Url{}; + bool IsUpgrade = false; + bool IsComplete = false; +}; + +http_parser_settings HttpParser::ParserSettings; + +/////////////////////////////////////////////////////////////////////////////// +enum class WsConnectionState : uint32_t +{ + kDisconnected, + kHandshaking, + kConnected +}; + +/////////////////////////////////////////////////////////////////////////////// +class WsConnectionId +{ + static std::atomic_uint32_t WsConnectionCounter; + +public: + WsConnectionId() = default; + + uint32_t Value() const { return m_Value; } + + auto operator<=>(const WsConnectionId& RHS) const = default; + + static WsConnectionId New() { return WsConnectionId(WsConnectionCounter.fetch_add(1)); } + +private: + WsConnectionId(uint32_t Value) : m_Value(Value) {} + + uint32_t m_Value{}; +}; + +std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1}; + +class WsServer; + +/////////////////////////////////////////////////////////////////////////////// +class WsConnection : public std::enable_shared_from_this<WsConnection> +{ +public: + WsConnection(WsServer& Server, WsConnectionId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) + : m_Server(Server) + , m_Id(Id) + , m_Socket(std::move(Socket)) + , m_StartTime(Clock::now()) + , m_Status() + { + } + + ~WsConnection(); + + WsConnectionId Id() const { return m_Id; } + asio::ip::tcp::socket& Socket() { return *m_Socket; } + TimePoint StartTime() const { return m_StartTime; } + std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + HttpParser& ParserHttp() { return *m_HttpParser; } + WsConnectionState Close(); + WsConnectionState State() const { return static_cast<WsConnectionState>(m_Status.load(std::memory_order_relaxed)); } + + WsConnectionState SetState(WsConnectionState NewState) { return static_cast<WsConnectionState>(m_Status.exchange(uint32_t(NewState))); } + + void InitializeHttpParser() { m_HttpParser = std::make_unique<HttpParser>(); } + void ReleaseHttpParser() { m_HttpParser.reset(); } + +private: + WsServer& m_Server; + WsConnectionId m_Id; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<HttpParser> m_HttpParser; + TimePoint m_StartTime; + std::atomic_uint32_t m_Status; + asio::streambuf m_ReadBuffer; +}; + +WsConnectionState +WsConnection::Close() +{ + using enum WsConnectionState; + + const auto PrevState = SetState(kDisconnected); + + if (PrevState != kDisconnected && m_Socket->is_open()) + { + m_Socket->close(); + } + + return PrevState; +} + +/////////////////////////////////////////////////////////////////////////////// +class WsThreadPool +{ +public: + WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} + void Start(uint32_t ThreadCount); + void Stop(); + +private: + asio::io_service& m_IoSvc; + std::vector<std::thread> m_Threads; +}; + +void +WsThreadPool::Start(uint32_t ThreadCount) +{ + ZEN_ASSERT(m_Threads.empty()); + + ZEN_LOG_DEBUG(WsLog, "starting '{}' websocket I/O thread(s)", ThreadCount); + + for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) + { + m_Threads.emplace_back([this, ThreadId = Idx + 1] { + try + { + m_IoSvc.run(); + } + catch (std::exception& Err) + { + ZEN_LOG_ERROR(WsLog, "process websocket I/O FAILED, reason '{}'", Err.what()); + } + + ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId); + }); + } +} + +void +WsThreadPool::Stop() +{ + for (std::thread& Thread : m_Threads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Threads.clear(); +} + +/////////////////////////////////////////////////////////////////////////////// +class WsServer final : public WebSocketServer +{ +public: + WsServer() = default; + virtual ~WsServer() { Shutdown(); } + + virtual bool Run(const WebSocketServerOptions& Options) override; + virtual void Shutdown() override; + +private: + friend class WsConnection; + + void AcceptConnection(); + void CloseConnection(WsConnection& Connection, const std::error_code& Ec); + void RemoveConnection(WsConnection& Connection); + void ReadConnection(WsConnection& Connection); + + struct IdHasher + { + size_t operator()(WsConnectionId Id) const { return size_t(Id.Value()); } + }; + + using ConnectionMap = std::unordered_map<WsConnectionId, std::shared_ptr<WsConnection>, IdHasher>; + + asio::io_service m_IoSvc; + std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; + std::unique_ptr<WsThreadPool> m_ThreadPool; + ConnectionMap m_Connections; + std::shared_mutex m_ConnMutex; + std::atomic_bool m_Running{}; +}; + +WsConnection::~WsConnection() +{ + m_Server.RemoveConnection(*this); +} + +bool +WsServer::Run(const WebSocketServerOptions& Options) +{ + HttpParser::Initialize(); + + m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6()); + + m_Acceptor->set_option(asio::ip::v6_only(false)); + m_Acceptor->set_option(asio::socket_base::reuse_address(true)); + m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + asio::error_code Ec; + m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec); + + if (Ec) + { + ZEN_LOG_ERROR(WsLog, "failed to bind websocket endpoint, error code '{}'", Ec.value()); + + return false; + } + + m_Acceptor->listen(); + m_Running = true; + + ZEN_LOG_INFO(WsLog, "web socket server running on port '{}'", Options.Port); + + AcceptConnection(); + + m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc); + m_ThreadPool->Start(Options.ThreadCount); + + return true; +} + +void +WsServer::Shutdown() +{ + if (m_Running) + { + ZEN_LOG_INFO(WsLog, "websocket server shutting down"); + + m_Running = false; + + m_Acceptor->close(); + m_Acceptor.reset(); + m_IoSvc.stop(); + + m_ThreadPool->Stop(); + } +} + +void +WsServer::AcceptConnection() +{ + auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc); + asio::ip::tcp::socket& SocketRef = *Socket.get(); + + m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_LOG_WARN(WsLog, "accept connection FAILED, error code '{}'", Ec.value()); + } + else + { + auto ConnId = WsConnectionId::New(); + + ZEN_LOG_DEBUG(WsLog, "accept connection OK, ID '{}'", ConnId.Value()); + + auto Connection = std::make_shared<WsConnection>(*this, ConnId, std::move(ConnectedSocket)); + + { + std::unique_lock _(m_ConnMutex); + m_Connections[ConnId] = Connection; + } + + Connection->InitializeHttpParser(); + Connection->SetState(WsConnectionState::kHandshaking); + + ReadConnection(*Connection); + } + + if (m_Running) + { + AcceptConnection(); + } + }); +} + +void +WsServer::CloseConnection(WsConnection& Connection, const std::error_code& Ec) +{ + if (const auto State = Connection.Close(); State != WsConnectionState::kDisconnected) + { + if (Ec) + { + ZEN_LOG_INFO(WsLog, + "closing connection '{}' ERROR, reason '{}' error code '{}'", + Connection.Id().Value(), + Ec.message(), + Ec.value()); + } + else + { + ZEN_LOG_INFO(WsLog, "closing connection '{}'", Connection.Id().Value()); + } + } +} + +void +WsServer::RemoveConnection(WsConnection& Connection) +{ + ZEN_LOG_INFO(WsLog, "removing connection '{}'", Connection.Id().Value()); +} + +void +WsServer::ReadConnection(WsConnection& Connection) +{ + Connection.ReadBuffer().prepare(64 << 10); + + asio::async_read( + Connection.Socket(), + Connection.ReadBuffer(), + asio::transfer_at_least(1), + [this, &Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { + if (ReadEc) + { + return CloseConnection(Connection, ReadEc); + } + + ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection.Id().Value()); + + using enum WsConnectionState; + + switch (Connection.State()) + { + case kHandshaking: + { + HttpParser& Parser = Connection.ParserHttp(); + const size_t Consumed = Parser.Parse(Connection.ReadBuffer().data()); + Connection.ReadBuffer().consume(Consumed); + + if (Parser.IsComplete == false) + { + return ReadConnection(Connection); + } + + if (Parser.IsUpgrade == false) + { + ZEN_LOG_DEBUG(WsLog, + "handshake with connection '{}' FAILED, reason 'not an upgrade request'", + Connection.Id().Value()); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; + + return async_write(Connection.Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + CloseConnection(Connection, WriteEc); + } + else + { + Connection.InitializeHttpParser(); + Connection.SetState(WsConnectionState::kHandshaking); + + ReadConnection(Connection); + } + }); + } + + std::unordered_map<std::string_view, std::string_view> Headers; + Parser.GetHeaders(Headers); + + std::string Reason; + std::string AcceptHash = Parser.ValidateWebSocketHandshake(Headers, Reason); + + if (AcceptHash.empty()) + { + ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection.Id().Value(), Reason); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; + + return async_write(Connection.Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + CloseConnection(Connection, WriteEc); + } + else + { + Connection.InitializeHttpParser(); + Connection.SetState(WsConnectionState::kHandshaking); + + ReadConnection(Connection); + } + }); + } + + ExtendableStringBuilder<128> Sb; + + Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: Upgrade\r\n"sv; + + // TODO: Verify protocol + if (Headers.contains(http_header::SecWebSocketProtocol)) + { + Sb << http_header::SecWebSocketProtocol << ": " << Headers[http_header::SecWebSocketProtocol] << "\r\n"; + } + + Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; + Sb << "\r\n"sv; + + std::string Response = Sb.ToString(); + asio::const_buffer Buffer = asio::buffer(Response); + + ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection.Id().Value()); + + async_write(Connection.Socket(), + Buffer, + [this, &Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + CloseConnection(Connection, WriteEc); + } + else + { + ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection.Id().Value()); + + Connection.ReleaseHttpParser(); + Connection.SetState(kConnected); + + ReadConnection(Connection); + } + }); + } + break; + case kConnected: + { + // TODO: Implement RPC API + } + break; + default: + break; + }; + }); +} + +} // namespace zen::asio_ws + +namespace zen { + +std::unique_ptr<WebSocketServer> +WebSocketServer::Create() +{ + return std::make_unique<asio_ws::WsServer>(); +} + +} // namespace zen diff --git a/zenhttp/websocketserver.cpp b/zenhttp/websocketserver.cpp deleted file mode 100644 index 776ed1019..000000000 --- a/zenhttp/websocketserver.cpp +++ /dev/null @@ -1,471 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zenhttp/websocketserver.h> - -#include <zencore/base64.h> -#include <zencore/iobuffer.h> -#include <zencore/logging.h> -#include <zencore/sha1.h> -#include <zencore/string.h> - -#include <span> -#include <thread> - -ZEN_THIRD_PARTY_INCLUDES_START -#include <fmt/format.h> -#include <http_parser.h> -#include <asio.hpp> -ZEN_THIRD_PARTY_INCLUDES_END - -namespace zen::asio_http { - -using namespace std::literals; - -struct HttpParser -{ - HttpParser() - { - http_parser_init(&Parser, HTTP_REQUEST); - Parser.data = this; - } - - size_t Parse(const char* Data, const size_t Size) { return http_parser_execute(&Parser, &ParserSettings, Data, Size); } - - void GetHeaders(std::unordered_map<std::string_view, std::string_view>& OutHeaders) - { - OutHeaders.reserve(HeaderEntries.size()); - - for (const auto& E : HeaderEntries) - { - auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size); - auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size); - - OutHeaders[Name] = Value; - } - } - - static void Initialize() - { - ParserSettings = {.on_message_begin = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - Parser.Url = UrlEntry{}; - Parser.CurrentHeader = HeaderEntry{}; - Parser.IsUpgrade = false; - Parser.IsComplete = false; - - Parser.HeaderStream.Clear(); - Parser.HeaderEntries.clear(); - - return 0; - }, - .on_url = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - Parser.Url.Offset = Parser.HeaderStream.Pos(); - Parser.Url.Size = Size; - - Parser.HeaderStream.Append(Data, uint32_t(Size)); - - return 0; - }, - .on_header_field = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - if (Parser.CurrentHeader.Value.Size > 0) - { - Parser.HeaderEntries.push_back(Parser.CurrentHeader); - Parser.CurrentHeader = HeaderEntry{}; - } - - if (Parser.CurrentHeader.Name.Size == 0) - { - Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos(); - } - - Parser.CurrentHeader.Name.Size += Size; - - Parser.HeaderStream.Append(Data, Size); - - return 0; - }, - .on_header_value = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - if (Parser.CurrentHeader.Value.Size == 0) - { - Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos(); - } - - Parser.CurrentHeader.Value.Size += Size; - - Parser.HeaderStream.Append(Data, Size); - - return 0; - }, - .on_headers_complete = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - if (Parser.CurrentHeader.Value.Size > 0) - { - Parser.HeaderEntries.push_back(Parser.CurrentHeader); - Parser.CurrentHeader = HeaderEntry{}; - } - - Parser.IsUpgrade = P->upgrade > 0; - - return 0; - }, - .on_message_complete = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - Parser.IsComplete = true; - Parser.IsUpgrade = P->upgrade > 0; - return 0; - }}; - } - - struct MemStreamEntry - { - size_t Offset{}; - size_t Size{}; - }; - - class MemStream - { - public: - MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {} - - void Append(const char* Data, size_t Size) - { - const size_t NewSize = m_Size + Size; - - if (NewSize > m_Buf.size()) - { - m_Buf.resize(m_Buf.size() + m_BlockSize); - } - - memcpy(m_Buf.data() + m_Size, Data, Size); - m_Size += Size; - } - - const char* Data() const { return m_Buf.data(); } - size_t Pos() const { return m_Size; } - void Clear() { m_Size = 0; } - - private: - std::vector<char> m_Buf; - size_t m_Size{}; - size_t m_BlockSize{}; - }; - - using UrlEntry = MemStreamEntry; - - struct HeaderEntry - { - MemStreamEntry Name; - MemStreamEntry Value; - }; - - static http_parser_settings ParserSettings; - - http_parser Parser; - MemStream HeaderStream; - std::vector<HeaderEntry> HeaderEntries; - HeaderEntry CurrentHeader{}; - UrlEntry Url{}; - bool IsUpgrade = false; - bool IsComplete = false; -}; - -http_parser_settings HttpParser::ParserSettings; - -class AsioWebSocketServer final : public WebSocketServer -{ -public: - AsioWebSocketServer() : m_Log(zen::logging::Get("websocket")) { HttpParser::Initialize(); } - - virtual ~AsioWebSocketServer() { Shutdown(); } - - virtual bool Run(const WebSocketServerOptions& Options) override - { - m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoService, asio::ip::tcp::v6()); - - m_Acceptor->set_option(asio::ip::v6_only(false)); - m_Acceptor->set_option(asio::socket_base::reuse_address(true)); - m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); - m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - asio::error_code Ec; - m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec); - - if (Ec) - { - ZEN_ERROR("failed to bind websocket endpoint, error code '{}'", Ec.value()); - - return false; - } - - m_Acceptor->listen(); - - BeginAccept(); - - StartIoThreads(Options.ThreadCount); - - m_Running.store(true, std::memory_order_relaxed); - - ZEN_INFO("websocket server running on port '{}'", Options.Port); - - return true; - } - - virtual void Shutdown() override - { - if (m_Running) - { - ZEN_INFO("websocket server shutting down"); - - m_Running = false; - - m_Acceptor->close(); - m_Acceptor.reset(); - m_IoService.stop(); - - StopIoThreads(); - } - } - -private: - enum class WebSocketState : uint32_t - { - kNone, - kHandshake, - kRead, - kWrite, - kError - }; - - struct WebSocketConnection : public std::enable_shared_from_this<WebSocketConnection> - { - WebSocketConnection(std::unique_ptr<asio::ip::tcp::socket>&& S, uint32_t ConnId) : Socket(std::move(S)), Id(ConnId) {} - - std::unique_ptr<asio::ip::tcp::socket> Socket; - asio::streambuf ReadBuffer; - WebSocketState State; - HttpParser Parser; - uint32_t Id; - }; - - void BeginAccept() - { - auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoService); - asio::ip::tcp::socket& SocketRef = *Socket.get(); - - m_Acceptor->async_accept(SocketRef, [this, NewSocket = std::move(Socket)](const asio::error_code& Ec) mutable { - if (Ec) - { - ZEN_WARN("accept error, error code '{}'", Ec.value()); - } - else - { - const uint32_t Id = m_ConnectionId.fetch_add(1); - auto Connection = std::make_shared<WebSocketConnection>(std::move(NewSocket), Id); - - Connection->State = WebSocketState::kHandshake; - - BeginRead(Connection); - } - - if (m_Running.load(std::memory_order_relaxed)) - { - BeginAccept(); - } - else - { - m_Acceptor->close(); - } - }); - } - - void BeginRead(std::shared_ptr<WebSocketConnection> Connection) - { - Connection->ReadBuffer.prepare(64 << 10); - - asio::async_read(*Connection->Socket, - Connection->ReadBuffer, - asio::transfer_at_least(1), - [Conn = Connection->shared_from_this(), this](const asio::error_code& Ec, std::size_t ByteCount) { - if (Ec) - { - ZEN_ERROR("read FAILED, connection '{}', error code '{}'", Conn->Id, Ec.value()); - Conn->Socket->close(); - return; - } - - ZEN_TRACE("reading {}B from connection '{}'", ByteCount, Conn->Id); - - WebSocketState NextState = WebSocketState::kError; - - switch (Conn->State) - { - case WebSocketState::kHandshake: - NextState = ProcessHandshake(Conn); - break; - } - - Conn->State = NextState; - - if (Conn->State == WebSocketState::kError) - { - ZEN_TRACE("process error, connection '{}'", Conn->Id); - Conn->Socket->close(); - return; - } - - BeginRead(Conn); - }); - } - - WebSocketState ProcessHandshake(std::shared_ptr<WebSocketConnection> Connection) - { - HttpParser& Parser = Connection->Parser; - const asio::const_buffer& Buffer = Connection->ReadBuffer.data(); - - const size_t BytesParsed = Parser.Parse(reinterpret_cast<const char*>(Buffer.data()), Buffer.size()); - Connection->ReadBuffer.consume(BytesParsed); - - if (Parser.IsComplete) - { - if (Parser.IsUpgrade == false) - { - ZEN_DEBUG("invalid websocket handshake request, closing connection '{}'", Connection->Id); - - return WebSocketState::kError; - } - - static constexpr std::string_view WebSocketKey = "Sec-WebSocket-Key"sv; - static constexpr std::string_view WebSocketOriginKey = "Sec-WebSocket-Origin"sv; - static constexpr std::string_view WebSocketProtocolKey = "Sec-WebSocket-Protocol"sv; - static constexpr std::string_view WebSocketVersionKey = "Sec-WebSocket-Version"sv; - static constexpr std::string_view WebSocketAcceptKey = "Sec-WebSocket-Accept"sv; - static constexpr std::string_view UpgradeKey = "Upgrade"sv; - static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; - - std::unordered_map<std::string_view, std::string_view> Headers; - Parser.GetHeaders(Headers); - - ZEN_DEBUG("handshake, Origin='{}', Protocol='{}', Version='{}', Key='{}'", - Headers[WebSocketOriginKey], - Headers[WebSocketProtocolKey], - Headers[WebSocketVersionKey], - Headers[WebSocketKey]); - - ExtendableStringBuilder<128> Sb; - Sb << Headers[WebSocketKey] << WebSocketGuid; - - SHA1Stream HashStream; - HashStream.Append(Sb.Data(), Sb.Size()); - - SHA1 Hash = HashStream.GetHash(); - Sb.Reset(); - - const uint32_t EncodedSize = Base64::GetEncodedDataSize(sizeof(SHA1::Hash)); - Sb.AddUninitialized(EncodedSize); - Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), Sb.Data()); - - std::string AcceptHash = Sb.ToString(); - - Sb.Reset(); - Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; - Sb << "Upgrade: websocket\r\n"sv; - Sb << "Connection: Upgrade\r\n"sv; - Sb << WebSocketProtocolKey << ": " << Headers[WebSocketProtocolKey] << "\r\n"; - Sb << WebSocketAcceptKey << ": " << AcceptHash << "\r\n" - << "\r\n"sv; - - IoBuffer Response = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView(Sb.ToView())); - asio::const_buffer ResponseView(Response.Data(), Response.Size()); - const uint64_t ResponseLength = Response.Size(); - - asio::async_write( - *Connection->Socket, - asio::buffer(Response.Data(), Response.Size()), - asio::transfer_exactly(ResponseLength), - [this, Conn = Connection->shared_from_this(), Buf = Response](const asio::error_code& Ec, std::size_t ByteCount) { - if (Ec) - { - ZEN_ERROR("write {}B FAILED, error code '{}'", ByteCount, Ec.value()); - } - else - { - ZEN_DEBUG("write {}B OK", ByteCount); - } - }); - - return WebSocketState::kRead; - } - - return WebSocketState::kHandshake; - } - - void StartIoThreads(uint32_t ThreadCount) - { - ZEN_DEBUG("starting '{}' websocket I/O thread(s)"); - - for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) - { - m_ThreadPool.emplace_back([this, ThreadId = Idx + 1] { - try - { - m_IoService.run(); - } - catch (std::exception& Err) - { - ZEN_ERROR("process websocket request FAILED, reason '{}'", Err.what()); - } - - ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId); - }); - } - } - - void StopIoThreads() - { - for (std::thread& Thread : m_ThreadPool) - { - if (Thread.joinable()) - { - Thread.join(); - } - } - - m_ThreadPool.clear(); - } - - spdlog::logger& Log() { return m_Log; } - spdlog::logger& m_Log; - - asio::io_service m_IoService; - std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; - std::atomic_bool m_Running{}; - std::atomic_uint32_t m_ConnectionId{1}; - std::vector<std::thread> m_ThreadPool; -}; - -} // namespace zen::asio_http - -namespace zen { - -std::unique_ptr<WebSocketServer> -WebSocketServer::Create() -{ - return std::make_unique<asio_http::AsioWebSocketServer>(); -} - -} // namespace zen |