aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/websocketserver.cpp
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-02-15 14:01:27 +0100
committerPer Larsson <[email protected]>2022-02-15 14:01:27 +0100
commit921ce02cca7c15113452fde59bafc8fb58663b98 (patch)
treeaa156785a389f0713a9999a68aaad25b851c634b /zenhttp/websocketserver.cpp
parentInitial websocket support. (diff)
downloadzen-921ce02cca7c15113452fde59bafc8fb58663b98.tar.xz
zen-921ce02cca7c15113452fde59bafc8fb58663b98.zip
Refactored websocket server and added static logger support.
Diffstat (limited to 'zenhttp/websocketserver.cpp')
-rw-r--r--zenhttp/websocketserver.cpp471
1 files changed, 0 insertions, 471 deletions
diff --git a/zenhttp/websocketserver.cpp b/zenhttp/websocketserver.cpp
deleted file mode 100644
index 776ed1019..000000000
--- a/zenhttp/websocketserver.cpp
+++ /dev/null
@@ -1,471 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include <zenhttp/websocketserver.h>
-
-#include <zencore/base64.h>
-#include <zencore/iobuffer.h>
-#include <zencore/logging.h>
-#include <zencore/sha1.h>
-#include <zencore/string.h>
-
-#include <span>
-#include <thread>
-
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <fmt/format.h>
-#include <http_parser.h>
-#include <asio.hpp>
-ZEN_THIRD_PARTY_INCLUDES_END
-
-namespace zen::asio_http {
-
-using namespace std::literals;
-
-struct HttpParser
-{
- HttpParser()
- {
- http_parser_init(&Parser, HTTP_REQUEST);
- Parser.data = this;
- }
-
- size_t Parse(const char* Data, const size_t Size) { return http_parser_execute(&Parser, &ParserSettings, Data, Size); }
-
- void GetHeaders(std::unordered_map<std::string_view, std::string_view>& OutHeaders)
- {
- OutHeaders.reserve(HeaderEntries.size());
-
- for (const auto& E : HeaderEntries)
- {
- auto Name = std::string_view(HeaderStream.Data() + E.Name.Offset, E.Name.Size);
- auto Value = std::string_view(HeaderStream.Data() + E.Value.Offset, E.Value.Size);
-
- OutHeaders[Name] = Value;
- }
- }
-
- static void Initialize()
- {
- ParserSettings = {.on_message_begin =
- [](http_parser* P) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
-
- Parser.Url = UrlEntry{};
- Parser.CurrentHeader = HeaderEntry{};
- Parser.IsUpgrade = false;
- Parser.IsComplete = false;
-
- Parser.HeaderStream.Clear();
- Parser.HeaderEntries.clear();
-
- return 0;
- },
- .on_url =
- [](http_parser* P, const char* Data, size_t Size) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
-
- Parser.Url.Offset = Parser.HeaderStream.Pos();
- Parser.Url.Size = Size;
-
- Parser.HeaderStream.Append(Data, uint32_t(Size));
-
- return 0;
- },
- .on_header_field =
- [](http_parser* P, const char* Data, size_t Size) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
-
- if (Parser.CurrentHeader.Value.Size > 0)
- {
- Parser.HeaderEntries.push_back(Parser.CurrentHeader);
- Parser.CurrentHeader = HeaderEntry{};
- }
-
- if (Parser.CurrentHeader.Name.Size == 0)
- {
- Parser.CurrentHeader.Name.Offset = Parser.HeaderStream.Pos();
- }
-
- Parser.CurrentHeader.Name.Size += Size;
-
- Parser.HeaderStream.Append(Data, Size);
-
- return 0;
- },
- .on_header_value =
- [](http_parser* P, const char* Data, size_t Size) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
-
- if (Parser.CurrentHeader.Value.Size == 0)
- {
- Parser.CurrentHeader.Value.Offset = Parser.HeaderStream.Pos();
- }
-
- Parser.CurrentHeader.Value.Size += Size;
-
- Parser.HeaderStream.Append(Data, Size);
-
- return 0;
- },
- .on_headers_complete =
- [](http_parser* P) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
-
- if (Parser.CurrentHeader.Value.Size > 0)
- {
- Parser.HeaderEntries.push_back(Parser.CurrentHeader);
- Parser.CurrentHeader = HeaderEntry{};
- }
-
- Parser.IsUpgrade = P->upgrade > 0;
-
- return 0;
- },
- .on_message_complete =
- [](http_parser* P) {
- HttpParser& Parser = *reinterpret_cast<HttpParser*>(P->data);
- Parser.IsComplete = true;
- Parser.IsUpgrade = P->upgrade > 0;
- return 0;
- }};
- }
-
- struct MemStreamEntry
- {
- size_t Offset{};
- size_t Size{};
- };
-
- class MemStream
- {
- public:
- MemStream(size_t BlockSize = 1024) : m_BlockSize(BlockSize) {}
-
- void Append(const char* Data, size_t Size)
- {
- const size_t NewSize = m_Size + Size;
-
- if (NewSize > m_Buf.size())
- {
- m_Buf.resize(m_Buf.size() + m_BlockSize);
- }
-
- memcpy(m_Buf.data() + m_Size, Data, Size);
- m_Size += Size;
- }
-
- const char* Data() const { return m_Buf.data(); }
- size_t Pos() const { return m_Size; }
- void Clear() { m_Size = 0; }
-
- private:
- std::vector<char> m_Buf;
- size_t m_Size{};
- size_t m_BlockSize{};
- };
-
- using UrlEntry = MemStreamEntry;
-
- struct HeaderEntry
- {
- MemStreamEntry Name;
- MemStreamEntry Value;
- };
-
- static http_parser_settings ParserSettings;
-
- http_parser Parser;
- MemStream HeaderStream;
- std::vector<HeaderEntry> HeaderEntries;
- HeaderEntry CurrentHeader{};
- UrlEntry Url{};
- bool IsUpgrade = false;
- bool IsComplete = false;
-};
-
-http_parser_settings HttpParser::ParserSettings;
-
-class AsioWebSocketServer final : public WebSocketServer
-{
-public:
- AsioWebSocketServer() : m_Log(zen::logging::Get("websocket")) { HttpParser::Initialize(); }
-
- virtual ~AsioWebSocketServer() { Shutdown(); }
-
- virtual bool Run(const WebSocketServerOptions& Options) override
- {
- m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoService, asio::ip::tcp::v6());
-
- m_Acceptor->set_option(asio::ip::v6_only(false));
- m_Acceptor->set_option(asio::socket_base::reuse_address(true));
- m_Acceptor->set_option(asio::ip::tcp::no_delay(true));
- m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
- asio::error_code Ec;
- m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Options.Port), Ec);
-
- if (Ec)
- {
- ZEN_ERROR("failed to bind websocket endpoint, error code '{}'", Ec.value());
-
- return false;
- }
-
- m_Acceptor->listen();
-
- BeginAccept();
-
- StartIoThreads(Options.ThreadCount);
-
- m_Running.store(true, std::memory_order_relaxed);
-
- ZEN_INFO("websocket server running on port '{}'", Options.Port);
-
- return true;
- }
-
- virtual void Shutdown() override
- {
- if (m_Running)
- {
- ZEN_INFO("websocket server shutting down");
-
- m_Running = false;
-
- m_Acceptor->close();
- m_Acceptor.reset();
- m_IoService.stop();
-
- StopIoThreads();
- }
- }
-
-private:
- enum class WebSocketState : uint32_t
- {
- kNone,
- kHandshake,
- kRead,
- kWrite,
- kError
- };
-
- struct WebSocketConnection : public std::enable_shared_from_this<WebSocketConnection>
- {
- WebSocketConnection(std::unique_ptr<asio::ip::tcp::socket>&& S, uint32_t ConnId) : Socket(std::move(S)), Id(ConnId) {}
-
- std::unique_ptr<asio::ip::tcp::socket> Socket;
- asio::streambuf ReadBuffer;
- WebSocketState State;
- HttpParser Parser;
- uint32_t Id;
- };
-
- void BeginAccept()
- {
- auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoService);
- asio::ip::tcp::socket& SocketRef = *Socket.get();
-
- m_Acceptor->async_accept(SocketRef, [this, NewSocket = std::move(Socket)](const asio::error_code& Ec) mutable {
- if (Ec)
- {
- ZEN_WARN("accept error, error code '{}'", Ec.value());
- }
- else
- {
- const uint32_t Id = m_ConnectionId.fetch_add(1);
- auto Connection = std::make_shared<WebSocketConnection>(std::move(NewSocket), Id);
-
- Connection->State = WebSocketState::kHandshake;
-
- BeginRead(Connection);
- }
-
- if (m_Running.load(std::memory_order_relaxed))
- {
- BeginAccept();
- }
- else
- {
- m_Acceptor->close();
- }
- });
- }
-
- void BeginRead(std::shared_ptr<WebSocketConnection> Connection)
- {
- Connection->ReadBuffer.prepare(64 << 10);
-
- asio::async_read(*Connection->Socket,
- Connection->ReadBuffer,
- asio::transfer_at_least(1),
- [Conn = Connection->shared_from_this(), this](const asio::error_code& Ec, std::size_t ByteCount) {
- if (Ec)
- {
- ZEN_ERROR("read FAILED, connection '{}', error code '{}'", Conn->Id, Ec.value());
- Conn->Socket->close();
- return;
- }
-
- ZEN_TRACE("reading {}B from connection '{}'", ByteCount, Conn->Id);
-
- WebSocketState NextState = WebSocketState::kError;
-
- switch (Conn->State)
- {
- case WebSocketState::kHandshake:
- NextState = ProcessHandshake(Conn);
- break;
- }
-
- Conn->State = NextState;
-
- if (Conn->State == WebSocketState::kError)
- {
- ZEN_TRACE("process error, connection '{}'", Conn->Id);
- Conn->Socket->close();
- return;
- }
-
- BeginRead(Conn);
- });
- }
-
- WebSocketState ProcessHandshake(std::shared_ptr<WebSocketConnection> Connection)
- {
- HttpParser& Parser = Connection->Parser;
- const asio::const_buffer& Buffer = Connection->ReadBuffer.data();
-
- const size_t BytesParsed = Parser.Parse(reinterpret_cast<const char*>(Buffer.data()), Buffer.size());
- Connection->ReadBuffer.consume(BytesParsed);
-
- if (Parser.IsComplete)
- {
- if (Parser.IsUpgrade == false)
- {
- ZEN_DEBUG("invalid websocket handshake request, closing connection '{}'", Connection->Id);
-
- return WebSocketState::kError;
- }
-
- static constexpr std::string_view WebSocketKey = "Sec-WebSocket-Key"sv;
- static constexpr std::string_view WebSocketOriginKey = "Sec-WebSocket-Origin"sv;
- static constexpr std::string_view WebSocketProtocolKey = "Sec-WebSocket-Protocol"sv;
- static constexpr std::string_view WebSocketVersionKey = "Sec-WebSocket-Version"sv;
- static constexpr std::string_view WebSocketAcceptKey = "Sec-WebSocket-Accept"sv;
- static constexpr std::string_view UpgradeKey = "Upgrade"sv;
- static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv;
-
- std::unordered_map<std::string_view, std::string_view> Headers;
- Parser.GetHeaders(Headers);
-
- ZEN_DEBUG("handshake, Origin='{}', Protocol='{}', Version='{}', Key='{}'",
- Headers[WebSocketOriginKey],
- Headers[WebSocketProtocolKey],
- Headers[WebSocketVersionKey],
- Headers[WebSocketKey]);
-
- ExtendableStringBuilder<128> Sb;
- Sb << Headers[WebSocketKey] << WebSocketGuid;
-
- SHA1Stream HashStream;
- HashStream.Append(Sb.Data(), Sb.Size());
-
- SHA1 Hash = HashStream.GetHash();
- Sb.Reset();
-
- const uint32_t EncodedSize = Base64::GetEncodedDataSize(sizeof(SHA1::Hash));
- Sb.AddUninitialized(EncodedSize);
- Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), Sb.Data());
-
- std::string AcceptHash = Sb.ToString();
-
- Sb.Reset();
- Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv;
- Sb << "Upgrade: websocket\r\n"sv;
- Sb << "Connection: Upgrade\r\n"sv;
- Sb << WebSocketProtocolKey << ": " << Headers[WebSocketProtocolKey] << "\r\n";
- Sb << WebSocketAcceptKey << ": " << AcceptHash << "\r\n"
- << "\r\n"sv;
-
- IoBuffer Response = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView(Sb.ToView()));
- asio::const_buffer ResponseView(Response.Data(), Response.Size());
- const uint64_t ResponseLength = Response.Size();
-
- asio::async_write(
- *Connection->Socket,
- asio::buffer(Response.Data(), Response.Size()),
- asio::transfer_exactly(ResponseLength),
- [this, Conn = Connection->shared_from_this(), Buf = Response](const asio::error_code& Ec, std::size_t ByteCount) {
- if (Ec)
- {
- ZEN_ERROR("write {}B FAILED, error code '{}'", ByteCount, Ec.value());
- }
- else
- {
- ZEN_DEBUG("write {}B OK", ByteCount);
- }
- });
-
- return WebSocketState::kRead;
- }
-
- return WebSocketState::kHandshake;
- }
-
- void StartIoThreads(uint32_t ThreadCount)
- {
- ZEN_DEBUG("starting '{}' websocket I/O thread(s)");
-
- for (uint32_t Idx = 0; Idx < ThreadCount; Idx++)
- {
- m_ThreadPool.emplace_back([this, ThreadId = Idx + 1] {
- try
- {
- m_IoService.run();
- }
- catch (std::exception& Err)
- {
- ZEN_ERROR("process websocket request FAILED, reason '{}'", Err.what());
- }
-
- ZEN_DEBUG("websocket I/O thread '{}' exiting", ThreadId);
- });
- }
- }
-
- void StopIoThreads()
- {
- for (std::thread& Thread : m_ThreadPool)
- {
- if (Thread.joinable())
- {
- Thread.join();
- }
- }
-
- m_ThreadPool.clear();
- }
-
- spdlog::logger& Log() { return m_Log; }
- spdlog::logger& m_Log;
-
- asio::io_service m_IoService;
- std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor;
- std::atomic_bool m_Running{};
- std::atomic_uint32_t m_ConnectionId{1};
- std::vector<std::thread> m_ThreadPool;
-};
-
-} // namespace zen::asio_http
-
-namespace zen {
-
-std::unique_ptr<WebSocketServer>
-WebSocketServer::Create()
-{
- return std::make_unique<asio_http::AsioWebSocketServer>();
-}
-
-} // namespace zen