diff options
| author | Stefan Boberg <[email protected]> | 2023-10-13 09:55:27 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-10-13 09:55:27 +0200 |
| commit | 74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d (patch) | |
| tree | acae59dac67b4d051403f35e580201c214ec4fda /src/zenhttp/servers/httpsys.cpp | |
| parent | faster oplog iteration (#471) (diff) | |
| download | zen-74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d.tar.xz zen-74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d.zip | |
restructured zenhttp (#472)
separating the http server implementations into a directory and moved diagsvcs into zenserver since it's somewhat hard-coded for it
Diffstat (limited to 'src/zenhttp/servers/httpsys.cpp')
| -rw-r--r-- | src/zenhttp/servers/httpsys.cpp | 2012 |
1 files changed, 2012 insertions, 0 deletions
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp new file mode 100644 index 000000000..c1b4717cb --- /dev/null +++ b/src/zenhttp/servers/httpsys.cpp @@ -0,0 +1,2012 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpsys.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/string.h> +#include <zencore/timer.h> +#include <zencore/trace.h> +#include <zenhttp/httpshared.h> + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include <zencore/windows.h> +# include <zencore/workthreadpool.h> +# include "iothreadpool.h" + +# include <http.h> + +namespace spdlog { +class logger; +} + +namespace zen { + +/** + * @brief Windows implementation of HTTP server based on http.sys + * + * This requires elevation to function by default but system configuration + * can soften this requirement. + * + * See README.md for details. + */ +class HttpSysServer : public HttpServer +{ + friend class HttpSysTransaction; + +public: + explicit HttpSysServer(const HttpSysConfig& Config); + ~HttpSysServer(); + + // HttpServer interface implementation + + virtual int Initialize(int BasePort) override; + virtual void Run(bool TestMode) override; + virtual void RequestExit() override; + virtual void RegisterService(HttpService& Service) override; + virtual void Close() override; + + WorkerThreadPool& WorkPool(); + + inline bool IsOk() const { return m_IsOk; } + inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + +private: + int InitializeServer(int BasePort); + void Cleanup(); + + void StartServer(); + void OnHandlingNewRequest(); + void IssueNewRequestMaybe(); + + void RegisterService(const char* Endpoint, HttpService& Service); + void UnregisterService(const char* Endpoint, HttpService& Service); + +private: + spdlog::logger& m_Log; + spdlog::logger& m_RequestLog; + spdlog::logger& Log() { return m_Log; } + + bool m_IsOk = false; + bool m_IsHttpInitialized = false; + bool m_IsRequestLoggingEnabled = false; + bool m_IsAsyncResponseEnabled = true; + + std::unique_ptr<WinIoThreadPool> m_IoThreadPool; + + RwLock m_AsyncWorkPoolInitLock; + WorkerThreadPool* m_AsyncWorkPool = nullptr; + + std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; + HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; + HANDLE m_RequestQueueHandle = 0; + std::atomic_int32_t m_PendingRequests{0}; + std::atomic_int32_t m_IsShuttingDown{0}; + int32_t m_MinPendingRequests = 16; + int32_t m_MaxPendingRequests = 128; + Event m_ShutdownEvent; + HttpSysConfig m_InitialConfig; +}; + +} // namespace zen +#endif + +#if ZEN_WITH_HTTPSYS + +# include <conio.h> +# include <mstcpip.h> +# pragma comment(lib, "httpapi.lib") + +std::wstring +UTF8_to_UTF16(const char* InPtr) +{ + std::wstring OutString; + unsigned int Codepoint; + + while (*InPtr != 0) + { + unsigned char InChar = static_cast<unsigned char>(*InPtr); + + if (InChar <= 0x7f) + Codepoint = InChar; + else if (InChar <= 0xbf) + Codepoint = (Codepoint << 6) | (InChar & 0x3f); + else if (InChar <= 0xdf) + Codepoint = InChar & 0x1f; + else if (InChar <= 0xef) + Codepoint = InChar & 0x0f; + else + Codepoint = InChar & 0x07; + + ++InPtr; + + if (((*InPtr & 0xc0) != 0x80) && (Codepoint <= 0x10ffff)) + { + if (Codepoint > 0xffff) + { + OutString.append(1, static_cast<wchar_t>(0xd800 + (Codepoint >> 10))); + OutString.append(1, static_cast<wchar_t>(0xdc00 + (Codepoint & 0x03ff))); + } + else if (Codepoint < 0xd800 || Codepoint >= 0xe000) + { + OutString.append(1, static_cast<wchar_t>(Codepoint)); + } + } + } + + return OutString; +} + +namespace zen { + +using namespace std::literals; + +class HttpSysServer; +class HttpSysTransaction; +class HttpMessageResponseRequest; + +////////////////////////////////////////////////////////////////////////// + +HttpVerb +TranslateHttpVerb(HTTP_VERB ReqVerb) +{ + switch (ReqVerb) + { + case HttpVerbOPTIONS: + return HttpVerb::kOptions; + + case HttpVerbGET: + return HttpVerb::kGet; + + case HttpVerbHEAD: + return HttpVerb::kHead; + + case HttpVerbPOST: + return HttpVerb::kPost; + + case HttpVerbPUT: + return HttpVerb::kPut; + + case HttpVerbDELETE: + return HttpVerb::kDelete; + + case HttpVerbCOPY: + return HttpVerb::kCopy; + + default: + // TODO: invalid request? + return (HttpVerb)0; + } +} + +uint64_t +GetContentLength(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; + std::string_view cl(clh.pRawValue, clh.RawValueLength); + uint64_t ContentLength = 0; + std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); + return ContentLength; +}; + +HttpContentType +GetContentType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +HttpContentType +GetAcceptType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +/** + * @brief Base class for any pending or active HTTP transactions + */ +class HttpSysRequestHandler +{ +public: + explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {} + virtual ~HttpSysRequestHandler() = default; + + virtual void IssueRequest(std::error_code& ErrorCode) = 0; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; + HttpSysTransaction& Transaction() { return m_Transaction; } + + HttpSysRequestHandler(const HttpSysRequestHandler&) = delete; + HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete; + +private: + HttpSysTransaction& m_Transaction; +}; + +/** + * This is the handler for the initial HTTP I/O request which will receive the headers + * and however much of the remaining payload might fit in the embedded request buffer. + * + * It is also used to receive any entity body data relating to the request + * + */ +struct InitialRequestHandler : public HttpSysRequestHandler +{ + inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } + inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } + inline bool IsInitialRequest() const { return m_IsInitialRequest; } + + InitialRequestHandler(HttpSysTransaction& InRequest); + ~InitialRequestHandler(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + + bool m_IsInitialRequest = true; + uint64_t m_CurrentPayloadOffset = 0; + uint64_t m_ContentLength = ~uint64_t(0); + IoBuffer m_PayloadBuffer; + UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)]; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpSysServerRequest : public HttpServerRequest +{ +public: + HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpSysServerRequest(); + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpSysServerRequest(const HttpSysServerRequest&) = delete; + HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; + + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; +}; + +/** HTTP transaction + + There will be an instance of this per pending and in-flight HTTP transaction + + */ +class HttpSysTransaction final +{ +public: + HttpSysTransaction(HttpSysServer& Server); + virtual ~HttpSysTransaction(); + + enum class Status + { + kDone, + kRequestPending + }; + + [[nodiscard]] Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + + static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io); + + void IssueInitialRequest(std::error_code& ErrorCode); + bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler); + + PTP_IO Iocp(); + HANDLE RequestQueueHandle(); + inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline HttpSysServer& Server() { return m_HttpServer; } + inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } + + HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload); + + HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); } + + struct CompletionMutexScope + { + CompletionMutexScope(HttpSysTransaction& Tx) : Lock(Tx.m_CompletionMutex) {} + ~CompletionMutexScope() = default; + + RwLock::ExclusiveLockScope Lock; + }; + +private: + OVERLAPPED m_HttpOverlapped{}; + HttpSysServer& m_HttpServer; + + // Tracks which handler is due to handle the next I/O completion event + HttpSysRequestHandler* m_CompletionHandler = nullptr; + RwLock m_CompletionMutex; + InitialRequestHandler m_InitialHttpHandler{*this}; + std::optional<HttpSysServerRequest> m_HandlerRequest; + Ref<IHttpPackageHandler> m_PackageHandler; +}; + +/** + * @brief HTTP request response I/O request handler + * + * Asynchronously streams out a response to an HTTP request via compound + * responses from memory or directly from file + */ + +class HttpMessageResponseRequest : public HttpSysRequestHandler +{ +public: + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> Blobs); + ~HttpMessageResponseRequest(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + void SuppressResponseBody(); // typically used for HEAD requests + +private: + std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks; + uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes + uint16_t m_ResponseCode = 0; + uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists + uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends + bool m_IsInitialResponse = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + std::vector<IoBuffer> m_DataBuffers; + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); +}; + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) +: HttpSysRequestHandler(InRequest) +{ + std::array<IoBuffer, 0> EmptyBufferList; + + InitializeForPayload(ResponseCode, EmptyBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) +: HttpSysRequestHandler(InRequest) +, m_ContentType(HttpContentType::kText) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> BlobList) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + InitializeForPayload(ResponseCode, BlobList); +} + +HttpMessageResponseRequest::~HttpMessageResponseRequest() +{ +} + +void +HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) +{ + ZEN_TRACE_CPU("httpsys::InitializeForPayload"); + + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_HttpDataChunks.reserve(ChunkCount); + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + + // Initialize the full array up front + + uint64_t LocalDataSize = 0; + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // Use direct file transfer + + auto& Chunk = m_HttpDataChunks.emplace_back(); + + Chunk.DataChunkType = HttpDataChunkFromFileHandle; + Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; + Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; + Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; + } + else + { + // Send from memory, need to make sure we chunk the buffer up since + // the underlying data structure only accepts 32-bit chunk sizes for + // memory chunks. When this happens the vector will be reallocated, + // which is fine since this will be a pretty rare case and sending + // the data is going to take a lot longer than a memory allocation :) + + const uint8_t* WriteCursor = reinterpret_cast<const uint8_t*>(Buffer.Data()); + + while (BufferDataSize) + { + const ULONG ThisChunkSize = gsl::narrow<ULONG>(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); + + auto& Chunk = m_HttpDataChunks.emplace_back(); + + Chunk.DataChunkType = HttpDataChunkFromMemory; + Chunk.FromMemory.pBuffer = (void*)WriteCursor; + Chunk.FromMemory.BufferLength = ThisChunkSize; + + BufferDataSize -= ThisChunkSize; + WriteCursor += ThisChunkSize; + } + } + } + + m_RemainingChunkCount = gsl::narrow<uint32_t>(m_HttpDataChunks.size()); + m_TotalDataSize = LocalDataSize; + + if (m_TotalDataSize == 0 && ResponseCode == 200) + { + // Some HTTP clients really don't like empty responses unless a 204 response is sent + m_ResponseCode = uint16_t(HttpResponseCode::NoContent); + } + else + { + m_ResponseCode = ResponseCode; + } +} + +void +HttpMessageResponseRequest::SuppressResponseBody() +{ + m_RemainingChunkCount = 0; + m_HttpDataChunks.clear(); + m_DataBuffers.clear(); +} + +HttpSysRequestHandler* +HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(NumberOfBytesTransferred); + + if (IoResult != NO_ERROR) + { + ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult)); + + // if one transmit failed there's really no need to go on + return nullptr; + } + + if (m_RemainingChunkCount == 0) + { + return nullptr; // All done + } + + return this; +} + +void +HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) +{ + ZEN_TRACE_CPU("httpsys::Response::IssueRequest"); + HttpSysTransaction& Tx = Transaction(); + HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); + PTP_IO const Iocp = Tx.Iocp(); + + StartThreadpoolIo(Iocp); + + // Split payload into batches to play well with the underlying API + + const int MaxChunksPerCall = 9999; + + const int ThisRequestChunkCount = std::min<int>(m_RemainingChunkCount, MaxChunksPerCall); + const int ThisRequestChunkOffset = m_NextDataChunkOffset; + + m_RemainingChunkCount -= ThisRequestChunkCount; + m_NextDataChunkOffset += ThisRequestChunkCount; + + /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA? + + From the docs: + + This flag enables buffering of data in the kernel on a per-response basis. It should + be used by an application doing synchronous I/O, or by a an application doing + asynchronous I/O with no more than one send outstanding at a time. + + Applications using asynchronous I/O which may have more than one send outstanding at + a time should not use this flag. + + When this flag is set, it should be used consistently in calls to the + HttpSendHttpResponse function as well. + */ + + ULONG SendFlags = HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA; + + if (m_RemainingChunkCount) + { + // We need to make more calls to send the full amount of data + SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + } + + ULONG SendResult = 0; + + if (m_IsInitialResponse) + { + // Populate response structure + + HTTP_RESPONSE HttpResponse = {}; + + HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount); + HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset; + + // Server header + // + // By default this will also add a suffix " Microsoft-HTTPAPI/2.0" to this header + // + // This is controlled via a registry key 'DisableServerHeader', at: + // + // Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\HTTP\Parameters + // + // Set DisableServerHeader to 1 to disable suffix, or 2 to disable the header altogether + // (only the latter appears to do anything in my testing, on Windows 10). + // + // (reference https://docs.microsoft.com/en-us/archive/blogs/dsnotes/wswcf-remove-server-header) + // + + PHTTP_KNOWN_HEADER ServerHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderServer]; + ServerHeader->pRawValue = "Zen"; + ServerHeader->RawValueLength = (USHORT)3; + + // Content-length header + + char ContentLengthString[32]; + _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10); + + PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength]; + ContentLengthHeader->pRawValue = ContentLengthString; + ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString); + + // Content-type header + + PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; + + std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + + ContentTypeHeader->pRawValue = ContentTypeString.data(); + ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); + + HttpResponse.StatusCode = m_ResponseCode; + HttpResponse.pReason = ReasonString.data(); + HttpResponse.ReasonLength = (USHORT)ReasonString.size(); + + // Cache policy + + HTTP_CACHE_POLICY CachePolicy; + + CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; + CachePolicy.SecondsToLive = 0; + + // Initial response API call + + SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + &HttpResponse, + &CachePolicy, + NULL, + NULL, + 0, + Tx.Overlapped(), + NULL); + + m_IsInitialResponse = false; + } + else + { + // Subsequent response API calls + + SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + (USHORT)ThisRequestChunkCount, // EntityChunkCount + &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks + NULL, // BytesSent + NULL, // Reserved1 + 0, // Reserved2 + Tx.Overlapped(), // Overlapped + NULL // LogData + ); + } + + auto EmitReponseDetails = [&](StringBuilderBase& ResponseDetails) -> void { + for (int i = 0; i < ThisRequestChunkCount; ++i) + { + const HTTP_DATA_CHUNK Chunk = m_HttpDataChunks[ThisRequestChunkOffset + i]; + + if (i > 0) + { + ResponseDetails << " + "; + } + + switch (Chunk.DataChunkType) + { + case HttpDataChunkFromMemory: + ResponseDetails << "mem:" << uint64_t(Chunk.FromMemory.BufferLength); + break; + + case HttpDataChunkFromFileHandle: + ResponseDetails << "file:"; + { + ResponseDetails << uint64_t(Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart) << "," + << uint64_t(Chunk.FromFileHandle.ByteRange.Length.QuadPart) << ","; + + std::error_code PathEc; + HANDLE FileHandle = Chunk.FromFileHandle.FileHandle; + std::filesystem::path Path = PathFromHandle(FileHandle, PathEc); + + if (PathEc) + { + ResponseDetails << "bad_file(handle=" << reinterpret_cast<uint64_t>(FileHandle) << ",error=" << PathEc.message() + << ")"; + } + else + { + const uint64_t FileSize = FileSizeFromHandle(FileHandle); + ResponseDetails << Path.u8string() << "(" << FileSize << ") handle=" << reinterpret_cast<uint64_t>(FileHandle); + } + } + break; + + case HttpDataChunkFromFragmentCache: + ResponseDetails << "frag:???"; // We do not use these + break; + + case HttpDataChunkFromFragmentCacheEx: + ResponseDetails << "frax:???"; // We do not use these + break; + +# if 0 // Not available in older Windows SDKs + case HttpDataChunkTrailers: + ResponseDetails << "trls:???"; // We do not use these + break; +# endif + + default: + ResponseDetails << "???: " << Chunk.DataChunkType; + break; + } + } + }; + + if (SendResult == NO_ERROR) + { + // Synchronous completion, but the completion event will still be posted to IOCP + + ErrorCode.clear(); + } + else if (SendResult == ERROR_IO_PENDING) + { + // Asynchronous completion, a completion notification will be posted to IOCP + + ErrorCode.clear(); + } + else + { + ErrorCode = MakeErrorCode(SendResult); + + // An error occurred, no completion will be posted to IOCP + + CancelThreadpoolIo(Iocp); + + // Emit diagnostics + + ExtendableStringBuilder<256> ResponseDetails; + EmitReponseDetails(ResponseDetails); + + ZEN_WARN("failed to send HTTP response (error {}: '{}'), request URL: '{}', ({}.{}) response: {}", + SendResult, + ErrorCode.message(), + HttpReq->pRawUrl, + Tx.ServerRequest().SessionId(), + HttpReq->RequestId, + ResponseDetails); + } +} + +/** HTTP completion handler for async work + + This is used to allow work to be taken off the request handler threads + and to support posting responses asynchronously. + */ + +class HttpAsyncWorkRequest : public HttpSysRequestHandler +{ +public: + HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response); + ~HttpAsyncWorkRequest(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + +private: + struct AsyncWorkItem : public IWork + { + virtual void Execute() override; + + AsyncWorkItem(HttpSysTransaction& InTx, std::function<void(HttpServerRequest&)>&& InHandler) + : Tx(InTx) + , Handler(std::move(InHandler)) + { + } + + HttpSysTransaction& Tx; + std::function<void(HttpServerRequest&)> Handler; + }; + + Ref<AsyncWorkItem> m_WorkItem; +}; + +HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response) +: HttpSysRequestHandler(Tx) +{ + m_WorkItem = new AsyncWorkItem(Tx, std::move(Response)); +} + +HttpAsyncWorkRequest::~HttpAsyncWorkRequest() +{ +} + +void +HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode) +{ + ZEN_TRACE_CPU("httpsys::AsyncWork::IssueRequest"); + ErrorCode.clear(); + + Transaction().Server().WorkPool().ScheduleWork(m_WorkItem); +} + +HttpSysRequestHandler* +HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // This ought to not be called since there should be no outstanding I/O request + // when this completion handler is active + + ZEN_UNUSED(IoResult, NumberOfBytesTransferred); + + ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); + + return this; +} + +void +HttpAsyncWorkRequest::AsyncWorkItem::Execute() +{ + ZEN_TRACE_CPU("httpsys::async_execute"); + + try + { + // We need to hold this lock while we're issuing new requests in order to + // prevent race conditions between the thread we are running on and any + // IOCP service threads. Otherwise the IOCP completion handler can end + // up deleting the transaction object before we are done with it! + HttpSysTransaction::CompletionMutexScope _(Tx); + HttpSysServerRequest& ThisRequest = Tx.ServerRequest(); + + ThisRequest.m_NextCompletionHandler = nullptr; + + { + ZEN_TRACE_CPU("httpsys::HandleRequest"); + Handler(ThisRequest); + } + + // TODO: should Handler be destroyed at this point to ensure there + // are no outstanding references into state which could be + // deleted asynchronously as a result of issuing the response? + + if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler) + { + return (void)Tx.IssueNextRequest(NextHandler); + } + else if (!ThisRequest.IsHandled()) + { + return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv)); + } + else + { + // "Handled" but no request handler? Shouldn't ever happen + return (void)Tx.IssueNextRequest( + new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv)); + } + } + catch (std::exception& Ex) + { + return (void)Tx.IssueNextRequest( + new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what()))); + } +} + +/** + _________ + / _____/ ______________ __ ___________ + \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ + / \ ___/| | \/\ /\ ___/| | \/ + /_______ /\___ >__| \_/ \___ >__| + \/ \/ \/ +*/ + +HttpSysServer::HttpSysServer(const HttpSysConfig& InConfig) +: m_Log(logging::Get("http")) +, m_RequestLog(logging::Get("http_requests")) +, m_IsRequestLoggingEnabled(InConfig.IsRequestLoggingEnabled) +, m_IsAsyncResponseEnabled(InConfig.IsAsyncResponseEnabled) +, m_InitialConfig(InConfig) +{ + // Initialize thread pool + + int MinThreadCount; + int MaxThreadCount; + + if (m_InitialConfig.ThreadCount == 0) + { + MinThreadCount = Max(8u, std::thread::hardware_concurrency()); + } + else + { + MinThreadCount = m_InitialConfig.ThreadCount; + } + + MaxThreadCount = MinThreadCount * 2; + + if (m_InitialConfig.IsDedicatedServer) + { + // In order to limit the potential impact of threads stuck + // in locks we allow the thread pool to be oversubscribed + // by a fair amount + + MaxThreadCount *= 2; + } + + m_IoThreadPool = std::make_unique<WinIoThreadPool>(MinThreadCount, MaxThreadCount); + + if (m_InitialConfig.AsyncWorkThreadCount == 0) + { + m_InitialConfig.AsyncWorkThreadCount = 16; + } + + ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr); + + if (Result != NO_ERROR) + { + return; + } + + m_IsHttpInitialized = true; + m_IsOk = true; + + ZEN_INFO("http.sys server started in {} mode, using {}-{} I/O threads and {} async worker threads", + m_InitialConfig.IsDedicatedServer ? "DEDICATED" : "NORMAL", + MinThreadCount, + MaxThreadCount, + m_InitialConfig.AsyncWorkThreadCount); +} + +HttpSysServer::~HttpSysServer() +{ + if (m_IsHttpInitialized) + { + ZEN_ERROR("~HttpSysServer() called without calling Close() first"); + } + + delete m_AsyncWorkPool; + m_AsyncWorkPool = nullptr; +} + +void +HttpSysServer::Close() +{ + if (m_IsHttpInitialized) + { + Cleanup(); + + HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr); + m_IsHttpInitialized = false; + } +} + +int +HttpSysServer::InitializeServer(int BasePort) +{ + using namespace std::literals; + + WideStringBuilder<64> WildcardUrlPath; + WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; + + m_IsOk = false; + + ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + int EffectivePort = BasePort; + + Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + + // Sharing violation implies the port is being used by another process + for (int PortOffset = 1; (Result == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + WildcardUrlPath.Reset(); + WildcardUrlPath << u8"http://*:"sv << int64_t(EffectivePort) << u8"/"sv; + + Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + } + + m_BaseUris.clear(); + if (Result == NO_ERROR) + { + m_BaseUris.push_back(WildcardUrlPath.c_str()); + } + else if (Result == ERROR_ACCESS_DENIED) + { + // If we can't register the wildcard path, we fall back to local paths + // This local paths allow requests originating locally to function, but will not allow + // remote origin requests to function. This can be remedied by using netsh + // during an install process to grant permissions to route public access to the appropriate + // port for the current user. eg: + // netsh http add urlacl url=http://*:8558/ user=<some_user> + + ZEN_WARN("Unable to register handler using '{}' - falling back to local-only", WideToUtf8(WildcardUrlPath)); + + const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; + + ULONG InternalResult = ERROR_SHARING_VIOLATION; + for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + + for (const std::u8string_view Host : Hosts) + { + WideStringBuilder<64> LocalUrlPath; + LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; + + InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + + if (InternalResult == NO_ERROR) + { + ZEN_INFO("Registered local handler '{}'", WideToUtf8(LocalUrlPath)); + + m_BaseUris.push_back(LocalUrlPath.c_str()); + } + else + { + break; + } + } + } + } + + if (m_BaseUris.empty()) + { + ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; + + WideStringBuilder<64> QueueName; + QueueName << "zenserver_" << EffectivePort; + + Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, + /* Name */ QueueName.c_str(), + /* SecurityAttributes */ nullptr, + /* Flags */ 0, + &m_RequestQueueHandle); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + + return EffectivePort; + } + + HttpBindingInfo.Flags.Present = 1; + HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle; + + Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo)); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + + return EffectivePort; + } + + // Configure rejection method. Default is to drop the connection, it's better if we + // return an explicit error code when the queue cannot accept more requests + + { + HTTP_503_RESPONSE_VERBOSITY VerbosityInformation = Http503ResponseVerbosityLimited; + + Result = HttpSetRequestQueueProperty(m_RequestQueueHandle, + HttpServer503VerbosityProperty, + &VerbosityInformation, + sizeof VerbosityInformation, + 0, + 0); + } + + // Tune the maximum number of pending requests in the http.sys request queue. By default + // the value is 1000 which is plenty for single user machines but for dedicated servers + // serving many users it makes sense to increase this to a higher number to help smooth + // out intermittent stalls like we might experience when GC is triggered + + if (m_InitialConfig.IsDedicatedServer) + { + ULONG QueueLength = 50000; + + Result = HttpSetRequestQueueProperty(m_RequestQueueHandle, HttpServerQueueLengthProperty, &QueueLength, sizeof QueueLength, 0, 0); + + if (Result != NO_ERROR) + { + ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result); + } + } + + // Create I/O completion port + + std::error_code ErrorCode; + m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); + + if (ErrorCode) + { + ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); + } + else + { + m_IsOk = true; + + ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + } + + // This is not available in all Windows SDK versions so for now we can't use recently + // released functionality. We should investigate how to get more recent SDK releases + // into the build + +# if 0 + if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4)) + { + ZEN_DEBUG("HTTP3 is available"); + } + else + { + ZEN_DEBUG("HTTP3 is NOT available"); + } +# endif + + return EffectivePort; +} + +void +HttpSysServer::Cleanup() +{ + ++m_IsShuttingDown; + + if (m_RequestQueueHandle) + { + HttpCloseRequestQueue(m_RequestQueueHandle); + m_RequestQueueHandle = nullptr; + } + + if (m_HttpUrlGroupId) + { + HttpCloseUrlGroup(m_HttpUrlGroupId); + m_HttpUrlGroupId = 0; + } + + if (m_HttpSessionId) + { + HttpCloseServerSession(m_HttpSessionId); + m_HttpSessionId = 0; + } +} + +WorkerThreadPool& +HttpSysServer::WorkPool() +{ + if (!m_AsyncWorkPool) + { + RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); + + if (!m_AsyncWorkPool) + { + m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"); + } + } + + return *m_AsyncWorkPool; +} + +void +HttpSysServer::StartServer() +{ + const int InitialRequestCount = 32; + + for (int i = 0; i < InitialRequestCount; ++i) + { + IssueNewRequestMaybe(); + } +} + +void +HttpSysServer::Run(bool IsInteractive) +{ + if (IsInteractive) + { + zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit"); + } + + do + { + // int WaitTimeout = -1; + int WaitTimeout = 100; + + if (IsInteractive) + { + WaitTimeout = 1000; + + if (_kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + UpdateLofreqTimerValue(); + } while (!IsApplicationExitRequested()); +} + +void +HttpSysServer::OnHandlingNewRequest() +{ + if (--m_PendingRequests > m_MinPendingRequests) + { + // We have more than the minimum number of requests pending, just let someone else + // enqueue new requests. This should be the common case as we check if we need to + // enqueue more requests before exiting the completion handler. + return; + } + + IssueNewRequestMaybe(); +} + +void +HttpSysServer::IssueNewRequestMaybe() +{ + if (m_IsShuttingDown.load(std::memory_order::acquire)) + { + return; + } + + if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests) + { + return; + } + + std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this); + + std::error_code ErrorCode; + Request->IssueInitialRequest(ErrorCode); + + if (ErrorCode) + { + // No request was actually issued. What is the appropriate response? + + return; + } + + // This may end up exceeding the MaxPendingRequests limit, but it's not + // really a problem. I'm doing it this way mostly to avoid dealing with + // exceptions here + ++m_PendingRequests; + + Request.release(); +} + +void +HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) +{ + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); + Service.SetUriPrefixLength(PathUtf16.size() + 1 /* leading slash */); + + // Convert to wide string + + for (const std::wstring& BaseUri : m_BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; + + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + + return; + } + } +} + +void +HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) +{ + ZEN_UNUSED(Service); + + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); + + // Convert to wide string + + for (const std::wstring& BaseUri : m_BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; + + ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + } + } +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler) +{ +} + +HttpSysTransaction::~HttpSysTransaction() +{ +} + +PTP_IO +HttpSysTransaction::Iocp() +{ + return m_HttpServer.m_IoThreadPool->Iocp(); +} + +HANDLE +HttpSysTransaction::RequestQueueHandle() +{ + return m_HttpServer.m_RequestQueueHandle; +} + +void +HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode) +{ + m_InitialHttpHandler.IssueRequest(ErrorCode); +} + +thread_local bool t_IsHttpSysThreadNamed = false; +static std::atomic<int> HttpSysThreadIndex = 0; + +static void +NameCurrentHttpSysThread() +{ + t_IsHttpSysThreadNamed = true; + const int ThreadIndex = ++HttpSysThreadIndex; + zen::ExtendableStringBuilder<16> ThreadName; + ThreadName << "httpio_" << ThreadIndex; + SetCurrentThreadName(ThreadName); +} + +void +HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io) +{ + ZEN_UNUSED(Io, Instance); + + // Assign names to threads for context + + if (!t_IsHttpSysThreadNamed) + { + NameCurrentHttpSysThread(); + } + + // Note that for a given transaction we may be in this completion function on more + // than one thread at any given moment. This means we need to be careful about what + // happens in here + + HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + + if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) + { + delete Transaction; + } + + // Ensure new requests are enqueued as necessary. We do this here instead + // of inside the transaction completion handler now to avoid spending time + // in unrelated API calls while holding the transaction lock + + if (HttpSysServer* HttpServer = reinterpret_cast<HttpSysServer*>(pContext)) + { + HttpServer->IssueNewRequestMaybe(); + } +} + +bool +HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler) +{ + ZEN_TRACE_CPU("httpsys::Transaction::IssueNextRequest"); + + HttpSysRequestHandler* CurrentHandler = m_CompletionHandler; + m_CompletionHandler = NewCompletionHandler; + + auto _ = MakeGuard([this, CurrentHandler] { + if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler)) + { + delete CurrentHandler; + } + }); + + if (NewCompletionHandler == nullptr) + { + return false; + } + + try + { + std::error_code ErrorCode; + m_CompletionHandler->IssueRequest(ErrorCode); + + if (!ErrorCode) + { + return true; + } + + ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message()); + } + catch (std::exception& Ex) + { + ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what()); + } + + // something went wrong, no request is pending + m_CompletionHandler = nullptr; + + return false; +} + +HttpSysTransaction::Status +HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // We use this to ensure sequential execution of completion handlers + // for any given transaction. It also ensures all member variables are + // in a consistent state for the current thread + + RwLock::ExclusiveLockScope _(m_CompletionMutex); + + bool IsRequestPending = false; + + if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler) + { + if ((CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest()) + { + // Ensure we have a sufficient number of pending requests outstanding + m_HttpServer.OnHandlingNewRequest(); + } + + auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred); + + IsRequestPending = IssueNextRequest(NewCompletionHandler); + } + + if (IsRequestPending) + { + // There is another request pending on this transaction, so it needs to remain valid + return Status::kRequestPending; + } + + if (m_HttpServer.m_IsRequestLoggingEnabled) + { + if (m_HandlerRequest.has_value()) + { + m_HttpServer.m_RequestLog.info("{} {}", ToString(m_HandlerRequest->RequestVerb()), m_HandlerRequest->RelativeUri()); + } + } + + // Transaction done, caller should clean up (delete) this instance + return Status::kDone; +} + +HttpSysServerRequest& +HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) +{ + HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + + // Default request handling + + if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + { + Service.HandleRequest(ThisRequest); + } + + return ThisRequest; +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) +: m_HttpTx(Tx) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); + + const int PrefixLength = Service.UriPrefixLength(); + const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(wchar_t); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + if (AbsPathLength >= PrefixLength) + { + // We convert the URI immediately because most of the code involved prefers to deal + // with utf8. This is overhead which I'd prefer to avoid but for now we just have + // to live with it + + WideToUtf8({(wchar_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)}, + m_UriUtf8); + + std::string_view UriSuffix8{m_UriUtf8}; + + m_UriWithExtension = UriSuffix8; // Retain URI with extension for user access + m_Uri = UriSuffix8; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(UriSuffix8.size() + 1); + } + } + } + else + { + m_UriUtf8.Reset(); + m_Uri = {}; + m_UriWithExtension = {}; + } + + if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) + { + --QueryStringLength; // We skip the leading question mark + + WideToUtf8({(wchar_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(wchar_t)}, m_QueryStringUtf8); + } + else + { + m_QueryStringUtf8.Reset(); + } + + m_QueryString = std::string_view(m_QueryStringUtf8); + m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); + m_ContentLength = GetContentLength(HttpRequestPtr); + m_ContentType = GetContentType(HttpRequestPtr); + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = GetAcceptType(HttpRequestPtr); + } + + if (m_Verb == HttpVerb::kHead) + { + SetSuppressResponseBody(); + } +} + +HttpSysServerRequest::~HttpSysServerRequest() +{ +} + +Oid +HttpSysServerRequest::ParseSessionId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Session"sv) + { + if (Header.RawValueLength == Oid::StringLength) + { + return Oid::FromHexString({Header.pRawValue, Header.RawValueLength}); + } + } + } + + return {}; +} + +uint32_t +HttpSysServerRequest::ParseRequestId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Request"sv) + { + std::string_view RequestValue{Header.pRawValue, Header.RawValueLength}; + uint32_t RequestId = 0; + std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId); + return RequestId; + } + } + + return 0; +} + +IoBuffer +HttpSysServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = + new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size()); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + if (m_HttpTx.Server().IsAsyncResponseEnabled()) + { + m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler)); + } + else + { + ContinuationHandler(m_HttpTx.ServerRequest()); + } +} + +bool +HttpSysServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + HTTP_REQUEST* Req = m_HttpTx.HttpRequest(); + const HTTP_KNOWN_HEADER& RangeHeader = Req->Headers.KnownHeaders[HttpHeaderRange]; + + return TryParseHttpRangeHeader({RangeHeader.pRawValue, RangeHeader.RawValueLength}, Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) +{ +} + +InitialRequestHandler::~InitialRequestHandler() +{ +} + +void +InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) +{ + ZEN_TRACE_CPU("httpsys::Request::IssueRequest"); + + HttpSysTransaction& Tx = Transaction(); + PTP_IO Iocp = Tx.Iocp(); + HTTP_REQUEST* HttpReq = Tx.HttpRequest(); + + StartThreadpoolIo(Iocp); + + ULONG HttpApiResult; + + if (IsInitialRequest()) + { + HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), + HTTP_NULL_ID, + HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, + HttpReq, + RequestBufferSize(), + NULL, + Tx.Overlapped()); + } + else + { + // The http.sys team recommends limiting the size to 128KB + static const uint64_t kMaxBytesPerApiCall = 128 * 1024; + + uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; + const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); + void* BufferWriteCursor = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; + + HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + 0, /* Flags */ + BufferWriteCursor, + gsl::narrow<ULONG>(BytesToReadThisCall), + nullptr, // BytesReturned + Tx.Overlapped()); + } + + if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) + { + CancelThreadpoolIo(Iocp); + + ErrorCode = MakeErrorCode(HttpApiResult); + + ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message()); + + return; + } + + ErrorCode.clear(); +} + +HttpSysRequestHandler* +InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); + + switch (IoResult) + { + default: + case ERROR_OPERATION_ABORTED: + return nullptr; + + case ERROR_MORE_DATA: // Insufficient buffer space + case NO_ERROR: + break; + } + + ZEN_TRACE_CPU("httpsys::HandleCompletion"); + + // Route request + + try + { + HTTP_REQUEST* HttpReq = HttpRequest(); + +# if 0 + for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + { + auto& ReqInfo = HttpReq->pRequestInfo[i]; + + switch (ReqInfo.InfoType) + { + case HttpRequestInfoTypeRequestTiming: + { + const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeAuth: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeChannelBind: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslProtocol: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslTokenBindingDraft: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslTokenBinding: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeTcpInfoV0: + { + const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeRequestSizing: + { + const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo); + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeQuicStats: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeTcpInfoV1: + { + const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + } + } +# endif + + if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) + { + if (m_IsInitialRequest) + { + m_ContentLength = GetContentLength(HttpReq); + const HttpContentType ContentType = GetContentType(HttpReq); + + if (m_ContentLength) + { + // Handle initial chunk read by copying any payload which has already been copied + // into our embedded request buffer + + m_PayloadBuffer = IoBuffer(m_ContentLength); + m_PayloadBuffer.SetContentType(ContentType); + + uint64_t BytesToRead = m_ContentLength; + uint8_t* const BufferBase = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()); + uint8_t* BufferWriteCursor = BufferBase; + + const int EntityChunkCount = HttpReq->EntityChunkCount; + + for (int i = 0; i < EntityChunkCount; ++i) + { + HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; + + ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); + + const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; + + ZEN_ASSERT(BufferLength <= BytesToRead); + + memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); + + BufferWriteCursor += BufferLength; + BytesToRead -= BufferLength; + } + + m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; + } + } + else + { + m_CurrentPayloadOffset += NumberOfBytesTransferred; + } + + if (m_CurrentPayloadOffset != m_ContentLength) + { + // Body not complete, issue another read request to receive more body data + return this; + } + + // Request body received completely + + m_PayloadBuffer.MakeImmutable(); + + HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer); + + if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler) + { + return Response; + } + + if (!ThisRequest.IsHandled()) + { + return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); + } + } + + // Unable to route + return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); + } + catch (std::system_error& SystemError) + { + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, SystemError.what()); + } + + ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, SystemError.what()); + } + catch (std::bad_alloc& BadAlloc) + { + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, BadAlloc.what()); + } + catch (std::exception& ex) + { + ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, ex.what()); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// HttpServer interface implementation +// + +int +HttpSysServer::Initialize(int BasePort) +{ + int EffectivePort = InitializeServer(BasePort); + StartServer(); + return EffectivePort; +} + +void +HttpSysServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} +void +HttpSysServer::RegisterService(HttpService& Service) +{ + RegisterService(Service.BaseUri(), Service); +} + +Ref<HttpServer> +CreateHttpSysServer(HttpSysConfig Config) +{ + return Ref<HttpServer>(new HttpSysServer(Config)); +} + +} // namespace zen +#endif |