// Copyright Epic Games, Inc. All Rights Reserved. #include "httpparser.h" #include #include #include #include #if ZEN_WITH_TESTS # include # include # include # include #endif namespace zen { using namespace std::literals; static constexpr uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); static constexpr uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); static constexpr uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); static constexpr uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); static constexpr uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); static constexpr uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); static constexpr uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); static constexpr uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv); static constexpr uint32_t HashUpgrade = HashStringAsLowerDjb2("Upgrade"sv); static constexpr uint32_t HashSecWebSocketKey = HashStringAsLowerDjb2("Sec-WebSocket-Key"sv); static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-WebSocket-Version"sv); ////////////////////////////////////////////////////////////////////////// // // HttpRequestParser // // clang-format off llhttp_settings_t HttpRequestParser::s_ParserSettings = []() { llhttp_settings_t S; llhttp_settings_init(&S); S.on_message_begin = [](llhttp_t* p) { return GetThis(p)->OnMessageBegin(); }; S.on_url = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }; S.on_status = [](llhttp_t*, const char*, size_t) { return 0; }; S.on_header_field = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }; S.on_header_value = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }; S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); }; S.on_body = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }; S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); }; return S; }(); // clang-format on HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection) { llhttp_init(&m_Parser, HTTP_REQUEST, &s_ParserSettings); m_Parser.data = this; ResetState(); } HttpRequestParser::~HttpRequestParser() { } size_t HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize) { llhttp_errno_t Err = llhttp_execute(&m_Parser, InputData, DataSize); if (Err == HPE_OK) { return DataSize; } if (Err == HPE_PAUSED_UPGRADE) { return DataSize; } ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", llhttp_errno_name(Err), llhttp_get_error_reason(&m_Parser)); return ~0ull; } int HttpRequestParser::OnUrl(const char* Data, size_t Bytes) { const size_t RemainingBufferSpace = std::numeric_limits::max() - m_HeaderData.size(); if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); return -1; } if (m_UrlRange.Length == 0) { ZEN_ASSERT_SLOW(m_UrlRange.Offset == 0); m_UrlRange.Offset = (uint32_t)m_HeaderData.size(); } m_HeaderData.insert(m_HeaderData.end(), Data, &Data[Bytes]); m_UrlRange.Length += (uint32_t)Bytes; return 0; } int HttpRequestParser::OnHeader(const char* Data, size_t Bytes) { const size_t RemainingBufferSpace = std::numeric_limits::max() - m_HeaderData.size(); if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); return -1; } if (m_HeaderEntries.empty()) { m_HeaderEntries.resize(1); } HeaderEntry* CurrentHeaderEntry = &m_HeaderEntries.back(); if (CurrentHeaderEntry->ValueRange.Length) { ParseCurrentHeader(); m_HeaderEntries.emplace_back(HeaderEntry{.NameRange = {.Offset = (uint32_t)m_HeaderData.size()}}); CurrentHeaderEntry = &m_HeaderEntries.back(); } else if (CurrentHeaderEntry->NameRange.Length == 0) { m_HeaderEntries.emplace_back(HeaderEntry{.NameRange = {.Offset = (uint32_t)m_HeaderData.size()}}); CurrentHeaderEntry = &m_HeaderEntries.back(); } m_HeaderData.insert(m_HeaderData.end(), Data, &Data[Bytes]); CurrentHeaderEntry->NameRange.Length += (uint32_t)Bytes; return 0; } void HttpRequestParser::ParseCurrentHeader() { ZEN_ASSERT_SLOW(!m_HeaderEntries.empty()); const HeaderEntry& CurrentHeaderEntry = m_HeaderEntries.back(); const size_t CurrentHeaderCount = m_HeaderEntries.size(); const std::string_view HeaderName(GetHeaderSubString(CurrentHeaderEntry.NameRange)); if (CurrentHeaderCount > std::numeric_limits::max()) { ZEN_WARN("HttpRequestParser parser only supports up to {} headers, can't store header '{}'. Dropping it.", std::numeric_limits::max(), HeaderName); return; } const std::string_view HeaderValue(GetHeaderSubString(CurrentHeaderEntry.ValueRange)); const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1); switch (HeaderHash) { case HashContentLength: m_ContentLengthHeaderIndex = CurrentHeaderIndex; break; case HashAccept: m_AcceptHeaderIndex = CurrentHeaderIndex; break; case HashContentType: m_ContentTypeHeaderIndex = CurrentHeaderIndex; break; case HashAuthorization: m_AuthorizationHeaderIndex = CurrentHeaderIndex; break; case HashSession: m_SessionId = Oid::TryFromHexString(HeaderValue); break; case HashRequest: std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); break; case HashExpect: if (HeaderValue == "100-continue"sv) { // We don't currently do anything with this m_Expect100Continue = true; } else { ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); } break; case HashRange: m_RangeHeaderIndex = CurrentHeaderIndex; break; case HashUpgrade: m_UpgradeHeaderIndex = CurrentHeaderIndex; break; case HashSecWebSocketKey: m_SecWebSocketKeyHeaderIndex = CurrentHeaderIndex; break; case HashSecWebSocketVersion: m_SecWebSocketVersionHeaderIndex = CurrentHeaderIndex; break; default: break; } } int HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes) { const size_t RemainingBufferSpace = std::numeric_limits::max() - m_HeaderData.size(); if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); return -1; } ZEN_ASSERT_SLOW(!m_HeaderEntries.empty()); HeaderEntry& CurrentHeaderEntry = m_HeaderEntries.back(); if (CurrentHeaderEntry.ValueRange.Length == 0) { CurrentHeaderEntry.ValueRange.Offset = (uint32_t)m_HeaderData.size(); } m_HeaderData.insert(m_HeaderData.end(), Data, &Data[Bytes]); CurrentHeaderEntry.ValueRange.Length += (uint32_t)Bytes; return 0; } static void NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl) { bool LastCharWasSeparator = false; const char* Url = InUrl.data(); const size_t UrlLength = InUrl.length(); for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex) { const char UrlChar = Url[UrlIndex]; const bool IsSeparator = (UrlChar == '/'); if (IsSeparator && LastCharWasSeparator) { if (NormalizedUrl.empty()) { NormalizedUrl.reserve(UrlLength); NormalizedUrl.append(Url, UrlIndex); } } else if (!NormalizedUrl.empty()) { NormalizedUrl.push_back(UrlChar); } LastCharWasSeparator = IsSeparator; } } int HttpRequestParser::OnHeadersComplete() { try { if (!m_HeaderEntries.empty()) { HeaderEntry& CurrentHeaderEntry = m_HeaderEntries.back(); if (CurrentHeaderEntry.NameRange.Length) { ParseCurrentHeader(); } } m_KeepAlive = !!llhttp_should_keep_alive(&m_Parser); switch (llhttp_get_method(&m_Parser)) { case HTTP_GET: m_RequestVerb = HttpVerb::kGet; break; case HTTP_POST: m_RequestVerb = HttpVerb::kPost; break; case HTTP_PUT: m_RequestVerb = HttpVerb::kPut; break; case HTTP_DELETE: m_RequestVerb = HttpVerb::kDelete; break; case HTTP_HEAD: m_RequestVerb = HttpVerb::kHead; break; case HTTP_COPY: m_RequestVerb = HttpVerb::kCopy; break; case HTTP_OPTIONS: m_RequestVerb = HttpVerb::kOptions; break; default: ZEN_WARN("invalid HTTP method: '{}'", llhttp_method_name(static_cast(llhttp_get_method(&m_Parser)))); break; } std::string_view FullUrl(GetHeaderSubString(m_UrlRange)); if (auto QuerySplit = FullUrl.find_first_of('?'); QuerySplit != std::string_view::npos) { m_UrlRange.Length = uint32_t(QuerySplit); m_QueryStringRange = {.Offset = uint32_t(m_UrlRange.Offset + QuerySplit + 1), .Length = uint32_t(FullUrl.size() - QuerySplit - 1)}; } NormalizeUrlPath(FullUrl, m_NormalizedUrl); std::string_view Value = GetHeaderValue(m_ContentLengthHeaderIndex); if (!Value.empty()) { uint64_t ContentLength = 0; std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength); if (ContentLength) { // TODO: should sanity-check content length here m_BodyBuffer = IoBuffer(ContentLength); } m_BodyBuffer.SetContentType(ContentType()); m_BodyPosition = 0; } } catch (const std::exception& Ex) { ZEN_WARN("HttpRequestParser::OnHeadersComplete failed. Reason '{}'", Ex.what()); return -1; } return 0; } int HttpRequestParser::OnBody(const char* Data, size_t Bytes) { if ((m_BodyPosition + Bytes) > m_BodyBuffer.Size()) { ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes", (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); return -1; } memcpy(reinterpret_cast(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); m_BodyPosition += Bytes; return 0; } void HttpRequestParser::ResetState() { m_UrlRange = {}; m_QueryStringRange = {}; m_HeaderEntries.clear(); m_ContentLengthHeaderIndex = -1; m_AcceptHeaderIndex = -1; m_ContentTypeHeaderIndex = -1; m_RangeHeaderIndex = -1; m_AuthorizationHeaderIndex = -1; m_UpgradeHeaderIndex = -1; m_SecWebSocketKeyHeaderIndex = -1; m_SecWebSocketVersionHeaderIndex = -1; m_RequestVerb = HttpVerb::kGet; m_Expect100Continue = false; m_BodyBuffer = {}; m_BodyPosition = 0; m_HeaderData.clear(); m_NormalizedUrl.clear(); } int HttpRequestParser::OnMessageBegin() { return 0; } int HttpRequestParser::OnMessageComplete() { try { m_Connection.HandleRequest(); ResetState(); return 0; } catch (const AssertException& AssertEx) { ZEN_WARN("Assert caught when processing http request: {}", AssertEx.FullDescription()); return -1; } catch (const std::system_error& SystemError) { if (IsOOM(SystemError.code())) { ZEN_WARN("out of memory when processing http request: '{}'", SystemError.what()); } else if (IsOOD(SystemError.code())) { ZEN_WARN("out of disk space when processing http request: '{}'", SystemError.what()); } else { ZEN_ERROR("failed processing http request: '{}' ({})", SystemError.what(), SystemError.code().value()); } ResetState(); return -1; } catch (const std::bad_alloc& BadAlloc) { ZEN_WARN("out of memory when processing http request: '{}'", BadAlloc.what()); ResetState(); return -1; } catch (const std::exception& Ex) { ZEN_ERROR("failed processing http request: '{}'", Ex.what()); ResetState(); return -1; } } bool HttpRequestParser::IsWebSocketUpgrade() const { std::string_view Upgrade = GetHeaderValue(m_UpgradeHeaderIndex); if (Upgrade.empty()) { return false; } // Case-insensitive check for "websocket" if (Upgrade.size() != 9) { return false; } return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0; } ////////////////////////////////////////////////////////////////////////// #if ZEN_WITH_TESTS namespace { struct MockCallbacks : HttpRequestParserCallbacks { int HandleRequestCount = 0; int TerminateCount = 0; HttpRequestParser* Parser = nullptr; HttpVerb LastVerb{}; std::string LastUrl; std::string LastQueryString; std::string LastBody; bool LastKeepAlive = false; bool LastIsWebSocketUpgrade = false; std::string LastSecWebSocketKey; std::string LastUpgradeHeader; HttpContentType LastContentType{}; void HandleRequest() override { ++HandleRequestCount; if (Parser) { LastVerb = Parser->RequestVerb(); LastUrl = std::string(Parser->Url()); LastQueryString = std::string(Parser->QueryString()); LastKeepAlive = Parser->IsKeepAlive(); LastIsWebSocketUpgrade = Parser->IsWebSocketUpgrade(); LastSecWebSocketKey = std::string(Parser->SecWebSocketKey()); LastUpgradeHeader = std::string(Parser->UpgradeHeader()); LastContentType = Parser->ContentType(); IoBuffer Body = Parser->Body(); if (Body.Size() > 0) { LastBody.assign(reinterpret_cast(Body.Data()), Body.Size()); } else { LastBody.clear(); } } } void TerminateConnection() override { ++TerminateCount; } }; } // anonymous namespace TEST_SUITE_BEGIN("http.httpparser"); TEST_CASE("httpparser.basic_get") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); CHECK_EQ(Consumed, Request.size()); CHECK_EQ(Mock.HandleRequestCount, 1); CHECK_EQ(Mock.LastVerb, HttpVerb::kGet); CHECK_EQ(Mock.LastUrl, "/path"); CHECK(Mock.LastKeepAlive); } TEST_CASE("httpparser.post_with_body") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string Request = "POST /api HTTP/1.1\r\n" "Host: localhost\r\n" "Content-Length: 13\r\n" "Content-Type: application/json\r\n" "\r\n" "{\"key\":\"val\"}"; size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); CHECK_EQ(Consumed, Request.size()); CHECK_EQ(Mock.HandleRequestCount, 1); CHECK_EQ(Mock.LastVerb, HttpVerb::kPost); CHECK_EQ(Mock.LastBody, "{\"key\":\"val\"}"); CHECK_EQ(Mock.LastContentType, HttpContentType::kJSON); } TEST_CASE("httpparser.pipelined_requests") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string Request = "GET /first HTTP/1.1\r\nHost: localhost\r\n\r\n" "GET /second HTTP/1.1\r\nHost: localhost\r\n\r\n"; size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); CHECK_EQ(Consumed, Request.size()); CHECK_EQ(Mock.HandleRequestCount, 2); CHECK_EQ(Mock.LastUrl, "/second"); } TEST_CASE("httpparser.partial_header") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string Chunk1 = "GET /path HTTP/1.1\r\nHost: loc"; std::string Chunk2 = "alhost\r\n\r\n"; size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size()); CHECK_NE(Consumed1, ~0ull); CHECK_EQ(Consumed1, Chunk1.size()); CHECK_EQ(Mock.HandleRequestCount, 0); size_t Consumed2 = Parser.ConsumeData(Chunk2.data(), Chunk2.size()); CHECK_NE(Consumed2, ~0ull); CHECK_EQ(Consumed2, Chunk2.size()); CHECK_EQ(Mock.HandleRequestCount, 1); CHECK_EQ(Mock.LastUrl, "/path"); } TEST_CASE("httpparser.partial_body") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string Headers = "POST /api HTTP/1.1\r\n" "Host: localhost\r\n" "Content-Length: 10\r\n" "\r\n"; std::string BodyPart1 = "hello"; std::string BodyPart2 = "world"; std::string Chunk1 = Headers + BodyPart1; size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size()); CHECK_NE(Consumed1, ~0ull); CHECK_EQ(Consumed1, Chunk1.size()); CHECK_EQ(Mock.HandleRequestCount, 0); size_t Consumed2 = Parser.ConsumeData(BodyPart2.data(), BodyPart2.size()); CHECK_NE(Consumed2, ~0ull); CHECK_EQ(Consumed2, BodyPart2.size()); CHECK_EQ(Mock.HandleRequestCount, 1); CHECK_EQ(Mock.LastBody, "helloworld"); } TEST_CASE("httpparser.invalid_request") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string Garbage = "NOT_HTTP garbage data\r\n\r\n"; size_t Consumed = Parser.ConsumeData(Garbage.data(), Garbage.size()); CHECK_EQ(Consumed, ~0ull); CHECK_EQ(Mock.HandleRequestCount, 0); } TEST_CASE("httpparser.body_overflow") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; // llhttp enforces Content-Length strictly: it delivers exactly 2 body bytes, // fires on_message_complete, then tries to parse the remaining "O_LONG_BODY" // as a new HTTP request which fails. std::string Request = "POST /api HTTP/1.1\r\n" "Host: localhost\r\n" "Content-Length: 2\r\n" "\r\n" "TOO_LONG_BODY"; size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); CHECK_EQ(Consumed, ~0ull); CHECK_EQ(Mock.HandleRequestCount, 1); CHECK_EQ(Mock.LastBody, "TO"); } TEST_CASE("httpparser.websocket_upgrade") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string Request = "GET /ws HTTP/1.1\r\n" "Host: localhost\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" "Sec-WebSocket-Version: 13\r\n" "\r\n"; size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); CHECK_EQ(Consumed, Request.size()); CHECK_EQ(Mock.HandleRequestCount, 1); CHECK(Mock.LastIsWebSocketUpgrade); CHECK_EQ(Mock.LastSecWebSocketKey, "dGhlIHNhbXBsZSBub25jZQ=="); CHECK_EQ(Mock.LastUpgradeHeader, "websocket"); } TEST_CASE("httpparser.websocket_upgrade_with_trailing_bytes") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string HttpPart = "GET /ws HTTP/1.1\r\n" "Host: localhost\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" "Sec-WebSocket-Version: 13\r\n" "\r\n"; // Append fake WebSocket frame bytes after the HTTP message std::string Request = HttpPart; Request.push_back('\x81'); Request.push_back('\x05'); Request.append("hello"); size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); CHECK_EQ(Consumed, Request.size()); CHECK_NE(Consumed, ~0ull); CHECK_EQ(Mock.HandleRequestCount, 1); CHECK(Mock.LastIsWebSocketUpgrade); } TEST_CASE("httpparser.keep_alive_detection") { SUBCASE("HTTP/1.1 default keep-alive") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; Parser.ConsumeData(Request.data(), Request.size()); CHECK(Mock.LastKeepAlive); } SUBCASE("Connection: close disables keep-alive") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; Parser.ConsumeData(Request.data(), Request.size()); CHECK_FALSE(Mock.LastKeepAlive); } } TEST_CASE("httpparser.all_verbs") { struct VerbTest { const char* Method; HttpVerb Expected; }; VerbTest Tests[] = { {"GET", HttpVerb::kGet}, {"POST", HttpVerb::kPost}, {"PUT", HttpVerb::kPut}, {"DELETE", HttpVerb::kDelete}, {"HEAD", HttpVerb::kHead}, {"COPY", HttpVerb::kCopy}, {"OPTIONS", HttpVerb::kOptions}, }; for (const VerbTest& Test : Tests) { CAPTURE(Test.Method); MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string Request = std::string(Test.Method) + " /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); CHECK_EQ(Consumed, Request.size()); CHECK_EQ(Mock.HandleRequestCount, 1); CHECK_EQ(Mock.LastVerb, Test.Expected); } } TEST_CASE("httpparser.query_string") { MockCallbacks Mock; HttpRequestParser Parser(Mock); Mock.Parser = &Parser; std::string Request = "GET /path?key=val&other=123 HTTP/1.1\r\nHost: localhost\r\n\r\n"; size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); CHECK_EQ(Consumed, Request.size()); CHECK_EQ(Mock.HandleRequestCount, 1); CHECK_EQ(Mock.LastUrl, "/path"); CHECK_EQ(Mock.LastQueryString, "key=val&other=123"); } TEST_SUITE_END(); void httpparser_forcelink() { } #endif // ZEN_WITH_TESTS } // namespace zen