aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/websocketserver.cpp
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-02-09 20:51:59 +0100
committerPer Larsson <[email protected]>2022-02-09 20:51:59 +0100
commit8692e7882ed9b7e8d3fb298ba51b7779d58a73b9 (patch)
tree7738ea0431ea46b98ac30eb0ab2702968f61b25d /zenhttp/websocketserver.cpp
parentMerge branch 'main' of https://github.com/EpicGames/zen (diff)
downloadzen-8692e7882ed9b7e8d3fb298ba51b7779d58a73b9.tar.xz
zen-8692e7882ed9b7e8d3fb298ba51b7779d58a73b9.zip
Initial websocket support.
Diffstat (limited to 'zenhttp/websocketserver.cpp')
-rw-r--r--zenhttp/websocketserver.cpp471
1 files changed, 471 insertions, 0 deletions
diff --git a/zenhttp/websocketserver.cpp b/zenhttp/websocketserver.cpp
new file mode 100644
index 000000000..776ed1019
--- /dev/null
+++ b/zenhttp/websocketserver.cpp
@@ -0,0 +1,471 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/websocketserver.h>
+
+#include <zencore/base64.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/sha1.h>
+#include <zencore/string.h>
+
+#include <span>
+#include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <http_parser.h>
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::asio_http {
+
+using namespace std::literals;
+
+struct HttpParser
+{
+ HttpParser()
+ {
+ http_parser_init(&Parser, HTTP_REQUEST);
+ Parser.data = this;
+ }
+
+ size_t Parse(const char* Data, const size_t Size) { return http_parser_execute(&Parser, &ParserSettings, Data, Size); }
+
+ void GetHeaders(std::unordered_map<std::string_view, std::string_view>& OutHeaders)
+ {
+ OutHeaders.reserve(HeaderEntries.size());
+
+ for (const auto& E : HeaderEntries)
+ {
+ auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size);
+ auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size);
+
+ OutHeaders[Name] = Value;
+ }
+ }
+
+ static void Initialize()
+ {
+ ParserSettings = {.on_message_begin =
+ [](http_parser* P) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+
+ Parser.Url = UrlEntry{};
+ Parser.CurrentHeader = HeaderEntry{};
+ Parser.IsUpgrade = false;
+ Parser.IsComplete = false;
+
+ Parser.HeaderStream.Clear();
+ Parser.HeaderEntries.clear();
+
+ return 0;
+ },
+ .on_url =
+ [](http_parser* P, const char* Data, size_t Size) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+
+ Parser.Url.Offset = Parser.HeaderStream.Pos();
+ Parser.Url.Size = Size;
+
+ Parser.HeaderStream.Append(Data, uint32_t(Size));
+
+ return 0;
+ },
+ .on_header_field =
+ [](http_parser* P, const char* Data, size_t Size) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+
+ if (Parser.CurrentHeader.Value.Size > 0)
+ {
+ Parser.HeaderEntries.push_back(Parser.CurrentHeader);
+ Parser.CurrentHeader = HeaderEntry{};
+ }
+
+ if (Parser.CurrentHeader.Name.Size == 0)
+ {
+ Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos();
+ }
+
+ Parser.CurrentHeader.Name.Size += Size;
+
+ Parser.HeaderStream.Append(Data, Size);
+
+ return 0;
+ },
+ .on_header_value =
+ [](http_parser* P, const char* Data, size_t Size) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+
+ if (Parser.CurrentHeader.Value.Size == 0)
+ {
+ Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos();
+ }
+
+ Parser.CurrentHeader.Value.Size += Size;
+
+ Parser.HeaderStream.Append(Data, Size);
+
+ return 0;
+ },
+ .on_headers_complete =
+ [](http_parser* P) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+
+ if (Parser.CurrentHeader.Value.Size > 0)
+ {
+ Parser.HeaderEntries.push_back(Parser.CurrentHeader);
+ Parser.CurrentHeader = HeaderEntry{};
+ }
+
+ Parser.IsUpgrade = P->upgrade > 0;
+
+ return 0;
+ },
+ .on_message_complete =
+ [](http_parser* P) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+ Parser.IsComplete = true;
+ Parser.IsUpgrade = P->upgrade > 0;
+ return 0;
+ }};
+ }
+
+ struct MemStreamEntry
+ {
+ size_t Offset{};
+ size_t Size{};
+ };
+
+ class MemStream
+ {
+ public:
+ MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {}
+
+ void Append(const char* Data, size_t Size)
+ {
+ const size_t NewSize = m_Size + Size;
+
+ if (NewSize > m_Buf.size())
+ {
+ m_Buf.resize(m_Buf.size() + m_BlockSize);
+ }
+
+ memcpy(m_Buf.data() + m_Size, Data, Size);
+ m_Size += Size;
+ }
+
+ const char* Data() const { return m_Buf.data(); }
+ size_t Pos() const { return m_Size; }
+ void Clear() { m_Size = 0; }
+
+ private:
+ std::vector<char> m_Buf;
+ size_t m_Size{};
+ size_t m_BlockSize{};
+ };
+
+ using UrlEntry = MemStreamEntry;
+
+ struct HeaderEntry
+ {
+ MemStreamEntry Name;
+ MemStreamEntry Value;
+ };
+
+ static http_parser_settings ParserSettings;
+
+ http_parser Parser;
+ MemStream HeaderStream;
+ std::vector<HeaderEntry> HeaderEntries;
+ HeaderEntry CurrentHeader{};
+ UrlEntry Url{};
+ bool IsUpgrade = false;
+ bool IsComplete = false;
+};
+
+http_parser_settings HttpParser::ParserSettings;
+
+class AsioWebSocketServer final : public WebSocketServer
+{
+public:
+ AsioWebSocketServer() : m_Log(zen::logging::Get("websocket")) { HttpParser::Initialize(); }
+
+ virtual ~AsioWebSocketServer() { Shutdown(); }
+
+ virtual bool Run(const WebSocketServerOptions& Options) override
+ {
+ m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoService, asio::ip::tcp::v6());
+
+ m_Acceptor->set_option(asio::ip::v6_only(false));
+ m_Acceptor->set_option(asio::socket_base::reuse_address(true));
+ m_Acceptor->set_option(asio::ip::tcp::no_delay(true));
+ m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ asio::error_code Ec;
+ m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec);
+
+ if (Ec)
+ {
+ ZEN_ERROR("failed to bind websocket endpoint, error code '{}'", Ec.value());
+
+ return false;
+ }
+
+ m_Acceptor->listen();
+
+ BeginAccept();
+
+ StartIoThreads(Options.ThreadCount);
+
+ m_Running.store(true, std::memory_order_relaxed);
+
+ ZEN_INFO("websocket server running on port '{}'", Options.Port);
+
+ return true;
+ }
+
+ virtual void Shutdown() override
+ {
+ if (m_Running)
+ {
+ ZEN_INFO("websocket server shutting down");
+
+ m_Running = false;
+
+ m_Acceptor->close();
+ m_Acceptor.reset();
+ m_IoService.stop();
+
+ StopIoThreads();
+ }
+ }
+
+private:
+ enum class WebSocketState : uint32_t
+ {
+ kNone,
+ kHandshake,
+ kRead,
+ kWrite,
+ kError
+ };
+
+ struct WebSocketConnection : public std::enable_shared_from_this<WebSocketConnection>
+ {
+ WebSocketConnection(std::unique_ptr<asio::ip::tcp::socket>&& S, uint32_t ConnId) : Socket(std::move(S)), Id(ConnId) {}
+
+ std::unique_ptr<asio::ip::tcp::socket> Socket;
+ asio::streambuf ReadBuffer;
+ WebSocketState State;
+ HttpParser Parser;
+ uint32_t Id;
+ };
+
+ void BeginAccept()
+ {
+ auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoService);
+ asio::ip::tcp::socket& SocketRef = *Socket.get();
+
+ m_Acceptor->async_accept(SocketRef, [this, NewSocket = std::move(Socket)](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ ZEN_WARN("accept error, error code '{}'", Ec.value());
+ }
+ else
+ {
+ const uint32_t Id = m_ConnectionId.fetch_add(1);
+ auto Connection = std::make_shared<WebSocketConnection>(std::move(NewSocket), Id);
+
+ Connection->State = WebSocketState::kHandshake;
+
+ BeginRead(Connection);
+ }
+
+ if (m_Running.load(std::memory_order_relaxed))
+ {
+ BeginAccept();
+ }
+ else
+ {
+ m_Acceptor->close();
+ }
+ });
+ }
+
+ void BeginRead(std::shared_ptr<WebSocketConnection> Connection)
+ {
+ Connection->ReadBuffer.prepare(64 << 10);
+
+ asio::async_read(*Connection->Socket,
+ Connection->ReadBuffer,
+ asio::transfer_at_least(1),
+ [Conn = Connection->shared_from_this(), this](const asio::error_code& Ec, std::size_t ByteCount) {
+ if (Ec)
+ {
+ ZEN_ERROR("read FAILED, connection '{}', error code '{}'", Conn->Id, Ec.value());
+ Conn->Socket->close();
+ return;
+ }
+
+ ZEN_TRACE("reading {}B from connection '{}'", ByteCount, Conn->Id);
+
+ WebSocketState NextState = WebSocketState::kError;
+
+ switch (Conn->State)
+ {
+ case WebSocketState::kHandshake:
+ NextState = ProcessHandshake(Conn);
+ break;
+ }
+
+ Conn->State = NextState;
+
+ if (Conn->State == WebSocketState::kError)
+ {
+ ZEN_TRACE("process error, connection '{}'", Conn->Id);
+ Conn->Socket->close();
+ return;
+ }
+
+ BeginRead(Conn);
+ });
+ }
+
+ WebSocketState ProcessHandshake(std::shared_ptr<WebSocketConnection> Connection)
+ {
+ HttpParser& Parser = Connection->Parser;
+ const asio::const_buffer& Buffer = Connection->ReadBuffer.data();
+
+ const size_t BytesParsed = Parser.Parse(reinterpret_cast<const char*>(Buffer.data()), Buffer.size());
+ Connection->ReadBuffer.consume(BytesParsed);
+
+ if (Parser.IsComplete)
+ {
+ if (Parser.IsUpgrade == false)
+ {
+ ZEN_DEBUG("invalid websocket handshake request, closing connection '{}'", Connection->Id);
+
+ return WebSocketState::kError;
+ }
+
+ static constexpr std::string_view WebSocketKey = "Sec-WebSocket-Key"sv;
+ static constexpr std::string_view WebSocketOriginKey = "Sec-WebSocket-Origin"sv;
+ static constexpr std::string_view WebSocketProtocolKey = "Sec-WebSocket-Protocol"sv;
+ static constexpr std::string_view WebSocketVersionKey = "Sec-WebSocket-Version"sv;
+ static constexpr std::string_view WebSocketAcceptKey = "Sec-WebSocket-Accept"sv;
+ static constexpr std::string_view UpgradeKey = "Upgrade"sv;
+ static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv;
+
+ std::unordered_map<std::string_view, std::string_view> Headers;
+ Parser.GetHeaders(Headers);
+
+ ZEN_DEBUG("handshake, Origin='{}', Protocol='{}', Version='{}', Key='{}'",
+ Headers[WebSocketOriginKey],
+ Headers[WebSocketProtocolKey],
+ Headers[WebSocketVersionKey],
+ Headers[WebSocketKey]);
+
+ ExtendableStringBuilder<128> Sb;
+ Sb << Headers[WebSocketKey] << WebSocketGuid;
+
+ SHA1Stream HashStream;
+ HashStream.Append(Sb.Data(), Sb.Size());
+
+ SHA1 Hash = HashStream.GetHash();
+ Sb.Reset();
+
+ const uint32_t EncodedSize = Base64::GetEncodedDataSize(sizeof(SHA1::Hash));
+ Sb.AddUninitialized(EncodedSize);
+ Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), Sb.Data());
+
+ std::string AcceptHash = Sb.ToString();
+
+ Sb.Reset();
+ Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv;
+ Sb << "Upgrade: websocket\r\n"sv;
+ Sb << "Connection: Upgrade\r\n"sv;
+ Sb << WebSocketProtocolKey << ": " << Headers[WebSocketProtocolKey] << "\r\n";
+ Sb << WebSocketAcceptKey << ": " << AcceptHash << "\r\n"
+ << "\r\n"sv;
+
+ IoBuffer Response = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView(Sb.ToView()));
+ asio::const_buffer ResponseView(Response.Data(), Response.Size());
+ const uint64_t ResponseLength = Response.Size();
+
+ asio::async_write(
+ *Connection->Socket,
+ asio::buffer(Response.Data(), Response.Size()),
+ asio::transfer_exactly(ResponseLength),
+ [this, Conn = Connection->shared_from_this(), Buf = Response](const asio::error_code& Ec, std::size_t ByteCount) {
+ if (Ec)
+ {
+ ZEN_ERROR("write {}B FAILED, error code '{}'", ByteCount, Ec.value());
+ }
+ else
+ {
+ ZEN_DEBUG("write {}B OK", ByteCount);
+ }
+ });
+
+ return WebSocketState::kRead;
+ }
+
+ return WebSocketState::kHandshake;
+ }
+
+ void StartIoThreads(uint32_t ThreadCount)
+ {
+ ZEN_DEBUG("starting '{}' websocket I/O thread(s)");
+
+ for (uint32_t Idx = 0; Idx < ThreadCount; Idx++)
+ {
+ m_ThreadPool.emplace_back([this, ThreadId = Idx + 1] {
+ try
+ {
+ m_IoService.run();
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("process websocket request FAILED, reason '{}'", Err.what());
+ }
+
+ ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId);
+ });
+ }
+ }
+
+ void StopIoThreads()
+ {
+ for (std::thread& Thread : m_ThreadPool)
+ {
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
+ }
+
+ m_ThreadPool.clear();
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+ spdlog::logger& m_Log;
+
+ asio::io_service m_IoService;
+ std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor;
+ std::atomic_bool m_Running{};
+ std::atomic_uint32_t m_ConnectionId{1};
+ std::vector<std::thread> m_ThreadPool;
+};
+
+} // namespace zen::asio_http
+
+namespace zen {
+
+std::unique_ptr<WebSocketServer>
+WebSocketServer::Create()
+{
+ return std::make_unique<asio_http::AsioWebSocketServer>();
+}
+
+} // namespace zen