aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/servers/httpsys.cpp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-03-15 20:42:36 +0100
committerStefan Boberg <[email protected]>2026-03-15 20:42:36 +0100
commit9c724efbf6b38466a9b6bfde37236369f1e85cb8 (patch)
tree214e1ec00c5bfca0704ce52789017ade734fd054 /src/zenhttp/servers/httpsys.cpp
parentreduced WaitForThreads time to see how it behaves with explicit thread pools (diff)
parentadd buildid updates to oplog and builds test scripts (#838) (diff)
downloadarchived-zen-9c724efbf6b38466a9b6bfde37236369f1e85cb8.tar.xz
archived-zen-9c724efbf6b38466a9b6bfde37236369f1e85cb8.zip
Merge remote-tracking branch 'origin/main' into sb/threadpool
Diffstat (limited to 'src/zenhttp/servers/httpsys.cpp')
-rw-r--r--src/zenhttp/servers/httpsys.cpp786
1 files changed, 640 insertions, 146 deletions
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 4406d0619..eaf080960 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -12,6 +12,7 @@
#include <zencore/memory/llm.h>
#include <zencore/scopeguard.h>
#include <zencore/string.h>
+#include <zencore/system.h>
#include <zencore/timer.h>
#include <zencore/trace.h>
#include <zenhttp/packageformat.h>
@@ -25,7 +26,9 @@
# include <zencore/workthreadpool.h>
# include "iothreadpool.h"
+# include <atomic>
# include <http.h>
+# include <asio.hpp> // for resolving addresses for GetExternalHost
namespace zen {
@@ -85,6 +88,8 @@ class HttpSysServerRequest;
class HttpSysServer : public HttpServer
{
friend class HttpSysTransaction;
+ friend class HttpMessageResponseRequest;
+ friend struct InitialRequestHandler;
public:
explicit HttpSysServer(const HttpSysConfig& Config);
@@ -92,12 +97,15 @@ public:
// HttpServer interface implementation
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool TestMode) override;
- virtual void OnRequestExit() override;
- virtual void OnRegisterService(HttpService& Service) override;
- virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
- virtual void OnClose() override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnRun(bool TestMode) override;
+ virtual void OnRequestExit() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
+ virtual uint64_t GetTotalBytesReceived() const override;
+ virtual uint64_t GetTotalBytesSent() const override;
WorkerThreadPool& WorkPool();
@@ -108,6 +116,12 @@ public:
private:
int InitializeServer(int BasePort);
+ bool CreateSessionAndUrlGroup();
+ bool RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris);
+ int RegisterHttpUrls(int BasePort);
+ bool RegisterHttpsUrls();
+ bool CreateRequestQueue(int EffectivePort);
+ bool SetupIoCompletionPort();
void Cleanup();
void StartServer();
@@ -117,6 +131,9 @@ private:
void RegisterService(const char* Endpoint, HttpService& Service);
void UnregisterService(const char* Endpoint, HttpService& Service);
+ bool BindSslCertificate(int Port);
+ void UnbindSslCertificate();
+
private:
LoggerRef m_Log;
LoggerRef m_RequestLog;
@@ -130,10 +147,13 @@ private:
std::unique_ptr<WinIoThreadPool> m_IoThreadPool;
bool m_IoThreadPoolIsWinTp = true;
- RwLock m_AsyncWorkPoolInitLock;
- WorkerThreadPool* m_AsyncWorkPool = nullptr;
+ RwLock m_AsyncWorkPoolInitLock;
+ std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr;
- std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ std::vector<std::wstring> m_HttpsBaseUris; // eg: https://*:nnnn/
+ bool m_DidAutoBindCert = false;
+ int m_HttpsPort = 0;
HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0;
HANDLE m_RequestQueueHandle = 0;
@@ -146,6 +166,9 @@ private:
RwLock m_RequestFilterLock;
std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+
+ std::atomic<uint64_t> m_TotalBytesReceived{0};
+ std::atomic<uint64_t> m_TotalBytesSent{0};
};
} // namespace zen
@@ -153,6 +176,10 @@ private:
#if ZEN_WITH_HTTPSYS
+# include "httpsys_iocontext.h"
+# include "wshttpsys.h"
+# include "wsframecodec.h"
+
# include <conio.h>
# include <mstcpip.h>
# pragma comment(lib, "httpapi.lib")
@@ -322,8 +349,9 @@ public:
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
- virtual bool IsLocalMachineRequest() const;
+ virtual bool IsLocalMachineRequest() const override;
virtual std::string_view GetAuthorizationHeader() const override;
+ virtual std::string_view GetRemoteAddress() const override;
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
@@ -339,11 +367,12 @@ public:
HttpSysServerRequest(const HttpSysServerRequest&) = delete;
HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete;
- HttpSysTransaction& m_HttpTx;
- HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
- IoBuffer m_PayloadBuffer;
- ExtendableStringBuilder<128> m_UriUtf8;
- ExtendableStringBuilder<128> m_QueryStringUtf8;
+ HttpSysTransaction& m_HttpTx;
+ HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
+ IoBuffer m_PayloadBuffer;
+ ExtendableStringBuilder<128> m_UriUtf8;
+ ExtendableStringBuilder<128> m_QueryStringUtf8;
+ mutable ExtendableStringBuilder<64> m_RemoteAddress;
};
/** HTTP transaction
@@ -378,7 +407,7 @@ public:
void StartIo();
void CancelIo();
HANDLE RequestQueueHandle();
- inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
+ inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; }
inline HttpSysServer& Server() { return m_HttpServer; }
inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
@@ -395,8 +424,8 @@ public:
};
private:
- OVERLAPPED m_HttpOverlapped{};
- HttpSysServer& m_HttpServer;
+ HttpSysIoContext m_IoContext{};
+ HttpSysServer& m_HttpServer;
// Tracks which handler is due to handle the next I/O completion event
HttpSysRequestHandler* m_CompletionHandler = nullptr;
@@ -436,6 +465,8 @@ public:
inline uint16_t GetResponseCode() const { return m_ResponseCode; }
inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
+ void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; }
+
private:
eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks;
uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes
@@ -445,6 +476,7 @@ private:
bool m_IsInitialResponse = true;
HttpContentType m_ContentType = HttpContentType::kBinary;
eastl::fixed_vector<IoBuffer, 16> m_DataBuffers;
+ std::string m_LocationHeader;
void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs);
};
@@ -585,7 +617,7 @@ HttpMessageResponseRequest::SuppressResponseBody()
HttpSysRequestHandler*
HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
{
- ZEN_UNUSED(NumberOfBytesTransferred);
+ Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
if (IoResult != NO_ERROR)
{
@@ -699,6 +731,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
ContentTypeHeader->pRawValue = ContentTypeString.data();
ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
+ // Location header (for redirects)
+
+ if (!m_LocationHeader.empty())
+ {
+ PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation];
+ LocationHeader->pRawValue = m_LocationHeader.data();
+ LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size();
+ }
+
std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
HttpResponse.StatusCode = m_ResponseCode;
@@ -900,7 +941,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr
ZEN_UNUSED(IoResult, NumberOfBytesTransferred);
- ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred);
+ ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}",
+ GetSystemErrorAsString(IoResult),
+ IoResult,
+ NumberOfBytesTransferred);
return this;
}
@@ -1035,8 +1079,10 @@ HttpSysServer::~HttpSysServer()
ZEN_ERROR("~HttpSysServer() called without calling Close() first");
}
- delete m_AsyncWorkPool;
+ auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed);
m_AsyncWorkPool = nullptr;
+
+ delete WorkPool;
}
void
@@ -1051,36 +1097,63 @@ HttpSysServer::OnClose()
}
}
-int
-HttpSysServer::InitializeServer(int BasePort)
+bool
+HttpSysServer::CreateSessionAndUrlGroup()
{
- ZEN_MEMSCOPE(GetHttpsysTag());
-
- using namespace std::literals;
-
- WideStringBuilder<64> WildcardUrlPath;
- WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
-
- m_IsOk = false;
-
ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create server session: {} ({:#x})", GetSystemErrorAsString(Result), Result);
- return 0;
+ return false;
}
Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create URL group: {} ({:#x})", GetSystemErrorAsString(Result), Result);
- return 0;
+ return false;
}
+ return true;
+}
+
+bool
+HttpSysServer::RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris)
+{
+ using namespace std::literals;
+
+ const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
+
+ for (const std::u8string_view Host : Hosts)
+ {
+ WideStringBuilder<64> LocalUrl;
+ LocalUrl << Scheme << u8"://"sv << Host << u8":"sv << int64_t(Port) << u8"/"sv;
+
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrl.c_str(), HTTP_URL_CONTEXT(0), 0);
+
+ if (Result == NO_ERROR)
+ {
+ ZEN_WARN("Registered local-only handler '{}' - this is not accessible from remote hosts", WideToUtf8(LocalUrl));
+ OutUris.push_back(LocalUrl.c_str());
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ return !OutUris.empty();
+}
+
+int
+HttpSysServer::RegisterHttpUrls(int BasePort)
+{
+ using namespace std::literals;
+
m_BaseUris.clear();
const bool AllowPortProbing = !m_InitialConfig.IsDedicatedServer;
@@ -1088,6 +1161,11 @@ HttpSysServer::InitializeServer(int BasePort)
int EffectivePort = BasePort;
+ WideStringBuilder<64> WildcardUrlPath;
+ WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
+
+ ULONG Result;
+
if (m_InitialConfig.ForceLoopback)
{
// Force trigger of opening using local port
@@ -1100,7 +1178,9 @@ HttpSysServer::InitializeServer(int BasePort)
if ((Result == ERROR_SHARING_VIOLATION))
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
@@ -1122,7 +1202,9 @@ HttpSysServer::InitializeServer(int BasePort)
{
for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++)
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
}
@@ -1139,11 +1221,11 @@ HttpSysServer::InitializeServer(int BasePort)
{
if (AllowLocalOnly)
{
- // If we can't register the wildcard path, we fall back to local paths
- // This local paths allow requests originating locally to function, but will not allow
- // remote origin requests to function. This can be remedied by using netsh
+ // If we can't register the wildcard path, we fall back to local paths.
+ // Local paths allow requests originating locally to function, but will not allow
+ // remote origin requests to function. This can be remedied by using netsh
// during an install process to grant permissions to route public access to the appropriate
- // port for the current user. eg:
+ // port for the current user. eg:
// netsh http add urlacl url=http://*:8558/ user=<some_user>
if (!m_InitialConfig.ForceLoopback)
@@ -1157,17 +1239,18 @@ HttpSysServer::InitializeServer(int BasePort)
const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
- ULONG InternalResult = ERROR_SHARING_VIOLATION;
- for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
+ bool ShouldRetryNextPort = true;
+ for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset)
{
- EffectivePort = BasePort + (PortOffset * 100);
+ EffectivePort = BasePort + (PortOffset * 100);
+ ShouldRetryNextPort = false;
for (const std::u8string_view Host : Hosts)
{
WideStringBuilder<64> LocalUrlPath;
LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv;
- InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+ ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
if (InternalResult == NO_ERROR)
{
@@ -1175,11 +1258,25 @@ HttpSysServer::InitializeServer(int BasePort)
m_BaseUris.push_back(LocalUrlPath.c_str());
}
+ else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED)
+ {
+ // Port may be owned by another process's wildcard registration (access denied)
+ // or actively in use (sharing violation) — retry on a different port
+ ShouldRetryNextPort = true;
+ }
else
{
- break;
+ ZEN_WARN("Failed to register local handler '{}': {} ({:#x})",
+ WideToUtf8(LocalUrlPath),
+ GetSystemErrorAsString(InternalResult),
+ InternalResult);
}
}
+
+ if (!m_BaseUris.empty())
+ {
+ break;
+ }
}
}
else
@@ -1193,29 +1290,123 @@ HttpSysServer::InitializeServer(int BasePort)
}
}
- if (m_BaseUris.empty())
+ if (m_BaseUris.empty() && m_InitialConfig.HttpsPort == 0)
{
- ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})",
+ WideToUtf8(WildcardUrlPath),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
+ return EffectivePort;
+}
+
+bool
+HttpSysServer::RegisterHttpsUrls()
+{
+ using namespace std::literals;
+
+ const bool AllowLocalOnly = !m_InitialConfig.IsDedicatedServer;
+ const int HttpsPort = m_InitialConfig.HttpsPort;
+
+ // If HTTPS-only mode, remove HTTP URLs and clear base URIs
+ if (m_InitialConfig.HttpsOnly)
+ {
+ for (const std::wstring& Uri : m_BaseUris)
+ {
+ HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Uri.c_str(), 0);
+ }
+ m_BaseUris.clear();
+ }
+
+ // Auto-bind certificate if thumbprint is provided
+ if (!m_InitialConfig.CertThumbprint.empty())
+ {
+ if (!BindSslCertificate(HttpsPort))
+ {
+ return false;
+ }
+ }
+ else
+ {
+ ZEN_INFO("HTTPS port {} configured without thumbprint - assuming pre-registered SSL certificate", HttpsPort);
+ }
+
+ // Register HTTPS URLs using same pattern as HTTP
+
+ WideStringBuilder<64> HttpsWildcard;
+ HttpsWildcard << u8"https://*:"sv << int64_t(HttpsPort) << u8"/"sv;
+
+ ULONG HttpsResult = NO_ERROR;
+
+ if (m_InitialConfig.ForceLoopback)
+ {
+ HttpsResult = ERROR_ACCESS_DENIED;
+ }
+ else
+ {
+ HttpsResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, HttpsWildcard.c_str(), HTTP_URL_CONTEXT(0), 0);
+ }
+
+ if (HttpsResult == NO_ERROR)
+ {
+ m_HttpsBaseUris.push_back(HttpsWildcard.c_str());
+ }
+ else if (HttpsResult == ERROR_ACCESS_DENIED && AllowLocalOnly)
+ {
+ if (!m_InitialConfig.ForceLoopback)
+ {
+ ZEN_WARN(
+ "Unable to register HTTPS handler using '{}' - falling back to local-only. "
+ "Please ensure the appropriate netsh URL reservation and SSL certificate configuration is made.",
+ WideToUtf8(HttpsWildcard));
+ }
+
+ RegisterLocalUrls(u8"https", HttpsPort, m_HttpsBaseUris);
+ }
+ else if (HttpsResult != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to register HTTPS URL '{}': {} ({:#x})",
+ WideToUtf8(HttpsWildcard),
+ GetSystemErrorAsString(HttpsResult),
+ HttpsResult);
+ return false;
+ }
+
+ if (m_HttpsBaseUris.empty())
+ {
+ ZEN_ERROR("Failed to register any HTTPS URL for port {}", HttpsPort);
+ return false;
+ }
+
+ m_HttpsPort = HttpsPort;
+ return true;
+}
+
+bool
+HttpSysServer::CreateRequestQueue(int EffectivePort)
+{
HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0};
WideStringBuilder<64> QueueName;
QueueName << "zenserver_" << EffectivePort;
- Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
- /* Name */ QueueName.c_str(),
- /* SecurityAttributes */ nullptr,
- /* Flags */ 0,
- &m_RequestQueueHandle);
+ ULONG Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
+ /* Name */ QueueName.c_str(),
+ /* SecurityAttributes */ nullptr,
+ /* Flags */ 0,
+ &m_RequestQueueHandle);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
- return 0;
+ return false;
}
HttpBindingInfo.Flags.Present = 1;
@@ -1225,9 +1416,12 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
- return 0;
+ return false;
}
// Configure rejection method. Default is to drop the connection, it's better if we
@@ -1257,42 +1451,82 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result);
+ ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result);
}
}
- // Create I/O completion port
+ return true;
+}
+bool
+HttpSysServer::SetupIoCompletionPort()
+{
std::error_code ErrorCode;
m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode);
if (ErrorCode)
{
- ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message());
+ ZEN_ERROR("Failed to create IOCP: {}", ErrorCode.message());
+ return false;
+ }
+
+ m_IsOk = true;
+
+ if (!m_BaseUris.empty())
+ {
+ ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ }
+ if (!m_HttpsBaseUris.empty())
+ {
+ ZEN_INFO("Started http.sys HTTPS server at '{}'", WideToUtf8(m_HttpsBaseUris.front()));
+ }
+
+ return true;
+}
+
+int
+HttpSysServer::InitializeServer(int BasePort)
+{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
+ m_IsOk = false;
+ if (!CreateSessionAndUrlGroup())
+ {
return 0;
}
- else
+
+ int EffectivePort = RegisterHttpUrls(BasePort);
+
+ if (m_InitialConfig.HttpsPort > 0)
{
- m_IsOk = true;
+ if (!RegisterHttpsUrls())
+ {
+ return 0;
+ }
+ }
- ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ if (m_BaseUris.empty() && m_HttpsBaseUris.empty())
+ {
+ ZEN_ERROR("No HTTP or HTTPS listeners could be registered");
+ return 0;
}
- // This is not available in all Windows SDK versions so for now we can't use recently
- // released functionality. We should investigate how to get more recent SDK releases
- // into the build
+ if (!CreateRequestQueue(EffectivePort))
+ {
+ return 0;
+ }
-# if 0
- if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4))
+ if (!SetupIoCompletionPort())
{
- ZEN_DEBUG("HTTP3 is available");
+ return 0;
}
- else
+
+ // When HTTPS-only, return the HTTPS port as the effective port
+ if (m_InitialConfig.HttpsOnly && m_HttpsPort > 0)
{
- ZEN_DEBUG("HTTP3 is NOT available");
+ return m_HttpsPort;
}
-# endif
return EffectivePort;
}
@@ -1302,6 +1536,8 @@ HttpSysServer::Cleanup()
{
++m_IsShuttingDown;
+ UnbindSslCertificate();
+
if (m_RequestQueueHandle)
{
HttpCloseRequestQueue(m_RequestQueueHandle);
@@ -1321,23 +1557,122 @@ HttpSysServer::Cleanup()
}
}
+// {7E3F4B2A-1C8D-4A6E-B5F0-9D2E8C7A3B1F} - Fixed GUID for zenserver SSL bindings
+static constexpr GUID ZenServerSslAppId = {0x7E3F4B2A, 0x1C8D, 0x4A6E, {0xB5, 0xF0, 0x9D, 0x2E, 0x8C, 0x7A, 0x3B, 0x1F}};
+
+bool
+HttpSysServer::BindSslCertificate(int Port)
+{
+ const std::string& Thumbprint = m_InitialConfig.CertThumbprint;
+ if (Thumbprint.size() != 40)
+ {
+ ZEN_ERROR("SSL certificate thumbprint must be exactly 40 hex characters, got {}", Thumbprint.size());
+ return false;
+ }
+
+ BYTE CertHash[20] = {};
+ if (!ParseHexBytes(Thumbprint, CertHash))
+ {
+ ZEN_ERROR("SSL certificate thumbprint contains invalid hex characters");
+ return false;
+ }
+
+ SOCKADDR_IN Address = {};
+ Address.sin_family = AF_INET;
+ Address.sin_port = htons(static_cast<USHORT>(Port));
+ Address.sin_addr.s_addr = INADDR_ANY;
+
+ const std::wstring StoreNameW = UTF8_to_UTF16(m_InitialConfig.CertStoreName.c_str());
+
+ HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {};
+ SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+ SslConfig.ParamDesc.pSslHash = CertHash;
+ SslConfig.ParamDesc.SslHashLength = sizeof(CertHash);
+ SslConfig.ParamDesc.pSslCertStoreName = const_cast<PWSTR>(StoreNameW.c_str());
+ SslConfig.ParamDesc.AppId = ZenServerSslAppId;
+
+ ULONG Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+
+ if (Result == ERROR_ALREADY_EXISTS)
+ {
+ // Remove existing binding and retry
+ HTTP_SERVICE_CONFIG_SSL_SET DeleteConfig = {};
+ DeleteConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+
+ HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &DeleteConfig, sizeof(DeleteConfig), nullptr);
+
+ Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+ }
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR(
+ "Failed to bind SSL certificate to port {}: {} ({:#x}). "
+ "This operation may require running as administrator.",
+ Port,
+ GetSystemErrorAsString(Result),
+ Result);
+ return false;
+ }
+
+ m_DidAutoBindCert = true;
+ m_HttpsPort = Port;
+
+ ZEN_INFO("SSL certificate auto-bound for 0.0.0.0:{} (thumbprint: {}..., store: {})",
+ Port,
+ Thumbprint.substr(0, 8),
+ m_InitialConfig.CertStoreName);
+
+ return true;
+}
+
+void
+HttpSysServer::UnbindSslCertificate()
+{
+ if (!m_DidAutoBindCert)
+ {
+ return;
+ }
+
+ SOCKADDR_IN Address = {};
+ Address.sin_family = AF_INET;
+ Address.sin_port = htons(static_cast<USHORT>(m_HttpsPort));
+ Address.sin_addr.s_addr = INADDR_ANY;
+
+ HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {};
+ SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+
+ ULONG Result = HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_WARN("Failed to remove SSL certificate binding from port {}: {} ({:#x})", m_HttpsPort, GetSystemErrorAsString(Result), Result);
+ }
+ else
+ {
+ ZEN_INFO("SSL certificate binding removed from port {}", m_HttpsPort);
+ }
+
+ m_DidAutoBindCert = false;
+}
+
WorkerThreadPool&
HttpSysServer::WorkPool()
{
ZEN_MEMSCOPE(GetHttpsysTag());
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_acquire))
{
RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock);
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_relaxed))
{
m_AsyncWorkPool =
new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async", m_InitialConfig.UseExplicitIoThreadPool);
}
}
- return *m_AsyncWorkPool;
+ return *m_AsyncWorkPool.load(std::memory_order_relaxed);
}
void
@@ -1449,19 +1784,23 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service)
// Convert to wide string
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
-
- ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
-
- if (Result != NO_ERROR)
+ auto RegisterWithBaseUris = [&](const std::vector<std::wstring>& BaseUris) {
+ for (const std::wstring& BaseUri : BaseUris)
{
- ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ std::wstring Url16 = BaseUri + PathUtf16;
- return;
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ return;
+ }
}
- }
+ };
+
+ RegisterWithBaseUris(m_BaseUris);
+ RegisterWithBaseUris(m_HttpsBaseUris);
}
void
@@ -1476,19 +1815,22 @@ HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service)
const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath);
- // Convert to wide string
-
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
+ auto UnregisterFromBaseUris = [&](const std::vector<std::wstring>& BaseUris) {
+ for (const std::wstring& BaseUri : BaseUris)
+ {
+ std::wstring Url16 = BaseUri + PathUtf16;
- ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
+ ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
- if (Result != NO_ERROR)
- {
- ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ }
}
- }
+ };
+
+ UnregisterFromBaseUris(m_BaseUris);
+ UnregisterFromBaseUris(m_HttpsBaseUris);
}
//////////////////////////////////////////////////////////////////////////
@@ -1551,7 +1893,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
// than one thread at any given moment. This means we need to be careful about what
// happens in here
- HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped);
+ HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped);
+
+ switch (IoContext->ContextType)
+ {
+ case HttpSysIoContext::Type::kWebSocketRead:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kWebSocketWrite:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kTransaction:
+ break;
+ }
+
+ HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext);
// Assign names to threads for context (only needed when using Windows' native
// thread pool)
@@ -1675,6 +2033,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
{
HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload);
+ m_HttpServer.MarkRequest();
+
// Default request handling
# if ZEN_WITH_OTEL
@@ -1884,6 +2244,17 @@ HttpSysServerRequest::IsLocalMachineRequest() const
}
std::string_view
+HttpSysServerRequest::GetRemoteAddress() const
+{
+ if (m_RemoteAddress.Size() == 0)
+ {
+ const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress;
+ GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false);
+ }
+ return m_RemoteAddress.ToView();
+}
+
+std::string_view
HttpSysServerRequest::GetAuthorizationHeader() const
{
const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
@@ -2111,6 +2482,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
break;
}
+ Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
+
ZEN_TRACE_CPU("httpsys::HandleCompletion");
// Route request
@@ -2119,64 +2492,122 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
{
HTTP_REQUEST* HttpReq = HttpRequest();
-# if 0
- for (int i = 0; i < HttpReq->RequestInfoCount; ++i)
+ if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
{
- auto& ReqInfo = HttpReq->pRequestInfo[i];
-
- switch (ReqInfo.InfoType)
+ // WebSocket upgrade detection
+ if (m_IsInitialRequest)
{
- case HttpRequestInfoTypeRequestTiming:
+ const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade];
+ if (UpgradeHeader.RawValueLength > 0 &&
+ StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0)
+ {
+ if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service))
{
- const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo);
+ // Extract Sec-WebSocket-Key from the unknown headers
+ // (http.sys has no known-header slot for it)
+ std::string_view SecWebSocketKey;
+ for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i)
+ {
+ const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i];
+ if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0)
+ {
+ SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength);
+ break;
+ }
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeAuth:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeChannelBind:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslProtocol:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBindingDraft:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBinding:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV0:
- {
- const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo);
+ if (SecWebSocketKey.empty())
+ {
+ ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header");
+ return nullptr;
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeRequestSizing:
- {
- const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo);
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeQuicStats:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV1:
- {
- const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo);
+ const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey);
+
+ HANDLE RequestQueueHandle = Transaction().RequestQueueHandle();
+ HTTP_REQUEST_ID RequestId = HttpReq->RequestId;
+
+ // Build the 101 Switching Protocols response
+ HTTP_RESPONSE Response = {};
+ Response.StatusCode = 101;
+ Response.pReason = "Switching Protocols";
+ Response.ReasonLength = (USHORT)strlen(Response.pReason);
+
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket";
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9;
+
+ eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders;
+
+ // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders
+ // despite there being an entry for it there (HttpHeaderConnection). If you try to do
+ // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below
+
+ UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"});
+
+ UnknownHeaders.push_back({.NameLength = 20,
+ .RawValueLength = (USHORT)AcceptKey.size(),
+ .pName = "Sec-WebSocket-Accept",
+ .pRawValue = AcceptKey.c_str()});
+
+ Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size();
+ Response.Headers.pUnknownHeaders = UnknownHeaders.data();
+
+ const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
+
+ // Use an OVERLAPPED with an event so we can wait synchronously.
+ // The request queue is IOCP-associated, so passing NULL for pOverlapped
+ // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent
+ // prevents IOCP delivery and lets us wait on the event directly.
+ OVERLAPPED SendOverlapped = {};
+ HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr);
+ SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1);
+
+ ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle,
+ RequestId,
+ Flags,
+ &Response,
+ nullptr, // CachePolicy
+ nullptr, // BytesSent
+ nullptr, // Reserved1
+ 0, // Reserved2
+ &SendOverlapped,
+ nullptr // LogData
+ );
+
+ if (SendResult == ERROR_IO_PENDING)
+ {
+ WaitForSingleObject(SendEvent, INFINITE);
+ SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE;
+ }
+
+ CloseHandle(SendEvent);
+
+ if (SendResult == NO_ERROR)
+ {
+ Transaction().Server().OnWebSocketConnectionOpened();
+ Ref<WsHttpSysConnection> WsConn(new WsHttpSysConnection(RequestQueueHandle,
+ RequestId,
+ *WsHandler,
+ Transaction().Iocp(),
+ &Transaction().Server()));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+
+ return nullptr;
+ }
+
+ ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult);
- ZEN_INFO("");
+ // WebSocket upgrade failed — return nullptr since ServerRequest()
+ // was never populated (no InvokeRequestHandler call)
+ return nullptr;
}
- break;
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
}
- }
-# endif
- if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
- {
if (m_IsInitialRequest)
{
m_ContentLength = GetContentLength(HttpReq);
@@ -2242,6 +2673,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
}
}
+ else
+ {
+ // If a default redirect is configured and the request is for the root path, send a 302
+ std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect();
+ std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength);
+ if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty()))
+ {
+ auto* Response = new HttpMessageResponseRequest(Transaction(), 302);
+ Response->SetLocationHeader(DefaultRedirect);
+ return Response;
+ }
+ }
// Unable to route
return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv);
@@ -2285,6 +2728,11 @@ HttpSysServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
ZEN_UNUSED(DataDir);
if (int EffectivePort = InitializeServer(BasePort))
{
+ if (m_HttpsPort > 0)
+ {
+ SetEffectiveHttpsPort(m_HttpsPort);
+ }
+
StartServer();
return EffectivePort;
@@ -2301,6 +2749,52 @@ HttpSysServer::OnRequestExit()
m_ShutdownEvent.Set();
}
+std::string
+HttpSysServer::OnGetExternalHost() const
+{
+ // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback
+ bool IsPublic = false;
+ for (const auto& Uri : m_BaseUris)
+ {
+ if (Uri.find(L'*') != std::wstring::npos)
+ {
+ IsPublic = true;
+ break;
+ }
+ }
+
+ if (!IsPublic)
+ {
+ return "127.0.0.1";
+ }
+
+ // Use the UDP connect trick: connecting a UDP socket to an external address
+ // causes the OS to select the appropriate local interface without sending any data.
+ try
+ {
+ asio::io_context IoService;
+ asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4());
+ Sock.connect(asio::ip::udp::endpoint(asio::ip::make_address("8.8.8.8"), 80));
+ return Sock.local_endpoint().address().to_string();
+ }
+ catch (const std::exception&)
+ {
+ return GetMachineName();
+ }
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesReceived() const
+{
+ return m_TotalBytesReceived.load(std::memory_order_relaxed);
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesSent() const
+{
+ return m_TotalBytesSent.load(std::memory_order_relaxed);
+}
+
void
HttpSysServer::OnRegisterService(HttpService& Service)
{