aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/servers
diff options
context:
space:
mode:
authorLiam Mitchell <[email protected]>2026-03-09 19:06:36 -0700
committerLiam Mitchell <[email protected]>2026-03-09 19:06:36 -0700
commitd1abc50ee9d4fb72efc646e17decafea741caa34 (patch)
treee4288e00f2f7ca0391b83d986efcb69d3ba66a83 /src/zenhttp/servers
parentAllow requests with invalid content-types unless specified in command line or... (diff)
parentupdated chunk–block analyser (#818) (diff)
downloadzen-d1abc50ee9d4fb72efc646e17decafea741caa34.tar.xz
zen-d1abc50ee9d4fb72efc646e17decafea741caa34.zip
Merge branch 'main' into lm/restrict-content-type
Diffstat (limited to 'src/zenhttp/servers')
-rw-r--r--src/zenhttp/servers/httpasio.cpp427
-rw-r--r--src/zenhttp/servers/httpasio.h2
-rw-r--r--src/zenhttp/servers/httpmulti.cpp31
-rw-r--r--src/zenhttp/servers/httpmulti.h12
-rw-r--r--src/zenhttp/servers/httpnull.cpp18
-rw-r--r--src/zenhttp/servers/httpnull.h1
-rw-r--r--src/zenhttp/servers/httpparser.cpp155
-rw-r--r--src/zenhttp/servers/httpparser.h12
-rw-r--r--src/zenhttp/servers/httpplugin.cpp140
-rw-r--r--src/zenhttp/servers/httpsys.cpp556
-rw-r--r--src/zenhttp/servers/httpsys_iocontext.h40
-rw-r--r--src/zenhttp/servers/httptracer.h4
-rw-r--r--src/zenhttp/servers/wsasio.cpp311
-rw-r--r--src/zenhttp/servers/wsasio.h77
-rw-r--r--src/zenhttp/servers/wsframecodec.cpp236
-rw-r--r--src/zenhttp/servers/wsframecodec.h74
-rw-r--r--src/zenhttp/servers/wshttpsys.cpp485
-rw-r--r--src/zenhttp/servers/wshttpsys.h107
-rw-r--r--src/zenhttp/servers/wstest.cpp925
19 files changed, 3262 insertions, 351 deletions
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 18a0f6a40..f5178ebe8 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -7,12 +7,15 @@
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/memory/llm.h>
+#include <zencore/system.h>
#include <zencore/thread.h>
#include <zencore/trace.h>
#include <zencore/windows.h>
#include <zenhttp/httpserver.h>
#include "httpparser.h"
+#include "wsasio.h"
+#include "wsframecodec.h"
#include <EASTL/fixed_vector.h>
@@ -89,15 +92,19 @@ IsIPv6AvailableSysctl(void)
char buf[16];
if (fgets(buf, sizeof(buf), f))
{
- fclose(f);
// 0 means IPv6 enabled, 1 means disabled
val = atoi(buf);
}
+ fclose(f);
}
return val == 0;
}
+#endif // ZEN_PLATFORM_LINUX
+namespace zen {
+
+#if ZEN_PLATFORM_LINUX
bool
IsIPv6Capable()
{
@@ -121,8 +128,6 @@ IsIPv6Capable()
}
#endif
-namespace zen {
-
const FLLMTag&
GetHttpasioTag()
{
@@ -145,7 +150,7 @@ inline LoggerRef
InitLogger()
{
LoggerRef Logger = logging::Get("asio");
- // Logger.set_level(spdlog::level::trace);
+ // Logger.SetLogLevel(logging::Trace);
return Logger;
}
@@ -496,16 +501,21 @@ public:
HttpAsioServerImpl();
~HttpAsioServerImpl();
- void Initialize(std::filesystem::path DataDir);
- int Start(uint16_t Port, const AsioConfig& Config);
- void Stop();
- void RegisterService(const char* UrlPath, HttpService& Service);
- HttpService* RouteRequest(std::string_view Url);
+ void Initialize(std::filesystem::path DataDir);
+ int Start(uint16_t Port, const AsioConfig& Config);
+ void Stop();
+ void RegisterService(const char* UrlPath, HttpService& Service);
+ void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter);
+ HttpService* RouteRequest(std::string_view Url);
+ IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
+
+ bool IsLoopbackOnly() const;
asio::io_service m_IoService;
asio::io_service::work m_Work{m_IoService};
std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
std::vector<std::thread> m_ThreadPool;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
LoggerRef m_RequestLog;
HttpServerTracer m_RequestTracer;
@@ -518,6 +528,11 @@ public:
RwLock m_Lock;
std::vector<ServiceEntry> m_UriHandlers;
+
+ std::atomic<uint64_t> m_TotalBytesReceived{0};
+ std::atomic<uint64_t> m_TotalBytesSent{0};
+
+ HttpServer* m_HttpServer = nullptr;
};
/**
@@ -527,12 +542,21 @@ public:
class HttpAsioServerRequest : public HttpServerRequest
{
public:
- HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber);
+ HttpAsioServerRequest(HttpRequestParser& Request,
+ HttpService& Service,
+ IoBuffer PayloadBuffer,
+ uint32_t RequestNumber,
+ bool IsLocalMachineRequest,
+ std::string RemoteAddress);
~HttpAsioServerRequest();
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
+ virtual bool IsLocalMachineRequest() const override;
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual std::string_view GetRemoteAddress() const override;
+
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
@@ -548,6 +572,8 @@ public:
HttpRequestParser& m_Request;
uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers
IoBuffer m_PayloadBuffer;
+ bool m_IsLocalMachineRequest;
+ std::string m_RemoteAddress;
std::unique_ptr<HttpResponse> m_Response;
};
@@ -925,6 +951,7 @@ private:
void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, uint32_t RequestNumber, HttpResponse* ResponseToPop);
void CloseConnection();
+ void SendInlineResponse(uint32_t RequestNumber, std::string_view StatusLine, std::string_view Headers = {}, std::string_view Body = {});
HttpAsioServerImpl& m_Server;
asio::streambuf m_RequestBuffer;
@@ -1025,6 +1052,8 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]
}
}
+ m_Server.m_TotalBytesReceived.fetch_add(ByteCount, std::memory_order_relaxed);
+
ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}",
m_ConnectionId,
m_RequestCounter.load(std::memory_order_relaxed),
@@ -1078,6 +1107,8 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
return;
}
+ m_Server.m_TotalBytesSent.fetch_add(ByteCount, std::memory_order_relaxed);
+
ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}",
m_ConnectionId,
RequestNumber,
@@ -1139,10 +1170,91 @@ HttpServerConnection::CloseConnection()
}
void
+HttpServerConnection::SendInlineResponse(uint32_t RequestNumber,
+ std::string_view StatusLine,
+ std::string_view Headers,
+ std::string_view Body)
+{
+ ExtendableStringBuilder<256> ResponseBuilder;
+ ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n";
+ if (!Headers.empty())
+ {
+ ResponseBuilder << Headers;
+ }
+ if (!m_RequestData.IsKeepAlive())
+ {
+ ResponseBuilder << "Connection: close\r\n";
+ }
+ ResponseBuilder << "\r\n";
+ if (!Body.empty())
+ {
+ ResponseBuilder << Body;
+ }
+ auto ResponseView = ResponseBuilder.ToView();
+ IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size());
+ auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize());
+ asio::async_write(
+ *m_Socket.get(),
+ Buffer,
+ [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) {
+ Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
+ });
+}
+
+void
HttpServerConnection::HandleRequest()
{
ZEN_MEMSCOPE(GetHttpasioTag());
+ // WebSocket upgrade detection must happen before the keep-alive check below,
+ // because Upgrade requests have "Connection: Upgrade" which the HTTP parser
+ // treats as non-keep-alive, causing a premature shutdown of the receive side.
+ if (m_RequestData.IsWebSocketUpgrade())
+ {
+ if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url()))
+ {
+ IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service);
+ if (WsHandler && !m_RequestData.SecWebSocketKey().empty())
+ {
+ std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(m_RequestData.SecWebSocketKey());
+
+ auto ResponseStr = std::make_shared<std::string>();
+ ResponseStr->reserve(256);
+ ResponseStr->append(
+ "HTTP/1.1 101 Switching Protocols\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: ");
+ ResponseStr->append(AcceptKey);
+ ResponseStr->append("\r\n\r\n");
+
+ // Send the 101 response on the current socket, then hand the socket off
+ // to a WsAsioConnection for the WebSocket protocol.
+ asio::async_write(*m_Socket,
+ asio::buffer(ResponseStr->data(), ResponseStr->size()),
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
+ return;
+ }
+
+ Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened();
+ Ref<WsAsioConnection> WsConn(
+ new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+ });
+
+ m_RequestState = RequestState::kDone;
+ return;
+ }
+ }
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
+
if (!m_RequestData.IsKeepAlive())
{
m_RequestState = RequestState::kWritingFinal;
@@ -1166,14 +1278,24 @@ HttpServerConnection::HandleRequest()
{
ZEN_TRACE_CPU("asio::HandleRequest");
- HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber);
+ m_Server.m_HttpServer->MarkRequest();
+
+ auto RemoteEndpoint = m_Socket->remote_endpoint();
+ bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address();
+
+ HttpAsioServerRequest Request(m_RequestData,
+ *Service,
+ m_RequestData.Body(),
+ RequestNumber,
+ IsLocalConnection,
+ RemoteEndpoint.address().to_string());
ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber);
const HttpVerb RequestVerb = Request.RequestVerb();
const std::string_view Uri = Request.RelativeUri();
- if (m_Server.m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server.m_RequestLog.ShouldLog(logging::Trace))
{
ZEN_LOG_TRACE(m_Server.m_RequestLog,
"connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})",
@@ -1188,56 +1310,73 @@ HttpServerConnection::HandleRequest()
std::vector<IoBuffer>{Request.ReadPayload()});
}
- if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_Server.FilterRequest(Request);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
{
- try
- {
- Service->HandleRequest(Request);
- }
- catch (const AssertException& AssertEx)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
- }
- catch (const std::system_error& SystemError)
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
{
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ try
{
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ Service->HandleRequest(Request);
}
- else
+ catch (const AssertException& AssertEx)
{
- ZEN_WARN("Caught system error exception while handling request: {}. ({})",
- SystemError.what(),
- SystemError.code().value());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
}
- }
- catch (const std::bad_alloc& BadAlloc)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ catch (const std::system_error& SystemError)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
- }
- catch (const std::exception& ex)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ {
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ }
+ else
+ {
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})",
+ SystemError.what(),
+ SystemError.code().value());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ }
+ }
+ catch (const std::bad_alloc& BadAlloc)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- ZEN_WARN("Caught exception while handling request: {}", ex.what());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
+ }
+ catch (const std::exception& ex)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_WARN("Caught exception while handling request: {}", ex.what());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
}
}
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
+ {
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
+ }
if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response))
{
+ if (Request.ShouldLogRequest())
+ {
+ ZEN_INFO("{} {} {} -> {}", ToString(RequestVerb), Uri, Response->ResponseCode(), NiceBytes(Response->ContentLength()));
+ }
+
// Transmit the response
if (m_RequestData.RequestVerb() == HttpVerb::kHead)
@@ -1278,51 +1417,24 @@ HttpServerConnection::HandleRequest()
}
}
- if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ // If a default redirect is configured and the request is for the root path, send a 302
+ std::string_view DefaultRedirect = m_Server.m_HttpServer->GetDefaultRedirect();
+ if (!DefaultRedirect.empty() && (m_RequestData.Url() == "/" || m_RequestData.Url().empty()))
{
- std::string_view Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "\r\n"sv;
-
- if (!m_RequestData.IsKeepAlive())
- {
- Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Connection: close\r\n"
- "\r\n"sv;
- }
-
- asio::async_write(*m_Socket.get(),
- asio::buffer(Response),
- [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) {
- Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
- });
+ ExtendableStringBuilder<128> Headers;
+ Headers << "Location: " << DefaultRedirect << "\r\nContent-Length: 0\r\n";
+ SendInlineResponse(RequestNumber, "302 Found"sv, Headers.ToView());
+ }
+ else if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ {
+ SendInlineResponse(RequestNumber, "404 NOT FOUND"sv);
}
else
{
- std::string_view Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Content-Length: 23\r\n"
- "Content-Type: text/plain\r\n"
- "\r\n"
- "No suitable route found"sv;
-
- if (!m_RequestData.IsKeepAlive())
- {
- Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Content-Length: 23\r\n"
- "Content-Type: text/plain\r\n"
- "Connection: close\r\n"
- "\r\n"
- "No suitable route found"sv;
- }
-
- asio::async_write(*m_Socket.get(),
- asio::buffer(Response),
- [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) {
- Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
- });
+ SendInlineResponse(RequestNumber,
+ "404 NOT FOUND"sv,
+ "Content-Length: 23\r\nContent-Type: text/plain\r\n"sv,
+ "No suitable route found"sv);
}
}
@@ -1348,8 +1460,11 @@ struct HttpAcceptor
m_Acceptor.set_option(exclusive_address(true));
m_AlternateProtocolAcceptor.set_option(exclusive_address(true));
#else // ZEN_PLATFORM_WINDOWS
- m_Acceptor.set_option(asio::socket_base::reuse_address(false));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(false));
+ // Allow binding to a port in TIME_WAIT so the server can restart immediately
+ // after a previous instance exits. On Linux this does not allow two processes
+ // to actively listen on the same port simultaneously.
+ m_Acceptor.set_option(asio::socket_base::reuse_address(true));
+ m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(true));
#endif // ZEN_PLATFORM_WINDOWS
m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
@@ -1512,7 +1627,7 @@ struct HttpAcceptor
{
ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message());
- return 0;
+ return {};
}
if (EffectivePort != BasePort)
@@ -1569,7 +1684,8 @@ struct HttpAcceptor
void StopAccepting() { m_IsStopped = true; }
- int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); }
+ int GetAcceptPort() const { return m_Acceptor.local_endpoint().port(); }
+ bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); }
bool IsValid() const { return m_IsValid; }
@@ -1632,11 +1748,15 @@ private:
HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request,
HttpService& Service,
IoBuffer PayloadBuffer,
- uint32_t RequestNumber)
+ uint32_t RequestNumber,
+ bool IsLocalMachineRequest,
+ std::string RemoteAddress)
: HttpServerRequest(Service)
, m_Request(Request)
, m_RequestNumber(RequestNumber)
, m_PayloadBuffer(std::move(PayloadBuffer))
+, m_IsLocalMachineRequest(IsLocalMachineRequest)
+, m_RemoteAddress(std::move(RemoteAddress))
{
const int PrefixLength = Service.UriPrefixLength();
@@ -1708,6 +1828,24 @@ HttpAsioServerRequest::ParseRequestId() const
return m_Request.RequestId();
}
+bool
+HttpAsioServerRequest::IsLocalMachineRequest() const
+{
+ return m_IsLocalMachineRequest;
+}
+
+std::string_view
+HttpAsioServerRequest::GetRemoteAddress() const
+{
+ return m_RemoteAddress;
+}
+
+std::string_view
+HttpAsioServerRequest::GetAuthorizationHeader() const
+{
+ return m_Request.AuthorizationHeader();
+}
+
IoBuffer
HttpAsioServerRequest::ReadPayload()
{
@@ -1904,6 +2042,37 @@ HttpAsioServerImpl::RouteRequest(std::string_view Url)
return CandidateService;
}
+void
+HttpAsioServerImpl::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+IHttpRequestFilter::Result
+HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_Lock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+
+ return RequestFilter->FilterRequest(Request);
+}
+
+bool
+HttpAsioServerImpl::IsLoopbackOnly() const
+{
+ return m_Acceptor && m_Acceptor->IsLoopbackOnly();
+}
+
} // namespace zen::asio_http
//////////////////////////////////////////////////////////////////////////
@@ -1916,11 +2085,15 @@ public:
HttpAsioServer(const AsioConfig& Config);
~HttpAsioServer();
- virtual void OnRegisterService(HttpService& Service) override;
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool IsInteractiveSession) override;
- virtual void OnRequestExit() override;
- virtual void OnClose() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual void OnRun(bool IsInteractiveSession) override;
+ virtual void OnRequestExit() override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
+ virtual uint64_t GetTotalBytesReceived() const override;
+ virtual uint64_t GetTotalBytesSent() const override;
private:
Event m_ShutdownEvent;
@@ -1934,6 +2107,7 @@ HttpAsioServer::HttpAsioServer(const AsioConfig& Config)
: m_InitialConfig(Config)
, m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>())
{
+ m_Impl->m_HttpServer = this;
ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser));
}
@@ -1965,6 +2139,12 @@ HttpAsioServer::OnRegisterService(HttpService& Service)
m_Impl->RegisterService(Service.BaseUri(), Service);
}
+void
+HttpAsioServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ m_Impl->SetHttpRequestFilter(RequestFilter);
+}
+
int
HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
{
@@ -1989,10 +2169,46 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
return m_BasePort;
}
+std::string
+HttpAsioServer::OnGetExternalHost() const
+{
+ if (m_Impl->IsLoopbackOnly())
+ {
+ return "127.0.0.1";
+ }
+
+ // Use the UDP connect trick: connecting a UDP socket to an external address
+ // causes the OS to select the appropriate local interface without sending any data.
+ try
+ {
+ asio::io_service IoService;
+ asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4());
+ Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80));
+ return Sock.local_endpoint().address().to_string();
+ }
+ catch (const std::exception&)
+ {
+ return GetMachineName();
+ }
+}
+
+uint64_t
+HttpAsioServer::GetTotalBytesReceived() const
+{
+ return m_Impl->m_TotalBytesReceived.load(std::memory_order_relaxed);
+}
+
+uint64_t
+HttpAsioServer::GetTotalBytesSent() const
+{
+ return m_Impl->m_TotalBytesSent.load(std::memory_order_relaxed);
+}
+
void
HttpAsioServer::OnRun(bool IsInteractive)
{
- const int WaitTimeout = 1000;
+ const int WaitTimeout = 1000;
+ bool ShutdownRequested = false;
#if ZEN_PLATFORM_WINDOWS
if (IsInteractive)
@@ -2008,12 +2224,13 @@ HttpAsioServer::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#else
if (IsInteractive)
{
@@ -2022,8 +2239,8 @@ HttpAsioServer::OnRun(bool IsInteractive)
do
{
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#endif
}
diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h
index c483dfc28..3ec1141a7 100644
--- a/src/zenhttp/servers/httpasio.h
+++ b/src/zenhttp/servers/httpasio.h
@@ -15,4 +15,6 @@ struct AsioConfig
Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config);
+bool IsIPv6Capable();
+
} // namespace zen
diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp
index 31cb04be5..584e06cbf 100644
--- a/src/zenhttp/servers/httpmulti.cpp
+++ b/src/zenhttp/servers/httpmulti.cpp
@@ -54,9 +54,19 @@ HttpMultiServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
}
void
+HttpMultiServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ for (auto& Server : m_Servers)
+ {
+ Server->SetHttpRequestFilter(RequestFilter);
+ }
+}
+
+void
HttpMultiServer::OnRun(bool IsInteractiveSession)
{
- const int WaitTimeout = 1000;
+ const int WaitTimeout = 1000;
+ bool ShutdownRequested = false;
#if ZEN_PLATFORM_WINDOWS
if (IsInteractiveSession)
@@ -72,12 +82,13 @@ HttpMultiServer::OnRun(bool IsInteractiveSession)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#else
if (IsInteractiveSession)
{
@@ -86,8 +97,8 @@ HttpMultiServer::OnRun(bool IsInteractiveSession)
do
{
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#endif
}
@@ -106,6 +117,16 @@ HttpMultiServer::OnClose()
}
}
+std::string
+HttpMultiServer::OnGetExternalHost() const
+{
+ if (!m_Servers.empty())
+ {
+ return std::string(m_Servers.front()->GetExternalHost());
+ }
+ return HttpServer::OnGetExternalHost();
+}
+
void
HttpMultiServer::AddServer(Ref<HttpServer> Server)
{
diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h
index ae0ed74cf..97699828a 100644
--- a/src/zenhttp/servers/httpmulti.h
+++ b/src/zenhttp/servers/httpmulti.h
@@ -15,11 +15,13 @@ public:
HttpMultiServer();
~HttpMultiServer();
- virtual void OnRegisterService(HttpService& Service) override;
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool IsInteractiveSession) override;
- virtual void OnRequestExit() override;
- virtual void OnClose() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnRun(bool IsInteractiveSession) override;
+ virtual void OnRequestExit() override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
void AddServer(Ref<HttpServer> Server);
diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp
index 0ec1cb3c4..9bb7ef3bc 100644
--- a/src/zenhttp/servers/httpnull.cpp
+++ b/src/zenhttp/servers/httpnull.cpp
@@ -24,6 +24,12 @@ HttpNullServer::OnRegisterService(HttpService& Service)
ZEN_UNUSED(Service);
}
+void
+HttpNullServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ ZEN_UNUSED(RequestFilter);
+}
+
int
HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
{
@@ -34,7 +40,8 @@ HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
void
HttpNullServer::OnRun(bool IsInteractiveSession)
{
- const int WaitTimeout = 1000;
+ const int WaitTimeout = 1000;
+ bool ShutdownRequested = false;
#if ZEN_PLATFORM_WINDOWS
if (IsInteractiveSession)
@@ -50,12 +57,13 @@ HttpNullServer::OnRun(bool IsInteractiveSession)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#else
if (IsInteractiveSession)
{
@@ -64,8 +72,8 @@ HttpNullServer::OnRun(bool IsInteractiveSession)
do
{
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#endif
}
diff --git a/src/zenhttp/servers/httpnull.h b/src/zenhttp/servers/httpnull.h
index ce7230938..52838f012 100644
--- a/src/zenhttp/servers/httpnull.h
+++ b/src/zenhttp/servers/httpnull.h
@@ -18,6 +18,7 @@ public:
~HttpNullServer();
virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
virtual void OnRun(bool IsInteractiveSession) override;
virtual void OnRequestExit() override;
diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp
index 93094e21b..918b55dc6 100644
--- a/src/zenhttp/servers/httpparser.cpp
+++ b/src/zenhttp/servers/httpparser.cpp
@@ -12,13 +12,17 @@ namespace zen {
using namespace std::literals;
-static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
-static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
-static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
-static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
-static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
-static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
-static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+static constexpr uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
+static constexpr uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
+static constexpr uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
+static constexpr uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
+static constexpr uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
+static constexpr uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
+static constexpr uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+static constexpr uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv);
+static constexpr uint32_t HashUpgrade = HashStringAsLowerDjb2("Upgrade"sv);
+static constexpr uint32_t HashSecWebSocketKey = HashStringAsLowerDjb2("Sec-WebSocket-Key"sv);
+static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-WebSocket-Version"sv);
//////////////////////////////////////////////////////////////////////////
//
@@ -142,41 +146,62 @@ HttpRequestParser::ParseCurrentHeader()
const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName);
const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1);
- if (HeaderHash == HashContentLength)
+ switch (HeaderHash)
{
- m_ContentLengthHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashAccept)
- {
- m_AcceptHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashContentType)
- {
- m_ContentTypeHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashSession)
- {
- m_SessionId = Oid::TryFromHexString(HeaderValue);
- }
- else if (HeaderHash == HashRequest)
- {
- std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
- }
- else if (HeaderHash == HashExpect)
- {
- if (HeaderValue == "100-continue"sv)
- {
- // We don't currently do anything with this
- m_Expect100Continue = true;
- }
- else
- {
- ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
- }
- }
- else if (HeaderHash == HashRange)
- {
- m_RangeHeaderIndex = CurrentHeaderIndex;
+ case HashContentLength:
+ m_ContentLengthHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashAccept:
+ m_AcceptHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashContentType:
+ m_ContentTypeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashAuthorization:
+ m_AuthorizationHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSession:
+ m_SessionId = Oid::TryFromHexString(HeaderValue);
+ break;
+
+ case HashRequest:
+ std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
+ break;
+
+ case HashExpect:
+ if (HeaderValue == "100-continue"sv)
+ {
+ // We don't currently do anything with this
+ m_Expect100Continue = true;
+ }
+ else
+ {
+ ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
+ }
+ break;
+
+ case HashRange:
+ m_RangeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashUpgrade:
+ m_UpgradeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSecWebSocketKey:
+ m_SecWebSocketKeyHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSecWebSocketVersion:
+ m_SecWebSocketVersionHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ default:
+ break;
}
}
@@ -220,11 +245,6 @@ NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl)
NormalizedUrl.reserve(UrlLength);
NormalizedUrl.append(Url, UrlIndex);
}
-
- if (!LastCharWasSeparator)
- {
- NormalizedUrl.push_back('/');
- }
}
else if (!NormalizedUrl.empty())
{
@@ -305,6 +325,7 @@ HttpRequestParser::OnHeadersComplete()
if (ContentLength)
{
+ // TODO: should sanity-check content length here
m_BodyBuffer = IoBuffer(ContentLength);
}
@@ -324,9 +345,9 @@ HttpRequestParser::OnHeadersComplete()
int
HttpRequestParser::OnBody(const char* Data, size_t Bytes)
{
- if (m_BodyPosition + Bytes > m_BodyBuffer.Size())
+ if ((m_BodyPosition + Bytes) > m_BodyBuffer.Size())
{
- ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes",
+ ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes",
(m_BodyPosition + Bytes) - m_BodyBuffer.Size());
return 1;
}
@@ -337,7 +358,7 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes)
{
if (m_BodyPosition != m_BodyBuffer.Size())
{
- ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
+ ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
return 1;
}
}
@@ -353,13 +374,18 @@ HttpRequestParser::ResetState()
m_HeaderEntries.clear();
- m_ContentLengthHeaderIndex = -1;
- m_AcceptHeaderIndex = -1;
- m_ContentTypeHeaderIndex = -1;
- m_RangeHeaderIndex = -1;
- m_Expect100Continue = false;
- m_BodyBuffer = {};
- m_BodyPosition = 0;
+ m_ContentLengthHeaderIndex = -1;
+ m_AcceptHeaderIndex = -1;
+ m_ContentTypeHeaderIndex = -1;
+ m_RangeHeaderIndex = -1;
+ m_AuthorizationHeaderIndex = -1;
+ m_UpgradeHeaderIndex = -1;
+ m_SecWebSocketKeyHeaderIndex = -1;
+ m_SecWebSocketVersionHeaderIndex = -1;
+ m_RequestVerb = HttpVerb::kGet;
+ m_Expect100Continue = false;
+ m_BodyBuffer = {};
+ m_BodyPosition = 0;
m_HeaderData.clear();
m_NormalizedUrl.clear();
@@ -416,4 +442,21 @@ HttpRequestParser::OnMessageComplete()
}
}
+bool
+HttpRequestParser::IsWebSocketUpgrade() const
+{
+ std::string_view Upgrade = GetHeaderValue(m_UpgradeHeaderIndex);
+ if (Upgrade.empty())
+ {
+ return false;
+ }
+
+ // Case-insensitive check for "websocket"
+ if (Upgrade.size() != 9)
+ {
+ return false;
+ }
+ return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0;
+}
+
} // namespace zen
diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h
index 0d2664ec5..23ad9d8fb 100644
--- a/src/zenhttp/servers/httpparser.h
+++ b/src/zenhttp/servers/httpparser.h
@@ -46,6 +46,12 @@ struct HttpRequestParser
std::string_view RangeHeader() const { return GetHeaderValue(m_RangeHeaderIndex); }
+ std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); }
+
+ std::string_view UpgradeHeader() const { return GetHeaderValue(m_UpgradeHeaderIndex); }
+ std::string_view SecWebSocketKey() const { return GetHeaderValue(m_SecWebSocketKeyHeaderIndex); }
+ bool IsWebSocketUpgrade() const;
+
private:
struct HeaderRange
{
@@ -83,7 +89,11 @@ private:
int8_t m_AcceptHeaderIndex;
int8_t m_ContentTypeHeaderIndex;
int8_t m_RangeHeaderIndex;
- HttpVerb m_RequestVerb;
+ int8_t m_AuthorizationHeaderIndex;
+ int8_t m_UpgradeHeaderIndex;
+ int8_t m_SecWebSocketKeyHeaderIndex;
+ int8_t m_SecWebSocketVersionHeaderIndex;
+ HttpVerb m_RequestVerb = HttpVerb::kGet;
std::atomic_bool m_KeepAlive{false};
bool m_Expect100Continue = false;
int m_RequestId = -1;
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp
index b9217ed87..4bf8c61bb 100644
--- a/src/zenhttp/servers/httpplugin.cpp
+++ b/src/zenhttp/servers/httpplugin.cpp
@@ -96,6 +96,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
// HttpPluginServer
virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
virtual void OnRun(bool IsInteractiveSession) override;
virtual void OnRequestExit() override;
@@ -104,7 +105,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
virtual void AddPlugin(Ref<TransportPlugin> Plugin) override;
virtual void RemovePlugin(Ref<TransportPlugin> Plugin) override;
- HttpService* RouteRequest(std::string_view Url);
+ HttpService* RouteRequest(std::string_view Url);
+ IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
struct ServiceEntry
{
@@ -112,7 +114,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
HttpService* Service;
};
- bool m_IsInitialized = false;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+ bool m_IsInitialized = false;
RwLock m_Lock;
std::vector<ServiceEntry> m_UriHandlers;
std::vector<Ref<TransportPlugin>> m_Plugins;
@@ -120,7 +123,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
bool m_IsRequestLoggingEnabled = false;
LoggerRef m_RequestLog;
std::atomic_uint32_t m_ConnectionIdCounter{0};
- int m_BasePort;
+ int m_BasePort = 0;
HttpServerTracer m_RequestTracer;
@@ -143,8 +146,11 @@ public:
HttpPluginServerRequest(const HttpPluginServerRequest&) = delete;
HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete;
- virtual Oid ParseSessionId() const override;
- virtual uint32_t ParseRequestId() const override;
+ // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection
+ virtual bool IsLocalMachineRequest() const /* override*/ { return false; }
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual Oid ParseSessionId() const override;
+ virtual uint32_t ParseRequestId() const override;
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
@@ -288,7 +294,7 @@ HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPlug
ConnectionName = "anonymous";
}
- ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('')", m_ConnectionId, ConnectionName);
+ ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('{}')", m_ConnectionId, ConnectionName);
}
uint32_t
@@ -372,12 +378,14 @@ HttpPluginConnectionHandler::HandleRequest()
{
ZEN_TRACE_CPU("http_plugin::HandleRequest");
+ m_Server->MarkRequest();
+
HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body());
const HttpVerb RequestVerb = Request.RequestVerb();
const std::string_view Uri = Request.RelativeUri();
- if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server->m_RequestLog.ShouldLog(logging::Trace))
{
ZEN_LOG_TRACE(m_Server->m_RequestLog,
"connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})",
@@ -392,53 +400,65 @@ HttpPluginConnectionHandler::HandleRequest()
std::vector<IoBuffer>{Request.ReadPayload()});
}
- if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_Server->FilterRequest(Request);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
{
- try
- {
- Service->HandleRequest(Request);
- }
- catch (const AssertException& AssertEx)
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
{
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
- }
- catch (const std::system_error& SystemError)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ try
{
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ Service->HandleRequest(Request);
}
- else
+ catch (const AssertException& AssertEx)
{
- ZEN_WARN("Caught system error exception while handling request: {}. ({})",
- SystemError.what(),
- SystemError.code().value());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
}
- }
- catch (const std::bad_alloc& BadAlloc)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ catch (const std::system_error& SystemError)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ {
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ }
+ else
+ {
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})",
+ SystemError.what(),
+ SystemError.code().value());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ }
+ }
+ catch (const std::bad_alloc& BadAlloc)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
- }
- catch (const std::exception& ex)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
+ }
+ catch (const std::exception& ex)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- ZEN_WARN("Caught exception while handling request: {}", ex.what());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ ZEN_WARN("Caught exception while handling request: {}", ex.what());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
}
}
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
+ {
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
+ }
if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response))
{
@@ -462,7 +482,7 @@ HttpPluginConnectionHandler::HandleRequest()
const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers();
- if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server->m_RequestLog.ShouldLog(logging::Trace))
{
m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber),
ResponseBuffers);
@@ -618,6 +638,12 @@ HttpPluginServerRequest::~HttpPluginServerRequest()
{
}
+std::string_view
+HttpPluginServerRequest::GetAuthorizationHeader() const
+{
+ return m_Request.AuthorizationHeader();
+}
+
Oid
HttpPluginServerRequest::ParseSessionId() const
{
@@ -750,6 +776,13 @@ HttpPluginServerImpl::OnInitialize(int InBasePort, std::filesystem::path DataDir
}
void
+HttpPluginServerImpl::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+void
HttpPluginServerImpl::OnClose()
{
if (!m_IsInitialized)
@@ -806,6 +839,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
@@ -894,6 +928,22 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url)
return CandidateService;
}
+IHttpRequestFilter::Result
+HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_Lock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ return RequestFilter->FilterRequest(Request);
+}
+
//////////////////////////////////////////////////////////////////////////
struct HttpPluginServerImpl;
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 54cc0c22d..dfe6bb6aa 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -12,6 +12,7 @@
#include <zencore/memory/llm.h>
#include <zencore/scopeguard.h>
#include <zencore/string.h>
+#include <zencore/system.h>
#include <zencore/timer.h>
#include <zencore/trace.h>
#include <zenhttp/packageformat.h>
@@ -25,7 +26,9 @@
# include <zencore/workthreadpool.h>
# include "iothreadpool.h"
+# include <atomic>
# include <http.h>
+# include <asio.hpp> // for resolving addresses for GetExternalHost
namespace zen {
@@ -72,6 +75,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In
OutString.Append("unknown");
}
+class HttpSysServerRequest;
+
/**
* @brief Windows implementation of HTTP server based on http.sys
*
@@ -83,6 +88,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In
class HttpSysServer : public HttpServer
{
friend class HttpSysTransaction;
+ friend class HttpMessageResponseRequest;
+ friend struct InitialRequestHandler;
public:
explicit HttpSysServer(const HttpSysConfig& Config);
@@ -90,17 +97,23 @@ public:
// HttpServer interface implementation
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool TestMode) override;
- virtual void OnRequestExit() override;
- virtual void OnRegisterService(HttpService& Service) override;
- virtual void OnClose() override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnRun(bool TestMode) override;
+ virtual void OnRequestExit() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
+ virtual uint64_t GetTotalBytesReceived() const override;
+ virtual uint64_t GetTotalBytesSent() const override;
WorkerThreadPool& WorkPool();
inline bool IsOk() const { return m_IsOk; }
inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; }
+ IHttpRequestFilter::Result FilterRequest(HttpSysServerRequest& Request);
+
private:
int InitializeServer(int BasePort);
void Cleanup();
@@ -124,8 +137,8 @@ private:
std::unique_ptr<WinIoThreadPool> m_IoThreadPool;
- RwLock m_AsyncWorkPoolInitLock;
- WorkerThreadPool* m_AsyncWorkPool = nullptr;
+ RwLock m_AsyncWorkPoolInitLock;
+ std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr;
std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
@@ -137,6 +150,12 @@ private:
int32_t m_MaxPendingRequests = 128;
Event m_ShutdownEvent;
HttpSysConfig m_InitialConfig;
+
+ RwLock m_RequestFilterLock;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+
+ std::atomic<uint64_t> m_TotalBytesReceived{0};
+ std::atomic<uint64_t> m_TotalBytesSent{0};
};
} // namespace zen
@@ -144,6 +163,10 @@ private:
#if ZEN_WITH_HTTPSYS
+# include "httpsys_iocontext.h"
+# include "wshttpsys.h"
+# include "wsframecodec.h"
+
# include <conio.h>
# include <mstcpip.h>
# pragma comment(lib, "httpapi.lib")
@@ -313,6 +336,10 @@ public:
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
+ virtual bool IsLocalMachineRequest() const override;
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual std::string_view GetRemoteAddress() const override;
+
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
@@ -320,16 +347,19 @@ public:
virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
virtual bool TryGetRanges(HttpRanges& Ranges) override;
+ void LogRequest(HttpMessageResponseRequest* Response);
+
using HttpServerRequest::WriteResponse;
HttpSysServerRequest(const HttpSysServerRequest&) = delete;
HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete;
- HttpSysTransaction& m_HttpTx;
- HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
- IoBuffer m_PayloadBuffer;
- ExtendableStringBuilder<128> m_UriUtf8;
- ExtendableStringBuilder<128> m_QueryStringUtf8;
+ HttpSysTransaction& m_HttpTx;
+ HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
+ IoBuffer m_PayloadBuffer;
+ ExtendableStringBuilder<128> m_UriUtf8;
+ ExtendableStringBuilder<128> m_QueryStringUtf8;
+ mutable ExtendableStringBuilder<64> m_RemoteAddress;
};
/** HTTP transaction
@@ -363,7 +393,7 @@ public:
PTP_IO Iocp();
HANDLE RequestQueueHandle();
- inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
+ inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; }
inline HttpSysServer& Server() { return m_HttpServer; }
inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
@@ -380,8 +410,8 @@ public:
};
private:
- OVERLAPPED m_HttpOverlapped{};
- HttpSysServer& m_HttpServer;
+ HttpSysIoContext m_IoContext{};
+ HttpSysServer& m_HttpServer;
// Tracks which handler is due to handle the next I/O completion event
HttpSysRequestHandler* m_CompletionHandler = nullptr;
@@ -418,7 +448,10 @@ public:
virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
void SuppressResponseBody(); // typically used for HEAD requests
- inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
+ inline uint16_t GetResponseCode() const { return m_ResponseCode; }
+ inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
+
+ void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; }
private:
eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks;
@@ -429,6 +462,7 @@ private:
bool m_IsInitialResponse = true;
HttpContentType m_ContentType = HttpContentType::kBinary;
eastl::fixed_vector<IoBuffer, 16> m_DataBuffers;
+ std::string m_LocationHeader;
void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs);
};
@@ -569,7 +603,7 @@ HttpMessageResponseRequest::SuppressResponseBody()
HttpSysRequestHandler*
HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
{
- ZEN_UNUSED(NumberOfBytesTransferred);
+ Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
if (IoResult != NO_ERROR)
{
@@ -684,6 +718,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
ContentTypeHeader->pRawValue = ContentTypeString.data();
ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
+ // Location header (for redirects)
+
+ if (!m_LocationHeader.empty())
+ {
+ PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation];
+ LocationHeader->pRawValue = m_LocationHeader.data();
+ LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size();
+ }
+
std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
HttpResponse.StatusCode = m_ResponseCode;
@@ -694,21 +737,22 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
HTTP_CACHE_POLICY CachePolicy;
- CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates;
+ CachePolicy.Policy = HttpCachePolicyNocache;
CachePolicy.SecondsToLive = 0;
// Initial response API call
- SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(),
- HttpReq->RequestId,
- SendFlags,
- &HttpResponse,
- &CachePolicy,
- NULL,
- NULL,
- 0,
- Tx.Overlapped(),
- NULL);
+ SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), // RequestQueueHandle
+ HttpReq->RequestId, // RequestId
+ SendFlags, // Flags
+ &HttpResponse, // HttpResponse
+ &CachePolicy, // CachePolicy
+ NULL, // BytesSent
+ NULL, // Reserved1
+ 0, // Reserved2
+ Tx.Overlapped(), // Overlapped
+ NULL // LogData
+ );
m_IsInitialResponse = false;
}
@@ -716,9 +760,9 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
{
// Subsequent response API calls
- SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(),
- HttpReq->RequestId,
- SendFlags,
+ SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), // RequestQueueHandle
+ HttpReq->RequestId, // RequestId
+ SendFlags, // Flags
(USHORT)ThisRequestChunkCount, // EntityChunkCount
&m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks
NULL, // BytesSent
@@ -884,7 +928,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr
ZEN_UNUSED(IoResult, NumberOfBytesTransferred);
- ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred);
+ ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}",
+ GetSystemErrorAsString(IoResult),
+ IoResult,
+ NumberOfBytesTransferred);
return this;
}
@@ -1017,8 +1064,10 @@ HttpSysServer::~HttpSysServer()
ZEN_ERROR("~HttpSysServer() called without calling Close() first");
}
- delete m_AsyncWorkPool;
+ auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed);
m_AsyncWorkPool = nullptr;
+
+ delete WorkPool;
}
void
@@ -1049,7 +1098,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})",
+ WideToUtf8(WildcardUrlPath),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1058,7 +1110,7 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result);
return 0;
}
@@ -1082,7 +1134,9 @@ HttpSysServer::InitializeServer(int BasePort)
if ((Result == ERROR_SHARING_VIOLATION))
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
@@ -1104,7 +1158,9 @@ HttpSysServer::InitializeServer(int BasePort)
{
for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++)
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
}
@@ -1128,25 +1184,29 @@ HttpSysServer::InitializeServer(int BasePort)
// port for the current user. eg:
// netsh http add urlacl url=http://*:8558/ user=<some_user>
- ZEN_WARN(
- "Unable to register handler using '{}' - falling back to local-only. "
- "Please ensure the appropriate netsh URL reservation configuration "
- "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)",
- WideToUtf8(WildcardUrlPath));
+ if (!m_InitialConfig.ForceLoopback)
+ {
+ ZEN_WARN(
+ "Unable to register handler using '{}' - falling back to local-only. "
+ "Please ensure the appropriate netsh URL reservation configuration "
+ "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)",
+ WideToUtf8(WildcardUrlPath));
+ }
const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
- ULONG InternalResult = ERROR_SHARING_VIOLATION;
- for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
+ bool ShouldRetryNextPort = true;
+ for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset)
{
- EffectivePort = BasePort + (PortOffset * 100);
+ EffectivePort = BasePort + (PortOffset * 100);
+ ShouldRetryNextPort = false;
for (const std::u8string_view Host : Hosts)
{
WideStringBuilder<64> LocalUrlPath;
LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv;
- InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+ ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
if (InternalResult == NO_ERROR)
{
@@ -1154,11 +1214,25 @@ HttpSysServer::InitializeServer(int BasePort)
m_BaseUris.push_back(LocalUrlPath.c_str());
}
+ else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED)
+ {
+ // Port may be owned by another process's wildcard registration (access denied)
+ // or actively in use (sharing violation) — retry on a different port
+ ShouldRetryNextPort = true;
+ }
else
{
- break;
+ ZEN_WARN("Failed to register local handler '{}': {} ({:#x})",
+ WideToUtf8(LocalUrlPath),
+ GetSystemErrorAsString(InternalResult),
+ InternalResult);
}
}
+
+ if (!m_BaseUris.empty())
+ {
+ break;
+ }
}
}
else
@@ -1174,7 +1248,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (m_BaseUris.empty())
{
- ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})",
+ WideToUtf8(WildcardUrlPath),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1192,7 +1269,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1204,7 +1284,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1236,7 +1319,7 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result);
+ ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result);
}
}
@@ -1258,21 +1341,6 @@ HttpSysServer::InitializeServer(int BasePort)
ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
}
- // This is not available in all Windows SDK versions so for now we can't use recently
- // released functionality. We should investigate how to get more recent SDK releases
- // into the build
-
-# if 0
- if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4))
- {
- ZEN_DEBUG("HTTP3 is available");
- }
- else
- {
- ZEN_DEBUG("HTTP3 is NOT available");
- }
-# endif
-
return EffectivePort;
}
@@ -1305,17 +1373,17 @@ HttpSysServer::WorkPool()
{
ZEN_MEMSCOPE(GetHttpsysTag());
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_acquire))
{
RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock);
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_relaxed))
{
- m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async");
+ m_AsyncWorkPool.store(new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"), std::memory_order_release);
}
}
- return *m_AsyncWorkPool;
+ return *m_AsyncWorkPool.load(std::memory_order_relaxed);
}
void
@@ -1337,9 +1405,9 @@ HttpSysServer::OnRun(bool IsInteractive)
ZEN_CONSOLE("Zen Server running (http.sys). Press ESC or Q to quit");
}
+ bool ShutdownRequested = false;
do
{
- // int WaitTimeout = -1;
int WaitTimeout = 100;
if (IsInteractive)
@@ -1352,14 +1420,15 @@ HttpSysServer::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
UpdateLofreqTimerValue();
- } while (!IsApplicationExitRequested());
+ } while (!ShutdownRequested);
}
void
@@ -1530,7 +1599,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
// than one thread at any given moment. This means we need to be careful about what
// happens in here
- HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped);
+ HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped);
+
+ switch (IoContext->ContextType)
+ {
+ case HttpSysIoContext::Type::kWebSocketRead:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kWebSocketWrite:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kTransaction:
+ break;
+ }
+
+ HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext);
if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone)
{
@@ -1641,6 +1726,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
{
HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload);
+ m_HttpServer.MarkRequest();
+
// Default request handling
# if ZEN_WITH_OTEL
@@ -1666,9 +1753,21 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator);
# endif
- if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
+ {
+ if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ {
+ Service.HandleRequest(ThisRequest);
+ }
+ }
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ ThisRequest.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
{
- Service.HandleRequest(ThisRequest);
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
}
return ThisRequest;
@@ -1810,6 +1909,52 @@ HttpSysServerRequest::ParseRequestId() const
return 0;
}
+bool
+HttpSysServerRequest::IsLocalMachineRequest() const
+{
+ const PSOCKADDR LocalAddress = m_HttpTx.HttpRequest()->Address.pLocalAddress;
+ const PSOCKADDR RemoteAddress = m_HttpTx.HttpRequest()->Address.pRemoteAddress;
+ if (LocalAddress->sa_family != RemoteAddress->sa_family)
+ {
+ return false;
+ }
+ if (LocalAddress->sa_family == AF_INET)
+ {
+ const SOCKADDR_IN& LocalAddressv4 = (const SOCKADDR_IN&)(*LocalAddress);
+ const SOCKADDR_IN& RemoteAddressv4 = (const SOCKADDR_IN&)(*RemoteAddress);
+ return LocalAddressv4.sin_addr.S_un.S_addr == RemoteAddressv4.sin_addr.S_un.S_addr;
+ }
+ else if (LocalAddress->sa_family == AF_INET6)
+ {
+ const SOCKADDR_IN6& LocalAddressv6 = (const SOCKADDR_IN6&)(*LocalAddress);
+ const SOCKADDR_IN6& RemoteAddressv6 = (const SOCKADDR_IN6&)(*RemoteAddress);
+ return memcmp(&LocalAddressv6.sin6_addr, &RemoteAddressv6.sin6_addr, sizeof(in6_addr)) == 0;
+ }
+ else
+ {
+ return false;
+ }
+}
+
+std::string_view
+HttpSysServerRequest::GetRemoteAddress() const
+{
+ if (m_RemoteAddress.Size() == 0)
+ {
+ const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress;
+ GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false);
+ }
+ return m_RemoteAddress.ToView();
+}
+
+std::string_view
+HttpSysServerRequest::GetAuthorizationHeader() const
+{
+ const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
+ const HTTP_KNOWN_HEADER& AuthorizationHeader = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAuthorization];
+ return std::string_view(AuthorizationHeader.pRawValue, AuthorizationHeader.RawValueLength);
+}
+
IoBuffer
HttpSysServerRequest::ReadPayload()
{
@@ -1823,7 +1968,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
ZEN_ASSERT(IsHandled() == false);
- auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
+ HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
if (SuppressBody())
{
@@ -1841,6 +1986,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
# endif
SetIsHandled();
+ LogRequest(Response);
}
void
@@ -1850,7 +1996,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
ZEN_ASSERT(IsHandled() == false);
- auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
+ HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
if (SuppressBody())
{
@@ -1868,6 +2014,20 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
# endif
SetIsHandled();
+ LogRequest(Response);
+}
+
+void
+HttpSysServerRequest::LogRequest(HttpMessageResponseRequest* Response)
+{
+ if (ShouldLogRequest())
+ {
+ ZEN_INFO("{} {} {} -> {}",
+ ToString(RequestVerb()),
+ m_UriUtf8.c_str(),
+ Response->GetResponseCode(),
+ NiceBytes(Response->GetResponseBodySize()));
+ }
}
void
@@ -1896,6 +2056,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
# endif
SetIsHandled();
+ LogRequest(Response);
}
void
@@ -2015,6 +2176,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
break;
}
+ Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
+
ZEN_TRACE_CPU("httpsys::HandleCompletion");
// Route request
@@ -2023,64 +2186,122 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
{
HTTP_REQUEST* HttpReq = HttpRequest();
-# if 0
- for (int i = 0; i < HttpReq->RequestInfoCount; ++i)
+ if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
{
- auto& ReqInfo = HttpReq->pRequestInfo[i];
-
- switch (ReqInfo.InfoType)
+ // WebSocket upgrade detection
+ if (m_IsInitialRequest)
{
- case HttpRequestInfoTypeRequestTiming:
+ const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade];
+ if (UpgradeHeader.RawValueLength > 0 &&
+ StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0)
+ {
+ if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service))
{
- const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo);
+ // Extract Sec-WebSocket-Key from the unknown headers
+ // (http.sys has no known-header slot for it)
+ std::string_view SecWebSocketKey;
+ for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i)
+ {
+ const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i];
+ if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0)
+ {
+ SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength);
+ break;
+ }
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeAuth:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeChannelBind:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslProtocol:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBindingDraft:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBinding:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV0:
- {
- const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo);
+ if (SecWebSocketKey.empty())
+ {
+ ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header");
+ return nullptr;
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeRequestSizing:
- {
- const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo);
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeQuicStats:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV1:
- {
- const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo);
+ const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey);
+
+ HANDLE RequestQueueHandle = Transaction().RequestQueueHandle();
+ HTTP_REQUEST_ID RequestId = HttpReq->RequestId;
+
+ // Build the 101 Switching Protocols response
+ HTTP_RESPONSE Response = {};
+ Response.StatusCode = 101;
+ Response.pReason = "Switching Protocols";
+ Response.ReasonLength = (USHORT)strlen(Response.pReason);
+
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket";
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9;
+
+ eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders;
- ZEN_INFO("");
+ // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders
+ // despite there being an entry for it there (HttpHeaderConnection). If you try to do
+ // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below
+
+ UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"});
+
+ UnknownHeaders.push_back({.NameLength = 20,
+ .RawValueLength = (USHORT)AcceptKey.size(),
+ .pName = "Sec-WebSocket-Accept",
+ .pRawValue = AcceptKey.c_str()});
+
+ Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size();
+ Response.Headers.pUnknownHeaders = UnknownHeaders.data();
+
+ const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
+
+ // Use an OVERLAPPED with an event so we can wait synchronously.
+ // The request queue is IOCP-associated, so passing NULL for pOverlapped
+ // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent
+ // prevents IOCP delivery and lets us wait on the event directly.
+ OVERLAPPED SendOverlapped = {};
+ HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr);
+ SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1);
+
+ ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle,
+ RequestId,
+ Flags,
+ &Response,
+ nullptr, // CachePolicy
+ nullptr, // BytesSent
+ nullptr, // Reserved1
+ 0, // Reserved2
+ &SendOverlapped,
+ nullptr // LogData
+ );
+
+ if (SendResult == ERROR_IO_PENDING)
+ {
+ WaitForSingleObject(SendEvent, INFINITE);
+ SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE;
+ }
+
+ CloseHandle(SendEvent);
+
+ if (SendResult == NO_ERROR)
+ {
+ Transaction().Server().OnWebSocketConnectionOpened();
+ Ref<WsHttpSysConnection> WsConn(new WsHttpSysConnection(RequestQueueHandle,
+ RequestId,
+ *WsHandler,
+ Transaction().Iocp(),
+ &Transaction().Server()));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+
+ return nullptr;
+ }
+
+ ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult);
+
+ // WebSocket upgrade failed — return nullptr since ServerRequest()
+ // was never populated (no InvokeRequestHandler call)
+ return nullptr;
}
- break;
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
}
- }
-# endif
- if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
- {
if (m_IsInitialRequest)
{
m_ContentLength = GetContentLength(HttpReq);
@@ -2146,6 +2367,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
}
}
+ else
+ {
+ // If a default redirect is configured and the request is for the root path, send a 302
+ std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect();
+ std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength);
+ if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty()))
+ {
+ auto* Response = new HttpMessageResponseRequest(Transaction(), 302);
+ Response->SetLocationHeader(DefaultRedirect);
+ return Response;
+ }
+ }
// Unable to route
return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv);
@@ -2205,12 +2438,81 @@ HttpSysServer::OnRequestExit()
m_ShutdownEvent.Set();
}
+std::string
+HttpSysServer::OnGetExternalHost() const
+{
+ // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback
+ bool IsPublic = false;
+ for (const auto& Uri : m_BaseUris)
+ {
+ if (Uri.find(L'*') != std::wstring::npos)
+ {
+ IsPublic = true;
+ break;
+ }
+ }
+
+ if (!IsPublic)
+ {
+ return "127.0.0.1";
+ }
+
+ // Use the UDP connect trick: connecting a UDP socket to an external address
+ // causes the OS to select the appropriate local interface without sending any data.
+ try
+ {
+ asio::io_service IoService;
+ asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4());
+ Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80));
+ return Sock.local_endpoint().address().to_string();
+ }
+ catch (const std::exception&)
+ {
+ return GetMachineName();
+ }
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesReceived() const
+{
+ return m_TotalBytesReceived.load(std::memory_order_relaxed);
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesSent() const
+{
+ return m_TotalBytesSent.load(std::memory_order_relaxed);
+}
+
void
HttpSysServer::OnRegisterService(HttpService& Service)
{
RegisterService(Service.BaseUri(), Service);
}
+void
+HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ RwLock::ExclusiveLockScope _(m_RequestFilterLock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+IHttpRequestFilter::Result
+HttpSysServer::FilterRequest(HttpSysServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_RequestFilterLock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ return RequestFilter->FilterRequest(Request);
+}
+
Ref<HttpServer>
CreateHttpSysServer(HttpSysConfig Config)
{
diff --git a/src/zenhttp/servers/httpsys_iocontext.h b/src/zenhttp/servers/httpsys_iocontext.h
new file mode 100644
index 000000000..4f8a97012
--- /dev/null
+++ b/src/zenhttp/servers/httpsys_iocontext.h
@@ -0,0 +1,40 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+
+# include <cstdint>
+
+namespace zen {
+
+/**
+ * Tagged OVERLAPPED wrapper for http.sys IOCP dispatch
+ *
+ * Both HttpSysTransaction (for normal HTTP request I/O) and WsHttpSysConnection
+ * (for WebSocket read/write) embed this struct. The single IoCompletionCallback
+ * bound to the request queue uses the ContextType tag to dispatch to the correct
+ * handler.
+ *
+ * The Overlapped member must be first so that CONTAINING_RECORD works to recover
+ * the HttpSysIoContext from the OVERLAPPED pointer provided by the threadpool.
+ */
+struct HttpSysIoContext
+{
+ OVERLAPPED Overlapped{};
+
+ enum class Type : uint8_t
+ {
+ kTransaction,
+ kWebSocketRead,
+ kWebSocketWrite,
+ } ContextType = Type::kTransaction;
+
+ void* Owner = nullptr;
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/httptracer.h b/src/zenhttp/servers/httptracer.h
index da72c79c9..a9a45f162 100644
--- a/src/zenhttp/servers/httptracer.h
+++ b/src/zenhttp/servers/httptracer.h
@@ -1,9 +1,9 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zenhttp/httpserver.h>
-
#pragma once
+#include <zenhttp/httpserver.h>
+
namespace zen {
/** Helper class for HTTP server implementations
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
new file mode 100644
index 000000000..b2543277a
--- /dev/null
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -0,0 +1,311 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wsasio.h"
+#include "wsframecodec.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/httpserver.h>
+
+namespace zen::asio_http {
+
+static LoggerRef
+WsLog()
+{
+ static LoggerRef g_Logger = logging::Get("ws");
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+WsAsioConnection::WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server)
+: m_Socket(std::move(Socket))
+, m_Handler(Handler)
+, m_HttpServer(Server)
+{
+}
+
+WsAsioConnection::~WsAsioConnection()
+{
+ m_IsOpen.store(false);
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketConnectionClosed();
+ }
+}
+
+void
+WsAsioConnection::Start()
+{
+ EnqueueRead();
+}
+
+bool
+WsAsioConnection::IsOpen() const
+{
+ return m_IsOpen.load(std::memory_order_relaxed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Read loop
+//
+
+void
+WsAsioConnection::EnqueueRead()
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ Ref<WsAsioConnection> Self(this);
+
+ asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) {
+ Self->OnDataReceived(Ec, ByteCount);
+ });
+}
+
+void
+WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "WebSocket read error: {}", Ec.message());
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection lost");
+ }
+ return;
+ }
+
+ ProcessReceivedData();
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ EnqueueRead();
+ }
+}
+
+void
+WsAsioConnection::ProcessReceivedData()
+{
+ while (m_ReadBuffer.size() > 0)
+ {
+ const auto& InputBuffer = m_ReadBuffer.data();
+ const auto* Data = static_cast<const uint8_t*>(InputBuffer.data());
+ const auto Size = InputBuffer.size();
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size);
+ if (!Frame.IsValid)
+ {
+ break; // not enough data yet
+ }
+
+ m_ReadBuffer.consume(Frame.BytesConsumed);
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed);
+ }
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with pong carrying the same payload
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ // Unsolicited pong — ignore per RFC 6455
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo close frame back if we haven't sent one yet
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+
+ // Shut down the socket
+ std::error_code ShutdownEc;
+ m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc);
+ m_Socket->close(ShutdownEc);
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(WsLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Write queue
+//
+
+void
+WsAsioConnection::SendText(std::string_view Text)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsAsioConnection::SendBinary(std::span<const uint8_t> Data)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsAsioConnection::Close(uint16_t Code, std::string_view Reason)
+{
+ DoClose(Code, Reason);
+}
+
+void
+WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason)
+{
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+}
+
+void
+WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+{
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameSent(Frame.size());
+ }
+
+ bool ShouldFlush = false;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_WriteQueue.push_back(std::move(Frame));
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ });
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+}
+
+void
+WsAsioConnection::FlushWriteQueue()
+{
+ std::vector<uint8_t> Frame;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+ Frame = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ });
+
+ if (Frame.empty())
+ {
+ return;
+ }
+
+ Ref<WsAsioConnection> Self(this);
+
+ // Move Frame into a shared_ptr so we can create the buffer and capture ownership
+ // in the same async_write call without evaluation order issues.
+ auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
+
+ asio::async_write(*m_Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); });
+}
+
+void
+WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "WebSocket write error: {}", Ec.message());
+ }
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_IsWriting = false;
+ m_WriteQueue.clear();
+ });
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+ return;
+ }
+
+ FlushWriteQueue();
+}
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h
new file mode 100644
index 000000000..e8bb3b1d2
--- /dev/null
+++ b/src/zenhttp/servers/wsasio.h
@@ -0,0 +1,77 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include <zencore/thread.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <deque>
+#include <memory>
+#include <vector>
+
+namespace zen {
+class HttpServer;
+} // namespace zen
+
+namespace zen::asio_http {
+
+/**
+ * WebSocket connection over an ASIO TCP socket
+ *
+ * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake)
+ * and runs an async read/write loop to exchange WebSocket frames.
+ *
+ * Lifetime is managed solely through intrusive reference counting (RefCounted).
+ * The async read/write callbacks capture Ref<WsAsioConnection> to keep the
+ * connection alive for the duration of the async operation. The service layer
+ * also holds a Ref<WebSocketConnection>.
+ */
+
+class WsAsioConnection : public WebSocketConnection
+{
+public:
+ WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server);
+ ~WsAsioConnection() override;
+
+ /**
+ * Start the async read loop. Must be called once after construction
+ * and the 101 response has been sent.
+ */
+ void Start();
+
+ // WebSocketConnection interface
+ void SendText(std::string_view Text) override;
+ void SendBinary(std::span<const uint8_t> Data) override;
+ void Close(uint16_t Code, std::string_view Reason) override;
+ bool IsOpen() const override;
+
+private:
+ void EnqueueRead();
+ void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
+ void ProcessReceivedData();
+
+ void EnqueueWrite(std::vector<uint8_t> Frame);
+ void FlushWriteQueue();
+ void OnWriteComplete(const asio::error_code& Ec, std::size_t ByteCount);
+
+ void DoClose(uint16_t Code, std::string_view Reason);
+
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ IWebSocketHandler& m_Handler;
+ zen::HttpServer* m_HttpServer;
+ asio::streambuf m_ReadBuffer;
+
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ bool m_IsWriting = false;
+
+ std::atomic<bool> m_IsOpen{true};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp
new file mode 100644
index 000000000..e452141fe
--- /dev/null
+++ b/src/zenhttp/servers/wsframecodec.cpp
@@ -0,0 +1,236 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wsframecodec.h"
+
+#include <zencore/base64.h>
+#include <zencore/sha1.h>
+
+#include <cstring>
+#include <random>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame parsing
+//
+
+WsFrameParseResult
+WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size)
+{
+ // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames)
+ if (Size < 2)
+ {
+ return {};
+ }
+
+ const bool Fin = (Data[0] & 0x80) != 0;
+ const uint8_t OpcodeRaw = Data[0] & 0x0F;
+ const bool Masked = (Data[1] & 0x80) != 0;
+ uint64_t PayloadLen = Data[1] & 0x7F;
+
+ size_t HeaderSize = 2;
+
+ if (PayloadLen == 126)
+ {
+ if (Size < 4)
+ {
+ return {};
+ }
+ PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]);
+ HeaderSize = 4;
+ }
+ else if (PayloadLen == 127)
+ {
+ if (Size < 10)
+ {
+ return {};
+ }
+ PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) |
+ (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]);
+ HeaderSize = 10;
+ }
+
+ // Reject frames with unreasonable payload sizes to prevent OOM
+ static constexpr uint64_t kMaxPayloadSize = 256 * 1024 * 1024; // 256 MB
+ if (PayloadLen > kMaxPayloadSize)
+ {
+ return {};
+ }
+
+ const size_t MaskSize = Masked ? 4 : 0;
+ const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen;
+
+ if (Size < TotalFrame)
+ {
+ return {};
+ }
+
+ const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr;
+ const uint8_t* PayloadData = Data + HeaderSize + MaskSize;
+
+ WsFrameParseResult Result;
+ Result.IsValid = true;
+ Result.BytesConsumed = TotalFrame;
+ Result.Opcode = static_cast<WebSocketOpcode>(OpcodeRaw);
+ Result.Fin = Fin;
+
+ Result.Payload.resize(static_cast<size_t>(PayloadLen));
+ if (PayloadLen > 0)
+ {
+ std::memcpy(Result.Payload.data(), PayloadData, static_cast<size_t>(PayloadLen));
+
+ if (Masked)
+ {
+ for (size_t i = 0; i < Result.Payload.size(); ++i)
+ {
+ Result.Payload[i] ^= MaskKey[i & 3];
+ }
+ }
+ }
+
+ return Result;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame building (server-to-client, no masking)
+//
+
+std::vector<uint8_t>
+WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+{
+ std::vector<uint8_t> Frame;
+
+ const size_t PayloadLen = Payload.size();
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length (no mask bit for server frames)
+ if (PayloadLen < 126)
+ {
+ Frame.push_back(static_cast<uint8_t>(PayloadLen));
+ }
+ else if (PayloadLen <= 0xFFFF)
+ {
+ Frame.push_back(126);
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF));
+ }
+ }
+
+ Frame.insert(Frame.end(), Payload.begin(), Payload.end());
+
+ return Frame;
+}
+
+std::vector<uint8_t>
+WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason)
+{
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ Payload.insert(Payload.end(), Reason.begin(), Reason.end());
+
+ return BuildFrame(WebSocketOpcode::kClose, Payload);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame building (client-to-server, with masking)
+//
+
+std::vector<uint8_t>
+WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+{
+ std::vector<uint8_t> Frame;
+
+ const size_t PayloadLen = Payload.size();
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length with mask bit set
+ if (PayloadLen < 126)
+ {
+ Frame.push_back(0x80 | static_cast<uint8_t>(PayloadLen));
+ }
+ else if (PayloadLen <= 0xFFFF)
+ {
+ Frame.push_back(0x80 | 126);
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(0x80 | 127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF));
+ }
+ }
+
+ // Generate random 4-byte mask key
+ static thread_local std::mt19937 s_Rng(std::random_device{}());
+ uint32_t MaskValue = s_Rng();
+ uint8_t MaskKey[4];
+ std::memcpy(MaskKey, &MaskValue, 4);
+
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+
+ // Masked payload
+ for (size_t i = 0; i < PayloadLen; ++i)
+ {
+ Frame.push_back(Payload[i] ^ MaskKey[i & 3]);
+ }
+
+ return Frame;
+}
+
+std::vector<uint8_t>
+WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason)
+{
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ Payload.insert(Payload.end(), Reason.begin(), Reason.end());
+
+ return BuildMaskedFrame(WebSocketOpcode::kClose, Payload);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2)
+//
+
+static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+
+std::string
+WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey)
+{
+ // Concatenate client key with the magic GUID
+ std::string Combined;
+ Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size());
+ Combined.append(ClientKey);
+ Combined.append(kWebSocketMagicGuid);
+
+ // SHA1 hash
+ SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size());
+
+ // Base64 encode the 20-byte hash
+ char Base64Buf[Base64::GetEncodedDataSize(20) + 1];
+ uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf);
+ Base64Buf[EncodedLen] = '\0';
+
+ return std::string(Base64Buf, EncodedLen);
+}
+
+} // namespace zen
diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h
new file mode 100644
index 000000000..2d90b6fa1
--- /dev/null
+++ b/src/zenhttp/servers/wsframecodec.h
@@ -0,0 +1,74 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include <cstddef>
+#include <cstdint>
+#include <optional>
+#include <span>
+#include <string>
+#include <string_view>
+#include <vector>
+
+namespace zen {
+
+/**
+ * Result of attempting to parse a single WebSocket frame from a byte buffer
+ */
+struct WsFrameParseResult
+{
+ bool IsValid = false; // true if a complete frame was successfully parsed
+ size_t BytesConsumed = 0; // number of bytes consumed from the input buffer
+ WebSocketOpcode Opcode = WebSocketOpcode::kText;
+ bool Fin = false;
+ std::vector<uint8_t> Payload;
+};
+
+/**
+ * RFC 6455 WebSocket frame codec
+ *
+ * Provides static helpers for parsing client-to-server frames (which are
+ * always masked) and building server-to-client frames (which are never masked).
+ */
+struct WsFrameCodec
+{
+ /**
+ * Try to parse one complete frame from the front of the buffer.
+ *
+ * Returns a result with IsValid == false and BytesConsumed == 0 when
+ * there is not enough data yet. The caller should accumulate more data
+ * and retry.
+ */
+ static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size);
+
+ /**
+ * Build a server-to-client frame (no masking)
+ */
+ static std::vector<uint8_t> BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload);
+
+ /**
+ * Build a close frame with a status code and optional reason string
+ */
+ static std::vector<uint8_t> BuildCloseFrame(uint16_t Code, std::string_view Reason = {});
+
+ /**
+ * Build a client-to-server frame (with masking per RFC 6455)
+ */
+ static std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload);
+
+ /**
+ * Build a masked close frame with status code and optional reason
+ */
+ static std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason = {});
+
+ /**
+ * Compute the Sec-WebSocket-Accept value per RFC 6455 section 4.2.2
+ *
+ * accept = Base64(SHA1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
+ */
+ static std::string ComputeAcceptKey(std::string_view ClientKey);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp
new file mode 100644
index 000000000..af320172d
--- /dev/null
+++ b/src/zenhttp/servers/wshttpsys.cpp
@@ -0,0 +1,485 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wshttpsys.h"
+
+#if ZEN_WITH_HTTPSYS
+
+# include "wsframecodec.h"
+
+# include <zencore/logging.h>
+# include <zenhttp/httpserver.h>
+
+namespace zen {
+
+static LoggerRef
+WsHttpSysLog()
+{
+ static LoggerRef g_Logger = logging::Get("ws_httpsys");
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle,
+ HTTP_REQUEST_ID RequestId,
+ IWebSocketHandler& Handler,
+ PTP_IO Iocp,
+ HttpServer* Server)
+: m_RequestQueueHandle(RequestQueueHandle)
+, m_RequestId(RequestId)
+, m_Handler(Handler)
+, m_Iocp(Iocp)
+, m_HttpServer(Server)
+, m_ReadBuffer(8192)
+{
+ m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead;
+ m_ReadIoContext.Owner = this;
+ m_WriteIoContext.ContextType = HttpSysIoContext::Type::kWebSocketWrite;
+ m_WriteIoContext.Owner = this;
+}
+
+WsHttpSysConnection::~WsHttpSysConnection()
+{
+ ZEN_ASSERT(m_OutstandingOps.load() == 0);
+
+ if (m_IsOpen.exchange(false))
+ {
+ Disconnect();
+ }
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketConnectionClosed();
+ }
+}
+
+void
+WsHttpSysConnection::Start()
+{
+ m_SelfRef = Ref<WsHttpSysConnection>(this);
+ IssueAsyncRead();
+}
+
+void
+WsHttpSysConnection::Shutdown()
+{
+ m_ShutdownRequested.store(true, std::memory_order_relaxed);
+
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED
+ HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
+}
+
+bool
+WsHttpSysConnection::IsOpen() const
+{
+ return m_IsOpen.load(std::memory_order_relaxed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Async read path
+//
+
+void
+WsHttpSysConnection::IssueAsyncRead()
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed) || m_ShutdownRequested.load(std::memory_order_relaxed))
+ {
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ m_OutstandingOps.fetch_add(1, std::memory_order_relaxed);
+
+ ZeroMemory(&m_ReadIoContext.Overlapped, sizeof(OVERLAPPED));
+
+ StartThreadpoolIo(m_Iocp);
+
+ ULONG Result = HttpReceiveRequestEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ 0, // Flags
+ m_ReadBuffer.data(),
+ (ULONG)m_ReadBuffer.size(),
+ nullptr, // BytesRead (ignored for async)
+ &m_ReadIoContext.Overlapped);
+
+ if (Result != NO_ERROR && Result != ERROR_IO_PENDING)
+ {
+ CancelThreadpoolIo(m_Iocp);
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "read issue failed");
+ }
+
+ MaybeReleaseSelfRef();
+ }
+}
+
+void
+WsHttpSysConnection::OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ // Hold a transient ref to prevent mid-callback destruction after MaybeReleaseSelfRef
+ Ref<WsHttpSysConnection> Guard(this);
+
+ if (IoResult != NO_ERROR)
+ {
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.exchange(false))
+ {
+ if (IoResult == ERROR_HANDLE_EOF)
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection closed");
+ }
+ else if (IoResult != ERROR_OPERATION_ABORTED)
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection lost");
+ }
+ }
+
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ if (NumberOfBytesTransferred > 0)
+ {
+ m_Accumulated.insert(m_Accumulated.end(), m_ReadBuffer.begin(), m_ReadBuffer.begin() + NumberOfBytesTransferred);
+ ProcessReceivedData();
+ }
+
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ IssueAsyncRead();
+ }
+ else
+ {
+ MaybeReleaseSelfRef();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame parsing
+//
+
+void
+WsHttpSysConnection::ProcessReceivedData()
+{
+ while (!m_Accumulated.empty())
+ {
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size());
+ if (!Frame.IsValid)
+ {
+ break; // not enough data yet
+ }
+
+ // Remove consumed bytes
+ m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed);
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed);
+ }
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with pong carrying the same payload
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ // Unsolicited pong — ignore per RFC 6455
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo close frame back if we haven't sent one yet
+ {
+ bool ShouldSendClose = false;
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ if (!m_CloseSent.exchange(true))
+ {
+ ShouldSendClose = true;
+ }
+ }
+ if (ShouldSendClose)
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+ Disconnect();
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(WsHttpSysLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Async write path
+//
+
+void
+WsHttpSysConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+{
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameSent(Frame.size());
+ }
+
+ bool ShouldFlush = false;
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.push_back(std::move(Frame));
+
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ }
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+}
+
+void
+WsHttpSysConnection::FlushWriteQueue()
+{
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+
+ m_CurrentWriteBuffer = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ }
+
+ m_OutstandingOps.fetch_add(1, std::memory_order_relaxed);
+
+ ZeroMemory(&m_WriteChunk, sizeof(m_WriteChunk));
+ m_WriteChunk.DataChunkType = HttpDataChunkFromMemory;
+ m_WriteChunk.FromMemory.pBuffer = m_CurrentWriteBuffer.data();
+ m_WriteChunk.FromMemory.BufferLength = (ULONG)m_CurrentWriteBuffer.size();
+
+ ZeroMemory(&m_WriteIoContext.Overlapped, sizeof(OVERLAPPED));
+
+ StartThreadpoolIo(m_Iocp);
+
+ ULONG Result = HttpSendResponseEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ HTTP_SEND_RESPONSE_FLAG_MORE_DATA,
+ 1,
+ &m_WriteChunk,
+ nullptr,
+ nullptr,
+ 0,
+ &m_WriteIoContext.Overlapped,
+ nullptr);
+
+ if (Result != NO_ERROR && Result != ERROR_IO_PENDING)
+ {
+ CancelThreadpoolIo(m_Iocp);
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket async write failed: {}", Result);
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.clear();
+ m_IsWriting = false;
+ }
+ m_CurrentWriteBuffer.clear();
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+
+ MaybeReleaseSelfRef();
+ }
+}
+
+void
+WsHttpSysConnection::OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ ZEN_UNUSED(NumberOfBytesTransferred);
+
+ // Hold a transient ref to prevent mid-callback destruction
+ Ref<WsHttpSysConnection> Guard(this);
+
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+ m_CurrentWriteBuffer.clear();
+
+ if (IoResult != NO_ERROR)
+ {
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket write completion error: {}", IoResult);
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.clear();
+ m_IsWriting = false;
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ FlushWriteQueue();
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Send interface
+//
+
+void
+WsHttpSysConnection::SendText(std::string_view Text)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsHttpSysConnection::SendBinary(std::span<const uint8_t> Data)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsHttpSysConnection::Close(uint16_t Code, std::string_view Reason)
+{
+ DoClose(Code, Reason);
+}
+
+void
+WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason)
+{
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ {
+ bool ShouldSendClose = false;
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ if (!m_CloseSent.exchange(true))
+ {
+ ShouldSendClose = true;
+ }
+ }
+ if (ShouldSendClose)
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+
+ // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED
+ HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Lifetime management
+//
+
+void
+WsHttpSysConnection::MaybeReleaseSelfRef()
+{
+ if (m_OutstandingOps.load(std::memory_order_relaxed) == 0 && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ m_SelfRef = nullptr;
+ }
+}
+
+void
+WsHttpSysConnection::Disconnect()
+{
+ // Send final empty body with DISCONNECT to tell http.sys the connection is done
+ HttpSendResponseEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ HTTP_SEND_RESPONSE_FLAG_DISCONNECT,
+ 0,
+ nullptr,
+ nullptr,
+ nullptr,
+ 0,
+ nullptr,
+ nullptr);
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h
new file mode 100644
index 000000000..6015e3873
--- /dev/null
+++ b/src/zenhttp/servers/wshttpsys.h
@@ -0,0 +1,107 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include "httpsys_iocontext.h"
+
+#include <zencore/thread.h>
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+# include <http.h>
+
+# include <atomic>
+# include <deque>
+# include <vector>
+
+namespace zen {
+
+class HttpServer;
+
+/**
+ * WebSocket connection over an http.sys opaque-mode connection
+ *
+ * After the 101 Switching Protocols response is sent with
+ * HTTP_SEND_RESPONSE_FLAG_OPAQUE, http.sys stops parsing HTTP on the
+ * connection. Raw bytes are exchanged via HttpReceiveRequestEntityBody /
+ * HttpSendResponseEntityBody using the original RequestId.
+ *
+ * All I/O is performed asynchronously via the same IOCP threadpool used
+ * for normal http.sys traffic, eliminating per-connection threads.
+ *
+ * Lifetime is managed through intrusive reference counting (RefCounted).
+ * A self-reference (m_SelfRef) is held from Start() until all outstanding
+ * async operations have drained, preventing premature destruction.
+ */
+class WsHttpSysConnection : public WebSocketConnection
+{
+public:
+ WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp, HttpServer* Server);
+ ~WsHttpSysConnection() override;
+
+ /**
+ * Start the async read loop. Must be called once after construction
+ * and after the 101 response has been sent.
+ */
+ void Start();
+
+ /**
+ * Shut down the connection. Cancels pending I/O; IOCP completions
+ * will fire with ERROR_OPERATION_ABORTED and drain naturally.
+ */
+ void Shutdown();
+
+ // WebSocketConnection interface
+ void SendText(std::string_view Text) override;
+ void SendBinary(std::span<const uint8_t> Data) override;
+ void Close(uint16_t Code, std::string_view Reason) override;
+ bool IsOpen() const override;
+
+ // Called from IoCompletionCallback via tagged dispatch
+ void OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+ void OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+
+private:
+ void IssueAsyncRead();
+ void ProcessReceivedData();
+ void EnqueueWrite(std::vector<uint8_t> Frame);
+ void FlushWriteQueue();
+ void DoClose(uint16_t Code, std::string_view Reason);
+ void Disconnect();
+ void MaybeReleaseSelfRef();
+
+ HANDLE m_RequestQueueHandle;
+ HTTP_REQUEST_ID m_RequestId;
+ IWebSocketHandler& m_Handler;
+ PTP_IO m_Iocp;
+ HttpServer* m_HttpServer;
+
+ // Tagged OVERLAPPED contexts for concurrent read and write
+ HttpSysIoContext m_ReadIoContext{};
+ HttpSysIoContext m_WriteIoContext{};
+
+ // Read state
+ std::vector<uint8_t> m_ReadBuffer;
+ std::vector<uint8_t> m_Accumulated;
+
+ // Write state
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ std::vector<uint8_t> m_CurrentWriteBuffer;
+ HTTP_DATA_CHUNK m_WriteChunk{};
+ bool m_IsWriting = false;
+
+ // Lifetime management
+ std::atomic<int32_t> m_OutstandingOps{0};
+ Ref<WsHttpSysConnection> m_SelfRef;
+ std::atomic<bool> m_ShutdownRequested{false};
+ std::atomic<bool> m_IsOpen{true};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
new file mode 100644
index 000000000..2134e4ff1
--- /dev/null
+++ b/src/zenhttp/servers/wstest.cpp
@@ -0,0 +1,925 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/scopeguard.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+# include <zenhttp/httpserver.h>
+# include <zenhttp/httpwsclient.h>
+# include <zenhttp/websocket.h>
+
+# include "httpasio.h"
+# include "wsframecodec.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# if ZEN_PLATFORM_WINDOWS
+# include <winsock2.h>
+# else
+# include <poll.h>
+# include <sys/socket.h>
+# endif
+# include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+# include <atomic>
+# include <chrono>
+# include <cstring>
+# include <random>
+# include <string>
+# include <string_view>
+# include <thread>
+# include <vector>
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Unit tests: WsFrameCodec
+//
+
+TEST_SUITE_BEGIN("http.wstest");
+
+TEST_CASE("websocket.framecodec")
+{
+ SUBCASE("ComputeAcceptKey RFC 6455 test vector")
+ {
+ // RFC 6455 section 4.2.2 example
+ std::string AcceptKey = WsFrameCodec::ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
+ CHECK_EQ(AcceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
+ }
+
+ SUBCASE("BuildFrame and TryParseFrame roundtrip - text")
+ {
+ std::string_view Text = "Hello, WebSocket!";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+
+ // Server frames are unmasked — TryParseFrame should handle them
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, Frame.size());
+ CHECK(Result.Fin);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), Text.size());
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text);
+ }
+
+ SUBCASE("BuildFrame and TryParseFrame roundtrip - binary")
+ {
+ std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD};
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, BinaryData);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary);
+ CHECK_EQ(Result.Payload, BinaryData);
+ }
+
+ SUBCASE("BuildFrame - medium payload (126-65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(300, 0x42);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 300u);
+ CHECK_EQ(Result.Payload, Payload);
+ }
+
+ SUBCASE("BuildFrame - large payload (>65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(70000, 0xAB);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 70000u);
+ }
+
+ SUBCASE("BuildCloseFrame roundtrip")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildCloseFrame(1000, "normal closure");
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Result.Payload.size() >= 2);
+
+ uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]);
+ CHECK_EQ(Code, 1000);
+
+ std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2);
+ CHECK_EQ(Reason, "normal closure");
+ }
+
+ SUBCASE("TryParseFrame - partial data returns invalid")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
+
+ // Pass only 1 byte — not enough for a frame header
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1);
+ CHECK_FALSE(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, 0u);
+ }
+
+ SUBCASE("TryParseFrame - empty payload")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK(Result.Payload.empty());
+ }
+
+ SUBCASE("TryParseFrame - masked client frame")
+ {
+ // Build a masked frame manually as a client would send
+ // Frame: FIN=1, opcode=text, MASK=1, payload_len=5, mask_key=0x37FA213D, payload="Hello"
+ uint8_t MaskKey[4] = {0x37, 0xFA, 0x21, 0x3D};
+ uint8_t MaskedPayload[5] = {};
+ const char* Original = "Hello";
+ for (int i = 0; i < 5; ++i)
+ {
+ MaskedPayload[i] = static_cast<uint8_t>(Original[i]) ^ MaskKey[i % 4];
+ }
+
+ std::vector<uint8_t> Frame;
+ Frame.push_back(0x81); // FIN + text
+ Frame.push_back(0x85); // MASK + len=5
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+ Frame.insert(Frame.end(), MaskedPayload, MaskedPayload + 5);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), 5u);
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), 5), "Hello"sv);
+ }
+
+ SUBCASE("BuildMaskedFrame roundtrip - text")
+ {
+ std::string_view Text = "Hello, masked WebSocket!";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+
+ // Verify mask bit is set
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, Frame.size());
+ CHECK(Result.Fin);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), Text.size());
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text);
+ }
+
+ SUBCASE("BuildMaskedFrame roundtrip - binary")
+ {
+ std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD};
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, BinaryData);
+
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary);
+ CHECK_EQ(Result.Payload, BinaryData);
+ }
+
+ SUBCASE("BuildMaskedFrame - medium payload (126-65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(300, 0x42);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload);
+
+ CHECK((Frame[1] & 0x80) != 0);
+ CHECK_EQ((Frame[1] & 0x7F), 126); // 16-bit extended length
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 300u);
+ CHECK_EQ(Result.Payload, Payload);
+ }
+
+ SUBCASE("BuildMaskedFrame - large payload (>65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(70000, 0xAB);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload);
+
+ CHECK((Frame[1] & 0x80) != 0);
+ CHECK_EQ((Frame[1] & 0x7F), 127); // 64-bit extended length
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 70000u);
+ }
+
+ SUBCASE("BuildMaskedCloseFrame roundtrip")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure");
+
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Result.Payload.size() >= 2);
+
+ uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]);
+ CHECK_EQ(Code, 1000);
+
+ std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2);
+ CHECK_EQ(Reason, "normal closure");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Integration tests: WebSocket over ASIO
+//
+
+namespace {
+
+ /**
+ * Helper: Build a masked client-to-server frame per RFC 6455
+ */
+ std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+ {
+ std::vector<uint8_t> Frame;
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length with mask bit set
+ if (Payload.size() < 126)
+ {
+ Frame.push_back(0x80 | static_cast<uint8_t>(Payload.size()));
+ }
+ else if (Payload.size() <= 0xFFFF)
+ {
+ Frame.push_back(0x80 | 126);
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(Payload.size() & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(0x80 | 127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> (i * 8)) & 0xFF));
+ }
+ }
+
+ // Mask key (use a fixed key for deterministic tests)
+ uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78};
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+
+ // Masked payload
+ for (size_t i = 0; i < Payload.size(); ++i)
+ {
+ Frame.push_back(Payload[i] ^ MaskKey[i & 3]);
+ }
+
+ return Frame;
+ }
+
+ std::vector<uint8_t> BuildMaskedTextFrame(std::string_view Text)
+ {
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ return BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ }
+
+ std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code)
+ {
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ return BuildMaskedFrame(WebSocketOpcode::kClose, Payload);
+ }
+
+ /**
+ * Test service that implements IWebSocketHandler
+ */
+ struct WsTestService : public HttpService, public IWebSocketHandler
+ {
+ const char* BaseUri() const override { return "/wstest/"; }
+
+ void HandleRequest(HttpServerRequest& Request) override
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello from wstest");
+ }
+
+ // IWebSocketHandler
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ {
+ m_OpenCount.fetch_add(1);
+
+ m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); });
+ }
+
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override
+ {
+ m_MessageCount.fetch_add(1);
+
+ if (Msg.Opcode == WebSocketOpcode::kText)
+ {
+ std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size());
+ m_LastMessage = std::string(Text);
+
+ // Echo the message back
+ Conn.SendText(Text);
+ }
+ }
+
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) override
+ {
+ m_CloseCount.fetch_add(1);
+ m_LastCloseCode = Code;
+
+ m_ConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_Connections.begin(), m_Connections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_Connections.erase(It, m_Connections.end());
+ });
+ }
+
+ void SendToAll(std::string_view Text)
+ {
+ RwLock::SharedLockScope _(m_ConnectionsLock);
+ for (auto& Conn : m_Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendText(Text);
+ }
+ }
+ }
+
+ std::atomic<int> m_OpenCount{0};
+ std::atomic<int> m_MessageCount{0};
+ std::atomic<int> m_CloseCount{0};
+ std::atomic<uint16_t> m_LastCloseCode{0};
+ std::string m_LastMessage;
+
+ RwLock m_ConnectionsLock;
+ std::vector<Ref<WebSocketConnection>> m_Connections;
+ };
+
+ /**
+ * Helper: Perform the WebSocket upgrade handshake on a raw TCP socket
+ *
+ * Returns true on success (101 response), false otherwise.
+ */
+ bool DoWebSocketHandshake(asio::ip::tcp::socket& Sock, std::string_view Path, int Port)
+ {
+ // Send HTTP upgrade request
+ ExtendableStringBuilder<512> Request;
+ Request << "GET " << Path << " HTTP/1.1\r\n"
+ << "Host: 127.0.0.1:" << Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ << "Sec-WebSocket-Version: 13\r\n"
+ << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+
+ asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size()));
+
+ // Read the response (look for "101")
+ asio::streambuf ResponseBuf;
+ asio::read_until(Sock, ResponseBuf, "\r\n\r\n");
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+
+ return Response.find("101") != std::string::npos;
+ }
+
+ /**
+ * Helper: Read a single server-to-client frame from a socket
+ *
+ * Uses a background thread with a synchronous ASIO read and a timeout.
+ */
+ WsFrameParseResult ReadOneFrame(asio::ip::tcp::socket& Sock, int TimeoutMs = 5000)
+ {
+ std::vector<uint8_t> Buffer;
+ WsFrameParseResult Result;
+ std::atomic<bool> Done{false};
+
+ std::thread Reader([&] {
+ while (!Done.load())
+ {
+ uint8_t Tmp[4096];
+ asio::error_code Ec;
+ size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec);
+ if (Ec || BytesRead == 0)
+ {
+ break;
+ }
+
+ Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead);
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size());
+ if (Frame.IsValid)
+ {
+ Result = std::move(Frame);
+ Done.store(true);
+ return;
+ }
+ }
+ });
+
+ auto Deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(TimeoutMs);
+ while (!Done.load() && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ if (!Done.load())
+ {
+ // Timeout — cancel the read
+ asio::error_code Ec;
+ Sock.cancel(Ec);
+ }
+
+ if (Reader.joinable())
+ {
+ Reader.join();
+ }
+
+ return Result;
+ }
+
+} // anonymous namespace
+
+TEST_CASE("websocket.integration")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = Server->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ // Give server a moment to start accepting
+ Sleep(100);
+
+ SUBCASE("handshake succeeds with 101")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ CHECK(Ok);
+
+ Sleep(50);
+ CHECK_EQ(TestService.m_OpenCount.load(), 1);
+
+ Sock.close();
+ }
+
+ SUBCASE("normal HTTP still works alongside WebSocket service")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ // Send a normal HTTP GET (not upgrade)
+ std::string HttpReq = fmt::format(
+ "GET /wstest/hello HTTP/1.1\r\n"
+ "Host: 127.0.0.1:{}\r\n"
+ "Connection: close\r\n"
+ "\r\n",
+ Port);
+
+ asio::write(Sock, asio::buffer(HttpReq));
+
+ asio::streambuf ResponseBuf;
+ asio::error_code Ec;
+ asio::read(Sock, ResponseBuf, asio::transfer_at_least(1), Ec);
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+ CHECK(Response.find("200") != std::string::npos);
+ }
+
+ SUBCASE("echo message roundtrip")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send a text message (masked, as client)
+ std::vector<uint8_t> Frame = BuildMaskedTextFrame("ping test");
+ asio::write(Sock, asio::buffer(Frame));
+
+ // Read the echo reply
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText);
+ std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(ReplyText, "ping test"sv);
+ CHECK_EQ(TestService.m_MessageCount.load(), 1);
+ CHECK_EQ(TestService.m_LastMessage, "ping test");
+
+ Sock.close();
+ }
+
+ SUBCASE("server push to client")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Server pushes a message
+ TestService.SendToAll("server says hello");
+
+ WsFrameParseResult Frame = ReadOneFrame(Sock);
+ REQUIRE(Frame.IsValid);
+ CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText);
+ std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size());
+ CHECK_EQ(Text, "server says hello"sv);
+
+ Sock.close();
+ }
+
+ SUBCASE("client close handshake")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send close frame
+ std::vector<uint8_t> CloseFrame = BuildMaskedCloseFrame(1000);
+ asio::write(Sock, asio::buffer(CloseFrame));
+
+ // Server should echo close back
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose);
+
+ Sleep(50);
+ CHECK_EQ(TestService.m_CloseCount.load(), 1);
+ CHECK_EQ(TestService.m_LastCloseCode.load(), 1000);
+
+ Sock.close();
+ }
+
+ SUBCASE("multiple concurrent connections")
+ {
+ constexpr int NumClients = 5;
+
+ asio::io_context IoCtx;
+ std::vector<asio::ip::tcp::socket> Sockets;
+
+ for (int i = 0; i < NumClients; ++i)
+ {
+ Sockets.emplace_back(IoCtx);
+ Sockets.back().connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sockets.back(), "/wstest/ws", Port);
+ REQUIRE(Ok);
+ }
+
+ Sleep(100);
+ CHECK_EQ(TestService.m_OpenCount.load(), NumClients);
+
+ // Broadcast from server
+ TestService.SendToAll("broadcast");
+
+ // Each client should receive the message
+ for (int i = 0; i < NumClients; ++i)
+ {
+ WsFrameParseResult Frame = ReadOneFrame(Sockets[i]);
+ REQUIRE(Frame.IsValid);
+ CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText);
+ std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size());
+ CHECK_EQ(Text, "broadcast"sv);
+ }
+
+ // Close all
+ for (auto& S : Sockets)
+ {
+ S.close();
+ }
+ }
+
+ SUBCASE("service without IWebSocketHandler rejects upgrade")
+ {
+ // Register a plain HTTP service (no WebSocket)
+ struct PlainService : public HttpService
+ {
+ const char* BaseUri() const override { return "/plain/"; }
+ void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK); }
+ };
+
+ PlainService Plain;
+ Server->RegisterService(Plain);
+
+ Sleep(50);
+
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ // Attempt WebSocket upgrade on the plain service
+ ExtendableStringBuilder<512> Request;
+ Request << "GET /plain/ws HTTP/1.1\r\n"
+ << "Host: 127.0.0.1:" << Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ << "Sec-WebSocket-Version: 13\r\n"
+ << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+ asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size()));
+
+ asio::streambuf ResponseBuf;
+ asio::read_until(Sock, ResponseBuf, "\r\n\r\n");
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+
+ // Should NOT get 101 — should fall through to normal request handling
+ CHECK(Response.find("101") == std::string::npos);
+
+ Sock.close();
+ }
+
+ SUBCASE("ping/pong auto-response")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send a ping frame with payload "test"
+ std::string_view PingPayload = "test";
+ std::span<const uint8_t> PingData(reinterpret_cast<const uint8_t*>(PingPayload.data()), PingPayload.size());
+ std::vector<uint8_t> PingFrame = BuildMaskedFrame(WebSocketOpcode::kPing, PingData);
+ asio::write(Sock, asio::buffer(PingFrame));
+
+ // Should receive a pong with the same payload
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kPong);
+ CHECK_EQ(Reply.Payload.size(), 4u);
+ std::string_view PongText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(PongText, "test"sv);
+
+ Sock.close();
+ }
+
+ SUBCASE("multiple messages in sequence")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ for (int i = 0; i < 10; ++i)
+ {
+ std::string Msg = fmt::format("message {}", i);
+ std::vector<uint8_t> Frame = BuildMaskedTextFrame(Msg);
+ asio::write(Sock, asio::buffer(Frame));
+
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText);
+ std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(ReplyText, Msg);
+ }
+
+ CHECK_EQ(TestService.m_MessageCount.load(), 10);
+
+ Sock.close();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Integration tests: HttpWsClient
+//
+
+namespace {
+
+ struct TestWsClientHandler : public IWsClientHandler
+ {
+ void OnWsOpen() override { m_OpenCount.fetch_add(1); }
+
+ void OnWsMessage(const WebSocketMessage& Msg) override
+ {
+ if (Msg.Opcode == WebSocketOpcode::kText)
+ {
+ std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size());
+ m_LastMessage = std::string(Text);
+ }
+ m_MessageCount.fetch_add(1);
+ }
+
+ void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override
+ {
+ m_CloseCount.fetch_add(1);
+ m_LastCloseCode = Code;
+ }
+
+ std::atomic<int> m_OpenCount{0};
+ std::atomic<int> m_MessageCount{0};
+ std::atomic<int> m_CloseCount{0};
+ std::atomic<uint16_t> m_LastCloseCode{0};
+ std::string m_LastMessage;
+ };
+
+} // anonymous namespace
+
+TEST_CASE("websocket.client")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = Server->Initialize(7576, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ Sleep(100);
+
+ SUBCASE("connect, echo, close")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port);
+
+ HttpWsClient Client(Url, Handler);
+ Client.Connect();
+
+ // Wait for OnWsOpen
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+ CHECK(Client.IsOpen());
+
+ // Send text, expect echo
+ Client.SendText("hello from client");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ CHECK_EQ(Handler.m_MessageCount.load(), 1);
+ CHECK_EQ(Handler.m_LastMessage, "hello from client");
+
+ // Close
+ Client.Close(1000, "done");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ // The server echoes the close frame, which triggers OnWsClose on the client side
+ // with the server's close code. Allow the connection to settle.
+ Sleep(50);
+ CHECK_FALSE(Client.IsOpen());
+ }
+
+ SUBCASE("connect to bad port")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = "ws://127.0.0.1:1/wstest/ws";
+
+ HttpWsClient Client(Url, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)});
+ Client.Connect();
+
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ CHECK_EQ(Handler.m_CloseCount.load(), 1);
+ CHECK_EQ(Handler.m_LastCloseCode.load(), 1006);
+ CHECK_EQ(Handler.m_OpenCount.load(), 0);
+ }
+
+ SUBCASE("server-initiated close")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port);
+
+ HttpWsClient Client(Url, Handler);
+ Client.Connect();
+
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+
+ // Copy connections then close them outside the lock to avoid deadlocking
+ // with OnWebSocketClose which acquires an exclusive lock
+ std::vector<Ref<WebSocketConnection>> Conns;
+ TestService.m_ConnectionsLock.WithSharedLock([&] { Conns = TestService.m_Connections; });
+ for (auto& Conn : Conns)
+ {
+ Conn->Close(1001, "going away");
+ }
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ CHECK_EQ(Handler.m_CloseCount.load(), 1);
+ CHECK_EQ(Handler.m_LastCloseCode.load(), 1001);
+ CHECK_FALSE(Client.IsOpen());
+ }
+}
+
+TEST_SUITE_END();
+
+void
+websocket_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_TESTS