diff options
| author | Stefan Boberg <[email protected]> | 2026-03-15 20:42:36 +0100 |
|---|---|---|
| committer | Stefan Boberg <[email protected]> | 2026-03-15 20:42:36 +0100 |
| commit | 9c724efbf6b38466a9b6bfde37236369f1e85cb8 (patch) | |
| tree | 214e1ec00c5bfca0704ce52789017ade734fd054 /src/zenhttp/servers/httpsys.cpp | |
| parent | reduced WaitForThreads time to see how it behaves with explicit thread pools (diff) | |
| parent | add buildid updates to oplog and builds test scripts (#838) (diff) | |
| download | archived-zen-9c724efbf6b38466a9b6bfde37236369f1e85cb8.tar.xz archived-zen-9c724efbf6b38466a9b6bfde37236369f1e85cb8.zip | |
Merge remote-tracking branch 'origin/main' into sb/threadpool
Diffstat (limited to 'src/zenhttp/servers/httpsys.cpp')
| -rw-r--r-- | src/zenhttp/servers/httpsys.cpp | 786 |
1 files changed, 640 insertions, 146 deletions
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 4406d0619..eaf080960 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -12,6 +12,7 @@ #include <zencore/memory/llm.h> #include <zencore/scopeguard.h> #include <zencore/string.h> +#include <zencore/system.h> #include <zencore/timer.h> #include <zencore/trace.h> #include <zenhttp/packageformat.h> @@ -25,7 +26,9 @@ # include <zencore/workthreadpool.h> # include "iothreadpool.h" +# include <atomic> # include <http.h> +# include <asio.hpp> // for resolving addresses for GetExternalHost namespace zen { @@ -85,6 +88,8 @@ class HttpSysServerRequest; class HttpSysServer : public HttpServer { friend class HttpSysTransaction; + friend class HttpMessageResponseRequest; + friend struct InitialRequestHandler; public: explicit HttpSysServer(const HttpSysConfig& Config); @@ -92,12 +97,15 @@ public: // HttpServer interface implementation - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool TestMode) override; - virtual void OnRequestExit() override; - virtual void OnRegisterService(HttpService& Service) override; - virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; - virtual void OnClose() override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnRun(bool TestMode) override; + virtual void OnRequestExit() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; + virtual uint64_t GetTotalBytesReceived() const override; + virtual uint64_t GetTotalBytesSent() const override; WorkerThreadPool& WorkPool(); @@ -108,6 +116,12 @@ public: private: int InitializeServer(int BasePort); + bool CreateSessionAndUrlGroup(); + bool RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris); + int RegisterHttpUrls(int BasePort); + bool RegisterHttpsUrls(); + bool CreateRequestQueue(int EffectivePort); + bool SetupIoCompletionPort(); void Cleanup(); void StartServer(); @@ -117,6 +131,9 @@ private: void RegisterService(const char* Endpoint, HttpService& Service); void UnregisterService(const char* Endpoint, HttpService& Service); + bool BindSslCertificate(int Port); + void UnbindSslCertificate(); + private: LoggerRef m_Log; LoggerRef m_RequestLog; @@ -130,10 +147,13 @@ private: std::unique_ptr<WinIoThreadPool> m_IoThreadPool; bool m_IoThreadPoolIsWinTp = true; - RwLock m_AsyncWorkPoolInitLock; - WorkerThreadPool* m_AsyncWorkPool = nullptr; + RwLock m_AsyncWorkPoolInitLock; + std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr; - std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + std::vector<std::wstring> m_HttpsBaseUris; // eg: https://*:nnnn/ + bool m_DidAutoBindCert = false; + int m_HttpsPort = 0; HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; HANDLE m_RequestQueueHandle = 0; @@ -146,6 +166,9 @@ private: RwLock m_RequestFilterLock; std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; + + std::atomic<uint64_t> m_TotalBytesReceived{0}; + std::atomic<uint64_t> m_TotalBytesSent{0}; }; } // namespace zen @@ -153,6 +176,10 @@ private: #if ZEN_WITH_HTTPSYS +# include "httpsys_iocontext.h" +# include "wshttpsys.h" +# include "wsframecodec.h" + # include <conio.h> # include <mstcpip.h> # pragma comment(lib, "httpapi.lib") @@ -322,8 +349,9 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; - virtual bool IsLocalMachineRequest() const; + virtual bool IsLocalMachineRequest() const override; virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -339,11 +367,12 @@ public: HttpSysServerRequest(const HttpSysServerRequest&) = delete; HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; - HttpSysTransaction& m_HttpTx; - HttpSysRequestHandler* m_NextCompletionHandler = nullptr; - IoBuffer m_PayloadBuffer; - ExtendableStringBuilder<128> m_UriUtf8; - ExtendableStringBuilder<128> m_QueryStringUtf8; + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; + mutable ExtendableStringBuilder<64> m_RemoteAddress; }; /** HTTP transaction @@ -378,7 +407,7 @@ public: void StartIo(); void CancelIo(); HANDLE RequestQueueHandle(); - inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; } inline HttpSysServer& Server() { return m_HttpServer; } inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } @@ -395,8 +424,8 @@ public: }; private: - OVERLAPPED m_HttpOverlapped{}; - HttpSysServer& m_HttpServer; + HttpSysIoContext m_IoContext{}; + HttpSysServer& m_HttpServer; // Tracks which handler is due to handle the next I/O completion event HttpSysRequestHandler* m_CompletionHandler = nullptr; @@ -436,6 +465,8 @@ public: inline uint16_t GetResponseCode() const { return m_ResponseCode; } inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } + void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; } + private: eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks; uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes @@ -445,6 +476,7 @@ private: bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; eastl::fixed_vector<IoBuffer, 16> m_DataBuffers; + std::string m_LocationHeader; void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); }; @@ -585,7 +617,7 @@ HttpMessageResponseRequest::SuppressResponseBody() HttpSysRequestHandler* HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { - ZEN_UNUSED(NumberOfBytesTransferred); + Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); if (IoResult != NO_ERROR) { @@ -699,6 +731,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + // Location header (for redirects) + + if (!m_LocationHeader.empty()) + { + PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation]; + LocationHeader->pRawValue = m_LocationHeader.data(); + LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size(); + } + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.StatusCode = m_ResponseCode; @@ -900,7 +941,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr ZEN_UNUSED(IoResult, NumberOfBytesTransferred); - ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); + ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}", + GetSystemErrorAsString(IoResult), + IoResult, + NumberOfBytesTransferred); return this; } @@ -1035,8 +1079,10 @@ HttpSysServer::~HttpSysServer() ZEN_ERROR("~HttpSysServer() called without calling Close() first"); } - delete m_AsyncWorkPool; + auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed); m_AsyncWorkPool = nullptr; + + delete WorkPool; } void @@ -1051,36 +1097,63 @@ HttpSysServer::OnClose() } } -int -HttpSysServer::InitializeServer(int BasePort) +bool +HttpSysServer::CreateSessionAndUrlGroup() { - ZEN_MEMSCOPE(GetHttpsysTag()); - - using namespace std::literals; - - WideStringBuilder<64> WildcardUrlPath; - WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; - - m_IsOk = false; - ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create server session: {} ({:#x})", GetSystemErrorAsString(Result), Result); - return 0; + return false; } Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create URL group: {} ({:#x})", GetSystemErrorAsString(Result), Result); - return 0; + return false; } + return true; +} + +bool +HttpSysServer::RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris) +{ + using namespace std::literals; + + const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; + + for (const std::u8string_view Host : Hosts) + { + WideStringBuilder<64> LocalUrl; + LocalUrl << Scheme << u8"://"sv << Host << u8":"sv << int64_t(Port) << u8"/"sv; + + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrl.c_str(), HTTP_URL_CONTEXT(0), 0); + + if (Result == NO_ERROR) + { + ZEN_WARN("Registered local-only handler '{}' - this is not accessible from remote hosts", WideToUtf8(LocalUrl)); + OutUris.push_back(LocalUrl.c_str()); + } + else + { + break; + } + } + + return !OutUris.empty(); +} + +int +HttpSysServer::RegisterHttpUrls(int BasePort) +{ + using namespace std::literals; + m_BaseUris.clear(); const bool AllowPortProbing = !m_InitialConfig.IsDedicatedServer; @@ -1088,6 +1161,11 @@ HttpSysServer::InitializeServer(int BasePort) int EffectivePort = BasePort; + WideStringBuilder<64> WildcardUrlPath; + WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; + + ULONG Result; + if (m_InitialConfig.ForceLoopback) { // Force trigger of opening using local port @@ -1100,7 +1178,9 @@ HttpSysServer::InitializeServer(int BasePort) if ((Result == ERROR_SHARING_VIOLATION)) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); @@ -1122,7 +1202,9 @@ HttpSysServer::InitializeServer(int BasePort) { for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); } @@ -1139,11 +1221,11 @@ HttpSysServer::InitializeServer(int BasePort) { if (AllowLocalOnly) { - // If we can't register the wildcard path, we fall back to local paths - // This local paths allow requests originating locally to function, but will not allow - // remote origin requests to function. This can be remedied by using netsh + // If we can't register the wildcard path, we fall back to local paths. + // Local paths allow requests originating locally to function, but will not allow + // remote origin requests to function. This can be remedied by using netsh // during an install process to grant permissions to route public access to the appropriate - // port for the current user. eg: + // port for the current user. eg: // netsh http add urlacl url=http://*:8558/ user=<some_user> if (!m_InitialConfig.ForceLoopback) @@ -1157,17 +1239,18 @@ HttpSysServer::InitializeServer(int BasePort) const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; - ULONG InternalResult = ERROR_SHARING_VIOLATION; - for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + bool ShouldRetryNextPort = true; + for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset) { - EffectivePort = BasePort + (PortOffset * 100); + EffectivePort = BasePort + (PortOffset * 100); + ShouldRetryNextPort = false; for (const std::u8string_view Host : Hosts) { WideStringBuilder<64> LocalUrlPath; LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; - InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); if (InternalResult == NO_ERROR) { @@ -1175,11 +1258,25 @@ HttpSysServer::InitializeServer(int BasePort) m_BaseUris.push_back(LocalUrlPath.c_str()); } + else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED) + { + // Port may be owned by another process's wildcard registration (access denied) + // or actively in use (sharing violation) — retry on a different port + ShouldRetryNextPort = true; + } else { - break; + ZEN_WARN("Failed to register local handler '{}': {} ({:#x})", + WideToUtf8(LocalUrlPath), + GetSystemErrorAsString(InternalResult), + InternalResult); } } + + if (!m_BaseUris.empty()) + { + break; + } } } else @@ -1193,29 +1290,123 @@ HttpSysServer::InitializeServer(int BasePort) } } - if (m_BaseUris.empty()) + if (m_BaseUris.empty() && m_InitialConfig.HttpsPort == 0) { - ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})", + WideToUtf8(WildcardUrlPath), + GetSystemErrorAsString(Result), + Result); return 0; } + return EffectivePort; +} + +bool +HttpSysServer::RegisterHttpsUrls() +{ + using namespace std::literals; + + const bool AllowLocalOnly = !m_InitialConfig.IsDedicatedServer; + const int HttpsPort = m_InitialConfig.HttpsPort; + + // If HTTPS-only mode, remove HTTP URLs and clear base URIs + if (m_InitialConfig.HttpsOnly) + { + for (const std::wstring& Uri : m_BaseUris) + { + HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Uri.c_str(), 0); + } + m_BaseUris.clear(); + } + + // Auto-bind certificate if thumbprint is provided + if (!m_InitialConfig.CertThumbprint.empty()) + { + if (!BindSslCertificate(HttpsPort)) + { + return false; + } + } + else + { + ZEN_INFO("HTTPS port {} configured without thumbprint - assuming pre-registered SSL certificate", HttpsPort); + } + + // Register HTTPS URLs using same pattern as HTTP + + WideStringBuilder<64> HttpsWildcard; + HttpsWildcard << u8"https://*:"sv << int64_t(HttpsPort) << u8"/"sv; + + ULONG HttpsResult = NO_ERROR; + + if (m_InitialConfig.ForceLoopback) + { + HttpsResult = ERROR_ACCESS_DENIED; + } + else + { + HttpsResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, HttpsWildcard.c_str(), HTTP_URL_CONTEXT(0), 0); + } + + if (HttpsResult == NO_ERROR) + { + m_HttpsBaseUris.push_back(HttpsWildcard.c_str()); + } + else if (HttpsResult == ERROR_ACCESS_DENIED && AllowLocalOnly) + { + if (!m_InitialConfig.ForceLoopback) + { + ZEN_WARN( + "Unable to register HTTPS handler using '{}' - falling back to local-only. " + "Please ensure the appropriate netsh URL reservation and SSL certificate configuration is made.", + WideToUtf8(HttpsWildcard)); + } + + RegisterLocalUrls(u8"https", HttpsPort, m_HttpsBaseUris); + } + else if (HttpsResult != NO_ERROR) + { + ZEN_ERROR("Failed to register HTTPS URL '{}': {} ({:#x})", + WideToUtf8(HttpsWildcard), + GetSystemErrorAsString(HttpsResult), + HttpsResult); + return false; + } + + if (m_HttpsBaseUris.empty()) + { + ZEN_ERROR("Failed to register any HTTPS URL for port {}", HttpsPort); + return false; + } + + m_HttpsPort = HttpsPort; + return true; +} + +bool +HttpSysServer::CreateRequestQueue(int EffectivePort) +{ HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; WideStringBuilder<64> QueueName; QueueName << "zenserver_" << EffectivePort; - Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, - /* Name */ QueueName.c_str(), - /* SecurityAttributes */ nullptr, - /* Flags */ 0, - &m_RequestQueueHandle); + ULONG Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, + /* Name */ QueueName.c_str(), + /* SecurityAttributes */ nullptr, + /* Flags */ 0, + &m_RequestQueueHandle); if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); - return 0; + return false; } HttpBindingInfo.Flags.Present = 1; @@ -1225,9 +1416,12 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); - return 0; + return false; } // Configure rejection method. Default is to drop the connection, it's better if we @@ -1257,42 +1451,82 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result); + ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result); } } - // Create I/O completion port + return true; +} +bool +HttpSysServer::SetupIoCompletionPort() +{ std::error_code ErrorCode; m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); if (ErrorCode) { - ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); + ZEN_ERROR("Failed to create IOCP: {}", ErrorCode.message()); + return false; + } + + m_IsOk = true; + + if (!m_BaseUris.empty()) + { + ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + } + if (!m_HttpsBaseUris.empty()) + { + ZEN_INFO("Started http.sys HTTPS server at '{}'", WideToUtf8(m_HttpsBaseUris.front())); + } + + return true; +} + +int +HttpSysServer::InitializeServer(int BasePort) +{ + ZEN_MEMSCOPE(GetHttpsysTag()); + + m_IsOk = false; + if (!CreateSessionAndUrlGroup()) + { return 0; } - else + + int EffectivePort = RegisterHttpUrls(BasePort); + + if (m_InitialConfig.HttpsPort > 0) { - m_IsOk = true; + if (!RegisterHttpsUrls()) + { + return 0; + } + } - ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + if (m_BaseUris.empty() && m_HttpsBaseUris.empty()) + { + ZEN_ERROR("No HTTP or HTTPS listeners could be registered"); + return 0; } - // This is not available in all Windows SDK versions so for now we can't use recently - // released functionality. We should investigate how to get more recent SDK releases - // into the build + if (!CreateRequestQueue(EffectivePort)) + { + return 0; + } -# if 0 - if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4)) + if (!SetupIoCompletionPort()) { - ZEN_DEBUG("HTTP3 is available"); + return 0; } - else + + // When HTTPS-only, return the HTTPS port as the effective port + if (m_InitialConfig.HttpsOnly && m_HttpsPort > 0) { - ZEN_DEBUG("HTTP3 is NOT available"); + return m_HttpsPort; } -# endif return EffectivePort; } @@ -1302,6 +1536,8 @@ HttpSysServer::Cleanup() { ++m_IsShuttingDown; + UnbindSslCertificate(); + if (m_RequestQueueHandle) { HttpCloseRequestQueue(m_RequestQueueHandle); @@ -1321,23 +1557,122 @@ HttpSysServer::Cleanup() } } +// {7E3F4B2A-1C8D-4A6E-B5F0-9D2E8C7A3B1F} - Fixed GUID for zenserver SSL bindings +static constexpr GUID ZenServerSslAppId = {0x7E3F4B2A, 0x1C8D, 0x4A6E, {0xB5, 0xF0, 0x9D, 0x2E, 0x8C, 0x7A, 0x3B, 0x1F}}; + +bool +HttpSysServer::BindSslCertificate(int Port) +{ + const std::string& Thumbprint = m_InitialConfig.CertThumbprint; + if (Thumbprint.size() != 40) + { + ZEN_ERROR("SSL certificate thumbprint must be exactly 40 hex characters, got {}", Thumbprint.size()); + return false; + } + + BYTE CertHash[20] = {}; + if (!ParseHexBytes(Thumbprint, CertHash)) + { + ZEN_ERROR("SSL certificate thumbprint contains invalid hex characters"); + return false; + } + + SOCKADDR_IN Address = {}; + Address.sin_family = AF_INET; + Address.sin_port = htons(static_cast<USHORT>(Port)); + Address.sin_addr.s_addr = INADDR_ANY; + + const std::wstring StoreNameW = UTF8_to_UTF16(m_InitialConfig.CertStoreName.c_str()); + + HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {}; + SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + SslConfig.ParamDesc.pSslHash = CertHash; + SslConfig.ParamDesc.SslHashLength = sizeof(CertHash); + SslConfig.ParamDesc.pSslCertStoreName = const_cast<PWSTR>(StoreNameW.c_str()); + SslConfig.ParamDesc.AppId = ZenServerSslAppId; + + ULONG Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + + if (Result == ERROR_ALREADY_EXISTS) + { + // Remove existing binding and retry + HTTP_SERVICE_CONFIG_SSL_SET DeleteConfig = {}; + DeleteConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + + HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &DeleteConfig, sizeof(DeleteConfig), nullptr); + + Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + } + + if (Result != NO_ERROR) + { + ZEN_ERROR( + "Failed to bind SSL certificate to port {}: {} ({:#x}). " + "This operation may require running as administrator.", + Port, + GetSystemErrorAsString(Result), + Result); + return false; + } + + m_DidAutoBindCert = true; + m_HttpsPort = Port; + + ZEN_INFO("SSL certificate auto-bound for 0.0.0.0:{} (thumbprint: {}..., store: {})", + Port, + Thumbprint.substr(0, 8), + m_InitialConfig.CertStoreName); + + return true; +} + +void +HttpSysServer::UnbindSslCertificate() +{ + if (!m_DidAutoBindCert) + { + return; + } + + SOCKADDR_IN Address = {}; + Address.sin_family = AF_INET; + Address.sin_port = htons(static_cast<USHORT>(m_HttpsPort)); + Address.sin_addr.s_addr = INADDR_ANY; + + HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {}; + SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + + ULONG Result = HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + + if (Result != NO_ERROR) + { + ZEN_WARN("Failed to remove SSL certificate binding from port {}: {} ({:#x})", m_HttpsPort, GetSystemErrorAsString(Result), Result); + } + else + { + ZEN_INFO("SSL certificate binding removed from port {}", m_HttpsPort); + } + + m_DidAutoBindCert = false; +} + WorkerThreadPool& HttpSysServer::WorkPool() { ZEN_MEMSCOPE(GetHttpsysTag()); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_acquire)) { RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_relaxed)) { m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async", m_InitialConfig.UseExplicitIoThreadPool); } } - return *m_AsyncWorkPool; + return *m_AsyncWorkPool.load(std::memory_order_relaxed); } void @@ -1449,19 +1784,23 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) // Convert to wide string - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; - - ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); - - if (Result != NO_ERROR) + auto RegisterWithBaseUris = [&](const std::vector<std::wstring>& BaseUris) { + for (const std::wstring& BaseUri : BaseUris) { - ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + std::wstring Url16 = BaseUri + PathUtf16; - return; + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + return; + } } - } + }; + + RegisterWithBaseUris(m_BaseUris); + RegisterWithBaseUris(m_HttpsBaseUris); } void @@ -1476,19 +1815,22 @@ HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); - // Convert to wide string - - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; + auto UnregisterFromBaseUris = [&](const std::vector<std::wstring>& BaseUris) { + for (const std::wstring& BaseUri : BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; - ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); + ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); - if (Result != NO_ERROR) - { - ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + } } - } + }; + + UnregisterFromBaseUris(m_BaseUris); + UnregisterFromBaseUris(m_HttpsBaseUris); } ////////////////////////////////////////////////////////////////////////// @@ -1551,7 +1893,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, // than one thread at any given moment. This means we need to be careful about what // happens in here - HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped); + + switch (IoContext->ContextType) + { + case HttpSysIoContext::Type::kWebSocketRead: + static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kWebSocketWrite: + static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kTransaction: + break; + } + + HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext); // Assign names to threads for context (only needed when using Windows' native // thread pool) @@ -1675,6 +2033,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) { HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + m_HttpServer.MarkRequest(); + // Default request handling # if ZEN_WITH_OTEL @@ -1884,6 +2244,17 @@ HttpSysServerRequest::IsLocalMachineRequest() const } std::string_view +HttpSysServerRequest::GetRemoteAddress() const +{ + if (m_RemoteAddress.Size() == 0) + { + const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress; + GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false); + } + return m_RemoteAddress.ToView(); +} + +std::string_view HttpSysServerRequest::GetAuthorizationHeader() const { const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); @@ -2111,6 +2482,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT break; } + Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); + ZEN_TRACE_CPU("httpsys::HandleCompletion"); // Route request @@ -2119,64 +2492,122 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT { HTTP_REQUEST* HttpReq = HttpRequest(); -# if 0 - for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) { - auto& ReqInfo = HttpReq->pRequestInfo[i]; - - switch (ReqInfo.InfoType) + // WebSocket upgrade detection + if (m_IsInitialRequest) { - case HttpRequestInfoTypeRequestTiming: + const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade]; + if (UpgradeHeader.RawValueLength > 0 && + StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0) + { + if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service)) { - const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo); + // Extract Sec-WebSocket-Key from the unknown headers + // (http.sys has no known-header slot for it) + std::string_view SecWebSocketKey; + for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i) + { + const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i]; + if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0) + { + SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength); + break; + } + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeAuth: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeChannelBind: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslProtocol: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBindingDraft: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBinding: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV0: - { - const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo); + if (SecWebSocketKey.empty()) + { + ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header"); + return nullptr; + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeRequestSizing: - { - const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo); - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeQuicStats: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV1: - { - const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo); + const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey); + + HANDLE RequestQueueHandle = Transaction().RequestQueueHandle(); + HTTP_REQUEST_ID RequestId = HttpReq->RequestId; + + // Build the 101 Switching Protocols response + HTTP_RESPONSE Response = {}; + Response.StatusCode = 101; + Response.pReason = "Switching Protocols"; + Response.ReasonLength = (USHORT)strlen(Response.pReason); + + Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket"; + Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9; + + eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders; + + // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders + // despite there being an entry for it there (HttpHeaderConnection). If you try to do + // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below + + UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"}); + + UnknownHeaders.push_back({.NameLength = 20, + .RawValueLength = (USHORT)AcceptKey.size(), + .pName = "Sec-WebSocket-Accept", + .pRawValue = AcceptKey.c_str()}); + + Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size(); + Response.Headers.pUnknownHeaders = UnknownHeaders.data(); + + const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + + // Use an OVERLAPPED with an event so we can wait synchronously. + // The request queue is IOCP-associated, so passing NULL for pOverlapped + // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent + // prevents IOCP delivery and lets us wait on the event directly. + OVERLAPPED SendOverlapped = {}; + HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr); + SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1); + + ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle, + RequestId, + Flags, + &Response, + nullptr, // CachePolicy + nullptr, // BytesSent + nullptr, // Reserved1 + 0, // Reserved2 + &SendOverlapped, + nullptr // LogData + ); + + if (SendResult == ERROR_IO_PENDING) + { + WaitForSingleObject(SendEvent, INFINITE); + SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE; + } + + CloseHandle(SendEvent); + + if (SendResult == NO_ERROR) + { + Transaction().Server().OnWebSocketConnectionOpened(); + Ref<WsHttpSysConnection> WsConn(new WsHttpSysConnection(RequestQueueHandle, + RequestId, + *WsHandler, + Transaction().Iocp(), + &Transaction().Server())); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + + return nullptr; + } + + ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult); - ZEN_INFO(""); + // WebSocket upgrade failed — return nullptr since ServerRequest() + // was never populated (no InvokeRequestHandler call) + return nullptr; } - break; + // Service doesn't support WebSocket or missing key — fall through to normal handling + } } - } -# endif - if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) - { if (m_IsInitialRequest) { m_ContentLength = GetContentLength(HttpReq); @@ -2242,6 +2673,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); } } + else + { + // If a default redirect is configured and the request is for the root path, send a 302 + std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect(); + std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength); + if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty())) + { + auto* Response = new HttpMessageResponseRequest(Transaction(), 302); + Response->SetLocationHeader(DefaultRedirect); + return Response; + } + } // Unable to route return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); @@ -2285,6 +2728,11 @@ HttpSysServer::OnInitialize(int BasePort, std::filesystem::path DataDir) ZEN_UNUSED(DataDir); if (int EffectivePort = InitializeServer(BasePort)) { + if (m_HttpsPort > 0) + { + SetEffectiveHttpsPort(m_HttpsPort); + } + StartServer(); return EffectivePort; @@ -2301,6 +2749,52 @@ HttpSysServer::OnRequestExit() m_ShutdownEvent.Set(); } +std::string +HttpSysServer::OnGetExternalHost() const +{ + // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback + bool IsPublic = false; + for (const auto& Uri : m_BaseUris) + { + if (Uri.find(L'*') != std::wstring::npos) + { + IsPublic = true; + break; + } + } + + if (!IsPublic) + { + return "127.0.0.1"; + } + + // Use the UDP connect trick: connecting a UDP socket to an external address + // causes the OS to select the appropriate local interface without sending any data. + try + { + asio::io_context IoService; + asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4()); + Sock.connect(asio::ip::udp::endpoint(asio::ip::make_address("8.8.8.8"), 80)); + return Sock.local_endpoint().address().to_string(); + } + catch (const std::exception&) + { + return GetMachineName(); + } +} + +uint64_t +HttpSysServer::GetTotalBytesReceived() const +{ + return m_TotalBytesReceived.load(std::memory_order_relaxed); +} + +uint64_t +HttpSysServer::GetTotalBytesSent() const +{ + return m_TotalBytesSent.load(std::memory_order_relaxed); +} + void HttpSysServer::OnRegisterService(HttpService& Service) { |