aboutsummaryrefslogtreecommitdiff
path: root/zenhttp/httpsys.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'zenhttp/httpsys.cpp')
-rw-r--r--zenhttp/httpsys.cpp299
1 files changed, 242 insertions, 57 deletions
diff --git a/zenhttp/httpsys.cpp b/zenhttp/httpsys.cpp
index 9b2e7f832..f88563097 100644
--- a/zenhttp/httpsys.cpp
+++ b/zenhttp/httpsys.cpp
@@ -129,16 +129,18 @@ GetAcceptType(const HTTP_REQUEST* HttpRequest)
class HttpSysRequestHandler
{
public:
- explicit HttpSysRequestHandler(HttpSysTransaction& InRequest) : m_Request(InRequest) {}
+ explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {}
virtual ~HttpSysRequestHandler() = default;
virtual void IssueRequest(std::error_code& ErrorCode) = 0;
virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0;
+ HttpSysTransaction& Transaction() { return m_Transaction; }
- HttpSysTransaction& Transaction() { return m_Request; }
+ HttpSysRequestHandler(const HttpSysRequestHandler&) = delete;
+ HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete;
private:
- HttpSysTransaction& m_Request; // Related HTTP transaction object
+ HttpSysTransaction& m_Transaction;
};
/**
@@ -184,12 +186,16 @@ public:
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override;
+ virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
using HttpServerRequest::WriteResponse;
- HttpSysTransaction& m_HttpTx;
- HttpMessageResponseRequest* m_Response = nullptr; // TODO: make this more general
- IoBuffer m_PayloadBuffer;
+ HttpSysServerRequest(const HttpSysServerRequest&) = delete;
+ HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete;
+
+ HttpSysTransaction& m_HttpTx;
+ HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
+ IoBuffer m_PayloadBuffer;
};
/** HTTP transaction
@@ -218,7 +224,9 @@ public:
ULONG_PTR NumberOfBytesTransferred,
PTP_IO Io);
- void IssueInitialRequest(std::error_code& ErrorCode);
+ void IssueInitialRequest(std::error_code& ErrorCode);
+ bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler);
+
PTP_IO Iocp();
HANDLE RequestQueueHandle();
inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
@@ -227,6 +235,8 @@ public:
HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload);
+ HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); }
+
private:
OVERLAPPED m_HttpOverlapped{};
HttpSysServer& m_HttpServer;
@@ -239,8 +249,6 @@ private:
Ref<IHttpPackageHandler> m_PackageHandler;
};
-//////////////////////////////////////////////////////////////////////////
-
/**
* @brief HTTP request response I/O request handler
*
@@ -588,6 +596,108 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
}
}
+/** HTTP completion handler for async work
+
+ This is used to allow work to be taken off the request handler threads
+ and to support posting responses asynchronously.
+ */
+
+class HttpAsyncWorkRequest : public HttpSysRequestHandler
+{
+public:
+ HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response);
+ ~HttpAsyncWorkRequest();
+
+ virtual void IssueRequest(std::error_code& ErrorCode) override final;
+ virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
+
+private:
+ struct AsyncWorkItem : public IWork
+ {
+ virtual void Execute() override;
+
+ AsyncWorkItem(HttpSysTransaction& InTx, std::function<void(HttpServerRequest&)>&& InHandler)
+ : Tx(InTx)
+ , Handler(std::move(InHandler))
+ {
+ }
+
+ HttpSysTransaction& Tx;
+ std::function<void(HttpServerRequest&)> Handler;
+ };
+
+ Ref<AsyncWorkItem> m_WorkItem;
+};
+
+HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response)
+: HttpSysRequestHandler(Tx)
+{
+ m_WorkItem = new AsyncWorkItem(Tx, std::move(Response));
+}
+
+HttpAsyncWorkRequest::~HttpAsyncWorkRequest()
+{
+}
+
+void
+HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode)
+{
+ ErrorCode.clear();
+
+ Transaction().Server().WorkPool().ScheduleWork(m_WorkItem);
+}
+
+HttpSysRequestHandler*
+HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ // This ought to not be called since there should be no outstanding I/O request
+ // when this completion handler is active
+
+ ZEN_UNUSED(IoResult, NumberOfBytesTransferred);
+
+ ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred);
+
+ return this;
+}
+
+void
+HttpAsyncWorkRequest::AsyncWorkItem::Execute()
+{
+ using namespace fmt::literals;
+
+ try
+ {
+ HttpSysServerRequest& ThisRequest = Tx.ServerRequest();
+
+ ThisRequest.m_NextCompletionHandler = nullptr;
+
+ Handler(ThisRequest);
+
+ // TODO: should Handler be destroyed at this point to ensure there
+ // are no outstanding references into state which could be
+ // deleted asynchronously as a result of issuing the response?
+
+ if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler)
+ {
+ return (void)Tx.IssueNextRequest(NextHandler);
+ }
+ else if (!ThisRequest.IsHandled())
+ {
+ return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv));
+ }
+ else
+ {
+ // "Handled" but no request handler? Shouldn't ever happen
+ return (void)Tx.IssueNextRequest(
+ new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv));
+ }
+ }
+ catch (std::exception& Ex)
+ {
+ return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 500, "Exception thrown in async work: '{}'"_format(Ex.what())));
+ }
+}
+
/**
_________
/ _____/ ______________ __ ___________
@@ -597,10 +707,11 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
\/ \/ \/
*/
-HttpSysServer::HttpSysServer(unsigned int ThreadCount)
+HttpSysServer::HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount)
: m_Log(logging::Get("http"))
, m_RequestLog(logging::Get("http_requests"))
, m_ThreadPool(ThreadCount)
+, m_AsyncWorkPool(AsyncWorkThreadCount)
{
ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr);
@@ -611,6 +722,8 @@ HttpSysServer::HttpSysServer(unsigned int ThreadCount)
m_IsHttpInitialized = true;
m_IsOk = true;
+
+ ZEN_INFO("http.sys server started, using {} I/O threads and {} async worker threads", ThreadCount, AsyncWorkThreadCount);
}
HttpSysServer::~HttpSysServer()
@@ -915,6 +1028,47 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
}
}
+bool
+HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler)
+{
+ HttpSysRequestHandler* CurrentHandler = m_CompletionHandler;
+ m_CompletionHandler = NewCompletionHandler;
+
+ auto _ = MakeGuard([this, CurrentHandler] {
+ if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler))
+ {
+ delete CurrentHandler;
+ }
+ });
+
+ if (NewCompletionHandler == nullptr)
+ {
+ return false;
+ }
+
+ try
+ {
+ std::error_code ErrorCode;
+ m_CompletionHandler->IssueRequest(ErrorCode);
+
+ if (!ErrorCode)
+ {
+ return true;
+ }
+
+ ZEN_ERROR("IssueRequest() failed: '{}'", ErrorCode.message());
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what());
+ }
+
+ // something went wrong, no request is pending
+ m_CompletionHandler = nullptr;
+
+ return false;
+}
+
HttpSysTransaction::Status
HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
{
@@ -934,38 +1088,9 @@ HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTran
m_HttpServer.OnHandlingRequest();
}
- m_CompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred);
+ auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred);
- if (m_CompletionHandler)
- {
- try
- {
- std::error_code ErrorCode;
- m_CompletionHandler->IssueRequest(ErrorCode);
-
- if (ErrorCode)
- {
- ZEN_ERROR("IssueRequest() failed {}", ErrorCode.message());
- }
- else
- {
- IsRequestPending = true;
- }
- }
- catch (std::exception& Ex)
- {
- ZEN_ERROR("exception caught from IssueRequest(): {}", Ex.what());
-
- // something went wrong, no request is pending
- }
- }
- else
- {
- if (CurrentHandler != &m_InitialHttpHandler)
- {
- delete CurrentHandler;
- }
- }
+ IsRequestPending = IssueNextRequest(NewCompletionHandler);
}
// Ensure new requests are enqueued as necessary
@@ -1086,23 +1211,48 @@ HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService&
const int PrefixLength = Service.UriPrefixLength();
const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t);
+ HttpContentType AcceptContentType = HttpContentType::kUnknownContentType;
+
if (AbsPathLength >= PrefixLength)
{
// We convert the URI immediately because most of the code involved prefers to deal
- // with utf8. This has some performance impact which I'd prefer to avoid but for now
- // we just have to live with it
+ // with utf8. This is overhead which I'd prefer to avoid but for now we just have
+ // to live with it
WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)},
m_UriUtf8);
+
+ std::string_view UriSuffix8{m_UriUtf8};
+
+ const size_t LastComponentIndex = UriSuffix8.find_last_of('/');
+
+ if (LastComponentIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastComponentIndex);
+ }
+
+ const size_t LastDotIndex = UriSuffix8.find_last_of('.');
+
+ if (LastDotIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastDotIndex + 1);
+
+ AcceptContentType = ParseContentType(UriSuffix8);
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_UriUtf8.RemoveSuffix(uint32_t(m_UriUtf8.Size() - LastComponentIndex - LastDotIndex - 1));
+ }
+ }
}
else
{
m_UriUtf8.Reset();
}
- if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength)
+ if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength)
{
- --QueryStringLength;
+ --QueryStringLength; // We skip the leading question mark
WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryStringUtf8);
}
@@ -1114,7 +1264,23 @@ HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService&
m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb);
m_ContentLength = GetContentLength(HttpRequestPtr);
m_ContentType = GetContentType(HttpRequestPtr);
- m_AcceptType = GetAcceptType(HttpRequestPtr);
+
+ // It an explicit content type extension was specified then we'll use that over any
+ // Accept: header value that may be present
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_AcceptType = AcceptContentType;
+ }
+ else
+ {
+ m_AcceptType = GetAcceptType(HttpRequestPtr);
+ }
+
+ if (m_Verb == HttpVerb::kHead)
+ {
+ SetSuppressResponseBody();
+ }
}
Oid
@@ -1172,13 +1338,15 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
{
ZEN_ASSERT(IsHandled() == false);
- m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
+ auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
if (SuppressBody())
{
- m_Response->SuppressResponseBody();
+ Response->SuppressResponseBody();
}
+ m_NextCompletionHandler = Response;
+
SetIsHandled();
}
@@ -1187,13 +1355,15 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
{
ZEN_ASSERT(IsHandled() == false);
- m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
+ auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
if (SuppressBody())
{
- m_Response->SuppressResponseBody();
+ Response->SuppressResponseBody();
}
+ m_NextCompletionHandler = Response;
+
SetIsHandled();
}
@@ -1202,17 +1372,32 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
{
ZEN_ASSERT(IsHandled() == false);
- m_Response =
+ auto Response =
new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size());
if (SuppressBody())
{
- m_Response->SuppressResponseBody();
+ Response->SuppressResponseBody();
}
+ m_NextCompletionHandler = Response;
+
SetIsHandled();
}
+void
+HttpSysServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
+{
+ if (m_HttpTx.Server().IsAsyncResponseEnabled())
+ {
+ ContinuationHandler(m_HttpTx.ServerRequest());
+ }
+ else
+ {
+ m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler));
+ }
+}
+
//////////////////////////////////////////////////////////////////////////
InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest)
@@ -1411,14 +1596,14 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer);
- if (!ThisRequest.IsHandled())
+ if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler)
{
- return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
+ return Response;
}
- if (HttpMessageResponseRequest* Response = ThisRequest.m_Response)
+ if (!ThisRequest.IsHandled())
{
- return Response;
+ return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
}
}
@@ -1462,4 +1647,4 @@ HttpSysServer::RegisterService(HttpService& Service)
}
} // namespace zen
-#endif \ No newline at end of file
+#endif