aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/servers
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp/servers')
-rw-r--r--src/zenhttp/servers/httpasio.cpp42
-rw-r--r--src/zenhttp/servers/httpmulti.cpp4
-rw-r--r--src/zenhttp/servers/httpnull.cpp4
-rw-r--r--src/zenhttp/servers/httpparser.cpp414
-rw-r--r--src/zenhttp/servers/httpparser.h8
-rw-r--r--src/zenhttp/servers/httpplugin.cpp42
-rw-r--r--src/zenhttp/servers/httpsys.cpp47
-rw-r--r--src/zenhttp/servers/wsasio.cpp2
-rw-r--r--src/zenhttp/servers/wshttpsys.cpp6
-rw-r--r--src/zenhttp/servers/wstest.cpp37
10 files changed, 520 insertions, 86 deletions
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 7972777b8..b624c3a29 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -625,6 +625,8 @@ public:
void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; }
void SetKeepAlive(bool KeepAlive) { m_IsKeepAlive = KeepAlive; }
+ void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); }
+ void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); }
/**
* Initialize the response for sending a payload made up of multiple blobs
@@ -768,10 +770,18 @@ public:
{
ZEN_MEMSCOPE(GetHttpasioTag());
+ std::string_view ContentTypeStr =
+ m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride);
+
m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
- << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
+ << "Content-Type: " << ContentTypeStr << "\r\n"
<< "Content-Length: " << ContentLength() << "\r\n"sv;
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv;
+ }
+
if (!m_IsKeepAlive)
{
m_Headers << "Connection: close\r\n"sv;
@@ -898,7 +908,9 @@ private:
bool m_AllowZeroCopyFileSend = true;
State m_State = State::kUninitialized;
HttpContentType m_ContentType = HttpContentType::kBinary;
- uint64_t m_ContentLength = 0;
+ std::string m_ContentTypeOverride;
+ std::string m_ContentRangeHeader;
+ uint64_t m_ContentLength = 0;
eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive
ExtendableStringBuilder<160> m_Headers;
@@ -1275,7 +1287,9 @@ HttpServerConnectionT<SocketType>::HandleRequest()
asio::buffer(ResponseStr->data(), ResponseStr->size()),
asio::bind_executor(
m_Strand,
- [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr, PrefixLen = Service->UriPrefixLength()](
+ const asio::error_code& Ec,
+ std::size_t) {
if (Ec)
{
ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
@@ -1287,7 +1301,9 @@ HttpServerConnectionT<SocketType>::HandleRequest()
Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
Ref<WebSocketConnection> WsConnRef(WsConn.Get());
- WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ std::string_view FullUrl = Conn->m_RequestData.Url();
+ std::string_view RelativeUri = FullUrl.substr(std::min(PrefixLen, static_cast<int>(FullUrl.size())));
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri);
WsConn->Start();
}));
@@ -1295,7 +1311,7 @@ HttpServerConnectionT<SocketType>::HandleRequest()
return;
}
}
- // Service doesn't support WebSocket or missing key — fall through to normal handling
+ // Service doesn't support WebSocket or missing key - fall through to normal handling
}
if (!m_RequestData.IsKeepAlive())
@@ -2127,6 +2143,10 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber));
m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
m_Response->SetKeepAlive(m_Request.IsKeepAlive());
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
std::array<IoBuffer, 0> Empty;
m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
@@ -2142,6 +2162,14 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
m_Response.reset(new HttpResponse(ContentType, m_RequestNumber));
m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
m_Response->SetKeepAlive(m_Request.IsKeepAlive());
+ if (!m_ContentTypeOverride.empty())
+ {
+ m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride));
+ }
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
}
@@ -2590,7 +2618,7 @@ HttpAsioServer::OnRun(bool IsInteractive)
}
ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
- } while (!ShutdownRequested);
+ } while (!ShutdownRequested && !IsApplicationExitRequested());
#else
if (IsInteractive)
{
@@ -2600,7 +2628,7 @@ HttpAsioServer::OnRun(bool IsInteractive)
do
{
ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
- } while (!ShutdownRequested);
+ } while (!ShutdownRequested && !IsApplicationExitRequested());
#endif
}
diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp
index 584e06cbf..196c0c142 100644
--- a/src/zenhttp/servers/httpmulti.cpp
+++ b/src/zenhttp/servers/httpmulti.cpp
@@ -88,7 +88,7 @@ HttpMultiServer::OnRun(bool IsInteractiveSession)
}
ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
- } while (!ShutdownRequested);
+ } while (!ShutdownRequested && !IsApplicationExitRequested());
#else
if (IsInteractiveSession)
{
@@ -98,7 +98,7 @@ HttpMultiServer::OnRun(bool IsInteractiveSession)
do
{
ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
- } while (!ShutdownRequested);
+ } while (!ShutdownRequested && !IsApplicationExitRequested());
#endif
}
diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp
index 9bb7ef3bc..d698bcb9d 100644
--- a/src/zenhttp/servers/httpnull.cpp
+++ b/src/zenhttp/servers/httpnull.cpp
@@ -63,7 +63,7 @@ HttpNullServer::OnRun(bool IsInteractiveSession)
}
ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
- } while (!ShutdownRequested);
+ } while (!ShutdownRequested && !IsApplicationExitRequested());
#else
if (IsInteractiveSession)
{
@@ -73,7 +73,7 @@ HttpNullServer::OnRun(bool IsInteractiveSession)
do
{
ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
- } while (!ShutdownRequested);
+ } while (!ShutdownRequested && !IsApplicationExitRequested());
#endif
}
diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp
index 918b55dc6..8b07c7905 100644
--- a/src/zenhttp/servers/httpparser.cpp
+++ b/src/zenhttp/servers/httpparser.cpp
@@ -8,6 +8,13 @@
#include <limits>
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+# include <cstring>
+# include <string>
+# include <string_view>
+#endif
+
namespace zen {
using namespace std::literals;
@@ -29,25 +36,25 @@ static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-W
// HttpRequestParser
//
-http_parser_settings HttpRequestParser::s_ParserSettings{
- .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); },
- .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); },
- .on_status =
- [](http_parser* p, const char* Data, size_t ByteCount) {
- ZEN_UNUSED(p, Data, ByteCount);
- return 0;
- },
- .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); },
- .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); },
- .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); },
- .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); },
- .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); },
- .on_chunk_header{},
- .on_chunk_complete{}};
+// clang-format off
+llhttp_settings_t HttpRequestParser::s_ParserSettings = []() {
+ llhttp_settings_t S;
+ llhttp_settings_init(&S);
+ S.on_message_begin = [](llhttp_t* p) { return GetThis(p)->OnMessageBegin(); };
+ S.on_url = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); };
+ S.on_status = [](llhttp_t*, const char*, size_t) { return 0; };
+ S.on_header_field = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); };
+ S.on_header_value = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); };
+ S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); };
+ S.on_body = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); };
+ S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); };
+ return S;
+}();
+// clang-format on
HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection)
{
- http_parser_init(&m_Parser, HTTP_REQUEST);
+ llhttp_init(&m_Parser, HTTP_REQUEST, &s_ParserSettings);
m_Parser.data = this;
ResetState();
@@ -60,16 +67,17 @@ HttpRequestParser::~HttpRequestParser()
size_t
HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize)
{
- const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize);
-
- http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser));
-
- if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE)
+ llhttp_errno_t Err = llhttp_execute(&m_Parser, InputData, DataSize);
+ if (Err == HPE_OK)
{
- ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno));
- return ~0ull;
+ return DataSize;
}
- return ConsumedBytes;
+ if (Err == HPE_PAUSED_UPGRADE)
+ {
+ return DataSize;
+ }
+ ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", llhttp_errno_name(Err), llhttp_get_error_reason(&m_Parser));
+ return ~0ull;
}
int
@@ -79,7 +87,7 @@ HttpRequestParser::OnUrl(const char* Data, size_t Bytes)
if (RemainingBufferSpace < Bytes)
{
ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
+ return -1;
}
if (m_UrlRange.Length == 0)
@@ -101,7 +109,7 @@ HttpRequestParser::OnHeader(const char* Data, size_t Bytes)
if (RemainingBufferSpace < Bytes)
{
ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
+ return -1;
}
if (m_HeaderEntries.empty())
@@ -212,7 +220,7 @@ HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes)
if (RemainingBufferSpace < Bytes)
{
ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
+ return -1;
}
ZEN_ASSERT_SLOW(!m_HeaderEntries.empty());
@@ -269,9 +277,9 @@ HttpRequestParser::OnHeadersComplete()
}
}
- m_KeepAlive = !!http_should_keep_alive(&m_Parser);
+ m_KeepAlive = !!llhttp_should_keep_alive(&m_Parser);
- switch (m_Parser.method)
+ switch (llhttp_get_method(&m_Parser))
{
case HTTP_GET:
m_RequestVerb = HttpVerb::kGet;
@@ -302,7 +310,7 @@ HttpRequestParser::OnHeadersComplete()
break;
default:
- ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method));
+ ZEN_WARN("invalid HTTP method: '{}'", llhttp_method_name(static_cast<llhttp_method_t>(llhttp_get_method(&m_Parser))));
break;
}
@@ -349,20 +357,11 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes)
{
ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes",
(m_BodyPosition + Bytes) - m_BodyBuffer.Size());
- return 1;
+ return -1;
}
memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes);
m_BodyPosition += Bytes;
- if (http_body_is_final(&m_Parser))
- {
- if (m_BodyPosition != m_BodyBuffer.Size())
- {
- ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
- return 1;
- }
- }
-
return 0;
}
@@ -409,7 +408,7 @@ HttpRequestParser::OnMessageComplete()
catch (const AssertException& AssertEx)
{
ZEN_WARN("Assert caught when processing http request: {}", AssertEx.FullDescription());
- return 1;
+ return -1;
}
catch (const std::system_error& SystemError)
{
@@ -426,19 +425,19 @@ HttpRequestParser::OnMessageComplete()
ZEN_ERROR("failed processing http request: '{}' ({})", SystemError.what(), SystemError.code().value());
}
ResetState();
- return 1;
+ return -1;
}
catch (const std::bad_alloc& BadAlloc)
{
ZEN_WARN("out of memory when processing http request: '{}'", BadAlloc.what());
ResetState();
- return 1;
+ return -1;
}
catch (const std::exception& Ex)
{
ZEN_ERROR("failed processing http request: '{}'", Ex.what());
ResetState();
- return 1;
+ return -1;
}
}
@@ -459,4 +458,331 @@ HttpRequestParser::IsWebSocketUpgrade() const
return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0;
}
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+namespace {
+
+ struct MockCallbacks : HttpRequestParserCallbacks
+ {
+ int HandleRequestCount = 0;
+ int TerminateCount = 0;
+
+ HttpRequestParser* Parser = nullptr;
+
+ HttpVerb LastVerb{};
+ std::string LastUrl;
+ std::string LastQueryString;
+ std::string LastBody;
+ bool LastKeepAlive = false;
+ bool LastIsWebSocketUpgrade = false;
+ std::string LastSecWebSocketKey;
+ std::string LastUpgradeHeader;
+ HttpContentType LastContentType{};
+
+ void HandleRequest() override
+ {
+ ++HandleRequestCount;
+ if (Parser)
+ {
+ LastVerb = Parser->RequestVerb();
+ LastUrl = std::string(Parser->Url());
+ LastQueryString = std::string(Parser->QueryString());
+ LastKeepAlive = Parser->IsKeepAlive();
+ LastIsWebSocketUpgrade = Parser->IsWebSocketUpgrade();
+ LastSecWebSocketKey = std::string(Parser->SecWebSocketKey());
+ LastUpgradeHeader = std::string(Parser->UpgradeHeader());
+ LastContentType = Parser->ContentType();
+
+ IoBuffer Body = Parser->Body();
+ if (Body.Size() > 0)
+ {
+ LastBody.assign(reinterpret_cast<const char*>(Body.Data()), Body.Size());
+ }
+ else
+ {
+ LastBody.clear();
+ }
+ }
+ }
+
+ void TerminateConnection() override { ++TerminateCount; }
+ };
+
+} // anonymous namespace
+
+TEST_SUITE_BEGIN("http.httpparser");
+
+TEST_CASE("httpparser.basic_get")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastVerb, HttpVerb::kGet);
+ CHECK_EQ(Mock.LastUrl, "/path");
+ CHECK(Mock.LastKeepAlive);
+}
+
+TEST_CASE("httpparser.post_with_body")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request =
+ "POST /api HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Content-Length: 13\r\n"
+ "Content-Type: application/json\r\n"
+ "\r\n"
+ "{\"key\":\"val\"}";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastVerb, HttpVerb::kPost);
+ CHECK_EQ(Mock.LastBody, "{\"key\":\"val\"}");
+ CHECK_EQ(Mock.LastContentType, HttpContentType::kJSON);
+}
+
+TEST_CASE("httpparser.pipelined_requests")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request =
+ "GET /first HTTP/1.1\r\nHost: localhost\r\n\r\n"
+ "GET /second HTTP/1.1\r\nHost: localhost\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 2);
+ CHECK_EQ(Mock.LastUrl, "/second");
+}
+
+TEST_CASE("httpparser.partial_header")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Chunk1 = "GET /path HTTP/1.1\r\nHost: loc";
+ std::string Chunk2 = "alhost\r\n\r\n";
+
+ size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size());
+ CHECK_NE(Consumed1, ~0ull);
+ CHECK_EQ(Consumed1, Chunk1.size());
+ CHECK_EQ(Mock.HandleRequestCount, 0);
+
+ size_t Consumed2 = Parser.ConsumeData(Chunk2.data(), Chunk2.size());
+ CHECK_NE(Consumed2, ~0ull);
+ CHECK_EQ(Consumed2, Chunk2.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastUrl, "/path");
+}
+
+TEST_CASE("httpparser.partial_body")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Headers =
+ "POST /api HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Content-Length: 10\r\n"
+ "\r\n";
+ std::string BodyPart1 = "hello";
+ std::string BodyPart2 = "world";
+
+ std::string Chunk1 = Headers + BodyPart1;
+
+ size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size());
+ CHECK_NE(Consumed1, ~0ull);
+ CHECK_EQ(Consumed1, Chunk1.size());
+ CHECK_EQ(Mock.HandleRequestCount, 0);
+
+ size_t Consumed2 = Parser.ConsumeData(BodyPart2.data(), BodyPart2.size());
+ CHECK_NE(Consumed2, ~0ull);
+ CHECK_EQ(Consumed2, BodyPart2.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastBody, "helloworld");
+}
+
+TEST_CASE("httpparser.invalid_request")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Garbage = "NOT_HTTP garbage data\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Garbage.data(), Garbage.size());
+ CHECK_EQ(Consumed, ~0ull);
+ CHECK_EQ(Mock.HandleRequestCount, 0);
+}
+
+TEST_CASE("httpparser.body_overflow")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ // llhttp enforces Content-Length strictly: it delivers exactly 2 body bytes,
+ // fires on_message_complete, then tries to parse the remaining "O_LONG_BODY"
+ // as a new HTTP request which fails.
+ std::string Request =
+ "POST /api HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Content-Length: 2\r\n"
+ "\r\n"
+ "TOO_LONG_BODY";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, ~0ull);
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastBody, "TO");
+}
+
+TEST_CASE("httpparser.websocket_upgrade")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request =
+ "GET /ws HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ "Sec-WebSocket-Version: 13\r\n"
+ "\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK(Mock.LastIsWebSocketUpgrade);
+ CHECK_EQ(Mock.LastSecWebSocketKey, "dGhlIHNhbXBsZSBub25jZQ==");
+ CHECK_EQ(Mock.LastUpgradeHeader, "websocket");
+}
+
+TEST_CASE("httpparser.websocket_upgrade_with_trailing_bytes")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string HttpPart =
+ "GET /ws HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ "Sec-WebSocket-Version: 13\r\n"
+ "\r\n";
+
+ // Append fake WebSocket frame bytes after the HTTP message
+ std::string Request = HttpPart;
+ Request.push_back('\x81');
+ Request.push_back('\x05');
+ Request.append("hello");
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_NE(Consumed, ~0ull);
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK(Mock.LastIsWebSocketUpgrade);
+}
+
+TEST_CASE("httpparser.keep_alive_detection")
+{
+ SUBCASE("HTTP/1.1 default keep-alive")
+ {
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n";
+ Parser.ConsumeData(Request.data(), Request.size());
+ CHECK(Mock.LastKeepAlive);
+ }
+
+ SUBCASE("Connection: close disables keep-alive")
+ {
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
+ Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_FALSE(Mock.LastKeepAlive);
+ }
+}
+
+TEST_CASE("httpparser.all_verbs")
+{
+ struct VerbTest
+ {
+ const char* Method;
+ HttpVerb Expected;
+ };
+
+ VerbTest Tests[] = {
+ {"GET", HttpVerb::kGet},
+ {"POST", HttpVerb::kPost},
+ {"PUT", HttpVerb::kPut},
+ {"DELETE", HttpVerb::kDelete},
+ {"HEAD", HttpVerb::kHead},
+ {"COPY", HttpVerb::kCopy},
+ {"OPTIONS", HttpVerb::kOptions},
+ };
+
+ for (const VerbTest& Test : Tests)
+ {
+ CAPTURE(Test.Method);
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = std::string(Test.Method) + " /path HTTP/1.1\r\nHost: localhost\r\n\r\n";
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastVerb, Test.Expected);
+ }
+}
+
+TEST_CASE("httpparser.query_string")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path?key=val&other=123 HTTP/1.1\r\nHost: localhost\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastUrl, "/path");
+ CHECK_EQ(Mock.LastQueryString, "key=val&other=123");
+}
+
+TEST_SUITE_END();
+
+void
+httpparser_forcelink()
+{
+}
+
+#endif // ZEN_WITH_TESTS
+
} // namespace zen
diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h
index 23ad9d8fb..4ff216248 100644
--- a/src/zenhttp/servers/httpparser.h
+++ b/src/zenhttp/servers/httpparser.h
@@ -8,7 +8,7 @@
#include <EASTL/fixed_vector.h>
ZEN_THIRD_PARTY_INCLUDES_START
-#include <http_parser.h>
+#include <llhttp.h>
ZEN_THIRD_PARTY_INCLUDES_END
#include <atomic>
@@ -100,7 +100,7 @@ private:
Oid m_SessionId{};
IoBuffer m_BodyBuffer;
uint64_t m_BodyPosition = 0;
- http_parser m_Parser;
+ llhttp_t m_Parser;
eastl::fixed_vector<char, 512> m_HeaderData;
std::string m_NormalizedUrl;
@@ -114,8 +114,8 @@ private:
int OnBody(const char* Data, size_t Bytes);
int OnMessageComplete();
- static HttpRequestParser* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); }
- static http_parser_settings s_ParserSettings;
+ static HttpRequestParser* GetThis(llhttp_t* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); }
+ static llhttp_settings_t s_ParserSettings;
};
} // namespace zen
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp
index 31b0315d4..ad7ed259a 100644
--- a/src/zenhttp/servers/httpplugin.cpp
+++ b/src/zenhttp/servers/httpplugin.cpp
@@ -185,13 +185,17 @@ public:
const std::vector<IoBuffer>& ResponseBuffers() const { return m_ResponseBuffers; }
void SuppressPayload() { m_ResponseBuffers.resize(1); }
+ void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); }
+ void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); }
std::string_view GetHeaders();
private:
- uint16_t m_ResponseCode = 0;
- bool m_IsKeepAlive = true;
- HttpContentType m_ContentType = HttpContentType::kBinary;
+ uint16_t m_ResponseCode = 0;
+ bool m_IsKeepAlive = true;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ std::string m_ContentTypeOverride;
+ std::string m_ContentRangeHeader;
uint64_t m_ContentLength = 0;
std::vector<IoBuffer> m_ResponseBuffers;
ExtendableStringBuilder<160> m_Headers;
@@ -246,10 +250,18 @@ HttpPluginResponse::GetHeaders()
if (m_Headers.Size() == 0)
{
+ std::string_view ContentTypeStr =
+ m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride);
+
m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
- << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
+ << "Content-Type: " << ContentTypeStr << "\r\n"
<< "Content-Length: " << ContentLength() << "\r\n"sv;
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv;
+ }
+
if (!m_IsKeepAlive)
{
m_Headers << "Connection: close\r\n"sv;
@@ -669,6 +681,10 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode)
ZEN_MEMSCOPE(GetHttppluginTag());
m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary));
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
std::array<IoBuffer, 0> Empty;
m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
@@ -681,6 +697,14 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpConten
ZEN_MEMSCOPE(GetHttppluginTag());
m_Response.reset(new HttpPluginResponse(ContentType));
+ if (!m_ContentTypeOverride.empty())
+ {
+ m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride));
+ }
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
}
@@ -831,6 +855,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive)
ZEN_CONSOLE("Zen Server running (plugin HTTP). Press ESC or Q to quit");
}
+ bool ShutdownRequested = false;
do
{
if (IsInteractive && _kbhit() != 0)
@@ -844,18 +869,19 @@ HttpPluginServerImpl::OnRun(bool IsInteractive)
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested && !IsApplicationExitRequested());
# else
if (IsInteractive)
{
ZEN_CONSOLE("Zen Server running (plugin HTTP). Ctrl-C to quit");
}
+ bool ShutdownRequested = false;
do
{
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested && !IsApplicationExitRequested());
# endif
}
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 2cad97725..c1b426bea 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -464,6 +464,8 @@ public:
inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; }
+ void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); }
+ void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); }
private:
eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks;
@@ -473,6 +475,8 @@ private:
uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends
bool m_IsInitialResponse = true;
HttpContentType m_ContentType = HttpContentType::kBinary;
+ std::string m_ContentTypeOverride;
+ std::string m_ContentRangeHeader;
eastl::fixed_vector<IoBuffer, 16> m_DataBuffers;
std::string m_LocationHeader;
@@ -725,7 +729,8 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType];
- std::string_view ContentTypeString = MapContentTypeToString(m_ContentType);
+ std::string_view ContentTypeString =
+ m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride);
ContentTypeHeader->pRawValue = ContentTypeString.data();
ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
@@ -739,6 +744,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size();
}
+ // Content-Range header (for 206 Partial Content single-range responses)
+
+ if (!m_ContentRangeHeader.empty())
+ {
+ PHTTP_KNOWN_HEADER ContentRangeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentRange];
+ ContentRangeHeader->pRawValue = m_ContentRangeHeader.data();
+ ContentRangeHeader->RawValueLength = (USHORT)m_ContentRangeHeader.size();
+ }
+
std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
HttpResponse.StatusCode = m_ResponseCode;
@@ -1258,7 +1272,7 @@ HttpSysServer::RegisterHttpUrls(int BasePort)
else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED)
{
// Port may be owned by another process's wildcard registration (access denied)
- // or actively in use (sharing violation) — retry on a different port
+ // or actively in use (sharing violation) - retry on a different port
ShouldRetryNextPort = true;
}
else
@@ -1713,7 +1727,7 @@ HttpSysServer::OnRun(bool IsInteractive)
ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
UpdateLofreqTimerValue();
- } while (!ShutdownRequested);
+ } while (!ShutdownRequested && !IsApplicationExitRequested());
}
void
@@ -2279,6 +2293,11 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
+ if (!m_ContentRangeHeader.empty())
+ {
+ Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
+
if (SuppressBody())
{
Response->SuppressResponseBody();
@@ -2307,6 +2326,15 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
+ if (!m_ContentTypeOverride.empty())
+ {
+ Response->SetContentTypeOverride(std::move(m_ContentTypeOverride));
+ }
+ if (!m_ContentRangeHeader.empty())
+ {
+ Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
+
if (SuppressBody())
{
Response->SuppressResponseBody();
@@ -2595,7 +2623,14 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
&Transaction().Server()));
Ref<WebSocketConnection> WsConnRef(WsConn.Get());
- WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ ExtendableStringBuilder<128> UrlUtf8;
+ WideToUtf8({(wchar_t*)HttpReq->CookedUrl.pAbsPath,
+ gsl::narrow<size_t>(HttpReq->CookedUrl.AbsPathLength / sizeof(wchar_t))},
+ UrlUtf8);
+ int PrefixLen = Service->UriPrefixLength();
+ std::string_view RelativeUri{UrlUtf8.ToView()};
+ RelativeUri.remove_prefix(std::min(PrefixLen, static_cast<int>(RelativeUri.size())));
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri);
WsConn->Start();
return nullptr;
@@ -2603,11 +2638,11 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult);
- // WebSocket upgrade failed — return nullptr since ServerRequest()
+ // WebSocket upgrade failed - return nullptr since ServerRequest()
// was never populated (no InvokeRequestHandler call)
return nullptr;
}
- // Service doesn't support WebSocket or missing key — fall through to normal handling
+ // Service doesn't support WebSocket or missing key - fall through to normal handling
}
}
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
index 5ae48f5b3..078c21ea1 100644
--- a/src/zenhttp/servers/wsasio.cpp
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -141,7 +141,7 @@ WsAsioConnectionT<SocketType>::ProcessReceivedData()
}
case WebSocketOpcode::kPong:
- // Unsolicited pong — ignore per RFC 6455
+ // Unsolicited pong - ignore per RFC 6455
break;
case WebSocketOpcode::kClose:
diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp
index af320172d..8520e9f60 100644
--- a/src/zenhttp/servers/wshttpsys.cpp
+++ b/src/zenhttp/servers/wshttpsys.cpp
@@ -70,7 +70,7 @@ WsHttpSysConnection::Shutdown()
return;
}
- // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED
+ // Cancel pending I/O - completions will fire with ERROR_OPERATION_ABORTED
HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
}
@@ -211,7 +211,7 @@ WsHttpSysConnection::ProcessReceivedData()
}
case WebSocketOpcode::kPong:
- // Unsolicited pong — ignore per RFC 6455
+ // Unsolicited pong - ignore per RFC 6455
break;
case WebSocketOpcode::kClose:
@@ -446,7 +446,7 @@ WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason)
m_Handler.OnWebSocketClose(*this, Code, Reason);
- // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED
+ // Cancel pending read I/O - completions drain via ERROR_OPERATION_ABORTED
HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
}
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
index 59c46a418..a58037fec 100644
--- a/src/zenhttp/servers/wstest.cpp
+++ b/src/zenhttp/servers/wstest.cpp
@@ -5,6 +5,7 @@
# include <zencore/scopeguard.h>
# include <zencore/testing.h>
# include <zencore/testutils.h>
+# include <zencore/timer.h>
# include <zenhttp/httpserver.h>
# include <zenhttp/httpwsclient.h>
@@ -59,7 +60,7 @@ TEST_CASE("websocket.framecodec")
std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
- // Server frames are unmasked — TryParseFrame should handle them
+ // Server frames are unmasked - TryParseFrame should handle them
WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
CHECK(Result.IsValid);
@@ -129,7 +130,7 @@ TEST_CASE("websocket.framecodec")
{
std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
- // Pass only 1 byte — not enough for a frame header
+ // Pass only 1 byte - not enough for a frame header
WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1);
CHECK_FALSE(Result.IsValid);
CHECK_EQ(Result.BytesConsumed, 0u);
@@ -335,8 +336,9 @@ namespace {
}
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override
{
+ ZEN_UNUSED(RelativeUri);
m_OpenCount.fetch_add(1);
m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); });
@@ -463,7 +465,7 @@ namespace {
if (!Done.load())
{
- // Timeout — cancel the read
+ // Timeout - cancel the read
asio::error_code Ec;
Sock.cancel(Ec);
}
@@ -476,6 +478,23 @@ namespace {
return Result;
}
+ static void WaitForServerListening(int Port)
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Probe(IoCtx);
+ asio::error_code Ec;
+ Probe.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)), Ec);
+ if (!Ec)
+ {
+ return;
+ }
+ Sleep(10);
+ }
+ }
+
} // anonymous namespace
TEST_CASE("websocket.integration")
@@ -501,8 +520,8 @@ TEST_CASE("websocket.integration")
Server->Close();
});
- // Give server a moment to start accepting
- Sleep(100);
+ // Wait for server to start accepting
+ WaitForServerListening(Port);
SUBCASE("handshake succeeds with 101")
{
@@ -692,7 +711,7 @@ TEST_CASE("websocket.integration")
std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
- // Should NOT get 101 — should fall through to normal request handling
+ // Should NOT get 101 - should fall through to normal request handling
CHECK(Response.find("101") == std::string::npos);
Sock.close();
@@ -813,7 +832,7 @@ TEST_CASE("websocket.client")
Server->Close();
});
- Sleep(100);
+ WaitForServerListening(Port);
SUBCASE("connect, echo, close")
{
@@ -937,7 +956,7 @@ TEST_CASE("websocket.client.unixsocket")
Server->Close();
});
- Sleep(100);
+ WaitForServerListening(Port);
SUBCASE("connect, echo, close over unix socket")
{