// Copyright Epic Games, Inc. All Rights Reserved. #include "httpsys.h" #include #include #include #include #if ZEN_WITH_HTTPSYS # include # include # pragma comment(lib, "httpapi.lib") std::wstring UTF8_to_wstring(const char* in) { std::wstring out; unsigned int codepoint; while (*in != 0) { unsigned char ch = static_cast(*in); if (ch <= 0x7f) codepoint = ch; else if (ch <= 0xbf) codepoint = (codepoint << 6) | (ch & 0x3f); else if (ch <= 0xdf) codepoint = ch & 0x1f; else if (ch <= 0xef) codepoint = ch & 0x0f; else codepoint = ch & 0x07; ++in; if (((*in & 0xc0) != 0x80) && (codepoint <= 0x10ffff)) { if (sizeof(wchar_t) > 2) { out.append(1, static_cast(codepoint)); } else if (codepoint > 0xffff) { out.append(1, static_cast(0xd800 + (codepoint >> 10))); out.append(1, static_cast(0xdc00 + (codepoint & 0x03ff))); } else if (codepoint < 0xd800 || codepoint >= 0xe000) { out.append(1, static_cast(codepoint)); } } } return out; } namespace zen { class HttpSysServer; class HttpSysTransaction; class HttpMessageResponseRequest; ////////////////////////////////////////////////////////////////////////// using namespace std::literals; static const uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv); static const uint32_t HashJson = HashStringDjb2("application/json"sv); static const uint32_t HashYaml = HashStringDjb2("text/yaml"sv); static const uint32_t HashText = HashStringDjb2("text/plain"sv); static const uint32_t HashCompactBinary = HashStringDjb2("application/x-ue-cb"sv); static const uint32_t HashCompactBinaryPackage = HashStringDjb2("application/x-ue-cbpkg"sv); HttpContentType MapContentType(const std::string_view& ContentTypeString) { if (!ContentTypeString.empty()) { const uint32_t CtHash = HashStringDjb2(ContentTypeString); if (CtHash == HashBinary) { return HttpContentType::kBinary; } else if (CtHash == HashCompactBinary) { return HttpContentType::kCbObject; } else if (CtHash == HashCompactBinaryPackage) { return HttpContentType::kCbPackage; } else if (CtHash == HashJson) { return HttpContentType::kJSON; } else if (CtHash == HashYaml) { return HttpContentType::kYAML; } else if (CtHash == HashText) { return HttpContentType::kText; } } return HttpContentType::kUnknownContentType; } ////////////////////////////////////////////////////////////////////////// const char* ReasonStringForHttpResultCode(int HttpCode) { switch (HttpCode) { // 1xx Informational case 100: return "Continue"; case 101: return "Switching Protocols"; // 2xx Success case 200: return "OK"; case 201: return "Created"; case 202: return "Accepted"; case 204: return "No Content"; case 205: return "Reset Content"; case 206: return "Partial Content"; // 3xx Redirection case 300: return "Multiple Choices"; case 301: return "Moved Permanently"; case 302: return "Found"; case 303: return "See Other"; case 304: return "Not Modified"; case 305: return "Use Proxy"; case 306: return "Switch Proxy"; case 307: return "Temporary Redirect"; case 308: return "Permanent Redirect"; // 4xx Client errors case 400: return "Bad Request"; case 401: return "Unauthorized"; case 402: return "Payment Required"; case 403: return "Forbidden"; case 404: return "Not Found"; case 405: return "Method Not Allowed"; case 406: return "Not Acceptable"; case 407: return "Proxy Authentication Required"; case 408: return "Request Timeout"; case 409: return "Conflict"; case 410: return "Gone"; case 411: return "Length Required"; case 412: return "Precondition Failed"; case 413: return "Payload Too Large"; case 414: return "URI Too Long"; case 415: return "Unsupported Media Type"; case 416: return "Range Not Satisifiable"; case 417: return "Expectation Failed"; case 418: return "I'm a teapot"; case 421: return "Misdirected Request"; case 422: return "Unprocessable Entity"; case 423: return "Locked"; case 424: return "Failed Dependency"; case 425: return "Too Early"; case 426: return "Upgrade Required"; case 428: return "Precondition Required"; case 429: return "Too Many Requests"; case 431: return "Request Header Fields Too Large"; // 5xx Server errors case 500: return "Internal Server Error"; case 501: return "Not Implemented"; case 502: return "Bad Gateway"; case 503: return "Service Unavailable"; case 504: return "Gateway Timeout"; case 505: return "HTTP Version Not Supported"; case 506: return "Variant Also Negotiates"; case 507: return "Insufficient Storage"; case 508: return "Loop Detected"; case 510: return "Not Extended"; case 511: return "Network Authentication Required"; default: return "Unknown Result"; } } HttpVerb TranslateHttpVerb(HTTP_VERB ReqVerb) { switch (ReqVerb) { case HttpVerbOPTIONS: return HttpVerb::kOptions; case HttpVerbGET: return HttpVerb::kGet; case HttpVerbHEAD: return HttpVerb::kHead; case HttpVerbPOST: return HttpVerb::kPost; case HttpVerbPUT: return HttpVerb::kPut; case HttpVerbDELETE: return HttpVerb::kDelete; case HttpVerbCOPY: return HttpVerb::kCopy; default: // TODO: invalid request? return (HttpVerb)0; } } uint64_t GetContentLength(const HTTP_REQUEST* HttpRequest) { const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; std::string_view cl(clh.pRawValue, clh.RawValueLength); uint64_t ContentLength = 0; std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); return ContentLength; }; HttpContentType GetContentType(const HTTP_REQUEST* HttpRequest) { const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; return MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); }; HttpContentType GetAcceptType(const HTTP_REQUEST* HttpRequest) { const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; return MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); }; /** * @brief Base class for any pending or active HTTP transactions */ class HttpSysRequestHandler { public: explicit HttpSysRequestHandler(HttpSysTransaction& InRequest) : m_Request(InRequest) {} virtual ~HttpSysRequestHandler() = default; virtual void IssueRequest(std::error_code& ErrorCode) = 0; virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; HttpSysTransaction& Transaction() { return m_Request; } private: HttpSysTransaction& m_Request; // Related HTTP transaction object }; /** * This is the handler for the initial HTTP I/O request which will receive the headers * and however much of the remaining payload might fit in the embedded request buffer */ struct InitialRequestHandler : public HttpSysRequestHandler { inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } inline bool IsInitialRequest() const { return m_IsInitialRequest; } InitialRequestHandler(HttpSysTransaction& InRequest); ~InitialRequestHandler(); virtual void IssueRequest(std::error_code& ErrorCode) override final; virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; bool m_IsInitialRequest = true; uint64_t m_CurrentPayloadOffset = 0; uint64_t m_ContentLength = ~uint64_t(0); IoBuffer m_PayloadBuffer; UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)]; }; /** * This is the class which request handlers use to interact with the server instance */ class HttpSysServerRequest : public HttpServerRequest { public: HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); ~HttpSysServerRequest() = default; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; HttpSysTransaction& m_HttpTx; HttpMessageResponseRequest* m_Response = nullptr; // TODO: make this more general IoBuffer m_PayloadBuffer; }; /** HTTP transaction There will be an instance of this per pending and in-flight HTTP transaction */ class HttpSysTransaction final { public: HttpSysTransaction(HttpSysServer& Server); virtual ~HttpSysTransaction(); enum class Status { kDone, kRequestPending }; Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, PVOID pContext /* HttpSysServer */, PVOID pOverlapped, ULONG IoResult, ULONG_PTR NumberOfBytesTransferred, PTP_IO Io); void IssueInitialRequest(std::error_code& ErrorCode); PTP_IO Iocp(); HANDLE RequestQueueHandle(); inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } inline HttpSysServer& Server() { return m_HttpServer; } inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } private: OVERLAPPED m_HttpOverlapped{}; HttpSysServer& m_HttpServer; HttpSysRequestHandler* m_CompletionHandler{nullptr}; // Tracks which handler is due to handle the next I/O completion event RwLock m_CompletionMutex; InitialRequestHandler m_InitialHttpHandler{*this}; }; ////////////////////////////////////////////////////////////////////////// /** * @brief HTTP request response I/O request handler * * Asynchronously streams out a response to an HTTP request via compound * responses from memory or directly from file */ class HttpMessageResponseRequest : public HttpSysRequestHandler { public: HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, HttpContentType ContentType, const void* Payload, size_t PayloadSize); HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, HttpContentType ContentType, std::span Blobs); ~HttpMessageResponseRequest(); virtual void IssueRequest(std::error_code& ErrorCode) override final; virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; void SuppressResponseBody(); private: std::vector m_HttpDataChunks; uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes uint16_t m_ResponseCode = 0; uint32_t m_NextDataChunkOffset = 0; // This is used for responses where the number of chunks exceed the maximum number for one API call uint32_t m_RemainingChunkCount = 0; bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; void Initialize(uint16_t ResponseCode, std::span Blobs); std::vector m_DataBuffers; }; HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) : HttpSysRequestHandler(InRequest) { std::array EmptyBufferList; Initialize(ResponseCode, EmptyBufferList); } HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) : HttpSysRequestHandler(InRequest) , m_ContentType(HttpContentType::kText) { IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); std::array SingleBufferList({MessageBuffer}); Initialize(ResponseCode, SingleBufferList); } HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, HttpContentType ContentType, const void* Payload, size_t PayloadSize) : HttpSysRequestHandler(InRequest) , m_ContentType(ContentType) { IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); std::array SingleBufferList({MessageBuffer}); Initialize(ResponseCode, SingleBufferList); } HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, HttpContentType ContentType, std::span BlobList) : HttpSysRequestHandler(InRequest) { Initialize(ResponseCode, BlobList); ZEN_UNUSED(ContentType); } HttpMessageResponseRequest::~HttpMessageResponseRequest() { } void HttpMessageResponseRequest::Initialize(uint16_t ResponseCode, std::span BlobList) { m_ResponseCode = ResponseCode; const uint32_t ChunkCount = gsl::narrow(BlobList.size()); m_HttpDataChunks.reserve(ChunkCount); m_DataBuffers.reserve(ChunkCount); for (IoBuffer& Buffer : BlobList) { m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); } // Initialize the full array up front uint64_t LocalDataSize = 0; for (IoBuffer& Buffer : m_DataBuffers) { uint64_t BufferDataSize = Buffer.Size(); ZEN_ASSERT(BufferDataSize); LocalDataSize += BufferDataSize; IoBufferFileReference FileRef; if (Buffer.GetFileReference(/* out */ FileRef)) { // Use direct file transfer m_HttpDataChunks.push_back({}); auto& Chunk = m_HttpDataChunks.back(); Chunk.DataChunkType = HttpDataChunkFromFileHandle; Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; } else { // Send from memory, need to make sure we chunk the buffer up since // the underlying data structure only accepts 32-bit chunk sizes for // memory chunks. When this happens the vector will be reallocated, // which is fine since this will be a pretty rare case and sending // the data is going to take a lot longer than a memory allocation :) const uint8_t* WriteCursor = reinterpret_cast(Buffer.Data()); while (BufferDataSize) { const ULONG ThisChunkSize = gsl::narrow(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); m_HttpDataChunks.push_back({}); auto& Chunk = m_HttpDataChunks.back(); Chunk.DataChunkType = HttpDataChunkFromMemory; Chunk.FromMemory.pBuffer = (void*)WriteCursor; Chunk.FromMemory.BufferLength = ThisChunkSize; BufferDataSize -= ThisChunkSize; WriteCursor += ThisChunkSize; } } } m_RemainingChunkCount = gsl::narrow(m_HttpDataChunks.size()); m_TotalDataSize = LocalDataSize; } void HttpMessageResponseRequest::SuppressResponseBody() { m_RemainingChunkCount = 0; m_HttpDataChunks.clear(); m_DataBuffers.clear(); } HttpSysRequestHandler* HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { ZEN_UNUSED(NumberOfBytesTransferred); ZEN_UNUSED(IoResult); if (m_RemainingChunkCount == 0) { return nullptr; // All done } return this; } void HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) { HttpSysTransaction& Tx = Transaction(); HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); PTP_IO const Iocp = Tx.Iocp(); StartThreadpoolIo(Iocp); // Split payload into batches to play well with the underlying API const int MaxChunksPerCall = 9999; const int ThisRequestChunkCount = std::min(m_RemainingChunkCount, MaxChunksPerCall); const int ThisRequestChunkOffset = m_NextDataChunkOffset; m_RemainingChunkCount -= ThisRequestChunkCount; m_NextDataChunkOffset += ThisRequestChunkCount; ULONG SendFlags = 0; if (m_RemainingChunkCount) { // We need to make more calls to send the full amount of data SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA; } ULONG SendResult = 0; if (m_IsInitialResponse) { // Populate response structure HTTP_RESPONSE HttpResponse = {}; HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount); HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset; // Content-length header char ContentLengthString[32]; _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10); PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength]; ContentLengthHeader->pRawValue = ContentLengthString; ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString); // Content-type header PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); HttpResponse.StatusCode = m_ResponseCode; HttpResponse.pReason = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.ReasonLength = (USHORT)strlen(HttpResponse.pReason); // Cache policy HTTP_CACHE_POLICY CachePolicy; CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; CachePolicy.SecondsToLive = 0; // Initial response API call SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), HttpReq->RequestId, SendFlags, &HttpResponse, &CachePolicy, NULL, NULL, 0, Tx.Overlapped(), NULL); m_IsInitialResponse = false; } else { // Subsequent response API calls SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), HttpReq->RequestId, SendFlags, (USHORT)ThisRequestChunkCount, // EntityChunkCount &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks NULL, // BytesSent NULL, // Reserved1 0, // Reserved2 Tx.Overlapped(), // Overlapped NULL // LogData ); } if ((SendResult != NO_ERROR) // Synchronous completion, but the completion event will still be posted to IOCP && (SendResult != ERROR_IO_PENDING) // Asynchronous completion ) { // Some error occurred, no completion will be posted CancelThreadpoolIo(Iocp); spdlog::error("failed to send HTTP response (error: {}) URL: {}"sv, SendResult, HttpReq->pRawUrl); ErrorCode = MakeWin32ErrorCode(SendResult); } } /** _________ / _____/ ______________ __ ___________ \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ / \ ___/| | \/\ /\ ___/| | \/ /_______ /\___ >__| \_/ \___ >__| \/ \/ \/ */ HttpSysServer::HttpSysServer(unsigned int ThreadCount) : m_ThreadPool(ThreadCount) { ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr); if (Result != NO_ERROR) { return; } m_IsHttpInitialized = true; m_IsOk = true; } HttpSysServer::~HttpSysServer() { if (m_IsHttpInitialized) { Cleanup(); HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr); } } void HttpSysServer::Initialize(const wchar_t* UrlPath) { m_IsOk = false; ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); if (Result != NO_ERROR) { spdlog::error("Failed to create server session for '{}': {x}"sv, WideToUtf8(UrlPath), Result); return; } Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); if (Result != NO_ERROR) { spdlog::error("Failed to create URL group for '{}': {x}"sv, WideToUtf8(UrlPath), Result); return; } m_BaseUri = UrlPath; Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, UrlPath, HTTP_URL_CONTEXT(0), 0); if (Result != NO_ERROR) { spdlog::error("Failed to add base URL to URL group for '{}': {x}"sv, WideToUtf8(UrlPath), Result); return; } HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, /* Name */ nullptr, /* SecurityAttributes */ nullptr, /* Flags */ 0, &m_RequestQueueHandle); if (Result != NO_ERROR) { spdlog::error("Failed to create request queue for '{}': {x}"sv, WideToUtf8(UrlPath), Result); return; } HttpBindingInfo.Flags.Present = 1; HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle; Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo)); if (Result != NO_ERROR) { spdlog::error("Failed to set server binding property for '{}': {x}"sv, WideToUtf8(UrlPath), Result); return; } // Create I/O completion port std::error_code ErrorCode; m_ThreadPool.CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); if (ErrorCode) { spdlog::error("Failed to create IOCP for '{}': {}"sv, WideToUtf8(UrlPath), ErrorCode.message()); } else { m_IsOk = true; } } void HttpSysServer::Cleanup() { ++m_IsShuttingDown; if (m_RequestQueueHandle) { HttpCloseRequestQueue(m_RequestQueueHandle); m_RequestQueueHandle = nullptr; } if (m_HttpUrlGroupId) { HttpCloseUrlGroup(m_HttpUrlGroupId); m_HttpUrlGroupId = 0; } if (m_HttpSessionId) { HttpCloseServerSession(m_HttpSessionId); m_HttpSessionId = 0; } } void HttpSysServer::StartServer() { const int InitialRequestCount = 32; for (int i = 0; i < InitialRequestCount; ++i) { IssueNewRequestMaybe(); } } void HttpSysServer::Run(bool TestMode) { if (TestMode == false) { zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit"); } do { int WaitTimeout = -1; if (!TestMode) { WaitTimeout = 1000; } if (!TestMode && _kbhit() != 0) { char c = (char)_getch(); if (c == 27 || c == 'Q' || c == 'q') { RequestApplicationExit(0); } } m_ShutdownEvent.Wait(WaitTimeout); } while (!IsApplicationExitRequested()); } void HttpSysServer::OnHandlingRequest() { if (--m_PendingRequests > m_MinPendingRequests) { // We have more than the minimum number of requests pending, just let someone else // enqueue new requests return; } IssueNewRequestMaybe(); } void HttpSysServer::IssueNewRequestMaybe() { if (m_IsShuttingDown.load(std::memory_order::acquire)) { return; } if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests) { return; } std::unique_ptr Request = std::make_unique(*this); std::error_code ec; Request->IssueInitialRequest(ec); if (ec) { // No request was actually issued. What is the appropriate response? return; } // This may end up exceeding the MaxPendingRequests limit, but it's not // really a problem. I'm doing it this way mostly to avoid dealing with // exceptions here ++m_PendingRequests; Request.release(); } void HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) { if (UrlPath[0] == '/') { ++UrlPath; } const std::wstring Path16 = UTF8_to_wstring(UrlPath); Service.SetUriPrefixLength(Path16.size() + 1 /* leading slash */); // Convert to wide string std::wstring Url16 = m_BaseUri + Path16; ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); if (Result != NO_ERROR) { spdlog::error("HttpAddUrlToUrlGroup failed with result {}"sv, Result); return; } } void HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) { ZEN_UNUSED(Service); if (UrlPath[0] == '/') { ++UrlPath; } const std::wstring Path16 = UTF8_to_wstring(UrlPath); // Convert to wide string std::wstring Url16 = m_BaseUri + Path16; ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); if (Result != NO_ERROR) { spdlog::error("HttpRemoveUrlFromUrlGroup failed with result {}"sv, Result); } } ////////////////////////////////////////////////////////////////////////// HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler) { } HttpSysTransaction::~HttpSysTransaction() { } PTP_IO HttpSysTransaction::Iocp() { return m_HttpServer.m_ThreadPool.Iocp(); } HANDLE HttpSysTransaction::RequestQueueHandle() { return m_HttpServer.m_RequestQueueHandle; } void HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode) { m_InitialHttpHandler.IssueRequest(ErrorCode); } void HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, PVOID pContext /* HttpSysServer */, PVOID pOverlapped, ULONG IoResult, ULONG_PTR NumberOfBytesTransferred, PTP_IO Io) { UNREFERENCED_PARAMETER(Io); UNREFERENCED_PARAMETER(Instance); UNREFERENCED_PARAMETER(pContext); // Note that for a given transaction we may be in this completion function on more // than one thread at any given moment. This means we need to be careful about what // happens in here HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) { delete Transaction; } } HttpSysTransaction::Status HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { // We use this to ensure sequential execution of completion handlers // for any given transaction. It also ensures all member variables are // in a consistent state for the current thread RwLock::ExclusiveLockScope _(m_CompletionMutex); bool IsRequestPending = false; if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler) { const bool IsInitialRequest = (CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest(); if (IsInitialRequest) { // Ensure we have a sufficient number of pending requests outstanding m_HttpServer.OnHandlingRequest(); } m_CompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred); if (m_CompletionHandler) { try { std::error_code ec; m_CompletionHandler->IssueRequest(ec); if (ec) { spdlog::error("IssueRequest() failed {}"sv, ec.message()); } else { IsRequestPending = true; } } catch (std::exception& Ex) { spdlog::error("exception caught from IssueRequest(): {}"sv, Ex.what()); // something went wrong, no request is pending } } else { if (IsInitialRequest == false) { delete CurrentHandler; } } } // Ensure new requests are enqueued as necessary m_HttpServer.IssueNewRequestMaybe(); if (IsRequestPending) { // There is another request pending on this transaction, so it needs to remain valid return Status::kRequestPending; } // Transaction done, caller should clean up (delete) this instance return Status::kDone; } ////////////////////////////////////////////////////////////////////////// HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) : m_HttpTx(Tx) , m_PayloadBuffer(std::move(PayloadBuffer)) { const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); const int PrefixLength = Service.UriPrefixLength(); const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t); if (AbsPathLength >= PrefixLength) { // We convert the URI immediately because most of the code involved prefers to deal // with utf8. This has some performance impact which I'd prefer to avoid but for now // we just have to live with it WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow(AbsPathLength - PrefixLength)}, m_UriUtf8); } else { m_UriUtf8.Reset(); } if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) { --QueryStringLength; WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryStringUtf8); } else { m_QueryStringUtf8.Reset(); } m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); m_ContentLength = GetContentLength(HttpRequestPtr); m_ContentType = GetContentType(HttpRequestPtr); m_AcceptType = GetAcceptType(HttpRequestPtr); } IoBuffer HttpSysServerRequest::ReadPayload() { return m_PayloadBuffer; } void HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) { ZEN_ASSERT(m_IsHandled == false); m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); if (m_SuppressBody) { m_Response->SuppressResponseBody(); } m_IsHandled = true; } void HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) { ZEN_ASSERT(m_IsHandled == false); m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); if (m_SuppressBody) { m_Response->SuppressResponseBody(); } m_IsHandled = true; } void HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) { ZEN_ASSERT(m_IsHandled == false); m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size()); if (m_SuppressBody) { m_Response->SuppressResponseBody(); } m_IsHandled = true; } ////////////////////////////////////////////////////////////////////////// InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) { } InitialRequestHandler::~InitialRequestHandler() { } void InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) { HttpSysTransaction& Tx = Transaction(); PTP_IO Iocp = Tx.Iocp(); HTTP_REQUEST* HttpReq = Tx.HttpRequest(); StartThreadpoolIo(Iocp); ULONG HttpApiResult; if (IsInitialRequest()) { HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), HTTP_NULL_ID, HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, HttpReq, RequestBufferSize(), NULL, Tx.Overlapped()); } else { // The http.sys team recommends limiting the size to 128KB static const uint64_t kMaxBytesPerApiCall = 128 * 1024; uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); void* BufferWriteCursor = reinterpret_cast(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), HttpReq->RequestId, 0, /* Flags */ BufferWriteCursor, gsl::narrow(BytesToReadThisCall), nullptr, // BytesReturned Tx.Overlapped()); } if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) { CancelThreadpoolIo(Iocp); if (HttpApiResult == ERROR_MORE_DATA) { // ProcessReceiveAndPostResponse(pIoRequest, pServerContext->Io, ERROR_MORE_DATA); } // CleanupHttpIoRequest(pIoRequest); ErrorCode = MakeWin32ErrorCode(HttpApiResult); spdlog::error("HttpReceiveHttpRequest failed, error {}", ErrorCode.message()); return; } ErrorCode = std::error_code(); } HttpSysRequestHandler* InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); switch (IoResult) { case ERROR_OPERATION_ABORTED: return nullptr; case ERROR_MORE_DATA: // Insufficient buffer space break; } // Route requests try { HTTP_REQUEST* HttpReq = HttpRequest(); # if 0 for (int i = 0; i < HttpReq->RequestInfoCount; ++i) { auto& ReqInfo = HttpReq->pRequestInfo[i]; switch (ReqInfo.InfoType) { case HttpRequestInfoTypeRequestTiming: { const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast(ReqInfo.pInfo); spdlog::info(""); } break; case HttpRequestInfoTypeAuth: spdlog::info(""); break; case HttpRequestInfoTypeChannelBind: spdlog::info(""); break; case HttpRequestInfoTypeSslProtocol: spdlog::info(""); break; case HttpRequestInfoTypeSslTokenBindingDraft: spdlog::info(""); break; case HttpRequestInfoTypeSslTokenBinding: spdlog::info(""); break; case HttpRequestInfoTypeTcpInfoV0: { const TCP_INFO_v0* TcpInfo = reinterpret_cast(ReqInfo.pInfo); spdlog::info(""); } break; case HttpRequestInfoTypeRequestSizing: { const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast(ReqInfo.pInfo); spdlog::info(""); } break; case HttpRequestInfoTypeQuicStats: spdlog::info(""); break; case HttpRequestInfoTypeTcpInfoV1: { const TCP_INFO_v1* TcpInfo = reinterpret_cast(ReqInfo.pInfo); spdlog::info(""); } break; } } # endif if (HttpService* Service = reinterpret_cast(HttpReq->UrlContext)) { if (m_IsInitialRequest) { m_ContentLength = GetContentLength(HttpReq); if (m_ContentLength) { // Handle initial chunk read by copying any payload which has already been copied // into our embedded request buffer m_PayloadBuffer = IoBuffer(m_ContentLength); HttpContentType ContentType = GetContentType(HttpReq); m_PayloadBuffer.SetContentType(ContentType); uint64_t BytesToRead = m_ContentLength; uint8_t* const BufferBase = reinterpret_cast(m_PayloadBuffer.MutableData()); uint8_t* BufferWriteCursor = BufferBase; const int EntityChunkCount = HttpReq->EntityChunkCount; for (int i = 0; i < EntityChunkCount; ++i) { HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; ZEN_ASSERT(BufferLength <= BytesToRead); memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); BufferWriteCursor += BufferLength; BytesToRead -= BufferLength; } m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; } } else { m_CurrentPayloadOffset += NumberOfBytesTransferred; } if (m_CurrentPayloadOffset == m_ContentLength) { m_PayloadBuffer.MakeImmutable(); // Body received completely - call request handler HttpSysServerRequest ThisRequest(Transaction(), *Service, m_PayloadBuffer); Service->HandleRequest(ThisRequest); if (!ThisRequest.IsHandled()) { return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); } if (ThisRequest.m_Response) { return ThisRequest.m_Response; } } else { // Issue a read request for more body data return this; } } // Unable to route return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); } catch (std::exception& ex) { // TODO provide more meaningful error output return new HttpMessageResponseRequest(Transaction(), 500, ex.what()); } } ////////////////////////////////////////////////////////////////////////// // // HttpServer interface implementation // void HttpSysServer::Initialize(int BasePort) { using namespace std::literals; WideStringBuilder<64> BaseUri; BaseUri << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; Initialize(BaseUri.c_str()); StartServer(); } void HttpSysServer::RequestExit() { m_ShutdownEvent.Set(); } void HttpSysServer::RegisterService(HttpService& Service) { RegisterService(Service.BaseUri(), Service); } } // namespace zen #endif