// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include #include #include ZEN_THIRD_PARTY_INCLUDES_END namespace zen::asio_ws { using namespace std::literals; ZEN_DEFINE_LOG_CATEGORY_STATIC(WsLog, "websocket"sv); using Clock = std::chrono::steady_clock; using TimePoint = Clock::time_point; /////////////////////////////////////////////////////////////////////////////// namespace http_header { static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; static constexpr std::string_view Upgrade = "Upgrade"sv; } // namespace http_header /////////////////////////////////////////////////////////////////////////////// struct HttpParser { HttpParser() { http_parser_init(&Parser, HTTP_REQUEST); Parser.data = this; } size_t Parse(asio::const_buffer Buffer) { return http_parser_execute(&Parser, &ParserSettings, reinterpret_cast(Buffer.data()), Buffer.size()); } void GetHeaders(std::unordered_map& OutHeaders) { OutHeaders.reserve(HeaderEntries.size()); for (const auto& E : HeaderEntries) { auto Name = std::string_view((const char*)HeaderStream.Data() + E.Name.Offset, E.Name.Size); auto Value = std::string_view((const char*)HeaderStream.Data() + E.Value.Offset, E.Value.Size); OutHeaders[Name] = Value; } } std::string ValidateWebSocketHandshake(std::unordered_map& Headers, std::string& OutReason) { static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; std::string AcceptHash; if (Headers.contains(http_header::SecWebSocketKey) == false) { OutReason = "Missing header Sec-WebSocket-Key"; return AcceptHash; } if (Headers.contains(http_header::Upgrade) == false) { OutReason = "Missing header Upgrade"; return AcceptHash; } ExtendableStringBuilder<128> Sb; Sb << Headers[http_header::SecWebSocketKey] << WebSocketGuid; SHA1Stream HashStream; HashStream.Append(Sb.Data(), Sb.Size()); SHA1 Hash = HashStream.GetHash(); AcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), AcceptHash.data()); return AcceptHash; } static void Initialize() { ParserSettings = {.on_message_begin = [](http_parser* P) { HttpParser& Parser = *reinterpret_cast(P->data); Parser.Url = UrlEntry{}; Parser.CurrentHeader = HeaderEntry{}; Parser.IsUpgrade = false; Parser.IsComplete = false; Parser.HeaderStream.Clear(); Parser.HeaderEntries.clear(); return 0; }, .on_url = [](http_parser* P, const char* Data, size_t Size) { HttpParser& Parser = *reinterpret_cast(P->data); Parser.Url.Offset = Parser.HeaderStream.CurrentOffset(); Parser.Url.Size = Size; Parser.HeaderStream.Write(Data, uint32_t(Size)); return 0; }, .on_header_field = [](http_parser* P, const char* Data, size_t Size) { HttpParser& Parser = *reinterpret_cast(P->data); if (Parser.CurrentHeader.Value.Size > 0) { Parser.HeaderEntries.push_back(Parser.CurrentHeader); Parser.CurrentHeader = HeaderEntry{}; } if (Parser.CurrentHeader.Name.Size == 0) { Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.CurrentOffset(); } Parser.CurrentHeader.Name.Size += Size; Parser.HeaderStream.Write(Data, Size); return 0; }, .on_header_value = [](http_parser* P, const char* Data, size_t Size) { HttpParser& Parser = *reinterpret_cast(P->data); if (Parser.CurrentHeader.Value.Size == 0) { Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.CurrentOffset(); } Parser.CurrentHeader.Value.Size += Size; Parser.HeaderStream.Write(Data, Size); return 0; }, .on_headers_complete = [](http_parser* P) { HttpParser& Parser = *reinterpret_cast(P->data); if (Parser.CurrentHeader.Value.Size > 0) { Parser.HeaderEntries.push_back(Parser.CurrentHeader); Parser.CurrentHeader = HeaderEntry{}; } Parser.IsUpgrade = P->upgrade > 0; return 0; }, .on_message_complete = [](http_parser* P) { HttpParser& Parser = *reinterpret_cast(P->data); Parser.IsComplete = true; Parser.IsUpgrade = P->upgrade > 0; return 0; }}; } struct MemStreamEntry { size_t Offset{}; size_t Size{}; }; using UrlEntry = MemStreamEntry; struct HeaderEntry { MemStreamEntry Name; MemStreamEntry Value; }; static http_parser_settings ParserSettings; http_parser Parser; SimpleBinaryWriter HeaderStream; std::vector HeaderEntries; HeaderEntry CurrentHeader{}; UrlEntry Url{}; bool IsUpgrade = false; bool IsComplete = false; }; http_parser_settings HttpParser::ParserSettings; /////////////////////////////////////////////////////////////////////////////// class WsMessageParser { public: WsMessageParser() {} void Reset() { m_Header.reset(); m_Stream.Clear(); } bool Parse(asio::const_buffer Buffer, size_t& OutConsumedBytes) { if (m_Header.has_value()) { OutConsumedBytes = Min(m_Header.value().ContentLength, Buffer.size()); m_Stream.Write(Buffer.data(), OutConsumedBytes); return true; } const size_t PrevOffset = m_Stream.CurrentOffset(); const size_t BytesToWrite = Min(sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(), Buffer.size()); const size_t RemainingBytes = Buffer.size() - BytesToWrite; m_Stream.Write(Buffer.data(), BytesToWrite); if (m_Stream.CurrentOffset() < sizeof(zen::WebSocketMessageHeader)) { OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; return true; } zen::WebSocketMessageHeader Header; if (zen::WebSocketMessageHeader::Read(m_Stream.GetView(), Header) == false) { OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; return false; } m_Header = Header; if (RemainingBytes > 0) { const size_t RemainingBytesToWrite = Min(m_Header.value().ContentLength, RemainingBytes); m_Stream.Write(reinterpret_cast(Buffer.data()) + BytesToWrite, RemainingBytesToWrite); } OutConsumedBytes = m_Stream.CurrentOffset() - PrevOffset; return true; } bool IsComplete() { if (m_Header.has_value()) { const size_t RemainingBytes = m_Header.value().ContentLength + sizeof(zen::WebSocketMessageHeader) - m_Stream.CurrentOffset(); return RemainingBytes == 0; } return false; } bool TryLoadMessage(CbPackage& OutPackage) { if (IsComplete()) { BinaryReader Reader(m_Stream.Data(), m_Stream.Size()); return OutPackage.TryLoad(Reader); } return false; } private: SimpleBinaryWriter m_Stream{64 << 10}; std::optional m_Header; }; /////////////////////////////////////////////////////////////////////////////// enum class WsConnectionState : uint32_t { kDisconnected, kHandshaking, kConnected }; /////////////////////////////////////////////////////////////////////////////// class WsConnectionId { static std::atomic_uint32_t WsConnectionCounter; public: WsConnectionId() = default; uint32_t Value() const { return m_Value; } auto operator<=>(const WsConnectionId& RHS) const = default; static WsConnectionId New() { return WsConnectionId(WsConnectionCounter.fetch_add(1)); } private: WsConnectionId(uint32_t Value) : m_Value(Value) {} uint32_t m_Value{}; }; std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1}; class WsServer; /////////////////////////////////////////////////////////////////////////////// class WsConnection : public std::enable_shared_from_this { public: WsConnection(WsServer& Server, WsConnectionId Id, std::unique_ptr Socket) : m_Server(Server) , m_Id(Id) , m_Socket(std::move(Socket)) , m_StartTime(Clock::now()) , m_Status() { } ~WsConnection(); WsConnectionId Id() const { return m_Id; } asio::ip::tcp::socket& Socket() { return *m_Socket; } TimePoint StartTime() const { return m_StartTime; } std::shared_ptr AsShared() { return shared_from_this(); } asio::streambuf& ReadBuffer() { return m_ReadBuffer; } HttpParser& ParserHttp() { return *m_HttpParser; } WsMessageParser& MessageParser() { return m_MsgParser; } WsConnectionState Close(); WsConnectionState State() const { return static_cast(m_Status.load(std::memory_order_relaxed)); } WsConnectionState SetState(WsConnectionState NewState) { return static_cast(m_Status.exchange(uint32_t(NewState))); } void InitializeHttpParser() { m_HttpParser = std::make_unique(); } void ReleaseHttpParser() { m_HttpParser.reset(); } private: WsServer& m_Server; WsConnectionId m_Id; std::unique_ptr m_Socket; std::unique_ptr m_HttpParser; WsMessageParser m_MsgParser; TimePoint m_StartTime; std::atomic_uint32_t m_Status; asio::streambuf m_ReadBuffer; }; WsConnectionState WsConnection::Close() { using enum WsConnectionState; const auto PrevState = SetState(kDisconnected); if (PrevState != kDisconnected && m_Socket->is_open()) { m_Socket->close(); } return PrevState; } /////////////////////////////////////////////////////////////////////////////// class WsThreadPool { public: WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} void Start(uint32_t ThreadCount); void Stop(); private: asio::io_service& m_IoSvc; std::vector m_Threads; }; void WsThreadPool::Start(uint32_t ThreadCount) { ZEN_ASSERT(m_Threads.empty()); ZEN_LOG_DEBUG(WsLog, "starting '{}' websocket I/O thread(s)", ThreadCount); for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) { m_Threads.emplace_back([this, ThreadId = Idx + 1] { try { m_IoSvc.run(); } catch (std::exception& Err) { ZEN_LOG_ERROR(WsLog, "process websocket I/O FAILED, reason '{}'", Err.what()); } ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId); }); } } void WsThreadPool::Stop() { for (std::thread& Thread : m_Threads) { if (Thread.joinable()) { Thread.join(); } } m_Threads.clear(); } /////////////////////////////////////////////////////////////////////////////// class WsServer final : public WebSocketServer { public: WsServer() = default; virtual ~WsServer() { Shutdown(); } virtual bool Run(const WebSocketServerOptions& Options) override; virtual void Shutdown() override; private: friend class WsConnection; void AcceptConnection(); void CloseConnection(std::shared_ptr Connection, const std::error_code& Ec); void RemoveConnection(const WsConnectionId Id); void ReadMessage(std::shared_ptr Connection); void RouteMessage(const CbPackage& Msg); struct IdHasher { size_t operator()(WsConnectionId Id) const { return size_t(Id.Value()); } }; using ConnectionMap = std::unordered_map, IdHasher>; asio::io_service m_IoSvc; std::unique_ptr m_Acceptor; std::unique_ptr m_ThreadPool; ConnectionMap m_Connections; std::shared_mutex m_ConnMutex; std::atomic_bool m_Running{}; }; WsConnection::~WsConnection() { m_Server.RemoveConnection(m_Id); } bool WsServer::Run(const WebSocketServerOptions& Options) { HttpParser::Initialize(); m_Acceptor = std::make_unique(m_IoSvc, asio::ip::tcp::v6()); m_Acceptor->set_option(asio::ip::v6_only(false)); m_Acceptor->set_option(asio::socket_base::reuse_address(true)); m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); asio::error_code Ec; m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec); if (Ec) { ZEN_LOG_ERROR(WsLog, "failed to bind websocket endpoint, error code '{}'", Ec.value()); return false; } m_Acceptor->listen(); m_Running = true; ZEN_LOG_INFO(WsLog, "web socket server running on port '{}'", Options.Port); AcceptConnection(); m_ThreadPool = std::make_unique(m_IoSvc); m_ThreadPool->Start(Options.ThreadCount); return true; } void WsServer::Shutdown() { if (m_Running) { ZEN_LOG_INFO(WsLog, "websocket server shutting down"); m_Running = false; m_Acceptor->close(); m_Acceptor.reset(); m_IoSvc.stop(); m_ThreadPool->Stop(); } } void WsServer::AcceptConnection() { auto Socket = std::make_unique(m_IoSvc); asio::ip::tcp::socket& SocketRef = *Socket.get(); m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { if (Ec) { ZEN_LOG_WARN(WsLog, "accept connection FAILED, error code '{}'", Ec.value()); } else { auto ConnId = WsConnectionId::New(); ZEN_LOG_DEBUG(WsLog, "accept connection OK, ID '{}'", ConnId.Value()); auto Connection = std::make_shared(*this, ConnId, std::move(ConnectedSocket)); { std::unique_lock _(m_ConnMutex); m_Connections[ConnId] = Connection; } Connection->InitializeHttpParser(); Connection->SetState(WsConnectionState::kHandshaking); ReadMessage(Connection); } if (m_Running) { AcceptConnection(); } }); } void WsServer::CloseConnection(std::shared_ptr Connection, const std::error_code& Ec) { if (const auto State = Connection->Close(); State != WsConnectionState::kDisconnected) { if (Ec) { ZEN_LOG_INFO(WsLog, "connection '{}' closed, ERROR '{}' error code '{}'", Connection->Id().Value(), Ec.message(), Ec.value()); } else { ZEN_LOG_INFO(WsLog, "connection '{}' closed", Connection->Id().Value()); } } const WsConnectionId Id = Connection->Id(); { std::unique_lock _(m_ConnMutex); m_Connections.erase(Id); } } void WsServer::RemoveConnection(const WsConnectionId Id) { ZEN_LOG_INFO(WsLog, "removing connection '{}'", Id.Value()); } void WsServer::ReadMessage(std::shared_ptr Connection) { Connection->ReadBuffer().prepare(64 << 10); asio::async_read( Connection->Socket(), Connection->ReadBuffer(), asio::transfer_at_least(1), [this, Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable { if (ReadEc) { return CloseConnection(Connection, ReadEc); } ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection->Id().Value()); using enum WsConnectionState; switch (Connection->State()) { case kHandshaking: { HttpParser& Parser = Connection->ParserHttp(); const size_t Consumed = Parser.Parse(Connection->ReadBuffer().data()); Connection->ReadBuffer().consume(Consumed); if (Parser.IsComplete == false) { return ReadMessage(Connection); } if (Parser.IsUpgrade == false) { ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason 'not an upgrade request'", Connection->Id().Value()); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; return async_write(Connection->Socket(), asio::buffer(UpgradeRequiredResponse), [this, Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { CloseConnection(Connection, WriteEc); } else { Connection->InitializeHttpParser(); Connection->SetState(WsConnectionState::kHandshaking); ReadMessage(Connection); } }); } std::unordered_map Headers; Parser.GetHeaders(Headers); std::string Reason; std::string AcceptHash = Parser.ValidateWebSocketHandshake(Headers, Reason); if (AcceptHash.empty()) { ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection->Id().Value(), Reason); constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; return async_write(Connection->Socket(), asio::buffer(UpgradeRequiredResponse), [this, &Connection](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { CloseConnection(Connection, WriteEc); } else { // TODO: Always close connection? Connection->InitializeHttpParser(); Connection->SetState(WsConnectionState::kHandshaking); ReadMessage(Connection); } }); } ExtendableStringBuilder<128> Sb; Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; Sb << "Upgrade: websocket\r\n"sv; Sb << "Connection: Upgrade\r\n"sv; // TODO: Verify protocol if (Headers.contains(http_header::SecWebSocketProtocol)) { Sb << http_header::SecWebSocketProtocol << ": " << Headers[http_header::SecWebSocketProtocol] << "\r\n"; } Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; Sb << "\r\n"sv; std::string Response = Sb.ToString(); asio::const_buffer Buffer = asio::buffer(Response); ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection->Id().Value()); async_write(Connection->Socket(), Buffer, [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) { if (WriteEc) { CloseConnection(Connection, WriteEc); } else { ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection->Id().Value()); Connection->ReleaseHttpParser(); Connection->SetState(kConnected); Connection->MessageParser().Reset(); ReadMessage(Connection); } }); } break; case kConnected: { for (;;) { if (Connection->ReadBuffer().size() == 0) { break; } WsMessageParser& MessageParser = Connection->MessageParser(); size_t ConsumedBytes{}; const bool Ok = MessageParser.Parse(Connection->ReadBuffer().data(), ConsumedBytes); Connection->ReadBuffer().consume(ConsumedBytes); if (Ok == false) { ZEN_LOG_WARN(WsLog, "parse websocket message FAILED, connection '{}'", Connection->Id().Value()); MessageParser.Reset(); } if (Ok == false || MessageParser.IsComplete() == false) { continue; } CbPackage Message; if (MessageParser.TryLoadMessage(Message) == false) { ZEN_LOG_WARN(WsLog, "invalid websocket message, connection '{}'", Connection->Id().Value()); continue; } RouteMessage(Message); } ReadMessage(Connection); } break; default: break; }; }); } void WsServer::RouteMessage(const CbPackage& Msg) { ZEN_UNUSED(Msg); ZEN_LOG_DEBUG(WsLog, "routing message"); } } // namespace zen::asio_ws namespace zen { bool WebSocketMessageHeader::IsValid() const { return Magic == ExpectedMagic && ContentLength != 0 && Crc32 != 0; } bool WebSocketMessageHeader::Read(MemoryView Memory, WebSocketMessageHeader& OutHeader) { if (Memory.GetSize() < sizeof(WebSocketMessageHeader)) { return false; } void* Dst = &OutHeader; memcpy(Dst, Memory.GetData(), sizeof(WebSocketMessageHeader)); return OutHeader.IsValid(); } std::unique_ptr WebSocketServer::Create() { return std::make_unique(); } } // namespace zen