aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/httpsys.cpp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2021-09-10 19:48:31 +0200
committerStefan Boberg <[email protected]>2021-09-10 19:48:31 +0200
commitf63296d7a92023a28a545ff5c34d4adb952d0b1f (patch)
tree3d7b491835e117a8a4108aeb6092f06470443ec2 /zenhttp/httpsys.cpp
parentAdded beginnings of a uWS http front-end (diff)
parentRefactored HTTP request handling to scale better (diff)
downloadzen-f63296d7a92023a28a545ff5c34d4adb952d0b1f.tar.xz
zen-f63296d7a92023a28a545ff5c34d4adb952d0b1f.zip
Merge branch 'cbpackage-update' of https://github.com/EpicGames/zen into cbpackage-update
Diffstat (limited to 'zenhttp/httpsys.cpp')
-rw-r--r--zenhttp/httpsys.cpp704
1 files changed, 395 insertions, 309 deletions
diff --git a/zenhttp/httpsys.cpp b/zenhttp/httpsys.cpp
index 471a8f80a..fd93aa68f 100644
--- a/zenhttp/httpsys.cpp
+++ b/zenhttp/httpsys.cpp
@@ -3,6 +3,7 @@
#include "httpsys.h"
#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
#include <zencore/string.h>
#if ZEN_WITH_HTTPSYS
@@ -54,13 +55,12 @@ UTF8_to_wstring(const char* in)
return out;
}
-//////////////////////////////////////////////////////////////////////////
-//
-// http.sys implementation
-//
-
namespace zen {
+class HttpSysServer;
+class HttpSysTransaction;
+class HttpMessageResponseRequest;
+
using namespace std::literals;
static const uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv);
@@ -245,14 +245,69 @@ ReasonStringForHttpResultCode(int HttpCode)
}
}
-class HttpSysServer;
-class HttpSysTransaction;
-class HttpMessageResponseRequest;
+HttpVerb
+TranslateHttpVerb(HTTP_VERB ReqVerb)
+{
+ switch (ReqVerb)
+ {
+ case HttpVerbOPTIONS:
+ return HttpVerb::kOptions;
+
+ case HttpVerbGET:
+ return HttpVerb::kGet;
+
+ case HttpVerbHEAD:
+ return HttpVerb::kHead;
+
+ case HttpVerbPOST:
+ return HttpVerb::kPost;
+
+ case HttpVerbPUT:
+ return HttpVerb::kPut;
+
+ case HttpVerbDELETE:
+ return HttpVerb::kDelete;
+
+ case HttpVerbCOPY:
+ return HttpVerb::kCopy;
+
+ default:
+ // TODO: invalid request?
+ return (HttpVerb)0;
+ }
+}
+
+uint64_t
+GetContentLength(const HTTP_REQUEST* HttpRequest)
+{
+ const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength];
+ std::string_view cl(clh.pRawValue, clh.RawValueLength);
+ uint64_t ContentLength = 0;
+ std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength);
+ return ContentLength;
+};
+
+HttpContentType
+GetContentType(const HTTP_REQUEST* HttpRequest)
+{
+ const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType];
+ return MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength});
+};
+
+HttpContentType
+GetAcceptType(const HTTP_REQUEST* HttpRequest)
+{
+ const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept];
+ return MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength});
+};
+/**
+ * @brief Base class for any pending or active HTTP transactions
+ */
class HttpSysRequestHandler
{
public:
- HttpSysRequestHandler(HttpSysTransaction& InRequest) : m_Request(InRequest) {}
+ explicit HttpSysRequestHandler(HttpSysTransaction& InRequest) : m_Request(InRequest) {}
virtual ~HttpSysRequestHandler() = default;
virtual void IssueRequest() = 0;
@@ -261,40 +316,62 @@ public:
HttpSysTransaction& Transaction() { return m_Request; }
private:
- HttpSysTransaction& m_Request; // Outermost HTTP transaction object
+ HttpSysTransaction& m_Request; // Related HTTP transaction object
};
+/**
+ * This is the handler for the initial HTTP I/O request which will receive the headers
+ * and however much of the remaining payload might fit in the embedded request buffer
+ */
struct InitialRequestHandler : public HttpSysRequestHandler
{
- inline PHTTP_REQUEST HttpRequest() { return (PHTTP_REQUEST)m_RequestBuffer; }
+ inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; }
inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; }
+ inline bool IsInitialRequest() const { return m_IsInitialRequest; }
- InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) {}
- ~InitialRequestHandler() {}
+ InitialRequestHandler(HttpSysTransaction& InRequest);
+ ~InitialRequestHandler();
- virtual void IssueRequest() override;
+ virtual void IssueRequest() override final;
virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
- PHTTP_REQUEST m_HttpRequestPtr = (HTTP_REQUEST*)(m_RequestBuffer);
- UCHAR m_RequestBuffer[16384 + sizeof(HTTP_REQUEST)];
+ bool m_IsInitialRequest = true;
+ uint64_t m_CurrentPayloadOffset = 0;
+ uint64_t m_ContentLength = ~uint64_t(0);
+ IoBuffer m_PayloadBuffer;
+ UCHAR m_RequestBuffer[512 + sizeof(HTTP_REQUEST)];
};
+/**
+ * @brief Handler used to read complete body
+ */
+class HttpPayloadReadRequest : public HttpSysRequestHandler
+{
+public:
+ HttpPayloadReadRequest(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) {}
+
+ virtual void IssueRequest() override;
+ virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
+};
+
+/**
+ * This is the class which request handlers use to interact with the server instance
+ */
+
class HttpSysServerRequest : public HttpServerRequest
{
public:
- HttpSysServerRequest() = default;
- HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service);
+ HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer);
~HttpSysServerRequest() = default;
- virtual void ReadPayload(std::function<void(HttpServerRequest&, IoBuffer)>&& CompletionHandler) override;
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponse HttpResponseCode) override;
virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override;
- bool m_IsInitialized = false;
HttpSysTransaction& m_HttpTx;
HttpMessageResponseRequest* m_Response = nullptr; // TODO: make this more general
+ IoBuffer m_PayloadBuffer;
};
/** HTTP transaction
@@ -305,9 +382,8 @@ public:
class HttpSysTransaction final
{
public:
- HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_HttpHandler(&m_InitialHttpHandler) {}
-
- virtual ~HttpSysTransaction() {}
+ HttpSysTransaction(HttpSysServer& Server);
+ virtual ~HttpSysTransaction();
enum class Status
{
@@ -329,8 +405,7 @@ public:
HANDLE RequestQueueHandle();
inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
inline HttpSysServer& Server() { return m_HttpServer; }
-
- inline PHTTP_REQUEST HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
+ inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
private:
OVERLAPPED m_HttpOverlapped{};
@@ -342,15 +417,6 @@ private:
//////////////////////////////////////////////////////////////////////////
-class HttpPayloadReadRequest : public HttpSysRequestHandler
-{
-public:
- HttpPayloadReadRequest(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) {}
-
- virtual void IssueRequest() override;
- virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
-};
-
void
HttpPayloadReadRequest::IssueRequest()
{
@@ -369,9 +435,16 @@ class HttpMessageResponseRequest : public HttpSysRequestHandler
{
public:
HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode);
- HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const char* Message);
- HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const void* Payload, size_t PayloadSize);
- HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::span<IoBuffer> Blobs);
+ HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message);
+ HttpMessageResponseRequest(HttpSysTransaction& InRequest,
+ uint16_t ResponseCode,
+ HttpContentType ContentType,
+ const void* Payload,
+ size_t PayloadSize);
+ HttpMessageResponseRequest(HttpSysTransaction& InRequest,
+ uint16_t ResponseCode,
+ HttpContentType ContentType,
+ std::span<IoBuffer> Blobs);
~HttpMessageResponseRequest();
virtual void IssueRequest() override;
@@ -387,6 +460,7 @@ private:
uint32_t m_NextDataChunkOffset = 0; // This is used for responses where the number of chunks exceed the maximum number for one API call
uint32_t m_RemainingChunkCount = 0;
bool m_IsInitialResponse = true;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
void Initialize(uint16_t ResponseCode, std::span<IoBuffer> Blobs);
@@ -396,36 +470,42 @@ private:
HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode)
: HttpSysRequestHandler(InRequest)
{
- std::array<IoBuffer, 0> buffers;
+ std::array<IoBuffer, 0> EmptyBufferList;
- Initialize(ResponseCode, buffers);
+ Initialize(ResponseCode, EmptyBufferList);
}
-HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const char* Message)
+HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message)
: HttpSysRequestHandler(InRequest)
+, m_ContentType(HttpContentType::kText)
{
- IoBuffer MessageBuffer(IoBuffer::Wrap, Message, strlen(Message));
- std::array<IoBuffer, 1> buffers({MessageBuffer});
+ IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size());
+ std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
- Initialize(ResponseCode, buffers);
+ Initialize(ResponseCode, SingleBufferList);
}
HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest,
uint16_t ResponseCode,
+ HttpContentType ContentType,
const void* Payload,
size_t PayloadSize)
: HttpSysRequestHandler(InRequest)
+, m_ContentType(ContentType)
{
IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize);
- std::array<IoBuffer, 1> buffers({MessageBuffer});
+ std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
- Initialize(ResponseCode, buffers);
+ Initialize(ResponseCode, SingleBufferList);
}
-HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::span<IoBuffer> Blobs)
+HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest,
+ uint16_t ResponseCode,
+ HttpContentType ContentType,
+ std::span<IoBuffer> BlobList)
: HttpSysRequestHandler(InRequest)
{
- Initialize(ResponseCode, Blobs);
+ Initialize(ResponseCode, BlobList);
}
HttpMessageResponseRequest::~HttpMessageResponseRequest()
@@ -433,16 +513,16 @@ HttpMessageResponseRequest::~HttpMessageResponseRequest()
}
void
-HttpMessageResponseRequest::Initialize(uint16_t ResponseCode, std::span<IoBuffer> Blobs)
+HttpMessageResponseRequest::Initialize(uint16_t ResponseCode, std::span<IoBuffer> BlobList)
{
m_HttpResponseCode = ResponseCode;
- const uint32_t ChunkCount = (uint32_t)Blobs.size();
+ const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size());
- m_HttpDataChunks.resize(ChunkCount);
+ m_HttpDataChunks.reserve(ChunkCount);
m_DataBuffers.reserve(ChunkCount);
- for (IoBuffer& Buffer : Blobs)
+ for (IoBuffer& Buffer : BlobList)
{
m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned();
}
@@ -451,36 +531,55 @@ HttpMessageResponseRequest::Initialize(uint16_t ResponseCode, std::span<IoBuffer
uint64_t LocalDataSize = 0;
+ for (IoBuffer& Buffer : m_DataBuffers)
{
- PHTTP_DATA_CHUNK ChunkPtr = m_HttpDataChunks.data();
+ uint64_t BufferDataSize = Buffer.Size();
- for (IoBuffer& Buffer : m_DataBuffers)
+ ZEN_ASSERT(BufferDataSize);
+
+ LocalDataSize += BufferDataSize;
+
+ IoBufferFileReference FileRef;
+ if (Buffer.GetFileReference(/* out */ FileRef))
{
- const ULONG BufferDataSize = (ULONG)Buffer.Size();
+ // Use direct file transfer
- ZEN_ASSERT(BufferDataSize);
+ m_HttpDataChunks.push_back({});
+ auto& Chunk = m_HttpDataChunks.back();
- IoBufferFileReference FileRef;
- if (Buffer.GetFileReference(/* out */ FileRef))
- {
- ChunkPtr->DataChunkType = HttpDataChunkFromFileHandle;
- ChunkPtr->FromFileHandle.FileHandle = FileRef.FileHandle;
- ChunkPtr->FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset;
- ChunkPtr->FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize;
- }
- else
+ Chunk.DataChunkType = HttpDataChunkFromFileHandle;
+ Chunk.FromFileHandle.FileHandle = FileRef.FileHandle;
+ Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset;
+ Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize;
+ }
+ else
+ {
+ // Send from memory, need to make sure we chunk the buffer up since
+ // the underlying data structure only accepts 32-bit chunk sizes for
+ // memory chunks. When this happens the vector will be reallocated,
+ // which is fine since this will be a pretty rare case and sending
+ // the data is going to take a lot longer than a memory allocation :)
+
+ const uint8_t* WriteCursor = reinterpret_cast<const uint8_t*>(Buffer.Data());
+
+ while (BufferDataSize)
{
- ChunkPtr->DataChunkType = HttpDataChunkFromMemory;
- ChunkPtr->FromMemory.pBuffer = (void*)Buffer.Data();
- ChunkPtr->FromMemory.BufferLength = BufferDataSize;
- }
- ++ChunkPtr;
+ const ULONG ThisChunkSize = gsl::narrow<ULONG>(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize));
+
+ m_HttpDataChunks.push_back({});
+ auto& Chunk = m_HttpDataChunks.back();
+
+ Chunk.DataChunkType = HttpDataChunkFromMemory;
+ Chunk.FromMemory.pBuffer = (void*)WriteCursor;
+ Chunk.FromMemory.BufferLength = ThisChunkSize;
- LocalDataSize += BufferDataSize;
+ BufferDataSize -= ThisChunkSize;
+ WriteCursor += ThisChunkSize;
+ }
}
}
- m_RemainingChunkCount = ChunkCount;
+ m_RemainingChunkCount = gsl::narrow<uint32_t>(m_HttpDataChunks.size());
m_TotalDataSize = LocalDataSize;
}
@@ -557,8 +656,10 @@ HttpMessageResponseRequest::IssueRequest()
PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType];
- ContentTypeHeader->pRawValue = "application/octet-stream"; /* TODO! We must respect the content type specified */
- ContentTypeHeader->RawValueLength = (USHORT)strlen(ContentTypeHeader->pRawValue);
+ std::string_view ContentTypeString = MapContentTypeToString(m_ContentType);
+
+ ContentTypeHeader->pRawValue = ContentTypeString.data();
+ ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
HttpResponse.StatusCode = m_HttpResponseCode;
HttpResponse.pReason = ReasonStringForHttpResultCode(m_HttpResponseCode);
@@ -617,7 +718,14 @@ HttpMessageResponseRequest::IssueRequest()
}
}
-//////////////////////////////////////////////////////////////////////////
+/**
+ _________
+ / _____/ ______________ __ ___________
+ \_____ \_/ __ \_ __ \ \/ // __ \_ __ \
+ / \ ___/| | \/\ /\ ___/| | \/
+ /_______ /\___ >__| \_/ \___ >__|
+ \/ \/ \/
+*/
HttpSysServer::HttpSysServer(int ThreadCount) : m_ThreadPool(ThreadCount)
{
@@ -707,9 +815,9 @@ HttpSysServer::Initialize(const wchar_t* UrlPath)
void
HttpSysServer::StartServer()
{
- int RequestCount = 32;
+ const int InitialRequestCount = 32;
- for (int i = 0; i < RequestCount; ++i)
+ for (int i = 0; i < InitialRequestCount; ++i)
{
IssueNewRequestMaybe();
}
@@ -749,9 +857,7 @@ HttpSysServer::Run(bool TestMode)
void
HttpSysServer::OnHandlingRequest()
{
- --m_PendingRequests;
-
- if (m_PendingRequests > m_MinPendingRequests)
+ if (--m_PendingRequests > m_MinPendingRequests)
{
// We have more than the minimum number of requests pending, just let someone else
// enqueue new requests
@@ -807,7 +913,7 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service)
}
void
-HttpSysServer::RemoveEndpoint(const char* UrlPath, HttpService& Service)
+HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service)
{
ZEN_UNUSED(Service);
@@ -832,218 +938,14 @@ HttpSysServer::RemoveEndpoint(const char* UrlPath, HttpService& Service)
//////////////////////////////////////////////////////////////////////////
-HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service) : m_IsInitialized(true), m_HttpTx(Tx)
-{
- PHTTP_REQUEST HttpRequestPtr = Tx.HttpRequest();
-
- const int PrefixLength = Service.UriPrefixLength();
- const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t);
-
- if (AbsPathLength >= PrefixLength)
- {
- // We convert the URI immediately because most of the code involved prefers to deal
- // with utf8. This has some performance impact which I'd prefer to avoid but for now
- // we just have to live with it
-
- WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)},
- m_UriUtf8);
- }
- else
- {
- m_UriUtf8.Reset();
- }
-
- if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength)
- {
- --QueryStringLength;
-
- WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryStringUtf8);
- }
- else
- {
- m_QueryStringUtf8.Reset();
- }
-
- switch (HttpRequestPtr->Verb)
- {
- case HttpVerbOPTIONS:
- m_Verb = HttpVerb::kOptions;
- break;
-
- case HttpVerbGET:
- m_Verb = HttpVerb::kGet;
- break;
-
- case HttpVerbHEAD:
- m_Verb = HttpVerb::kHead;
- break;
-
- case HttpVerbPOST:
- m_Verb = HttpVerb::kPost;
- break;
-
- case HttpVerbPUT:
- m_Verb = HttpVerb::kPut;
- break;
-
- case HttpVerbDELETE:
- m_Verb = HttpVerb::kDelete;
- break;
-
- case HttpVerbCOPY:
- m_Verb = HttpVerb::kCopy;
- break;
-
- default:
- // TODO: invalid request?
- m_Verb = (HttpVerb)0;
- break;
- }
-
- const HTTP_KNOWN_HEADER& clh = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentLength];
- std::string_view cl(clh.pRawValue, clh.RawValueLength);
- std::from_chars(cl.data(), cl.data() + cl.size(), m_ContentLength);
-
- const HTTP_KNOWN_HEADER& CtHdr = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentType];
- m_ContentType = MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength});
-
- const HTTP_KNOWN_HEADER& AcceptHdr = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAccept];
- m_AcceptType = MapContentType({AcceptHdr.pRawValue, AcceptHdr.RawValueLength});
-}
-
-void
-HttpSysServerRequest::ReadPayload(std::function<void(HttpServerRequest&, IoBuffer)>&& CompletionHandler)
-{
- ZEN_UNUSED(CompletionHandler);
-}
-
-IoBuffer
-HttpSysServerRequest::ReadPayload()
-{
- // This is presently synchronous for simplicity, but we
- // need to implement an asynchronous version also
-
- HTTP_REQUEST* const HttpReq = m_HttpTx.HttpRequest();
-
- IoBuffer PayloadBuffer(m_ContentLength);
-
- HttpContentType ContentType = RequestContentType();
- PayloadBuffer.SetContentType(ContentType);
-
- uint64_t BytesToRead = m_ContentLength;
-
- uint8_t* ReadPointer = reinterpret_cast<uint8_t*>(PayloadBuffer.MutableData());
-
- // First deal with any payload which has already been copied
- // into our request buffer
-
- const int EntityChunkCount = HttpReq->EntityChunkCount;
-
- for (int i = 0; i < EntityChunkCount; ++i)
- {
- HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i];
-
- ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory);
-
- const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength;
-
- ZEN_ASSERT(BufferLength <= BytesToRead);
-
- memcpy(ReadPointer, EntityChunk.FromMemory.pBuffer, BufferLength);
-
- ReadPointer += BufferLength;
- BytesToRead -= BufferLength;
- }
-
- if (BytesToRead == 0)
- {
- PayloadBuffer.MakeImmutable();
-
- return PayloadBuffer;
- }
-
- // Call http.sys API to receive the remaining data SYNCHRONOUSLY
-
- static const uint64_t kMaxBytesPerApiCall = 1 * 1024 * 1024;
-
- while (BytesToRead)
- {
- ULONG BytesRead = 0;
-
- const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall);
-
- ULONG ApiResult = HttpReceiveRequestEntityBody(m_HttpTx.RequestQueueHandle(),
- HttpReq->RequestId,
- 0, /* Flags */
- ReadPointer,
- gsl::narrow<ULONG>(BytesToReadThisCall),
- &BytesRead,
- NULL /* Overlapped */
- );
-
- if (ApiResult != NO_ERROR && ApiResult != ERROR_HANDLE_EOF)
- {
- throw HttpServerException("payload read failed", ApiResult);
- }
-
- BytesToRead -= BytesRead;
- ReadPointer += BytesRead;
- }
-
- PayloadBuffer.MakeImmutable();
-
- return PayloadBuffer;
-}
-
-void
-HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode)
+HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_HttpHandler(&m_InitialHttpHandler)
{
- ZEN_ASSERT(m_IsHandled == false);
-
- m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode);
-
- if (m_SuppressBody)
- {
- m_Response->SuppressResponseBody();
- }
-
- m_IsHandled = true;
}
-void
-HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
+HttpSysTransaction::~HttpSysTransaction()
{
- ZEN_ASSERT(m_IsHandled == false);
- ZEN_UNUSED(ContentType);
-
- m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, Blobs);
-
- if (m_SuppressBody)
- {
- m_Response->SuppressResponseBody();
- }
-
- m_IsHandled = true;
}
-void
-HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
-{
- ZEN_ASSERT(m_IsHandled == false);
- ZEN_UNUSED(ContentType);
-
- m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ResponseString.data(), ResponseString.size());
-
- if (m_SuppressBody)
- {
- m_Response->SuppressResponseBody();
- }
-
- m_IsHandled = true;
-}
-
-//////////////////////////////////////////////////////////////////////////
-
PTP_IO
HttpSysTransaction::Iocp()
{
@@ -1090,14 +992,16 @@ HttpSysTransaction::Status
HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
{
// We use this to ensure sequential execution of completion handlers
- // for any given transaction.
+ // for any given transaction. It also ensures all member variables are
+ // in a consistent state for the current thread
+
RwLock::ExclusiveLockScope _(m_CompletionMutex);
bool RequestPending = false;
if (HttpSysRequestHandler* CurrentHandler = m_HttpHandler)
{
- const bool IsInitialRequest = (CurrentHandler == &m_InitialHttpHandler);
+ const bool IsInitialRequest = (CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest();
if (IsInitialRequest)
{
@@ -1144,37 +1048,158 @@ HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTran
//////////////////////////////////////////////////////////////////////////
+HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer)
+: m_HttpTx(Tx)
+, m_PayloadBuffer(std::move(PayloadBuffer))
+{
+ const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest();
+
+ const int PrefixLength = Service.UriPrefixLength();
+ const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t);
+
+ if (AbsPathLength >= PrefixLength)
+ {
+ // We convert the URI immediately because most of the code involved prefers to deal
+ // with utf8. This has some performance impact which I'd prefer to avoid but for now
+ // we just have to live with it
+
+ WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)},
+ m_UriUtf8);
+ }
+ else
+ {
+ m_UriUtf8.Reset();
+ }
+
+ if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength)
+ {
+ --QueryStringLength;
+
+ WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryStringUtf8);
+ }
+ else
+ {
+ m_QueryStringUtf8.Reset();
+ }
+
+ m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb);
+ m_ContentLength = GetContentLength(HttpRequestPtr);
+ m_ContentType = GetContentType(HttpRequestPtr);
+ m_AcceptType = GetAcceptType(HttpRequestPtr);
+}
+
+IoBuffer
+HttpSysServerRequest::ReadPayload()
+{
+ return m_PayloadBuffer;
+}
+
+void
+HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode)
+{
+ ZEN_ASSERT(m_IsHandled == false);
+
+ m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode);
+
+ if (m_SuppressBody)
+ {
+ m_Response->SuppressResponseBody();
+ }
+
+ m_IsHandled = true;
+}
+
+void
+HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
+{
+ ZEN_ASSERT(m_IsHandled == false);
+
+ m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ContentType, Blobs);
+
+ if (m_SuppressBody)
+ {
+ m_Response->SuppressResponseBody();
+ }
+
+ m_IsHandled = true;
+}
+
+void
+HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
+{
+ ZEN_ASSERT(m_IsHandled == false);
+
+ m_Response =
+ new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ContentType, ResponseString.data(), ResponseString.size());
+
+ if (m_SuppressBody)
+ {
+ m_Response->SuppressResponseBody();
+ }
+
+ m_IsHandled = true;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest)
+{
+}
+
+InitialRequestHandler::~InitialRequestHandler()
+{
+}
+
void
InitialRequestHandler::IssueRequest()
{
- PTP_IO Iocp = Transaction().Iocp();
+ HttpSysTransaction& Tx = Transaction();
+ PTP_IO Iocp = Tx.Iocp();
+ HTTP_REQUEST* HttpReq = Tx.HttpRequest();
StartThreadpoolIo(Iocp);
- HttpSysTransaction& Tx = Transaction();
+ ULONG HttpApiResult;
- HTTP_REQUEST* HttpReq = Tx.HttpRequest();
+ if (IsInitialRequest())
+ {
+ HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(),
+ HTTP_NULL_ID,
+ HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY,
+ HttpReq,
+ RequestBufferSize(),
+ NULL,
+ Tx.Overlapped());
+ }
+ else
+ {
+ static const uint64_t kMaxBytesPerApiCall = 64 * 1024;
- ULONG Result = HttpReceiveHttpRequest(Tx.RequestQueueHandle(),
- HTTP_NULL_ID,
- HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY,
- HttpReq,
- RequestBufferSize(),
- NULL,
- Tx.Overlapped());
+ uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset;
+ const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall);
+ void* BufferWriteCursor = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset;
+
+ HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(),
+ HttpReq->RequestId,
+ 0, /* Flags */
+ BufferWriteCursor,
+ gsl::narrow<ULONG>(BytesToReadThisCall),
+ nullptr, // BytesReturned
+ Tx.Overlapped());
+ }
- if (Result != ERROR_IO_PENDING && Result != NO_ERROR)
+ if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR)
{
CancelThreadpoolIo(Iocp);
- if (Result == ERROR_MORE_DATA)
+ if (HttpApiResult == ERROR_MORE_DATA)
{
// ProcessReceiveAndPostResponse(pIoRequest, pServerContext->Io, ERROR_MORE_DATA);
}
// CleanupHttpIoRequest(pIoRequest);
- spdlog::error("HttpReceiveHttpRequest failed, error {:x}", Result);
+ spdlog::error("HttpReceiveHttpRequest failed, error {:x}", HttpApiResult);
return;
}
@@ -1186,29 +1211,90 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
ZEN_UNUSED(IoResult);
ZEN_UNUSED(NumberOfBytesTransferred);
+ auto _ = MakeGuard([&] { m_IsInitialRequest = false; });
+
// Route requests
try
{
- if (HttpService* Service = reinterpret_cast<HttpService*>(m_HttpRequestPtr->UrlContext))
+ HTTP_REQUEST* HttpReq = HttpRequest();
+
+ if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
{
- HttpSysServerRequest ThisRequest(Transaction(), *Service);
+ if (m_IsInitialRequest)
+ {
+ m_ContentLength = GetContentLength(HttpReq);
+ HttpVerb Verb = TranslateHttpVerb(HttpReq->Verb);
+
+ if (m_ContentLength)
+ {
+ // Handle initial chunk read by copying any payload which has already been copied
+ // into our embedded request buffer
+
+ m_PayloadBuffer = IoBuffer(m_ContentLength);
+
+ HttpContentType ContentType = GetContentType(HttpReq);
+ m_PayloadBuffer.SetContentType(ContentType);
+
+ uint64_t BytesToRead = m_ContentLength;
+ uint8_t* const BufferBase = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData());
+ uint8_t* BufferWriteCursor = BufferBase;
+
+ const int EntityChunkCount = HttpReq->EntityChunkCount;
+
+ for (int i = 0; i < EntityChunkCount; ++i)
+ {
+ HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i];
+
+ ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory);
+
+ const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength;
+
+ ZEN_ASSERT(BufferLength <= BytesToRead);
+
+ memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength);
- Service->HandleRequest(ThisRequest);
+ BufferWriteCursor += BufferLength;
+ BytesToRead -= BufferLength;
+ }
- if (!ThisRequest.IsHandled())
+ m_CurrentPayloadOffset = BufferWriteCursor - BufferBase;
+ }
+ }
+ else
{
- return new HttpMessageResponseRequest(Transaction(), 404, "Not found");
+ m_CurrentPayloadOffset += NumberOfBytesTransferred;
}
- if (ThisRequest.m_Response)
+ if (m_CurrentPayloadOffset == m_ContentLength)
+ {
+ m_PayloadBuffer.MakeImmutable();
+
+ // Body received completely - call request handler
+
+ HttpSysServerRequest ThisRequest(Transaction(), *Service, m_PayloadBuffer);
+
+ Service->HandleRequest(ThisRequest);
+
+ if (!ThisRequest.IsHandled())
+ {
+ return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
+ }
+
+ if (ThisRequest.m_Response)
+ {
+ return ThisRequest.m_Response;
+ }
+ }
+ else
{
- return ThisRequest.m_Response;
+ // Issue a read request for more body data
+ return this;
}
}
// Unable to route
- return new HttpMessageResponseRequest(Transaction(), 404, "Item unknown");
+ return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv);
}
catch (std::exception& ex)
{