aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/servers/httpsys.cpp
diff options
context:
space:
mode:
authorLiam Mitchell <[email protected]>2026-03-09 18:40:40 -0700
committerLiam Mitchell <[email protected]>2026-03-09 18:40:40 -0700
commit97aa4e5c48305647a5d8f09da5f24bc1ce5540f3 (patch)
tree11062e72f4342aeb2f16ac19d6af20ac0e4acd78 /src/zenhttp/servers/httpsys.cpp
parentMerge branch 'main' into lm/oidctoken-exe-path (diff)
parentupdated chunk–block analyser (#818) (diff)
downloadzen-97aa4e5c48305647a5d8f09da5f24bc1ce5540f3.tar.xz
zen-97aa4e5c48305647a5d8f09da5f24bc1ce5540f3.zip
Merge branch 'main' into lm/oidctoken-exe-path
Diffstat (limited to 'src/zenhttp/servers/httpsys.cpp')
-rw-r--r--src/zenhttp/servers/httpsys.cpp556
1 files changed, 429 insertions, 127 deletions
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 54cc0c22d..dfe6bb6aa 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -12,6 +12,7 @@
#include <zencore/memory/llm.h>
#include <zencore/scopeguard.h>
#include <zencore/string.h>
+#include <zencore/system.h>
#include <zencore/timer.h>
#include <zencore/trace.h>
#include <zenhttp/packageformat.h>
@@ -25,7 +26,9 @@
# include <zencore/workthreadpool.h>
# include "iothreadpool.h"
+# include <atomic>
# include <http.h>
+# include <asio.hpp> // for resolving addresses for GetExternalHost
namespace zen {
@@ -72,6 +75,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In
OutString.Append("unknown");
}
+class HttpSysServerRequest;
+
/**
* @brief Windows implementation of HTTP server based on http.sys
*
@@ -83,6 +88,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In
class HttpSysServer : public HttpServer
{
friend class HttpSysTransaction;
+ friend class HttpMessageResponseRequest;
+ friend struct InitialRequestHandler;
public:
explicit HttpSysServer(const HttpSysConfig& Config);
@@ -90,17 +97,23 @@ public:
// HttpServer interface implementation
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool TestMode) override;
- virtual void OnRequestExit() override;
- virtual void OnRegisterService(HttpService& Service) override;
- virtual void OnClose() override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnRun(bool TestMode) override;
+ virtual void OnRequestExit() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
+ virtual uint64_t GetTotalBytesReceived() const override;
+ virtual uint64_t GetTotalBytesSent() const override;
WorkerThreadPool& WorkPool();
inline bool IsOk() const { return m_IsOk; }
inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; }
+ IHttpRequestFilter::Result FilterRequest(HttpSysServerRequest& Request);
+
private:
int InitializeServer(int BasePort);
void Cleanup();
@@ -124,8 +137,8 @@ private:
std::unique_ptr<WinIoThreadPool> m_IoThreadPool;
- RwLock m_AsyncWorkPoolInitLock;
- WorkerThreadPool* m_AsyncWorkPool = nullptr;
+ RwLock m_AsyncWorkPoolInitLock;
+ std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr;
std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
@@ -137,6 +150,12 @@ private:
int32_t m_MaxPendingRequests = 128;
Event m_ShutdownEvent;
HttpSysConfig m_InitialConfig;
+
+ RwLock m_RequestFilterLock;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+
+ std::atomic<uint64_t> m_TotalBytesReceived{0};
+ std::atomic<uint64_t> m_TotalBytesSent{0};
};
} // namespace zen
@@ -144,6 +163,10 @@ private:
#if ZEN_WITH_HTTPSYS
+# include "httpsys_iocontext.h"
+# include "wshttpsys.h"
+# include "wsframecodec.h"
+
# include <conio.h>
# include <mstcpip.h>
# pragma comment(lib, "httpapi.lib")
@@ -313,6 +336,10 @@ public:
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
+ virtual bool IsLocalMachineRequest() const override;
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual std::string_view GetRemoteAddress() const override;
+
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
@@ -320,16 +347,19 @@ public:
virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
virtual bool TryGetRanges(HttpRanges& Ranges) override;
+ void LogRequest(HttpMessageResponseRequest* Response);
+
using HttpServerRequest::WriteResponse;
HttpSysServerRequest(const HttpSysServerRequest&) = delete;
HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete;
- HttpSysTransaction& m_HttpTx;
- HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
- IoBuffer m_PayloadBuffer;
- ExtendableStringBuilder<128> m_UriUtf8;
- ExtendableStringBuilder<128> m_QueryStringUtf8;
+ HttpSysTransaction& m_HttpTx;
+ HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
+ IoBuffer m_PayloadBuffer;
+ ExtendableStringBuilder<128> m_UriUtf8;
+ ExtendableStringBuilder<128> m_QueryStringUtf8;
+ mutable ExtendableStringBuilder<64> m_RemoteAddress;
};
/** HTTP transaction
@@ -363,7 +393,7 @@ public:
PTP_IO Iocp();
HANDLE RequestQueueHandle();
- inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
+ inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; }
inline HttpSysServer& Server() { return m_HttpServer; }
inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
@@ -380,8 +410,8 @@ public:
};
private:
- OVERLAPPED m_HttpOverlapped{};
- HttpSysServer& m_HttpServer;
+ HttpSysIoContext m_IoContext{};
+ HttpSysServer& m_HttpServer;
// Tracks which handler is due to handle the next I/O completion event
HttpSysRequestHandler* m_CompletionHandler = nullptr;
@@ -418,7 +448,10 @@ public:
virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
void SuppressResponseBody(); // typically used for HEAD requests
- inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
+ inline uint16_t GetResponseCode() const { return m_ResponseCode; }
+ inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
+
+ void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; }
private:
eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks;
@@ -429,6 +462,7 @@ private:
bool m_IsInitialResponse = true;
HttpContentType m_ContentType = HttpContentType::kBinary;
eastl::fixed_vector<IoBuffer, 16> m_DataBuffers;
+ std::string m_LocationHeader;
void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs);
};
@@ -569,7 +603,7 @@ HttpMessageResponseRequest::SuppressResponseBody()
HttpSysRequestHandler*
HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
{
- ZEN_UNUSED(NumberOfBytesTransferred);
+ Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
if (IoResult != NO_ERROR)
{
@@ -684,6 +718,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
ContentTypeHeader->pRawValue = ContentTypeString.data();
ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
+ // Location header (for redirects)
+
+ if (!m_LocationHeader.empty())
+ {
+ PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation];
+ LocationHeader->pRawValue = m_LocationHeader.data();
+ LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size();
+ }
+
std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
HttpResponse.StatusCode = m_ResponseCode;
@@ -694,21 +737,22 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
HTTP_CACHE_POLICY CachePolicy;
- CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates;
+ CachePolicy.Policy = HttpCachePolicyNocache;
CachePolicy.SecondsToLive = 0;
// Initial response API call
- SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(),
- HttpReq->RequestId,
- SendFlags,
- &HttpResponse,
- &CachePolicy,
- NULL,
- NULL,
- 0,
- Tx.Overlapped(),
- NULL);
+ SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), // RequestQueueHandle
+ HttpReq->RequestId, // RequestId
+ SendFlags, // Flags
+ &HttpResponse, // HttpResponse
+ &CachePolicy, // CachePolicy
+ NULL, // BytesSent
+ NULL, // Reserved1
+ 0, // Reserved2
+ Tx.Overlapped(), // Overlapped
+ NULL // LogData
+ );
m_IsInitialResponse = false;
}
@@ -716,9 +760,9 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
{
// Subsequent response API calls
- SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(),
- HttpReq->RequestId,
- SendFlags,
+ SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), // RequestQueueHandle
+ HttpReq->RequestId, // RequestId
+ SendFlags, // Flags
(USHORT)ThisRequestChunkCount, // EntityChunkCount
&m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks
NULL, // BytesSent
@@ -884,7 +928,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr
ZEN_UNUSED(IoResult, NumberOfBytesTransferred);
- ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred);
+ ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}",
+ GetSystemErrorAsString(IoResult),
+ IoResult,
+ NumberOfBytesTransferred);
return this;
}
@@ -1017,8 +1064,10 @@ HttpSysServer::~HttpSysServer()
ZEN_ERROR("~HttpSysServer() called without calling Close() first");
}
- delete m_AsyncWorkPool;
+ auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed);
m_AsyncWorkPool = nullptr;
+
+ delete WorkPool;
}
void
@@ -1049,7 +1098,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})",
+ WideToUtf8(WildcardUrlPath),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1058,7 +1110,7 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result);
return 0;
}
@@ -1082,7 +1134,9 @@ HttpSysServer::InitializeServer(int BasePort)
if ((Result == ERROR_SHARING_VIOLATION))
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
@@ -1104,7 +1158,9 @@ HttpSysServer::InitializeServer(int BasePort)
{
for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++)
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
}
@@ -1128,25 +1184,29 @@ HttpSysServer::InitializeServer(int BasePort)
// port for the current user. eg:
// netsh http add urlacl url=http://*:8558/ user=<some_user>
- ZEN_WARN(
- "Unable to register handler using '{}' - falling back to local-only. "
- "Please ensure the appropriate netsh URL reservation configuration "
- "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)",
- WideToUtf8(WildcardUrlPath));
+ if (!m_InitialConfig.ForceLoopback)
+ {
+ ZEN_WARN(
+ "Unable to register handler using '{}' - falling back to local-only. "
+ "Please ensure the appropriate netsh URL reservation configuration "
+ "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)",
+ WideToUtf8(WildcardUrlPath));
+ }
const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
- ULONG InternalResult = ERROR_SHARING_VIOLATION;
- for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
+ bool ShouldRetryNextPort = true;
+ for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset)
{
- EffectivePort = BasePort + (PortOffset * 100);
+ EffectivePort = BasePort + (PortOffset * 100);
+ ShouldRetryNextPort = false;
for (const std::u8string_view Host : Hosts)
{
WideStringBuilder<64> LocalUrlPath;
LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv;
- InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+ ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
if (InternalResult == NO_ERROR)
{
@@ -1154,11 +1214,25 @@ HttpSysServer::InitializeServer(int BasePort)
m_BaseUris.push_back(LocalUrlPath.c_str());
}
+ else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED)
+ {
+ // Port may be owned by another process's wildcard registration (access denied)
+ // or actively in use (sharing violation) — retry on a different port
+ ShouldRetryNextPort = true;
+ }
else
{
- break;
+ ZEN_WARN("Failed to register local handler '{}': {} ({:#x})",
+ WideToUtf8(LocalUrlPath),
+ GetSystemErrorAsString(InternalResult),
+ InternalResult);
}
}
+
+ if (!m_BaseUris.empty())
+ {
+ break;
+ }
}
}
else
@@ -1174,7 +1248,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (m_BaseUris.empty())
{
- ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})",
+ WideToUtf8(WildcardUrlPath),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1192,7 +1269,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1204,7 +1284,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1236,7 +1319,7 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result);
+ ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result);
}
}
@@ -1258,21 +1341,6 @@ HttpSysServer::InitializeServer(int BasePort)
ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
}
- // This is not available in all Windows SDK versions so for now we can't use recently
- // released functionality. We should investigate how to get more recent SDK releases
- // into the build
-
-# if 0
- if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4))
- {
- ZEN_DEBUG("HTTP3 is available");
- }
- else
- {
- ZEN_DEBUG("HTTP3 is NOT available");
- }
-# endif
-
return EffectivePort;
}
@@ -1305,17 +1373,17 @@ HttpSysServer::WorkPool()
{
ZEN_MEMSCOPE(GetHttpsysTag());
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_acquire))
{
RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock);
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_relaxed))
{
- m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async");
+ m_AsyncWorkPool.store(new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"), std::memory_order_release);
}
}
- return *m_AsyncWorkPool;
+ return *m_AsyncWorkPool.load(std::memory_order_relaxed);
}
void
@@ -1337,9 +1405,9 @@ HttpSysServer::OnRun(bool IsInteractive)
ZEN_CONSOLE("Zen Server running (http.sys). Press ESC or Q to quit");
}
+ bool ShutdownRequested = false;
do
{
- // int WaitTimeout = -1;
int WaitTimeout = 100;
if (IsInteractive)
@@ -1352,14 +1420,15 @@ HttpSysServer::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
UpdateLofreqTimerValue();
- } while (!IsApplicationExitRequested());
+ } while (!ShutdownRequested);
}
void
@@ -1530,7 +1599,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
// than one thread at any given moment. This means we need to be careful about what
// happens in here
- HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped);
+ HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped);
+
+ switch (IoContext->ContextType)
+ {
+ case HttpSysIoContext::Type::kWebSocketRead:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kWebSocketWrite:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kTransaction:
+ break;
+ }
+
+ HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext);
if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone)
{
@@ -1641,6 +1726,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
{
HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload);
+ m_HttpServer.MarkRequest();
+
// Default request handling
# if ZEN_WITH_OTEL
@@ -1666,9 +1753,21 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator);
# endif
- if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
+ {
+ if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ {
+ Service.HandleRequest(ThisRequest);
+ }
+ }
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ ThisRequest.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
{
- Service.HandleRequest(ThisRequest);
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
}
return ThisRequest;
@@ -1810,6 +1909,52 @@ HttpSysServerRequest::ParseRequestId() const
return 0;
}
+bool
+HttpSysServerRequest::IsLocalMachineRequest() const
+{
+ const PSOCKADDR LocalAddress = m_HttpTx.HttpRequest()->Address.pLocalAddress;
+ const PSOCKADDR RemoteAddress = m_HttpTx.HttpRequest()->Address.pRemoteAddress;
+ if (LocalAddress->sa_family != RemoteAddress->sa_family)
+ {
+ return false;
+ }
+ if (LocalAddress->sa_family == AF_INET)
+ {
+ const SOCKADDR_IN& LocalAddressv4 = (const SOCKADDR_IN&)(*LocalAddress);
+ const SOCKADDR_IN& RemoteAddressv4 = (const SOCKADDR_IN&)(*RemoteAddress);
+ return LocalAddressv4.sin_addr.S_un.S_addr == RemoteAddressv4.sin_addr.S_un.S_addr;
+ }
+ else if (LocalAddress->sa_family == AF_INET6)
+ {
+ const SOCKADDR_IN6& LocalAddressv6 = (const SOCKADDR_IN6&)(*LocalAddress);
+ const SOCKADDR_IN6& RemoteAddressv6 = (const SOCKADDR_IN6&)(*RemoteAddress);
+ return memcmp(&LocalAddressv6.sin6_addr, &RemoteAddressv6.sin6_addr, sizeof(in6_addr)) == 0;
+ }
+ else
+ {
+ return false;
+ }
+}
+
+std::string_view
+HttpSysServerRequest::GetRemoteAddress() const
+{
+ if (m_RemoteAddress.Size() == 0)
+ {
+ const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress;
+ GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false);
+ }
+ return m_RemoteAddress.ToView();
+}
+
+std::string_view
+HttpSysServerRequest::GetAuthorizationHeader() const
+{
+ const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
+ const HTTP_KNOWN_HEADER& AuthorizationHeader = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAuthorization];
+ return std::string_view(AuthorizationHeader.pRawValue, AuthorizationHeader.RawValueLength);
+}
+
IoBuffer
HttpSysServerRequest::ReadPayload()
{
@@ -1823,7 +1968,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
ZEN_ASSERT(IsHandled() == false);
- auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
+ HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
if (SuppressBody())
{
@@ -1841,6 +1986,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
# endif
SetIsHandled();
+ LogRequest(Response);
}
void
@@ -1850,7 +1996,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
ZEN_ASSERT(IsHandled() == false);
- auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
+ HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
if (SuppressBody())
{
@@ -1868,6 +2014,20 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
# endif
SetIsHandled();
+ LogRequest(Response);
+}
+
+void
+HttpSysServerRequest::LogRequest(HttpMessageResponseRequest* Response)
+{
+ if (ShouldLogRequest())
+ {
+ ZEN_INFO("{} {} {} -> {}",
+ ToString(RequestVerb()),
+ m_UriUtf8.c_str(),
+ Response->GetResponseCode(),
+ NiceBytes(Response->GetResponseBodySize()));
+ }
}
void
@@ -1896,6 +2056,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
# endif
SetIsHandled();
+ LogRequest(Response);
}
void
@@ -2015,6 +2176,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
break;
}
+ Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
+
ZEN_TRACE_CPU("httpsys::HandleCompletion");
// Route request
@@ -2023,64 +2186,122 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
{
HTTP_REQUEST* HttpReq = HttpRequest();
-# if 0
- for (int i = 0; i < HttpReq->RequestInfoCount; ++i)
+ if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
{
- auto& ReqInfo = HttpReq->pRequestInfo[i];
-
- switch (ReqInfo.InfoType)
+ // WebSocket upgrade detection
+ if (m_IsInitialRequest)
{
- case HttpRequestInfoTypeRequestTiming:
+ const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade];
+ if (UpgradeHeader.RawValueLength > 0 &&
+ StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0)
+ {
+ if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service))
{
- const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo);
+ // Extract Sec-WebSocket-Key from the unknown headers
+ // (http.sys has no known-header slot for it)
+ std::string_view SecWebSocketKey;
+ for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i)
+ {
+ const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i];
+ if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0)
+ {
+ SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength);
+ break;
+ }
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeAuth:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeChannelBind:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslProtocol:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBindingDraft:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBinding:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV0:
- {
- const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo);
+ if (SecWebSocketKey.empty())
+ {
+ ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header");
+ return nullptr;
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeRequestSizing:
- {
- const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo);
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeQuicStats:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV1:
- {
- const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo);
+ const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey);
+
+ HANDLE RequestQueueHandle = Transaction().RequestQueueHandle();
+ HTTP_REQUEST_ID RequestId = HttpReq->RequestId;
+
+ // Build the 101 Switching Protocols response
+ HTTP_RESPONSE Response = {};
+ Response.StatusCode = 101;
+ Response.pReason = "Switching Protocols";
+ Response.ReasonLength = (USHORT)strlen(Response.pReason);
+
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket";
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9;
+
+ eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders;
- ZEN_INFO("");
+ // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders
+ // despite there being an entry for it there (HttpHeaderConnection). If you try to do
+ // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below
+
+ UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"});
+
+ UnknownHeaders.push_back({.NameLength = 20,
+ .RawValueLength = (USHORT)AcceptKey.size(),
+ .pName = "Sec-WebSocket-Accept",
+ .pRawValue = AcceptKey.c_str()});
+
+ Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size();
+ Response.Headers.pUnknownHeaders = UnknownHeaders.data();
+
+ const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
+
+ // Use an OVERLAPPED with an event so we can wait synchronously.
+ // The request queue is IOCP-associated, so passing NULL for pOverlapped
+ // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent
+ // prevents IOCP delivery and lets us wait on the event directly.
+ OVERLAPPED SendOverlapped = {};
+ HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr);
+ SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1);
+
+ ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle,
+ RequestId,
+ Flags,
+ &Response,
+ nullptr, // CachePolicy
+ nullptr, // BytesSent
+ nullptr, // Reserved1
+ 0, // Reserved2
+ &SendOverlapped,
+ nullptr // LogData
+ );
+
+ if (SendResult == ERROR_IO_PENDING)
+ {
+ WaitForSingleObject(SendEvent, INFINITE);
+ SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE;
+ }
+
+ CloseHandle(SendEvent);
+
+ if (SendResult == NO_ERROR)
+ {
+ Transaction().Server().OnWebSocketConnectionOpened();
+ Ref<WsHttpSysConnection> WsConn(new WsHttpSysConnection(RequestQueueHandle,
+ RequestId,
+ *WsHandler,
+ Transaction().Iocp(),
+ &Transaction().Server()));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+
+ return nullptr;
+ }
+
+ ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult);
+
+ // WebSocket upgrade failed — return nullptr since ServerRequest()
+ // was never populated (no InvokeRequestHandler call)
+ return nullptr;
}
- break;
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
}
- }
-# endif
- if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
- {
if (m_IsInitialRequest)
{
m_ContentLength = GetContentLength(HttpReq);
@@ -2146,6 +2367,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
}
}
+ else
+ {
+ // If a default redirect is configured and the request is for the root path, send a 302
+ std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect();
+ std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength);
+ if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty()))
+ {
+ auto* Response = new HttpMessageResponseRequest(Transaction(), 302);
+ Response->SetLocationHeader(DefaultRedirect);
+ return Response;
+ }
+ }
// Unable to route
return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv);
@@ -2205,12 +2438,81 @@ HttpSysServer::OnRequestExit()
m_ShutdownEvent.Set();
}
+std::string
+HttpSysServer::OnGetExternalHost() const
+{
+ // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback
+ bool IsPublic = false;
+ for (const auto& Uri : m_BaseUris)
+ {
+ if (Uri.find(L'*') != std::wstring::npos)
+ {
+ IsPublic = true;
+ break;
+ }
+ }
+
+ if (!IsPublic)
+ {
+ return "127.0.0.1";
+ }
+
+ // Use the UDP connect trick: connecting a UDP socket to an external address
+ // causes the OS to select the appropriate local interface without sending any data.
+ try
+ {
+ asio::io_service IoService;
+ asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4());
+ Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80));
+ return Sock.local_endpoint().address().to_string();
+ }
+ catch (const std::exception&)
+ {
+ return GetMachineName();
+ }
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesReceived() const
+{
+ return m_TotalBytesReceived.load(std::memory_order_relaxed);
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesSent() const
+{
+ return m_TotalBytesSent.load(std::memory_order_relaxed);
+}
+
void
HttpSysServer::OnRegisterService(HttpService& Service)
{
RegisterService(Service.BaseUri(), Service);
}
+void
+HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ RwLock::ExclusiveLockScope _(m_RequestFilterLock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+IHttpRequestFilter::Result
+HttpSysServer::FilterRequest(HttpSysServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_RequestFilterLock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ return RequestFilter->FilterRequest(Request);
+}
+
Ref<HttpServer>
CreateHttpSysServer(HttpSysConfig Config)
{