diff options
| author | Per Larsson <[email protected]> | 2022-02-15 14:07:40 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-02-15 14:07:40 +0100 |
| commit | 386f56cd3b7f06fc30318adcbdc0753ddc02c127 (patch) | |
| tree | bdf94e1cc1b7076745273b1633730d2710a2e305 /zenhttp/websocketasio.cpp | |
| parent | Refactored websocket server and added static logger support. (diff) | |
| download | zen-386f56cd3b7f06fc30318adcbdc0753ddc02c127.tar.xz zen-386f56cd3b7f06fc30318adcbdc0753ddc02c127.zip | |
Renamed asio web socket impl.
Diffstat (limited to 'zenhttp/websocketasio.cpp')
| -rw-r--r-- | zenhttp/websocketasio.cpp | 683 |
1 files changed, 683 insertions, 0 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp new file mode 100644 index 000000000..bb3999780 --- /dev/null +++ b/zenhttp/websocketasio.cpp @@ -0,0 +1,683 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/websocketserver.h> + +#include <zencore/base64.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/sha1.h> +#include <zencore/string.h> + +#include <chrono> +#include <compare> +#include <optional> +#include <shared_mutex> +#include <span> +#include <system_error> +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <http_parser.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::asio_ws { + +using namespace std::literals; + +ZEN_DEFINE_LOG_CATEGORY_STATIC(WsLog, "websocket"sv); + +using Clock = std::chrono::steady_clock; +using TimePoint = Clock::time_point; + +/////////////////////////////////////////////////////////////////////////////// +namespace http_header { + static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; + static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; + static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; + static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; + static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; + static constexpr std::string_view Upgrade = "Upgrade"sv; +} // namespace http_header + +/////////////////////////////////////////////////////////////////////////////// +struct HttpParser +{ + HttpParser() + { + http_parser_init(&Parser, HTTP_REQUEST); + Parser.data = this; + } + + size_t Parse(asio::const_buffer Buffer) + { + return http_parser_execute(&Parser, &ParserSettings, reinterpret_cast<const char*>(Buffer.data()), Buffer.size()); + } + + void GetHeaders(std::unordered_map<std::string_view, std::string_view>& OutHeaders) + { + OutHeaders.reserve(HeaderEntries.size()); + + for (const auto& E : HeaderEntries) + { + auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size); + auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size); + + OutHeaders[Name] = Value; + } + } + + std::string ValidateWebSocketHandshake(std::unordered_map<std::string_view, std::string_view>& Headers, std::string& OutReason) + { + static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + + std::string AcceptHash; + + if (Headers.contains(http_header::SecWebSocketKey) == false) + { + OutReason = "Missing header Sec-WebSocket-Key"; + return AcceptHash; + } + + if (Headers.contains(http_header::Upgrade) == false) + { + OutReason = "Missing header Upgrade"; + return AcceptHash; + } + + ExtendableStringBuilder<128> Sb; + Sb << Headers[http_header::SecWebSocketKey] << WebSocketGuid; + + SHA1Stream HashStream; + HashStream.Append(Sb.Data(), Sb.Size()); + + SHA1 Hash = HashStream.GetHash(); + + AcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); + Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), AcceptHash.data()); + + return AcceptHash; + } + + static void Initialize() + { + ParserSettings = {.on_message_begin = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + Parser.Url = UrlEntry{}; + Parser.CurrentHeader = HeaderEntry{}; + Parser.IsUpgrade = false; + Parser.IsComplete = false; + + Parser.HeaderStream.Clear(); + Parser.HeaderEntries.clear(); + + return 0; + }, + .on_url = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + Parser.Url.Offset = Parser.HeaderStream.Pos(); + Parser.Url.Size = Size; + + Parser.HeaderStream.Append(Data, uint32_t(Size)); + + return 0; + }, + .on_header_field = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + if (Parser.CurrentHeader.Value.Size > 0) + { + Parser.HeaderEntries.push_back(Parser.CurrentHeader); + Parser.CurrentHeader = HeaderEntry{}; + } + + if (Parser.CurrentHeader.Name.Size == 0) + { + Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos(); + } + + Parser.CurrentHeader.Name.Size += Size; + + Parser.HeaderStream.Append(Data, Size); + + return 0; + }, + .on_header_value = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + if (Parser.CurrentHeader.Value.Size == 0) + { + Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos(); + } + + Parser.CurrentHeader.Value.Size += Size; + + Parser.HeaderStream.Append(Data, Size); + + return 0; + }, + .on_headers_complete = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + if (Parser.CurrentHeader.Value.Size > 0) + { + Parser.HeaderEntries.push_back(Parser.CurrentHeader); + Parser.CurrentHeader = HeaderEntry{}; + } + + Parser.IsUpgrade = P->upgrade > 0; + + return 0; + }, + .on_message_complete = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + Parser.IsComplete = true; + Parser.IsUpgrade = P->upgrade > 0; + return 0; + }}; + } + + struct MemStreamEntry + { + size_t Offset{}; + size_t Size{}; + }; + + class MemStream + { + public: + MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {} + + void Append(const char* Data, size_t Size) + { + const size_t NewSize = m_Size + Size; + + if (NewSize > m_Buf.size()) + { + m_Buf.resize(m_Buf.size() + m_BlockSize); + } + + memcpy(m_Buf.data() + m_Size, Data, Size); + m_Size += Size; + } + + const char* Data() const { return m_Buf.data(); } + size_t Pos() const { return m_Size; } + void Clear() { m_Size = 0; } + + private: + std::vector<char> m_Buf; + size_t m_Size{}; + size_t m_BlockSize{}; + }; + + using UrlEntry = MemStreamEntry; + + struct HeaderEntry + { + MemStreamEntry Name; + MemStreamEntry Value; + }; + + static http_parser_settings ParserSettings; + + http_parser Parser; + MemStream HeaderStream; + std::vector<HeaderEntry> HeaderEntries; + HeaderEntry CurrentHeader{}; + UrlEntry Url{}; + bool IsUpgrade = false; + bool IsComplete = false; +}; + +http_parser_settings HttpParser::ParserSettings; + +/////////////////////////////////////////////////////////////////////////////// +enum class WsConnectionState : uint32_t +{ + kDisconnected, + kHandshaking, + kConnected +}; + +/////////////////////////////////////////////////////////////////////////////// +class WsConnectionId +{ + static std::atomic_uint32_t WsConnectionCounter; + +public: + WsConnectionId() = default; + + uint32_t Value() const { return m_Value; } + + auto operator<=>(const WsConnectionId& RHS) const = default; + + static WsConnectionId New() { return WsConnectionId(WsConnectionCounter.fetch_add(1)); } + +private: + WsConnectionId(uint32_t Value) : m_Value(Value) {} + + uint32_t m_Value{}; +}; + +std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1}; + +class WsServer; + +/////////////////////////////////////////////////////////////////////////////// +class WsConnection : public std::enable_shared_from_this<WsConnection> +{ +public: + WsConnection(WsServer& Server, WsConnectionId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) + : m_Server(Server) + , m_Id(Id) + , m_Socket(std::move(Socket)) + , m_StartTime(Clock::now()) + , m_Status() + { + } + + ~WsConnection(); + + WsConnectionId Id() const { return m_Id; } + asio::ip::tcp::socket& Socket() { return *m_Socket; } + TimePoint StartTime() const { return m_StartTime; } + std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + HttpParser& ParserHttp() { return *m_HttpParser; } + WsConnectionState Close(); + WsConnectionState State() const { return static_cast<WsConnectionState>(m_Status.load(std::memory_order_relaxed)); } + + WsConnectionState SetState(WsConnectionState NewState) { return static_cast<WsConnectionState>(m_Status.exchange(uint32_t(NewState))); } + + void InitializeHttpParser() { m_HttpParser = std::make_unique<HttpParser>(); } + void ReleaseHttpParser() { m_HttpParser.reset(); } + +private: + WsServer& m_Server; + WsConnectionId m_Id; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<HttpParser> m_HttpParser; + TimePoint m_StartTime; + std::atomic_uint32_t m_Status; + asio::streambuf m_ReadBuffer; +}; + +WsConnectionState +WsConnection::Close() +{ + using enum WsConnectionState; + + const auto PrevState = SetState(kDisconnected); + + if (PrevState != kDisconnected && m_Socket->is_open()) + { + m_Socket->close(); + } + + return PrevState; +} + +/////////////////////////////////////////////////////////////////////////////// +class WsThreadPool +{ +public: + WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} + void Start(uint32_t ThreadCount); + void Stop(); + +private: + asio::io_service& m_IoSvc; + std::vector<std::thread> m_Threads; +}; + +void +WsThreadPool::Start(uint32_t ThreadCount) +{ + ZEN_ASSERT(m_Threads.empty()); + + ZEN_LOG_DEBUG(WsLog, "starting '{}' websocket I/O thread(s)", ThreadCount); + + for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) + { + m_Threads.emplace_back([this, ThreadId = Idx + 1] { + try + { + m_IoSvc.run(); + } + catch (std::exception& Err) + { + ZEN_LOG_ERROR(WsLog, "process websocket I/O FAILED, reason '{}'", Err.what()); + } + + ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId); + }); + } +} + +void +WsThreadPool::Stop() +{ + for (std::thread& Thread : m_Threads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Threads.clear(); +} + +/////////////////////////////////////////////////////////////////////////////// +class WsServer final : public WebSocketServer +{ +public: + WsServer() = default; + virtual ~WsServer() { Shutdown(); } + + virtual bool Run(const WebSocketServerOptions& Options) override; + virtual void Shutdown() override; + +private: + friend class WsConnection; + + void AcceptConnection(); + void CloseConnection(WsConnection& Connection, const std::error_code& Ec); + void RemoveConnection(WsConnection& Connection); + void ReadConnection(WsConnection& Connection); + + struct IdHasher + { + size_t operator()(WsConnectionId Id) const { return size_t(Id.Value()); } + }; + + using ConnectionMap = std::unordered_map<WsConnectionId, std::shared_ptr<WsConnection>, IdHasher>; + + asio::io_service m_IoSvc; + std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; + std::unique_ptr<WsThreadPool> m_ThreadPool; + ConnectionMap m_Connections; + std::shared_mutex m_ConnMutex; + std::atomic_bool m_Running{}; +}; + +WsConnection::~WsConnection() +{ + m_Server.RemoveConnection(*this); +} + +bool +WsServer::Run(const WebSocketServerOptions& Options) +{ + HttpParser::Initialize(); + + m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6()); + + m_Acceptor->set_option(asio::ip::v6_only(false)); + m_Acceptor->set_option(asio::socket_base::reuse_address(true)); + m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + asio::error_code Ec; + m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec); + + if (Ec) + { + ZEN_LOG_ERROR(WsLog, "failed to bind websocket endpoint, error code '{}'", Ec.value()); + + return false; + } + + m_Acceptor->listen(); + m_Running = true; + + ZEN_LOG_INFO(WsLog, "web socket server running on port '{}'", Options.Port); + + AcceptConnection(); + + m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc); + m_ThreadPool->Start(Options.ThreadCount); + + return true; +} + +void +WsServer::Shutdown() +{ + if (m_Running) + { + ZEN_LOG_INFO(WsLog, "websocket server shutting down"); + + m_Running = false; + + m_Acceptor->close(); + m_Acceptor.reset(); + m_IoSvc.stop(); + + m_ThreadPool->Stop(); + } +} + +void +WsServer::AcceptConnection() +{ + auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc); + asio::ip::tcp::socket& SocketRef = *Socket.get(); + + m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_LOG_WARN(WsLog, "accept connection FAILED, error code '{}'", Ec.value()); + } + else + { + auto ConnId = WsConnectionId::New(); + + ZEN_LOG_DEBUG(WsLog, "accept connection OK, ID '{}'", ConnId.Value()); + + auto Connection = std::make_shared<WsConnection>(*this, ConnId, std::move(ConnectedSocket)); + + { + std::unique_lock _(m_ConnMutex); + m_Connections[ConnId] = Connection; + } + + Connection->InitializeHttpParser(); + Connection->SetState(WsConnectionState::kHandshaking); + + ReadConnection(*Connection); + } + + if (m_Running) + { + AcceptConnection(); + } + }); +} + +void +WsServer::CloseConnection(WsConnection& Connection, const std::error_code& Ec) +{ + if (const auto State = Connection.Close(); State != WsConnectionState::kDisconnected) + { + if (Ec) + { + ZEN_LOG_INFO(WsLog, + "closing connection '{}' ERROR, reason '{}' error code '{}'", + Connection.Id().Value(), + Ec.message(), + Ec.value()); + } + else + { + ZEN_LOG_INFO(WsLog, "closing connection '{}'", Connection.Id().Value()); + } + } +} + +void +WsServer::RemoveConnection(WsConnection& Connection) +{ + ZEN_LOG_INFO(WsLog, "removing connection '{}'", Connection.Id().Value()); +} + +void +WsServer::ReadConnection(WsConnection& Connection) +{ + Connection.ReadBuffer().prepare(64 << 10); + + asio::async_read( + Connection.Socket(), + Connection.ReadBuffer(), + asio::transfer_at_least(1), + [this, &Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { + if (ReadEc) + { + return CloseConnection(Connection, ReadEc); + } + + ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection.Id().Value()); + + using enum WsConnectionState; + + switch (Connection.State()) + { + case kHandshaking: + { + HttpParser& Parser = Connection.ParserHttp(); + const size_t Consumed = Parser.Parse(Connection.ReadBuffer().data()); + Connection.ReadBuffer().consume(Consumed); + + if (Parser.IsComplete == false) + { + return ReadConnection(Connection); + } + + if (Parser.IsUpgrade == false) + { + ZEN_LOG_DEBUG(WsLog, + "handshake with connection '{}' FAILED, reason 'not an upgrade request'", + Connection.Id().Value()); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; + + return async_write(Connection.Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + CloseConnection(Connection, WriteEc); + } + else + { + Connection.InitializeHttpParser(); + Connection.SetState(WsConnectionState::kHandshaking); + + ReadConnection(Connection); + } + }); + } + + std::unordered_map<std::string_view, std::string_view> Headers; + Parser.GetHeaders(Headers); + + std::string Reason; + std::string AcceptHash = Parser.ValidateWebSocketHandshake(Headers, Reason); + + if (AcceptHash.empty()) + { + ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection.Id().Value(), Reason); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; + + return async_write(Connection.Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + CloseConnection(Connection, WriteEc); + } + else + { + Connection.InitializeHttpParser(); + Connection.SetState(WsConnectionState::kHandshaking); + + ReadConnection(Connection); + } + }); + } + + ExtendableStringBuilder<128> Sb; + + Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: Upgrade\r\n"sv; + + // TODO: Verify protocol + if (Headers.contains(http_header::SecWebSocketProtocol)) + { + Sb << http_header::SecWebSocketProtocol << ": " << Headers[http_header::SecWebSocketProtocol] << "\r\n"; + } + + Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; + Sb << "\r\n"sv; + + std::string Response = Sb.ToString(); + asio::const_buffer Buffer = asio::buffer(Response); + + ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection.Id().Value()); + + async_write(Connection.Socket(), + Buffer, + [this, &Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + CloseConnection(Connection, WriteEc); + } + else + { + ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection.Id().Value()); + + Connection.ReleaseHttpParser(); + Connection.SetState(kConnected); + + ReadConnection(Connection); + } + }); + } + break; + + case kConnected: + { + // TODO: Implement RPC API + } + break; + + default: + break; + }; + }); +} + +} // namespace zen::asio_ws + +namespace zen { + +std::unique_ptr<WebSocketServer> +WebSocketServer::Create() +{ + return std::make_unique<asio_ws::WsServer>(); +} + +} // namespace zen |