diff options
| author | Stefan Boberg <[email protected]> | 2021-09-10 19:48:31 +0200 |
|---|---|---|
| committer | Stefan Boberg <[email protected]> | 2021-09-10 19:48:31 +0200 |
| commit | f63296d7a92023a28a545ff5c34d4adb952d0b1f (patch) | |
| tree | 3d7b491835e117a8a4108aeb6092f06470443ec2 /zenhttp/httpsys.cpp | |
| parent | Added beginnings of a uWS http front-end (diff) | |
| parent | Refactored HTTP request handling to scale better (diff) | |
| download | zen-f63296d7a92023a28a545ff5c34d4adb952d0b1f.tar.xz zen-f63296d7a92023a28a545ff5c34d4adb952d0b1f.zip | |
Merge branch 'cbpackage-update' of https://github.com/EpicGames/zen into cbpackage-update
Diffstat (limited to 'zenhttp/httpsys.cpp')
| -rw-r--r-- | zenhttp/httpsys.cpp | 704 |
1 files changed, 395 insertions, 309 deletions
diff --git a/zenhttp/httpsys.cpp b/zenhttp/httpsys.cpp index 471a8f80a..fd93aa68f 100644 --- a/zenhttp/httpsys.cpp +++ b/zenhttp/httpsys.cpp @@ -3,6 +3,7 @@ #include "httpsys.h" #include <zencore/logging.h> +#include <zencore/scopeguard.h> #include <zencore/string.h> #if ZEN_WITH_HTTPSYS @@ -54,13 +55,12 @@ UTF8_to_wstring(const char* in) return out; } -////////////////////////////////////////////////////////////////////////// -// -// http.sys implementation -// - namespace zen { +class HttpSysServer; +class HttpSysTransaction; +class HttpMessageResponseRequest; + using namespace std::literals; static const uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv); @@ -245,14 +245,69 @@ ReasonStringForHttpResultCode(int HttpCode) } } -class HttpSysServer; -class HttpSysTransaction; -class HttpMessageResponseRequest; +HttpVerb +TranslateHttpVerb(HTTP_VERB ReqVerb) +{ + switch (ReqVerb) + { + case HttpVerbOPTIONS: + return HttpVerb::kOptions; + + case HttpVerbGET: + return HttpVerb::kGet; + + case HttpVerbHEAD: + return HttpVerb::kHead; + + case HttpVerbPOST: + return HttpVerb::kPost; + + case HttpVerbPUT: + return HttpVerb::kPut; + + case HttpVerbDELETE: + return HttpVerb::kDelete; + + case HttpVerbCOPY: + return HttpVerb::kCopy; + + default: + // TODO: invalid request? + return (HttpVerb)0; + } +} + +uint64_t +GetContentLength(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; + std::string_view cl(clh.pRawValue, clh.RawValueLength); + uint64_t ContentLength = 0; + std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); + return ContentLength; +}; + +HttpContentType +GetContentType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; + return MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +HttpContentType +GetAcceptType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; + return MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; +/** + * @brief Base class for any pending or active HTTP transactions + */ class HttpSysRequestHandler { public: - HttpSysRequestHandler(HttpSysTransaction& InRequest) : m_Request(InRequest) {} + explicit HttpSysRequestHandler(HttpSysTransaction& InRequest) : m_Request(InRequest) {} virtual ~HttpSysRequestHandler() = default; virtual void IssueRequest() = 0; @@ -261,40 +316,62 @@ public: HttpSysTransaction& Transaction() { return m_Request; } private: - HttpSysTransaction& m_Request; // Outermost HTTP transaction object + HttpSysTransaction& m_Request; // Related HTTP transaction object }; +/** + * This is the handler for the initial HTTP I/O request which will receive the headers + * and however much of the remaining payload might fit in the embedded request buffer + */ struct InitialRequestHandler : public HttpSysRequestHandler { - inline PHTTP_REQUEST HttpRequest() { return (PHTTP_REQUEST)m_RequestBuffer; } + inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } + inline bool IsInitialRequest() const { return m_IsInitialRequest; } - InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) {} - ~InitialRequestHandler() {} + InitialRequestHandler(HttpSysTransaction& InRequest); + ~InitialRequestHandler(); - virtual void IssueRequest() override; + virtual void IssueRequest() override final; virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; - PHTTP_REQUEST m_HttpRequestPtr = (HTTP_REQUEST*)(m_RequestBuffer); - UCHAR m_RequestBuffer[16384 + sizeof(HTTP_REQUEST)]; + bool m_IsInitialRequest = true; + uint64_t m_CurrentPayloadOffset = 0; + uint64_t m_ContentLength = ~uint64_t(0); + IoBuffer m_PayloadBuffer; + UCHAR m_RequestBuffer[512 + sizeof(HTTP_REQUEST)]; }; +/** + * @brief Handler used to read complete body + */ +class HttpPayloadReadRequest : public HttpSysRequestHandler +{ +public: + HttpPayloadReadRequest(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) {} + + virtual void IssueRequest() override; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + class HttpSysServerRequest : public HttpServerRequest { public: - HttpSysServerRequest() = default; - HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service); + HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); ~HttpSysServerRequest() = default; - virtual void ReadPayload(std::function<void(HttpServerRequest&, IoBuffer)>&& CompletionHandler) override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponse HttpResponseCode) override; virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; - bool m_IsInitialized = false; HttpSysTransaction& m_HttpTx; HttpMessageResponseRequest* m_Response = nullptr; // TODO: make this more general + IoBuffer m_PayloadBuffer; }; /** HTTP transaction @@ -305,9 +382,8 @@ public: class HttpSysTransaction final { public: - HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_HttpHandler(&m_InitialHttpHandler) {} - - virtual ~HttpSysTransaction() {} + HttpSysTransaction(HttpSysServer& Server); + virtual ~HttpSysTransaction(); enum class Status { @@ -329,8 +405,7 @@ public: HANDLE RequestQueueHandle(); inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } inline HttpSysServer& Server() { return m_HttpServer; } - - inline PHTTP_REQUEST HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } + inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } private: OVERLAPPED m_HttpOverlapped{}; @@ -342,15 +417,6 @@ private: ////////////////////////////////////////////////////////////////////////// -class HttpPayloadReadRequest : public HttpSysRequestHandler -{ -public: - HttpPayloadReadRequest(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) {} - - virtual void IssueRequest() override; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; -}; - void HttpPayloadReadRequest::IssueRequest() { @@ -369,9 +435,16 @@ class HttpMessageResponseRequest : public HttpSysRequestHandler { public: HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const char* Message); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const void* Payload, size_t PayloadSize); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::span<IoBuffer> Blobs); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> Blobs); ~HttpMessageResponseRequest(); virtual void IssueRequest() override; @@ -387,6 +460,7 @@ private: uint32_t m_NextDataChunkOffset = 0; // This is used for responses where the number of chunks exceed the maximum number for one API call uint32_t m_RemainingChunkCount = 0; bool m_IsInitialResponse = true; + HttpContentType m_ContentType = HttpContentType::kBinary; void Initialize(uint16_t ResponseCode, std::span<IoBuffer> Blobs); @@ -396,36 +470,42 @@ private: HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) : HttpSysRequestHandler(InRequest) { - std::array<IoBuffer, 0> buffers; + std::array<IoBuffer, 0> EmptyBufferList; - Initialize(ResponseCode, buffers); + Initialize(ResponseCode, EmptyBufferList); } -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const char* Message) +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) : HttpSysRequestHandler(InRequest) +, m_ContentType(HttpContentType::kText) { - IoBuffer MessageBuffer(IoBuffer::Wrap, Message, strlen(Message)); - std::array<IoBuffer, 1> buffers({MessageBuffer}); + IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); - Initialize(ResponseCode, buffers); + Initialize(ResponseCode, SingleBufferList); } HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, + HttpContentType ContentType, const void* Payload, size_t PayloadSize) : HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) { IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); - std::array<IoBuffer, 1> buffers({MessageBuffer}); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); - Initialize(ResponseCode, buffers); + Initialize(ResponseCode, SingleBufferList); } -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::span<IoBuffer> Blobs) +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> BlobList) : HttpSysRequestHandler(InRequest) { - Initialize(ResponseCode, Blobs); + Initialize(ResponseCode, BlobList); } HttpMessageResponseRequest::~HttpMessageResponseRequest() @@ -433,16 +513,16 @@ HttpMessageResponseRequest::~HttpMessageResponseRequest() } void -HttpMessageResponseRequest::Initialize(uint16_t ResponseCode, std::span<IoBuffer> Blobs) +HttpMessageResponseRequest::Initialize(uint16_t ResponseCode, std::span<IoBuffer> BlobList) { m_HttpResponseCode = ResponseCode; - const uint32_t ChunkCount = (uint32_t)Blobs.size(); + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); - m_HttpDataChunks.resize(ChunkCount); + m_HttpDataChunks.reserve(ChunkCount); m_DataBuffers.reserve(ChunkCount); - for (IoBuffer& Buffer : Blobs) + for (IoBuffer& Buffer : BlobList) { m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); } @@ -451,36 +531,55 @@ HttpMessageResponseRequest::Initialize(uint16_t ResponseCode, std::span<IoBuffer uint64_t LocalDataSize = 0; + for (IoBuffer& Buffer : m_DataBuffers) { - PHTTP_DATA_CHUNK ChunkPtr = m_HttpDataChunks.data(); + uint64_t BufferDataSize = Buffer.Size(); - for (IoBuffer& Buffer : m_DataBuffers) + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) { - const ULONG BufferDataSize = (ULONG)Buffer.Size(); + // Use direct file transfer - ZEN_ASSERT(BufferDataSize); + m_HttpDataChunks.push_back({}); + auto& Chunk = m_HttpDataChunks.back(); - IoBufferFileReference FileRef; - if (Buffer.GetFileReference(/* out */ FileRef)) - { - ChunkPtr->DataChunkType = HttpDataChunkFromFileHandle; - ChunkPtr->FromFileHandle.FileHandle = FileRef.FileHandle; - ChunkPtr->FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; - ChunkPtr->FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; - } - else + Chunk.DataChunkType = HttpDataChunkFromFileHandle; + Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; + Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; + Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; + } + else + { + // Send from memory, need to make sure we chunk the buffer up since + // the underlying data structure only accepts 32-bit chunk sizes for + // memory chunks. When this happens the vector will be reallocated, + // which is fine since this will be a pretty rare case and sending + // the data is going to take a lot longer than a memory allocation :) + + const uint8_t* WriteCursor = reinterpret_cast<const uint8_t*>(Buffer.Data()); + + while (BufferDataSize) { - ChunkPtr->DataChunkType = HttpDataChunkFromMemory; - ChunkPtr->FromMemory.pBuffer = (void*)Buffer.Data(); - ChunkPtr->FromMemory.BufferLength = BufferDataSize; - } - ++ChunkPtr; + const ULONG ThisChunkSize = gsl::narrow<ULONG>(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); + + m_HttpDataChunks.push_back({}); + auto& Chunk = m_HttpDataChunks.back(); + + Chunk.DataChunkType = HttpDataChunkFromMemory; + Chunk.FromMemory.pBuffer = (void*)WriteCursor; + Chunk.FromMemory.BufferLength = ThisChunkSize; - LocalDataSize += BufferDataSize; + BufferDataSize -= ThisChunkSize; + WriteCursor += ThisChunkSize; + } } } - m_RemainingChunkCount = ChunkCount; + m_RemainingChunkCount = gsl::narrow<uint32_t>(m_HttpDataChunks.size()); m_TotalDataSize = LocalDataSize; } @@ -557,8 +656,10 @@ HttpMessageResponseRequest::IssueRequest() PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; - ContentTypeHeader->pRawValue = "application/octet-stream"; /* TODO! We must respect the content type specified */ - ContentTypeHeader->RawValueLength = (USHORT)strlen(ContentTypeHeader->pRawValue); + std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + + ContentTypeHeader->pRawValue = ContentTypeString.data(); + ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); HttpResponse.StatusCode = m_HttpResponseCode; HttpResponse.pReason = ReasonStringForHttpResultCode(m_HttpResponseCode); @@ -617,7 +718,14 @@ HttpMessageResponseRequest::IssueRequest() } } -////////////////////////////////////////////////////////////////////////// +/** + _________ + / _____/ ______________ __ ___________ + \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ + / \ ___/| | \/\ /\ ___/| | \/ + /_______ /\___ >__| \_/ \___ >__| + \/ \/ \/ +*/ HttpSysServer::HttpSysServer(int ThreadCount) : m_ThreadPool(ThreadCount) { @@ -707,9 +815,9 @@ HttpSysServer::Initialize(const wchar_t* UrlPath) void HttpSysServer::StartServer() { - int RequestCount = 32; + const int InitialRequestCount = 32; - for (int i = 0; i < RequestCount; ++i) + for (int i = 0; i < InitialRequestCount; ++i) { IssueNewRequestMaybe(); } @@ -749,9 +857,7 @@ HttpSysServer::Run(bool TestMode) void HttpSysServer::OnHandlingRequest() { - --m_PendingRequests; - - if (m_PendingRequests > m_MinPendingRequests) + if (--m_PendingRequests > m_MinPendingRequests) { // We have more than the minimum number of requests pending, just let someone else // enqueue new requests @@ -807,7 +913,7 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) } void -HttpSysServer::RemoveEndpoint(const char* UrlPath, HttpService& Service) +HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) { ZEN_UNUSED(Service); @@ -832,218 +938,14 @@ HttpSysServer::RemoveEndpoint(const char* UrlPath, HttpService& Service) ////////////////////////////////////////////////////////////////////////// -HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service) : m_IsInitialized(true), m_HttpTx(Tx) -{ - PHTTP_REQUEST HttpRequestPtr = Tx.HttpRequest(); - - const int PrefixLength = Service.UriPrefixLength(); - const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t); - - if (AbsPathLength >= PrefixLength) - { - // We convert the URI immediately because most of the code involved prefers to deal - // with utf8. This has some performance impact which I'd prefer to avoid but for now - // we just have to live with it - - WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)}, - m_UriUtf8); - } - else - { - m_UriUtf8.Reset(); - } - - if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) - { - --QueryStringLength; - - WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryStringUtf8); - } - else - { - m_QueryStringUtf8.Reset(); - } - - switch (HttpRequestPtr->Verb) - { - case HttpVerbOPTIONS: - m_Verb = HttpVerb::kOptions; - break; - - case HttpVerbGET: - m_Verb = HttpVerb::kGet; - break; - - case HttpVerbHEAD: - m_Verb = HttpVerb::kHead; - break; - - case HttpVerbPOST: - m_Verb = HttpVerb::kPost; - break; - - case HttpVerbPUT: - m_Verb = HttpVerb::kPut; - break; - - case HttpVerbDELETE: - m_Verb = HttpVerb::kDelete; - break; - - case HttpVerbCOPY: - m_Verb = HttpVerb::kCopy; - break; - - default: - // TODO: invalid request? - m_Verb = (HttpVerb)0; - break; - } - - const HTTP_KNOWN_HEADER& clh = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentLength]; - std::string_view cl(clh.pRawValue, clh.RawValueLength); - std::from_chars(cl.data(), cl.data() + cl.size(), m_ContentLength); - - const HTTP_KNOWN_HEADER& CtHdr = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentType]; - m_ContentType = MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); - - const HTTP_KNOWN_HEADER& AcceptHdr = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAccept]; - m_AcceptType = MapContentType({AcceptHdr.pRawValue, AcceptHdr.RawValueLength}); -} - -void -HttpSysServerRequest::ReadPayload(std::function<void(HttpServerRequest&, IoBuffer)>&& CompletionHandler) -{ - ZEN_UNUSED(CompletionHandler); -} - -IoBuffer -HttpSysServerRequest::ReadPayload() -{ - // This is presently synchronous for simplicity, but we - // need to implement an asynchronous version also - - HTTP_REQUEST* const HttpReq = m_HttpTx.HttpRequest(); - - IoBuffer PayloadBuffer(m_ContentLength); - - HttpContentType ContentType = RequestContentType(); - PayloadBuffer.SetContentType(ContentType); - - uint64_t BytesToRead = m_ContentLength; - - uint8_t* ReadPointer = reinterpret_cast<uint8_t*>(PayloadBuffer.MutableData()); - - // First deal with any payload which has already been copied - // into our request buffer - - const int EntityChunkCount = HttpReq->EntityChunkCount; - - for (int i = 0; i < EntityChunkCount; ++i) - { - HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; - - ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); - - const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; - - ZEN_ASSERT(BufferLength <= BytesToRead); - - memcpy(ReadPointer, EntityChunk.FromMemory.pBuffer, BufferLength); - - ReadPointer += BufferLength; - BytesToRead -= BufferLength; - } - - if (BytesToRead == 0) - { - PayloadBuffer.MakeImmutable(); - - return PayloadBuffer; - } - - // Call http.sys API to receive the remaining data SYNCHRONOUSLY - - static const uint64_t kMaxBytesPerApiCall = 1 * 1024 * 1024; - - while (BytesToRead) - { - ULONG BytesRead = 0; - - const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); - - ULONG ApiResult = HttpReceiveRequestEntityBody(m_HttpTx.RequestQueueHandle(), - HttpReq->RequestId, - 0, /* Flags */ - ReadPointer, - gsl::narrow<ULONG>(BytesToReadThisCall), - &BytesRead, - NULL /* Overlapped */ - ); - - if (ApiResult != NO_ERROR && ApiResult != ERROR_HANDLE_EOF) - { - throw HttpServerException("payload read failed", ApiResult); - } - - BytesToRead -= BytesRead; - ReadPointer += BytesRead; - } - - PayloadBuffer.MakeImmutable(); - - return PayloadBuffer; -} - -void -HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode) +HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_HttpHandler(&m_InitialHttpHandler) { - ZEN_ASSERT(m_IsHandled == false); - - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode); - - if (m_SuppressBody) - { - m_Response->SuppressResponseBody(); - } - - m_IsHandled = true; } -void -HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +HttpSysTransaction::~HttpSysTransaction() { - ZEN_ASSERT(m_IsHandled == false); - ZEN_UNUSED(ContentType); - - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, Blobs); - - if (m_SuppressBody) - { - m_Response->SuppressResponseBody(); - } - - m_IsHandled = true; } -void -HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) -{ - ZEN_ASSERT(m_IsHandled == false); - ZEN_UNUSED(ContentType); - - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ResponseString.data(), ResponseString.size()); - - if (m_SuppressBody) - { - m_Response->SuppressResponseBody(); - } - - m_IsHandled = true; -} - -////////////////////////////////////////////////////////////////////////// - PTP_IO HttpSysTransaction::Iocp() { @@ -1090,14 +992,16 @@ HttpSysTransaction::Status HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { // We use this to ensure sequential execution of completion handlers - // for any given transaction. + // for any given transaction. It also ensures all member variables are + // in a consistent state for the current thread + RwLock::ExclusiveLockScope _(m_CompletionMutex); bool RequestPending = false; if (HttpSysRequestHandler* CurrentHandler = m_HttpHandler) { - const bool IsInitialRequest = (CurrentHandler == &m_InitialHttpHandler); + const bool IsInitialRequest = (CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest(); if (IsInitialRequest) { @@ -1144,37 +1048,158 @@ HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTran ////////////////////////////////////////////////////////////////////////// +HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) +: m_HttpTx(Tx) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); + + const int PrefixLength = Service.UriPrefixLength(); + const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t); + + if (AbsPathLength >= PrefixLength) + { + // We convert the URI immediately because most of the code involved prefers to deal + // with utf8. This has some performance impact which I'd prefer to avoid but for now + // we just have to live with it + + WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)}, + m_UriUtf8); + } + else + { + m_UriUtf8.Reset(); + } + + if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) + { + --QueryStringLength; + + WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryStringUtf8); + } + else + { + m_QueryStringUtf8.Reset(); + } + + m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); + m_ContentLength = GetContentLength(HttpRequestPtr); + m_ContentType = GetContentType(HttpRequestPtr); + m_AcceptType = GetAcceptType(HttpRequestPtr); +} + +IoBuffer +HttpSysServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode) +{ + ZEN_ASSERT(m_IsHandled == false); + + m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode); + + if (m_SuppressBody) + { + m_Response->SuppressResponseBody(); + } + + m_IsHandled = true; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(m_IsHandled == false); + + m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ContentType, Blobs); + + if (m_SuppressBody) + { + m_Response->SuppressResponseBody(); + } + + m_IsHandled = true; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(m_IsHandled == false); + + m_Response = + new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ContentType, ResponseString.data(), ResponseString.size()); + + if (m_SuppressBody) + { + m_Response->SuppressResponseBody(); + } + + m_IsHandled = true; +} + +////////////////////////////////////////////////////////////////////////// + +InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) +{ +} + +InitialRequestHandler::~InitialRequestHandler() +{ +} + void InitialRequestHandler::IssueRequest() { - PTP_IO Iocp = Transaction().Iocp(); + HttpSysTransaction& Tx = Transaction(); + PTP_IO Iocp = Tx.Iocp(); + HTTP_REQUEST* HttpReq = Tx.HttpRequest(); StartThreadpoolIo(Iocp); - HttpSysTransaction& Tx = Transaction(); + ULONG HttpApiResult; - HTTP_REQUEST* HttpReq = Tx.HttpRequest(); + if (IsInitialRequest()) + { + HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), + HTTP_NULL_ID, + HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, + HttpReq, + RequestBufferSize(), + NULL, + Tx.Overlapped()); + } + else + { + static const uint64_t kMaxBytesPerApiCall = 64 * 1024; - ULONG Result = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), - HTTP_NULL_ID, - HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, - HttpReq, - RequestBufferSize(), - NULL, - Tx.Overlapped()); + uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; + const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); + void* BufferWriteCursor = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; + + HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + 0, /* Flags */ + BufferWriteCursor, + gsl::narrow<ULONG>(BytesToReadThisCall), + nullptr, // BytesReturned + Tx.Overlapped()); + } - if (Result != ERROR_IO_PENDING && Result != NO_ERROR) + if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) { CancelThreadpoolIo(Iocp); - if (Result == ERROR_MORE_DATA) + if (HttpApiResult == ERROR_MORE_DATA) { // ProcessReceiveAndPostResponse(pIoRequest, pServerContext->Io, ERROR_MORE_DATA); } // CleanupHttpIoRequest(pIoRequest); - spdlog::error("HttpReceiveHttpRequest failed, error {:x}", Result); + spdlog::error("HttpReceiveHttpRequest failed, error {:x}", HttpApiResult); return; } @@ -1186,29 +1211,90 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT ZEN_UNUSED(IoResult); ZEN_UNUSED(NumberOfBytesTransferred); + auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); + // Route requests try { - if (HttpService* Service = reinterpret_cast<HttpService*>(m_HttpRequestPtr->UrlContext)) + HTTP_REQUEST* HttpReq = HttpRequest(); + + if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) { - HttpSysServerRequest ThisRequest(Transaction(), *Service); + if (m_IsInitialRequest) + { + m_ContentLength = GetContentLength(HttpReq); + HttpVerb Verb = TranslateHttpVerb(HttpReq->Verb); + + if (m_ContentLength) + { + // Handle initial chunk read by copying any payload which has already been copied + // into our embedded request buffer + + m_PayloadBuffer = IoBuffer(m_ContentLength); + + HttpContentType ContentType = GetContentType(HttpReq); + m_PayloadBuffer.SetContentType(ContentType); + + uint64_t BytesToRead = m_ContentLength; + uint8_t* const BufferBase = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()); + uint8_t* BufferWriteCursor = BufferBase; + + const int EntityChunkCount = HttpReq->EntityChunkCount; + + for (int i = 0; i < EntityChunkCount; ++i) + { + HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; + + ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); + + const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; + + ZEN_ASSERT(BufferLength <= BytesToRead); + + memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); - Service->HandleRequest(ThisRequest); + BufferWriteCursor += BufferLength; + BytesToRead -= BufferLength; + } - if (!ThisRequest.IsHandled()) + m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; + } + } + else { - return new HttpMessageResponseRequest(Transaction(), 404, "Not found"); + m_CurrentPayloadOffset += NumberOfBytesTransferred; } - if (ThisRequest.m_Response) + if (m_CurrentPayloadOffset == m_ContentLength) + { + m_PayloadBuffer.MakeImmutable(); + + // Body received completely - call request handler + + HttpSysServerRequest ThisRequest(Transaction(), *Service, m_PayloadBuffer); + + Service->HandleRequest(ThisRequest); + + if (!ThisRequest.IsHandled()) + { + return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); + } + + if (ThisRequest.m_Response) + { + return ThisRequest.m_Response; + } + } + else { - return ThisRequest.m_Response; + // Issue a read request for more body data + return this; } } // Unable to route - return new HttpMessageResponseRequest(Transaction(), 404, "Item unknown"); + return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); } catch (std::exception& ex) { |