// Copyright Epic Games, Inc. All Rights Reserved. #include "asiohttpserver.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include #include #include ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_PLATFORM_WINDOWS # include #endif namespace zen::asio_http { using namespace std::literals; ZEN_DEFINE_LOG_CATEGORY_STATIC(LogHttp, "http"sv); using Clock = std::chrono::steady_clock; using TimePoint = Clock::time_point; /////////////////////////////////////////////////////////////////////////////// namespace http_header { static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; static constexpr std::string_view Upgrade = "Upgrade"sv; } // namespace http_header /////////////////////////////////////////////////////////////////////////////// class HttpConnectionId { static std::atomic_uint32_t NextId; public: HttpConnectionId() = default; uint32_t Value() const { return m_Value; } auto operator<=>(const HttpConnectionId&) const = default; static HttpConnectionId New() { return HttpConnectionId(NextId.fetch_add(1)); } private: HttpConnectionId(uint32_t Value) : m_Value(Value) {} uint32_t m_Value{}; }; std::atomic_uint32_t HttpConnectionId::NextId{1}; /////////////////////////////////////////////////////////////////////////////// enum class ParseMessageStatus : uint32_t { kError, kContinue, kDone, }; struct ParseMessageResult { ParseMessageStatus Status{}; size_t ByteCount{}; std::optional Reason; }; class MessageParser { public: virtual ~MessageParser() = default; ParseMessageResult ParseMessage(MemoryView Msg); void Reset(); protected: MessageParser() = default; virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; virtual void OnReset() = 0; BinaryWriter m_Stream; }; ParseMessageResult MessageParser::ParseMessage(MemoryView Msg) { return OnParseMessage(Msg); } void MessageParser::Reset() { OnReset(); m_Stream.Reset(); } /////////////////////////////////////////////////////////////////////////////// class HttpHeaderData { public: struct Entry { uint32_t Offset{0}; uint32_t Size{0}; }; struct HeaderEntry { Entry Name{}; Entry Value{}; }; static constexpr size_t kMaxHeaderSize = 2048; static constexpr size_t kMaxHeaderCount = 16; static constexpr size_t kHeaderIndexSize = kMaxHeaderCount * sizeof(HeaderEntry); static constexpr size_t kMaxHeaderContentSize = kMaxHeaderSize - kHeaderIndexSize; HttpHeaderData() { Reset(); } std::string_view StatusLine() const { const Entry* StatusEntry = reinterpret_cast(m_Buffer); const char* Buffer = reinterpret_cast(m_Buffer) + StatusEntry->Offset; return std::string_view(Buffer, size_t(StatusEntry->Size)); } std::string_view Url() const { const Entry* UrlEntry = reinterpret_cast(m_Buffer + sizeof(Entry)); const char* Buffer = reinterpret_cast(m_Buffer) + UrlEntry->Offset; return std::string_view(Buffer, size_t(UrlEntry->Size)); } void Reset() { memset(m_Buffer, 0, kMaxHeaderSize); m_CursorView = MutableMemoryView(&m_Buffer[kHeaderIndexSize], kMaxHeaderContentSize); m_HeaderCount = 0; } void AppendData(MemoryView Data) { m_CursorView.CopyFrom(Data); m_CursorView += Data.GetSize(); } void AppendStatus(Entry StatusEntry) { MutableMemoryView BufferView(&m_Buffer, sizeof(Entry)); BufferView.CopyFrom(MemoryView(&StatusEntry, sizeof(Entry))); } void AppendUrl(Entry StatusEntry) { MutableMemoryView BufferView(&m_Buffer + sizeof(Entry), sizeof(Entry)); BufferView.CopyFrom(MemoryView(&StatusEntry, sizeof(Entry))); } void AppendHeader(HeaderEntry Header) { MutableMemoryView HeaderView(&m_Buffer + ((m_HeaderCount + 1) * sizeof(HeaderEntry)), sizeof(HeaderEntry)); HeaderView.CopyFrom(MemoryView(&Header, sizeof(HeaderEntry))); m_HeaderCount++; } private: uint8_t m_Buffer[kMaxHeaderSize]; MutableMemoryView m_CursorView; size_t m_HeaderCount{0}; }; /////////////////////////////////////////////////////////////////////////////// class HttpRequestMessage final : public HttpServerRequest { public: HttpRequestMessage() = default; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) override; private: virtual Oid ParseSessionId() const; virtual uint32_t ParseRequestId() const; }; void HttpRequestMessage::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) { ZEN_UNUSED(ResponseCode, ContentType, Blobs); } void HttpRequestMessage::WriteResponse(HttpResponseCode ResponseCode) { ZEN_UNUSED(ResponseCode); } void HttpRequestMessage::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) { ZEN_UNUSED(ResponseCode, ContentType, ResponseString); } void HttpRequestMessage::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) { ZEN_UNUSED(ResponseCode, ContentType, Payload); } Oid HttpRequestMessage::ParseSessionId() const { return Oid::NewOid(); } uint32_t HttpRequestMessage::ParseRequestId() const { return 0u; } /////////////////////////////////////////////////////////////////////////////// enum class HttpMessageParserType { kRequest, kResponse, kBoth }; class HttpMessageParser final : public MessageParser { public: using HttpHeaders = std::unordered_map; HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); } virtual ~HttpMessageParser() = default; int32_t StatusCode() const { return m_Parser.status_code; } bool IsUpgrade() const { return m_Parser.upgrade != 0; } HttpHeaders& Headers() { return m_Headers; } MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } std::string_view StatusText() const { return std::string_view(reinterpret_cast(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); } bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); private: void Initialize(); virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; virtual void OnReset() override; int OnMessageBegin(); int OnUrl(MemoryView Url); int OnStatus(MemoryView Status); int OnHeaderField(MemoryView HeaderField); int OnHeaderValue(MemoryView HeaderValue); int OnHeadersComplete(); int OnBody(MemoryView Body); int OnMessageComplete(); struct StreamEntry { uint64_t Offset{}; uint64_t Size{}; }; struct HeaderStreamEntry { StreamEntry Field{}; StreamEntry Value{}; }; HttpMessageParserType m_Type; http_parser m_Parser; StreamEntry m_UrlEntry; StreamEntry m_StatusEntry; StreamEntry m_BodyEntry; HeaderStreamEntry m_CurrentHeader; std::vector m_HeaderEntries; HttpHeaders m_Headers; bool m_IsMsgComplete{false}; static http_parser_settings ParserSettings; }; http_parser_settings HttpMessageParser::ParserSettings = { .on_message_begin = [](http_parser* P) { return reinterpret_cast(P->data)->OnMessageBegin(); }, .on_url = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnUrl(MemoryView(Data, Size)); }, .on_status = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnStatus(MemoryView(Data, Size)); }, .on_header_field = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnHeaderField(MemoryView(Data, Size)); }, .on_header_value = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, .on_headers_complete = [](http_parser* P) { return reinterpret_cast(P->data)->OnHeadersComplete(); }, .on_body = [](http_parser* P, const char* Data, size_t Size) { return reinterpret_cast(P->data)->OnBody(MemoryView(Data, Size)); }, .on_message_complete = [](http_parser* P) { return reinterpret_cast(P->data)->OnMessageComplete(); }}; void HttpMessageParser::Initialize() { http_parser_init(&m_Parser, m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE : HTTP_BOTH); m_Parser.data = this; m_UrlEntry = {}; m_StatusEntry = {}; m_CurrentHeader = {}; m_BodyEntry = {}; m_IsMsgComplete = false; m_HeaderEntries.clear(); } ParseMessageResult HttpMessageParser::OnParseMessage(MemoryView Msg) { const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast(Msg.GetData()), Msg.GetSize()); auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; if (m_Parser.http_errno != 0) { Status = ParseMessageStatus::kError; } return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; } void HttpMessageParser::OnReset() { Initialize(); } int HttpMessageParser::OnMessageBegin() { ZEN_ASSERT(m_IsMsgComplete == false); ZEN_ASSERT(m_HeaderEntries.empty()); ZEN_ASSERT(m_Headers.empty()); return 0; } int HttpMessageParser::OnStatus(MemoryView Status) { m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; m_Stream.Write(Status); return 0; } int HttpMessageParser::OnUrl(MemoryView Url) { m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; m_Stream.Write(Url); return 0; } int HttpMessageParser::OnHeaderField(MemoryView HeaderField) { if (m_CurrentHeader.Value.Size > 0) { m_HeaderEntries.push_back(m_CurrentHeader); m_CurrentHeader = {}; } if (m_CurrentHeader.Field.Size == 0) { m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); } m_CurrentHeader.Field.Size += HeaderField.GetSize(); m_Stream.Write(HeaderField); return 0; } int HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) { if (m_CurrentHeader.Value.Size == 0) { m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); } m_CurrentHeader.Value.Size += HeaderValue.GetSize(); m_Stream.Write(HeaderValue); return 0; } int HttpMessageParser::OnHeadersComplete() { if (m_CurrentHeader.Value.Size > 0) { m_HeaderEntries.push_back(m_CurrentHeader); m_CurrentHeader = {}; } m_Headers.clear(); m_Headers.reserve(m_HeaderEntries.size()); const char* StreamData = reinterpret_cast(m_Stream.Data()); for (const auto& Entry : m_HeaderEntries) { auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); m_Headers.try_emplace(std::move(Field), std::move(Value)); } return 0; } int HttpMessageParser::OnBody(MemoryView Body) { m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; m_Stream.Write(Body); return 0; } int HttpMessageParser::OnMessageComplete() { m_IsMsgComplete = true; return 0; } bool HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) { static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; OutAcceptHash = std::string(); if (m_Headers.contains(http_header::SecWebSocketKey) == false) { OutReason = "Missing header Sec-WebSocket-Key"; return false; } if (m_Headers.contains(http_header::Upgrade) == false) { OutReason = "Missing header Upgrade"; return false; } ExtendableStringBuilder<128> Sb; Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; SHA1Stream HashStream; HashStream.Append(Sb.Data(), Sb.Size()); SHA1 Hash = HashStream.GetHash(); OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); return true; } /////////////////////////////////////////////////////////////////////////////// enum class ConnectionFlags : uint32_t { kNone = 0, kHttp = 1 << 0, kWebSocket = 1 << 1, kConnected = 1 << 2 }; ENUM_CLASS_FLAGS(ConnectionFlags); struct HttpConnection : public std::enable_shared_from_this { HttpConnection(HttpConnectionId ConnId, asio::ip::tcp::socket&& S) : Id(ConnId) , Socket(std::move(S)) , StartTime(Clock::now()) , FlagsValue(uint32_t(ConnectionFlags::kHttp | ConnectionFlags::kConnected)) , Parser(new HttpMessageParser(HttpMessageParserType::kRequest)) { } std::shared_ptr AsShared() { return shared_from_this(); } std::string RemoteAddr() const { return Socket.remote_endpoint().address().to_string(); } ConnectionFlags Flags() const { return static_cast(FlagsValue.load(std::memory_order_relaxed)); } HttpConnectionId Id; std::atomic FlagsValue; asio::ip::tcp::socket Socket; TimePoint StartTime; std::unique_ptr Parser; asio::streambuf ReadBuffer; std::mutex WriteMutex; }; /////////////////////////////////////////////////////////////////////////////// class AsioThreadPool { public: AsioThreadPool(asio::io_context& IoCtx) : m_IoCtx(IoCtx) {} void Start(uint32_t ThreadCount); void Stop(); private: asio::io_context& m_IoCtx; std::vector m_Threads; std::atomic_bool m_Running{false}; }; void AsioThreadPool::Start(uint32_t ThreadCount) { ZEN_ASSERT(m_Threads.empty()); ZEN_LOG_DEBUG(LogHttp, "starting '{}' HTTP I/O thread(s)", ThreadCount); m_Running = true; for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) { m_Threads.emplace_back([this, ThreadId = Idx + 1] { for (;;) { if (m_Running == false) { break; } try { m_IoCtx.run(); } catch (std::exception& Err) { ZEN_LOG_ERROR(LogHttp, "process HTTP I/O FAILED, reason '{}'", Err.what()); } } }); } } void AsioThreadPool::Stop() { if (m_Running) { m_Running = false; for (std::thread& Thread : m_Threads) { if (Thread.joinable()) { Thread.join(); } } m_Threads.clear(); } } /////////////////////////////////////////////////////////////////////////////// class AsioHttpServer final : public HttpServer { public: AsioHttpServer(const zen::AsioHttpServerOptions& Options); virtual void RegisterService(HttpService& Service) override; virtual int Initialize(int Port) override; virtual void Run(bool IsInteractiveSession) override; virtual void RequestExit() override; private: void AcceptConnection(); void CloseConnection(std::shared_ptr& Connection, const asio::error_code& Ec); void ReadMessage(std::shared_ptr& Connection); struct IdHasher { size_t operator()(HttpConnectionId Id) const { return size_t(Id.Value()); } }; using ConnectionMap = std::unordered_map, IdHasher>; zen::AsioHttpServerOptions m_Options; asio::io_service m_IoCtx; std::unique_ptr m_ThreadPool; std::unique_ptr m_Acceptor; std::optional m_ListeningSocket; std::mutex m_ShutdownMutex; std::condition_variable m_ShutdownSignal; std::vector m_Services; ConnectionMap m_Connections; std::shared_mutex m_ConnMutex; std::atomic_bool m_Running{false}; }; AsioHttpServer::AsioHttpServer(const AsioHttpServerOptions& Options) : m_Options(Options) { } void AsioHttpServer::RegisterService(HttpService& Service) { m_Services.push_back(&Service); } int AsioHttpServer::Initialize(int Port) { m_Acceptor = std::make_unique(m_IoCtx, asio::ip::tcp::v6()); m_Acceptor->set_option(asio::ip::v6_only(false)); m_Acceptor->set_option(asio::socket_base::reuse_address(true)); m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); m_Acceptor->set_option(asio::socket_base::receive_buffer_size(m_Options.ReceiveBufferSize)); m_Acceptor->set_option(asio::socket_base::send_buffer_size(m_Options.SendBufferSize)); #if ZEN_PLATFORM_WINDOWS // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. // This must be used by both the client and server side, and is only effective in the absence of // Windows Filtering Platform (WFP) callouts which can be installed by security software. // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path SOCKET NativeSocket = m_Acceptor->native_handle(); int LoopbackOptionValue = 1; DWORD OptionNumberOfBytesReturned = 0; WSAIoctl(NativeSocket, SIO_LOOPBACK_FAST_PATH, &LoopbackOptionValue, sizeof(LoopbackOptionValue), NULL, 0, &OptionNumberOfBytesReturned, 0, 0); #endif asio::error_code Ec; m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), asio::ip::port_type(Port)), Ec); if (Ec) { ZEN_LOG_ERROR(LogHttp, "bind endpoint FAILED, reason '{}'", Ec.message()); return -1; } m_Acceptor->listen(); m_Running = true; ZEN_LOG_INFO(LogHttp, "web socket server running on port '{}'", Port); AcceptConnection(); m_ThreadPool = std::make_unique(m_IoCtx); m_ThreadPool->Start(m_Options.ThreadCount); return 0; } void AsioHttpServer::Run(bool IsInteractiveSession) { if (IsInteractiveSession) { zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit"); } for (;;) { if (IsApplicationExitRequested()) { break; } std::unique_lock Lock(m_ShutdownMutex); m_ShutdownSignal.wait_for(Lock, std::chrono::seconds(2)); } } void AsioHttpServer::RequestExit() { m_ShutdownSignal.notify_one(); } void AsioHttpServer::AcceptConnection() { m_ListeningSocket.emplace(m_IoCtx); m_Acceptor->async_accept(m_ListeningSocket.value(), [this](const asio::error_code& Ec) mutable { if (m_Running) { if (Ec) { ZEN_LOG_WARN(LogHttp, "accept connection FAILED, reason '{}'", Ec.message()); } else { auto Connection = std::make_shared(HttpConnectionId::New(), std::move(m_ListeningSocket.value())); ZEN_LOG_TRACE(LogHttp, "accept connection '#{} {}' OK", Connection->Id.Value(), Connection->RemoteAddr()); { std::unique_lock _(m_ConnMutex); m_Connections[Connection->Id] = Connection; } ReadMessage(Connection); } AcceptConnection(); } }); } void AsioHttpServer::CloseConnection(std::shared_ptr& Connection, const asio::error_code& Ec) { if (Ec) { ZEN_LOG_INFO(LogHttp, "connection '{}' closed, reason '{} ({})'", Connection->Id.Value(), Ec.message(), Ec.value()); } else { ZEN_LOG_INFO(LogHttp, "connection '{}' closed", Connection->Id.Value()); } const HttpConnectionId Id = Connection->Id; { std::unique_lock _(m_ConnMutex); m_Connections.erase(Id); } } void AsioHttpServer::ReadMessage(std::shared_ptr& Connection) { Connection->ReadBuffer.prepare(m_Options.ReceiveBufferSize); asio::async_read(Connection->Socket, Connection->ReadBuffer, asio::transfer_at_least(1), [this, Connection](const asio::error_code& Ec, std::size_t) mutable { if (Ec) { return CloseConnection(Connection, Ec); } const ConnectionFlags Flags = Connection->Flags(); HttpMessageParser& Parser = *reinterpret_cast(Connection->Parser.get()); uint64_t RemainingBytes = Connection->ReadBuffer.size(); while (RemainingBytes > 0) { MemoryView MessageData = MemoryView(Connection->ReadBuffer.data().data(), RemainingBytes); const ParseMessageResult Result = Parser.ParseMessage(MessageData); Connection->ReadBuffer.consume(Result.ByteCount); RemainingBytes = Connection->ReadBuffer.size(); if (Result.Status == ParseMessageStatus::kError) { ZEN_LOG_WARN(LogHttp, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); return CloseConnection(Connection, std::error_code()); } if (Result.Status == ParseMessageStatus::kContinue) { ZEN_ASSERT(RemainingBytes == 0); continue; } } ReadMessage(Connection); // ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); // Connection->ReadBuffer.consume(Result.ByteCount); // if (Result.Status == ParseMessageStatus::kContinue) //{ // return ReadMessage(Connection); //} // if (Result.Status == ParseMessageStatus::kError) //{ // const std::string Reason = Result.Reason.has_value() ? std::move(Result.Reason.value()) : "HTTP parse error"; // ZEN_LOG_WARN(LogHttp, "parse message FAILED, connection #{}, reason '{}'", Connection->Id.Value(), Reason); // return CloseConnection(Connection, std::error_code()); //} }); } } // namespace zen::asio_http namespace zen { Ref CreateAsioHttpServer(const AsioHttpServerOptions& Options) { return new zen::asio_http::AsioHttpServer(Options); } } // namespace zen