diff options
| author | Per Larsson <[email protected]> | 2022-02-09 20:51:59 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-02-09 20:51:59 +0100 |
| commit | 8692e7882ed9b7e8d3fb298ba51b7779d58a73b9 (patch) | |
| tree | 7738ea0431ea46b98ac30eb0ab2702968f61b25d /zenhttp/websocketserver.cpp | |
| parent | Merge branch 'main' of https://github.com/EpicGames/zen (diff) | |
| download | zen-8692e7882ed9b7e8d3fb298ba51b7779d58a73b9.tar.xz zen-8692e7882ed9b7e8d3fb298ba51b7779d58a73b9.zip | |
Initial websocket support.
Diffstat (limited to 'zenhttp/websocketserver.cpp')
| -rw-r--r-- | zenhttp/websocketserver.cpp | 471 |
1 files changed, 471 insertions, 0 deletions
diff --git a/zenhttp/websocketserver.cpp b/zenhttp/websocketserver.cpp new file mode 100644 index 000000000..776ed1019 --- /dev/null +++ b/zenhttp/websocketserver.cpp @@ -0,0 +1,471 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/websocketserver.h> + +#include <zencore/base64.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/sha1.h> +#include <zencore/string.h> + +#include <span> +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <http_parser.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::asio_http { + +using namespace std::literals; + +struct HttpParser +{ + HttpParser() + { + http_parser_init(&Parser, HTTP_REQUEST); + Parser.data = this; + } + + size_t Parse(const char* Data, const size_t Size) { return http_parser_execute(&Parser, &ParserSettings, Data, Size); } + + void GetHeaders(std::unordered_map<std::string_view, std::string_view>& OutHeaders) + { + OutHeaders.reserve(HeaderEntries.size()); + + for (const auto& E : HeaderEntries) + { + auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size); + auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size); + + OutHeaders[Name] = Value; + } + } + + static void Initialize() + { + ParserSettings = {.on_message_begin = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + Parser.Url = UrlEntry{}; + Parser.CurrentHeader = HeaderEntry{}; + Parser.IsUpgrade = false; + Parser.IsComplete = false; + + Parser.HeaderStream.Clear(); + Parser.HeaderEntries.clear(); + + return 0; + }, + .on_url = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + Parser.Url.Offset = Parser.HeaderStream.Pos(); + Parser.Url.Size = Size; + + Parser.HeaderStream.Append(Data, uint32_t(Size)); + + return 0; + }, + .on_header_field = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + if (Parser.CurrentHeader.Value.Size > 0) + { + Parser.HeaderEntries.push_back(Parser.CurrentHeader); + Parser.CurrentHeader = HeaderEntry{}; + } + + if (Parser.CurrentHeader.Name.Size == 0) + { + Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos(); + } + + Parser.CurrentHeader.Name.Size += Size; + + Parser.HeaderStream.Append(Data, Size); + + return 0; + }, + .on_header_value = + [](http_parser* P, const char* Data, size_t Size) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + if (Parser.CurrentHeader.Value.Size == 0) + { + Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos(); + } + + Parser.CurrentHeader.Value.Size += Size; + + Parser.HeaderStream.Append(Data, Size); + + return 0; + }, + .on_headers_complete = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + + if (Parser.CurrentHeader.Value.Size > 0) + { + Parser.HeaderEntries.push_back(Parser.CurrentHeader); + Parser.CurrentHeader = HeaderEntry{}; + } + + Parser.IsUpgrade = P->upgrade > 0; + + return 0; + }, + .on_message_complete = + [](http_parser* P) { + HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data); + Parser.IsComplete = true; + Parser.IsUpgrade = P->upgrade > 0; + return 0; + }}; + } + + struct MemStreamEntry + { + size_t Offset{}; + size_t Size{}; + }; + + class MemStream + { + public: + MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {} + + void Append(const char* Data, size_t Size) + { + const size_t NewSize = m_Size + Size; + + if (NewSize > m_Buf.size()) + { + m_Buf.resize(m_Buf.size() + m_BlockSize); + } + + memcpy(m_Buf.data() + m_Size, Data, Size); + m_Size += Size; + } + + const char* Data() const { return m_Buf.data(); } + size_t Pos() const { return m_Size; } + void Clear() { m_Size = 0; } + + private: + std::vector<char> m_Buf; + size_t m_Size{}; + size_t m_BlockSize{}; + }; + + using UrlEntry = MemStreamEntry; + + struct HeaderEntry + { + MemStreamEntry Name; + MemStreamEntry Value; + }; + + static http_parser_settings ParserSettings; + + http_parser Parser; + MemStream HeaderStream; + std::vector<HeaderEntry> HeaderEntries; + HeaderEntry CurrentHeader{}; + UrlEntry Url{}; + bool IsUpgrade = false; + bool IsComplete = false; +}; + +http_parser_settings HttpParser::ParserSettings; + +class AsioWebSocketServer final : public WebSocketServer +{ +public: + AsioWebSocketServer() : m_Log(zen::logging::Get("websocket")) { HttpParser::Initialize(); } + + virtual ~AsioWebSocketServer() { Shutdown(); } + + virtual bool Run(const WebSocketServerOptions& Options) override + { + m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoService, asio::ip::tcp::v6()); + + m_Acceptor->set_option(asio::ip::v6_only(false)); + m_Acceptor->set_option(asio::socket_base::reuse_address(true)); + m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + asio::error_code Ec; + m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec); + + if (Ec) + { + ZEN_ERROR("failed to bind websocket endpoint, error code '{}'", Ec.value()); + + return false; + } + + m_Acceptor->listen(); + + BeginAccept(); + + StartIoThreads(Options.ThreadCount); + + m_Running.store(true, std::memory_order_relaxed); + + ZEN_INFO("websocket server running on port '{}'", Options.Port); + + return true; + } + + virtual void Shutdown() override + { + if (m_Running) + { + ZEN_INFO("websocket server shutting down"); + + m_Running = false; + + m_Acceptor->close(); + m_Acceptor.reset(); + m_IoService.stop(); + + StopIoThreads(); + } + } + +private: + enum class WebSocketState : uint32_t + { + kNone, + kHandshake, + kRead, + kWrite, + kError + }; + + struct WebSocketConnection : public std::enable_shared_from_this<WebSocketConnection> + { + WebSocketConnection(std::unique_ptr<asio::ip::tcp::socket>&& S, uint32_t ConnId) : Socket(std::move(S)), Id(ConnId) {} + + std::unique_ptr<asio::ip::tcp::socket> Socket; + asio::streambuf ReadBuffer; + WebSocketState State; + HttpParser Parser; + uint32_t Id; + }; + + void BeginAccept() + { + auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoService); + asio::ip::tcp::socket& SocketRef = *Socket.get(); + + m_Acceptor->async_accept(SocketRef, [this, NewSocket = std::move(Socket)](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("accept error, error code '{}'", Ec.value()); + } + else + { + const uint32_t Id = m_ConnectionId.fetch_add(1); + auto Connection = std::make_shared<WebSocketConnection>(std::move(NewSocket), Id); + + Connection->State = WebSocketState::kHandshake; + + BeginRead(Connection); + } + + if (m_Running.load(std::memory_order_relaxed)) + { + BeginAccept(); + } + else + { + m_Acceptor->close(); + } + }); + } + + void BeginRead(std::shared_ptr<WebSocketConnection> Connection) + { + Connection->ReadBuffer.prepare(64 << 10); + + asio::async_read(*Connection->Socket, + Connection->ReadBuffer, + asio::transfer_at_least(1), + [Conn = Connection->shared_from_this(), this](const asio::error_code& Ec, std::size_t ByteCount) { + if (Ec) + { + ZEN_ERROR("read FAILED, connection '{}', error code '{}'", Conn->Id, Ec.value()); + Conn->Socket->close(); + return; + } + + ZEN_TRACE("reading {}B from connection '{}'", ByteCount, Conn->Id); + + WebSocketState NextState = WebSocketState::kError; + + switch (Conn->State) + { + case WebSocketState::kHandshake: + NextState = ProcessHandshake(Conn); + break; + } + + Conn->State = NextState; + + if (Conn->State == WebSocketState::kError) + { + ZEN_TRACE("process error, connection '{}'", Conn->Id); + Conn->Socket->close(); + return; + } + + BeginRead(Conn); + }); + } + + WebSocketState ProcessHandshake(std::shared_ptr<WebSocketConnection> Connection) + { + HttpParser& Parser = Connection->Parser; + const asio::const_buffer& Buffer = Connection->ReadBuffer.data(); + + const size_t BytesParsed = Parser.Parse(reinterpret_cast<const char*>(Buffer.data()), Buffer.size()); + Connection->ReadBuffer.consume(BytesParsed); + + if (Parser.IsComplete) + { + if (Parser.IsUpgrade == false) + { + ZEN_DEBUG("invalid websocket handshake request, closing connection '{}'", Connection->Id); + + return WebSocketState::kError; + } + + static constexpr std::string_view WebSocketKey = "Sec-WebSocket-Key"sv; + static constexpr std::string_view WebSocketOriginKey = "Sec-WebSocket-Origin"sv; + static constexpr std::string_view WebSocketProtocolKey = "Sec-WebSocket-Protocol"sv; + static constexpr std::string_view WebSocketVersionKey = "Sec-WebSocket-Version"sv; + static constexpr std::string_view WebSocketAcceptKey = "Sec-WebSocket-Accept"sv; + static constexpr std::string_view UpgradeKey = "Upgrade"sv; + static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + + std::unordered_map<std::string_view, std::string_view> Headers; + Parser.GetHeaders(Headers); + + ZEN_DEBUG("handshake, Origin='{}', Protocol='{}', Version='{}', Key='{}'", + Headers[WebSocketOriginKey], + Headers[WebSocketProtocolKey], + Headers[WebSocketVersionKey], + Headers[WebSocketKey]); + + ExtendableStringBuilder<128> Sb; + Sb << Headers[WebSocketKey] << WebSocketGuid; + + SHA1Stream HashStream; + HashStream.Append(Sb.Data(), Sb.Size()); + + SHA1 Hash = HashStream.GetHash(); + Sb.Reset(); + + const uint32_t EncodedSize = Base64::GetEncodedDataSize(sizeof(SHA1::Hash)); + Sb.AddUninitialized(EncodedSize); + Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), Sb.Data()); + + std::string AcceptHash = Sb.ToString(); + + Sb.Reset(); + Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: Upgrade\r\n"sv; + Sb << WebSocketProtocolKey << ": " << Headers[WebSocketProtocolKey] << "\r\n"; + Sb << WebSocketAcceptKey << ": " << AcceptHash << "\r\n" + << "\r\n"sv; + + IoBuffer Response = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView(Sb.ToView())); + asio::const_buffer ResponseView(Response.Data(), Response.Size()); + const uint64_t ResponseLength = Response.Size(); + + asio::async_write( + *Connection->Socket, + asio::buffer(Response.Data(), Response.Size()), + asio::transfer_exactly(ResponseLength), + [this, Conn = Connection->shared_from_this(), Buf = Response](const asio::error_code& Ec, std::size_t ByteCount) { + if (Ec) + { + ZEN_ERROR("write {}B FAILED, error code '{}'", ByteCount, Ec.value()); + } + else + { + ZEN_DEBUG("write {}B OK", ByteCount); + } + }); + + return WebSocketState::kRead; + } + + return WebSocketState::kHandshake; + } + + void StartIoThreads(uint32_t ThreadCount) + { + ZEN_DEBUG("starting '{}' websocket I/O thread(s)"); + + for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) + { + m_ThreadPool.emplace_back([this, ThreadId = Idx + 1] { + try + { + m_IoService.run(); + } + catch (std::exception& Err) + { + ZEN_ERROR("process websocket request FAILED, reason '{}'", Err.what()); + } + + ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId); + }); + } + } + + void StopIoThreads() + { + for (std::thread& Thread : m_ThreadPool) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_ThreadPool.clear(); + } + + spdlog::logger& Log() { return m_Log; } + spdlog::logger& m_Log; + + asio::io_service m_IoService; + std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; + std::atomic_bool m_Running{}; + std::atomic_uint32_t m_ConnectionId{1}; + std::vector<std::thread> m_ThreadPool; +}; + +} // namespace zen::asio_http + +namespace zen { + +std::unique_ptr<WebSocketServer> +WebSocketServer::Create() +{ + return std::make_unique<asio_http::AsioWebSocketServer>(); +} + +} // namespace zen |