diff options
Diffstat (limited to 'src/zenhttp/servers/httpsys.cpp')
| -rw-r--r-- | src/zenhttp/servers/httpsys.cpp | 556 |
1 files changed, 429 insertions, 127 deletions
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 54cc0c22d..dfe6bb6aa 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -12,6 +12,7 @@ #include <zencore/memory/llm.h> #include <zencore/scopeguard.h> #include <zencore/string.h> +#include <zencore/system.h> #include <zencore/timer.h> #include <zencore/trace.h> #include <zenhttp/packageformat.h> @@ -25,7 +26,9 @@ # include <zencore/workthreadpool.h> # include "iothreadpool.h" +# include <atomic> # include <http.h> +# include <asio.hpp> // for resolving addresses for GetExternalHost namespace zen { @@ -72,6 +75,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In OutString.Append("unknown"); } +class HttpSysServerRequest; + /** * @brief Windows implementation of HTTP server based on http.sys * @@ -83,6 +88,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In class HttpSysServer : public HttpServer { friend class HttpSysTransaction; + friend class HttpMessageResponseRequest; + friend struct InitialRequestHandler; public: explicit HttpSysServer(const HttpSysConfig& Config); @@ -90,17 +97,23 @@ public: // HttpServer interface implementation - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool TestMode) override; - virtual void OnRequestExit() override; - virtual void OnRegisterService(HttpService& Service) override; - virtual void OnClose() override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnRun(bool TestMode) override; + virtual void OnRequestExit() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; + virtual uint64_t GetTotalBytesReceived() const override; + virtual uint64_t GetTotalBytesSent() const override; WorkerThreadPool& WorkPool(); inline bool IsOk() const { return m_IsOk; } inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + IHttpRequestFilter::Result FilterRequest(HttpSysServerRequest& Request); + private: int InitializeServer(int BasePort); void Cleanup(); @@ -124,8 +137,8 @@ private: std::unique_ptr<WinIoThreadPool> m_IoThreadPool; - RwLock m_AsyncWorkPoolInitLock; - WorkerThreadPool* m_AsyncWorkPool = nullptr; + RwLock m_AsyncWorkPoolInitLock; + std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr; std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; @@ -137,6 +150,12 @@ private: int32_t m_MaxPendingRequests = 128; Event m_ShutdownEvent; HttpSysConfig m_InitialConfig; + + RwLock m_RequestFilterLock; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; + + std::atomic<uint64_t> m_TotalBytesReceived{0}; + std::atomic<uint64_t> m_TotalBytesSent{0}; }; } // namespace zen @@ -144,6 +163,10 @@ private: #if ZEN_WITH_HTTPSYS +# include "httpsys_iocontext.h" +# include "wshttpsys.h" +# include "wsframecodec.h" + # include <conio.h> # include <mstcpip.h> # pragma comment(lib, "httpapi.lib") @@ -313,6 +336,10 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; + virtual bool IsLocalMachineRequest() const override; + virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; + virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; @@ -320,16 +347,19 @@ public: virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; virtual bool TryGetRanges(HttpRanges& Ranges) override; + void LogRequest(HttpMessageResponseRequest* Response); + using HttpServerRequest::WriteResponse; HttpSysServerRequest(const HttpSysServerRequest&) = delete; HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; - HttpSysTransaction& m_HttpTx; - HttpSysRequestHandler* m_NextCompletionHandler = nullptr; - IoBuffer m_PayloadBuffer; - ExtendableStringBuilder<128> m_UriUtf8; - ExtendableStringBuilder<128> m_QueryStringUtf8; + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; + mutable ExtendableStringBuilder<64> m_RemoteAddress; }; /** HTTP transaction @@ -363,7 +393,7 @@ public: PTP_IO Iocp(); HANDLE RequestQueueHandle(); - inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; } inline HttpSysServer& Server() { return m_HttpServer; } inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } @@ -380,8 +410,8 @@ public: }; private: - OVERLAPPED m_HttpOverlapped{}; - HttpSysServer& m_HttpServer; + HttpSysIoContext m_IoContext{}; + HttpSysServer& m_HttpServer; // Tracks which handler is due to handle the next I/O completion event HttpSysRequestHandler* m_CompletionHandler = nullptr; @@ -418,7 +448,10 @@ public: virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; void SuppressResponseBody(); // typically used for HEAD requests - inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } + inline uint16_t GetResponseCode() const { return m_ResponseCode; } + inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } + + void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; } private: eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks; @@ -429,6 +462,7 @@ private: bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; eastl::fixed_vector<IoBuffer, 16> m_DataBuffers; + std::string m_LocationHeader; void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); }; @@ -569,7 +603,7 @@ HttpMessageResponseRequest::SuppressResponseBody() HttpSysRequestHandler* HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { - ZEN_UNUSED(NumberOfBytesTransferred); + Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); if (IoResult != NO_ERROR) { @@ -684,6 +718,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + // Location header (for redirects) + + if (!m_LocationHeader.empty()) + { + PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation]; + LocationHeader->pRawValue = m_LocationHeader.data(); + LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size(); + } + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.StatusCode = m_ResponseCode; @@ -694,21 +737,22 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) HTTP_CACHE_POLICY CachePolicy; - CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; + CachePolicy.Policy = HttpCachePolicyNocache; CachePolicy.SecondsToLive = 0; // Initial response API call - SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, - &HttpResponse, - &CachePolicy, - NULL, - NULL, - 0, - Tx.Overlapped(), - NULL); + SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), // RequestQueueHandle + HttpReq->RequestId, // RequestId + SendFlags, // Flags + &HttpResponse, // HttpResponse + &CachePolicy, // CachePolicy + NULL, // BytesSent + NULL, // Reserved1 + 0, // Reserved2 + Tx.Overlapped(), // Overlapped + NULL // LogData + ); m_IsInitialResponse = false; } @@ -716,9 +760,9 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) { // Subsequent response API calls - SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, + SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), // RequestQueueHandle + HttpReq->RequestId, // RequestId + SendFlags, // Flags (USHORT)ThisRequestChunkCount, // EntityChunkCount &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks NULL, // BytesSent @@ -884,7 +928,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr ZEN_UNUSED(IoResult, NumberOfBytesTransferred); - ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); + ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}", + GetSystemErrorAsString(IoResult), + IoResult, + NumberOfBytesTransferred); return this; } @@ -1017,8 +1064,10 @@ HttpSysServer::~HttpSysServer() ZEN_ERROR("~HttpSysServer() called without calling Close() first"); } - delete m_AsyncWorkPool; + auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed); m_AsyncWorkPool = nullptr; + + delete WorkPool; } void @@ -1049,7 +1098,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})", + WideToUtf8(WildcardUrlPath), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1058,7 +1110,7 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result); return 0; } @@ -1082,7 +1134,9 @@ HttpSysServer::InitializeServer(int BasePort) if ((Result == ERROR_SHARING_VIOLATION)) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); @@ -1104,7 +1158,9 @@ HttpSysServer::InitializeServer(int BasePort) { for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); } @@ -1128,25 +1184,29 @@ HttpSysServer::InitializeServer(int BasePort) // port for the current user. eg: // netsh http add urlacl url=http://*:8558/ user=<some_user> - ZEN_WARN( - "Unable to register handler using '{}' - falling back to local-only. " - "Please ensure the appropriate netsh URL reservation configuration " - "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)", - WideToUtf8(WildcardUrlPath)); + if (!m_InitialConfig.ForceLoopback) + { + ZEN_WARN( + "Unable to register handler using '{}' - falling back to local-only. " + "Please ensure the appropriate netsh URL reservation configuration " + "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)", + WideToUtf8(WildcardUrlPath)); + } const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; - ULONG InternalResult = ERROR_SHARING_VIOLATION; - for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + bool ShouldRetryNextPort = true; + for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset) { - EffectivePort = BasePort + (PortOffset * 100); + EffectivePort = BasePort + (PortOffset * 100); + ShouldRetryNextPort = false; for (const std::u8string_view Host : Hosts) { WideStringBuilder<64> LocalUrlPath; LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; - InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); if (InternalResult == NO_ERROR) { @@ -1154,11 +1214,25 @@ HttpSysServer::InitializeServer(int BasePort) m_BaseUris.push_back(LocalUrlPath.c_str()); } + else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED) + { + // Port may be owned by another process's wildcard registration (access denied) + // or actively in use (sharing violation) — retry on a different port + ShouldRetryNextPort = true; + } else { - break; + ZEN_WARN("Failed to register local handler '{}': {} ({:#x})", + WideToUtf8(LocalUrlPath), + GetSystemErrorAsString(InternalResult), + InternalResult); } } + + if (!m_BaseUris.empty()) + { + break; + } } } else @@ -1174,7 +1248,10 @@ HttpSysServer::InitializeServer(int BasePort) if (m_BaseUris.empty()) { - ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})", + WideToUtf8(WildcardUrlPath), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1192,7 +1269,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1204,7 +1284,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1236,7 +1319,7 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result); + ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result); } } @@ -1258,21 +1341,6 @@ HttpSysServer::InitializeServer(int BasePort) ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); } - // This is not available in all Windows SDK versions so for now we can't use recently - // released functionality. We should investigate how to get more recent SDK releases - // into the build - -# if 0 - if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4)) - { - ZEN_DEBUG("HTTP3 is available"); - } - else - { - ZEN_DEBUG("HTTP3 is NOT available"); - } -# endif - return EffectivePort; } @@ -1305,17 +1373,17 @@ HttpSysServer::WorkPool() { ZEN_MEMSCOPE(GetHttpsysTag()); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_acquire)) { RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_relaxed)) { - m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"); + m_AsyncWorkPool.store(new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"), std::memory_order_release); } } - return *m_AsyncWorkPool; + return *m_AsyncWorkPool.load(std::memory_order_relaxed); } void @@ -1337,9 +1405,9 @@ HttpSysServer::OnRun(bool IsInteractive) ZEN_CONSOLE("Zen Server running (http.sys). Press ESC or Q to quit"); } + bool ShutdownRequested = false; do { - // int WaitTimeout = -1; int WaitTimeout = 100; if (IsInteractive) @@ -1352,14 +1420,15 @@ HttpSysServer::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } } - m_ShutdownEvent.Wait(WaitTimeout); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); UpdateLofreqTimerValue(); - } while (!IsApplicationExitRequested()); + } while (!ShutdownRequested); } void @@ -1530,7 +1599,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, // than one thread at any given moment. This means we need to be careful about what // happens in here - HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped); + + switch (IoContext->ContextType) + { + case HttpSysIoContext::Type::kWebSocketRead: + static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kWebSocketWrite: + static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kTransaction: + break; + } + + HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext); if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) { @@ -1641,6 +1726,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) { HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + m_HttpServer.MarkRequest(); + // Default request handling # if ZEN_WITH_OTEL @@ -1666,9 +1753,21 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator); # endif - if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest); + if (FilterResult == IHttpRequestFilter::Result::Accepted) + { + if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + { + Service.HandleRequest(ThisRequest); + } + } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + ThisRequest.WriteResponse(HttpResponseCode::Forbidden); + } + else { - Service.HandleRequest(ThisRequest); + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); } return ThisRequest; @@ -1810,6 +1909,52 @@ HttpSysServerRequest::ParseRequestId() const return 0; } +bool +HttpSysServerRequest::IsLocalMachineRequest() const +{ + const PSOCKADDR LocalAddress = m_HttpTx.HttpRequest()->Address.pLocalAddress; + const PSOCKADDR RemoteAddress = m_HttpTx.HttpRequest()->Address.pRemoteAddress; + if (LocalAddress->sa_family != RemoteAddress->sa_family) + { + return false; + } + if (LocalAddress->sa_family == AF_INET) + { + const SOCKADDR_IN& LocalAddressv4 = (const SOCKADDR_IN&)(*LocalAddress); + const SOCKADDR_IN& RemoteAddressv4 = (const SOCKADDR_IN&)(*RemoteAddress); + return LocalAddressv4.sin_addr.S_un.S_addr == RemoteAddressv4.sin_addr.S_un.S_addr; + } + else if (LocalAddress->sa_family == AF_INET6) + { + const SOCKADDR_IN6& LocalAddressv6 = (const SOCKADDR_IN6&)(*LocalAddress); + const SOCKADDR_IN6& RemoteAddressv6 = (const SOCKADDR_IN6&)(*RemoteAddress); + return memcmp(&LocalAddressv6.sin6_addr, &RemoteAddressv6.sin6_addr, sizeof(in6_addr)) == 0; + } + else + { + return false; + } +} + +std::string_view +HttpSysServerRequest::GetRemoteAddress() const +{ + if (m_RemoteAddress.Size() == 0) + { + const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress; + GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false); + } + return m_RemoteAddress.ToView(); +} + +std::string_view +HttpSysServerRequest::GetAuthorizationHeader() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + const HTTP_KNOWN_HEADER& AuthorizationHeader = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAuthorization]; + return std::string_view(AuthorizationHeader.pRawValue, AuthorizationHeader.RawValueLength); +} + IoBuffer HttpSysServerRequest::ReadPayload() { @@ -1823,7 +1968,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) ZEN_ASSERT(IsHandled() == false); - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); if (SuppressBody()) { @@ -1841,6 +1986,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) # endif SetIsHandled(); + LogRequest(Response); } void @@ -1850,7 +1996,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy ZEN_ASSERT(IsHandled() == false); - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); if (SuppressBody()) { @@ -1868,6 +2014,20 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy # endif SetIsHandled(); + LogRequest(Response); +} + +void +HttpSysServerRequest::LogRequest(HttpMessageResponseRequest* Response) +{ + if (ShouldLogRequest()) + { + ZEN_INFO("{} {} {} -> {}", + ToString(RequestVerb()), + m_UriUtf8.c_str(), + Response->GetResponseCode(), + NiceBytes(Response->GetResponseBodySize())); + } } void @@ -1896,6 +2056,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy # endif SetIsHandled(); + LogRequest(Response); } void @@ -2015,6 +2176,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT break; } + Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); + ZEN_TRACE_CPU("httpsys::HandleCompletion"); // Route request @@ -2023,64 +2186,122 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT { HTTP_REQUEST* HttpReq = HttpRequest(); -# if 0 - for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) { - auto& ReqInfo = HttpReq->pRequestInfo[i]; - - switch (ReqInfo.InfoType) + // WebSocket upgrade detection + if (m_IsInitialRequest) { - case HttpRequestInfoTypeRequestTiming: + const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade]; + if (UpgradeHeader.RawValueLength > 0 && + StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0) + { + if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service)) { - const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo); + // Extract Sec-WebSocket-Key from the unknown headers + // (http.sys has no known-header slot for it) + std::string_view SecWebSocketKey; + for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i) + { + const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i]; + if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0) + { + SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength); + break; + } + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeAuth: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeChannelBind: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslProtocol: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBindingDraft: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBinding: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV0: - { - const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo); + if (SecWebSocketKey.empty()) + { + ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header"); + return nullptr; + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeRequestSizing: - { - const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo); - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeQuicStats: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV1: - { - const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo); + const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey); + + HANDLE RequestQueueHandle = Transaction().RequestQueueHandle(); + HTTP_REQUEST_ID RequestId = HttpReq->RequestId; + + // Build the 101 Switching Protocols response + HTTP_RESPONSE Response = {}; + Response.StatusCode = 101; + Response.pReason = "Switching Protocols"; + Response.ReasonLength = (USHORT)strlen(Response.pReason); + + Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket"; + Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9; + + eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders; - ZEN_INFO(""); + // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders + // despite there being an entry for it there (HttpHeaderConnection). If you try to do + // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below + + UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"}); + + UnknownHeaders.push_back({.NameLength = 20, + .RawValueLength = (USHORT)AcceptKey.size(), + .pName = "Sec-WebSocket-Accept", + .pRawValue = AcceptKey.c_str()}); + + Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size(); + Response.Headers.pUnknownHeaders = UnknownHeaders.data(); + + const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + + // Use an OVERLAPPED with an event so we can wait synchronously. + // The request queue is IOCP-associated, so passing NULL for pOverlapped + // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent + // prevents IOCP delivery and lets us wait on the event directly. + OVERLAPPED SendOverlapped = {}; + HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr); + SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1); + + ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle, + RequestId, + Flags, + &Response, + nullptr, // CachePolicy + nullptr, // BytesSent + nullptr, // Reserved1 + 0, // Reserved2 + &SendOverlapped, + nullptr // LogData + ); + + if (SendResult == ERROR_IO_PENDING) + { + WaitForSingleObject(SendEvent, INFINITE); + SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE; + } + + CloseHandle(SendEvent); + + if (SendResult == NO_ERROR) + { + Transaction().Server().OnWebSocketConnectionOpened(); + Ref<WsHttpSysConnection> WsConn(new WsHttpSysConnection(RequestQueueHandle, + RequestId, + *WsHandler, + Transaction().Iocp(), + &Transaction().Server())); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + + return nullptr; + } + + ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult); + + // WebSocket upgrade failed — return nullptr since ServerRequest() + // was never populated (no InvokeRequestHandler call) + return nullptr; } - break; + // Service doesn't support WebSocket or missing key — fall through to normal handling + } } - } -# endif - if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) - { if (m_IsInitialRequest) { m_ContentLength = GetContentLength(HttpReq); @@ -2146,6 +2367,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); } } + else + { + // If a default redirect is configured and the request is for the root path, send a 302 + std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect(); + std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength); + if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty())) + { + auto* Response = new HttpMessageResponseRequest(Transaction(), 302); + Response->SetLocationHeader(DefaultRedirect); + return Response; + } + } // Unable to route return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); @@ -2205,12 +2438,81 @@ HttpSysServer::OnRequestExit() m_ShutdownEvent.Set(); } +std::string +HttpSysServer::OnGetExternalHost() const +{ + // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback + bool IsPublic = false; + for (const auto& Uri : m_BaseUris) + { + if (Uri.find(L'*') != std::wstring::npos) + { + IsPublic = true; + break; + } + } + + if (!IsPublic) + { + return "127.0.0.1"; + } + + // Use the UDP connect trick: connecting a UDP socket to an external address + // causes the OS to select the appropriate local interface without sending any data. + try + { + asio::io_service IoService; + asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4()); + Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80)); + return Sock.local_endpoint().address().to_string(); + } + catch (const std::exception&) + { + return GetMachineName(); + } +} + +uint64_t +HttpSysServer::GetTotalBytesReceived() const +{ + return m_TotalBytesReceived.load(std::memory_order_relaxed); +} + +uint64_t +HttpSysServer::GetTotalBytesSent() const +{ + return m_TotalBytesSent.load(std::memory_order_relaxed); +} + void HttpSysServer::OnRegisterService(HttpService& Service) { RegisterService(Service.BaseUri(), Service); } +void +HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + RwLock::ExclusiveLockScope _(m_RequestFilterLock); + m_HttpRequestFilter.store(RequestFilter); +} + +IHttpRequestFilter::Result +HttpSysServer::FilterRequest(HttpSysServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_RequestFilterLock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + return RequestFilter->FilterRequest(Request); +} + Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config) { |