diff options
| author | Stefan Boberg <[email protected]> | 2023-10-13 09:55:27 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-10-13 09:55:27 +0200 |
| commit | 74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d (patch) | |
| tree | acae59dac67b4d051403f35e580201c214ec4fda /src/zenhttp/servers | |
| parent | faster oplog iteration (#471) (diff) | |
| download | zen-74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d.tar.xz zen-74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d.zip | |
restructured zenhttp (#472)
separating the http server implementations into a directory and moved diagsvcs into zenserver since it's somewhat hard-coded for it
Diffstat (limited to 'src/zenhttp/servers')
| -rw-r--r-- | src/zenhttp/servers/httpasio.cpp | 1052 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpasio.h | 11 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpnull.cpp | 88 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpnull.h | 30 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpparser.cpp | 370 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpparser.h | 112 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpplugin.cpp | 781 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpsys.cpp | 2012 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpsys.h | 28 | ||||
| -rw-r--r-- | src/zenhttp/servers/iothreadpool.cpp | 54 | ||||
| -rw-r--r-- | src/zenhttp/servers/iothreadpool.h | 31 |
11 files changed, 4569 insertions, 0 deletions
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp new file mode 100644 index 000000000..0c6b189f9 --- /dev/null +++ b/src/zenhttp/servers/httpasio.cpp @@ -0,0 +1,1052 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpasio.h" + +#include <zencore/except.h> +#include <zencore/logging.h> +#include <zencore/thread.h> +#include <zencore/trace.h> +#include <zenhttp/httpserver.h> + +#include "httpparser.h" + +#include <deque> +#include <memory> +#include <string_view> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#if ZEN_PLATFORM_WINDOWS +# include <conio.h> +# include <mstcpip.h> +#endif +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#define ASIO_VERBOSE_TRACE 0 + +#if ASIO_VERBOSE_TRACE +# define ZEN_TRACE_VERBOSE ZEN_TRACE +#else +# define ZEN_TRACE_VERBOSE(fmtstr, ...) +#endif + +namespace zen::asio_http { + +using namespace std::literals; + +struct HttpAcceptor; +struct HttpResponse; +struct HttpServerConnection; + +static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); + +inline spdlog::logger& +InitLogger() +{ + spdlog::logger& Logger = logging::Get("asio"); + // Logger.set_level(spdlog::level::trace); + return Logger; +} + +inline spdlog::logger& +Log() +{ + static spdlog::logger& g_Logger = InitLogger(); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAsioServerImpl +{ +public: + HttpAsioServerImpl(); + ~HttpAsioServerImpl(); + + int Start(uint16_t Port, int ThreadCount); + void Stop(); + void RegisterService(const char* UrlPath, HttpService& Service); + HttpService* RouteRequest(std::string_view Url); + + asio::io_service m_IoService; + asio::io_service::work m_Work{m_IoService}; + std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; + std::vector<std::thread> m_ThreadPool; + + struct ServiceEntry + { + std::string ServiceUrlPath; + HttpService* Service; + }; + + RwLock m_Lock; + std::vector<ServiceEntry> m_UriHandlers; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpAsioServerRequest : public HttpServerRequest +{ +public: + HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpAsioServerRequest(); + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpAsioServerRequest(const HttpAsioServerRequest&) = delete; + HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete; + + HttpRequestParser& m_Request; + IoBuffer m_PayloadBuffer; + std::unique_ptr<HttpResponse> m_Response; +}; + +struct HttpResponse +{ +public: + HttpResponse() = default; + explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) + { + ZEN_TRACE_CPU("asio::InitializeForPayload"); + + m_ResponseCode = ResponseCode; + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { +#if 1 + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); +#else + IoBuffer TempBuffer = std::move(Buffer); + TempBuffer.MakeOwned(); + m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer)); +#endif + } + + uint64_t LocalDataSize = 0; + + m_AsioBuffers.push_back({}); // Placeholder for header + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // TODO: Use direct file transfer, via TransmitFile/sendfile + // + // this looks like it requires some custom asio plumbing however + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + else + { + // Send from memory + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + } + m_ContentLength = LocalDataSize; + + auto Headers = GetHeaders(); + m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size()); + } + + uint16_t ResponseCode() const { return m_ResponseCode; } + uint64_t ContentLength() const { return m_ContentLength; } + + const std::vector<asio::const_buffer>& AsioBuffers() const { return m_AsioBuffers; } + + std::string_view GetHeaders() + { + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" + << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Length: " << ContentLength() << "\r\n"sv; + + if (!m_IsKeepAlive) + { + m_Headers << "Connection: close\r\n"sv; + } + + m_Headers << "\r\n"sv; + + return m_Headers; + } + + void SuppressPayload() { m_AsioBuffers.resize(1); } + +private: + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; + std::vector<IoBuffer> m_DataBuffers; + std::vector<asio::const_buffer> m_AsioBuffers; + ExtendableStringBuilder<160> m_Headers; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnection> +{ + HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket); + ~HttpServerConnection(); + + std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); } + + // HttpConnectionBase implementation + + virtual void TerminateConnection() override; + virtual void HandleRequest() override; + + void HandleNewRequest(); + +private: + enum class RequestState + { + kInitialState, + kInitialRead, + kReadingMore, + kWriting, + kWritingFinal, + kDone, + kTerminated + }; + + RequestState m_RequestState = RequestState::kInitialState; + HttpRequestParser m_RequestData{*this}; + + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false); + void CloseConnection(); + + HttpAsioServerImpl& m_Server; + asio::streambuf m_RequestBuffer; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::atomic<uint32_t> m_RequestCounter{0}; + uint32_t m_ConnectionId = 0; + Ref<IHttpPackageHandler> m_PackageHandler; + + RwLock m_ResponsesLock; + std::deque<std::unique_ptr<HttpResponse>> m_Responses; +}; + +std::atomic<uint32_t> g_ConnectionIdCounter{0}; + +HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket) +: m_Server(Server) +, m_Socket(std::move(Socket)) +, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) +{ + ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); +} + +HttpServerConnection::~HttpServerConnection() +{ + ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); +} + +void +HttpServerConnection::HandleNewRequest() +{ + EnqueueRead(); +} + +void +HttpServerConnection::TerminateConnection() +{ + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) + { + return; + } + + m_RequestState = RequestState::kTerminated; + ZEN_ASSERT(m_Socket); + + // Terminating, we don't care about any errors when closing socket + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_both, Ec); + m_Socket->close(Ec); +} + +void +HttpServerConnection::EnqueueRead() +{ + if (m_RequestState == RequestState::kInitialRead) + { + m_RequestState = RequestState::kReadingMore; + } + else + { + m_RequestState = RequestState::kInitialRead; + } + + m_RequestBuffer.prepare(64 * 1024); + + asio::async_read(*m_Socket.get(), + m_RequestBuffer, + asio::transfer_at_least(1), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); +} + +void +HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead) + { + ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection: {}, reason: '{}'", m_ConnectionId, Ec.message()); + return; + } + else + { + ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); + return TerminateConnection(); + } + } + + ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}", + m_ConnectionId, + m_RequestCounter.load(std::memory_order_relaxed), + zen::GetCurrentThreadId(), + NiceBytes(ByteCount)); + + while (m_RequestBuffer.size()) + { + const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); + + size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size()); + if (Result == ~0ull) + { + return TerminateConnection(); + } + + m_RequestBuffer.consume(Result); + } + + switch (m_RequestState) + { + case RequestState::kDone: + case RequestState::kWritingFinal: + case RequestState::kTerminated: + break; + + default: + EnqueueRead(); + break; + } +} + +void +HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop) +{ + if (Ec) + { + ZEN_WARN("on data sent ERROR, connection: {}, reason: '{}'", m_ConnectionId, Ec.message()); + TerminateConnection(); + } + else + { + ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}", + m_ConnectionId, + m_RequestCounter.load(std::memory_order_relaxed), + zen::GetCurrentThreadId(), + NiceBytes(ByteCount)); + + if (!m_RequestData.IsKeepAlive()) + { + CloseConnection(); + } + else + { + if (Pop) + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.pop_front(); + } + + m_RequestCounter.fetch_add(1); + } + } +} + +void +HttpServerConnection::CloseConnection() +{ + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) + { + return; + } + ZEN_ASSERT(m_Socket); + m_RequestState = RequestState::kDone; + + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + if (Ec) + { + ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); + } + m_Socket->close(Ec); + if (Ec) + { + ZEN_WARN("socket close ERROR, reason '{}'", Ec.message()); + } +} + +void +HttpServerConnection::HandleRequest() +{ + if (!m_RequestData.IsKeepAlive()) + { + m_RequestState = RequestState::kWritingFinal; + + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + + if (Ec) + { + ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); + } + } + else + { + m_RequestState = RequestState::kWriting; + } + + if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) + { + ZEN_TRACE_CPU("asio::HandleRequest"); + + HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body()); + + ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed)); + + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + { + try + { + Service->HandleRequest(Request); + } + catch (std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } + } + + if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response)) + { + // Transmit the response + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + Response->SuppressPayload(); + } + + auto ResponseBuffers = Response->AsioBuffers(); + + uint64_t ResponseLength = 0; + + for (auto& Buffer : ResponseBuffers) + { + ResponseLength += Buffer.size(); + } + + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.push_back(std::move(Response)); + } + + // TODO: should cork/uncork for Linux? + + { + ZEN_TRACE_CPU("asio::async_write"); + asio::async_write(*m_Socket.get(), + ResponseBuffers, + asio::transfer_exactly(ResponseLength), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, true); + }); + } + return; + } + } + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "\r\n"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Connection: close\r\n" + "\r\n"sv; + } + + asio::async_write( + *m_Socket.get(), + asio::buffer(Response), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); + } + else + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "No suitable route found"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "Connection: close\r\n" + "\r\n" + "No suitable route found"sv; + } + + asio::async_write( + *m_Socket.get(), + asio::buffer(Response), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); + } +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAcceptor +{ + HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort) + : m_Server(Server) + , m_IoService(IoService) + , m_Acceptor(m_IoService, asio::ip::tcp::v6()) + { + m_Acceptor.set_option(asio::ip::v6_only(false)); +#if ZEN_PLATFORM_WINDOWS + // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms + typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> excluse_address; + m_Acceptor.set_option(excluse_address(true)); +#else // ZEN_PLATFORM_WINDOWS + m_Acceptor.set_option(asio::socket_base::reuse_address(false)); +#endif // ZEN_PLATFORM_WINDOWS + + m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + uint16_t EffectivePort = BasePort; + + asio::error_code BindErrorCode; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + // Sharing violation implies the port is being used by another process + for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode == asio::error::access_denied) + { + EffectivePort = 0; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode) + { + ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message()); + } + +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor.native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +#endif + m_Acceptor.listen(); + + ZEN_INFO("Started asio server at port '{}'", EffectivePort); + } + + void Start() + { + m_Acceptor.listen(); + InitAccept(); + } + + void Stop() { m_IsStopped = true; } + + void InitAccept() + { + auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService); + asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); + + m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", + m_Acceptor.local_endpoint().address().to_string(), + m_Acceptor.local_endpoint().port(), + Ec.message()); + } + else + { + // New connection established, pass socket ownership into connection object + // and initiate request handling loop. The connection lifetime is + // managed by the async read/write loop by passing the shared + // reference to the callbacks. + + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + } + + if (!m_IsStopped.load()) + { + InitAccept(); + } + else + { + std::error_code CloseEc; + m_Acceptor.close(CloseEc); + if (CloseEc) + { + ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message()); + } + } + }); + } + + int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } + +private: + HttpAsioServerImpl& m_Server; + asio::io_service& m_IoService; + asio::ip::tcp::acceptor m_Acceptor; + std::atomic<bool> m_IsStopped{false}; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer) +: m_Request(Request) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const int PrefixLength = Service.UriPrefixLength(); + + std::string_view Uri = Request.Url(); + Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size()))); + m_Uri = Uri; + m_UriWithExtension = Uri; + m_QueryString = Request.QueryString(); + + m_Verb = Request.RequestVerb(); + m_ContentLength = Request.Body().Size(); + m_ContentType = Request.ContentType(); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + // Parse any extension, to allow requesting a particular response encoding via the URL + + { + std::string_view UriSuffix8{m_Uri}; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); + } + } + } + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = Request.AcceptType(); + } +} + +HttpAsioServerRequest::~HttpAsioServerRequest() +{ +} + +Oid +HttpAsioServerRequest::ParseSessionId() const +{ + return m_Request.SessionId(); +} + +uint32_t +HttpAsioServerRequest::ParseRequestId() const +{ + return m_Request.RequestId(); +} + +IoBuffer +HttpAsioServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(HttpContentType::kBinary)); + std::array<IoBuffer, 0> Empty; + + m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(ContentType)); + m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(!m_Response); + m_Response.reset(new HttpResponse(ContentType)); + + IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); +} + +void +HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + ZEN_ASSERT(!m_Response); + + // Not one bit async, innit + ContinuationHandler(*this); +} + +bool +HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerImpl::HttpAsioServerImpl() +{ +} + +HttpAsioServerImpl::~HttpAsioServerImpl() +{ +} + +int +HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount) +{ + ZEN_ASSERT(ThreadCount > 0); + + ZEN_INFO("starting asio http with {} service threads", ThreadCount); + + m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port)); + m_Acceptor->Start(); + + // This should consist of a set of minimum threads and grow on demand to + // meet concurrency needs? Right now we end up allocating a large number + // of threads even if we never end up using all of them, which seems + // wasteful. It's also not clear how the demand for concurrency should + // be balanced with the engine side - ideally we'd have some kind of + // global scheduling to prevent one side from preventing the other side + // from making progress. Or at the very least, thread priorities should + // be considered. + + for (int i = 0; i < ThreadCount; ++i) + { + m_ThreadPool.emplace_back([this, Index = i + 1] { + SetCurrentThreadName(fmt::format("asio_io_{}", Index)); + + try + { + m_IoService.run(); + } + catch (std::exception& e) + { + ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what()); + } + }); + } + + ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort()); + + return m_Acceptor->GetAcceptPort(); +} + +void +HttpAsioServerImpl::Stop() +{ + m_Acceptor->Stop(); + m_IoService.stop(); + for (auto& Thread : m_ThreadPool) + { + Thread.join(); + } +} + +void +HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) +{ + std::string_view UrlPath(InUrlPath); + Service.SetUriPrefixLength(UrlPath.size()); + if (!UrlPath.empty() && UrlPath.back() == '/') + { + UrlPath.remove_suffix(1); + } + + RwLock::ExclusiveLockScope _(m_Lock); + m_UriHandlers.push_back({std::string(UrlPath), &Service}); +} + +HttpService* +HttpAsioServerImpl::RouteRequest(std::string_view Url) +{ + RwLock::SharedLockScope _(m_Lock); + + HttpService* CandidateService = nullptr; + std::string::size_type CandidateMatchSize = 0; + for (const ServiceEntry& SvcEntry : m_UriHandlers) + { + const std::string& SvcUrl = SvcEntry.ServiceUrlPath; + const std::string::size_type SvcUrlSize = SvcUrl.size(); + if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && + ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) + { + CandidateMatchSize = SvcUrl.size(); + CandidateService = SvcEntry.Service; + } + } + + return CandidateService; +} + +} // namespace zen::asio_http + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +class HttpAsioServer : public HttpServer +{ +public: + HttpAsioServer(unsigned int ThreadCount); + ~HttpAsioServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual int Initialize(int BasePort) override; + virtual void Run(bool IsInteractiveSession) override; + virtual void RequestExit() override; + virtual void Close() override; + +private: + Event m_ShutdownEvent; + int m_BasePort = 0; + unsigned int m_ThreadCount = 0; + + std::unique_ptr<asio_http::HttpAsioServerImpl> m_Impl; +}; + +HttpAsioServer::HttpAsioServer(unsigned int ThreadCount) +: m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +, m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>()) +{ + ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser)); +} + +HttpAsioServer::~HttpAsioServer() +{ + if (m_Impl) + { + ZEN_ERROR("~HttpAsioServer() called without calling Close() first"); + } +} + +void +HttpAsioServer::Close() +{ + try + { + m_Impl->Stop(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception stopping http asio server: {}", ex.what()); + } + m_Impl.reset(); +} + +void +HttpAsioServer::RegisterService(HttpService& Service) +{ + m_Impl->RegisterService(Service.BaseUri(), Service); +} + +int +HttpAsioServer::Initialize(int BasePort) +{ + m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), m_ThreadCount); + return m_BasePort; +} + +void +HttpAsioServer::Run(bool IsInteractive) +{ + const bool TestMode = !IsInteractive; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +#if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#endif +} + +void +HttpAsioServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +Ref<HttpServer> +CreateHttpAsioServer(unsigned int ThreadCount) +{ + return Ref<HttpServer>{new HttpAsioServer(ThreadCount)}; +} + +} // namespace zen diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h new file mode 100644 index 000000000..2366f3437 --- /dev/null +++ b/src/zenhttp/servers/httpasio.h @@ -0,0 +1,11 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +namespace zen { + +Ref<HttpServer> CreateHttpAsioServer(unsigned int ThreadCount); + +} // namespace zen diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp new file mode 100644 index 000000000..658f51831 --- /dev/null +++ b/src/zenhttp/servers/httpnull.cpp @@ -0,0 +1,88 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpnull.h" + +#include <zencore/logging.h> + +#if ZEN_PLATFORM_WINDOWS +# include <conio.h> +#endif + +namespace zen { + +HttpNullServer::HttpNullServer() +{ +} + +HttpNullServer::~HttpNullServer() +{ +} + +void +HttpNullServer::RegisterService(HttpService& Service) +{ + ZEN_UNUSED(Service); +} + +int +HttpNullServer::Initialize(int BasePort) +{ + return BasePort; +} + +void +HttpNullServer::Run(bool IsInteractiveSession) +{ + const bool TestMode = !IsInteractiveSession; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +#if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#endif +} + +void +HttpNullServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +void +HttpNullServer::Close() +{ +} + +} // namespace zen diff --git a/src/zenhttp/servers/httpnull.h b/src/zenhttp/servers/httpnull.h new file mode 100644 index 000000000..965e729f7 --- /dev/null +++ b/src/zenhttp/servers/httpnull.h @@ -0,0 +1,30 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> + +namespace zen { + +/** + * @brief Null implementation of "http" server. Does nothing + */ + +class HttpNullServer : public HttpServer +{ +public: + HttpNullServer(); + ~HttpNullServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual int Initialize(int BasePort) override; + virtual void Run(bool IsInteractiveSession) override; + virtual void RequestExit() override; + virtual void Close() override; + +private: + Event m_ShutdownEvent; +}; + +} // namespace zen diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp new file mode 100644 index 000000000..6b987151a --- /dev/null +++ b/src/zenhttp/servers/httpparser.cpp @@ -0,0 +1,370 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpparser.h" + +#include <zencore/logging.h> +#include <zencore/string.h> + +namespace zen { + +using namespace std::literals; + +static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); + +////////////////////////////////////////////////////////////////////////// +// +// HttpRequestParser +// + +http_parser_settings HttpRequestParser::s_ParserSettings{ + .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, + .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, + .on_status = + [](http_parser* p, const char* Data, size_t ByteCount) { + ZEN_UNUSED(p, Data, ByteCount); + return 0; + }, + .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, + .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, + .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, + .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, + .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, + .on_chunk_header{}, + .on_chunk_complete{}}; + +HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection) +{ + http_parser_init(&m_Parser, HTTP_REQUEST); + m_Parser.data = this; + + ResetState(); +} + +HttpRequestParser::~HttpRequestParser() +{ +} + +size_t +HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize) +{ + const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); + + http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser)); + + if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE) + { + ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno)); + return ~0ull; + } + + return ConsumedBytes; +} + +int +HttpRequestParser::OnUrl(const char* Data, size_t Bytes) +{ + if (!m_Url) + { + ZEN_ASSERT_SLOW(m_UrlLength == 0); + m_Url = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_UrlLength += Bytes; + + return 0; +} + +int +HttpRequestParser::OnHeader(const char* Data, size_t Bytes) +{ + if (m_CurrentHeaderValueLength) + { + AppendCurrentHeader(); + + m_CurrentHeaderNameLength = 0; + m_CurrentHeaderValueLength = 0; + m_CurrentHeaderName = m_HeaderCursor; + } + else if (m_CurrentHeaderName == nullptr) + { + m_CurrentHeaderName = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_CurrentHeaderNameLength += Bytes; + + return 0; +} + +void +HttpRequestParser::AppendCurrentHeader() +{ + std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength); + std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength); + + const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); + + if (HeaderHash == HashContentLength) + { + m_ContentLengthHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashAccept) + { + m_AcceptHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashContentType) + { + m_ContentTypeHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashSession) + { + m_SessionId = Oid::FromHexString(HeaderValue); + } + else if (HeaderHash == HashRequest) + { + std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); + } + else if (HeaderHash == HashExpect) + { + if (HeaderValue == "100-continue"sv) + { + // We don't currently do anything with this + m_Expect100Continue = true; + } + else + { + ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); + } + } + else if (HeaderHash == HashRange) + { + m_RangeHeaderIndex = (int8_t)m_Headers.size(); + } + + m_Headers.emplace_back(HeaderName, HeaderValue); +} + +int +HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes) +{ + if (m_CurrentHeaderValueLength == 0) + { + m_CurrentHeaderValue = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_CurrentHeaderValueLength += Bytes; + + return 0; +} + +static void +NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl) +{ + bool LastCharWasSeparator = false; + for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex) + { + const char UrlChar = Url[UrlIndex]; + const bool IsSeparator = (UrlChar == '/'); + + if (IsSeparator && LastCharWasSeparator) + { + if (NormalizedUrl.empty()) + { + NormalizedUrl.reserve(UrlLength); + NormalizedUrl.append(Url, UrlIndex); + } + + if (!LastCharWasSeparator) + { + NormalizedUrl.push_back('/'); + } + } + else if (!NormalizedUrl.empty()) + { + NormalizedUrl.push_back(UrlChar); + } + + LastCharWasSeparator = IsSeparator; + } +} + +int +HttpRequestParser::OnHeadersComplete() +{ + try + { + if (m_CurrentHeaderValueLength) + { + AppendCurrentHeader(); + } + + m_KeepAlive = !!http_should_keep_alive(&m_Parser); + + switch (m_Parser.method) + { + case HTTP_GET: + m_RequestVerb = HttpVerb::kGet; + break; + + case HTTP_POST: + m_RequestVerb = HttpVerb::kPost; + break; + + case HTTP_PUT: + m_RequestVerb = HttpVerb::kPut; + break; + + case HTTP_DELETE: + m_RequestVerb = HttpVerb::kDelete; + break; + + case HTTP_HEAD: + m_RequestVerb = HttpVerb::kHead; + break; + + case HTTP_COPY: + m_RequestVerb = HttpVerb::kCopy; + break; + + case HTTP_OPTIONS: + m_RequestVerb = HttpVerb::kOptions; + break; + + default: + ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method)); + break; + } + + std::string_view Url(m_Url, m_UrlLength); + + if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos) + { + m_UrlLength = QuerySplit; + m_QueryString = m_Url + QuerySplit + 1; + m_QueryLength = Url.size() - QuerySplit - 1; + } + + NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl); + + if (m_ContentLengthHeaderIndex >= 0) + { + std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value; + uint64_t ContentLength = 0; + std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength); + + if (ContentLength) + { + m_BodyBuffer = IoBuffer(ContentLength); + } + + m_BodyBuffer.SetContentType(ContentType()); + + m_BodyPosition = 0; + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("HttpRequestParser::OnHeadersComplete failed. Reason '{}'", Ex.what()); + return -1; + } + return 0; +} + +int +HttpRequestParser::OnBody(const char* Data, size_t Bytes) +{ + if (m_BodyPosition + Bytes > m_BodyBuffer.Size()) + { + ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes", + (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); + return 1; + } + memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); + m_BodyPosition += Bytes; + + if (http_body_is_final(&m_Parser)) + { + if (m_BodyPosition != m_BodyBuffer.Size()) + { + ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); + return 1; + } + } + + return 0; +} + +void +HttpRequestParser::ResetState() +{ + m_HeaderCursor = m_HeaderBuffer; + m_CurrentHeaderName = nullptr; + m_CurrentHeaderNameLength = 0; + m_CurrentHeaderValue = nullptr; + m_CurrentHeaderValueLength = 0; + m_CurrentHeaderName = nullptr; + m_Url = nullptr; + m_UrlLength = 0; + m_QueryString = nullptr; + m_QueryLength = 0; + m_ContentLengthHeaderIndex = -1; + m_AcceptHeaderIndex = -1; + m_ContentTypeHeaderIndex = -1; + m_RangeHeaderIndex = -1; + m_Expect100Continue = false; + m_BodyBuffer = {}; + m_BodyPosition = 0; + m_Headers.clear(); + m_NormalizedUrl.clear(); +} + +int +HttpRequestParser::OnMessageBegin() +{ + return 0; +} + +int +HttpRequestParser::OnMessageComplete() +{ + m_Connection.HandleRequest(); + + ResetState(); + + return 0; +} + +} // namespace zen diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h new file mode 100644 index 000000000..219ac351d --- /dev/null +++ b/src/zenhttp/servers/httpparser.h @@ -0,0 +1,112 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/uid.h> +#include <zenhttp/httpcommon.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <http_parser.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class HttpRequestParserCallbacks +{ +public: + virtual ~HttpRequestParserCallbacks() = default; + virtual void HandleRequest() = 0; + virtual void TerminateConnection() = 0; +}; + +struct HttpRequestParser +{ + explicit HttpRequestParser(HttpRequestParserCallbacks& Connection); + ~HttpRequestParser(); + + size_t ConsumeData(const char* InputData, size_t DataSize); + void ResetState(); + + HttpVerb RequestVerb() const { return m_RequestVerb; } + bool IsKeepAlive() const { return m_KeepAlive; } + std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; } + std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); } + IoBuffer Body() { return m_BodyBuffer; } + + inline HttpContentType ContentType() + { + if (m_ContentTypeHeaderIndex < 0) + { + return HttpContentType::kUnknownContentType; + } + + return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value); + } + + inline HttpContentType AcceptType() + { + if (m_AcceptHeaderIndex < 0) + { + return HttpContentType::kUnknownContentType; + } + + return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value); + } + + Oid SessionId() const { return m_SessionId; } + int RequestId() const { return m_RequestId; } + + std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); } + +private: + struct HeaderEntry + { + HeaderEntry() = default; + + HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {} + + std::string_view Name; + std::string_view Value; + }; + + HttpRequestParserCallbacks& m_Connection; + char* m_HeaderCursor = m_HeaderBuffer; + char* m_Url = nullptr; + size_t m_UrlLength = 0; + char* m_QueryString = nullptr; + size_t m_QueryLength = 0; + char* m_CurrentHeaderName = nullptr; // Used while parsing headers + size_t m_CurrentHeaderNameLength = 0; + char* m_CurrentHeaderValue = nullptr; // Used while parsing headers + size_t m_CurrentHeaderValueLength = 0; + std::vector<HeaderEntry> m_Headers; + int8_t m_ContentLengthHeaderIndex; + int8_t m_AcceptHeaderIndex; + int8_t m_ContentTypeHeaderIndex; + int8_t m_RangeHeaderIndex; + HttpVerb m_RequestVerb; + bool m_KeepAlive = false; + bool m_Expect100Continue = false; + int m_RequestId = -1; + Oid m_SessionId{}; + IoBuffer m_BodyBuffer; + uint64_t m_BodyPosition = 0; + http_parser m_Parser; + char m_HeaderBuffer[1024]; + std::string m_NormalizedUrl; + + void AppendCurrentHeader(); + + int OnMessageBegin(); + int OnUrl(const char* Data, size_t Bytes); + int OnHeader(const char* Data, size_t Bytes); + int OnHeaderValue(const char* Data, size_t Bytes); + int OnHeadersComplete(); + int OnBody(const char* Data, size_t Bytes); + int OnMessageComplete(); + + static HttpRequestParser* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); } + static http_parser_settings s_ParserSettings; +}; + +} // namespace zen diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp new file mode 100644 index 000000000..2e934473e --- /dev/null +++ b/src/zenhttp/servers/httpplugin.cpp @@ -0,0 +1,781 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpplugin.h> + +#if ZEN_WITH_PLUGINS + +# include "httpparser.h" + +# include <zencore/except.h> +# include <zencore/logging.h> +# include <zencore/trace.h> +# include <zencore/workthreadpool.h> +# include <zenhttp/httpserver.h> + +# include <memory> +# include <string_view> + +# if ZEN_PLATFORM_WINDOWS +# include <conio.h> +# endif + +# define PLUGIN_VERBOSE_TRACE 1 + +# if PLUGIN_VERBOSE_TRACE +# define ZEN_TRACE_VERBOSE ZEN_TRACE +# else +# define ZEN_TRACE_VERBOSE(fmtstr, ...) +# endif + +namespace zen { + +struct HttpPluginServerImpl; +struct HttpPluginResponse; + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginConnectionHandler : public TransportServerConnection, public HttpRequestParserCallbacks, RefCounted +{ + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + + virtual void OnBytesRead(const void* Buffer, size_t DataSize) override; + + // HttpRequestParserCallbacks + + virtual void HandleRequest() override; + virtual void TerminateConnection() override; + + void Initialize(TransportConnection* Transport, HttpPluginServerImpl& Server); + +private: + enum class RequestState + { + kInitialState, + kInitialRead, + kReadingMore, + kWriting, // Currently writing response, connection will be re-used + kWritingFinal, // Writing response, connection will be closed + kDone, + kTerminated + }; + + RequestState m_RequestState = RequestState::kInitialState; + HttpRequestParser m_RequestParser{*this}; + + uint32_t m_ConnectionId = 0; + Ref<IHttpPackageHandler> m_PackageHandler; + + TransportConnection* m_TransportConnection = nullptr; + HttpPluginServerImpl* m_Server = nullptr; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginServerImpl : public TransportServer +{ + HttpPluginServerImpl(); + ~HttpPluginServerImpl(); + + void AddPlugin(Ref<TransportPlugin> Plugin); + void RemovePlugin(Ref<TransportPlugin> Plugin); + + void Start(); + void Stop(); + + void RegisterService(const char* InUrlPath, HttpService& Service); + HttpService* RouteRequest(std::string_view Url); + + struct ServiceEntry + { + std::string ServiceUrlPath; + HttpService* Service; + }; + + RwLock m_Lock; + std::vector<ServiceEntry> m_UriHandlers; + std::vector<Ref<TransportPlugin>> m_Plugins; + + // TransportServer + + virtual TransportServerConnection* CreateConnectionHandler(TransportConnection* Connection) override; +}; + +/** This is the class which request handlers interface with when + generating responses + */ + +class HttpPluginServerRequest : public HttpServerRequest +{ +public: + HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpPluginServerRequest(); + + HttpPluginServerRequest(const HttpPluginServerRequest&) = delete; + HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpRequestParser& m_Request; + IoBuffer m_PayloadBuffer; + std::unique_ptr<HttpPluginResponse> m_Response; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginResponse +{ +public: + HttpPluginResponse() = default; + explicit HttpPluginResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList); + + inline uint16_t ResponseCode() const { return m_ResponseCode; } + inline uint64_t ContentLength() const { return m_ContentLength; } + + const std::vector<IoBuffer>& ResponseBuffers() const { return m_ResponseBuffers; } + void SuppressPayload() { m_ResponseBuffers.resize(1); } + +private: + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; + std::vector<IoBuffer> m_ResponseBuffers; + ExtendableStringBuilder<160> m_Headers; + + std::string_view GetHeaders(); +}; + +void +HttpPluginResponse::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) +{ + ZEN_TRACE_CPU("http_plugin::InitializeForPayload"); + + m_ResponseCode = ResponseCode; + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_ResponseBuffers.reserve(ChunkCount + 1); + m_ResponseBuffers.push_back({}); // Placeholder for header + + uint64_t TotalDataSize = 0; + + for (IoBuffer& Buffer : BlobList) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + TotalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // TODO: Use direct file transfer, via TransmitFile/sendfile + + m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + else + { + // Send from memory + + m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + } + m_ContentLength = TotalDataSize; + + auto Headers = GetHeaders(); + m_ResponseBuffers[0] = IoBufferBuilder::MakeCloneFromMemory(Headers.data(), Headers.size()); +} + +std::string_view +HttpPluginResponse::GetHeaders() +{ + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" + << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Length: " << ContentLength() << "\r\n"sv; + + if (!m_IsKeepAlive) + { + m_Headers << "Connection: close\r\n"sv; + } + + m_Headers << "\r\n"sv; + + return m_Headers; +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPluginServerImpl& Server) +{ + m_TransportConnection = Transport; + m_Server = &Server; +} + +uint32_t +HttpPluginConnectionHandler::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +HttpPluginConnectionHandler::Release() const +{ + return RefCounted::Release(); +} + +void +HttpPluginConnectionHandler::OnBytesRead(const void* Buffer, size_t AvailableBytes) +{ + while (AvailableBytes) + { + const size_t ConsumedBytes = m_RequestParser.ConsumeData((const char*)Buffer, AvailableBytes); + + if (ConsumedBytes == ~0ull) + { + // terminate connection + + return TerminateConnection(); + } + + Buffer = reinterpret_cast<const uint8_t*>(Buffer) + ConsumedBytes; + AvailableBytes -= ConsumedBytes; + } +} + +// HttpRequestParserCallbacks + +void +HttpPluginConnectionHandler::HandleRequest() +{ + if (!m_RequestParser.IsKeepAlive()) + { + // Once response has been written, connection is done + m_RequestState = RequestState::kWritingFinal; + + // We're not going to read any more data from this socket + + const bool Receive = true; + const bool Transmit = false; + m_TransportConnection->Shutdown(Receive, Transmit); + } + else + { + m_RequestState = RequestState::kWriting; + } + + auto SendBuffer = [&](const IoBuffer& InBuffer) -> int64_t { + const char* Buffer = reinterpret_cast<const char*>(InBuffer.GetData()); + size_t Bytes = InBuffer.GetSize(); + + return m_TransportConnection->WriteBytes(Buffer, Bytes); + }; + + // Generate response + + if (HttpService* Service = m_Server->RouteRequest(m_RequestParser.Url())) + { + ZEN_TRACE_CPU("http_plugin::HandleRequest"); + + HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body()); + + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + { + try + { + Service->HandleRequest(Request); + } + catch (std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } + } + + if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response)) + { + // Transmit the response + + if (m_RequestParser.RequestVerb() == HttpVerb::kHead) + { + Response->SuppressPayload(); + } + + const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers(); + + //// TODO: should cork/uncork for Linux? + + for (const IoBuffer& Buffer : ResponseBuffers) + { + int64_t SentBytes = SendBuffer(Buffer); + + if (SentBytes < 0) + { + TerminateConnection(); + + return; + } + } + + return; + } + } + + // No route found for request + + std::string_view Response; + + if (m_RequestParser.RequestVerb() == HttpVerb::kHead) + { + if (m_RequestParser.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "\r\n"sv; + } + else + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Connection: close\r\n" + "\r\n"sv; + } + } + else + { + if (m_RequestParser.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "No suitable route found"sv; + } + else + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "Connection: close\r\n" + "\r\n" + "No suitable route found"sv; + } + } + + const int64_t SentBytes = SendBuffer(IoBufferBuilder::MakeFromMemory(MakeMemoryView(Response))); + + if (SentBytes < 0) + { + TerminateConnection(); + + return; + } +} + +void +HttpPluginConnectionHandler::TerminateConnection() +{ + ZEN_ASSERT(m_TransportConnection); + m_TransportConnection->CloseConnection(); +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServerRequest::HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer) +: m_Request(Request) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const int PrefixLength = Service.UriPrefixLength(); + + std::string_view Uri = Request.Url(); + Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size()))); + m_Uri = Uri; + m_UriWithExtension = Uri; + m_QueryString = Request.QueryString(); + + m_Verb = Request.RequestVerb(); + m_ContentLength = Request.Body().Size(); + m_ContentType = Request.ContentType(); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + // Parse any extension, to allow requesting a particular response encoding via the URL + + { + std::string_view UriSuffix8{m_Uri}; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); + } + } + } + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = Request.AcceptType(); + } +} + +HttpPluginServerRequest::~HttpPluginServerRequest() +{ +} + +Oid +HttpPluginServerRequest::ParseSessionId() const +{ + return m_Request.SessionId(); +} + +uint32_t +HttpPluginServerRequest::ParseRequestId() const +{ + return m_Request.RequestId(); +} + +IoBuffer +HttpPluginServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary)); + std::array<IoBuffer, 0> Empty; + + m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpPluginResponse(ContentType)); + m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(!m_Response); + m_Response.reset(new HttpPluginResponse(ContentType)); + + IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); +} + +void +HttpPluginServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + ZEN_ASSERT(!m_Response); + + // Not one bit async, innit + ContinuationHandler(*this); +} + +bool +HttpPluginServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServerImpl::HttpPluginServerImpl() +{ +} + +HttpPluginServerImpl::~HttpPluginServerImpl() +{ +} + +TransportServerConnection* +HttpPluginServerImpl::CreateConnectionHandler(TransportConnection* Connection) +{ + HttpPluginConnectionHandler* Handler{new HttpPluginConnectionHandler()}; + Handler->Initialize(Connection, *this); + return Handler; +} + +void +HttpPluginServerImpl::Start() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (auto& Plugin : m_Plugins) + { + try + { + Plugin->Initialize(this); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught during plugin initialization: {}", Ex.what()); + } + } +} + +void +HttpPluginServerImpl::Stop() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (auto& Plugin : m_Plugins) + { + try + { + Plugin->Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught during plugin shutdown: {}", Ex.what()); + } + + Plugin = nullptr; + } + + m_Plugins.clear(); +} + +void +HttpPluginServerImpl::AddPlugin(Ref<TransportPlugin> Plugin) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_Plugins.emplace_back(std::move(Plugin)); +} + +void +HttpPluginServerImpl::RemovePlugin(Ref<TransportPlugin> Plugin) +{ + RwLock::ExclusiveLockScope _(m_Lock); + auto It = std::find(begin(m_Plugins), end(m_Plugins), Plugin); + if (It != m_Plugins.end()) + { + m_Plugins.erase(It); + } +} + +void +HttpPluginServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) +{ + std::string_view UrlPath(InUrlPath); + Service.SetUriPrefixLength(UrlPath.size()); + if (!UrlPath.empty() && UrlPath.back() == '/') + { + UrlPath.remove_suffix(1); + } + + RwLock::ExclusiveLockScope _(m_Lock); + m_UriHandlers.push_back({std::string(UrlPath), &Service}); +} + +HttpService* +HttpPluginServerImpl::RouteRequest(std::string_view Url) +{ + RwLock::SharedLockScope _(m_Lock); + + HttpService* CandidateService = nullptr; + std::string::size_type CandidateMatchSize = 0; + for (const ServiceEntry& SvcEntry : m_UriHandlers) + { + const std::string& SvcUrl = SvcEntry.ServiceUrlPath; + const std::string::size_type SvcUrlSize = SvcUrl.size(); + if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && + ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) + { + CandidateMatchSize = SvcUrl.size(); + CandidateService = SvcEntry.Service; + } + } + + return CandidateService; +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServer::HttpPluginServer(unsigned int ThreadCount) +: m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +, m_Impl(new HttpPluginServerImpl) +{ +} + +HttpPluginServer::~HttpPluginServer() +{ + if (m_Impl) + { + ZEN_ERROR("~HttpPluginServer() called without calling Close() first"); + } +} + +int +HttpPluginServer::Initialize(int BasePort) +{ + try + { + m_Impl->Start(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception starting http plugin server: {}", ex.what()); + } + + return BasePort; +} + +void +HttpPluginServer::Close() +{ + try + { + m_Impl->Stop(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception stopping http plugin server: {}", ex.what()); + } + + delete m_Impl; + m_Impl = nullptr; +} + +void +HttpPluginServer::Run(bool IsInteractive) +{ + const bool TestMode = !IsInteractive; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +# if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +# else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +# endif +} + +void +HttpPluginServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +void +HttpPluginServer::RegisterService(HttpService& Service) +{ + m_Impl->RegisterService(Service.BaseUri(), Service); +} + +void +HttpPluginServer::AddPlugin(Ref<TransportPlugin> Plugin) +{ + m_Impl->AddPlugin(Plugin); +} + +void +HttpPluginServer::RemovePlugin(Ref<TransportPlugin> Plugin) +{ + m_Impl->RemovePlugin(Plugin); +} + +} // namespace zen +#endif diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp new file mode 100644 index 000000000..c1b4717cb --- /dev/null +++ b/src/zenhttp/servers/httpsys.cpp @@ -0,0 +1,2012 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpsys.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/string.h> +#include <zencore/timer.h> +#include <zencore/trace.h> +#include <zenhttp/httpshared.h> + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include <zencore/windows.h> +# include <zencore/workthreadpool.h> +# include "iothreadpool.h" + +# include <http.h> + +namespace spdlog { +class logger; +} + +namespace zen { + +/** + * @brief Windows implementation of HTTP server based on http.sys + * + * This requires elevation to function by default but system configuration + * can soften this requirement. + * + * See README.md for details. + */ +class HttpSysServer : public HttpServer +{ + friend class HttpSysTransaction; + +public: + explicit HttpSysServer(const HttpSysConfig& Config); + ~HttpSysServer(); + + // HttpServer interface implementation + + virtual int Initialize(int BasePort) override; + virtual void Run(bool TestMode) override; + virtual void RequestExit() override; + virtual void RegisterService(HttpService& Service) override; + virtual void Close() override; + + WorkerThreadPool& WorkPool(); + + inline bool IsOk() const { return m_IsOk; } + inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + +private: + int InitializeServer(int BasePort); + void Cleanup(); + + void StartServer(); + void OnHandlingNewRequest(); + void IssueNewRequestMaybe(); + + void RegisterService(const char* Endpoint, HttpService& Service); + void UnregisterService(const char* Endpoint, HttpService& Service); + +private: + spdlog::logger& m_Log; + spdlog::logger& m_RequestLog; + spdlog::logger& Log() { return m_Log; } + + bool m_IsOk = false; + bool m_IsHttpInitialized = false; + bool m_IsRequestLoggingEnabled = false; + bool m_IsAsyncResponseEnabled = true; + + std::unique_ptr<WinIoThreadPool> m_IoThreadPool; + + RwLock m_AsyncWorkPoolInitLock; + WorkerThreadPool* m_AsyncWorkPool = nullptr; + + std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; + HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; + HANDLE m_RequestQueueHandle = 0; + std::atomic_int32_t m_PendingRequests{0}; + std::atomic_int32_t m_IsShuttingDown{0}; + int32_t m_MinPendingRequests = 16; + int32_t m_MaxPendingRequests = 128; + Event m_ShutdownEvent; + HttpSysConfig m_InitialConfig; +}; + +} // namespace zen +#endif + +#if ZEN_WITH_HTTPSYS + +# include <conio.h> +# include <mstcpip.h> +# pragma comment(lib, "httpapi.lib") + +std::wstring +UTF8_to_UTF16(const char* InPtr) +{ + std::wstring OutString; + unsigned int Codepoint; + + while (*InPtr != 0) + { + unsigned char InChar = static_cast<unsigned char>(*InPtr); + + if (InChar <= 0x7f) + Codepoint = InChar; + else if (InChar <= 0xbf) + Codepoint = (Codepoint << 6) | (InChar & 0x3f); + else if (InChar <= 0xdf) + Codepoint = InChar & 0x1f; + else if (InChar <= 0xef) + Codepoint = InChar & 0x0f; + else + Codepoint = InChar & 0x07; + + ++InPtr; + + if (((*InPtr & 0xc0) != 0x80) && (Codepoint <= 0x10ffff)) + { + if (Codepoint > 0xffff) + { + OutString.append(1, static_cast<wchar_t>(0xd800 + (Codepoint >> 10))); + OutString.append(1, static_cast<wchar_t>(0xdc00 + (Codepoint & 0x03ff))); + } + else if (Codepoint < 0xd800 || Codepoint >= 0xe000) + { + OutString.append(1, static_cast<wchar_t>(Codepoint)); + } + } + } + + return OutString; +} + +namespace zen { + +using namespace std::literals; + +class HttpSysServer; +class HttpSysTransaction; +class HttpMessageResponseRequest; + +////////////////////////////////////////////////////////////////////////// + +HttpVerb +TranslateHttpVerb(HTTP_VERB ReqVerb) +{ + switch (ReqVerb) + { + case HttpVerbOPTIONS: + return HttpVerb::kOptions; + + case HttpVerbGET: + return HttpVerb::kGet; + + case HttpVerbHEAD: + return HttpVerb::kHead; + + case HttpVerbPOST: + return HttpVerb::kPost; + + case HttpVerbPUT: + return HttpVerb::kPut; + + case HttpVerbDELETE: + return HttpVerb::kDelete; + + case HttpVerbCOPY: + return HttpVerb::kCopy; + + default: + // TODO: invalid request? + return (HttpVerb)0; + } +} + +uint64_t +GetContentLength(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; + std::string_view cl(clh.pRawValue, clh.RawValueLength); + uint64_t ContentLength = 0; + std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); + return ContentLength; +}; + +HttpContentType +GetContentType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +HttpContentType +GetAcceptType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +/** + * @brief Base class for any pending or active HTTP transactions + */ +class HttpSysRequestHandler +{ +public: + explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {} + virtual ~HttpSysRequestHandler() = default; + + virtual void IssueRequest(std::error_code& ErrorCode) = 0; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; + HttpSysTransaction& Transaction() { return m_Transaction; } + + HttpSysRequestHandler(const HttpSysRequestHandler&) = delete; + HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete; + +private: + HttpSysTransaction& m_Transaction; +}; + +/** + * This is the handler for the initial HTTP I/O request which will receive the headers + * and however much of the remaining payload might fit in the embedded request buffer. + * + * It is also used to receive any entity body data relating to the request + * + */ +struct InitialRequestHandler : public HttpSysRequestHandler +{ + inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } + inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } + inline bool IsInitialRequest() const { return m_IsInitialRequest; } + + InitialRequestHandler(HttpSysTransaction& InRequest); + ~InitialRequestHandler(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + + bool m_IsInitialRequest = true; + uint64_t m_CurrentPayloadOffset = 0; + uint64_t m_ContentLength = ~uint64_t(0); + IoBuffer m_PayloadBuffer; + UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)]; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpSysServerRequest : public HttpServerRequest +{ +public: + HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpSysServerRequest(); + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpSysServerRequest(const HttpSysServerRequest&) = delete; + HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; + + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; +}; + +/** HTTP transaction + + There will be an instance of this per pending and in-flight HTTP transaction + + */ +class HttpSysTransaction final +{ +public: + HttpSysTransaction(HttpSysServer& Server); + virtual ~HttpSysTransaction(); + + enum class Status + { + kDone, + kRequestPending + }; + + [[nodiscard]] Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + + static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io); + + void IssueInitialRequest(std::error_code& ErrorCode); + bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler); + + PTP_IO Iocp(); + HANDLE RequestQueueHandle(); + inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline HttpSysServer& Server() { return m_HttpServer; } + inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } + + HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload); + + HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); } + + struct CompletionMutexScope + { + CompletionMutexScope(HttpSysTransaction& Tx) : Lock(Tx.m_CompletionMutex) {} + ~CompletionMutexScope() = default; + + RwLock::ExclusiveLockScope Lock; + }; + +private: + OVERLAPPED m_HttpOverlapped{}; + HttpSysServer& m_HttpServer; + + // Tracks which handler is due to handle the next I/O completion event + HttpSysRequestHandler* m_CompletionHandler = nullptr; + RwLock m_CompletionMutex; + InitialRequestHandler m_InitialHttpHandler{*this}; + std::optional<HttpSysServerRequest> m_HandlerRequest; + Ref<IHttpPackageHandler> m_PackageHandler; +}; + +/** + * @brief HTTP request response I/O request handler + * + * Asynchronously streams out a response to an HTTP request via compound + * responses from memory or directly from file + */ + +class HttpMessageResponseRequest : public HttpSysRequestHandler +{ +public: + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> Blobs); + ~HttpMessageResponseRequest(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + void SuppressResponseBody(); // typically used for HEAD requests + +private: + std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks; + uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes + uint16_t m_ResponseCode = 0; + uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists + uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends + bool m_IsInitialResponse = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + std::vector<IoBuffer> m_DataBuffers; + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); +}; + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) +: HttpSysRequestHandler(InRequest) +{ + std::array<IoBuffer, 0> EmptyBufferList; + + InitializeForPayload(ResponseCode, EmptyBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) +: HttpSysRequestHandler(InRequest) +, m_ContentType(HttpContentType::kText) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> BlobList) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + InitializeForPayload(ResponseCode, BlobList); +} + +HttpMessageResponseRequest::~HttpMessageResponseRequest() +{ +} + +void +HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) +{ + ZEN_TRACE_CPU("httpsys::InitializeForPayload"); + + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_HttpDataChunks.reserve(ChunkCount); + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + + // Initialize the full array up front + + uint64_t LocalDataSize = 0; + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // Use direct file transfer + + auto& Chunk = m_HttpDataChunks.emplace_back(); + + Chunk.DataChunkType = HttpDataChunkFromFileHandle; + Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; + Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; + Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; + } + else + { + // Send from memory, need to make sure we chunk the buffer up since + // the underlying data structure only accepts 32-bit chunk sizes for + // memory chunks. When this happens the vector will be reallocated, + // which is fine since this will be a pretty rare case and sending + // the data is going to take a lot longer than a memory allocation :) + + const uint8_t* WriteCursor = reinterpret_cast<const uint8_t*>(Buffer.Data()); + + while (BufferDataSize) + { + const ULONG ThisChunkSize = gsl::narrow<ULONG>(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); + + auto& Chunk = m_HttpDataChunks.emplace_back(); + + Chunk.DataChunkType = HttpDataChunkFromMemory; + Chunk.FromMemory.pBuffer = (void*)WriteCursor; + Chunk.FromMemory.BufferLength = ThisChunkSize; + + BufferDataSize -= ThisChunkSize; + WriteCursor += ThisChunkSize; + } + } + } + + m_RemainingChunkCount = gsl::narrow<uint32_t>(m_HttpDataChunks.size()); + m_TotalDataSize = LocalDataSize; + + if (m_TotalDataSize == 0 && ResponseCode == 200) + { + // Some HTTP clients really don't like empty responses unless a 204 response is sent + m_ResponseCode = uint16_t(HttpResponseCode::NoContent); + } + else + { + m_ResponseCode = ResponseCode; + } +} + +void +HttpMessageResponseRequest::SuppressResponseBody() +{ + m_RemainingChunkCount = 0; + m_HttpDataChunks.clear(); + m_DataBuffers.clear(); +} + +HttpSysRequestHandler* +HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(NumberOfBytesTransferred); + + if (IoResult != NO_ERROR) + { + ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult)); + + // if one transmit failed there's really no need to go on + return nullptr; + } + + if (m_RemainingChunkCount == 0) + { + return nullptr; // All done + } + + return this; +} + +void +HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) +{ + ZEN_TRACE_CPU("httpsys::Response::IssueRequest"); + HttpSysTransaction& Tx = Transaction(); + HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); + PTP_IO const Iocp = Tx.Iocp(); + + StartThreadpoolIo(Iocp); + + // Split payload into batches to play well with the underlying API + + const int MaxChunksPerCall = 9999; + + const int ThisRequestChunkCount = std::min<int>(m_RemainingChunkCount, MaxChunksPerCall); + const int ThisRequestChunkOffset = m_NextDataChunkOffset; + + m_RemainingChunkCount -= ThisRequestChunkCount; + m_NextDataChunkOffset += ThisRequestChunkCount; + + /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA? + + From the docs: + + This flag enables buffering of data in the kernel on a per-response basis. It should + be used by an application doing synchronous I/O, or by a an application doing + asynchronous I/O with no more than one send outstanding at a time. + + Applications using asynchronous I/O which may have more than one send outstanding at + a time should not use this flag. + + When this flag is set, it should be used consistently in calls to the + HttpSendHttpResponse function as well. + */ + + ULONG SendFlags = HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA; + + if (m_RemainingChunkCount) + { + // We need to make more calls to send the full amount of data + SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + } + + ULONG SendResult = 0; + + if (m_IsInitialResponse) + { + // Populate response structure + + HTTP_RESPONSE HttpResponse = {}; + + HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount); + HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset; + + // Server header + // + // By default this will also add a suffix " Microsoft-HTTPAPI/2.0" to this header + // + // This is controlled via a registry key 'DisableServerHeader', at: + // + // Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\HTTP\Parameters + // + // Set DisableServerHeader to 1 to disable suffix, or 2 to disable the header altogether + // (only the latter appears to do anything in my testing, on Windows 10). + // + // (reference https://docs.microsoft.com/en-us/archive/blogs/dsnotes/wswcf-remove-server-header) + // + + PHTTP_KNOWN_HEADER ServerHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderServer]; + ServerHeader->pRawValue = "Zen"; + ServerHeader->RawValueLength = (USHORT)3; + + // Content-length header + + char ContentLengthString[32]; + _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10); + + PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength]; + ContentLengthHeader->pRawValue = ContentLengthString; + ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString); + + // Content-type header + + PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; + + std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + + ContentTypeHeader->pRawValue = ContentTypeString.data(); + ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); + + HttpResponse.StatusCode = m_ResponseCode; + HttpResponse.pReason = ReasonString.data(); + HttpResponse.ReasonLength = (USHORT)ReasonString.size(); + + // Cache policy + + HTTP_CACHE_POLICY CachePolicy; + + CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; + CachePolicy.SecondsToLive = 0; + + // Initial response API call + + SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + &HttpResponse, + &CachePolicy, + NULL, + NULL, + 0, + Tx.Overlapped(), + NULL); + + m_IsInitialResponse = false; + } + else + { + // Subsequent response API calls + + SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + (USHORT)ThisRequestChunkCount, // EntityChunkCount + &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks + NULL, // BytesSent + NULL, // Reserved1 + 0, // Reserved2 + Tx.Overlapped(), // Overlapped + NULL // LogData + ); + } + + auto EmitReponseDetails = [&](StringBuilderBase& ResponseDetails) -> void { + for (int i = 0; i < ThisRequestChunkCount; ++i) + { + const HTTP_DATA_CHUNK Chunk = m_HttpDataChunks[ThisRequestChunkOffset + i]; + + if (i > 0) + { + ResponseDetails << " + "; + } + + switch (Chunk.DataChunkType) + { + case HttpDataChunkFromMemory: + ResponseDetails << "mem:" << uint64_t(Chunk.FromMemory.BufferLength); + break; + + case HttpDataChunkFromFileHandle: + ResponseDetails << "file:"; + { + ResponseDetails << uint64_t(Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart) << "," + << uint64_t(Chunk.FromFileHandle.ByteRange.Length.QuadPart) << ","; + + std::error_code PathEc; + HANDLE FileHandle = Chunk.FromFileHandle.FileHandle; + std::filesystem::path Path = PathFromHandle(FileHandle, PathEc); + + if (PathEc) + { + ResponseDetails << "bad_file(handle=" << reinterpret_cast<uint64_t>(FileHandle) << ",error=" << PathEc.message() + << ")"; + } + else + { + const uint64_t FileSize = FileSizeFromHandle(FileHandle); + ResponseDetails << Path.u8string() << "(" << FileSize << ") handle=" << reinterpret_cast<uint64_t>(FileHandle); + } + } + break; + + case HttpDataChunkFromFragmentCache: + ResponseDetails << "frag:???"; // We do not use these + break; + + case HttpDataChunkFromFragmentCacheEx: + ResponseDetails << "frax:???"; // We do not use these + break; + +# if 0 // Not available in older Windows SDKs + case HttpDataChunkTrailers: + ResponseDetails << "trls:???"; // We do not use these + break; +# endif + + default: + ResponseDetails << "???: " << Chunk.DataChunkType; + break; + } + } + }; + + if (SendResult == NO_ERROR) + { + // Synchronous completion, but the completion event will still be posted to IOCP + + ErrorCode.clear(); + } + else if (SendResult == ERROR_IO_PENDING) + { + // Asynchronous completion, a completion notification will be posted to IOCP + + ErrorCode.clear(); + } + else + { + ErrorCode = MakeErrorCode(SendResult); + + // An error occurred, no completion will be posted to IOCP + + CancelThreadpoolIo(Iocp); + + // Emit diagnostics + + ExtendableStringBuilder<256> ResponseDetails; + EmitReponseDetails(ResponseDetails); + + ZEN_WARN("failed to send HTTP response (error {}: '{}'), request URL: '{}', ({}.{}) response: {}", + SendResult, + ErrorCode.message(), + HttpReq->pRawUrl, + Tx.ServerRequest().SessionId(), + HttpReq->RequestId, + ResponseDetails); + } +} + +/** HTTP completion handler for async work + + This is used to allow work to be taken off the request handler threads + and to support posting responses asynchronously. + */ + +class HttpAsyncWorkRequest : public HttpSysRequestHandler +{ +public: + HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response); + ~HttpAsyncWorkRequest(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + +private: + struct AsyncWorkItem : public IWork + { + virtual void Execute() override; + + AsyncWorkItem(HttpSysTransaction& InTx, std::function<void(HttpServerRequest&)>&& InHandler) + : Tx(InTx) + , Handler(std::move(InHandler)) + { + } + + HttpSysTransaction& Tx; + std::function<void(HttpServerRequest&)> Handler; + }; + + Ref<AsyncWorkItem> m_WorkItem; +}; + +HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response) +: HttpSysRequestHandler(Tx) +{ + m_WorkItem = new AsyncWorkItem(Tx, std::move(Response)); +} + +HttpAsyncWorkRequest::~HttpAsyncWorkRequest() +{ +} + +void +HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode) +{ + ZEN_TRACE_CPU("httpsys::AsyncWork::IssueRequest"); + ErrorCode.clear(); + + Transaction().Server().WorkPool().ScheduleWork(m_WorkItem); +} + +HttpSysRequestHandler* +HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // This ought to not be called since there should be no outstanding I/O request + // when this completion handler is active + + ZEN_UNUSED(IoResult, NumberOfBytesTransferred); + + ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); + + return this; +} + +void +HttpAsyncWorkRequest::AsyncWorkItem::Execute() +{ + ZEN_TRACE_CPU("httpsys::async_execute"); + + try + { + // We need to hold this lock while we're issuing new requests in order to + // prevent race conditions between the thread we are running on and any + // IOCP service threads. Otherwise the IOCP completion handler can end + // up deleting the transaction object before we are done with it! + HttpSysTransaction::CompletionMutexScope _(Tx); + HttpSysServerRequest& ThisRequest = Tx.ServerRequest(); + + ThisRequest.m_NextCompletionHandler = nullptr; + + { + ZEN_TRACE_CPU("httpsys::HandleRequest"); + Handler(ThisRequest); + } + + // TODO: should Handler be destroyed at this point to ensure there + // are no outstanding references into state which could be + // deleted asynchronously as a result of issuing the response? + + if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler) + { + return (void)Tx.IssueNextRequest(NextHandler); + } + else if (!ThisRequest.IsHandled()) + { + return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv)); + } + else + { + // "Handled" but no request handler? Shouldn't ever happen + return (void)Tx.IssueNextRequest( + new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv)); + } + } + catch (std::exception& Ex) + { + return (void)Tx.IssueNextRequest( + new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what()))); + } +} + +/** + _________ + / _____/ ______________ __ ___________ + \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ + / \ ___/| | \/\ /\ ___/| | \/ + /_______ /\___ >__| \_/ \___ >__| + \/ \/ \/ +*/ + +HttpSysServer::HttpSysServer(const HttpSysConfig& InConfig) +: m_Log(logging::Get("http")) +, m_RequestLog(logging::Get("http_requests")) +, m_IsRequestLoggingEnabled(InConfig.IsRequestLoggingEnabled) +, m_IsAsyncResponseEnabled(InConfig.IsAsyncResponseEnabled) +, m_InitialConfig(InConfig) +{ + // Initialize thread pool + + int MinThreadCount; + int MaxThreadCount; + + if (m_InitialConfig.ThreadCount == 0) + { + MinThreadCount = Max(8u, std::thread::hardware_concurrency()); + } + else + { + MinThreadCount = m_InitialConfig.ThreadCount; + } + + MaxThreadCount = MinThreadCount * 2; + + if (m_InitialConfig.IsDedicatedServer) + { + // In order to limit the potential impact of threads stuck + // in locks we allow the thread pool to be oversubscribed + // by a fair amount + + MaxThreadCount *= 2; + } + + m_IoThreadPool = std::make_unique<WinIoThreadPool>(MinThreadCount, MaxThreadCount); + + if (m_InitialConfig.AsyncWorkThreadCount == 0) + { + m_InitialConfig.AsyncWorkThreadCount = 16; + } + + ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr); + + if (Result != NO_ERROR) + { + return; + } + + m_IsHttpInitialized = true; + m_IsOk = true; + + ZEN_INFO("http.sys server started in {} mode, using {}-{} I/O threads and {} async worker threads", + m_InitialConfig.IsDedicatedServer ? "DEDICATED" : "NORMAL", + MinThreadCount, + MaxThreadCount, + m_InitialConfig.AsyncWorkThreadCount); +} + +HttpSysServer::~HttpSysServer() +{ + if (m_IsHttpInitialized) + { + ZEN_ERROR("~HttpSysServer() called without calling Close() first"); + } + + delete m_AsyncWorkPool; + m_AsyncWorkPool = nullptr; +} + +void +HttpSysServer::Close() +{ + if (m_IsHttpInitialized) + { + Cleanup(); + + HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr); + m_IsHttpInitialized = false; + } +} + +int +HttpSysServer::InitializeServer(int BasePort) +{ + using namespace std::literals; + + WideStringBuilder<64> WildcardUrlPath; + WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; + + m_IsOk = false; + + ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + int EffectivePort = BasePort; + + Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + + // Sharing violation implies the port is being used by another process + for (int PortOffset = 1; (Result == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + WildcardUrlPath.Reset(); + WildcardUrlPath << u8"http://*:"sv << int64_t(EffectivePort) << u8"/"sv; + + Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + } + + m_BaseUris.clear(); + if (Result == NO_ERROR) + { + m_BaseUris.push_back(WildcardUrlPath.c_str()); + } + else if (Result == ERROR_ACCESS_DENIED) + { + // If we can't register the wildcard path, we fall back to local paths + // This local paths allow requests originating locally to function, but will not allow + // remote origin requests to function. This can be remedied by using netsh + // during an install process to grant permissions to route public access to the appropriate + // port for the current user. eg: + // netsh http add urlacl url=http://*:8558/ user=<some_user> + + ZEN_WARN("Unable to register handler using '{}' - falling back to local-only", WideToUtf8(WildcardUrlPath)); + + const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; + + ULONG InternalResult = ERROR_SHARING_VIOLATION; + for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + + for (const std::u8string_view Host : Hosts) + { + WideStringBuilder<64> LocalUrlPath; + LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; + + InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + + if (InternalResult == NO_ERROR) + { + ZEN_INFO("Registered local handler '{}'", WideToUtf8(LocalUrlPath)); + + m_BaseUris.push_back(LocalUrlPath.c_str()); + } + else + { + break; + } + } + } + } + + if (m_BaseUris.empty()) + { + ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; + + WideStringBuilder<64> QueueName; + QueueName << "zenserver_" << EffectivePort; + + Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, + /* Name */ QueueName.c_str(), + /* SecurityAttributes */ nullptr, + /* Flags */ 0, + &m_RequestQueueHandle); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + + return EffectivePort; + } + + HttpBindingInfo.Flags.Present = 1; + HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle; + + Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo)); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + + return EffectivePort; + } + + // Configure rejection method. Default is to drop the connection, it's better if we + // return an explicit error code when the queue cannot accept more requests + + { + HTTP_503_RESPONSE_VERBOSITY VerbosityInformation = Http503ResponseVerbosityLimited; + + Result = HttpSetRequestQueueProperty(m_RequestQueueHandle, + HttpServer503VerbosityProperty, + &VerbosityInformation, + sizeof VerbosityInformation, + 0, + 0); + } + + // Tune the maximum number of pending requests in the http.sys request queue. By default + // the value is 1000 which is plenty for single user machines but for dedicated servers + // serving many users it makes sense to increase this to a higher number to help smooth + // out intermittent stalls like we might experience when GC is triggered + + if (m_InitialConfig.IsDedicatedServer) + { + ULONG QueueLength = 50000; + + Result = HttpSetRequestQueueProperty(m_RequestQueueHandle, HttpServerQueueLengthProperty, &QueueLength, sizeof QueueLength, 0, 0); + + if (Result != NO_ERROR) + { + ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result); + } + } + + // Create I/O completion port + + std::error_code ErrorCode; + m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); + + if (ErrorCode) + { + ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); + } + else + { + m_IsOk = true; + + ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + } + + // This is not available in all Windows SDK versions so for now we can't use recently + // released functionality. We should investigate how to get more recent SDK releases + // into the build + +# if 0 + if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4)) + { + ZEN_DEBUG("HTTP3 is available"); + } + else + { + ZEN_DEBUG("HTTP3 is NOT available"); + } +# endif + + return EffectivePort; +} + +void +HttpSysServer::Cleanup() +{ + ++m_IsShuttingDown; + + if (m_RequestQueueHandle) + { + HttpCloseRequestQueue(m_RequestQueueHandle); + m_RequestQueueHandle = nullptr; + } + + if (m_HttpUrlGroupId) + { + HttpCloseUrlGroup(m_HttpUrlGroupId); + m_HttpUrlGroupId = 0; + } + + if (m_HttpSessionId) + { + HttpCloseServerSession(m_HttpSessionId); + m_HttpSessionId = 0; + } +} + +WorkerThreadPool& +HttpSysServer::WorkPool() +{ + if (!m_AsyncWorkPool) + { + RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); + + if (!m_AsyncWorkPool) + { + m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"); + } + } + + return *m_AsyncWorkPool; +} + +void +HttpSysServer::StartServer() +{ + const int InitialRequestCount = 32; + + for (int i = 0; i < InitialRequestCount; ++i) + { + IssueNewRequestMaybe(); + } +} + +void +HttpSysServer::Run(bool IsInteractive) +{ + if (IsInteractive) + { + zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit"); + } + + do + { + // int WaitTimeout = -1; + int WaitTimeout = 100; + + if (IsInteractive) + { + WaitTimeout = 1000; + + if (_kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + UpdateLofreqTimerValue(); + } while (!IsApplicationExitRequested()); +} + +void +HttpSysServer::OnHandlingNewRequest() +{ + if (--m_PendingRequests > m_MinPendingRequests) + { + // We have more than the minimum number of requests pending, just let someone else + // enqueue new requests. This should be the common case as we check if we need to + // enqueue more requests before exiting the completion handler. + return; + } + + IssueNewRequestMaybe(); +} + +void +HttpSysServer::IssueNewRequestMaybe() +{ + if (m_IsShuttingDown.load(std::memory_order::acquire)) + { + return; + } + + if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests) + { + return; + } + + std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this); + + std::error_code ErrorCode; + Request->IssueInitialRequest(ErrorCode); + + if (ErrorCode) + { + // No request was actually issued. What is the appropriate response? + + return; + } + + // This may end up exceeding the MaxPendingRequests limit, but it's not + // really a problem. I'm doing it this way mostly to avoid dealing with + // exceptions here + ++m_PendingRequests; + + Request.release(); +} + +void +HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) +{ + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); + Service.SetUriPrefixLength(PathUtf16.size() + 1 /* leading slash */); + + // Convert to wide string + + for (const std::wstring& BaseUri : m_BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; + + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + + return; + } + } +} + +void +HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) +{ + ZEN_UNUSED(Service); + + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); + + // Convert to wide string + + for (const std::wstring& BaseUri : m_BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; + + ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + } + } +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler) +{ +} + +HttpSysTransaction::~HttpSysTransaction() +{ +} + +PTP_IO +HttpSysTransaction::Iocp() +{ + return m_HttpServer.m_IoThreadPool->Iocp(); +} + +HANDLE +HttpSysTransaction::RequestQueueHandle() +{ + return m_HttpServer.m_RequestQueueHandle; +} + +void +HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode) +{ + m_InitialHttpHandler.IssueRequest(ErrorCode); +} + +thread_local bool t_IsHttpSysThreadNamed = false; +static std::atomic<int> HttpSysThreadIndex = 0; + +static void +NameCurrentHttpSysThread() +{ + t_IsHttpSysThreadNamed = true; + const int ThreadIndex = ++HttpSysThreadIndex; + zen::ExtendableStringBuilder<16> ThreadName; + ThreadName << "httpio_" << ThreadIndex; + SetCurrentThreadName(ThreadName); +} + +void +HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io) +{ + ZEN_UNUSED(Io, Instance); + + // Assign names to threads for context + + if (!t_IsHttpSysThreadNamed) + { + NameCurrentHttpSysThread(); + } + + // Note that for a given transaction we may be in this completion function on more + // than one thread at any given moment. This means we need to be careful about what + // happens in here + + HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + + if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) + { + delete Transaction; + } + + // Ensure new requests are enqueued as necessary. We do this here instead + // of inside the transaction completion handler now to avoid spending time + // in unrelated API calls while holding the transaction lock + + if (HttpSysServer* HttpServer = reinterpret_cast<HttpSysServer*>(pContext)) + { + HttpServer->IssueNewRequestMaybe(); + } +} + +bool +HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler) +{ + ZEN_TRACE_CPU("httpsys::Transaction::IssueNextRequest"); + + HttpSysRequestHandler* CurrentHandler = m_CompletionHandler; + m_CompletionHandler = NewCompletionHandler; + + auto _ = MakeGuard([this, CurrentHandler] { + if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler)) + { + delete CurrentHandler; + } + }); + + if (NewCompletionHandler == nullptr) + { + return false; + } + + try + { + std::error_code ErrorCode; + m_CompletionHandler->IssueRequest(ErrorCode); + + if (!ErrorCode) + { + return true; + } + + ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message()); + } + catch (std::exception& Ex) + { + ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what()); + } + + // something went wrong, no request is pending + m_CompletionHandler = nullptr; + + return false; +} + +HttpSysTransaction::Status +HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // We use this to ensure sequential execution of completion handlers + // for any given transaction. It also ensures all member variables are + // in a consistent state for the current thread + + RwLock::ExclusiveLockScope _(m_CompletionMutex); + + bool IsRequestPending = false; + + if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler) + { + if ((CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest()) + { + // Ensure we have a sufficient number of pending requests outstanding + m_HttpServer.OnHandlingNewRequest(); + } + + auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred); + + IsRequestPending = IssueNextRequest(NewCompletionHandler); + } + + if (IsRequestPending) + { + // There is another request pending on this transaction, so it needs to remain valid + return Status::kRequestPending; + } + + if (m_HttpServer.m_IsRequestLoggingEnabled) + { + if (m_HandlerRequest.has_value()) + { + m_HttpServer.m_RequestLog.info("{} {}", ToString(m_HandlerRequest->RequestVerb()), m_HandlerRequest->RelativeUri()); + } + } + + // Transaction done, caller should clean up (delete) this instance + return Status::kDone; +} + +HttpSysServerRequest& +HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) +{ + HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + + // Default request handling + + if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + { + Service.HandleRequest(ThisRequest); + } + + return ThisRequest; +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) +: m_HttpTx(Tx) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); + + const int PrefixLength = Service.UriPrefixLength(); + const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(wchar_t); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + if (AbsPathLength >= PrefixLength) + { + // We convert the URI immediately because most of the code involved prefers to deal + // with utf8. This is overhead which I'd prefer to avoid but for now we just have + // to live with it + + WideToUtf8({(wchar_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)}, + m_UriUtf8); + + std::string_view UriSuffix8{m_UriUtf8}; + + m_UriWithExtension = UriSuffix8; // Retain URI with extension for user access + m_Uri = UriSuffix8; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(UriSuffix8.size() + 1); + } + } + } + else + { + m_UriUtf8.Reset(); + m_Uri = {}; + m_UriWithExtension = {}; + } + + if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) + { + --QueryStringLength; // We skip the leading question mark + + WideToUtf8({(wchar_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(wchar_t)}, m_QueryStringUtf8); + } + else + { + m_QueryStringUtf8.Reset(); + } + + m_QueryString = std::string_view(m_QueryStringUtf8); + m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); + m_ContentLength = GetContentLength(HttpRequestPtr); + m_ContentType = GetContentType(HttpRequestPtr); + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = GetAcceptType(HttpRequestPtr); + } + + if (m_Verb == HttpVerb::kHead) + { + SetSuppressResponseBody(); + } +} + +HttpSysServerRequest::~HttpSysServerRequest() +{ +} + +Oid +HttpSysServerRequest::ParseSessionId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Session"sv) + { + if (Header.RawValueLength == Oid::StringLength) + { + return Oid::FromHexString({Header.pRawValue, Header.RawValueLength}); + } + } + } + + return {}; +} + +uint32_t +HttpSysServerRequest::ParseRequestId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Request"sv) + { + std::string_view RequestValue{Header.pRawValue, Header.RawValueLength}; + uint32_t RequestId = 0; + std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId); + return RequestId; + } + } + + return 0; +} + +IoBuffer +HttpSysServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = + new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size()); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + if (m_HttpTx.Server().IsAsyncResponseEnabled()) + { + m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler)); + } + else + { + ContinuationHandler(m_HttpTx.ServerRequest()); + } +} + +bool +HttpSysServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + HTTP_REQUEST* Req = m_HttpTx.HttpRequest(); + const HTTP_KNOWN_HEADER& RangeHeader = Req->Headers.KnownHeaders[HttpHeaderRange]; + + return TryParseHttpRangeHeader({RangeHeader.pRawValue, RangeHeader.RawValueLength}, Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) +{ +} + +InitialRequestHandler::~InitialRequestHandler() +{ +} + +void +InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) +{ + ZEN_TRACE_CPU("httpsys::Request::IssueRequest"); + + HttpSysTransaction& Tx = Transaction(); + PTP_IO Iocp = Tx.Iocp(); + HTTP_REQUEST* HttpReq = Tx.HttpRequest(); + + StartThreadpoolIo(Iocp); + + ULONG HttpApiResult; + + if (IsInitialRequest()) + { + HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), + HTTP_NULL_ID, + HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, + HttpReq, + RequestBufferSize(), + NULL, + Tx.Overlapped()); + } + else + { + // The http.sys team recommends limiting the size to 128KB + static const uint64_t kMaxBytesPerApiCall = 128 * 1024; + + uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; + const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); + void* BufferWriteCursor = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; + + HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + 0, /* Flags */ + BufferWriteCursor, + gsl::narrow<ULONG>(BytesToReadThisCall), + nullptr, // BytesReturned + Tx.Overlapped()); + } + + if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) + { + CancelThreadpoolIo(Iocp); + + ErrorCode = MakeErrorCode(HttpApiResult); + + ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message()); + + return; + } + + ErrorCode.clear(); +} + +HttpSysRequestHandler* +InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); + + switch (IoResult) + { + default: + case ERROR_OPERATION_ABORTED: + return nullptr; + + case ERROR_MORE_DATA: // Insufficient buffer space + case NO_ERROR: + break; + } + + ZEN_TRACE_CPU("httpsys::HandleCompletion"); + + // Route request + + try + { + HTTP_REQUEST* HttpReq = HttpRequest(); + +# if 0 + for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + { + auto& ReqInfo = HttpReq->pRequestInfo[i]; + + switch (ReqInfo.InfoType) + { + case HttpRequestInfoTypeRequestTiming: + { + const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeAuth: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeChannelBind: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslProtocol: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslTokenBindingDraft: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslTokenBinding: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeTcpInfoV0: + { + const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeRequestSizing: + { + const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo); + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeQuicStats: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeTcpInfoV1: + { + const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + } + } +# endif + + if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) + { + if (m_IsInitialRequest) + { + m_ContentLength = GetContentLength(HttpReq); + const HttpContentType ContentType = GetContentType(HttpReq); + + if (m_ContentLength) + { + // Handle initial chunk read by copying any payload which has already been copied + // into our embedded request buffer + + m_PayloadBuffer = IoBuffer(m_ContentLength); + m_PayloadBuffer.SetContentType(ContentType); + + uint64_t BytesToRead = m_ContentLength; + uint8_t* const BufferBase = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()); + uint8_t* BufferWriteCursor = BufferBase; + + const int EntityChunkCount = HttpReq->EntityChunkCount; + + for (int i = 0; i < EntityChunkCount; ++i) + { + HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; + + ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); + + const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; + + ZEN_ASSERT(BufferLength <= BytesToRead); + + memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); + + BufferWriteCursor += BufferLength; + BytesToRead -= BufferLength; + } + + m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; + } + } + else + { + m_CurrentPayloadOffset += NumberOfBytesTransferred; + } + + if (m_CurrentPayloadOffset != m_ContentLength) + { + // Body not complete, issue another read request to receive more body data + return this; + } + + // Request body received completely + + m_PayloadBuffer.MakeImmutable(); + + HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer); + + if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler) + { + return Response; + } + + if (!ThisRequest.IsHandled()) + { + return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); + } + } + + // Unable to route + return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); + } + catch (std::system_error& SystemError) + { + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, SystemError.what()); + } + + ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, SystemError.what()); + } + catch (std::bad_alloc& BadAlloc) + { + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, BadAlloc.what()); + } + catch (std::exception& ex) + { + ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, ex.what()); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// HttpServer interface implementation +// + +int +HttpSysServer::Initialize(int BasePort) +{ + int EffectivePort = InitializeServer(BasePort); + StartServer(); + return EffectivePort; +} + +void +HttpSysServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} +void +HttpSysServer::RegisterService(HttpService& Service) +{ + RegisterService(Service.BaseUri(), Service); +} + +Ref<HttpServer> +CreateHttpSysServer(HttpSysConfig Config) +{ + return Ref<HttpServer>(new HttpSysServer(Config)); +} + +} // namespace zen +#endif diff --git a/src/zenhttp/servers/httpsys.h b/src/zenhttp/servers/httpsys.h new file mode 100644 index 000000000..6a6b16525 --- /dev/null +++ b/src/zenhttp/servers/httpsys.h @@ -0,0 +1,28 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +#ifndef ZEN_WITH_HTTPSYS +# if ZEN_PLATFORM_WINDOWS +# define ZEN_WITH_HTTPSYS 1 +# else +# define ZEN_WITH_HTTPSYS 0 +# endif +#endif + +namespace zen { + +struct HttpSysConfig +{ + unsigned int ThreadCount = 0; + unsigned int AsyncWorkThreadCount = 0; + bool IsAsyncResponseEnabled = true; + bool IsRequestLoggingEnabled = false; + bool IsDedicatedServer = false; +}; + +Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config); + +} // namespace zen diff --git a/src/zenhttp/servers/iothreadpool.cpp b/src/zenhttp/servers/iothreadpool.cpp new file mode 100644 index 000000000..da4b42e28 --- /dev/null +++ b/src/zenhttp/servers/iothreadpool.cpp @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "iothreadpool.h" + +#include <zencore/except.h> + +#if ZEN_PLATFORM_WINDOWS + +namespace zen { + +WinIoThreadPool::WinIoThreadPool(int InThreadCount, int InMaxThreadCount) +{ + ZEN_ASSERT(InThreadCount); + + if (InMaxThreadCount < InThreadCount) + { + InMaxThreadCount = InThreadCount; + } + + m_ThreadPool = CreateThreadpool(NULL); + + SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount); + SetThreadpoolThreadMaximum(m_ThreadPool, InMaxThreadCount); + + InitializeThreadpoolEnvironment(&m_CallbackEnvironment); + + m_CleanupGroup = CreateThreadpoolCleanupGroup(); + + SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool); + + SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); +} + +WinIoThreadPool::~WinIoThreadPool() +{ + CloseThreadpool(m_ThreadPool); +} + +void +WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) +{ + ZEN_ASSERT(!m_ThreadPoolIo); + + m_ThreadPoolIo = CreateThreadpoolIo(IoHandle, Callback, Context, &m_CallbackEnvironment); + + if (!m_ThreadPoolIo) + { + ErrorCode = MakeErrorCodeFromLastError(); + } +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/servers/iothreadpool.h b/src/zenhttp/servers/iothreadpool.h new file mode 100644 index 000000000..e75e95e58 --- /dev/null +++ b/src/zenhttp/servers/iothreadpool.h @@ -0,0 +1,31 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> + +# include <system_error> + +namespace zen { + +class WinIoThreadPool +{ +public: + WinIoThreadPool(int InThreadCount, int InMaxThreadCount); + ~WinIoThreadPool(); + + void CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode); + inline PTP_IO Iocp() const { return m_ThreadPoolIo; } + +private: + PTP_POOL m_ThreadPool = nullptr; + PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; + PTP_IO m_ThreadPoolIo = nullptr; + TP_CALLBACK_ENVIRON m_CallbackEnvironment; +}; + +} // namespace zen +#endif |