diff options
| author | Stefan Boberg <[email protected]> | 2023-10-13 09:55:27 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-10-13 09:55:27 +0200 |
| commit | 74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d (patch) | |
| tree | acae59dac67b4d051403f35e580201c214ec4fda /src/zenhttp/servers/httpasio.cpp | |
| parent | faster oplog iteration (#471) (diff) | |
| download | zen-74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d.tar.xz zen-74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d.zip | |
restructured zenhttp (#472)
separating the http server implementations into a directory and moved diagsvcs into zenserver since it's somewhat hard-coded for it
Diffstat (limited to 'src/zenhttp/servers/httpasio.cpp')
| -rw-r--r-- | src/zenhttp/servers/httpasio.cpp | 1052 |
1 files changed, 1052 insertions, 0 deletions
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp new file mode 100644 index 000000000..0c6b189f9 --- /dev/null +++ b/src/zenhttp/servers/httpasio.cpp @@ -0,0 +1,1052 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpasio.h" + +#include <zencore/except.h> +#include <zencore/logging.h> +#include <zencore/thread.h> +#include <zencore/trace.h> +#include <zenhttp/httpserver.h> + +#include "httpparser.h" + +#include <deque> +#include <memory> +#include <string_view> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#if ZEN_PLATFORM_WINDOWS +# include <conio.h> +# include <mstcpip.h> +#endif +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#define ASIO_VERBOSE_TRACE 0 + +#if ASIO_VERBOSE_TRACE +# define ZEN_TRACE_VERBOSE ZEN_TRACE +#else +# define ZEN_TRACE_VERBOSE(fmtstr, ...) +#endif + +namespace zen::asio_http { + +using namespace std::literals; + +struct HttpAcceptor; +struct HttpResponse; +struct HttpServerConnection; + +static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); + +inline spdlog::logger& +InitLogger() +{ + spdlog::logger& Logger = logging::Get("asio"); + // Logger.set_level(spdlog::level::trace); + return Logger; +} + +inline spdlog::logger& +Log() +{ + static spdlog::logger& g_Logger = InitLogger(); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAsioServerImpl +{ +public: + HttpAsioServerImpl(); + ~HttpAsioServerImpl(); + + int Start(uint16_t Port, int ThreadCount); + void Stop(); + void RegisterService(const char* UrlPath, HttpService& Service); + HttpService* RouteRequest(std::string_view Url); + + asio::io_service m_IoService; + asio::io_service::work m_Work{m_IoService}; + std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; + std::vector<std::thread> m_ThreadPool; + + struct ServiceEntry + { + std::string ServiceUrlPath; + HttpService* Service; + }; + + RwLock m_Lock; + std::vector<ServiceEntry> m_UriHandlers; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpAsioServerRequest : public HttpServerRequest +{ +public: + HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpAsioServerRequest(); + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpAsioServerRequest(const HttpAsioServerRequest&) = delete; + HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete; + + HttpRequestParser& m_Request; + IoBuffer m_PayloadBuffer; + std::unique_ptr<HttpResponse> m_Response; +}; + +struct HttpResponse +{ +public: + HttpResponse() = default; + explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) + { + ZEN_TRACE_CPU("asio::InitializeForPayload"); + + m_ResponseCode = ResponseCode; + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { +#if 1 + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); +#else + IoBuffer TempBuffer = std::move(Buffer); + TempBuffer.MakeOwned(); + m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer)); +#endif + } + + uint64_t LocalDataSize = 0; + + m_AsioBuffers.push_back({}); // Placeholder for header + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // TODO: Use direct file transfer, via TransmitFile/sendfile + // + // this looks like it requires some custom asio plumbing however + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + else + { + // Send from memory + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + } + m_ContentLength = LocalDataSize; + + auto Headers = GetHeaders(); + m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size()); + } + + uint16_t ResponseCode() const { return m_ResponseCode; } + uint64_t ContentLength() const { return m_ContentLength; } + + const std::vector<asio::const_buffer>& AsioBuffers() const { return m_AsioBuffers; } + + std::string_view GetHeaders() + { + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" + << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Length: " << ContentLength() << "\r\n"sv; + + if (!m_IsKeepAlive) + { + m_Headers << "Connection: close\r\n"sv; + } + + m_Headers << "\r\n"sv; + + return m_Headers; + } + + void SuppressPayload() { m_AsioBuffers.resize(1); } + +private: + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; + std::vector<IoBuffer> m_DataBuffers; + std::vector<asio::const_buffer> m_AsioBuffers; + ExtendableStringBuilder<160> m_Headers; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnection> +{ + HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket); + ~HttpServerConnection(); + + std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); } + + // HttpConnectionBase implementation + + virtual void TerminateConnection() override; + virtual void HandleRequest() override; + + void HandleNewRequest(); + +private: + enum class RequestState + { + kInitialState, + kInitialRead, + kReadingMore, + kWriting, + kWritingFinal, + kDone, + kTerminated + }; + + RequestState m_RequestState = RequestState::kInitialState; + HttpRequestParser m_RequestData{*this}; + + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false); + void CloseConnection(); + + HttpAsioServerImpl& m_Server; + asio::streambuf m_RequestBuffer; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::atomic<uint32_t> m_RequestCounter{0}; + uint32_t m_ConnectionId = 0; + Ref<IHttpPackageHandler> m_PackageHandler; + + RwLock m_ResponsesLock; + std::deque<std::unique_ptr<HttpResponse>> m_Responses; +}; + +std::atomic<uint32_t> g_ConnectionIdCounter{0}; + +HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket) +: m_Server(Server) +, m_Socket(std::move(Socket)) +, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) +{ + ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); +} + +HttpServerConnection::~HttpServerConnection() +{ + ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); +} + +void +HttpServerConnection::HandleNewRequest() +{ + EnqueueRead(); +} + +void +HttpServerConnection::TerminateConnection() +{ + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) + { + return; + } + + m_RequestState = RequestState::kTerminated; + ZEN_ASSERT(m_Socket); + + // Terminating, we don't care about any errors when closing socket + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_both, Ec); + m_Socket->close(Ec); +} + +void +HttpServerConnection::EnqueueRead() +{ + if (m_RequestState == RequestState::kInitialRead) + { + m_RequestState = RequestState::kReadingMore; + } + else + { + m_RequestState = RequestState::kInitialRead; + } + + m_RequestBuffer.prepare(64 * 1024); + + asio::async_read(*m_Socket.get(), + m_RequestBuffer, + asio::transfer_at_least(1), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); +} + +void +HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead) + { + ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection: {}, reason: '{}'", m_ConnectionId, Ec.message()); + return; + } + else + { + ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); + return TerminateConnection(); + } + } + + ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}", + m_ConnectionId, + m_RequestCounter.load(std::memory_order_relaxed), + zen::GetCurrentThreadId(), + NiceBytes(ByteCount)); + + while (m_RequestBuffer.size()) + { + const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); + + size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size()); + if (Result == ~0ull) + { + return TerminateConnection(); + } + + m_RequestBuffer.consume(Result); + } + + switch (m_RequestState) + { + case RequestState::kDone: + case RequestState::kWritingFinal: + case RequestState::kTerminated: + break; + + default: + EnqueueRead(); + break; + } +} + +void +HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop) +{ + if (Ec) + { + ZEN_WARN("on data sent ERROR, connection: {}, reason: '{}'", m_ConnectionId, Ec.message()); + TerminateConnection(); + } + else + { + ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}", + m_ConnectionId, + m_RequestCounter.load(std::memory_order_relaxed), + zen::GetCurrentThreadId(), + NiceBytes(ByteCount)); + + if (!m_RequestData.IsKeepAlive()) + { + CloseConnection(); + } + else + { + if (Pop) + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.pop_front(); + } + + m_RequestCounter.fetch_add(1); + } + } +} + +void +HttpServerConnection::CloseConnection() +{ + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) + { + return; + } + ZEN_ASSERT(m_Socket); + m_RequestState = RequestState::kDone; + + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + if (Ec) + { + ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); + } + m_Socket->close(Ec); + if (Ec) + { + ZEN_WARN("socket close ERROR, reason '{}'", Ec.message()); + } +} + +void +HttpServerConnection::HandleRequest() +{ + if (!m_RequestData.IsKeepAlive()) + { + m_RequestState = RequestState::kWritingFinal; + + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + + if (Ec) + { + ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); + } + } + else + { + m_RequestState = RequestState::kWriting; + } + + if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) + { + ZEN_TRACE_CPU("asio::HandleRequest"); + + HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body()); + + ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed)); + + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + { + try + { + Service->HandleRequest(Request); + } + catch (std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } + } + + if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response)) + { + // Transmit the response + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + Response->SuppressPayload(); + } + + auto ResponseBuffers = Response->AsioBuffers(); + + uint64_t ResponseLength = 0; + + for (auto& Buffer : ResponseBuffers) + { + ResponseLength += Buffer.size(); + } + + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.push_back(std::move(Response)); + } + + // TODO: should cork/uncork for Linux? + + { + ZEN_TRACE_CPU("asio::async_write"); + asio::async_write(*m_Socket.get(), + ResponseBuffers, + asio::transfer_exactly(ResponseLength), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, true); + }); + } + return; + } + } + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "\r\n"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Connection: close\r\n" + "\r\n"sv; + } + + asio::async_write( + *m_Socket.get(), + asio::buffer(Response), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); + } + else + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "No suitable route found"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "Connection: close\r\n" + "\r\n" + "No suitable route found"sv; + } + + asio::async_write( + *m_Socket.get(), + asio::buffer(Response), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); + } +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAcceptor +{ + HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort) + : m_Server(Server) + , m_IoService(IoService) + , m_Acceptor(m_IoService, asio::ip::tcp::v6()) + { + m_Acceptor.set_option(asio::ip::v6_only(false)); +#if ZEN_PLATFORM_WINDOWS + // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms + typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> excluse_address; + m_Acceptor.set_option(excluse_address(true)); +#else // ZEN_PLATFORM_WINDOWS + m_Acceptor.set_option(asio::socket_base::reuse_address(false)); +#endif // ZEN_PLATFORM_WINDOWS + + m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + uint16_t EffectivePort = BasePort; + + asio::error_code BindErrorCode; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + // Sharing violation implies the port is being used by another process + for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode == asio::error::access_denied) + { + EffectivePort = 0; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode) + { + ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message()); + } + +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor.native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +#endif + m_Acceptor.listen(); + + ZEN_INFO("Started asio server at port '{}'", EffectivePort); + } + + void Start() + { + m_Acceptor.listen(); + InitAccept(); + } + + void Stop() { m_IsStopped = true; } + + void InitAccept() + { + auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService); + asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); + + m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", + m_Acceptor.local_endpoint().address().to_string(), + m_Acceptor.local_endpoint().port(), + Ec.message()); + } + else + { + // New connection established, pass socket ownership into connection object + // and initiate request handling loop. The connection lifetime is + // managed by the async read/write loop by passing the shared + // reference to the callbacks. + + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + } + + if (!m_IsStopped.load()) + { + InitAccept(); + } + else + { + std::error_code CloseEc; + m_Acceptor.close(CloseEc); + if (CloseEc) + { + ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message()); + } + } + }); + } + + int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } + +private: + HttpAsioServerImpl& m_Server; + asio::io_service& m_IoService; + asio::ip::tcp::acceptor m_Acceptor; + std::atomic<bool> m_IsStopped{false}; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer) +: m_Request(Request) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const int PrefixLength = Service.UriPrefixLength(); + + std::string_view Uri = Request.Url(); + Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size()))); + m_Uri = Uri; + m_UriWithExtension = Uri; + m_QueryString = Request.QueryString(); + + m_Verb = Request.RequestVerb(); + m_ContentLength = Request.Body().Size(); + m_ContentType = Request.ContentType(); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + // Parse any extension, to allow requesting a particular response encoding via the URL + + { + std::string_view UriSuffix8{m_Uri}; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); + } + } + } + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = Request.AcceptType(); + } +} + +HttpAsioServerRequest::~HttpAsioServerRequest() +{ +} + +Oid +HttpAsioServerRequest::ParseSessionId() const +{ + return m_Request.SessionId(); +} + +uint32_t +HttpAsioServerRequest::ParseRequestId() const +{ + return m_Request.RequestId(); +} + +IoBuffer +HttpAsioServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(HttpContentType::kBinary)); + std::array<IoBuffer, 0> Empty; + + m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(ContentType)); + m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(!m_Response); + m_Response.reset(new HttpResponse(ContentType)); + + IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); +} + +void +HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + ZEN_ASSERT(!m_Response); + + // Not one bit async, innit + ContinuationHandler(*this); +} + +bool +HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerImpl::HttpAsioServerImpl() +{ +} + +HttpAsioServerImpl::~HttpAsioServerImpl() +{ +} + +int +HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount) +{ + ZEN_ASSERT(ThreadCount > 0); + + ZEN_INFO("starting asio http with {} service threads", ThreadCount); + + m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port)); + m_Acceptor->Start(); + + // This should consist of a set of minimum threads and grow on demand to + // meet concurrency needs? Right now we end up allocating a large number + // of threads even if we never end up using all of them, which seems + // wasteful. It's also not clear how the demand for concurrency should + // be balanced with the engine side - ideally we'd have some kind of + // global scheduling to prevent one side from preventing the other side + // from making progress. Or at the very least, thread priorities should + // be considered. + + for (int i = 0; i < ThreadCount; ++i) + { + m_ThreadPool.emplace_back([this, Index = i + 1] { + SetCurrentThreadName(fmt::format("asio_io_{}", Index)); + + try + { + m_IoService.run(); + } + catch (std::exception& e) + { + ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what()); + } + }); + } + + ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort()); + + return m_Acceptor->GetAcceptPort(); +} + +void +HttpAsioServerImpl::Stop() +{ + m_Acceptor->Stop(); + m_IoService.stop(); + for (auto& Thread : m_ThreadPool) + { + Thread.join(); + } +} + +void +HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) +{ + std::string_view UrlPath(InUrlPath); + Service.SetUriPrefixLength(UrlPath.size()); + if (!UrlPath.empty() && UrlPath.back() == '/') + { + UrlPath.remove_suffix(1); + } + + RwLock::ExclusiveLockScope _(m_Lock); + m_UriHandlers.push_back({std::string(UrlPath), &Service}); +} + +HttpService* +HttpAsioServerImpl::RouteRequest(std::string_view Url) +{ + RwLock::SharedLockScope _(m_Lock); + + HttpService* CandidateService = nullptr; + std::string::size_type CandidateMatchSize = 0; + for (const ServiceEntry& SvcEntry : m_UriHandlers) + { + const std::string& SvcUrl = SvcEntry.ServiceUrlPath; + const std::string::size_type SvcUrlSize = SvcUrl.size(); + if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && + ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) + { + CandidateMatchSize = SvcUrl.size(); + CandidateService = SvcEntry.Service; + } + } + + return CandidateService; +} + +} // namespace zen::asio_http + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +class HttpAsioServer : public HttpServer +{ +public: + HttpAsioServer(unsigned int ThreadCount); + ~HttpAsioServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual int Initialize(int BasePort) override; + virtual void Run(bool IsInteractiveSession) override; + virtual void RequestExit() override; + virtual void Close() override; + +private: + Event m_ShutdownEvent; + int m_BasePort = 0; + unsigned int m_ThreadCount = 0; + + std::unique_ptr<asio_http::HttpAsioServerImpl> m_Impl; +}; + +HttpAsioServer::HttpAsioServer(unsigned int ThreadCount) +: m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +, m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>()) +{ + ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser)); +} + +HttpAsioServer::~HttpAsioServer() +{ + if (m_Impl) + { + ZEN_ERROR("~HttpAsioServer() called without calling Close() first"); + } +} + +void +HttpAsioServer::Close() +{ + try + { + m_Impl->Stop(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception stopping http asio server: {}", ex.what()); + } + m_Impl.reset(); +} + +void +HttpAsioServer::RegisterService(HttpService& Service) +{ + m_Impl->RegisterService(Service.BaseUri(), Service); +} + +int +HttpAsioServer::Initialize(int BasePort) +{ + m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), m_ThreadCount); + return m_BasePort; +} + +void +HttpAsioServer::Run(bool IsInteractive) +{ + const bool TestMode = !IsInteractive; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +#if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#endif +} + +void +HttpAsioServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +Ref<HttpServer> +CreateHttpAsioServer(unsigned int ThreadCount) +{ + return Ref<HttpServer>{new HttpAsioServer(ThreadCount)}; +} + +} // namespace zen |