From 74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Fri, 13 Oct 2023 09:55:27 +0200 Subject: restructured zenhttp (#472) separating the http server implementations into a directory and moved diagsvcs into zenserver since it's somewhat hard-coded for it --- src/zenhttp/diagsvcs.cpp | 135 --- src/zenhttp/httpasio.cpp | 1052 ----------------- src/zenhttp/httpasio.h | 11 - src/zenhttp/httpnull.cpp | 88 -- src/zenhttp/httpnull.h | 30 - src/zenhttp/httpparser.cpp | 370 ------ src/zenhttp/httpparser.h | 112 -- src/zenhttp/httpplugin.cpp | 781 ------------- src/zenhttp/httpserver.cpp | 7 +- src/zenhttp/httpsys.cpp | 2012 -------------------------------- src/zenhttp/httpsys.h | 28 - src/zenhttp/include/zenhttp/diagsvcs.h | 119 -- src/zenhttp/iothreadpool.cpp | 54 - src/zenhttp/iothreadpool.h | 31 - src/zenhttp/servers/httpasio.cpp | 1052 +++++++++++++++++ src/zenhttp/servers/httpasio.h | 11 + src/zenhttp/servers/httpnull.cpp | 88 ++ src/zenhttp/servers/httpnull.h | 30 + src/zenhttp/servers/httpparser.cpp | 370 ++++++ src/zenhttp/servers/httpparser.h | 112 ++ src/zenhttp/servers/httpplugin.cpp | 781 +++++++++++++ src/zenhttp/servers/httpsys.cpp | 2012 ++++++++++++++++++++++++++++++++ src/zenhttp/servers/httpsys.h | 28 + src/zenhttp/servers/iothreadpool.cpp | 54 + src/zenhttp/servers/iothreadpool.h | 31 + src/zenhttp/xmake.lua | 2 +- src/zenserver/diag/diagsvcs.cpp | 135 +++ src/zenserver/diag/diagsvcs.h | 119 ++ src/zenserver/zenserver.h | 2 +- 29 files changed, 4828 insertions(+), 4829 deletions(-) delete mode 100644 src/zenhttp/diagsvcs.cpp delete mode 100644 src/zenhttp/httpasio.cpp delete mode 100644 src/zenhttp/httpasio.h delete mode 100644 src/zenhttp/httpnull.cpp delete mode 100644 src/zenhttp/httpnull.h delete mode 100644 src/zenhttp/httpparser.cpp delete mode 100644 src/zenhttp/httpparser.h delete mode 100644 src/zenhttp/httpplugin.cpp delete mode 100644 src/zenhttp/httpsys.cpp delete mode 100644 src/zenhttp/httpsys.h delete mode 100644 src/zenhttp/include/zenhttp/diagsvcs.h delete mode 100644 src/zenhttp/iothreadpool.cpp delete mode 100644 src/zenhttp/iothreadpool.h create mode 100644 src/zenhttp/servers/httpasio.cpp create mode 100644 src/zenhttp/servers/httpasio.h create mode 100644 src/zenhttp/servers/httpnull.cpp create mode 100644 src/zenhttp/servers/httpnull.h create mode 100644 src/zenhttp/servers/httpparser.cpp create mode 100644 src/zenhttp/servers/httpparser.h create mode 100644 src/zenhttp/servers/httpplugin.cpp create mode 100644 src/zenhttp/servers/httpsys.cpp create mode 100644 src/zenhttp/servers/httpsys.h create mode 100644 src/zenhttp/servers/iothreadpool.cpp create mode 100644 src/zenhttp/servers/iothreadpool.h create mode 100644 src/zenserver/diag/diagsvcs.cpp create mode 100644 src/zenserver/diag/diagsvcs.h (limited to 'src') diff --git a/src/zenhttp/diagsvcs.cpp b/src/zenhttp/diagsvcs.cpp deleted file mode 100644 index 9a547aa47..000000000 --- a/src/zenhttp/diagsvcs.cpp +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "zenhttp/diagsvcs.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace zen { - -using namespace std::literals; - -static bool -ReadLogFile(const std::string& Path, StringBuilderBase& Out) -{ - try - { - constexpr auto ReadSize = std::size_t{4096}; - auto FileStream = std::ifstream{Path}; - - std::string Buf(ReadSize, '\0'); - while (FileStream.read(&Buf[0], ReadSize)) - { - Out.Append(std::string_view(&Buf[0], FileStream.gcount())); - } - Out.Append(std::string_view(&Buf[0], FileStream.gcount())); - - return true; - } - catch (std::exception&) - { - Out.Reset(); - return false; - } -} - -HttpHealthService::HttpHealthService() -{ - m_Router.RegisterRoute( - "", - [](HttpRouterRequest& RoutedReq) { - HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); - HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "info", - [this](HttpRouterRequest& RoutedReq) { - HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); - - CbObjectWriter Writer; - - { - RwLock::SharedLockScope _(m_InfoLock); - Writer << "DataRoot"sv << m_HealthInfo.DataRoot.string(); - Writer << "AbsLogPath"sv << m_HealthInfo.AbsLogPath.string(); - Writer << "BuildVersion"sv << m_HealthInfo.BuildVersion; - Writer << "HttpServerClass"sv << m_HealthInfo.HttpServerClass; - } - - HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "log", - [this](HttpRouterRequest& RoutedReq) { - HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); - - zen::Log().flush(); - - std::filesystem::path Path = [&] { - RwLock::SharedLockScope _(m_InfoLock); - return m_HealthInfo.AbsLogPath.empty() ? m_HealthInfo.DataRoot / "logs/zenserver.log" : m_HealthInfo.AbsLogPath; - }(); - - ExtendableStringBuilder<4096> Sb; - if (ReadLogFile(Path.string(), Sb) && Sb.Size() > 0) - { - HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Sb.ToView()); - } - else - { - HttpReq.WriteResponse(HttpResponseCode::NotFound); - } - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "version", - [this](HttpRouterRequest& RoutedReq) { - HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); - if (HttpReq.GetQueryParams().GetValue("detailed") == "true") - { - HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION_BUILD_STRING_FULL); - } - else - { - HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION); - } - }, - HttpVerb::kGet); -} - -void -HttpHealthService::SetHealthInfo(HealthServiceInfo&& Info) -{ - RwLock::ExclusiveLockScope _(m_InfoLock); - m_HealthInfo = std::move(Info); -} - -const char* -HttpHealthService::BaseUri() const -{ - return "/health/"; -} - -void -HttpHealthService::HandleRequest(HttpServerRequest& Request) -{ - if (!m_Router.HandleRequest(Request)) - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv); - } -} - -} // namespace zen diff --git a/src/zenhttp/httpasio.cpp b/src/zenhttp/httpasio.cpp deleted file mode 100644 index 0c6b189f9..000000000 --- a/src/zenhttp/httpasio.cpp +++ /dev/null @@ -1,1052 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "httpasio.h" - -#include -#include -#include -#include -#include - -#include "httpparser.h" - -#include -#include -#include -#include - -ZEN_THIRD_PARTY_INCLUDES_START -#if ZEN_PLATFORM_WINDOWS -# include -# include -#endif -#include -ZEN_THIRD_PARTY_INCLUDES_END - -#define ASIO_VERBOSE_TRACE 0 - -#if ASIO_VERBOSE_TRACE -# define ZEN_TRACE_VERBOSE ZEN_TRACE -#else -# define ZEN_TRACE_VERBOSE(fmtstr, ...) -#endif - -namespace zen::asio_http { - -using namespace std::literals; - -struct HttpAcceptor; -struct HttpResponse; -struct HttpServerConnection; - -static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); -static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); -static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); -static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); -static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); -static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); -static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); - -inline spdlog::logger& -InitLogger() -{ - spdlog::logger& Logger = logging::Get("asio"); - // Logger.set_level(spdlog::level::trace); - return Logger; -} - -inline spdlog::logger& -Log() -{ - static spdlog::logger& g_Logger = InitLogger(); - return g_Logger; -} - -////////////////////////////////////////////////////////////////////////// - -struct HttpAsioServerImpl -{ -public: - HttpAsioServerImpl(); - ~HttpAsioServerImpl(); - - int Start(uint16_t Port, int ThreadCount); - void Stop(); - void RegisterService(const char* UrlPath, HttpService& Service); - HttpService* RouteRequest(std::string_view Url); - - asio::io_service m_IoService; - asio::io_service::work m_Work{m_IoService}; - std::unique_ptr m_Acceptor; - std::vector m_ThreadPool; - - struct ServiceEntry - { - std::string ServiceUrlPath; - HttpService* Service; - }; - - RwLock m_Lock; - std::vector m_UriHandlers; -}; - -/** - * This is the class which request handlers use to interact with the server instance - */ - -class HttpAsioServerRequest : public HttpServerRequest -{ -public: - HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer); - ~HttpAsioServerRequest(); - - virtual Oid ParseSessionId() const override; - virtual uint32_t ParseRequestId() const override; - - virtual IoBuffer ReadPayload() override; - virtual void WriteResponse(HttpResponseCode ResponseCode) override; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; - virtual void WriteResponseAsync(std::function&& ContinuationHandler) override; - virtual bool TryGetRanges(HttpRanges& Ranges) override; - - using HttpServerRequest::WriteResponse; - - HttpAsioServerRequest(const HttpAsioServerRequest&) = delete; - HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete; - - HttpRequestParser& m_Request; - IoBuffer m_PayloadBuffer; - std::unique_ptr m_Response; -}; - -struct HttpResponse -{ -public: - HttpResponse() = default; - explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} - - void InitializeForPayload(uint16_t ResponseCode, std::span BlobList) - { - ZEN_TRACE_CPU("asio::InitializeForPayload"); - - m_ResponseCode = ResponseCode; - const uint32_t ChunkCount = gsl::narrow(BlobList.size()); - - m_DataBuffers.reserve(ChunkCount); - - for (IoBuffer& Buffer : BlobList) - { -#if 1 - m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); -#else - IoBuffer TempBuffer = std::move(Buffer); - TempBuffer.MakeOwned(); - m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer)); -#endif - } - - uint64_t LocalDataSize = 0; - - m_AsioBuffers.push_back({}); // Placeholder for header - - for (IoBuffer& Buffer : m_DataBuffers) - { - uint64_t BufferDataSize = Buffer.Size(); - - ZEN_ASSERT(BufferDataSize); - - LocalDataSize += BufferDataSize; - - IoBufferFileReference FileRef; - if (Buffer.GetFileReference(/* out */ FileRef)) - { - // TODO: Use direct file transfer, via TransmitFile/sendfile - // - // this looks like it requires some custom asio plumbing however - - m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); - } - else - { - // Send from memory - - m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); - } - } - m_ContentLength = LocalDataSize; - - auto Headers = GetHeaders(); - m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size()); - } - - uint16_t ResponseCode() const { return m_ResponseCode; } - uint64_t ContentLength() const { return m_ContentLength; } - - const std::vector& AsioBuffers() const { return m_AsioBuffers; } - - std::string_view GetHeaders() - { - m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" - << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" - << "Content-Length: " << ContentLength() << "\r\n"sv; - - if (!m_IsKeepAlive) - { - m_Headers << "Connection: close\r\n"sv; - } - - m_Headers << "\r\n"sv; - - return m_Headers; - } - - void SuppressPayload() { m_AsioBuffers.resize(1); } - -private: - uint16_t m_ResponseCode = 0; - bool m_IsKeepAlive = true; - HttpContentType m_ContentType = HttpContentType::kBinary; - uint64_t m_ContentLength = 0; - std::vector m_DataBuffers; - std::vector m_AsioBuffers; - ExtendableStringBuilder<160> m_Headers; -}; - -////////////////////////////////////////////////////////////////////////// - -struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this -{ - HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr&& Socket); - ~HttpServerConnection(); - - std::shared_ptr AsSharedPtr() { return shared_from_this(); } - - // HttpConnectionBase implementation - - virtual void TerminateConnection() override; - virtual void HandleRequest() override; - - void HandleNewRequest(); - -private: - enum class RequestState - { - kInitialState, - kInitialRead, - kReadingMore, - kWriting, - kWritingFinal, - kDone, - kTerminated - }; - - RequestState m_RequestState = RequestState::kInitialState; - HttpRequestParser m_RequestData{*this}; - - void EnqueueRead(); - void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); - void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false); - void CloseConnection(); - - HttpAsioServerImpl& m_Server; - asio::streambuf m_RequestBuffer; - std::unique_ptr m_Socket; - std::atomic m_RequestCounter{0}; - uint32_t m_ConnectionId = 0; - Ref m_PackageHandler; - - RwLock m_ResponsesLock; - std::deque> m_Responses; -}; - -std::atomic g_ConnectionIdCounter{0}; - -HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr&& Socket) -: m_Server(Server) -, m_Socket(std::move(Socket)) -, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) -{ - ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); -} - -HttpServerConnection::~HttpServerConnection() -{ - ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); -} - -void -HttpServerConnection::HandleNewRequest() -{ - EnqueueRead(); -} - -void -HttpServerConnection::TerminateConnection() -{ - if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) - { - return; - } - - m_RequestState = RequestState::kTerminated; - ZEN_ASSERT(m_Socket); - - // Terminating, we don't care about any errors when closing socket - std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_both, Ec); - m_Socket->close(Ec); -} - -void -HttpServerConnection::EnqueueRead() -{ - if (m_RequestState == RequestState::kInitialRead) - { - m_RequestState = RequestState::kReadingMore; - } - else - { - m_RequestState = RequestState::kInitialRead; - } - - m_RequestBuffer.prepare(64 * 1024); - - asio::async_read(*m_Socket.get(), - m_RequestBuffer, - asio::transfer_at_least(1), - [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); -} - -void -HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) -{ - if (Ec) - { - if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead) - { - ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection: {}, reason: '{}'", m_ConnectionId, Ec.message()); - return; - } - else - { - ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); - return TerminateConnection(); - } - } - - ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}", - m_ConnectionId, - m_RequestCounter.load(std::memory_order_relaxed), - zen::GetCurrentThreadId(), - NiceBytes(ByteCount)); - - while (m_RequestBuffer.size()) - { - const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); - - size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size()); - if (Result == ~0ull) - { - return TerminateConnection(); - } - - m_RequestBuffer.consume(Result); - } - - switch (m_RequestState) - { - case RequestState::kDone: - case RequestState::kWritingFinal: - case RequestState::kTerminated: - break; - - default: - EnqueueRead(); - break; - } -} - -void -HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop) -{ - if (Ec) - { - ZEN_WARN("on data sent ERROR, connection: {}, reason: '{}'", m_ConnectionId, Ec.message()); - TerminateConnection(); - } - else - { - ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}", - m_ConnectionId, - m_RequestCounter.load(std::memory_order_relaxed), - zen::GetCurrentThreadId(), - NiceBytes(ByteCount)); - - if (!m_RequestData.IsKeepAlive()) - { - CloseConnection(); - } - else - { - if (Pop) - { - RwLock::ExclusiveLockScope _(m_ResponsesLock); - m_Responses.pop_front(); - } - - m_RequestCounter.fetch_add(1); - } - } -} - -void -HttpServerConnection::CloseConnection() -{ - if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) - { - return; - } - ZEN_ASSERT(m_Socket); - m_RequestState = RequestState::kDone; - - std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); - if (Ec) - { - ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); - } - m_Socket->close(Ec); - if (Ec) - { - ZEN_WARN("socket close ERROR, reason '{}'", Ec.message()); - } -} - -void -HttpServerConnection::HandleRequest() -{ - if (!m_RequestData.IsKeepAlive()) - { - m_RequestState = RequestState::kWritingFinal; - - std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); - - if (Ec) - { - ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); - } - } - else - { - m_RequestState = RequestState::kWriting; - } - - if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) - { - ZEN_TRACE_CPU("asio::HandleRequest"); - - HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body()); - - ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed)); - - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) - { - try - { - Service->HandleRequest(Request); - } - catch (std::system_error& SystemError) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) - { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); - } - else - { - ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); - } - } - catch (std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); - } - } - - if (std::unique_ptr Response = std::move(Request.m_Response)) - { - // Transmit the response - - if (m_RequestData.RequestVerb() == HttpVerb::kHead) - { - Response->SuppressPayload(); - } - - auto ResponseBuffers = Response->AsioBuffers(); - - uint64_t ResponseLength = 0; - - for (auto& Buffer : ResponseBuffers) - { - ResponseLength += Buffer.size(); - } - - { - RwLock::ExclusiveLockScope _(m_ResponsesLock); - m_Responses.push_back(std::move(Response)); - } - - // TODO: should cork/uncork for Linux? - - { - ZEN_TRACE_CPU("asio::async_write"); - asio::async_write(*m_Socket.get(), - ResponseBuffers, - asio::transfer_exactly(ResponseLength), - [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, true); - }); - } - return; - } - } - - if (m_RequestData.RequestVerb() == HttpVerb::kHead) - { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "\r\n"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Connection: close\r\n" - "\r\n"sv; - } - - asio::async_write( - *m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); - } - else - { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "No suitable route found"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "Connection: close\r\n" - "\r\n" - "No suitable route found"sv; - } - - asio::async_write( - *m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); - } -} - -////////////////////////////////////////////////////////////////////////// - -struct HttpAcceptor -{ - HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort) - : m_Server(Server) - , m_IoService(IoService) - , m_Acceptor(m_IoService, asio::ip::tcp::v6()) - { - m_Acceptor.set_option(asio::ip::v6_only(false)); -#if ZEN_PLATFORM_WINDOWS - // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms - typedef asio::detail::socket_option::boolean excluse_address; - m_Acceptor.set_option(excluse_address(true)); -#else // ZEN_PLATFORM_WINDOWS - m_Acceptor.set_option(asio::socket_base::reuse_address(false)); -#endif // ZEN_PLATFORM_WINDOWS - - m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); - m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - uint16_t EffectivePort = BasePort; - - asio::error_code BindErrorCode; - m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); - // Sharing violation implies the port is being used by another process - for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset) - { - EffectivePort = BasePort + (PortOffset * 100); - m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); - } - if (BindErrorCode == asio::error::access_denied) - { - EffectivePort = 0; - m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); - } - if (BindErrorCode) - { - ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message()); - } - -#if ZEN_PLATFORM_WINDOWS - // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. - // This must be used by both the client and server side, and is only effective in the absence of - // Windows Filtering Platform (WFP) callouts which can be installed by security software. - // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path - SOCKET NativeSocket = m_Acceptor.native_handle(); - int LoopbackOptionValue = 1; - DWORD OptionNumberOfBytesReturned = 0; - WSAIoctl(NativeSocket, - SIO_LOOPBACK_FAST_PATH, - &LoopbackOptionValue, - sizeof(LoopbackOptionValue), - NULL, - 0, - &OptionNumberOfBytesReturned, - 0, - 0); -#endif - m_Acceptor.listen(); - - ZEN_INFO("Started asio server at port '{}'", EffectivePort); - } - - void Start() - { - m_Acceptor.listen(); - InitAccept(); - } - - void Stop() { m_IsStopped = true; } - - void InitAccept() - { - auto SocketPtr = std::make_unique(m_IoService); - asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); - - m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { - if (Ec) - { - ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", - m_Acceptor.local_endpoint().address().to_string(), - m_Acceptor.local_endpoint().port(), - Ec.message()); - } - else - { - // New connection established, pass socket ownership into connection object - // and initiate request handling loop. The connection lifetime is - // managed by the async read/write loop by passing the shared - // reference to the callbacks. - - Socket->set_option(asio::ip::tcp::no_delay(true)); - Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - auto Conn = std::make_shared(m_Server, std::move(Socket)); - Conn->HandleNewRequest(); - } - - if (!m_IsStopped.load()) - { - InitAccept(); - } - else - { - std::error_code CloseEc; - m_Acceptor.close(CloseEc); - if (CloseEc) - { - ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message()); - } - } - }); - } - - int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } - -private: - HttpAsioServerImpl& m_Server; - asio::io_service& m_IoService; - asio::ip::tcp::acceptor m_Acceptor; - std::atomic m_IsStopped{false}; -}; - -////////////////////////////////////////////////////////////////////////// - -HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer) -: m_Request(Request) -, m_PayloadBuffer(std::move(PayloadBuffer)) -{ - const int PrefixLength = Service.UriPrefixLength(); - - std::string_view Uri = Request.Url(); - Uri.remove_prefix(std::min(PrefixLength, static_cast(Uri.size()))); - m_Uri = Uri; - m_UriWithExtension = Uri; - m_QueryString = Request.QueryString(); - - m_Verb = Request.RequestVerb(); - m_ContentLength = Request.Body().Size(); - m_ContentType = Request.ContentType(); - - HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; - - // Parse any extension, to allow requesting a particular response encoding via the URL - - { - std::string_view UriSuffix8{m_Uri}; - - const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); - - if (LastComponentIndex != std::string_view::npos) - { - UriSuffix8.remove_prefix(LastComponentIndex); - } - - const size_t LastDotIndex = UriSuffix8.find_last_of('.'); - - if (LastDotIndex != std::string_view::npos) - { - UriSuffix8.remove_prefix(LastDotIndex + 1); - - AcceptContentType = ParseContentType(UriSuffix8); - - if (AcceptContentType != HttpContentType::kUnknownContentType) - { - m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); - } - } - } - - // It an explicit content type extension was specified then we'll use that over any - // Accept: header value that may be present - - if (AcceptContentType != HttpContentType::kUnknownContentType) - { - m_AcceptType = AcceptContentType; - } - else - { - m_AcceptType = Request.AcceptType(); - } -} - -HttpAsioServerRequest::~HttpAsioServerRequest() -{ -} - -Oid -HttpAsioServerRequest::ParseSessionId() const -{ - return m_Request.SessionId(); -} - -uint32_t -HttpAsioServerRequest::ParseRequestId() const -{ - return m_Request.RequestId(); -} - -IoBuffer -HttpAsioServerRequest::ReadPayload() -{ - return m_PayloadBuffer; -} - -void -HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) -{ - ZEN_ASSERT(!m_Response); - - m_Response.reset(new HttpResponse(HttpContentType::kBinary)); - std::array Empty; - - m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); -} - -void -HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) -{ - ZEN_ASSERT(!m_Response); - - m_Response.reset(new HttpResponse(ContentType)); - m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); -} - -void -HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) -{ - ZEN_ASSERT(!m_Response); - m_Response.reset(new HttpResponse(ContentType)); - - IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); - std::array SingleBufferList({MessageBuffer}); - - m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); -} - -void -HttpAsioServerRequest::WriteResponseAsync(std::function&& ContinuationHandler) -{ - ZEN_ASSERT(!m_Response); - - // Not one bit async, innit - ContinuationHandler(*this); -} - -bool -HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges) -{ - return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); -} - -////////////////////////////////////////////////////////////////////////// - -HttpAsioServerImpl::HttpAsioServerImpl() -{ -} - -HttpAsioServerImpl::~HttpAsioServerImpl() -{ -} - -int -HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount) -{ - ZEN_ASSERT(ThreadCount > 0); - - ZEN_INFO("starting asio http with {} service threads", ThreadCount); - - m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port)); - m_Acceptor->Start(); - - // This should consist of a set of minimum threads and grow on demand to - // meet concurrency needs? Right now we end up allocating a large number - // of threads even if we never end up using all of them, which seems - // wasteful. It's also not clear how the demand for concurrency should - // be balanced with the engine side - ideally we'd have some kind of - // global scheduling to prevent one side from preventing the other side - // from making progress. Or at the very least, thread priorities should - // be considered. - - for (int i = 0; i < ThreadCount; ++i) - { - m_ThreadPool.emplace_back([this, Index = i + 1] { - SetCurrentThreadName(fmt::format("asio_io_{}", Index)); - - try - { - m_IoService.run(); - } - catch (std::exception& e) - { - ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what()); - } - }); - } - - ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort()); - - return m_Acceptor->GetAcceptPort(); -} - -void -HttpAsioServerImpl::Stop() -{ - m_Acceptor->Stop(); - m_IoService.stop(); - for (auto& Thread : m_ThreadPool) - { - Thread.join(); - } -} - -void -HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) -{ - std::string_view UrlPath(InUrlPath); - Service.SetUriPrefixLength(UrlPath.size()); - if (!UrlPath.empty() && UrlPath.back() == '/') - { - UrlPath.remove_suffix(1); - } - - RwLock::ExclusiveLockScope _(m_Lock); - m_UriHandlers.push_back({std::string(UrlPath), &Service}); -} - -HttpService* -HttpAsioServerImpl::RouteRequest(std::string_view Url) -{ - RwLock::SharedLockScope _(m_Lock); - - HttpService* CandidateService = nullptr; - std::string::size_type CandidateMatchSize = 0; - for (const ServiceEntry& SvcEntry : m_UriHandlers) - { - const std::string& SvcUrl = SvcEntry.ServiceUrlPath; - const std::string::size_type SvcUrlSize = SvcUrl.size(); - if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && - ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) - { - CandidateMatchSize = SvcUrl.size(); - CandidateService = SvcEntry.Service; - } - } - - return CandidateService; -} - -} // namespace zen::asio_http - -////////////////////////////////////////////////////////////////////////// - -namespace zen { - -class HttpAsioServer : public HttpServer -{ -public: - HttpAsioServer(unsigned int ThreadCount); - ~HttpAsioServer(); - - virtual void RegisterService(HttpService& Service) override; - virtual int Initialize(int BasePort) override; - virtual void Run(bool IsInteractiveSession) override; - virtual void RequestExit() override; - virtual void Close() override; - -private: - Event m_ShutdownEvent; - int m_BasePort = 0; - unsigned int m_ThreadCount = 0; - - std::unique_ptr m_Impl; -}; - -HttpAsioServer::HttpAsioServer(unsigned int ThreadCount) -: m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) -, m_Impl(std::make_unique()) -{ - ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser)); -} - -HttpAsioServer::~HttpAsioServer() -{ - if (m_Impl) - { - ZEN_ERROR("~HttpAsioServer() called without calling Close() first"); - } -} - -void -HttpAsioServer::Close() -{ - try - { - m_Impl->Stop(); - } - catch (std::exception& ex) - { - ZEN_WARN("Caught exception stopping http asio server: {}", ex.what()); - } - m_Impl.reset(); -} - -void -HttpAsioServer::RegisterService(HttpService& Service) -{ - m_Impl->RegisterService(Service.BaseUri(), Service); -} - -int -HttpAsioServer::Initialize(int BasePort) -{ - m_BasePort = m_Impl->Start(gsl::narrow(BasePort), m_ThreadCount); - return m_BasePort; -} - -void -HttpAsioServer::Run(bool IsInteractive) -{ - const bool TestMode = !IsInteractive; - - int WaitTimeout = -1; - if (!TestMode) - { - WaitTimeout = 1000; - } - -#if ZEN_PLATFORM_WINDOWS - if (TestMode == false) - { - zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit"); - } - - do - { - if (!TestMode && _kbhit() != 0) - { - char c = (char)_getch(); - - if (c == 27 || c == 'Q' || c == 'q') - { - RequestApplicationExit(0); - } - } - - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); -#else - if (TestMode == false) - { - zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit"); - } - - do - { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); -#endif -} - -void -HttpAsioServer::RequestExit() -{ - m_ShutdownEvent.Set(); -} - -Ref -CreateHttpAsioServer(unsigned int ThreadCount) -{ - return Ref{new HttpAsioServer(ThreadCount)}; -} - -} // namespace zen diff --git a/src/zenhttp/httpasio.h b/src/zenhttp/httpasio.h deleted file mode 100644 index 2366f3437..000000000 --- a/src/zenhttp/httpasio.h +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -namespace zen { - -Ref CreateHttpAsioServer(unsigned int ThreadCount); - -} // namespace zen diff --git a/src/zenhttp/httpnull.cpp b/src/zenhttp/httpnull.cpp deleted file mode 100644 index 658f51831..000000000 --- a/src/zenhttp/httpnull.cpp +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "httpnull.h" - -#include - -#if ZEN_PLATFORM_WINDOWS -# include -#endif - -namespace zen { - -HttpNullServer::HttpNullServer() -{ -} - -HttpNullServer::~HttpNullServer() -{ -} - -void -HttpNullServer::RegisterService(HttpService& Service) -{ - ZEN_UNUSED(Service); -} - -int -HttpNullServer::Initialize(int BasePort) -{ - return BasePort; -} - -void -HttpNullServer::Run(bool IsInteractiveSession) -{ - const bool TestMode = !IsInteractiveSession; - - int WaitTimeout = -1; - if (!TestMode) - { - WaitTimeout = 1000; - } - -#if ZEN_PLATFORM_WINDOWS - if (TestMode == false) - { - zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Press ESC or Q to quit"); - } - - do - { - if (!TestMode && _kbhit() != 0) - { - char c = (char)_getch(); - - if (c == 27 || c == 'Q' || c == 'q') - { - RequestApplicationExit(0); - } - } - - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); -#else - if (TestMode == false) - { - zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Ctrl-C to quit"); - } - - do - { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); -#endif -} - -void -HttpNullServer::RequestExit() -{ - m_ShutdownEvent.Set(); -} - -void -HttpNullServer::Close() -{ -} - -} // namespace zen diff --git a/src/zenhttp/httpnull.h b/src/zenhttp/httpnull.h deleted file mode 100644 index 965e729f7..000000000 --- a/src/zenhttp/httpnull.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include -#include - -namespace zen { - -/** - * @brief Null implementation of "http" server. Does nothing - */ - -class HttpNullServer : public HttpServer -{ -public: - HttpNullServer(); - ~HttpNullServer(); - - virtual void RegisterService(HttpService& Service) override; - virtual int Initialize(int BasePort) override; - virtual void Run(bool IsInteractiveSession) override; - virtual void RequestExit() override; - virtual void Close() override; - -private: - Event m_ShutdownEvent; -}; - -} // namespace zen diff --git a/src/zenhttp/httpparser.cpp b/src/zenhttp/httpparser.cpp deleted file mode 100644 index 6b987151a..000000000 --- a/src/zenhttp/httpparser.cpp +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "httpparser.h" - -#include -#include - -namespace zen { - -using namespace std::literals; - -static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); -static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); -static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); -static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); -static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); -static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); -static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); - -////////////////////////////////////////////////////////////////////////// -// -// HttpRequestParser -// - -http_parser_settings HttpRequestParser::s_ParserSettings{ - .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, - .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, - .on_status = - [](http_parser* p, const char* Data, size_t ByteCount) { - ZEN_UNUSED(p, Data, ByteCount); - return 0; - }, - .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, - .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, - .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, - .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, - .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, - .on_chunk_header{}, - .on_chunk_complete{}}; - -HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection) -{ - http_parser_init(&m_Parser, HTTP_REQUEST); - m_Parser.data = this; - - ResetState(); -} - -HttpRequestParser::~HttpRequestParser() -{ -} - -size_t -HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize) -{ - const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); - - http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser)); - - if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE) - { - ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno)); - return ~0ull; - } - - return ConsumedBytes; -} - -int -HttpRequestParser::OnUrl(const char* Data, size_t Bytes) -{ - if (!m_Url) - { - ZEN_ASSERT_SLOW(m_UrlLength == 0); - m_Url = m_HeaderCursor; - } - - const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; - - if (RemainingBufferSpace < Bytes) - { - ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; - } - - memcpy(m_HeaderCursor, Data, Bytes); - m_HeaderCursor += Bytes; - m_UrlLength += Bytes; - - return 0; -} - -int -HttpRequestParser::OnHeader(const char* Data, size_t Bytes) -{ - if (m_CurrentHeaderValueLength) - { - AppendCurrentHeader(); - - m_CurrentHeaderNameLength = 0; - m_CurrentHeaderValueLength = 0; - m_CurrentHeaderName = m_HeaderCursor; - } - else if (m_CurrentHeaderName == nullptr) - { - m_CurrentHeaderName = m_HeaderCursor; - } - - const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; - if (RemainingBufferSpace < Bytes) - { - ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; - } - - memcpy(m_HeaderCursor, Data, Bytes); - m_HeaderCursor += Bytes; - m_CurrentHeaderNameLength += Bytes; - - return 0; -} - -void -HttpRequestParser::AppendCurrentHeader() -{ - std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength); - std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength); - - const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); - - if (HeaderHash == HashContentLength) - { - m_ContentLengthHeaderIndex = (int8_t)m_Headers.size(); - } - else if (HeaderHash == HashAccept) - { - m_AcceptHeaderIndex = (int8_t)m_Headers.size(); - } - else if (HeaderHash == HashContentType) - { - m_ContentTypeHeaderIndex = (int8_t)m_Headers.size(); - } - else if (HeaderHash == HashSession) - { - m_SessionId = Oid::FromHexString(HeaderValue); - } - else if (HeaderHash == HashRequest) - { - std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); - } - else if (HeaderHash == HashExpect) - { - if (HeaderValue == "100-continue"sv) - { - // We don't currently do anything with this - m_Expect100Continue = true; - } - else - { - ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); - } - } - else if (HeaderHash == HashRange) - { - m_RangeHeaderIndex = (int8_t)m_Headers.size(); - } - - m_Headers.emplace_back(HeaderName, HeaderValue); -} - -int -HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes) -{ - if (m_CurrentHeaderValueLength == 0) - { - m_CurrentHeaderValue = m_HeaderCursor; - } - - const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; - if (RemainingBufferSpace < Bytes) - { - ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; - } - - memcpy(m_HeaderCursor, Data, Bytes); - m_HeaderCursor += Bytes; - m_CurrentHeaderValueLength += Bytes; - - return 0; -} - -static void -NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl) -{ - bool LastCharWasSeparator = false; - for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex) - { - const char UrlChar = Url[UrlIndex]; - const bool IsSeparator = (UrlChar == '/'); - - if (IsSeparator && LastCharWasSeparator) - { - if (NormalizedUrl.empty()) - { - NormalizedUrl.reserve(UrlLength); - NormalizedUrl.append(Url, UrlIndex); - } - - if (!LastCharWasSeparator) - { - NormalizedUrl.push_back('/'); - } - } - else if (!NormalizedUrl.empty()) - { - NormalizedUrl.push_back(UrlChar); - } - - LastCharWasSeparator = IsSeparator; - } -} - -int -HttpRequestParser::OnHeadersComplete() -{ - try - { - if (m_CurrentHeaderValueLength) - { - AppendCurrentHeader(); - } - - m_KeepAlive = !!http_should_keep_alive(&m_Parser); - - switch (m_Parser.method) - { - case HTTP_GET: - m_RequestVerb = HttpVerb::kGet; - break; - - case HTTP_POST: - m_RequestVerb = HttpVerb::kPost; - break; - - case HTTP_PUT: - m_RequestVerb = HttpVerb::kPut; - break; - - case HTTP_DELETE: - m_RequestVerb = HttpVerb::kDelete; - break; - - case HTTP_HEAD: - m_RequestVerb = HttpVerb::kHead; - break; - - case HTTP_COPY: - m_RequestVerb = HttpVerb::kCopy; - break; - - case HTTP_OPTIONS: - m_RequestVerb = HttpVerb::kOptions; - break; - - default: - ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method)); - break; - } - - std::string_view Url(m_Url, m_UrlLength); - - if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos) - { - m_UrlLength = QuerySplit; - m_QueryString = m_Url + QuerySplit + 1; - m_QueryLength = Url.size() - QuerySplit - 1; - } - - NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl); - - if (m_ContentLengthHeaderIndex >= 0) - { - std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value; - uint64_t ContentLength = 0; - std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength); - - if (ContentLength) - { - m_BodyBuffer = IoBuffer(ContentLength); - } - - m_BodyBuffer.SetContentType(ContentType()); - - m_BodyPosition = 0; - } - } - catch (const std::exception& Ex) - { - ZEN_WARN("HttpRequestParser::OnHeadersComplete failed. Reason '{}'", Ex.what()); - return -1; - } - return 0; -} - -int -HttpRequestParser::OnBody(const char* Data, size_t Bytes) -{ - if (m_BodyPosition + Bytes > m_BodyBuffer.Size()) - { - ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes", - (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); - return 1; - } - memcpy(reinterpret_cast(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); - m_BodyPosition += Bytes; - - if (http_body_is_final(&m_Parser)) - { - if (m_BodyPosition != m_BodyBuffer.Size()) - { - ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); - return 1; - } - } - - return 0; -} - -void -HttpRequestParser::ResetState() -{ - m_HeaderCursor = m_HeaderBuffer; - m_CurrentHeaderName = nullptr; - m_CurrentHeaderNameLength = 0; - m_CurrentHeaderValue = nullptr; - m_CurrentHeaderValueLength = 0; - m_CurrentHeaderName = nullptr; - m_Url = nullptr; - m_UrlLength = 0; - m_QueryString = nullptr; - m_QueryLength = 0; - m_ContentLengthHeaderIndex = -1; - m_AcceptHeaderIndex = -1; - m_ContentTypeHeaderIndex = -1; - m_RangeHeaderIndex = -1; - m_Expect100Continue = false; - m_BodyBuffer = {}; - m_BodyPosition = 0; - m_Headers.clear(); - m_NormalizedUrl.clear(); -} - -int -HttpRequestParser::OnMessageBegin() -{ - return 0; -} - -int -HttpRequestParser::OnMessageComplete() -{ - m_Connection.HandleRequest(); - - ResetState(); - - return 0; -} - -} // namespace zen diff --git a/src/zenhttp/httpparser.h b/src/zenhttp/httpparser.h deleted file mode 100644 index 219ac351d..000000000 --- a/src/zenhttp/httpparser.h +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include -#include - -ZEN_THIRD_PARTY_INCLUDES_START -#include -ZEN_THIRD_PARTY_INCLUDES_END - -namespace zen { - -class HttpRequestParserCallbacks -{ -public: - virtual ~HttpRequestParserCallbacks() = default; - virtual void HandleRequest() = 0; - virtual void TerminateConnection() = 0; -}; - -struct HttpRequestParser -{ - explicit HttpRequestParser(HttpRequestParserCallbacks& Connection); - ~HttpRequestParser(); - - size_t ConsumeData(const char* InputData, size_t DataSize); - void ResetState(); - - HttpVerb RequestVerb() const { return m_RequestVerb; } - bool IsKeepAlive() const { return m_KeepAlive; } - std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; } - std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); } - IoBuffer Body() { return m_BodyBuffer; } - - inline HttpContentType ContentType() - { - if (m_ContentTypeHeaderIndex < 0) - { - return HttpContentType::kUnknownContentType; - } - - return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value); - } - - inline HttpContentType AcceptType() - { - if (m_AcceptHeaderIndex < 0) - { - return HttpContentType::kUnknownContentType; - } - - return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value); - } - - Oid SessionId() const { return m_SessionId; } - int RequestId() const { return m_RequestId; } - - std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); } - -private: - struct HeaderEntry - { - HeaderEntry() = default; - - HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {} - - std::string_view Name; - std::string_view Value; - }; - - HttpRequestParserCallbacks& m_Connection; - char* m_HeaderCursor = m_HeaderBuffer; - char* m_Url = nullptr; - size_t m_UrlLength = 0; - char* m_QueryString = nullptr; - size_t m_QueryLength = 0; - char* m_CurrentHeaderName = nullptr; // Used while parsing headers - size_t m_CurrentHeaderNameLength = 0; - char* m_CurrentHeaderValue = nullptr; // Used while parsing headers - size_t m_CurrentHeaderValueLength = 0; - std::vector m_Headers; - int8_t m_ContentLengthHeaderIndex; - int8_t m_AcceptHeaderIndex; - int8_t m_ContentTypeHeaderIndex; - int8_t m_RangeHeaderIndex; - HttpVerb m_RequestVerb; - bool m_KeepAlive = false; - bool m_Expect100Continue = false; - int m_RequestId = -1; - Oid m_SessionId{}; - IoBuffer m_BodyBuffer; - uint64_t m_BodyPosition = 0; - http_parser m_Parser; - char m_HeaderBuffer[1024]; - std::string m_NormalizedUrl; - - void AppendCurrentHeader(); - - int OnMessageBegin(); - int OnUrl(const char* Data, size_t Bytes); - int OnHeader(const char* Data, size_t Bytes); - int OnHeaderValue(const char* Data, size_t Bytes); - int OnHeadersComplete(); - int OnBody(const char* Data, size_t Bytes); - int OnMessageComplete(); - - static HttpRequestParser* GetThis(http_parser* Parser) { return reinterpret_cast(Parser->data); } - static http_parser_settings s_ParserSettings; -}; - -} // namespace zen diff --git a/src/zenhttp/httpplugin.cpp b/src/zenhttp/httpplugin.cpp deleted file mode 100644 index 2e934473e..000000000 --- a/src/zenhttp/httpplugin.cpp +++ /dev/null @@ -1,781 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include - -#if ZEN_WITH_PLUGINS - -# include "httpparser.h" - -# include -# include -# include -# include -# include - -# include -# include - -# if ZEN_PLATFORM_WINDOWS -# include -# endif - -# define PLUGIN_VERBOSE_TRACE 1 - -# if PLUGIN_VERBOSE_TRACE -# define ZEN_TRACE_VERBOSE ZEN_TRACE -# else -# define ZEN_TRACE_VERBOSE(fmtstr, ...) -# endif - -namespace zen { - -struct HttpPluginServerImpl; -struct HttpPluginResponse; - -using namespace std::literals; - -////////////////////////////////////////////////////////////////////////// - -struct HttpPluginConnectionHandler : public TransportServerConnection, public HttpRequestParserCallbacks, RefCounted -{ - virtual uint32_t AddRef() const override; - virtual uint32_t Release() const override; - - virtual void OnBytesRead(const void* Buffer, size_t DataSize) override; - - // HttpRequestParserCallbacks - - virtual void HandleRequest() override; - virtual void TerminateConnection() override; - - void Initialize(TransportConnection* Transport, HttpPluginServerImpl& Server); - -private: - enum class RequestState - { - kInitialState, - kInitialRead, - kReadingMore, - kWriting, // Currently writing response, connection will be re-used - kWritingFinal, // Writing response, connection will be closed - kDone, - kTerminated - }; - - RequestState m_RequestState = RequestState::kInitialState; - HttpRequestParser m_RequestParser{*this}; - - uint32_t m_ConnectionId = 0; - Ref m_PackageHandler; - - TransportConnection* m_TransportConnection = nullptr; - HttpPluginServerImpl* m_Server = nullptr; -}; - -////////////////////////////////////////////////////////////////////////// - -struct HttpPluginServerImpl : public TransportServer -{ - HttpPluginServerImpl(); - ~HttpPluginServerImpl(); - - void AddPlugin(Ref Plugin); - void RemovePlugin(Ref Plugin); - - void Start(); - void Stop(); - - void RegisterService(const char* InUrlPath, HttpService& Service); - HttpService* RouteRequest(std::string_view Url); - - struct ServiceEntry - { - std::string ServiceUrlPath; - HttpService* Service; - }; - - RwLock m_Lock; - std::vector m_UriHandlers; - std::vector> m_Plugins; - - // TransportServer - - virtual TransportServerConnection* CreateConnectionHandler(TransportConnection* Connection) override; -}; - -/** This is the class which request handlers interface with when - generating responses - */ - -class HttpPluginServerRequest : public HttpServerRequest -{ -public: - HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer); - ~HttpPluginServerRequest(); - - HttpPluginServerRequest(const HttpPluginServerRequest&) = delete; - HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; - - virtual Oid ParseSessionId() const override; - virtual uint32_t ParseRequestId() const override; - - virtual IoBuffer ReadPayload() override; - virtual void WriteResponse(HttpResponseCode ResponseCode) override; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; - virtual void WriteResponseAsync(std::function&& ContinuationHandler) override; - virtual bool TryGetRanges(HttpRanges& Ranges) override; - - using HttpServerRequest::WriteResponse; - - HttpRequestParser& m_Request; - IoBuffer m_PayloadBuffer; - std::unique_ptr m_Response; -}; - -////////////////////////////////////////////////////////////////////////// - -struct HttpPluginResponse -{ -public: - HttpPluginResponse() = default; - explicit HttpPluginResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} - - void InitializeForPayload(uint16_t ResponseCode, std::span BlobList); - - inline uint16_t ResponseCode() const { return m_ResponseCode; } - inline uint64_t ContentLength() const { return m_ContentLength; } - - const std::vector& ResponseBuffers() const { return m_ResponseBuffers; } - void SuppressPayload() { m_ResponseBuffers.resize(1); } - -private: - uint16_t m_ResponseCode = 0; - bool m_IsKeepAlive = true; - HttpContentType m_ContentType = HttpContentType::kBinary; - uint64_t m_ContentLength = 0; - std::vector m_ResponseBuffers; - ExtendableStringBuilder<160> m_Headers; - - std::string_view GetHeaders(); -}; - -void -HttpPluginResponse::InitializeForPayload(uint16_t ResponseCode, std::span BlobList) -{ - ZEN_TRACE_CPU("http_plugin::InitializeForPayload"); - - m_ResponseCode = ResponseCode; - const uint32_t ChunkCount = gsl::narrow(BlobList.size()); - - m_ResponseBuffers.reserve(ChunkCount + 1); - m_ResponseBuffers.push_back({}); // Placeholder for header - - uint64_t TotalDataSize = 0; - - for (IoBuffer& Buffer : BlobList) - { - uint64_t BufferDataSize = Buffer.Size(); - - ZEN_ASSERT(BufferDataSize); - - TotalDataSize += BufferDataSize; - - IoBufferFileReference FileRef; - if (Buffer.GetFileReference(/* out */ FileRef)) - { - // TODO: Use direct file transfer, via TransmitFile/sendfile - - m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned(); - } - else - { - // Send from memory - - m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned(); - } - } - m_ContentLength = TotalDataSize; - - auto Headers = GetHeaders(); - m_ResponseBuffers[0] = IoBufferBuilder::MakeCloneFromMemory(Headers.data(), Headers.size()); -} - -std::string_view -HttpPluginResponse::GetHeaders() -{ - m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" - << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" - << "Content-Length: " << ContentLength() << "\r\n"sv; - - if (!m_IsKeepAlive) - { - m_Headers << "Connection: close\r\n"sv; - } - - m_Headers << "\r\n"sv; - - return m_Headers; -} - -////////////////////////////////////////////////////////////////////////// - -void -HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPluginServerImpl& Server) -{ - m_TransportConnection = Transport; - m_Server = &Server; -} - -uint32_t -HttpPluginConnectionHandler::AddRef() const -{ - return RefCounted::AddRef(); -} - -uint32_t -HttpPluginConnectionHandler::Release() const -{ - return RefCounted::Release(); -} - -void -HttpPluginConnectionHandler::OnBytesRead(const void* Buffer, size_t AvailableBytes) -{ - while (AvailableBytes) - { - const size_t ConsumedBytes = m_RequestParser.ConsumeData((const char*)Buffer, AvailableBytes); - - if (ConsumedBytes == ~0ull) - { - // terminate connection - - return TerminateConnection(); - } - - Buffer = reinterpret_cast(Buffer) + ConsumedBytes; - AvailableBytes -= ConsumedBytes; - } -} - -// HttpRequestParserCallbacks - -void -HttpPluginConnectionHandler::HandleRequest() -{ - if (!m_RequestParser.IsKeepAlive()) - { - // Once response has been written, connection is done - m_RequestState = RequestState::kWritingFinal; - - // We're not going to read any more data from this socket - - const bool Receive = true; - const bool Transmit = false; - m_TransportConnection->Shutdown(Receive, Transmit); - } - else - { - m_RequestState = RequestState::kWriting; - } - - auto SendBuffer = [&](const IoBuffer& InBuffer) -> int64_t { - const char* Buffer = reinterpret_cast(InBuffer.GetData()); - size_t Bytes = InBuffer.GetSize(); - - return m_TransportConnection->WriteBytes(Buffer, Bytes); - }; - - // Generate response - - if (HttpService* Service = m_Server->RouteRequest(m_RequestParser.Url())) - { - ZEN_TRACE_CPU("http_plugin::HandleRequest"); - - HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body()); - - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) - { - try - { - Service->HandleRequest(Request); - } - catch (std::system_error& SystemError) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) - { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); - } - else - { - ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); - } - } - catch (std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); - } - } - - if (std::unique_ptr Response = std::move(Request.m_Response)) - { - // Transmit the response - - if (m_RequestParser.RequestVerb() == HttpVerb::kHead) - { - Response->SuppressPayload(); - } - - const std::vector& ResponseBuffers = Response->ResponseBuffers(); - - //// TODO: should cork/uncork for Linux? - - for (const IoBuffer& Buffer : ResponseBuffers) - { - int64_t SentBytes = SendBuffer(Buffer); - - if (SentBytes < 0) - { - TerminateConnection(); - - return; - } - } - - return; - } - } - - // No route found for request - - std::string_view Response; - - if (m_RequestParser.RequestVerb() == HttpVerb::kHead) - { - if (m_RequestParser.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "\r\n"sv; - } - else - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Connection: close\r\n" - "\r\n"sv; - } - } - else - { - if (m_RequestParser.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "No suitable route found"sv; - } - else - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "Connection: close\r\n" - "\r\n" - "No suitable route found"sv; - } - } - - const int64_t SentBytes = SendBuffer(IoBufferBuilder::MakeFromMemory(MakeMemoryView(Response))); - - if (SentBytes < 0) - { - TerminateConnection(); - - return; - } -} - -void -HttpPluginConnectionHandler::TerminateConnection() -{ - ZEN_ASSERT(m_TransportConnection); - m_TransportConnection->CloseConnection(); -} - -////////////////////////////////////////////////////////////////////////// - -HttpPluginServerRequest::HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer) -: m_Request(Request) -, m_PayloadBuffer(std::move(PayloadBuffer)) -{ - const int PrefixLength = Service.UriPrefixLength(); - - std::string_view Uri = Request.Url(); - Uri.remove_prefix(std::min(PrefixLength, static_cast(Uri.size()))); - m_Uri = Uri; - m_UriWithExtension = Uri; - m_QueryString = Request.QueryString(); - - m_Verb = Request.RequestVerb(); - m_ContentLength = Request.Body().Size(); - m_ContentType = Request.ContentType(); - - HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; - - // Parse any extension, to allow requesting a particular response encoding via the URL - - { - std::string_view UriSuffix8{m_Uri}; - - const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); - - if (LastComponentIndex != std::string_view::npos) - { - UriSuffix8.remove_prefix(LastComponentIndex); - } - - const size_t LastDotIndex = UriSuffix8.find_last_of('.'); - - if (LastDotIndex != std::string_view::npos) - { - UriSuffix8.remove_prefix(LastDotIndex + 1); - - AcceptContentType = ParseContentType(UriSuffix8); - - if (AcceptContentType != HttpContentType::kUnknownContentType) - { - m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); - } - } - } - - // It an explicit content type extension was specified then we'll use that over any - // Accept: header value that may be present - - if (AcceptContentType != HttpContentType::kUnknownContentType) - { - m_AcceptType = AcceptContentType; - } - else - { - m_AcceptType = Request.AcceptType(); - } -} - -HttpPluginServerRequest::~HttpPluginServerRequest() -{ -} - -Oid -HttpPluginServerRequest::ParseSessionId() const -{ - return m_Request.SessionId(); -} - -uint32_t -HttpPluginServerRequest::ParseRequestId() const -{ - return m_Request.RequestId(); -} - -IoBuffer -HttpPluginServerRequest::ReadPayload() -{ - return m_PayloadBuffer; -} - -void -HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode) -{ - ZEN_ASSERT(!m_Response); - - m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary)); - std::array Empty; - - m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); -} - -void -HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) -{ - ZEN_ASSERT(!m_Response); - - m_Response.reset(new HttpPluginResponse(ContentType)); - m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); -} - -void -HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) -{ - ZEN_ASSERT(!m_Response); - m_Response.reset(new HttpPluginResponse(ContentType)); - - IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); - std::array SingleBufferList({MessageBuffer}); - - m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); -} - -void -HttpPluginServerRequest::WriteResponseAsync(std::function&& ContinuationHandler) -{ - ZEN_ASSERT(!m_Response); - - // Not one bit async, innit - ContinuationHandler(*this); -} - -bool -HttpPluginServerRequest::TryGetRanges(HttpRanges& Ranges) -{ - return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); -} - -////////////////////////////////////////////////////////////////////////// - -HttpPluginServerImpl::HttpPluginServerImpl() -{ -} - -HttpPluginServerImpl::~HttpPluginServerImpl() -{ -} - -TransportServerConnection* -HttpPluginServerImpl::CreateConnectionHandler(TransportConnection* Connection) -{ - HttpPluginConnectionHandler* Handler{new HttpPluginConnectionHandler()}; - Handler->Initialize(Connection, *this); - return Handler; -} - -void -HttpPluginServerImpl::Start() -{ - RwLock::ExclusiveLockScope _(m_Lock); - - for (auto& Plugin : m_Plugins) - { - try - { - Plugin->Initialize(this); - } - catch (std::exception& Ex) - { - ZEN_WARN("exception caught during plugin initialization: {}", Ex.what()); - } - } -} - -void -HttpPluginServerImpl::Stop() -{ - RwLock::ExclusiveLockScope _(m_Lock); - - for (auto& Plugin : m_Plugins) - { - try - { - Plugin->Shutdown(); - } - catch (std::exception& Ex) - { - ZEN_WARN("exception caught during plugin shutdown: {}", Ex.what()); - } - - Plugin = nullptr; - } - - m_Plugins.clear(); -} - -void -HttpPluginServerImpl::AddPlugin(Ref Plugin) -{ - RwLock::ExclusiveLockScope _(m_Lock); - m_Plugins.emplace_back(std::move(Plugin)); -} - -void -HttpPluginServerImpl::RemovePlugin(Ref Plugin) -{ - RwLock::ExclusiveLockScope _(m_Lock); - auto It = std::find(begin(m_Plugins), end(m_Plugins), Plugin); - if (It != m_Plugins.end()) - { - m_Plugins.erase(It); - } -} - -void -HttpPluginServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) -{ - std::string_view UrlPath(InUrlPath); - Service.SetUriPrefixLength(UrlPath.size()); - if (!UrlPath.empty() && UrlPath.back() == '/') - { - UrlPath.remove_suffix(1); - } - - RwLock::ExclusiveLockScope _(m_Lock); - m_UriHandlers.push_back({std::string(UrlPath), &Service}); -} - -HttpService* -HttpPluginServerImpl::RouteRequest(std::string_view Url) -{ - RwLock::SharedLockScope _(m_Lock); - - HttpService* CandidateService = nullptr; - std::string::size_type CandidateMatchSize = 0; - for (const ServiceEntry& SvcEntry : m_UriHandlers) - { - const std::string& SvcUrl = SvcEntry.ServiceUrlPath; - const std::string::size_type SvcUrlSize = SvcUrl.size(); - if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && - ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) - { - CandidateMatchSize = SvcUrl.size(); - CandidateService = SvcEntry.Service; - } - } - - return CandidateService; -} - -////////////////////////////////////////////////////////////////////////// - -HttpPluginServer::HttpPluginServer(unsigned int ThreadCount) -: m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) -, m_Impl(new HttpPluginServerImpl) -{ -} - -HttpPluginServer::~HttpPluginServer() -{ - if (m_Impl) - { - ZEN_ERROR("~HttpPluginServer() called without calling Close() first"); - } -} - -int -HttpPluginServer::Initialize(int BasePort) -{ - try - { - m_Impl->Start(); - } - catch (std::exception& ex) - { - ZEN_WARN("Caught exception starting http plugin server: {}", ex.what()); - } - - return BasePort; -} - -void -HttpPluginServer::Close() -{ - try - { - m_Impl->Stop(); - } - catch (std::exception& ex) - { - ZEN_WARN("Caught exception stopping http plugin server: {}", ex.what()); - } - - delete m_Impl; - m_Impl = nullptr; -} - -void -HttpPluginServer::Run(bool IsInteractive) -{ - const bool TestMode = !IsInteractive; - - int WaitTimeout = -1; - if (!TestMode) - { - WaitTimeout = 1000; - } - -# if ZEN_PLATFORM_WINDOWS - if (TestMode == false) - { - zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Press ESC or Q to quit"); - } - - do - { - if (!TestMode && _kbhit() != 0) - { - char c = (char)_getch(); - - if (c == 27 || c == 'Q' || c == 'q') - { - RequestApplicationExit(0); - } - } - - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); -# else - if (TestMode == false) - { - zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Ctrl-C to quit"); - } - - do - { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); -# endif -} - -void -HttpPluginServer::RequestExit() -{ - m_ShutdownEvent.Set(); -} - -void -HttpPluginServer::RegisterService(HttpService& Service) -{ - m_Impl->RegisterService(Service.BaseUri(), Service); -} - -void -HttpPluginServer::AddPlugin(Ref Plugin) -{ - m_Impl->AddPlugin(Plugin); -} - -void -HttpPluginServer::RemovePlugin(Ref Plugin) -{ - m_Impl->RemovePlugin(Plugin); -} - -} // namespace zen -#endif diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index 7ea7cf91d..cd62ea157 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -2,10 +2,9 @@ #include -#include "httpasio.h" -#include "httpnull.h" -#include "httpsys.h" - +#include "servers/httpasio.h" +#include "servers/httpnull.h" +#include "servers/httpsys.h" #include "zenhttp/httpplugin.h" #if ZEN_WITH_PLUGINS diff --git a/src/zenhttp/httpsys.cpp b/src/zenhttp/httpsys.cpp deleted file mode 100644 index c1b4717cb..000000000 --- a/src/zenhttp/httpsys.cpp +++ /dev/null @@ -1,2012 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "httpsys.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if ZEN_WITH_HTTPSYS -# define _WINSOCKAPI_ -# include -# include -# include "iothreadpool.h" - -# include - -namespace spdlog { -class logger; -} - -namespace zen { - -/** - * @brief Windows implementation of HTTP server based on http.sys - * - * This requires elevation to function by default but system configuration - * can soften this requirement. - * - * See README.md for details. - */ -class HttpSysServer : public HttpServer -{ - friend class HttpSysTransaction; - -public: - explicit HttpSysServer(const HttpSysConfig& Config); - ~HttpSysServer(); - - // HttpServer interface implementation - - virtual int Initialize(int BasePort) override; - virtual void Run(bool TestMode) override; - virtual void RequestExit() override; - virtual void RegisterService(HttpService& Service) override; - virtual void Close() override; - - WorkerThreadPool& WorkPool(); - - inline bool IsOk() const { return m_IsOk; } - inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } - -private: - int InitializeServer(int BasePort); - void Cleanup(); - - void StartServer(); - void OnHandlingNewRequest(); - void IssueNewRequestMaybe(); - - void RegisterService(const char* Endpoint, HttpService& Service); - void UnregisterService(const char* Endpoint, HttpService& Service); - -private: - spdlog::logger& m_Log; - spdlog::logger& m_RequestLog; - spdlog::logger& Log() { return m_Log; } - - bool m_IsOk = false; - bool m_IsHttpInitialized = false; - bool m_IsRequestLoggingEnabled = false; - bool m_IsAsyncResponseEnabled = true; - - std::unique_ptr m_IoThreadPool; - - RwLock m_AsyncWorkPoolInitLock; - WorkerThreadPool* m_AsyncWorkPool = nullptr; - - std::vector m_BaseUris; // eg: http://*:nnnn/ - HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; - HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; - HANDLE m_RequestQueueHandle = 0; - std::atomic_int32_t m_PendingRequests{0}; - std::atomic_int32_t m_IsShuttingDown{0}; - int32_t m_MinPendingRequests = 16; - int32_t m_MaxPendingRequests = 128; - Event m_ShutdownEvent; - HttpSysConfig m_InitialConfig; -}; - -} // namespace zen -#endif - -#if ZEN_WITH_HTTPSYS - -# include -# include -# pragma comment(lib, "httpapi.lib") - -std::wstring -UTF8_to_UTF16(const char* InPtr) -{ - std::wstring OutString; - unsigned int Codepoint; - - while (*InPtr != 0) - { - unsigned char InChar = static_cast(*InPtr); - - if (InChar <= 0x7f) - Codepoint = InChar; - else if (InChar <= 0xbf) - Codepoint = (Codepoint << 6) | (InChar & 0x3f); - else if (InChar <= 0xdf) - Codepoint = InChar & 0x1f; - else if (InChar <= 0xef) - Codepoint = InChar & 0x0f; - else - Codepoint = InChar & 0x07; - - ++InPtr; - - if (((*InPtr & 0xc0) != 0x80) && (Codepoint <= 0x10ffff)) - { - if (Codepoint > 0xffff) - { - OutString.append(1, static_cast(0xd800 + (Codepoint >> 10))); - OutString.append(1, static_cast(0xdc00 + (Codepoint & 0x03ff))); - } - else if (Codepoint < 0xd800 || Codepoint >= 0xe000) - { - OutString.append(1, static_cast(Codepoint)); - } - } - } - - return OutString; -} - -namespace zen { - -using namespace std::literals; - -class HttpSysServer; -class HttpSysTransaction; -class HttpMessageResponseRequest; - -////////////////////////////////////////////////////////////////////////// - -HttpVerb -TranslateHttpVerb(HTTP_VERB ReqVerb) -{ - switch (ReqVerb) - { - case HttpVerbOPTIONS: - return HttpVerb::kOptions; - - case HttpVerbGET: - return HttpVerb::kGet; - - case HttpVerbHEAD: - return HttpVerb::kHead; - - case HttpVerbPOST: - return HttpVerb::kPost; - - case HttpVerbPUT: - return HttpVerb::kPut; - - case HttpVerbDELETE: - return HttpVerb::kDelete; - - case HttpVerbCOPY: - return HttpVerb::kCopy; - - default: - // TODO: invalid request? - return (HttpVerb)0; - } -} - -uint64_t -GetContentLength(const HTTP_REQUEST* HttpRequest) -{ - const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; - std::string_view cl(clh.pRawValue, clh.RawValueLength); - uint64_t ContentLength = 0; - std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); - return ContentLength; -}; - -HttpContentType -GetContentType(const HTTP_REQUEST* HttpRequest) -{ - const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; - return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); -}; - -HttpContentType -GetAcceptType(const HTTP_REQUEST* HttpRequest) -{ - const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; - return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); -}; - -/** - * @brief Base class for any pending or active HTTP transactions - */ -class HttpSysRequestHandler -{ -public: - explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {} - virtual ~HttpSysRequestHandler() = default; - - virtual void IssueRequest(std::error_code& ErrorCode) = 0; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; - HttpSysTransaction& Transaction() { return m_Transaction; } - - HttpSysRequestHandler(const HttpSysRequestHandler&) = delete; - HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete; - -private: - HttpSysTransaction& m_Transaction; -}; - -/** - * This is the handler for the initial HTTP I/O request which will receive the headers - * and however much of the remaining payload might fit in the embedded request buffer. - * - * It is also used to receive any entity body data relating to the request - * - */ -struct InitialRequestHandler : public HttpSysRequestHandler -{ - inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } - inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } - inline bool IsInitialRequest() const { return m_IsInitialRequest; } - - InitialRequestHandler(HttpSysTransaction& InRequest); - ~InitialRequestHandler(); - - virtual void IssueRequest(std::error_code& ErrorCode) override final; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; - - bool m_IsInitialRequest = true; - uint64_t m_CurrentPayloadOffset = 0; - uint64_t m_ContentLength = ~uint64_t(0); - IoBuffer m_PayloadBuffer; - UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)]; -}; - -/** - * This is the class which request handlers use to interact with the server instance - */ - -class HttpSysServerRequest : public HttpServerRequest -{ -public: - HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); - ~HttpSysServerRequest(); - - virtual Oid ParseSessionId() const override; - virtual uint32_t ParseRequestId() const override; - - virtual IoBuffer ReadPayload() override; - virtual void WriteResponse(HttpResponseCode ResponseCode) override; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; - virtual void WriteResponseAsync(std::function&& ContinuationHandler) override; - virtual bool TryGetRanges(HttpRanges& Ranges) override; - - using HttpServerRequest::WriteResponse; - - HttpSysServerRequest(const HttpSysServerRequest&) = delete; - HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; - - HttpSysTransaction& m_HttpTx; - HttpSysRequestHandler* m_NextCompletionHandler = nullptr; - IoBuffer m_PayloadBuffer; - ExtendableStringBuilder<128> m_UriUtf8; - ExtendableStringBuilder<128> m_QueryStringUtf8; -}; - -/** HTTP transaction - - There will be an instance of this per pending and in-flight HTTP transaction - - */ -class HttpSysTransaction final -{ -public: - HttpSysTransaction(HttpSysServer& Server); - virtual ~HttpSysTransaction(); - - enum class Status - { - kDone, - kRequestPending - }; - - [[nodiscard]] Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); - - static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, - PVOID pContext /* HttpSysServer */, - PVOID pOverlapped, - ULONG IoResult, - ULONG_PTR NumberOfBytesTransferred, - PTP_IO Io); - - void IssueInitialRequest(std::error_code& ErrorCode); - bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler); - - PTP_IO Iocp(); - HANDLE RequestQueueHandle(); - inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } - inline HttpSysServer& Server() { return m_HttpServer; } - inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } - - HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload); - - HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); } - - struct CompletionMutexScope - { - CompletionMutexScope(HttpSysTransaction& Tx) : Lock(Tx.m_CompletionMutex) {} - ~CompletionMutexScope() = default; - - RwLock::ExclusiveLockScope Lock; - }; - -private: - OVERLAPPED m_HttpOverlapped{}; - HttpSysServer& m_HttpServer; - - // Tracks which handler is due to handle the next I/O completion event - HttpSysRequestHandler* m_CompletionHandler = nullptr; - RwLock m_CompletionMutex; - InitialRequestHandler m_InitialHttpHandler{*this}; - std::optional m_HandlerRequest; - Ref m_PackageHandler; -}; - -/** - * @brief HTTP request response I/O request handler - * - * Asynchronously streams out a response to an HTTP request via compound - * responses from memory or directly from file - */ - -class HttpMessageResponseRequest : public HttpSysRequestHandler -{ -public: - HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, - uint16_t ResponseCode, - HttpContentType ContentType, - const void* Payload, - size_t PayloadSize); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, - uint16_t ResponseCode, - HttpContentType ContentType, - std::span Blobs); - ~HttpMessageResponseRequest(); - - virtual void IssueRequest(std::error_code& ErrorCode) override final; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; - void SuppressResponseBody(); // typically used for HEAD requests - -private: - std::vector m_HttpDataChunks; - uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes - uint16_t m_ResponseCode = 0; - uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists - uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends - bool m_IsInitialResponse = true; - HttpContentType m_ContentType = HttpContentType::kBinary; - std::vector m_DataBuffers; - - void InitializeForPayload(uint16_t ResponseCode, std::span Blobs); -}; - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) -: HttpSysRequestHandler(InRequest) -{ - std::array EmptyBufferList; - - InitializeForPayload(ResponseCode, EmptyBufferList); -} - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) -: HttpSysRequestHandler(InRequest) -, m_ContentType(HttpContentType::kText) -{ - IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); - std::array SingleBufferList({MessageBuffer}); - - InitializeForPayload(ResponseCode, SingleBufferList); -} - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, - uint16_t ResponseCode, - HttpContentType ContentType, - const void* Payload, - size_t PayloadSize) -: HttpSysRequestHandler(InRequest) -, m_ContentType(ContentType) -{ - IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); - std::array SingleBufferList({MessageBuffer}); - - InitializeForPayload(ResponseCode, SingleBufferList); -} - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, - uint16_t ResponseCode, - HttpContentType ContentType, - std::span BlobList) -: HttpSysRequestHandler(InRequest) -, m_ContentType(ContentType) -{ - InitializeForPayload(ResponseCode, BlobList); -} - -HttpMessageResponseRequest::~HttpMessageResponseRequest() -{ -} - -void -HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span BlobList) -{ - ZEN_TRACE_CPU("httpsys::InitializeForPayload"); - - const uint32_t ChunkCount = gsl::narrow(BlobList.size()); - - m_HttpDataChunks.reserve(ChunkCount); - m_DataBuffers.reserve(ChunkCount); - - for (IoBuffer& Buffer : BlobList) - { - m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); - } - - // Initialize the full array up front - - uint64_t LocalDataSize = 0; - - for (IoBuffer& Buffer : m_DataBuffers) - { - uint64_t BufferDataSize = Buffer.Size(); - - ZEN_ASSERT(BufferDataSize); - - LocalDataSize += BufferDataSize; - - IoBufferFileReference FileRef; - if (Buffer.GetFileReference(/* out */ FileRef)) - { - // Use direct file transfer - - auto& Chunk = m_HttpDataChunks.emplace_back(); - - Chunk.DataChunkType = HttpDataChunkFromFileHandle; - Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; - Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; - Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; - } - else - { - // Send from memory, need to make sure we chunk the buffer up since - // the underlying data structure only accepts 32-bit chunk sizes for - // memory chunks. When this happens the vector will be reallocated, - // which is fine since this will be a pretty rare case and sending - // the data is going to take a lot longer than a memory allocation :) - - const uint8_t* WriteCursor = reinterpret_cast(Buffer.Data()); - - while (BufferDataSize) - { - const ULONG ThisChunkSize = gsl::narrow(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); - - auto& Chunk = m_HttpDataChunks.emplace_back(); - - Chunk.DataChunkType = HttpDataChunkFromMemory; - Chunk.FromMemory.pBuffer = (void*)WriteCursor; - Chunk.FromMemory.BufferLength = ThisChunkSize; - - BufferDataSize -= ThisChunkSize; - WriteCursor += ThisChunkSize; - } - } - } - - m_RemainingChunkCount = gsl::narrow(m_HttpDataChunks.size()); - m_TotalDataSize = LocalDataSize; - - if (m_TotalDataSize == 0 && ResponseCode == 200) - { - // Some HTTP clients really don't like empty responses unless a 204 response is sent - m_ResponseCode = uint16_t(HttpResponseCode::NoContent); - } - else - { - m_ResponseCode = ResponseCode; - } -} - -void -HttpMessageResponseRequest::SuppressResponseBody() -{ - m_RemainingChunkCount = 0; - m_HttpDataChunks.clear(); - m_DataBuffers.clear(); -} - -HttpSysRequestHandler* -HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) -{ - ZEN_UNUSED(NumberOfBytesTransferred); - - if (IoResult != NO_ERROR) - { - ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult)); - - // if one transmit failed there's really no need to go on - return nullptr; - } - - if (m_RemainingChunkCount == 0) - { - return nullptr; // All done - } - - return this; -} - -void -HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) -{ - ZEN_TRACE_CPU("httpsys::Response::IssueRequest"); - HttpSysTransaction& Tx = Transaction(); - HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); - PTP_IO const Iocp = Tx.Iocp(); - - StartThreadpoolIo(Iocp); - - // Split payload into batches to play well with the underlying API - - const int MaxChunksPerCall = 9999; - - const int ThisRequestChunkCount = std::min(m_RemainingChunkCount, MaxChunksPerCall); - const int ThisRequestChunkOffset = m_NextDataChunkOffset; - - m_RemainingChunkCount -= ThisRequestChunkCount; - m_NextDataChunkOffset += ThisRequestChunkCount; - - /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA? - - From the docs: - - This flag enables buffering of data in the kernel on a per-response basis. It should - be used by an application doing synchronous I/O, or by a an application doing - asynchronous I/O with no more than one send outstanding at a time. - - Applications using asynchronous I/O which may have more than one send outstanding at - a time should not use this flag. - - When this flag is set, it should be used consistently in calls to the - HttpSendHttpResponse function as well. - */ - - ULONG SendFlags = HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA; - - if (m_RemainingChunkCount) - { - // We need to make more calls to send the full amount of data - SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA; - } - - ULONG SendResult = 0; - - if (m_IsInitialResponse) - { - // Populate response structure - - HTTP_RESPONSE HttpResponse = {}; - - HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount); - HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset; - - // Server header - // - // By default this will also add a suffix " Microsoft-HTTPAPI/2.0" to this header - // - // This is controlled via a registry key 'DisableServerHeader', at: - // - // Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\HTTP\Parameters - // - // Set DisableServerHeader to 1 to disable suffix, or 2 to disable the header altogether - // (only the latter appears to do anything in my testing, on Windows 10). - // - // (reference https://docs.microsoft.com/en-us/archive/blogs/dsnotes/wswcf-remove-server-header) - // - - PHTTP_KNOWN_HEADER ServerHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderServer]; - ServerHeader->pRawValue = "Zen"; - ServerHeader->RawValueLength = (USHORT)3; - - // Content-length header - - char ContentLengthString[32]; - _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10); - - PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength]; - ContentLengthHeader->pRawValue = ContentLengthString; - ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString); - - // Content-type header - - PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; - - std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); - - ContentTypeHeader->pRawValue = ContentTypeString.data(); - ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); - - std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); - - HttpResponse.StatusCode = m_ResponseCode; - HttpResponse.pReason = ReasonString.data(); - HttpResponse.ReasonLength = (USHORT)ReasonString.size(); - - // Cache policy - - HTTP_CACHE_POLICY CachePolicy; - - CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; - CachePolicy.SecondsToLive = 0; - - // Initial response API call - - SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, - &HttpResponse, - &CachePolicy, - NULL, - NULL, - 0, - Tx.Overlapped(), - NULL); - - m_IsInitialResponse = false; - } - else - { - // Subsequent response API calls - - SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, - (USHORT)ThisRequestChunkCount, // EntityChunkCount - &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks - NULL, // BytesSent - NULL, // Reserved1 - 0, // Reserved2 - Tx.Overlapped(), // Overlapped - NULL // LogData - ); - } - - auto EmitReponseDetails = [&](StringBuilderBase& ResponseDetails) -> void { - for (int i = 0; i < ThisRequestChunkCount; ++i) - { - const HTTP_DATA_CHUNK Chunk = m_HttpDataChunks[ThisRequestChunkOffset + i]; - - if (i > 0) - { - ResponseDetails << " + "; - } - - switch (Chunk.DataChunkType) - { - case HttpDataChunkFromMemory: - ResponseDetails << "mem:" << uint64_t(Chunk.FromMemory.BufferLength); - break; - - case HttpDataChunkFromFileHandle: - ResponseDetails << "file:"; - { - ResponseDetails << uint64_t(Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart) << "," - << uint64_t(Chunk.FromFileHandle.ByteRange.Length.QuadPart) << ","; - - std::error_code PathEc; - HANDLE FileHandle = Chunk.FromFileHandle.FileHandle; - std::filesystem::path Path = PathFromHandle(FileHandle, PathEc); - - if (PathEc) - { - ResponseDetails << "bad_file(handle=" << reinterpret_cast(FileHandle) << ",error=" << PathEc.message() - << ")"; - } - else - { - const uint64_t FileSize = FileSizeFromHandle(FileHandle); - ResponseDetails << Path.u8string() << "(" << FileSize << ") handle=" << reinterpret_cast(FileHandle); - } - } - break; - - case HttpDataChunkFromFragmentCache: - ResponseDetails << "frag:???"; // We do not use these - break; - - case HttpDataChunkFromFragmentCacheEx: - ResponseDetails << "frax:???"; // We do not use these - break; - -# if 0 // Not available in older Windows SDKs - case HttpDataChunkTrailers: - ResponseDetails << "trls:???"; // We do not use these - break; -# endif - - default: - ResponseDetails << "???: " << Chunk.DataChunkType; - break; - } - } - }; - - if (SendResult == NO_ERROR) - { - // Synchronous completion, but the completion event will still be posted to IOCP - - ErrorCode.clear(); - } - else if (SendResult == ERROR_IO_PENDING) - { - // Asynchronous completion, a completion notification will be posted to IOCP - - ErrorCode.clear(); - } - else - { - ErrorCode = MakeErrorCode(SendResult); - - // An error occurred, no completion will be posted to IOCP - - CancelThreadpoolIo(Iocp); - - // Emit diagnostics - - ExtendableStringBuilder<256> ResponseDetails; - EmitReponseDetails(ResponseDetails); - - ZEN_WARN("failed to send HTTP response (error {}: '{}'), request URL: '{}', ({}.{}) response: {}", - SendResult, - ErrorCode.message(), - HttpReq->pRawUrl, - Tx.ServerRequest().SessionId(), - HttpReq->RequestId, - ResponseDetails); - } -} - -/** HTTP completion handler for async work - - This is used to allow work to be taken off the request handler threads - and to support posting responses asynchronously. - */ - -class HttpAsyncWorkRequest : public HttpSysRequestHandler -{ -public: - HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function&& Response); - ~HttpAsyncWorkRequest(); - - virtual void IssueRequest(std::error_code& ErrorCode) override final; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; - -private: - struct AsyncWorkItem : public IWork - { - virtual void Execute() override; - - AsyncWorkItem(HttpSysTransaction& InTx, std::function&& InHandler) - : Tx(InTx) - , Handler(std::move(InHandler)) - { - } - - HttpSysTransaction& Tx; - std::function Handler; - }; - - Ref m_WorkItem; -}; - -HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function&& Response) -: HttpSysRequestHandler(Tx) -{ - m_WorkItem = new AsyncWorkItem(Tx, std::move(Response)); -} - -HttpAsyncWorkRequest::~HttpAsyncWorkRequest() -{ -} - -void -HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode) -{ - ZEN_TRACE_CPU("httpsys::AsyncWork::IssueRequest"); - ErrorCode.clear(); - - Transaction().Server().WorkPool().ScheduleWork(m_WorkItem); -} - -HttpSysRequestHandler* -HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) -{ - // This ought to not be called since there should be no outstanding I/O request - // when this completion handler is active - - ZEN_UNUSED(IoResult, NumberOfBytesTransferred); - - ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); - - return this; -} - -void -HttpAsyncWorkRequest::AsyncWorkItem::Execute() -{ - ZEN_TRACE_CPU("httpsys::async_execute"); - - try - { - // We need to hold this lock while we're issuing new requests in order to - // prevent race conditions between the thread we are running on and any - // IOCP service threads. Otherwise the IOCP completion handler can end - // up deleting the transaction object before we are done with it! - HttpSysTransaction::CompletionMutexScope _(Tx); - HttpSysServerRequest& ThisRequest = Tx.ServerRequest(); - - ThisRequest.m_NextCompletionHandler = nullptr; - - { - ZEN_TRACE_CPU("httpsys::HandleRequest"); - Handler(ThisRequest); - } - - // TODO: should Handler be destroyed at this point to ensure there - // are no outstanding references into state which could be - // deleted asynchronously as a result of issuing the response? - - if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler) - { - return (void)Tx.IssueNextRequest(NextHandler); - } - else if (!ThisRequest.IsHandled()) - { - return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv)); - } - else - { - // "Handled" but no request handler? Shouldn't ever happen - return (void)Tx.IssueNextRequest( - new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv)); - } - } - catch (std::exception& Ex) - { - return (void)Tx.IssueNextRequest( - new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what()))); - } -} - -/** - _________ - / _____/ ______________ __ ___________ - \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ - / \ ___/| | \/\ /\ ___/| | \/ - /_______ /\___ >__| \_/ \___ >__| - \/ \/ \/ -*/ - -HttpSysServer::HttpSysServer(const HttpSysConfig& InConfig) -: m_Log(logging::Get("http")) -, m_RequestLog(logging::Get("http_requests")) -, m_IsRequestLoggingEnabled(InConfig.IsRequestLoggingEnabled) -, m_IsAsyncResponseEnabled(InConfig.IsAsyncResponseEnabled) -, m_InitialConfig(InConfig) -{ - // Initialize thread pool - - int MinThreadCount; - int MaxThreadCount; - - if (m_InitialConfig.ThreadCount == 0) - { - MinThreadCount = Max(8u, std::thread::hardware_concurrency()); - } - else - { - MinThreadCount = m_InitialConfig.ThreadCount; - } - - MaxThreadCount = MinThreadCount * 2; - - if (m_InitialConfig.IsDedicatedServer) - { - // In order to limit the potential impact of threads stuck - // in locks we allow the thread pool to be oversubscribed - // by a fair amount - - MaxThreadCount *= 2; - } - - m_IoThreadPool = std::make_unique(MinThreadCount, MaxThreadCount); - - if (m_InitialConfig.AsyncWorkThreadCount == 0) - { - m_InitialConfig.AsyncWorkThreadCount = 16; - } - - ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr); - - if (Result != NO_ERROR) - { - return; - } - - m_IsHttpInitialized = true; - m_IsOk = true; - - ZEN_INFO("http.sys server started in {} mode, using {}-{} I/O threads and {} async worker threads", - m_InitialConfig.IsDedicatedServer ? "DEDICATED" : "NORMAL", - MinThreadCount, - MaxThreadCount, - m_InitialConfig.AsyncWorkThreadCount); -} - -HttpSysServer::~HttpSysServer() -{ - if (m_IsHttpInitialized) - { - ZEN_ERROR("~HttpSysServer() called without calling Close() first"); - } - - delete m_AsyncWorkPool; - m_AsyncWorkPool = nullptr; -} - -void -HttpSysServer::Close() -{ - if (m_IsHttpInitialized) - { - Cleanup(); - - HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr); - m_IsHttpInitialized = false; - } -} - -int -HttpSysServer::InitializeServer(int BasePort) -{ - using namespace std::literals; - - WideStringBuilder<64> WildcardUrlPath; - WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; - - m_IsOk = false; - - ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); - - if (Result != NO_ERROR) - { - ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); - - return BasePort; - } - - Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); - - if (Result != NO_ERROR) - { - ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); - - return BasePort; - } - - int EffectivePort = BasePort; - - Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); - - // Sharing violation implies the port is being used by another process - for (int PortOffset = 1; (Result == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) - { - EffectivePort = BasePort + (PortOffset * 100); - WildcardUrlPath.Reset(); - WildcardUrlPath << u8"http://*:"sv << int64_t(EffectivePort) << u8"/"sv; - - Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); - } - - m_BaseUris.clear(); - if (Result == NO_ERROR) - { - m_BaseUris.push_back(WildcardUrlPath.c_str()); - } - else if (Result == ERROR_ACCESS_DENIED) - { - // If we can't register the wildcard path, we fall back to local paths - // This local paths allow requests originating locally to function, but will not allow - // remote origin requests to function. This can be remedied by using netsh - // during an install process to grant permissions to route public access to the appropriate - // port for the current user. eg: - // netsh http add urlacl url=http://*:8558/ user= - - ZEN_WARN("Unable to register handler using '{}' - falling back to local-only", WideToUtf8(WildcardUrlPath)); - - const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; - - ULONG InternalResult = ERROR_SHARING_VIOLATION; - for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) - { - EffectivePort = BasePort + (PortOffset * 100); - - for (const std::u8string_view Host : Hosts) - { - WideStringBuilder<64> LocalUrlPath; - LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; - - InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); - - if (InternalResult == NO_ERROR) - { - ZEN_INFO("Registered local handler '{}'", WideToUtf8(LocalUrlPath)); - - m_BaseUris.push_back(LocalUrlPath.c_str()); - } - else - { - break; - } - } - } - } - - if (m_BaseUris.empty()) - { - ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); - - return BasePort; - } - - HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; - - WideStringBuilder<64> QueueName; - QueueName << "zenserver_" << EffectivePort; - - Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, - /* Name */ QueueName.c_str(), - /* SecurityAttributes */ nullptr, - /* Flags */ 0, - &m_RequestQueueHandle); - - if (Result != NO_ERROR) - { - ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); - - return EffectivePort; - } - - HttpBindingInfo.Flags.Present = 1; - HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle; - - Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo)); - - if (Result != NO_ERROR) - { - ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); - - return EffectivePort; - } - - // Configure rejection method. Default is to drop the connection, it's better if we - // return an explicit error code when the queue cannot accept more requests - - { - HTTP_503_RESPONSE_VERBOSITY VerbosityInformation = Http503ResponseVerbosityLimited; - - Result = HttpSetRequestQueueProperty(m_RequestQueueHandle, - HttpServer503VerbosityProperty, - &VerbosityInformation, - sizeof VerbosityInformation, - 0, - 0); - } - - // Tune the maximum number of pending requests in the http.sys request queue. By default - // the value is 1000 which is plenty for single user machines but for dedicated servers - // serving many users it makes sense to increase this to a higher number to help smooth - // out intermittent stalls like we might experience when GC is triggered - - if (m_InitialConfig.IsDedicatedServer) - { - ULONG QueueLength = 50000; - - Result = HttpSetRequestQueueProperty(m_RequestQueueHandle, HttpServerQueueLengthProperty, &QueueLength, sizeof QueueLength, 0, 0); - - if (Result != NO_ERROR) - { - ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result); - } - } - - // Create I/O completion port - - std::error_code ErrorCode; - m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); - - if (ErrorCode) - { - ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); - } - else - { - m_IsOk = true; - - ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); - } - - // This is not available in all Windows SDK versions so for now we can't use recently - // released functionality. We should investigate how to get more recent SDK releases - // into the build - -# if 0 - if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4)) - { - ZEN_DEBUG("HTTP3 is available"); - } - else - { - ZEN_DEBUG("HTTP3 is NOT available"); - } -# endif - - return EffectivePort; -} - -void -HttpSysServer::Cleanup() -{ - ++m_IsShuttingDown; - - if (m_RequestQueueHandle) - { - HttpCloseRequestQueue(m_RequestQueueHandle); - m_RequestQueueHandle = nullptr; - } - - if (m_HttpUrlGroupId) - { - HttpCloseUrlGroup(m_HttpUrlGroupId); - m_HttpUrlGroupId = 0; - } - - if (m_HttpSessionId) - { - HttpCloseServerSession(m_HttpSessionId); - m_HttpSessionId = 0; - } -} - -WorkerThreadPool& -HttpSysServer::WorkPool() -{ - if (!m_AsyncWorkPool) - { - RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); - - if (!m_AsyncWorkPool) - { - m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"); - } - } - - return *m_AsyncWorkPool; -} - -void -HttpSysServer::StartServer() -{ - const int InitialRequestCount = 32; - - for (int i = 0; i < InitialRequestCount; ++i) - { - IssueNewRequestMaybe(); - } -} - -void -HttpSysServer::Run(bool IsInteractive) -{ - if (IsInteractive) - { - zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit"); - } - - do - { - // int WaitTimeout = -1; - int WaitTimeout = 100; - - if (IsInteractive) - { - WaitTimeout = 1000; - - if (_kbhit() != 0) - { - char c = (char)_getch(); - - if (c == 27 || c == 'Q' || c == 'q') - { - RequestApplicationExit(0); - } - } - } - - m_ShutdownEvent.Wait(WaitTimeout); - UpdateLofreqTimerValue(); - } while (!IsApplicationExitRequested()); -} - -void -HttpSysServer::OnHandlingNewRequest() -{ - if (--m_PendingRequests > m_MinPendingRequests) - { - // We have more than the minimum number of requests pending, just let someone else - // enqueue new requests. This should be the common case as we check if we need to - // enqueue more requests before exiting the completion handler. - return; - } - - IssueNewRequestMaybe(); -} - -void -HttpSysServer::IssueNewRequestMaybe() -{ - if (m_IsShuttingDown.load(std::memory_order::acquire)) - { - return; - } - - if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests) - { - return; - } - - std::unique_ptr Request = std::make_unique(*this); - - std::error_code ErrorCode; - Request->IssueInitialRequest(ErrorCode); - - if (ErrorCode) - { - // No request was actually issued. What is the appropriate response? - - return; - } - - // This may end up exceeding the MaxPendingRequests limit, but it's not - // really a problem. I'm doing it this way mostly to avoid dealing with - // exceptions here - ++m_PendingRequests; - - Request.release(); -} - -void -HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) -{ - if (UrlPath[0] == '/') - { - ++UrlPath; - } - - const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); - Service.SetUriPrefixLength(PathUtf16.size() + 1 /* leading slash */); - - // Convert to wide string - - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; - - ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); - - if (Result != NO_ERROR) - { - ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); - - return; - } - } -} - -void -HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) -{ - ZEN_UNUSED(Service); - - if (UrlPath[0] == '/') - { - ++UrlPath; - } - - const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); - - // Convert to wide string - - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; - - ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); - - if (Result != NO_ERROR) - { - ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); - } - } -} - -////////////////////////////////////////////////////////////////////////// - -HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler) -{ -} - -HttpSysTransaction::~HttpSysTransaction() -{ -} - -PTP_IO -HttpSysTransaction::Iocp() -{ - return m_HttpServer.m_IoThreadPool->Iocp(); -} - -HANDLE -HttpSysTransaction::RequestQueueHandle() -{ - return m_HttpServer.m_RequestQueueHandle; -} - -void -HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode) -{ - m_InitialHttpHandler.IssueRequest(ErrorCode); -} - -thread_local bool t_IsHttpSysThreadNamed = false; -static std::atomic HttpSysThreadIndex = 0; - -static void -NameCurrentHttpSysThread() -{ - t_IsHttpSysThreadNamed = true; - const int ThreadIndex = ++HttpSysThreadIndex; - zen::ExtendableStringBuilder<16> ThreadName; - ThreadName << "httpio_" << ThreadIndex; - SetCurrentThreadName(ThreadName); -} - -void -HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, - PVOID pContext /* HttpSysServer */, - PVOID pOverlapped, - ULONG IoResult, - ULONG_PTR NumberOfBytesTransferred, - PTP_IO Io) -{ - ZEN_UNUSED(Io, Instance); - - // Assign names to threads for context - - if (!t_IsHttpSysThreadNamed) - { - NameCurrentHttpSysThread(); - } - - // Note that for a given transaction we may be in this completion function on more - // than one thread at any given moment. This means we need to be careful about what - // happens in here - - HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); - - if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) - { - delete Transaction; - } - - // Ensure new requests are enqueued as necessary. We do this here instead - // of inside the transaction completion handler now to avoid spending time - // in unrelated API calls while holding the transaction lock - - if (HttpSysServer* HttpServer = reinterpret_cast(pContext)) - { - HttpServer->IssueNewRequestMaybe(); - } -} - -bool -HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler) -{ - ZEN_TRACE_CPU("httpsys::Transaction::IssueNextRequest"); - - HttpSysRequestHandler* CurrentHandler = m_CompletionHandler; - m_CompletionHandler = NewCompletionHandler; - - auto _ = MakeGuard([this, CurrentHandler] { - if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler)) - { - delete CurrentHandler; - } - }); - - if (NewCompletionHandler == nullptr) - { - return false; - } - - try - { - std::error_code ErrorCode; - m_CompletionHandler->IssueRequest(ErrorCode); - - if (!ErrorCode) - { - return true; - } - - ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message()); - } - catch (std::exception& Ex) - { - ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what()); - } - - // something went wrong, no request is pending - m_CompletionHandler = nullptr; - - return false; -} - -HttpSysTransaction::Status -HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) -{ - // We use this to ensure sequential execution of completion handlers - // for any given transaction. It also ensures all member variables are - // in a consistent state for the current thread - - RwLock::ExclusiveLockScope _(m_CompletionMutex); - - bool IsRequestPending = false; - - if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler) - { - if ((CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest()) - { - // Ensure we have a sufficient number of pending requests outstanding - m_HttpServer.OnHandlingNewRequest(); - } - - auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred); - - IsRequestPending = IssueNextRequest(NewCompletionHandler); - } - - if (IsRequestPending) - { - // There is another request pending on this transaction, so it needs to remain valid - return Status::kRequestPending; - } - - if (m_HttpServer.m_IsRequestLoggingEnabled) - { - if (m_HandlerRequest.has_value()) - { - m_HttpServer.m_RequestLog.info("{} {}", ToString(m_HandlerRequest->RequestVerb()), m_HandlerRequest->RelativeUri()); - } - } - - // Transaction done, caller should clean up (delete) this instance - return Status::kDone; -} - -HttpSysServerRequest& -HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) -{ - HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); - - // Default request handling - - if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) - { - Service.HandleRequest(ThisRequest); - } - - return ThisRequest; -} - -////////////////////////////////////////////////////////////////////////// - -HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) -: m_HttpTx(Tx) -, m_PayloadBuffer(std::move(PayloadBuffer)) -{ - const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); - - const int PrefixLength = Service.UriPrefixLength(); - const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(wchar_t); - - HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; - - if (AbsPathLength >= PrefixLength) - { - // We convert the URI immediately because most of the code involved prefers to deal - // with utf8. This is overhead which I'd prefer to avoid but for now we just have - // to live with it - - WideToUtf8({(wchar_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow(AbsPathLength - PrefixLength)}, - m_UriUtf8); - - std::string_view UriSuffix8{m_UriUtf8}; - - m_UriWithExtension = UriSuffix8; // Retain URI with extension for user access - m_Uri = UriSuffix8; - - const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); - - if (LastComponentIndex != std::string_view::npos) - { - UriSuffix8.remove_prefix(LastComponentIndex); - } - - const size_t LastDotIndex = UriSuffix8.find_last_of('.'); - - if (LastDotIndex != std::string_view::npos) - { - UriSuffix8.remove_prefix(LastDotIndex + 1); - - AcceptContentType = ParseContentType(UriSuffix8); - if (AcceptContentType != HttpContentType::kUnknownContentType) - { - m_Uri.remove_suffix(UriSuffix8.size() + 1); - } - } - } - else - { - m_UriUtf8.Reset(); - m_Uri = {}; - m_UriWithExtension = {}; - } - - if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) - { - --QueryStringLength; // We skip the leading question mark - - WideToUtf8({(wchar_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(wchar_t)}, m_QueryStringUtf8); - } - else - { - m_QueryStringUtf8.Reset(); - } - - m_QueryString = std::string_view(m_QueryStringUtf8); - m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); - m_ContentLength = GetContentLength(HttpRequestPtr); - m_ContentType = GetContentType(HttpRequestPtr); - - // It an explicit content type extension was specified then we'll use that over any - // Accept: header value that may be present - - if (AcceptContentType != HttpContentType::kUnknownContentType) - { - m_AcceptType = AcceptContentType; - } - else - { - m_AcceptType = GetAcceptType(HttpRequestPtr); - } - - if (m_Verb == HttpVerb::kHead) - { - SetSuppressResponseBody(); - } -} - -HttpSysServerRequest::~HttpSysServerRequest() -{ -} - -Oid -HttpSysServerRequest::ParseSessionId() const -{ - const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); - - for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) - { - HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; - std::string_view HeaderName{Header.pName, Header.NameLength}; - - if (HeaderName == "UE-Session"sv) - { - if (Header.RawValueLength == Oid::StringLength) - { - return Oid::FromHexString({Header.pRawValue, Header.RawValueLength}); - } - } - } - - return {}; -} - -uint32_t -HttpSysServerRequest::ParseRequestId() const -{ - const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); - - for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) - { - HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; - std::string_view HeaderName{Header.pName, Header.NameLength}; - - if (HeaderName == "UE-Request"sv) - { - std::string_view RequestValue{Header.pRawValue, Header.RawValueLength}; - uint32_t RequestId = 0; - std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId); - return RequestId; - } - } - - return 0; -} - -IoBuffer -HttpSysServerRequest::ReadPayload() -{ - return m_PayloadBuffer; -} - -void -HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) -{ - ZEN_ASSERT(IsHandled() == false); - - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); - - if (SuppressBody()) - { - Response->SuppressResponseBody(); - } - - m_NextCompletionHandler = Response; - - SetIsHandled(); -} - -void -HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) -{ - ZEN_ASSERT(IsHandled() == false); - - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); - - if (SuppressBody()) - { - Response->SuppressResponseBody(); - } - - m_NextCompletionHandler = Response; - - SetIsHandled(); -} - -void -HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) -{ - ZEN_ASSERT(IsHandled() == false); - - auto Response = - new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size()); - - if (SuppressBody()) - { - Response->SuppressResponseBody(); - } - - m_NextCompletionHandler = Response; - - SetIsHandled(); -} - -void -HttpSysServerRequest::WriteResponseAsync(std::function&& ContinuationHandler) -{ - if (m_HttpTx.Server().IsAsyncResponseEnabled()) - { - m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler)); - } - else - { - ContinuationHandler(m_HttpTx.ServerRequest()); - } -} - -bool -HttpSysServerRequest::TryGetRanges(HttpRanges& Ranges) -{ - HTTP_REQUEST* Req = m_HttpTx.HttpRequest(); - const HTTP_KNOWN_HEADER& RangeHeader = Req->Headers.KnownHeaders[HttpHeaderRange]; - - return TryParseHttpRangeHeader({RangeHeader.pRawValue, RangeHeader.RawValueLength}, Ranges); -} - -////////////////////////////////////////////////////////////////////////// - -InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) -{ -} - -InitialRequestHandler::~InitialRequestHandler() -{ -} - -void -InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) -{ - ZEN_TRACE_CPU("httpsys::Request::IssueRequest"); - - HttpSysTransaction& Tx = Transaction(); - PTP_IO Iocp = Tx.Iocp(); - HTTP_REQUEST* HttpReq = Tx.HttpRequest(); - - StartThreadpoolIo(Iocp); - - ULONG HttpApiResult; - - if (IsInitialRequest()) - { - HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), - HTTP_NULL_ID, - HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, - HttpReq, - RequestBufferSize(), - NULL, - Tx.Overlapped()); - } - else - { - // The http.sys team recommends limiting the size to 128KB - static const uint64_t kMaxBytesPerApiCall = 128 * 1024; - - uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; - const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); - void* BufferWriteCursor = reinterpret_cast(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; - - HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), - HttpReq->RequestId, - 0, /* Flags */ - BufferWriteCursor, - gsl::narrow(BytesToReadThisCall), - nullptr, // BytesReturned - Tx.Overlapped()); - } - - if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) - { - CancelThreadpoolIo(Iocp); - - ErrorCode = MakeErrorCode(HttpApiResult); - - ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message()); - - return; - } - - ErrorCode.clear(); -} - -HttpSysRequestHandler* -InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) -{ - auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); - - switch (IoResult) - { - default: - case ERROR_OPERATION_ABORTED: - return nullptr; - - case ERROR_MORE_DATA: // Insufficient buffer space - case NO_ERROR: - break; - } - - ZEN_TRACE_CPU("httpsys::HandleCompletion"); - - // Route request - - try - { - HTTP_REQUEST* HttpReq = HttpRequest(); - -# if 0 - for (int i = 0; i < HttpReq->RequestInfoCount; ++i) - { - auto& ReqInfo = HttpReq->pRequestInfo[i]; - - switch (ReqInfo.InfoType) - { - case HttpRequestInfoTypeRequestTiming: - { - const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast(ReqInfo.pInfo); - - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeAuth: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeChannelBind: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslProtocol: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBindingDraft: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBinding: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV0: - { - const TCP_INFO_v0* TcpInfo = reinterpret_cast(ReqInfo.pInfo); - - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeRequestSizing: - { - const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast(ReqInfo.pInfo); - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeQuicStats: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV1: - { - const TCP_INFO_v1* TcpInfo = reinterpret_cast(ReqInfo.pInfo); - - ZEN_INFO(""); - } - break; - } - } -# endif - - if (HttpService* Service = reinterpret_cast(HttpReq->UrlContext)) - { - if (m_IsInitialRequest) - { - m_ContentLength = GetContentLength(HttpReq); - const HttpContentType ContentType = GetContentType(HttpReq); - - if (m_ContentLength) - { - // Handle initial chunk read by copying any payload which has already been copied - // into our embedded request buffer - - m_PayloadBuffer = IoBuffer(m_ContentLength); - m_PayloadBuffer.SetContentType(ContentType); - - uint64_t BytesToRead = m_ContentLength; - uint8_t* const BufferBase = reinterpret_cast(m_PayloadBuffer.MutableData()); - uint8_t* BufferWriteCursor = BufferBase; - - const int EntityChunkCount = HttpReq->EntityChunkCount; - - for (int i = 0; i < EntityChunkCount; ++i) - { - HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; - - ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); - - const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; - - ZEN_ASSERT(BufferLength <= BytesToRead); - - memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); - - BufferWriteCursor += BufferLength; - BytesToRead -= BufferLength; - } - - m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; - } - } - else - { - m_CurrentPayloadOffset += NumberOfBytesTransferred; - } - - if (m_CurrentPayloadOffset != m_ContentLength) - { - // Body not complete, issue another read request to receive more body data - return this; - } - - // Request body received completely - - m_PayloadBuffer.MakeImmutable(); - - HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer); - - if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler) - { - return Response; - } - - if (!ThisRequest.IsHandled()) - { - return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); - } - } - - // Unable to route - return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); - } - catch (std::system_error& SystemError) - { - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) - { - return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, SystemError.what()); - } - - ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); - return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, SystemError.what()); - } - catch (std::bad_alloc& BadAlloc) - { - return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, BadAlloc.what()); - } - catch (std::exception& ex) - { - ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); - return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, ex.what()); - } -} - -////////////////////////////////////////////////////////////////////////// -// -// HttpServer interface implementation -// - -int -HttpSysServer::Initialize(int BasePort) -{ - int EffectivePort = InitializeServer(BasePort); - StartServer(); - return EffectivePort; -} - -void -HttpSysServer::RequestExit() -{ - m_ShutdownEvent.Set(); -} -void -HttpSysServer::RegisterService(HttpService& Service) -{ - RegisterService(Service.BaseUri(), Service); -} - -Ref -CreateHttpSysServer(HttpSysConfig Config) -{ - return Ref(new HttpSysServer(Config)); -} - -} // namespace zen -#endif diff --git a/src/zenhttp/httpsys.h b/src/zenhttp/httpsys.h deleted file mode 100644 index 6a6b16525..000000000 --- a/src/zenhttp/httpsys.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -#ifndef ZEN_WITH_HTTPSYS -# if ZEN_PLATFORM_WINDOWS -# define ZEN_WITH_HTTPSYS 1 -# else -# define ZEN_WITH_HTTPSYS 0 -# endif -#endif - -namespace zen { - -struct HttpSysConfig -{ - unsigned int ThreadCount = 0; - unsigned int AsyncWorkThreadCount = 0; - bool IsAsyncResponseEnabled = true; - bool IsRequestLoggingEnabled = false; - bool IsDedicatedServer = false; -}; - -Ref CreateHttpSysServer(HttpSysConfig Config); - -} // namespace zen diff --git a/src/zenhttp/include/zenhttp/diagsvcs.h b/src/zenhttp/include/zenhttp/diagsvcs.h deleted file mode 100644 index 8cc869c83..000000000 --- a/src/zenhttp/include/zenhttp/diagsvcs.h +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include -#include - -#include - -////////////////////////////////////////////////////////////////////////// - -namespace zen { - -/** HTTP test endpoint - - This is intended to be used to exercise basic HTTP communication infrastructure - which is useful for benchmarking performance of the server code and when evaluating - network performance / diagnosing connectivity issues - - */ -class HttpTestService : public HttpService -{ -public: - HttpTestService() {} - ~HttpTestService() = default; - - virtual const char* BaseUri() const override { return "/test/"; } - - virtual void HandleRequest(HttpServerRequest& Request) override - { - using namespace std::literals; - - auto Uri = Request.RelativeUri(); - - if (Uri == "hello"sv) - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"hello world!"sv); - } - else if (Uri == "1K"sv) - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1k); - } - else if (Uri == "1M"sv) - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1m); - } - else if (Uri == "1M_1k"sv) - { - std::vector Buffers; - Buffers.reserve(1024); - - for (int i = 0; i < 1024; ++i) - { - Buffers.push_back(m_1k); - } - - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); - } - else if (Uri == "1G"sv) - { - std::vector Buffers; - Buffers.reserve(1024); - - for (int i = 0; i < 1024; ++i) - { - Buffers.push_back(m_1m); - } - - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); - } - else if (Uri == "1G_1k"sv) - { - std::vector Buffers; - Buffers.reserve(1024 * 1024); - - for (int i = 0; i < 1024 * 1024; ++i) - { - Buffers.push_back(m_1k); - } - - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); - } - } - -private: - IoBuffer m_1m{1024 * 1024}; - IoBuffer m_1k{m_1m, 0u, 1024}; -}; - -struct HealthServiceInfo -{ - std::filesystem::path DataRoot; - std::filesystem::path AbsLogPath; - std::string HttpServerClass; - std::string BuildVersion; -}; - -/** Health monitoring endpoint - - Thji - */ -class HttpHealthService : public HttpService -{ -public: - HttpHealthService(); - ~HttpHealthService() = default; - - void SetHealthInfo(HealthServiceInfo&& Info); - - virtual const char* BaseUri() const override; - virtual void HandleRequest(HttpServerRequest& Request) override final; - -private: - HttpRequestRouter m_Router; - RwLock m_InfoLock; - HealthServiceInfo m_HealthInfo; -}; - -} // namespace zen diff --git a/src/zenhttp/iothreadpool.cpp b/src/zenhttp/iothreadpool.cpp deleted file mode 100644 index da4b42e28..000000000 --- a/src/zenhttp/iothreadpool.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "iothreadpool.h" - -#include - -#if ZEN_PLATFORM_WINDOWS - -namespace zen { - -WinIoThreadPool::WinIoThreadPool(int InThreadCount, int InMaxThreadCount) -{ - ZEN_ASSERT(InThreadCount); - - if (InMaxThreadCount < InThreadCount) - { - InMaxThreadCount = InThreadCount; - } - - m_ThreadPool = CreateThreadpool(NULL); - - SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount); - SetThreadpoolThreadMaximum(m_ThreadPool, InMaxThreadCount); - - InitializeThreadpoolEnvironment(&m_CallbackEnvironment); - - m_CleanupGroup = CreateThreadpoolCleanupGroup(); - - SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool); - - SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); -} - -WinIoThreadPool::~WinIoThreadPool() -{ - CloseThreadpool(m_ThreadPool); -} - -void -WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) -{ - ZEN_ASSERT(!m_ThreadPoolIo); - - m_ThreadPoolIo = CreateThreadpoolIo(IoHandle, Callback, Context, &m_CallbackEnvironment); - - if (!m_ThreadPoolIo) - { - ErrorCode = MakeErrorCodeFromLastError(); - } -} - -} // namespace zen - -#endif diff --git a/src/zenhttp/iothreadpool.h b/src/zenhttp/iothreadpool.h deleted file mode 100644 index e75e95e58..000000000 --- a/src/zenhttp/iothreadpool.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -#if ZEN_PLATFORM_WINDOWS -# include - -# include - -namespace zen { - -class WinIoThreadPool -{ -public: - WinIoThreadPool(int InThreadCount, int InMaxThreadCount); - ~WinIoThreadPool(); - - void CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode); - inline PTP_IO Iocp() const { return m_ThreadPoolIo; } - -private: - PTP_POOL m_ThreadPool = nullptr; - PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; - PTP_IO m_ThreadPoolIo = nullptr; - TP_CALLBACK_ENVIRON m_CallbackEnvironment; -}; - -} // namespace zen -#endif diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp new file mode 100644 index 000000000..0c6b189f9 --- /dev/null +++ b/src/zenhttp/servers/httpasio.cpp @@ -0,0 +1,1052 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpasio.h" + +#include +#include +#include +#include +#include + +#include "httpparser.h" + +#include +#include +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#if ZEN_PLATFORM_WINDOWS +# include +# include +#endif +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#define ASIO_VERBOSE_TRACE 0 + +#if ASIO_VERBOSE_TRACE +# define ZEN_TRACE_VERBOSE ZEN_TRACE +#else +# define ZEN_TRACE_VERBOSE(fmtstr, ...) +#endif + +namespace zen::asio_http { + +using namespace std::literals; + +struct HttpAcceptor; +struct HttpResponse; +struct HttpServerConnection; + +static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); + +inline spdlog::logger& +InitLogger() +{ + spdlog::logger& Logger = logging::Get("asio"); + // Logger.set_level(spdlog::level::trace); + return Logger; +} + +inline spdlog::logger& +Log() +{ + static spdlog::logger& g_Logger = InitLogger(); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAsioServerImpl +{ +public: + HttpAsioServerImpl(); + ~HttpAsioServerImpl(); + + int Start(uint16_t Port, int ThreadCount); + void Stop(); + void RegisterService(const char* UrlPath, HttpService& Service); + HttpService* RouteRequest(std::string_view Url); + + asio::io_service m_IoService; + asio::io_service::work m_Work{m_IoService}; + std::unique_ptr m_Acceptor; + std::vector m_ThreadPool; + + struct ServiceEntry + { + std::string ServiceUrlPath; + HttpService* Service; + }; + + RwLock m_Lock; + std::vector m_UriHandlers; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpAsioServerRequest : public HttpServerRequest +{ +public: + HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpAsioServerRequest(); + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpAsioServerRequest(const HttpAsioServerRequest&) = delete; + HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete; + + HttpRequestParser& m_Request; + IoBuffer m_PayloadBuffer; + std::unique_ptr m_Response; +}; + +struct HttpResponse +{ +public: + HttpResponse() = default; + explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} + + void InitializeForPayload(uint16_t ResponseCode, std::span BlobList) + { + ZEN_TRACE_CPU("asio::InitializeForPayload"); + + m_ResponseCode = ResponseCode; + const uint32_t ChunkCount = gsl::narrow(BlobList.size()); + + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { +#if 1 + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); +#else + IoBuffer TempBuffer = std::move(Buffer); + TempBuffer.MakeOwned(); + m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer)); +#endif + } + + uint64_t LocalDataSize = 0; + + m_AsioBuffers.push_back({}); // Placeholder for header + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // TODO: Use direct file transfer, via TransmitFile/sendfile + // + // this looks like it requires some custom asio plumbing however + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + else + { + // Send from memory + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + } + m_ContentLength = LocalDataSize; + + auto Headers = GetHeaders(); + m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size()); + } + + uint16_t ResponseCode() const { return m_ResponseCode; } + uint64_t ContentLength() const { return m_ContentLength; } + + const std::vector& AsioBuffers() const { return m_AsioBuffers; } + + std::string_view GetHeaders() + { + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" + << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Length: " << ContentLength() << "\r\n"sv; + + if (!m_IsKeepAlive) + { + m_Headers << "Connection: close\r\n"sv; + } + + m_Headers << "\r\n"sv; + + return m_Headers; + } + + void SuppressPayload() { m_AsioBuffers.resize(1); } + +private: + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; + std::vector m_DataBuffers; + std::vector m_AsioBuffers; + ExtendableStringBuilder<160> m_Headers; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this +{ + HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr&& Socket); + ~HttpServerConnection(); + + std::shared_ptr AsSharedPtr() { return shared_from_this(); } + + // HttpConnectionBase implementation + + virtual void TerminateConnection() override; + virtual void HandleRequest() override; + + void HandleNewRequest(); + +private: + enum class RequestState + { + kInitialState, + kInitialRead, + kReadingMore, + kWriting, + kWritingFinal, + kDone, + kTerminated + }; + + RequestState m_RequestState = RequestState::kInitialState; + HttpRequestParser m_RequestData{*this}; + + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false); + void CloseConnection(); + + HttpAsioServerImpl& m_Server; + asio::streambuf m_RequestBuffer; + std::unique_ptr m_Socket; + std::atomic m_RequestCounter{0}; + uint32_t m_ConnectionId = 0; + Ref m_PackageHandler; + + RwLock m_ResponsesLock; + std::deque> m_Responses; +}; + +std::atomic g_ConnectionIdCounter{0}; + +HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr&& Socket) +: m_Server(Server) +, m_Socket(std::move(Socket)) +, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) +{ + ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); +} + +HttpServerConnection::~HttpServerConnection() +{ + ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); +} + +void +HttpServerConnection::HandleNewRequest() +{ + EnqueueRead(); +} + +void +HttpServerConnection::TerminateConnection() +{ + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) + { + return; + } + + m_RequestState = RequestState::kTerminated; + ZEN_ASSERT(m_Socket); + + // Terminating, we don't care about any errors when closing socket + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_both, Ec); + m_Socket->close(Ec); +} + +void +HttpServerConnection::EnqueueRead() +{ + if (m_RequestState == RequestState::kInitialRead) + { + m_RequestState = RequestState::kReadingMore; + } + else + { + m_RequestState = RequestState::kInitialRead; + } + + m_RequestBuffer.prepare(64 * 1024); + + asio::async_read(*m_Socket.get(), + m_RequestBuffer, + asio::transfer_at_least(1), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); +} + +void +HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead) + { + ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection: {}, reason: '{}'", m_ConnectionId, Ec.message()); + return; + } + else + { + ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); + return TerminateConnection(); + } + } + + ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}", + m_ConnectionId, + m_RequestCounter.load(std::memory_order_relaxed), + zen::GetCurrentThreadId(), + NiceBytes(ByteCount)); + + while (m_RequestBuffer.size()) + { + const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); + + size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size()); + if (Result == ~0ull) + { + return TerminateConnection(); + } + + m_RequestBuffer.consume(Result); + } + + switch (m_RequestState) + { + case RequestState::kDone: + case RequestState::kWritingFinal: + case RequestState::kTerminated: + break; + + default: + EnqueueRead(); + break; + } +} + +void +HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop) +{ + if (Ec) + { + ZEN_WARN("on data sent ERROR, connection: {}, reason: '{}'", m_ConnectionId, Ec.message()); + TerminateConnection(); + } + else + { + ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}", + m_ConnectionId, + m_RequestCounter.load(std::memory_order_relaxed), + zen::GetCurrentThreadId(), + NiceBytes(ByteCount)); + + if (!m_RequestData.IsKeepAlive()) + { + CloseConnection(); + } + else + { + if (Pop) + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.pop_front(); + } + + m_RequestCounter.fetch_add(1); + } + } +} + +void +HttpServerConnection::CloseConnection() +{ + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) + { + return; + } + ZEN_ASSERT(m_Socket); + m_RequestState = RequestState::kDone; + + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + if (Ec) + { + ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); + } + m_Socket->close(Ec); + if (Ec) + { + ZEN_WARN("socket close ERROR, reason '{}'", Ec.message()); + } +} + +void +HttpServerConnection::HandleRequest() +{ + if (!m_RequestData.IsKeepAlive()) + { + m_RequestState = RequestState::kWritingFinal; + + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + + if (Ec) + { + ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); + } + } + else + { + m_RequestState = RequestState::kWriting; + } + + if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) + { + ZEN_TRACE_CPU("asio::HandleRequest"); + + HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body()); + + ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed)); + + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + { + try + { + Service->HandleRequest(Request); + } + catch (std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } + } + + if (std::unique_ptr Response = std::move(Request.m_Response)) + { + // Transmit the response + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + Response->SuppressPayload(); + } + + auto ResponseBuffers = Response->AsioBuffers(); + + uint64_t ResponseLength = 0; + + for (auto& Buffer : ResponseBuffers) + { + ResponseLength += Buffer.size(); + } + + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.push_back(std::move(Response)); + } + + // TODO: should cork/uncork for Linux? + + { + ZEN_TRACE_CPU("asio::async_write"); + asio::async_write(*m_Socket.get(), + ResponseBuffers, + asio::transfer_exactly(ResponseLength), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, true); + }); + } + return; + } + } + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "\r\n"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Connection: close\r\n" + "\r\n"sv; + } + + asio::async_write( + *m_Socket.get(), + asio::buffer(Response), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); + } + else + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "No suitable route found"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "Connection: close\r\n" + "\r\n" + "No suitable route found"sv; + } + + asio::async_write( + *m_Socket.get(), + asio::buffer(Response), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); + } +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAcceptor +{ + HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort) + : m_Server(Server) + , m_IoService(IoService) + , m_Acceptor(m_IoService, asio::ip::tcp::v6()) + { + m_Acceptor.set_option(asio::ip::v6_only(false)); +#if ZEN_PLATFORM_WINDOWS + // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms + typedef asio::detail::socket_option::boolean excluse_address; + m_Acceptor.set_option(excluse_address(true)); +#else // ZEN_PLATFORM_WINDOWS + m_Acceptor.set_option(asio::socket_base::reuse_address(false)); +#endif // ZEN_PLATFORM_WINDOWS + + m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + uint16_t EffectivePort = BasePort; + + asio::error_code BindErrorCode; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + // Sharing violation implies the port is being used by another process + for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode == asio::error::access_denied) + { + EffectivePort = 0; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode) + { + ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message()); + } + +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor.native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +#endif + m_Acceptor.listen(); + + ZEN_INFO("Started asio server at port '{}'", EffectivePort); + } + + void Start() + { + m_Acceptor.listen(); + InitAccept(); + } + + void Stop() { m_IsStopped = true; } + + void InitAccept() + { + auto SocketPtr = std::make_unique(m_IoService); + asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); + + m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", + m_Acceptor.local_endpoint().address().to_string(), + m_Acceptor.local_endpoint().port(), + Ec.message()); + } + else + { + // New connection established, pass socket ownership into connection object + // and initiate request handling loop. The connection lifetime is + // managed by the async read/write loop by passing the shared + // reference to the callbacks. + + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + auto Conn = std::make_shared(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + } + + if (!m_IsStopped.load()) + { + InitAccept(); + } + else + { + std::error_code CloseEc; + m_Acceptor.close(CloseEc); + if (CloseEc) + { + ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message()); + } + } + }); + } + + int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } + +private: + HttpAsioServerImpl& m_Server; + asio::io_service& m_IoService; + asio::ip::tcp::acceptor m_Acceptor; + std::atomic m_IsStopped{false}; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer) +: m_Request(Request) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const int PrefixLength = Service.UriPrefixLength(); + + std::string_view Uri = Request.Url(); + Uri.remove_prefix(std::min(PrefixLength, static_cast(Uri.size()))); + m_Uri = Uri; + m_UriWithExtension = Uri; + m_QueryString = Request.QueryString(); + + m_Verb = Request.RequestVerb(); + m_ContentLength = Request.Body().Size(); + m_ContentType = Request.ContentType(); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + // Parse any extension, to allow requesting a particular response encoding via the URL + + { + std::string_view UriSuffix8{m_Uri}; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); + } + } + } + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = Request.AcceptType(); + } +} + +HttpAsioServerRequest::~HttpAsioServerRequest() +{ +} + +Oid +HttpAsioServerRequest::ParseSessionId() const +{ + return m_Request.SessionId(); +} + +uint32_t +HttpAsioServerRequest::ParseRequestId() const +{ + return m_Request.RequestId(); +} + +IoBuffer +HttpAsioServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(HttpContentType::kBinary)); + std::array Empty; + + m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(ContentType)); + m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(!m_Response); + m_Response.reset(new HttpResponse(ContentType)); + + IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); + std::array SingleBufferList({MessageBuffer}); + + m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); +} + +void +HttpAsioServerRequest::WriteResponseAsync(std::function&& ContinuationHandler) +{ + ZEN_ASSERT(!m_Response); + + // Not one bit async, innit + ContinuationHandler(*this); +} + +bool +HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerImpl::HttpAsioServerImpl() +{ +} + +HttpAsioServerImpl::~HttpAsioServerImpl() +{ +} + +int +HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount) +{ + ZEN_ASSERT(ThreadCount > 0); + + ZEN_INFO("starting asio http with {} service threads", ThreadCount); + + m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port)); + m_Acceptor->Start(); + + // This should consist of a set of minimum threads and grow on demand to + // meet concurrency needs? Right now we end up allocating a large number + // of threads even if we never end up using all of them, which seems + // wasteful. It's also not clear how the demand for concurrency should + // be balanced with the engine side - ideally we'd have some kind of + // global scheduling to prevent one side from preventing the other side + // from making progress. Or at the very least, thread priorities should + // be considered. + + for (int i = 0; i < ThreadCount; ++i) + { + m_ThreadPool.emplace_back([this, Index = i + 1] { + SetCurrentThreadName(fmt::format("asio_io_{}", Index)); + + try + { + m_IoService.run(); + } + catch (std::exception& e) + { + ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what()); + } + }); + } + + ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort()); + + return m_Acceptor->GetAcceptPort(); +} + +void +HttpAsioServerImpl::Stop() +{ + m_Acceptor->Stop(); + m_IoService.stop(); + for (auto& Thread : m_ThreadPool) + { + Thread.join(); + } +} + +void +HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) +{ + std::string_view UrlPath(InUrlPath); + Service.SetUriPrefixLength(UrlPath.size()); + if (!UrlPath.empty() && UrlPath.back() == '/') + { + UrlPath.remove_suffix(1); + } + + RwLock::ExclusiveLockScope _(m_Lock); + m_UriHandlers.push_back({std::string(UrlPath), &Service}); +} + +HttpService* +HttpAsioServerImpl::RouteRequest(std::string_view Url) +{ + RwLock::SharedLockScope _(m_Lock); + + HttpService* CandidateService = nullptr; + std::string::size_type CandidateMatchSize = 0; + for (const ServiceEntry& SvcEntry : m_UriHandlers) + { + const std::string& SvcUrl = SvcEntry.ServiceUrlPath; + const std::string::size_type SvcUrlSize = SvcUrl.size(); + if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && + ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) + { + CandidateMatchSize = SvcUrl.size(); + CandidateService = SvcEntry.Service; + } + } + + return CandidateService; +} + +} // namespace zen::asio_http + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +class HttpAsioServer : public HttpServer +{ +public: + HttpAsioServer(unsigned int ThreadCount); + ~HttpAsioServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual int Initialize(int BasePort) override; + virtual void Run(bool IsInteractiveSession) override; + virtual void RequestExit() override; + virtual void Close() override; + +private: + Event m_ShutdownEvent; + int m_BasePort = 0; + unsigned int m_ThreadCount = 0; + + std::unique_ptr m_Impl; +}; + +HttpAsioServer::HttpAsioServer(unsigned int ThreadCount) +: m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +, m_Impl(std::make_unique()) +{ + ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser)); +} + +HttpAsioServer::~HttpAsioServer() +{ + if (m_Impl) + { + ZEN_ERROR("~HttpAsioServer() called without calling Close() first"); + } +} + +void +HttpAsioServer::Close() +{ + try + { + m_Impl->Stop(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception stopping http asio server: {}", ex.what()); + } + m_Impl.reset(); +} + +void +HttpAsioServer::RegisterService(HttpService& Service) +{ + m_Impl->RegisterService(Service.BaseUri(), Service); +} + +int +HttpAsioServer::Initialize(int BasePort) +{ + m_BasePort = m_Impl->Start(gsl::narrow(BasePort), m_ThreadCount); + return m_BasePort; +} + +void +HttpAsioServer::Run(bool IsInteractive) +{ + const bool TestMode = !IsInteractive; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +#if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#endif +} + +void +HttpAsioServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +Ref +CreateHttpAsioServer(unsigned int ThreadCount) +{ + return Ref{new HttpAsioServer(ThreadCount)}; +} + +} // namespace zen diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h new file mode 100644 index 000000000..2366f3437 --- /dev/null +++ b/src/zenhttp/servers/httpasio.h @@ -0,0 +1,11 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +namespace zen { + +Ref CreateHttpAsioServer(unsigned int ThreadCount); + +} // namespace zen diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp new file mode 100644 index 000000000..658f51831 --- /dev/null +++ b/src/zenhttp/servers/httpnull.cpp @@ -0,0 +1,88 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpnull.h" + +#include + +#if ZEN_PLATFORM_WINDOWS +# include +#endif + +namespace zen { + +HttpNullServer::HttpNullServer() +{ +} + +HttpNullServer::~HttpNullServer() +{ +} + +void +HttpNullServer::RegisterService(HttpService& Service) +{ + ZEN_UNUSED(Service); +} + +int +HttpNullServer::Initialize(int BasePort) +{ + return BasePort; +} + +void +HttpNullServer::Run(bool IsInteractiveSession) +{ + const bool TestMode = !IsInteractiveSession; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +#if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#endif +} + +void +HttpNullServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +void +HttpNullServer::Close() +{ +} + +} // namespace zen diff --git a/src/zenhttp/servers/httpnull.h b/src/zenhttp/servers/httpnull.h new file mode 100644 index 000000000..965e729f7 --- /dev/null +++ b/src/zenhttp/servers/httpnull.h @@ -0,0 +1,30 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +namespace zen { + +/** + * @brief Null implementation of "http" server. Does nothing + */ + +class HttpNullServer : public HttpServer +{ +public: + HttpNullServer(); + ~HttpNullServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual int Initialize(int BasePort) override; + virtual void Run(bool IsInteractiveSession) override; + virtual void RequestExit() override; + virtual void Close() override; + +private: + Event m_ShutdownEvent; +}; + +} // namespace zen diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp new file mode 100644 index 000000000..6b987151a --- /dev/null +++ b/src/zenhttp/servers/httpparser.cpp @@ -0,0 +1,370 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpparser.h" + +#include +#include + +namespace zen { + +using namespace std::literals; + +static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); + +////////////////////////////////////////////////////////////////////////// +// +// HttpRequestParser +// + +http_parser_settings HttpRequestParser::s_ParserSettings{ + .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, + .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, + .on_status = + [](http_parser* p, const char* Data, size_t ByteCount) { + ZEN_UNUSED(p, Data, ByteCount); + return 0; + }, + .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, + .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, + .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, + .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, + .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, + .on_chunk_header{}, + .on_chunk_complete{}}; + +HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection) +{ + http_parser_init(&m_Parser, HTTP_REQUEST); + m_Parser.data = this; + + ResetState(); +} + +HttpRequestParser::~HttpRequestParser() +{ +} + +size_t +HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize) +{ + const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); + + http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser)); + + if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE) + { + ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno)); + return ~0ull; + } + + return ConsumedBytes; +} + +int +HttpRequestParser::OnUrl(const char* Data, size_t Bytes) +{ + if (!m_Url) + { + ZEN_ASSERT_SLOW(m_UrlLength == 0); + m_Url = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_UrlLength += Bytes; + + return 0; +} + +int +HttpRequestParser::OnHeader(const char* Data, size_t Bytes) +{ + if (m_CurrentHeaderValueLength) + { + AppendCurrentHeader(); + + m_CurrentHeaderNameLength = 0; + m_CurrentHeaderValueLength = 0; + m_CurrentHeaderName = m_HeaderCursor; + } + else if (m_CurrentHeaderName == nullptr) + { + m_CurrentHeaderName = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_CurrentHeaderNameLength += Bytes; + + return 0; +} + +void +HttpRequestParser::AppendCurrentHeader() +{ + std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength); + std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength); + + const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); + + if (HeaderHash == HashContentLength) + { + m_ContentLengthHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashAccept) + { + m_AcceptHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashContentType) + { + m_ContentTypeHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashSession) + { + m_SessionId = Oid::FromHexString(HeaderValue); + } + else if (HeaderHash == HashRequest) + { + std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); + } + else if (HeaderHash == HashExpect) + { + if (HeaderValue == "100-continue"sv) + { + // We don't currently do anything with this + m_Expect100Continue = true; + } + else + { + ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); + } + } + else if (HeaderHash == HashRange) + { + m_RangeHeaderIndex = (int8_t)m_Headers.size(); + } + + m_Headers.emplace_back(HeaderName, HeaderValue); +} + +int +HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes) +{ + if (m_CurrentHeaderValueLength == 0) + { + m_CurrentHeaderValue = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_CurrentHeaderValueLength += Bytes; + + return 0; +} + +static void +NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl) +{ + bool LastCharWasSeparator = false; + for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex) + { + const char UrlChar = Url[UrlIndex]; + const bool IsSeparator = (UrlChar == '/'); + + if (IsSeparator && LastCharWasSeparator) + { + if (NormalizedUrl.empty()) + { + NormalizedUrl.reserve(UrlLength); + NormalizedUrl.append(Url, UrlIndex); + } + + if (!LastCharWasSeparator) + { + NormalizedUrl.push_back('/'); + } + } + else if (!NormalizedUrl.empty()) + { + NormalizedUrl.push_back(UrlChar); + } + + LastCharWasSeparator = IsSeparator; + } +} + +int +HttpRequestParser::OnHeadersComplete() +{ + try + { + if (m_CurrentHeaderValueLength) + { + AppendCurrentHeader(); + } + + m_KeepAlive = !!http_should_keep_alive(&m_Parser); + + switch (m_Parser.method) + { + case HTTP_GET: + m_RequestVerb = HttpVerb::kGet; + break; + + case HTTP_POST: + m_RequestVerb = HttpVerb::kPost; + break; + + case HTTP_PUT: + m_RequestVerb = HttpVerb::kPut; + break; + + case HTTP_DELETE: + m_RequestVerb = HttpVerb::kDelete; + break; + + case HTTP_HEAD: + m_RequestVerb = HttpVerb::kHead; + break; + + case HTTP_COPY: + m_RequestVerb = HttpVerb::kCopy; + break; + + case HTTP_OPTIONS: + m_RequestVerb = HttpVerb::kOptions; + break; + + default: + ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method)); + break; + } + + std::string_view Url(m_Url, m_UrlLength); + + if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos) + { + m_UrlLength = QuerySplit; + m_QueryString = m_Url + QuerySplit + 1; + m_QueryLength = Url.size() - QuerySplit - 1; + } + + NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl); + + if (m_ContentLengthHeaderIndex >= 0) + { + std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value; + uint64_t ContentLength = 0; + std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength); + + if (ContentLength) + { + m_BodyBuffer = IoBuffer(ContentLength); + } + + m_BodyBuffer.SetContentType(ContentType()); + + m_BodyPosition = 0; + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("HttpRequestParser::OnHeadersComplete failed. Reason '{}'", Ex.what()); + return -1; + } + return 0; +} + +int +HttpRequestParser::OnBody(const char* Data, size_t Bytes) +{ + if (m_BodyPosition + Bytes > m_BodyBuffer.Size()) + { + ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes", + (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); + return 1; + } + memcpy(reinterpret_cast(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); + m_BodyPosition += Bytes; + + if (http_body_is_final(&m_Parser)) + { + if (m_BodyPosition != m_BodyBuffer.Size()) + { + ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); + return 1; + } + } + + return 0; +} + +void +HttpRequestParser::ResetState() +{ + m_HeaderCursor = m_HeaderBuffer; + m_CurrentHeaderName = nullptr; + m_CurrentHeaderNameLength = 0; + m_CurrentHeaderValue = nullptr; + m_CurrentHeaderValueLength = 0; + m_CurrentHeaderName = nullptr; + m_Url = nullptr; + m_UrlLength = 0; + m_QueryString = nullptr; + m_QueryLength = 0; + m_ContentLengthHeaderIndex = -1; + m_AcceptHeaderIndex = -1; + m_ContentTypeHeaderIndex = -1; + m_RangeHeaderIndex = -1; + m_Expect100Continue = false; + m_BodyBuffer = {}; + m_BodyPosition = 0; + m_Headers.clear(); + m_NormalizedUrl.clear(); +} + +int +HttpRequestParser::OnMessageBegin() +{ + return 0; +} + +int +HttpRequestParser::OnMessageComplete() +{ + m_Connection.HandleRequest(); + + ResetState(); + + return 0; +} + +} // namespace zen diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h new file mode 100644 index 000000000..219ac351d --- /dev/null +++ b/src/zenhttp/servers/httpparser.h @@ -0,0 +1,112 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class HttpRequestParserCallbacks +{ +public: + virtual ~HttpRequestParserCallbacks() = default; + virtual void HandleRequest() = 0; + virtual void TerminateConnection() = 0; +}; + +struct HttpRequestParser +{ + explicit HttpRequestParser(HttpRequestParserCallbacks& Connection); + ~HttpRequestParser(); + + size_t ConsumeData(const char* InputData, size_t DataSize); + void ResetState(); + + HttpVerb RequestVerb() const { return m_RequestVerb; } + bool IsKeepAlive() const { return m_KeepAlive; } + std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; } + std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); } + IoBuffer Body() { return m_BodyBuffer; } + + inline HttpContentType ContentType() + { + if (m_ContentTypeHeaderIndex < 0) + { + return HttpContentType::kUnknownContentType; + } + + return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value); + } + + inline HttpContentType AcceptType() + { + if (m_AcceptHeaderIndex < 0) + { + return HttpContentType::kUnknownContentType; + } + + return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value); + } + + Oid SessionId() const { return m_SessionId; } + int RequestId() const { return m_RequestId; } + + std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); } + +private: + struct HeaderEntry + { + HeaderEntry() = default; + + HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {} + + std::string_view Name; + std::string_view Value; + }; + + HttpRequestParserCallbacks& m_Connection; + char* m_HeaderCursor = m_HeaderBuffer; + char* m_Url = nullptr; + size_t m_UrlLength = 0; + char* m_QueryString = nullptr; + size_t m_QueryLength = 0; + char* m_CurrentHeaderName = nullptr; // Used while parsing headers + size_t m_CurrentHeaderNameLength = 0; + char* m_CurrentHeaderValue = nullptr; // Used while parsing headers + size_t m_CurrentHeaderValueLength = 0; + std::vector m_Headers; + int8_t m_ContentLengthHeaderIndex; + int8_t m_AcceptHeaderIndex; + int8_t m_ContentTypeHeaderIndex; + int8_t m_RangeHeaderIndex; + HttpVerb m_RequestVerb; + bool m_KeepAlive = false; + bool m_Expect100Continue = false; + int m_RequestId = -1; + Oid m_SessionId{}; + IoBuffer m_BodyBuffer; + uint64_t m_BodyPosition = 0; + http_parser m_Parser; + char m_HeaderBuffer[1024]; + std::string m_NormalizedUrl; + + void AppendCurrentHeader(); + + int OnMessageBegin(); + int OnUrl(const char* Data, size_t Bytes); + int OnHeader(const char* Data, size_t Bytes); + int OnHeaderValue(const char* Data, size_t Bytes); + int OnHeadersComplete(); + int OnBody(const char* Data, size_t Bytes); + int OnMessageComplete(); + + static HttpRequestParser* GetThis(http_parser* Parser) { return reinterpret_cast(Parser->data); } + static http_parser_settings s_ParserSettings; +}; + +} // namespace zen diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp new file mode 100644 index 000000000..2e934473e --- /dev/null +++ b/src/zenhttp/servers/httpplugin.cpp @@ -0,0 +1,781 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#if ZEN_WITH_PLUGINS + +# include "httpparser.h" + +# include +# include +# include +# include +# include + +# include +# include + +# if ZEN_PLATFORM_WINDOWS +# include +# endif + +# define PLUGIN_VERBOSE_TRACE 1 + +# if PLUGIN_VERBOSE_TRACE +# define ZEN_TRACE_VERBOSE ZEN_TRACE +# else +# define ZEN_TRACE_VERBOSE(fmtstr, ...) +# endif + +namespace zen { + +struct HttpPluginServerImpl; +struct HttpPluginResponse; + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginConnectionHandler : public TransportServerConnection, public HttpRequestParserCallbacks, RefCounted +{ + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + + virtual void OnBytesRead(const void* Buffer, size_t DataSize) override; + + // HttpRequestParserCallbacks + + virtual void HandleRequest() override; + virtual void TerminateConnection() override; + + void Initialize(TransportConnection* Transport, HttpPluginServerImpl& Server); + +private: + enum class RequestState + { + kInitialState, + kInitialRead, + kReadingMore, + kWriting, // Currently writing response, connection will be re-used + kWritingFinal, // Writing response, connection will be closed + kDone, + kTerminated + }; + + RequestState m_RequestState = RequestState::kInitialState; + HttpRequestParser m_RequestParser{*this}; + + uint32_t m_ConnectionId = 0; + Ref m_PackageHandler; + + TransportConnection* m_TransportConnection = nullptr; + HttpPluginServerImpl* m_Server = nullptr; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginServerImpl : public TransportServer +{ + HttpPluginServerImpl(); + ~HttpPluginServerImpl(); + + void AddPlugin(Ref Plugin); + void RemovePlugin(Ref Plugin); + + void Start(); + void Stop(); + + void RegisterService(const char* InUrlPath, HttpService& Service); + HttpService* RouteRequest(std::string_view Url); + + struct ServiceEntry + { + std::string ServiceUrlPath; + HttpService* Service; + }; + + RwLock m_Lock; + std::vector m_UriHandlers; + std::vector> m_Plugins; + + // TransportServer + + virtual TransportServerConnection* CreateConnectionHandler(TransportConnection* Connection) override; +}; + +/** This is the class which request handlers interface with when + generating responses + */ + +class HttpPluginServerRequest : public HttpServerRequest +{ +public: + HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpPluginServerRequest(); + + HttpPluginServerRequest(const HttpPluginServerRequest&) = delete; + HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpRequestParser& m_Request; + IoBuffer m_PayloadBuffer; + std::unique_ptr m_Response; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginResponse +{ +public: + HttpPluginResponse() = default; + explicit HttpPluginResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} + + void InitializeForPayload(uint16_t ResponseCode, std::span BlobList); + + inline uint16_t ResponseCode() const { return m_ResponseCode; } + inline uint64_t ContentLength() const { return m_ContentLength; } + + const std::vector& ResponseBuffers() const { return m_ResponseBuffers; } + void SuppressPayload() { m_ResponseBuffers.resize(1); } + +private: + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; + std::vector m_ResponseBuffers; + ExtendableStringBuilder<160> m_Headers; + + std::string_view GetHeaders(); +}; + +void +HttpPluginResponse::InitializeForPayload(uint16_t ResponseCode, std::span BlobList) +{ + ZEN_TRACE_CPU("http_plugin::InitializeForPayload"); + + m_ResponseCode = ResponseCode; + const uint32_t ChunkCount = gsl::narrow(BlobList.size()); + + m_ResponseBuffers.reserve(ChunkCount + 1); + m_ResponseBuffers.push_back({}); // Placeholder for header + + uint64_t TotalDataSize = 0; + + for (IoBuffer& Buffer : BlobList) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + TotalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // TODO: Use direct file transfer, via TransmitFile/sendfile + + m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + else + { + // Send from memory + + m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + } + m_ContentLength = TotalDataSize; + + auto Headers = GetHeaders(); + m_ResponseBuffers[0] = IoBufferBuilder::MakeCloneFromMemory(Headers.data(), Headers.size()); +} + +std::string_view +HttpPluginResponse::GetHeaders() +{ + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" + << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Length: " << ContentLength() << "\r\n"sv; + + if (!m_IsKeepAlive) + { + m_Headers << "Connection: close\r\n"sv; + } + + m_Headers << "\r\n"sv; + + return m_Headers; +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPluginServerImpl& Server) +{ + m_TransportConnection = Transport; + m_Server = &Server; +} + +uint32_t +HttpPluginConnectionHandler::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +HttpPluginConnectionHandler::Release() const +{ + return RefCounted::Release(); +} + +void +HttpPluginConnectionHandler::OnBytesRead(const void* Buffer, size_t AvailableBytes) +{ + while (AvailableBytes) + { + const size_t ConsumedBytes = m_RequestParser.ConsumeData((const char*)Buffer, AvailableBytes); + + if (ConsumedBytes == ~0ull) + { + // terminate connection + + return TerminateConnection(); + } + + Buffer = reinterpret_cast(Buffer) + ConsumedBytes; + AvailableBytes -= ConsumedBytes; + } +} + +// HttpRequestParserCallbacks + +void +HttpPluginConnectionHandler::HandleRequest() +{ + if (!m_RequestParser.IsKeepAlive()) + { + // Once response has been written, connection is done + m_RequestState = RequestState::kWritingFinal; + + // We're not going to read any more data from this socket + + const bool Receive = true; + const bool Transmit = false; + m_TransportConnection->Shutdown(Receive, Transmit); + } + else + { + m_RequestState = RequestState::kWriting; + } + + auto SendBuffer = [&](const IoBuffer& InBuffer) -> int64_t { + const char* Buffer = reinterpret_cast(InBuffer.GetData()); + size_t Bytes = InBuffer.GetSize(); + + return m_TransportConnection->WriteBytes(Buffer, Bytes); + }; + + // Generate response + + if (HttpService* Service = m_Server->RouteRequest(m_RequestParser.Url())) + { + ZEN_TRACE_CPU("http_plugin::HandleRequest"); + + HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body()); + + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + { + try + { + Service->HandleRequest(Request); + } + catch (std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } + } + + if (std::unique_ptr Response = std::move(Request.m_Response)) + { + // Transmit the response + + if (m_RequestParser.RequestVerb() == HttpVerb::kHead) + { + Response->SuppressPayload(); + } + + const std::vector& ResponseBuffers = Response->ResponseBuffers(); + + //// TODO: should cork/uncork for Linux? + + for (const IoBuffer& Buffer : ResponseBuffers) + { + int64_t SentBytes = SendBuffer(Buffer); + + if (SentBytes < 0) + { + TerminateConnection(); + + return; + } + } + + return; + } + } + + // No route found for request + + std::string_view Response; + + if (m_RequestParser.RequestVerb() == HttpVerb::kHead) + { + if (m_RequestParser.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "\r\n"sv; + } + else + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Connection: close\r\n" + "\r\n"sv; + } + } + else + { + if (m_RequestParser.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "No suitable route found"sv; + } + else + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "Connection: close\r\n" + "\r\n" + "No suitable route found"sv; + } + } + + const int64_t SentBytes = SendBuffer(IoBufferBuilder::MakeFromMemory(MakeMemoryView(Response))); + + if (SentBytes < 0) + { + TerminateConnection(); + + return; + } +} + +void +HttpPluginConnectionHandler::TerminateConnection() +{ + ZEN_ASSERT(m_TransportConnection); + m_TransportConnection->CloseConnection(); +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServerRequest::HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer) +: m_Request(Request) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const int PrefixLength = Service.UriPrefixLength(); + + std::string_view Uri = Request.Url(); + Uri.remove_prefix(std::min(PrefixLength, static_cast(Uri.size()))); + m_Uri = Uri; + m_UriWithExtension = Uri; + m_QueryString = Request.QueryString(); + + m_Verb = Request.RequestVerb(); + m_ContentLength = Request.Body().Size(); + m_ContentType = Request.ContentType(); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + // Parse any extension, to allow requesting a particular response encoding via the URL + + { + std::string_view UriSuffix8{m_Uri}; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); + } + } + } + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = Request.AcceptType(); + } +} + +HttpPluginServerRequest::~HttpPluginServerRequest() +{ +} + +Oid +HttpPluginServerRequest::ParseSessionId() const +{ + return m_Request.SessionId(); +} + +uint32_t +HttpPluginServerRequest::ParseRequestId() const +{ + return m_Request.RequestId(); +} + +IoBuffer +HttpPluginServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary)); + std::array Empty; + + m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpPluginResponse(ContentType)); + m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(!m_Response); + m_Response.reset(new HttpPluginResponse(ContentType)); + + IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); + std::array SingleBufferList({MessageBuffer}); + + m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); +} + +void +HttpPluginServerRequest::WriteResponseAsync(std::function&& ContinuationHandler) +{ + ZEN_ASSERT(!m_Response); + + // Not one bit async, innit + ContinuationHandler(*this); +} + +bool +HttpPluginServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServerImpl::HttpPluginServerImpl() +{ +} + +HttpPluginServerImpl::~HttpPluginServerImpl() +{ +} + +TransportServerConnection* +HttpPluginServerImpl::CreateConnectionHandler(TransportConnection* Connection) +{ + HttpPluginConnectionHandler* Handler{new HttpPluginConnectionHandler()}; + Handler->Initialize(Connection, *this); + return Handler; +} + +void +HttpPluginServerImpl::Start() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (auto& Plugin : m_Plugins) + { + try + { + Plugin->Initialize(this); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught during plugin initialization: {}", Ex.what()); + } + } +} + +void +HttpPluginServerImpl::Stop() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (auto& Plugin : m_Plugins) + { + try + { + Plugin->Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught during plugin shutdown: {}", Ex.what()); + } + + Plugin = nullptr; + } + + m_Plugins.clear(); +} + +void +HttpPluginServerImpl::AddPlugin(Ref Plugin) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_Plugins.emplace_back(std::move(Plugin)); +} + +void +HttpPluginServerImpl::RemovePlugin(Ref Plugin) +{ + RwLock::ExclusiveLockScope _(m_Lock); + auto It = std::find(begin(m_Plugins), end(m_Plugins), Plugin); + if (It != m_Plugins.end()) + { + m_Plugins.erase(It); + } +} + +void +HttpPluginServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) +{ + std::string_view UrlPath(InUrlPath); + Service.SetUriPrefixLength(UrlPath.size()); + if (!UrlPath.empty() && UrlPath.back() == '/') + { + UrlPath.remove_suffix(1); + } + + RwLock::ExclusiveLockScope _(m_Lock); + m_UriHandlers.push_back({std::string(UrlPath), &Service}); +} + +HttpService* +HttpPluginServerImpl::RouteRequest(std::string_view Url) +{ + RwLock::SharedLockScope _(m_Lock); + + HttpService* CandidateService = nullptr; + std::string::size_type CandidateMatchSize = 0; + for (const ServiceEntry& SvcEntry : m_UriHandlers) + { + const std::string& SvcUrl = SvcEntry.ServiceUrlPath; + const std::string::size_type SvcUrlSize = SvcUrl.size(); + if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && + ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) + { + CandidateMatchSize = SvcUrl.size(); + CandidateService = SvcEntry.Service; + } + } + + return CandidateService; +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServer::HttpPluginServer(unsigned int ThreadCount) +: m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +, m_Impl(new HttpPluginServerImpl) +{ +} + +HttpPluginServer::~HttpPluginServer() +{ + if (m_Impl) + { + ZEN_ERROR("~HttpPluginServer() called without calling Close() first"); + } +} + +int +HttpPluginServer::Initialize(int BasePort) +{ + try + { + m_Impl->Start(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception starting http plugin server: {}", ex.what()); + } + + return BasePort; +} + +void +HttpPluginServer::Close() +{ + try + { + m_Impl->Stop(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception stopping http plugin server: {}", ex.what()); + } + + delete m_Impl; + m_Impl = nullptr; +} + +void +HttpPluginServer::Run(bool IsInteractive) +{ + const bool TestMode = !IsInteractive; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +# if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +# else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +# endif +} + +void +HttpPluginServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +void +HttpPluginServer::RegisterService(HttpService& Service) +{ + m_Impl->RegisterService(Service.BaseUri(), Service); +} + +void +HttpPluginServer::AddPlugin(Ref Plugin) +{ + m_Impl->AddPlugin(Plugin); +} + +void +HttpPluginServer::RemovePlugin(Ref Plugin) +{ + m_Impl->RemovePlugin(Plugin); +} + +} // namespace zen +#endif diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp new file mode 100644 index 000000000..c1b4717cb --- /dev/null +++ b/src/zenhttp/servers/httpsys.cpp @@ -0,0 +1,2012 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpsys.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include +# include +# include "iothreadpool.h" + +# include + +namespace spdlog { +class logger; +} + +namespace zen { + +/** + * @brief Windows implementation of HTTP server based on http.sys + * + * This requires elevation to function by default but system configuration + * can soften this requirement. + * + * See README.md for details. + */ +class HttpSysServer : public HttpServer +{ + friend class HttpSysTransaction; + +public: + explicit HttpSysServer(const HttpSysConfig& Config); + ~HttpSysServer(); + + // HttpServer interface implementation + + virtual int Initialize(int BasePort) override; + virtual void Run(bool TestMode) override; + virtual void RequestExit() override; + virtual void RegisterService(HttpService& Service) override; + virtual void Close() override; + + WorkerThreadPool& WorkPool(); + + inline bool IsOk() const { return m_IsOk; } + inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + +private: + int InitializeServer(int BasePort); + void Cleanup(); + + void StartServer(); + void OnHandlingNewRequest(); + void IssueNewRequestMaybe(); + + void RegisterService(const char* Endpoint, HttpService& Service); + void UnregisterService(const char* Endpoint, HttpService& Service); + +private: + spdlog::logger& m_Log; + spdlog::logger& m_RequestLog; + spdlog::logger& Log() { return m_Log; } + + bool m_IsOk = false; + bool m_IsHttpInitialized = false; + bool m_IsRequestLoggingEnabled = false; + bool m_IsAsyncResponseEnabled = true; + + std::unique_ptr m_IoThreadPool; + + RwLock m_AsyncWorkPoolInitLock; + WorkerThreadPool* m_AsyncWorkPool = nullptr; + + std::vector m_BaseUris; // eg: http://*:nnnn/ + HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; + HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; + HANDLE m_RequestQueueHandle = 0; + std::atomic_int32_t m_PendingRequests{0}; + std::atomic_int32_t m_IsShuttingDown{0}; + int32_t m_MinPendingRequests = 16; + int32_t m_MaxPendingRequests = 128; + Event m_ShutdownEvent; + HttpSysConfig m_InitialConfig; +}; + +} // namespace zen +#endif + +#if ZEN_WITH_HTTPSYS + +# include +# include +# pragma comment(lib, "httpapi.lib") + +std::wstring +UTF8_to_UTF16(const char* InPtr) +{ + std::wstring OutString; + unsigned int Codepoint; + + while (*InPtr != 0) + { + unsigned char InChar = static_cast(*InPtr); + + if (InChar <= 0x7f) + Codepoint = InChar; + else if (InChar <= 0xbf) + Codepoint = (Codepoint << 6) | (InChar & 0x3f); + else if (InChar <= 0xdf) + Codepoint = InChar & 0x1f; + else if (InChar <= 0xef) + Codepoint = InChar & 0x0f; + else + Codepoint = InChar & 0x07; + + ++InPtr; + + if (((*InPtr & 0xc0) != 0x80) && (Codepoint <= 0x10ffff)) + { + if (Codepoint > 0xffff) + { + OutString.append(1, static_cast(0xd800 + (Codepoint >> 10))); + OutString.append(1, static_cast(0xdc00 + (Codepoint & 0x03ff))); + } + else if (Codepoint < 0xd800 || Codepoint >= 0xe000) + { + OutString.append(1, static_cast(Codepoint)); + } + } + } + + return OutString; +} + +namespace zen { + +using namespace std::literals; + +class HttpSysServer; +class HttpSysTransaction; +class HttpMessageResponseRequest; + +////////////////////////////////////////////////////////////////////////// + +HttpVerb +TranslateHttpVerb(HTTP_VERB ReqVerb) +{ + switch (ReqVerb) + { + case HttpVerbOPTIONS: + return HttpVerb::kOptions; + + case HttpVerbGET: + return HttpVerb::kGet; + + case HttpVerbHEAD: + return HttpVerb::kHead; + + case HttpVerbPOST: + return HttpVerb::kPost; + + case HttpVerbPUT: + return HttpVerb::kPut; + + case HttpVerbDELETE: + return HttpVerb::kDelete; + + case HttpVerbCOPY: + return HttpVerb::kCopy; + + default: + // TODO: invalid request? + return (HttpVerb)0; + } +} + +uint64_t +GetContentLength(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; + std::string_view cl(clh.pRawValue, clh.RawValueLength); + uint64_t ContentLength = 0; + std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); + return ContentLength; +}; + +HttpContentType +GetContentType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +HttpContentType +GetAcceptType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +/** + * @brief Base class for any pending or active HTTP transactions + */ +class HttpSysRequestHandler +{ +public: + explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {} + virtual ~HttpSysRequestHandler() = default; + + virtual void IssueRequest(std::error_code& ErrorCode) = 0; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; + HttpSysTransaction& Transaction() { return m_Transaction; } + + HttpSysRequestHandler(const HttpSysRequestHandler&) = delete; + HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete; + +private: + HttpSysTransaction& m_Transaction; +}; + +/** + * This is the handler for the initial HTTP I/O request which will receive the headers + * and however much of the remaining payload might fit in the embedded request buffer. + * + * It is also used to receive any entity body data relating to the request + * + */ +struct InitialRequestHandler : public HttpSysRequestHandler +{ + inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } + inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } + inline bool IsInitialRequest() const { return m_IsInitialRequest; } + + InitialRequestHandler(HttpSysTransaction& InRequest); + ~InitialRequestHandler(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + + bool m_IsInitialRequest = true; + uint64_t m_CurrentPayloadOffset = 0; + uint64_t m_ContentLength = ~uint64_t(0); + IoBuffer m_PayloadBuffer; + UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)]; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpSysServerRequest : public HttpServerRequest +{ +public: + HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpSysServerRequest(); + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpSysServerRequest(const HttpSysServerRequest&) = delete; + HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; + + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; +}; + +/** HTTP transaction + + There will be an instance of this per pending and in-flight HTTP transaction + + */ +class HttpSysTransaction final +{ +public: + HttpSysTransaction(HttpSysServer& Server); + virtual ~HttpSysTransaction(); + + enum class Status + { + kDone, + kRequestPending + }; + + [[nodiscard]] Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + + static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io); + + void IssueInitialRequest(std::error_code& ErrorCode); + bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler); + + PTP_IO Iocp(); + HANDLE RequestQueueHandle(); + inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline HttpSysServer& Server() { return m_HttpServer; } + inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } + + HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload); + + HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); } + + struct CompletionMutexScope + { + CompletionMutexScope(HttpSysTransaction& Tx) : Lock(Tx.m_CompletionMutex) {} + ~CompletionMutexScope() = default; + + RwLock::ExclusiveLockScope Lock; + }; + +private: + OVERLAPPED m_HttpOverlapped{}; + HttpSysServer& m_HttpServer; + + // Tracks which handler is due to handle the next I/O completion event + HttpSysRequestHandler* m_CompletionHandler = nullptr; + RwLock m_CompletionMutex; + InitialRequestHandler m_InitialHttpHandler{*this}; + std::optional m_HandlerRequest; + Ref m_PackageHandler; +}; + +/** + * @brief HTTP request response I/O request handler + * + * Asynchronously streams out a response to an HTTP request via compound + * responses from memory or directly from file + */ + +class HttpMessageResponseRequest : public HttpSysRequestHandler +{ +public: + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span Blobs); + ~HttpMessageResponseRequest(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + void SuppressResponseBody(); // typically used for HEAD requests + +private: + std::vector m_HttpDataChunks; + uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes + uint16_t m_ResponseCode = 0; + uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists + uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends + bool m_IsInitialResponse = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + std::vector m_DataBuffers; + + void InitializeForPayload(uint16_t ResponseCode, std::span Blobs); +}; + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) +: HttpSysRequestHandler(InRequest) +{ + std::array EmptyBufferList; + + InitializeForPayload(ResponseCode, EmptyBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) +: HttpSysRequestHandler(InRequest) +, m_ContentType(HttpContentType::kText) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); + std::array SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); + std::array SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span BlobList) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + InitializeForPayload(ResponseCode, BlobList); +} + +HttpMessageResponseRequest::~HttpMessageResponseRequest() +{ +} + +void +HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span BlobList) +{ + ZEN_TRACE_CPU("httpsys::InitializeForPayload"); + + const uint32_t ChunkCount = gsl::narrow(BlobList.size()); + + m_HttpDataChunks.reserve(ChunkCount); + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + + // Initialize the full array up front + + uint64_t LocalDataSize = 0; + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // Use direct file transfer + + auto& Chunk = m_HttpDataChunks.emplace_back(); + + Chunk.DataChunkType = HttpDataChunkFromFileHandle; + Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; + Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; + Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; + } + else + { + // Send from memory, need to make sure we chunk the buffer up since + // the underlying data structure only accepts 32-bit chunk sizes for + // memory chunks. When this happens the vector will be reallocated, + // which is fine since this will be a pretty rare case and sending + // the data is going to take a lot longer than a memory allocation :) + + const uint8_t* WriteCursor = reinterpret_cast(Buffer.Data()); + + while (BufferDataSize) + { + const ULONG ThisChunkSize = gsl::narrow(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); + + auto& Chunk = m_HttpDataChunks.emplace_back(); + + Chunk.DataChunkType = HttpDataChunkFromMemory; + Chunk.FromMemory.pBuffer = (void*)WriteCursor; + Chunk.FromMemory.BufferLength = ThisChunkSize; + + BufferDataSize -= ThisChunkSize; + WriteCursor += ThisChunkSize; + } + } + } + + m_RemainingChunkCount = gsl::narrow(m_HttpDataChunks.size()); + m_TotalDataSize = LocalDataSize; + + if (m_TotalDataSize == 0 && ResponseCode == 200) + { + // Some HTTP clients really don't like empty responses unless a 204 response is sent + m_ResponseCode = uint16_t(HttpResponseCode::NoContent); + } + else + { + m_ResponseCode = ResponseCode; + } +} + +void +HttpMessageResponseRequest::SuppressResponseBody() +{ + m_RemainingChunkCount = 0; + m_HttpDataChunks.clear(); + m_DataBuffers.clear(); +} + +HttpSysRequestHandler* +HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(NumberOfBytesTransferred); + + if (IoResult != NO_ERROR) + { + ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult)); + + // if one transmit failed there's really no need to go on + return nullptr; + } + + if (m_RemainingChunkCount == 0) + { + return nullptr; // All done + } + + return this; +} + +void +HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) +{ + ZEN_TRACE_CPU("httpsys::Response::IssueRequest"); + HttpSysTransaction& Tx = Transaction(); + HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); + PTP_IO const Iocp = Tx.Iocp(); + + StartThreadpoolIo(Iocp); + + // Split payload into batches to play well with the underlying API + + const int MaxChunksPerCall = 9999; + + const int ThisRequestChunkCount = std::min(m_RemainingChunkCount, MaxChunksPerCall); + const int ThisRequestChunkOffset = m_NextDataChunkOffset; + + m_RemainingChunkCount -= ThisRequestChunkCount; + m_NextDataChunkOffset += ThisRequestChunkCount; + + /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA? + + From the docs: + + This flag enables buffering of data in the kernel on a per-response basis. It should + be used by an application doing synchronous I/O, or by a an application doing + asynchronous I/O with no more than one send outstanding at a time. + + Applications using asynchronous I/O which may have more than one send outstanding at + a time should not use this flag. + + When this flag is set, it should be used consistently in calls to the + HttpSendHttpResponse function as well. + */ + + ULONG SendFlags = HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA; + + if (m_RemainingChunkCount) + { + // We need to make more calls to send the full amount of data + SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + } + + ULONG SendResult = 0; + + if (m_IsInitialResponse) + { + // Populate response structure + + HTTP_RESPONSE HttpResponse = {}; + + HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount); + HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset; + + // Server header + // + // By default this will also add a suffix " Microsoft-HTTPAPI/2.0" to this header + // + // This is controlled via a registry key 'DisableServerHeader', at: + // + // Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\HTTP\Parameters + // + // Set DisableServerHeader to 1 to disable suffix, or 2 to disable the header altogether + // (only the latter appears to do anything in my testing, on Windows 10). + // + // (reference https://docs.microsoft.com/en-us/archive/blogs/dsnotes/wswcf-remove-server-header) + // + + PHTTP_KNOWN_HEADER ServerHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderServer]; + ServerHeader->pRawValue = "Zen"; + ServerHeader->RawValueLength = (USHORT)3; + + // Content-length header + + char ContentLengthString[32]; + _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10); + + PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength]; + ContentLengthHeader->pRawValue = ContentLengthString; + ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString); + + // Content-type header + + PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; + + std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + + ContentTypeHeader->pRawValue = ContentTypeString.data(); + ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); + + HttpResponse.StatusCode = m_ResponseCode; + HttpResponse.pReason = ReasonString.data(); + HttpResponse.ReasonLength = (USHORT)ReasonString.size(); + + // Cache policy + + HTTP_CACHE_POLICY CachePolicy; + + CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; + CachePolicy.SecondsToLive = 0; + + // Initial response API call + + SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + &HttpResponse, + &CachePolicy, + NULL, + NULL, + 0, + Tx.Overlapped(), + NULL); + + m_IsInitialResponse = false; + } + else + { + // Subsequent response API calls + + SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + (USHORT)ThisRequestChunkCount, // EntityChunkCount + &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks + NULL, // BytesSent + NULL, // Reserved1 + 0, // Reserved2 + Tx.Overlapped(), // Overlapped + NULL // LogData + ); + } + + auto EmitReponseDetails = [&](StringBuilderBase& ResponseDetails) -> void { + for (int i = 0; i < ThisRequestChunkCount; ++i) + { + const HTTP_DATA_CHUNK Chunk = m_HttpDataChunks[ThisRequestChunkOffset + i]; + + if (i > 0) + { + ResponseDetails << " + "; + } + + switch (Chunk.DataChunkType) + { + case HttpDataChunkFromMemory: + ResponseDetails << "mem:" << uint64_t(Chunk.FromMemory.BufferLength); + break; + + case HttpDataChunkFromFileHandle: + ResponseDetails << "file:"; + { + ResponseDetails << uint64_t(Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart) << "," + << uint64_t(Chunk.FromFileHandle.ByteRange.Length.QuadPart) << ","; + + std::error_code PathEc; + HANDLE FileHandle = Chunk.FromFileHandle.FileHandle; + std::filesystem::path Path = PathFromHandle(FileHandle, PathEc); + + if (PathEc) + { + ResponseDetails << "bad_file(handle=" << reinterpret_cast(FileHandle) << ",error=" << PathEc.message() + << ")"; + } + else + { + const uint64_t FileSize = FileSizeFromHandle(FileHandle); + ResponseDetails << Path.u8string() << "(" << FileSize << ") handle=" << reinterpret_cast(FileHandle); + } + } + break; + + case HttpDataChunkFromFragmentCache: + ResponseDetails << "frag:???"; // We do not use these + break; + + case HttpDataChunkFromFragmentCacheEx: + ResponseDetails << "frax:???"; // We do not use these + break; + +# if 0 // Not available in older Windows SDKs + case HttpDataChunkTrailers: + ResponseDetails << "trls:???"; // We do not use these + break; +# endif + + default: + ResponseDetails << "???: " << Chunk.DataChunkType; + break; + } + } + }; + + if (SendResult == NO_ERROR) + { + // Synchronous completion, but the completion event will still be posted to IOCP + + ErrorCode.clear(); + } + else if (SendResult == ERROR_IO_PENDING) + { + // Asynchronous completion, a completion notification will be posted to IOCP + + ErrorCode.clear(); + } + else + { + ErrorCode = MakeErrorCode(SendResult); + + // An error occurred, no completion will be posted to IOCP + + CancelThreadpoolIo(Iocp); + + // Emit diagnostics + + ExtendableStringBuilder<256> ResponseDetails; + EmitReponseDetails(ResponseDetails); + + ZEN_WARN("failed to send HTTP response (error {}: '{}'), request URL: '{}', ({}.{}) response: {}", + SendResult, + ErrorCode.message(), + HttpReq->pRawUrl, + Tx.ServerRequest().SessionId(), + HttpReq->RequestId, + ResponseDetails); + } +} + +/** HTTP completion handler for async work + + This is used to allow work to be taken off the request handler threads + and to support posting responses asynchronously. + */ + +class HttpAsyncWorkRequest : public HttpSysRequestHandler +{ +public: + HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function&& Response); + ~HttpAsyncWorkRequest(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + +private: + struct AsyncWorkItem : public IWork + { + virtual void Execute() override; + + AsyncWorkItem(HttpSysTransaction& InTx, std::function&& InHandler) + : Tx(InTx) + , Handler(std::move(InHandler)) + { + } + + HttpSysTransaction& Tx; + std::function Handler; + }; + + Ref m_WorkItem; +}; + +HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function&& Response) +: HttpSysRequestHandler(Tx) +{ + m_WorkItem = new AsyncWorkItem(Tx, std::move(Response)); +} + +HttpAsyncWorkRequest::~HttpAsyncWorkRequest() +{ +} + +void +HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode) +{ + ZEN_TRACE_CPU("httpsys::AsyncWork::IssueRequest"); + ErrorCode.clear(); + + Transaction().Server().WorkPool().ScheduleWork(m_WorkItem); +} + +HttpSysRequestHandler* +HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // This ought to not be called since there should be no outstanding I/O request + // when this completion handler is active + + ZEN_UNUSED(IoResult, NumberOfBytesTransferred); + + ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); + + return this; +} + +void +HttpAsyncWorkRequest::AsyncWorkItem::Execute() +{ + ZEN_TRACE_CPU("httpsys::async_execute"); + + try + { + // We need to hold this lock while we're issuing new requests in order to + // prevent race conditions between the thread we are running on and any + // IOCP service threads. Otherwise the IOCP completion handler can end + // up deleting the transaction object before we are done with it! + HttpSysTransaction::CompletionMutexScope _(Tx); + HttpSysServerRequest& ThisRequest = Tx.ServerRequest(); + + ThisRequest.m_NextCompletionHandler = nullptr; + + { + ZEN_TRACE_CPU("httpsys::HandleRequest"); + Handler(ThisRequest); + } + + // TODO: should Handler be destroyed at this point to ensure there + // are no outstanding references into state which could be + // deleted asynchronously as a result of issuing the response? + + if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler) + { + return (void)Tx.IssueNextRequest(NextHandler); + } + else if (!ThisRequest.IsHandled()) + { + return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv)); + } + else + { + // "Handled" but no request handler? Shouldn't ever happen + return (void)Tx.IssueNextRequest( + new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv)); + } + } + catch (std::exception& Ex) + { + return (void)Tx.IssueNextRequest( + new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what()))); + } +} + +/** + _________ + / _____/ ______________ __ ___________ + \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ + / \ ___/| | \/\ /\ ___/| | \/ + /_______ /\___ >__| \_/ \___ >__| + \/ \/ \/ +*/ + +HttpSysServer::HttpSysServer(const HttpSysConfig& InConfig) +: m_Log(logging::Get("http")) +, m_RequestLog(logging::Get("http_requests")) +, m_IsRequestLoggingEnabled(InConfig.IsRequestLoggingEnabled) +, m_IsAsyncResponseEnabled(InConfig.IsAsyncResponseEnabled) +, m_InitialConfig(InConfig) +{ + // Initialize thread pool + + int MinThreadCount; + int MaxThreadCount; + + if (m_InitialConfig.ThreadCount == 0) + { + MinThreadCount = Max(8u, std::thread::hardware_concurrency()); + } + else + { + MinThreadCount = m_InitialConfig.ThreadCount; + } + + MaxThreadCount = MinThreadCount * 2; + + if (m_InitialConfig.IsDedicatedServer) + { + // In order to limit the potential impact of threads stuck + // in locks we allow the thread pool to be oversubscribed + // by a fair amount + + MaxThreadCount *= 2; + } + + m_IoThreadPool = std::make_unique(MinThreadCount, MaxThreadCount); + + if (m_InitialConfig.AsyncWorkThreadCount == 0) + { + m_InitialConfig.AsyncWorkThreadCount = 16; + } + + ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr); + + if (Result != NO_ERROR) + { + return; + } + + m_IsHttpInitialized = true; + m_IsOk = true; + + ZEN_INFO("http.sys server started in {} mode, using {}-{} I/O threads and {} async worker threads", + m_InitialConfig.IsDedicatedServer ? "DEDICATED" : "NORMAL", + MinThreadCount, + MaxThreadCount, + m_InitialConfig.AsyncWorkThreadCount); +} + +HttpSysServer::~HttpSysServer() +{ + if (m_IsHttpInitialized) + { + ZEN_ERROR("~HttpSysServer() called without calling Close() first"); + } + + delete m_AsyncWorkPool; + m_AsyncWorkPool = nullptr; +} + +void +HttpSysServer::Close() +{ + if (m_IsHttpInitialized) + { + Cleanup(); + + HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr); + m_IsHttpInitialized = false; + } +} + +int +HttpSysServer::InitializeServer(int BasePort) +{ + using namespace std::literals; + + WideStringBuilder<64> WildcardUrlPath; + WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; + + m_IsOk = false; + + ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + int EffectivePort = BasePort; + + Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + + // Sharing violation implies the port is being used by another process + for (int PortOffset = 1; (Result == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + WildcardUrlPath.Reset(); + WildcardUrlPath << u8"http://*:"sv << int64_t(EffectivePort) << u8"/"sv; + + Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + } + + m_BaseUris.clear(); + if (Result == NO_ERROR) + { + m_BaseUris.push_back(WildcardUrlPath.c_str()); + } + else if (Result == ERROR_ACCESS_DENIED) + { + // If we can't register the wildcard path, we fall back to local paths + // This local paths allow requests originating locally to function, but will not allow + // remote origin requests to function. This can be remedied by using netsh + // during an install process to grant permissions to route public access to the appropriate + // port for the current user. eg: + // netsh http add urlacl url=http://*:8558/ user= + + ZEN_WARN("Unable to register handler using '{}' - falling back to local-only", WideToUtf8(WildcardUrlPath)); + + const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; + + ULONG InternalResult = ERROR_SHARING_VIOLATION; + for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + + for (const std::u8string_view Host : Hosts) + { + WideStringBuilder<64> LocalUrlPath; + LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; + + InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + + if (InternalResult == NO_ERROR) + { + ZEN_INFO("Registered local handler '{}'", WideToUtf8(LocalUrlPath)); + + m_BaseUris.push_back(LocalUrlPath.c_str()); + } + else + { + break; + } + } + } + } + + if (m_BaseUris.empty()) + { + ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; + + WideStringBuilder<64> QueueName; + QueueName << "zenserver_" << EffectivePort; + + Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, + /* Name */ QueueName.c_str(), + /* SecurityAttributes */ nullptr, + /* Flags */ 0, + &m_RequestQueueHandle); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + + return EffectivePort; + } + + HttpBindingInfo.Flags.Present = 1; + HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle; + + Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo)); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + + return EffectivePort; + } + + // Configure rejection method. Default is to drop the connection, it's better if we + // return an explicit error code when the queue cannot accept more requests + + { + HTTP_503_RESPONSE_VERBOSITY VerbosityInformation = Http503ResponseVerbosityLimited; + + Result = HttpSetRequestQueueProperty(m_RequestQueueHandle, + HttpServer503VerbosityProperty, + &VerbosityInformation, + sizeof VerbosityInformation, + 0, + 0); + } + + // Tune the maximum number of pending requests in the http.sys request queue. By default + // the value is 1000 which is plenty for single user machines but for dedicated servers + // serving many users it makes sense to increase this to a higher number to help smooth + // out intermittent stalls like we might experience when GC is triggered + + if (m_InitialConfig.IsDedicatedServer) + { + ULONG QueueLength = 50000; + + Result = HttpSetRequestQueueProperty(m_RequestQueueHandle, HttpServerQueueLengthProperty, &QueueLength, sizeof QueueLength, 0, 0); + + if (Result != NO_ERROR) + { + ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result); + } + } + + // Create I/O completion port + + std::error_code ErrorCode; + m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); + + if (ErrorCode) + { + ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); + } + else + { + m_IsOk = true; + + ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + } + + // This is not available in all Windows SDK versions so for now we can't use recently + // released functionality. We should investigate how to get more recent SDK releases + // into the build + +# if 0 + if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4)) + { + ZEN_DEBUG("HTTP3 is available"); + } + else + { + ZEN_DEBUG("HTTP3 is NOT available"); + } +# endif + + return EffectivePort; +} + +void +HttpSysServer::Cleanup() +{ + ++m_IsShuttingDown; + + if (m_RequestQueueHandle) + { + HttpCloseRequestQueue(m_RequestQueueHandle); + m_RequestQueueHandle = nullptr; + } + + if (m_HttpUrlGroupId) + { + HttpCloseUrlGroup(m_HttpUrlGroupId); + m_HttpUrlGroupId = 0; + } + + if (m_HttpSessionId) + { + HttpCloseServerSession(m_HttpSessionId); + m_HttpSessionId = 0; + } +} + +WorkerThreadPool& +HttpSysServer::WorkPool() +{ + if (!m_AsyncWorkPool) + { + RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); + + if (!m_AsyncWorkPool) + { + m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"); + } + } + + return *m_AsyncWorkPool; +} + +void +HttpSysServer::StartServer() +{ + const int InitialRequestCount = 32; + + for (int i = 0; i < InitialRequestCount; ++i) + { + IssueNewRequestMaybe(); + } +} + +void +HttpSysServer::Run(bool IsInteractive) +{ + if (IsInteractive) + { + zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit"); + } + + do + { + // int WaitTimeout = -1; + int WaitTimeout = 100; + + if (IsInteractive) + { + WaitTimeout = 1000; + + if (_kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + UpdateLofreqTimerValue(); + } while (!IsApplicationExitRequested()); +} + +void +HttpSysServer::OnHandlingNewRequest() +{ + if (--m_PendingRequests > m_MinPendingRequests) + { + // We have more than the minimum number of requests pending, just let someone else + // enqueue new requests. This should be the common case as we check if we need to + // enqueue more requests before exiting the completion handler. + return; + } + + IssueNewRequestMaybe(); +} + +void +HttpSysServer::IssueNewRequestMaybe() +{ + if (m_IsShuttingDown.load(std::memory_order::acquire)) + { + return; + } + + if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests) + { + return; + } + + std::unique_ptr Request = std::make_unique(*this); + + std::error_code ErrorCode; + Request->IssueInitialRequest(ErrorCode); + + if (ErrorCode) + { + // No request was actually issued. What is the appropriate response? + + return; + } + + // This may end up exceeding the MaxPendingRequests limit, but it's not + // really a problem. I'm doing it this way mostly to avoid dealing with + // exceptions here + ++m_PendingRequests; + + Request.release(); +} + +void +HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) +{ + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); + Service.SetUriPrefixLength(PathUtf16.size() + 1 /* leading slash */); + + // Convert to wide string + + for (const std::wstring& BaseUri : m_BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; + + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + + return; + } + } +} + +void +HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) +{ + ZEN_UNUSED(Service); + + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); + + // Convert to wide string + + for (const std::wstring& BaseUri : m_BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; + + ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + } + } +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler) +{ +} + +HttpSysTransaction::~HttpSysTransaction() +{ +} + +PTP_IO +HttpSysTransaction::Iocp() +{ + return m_HttpServer.m_IoThreadPool->Iocp(); +} + +HANDLE +HttpSysTransaction::RequestQueueHandle() +{ + return m_HttpServer.m_RequestQueueHandle; +} + +void +HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode) +{ + m_InitialHttpHandler.IssueRequest(ErrorCode); +} + +thread_local bool t_IsHttpSysThreadNamed = false; +static std::atomic HttpSysThreadIndex = 0; + +static void +NameCurrentHttpSysThread() +{ + t_IsHttpSysThreadNamed = true; + const int ThreadIndex = ++HttpSysThreadIndex; + zen::ExtendableStringBuilder<16> ThreadName; + ThreadName << "httpio_" << ThreadIndex; + SetCurrentThreadName(ThreadName); +} + +void +HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io) +{ + ZEN_UNUSED(Io, Instance); + + // Assign names to threads for context + + if (!t_IsHttpSysThreadNamed) + { + NameCurrentHttpSysThread(); + } + + // Note that for a given transaction we may be in this completion function on more + // than one thread at any given moment. This means we need to be careful about what + // happens in here + + HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + + if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) + { + delete Transaction; + } + + // Ensure new requests are enqueued as necessary. We do this here instead + // of inside the transaction completion handler now to avoid spending time + // in unrelated API calls while holding the transaction lock + + if (HttpSysServer* HttpServer = reinterpret_cast(pContext)) + { + HttpServer->IssueNewRequestMaybe(); + } +} + +bool +HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler) +{ + ZEN_TRACE_CPU("httpsys::Transaction::IssueNextRequest"); + + HttpSysRequestHandler* CurrentHandler = m_CompletionHandler; + m_CompletionHandler = NewCompletionHandler; + + auto _ = MakeGuard([this, CurrentHandler] { + if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler)) + { + delete CurrentHandler; + } + }); + + if (NewCompletionHandler == nullptr) + { + return false; + } + + try + { + std::error_code ErrorCode; + m_CompletionHandler->IssueRequest(ErrorCode); + + if (!ErrorCode) + { + return true; + } + + ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message()); + } + catch (std::exception& Ex) + { + ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what()); + } + + // something went wrong, no request is pending + m_CompletionHandler = nullptr; + + return false; +} + +HttpSysTransaction::Status +HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // We use this to ensure sequential execution of completion handlers + // for any given transaction. It also ensures all member variables are + // in a consistent state for the current thread + + RwLock::ExclusiveLockScope _(m_CompletionMutex); + + bool IsRequestPending = false; + + if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler) + { + if ((CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest()) + { + // Ensure we have a sufficient number of pending requests outstanding + m_HttpServer.OnHandlingNewRequest(); + } + + auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred); + + IsRequestPending = IssueNextRequest(NewCompletionHandler); + } + + if (IsRequestPending) + { + // There is another request pending on this transaction, so it needs to remain valid + return Status::kRequestPending; + } + + if (m_HttpServer.m_IsRequestLoggingEnabled) + { + if (m_HandlerRequest.has_value()) + { + m_HttpServer.m_RequestLog.info("{} {}", ToString(m_HandlerRequest->RequestVerb()), m_HandlerRequest->RelativeUri()); + } + } + + // Transaction done, caller should clean up (delete) this instance + return Status::kDone; +} + +HttpSysServerRequest& +HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) +{ + HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + + // Default request handling + + if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + { + Service.HandleRequest(ThisRequest); + } + + return ThisRequest; +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) +: m_HttpTx(Tx) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); + + const int PrefixLength = Service.UriPrefixLength(); + const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(wchar_t); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + if (AbsPathLength >= PrefixLength) + { + // We convert the URI immediately because most of the code involved prefers to deal + // with utf8. This is overhead which I'd prefer to avoid but for now we just have + // to live with it + + WideToUtf8({(wchar_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow(AbsPathLength - PrefixLength)}, + m_UriUtf8); + + std::string_view UriSuffix8{m_UriUtf8}; + + m_UriWithExtension = UriSuffix8; // Retain URI with extension for user access + m_Uri = UriSuffix8; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(UriSuffix8.size() + 1); + } + } + } + else + { + m_UriUtf8.Reset(); + m_Uri = {}; + m_UriWithExtension = {}; + } + + if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) + { + --QueryStringLength; // We skip the leading question mark + + WideToUtf8({(wchar_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(wchar_t)}, m_QueryStringUtf8); + } + else + { + m_QueryStringUtf8.Reset(); + } + + m_QueryString = std::string_view(m_QueryStringUtf8); + m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); + m_ContentLength = GetContentLength(HttpRequestPtr); + m_ContentType = GetContentType(HttpRequestPtr); + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = GetAcceptType(HttpRequestPtr); + } + + if (m_Verb == HttpVerb::kHead) + { + SetSuppressResponseBody(); + } +} + +HttpSysServerRequest::~HttpSysServerRequest() +{ +} + +Oid +HttpSysServerRequest::ParseSessionId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Session"sv) + { + if (Header.RawValueLength == Oid::StringLength) + { + return Oid::FromHexString({Header.pRawValue, Header.RawValueLength}); + } + } + } + + return {}; +} + +uint32_t +HttpSysServerRequest::ParseRequestId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Request"sv) + { + std::string_view RequestValue{Header.pRawValue, Header.RawValueLength}; + uint32_t RequestId = 0; + std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId); + return RequestId; + } + } + + return 0; +} + +IoBuffer +HttpSysServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = + new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size()); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponseAsync(std::function&& ContinuationHandler) +{ + if (m_HttpTx.Server().IsAsyncResponseEnabled()) + { + m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler)); + } + else + { + ContinuationHandler(m_HttpTx.ServerRequest()); + } +} + +bool +HttpSysServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + HTTP_REQUEST* Req = m_HttpTx.HttpRequest(); + const HTTP_KNOWN_HEADER& RangeHeader = Req->Headers.KnownHeaders[HttpHeaderRange]; + + return TryParseHttpRangeHeader({RangeHeader.pRawValue, RangeHeader.RawValueLength}, Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) +{ +} + +InitialRequestHandler::~InitialRequestHandler() +{ +} + +void +InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) +{ + ZEN_TRACE_CPU("httpsys::Request::IssueRequest"); + + HttpSysTransaction& Tx = Transaction(); + PTP_IO Iocp = Tx.Iocp(); + HTTP_REQUEST* HttpReq = Tx.HttpRequest(); + + StartThreadpoolIo(Iocp); + + ULONG HttpApiResult; + + if (IsInitialRequest()) + { + HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), + HTTP_NULL_ID, + HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, + HttpReq, + RequestBufferSize(), + NULL, + Tx.Overlapped()); + } + else + { + // The http.sys team recommends limiting the size to 128KB + static const uint64_t kMaxBytesPerApiCall = 128 * 1024; + + uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; + const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); + void* BufferWriteCursor = reinterpret_cast(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; + + HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + 0, /* Flags */ + BufferWriteCursor, + gsl::narrow(BytesToReadThisCall), + nullptr, // BytesReturned + Tx.Overlapped()); + } + + if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) + { + CancelThreadpoolIo(Iocp); + + ErrorCode = MakeErrorCode(HttpApiResult); + + ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message()); + + return; + } + + ErrorCode.clear(); +} + +HttpSysRequestHandler* +InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); + + switch (IoResult) + { + default: + case ERROR_OPERATION_ABORTED: + return nullptr; + + case ERROR_MORE_DATA: // Insufficient buffer space + case NO_ERROR: + break; + } + + ZEN_TRACE_CPU("httpsys::HandleCompletion"); + + // Route request + + try + { + HTTP_REQUEST* HttpReq = HttpRequest(); + +# if 0 + for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + { + auto& ReqInfo = HttpReq->pRequestInfo[i]; + + switch (ReqInfo.InfoType) + { + case HttpRequestInfoTypeRequestTiming: + { + const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeAuth: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeChannelBind: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslProtocol: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslTokenBindingDraft: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslTokenBinding: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeTcpInfoV0: + { + const TCP_INFO_v0* TcpInfo = reinterpret_cast(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeRequestSizing: + { + const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast(ReqInfo.pInfo); + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeQuicStats: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeTcpInfoV1: + { + const TCP_INFO_v1* TcpInfo = reinterpret_cast(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + } + } +# endif + + if (HttpService* Service = reinterpret_cast(HttpReq->UrlContext)) + { + if (m_IsInitialRequest) + { + m_ContentLength = GetContentLength(HttpReq); + const HttpContentType ContentType = GetContentType(HttpReq); + + if (m_ContentLength) + { + // Handle initial chunk read by copying any payload which has already been copied + // into our embedded request buffer + + m_PayloadBuffer = IoBuffer(m_ContentLength); + m_PayloadBuffer.SetContentType(ContentType); + + uint64_t BytesToRead = m_ContentLength; + uint8_t* const BufferBase = reinterpret_cast(m_PayloadBuffer.MutableData()); + uint8_t* BufferWriteCursor = BufferBase; + + const int EntityChunkCount = HttpReq->EntityChunkCount; + + for (int i = 0; i < EntityChunkCount; ++i) + { + HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; + + ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); + + const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; + + ZEN_ASSERT(BufferLength <= BytesToRead); + + memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); + + BufferWriteCursor += BufferLength; + BytesToRead -= BufferLength; + } + + m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; + } + } + else + { + m_CurrentPayloadOffset += NumberOfBytesTransferred; + } + + if (m_CurrentPayloadOffset != m_ContentLength) + { + // Body not complete, issue another read request to receive more body data + return this; + } + + // Request body received completely + + m_PayloadBuffer.MakeImmutable(); + + HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer); + + if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler) + { + return Response; + } + + if (!ThisRequest.IsHandled()) + { + return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); + } + } + + // Unable to route + return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); + } + catch (std::system_error& SystemError) + { + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, SystemError.what()); + } + + ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, SystemError.what()); + } + catch (std::bad_alloc& BadAlloc) + { + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, BadAlloc.what()); + } + catch (std::exception& ex) + { + ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, ex.what()); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// HttpServer interface implementation +// + +int +HttpSysServer::Initialize(int BasePort) +{ + int EffectivePort = InitializeServer(BasePort); + StartServer(); + return EffectivePort; +} + +void +HttpSysServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} +void +HttpSysServer::RegisterService(HttpService& Service) +{ + RegisterService(Service.BaseUri(), Service); +} + +Ref +CreateHttpSysServer(HttpSysConfig Config) +{ + return Ref(new HttpSysServer(Config)); +} + +} // namespace zen +#endif diff --git a/src/zenhttp/servers/httpsys.h b/src/zenhttp/servers/httpsys.h new file mode 100644 index 000000000..6a6b16525 --- /dev/null +++ b/src/zenhttp/servers/httpsys.h @@ -0,0 +1,28 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#ifndef ZEN_WITH_HTTPSYS +# if ZEN_PLATFORM_WINDOWS +# define ZEN_WITH_HTTPSYS 1 +# else +# define ZEN_WITH_HTTPSYS 0 +# endif +#endif + +namespace zen { + +struct HttpSysConfig +{ + unsigned int ThreadCount = 0; + unsigned int AsyncWorkThreadCount = 0; + bool IsAsyncResponseEnabled = true; + bool IsRequestLoggingEnabled = false; + bool IsDedicatedServer = false; +}; + +Ref CreateHttpSysServer(HttpSysConfig Config); + +} // namespace zen diff --git a/src/zenhttp/servers/iothreadpool.cpp b/src/zenhttp/servers/iothreadpool.cpp new file mode 100644 index 000000000..da4b42e28 --- /dev/null +++ b/src/zenhttp/servers/iothreadpool.cpp @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "iothreadpool.h" + +#include + +#if ZEN_PLATFORM_WINDOWS + +namespace zen { + +WinIoThreadPool::WinIoThreadPool(int InThreadCount, int InMaxThreadCount) +{ + ZEN_ASSERT(InThreadCount); + + if (InMaxThreadCount < InThreadCount) + { + InMaxThreadCount = InThreadCount; + } + + m_ThreadPool = CreateThreadpool(NULL); + + SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount); + SetThreadpoolThreadMaximum(m_ThreadPool, InMaxThreadCount); + + InitializeThreadpoolEnvironment(&m_CallbackEnvironment); + + m_CleanupGroup = CreateThreadpoolCleanupGroup(); + + SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool); + + SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); +} + +WinIoThreadPool::~WinIoThreadPool() +{ + CloseThreadpool(m_ThreadPool); +} + +void +WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) +{ + ZEN_ASSERT(!m_ThreadPoolIo); + + m_ThreadPoolIo = CreateThreadpoolIo(IoHandle, Callback, Context, &m_CallbackEnvironment); + + if (!m_ThreadPoolIo) + { + ErrorCode = MakeErrorCodeFromLastError(); + } +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/servers/iothreadpool.h b/src/zenhttp/servers/iothreadpool.h new file mode 100644 index 000000000..e75e95e58 --- /dev/null +++ b/src/zenhttp/servers/iothreadpool.h @@ -0,0 +1,31 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if ZEN_PLATFORM_WINDOWS +# include + +# include + +namespace zen { + +class WinIoThreadPool +{ +public: + WinIoThreadPool(int InThreadCount, int InMaxThreadCount); + ~WinIoThreadPool(); + + void CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode); + inline PTP_IO Iocp() const { return m_ThreadPoolIo; } + +private: + PTP_POOL m_ThreadPool = nullptr; + PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; + PTP_IO m_ThreadPoolIo = nullptr; + TP_CALLBACK_ENVIRON m_CallbackEnvironment; +}; + +} // namespace zen +#endif diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 588fd8b87..e90fdfd1c 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -5,7 +5,7 @@ target('zenhttp') set_group("libs") add_headerfiles("**.h") add_files("**.cpp") - add_files("httpsys.cpp", {unity_ignored=true}) + add_files("servers/httpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) add_deps("zencore", "transport-sdk") add_packages( diff --git a/src/zenserver/diag/diagsvcs.cpp b/src/zenserver/diag/diagsvcs.cpp new file mode 100644 index 000000000..93c2eafc3 --- /dev/null +++ b/src/zenserver/diag/diagsvcs.cpp @@ -0,0 +1,135 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "diagsvcs.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace zen { + +using namespace std::literals; + +static bool +ReadLogFile(const std::string& Path, StringBuilderBase& Out) +{ + try + { + constexpr auto ReadSize = std::size_t{4096}; + auto FileStream = std::ifstream{Path}; + + std::string Buf(ReadSize, '\0'); + while (FileStream.read(&Buf[0], ReadSize)) + { + Out.Append(std::string_view(&Buf[0], FileStream.gcount())); + } + Out.Append(std::string_view(&Buf[0], FileStream.gcount())); + + return true; + } + catch (std::exception&) + { + Out.Reset(); + return false; + } +} + +HttpHealthService::HttpHealthService() +{ + m_Router.RegisterRoute( + "", + [](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "info", + [this](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + + CbObjectWriter Writer; + + { + RwLock::SharedLockScope _(m_InfoLock); + Writer << "DataRoot"sv << m_HealthInfo.DataRoot.string(); + Writer << "AbsLogPath"sv << m_HealthInfo.AbsLogPath.string(); + Writer << "BuildVersion"sv << m_HealthInfo.BuildVersion; + Writer << "HttpServerClass"sv << m_HealthInfo.HttpServerClass; + } + + HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "log", + [this](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + + zen::Log().flush(); + + std::filesystem::path Path = [&] { + RwLock::SharedLockScope _(m_InfoLock); + return m_HealthInfo.AbsLogPath.empty() ? m_HealthInfo.DataRoot / "logs/zenserver.log" : m_HealthInfo.AbsLogPath; + }(); + + ExtendableStringBuilder<4096> Sb; + if (ReadLogFile(Path.string(), Sb) && Sb.Size() > 0) + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Sb.ToView()); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "version", + [this](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + if (HttpReq.GetQueryParams().GetValue("detailed") == "true") + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION_BUILD_STRING_FULL); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION); + } + }, + HttpVerb::kGet); +} + +void +HttpHealthService::SetHealthInfo(HealthServiceInfo&& Info) +{ + RwLock::ExclusiveLockScope _(m_InfoLock); + m_HealthInfo = std::move(Info); +} + +const char* +HttpHealthService::BaseUri() const +{ + return "/health/"; +} + +void +HttpHealthService::HandleRequest(HttpServerRequest& Request) +{ + if (!m_Router.HandleRequest(Request)) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv); + } +} + +} // namespace zen diff --git a/src/zenserver/diag/diagsvcs.h b/src/zenserver/diag/diagsvcs.h new file mode 100644 index 000000000..8cc869c83 --- /dev/null +++ b/src/zenserver/diag/diagsvcs.h @@ -0,0 +1,119 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +#include + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +/** HTTP test endpoint + + This is intended to be used to exercise basic HTTP communication infrastructure + which is useful for benchmarking performance of the server code and when evaluating + network performance / diagnosing connectivity issues + + */ +class HttpTestService : public HttpService +{ +public: + HttpTestService() {} + ~HttpTestService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + + virtual void HandleRequest(HttpServerRequest& Request) override + { + using namespace std::literals; + + auto Uri = Request.RelativeUri(); + + if (Uri == "hello"sv) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"hello world!"sv); + } + else if (Uri == "1K"sv) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1k); + } + else if (Uri == "1M"sv) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1m); + } + else if (Uri == "1M_1k"sv) + { + std::vector Buffers; + Buffers.reserve(1024); + + for (int i = 0; i < 1024; ++i) + { + Buffers.push_back(m_1k); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); + } + else if (Uri == "1G"sv) + { + std::vector Buffers; + Buffers.reserve(1024); + + for (int i = 0; i < 1024; ++i) + { + Buffers.push_back(m_1m); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); + } + else if (Uri == "1G_1k"sv) + { + std::vector Buffers; + Buffers.reserve(1024 * 1024); + + for (int i = 0; i < 1024 * 1024; ++i) + { + Buffers.push_back(m_1k); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); + } + } + +private: + IoBuffer m_1m{1024 * 1024}; + IoBuffer m_1k{m_1m, 0u, 1024}; +}; + +struct HealthServiceInfo +{ + std::filesystem::path DataRoot; + std::filesystem::path AbsLogPath; + std::string HttpServerClass; + std::string BuildVersion; +}; + +/** Health monitoring endpoint + + Thji + */ +class HttpHealthService : public HttpService +{ +public: + HttpHealthService(); + ~HttpHealthService() = default; + + void SetHealthInfo(HealthServiceInfo&& Info); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override final; + +private: + HttpRequestRouter m_Router; + RwLock m_InfoLock; + HealthServiceInfo m_HealthInfo; +}; + +} // namespace zen diff --git a/src/zenserver/zenserver.h b/src/zenserver/zenserver.h index 0c28c1229..cc5b53cef 100644 --- a/src/zenserver/zenserver.h +++ b/src/zenserver/zenserver.h @@ -21,7 +21,6 @@ ZEN_THIRD_PARTY_INCLUDES_END #include #include -#include #include #include #include @@ -29,6 +28,7 @@ ZEN_THIRD_PARTY_INCLUDES_END #include "admin/admin.h" #include "cache/httpstructuredcache.h" #include "cache/structuredcachestore.h" +#include "diag/diagsvcs.h" #include "frontend/frontend.h" #include "httpcidstore.h" #include "objectstore/objectstore.h" -- cgit v1.2.3