From 49d09922716b216896fe60a92b1a126c9ba8c302 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Fri, 10 Sep 2021 18:55:30 +0200 Subject: Refactored HTTP request handling to scale better The new logic simply reads the whole payload up front before dispatching to the endpoint handler. This increases concurrency as fewer threads will be blocked waiting for payloads Similar logic will be added for compact binary package negotiation and ultimately we want to support streaming payloads to a staging directory on disk rather than keeping them all in memory --- zenhttp/httpsys.cpp | 704 +++++++++++++++++++++++++++++----------------------- 1 file changed, 395 insertions(+), 309 deletions(-) (limited to 'zenhttp/httpsys.cpp') diff --git a/zenhttp/httpsys.cpp b/zenhttp/httpsys.cpp index 471a8f80a..fd93aa68f 100644 --- a/zenhttp/httpsys.cpp +++ b/zenhttp/httpsys.cpp @@ -3,6 +3,7 @@ #include "httpsys.h" #include +#include #include #if ZEN_WITH_HTTPSYS @@ -54,13 +55,12 @@ UTF8_to_wstring(const char* in) return out; } -////////////////////////////////////////////////////////////////////////// -// -// http.sys implementation -// - namespace zen { +class HttpSysServer; +class HttpSysTransaction; +class HttpMessageResponseRequest; + using namespace std::literals; static const uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv); @@ -245,14 +245,69 @@ ReasonStringForHttpResultCode(int HttpCode) } } -class HttpSysServer; -class HttpSysTransaction; -class HttpMessageResponseRequest; +HttpVerb +TranslateHttpVerb(HTTP_VERB ReqVerb) +{ + switch (ReqVerb) + { + case HttpVerbOPTIONS: + return HttpVerb::kOptions; + + case HttpVerbGET: + return HttpVerb::kGet; + + case HttpVerbHEAD: + return HttpVerb::kHead; + + case HttpVerbPOST: + return HttpVerb::kPost; + + case HttpVerbPUT: + return HttpVerb::kPut; + + case HttpVerbDELETE: + return HttpVerb::kDelete; + + case HttpVerbCOPY: + return HttpVerb::kCopy; + + default: + // TODO: invalid request? + return (HttpVerb)0; + } +} + +uint64_t +GetContentLength(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; + std::string_view cl(clh.pRawValue, clh.RawValueLength); + uint64_t ContentLength = 0; + std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); + return ContentLength; +}; + +HttpContentType +GetContentType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; + return MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +HttpContentType +GetAcceptType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; + return MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; +/** + * @brief Base class for any pending or active HTTP transactions + */ class HttpSysRequestHandler { public: - HttpSysRequestHandler(HttpSysTransaction& InRequest) : m_Request(InRequest) {} + explicit HttpSysRequestHandler(HttpSysTransaction& InRequest) : m_Request(InRequest) {} virtual ~HttpSysRequestHandler() = default; virtual void IssueRequest() = 0; @@ -261,40 +316,62 @@ public: HttpSysTransaction& Transaction() { return m_Request; } private: - HttpSysTransaction& m_Request; // Outermost HTTP transaction object + HttpSysTransaction& m_Request; // Related HTTP transaction object }; +/** + * This is the handler for the initial HTTP I/O request which will receive the headers + * and however much of the remaining payload might fit in the embedded request buffer + */ struct InitialRequestHandler : public HttpSysRequestHandler { - inline PHTTP_REQUEST HttpRequest() { return (PHTTP_REQUEST)m_RequestBuffer; } + inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } + inline bool IsInitialRequest() const { return m_IsInitialRequest; } - InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) {} - ~InitialRequestHandler() {} + InitialRequestHandler(HttpSysTransaction& InRequest); + ~InitialRequestHandler(); - virtual void IssueRequest() override; + virtual void IssueRequest() override final; virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; - PHTTP_REQUEST m_HttpRequestPtr = (HTTP_REQUEST*)(m_RequestBuffer); - UCHAR m_RequestBuffer[16384 + sizeof(HTTP_REQUEST)]; + bool m_IsInitialRequest = true; + uint64_t m_CurrentPayloadOffset = 0; + uint64_t m_ContentLength = ~uint64_t(0); + IoBuffer m_PayloadBuffer; + UCHAR m_RequestBuffer[512 + sizeof(HTTP_REQUEST)]; }; +/** + * @brief Handler used to read complete body + */ +class HttpPayloadReadRequest : public HttpSysRequestHandler +{ +public: + HttpPayloadReadRequest(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) {} + + virtual void IssueRequest() override; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + class HttpSysServerRequest : public HttpServerRequest { public: - HttpSysServerRequest() = default; - HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service); + HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); ~HttpSysServerRequest() = default; - virtual void ReadPayload(std::function&& CompletionHandler) override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponse HttpResponseCode) override; virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span Blobs) override; virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; - bool m_IsInitialized = false; HttpSysTransaction& m_HttpTx; HttpMessageResponseRequest* m_Response = nullptr; // TODO: make this more general + IoBuffer m_PayloadBuffer; }; /** HTTP transaction @@ -305,9 +382,8 @@ public: class HttpSysTransaction final { public: - HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_HttpHandler(&m_InitialHttpHandler) {} - - virtual ~HttpSysTransaction() {} + HttpSysTransaction(HttpSysServer& Server); + virtual ~HttpSysTransaction(); enum class Status { @@ -329,8 +405,7 @@ public: HANDLE RequestQueueHandle(); inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } inline HttpSysServer& Server() { return m_HttpServer; } - - inline PHTTP_REQUEST HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } + inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } private: OVERLAPPED m_HttpOverlapped{}; @@ -342,15 +417,6 @@ private: ////////////////////////////////////////////////////////////////////////// -class HttpPayloadReadRequest : public HttpSysRequestHandler -{ -public: - HttpPayloadReadRequest(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) {} - - virtual void IssueRequest() override; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; -}; - void HttpPayloadReadRequest::IssueRequest() { @@ -369,9 +435,16 @@ class HttpMessageResponseRequest : public HttpSysRequestHandler { public: HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const char* Message); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const void* Payload, size_t PayloadSize); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::span Blobs); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span Blobs); ~HttpMessageResponseRequest(); virtual void IssueRequest() override; @@ -387,6 +460,7 @@ private: uint32_t m_NextDataChunkOffset = 0; // This is used for responses where the number of chunks exceed the maximum number for one API call uint32_t m_RemainingChunkCount = 0; bool m_IsInitialResponse = true; + HttpContentType m_ContentType = HttpContentType::kBinary; void Initialize(uint16_t ResponseCode, std::span Blobs); @@ -396,36 +470,42 @@ private: HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) : HttpSysRequestHandler(InRequest) { - std::array buffers; + std::array EmptyBufferList; - Initialize(ResponseCode, buffers); + Initialize(ResponseCode, EmptyBufferList); } -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const char* Message) +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) : HttpSysRequestHandler(InRequest) +, m_ContentType(HttpContentType::kText) { - IoBuffer MessageBuffer(IoBuffer::Wrap, Message, strlen(Message)); - std::array buffers({MessageBuffer}); + IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); + std::array SingleBufferList({MessageBuffer}); - Initialize(ResponseCode, buffers); + Initialize(ResponseCode, SingleBufferList); } HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, + HttpContentType ContentType, const void* Payload, size_t PayloadSize) : HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) { IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); - std::array buffers({MessageBuffer}); + std::array SingleBufferList({MessageBuffer}); - Initialize(ResponseCode, buffers); + Initialize(ResponseCode, SingleBufferList); } -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::span Blobs) +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span BlobList) : HttpSysRequestHandler(InRequest) { - Initialize(ResponseCode, Blobs); + Initialize(ResponseCode, BlobList); } HttpMessageResponseRequest::~HttpMessageResponseRequest() @@ -433,16 +513,16 @@ HttpMessageResponseRequest::~HttpMessageResponseRequest() } void -HttpMessageResponseRequest::Initialize(uint16_t ResponseCode, std::span Blobs) +HttpMessageResponseRequest::Initialize(uint16_t ResponseCode, std::span BlobList) { m_HttpResponseCode = ResponseCode; - const uint32_t ChunkCount = (uint32_t)Blobs.size(); + const uint32_t ChunkCount = gsl::narrow(BlobList.size()); - m_HttpDataChunks.resize(ChunkCount); + m_HttpDataChunks.reserve(ChunkCount); m_DataBuffers.reserve(ChunkCount); - for (IoBuffer& Buffer : Blobs) + for (IoBuffer& Buffer : BlobList) { m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); } @@ -451,36 +531,55 @@ HttpMessageResponseRequest::Initialize(uint16_t ResponseCode, std::spanDataChunkType = HttpDataChunkFromFileHandle; - ChunkPtr->FromFileHandle.FileHandle = FileRef.FileHandle; - ChunkPtr->FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; - ChunkPtr->FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; - } - else + Chunk.DataChunkType = HttpDataChunkFromFileHandle; + Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; + Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; + Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; + } + else + { + // Send from memory, need to make sure we chunk the buffer up since + // the underlying data structure only accepts 32-bit chunk sizes for + // memory chunks. When this happens the vector will be reallocated, + // which is fine since this will be a pretty rare case and sending + // the data is going to take a lot longer than a memory allocation :) + + const uint8_t* WriteCursor = reinterpret_cast(Buffer.Data()); + + while (BufferDataSize) { - ChunkPtr->DataChunkType = HttpDataChunkFromMemory; - ChunkPtr->FromMemory.pBuffer = (void*)Buffer.Data(); - ChunkPtr->FromMemory.BufferLength = BufferDataSize; - } - ++ChunkPtr; + const ULONG ThisChunkSize = gsl::narrow(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); + + m_HttpDataChunks.push_back({}); + auto& Chunk = m_HttpDataChunks.back(); + + Chunk.DataChunkType = HttpDataChunkFromMemory; + Chunk.FromMemory.pBuffer = (void*)WriteCursor; + Chunk.FromMemory.BufferLength = ThisChunkSize; - LocalDataSize += BufferDataSize; + BufferDataSize -= ThisChunkSize; + WriteCursor += ThisChunkSize; + } } } - m_RemainingChunkCount = ChunkCount; + m_RemainingChunkCount = gsl::narrow(m_HttpDataChunks.size()); m_TotalDataSize = LocalDataSize; } @@ -557,8 +656,10 @@ HttpMessageResponseRequest::IssueRequest() PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; - ContentTypeHeader->pRawValue = "application/octet-stream"; /* TODO! We must respect the content type specified */ - ContentTypeHeader->RawValueLength = (USHORT)strlen(ContentTypeHeader->pRawValue); + std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + + ContentTypeHeader->pRawValue = ContentTypeString.data(); + ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); HttpResponse.StatusCode = m_HttpResponseCode; HttpResponse.pReason = ReasonStringForHttpResultCode(m_HttpResponseCode); @@ -617,7 +718,14 @@ HttpMessageResponseRequest::IssueRequest() } } -////////////////////////////////////////////////////////////////////////// +/** + _________ + / _____/ ______________ __ ___________ + \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ + / \ ___/| | \/\ /\ ___/| | \/ + /_______ /\___ >__| \_/ \___ >__| + \/ \/ \/ +*/ HttpSysServer::HttpSysServer(int ThreadCount) : m_ThreadPool(ThreadCount) { @@ -707,9 +815,9 @@ HttpSysServer::Initialize(const wchar_t* UrlPath) void HttpSysServer::StartServer() { - int RequestCount = 32; + const int InitialRequestCount = 32; - for (int i = 0; i < RequestCount; ++i) + for (int i = 0; i < InitialRequestCount; ++i) { IssueNewRequestMaybe(); } @@ -749,9 +857,7 @@ HttpSysServer::Run(bool TestMode) void HttpSysServer::OnHandlingRequest() { - --m_PendingRequests; - - if (m_PendingRequests > m_MinPendingRequests) + if (--m_PendingRequests > m_MinPendingRequests) { // We have more than the minimum number of requests pending, just let someone else // enqueue new requests @@ -807,7 +913,7 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) } void -HttpSysServer::RemoveEndpoint(const char* UrlPath, HttpService& Service) +HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) { ZEN_UNUSED(Service); @@ -832,218 +938,14 @@ HttpSysServer::RemoveEndpoint(const char* UrlPath, HttpService& Service) ////////////////////////////////////////////////////////////////////////// -HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service) : m_IsInitialized(true), m_HttpTx(Tx) -{ - PHTTP_REQUEST HttpRequestPtr = Tx.HttpRequest(); - - const int PrefixLength = Service.UriPrefixLength(); - const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t); - - if (AbsPathLength >= PrefixLength) - { - // We convert the URI immediately because most of the code involved prefers to deal - // with utf8. This has some performance impact which I'd prefer to avoid but for now - // we just have to live with it - - WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow(AbsPathLength - PrefixLength)}, - m_UriUtf8); - } - else - { - m_UriUtf8.Reset(); - } - - if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) - { - --QueryStringLength; - - WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryStringUtf8); - } - else - { - m_QueryStringUtf8.Reset(); - } - - switch (HttpRequestPtr->Verb) - { - case HttpVerbOPTIONS: - m_Verb = HttpVerb::kOptions; - break; - - case HttpVerbGET: - m_Verb = HttpVerb::kGet; - break; - - case HttpVerbHEAD: - m_Verb = HttpVerb::kHead; - break; - - case HttpVerbPOST: - m_Verb = HttpVerb::kPost; - break; - - case HttpVerbPUT: - m_Verb = HttpVerb::kPut; - break; - - case HttpVerbDELETE: - m_Verb = HttpVerb::kDelete; - break; - - case HttpVerbCOPY: - m_Verb = HttpVerb::kCopy; - break; - - default: - // TODO: invalid request? - m_Verb = (HttpVerb)0; - break; - } - - const HTTP_KNOWN_HEADER& clh = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentLength]; - std::string_view cl(clh.pRawValue, clh.RawValueLength); - std::from_chars(cl.data(), cl.data() + cl.size(), m_ContentLength); - - const HTTP_KNOWN_HEADER& CtHdr = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentType]; - m_ContentType = MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); - - const HTTP_KNOWN_HEADER& AcceptHdr = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAccept]; - m_AcceptType = MapContentType({AcceptHdr.pRawValue, AcceptHdr.RawValueLength}); -} - -void -HttpSysServerRequest::ReadPayload(std::function&& CompletionHandler) -{ - ZEN_UNUSED(CompletionHandler); -} - -IoBuffer -HttpSysServerRequest::ReadPayload() -{ - // This is presently synchronous for simplicity, but we - // need to implement an asynchronous version also - - HTTP_REQUEST* const HttpReq = m_HttpTx.HttpRequest(); - - IoBuffer PayloadBuffer(m_ContentLength); - - HttpContentType ContentType = RequestContentType(); - PayloadBuffer.SetContentType(ContentType); - - uint64_t BytesToRead = m_ContentLength; - - uint8_t* ReadPointer = reinterpret_cast(PayloadBuffer.MutableData()); - - // First deal with any payload which has already been copied - // into our request buffer - - const int EntityChunkCount = HttpReq->EntityChunkCount; - - for (int i = 0; i < EntityChunkCount; ++i) - { - HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; - - ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); - - const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; - - ZEN_ASSERT(BufferLength <= BytesToRead); - - memcpy(ReadPointer, EntityChunk.FromMemory.pBuffer, BufferLength); - - ReadPointer += BufferLength; - BytesToRead -= BufferLength; - } - - if (BytesToRead == 0) - { - PayloadBuffer.MakeImmutable(); - - return PayloadBuffer; - } - - // Call http.sys API to receive the remaining data SYNCHRONOUSLY - - static const uint64_t kMaxBytesPerApiCall = 1 * 1024 * 1024; - - while (BytesToRead) - { - ULONG BytesRead = 0; - - const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); - - ULONG ApiResult = HttpReceiveRequestEntityBody(m_HttpTx.RequestQueueHandle(), - HttpReq->RequestId, - 0, /* Flags */ - ReadPointer, - gsl::narrow(BytesToReadThisCall), - &BytesRead, - NULL /* Overlapped */ - ); - - if (ApiResult != NO_ERROR && ApiResult != ERROR_HANDLE_EOF) - { - throw HttpServerException("payload read failed", ApiResult); - } - - BytesToRead -= BytesRead; - ReadPointer += BytesRead; - } - - PayloadBuffer.MakeImmutable(); - - return PayloadBuffer; -} - -void -HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode) +HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_HttpHandler(&m_InitialHttpHandler) { - ZEN_ASSERT(m_IsHandled == false); - - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode); - - if (m_SuppressBody) - { - m_Response->SuppressResponseBody(); - } - - m_IsHandled = true; } -void -HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span Blobs) +HttpSysTransaction::~HttpSysTransaction() { - ZEN_ASSERT(m_IsHandled == false); - ZEN_UNUSED(ContentType); - - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, Blobs); - - if (m_SuppressBody) - { - m_Response->SuppressResponseBody(); - } - - m_IsHandled = true; } -void -HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) -{ - ZEN_ASSERT(m_IsHandled == false); - ZEN_UNUSED(ContentType); - - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ResponseString.data(), ResponseString.size()); - - if (m_SuppressBody) - { - m_Response->SuppressResponseBody(); - } - - m_IsHandled = true; -} - -////////////////////////////////////////////////////////////////////////// - PTP_IO HttpSysTransaction::Iocp() { @@ -1090,14 +992,16 @@ HttpSysTransaction::Status HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { // We use this to ensure sequential execution of completion handlers - // for any given transaction. + // for any given transaction. It also ensures all member variables are + // in a consistent state for the current thread + RwLock::ExclusiveLockScope _(m_CompletionMutex); bool RequestPending = false; if (HttpSysRequestHandler* CurrentHandler = m_HttpHandler) { - const bool IsInitialRequest = (CurrentHandler == &m_InitialHttpHandler); + const bool IsInitialRequest = (CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest(); if (IsInitialRequest) { @@ -1144,37 +1048,158 @@ HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTran ////////////////////////////////////////////////////////////////////////// +HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) +: m_HttpTx(Tx) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); + + const int PrefixLength = Service.UriPrefixLength(); + const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t); + + if (AbsPathLength >= PrefixLength) + { + // We convert the URI immediately because most of the code involved prefers to deal + // with utf8. This has some performance impact which I'd prefer to avoid but for now + // we just have to live with it + + WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow(AbsPathLength - PrefixLength)}, + m_UriUtf8); + } + else + { + m_UriUtf8.Reset(); + } + + if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) + { + --QueryStringLength; + + WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryStringUtf8); + } + else + { + m_QueryStringUtf8.Reset(); + } + + m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); + m_ContentLength = GetContentLength(HttpRequestPtr); + m_ContentType = GetContentType(HttpRequestPtr); + m_AcceptType = GetAcceptType(HttpRequestPtr); +} + +IoBuffer +HttpSysServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode) +{ + ZEN_ASSERT(m_IsHandled == false); + + m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode); + + if (m_SuppressBody) + { + m_Response->SuppressResponseBody(); + } + + m_IsHandled = true; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span Blobs) +{ + ZEN_ASSERT(m_IsHandled == false); + + m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ContentType, Blobs); + + if (m_SuppressBody) + { + m_Response->SuppressResponseBody(); + } + + m_IsHandled = true; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(m_IsHandled == false); + + m_Response = + new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ContentType, ResponseString.data(), ResponseString.size()); + + if (m_SuppressBody) + { + m_Response->SuppressResponseBody(); + } + + m_IsHandled = true; +} + +////////////////////////////////////////////////////////////////////////// + +InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) +{ +} + +InitialRequestHandler::~InitialRequestHandler() +{ +} + void InitialRequestHandler::IssueRequest() { - PTP_IO Iocp = Transaction().Iocp(); + HttpSysTransaction& Tx = Transaction(); + PTP_IO Iocp = Tx.Iocp(); + HTTP_REQUEST* HttpReq = Tx.HttpRequest(); StartThreadpoolIo(Iocp); - HttpSysTransaction& Tx = Transaction(); + ULONG HttpApiResult; - HTTP_REQUEST* HttpReq = Tx.HttpRequest(); + if (IsInitialRequest()) + { + HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), + HTTP_NULL_ID, + HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, + HttpReq, + RequestBufferSize(), + NULL, + Tx.Overlapped()); + } + else + { + static const uint64_t kMaxBytesPerApiCall = 64 * 1024; - ULONG Result = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), - HTTP_NULL_ID, - HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, - HttpReq, - RequestBufferSize(), - NULL, - Tx.Overlapped()); + uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; + const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); + void* BufferWriteCursor = reinterpret_cast(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; + + HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + 0, /* Flags */ + BufferWriteCursor, + gsl::narrow(BytesToReadThisCall), + nullptr, // BytesReturned + Tx.Overlapped()); + } - if (Result != ERROR_IO_PENDING && Result != NO_ERROR) + if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) { CancelThreadpoolIo(Iocp); - if (Result == ERROR_MORE_DATA) + if (HttpApiResult == ERROR_MORE_DATA) { // ProcessReceiveAndPostResponse(pIoRequest, pServerContext->Io, ERROR_MORE_DATA); } // CleanupHttpIoRequest(pIoRequest); - spdlog::error("HttpReceiveHttpRequest failed, error {:x}", Result); + spdlog::error("HttpReceiveHttpRequest failed, error {:x}", HttpApiResult); return; } @@ -1186,29 +1211,90 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT ZEN_UNUSED(IoResult); ZEN_UNUSED(NumberOfBytesTransferred); + auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); + // Route requests try { - if (HttpService* Service = reinterpret_cast(m_HttpRequestPtr->UrlContext)) + HTTP_REQUEST* HttpReq = HttpRequest(); + + if (HttpService* Service = reinterpret_cast(HttpReq->UrlContext)) { - HttpSysServerRequest ThisRequest(Transaction(), *Service); + if (m_IsInitialRequest) + { + m_ContentLength = GetContentLength(HttpReq); + HttpVerb Verb = TranslateHttpVerb(HttpReq->Verb); + + if (m_ContentLength) + { + // Handle initial chunk read by copying any payload which has already been copied + // into our embedded request buffer + + m_PayloadBuffer = IoBuffer(m_ContentLength); + + HttpContentType ContentType = GetContentType(HttpReq); + m_PayloadBuffer.SetContentType(ContentType); + + uint64_t BytesToRead = m_ContentLength; + uint8_t* const BufferBase = reinterpret_cast(m_PayloadBuffer.MutableData()); + uint8_t* BufferWriteCursor = BufferBase; + + const int EntityChunkCount = HttpReq->EntityChunkCount; + + for (int i = 0; i < EntityChunkCount; ++i) + { + HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; + + ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); + + const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; + + ZEN_ASSERT(BufferLength <= BytesToRead); + + memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); - Service->HandleRequest(ThisRequest); + BufferWriteCursor += BufferLength; + BytesToRead -= BufferLength; + } - if (!ThisRequest.IsHandled()) + m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; + } + } + else { - return new HttpMessageResponseRequest(Transaction(), 404, "Not found"); + m_CurrentPayloadOffset += NumberOfBytesTransferred; } - if (ThisRequest.m_Response) + if (m_CurrentPayloadOffset == m_ContentLength) + { + m_PayloadBuffer.MakeImmutable(); + + // Body received completely - call request handler + + HttpSysServerRequest ThisRequest(Transaction(), *Service, m_PayloadBuffer); + + Service->HandleRequest(ThisRequest); + + if (!ThisRequest.IsHandled()) + { + return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); + } + + if (ThisRequest.m_Response) + { + return ThisRequest.m_Response; + } + } + else { - return ThisRequest.m_Response; + // Issue a read request for more body data + return this; } } // Unable to route - return new HttpMessageResponseRequest(Transaction(), 404, "Item unknown"); + return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); } catch (std::exception& ex) { -- cgit v1.2.3