diff options
Diffstat (limited to 'zenhttp/asiohttpserver.cpp')
| -rw-r--r-- | zenhttp/asiohttpserver.cpp | 799 |
1 files changed, 799 insertions, 0 deletions
diff --git a/zenhttp/asiohttpserver.cpp b/zenhttp/asiohttpserver.cpp new file mode 100644 index 000000000..9185f8856 --- /dev/null +++ b/zenhttp/asiohttpserver.cpp @@ -0,0 +1,799 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "asiohttpserver.h" + +#include <zencore/base64.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/enumflags.h> +#include <zencore/intmath.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/sha1.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/trace.h> + +#include <zenhttp/httpserver.h> + +#include <chrono> +#include <optional> +#include <shared_mutex> +#include <span> +#include <system_error> +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <http_parser.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# include <mstcpip.h> +#endif + +namespace zen::asio_http { + +using namespace std::literals; + +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogHttp, "http"sv); + +using Clock = std::chrono::steady_clock; +using TimePoint = Clock::time_point; + +/////////////////////////////////////////////////////////////////////////////// +namespace http_header { + static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; + static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; + static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; + static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; + static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; + static constexpr std::string_view Upgrade = "Upgrade"sv; +} // namespace http_header + +/////////////////////////////////////////////////////////////////////////////// +class HttpConnectionId +{ + static std::atomic_uint32_t NextId; + +public: + HttpConnectionId() = default; + + uint32_t Value() const { return m_Value; } + + auto operator<=>(const HttpConnectionId&) const = default; + + static HttpConnectionId New() { return HttpConnectionId(NextId.fetch_add(1)); } + +private: + HttpConnectionId(uint32_t Value) : m_Value(Value) {} + + uint32_t m_Value{}; +}; + +std::atomic_uint32_t HttpConnectionId::NextId{1}; + +/////////////////////////////////////////////////////////////////////////////// +enum class ParseMessageStatus : uint32_t +{ + kError, + kContinue, + kDone, +}; + +struct ParseMessageResult +{ + ParseMessageStatus Status{}; + size_t ByteCount{}; + std::optional<std::string> Reason; +}; + +class MessageParser +{ +public: + virtual ~MessageParser() = default; + + ParseMessageResult ParseMessage(MemoryView Msg); + void Reset(); + +protected: + MessageParser() = default; + + virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; + virtual void OnReset() = 0; + + BinaryWriter m_Stream; +}; + +ParseMessageResult +MessageParser::ParseMessage(MemoryView Msg) +{ + return OnParseMessage(Msg); +} + +void +MessageParser::Reset() +{ + OnReset(); + + m_Stream.Reset(); +} + +/////////////////////////////////////////////////////////////////////////////// +class HttpRequestMessage final : public HttpServerRequest +{ +public: + HttpRequestMessage() = default; + + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) override; + +private: + virtual Oid ParseSessionId() const; + virtual uint32_t ParseRequestId() const; +}; + +void +HttpRequestMessage::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_UNUSED(ResponseCode, ContentType, Blobs); +} + +void +HttpRequestMessage::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_UNUSED(ResponseCode); +} + +void +HttpRequestMessage::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_UNUSED(ResponseCode, ContentType, ResponseString); +} + +void +HttpRequestMessage::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) +{ + ZEN_UNUSED(ResponseCode, ContentType, Payload); +} + +Oid +HttpRequestMessage::ParseSessionId() const +{ + return Oid::NewOid(); +} + +uint32_t +HttpRequestMessage::ParseRequestId() const +{ + return 0u; +} + +/////////////////////////////////////////////////////////////////////////////// +enum class HttpMessageParserType +{ + kRequest, + kResponse, + kBoth +}; + +class HttpMessageParser final : public MessageParser +{ +public: + using HttpHeaders = std::unordered_map<std::string_view, std::string_view>; + + HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); } + + virtual ~HttpMessageParser() = default; + + int32_t StatusCode() const { return m_Parser.status_code; } + bool IsUpgrade() const { return m_Parser.upgrade != 0; } + HttpHeaders& Headers() { return m_Headers; } + MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } + + std::string_view StatusText() const + { + return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); + } + + bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); + +private: + void Initialize(); + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; + int OnMessageBegin(); + int OnUrl(MemoryView Url); + int OnStatus(MemoryView Status); + int OnHeaderField(MemoryView HeaderField); + int OnHeaderValue(MemoryView HeaderValue); + int OnHeadersComplete(); + int OnBody(MemoryView Body); + int OnMessageComplete(); + + struct StreamEntry + { + uint64_t Offset{}; + uint64_t Size{}; + }; + + struct HeaderStreamEntry + { + StreamEntry Field{}; + StreamEntry Value{}; + }; + + HttpMessageParserType m_Type; + http_parser m_Parser; + StreamEntry m_UrlEntry; + StreamEntry m_StatusEntry; + StreamEntry m_BodyEntry; + HeaderStreamEntry m_CurrentHeader; + std::vector<HeaderStreamEntry> m_HeaderEntries; + HttpHeaders m_Headers; + bool m_IsMsgComplete{false}; + + static http_parser_settings ParserSettings; +}; + +http_parser_settings HttpMessageParser::ParserSettings = { + .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); }, + + .on_url = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); }, + + .on_status = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); }, + + .on_header_field = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); }, + + .on_header_value = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, + + .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); }, + + .on_body = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); }, + + .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }}; + +void +HttpMessageParser::Initialize() +{ + http_parser_init(&m_Parser, + m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST + : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE + : HTTP_BOTH); + m_Parser.data = this; + + m_UrlEntry = {}; + m_StatusEntry = {}; + m_CurrentHeader = {}; + m_BodyEntry = {}; + + m_IsMsgComplete = false; + + m_HeaderEntries.clear(); +} + +ParseMessageResult +HttpMessageParser::OnParseMessage(MemoryView Msg) +{ + const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize()); + + auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; + + if (m_Parser.http_errno != 0) + { + Status = ParseMessageStatus::kError; + } + + return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; +} + +void +HttpMessageParser::OnReset() +{ + Initialize(); +} + +int +HttpMessageParser::OnMessageBegin() +{ + ZEN_ASSERT(m_IsMsgComplete == false); + ZEN_ASSERT(m_HeaderEntries.empty()); + ZEN_ASSERT(m_Headers.empty()); + + return 0; +} + +int +HttpMessageParser::OnStatus(MemoryView Status) +{ + m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; + + m_Stream.Write(Status); + + return 0; +} + +int +HttpMessageParser::OnUrl(MemoryView Url) +{ + m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; + + m_Stream.Write(Url); + + return 0; +} + +int +HttpMessageParser::OnHeaderField(MemoryView HeaderField) +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } + + if (m_CurrentHeader.Field.Size == 0) + { + m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Field.Size += HeaderField.GetSize(); + + m_Stream.Write(HeaderField); + + return 0; +} + +int +HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) +{ + if (m_CurrentHeader.Value.Size == 0) + { + m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Value.Size += HeaderValue.GetSize(); + + m_Stream.Write(HeaderValue); + + return 0; +} + +int +HttpMessageParser::OnHeadersComplete() +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } + + m_Headers.clear(); + m_Headers.reserve(m_HeaderEntries.size()); + + const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data()); + + for (const auto& Entry : m_HeaderEntries) + { + auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); + auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); + + m_Headers.try_emplace(std::move(Field), std::move(Value)); + } + + return 0; +} + +int +HttpMessageParser::OnBody(MemoryView Body) +{ + m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; + + m_Stream.Write(Body); + + return 0; +} + +int +HttpMessageParser::OnMessageComplete() +{ + m_IsMsgComplete = true; + + return 0; +} + +bool +HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) +{ + static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + + OutAcceptHash = std::string(); + + if (m_Headers.contains(http_header::SecWebSocketKey) == false) + { + OutReason = "Missing header Sec-WebSocket-Key"; + return false; + } + + if (m_Headers.contains(http_header::Upgrade) == false) + { + OutReason = "Missing header Upgrade"; + return false; + } + + ExtendableStringBuilder<128> Sb; + Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; + + SHA1Stream HashStream; + HashStream.Append(Sb.Data(), Sb.Size()); + + SHA1 Hash = HashStream.GetHash(); + + OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); + Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); + + return true; +} +/////////////////////////////////////////////////////////////////////////////// +enum class ConnectionFlags : uint32_t +{ + kNone = 0, + kHttp = 1 << 0, + kWebSocket = 1 << 1, + kConnected = 1 << 2 +}; + +ENUM_CLASS_FLAGS(ConnectionFlags); + +struct HttpConnection : public std::enable_shared_from_this<HttpConnection> +{ + HttpConnection(HttpConnectionId ConnId, asio::ip::tcp::socket&& S) + : Id(ConnId) + , Socket(std::move(S)) + , StartTime(Clock::now()) + , FlagsValue(uint32_t(ConnectionFlags::kHttp | ConnectionFlags::kConnected)) + , Parser(new HttpMessageParser(HttpMessageParserType::kRequest)) + { + } + + std::shared_ptr<HttpConnection> AsShared() { return shared_from_this(); } + std::string RemoteAddr() const { return Socket.remote_endpoint().address().to_string(); } + ConnectionFlags Flags() const { return static_cast<ConnectionFlags>(FlagsValue.load(std::memory_order_relaxed)); } + + HttpConnectionId Id; + std::atomic<uint32_t> FlagsValue; + asio::ip::tcp::socket Socket; + TimePoint StartTime; + std::unique_ptr<MessageParser> Parser; + asio::streambuf ReadBuffer; + std::mutex WriteMutex; +}; + +/////////////////////////////////////////////////////////////////////////////// +class AsioThreadPool +{ +public: + AsioThreadPool(asio::io_context& IoCtx) : m_IoCtx(IoCtx) {} + void Start(uint32_t ThreadCount); + void Stop(); + +private: + asio::io_context& m_IoCtx; + std::vector<std::thread> m_Threads; + std::atomic_bool m_Running{false}; +}; + +void +AsioThreadPool::Start(uint32_t ThreadCount) +{ + ZEN_ASSERT(m_Threads.empty()); + + ZEN_LOG_DEBUG(LogHttp, "starting '{}' HTTP I/O thread(s)", ThreadCount); + + m_Running = true; + + for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) + { + m_Threads.emplace_back([this, ThreadId = Idx + 1] { + for (;;) + { + if (m_Running == false) + { + break; + } + + try + { + m_IoCtx.run(); + } + catch (std::exception& Err) + { + ZEN_LOG_ERROR(LogHttp, "process HTTP I/O FAILED, reason '{}'", Err.what()); + } + } + }); + } +} + +void +AsioThreadPool::Stop() +{ + if (m_Running) + { + m_Running = false; + + for (std::thread& Thread : m_Threads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Threads.clear(); + } +} + +/////////////////////////////////////////////////////////////////////////////// +class AsioHttpServer final : public HttpServer +{ +public: + AsioHttpServer(const zen::AsioHttpServerOptions& Options); + + virtual void RegisterService(HttpService& Service) override; + virtual int Initialize(int Port) override; + virtual void Run(bool IsInteractiveSession) override; + virtual void RequestExit() override; + +private: + void AcceptConnection(); + void CloseConnection(std::shared_ptr<HttpConnection>& Connection, const asio::error_code& Ec); + void ReadMessage(std::shared_ptr<HttpConnection>& Connection); + + struct IdHasher + { + size_t operator()(HttpConnectionId Id) const { return size_t(Id.Value()); } + }; + + using ConnectionMap = std::unordered_map<HttpConnectionId, std::shared_ptr<HttpConnection>, IdHasher>; + + zen::AsioHttpServerOptions m_Options; + asio::io_service m_IoCtx; + std::unique_ptr<AsioThreadPool> m_ThreadPool; + std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; + std::optional<asio::ip::tcp::socket> m_ListeningSocket; + std::mutex m_ShutdownMutex; + std::condition_variable m_ShutdownSignal; + std::vector<HttpService*> m_Services; + ConnectionMap m_Connections; + std::shared_mutex m_ConnMutex; + std::atomic_bool m_Running{false}; +}; + +AsioHttpServer::AsioHttpServer(const AsioHttpServerOptions& Options) : m_Options(Options) +{ +} + +void +AsioHttpServer::RegisterService(HttpService& Service) +{ + m_Services.push_back(&Service); +} + +int +AsioHttpServer::Initialize(int Port) +{ + m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoCtx, asio::ip::tcp::v6()); + + m_Acceptor->set_option(asio::ip::v6_only(false)); + m_Acceptor->set_option(asio::socket_base::reuse_address(true)); + m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor->set_option(asio::socket_base::receive_buffer_size(m_Options.ReceiveBufferSize)); + m_Acceptor->set_option(asio::socket_base::send_buffer_size(m_Options.SendBufferSize)); + +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor->native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +#endif + + asio::error_code Ec; + m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), asio::ip::port_type(Port)), Ec); + + if (Ec) + { + ZEN_LOG_ERROR(LogHttp, "bind endpoint FAILED, reason '{}'", Ec.message()); + + return -1; + } + + m_Acceptor->listen(); + m_Running = true; + + ZEN_LOG_INFO(LogHttp, "web socket server running on port '{}'", Port); + + AcceptConnection(); + + m_ThreadPool = std::make_unique<AsioThreadPool>(m_IoCtx); + m_ThreadPool->Start(m_Options.ThreadCount); + + return 0; +} + +void +AsioHttpServer::Run(bool IsInteractiveSession) +{ + if (IsInteractiveSession) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit"); + } + + for (;;) + { + if (IsApplicationExitRequested()) + { + break; + } + + std::unique_lock Lock(m_ShutdownMutex); + m_ShutdownSignal.wait_for(Lock, std::chrono::seconds(2)); + } +} + +void +AsioHttpServer::RequestExit() +{ + m_ShutdownSignal.notify_one(); +} + +void +AsioHttpServer::AcceptConnection() +{ + m_ListeningSocket.emplace(m_IoCtx); + + m_Acceptor->async_accept(m_ListeningSocket.value(), [this](const asio::error_code& Ec) mutable { + if (m_Running) + { + if (Ec) + { + ZEN_LOG_WARN(LogHttp, "accept connection FAILED, reason '{}'", Ec.message()); + } + else + { + auto Connection = std::make_shared<HttpConnection>(HttpConnectionId::New(), std::move(m_ListeningSocket.value())); + + ZEN_LOG_TRACE(LogHttp, "accept connection '#{} {}' OK", Connection->Id.Value(), Connection->RemoteAddr()); + + { + std::unique_lock _(m_ConnMutex); + m_Connections[Connection->Id] = Connection; + } + + ReadMessage(Connection); + } + + AcceptConnection(); + } + }); +} + +void +AsioHttpServer::CloseConnection(std::shared_ptr<HttpConnection>& Connection, const asio::error_code& Ec) +{ + if (Ec) + { + ZEN_LOG_INFO(LogHttp, "connection '{}' closed, reason '{} ({})'", Connection->Id.Value(), Ec.message(), Ec.value()); + } + else + { + ZEN_LOG_INFO(LogHttp, "connection '{}' closed", Connection->Id.Value()); + } + + const HttpConnectionId Id = Connection->Id; + + { + std::unique_lock _(m_ConnMutex); + m_Connections.erase(Id); + } +} + +void +AsioHttpServer::ReadMessage(std::shared_ptr<HttpConnection>& Connection) +{ + Connection->ReadBuffer.prepare(m_Options.ReceiveBufferSize); + + asio::async_read(Connection->Socket, + Connection->ReadBuffer, + asio::transfer_at_least(1), + [this, Connection](const asio::error_code& Ec, std::size_t) mutable { + if (Ec) + { + return CloseConnection(Connection, Ec); + } + + const ConnectionFlags Flags = Connection->Flags(); + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser.get()); + + uint64_t RemainingBytes = Connection->ReadBuffer.size(); + + while (RemainingBytes > 0) + { + MemoryView MessageData = MemoryView(Connection->ReadBuffer.data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Connection->ReadBuffer.consume(Result.ByteCount); + RemainingBytes = Connection->ReadBuffer.size(); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogHttp, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + + return CloseConnection(Connection, std::error_code()); + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(RemainingBytes == 0); + continue; + } + } + + ReadMessage(Connection); + + // ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + + // Connection->ReadBuffer.consume(Result.ByteCount); + + // if (Result.Status == ParseMessageStatus::kContinue) + //{ + // return ReadMessage(Connection); + //} + + // if (Result.Status == ParseMessageStatus::kError) + //{ + // const std::string Reason = Result.Reason.has_value() ? std::move(Result.Reason.value()) : "HTTP parse error"; + // ZEN_LOG_WARN(LogHttp, "parse message FAILED, connection #{}, reason '{}'", Connection->Id.Value(), Reason); + + // return CloseConnection(Connection, std::error_code()); + //} + }); +} + +} // namespace zen::asio_http + +namespace zen { + +Ref<HttpServer> +CreateAsioHttpServer(const AsioHttpServerOptions& Options) +{ + return new zen::asio_http::AsioHttpServer(Options); +} + +} // namespace zen |