diff options
Diffstat (limited to 'src/zenhttp/servers')
| -rw-r--r-- | src/zenhttp/servers/httpasio.cpp | 42 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpmulti.cpp | 4 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpnull.cpp | 4 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpparser.cpp | 414 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpparser.h | 8 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpplugin.cpp | 42 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpsys.cpp | 47 | ||||
| -rw-r--r-- | src/zenhttp/servers/wsasio.cpp | 2 | ||||
| -rw-r--r-- | src/zenhttp/servers/wshttpsys.cpp | 6 | ||||
| -rw-r--r-- | src/zenhttp/servers/wstest.cpp | 37 |
10 files changed, 520 insertions, 86 deletions
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 7972777b8..b624c3a29 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -625,6 +625,8 @@ public: void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; } void SetKeepAlive(bool KeepAlive) { m_IsKeepAlive = KeepAlive; } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } /** * Initialize the response for sending a payload made up of multiple blobs @@ -768,10 +770,18 @@ public: { ZEN_MEMSCOPE(GetHttpasioTag()); + std::string_view ContentTypeStr = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" - << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Type: " << ContentTypeStr << "\r\n" << "Content-Length: " << ContentLength() << "\r\n"sv; + if (!m_ContentRangeHeader.empty()) + { + m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv; + } + if (!m_IsKeepAlive) { m_Headers << "Connection: close\r\n"sv; @@ -898,7 +908,9 @@ private: bool m_AllowZeroCopyFileSend = true; State m_State = State::kUninitialized; HttpContentType m_ContentType = HttpContentType::kBinary; - uint64_t m_ContentLength = 0; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; + uint64_t m_ContentLength = 0; eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive ExtendableStringBuilder<160> m_Headers; @@ -1275,7 +1287,9 @@ HttpServerConnectionT<SocketType>::HandleRequest() asio::buffer(ResponseStr->data(), ResponseStr->size()), asio::bind_executor( m_Strand, - [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr, PrefixLen = Service->UriPrefixLength()]( + const asio::error_code& Ec, + std::size_t) { if (Ec) { ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); @@ -1287,7 +1301,9 @@ HttpServerConnectionT<SocketType>::HandleRequest() Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + std::string_view FullUrl = Conn->m_RequestData.Url(); + std::string_view RelativeUri = FullUrl.substr(std::min(PrefixLen, static_cast<int>(FullUrl.size()))); + WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri); WsConn->Start(); })); @@ -1295,7 +1311,7 @@ HttpServerConnectionT<SocketType>::HandleRequest() return; } } - // Service doesn't support WebSocket or missing key — fall through to normal handling + // Service doesn't support WebSocket or missing key - fall through to normal handling } if (!m_RequestData.IsKeepAlive()) @@ -2127,6 +2143,10 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); m_Response->SetKeepAlive(m_Request.IsKeepAlive()); + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -2142,6 +2162,14 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); m_Response->SetKeepAlive(m_Request.IsKeepAlive()); + if (!m_ContentTypeOverride.empty()) + { + m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } @@ -2590,7 +2618,7 @@ HttpAsioServer::OnRun(bool IsInteractive) } ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #else if (IsInteractive) { @@ -2600,7 +2628,7 @@ HttpAsioServer::OnRun(bool IsInteractive) do { ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #endif } diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 584e06cbf..196c0c142 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -88,7 +88,7 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) } ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #else if (IsInteractiveSession) { @@ -98,7 +98,7 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) do { ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #endif } diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp index 9bb7ef3bc..d698bcb9d 100644 --- a/src/zenhttp/servers/httpnull.cpp +++ b/src/zenhttp/servers/httpnull.cpp @@ -63,7 +63,7 @@ HttpNullServer::OnRun(bool IsInteractiveSession) } ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #else if (IsInteractiveSession) { @@ -73,7 +73,7 @@ HttpNullServer::OnRun(bool IsInteractiveSession) do { ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #endif } diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index 918b55dc6..8b07c7905 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -8,6 +8,13 @@ #include <limits> +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <cstring> +# include <string> +# include <string_view> +#endif + namespace zen { using namespace std::literals; @@ -29,25 +36,25 @@ static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-W // HttpRequestParser // -http_parser_settings HttpRequestParser::s_ParserSettings{ - .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, - .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, - .on_status = - [](http_parser* p, const char* Data, size_t ByteCount) { - ZEN_UNUSED(p, Data, ByteCount); - return 0; - }, - .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, - .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, - .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, - .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, - .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, - .on_chunk_header{}, - .on_chunk_complete{}}; +// clang-format off +llhttp_settings_t HttpRequestParser::s_ParserSettings = []() { + llhttp_settings_t S; + llhttp_settings_init(&S); + S.on_message_begin = [](llhttp_t* p) { return GetThis(p)->OnMessageBegin(); }; + S.on_url = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }; + S.on_status = [](llhttp_t*, const char*, size_t) { return 0; }; + S.on_header_field = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }; + S.on_header_value = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }; + S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); }; + S.on_body = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }; + S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); }; + return S; +}(); +// clang-format on HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection) { - http_parser_init(&m_Parser, HTTP_REQUEST); + llhttp_init(&m_Parser, HTTP_REQUEST, &s_ParserSettings); m_Parser.data = this; ResetState(); @@ -60,16 +67,17 @@ HttpRequestParser::~HttpRequestParser() size_t HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize) { - const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); - - http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser)); - - if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE) + llhttp_errno_t Err = llhttp_execute(&m_Parser, InputData, DataSize); + if (Err == HPE_OK) { - ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno)); - return ~0ull; + return DataSize; } - return ConsumedBytes; + if (Err == HPE_PAUSED_UPGRADE) + { + return DataSize; + } + ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", llhttp_errno_name(Err), llhttp_get_error_reason(&m_Parser)); + return ~0ull; } int @@ -79,7 +87,7 @@ HttpRequestParser::OnUrl(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } if (m_UrlRange.Length == 0) @@ -101,7 +109,7 @@ HttpRequestParser::OnHeader(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } if (m_HeaderEntries.empty()) @@ -212,7 +220,7 @@ HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } ZEN_ASSERT_SLOW(!m_HeaderEntries.empty()); @@ -269,9 +277,9 @@ HttpRequestParser::OnHeadersComplete() } } - m_KeepAlive = !!http_should_keep_alive(&m_Parser); + m_KeepAlive = !!llhttp_should_keep_alive(&m_Parser); - switch (m_Parser.method) + switch (llhttp_get_method(&m_Parser)) { case HTTP_GET: m_RequestVerb = HttpVerb::kGet; @@ -302,7 +310,7 @@ HttpRequestParser::OnHeadersComplete() break; default: - ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method)); + ZEN_WARN("invalid HTTP method: '{}'", llhttp_method_name(static_cast<llhttp_method_t>(llhttp_get_method(&m_Parser)))); break; } @@ -349,20 +357,11 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes) { ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes", (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); - return 1; + return -1; } memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); m_BodyPosition += Bytes; - if (http_body_is_final(&m_Parser)) - { - if (m_BodyPosition != m_BodyBuffer.Size()) - { - ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); - return 1; - } - } - return 0; } @@ -409,7 +408,7 @@ HttpRequestParser::OnMessageComplete() catch (const AssertException& AssertEx) { ZEN_WARN("Assert caught when processing http request: {}", AssertEx.FullDescription()); - return 1; + return -1; } catch (const std::system_error& SystemError) { @@ -426,19 +425,19 @@ HttpRequestParser::OnMessageComplete() ZEN_ERROR("failed processing http request: '{}' ({})", SystemError.what(), SystemError.code().value()); } ResetState(); - return 1; + return -1; } catch (const std::bad_alloc& BadAlloc) { ZEN_WARN("out of memory when processing http request: '{}'", BadAlloc.what()); ResetState(); - return 1; + return -1; } catch (const std::exception& Ex) { ZEN_ERROR("failed processing http request: '{}'", Ex.what()); ResetState(); - return 1; + return -1; } } @@ -459,4 +458,331 @@ HttpRequestParser::IsWebSocketUpgrade() const return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0; } +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +namespace { + + struct MockCallbacks : HttpRequestParserCallbacks + { + int HandleRequestCount = 0; + int TerminateCount = 0; + + HttpRequestParser* Parser = nullptr; + + HttpVerb LastVerb{}; + std::string LastUrl; + std::string LastQueryString; + std::string LastBody; + bool LastKeepAlive = false; + bool LastIsWebSocketUpgrade = false; + std::string LastSecWebSocketKey; + std::string LastUpgradeHeader; + HttpContentType LastContentType{}; + + void HandleRequest() override + { + ++HandleRequestCount; + if (Parser) + { + LastVerb = Parser->RequestVerb(); + LastUrl = std::string(Parser->Url()); + LastQueryString = std::string(Parser->QueryString()); + LastKeepAlive = Parser->IsKeepAlive(); + LastIsWebSocketUpgrade = Parser->IsWebSocketUpgrade(); + LastSecWebSocketKey = std::string(Parser->SecWebSocketKey()); + LastUpgradeHeader = std::string(Parser->UpgradeHeader()); + LastContentType = Parser->ContentType(); + + IoBuffer Body = Parser->Body(); + if (Body.Size() > 0) + { + LastBody.assign(reinterpret_cast<const char*>(Body.Data()), Body.Size()); + } + else + { + LastBody.clear(); + } + } + } + + void TerminateConnection() override { ++TerminateCount; } + }; + +} // anonymous namespace + +TEST_SUITE_BEGIN("http.httpparser"); + +TEST_CASE("httpparser.basic_get") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, HttpVerb::kGet); + CHECK_EQ(Mock.LastUrl, "/path"); + CHECK(Mock.LastKeepAlive); +} + +TEST_CASE("httpparser.post_with_body") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 13\r\n" + "Content-Type: application/json\r\n" + "\r\n" + "{\"key\":\"val\"}"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, HttpVerb::kPost); + CHECK_EQ(Mock.LastBody, "{\"key\":\"val\"}"); + CHECK_EQ(Mock.LastContentType, HttpContentType::kJSON); +} + +TEST_CASE("httpparser.pipelined_requests") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "GET /first HTTP/1.1\r\nHost: localhost\r\n\r\n" + "GET /second HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 2); + CHECK_EQ(Mock.LastUrl, "/second"); +} + +TEST_CASE("httpparser.partial_header") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Chunk1 = "GET /path HTTP/1.1\r\nHost: loc"; + std::string Chunk2 = "alhost\r\n\r\n"; + + size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size()); + CHECK_NE(Consumed1, ~0ull); + CHECK_EQ(Consumed1, Chunk1.size()); + CHECK_EQ(Mock.HandleRequestCount, 0); + + size_t Consumed2 = Parser.ConsumeData(Chunk2.data(), Chunk2.size()); + CHECK_NE(Consumed2, ~0ull); + CHECK_EQ(Consumed2, Chunk2.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastUrl, "/path"); +} + +TEST_CASE("httpparser.partial_body") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Headers = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 10\r\n" + "\r\n"; + std::string BodyPart1 = "hello"; + std::string BodyPart2 = "world"; + + std::string Chunk1 = Headers + BodyPart1; + + size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size()); + CHECK_NE(Consumed1, ~0ull); + CHECK_EQ(Consumed1, Chunk1.size()); + CHECK_EQ(Mock.HandleRequestCount, 0); + + size_t Consumed2 = Parser.ConsumeData(BodyPart2.data(), BodyPart2.size()); + CHECK_NE(Consumed2, ~0ull); + CHECK_EQ(Consumed2, BodyPart2.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastBody, "helloworld"); +} + +TEST_CASE("httpparser.invalid_request") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Garbage = "NOT_HTTP garbage data\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Garbage.data(), Garbage.size()); + CHECK_EQ(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 0); +} + +TEST_CASE("httpparser.body_overflow") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + // llhttp enforces Content-Length strictly: it delivers exactly 2 body bytes, + // fires on_message_complete, then tries to parse the remaining "O_LONG_BODY" + // as a new HTTP request which fails. + std::string Request = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 2\r\n" + "\r\n" + "TOO_LONG_BODY"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastBody, "TO"); +} + +TEST_CASE("httpparser.websocket_upgrade") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "GET /ws HTTP/1.1\r\n" + "Host: localhost\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK(Mock.LastIsWebSocketUpgrade); + CHECK_EQ(Mock.LastSecWebSocketKey, "dGhlIHNhbXBsZSBub25jZQ=="); + CHECK_EQ(Mock.LastUpgradeHeader, "websocket"); +} + +TEST_CASE("httpparser.websocket_upgrade_with_trailing_bytes") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string HttpPart = + "GET /ws HTTP/1.1\r\n" + "Host: localhost\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + + // Append fake WebSocket frame bytes after the HTTP message + std::string Request = HttpPart; + Request.push_back('\x81'); + Request.push_back('\x05'); + Request.append("hello"); + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_NE(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK(Mock.LastIsWebSocketUpgrade); +} + +TEST_CASE("httpparser.keep_alive_detection") +{ + SUBCASE("HTTP/1.1 default keep-alive") + { + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + Parser.ConsumeData(Request.data(), Request.size()); + CHECK(Mock.LastKeepAlive); + } + + SUBCASE("Connection: close disables keep-alive") + { + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; + Parser.ConsumeData(Request.data(), Request.size()); + CHECK_FALSE(Mock.LastKeepAlive); + } +} + +TEST_CASE("httpparser.all_verbs") +{ + struct VerbTest + { + const char* Method; + HttpVerb Expected; + }; + + VerbTest Tests[] = { + {"GET", HttpVerb::kGet}, + {"POST", HttpVerb::kPost}, + {"PUT", HttpVerb::kPut}, + {"DELETE", HttpVerb::kDelete}, + {"HEAD", HttpVerb::kHead}, + {"COPY", HttpVerb::kCopy}, + {"OPTIONS", HttpVerb::kOptions}, + }; + + for (const VerbTest& Test : Tests) + { + CAPTURE(Test.Method); + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = std::string(Test.Method) + " /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, Test.Expected); + } +} + +TEST_CASE("httpparser.query_string") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path?key=val&other=123 HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastUrl, "/path"); + CHECK_EQ(Mock.LastQueryString, "key=val&other=123"); +} + +TEST_SUITE_END(); + +void +httpparser_forcelink() +{ +} + +#endif // ZEN_WITH_TESTS + } // namespace zen diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index 23ad9d8fb..4ff216248 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -8,7 +8,7 @@ #include <EASTL/fixed_vector.h> ZEN_THIRD_PARTY_INCLUDES_START -#include <http_parser.h> +#include <llhttp.h> ZEN_THIRD_PARTY_INCLUDES_END #include <atomic> @@ -100,7 +100,7 @@ private: Oid m_SessionId{}; IoBuffer m_BodyBuffer; uint64_t m_BodyPosition = 0; - http_parser m_Parser; + llhttp_t m_Parser; eastl::fixed_vector<char, 512> m_HeaderData; std::string m_NormalizedUrl; @@ -114,8 +114,8 @@ private: int OnBody(const char* Data, size_t Bytes); int OnMessageComplete(); - static HttpRequestParser* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); } - static http_parser_settings s_ParserSettings; + static HttpRequestParser* GetThis(llhttp_t* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); } + static llhttp_settings_t s_ParserSettings; }; } // namespace zen diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 31b0315d4..ad7ed259a 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -185,13 +185,17 @@ public: const std::vector<IoBuffer>& ResponseBuffers() const { return m_ResponseBuffers; } void SuppressPayload() { m_ResponseBuffers.resize(1); } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } std::string_view GetHeaders(); private: - uint16_t m_ResponseCode = 0; - bool m_IsKeepAlive = true; - HttpContentType m_ContentType = HttpContentType::kBinary; + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; uint64_t m_ContentLength = 0; std::vector<IoBuffer> m_ResponseBuffers; ExtendableStringBuilder<160> m_Headers; @@ -246,10 +250,18 @@ HttpPluginResponse::GetHeaders() if (m_Headers.Size() == 0) { + std::string_view ContentTypeStr = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" - << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Type: " << ContentTypeStr << "\r\n" << "Content-Length: " << ContentLength() << "\r\n"sv; + if (!m_ContentRangeHeader.empty()) + { + m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv; + } + if (!m_IsKeepAlive) { m_Headers << "Connection: close\r\n"sv; @@ -669,6 +681,10 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode) ZEN_MEMSCOPE(GetHttppluginTag()); m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary)); + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -681,6 +697,14 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpConten ZEN_MEMSCOPE(GetHttppluginTag()); m_Response.reset(new HttpPluginResponse(ContentType)); + if (!m_ContentTypeOverride.empty()) + { + m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } @@ -831,6 +855,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive) ZEN_CONSOLE("Zen Server running (plugin HTTP). Press ESC or Q to quit"); } + bool ShutdownRequested = false; do { if (IsInteractive && _kbhit() != 0) @@ -844,18 +869,19 @@ HttpPluginServerImpl::OnRun(bool IsInteractive) } } - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested && !IsApplicationExitRequested()); # else if (IsInteractive) { ZEN_CONSOLE("Zen Server running (plugin HTTP). Ctrl-C to quit"); } + bool ShutdownRequested = false; do { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested && !IsApplicationExitRequested()); # endif } diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 2cad97725..c1b426bea 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -464,6 +464,8 @@ public: inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } private: eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks; @@ -473,6 +475,8 @@ private: uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; eastl::fixed_vector<IoBuffer, 16> m_DataBuffers; std::string m_LocationHeader; @@ -725,7 +729,8 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; - std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + std::string_view ContentTypeString = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); @@ -739,6 +744,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size(); } + // Content-Range header (for 206 Partial Content single-range responses) + + if (!m_ContentRangeHeader.empty()) + { + PHTTP_KNOWN_HEADER ContentRangeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentRange]; + ContentRangeHeader->pRawValue = m_ContentRangeHeader.data(); + ContentRangeHeader->RawValueLength = (USHORT)m_ContentRangeHeader.size(); + } + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.StatusCode = m_ResponseCode; @@ -1258,7 +1272,7 @@ HttpSysServer::RegisterHttpUrls(int BasePort) else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED) { // Port may be owned by another process's wildcard registration (access denied) - // or actively in use (sharing violation) — retry on a different port + // or actively in use (sharing violation) - retry on a different port ShouldRetryNextPort = true; } else @@ -1713,7 +1727,7 @@ HttpSysServer::OnRun(bool IsInteractive) ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); UpdateLofreqTimerValue(); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); } void @@ -2279,6 +2293,11 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + if (!m_ContentRangeHeader.empty()) + { + Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } + if (SuppressBody()) { Response->SuppressResponseBody(); @@ -2307,6 +2326,15 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + if (!m_ContentTypeOverride.empty()) + { + Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } + if (SuppressBody()) { Response->SuppressResponseBody(); @@ -2595,7 +2623,14 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT &Transaction().Server())); Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + ExtendableStringBuilder<128> UrlUtf8; + WideToUtf8({(wchar_t*)HttpReq->CookedUrl.pAbsPath, + gsl::narrow<size_t>(HttpReq->CookedUrl.AbsPathLength / sizeof(wchar_t))}, + UrlUtf8); + int PrefixLen = Service->UriPrefixLength(); + std::string_view RelativeUri{UrlUtf8.ToView()}; + RelativeUri.remove_prefix(std::min(PrefixLen, static_cast<int>(RelativeUri.size()))); + WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri); WsConn->Start(); return nullptr; @@ -2603,11 +2638,11 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult); - // WebSocket upgrade failed — return nullptr since ServerRequest() + // WebSocket upgrade failed - return nullptr since ServerRequest() // was never populated (no InvokeRequestHandler call) return nullptr; } - // Service doesn't support WebSocket or missing key — fall through to normal handling + // Service doesn't support WebSocket or missing key - fall through to normal handling } } diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp index 5ae48f5b3..078c21ea1 100644 --- a/src/zenhttp/servers/wsasio.cpp +++ b/src/zenhttp/servers/wsasio.cpp @@ -141,7 +141,7 @@ WsAsioConnectionT<SocketType>::ProcessReceivedData() } case WebSocketOpcode::kPong: - // Unsolicited pong — ignore per RFC 6455 + // Unsolicited pong - ignore per RFC 6455 break; case WebSocketOpcode::kClose: diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp index af320172d..8520e9f60 100644 --- a/src/zenhttp/servers/wshttpsys.cpp +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -70,7 +70,7 @@ WsHttpSysConnection::Shutdown() return; } - // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED + // Cancel pending I/O - completions will fire with ERROR_OPERATION_ABORTED HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); } @@ -211,7 +211,7 @@ WsHttpSysConnection::ProcessReceivedData() } case WebSocketOpcode::kPong: - // Unsolicited pong — ignore per RFC 6455 + // Unsolicited pong - ignore per RFC 6455 break; case WebSocketOpcode::kClose: @@ -446,7 +446,7 @@ WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason) m_Handler.OnWebSocketClose(*this, Code, Reason); - // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED + // Cancel pending read I/O - completions drain via ERROR_OPERATION_ABORTED HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); } diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp index 59c46a418..a58037fec 100644 --- a/src/zenhttp/servers/wstest.cpp +++ b/src/zenhttp/servers/wstest.cpp @@ -5,6 +5,7 @@ # include <zencore/scopeguard.h> # include <zencore/testing.h> # include <zencore/testutils.h> +# include <zencore/timer.h> # include <zenhttp/httpserver.h> # include <zenhttp/httpwsclient.h> @@ -59,7 +60,7 @@ TEST_CASE("websocket.framecodec") std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); - // Server frames are unmasked — TryParseFrame should handle them + // Server frames are unmasked - TryParseFrame should handle them WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); @@ -129,7 +130,7 @@ TEST_CASE("websocket.framecodec") { std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{}); - // Pass only 1 byte — not enough for a frame header + // Pass only 1 byte - not enough for a frame header WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1); CHECK_FALSE(Result.IsValid); CHECK_EQ(Result.BytesConsumed, 0u); @@ -335,8 +336,9 @@ namespace { } // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override { + ZEN_UNUSED(RelativeUri); m_OpenCount.fetch_add(1); m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); }); @@ -463,7 +465,7 @@ namespace { if (!Done.load()) { - // Timeout — cancel the read + // Timeout - cancel the read asio::error_code Ec; Sock.cancel(Ec); } @@ -476,6 +478,23 @@ namespace { return Result; } + static void WaitForServerListening(int Port) + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + asio::io_context IoCtx; + asio::ip::tcp::socket Probe(IoCtx); + asio::error_code Ec; + Probe.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)), Ec); + if (!Ec) + { + return; + } + Sleep(10); + } + } + } // anonymous namespace TEST_CASE("websocket.integration") @@ -501,8 +520,8 @@ TEST_CASE("websocket.integration") Server->Close(); }); - // Give server a moment to start accepting - Sleep(100); + // Wait for server to start accepting + WaitForServerListening(Port); SUBCASE("handshake succeeds with 101") { @@ -692,7 +711,7 @@ TEST_CASE("websocket.integration") std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); - // Should NOT get 101 — should fall through to normal request handling + // Should NOT get 101 - should fall through to normal request handling CHECK(Response.find("101") == std::string::npos); Sock.close(); @@ -813,7 +832,7 @@ TEST_CASE("websocket.client") Server->Close(); }); - Sleep(100); + WaitForServerListening(Port); SUBCASE("connect, echo, close") { @@ -937,7 +956,7 @@ TEST_CASE("websocket.client.unixsocket") Server->Close(); }); - Sleep(100); + WaitForServerListening(Port); SUBCASE("connect, echo, close over unix socket") { |