// Copyright Epic Games, Inc. All Rights Reserved. #include "httpsys.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #if ZEN_WITH_HTTPSYS # define _WINSOCKAPI_ # include # include # include "iothreadpool.h" # include namespace zen { const FLLMTag& GetHttpsysTag() { static FLLMTag HttpsysTag("httpsys"); return HttpsysTag; } static void GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool IncludePort = true) { if (SockAddr) { // When a port is desired, use WSAAddressToStringA which includes the port. if (IncludePort) { CHAR AddrBuf[64] = {}; DWORD AddrBufLen = sizeof(AddrBuf); const DWORD AddrLen = (SockAddr->sa_family == AF_INET) ? sizeof(SOCKADDR_IN) : sizeof(SOCKADDR_IN6); if (WSAAddressToStringA((LPSOCKADDR)SockAddr, AddrLen, nullptr, AddrBuf, &AddrBufLen) == 0) { OutString.Append(AddrBuf, AddrBufLen); return; } } else { // When port should be omitted, use getnameinfo with NI_NUMERICHOST to get only the numeric address. CHAR HostBuf[64] = {}; const int SockLen = (SockAddr->sa_family == AF_INET) ? sizeof(SOCKADDR_IN) : sizeof(SOCKADDR_IN6); if (getnameinfo((const SOCKADDR*)SockAddr, SockLen, HostBuf, (DWORD)sizeof(HostBuf), nullptr, 0, NI_NUMERICHOST) == 0) { OutString.Append(HostBuf, (DWORD)strlen(HostBuf)); return; } } } OutString.Append("unknown"); } class HttpSysServerRequest; /** * @brief Windows implementation of HTTP server based on http.sys * * This requires elevation to function by default but system configuration * can soften this requirement. * * See README.md for details. */ class HttpSysServer : public HttpServer { friend class HttpSysTransaction; public: explicit HttpSysServer(const HttpSysConfig& Config); ~HttpSysServer(); // HttpServer interface implementation virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool TestMode) override; virtual void OnRequestExit() override; virtual void OnRegisterService(HttpService& Service) override; virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual void OnClose() override; WorkerThreadPool& WorkPool(); inline bool IsOk() const { return m_IsOk; } inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } IHttpRequestFilter::Result FilterRequest(HttpSysServerRequest& Request); private: int InitializeServer(int BasePort); void Cleanup(); void StartServer(); void OnHandlingNewRequest(); void IssueNewRequestMaybe(); void RegisterService(const char* Endpoint, HttpService& Service); void UnregisterService(const char* Endpoint, HttpService& Service); private: LoggerRef m_Log; LoggerRef m_RequestLog; LoggerRef Log() { return m_Log; } bool m_IsOk = false; bool m_IsHttpInitialized = false; bool m_IsRequestLoggingEnabled = false; bool m_IsAsyncResponseEnabled = true; std::unique_ptr m_IoThreadPool; RwLock m_AsyncWorkPoolInitLock; WorkerThreadPool* m_AsyncWorkPool = nullptr; std::vector m_BaseUris; // eg: http://*:nnnn/ HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; HANDLE m_RequestQueueHandle = 0; std::atomic_int32_t m_PendingRequests{0}; std::atomic_int32_t m_IsShuttingDown{0}; int32_t m_MinPendingRequests = 16; int32_t m_MaxPendingRequests = 128; Event m_ShutdownEvent; HttpSysConfig m_InitialConfig; RwLock m_RequestFilterLock; std::atomic m_HttpRequestFilter = nullptr; }; } // namespace zen #endif #if ZEN_WITH_HTTPSYS # include # include # pragma comment(lib, "httpapi.lib") std::wstring UTF8_to_UTF16(const char* InPtr) { std::wstring OutString; unsigned int Codepoint; while (*InPtr != 0) { unsigned char InChar = static_cast(*InPtr); if (InChar <= 0x7f) Codepoint = InChar; else if (InChar <= 0xbf) Codepoint = (Codepoint << 6) | (InChar & 0x3f); else if (InChar <= 0xdf) Codepoint = InChar & 0x1f; else if (InChar <= 0xef) Codepoint = InChar & 0x0f; else Codepoint = InChar & 0x07; ++InPtr; if (((*InPtr & 0xc0) != 0x80) && (Codepoint <= 0x10ffff)) { if (Codepoint > 0xffff) { OutString.append(1, static_cast(0xd800 + (Codepoint >> 10))); OutString.append(1, static_cast(0xdc00 + (Codepoint & 0x03ff))); } else if (Codepoint < 0xd800 || Codepoint >= 0xe000) { OutString.append(1, static_cast(Codepoint)); } } } return OutString; } namespace zen { using namespace std::literals; class HttpSysServer; class HttpSysTransaction; class HttpMessageResponseRequest; ////////////////////////////////////////////////////////////////////////// HttpVerb TranslateHttpVerb(HTTP_VERB ReqVerb) { switch (ReqVerb) { case HttpVerbOPTIONS: return HttpVerb::kOptions; case HttpVerbGET: return HttpVerb::kGet; case HttpVerbHEAD: return HttpVerb::kHead; case HttpVerbPOST: return HttpVerb::kPost; case HttpVerbPUT: return HttpVerb::kPut; case HttpVerbDELETE: return HttpVerb::kDelete; case HttpVerbCOPY: return HttpVerb::kCopy; default: // TODO: invalid request? return (HttpVerb)0; } } uint64_t GetContentLength(const HTTP_REQUEST* HttpRequest) { const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; std::string_view cl(clh.pRawValue, clh.RawValueLength); uint64_t ContentLength = 0; std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); return ContentLength; }; HttpContentType GetContentType(const HTTP_REQUEST* HttpRequest) { const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); }; HttpContentType GetAcceptType(const HTTP_REQUEST* HttpRequest) { const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); }; /** * @brief Base class for any pending or active HTTP transactions */ class HttpSysRequestHandler { public: explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {} virtual ~HttpSysRequestHandler() = default; virtual void IssueRequest(std::error_code& ErrorCode) = 0; virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; HttpSysTransaction& Transaction() { return m_Transaction; } HttpSysRequestHandler(const HttpSysRequestHandler&) = delete; HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete; private: HttpSysTransaction& m_Transaction; }; /** * This is the handler for the initial HTTP I/O request which will receive the headers * and however much of the remaining payload might fit in the embedded request buffer. * * It is also used to receive any entity body data relating to the request * */ struct InitialRequestHandler : public HttpSysRequestHandler { inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } inline bool IsInitialRequest() const { return m_IsInitialRequest; } InitialRequestHandler(HttpSysTransaction& InRequest); ~InitialRequestHandler(); virtual void IssueRequest(std::error_code& ErrorCode) override final; virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; bool m_IsInitialRequest = true; uint64_t m_CurrentPayloadOffset = 0; uint64_t m_ContentLength = ~uint64_t(0); IoBuffer m_PayloadBuffer; UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)]; }; /** * This is the class which request handlers use to interact with the server instance */ class HttpSysServerRequest : public HttpServerRequest { public: HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); ~HttpSysServerRequest(); virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; virtual bool IsLocalMachineRequest() const; virtual std::string_view GetAuthorizationHeader() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; virtual void WriteResponseAsync(std::function&& ContinuationHandler) override; virtual bool TryGetRanges(HttpRanges& Ranges) override; void LogRequest(HttpMessageResponseRequest* Response); using HttpServerRequest::WriteResponse; HttpSysServerRequest(const HttpSysServerRequest&) = delete; HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; HttpSysTransaction& m_HttpTx; HttpSysRequestHandler* m_NextCompletionHandler = nullptr; IoBuffer m_PayloadBuffer; ExtendableStringBuilder<128> m_UriUtf8; ExtendableStringBuilder<128> m_QueryStringUtf8; }; /** HTTP transaction There will be an instance of this per pending and in-flight HTTP transaction */ class HttpSysTransaction final { public: HttpSysTransaction(HttpSysServer& Server); virtual ~HttpSysTransaction(); enum class Status { kDone, kRequestPending }; [[nodiscard]] Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, PVOID pContext /* HttpSysServer */, PVOID pOverlapped, ULONG IoResult, ULONG_PTR NumberOfBytesTransferred, PTP_IO Io); void IssueInitialRequest(std::error_code& ErrorCode); bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler); PTP_IO Iocp(); HANDLE RequestQueueHandle(); inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } inline HttpSysServer& Server() { return m_HttpServer; } inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload); HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); } struct CompletionMutexScope { CompletionMutexScope(HttpSysTransaction& Tx) : Lock(Tx.m_CompletionMutex) {} ~CompletionMutexScope() = default; RwLock::ExclusiveLockScope Lock; }; private: OVERLAPPED m_HttpOverlapped{}; HttpSysServer& m_HttpServer; // Tracks which handler is due to handle the next I/O completion event HttpSysRequestHandler* m_CompletionHandler = nullptr; RwLock m_CompletionMutex; InitialRequestHandler m_InitialHttpHandler{*this}; std::optional m_HandlerRequest; Ref m_PackageHandler; }; /** * @brief HTTP request response I/O request handler * * Asynchronously streams out a response to an HTTP request via compound * responses from memory or directly from file */ class HttpMessageResponseRequest : public HttpSysRequestHandler { public: HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, HttpContentType ContentType, const void* Payload, size_t PayloadSize); HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, HttpContentType ContentType, std::span Blobs); ~HttpMessageResponseRequest(); virtual void IssueRequest(std::error_code& ErrorCode) override final; virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; void SuppressResponseBody(); // typically used for HEAD requests inline uint16_t GetResponseCode() const { return m_ResponseCode; } inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } private: eastl::fixed_vector m_HttpDataChunks; uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes uint16_t m_ResponseCode = 0; uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; eastl::fixed_vector m_DataBuffers; void InitializeForPayload(uint16_t ResponseCode, std::span Blobs); }; HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) : HttpSysRequestHandler(InRequest) { std::array EmptyBufferList; InitializeForPayload(ResponseCode, EmptyBufferList); } HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) : HttpSysRequestHandler(InRequest) , m_ContentType(HttpContentType::kText) { IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); std::array SingleBufferList({MessageBuffer}); InitializeForPayload(ResponseCode, SingleBufferList); } HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, HttpContentType ContentType, const void* Payload, size_t PayloadSize) : HttpSysRequestHandler(InRequest) , m_ContentType(ContentType) { IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); std::array SingleBufferList({MessageBuffer}); InitializeForPayload(ResponseCode, SingleBufferList); } HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, HttpContentType ContentType, std::span BlobList) : HttpSysRequestHandler(InRequest) , m_ContentType(ContentType) { InitializeForPayload(ResponseCode, BlobList); } HttpMessageResponseRequest::~HttpMessageResponseRequest() { } void HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span BlobList) { ZEN_TRACE_CPU("httpsys::InitializeForPayload"); const uint32_t ChunkCount = gsl::narrow(BlobList.size()); m_HttpDataChunks.reserve(ChunkCount); m_DataBuffers.reserve(ChunkCount); for (IoBuffer& Buffer : BlobList) { m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); } // Initialize the full array up front uint64_t LocalDataSize = 0; for (IoBuffer& Buffer : m_DataBuffers) { uint64_t BufferDataSize = Buffer.Size(); ZEN_ASSERT(BufferDataSize); LocalDataSize += BufferDataSize; IoBufferFileReference FileRef; if (Buffer.GetFileReference(/* out */ FileRef)) { // Use direct file transfer auto& Chunk = m_HttpDataChunks.emplace_back(); Chunk.DataChunkType = HttpDataChunkFromFileHandle; Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; } else { // Send from memory, need to make sure we chunk the buffer up since // the underlying data structure only accepts 32-bit chunk sizes for // memory chunks. When this happens the vector will be reallocated, // which is fine since this will be a pretty rare case and sending // the data is going to take a lot longer than a memory allocation :) const uint8_t* WriteCursor = reinterpret_cast(Buffer.Data()); while (BufferDataSize) { const ULONG ThisChunkSize = gsl::narrow(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); auto& Chunk = m_HttpDataChunks.emplace_back(); Chunk.DataChunkType = HttpDataChunkFromMemory; Chunk.FromMemory.pBuffer = (void*)WriteCursor; Chunk.FromMemory.BufferLength = ThisChunkSize; BufferDataSize -= ThisChunkSize; WriteCursor += ThisChunkSize; } } } m_RemainingChunkCount = gsl::narrow(m_HttpDataChunks.size()); m_TotalDataSize = LocalDataSize; if (m_TotalDataSize == 0 && ResponseCode == 200) { // Some HTTP clients really don't like empty responses unless a 204 response is sent m_ResponseCode = uint16_t(HttpResponseCode::NoContent); } else { m_ResponseCode = ResponseCode; } } void HttpMessageResponseRequest::SuppressResponseBody() { m_RemainingChunkCount = 0; m_HttpDataChunks.clear(); m_DataBuffers.clear(); } HttpSysRequestHandler* HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { ZEN_UNUSED(NumberOfBytesTransferred); if (IoResult != NO_ERROR) { ZEN_WARN("response '{}' ({}) aborted after transfering '{}', {} out of {} bytes, reason: {} ({})", ReasonStringForHttpResultCode(m_ResponseCode), m_ResponseCode, ToString(m_ContentType), NumberOfBytesTransferred, m_TotalDataSize, GetSystemErrorAsString(IoResult), IoResult); // if one transmit failed there's really no need to go on return nullptr; } if (m_RemainingChunkCount == 0) { return nullptr; // All done } return this; } void HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) { ZEN_TRACE_CPU("httpsys::Response::IssueRequest"); HttpSysTransaction& Tx = Transaction(); HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); PTP_IO const Iocp = Tx.Iocp(); StartThreadpoolIo(Iocp); // Split payload into batches to play well with the underlying API const int MaxChunksPerCall = 9999; const int ThisRequestChunkCount = std::min(m_RemainingChunkCount, MaxChunksPerCall); const int ThisRequestChunkOffset = m_NextDataChunkOffset; m_RemainingChunkCount -= ThisRequestChunkCount; m_NextDataChunkOffset += ThisRequestChunkCount; /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA? From the docs: This flag enables buffering of data in the kernel on a per-response basis. It should be used by an application doing synchronous I/O, or by a an application doing asynchronous I/O with no more than one send outstanding at a time. Applications using asynchronous I/O which may have more than one send outstanding at a time should not use this flag. When this flag is set, it should be used consistently in calls to the HttpSendHttpResponse function as well. */ ULONG SendFlags = HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA; if (m_RemainingChunkCount) { // We need to make more calls to send the full amount of data SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA; } ULONG SendResult = 0; if (m_IsInitialResponse) { // Populate response structure HTTP_RESPONSE HttpResponse = {}; HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount); HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset; // Server header // // By default this will also add a suffix " Microsoft-HTTPAPI/2.0" to this header // // This is controlled via a registry key 'DisableServerHeader', at: // // Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\HTTP\Parameters // // Set DisableServerHeader to 1 to disable suffix, or 2 to disable the header altogether // (only the latter appears to do anything in my testing, on Windows 10). // // (reference https://docs.microsoft.com/en-us/archive/blogs/dsnotes/wswcf-remove-server-header) // PHTTP_KNOWN_HEADER ServerHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderServer]; ServerHeader->pRawValue = "Zen"; ServerHeader->RawValueLength = (USHORT)3; // Content-length header char ContentLengthString[32]; _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10); PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength]; ContentLengthHeader->pRawValue = ContentLengthString; ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString); // Content-type header PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.StatusCode = m_ResponseCode; HttpResponse.pReason = ReasonString.data(); HttpResponse.ReasonLength = (USHORT)ReasonString.size(); // Cache policy HTTP_CACHE_POLICY CachePolicy; CachePolicy.Policy = HttpCachePolicyNocache; CachePolicy.SecondsToLive = 0; // Initial response API call SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), // RequestQueueHandle HttpReq->RequestId, // RequestId SendFlags, // Flags &HttpResponse, // HttpResponse &CachePolicy, // CachePolicy NULL, // BytesSent NULL, // Reserved1 0, // Reserved2 Tx.Overlapped(), // Overlapped NULL // LogData ); m_IsInitialResponse = false; } else { // Subsequent response API calls SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), // RequestQueueHandle HttpReq->RequestId, // RequestId SendFlags, // Flags (USHORT)ThisRequestChunkCount, // EntityChunkCount &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks NULL, // BytesSent NULL, // Reserved1 0, // Reserved2 Tx.Overlapped(), // Overlapped NULL // LogData ); } auto EmitResponseDetails = [&](StringBuilderBase& ResponseDetails) -> void { for (int i = 0; i < ThisRequestChunkCount; ++i) { const HTTP_DATA_CHUNK Chunk = m_HttpDataChunks[ThisRequestChunkOffset + i]; if (i > 0) { ResponseDetails << " + "; } switch (Chunk.DataChunkType) { case HttpDataChunkFromMemory: ResponseDetails << "mem:" << uint64_t(Chunk.FromMemory.BufferLength); break; case HttpDataChunkFromFileHandle: ResponseDetails << "file:"; { ResponseDetails << uint64_t(Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart) << "," << uint64_t(Chunk.FromFileHandle.ByteRange.Length.QuadPart) << ","; std::error_code PathEc; HANDLE FileHandle = Chunk.FromFileHandle.FileHandle; std::filesystem::path Path = PathFromHandle(FileHandle, PathEc); if (PathEc) { ResponseDetails << "bad_file(handle=" << reinterpret_cast(FileHandle) << ",error=" << PathEc.message() << ")"; } else { const uint64_t FileSize = FileSizeFromHandle(FileHandle); ResponseDetails << Path.u8string() << "(" << FileSize << ") handle=" << reinterpret_cast(FileHandle); } } break; case HttpDataChunkFromFragmentCache: ResponseDetails << "frag:???"; // We do not use these break; case HttpDataChunkFromFragmentCacheEx: ResponseDetails << "frax:???"; // We do not use these break; # if 0 // Not available in older Windows SDKs case HttpDataChunkTrailers: ResponseDetails << "trls:???"; // We do not use these break; # endif default: ResponseDetails << "???: " << Chunk.DataChunkType; break; } } }; if (SendResult == NO_ERROR) { // Synchronous completion, but the completion event will still be posted to IOCP ErrorCode.clear(); } else if (SendResult == ERROR_IO_PENDING) { // Asynchronous completion, a completion notification will be posted to IOCP ErrorCode.clear(); } else { ErrorCode = MakeErrorCode(SendResult); // An error occurred, no completion will be posted to IOCP CancelThreadpoolIo(Iocp); // Emit diagnostics ExtendableStringBuilder<256> ResponseDetails; EmitResponseDetails(ResponseDetails); ZEN_WARN("failed to send HTTP response (error {}: '{}'), request URL: '{}', ({}.{}) response: {}", SendResult, ErrorCode.message(), HttpReq->pRawUrl, Tx.ServerRequest().SessionId(), HttpReq->RequestId, ResponseDetails); } } /** HTTP completion handler for async work This is used to allow work to be taken off the request handler threads and to support posting responses asynchronously. */ class HttpAsyncWorkRequest : public HttpSysRequestHandler { public: HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function&& Response); ~HttpAsyncWorkRequest(); virtual void IssueRequest(std::error_code& ErrorCode) override final; virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; private: struct AsyncWorkItem : public IWork { virtual void Execute() override; AsyncWorkItem(HttpSysTransaction& InTx, std::function&& InHandler) : Tx(InTx) , Handler(std::move(InHandler)) { } HttpSysTransaction& Tx; std::function Handler; }; Ref m_WorkItem; }; HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function&& Response) : HttpSysRequestHandler(Tx) { m_WorkItem = new AsyncWorkItem(Tx, std::move(Response)); } HttpAsyncWorkRequest::~HttpAsyncWorkRequest() { } void HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode) { ZEN_TRACE_CPU("httpsys::AsyncWork::IssueRequest"); ErrorCode.clear(); Transaction().Server().WorkPool().ScheduleWork(m_WorkItem, WorkerThreadPool::EMode::EnableBacklog); } HttpSysRequestHandler* HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { // This ought to not be called since there should be no outstanding I/O request // when this completion handler is active ZEN_UNUSED(IoResult, NumberOfBytesTransferred); ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); return this; } void HttpAsyncWorkRequest::AsyncWorkItem::Execute() { ZEN_MEMSCOPE(GetHttpsysTag()); ZEN_TRACE_CPU("httpsys::async_execute"); try { // We need to hold this lock while we're issuing new requests in order to // prevent race conditions between the thread we are running on and any // IOCP service threads. Otherwise the IOCP completion handler can end // up deleting the transaction object before we are done with it! HttpSysTransaction::CompletionMutexScope _(Tx); HttpSysServerRequest& ThisRequest = Tx.ServerRequest(); ThisRequest.m_NextCompletionHandler = nullptr; { ZEN_TRACE_CPU("httpsys::HandleRequest"); Handler(ThisRequest); } // TODO: should Handler be destroyed at this point to ensure there // are no outstanding references into state which could be // deleted asynchronously as a result of issuing the response? if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler) { return (void)Tx.IssueNextRequest(NextHandler); } else if (!ThisRequest.IsHandled()) { return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv)); } else { // "Handled" but no request handler? Shouldn't ever happen return (void)Tx.IssueNextRequest( new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv)); } } catch (const AssertException& AssertEx) { return (void)Tx.IssueNextRequest( new HttpMessageResponseRequest(Tx, 500, fmt::format("Assert thrown in async work: '{}", AssertEx.FullDescription()))); } catch (const std::exception& Ex) { return (void)Tx.IssueNextRequest( new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: {}", Ex.what()))); } } /** _________ / _____/ ______________ __ ___________ \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ / \ ___/| | \/\ /\ ___/| | \/ /_______ /\___ >__| \_/ \___ >__| \/ \/ \/ */ HttpSysServer::HttpSysServer(const HttpSysConfig& InConfig) : m_Log(logging::Get("httpsys")) , m_RequestLog(logging::Get("http_requests")) , m_IsRequestLoggingEnabled(InConfig.IsRequestLoggingEnabled) , m_IsAsyncResponseEnabled(InConfig.IsAsyncResponseEnabled) , m_InitialConfig(InConfig) { ZEN_MEMSCOPE(GetHttpsysTag()); // Initialize thread pool int MinThreadCount; int MaxThreadCount; if (m_InitialConfig.ThreadCount == 0) { MinThreadCount = Max(8u, GetHardwareConcurrency()); } else { MinThreadCount = m_InitialConfig.ThreadCount; } MaxThreadCount = MinThreadCount * 2; if (m_InitialConfig.IsDedicatedServer && m_InitialConfig.ThreadCount == 0) { // In order to limit the potential impact of threads stuck // in locks we allow the thread pool to be oversubscribed // by a fair amount MaxThreadCount *= 2; } m_IoThreadPool = std::make_unique(MinThreadCount, MaxThreadCount); if (m_InitialConfig.AsyncWorkThreadCount == 0) { m_InitialConfig.AsyncWorkThreadCount = 16; } ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr); if (Result != NO_ERROR) { return; } m_IsHttpInitialized = true; m_IsOk = true; ZEN_INFO("http.sys server started in {} mode, using {}-{} I/O threads and {} async worker threads", m_InitialConfig.IsDedicatedServer ? "DEDICATED" : "NORMAL", MinThreadCount, MaxThreadCount, m_InitialConfig.AsyncWorkThreadCount); } HttpSysServer::~HttpSysServer() { if (m_IsHttpInitialized) { ZEN_ERROR("~HttpSysServer() called without calling Close() first"); } delete m_AsyncWorkPool; m_AsyncWorkPool = nullptr; } void HttpSysServer::OnClose() { if (m_IsHttpInitialized) { Cleanup(); HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr); m_IsHttpInitialized = false; } } int HttpSysServer::InitializeServer(int BasePort) { ZEN_MEMSCOPE(GetHttpsysTag()); using namespace std::literals; WideStringBuilder<64> WildcardUrlPath; WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; m_IsOk = false; ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); if (Result != NO_ERROR) { ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); return 0; } Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); if (Result != NO_ERROR) { ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); return 0; } m_BaseUris.clear(); const bool AllowPortProbing = !m_InitialConfig.IsDedicatedServer; const bool AllowLocalOnly = !m_InitialConfig.IsDedicatedServer; int EffectivePort = BasePort; if (m_InitialConfig.ForceLoopback) { // Force trigger of opening using local port ZEN_ASSERT(AllowLocalOnly); Result = ERROR_ACCESS_DENIED; } else { Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); if ((Result == ERROR_SHARING_VIOLATION)) { ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); if (Result == ERROR_SHARING_VIOLATION) { if (AllowPortProbing) { // Sharing violation implies the port is being used by another process for (int PortOffset = 1; (Result == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) { EffectivePort = BasePort + (PortOffset * 100); WildcardUrlPath.Reset(); WildcardUrlPath << u8"http://*:"sv << int64_t(EffectivePort) << u8"/"sv; Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); } } else { for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++) { ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); } } } } } if (Result == NO_ERROR) { m_BaseUris.push_back(WildcardUrlPath.c_str()); } else if (Result == ERROR_ACCESS_DENIED) { if (AllowLocalOnly) { // If we can't register the wildcard path, we fall back to local paths // This local paths allow requests originating locally to function, but will not allow // remote origin requests to function. This can be remedied by using netsh // during an install process to grant permissions to route public access to the appropriate // port for the current user. eg: // netsh http add urlacl url=http://*:8558/ user= if (!m_InitialConfig.ForceLoopback) { ZEN_WARN( "Unable to register handler using '{}' - falling back to local-only. " "Please ensure the appropriate netsh URL reservation configuration " "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)", WideToUtf8(WildcardUrlPath)); } const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; ULONG InternalResult = ERROR_SHARING_VIOLATION; for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) { EffectivePort = BasePort + (PortOffset * 100); for (const std::u8string_view Host : Hosts) { WideStringBuilder<64> LocalUrlPath; LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); if (InternalResult == NO_ERROR) { ZEN_WARN("Registered local-only handler '{}' - this is not accessible from remote hosts", WideToUtf8(LocalUrlPath)); m_BaseUris.push_back(LocalUrlPath.c_str()); } else { break; } } } } else { ZEN_ERROR( "Unable to register URL handler for '{}' - access denied. Please ensure the appropriate netsh URL reservation " "configuration is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)", WideToUtf8(WildcardUrlPath)); return 0; } } if (m_BaseUris.empty()) { ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); return 0; } HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; WideStringBuilder<64> QueueName; QueueName << "zenserver_" << EffectivePort; Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, /* Name */ QueueName.c_str(), /* SecurityAttributes */ nullptr, /* Flags */ 0, &m_RequestQueueHandle); if (Result != NO_ERROR) { ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); return 0; } HttpBindingInfo.Flags.Present = 1; HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle; Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo)); if (Result != NO_ERROR) { ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); return 0; } // Configure rejection method. Default is to drop the connection, it's better if we // return an explicit error code when the queue cannot accept more requests { HTTP_503_RESPONSE_VERBOSITY VerbosityInformation = Http503ResponseVerbosityLimited; Result = HttpSetRequestQueueProperty(m_RequestQueueHandle, HttpServer503VerbosityProperty, &VerbosityInformation, sizeof VerbosityInformation, 0, 0); } // Tune the maximum number of pending requests in the http.sys request queue. By default // the value is 1000 which is plenty for single user machines but for dedicated servers // serving many users it makes sense to increase this to a higher number to help smooth // out intermittent stalls like we might experience when GC is triggered if (m_InitialConfig.IsDedicatedServer) { ULONG QueueLength = 50000; Result = HttpSetRequestQueueProperty(m_RequestQueueHandle, HttpServerQueueLengthProperty, &QueueLength, sizeof QueueLength, 0, 0); if (Result != NO_ERROR) { ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result); } } // Create I/O completion port std::error_code ErrorCode; m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); if (ErrorCode) { ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); return 0; } else { m_IsOk = true; ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); } // This is not available in all Windows SDK versions so for now we can't use recently // released functionality. We should investigate how to get more recent SDK releases // into the build # if 0 if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4)) { ZEN_DEBUG("HTTP3 is available"); } else { ZEN_DEBUG("HTTP3 is NOT available"); } # endif return EffectivePort; } void HttpSysServer::Cleanup() { ++m_IsShuttingDown; if (m_RequestQueueHandle) { HttpCloseRequestQueue(m_RequestQueueHandle); m_RequestQueueHandle = nullptr; } if (m_HttpUrlGroupId) { HttpCloseUrlGroup(m_HttpUrlGroupId); m_HttpUrlGroupId = 0; } if (m_HttpSessionId) { HttpCloseServerSession(m_HttpSessionId); m_HttpSessionId = 0; } } WorkerThreadPool& HttpSysServer::WorkPool() { ZEN_MEMSCOPE(GetHttpsysTag()); if (!m_AsyncWorkPool) { RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); if (!m_AsyncWorkPool) { m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"); } } return *m_AsyncWorkPool; } void HttpSysServer::StartServer() { const int InitialRequestCount = 32; for (int i = 0; i < InitialRequestCount; ++i) { IssueNewRequestMaybe(); } } void HttpSysServer::OnRun(bool IsInteractive) { if (IsInteractive) { ZEN_CONSOLE("Zen Server running (http.sys). Press ESC or Q to quit"); } bool ShutdownRequested = false; do { int WaitTimeout = 100; if (IsInteractive) { WaitTimeout = 1000; if (_kbhit() != 0) { char c = (char)_getch(); if (c == 27 || c == 'Q' || c == 'q') { m_ShutdownEvent.Set(); RequestApplicationExit(0); } } } ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); UpdateLofreqTimerValue(); } while (!ShutdownRequested); } void HttpSysServer::OnHandlingNewRequest() { if (--m_PendingRequests > m_MinPendingRequests) { // We have more than the minimum number of requests pending, just let someone else // enqueue new requests. This should be the common case as we check if we need to // enqueue more requests before exiting the completion handler. return; } IssueNewRequestMaybe(); } void HttpSysServer::IssueNewRequestMaybe() { if (m_IsShuttingDown.load(std::memory_order::acquire)) { return; } if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests) { return; } ZEN_MEMSCOPE(GetHttpsysTag()); std::unique_ptr Request = std::make_unique(*this); std::error_code ErrorCode; Request->IssueInitialRequest(ErrorCode); if (ErrorCode) { // No request was actually issued. What is the appropriate response? return; } // This may end up exceeding the MaxPendingRequests limit, but it's not // really a problem. I'm doing it this way mostly to avoid dealing with // exceptions here ++m_PendingRequests; Request.release(); } void HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) { ZEN_MEMSCOPE(GetHttpsysTag()); if (UrlPath[0] == '/') { ++UrlPath; } const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); Service.SetUriPrefixLength(PathUtf16.size() + 1 /* leading slash */); // Convert to wide string for (const std::wstring& BaseUri : m_BaseUris) { std::wstring Url16 = BaseUri + PathUtf16; ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); if (Result != NO_ERROR) { ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); return; } } } void HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) { ZEN_UNUSED(Service); if (UrlPath[0] == '/') { ++UrlPath; } const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); // Convert to wide string for (const std::wstring& BaseUri : m_BaseUris) { std::wstring Url16 = BaseUri + PathUtf16; ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); if (Result != NO_ERROR) { ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); } } } ////////////////////////////////////////////////////////////////////////// HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler) { } HttpSysTransaction::~HttpSysTransaction() { } PTP_IO HttpSysTransaction::Iocp() { return m_HttpServer.m_IoThreadPool->Iocp(); } HANDLE HttpSysTransaction::RequestQueueHandle() { return m_HttpServer.m_RequestQueueHandle; } void HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode) { m_InitialHttpHandler.IssueRequest(ErrorCode); } thread_local bool t_IsHttpSysThreadNamed = false; static std::atomic HttpSysThreadIndex = 0; static void NameCurrentHttpSysThread() { t_IsHttpSysThreadNamed = true; const int ThreadIndex = ++HttpSysThreadIndex; zen::ExtendableStringBuilder<16> ThreadName; ThreadName << "httpio_" << ThreadIndex; SetCurrentThreadName(ThreadName); } void HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, PVOID pContext /* HttpSysServer */, PVOID pOverlapped, ULONG IoResult, ULONG_PTR NumberOfBytesTransferred, PTP_IO Io) { ZEN_UNUSED(Io, Instance); // Assign names to threads for context if (!t_IsHttpSysThreadNamed) { NameCurrentHttpSysThread(); } // Note that for a given transaction we may be in this completion function on more // than one thread at any given moment. This means we need to be careful about what // happens in here HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) { delete Transaction; } // Ensure new requests are enqueued as necessary. We do this here instead // of inside the transaction completion handler now to avoid spending time // in unrelated API calls while holding the transaction lock if (HttpSysServer* HttpServer = reinterpret_cast(pContext)) { HttpServer->IssueNewRequestMaybe(); } } bool HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler) { ZEN_TRACE_CPU("httpsys::Transaction::IssueNextRequest"); HttpSysRequestHandler* CurrentHandler = m_CompletionHandler; m_CompletionHandler = NewCompletionHandler; auto _ = MakeGuard([this, CurrentHandler] { if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler)) { delete CurrentHandler; } }); if (NewCompletionHandler == nullptr) { return false; } try { std::error_code ErrorCode; m_CompletionHandler->IssueRequest(ErrorCode); if (!ErrorCode) { return true; } ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message()); } catch (const AssertException& AssertEx) { ZEN_ERROR("Assert thrown in IssueNextRequest(): {}", AssertEx.FullDescription()); } catch (const std::exception& Ex) { ZEN_ERROR("exception caught in IssueNextRequest(): {}", Ex.what()); } // something went wrong, no request is pending m_CompletionHandler = nullptr; return false; } HttpSysTransaction::Status HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { // We use this to ensure sequential execution of completion handlers // for any given transaction. It also ensures all member variables are // in a consistent state for the current thread RwLock::ExclusiveLockScope _(m_CompletionMutex); bool IsRequestPending = false; if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler) { if ((CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest()) { // Ensure we have a sufficient number of pending requests outstanding m_HttpServer.OnHandlingNewRequest(); } auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred); IsRequestPending = IssueNextRequest(NewCompletionHandler); } if (IsRequestPending) { // There is another request pending on this transaction, so it needs to remain valid return Status::kRequestPending; } if (m_HttpServer.m_IsRequestLoggingEnabled) { if (m_HandlerRequest.has_value()) { ZEN_LOG_INFO(m_HttpServer.m_RequestLog, "{} {}", ToString(m_HandlerRequest->RequestVerb()), m_HandlerRequest->RelativeUri()); } } // Transaction done, caller should clean up (delete) this instance return Status::kDone; } HttpSysServerRequest& HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) { HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); // Default request handling # if ZEN_WITH_OTEL std::string_view Verb = ToString(ThisRequest.RequestVerb()); std::string_view Uri = ThisRequest.m_UriUtf8.ToView(); auto SpanNamer = [&](StringBuilderBase& SpanName) { SpanName << Verb << " " << Uri; }; auto SpanAnnotator = [&](otel::Span& Span) { Span.AddAttribute("http.request.method"sv, Verb); Span.AddAttribute("url.path"sv, Uri); // FIXME: should be total size including headers etc according to spec Span.AddAttribute("http.request.size"sv, static_cast(ThisRequest.ContentLength())); Span.SetKind(otel::Span::Kind::kServer); // client.address const SOCKADDR* SockAddr = ThisRequest.m_HttpTx.HttpRequest()->Address.pRemoteAddress; ExtendableStringBuilder<64> ClientAddr; GetAddressString(ClientAddr, SockAddr, /* IncludePort */ false); Span.AddAttribute("client.address"sv, ClientAddr.ToView()); }; otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator); # endif IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest); if (FilterResult == IHttpRequestFilter::Result::Accepted) { if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) { Service.HandleRequest(ThisRequest); } } else if (FilterResult == IHttpRequestFilter::Result::Forbidden) { ThisRequest.WriteResponse(HttpResponseCode::Forbidden); } else { ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); } return ThisRequest; } ////////////////////////////////////////////////////////////////////////// HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) : HttpServerRequest(Service) , m_HttpTx(Tx) , m_PayloadBuffer(std::move(PayloadBuffer)) { const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); const int PrefixLength = Service.UriPrefixLength(); const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(wchar_t); HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; WideToUtf8({(wchar_t*)HttpRequestPtr->CookedUrl.pAbsPath, gsl::narrow(AbsPathLength)}, m_UriUtf8); if (AbsPathLength >= PrefixLength) { // We convert the URI immediately because most of the code involved prefers to deal // with utf8. This is overhead which I'd prefer to avoid but for now we just have // to live with it std::string_view UriSuffix8{m_UriUtf8}; UriSuffix8.remove_prefix(PrefixLength); m_UriWithExtension = UriSuffix8; // Retain URI with extension for user access m_Uri = UriSuffix8; const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); if (LastComponentIndex != std::string_view::npos) { UriSuffix8.remove_prefix(LastComponentIndex); } const size_t LastDotIndex = UriSuffix8.find_last_of('.'); if (LastDotIndex != std::string_view::npos) { UriSuffix8.remove_prefix(LastDotIndex + 1); AcceptContentType = ParseContentType(UriSuffix8); if (AcceptContentType != HttpContentType::kUnknownContentType) { m_Uri.remove_suffix(UriSuffix8.size() + 1); } } } else { m_Uri = {}; m_UriWithExtension = {}; } if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) { --QueryStringLength; // We skip the leading question mark WideToUtf8({(wchar_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(wchar_t)}, m_QueryStringUtf8); } else { m_QueryStringUtf8.Reset(); } m_QueryString = std::string_view(m_QueryStringUtf8); m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); m_ContentLength = GetContentLength(HttpRequestPtr); m_ContentType = GetContentType(HttpRequestPtr); // It an explicit content type extension was specified then we'll use that over any // Accept: header value that may be present if (AcceptContentType != HttpContentType::kUnknownContentType) { m_AcceptType = AcceptContentType; } else { m_AcceptType = GetAcceptType(HttpRequestPtr); } if (m_Verb == HttpVerb::kHead) { SetSuppressResponseBody(); } } HttpSysServerRequest::~HttpSysServerRequest() { } Oid HttpSysServerRequest::ParseSessionId() const { const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) { HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; std::string_view HeaderName{Header.pName, Header.NameLength}; if (HeaderName == "UE-Session"sv) { if (Header.RawValueLength == Oid::StringLength) { return Oid::TryFromHexString({Header.pRawValue, Header.RawValueLength}); } } } return {}; } uint32_t HttpSysServerRequest::ParseRequestId() const { const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) { HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; std::string_view HeaderName{Header.pName, Header.NameLength}; if (HeaderName == "UE-Request"sv) { std::string_view RequestValue{Header.pRawValue, Header.RawValueLength}; uint32_t RequestId = 0; std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId); return RequestId; } } return 0; } bool HttpSysServerRequest::IsLocalMachineRequest() const { const PSOCKADDR LocalAddress = m_HttpTx.HttpRequest()->Address.pLocalAddress; const PSOCKADDR RemoteAddress = m_HttpTx.HttpRequest()->Address.pRemoteAddress; if (LocalAddress->sa_family != RemoteAddress->sa_family) { return false; } if (LocalAddress->sa_family == AF_INET) { const SOCKADDR_IN& LocalAddressv4 = (const SOCKADDR_IN&)(*LocalAddress); const SOCKADDR_IN& RemoteAddressv4 = (const SOCKADDR_IN&)(*RemoteAddress); return LocalAddressv4.sin_addr.S_un.S_addr == RemoteAddressv4.sin_addr.S_un.S_addr; } else if (LocalAddress->sa_family == AF_INET6) { const SOCKADDR_IN6& LocalAddressv6 = (const SOCKADDR_IN6&)(*LocalAddress); const SOCKADDR_IN6& RemoteAddressv6 = (const SOCKADDR_IN6&)(*RemoteAddress); return memcmp(&LocalAddressv6.sin6_addr, &RemoteAddressv6.sin6_addr, sizeof(in6_addr)) == 0; } else { return false; } } std::string_view HttpSysServerRequest::GetAuthorizationHeader() const { const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); const HTTP_KNOWN_HEADER& AuthorizationHeader = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAuthorization]; return std::string_view(AuthorizationHeader.pRawValue, AuthorizationHeader.RawValueLength); } IoBuffer HttpSysServerRequest::ReadPayload() { return m_PayloadBuffer; } void HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) { ZEN_MEMSCOPE(GetHttpsysTag()); ZEN_ASSERT(IsHandled() == false); HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); if (SuppressBody()) { Response->SuppressResponseBody(); } m_NextCompletionHandler = Response; # if ZEN_WITH_OTEL if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ActiveSpan->AddAttribute("http.response.body.size"sv, Response->GetResponseBodySize()); ActiveSpan->AddAttribute("http.response.status_code"sv, static_cast(ResponseCode)); } # endif SetIsHandled(); LogRequest(Response); } void HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) { ZEN_MEMSCOPE(GetHttpsysTag()); ZEN_ASSERT(IsHandled() == false); HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); if (SuppressBody()) { Response->SuppressResponseBody(); } m_NextCompletionHandler = Response; # if ZEN_WITH_OTEL if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ActiveSpan->AddAttribute("http.response.body.size"sv, Response->GetResponseBodySize()); ActiveSpan->AddAttribute("http.response.status_code"sv, static_cast(ResponseCode)); } # endif SetIsHandled(); LogRequest(Response); } void HttpSysServerRequest::LogRequest(HttpMessageResponseRequest* Response) { if (ShouldLogRequest()) { ZEN_INFO("{} {} {} -> {}", ToString(RequestVerb()), m_UriUtf8.c_str(), Response->GetResponseCode(), NiceBytes(Response->GetResponseBodySize())); } } void HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) { ZEN_MEMSCOPE(GetHttpsysTag()); ZEN_ASSERT(IsHandled() == false); auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size()); if (SuppressBody()) { Response->SuppressResponseBody(); } m_NextCompletionHandler = Response; # if ZEN_WITH_OTEL if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ActiveSpan->AddAttribute("http.response.body.size"sv, Response->GetResponseBodySize()); ActiveSpan->AddAttribute("http.response.status_code"sv, static_cast(ResponseCode)); } # endif SetIsHandled(); LogRequest(Response); } void HttpSysServerRequest::WriteResponseAsync(std::function&& ContinuationHandler) { ZEN_MEMSCOPE(GetHttpsysTag()); if (m_HttpTx.Server().IsAsyncResponseEnabled()) { m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler)); } else { ContinuationHandler(m_HttpTx.ServerRequest()); } } bool HttpSysServerRequest::TryGetRanges(HttpRanges& Ranges) { HTTP_REQUEST* Req = m_HttpTx.HttpRequest(); const HTTP_KNOWN_HEADER& RangeHeader = Req->Headers.KnownHeaders[HttpHeaderRange]; return TryParseHttpRangeHeader({RangeHeader.pRawValue, RangeHeader.RawValueLength}, Ranges); } ////////////////////////////////////////////////////////////////////////// InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) { } InitialRequestHandler::~InitialRequestHandler() { } void InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) { ZEN_TRACE_CPU("httpsys::Request::IssueRequest"); HttpSysTransaction& Tx = Transaction(); PTP_IO Iocp = Tx.Iocp(); HTTP_REQUEST* HttpReq = Tx.HttpRequest(); StartThreadpoolIo(Iocp); ULONG HttpApiResult; if (IsInitialRequest()) { HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), HTTP_NULL_ID, HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, HttpReq, RequestBufferSize(), NULL, Tx.Overlapped()); } else { // The http.sys team recommends limiting the size to 128KB static const uint64_t kMaxBytesPerApiCall = 128 * 1024; uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); void* BufferWriteCursor = reinterpret_cast(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), HttpReq->RequestId, 0, /* Flags */ BufferWriteCursor, gsl::narrow(BytesToReadThisCall), nullptr, // BytesReturned Tx.Overlapped()); } if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) { CancelThreadpoolIo(Iocp); ErrorCode = MakeErrorCode(HttpApiResult); if (IsInitialRequest()) { ZEN_WARN("initial HttpReceiveHttpRequest failed: '{}'", ErrorCode.message()); } else { ZEN_WARN("HttpReceiveHttpRequest (offset: {}, content-length: {}) failed: '{}'", m_CurrentPayloadOffset, m_PayloadBuffer.GetSize(), ErrorCode.message()); } return; } ErrorCode.clear(); } HttpSysRequestHandler* InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { ZEN_MEMSCOPE(GetHttpsysTag()); auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); switch (IoResult) { default: case ERROR_OPERATION_ABORTED: return nullptr; case ERROR_MORE_DATA: // Insufficient buffer space case NO_ERROR: break; } ZEN_TRACE_CPU("httpsys::HandleCompletion"); // Route request try { HTTP_REQUEST* HttpReq = HttpRequest(); # if 0 for (int i = 0; i < HttpReq->RequestInfoCount; ++i) { auto& ReqInfo = HttpReq->pRequestInfo[i]; switch (ReqInfo.InfoType) { case HttpRequestInfoTypeRequestTiming: { const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast(ReqInfo.pInfo); ZEN_INFO(""); } break; case HttpRequestInfoTypeAuth: ZEN_INFO(""); break; case HttpRequestInfoTypeChannelBind: ZEN_INFO(""); break; case HttpRequestInfoTypeSslProtocol: ZEN_INFO(""); break; case HttpRequestInfoTypeSslTokenBindingDraft: ZEN_INFO(""); break; case HttpRequestInfoTypeSslTokenBinding: ZEN_INFO(""); break; case HttpRequestInfoTypeTcpInfoV0: { const TCP_INFO_v0* TcpInfo = reinterpret_cast(ReqInfo.pInfo); ZEN_INFO(""); } break; case HttpRequestInfoTypeRequestSizing: { const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast(ReqInfo.pInfo); ZEN_INFO(""); } break; case HttpRequestInfoTypeQuicStats: ZEN_INFO(""); break; case HttpRequestInfoTypeTcpInfoV1: { const TCP_INFO_v1* TcpInfo = reinterpret_cast(ReqInfo.pInfo); ZEN_INFO(""); } break; } } # endif if (HttpService* Service = reinterpret_cast(HttpReq->UrlContext)) { if (m_IsInitialRequest) { m_ContentLength = GetContentLength(HttpReq); const HttpContentType ContentType = GetContentType(HttpReq); if (m_ContentLength) { // Handle initial chunk read by copying any payload which has already been copied // into our embedded request buffer m_PayloadBuffer = IoBuffer(m_ContentLength); m_PayloadBuffer.SetContentType(ContentType); uint64_t BytesToRead = m_ContentLength; uint8_t* const BufferBase = reinterpret_cast(m_PayloadBuffer.MutableData()); uint8_t* BufferWriteCursor = BufferBase; const int EntityChunkCount = HttpReq->EntityChunkCount; for (int i = 0; i < EntityChunkCount; ++i) { HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; ZEN_ASSERT(BufferLength <= BytesToRead); memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); BufferWriteCursor += BufferLength; BytesToRead -= BufferLength; } m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; } } else { m_CurrentPayloadOffset += NumberOfBytesTransferred; } if (m_CurrentPayloadOffset != m_ContentLength) { // Body not complete, issue another read request to receive more body data return this; } // Request body received completely m_PayloadBuffer.MakeImmutable(); HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer); if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler) { return Response; } if (!ThisRequest.IsHandled()) { return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); } } // Unable to route return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); } catch (const AssertException& AssertEx) { ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, AssertEx.FullDescription()); } catch (const std::system_error& SystemError) { if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) { return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, SystemError.what()); } ZEN_WARN("Caught system error exception while handling request: {}. ({})", SystemError.what(), SystemError.code().value()); return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, SystemError.what()); } catch (const std::bad_alloc& BadAlloc) { return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, BadAlloc.what()); } catch (const std::exception& ex) { ZEN_WARN("Caught exception while handling request: '{}'", ex.what()); return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, ex.what()); } } ////////////////////////////////////////////////////////////////////////// // // HttpServer interface implementation // int HttpSysServer::OnInitialize(int BasePort, std::filesystem::path DataDir) { ZEN_TRACE_CPU("HttpSysServer::Initialize"); ZEN_UNUSED(DataDir); if (int EffectivePort = InitializeServer(BasePort)) { StartServer(); return EffectivePort; } ZEN_WARN("http.sys server was not initialized"); return 0; } void HttpSysServer::OnRequestExit() { m_ShutdownEvent.Set(); } void HttpSysServer::OnRegisterService(HttpService& Service) { RegisterService(Service.BaseUri(), Service); } void HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) { RwLock::ExclusiveLockScope _(m_RequestFilterLock); m_HttpRequestFilter.store(RequestFilter); } IHttpRequestFilter::Result HttpSysServer::FilterRequest(HttpSysServerRequest& Request) { if (!m_HttpRequestFilter.load()) { return IHttpRequestFilter::Result::Accepted; } RwLock::SharedLockScope _(m_RequestFilterLock); IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); if (!RequestFilter) { return IHttpRequestFilter::Result::Accepted; } return RequestFilter->FilterRequest(Request); } Ref CreateHttpSysServer(HttpSysConfig Config) { ZEN_TRACE_CPU("CreateHttpSysServer"); ZEN_MEMSCOPE(GetHttpsysTag()); return Ref(new HttpSysServer(Config)); } } // namespace zen #endif