aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/websocketasio.cpp
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-02-15 14:07:40 +0100
committerPer Larsson <[email protected]>2022-02-15 14:07:40 +0100
commit386f56cd3b7f06fc30318adcbdc0753ddc02c127 (patch)
treebdf94e1cc1b7076745273b1633730d2710a2e305 /zenhttp/websocketasio.cpp
parentRefactored websocket server and added static logger support. (diff)
downloadzen-386f56cd3b7f06fc30318adcbdc0753ddc02c127.tar.xz
zen-386f56cd3b7f06fc30318adcbdc0753ddc02c127.zip
Renamed asio web socket impl.
Diffstat (limited to 'zenhttp/websocketasio.cpp')
-rw-r--r--zenhttp/websocketasio.cpp683
1 files changed, 683 insertions, 0 deletions
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp
new file mode 100644
index 000000000..bb3999780
--- /dev/null
+++ b/zenhttp/websocketasio.cpp
@@ -0,0 +1,683 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/websocketserver.h>
+
+#include <zencore/base64.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/sha1.h>
+#include <zencore/string.h>
+
+#include <chrono>
+#include <compare>
+#include <optional>
+#include <shared_mutex>
+#include <span>
+#include <system_error>
+#include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <http_parser.h>
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::asio_ws {
+
+using namespace std::literals;
+
+ZEN_DEFINE_LOG_CATEGORY_STATIC(WsLog, "websocket"sv);
+
+using Clock = std::chrono::steady_clock;
+using TimePoint = Clock::time_point;
+
+///////////////////////////////////////////////////////////////////////////////
+namespace http_header {
+ static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv;
+ static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv;
+ static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv;
+ static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv;
+ static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv;
+ static constexpr std::string_view Upgrade = "Upgrade"sv;
+} // namespace http_header
+
+///////////////////////////////////////////////////////////////////////////////
+struct HttpParser
+{
+ HttpParser()
+ {
+ http_parser_init(&Parser, HTTP_REQUEST);
+ Parser.data = this;
+ }
+
+ size_t Parse(asio::const_buffer Buffer)
+ {
+ return http_parser_execute(&Parser, &ParserSettings, reinterpret_cast<const char*>(Buffer.data()), Buffer.size());
+ }
+
+ void GetHeaders(std::unordered_map<std::string_view, std::string_view>& OutHeaders)
+ {
+ OutHeaders.reserve(HeaderEntries.size());
+
+ for (const auto& E : HeaderEntries)
+ {
+ auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size);
+ auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size);
+
+ OutHeaders[Name] = Value;
+ }
+ }
+
+ std::string ValidateWebSocketHandshake(std::unordered_map<std::string_view, std::string_view>& Headers, std::string& OutReason)
+ {
+ static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv;
+
+ std::string AcceptHash;
+
+ if (Headers.contains(http_header::SecWebSocketKey) == false)
+ {
+ OutReason = "Missing header Sec-WebSocket-Key";
+ return AcceptHash;
+ }
+
+ if (Headers.contains(http_header::Upgrade) == false)
+ {
+ OutReason = "Missing header Upgrade";
+ return AcceptHash;
+ }
+
+ ExtendableStringBuilder<128> Sb;
+ Sb << Headers[http_header::SecWebSocketKey] << WebSocketGuid;
+
+ SHA1Stream HashStream;
+ HashStream.Append(Sb.Data(), Sb.Size());
+
+ SHA1 Hash = HashStream.GetHash();
+
+ AcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash)));
+ Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), AcceptHash.data());
+
+ return AcceptHash;
+ }
+
+ static void Initialize()
+ {
+ ParserSettings = {.on_message_begin =
+ [](http_parser* P) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+
+ Parser.Url = UrlEntry{};
+ Parser.CurrentHeader = HeaderEntry{};
+ Parser.IsUpgrade = false;
+ Parser.IsComplete = false;
+
+ Parser.HeaderStream.Clear();
+ Parser.HeaderEntries.clear();
+
+ return 0;
+ },
+ .on_url =
+ [](http_parser* P, const char* Data, size_t Size) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+
+ Parser.Url.Offset = Parser.HeaderStream.Pos();
+ Parser.Url.Size = Size;
+
+ Parser.HeaderStream.Append(Data, uint32_t(Size));
+
+ return 0;
+ },
+ .on_header_field =
+ [](http_parser* P, const char* Data, size_t Size) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+
+ if (Parser.CurrentHeader.Value.Size > 0)
+ {
+ Parser.HeaderEntries.push_back(Parser.CurrentHeader);
+ Parser.CurrentHeader = HeaderEntry{};
+ }
+
+ if (Parser.CurrentHeader.Name.Size == 0)
+ {
+ Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos();
+ }
+
+ Parser.CurrentHeader.Name.Size += Size;
+
+ Parser.HeaderStream.Append(Data, Size);
+
+ return 0;
+ },
+ .on_header_value =
+ [](http_parser* P, const char* Data, size_t Size) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+
+ if (Parser.CurrentHeader.Value.Size == 0)
+ {
+ Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos();
+ }
+
+ Parser.CurrentHeader.Value.Size += Size;
+
+ Parser.HeaderStream.Append(Data, Size);
+
+ return 0;
+ },
+ .on_headers_complete =
+ [](http_parser* P) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+
+ if (Parser.CurrentHeader.Value.Size > 0)
+ {
+ Parser.HeaderEntries.push_back(Parser.CurrentHeader);
+ Parser.CurrentHeader = HeaderEntry{};
+ }
+
+ Parser.IsUpgrade = P->upgrade > 0;
+
+ return 0;
+ },
+ .on_message_complete =
+ [](http_parser* P) {
+ HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
+ Parser.IsComplete = true;
+ Parser.IsUpgrade = P->upgrade > 0;
+ return 0;
+ }};
+ }
+
+ struct MemStreamEntry
+ {
+ size_t Offset{};
+ size_t Size{};
+ };
+
+ class MemStream
+ {
+ public:
+ MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {}
+
+ void Append(const char* Data, size_t Size)
+ {
+ const size_t NewSize = m_Size + Size;
+
+ if (NewSize > m_Buf.size())
+ {
+ m_Buf.resize(m_Buf.size() + m_BlockSize);
+ }
+
+ memcpy(m_Buf.data() + m_Size, Data, Size);
+ m_Size += Size;
+ }
+
+ const char* Data() const { return m_Buf.data(); }
+ size_t Pos() const { return m_Size; }
+ void Clear() { m_Size = 0; }
+
+ private:
+ std::vector<char> m_Buf;
+ size_t m_Size{};
+ size_t m_BlockSize{};
+ };
+
+ using UrlEntry = MemStreamEntry;
+
+ struct HeaderEntry
+ {
+ MemStreamEntry Name;
+ MemStreamEntry Value;
+ };
+
+ static http_parser_settings ParserSettings;
+
+ http_parser Parser;
+ MemStream HeaderStream;
+ std::vector<HeaderEntry> HeaderEntries;
+ HeaderEntry CurrentHeader{};
+ UrlEntry Url{};
+ bool IsUpgrade = false;
+ bool IsComplete = false;
+};
+
+http_parser_settings HttpParser::ParserSettings;
+
+///////////////////////////////////////////////////////////////////////////////
+enum class WsConnectionState : uint32_t
+{
+ kDisconnected,
+ kHandshaking,
+ kConnected
+};
+
+///////////////////////////////////////////////////////////////////////////////
+class WsConnectionId
+{
+ static std::atomic_uint32_t WsConnectionCounter;
+
+public:
+ WsConnectionId() = default;
+
+ uint32_t Value() const { return m_Value; }
+
+ auto operator<=>(const WsConnectionId& RHS) const = default;
+
+ static WsConnectionId New() { return WsConnectionId(WsConnectionCounter.fetch_add(1)); }
+
+private:
+ WsConnectionId(uint32_t Value) : m_Value(Value) {}
+
+ uint32_t m_Value{};
+};
+
+std::atomic_uint32_t WsConnectionId::WsConnectionCounter{1};
+
+class WsServer;
+
+///////////////////////////////////////////////////////////////////////////////
+class WsConnection : public std::enable_shared_from_this<WsConnection>
+{
+public:
+ WsConnection(WsServer& Server, WsConnectionId Id, std::unique_ptr<asio::ip::tcp::socket> Socket)
+ : m_Server(Server)
+ , m_Id(Id)
+ , m_Socket(std::move(Socket))
+ , m_StartTime(Clock::now())
+ , m_Status()
+ {
+ }
+
+ ~WsConnection();
+
+ WsConnectionId Id() const { return m_Id; }
+ asio::ip::tcp::socket& Socket() { return *m_Socket; }
+ TimePoint StartTime() const { return m_StartTime; }
+ std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); }
+ asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
+ HttpParser& ParserHttp() { return *m_HttpParser; }
+ WsConnectionState Close();
+ WsConnectionState State() const { return static_cast<WsConnectionState>(m_Status.load(std::memory_order_relaxed)); }
+
+ WsConnectionState SetState(WsConnectionState NewState) { return static_cast<WsConnectionState>(m_Status.exchange(uint32_t(NewState))); }
+
+ void InitializeHttpParser() { m_HttpParser = std::make_unique<HttpParser>(); }
+ void ReleaseHttpParser() { m_HttpParser.reset(); }
+
+private:
+ WsServer& m_Server;
+ WsConnectionId m_Id;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<HttpParser> m_HttpParser;
+ TimePoint m_StartTime;
+ std::atomic_uint32_t m_Status;
+ asio::streambuf m_ReadBuffer;
+};
+
+WsConnectionState
+WsConnection::Close()
+{
+ using enum WsConnectionState;
+
+ const auto PrevState = SetState(kDisconnected);
+
+ if (PrevState != kDisconnected && m_Socket->is_open())
+ {
+ m_Socket->close();
+ }
+
+ return PrevState;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsThreadPool
+{
+public:
+ WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {}
+ void Start(uint32_t ThreadCount);
+ void Stop();
+
+private:
+ asio::io_service& m_IoSvc;
+ std::vector<std::thread> m_Threads;
+};
+
+void
+WsThreadPool::Start(uint32_t ThreadCount)
+{
+ ZEN_ASSERT(m_Threads.empty());
+
+ ZEN_LOG_DEBUG(WsLog, "starting '{}' websocket I/O thread(s)", ThreadCount);
+
+ for (uint32_t Idx = 0; Idx < ThreadCount; Idx++)
+ {
+ m_Threads.emplace_back([this, ThreadId = Idx + 1] {
+ try
+ {
+ m_IoSvc.run();
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_LOG_ERROR(WsLog, "process websocket I/O FAILED, reason '{}'", Err.what());
+ }
+
+ ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId);
+ });
+ }
+}
+
+void
+WsThreadPool::Stop()
+{
+ for (std::thread& Thread : m_Threads)
+ {
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
+ }
+
+ m_Threads.clear();
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsServer final : public WebSocketServer
+{
+public:
+ WsServer() = default;
+ virtual ~WsServer() { Shutdown(); }
+
+ virtual bool Run(const WebSocketServerOptions& Options) override;
+ virtual void Shutdown() override;
+
+private:
+ friend class WsConnection;
+
+ void AcceptConnection();
+ void CloseConnection(WsConnection& Connection, const std::error_code& Ec);
+ void RemoveConnection(WsConnection& Connection);
+ void ReadConnection(WsConnection& Connection);
+
+ struct IdHasher
+ {
+ size_t operator()(WsConnectionId Id) const { return size_t(Id.Value()); }
+ };
+
+ using ConnectionMap = std::unordered_map<WsConnectionId, std::shared_ptr<WsConnection>, IdHasher>;
+
+ asio::io_service m_IoSvc;
+ std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor;
+ std::unique_ptr<WsThreadPool> m_ThreadPool;
+ ConnectionMap m_Connections;
+ std::shared_mutex m_ConnMutex;
+ std::atomic_bool m_Running{};
+};
+
+WsConnection::~WsConnection()
+{
+ m_Server.RemoveConnection(*this);
+}
+
+bool
+WsServer::Run(const WebSocketServerOptions& Options)
+{
+ HttpParser::Initialize();
+
+ m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6());
+
+ m_Acceptor->set_option(asio::ip::v6_only(false));
+ m_Acceptor->set_option(asio::socket_base::reuse_address(true));
+ m_Acceptor->set_option(asio::ip::tcp::no_delay(true));
+ m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ asio::error_code Ec;
+ m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec);
+
+ if (Ec)
+ {
+ ZEN_LOG_ERROR(WsLog, "failed to bind websocket endpoint, error code '{}'", Ec.value());
+
+ return false;
+ }
+
+ m_Acceptor->listen();
+ m_Running = true;
+
+ ZEN_LOG_INFO(WsLog, "web socket server running on port '{}'", Options.Port);
+
+ AcceptConnection();
+
+ m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc);
+ m_ThreadPool->Start(Options.ThreadCount);
+
+ return true;
+}
+
+void
+WsServer::Shutdown()
+{
+ if (m_Running)
+ {
+ ZEN_LOG_INFO(WsLog, "websocket server shutting down");
+
+ m_Running = false;
+
+ m_Acceptor->close();
+ m_Acceptor.reset();
+ m_IoSvc.stop();
+
+ m_ThreadPool->Stop();
+ }
+}
+
+void
+WsServer::AcceptConnection()
+{
+ auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc);
+ asio::ip::tcp::socket& SocketRef = *Socket.get();
+
+ m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ ZEN_LOG_WARN(WsLog, "accept connection FAILED, error code '{}'", Ec.value());
+ }
+ else
+ {
+ auto ConnId = WsConnectionId::New();
+
+ ZEN_LOG_DEBUG(WsLog, "accept connection OK, ID '{}'", ConnId.Value());
+
+ auto Connection = std::make_shared<WsConnection>(*this, ConnId, std::move(ConnectedSocket));
+
+ {
+ std::unique_lock _(m_ConnMutex);
+ m_Connections[ConnId] = Connection;
+ }
+
+ Connection->InitializeHttpParser();
+ Connection->SetState(WsConnectionState::kHandshaking);
+
+ ReadConnection(*Connection);
+ }
+
+ if (m_Running)
+ {
+ AcceptConnection();
+ }
+ });
+}
+
+void
+WsServer::CloseConnection(WsConnection& Connection, const std::error_code& Ec)
+{
+ if (const auto State = Connection.Close(); State != WsConnectionState::kDisconnected)
+ {
+ if (Ec)
+ {
+ ZEN_LOG_INFO(WsLog,
+ "closing connection '{}' ERROR, reason '{}' error code '{}'",
+ Connection.Id().Value(),
+ Ec.message(),
+ Ec.value());
+ }
+ else
+ {
+ ZEN_LOG_INFO(WsLog, "closing connection '{}'", Connection.Id().Value());
+ }
+ }
+}
+
+void
+WsServer::RemoveConnection(WsConnection& Connection)
+{
+ ZEN_LOG_INFO(WsLog, "removing connection '{}'", Connection.Id().Value());
+}
+
+void
+WsServer::ReadConnection(WsConnection& Connection)
+{
+ Connection.ReadBuffer().prepare(64 << 10);
+
+ asio::async_read(
+ Connection.Socket(),
+ Connection.ReadBuffer(),
+ asio::transfer_at_least(1),
+ [this, &Connection](const asio::error_code& ReadEc, std::size_t ByteCount) mutable {
+ if (ReadEc)
+ {
+ return CloseConnection(Connection, ReadEc);
+ }
+
+ ZEN_LOG_DEBUG(WsLog, "reading {}B from connection '{}'", ByteCount, Connection.Id().Value());
+
+ using enum WsConnectionState;
+
+ switch (Connection.State())
+ {
+ case kHandshaking:
+ {
+ HttpParser& Parser = Connection.ParserHttp();
+ const size_t Consumed = Parser.Parse(Connection.ReadBuffer().data());
+ Connection.ReadBuffer().consume(Consumed);
+
+ if (Parser.IsComplete == false)
+ {
+ return ReadConnection(Connection);
+ }
+
+ if (Parser.IsUpgrade == false)
+ {
+ ZEN_LOG_DEBUG(WsLog,
+ "handshake with connection '{}' FAILED, reason 'not an upgrade request'",
+ Connection.Id().Value());
+
+ constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv;
+
+ return async_write(Connection.Socket(),
+ asio::buffer(UpgradeRequiredResponse),
+ [this, &Connection](const asio::error_code& WriteEc, std::size_t) {
+ if (WriteEc)
+ {
+ CloseConnection(Connection, WriteEc);
+ }
+ else
+ {
+ Connection.InitializeHttpParser();
+ Connection.SetState(WsConnectionState::kHandshaking);
+
+ ReadConnection(Connection);
+ }
+ });
+ }
+
+ std::unordered_map<std::string_view, std::string_view> Headers;
+ Parser.GetHeaders(Headers);
+
+ std::string Reason;
+ std::string AcceptHash = Parser.ValidateWebSocketHandshake(Headers, Reason);
+
+ if (AcceptHash.empty())
+ {
+ ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' FAILED, reason '{}'", Connection.Id().Value(), Reason);
+
+ constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv;
+
+ return async_write(Connection.Socket(),
+ asio::buffer(UpgradeRequiredResponse),
+ [this, &Connection](const asio::error_code& WriteEc, std::size_t) {
+ if (WriteEc)
+ {
+ CloseConnection(Connection, WriteEc);
+ }
+ else
+ {
+ Connection.InitializeHttpParser();
+ Connection.SetState(WsConnectionState::kHandshaking);
+
+ ReadConnection(Connection);
+ }
+ });
+ }
+
+ ExtendableStringBuilder<128> Sb;
+
+ Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv;
+ Sb << "Upgrade: websocket\r\n"sv;
+ Sb << "Connection: Upgrade\r\n"sv;
+
+ // TODO: Verify protocol
+ if (Headers.contains(http_header::SecWebSocketProtocol))
+ {
+ Sb << http_header::SecWebSocketProtocol << ": " << Headers[http_header::SecWebSocketProtocol] << "\r\n";
+ }
+
+ Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n";
+ Sb << "\r\n"sv;
+
+ std::string Response = Sb.ToString();
+ asio::const_buffer Buffer = asio::buffer(Response);
+
+ ZEN_LOG_DEBUG(WsLog, "accepting handshake from connection '{}'", Connection.Id().Value());
+
+ async_write(Connection.Socket(),
+ Buffer,
+ [this, &Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t) {
+ if (WriteEc)
+ {
+ CloseConnection(Connection, WriteEc);
+ }
+ else
+ {
+ ZEN_LOG_DEBUG(WsLog, "handshake with connection '{}' OK", Connection.Id().Value());
+
+ Connection.ReleaseHttpParser();
+ Connection.SetState(kConnected);
+
+ ReadConnection(Connection);
+ }
+ });
+ }
+ break;
+
+ case kConnected:
+ {
+ // TODO: Implement RPC API
+ }
+ break;
+
+ default:
+ break;
+ };
+ });
+}
+
+} // namespace zen::asio_ws
+
+namespace zen {
+
+std::unique_ptr<WebSocketServer>
+WebSocketServer::Create()
+{
+ return std::make_unique<asio_ws::WsServer>();
+}
+
+} // namespace zen