diff options
Diffstat (limited to 'zenhttp/websocketserver.cpp')
| -rw-r--r-- | zenhttp/websocketserver.cpp | 471 |
1 files changed, 0 insertions, 471 deletions
diff --git a/zenhttp/websocketserver.cpp b/zenhttp/websocketserver.cpp deleted file mode 100644 index 776ed1019..000000000 --- a/zenhttp/websocketserver.cpp +++ /dev/null @@ -1,471 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zenhttp/websocketserver.h> - -#include <zencore/base64.h> -#include <zencore/iobuffer.h> -#include <zencore/logging.h> -#include <zencore/sha1.h> -#include <zencore/string.h> - -#include <span> -#include <thread> - -ZEN_THIRD_PARTY_INCLUDES_START -#include <fmt/format.h> -#include <http_parser.h> -#include <asio.hpp> -ZEN_THIRD_PARTY_INCLUDES_END - -namespace zen::asio_http { - -using namespace std::literals; - -struct HttpParser -{ - HttpParser() - { - http_parser_init(&Parser, HTTP_REQUEST); - Parser.data = this; - } - - size_t Parse(const char* Data, const size_t Size) { return http_parser_execute(&Parser, &ParserSettings, Data, Size); } - - void GetHeaders(std::unordered_map<std::string_view, std::string_view>& OutHeaders) - { - OutHeaders.reserve(HeaderEntries.size()); - - for (const auto& E : HeaderEntries) - { - auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size); - auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size); - - OutHeaders[Name] = Value; - } - } - - static void Initialize() - { - ParserSettings = {.on_message_begin = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - Parser.Url = UrlEntry{}; - Parser.CurrentHeader = HeaderEntry{}; - Parser.IsUpgrade = false; - Parser.IsComplete = false; - - Parser.HeaderStream.Clear(); - Parser.HeaderEntries.clear(); - - return 0; - }, - .on_url = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - Parser.Url.Offset = Parser.HeaderStream.Pos(); - Parser.Url.Size = Size; - - Parser.HeaderStream.Append(Data, uint32_t(Size)); - - return 0; - }, - .on_header_field = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - if (Parser.CurrentHeader.Value.Size > 0) - { - Parser.HeaderEntries.push_back(Parser.CurrentHeader); - Parser.CurrentHeader = HeaderEntry{}; - } - - if (Parser.CurrentHeader.Name.Size == 0) - { - Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos(); - } - - Parser.CurrentHeader.Name.Size += Size; - - Parser.HeaderStream.Append(Data, Size); - - return 0; - }, - .on_header_value = - [](http_parser* P, const char* Data, size_t Size) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - if (Parser.CurrentHeader.Value.Size == 0) - { - Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos(); - } - - Parser.CurrentHeader.Value.Size += Size; - - Parser.HeaderStream.Append(Data, Size); - - return 0; - }, - .on_headers_complete = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - - if (Parser.CurrentHeader.Value.Size > 0) - { - Parser.HeaderEntries.push_back(Parser.CurrentHeader); - Parser.CurrentHeader = HeaderEntry{}; - } - - Parser.IsUpgrade = P->upgrade > 0; - - return 0; - }, - .on_message_complete = - [](http_parser* P) { - HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); - Parser.IsComplete = true; - Parser.IsUpgrade = P->upgrade > 0; - return 0; - }}; - } - - struct MemStreamEntry - { - size_t Offset{}; - size_t Size{}; - }; - - class MemStream - { - public: - MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {} - - void Append(const char* Data, size_t Size) - { - const size_t NewSize = m_Size + Size; - - if (NewSize > m_Buf.size()) - { - m_Buf.resize(m_Buf.size() + m_BlockSize); - } - - memcpy(m_Buf.data() + m_Size, Data, Size); - m_Size += Size; - } - - const char* Data() const { return m_Buf.data(); } - size_t Pos() const { return m_Size; } - void Clear() { m_Size = 0; } - - private: - std::vector<char> m_Buf; - size_t m_Size{}; - size_t m_BlockSize{}; - }; - - using UrlEntry = MemStreamEntry; - - struct HeaderEntry - { - MemStreamEntry Name; - MemStreamEntry Value; - }; - - static http_parser_settings ParserSettings; - - http_parser Parser; - MemStream HeaderStream; - std::vector<HeaderEntry> HeaderEntries; - HeaderEntry CurrentHeader{}; - UrlEntry Url{}; - bool IsUpgrade = false; - bool IsComplete = false; -}; - -http_parser_settings HttpParser::ParserSettings; - -class AsioWebSocketServer final : public WebSocketServer -{ -public: - AsioWebSocketServer() : m_Log(zen::logging::Get("websocket")) { HttpParser::Initialize(); } - - virtual ~AsioWebSocketServer() { Shutdown(); } - - virtual bool Run(const WebSocketServerOptions& Options) override - { - m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoService, asio::ip::tcp::v6()); - - m_Acceptor->set_option(asio::ip::v6_only(false)); - m_Acceptor->set_option(asio::socket_base::reuse_address(true)); - m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); - m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - asio::error_code Ec; - m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec); - - if (Ec) - { - ZEN_ERROR("failed to bind websocket endpoint, error code '{}'", Ec.value()); - - return false; - } - - m_Acceptor->listen(); - - BeginAccept(); - - StartIoThreads(Options.ThreadCount); - - m_Running.store(true, std::memory_order_relaxed); - - ZEN_INFO("websocket server running on port '{}'", Options.Port); - - return true; - } - - virtual void Shutdown() override - { - if (m_Running) - { - ZEN_INFO("websocket server shutting down"); - - m_Running = false; - - m_Acceptor->close(); - m_Acceptor.reset(); - m_IoService.stop(); - - StopIoThreads(); - } - } - -private: - enum class WebSocketState : uint32_t - { - kNone, - kHandshake, - kRead, - kWrite, - kError - }; - - struct WebSocketConnection : public std::enable_shared_from_this<WebSocketConnection> - { - WebSocketConnection(std::unique_ptr<asio::ip::tcp::socket>&& S, uint32_t ConnId) : Socket(std::move(S)), Id(ConnId) {} - - std::unique_ptr<asio::ip::tcp::socket> Socket; - asio::streambuf ReadBuffer; - WebSocketState State; - HttpParser Parser; - uint32_t Id; - }; - - void BeginAccept() - { - auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoService); - asio::ip::tcp::socket& SocketRef = *Socket.get(); - - m_Acceptor->async_accept(SocketRef, [this, NewSocket = std::move(Socket)](const asio::error_code& Ec) mutable { - if (Ec) - { - ZEN_WARN("accept error, error code '{}'", Ec.value()); - } - else - { - const uint32_t Id = m_ConnectionId.fetch_add(1); - auto Connection = std::make_shared<WebSocketConnection>(std::move(NewSocket), Id); - - Connection->State = WebSocketState::kHandshake; - - BeginRead(Connection); - } - - if (m_Running.load(std::memory_order_relaxed)) - { - BeginAccept(); - } - else - { - m_Acceptor->close(); - } - }); - } - - void BeginRead(std::shared_ptr<WebSocketConnection> Connection) - { - Connection->ReadBuffer.prepare(64 << 10); - - asio::async_read(*Connection->Socket, - Connection->ReadBuffer, - asio::transfer_at_least(1), - [Conn = Connection->shared_from_this(), this](const asio::error_code& Ec, std::size_t ByteCount) { - if (Ec) - { - ZEN_ERROR("read FAILED, connection '{}', error code '{}'", Conn->Id, Ec.value()); - Conn->Socket->close(); - return; - } - - ZEN_TRACE("reading {}B from connection '{}'", ByteCount, Conn->Id); - - WebSocketState NextState = WebSocketState::kError; - - switch (Conn->State) - { - case WebSocketState::kHandshake: - NextState = ProcessHandshake(Conn); - break; - } - - Conn->State = NextState; - - if (Conn->State == WebSocketState::kError) - { - ZEN_TRACE("process error, connection '{}'", Conn->Id); - Conn->Socket->close(); - return; - } - - BeginRead(Conn); - }); - } - - WebSocketState ProcessHandshake(std::shared_ptr<WebSocketConnection> Connection) - { - HttpParser& Parser = Connection->Parser; - const asio::const_buffer& Buffer = Connection->ReadBuffer.data(); - - const size_t BytesParsed = Parser.Parse(reinterpret_cast<const char*>(Buffer.data()), Buffer.size()); - Connection->ReadBuffer.consume(BytesParsed); - - if (Parser.IsComplete) - { - if (Parser.IsUpgrade == false) - { - ZEN_DEBUG("invalid websocket handshake request, closing connection '{}'", Connection->Id); - - return WebSocketState::kError; - } - - static constexpr std::string_view WebSocketKey = "Sec-WebSocket-Key"sv; - static constexpr std::string_view WebSocketOriginKey = "Sec-WebSocket-Origin"sv; - static constexpr std::string_view WebSocketProtocolKey = "Sec-WebSocket-Protocol"sv; - static constexpr std::string_view WebSocketVersionKey = "Sec-WebSocket-Version"sv; - static constexpr std::string_view WebSocketAcceptKey = "Sec-WebSocket-Accept"sv; - static constexpr std::string_view UpgradeKey = "Upgrade"sv; - static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; - - std::unordered_map<std::string_view, std::string_view> Headers; - Parser.GetHeaders(Headers); - - ZEN_DEBUG("handshake, Origin='{}', Protocol='{}', Version='{}', Key='{}'", - Headers[WebSocketOriginKey], - Headers[WebSocketProtocolKey], - Headers[WebSocketVersionKey], - Headers[WebSocketKey]); - - ExtendableStringBuilder<128> Sb; - Sb << Headers[WebSocketKey] << WebSocketGuid; - - SHA1Stream HashStream; - HashStream.Append(Sb.Data(), Sb.Size()); - - SHA1 Hash = HashStream.GetHash(); - Sb.Reset(); - - const uint32_t EncodedSize = Base64::GetEncodedDataSize(sizeof(SHA1::Hash)); - Sb.AddUninitialized(EncodedSize); - Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), Sb.Data()); - - std::string AcceptHash = Sb.ToString(); - - Sb.Reset(); - Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; - Sb << "Upgrade: websocket\r\n"sv; - Sb << "Connection: Upgrade\r\n"sv; - Sb << WebSocketProtocolKey << ": " << Headers[WebSocketProtocolKey] << "\r\n"; - Sb << WebSocketAcceptKey << ": " << AcceptHash << "\r\n" - << "\r\n"sv; - - IoBuffer Response = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView(Sb.ToView())); - asio::const_buffer ResponseView(Response.Data(), Response.Size()); - const uint64_t ResponseLength = Response.Size(); - - asio::async_write( - *Connection->Socket, - asio::buffer(Response.Data(), Response.Size()), - asio::transfer_exactly(ResponseLength), - [this, Conn = Connection->shared_from_this(), Buf = Response](const asio::error_code& Ec, std::size_t ByteCount) { - if (Ec) - { - ZEN_ERROR("write {}B FAILED, error code '{}'", ByteCount, Ec.value()); - } - else - { - ZEN_DEBUG("write {}B OK", ByteCount); - } - }); - - return WebSocketState::kRead; - } - - return WebSocketState::kHandshake; - } - - void StartIoThreads(uint32_t ThreadCount) - { - ZEN_DEBUG("starting '{}' websocket I/O thread(s)"); - - for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) - { - m_ThreadPool.emplace_back([this, ThreadId = Idx + 1] { - try - { - m_IoService.run(); - } - catch (std::exception& Err) - { - ZEN_ERROR("process websocket request FAILED, reason '{}'", Err.what()); - } - - ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId); - }); - } - } - - void StopIoThreads() - { - for (std::thread& Thread : m_ThreadPool) - { - if (Thread.joinable()) - { - Thread.join(); - } - } - - m_ThreadPool.clear(); - } - - spdlog::logger& Log() { return m_Log; } - spdlog::logger& m_Log; - - asio::io_service m_IoService; - std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; - std::atomic_bool m_Running{}; - std::atomic_uint32_t m_ConnectionId{1}; - std::vector<std::thread> m_ThreadPool; -}; - -} // namespace zen::asio_http - -namespace zen { - -std::unique_ptr<WebSocketServer> -WebSocketServer::Create() -{ - return std::make_unique<asio_http::AsioWebSocketServer>(); -} - -} // namespace zen |